9
9
from strands .types .tools import ToolSpec
10
10
11
11
from rllm .agents .agent import Step , Trajectory
12
- from rllm .engine . rollout import ModelOutput , RolloutEngine
12
+ from rllm .engine import ModelOutput , RolloutEngine
13
13
14
14
T = TypeVar ("T" , bound = BaseModel )
15
15
16
16
17
17
class RLLMModel (Model ):
18
18
"""Model class that uses rLLM's RolloutEngine for inference."""
19
19
20
- def __init__ (self , rollout_engine : RolloutEngine , model_id : str = "gpt-4" , ** model_config ):
20
+ def __init__ (self , rollout_engine : RolloutEngine , ** kwargs ):
21
21
"""Initialize the RLLMModel.
22
22
23
23
Args:
24
24
rollout_engine: The rLLM RolloutEngine instance to use for inference
25
- model_id: The model ID to use
26
- **model_config: Additional model configuration
27
25
"""
28
26
self .rollout_engine = rollout_engine
29
- self .config = {"model_id" : model_id , "params" : model_config }
30
-
31
- def update_config (self , ** model_config : Any ) -> None :
32
- """Update the model configuration.
33
-
34
- Args:
35
- **model_config: Configuration overrides.
36
- """
37
- if "model_id" in model_config :
38
- self .config ["model_id" ] = model_config .pop ("model_id" )
39
-
40
- if "params" not in self .config :
41
- self .config ["params" ] = {}
42
- self .config ["params" ].update (model_config )
43
-
44
- def get_config (self ) -> dict [str , Any ]:
45
- """Get the model configuration.
46
-
47
- Returns:
48
- The model's configuration.
49
- """
50
- return self .config .copy ()
27
+ self .kwargs = kwargs
51
28
52
29
async def structured_output (self , output_model : type [T ], prompt : Messages , system_prompt : str | None = None , ** kwargs : Any ) -> AsyncGenerator [dict [str , T | Any ], None ]:
53
30
"""Get structured output from the model.
@@ -73,9 +50,7 @@ async def structured_output(self, output_model: type[T], prompt: Messages, syste
73
50
messages [- 1 ]["content" ] = f"{ original_content } \n \n Please respond with a JSON object that matches this schema: { output_model .model_json_schema ()} "
74
51
75
52
# Get response from rollout engine
76
- model_output : ModelOutput = await self .rollout_engine .get_model_response (messages , model = self .config ["model_id" ], ** self .config .get ("params" , {}), ** kwargs )
77
-
78
- response_text = model_output .text
53
+ response_text = (await self .rollout_engine .get_model_response (messages , ** kwargs )).text
79
54
80
55
try :
81
56
# Try to parse the response as JSON and convert to the output model
@@ -129,7 +104,7 @@ async def stream(
129
104
yield {"contentBlockStart" : {"start" : {}}}
130
105
131
106
# Get response from rollout engine
132
- model_output : ModelOutput = await self .rollout_engine .get_model_response (chat_messages , model = self . config [ "model_id" ], ** self . config . get ( "params" , {}), ** kwargs )
107
+ model_output : ModelOutput = await self .rollout_engine .get_model_response (chat_messages , ** kwargs )
133
108
134
109
# Extract text from ModelOutput
135
110
response_text = model_output .text
@@ -203,11 +178,10 @@ def _convert_messages_to_chat_format(self, messages: Messages, system_prompt: st
203
178
204
179
205
180
class StrandsAgent (Agent ):
206
- def __init__ (self , model : str , ** kwargs ):
181
+ def __init__ (self , model : RLLMModel , ** kwargs ):
207
182
"""Initialize StrandsAgent with trajectory tracking.
208
183
209
184
Args:
210
- model: The model to use (can be a string or Model instance)
211
185
**kwargs: Additional arguments to pass to the base Agent class
212
186
"""
213
187
0 commit comments