Skip to content

Commit 50e18f8

Browse files
authored
Merge pull request #117 from FengZiYjun/fix-doc
[doc] Improve Documentation
2 parents dfb62ec + 07c2b87 commit 50e18f8

File tree

15 files changed

+209
-173
lines changed

15 files changed

+209
-173
lines changed

docs/source/conf.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,16 @@
1616
import sys
1717
sys.path.insert(0, os.path.abspath('../../'))
1818

19-
import sphinx_rtd_theme
20-
2119
# -- Project information -----------------------------------------------------
2220

2321
project = 'fastNLP'
2422
copyright = '2018, xpqiu'
2523
author = 'xpqiu'
2624

2725
# The short X.Y version
28-
version = ''
26+
version = '0.2'
2927
# The full version, including alpha/beta/rc tags
30-
release = '1.0'
28+
release = '0.2'
3129

3230

3331
# -- General configuration ---------------------------------------------------
511 Bytes
Loading

fastNLP/core/batch.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,19 @@
55
class Batch(object):
66
"""Batch is an iterable object which iterates over mini-batches.
77
8-
::
9-
for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()):
8+
Example::
109
10+
for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()):
11+
# ...
12+
13+
:param dataset: a DataSet object
14+
:param batch_size: int, the size of the batch
15+
:param sampler: a Sampler object
16+
:param as_numpy: bool. If True, return Numpy array. Otherwise, return torch tensors.
1117
1218
"""
1319

1420
def __init__(self, dataset, batch_size, sampler, as_numpy=False):
15-
"""
16-
17-
:param dataset: a DataSet object
18-
:param batch_size: int, the size of the batch
19-
:param sampler: a Sampler object
20-
:param as_numpy: bool. If True, return Numpy array. Otherwise, return torch tensors.
21-
22-
"""
2321
self.dataset = dataset
2422
self.batch_size = batch_size
2523
self.sampler = sampler

fastNLP/core/dataset.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def __getstate__(self):
118118
def __len__(self):
119119
"""Fetch the length of the dataset.
120120
121-
:return int length:
121+
:return length:
122122
"""
123123
if len(self.field_arrays) == 0:
124124
return 0
@@ -170,7 +170,7 @@ def add_field(self, name, fields, padding_val=0, is_input=False, is_target=False
170170
def delete_field(self, name):
171171
"""Delete a field based on the field name.
172172
173-
:param str name: the name of the field to be deleted.
173+
:param name: the name of the field to be deleted.
174174
"""
175175
self.field_arrays.pop(name)
176176

@@ -182,14 +182,14 @@ def get_field(self, field_name):
182182
def get_all_fields(self):
183183
"""Return all the fields with their names.
184184
185-
:return dict field_arrays: the internal data structure of DataSet.
185+
:return field_arrays: the internal data structure of DataSet.
186186
"""
187187
return self.field_arrays
188188

189189
def get_length(self):
190190
"""Fetch the length of the dataset.
191191
192-
:return int length:
192+
:return length:
193193
"""
194194
return len(self)
195195

@@ -232,14 +232,14 @@ def set_input(self, *field_name, flag=True):
232232
def get_input_name(self):
233233
"""Get all field names with `is_input` as True.
234234
235-
:return list field_names: a list of str
235+
:return field_names: a list of str
236236
"""
237237
return [name for name, field in self.field_arrays.items() if field.is_input]
238238

239239
def get_target_name(self):
240240
"""Get all field names with `is_target` as True.
241241
242-
:return list field_names: a list of str
242+
:return field_names: a list of str
243243
"""
244244
return [name for name, field in self.field_arrays.items() if field.is_target]
245245

@@ -294,8 +294,9 @@ def split(self, dev_ratio):
294294
"""Split the dataset into training and development(validation) set.
295295
296296
:param float dev_ratio: the ratio of test set in all data.
297-
:return DataSet train_set: the training set
298-
DataSet dev_set: the development set
297+
:return (train_set, dev_set):
298+
train_set: the training set
299+
dev_set: the development set
299300
"""
300301
assert isinstance(dev_ratio, float)
301302
assert 0 < dev_ratio < 1
@@ -326,7 +327,7 @@ def read_csv(cls, csv_path, headers=None, sep=",", dropna=True):
326327
:param List[str] or Tuple[str] headers: headers of the CSV file
327328
:param str sep: delimiter in CSV file. Default: ","
328329
:param bool dropna: If True, drop rows that have less entries than headers.
329-
:return DataSet dataset:
330+
:return dataset: the read data set
330331
331332
"""
332333
with open(csv_path, "r") as f:
@@ -370,7 +371,7 @@ def load(path):
370371
"""Load a DataSet object from pickle.
371372
372373
:param str path: the path to the pickle
373-
:return DataSet data_set:
374+
:return data_set:
374375
"""
375376
with open(path, 'rb') as f:
376377
return pickle.load(f)

fastNLP/core/fieldarray.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,18 @@
22

33

44
class FieldArray(object):
5-
"""FieldArray is the collection of Instances of the same Field.
6-
It is the basic element of DataSet class.
5+
"""``FieldArray`` is the collection of ``Instance``s of the same field.
6+
It is the basic element of ``DataSet`` class.
7+
8+
:param str name: the name of the FieldArray
9+
:param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray.
10+
:param int padding_val: the integer for padding. Default: 0.
11+
:param bool is_target: If True, this FieldArray is used to compute loss.
12+
:param bool is_input: If True, this FieldArray is used to the model input.
713
814
"""
915

1016
def __init__(self, name, content, padding_val=0, is_target=None, is_input=None):
11-
"""
12-
13-
:param str name: the name of the FieldArray
14-
:param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray.
15-
:param int padding_val: the integer for padding. Default: 0.
16-
:param bool is_target: If True, this FieldArray is used to compute loss.
17-
:param bool is_input: If True, this FieldArray is used to the model input.
18-
"""
1917
self.name = name
2018
if isinstance(content, list):
2119
content = content

fastNLP/core/instance.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
class Instance(object):
2-
"""An Instance is an example of data. It is the collection of Fields.
2+
"""An Instance is an example of data.
3+
Example::
4+
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2])
5+
ins["field_1"]
6+
>>[1, 1, 1]
7+
ins.add_field("field_3", [3, 3, 3])
38
4-
::
5-
Instance(field_1=[1, 1, 1], field_2=[2, 2, 2])
9+
:param fields: a dict of (str: list).
610
711
"""
812

913
def __init__(self, **fields):
10-
"""
11-
12-
:param fields: a dict of (str: list).
13-
"""
1414
self.fields = fields
1515

1616
def add_field(self, field_name, field):
1717
"""Add a new field to the instance.
1818
1919
:param field_name: str, the name of the field.
20-
:param field:
2120
"""
2221
self.fields[field_name] = field
2322

fastNLP/core/losses.py

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313

1414

1515
class LossBase(object):
16+
"""Base class for all losses.
17+
18+
"""
1619
def __init__(self):
1720
self.param_map = {}
1821
self._checked = False
@@ -68,10 +71,9 @@ def _init_param_map(self, key_map=None, **kwargs):
6871
# f"positional argument.).")
6972

7073
def _fast_param_map(self, pred_dict, target_dict):
71-
"""
72-
73-
Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map.
74+
"""Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map.
7475
such as pred_dict has one element, target_dict has one element
76+
7577
:param pred_dict:
7678
:param target_dict:
7779
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping.
@@ -265,27 +267,22 @@ def _prepare_losser(losser):
265267

266268

267269
def squash(predict, truth, **kwargs):
268-
"""To reshape tensors in order to fit loss functions in pytorch
269-
270-
:param predict : Tensor, model output
271-
:param truth : Tensor, truth from dataset
272-
:param **kwargs : extra arguments
270+
"""To reshape tensors in order to fit loss functions in PyTorch.
273271
272+
:param predict: Tensor, model output
273+
:param truth: Tensor, truth from dataset
274+
:param **kwargs: extra arguments
274275
:return predict , truth: predict & truth after processing
275276
"""
276277
return predict.view(-1, predict.size()[-1]), truth.view(-1, )
277278

278279

279280
def unpad(predict, truth, **kwargs):
280-
"""To process padded sequence output to get true loss
281-
Using pack_padded_sequence() method
282-
This method contains squash()
281+
"""To process padded sequence output to get true loss.
283282
284-
:param predict : Tensor, [batch_size , max_len , tag_size]
285-
:param truth : Tensor, [batch_size , max_len]
286-
:param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist
287-
kwargs["lens"] : list or LongTensor, [batch_size]
288-
the i-th element is true lengths of i-th sequence
283+
:param predict: Tensor, [batch_size , max_len , tag_size]
284+
:param truth: Tensor, [batch_size , max_len]
285+
:param kwargs: kwargs["lens"] is a list or LongTensor, with size [batch_size]. The i-th element is true lengths of i-th sequence.
289286
290287
:return predict , truth: predict & truth after processing
291288
"""
@@ -299,15 +296,11 @@ def unpad(predict, truth, **kwargs):
299296

300297

301298
def unpad_mask(predict, truth, **kwargs):
302-
"""To process padded sequence output to get true loss
303-
Using mask() method
304-
This method contains squash()
299+
"""To process padded sequence output to get true loss.
305300
306-
:param predict : Tensor, [batch_size , max_len , tag_size]
307-
:param truth : Tensor, [batch_size , max_len]
308-
:param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist
309-
kwargs["lens"] : list or LongTensor, [batch_size]
310-
the i-th element is true lengths of i-th sequence
301+
:param predict: Tensor, [batch_size , max_len , tag_size]
302+
:param truth: Tensor, [batch_size , max_len]
303+
:param kwargs: kwargs["lens"] is a list or LongTensor, with size [batch_size]. The i-th element is true lengths of i-th sequence.
311304
312305
:return predict , truth: predict & truth after processing
313306
"""
@@ -318,14 +311,11 @@ def unpad_mask(predict, truth, **kwargs):
318311

319312

320313
def mask(predict, truth, **kwargs):
321-
"""To select specific elements from Tensor
322-
This method contains squash()
314+
"""To select specific elements from Tensor. This method calls ``squash()``.
323315
324-
:param predict : Tensor, [batch_size , max_len , tag_size]
325-
:param truth : Tensor, [batch_size , max_len]
326-
:param **kwargs : extra arguments, kwargs["mask"] is expected to be exsist
327-
kwargs["mask"] : ByteTensor, [batch_size , max_len]
328-
the mask Tensor , the position that is 1 will be selected
316+
:param predict: Tensor, [batch_size , max_len , tag_size]
317+
:param truth: Tensor, [batch_size , max_len]
318+
:param **kwargs: extra arguments, kwargs["mask"]: ByteTensor, [batch_size , max_len], the mask Tensor. The position that is 1 will be selected.
329319
330320
:return predict , truth: predict & truth after processing
331321
"""
@@ -343,13 +333,11 @@ def mask(predict, truth, **kwargs):
343333

344334

345335
def make_mask(lens, tar_len):
346-
"""to generate a mask that select [:lens[i]] for i-th element
347-
embezzle from fastNLP.models.sequence_modeling.seq_mask
348-
349-
:param lens : list or LongTensor, [batch_size]
350-
:param tar_len : int
336+
"""To generate a mask over a sequence.
351337
352-
:return mask : ByteTensor
338+
:param lens: list or LongTensor, [batch_size]
339+
:param tar_len: int
340+
:return mask: ByteTensor
353341
"""
354342
lens = torch.LongTensor(lens)
355343
mask = [torch.ge(lens, i + 1) for i in range(tar_len)]

0 commit comments

Comments
 (0)