Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 58 additions & 49 deletions ddtrace/contrib/internal/mcp/patch.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import os
import sys
from typing import Any
from typing import TYPE_CHECKING
from typing import Dict
from typing import Optional

import mcp


if TYPE_CHECKING:
from mcp.types import ClientRequest
from mcp.types import Request

from ddtrace import config
from ddtrace._trace.pin import Pin
from ddtrace._trace.span import Span
Expand All @@ -17,8 +22,7 @@
from ddtrace.internal.logger import get_logger
from ddtrace.internal.utils.formats import asbool
from ddtrace.llmobs._integrations.mcp import CLIENT_TOOL_CALL_OPERATION_NAME
from ddtrace.llmobs._integrations.mcp import REQUEST_RESPONDER_ENTER_OPERATION_NAME
from ddtrace.llmobs._integrations.mcp import REQUEST_RESPONDER_RESPOND_OPERATION_NAME
from ddtrace.llmobs._integrations.mcp import SERVER_REQUEST_OPERATION_NAME
from ddtrace.llmobs._integrations.mcp import SERVER_TOOL_CALL_OPERATION_NAME
from ddtrace.llmobs._integrations.mcp import MCPIntegration
from ddtrace.llmobs._utils import _get_attr
Expand All @@ -31,6 +35,7 @@
"mcp",
{
"distributed_tracing": asbool(os.getenv("DD_MCP_DISTRIBUTED_TRACING", default=True)),
"capture_intent": asbool(os.getenv("DD_MCP_CAPTURE_INTENT", default=False)),
},
)

Expand All @@ -45,7 +50,7 @@ def _supported_versions() -> Dict[str, str]:
return {"mcp": ">=1.10.0"}


def _set_distributed_headers_into_mcp_request(pin: Pin, request):
def _set_distributed_headers_into_mcp_request(pin: Pin, request: "ClientRequest") -> "ClientRequest":
"""Inject distributed tracing headers into MCP request metadata."""
span = pin.tracer.current_span()
if span is None:
Expand Down Expand Up @@ -85,19 +90,13 @@ def _set_distributed_headers_into_mcp_request(pin: Pin, request):
return request


def _extract_distributed_headers_from_mcp_request(kwargs: Dict[str, Any]) -> Optional[Dict[str, str]]:
if "context" not in kwargs:
return
context = kwargs.get("context")
if not context or not _get_attr(context, "request_context", None):
return
request_context = _get_attr(context, "request_context", None)
meta = _get_attr(request_context, "meta", None)
if not meta:
return
headers = _get_attr(meta, "_dd_trace_context", None)
if headers:
return headers
def _extract_distributed_headers_from_mcp_request(request_root: "Request") -> Optional[Dict[str, str]]:
"""Extract distributed tracing headers from MCP request params.meta field."""
request_params = _get_attr(request_root, "params", None)
meta = _get_attr(request_params, "meta", None) if request_params else None
meta_dict = meta.model_dump() if meta and hasattr(meta, "model_dump") else {}
headers = meta_dict.get("_dd_trace_context", {})
return headers if headers else None


@with_traced_module
Expand Down Expand Up @@ -142,30 +141,6 @@ async def traced_call_tool(mcp, pin: Pin, func, instance, args: tuple, kwargs: d
span.finish()


@with_traced_module
async def traced_tool_manager_call_tool(mcp, pin: Pin, func, instance, args: tuple, kwargs: dict):
integration = mcp._datadog_integration
if config.mcp.distributed_tracing:
activate_distributed_headers(pin.tracer, config.mcp, _extract_distributed_headers_from_mcp_request(kwargs))

span = integration.trace(pin, SERVER_TOOL_CALL_OPERATION_NAME, submit_to_llmobs=True)

try:
result = await func(*args, **kwargs)
integration.llmobs_set_tags(
span, args=args, kwargs=kwargs, response=result, operation=SERVER_TOOL_CALL_OPERATION_NAME
)
return result
except Exception:
integration.llmobs_set_tags(
span, args=args, kwargs=kwargs, response=None, operation=SERVER_TOOL_CALL_OPERATION_NAME
)
span.set_exc_info(*sys.exc_info())
raise
finally:
span.finish()


@with_traced_module
async def traced_client_session_initialize(mcp, pin: Pin, func, instance, args: tuple, kwargs: dict):
integration: MCPIntegration = mcp._datadog_integration
Expand Down Expand Up @@ -234,44 +209,82 @@ async def traced_client_session_aexit(mcp, pin: Pin, func, instance, args: tuple

@with_traced_module
def traced_request_responder_enter(mcp, pin: Pin, func, instance, args: tuple, kwargs: dict):
from mcp.types import CallToolRequest
from mcp.types import InitializeRequest

integration: MCPIntegration = mcp._datadog_integration
request_wrapper = _get_attr(instance, "request", None)
request_root = _get_attr(request_wrapper, "root", None)

# While this patch can trace all requests, we only trace this type right now
if not request_root or not isinstance(request_root, InitializeRequest):
# While this patch can trace all requests, we only trace these types right now
if not request_root or (
not isinstance(request_root, InitializeRequest) and not isinstance(request_root, CallToolRequest)
):
return func(*args, **kwargs)

# Activate distributed tracing if enabled for tool calls
if (
isinstance(request_root, CallToolRequest)
and config.mcp.distributed_tracing
and (headers := _extract_distributed_headers_from_mcp_request(request_root))
):
activate_distributed_headers(pin.tracer, config.mcp, headers)

operation_name = (
SERVER_TOOL_CALL_OPERATION_NAME if isinstance(request_root, CallToolRequest) else SERVER_REQUEST_OPERATION_NAME
)

span = integration.trace(
pin, REQUEST_RESPONDER_ENTER_OPERATION_NAME, submit_to_llmobs=True, span_name="mcp.initialize"
pin,
operation_name,
submit_to_llmobs=True,
span_name="mcp.{}".format(_get_attr(request_root, "method", "unknown")),
)
setattr(instance, "_dd_span", span)

if isinstance(request_root, CallToolRequest):
integration.process_ddtrace_argument(span, request_root)

return func(*args, **kwargs)


@with_traced_module
def traced_request_responder_exit(mcp, pin: Pin, func, instance, args: tuple, kwargs: dict):
span: Optional[Span] = getattr(instance, "_dd_span", None)
if span:
# Check if an exception occurred (__exit__ receives (exc_type, exc_val, exc_tb))
exc_type = args[0] if len(args) > 0 else None
exc_val = args[1] if len(args) > 1 else None
exc_tb = args[2] if len(args) > 2 else None

if exc_type is not None:
span.set_exc_info(exc_type, exc_val, exc_tb)

span.finish()
return func(*args, **kwargs)


@with_traced_module
async def traced_request_responder_respond(mcp, pin: Pin, func, instance, args: tuple, kwargs: dict):
from mcp.types import ListToolsResult

response_arg = args[0] if len(args) > 0 else None
response = getattr(response_arg, "root", None)
integration: MCPIntegration = mcp._datadog_integration
span: Optional[Span] = getattr(instance, "_dd_span", None)

if config.mcp.capture_intent and isinstance(response, ListToolsResult):
integration.inject_tools_list_response(response)

if span:
integration.llmobs_set_tags(
span,
args=args,
kwargs=dict(**kwargs, request_responder=instance),
response=None,
operation=REQUEST_RESPONDER_RESPOND_OPERATION_NAME,
operation=SERVER_REQUEST_OPERATION_NAME,
)

return await func(*args, **kwargs)


Expand All @@ -284,7 +297,6 @@ def patch():
mcp._datadog_integration = MCPIntegration(integration_config=config.mcp)

from mcp.client.session import ClientSession
from mcp.server.fastmcp.tools.tool_manager import ToolManager
from mcp.shared.session import BaseSession
from mcp.shared.session import RequestResponder

Expand All @@ -294,7 +306,6 @@ def patch():
wrap(ClientSession, "call_tool", traced_call_tool(mcp))
wrap(ClientSession, "list_tools", traced_client_session_list_tools(mcp))
wrap(ClientSession, "initialize", traced_client_session_initialize(mcp))
wrap(ToolManager, "call_tool", traced_tool_manager_call_tool(mcp))
wrap(RequestResponder, "__enter__", traced_request_responder_enter(mcp))
wrap(RequestResponder, "__exit__", traced_request_responder_exit(mcp))
wrap(RequestResponder, "respond", traced_request_responder_respond(mcp))
Expand All @@ -307,7 +318,6 @@ def unpatch():
mcp.__datadog_patch = False

from mcp.client.session import ClientSession
from mcp.server.fastmcp.tools.tool_manager import ToolManager
from mcp.shared.session import BaseSession
from mcp.shared.session import RequestResponder

Expand All @@ -317,7 +327,6 @@ def unpatch():
unwrap(ClientSession, "call_tool")
unwrap(ClientSession, "list_tools")
unwrap(ClientSession, "initialize")
unwrap(ToolManager, "call_tool")
unwrap(RequestResponder, "__enter__")
unwrap(RequestResponder, "__exit__")
unwrap(RequestResponder, "respond")
Expand Down
2 changes: 2 additions & 0 deletions ddtrace/llmobs/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
OUTPUT_MESSAGES = "_ml_obs.meta.output.messages"
OUTPUT_VALUE = "_ml_obs.meta.output.value"

INTENT = "_ml_obs.meta.intent"

SPAN_START_WHILE_DISABLED_WARNING = (
"Span started with LLMObs disabled."
" If using ddtrace-run, ensure DD_LLMOBS_ENABLED is set to 1. Else, use LLMObs.enable()."
Expand Down
Loading