2
2
"""
3
3
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
4
4
"""
5
+ import os
6
+ import socket
5
7
6
8
import hydra
7
9
import ray
10
+ from omegaconf import OmegaConf
8
11
from verl .trainer .ppo .reward import load_reward_manager
9
12
10
13
from rllm .trainer .env_agent_mappings import AGENT_CLASS_MAPPING , ENV_CLASS_MAPPING , WORKFLOW_CLASS_MAPPING
@@ -20,121 +23,206 @@ def main(config):
20
23
21
24
22
25
def run_ppo_agent (config ):
26
+ # Check if Ray is not initialized
23
27
if not ray .is_initialized ():
24
- # this is for local ray cluster
25
- ray .init (runtime_env = {"env_vars" : {"TOKENIZERS_PARALLELISM" : "true" , "NCCL_DEBUG" : "WARN" }})
28
+ # Initialize Ray with a local cluster configuration
29
+ # Set environment variables in the runtime environment to control tokenizer parallelism,
30
+ # NCCL debug level, VLLM logging level, and allow runtime LoRA updating
31
+ # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration
32
+ ray .init (
33
+ runtime_env = get_ppo_ray_runtime_env (),
34
+ num_cpus = config .ray_init .num_cpus ,
35
+ )
36
+
37
+ # Create a remote instance of the TaskRunner class, and
38
+ # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete
39
+ if (
40
+ is_cuda_available
41
+ and config .trainer .get ("profile_steps" ) is not None
42
+ and len (config .trainer .get ("profile_steps" , [])) > 0
43
+ ):
44
+ nsight_options = OmegaConf .to_container (config .trainer .controller_nsight_options )
45
+ runner = TaskRunner .options (runtime_env = {"nsight" : nsight_options }).remote ()
46
+ else :
47
+ runner = TaskRunner .remote ()
48
+ ray .get (runner .run .remote (config ))
26
49
27
- ray .get (train_agent .remote (config ))
50
+ # [Optional] get the path of the timeline trace file from the configuration, default to None
51
+ # This file is used for performance analysis
52
+ timeline_json_file = config .ray_init .get ("timeline_json_file" , None )
53
+ if timeline_json_file :
54
+ ray .timeline (filename = timeline_json_file )
28
55
29
56
30
57
@ray .remote (num_cpus = 1 ) # please make sure main_task is not scheduled on head
31
- def train_agent (config , workflow_class = None , workflow_args = None , agent_class = None , env_class = None , agent_args = None , env_args = None ):
32
- # print initial config
33
- from pprint import pprint
58
+ class TaskRunner :
59
+ """Ray remote class for executing distributed PPO training tasks.
34
60
35
- from omegaconf import OmegaConf
36
- from verl .utils .fs import copy_local_path_from_hdfs
61
+ This class encapsulates the main training logic and runs as a Ray remote actor
62
+ to enable distributed execution across multiple nodes and GPUs.
63
+ """
37
64
38
- OmegaConf .register_new_resolver ("mul" , lambda x , y : int (x ) * int (y ))
39
- OmegaConf .resolve (config )
40
- pprint (OmegaConf .to_container (config ))
65
+ def run (self , config , workflow_class = None , workflow_args = None , agent_class = None , env_class = None , agent_args = None , env_args = None ):
66
+ """Execute the main PPO training workflow.
41
67
42
- # download the checkpoint from hdfs
43
- local_path = copy_local_path_from_hdfs ( config . actor_rollout_ref . model . path )
68
+ This method sets up the distributed training environment, initializes
69
+ workers, datasets, and reward functions, then starts the training process.
44
70
45
- # instantiate tokenizer
46
- from verl .utils import hf_tokenizer
71
+ Args:
72
+ config: Training configuration object containing all parameters needed
73
+ for setting up and running the PPO training process.
74
+ """
75
+ # Print the initial configuration. `resolve=True` will evaluate symbolic values.
76
+ from pprint import pprint
47
77
48
- trust_remote_code = config .data .get ("trust_remote_code" , False )
49
- tokenizer = hf_tokenizer (local_path , trust_remote_code = trust_remote_code )
50
- # processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
78
+ from omegaconf import OmegaConf
51
79
52
- if config .actor_rollout_ref .actor .strategy in ["fsdp" , "fsdp2" ]:
53
- assert config .critic .strategy in ["fsdp" , "fsdp2" ]
54
- from verl .single_controller .ray import RayWorkerGroup
55
- from verl .workers .fsdp_workers import ActorRolloutRefWorker , AsyncActorRolloutRefWorker , CriticWorker
80
+ from verl .utils .fs import copy_to_local
56
81
57
- actor_rollout_cls = AsyncActorRolloutRefWorker if config .actor_rollout_ref .rollout .mode == "async" else ActorRolloutRefWorker
58
- ray_worker_group_cls = RayWorkerGroup
59
- else :
60
- raise NotImplementedError
61
-
62
- from verl .trainer .ppo .ray_trainer import ResourcePoolManager , Role
63
-
64
- role_worker_mapping = {
65
- Role .ActorRollout : ray .remote (max_concurrency = 2048 )(actor_rollout_cls ),
66
- Role .Critic : ray .remote (CriticWorker ),
67
- }
68
-
69
- global_pool_id = "global_pool"
70
- resource_pool_spec = {
71
- global_pool_id : [config .trainer .n_gpus_per_node ] * config .trainer .nnodes ,
72
- }
73
- mapping = {
74
- Role .ActorRollout : global_pool_id ,
75
- Role .Critic : global_pool_id ,
76
- }
77
-
78
- if config .algorithm .use_kl_in_reward or config .actor_rollout_ref .actor .use_kl_loss :
79
- role_worker_mapping [Role .RefPolicy ] = ray .remote (ActorRolloutRefWorker )
80
- mapping [Role .RefPolicy ] = global_pool_id
81
-
82
- reward_fn = load_reward_manager (config , tokenizer , num_examine = 0 , ** config .reward_model .get ("reward_kwargs" , {}))
83
- val_reward_fn = load_reward_manager (config , tokenizer , num_examine = 1 )
84
- resource_pool_manager = ResourcePoolManager (resource_pool_spec = resource_pool_spec , mapping = mapping )
85
-
86
- if config .rllm .workflow .use_workflow :
87
- if workflow_class is None :
88
- workflow_class = WORKFLOW_CLASS_MAPPING [config .rllm .workflow .name ]
89
- workflow_args = workflow_args or {}
90
- if config .rllm .workflow .get ("workflow_args" ) is not None :
91
- workflow_args .update (config .rllm .workflow .get ("workflow_args" ))
92
-
93
- trainer = AgentWorkflowPPOTrainer (
94
- config = config ,
95
- tokenizer = tokenizer ,
96
- role_worker_mapping = role_worker_mapping ,
97
- resource_pool_manager = resource_pool_manager ,
98
- ray_worker_group_cls = ray_worker_group_cls ,
99
- reward_fn = reward_fn ,
100
- val_reward_fn = val_reward_fn ,
101
- workflow_class = workflow_class ,
102
- workflow_args = workflow_args ,
103
- )
82
+ print (f"TaskRunner hostname: { socket .gethostname ()} , PID: { os .getpid ()} " )
83
+ OmegaConf .register_new_resolver ("mul" , lambda x , y : int (x ) * int (y ))
84
+ OmegaConf .resolve (config )
85
+ pprint (OmegaConf .to_container (config ))
104
86
105
- else :
106
- if env_class is None :
107
- env_class = ENV_CLASS_MAPPING [config .rllm .env .name ]
108
- if agent_class is None :
109
- agent_class = AGENT_CLASS_MAPPING [config .rllm .agent .name ]
110
-
111
- env_args = env_args or {}
112
- agent_args = agent_args or {}
113
- if config .rllm .env .get ("env_args" ) is not None :
114
- env_args .update (config .rllm .env .get ("env_args" ))
115
- if config .rllm .agent .get ("agent_args" ) is not None :
116
- agent_args .update (config .rllm .agent .get ("agent_args" ))
117
-
118
- trainer = AgentPPOTrainer (
119
- config = config ,
120
- tokenizer = tokenizer ,
121
- role_worker_mapping = role_worker_mapping ,
122
- resource_pool_manager = resource_pool_manager ,
123
- ray_worker_group_cls = ray_worker_group_cls ,
124
- reward_fn = reward_fn ,
125
- val_reward_fn = val_reward_fn ,
126
- env_class = env_class ,
127
- agent_class = agent_class ,
128
- env_args = env_args ,
129
- agent_args = agent_args ,
87
+ # Download the checkpoint from HDFS to the local machine.
88
+ # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on
89
+ local_path = copy_to_local (
90
+ config .actor_rollout_ref .model .path , use_shm = config .actor_rollout_ref .model .get ("use_shm" , False )
130
91
)
131
92
132
- trainer .init_workers ()
133
- try :
134
- trainer .fit_agent ()
135
- finally :
136
- trainer .shutdown ()
137
-
93
+ # Instantiate the tokenizer and processor.
94
+ from verl .utils import hf_tokenizer
95
+
96
+ trust_remote_code = config .data .get ("trust_remote_code" , False )
97
+ tokenizer = hf_tokenizer (local_path , trust_remote_code = trust_remote_code )
98
+ # Used for multimodal LLM, could be None
99
+ # processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
100
+
101
+ # Define worker classes based on the actor strategy.
102
+ if config .actor_rollout_ref .actor .strategy in {"fsdp" , "fsdp2" }:
103
+ assert config .critic .strategy in {"fsdp" , "fsdp2" }
104
+ from verl .single_controller .ray import RayWorkerGroup
105
+ from verl .workers .fsdp_workers import ActorRolloutRefWorker , AsyncActorRolloutRefWorker
106
+
107
+ use_legacy_worker_impl = config .trainer .get ("use_legacy_worker_impl" , "auto" )
108
+ if use_legacy_worker_impl in ["auto" , "enable" ]:
109
+ # import warnings
110
+ # warnings.warn(f"Legacy worker impl is going to be deprecated, will be removed in the future. \
111
+ # Please set trainer.use_legacy_worker_impl = false to switch to the new worker implementation.")
112
+ from verl .workers .fsdp_workers import CriticWorker
113
+ elif use_legacy_worker_impl == "disable" :
114
+ from verl .workers .roles import CriticWorker
115
+
116
+ print ("Using new worker implementation" )
117
+ else :
118
+ raise ValueError (f"Invalid use_legacy_worker_impl: { use_legacy_worker_impl } " )
119
+
120
+ actor_rollout_cls = (
121
+ AsyncActorRolloutRefWorker
122
+ if config .actor_rollout_ref .rollout .mode == "async"
123
+ else ActorRolloutRefWorker
124
+ )
125
+ ray_worker_group_cls = RayWorkerGroup
126
+
127
+ elif config .actor_rollout_ref .actor .strategy == "megatron" :
128
+ assert config .actor_rollout_ref .actor .strategy == config .critic .strategy
129
+ from verl .single_controller .ray .megatron import NVMegatronRayWorkerGroup
130
+ from verl .workers .megatron_workers import ActorRolloutRefWorker , AsyncActorRolloutRefWorker , CriticWorker
131
+
132
+ actor_rollout_cls = (
133
+ AsyncActorRolloutRefWorker
134
+ if config .actor_rollout_ref .rollout .mode == "async"
135
+ else ActorRolloutRefWorker
136
+ )
137
+ ray_worker_group_cls = NVMegatronRayWorkerGroup
138
+
139
+ else :
140
+ raise NotImplementedError
141
+
142
+ from verl .trainer .ppo .ray_trainer import ResourcePoolManager , Role
143
+
144
+ # Map roles to their corresponding remote worker classes.
145
+ role_worker_mapping = {
146
+ Role .ActorRollout : ray .remote (actor_rollout_cls ),
147
+ Role .Critic : ray .remote (CriticWorker ),
148
+ }
149
+
150
+ # Define the resource pool specification.
151
+ # Map roles to the resource pool.
152
+ global_pool_id = "global_pool"
153
+ resource_pool_spec = {
154
+ global_pool_id : [config .trainer .n_gpus_per_node ] * config .trainer .nnodes ,
155
+ }
156
+ mapping = {
157
+ Role .ActorRollout : global_pool_id ,
158
+ Role .Critic : global_pool_id ,
159
+ }
160
+
161
+ # Add a reference policy worker if KL loss or KL reward is used.
162
+ if config .algorithm .use_kl_in_reward or config .actor_rollout_ref .actor .use_kl_loss :
163
+ role_worker_mapping [Role .RefPolicy ] = ray .remote (ActorRolloutRefWorker )
164
+ mapping [Role .RefPolicy ] = global_pool_id
165
+
166
+ # Load the reward manager for training and validation.
167
+ reward_fn = load_reward_manager (
168
+ config , tokenizer , num_examine = 0 , ** config .reward_model .get ("reward_kwargs" , {})
169
+ )
170
+ val_reward_fn = load_reward_manager (
171
+ config , tokenizer , num_examine = 1 , ** config .reward_model .get ("reward_kwargs" , {})
172
+ )
173
+ resource_pool_manager = ResourcePoolManager (resource_pool_spec = resource_pool_spec , mapping = mapping )
174
+
175
+ if config .rllm .workflow .use_workflow :
176
+ if workflow_class is None :
177
+ workflow_class = WORKFLOW_CLASS_MAPPING [config .rllm .workflow .name ]
178
+ workflow_args = workflow_args or {}
179
+ if config .rllm .workflow .get ("workflow_args" ) is not None :
180
+ workflow_args .update (config .rllm .workflow .get ("workflow_args" ))
181
+
182
+ trainer = AgentWorkflowPPOTrainer (
183
+ config = config ,
184
+ tokenizer = tokenizer ,
185
+ role_worker_mapping = role_worker_mapping ,
186
+ resource_pool_manager = resource_pool_manager ,
187
+ ray_worker_group_cls = ray_worker_group_cls ,
188
+ reward_fn = reward_fn ,
189
+ val_reward_fn = val_reward_fn ,
190
+ workflow_class = workflow_class ,
191
+ workflow_args = workflow_args ,
192
+ )
193
+
194
+ else :
195
+ if env_class is None :
196
+ env_class = ENV_CLASS_MAPPING [config .rllm .env .name ]
197
+ if agent_class is None :
198
+ agent_class = AGENT_CLASS_MAPPING [config .rllm .agent .name ]
199
+
200
+ env_args = env_args or {}
201
+ agent_args = agent_args or {}
202
+ if config .rllm .env .get ("env_args" ) is not None :
203
+ env_args .update (config .rllm .env .get ("env_args" ))
204
+ if config .rllm .agent .get ("agent_args" ) is not None :
205
+ agent_args .update (config .rllm .agent .get ("agent_args" ))
206
+
207
+ trainer = AgentPPOTrainer (
208
+ config = config ,
209
+ tokenizer = tokenizer ,
210
+ role_worker_mapping = role_worker_mapping ,
211
+ resource_pool_manager = resource_pool_manager ,
212
+ ray_worker_group_cls = ray_worker_group_cls ,
213
+ reward_fn = reward_fn ,
214
+ val_reward_fn = val_reward_fn ,
215
+ env_class = env_class ,
216
+ agent_class = agent_class ,
217
+ env_args = env_args ,
218
+ agent_args = agent_args ,
219
+ )
220
+
221
+ trainer .init_workers ()
222
+ try :
223
+ trainer .fit_agent ()
224
+ finally :
225
+ trainer .shutdown ()
138
226
139
227
if __name__ == "__main__" :
140
228
main ()
0 commit comments