14
14
import os
15
15
import random
16
16
from concurrent .futures import ThreadPoolExecutor
17
- from typing import Any , AsyncIterator , Callable , Generator , Mapping , Optional , Type , TypeVar , Union , cast
17
+ from typing import Any , AsyncIterator , Callable , Generator , List , Mapping , Optional , Type , TypeVar , Union , cast
18
18
19
19
from opentelemetry import trace
20
20
from pydantic import BaseModel
31
31
from ..types .content import ContentBlock , Message , Messages
32
32
from ..types .exceptions import ContextWindowOverflowException
33
33
from ..types .models import Model
34
- from ..types .tools import ToolConfig
34
+ from ..types .tools import ToolConfig , ToolResult , ToolUse
35
35
from ..types .traces import AttributeValue
36
36
from .agent_result import AgentResult
37
37
from .conversation_manager import (
@@ -97,104 +97,56 @@ def __getattr__(self, name: str) -> Callable[..., Any]:
97
97
AttributeError: If no tool with the given name exists or if multiple tools match the given name.
98
98
"""
99
99
100
- def find_normalized_tool_name () -> Optional [str ]:
101
- """Lookup the tool represented by name, replacing characters with underscores as necessary."""
102
- tool_registry = self ._agent .tool_registry .registry
103
-
104
- if tool_registry .get (name , None ):
105
- return name
106
-
107
- # If the desired name contains underscores, it might be a placeholder for characters that can't be
108
- # represented as python identifiers but are valid as tool names, such as dashes. In that case, find
109
- # all tools that can be represented with the normalized name
110
- if "_" in name :
111
- filtered_tools = [
112
- tool_name for (tool_name , tool ) in tool_registry .items () if tool_name .replace ("-" , "_" ) == name
113
- ]
114
-
115
- # The registry itself defends against similar names, so we can just take the first match
116
- if filtered_tools :
117
- return filtered_tools [0 ]
118
-
119
- raise AttributeError (f"Tool '{ name } ' not found" )
120
-
121
- def caller (** kwargs : Any ) -> Any :
100
+ def caller (
101
+ user_message_override : Optional [str ] = None ,
102
+ record_direct_tool_call : Optional [bool ] = None ,
103
+ ** kwargs : Any ,
104
+ ) -> Any :
122
105
"""Call a tool directly by name.
123
106
124
107
Args:
108
+ user_message_override: Optional custom message to record instead of default
109
+ record_direct_tool_call: Whether to record direct tool calls in message history. Overrides class
110
+ attribute if provided.
125
111
**kwargs: Keyword arguments to pass to the tool.
126
112
127
- - user_message_override: Custom message to record instead of default
128
- - tool_execution_handler: Custom handler for tool execution
129
- - event_loop_metrics: Custom metrics collector
130
- - messages: Custom message history to use
131
- - tool_config: Custom tool configuration
132
- - callback_handler: Custom callback handler
133
- - record_direct_tool_call: Whether to record this call in history
134
-
135
113
Returns:
136
114
The result returned by the tool.
137
115
138
116
Raises:
139
117
AttributeError: If the tool doesn't exist.
140
118
"""
141
- normalized_name = find_normalized_tool_name ( )
119
+ normalized_name = self . _find_normalized_tool_name ( name )
142
120
143
121
# Create unique tool ID and set up the tool request
144
122
tool_id = f"tooluse_{ name } _{ random .randint (100000000 , 999999999 )} "
145
- tool_use = {
123
+ tool_use : ToolUse = {
146
124
"toolUseId" : tool_id ,
147
125
"name" : normalized_name ,
148
126
"input" : kwargs .copy (),
149
127
}
150
128
151
- # Extract tool execution parameters
152
- user_message_override = kwargs .get ("user_message_override" , None )
153
- tool_execution_handler = kwargs .get ("tool_execution_handler" , self ._agent .thread_pool_wrapper )
154
- event_loop_metrics = kwargs .get ("event_loop_metrics" , self ._agent .event_loop_metrics )
155
- messages = kwargs .get ("messages" , self ._agent .messages )
156
- tool_config = kwargs .get ("tool_config" , self ._agent .tool_config )
157
- callback_handler = kwargs .get ("callback_handler" , self ._agent .callback_handler )
158
- record_direct_tool_call = kwargs .get ("record_direct_tool_call" , self ._agent .record_direct_tool_call )
159
-
160
- # Process tool call
161
- handler_kwargs = {
162
- k : v
163
- for k , v in kwargs .items ()
164
- if k
165
- not in [
166
- "tool_execution_handler" ,
167
- "event_loop_metrics" ,
168
- "messages" ,
169
- "tool_config" ,
170
- "callback_handler" ,
171
- "tool_handler" ,
172
- "system_prompt" ,
173
- "model" ,
174
- "model_id" ,
175
- "user_message_override" ,
176
- "agent" ,
177
- "record_direct_tool_call" ,
178
- ]
179
- }
180
-
181
129
# Execute the tool
182
130
tool_result = self ._agent .tool_handler .process (
183
131
tool = tool_use ,
184
132
model = self ._agent .model ,
185
133
system_prompt = self ._agent .system_prompt ,
186
- messages = messages ,
187
- tool_config = tool_config ,
188
- callback_handler = callback_handler ,
189
- tool_execution_handler = tool_execution_handler ,
190
- event_loop_metrics = event_loop_metrics ,
191
- agent = self ._agent ,
192
- ** handler_kwargs ,
134
+ messages = self ._agent .messages ,
135
+ tool_config = self ._agent .tool_config ,
136
+ callback_handler = self ._agent .callback_handler ,
137
+ kwargs = kwargs ,
193
138
)
194
139
195
- if record_direct_tool_call :
140
+ if record_direct_tool_call is not None :
141
+ should_record_direct_tool_call = record_direct_tool_call
142
+ else :
143
+ should_record_direct_tool_call = self ._agent .record_direct_tool_call
144
+
145
+ if should_record_direct_tool_call :
196
146
# Create a record of this tool execution in the message history
197
- self ._agent ._record_tool_execution (tool_use , tool_result , user_message_override , messages )
147
+ self ._agent ._record_tool_execution (
148
+ tool_use , tool_result , user_message_override , self ._agent .messages
149
+ )
198
150
199
151
# Apply window management
200
152
self ._agent .conversation_manager .apply_management (self ._agent )
@@ -203,6 +155,27 @@ def caller(**kwargs: Any) -> Any:
203
155
204
156
return caller
205
157
158
+ def _find_normalized_tool_name (self , name : str ) -> str :
159
+ """Lookup the tool represented by name, replacing characters with underscores as necessary."""
160
+ tool_registry = self ._agent .tool_registry .registry
161
+
162
+ if tool_registry .get (name , None ):
163
+ return name
164
+
165
+ # If the desired name contains underscores, it might be a placeholder for characters that can't be
166
+ # represented as python identifiers but are valid as tool names, such as dashes. In that case, find
167
+ # all tools that can be represented with the normalized name
168
+ if "_" in name :
169
+ filtered_tools = [
170
+ tool_name for (tool_name , tool ) in tool_registry .items () if tool_name .replace ("-" , "_" ) == name
171
+ ]
172
+
173
+ # The registry itself defends against similar names, so we can just take the first match
174
+ if filtered_tools :
175
+ return filtered_tools [0 ]
176
+
177
+ raise AttributeError (f"Tool '{ name } ' not found" )
178
+
206
179
def __init__ (
207
180
self ,
208
181
model : Union [Model , str , None ] = None ,
@@ -371,7 +344,7 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
371
344
372
345
Args:
373
346
prompt: The natural language prompt from the user.
374
- **kwargs: Additional parameters to pass to the event loop.
347
+ **kwargs: Additional parameters to pass through the event loop.
375
348
376
349
Returns:
377
350
Result object containing:
@@ -514,44 +487,35 @@ def _execute_event_loop_cycle(
514
487
Yields:
515
488
Events of the loop cycle.
516
489
"""
517
- # Extract parameters with fallbacks to instance values
518
- system_prompt = kwargs .pop ("system_prompt" , self .system_prompt )
519
- model = kwargs .pop ("model" , self .model )
520
- tool_execution_handler = kwargs .pop ("tool_execution_handler" , self .thread_pool_wrapper )
521
- event_loop_metrics = kwargs .pop ("event_loop_metrics" , self .event_loop_metrics )
522
- callback_handler_override = kwargs .pop ("callback_handler" , callback_handler )
523
- tool_handler = kwargs .pop ("tool_handler" , self .tool_handler )
524
- messages = kwargs .pop ("messages" , self .messages )
525
- tool_config = kwargs .pop ("tool_config" , self .tool_config )
526
- kwargs .pop ("agent" , None ) # Remove agent to avoid conflicts
490
+ # Add `Agent` to kwargs to keep backwards-compatibility
491
+ kwargs ["agent" ] = self
527
492
528
493
try :
529
494
# Execute the main event loop cycle
530
495
yield from event_loop_cycle (
531
- model = model ,
532
- system_prompt = system_prompt ,
533
- messages = messages , # will be modified by event_loop_cycle
534
- tool_config = tool_config ,
535
- callback_handler = callback_handler_override ,
536
- tool_handler = tool_handler ,
537
- tool_execution_handler = tool_execution_handler ,
538
- event_loop_metrics = event_loop_metrics ,
539
- agent = self ,
496
+ model = self .model ,
497
+ system_prompt = self .system_prompt ,
498
+ messages = self .messages , # will be modified by event_loop_cycle
499
+ tool_config = self .tool_config ,
500
+ callback_handler = callback_handler ,
501
+ tool_handler = self .tool_handler ,
502
+ tool_execution_handler = self .thread_pool_wrapper ,
503
+ event_loop_metrics = self .event_loop_metrics ,
540
504
event_loop_parent_span = self .trace_span ,
541
- ** kwargs ,
505
+ kwargs = kwargs ,
542
506
)
543
507
544
508
except ContextWindowOverflowException as e :
545
509
# Try reducing the context size and retrying
546
510
self .conversation_manager .reduce_context (self , e = e )
547
- yield from self ._execute_event_loop_cycle (callback_handler_override , kwargs )
511
+ yield from self ._execute_event_loop_cycle (callback_handler , kwargs )
548
512
549
513
def _record_tool_execution (
550
514
self ,
551
- tool : dict [ str , Any ] ,
552
- tool_result : dict [ str , Any ] ,
515
+ tool : ToolUse ,
516
+ tool_result : ToolResult ,
553
517
user_message_override : Optional [str ],
554
- messages : list [ dict [ str , Any ]] ,
518
+ messages : Messages ,
555
519
) -> None :
556
520
"""Record a tool execution in the message history.
557
521
@@ -569,7 +533,7 @@ def _record_tool_execution(
569
533
messages: The message history to append to.
570
534
"""
571
535
# Create user message describing the tool call
572
- user_msg_content = [
536
+ user_msg_content : List [ ContentBlock ] = [
573
537
{"text" : (f"agent.tool.{ tool ['name' ]} direct tool call.\n Input parameters: { json .dumps (tool ['input' ])} \n " )}
574
538
]
575
539
@@ -578,19 +542,19 @@ def _record_tool_execution(
578
542
user_msg_content .insert (0 , {"text" : f"{ user_message_override } \n " })
579
543
580
544
# Create the message sequence
581
- user_msg = {
545
+ user_msg : Message = {
582
546
"role" : "user" ,
583
547
"content" : user_msg_content ,
584
548
}
585
- tool_use_msg = {
549
+ tool_use_msg : Message = {
586
550
"role" : "assistant" ,
587
551
"content" : [{"toolUse" : tool }],
588
552
}
589
- tool_result_msg = {
553
+ tool_result_msg : Message = {
590
554
"role" : "user" ,
591
555
"content" : [{"toolResult" : tool_result }],
592
556
}
593
- assistant_msg = {
557
+ assistant_msg : Message = {
594
558
"role" : "assistant" ,
595
559
"content" : [{"text" : f"agent.{ tool ['name' ]} was called" }],
596
560
}
0 commit comments