Skip to content

Commit eb7e010

Browse files
beariceclaude
andcommitted
Improve WebSocket detection with proper header parsing
- Extract WebSocket upgrade detection into dedicated function - Replace substring matching with proper token parsing for Connection header - Use exact matching for Upgrade header instead of contains() - Add comprehensive test coverage for edge cases - Improve RFC 6455 compliance and prevent false positives 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 7c5c236 commit eb7e010

File tree

1 file changed

+65
-8
lines changed

1 file changed

+65
-8
lines changed

src/common/http_proxy.rs

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,20 @@ use crate::{
1919

2020
use super::frames::{FrameIO, frames_from_stream};
2121

22+
// Helper function to check if a request is a WebSocket upgrade
23+
fn is_websocket_upgrade(request: &HttpRequest) -> bool {
24+
let connection = request.header("Connection", "").to_lowercase();
25+
let upgrade = request.header("Upgrade", "").to_lowercase();
26+
27+
// Check if Connection header contains "upgrade" as a separate token
28+
let has_upgrade_connection = connection.split(',').any(|token| token.trim() == "upgrade");
29+
30+
// Check if Upgrade header is exactly "websocket"
31+
let has_websocket_upgrade = upgrade == "websocket";
32+
33+
has_upgrade_connection && has_websocket_upgrade
34+
}
35+
2236
// Helper function to send error response to client
2337
async fn send_error_response(
2438
client_stream: &mut IOBufStream,
@@ -296,14 +310,7 @@ impl ContextCallback for HttpForwardCallback {
296310
let mut request = request.unwrap().as_ref().clone();
297311

298312
// Only add Connection: close for regular HTTP requests, not WebSocket upgrades
299-
let is_websocket_upgrade = request
300-
.header("Connection", "")
301-
.to_lowercase()
302-
.contains("upgrade")
303-
&& request
304-
.header("Upgrade", "")
305-
.to_lowercase()
306-
.contains("websocket");
313+
let is_websocket_upgrade = is_websocket_upgrade(&request);
307314

308315
if !is_websocket_upgrade {
309316
request = request.with_header("Connection", "close");
@@ -400,3 +407,53 @@ impl ContextCallback for FrameChannelCallback {
400407
}
401408
}
402409
}
410+
411+
#[cfg(test)]
412+
mod tests {
413+
use super::*;
414+
415+
#[test]
416+
fn test_websocket_detection() {
417+
// Valid WebSocket upgrade request
418+
let ws_request = HttpRequest::new("GET", "/")
419+
.with_header("Connection", "upgrade")
420+
.with_header("Upgrade", "websocket");
421+
assert!(is_websocket_upgrade(&ws_request));
422+
423+
// Valid WebSocket upgrade with multiple Connection header values
424+
let ws_request_multi = HttpRequest::new("GET", "/")
425+
.with_header("Connection", "keep-alive, upgrade")
426+
.with_header("Upgrade", "websocket");
427+
assert!(is_websocket_upgrade(&ws_request_multi));
428+
429+
// Valid WebSocket upgrade with different casing
430+
let ws_request_case = HttpRequest::new("GET", "/")
431+
.with_header("Connection", "Upgrade")
432+
.with_header("Upgrade", "WebSocket");
433+
assert!(is_websocket_upgrade(&ws_request_case));
434+
435+
// Invalid: contains "upgrade" but not as separate token
436+
let invalid_contains = HttpRequest::new("GET", "/")
437+
.with_header("Connection", "keep-alive-upgrade")
438+
.with_header("Upgrade", "websocket");
439+
assert!(!is_websocket_upgrade(&invalid_contains));
440+
441+
// Invalid: Upgrade header contains websocket but not exactly
442+
let invalid_upgrade = HttpRequest::new("GET", "/")
443+
.with_header("Connection", "upgrade")
444+
.with_header("Upgrade", "websocket-custom");
445+
assert!(!is_websocket_upgrade(&invalid_upgrade));
446+
447+
// Invalid: Regular HTTP request
448+
let http_request = HttpRequest::new("GET", "/").with_header("Connection", "keep-alive");
449+
assert!(!is_websocket_upgrade(&http_request));
450+
451+
// Invalid: Missing Connection header
452+
let no_connection = HttpRequest::new("GET", "/").with_header("Upgrade", "websocket");
453+
assert!(!is_websocket_upgrade(&no_connection));
454+
455+
// Invalid: Missing Upgrade header
456+
let no_upgrade = HttpRequest::new("GET", "/").with_header("Connection", "upgrade");
457+
assert!(!is_websocket_upgrade(&no_upgrade));
458+
}
459+
}

0 commit comments

Comments
 (0)