Skip to content

Commit c4dbc7b

Browse files
authored
Merge pull request #86 from FengZiYjun/master
Name Changes & More Tests
2 parents 9733249 + 28a0683 commit c4dbc7b

31 files changed

+506
-854
lines changed

examples/readme_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from fastNLP.core.trainer import ClassificationTrainer
66
from fastNLP.loader.dataset_loader import ClassDatasetLoader
77
from fastNLP.models.base_model import BaseModel
8-
from fastNLP.modules import aggregation
8+
from fastNLP.modules import aggregator
99
from fastNLP.modules import decoder
1010
from fastNLP.modules import encoder
1111

@@ -21,7 +21,7 @@ def __init__(self, num_classes, vocab_size):
2121
self.emb = encoder.Embedding(nums=vocab_size, dims=300)
2222
self.enc = encoder.Conv(
2323
in_channels=300, out_channels=100, kernel_size=3)
24-
self.agg = aggregation.MaxPool()
24+
self.agg = aggregator.MaxPool()
2525
self.dec = decoder.MLP(size_layer=[100, num_classes])
2626

2727
def forward(self, x):

fastNLP/core/batch.py

Lines changed: 8 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,6 @@
22

33
import torch
44

5-
from fastNLP.core.dataset import DataSet
6-
from fastNLP.core.field import TextField, LabelField
7-
from fastNLP.core.instance import Instance
8-
95

106
class Batch(object):
117
"""Batch is an iterable object which iterates over mini-batches.
@@ -16,6 +12,14 @@ class Batch(object):
1612
"""
1713

1814
def __init__(self, dataset, batch_size, sampler, use_cuda):
15+
"""
16+
17+
:param dataset: a DataSet object
18+
:param batch_size: int, the size of the batch
19+
:param sampler: a Sampler object
20+
:param use_cuda: bool, whetjher to use GPU
21+
22+
"""
1923
self.dataset = dataset
2024
self.batch_size = batch_size
2125
self.sampler = sampler
@@ -81,46 +85,3 @@ def __next__(self):
8185
self.curidx += endidx
8286
return batch_x, batch_y
8387

84-
85-
if __name__ == "__main__":
86-
"""simple running example
87-
"""
88-
texts = ["i am a cat",
89-
"this is a test of new batch",
90-
"haha"
91-
]
92-
labels = [0, 1, 0]
93-
94-
# prepare vocabulary
95-
vocab = {}
96-
for text in texts:
97-
for tokens in text.split():
98-
if tokens not in vocab:
99-
vocab[tokens] = len(vocab)
100-
print("vocabulary: ", vocab)
101-
102-
# prepare input dataset
103-
data = DataSet()
104-
for text, label in zip(texts, labels):
105-
x = TextField(text.split(), False)
106-
y = LabelField(label, is_target=True)
107-
ins = Instance(text=x, label=y)
108-
data.append(ins)
109-
110-
# use vocabulary to index data
111-
data.index_field("text", vocab)
112-
113-
114-
# define naive sampler for batch class
115-
class SeqSampler:
116-
def __call__(self, dataset):
117-
return list(range(len(dataset)))
118-
119-
120-
# use batch to iterate dataset
121-
data_iterator = Batch(data, 2, SeqSampler(), False)
122-
for epoch in range(1):
123-
for batch_x, batch_y in data_iterator:
124-
print(batch_x)
125-
print(batch_y)
126-
# do stuff

fastNLP/core/predictor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import numpy as np
22
import torch
33

4-
from fastNLP.core.action import SequentialSampler
54
from fastNLP.core.batch import Batch
65
from fastNLP.core.dataset import create_dataset_from_lists
76
from fastNLP.core.preprocess import load_pickle
7+
from fastNLP.core.sampler import SequentialSampler
88

99

1010
class Predictor(object):
@@ -62,9 +62,13 @@ def mode(self, network, test=True):
6262

6363
def data_forward(self, network, x):
6464
"""Forward through network."""
65-
y = network(**x)
6665
if self._task == "seq_label":
66+
y = network(x["word_seq"], x["word_seq_origin_len"])
6767
y = network.prediction(y)
68+
elif self._task == "text_classify":
69+
y = network(x["word_seq"])
70+
else:
71+
raise NotImplementedError("Unknown task type {}.".format(self._task))
6872
return y
6973

7074
def prepare_input(self, data):

fastNLP/core/preprocess.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,21 +52,28 @@ def pickle_exist(pickle_path, pickle_name):
5252
return False
5353

5454

55-
class BasePreprocess(object):
56-
"""Base class of all preprocessors.
57-
Preprocessors are responsible for converting data of strings into data of indices.
55+
class Preprocessor(object):
56+
"""Preprocessors are responsible for converting data of strings into data of indices.
5857
During the pre-processing, the following pickle files will be built:
5958
60-
- "word2id.pkl", a mapping from words(tokens) to indices
61-
- "id2word.pkl", a reversed dictionary
59+
- "word2id.pkl", a Vocabulary object, mapping words to indices.
60+
- "class2id.pkl", a Vocabulary object, mapping labels to indices.
61+
- "data_train.pkl", a DataSet object for training
62+
- "data_dev.pkl", a DataSet object for validation, if train_dev_split > 0.
63+
- "data_test.pkl", a DataSet object for testing, if test_data is not None.
6264
6365
These four pickle files are expected to be saved in the given pickle directory once they are constructed.
6466
Preprocessors will check if those files are already in the directory and will reuse them in future calls.
6567
"""
6668

67-
def __init__(self):
69+
def __init__(self, label_is_seq=False):
70+
"""
71+
72+
:param label_is_seq: bool, whether label is a sequence. If True, label vocabulary will preserve
73+
several special tokens for sequence processing.
74+
"""
6875
self.data_vocab = Vocabulary()
69-
self.label_vocab = Vocabulary()
76+
self.label_vocab = Vocabulary(need_default=label_is_seq)
7077

7178
@property
7279
def vocab_size(self):
@@ -259,20 +266,20 @@ def convert_to_dataset(self, data, vocab, label_vocab):
259266
return data_set
260267

261268

262-
class SeqLabelPreprocess(BasePreprocess):
269+
class SeqLabelPreprocess(Preprocessor):
263270
def __init__(self):
264-
271+
print("[FastNLP warning] SeqLabelPreprocess is about to deprecate. Please use Preprocess directly.")
265272
super(SeqLabelPreprocess, self).__init__()
266273

267274

268-
269-
class ClassPreprocess(BasePreprocess):
275+
class ClassPreprocess(Preprocessor):
270276
def __init__(self):
277+
print("[FastNLP warning] ClassPreprocess is about to deprecate. Please use Preprocess directly.")
271278
super(ClassPreprocess, self).__init__()
272279

273280

274281
if __name__ == "__main__":
275-
p = BasePreprocess()
282+
p = Preprocessor()
276283
train_dev_data = [[["I", "am", "a", "good", "student", "."], "0"],
277284
[["You", "are", "pretty", "."], "1"]
278285
]
Lines changed: 70 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from collections import Counter
2-
31
import numpy as np
42
import torch
53

@@ -17,6 +15,56 @@ def convert_to_torch_tensor(data_list, use_cuda):
1715
return data_list
1816

1917

18+
class BaseSampler(object):
19+
"""The base class of all samplers.
20+
21+
Sub-classes must implement the __call__ method.
22+
__call__ takes a DataSet object and returns a list of int - the sampling indices.
23+
"""
24+
25+
def __call__(self, *args, **kwargs):
26+
raise NotImplementedError
27+
28+
29+
class SequentialSampler(BaseSampler):
30+
"""Sample data in the original order.
31+
32+
"""
33+
34+
def __call__(self, data_set):
35+
return list(range(len(data_set)))
36+
37+
38+
class RandomSampler(BaseSampler):
39+
"""Sample data in random permutation order.
40+
41+
"""
42+
43+
def __call__(self, data_set):
44+
return list(np.random.permutation(len(data_set)))
45+
46+
47+
def simple_sort_bucketing(lengths):
48+
"""
49+
50+
:param lengths: list of int, the lengths of all examples.
51+
:param buckets: list of int. The length of the list is the number of buckets. Each integer is the maximum length
52+
threshold for each bucket (This is usually None.).
53+
:return data: 2-level list
54+
::
55+
56+
[
57+
[index_11, index_12, ...], # bucket 1
58+
[index_21, index_22, ...], # bucket 2
59+
...
60+
]
61+
62+
"""
63+
lengths_mapping = [(idx, length) for idx, length in enumerate(lengths)]
64+
sorted_lengths = sorted(lengths_mapping, key=lambda x: x[1])
65+
# TODO: need to return buckets
66+
return [idx for idx, _ in sorted_lengths]
67+
2068
def k_means_1d(x, k, max_iter=100):
2169
"""Perform k-means on 1-D data.
2270
@@ -46,18 +94,10 @@ def k_means_1d(x, k, max_iter=100):
4694
return np.array(centroids), assign
4795

4896

49-
def k_means_bucketing(all_inst, buckets):
97+
def k_means_bucketing(lengths, buckets):
5098
"""Assign all instances into possible buckets using k-means, such that instances in the same bucket have similar lengths.
5199
52-
:param all_inst: 3-level list
53-
E.g. ::
54-
55-
[
56-
[[word_11, word_12, word_13], [label_11. label_12]], # sample 1
57-
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2
58-
...
59-
]
60-
100+
:param lengths: list of int, the length of all samples.
61101
:param buckets: list of int. The length of the list is the number of buckets. Each integer is the maximum length
62102
threshold for each bucket (This is usually None.).
63103
:return data: 2-level list
@@ -72,7 +112,6 @@ def k_means_bucketing(all_inst, buckets):
72112
"""
73113
bucket_data = [[] for _ in buckets]
74114
num_buckets = len(buckets)
75-
lengths = np.array([len(inst[0]) for inst in all_inst])
76115
_, assignments = k_means_1d(lengths, num_buckets)
77116

78117
for idx, bucket_id in enumerate(assignments):
@@ -81,102 +120,33 @@ def k_means_bucketing(all_inst, buckets):
81120
return bucket_data
82121

83122

84-
class BaseSampler(object):
85-
"""The base class of all samplers.
86-
87-
"""
88-
89-
def __call__(self, *args, **kwargs):
90-
raise NotImplementedError
91-
92-
93-
class SequentialSampler(BaseSampler):
94-
"""Sample data in the original order.
95-
96-
"""
97-
98-
def __call__(self, data_set):
99-
return list(range(len(data_set)))
100-
101-
102-
class RandomSampler(BaseSampler):
103-
"""Sample data in random permutation order.
104-
105-
"""
106-
107-
def __call__(self, data_set):
108-
return list(np.random.permutation(len(data_set)))
109-
110-
111-
112-
class Batchifier(object):
113-
"""Wrap random or sequential sampler to generate a mini-batch.
114-
115-
"""
116-
117-
def __init__(self, sampler, batch_size, drop_last=True):
118-
"""
119-
120-
:param sampler: a Sampler object
121-
:param batch_size: int, the size of the mini-batch
122-
:param drop_last: bool, whether to drop the last examples that are not enough to make a mini-batch.
123-
124-
"""
125-
super(Batchifier, self).__init__()
126-
self.sampler = sampler
127-
self.batch_size = batch_size
128-
self.drop_last = drop_last
129-
130-
def __iter__(self):
131-
batch = []
132-
for example in self.sampler:
133-
batch.append(example)
134-
if len(batch) == self.batch_size:
135-
yield batch
136-
batch = []
137-
if 0 < len(batch) < self.batch_size and self.drop_last is False:
138-
yield batch
139-
140-
141-
class BucketBatchifier(Batchifier):
123+
class BucketSampler(BaseSampler):
142124
"""Partition all samples into multiple buckets, each of which contains sentences of approximately the same length.
143125
In sampling, first random choose a bucket. Then sample data from it.
144126
The number of buckets is decided dynamically by the variance of sentence lengths.
145-
TODO: merge it into Batch
127+
146128
"""
147129

148-
def __init__(self, data_set, batch_size, num_buckets, drop_last=True, sampler=None):
130+
def __call__(self, data_set, batch_size, num_buckets):
131+
return self._process(data_set, batch_size, num_buckets)
132+
133+
def _process(self, data_set, batch_size, num_buckets, use_kmeans=False):
149134
"""
150135
151-
:param data_set: three-level list, shape [num_samples, 2]
136+
:param data_set: a DataSet object
152137
:param batch_size: int
153138
:param num_buckets: int, number of buckets for grouping these sequences.
154-
:param drop_last: bool, useless currently.
155-
:param sampler: Sampler, useless currently.
139+
:param use_kmeans: bool, whether to use k-means to create buckets.
156140
157141
"""
158-
super(BucketBatchifier, self).__init__(sampler, batch_size, drop_last)
159142
buckets = ([None] * num_buckets)
160-
self.data = data_set
161-
self.batch_size = batch_size
162-
self.length_freq = dict(Counter([len(example) for example in data_set]))
163-
self.buckets = k_means_bucketing(data_set, buckets)
164-
165-
def __iter__(self):
166-
"""Make a min-batch of data."""
167-
for _ in range(len(self.data) // self.batch_size):
168-
bucket_samples = self.buckets[np.random.randint(0, len(self.buckets))]
169-
np.random.shuffle(bucket_samples)
170-
yield [self.data[idx] for idx in bucket_samples[:batch_size]]
171-
172-
173-
if __name__ == "__main__":
174-
import random
175-
176-
data = [[[y] * random.randint(0, 50), [y]] for y in range(500)]
177-
batch_size = 8
178-
iterator = iter(BucketBatchifier(data, batch_size, num_buckets=5))
179-
for d in iterator:
180-
print("\nbatch:")
181-
for dd in d:
182-
print(len(dd[0]), end=" ")
143+
if use_kmeans is True:
144+
buckets = k_means_bucketing(data_set, buckets)
145+
else:
146+
buckets = simple_sort_bucketing(data_set)
147+
index_list = []
148+
for _ in range(len(data_set) // batch_size):
149+
chosen_bucket = buckets[np.random.randint(0, len(buckets))]
150+
np.random.shuffle(chosen_bucket)
151+
index_list += [idx for idx in chosen_bucket[:batch_size]]
152+
return index_list

0 commit comments

Comments
 (0)