Skip to content

Conversation

@szrlee
Copy link
Collaborator

@szrlee szrlee commented Nov 10, 2025

Overview

This PR fixes bugs, refactors configuration for semantic clarity, and adds batch normalization support to the rollout correction implementation introduced in PR #3984.


Bug Fixes

1. Metrics Computation Running in Wrong Mode ⚠️

Problem: Rollout correction metrics were computed in bypass mode instead of decoupled mode, making them meaningless.

Root Cause: Incorrect condition at ray_trainer.py:1177-1180

# BEFORE (incorrect - runs in bypass mode)
if rollout_corr_config is not None and "rollout_log_probs" in batch.batch:
    batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch)
# AFTER (correct - runs in decoupled mode only)
if (rollout_corr_config is not None
    and "rollout_log_probs" in batch.batch
    and not bypass_recomputing_logprobs):  # Only in decoupled mode
    batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config)

Impact:

  • IS weights and rejection sampling metrics are now computed only when meaningful (decoupled mode with 3 policies)
  • In bypass mode (2 policies), actor now correctly computes metrics from evolving π_θ vs π_rollout

Related Changes:


Configuration Refactor (Semantic Clarity)

2. Variable Renaming

Renamed config variables to accurately reflect their semantics:

Old Name New Name Rationale
bypass_old_logprob_for_rollout bypass_mode Directly describes the operating mode (2-policy vs 3-policy)
use_pure_rollout_correction use_policy_gradient Reflects actual choice: policy gradient loss vs Q-function loss

Before (algorithm.py @ 0ef0e05b):

bypass_old_logprob_for_rollout: bool = False  # Unclear what "bypass" means
use_pure_rollout_correction: bool = False     # "Pure" is vague

After (algorithm.py @ HEAD):

bypass_mode: bool = False           # Clear: bypass or decoupled mode
use_policy_gradient: bool = False   # Clear: PG or Q-function loss

Files Updated:


New Feature: Batch Normalization

3. IS Weight Batch Normalization

Added: rollout_is_batch_normalize config parameter (algorithm.py:159)

rollout_is_batch_normalize: bool = False

Purpose:

  • Normalizes importance sampling weights to have mean=1.0 within each batch
  • Aligns normalization scope with IS aggregation level (token/sequence/geometric)
  • Helps stabilize training when policy drift is large

Behavior:

  • True: IS weights normalized so mean=1.0 per batch (reduces variance)
  • False: Raw truncated IS weights used (standard behavior, default)

Documentation:


Documentation Overhaul

4. File Reorganization

Moved documentation to docs/algo/:

  • docs/advance/rollout_corr.mddocs/algo/rollout_corr.md (+439 additions)
  • docs/advance/rollout_corr_math.mddocs/algo/rollout_corr_math.md (+459 additions)

Deleted redundant file:

  • examples/rollout_correction/README.md (-253 lines)

Updated references:

5. Preset Renaming for Clarity

Renamed presets to clearly indicate operating mode:

Old Name New Name Operating Mode Description
token_is decoupled_token_is Decoupled (3-policy) Token-level IS weighting
seq_is decoupled_seq_is Decoupled (3-policy) Sequence-level IS weighting
geo_rs decoupled_geo_rs Decoupled (3-policy) Geometric rejection sampling
ppo_is_bypass ppo_is_bypass Bypass (2-policy) PPO with IS (unchanged)
pure_is pg_is Bypass (2-policy) Policy gradient + sequence IS
N/A pg_rs Bypass (2-policy) Policy gradient + geometric RS (new)

Naming Convention:

  • Decoupled mode presets: decoupled_* (requires old_log_prob computation)
  • Bypass mode presets: pg_* or ppo_* (skips old_log_prob computation)

6. Content Improvements

Cross-References:

Clarified Loss Formulations:

  • Changed examples from PPO to REINFORCE in rollout_corr_math.md §3.3
  • Rationale: Separates IS weight mechanics from PPO clipping for clarity
  • Added note that REINFORCE examples can be combined with PPO clipping

New Sections:


Code Quality Improvements

7. Enhanced Comments and Documentation

Trainer Logic (ray_trainer.py):

  • Lines 1104-1107: Operating mode selection logic
  • Lines 1175-1177: Metrics computation behavior explanation

Policy Loss (core_algos.py):

  • Enhanced docstrings for compute_policy_loss_with_rollout_correction
  • Clarified when to use policy gradient vs Q-function loss

Actor Workers (dp_actor.py, megatron_actor.py):

  • Added comments explaining bypass mode metrics computation

8. Code Simplification

Removed Unused Logic (rollout_corr_helper.py):

  • Removed unnecessary config parameters from metrics computation
  • Removed unused IS weight processing logic
  • Simplified metrics calculation flow

Improved Variable Reuse:

  • Reused need_recomputation variable instead of redundant bypass mode checks
  • Reduced code duplication

Commit History

18 commits (click to expand)
  1. 7c9e41da - fix(rollout_corr): compute metrics in actor for bypass mode and fix trainer bugs
  2. 96ae2be1 - docs(rollout_corr): move to algo/ and add pure_rs preset
  3. c0ea9bdc - feat(rollout_corr): add batch normalization option for IS weights
  4. 7de6c5f9 - docs(rollout_corr_math): use REINFORCE in aggregation loss examples for clarity
  5. 2b34cfee - refactor(rollout_corr): simplify metrics computation by removing unused config and IS weight logic
  6. 0c42f85a - docs(rollout_corr): add prominent cross-references between usage and math docs
  7. fef8a48f - docs(rollout_corr_math): add dedicated section for batch normalization
  8. 08cc9c7d - fix: docstring of compute_policy_loss_with_rollout_correction
  9. 437a4aba - feat: reuse need_recomputation instead of bypass_mode
  10. 5f9a53bf - feat: improve comments
  11. b2f63709 - feat: improve comments
  12. 79cdbf2f - feat: refactor bypass_recomputing_logprobs
  13. 62e32701 - feat(rollout_corr): align batch normalization with IS aggregation level
  14. b5c19ff7 - docs(rollout_corr): rename decoupled mode presets for clarity and update examples
  15. 11f9aa05 - fix(rollout_corr): correct metrics computation to run in decoupled mode only
  16. 58565cb0 - docs(rollout_corr): rename presets for clarity and consistency
  17. 8bb1a0e0 - refactor(rollout_corr): rename config vars for semantic clarity
  18. 6002c00c - refactor(rollout_corr): update implementation to use renamed config variables

Summary

This PR systematically improves the rollout correction implementation through three key areas:

  1. Bug Fixes: Corrected metrics computation to run in the appropriate mode
  2. Semantic Clarity: Renamed variables to accurately reflect their purpose (bypass_mode, use_policy_gradient)
  3. Feature Addition: Added batch normalization option for IS weights with comprehensive documentation

All changes maintain backward compatibility while significantly improving code clarity, correctness, and maintainability.

…rainer bugs

Fix three critical issues in rollout correction metrics computation:

1. Missing rollout_corr_config parameter in ray_trainer.py line 1178
   compute_rollout_correction_and_add_to_batch() call

2. Trainer computes meaningless metrics in bypass mode since old=rollout
   Results in KL≈0, weights≈1.0 that don't reflect actual drift

3. No metrics computed for bypass+non-pure mode during actor training
   Bypass+pure already computes metrics in pure loss function

Solution:
- Add compute_rollout_corr_metrics_from_logprobs() helper function to compute
  metrics using current policy vs rollout policy log probabilities
- Always pass rollout_correction config to actor in bypass mode for metrics
- Skip trainer metrics in bypass mode, compute meaningful metrics in actor
- Actor computes per-microbatch metrics showing drift as training progresses

Behavior by mode:
- Bypass+non-pure: Actor computes metrics (π_current vs π_rollout)
- Bypass+pure: Pure loss function computes metrics internally
- Decoupled: Trainer computes metrics (π_old vs π_rollout)

Files changed:
- verl/trainer/ppo/rollout_corr_helper.py: Add metrics helper, always pass config
- verl/trainer/ppo/ray_trainer.py: Fix missing param, skip bypass metrics
- verl/workers/actor/dp_actor.py: Add rollout_log_probs selection, compute metrics
- verl/workers/actor/megatron_actor.py: Add rollout_log_probs selection, compute metrics
- verl/trainer/ppo/core_algos.py: Remove outdated documentation
Move rollout correction documentation from docs/advance/ to docs/algo/ to better
reflect its algorithmic nature. Add new pure_rs() preset for pure rejection sampling
in bypass mode (geometric RS without IS weights).

Changes:
1. Move docs/advance/rollout_corr.md → docs/algo/rollout_corr.md
2. Move docs/advance/rollout_corr_math.md → docs/algo/rollout_corr_math.md
3. Update docs/index.rst to reflect new locations
4. Update cross-reference in docs/advance/fully_async.md
5. Add RolloutCorrectionConfig.pure_rs() preset method
   - Pure rejection sampling (no IS weights)
   - Geometric aggregation level
   - Bypass mode (skips old_log_prob)
   - use_pure_rollout_correction=True
6. Update preset tables in both documents (7 total presets)
Add rollout_is_batch_normalize parameter to RolloutCorrectionConfig to enable
batch normalization of importance sampling weights.

When enabled, IS weights are normalized to have mean=1.0 within each batch,
which reduces variance by ensuring the average weight is always 1.0 per batch.

Changes:
1. Add rollout_is_batch_normalize field to RolloutCorrectionConfig (default: False)
2. Update compute_rollout_correction_weights() to support batch normalization
   - Normalize weights by dividing by masked mean after truncation
   - Add rollout_is_batch_norm_factor metric to track normalization factor
3. Thread parameter through all call sites:
   - compute_rollout_correction_and_rejection_mask()
   - compute_rollout_correction_and_add_to_batch()
   - compute_rollout_corr_metrics_from_logprobs()
   - compute_policy_loss_with_rollout_correction() (bypass-pure mode)
   - compute_policy_loss_rollout_correction_wrapper() (bypass-pure mode)
4. Update documentation in rollout_corr.md and rollout_corr_math.md

The batch normalization is applied AFTER truncation to preserve the
truncation semantics while ensuring mean=1.0 for variance reduction.
…or clarity

Change loss function examples in Section 3.3 from PPO formulation to REINFORCE
for better clarity, and note that PPO clipping can be added.

Changes:
- Token-level (§3.3.1): Use REINFORCE+TIS instead of PPO+TIS
- Sequence-level (§3.3.2): Use REINFORCE+SeqIS instead of PPO+SeqIS
- Geometric-level (§3.3.3): Use REINFORCE+GeoRS instead of PPO+GeoRS
- Add note that all can be combined with PPO clipping
- Simplifies understanding by separating IS mechanics from PPO clipping

This makes it clearer that IS weighting is orthogonal to the choice of
policy gradient algorithm (REINFORCE vs PPO).
@szrlee szrlee changed the title [docs] Rollout Correction: Fix Metrics, Add Documentation, and Add Batch Normalization [doc,feat] Rollout Correction: Fix Metrics, Add Documentation, and Add Batch Normalization Nov 10, 2025
@szrlee szrlee changed the title [doc,feat] Rollout Correction: Fix Metrics, Add Documentation, and Add Batch Normalization [doc,algo] Rollout Correction: Fix Metrics, Add Documentation, and Add Batch Normalization Nov 10, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request significantly enhances the rollout correction feature by addressing several key areas. It fixes bugs in metrics computation for bypass mode, ensuring accurate reflection of policy drift during training. The documentation has been reorganized for better clarity and includes a new pure rejection sampling preset. A valuable addition is the optional batch normalization for importance sampling weights, which helps reduce variance. Finally, the loss function documentation has been improved by using REINFORCE examples, clarifying the orthogonality of IS weighting to the choice of policy gradient algorithm. The changes are well-implemented and contribute positively to the robustness and clarity of the rollout correction framework.

@szrlee szrlee changed the title [doc,algo] Rollout Correction: Fix Metrics, Add Documentation, and Add Batch Normalization [doc,algo] feat: Rollout Correction - Fix Metrics, Add Documentation, and Add Batch Normalization Nov 10, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several valuable improvements to the rollout correction feature. The addition of batch normalization for importance sampling weights is a great enhancement for reducing variance. The bug fix for metrics computation in bypass mode is a crucial correction that improves monitoring capabilities. Furthermore, the documentation has been significantly reorganized and expanded, which greatly enhances clarity and usability for developers. I've identified one high-severity issue in a new helper function related to inconsistent default values and a potential crash, for which I've provided a detailed comment and a suggested fix. Overall, this is a high-quality contribution.

@tongyx361 tongyx361 self-assigned this Nov 10, 2025
szrlee and others added 8 commits November 10, 2025 20:41
…ed config and IS weight logic

Simplify compute_rollout_corr_metrics_from_logprobs() to only compute diagnostic
metrics (KL, PPL, χ²), removing unnecessary IS weight statistics and config handling.

Issues fixed:
- Dead code: unreachable else block for None config (call sites always check first)
- Inconsistent defaults: hardcoded "token" contradicts dataclass default "sequence"
- Unused parameter: rollout_corr_config passed but only used for IS weight computation
- Redundant check: rollout_log_prob presence already implies feature enabled

Changes:
- Remove rollout_corr_config parameter and all config normalization logic
- Remove IS weight computation (not needed for actor monitoring)
- Simplify call sites to only check rollout_log_prob availability
…math docs

Add visible callout boxes at the top of both documentation files to help users
navigate between practical usage guide and mathematical theory.

Changes:
- Add 📖 Documentation Structure callout at beginning of both docs
- rollout_corr.md: Links to math doc for theoretical foundations
- rollout_corr_math.md: Links to usage guide for practical implementation
- Includes guidance on where to start based on user needs

This improves discoverability and helps users understand the relationship
between the two complementary documents.
Move batch normalization from a bullet point to its own section (§3.6) with
comprehensive explanation of theory, usage, and trade-offs.

Changes:
- Add new §3.6 Batch Normalization with full mathematical analysis
- Explain why it reduces variance (eliminates batch-level mean fluctuations)
- Document when to use and trade-offs
- Include normalization formula and gradient scaling analysis
- Renumber §3.6 → §3.7 (Combination Matrix)
- Renumber §3.7 → §3.8 (Common Implementation Mistake)
- Update §3.3.2 to reference new section instead of inline mention

This gives batch normalization the prominence it deserves as an important
variance reduction technique.
Make IS weight batch normalization aggregation-aware to prevent length bias
in sequence-level IS:

- Token-level: normalize over all token weights (unchanged)
- Sequence-level: normalize over sequence means (one per sequence)

Previously, sequence-level incorrectly normalized over all token positions,
giving longer sequences more weight. Now each sequence contributes equally.

Also update docs in rollout_corr.md and rollout_corr_math.md §3.6.
…ate examples

Rename decoupled mode presets to include `decoupled_` prefix:
- token_is() → decoupled_token_is()
- seq_is() → decoupled_seq_is()
- seq_is_rs() → decoupled_seq_is_rs()
- geo_rs() → decoupled_geo_rs()

Bypass mode presets unchanged (already clear):
- ppo_is_bypass(), pure_is(), pure_rs()

This makes operating mode immediately apparent from method name.

Changes:
- docs: Update preset tables, quick start examples, and detailed sections
- docs: Add missing pure_rs to method summary table
- docs: Clarify pure_rs uses pure policy gradient (not PPO)
- examples: Update RLOO example to use pure_is mode with RLOO advantage
- examples: Update DAPO recipe with sequence-level IS and geometric RS
- examples: Remove redundant README (content in main docs)
…de only

Fixed critical bug where rollout correction metrics were computed in bypass
mode instead of decoupled mode. Changed condition from `bypass_recomputing_logprobs`
to `not bypass_recomputing_logprobs` at ray_trainer.py:1181.

Also clarified comments:
- Operating mode selection (ray_trainer.py:1104-1107)
- Metrics computation behavior (ray_trainer.py:1175-1177)
- Actor metrics comments (dp_actor.py, megatron_actor.py)
szrlee and others added 4 commits November 11, 2025 23:43
Standardize preset naming to clearly indicate mode and algorithm:

Decoupled mode (3 policies):
  token_is → decoupled_token_is
  seq_is → decoupled_seq_is
  seq_is_rs → decoupled_seq_is_rs
  geo_rs → decoupled_geo_rs

Policy gradient (bypass mode):
  pure_is → pg_is
  pure_rs → pg_rs

Changes:
- Remove aliases (token_tis, seq_mis, geo_mis)
- Add disabled() to tables
- Update docstrings and examples
- Improve YAML comments
- Add missing rollout_is_batch_normalize parameter

All 8 presets now consistent across code and docs.
Rename variables to match documentation terminology exactly:
  • bypass_old_logprob_for_rollout → bypass_mode
  • use_pure_rollout_correction → use_policy_gradient

Rationale:
- 'bypass_mode' matches docs ('Bypass mode' vs 'Decoupled mode')
- 'use_policy_gradient' matches docs ('Policy Gradient loss' vs 'PPO loss')
- Previous names were ambiguous and non-standard
- New names are self-documenting and semantically precise

Changes:
- Update all config files (algorithm.py, YAML, both docs)
- Update all preset method implementations
- Update all YAML examples in documentation
- Update configuration tables

This is a BREAKING API change for config-based users.
…ariables

Complete variable renaming initiated in commit 8bb1a0e:
  • bypass_old_logprob_for_rollout → bypass_mode
  • use_pure_rollout_correction → use_policy_gradient

Changes:
- Update implementation files (rollout_corr_helper, ray_trainer, core_algos)
- Update example script (run_with_rollout_corr.sh)
- Improve docstring classifications ("Bypass + PPO/PG loss")
- Clarify documentation terminology (IS/RS independence)
- Update mode names ("Bypass + Policy Gradient mode")

All references to old variable names removed from codebase.
):
from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch

# Compute IS weights, apply rejection sampling, compute metrics
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is not about this PR, but the verl's ray_trainer design. Now we have a huge fit function, so if a recipe needs to customize a fit function, the author has to copy all these lines and their code will be outdated. The sync from main fit function to individual recipes will cost too much efforts.

@wuxibin89 I think we should split the huge fit function into several modules, the rollout correction also took one module. Especially for the new refactored engine version.

@ISEEKYAN ISEEKYAN merged commit 2c6c65c into volcengine:main Nov 12, 2025
78 of 81 checks passed
chenjiaoAngel added a commit to chenjiaoAngel/verl that referenced this pull request Nov 14, 2025
… and Add Batch Normalization (volcengine#4070)

## Overview

This PR fixes bugs, refactors configuration for semantic clarity, and
adds batch normalization support to the rollout correction
implementation introduced in PR volcengine#3984.

---

## Bug Fixes

### 1. Metrics Computation Running in Wrong Mode ⚠️

**Problem**: Rollout correction metrics were computed in **bypass mode**
instead of **decoupled mode**, making them meaningless.

**Root Cause**: Incorrect condition at
[ray_trainer.py:1177-1180](verl/trainer/ppo/ray_trainer.py#L1177)
```python
# BEFORE (incorrect - runs in bypass mode)
if rollout_corr_config is not None and "rollout_log_probs" in batch.batch:
    batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch)
```

```python
# AFTER (correct - runs in decoupled mode only)
if (rollout_corr_config is not None
    and "rollout_log_probs" in batch.batch
    and not bypass_recomputing_logprobs):  # Only in decoupled mode
    batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config)
```

**Impact**:
- IS weights and rejection sampling metrics are now computed only when
meaningful (decoupled mode with 3 policies)
- In bypass mode (2 policies), actor now correctly computes metrics from
evolving π_θ vs π_rollout

**Related Changes**:
- Added clarifying comments in
[ray_trainer.py:1104-1107](verl/trainer/ppo/ray_trainer.py#L1104)
(operating mode selection)
- Added clarifying comments in
[ray_trainer.py:1175-1177](verl/trainer/ppo/ray_trainer.py#L1175)
(metrics behavior)
- Fixed actor metrics computation in
[dp_actor.py](verl/workers/actor/dp_actor.py),
[megatron_actor.py](verl/workers/actor/megatron_actor.py)

---

## Configuration Refactor (Semantic Clarity)

### 2. Variable Renaming

Renamed config variables to accurately reflect their semantics:

| Old Name | New Name | Rationale |
|----------|----------|-----------|
| `bypass_old_logprob_for_rollout` | `bypass_mode` | Directly describes
the operating mode (2-policy vs 3-policy) |
| `use_pure_rollout_correction` | `use_policy_gradient` | Reflects
actual choice: policy gradient loss vs Q-function loss |

**Before** ([algorithm.py @
e8ad3cd](https://github.com/volcengine/verl/blob/e8ad3cdb/verl/trainer/config/algorithm.py)):
```python
bypass_old_logprob_for_rollout: bool = False  # Unclear what "bypass" means
use_pure_rollout_correction: bool = False     # "Pure" is vague
```

**After** ([algorithm.py @
HEAD](verl/trainer/config/algorithm.py#L157)):
```python
bypass_mode: bool = False           # Clear: bypass or decoupled mode
use_policy_gradient: bool = False   # Clear: PG or Q-function loss
```

**Files Updated**:
- Core config: [algorithm.py](verl/trainer/config/algorithm.py),
[rollout_correction.yaml](verl/trainer/config/algorithm/rollout_correction.yaml)
- Implementation: [ray_trainer.py](verl/trainer/ppo/ray_trainer.py),
[rollout_corr_helper.py](verl/trainer/ppo/rollout_corr_helper.py),
[core_algos.py](verl/trainer/ppo/core_algos.py)
- Examples:
[run_with_rollout_corr.sh](examples/rollout_correction/run_with_rollout_corr.sh),
[run_dapo_qwen2.5_32b_rollout_corr.sh](recipe/dapo/run_dapo_qwen2.5_32b_rollout_corr.sh)
- Generated configs:
[_generated_ppo_trainer.yaml](verl/trainer/config/_generated_ppo_trainer.yaml),
[_generated_ppo_megatron_trainer.yaml](verl/trainer/config/_generated_ppo_megatron_trainer.yaml)

---

## New Feature: Batch Normalization

### 3. IS Weight Batch Normalization

**Added**: `rollout_is_batch_normalize` config parameter
([algorithm.py:159](verl/trainer/config/algorithm.py#L159))

```python
rollout_is_batch_normalize: bool = False
```

**Purpose**:
- Normalizes importance sampling weights to have mean=1.0 within each
batch
- Aligns normalization scope with IS aggregation level
(token/sequence/geometric)
- Helps stabilize training when policy drift is large

**Behavior**:
- `True`: IS weights normalized so mean=1.0 per batch (reduces variance)
- `False`: Raw truncated IS weights used (standard behavior, default)

**Documentation**:
- Mathematical formulation: [rollout_corr_math.md
§3.4](docs/algo/rollout_corr_math.md)
- Usage guide: [rollout_corr.md](docs/algo/rollout_corr.md)

---

## Documentation Overhaul

### 4. File Reorganization

**Moved documentation to `docs/algo/`**:
- `docs/advance/rollout_corr.md` → `docs/algo/rollout_corr.md` (+439
additions)
- `docs/advance/rollout_corr_math.md` → `docs/algo/rollout_corr_math.md`
(+459 additions)

**Deleted redundant file**:
- `examples/rollout_correction/README.md` (-253 lines)

**Updated references**:
- [docs/index.rst](docs/index.rst): Updated paths
- [docs/advance/fully_async.md](docs/advance/fully_async.md): Updated
cross-references

### 5. Preset Renaming for Clarity

Renamed presets to clearly indicate operating mode:

| Old Name | New Name | Operating Mode | Description |
|----------|----------|----------------|-------------|
| `token_is` | `decoupled_token_is` | Decoupled (3-policy) | Token-level
IS weighting |
| `seq_is` | `decoupled_seq_is` | Decoupled (3-policy) | Sequence-level
IS weighting |
| `geo_rs` | `decoupled_geo_rs` | Decoupled (3-policy) | Geometric
rejection sampling |
| `ppo_is_bypass` | `ppo_is_bypass` | Bypass (2-policy) | PPO with IS
(unchanged) |
| `pure_is` | `pg_is` | Bypass (2-policy) | Policy gradient + sequence
IS |
| N/A | `pg_rs` | Bypass (2-policy) | Policy gradient + geometric RS
(new) |

**Naming Convention**:
- **Decoupled mode** presets: `decoupled_*` (requires old_log_prob
computation)
- **Bypass mode** presets: `pg_*` or `ppo_*` (skips old_log_prob
computation)

### 6. Content Improvements

**Cross-References**:
- Added prominent links between
[rollout_corr.md](docs/algo/rollout_corr.md) (usage guide) and
[rollout_corr_math.md](docs/algo/rollout_corr_math.md) (mathematical
foundations)

**Clarified Loss Formulations**:
- Changed examples from PPO to REINFORCE in [rollout_corr_math.md
§3.3](docs/algo/rollout_corr_math.md)
- **Rationale**: Separates IS weight mechanics from PPO clipping for
clarity
- Added note that REINFORCE examples can be combined with PPO clipping

**New Sections**:
- Dedicated batch normalization section: [rollout_corr_math.md
§3.4](docs/algo/rollout_corr_math.md)
- Improved operating mode explanations throughout

---

## Code Quality Improvements

### 7. Enhanced Comments and Documentation

**Trainer Logic** ([ray_trainer.py](verl/trainer/ppo/ray_trainer.py)):
- Lines 1104-1107: Operating mode selection logic
- Lines 1175-1177: Metrics computation behavior explanation

**Policy Loss** ([core_algos.py](verl/trainer/ppo/core_algos.py)):
- Enhanced docstrings for `compute_policy_loss_with_rollout_correction`
- Clarified when to use policy gradient vs Q-function loss

**Actor Workers** ([dp_actor.py](verl/workers/actor/dp_actor.py),
[megatron_actor.py](verl/workers/actor/megatron_actor.py)):
- Added comments explaining bypass mode metrics computation

### 8. Code Simplification

**Removed Unused Logic**
([rollout_corr_helper.py](verl/trainer/ppo/rollout_corr_helper.py)):
- Removed unnecessary config parameters from metrics computation
- Removed unused IS weight processing logic
- Simplified metrics calculation flow

**Improved Variable Reuse**:
- Reused `need_recomputation` variable instead of redundant bypass mode
checks
- Reduced code duplication

---

## Commit History

<details>
<summary>18 commits (click to expand)</summary>

1. `7c9e41da` - fix(rollout_corr): compute metrics in actor for bypass
mode and fix trainer bugs
2. `96ae2be1` - docs(rollout_corr): move to algo/ and add pure_rs preset
3. `c0ea9bdc` - feat(rollout_corr): add batch normalization option for
IS weights
4. `7de6c5f9` - docs(rollout_corr_math): use REINFORCE in aggregation
loss examples for clarity
5. `2b34cfee` - refactor(rollout_corr): simplify metrics computation by
removing unused config and IS weight logic
6. `0c42f85a` - docs(rollout_corr): add prominent cross-references
between usage and math docs
7. `fef8a48f` - docs(rollout_corr_math): add dedicated section for batch
normalization
8. `08cc9c7d` - fix: docstring of
compute_policy_loss_with_rollout_correction
9. `437a4aba` - feat: reuse need_recomputation instead of bypass_mode
10. `5f9a53bf` - feat: improve comments
11. `b2f63709` - feat: improve comments
12. `79cdbf2f` - feat: refactor bypass_recomputing_logprobs
13. `62e32701` - feat(rollout_corr): align batch normalization with IS
aggregation level
14. `b5c19ff7` - docs(rollout_corr): rename decoupled mode presets for
clarity and update examples
15. `11f9aa05` - fix(rollout_corr): correct metrics computation to run
in decoupled mode only
16. `58565cb0` - docs(rollout_corr): rename presets for clarity and
consistency
17. `8bb1a0e0` - refactor(rollout_corr): rename config vars for semantic
clarity
18. `6002c00c` - refactor(rollout_corr): update implementation to use
renamed config variables

</details>

---

## Summary

This PR systematically improves the rollout correction implementation
through three key areas:

1. **Bug Fixes**: Corrected metrics computation to run in the
appropriate mode
2. **Semantic Clarity**: Renamed variables to accurately reflect their
purpose (`bypass_mode`, `use_policy_gradient`)
3. **Feature Addition**: Added batch normalization option for IS weights
with comprehensive documentation

All changes maintain backward compatibility while significantly
improving code clarity, correctness, and maintainability.

---------

Co-authored-by: Shawn/Yuxuan Tong <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants