Skip to content

Commit 3fa67af

Browse files
update workflow design
1 parent e6ffce8 commit 3fa67af

File tree

4 files changed

+176
-67
lines changed

4 files changed

+176
-67
lines changed

rllm/engine/agent_workflow_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ async def process_task_with_retry(task: dict, uid: str) -> Episode:
6868
try:
6969
for retry_attempt in range(1, self.retry_limit + 1):
7070
try:
71-
episode = await workflow(task=task, uid=uid, **kwargs)
71+
episode = await workflow.run_with_termination_handling(task=task, uid=uid, **kwargs)
7272
return episode
7373
except Exception as e:
7474
print(f"Rollout {uid} failed on attempt {retry_attempt}/{self.retry_limit}: {e}")
@@ -177,7 +177,7 @@ def _transform_results_for_verl(self, episodes: list[Episode], task_ids: np.ndar
177177

178178
episode_ids.extend([episode.id] * total_steps)
179179
is_correct.extend([episode.is_correct] * total_steps)
180-
termination_reasons.extend([episode.termination_reason if episode.termination_reason is not None else TerminationReason.ENV_DONE] * total_steps)
180+
termination_reasons.extend([episode.termination_reason if episode.termination_reason is not None else TerminationReason.UNKNOWN] * total_steps)
181181
metrics.extend([episode.metrics] * total_steps)
182182
repeat_counts.append(total_steps)
183183

rllm/workflows/multi_turn_workflow.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from rllm.agents.agent import Episode
2+
from rllm.workflows.workflow import TerminationEvent, TerminationReason, Workflow
3+
4+
5+
class MultiTurnWorkflow(Workflow):
6+
def __init__(
7+
self,
8+
agent_cls,
9+
env_cls,
10+
agent_args=None,
11+
env_args=None,
12+
max_steps=5,
13+
**kwargs,
14+
):
15+
super().__init__(**kwargs)
16+
17+
# Initialize mutable defaults
18+
agent_args = dict(agent_args) if agent_args is not None else {}
19+
env_args = dict(env_args) if env_args is not None else {}
20+
21+
self.agent = agent_cls(**agent_args)
22+
self.register_agent(self.agent)
23+
self.env = env_cls(**env_args)
24+
self.max_steps = max_steps
25+
26+
async def run(self, task: dict, uid: str, **kwargs) -> Episode | None:
27+
"""Execute a multi-step workflow"""
28+
29+
observation, info = await self.run_in_executor(self.reset, task=task, uid=uid) # returns observation and info from the environment
30+
31+
self.agent.update_from_env(observation, 0, False, info)
32+
33+
for _ in range(1, self.max_steps + 1):
34+
response = (await self.get_model_response(self.agent, **kwargs)).text
35+
action = self.agent.update_from_model(response)
36+
37+
next_obs, reward, done, info = await self.run_in_executor(self.env.step, action)
38+
self.agent.update_from_env(next_obs, reward, done, info)
39+
40+
if self._termination_buffer is not None:
41+
raise TerminationEvent(self._termination_buffer)
42+
43+
if done:
44+
raise TerminationReason.ENV_DONE
45+
46+
raise TerminationReason.MAX_TURNS_EXCEEDED
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from rllm.agents.agent import Episode
2+
from rllm.workflows.workflow import TerminationEvent, TerminationReason, Workflow
3+
4+
5+
class SingleTurnWorkflow(Workflow):
6+
def __init__(
7+
self,
8+
agent_cls,
9+
env_cls,
10+
agent_args=None,
11+
env_args=None,
12+
**kwargs,
13+
):
14+
super().__init__(**kwargs)
15+
16+
# Initialize mutable defaults
17+
agent_args = dict(agent_args) if agent_args is not None else {}
18+
env_args = dict(env_args) if env_args is not None else {}
19+
20+
self.agent = agent_cls(**agent_args)
21+
self.register_agent(self.agent)
22+
self.env = env_cls(**env_args)
23+
24+
async def run(self, task: dict, uid: str, **kwargs) -> Episode | None:
25+
observation, info = await self.run_in_executor(self.reset, task=task, uid=uid) # returns observation and info from the environment
26+
self.agent.update_from_env(observation, 0, False, info)
27+
28+
response = (await self.get_model_response(self.agent, **kwargs)).text
29+
action = self.agent.update_from_model(response)
30+
31+
next_obs, reward, done, info = await self.run_in_executor(self.env.step, action)
32+
self.agent.update_from_env(next_obs, reward, done, info)
33+
34+
if self._termination_buffer is not None:
35+
raise TerminationEvent(self._termination_buffer)
36+
37+
raise TerminationReason.ENV_DONE

0 commit comments

Comments
 (0)