Skip to content

Commit 900c7f4

Browse files
authored
[MCP Gateway] Litellm mcp pre and during guardrails (#13188)
* add guardrail support * add guardrail support * guardrails for MCP * added changes * add mcp guardrails * added test * add ui * fix guardrail form * working with cursor * remvoe print * fix mcp servertests * fix mypy and remove console logs * fix mypy and remove console logs * fix mypy tests
1 parent c125ae4 commit 900c7f4

File tree

38 files changed

+1294
-104
lines changed

38 files changed

+1294
-104
lines changed

enterprise/enterprise_hooks/aporia_ai.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ async def async_moderation_hook(
173173
"moderation",
174174
"audio_transcription",
175175
"responses",
176+
"mcp_call",
176177
],
177178
):
178179
from litellm.proxy.common_utils.callback_utils import (

enterprise/enterprise_hooks/google_text_moderation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ async def async_moderation_hook(
9595
"moderation",
9696
"audio_transcription",
9797
"responses",
98+
"mcp_call",
9899
],
99100
):
100101
"""

enterprise/enterprise_hooks/openai_moderation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ async def async_moderation_hook(
4242
"moderation",
4343
"audio_transcription",
4444
"responses",
45+
"mcp_call",
4546
],
4647
):
4748
text = ""

enterprise/litellm_enterprise/enterprise_callbacks/llama_guard.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ async def async_moderation_hook(
105105
"moderation",
106106
"audio_transcription",
107107
"responses",
108+
"mcp_call",
108109
],
109110
):
110111
"""

enterprise/litellm_enterprise/enterprise_callbacks/llm_guard.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ async def async_moderation_hook(
127127
"moderation",
128128
"audio_transcription",
129129
"responses",
130+
"mcp_call",
130131
],
131132
):
132133
"""

enterprise/litellm_enterprise/enterprise_callbacks/pagerduty/pagerduty.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ async def async_pre_call_hook(
147147
"audio_transcription",
148148
"pass_through_endpoint",
149149
"rerank",
150+
"mcp_call",
150151
],
151152
) -> Optional[Union[Exception, str, dict]]:
152153
"""

enterprise/litellm_enterprise/proxy/hooks/managed_files.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ async def async_pre_call_hook(
290290
"aretrieve_fine_tuning_job",
291291
"alist_fine_tuning_jobs",
292292
"acancel_fine_tuning_job",
293+
"mcp_call",
293294
],
294295
) -> Union[Exception, str, Dict, None]:
295296
"""

litellm/integrations/custom_guardrail.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,6 @@ def should_run_guardrail(
234234
Returns True if the guardrail should be run on the event_type
235235
"""
236236
requested_guardrails = self.get_guardrail_from_metadata(data)
237-
238237
verbose_logger.debug(
239238
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s self.default_on= %s",
240239
self.guardrail_name,
@@ -243,7 +242,6 @@ def should_run_guardrail(
243242
requested_guardrails,
244243
self.default_on,
245244
)
246-
247245
if self.default_on is True:
248246
if self._event_hook_is_event_type(event_type):
249247
if isinstance(self.event_hook, Mode):
@@ -287,7 +285,6 @@ def should_run_guardrail(
287285
)
288286
if result is not None:
289287
return result
290-
291288
return True
292289

293290
def _event_hook_is_event_type(self, event_type: GuardrailEventHooks) -> bool:

litellm/integrations/custom_logger.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ async def async_pre_call_hook(
281281
"audio_transcription",
282282
"pass_through_endpoint",
283283
"rerank",
284+
"mcp_call",
284285
],
285286
) -> Optional[
286287
Union[Exception, str, dict]
@@ -327,6 +328,7 @@ async def async_moderation_hook(
327328
"moderation",
328329
"audio_transcription",
329330
"responses",
331+
"mcp_call",
330332
],
331333
) -> Any:
332334
pass

litellm/proxy/_experimental/mcp_server/mcp_server_manager.py

Lines changed: 47 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from mcp.types import Tool as MCPTool
1818

1919
from litellm._logging import verbose_logger
20+
from litellm.exceptions import BlockedPiiEntityError, GuardrailRaisedException
21+
from fastapi import HTTPException
2022
from litellm.experimental_mcp_client.client import MCPClient
2123
from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import (
2224
MCPRequestHandler,
@@ -592,22 +594,22 @@ async def call_tool(
592594
"server_name": server_name_from_prefix,
593595
"user_api_key_auth": user_api_key_auth,
594596
}
595-
pre_hook_result = await proxy_logging_obj.async_pre_mcp_tool_call_hook(
596-
kwargs=pre_hook_kwargs,
597-
request_obj=None, # Will be created in the hook
598-
start_time=start_time,
599-
end_time=start_time,
600-
)
601-
602-
if pre_hook_result:
603-
# Check if the call should proceed
604-
if not pre_hook_result.get("should_proceed", True):
605-
error_message = pre_hook_result.get("error_message", "Tool call rejected by pre-hook")
606-
raise ValueError(error_message)
597+
try:
598+
pre_hook_result = await proxy_logging_obj.async_pre_mcp_tool_call_hook(
599+
kwargs=pre_hook_kwargs,
600+
request_obj=None, # Will be created in the hook
601+
start_time=start_time,
602+
end_time=start_time,
603+
)
607604

608-
# Apply any argument modifications
609-
if pre_hook_result.get("modified_arguments"):
610-
arguments = pre_hook_result["modified_arguments"]
605+
if pre_hook_result:
606+
# Apply any argument modifications
607+
if pre_hook_result.get("modified_arguments"):
608+
arguments = pre_hook_result["modified_arguments"]
609+
except (BlockedPiiEntityError, GuardrailRaisedException, HTTPException) as e:
610+
# Re-raise guardrail exceptions to properly fail the MCP call
611+
verbose_logger.error(f"Guardrail blocked MCP tool call pre call: {str(e)}")
612+
raise e
611613

612614
# Get server-specific auth header if available
613615
server_auth_header = None
@@ -627,6 +629,7 @@ async def call_tool(
627629
)
628630

629631
async with client:
632+
630633
# Use the original tool name (without prefix) for the actual call
631634
call_tool_params = MCPCallToolRequestParams(
632635
name=original_tool_name,
@@ -635,40 +638,39 @@ async def call_tool(
635638

636639
# Initialize during_hook_task as None
637640
during_hook_task = None
638-
641+
tasks = []
639642
# Start during hook if proxy_logging_obj is available
640643
if proxy_logging_obj:
641-
try:
642-
during_hook_task = asyncio.create_task(
643-
proxy_logging_obj.async_during_mcp_tool_call_hook(
644-
kwargs={
645-
"name": name,
646-
"arguments": arguments,
647-
"server_name": server_name_from_prefix,
648-
},
649-
request_obj=None, # Will be created in the hook
650-
start_time=start_time,
651-
end_time=start_time,
652-
)
644+
during_hook_task = asyncio.create_task(
645+
proxy_logging_obj.async_during_mcp_tool_call_hook(
646+
kwargs={
647+
"name": name,
648+
"arguments": arguments,
649+
"server_name": server_name_from_prefix,
650+
},
651+
request_obj=None, # Will be created in the hook
652+
start_time=start_time,
653+
end_time=start_time,
653654
)
654-
except Exception as e:
655-
verbose_logger.warning(f"During hook error (non-blocking): {str(e)}")
656-
657-
result = await client.call_tool(call_tool_params)
658-
659-
#########################################################
660-
# Check during hook result if it completed
661-
#########################################################
662-
if proxy_logging_obj and during_hook_task is not None:
663-
try:
664-
during_hook_result = await during_hook_task
665-
if during_hook_result and not during_hook_result.get("should_continue", True):
666-
error_message = during_hook_result.get("error_message", "Tool call cancelled by during-hook")
667-
raise ValueError(error_message)
668-
except Exception as e:
669-
verbose_logger.warning(f"During hook error (non-blocking): {str(e)}")
655+
)
656+
tasks.append(during_hook_task)
670657

671-
return result
658+
659+
tasks.append(asyncio.create_task(client.call_tool(call_tool_params)))
660+
try:
661+
662+
mcp_responses = await asyncio.gather(*tasks)
663+
664+
# If proxy_logging_obj is None, the tool call result is at index 0
665+
# If proxy_logging_obj is not None, the tool call result is at index 1 (after the during hook task)
666+
result_index = 1 if proxy_logging_obj else 0
667+
result = mcp_responses[result_index]
668+
669+
return cast(CallToolResult, result)
670+
except (BlockedPiiEntityError, GuardrailRaisedException, HTTPException) as e:
671+
# Re-raise guardrail exceptions to properly fail the MCP call
672+
verbose_logger.error(f"Guardrail blocked MCP tool call during result check: {str(e)}")
673+
raise e
672674

673675
#########################################################
674676
# End of Methods that call the upstream MCP servers

0 commit comments

Comments
 (0)