Skip to content

PhysicsNeMo PEFT - LoRA#1691

Open
mnabian wants to merge 7 commits into
NVIDIA:mainfrom
mnabian:peft
Open

PhysicsNeMo PEFT - LoRA#1691
mnabian wants to merge 7 commits into
NVIDIA:mainfrom
mnabian:peft

Conversation

@mnabian
Copy link
Copy Markdown
Collaborator

@mnabian mnabian commented Jun 1, 2026

PhysicsNeMo Pull Request

Description

Adds a native, self-contained Low-Rank Adaptation (LoRA) subpackage for
parameter-efficient fine-tuning of PhysicsNeMo models, shipped under
physicsnemo.experimental.peft, plus an end-to-end example and a user-guide page.

LoRA freezes a pretrained model and trains only small low-rank adapter matrices
injected beside selected layers. This adapts a model to a new dataset at a
fraction of the cost of full fine-tuning, produces a tiny adapter checkpoint, lowers memory (frozen layers drop saved
activations), and reduces overfitting/forgetting in the small-data regime typical
of SciML.

Motivation

  • Small-data domain adaptation (the regime ISV end-users actually live in):
    adapt a foundation checkpoint to a new vehicle class / chemistry / operating
    condition from a handful of samples.
  • Deployment story: one frozen base + N swappable adapters at serve time.

What's included

Package — physicsnemo/experimental/peft/

File Contents
config.py LoRAConfig dataclass + validation
lora.py LoRALayer mixin, LoRALinear, LoRA_te_Linear, LoRA_te_LayerNormMLP, and the type→wrapper registry (register_lora_wrapper / get_wrapper_for)
apply.py apply_lora, resolve_targets, ApplyResult, freeze logic, guards
merge.py merge_lora
io.py save_adapter / load_adapter (adapter archive)
utils.py split_params_for_optimizer, print_trainable_parameters, set_adapter_enabled, compute_base_fingerprint

Public API:

from physicsnemo.experimental.peft import (
    LoRAConfig,                  # declare rank/alpha + which layers to adapt
    apply_lora,                  # inject adapters in place, freeze the base
    ApplyResult,                 # report returned by apply_lora
    split_params_for_optimizer, # route adapter params to AdamW
    print_trainable_parameters, # "trainable params: N (X% of M total)"
    save_adapter, load_adapter,  # adapter-only archive I/O
    merge_lora,                  # fold adapters into base for zero-overhead inference
    set_adapter_enabled,         # toggle adapters (base-vs-adapter comparison)
    register_lora_wrapper,       # extension seam for new layer types
)

End-to-end example — examples/cfd/external_aerodynamics/transformer_models/

A separate, runnable LoRA recipe living alongside the existing GeoTransolver
training example (train.py is not modified):

  • src/finetune.py — load a pretrained base → apply_lora → train only the
    adapters (AdamW, find_unused_parameters=True, DistributedSampler sharding)
    save_adapter.
  • src/deploy.pyload_adapter (adapter-swap) or merge_lora (fold in).
  • conf/finetune_lora.yaml — reuses the example's model/data/training config
    groups + a peft: block.
  • FINETUNE_LORA.md — full walkthrough; the main README cross-references it.

Documentation

  • physicsnemo-docs/docs/user-guide/peft.rst — user-guide page (overview, when to
    use, the LoRA math, quickstart, layer targeting, optimizer setup, save/load,
    merge, Transformer Engine support, extensibility, API reference), wired into the
    User Guide toctree.

Quickstart

import torch
from physicsnemo.experimental.peft import (
    LoRAConfig, apply_lora, split_params_for_optimizer, save_adapter,
)

model = build_model()                     # any torch.nn.Module
model.load_state_dict(pretrained)

apply_lora(model, LoRAConfig(rank=16, alpha=16,
           target_pattern=r"blocks\.\d+\.attn\.(q|k|v|out)_proj"))

groups = split_params_for_optimizer(model)
opt = torch.optim.AdamW(groups["lora"] + groups["extras"], lr=5e-4)
# ... train ...
save_adapter(model, "adapter.lora")

Layer targeting

A LoRAConfig sets exactly one selector, matched against fully-qualified
module names (leaf names are not unique):

  • target_modules — explicit list of names
  • target_pattern — regex (re.search)
  • target_filter — predicate (name, module) -> bool

Plus two modifiers: wrap_mlp (additively adapt the transformer feed-forward
sub-block) and extras_trainable (modules to train fully, not low-rank). Only
registered layer types are eligible; a selector matching zero wrappable layers
raises (no silent misses).

Key design decisions

  • In-place mutation, not a PeftModel wrapper. apply_lora swaps matched
    leaves for LoRA wrappers and freezes the base; the model keeps its class and
    identity, so existing .mdlus checkpoint/inference tooling still works.
  • Works on any torch.nn.Module — no dependency on physicsnemo.Module or
    the .mdlus format (a plain-PyTorch user can adapt their own model).
  • Transformer Engine: te.Linear adapts per-matrix; the fused
    te.LayerNormMLP (no addressable child Linears) adapts via a single rank-r
    residual across the FFN sub-block (kept un-mergeable). The te.LayerNormLinear
    output head is not wrapped in v1 (documented; future register_lora_wrapper).
  • Adapter archive: a plain ZIP (adapter_config.json + adapter_model.pt +
    metadata.json) holding only the trainable slice. Disambiguated from full model
    checkpoints by metadata.kind == "lora_adapter" and a structural
    base_fingerprint. Loaded only by load_adapter (with weights_only=True);
    recommended extension .lora (any extension works) since it is neither a
    torch.save file nor a Module checkpoint.
  • Optimizer routing: adapter params go to AdamW, never Muon (Newton–Schulz
    orthogonalization is degenerate on low-rank factors) — via
    split_params_for_optimizer.
  • Type→wrapper registry: the single extension seam — supporting a new layer
    type (equivariant, MoE, …) is one register_lora_wrapper(type, wrapper) call,
    with no changes to targeting / apply / save / merge.

Validation

Tests (test/experimental/peft/)

42 tests across 7 files, run in-container on GPU + Transformer Engine; the
TE/GeoTransolver tests skip cleanly on CPU-only CI.

File Covers
test_config.py LoRAConfig validation (selectors, rank, dropout, reserved init)
test_lora_linear.py B=0 init parity, grad-flow-to-LoRA-only, dtype/device inheritance, merge transpose, disable; TE te.Linear + fused te.LayerNormMLP residual
test_apply.py exact/regex/callable targeting, zero-wrap & double-apply guards, freezing, extras_trainable not unfreezing nested base, optimizer split, wrap_mlp expansion
test_merge.py post-merge forward preserved, idempotency
test_io.py round-trip, archive structure, any-extension, missing-parent-dir creation, kind check, fingerprint mismatch, weights_only rejection of unsafe pickles, save-after-merge guard
test_smoke.py full recipe via real training: apply → AdamW → loss drops → save → load → merge
test_geotransolver_lora.py apply_lora + wrap_mlp on a real TE GeoTransolver (GPU+TE)

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

@mnabian mnabian self-assigned this Jun 1, 2026
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Jun 1, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@mnabian mnabian marked this pull request as draft June 1, 2026 18:26
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 1, 2026

Greptile Summary

This PR adds physicsnemo.experimental.peft, a native LoRA PEFT module, along with a companion GeoTransolver fine-tuning/deploy example. The core library (apply_lora, save_adapter/load_adapter, merge_lora, adapter wrappers for nn.Linear, te.Linear, and fused te.LayerNormMLP) is well-designed and backed by thorough unit and integration tests.

  • New physicsnemo/experimental/peft/ package implementing LoRA injection, adapter serialization as a ZIP archive, weight merging, and optimizer param routing, with TE-aware wrappers that preserve fp8 fused kernels.
  • New examples/cfd/external_aerodynamics/transformer_models/src/finetune.py and deploy.py demonstrating the end-to-end workflow; finetune.py has a DistributedSampler initialization that passes the DataLoader instead of DataLoader.dataset (covered in a prior review comment).

Important Files Changed

Filename Overview
physicsnemo/experimental/peft/init.py New public API module; clean re-exports with explicit all.
physicsnemo/experimental/peft/apply.py In-place LoRA injection; fingerprinting, targeting, and freezing logic are correct.
physicsnemo/experimental/peft/config.py LoRAConfig dataclass with solid validation; effective_alpha/scaling properties correct.
physicsnemo/experimental/peft/io.py Save/load adapter ZIP; save_adapter silently writes empty adapter when called after merge_lora, producing an unloadable file.
physicsnemo/experimental/peft/lora.py LoRALayer mixin + concrete wrappers for nn.Linear, te.Linear, te.LayerNormMLP; math and MRO-based submodule registration are correct.
physicsnemo/experimental/peft/merge.py merge_lora folds mergeable adapters into base weights correctly; idempotent and handles non-mergeable TE fused layers.
physicsnemo/experimental/peft/utils.py Fingerprinting, optimizer param routing, and enable/disable utilities are correct.
examples/cfd/external_aerodynamics/transformer_models/src/finetune.py LoRA fine-tuning entry point; DistributedSampler is initialized with the DataLoader instead of its .dataset, causing corrupt sharding indices in multi-GPU runs.
examples/cfd/external_aerodynamics/transformer_models/src/deploy.py Deploy script correctly loads adapter, optionally merges, and guards against saving merged checkpoint when non-mergeable adapters remain.
test/experimental/peft/test_smoke.py End-to-end smoke test (apply → train → save → load → merge) exercises the full PEFT round-trip on a toy model.

Reviews (2): Last reviewed commit: "minor bug fixes" | Re-trigger Greptile

Comment thread physicsnemo/experimental/peft/io.py Outdated
Comment thread physicsnemo/experimental/peft/config.py Outdated
@mnabian mnabian marked this pull request as ready for review June 3, 2026 23:32
@mnabian mnabian requested a review from ys-teh June 3, 2026 23:32
Comment thread physicsnemo/experimental/peft/io.py
@ys-teh
Copy link
Copy Markdown
Collaborator

ys-teh commented Jun 4, 2026

The general implementation structure looks modular and good to me. register_lora_wrappers will be useful. You may want to consider putting lora-related items into a folder called lora in case there will be non-lora peft approaches in the future. Utils can likely stay outside.

The implementation differs in style compared to existing finetuning implementation in Alchemi. The closest one, for example, is patching module and it looks like this. This potentially needs to be resolved on the Alchemi side. In any case, your modular implementation should help.

@mnabian
Copy link
Copy Markdown
Collaborator Author

mnabian commented Jun 4, 2026

The general implementation structure looks modular and good to me. register_lora_wrappers will be useful. You may want to consider putting lora-related items into a folder called lora in case there will be non-lora peft approaches in the future. Utils can likely stay outside.

Thanks for your comment. I would hold the lora/ split until we add a second, non-variant method — at which point the clean factoring is method-agnostic core (targeting, I/O, registry, freeze, utils) vs. tuners/lora, tuners/, not just renaming files. Since the public API is the package __init__, that move does not break anything.

@mnabian mnabian changed the title PhysicsNeMo PEFT PhysicsNeMo PEFT - LoRA Jun 4, 2026
@mnabian mnabian requested a review from laserkelvin June 4, 2026 00:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants