4
4
# This source code is licensed under the terms described in the LICENSE file in
5
5
# the root directory of this source tree.
6
6
import logging
7
- from typing import AsyncIterator , Iterator , List , Optional , Tuple , Union
7
+ from typing import Any , AsyncIterator , Callable , Iterator , List , Optional , Tuple , Union
8
8
9
9
from llama_stack_client import LlamaStackClient
10
-
11
10
from llama_stack_client .types import ToolResponseMessage , ToolResponseParam , UserMessage
12
11
from llama_stack_client .types .agent_create_params import AgentConfig
13
12
from llama_stack_client .types .agents .turn import CompletionMessage , Turn
18
17
from llama_stack_client .types .shared_params .response_format import ResponseFormat
19
18
from llama_stack_client .types .shared_params .sampling_params import SamplingParams
20
19
21
- from .client_tool import ClientTool
20
+ from .client_tool import ClientTool , client_tool
22
21
from .tool_parser import ToolParser
23
22
24
23
DEFAULT_MAX_ITER = 10
28
27
29
28
class AgentUtils :
30
29
@staticmethod
31
- def get_client_tools (tools : Optional [List [Union [Toolgroup , ClientTool ]]]) -> List [ClientTool ]:
30
+ def get_client_tools (tools : Optional [List [Union [Toolgroup , ClientTool , Callable [..., Any ] ]]]) -> List [ClientTool ]:
32
31
if not tools :
33
32
return []
34
33
34
+ # Wrap any function in client_tool decorator
35
+ tools = [client_tool (tool ) if (callable (tool ) and not isinstance (tool , ClientTool )) else tool for tool in tools ]
35
36
return [tool for tool in tools if isinstance (tool , ClientTool )]
36
37
37
38
@staticmethod
@@ -59,7 +60,7 @@ def get_turn_id(chunk: AgentTurnResponseStreamChunk) -> Optional[str]:
59
60
def get_agent_config (
60
61
model : Optional [str ] = None ,
61
62
instructions : Optional [str ] = None ,
62
- tools : Optional [List [Union [Toolgroup , ClientTool ]]] = None ,
63
+ tools : Optional [List [Union [Toolgroup , ClientTool , Callable [..., Any ] ]]] = None ,
63
64
tool_config : Optional [ToolConfig ] = None ,
64
65
sampling_params : Optional [SamplingParams ] = None ,
65
66
max_infer_iters : Optional [int ] = None ,
@@ -96,16 +97,12 @@ def get_agent_config(
96
97
agent_config ["tool_config" ] = tool_config
97
98
if tools is not None :
98
99
toolgroups : List [Toolgroup ] = []
99
- client_tools : List [ClientTool ] = []
100
-
101
100
for tool in tools :
102
101
if isinstance (tool , str ) or isinstance (tool , dict ):
103
102
toolgroups .append (tool )
104
- else :
105
- client_tools .append (tool )
106
103
107
104
agent_config ["toolgroups" ] = toolgroups
108
- agent_config ["client_tools" ] = [tool .get_tool_definition () for tool in client_tools ]
105
+ agent_config ["client_tools" ] = [tool .get_tool_definition () for tool in AgentUtils . get_client_tools ( tools ) ]
109
106
110
107
agent_config = AgentConfig (** agent_config )
111
108
return agent_config
@@ -122,7 +119,7 @@ def __init__(
122
119
tool_parser : Optional [ToolParser ] = None ,
123
120
model : Optional [str ] = None ,
124
121
instructions : Optional [str ] = None ,
125
- tools : Optional [List [Union [Toolgroup , ClientTool ]]] = None ,
122
+ tools : Optional [List [Union [Toolgroup , ClientTool , Callable [..., Any ] ]]] = None ,
126
123
tool_config : Optional [ToolConfig ] = None ,
127
124
sampling_params : Optional [SamplingParams ] = None ,
128
125
max_infer_iters : Optional [int ] = None ,
@@ -143,7 +140,7 @@ def __init__(
143
140
:param instructions: The instructions for the agent.
144
141
:param tools: A list of tools for the agent. Values can be one of the following:
145
142
- dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}}
146
- - a python function decorated with @client_tool
143
+ - a python function with a docstring. See @client_tool for more details.
147
144
- str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search"
148
145
- str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter", where all tools in the toolgroup will be added to the agent
149
146
- an instance of ClientTool: A client tool object.
@@ -332,7 +329,7 @@ def __init__(
332
329
tool_parser : Optional [ToolParser ] = None ,
333
330
model : Optional [str ] = None ,
334
331
instructions : Optional [str ] = None ,
335
- tools : Optional [List [Union [Toolgroup , ClientTool ]]] = None ,
332
+ tools : Optional [List [Union [Toolgroup , ClientTool , Callable [..., Any ] ]]] = None ,
336
333
tool_config : Optional [ToolConfig ] = None ,
337
334
sampling_params : Optional [SamplingParams ] = None ,
338
335
max_infer_iters : Optional [int ] = None ,
@@ -353,7 +350,7 @@ def __init__(
353
350
:param instructions: The instructions for the agent.
354
351
:param tools: A list of tools for the agent. Values can be one of the following:
355
352
- dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}}
356
- - a python function decorated with @client_tool
353
+ - a python function with a docstring. See @client_tool for more details.
357
354
- str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search"
358
355
- str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter", where all tools in the toolgroup will be added to the agent
359
356
- an instance of ClientTool: A client tool object.
0 commit comments