Skip to content

Commit b45a944

Browse files
fix non-workflow trainer/engine
1 parent 62cc8b9 commit b45a944

File tree

3 files changed

+23
-17
lines changed

3 files changed

+23
-17
lines changed

rllm/engine/agent_execution_engine.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,16 @@ async def get_model_response(self, prompt, application_id, **kwargs) -> str:
137137
NotImplementedError: If the engine type is not supported
138138
"""
139139

140-
sampling_params = deepcopy(self.sampling_params).update(kwargs)
140+
sampling_params = self.sampling_params.copy()
141+
sampling_params.update(kwargs)
141142

142143
if self.engine_name == "openai":
143-
output = await self.rollout_engine.get_model_response(prompt, application_id, **sampling_params)
144+
output = await self.rollout_engine.get_model_response(prompt, application_id=application_id, **sampling_params)
144145
return output.text
145146
elif self.engine_name == "verl":
146-
validate = kwargs.get("meta_info", {}).get("validate", False)
147-
output = await self.rollout_engine.get_model_response(prompt, application_id, validate=validate, **sampling_params)
147+
meta_data = sampling_params.pop("meta_info", {})
148+
validate = meta_data.get("validate", False)
149+
output = await self.rollout_engine.get_model_response(prompt, application_id=application_id, validate=validate, **sampling_params)
148150
return output.text
149151
else:
150152
raise NotImplementedError(f"Engine type '{self.engine_name}' not supported")

rllm/trainer/config/agent_ppo_trainer.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ rllm:
2222
name: miniwobagent
2323
max_steps: 20
2424
trajectory_timeout: null
25-
# overlong_filter: False # TODO: refactor as compact_filtering
25+
overlong_filter: False
2626
agent_args: {}
2727
engine_args: {}
2828
env:

rllm/trainer/verl/agent_ppo_trainer.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
assert self.config.actor_rollout_ref.hybrid_engine, "Only hybrid engine is supported"
5656
assert self.config.actor_rollout_ref.rollout.mode == "async", "Only async rollout mode is supported"
5757

58-
if self.config.rllm.agent.stepwise_advantage.enable:
58+
if self.config.rllm.stepwise_advantage.enable:
5959
print("Using step-level advantage, max_prompt_length and max_response_length will be applied step-wise")
6060
else:
6161
print("Using trajectory-level advantage, max_prompt_length and max_response_length will be applied episode-wise")
@@ -76,8 +76,8 @@ def init_workers(self):
7676
agent_args=self.agent_args,
7777
env_class=self.env_class,
7878
env_args=self.env_args,
79-
enforce_max_prompt_length=self.config.rllm.agent.stepwise_advantage.enable,
80-
trajectory_timeout=self.config.rllm.rllm.agent.trajectory_timeout,
79+
enforce_max_prompt_length=self.config.rllm.stepwise_advantage.enable,
80+
trajectory_timeout=self.config.rllm.agent.trajectory_timeout,
8181
overlong_filter=self.config.rllm.agent.get("overlong_filter", False),
8282
**self.config.rllm.agent.get("engine_args", {}),
8383
)
@@ -167,7 +167,7 @@ def fit_agent(self):
167167
with marked_timer("step", timing_raw):
168168
self.init_envs_and_agents(batch)
169169

170-
if self.config.rllm.agent.stepwise_advantage.enable:
170+
if self.config.rllm.stepwise_advantage.enable:
171171
final_gen_batch_output = self.generate_agent_steps(timing_raw=timing_raw, meta_info=batch.meta_info, uids=batch.non_tensor_batch["uid"])
172172
repeat_counts = final_gen_batch_output.meta_info["repeat_counts"]
173173
# need to repeat to make shape match
@@ -227,7 +227,7 @@ def fit_agent(self):
227227
if self.config.rllm.rejection_sample.enable:
228228
# log the actual complete training rewards before rejection sampling
229229
token_level_rewards = None # for metrics calculation
230-
if self.config.rllm.agent.stepwise_advantage.enable:
230+
if self.config.rllm.stepwise_advantage.enable:
231231
is_pad_step = batch.non_tensor_batch["is_pad_step"]
232232
non_pad_step_indices = np.where(is_pad_step == False)[0]
233233
non_pad_steps = batch.select_idxs(non_pad_step_indices)
@@ -249,7 +249,7 @@ def fit_agent(self):
249249
# Filter batch to keep only valid samples
250250
batch = batch[valid_mask]
251251

252-
if self.config.rllm.agent.stepwise_advantage.enable and self.config.rllm.rllm.stepwise_advantage.mode == "broadcast":
252+
if self.config.rllm.stepwise_advantage.enable and self.config.rllm.stepwise_advantage.mode == "broadcast":
253253
# batch now only contains steps with valid uids
254254
# filter out padding steps
255255
is_pad_step = batch.non_tensor_batch["is_pad_step"]
@@ -325,23 +325,23 @@ def fit_agent(self):
325325

326326
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
327327

328-
if self.config.rllm.agent.stepwise_advantage.enable:
329-
if self.config.rllm.rllm.stepwise_advantage.mode == "per_step":
328+
if self.config.rllm.stepwise_advantage.enable:
329+
if self.config.rllm.stepwise_advantage.mode == "per_step":
330330
batch.batch["token_level_rewards"] = batch.batch["mc_returns"]
331331
batch.non_tensor_batch["uid"] = batch.non_tensor_batch["step_ids"]
332332

333333
is_pad_step = batch.non_tensor_batch["is_pad_step"]
334334
non_pad_step_indices = np.where(is_pad_step == False)[0]
335335
batch = batch.select_idxs(non_pad_step_indices) # This batch only has non_pad steps
336-
elif self.config.rllm.rllm.stepwise_advantage.mode == "broadcast":
336+
elif self.config.rllm.stepwise_advantage.mode == "broadcast":
337337
# In case of step-wise advantage broadcast, we would split out the final steps, then merge again
338338
is_last_step = batch.non_tensor_batch["is_last_step"]
339339
last_step_indices = np.where(is_last_step == True)[0]
340340
other_step_indices = np.where(is_last_step == False)[0]
341341
other_step_batch = batch.select_idxs(other_step_indices)
342342
batch = batch.select_idxs(last_step_indices) # This batch only has last steps
343343
else:
344-
raise ValueError(f"Stepwise advantage mode {self.config.rllm.rllm.stepwise_advantage.mode} not supported")
344+
raise ValueError(f"Stepwise advantage mode {self.config.rllm.stepwise_advantage.mode} not supported")
345345

346346
# compute advantages, executed on the driver process
347347
batch = compute_advantage(
@@ -354,13 +354,17 @@ def fit_agent(self):
354354
config=self.config.algorithm,
355355
)
356356

357-
if self.config.rllm.agent.stepwise_advantage.enable and self.config.rllm.rllm.stepwise_advantage.mode == "broadcast":
357+
if self.config.rllm.stepwise_advantage.enable and self.config.rllm.stepwise_advantage.mode == "broadcast":
358358
# remove the padded last steps
359359
# Merging the separated out steps using the advantage from last steps
360360
self._stepwise_advantage_broadcast(batch, other_step_batch=other_step_batch)
361361
# batch = batch.merge(other_step_batch)
362362
batch = DataProto.concat([batch, other_step_batch])
363363

364+
if self.config.rllm.mask_truncated_samples:
365+
mask = batch.batch["attention_mask"][:, -1] == 1
366+
batch = batch[~mask]
367+
364368
batch = self._pad_dataproto_to_world_size(batch=batch)
365369
# balance the number of valid tokens on each dp rank.
366370
# Note that this breaks the order of data inside the batch.
@@ -431,7 +435,7 @@ def _validate_agent(self):
431435
}
432436
self.init_envs_and_agents(test_batch)
433437

434-
if self.config.rllm.agent.stepwise_advantage.enable:
438+
if self.config.rllm.stepwise_advantage.enable:
435439
test_output_gen_batch = self.generate_agent_steps(meta_info=test_batch.meta_info, uids=test_batch.non_tensor_batch["uid"])
436440
# for validation, we only need the last step
437441
is_last_step = test_output_gen_batch.non_tensor_batch["is_last_step"]

0 commit comments

Comments
 (0)