Skip to content

Add dataset type olmo_grain for AI2 OLMo numpy pretrain mixes#3749

Open
gagika wants to merge 1 commit intomainfrom
gagik-olmo-data
Open

Add dataset type olmo_grain for AI2 OLMo numpy pretrain mixes#3749
gagika wants to merge 1 commit intomainfrom
gagik-olmo-data

Conversation

@gagika
Copy link
Copy Markdown
Collaborator

@gagika gagika commented Apr 26, 2026

Description

  • Adds dataset_type=olmo_grain, a Grain-based input pipeline for AI2's
    pre-tokenized OLMo numpy mixes (e.g. OLMo-mix-0925-official.txt).
    Reads headerless .npy token streams from a gcsfuse mount, applies
    OLMo-core's repeated-n-gram filter, and yields the shapes the MaxText
    pretrain trainer expects.
  • Stateless sampler: record at step k is a pure function of
    (seed, shard, k). Resume reads the latest step from
    config.checkpoint_dir and shifts the sampler — no Grain iterator
    state in the checkpoint.
  • Ships two data tools (download_olmo_data_to_gcs.py with HTTP-Range
    resume; build_olmo_npy_index.py for header-scan indexing) and two
    launchers (run_olmo3_7b_grain_smoke.sh,
    run_olmo3_7b_grain_resume_test.sh).

Tests

  • Unit tests pass (tests/unit/input_pipeline/olmo_*)
  • Smoke train: 50 steps, loss 11.99 → 8.93 on v4-8 (4-layer bf16)
  • Resume test: Run B picks up at step 50 with loss continuity 8.931 → 8.930

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

authored-by: @aireenmei

@github-actions
Copy link
Copy Markdown

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @gagika, but I was unable to process your request. Please see the logs for more details.

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 26, 2026

Comment thread tests/unit/input_pipeline/olmo_data_grain_test.py Outdated
Comment thread src/maxtext/input_pipeline/olmo_data_grain.py Outdated
Comment thread src/maxtext/input_pipeline/olmo_data_grain.py Outdated
Comment thread src/maxtext/input_pipeline/olmo_data_grain.py
@gagika gagika force-pushed the gagik-olmo-data branch 2 times, most recently from 9de5321 to 9e3ff8f Compare May 4, 2026 14:22
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 5, 2026

🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request introduces a high-quality, Grain-based input pipeline for AI2's OLMo numpy datasets. The implementation is robust, well-documented, and includes a particularly clean approach to stateless resumption by deriving the data offset from the model checkpoint step.

🔍 General Feedback

  • Stateless Resume: The initial_step logic in the sampler is an excellent design choice that avoids the complexities of Grain iterator-state serialization.
  • N-gram Filtering: The integration of OLMo-core's repetition filter via a custom transform that masks instances in the loss is both efficient and sharding-friendly.
  • Testing and Validation: The inclusion of unit tests, smoke scripts, and end-to-end resume tests provides great confidence in the stability of the new pipeline.
  • Performance: While the in-memory permutation for shuffling is currently manageable, it's worth monitoring as dataset sizes scale further.

Comment thread src/maxtext/input_pipeline/olmo_data.py
Comment thread src/maxtext/input_pipeline/olmo_data.py
total_instances: ``index.total_instances`` from the OLMo index.
seed: Base seed for the shuffle.
shard_index: Zero-based index of this data-loading host. Typically
``jax.process_index()``.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 For very large datasets (e.g., the 724M instance mix mentioned), allocating the full permutation in host memory (~5.8 GB) can be a significant spike, especially if many hosts are doing it simultaneously at an epoch boundary.

While acceptable for the current scope, consider implementing a lazy or on-disk permutation scheme if the dataset size grows further or if host memory becomes a constraint.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at this scale it's fine, we can follow up with Chunked / Philox-keyed shuffle as follow-up if we need larger.

Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change!

Have you had a chance to chat with @aireenmei about this yet? I'm wondering if we could design these features to directly leverage the existing grain_data_processing.py. For example, we could add things like n-gram filtering and pre-tokenized numpy mixes as features there to improve code reusability.

The main benefit I see is reducing maintenance overhead. We currently maintain both tfds and c4_tfds_mlperf, but the latter is rarely used and has some maintenance issues. Since Grain will be heavily used moving forward, it makes sense to build on top of it directly. Let me know your thoughts—happy to discuss!

Comment thread docs/guides/data_input_pipeline/olmo_grain.md Outdated
Comment thread src/maxtext/configs/types.py
@gagika
Copy link
Copy Markdown
Collaborator Author

gagika commented May 6, 2026

Thanks for the change!

Have you had a chance to chat with @aireenmei about this yet? I'm wondering if we could design these features to directly leverage the existing grain_data_processing.py. For example, we could add things like n-gram filtering and pre-tokenized numpy mixes as features there to improve code reusability.

The main benefit I see is reducing maintenance overhead. We currently maintain both tfds and c4_tfds_mlperf, but the latter is rarely used and has some maintenance issues. Since Grain will be heavily used moving forward, it makes sense to build on top of it directly. Let me know your thoughts—happy to discuss!

I haven't yet chatted with @aireenmei, I will follow up.

The main reason I kept this separate is the stateless-resume contract — record at step k must be a pure function of (seed, shard, k) so we recover the data offset from the checkpoint step alone, with no Grain iterator state in the checkpoint. MapDataset.shuffle().repeat() has no step-offset hook, so reusing it as-is would drop that property.

The general pieces, e.g. the n-gram mask, probably can be lifted into grain_data_processing.py as a follow-up. Happy to take that on after this lands.
I will chat with Aireen about it.

@gagika gagika force-pushed the gagik-olmo-data branch from 9e3ff8f to 587a4f3 Compare May 6, 2026 05:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants