4
4
try :
5
5
import boto3
6
6
except ImportError :
7
- raise ImportError (
8
- "The 'boto3' library is required. Please install it using 'pip install boto3'."
9
- )
7
+ raise ImportError ("The 'boto3' library is required. Please install it using 'pip install boto3'." )
10
8
11
9
from mem0 .configs .llms .base import BaseLlmConfig
12
10
from mem0 .llms .base import LLMBase
13
11
14
12
15
13
class AWSBedrockLLM (LLMBase ):
16
- """
17
- A wrapper for AWS Bedrock's language models, integrating them with the LLMBase class.
18
- """
19
-
20
14
def __init__ (self , config : Optional [BaseLlmConfig ] = None ):
21
- """
22
- Initializes the AWS Bedrock LLM with the provided configuration.
23
-
24
- Args:
25
- config (Optional[BaseLlmConfig]): Configuration object for the model.
26
- """
27
15
super ().__init__ (config )
28
16
29
17
if not self .config .model :
@@ -37,29 +25,49 @@ def __init__(self, config: Optional[BaseLlmConfig] = None):
37
25
38
26
def _format_messages (self , messages : List [Dict [str , str ]]) -> str :
39
27
"""
40
- Formats a list of messages into a structured prompt for the model.
28
+ Formats a list of messages into the required prompt structure for the model.
41
29
42
30
Args:
43
- messages (List[Dict[str, str]]): A list of dictionaries containing 'role' and 'content'.
31
+ messages (List[Dict[str, str]]): A list of dictionaries where each dictionary represents a message.
32
+ Each dictionary contains 'role' and 'content' keys.
44
33
45
34
Returns:
46
35
str: A formatted string combining all messages, structured with roles capitalized and separated by newlines.
47
36
"""
48
- formatted_messages = [
49
- f"\n \n { msg ['role' ].capitalize ()} : { msg ['content' ]} " for msg in messages
50
- ]
37
+ formatted_messages = []
38
+ for message in messages :
39
+ role = message ["role" ].capitalize ()
40
+ content = message ["content" ]
41
+ formatted_messages .append (f"\n \n { role } : { content } " )
42
+
51
43
return "" .join (formatted_messages ) + "\n \n Assistant:"
52
44
53
- def _parse_response (self , response ) -> str :
45
+ def _parse_response (self , response , tools ) -> str :
54
46
"""
55
- Extracts the generated response from the API response .
47
+ Process the response based on whether tools are used or not .
56
48
57
49
Args:
58
- response: The raw response from the AWS Bedrock API.
50
+ response: The raw response from API.
51
+ tools: The list of tools provided in the request.
59
52
60
53
Returns:
61
- str: The generated response text .
54
+ str or dict : The processed response.
62
55
"""
56
+ if tools :
57
+ processed_response = {"tool_calls" : []}
58
+
59
+ if response ["output" ]["message" ]["content" ]:
60
+ for item in response ["output" ]["message" ]["content" ]:
61
+ if "toolUse" in item :
62
+ processed_response ["tool_calls" ].append (
63
+ {
64
+ "name" : item ["toolUse" ]["name" ],
65
+ "arguments" : item ["toolUse" ]["input" ],
66
+ }
67
+ )
68
+
69
+ return processed_response
70
+
63
71
response_body = json .loads (response ["body" ].read ().decode ())
64
72
return response_body .get ("completion" , "" )
65
73
@@ -68,21 +76,22 @@ def _prepare_input(
68
76
provider : str ,
69
77
model : str ,
70
78
prompt : str ,
71
- model_kwargs : Optional [Dict [str , Any ]] = None ,
79
+ model_kwargs : Optional [Dict [str , Any ]] = {} ,
72
80
) -> Dict [str , Any ]:
73
81
"""
74
- Prepares the input dictionary for the specified provider's model.
82
+ Prepares the input dictionary for the specified provider's model by mapping and renaming
83
+ keys in the input based on the provider's requirements.
75
84
76
85
Args:
77
- provider (str): The model provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon").
78
- model (str): The model identifier.
79
- prompt (str): The input prompt.
80
- model_kwargs (Optional[ Dict[str, Any]] ): Additional model parameters .
86
+ provider (str): The name of the service provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon").
87
+ model (str): The name or identifier of the model being used .
88
+ prompt (str): The text prompt to be processed by the model .
89
+ model_kwargs (Dict[str, Any]): Additional keyword arguments specific to the model's requirements .
81
90
82
91
Returns:
83
- Dict[str, Any]: The prepared input dictionary.
92
+ Dict[str, Any]: The prepared input dictionary with the correct keys and values for the specified provider .
84
93
"""
85
- model_kwargs = model_kwargs or {}
94
+
86
95
input_body = {"prompt" : prompt , ** model_kwargs }
87
96
88
97
provider_mappings = {
@@ -110,35 +119,102 @@ def _prepare_input(
110
119
},
111
120
}
112
121
input_body ["textGenerationConfig" ] = {
113
- k : v
114
- for k , v in input_body ["textGenerationConfig" ].items ()
115
- if v is not None
122
+ k : v for k , v in input_body ["textGenerationConfig" ].items () if v is not None
116
123
}
117
124
118
125
return input_body
119
126
120
- def generate_response (self , messages : List [ Dict [ str , str ]]) -> str :
127
+ def _convert_tool_format (self , original_tools ) :
121
128
"""
122
- Generates a response using AWS Bedrock based on the provided messages .
129
+ Converts a list of tools from their original format to a new standardized format .
123
130
124
131
Args:
125
- messages (List[Dict[str, str]] ): List of message dictionaries containing 'role' and 'content' .
132
+ original_tools (list ): A list of dictionaries representing the original tools, each containing a 'type' key and corresponding details .
126
133
127
134
Returns:
128
- str: The generated response text .
135
+ list: A list of dictionaries representing the tools in the new standardized format .
129
136
"""
130
- prompt = self ._format_messages (messages )
131
- provider = self .config .model .split ("." )[0 ]
132
- input_body = self ._prepare_input (
133
- provider , self .config .model , prompt , self .model_kwargs
134
- )
135
- body = json .dumps (input_body )
136
-
137
- response = self .client .invoke_model (
138
- body = body ,
139
- modelId = self .config .model ,
140
- accept = "application/json" ,
141
- contentType = "application/json" ,
142
- )
143
-
144
- return self ._parse_response (response )
137
+ new_tools = []
138
+
139
+ for tool in original_tools :
140
+ if tool ["type" ] == "function" :
141
+ function = tool ["function" ]
142
+ new_tool = {
143
+ "toolSpec" : {
144
+ "name" : function ["name" ],
145
+ "description" : function ["description" ],
146
+ "inputSchema" : {
147
+ "json" : {
148
+ "type" : "object" ,
149
+ "properties" : {},
150
+ "required" : function ["parameters" ].get ("required" , []),
151
+ }
152
+ },
153
+ }
154
+ }
155
+
156
+ for prop , details in function ["parameters" ].get ("properties" , {}).items ():
157
+ new_tool ["toolSpec" ]["inputSchema" ]["json" ]["properties" ][prop ] = {
158
+ "type" : details .get ("type" , "string" ),
159
+ "description" : details .get ("description" , "" ),
160
+ }
161
+
162
+ new_tools .append (new_tool )
163
+
164
+ return new_tools
165
+
166
+ def generate_response (
167
+ self ,
168
+ messages : List [Dict [str , str ]],
169
+ response_format = None ,
170
+ tools : Optional [List [Dict ]] = None ,
171
+ tool_choice : str = "auto" ,
172
+ ):
173
+ """
174
+ Generate a response based on the given messages using AWS Bedrock.
175
+
176
+ Args:
177
+ messages (list): List of message dicts containing 'role' and 'content'.
178
+ tools (list, optional): List of tools that the model can call. Defaults to None.
179
+ tool_choice (str, optional): Tool choice method. Defaults to "auto".
180
+
181
+ Returns:
182
+ str: The generated response.
183
+ """
184
+
185
+ if tools :
186
+ # Use converse method when tools are provided
187
+ messages = [
188
+ {
189
+ "role" : "user" ,
190
+ "content" : [{"text" : message ["content" ]} for message in messages ],
191
+ }
192
+ ]
193
+ inference_config = {
194
+ "temperature" : self .model_kwargs ["temperature" ],
195
+ "maxTokens" : self .model_kwargs ["max_tokens_to_sample" ],
196
+ "topP" : self .model_kwargs ["top_p" ],
197
+ }
198
+ tools_config = {"tools" : self ._convert_tool_format (tools )}
199
+
200
+ response = self .client .converse (
201
+ modelId = self .config .model ,
202
+ messages = messages ,
203
+ inferenceConfig = inference_config ,
204
+ toolConfig = tools_config ,
205
+ )
206
+ else :
207
+ # Use invoke_model method when no tools are provided
208
+ prompt = self ._format_messages (messages )
209
+ provider = self .model .split ("." )[0 ]
210
+ input_body = self ._prepare_input (provider , self .config .model , prompt , ** self .model_kwargs )
211
+ body = json .dumps (input_body )
212
+
213
+ response = self .client .invoke_model (
214
+ body = body ,
215
+ modelId = self .model ,
216
+ accept = "application/json" ,
217
+ contentType = "application/json" ,
218
+ )
219
+
220
+ return self ._parse_response (response , tools )
0 commit comments