Skip to content

Commit 267baec

Browse files
author
yunfan
committed
add dataloader register
1 parent 720a264 commit 267baec

File tree

5 files changed

+147
-26
lines changed

5 files changed

+147
-26
lines changed

fastNLP/core/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .batch import Batch
2-
from .dataset import DataSet
2+
# from .dataset import DataSet
33
from .fieldarray import FieldArray
44
from .instance import Instance
55
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward
@@ -8,4 +8,6 @@
88
from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler
99
from .tester import Tester
1010
from .trainer import Trainer
11-
from .vocabulary import Vocabulary
11+
from .vocabulary import Vocabulary
12+
from ..io.dataset_loader import DataSet
13+

fastNLP/core/dataset.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
from fastNLP.core.fieldarray import FieldArray
66
from fastNLP.core.instance import Instance
77
from fastNLP.core.utils import get_func_signature
8-
9-
_READERS = {}
8+
from fastNLP.io.base_loader import DataLoaderRegister
109

1110

1211
class DataSet(object):
@@ -98,6 +97,24 @@ def __getitem__(self, idx):
9897
else:
9998
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))
10099

100+
def __getattr__(self, item):
101+
if item == "field_arrays":
102+
raise AttributeError
103+
# TODO dataset.x
104+
if item in self.field_arrays:
105+
return self.field_arrays[item]
106+
try:
107+
reader = DataLoaderRegister.get_reader(item)
108+
return reader
109+
except AttributeError:
110+
raise
111+
112+
def __setstate__(self, state):
113+
self.__dict__ = state
114+
115+
def __getstate__(self):
116+
return self.__dict__
117+
101118
def __len__(self):
102119
"""Fetch the length of the dataset.
103120
@@ -226,16 +243,6 @@ def get_target_name(self):
226243
"""
227244
return [name for name, field in self.field_arrays.items() if field.is_target]
228245

229-
@classmethod
230-
def set_reader(cls, method_name):
231-
assert isinstance(method_name, str)
232-
233-
def wrapper(read_cls):
234-
_READERS[method_name] = read_cls
235-
return read_cls
236-
237-
return wrapper
238-
239246
def apply(self, func, new_field_name=None, **kwargs):
240247
"""Apply a function to every instance of the DataSet.
241248
@@ -347,6 +354,9 @@ def read_csv(cls, csv_path, headers=None, sep=",", dropna=True):
347354
_dict[header].append(content)
348355
return cls(_dict)
349356

357+
# def read_pos(self):
358+
# return DataLoaderRegister.get_reader('read_pos')
359+
350360
def save(self, path):
351361
"""Save the DataSet object as pickle.
352362

fastNLP/core/trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch
8585
if metric_key is not None:
8686
self.increase_better = False if metric_key[0] == "-" else True
8787
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key
88-
else:
89-
self.metric_key = None
88+
elif metrics is not None:
89+
self.metric_key = metrics[0].__class__.__name__.lower().strip('metric')
9090

9191
# prepare loss
9292
losser = _prepare_losser(loss)
@@ -147,7 +147,7 @@ def train(self):
147147

148148
self._mode(self.model, is_test=False)
149149

150-
self.start_time = str(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
150+
self.start_time = str(datetime.now().strftime('%Y-%m-%d %H-%M-%S'))
151151
print("training epochs started " + self.start_time, flush=True)
152152
if self.save_path is None:
153153
class psudoSW:
@@ -260,7 +260,7 @@ def _do_validation(self):
260260
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val,
261261
global_step=self.step)
262262
if self.save_path is not None and self._better_eval_result(res):
263-
metric_key = self.metric_key if self.metric_key is not None else "None"
263+
metric_key = self.metric_key if self.metric_key is not None else ""
264264
self._save_model(self.model,
265265
"best_" + "_".join([self.model.__class__.__name__, metric_key, self.start_time]))
266266
return res

fastNLP/io/base_loader.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,39 @@ def load_with_cache(cls, data_path, cache_path):
2929
with open(cache_path, 'wb') as f:
3030
pickle.dump(obj, f)
3131
return obj
32+
33+
34+
class ToyLoader0(BaseLoader):
35+
"""
36+
For CharLM
37+
"""
38+
39+
def __init__(self, data_path):
40+
super(ToyLoader0, self).__init__(data_path)
41+
42+
def load(self):
43+
with open(self.data_path, 'r') as f:
44+
corpus = f.read().lower()
45+
import re
46+
corpus = re.sub(r"<unk>", "unk", corpus)
47+
return corpus.split()
48+
49+
50+
class DataLoaderRegister:
51+
""""register for data sets"""
52+
_readers = {}
53+
54+
@classmethod
55+
def set_reader(cls, reader_cls, read_fn_name):
56+
# def wrapper(reader_cls):
57+
if read_fn_name in cls._readers:
58+
raise KeyError('duplicate reader: {} and {} for read_func: {}'.format(cls._readers[read_fn_name], reader_cls, read_fn_name))
59+
if hasattr(reader_cls, 'load'):
60+
cls._readers[read_fn_name] = reader_cls().load
61+
return reader_cls
62+
63+
@classmethod
64+
def get_reader(cls, read_fn_name):
65+
if read_fn_name in cls._readers:
66+
return cls._readers[read_fn_name]
67+
raise AttributeError('no read function: {}'.format(read_fn_name))

fastNLP/io/dataset_loader.py

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from fastNLP.core.dataset import DataSet
44
from fastNLP.core.instance import Instance
5-
from fastNLP.io.base_loader import BaseLoader
5+
from fastNLP.io.base_loader import DataLoaderRegister
66

77

88
def convert_seq_dataset(data):
@@ -61,12 +61,9 @@ def convert_seq2seq_dataset(data):
6161
return dataset
6262

6363

64-
class DataSetLoader(BaseLoader):
64+
class DataSetLoader:
6565
""""loader for data sets"""
6666

67-
def __init__(self):
68-
super(DataSetLoader, self).__init__()
69-
7067
def load(self, path):
7168
""" load data in `path` into a dataset
7269
"""
@@ -104,9 +101,9 @@ def load(self, data_path, split=None):
104101

105102
def convert(self, data):
106103
return convert_seq_dataset(data)
104+
DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata')
107105

108106

109-
@DataSet.set_reader('read_pos')
110107
class POSDataSetLoader(DataSetLoader):
111108
"""Dataset Loader for POS Tag datasets.
112109
@@ -174,9 +171,9 @@ def convert(self, data):
174171
"""Convert lists of strings into Instances with Fields.
175172
"""
176173
return convert_seq2seq_dataset(data)
174+
DataLoaderRegister.set_reader(POSDataSetLoader, 'read_pos')
177175

178176

179-
@DataSet.set_reader('read_tokenize')
180177
class TokenizeDataSetLoader(DataSetLoader):
181178
"""
182179
Data set loader for tokenization data sets
@@ -236,7 +233,6 @@ def convert(self, data):
236233
return convert_seq2seq_dataset(data)
237234

238235

239-
@DataSet.set_reader('read_class')
240236
class ClassDataSetLoader(DataSetLoader):
241237
"""Loader for classification data sets"""
242238

@@ -275,6 +271,83 @@ def convert(self, data):
275271
return convert_seq2tag_dataset(data)
276272

277273

274+
class ConllLoader(DataSetLoader):
275+
"""loader for conll format files"""
276+
277+
def __init__(self):
278+
"""
279+
:param str data_path: the path to the conll data set
280+
"""
281+
super(ConllLoader, self).__init__()
282+
283+
def load(self, data_path):
284+
"""
285+
:return: list lines: all lines in a conll file
286+
"""
287+
with open(data_path, "r", encoding="utf-8") as f:
288+
lines = f.readlines()
289+
data = self.parse(lines)
290+
return self.convert(data)
291+
292+
@staticmethod
293+
def parse(lines):
294+
"""
295+
:param list lines:a list containing all lines in a conll file.
296+
:return: a 3D list
297+
"""
298+
sentences = list()
299+
tokens = list()
300+
for line in lines:
301+
if line[0] == "#":
302+
# skip the comments
303+
continue
304+
if line == "\n":
305+
sentences.append(tokens)
306+
tokens = []
307+
continue
308+
tokens.append(line.split())
309+
return sentences
310+
311+
def convert(self, data):
312+
pass
313+
314+
315+
class LMDataSetLoader(DataSetLoader):
316+
"""Language Model Dataset Loader
317+
318+
This loader produces data for language model training in a supervised way.
319+
That means it has X and Y.
320+
321+
"""
322+
323+
def __init__(self):
324+
super(LMDataSetLoader, self).__init__()
325+
326+
def load(self, data_path):
327+
if not os.path.exists(data_path):
328+
raise FileNotFoundError("file {} not found.".format(data_path))
329+
with open(data_path, "r", encoding="utf=8") as f:
330+
text = " ".join(f.readlines())
331+
tokens = text.strip().split()
332+
data = self.sentence_cut(tokens)
333+
return self.convert(data)
334+
335+
def sentence_cut(self, tokens, sentence_length=15):
336+
start_idx = 0
337+
data_set = []
338+
for idx in range(len(tokens) // sentence_length):
339+
x = tokens[start_idx * idx: start_idx * idx + sentence_length]
340+
y = tokens[start_idx * idx + 1: start_idx * idx + sentence_length + 1]
341+
if start_idx * idx + sentence_length + 1 >= len(tokens):
342+
# ad hoc
343+
y.extend(["<unk>"])
344+
data_set.append([x, y])
345+
return data_set
346+
347+
def convert(self, data):
348+
pass
349+
350+
278351
@DataSet.set_reader('read_people_daily')
279352
class PeopleDailyCorpusLoader(DataSetLoader):
280353
"""

0 commit comments

Comments
 (0)