2121 MCPBaseTransport ,
2222 SSETransport ,
2323 StdioTransport ,
24+ TimeoutConfig ,
2425)
2526
2627logger = get_logger ("chuk_tool_processor.mcp.stream_manager" )
@@ -38,15 +39,15 @@ class StreamManager:
3839 - HTTP Streamable (modern replacement for SSE, spec 2025-03-26) with graceful headers handling
3940 """
4041
41- def __init__ (self ) -> None :
42+ def __init__ (self , timeout_config : TimeoutConfig | None = None ) -> None :
4243 self .transports : dict [str , MCPBaseTransport ] = {}
4344 self .server_info : list [dict [str , Any ]] = []
4445 self .tool_to_server_map : dict [str , str ] = {}
4546 self .server_names : dict [int , str ] = {}
4647 self .all_tools : list [dict [str , Any ]] = []
4748 self ._lock = asyncio .Lock ()
4849 self ._closed = False # Track if we've been closed
49- self ._shutdown_timeout = 2.0 # Maximum time to spend on shutdown
50+ self .timeout_config = timeout_config or TimeoutConfig ()
5051
5152 # ------------------------------------------------------------------ #
5253 # factory helpers with enhanced error handling #
@@ -251,8 +252,12 @@ async def initialize(
251252 self .transports [server_name ] = transport
252253
253254 # Ping and get tools with timeout protection (use longer timeouts for slow servers)
254- status = "Up" if await asyncio .wait_for (transport .send_ping (), timeout = 30.0 ) else "Down"
255- tools = await asyncio .wait_for (transport .get_tools (), timeout = 30.0 )
255+ status = (
256+ "Up"
257+ if await asyncio .wait_for (transport .send_ping (), timeout = self .timeout_config .operation )
258+ else "Down"
259+ )
260+ tools = await asyncio .wait_for (transport .get_tools (), timeout = self .timeout_config .operation )
256261
257262 for t in tools :
258263 name = t .get ("name" )
@@ -333,8 +338,12 @@ async def initialize_with_sse(
333338
334339 self .transports [name ] = transport
335340 # Use longer timeouts for slow servers (ping can take time after initialization)
336- status = "Up" if await asyncio .wait_for (transport .send_ping (), timeout = 30.0 ) else "Down"
337- tools = await asyncio .wait_for (transport .get_tools (), timeout = 30.0 )
341+ status = (
342+ "Up"
343+ if await asyncio .wait_for (transport .send_ping (), timeout = self .timeout_config .operation )
344+ else "Down"
345+ )
346+ tools = await asyncio .wait_for (transport .get_tools (), timeout = self .timeout_config .operation )
338347
339348 for t in tools :
340349 tname = t .get ("name" )
@@ -415,8 +424,12 @@ async def initialize_with_http_streamable(
415424
416425 self .transports [name ] = transport
417426 # Use longer timeouts for slow servers (ping can take time after initialization)
418- status = "Up" if await asyncio .wait_for (transport .send_ping (), timeout = 30.0 ) else "Down"
419- tools = await asyncio .wait_for (transport .get_tools (), timeout = 30.0 )
427+ status = (
428+ "Up"
429+ if await asyncio .wait_for (transport .send_ping (), timeout = self .timeout_config .operation )
430+ else "Down"
431+ )
432+ tools = await asyncio .wait_for (transport .get_tools (), timeout = self .timeout_config .operation )
420433
421434 for t in tools :
422435 tname = t .get ("name" )
@@ -462,7 +475,7 @@ async def list_tools(self, server_name: str) -> list[dict[str, Any]]:
462475 transport = self .transports [server_name ]
463476
464477 try :
465- tools = await asyncio .wait_for (transport .get_tools (), timeout = 10.0 )
478+ tools = await asyncio .wait_for (transport .get_tools (), timeout = self . timeout_config . operation )
466479 logger .debug ("Found %d tools for server %s" , len (tools ), server_name )
467480 return tools
468481 except TimeoutError :
@@ -481,7 +494,7 @@ async def ping_servers(self) -> list[dict[str, Any]]:
481494
482495 async def _ping_one (name : str , tr : MCPBaseTransport ):
483496 try :
484- ok = await asyncio .wait_for (tr .send_ping (), timeout = 5.0 )
497+ ok = await asyncio .wait_for (tr .send_ping (), timeout = self . timeout_config . quick )
485498 except Exception :
486499 ok = False
487500 return {"server" : name , "ok" : ok }
@@ -496,7 +509,7 @@ async def list_resources(self) -> list[dict[str, Any]]:
496509
497510 async def _one (name : str , tr : MCPBaseTransport ):
498511 try :
499- res = await asyncio .wait_for (tr .list_resources (), timeout = 10.0 )
512+ res = await asyncio .wait_for (tr .list_resources (), timeout = self . timeout_config . operation )
500513 resources = res .get ("resources" , []) if isinstance (res , dict ) else res
501514 for item in resources :
502515 item = dict (item )
@@ -516,7 +529,7 @@ async def list_prompts(self) -> list[dict[str, Any]]:
516529
517530 async def _one (name : str , tr : MCPBaseTransport ):
518531 try :
519- res = await asyncio .wait_for (tr .list_prompts (), timeout = 10.0 )
532+ res = await asyncio .wait_for (tr .list_prompts (), timeout = self . timeout_config . operation )
520533 prompts = res .get ("prompts" , []) if isinstance (res , dict ) else res
521534 for item in prompts :
522535 item = dict (item )
@@ -643,7 +656,7 @@ async def _concurrent_close(self, transport_items: list[tuple[str, MCPBaseTransp
643656 try :
644657 results = await asyncio .wait_for (
645658 asyncio .gather (* [task for _ , task in close_tasks ], return_exceptions = True ),
646- timeout = self ._shutdown_timeout ,
659+ timeout = self .timeout_config . shutdown ,
647660 )
648661
649662 # Process results
@@ -666,7 +679,8 @@ async def _concurrent_close(self, transport_items: list[tuple[str, MCPBaseTransp
666679 # Brief wait for cancellations to complete
667680 with contextlib .suppress (TimeoutError ):
668681 await asyncio .wait_for (
669- asyncio .gather (* [task for _ , task in close_tasks ], return_exceptions = True ), timeout = 0.5
682+ asyncio .gather (* [task for _ , task in close_tasks ], return_exceptions = True ),
683+ timeout = self .timeout_config .shutdown ,
670684 )
671685
672686 async def _sequential_close (self , transport_items : list [tuple [str , MCPBaseTransport ]], close_results : list ) -> None :
@@ -675,7 +689,7 @@ async def _sequential_close(self, transport_items: list[tuple[str, MCPBaseTransp
675689 try :
676690 await asyncio .wait_for (
677691 self ._close_single_transport (name , transport ),
678- timeout = 0.5 , # Short timeout per transport
692+ timeout = self . timeout_config . shutdown ,
679693 )
680694 logger .debug ("Closed transport: %s" , name )
681695 close_results .append ((name , True , None ))
@@ -767,7 +781,7 @@ async def health_check(self) -> dict[str, Any]:
767781
768782 for name , transport in self .transports .items ():
769783 try :
770- ping_ok = await asyncio .wait_for (transport .send_ping (), timeout = 5.0 )
784+ ping_ok = await asyncio .wait_for (transport .send_ping (), timeout = self . timeout_config . quick )
771785 health_info ["transports" ][name ] = {
772786 "status" : "healthy" if ping_ok else "unhealthy" ,
773787 "ping_success" : ping_ok ,
0 commit comments