1
- from collections import Counter
2
-
3
1
import numpy as np
4
2
import torch
5
3
@@ -17,6 +15,56 @@ def convert_to_torch_tensor(data_list, use_cuda):
17
15
return data_list
18
16
19
17
18
+ class BaseSampler (object ):
19
+ """The base class of all samplers.
20
+
21
+ Sub-classes must implement the __call__ method.
22
+ __call__ takes a DataSet object and returns a list of int - the sampling indices.
23
+ """
24
+
25
+ def __call__ (self , * args , ** kwargs ):
26
+ raise NotImplementedError
27
+
28
+
29
+ class SequentialSampler (BaseSampler ):
30
+ """Sample data in the original order.
31
+
32
+ """
33
+
34
+ def __call__ (self , data_set ):
35
+ return list (range (len (data_set )))
36
+
37
+
38
+ class RandomSampler (BaseSampler ):
39
+ """Sample data in random permutation order.
40
+
41
+ """
42
+
43
+ def __call__ (self , data_set ):
44
+ return list (np .random .permutation (len (data_set )))
45
+
46
+
47
+ def simple_sort_bucketing (lengths ):
48
+ """
49
+
50
+ :param lengths: list of int, the lengths of all examples.
51
+ :param buckets: list of int. The length of the list is the number of buckets. Each integer is the maximum length
52
+ threshold for each bucket (This is usually None.).
53
+ :return data: 2-level list
54
+ ::
55
+
56
+ [
57
+ [index_11, index_12, ...], # bucket 1
58
+ [index_21, index_22, ...], # bucket 2
59
+ ...
60
+ ]
61
+
62
+ """
63
+ lengths_mapping = [(idx , length ) for idx , length in enumerate (lengths )]
64
+ sorted_lengths = sorted (lengths_mapping , key = lambda x : x [1 ])
65
+ # TODO: need to return buckets
66
+ return [idx for idx , _ in sorted_lengths ]
67
+
20
68
def k_means_1d (x , k , max_iter = 100 ):
21
69
"""Perform k-means on 1-D data.
22
70
@@ -46,18 +94,10 @@ def k_means_1d(x, k, max_iter=100):
46
94
return np .array (centroids ), assign
47
95
48
96
49
- def k_means_bucketing (all_inst , buckets ):
97
+ def k_means_bucketing (lengths , buckets ):
50
98
"""Assign all instances into possible buckets using k-means, such that instances in the same bucket have similar lengths.
51
99
52
- :param all_inst: 3-level list
53
- E.g. ::
54
-
55
- [
56
- [[word_11, word_12, word_13], [label_11. label_12]], # sample 1
57
- [[word_21, word_22, word_23], [label_21. label_22]], # sample 2
58
- ...
59
- ]
60
-
100
+ :param lengths: list of int, the length of all samples.
61
101
:param buckets: list of int. The length of the list is the number of buckets. Each integer is the maximum length
62
102
threshold for each bucket (This is usually None.).
63
103
:return data: 2-level list
@@ -72,7 +112,6 @@ def k_means_bucketing(all_inst, buckets):
72
112
"""
73
113
bucket_data = [[] for _ in buckets ]
74
114
num_buckets = len (buckets )
75
- lengths = np .array ([len (inst [0 ]) for inst in all_inst ])
76
115
_ , assignments = k_means_1d (lengths , num_buckets )
77
116
78
117
for idx , bucket_id in enumerate (assignments ):
@@ -81,102 +120,33 @@ def k_means_bucketing(all_inst, buckets):
81
120
return bucket_data
82
121
83
122
84
- class BaseSampler (object ):
85
- """The base class of all samplers.
86
-
87
- """
88
-
89
- def __call__ (self , * args , ** kwargs ):
90
- raise NotImplementedError
91
-
92
-
93
- class SequentialSampler (BaseSampler ):
94
- """Sample data in the original order.
95
-
96
- """
97
-
98
- def __call__ (self , data_set ):
99
- return list (range (len (data_set )))
100
-
101
-
102
- class RandomSampler (BaseSampler ):
103
- """Sample data in random permutation order.
104
-
105
- """
106
-
107
- def __call__ (self , data_set ):
108
- return list (np .random .permutation (len (data_set )))
109
-
110
-
111
-
112
- class Batchifier (object ):
113
- """Wrap random or sequential sampler to generate a mini-batch.
114
-
115
- """
116
-
117
- def __init__ (self , sampler , batch_size , drop_last = True ):
118
- """
119
-
120
- :param sampler: a Sampler object
121
- :param batch_size: int, the size of the mini-batch
122
- :param drop_last: bool, whether to drop the last examples that are not enough to make a mini-batch.
123
-
124
- """
125
- super (Batchifier , self ).__init__ ()
126
- self .sampler = sampler
127
- self .batch_size = batch_size
128
- self .drop_last = drop_last
129
-
130
- def __iter__ (self ):
131
- batch = []
132
- for example in self .sampler :
133
- batch .append (example )
134
- if len (batch ) == self .batch_size :
135
- yield batch
136
- batch = []
137
- if 0 < len (batch ) < self .batch_size and self .drop_last is False :
138
- yield batch
139
-
140
-
141
- class BucketBatchifier (Batchifier ):
123
+ class BucketSampler (BaseSampler ):
142
124
"""Partition all samples into multiple buckets, each of which contains sentences of approximately the same length.
143
125
In sampling, first random choose a bucket. Then sample data from it.
144
126
The number of buckets is decided dynamically by the variance of sentence lengths.
145
- TODO: merge it into Batch
127
+
146
128
"""
147
129
148
- def __init__ (self , data_set , batch_size , num_buckets , drop_last = True , sampler = None ):
130
+ def __call__ (self , data_set , batch_size , num_buckets ):
131
+ return self ._process (data_set , batch_size , num_buckets )
132
+
133
+ def _process (self , data_set , batch_size , num_buckets , use_kmeans = False ):
149
134
"""
150
135
151
- :param data_set: three-level list, shape [num_samples, 2]
136
+ :param data_set: a DataSet object
152
137
:param batch_size: int
153
138
:param num_buckets: int, number of buckets for grouping these sequences.
154
- :param drop_last: bool, useless currently.
155
- :param sampler: Sampler, useless currently.
139
+ :param use_kmeans: bool, whether to use k-means to create buckets.
156
140
157
141
"""
158
- super (BucketBatchifier , self ).__init__ (sampler , batch_size , drop_last )
159
142
buckets = ([None ] * num_buckets )
160
- self .data = data_set
161
- self .batch_size = batch_size
162
- self .length_freq = dict (Counter ([len (example ) for example in data_set ]))
163
- self .buckets = k_means_bucketing (data_set , buckets )
164
-
165
- def __iter__ (self ):
166
- """Make a min-batch of data."""
167
- for _ in range (len (self .data ) // self .batch_size ):
168
- bucket_samples = self .buckets [np .random .randint (0 , len (self .buckets ))]
169
- np .random .shuffle (bucket_samples )
170
- yield [self .data [idx ] for idx in bucket_samples [:batch_size ]]
171
-
172
-
173
- if __name__ == "__main__" :
174
- import random
175
-
176
- data = [[[y ] * random .randint (0 , 50 ), [y ]] for y in range (500 )]
177
- batch_size = 8
178
- iterator = iter (BucketBatchifier (data , batch_size , num_buckets = 5 ))
179
- for d in iterator :
180
- print ("\n batch:" )
181
- for dd in d :
182
- print (len (dd [0 ]), end = " " )
143
+ if use_kmeans is True :
144
+ buckets = k_means_bucketing (data_set , buckets )
145
+ else :
146
+ buckets = simple_sort_bucketing (data_set )
147
+ index_list = []
148
+ for _ in range (len (data_set ) // batch_size ):
149
+ chosen_bucket = buckets [np .random .randint (0 , len (buckets ))]
150
+ np .random .shuffle (chosen_bucket )
151
+ index_list += [idx for idx in chosen_bucket [:batch_size ]]
152
+ return index_list
0 commit comments