Skip to content

Commit ee66e0c

Browse files
Reverting the tools commit (#2404)
1 parent 1aed611 commit ee66e0c

21 files changed

+990
-475
lines changed

mem0/llms/anthropic.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,14 @@
44
try:
55
import anthropic
66
except ImportError:
7-
raise ImportError(
8-
"The 'anthropic' library is required. Please install it using 'pip install anthropic'."
9-
)
7+
raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.")
108

119
from mem0.configs.llms.base import BaseLlmConfig
1210
from mem0.llms.base import LLMBase
1311

1412

1513
class AnthropicLLM(LLMBase):
16-
"""
17-
A class for interacting with Anthropic's Claude models using the specified configuration.
18-
"""
19-
2014
def __init__(self, config: Optional[BaseLlmConfig] = None):
21-
"""
22-
Initializes the AnthropicLLM instance with the given configuration.
23-
24-
Args:
25-
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
26-
"""
2715
super().__init__(config)
2816

2917
if not self.config.model:
@@ -35,17 +23,23 @@ def __init__(self, config: Optional[BaseLlmConfig] = None):
3523
def generate_response(
3624
self,
3725
messages: List[Dict[str, str]],
38-
) -> str:
26+
response_format=None,
27+
tools: Optional[List[Dict]] = None,
28+
tool_choice: str = "auto",
29+
):
3930
"""
40-
Generates a response using Anthropic's Claude model based on the provided messages.
31+
Generate a response based on the given messages using Anthropic.
4132
4233
Args:
43-
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
34+
messages (list): List of message dicts containing 'role' and 'content'.
35+
response_format (str or object, optional): Format of the response. Defaults to "text".
36+
tools (list, optional): List of tools that the model can call. Defaults to None.
37+
tool_choice (str, optional): Tool choice method. Defaults to "auto".
4438
4539
Returns:
46-
str: The generated response from the model.
40+
str: The generated response.
4741
"""
48-
# Extract system message separately
42+
# Separate system message from other messages
4943
system_message = ""
5044
filtered_messages = []
5145
for message in messages:
@@ -62,6 +56,9 @@ def generate_response(
6256
"max_tokens": self.config.max_tokens,
6357
"top_p": self.config.top_p,
6458
}
59+
if tools: # TODO: Remove tools if no issues found with new memory addition logic
60+
params["tools"] = tools
61+
params["tool_choice"] = tool_choice
6562

6663
response = self.client.messages.create(**params)
6764
return response.content[0].text

mem0/llms/aws_bedrock.py

Lines changed: 128 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,14 @@
44
try:
55
import boto3
66
except ImportError:
7-
raise ImportError(
8-
"The 'boto3' library is required. Please install it using 'pip install boto3'."
9-
)
7+
raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")
108

119
from mem0.configs.llms.base import BaseLlmConfig
1210
from mem0.llms.base import LLMBase
1311

1412

1513
class AWSBedrockLLM(LLMBase):
16-
"""
17-
A wrapper for AWS Bedrock's language models, integrating them with the LLMBase class.
18-
"""
19-
2014
def __init__(self, config: Optional[BaseLlmConfig] = None):
21-
"""
22-
Initializes the AWS Bedrock LLM with the provided configuration.
23-
24-
Args:
25-
config (Optional[BaseLlmConfig]): Configuration object for the model.
26-
"""
2715
super().__init__(config)
2816

2917
if not self.config.model:
@@ -37,29 +25,49 @@ def __init__(self, config: Optional[BaseLlmConfig] = None):
3725

3826
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
3927
"""
40-
Formats a list of messages into a structured prompt for the model.
28+
Formats a list of messages into the required prompt structure for the model.
4129
4230
Args:
43-
messages (List[Dict[str, str]]): A list of dictionaries containing 'role' and 'content'.
31+
messages (List[Dict[str, str]]): A list of dictionaries where each dictionary represents a message.
32+
Each dictionary contains 'role' and 'content' keys.
4433
4534
Returns:
4635
str: A formatted string combining all messages, structured with roles capitalized and separated by newlines.
4736
"""
48-
formatted_messages = [
49-
f"\n\n{msg['role'].capitalize()}: {msg['content']}" for msg in messages
50-
]
37+
formatted_messages = []
38+
for message in messages:
39+
role = message["role"].capitalize()
40+
content = message["content"]
41+
formatted_messages.append(f"\n\n{role}: {content}")
42+
5143
return "".join(formatted_messages) + "\n\nAssistant:"
5244

53-
def _parse_response(self, response) -> str:
45+
def _parse_response(self, response, tools) -> str:
5446
"""
55-
Extracts the generated response from the API response.
47+
Process the response based on whether tools are used or not.
5648
5749
Args:
58-
response: The raw response from the AWS Bedrock API.
50+
response: The raw response from API.
51+
tools: The list of tools provided in the request.
5952
6053
Returns:
61-
str: The generated response text.
54+
str or dict: The processed response.
6255
"""
56+
if tools:
57+
processed_response = {"tool_calls": []}
58+
59+
if response["output"]["message"]["content"]:
60+
for item in response["output"]["message"]["content"]:
61+
if "toolUse" in item:
62+
processed_response["tool_calls"].append(
63+
{
64+
"name": item["toolUse"]["name"],
65+
"arguments": item["toolUse"]["input"],
66+
}
67+
)
68+
69+
return processed_response
70+
6371
response_body = json.loads(response["body"].read().decode())
6472
return response_body.get("completion", "")
6573

@@ -68,21 +76,22 @@ def _prepare_input(
6876
provider: str,
6977
model: str,
7078
prompt: str,
71-
model_kwargs: Optional[Dict[str, Any]] = None,
79+
model_kwargs: Optional[Dict[str, Any]] = {},
7280
) -> Dict[str, Any]:
7381
"""
74-
Prepares the input dictionary for the specified provider's model.
82+
Prepares the input dictionary for the specified provider's model by mapping and renaming
83+
keys in the input based on the provider's requirements.
7584
7685
Args:
77-
provider (str): The model provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon").
78-
model (str): The model identifier.
79-
prompt (str): The input prompt.
80-
model_kwargs (Optional[Dict[str, Any]]): Additional model parameters.
86+
provider (str): The name of the service provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon").
87+
model (str): The name or identifier of the model being used.
88+
prompt (str): The text prompt to be processed by the model.
89+
model_kwargs (Dict[str, Any]): Additional keyword arguments specific to the model's requirements.
8190
8291
Returns:
83-
Dict[str, Any]: The prepared input dictionary.
92+
Dict[str, Any]: The prepared input dictionary with the correct keys and values for the specified provider.
8493
"""
85-
model_kwargs = model_kwargs or {}
94+
8695
input_body = {"prompt": prompt, **model_kwargs}
8796

8897
provider_mappings = {
@@ -110,35 +119,102 @@ def _prepare_input(
110119
},
111120
}
112121
input_body["textGenerationConfig"] = {
113-
k: v
114-
for k, v in input_body["textGenerationConfig"].items()
115-
if v is not None
122+
k: v for k, v in input_body["textGenerationConfig"].items() if v is not None
116123
}
117124

118125
return input_body
119126

120-
def generate_response(self, messages: List[Dict[str, str]]) -> str:
127+
def _convert_tool_format(self, original_tools):
121128
"""
122-
Generates a response using AWS Bedrock based on the provided messages.
129+
Converts a list of tools from their original format to a new standardized format.
123130
124131
Args:
125-
messages (List[Dict[str, str]]): List of message dictionaries containing 'role' and 'content'.
132+
original_tools (list): A list of dictionaries representing the original tools, each containing a 'type' key and corresponding details.
126133
127134
Returns:
128-
str: The generated response text.
135+
list: A list of dictionaries representing the tools in the new standardized format.
129136
"""
130-
prompt = self._format_messages(messages)
131-
provider = self.config.model.split(".")[0]
132-
input_body = self._prepare_input(
133-
provider, self.config.model, prompt, self.model_kwargs
134-
)
135-
body = json.dumps(input_body)
136-
137-
response = self.client.invoke_model(
138-
body=body,
139-
modelId=self.config.model,
140-
accept="application/json",
141-
contentType="application/json",
142-
)
143-
144-
return self._parse_response(response)
137+
new_tools = []
138+
139+
for tool in original_tools:
140+
if tool["type"] == "function":
141+
function = tool["function"]
142+
new_tool = {
143+
"toolSpec": {
144+
"name": function["name"],
145+
"description": function["description"],
146+
"inputSchema": {
147+
"json": {
148+
"type": "object",
149+
"properties": {},
150+
"required": function["parameters"].get("required", []),
151+
}
152+
},
153+
}
154+
}
155+
156+
for prop, details in function["parameters"].get("properties", {}).items():
157+
new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = {
158+
"type": details.get("type", "string"),
159+
"description": details.get("description", ""),
160+
}
161+
162+
new_tools.append(new_tool)
163+
164+
return new_tools
165+
166+
def generate_response(
167+
self,
168+
messages: List[Dict[str, str]],
169+
response_format=None,
170+
tools: Optional[List[Dict]] = None,
171+
tool_choice: str = "auto",
172+
):
173+
"""
174+
Generate a response based on the given messages using AWS Bedrock.
175+
176+
Args:
177+
messages (list): List of message dicts containing 'role' and 'content'.
178+
tools (list, optional): List of tools that the model can call. Defaults to None.
179+
tool_choice (str, optional): Tool choice method. Defaults to "auto".
180+
181+
Returns:
182+
str: The generated response.
183+
"""
184+
185+
if tools:
186+
# Use converse method when tools are provided
187+
messages = [
188+
{
189+
"role": "user",
190+
"content": [{"text": message["content"]} for message in messages],
191+
}
192+
]
193+
inference_config = {
194+
"temperature": self.model_kwargs["temperature"],
195+
"maxTokens": self.model_kwargs["max_tokens_to_sample"],
196+
"topP": self.model_kwargs["top_p"],
197+
}
198+
tools_config = {"tools": self._convert_tool_format(tools)}
199+
200+
response = self.client.converse(
201+
modelId=self.config.model,
202+
messages=messages,
203+
inferenceConfig=inference_config,
204+
toolConfig=tools_config,
205+
)
206+
else:
207+
# Use invoke_model method when no tools are provided
208+
prompt = self._format_messages(messages)
209+
provider = self.model.split(".")[0]
210+
input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs)
211+
body = json.dumps(input_body)
212+
213+
response = self.client.invoke_model(
214+
body=body,
215+
modelId=self.model,
216+
accept="application/json",
217+
contentType="application/json",
218+
)
219+
220+
return self._parse_response(response, tools)

0 commit comments

Comments
 (0)