6
6
from fastNLP .core .dataset import DataSet
7
7
from fastNLP .core .field import TextField , LabelField
8
8
from fastNLP .core .instance import Instance
9
-
10
- DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0
11
- DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1
12
- DEFAULT_RESERVED_LABEL = ['<reserved-2>' ,
13
- '<reserved-3>' ,
14
- '<reserved-4>' ] # dict index = 2~4
15
-
16
- DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL : 0 , DEFAULT_UNKNOWN_LABEL : 1 ,
17
- DEFAULT_RESERVED_LABEL [0 ]: 2 , DEFAULT_RESERVED_LABEL [1 ]: 3 ,
18
- DEFAULT_RESERVED_LABEL [2 ]: 4 }
9
+ from fastNLP .core .vocabulary import Vocabulary
19
10
20
11
21
12
# the first vocab in dict with the index = 5
@@ -68,24 +59,22 @@ class BasePreprocess(object):
68
59
69
60
- "word2id.pkl", a mapping from words(tokens) to indices
70
61
- "id2word.pkl", a reversed dictionary
71
- - "label2id.pkl", a dictionary on labels
72
- - "id2label.pkl", a reversed dictionary on labels
73
62
74
63
These four pickle files are expected to be saved in the given pickle directory once they are constructed.
75
64
Preprocessors will check if those files are already in the directory and will reuse them in future calls.
76
65
"""
77
66
78
67
def __init__ (self ):
79
- self .word2index = None
80
- self .label2index = None
68
+ self .data_vocab = Vocabulary ()
69
+ self .label_vocab = Vocabulary ()
81
70
82
71
@property
83
72
def vocab_size (self ):
84
- return len (self .word2index )
73
+ return len (self .data_vocab )
85
74
86
75
@property
87
76
def num_classes (self ):
88
- return len (self .label2index )
77
+ return len (self .label_vocab )
89
78
90
79
def run (self , train_dev_data , test_data = None , pickle_path = "./" , train_dev_split = 0 , cross_val = False , n_fold = 10 ):
91
80
"""Main pre-processing pipeline.
@@ -102,20 +91,14 @@ def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=
102
91
"""
103
92
104
93
if pickle_exist (pickle_path , "word2id.pkl" ) and pickle_exist (pickle_path , "class2id.pkl" ):
105
- self .word2index = load_pickle (pickle_path , "word2id.pkl" )
106
- self .label2index = load_pickle (pickle_path , "class2id.pkl" )
94
+ self .data_vocab = load_pickle (pickle_path , "word2id.pkl" )
95
+ self .label_vocab = load_pickle (pickle_path , "class2id.pkl" )
107
96
else :
108
- self .word2index , self .label2index = self .build_dict (train_dev_data )
109
- save_pickle (self .word2index , pickle_path , "word2id.pkl" )
110
- save_pickle (self .label2index , pickle_path , "class2id.pkl" )
111
-
112
- if not pickle_exist (pickle_path , "id2word.pkl" ):
113
- index2word = self .build_reverse_dict (self .word2index )
114
- save_pickle (index2word , pickle_path , "id2word.pkl" )
97
+ self .data_vocab , self .label_vocab = self .build_dict (train_dev_data )
98
+ save_pickle (self .data_vocab , pickle_path , "word2id.pkl" )
99
+ save_pickle (self .label_vocab , pickle_path , "class2id.pkl" )
115
100
116
- if not pickle_exist (pickle_path , "id2class.pkl" ):
117
- index2label = self .build_reverse_dict (self .label2index )
118
- save_pickle (index2label , pickle_path , "id2class.pkl" )
101
+ self .build_reverse_dict ()
119
102
120
103
train_set = []
121
104
dev_set = []
@@ -125,13 +108,13 @@ def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=
125
108
split = int (len (train_dev_data ) * train_dev_split )
126
109
data_dev = train_dev_data [: split ]
127
110
data_train = train_dev_data [split :]
128
- train_set = self .convert_to_dataset (data_train , self .word2index , self .label2index )
129
- dev_set = self .convert_to_dataset (data_dev , self .word2index , self .label2index )
111
+ train_set = self .convert_to_dataset (data_train , self .data_vocab , self .label_vocab )
112
+ dev_set = self .convert_to_dataset (data_dev , self .data_vocab , self .label_vocab )
130
113
131
114
save_pickle (dev_set , pickle_path , "data_dev.pkl" )
132
115
print ("{} of the training data is split for validation. " .format (train_dev_split ))
133
116
else :
134
- train_set = self .convert_to_dataset (train_dev_data , self .word2index , self .label2index )
117
+ train_set = self .convert_to_dataset (train_dev_data , self .data_vocab , self .label_vocab )
135
118
save_pickle (train_set , pickle_path , "data_train.pkl" )
136
119
else :
137
120
train_set = load_pickle (pickle_path , "data_train.pkl" )
@@ -143,8 +126,8 @@ def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=
143
126
# cross validation
144
127
data_cv = self .cv_split (train_dev_data , n_fold )
145
128
for i , (data_train_cv , data_dev_cv ) in enumerate (data_cv ):
146
- data_train_cv = self .convert_to_dataset (data_train_cv , self .word2index , self .label2index )
147
- data_dev_cv = self .convert_to_dataset (data_dev_cv , self .word2index , self .label2index )
129
+ data_train_cv = self .convert_to_dataset (data_train_cv , self .data_vocab , self .label_vocab )
130
+ data_dev_cv = self .convert_to_dataset (data_dev_cv , self .data_vocab , self .label_vocab )
148
131
save_pickle (
149
132
data_train_cv , pickle_path ,
150
133
"data_train_{}.pkl" .format (i ))
@@ -165,7 +148,7 @@ def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=
165
148
test_set = []
166
149
if test_data is not None :
167
150
if not pickle_exist (pickle_path , "data_test.pkl" ):
168
- test_set = self .convert_to_dataset (test_data , self .word2index , self .label2index )
151
+ test_set = self .convert_to_dataset (test_data , self .data_vocab , self .label_vocab )
169
152
save_pickle (test_set , pickle_path , "data_test.pkl" )
170
153
171
154
# return preprocessed results
@@ -180,28 +163,15 @@ def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=
180
163
return tuple (results )
181
164
182
165
def build_dict (self , data ):
183
- label2index = DEFAULT_WORD_TO_INDEX .copy ()
184
- word2index = DEFAULT_WORD_TO_INDEX .copy ()
185
166
for example in data :
186
- for word in example [0 ]:
187
- if word not in word2index :
188
- word2index [word ] = len (word2index )
189
- label = example [1 ]
190
- if isinstance (label , str ):
191
- # label is a string
192
- if label not in label2index :
193
- label2index [label ] = len (label2index )
194
- elif isinstance (label , list ):
195
- # label is a list of strings
196
- for single_label in label :
197
- if single_label not in label2index :
198
- label2index [single_label ] = len (label2index )
199
- return word2index , label2index
200
-
201
-
202
- def build_reverse_dict (self , word_dict ):
203
- id2word = {word_dict [w ]: w for w in word_dict }
204
- return id2word
167
+ word , label = example
168
+ self .data_vocab .update (word )
169
+ self .label_vocab .update (label )
170
+ return self .data_vocab , self .label_vocab
171
+
172
+ def build_reverse_dict (self ):
173
+ self .data_vocab .build_reverse_vocab ()
174
+ self .label_vocab .build_reverse_vocab ()
205
175
206
176
def data_split (self , data , train_dev_split ):
207
177
"""Split data into train and dev set."""
0 commit comments