Skip to content

Commit 0cb0ee8

Browse files
committed
Fix mypy errors
1 parent 05f8874 commit 0cb0ee8

File tree

5 files changed

+64
-38
lines changed

5 files changed

+64
-38
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ test-integ = [
179179
prepare = [
180180
"hatch fmt --linter",
181181
"hatch fmt --formatter",
182+
"hatch run test-lint",
182183
"hatch test --all"
183184
]
184185

src/strands/tools/decorator.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,17 @@ def my_tool(param1: str, param2: int = 42) -> dict:
5151
ParamSpec,
5252
Type,
5353
TypeVar,
54+
Union,
5455
cast,
5556
get_type_hints,
5657
overload,
57-
override,
5858
)
5959

6060
import docstring_parser
6161
from pydantic import BaseModel, Field, create_model
62+
from typing_extensions import override
6263

63-
from strands.types.tools import AgentTool, ToolResult, ToolSpec, ToolUse
64+
from strands.types.tools import AgentTool, JSONSchema, ToolResult, ToolSpec, ToolUse
6465

6566
# Type for wrapped function
6667
T = TypeVar("T", bound=Callable[..., Any])
@@ -253,7 +254,7 @@ def validate_input(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
253254
R = TypeVar("R") # Return type
254255

255256

256-
class DecoratedFunctionTool(Generic[P, R], AgentTool):
257+
class DecoratedFunctionTool(AgentTool, Generic[P, R]):
257258
"""An AgentTool that wraps a function that was decorated with @tool.
258259
259260
This class adapts Python functions decorated with @tool to the AgentTool interface. It handles both direct
@@ -292,7 +293,7 @@ def __init__(
292293

293294
functools.update_wrapper(wrapper=self, wrapped=self.original_function)
294295

295-
def __get__(self, instance: Any, obj_type=None) -> "DecoratedFunctionTool[P, R]":
296+
def __get__(self, instance: Any, obj_type: Optional[Type] = None) -> "DecoratedFunctionTool[P, R]":
296297
"""Descriptor protocol implementation for proper method binding.
297298
298299
This method enables the decorated function to work correctly when used as a class method.
@@ -396,7 +397,8 @@ def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolRes
396397
if "agent" in kwargs and "agent" in self._metadata.signature.parameters:
397398
validated_input["agent"] = kwargs.get("agent")
398399

399-
result = self.original_function(**validated_input)
400+
# We get "too few arguments here" but because that's because fof the way we're calling it
401+
result = self.original_function(**validated_input) # type: ignore
400402

401403
# FORMAT THE RESULT for Strands Agent
402404
if isinstance(result, dict) and "status" in result and "content" in result:
@@ -456,8 +458,19 @@ def get_display_properties(self) -> dict[str, str]:
456458
def tool(__func: Callable[P, R]) -> DecoratedFunctionTool[P, R]: ...
457459
# Handle @decorator()
458460
@overload
459-
def tool(**tool_kwargs: Any) -> Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: ...
460-
def tool(func: Optional[Callable[P, R]] = None, **tool_kwargs: Any) -> DecoratedFunctionTool[P, R]:
461+
def tool(
462+
description: Optional[str] = None,
463+
inputSchema: Optional[JSONSchema] = None,
464+
name: Optional[str] = None,
465+
) -> Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: ...
466+
# Suppressing the type error because we want callers to be able to use both `tool` and `tool()` at the
467+
# call site, but the actual implementation handles that and it's not representable via the type-system
468+
def tool( # type: ignore
469+
func: Optional[Callable[P, R]] = None,
470+
description: Optional[str] = None,
471+
inputSchema: Optional[JSONSchema] = None,
472+
name: Optional[str] = None,
473+
) -> Union[DecoratedFunctionTool[P, R], Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]]:
461474
"""Decorator that transforms a Python function into a Strands tool.
462475
463476
This decorator seamlessly enables a function to be called both as a regular Python function and as a Strands tool.
@@ -472,25 +485,31 @@ def tool(func: Optional[Callable[P, R]] = None, **tool_kwargs: Any) -> Decorated
472485
4. Formats return values according to the expected Strands tool result format
473486
5. Provides automatic error handling and reporting
474487
488+
The decorator can be used in two ways:
489+
- As a simple decorator: `@tool`
490+
- With parameters: `@tool(name="custom_name", description="Custom description")`
491+
475492
Args:
476-
func: The function to decorate.
477-
**tool_kwargs: Additional tool specification options to override extracted values.
478-
E.g., `name="custom_name", description="Custom description"`.
493+
func: The function to decorate. When used as a simple decorator, this is the function being decorated.
494+
When used with parameters, this will be None.
495+
description: Optional custom description to override the function's docstring.
496+
inputSchema: Optional custom JSON schema to override the automatically generated schema.
497+
name: Optional custom name to override the function's name.
479498
480499
Returns:
481-
The decorated function with attached tool specifications.
500+
An AgentTool that also mimics the original function when invoked
482501
483502
Example:
484503
```python
485504
@tool
486505
def my_tool(name: str, count: int = 1) -> str:
487506
'''Does something useful with the provided parameters.
488507
489-
"Args:
508+
Args:
490509
name: The name to process
491510
count: Number of times to process (default: 1)
492511
493-
"Returns:
512+
Returns:
494513
A message with the result
495514
'''
496515
return f"Processed {name} {count} times"
@@ -503,13 +522,25 @@ def my_tool(name: str, count: int = 1) -> str:
503522
# "content": [{"text": "Processed example 3 times"}]
504523
# }
505524
```
525+
526+
Example with parameters:
527+
```python
528+
@tool(name="custom_tool", description="A tool with a custom name and description")
529+
def my_tool(name: str, count: int = 1) -> str:
530+
return f"Processed {name} {count} times"
531+
```
506532
"""
507533

508534
def decorator(f: T) -> "DecoratedFunctionTool[P, R]":
509535
# Create function tool metadata
510536
tool_meta = FunctionToolMetadata(f)
511537
tool_spec = tool_meta.extract_metadata()
512-
tool_spec.update(tool_kwargs)
538+
if name is not None:
539+
tool_spec["name"] = name
540+
if description is not None:
541+
tool_spec["description"] = description
542+
if inputSchema is not None:
543+
tool_spec["inputSchema"] = inputSchema
513544

514545
tool_name = tool_spec.get("name", f.__name__)
515546

@@ -520,6 +551,8 @@ def decorator(f: T) -> "DecoratedFunctionTool[P, R]":
520551

521552
# Handle both @tool and @tool() syntax
522553
if func is None:
554+
# Need to ignore type-checking here since it's hard to represent the support
555+
# for both flows using the type system
523556
return decorator
524557

525558
return decorator(func)

src/strands/tools/loader.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import os
77
import sys
88
from pathlib import Path
9-
from typing import Any, Dict, List, Optional
9+
from typing import Any, Dict, List, Optional, cast
1010

1111
from ..types.tools import AgentTool
1212
from .decorator import DecoratedFunctionTool
@@ -141,7 +141,8 @@ def load_python_tool(tool_path: str, tool_name: str) -> AgentTool:
141141
logger.debug(
142142
"tool_name=<%s>, module_path=<%s> | found function-based tool", function_name, module_path
143143
)
144-
return func
144+
# mypy has problems converting between DecoratedFunctionTool <-> AgentTool
145+
return cast(AgentTool, func)
145146
else:
146147
raise ValueError(
147148
f"Function {function_name} in {module_path} is not a valid tool (missing @tool decorator)"
@@ -174,8 +175,8 @@ def load_python_tool(tool_path: str, tool_name: str) -> AgentTool:
174175
logger.debug(
175176
"tool_name=<%s>, tool_path=<%s> | found function-based tool in path", attr_name, tool_path
176177
)
177-
# Return as DecoratedFunctionTool
178-
return attr
178+
# mypy has problems converting between DecoratedFunctionTool <-> AgentTool
179+
return cast(AgentTool, attr)
179180

180181
# If no function-based tools found, fall back to traditional module-level tool
181182
tool_spec = getattr(module, "TOOL_SPEC", None)

src/strands/tools/registry.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111
from importlib import import_module, util
1212
from os.path import expanduser
1313
from pathlib import Path
14-
from typing import Any, Dict, List, Optional
14+
from typing import Any, Dict, List, Optional, overload
1515

1616
from typing_extensions import TypedDict, cast
1717

1818
from ..types.tools import AgentTool, Tool, ToolChoice, ToolChoiceAuto, ToolConfig, ToolSpec
19+
from .decorator import DecoratedFunctionTool
1920
from .loader import scan_module_for_tools
20-
from .tools import FunctionTool, PythonAgentTool, normalize_schema, normalize_tool_spec
21+
from .tools import PythonAgentTool, normalize_schema, normalize_tool_spec
2122

2223
logger = logging.getLogger(__name__)
2324

@@ -92,15 +93,7 @@ def process_tools(self, tools: List[Any]) -> List[str]:
9293
if not function_tools:
9394
logger.warning("tool_name=<%s>, module_path=<%s> | invalid agent tool", tool_name, module_path)
9495

95-
# Case 5: Function decorated with @tool
96-
elif inspect.isfunction(tool) and hasattr(tool, "TOOL_SPEC"):
97-
try:
98-
function_tool = FunctionTool(tool)
99-
logger.debug("tool_name=<%s> | registering function tool", function_tool.tool_name)
100-
self.register_tool(function_tool)
101-
tool_names.append(function_tool.tool_name)
102-
except Exception as e:
103-
logger.warning("tool_name=<%s> | failed to register function tool | %s", tool.__name__, e)
96+
# Case 5: AgentTools (which also covers @tool)
10497
elif isinstance(tool, AgentTool):
10598
self.register_tool(tool)
10699
tool_names.append(tool.tool_name)
@@ -176,7 +169,12 @@ def get_all_tools_config(self) -> Dict[str, Any]:
176169
logger.debug("tool_count=<%s> | tools configured", len(tool_config))
177170
return tool_config
178171

179-
def register_tool(self, tool: AgentTool) -> None:
172+
# mypy has problems converting between DecoratedFunctionTool <-> AgentTool
173+
@overload
174+
def register_tool(self, tool: DecoratedFunctionTool) -> None: ...
175+
@overload
176+
def register_tool(self, tool: AgentTool) -> None: ...
177+
def register_tool(self, tool: AgentTool) -> None: # type: ignore
180178
"""Register a tool function with the given name.
181179
182180
Args:

tests/strands/tools/test_decorator.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,7 @@ def test_tool(param1: str, param2: int) -> str:
5858

5959
# Make sure these are set properly
6060
assert test_tool.__wrapped__ is not None
61-
assert test_tool.__doc__ == (
62-
"Test tool function.\n"
63-
"\n"
64-
" Args:\n"
65-
" param1: First parameter\n"
66-
" param2: Second parameter\n"
67-
" "
68-
)
61+
assert test_tool.__doc__ == test_tool.original_function.__doc__
6962

7063

7164
def test_tool_with_custom_name_description():

0 commit comments

Comments
 (0)