Skip to content

Commit db0a789

Browse files
committed
* final clean up
* remove conflicts * all tests passed
1 parent 267baec commit db0a789

File tree

6 files changed

+38
-26
lines changed

6 files changed

+38
-26
lines changed

fastNLP/core/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,10 @@ def __getitem__(self, idx):
9898
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))
9999

100100
def __getattr__(self, item):
101+
# Not tested. Don't use !!
101102
if item == "field_arrays":
102103
raise AttributeError
103-
# TODO dataset.x
104-
if item in self.field_arrays:
104+
if isinstance(item, str) and item in self.field_arrays:
105105
return self.field_arrays[item]
106106
try:
107107
reader = DataLoaderRegister.get_reader(item)

fastNLP/core/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch
8585
if metric_key is not None:
8686
self.increase_better = False if metric_key[0] == "-" else True
8787
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key
88-
elif metrics is not None:
88+
elif len(metrics) > 0:
8989
self.metric_key = metrics[0].__class__.__name__.lower().strip('metric')
9090

9191
# prepare loss

fastNLP/io/base_loader.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,6 @@ def load_with_cache(cls, data_path, cache_path):
3131
return obj
3232

3333

34-
class ToyLoader0(BaseLoader):
35-
"""
36-
For CharLM
37-
"""
38-
39-
def __init__(self, data_path):
40-
super(ToyLoader0, self).__init__(data_path)
41-
42-
def load(self):
43-
with open(self.data_path, 'r') as f:
44-
corpus = f.read().lower()
45-
import re
46-
corpus = re.sub(r"<unk>", "unk", corpus)
47-
return corpus.split()
48-
49-
5034
class DataLoaderRegister:
5135
""""register for data sets"""
5236
_readers = {}

fastNLP/io/dataset_loader.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def convert(self, data):
7575
raise NotImplementedError
7676

7777

78-
@DataSet.set_reader("read_naive")
7978
class NativeDataSetLoader(DataSetLoader):
8079
def __init__(self):
8180
super(NativeDataSetLoader, self).__init__()
@@ -87,7 +86,9 @@ def load(self, path):
8786
return ds
8887

8988

90-
@DataSet.set_reader('read_raw')
89+
DataLoaderRegister.set_reader(NativeDataSetLoader, 'read_naive')
90+
91+
9192
class RawDataSetLoader(DataSetLoader):
9293
def __init__(self):
9394
super(RawDataSetLoader, self).__init__()
@@ -101,6 +102,8 @@ def load(self, data_path, split=None):
101102

102103
def convert(self, data):
103104
return convert_seq_dataset(data)
105+
106+
104107
DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata')
105108

106109

@@ -171,6 +174,8 @@ def convert(self, data):
171174
"""Convert lists of strings into Instances with Fields.
172175
"""
173176
return convert_seq2seq_dataset(data)
177+
178+
174179
DataLoaderRegister.set_reader(POSDataSetLoader, 'read_pos')
175180

176181

@@ -348,7 +353,6 @@ def convert(self, data):
348353
pass
349354

350355

351-
@DataSet.set_reader('read_people_daily')
352356
class PeopleDailyCorpusLoader(DataSetLoader):
353357
"""
354358
People Daily Corpus: Chinese word segmentation, POS tag, NER

test/core/test_dataset.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,20 @@ def test_get_field(self):
178178
self.assertTrue(isinstance(ans, FieldArray))
179179
self.assertEqual(ans.content, [[5, 6]] * 10)
180180

181+
def test_reader(self):
182+
# 跑通即可
183+
ds = DataSet().read_naive("test/data_for_tests/tutorial_sample_dataset.csv")
184+
self.assertTrue(isinstance(ds, DataSet))
185+
self.assertTrue(len(ds) > 0)
186+
187+
ds = DataSet().read_rawdata("test/data_for_tests/people_daily_raw.txt")
188+
self.assertTrue(isinstance(ds, DataSet))
189+
self.assertTrue(len(ds) > 0)
190+
191+
ds = DataSet().read_pos("test/data_for_tests/people.txt")
192+
self.assertTrue(isinstance(ds, DataSet))
193+
self.assertTrue(len(ds) > 0)
194+
181195

182196
class TestDataSetIter(unittest.TestCase):
183197
def test__repr__(self):

test/core/test_optimizer.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
class TestOptim(unittest.TestCase):
99
def test_SGD(self):
10-
optim = SGD(torch.nn.Linear(10, 3).parameters())
10+
optim = SGD(model_params=torch.nn.Linear(10, 3).parameters())
1111
self.assertTrue("lr" in optim.__dict__["settings"])
1212
self.assertTrue("momentum" in optim.__dict__["settings"])
1313
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
@@ -22,13 +22,18 @@ def test_SGD(self):
2222
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002)
2323
self.assertEqual(optim.__dict__["settings"]["momentum"], 0.989)
2424

25-
with self.assertRaises(RuntimeError):
25+
optim = SGD(0.001)
26+
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)
27+
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
28+
self.assertTrue(isinstance(res, torch.optim.SGD))
29+
30+
with self.assertRaises(TypeError):
2631
_ = SGD("???")
27-
with self.assertRaises(RuntimeError):
32+
with self.assertRaises(TypeError):
2833
_ = SGD(0.001, lr=0.002)
2934

3035
def test_Adam(self):
31-
optim = Adam(torch.nn.Linear(10, 3).parameters())
36+
optim = Adam(model_params=torch.nn.Linear(10, 3).parameters())
3237
self.assertTrue("lr" in optim.__dict__["settings"])
3338
self.assertTrue("weight_decay" in optim.__dict__["settings"])
3439
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
@@ -42,3 +47,8 @@ def test_Adam(self):
4247
optim = Adam(lr=0.002, weight_decay=0.989)
4348
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002)
4449
self.assertEqual(optim.__dict__["settings"]["weight_decay"], 0.989)
50+
51+
optim = Adam(0.001)
52+
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)
53+
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
54+
self.assertTrue(isinstance(res, torch.optim.Adam))

0 commit comments

Comments
 (0)