Skip to content

Commit ab9683f

Browse files
authored
Add SpanBERT module (#300)
* Add SpanBERT module * Clean the code * Fix CI * Fix CI * Resolve comments
1 parent 931ead9 commit ab9683f

File tree

4 files changed

+199
-11
lines changed

4 files changed

+199
-11
lines changed

texar/torch/data/tokenizers/bert_tokenizer.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from texar.torch.data.tokenizers.tokenizer_base import TokenizerBase
2727
from texar.torch.data.tokenizers.bert_tokenizer_utils import \
2828
load_vocab, BasicTokenizer, WordpieceTokenizer
29+
from texar.torch.hyperparams import HParams
2930
from texar.torch.utils.utils import truncate_seq_pair
3031

3132
__all__ = [
@@ -74,6 +75,10 @@ class BERTTokenizer(PretrainedBERTMixin, TokenizerBase):
7475
'scibert-scivocab-cased': 512,
7576
'scibert-basevocab-uncased': 512,
7677
'scibert-basevocab-cased': 512,
78+
79+
# SpanBERT
80+
'spanbert-base-cased': 512,
81+
'spanbert-large-cased': 512,
7782
}
7883
_VOCAB_FILE_NAMES = {'vocab_file': 'vocab.txt'}
7984
_VOCAB_FILE_MAP = {
@@ -98,13 +103,30 @@ class BERTTokenizer(PretrainedBERTMixin, TokenizerBase):
98103
'scibert-scivocab-cased': 'vocab.txt',
99104
'scibert-basevocab-uncased': 'vocab.txt',
100105
'scibert-basevocab-cased': 'vocab.txt',
106+
107+
# SpanBERT
108+
'spanbert-base-cased': 'vocab.txt',
109+
'spanbert-large-cased': 'vocab.txt',
101110
}
102111
}
103112

104113
def __init__(self,
105114
pretrained_model_name: Optional[str] = None,
106115
cache_dir: Optional[str] = None,
107116
hparams=None):
117+
118+
# SpanBERT checkpoint files do not include vocabulary file, use
119+
# standard BERT directly when user use the pre-trained SpanBERT.
120+
if pretrained_model_name is not None:
121+
if pretrained_model_name.startswith('spanbert'):
122+
pretrained_model_name = pretrained_model_name.lstrip('span')
123+
elif hparams is not None:
124+
hparams = HParams(hparams, None)
125+
if hparams.pretrained_model_name is not None and \
126+
hparams.pretrained_model_name.startswith('spanbert'):
127+
pretrained_model_name = \
128+
hparams.pretrained_model_name.lstrip('span')
129+
108130
self.load_pretrained_config(pretrained_model_name, cache_dir, hparams)
109131

110132
super().__init__(hparams=None)

texar/torch/modules/encoders/bert_encoder.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,14 @@ def __init__(self,
7575
# Segment embedding for each type of tokens
7676
self.segment_embedder = None
7777
if self._hparams.get('type_vocab_size', 0) > 0:
78-
self.segment_embedder = WordEmbedder(
79-
vocab_size=self._hparams.type_vocab_size,
80-
hparams=self._hparams.segment_embed)
78+
if self.pretrained_model_name is not None and \
79+
self.pretrained_model_name.startswith('spanbert'):
80+
# Do not construct segment_embedder for SpanBERT
81+
pass
82+
else:
83+
self.segment_embedder = WordEmbedder(
84+
vocab_size=self._hparams.type_vocab_size,
85+
hparams=self._hparams.segment_embed)
8186

8287
# Position embedding
8388
self.position_embedder = PositionEmbedder(
@@ -289,7 +294,10 @@ def forward(self, # type: ignore
289294
inputs: Union[torch.Tensor, torch.LongTensor],
290295
sequence_length: Optional[torch.LongTensor] = None,
291296
segment_ids: Optional[torch.LongTensor] = None):
292-
r"""Encodes the inputs.
297+
r"""Encodes the inputs. Note that the SpanBERT model does not use
298+
segmentation embedding. As a result, SpanBERT does not require
299+
`segment_ids` as an input when you use pre-trained SpanBERT checkpoint
300+
files.
293301
294302
Args:
295303
inputs: Either a **2D Tensor** of shape `[batch_size, max_time]`,

texar/torch/modules/pretrained/bert.py

Lines changed: 96 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
_BIOBERT_PATH = "https://github.com/naver/biobert-pretrained/releases/download/"
3535
_SCIBERT_PATH = "https://s3-us-west-2.amazonaws.com/ai2-s2-research/" \
3636
"scibert/tensorflow_models/"
37+
_SPANBERT_PATH = "https://dl.fbaipublicfiles.com/fairseq/models/"
3738

3839

3940
class PretrainedBERTMixin(PretrainedMixin, ABC):
@@ -97,6 +98,21 @@ class PretrainedBERTMixin(PretrainedMixin, ABC):
9798
* ``scibert-basevocab-cased``: Cased version of the model trained on
9899
the original BERT vocabulary.
99100
101+
* **SpanBERT**: proposed in (`Joshi et al`. 2019)
102+
`SpanBERT: Improving Pre-training by Representing and Predicting Spans`_.
103+
As a variant of the standard BERT model, SpanBERT extends BERT by
104+
(1) masking contiguous random spans, rather than random tokens, and
105+
(2) training the span boundary representations to predict the entire
106+
content of the masked span, without relying on the individual token
107+
representations within it. Differing from the standard BERT, the
108+
SpanBERT model does not use segmentation embedding. Available model names
109+
include:
110+
111+
* ``spanbert-base-cased``: SpanBERT using the BERT-base architecture,
112+
12-layer, 768-hidden, 12-heads , 110M parameters.
113+
* ``spanbert-large-cased``: SpanBERT using the BERT-large architecture,
114+
24-layer, 1024-hidden, 16-heads, 340M parameters.
115+
100116
We provide the following BERT classes:
101117
102118
* :class:`~texar.torch.modules.BERTEncoder` for text encoding.
@@ -111,6 +127,9 @@ class PretrainedBERTMixin(PretrainedMixin, ABC):
111127
112128
.. _`SciBERT: A Pretrained Language Model for Scientific Text`:
113129
https://arxiv.org/abs/1903.10676
130+
131+
.. _`SpanBERT: Improving Pre-training by Representing and Predicting Spans`:
132+
https://arxiv.org/abs/1907.10529
114133
"""
115134

116135
_MODEL_NAME = "BERT"
@@ -150,6 +169,12 @@ class PretrainedBERTMixin(PretrainedMixin, ABC):
150169
_SCIBERT_PATH + 'scibert_basevocab_uncased.tar.gz',
151170
'scibert-basevocab-cased':
152171
_SCIBERT_PATH + 'scibert_basevocab_cased.tar.gz',
172+
173+
# SpanBERT
174+
'spanbert-base-cased':
175+
_SPANBERT_PATH + "spanbert_hf_base.tar.gz",
176+
'spanbert-large-cased':
177+
_SPANBERT_PATH + "spanbert_hf.tar.gz",
153178
}
154179
_MODEL2CKPT = {
155180
# Standard BERT
@@ -172,6 +197,10 @@ class PretrainedBERTMixin(PretrainedMixin, ABC):
172197
'scibert-scivocab-cased': 'bert_model.ckpt',
173198
'scibert-basevocab-uncased': 'bert_model.ckpt',
174199
'scibert-basevocab-cased': 'bert_model.ckpt',
200+
201+
# SpanBERT
202+
'spanbert-base-cased': 'pytorch_model.bin',
203+
'spanbert-large-cased': 'pytorch_model.bin',
175204
}
176205

177206
@classmethod
@@ -182,13 +211,14 @@ def _transform_config(cls, pretrained_model_name: str,
182211
config_path = None
183212

184213
for file in files:
185-
if file == 'bert_config.json':
214+
if file in ('bert_config.json', 'config.json'):
186215
config_path = os.path.join(root, file)
187216
with open(config_path) as f:
188217
config_ckpt = json.loads(f.read())
189218
hidden_dim = config_ckpt['hidden_size']
190219
vocab_size = config_ckpt['vocab_size']
191-
type_vocab_size = config_ckpt['type_vocab_size']
220+
if not pretrained_model_name.startswith('spanbert'):
221+
type_vocab_size = config_ckpt['type_vocab_size']
192222
position_size = config_ckpt['max_position_embeddings']
193223
embedding_dropout = config_ckpt['hidden_dropout_prob']
194224
num_blocks = config_ckpt['num_hidden_layers']
@@ -208,11 +238,6 @@ def _transform_config(cls, pretrained_model_name: str,
208238
'dim': hidden_dim
209239
},
210240
'vocab_size': vocab_size,
211-
'segment_embed': {
212-
'name': 'token_type_embeddings',
213-
'dim': hidden_dim
214-
},
215-
'type_vocab_size': type_vocab_size,
216241
'position_embed': {
217242
'name': 'position_embeddings',
218243
'dim': hidden_dim
@@ -256,10 +281,74 @@ def _transform_config(cls, pretrained_model_name: str,
256281
}
257282
}
258283

284+
if not pretrained_model_name.startswith('spanbert'):
285+
configs.update({
286+
'segment_embed': {
287+
'name': 'token_type_embeddings',
288+
'dim': hidden_dim},
289+
'type_vocab_size': type_vocab_size,
290+
})
291+
259292
return configs
260293

261294
def _init_from_checkpoint(self, pretrained_model_name: str,
262295
cache_dir: str, **kwargs):
296+
if pretrained_model_name.startswith('spanbert'):
297+
global_tensor_map = {
298+
'bert.embeddings.word_embeddings.weight':
299+
'word_embedder._embedding',
300+
'bert.embeddings.position_embeddings.weight':
301+
'position_embedder._embedding',
302+
'bert.embeddings.LayerNorm.weight':
303+
'encoder.input_normalizer.weight',
304+
'bert.embeddings.LayerNorm.bias':
305+
'encoder.input_normalizer.bias',
306+
}
307+
308+
attention_tensor_map = {
309+
"attention.self.key.bias": "self_attns.{}.K_dense.bias",
310+
"attention.self.query.bias": "self_attns.{}.Q_dense.bias",
311+
"attention.self.value.bias": "self_attns.{}.V_dense.bias",
312+
"attention.output.dense.bias": "self_attns.{}.O_dense.bias",
313+
"attention.output.LayerNorm.weight":
314+
"poswise_layer_norm.{}.weight",
315+
"attention.output.LayerNorm.bias": "poswise_layer_norm.{}.bias",
316+
"intermediate.dense.bias": "poswise_networks.{}._layers.0.bias",
317+
"output.dense.bias": "poswise_networks.{}._layers.2.bias",
318+
"output.LayerNorm.weight": "output_layer_norm.{}.weight",
319+
"output.LayerNorm.bias": "output_layer_norm.{}.bias",
320+
"attention.self.key.weight": "self_attns.{}.K_dense.weight",
321+
"attention.self.query.weight": "self_attns.{}.Q_dense.weight",
322+
"attention.self.value.weight": "self_attns.{}.V_dense.weight",
323+
"attention.output.dense.weight": "self_attns.{}.O_dense.weight",
324+
"intermediate.dense.weight":
325+
"poswise_networks.{}._layers.0.weight",
326+
"output.dense.weight": "poswise_networks.{}._layers.2.weight",
327+
}
328+
checkpoint_path = os.path.abspath(os.path.join(
329+
cache_dir, self._MODEL2CKPT[pretrained_model_name]))
330+
331+
device = next(self.parameters()).device
332+
params = torch.load(checkpoint_path, map_location=device)
333+
334+
for name, tensor in params.items():
335+
if name in global_tensor_map:
336+
v_name = global_tensor_map[name]
337+
pointer = self._name_to_variable(v_name)
338+
assert pointer.shape == tensor.shape
339+
pointer.data = tensor.data.type(pointer.dtype)
340+
elif name.startswith('bert.encoder.layer.'):
341+
name = name.lstrip('bert.encoder.layer.')
342+
layer_num, layer_name = name.split('.', 1)
343+
if layer_name in attention_tensor_map:
344+
v_name = attention_tensor_map[layer_name]
345+
pointer = self._name_to_variable(
346+
'encoder.' + v_name.format(layer_num))
347+
assert pointer.shape == tensor.shape
348+
pointer.data = tensor.data.type(pointer.dtype)
349+
350+
return
351+
263352
try:
264353
import numpy as np
265354
import tensorflow as tf

texar/torch/modules/pretrained/bert_test.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,75 @@ def test_load_pretrained_bert_AND_transform_bert_to_texar_config(self):
102102

103103
self.assertDictEqual(model_config, exp_config)
104104

105+
@pretrained_test
106+
def test_load_spanbert_AND_transform_spanbert_to_texar_config(
107+
self):
108+
pretrained_model_dir = PretrainedBERTMixin.download_checkpoint(
109+
pretrained_model_name="spanbert-base-cased")
110+
111+
info = list(os.walk(pretrained_model_dir))
112+
_, _, files = info[0]
113+
self.assertIn('config.json', files)
114+
self.assertIn('pytorch_model.bin', files)
115+
116+
model_config = PretrainedBERTMixin._transform_config(
117+
pretrained_model_name="spanbert-base-cased",
118+
cache_dir=pretrained_model_dir)
119+
120+
exp_config = {
121+
'hidden_size': 768,
122+
'embed': {
123+
'name': 'word_embeddings',
124+
'dim': 768
125+
},
126+
'vocab_size': 28996,
127+
'position_embed': {
128+
'name': 'position_embeddings',
129+
'dim': 768
130+
},
131+
'position_size': 512,
132+
'encoder': {
133+
'name': 'encoder',
134+
'embedding_dropout': 0.1,
135+
'num_blocks': 12,
136+
'multihead_attention': {
137+
'use_bias': True,
138+
'num_units': 768,
139+
'num_heads': 12,
140+
'output_dim': 768,
141+
'dropout_rate': 0.1,
142+
'name': 'self'
143+
},
144+
'residual_dropout': 0.1,
145+
'dim': 768,
146+
'use_bert_config': True,
147+
'eps': 1e-12,
148+
'poswise_feedforward': {
149+
'layers': [
150+
{
151+
'type': 'Linear',
152+
'kwargs': {
153+
'in_features': 768,
154+
'out_features': 3072,
155+
'bias': True
156+
}
157+
},
158+
{'type': 'BertGELU'},
159+
{
160+
'type': 'Linear',
161+
'kwargs': {
162+
'in_features': 3072,
163+
'out_features': 768,
164+
'bias': True
165+
}
166+
}
167+
]
168+
}
169+
}
170+
}
171+
172+
self.assertDictEqual(model_config, exp_config)
173+
105174

106175
if __name__ == "__main__":
107176
unittest.main()

0 commit comments

Comments
 (0)