7
7
import inspect
8
8
import json
9
9
from abc import abstractmethod
10
- from typing import Callable , Dict , get_args , get_origin , get_type_hints , List , TypeVar , Union
10
+ from typing import Any , Callable , Dict , get_args , get_origin , get_type_hints , List , TypeVar , Union
11
11
12
- from llama_stack_client .types import Message , ToolResponseMessage
12
+ from llama_stack_client .types import Message , CompletionMessage , ToolResponse
13
13
from llama_stack_client .types .tool_def_param import Parameter , ToolDefParam
14
14
15
15
@@ -63,28 +63,37 @@ def get_tool_definition(self) -> ToolDefParam:
63
63
def run (
64
64
self ,
65
65
message_history : List [Message ],
66
- ) -> ToolResponseMessage :
66
+ ) -> ToolResponse :
67
67
# NOTE: we could override this method to use the entire message history for advanced tools
68
68
last_message = message_history [- 1 ]
69
-
69
+ assert isinstance ( last_message , CompletionMessage ), "Expected CompletionMessage"
70
70
assert len (last_message .tool_calls ) == 1 , "Expected single tool call"
71
71
tool_call = last_message .tool_calls [0 ]
72
72
73
+ metadata = {}
73
74
try :
74
75
response = self .run_impl (** tool_call .arguments )
75
- response_str = json .dumps (response , ensure_ascii = False )
76
+ if isinstance (response , dict ) and "content" in response :
77
+ content = json .dumps (response ["content" ], ensure_ascii = False )
78
+ metadata = response .get ("metadata" , {})
79
+ else :
80
+ content = json .dumps (response , ensure_ascii = False )
76
81
except Exception as e :
77
- response_str = f"Error when running tool: { e } "
78
-
79
- return ToolResponseMessage (
82
+ content = f"Error when running tool: { e } "
83
+ return ToolResponse (
80
84
call_id = tool_call .call_id ,
81
85
tool_name = tool_call .tool_name ,
82
- content = response_str ,
83
- role = "tool" ,
86
+ content = content ,
87
+ metadata = metadata ,
84
88
)
85
89
86
90
@abstractmethod
87
- def run_impl (self , ** kwargs ):
91
+ def run_impl (self , ** kwargs ) -> Any :
92
+ """
93
+ Can return any json serializable object.
94
+ To return metadata along with the response, return a dict with a "content" key, and a "metadata" key, where the "content" is the response that'll
95
+ be serialized and passed to the model, and the "metadata" will be logged as metadata in the tool execution step within the Agent execution trace.
96
+ """
88
97
raise NotImplementedError
89
98
90
99
@@ -107,6 +116,10 @@ def add(x: int, y: int) -> int:
107
116
108
117
Note that you must use RST-style docstrings with :param tags for each parameter. These will be used for prompting model to use tools correctly.
109
118
:returns: tags in the docstring is optional as it would not be used for the tool's description.
119
+
120
+ Your function can return any json serializable object.
121
+ To return metadata along with the response, return a dict with a "content" key, and a "metadata" key, where the "content" is the response that'll
122
+ be serialized and passed to the model, and the "metadata" will be logged as metadata in the tool execution step within the Agent execution trace.
110
123
"""
111
124
112
125
class _WrappedTool (ClientTool ):
@@ -162,7 +175,7 @@ def get_params_definition(self) -> Dict[str, Parameter]:
162
175
163
176
return params
164
177
165
- def run_impl (self , ** kwargs ):
178
+ def run_impl (self , ** kwargs ) -> Any :
166
179
return func (** kwargs )
167
180
168
181
return _WrappedTool ()
0 commit comments