-
Notifications
You must be signed in to change notification settings - Fork 418
Description
Feature or Model Request
-
What problem are you trying to solve?
The current data pipeline (TokenizeAndTrim
) simply truncates any text sequence that exceedsmax_target_length
. This causes all data past this limit in long documents (e.g., from large corpora like C4) to be completely lost during training. This is a highly inefficient use of data. -
Why is this problem important?
Data loss leads directly to wasted compute resources and storage. Furthermore, it risks biasing the model by only training on the beginning of every document, which reduces the model's opportunity to learn from long-context information. It is far more efficient to utilize 100% of the purchased or preprocessed data. -
Describe your requested feature or solution.
- Implement a new class,
TokenizeAndChunk
, that inherits fromgrain.experimental.FlatMapTransform
(which supports 1:N mappings) as an alternative to the existingTokenizeAndTrim
(grain.MapTransform
). - Instead of truncating the token list (i.e.,
[:seq_len]
), this new transform should split the entire token sequence into multiple chunks, each ofmax_target_length
, and return them as a list ([chunk_1, chunk_2, ...]
). - Add a boolean configuration flag,
use_truncation
, to control this behavior:- When
use_truncation = True
, the pipeline will use the existingTokenizeAndTrim
(MapTransform
) to truncate sequences. - When
use_truncation = False
, the pipeline will use the newTokenizeAndChunk
(FlatMapTransform
) to split sequences into multiple examples.
- When
- Implement a new class,
-
Describe alternatives you’ve considered (if any).
The only alternative is the current "truncation" behavior, which, as noted, is data-inefficient. -
Additional context or examples.
For example, ifmax_target_length = 512
and an input document tokenizes to 1200 tokens:- Current (
use_truncation = True
): Produces one example of 512 tokens. (688 tokens are lost). - Proposed (
use_truncation = False
): Produces three examples: one of 512 tokens, one of 512 tokens, and one of 176 tokens. (0 tokens are lost).
This implementation requires using
dataset.apply(TokenizeAndChunk(...))
in the pipeline (when the flag is false) instead ofdataset.map(TokenizeAndTrim(...))
. - Current (
Additional Context
- AI GDE and Kakao