Skip to content

NNX migration prep (2/N): NNX utils and sharding utilities#3470

Draft
ecnal-cienet wants to merge 2 commits intomainfrom
feat/migrate-nnx-utils
Draft

NNX migration prep (2/N): NNX utils and sharding utilities#3470
ecnal-cienet wants to merge 2 commits intomainfrom
feat/migrate-nnx-utils

Conversation

@ecnal-cienet
Copy link
Collaborator

@ecnal-cienet ecnal-cienet commented Mar 20, 2026

Description

Note: This is the first in a series of NNX migration PRs. Pure NNX training is not yet implemented — all NNX code paths currently raise NotImplementedError. This PR only introduces the structural scaffolding needed for subsequent patches to plug in NNX logic without modifying shared infrastructure.

  • NNX sharding utilities (maxtext_utils_nnx.py) — Functions to manipulate NNX model shardings using abstract model state: get_named_sharding_nnx, set_named_sharding_nnx, get_partition_spec_nnx, and memory movement helpers (move_memory_to_host / move_memory_to_device).
  • get_abstract_state NNX path — Added get_abstract_state_nnx to maxtext_utils.py, which uses nnx.get_abstract_model to return a flat nnx.State (rather than a full TrainStateNNX), and updated get_abstract_state to dispatch to it when pure_nnx=True.
  • maxtext_utils.get_mesh_from_config() — Extracted mesh creation into a standalone function with unit tests.
  • Unit tests — Added tests/unit/maxtext_utils_nnx_test.py and extended tests/unit/maxtext_utils_test.py to cover the new mesh and sharding utilities.

Note on Flax deprecation warnings:
Flax v0.12 emits DeprecationWarning for .value access and VariableState. These are intentionally left unaddressed because post-training currently requires Flax v0.11 compatibility.

Tests

pytest tests/unit/maxtext_utils_nnx_test.py tests/unit/maxtext_utils_test.py -v 

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

- pure_nnx: a flag to to choose pure NNX logic when NNX and linen models
  co-exist.
- init_state_fn: a function to initialize the model state for the
  training. It will be set to different function for NNX and Linen.
@ecnal-cienet ecnal-cienet changed the title Feat/migrate nnx utils NNX migration prep (1/N): Migrate MaxText Utils Mar 20, 2026
@ecnal-cienet ecnal-cienet changed the title NNX migration prep (1/N): Migrate MaxText Utils NNX migration prep (2/N): Migrate MaxText Utils Mar 20, 2026
@ecnal-cienet ecnal-cienet force-pushed the feat/migrate-nnx-utils branch from 7669e8e to 4fc37b6 Compare March 20, 2026 21:54
- Add utils to manipulate the NNX shardings with abstract state of a
  model
  - also add unit tests for the utils
- Extract mesh creation function to maxtext_utils.get_mesh_from_config()
  - also add unit tests for this func

Note:
flax v0.12 has DeprecationWarning in multiple places:
  - DeprecationWarning: '.value' access is now deprecated. Use
    variable.get_value() or variable[...] (for [Array]).
  - DeprecationWarning: 'VariableState' was removed, this is just
    an alias to 'Variable'. Plase use 'Variable' directly instead.
But since the code needs to work with post-training, which currently
requires flax v0.11, we didn't change code for these warnings.
@ecnal-cienet ecnal-cienet force-pushed the feat/migrate-nnx-utils branch from 4fc37b6 to 722386f Compare March 21, 2026 00:57
@ecnal-cienet ecnal-cienet changed the title NNX migration prep (2/N): Migrate MaxText Utils NNX migration prep (2/N): NNX utils and sharding utilities Mar 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants