diff --git a/ckpt_converter/README.md b/ckpt_converter/README.md index 844c3aeb..5a51b5bb 100644 --- a/ckpt_converter/README.md +++ b/ckpt_converter/README.md @@ -30,7 +30,8 @@ bash scripts/ckpt_converter/convert_to_hf_cpu.sh \ /path/to/VENV_DIR \ /path/to/MEGATRON_PATH \ /path/to/HF_TOKENIZER_PATH \ - /path/to/OUTPUT_DIR + /path/to/OUTPUT_DIR \ + MODEL_PARALLEL_SIZE ``` - `TASK_DIR`: Root of training run; must contain `checkpoints/iter_XXXXXXX`. @@ -39,5 +40,6 @@ bash scripts/ckpt_converter/convert_to_hf_cpu.sh \ - `MEGATRON_PATH`: Megatron-LM root directory (must contain `tools/checkpoint/convert.py`). - `HF_TOKENIZER_PATH`: Path to the Hugging Face tokenizer directory. - `OUTPUT_DIR`: Destination directory for converted HF checkpoints. +- `MODEL_PARALLEL_SIZE`: Only support tensor parallel at this point. -The script queues a single PBS job, passing these values to `qsub_convert_cpu.sh`. \ No newline at end of file +The script queues a single PBS job, passing these values to `qsub_convert_cpu.sh`. diff --git a/ckpt_converter/compare_hf_ckpt.py b/ckpt_converter/compare_hf_ckpt.py new file mode 100644 index 00000000..1c36dcba --- /dev/null +++ b/ckpt_converter/compare_hf_ckpt.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import argparse +import json +import os +import sys +from typing import Dict, Tuple, List, Optional + +import torch +from safetensors.torch import load_file as safe_load_file + +INDEX_NAME = "model.safetensors.index.json" + + +def _read_json(path: str) -> dict: + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + + +def _find_index_file(dirpath: str) -> Optional[str]: + p = os.path.join(dirpath, INDEX_NAME) + return p if os.path.isfile(p) else None + + +def _list_safetensors_files(dirpath: str) -> List[str]: + return sorted( + os.path.join(dirpath, f) + for f in os.listdir(dirpath) + if f.endswith(".safetensors") + ) + + +def _load_state_dict_from_dir(dirpath: str) -> Dict[str, torch.Tensor]: + index_path = _find_index_file(dirpath) + state: Dict[str, torch.Tensor] = {} + if index_path is not None: + index = _read_json(index_path) + weight_map: Dict[str, str] = index.get("weight_map") or {} + shards = sorted(set(weight_map.values())) + for shard in shards: + shard_path = os.path.join(dirpath, shard) + shard_sd = safe_load_file(shard_path, device="cpu") + for k, v in shard_sd.items(): + if weight_map.get(k) == shard: + state[k] = v.detach().cpu() + return state + + files = _list_safetensors_files(dirpath) + if not files: + raise RuntimeError(f"No .safetensors found in: {dirpath}") + for wf in files: + shard_sd = safe_load_file(wf, device="cpu") + for k, v in shard_sd.items(): + state[k] = v.detach().cpu() + return state + + +def compare_checkpoints(dir_a: str, dir_b: str, max_diffs: int = 20) -> Tuple[bool, str]: + sd_a = _load_state_dict_from_dir(dir_a) + sd_b = _load_state_dict_from_dir(dir_b) + + keys_a = set(sd_a.keys()) + keys_b = set(sd_b.keys()) + + msgs: List[str] = [] + missing_in_b = sorted(keys_a - keys_b) + extra_in_b = sorted(keys_b - keys_a) + if missing_in_b: + msgs.append(f"Missing in B ({len(missing_in_b)}): {missing_in_b[:max_diffs]}") + if extra_in_b: + msgs.append(f"Extra in B ({len(extra_in_b)}): {extra_in_b[:max_diffs]}") + + common = sorted(keys_a & keys_b) + shape_mismatch = [] + dtype_mismatch = [] + value_mismatch = [] + + for name in common: + t1 = sd_a[name] + t2 = sd_b[name] + if t1.shape != t2.shape: + shape_mismatch.append((name, t1.shape, t2.shape)) + if len(shape_mismatch) >= max_diffs: + break + for name, _, _ in shape_mismatch: + if name in common: + common.remove(name) + + for name in common: + t1 = sd_a[name] + t2 = sd_b[name] + if t1.dtype != t2.dtype: + dtype_mismatch.append((name, str(t1.dtype), str(t2.dtype))) + if len(dtype_mismatch) >= max_diffs: + break + for name, _, _ in dtype_mismatch: + if name in common: + common.remove(name) + + for name in common: + t1 = sd_a[name] + t2 = sd_b[name] + if not torch.equal(t1, t2): + value_mismatch.append(name) + if len(value_mismatch) >= max_diffs: + break + + if shape_mismatch: + msgs.append(f"Shape mismatches ({len(shape_mismatch)}): {shape_mismatch[:max_diffs]}") + if dtype_mismatch: + msgs.append(f"Dtype mismatches ({len(dtype_mismatch)}): {dtype_mismatch[:max_diffs]}") + if value_mismatch: + msgs.append(f"Value mismatches ({len(value_mismatch)}): {value_mismatch[:max_diffs]}") + + ok = not (missing_in_b or extra_in_b or shape_mismatch or dtype_mismatch or value_mismatch) + if ok: + return True, "All tensors identical." + summary = f"Summary: missing={len(missing_in_b)}, extra={len(extra_in_b)}, shape={len(shape_mismatch)}, dtype={len(dtype_mismatch)}, value={len(value_mismatch)}." + return False, "\n".join(msgs + [summary]) + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description="Compare two HF safetensors checkpoints (strict equality).") + p.add_argument("ckpt_a", type=str) + p.add_argument("ckpt_b", type=str) + p.add_argument("--max-diffs", type=int, default=20) + return p.parse_args() + + +def main() -> int: + args = parse_args() + a = os.path.abspath(args.ckpt_a) + b = os.path.abspath(args.ckpt_b) + if not os.path.isdir(a) or not os.path.isdir(b): + print("Both inputs must be directories.", file=sys.stderr) + return 2 + try: + ok, report = compare_checkpoints(a, b, max_diffs=args.max_diffs) + except Exception as e: + print(f"[error] {e}", file=sys.stderr) + return 2 + if ok: + print("OK: All tensors identical.") + return 0 + print("NOT EQUAL:") + print(report) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file diff --git a/ckpt_converter/convert_to_hf_cpu.sh b/ckpt_converter/convert_to_hf_cpu.sh index 01e28de4..fb23e694 100755 --- a/ckpt_converter/convert_to_hf_cpu.sh +++ b/ckpt_converter/convert_to_hf_cpu.sh @@ -6,8 +6,8 @@ set -eu -o pipefail -if [ "$#" -ne 6 ]; then - echo "Usage: $0 TASK_DIR ITER VENV_DIR MEGATRON_PATH HF_TOKENIZER_PATH OUTPUT_DIR" >&2 +if [ "$#" -ne 7 ]; then + echo "Usage: $0 TASK_DIR ITER VENV_DIR MEGATRON_PATH HF_TOKENIZER_PATH OUTPUT_DIR PARALLEL_SIZE" >&2 exit 2 fi @@ -17,12 +17,14 @@ venv_dir=$3 megatron_path=$4 hf_tokenizer_path=$5 output_dir=$6 +parrallel_size=$7 script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # Create env list in case cumulatively adding -v is not supported env_list="SCRIPT_DIR=${script_dir},TASK_DIR=${task_dir},ITER=${iter}," env_list+="VENV_DIR=${venv_dir},MEGATRON_PATH=${megatron_path}," -env_list+="HF_TOKENIZER_PATH=${hf_tokenizer_path},OUTPUT_DIR=${output_dir}" +env_list+="HF_TOKENIZER_PATH=${hf_tokenizer_path},OUTPUT_DIR=${output_dir}," +env_list+="PARALLEL_SIZE=${parrallel_size}" echo "Submitting job for iteration: ${iter}" diff --git a/ckpt_converter/loader_mcore_cpu.py b/ckpt_converter/loader_mcore_cpu.py index 3453a804..0612dadc 100644 --- a/ckpt_converter/loader_mcore_cpu.py +++ b/ckpt_converter/loader_mcore_cpu.py @@ -41,8 +41,10 @@ def _read_metadata_cpu(tracker_filename): max_iter = iteration return max_iter, release + ckpt.read_metadata = _read_metadata_cpu + def add_arguments(parser): group = parser.add_argument_group(title='Megatron loader (CPU)') group.add_argument('--true-vocab-size', type=int, default=None, @@ -60,6 +62,7 @@ def add_arguments(parser): choices=['local', 'transformer_engine'], help='Which Transformer implementation to use.') + class MegatronCheckpointLoaderBase: def __init__(self, args, queue, build_tokenizer=False): self.args = args @@ -159,17 +162,21 @@ def initialize_megatron_env(self): self.queue.put("exit") sys.exit(1) - import torch.distributed as dist + # Init torch.distributed once if dist.is_available() and not dist.is_initialized(): - import tempfile - tmp = tempfile.NamedTemporaryFile(delete=False) - tmp.close() - dist.init_process_group( - backend='gloo', - init_method=f'file://{tmp.name}', - rank=0, - world_size=1, - ) + # Prefer env:// when under torchrun; fallback to file:// (single-rank) + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + dist.init_process_group(backend='gloo', init_method='env://') + else: + import tempfile + tmp = tempfile.NamedTemporaryFile(delete=False) + tmp.close() + dist.init_process_group( + backend='gloo', + init_method=f'file://{tmp.name}', + rank=0, + world_size=1, + ) # Init model parallel groups try: @@ -248,7 +255,14 @@ def load_model_shards(self, model_provider, dtype): def get_models_for_pipeline_stage(count, dtype): local_models_for_stage = [[] for _ in range(vp_size)] - for tp_rank in range(count): + + # Only load TP shard under multi-process; otherwise iterate all + if dist.is_initialized() and dist.get_world_size() > 1 and mpu.get_tensor_model_parallel_world_size() > 1: + tp_ranks = [mpu.get_tensor_model_parallel_rank()] + else: + tp_ranks = range(count) + + for tp_rank in tp_ranks: mpu.set_tensor_model_parallel_rank(tp_rank) model_list = [] for i in range(vp_size): @@ -294,11 +308,31 @@ def queue_put(self, name, msg): msg["name"] = name self.queue.put(msg) + def _tp_gather_cat_rank0(self, t: torch.Tensor, dim: int): + from megatron.core import mpu + if not (dist.is_initialized() and mpu.get_tensor_model_parallel_world_size() > 1): + return t + group = mpu.get_tensor_model_parallel_group() + ws = mpu.get_tensor_model_parallel_world_size() + parts = [torch.empty_like(t) for _ in range(ws)] + dist.all_gather(parts, t, group=group) + if mpu.get_tensor_model_parallel_rank() == 0: + return torch.cat(parts, dim=dim) + else: + return None + def send_llm_over_queue(self, schema): + from megatron.core import mpu + tp_size = self.margs.tensor_model_parallel_size pp_size = self.margs.pipeline_model_parallel_size vp_size = self.margs.virtual_pipeline_model_parallel_size or 1 + # Only main TP rank sends to saver; other ranks only participate in communication + is_main_tp = (not dist.is_initialized()) or \ + (mpu.get_tensor_model_parallel_world_size() == 1) or \ + (mpu.get_tensor_model_parallel_rank() == 0) + first_pipeline_models = self.all_models[0][0] # Embeddings @@ -306,11 +340,15 @@ def send_llm_over_queue(self, schema): message = { "word embeddings": torch.cat([e["word"] for e in embeddings], dim=0) } + # Gather full embeddings across TP + message["word embeddings"] = self._tp_gather_cat_rank0(message["word embeddings"], dim=0) + if self.md.position_embedding_type == 'learned_absolute': message["position embeddings"] = embeddings[0]["pos"] else: assert embeddings[0]["pos"] is None - self.queue_put("embeddings", message) + if is_main_tp: + self.queue_put("embeddings", message) total_layer_num = 0 for vp_rank in range(vp_size): @@ -335,6 +373,7 @@ def send_llm_over_queue(self, schema): mlp_l0_weight, mlp_l0_bias = [], [] mlp_l1_weight = [] + # collect local TP-shard tensors (on this rank) for model_tp in models: layer_p = schema.get_layer(model_tp, layer_idx) qkv_weight.append(layer_p["self_attn_qkv_weight"]) @@ -346,8 +385,9 @@ def send_llm_over_queue(self, schema): if self.md.linear_bias: mlp_l0_bias.append(layer_p["mlp_fc1_bias"]) + # Build message with local concat if self.md.swiglu: - for i in range(tp_size): + for i in range(len(mlp_l0_weight)): mlp_l0_weight[i] = torch.chunk(mlp_l0_weight[i], 2, dim=0) message["mlp l0 weight W"] = torch.cat([w[0] for w in mlp_l0_weight], dim=0) message["mlp l0 weight V"] = torch.cat([w[1] for w in mlp_l0_weight], dim=0) @@ -362,23 +402,53 @@ def send_llm_over_queue(self, schema): message["qkv bias"] = torch.cat(qkv_bias, dim=0) if self.md.linear_bias: if self.md.swiglu: - for i in range(tp_size): + for i in range(len(mlp_l0_bias)): mlp_l0_bias[i] = torch.chunk(mlp_l0_bias[i], 2, dim=0) message["mlp l0 bias W"] = torch.cat([b[0] for b in mlp_l0_bias], dim=0) message["mlp l0 bias V"] = torch.cat([b[1] for b in mlp_l0_bias], dim=0) else: message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0) - self.queue_put(f"transformer layer {total_layer_num}", message) + # Now gather to full tensors across TP + qkv_w = self._tp_gather_cat_rank0(message["qkv weight"], dim=0) + dense = self._tp_gather_cat_rank0(message["dense weight"], dim=1) + mlp_l1 = self._tp_gather_cat_rank0(message["mlp l1 weight"], dim=1) + if self.md.qkv_bias: + qkv_b = self._tp_gather_cat_rank0(message["qkv bias"], dim=0) + if self.md.swiglu: + mlp0W = self._tp_gather_cat_rank0(message["mlp l0 weight W"], dim=0) + mlp0V = self._tp_gather_cat_rank0(message["mlp l0 weight V"], dim=0) + if self.md.linear_bias: + mlp0bW = self._tp_gather_cat_rank0(message["mlp l0 bias W"], dim=0) + mlp0bV = self._tp_gather_cat_rank0(message["mlp l0 bias V"], dim=0) + else: + mlp0 = self._tp_gather_cat_rank0(message["mlp l0 weight"], dim=0) + if self.md.linear_bias: + mlp0b = self._tp_gather_cat_rank0(message["mlp l0 bias"], dim=0) + + if is_main_tp: + message["qkv weight"] = qkv_w + message["dense weight"] = dense + message["mlp l1 weight"]= mlp_l1 + if self.md.qkv_bias: message["qkv bias"] = qkv_b + if self.md.swiglu: + message["mlp l0 weight W"] = mlp0W + message["mlp l0 weight V"] = mlp0V + if self.md.linear_bias: + message["mlp l0 bias W"] = mlp0bW + message["mlp l0 bias V"] = mlp0bV + else: + message["mlp l0 weight"] = mlp0 + if self.md.linear_bias: message["mlp l0 bias"] = mlp0b + + self.queue_put(f"transformer layer {total_layer_num}", message) total_layer_num += 1 # Final norm models = self.all_models[0][0] final_norm = schema.get("final_norm", models[0]) - message = {"weight": final_norm["weight"]} - if self.md.norm_has_bias: - message["bias"] = final_norm["bias"] - self.queue_put("final norm", message) + if is_main_tp: + self.queue_put("final norm", {"weight": final_norm["weight"], **({"bias": final_norm["bias"]} if self.md.norm_has_bias else {})}) # Output layer if self.md.output_layer: @@ -386,23 +456,10 @@ def send_llm_over_queue(self, schema): message = { "weight": torch.cat([layer["weight"] for layer in output_layers], dim=0), } - self.queue_put("output layer", message) - - # BERT-specific - if self.md.model_type == 'BERT': - pooler = schema.get("pooler", models[0]) - message = {"weight": pooler["weight"], "bias": pooler["bias"]} - self.queue_put("pooler", message) - - lm_head = schema.get("lm_head", models[0]) - message = { - "dense weight": lm_head["dense_weight"], - "dense bias": lm_head["dense_bias"], - "norm weight": lm_head["norm_weight"], - } - if self.md.norm_has_bias: - message["norm bias"] = lm_head["norm_bias"] - self.queue_put("lm head", message) + w_all = self._tp_gather_cat_rank0(message["weight"], dim=0) + if is_main_tp: + message["weight"] = w_all + self.queue_put("output layer", message) def build_checkpoint_metadata(self, true_vocab_size): norm_has_bias = True @@ -460,6 +517,7 @@ def build_sys_argv(self): '--no-gradient-accumulation-fusion', '--no-gradient-reduce-div-fusion', '--tp-comm-bootstrap-backend', 'gloo', + '--finetune', ] def import_model_provider(self): @@ -493,7 +551,28 @@ def load(self): self.send_model_over_queue() def send_model_over_queue(self): - raise NotImplementedError + from megatron.core import mpu + is_main_tp = (not dist.is_initialized()) or (mpu.get_tensor_model_parallel_world_size() == 1) or (mpu.get_tensor_model_parallel_rank() == 0) + + self.send_metadata_over_queue() + schema = get_model_schema( + self.md.model_type, + self.margs.transformer_impl, + self.margs.num_experts, + self.margs.expert_model_parallel_size, + ) + self.send_llm_over_queue(schema) + + if is_main_tp: + self.queue.put("done") + else: + self.queue.put("exit") + + def send_metadata_over_queue(self): + self.md.consumed_train_samples = self.consumed_train_samples + self.md.consumed_valid_samples = self.consumed_valid_samples + self.queue.put(self.md) + class MegatronCheckpointLoaderLLM(MegatronCheckpointLoaderBase): def build_sys_argv(self): @@ -513,6 +592,9 @@ def import_model_provider(self): raise Exception(f"Unrecognized model type: {self.args.model_type}") def send_model_over_queue(self): + from megatron.core import mpu + is_main_tp = (not dist.is_initialized()) or (mpu.get_tensor_model_parallel_world_size() == 1) or (mpu.get_tensor_model_parallel_rank() == 0) + self.send_metadata_over_queue() schema = get_model_schema( self.md.model_type, @@ -521,13 +603,18 @@ def send_model_over_queue(self): self.margs.expert_model_parallel_size, ) self.send_llm_over_queue(schema) - self.queue.put("done") + + if is_main_tp: + self.queue.put("done") + else: + self.queue.put("exit") def send_metadata_over_queue(self): self.md.consumed_train_samples = self.consumed_train_samples self.md.consumed_valid_samples = self.consumed_valid_samples self.queue.put(self.md) + def load_checkpoint(queue, args): loader = MegatronCheckpointLoaderLLM(args, queue) try: diff --git a/ckpt_converter/qsub_convert_cpu.sh b/ckpt_converter/qsub_convert_cpu.sh index e95743e6..09107c91 100644 --- a/ckpt_converter/qsub_convert_cpu.sh +++ b/ckpt_converter/qsub_convert_cpu.sh @@ -3,8 +3,6 @@ #PBS -q R9920251000 #PBS -N 0208_convert #PBS -l select=1 -#PBS -o /dev/null -#PBS -e /dev/null #PBS -m n set -eu -o pipefail @@ -38,18 +36,25 @@ echo "CKPT_ROOT=${CKPT_ROOT}" echo "HF_TOKENIZER_PATH=${HF_TOKENIZER_PATH}" echo "OUTPUT_DIR=${OUTPUT_DIR}" echo "LOADER_SAVER_PATH=${LOADER_SAVER_PATH}" +echo "PARALLEL_SIZE=${PARALLEL_SIZE}" source ${VENV_DIR}/bin/activate +export CUDA_DEVICE_MAX_CONNECTIONS=1 + # Force CPU export CUDA_VISIBLE_DEVICES="" export NVIDIA_VISIBLE_DEVICES="" export OMP_NUM_THREADS=${OMP_NUM_THREADS:-8} # Logs -mkdir -p ${TASK_DIR}/logs -LOGFILE=${TASK_DIR}/logs/convert_${ITER}_${JOBID}.out -ERRFILE=${TASK_DIR}/logs/convert_${ITER}_${JOBID}.err +# mkdir -p ${TASK_DIR}/logs +# LOGFILE=${TASK_DIR}/logs/convert_${ITER}_${JOBID}.out +# ERRFILE=${TASK_DIR}/logs/convert_${ITER}_${JOBID}.err +# exec > "$LOGFILE" 2> "$ERRFILE" +mkdir -p ${OUTPUT_DIR}/logs +LOGFILE=${OUTPUT_DIR}/logs/convert_${ITER}_${JOBID}.out +ERRFILE=${OUTPUT_DIR}/logs/convert_${ITER}_${JOBID}.err exec > "$LOGFILE" 2> "$ERRFILE" # Sanity checks @@ -73,7 +78,8 @@ ln -s "${CKPT_ROOT}/${ITER_NAME}" "${LOAD_ROOT}/${ITER_NAME}" echo "Converting torch_dist -> HF on CPU..." PYTHONPATH="${LOADER_SAVER_PATH}:${MEGATRON_PATH}:${PYTHONPATH:-}" \ -python "${MEGATRON_PATH}/tools/checkpoint/convert.py" \ +torchrun --standalone --nproc_per_node ${PARALLEL_SIZE} \ +"${MEGATRON_PATH}/tools/checkpoint/convert.py" \ --model-type GPT \ --loader mcore_cpu \ --loader-transformer-impl local \