11import os
2+ import threading
23import time
34from typing import Any , Dict , Optional
45
@@ -54,6 +55,7 @@ def __init__(
5455 assert batch_size % minibatch_size == 0 , "batch_size should be divisible by microbatch_size"
5556 self .num_microbatches = batch_size // minibatch_size
5657 self .data_uid = 0
58+ self .sync_model_thread_started = False
5759
5860 self .model_config = model_config
5961 self .plugin_config = plugin_config
@@ -64,7 +66,6 @@ def __init__(
6466 self .shared_sync_data_actor = shared_sync_data_actor
6567 self .shared_signal_actor = shared_signal_actor
6668 self .state_dict_cpu = {}
67- self .next_data_source = 0 # used to track which producer to get data from next
6869
6970 def setup (self ) -> None :
7071 launch (self .rank , self .world_size , self .master_addr , self .master_port , local_rank = 0 )
@@ -183,7 +184,6 @@ def loop(self) -> None:
183184 raw_batch = ray .get (self .shared_sync_data_actor .get_data .remote (self .data_uid ))
184185 continue
185186 self .data_uid += 1
186- self .next_data_source = (self .next_data_source + 1 ) % self .num_producers
187187 raw_batch = {k : v .to (self .device ) for k , v in raw_batch .items ()}
188188 # calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
189189 # we need to calculate the metrics before filtering here for logging
@@ -253,6 +253,7 @@ def loop(self) -> None:
253253 if loss is not None :
254254 pbar .set_postfix ({"loss" : loss })
255255 need_sync_model = True
256+ ray .get (self .shared_signal_actor .set_signal .remote ("global_step" , self .global_step + 1 ))
256257 if need_sync_model and (
257258 (self .global_step + 1 ) % self .save_interval == 0
258259 or self .received_prompts >= self .train_dataset_size
@@ -269,49 +270,76 @@ def loop(self) -> None:
269270 if need_sync_model and (
270271 episode != self .num_episodes - 1 or self .received_prompts != self .train_dataset_size
271272 ):
272- # sync model weights to all producers, if no model update or it is the last training step, skip syncing
273- if self .pp_size > 1 :
274- print (
275- f"[T{ dist .get_rank ()} ] Sync model PP stage { self .pp_rank } episode { episode } step { self .global_step } "
276- )
277- else :
278- print (f"[T{ dist .get_rank ()} ] Sync model episode { episode } step { self .global_step } " )
279- torch .cuda .empty_cache ()
280- self .state_dict_cpu = {k : v .cpu () for k , v in self .state_dict ().items ()}
281- cc .barrier (group_name = "consumer_pg" )
282- if self .pp_size > 1 :
283- if self .tp_rank == 0 and self .dp_rank == 0 :
284- self .profiler .enter ("sync_model" )
285- ray .get (
286- self .shared_signal_actor .set_signal .remote (
287- f"consumer_pp_{ self .pp_rank } " , "ready_sync_model"
288- )
289- )
273+
274+ def sync_model_thread ():
275+ # sync model weights to all producers, if no model update or it is the last training step, skip syncing
276+ if self .pp_size > 1 :
290277 print (
291278 f"[T{ dist .get_rank ()} ] Sync model PP stage { self .pp_rank } episode { episode } step { self .global_step } "
292279 )
293- ray_broadcast_tensor_dict (
294- self .state_dict_cpu ,
295- src = 0 ,
296- device = torch .device ("cpu" ),
297- group_name = f"sync_model_consumer_pp_{ self .pp_rank } " ,
298- backend = "gloo" ,
299- )
300- self .profiler .exit ("sync_model" )
301- else :
302- if self .rank == 0 :
303- self .profiler .enter ("sync_model" )
304- ray .get (self .shared_signal_actor .set_signal .remote ("consumer" , "ready_sync_model" ))
280+ else :
305281 print (f"[T{ dist .get_rank ()} ] Sync model episode { episode } step { self .global_step } " )
306- ray_broadcast_tensor_dict (
307- self .state_dict_cpu ,
308- src = 0 ,
309- device = torch .device ("cpu" ),
310- group_name = "sync_model_consumer" ,
311- backend = "gloo" ,
312- )
313- self .profiler .exit ("sync_model" )
282+ torch .cuda .empty_cache ()
283+ if self .pp_size > 1 :
284+ if self .tp_rank == 0 and self .dp_rank == 0 :
285+ self .profiler .enter ("sync_model" )
286+ ray .get (
287+ self .shared_signal_actor .set_signal .remote (
288+ f"consumer_pp_{ self .pp_rank } " , "ready_sync_model"
289+ )
290+ )
291+ print (
292+ f"[T{ dist .get_rank ()} ] Sync model PP stage { self .pp_rank } episode { episode } step { self .global_step } "
293+ )
294+ ray_broadcast_tensor_dict (
295+ self .state_dict_cpu ,
296+ src = 0 ,
297+ device = torch .device ("cpu" ),
298+ group_name = f"sync_model_consumer_pp_{ self .pp_rank } " ,
299+ backend = "gloo" ,
300+ )
301+ self .profiler .exit ("sync_model" )
302+ else :
303+ if self .rank == 0 :
304+ self .profiler .enter ("sync_model" )
305+ ray .get (self .shared_signal_actor .set_signal .remote ("consumer" , "ready_sync_model" ))
306+ print (f"[T{ dist .get_rank ()} ] Sync model episode { episode } step { self .global_step } " )
307+ ray_broadcast_tensor_dict (
308+ self .state_dict_cpu ,
309+ src = 0 ,
310+ device = torch .device ("cpu" ),
311+ group_name = "sync_model_consumer" ,
312+ backend = "gloo" ,
313+ )
314+ self .profiler .exit ("sync_model" )
315+
316+ if not self .sync_model_thread_started :
317+ # only sync model when the thread is not started and no other thread is broadcasting
318+ self .sync_model_thread_started = True
319+ state_dict_ = self .state_dict ()
320+ if (self .pp_size > 1 and self .tp_rank == 0 and self .dp_rank == 0 ) or (
321+ self .pp_size == 1 and self .rank == 0
322+ ):
323+ if len (self .state_dict_cpu ) == 0 :
324+ # use pinned memory to speed up the transfer
325+ self .state_dict_cpu = {k : v .cpu ().pin_memory () for k , v in state_dict_ .items ()}
326+ torch .cuda .synchronize ()
327+ for k , v in state_dict_ .items ():
328+ self .state_dict_cpu [k ].copy_ (v , non_blocking = True )
329+ torch .cuda .synchronize ()
330+ cc .barrier (
331+ group_name = "consumer_pg"
332+ ) # to make sure all ranks have state dict offloaded to CPU before starting the thread
333+ time_before_starting_thread = time .time ()
334+ threading .Thread (target = sync_model_thread ).start ()
335+ # sync_model_thread()
336+ self .profiler .log (
337+ f"Sync model, took { time .time () - time_before_starting_thread :.2f} seconds"
338+ )
339+ self .sync_model_thread_started = False
340+ # ray.get(self.shared_signal_actor.release_process_lock.remote("broadcasting_lock"))
314341 self .profiler .log (f"Peak memory usage: { torch .cuda .max_memory_allocated () / 1024 ** 2 :.2f} MB" )
342+ self .received_prompts = 0
315343 ray .get (self .shared_signal_actor .set_signal .remote ("consumer" , "terminate" ))
316344
317345 def __del__ (self ):
0 commit comments