Skip to content

Commit 9733249

Browse files
authored
Merge pull request #82 from choosewhatulike/master
add Vocabulary
2 parents 4d66bd6 + e8cc702 commit 9733249

File tree

10 files changed

+203
-73
lines changed

10 files changed

+203
-73
lines changed

fastNLP/core/predictor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ def __init__(self, pickle_path, task):
2727
self.batch_output = []
2828
self.pickle_path = pickle_path
2929
self._task = task # one of ("seq_label", "text_classify")
30-
self.index2label = load_pickle(self.pickle_path, "id2class.pkl")
31-
self.word2index = load_pickle(self.pickle_path, "word2id.pkl")
30+
self.label_vocab = load_pickle(self.pickle_path, "class2id.pkl")
31+
self.word_vocab = load_pickle(self.pickle_path, "word2id.pkl")
3232

3333
def predict(self, network, data):
3434
"""Perform inference using the trained model.
@@ -82,7 +82,7 @@ def prepare_input(self, data):
8282
:return data_set: a DataSet instance.
8383
"""
8484
assert isinstance(data, list)
85-
return create_dataset_from_lists(data, self.word2index, has_target=False)
85+
return create_dataset_from_lists(data, self.word_vocab, has_target=False)
8686

8787
def prepare_output(self, data):
8888
"""Transform list of batch outputs into strings."""
@@ -97,14 +97,14 @@ def _seq_label_prepare_output(self, batch_outputs):
9797
results = []
9898
for batch in batch_outputs:
9999
for example in np.array(batch):
100-
results.append([self.index2label[int(x)] for x in example])
100+
results.append([self.label_vocab.to_word(int(x)) for x in example])
101101
return results
102102

103103
def _text_classify_prepare_output(self, batch_outputs):
104104
results = []
105105
for batch_out in batch_outputs:
106106
idx = np.argmax(batch_out.detach().numpy(), axis=-1)
107-
results.extend([self.index2label[i] for i in idx])
107+
results.extend([self.label_vocab.to_word(i) for i in idx])
108108
return results
109109

110110

fastNLP/core/preprocess.py

Lines changed: 25 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,7 @@
66
from fastNLP.core.dataset import DataSet
77
from fastNLP.core.field import TextField, LabelField
88
from fastNLP.core.instance import Instance
9-
10-
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0
11-
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1
12-
DEFAULT_RESERVED_LABEL = ['<reserved-2>',
13-
'<reserved-3>',
14-
'<reserved-4>'] # dict index = 2~4
15-
16-
DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
17-
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3,
18-
DEFAULT_RESERVED_LABEL[2]: 4}
9+
from fastNLP.core.vocabulary import Vocabulary
1910

2011

2112
# the first vocab in dict with the index = 5
@@ -68,24 +59,22 @@ class BasePreprocess(object):
6859
6960
- "word2id.pkl", a mapping from words(tokens) to indices
7061
- "id2word.pkl", a reversed dictionary
71-
- "label2id.pkl", a dictionary on labels
72-
- "id2label.pkl", a reversed dictionary on labels
7362
7463
These four pickle files are expected to be saved in the given pickle directory once they are constructed.
7564
Preprocessors will check if those files are already in the directory and will reuse them in future calls.
7665
"""
7766

7867
def __init__(self):
79-
self.word2index = None
80-
self.label2index = None
68+
self.data_vocab = Vocabulary()
69+
self.label_vocab = Vocabulary()
8170

8271
@property
8372
def vocab_size(self):
84-
return len(self.word2index)
73+
return len(self.data_vocab)
8574

8675
@property
8776
def num_classes(self):
88-
return len(self.label2index)
77+
return len(self.label_vocab)
8978

9079
def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10):
9180
"""Main pre-processing pipeline.
@@ -102,20 +91,14 @@ def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=
10291
"""
10392

10493
if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"):
105-
self.word2index = load_pickle(pickle_path, "word2id.pkl")
106-
self.label2index = load_pickle(pickle_path, "class2id.pkl")
94+
self.data_vocab = load_pickle(pickle_path, "word2id.pkl")
95+
self.label_vocab = load_pickle(pickle_path, "class2id.pkl")
10796
else:
108-
self.word2index, self.label2index = self.build_dict(train_dev_data)
109-
save_pickle(self.word2index, pickle_path, "word2id.pkl")
110-
save_pickle(self.label2index, pickle_path, "class2id.pkl")
111-
112-
if not pickle_exist(pickle_path, "id2word.pkl"):
113-
index2word = self.build_reverse_dict(self.word2index)
114-
save_pickle(index2word, pickle_path, "id2word.pkl")
97+
self.data_vocab, self.label_vocab = self.build_dict(train_dev_data)
98+
save_pickle(self.data_vocab, pickle_path, "word2id.pkl")
99+
save_pickle(self.label_vocab, pickle_path, "class2id.pkl")
115100

116-
if not pickle_exist(pickle_path, "id2class.pkl"):
117-
index2label = self.build_reverse_dict(self.label2index)
118-
save_pickle(index2label, pickle_path, "id2class.pkl")
101+
self.build_reverse_dict()
119102

120103
train_set = []
121104
dev_set = []
@@ -125,13 +108,13 @@ def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=
125108
split = int(len(train_dev_data) * train_dev_split)
126109
data_dev = train_dev_data[: split]
127110
data_train = train_dev_data[split:]
128-
train_set = self.convert_to_dataset(data_train, self.word2index, self.label2index)
129-
dev_set = self.convert_to_dataset(data_dev, self.word2index, self.label2index)
111+
train_set = self.convert_to_dataset(data_train, self.data_vocab, self.label_vocab)
112+
dev_set = self.convert_to_dataset(data_dev, self.data_vocab, self.label_vocab)
130113

131114
save_pickle(dev_set, pickle_path, "data_dev.pkl")
132115
print("{} of the training data is split for validation. ".format(train_dev_split))
133116
else:
134-
train_set = self.convert_to_dataset(train_dev_data, self.word2index, self.label2index)
117+
train_set = self.convert_to_dataset(train_dev_data, self.data_vocab, self.label_vocab)
135118
save_pickle(train_set, pickle_path, "data_train.pkl")
136119
else:
137120
train_set = load_pickle(pickle_path, "data_train.pkl")
@@ -143,8 +126,8 @@ def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=
143126
# cross validation
144127
data_cv = self.cv_split(train_dev_data, n_fold)
145128
for i, (data_train_cv, data_dev_cv) in enumerate(data_cv):
146-
data_train_cv = self.convert_to_dataset(data_train_cv, self.word2index, self.label2index)
147-
data_dev_cv = self.convert_to_dataset(data_dev_cv, self.word2index, self.label2index)
129+
data_train_cv = self.convert_to_dataset(data_train_cv, self.data_vocab, self.label_vocab)
130+
data_dev_cv = self.convert_to_dataset(data_dev_cv, self.data_vocab, self.label_vocab)
148131
save_pickle(
149132
data_train_cv, pickle_path,
150133
"data_train_{}.pkl".format(i))
@@ -165,7 +148,7 @@ def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=
165148
test_set = []
166149
if test_data is not None:
167150
if not pickle_exist(pickle_path, "data_test.pkl"):
168-
test_set = self.convert_to_dataset(test_data, self.word2index, self.label2index)
151+
test_set = self.convert_to_dataset(test_data, self.data_vocab, self.label_vocab)
169152
save_pickle(test_set, pickle_path, "data_test.pkl")
170153

171154
# return preprocessed results
@@ -180,28 +163,15 @@ def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=
180163
return tuple(results)
181164

182165
def build_dict(self, data):
183-
label2index = DEFAULT_WORD_TO_INDEX.copy()
184-
word2index = DEFAULT_WORD_TO_INDEX.copy()
185166
for example in data:
186-
for word in example[0]:
187-
if word not in word2index:
188-
word2index[word] = len(word2index)
189-
label = example[1]
190-
if isinstance(label, str):
191-
# label is a string
192-
if label not in label2index:
193-
label2index[label] = len(label2index)
194-
elif isinstance(label, list):
195-
# label is a list of strings
196-
for single_label in label:
197-
if single_label not in label2index:
198-
label2index[single_label] = len(label2index)
199-
return word2index, label2index
200-
201-
202-
def build_reverse_dict(self, word_dict):
203-
id2word = {word_dict[w]: w for w in word_dict}
204-
return id2word
167+
word, label = example
168+
self.data_vocab.update(word)
169+
self.label_vocab.update(label)
170+
return self.data_vocab, self.label_vocab
171+
172+
def build_reverse_dict(self):
173+
self.data_vocab.build_reverse_vocab()
174+
self.label_vocab.build_reverse_vocab()
205175

206176
def data_split(self, data, train_dev_split):
207177
"""Split data into train and dev set."""

fastNLP/core/vocabulary.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
from copy import deepcopy
2+
3+
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0
4+
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1
5+
DEFAULT_RESERVED_LABEL = ['<reserved-2>',
6+
'<reserved-3>',
7+
'<reserved-4>'] # dict index = 2~4
8+
9+
DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
10+
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3,
11+
DEFAULT_RESERVED_LABEL[2]: 4}
12+
13+
def isiterable(p_object):
14+
try:
15+
it = iter(p_object)
16+
except TypeError:
17+
return False
18+
return True
19+
20+
class Vocabulary(object):
21+
"""Use for word and index one to one mapping
22+
23+
Example::
24+
25+
vocab = Vocabulary()
26+
word_list = "this is a word list".split()
27+
vocab.update(word_list)
28+
vocab["word"]
29+
vocab.to_word(5)
30+
"""
31+
def __init__(self, need_default=True):
32+
"""
33+
:param bool need_default: set if the Vocabulary has default labels reserved.
34+
"""
35+
if need_default:
36+
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX)
37+
self.padding_label = DEFAULT_PADDING_LABEL
38+
self.unknown_label = DEFAULT_UNKNOWN_LABEL
39+
else:
40+
self.word2idx = {}
41+
self.padding_label = None
42+
self.unknown_label = None
43+
44+
self.has_default = need_default
45+
self.idx2word = None
46+
47+
def __len__(self):
48+
return len(self.word2idx)
49+
50+
def update(self, word):
51+
"""add word or list of words into Vocabulary
52+
53+
:param word: a list of str or str
54+
"""
55+
if not isinstance(word, str) and isiterable(word):
56+
# it's a nested list
57+
for w in word:
58+
self.update(w)
59+
else:
60+
# it's a word to be added
61+
if word not in self.word2idx:
62+
self.word2idx[word] = len(self)
63+
if self.idx2word is not None:
64+
self.idx2word = None
65+
66+
67+
def __getitem__(self, w):
68+
"""To support usage like::
69+
70+
vocab[w]
71+
"""
72+
if w in self.word2idx:
73+
return self.word2idx[w]
74+
else:
75+
return self.word2idx[DEFAULT_UNKNOWN_LABEL]
76+
77+
def to_index(self, w):
78+
""" like to_index(w) function, turn a word to the index
79+
if w is not in Vocabulary, return the unknown label
80+
81+
:param str w:
82+
"""
83+
return self[w]
84+
85+
def unknown_idx(self):
86+
if self.unknown_label is None:
87+
return None
88+
return self.word2idx[self.unknown_label]
89+
90+
def padding_idx(self):
91+
if self.padding_label is None:
92+
return None
93+
return self.word2idx[self.padding_label]
94+
95+
def build_reverse_vocab(self):
96+
"""build 'index to word' dict based on 'word to index' dict
97+
"""
98+
self.idx2word = {self.word2idx[w] : w for w in self.word2idx}
99+
100+
def to_word(self, idx):
101+
"""given a word's index, return the word itself
102+
103+
:param int idx:
104+
"""
105+
if self.idx2word is None:
106+
self.build_reverse_vocab()
107+
return self.idx2word[idx]
108+
109+
def __getstate__(self):
110+
"""use to prepare data for pickle
111+
"""
112+
state = self.__dict__.copy()
113+
# no need to pickle idx2word as it can be constructed from word2idx
114+
del state['idx2word']
115+
return state
116+
117+
def __setstate__(self, state):
118+
"""use to restore state from pickle
119+
"""
120+
self.__dict__.update(state)
121+
self.idx2word = None
122+
123+
124+

fastNLP/fastnlp.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(self, model_dir="./"):
6969
:param model_dir: this directory should contain the following files:
7070
1. a pre-trained model
7171
2. a config file
72-
3. "id2class.pkl"
72+
3. "class2id.pkl"
7373
4. "word2id.pkl"
7474
"""
7575
self.model_dir = model_dir
@@ -99,10 +99,10 @@ def load(self, model_name, config_file="config", section_name="model"):
9999
print("Restore model hyper-parameters {}".format(str(model_args.data)))
100100

101101
# fetch dictionary size and number of labels from pickle files
102-
word2index = load_pickle(self.model_dir, "word2id.pkl")
103-
model_args["vocab_size"] = len(word2index)
104-
index2label = load_pickle(self.model_dir, "id2class.pkl")
105-
model_args["num_classes"] = len(index2label)
102+
word_vocab = load_pickle(self.model_dir, "word2id.pkl")
103+
model_args["vocab_size"] = len(word_vocab)
104+
label_vocab = load_pickle(self.model_dir, "class2id.pkl")
105+
model_args["num_classes"] = len(label_vocab)
106106

107107
# Construct the model
108108
model = model_class(model_args)

reproduction/chinese_word_segment/run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def infer():
3232
# fetch dictionary size and number of labels from pickle files
3333
word2index = load_pickle(pickle_path, "word2id.pkl")
3434
test_args["vocab_size"] = len(word2index)
35-
index2label = load_pickle(pickle_path, "id2class.pkl")
35+
index2label = load_pickle(pickle_path, "class2id.pkl")
3636
test_args["num_classes"] = len(index2label)
3737

3838

@@ -105,7 +105,7 @@ def test():
105105
# fetch dictionary size and number of labels from pickle files
106106
word2index = load_pickle(pickle_path, "word2id.pkl")
107107
test_args["vocab_size"] = len(word2index)
108-
index2label = load_pickle(pickle_path, "id2class.pkl")
108+
index2label = load_pickle(pickle_path, "class2id.pkl")
109109
test_args["num_classes"] = len(index2label)
110110

111111
# load dev data

reproduction/pos_tag_model/train_pos_tag.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def infer():
3333
# fetch dictionary size and number of labels from pickle files
3434
word2index = load_pickle(pickle_path, "word2id.pkl")
3535
test_args["vocab_size"] = len(word2index)
36-
index2label = load_pickle(pickle_path, "id2class.pkl")
36+
index2label = load_pickle(pickle_path, "class2id.pkl")
3737
test_args["num_classes"] = len(index2label)
3838

3939
# Define the same model
@@ -105,7 +105,7 @@ def test():
105105
# fetch dictionary size and number of labels from pickle files
106106
word2index = load_pickle(pickle_path, "word2id.pkl")
107107
test_args["vocab_size"] = len(word2index)
108-
index2label = load_pickle(pickle_path, "id2class.pkl")
108+
index2label = load_pickle(pickle_path, "class2id.pkl")
109109
test_args["num_classes"] = len(index2label)
110110

111111
# load dev data

test/core/test_predictor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from fastNLP.core.predictor import Predictor
55
from fastNLP.core.preprocess import save_pickle
66
from fastNLP.models.sequence_modeling import SeqLabeling
7+
from fastNLP.core.vocabulary import Vocabulary
78

89

910
class TestPredictor(unittest.TestCase):
@@ -23,10 +24,14 @@ def test_seq_label(self):
2324
['a', 'b', 'c', 'd', '$'],
2425
['!', 'b', 'c', 'd', 'e']
2526
]
26-
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
27+
28+
vocab = Vocabulary()
29+
vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
30+
class_vocab = Vocabulary()
31+
class_vocab.word2idx = {"0":0, "1":1, "2":2, "3":3, "4":4}
2732

2833
os.system("mkdir save")
29-
save_pickle({0: "0", 1: "1", 2: "2", 3: "3", 4: "4"}, "./save/", "id2class.pkl")
34+
save_pickle(class_vocab, "./save/", "class2id.pkl")
3035
save_pickle(vocab, "./save/", "word2id.pkl")
3136

3237
model = SeqLabeling(model_args)

0 commit comments

Comments
 (0)