From 2f9320066be2fe4def351d6630127b9c46686002 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 28 Mar 2022 13:10:54 -0700 Subject: [PATCH] add debug utils --- megatron/debug_utils.py | 137 ++++++++++++++++++++++++++++++++++ megatron/model/transformer.py | 10 ++- 2 files changed, 144 insertions(+), 3 deletions(-) create mode 100644 megatron/debug_utils.py diff --git a/megatron/debug_utils.py b/megatron/debug_utils.py new file mode 100644 index 000000000..f544073bf --- /dev/null +++ b/megatron/debug_utils.py @@ -0,0 +1,137 @@ + +import torch.distributed as dist +import torch +import os +import socket +import fcntl + +def printflock(*msgs): + """ print """ + with open(__file__, "r") as fh: + fcntl.flock(fh, fcntl.LOCK_EX) + try: + print(*msgs) + finally: + fcntl.flock(fh, fcntl.LOCK_UN) + + +def get_fingerprint_header(): + return f"{'min':^13} {'max':^13} {'mean':^13} {'l2 norm':^12} metadata" + + +def get_fingerprint(p): + return f"{p.min():13.6e} {p.max():13.6e} {p.mean():13.6e} {p.norm():12.6e}" + + +def dump_weights(preamble, iteration, model, optimizer, tensor=None): + return + + tp_rank = mpu.get_tensor_model_parallel_rank() + pp_rank = mpu.get_pipeline_model_parallel_rank() + dp_rank = mpu.get_data_parallel_rank() + dp_size = mpu.get_data_parallel_world_size() + fn = f"debug-bf16-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-{preamble}.txt" + + # only care for first and last pp stages and dp0 tp0 + if not (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()): + return + + if not (tp_rank == 0 and dp_rank == 0): + return + + if tensor is not None: + orig_tensor = tensor + if hasattr(tensor, "_hp_param"): + numel = tensor._hp_param.numel() # // dp_size + tensor = tensor.flatten().narrow(0, 0, numel) + + #print(fn) + with open(fn, "w") as fh: + fh.write(f"{get_fingerprint_header()}\n") + + if tensor is not None: + fh.write(f"{get_fingerprint(tensor)} tensor {tensor.shape}\n") + else: + for n, p in model[0].named_parameters(): + fh.write(f"{get_fingerprint(p)} {n} {p.shape}\n") + + # until we figure out how to dump the actual fp32 values don't do this + fn = f"debug-fp32-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-{preamble}.txt" + with open(fn, "w") as fh: + fh.write(f"{get_fingerprint_header()}\n") + if tensor is not None: + tensor = orig_tensor + if hasattr(tensor, "_hp_param"): + fh.write(f"{get_fingerprint(tensor._hp_param)} tensor {tensor._hp_param.shape}\n") + fh.write(f"{get_fingerprint(tensor._hp_grad)} tensor grad\n") + else: + fh.write(f"{get_fingerprint(tensor)} tensor {tensor.shape}\n") + fh.write(f"{get_fingerprint(tensor.grad)} tensor grad\n") + + else: + if hasattr(model[0].module.tied_modules, "embed"): + p = model[0].module.tied_modules.embed.word_embeddings.weight._hp_param + fh.write(f"{get_fingerprint(p)} module.tied_modules.embed.word_embeddings.weight._hp_param {p.shape}\n") + + # for i, param_group in enumerate(optimizer.param_groups): + # fh.write(f"{get_fingerprint(optimizer.fp32_groups_flat_partition[i])} group={i}\n") + #fh.write(f"{i}={optimizer.fp32_groups_flat_partition[i]}\n") + # if mpu.is_pipeline_first_stage(): + # x = optimizer.fp32_groups_flat_partition[0] + # fh.write(f"fp32={x[:402432]}\n") + # if mpu.is_pipeline_last_stage()): + # x = optimizer.fp32_groups_flat_partition[1] + # fh.write(f"fp32={x[-402432:]}\n") + + # import os + # import socket + # hostname = socket.gethostname() + # pid = os.getpid() + # global_rank = torch.distributed.get_rank() + #fn = f"debug-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-global{global_rank}-{preamble}-{pid}.txt" + + + + +# compare before +# perl -le 'print qx[diff -u debug-$_-pp0-tp0-dp0-global0-before-iteration.txt debug-$_-pp1-tp0-dp0-global1-before-iteration.txt] for 301..320' +# compare after +# perl -le 'print qx[diff -u debug-$_-pp0-tp0-dp0-global0-after-iteration.txt debug-$_-pp1-tp0-dp0-global1-after-iteration.txt] for 301..320' + + +import torch + +def dump_emb(preamble, iteration, model): + return + + # torch.set_printoptions( + # threshold=10000000000, # print all data (without ... skipping) - can be huge! + # sci_mode=False, # print all data on the same scale of 1 (this disables scientific notation) + # precision=6, # print X decimal points for floats (default 4) + # ) + + # only care for first and last pp stages and dp0 tp0 + if not (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()): + return + + #printflock(f"pp rank={pp_rank} {preamble} {model[0].module.tied_modules.embed.word_embeddings.weight}") + + tp_rank = mpu.get_tensor_model_parallel_rank() + pp_rank = mpu.get_pipeline_model_parallel_rank() + dp_rank = mpu.get_data_parallel_rank() + #global_rank = torch.distributed.get_rank() + + if not (tp_rank == 0 and dp_rank == 0): + return + + # fn = f"debug-emb-bf16-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-{preamble}.zip" + # torch.save(model[0].module.tied_modules.embed.word_embeddings.weight, fn) + + fn = f"debug-emb-bf16-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-{preamble}.txt" + #fn = f"debug-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-global{global_rank}-{preamble}.txt" + #print(fn) + with open(fn, "w") as fh: + fh.write(f"module.tied_modules.embed.word_embeddings.weight={model[0].module.tied_modules.embed.word_embeddings.weight.cpu()}\n") + # if pp_rank == 0: + # fh.write(f"module.tied_modules.embed.word_embeddings.norm.weight={model[0].module.tied_modules.embed.word_embeddings.norm.weight.cpu()}\n") + # fh.write(f"module.tied_modules.embed.word_embeddings.norm.bias={model[0].module.tied_modules.embed.word_embeddings.norm.bias.cpu()}\n") diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 48401a9f1..325a231ba 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -501,6 +501,10 @@ def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False): # hidden_states: [b, s, h] + from megatron.debug_utils import dump_weights + args = get_args() + dump_weights("before-input_layernorm", args.iteration, None, None, tensor=self.input_layernorm.weight) + # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. @@ -608,12 +612,12 @@ def get_slopes_power_of_2(n): slopes = torch.Tensor(get_slopes(num_attention_heads)) alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0).expand( num_attention_heads, -1, -1) - + #Select the part of the tensor that corresponds to our tensor parallel index. tp_world_size = mpu.get_tensor_model_parallel_world_size() tp_index = mpu.get_tensor_model_parallel_rank() alibi = alibi.reshape((tp_world_size, -1, *alibi.shape[1:]))[tp_index] - + alibi = alibi.repeat(batch_size, 1, 1) return alibi @@ -629,7 +633,7 @@ class ParallelTransformerLayerPipe(ParallelTransformerLayer): to the next stage in the pipeline. This version is useful if masks are dynamic. - + 2) forward(input, **kwargs) -> output When the mask is static over all samples, it is advantageous to cache the mask and avoid communicating it.