Skip to content

Commit b7381ab

Browse files
fix integrations
1 parent 3390d96 commit b7381ab

File tree

3 files changed

+15
-40
lines changed

3 files changed

+15
-40
lines changed

rllm/engine/rollout/openai_engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(self, model: str, tokenizer=None, api_retries: int = 3, base_url: s
3434
async def chat_completion(self, messages: list[dict], **kwargs) -> ModelOutput:
3535
sampling_params = self.sampling_params.copy()
3636
sampling_params.update(kwargs)
37+
sampling_params.pop("model", None)
3738
retries = self.api_retries
3839
while retries > 0:
3940
try:

rllm/integrations/smolagents.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class SmolTool:
3838

3939
# Import BaseAgent from rLLM for wrapper classes
4040
from rllm.agents.agent import Step, Trajectory
41-
from rllm.engine.rollout.rollout_engine import ModelOutput, RolloutEngine
41+
from rllm.engine import ModelOutput
4242

4343
logger = logging.getLogger(__name__)
4444

@@ -94,7 +94,7 @@ async def arun(
9494

9595
self.logger.log_task(
9696
content=self.task.strip(),
97-
subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, 'model_id') else '')}",
97+
subtitle=f"{type(self.model).__name__}",
9898
level=LogLevel.INFO,
9999
title=self.name if hasattr(self, "name") else None,
100100
)
@@ -128,7 +128,7 @@ async def arun(
128128
total_output_tokens = 0
129129
correct_token_usage = True
130130
for step in self.memory.steps:
131-
if isinstance(step, ActionStep | PlanningStep):
131+
if isinstance(step, ActionStep) or isinstance(step, PlanningStep):
132132
if step.token_usage is None:
133133
correct_token_usage = False
134134
break
@@ -312,15 +312,12 @@ class RLLMOpenAIModel(SmolModel):
312312
- Skips MessageRole.TOOL_CALL messages (handled as part of assistant messages)
313313
"""
314314

315-
def __init__(self, rollout_engine: RolloutEngine, **kwargs):
315+
def __init__(self, rollout_engine=None, **kwargs):
316316
"""
317317
Initialize the RLLM-integrated OpenAI model.
318318
319319
Args:
320320
rollout_engine: rLLM's RolloutEngine instance
321-
application_id: Unique identifier for the application
322-
model_id: The model identifier (for compatibility)
323-
sampling_params: Sampling parameters for generation
324321
**kwargs: Additional arguments (ignored, for compatibility)
325322
"""
326323
self.rollout_engine = rollout_engine
@@ -330,6 +327,9 @@ def __init__(self, rollout_engine: RolloutEngine, **kwargs):
330327
# Store kwargs for potential future use
331328
self.kwargs = kwargs
332329

330+
if not rollout_engine:
331+
raise ValueError("rollout_engine is required for RLLMOpenAIModel. Pass an instance of rLLM's RolloutEngine.")
332+
333333
async def generate_async(self, messages: list[dict[str, Any]], stop_sequences: list[str] | None = None, response_format: dict[str, str] | None = None, tools_to_call_from: list | None = None, **kwargs) -> Any:
334334
"""
335335
Async version of generate that can be called from async contexts.
@@ -342,7 +342,7 @@ async def generate_async(self, messages: list[dict[str, Any]], stop_sequences: l
342342
# Handle ChatMessage objects from SmolAgent using the helper method
343343
prompt = self._convert_smolagent_messages_to_openai(messages)
344344

345-
model_output: ModelOutput = await self.rollout_engine.get_model_response(prompt, **kwargs)
345+
model_output: ModelOutput = await self.rollout_engine.get_model_response(prompt, max_tokens=kwargs.pop("max_tokens", 4096), **kwargs)
346346

347347
# Extract text and token usage from ModelOutput
348348
response_text = model_output.text

rllm/integrations/strands.py

Lines changed: 6 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,45 +9,22 @@
99
from strands.types.tools import ToolSpec
1010

1111
from rllm.agents.agent import Step, Trajectory
12-
from rllm.engine.rollout import ModelOutput, RolloutEngine
12+
from rllm.engine import ModelOutput, RolloutEngine
1313

1414
T = TypeVar("T", bound=BaseModel)
1515

1616

1717
class RLLMModel(Model):
1818
"""Model class that uses rLLM's RolloutEngine for inference."""
1919

20-
def __init__(self, rollout_engine: RolloutEngine, model_id: str = "gpt-4", **model_config):
20+
def __init__(self, rollout_engine: RolloutEngine, **kwargs):
2121
"""Initialize the RLLMModel.
2222
2323
Args:
2424
rollout_engine: The rLLM RolloutEngine instance to use for inference
25-
model_id: The model ID to use
26-
**model_config: Additional model configuration
2725
"""
2826
self.rollout_engine = rollout_engine
29-
self.config = {"model_id": model_id, "params": model_config}
30-
31-
def update_config(self, **model_config: Any) -> None:
32-
"""Update the model configuration.
33-
34-
Args:
35-
**model_config: Configuration overrides.
36-
"""
37-
if "model_id" in model_config:
38-
self.config["model_id"] = model_config.pop("model_id")
39-
40-
if "params" not in self.config:
41-
self.config["params"] = {}
42-
self.config["params"].update(model_config)
43-
44-
def get_config(self) -> dict[str, Any]:
45-
"""Get the model configuration.
46-
47-
Returns:
48-
The model's configuration.
49-
"""
50-
return self.config.copy()
27+
self.kwargs = kwargs
5128

5229
async def structured_output(self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any) -> AsyncGenerator[dict[str, T | Any], None]:
5330
"""Get structured output from the model.
@@ -73,9 +50,7 @@ async def structured_output(self, output_model: type[T], prompt: Messages, syste
7350
messages[-1]["content"] = f"{original_content}\n\nPlease respond with a JSON object that matches this schema: {output_model.model_json_schema()}"
7451

7552
# Get response from rollout engine
76-
model_output: ModelOutput = await self.rollout_engine.get_model_response(messages, model=self.config["model_id"], **self.config.get("params", {}), **kwargs)
77-
78-
response_text = model_output.text
53+
response_text = (await self.rollout_engine.get_model_response(messages, **kwargs)).text
7954

8055
try:
8156
# Try to parse the response as JSON and convert to the output model
@@ -129,7 +104,7 @@ async def stream(
129104
yield {"contentBlockStart": {"start": {}}}
130105

131106
# Get response from rollout engine
132-
model_output: ModelOutput = await self.rollout_engine.get_model_response(chat_messages, model=self.config["model_id"], **self.config.get("params", {}), **kwargs)
107+
model_output: ModelOutput = await self.rollout_engine.get_model_response(chat_messages, **kwargs)
133108

134109
# Extract text from ModelOutput
135110
response_text = model_output.text
@@ -203,11 +178,10 @@ def _convert_messages_to_chat_format(self, messages: Messages, system_prompt: st
203178

204179

205180
class StrandsAgent(Agent):
206-
def __init__(self, model: str, **kwargs):
181+
def __init__(self, model: RLLMModel, **kwargs):
207182
"""Initialize StrandsAgent with trajectory tracking.
208183
209184
Args:
210-
model: The model to use (can be a string or Model instance)
211185
**kwargs: Additional arguments to pass to the base Agent class
212186
"""
213187

0 commit comments

Comments
 (0)