From 509274c47e0919c586293a3b3134324b8eecf13e Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Wed, 9 Jul 2025 11:21:43 +0800 Subject: [PATCH 1/5] add code for zero-bubble implementation --- .../ColossalChat/coati/distributed/comm.py | 117 +++- .../coati/distributed/launch_zero_bubble.py | 306 ++++++++++ .../coati/distributed/zero_bubble/__init__.py | 0 .../coati/distributed/zero_bubble/consumer.py | 347 ++++++++++++ .../distributed/zero_bubble/distributor.py | 108 ++++ .../distributed/zero_bubble/grpo_consumer.py | 498 ++++++++++++++++ .../coati/distributed/zero_bubble/producer.py | 533 ++++++++++++++++++ .../ColossalChat/rl_example_zero_bubble.py | 369 ++++++++++++ 8 files changed, 2267 insertions(+), 11 deletions(-) create mode 100644 applications/ColossalChat/coati/distributed/launch_zero_bubble.py create mode 100644 applications/ColossalChat/coati/distributed/zero_bubble/__init__.py create mode 100644 applications/ColossalChat/coati/distributed/zero_bubble/consumer.py create mode 100644 applications/ColossalChat/coati/distributed/zero_bubble/distributor.py create mode 100644 applications/ColossalChat/coati/distributed/zero_bubble/grpo_consumer.py create mode 100644 applications/ColossalChat/coati/distributed/zero_bubble/producer.py create mode 100644 applications/ColossalChat/rl_example_zero_bubble.py diff --git a/applications/ColossalChat/coati/distributed/comm.py b/applications/ColossalChat/coati/distributed/comm.py index 3824303f55bd..0a724d53bc18 100644 --- a/applications/ColossalChat/coati/distributed/comm.py +++ b/applications/ColossalChat/coati/distributed/comm.py @@ -1,5 +1,6 @@ from typing import Any, Dict - +import copy +import ray import ray.util.collective as cc import torch import torch.distributed.distributed_c10d as c10d @@ -30,11 +31,18 @@ def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str = obj = c10d._tensor_to_object(obj, size_tensor.item()) return obj - def ray_broadcast_tensor_dict( - tensor_dict: Dict[str, torch.Tensor], src: int = 0, device=None, group_name: str = "default" + tensor_dict: Dict[str, torch.Tensor], + src: int = 0, + device=None, + group_name: str = "default", + backend: str = "nccl", + offload_to_cpu: bool = False, + pin_memory: bool = False, ) -> Dict[str, torch.Tensor]: rank = cc.get_rank(group_name) + if tensor_dict is None: + tensor_dict = {} if rank == src: metadata = [] for k, v in tensor_dict.items(): @@ -42,16 +50,103 @@ def ray_broadcast_tensor_dict( else: metadata = None metadata = ray_broadcast_object(metadata, src, device, group_name) - if rank != src: - out_dict = {} for k, shape, dtype in metadata: if rank == src: - tensor = tensor_dict[k] + if offload_to_cpu: + tensor = tensor_dict[k].to(device) + else: + tensor = tensor_dict[k] else: - tensor = torch.empty(shape, dtype=dtype, device=device) + tensor = tensor_dict.get(k, torch.zeros(shape, dtype=dtype, device=device, pin_memory=pin_memory)) + if backend == "gloo" and dtype == torch.bfloat16: + # Gloo does not support bfloat16, convert to float16 + tensor = tensor.view(torch.float16) cc.broadcast(tensor, src, group_name) + if backend == "gloo" and dtype == torch.bfloat16: + # Convert back to bfloat16 if it was converted to float16 + tensor = tensor.view(torch.bfloat16) if rank != src: - out_dict[k] = tensor - if rank == src: - out_dict = tensor_dict - return out_dict + if offload_to_cpu: + tensor_dict[k] = tensor.cpu() + else: + tensor_dict[k] = tensor + return tensor_dict + + +@ray.remote +class SharedVariableActor: + def __init__(self, number_of_readers: int = 0, buffer_size_limit: int = 1000): + self.data_queue = [] + self.data_uid = 0 + self.number_of_readers = number_of_readers + self.queue_size = 0 + self.signals = {} + self.process_locks = {} + self.signal_procs_meet_count = {} + self.buffer_size_limit = buffer_size_limit + + def pickup_rollout_task(self, num_tasks: int): + """ + use queue size to control whether producers should generating new rollouts or wait + for consumer to consumer more data. if queue size is less than threshold, + it means consumer is consuming data fast enough, so producers can generate new rollouts. + if queue size is greater than threshold, it means consumer is consuming data slowly, + so producers should wait for consumer to consume more data. + + Any free producer can pick up the task to generate rollout then increase the queued_data_size + to prevent other producer to pick up the task redundantly, Note it is not the real + queue length as data may still be generating + """ + ret = False + if self.queue_size < self.buffer_size_limit: + ret = True + self.queue_size += num_tasks + return ret + + def append_data(self, data): + self.data_queue.append([self.data_uid, data, 0]) # [data_uid, data, access_count] + self.data_uid += 1 + return True + + def get_data(self, data_uid: int): + # for multi-process data reading + if not self.data_queue: + # no data in the queue, return None + return None + to_pop_index = None + ret = None + for i, (uid, data, access_count) in enumerate(self.data_queue): + if uid == data_uid: + # found the data with the given uid + self.data_queue[i][2] += 1 + ret = copy.deepcopy(data) + if self.data_queue[i][2] == self.number_of_readers: + to_pop_index = i + break + if to_pop_index is not None: + # remove the data from the queue if it has been accessed by all readers + self.data_queue.pop(to_pop_index) + self.queue_size -= data["input_ids"].size(0) + return ret + + def acquire_process_lock(self, key: str): + # atomic lock for process + if key not in self.process_locks: + self.process_locks[key] = 1 # locked + return 0 + if self.process_locks[key] == 0: + self.process_locks[key] = 1 # lock the process + return 0 + else: + return 1 + + def release_process_lock(self, key: str): + # atomic unlock for process + assert self.process_locks.get(key, 0) == 1, f"Releasing a process lock {key} that is not locked." + self.process_locks[key] = 0 + + def set_signal(self, key: str, signal: str): + self.signals[key] = signal + + def get_signal(self): + return self.signals diff --git a/applications/ColossalChat/coati/distributed/launch_zero_bubble.py b/applications/ColossalChat/coati/distributed/launch_zero_bubble.py new file mode 100644 index 000000000000..635493b7d957 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/launch_zero_bubble.py @@ -0,0 +1,306 @@ +import copy +import os +import uuid +from typing import Any, Dict, Optional + +import ray + +from .comm import SharedVariableActor +from .zero_bubble.distributor import Distributor +from .zero_bubble.grpo_consumer import GRPOConsumer +from .zero_bubble.producer import SimpleProducer + +ALGO_MAP = {"GRPO": GRPOConsumer, "DAPO": GRPOConsumer} + + +def get_jsonl_size_fast(path: str) -> int: + with open(path) as f: + lines = f.readlines() + lines = [line for line in lines if line.strip()] + return len(lines) + + +def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int: + tp_size = plugin_config.get("tp_size", 1) + pp_size = plugin_config.get("pp_size", 1) + ep_size = plugin_config.get("ep_size", 1) + sp_size = plugin_config.get("sp_size", 1) + return n_procs // (tp_size * pp_size * ep_size * sp_size) + + +def launch_distributed( + num_producers: int, + num_proc_per_producer: int, + num_consumer_procs: int, + num_episodes: int, + inference_batch_size: int, + inference_microbatch_size: int, + train_batch_size: int, + train_minibatch_size: int, + train_dataset_config: Dict[str, Any], + inference_model_config: Dict[str, Any], + generate_config: Dict[str, Any], + train_model_config: Dict[str, Any], + grpo_config: Dict[str, Any], + plugin_config: Dict[str, Any], + tokenizer_config: Optional[Dict[str, Any]] = None, + inference_backend: str = "transformers", + num_generations: int = 8, + master_addr: str = "localhost", + master_port: int = 29500, + core_algo: str = "GRPO", + project_name: Optional[str] = None, + save_interval: int = 100, + save_dir: str = "./model", + eval_dataset_config: Optional[Dict[str, Any]] = None, + eval_interval: int = 100, + eval_save_dir: Optional[str] = None, + eval_generation_config: Optional[Dict[str, Any]] = None, + log_rollout_interval: int = 20, + rollout_save_dir: str = "./rollout", + enable_profiling: bool = False, + data_actor_buffer_size_limit: int = 0, +): + if core_algo not in ALGO_MAP: + raise NotImplementedError(f"{core_algo} is not supported yet.") + else: + core_consumer = ALGO_MAP.get(core_algo, GRPOConsumer) + + train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config) + assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0 + if data_actor_buffer_size_limit <= 0: + # use 2 times the train_minibatch_size as the default buffer size limit + data_actor_buffer_size_limit = train_minibatch_size * train_dp_size * 2 + + dataset_path = train_dataset_config["path"] + train_dataset_size = get_jsonl_size_fast(dataset_path) + global_inference_batch_size = inference_batch_size * num_producers + train_dataset_size = (train_dataset_size // global_inference_batch_size) * global_inference_batch_size + + run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}" + wandb_group_name = str(uuid.uuid4()) + rollout_log_file = os.path.join( + rollout_save_dir, + f"{project_name.replace(' ','_')}_run_{wandb_group_name}.jsonl", + ) + + # Attention: Ray use complex schedualing method that consider various factors including load-balancing. + # when requesting resources, it is not guaranteed that the resource comes from a node with lower node it + # this go against the design principle of our implementation, and we need to manually force the schedualing, + # allocating the producer to nodes with lower node id and the consumer to the resouces from nodes with higher + # node id. See the reference here: https://docs.ray.io/en/latest/ray-core/scheduling/index.html#nodeaffinityschedulingstrategy + nodes = ray.nodes() + + # every producer is associated with a data worker, data worker is responsible for moving data from the producer to all consumer + shared_sync_data_actor = SharedVariableActor.remote(num_consumer_procs, data_actor_buffer_size_limit) + # all producer and the consumer 0 share the same model actor, model actor only provide signal for model synchronization + shared_signal_actor = SharedVariableActor.remote() + + node_info = { + node["NodeID"]: { + "num_gpus": node["Resources"].get("GPU", 0), + "address": node["NodeManagerAddress"], + } # Default to 0 if no GPUs are available + for node in nodes + } + gpu_to_node_id = [] + gpu_to_ip_address = [] + for node_id in node_info: + for idx in range(int(node_info[node_id]["num_gpus"])): + gpu_to_node_id.append(node_id) + gpu_to_ip_address.append(node_info[node_id]["address"]) + print(node_info) + + producer_procs = [] + for i in range(num_producers): + node_id = gpu_to_node_id[0] + producer_ip_address = gpu_to_ip_address[0] + for _ in range(num_proc_per_producer): + gpu_to_node_id.pop(0) + gpu_to_ip_address.pop(0) + print(f"Schedual Producer P[{i}] which requires {num_proc_per_producer} GPUs on node {producer_ip_address}") + producer = SimpleProducer.options(num_gpus=num_proc_per_producer, num_cpus=4).remote( + shared_sync_data_actor=shared_sync_data_actor, + shared_signal_actor=shared_signal_actor, + producer_idx=i, + num_producers=num_producers, + num_consumer_procs=num_consumer_procs, + num_episodes=num_episodes, + batch_size=inference_batch_size, + train_dataset_config=train_dataset_config, + model_config=inference_model_config, + generate_config=generate_config, + tokenizer_config=tokenizer_config, + microbatch_size=inference_microbatch_size, + backend=inference_backend, + num_generations=num_generations, + consumer_plugin_config=plugin_config, + eval_dataset_config=eval_dataset_config, + eval_interval=eval_interval, + grpo_config=grpo_config, + eval_save_dir=eval_save_dir, + eval_generation_config=eval_generation_config, + project_name=project_name, + run_name=run_name, + wandb_group_name=wandb_group_name, + log_rollout_interval=log_rollout_interval, + rollout_log_file=rollout_log_file, + enable_profiling=enable_profiling, + ) + producer_procs.append(producer) + # ray.get([p.setup.remote() for p in producer_procs]) + generate_config_consumer = copy.deepcopy(generate_config) + generate_config_consumer.update( + dict( + backend=inference_backend, + ) + ) + consumer_master_ip_address = gpu_to_ip_address[0] + print(f"Use {consumer_master_ip_address} as master address for torch DDP.") + consumer_procs = [] + if num_consumer_procs <= 1: + raise ValueError("Number of consumer processes should be greater than 1 for async rl training.") + for i in range(num_consumer_procs): + node_id = gpu_to_node_id[0] + consumer_ip_address = gpu_to_ip_address[0] + gpu_to_node_id.pop(0) + gpu_to_ip_address.pop(0) + print(f"Schedual Consumer T[{i}] which requires 1 GPUs on node {consumer_ip_address}") + consumer = core_consumer.options(num_gpus=1, num_cpus=4).remote( + shared_sync_data_actor=shared_sync_data_actor, + shared_signal_actor=shared_signal_actor, + num_producers=num_producers, + num_episodes=num_episodes, + rank=i, + world_size=num_consumer_procs, + master_addr=consumer_master_ip_address, + master_port=master_port, + train_dataset_size=train_dataset_size, + batch_size=train_batch_size, + model_config=train_model_config, + plugin_config=plugin_config, + minibatch_size=train_minibatch_size, + generate_config=generate_config_consumer, + grpo_config=grpo_config, + num_generations=num_generations, + save_interval=save_interval, + save_dir=save_dir, + project_name=project_name, + run_name=run_name, + wandb_group_name=wandb_group_name, + enable_profiling=enable_profiling, + ) + consumer_procs.append(consumer) + + distributor_procs = [] + for i in range(num_producers): + distributor_procs.append( + Distributor.options(num_cpus=2).remote( + i, + plugin_config.get("pp_size", 1), + num_producers, + shared_signal_actor, + enable_profiling=enable_profiling, + ) + ) + print("=================== All processes are created, starting setup torch DDP ===================", flush=True) + ray.get([p.setup.remote() for p in consumer_procs]) + print( + "=================== All processes are setup, starting initialize communication groups ===================", + flush=True, + ) + remote_refs = [] + # Initialize consumer communication group + for i, p in enumerate(consumer_procs): + remote_refs.append(p.init_collective_group.remote(num_consumer_procs, i, "gloo", f"consumer_pg")) + ray.get(remote_refs) + remote_refs = [] + # Initialize producer communication group + for i, p in enumerate(producer_procs): + remote_refs.append(p.init_collective_group.remote(num_producers, i, "nccl", f"producer_pg")) + ray.get(remote_refs) + remote_refs = [] + # Initialize distributor communication group + for i, p in enumerate(distributor_procs): + remote_refs.append(p.init_collective_group.remote(num_producers, i, "gloo", f"distributor_pg")) + ray.get(remote_refs) + remote_refs = [] + # Initialize sync model communication group between consumer and sync model actor + # As per tested, gloo do not support nested initialization, so we need to initialize all participants in the same group in the same ray.get call. + consumer_pp = plugin_config.get("pp_size", 1) + for i, p in enumerate(consumer_procs): + consumer_ddp_config = ray.get(p.get_ddp_config.remote()) + if consumer_pp > 1: + if consumer_ddp_config["tp_rank"] == 0 and consumer_ddp_config["dp_rank"] == 0: + pp_rank = consumer_ddp_config["pp_rank"] + remote_refs.append( + p.init_collective_group.remote( + num_producers + 1, + 0, + backend="gloo", + group_name=f"sync_model_consumer_pp_{pp_rank}", + gloo_timeout=3000000, + ) + ) + for distributor_id, p_distributor in enumerate(distributor_procs): + remote_refs.append( + p_distributor.init_collective_group.remote( + num_producers + 1, + 1 + distributor_id, + backend="gloo", + group_name=f"sync_model_consumer_pp_{pp_rank}", + gloo_timeout=3000000, + ) + ) + ray.get(remote_refs) + remote_refs = [] + else: + if i == 0: + remote_refs.append( + p.init_collective_group.remote( + num_producers + 1, 0, backend="gloo", group_name=f"sync_model_consumer", gloo_timeout=3000000 + ) + ) + for distributor_id, p_distributor in enumerate(distributor_procs): + remote_refs.append( + p_distributor.init_collective_group.remote( + num_producers + 1, + 1 + distributor_id, + backend="gloo", + group_name=f"sync_model_consumer", + gloo_timeout=3000000, + ) + ) + ray.get(remote_refs) + remote_refs = [] + # Initialize sync model communication group between producer and sync model actor + for i, p in enumerate(producer_procs): + if consumer_pp > 1: + for pp_rank in range(consumer_pp): + remote_refs.append( + p.init_collective_group.remote( + 2, 0, backend="gloo", group_name=f"sync_model_producer_{i}_pp_{pp_rank}", gloo_timeout=3000000 + ) + ) + remote_refs.append( + distributor_procs[i].init_collective_group.remote( + 2, 1, backend="gloo", group_name=f"sync_model_producer_{i}_pp_{pp_rank}", gloo_timeout=3000000 + ) + ) + ray.get(remote_refs) + remote_refs = [] + else: + remote_refs.append( + p.init_collective_group.remote( + 2, 0, backend="gloo", group_name=f"sync_model_producer_{i}", gloo_timeout=3000000 + ) + ) + remote_refs.append( + distributor_procs[i].init_collective_group.remote( + 2, 1, backend="gloo", group_name=f"sync_model_producer_{i}", gloo_timeout=3000000 + ) + ) + ray.get(remote_refs) + remote_refs = [] + print("=================== All processes are set up, starting loop ===================", flush=True) + ray.get([p.loop.remote() for p in (producer_procs + consumer_procs + distributor_procs)]) diff --git a/applications/ColossalChat/coati/distributed/zero_bubble/__init__.py b/applications/ColossalChat/coati/distributed/zero_bubble/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/applications/ColossalChat/coati/distributed/zero_bubble/consumer.py b/applications/ColossalChat/coati/distributed/zero_bubble/consumer.py new file mode 100644 index 000000000000..82242e874308 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/zero_bubble/consumer.py @@ -0,0 +1,347 @@ +import os +import threading +import time +from typing import Any, Dict, Optional + +import ray +import ray.util.collective as cc +import torch +import torch.distributed as dist +from coati.distributed.profiling_utils import CustomProfiler +from tqdm import tqdm + +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.initialize import launch +from colossalai.utils import get_current_device + +from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict +from coati.distributed.utils import bind_batch, post_recv, unbind_batch + + +class BaseConsumer: + def __init__( + self, + shared_sync_data_actor: SharedVariableActor, + shared_signal_actor: SharedVariableActor, + num_producers: int, + num_episodes: int, + rank: int, + world_size: int, + master_addr: str, + master_port: int, + train_dataset_size: int, + batch_size: int, + model_config: Dict[str, Any], + plugin_config: Dict[str, Any], + minibatch_size: int = 1, + save_interval: int = 100, + save_dir: str = "./model", + enable_profiling: bool = False, + ): + self.num_producers = num_producers + self.num_episodes = num_episodes + self.rank = rank + self.world_size = world_size + self.master_addr = master_addr + self.master_port = master_port + self.train_dataset_size = train_dataset_size + self.received_prompts = 0 + self.batch_size = batch_size + self.minibatch_size = minibatch_size + self.save_interval = save_interval + self.save_dir = save_dir + self.enable_profiling = enable_profiling + assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size" + self.num_microbatches = batch_size // minibatch_size + self.data_uid = 0 + self.sync_model_thread_started = False + + self.model_config = model_config + self.plugin_config = plugin_config + + self.device = get_current_device() + self.lr_scheduler = None + + self.shared_sync_data_actor = shared_sync_data_actor + self.shared_signal_actor = shared_signal_actor + self.state_dict_cpu = {} + + def setup(self) -> None: + launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) + + plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2) + if ( + self.plugin_config.get("pp_size", 1) > 1 + and "num_microbatches" not in self.plugin_config + and "microbatch_size" not in self.plugin_config + ): + plugin_config["microbatch_size"] = max(1, self.minibatch_size // plugin_config.get("pp_size", 1)) + plugin_config.update(self.plugin_config) + self.plugin = HybridParallelPlugin(**plugin_config) + self.booster = Booster(plugin=self.plugin) + self.dp_rank = dist.get_rank(self.plugin.dp_group) + self.tp_rank = dist.get_rank(self.plugin.tp_group) + self.pp_rank = dist.get_rank(self.plugin.pp_group) + + self.dp_size = dist.get_world_size(self.plugin.dp_group) + self.tp_size = dist.get_world_size(self.plugin.tp_group) + self.pp_size = dist.get_world_size(self.plugin.pp_group) + + self.buffer = [] + self.recv_cnt = 0 + self.profiler = CustomProfiler(f"C{self.rank}", disabled=not self.enable_profiling) + + def get_ddp_config(self) -> Dict[str, Any]: + """ + Get the DDP configuration for the consumer. + This method is used to get the DDP configuration for the consumer. + """ + return { + "dp_size": self.dp_size, + "tp_size": self.tp_size, + "pp_size": self.pp_size, + "dp_rank": self.dp_rank, + "tp_rank": self.tp_rank, + "pp_rank": self.pp_rank, + "world_size": self.world_size, + "rank": self.rank, + } + + def init_collective_group( + self, + world_size: int, + rank: int, + backend: str = "nccl", + group_name: str = "default", + gloo_timeout: int = 3000000, + ): + cc.init_collective_group( + world_size=world_size, rank=rank, backend=backend, group_name=group_name, gloo_timeout=gloo_timeout + ) + print(f"[C{self.rank}] Initialized {group_name} collective group", flush=True) + + def state_dict(self) -> Dict[str, torch.Tensor]: + raise NotImplementedError + + def step(self, **kwargs) -> Optional[float]: + raise NotImplementedError + + def prepare_mini_batch(self, effective_group_to_raw_group_mapping: Dict[int, int]) -> Dict[str, torch.Tensor]: + """ + Prepare a mini-batch from the effective group to raw group mapping. + This method is used to create a mini-batch for training. + """ + batches = [ + self.buffer[effective_group_to_raw_group_mapping[i]] + for i in range(self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size) + ] + # every dp_rank will receive a complete mini-batch, no need to sync within step() later + # each mini-batch use the first self.dp_size * minibatch_size effective samples + raw_mini_batches = self.buffer[ + : effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 + ] # include the last effective sample + raw_mini_batches_metric_dict = { + "raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches], + "raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches], + "raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches], + "raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches], + } + batch = bind_batch([t[0] for t in batches]) + batch = post_recv(batch) + return batch, raw_mini_batches_metric_dict + + def calculate_effective_group_to_raw_group_mapping(self): + effective_group_to_raw_group_mapping = {} + for buffer_idx in range(len(self.buffer)): + if self.buffer[buffer_idx][0] is not None: + effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx + return effective_group_to_raw_group_mapping + + def loop(self) -> None: + print(f"Consumer{self.rank}, nmb: {self.num_microbatches}") + for episode in range(self.num_episodes): + with tqdm( + range(self.train_dataset_size), + desc=f"Episode {episode} with rollout step(s)", + disable=self.rank != 0, + ) as pbar: + while self.received_prompts < self.train_dataset_size: + torch.cuda.reset_peak_memory_stats() + effective_group_to_raw_group_mapping = {} + self.profiler.enter(f"recv_data") + while len(effective_group_to_raw_group_mapping) < self.dp_size * self.minibatch_size: + # receive data from producers + raw_batch = ray.get( + self.shared_sync_data_actor.get_data.remote(self.data_uid) + ) # get the first queued data + while raw_batch is None: + self.profiler.log(f"No data received by consumer {self.rank}, skipping") + print( + f"[T{dist.get_rank()}] No data received by consumer {self.rank}, skipping. Consider increasing the data actor buffer limit" + ) + time.sleep(1) + raw_batch = ray.get(self.shared_sync_data_actor.get_data.remote(self.data_uid)) + continue + self.data_uid += 1 + raw_batch = {k: v.to(self.device) for k, v in raw_batch.items()} + # calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete), + # we need to calculate the metrics before filtering here for logging + # [batch_size, num_generations] -> [batch_size] + reward = raw_batch["reward"][:, :, 0] + format_acc = raw_batch["format_acc"][:, :, 0] + ans_acc = raw_batch["ans_acc"][:, :, 0] + response_len = ( + raw_batch["response_idx"][:, :, 1] - raw_batch["response_idx"][:, :, 0] + 1 + ).type(torch.float32) + effective_group_mask = None + if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True): + # filter the group based on the reward and accuracy + group_ans_acc_mean = ans_acc.mean(dim=1) + effective_group_mask = torch.logical_and( + group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1] + ) + + raw_batch = unbind_batch(raw_batch) # List[Dict[str, torch.Tensor]] + self.received_prompts += len(raw_batch) + pbar.update(len(raw_batch)) + for group_idx, group_with_reward in enumerate(raw_batch): + self.buffer.append( + [ + ( + group_with_reward + if effective_group_mask is None or effective_group_mask[group_idx] + else None + ), + reward[group_idx], + format_acc[group_idx], + ans_acc[group_idx], + response_len[group_idx], + ] + ) + if effective_group_mask is not None: + print( + f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups" + ) + # mapping the effective group to the raw group for indexing + effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping() + print( + f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}" + ) + self.profiler.exit(f"recv_data") + need_sync_model = False + while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size: + # after we have enough effective groups, we can start training + # on each dp_rank, we use minibatch_size effective samples to form a batch + batch, raw_mini_batches_metric_dict = self.prepare_mini_batch( + effective_group_to_raw_group_mapping + ) + self.profiler.enter("step") + loss = self.step(pbar, **batch, **raw_mini_batches_metric_dict) + self.profiler.exit("step") + self.buffer = self.buffer[ + effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 : + ] + # recalculate the effective group to raw group mapping + effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping) + effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping() + assert ( + len(effective_group_to_raw_group_mapping) + == effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size + ) + # cc.barrier(group_name="consumer_pg") + if loss is not None: + pbar.set_postfix({"loss": loss}) + need_sync_model = True + ray.get(self.shared_signal_actor.set_signal.remote("global_step", self.global_step + 1)) + if need_sync_model and ( + (self.global_step + 1) % self.save_interval == 0 + or self.received_prompts >= self.train_dataset_size + ): + if self.rank == 0: + print(f"Start saving policy model at step {self.global_step + 1}.") + save_path = os.path.join( + self.save_dir, f"modeling-episode-{episode}-step-{self.global_step + 1}" + ) + self.booster.save_model(self.policy_model, save_path, shard=True) + if self.rank == 0: + print(f"Saved model checkpoint at step {self.global_step + 1} in folder {save_path}") + + if need_sync_model and ( + episode != self.num_episodes - 1 or self.received_prompts != self.train_dataset_size + ): + + def sync_model_thread(): + # sync model weights to all producers, if no model update or it is the last training step, skip syncing + if self.pp_size > 1: + print( + f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {self.global_step}" + ) + else: + print(f"[T{dist.get_rank()}] Sync model episode {episode} step {self.global_step}") + torch.cuda.empty_cache() + if self.pp_size > 1: + if self.tp_rank == 0 and self.dp_rank == 0: + self.profiler.enter("sync_model") + ray.get( + self.shared_signal_actor.set_signal.remote( + f"consumer_pp_{self.pp_rank}", "ready_sync_model" + ) + ) + print( + f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {self.global_step}" + ) + ray_broadcast_tensor_dict( + self.state_dict_cpu, + src=0, + device=torch.device("cpu"), + group_name=f"sync_model_consumer_pp_{self.pp_rank}", + backend="gloo", + ) + self.profiler.exit("sync_model") + else: + if self.rank == 0: + self.profiler.enter("sync_model") + ray.get(self.shared_signal_actor.set_signal.remote("consumer", "ready_sync_model")) + print(f"[T{dist.get_rank()}] Sync model episode {episode} step {self.global_step}") + ray_broadcast_tensor_dict( + self.state_dict_cpu, + src=0, + device=torch.device("cpu"), + group_name="sync_model_consumer", + backend="gloo", + ) + self.profiler.exit("sync_model") + + if not self.sync_model_thread_started: + # only sync model when the thread is not started and no other thread is broadcasting + self.sync_model_thread_started = True + state_dict_ = self.state_dict() + if (self.pp_size > 1 and self.tp_rank == 0 and self.dp_rank == 0) or ( + self.pp_size == 1 and self.rank == 0 + ): + if len(self.state_dict_cpu) == 0: + # use pinned memory to speed up the transfer + self.state_dict_cpu = {k: v.cpu().pin_memory() for k, v in state_dict_.items()} + torch.cuda.synchronize() + for k, v in state_dict_.items(): + self.state_dict_cpu[k].copy_(v, non_blocking=True) + torch.cuda.synchronize() + cc.barrier( + group_name="consumer_pg" + ) # to make sure all ranks have state dict offloaded to CPU before starting the thread + time_before_starting_thread = time.time() + threading.Thread(target=sync_model_thread).start() + # sync_model_thread() + self.profiler.log( + f"Sync model, took {time.time() - time_before_starting_thread:.2f} seconds" + ) + self.sync_model_thread_started = False + # ray.get(self.shared_signal_actor.release_process_lock.remote("broadcasting_lock")) + self.profiler.log(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") + self.received_prompts = 0 + ray.get(self.shared_signal_actor.set_signal.remote("consumer", "terminate")) + + def __del__(self): + if hasattr(self, "profiler"): + self.profiler.close() diff --git a/applications/ColossalChat/coati/distributed/zero_bubble/distributor.py b/applications/ColossalChat/coati/distributed/zero_bubble/distributor.py new file mode 100644 index 000000000000..b16f4b67ef39 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/zero_bubble/distributor.py @@ -0,0 +1,108 @@ +import time + +import ray +import ray.util.collective as cc +import torch +from coati.distributed.profiling_utils import CustomProfiler + +from colossalai.utils import get_current_device + +from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict + + +@ray.remote +class Distributor: + def __init__( + self, + distributor_id, + consumer_pp_size, + num_producers, + shared_signal_actor: SharedVariableActor, + enable_profiling: bool = True, + ): + self.distributor_id = distributor_id + self.consumer_pp_size = consumer_pp_size + self.state_dict_cpu = {} + self.num_producers = num_producers + self.shared_signal_actor = shared_signal_actor + self.device = get_current_device() + self.profiler = CustomProfiler(f"D{self.distributor_id}", disabled=not enable_profiling) + + def init_collective_group( + self, + world_size: int, + rank: int, + backend: str = "nccl", + group_name: str = "default", + gloo_timeout: int = 3000000, + ): + cc.init_collective_group( + world_size=world_size, rank=rank, backend=backend, group_name=group_name, gloo_timeout=gloo_timeout + ) + print(f"[D] Initialized {group_name} collective group", flush=True) + + def loop(self): + while True: + time.sleep(1) + signal = ray.get(self.shared_signal_actor.get_signal.remote()) + if self.consumer_pp_size > 1: + for i in range(self.consumer_pp_size): + if signal.get(f"consumer_pp_{i}", None) == "ready_sync_model": + self.profiler.enter(f"sync_model_consumer_pp_{i}") + cc.barrier(group_name="distributor_pg") + ray.get(self.shared_signal_actor.set_signal.remote(f"consumer_pp_{i}", "not_ready_sync_model")) + # Broadcast the model state dict from consumer to shared variable actor + self.state_dict_cpu[i] = ray_broadcast_tensor_dict( + None, + 0, + device=torch.device("cpu"), + group_name=f"sync_model_consumer_pp_{i}", + backend="gloo", + ) + self.profiler.exit(f"sync_model_consumer_pp_{i}") + for i in range(self.consumer_pp_size): + if signal.get(f"producer_{self.distributor_id}_pp_{i}", None) == "ready_sync_model": + self.profiler.enter(f"sync_model_producer_{self.distributor_id}_pp_{i}") + # Broadcast the model state dict to all producers + ray.get( + self.shared_signal_actor.set_signal.remote( + f"producer_{self.distributor_id}_pp_{i}", "not_ready_sync_model" + ) + ) + ray_broadcast_tensor_dict( + self.state_dict_cpu[i], + 1, + device=torch.device("cpu"), + group_name=f"sync_model_producer_{self.distributor_id}_pp_{i}", + backend="gloo", + ) + self.profiler.exit(f"sync_model_producer_{self.distributor_id}_pp_{i}") + else: + if signal.get("consumer", None) == "ready_sync_model": + self.profiler.enter("sync_model_consumer") + cc.barrier(group_name="distributor_pg") + ray.get(self.shared_signal_actor.set_signal.remote("consumer", "not_ready_sync_model")) + # Broadcast the model state dict from consumer to shared variable actor + self.state_dict_cpu = ray_broadcast_tensor_dict( + None, 0, device=torch.device("cpu"), group_name="sync_model_consumer", backend="gloo" + ) + self.profiler.exit("sync_model_consumer") + if signal.get(f"producer_{self.distributor_id}", None) == "ready_sync_model": + self.profiler.enter(f"sync_model_producer_{self.distributor_id}") + # Broadcast the model state dict to all producers + ray.get( + self.shared_signal_actor.set_signal.remote( + f"producer_{self.distributor_id}", "not_ready_sync_model" + ) + ) + ray_broadcast_tensor_dict( + self.state_dict_cpu, + 1, + device=torch.device("cpu"), + group_name=f"sync_model_producer_{self.distributor_id}", + backend="gloo", + ) + self.profiler.exit(f"sync_model_producer_{self.distributor_id}") + if signal.get("consumer", None) == "terminate": + self.profiler.log("terminate sync model worker") + break diff --git a/applications/ColossalChat/coati/distributed/zero_bubble/grpo_consumer.py b/applications/ColossalChat/coati/distributed/zero_bubble/grpo_consumer.py new file mode 100644 index 000000000000..c07385b97c21 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/zero_bubble/grpo_consumer.py @@ -0,0 +1,498 @@ +from contextlib import nullcontext +from typing import Any, Optional + +import ray +import torch +import wandb +from coati.distributed.comm import SharedVariableActor +from coati.distributed.zero_bubble.consumer import BaseConsumer +from coati.distributed.loss import PolicyLoss +from coati.distributed.utils import memory_efficient_logprob +from coati.trainer.utils import all_reduce_mean, all_reduce_sum +from transformers import AutoModelForCausalLM, AutoTokenizer + +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam + + +@ray.remote +class GRPOConsumer(BaseConsumer): + def __init__( + self, + shared_sync_data_actor: SharedVariableActor, + shared_signal_actor: SharedVariableActor, + num_producers, + num_episodes, + rank, + world_size, + master_addr, + master_port, + train_dataset_size, + batch_size, + model_config, + plugin_config, + minibatch_size=1, + num_generations=8, + generate_config=None, + grpo_config={}, + save_interval: int = 100, + save_dir="./model", + project_name: str = None, + run_name: str = None, + wandb_group_name: str = None, + enable_profiling: bool = False, + ): + print(f"Using GRPO config: {grpo_config}") + if ( + plugin_config.get("pp_size", 1) > 1 + and "num_microbatches" not in plugin_config + and "microbatch_size" not in plugin_config + ): + plugin_config["microbatch_size"] = max( + 1, grpo_config.get("train_microbatch_size") // plugin_config.get("pp_size", 1) + ) + super().__init__( + shared_sync_data_actor, + shared_signal_actor, + num_producers, + num_episodes, + rank, + world_size, + master_addr, + master_port, + train_dataset_size, + batch_size, + model_config, + plugin_config, + minibatch_size, + save_interval=save_interval, + save_dir=save_dir, + enable_profiling=enable_profiling, + ) + path = model_config.pop("path") + self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) + self.policy_model.train() + self.policy_model.gradient_checkpointing_enable() + self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6)) + self.accum_loss = torch.zeros(1, device=self.device) + self.accum_kl = torch.zeros(1, device=self.device) + self.accum_advantages = torch.zeros(1, device=self.device) + self.raw_train_batch_reward = [] + self.raw_train_batch_format_acc = [] + self.raw_train_batch_ans_acc = [] + self.raw_train_batch_response_len = [] + self.accum_count = 0 + self.generate_config = generate_config + self.grpo_config = grpo_config + self.project_name = project_name + self.effective_sample_count = 0 + self.effective_prompt_count = 0 + self.project_name = project_name + self.run_name = run_name + self.wandb_group_name = wandb_group_name + + self.policy_loss_fn = PolicyLoss( + clip_eps_low=grpo_config.get("clip_eps_low", 0.2), + clip_eps_high=grpo_config.get("clip_eps_high", 0.2), + beta=grpo_config.get("beta", 0.01), + loss_variation=grpo_config.get("loss_variation", "sample_level"), + ) + + # Reference model is initialized from policy model. + if self.policy_loss_fn.beta > 0: + self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) + self.reference_model.eval() + + self.tokenizer = AutoTokenizer.from_pretrained(path) + self.pad_token_id = self.tokenizer.pad_token_id + self.num_generations = num_generations + self.filter_range = grpo_config.get("filter_range", None) + if self.filter_range is not None: + assert len(self.filter_range) == 2, "Filter range should have 2 values." + + self.filter_truncated_response = grpo_config.get("filter_truncated_response", False) + if self.filter_truncated_response: + self.max_length = 0 + if "max_tokens" in self.generate_config: + self.max_length = self.generate_config["max_tokens"] + elif "max_new_tokens" in self.generate_config: + self.max_length = self.generate_config["max_new_tokens"] + else: + raise ValueError( + "either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config." + ) + # Initialize verifiable reward. + grpo_config.get("response_format_tags", None) + self.global_step = 0 + + def setup(self): + super().setup() + if (not self.plugin.pp_size > 1 and self.rank == 0) or ( + self.plugin.pp_size > 1 + and self.booster.plugin.stage_manager.is_last_stage() + and self.tp_rank == 0 + and self.dp_rank == 0 + ): + self.wandb_run = wandb.init( + project=self.project_name, + sync_tensorboard=False, + dir="./wandb", + name=self.run_name, + group=self.wandb_group_name, + ) + + self.lr_scheduler = CosineAnnealingWarmupLR( + optimizer=self.optimizer, + total_steps=min(self.num_episodes, 4) * self.train_dataset_size // (self.batch_size * self.dp_size), + warmup_steps=0, + eta_min=0.1 * self.grpo_config.get("lr", 1e-6), + ) + + self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost( + self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler + ) + if self.policy_loss_fn.beta > 0: + self.reference_model, *_ = self.booster.boost(self.reference_model) + self.plugin.logger.set_level("ERROR") + + def step(self, pbar: Any, **kwargs) -> Optional[float]: + """ + Step data from policy model: + [{ + "input_ids": torch.Tensor, + "attention_mask": torch.Tensor, + "action_mask": torch.Tensor, + "action_log_probs": torch.Tensor, + }, + ...] + Format: + [minibatch_size, num_of_generation, prompt_length + response_length] --- ............. + """ + # Reshape to [minibatch_size x num_of_generation, prompt_length + response_length] + data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items() if "raw_train_mini_batch_" not in k} + self.raw_train_batch_reward.extend(kwargs["raw_train_mini_batch_reward"]) + self.raw_train_batch_format_acc.extend(kwargs["raw_train_mini_batch_format_acc"]) + self.raw_train_batch_ans_acc.extend(kwargs["raw_train_mini_batch_ans_acc"]) + self.raw_train_batch_response_len.extend(kwargs["raw_train_mini_batch_response_len"]) + action_mask = data["action_mask"] + num_action = action_mask.shape[1] + old_action_log_probs = data["action_log_probs"] + response_length = torch.sum(action_mask, dim=1).to(torch.float32) + train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0)) + + reward = data["reward"].view((-1)) + format_acc = data["format_acc"].view((-1)) + ans_acc = data["ans_acc"].view((-1)) + + # [minibatch_size, num_generations] + + group_reward = reward.view(-1, self.num_generations) + reward_mean = group_reward.mean(dim=1) + # [minibatch_size x num_generations] + reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) + + reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) + # [minibatch_size x num_generations] + advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) + + # [minibatch_size x num_of_generation] + loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool() + + # filter out overlength samples + if self.filter_truncated_response and action_mask.size(1) == self.max_length: + loss_mask = torch.logical_and( + loss_mask, + action_mask[:, -1] == False, + ) + if self.filter_range is not None and self.grpo_config.get("dynamic_batching", False) == False: + # filter out samples with reward outside the range + # if dynamic batching is enabled, we filter out out of range groups before training + group_ans_acc_mean = ( + ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=-1) + ) + loss_mask = torch.logical_and( + loss_mask, + torch.logical_and( + group_ans_acc_mean > self.filter_range[0], + group_ans_acc_mean < self.filter_range[1], + ), + ) + self.effective_prompt_count += ( + group_reward.size(0) * self.dp_size + ) # all prompts in the batch are effective as we filtered out the bad ones before step. + + mean_kl, mean_loss = [], [] + + need_update = self.effective_prompt_count >= self.batch_size * self.dp_size + + effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin) + effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask + total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin) + self.effective_sample_count += effective_samples.item() + pbar.set_postfix( + { + "Global Step": self.global_step, + "Gradient Accumulation on": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size} effective prompts, {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations} effective samples", + } + ) + + # Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500 + ctx = ( + nullcontext() + if need_update or self.booster.plugin.zero_stage == 2 + else self.booster.no_sync(self.policy_model, self.optimizer) + ) + with ctx: + for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size): + input_ids_forward_micro_batch = data["input_ids"][ + forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size + ] + attention_mask_forward_micro_batch = data["attention_mask"][ + forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size + ] + action_mask_forward_micro_batch = action_mask[ + forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size + ] + loss_mask_forward_micro_batch = ( + loss_mask[forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size] + if loss_mask is not None + else None + ) + advantages_forward_micro_batch = advantages[ + forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size + ] + + if self.plugin.pp_size > 1: + # Support training with PP. + if self.policy_loss_fn.beta > 0: + with torch.no_grad(): + reference_model_outputs = self.booster.execute_pipeline( + iter( + [ + { + "input_ids": input_ids_forward_micro_batch, + "attention_mask": attention_mask_forward_micro_batch, + } + ] + ), + self.reference_model, + criterion=lambda outputs, inputs: torch.tensor( + [0.0], device=action_mask.device + ), # dummy criterion + optimizer=None, + return_loss=False, + return_outputs=True, + ) + + if self.booster.plugin.stage_manager.is_last_stage(): + reference_action_log_probs = memory_efficient_logprob( + reference_model_outputs["outputs"]["logits"] / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + shard_config=self.plugin.shard_config, + ) + else: + # Dummy reference logprobs for data iterator. + reference_action_log_probs = None + else: + reference_action_log_probs = None + + data_policy_forward = { + "input_ids": input_ids_forward_micro_batch, + "attention_mask": attention_mask_forward_micro_batch, + "action_mask": action_mask_forward_micro_batch, + "advantages": advantages_forward_micro_batch, + "loss_mask": loss_mask_forward_micro_batch, + "source": self.rank, + } + if reference_action_log_probs is not None: + data_policy_forward["reference_action_log_probs"] = reference_action_log_probs + + kl = [] + + def _criterion(outputs, inputs): + action_logits = outputs.logits + action_log_probs = memory_efficient_logprob( + action_logits / self.generate_config["temperature"], + inputs["input_ids"], + num_action, + shard_config=self.plugin.shard_config, + ) + if "reference_action_log_probs" in inputs: + per_token_kl = ( + torch.exp(inputs["reference_action_log_probs"] - action_log_probs) + - (inputs["reference_action_log_probs"] - action_log_probs) + - 1 + ) + appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum( + inputs["action_mask"], dim=-1 + ) + kl.append(appox_kl.mean()) + else: + per_token_kl = 0.0 + kl.append(torch.tensor(0.0)) + + loss, _ = self.policy_loss_fn( + action_log_probs, + action_log_probs, + inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1), + per_token_kl, + inputs["action_mask"], + loss_mask=inputs["loss_mask"], + total_effective_tokens_in_batch=total_effective_tokens_count, + ) + return loss + + policy_model_outputs = self.booster.execute_pipeline( + iter([data_policy_forward]), + self.policy_model, + criterion=_criterion, + optimizer=self.optimizer, + return_loss=True, + return_outputs=False, + ) + loss = policy_model_outputs["loss"] + + if self.booster.plugin.stage_manager.is_last_stage(): + if len(kl) > 0: + kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data + mean_kl.append(kl) + mean_loss.append(all_reduce_mean(loss, self.plugin).data) + else: + policy_model_logits = self.policy_model( + input_ids=input_ids_forward_micro_batch, + attention_mask=attention_mask_forward_micro_batch, + ).logits + action_log_probs = memory_efficient_logprob( + policy_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + shard_config=self.plugin.shard_config, + ) + + if self.policy_loss_fn.beta > 0: + with torch.no_grad(): + reference_model_logits = self.reference_model( + input_ids=input_ids_forward_micro_batch, + attention_mask=attention_mask_forward_micro_batch, + ).logits + reference_action_log_probs = memory_efficient_logprob( + reference_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + shard_config=self.plugin.shard_config, + ) + per_token_kl = ( + torch.exp(reference_action_log_probs - action_log_probs) + - (reference_action_log_probs - action_log_probs) + - 1 + ) + kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( + action_mask_forward_micro_batch, dim=-1 + ) + else: + per_token_kl = 0.0 + kl = None + + loss, _ = self.policy_loss_fn( + action_log_probs, + old_action_log_probs, + advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), + per_token_kl, + action_mask_forward_micro_batch, + loss_mask=loss_mask_forward_micro_batch, + total_effective_tokens_in_batch=total_effective_tokens_count, + ) + + self.booster.backward(loss, self.optimizer) + loss = all_reduce_mean(loss, self.plugin) + # Calculate accumulate value. + if kl is not None: + kl = all_reduce_mean(kl.mean(), self.plugin) + mean_kl.append(kl.data) + mean_loss.append(loss.data) + if not self.plugin.pp_size > 1 or ( + self.plugin.pp_size > 1 + and self.booster.plugin.stage_manager.is_last_stage() + and self.tp_rank == 0 + and self.dp_rank == 0 + ): + reward = all_reduce_mean(reward.mean(), self.plugin) + format_acc = all_reduce_mean(format_acc.mean(), self.plugin) + ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin) + advantages = all_reduce_mean(advantages.mean(), self.plugin) + response_length = all_reduce_mean(response_length.mean(), self.plugin) + self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) + if self.policy_loss_fn.beta > 0: + self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) + self.accum_advantages.add_(advantages.data) + self.accum_count += 1 + if need_update: + self.optimizer.step() + self.optimizer.zero_grad() + self.global_step += 1 + if self.lr_scheduler is not None: + self.lr_scheduler.step() + # no need to run all reduce as raw_train_batch_* are not splited across dp rank + sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations + self.effective_prompt_count = 0 + self.effective_sample_count = 0 + loss_scalar = self.accum_loss.item() + if not self.plugin.pp_size > 1 or ( + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 + ): + if (not self.plugin.pp_size > 1 and self.rank == 0) or ( + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 + ): + raw_batch_reward_mean = torch.cat(self.raw_train_batch_reward, dim=0).mean().cpu().item() + raw_batch_format_acc_mean = torch.cat(self.raw_train_batch_format_acc, dim=0).mean().cpu().item() + raw_batch_ans_acc_mean = torch.cat(self.raw_train_batch_ans_acc, dim=0).mean().cpu().item() + raw_batch_response_len = torch.cat(self.raw_train_batch_response_len, dim=0) + raw_batch_response_len_mean = raw_batch_response_len.mean().cpu().item() + overlength_samples_ratio = ( + (raw_batch_response_len >= action_mask.size(-1)).to(float).mean().cpu().item() + ) # not an exact figure, but a close estimate + self.raw_train_batch_reward = [] + self.raw_train_batch_format_acc = [] + self.raw_train_batch_ans_acc = [] + self.raw_train_batch_response_len = [] + to_log_msg = [ + f"Loss: {self.accum_loss.item() / self.accum_count:.4f}", + f"Reward: {raw_batch_reward_mean:.4f}", + f"format Reward: {raw_batch_format_acc_mean:.4f}", + f"Acc Reward: {raw_batch_ans_acc_mean:.4f}", + f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}", + f"Response Length: {raw_batch_response_len_mean:.4f}", + f"Sample_utilization: {sample_utilization:.4f}", + f"Overlength samples ratio: {overlength_samples_ratio:.4f}", + ] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else []) + print("\n".join(to_log_msg)) + metrics = { + "metrics/reward": raw_batch_reward_mean, + "metrics/format_acc": raw_batch_format_acc_mean, + "metrics/ans_acc": raw_batch_ans_acc_mean, + "metrics/response_length": raw_batch_response_len_mean, + "train/loss": self.accum_loss.item() / self.accum_count, + "train/advantages": self.accum_advantages.item() / self.accum_count, + "train/learning_rate": self.lr_scheduler.get_last_lr()[0], + "train/sample_utilization": sample_utilization, + "train/overlength_samples_ratio": overlength_samples_ratio, + "rollout/temperature": data["temperature"].cpu().numpy()[0][0], + } + if self.policy_loss_fn.beta > 0: + metrics["train/kl"] = self.accum_kl.item() / self.accum_count + if self.wandb_run is not None: + self.wandb_run.log(metrics) + self.accum_loss.zero_() + self.accum_kl.zero_() + self.accum_advantages.zero_() + self.accum_count = 0 + return loss_scalar + else: + return None + + def state_dict(self): + self.policy_model._force_wait_all_gather() + model = self.policy_model.unwrap() + state_dict = model.state_dict() + return state_dict diff --git a/applications/ColossalChat/coati/distributed/zero_bubble/producer.py b/applications/ColossalChat/coati/distributed/zero_bubble/producer.py new file mode 100644 index 000000000000..9e57914c4e04 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/zero_bubble/producer.py @@ -0,0 +1,533 @@ +import copy +import json +import os +import threading +import time +from typing import Any, Dict, Optional + +import ray +import ray.util.collective as cc +import torch +import tqdm +import wandb +from coati.dataset.loader import RawConversationDataset, collate_fn_grpo +from coati.distributed.profiling_utils import CustomProfiler +from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn +from coati.distributed.reward.verifiable_reward import VerifiableReward +from ray.util.collective import allreduce +from ray.util.collective.types import ReduceOp +from torch.utils.data import DataLoader, DistributedSampler +from transformers import AutoTokenizer + +from colossalai.utils import get_current_device + +from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict +from coati.distributed.inference_backend import BACKEND_MAP +from coati.distributed.utils import pre_send, safe_append_to_jsonl_file + +try: + from vllm import SamplingParams +except ImportError: + LLM = None + + +class BaseProducer: + def __init__( + self, + shared_sync_data_actor: SharedVariableActor, + shared_signal_actor: SharedVariableActor, + producer_idx: int, + num_producers: int, + num_consumer_procs: int, + num_episodes: int, + batch_size: int, + train_dataset_config: Dict[str, Any], + model_config: Dict[str, Any], + generate_config: Dict[str, Any], + tokenizer_config: Optional[Dict[str, Any]] = None, + microbatch_size: int = 1, + backend: str = "transformers", + consumer_plugin_config: Dict[str, Any] = None, + eval_dataset_config=None, + eval_interval=-1, # disable evaluation + grpo_config: Dict[str, Any] = None, + eval_save_dir: str = "./eval", + project_name: str = None, + run_name: str = None, + wandb_group_name: str = None, + log_rollout_interval: int = 20, + rollout_log_file: str = "./rollout_log.jsonl", + enable_profiling: bool = False, + ): + self.producer_idx = producer_idx + self.num_producers = num_producers + self.num_consumer_procs = num_consumer_procs + self.num_episodes = num_episodes + self.batch_size = batch_size + self.microbatch_size = microbatch_size + assert batch_size % microbatch_size == 0 + self.num_microbatches = batch_size // microbatch_size + self.latest_eval_step = -1 + self.profiler = CustomProfiler(f"P{self.producer_idx}", disabled=not enable_profiling) + + # for async data and model sync + self.shared_sync_data_actor = shared_sync_data_actor + self.shared_signal_actor = shared_signal_actor + self.sync_model_thread_started = False + + self.train_dataset_config = train_dataset_config + self.model_config = model_config + self.generate_config = generate_config + self.tokenizer_config = tokenizer_config + self.consumer_plugin_config = consumer_plugin_config + self.eval_interval = eval_interval + self.eval_save_dir = eval_save_dir + self.consumer_global_step = 0 + self.producer_weight_version = 0 + self.eval_mode = False + self.log_rollout_interval = log_rollout_interval + self.latest_rollout_log_step = -1 + self.grpo_config = grpo_config + reward_model_kwargs = { + k: v + for k, v in grpo_config.items() + if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"] + } + self.response_format_tags = grpo_config.get("response_format_tags", None) + if producer_idx == 0: + if os.path.exists(rollout_log_file): + raise ValueError( + f"Rollout log file {rollout_log_file} already exists. Please delete it or change the name." + ) + else: + os.makedirs(os.path.dirname(rollout_log_file), exist_ok=True) + self.rollout_log_file = open(rollout_log_file, "w", encoding="utf8") + if self.producer_idx == 0: + self.wandb_run = wandb.init( + project=project_name, + sync_tensorboard=False, + dir="./wandb", + name=run_name + "_eval", + group=wandb_group_name, + ) + + if os.path.exists(self.eval_save_dir) and self.eval_interval > 0: + raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.") + + # init tokenizer + if tokenizer_config is None: + tokenizer_path = model_config["path"] + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + else: + tokenizer_path = tokenizer_config.pop("path") + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_config) + self.tokenizer.padding_side = "left" + + # init dataloader + train_dataset_path = train_dataset_config.pop("path") + self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config) + self.train_dataloader = DataLoader( + self.train_dataset, + batch_size=microbatch_size, + sampler=DistributedSampler( + self.train_dataset, + num_replicas=num_producers, + rank=producer_idx, + shuffle=True, + drop_last=True, + seed=42, + ), + num_workers=4, + drop_last=True, + collate_fn=collate_fn_grpo, + ) + if grpo_config["reward_fn_type"] == "think_answer_tags": + self.evaluation_function = math_reward_fn + elif grpo_config["reward_fn_type"] == "boxed": + self.evaluation_function = boxed_math_reward_fn + elif grpo_config["reward_fn_type"] == "code": + self.evaluation_function = code_reward_fn + else: + raise ValueError(f"Unknown evaluation function type {grpo_config['reward_fn_type']}") + + self.eval_dataset_config = eval_dataset_config + if self.eval_dataset_config is not None: + self.eval_dataloaders = {} + for eval_task_name in self.eval_dataset_config: + eval_dataset_path = eval_dataset_config[eval_task_name].pop("path") + eval_dataset = RawConversationDataset( + self.tokenizer, eval_dataset_path, **eval_dataset_config[eval_task_name] + ) + print(f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}") + self.eval_dataloaders[eval_task_name] = DataLoader( + eval_dataset, + batch_size=microbatch_size, + sampler=DistributedSampler( + eval_dataset, + num_replicas=num_producers, + rank=producer_idx, + shuffle=False, + drop_last=False, + seed=42, + ), + collate_fn=collate_fn_grpo, + ) + else: + print("No eval dataset provided, skip eval") + self.device = get_current_device() + self.reward_model = VerifiableReward( + reward_fns=[self.evaluation_function], # multiple reward functions can be added here + tokenizer=self.tokenizer, + tags=self.response_format_tags, + **reward_model_kwargs, + ) + + # init backend + if backend in BACKEND_MAP: + self.backend_cls = BACKEND_MAP[backend] + else: + raise ValueError(f"Unexpected backend {backend}") + + self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size + self.state_dict_cpu = {i: None for i in range(self.consumer_pp_size)} + + def init_collective_group( + self, + world_size: int, + rank: int, + backend: str = "nccl", + group_name: str = "default", + gloo_timeout: int = 3000000, + ): + cc.init_collective_group( + world_size=world_size, rank=rank, backend=backend, group_name=group_name, gloo_timeout=gloo_timeout + ) + print(f"[P{self.producer_idx}] Initialized {group_name} collective group", flush=True) + + def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: + raise NotImplementedError + + def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: + raise NotImplementedError + + def loop(self) -> None: + num_update_per_episode = len(self.train_dataloader) // self.num_microbatches + num_valid_microbatches = num_update_per_episode * self.num_microbatches + + print( + f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.train_dataloader)}" + ) + for episode in range(self.num_episodes): + self.train_dataloader.sampler.set_epoch(episode) + for i, batch in enumerate(self.train_dataloader): + self.profiler.log(f"train episode {episode} batch {i}") + if i >= num_valid_microbatches: + break + + self.consumer_global_step = ray.get(self.shared_signal_actor.get_signal.remote()).get("global_step", 0) + # sync model first, as the model syncing runs in a separate thread, will not block the main thread + # sync model during inference, which takes less than 10s, so that the model can be updated immediately after inference + if episode != self.num_episodes - 1 or i != num_valid_microbatches - 1: + # don't sync model for last iteration + if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( + "enable_sleep_mode", False + ): + self.model.llm.sleep() # revict KV_cache to avoid OOM + torch.cuda.empty_cache() + + # sync model thread function + def sync_model_thread(): + if self.consumer_pp_size > 1: + self.profiler.enter("sync_model") + for pp_idx in range(self.consumer_pp_size): + ray.get( + self.shared_signal_actor.set_signal.remote( + f"producer_{self.producer_idx}_pp_{pp_idx}", "ready_sync_model" + ) + ) + print( + f"[P{self.producer_idx}] Sync model PP stage {pp_idx} episode {episode} step {(i + 1) // self.num_microbatches - 1}" + ) + self.state_dict_cpu[pp_idx] = ray_broadcast_tensor_dict( + self.state_dict_cpu[pp_idx], + 1, + device=torch.device("cpu"), + group_name=f"sync_model_producer_{self.producer_idx}_pp_{pp_idx}", + backend="gloo", # use gloo for CPU communication + pin_memory=True, + ) + self.profiler.exit("sync_model") + else: + self.profiler.enter("sync_model") + ray.get( + self.shared_signal_actor.set_signal.remote( + f"producer_{self.producer_idx}", "ready_sync_model" + ) + ) + print( + f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}" + ) + time0 = time.time() + self.state_dict_cpu[0] = ray_broadcast_tensor_dict( + self.state_dict_cpu[0], + 1, + device=torch.device("cpu"), + group_name=f"sync_model_producer_{self.producer_idx}", + backend="gloo", # use gloo for CPU communication + pin_memory=True, + ) + self.profiler.log(f"Broadcast model state dict took {time.time() - time0:.2f} seconds") + self.profiler.exit("sync_model") + self.sync_model_thread_started = False + + if not self.sync_model_thread_started and self.consumer_global_step != self.producer_weight_version: + # only sync model when the thread is not started and global step is changed + self.sync_model_thread_started = True + self.sync_model_thread = threading.Thread(target=sync_model_thread) + self.producer_weight_version = self.consumer_global_step + self.sync_model_thread.start() + torch.cuda.empty_cache() + if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( + "enable_sleep_mode", False + ): + self.model.llm.wake_up() + + if self.eval_interval > 0 and self.eval_dataset_config is not None: + if ( + self.consumer_global_step - self.latest_eval_step >= self.eval_interval + and self.consumer_global_step > self.latest_eval_step + ) or self.latest_eval_step == -1: + to_log_msg = {} + self.eval_mode = True + for eval_task_name in self.eval_dataloaders: + if self.producer_idx == 0: + print( + f"[P{self.producer_idx}] Evaluate model at training step {self.consumer_global_step} on task {eval_task_name}" + ) + eval_results = [] + eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device) + for eval_batch in tqdm.tqdm( + self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0 + ): + eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params) + eval_results = eval_results + [ + self.evaluation_function( + eval_outputs["input_ids"][m][n], + eval_outputs[ + ( + "test_cases" + if self.grpo_config["reward_fn_type"] == "code" + else "gt_answer" + ) + ][m], + eval_outputs["response_idx"][m][n], + tokenizer=self.tokenizer, + eval_mode=True, + tags=self.response_format_tags, + ) + for m in range(eval_outputs["input_ids"].size(0)) + for n in range(eval_outputs["input_ids"].size(1)) + ] + eval_statistics_tensor[0] += len([res for res in eval_results if res["ans_valid"] == 1]) + eval_statistics_tensor[1] += len(eval_results) + allreduce(eval_statistics_tensor, op=ReduceOp.SUM, group_name="producer_pg") + to_log_msg[f"eval/{eval_task_name}"] = ( + eval_statistics_tensor[0].item() / eval_statistics_tensor[1].item() + ) + if self.producer_idx == 0: + print( + f"[P{self.producer_idx}]: Accuracy on {eval_task_name}: {to_log_msg[f'eval/{eval_task_name}']}" + ) + # save eval results + safe_append_to_jsonl_file( + os.path.join( + self.eval_save_dir, + f"{eval_task_name}_training_step_{self.consumer_global_step}.jsonl", + ), + eval_results, + ) + + if self.producer_idx == 0: + self.wandb_run.log(to_log_msg, step=self.consumer_global_step) + self.eval_mode = False + self.latest_eval_step = self.consumer_global_step + self.profiler.enter("sleep") + while not (ray.get(self.shared_sync_data_actor.pickup_rollout_task.remote(self.microbatch_size))): + time.sleep(1) + self.profiler.exit("sleep") + self.profiler.enter("rollout") + self.profiler.log(f"rollout batch {i} episode {episode}") + # time.sleep(30) # simulate long inference time + outputs = self.rollout(**batch) + self.profiler.exit("rollout") + outputs["temperature"] = torch.tensor( + [self.model.generate_config["temperature"]] * outputs["input_ids"].size(0) + ).to(outputs["input_ids"].device) + bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1) + self.profiler.enter("calculate_reward") + if self.grpo_config["reward_fn_type"] == "code": + test_cases = [] + for prompt_id in range(bs): + test_cases.extend([outputs["test_cases"][prompt_id]] * num_gen) + reward_model_output = self.reward_model( + outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))), + test_cases=test_cases, + response_idx=outputs["response_idx"].view((-1, 2)), + ) + else: + gt_answer = [] + for prompt_id in range(bs): + gt_answer.extend([outputs["gt_answer"][prompt_id]] * num_gen) + reward_model_output = self.reward_model( + outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))), + gt_answer=gt_answer, + response_idx=outputs["response_idx"].view((-1, 2)), + ) + outputs["reward"] = ( + torch.tensor([value[0] for value in reward_model_output]) + .to(outputs["input_ids"].device) + .view((bs, num_gen, 1)) + ) + outputs["format_acc"] = ( + torch.tensor([value[1] for value in reward_model_output]) + .to(outputs["input_ids"].device) + .view((bs, num_gen, 1)) + ) + outputs["ans_acc"] = ( + torch.tensor([value[2] for value in reward_model_output]) + .to(outputs["input_ids"].device) + .view((bs, num_gen, 1)) + ) + if "gt_answer" in outputs: + outputs.pop("gt_answer") + if "test_cases" in outputs: + outputs.pop("test_cases") + self.profiler.exit("calculate_reward") + + print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") + outputs = pre_send(outputs) + outputs = {k: v.cpu() for k, v in outputs.items()} + self.profiler.enter("send_data") + + ray.get(self.shared_sync_data_actor.append_data.remote(outputs)) + self.profiler.exit("send_data") + + if (i + 1) % self.num_microbatches == 0 and ( + episode != self.num_episodes - 1 or i != num_valid_microbatches - 1 + ): + if not self.sync_model_thread_started: + # load state dict, note this should be done in the main thread to avoid race condition + for pp_idx in range(self.consumer_pp_size): + if self.state_dict_cpu[pp_idx] is not None and self.state_dict_cpu[pp_idx] != {}: + self.load_state_dict(self.state_dict_cpu[pp_idx]) + + # linear annealing for 1 episode, temperature from initial to 0.9 + if episode <= 0: + ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader) + self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[ + "temperature" + ] + ratio * 0.9 + if isinstance(self.model, BACKEND_MAP["vllm"]): + self.model.sample_params.temperature = (1 - ratio) * self.generate_config[ + "temperature" + ] + ratio * 0.9 + + def __del__(self): + self.profiler.close() + + +@ray.remote +class SimpleProducer(BaseProducer): + def __init__( + self, + shared_sync_data_actor: SharedVariableActor, + shared_signal_actor: SharedVariableActor, + producer_idx, + num_producers, + num_consumer_procs, + num_episodes, + batch_size, + train_dataset_config, + model_config, + generate_config, + tokenizer_config=None, + microbatch_size=1, + backend="transformers", + num_generations: int = 8, + consumer_plugin_config=None, + eval_dataset_config=None, + eval_interval=-1, # disable evaluation + grpo_config: Dict[str, Any] = None, + eval_save_dir: str = "./eval", + eval_generation_config={}, + project_name: str = None, + run_name: str = None, + wandb_group_name: str = None, + log_rollout_interval: int = 20, + rollout_log_file: str = "./rollout_log.jsonl", + enable_profiling: bool = False, + ): + super().__init__( + shared_sync_data_actor, + shared_signal_actor, + producer_idx, + num_producers, + num_consumer_procs, + num_episodes, + batch_size, + train_dataset_config, + model_config, + generate_config, + tokenizer_config, + microbatch_size, + backend, + consumer_plugin_config, + eval_dataset_config=eval_dataset_config, + eval_interval=eval_interval, + grpo_config=grpo_config, + eval_save_dir=eval_save_dir, + project_name=project_name, + run_name=run_name, + wandb_group_name=wandb_group_name, + log_rollout_interval=log_rollout_interval, + rollout_log_file=rollout_log_file, + enable_profiling=enable_profiling, + ) + self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations) + self.eval_generation_config = copy.deepcopy(self.model.generate_config) + self.eval_generation_config["n"] = 1 # use 1 generation for evaluation + self.eval_generation_config.update(eval_generation_config) + self.eval_sample_params = SamplingParams(**self.eval_generation_config) + + @torch.no_grad() + def rollout(self, input_ids, attention_mask, **kwargs): + rollouts = self.model.generate(input_ids, attention_mask, **kwargs) + if self.producer_idx == 0 and not self.eval_mode: + if ( + self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval + or self.latest_rollout_log_step == -1 + ): + new_record = ( + json.dumps( + { + "train_step": self.consumer_global_step, + "rollout": self.tokenizer.batch_decode( + rollouts["input_ids"][:, 0], skip_special_tokens=True + ), + } + ) + + "\n" + ) + self.rollout_log_file.write(new_record) + self.rollout_log_file.flush() + self.latest_rollout_log_step = self.consumer_global_step + return rollouts + + def __del__(self): + if self.producer_idx == 0: + self.wandb_run.finish() + if hasattr(self, "rollout_log_file"): + self.rollout_log_file.close() + + def load_state_dict(self, state_dict): + self.model.load_state_dict(state_dict) diff --git a/applications/ColossalChat/rl_example_zero_bubble.py b/applications/ColossalChat/rl_example_zero_bubble.py new file mode 100644 index 000000000000..e4f149653f39 --- /dev/null +++ b/applications/ColossalChat/rl_example_zero_bubble.py @@ -0,0 +1,369 @@ +import argparse +import json +import os + +import ray +import torch +from coati.distributed.launch_zero_bubble import launch_distributed + +DEFAUT_SYSTEM_PROMPT = { + "think_answer_tags": "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the tags, i.e., 123 .\n\n", + "boxed": "Please reason step by step, and put your final answer within \\boxed{}.", + "code": "You are a helpful assistant.", +} + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") + parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") + parser.add_argument( + "-ed", + "--eval-dataset", + type=str, + default=None, + help="Evaluation dataset for each task, please use json format to specify the dataset for each task. \ + For example: {'task1':'data_eval_task1.jsonl', 'task2':'data_eval_task2.jsonl'}, the jsonl file should be in the same format as the training dataset. \ + The key is the task name, and the value is the path to the jsonl file", + ) + parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.") + parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.") + + # Distributed training parameters + parser.add_argument("-t", "--num-trainers", type=int, default=2) + parser.add_argument("-i", "--num-inferencer", type=int, default=2) + parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.") + parser.add_argument( + "-ibs", + "--inference-batch-size", + type=int, + default=64, + help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.", + ) + parser.add_argument( + "-imbs", + "--inference-microbatch-size", + type=int, + default=8, + help="Effective batch size for the inference backend to run generation. Please select based on memory constraint.", + ) + parser.add_argument( + "-tbs", + "--train-batch-size", + type=int, + default=32, + help="Number of unique prompts to update policy model per step per dp group. Gradient is accumulated across tbs * dp_size unique prompts, equivalently tbs * g * dp_size samples", + ) + parser.add_argument( + "-tMbs", + "--train-minibatch-size", + type=int, + default=8, + help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs", + ) + parser.add_argument( + "-tmbs", + "--train-microbatch-size", + type=int, + default=2, + help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.", + ) + parser.add_argument( + "-tp", + "--tensor-parallel-size", + type=int, + default=1, + help="Tensor parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.", + ) + parser.add_argument( + "-pp", + "--pipeline-parallel-size", + type=int, + default=1, + help="Pipeline parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.", + ) + parser.add_argument( + "-zero", + "--zero-stage", + type=int, + default=0, + help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.", + ) + parser.add_argument( + "--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional" + ) + parser.add_argument( + "--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional" + ) + parser.add_argument( + "--master_port", type=int, default=29506, help="Master port for multi-node distributed training, Optional" + ) + + # Sampling parameters + parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"]) + parser.add_argument("-temp", "--temperature", type=float, default=1.0, help="Temperature for sampling.") + parser.add_argument( + "-topk", + "--top-k", + type=int, + default=None, + help="Top k for sampling. Please check the generation arguments documentation for your backend.", + ) + parser.add_argument( + "-topp", + "--top-p", + type=float, + default=1.0, + help="Top p for sampling. Please check the generation arguments documentation for your backend.", + ) + parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.") + parser.add_argument("-mnt", "--max-new-tokens", type=int, default=1024 * 4 - 512, help="Max length for generation.") + parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.") + parser.add_argument( + "-ptp", + "--producer-tensor-parallel-size", + type=int, + default=1, + help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.", + ) + + # GRPO parameters + parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"]) + parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.") + parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.") + parser.add_argument( + "-rt", + "--reward-type", + type=str, + default="think_answer_tags", + choices=["think_answer_tags", "boxed", "code"], + help="Reward type for GRPO.", + ) + parser.add_argument( + "-ei", + "--eval-interval", + type=int, + default=100, + help="Interval for evaluation. Evaluate every ei training steps.", + ) + parser.add_argument( + "-cbsl", + "--data_actor_buffer_size_limit", + type=int, + default=-1, + help="The approximate number of samples to keep in the consumer buffer. After this limit is reached, the producer will stop generating new samples and prioritize model sync until the consumer has processed some samples", + ) + + # Logging/Checkpointing parameters + parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.") + parser.add_argument("-sd", "--save-dir", type=str, default="./model", help="Directory for saving checkpoints.") + parser.add_argument( + "-esd", "--eval-save-dir", type=str, default="./eval", help="Directory for saving evaluation results." + ) + parser.add_argument( + "-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings." + ) + parser.add_argument( + "--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process." + ) + args = parser.parse_args() + + if args.train_minibatch_size is None: + # Default settings: Using train batch size as mini batch size + args.train_minibatch_size = args.train_batch_size + if args.inference_batch_size is None: + # Default settings: Using train batch size as inference batch size, sync every inference model every train step + args.inference_batch_size = args.train_batch_size + assert ( + args.train_minibatch_size * args.num_generations >= args.train_microbatch_size + and args.train_microbatch_size > 0 + ), "Train micro batch size must be greater than 0 less than train mini batch size * num generations" + assert ( + args.train_minibatch_size <= args.train_batch_size and args.train_batch_size % args.train_minibatch_size == 0 + ), "Train mini batch size must be less than or equals to train batch size and train batch size must be divisible by train mini batch size" + + if args.master_address is None: + # Default settings: Using single machine + ray.init( + address="local", + namespace="ray-example", + runtime_env={ + "env_vars": { + # "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray + "TOKENIZERS_PARALLELISM": "false" + }, + }, + ) + else: + # For ray distributed multi-machine training, Please change _node_ip_address to your IP address of your master node + ray.init( + _node_ip_address=args.master_address, + namespace="ray-example", + _temp_dir=args.ray_dir, + runtime_env={ + "env_vars": { + # "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray + "TOKENIZERS_PARALLELISM": "false" + }, + }, + ) + + if args.top_k is None: + if args.backend == "transformers": + args.top_k = 50 + elif args.backend == "vllm": + args.top_k = -1 + + os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock + + inference_model_config = dict(path=args.model) + train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False) + generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature) + + if args.backend == "transformers": + inference_model_config.update( + dict( + use_flash_attention_2=True, + torch_dtype=torch.bfloat16, + ) + ) + generate_config.update( + dict( + max_length=args.max_new_tokens + args.max_prompt_tokens, + do_sample=True, + max_new_tokens=None, + early_stopping=False if args.reward_type == "think_answer_tags" else True, + stop_strings=[""] if args.reward_type == "think_answer_tags" else None, + ) + ) + eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation + elif args.backend == "vllm": + inference_model_config.update( + dict( + gpu_memory_utilization=0.7, + enforce_eager=True, + enable_chunked_prefill=True, + max_model_len=args.max_new_tokens + args.max_prompt_tokens, + tensor_parallel_size=args.producer_tensor_parallel_size, + ) + ) + generate_config.update( + dict( + max_tokens=args.max_new_tokens, # max new tokens + ignore_eos=True if args.reward_type == "think_answer_tags" else False, + include_stop_str_in_output=True, + stop=[""] if args.reward_type == "think_answer_tags" else None, + ) + ) + eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation + else: + raise ValueError(f"Unsupported backend: {args.backend}") + + if args.algo == "GRPO": + # Default Settings + grpo_config = { + "lr": args.learning_rate, + "train_microbatch_size": args.train_microbatch_size, + "num_minibatch_during_rollout": 1, # number of mini batches to pop out from buffer and used for training during rollout of the producer after it syncs the model. Hint, set to a proper value close to the number of mini batches for training that takes roughly the same time as the rollout of the producer. A value that is too large or too small will cause bubble time on the trainer or the producer. + "beta": args.kl_coeff, # KL penalty coefficient + "loss_variation": "sample_level", + "reward_fn_type": args.reward_type, + "max_length": args.max_new_tokens + args.max_prompt_tokens, + "max_new_tokens": args.max_new_tokens, + "response_format_tags": ( + { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, + } + if args.reward_type == "think_answer_tags" + else None + ), + } + elif args.algo == "DAPO": + # DAPO variant settings + grpo_config = { + "filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch + "lr": args.learning_rate, + "train_microbatch_size": args.train_microbatch_size, + "dynamic_batching": True, + "clip_eps_low": 0.2, + "clip_eps_high": 0.28, + "skip_threshold": 20.0, + "beta": 0, # no KL penalty for DAPO + "loss_variation": "token_level", + "soft_over_length_punishment": True, + "max_length": args.max_new_tokens + args.max_prompt_tokens, + "max_new_tokens": args.max_new_tokens, + "cache_length": min(1024, int(args.max_new_tokens / 4)), + "filter_truncated_response": True, + "reward_fn_type": args.reward_type, + "response_format_tags": ( + { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, + } + if args.reward_type == "think_answer_tags" + else None + ), + } + else: + raise ValueError(f"Unsupported algorithm: {args.algo}") + + if args.system_prompt is None: + # Default system prompt + args.system_prompt = DEFAUT_SYSTEM_PROMPT[args.reward_type] + + launch_distributed( + num_producers=args.num_inferencer, + num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size), + num_consumer_procs=args.num_trainers, + num_episodes=args.num_episodes, + inference_batch_size=args.inference_batch_size, + inference_microbatch_size=args.inference_microbatch_size, + train_batch_size=args.train_batch_size, + train_minibatch_size=args.train_minibatch_size, + train_dataset_config={ + "path": args.dataset, + "max_length": args.max_prompt_tokens, + "system_prompt": args.system_prompt, + }, + inference_model_config=inference_model_config, + generate_config=generate_config, + num_generations=args.num_generations, + train_model_config=train_model_config, + grpo_config=grpo_config, + plugin_config={ + "tp_size": args.tensor_parallel_size, + "pp_size": args.pipeline_parallel_size, + "microbatch_size": max( + 1, args.train_microbatch_size // args.pipeline_parallel_size + ), # microbatch size should be set to train_microbatch_size // pp_size + "zero_stage": args.zero_stage, + "max_norm": 1.0, + }, # for pp, tp + inference_backend=args.backend, + master_addr="localhost", + master_port=args.master_port, + core_algo=args.algo, + project_name=args.project, + save_interval=args.save_interval, + save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")), + eval_dataset_config=( + { + k: {"path": v, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt} + for k, v in json.loads(args.eval_dataset).items() + } + if args.eval_dataset + else None + ), + eval_interval=args.eval_interval, + eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")), + eval_generation_config=eval_generation_config, + log_rollout_interval=20, + rollout_save_dir=args.rollout_save_dir, + enable_profiling=args.enable_profiling, + data_actor_buffer_size_limit=args.data_actor_buffer_size_limit, + ) From c5e97f4e25b68371cd4c3e52693eb517440b712a Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Mon, 14 Jul 2025 16:25:03 +0800 Subject: [PATCH 2/5] fix code evaluation --- .../coati/distributed/producer.py | 4 +- .../reward/code_reward/testing_util.py | 19 ++++++--- .../distributed/reward/code_reward/utils.py | 40 ++++++++++++------- .../coati/distributed/reward/reward_fn.py | 17 ++++++-- applications/ColossalChat/rl_example.py | 19 +++++++-- .../ColossalChat/start_code_verifier.py | 35 ++++++++++++++++ 6 files changed, 104 insertions(+), 30 deletions(-) create mode 100644 applications/ColossalChat/start_code_verifier.py diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 2a37463913ee..11fb5d3aa3a9 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -83,7 +83,7 @@ def __init__( reward_model_kwargs = { k: v for k, v in grpo_config.items() - if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"] + if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length", "code_verifier_api_url"] } self.response_format_tags = grpo_config.get("response_format_tags", None) if producer_idx == 0: @@ -250,7 +250,7 @@ def loop(self) -> None: for m in range(eval_outputs["input_ids"].size(0)) for n in range(eval_outputs["input_ids"].size(1)) ] - eval_statistics_tensor[0] += len([res for res in eval_results if res["ans_valid"] == 1]) + eval_statistics_tensor[0] += sum([max(0, res["ans_valid"]) for res in eval_results]) eval_statistics_tensor[1] += len(eval_results) allreduce(eval_statistics_tensor, op=ReduceOp.SUM, group_name="producer_group") to_log_msg[f"eval/{eval_task_name}"] = ( diff --git a/applications/ColossalChat/coati/distributed/reward/code_reward/testing_util.py b/applications/ColossalChat/coati/distributed/reward/code_reward/testing_util.py index c6d6d8fad06e..12973337e2a6 100644 --- a/applications/ColossalChat/coati/distributed/reward/code_reward/testing_util.py +++ b/applications/ColossalChat/coati/distributed/reward/code_reward/testing_util.py @@ -89,7 +89,7 @@ def clean_traceback(error_traceback): return error_traceback -def run_test(in_outs, test=None, debug=False, timeout=15): +def run_test(in_outs, test=None, debug=False, timeout=15, run_all_tests=False): """ if test(generated_code) is not None it'll try to run the code. otherwise it'll just return an input and output pair. @@ -180,8 +180,8 @@ def run_test(in_outs, test=None, debug=False, timeout=15): tmp_test = new_test sol += tmp_test - if debug: - print(f"sol = {sol}") + # if debug: + # print(f"sol = {sol}") method_name = "code" signal.alarm(timeout) try: @@ -202,8 +202,7 @@ def run_test(in_outs, test=None, debug=False, timeout=15): } signal.alarm(0) if debug: - print(f"get method = {datetime.now().time()}") - + print(f"get method {method_name} = {datetime.now().time()}") try: method = getattr(tmp, method_name) # get_attr second arg must be str except Exception: @@ -329,6 +328,9 @@ def run_test(in_outs, test=None, debug=False, timeout=15): error_traceback = traceback.format_exc() print(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}") results.append(-1) + signal.alarm(0) + if run_all_tests: + continue return results, { "error": repr(e), "traceback": clean_traceback(error_traceback), @@ -519,6 +521,10 @@ def run_test(in_outs, test=None, debug=False, timeout=15): results.append(tmp_result) if tmp_result is not True: + if debug: + print("final result:", results) + if run_all_tests: + continue return results, { "output": raw_true_output_copy, "expected": raw_outputs, @@ -539,7 +545,8 @@ def run_test(in_outs, test=None, debug=False, timeout=15): ) print(f"results = {results}") - + if debug: + print("final results", results) return results, {} diff --git a/applications/ColossalChat/coati/distributed/reward/code_reward/utils.py b/applications/ColossalChat/coati/distributed/reward/code_reward/utils.py index e19e9e387b71..dbb3b2149780 100644 --- a/applications/ColossalChat/coati/distributed/reward/code_reward/utils.py +++ b/applications/ColossalChat/coati/distributed/reward/code_reward/utils.py @@ -16,27 +16,24 @@ # limitations under the License. import multiprocessing -import os -import sys import traceback from typing import Optional +import requests + from .testing_util import run_test def _temp_run(sample, generation, debug, result, metadata_list, timeout): - with open(os.devnull, "w") as devnull: - sys.stdout = devnull - sys.stderr = devnull - try: - res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout) - result.append(res) - metadata_list.append(metadata) - except Exception: - # print(e) # some tracebacks are extremely long. - traceback.print_exc(10) - result.append([-1 for i in range(len(sample["inputs"]))]) - metadata_list.append({}) + try: + res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout) + result.append(res) + metadata_list.append(metadata) + except Exception: + # print(e) # some tracebacks are extremely long. + traceback.print_exc(10) + result.append([-1 for i in range(len(sample["inputs"]))]) + metadata_list.append({}) def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=True): @@ -49,7 +46,7 @@ def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=Tru metadata_list = manager.list() p = multiprocessing.Process(target=_temp_run, args=(in_outs, generation, debug, result, metadata_list, timeout)) p.start() - p.join(timeout=timeout + 1) + p.join(timeout=600) # Global timeout of 10 minutes that's for all test cases combined if p.is_alive(): p.kill() # p.terminate() @@ -59,3 +56,16 @@ def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=Tru if debug: print("global timeout") return result[0], metadata_list + + +def check_correctness_code_api( + in_outs: Optional[dict], generation, timeout=10, debug=True, url="http://localhost:8000/check_correctness" +): + payload = {"in_outs": in_outs, "generation": generation, "timeout": timeout, "debug": debug} + response = requests.post(url, json=payload) + if response.status_code == 200: + results = response.json() + return results["result"], results["metadata"] + else: + print(f"Error: {response.status_code} - {response.text}") + return [-1 for i in range(len(in_outs["inputs"]))], {} diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index 5c8810fdf4cc..f7a2fb89cadb 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -24,7 +24,7 @@ from latex2sympy2_extended import NormalizationConfig from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify -from .code_reward.utils import check_correctness as check_correctness_code +from .code_reward.utils import check_correctness_code_api as check_correctness_code from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure CANNOT_PARSE_GT_ANSWER = -1 @@ -223,6 +223,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): def code_reward_fn(input_ids, test_cases, response_idx, **kwargs): + url = kwargs.get("url", "http://localhost:8000/check_correctness") tokenizer = kwargs["tokenizer"] eval_mode = kwargs.get("eval_mode", False) soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False) @@ -255,6 +256,9 @@ def code_reward_fn(input_ids, test_cases, response_idx, **kwargs): if format_valid: format_acc += 1 + res = [] + metadata = [] + try: try: if not isinstance(test_cases, dict): @@ -264,15 +268,18 @@ def code_reward_fn(input_ids, test_cases, response_idx, **kwargs): raise e # Complete check on all in-out pairs first. If there is no failure, per-sample test can be skipped. try: - res, metadata = check_correctness_code(in_outs=test_cases, generation=solution, timeout=10, debug=True) + res, metadata = check_correctness_code( + in_outs=test_cases, generation=solution, timeout=10, debug=False, url=url + ) metadata = dict(enumerate(metadata))[0] - success = all(map(lambda x: x is True, res)) + success = all(map(lambda x: x == 1, res)) if success: ans_acc += 1 if eval_mode or format_valid: reward += acc_score if not eval_mode: reward = reward + length_reward + except Exception: pass @@ -288,7 +295,9 @@ def code_reward_fn(input_ids, test_cases, response_idx, **kwargs): return { "prompt": prompt, "prediction": decoded_final_answer, - "gold": test_cases["outputs"], + "test_cases": test_cases, + "test_results": res, + "test_metadata": metadata, "parsed": solution, "format_valid": format_acc.item(), "ans_valid": ans_acc.item(), diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index d7b7c2a5d0fc..08814f9f1e61 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -12,6 +12,9 @@ "code": "You are a helpful assistant.", } +# bypass the proxy for local addresses +os.environ["no_proxy"] = "127.0.0.1,localhost" + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") @@ -138,6 +141,13 @@ choices=["think_answer_tags", "boxed", "code"], help="Reward type for GRPO.", ) + parser.add_argument( + "-cv", + "--code-verifier-api-url", + type=str, + default=None, + help="API URL for code verifier. If not provided, the code verifier will be disabled.", + ) parser.add_argument( "-ei", "--eval-interval", @@ -165,6 +175,7 @@ parser.add_argument( "--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process." ) + args = parser.parse_args() if args.train_minibatch_size is None: @@ -188,7 +199,7 @@ namespace="ray-example", runtime_env={ "env_vars": { - # "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray + # "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray "TOKENIZERS_PARALLELISM": "false" }, }, @@ -201,7 +212,7 @@ _temp_dir=args.ray_dir, runtime_env={ "env_vars": { - # "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray + # "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray "TOKENIZERS_PARALLELISM": "false" }, }, @@ -321,7 +332,9 @@ } else: raise ValueError(f"Unsupported algorithm: {args.algo}") - + if args.reward_type == "code": + assert args.code_verifier_api_url is not None, "Please provide a code verifier API URL for code reward type." + grpo_config.update({"code_verifier_api_url": args.code_verifier_api_url}) if args.system_prompt is None: # Default system prompt args.system_prompt = DEFAUT_SYSTEM_PROMPT[args.reward_type] diff --git a/applications/ColossalChat/start_code_verifier.py b/applications/ColossalChat/start_code_verifier.py new file mode 100644 index 000000000000..d1924f610698 --- /dev/null +++ b/applications/ColossalChat/start_code_verifier.py @@ -0,0 +1,35 @@ +from typing import List, Optional + +from coati.distributed.reward.code_reward.utils import check_correctness # Assuming utils.py is in the same directory +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel + +app = FastAPI() + + +class CheckCorrectnessRequest(BaseModel): + in_outs: Optional[dict] + generation: str + timeout: int = 10 + debug: bool = True + eval_mode: bool = False + + +class CheckCorrectnessResponse(BaseModel): + result: List[int] + metadata: List[dict] + + +@app.post("/check_correctness", response_model=CheckCorrectnessResponse) +def check_correctness_api(request: CheckCorrectnessRequest): + try: + result, metadata = check_correctness( + in_outs=request.in_outs, + generation=request.generation, + timeout=request.timeout, + debug=request.debug, + eval_mode=request.eval_mode, + ) + return CheckCorrectnessResponse(result=result, metadata=metadata) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) From f54ae56f12282a129a7ae16466af3dcd9d9fee4f Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Wed, 16 Jul 2025 16:44:23 +0800 Subject: [PATCH 3/5] add entropy --- .../coati/distributed/grpo_consumer.py | 39 ++++++++++++++++++- .../ColossalChat/coati/distributed/utils.py | 10 +++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index f8ce1afde4b4..754f780979a9 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -6,7 +6,7 @@ import wandb from coati.distributed.consumer import BaseConsumer from coati.distributed.loss import PolicyLoss -from coati.distributed.utils import memory_efficient_logprob +from coati.distributed.utils import entropy_from_logits, memory_efficient_logprob from coati.trainer.utils import all_reduce_mean, all_reduce_sum from transformers import AutoModelForCausalLM, AutoTokenizer @@ -75,6 +75,7 @@ def __init__( self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6)) self.accum_loss = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device) + self.accum_entropy = torch.zeros(1, device=self.device) self.accum_advantages = torch.zeros(1, device=self.device) self.raw_train_batch_reward = [] self.raw_train_batch_format_acc = [] @@ -244,6 +245,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]: else self.booster.no_sync(self.policy_model, self.optimizer) ) with ctx: + mini_batch_entropies = [] for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size): input_ids_forward_micro_batch = data["input_ids"][ forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size @@ -310,9 +312,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]: data_policy_forward["reference_action_log_probs"] = reference_action_log_probs kl = [] + policy_model_logits = torch.empty_like(input_ids_forward_micro_batch, device=self.device) def _criterion(outputs, inputs): action_logits = outputs.logits + policy_model_logits.copy_(action_logits) action_log_probs = memory_efficient_logprob( action_logits / self.generate_config["temperature"], inputs["input_ids"], @@ -359,6 +363,20 @@ def _criterion(outputs, inputs): kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data mean_kl.append(kl) mean_loss.append(all_reduce_mean(loss, self.plugin).data) + mini_batch_entropies.append( + all_reduce_mean( + ( + ( + ( + entropy_from_logits(policy_model_logits[:, -num_action:]) + * action_mask_forward_micro_batch + ).sum(-1) + ) + / action_mask_forward_micro_batch.sum(-1) + ).detach(), + self.plugin, + ) + ) else: policy_model_logits = self.policy_model( input_ids=input_ids_forward_micro_batch, @@ -412,6 +430,20 @@ def _criterion(outputs, inputs): kl = all_reduce_mean(kl.mean(), self.plugin) mean_kl.append(kl.data) mean_loss.append(loss.data) + mini_batch_entropies.append( + all_reduce_mean( + ( + ( + ( + entropy_from_logits(policy_model_logits[:, -num_action:]) + * action_mask_forward_micro_batch + ).sum(-1) + ) + / action_mask_forward_micro_batch.sum(-1) + ).detach(), + self.plugin, + ) + ) if not self.plugin.pp_size > 1 or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() @@ -423,7 +455,9 @@ def _criterion(outputs, inputs): ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin) advantages = all_reduce_mean(advantages.mean(), self.plugin) response_length = all_reduce_mean(response_length.mean(), self.plugin) + entropy = torch.cat(mini_batch_entropies, dim=0).mean() self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) + self.accum_entropy.add_(entropy.data) if self.policy_loss_fn.beta > 0: self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) self.accum_advantages.add_(advantages.data) @@ -464,6 +498,7 @@ def _criterion(outputs, inputs): f"Response Length: {raw_batch_response_len_mean:.4f}", f"Sample_utilization: {sample_utilization:.4f}", f"Overlength samples ratio: {overlength_samples_ratio:.4f}", + f"Entropy: {self.accum_entropy.item() / self.accum_count:.4f}", ] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else []) print("\n".join(to_log_msg)) metrics = { @@ -475,6 +510,7 @@ def _criterion(outputs, inputs): "train/advantages": self.accum_advantages.item() / self.accum_count, "train/learning_rate": self.lr_scheduler.get_last_lr()[0], "train/sample_utilization": sample_utilization, + "train/entropy": self.accum_entropy.item() / self.accum_count, "train/overlength_samples_ratio": overlength_samples_ratio, "rollout/temperature": data["temperature"].cpu().numpy()[0][0], } @@ -484,6 +520,7 @@ def _criterion(outputs, inputs): self.wandb_run.log(metrics) self.accum_loss.zero_() self.accum_kl.zero_() + self.accum_entropy.zero_() self.accum_advantages.zero_() self.accum_count = 0 return loss_scalar diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index d46243114eea..466914cc0d4d 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -110,6 +110,16 @@ def memory_efficient_logprob( return action_log_probs +def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor: + """ + Calculate entropy + Reference: https://github.com/volcengine/verl/blob/96b730bbed80292a439f0c0057d3920ab8b28d52/verl/utils/torch_functional.py#L145 + """ + p = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(p * logits, dim=-1) + return entropy + + def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: """ Compute the masked mean of a tensor along a specified dimension. From e774edeb8073bd6be0070ecfff0a317e572dea3e Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Mon, 21 Jul 2025 17:21:07 +0800 Subject: [PATCH 4/5] fix racing condition --- .../ColossalChat/coati/distributed/comm.py | 6 ++- .../coati/distributed/grpo_consumer.py | 18 +++++-- .../coati/distributed/inference_backend.py | 6 ++- .../coati/distributed/launch_zero_bubble.py | 5 +- .../ColossalChat/coati/distributed/loss.py | 4 +- .../coati/distributed/zero_bubble/consumer.py | 8 +-- .../distributed/zero_bubble/distributor.py | 21 ++++++-- .../distributed/zero_bubble/grpo_consumer.py | 49 ++++++++++++++++--- .../coati/distributed/zero_bubble/producer.py | 22 ++++++--- .../ColossalChat/rl_example_zero_bubble.py | 11 ++++- 10 files changed, 113 insertions(+), 37 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/comm.py b/applications/ColossalChat/coati/distributed/comm.py index 0a724d53bc18..21e6c7d90c79 100644 --- a/applications/ColossalChat/coati/distributed/comm.py +++ b/applications/ColossalChat/coati/distributed/comm.py @@ -1,5 +1,6 @@ -from typing import Any, Dict import copy +from typing import Any, Dict + import ray import ray.util.collective as cc import torch @@ -31,6 +32,7 @@ def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str = obj = c10d._tensor_to_object(obj, size_tensor.item()) return obj + def ray_broadcast_tensor_dict( tensor_dict: Dict[str, torch.Tensor], src: int = 0, @@ -98,7 +100,7 @@ def pickup_rollout_task(self, num_tasks: int): queue length as data may still be generating """ ret = False - if self.queue_size < self.buffer_size_limit: + if self.queue_size < (self.buffer_size_limit / max(0.1, self.signals.get("sample_utilization", 1.0))): ret = True self.queue_size += num_tasks return ret diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 754f780979a9..7b4843b08965 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -250,6 +250,9 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]: input_ids_forward_micro_batch = data["input_ids"][ forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size ] + old_action_log_probs_micro_batch = old_action_log_probs[ + forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size + ] attention_mask_forward_micro_batch = data["attention_mask"][ forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size ] @@ -306,17 +309,22 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]: "action_mask": action_mask_forward_micro_batch, "advantages": advantages_forward_micro_batch, "loss_mask": loss_mask_forward_micro_batch, + "old_action_log_probs": old_action_log_probs_micro_batch, "source": self.rank, } if reference_action_log_probs is not None: data_policy_forward["reference_action_log_probs"] = reference_action_log_probs kl = [] - policy_model_logits = torch.empty_like(input_ids_forward_micro_batch, device=self.device) def _criterion(outputs, inputs): action_logits = outputs.logits - policy_model_logits.copy_(action_logits) + mini_batch_entropies.append( + ( + ((entropy_from_logits(action_logits[:, -num_action:]) * inputs["action_mask"]).sum(-1)) + / inputs["action_mask"].sum(-1) + ).detach() + ) action_log_probs = memory_efficient_logprob( action_logits / self.generate_config["temperature"], inputs["input_ids"], @@ -339,7 +347,7 @@ def _criterion(outputs, inputs): loss, _ = self.policy_loss_fn( action_log_probs, - action_log_probs, + inputs["old_action_log_probs"], inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1), per_token_kl, inputs["action_mask"], @@ -415,7 +423,7 @@ def _criterion(outputs, inputs): loss, _ = self.policy_loss_fn( action_log_probs, - old_action_log_probs, + old_action_log_probs_micro_batch, advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), per_token_kl, action_mask_forward_micro_batch, @@ -455,7 +463,7 @@ def _criterion(outputs, inputs): ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin) advantages = all_reduce_mean(advantages.mean(), self.plugin) response_length = all_reduce_mean(response_length.mean(), self.plugin) - entropy = torch.cat(mini_batch_entropies, dim=0).mean() + entropy = all_reduce_mean(torch.cat(mini_batch_entropies, dim=0).mean(), self.plugin) self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) self.accum_entropy.add_(entropy.data) if self.policy_loss_fn.beta > 0: diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 34827e4e2cf9..331f8d7b6a01 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -59,6 +59,7 @@ def __init__( generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer, num_generations: int = 8, + tokenizer_config: Dict[str, Any] = None, ): model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG) model_config.update(self.FORCE_MODEL_CONFIG) @@ -132,6 +133,7 @@ def __init__( generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer, num_generations: int = 8, + tokenizer_config: Dict[str, Any] = None, ): if sgl is None: raise ImportError("sglang is not installed") @@ -196,12 +198,14 @@ def __init__( generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer, num_generations: int = 8, + tokenizer_config: Dict[str, Any] = None, ): if LLM is None: raise ImportError("vllm is not installed") model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG) path = model_config.pop("path") - self.llm = LLM(model=path, **model_config) + tokenizer_path = tokenizer_config.get("path", None) if tokenizer_config is not None else None + self.llm = LLM(model=path, tokenizer=tokenizer_path, **model_config) generate_config = generate_config.copy() generate_config.update(self.FORCE_GENERATE_CONFIG) generate_config.update({"n": num_generations}) diff --git a/applications/ColossalChat/coati/distributed/launch_zero_bubble.py b/applications/ColossalChat/coati/distributed/launch_zero_bubble.py index 635493b7d957..de5b6135360b 100644 --- a/applications/ColossalChat/coati/distributed/launch_zero_bubble.py +++ b/applications/ColossalChat/coati/distributed/launch_zero_bubble.py @@ -130,7 +130,7 @@ def launch_distributed( train_dataset_config=train_dataset_config, model_config=inference_model_config, generate_config=generate_config, - tokenizer_config=tokenizer_config, + tokenizer_config=copy.deepcopy(tokenizer_config), microbatch_size=inference_microbatch_size, backend=inference_backend, num_generations=num_generations, @@ -158,8 +158,6 @@ def launch_distributed( consumer_master_ip_address = gpu_to_ip_address[0] print(f"Use {consumer_master_ip_address} as master address for torch DDP.") consumer_procs = [] - if num_consumer_procs <= 1: - raise ValueError("Number of consumer processes should be greater than 1 for async rl training.") for i in range(num_consumer_procs): node_id = gpu_to_node_id[0] consumer_ip_address = gpu_to_ip_address[0] @@ -180,6 +178,7 @@ def launch_distributed( model_config=train_model_config, plugin_config=plugin_config, minibatch_size=train_minibatch_size, + tokenizer_config=copy.deepcopy(tokenizer_config), generate_config=generate_config_consumer, grpo_config=grpo_config, num_generations=num_generations, diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py index 36057b24faf5..ea4d0dd11c7e 100644 --- a/applications/ColossalChat/coati/distributed/loss.py +++ b/applications/ColossalChat/coati/distributed/loss.py @@ -35,9 +35,9 @@ def forward( total_effective_tokens_in_batch: torch.Tensor = None, ) -> torch.Tensor: if action_mask is None: - ratio = (log_probs - log_probs.detach()).exp() + ratio = (log_probs - old_log_probs.detach()).exp() else: - ratio = ((log_probs - log_probs.detach()) * action_mask).exp() + ratio = ((log_probs - old_log_probs.detach()) * action_mask).exp() surr1 = ratio * advantages surr2 = ratio.clamp(1 - self.clip_eps_low, 1 + self.clip_eps_high) * advantages diff --git a/applications/ColossalChat/coati/distributed/zero_bubble/consumer.py b/applications/ColossalChat/coati/distributed/zero_bubble/consumer.py index 82242e874308..2b4790884eff 100644 --- a/applications/ColossalChat/coati/distributed/zero_bubble/consumer.py +++ b/applications/ColossalChat/coati/distributed/zero_bubble/consumer.py @@ -7,7 +7,9 @@ import ray.util.collective as cc import torch import torch.distributed as dist +from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict from coati.distributed.profiling_utils import CustomProfiler +from coati.distributed.utils import bind_batch, post_recv, unbind_batch from tqdm import tqdm from colossalai.booster import Booster @@ -15,9 +17,6 @@ from colossalai.initialize import launch from colossalai.utils import get_current_device -from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict -from coati.distributed.utils import bind_batch, post_recv, unbind_batch - class BaseConsumer: def __init__( @@ -175,14 +174,15 @@ def loop(self) -> None: raw_batch = ray.get( self.shared_sync_data_actor.get_data.remote(self.data_uid) ) # get the first queued data + self.profiler.log(f"enter sleep") while raw_batch is None: - self.profiler.log(f"No data received by consumer {self.rank}, skipping") print( f"[T{dist.get_rank()}] No data received by consumer {self.rank}, skipping. Consider increasing the data actor buffer limit" ) time.sleep(1) raw_batch = ray.get(self.shared_sync_data_actor.get_data.remote(self.data_uid)) continue + self.profiler.log(f"exit sleep") self.data_uid += 1 raw_batch = {k: v.to(self.device) for k, v in raw_batch.items()} # calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete), diff --git a/applications/ColossalChat/coati/distributed/zero_bubble/distributor.py b/applications/ColossalChat/coati/distributed/zero_bubble/distributor.py index b16f4b67ef39..ea04ae13cdce 100644 --- a/applications/ColossalChat/coati/distributed/zero_bubble/distributor.py +++ b/applications/ColossalChat/coati/distributed/zero_bubble/distributor.py @@ -3,12 +3,11 @@ import ray import ray.util.collective as cc import torch +from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict from coati.distributed.profiling_utils import CustomProfiler from colossalai.utils import get_current_device -from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict - @ray.remote class Distributor: @@ -21,6 +20,7 @@ def __init__( enable_profiling: bool = True, ): self.distributor_id = distributor_id + self.weight_version = [0] * consumer_pp_size self.consumer_pp_size = consumer_pp_size self.state_dict_cpu = {} self.num_producers = num_producers @@ -42,14 +42,17 @@ def init_collective_group( print(f"[D] Initialized {group_name} collective group", flush=True) def loop(self): + last_weight_version = self.get_weight_version() while True: time.sleep(1) signal = ray.get(self.shared_signal_actor.get_signal.remote()) if self.consumer_pp_size > 1: - for i in range(self.consumer_pp_size): - if signal.get(f"consumer_pp_{i}", None) == "ready_sync_model": + if all( + [signal.get(f"consumer_pp_{i}", None) == "ready_sync_model" for i in range(self.consumer_pp_size)] + ): + cc.barrier(group_name="distributor_pg") + for i in range(self.consumer_pp_size): self.profiler.enter(f"sync_model_consumer_pp_{i}") - cc.barrier(group_name="distributor_pg") ray.get(self.shared_signal_actor.set_signal.remote(f"consumer_pp_{i}", "not_ready_sync_model")) # Broadcast the model state dict from consumer to shared variable actor self.state_dict_cpu[i] = ray_broadcast_tensor_dict( @@ -60,6 +63,7 @@ def loop(self): backend="gloo", ) self.profiler.exit(f"sync_model_consumer_pp_{i}") + self.weight_version[i] += 1 for i in range(self.consumer_pp_size): if signal.get(f"producer_{self.distributor_id}_pp_{i}", None) == "ready_sync_model": self.profiler.enter(f"sync_model_producer_{self.distributor_id}_pp_{i}") @@ -87,6 +91,7 @@ def loop(self): None, 0, device=torch.device("cpu"), group_name="sync_model_consumer", backend="gloo" ) self.profiler.exit("sync_model_consumer") + self.weight_version[0] += 1 if signal.get(f"producer_{self.distributor_id}", None) == "ready_sync_model": self.profiler.enter(f"sync_model_producer_{self.distributor_id}") # Broadcast the model state dict to all producers @@ -106,3 +111,9 @@ def loop(self): if signal.get("consumer", None) == "terminate": self.profiler.log("terminate sync model worker") break + if last_weight_version != self.get_weight_version(): + last_weight_version = self.get_weight_version() + ray.get(self.shared_signal_actor.set_signal.remote("distributor_weight_version", last_weight_version)) + + def get_weight_version(self): + return min(self.weight_version) diff --git a/applications/ColossalChat/coati/distributed/zero_bubble/grpo_consumer.py b/applications/ColossalChat/coati/distributed/zero_bubble/grpo_consumer.py index c07385b97c21..f047852715d1 100644 --- a/applications/ColossalChat/coati/distributed/zero_bubble/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/zero_bubble/grpo_consumer.py @@ -5,9 +5,9 @@ import torch import wandb from coati.distributed.comm import SharedVariableActor -from coati.distributed.zero_bubble.consumer import BaseConsumer from coati.distributed.loss import PolicyLoss -from coati.distributed.utils import memory_efficient_logprob +from coati.distributed.utils import entropy_from_logits, memory_efficient_logprob +from coati.distributed.zero_bubble.consumer import BaseConsumer from coati.trainer.utils import all_reduce_mean, all_reduce_sum from transformers import AutoModelForCausalLM, AutoTokenizer @@ -33,6 +33,7 @@ def __init__( plugin_config, minibatch_size=1, num_generations=8, + tokenizer_config=None, generate_config=None, grpo_config={}, save_interval: int = 100, @@ -73,9 +74,11 @@ def __init__( self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.policy_model.train() self.policy_model.gradient_checkpointing_enable() + self.vocab_size = self.policy_model.config.vocab_size self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6)) self.accum_loss = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device) + self.accum_entropy = torch.zeros(1, device=self.device) self.accum_advantages = torch.zeros(1, device=self.device) self.raw_train_batch_reward = [] self.raw_train_batch_format_acc = [] @@ -102,8 +105,11 @@ def __init__( if self.policy_loss_fn.beta > 0: self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.reference_model.eval() - - self.tokenizer = AutoTokenizer.from_pretrained(path) + if tokenizer_config is not None: + path = tokenizer_config.pop("path", None) + self.tokenizer = AutoTokenizer.from_pretrained(path, **tokenizer_config) + else: + self.tokenizer = AutoTokenizer.from_pretrained(path) self.pad_token_id = self.tokenizer.pad_token_id self.num_generations = num_generations self.filter_range = grpo_config.get("filter_range", None) @@ -243,10 +249,14 @@ def step(self, pbar: Any, **kwargs) -> Optional[float]: else self.booster.no_sync(self.policy_model, self.optimizer) ) with ctx: + mini_batch_entropies = [] for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size): input_ids_forward_micro_batch = data["input_ids"][ forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size ] + old_action_log_probs_micro_batch = old_action_log_probs[ + forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size + ] attention_mask_forward_micro_batch = data["attention_mask"][ forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size ] @@ -303,6 +313,7 @@ def step(self, pbar: Any, **kwargs) -> Optional[float]: "action_mask": action_mask_forward_micro_batch, "advantages": advantages_forward_micro_batch, "loss_mask": loss_mask_forward_micro_batch, + "old_action_log_probs": old_action_log_probs_micro_batch, "source": self.rank, } if reference_action_log_probs is not None: @@ -312,6 +323,12 @@ def step(self, pbar: Any, **kwargs) -> Optional[float]: def _criterion(outputs, inputs): action_logits = outputs.logits + mini_batch_entropies.append( + ( + ((entropy_from_logits(action_logits[:, -num_action:]) * inputs["action_mask"]).sum(-1)) + / inputs["action_mask"].sum(-1) + ).detach() + ) action_log_probs = memory_efficient_logprob( action_logits / self.generate_config["temperature"], inputs["input_ids"], @@ -334,7 +351,7 @@ def _criterion(outputs, inputs): loss, _ = self.policy_loss_fn( action_log_probs, - action_log_probs, + inputs["old_action_log_probs"], inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1), per_token_kl, inputs["action_mask"], @@ -396,7 +413,7 @@ def _criterion(outputs, inputs): loss, _ = self.policy_loss_fn( action_log_probs, - old_action_log_probs, + old_action_log_probs_micro_batch, advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), per_token_kl, action_mask_forward_micro_batch, @@ -411,6 +428,20 @@ def _criterion(outputs, inputs): kl = all_reduce_mean(kl.mean(), self.plugin) mean_kl.append(kl.data) mean_loss.append(loss.data) + mini_batch_entropies.append( + all_reduce_mean( + ( + ( + ( + entropy_from_logits(policy_model_logits[:, -num_action:]) + * action_mask_forward_micro_batch + ).sum(-1) + ) + / action_mask_forward_micro_batch.sum(-1) + ).detach(), + self.plugin, + ) + ) if not self.plugin.pp_size > 1 or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() @@ -422,7 +453,9 @@ def _criterion(outputs, inputs): ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin) advantages = all_reduce_mean(advantages.mean(), self.plugin) response_length = all_reduce_mean(response_length.mean(), self.plugin) + entropy = all_reduce_mean(torch.cat(mini_batch_entropies, dim=0).mean(), self.plugin) self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) + self.accum_entropy.add_(entropy.data) if self.policy_loss_fn.beta > 0: self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) self.accum_advantages.add_(advantages.data) @@ -465,6 +498,7 @@ def _criterion(outputs, inputs): f"Response Length: {raw_batch_response_len_mean:.4f}", f"Sample_utilization: {sample_utilization:.4f}", f"Overlength samples ratio: {overlength_samples_ratio:.4f}", + f"Entropy: {self.accum_entropy.item() / self.accum_count:.4f}", ] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else []) print("\n".join(to_log_msg)) metrics = { @@ -476,6 +510,7 @@ def _criterion(outputs, inputs): "train/advantages": self.accum_advantages.item() / self.accum_count, "train/learning_rate": self.lr_scheduler.get_last_lr()[0], "train/sample_utilization": sample_utilization, + "train/entropy": self.accum_entropy.item() / self.accum_count, "train/overlength_samples_ratio": overlength_samples_ratio, "rollout/temperature": data["temperature"].cpu().numpy()[0][0], } @@ -483,8 +518,10 @@ def _criterion(outputs, inputs): metrics["train/kl"] = self.accum_kl.item() / self.accum_count if self.wandb_run is not None: self.wandb_run.log(metrics) + ray.get(self.shared_signal_actor.set_signal.remote("sample_utilization", sample_utilization)) self.accum_loss.zero_() self.accum_kl.zero_() + self.accum_entropy.zero_() self.accum_advantages.zero_() self.accum_count = 0 return loss_scalar diff --git a/applications/ColossalChat/coati/distributed/zero_bubble/producer.py b/applications/ColossalChat/coati/distributed/zero_bubble/producer.py index 9e57914c4e04..31c314dd51f8 100644 --- a/applications/ColossalChat/coati/distributed/zero_bubble/producer.py +++ b/applications/ColossalChat/coati/distributed/zero_bubble/producer.py @@ -11,9 +11,12 @@ import tqdm import wandb from coati.dataset.loader import RawConversationDataset, collate_fn_grpo +from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict +from coati.distributed.inference_backend import BACKEND_MAP from coati.distributed.profiling_utils import CustomProfiler from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward +from coati.distributed.utils import pre_send, safe_append_to_jsonl_file from ray.util.collective import allreduce from ray.util.collective.types import ReduceOp from torch.utils.data import DataLoader, DistributedSampler @@ -21,10 +24,6 @@ from colossalai.utils import get_current_device -from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict -from coati.distributed.inference_backend import BACKEND_MAP -from coati.distributed.utils import pre_send, safe_append_to_jsonl_file - try: from vllm import SamplingParams except ImportError: @@ -280,11 +279,17 @@ def sync_model_thread(): self.profiler.exit("sync_model") self.sync_model_thread_started = False - if not self.sync_model_thread_started and self.consumer_global_step != self.producer_weight_version: + distributor_weight_version = ray.get(self.shared_signal_actor.get_signal.remote()).get( + f"distributor_weight_version", 0 + ) + if ( + not self.sync_model_thread_started + and distributor_weight_version != self.producer_weight_version + ): # only sync model when the thread is not started and global step is changed self.sync_model_thread_started = True self.sync_model_thread = threading.Thread(target=sync_model_thread) - self.producer_weight_version = self.consumer_global_step + self.producer_weight_version = distributor_weight_version self.sync_model_thread.start() torch.cuda.empty_cache() if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( @@ -478,7 +483,7 @@ def __init__( train_dataset_config, model_config, generate_config, - tokenizer_config, + copy.deepcopy(tokenizer_config), microbatch_size, backend, consumer_plugin_config, @@ -493,7 +498,8 @@ def __init__( rollout_log_file=rollout_log_file, enable_profiling=enable_profiling, ) - self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations) + print("tokenizer_config", tokenizer_config) + self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations, tokenizer_config) self.eval_generation_config = copy.deepcopy(self.model.generate_config) self.eval_generation_config["n"] = 1 # use 1 generation for evaluation self.eval_generation_config.update(eval_generation_config) diff --git a/applications/ColossalChat/rl_example_zero_bubble.py b/applications/ColossalChat/rl_example_zero_bubble.py index e4f149653f39..89270d753a09 100644 --- a/applications/ColossalChat/rl_example_zero_bubble.py +++ b/applications/ColossalChat/rl_example_zero_bubble.py @@ -15,6 +15,12 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") + parser.add_argument( + "--tokenizer-path", + type=str, + default=None, + help="Path to the tokenizer. If not provided, will use the model path.", + ) parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") parser.add_argument( "-ed", @@ -166,6 +172,7 @@ "--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process." ) args = parser.parse_args() + print(args) if args.train_minibatch_size is None: # Default settings: Using train batch size as mini batch size @@ -283,7 +290,7 @@ elif args.algo == "DAPO": # DAPO variant settings grpo_config = { - "filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch + "filter_range": [0.01, 0.7], # only filter out all zero batch and all one batch "lr": args.learning_rate, "train_microbatch_size": args.train_microbatch_size, "dynamic_batching": True, @@ -343,7 +350,9 @@ ), # microbatch size should be set to train_microbatch_size // pp_size "zero_stage": args.zero_stage, "max_norm": 1.0, + "num_layers_per_stage": [18, 10], }, # for pp, tp + tokenizer_config={"path": args.tokenizer_path} if args.tokenizer_path else {"path": args.model}, inference_backend=args.backend, master_addr="localhost", master_port=args.master_port, From 5c5cb1863b86d92f56ef11d883a0cdeedddc5448 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Mon, 21 Jul 2025 18:04:20 +0800 Subject: [PATCH 5/5] hotfix --- .../coati/distributed/grpo_consumer.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 7b4843b08965..a3f1a1cbbbb2 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -371,20 +371,6 @@ def _criterion(outputs, inputs): kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data mean_kl.append(kl) mean_loss.append(all_reduce_mean(loss, self.plugin).data) - mini_batch_entropies.append( - all_reduce_mean( - ( - ( - ( - entropy_from_logits(policy_model_logits[:, -num_action:]) - * action_mask_forward_micro_batch - ).sum(-1) - ) - / action_mask_forward_micro_batch.sum(-1) - ).detach(), - self.plugin, - ) - ) else: policy_model_logits = self.policy_model( input_ids=input_ids_forward_micro_batch,