Skip to content

[feature request] Support chunking (splitting) long sequences instead of truncation during tokenization #2344

@bzantium

Description

@bzantium

Feature or Model Request

  • What problem are you trying to solve?
    The current data pipeline (TokenizeAndTrim) simply truncates any text sequence that exceeds max_target_length. This causes all data past this limit in long documents (e.g., from large corpora like C4) to be completely lost during training. This is a highly inefficient use of data.

  • Why is this problem important?
    Data loss leads directly to wasted compute resources and storage. Furthermore, it risks biasing the model by only training on the beginning of every document, which reduces the model's opportunity to learn from long-context information. It is far more efficient to utilize 100% of the purchased or preprocessed data.

  • Describe your requested feature or solution.

    1. Implement a new class, TokenizeAndChunk, that inherits from grain.experimental.FlatMapTransform (which supports 1:N mappings) as an alternative to the existing TokenizeAndTrim (grain.MapTransform).
    2. Instead of truncating the token list (i.e., [:seq_len]), this new transform should split the entire token sequence into multiple chunks, each of max_target_length, and return them as a list ([chunk_1, chunk_2, ...]).
    3. Add a boolean configuration flag, use_truncation, to control this behavior:
      • When use_truncation = True, the pipeline will use the existing TokenizeAndTrim (MapTransform) to truncate sequences.
      • When use_truncation = False, the pipeline will use the new TokenizeAndChunk (FlatMapTransform) to split sequences into multiple examples.
  • Describe alternatives you’ve considered (if any).
    The only alternative is the current "truncation" behavior, which, as noted, is data-inefficient.

  • Additional context or examples.
    For example, if max_target_length = 512 and an input document tokenizes to 1200 tokens:

    • Current (use_truncation = True): Produces one example of 512 tokens. (688 tokens are lost).
    • Proposed (use_truncation = False): Produces three examples: one of 512 tokens, one of 512 tokens, and one of 176 tokens. (0 tokens are lost).

    This implementation requires using dataset.apply(TokenizeAndChunk(...)) in the pipeline (when the flag is false) instead of dataset.map(TokenizeAndTrim(...)).

Additional Context

  • AI GDE and Kakao

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions