@@ -19,6 +19,20 @@ use crate::{
19
19
20
20
use super :: frames:: { FrameIO , frames_from_stream} ;
21
21
22
+ // Helper function to check if a request is a WebSocket upgrade
23
+ fn is_websocket_upgrade ( request : & HttpRequest ) -> bool {
24
+ let connection = request. header ( "Connection" , "" ) . to_lowercase ( ) ;
25
+ let upgrade = request. header ( "Upgrade" , "" ) . to_lowercase ( ) ;
26
+
27
+ // Check if Connection header contains "upgrade" as a separate token
28
+ let has_upgrade_connection = connection. split ( ',' ) . any ( |token| token. trim ( ) == "upgrade" ) ;
29
+
30
+ // Check if Upgrade header is exactly "websocket"
31
+ let has_websocket_upgrade = upgrade == "websocket" ;
32
+
33
+ has_upgrade_connection && has_websocket_upgrade
34
+ }
35
+
22
36
// Helper function to send error response to client
23
37
async fn send_error_response (
24
38
client_stream : & mut IOBufStream ,
@@ -296,14 +310,7 @@ impl ContextCallback for HttpForwardCallback {
296
310
let mut request = request. unwrap ( ) . as_ref ( ) . clone ( ) ;
297
311
298
312
// Only add Connection: close for regular HTTP requests, not WebSocket upgrades
299
- let is_websocket_upgrade = request
300
- . header ( "Connection" , "" )
301
- . to_lowercase ( )
302
- . contains ( "upgrade" )
303
- && request
304
- . header ( "Upgrade" , "" )
305
- . to_lowercase ( )
306
- . contains ( "websocket" ) ;
313
+ let is_websocket_upgrade = is_websocket_upgrade ( & request) ;
307
314
308
315
if !is_websocket_upgrade {
309
316
request = request. with_header ( "Connection" , "close" ) ;
@@ -400,3 +407,53 @@ impl ContextCallback for FrameChannelCallback {
400
407
}
401
408
}
402
409
}
410
+
411
+ #[ cfg( test) ]
412
+ mod tests {
413
+ use super :: * ;
414
+
415
+ #[ test]
416
+ fn test_websocket_detection ( ) {
417
+ // Valid WebSocket upgrade request
418
+ let ws_request = HttpRequest :: new ( "GET" , "/" )
419
+ . with_header ( "Connection" , "upgrade" )
420
+ . with_header ( "Upgrade" , "websocket" ) ;
421
+ assert ! ( is_websocket_upgrade( & ws_request) ) ;
422
+
423
+ // Valid WebSocket upgrade with multiple Connection header values
424
+ let ws_request_multi = HttpRequest :: new ( "GET" , "/" )
425
+ . with_header ( "Connection" , "keep-alive, upgrade" )
426
+ . with_header ( "Upgrade" , "websocket" ) ;
427
+ assert ! ( is_websocket_upgrade( & ws_request_multi) ) ;
428
+
429
+ // Valid WebSocket upgrade with different casing
430
+ let ws_request_case = HttpRequest :: new ( "GET" , "/" )
431
+ . with_header ( "Connection" , "Upgrade" )
432
+ . with_header ( "Upgrade" , "WebSocket" ) ;
433
+ assert ! ( is_websocket_upgrade( & ws_request_case) ) ;
434
+
435
+ // Invalid: contains "upgrade" but not as separate token
436
+ let invalid_contains = HttpRequest :: new ( "GET" , "/" )
437
+ . with_header ( "Connection" , "keep-alive-upgrade" )
438
+ . with_header ( "Upgrade" , "websocket" ) ;
439
+ assert ! ( !is_websocket_upgrade( & invalid_contains) ) ;
440
+
441
+ // Invalid: Upgrade header contains websocket but not exactly
442
+ let invalid_upgrade = HttpRequest :: new ( "GET" , "/" )
443
+ . with_header ( "Connection" , "upgrade" )
444
+ . with_header ( "Upgrade" , "websocket-custom" ) ;
445
+ assert ! ( !is_websocket_upgrade( & invalid_upgrade) ) ;
446
+
447
+ // Invalid: Regular HTTP request
448
+ let http_request = HttpRequest :: new ( "GET" , "/" ) . with_header ( "Connection" , "keep-alive" ) ;
449
+ assert ! ( !is_websocket_upgrade( & http_request) ) ;
450
+
451
+ // Invalid: Missing Connection header
452
+ let no_connection = HttpRequest :: new ( "GET" , "/" ) . with_header ( "Upgrade" , "websocket" ) ;
453
+ assert ! ( !is_websocket_upgrade( & no_connection) ) ;
454
+
455
+ // Invalid: Missing Upgrade header
456
+ let no_upgrade = HttpRequest :: new ( "GET" , "/" ) . with_header ( "Connection" , "upgrade" ) ;
457
+ assert ! ( !is_websocket_upgrade( & no_upgrade) ) ;
458
+ }
459
+ }
0 commit comments