Skip to content

Commit 2ec8044

Browse files
authored
feat(agent): support plain function as client_tool (#187)
Summary: Test Plan: test_agents.py integration tests
1 parent fc9907c commit 2ec8044

File tree

1 file changed

+11
-14
lines changed
  • src/llama_stack_client/lib/agents

1 file changed

+11
-14
lines changed

src/llama_stack_client/lib/agents/agent.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66
import logging
7-
from typing import AsyncIterator, Iterator, List, Optional, Tuple, Union
7+
from typing import Any, AsyncIterator, Callable, Iterator, List, Optional, Tuple, Union
88

99
from llama_stack_client import LlamaStackClient
10-
1110
from llama_stack_client.types import ToolResponseMessage, ToolResponseParam, UserMessage
1211
from llama_stack_client.types.agent_create_params import AgentConfig
1312
from llama_stack_client.types.agents.turn import CompletionMessage, Turn
@@ -18,7 +17,7 @@
1817
from llama_stack_client.types.shared_params.response_format import ResponseFormat
1918
from llama_stack_client.types.shared_params.sampling_params import SamplingParams
2019

21-
from .client_tool import ClientTool
20+
from .client_tool import ClientTool, client_tool
2221
from .tool_parser import ToolParser
2322

2423
DEFAULT_MAX_ITER = 10
@@ -28,10 +27,12 @@
2827

2928
class AgentUtils:
3029
@staticmethod
31-
def get_client_tools(tools: Optional[List[Union[Toolgroup, ClientTool]]]) -> List[ClientTool]:
30+
def get_client_tools(tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]]) -> List[ClientTool]:
3231
if not tools:
3332
return []
3433

34+
# Wrap any function in client_tool decorator
35+
tools = [client_tool(tool) if (callable(tool) and not isinstance(tool, ClientTool)) else tool for tool in tools]
3536
return [tool for tool in tools if isinstance(tool, ClientTool)]
3637

3738
@staticmethod
@@ -59,7 +60,7 @@ def get_turn_id(chunk: AgentTurnResponseStreamChunk) -> Optional[str]:
5960
def get_agent_config(
6061
model: Optional[str] = None,
6162
instructions: Optional[str] = None,
62-
tools: Optional[List[Union[Toolgroup, ClientTool]]] = None,
63+
tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]] = None,
6364
tool_config: Optional[ToolConfig] = None,
6465
sampling_params: Optional[SamplingParams] = None,
6566
max_infer_iters: Optional[int] = None,
@@ -96,16 +97,12 @@ def get_agent_config(
9697
agent_config["tool_config"] = tool_config
9798
if tools is not None:
9899
toolgroups: List[Toolgroup] = []
99-
client_tools: List[ClientTool] = []
100-
101100
for tool in tools:
102101
if isinstance(tool, str) or isinstance(tool, dict):
103102
toolgroups.append(tool)
104-
else:
105-
client_tools.append(tool)
106103

107104
agent_config["toolgroups"] = toolgroups
108-
agent_config["client_tools"] = [tool.get_tool_definition() for tool in client_tools]
105+
agent_config["client_tools"] = [tool.get_tool_definition() for tool in AgentUtils.get_client_tools(tools)]
109106

110107
agent_config = AgentConfig(**agent_config)
111108
return agent_config
@@ -122,7 +119,7 @@ def __init__(
122119
tool_parser: Optional[ToolParser] = None,
123120
model: Optional[str] = None,
124121
instructions: Optional[str] = None,
125-
tools: Optional[List[Union[Toolgroup, ClientTool]]] = None,
122+
tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]] = None,
126123
tool_config: Optional[ToolConfig] = None,
127124
sampling_params: Optional[SamplingParams] = None,
128125
max_infer_iters: Optional[int] = None,
@@ -143,7 +140,7 @@ def __init__(
143140
:param instructions: The instructions for the agent.
144141
:param tools: A list of tools for the agent. Values can be one of the following:
145142
- dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}}
146-
- a python function decorated with @client_tool
143+
- a python function with a docstring. See @client_tool for more details.
147144
- str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search"
148145
- str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter", where all tools in the toolgroup will be added to the agent
149146
- an instance of ClientTool: A client tool object.
@@ -332,7 +329,7 @@ def __init__(
332329
tool_parser: Optional[ToolParser] = None,
333330
model: Optional[str] = None,
334331
instructions: Optional[str] = None,
335-
tools: Optional[List[Union[Toolgroup, ClientTool]]] = None,
332+
tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]] = None,
336333
tool_config: Optional[ToolConfig] = None,
337334
sampling_params: Optional[SamplingParams] = None,
338335
max_infer_iters: Optional[int] = None,
@@ -353,7 +350,7 @@ def __init__(
353350
:param instructions: The instructions for the agent.
354351
:param tools: A list of tools for the agent. Values can be one of the following:
355352
- dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}}
356-
- a python function decorated with @client_tool
353+
- a python function with a docstring. See @client_tool for more details.
357354
- str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search"
358355
- str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter", where all tools in the toolgroup will be added to the agent
359356
- an instance of ClientTool: A client tool object.

0 commit comments

Comments
 (0)