3434_BIOBERT_PATH = "https://github.com/naver/biobert-pretrained/releases/download/"
3535_SCIBERT_PATH = "https://s3-us-west-2.amazonaws.com/ai2-s2-research/" \
3636 "scibert/tensorflow_models/"
37+ _SPANBERT_PATH = "https://dl.fbaipublicfiles.com/fairseq/models/"
3738
3839
3940class PretrainedBERTMixin (PretrainedMixin , ABC ):
@@ -97,6 +98,21 @@ class PretrainedBERTMixin(PretrainedMixin, ABC):
9798 * ``scibert-basevocab-cased``: Cased version of the model trained on
9899 the original BERT vocabulary.
99100
101+ * **SpanBERT**: proposed in (`Joshi et al`. 2019)
102+ `SpanBERT: Improving Pre-training by Representing and Predicting Spans`_.
103+ As a variant of the standard BERT model, SpanBERT extends BERT by
104+ (1) masking contiguous random spans, rather than random tokens, and
105+ (2) training the span boundary representations to predict the entire
106+ content of the masked span, without relying on the individual token
107+ representations within it. Differing from the standard BERT, the
108+ SpanBERT model does not use segmentation embedding. Available model names
109+ include:
110+
111+ * ``spanbert-base-cased``: SpanBERT using the BERT-base architecture,
112+ 12-layer, 768-hidden, 12-heads , 110M parameters.
113+ * ``spanbert-large-cased``: SpanBERT using the BERT-large architecture,
114+ 24-layer, 1024-hidden, 16-heads, 340M parameters.
115+
100116 We provide the following BERT classes:
101117
102118 * :class:`~texar.torch.modules.BERTEncoder` for text encoding.
@@ -111,6 +127,9 @@ class PretrainedBERTMixin(PretrainedMixin, ABC):
111127
112128 .. _`SciBERT: A Pretrained Language Model for Scientific Text`:
113129 https://arxiv.org/abs/1903.10676
130+
131+ .. _`SpanBERT: Improving Pre-training by Representing and Predicting Spans`:
132+ https://arxiv.org/abs/1907.10529
114133 """
115134
116135 _MODEL_NAME = "BERT"
@@ -150,6 +169,12 @@ class PretrainedBERTMixin(PretrainedMixin, ABC):
150169 _SCIBERT_PATH + 'scibert_basevocab_uncased.tar.gz' ,
151170 'scibert-basevocab-cased' :
152171 _SCIBERT_PATH + 'scibert_basevocab_cased.tar.gz' ,
172+
173+ # SpanBERT
174+ 'spanbert-base-cased' :
175+ _SPANBERT_PATH + "spanbert_hf_base.tar.gz" ,
176+ 'spanbert-large-cased' :
177+ _SPANBERT_PATH + "spanbert_hf.tar.gz" ,
153178 }
154179 _MODEL2CKPT = {
155180 # Standard BERT
@@ -172,6 +197,10 @@ class PretrainedBERTMixin(PretrainedMixin, ABC):
172197 'scibert-scivocab-cased' : 'bert_model.ckpt' ,
173198 'scibert-basevocab-uncased' : 'bert_model.ckpt' ,
174199 'scibert-basevocab-cased' : 'bert_model.ckpt' ,
200+
201+ # SpanBERT
202+ 'spanbert-base-cased' : 'pytorch_model.bin' ,
203+ 'spanbert-large-cased' : 'pytorch_model.bin' ,
175204 }
176205
177206 @classmethod
@@ -182,13 +211,14 @@ def _transform_config(cls, pretrained_model_name: str,
182211 config_path = None
183212
184213 for file in files :
185- if file == 'bert_config.json' :
214+ if file in ( 'bert_config.json' , 'config.json' ) :
186215 config_path = os .path .join (root , file )
187216 with open (config_path ) as f :
188217 config_ckpt = json .loads (f .read ())
189218 hidden_dim = config_ckpt ['hidden_size' ]
190219 vocab_size = config_ckpt ['vocab_size' ]
191- type_vocab_size = config_ckpt ['type_vocab_size' ]
220+ if not pretrained_model_name .startswith ('spanbert' ):
221+ type_vocab_size = config_ckpt ['type_vocab_size' ]
192222 position_size = config_ckpt ['max_position_embeddings' ]
193223 embedding_dropout = config_ckpt ['hidden_dropout_prob' ]
194224 num_blocks = config_ckpt ['num_hidden_layers' ]
@@ -208,11 +238,6 @@ def _transform_config(cls, pretrained_model_name: str,
208238 'dim' : hidden_dim
209239 },
210240 'vocab_size' : vocab_size ,
211- 'segment_embed' : {
212- 'name' : 'token_type_embeddings' ,
213- 'dim' : hidden_dim
214- },
215- 'type_vocab_size' : type_vocab_size ,
216241 'position_embed' : {
217242 'name' : 'position_embeddings' ,
218243 'dim' : hidden_dim
@@ -256,10 +281,74 @@ def _transform_config(cls, pretrained_model_name: str,
256281 }
257282 }
258283
284+ if not pretrained_model_name .startswith ('spanbert' ):
285+ configs .update ({
286+ 'segment_embed' : {
287+ 'name' : 'token_type_embeddings' ,
288+ 'dim' : hidden_dim },
289+ 'type_vocab_size' : type_vocab_size ,
290+ })
291+
259292 return configs
260293
261294 def _init_from_checkpoint (self , pretrained_model_name : str ,
262295 cache_dir : str , ** kwargs ):
296+ if pretrained_model_name .startswith ('spanbert' ):
297+ global_tensor_map = {
298+ 'bert.embeddings.word_embeddings.weight' :
299+ 'word_embedder._embedding' ,
300+ 'bert.embeddings.position_embeddings.weight' :
301+ 'position_embedder._embedding' ,
302+ 'bert.embeddings.LayerNorm.weight' :
303+ 'encoder.input_normalizer.weight' ,
304+ 'bert.embeddings.LayerNorm.bias' :
305+ 'encoder.input_normalizer.bias' ,
306+ }
307+
308+ attention_tensor_map = {
309+ "attention.self.key.bias" : "self_attns.{}.K_dense.bias" ,
310+ "attention.self.query.bias" : "self_attns.{}.Q_dense.bias" ,
311+ "attention.self.value.bias" : "self_attns.{}.V_dense.bias" ,
312+ "attention.output.dense.bias" : "self_attns.{}.O_dense.bias" ,
313+ "attention.output.LayerNorm.weight" :
314+ "poswise_layer_norm.{}.weight" ,
315+ "attention.output.LayerNorm.bias" : "poswise_layer_norm.{}.bias" ,
316+ "intermediate.dense.bias" : "poswise_networks.{}._layers.0.bias" ,
317+ "output.dense.bias" : "poswise_networks.{}._layers.2.bias" ,
318+ "output.LayerNorm.weight" : "output_layer_norm.{}.weight" ,
319+ "output.LayerNorm.bias" : "output_layer_norm.{}.bias" ,
320+ "attention.self.key.weight" : "self_attns.{}.K_dense.weight" ,
321+ "attention.self.query.weight" : "self_attns.{}.Q_dense.weight" ,
322+ "attention.self.value.weight" : "self_attns.{}.V_dense.weight" ,
323+ "attention.output.dense.weight" : "self_attns.{}.O_dense.weight" ,
324+ "intermediate.dense.weight" :
325+ "poswise_networks.{}._layers.0.weight" ,
326+ "output.dense.weight" : "poswise_networks.{}._layers.2.weight" ,
327+ }
328+ checkpoint_path = os .path .abspath (os .path .join (
329+ cache_dir , self ._MODEL2CKPT [pretrained_model_name ]))
330+
331+ device = next (self .parameters ()).device
332+ params = torch .load (checkpoint_path , map_location = device )
333+
334+ for name , tensor in params .items ():
335+ if name in global_tensor_map :
336+ v_name = global_tensor_map [name ]
337+ pointer = self ._name_to_variable (v_name )
338+ assert pointer .shape == tensor .shape
339+ pointer .data = tensor .data .type (pointer .dtype )
340+ elif name .startswith ('bert.encoder.layer.' ):
341+ name = name .lstrip ('bert.encoder.layer.' )
342+ layer_num , layer_name = name .split ('.' , 1 )
343+ if layer_name in attention_tensor_map :
344+ v_name = attention_tensor_map [layer_name ]
345+ pointer = self ._name_to_variable (
346+ 'encoder.' + v_name .format (layer_num ))
347+ assert pointer .shape == tensor .shape
348+ pointer .data = tensor .data .type (pointer .dtype )
349+
350+ return
351+
263352 try :
264353 import numpy as np
265354 import tensorflow as tf
0 commit comments