Skip to content

Commit 6ef8a6f

Browse files
committed
doc(model): minor comment clarifications
1 parent d1d82e0 commit 6ef8a6f

File tree

5 files changed

+45
-21
lines changed

5 files changed

+45
-21
lines changed

notebooks/src/code/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ class SageMakerTrainingArguments(TrainingArguments):
7474
metadata={"help": "The evaluation strategy to use."},
7575
)
7676
save_strategy: IntervalStrategy = field(
77-
# We'd like some eval metrics by default, rather than the usual "no" strategy
77+
# Should match evaluation strategy for early stopping to work
7878
default="epoch",
7979
metadata={"help": "The model save strategy to use."},
8080
)

notebooks/src/code/data/base.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
@dataclass
3232
class TaskData:
33-
"""Base interface exposed by the different task types (MLM, NER, etc) to training scripts
33+
"""Base data interface exposed by the different task types (MLM, NER, etc) to training scripts
3434
3535
Each new task module should implement a method get_task(data_args, tokenizer) -> TaskData
3636
"""
@@ -70,7 +70,13 @@ def split(
7070

7171

7272
class NaiveExampleSplitter(ExampleSplitterBase):
73-
"""Split sequences by word, and pull final sequence start forward to fill max allowable length"""
73+
"""Split sequences by word, and pull final sequence start forward if it comes up <50% max len
74+
75+
This algorithm produces examples by splitting tokens on word boundaries, extending each sample
76+
until max_content_seq_len is filled. *IF* the final generated example is less than 50% of the
77+
maximum tokens, its start index will be pulled forward to consume as many words as will fit.
78+
Apart from this, there will be no overlap between examples.
79+
"""
7480

7581
@classmethod
7682
def n_examples(cls, n_tokens: int, max_content_seq_len: int) -> int:
@@ -143,7 +149,15 @@ def split(
143149

144150

145151
class TextractLayoutLMDatasetBase(Dataset):
146-
"""Base class for PyTorch/Hugging Face dataset using Amazon Textract for LayoutLM-based models"""
152+
"""Base class for PyTorch/Hugging Face dataset using Amazon Textract for LayoutLM-based models
153+
154+
The base dataset assumes fixed/known length, which typically requires analyzing the source data
155+
on init - but avoids the complications of shuffling iterable dataset samples in a multi-process
156+
environment, or introducing SageMaker Pipe Mode and RecordIO formats.
157+
158+
Source data is provided as a folder of Amazon Textract result JSONs, with an optional JSONLines
159+
manifest file annotating the documents in case the task is supervised.
160+
"""
147161

148162
def __init__(
149163
self,
@@ -286,8 +300,8 @@ def max_content_seq_len(self):
286300
class DummyDataCollator:
287301
"""Data collator that just stacks tensors from inputs.
288302
289-
For use with Dataset classes where the leg-work is already done and HF's default
290-
"DataCollatorWithPadding" should explicitly *not* be used.
303+
For use with Dataset classes where the tokenization and collation leg-work is already done and
304+
HF's default "DataCollatorWithPadding" should explicitly *not* be used.
291305
"""
292306

293307
def __call__(self, features):

notebooks/src/code/data/geometry.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
class AnnotationBoundingBox:
1515
"""Class to parse a bounding box annotated by SageMaker Ground Truth Object Detection
1616
17-
Calculates all box TLHWBR metrics (both absolute and relative) on init, for efficient and easy
18-
processing later.
17+
Pre-calculates all box TLHWBR metrics (both absolute and relative) on init, for efficient and
18+
easy processing later.
1919
"""
2020

2121
def __init__(self, manifest_box: dict, image_height: int, image_width: int):
@@ -90,6 +90,14 @@ class BoundingBoxAnnotationResult:
9090
"""Class to parse the result field saved by a SageMaker Ground Truth Object Detection job"""
9191

9292
def __init__(self, manifest_obj: dict):
93+
"""Initialize a BoundingBoxAnnotationResult
94+
95+
Arguments
96+
---------
97+
manifest_obj : dict
98+
The contents of the output field of a record in a SMGT Object Detection labelling job
99+
output manifest, or equivalent.
100+
"""
93101
try:
94102
image_size_spec = manifest_obj["image_size"][0]
95103
self._image_height = int(image_size_spec["height"])
@@ -101,9 +109,9 @@ def __init__(self, manifest_obj: dict):
101109
raise ValueError(
102110
"".join(
103111
(
104-
"manifest_obj must be a dictionary including 'image_size': a list of length 1 ",
105-
"whose first/only element is a dict with integer properties 'height' and ",
106-
f"'width', optionally also 'depth'. Got: {manifest_obj}",
112+
"manifest_obj must be a dictionary including 'image_size': a list of ",
113+
"length 1 whose first/only element is a dict with integer properties ",
114+
f"'height' and 'width', optionally also 'depth'. Got: {manifest_obj}",
107115
)
108116
)
109117
) from e

notebooks/src/code/data/mlm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ class TextractLayoutLMDataCollatorForLanguageModelling(DataCollatorForLanguageMo
3535
"""Collator to process (batches of) Examples into batched model inputs
3636
3737
For this case, tokenization can happen at the batch level which allows us to pad to the longest
38-
sample in batch rather than the overall model max_seq_len. Word splitting is already done by
39-
Textract, and some custom logic is required to feed through the bounding box inputs from
40-
Textract (at word level) to the model inputs (at token level).
38+
sample in batch rather than the overall model max_seq_len - for efficiency. Word splitting is
39+
already done by Textract, and some custom logic is required to feed through the bounding box
40+
inputs from Textract (at word level) to the model inputs (at token level).
4141
"""
4242

4343
bos_token_box: Tuple[int, int, int, int] = (0, 0, 0, 0)

notebooks/util/training.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ def get_hf_metric_regex(metric_name: str) -> str:
1010
{'eval_loss': 0.3940396010875702, ..., 'epoch': 1.0}
1111
"""
1212
scientific_number_exp = r"(-?[0-9]+(\.[0-9]+)?(e[+\-][0-9]+)?)"
13-
return "".join((
14-
"'",
15-
metric_name,
16-
"': ",
17-
scientific_number_exp,
18-
"[,}]",
19-
))
13+
return "".join(
14+
(
15+
"'",
16+
metric_name,
17+
"': ",
18+
scientific_number_exp,
19+
"[,}]",
20+
)
21+
)

0 commit comments

Comments
 (0)