@@ -2,6 +2,7 @@ import { isIP } from 'node:net';
2
2
3
3
import { Injectable } from '@nestjs/common' ;
4
4
import type { Response } from 'express' ;
5
+ import { ClsService } from 'nestjs-cls' ;
5
6
6
7
import { Config } from '../config' ;
7
8
import { OnEvent } from '../event' ;
@@ -11,10 +12,13 @@ export class URLHelper {
11
12
redirectAllowHosts ! : string [ ] ;
12
13
13
14
origin ! : string ;
15
+ allowedOrigins ! : string [ ] ;
14
16
baseUrl ! : string ;
15
- home ! : string ;
16
17
17
- constructor ( private readonly config : Config ) {
18
+ constructor (
19
+ private readonly config : Config ,
20
+ private readonly cls ?: ClsService
21
+ ) {
18
22
this . init ( ) ;
19
23
}
20
24
@@ -34,19 +38,40 @@ export class URLHelper {
34
38
this . baseUrl =
35
39
externalUrl . origin + externalUrl . pathname . replace ( / \/ $ / , '' ) ;
36
40
} else {
37
- this . origin = [
38
- this . config . server . https ? 'https' : 'http' ,
39
- '://' ,
40
- this . config . server . host ,
41
- this . config . server . host === 'localhost' || isIP ( this . config . server . host )
42
- ? `:${ this . config . server . port } `
43
- : '' ,
44
- ] . join ( '' ) ;
41
+ this . origin = this . convertHostToOrigin ( this . config . server . host ) ;
45
42
this . baseUrl = this . origin + this . config . server . path ;
46
43
}
47
44
48
- this . home = this . baseUrl ;
49
45
this . redirectAllowHosts = [ this . baseUrl ] ;
46
+
47
+ this . allowedOrigins = [ this . origin ] ;
48
+ if ( this . config . server . hosts . length > 0 ) {
49
+ for ( const host of this . config . server . hosts ) {
50
+ this . allowedOrigins . push ( this . convertHostToOrigin ( host ) ) ;
51
+ }
52
+ }
53
+ }
54
+
55
+ get requestOrigin ( ) {
56
+ if ( this . config . server . hosts . length === 0 ) {
57
+ return this . origin ;
58
+ }
59
+
60
+ // support multiple hosts
61
+ const requestHost = this . cls ?. get < string | undefined > ( CLS_REQUEST_HOST ) ;
62
+ if ( ! requestHost || ! this . config . server . hosts . includes ( requestHost ) ) {
63
+ return this . origin ;
64
+ }
65
+
66
+ return this . convertHostToOrigin ( requestHost ) ;
67
+ }
68
+
69
+ get requestBaseUrl ( ) {
70
+ if ( this . config . server . hosts . length === 0 ) {
71
+ return this . baseUrl ;
72
+ }
73
+
74
+ return this . requestOrigin + this . config . server . path ;
50
75
}
51
76
52
77
stringify ( query : Record < string , any > ) {
@@ -72,7 +97,7 @@ export class URLHelper {
72
97
}
73
98
74
99
url ( path : string , query : Record < string , any > = { } ) {
75
- const url = new URL ( path , this . origin ) ;
100
+ const url = new URL ( path , this . requestOrigin ) ;
76
101
77
102
for ( const key in query ) {
78
103
url . searchParams . set ( key , query [ key ] ) ;
@@ -87,7 +112,7 @@ export class URLHelper {
87
112
88
113
safeRedirect ( res : Response , to : string ) {
89
114
try {
90
- const finalTo = new URL ( decodeURIComponent ( to ) , this . baseUrl ) ;
115
+ const finalTo = new URL ( decodeURIComponent ( to ) , this . requestBaseUrl ) ;
91
116
92
117
for ( const host of this . redirectAllowHosts ) {
93
118
const hostURL = new URL ( host ) ;
@@ -103,7 +128,7 @@ export class URLHelper {
103
128
}
104
129
105
130
// redirect to home if the url is invalid
106
- return res . redirect ( this . home ) ;
131
+ return res . redirect ( this . baseUrl ) ;
107
132
}
108
133
109
134
verify ( url : string | URL ) {
@@ -118,4 +143,13 @@ export class URLHelper {
118
143
return false ;
119
144
}
120
145
}
146
+
147
+ private convertHostToOrigin ( host : string ) {
148
+ return [
149
+ this . config . server . https ? 'https' : 'http' ,
150
+ '://' ,
151
+ host ,
152
+ host === 'localhost' || isIP ( host ) ? `:${ this . config . server . port } ` : '' ,
153
+ ] . join ( '' ) ;
154
+ }
121
155
}
0 commit comments