Skip to content

Commit 8fa061c

Browse files
fix agents and envs
1 parent 3fa67af commit 8fa061c

File tree

6 files changed

+95
-39
lines changed

6 files changed

+95
-39
lines changed

rllm/agents/code_agent.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,17 @@ class CompetitionCodingAgent(BaseAgent):
2323
A code agent that iteratively writes code to solve a problem.
2424
"""
2525

26-
def __init__(self, remove_thinking=False, max_tests=2, public_test_only=True):
26+
def __init__(self, accumulate_thinking=False, max_tests=2, public_test_only=False):
2727
"""
2828
Initialize the CodeAgent.
2929
"""
3030
self.revise_instruction = "Here's the feedback from the previous attempt. Revise the code to fix the errors and improve the solution."
3131
self._trajectory = Trajectory()
3232
self.messages = []
33-
self.remove_thinking = remove_thinking
33+
self.accumulate_thinking = accumulate_thinking
34+
3435
self.max_tests = max_tests
3536
self.public_test_only = public_test_only
36-
self.current_observation = None
3737

3838
def format_test_results(self, test_results: list[dict]) -> str:
3939
def normalize_string(s):
@@ -102,46 +102,63 @@ def update_from_env(self, observation: Any, reward: float, done: bool, info: dic
102102
else:
103103
formatted_observation = str(observation)
104104

105+
# Update reward on the latest step
106+
if self.trajectory.steps:
107+
cur_step = self.get_current_state()
108+
cur_step.reward = reward
109+
cur_step.done = done
110+
cur_step.info = info
111+
105112
if done:
106113
return
107114

108115
self.messages.append({"role": "user", "content": formatted_observation})
109-
self.current_observation = formatted_observation
116+
117+
new_step = Step(observation=formatted_observation)
118+
self._trajectory.steps.append(new_step)
110119

111120
def update_from_model(self, response: str, **kwargs) -> Action:
112121
"""
113122
Updates the agent's internal state based on the model's response.
114123
"""
115-
content = response
116-
action = response
117-
118-
# Handle thinking removal if needed
119-
if self.remove_thinking and content.count("</think>") == 1:
120-
thought, action = response.split("</think>")
121-
thought += "</think>"
122-
action = action.strip()
123-
self.messages.append({"role": "assistant", "content": action})
124+
self.messages.append({"role": "assistant", "content": response})
125+
126+
cur_step = self.get_current_state()
127+
cur_step.chat_completions = self.chat_completions
128+
cur_step.model_response = response
129+
130+
if response.count("</think>") == 1:
131+
thought, sep, action = response.partition("</think>")
132+
thought = thought + sep
133+
action = Action(action.strip())
124134
else:
125-
self.messages.append({"role": "assistant", "content": response})
135+
thought = None
136+
action = Action(response.strip())
126137

127-
# Create new step
128-
new_step = Step(chat_completions=copy.deepcopy(self.chat_completions), action=action, model_response=response, observation=self.current_observation)
129-
self._trajectory.steps.append(new_step)
138+
cur_step.thought = thought
139+
cur_step.action = action
130140

131-
return Action(action=action)
141+
return action
132142

133143
def reset(self):
134144
"""
135145
Resets the agent's internal state for a new episode.
136146
"""
137147
self._trajectory = Trajectory()
138148
self.messages = []
139-
self.current_observation = None
140149

141150
@property
142151
def chat_completions(self) -> list[dict[str, str]]:
143-
"""Returns the history of messages for chat completion."""
144-
return self.messages
152+
"""Return conversation history for model interaction."""
153+
# remove thinking from assistant messages if not accumulate_thinking except the last one
154+
messages = copy.deepcopy(self.messages)
155+
if not self.accumulate_thinking:
156+
for msg in messages[:-1]:
157+
if msg["role"] == "assistant":
158+
_, sep, after = msg["content"].partition("</think>")
159+
if sep:
160+
msg["content"] = after
161+
return messages
145162

146163
@property
147164
def trajectory(self) -> Trajectory:

rllm/agents/math_agent.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,38 +13,62 @@ def __init__(self, accumulate_thinking=True):
1313
"""
1414
Initialize the MathAgent.
1515
"""
16-
self.instruction = "Let's think step by step, and put your final answer within \\boxed{}."
1716
self._trajectory = Trajectory()
1817
self.messages = []
1918
self.accumulate_thinking = accumulate_thinking
2019

2120
def update_from_env(self, observation: Any, reward: float, done: bool, info: dict, **kwargs):
2221
"""Process environment feedback and update internal state."""
2322

24-
# Format observation based on whether it's the initial problem or subsequent feedback
25-
if not self.trajectory.steps:
26-
# Initial problem presentation
27-
assert isinstance(observation, dict) and "question" in observation
28-
question = observation["question"]
29-
formatted_observation = f"{question} {self.instruction}"
23+
# If observation is None, this is a reward update for the existing step
24+
if observation is None:
25+
if self.trajectory.steps:
26+
cur_step = self.get_current_state()
27+
cur_step.reward = reward
28+
cur_step.done = done
29+
cur_step.info = info
30+
return
31+
32+
# This is a new observation, create a new step
33+
if isinstance(observation, dict):
34+
formatted_observation = observation["question"]
35+
elif isinstance(observation, str):
36+
formatted_observation = observation
3037
else:
31-
# Follow-up correction prompt
32-
formatted_observation = "Your previous answer may contain a mistake. Please review it carefully and answer again. Put your final answer within \\boxed{}."
38+
raise ValueError(f"Invalid observation type: {type(observation)}")
3339

3440
self.messages.append({"role": "user", "content": formatted_observation})
3541

42+
new_step = Step(observation=formatted_observation)
43+
self._trajectory.steps.append(new_step)
44+
3645
def update_from_model(self, response: str, **kwargs) -> Action:
3746
"""
3847
Updates the agent's internal state based on the model's response.
3948
"""
49+
50+
# Update the latest step
4051
self.messages.append({"role": "assistant", "content": response})
41-
new_step = Step(chat_completions=copy.deepcopy(self.chat_completions))
42-
self.trajectory.steps.append(new_step)
4352

44-
return Action(action=response)
53+
cur_step = self.get_current_state()
54+
cur_step.chat_completions = self.chat_completions
55+
cur_step.model_response = response
56+
57+
if response.count("</think>") == 1:
58+
thought, sep, action = response.partition("</think>")
59+
thought = thought + sep
60+
action = Action(action.strip())
61+
else:
62+
thought = None
63+
action = Action(response.strip())
64+
65+
cur_step.thought = thought
66+
cur_step.action = action
67+
68+
return action
4569

46-
def reset(self):
47-
"""Reset agent state for new episode."""
70+
def reset(self) -> None:
71+
"""Reset agent state for new episode (wipes trajectory and messages)."""
4872
self._trajectory = Trajectory()
4973
self.messages = []
5074

rllm/environments/base/multi_turn_env.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,13 @@ def __init__(self, task: dict | None = None, max_turns: int = 3, **kwargs):
2525
self.max_turns = max_turns
2626
self.current_turn = 0
2727
self.done = False
28-
self.history: list[Any] = []
28+
self.history = []
29+
30+
def reset(self, task: dict | None = None):
31+
# Use the provided task if available, otherwise use the default task
32+
if task is not None:
33+
self.task = task
2934

30-
def reset(self):
3135
self.done = False
3236
self.current_turn = 0
3337
self.history = []

rllm/environments/code/competition_coding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def reset(self, task=None, seed=None):
4040
self.prev_reward = None
4141

4242
# Return the first question
43-
return {"question": self.task["question"]}, {}
43+
return self.task, {}
4444

4545
def step(self, action):
4646
"""

rllm/rewards/countdown_reward.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import random
22
import re
33

4+
from rllm import Action
45
from rllm.rewards.reward_types import RewardOutput
56

67

@@ -109,7 +110,7 @@ def compute_score(solution_str, ground_truth, method="strict", format_score=0.1,
109110
return format_score
110111

111112

112-
def countdown_reward_fn(task_info: dict, action: str) -> RewardOutput:
113+
def countdown_reward_fn(task_info: dict, action: str | Action) -> RewardOutput:
113114
"""
114115
A specialized reward function for countdown tasks using the compute_score helper.
115116
@@ -124,6 +125,9 @@ def countdown_reward_fn(task_info: dict, action: str) -> RewardOutput:
124125
RewardOutput with reward and metadata
125126
"""
126127
try:
128+
if isinstance(action, Action):
129+
action = action.action
130+
127131
# Extract basic info
128132
target = task_info.get("target")
129133
nums = task_info.get("nums", [])

rllm/rewards/reward_fn.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Protocol, runtime_checkable
22

3+
from rllm.agents.agent import Action
34
from rllm.rewards.code_reward import RewardCodeFn
45
from rllm.rewards.math_reward import RewardMathFn
56
from rllm.rewards.reward_types import RewardConfig, RewardInput, RewardOutput
@@ -53,6 +54,8 @@ def math_reward_fn(task_info: dict, action: str) -> RewardOutput:
5354
"""
5455
reward_config = RewardConfig()
5556
reward_fn = RewardMathFn(reward_config)
57+
if isinstance(action, Action):
58+
action = action.action
5659
return reward_fn(task_info, action)
5760

5861

@@ -69,6 +72,8 @@ def search_reward_fn(task_info: dict, action: str) -> RewardOutput:
6972
"""
7073
reward_config = RewardConfig()
7174
reward_fn = RewardSearchFn(reward_config)
75+
if isinstance(action, Action):
76+
action = action.action
7277

7378
# Create RewardInput from task_info and action
7479
reward_input = RewardInput(task_info=task_info, action=action)
@@ -89,4 +94,6 @@ def code_reward_fn(task_info: dict, action: str) -> RewardOutput:
8994
"""
9095
reward_config = RewardConfig()
9196
reward_fn = RewardCodeFn(reward_config)
97+
if isinstance(action, Action):
98+
action = action.action
9299
return reward_fn(task_info, action)

0 commit comments

Comments
 (0)