Skip to content

Commit bf35003

Browse files
jasont314nazar-ospanovzimo0110sanjay-adhikesaven
committed
fix: preserve rebased PP/EP path compatibility and refresh artifacts
Co-authored-by: Nazar Ospanov <aimogenius@berkeley.edu> Co-authored-by: Zoir Imomaliev <91550816+zimo0110@users.noreply.github.com> Co-authored-by: Sanjay Adhikesaven <sanjay.adhikesaven1@gmail.com> Signed-off-by: Jason Trinh <jasontrinh@berkeley.edu>
1 parent fa19503 commit bf35003

10 files changed

Lines changed: 248 additions & 84 deletions

File tree

checkpoints/optimized_training.jsonl

Lines changed: 12 additions & 50 deletions
Large diffs are not rendered by default.

examples/llm_finetune/nemotron/nemotron_nano_v3_pp_ep_squad.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ validation_dataset:
110110
_target_: nemo_automodel.components.datasets.llm.squad.make_squad_dataset
111111
dataset_name: rajpurkar/squad
112112
split: validation
113-
limit_dataset_samples: 64
113+
# With dp=2 and local_batch_size=64, keep at least 128 samples so each DP rank
114+
# gets a full local batch during validation (avoids PP microbatch shape mismatch).
115+
limit_dataset_samples: 128
114116
seq_length: 1024
115117
padding: max_length
116118
truncation: true

nemo_automodel/_transformers/auto_model.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,21 @@
3939
from transformers import ( # noqa: E402
4040
AutoModelForCausalLM,
4141
AutoModelForImageTextToText,
42-
AutoModelForMultimodalLM,
4342
AutoModelForSequenceClassification,
4443
AutoModelForTextToWaveform,
4544
PreTrainedModel,
4645
)
47-
from transformers.initialization import no_init_weights # noqa: E402
46+
try: # noqa: E402
47+
from transformers import AutoModelForMultimodalLM # noqa: E402
48+
except ImportError: # transformers<4.58
49+
# Older transformers releases expose image-text multimodal auto-models
50+
# under AutoModelForImageTextToText but not AutoModelForMultimodalLM.
51+
AutoModelForMultimodalLM = AutoModelForImageTextToText
52+
try: # noqa: E402
53+
from transformers.initialization import no_init_weights # noqa: E402
54+
except ImportError: # transformers<4.58
55+
from transformers.modeling_utils import no_init_weights # noqa: E402
56+
4857
from transformers.models.auto.auto_factory import _BaseAutoModelClass # noqa: E402
4958
from transformers.utils import ContextManagers # noqa: E402
5059

nemo_automodel/_transformers/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,10 @@ def _patch_special_tokens_pattern():
117117
lack CLS/SEP tokens end up with ``None`` IDs in the sequence, crashing
118118
``pad()``.
119119
"""
120-
from transformers.tokenization_python import PreTrainedTokenizer
120+
try:
121+
from transformers.tokenization_python import PreTrainedTokenizer
122+
except ModuleNotFoundError: # transformers<5.x
123+
from transformers.tokenization_utils import PreTrainedTokenizer
121124

122125
_orig_init = PreTrainedTokenizer.__init__
123126

nemo_automodel/components/distributed/pipelining/functional.py

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -482,8 +482,94 @@ def split_model_into_stages(
482482
pp_rank = pp_mesh.get_local_rank()
483483
pp_size = pp_mesh.size()
484484
# Detect model structure
485-
has_model_attr = hasattr(model, "model") and getattr(model, "model", None) is not None
486-
has_backbone_attr = (not has_model_attr) and hasattr(model, "backbone") and getattr(model, "backbone", None) is not None
485+
model_has_model_attr = hasattr(model, "model") and getattr(model, "model", None) is not None
486+
model_has_backbone_attr = hasattr(model, "backbone") and getattr(model, "backbone", None) is not None
487+
488+
def _submodule_exists(module_root: nn.Module, module_fqn: str) -> bool:
489+
if not module_fqn:
490+
return True
491+
try:
492+
module_root.get_submodule(module_fqn)
493+
return True
494+
except Exception:
495+
return False
496+
497+
def _normalize_stage_fqn_aliases(explicit_stages: list[list[str]]) -> list[list[str]]:
498+
alias_suffixes = (
499+
(".embeddings", ".embed_tokens"),
500+
(".embed_tokens", ".embeddings"),
501+
(".norm_f", ".norm"),
502+
(".norm", ".norm_f"),
503+
)
504+
rewrites: list[tuple[str, str]] = []
505+
normalized_stages: list[list[str]] = []
506+
for stage_modules in explicit_stages:
507+
normalized_stage: list[str] = []
508+
for module_fqn in stage_modules:
509+
normalized_fqn = module_fqn
510+
if not _submodule_exists(model, normalized_fqn):
511+
for src_suffix, dst_suffix in alias_suffixes:
512+
if normalized_fqn.endswith(src_suffix):
513+
candidate = normalized_fqn[: -len(src_suffix)] + dst_suffix
514+
if _submodule_exists(model, candidate):
515+
rewrites.append((normalized_fqn, candidate))
516+
normalized_fqn = candidate
517+
break
518+
normalized_stage.append(normalized_fqn)
519+
normalized_stages.append(normalized_stage)
520+
521+
if rewrites:
522+
# De-duplicate while preserving insertion order.
523+
unique_rewrites = list(dict.fromkeys(rewrites))
524+
logger.info(
525+
"Rewriting pipeline stage FQN aliases for current model structure: %s",
526+
", ".join(f"{src}->{dst}" for src, dst in unique_rewrites),
527+
)
528+
529+
return normalized_stages
530+
531+
# Normalize explicit stage FQNs to the model's actual root attribute.
532+
if module_names_per_stage is not None:
533+
uses_backbone_prefix = any(
534+
module_fqn == "backbone" or module_fqn.startswith("backbone.")
535+
for stage_modules in module_names_per_stage
536+
for module_fqn in stage_modules
537+
)
538+
uses_model_prefix = any(
539+
module_fqn == "model" or module_fqn.startswith("model.")
540+
for stage_modules in module_names_per_stage
541+
for module_fqn in stage_modules
542+
)
543+
544+
if uses_backbone_prefix and not model_has_backbone_attr and model_has_model_attr:
545+
logger.info("Rewriting pipeline stage FQNs from backbone.* to model.* for current model structure.")
546+
module_names_per_stage = [
547+
[
548+
("model." + module_fqn[len("backbone.") :] if module_fqn.startswith("backbone.") else module_fqn)
549+
for module_fqn in stage_modules
550+
]
551+
for stage_modules in module_names_per_stage
552+
]
553+
elif uses_model_prefix and not model_has_model_attr and model_has_backbone_attr:
554+
logger.info("Rewriting pipeline stage FQNs from model.* to backbone.* for current model structure.")
555+
module_names_per_stage = [
556+
[
557+
("backbone." + module_fqn[len("model.") :] if module_fqn.startswith("model.") else module_fqn)
558+
for module_fqn in stage_modules
559+
]
560+
for stage_modules in module_names_per_stage
561+
]
562+
563+
module_names_per_stage = _normalize_stage_fqn_aliases(module_names_per_stage)
564+
565+
prefer_backbone_attr = False
566+
if module_names_per_stage is not None:
567+
prefer_backbone_attr = any(
568+
module_fqn.startswith("backbone.") for stage_modules in module_names_per_stage for module_fqn in stage_modules
569+
)
570+
571+
has_backbone_attr = model_has_backbone_attr and (prefer_backbone_attr or not model_has_model_attr)
572+
has_model_attr = model_has_model_attr and not has_backbone_attr
487573
if has_backbone_attr:
488574
text_model = model.backbone
489575
text_model_attr_name = ""

nemo_automodel/components/distributed/pipelining/hf_utils.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import logging
16+
import inspect
1617
import types
1718
from typing import TYPE_CHECKING, Callable, Optional, Union
1819

@@ -25,7 +26,7 @@
2526
logger = logging.getLogger(__name__)
2627

2728
# Constants for identifying text/language modules in multimodal models
28-
TEXT_MODULE_ATTRS = ("language_model", "text_model", "text_decoder")
29+
TEXT_MODULE_ATTRS = ("language_model", "text_model", "text_decoder", "backbone")
2930
MULTIMODAL_SUFFIXES = (
3031
"vision_tower",
3132
"visual",
@@ -127,7 +128,7 @@ def pipeline_forward(
127128
causal_mask = (
128129
self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
129130
if hasattr(self, "_update_causal_mask")
130-
else attention_mask
131+
else None
131132
)
132133
mamba_mask = (
133134
self._update_mamba_mask(attention_mask, cache_position)
@@ -142,12 +143,29 @@ def pipeline_forward(
142143
layer_mask = causal_mask if pp_needs_attention_mask else None
143144
else:
144145
layer_mask = None
145-
hidden_states = mixer_block(
146-
hidden_states,
147-
cache_params=past_key_values,
148-
cache_position=cache_position if pp_needs_cache_position else None,
149-
attention_mask=layer_mask,
150-
)
146+
# Some NemotronH-like blocks (e.g., local NemotronV3Block) do not accept
147+
# cache kwargs, while HF NemotronH blocks do. Use signature-aware dispatch.
148+
signature_owner = getattr(mixer_block, "_checkpoint_wrapped_module", mixer_block)
149+
supports_cache_params = getattr(signature_owner, "_nemo_pp_supports_cache_params", None)
150+
supports_cache_position = getattr(signature_owner, "_nemo_pp_supports_cache_position", None)
151+
if supports_cache_params is None or supports_cache_position is None:
152+
try:
153+
forward_params = inspect.signature(signature_owner.forward).parameters
154+
supports_cache_params = "cache_params" in forward_params
155+
supports_cache_position = "cache_position" in forward_params
156+
except (TypeError, ValueError):
157+
supports_cache_params = True
158+
supports_cache_position = True
159+
setattr(signature_owner, "_nemo_pp_supports_cache_params", supports_cache_params)
160+
setattr(signature_owner, "_nemo_pp_supports_cache_position", supports_cache_position)
161+
162+
block_kwargs = {"attention_mask": layer_mask}
163+
if supports_cache_params:
164+
block_kwargs["cache_params"] = past_key_values
165+
if supports_cache_position:
166+
block_kwargs["cache_position"] = cache_position if pp_needs_cache_position else None
167+
168+
hidden_states = mixer_block(hidden_states, **block_kwargs)
151169
else:
152170
# Attention mask handling (compilation-friendly):
153171
# causal_mask_mapping should be precomputed in data pipeline via default_collater

nemo_automodel/components/models/nemotron_v3/model.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,15 @@ def initialize_weights(self, buffer_device: torch.device | None = None) -> None:
160160
"""
161161
# Embedding weights: normal initialization
162162
with buffer_device:
163-
nn.init.normal_(self.embed_tokens.weight, mean=0.0, std=self.config.initializer_range)
164-
self.norm.reset_parameters()
163+
if self.embed_tokens is not None and getattr(self.embed_tokens, "weight", None) is not None:
164+
nn.init.normal_(self.embed_tokens.weight, mean=0.0, std=self.config.initializer_range)
165+
if self.norm is not None and hasattr(self.norm, "reset_parameters"):
166+
self.norm.reset_parameters()
165167

166168
# Initialize all layers via delegation
167169
for block in self.layers.values():
168-
block.init_weights(buffer_device=buffer_device)
170+
if block is not None:
171+
block.init_weights(buffer_device=buffer_device)
169172

170173

171174
class NemotronHForCausalLM(HFCheckpointingMixin, nn.Module, MoEFSDPSyncMixin):
@@ -307,8 +310,10 @@ def initialize_weights(
307310
"""
308311
buffer_device = buffer_device or torch.device(f"cuda:{torch.cuda.current_device()}")
309312
with buffer_device:
310-
self.model.initialize_weights(buffer_device=buffer_device)
311-
nn.init.normal_(self.lm_head.weight, mean=0.0, std=self.config.initializer_range)
313+
if self.model is not None:
314+
self.model.initialize_weights(buffer_device=buffer_device)
315+
if self.lm_head is not None and getattr(self.lm_head, "weight", None) is not None:
316+
nn.init.normal_(self.lm_head.weight, mean=0.0, std=self.config.initializer_range)
312317

313318
self.to(dtype)
314319

nemo_automodel/components/moe/parallelizer.py

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,39 @@ def _get_cp_stream() -> torch.cuda.Stream:
4747
return _CP_STREAM
4848

4949

50+
def _resolve_text_layer_container(model: nn.Module) -> nn.Module:
51+
"""Return the module that owns transformer `layers` for block-wise iteration."""
52+
if hasattr(model, "layers") and model.layers is not None:
53+
return model
54+
55+
# Try common nested containers first.
56+
for attr_name in ("backbone", "model", "language_model", "text_model", "text_decoder"):
57+
if hasattr(model, attr_name):
58+
nested = getattr(model, attr_name)
59+
if nested is None:
60+
continue
61+
if hasattr(nested, "layers") and nested.layers is not None:
62+
return nested
63+
nested_text = get_text_module(nested)
64+
if hasattr(nested_text, "layers") and nested_text.layers is not None:
65+
return nested_text
66+
67+
# Fallback: search any nested submodule exposing `layers`.
68+
for _, submod in model.named_modules():
69+
if submod is model:
70+
continue
71+
if hasattr(submod, "layers") and submod.layers is not None:
72+
return submod
73+
74+
child_names = list(model._modules.keys()) if hasattr(model, "_modules") else []
75+
has_backbone = hasattr(model, "backbone") and getattr(model, "backbone") is not None
76+
has_model_attr = hasattr(model, "model") and getattr(model, "model") is not None
77+
raise AttributeError(
78+
"Could not find a module with `layers` under "
79+
f"{type(model).__name__} (children={child_names[:24]}, has_backbone={has_backbone}, has_model={has_model_attr})"
80+
)
81+
82+
5083
class ExpertParallel(ParallelStyle):
5184
"""
5285
ExpertParallel class is used to shard the MoE parameters on the EP mesh.
@@ -83,8 +116,9 @@ def apply_ep(model: nn.Module, ep_mesh: DeviceMesh, moe_mesh: DeviceMesh | None
83116
_model = model
84117
# Prefer nested text modules when present
85118
_model = get_text_module(_model)
119+
_layer_container = _resolve_text_layer_container(_model)
86120

87-
for _, block in _model.layers.named_children():
121+
for _, block in _layer_container.layers.named_children():
88122
moe_module = block.moe if hasattr(block, "moe") else block.mlp
89123
if isinstance(moe_module, MoE):
90124
# GroupedExpertsTEGroupedLinear uses TE's GroupedLinear which creates
@@ -145,15 +179,17 @@ def selective_checkpointing_context_fn():
145179
_model = model.model
146180
else:
147181
_model = model
148-
for layer_id, block in _model.layers.named_children():
182+
_model = get_text_module(_model)
183+
_layer_container = _resolve_text_layer_container(_model)
184+
for layer_id, block in _layer_container.layers.named_children():
149185
if ignore_router:
150186
block = ptd_checkpoint_wrapper(
151187
block, preserve_rng_state=True, context_fn=selective_checkpointing_context_fn
152188
)
153189
else:
154190
block = ptd_checkpoint_wrapper(block, preserve_rng_state=True)
155191

156-
_model.layers.register_module(layer_id, block)
192+
_layer_container.layers.register_module(layer_id, block)
157193

158194

159195
def apply_fsdp(
@@ -193,8 +229,9 @@ def apply_fsdp(
193229
_model = model
194230
# handle VLM
195231
_model = get_text_module(_model)
232+
_layer_container = _resolve_text_layer_container(_model)
196233

197-
for _, block in _model.layers.named_children():
234+
for _, block in _layer_container.layers.named_children():
198235
moe_module = block.moe if hasattr(block, "moe") else block.mlp
199236
if isinstance(moe_module, MoE) and ep_shard_enabled:
200237
# Apply FSDP on dim=1 for grouped experts since we may have more
@@ -217,8 +254,8 @@ def apply_fsdp(
217254

218255
fully_shard_default(block, ignored_params=ignored_params)
219256

220-
if hasattr(_model, "embed_tokens") and _model.embed_tokens is not None:
221-
fully_shard_default(_model.embed_tokens)
257+
if hasattr(_layer_container, "embed_tokens") and _layer_container.embed_tokens is not None:
258+
fully_shard_default(_layer_container.embed_tokens)
222259

223260
lm_head = getattr(_model, "lm_head", None) or getattr(model, "lm_head", None)
224261
if lm_head is not None:
@@ -252,7 +289,7 @@ def apply_fsdp(
252289
else:
253290
logging.info("Skipping FSDP wrap for frozen visual tower")
254291

255-
fully_shard_default(_model)
292+
fully_shard_default(_layer_container)
256293

257294
# If model has a nested structure (outer model wrapping inner _model), wrap the outer model if requested
258295
if wrap_outer_model and model is not _model:
@@ -266,8 +303,10 @@ def apply_cp(model: torch.nn.Module, cp_mesh: DeviceMesh, cp_comm_type: str = "p
266303
_model = model.model
267304
else:
268305
_model = model
306+
_model = get_text_module(_model)
307+
_layer_container = _resolve_text_layer_container(_model)
269308

270-
for _, block in _model.layers.named_children():
309+
for _, block in _layer_container.layers.named_children():
271310
attn_module = block.self_attn.attn_module
272311
assert isinstance(attn_module, DotProductAttention), (
273312
"Context parallelism is only supported for TransformerEngine's DotProductAttention"
@@ -307,10 +346,24 @@ def parallelize_model(
307346

308347
ep_enabled = ep_axis_name is not None and moe_mesh is not None and moe_mesh[ep_axis_name].size() > 1
309348
if ep_enabled:
310-
assert model.model.moe_config.n_routed_experts % moe_mesh[ep_axis_name].size() == 0, (
311-
f"n_routed_experts {model.model.moe_config.n_routed_experts} must be divisible by "
312-
f"expert_parallel_degree {moe_mesh[ep_axis_name].size()}"
313-
)
349+
_model = model.model if hasattr(model, "model") and model.model is not None else model
350+
_model = get_text_module(_model)
351+
n_routed_experts = None
352+
if hasattr(_model, "moe_config") and _model.moe_config is not None:
353+
n_routed_experts = getattr(_model.moe_config, "n_routed_experts", None)
354+
if n_routed_experts is None and hasattr(_model, "config"):
355+
for attr in ("n_routed_experts", "moe_num_experts", "num_experts"):
356+
if hasattr(_model.config, attr):
357+
n_routed_experts = getattr(_model.config, attr)
358+
break
359+
360+
if n_routed_experts is not None:
361+
assert n_routed_experts % moe_mesh[ep_axis_name].size() == 0, (
362+
f"n_routed_experts {n_routed_experts} must be divisible by "
363+
f"expert_parallel_degree {moe_mesh[ep_axis_name].size()}"
364+
)
365+
else:
366+
logger.warning("Could not infer n_routed_experts; skipping EP divisibility assertion.")
314367

315368
apply_ep(model, moe_mesh[ep_axis_name], moe_mesh=moe_mesh)
316369

0 commit comments

Comments
 (0)