diff --git a/examples/fairseq/tasks/data/utils.py b/examples/fairseq/tasks/data/utils.py index 9511061b..d8c0f0c0 100644 --- a/examples/fairseq/tasks/data/utils.py +++ b/examples/fairseq/tasks/data/utils.py @@ -62,7 +62,7 @@ def close(self): pass -class WeightIterator(object): +class WeightIterator: def __init__(self, weights, seed): self.weights = weights self.seed = seed diff --git a/torchscale/architecture/config.py b/torchscale/architecture/config.py index 0d2e9bee..74fa3f92 100644 --- a/torchscale/architecture/config.py +++ b/torchscale/architecture/config.py @@ -2,14 +2,8 @@ # Licensed under The MIT License [see LICENSE for details] -class EncoderConfig(object): +class Config: def __init__(self, **kwargs): - self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768) - self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12) - self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072) - self.encoder_layers = kwargs.pop("encoder_layers", 12) - self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True) - self.normalize_output = kwargs.pop("normalize_output", True) self.activation_fn = kwargs.pop("activation_fn", "gelu") self.dropout = kwargs.pop("dropout", 0.0) self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0) @@ -35,31 +29,15 @@ def __init__(self, **kwargs): self.subln = kwargs.pop("subln", True) self.bert_init = kwargs.pop("bert_init", False) self.multiway = kwargs.pop("multiway", False) - self.share_encoder_input_output_embed = kwargs.pop( - "share_encoder_input_output_embed", False - ) - self.max_source_positions = kwargs.pop("max_source_positions", 1024) self.no_output_layer = kwargs.pop("no_output_layer", False) self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5) - # Text self.vocab_size = kwargs.pop("vocab_size", -1) - # Vision - self.img_size = kwargs.pop("img_size", 224) - self.patch_size = kwargs.pop("patch_size", 16) - self.in_chans = kwargs.pop("in_chans", 3) - # Fairscale self.checkpoint_activations = kwargs.pop("checkpoint_activations", False) self.fsdp = kwargs.pop("fsdp", False) self.ddp_rank = kwargs.pop("ddp_rank", 0) self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512) - if self.deepnorm: - self.encoder_normalize_before = False - self.subln = False - if self.subln: - self.encoder_normalize_before = True - self.deepnorm = False if self.use_xmoe: self.moe_normalize_gate_prob_before_dropping = True self.moe_second_expert_policy = "random" @@ -71,138 +49,58 @@ def override(self, args): self.__dict__[hp] = getattr(args, hp, None) -class DecoderConfig(object): +class EncoderConfig(Config): def __init__(self, **kwargs): - self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768) - self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12) - self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072) - self.decoder_layers = kwargs.pop("decoder_layers", 12) - self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True) - self.activation_fn = kwargs.pop("activation_fn", "gelu") - self.dropout = kwargs.pop("dropout", 0.0) - self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0) - self.attention_dropout = kwargs.pop("attention_dropout", 0.0) - self.activation_dropout = kwargs.pop("activation_dropout", 0.0) - self.no_scale_embedding = kwargs.pop("no_scale_embedding", True) - self.layernorm_embedding = kwargs.pop("layernorm_embedding", False) - self.moe_freq = kwargs.pop("moe_freq", 0) - self.moe_top1_expert = kwargs.pop("moe_top1_expert", False) - self.moe_expert_count = kwargs.pop("moe_expert_count", 0) - self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True) - self.moe_eval_capacity_token_fraction = kwargs.pop( - "moe_eval_capacity_token_fraction", 0.25 - ) - self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random") - self.moe_normalize_gate_prob_before_dropping = kwargs.pop( - "moe_normalize_gate_prob_before_dropping", False - ) - self.use_xmoe = kwargs.pop("use_xmoe", False) + super().__init__(**kwargs) + self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768) + self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12) + self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072) + self.encoder_layers = kwargs.pop("encoder_layers", 12) + self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True) + self.normalize_output = kwargs.pop("normalize_output", True) self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0) - self.max_rel_pos = kwargs.pop("max_rel_pos", 0) - self.deepnorm = kwargs.pop("deepnorm", False) - self.subln = kwargs.pop("subln", True) - self.bert_init = kwargs.pop("bert_init", False) - self.multiway = kwargs.pop("multiway", False) - self.share_decoder_input_output_embed = kwargs.pop( - "share_decoder_input_output_embed", False + self.share_encoder_input_output_embed = kwargs.pop( + "share_encoder_input_output_embed", False ) - self.max_target_positions = kwargs.pop("max_target_positions", 1024) - self.no_output_layer = kwargs.pop("no_output_layer", False) - self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5) - # Text - self.vocab_size = kwargs.pop("vocab_size", -1) - # Fairscale - self.checkpoint_activations = kwargs.pop("checkpoint_activations", False) - self.fsdp = kwargs.pop("fsdp", False) - self.ddp_rank = kwargs.pop("ddp_rank", 0) - self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) - self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512) + self.max_source_positions = kwargs.pop("max_source_positions", 1024) + # Vision + self.img_size = kwargs.pop("img_size", 224) + self.patch_size = kwargs.pop("patch_size", 16) + self.in_chans = kwargs.pop("in_chans", 3) + if self.deepnorm: - self.decoder_normalize_before = False + self.encoder_normalize_before = False self.subln = False if self.subln: - self.decoder_normalize_before = True + self.encoder_normalize_before = True self.deepnorm = False - if self.use_xmoe: - self.moe_normalize_gate_prob_before_dropping = True - self.moe_second_expert_policy = "random" - assert self.moe_freq > 0 and self.moe_expert_count > 0 + - def override(self, args): - for hp in self.__dict__.keys(): - if getattr(args, hp, None) is not None: - self.__dict__[hp] = getattr(args, hp, None) - - -class EncoderDecoderConfig(object): +class DecoderConfig(Config): def __init__(self, **kwargs): - self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768) - self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12) - self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072) - self.encoder_layers = kwargs.pop("encoder_layers", 12) - self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True) + super().__init__(**kwargs) self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768) self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12) self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072) self.decoder_layers = kwargs.pop("decoder_layers", 12) self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True) - self.activation_fn = kwargs.pop("activation_fn", "gelu") - self.dropout = kwargs.pop("dropout", 0.0) - self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0) - self.attention_dropout = kwargs.pop("attention_dropout", 0.0) - self.activation_dropout = kwargs.pop("activation_dropout", 0.0) - self.no_scale_embedding = kwargs.pop("no_scale_embedding", True) - self.layernorm_embedding = kwargs.pop("layernorm_embedding", False) - self.moe_freq = kwargs.pop("moe_freq", 0) - self.moe_top1_expert = kwargs.pop("moe_top1_expert", False) - self.moe_expert_count = kwargs.pop("moe_expert_count", 0) - self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True) - self.moe_eval_capacity_token_fraction = kwargs.pop( - "moe_eval_capacity_token_fraction", 0.25 - ) - self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random") - self.moe_normalize_gate_prob_before_dropping = kwargs.pop( - "moe_normalize_gate_prob_before_dropping", False - ) - self.use_xmoe = kwargs.pop("use_xmoe", False) self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0) - self.max_rel_pos = kwargs.pop("max_rel_pos", 0) - self.deepnorm = kwargs.pop("deepnorm", False) - self.subln = kwargs.pop("subln", True) - self.bert_init = kwargs.pop("bert_init", False) - self.multiway = kwargs.pop("multiway", False) - self.share_all_embeddings = kwargs.pop("share_all_embeddings", False) self.share_decoder_input_output_embed = kwargs.pop( "share_decoder_input_output_embed", False ) - self.max_source_positions = kwargs.pop("max_source_positions", 1024) self.max_target_positions = kwargs.pop("max_target_positions", 1024) - self.no_output_layer = kwargs.pop("no_output_layer", False) - self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5) - # Text - self.vocab_size = kwargs.pop("vocab_size", -1) - # Fairscale - self.checkpoint_activations = kwargs.pop("checkpoint_activations", False) - self.fsdp = kwargs.pop("fsdp", False) - self.ddp_rank = kwargs.pop("ddp_rank", 0) - self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) - self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512) if self.deepnorm: - self.encoder_normalize_before = False self.decoder_normalize_before = False self.subln = False if self.subln: - self.encoder_normalize_before = True self.decoder_normalize_before = True self.deepnorm = False - if self.use_xmoe: - self.moe_normalize_gate_prob_before_dropping = True - self.moe_second_expert_policy = "random" - assert self.moe_freq > 0 and self.moe_expert_count > 0 - def override(self, args): - for hp in self.__dict__.keys(): - if getattr(args, hp, None) is not None: - self.__dict__[hp] = getattr(args, hp, None) + +class EncoderDecoderConfig(EncoderConfig, DecoderConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.share_all_embeddings = kwargs.pop("share_all_embeddings", False) + diff --git a/torchscale/architecture/decoder.py b/torchscale/architecture/decoder.py index ed407b06..e3288347 100644 --- a/torchscale/architecture/decoder.py +++ b/torchscale/architecture/decoder.py @@ -9,12 +9,14 @@ from fairscale.nn import checkpoint_wrapper, wrap from torchscale.architecture.utils import init_bert_params +from torchscale.architecture.config import DecoderConfig, EncoderConfig from torchscale.component.droppath import DropPath from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts from torchscale.component.multihead_attention import MultiheadAttention from torchscale.component.relative_position_bias import RelativePositionBias from torchscale.component.xmoe.moe_layer import MOELayer from torchscale.component.xmoe.routing import Top1Gate, Top2Gate + try: from apex.normalization import FusedLayerNorm as LayerNorm except ModuleNotFoundError: @@ -23,7 +25,7 @@ class DecoderLayer(nn.Module): def __init__( self, - args, + args: DecoderConfig, depth, is_moe_layer=False, is_encoder_decoder=False, @@ -209,7 +211,7 @@ def forward( class Decoder(nn.Module): def __init__( self, - args, + args: DecoderConfig, embed_tokens=None, embed_positions=None, output_projection=None, diff --git a/torchscale/architecture/encoder.py b/torchscale/architecture/encoder.py index 62ab174f..c20fa65c 100644 --- a/torchscale/architecture/encoder.py +++ b/torchscale/architecture/encoder.py @@ -13,6 +13,7 @@ from torch.nn import LayerNorm from torchscale.architecture.utils import init_bert_params +from torchscale.architecture.config import EncoderConfig from torchscale.component.droppath import DropPath from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts from torchscale.component.multihead_attention import MultiheadAttention @@ -23,7 +24,11 @@ class EncoderLayer(nn.Module): - def __init__(self, args, depth, is_moe_layer=False, is_encoder_decoder=False): + def __init__(self, + args: EncoderConfig, + depth, + is_moe_layer: bool = False, + is_encoder_decoder: bool = False): super().__init__() self.args = args self.embed_dim = args.encoder_embed_dim @@ -165,11 +170,11 @@ def forward(self, x, encoder_padding_mask, attn_mask=None, rel_pos=None, multiwa class Encoder(nn.Module): def __init__( self, - args, + args: EncoderConfig, embed_tokens=None, embed_positions=None, output_projection=None, - is_encoder_decoder=False, + is_encoder_decoder: bool = False, **kwargs ): self.args = args diff --git a/torchscale/architecture/encoder_decoder.py b/torchscale/architecture/encoder_decoder.py index 91a906ec..ed64641c 100644 --- a/torchscale/architecture/encoder_decoder.py +++ b/torchscale/architecture/encoder_decoder.py @@ -3,6 +3,7 @@ import torch.nn as nn +from torchscale.architecture.config import EncoderDecoderConfig from torchscale.architecture.decoder import Decoder from torchscale.architecture.encoder import Encoder @@ -10,7 +11,7 @@ class EncoderDecoder(nn.Module): def __init__( self, - args, + args: EncoderDecoderConfig, encoder_embed_tokens=None, encoder_embed_positions=None, decoder_embed_tokens=None, diff --git a/torchscale/component/feedforward_network.py b/torchscale/component/feedforward_network.py index cc187a8a..8e970d28 100644 --- a/torchscale/component/feedforward_network.py +++ b/torchscale/component/feedforward_network.py @@ -13,7 +13,7 @@ from .xmoe.global_groups import get_moe_group -class set_torch_seed(object): +class set_torch_seed: def __init__(self, seed): assert isinstance(seed, int) self.rng_state = self.get_rng_state()