feat: make mesh accept meshcontext#2266
Conversation
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
|
/ok to test 3dcadfb |
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
|
/ok to test a8b2df6 |
|
/ok to test a4876ae |
|
/ok to test d836169 |
jgerh
left a comment
There was a problem hiding this comment.
Completed tech pubs review of docs/guides/gradient-checkpointing.md. No changes needed. LGTM.
|
/ok to test 010ddc8 |
Two leftover references to the old setup_distributed/dist_setup API were missed when the recipe was migrated to create_mesh_context_from_config: - nemo_automodel/recipes/vlm/finetune.py:794 still read self.dist_setup.cp_size, which would AttributeError on any PP+CP VLM run. - tests/unit_tests/recipes/test_finetune_vlm_cp_wiring.py monkeypatched the stale symbol "setup_distributed", causing three parametrizations of test_setup_skips_pp_media_prechunk_when_cp_preembeds_vlm_inputs to fail during pytest setup with AttributeError. Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
|
/ok to test e85de37 |
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
|
/ok to test fd99484 |
|
/ok to test 20415cf |
| resolve_trust_remote_code, | ||
| ) | ||
| from nemo_automodel.recipes._dist_setup import setup_distributed | ||
| from nemo_automodel.recipes._dist_utils import create_distributed_setup_from_config |
There was a problem hiding this comment.
should this be a private module if we are importing it in a public facing file?
There was a problem hiding this comment.
it's not great, agreed, but we'll be refactoring recipes so for now I'm ok with it.
|
/claude review |
| from nemo_automodel.components.distributed import FSDP2Config, create_mesh_context, initialize_distributed | ||
| from nemo_automodel._transformers.infrastructure import instantiate_infrastructure | ||
|
|
||
| # 1. Create strategy config | ||
| dist_env = initialize_distributed("nccl") | ||
| config = FSDP2Config(sequence_parallel=True, activation_checkpointing=True) | ||
|
|
||
| # 2. Create device mesh | ||
| device_mesh, moe_mesh = create_device_mesh( | ||
| config, tp_size=2, pp_size=1, cp_size=1, ep_size=1, world_size=8, | ||
| mesh = create_mesh_context( | ||
| config, tp_size=2, pp_size=1, cp_size=1, ep_size=1, world_size=dist_env.world_size, | ||
| ) | ||
|
|
||
| # 3. Build MeshContext | ||
| mesh = MeshContext.from_meshes( | ||
| device_mesh, moe_mesh, strategy_config=config, activation_checkpointing=True, | ||
| ) | ||
|
|
||
| # 4. Instantiate infrastructure | ||
| # 3. Instantiate infrastructure | ||
| model_wrapper, autopipeline, parallelize_fn, qat_quantizer = instantiate_infrastructure( | ||
| distributed_config=config, mesh=mesh, | ||
| ) |
There was a problem hiding this comment.
Bug: create_mesh_context doesn't exist — not defined or exported anywhere after this PR. The from_pretrained example below (line 398) also uses the removed distributed_config= kwarg.
Both examples should use the new DistributedSetup API:
from nemo_automodel.components.distributed import (
DistributedSetup, FSDP2Config, ParallelismSizes, initialize_distributed,
)
dist_env = initialize_distributed("nccl")
distributed_setup = DistributedSetup.build(
strategy=FSDP2Config(sequence_parallel=True),
parallelism_sizes=ParallelismSizes(tp_size=2),
activation_checkpointing=True,
world_size=dist_env.world_size,
)
model = NeMoAutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B",
distributed_setup=distributed_setup,
)| components/distributed/device_mesh.py | ||
| create_device_mesh() -- routes to FSDP2/MegatronFSDP/DDP mesh creation | ||
| create_mesh_context() -- builds MeshContext from strategy + parallelism | ||
| _create_device_meshes() -- routes to FSDP2/MegatronFSDP/DDP raw mesh creation | ||
| _create_fsdp2_device_mesh() -- shape (pp, dp_replicate, dp_shard, cp, tp) + flattened submeshes |
There was a problem hiding this comment.
Bug: device_mesh.py was deleted in this PR. create_mesh_context() doesn't exist. These helpers now live in mesh_utils.py, and MeshContext.build() is the public entry point in mesh.py.
| components/distributed/device_mesh.py | |
| create_device_mesh() -- routes to FSDP2/MegatronFSDP/DDP mesh creation | |
| create_mesh_context() -- builds MeshContext from strategy + parallelism | |
| _create_device_meshes() -- routes to FSDP2/MegatronFSDP/DDP raw mesh creation | |
| _create_fsdp2_device_mesh() -- shape (pp, dp_replicate, dp_shard, cp, tp) + flattened submeshes | |
| components/distributed/mesh_utils.py | |
| _create_device_meshes() -- routes to FSDP2/MegatronFSDP/DDP raw mesh creation | |
| _create_fsdp2_device_mesh() -- shape (pp, dp_replicate, dp_shard, cp, tp) + flattened submeshes | |
| _create_megatron_fsdp_device_mesh() -- shape (dp, cp, tp) |
There was a problem hiding this comment.
Light Review Summary
The API refactoring looks clean — collapsing the scattered distributed kwargs into a single DistributedSetup is a clear win for the public API surface, and the backward-compatible device_mesh= path is well handled.
Documentation issues in SKILL.md (4 inline comments): The skill file has several references to functions and files that no longer exist after this PR (create_mesh_context_from_config, create_mesh_context, device_mesh.py). The programmatic API code examples also use removed kwargs (distributed_config=, tp_size=). Since this file is used as authoritative context by AI agents, these should be fixed before merge.
Latent behavior change (1 inline comment): FSDP2Config(mp_policy=None) now keeps None instead of defaulting to the standard MixedPrecisionPolicy, since __post_init__ was replaced by default_factory. Current callers always provide mp_policy explicitly so this is latent, but the _create_parallel_manager code has a trap where args.get("mp_policy", None) would trigger this.
Test coverage for the new API looks good.
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
Co-authored-by: Adil <47084919+adil-a@users.noreply.github.com>
What does this PR do?
Refactors the distributed public API so topology and distributed policies are layered explicitly.
The main user-facing object is now
DistributedSetup, which owns:mesh_context: runtime topology andDeviceMesh/ MoE mesh accessstrategy_config: FSDP2 / Megatron FSDP / DDP strategy configpipeline_config: pipeline-parallel runtime configmoe_parallel_config: MoE parallelization configactivation_checkpointing: activation-checkpointing policyMeshContextis narrowed to topology only. It no longer owns activation checkpointing or higher-level training policy.Changelog
DistributedSetup.build(...)as the component-layer entry point for constructing distributed setup from strategy, parallelism sizes, pipeline config, MoE config, and activation checkpointing.device_meshcompatibility inNeMoAutoModel*.from_pretrainedby wrapping raw HF-style meshes into an internal topology-onlyDistributedSetup.device_mesh.pyand move raw mesh construction/access helpers intomesh_utils.py.ParallelismSizesfordp/tp/pp/cp/epsizing intent.MoEParallelizerConfiginto distributed config, since it is part of distributed setup rather than model-only MoE config.DistributedSetupfrom YAML/programmatic config and fan out the derived runtime attributes consistently.device_meshcompatibility.API shape
Python usage:
HF-compatible raw mesh usage is still allowed:
Future work
Currently FSDP2Config is not pure FSDP, but also includes options for TP/SP; those will be refactored in a follow-up PR to separate concerns.
Before your PR is "Ready for review"
Pre checks:
Validation:
python -m ruff check ...python -m ruff format --check ...python -m py_compile ...pytest tests/unit_tests/recipes/test_dist_utils.py -qNote: local full recipe test collection is blocked in my environment by an existing
mlflow/cachetools.func.cachedimport mismatch. CI should be used for full CPU coverage.Additional Information
This keeps the TorchTitan-like layering:
ParallelismSizesMeshContextDistributedSetupcreate_distributed_setup_from_config