Skip to content

Commit 1b393de

Browse files
fix ppo trainer, remove gpu dependencies
1 parent b7381ab commit 1b393de

File tree

6 files changed

+197
-117
lines changed

6 files changed

+197
-117
lines changed

Dockerfile

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,9 @@ ENV DEBIAN_FRONTEND=noninteractive
44

55
WORKDIR /workspace
66

7-
RUN pip uninstall verl -y || true
7+
RUN git clone https://github.com/rllm-org/rllm.git rllm
88

9-
RUN git clone --recurse-submodules https://github.com/rllm-org/rllm.git rllm
10-
11-
RUN cd rllm && \
12-
pip install -e ./verl && \
13-
pip install -e .
9+
RUN cd rllm && pip install -e .
1410

1511
RUN pip install playwright && \
1612
playwright install chromium && \

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,6 @@ bash scripts/install_verl.sh # (or follow the instructions at https://verl.readt
5656

5757
# Install rllm
5858
pip install -e .
59-
60-
**Note:** On macOS, GPU features (flash-attn, deepspeed, vllm) are automatically excluded for compatibility. For GPU support on macOS, you can install with: `pip install -e .[gpu]`
6159
```
6260

6361
### Installation with Docker 🐳

examples/solver_judge/train_solver_judge_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import hydra
22

3-
from examples.countdown.countdown_reward import countdown_reward_fn
43
from examples.solver_judge.solver_judge_flow import SolverJudgeWorkflow
54
from rllm.data.dataset import DatasetRegistry
5+
from rllm.rewards.countdown_reward import countdown_reward_fn
66
from rllm.trainer.agent_trainer import AgentTrainer
77

88

pyproject.toml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,6 @@ dependencies = [
8989
]
9090

9191
[project.optional-dependencies]
92-
gpu = [
93-
"flash-attn>=2.7.4.post1; sys_platform != 'darwin'",
94-
"vllm>=0.8.3; sys_platform != 'darwin'",
95-
"sglang>=0.4.6.post1; sys_platform != 'darwin'",
96-
]
9792

9893
smolagents = [
9994
"smolagents==1.20.0",

rllm/trainer/agent_trainer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from typing import Any
22

33
import ray
4+
from verl.trainer.constants_ppo import get_ppo_ray_runtime_env
45

56
from rllm.data import Dataset
6-
from rllm.trainer.verl.train_agent_ppo import train_agent
7+
from rllm.trainer.verl.train_agent_ppo import TaskRunner
78

89

910
class AgentTrainer:
@@ -55,10 +56,12 @@ def __init__(
5556

5657
def train(self):
5758
if not ray.is_initialized():
58-
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
59+
ray.init(runtime_env=get_ppo_ray_runtime_env(), num_cpus=self.config.ray_init.num_cpus)
60+
61+
runner = TaskRunner.remote()
5962

6063
ray.get(
61-
train_agent.remote(
64+
runner.run.remote(
6265
config=self.config,
6366
workflow_class=self.workflow_class,
6467
workflow_args=self.workflow_args,

rllm/trainer/verl/train_agent_ppo.py

Lines changed: 188 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
"""
33
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
44
"""
5+
import os
6+
import socket
57

68
import hydra
79
import ray
10+
from omegaconf import OmegaConf
811
from verl.trainer.ppo.reward import load_reward_manager
912

1013
from rllm.trainer.env_agent_mappings import AGENT_CLASS_MAPPING, ENV_CLASS_MAPPING, WORKFLOW_CLASS_MAPPING
@@ -20,121 +23,206 @@ def main(config):
2023

2124

2225
def run_ppo_agent(config):
26+
# Check if Ray is not initialized
2327
if not ray.is_initialized():
24-
# this is for local ray cluster
25-
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
28+
# Initialize Ray with a local cluster configuration
29+
# Set environment variables in the runtime environment to control tokenizer parallelism,
30+
# NCCL debug level, VLLM logging level, and allow runtime LoRA updating
31+
# `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration
32+
ray.init(
33+
runtime_env=get_ppo_ray_runtime_env(),
34+
num_cpus=config.ray_init.num_cpus,
35+
)
36+
37+
# Create a remote instance of the TaskRunner class, and
38+
# Execute the `run` method of the TaskRunner instance remotely and wait for it to complete
39+
if (
40+
is_cuda_available
41+
and config.trainer.get("profile_steps") is not None
42+
and len(config.trainer.get("profile_steps", [])) > 0
43+
):
44+
nsight_options = OmegaConf.to_container(config.trainer.controller_nsight_options)
45+
runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote()
46+
else:
47+
runner = TaskRunner.remote()
48+
ray.get(runner.run.remote(config))
2649

27-
ray.get(train_agent.remote(config))
50+
# [Optional] get the path of the timeline trace file from the configuration, default to None
51+
# This file is used for performance analysis
52+
timeline_json_file = config.ray_init.get("timeline_json_file", None)
53+
if timeline_json_file:
54+
ray.timeline(filename=timeline_json_file)
2855

2956

3057
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
31-
def train_agent(config, workflow_class=None, workflow_args=None, agent_class=None, env_class=None, agent_args=None, env_args=None):
32-
# print initial config
33-
from pprint import pprint
58+
class TaskRunner:
59+
"""Ray remote class for executing distributed PPO training tasks.
3460
35-
from omegaconf import OmegaConf
36-
from verl.utils.fs import copy_local_path_from_hdfs
61+
This class encapsulates the main training logic and runs as a Ray remote actor
62+
to enable distributed execution across multiple nodes and GPUs.
63+
"""
3764

38-
OmegaConf.register_new_resolver("mul", lambda x, y: int(x) * int(y))
39-
OmegaConf.resolve(config)
40-
pprint(OmegaConf.to_container(config))
65+
def run(self, config, workflow_class=None, workflow_args=None, agent_class=None, env_class=None, agent_args=None, env_args=None):
66+
"""Execute the main PPO training workflow.
4167
42-
# download the checkpoint from hdfs
43-
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
68+
This method sets up the distributed training environment, initializes
69+
workers, datasets, and reward functions, then starts the training process.
4470
45-
# instantiate tokenizer
46-
from verl.utils import hf_tokenizer
71+
Args:
72+
config: Training configuration object containing all parameters needed
73+
for setting up and running the PPO training process.
74+
"""
75+
# Print the initial configuration. `resolve=True` will evaluate symbolic values.
76+
from pprint import pprint
4777

48-
trust_remote_code = config.data.get("trust_remote_code", False)
49-
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
50-
# processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
78+
from omegaconf import OmegaConf
5179

52-
if config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]:
53-
assert config.critic.strategy in ["fsdp", "fsdp2"]
54-
from verl.single_controller.ray import RayWorkerGroup
55-
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker
80+
from verl.utils.fs import copy_to_local
5681

57-
actor_rollout_cls = AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker
58-
ray_worker_group_cls = RayWorkerGroup
59-
else:
60-
raise NotImplementedError
61-
62-
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
63-
64-
role_worker_mapping = {
65-
Role.ActorRollout: ray.remote(max_concurrency=2048)(actor_rollout_cls),
66-
Role.Critic: ray.remote(CriticWorker),
67-
}
68-
69-
global_pool_id = "global_pool"
70-
resource_pool_spec = {
71-
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
72-
}
73-
mapping = {
74-
Role.ActorRollout: global_pool_id,
75-
Role.Critic: global_pool_id,
76-
}
77-
78-
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
79-
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
80-
mapping[Role.RefPolicy] = global_pool_id
81-
82-
reward_fn = load_reward_manager(config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}))
83-
val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1)
84-
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
85-
86-
if config.rllm.workflow.use_workflow:
87-
if workflow_class is None:
88-
workflow_class = WORKFLOW_CLASS_MAPPING[config.rllm.workflow.name]
89-
workflow_args = workflow_args or {}
90-
if config.rllm.workflow.get("workflow_args") is not None:
91-
workflow_args.update(config.rllm.workflow.get("workflow_args"))
92-
93-
trainer = AgentWorkflowPPOTrainer(
94-
config=config,
95-
tokenizer=tokenizer,
96-
role_worker_mapping=role_worker_mapping,
97-
resource_pool_manager=resource_pool_manager,
98-
ray_worker_group_cls=ray_worker_group_cls,
99-
reward_fn=reward_fn,
100-
val_reward_fn=val_reward_fn,
101-
workflow_class=workflow_class,
102-
workflow_args=workflow_args,
103-
)
82+
print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")
83+
OmegaConf.register_new_resolver("mul", lambda x, y: int(x) * int(y))
84+
OmegaConf.resolve(config)
85+
pprint(OmegaConf.to_container(config))
10486

105-
else:
106-
if env_class is None:
107-
env_class = ENV_CLASS_MAPPING[config.rllm.env.name]
108-
if agent_class is None:
109-
agent_class = AGENT_CLASS_MAPPING[config.rllm.agent.name]
110-
111-
env_args = env_args or {}
112-
agent_args = agent_args or {}
113-
if config.rllm.env.get("env_args") is not None:
114-
env_args.update(config.rllm.env.get("env_args"))
115-
if config.rllm.agent.get("agent_args") is not None:
116-
agent_args.update(config.rllm.agent.get("agent_args"))
117-
118-
trainer = AgentPPOTrainer(
119-
config=config,
120-
tokenizer=tokenizer,
121-
role_worker_mapping=role_worker_mapping,
122-
resource_pool_manager=resource_pool_manager,
123-
ray_worker_group_cls=ray_worker_group_cls,
124-
reward_fn=reward_fn,
125-
val_reward_fn=val_reward_fn,
126-
env_class=env_class,
127-
agent_class=agent_class,
128-
env_args=env_args,
129-
agent_args=agent_args,
87+
# Download the checkpoint from HDFS to the local machine.
88+
# `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on
89+
local_path = copy_to_local(
90+
config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False)
13091
)
13192

132-
trainer.init_workers()
133-
try:
134-
trainer.fit_agent()
135-
finally:
136-
trainer.shutdown()
137-
93+
# Instantiate the tokenizer and processor.
94+
from verl.utils import hf_tokenizer
95+
96+
trust_remote_code = config.data.get("trust_remote_code", False)
97+
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
98+
# Used for multimodal LLM, could be None
99+
# processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
100+
101+
# Define worker classes based on the actor strategy.
102+
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
103+
assert config.critic.strategy in {"fsdp", "fsdp2"}
104+
from verl.single_controller.ray import RayWorkerGroup
105+
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
106+
107+
use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
108+
if use_legacy_worker_impl in ["auto", "enable"]:
109+
# import warnings
110+
# warnings.warn(f"Legacy worker impl is going to be deprecated, will be removed in the future. \
111+
# Please set trainer.use_legacy_worker_impl = false to switch to the new worker implementation.")
112+
from verl.workers.fsdp_workers import CriticWorker
113+
elif use_legacy_worker_impl == "disable":
114+
from verl.workers.roles import CriticWorker
115+
116+
print("Using new worker implementation")
117+
else:
118+
raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}")
119+
120+
actor_rollout_cls = (
121+
AsyncActorRolloutRefWorker
122+
if config.actor_rollout_ref.rollout.mode == "async"
123+
else ActorRolloutRefWorker
124+
)
125+
ray_worker_group_cls = RayWorkerGroup
126+
127+
elif config.actor_rollout_ref.actor.strategy == "megatron":
128+
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
129+
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
130+
from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker
131+
132+
actor_rollout_cls = (
133+
AsyncActorRolloutRefWorker
134+
if config.actor_rollout_ref.rollout.mode == "async"
135+
else ActorRolloutRefWorker
136+
)
137+
ray_worker_group_cls = NVMegatronRayWorkerGroup
138+
139+
else:
140+
raise NotImplementedError
141+
142+
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
143+
144+
# Map roles to their corresponding remote worker classes.
145+
role_worker_mapping = {
146+
Role.ActorRollout: ray.remote(actor_rollout_cls),
147+
Role.Critic: ray.remote(CriticWorker),
148+
}
149+
150+
# Define the resource pool specification.
151+
# Map roles to the resource pool.
152+
global_pool_id = "global_pool"
153+
resource_pool_spec = {
154+
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
155+
}
156+
mapping = {
157+
Role.ActorRollout: global_pool_id,
158+
Role.Critic: global_pool_id,
159+
}
160+
161+
# Add a reference policy worker if KL loss or KL reward is used.
162+
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
163+
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
164+
mapping[Role.RefPolicy] = global_pool_id
165+
166+
# Load the reward manager for training and validation.
167+
reward_fn = load_reward_manager(
168+
config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})
169+
)
170+
val_reward_fn = load_reward_manager(
171+
config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {})
172+
)
173+
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
174+
175+
if config.rllm.workflow.use_workflow:
176+
if workflow_class is None:
177+
workflow_class = WORKFLOW_CLASS_MAPPING[config.rllm.workflow.name]
178+
workflow_args = workflow_args or {}
179+
if config.rllm.workflow.get("workflow_args") is not None:
180+
workflow_args.update(config.rllm.workflow.get("workflow_args"))
181+
182+
trainer = AgentWorkflowPPOTrainer(
183+
config=config,
184+
tokenizer=tokenizer,
185+
role_worker_mapping=role_worker_mapping,
186+
resource_pool_manager=resource_pool_manager,
187+
ray_worker_group_cls=ray_worker_group_cls,
188+
reward_fn=reward_fn,
189+
val_reward_fn=val_reward_fn,
190+
workflow_class=workflow_class,
191+
workflow_args=workflow_args,
192+
)
193+
194+
else:
195+
if env_class is None:
196+
env_class = ENV_CLASS_MAPPING[config.rllm.env.name]
197+
if agent_class is None:
198+
agent_class = AGENT_CLASS_MAPPING[config.rllm.agent.name]
199+
200+
env_args = env_args or {}
201+
agent_args = agent_args or {}
202+
if config.rllm.env.get("env_args") is not None:
203+
env_args.update(config.rllm.env.get("env_args"))
204+
if config.rllm.agent.get("agent_args") is not None:
205+
agent_args.update(config.rllm.agent.get("agent_args"))
206+
207+
trainer = AgentPPOTrainer(
208+
config=config,
209+
tokenizer=tokenizer,
210+
role_worker_mapping=role_worker_mapping,
211+
resource_pool_manager=resource_pool_manager,
212+
ray_worker_group_cls=ray_worker_group_cls,
213+
reward_fn=reward_fn,
214+
val_reward_fn=val_reward_fn,
215+
env_class=env_class,
216+
agent_class=agent_class,
217+
env_args=env_args,
218+
agent_args=agent_args,
219+
)
220+
221+
trainer.init_workers()
222+
try:
223+
trainer.fit_agent()
224+
finally:
225+
trainer.shutdown()
138226

139227
if __name__ == "__main__":
140228
main()

0 commit comments

Comments
 (0)