Skip to content

Commit 8e4fd56

Browse files
ehhuangEric Huang (AI Platform)
andauthored
feat: support client tool output metadata (#180)
Summary: Corresponding change on the python client for llamastack/llama-stack#1426 Test Plan: Tested in llamastack/llama-stack#1426 Co-authored-by: Eric Huang (AI Platform) <[email protected]>
1 parent e33aa4a commit 8e4fd56

File tree

2 files changed

+32
-21
lines changed

2 files changed

+32
-21
lines changed

src/llama_stack_client/lib/agents/agent.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from llama_stack_client import LlamaStackClient
99

10-
from llama_stack_client.types import ToolResponseMessage, UserMessage
10+
from llama_stack_client.types import ToolResponseMessage, ToolResponseParam, UserMessage
1111
from llama_stack_client.types.agent_create_params import AgentConfig
1212
from llama_stack_client.types.agents.turn import CompletionMessage, Turn
1313
from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup
@@ -74,7 +74,7 @@ def _get_turn_id(self, chunk: AgentTurnResponseStreamChunk) -> Optional[str]:
7474

7575
return chunk.event.payload.turn.turn_id
7676

77-
def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage:
77+
def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseParam:
7878
assert len(tool_calls) == 1, "Only one tool call is supported"
7979
tool_call = tool_calls[0]
8080

@@ -101,20 +101,18 @@ def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage:
101101
tool_name=tool_call.tool_name,
102102
kwargs=tool_call.arguments,
103103
)
104-
tool_response_message = ToolResponseMessage(
104+
tool_response = ToolResponseParam(
105105
call_id=tool_call.call_id,
106106
tool_name=tool_call.tool_name,
107107
content=tool_result.content,
108-
role="tool",
109108
)
110-
return tool_response_message
109+
return tool_response
111110

112111
# cannot find tools
113-
return ToolResponseMessage(
112+
return ToolResponseParam(
114113
call_id=tool_call.call_id,
115114
tool_name=tool_call.tool_name,
116115
content=f"Unknown tool `{tool_call.tool_name}` was called.",
117-
role="tool",
118116
)
119117

120118
def create_turn(
@@ -176,14 +174,14 @@ def _create_turn_streaming(
176174
yield chunk
177175

178176
# run the tools
179-
tool_response_message = self._run_tool(tool_calls)
177+
tool_response = self._run_tool(tool_calls)
180178

181179
# pass it to next iteration
182180
turn_response = self.client.agents.turn.resume(
183181
agent_id=self.agent_id,
184182
session_id=session_id or self.session_id[-1],
185183
turn_id=turn_id,
186-
tool_responses=[tool_response_message],
184+
tool_responses=[tool_response],
187185
stream=True,
188186
)
189187
n_iter += 1

src/llama_stack_client/lib/agents/client_tool.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
import inspect
88
import json
99
from abc import abstractmethod
10-
from typing import Callable, Dict, get_args, get_origin, get_type_hints, List, TypeVar, Union
10+
from typing import Any, Callable, Dict, get_args, get_origin, get_type_hints, List, TypeVar, Union
1111

12-
from llama_stack_client.types import Message, ToolResponseMessage
12+
from llama_stack_client.types import Message, CompletionMessage, ToolResponse
1313
from llama_stack_client.types.tool_def_param import Parameter, ToolDefParam
1414

1515

@@ -63,28 +63,37 @@ def get_tool_definition(self) -> ToolDefParam:
6363
def run(
6464
self,
6565
message_history: List[Message],
66-
) -> ToolResponseMessage:
66+
) -> ToolResponse:
6767
# NOTE: we could override this method to use the entire message history for advanced tools
6868
last_message = message_history[-1]
69-
69+
assert isinstance(last_message, CompletionMessage), "Expected CompletionMessage"
7070
assert len(last_message.tool_calls) == 1, "Expected single tool call"
7171
tool_call = last_message.tool_calls[0]
7272

73+
metadata = {}
7374
try:
7475
response = self.run_impl(**tool_call.arguments)
75-
response_str = json.dumps(response, ensure_ascii=False)
76+
if isinstance(response, dict) and "content" in response:
77+
content = json.dumps(response["content"], ensure_ascii=False)
78+
metadata = response.get("metadata", {})
79+
else:
80+
content = json.dumps(response, ensure_ascii=False)
7681
except Exception as e:
77-
response_str = f"Error when running tool: {e}"
78-
79-
return ToolResponseMessage(
82+
content = f"Error when running tool: {e}"
83+
return ToolResponse(
8084
call_id=tool_call.call_id,
8185
tool_name=tool_call.tool_name,
82-
content=response_str,
83-
role="tool",
86+
content=content,
87+
metadata=metadata,
8488
)
8589

8690
@abstractmethod
87-
def run_impl(self, **kwargs):
91+
def run_impl(self, **kwargs) -> Any:
92+
"""
93+
Can return any json serializable object.
94+
To return metadata along with the response, return a dict with a "content" key, and a "metadata" key, where the "content" is the response that'll
95+
be serialized and passed to the model, and the "metadata" will be logged as metadata in the tool execution step within the Agent execution trace.
96+
"""
8897
raise NotImplementedError
8998

9099

@@ -107,6 +116,10 @@ def add(x: int, y: int) -> int:
107116
108117
Note that you must use RST-style docstrings with :param tags for each parameter. These will be used for prompting model to use tools correctly.
109118
:returns: tags in the docstring is optional as it would not be used for the tool's description.
119+
120+
Your function can return any json serializable object.
121+
To return metadata along with the response, return a dict with a "content" key, and a "metadata" key, where the "content" is the response that'll
122+
be serialized and passed to the model, and the "metadata" will be logged as metadata in the tool execution step within the Agent execution trace.
110123
"""
111124

112125
class _WrappedTool(ClientTool):
@@ -162,7 +175,7 @@ def get_params_definition(self) -> Dict[str, Parameter]:
162175

163176
return params
164177

165-
def run_impl(self, **kwargs):
178+
def run_impl(self, **kwargs) -> Any:
166179
return func(**kwargs)
167180

168181
return _WrappedTool()

0 commit comments

Comments
 (0)