@@ -55,7 +55,7 @@ def __init__(
55
55
assert self .config .actor_rollout_ref .hybrid_engine , "Only hybrid engine is supported"
56
56
assert self .config .actor_rollout_ref .rollout .mode == "async" , "Only async rollout mode is supported"
57
57
58
- if self .config .rllm .agent . stepwise_advantage .enable :
58
+ if self .config .rllm .stepwise_advantage .enable :
59
59
print ("Using step-level advantage, max_prompt_length and max_response_length will be applied step-wise" )
60
60
else :
61
61
print ("Using trajectory-level advantage, max_prompt_length and max_response_length will be applied episode-wise" )
@@ -76,8 +76,8 @@ def init_workers(self):
76
76
agent_args = self .agent_args ,
77
77
env_class = self .env_class ,
78
78
env_args = self .env_args ,
79
- enforce_max_prompt_length = self .config .rllm .agent . stepwise_advantage .enable ,
80
- trajectory_timeout = self .config .rllm .rllm . agent .trajectory_timeout ,
79
+ enforce_max_prompt_length = self .config .rllm .stepwise_advantage .enable ,
80
+ trajectory_timeout = self .config .rllm .agent .trajectory_timeout ,
81
81
overlong_filter = self .config .rllm .agent .get ("overlong_filter" , False ),
82
82
** self .config .rllm .agent .get ("engine_args" , {}),
83
83
)
@@ -167,7 +167,7 @@ def fit_agent(self):
167
167
with marked_timer ("step" , timing_raw ):
168
168
self .init_envs_and_agents (batch )
169
169
170
- if self .config .rllm .agent . stepwise_advantage .enable :
170
+ if self .config .rllm .stepwise_advantage .enable :
171
171
final_gen_batch_output = self .generate_agent_steps (timing_raw = timing_raw , meta_info = batch .meta_info , uids = batch .non_tensor_batch ["uid" ])
172
172
repeat_counts = final_gen_batch_output .meta_info ["repeat_counts" ]
173
173
# need to repeat to make shape match
@@ -227,7 +227,7 @@ def fit_agent(self):
227
227
if self .config .rllm .rejection_sample .enable :
228
228
# log the actual complete training rewards before rejection sampling
229
229
token_level_rewards = None # for metrics calculation
230
- if self .config .rllm .agent . stepwise_advantage .enable :
230
+ if self .config .rllm .stepwise_advantage .enable :
231
231
is_pad_step = batch .non_tensor_batch ["is_pad_step" ]
232
232
non_pad_step_indices = np .where (is_pad_step == False )[0 ]
233
233
non_pad_steps = batch .select_idxs (non_pad_step_indices )
@@ -249,7 +249,7 @@ def fit_agent(self):
249
249
# Filter batch to keep only valid samples
250
250
batch = batch [valid_mask ]
251
251
252
- if self .config .rllm .agent . stepwise_advantage .enable and self .config . rllm .rllm .stepwise_advantage .mode == "broadcast" :
252
+ if self .config .rllm .stepwise_advantage .enable and self .config .rllm .stepwise_advantage .mode == "broadcast" :
253
253
# batch now only contains steps with valid uids
254
254
# filter out padding steps
255
255
is_pad_step = batch .non_tensor_batch ["is_pad_step" ]
@@ -325,23 +325,23 @@ def fit_agent(self):
325
325
326
326
batch .batch ["token_level_rewards" ] = batch .batch ["token_level_scores" ]
327
327
328
- if self .config .rllm .agent . stepwise_advantage .enable :
329
- if self .config .rllm .rllm . stepwise_advantage .mode == "per_step" :
328
+ if self .config .rllm .stepwise_advantage .enable :
329
+ if self .config .rllm .stepwise_advantage .mode == "per_step" :
330
330
batch .batch ["token_level_rewards" ] = batch .batch ["mc_returns" ]
331
331
batch .non_tensor_batch ["uid" ] = batch .non_tensor_batch ["step_ids" ]
332
332
333
333
is_pad_step = batch .non_tensor_batch ["is_pad_step" ]
334
334
non_pad_step_indices = np .where (is_pad_step == False )[0 ]
335
335
batch = batch .select_idxs (non_pad_step_indices ) # This batch only has non_pad steps
336
- elif self .config .rllm .rllm . stepwise_advantage .mode == "broadcast" :
336
+ elif self .config .rllm .stepwise_advantage .mode == "broadcast" :
337
337
# In case of step-wise advantage broadcast, we would split out the final steps, then merge again
338
338
is_last_step = batch .non_tensor_batch ["is_last_step" ]
339
339
last_step_indices = np .where (is_last_step == True )[0 ]
340
340
other_step_indices = np .where (is_last_step == False )[0 ]
341
341
other_step_batch = batch .select_idxs (other_step_indices )
342
342
batch = batch .select_idxs (last_step_indices ) # This batch only has last steps
343
343
else :
344
- raise ValueError (f"Stepwise advantage mode { self .config .rllm .rllm . stepwise_advantage .mode } not supported" )
344
+ raise ValueError (f"Stepwise advantage mode { self .config .rllm .stepwise_advantage .mode } not supported" )
345
345
346
346
# compute advantages, executed on the driver process
347
347
batch = compute_advantage (
@@ -354,13 +354,17 @@ def fit_agent(self):
354
354
config = self .config .algorithm ,
355
355
)
356
356
357
- if self .config .rllm .agent . stepwise_advantage .enable and self .config . rllm .rllm .stepwise_advantage .mode == "broadcast" :
357
+ if self .config .rllm .stepwise_advantage .enable and self .config .rllm .stepwise_advantage .mode == "broadcast" :
358
358
# remove the padded last steps
359
359
# Merging the separated out steps using the advantage from last steps
360
360
self ._stepwise_advantage_broadcast (batch , other_step_batch = other_step_batch )
361
361
# batch = batch.merge(other_step_batch)
362
362
batch = DataProto .concat ([batch , other_step_batch ])
363
363
364
+ if self .config .rllm .mask_truncated_samples :
365
+ mask = batch .batch ["attention_mask" ][:, - 1 ] == 1
366
+ batch = batch [~ mask ]
367
+
364
368
batch = self ._pad_dataproto_to_world_size (batch = batch )
365
369
# balance the number of valid tokens on each dp rank.
366
370
# Note that this breaks the order of data inside the batch.
@@ -431,7 +435,7 @@ def _validate_agent(self):
431
435
}
432
436
self .init_envs_and_agents (test_batch )
433
437
434
- if self .config .rllm .agent . stepwise_advantage .enable :
438
+ if self .config .rllm .stepwise_advantage .enable :
435
439
test_output_gen_batch = self .generate_agent_steps (meta_info = test_batch .meta_info , uids = test_batch .non_tensor_batch ["uid" ])
436
440
# for validation, we only need the last step
437
441
is_last_step = test_output_gen_batch .non_tensor_batch ["is_last_step" ]
0 commit comments