Skip to content

Commit 82c21b2

Browse files
authored
Merge pull request #1 from allenporter/mcp-proxy-tests
fix: update mcp-proxy to only enable capabilities that are exported by the server
2 parents c07d479 + 48ea353 commit 82c21b2

File tree

5 files changed

+269
-61
lines changed

5 files changed

+269
-61
lines changed

pyproject.toml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,17 @@ build-backend = "setuptools.build_meta"
1111

1212
[project.scripts]
1313
mcp-proxy = "mcp_proxy.__main__:main"
14+
15+
[tool.uv]
16+
dev-dependencies = [
17+
"pytest>=8.3.3",
18+
"pytest-asyncio>=0.25.0",
19+
]
20+
21+
[tool.pytest.ini_options]
22+
pythonpath = "src"
23+
addopts = [
24+
"--import-mode=importlib",
25+
]
26+
asyncio_mode = "auto"
27+
asyncio_default_fixture_loop_scope = "function"

src/mcp_proxy/__init__.py

Lines changed: 70 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -7,79 +7,89 @@
77
logger = logging.getLogger(__name__)
88

99

10-
async def confugure_app(name: str, remote_app: ClientSession):
11-
app = server.Server(name)
10+
async def create_proxy_server(remote_app: ClientSession):
11+
"""Create a server instance from a remote app."""
1212

13-
async def _list_prompts(_: t.Any) -> types.ServerResult:
14-
result = await remote_app.list_prompts()
15-
return types.ServerResult(result)
13+
response = await remote_app.initialize()
14+
capabilities = response.capabilities
1615

17-
app.request_handlers[types.ListPromptsRequest] = _list_prompts
16+
app = server.Server(response.serverInfo.name)
1817

19-
async def _get_prompt(req: types.GetPromptRequest) -> types.ServerResult:
20-
result = await remote_app.get_prompt(req.params.name, req.params.arguments)
21-
return types.ServerResult(result)
18+
if capabilities.prompts:
19+
async def _list_prompts(_: t.Any) -> types.ServerResult:
20+
result = await remote_app.list_prompts()
21+
return types.ServerResult(result)
2222

23-
app.request_handlers[types.GetPromptRequest] = _get_prompt
23+
app.request_handlers[types.ListPromptsRequest] = _list_prompts
2424

25-
async def _list_resources(_: t.Any) -> types.ServerResult:
26-
result = await remote_app.list_resources()
27-
return types.ServerResult(result)
25+
async def _get_prompt(req: types.GetPromptRequest) -> types.ServerResult:
26+
result = await remote_app.get_prompt(req.params.name, req.params.arguments)
27+
return types.ServerResult(result)
2828

29-
app.request_handlers[types.ListResourcesRequest] = _list_resources
29+
app.request_handlers[types.GetPromptRequest] = _get_prompt
3030

31-
# list_resource_templates() is not implemented in the client
32-
# async def _list_resource_templates(_: t.Any) -> types.ServerResult:
33-
# result = await remote_app.list_resource_templates()
34-
# return types.ServerResult(result)
31+
if capabilities.resources:
32+
async def _list_resources(_: t.Any) -> types.ServerResult:
33+
result = await remote_app.list_resources()
34+
return types.ServerResult(result)
3535

36-
# app.request_handlers[types.ListResourceTemplatesRequest] = _list_resource_templates
36+
app.request_handlers[types.ListResourcesRequest] = _list_resources
3737

38-
async def _read_resource(req: types.ReadResourceRequest):
39-
result = await remote_app.read_resource(req.params.uri)
40-
return types.ServerResult(result)
38+
# list_resource_templates() is not implemented in the client
39+
# async def _list_resource_templates(_: t.Any) -> types.ServerResult:
40+
# result = await remote_app.list_resource_templates()
41+
# return types.ServerResult(result)
4142

42-
app.request_handlers[types.ReadResourceRequest] = _read_resource
43+
# app.request_handlers[types.ListResourceTemplatesRequest] = _list_resource_templates
4344

44-
async def _set_logging_level(req: types.SetLevelRequest):
45-
await remote_app.set_logging_level(req.params.level)
46-
return types.ServerResult(types.EmptyResult())
45+
async def _read_resource(req: types.ReadResourceRequest):
46+
result = await remote_app.read_resource(req.params.uri)
47+
return types.ServerResult(result)
4748

48-
app.request_handlers[types.SetLevelRequest] = _set_logging_level
49+
app.request_handlers[types.ReadResourceRequest] = _read_resource
4950

50-
async def _subscribe_resource(req: types.SubscribeRequest):
51-
await remote_app.subscribe_resource(req.params.uri)
52-
return types.ServerResult(types.EmptyResult())
51+
if capabilities.logging:
52+
async def _set_logging_level(req: types.SetLevelRequest):
53+
await remote_app.set_logging_level(req.params.level)
54+
return types.ServerResult(types.EmptyResult())
5355

54-
app.request_handlers[types.SubscribeRequest] = _subscribe_resource
56+
app.request_handlers[types.SetLevelRequest] = _set_logging_level
5557

56-
async def _unsubscribe_resource(req: types.UnsubscribeRequest):
57-
await remote_app.unsubscribe_resource(req.params.uri)
58-
return types.ServerResult(types.EmptyResult())
58+
if capabilities.resources:
59+
async def _subscribe_resource(req: types.SubscribeRequest):
60+
await remote_app.subscribe_resource(req.params.uri)
61+
return types.ServerResult(types.EmptyResult())
5962

60-
app.request_handlers[types.UnsubscribeRequest] = _unsubscribe_resource
63+
app.request_handlers[types.SubscribeRequest] = _subscribe_resource
6164

62-
async def _list_tools(_: t.Any):
63-
tools = await remote_app.list_tools()
64-
return types.ServerResult(tools)
65+
async def _unsubscribe_resource(req: types.UnsubscribeRequest):
66+
await remote_app.unsubscribe_resource(req.params.uri)
67+
return types.ServerResult(types.EmptyResult())
6568

66-
app.request_handlers[types.ListToolsRequest] = _list_tools
69+
app.request_handlers[types.UnsubscribeRequest] = _unsubscribe_resource
6770

68-
async def _call_tool(req: types.CallToolRequest) -> types.ServerResult:
69-
try:
70-
result = await remote_app.call_tool(
71-
req.params.name, (req.params.arguments or {})
72-
)
73-
return types.ServerResult(result)
74-
except Exception as e:
75-
return types.ServerResult(
76-
types.CallToolResult(
77-
content=[types.TextContent(type="text", text=str(e))],
78-
isError=True,
71+
if capabilities.tools:
72+
async def _list_tools(_: t.Any):
73+
tools = await remote_app.list_tools()
74+
return types.ServerResult(tools)
75+
76+
app.request_handlers[types.ListToolsRequest] = _list_tools
77+
78+
async def _call_tool(req: types.CallToolRequest) -> types.ServerResult:
79+
try:
80+
result = await remote_app.call_tool(
81+
req.params.name, (req.params.arguments or {})
82+
)
83+
return types.ServerResult(result)
84+
except Exception as e:
85+
return types.ServerResult(
86+
types.CallToolResult(
87+
content=[types.TextContent(type="text", text=str(e))],
88+
isError=True,
89+
)
7990
)
80-
)
8191

82-
app.request_handlers[types.CallToolRequest] = _call_tool
92+
app.request_handlers[types.CallToolRequest] = _call_tool
8393

8494
async def _send_progress_notification(req: types.ProgressNotification):
8595
await remote_app.send_progress_notification(
@@ -96,19 +106,18 @@ async def _complete(req: types.CompleteRequest):
96106

97107
app.request_handlers[types.CompleteRequest] = _complete
98108

99-
async with server.stdio_server() as (read_stream, write_stream):
100-
await app.run(
101-
read_stream,
102-
write_stream,
103-
app.create_initialization_options(),
104-
)
109+
return app
105110

106111

107112
async def run_sse_client(url: str):
108113
from mcp.client.sse import sse_client
109114

110115
async with sse_client(url=url) as (read_stream, write_stream):
111116
async with ClientSession(read_stream, write_stream) as session:
112-
response = await session.initialize()
113-
114-
await confugure_app(response.serverInfo.name, session)
117+
app = await create_proxy_server(session)
118+
async with server.stdio_server() as (read_stream, write_stream):
119+
await app.run(
120+
read_stream,
121+
write_stream,
122+
app.create_initialization_options(),
123+
)

tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Tests for mcp-proxy."""

tests/test_init.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""Tests for the mcp-proxy module.
2+
3+
Tests are running in two modes:
4+
- One where the server is exercised directly though an in memory client, just to
5+
set a baseline for the expected behavior.
6+
- Another where the server is exercised through a proxy server, which forwards
7+
the requests to the original server.
8+
9+
The same test code is run on both to ensure parity.
10+
"""
11+
12+
from typing import Any
13+
from collections.abc import AsyncGenerator, Callable
14+
from contextlib import asynccontextmanager, AbstractAsyncContextManager
15+
16+
import pytest
17+
18+
from mcp import types
19+
from mcp.client.session import ClientSession
20+
from mcp.server import Server
21+
from mcp.shared.exceptions import McpError
22+
from mcp.shared.memory import create_connected_server_and_client_session
23+
24+
from mcp_proxy import create_proxy_server
25+
26+
TOOL_INPUT_SCHEMA = {
27+
"type": "object",
28+
"properties": {
29+
"input1": {"type": "string"}
30+
}
31+
}
32+
33+
SessionContextManager = Callable[[Server], AbstractAsyncContextManager[ClientSession]]
34+
35+
# Direct server connection
36+
in_memory: SessionContextManager = create_connected_server_and_client_session
37+
38+
@asynccontextmanager
39+
async def proxy(server: Server) -> AsyncGenerator[ClientSession, None]:
40+
"""Create a connection to the server through the proxy server."""
41+
async with in_memory(server) as session:
42+
wrapped_server = await create_proxy_server(session)
43+
async with in_memory(wrapped_server) as wrapped_session:
44+
yield wrapped_session
45+
46+
47+
@pytest.fixture(params=["server", "proxy"], scope="function")
48+
def session_generator(request: Any) -> SessionContextManager:
49+
"""Fixture that returns a client creation strategy either direct or using the proxy."""
50+
if request.param == "server":
51+
return in_memory
52+
return proxy
53+
54+
55+
async def test_list_prompts(session_generator: SessionContextManager):
56+
"""Test list_prompts."""
57+
58+
server = Server("prompt-server")
59+
60+
@server.list_prompts()
61+
async def list_prompts() -> list[types.Prompt]:
62+
return [types.Prompt(name="prompt1")]
63+
64+
async with session_generator(server) as session:
65+
result = await session.initialize()
66+
assert result.serverInfo.name == "prompt-server"
67+
assert result.capabilities
68+
assert result.capabilities.prompts
69+
assert not result.capabilities.tools
70+
assert not result.capabilities.resources
71+
assert not result.capabilities.logging
72+
73+
result = await session.list_prompts()
74+
assert result.prompts == [types.Prompt(name="prompt1")]
75+
76+
with pytest.raises(McpError, match="Method not found"):
77+
await session.list_tools()
78+
79+
80+
async def test_list_tools(session_generator: SessionContextManager):
81+
"""Test list_tools."""
82+
83+
server = Server("tools-server")
84+
85+
@server.list_tools()
86+
async def list_tools() -> list[types.Tool]:
87+
return [types.Tool(
88+
name="tool-name",
89+
description="tool-description",
90+
inputSchema=TOOL_INPUT_SCHEMA
91+
)]
92+
93+
async with session_generator(server) as session:
94+
result = await session.initialize()
95+
assert result.serverInfo.name == "tools-server"
96+
assert result.capabilities
97+
assert result.capabilities.tools
98+
assert not result.capabilities.prompts
99+
assert not result.capabilities.resources
100+
assert not result.capabilities.logging
101+
102+
result = await session.list_tools()
103+
assert len(result.tools) == 1
104+
assert result.tools[0].name == "tool-name"
105+
assert result.tools[0].description == "tool-description"
106+
assert result.tools[0].inputSchema == TOOL_INPUT_SCHEMA
107+
108+
with pytest.raises(McpError, match="Method not found"):
109+
await session.list_prompts()

0 commit comments

Comments
 (0)