From 132795ff002f6ae9135938fc1cbf80b751eef1bf Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 11 Oct 2022 11:25:37 +0200 Subject: [PATCH 1/5] v1 --- megatron/arguments.py | 20 +++ megatron/model/fused_layer_norm.py | 17 ++- megatron/model/fused_softmax.py | 3 +- megatron/model/gpt_model.py | 183 +++++++++++++++++++++++++--- megatron/model/language_model.py | 23 +++- megatron/model/transformer.py | 188 +++++++++++++++++++++++++---- megatron/optimizer/__init__.py | 14 +++ 7 files changed, 403 insertions(+), 45 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index c18235a78..2e7a4c109 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -242,6 +242,9 @@ def parse_args(extra_args_provider=None, defaults={}, # Checks. if args.ffn_hidden_size is None: args.ffn_hidden_size = 4 * args.hidden_size + + if args.student_ffn_hidden_size is None: + args.student_ffn_hidden_size = 4 * args.student_hidden_size if args.kv_channels is None: assert args.hidden_size % args.num_attention_heads == 0 @@ -353,9 +356,18 @@ def _add_network_size_args(parser): help='Number of transformer layers.') group.add_argument('--hidden-size', type=int, default=None, help='Tansformer hidden size.') + group.add_argument('--student-num-layers', type=int, default=None, + help='Number of student transformer layers.') + group.add_argument('--student-hidden-size', type=int, default=None, + help='Student Tansformer hidden size.') + group.add_argument('--student-num-attention-heads', type=int, default=None, + help='Number of student transformer attention heads.') group.add_argument('--ffn-hidden-size', type=int, default=None, help='Transformer Feed-Forward Network hidden size. ' 'This is set to 4*hidden-size if not provided') + group.add_argument('--student-ffn-hidden-size', type=int, default=None, + help='Transformer Feed-Forward Network hidden size. ' + 'This is set to 4*hidden-size if not provided') group.add_argument('--num-attention-heads', type=int, default=None, help='Number of transformer attention heads.') group.add_argument('--kv-channels', type=int, default=None, @@ -660,6 +672,10 @@ def _add_checkpointing_args(parser): help='Do not save current rng state.') group.add_argument('--load', type=str, default=None, help='Directory containing a model checkpoint.') + group.add_argument('--teacher-load', type=str, default=None, + help='Directory containing a model checkpoint.') + group.add_argument('--student-load', type=str, default=None, + help='Directory containing a model checkpoint.') group.add_argument('--no-load-optim', action='store_true', default=None, help='Do not load optimizer when loading checkpoint.') group.add_argument('--no-load-rng', action='store_true', default=None, @@ -715,8 +731,12 @@ def _add_distributed_args(parser): group.add_argument('--tensor-model-parallel-size', type=int, default=1, help='Degree of tensor model parallelism.') + group.add_argument('--student-tensor-model-parallel-size', type=int, default=1, + help='Degree of tensor model parallelism.') group.add_argument('--pipeline-model-parallel-size', type=int, default=1, help='Degree of pipeline model parallelism.') + group.add_argument('--student-pipeline-model-parallel-size', type=int, default=1, + help='Degree of pipeline model parallelism.') group.add_argument('--model-parallel-size', type=int, default=None, help='Old model parallel argument, do not use. Use ' '--tensor-model-parallel-size instead.') diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index 55e9c9dd8..0372132a9 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -41,7 +41,10 @@ def forward(ctx, input, weight, bias, normalized_shape, eps): ctx.normalized_shape = normalized_shape ctx.eps = eps - input_ = input.contiguous() + if isinstance(input, tuple): + input_ = input[0].contiguous() + else: + input_ = input.contiguous() weight_ = weight.contiguous() bias_ = bias.contiguous() output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine( @@ -109,3 +112,15 @@ def forward(self, input): input, self.weight, self.bias, self.normalized_shape, self.eps) else: return F.layer_norm(input, self.normalized_shape, self.weight, self.bias) + +class MixedFusedLayerNormTeacher(MixedFusedLayerNorm): + + @torch.no_grad() + def forward(self, input): + input, original_input = input + return (super().forward(input), original_input) + +class MixedFusedLayerNormStudent(MixedFusedLayerNorm): + def forward(self, input): + input, logits_teacher = input + return (super().forward(input), logits_teacher) \ No newline at end of file diff --git a/megatron/model/fused_softmax.py b/megatron/model/fused_softmax.py index 07192e2bf..9813085e1 100644 --- a/megatron/model/fused_softmax.py +++ b/megatron/model/fused_softmax.py @@ -188,7 +188,7 @@ def forward_fused_softmax(self, input, mask): if self.attn_mask_type == AttnMaskType.causal: assert sq == sk, "causal mask is only for self attention" - assert mask is None, "Mask is silently ignored due to the use of a custom kernel" + # assert mask is None, "Mask is silently ignored due to the use of a custom kernel" # input is 3D tensor (attn_batches, sq, sk) input = input.view(-1, sq, sk) @@ -236,3 +236,4 @@ def get_batch_per_block(sq, sk, b, np): import scaled_masked_softmax_cuda return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np) + diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index a9e3e2604..bc837fa8d 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -30,9 +30,11 @@ from deepspeed.pipe import PipelineModule, LayerSpec, TiedLayerSpec from megatron.model.fused_layer_norm import MixedFusedLayerNorm as LayerNorm +from megatron.model.fused_layer_norm import MixedFusedLayerNormTeacher as LayerNormTeacher +from megatron.model.fused_layer_norm import MixedFusedLayerNormStudent as LayerNormStudent from megatron.model.module import float16_to_fp32 -from .language_model import EmbeddingPipe -from .transformer import ParallelTransformerLayerPipe +from .language_model import EmbeddingPipe, EmbeddingPipeTeacher, EmbeddingPipeStudent +from .transformer import ParallelTransformerLayerPipe, ParallelTransformerLayerPipeTeacher, ParallelTransformerLayerPipeStudent def post_language_model_processing(lm_output, labels, logit_weights, @@ -195,6 +197,57 @@ def CrossEntropy(output, labels): return CrossEntropy +def get_ts_loss(is_prefix: bool): + def TeacherStudentLoss(output, labels): + output, teacher_logits = output[0], output[1] + labels, loss_mask = labels[0], labels[1] + + args = get_args() + + losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels) + + if is_prefix: + micro_batch_size, sequence_length = loss_mask.shape + average_tokens_per_sample: torch.Tensor + if args.loss_on_targets_only: + # HACK: This is useful when we obtain loss masks that are microbatch dependent. Consequently, if we want to + # preserve the notion that all tokens have the same impact on the loss, we can only normalise using a + # microbatch independent value. It should be expected weight over a microbatch. + # Here we still use `sequence_length`, that's batch size dependent, in order to be backwards compatible with + # current experiment on vanilla gpt. + if args.reweight_loss_based_on_position_frequency: + reweight = torch.arange( + sequence_length, 0, -1, dtype=torch.float, device=loss_mask.device + ) / (sequence_length + 1) * 2 + average_tokens_per_sample = reweight.flip(-1).cumsum(-1).mean() + else: + average_tokens_per_sample = (sequence_length + 1) / 2 + else: + average_tokens_per_sample = sequence_length + expected_number_of_tokens = average_tokens_per_sample * micro_batch_size + else: + expected_number_of_tokens = loss_mask.sum() + + loss_mask = loss_mask.view(-1) + loss = torch.sum(losses.view(-1) * loss_mask) / expected_number_of_tokens + + # TODO: check if the formula is correct + teacher_logits = teacher_logits.detach() + # First pass it on CPU - otherwise we get OOM errors + softmax_labels = torch.nn.Softmax(dim=-1)(teacher_logits) + softmax_labels = softmax_labels.permute(1, 0, 2) + + student_log_softax = -torch.nn.LogSoftmax(dim=-1)(output) + + # print(output.shape, teacher_logits.shape) + # print(student_log_softax.shape, softmax_labels.shape) + softmax_logits = student_log_softax * softmax_labels + logits_loss = softmax_logits.mean() + + return loss + logits_loss + return TeacherStudentLoss + + class GPTModelPipe(PipelineModule,MegatronModule): """GPT-2 Language model.""" @@ -223,7 +276,7 @@ def _to_float16(inputs): # Embedding layer self.specs.append(TiedLayerSpec('embed', - EmbeddingPipe, + EmbeddingPipeTeacher, args.hidden_size, args.padded_vocab_size, args.hidden_dropout, @@ -239,14 +292,14 @@ def _to_float16(inputs): self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous().float(), *x[1:])) else: if getattr(args, 'pretrain_causal_attention', False): - self.specs.append(lambda x: x.transpose(0, 1).contiguous()) + self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous(), x[1])) else: # EmbeddingPipe returns attention mask as well self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous(), *x[1:])) for layer_idx in range(args.num_layers): self.specs.append( - LayerSpec(ParallelTransformerLayerPipe, + LayerSpec(ParallelTransformerLayerPipeTeacher, init_method=init_method, output_layer_init_method=scaled_init_method_normal(args.init_method_std, args.num_layers), @@ -256,14 +309,16 @@ def _to_float16(inputs): # Undo data format change def undo(x): - if not getattr(args, 'pretrain_causal_attention', False): - x = x[0] + # if not getattr(args, 'pretrain_causal_attention', False): + # x = x[0] + if isinstance(x, tuple): + return (x[0][0].transpose(0, 1).contiguous(), x[1]) return x.transpose(0, 1).contiguous() self.specs.append(undo) # Final layernorm after transformer layers self.specs.append( - LayerSpec(LayerNorm, + LayerSpec(LayerNormTeacher, args.hidden_size, eps=args.layernorm_epsilon)) @@ -276,7 +331,7 @@ def _logits_helper(embedding, lm_output): self.specs.append( TiedLayerSpec('embed', - EmbeddingPipe, + EmbeddingPipeTeacher, args.hidden_size, args.padded_vocab_size, args.hidden_dropout, @@ -286,19 +341,17 @@ def _logits_helper(embedding, lm_output): tied_weight_attr='word_embeddings_weight') ) + # self.specs.append(lambda x: print(x[0])) # Convert to fp32 if needed - if args.fp16 or args.bf16: - self.specs.append(float16_to_fp32) + # if args.fp16 or args.bf16: + # self.specs.append(float16_to_fp32) if args.checkpoint_activations: interval = args.checkpoint_num_layers else: interval = 0 - from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology - topo = PipeModelDataParallelTopology(num_pp=mpu.get_pipeline_model_parallel_world_size(), - num_mp=mpu.get_tensor_model_parallel_world_size(), - num_dp=mpu.get_data_parallel_world_size()) + # here one can extend the regex to include more layers to be counted towards partitioning, # e.g. 'type:transformer|embedding' will add up all the transformer blocks and also the first @@ -306,14 +359,112 @@ def _logits_helper(embedding, lm_output): # balance you may want to use less transformer layers # # caveat emptor: the current implementation of PP fails unless each stage has at least one + + # Beginning student model + + init_method = init_method_normal(args.init_method_std) + + + def _to_float16(inputs): + if args.fp16: + return fp32_to_float16(inputs, lambda v: v.half()) + elif args.bf16: + return fp32_to_float16(inputs, lambda v: v.bfloat16()) + else: + return inputs + + # self.specs.append(_to_float16) + self.specs.append(lambda x: (x[0], x[1])) + + # Embedding layer + self.specs.append(TiedLayerSpec('embed_student', + EmbeddingPipeStudent, + args.student_hidden_size, + args.padded_vocab_size, + args.hidden_dropout, + init_method=init_method, + num_tokentypes=num_tokentypes, + tied_weight_attr='word_embeddings_weight')) + + if args.fp32_residual_connection: + if getattr(args, 'pretrain_causal_attention', False): + self.specs.append(lambda x: x.transpose(0, 1).contiguous().float()) + else: + # EmbeddingPipe returns attention mask as well + self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous().float(), *x[1:])) + else: + if getattr(args, 'pretrain_causal_attention', False): + self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous(), x[1])) + else: + # EmbeddingPipe returns attention mask as well + self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous(), *x[1:])) + + for layer_idx in range(args.student_num_layers): + self.specs.append( + LayerSpec(ParallelTransformerLayerPipeStudent, + init_method=init_method, + output_layer_init_method=scaled_init_method_normal(args.init_method_std, + args.student_num_layers), + layer_number=layer_idx, + # TODO: Change naming of class from GPT to something that encapsulate prefix lm. + self_attn_mask_type=attn_mask_type)) + + # Undo data format change + # def undo(x): + # if not getattr(args, 'pretrain_causal_attention', False): + # x = x[0] + # return x.transpose(0, 1).contiguous() + # self.specs.append(undo) + + # Final layernorm after transformer layers + self.specs.append( + LayerSpec(LayerNormStudent, + args.student_hidden_size, + eps=args.layernorm_epsilon)) + + def _logits_helper(embedding, lm_output): + """A wrapper to massage inputs/outputs from pipeline. """ + return parallel_lm_logits( + lm_output, + embedding.word_embeddings_weight, + self.parallel_output) + + self.specs.append( + TiedLayerSpec('embed_student', + EmbeddingPipeStudent, + args.student_hidden_size, + args.padded_vocab_size, + args.hidden_dropout, + init_method=init_method, + num_tokentypes=num_tokentypes, + forward_fn=_logits_helper, + tied_weight_attr='word_embeddings_weight') + ) + + # Convert to fp32 if needed + if args.fp16 or args.bf16: + self.specs.append(float16_to_fp32) + + if args.checkpoint_activations: + interval = args.checkpoint_num_layers + else: + interval = 0 + # transformer layer if args.pp_partition_method is not None: partition_method = args.pp_partition_method else: partition_method = 'type:transformer' + + from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology + topo = PipeModelDataParallelTopology(num_pp=mpu.get_pipeline_model_parallel_world_size(), + num_mp=mpu.get_tensor_model_parallel_world_size(), + num_dp=mpu.get_data_parallel_world_size()) + + super().__init__(layers=self.specs, - loss_fn=get_cross_entropy(is_prefix=attn_mask_type is AttnMaskType.prefix), + loss_fn=get_ts_loss(is_prefix=attn_mask_type is AttnMaskType.prefix), topology=topo, activation_checkpoint_interval=interval, partition_method=partition_method) diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index fc284431a..e33b39ef3 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -15,6 +15,7 @@ """Transformer based language model.""" +from importlib import invalidate_caches import torch import torch.nn.functional as F @@ -29,6 +30,11 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None): """LM logits using word embedding weights.""" + if isinstance(input_, tuple): + original_inputs = input_[1] + input_ = input_[0] + else: + original_inputs = None # Parallel logits. input_parallel = mpu.copy_to_tensor_model_parallel_region(input_) # Matrix multiply. @@ -38,9 +44,9 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias) # Gather if needed. if parallel_output: - return logits_parallel + return original_inputs, logits_parallel - return mpu.gather_from_tensor_model_parallel_region(logits_parallel) + return original_inputs, mpu.gather_from_tensor_model_parallel_region(logits_parallel) def get_language_model(num_tokentypes, add_pooler, @@ -274,6 +280,9 @@ def forward(self, inputs, **kwargs): input_ids = inputs[0] position_ids = inputs[1] + if isinstance(input_ids, tuple): + # print(input_ids) + input_ids = input_ids[0] if getattr(self._args, 'pretrain_causal_attention', False): attention_mask = None else: @@ -298,6 +307,16 @@ def word_embeddings_weight(self): """Easy accessory for the DeepSpeed pipeline engine to tie embeddings across stages.""" return self.word_embeddings.weight +class EmbeddingPipeTeacher(EmbeddingPipe): + @torch.no_grad() + def forward(self, inputs, **kwargs): + return (super().forward(inputs, **kwargs), inputs) + +class EmbeddingPipeStudent(EmbeddingPipe): + def forward(self, inputs, **kwargs): + inputs, logits_teacher = inputs + return (super().forward(inputs, **kwargs), logits_teacher) + class TransformerLanguageModel(MegatronModule): """Transformer language model. diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 03e6faaec..4ec55302a 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -65,18 +65,27 @@ class ParallelMLP(MegatronModule): applied. """ - def __init__(self, init_method, output_layer_init_method): + def __init__(self, init_method, output_layer_init_method, student_=False): super(ParallelMLP, self).__init__() args = get_args() # Project to ffn_hidden_size - self.dense_h_to_4h = mpu.ColumnParallelLinear( - args.hidden_size, - # GLU is a special activation that divides the dimension by a factor 2. - 2 * args.ffn_hidden_size if args.glu_activation else args.ffn_hidden_size, - gather_output=False, - init_method=init_method, - skip_bias_add=True) + if not student_: + self.dense_h_to_4h = mpu.ColumnParallelLinear( + args.hidden_size, + # GLU is a special activation that divides the dimension by a factor 2. + 2 * args.ffn_hidden_size if args.glu_activation else args.ffn_hidden_size, + gather_output=False, + init_method=init_method, + skip_bias_add=True) + else: + self.dense_h_to_4h = mpu.ColumnParallelLinear( + args.student_hidden_size, + # GLU is a special activation that divides the dimension by a factor 2. + 2 * args.student_ffn_hidden_size if args.glu_activation else args.student_ffn_hidden_size, + gather_output=False, + init_method=init_method, + skip_bias_add=True) self.bias_gelu_fusion = args.bias_gelu_fusion self.activation_func = F.gelu @@ -88,12 +97,20 @@ def __init__(self, init_method, output_layer_init_method): self.activation_func = erf_gelu # Project back to h. - self.dense_4h_to_h = mpu.RowParallelLinear( - args.ffn_hidden_size, - args.hidden_size, - input_is_parallel=True, - init_method=output_layer_init_method, - skip_bias_add=True) + if not student_: + self.dense_4h_to_h = mpu.RowParallelLinear( + args.ffn_hidden_size, + args.hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + skip_bias_add=True) + else: + self.dense_4h_to_h = mpu.RowParallelLinear( + args.student_ffn_hidden_size, + args.student_hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + skip_bias_add=True) def forward(self, hidden_states): @@ -123,7 +140,8 @@ class ParallelAttention(MegatronModule): def __init__(self, init_method, output_layer_init_method, layer_number, attention_type=AttnType.self_attn, - attn_mask_type=AttnMaskType.padding): + attn_mask_type=AttnMaskType.padding, + student_=False): super(ParallelAttention, self).__init__() args = get_args() self.fp16 = args.fp16 @@ -145,27 +163,27 @@ def __init__(self, init_method, self.hidden_size_per_partition = mpu.divide(projection_size, world_size) self.hidden_size_per_attention_head = mpu.divide( - projection_size, args.num_attention_heads) + projection_size, args.num_attention_heads if not student_ else args.student_num_attention_heads) self.num_attention_heads_per_partition = mpu.divide( - args.num_attention_heads, world_size) + args.num_attention_heads if not student_ else args.student_num_attention_heads, world_size) # Strided linear layer. if attention_type == AttnType.self_attn: self.query_key_value = mpu.ColumnParallelLinear( - args.hidden_size, + args.hidden_size if not student_ else args.student_hidden_size, 3 * projection_size, gather_output=False, init_method=init_method) else: assert attention_type == AttnType.cross_attn self.query = mpu.ColumnParallelLinear( - args.hidden_size, + args.hidden_size if not student_ else args.student_hidden_size, projection_size, gather_output=False, init_method=init_method) self.key_value = mpu.ColumnParallelLinear( - args.hidden_size, + args.hidden_size if not student_ else args.student_hidden_size, 2 * projection_size, gather_output=False, init_method=init_method) @@ -192,7 +210,7 @@ def __init__(self, init_method, # Output. self.dense = mpu.RowParallelLinear( projection_size, - args.hidden_size, + args.hidden_size if not student_ else args.student_hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True) @@ -439,7 +457,8 @@ class ParallelTransformerLayer(MegatronModule): def __init__(self, init_method, output_layer_init_method, layer_number, layer_type=LayerType.encoder, - self_attn_mask_type=AttnMaskType.padding): + self_attn_mask_type=AttnMaskType.padding, + student_=False): args = get_args() super(ParallelTransformerLayer, self).__init__() @@ -463,7 +482,8 @@ def __init__(self, init_method, output_layer_init_method, output_layer_init_method, layer_number, attention_type=AttnType.self_attn, - attn_mask_type=self_attn_mask_type) + attn_mask_type=self_attn_mask_type, + student_=student_) self.hidden_dropout = args.hidden_dropout self.bias_dropout_fusion = args.bias_dropout_fusion @@ -477,10 +497,11 @@ def __init__(self, init_method, output_layer_init_method, init_method, output_layer_init_method, layer_number, - attention_type=AttnType.cross_attn) + attention_type=AttnType.cross_attn, + student_=student_) # Layernorm on the attention output. self.post_inter_attention_layernorm = LayerNorm( - args.hidden_size, + args.hidden_size if not student_ else args.student_hidden_size, eps=args.layernorm_epsilon) # MLP @@ -504,6 +525,8 @@ def forward(self, hidden_states, attention_mask, # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) + if isinstance(layernorm_output, tuple): + layernorm_output, _ = layernorm_output # Self attention. attention_output, attention_bias = \ self.self_attention(layernorm_output, @@ -521,6 +544,12 @@ def forward(self, hidden_states, attention_mask, else: residual = hidden_states + if isinstance(residual, tuple): + if len(residual) > 1: + residual, _ = residual + else: + residual = residual[0] + # jit scripting for a nn.module (with dropout) is not # trigerring the fusion kernel. For now, we use two # different nn.functional routines to account for varying @@ -647,6 +676,115 @@ def forward(self, inputs, **kwargs): else: raise RuntimeError('Received more inputs than understood.') +class ParallelTransformerLayerPipeTeacher(ParallelTransformerLayerPipe): + """Extends ParallelTransformerLayer to forward attention_mask through the pipeline. + + Forward has two usages that affect attention mask communication: + + 1) forward((input, attn_mask) , **kwargs) -> (output, mask) + When the attention mask is provided as the second positional + argument, typical pipeline behavior is used and both the output + *and* mask are returned in a tuple. This tuple is then forwarded + to the next stage in the pipeline. + + This version is useful if masks are dynamic. + + 2) forward(input, **kwargs) -> output + When the mask is static over all samples, it is advantageous to + cache the mask and avoid communicating it. + """ + @torch.no_grad() + def forward(self, inputs, **kwargs): + input_ids = inputs[-1] + if isinstance(input_ids, tuple): + # input_ids = input_ids[0] + input_ids = input_ids + # print(self.layer_number, input_ids) + return (super().forward(inputs, **kwargs), input_ids) + +class ParallelTransformerLayerPipeStudent(ParallelTransformerLayerPipe): + """Extends ParallelTransformerLayer to forward attention_mask through the pipeline. + + Forward has two usages that affect attention mask communication: + + 1) forward((input, attn_mask) , **kwargs) -> (output, mask) + When the attention mask is provided as the second positional + argument, typical pipeline behavior is used and both the output + *and* mask are returned in a tuple. This tuple is then forwarded + to the next stage in the pipeline. + + This version is useful if masks are dynamic. + + 2) forward(input, **kwargs) -> output + When the mask is static over all samples, it is advantageous to + cache the mask and avoid communicating it. + """ + def __init__(self, init_method, output_layer_init_method, + layer_number, layer_type=LayerType.encoder, + self_attn_mask_type=AttnMaskType.padding): + args = get_args() + + super(ParallelTransformerLayer, self).__init__() + self.layer_number = layer_number + self.layer_type = layer_type + + self.apply_residual_connection_post_layernorm \ + = args.apply_residual_connection_post_layernorm + + self.bf16 = args.bf16 + self.fp32_residual_connection = args.fp32_residual_connection + + # Layernorm on the input data. + self.input_layernorm = LayerNorm( + args.student_hidden_size, + eps=args.layernorm_epsilon) + + # Self attention. + self.self_attention = ParallelAttention( + init_method, + output_layer_init_method, + layer_number, + attention_type=AttnType.self_attn, + attn_mask_type=self_attn_mask_type, + student_=True) + self.hidden_dropout = args.hidden_dropout + self.bias_dropout_fusion = args.bias_dropout_fusion + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNorm( + args.student_hidden_size, + eps=args.layernorm_epsilon) + + if self.layer_type == LayerType.decoder: + self.inter_attention = ParallelAttention( + init_method, + output_layer_init_method, + layer_number, + attention_type=AttnType.cross_attn, + student_=True) + # Layernorm on the attention output. + self.post_inter_attention_layernorm = LayerNorm( + args.student_hidden_size, + eps=args.layernorm_epsilon) + + # MLP + self.mlp = ParallelMLP(init_method, + output_layer_init_method, student_=True) + + # Alibi + if args.position_embedding_type == PositionEmbeddingType.alibi: + self.alibi = self._build_alibi_tensor(args.seq_length, args.student_num_attention_heads, args.micro_batch_size).to(torch.cuda.current_device()) + if args.params_dtype == torch.float16: + self.alibi = self.alibi.to(torch.float16) + elif args.params_dtype == torch.bfloat16: + self.alibi = self.alibi.to(torch.bfloat16) + else: + self.alibi = None + @torch.no_grad() + def forward(self, inputs, **kwargs): + logits_teacher = inputs[-1] + inputs = inputs[:-1] + return (super().forward(inputs, **kwargs), logits_teacher) class ParallelTransformer(MegatronModule): """Transformer class.""" diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py index 738717d55..fe4dd9002 100644 --- a/megatron/optimizer/__init__.py +++ b/megatron/optimizer/__init__.py @@ -22,6 +22,17 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer +def _filter_for_teacher_student(modules): + trainable_modules = [] + + for module in modules: + for module_ in module.modules(): + if "Student" in module_.__class__.__name__: + trainable_modules.append(module_) + + return trainable_modules + + def _get_params_for_weight_decay_optimization(modules): """Divide params into with-weight-decay and without-weight-decay groups. @@ -30,6 +41,9 @@ def _get_params_for_weight_decay_optimization(modules): weight_decay_params = {'params': []} no_weight_decay_params = {'params': [], 'weight_decay': 0.0} + + modules = _filter_for_teacher_student(modules) + for module in modules: for module_ in module.modules(): if isinstance(module_, LayerNorm): From 720178407ee5e1bffaa178975d252b97c78da2b8 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 11 Oct 2022 18:59:34 +0200 Subject: [PATCH 2/5] cleaner version --- megatron/model/fused_layer_norm.py | 5 +++-- megatron/model/gpt_model.py | 30 +++++++++++++----------------- megatron/model/language_model.py | 12 +++++++----- megatron/model/transformer.py | 4 ++-- megatron/optimizer/__init__.py | 8 ++++++-- 5 files changed, 31 insertions(+), 28 deletions(-) diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index 0372132a9..16053e8e5 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -114,10 +114,11 @@ def forward(self, input): return F.layer_norm(input, self.normalized_shape, self.weight, self.bias) class MixedFusedLayerNormTeacher(MixedFusedLayerNorm): - - @torch.no_grad() + # @torch.no_grad() def forward(self, input): input, original_input = input + print("input", input.shape) + print("original_input", original_input[0].shape) return (super().forward(input), original_input) class MixedFusedLayerNormStudent(MixedFusedLayerNorm): diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index bc837fa8d..30335464f 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -199,12 +199,12 @@ def CrossEntropy(output, labels): def get_ts_loss(is_prefix: bool): def TeacherStudentLoss(output, labels): - output, teacher_logits = output[0], output[1] + student_logits, teacher_logits = output labels, loss_mask = labels[0], labels[1] args = get_args() - losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels) + losses = mpu.vocab_parallel_cross_entropy(student_logits.contiguous().float(), labels) if is_prefix: micro_batch_size, sequence_length = loss_mask.shape @@ -235,9 +235,9 @@ def TeacherStudentLoss(output, labels): teacher_logits = teacher_logits.detach() # First pass it on CPU - otherwise we get OOM errors softmax_labels = torch.nn.Softmax(dim=-1)(teacher_logits) - softmax_labels = softmax_labels.permute(1, 0, 2) - - student_log_softax = -torch.nn.LogSoftmax(dim=-1)(output) + # softmax_labels = softmax_labels.permute(1, 0, 2) + + student_log_softax = -torch.nn.LogSoftmax(dim=-1)(student_logits) # print(output.shape, teacher_logits.shape) # print(student_log_softax.shape, softmax_labels.shape) @@ -373,8 +373,6 @@ def _to_float16(inputs): else: return inputs - # self.specs.append(_to_float16) - self.specs.append(lambda x: (x[0], x[1])) # Embedding layer self.specs.append(TiedLayerSpec('embed_student', @@ -410,11 +408,11 @@ def _to_float16(inputs): self_attn_mask_type=attn_mask_type)) # Undo data format change - # def undo(x): - # if not getattr(args, 'pretrain_causal_attention', False): - # x = x[0] - # return x.transpose(0, 1).contiguous() - # self.specs.append(undo) + def undo(x): + if isinstance(x, tuple): + return (x[0].transpose(0, 1).contiguous(), x[1]) + return x.transpose(0, 1).contiguous() + self.specs.append(undo) # Final layernorm after transformer layers self.specs.append( @@ -422,12 +420,12 @@ def _to_float16(inputs): args.student_hidden_size, eps=args.layernorm_epsilon)) - def _logits_helper(embedding, lm_output): + def _logits_helper_student(embedding, lm_output): """A wrapper to massage inputs/outputs from pipeline. """ return parallel_lm_logits( lm_output, embedding.word_embeddings_weight, - self.parallel_output) + self.parallel_output, permute_output=True) self.specs.append( TiedLayerSpec('embed_student', @@ -437,7 +435,7 @@ def _logits_helper(embedding, lm_output): args.hidden_dropout, init_method=init_method, num_tokentypes=num_tokentypes, - forward_fn=_logits_helper, + forward_fn=_logits_helper_student, tied_weight_attr='word_embeddings_weight') ) @@ -461,8 +459,6 @@ def _logits_helper(embedding, lm_output): num_mp=mpu.get_tensor_model_parallel_world_size(), num_dp=mpu.get_data_parallel_world_size()) - - super().__init__(layers=self.specs, loss_fn=get_ts_loss(is_prefix=attn_mask_type is AttnMaskType.prefix), topology=topo, diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index e33b39ef3..bb743c1fb 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -28,9 +28,10 @@ from megatron.model.utils import init_method_normal, scaled_init_method_normal def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, - bias=None): + bias=None, permute_output=False): """LM logits using word embedding weights.""" if isinstance(input_, tuple): + # retrieve the input tensor from the tuple original_inputs = input_[1] input_ = input_[0] else: @@ -42,11 +43,12 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, logits_parallel = F.linear(input_parallel, word_embeddings_weight) else: logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias) + # Gather if needed. if parallel_output: - return original_inputs, logits_parallel - - return original_inputs, mpu.gather_from_tensor_model_parallel_region(logits_parallel) + return (logits_parallel, original_inputs) if permute_output else (original_inputs, logits_parallel) + + return (mpu.gather_from_tensor_model_parallel_region(logits_parallel), original_inputs) if permute_output else (original_inputs, mpu.gather_from_tensor_model_parallel_region(logits_parallel)) def get_language_model(num_tokentypes, add_pooler, @@ -308,7 +310,7 @@ def word_embeddings_weight(self): return self.word_embeddings.weight class EmbeddingPipeTeacher(EmbeddingPipe): - @torch.no_grad() + # @torch.no_grad() def forward(self, inputs, **kwargs): return (super().forward(inputs, **kwargs), inputs) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 4ec55302a..9a83bb6a7 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -693,7 +693,7 @@ class ParallelTransformerLayerPipeTeacher(ParallelTransformerLayerPipe): When the mask is static over all samples, it is advantageous to cache the mask and avoid communicating it. """ - @torch.no_grad() + # @torch.no_grad() def forward(self, inputs, **kwargs): input_ids = inputs[-1] if isinstance(input_ids, tuple): @@ -780,7 +780,7 @@ def __init__(self, init_method, output_layer_init_method, self.alibi = self.alibi.to(torch.bfloat16) else: self.alibi = None - @torch.no_grad() + def forward(self, inputs, **kwargs): logits_teacher = inputs[-1] inputs = inputs[:-1] diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py index fe4dd9002..6b99c2697 100644 --- a/megatron/optimizer/__init__.py +++ b/megatron/optimizer/__init__.py @@ -18,6 +18,9 @@ from megatron import get_args from megatron.model.fused_layer_norm import MixedFusedLayerNorm as LayerNorm +from megatron.model.fused_layer_norm import MixedFusedLayerNormTeacher, MixedFusedLayerNormStudent +from megatron.model.transformer import ParallelTransformerLayerPipeStudent, ParallelTransformerLayerPipeTeacher +from megatron.model.language_model import EmbeddingPipeStudent, EmbeddingPipeTeacher from .grad_scaler import ConstantGradScaler, DynamicGradScaler from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer @@ -27,7 +30,8 @@ def _filter_for_teacher_student(modules): for module in modules: for module_ in module.modules(): - if "Student" in module_.__class__.__name__: + # if "Student" in module_.__class__.__name__: + if isinstance(module, (ParallelTransformerLayerPipeStudent, EmbeddingPipeStudent, MixedFusedLayerNormStudent)): trainable_modules.append(module_) return trainable_modules @@ -46,7 +50,7 @@ def _get_params_for_weight_decay_optimization(modules): for module in modules: for module_ in module.modules(): - if isinstance(module_, LayerNorm): + if isinstance(module_, MixedFusedLayerNormStudent): no_weight_decay_params['params'].extend( [p for p in list(module_._parameters.values()) if p is not None]) From 3aea1aaf419c57d515bb94ca50ac92cb4e900e8d Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 12 Oct 2022 12:46:25 +0200 Subject: [PATCH 3/5] add few modifs --- megatron/model/fused_layer_norm.py | 17 ++++++++++------- megatron/model/transformer.py | 8 +++++--- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index 16053e8e5..ae261f950 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -94,7 +94,6 @@ def __init__(self, normalized_shape, eps=1e-5): or version.parse(torch.__version__) >= version.parse("1.11.0") # https://github.com/pytorch/pytorch/pull/66920 ) - def reset_parameters(self): init.ones_(self.weight) @@ -116,12 +115,16 @@ def forward(self, input): class MixedFusedLayerNormTeacher(MixedFusedLayerNorm): # @torch.no_grad() def forward(self, input): - input, original_input = input - print("input", input.shape) - print("original_input", original_input[0].shape) - return (super().forward(input), original_input) + if len(input) ==2: + input, original_input = input + return (super().forward(input), original_input) + else: + return super().forward(input) class MixedFusedLayerNormStudent(MixedFusedLayerNorm): def forward(self, input): - input, logits_teacher = input - return (super().forward(input), logits_teacher) \ No newline at end of file + if len(input) == 2: + input, logits_teacher = input + return (super().forward(input), logits_teacher) + else: + return super().forward(input) \ No newline at end of file diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 9a83bb6a7..2f10ea0a9 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -23,6 +23,8 @@ from megatron import mpu from .module import MegatronModule from megatron.enums import AttnMaskType, LayerType, AttnType, PositionEmbeddingType +from megatron.model.fused_layer_norm import MixedFusedLayerNormStudent as LayerNormStudent +# from megatron.model.fused_layer_norm import MixedFusedLayerNorm as LayerNorm from megatron.model.fused_layer_norm import MixedFusedLayerNorm as LayerNorm from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_bias_gelu import bias_gelu_impl @@ -735,7 +737,7 @@ def __init__(self, init_method, output_layer_init_method, self.fp32_residual_connection = args.fp32_residual_connection # Layernorm on the input data. - self.input_layernorm = LayerNorm( + self.input_layernorm = LayerNormStudent( args.student_hidden_size, eps=args.layernorm_epsilon) @@ -751,7 +753,7 @@ def __init__(self, init_method, output_layer_init_method, self.bias_dropout_fusion = args.bias_dropout_fusion # Layernorm on the attention output - self.post_attention_layernorm = LayerNorm( + self.post_attention_layernorm = LayerNormStudent( args.student_hidden_size, eps=args.layernorm_epsilon) @@ -763,7 +765,7 @@ def __init__(self, init_method, output_layer_init_method, attention_type=AttnType.cross_attn, student_=True) # Layernorm on the attention output. - self.post_inter_attention_layernorm = LayerNorm( + self.post_inter_attention_layernorm = LayerNormStudent( args.student_hidden_size, eps=args.layernorm_epsilon) From 2d8bea2e087b283ab6d72deb05f5205d85553d75 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 12 Oct 2022 13:33:53 +0200 Subject: [PATCH 4/5] working v1 without pp --- megatron/model/fused_layer_norm.py | 17 +++++++++-------- megatron/model/gpt_model.py | 4 +--- megatron/model/language_model.py | 2 +- megatron/model/transformer.py | 16 ++++++---------- megatron/optimizer/__init__.py | 7 +++---- megatron/training.py | 2 ++ 6 files changed, 22 insertions(+), 26 deletions(-) diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index ae261f950..4dac97ec5 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -106,25 +106,26 @@ def forward(self, input): torch.distributed.all_reduce(self.weight, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group()) torch.distributed.all_reduce(self.bias, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group()) - if self.use_meg_ds_fused_layer_norm: + # if self.use_meg_ds_fused_layer_norm: + if False: return FusedLayerNormAffineFunction.apply( input, self.weight, self.bias, self.normalized_shape, self.eps) else: return F.layer_norm(input, self.normalized_shape, self.weight, self.bias) class MixedFusedLayerNormTeacher(MixedFusedLayerNorm): - # @torch.no_grad() + @torch.no_grad() def forward(self, input): - if len(input) ==2: - input, original_input = input - return (super().forward(input), original_input) + if isinstance(input, tuple): + input, *original_input = input + return (super().forward(input), *original_input) else: return super().forward(input) class MixedFusedLayerNormStudent(MixedFusedLayerNorm): def forward(self, input): - if len(input) == 2: - input, logits_teacher = input - return (super().forward(input), logits_teacher) + if isinstance(input, tuple): + input, *logits_teacher = input + return (super().forward(input), *logits_teacher) else: return super().forward(input) \ No newline at end of file diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 30335464f..5f61e8fc4 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -309,10 +309,8 @@ def _to_float16(inputs): # Undo data format change def undo(x): - # if not getattr(args, 'pretrain_causal_attention', False): - # x = x[0] if isinstance(x, tuple): - return (x[0][0].transpose(0, 1).contiguous(), x[1]) + return (x[0].transpose(0, 1).contiguous(), *x[1:]) return x.transpose(0, 1).contiguous() self.specs.append(undo) diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index bb743c1fb..6e99e68ad 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -310,7 +310,7 @@ def word_embeddings_weight(self): return self.word_embeddings.weight class EmbeddingPipeTeacher(EmbeddingPipe): - # @torch.no_grad() + @torch.no_grad() def forward(self, inputs, **kwargs): return (super().forward(inputs, **kwargs), inputs) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 2f10ea0a9..e1d15e97f 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -24,7 +24,6 @@ from .module import MegatronModule from megatron.enums import AttnMaskType, LayerType, AttnType, PositionEmbeddingType from megatron.model.fused_layer_norm import MixedFusedLayerNormStudent as LayerNormStudent -# from megatron.model.fused_layer_norm import MixedFusedLayerNorm as LayerNorm from megatron.model.fused_layer_norm import MixedFusedLayerNorm as LayerNorm from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_bias_gelu import bias_gelu_impl @@ -695,14 +694,11 @@ class ParallelTransformerLayerPipeTeacher(ParallelTransformerLayerPipe): When the mask is static over all samples, it is advantageous to cache the mask and avoid communicating it. """ - # @torch.no_grad() + @torch.no_grad() def forward(self, inputs, **kwargs): - input_ids = inputs[-1] - if isinstance(input_ids, tuple): - # input_ids = input_ids[0] - input_ids = input_ids + # input_ids = inputs[-1] # print(self.layer_number, input_ids) - return (super().forward(inputs, **kwargs), input_ids) + return (super().forward(inputs[0], **kwargs), *inputs[1:]) class ParallelTransformerLayerPipeStudent(ParallelTransformerLayerPipe): """Extends ParallelTransformerLayer to forward attention_mask through the pipeline. @@ -784,9 +780,9 @@ def __init__(self, init_method, output_layer_init_method, self.alibi = None def forward(self, inputs, **kwargs): - logits_teacher = inputs[-1] - inputs = inputs[:-1] - return (super().forward(inputs, **kwargs), logits_teacher) + # logits_teacher = inputs[-1] + # inputs = inputs[:-1] + return (super().forward(inputs[0], **kwargs), *inputs[1:]) class ParallelTransformer(MegatronModule): """Transformer class.""" diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py index 6b99c2697..68c74fea9 100644 --- a/megatron/optimizer/__init__.py +++ b/megatron/optimizer/__init__.py @@ -30,11 +30,10 @@ def _filter_for_teacher_student(modules): for module in modules: for module_ in module.modules(): - # if "Student" in module_.__class__.__name__: - if isinstance(module, (ParallelTransformerLayerPipeStudent, EmbeddingPipeStudent, MixedFusedLayerNormStudent)): + # TODO: this is empty ??? + if isinstance(module_, (ParallelTransformerLayerPipeStudent, EmbeddingPipeStudent, MixedFusedLayerNormStudent)): trainable_modules.append(module_) - - return trainable_modules + return modules diff --git a/megatron/training.py b/megatron/training.py index bd00bc77e..f6d470d37 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -953,6 +953,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler, params_norm = None if args.log_params_norm: params_norm = calc_params_l2_norm(model) + + # raise NotImplementedError(optimizer.param_groups) report_memory_flag = training_log(loss_dict, total_loss_dict, optimizer.param_groups[0]['lr'], iteration, loss_scale, From 96f97e12ae71a2d828a3aab8215270825642ad7d Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 19 Oct 2022 14:56:54 +0200 Subject: [PATCH 5/5] v2 --- megatron/model/fused_layer_norm.py | 28 ++++++++++++++-------------- megatron/model/gpt_model.py | 28 +++++++++++++++------------- megatron/model/language_model.py | 5 ++--- megatron/model/transformer.py | 11 +++++------ megatron/optimizer/__init__.py | 8 +++++--- megatron/training.py | 3 ++- 6 files changed, 43 insertions(+), 40 deletions(-) diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index 4dac97ec5..13a108096 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -101,31 +101,31 @@ def reset_parameters(self): def forward(self, input): + if isinstance(input, tuple): + input = input[0] if self.layernorm_tp_auto_sync: torch.distributed.all_reduce(self.weight, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group()) torch.distributed.all_reduce(self.bias, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group()) - # if self.use_meg_ds_fused_layer_norm: - if False: + if self.use_meg_ds_fused_layer_norm: + #if False: return FusedLayerNormAffineFunction.apply( input, self.weight, self.bias, self.normalized_shape, self.eps) else: return F.layer_norm(input, self.normalized_shape, self.weight, self.bias) -class MixedFusedLayerNormTeacher(MixedFusedLayerNorm): - @torch.no_grad() +class MixedFusedLayerNormTeacher(MixedFusedLayerNorm): + # @torch.no_grad() def forward(self, input): - if isinstance(input, tuple): - input, *original_input = input - return (super().forward(input), *original_input) - else: - return super().forward(input) + # if isinstance(input, tuple): + input, *original_input = input + # return (super().forward(input), *original_input) + # else: + # return super().forward(input) + + return super().forward(input), *original_input class MixedFusedLayerNormStudent(MixedFusedLayerNorm): def forward(self, input): - if isinstance(input, tuple): - input, *logits_teacher = input - return (super().forward(input), *logits_teacher) - else: - return super().forward(input) \ No newline at end of file + return (super().forward(input), input[1]) \ No newline at end of file diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 5f61e8fc4..730a7faed 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -200,6 +200,8 @@ def CrossEntropy(output, labels): def get_ts_loss(is_prefix: bool): def TeacherStudentLoss(output, labels): student_logits, teacher_logits = output + if isinstance(teacher_logits, tuple): + teacher_logits = teacher_logits[0] labels, loss_mask = labels[0], labels[1] args = get_args() @@ -232,15 +234,12 @@ def TeacherStudentLoss(output, labels): loss = torch.sum(losses.view(-1) * loss_mask) / expected_number_of_tokens # TODO: check if the formula is correct - teacher_logits = teacher_logits.detach() + # teacher_logits = teacher_logits.detach() # First pass it on CPU - otherwise we get OOM errors - softmax_labels = torch.nn.Softmax(dim=-1)(teacher_logits) - # softmax_labels = softmax_labels.permute(1, 0, 2) - - student_log_softax = -torch.nn.LogSoftmax(dim=-1)(student_logits) + # teacher_logits = teacher_logits.detach() + softmax_labels = torch.nn.Softmax(dim=-1)(teacher_logits.contiguous().float()) + student_log_softax = -torch.nn.LogSoftmax(dim=-1)(student_logits.contiguous().float()) - # print(output.shape, teacher_logits.shape) - # print(student_log_softax.shape, softmax_labels.shape) softmax_logits = student_log_softax * softmax_labels logits_loss = softmax_logits.mean() @@ -275,7 +274,7 @@ def _to_float16(inputs): self.specs.append(_to_float16) # Embedding layer - self.specs.append(TiedLayerSpec('embed', + self.specs.append(TiedLayerSpec('embed_teacher', EmbeddingPipeTeacher, args.hidden_size, args.padded_vocab_size, @@ -283,7 +282,7 @@ def _to_float16(inputs): init_method=init_method, num_tokentypes=num_tokentypes, tied_weight_attr='word_embeddings_weight')) - + if args.fp32_residual_connection: if getattr(args, 'pretrain_causal_attention', False): self.specs.append(lambda x: x.transpose(0, 1).contiguous().float()) @@ -292,7 +291,8 @@ def _to_float16(inputs): self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous().float(), *x[1:])) else: if getattr(args, 'pretrain_causal_attention', False): - self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous(), x[1])) + self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous(), *x[1:])) + # self.specs.append(lambda x: (x.transpose(0, 1).contiguous(), *x[1:])) else: # EmbeddingPipe returns attention mask as well self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous(), *x[1:])) @@ -310,7 +310,8 @@ def _to_float16(inputs): # Undo data format change def undo(x): if isinstance(x, tuple): - return (x[0].transpose(0, 1).contiguous(), *x[1:]) + return (x[0].transpose(0, 1).contiguous(), (x[1:])) + # return (x[0].transpose(0, 1).contiguous(), *x[1:]) return x.transpose(0, 1).contiguous() self.specs.append(undo) @@ -328,7 +329,7 @@ def _logits_helper(embedding, lm_output): self.parallel_output) self.specs.append( - TiedLayerSpec('embed', + TiedLayerSpec('embed_teacher', EmbeddingPipeTeacher, args.hidden_size, args.padded_vocab_size, @@ -408,7 +409,7 @@ def _to_float16(inputs): # Undo data format change def undo(x): if isinstance(x, tuple): - return (x[0].transpose(0, 1).contiguous(), x[1]) + return (x[0].transpose(0, 1).contiguous(), x[1:]) return x.transpose(0, 1).contiguous() self.specs.append(undo) @@ -424,6 +425,7 @@ def _logits_helper_student(embedding, lm_output): lm_output, embedding.word_embeddings_weight, self.parallel_output, permute_output=True) + self.specs.append( TiedLayerSpec('embed_student', diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index 6e99e68ad..001d19709 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -279,7 +279,6 @@ class EmbeddingPipe(Embedding): def forward(self, inputs, **kwargs): if not hasattr(self, '_args'): self._args = get_args() - input_ids = inputs[0] position_ids = inputs[1] if isinstance(input_ids, tuple): @@ -310,9 +309,9 @@ def word_embeddings_weight(self): return self.word_embeddings.weight class EmbeddingPipeTeacher(EmbeddingPipe): - @torch.no_grad() + # @torch.no_grad() def forward(self, inputs, **kwargs): - return (super().forward(inputs, **kwargs), inputs) + return (super().forward(inputs, **kwargs), *inputs) class EmbeddingPipeStudent(EmbeddingPipe): def forward(self, inputs, **kwargs): diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index e1d15e97f..78400ad22 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -597,6 +597,9 @@ def forward(self, hidden_states, attention_mask, layernorm_output = self.post_inter_attention_layernorm(layernorm_input) # MLP. + # print("========", layernorm_output) + if isinstance(layernorm_output, tuple): + layernorm_output, _ = layernorm_output mlp_output, mlp_bias = self.mlp(layernorm_output) # Second residual connection. @@ -694,10 +697,8 @@ class ParallelTransformerLayerPipeTeacher(ParallelTransformerLayerPipe): When the mask is static over all samples, it is advantageous to cache the mask and avoid communicating it. """ - @torch.no_grad() + # @torch.no_grad() def forward(self, inputs, **kwargs): - # input_ids = inputs[-1] - # print(self.layer_number, input_ids) return (super().forward(inputs[0], **kwargs), *inputs[1:]) class ParallelTransformerLayerPipeStudent(ParallelTransformerLayerPipe): @@ -780,9 +781,7 @@ def __init__(self, init_method, output_layer_init_method, self.alibi = None def forward(self, inputs, **kwargs): - # logits_teacher = inputs[-1] - # inputs = inputs[:-1] - return (super().forward(inputs[0], **kwargs), *inputs[1:]) + return (super().forward(inputs[0], **kwargs), inputs[1]) class ParallelTransformer(MegatronModule): """Transformer class.""" diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py index 68c74fea9..31f1e3a3f 100644 --- a/megatron/optimizer/__init__.py +++ b/megatron/optimizer/__init__.py @@ -29,11 +29,13 @@ def _filter_for_teacher_student(modules): trainable_modules = [] for module in modules: - for module_ in module.modules(): + # for module_ in module.modules(): + for module_ in module.children(): # TODO: this is empty ??? if isinstance(module_, (ParallelTransformerLayerPipeStudent, EmbeddingPipeStudent, MixedFusedLayerNormStudent)): trainable_modules.append(module_) - return modules + # return modules + return trainable_modules @@ -45,7 +47,7 @@ def _get_params_for_weight_decay_optimization(modules): weight_decay_params = {'params': []} no_weight_decay_params = {'params': [], 'weight_decay': 0.0} - modules = _filter_for_teacher_student(modules) + # modules = _filter_for_teacher_student(modules) for module in modules: for module_ in module.modules(): diff --git a/megatron/training.py b/megatron/training.py index f6d470d37..d45eabc3c 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -955,8 +955,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler, params_norm = calc_params_l2_norm(model) # raise NotImplementedError(optimizer.param_groups) + report_memory_flag = training_log(loss_dict, total_loss_dict, - optimizer.param_groups[0]['lr'], + optimizer.param_groups[0]['lr'] if len(optimizer.param_groups) > 0 else 0.0, iteration, loss_scale, report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad,