Skip to content

Commit a392509

Browse files
authored
refactor: remove kwargs spread after agent call (#289)
* refactor: remove kwargs spread after agent call * fix: Add local method override * fix: fix unit tests
1 parent 91fd7f1 commit a392509

File tree

8 files changed

+230
-250
lines changed

8 files changed

+230
-250
lines changed

src/strands/agent/agent.py

Lines changed: 67 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import os
1515
import random
1616
from concurrent.futures import ThreadPoolExecutor
17-
from typing import Any, AsyncIterator, Callable, Generator, Mapping, Optional, Type, TypeVar, Union, cast
17+
from typing import Any, AsyncIterator, Callable, Generator, List, Mapping, Optional, Type, TypeVar, Union, cast
1818

1919
from opentelemetry import trace
2020
from pydantic import BaseModel
@@ -31,7 +31,7 @@
3131
from ..types.content import ContentBlock, Message, Messages
3232
from ..types.exceptions import ContextWindowOverflowException
3333
from ..types.models import Model
34-
from ..types.tools import ToolConfig
34+
from ..types.tools import ToolConfig, ToolResult, ToolUse
3535
from ..types.traces import AttributeValue
3636
from .agent_result import AgentResult
3737
from .conversation_manager import (
@@ -97,104 +97,56 @@ def __getattr__(self, name: str) -> Callable[..., Any]:
9797
AttributeError: If no tool with the given name exists or if multiple tools match the given name.
9898
"""
9999

100-
def find_normalized_tool_name() -> Optional[str]:
101-
"""Lookup the tool represented by name, replacing characters with underscores as necessary."""
102-
tool_registry = self._agent.tool_registry.registry
103-
104-
if tool_registry.get(name, None):
105-
return name
106-
107-
# If the desired name contains underscores, it might be a placeholder for characters that can't be
108-
# represented as python identifiers but are valid as tool names, such as dashes. In that case, find
109-
# all tools that can be represented with the normalized name
110-
if "_" in name:
111-
filtered_tools = [
112-
tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name
113-
]
114-
115-
# The registry itself defends against similar names, so we can just take the first match
116-
if filtered_tools:
117-
return filtered_tools[0]
118-
119-
raise AttributeError(f"Tool '{name}' not found")
120-
121-
def caller(**kwargs: Any) -> Any:
100+
def caller(
101+
user_message_override: Optional[str] = None,
102+
record_direct_tool_call: Optional[bool] = None,
103+
**kwargs: Any,
104+
) -> Any:
122105
"""Call a tool directly by name.
123106
124107
Args:
108+
user_message_override: Optional custom message to record instead of default
109+
record_direct_tool_call: Whether to record direct tool calls in message history. Overrides class
110+
attribute if provided.
125111
**kwargs: Keyword arguments to pass to the tool.
126112
127-
- user_message_override: Custom message to record instead of default
128-
- tool_execution_handler: Custom handler for tool execution
129-
- event_loop_metrics: Custom metrics collector
130-
- messages: Custom message history to use
131-
- tool_config: Custom tool configuration
132-
- callback_handler: Custom callback handler
133-
- record_direct_tool_call: Whether to record this call in history
134-
135113
Returns:
136114
The result returned by the tool.
137115
138116
Raises:
139117
AttributeError: If the tool doesn't exist.
140118
"""
141-
normalized_name = find_normalized_tool_name()
119+
normalized_name = self._find_normalized_tool_name(name)
142120

143121
# Create unique tool ID and set up the tool request
144122
tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}"
145-
tool_use = {
123+
tool_use: ToolUse = {
146124
"toolUseId": tool_id,
147125
"name": normalized_name,
148126
"input": kwargs.copy(),
149127
}
150128

151-
# Extract tool execution parameters
152-
user_message_override = kwargs.get("user_message_override", None)
153-
tool_execution_handler = kwargs.get("tool_execution_handler", self._agent.thread_pool_wrapper)
154-
event_loop_metrics = kwargs.get("event_loop_metrics", self._agent.event_loop_metrics)
155-
messages = kwargs.get("messages", self._agent.messages)
156-
tool_config = kwargs.get("tool_config", self._agent.tool_config)
157-
callback_handler = kwargs.get("callback_handler", self._agent.callback_handler)
158-
record_direct_tool_call = kwargs.get("record_direct_tool_call", self._agent.record_direct_tool_call)
159-
160-
# Process tool call
161-
handler_kwargs = {
162-
k: v
163-
for k, v in kwargs.items()
164-
if k
165-
not in [
166-
"tool_execution_handler",
167-
"event_loop_metrics",
168-
"messages",
169-
"tool_config",
170-
"callback_handler",
171-
"tool_handler",
172-
"system_prompt",
173-
"model",
174-
"model_id",
175-
"user_message_override",
176-
"agent",
177-
"record_direct_tool_call",
178-
]
179-
}
180-
181129
# Execute the tool
182130
tool_result = self._agent.tool_handler.process(
183131
tool=tool_use,
184132
model=self._agent.model,
185133
system_prompt=self._agent.system_prompt,
186-
messages=messages,
187-
tool_config=tool_config,
188-
callback_handler=callback_handler,
189-
tool_execution_handler=tool_execution_handler,
190-
event_loop_metrics=event_loop_metrics,
191-
agent=self._agent,
192-
**handler_kwargs,
134+
messages=self._agent.messages,
135+
tool_config=self._agent.tool_config,
136+
callback_handler=self._agent.callback_handler,
137+
kwargs=kwargs,
193138
)
194139

195-
if record_direct_tool_call:
140+
if record_direct_tool_call is not None:
141+
should_record_direct_tool_call = record_direct_tool_call
142+
else:
143+
should_record_direct_tool_call = self._agent.record_direct_tool_call
144+
145+
if should_record_direct_tool_call:
196146
# Create a record of this tool execution in the message history
197-
self._agent._record_tool_execution(tool_use, tool_result, user_message_override, messages)
147+
self._agent._record_tool_execution(
148+
tool_use, tool_result, user_message_override, self._agent.messages
149+
)
198150

199151
# Apply window management
200152
self._agent.conversation_manager.apply_management(self._agent)
@@ -203,6 +155,27 @@ def caller(**kwargs: Any) -> Any:
203155

204156
return caller
205157

158+
def _find_normalized_tool_name(self, name: str) -> str:
159+
"""Lookup the tool represented by name, replacing characters with underscores as necessary."""
160+
tool_registry = self._agent.tool_registry.registry
161+
162+
if tool_registry.get(name, None):
163+
return name
164+
165+
# If the desired name contains underscores, it might be a placeholder for characters that can't be
166+
# represented as python identifiers but are valid as tool names, such as dashes. In that case, find
167+
# all tools that can be represented with the normalized name
168+
if "_" in name:
169+
filtered_tools = [
170+
tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name
171+
]
172+
173+
# The registry itself defends against similar names, so we can just take the first match
174+
if filtered_tools:
175+
return filtered_tools[0]
176+
177+
raise AttributeError(f"Tool '{name}' not found")
178+
206179
def __init__(
207180
self,
208181
model: Union[Model, str, None] = None,
@@ -371,7 +344,7 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
371344
372345
Args:
373346
prompt: The natural language prompt from the user.
374-
**kwargs: Additional parameters to pass to the event loop.
347+
**kwargs: Additional parameters to pass through the event loop.
375348
376349
Returns:
377350
Result object containing:
@@ -514,44 +487,35 @@ def _execute_event_loop_cycle(
514487
Yields:
515488
Events of the loop cycle.
516489
"""
517-
# Extract parameters with fallbacks to instance values
518-
system_prompt = kwargs.pop("system_prompt", self.system_prompt)
519-
model = kwargs.pop("model", self.model)
520-
tool_execution_handler = kwargs.pop("tool_execution_handler", self.thread_pool_wrapper)
521-
event_loop_metrics = kwargs.pop("event_loop_metrics", self.event_loop_metrics)
522-
callback_handler_override = kwargs.pop("callback_handler", callback_handler)
523-
tool_handler = kwargs.pop("tool_handler", self.tool_handler)
524-
messages = kwargs.pop("messages", self.messages)
525-
tool_config = kwargs.pop("tool_config", self.tool_config)
526-
kwargs.pop("agent", None) # Remove agent to avoid conflicts
490+
# Add `Agent` to kwargs to keep backwards-compatibility
491+
kwargs["agent"] = self
527492

528493
try:
529494
# Execute the main event loop cycle
530495
yield from event_loop_cycle(
531-
model=model,
532-
system_prompt=system_prompt,
533-
messages=messages, # will be modified by event_loop_cycle
534-
tool_config=tool_config,
535-
callback_handler=callback_handler_override,
536-
tool_handler=tool_handler,
537-
tool_execution_handler=tool_execution_handler,
538-
event_loop_metrics=event_loop_metrics,
539-
agent=self,
496+
model=self.model,
497+
system_prompt=self.system_prompt,
498+
messages=self.messages, # will be modified by event_loop_cycle
499+
tool_config=self.tool_config,
500+
callback_handler=callback_handler,
501+
tool_handler=self.tool_handler,
502+
tool_execution_handler=self.thread_pool_wrapper,
503+
event_loop_metrics=self.event_loop_metrics,
540504
event_loop_parent_span=self.trace_span,
541-
**kwargs,
505+
kwargs=kwargs,
542506
)
543507

544508
except ContextWindowOverflowException as e:
545509
# Try reducing the context size and retrying
546510
self.conversation_manager.reduce_context(self, e=e)
547-
yield from self._execute_event_loop_cycle(callback_handler_override, kwargs)
511+
yield from self._execute_event_loop_cycle(callback_handler, kwargs)
548512

549513
def _record_tool_execution(
550514
self,
551-
tool: dict[str, Any],
552-
tool_result: dict[str, Any],
515+
tool: ToolUse,
516+
tool_result: ToolResult,
553517
user_message_override: Optional[str],
554-
messages: list[dict[str, Any]],
518+
messages: Messages,
555519
) -> None:
556520
"""Record a tool execution in the message history.
557521
@@ -569,7 +533,7 @@ def _record_tool_execution(
569533
messages: The message history to append to.
570534
"""
571535
# Create user message describing the tool call
572-
user_msg_content = [
536+
user_msg_content: List[ContentBlock] = [
573537
{"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {json.dumps(tool['input'])}\n")}
574538
]
575539

@@ -578,19 +542,19 @@ def _record_tool_execution(
578542
user_msg_content.insert(0, {"text": f"{user_message_override}\n"})
579543

580544
# Create the message sequence
581-
user_msg = {
545+
user_msg: Message = {
582546
"role": "user",
583547
"content": user_msg_content,
584548
}
585-
tool_use_msg = {
549+
tool_use_msg: Message = {
586550
"role": "assistant",
587551
"content": [{"toolUse": tool}],
588552
}
589-
tool_result_msg = {
553+
tool_result_msg: Message = {
590554
"role": "user",
591555
"content": [{"toolResult": tool_result}],
592556
}
593-
assistant_msg = {
557+
assistant_msg: Message = {
594558
"role": "assistant",
595559
"content": [{"text": f"agent.{tool['name']} was called"}],
596560
}

0 commit comments

Comments
 (0)