Skip to content

feat: make mesh accept meshcontext#2266

Open
adil-a wants to merge 18 commits into
mainfrom
akoumpa/refactor_auto_class_public_api
Open

feat: make mesh accept meshcontext#2266
adil-a wants to merge 18 commits into
mainfrom
akoumpa/refactor_auto_class_public_api

Conversation

@adil-a
Copy link
Copy Markdown
Collaborator

@adil-a adil-a commented May 18, 2026

What does this PR do?

Refactors the distributed public API so topology and distributed policies are layered explicitly.

The main user-facing object is now DistributedSetup, which owns:

  • mesh_context: runtime topology and DeviceMesh / MoE mesh access
  • strategy_config: FSDP2 / Megatron FSDP / DDP strategy config
  • pipeline_config: pipeline-parallel runtime config
  • moe_parallel_config: MoE parallelization config
  • activation_checkpointing: activation-checkpointing policy

MeshContext is narrowed to topology only. It no longer owns activation checkpointing or higher-level training policy.

Changelog

  • Add DistributedSetup.build(...) as the component-layer entry point for constructing distributed setup from strategy, parallelism sizes, pipeline config, MoE config, and activation checkpointing.
  • Keep device_mesh compatibility in NeMoAutoModel*.from_pretrained by wrapping raw HF-style meshes into an internal topology-only DistributedSetup.
  • Remove legacy device_mesh.py and move raw mesh construction/access helpers into mesh_utils.py.
  • Introduce ParallelismSizes for dp/tp/pp/cp/ep sizing intent.
  • Move MoEParallelizerConfig into distributed config, since it is part of distributed setup rather than model-only MoE config.
  • Update recipes to build a single DistributedSetup from YAML/programmatic config and fan out the derived runtime attributes consistently.
  • Update diffusion, LLM, VLM, KD, retrieval, and sequence-classification callsites to use the new setup layering.
  • Update tests for the new layering and raw device_mesh compatibility.

API shape

Python usage:

from nemo_automodel.components.distributed import DistributedSetup, FSDP2Config, ParallelismSizes
from nemo_automodel import NeMoAutoModelForCausalLM

distributed_setup = DistributedSetup.build(
    strategy=FSDP2Config(sequence_parallel=True),
    parallelism_sizes=ParallelismSizes(tp_size=2, ep_size=8),
)

model = NeMoAutoModelForCausalLM.from_pretrained(
    "model/name",
    distributed_setup=distributed_setup,
)

HF-compatible raw mesh usage is still allowed:

model = NeMoAutoModelForCausalLM.from_pretrained(
    "model/name",
    device_mesh=device_mesh,
)

Future work

Currently FSDP2Config is not pure FSDP, but also includes options for TP/SP; those will be refactored in a follow-up PR to separate concerns.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?

Validation:

  • python -m ruff check ...
  • python -m ruff format --check ...
  • python -m py_compile ...
  • pytest tests/unit_tests/recipes/test_dist_utils.py -q

Note: local full recipe test collection is blocked in my environment by an existing mlflow / cachetools.func.cached import mismatch. CI should be used for full CPU coverage.

Additional Information

This keeps the TorchTitan-like layering:

  • sizes: ParallelismSizes
  • topology: MeshContext
  • distributed policies and topology bundle: DistributedSetup
  • recipe/YAML adapter: create_distributed_setup_from_config

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 18, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@adil-a
Copy link
Copy Markdown
Collaborator Author

adil-a commented May 18, 2026

/ok to test 3dcadfb

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented May 18, 2026

/ok to test a8b2df6

@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented May 19, 2026

/ok to test a4876ae

@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented May 20, 2026

/ok to test d836169

Copy link
Copy Markdown
Contributor

@jgerh jgerh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completed tech pubs review of docs/guides/gradient-checkpointing.md. No changes needed. LGTM.

@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented May 26, 2026

/ok to test 010ddc8

Two leftover references to the old setup_distributed/dist_setup API were
missed when the recipe was migrated to create_mesh_context_from_config:

- nemo_automodel/recipes/vlm/finetune.py:794 still read
  self.dist_setup.cp_size, which would AttributeError on any PP+CP VLM run.
- tests/unit_tests/recipes/test_finetune_vlm_cp_wiring.py monkeypatched
  the stale symbol "setup_distributed", causing three parametrizations of
  test_setup_skips_pp_media_prechunk_when_cp_preembeds_vlm_inputs to fail
  during pytest setup with AttributeError.

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented May 26, 2026

/ok to test e85de37

akoumpa added 4 commits May 26, 2026 15:17
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented May 27, 2026

/ok to test fd99484

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented May 27, 2026

/ok to test 20415cf

Comment thread nemo_automodel/components/distributed/config.py Outdated
resolve_trust_remote_code,
)
from nemo_automodel.recipes._dist_setup import setup_distributed
from nemo_automodel.recipes._dist_utils import create_distributed_setup_from_config
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be a private module if we are importing it in a public facing file?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not great, agreed, but we'll be refactoring recipes so for now I'm ok with it.

@adil-a
Copy link
Copy Markdown
Collaborator Author

adil-a commented May 27, 2026

/claude review

Comment thread skills/distributed-training/SKILL.md Outdated
Comment on lines +377 to 390
from nemo_automodel.components.distributed import FSDP2Config, create_mesh_context, initialize_distributed
from nemo_automodel._transformers.infrastructure import instantiate_infrastructure

# 1. Create strategy config
dist_env = initialize_distributed("nccl")
config = FSDP2Config(sequence_parallel=True, activation_checkpointing=True)

# 2. Create device mesh
device_mesh, moe_mesh = create_device_mesh(
config, tp_size=2, pp_size=1, cp_size=1, ep_size=1, world_size=8,
mesh = create_mesh_context(
config, tp_size=2, pp_size=1, cp_size=1, ep_size=1, world_size=dist_env.world_size,
)

# 3. Build MeshContext
mesh = MeshContext.from_meshes(
device_mesh, moe_mesh, strategy_config=config, activation_checkpointing=True,
)

# 4. Instantiate infrastructure
# 3. Instantiate infrastructure
model_wrapper, autopipeline, parallelize_fn, qat_quantizer = instantiate_infrastructure(
distributed_config=config, mesh=mesh,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: create_mesh_context doesn't exist — not defined or exported anywhere after this PR. The from_pretrained example below (line 398) also uses the removed distributed_config= kwarg.

Both examples should use the new DistributedSetup API:

from nemo_automodel.components.distributed import (
    DistributedSetup, FSDP2Config, ParallelismSizes, initialize_distributed,
)

dist_env = initialize_distributed("nccl")
distributed_setup = DistributedSetup.build(
    strategy=FSDP2Config(sequence_parallel=True),
    parallelism_sizes=ParallelismSizes(tp_size=2),
    activation_checkpointing=True,
    world_size=dist_env.world_size,
)

model = NeMoAutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-1B",
    distributed_setup=distributed_setup,
)

Comment on lines 428 to 431
components/distributed/device_mesh.py
create_device_mesh() -- routes to FSDP2/MegatronFSDP/DDP mesh creation
create_mesh_context() -- builds MeshContext from strategy + parallelism
_create_device_meshes() -- routes to FSDP2/MegatronFSDP/DDP raw mesh creation
_create_fsdp2_device_mesh() -- shape (pp, dp_replicate, dp_shard, cp, tp) + flattened submeshes
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: device_mesh.py was deleted in this PR. create_mesh_context() doesn't exist. These helpers now live in mesh_utils.py, and MeshContext.build() is the public entry point in mesh.py.

Suggested change
components/distributed/device_mesh.py
create_device_mesh() -- routes to FSDP2/MegatronFSDP/DDP mesh creation
create_mesh_context() -- builds MeshContext from strategy + parallelism
_create_device_meshes() -- routes to FSDP2/MegatronFSDP/DDP raw mesh creation
_create_fsdp2_device_mesh() -- shape (pp, dp_replicate, dp_shard, cp, tp) + flattened submeshes
components/distributed/mesh_utils.py
_create_device_meshes() -- routes to FSDP2/MegatronFSDP/DDP raw mesh creation
_create_fsdp2_device_mesh() -- shape (pp, dp_replicate, dp_shard, cp, tp) + flattened submeshes
_create_megatron_fsdp_device_mesh() -- shape (dp, cp, tp)

Comment thread skills/distributed-training/SKILL.md Outdated
Comment thread nemo_automodel/_diffusers/auto_diffusion_pipeline.py Outdated
Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Light Review Summary

The API refactoring looks clean — collapsing the scattered distributed kwargs into a single DistributedSetup is a clear win for the public API surface, and the backward-compatible device_mesh= path is well handled.

Documentation issues in SKILL.md (4 inline comments): The skill file has several references to functions and files that no longer exist after this PR (create_mesh_context_from_config, create_mesh_context, device_mesh.py). The programmatic API code examples also use removed kwargs (distributed_config=, tp_size=). Since this file is used as authoritative context by AI agents, these should be fixed before merge.

Latent behavior change (1 inline comment): FSDP2Config(mp_policy=None) now keeps None instead of defaulting to the standard MixedPrecisionPolicy, since __post_init__ was replaced by default_factory. Current callers always provide mp_policy explicitly so this is latent, but the _create_parallel_manager code has a trap where args.get("mp_policy", None) would trigger this.

Test coverage for the new API looks good.

akoumpa and others added 4 commits May 27, 2026 14:29
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
Co-authored-by: Adil <47084919+adil-a@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants