Skip to content

Commit 607bcdd

Browse files
committed
fixed support for oauth headers passthrough
1 parent 02e1e30 commit 607bcdd

20 files changed

+1418
-92
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "chuk-tool-processor"
7-
version = "0.6.14"
7+
version = "0.6.15"
88
description = "Async-native framework for registering, discovering, and executing tools referenced in LLM responses"
99
readme = "README.md"
1010
requires-python = ">=3.11"

src/chuk_tool_processor/mcp/stream_manager.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,12 +369,11 @@ async def initialize_with_http_streamable(
369369
"session_id": cfg.get("session_id"),
370370
}
371371

372-
# Handle headers if provided (for future HTTPStreamableTransport support)
372+
# Handle headers if provided
373373
headers = cfg.get("headers", {})
374374
if headers:
375-
logger.debug("HTTP Streamable %s: Headers provided but not yet supported in transport", name)
376-
# TODO: Add headers support when HTTPStreamableTransport is updated
377-
# transport_params['headers'] = headers
375+
transport_params["headers"] = headers
376+
logger.debug("HTTP Streamable %s: Custom headers configured: %s", name, list(headers.keys()))
378377

379378
transport = HTTPStreamableTransport(**transport_params)
380379

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
"""Tests for HTTP Streamable transport with custom headers support."""
2+
3+
from unittest.mock import AsyncMock, patch
4+
5+
import pytest
6+
7+
from chuk_tool_processor.mcp.stream_manager import StreamManager
8+
from chuk_tool_processor.mcp.transport.http_streamable_transport import HTTPStreamableTransport
9+
10+
11+
class TestHTTPStreamableTransportHeaders:
12+
"""Tests for HTTPStreamableTransport custom headers functionality."""
13+
14+
def test_init_with_headers(self):
15+
"""Test transport initialization with custom headers."""
16+
headers = {"Authorization": "Bearer test-token", "X-Custom-Header": "custom-value"}
17+
18+
transport = HTTPStreamableTransport(url="https://example.com/mcp", headers=headers)
19+
20+
assert transport.configured_headers == headers
21+
assert "Authorization" in transport.configured_headers
22+
assert transport.configured_headers["Authorization"] == "Bearer test-token"
23+
24+
def test_init_without_headers(self):
25+
"""Test transport initialization without custom headers."""
26+
transport = HTTPStreamableTransport(url="https://example.com/mcp")
27+
28+
assert transport.configured_headers == {}
29+
30+
def test_headers_with_api_key(self):
31+
"""Test that both headers and api_key can coexist."""
32+
headers = {"X-Custom": "value"}
33+
34+
transport = HTTPStreamableTransport(url="https://example.com/mcp", api_key="test-key", headers=headers)
35+
36+
assert transport.api_key == "test-key"
37+
assert transport.configured_headers == headers
38+
39+
@pytest.mark.asyncio
40+
async def test_headers_used_in_requests(self):
41+
"""Test that configured headers are used in HTTP requests."""
42+
headers = {"Authorization": "Bearer oauth-token"}
43+
44+
with patch("chuk_tool_processor.mcp.transport.http_streamable_transport.http_client") as mock_client:
45+
# Mock the http_client context manager
46+
mock_read = AsyncMock()
47+
mock_write = AsyncMock()
48+
mock_context = AsyncMock()
49+
mock_context.__aenter__ = AsyncMock(return_value=(mock_read, mock_write))
50+
mock_context.__aexit__ = AsyncMock(return_value=None)
51+
52+
# http_client is called as a function returning a context manager
53+
mock_client.return_value = mock_context
54+
55+
# Mock the send_initialize and send_ping functions
56+
with (
57+
patch("chuk_tool_processor.mcp.transport.http_streamable_transport.send_initialize") as mock_init,
58+
patch("chuk_tool_processor.mcp.transport.http_streamable_transport.send_ping") as mock_ping,
59+
):
60+
mock_init.return_value = None
61+
mock_ping.return_value = True
62+
63+
transport = HTTPStreamableTransport(url="https://example.com/mcp", headers=headers)
64+
65+
# Initialize should use the headers
66+
await transport.initialize()
67+
68+
# Verify http_client was called
69+
assert mock_client.called
70+
71+
# Get the call arguments
72+
call_args = mock_client.call_args
73+
74+
# The first argument should be StreamableHTTPParameters with headers
75+
http_params = call_args[0][0]
76+
assert hasattr(http_params, "headers")
77+
assert "Authorization" in http_params.headers
78+
assert http_params.headers["Authorization"] == "Bearer oauth-token"
79+
80+
81+
class TestStreamManagerHeadersPassthrough:
82+
"""Tests for StreamManager passing headers to HTTP transport."""
83+
84+
@pytest.mark.asyncio
85+
async def test_headers_passthrough_to_http_transport(self):
86+
"""Test that headers from server config are passed to HTTPStreamableTransport."""
87+
servers = [
88+
{
89+
"name": "test-server",
90+
"url": "https://example.com/mcp",
91+
"headers": {"Authorization": "Bearer test-token", "X-Custom": "value"},
92+
}
93+
]
94+
95+
with patch("chuk_tool_processor.mcp.stream_manager.HTTPStreamableTransport") as MockTransport:
96+
# Create mock transport instance
97+
mock_transport = AsyncMock()
98+
mock_transport.initialize = AsyncMock(return_value=True)
99+
mock_transport.send_ping = AsyncMock(return_value=True)
100+
mock_transport.get_tools = AsyncMock(return_value=[{"name": "test_tool", "description": "Test tool"}])
101+
102+
MockTransport.return_value = mock_transport
103+
104+
# Initialize stream manager
105+
stream_manager = StreamManager()
106+
await stream_manager.initialize_with_http_streamable(
107+
servers=servers, connection_timeout=5.0, default_timeout=5.0
108+
)
109+
110+
# Verify HTTPStreamableTransport was called with headers
111+
MockTransport.assert_called_once()
112+
call_kwargs = MockTransport.call_args.kwargs
113+
114+
assert "headers" in call_kwargs
115+
assert call_kwargs["headers"] == {"Authorization": "Bearer test-token", "X-Custom": "value"}
116+
117+
# Verify transport was initialized
118+
assert mock_transport.initialize.called
119+
120+
@pytest.mark.asyncio
121+
async def test_no_headers_still_works(self):
122+
"""Test that servers without headers still work."""
123+
servers = [{"name": "test-server", "url": "https://example.com/mcp"}]
124+
125+
with patch("chuk_tool_processor.mcp.stream_manager.HTTPStreamableTransport") as MockTransport:
126+
mock_transport = AsyncMock()
127+
mock_transport.initialize = AsyncMock(return_value=True)
128+
mock_transport.send_ping = AsyncMock(return_value=True)
129+
mock_transport.get_tools = AsyncMock(return_value=[])
130+
131+
MockTransport.return_value = mock_transport
132+
133+
stream_manager = StreamManager()
134+
await stream_manager.initialize_with_http_streamable(
135+
servers=servers, connection_timeout=5.0, default_timeout=5.0
136+
)
137+
138+
# Should be called without headers parameter
139+
MockTransport.assert_called_once()
140+
call_kwargs = MockTransport.call_args.kwargs
141+
142+
# Headers should not be in kwargs or should be empty dict
143+
assert call_kwargs.get("headers", {}) == {}
144+
145+
@pytest.mark.asyncio
146+
async def test_multiple_servers_with_different_headers(self):
147+
"""Test multiple servers with different header configurations."""
148+
servers = [
149+
{"name": "server1", "url": "https://server1.com/mcp", "headers": {"Authorization": "Bearer token1"}},
150+
{"name": "server2", "url": "https://server2.com/mcp", "headers": {"Authorization": "Bearer token2"}},
151+
{"name": "server3", "url": "https://server3.com/mcp"},
152+
]
153+
154+
with patch("chuk_tool_processor.mcp.stream_manager.HTTPStreamableTransport") as MockTransport:
155+
mock_transport = AsyncMock()
156+
mock_transport.initialize = AsyncMock(return_value=True)
157+
mock_transport.send_ping = AsyncMock(return_value=True)
158+
mock_transport.get_tools = AsyncMock(return_value=[])
159+
160+
MockTransport.return_value = mock_transport
161+
162+
stream_manager = StreamManager()
163+
await stream_manager.initialize_with_http_streamable(
164+
servers=servers, connection_timeout=5.0, default_timeout=5.0
165+
)
166+
167+
# Should be called 3 times
168+
assert MockTransport.call_count == 3
169+
170+
# Check each call
171+
calls = MockTransport.call_args_list
172+
173+
# Server 1 - with headers
174+
assert calls[0].kwargs.get("headers") == {"Authorization": "Bearer token1"}
175+
176+
# Server 2 - with different headers
177+
assert calls[1].kwargs.get("headers") == {"Authorization": "Bearer token2"}
178+
179+
# Server 3 - no headers
180+
assert calls[2].kwargs.get("headers", {}) == {}
181+
182+
183+
class TestOAuthHeadersIntegration:
184+
"""Integration tests for OAuth-style headers."""
185+
186+
@pytest.mark.asyncio
187+
async def test_oauth_bearer_token(self):
188+
"""Test OAuth bearer token in Authorization header."""
189+
servers = [
190+
{
191+
"name": "oauth-server",
192+
"url": "https://api.example.com/mcp",
193+
"headers": {"Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."},
194+
}
195+
]
196+
197+
with patch("chuk_tool_processor.mcp.stream_manager.HTTPStreamableTransport") as MockTransport:
198+
mock_transport = AsyncMock()
199+
mock_transport.initialize = AsyncMock(return_value=True)
200+
mock_transport.send_ping = AsyncMock(return_value=True)
201+
mock_transport.get_tools = AsyncMock(return_value=[])
202+
203+
MockTransport.return_value = mock_transport
204+
205+
stream_manager = StreamManager()
206+
await stream_manager.initialize_with_http_streamable(servers=servers)
207+
208+
# Verify OAuth header was passed
209+
call_kwargs = MockTransport.call_args.kwargs
210+
assert "Authorization" in call_kwargs.get("headers", {})
211+
assert call_kwargs["headers"]["Authorization"].startswith("Bearer ")
212+
213+
@pytest.mark.asyncio
214+
async def test_multiple_auth_headers(self):
215+
"""Test multiple authentication-related headers."""
216+
servers = [
217+
{
218+
"name": "multi-auth-server",
219+
"url": "https://api.example.com/mcp",
220+
"headers": {"Authorization": "Bearer token", "X-API-Key": "api-key-123", "X-Session-Id": "session-abc"},
221+
}
222+
]
223+
224+
with patch("chuk_tool_processor.mcp.stream_manager.HTTPStreamableTransport") as MockTransport:
225+
mock_transport = AsyncMock()
226+
mock_transport.initialize = AsyncMock(return_value=True)
227+
mock_transport.send_ping = AsyncMock(return_value=True)
228+
mock_transport.get_tools = AsyncMock(return_value=[])
229+
230+
MockTransport.return_value = mock_transport
231+
232+
stream_manager = StreamManager()
233+
await stream_manager.initialize_with_http_streamable(servers=servers)
234+
235+
# Verify all headers were passed
236+
call_kwargs = MockTransport.call_args.kwargs
237+
headers = call_kwargs.get("headers", {})
238+
239+
assert headers["Authorization"] == "Bearer token"
240+
assert headers["X-API-Key"] == "api-key-123"
241+
assert headers["X-Session-Id"] == "session-abc"

tests/mcp/test_mcp_tool.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,6 @@ async def test_execute_with_dict_arguments(self, simple_mcp_tool, mock_stream_ma
422422
assert result == "dict args work"
423423
mock_stream_manager.call_tool.assert_called_once()
424424

425-
426425
def test_serialization_preserves_recovery_config(self, simple_mcp_tool):
427426
"""Test serialization preserves recovery config."""
428427
from chuk_tool_processor.mcp.mcp_tool import RecoveryConfig

tests/mcp/test_register_mcp_tools.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,6 @@ def _patch_registry(self, mock_registry):
451451
@pytest.fixture
452452
def mock_registry_with_tools(self):
453453
"""Mock registry with some registered tools."""
454-
from chuk_tool_processor.mcp.register_mcp_tools import update_mcp_tools_stream_manager
455454

456455
reg = Mock(spec=ToolRegistryInterface)
457456
reg.list_tools = AsyncMock(return_value=[("mcp", "tool1"), ("mcp", "tool2"), ("other", "tool3")])

tests/mcp/test_setup_mcp_http_streamable.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ async def test_setup_mcp_http_streamable_basic(self):
3838
{"name": "geocoding", "url": "http://geo.com"},
3939
]
4040

41-
processor, stream_manager = await setup_mcp_http_streamable(
42-
servers=servers, namespace="http_test"
43-
)
41+
processor, stream_manager = await setup_mcp_http_streamable(servers=servers, namespace="http_test")
4442

4543
assert processor == mock_processor
4644
assert stream_manager == mock_stream_manager
@@ -59,9 +57,7 @@ async def test_setup_with_custom_options(self):
5957
"chuk_tool_processor.mcp.setup_mcp_http_streamable.register_mcp_tools",
6058
AsyncMock(return_value=["tool1"]),
6159
) as mock_register,
62-
patch(
63-
"chuk_tool_processor.mcp.setup_mcp_http_streamable.ToolProcessor"
64-
) as mock_processor_class,
60+
patch("chuk_tool_processor.mcp.setup_mcp_http_streamable.ToolProcessor") as mock_processor_class,
6561
):
6662
servers = [{"name": "test", "url": "http://test.com", "api_key": "test-key"}]
6763
server_names = {0: "custom_name"}
@@ -195,9 +191,7 @@ async def test_setup_passes_all_processor_options(self):
195191
"chuk_tool_processor.mcp.setup_mcp_http_streamable.register_mcp_tools",
196192
AsyncMock(return_value=[]),
197193
),
198-
patch(
199-
"chuk_tool_processor.mcp.setup_mcp_http_streamable.ToolProcessor"
200-
) as mock_processor_class,
194+
patch("chuk_tool_processor.mcp.setup_mcp_http_streamable.ToolProcessor") as mock_processor_class,
201195
):
202196
servers = [{"name": "test", "url": "http://test.com"}]
203197

tests/mcp/test_stream_manager.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -390,9 +390,7 @@ async def test_create_with_sse_factory(self):
390390
mock_transport.get_tools = AsyncMock(return_value=[])
391391
mock_sse.return_value = mock_transport
392392

393-
stream_manager = await StreamManager.create_with_sse(
394-
servers=[{"name": "test", "url": "http://test.com"}]
395-
)
393+
stream_manager = await StreamManager.create_with_sse(servers=[{"name": "test", "url": "http://test.com"}])
396394

397395
assert stream_manager is not None
398396
assert isinstance(stream_manager, StreamManager)
@@ -762,9 +760,7 @@ async def test_initialize_with_http_streamable_init_failure(self, stream_manager
762760
mock_transport.initialize = AsyncMock(return_value=False)
763761
mock_http.return_value = mock_transport
764762

765-
await stream_manager.initialize_with_http_streamable(
766-
servers=[{"name": "test", "url": "http://test.com"}]
767-
)
763+
await stream_manager.initialize_with_http_streamable(servers=[{"name": "test", "url": "http://test.com"}])
768764

769765
assert "test" not in stream_manager.transports
770766

@@ -798,21 +794,30 @@ async def test_initialize_with_http_streamable_exception(self, stream_manager):
798794
assert "test" not in stream_manager.transports
799795

800796
@pytest.mark.asyncio
801-
async def test_initialize_with_http_streamable_headers_warning(self, stream_manager):
802-
"""Test initialize_with_http_streamable logs warning for headers."""
797+
async def test_initialize_with_http_streamable_headers_support(self, stream_manager):
798+
"""Test initialize_with_http_streamable passes headers correctly."""
803799
with patch("chuk_tool_processor.mcp.stream_manager.HTTPStreamableTransport") as mock_http:
804800
mock_transport = AsyncMock(spec=MCPBaseTransport)
805801
mock_transport.initialize = AsyncMock(return_value=True)
802+
mock_transport.send_ping = AsyncMock(return_value=True)
806803
mock_transport.get_tools = AsyncMock(return_value=[])
807804
mock_http.return_value = mock_transport
808805

809806
await stream_manager.initialize_with_http_streamable(
810-
servers=[{"name": "test", "url": "http://test.com", "headers": {"Custom": "Header"}}]
807+
servers=[
808+
{
809+
"name": "test",
810+
"url": "http://test.com",
811+
"headers": {"Custom": "Header", "Authorization": "Bearer token"},
812+
}
813+
]
811814
)
812815

813-
# Headers should not be passed (not supported yet)
816+
# Headers should now be passed
814817
call_kwargs = mock_http.call_args[1]
815-
assert "headers" not in call_kwargs
818+
assert "headers" in call_kwargs
819+
assert call_kwargs["headers"]["Custom"] == "Header"
820+
assert call_kwargs["headers"]["Authorization"] == "Bearer token"
816821

817822
# ------------------------------------------------------------------ #
818823
# Query methods tests #
@@ -1152,10 +1157,7 @@ def test_get_streams_when_closed(self, stream_manager):
11521157
def test_get_streams_fallback_to_attributes(self, stream_manager):
11531158
"""Test get_streams falls back to read/write stream attributes."""
11541159
# Create mock that doesn't have get_streams but has read/write_stream attributes
1155-
mock_transport = type('MockTransport', (), {
1156-
'read_stream': 'read',
1157-
'write_stream': 'write'
1158-
})()
1160+
mock_transport = type("MockTransport", (), {"read_stream": "read", "write_stream": "write"})()
11591161
stream_manager.transports["server1"] = mock_transport
11601162

11611163
streams = stream_manager.get_streams()

0 commit comments

Comments
 (0)