From 138a2d4c06bc5b5a6532147f0875055a32b860cf Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Thu, 14 May 2026 17:21:21 -0700 Subject: [PATCH 1/3] add use_memory_efficient_lora knob Signed-off-by: Alexandros Koumparoulis --- nemo_automodel/components/_peft/lora.py | 104 +++++-- .../hf_peft/test_lora_memory_profile.py | 285 ++++++++++++++++++ tests/unit_tests/_peft/test_lora.py | 119 ++++++-- 3 files changed, 470 insertions(+), 38 deletions(-) create mode 100644 tests/functional_tests/hf_peft/test_lora_memory_profile.py diff --git a/nemo_automodel/components/_peft/lora.py b/nemo_automodel/components/_peft/lora.py index 663f8a1643..57bf669aef 100644 --- a/nemo_automodel/components/_peft/lora.py +++ b/nemo_automodel/components/_peft/lora.py @@ -53,6 +53,7 @@ class PeftConfig: dropout_position: Literal["pre", "post"] = "post" lora_A_init: str = "xavier" lora_dtype: Optional[torch.dtype] = None + use_memory_efficient_lora: bool = True use_triton: bool = False moe_rank_scaling: bool = False @@ -72,6 +73,7 @@ def from_dict(cls, d: dict[str, Any]): dropout_position=d.get("dropout_position", "post"), lora_A_init=d.get("lora_A_init", "xavier"), lora_dtype=d.get("lora_dtype", None), + use_memory_efficient_lora=d.get("use_memory_efficient_lora", True), use_triton=d.get("use_triton", False), moe_rank_scaling=d.get("moe_rank_scaling", False), ) @@ -102,6 +104,7 @@ def __init__( dropout_position="post", lora_A_init_method="xavier", lora_dtype=None, + use_memory_efficient_lora=True, ): """ LinearLora constructor. @@ -138,6 +141,7 @@ def __init__( dropout_position=dropout_position, lora_A_init_method=lora_A_init_method, lora_dtype=lora_dtype, + use_memory_efficient_lora=use_memory_efficient_lora, ) @torch.no_grad @@ -165,6 +169,7 @@ def _init_adapter( dropout_position="post", lora_A_init_method="xavier", lora_dtype=None, + use_memory_efficient_lora=True, ): """ Adds LoRA weights to obj. Obj is either a LinearLoRA or an nn.Module (when monkey-patching). @@ -182,6 +187,7 @@ def _init_adapter( obj.dim = dim obj.scale = alpha / dim obj.use_dora = bool(use_dora) + obj.use_memory_efficient_lora = bool(use_memory_efficient_lora) # Freezer device = obj.weight.device @@ -227,6 +233,22 @@ def _dora_weight_norm(self) -> torch.Tensor: weight_norm = torch.linalg.norm(weight + self.scale * delta_w, dim=1).to(weight.dtype) return weight_norm.detach() + def _should_use_memory_efficient_lora(self, x: torch.Tensor) -> bool: + """Return whether this LoRA branch can use the custom autograd path.""" + if not getattr(self, "use_memory_efficient_lora", False): + return False + if isinstance(x, DTensor): + return False + if isinstance(getattr(self.lora_A, "weight", None), DTensor): + return False + if isinstance(getattr(self.lora_B, "weight", None), DTensor): + return False + if torch.compiler.is_compiling(): + return False + if HAS_TE and isinstance(getattr(self, "lora_A", None), transformer_engine.pytorch.Linear): + return False + return True + def forward(self, x): """ Forward pass through the original linear layer augmented with the LoRA pathway. @@ -275,7 +297,12 @@ def forward(self, x): # Apply scale before lora_B to keep lora_res as a Partial tensor. # This allows both res and lora_res to remain Partial, so only one reduce-scatter is needed after addition. # Multiplying after lora_B would convert Partial to Replicate, causing an extra reduce-scatter operation. - lora_res = self.lora_B(self.lora_A(x) * self.scale) + if self._should_use_memory_efficient_lora(x): + lora_res = LoRATritonFunction.apply( + x, self.lora_A.weight, self.lora_B.weight, self.scale, x.dtype, False + ) + else: + lora_res = self.lora_B(self.lora_A(x) * self.scale) if self.dropout_position == "post": lora_res = F.dropout(lora_res, p=self.dropout_p, training=self.training) return res + lora_res @@ -357,7 +384,10 @@ def forward(self, x): if self.dropout_position == "pre": x = F.dropout(x, p=self.dropout_p, training=self.training) - lora_res = LoRATritonFunction.apply(x, self.lora_A.weight, self.lora_B.weight, self.scale, x.dtype) + if self.use_memory_efficient_lora: + lora_res = LoRATritonFunction.apply(x, self.lora_A.weight, self.lora_B.weight, self.scale, x.dtype, True) + else: + lora_res = self.lora_B(self.lora_A(x) * self.scale) if self.dropout_position == "post": lora_res = F.dropout(lora_res, p=self.dropout_p, training=self.training) @@ -373,6 +403,7 @@ def patch_linear_module( dropout_position="post", lora_A_init_method="xavier", lora_dtype=None, + use_memory_efficient_lora=True, use_triton=True, layer_name=None, ): @@ -396,8 +427,10 @@ def patch_linear_module( Defaults to 'post' (choices: 'pre', 'post'). lora_A_init_method (str, optional): lora_a init method. Defaults to 'xavier'. lora_dtype (_type_, optional): Lora weights' dtype. By default will use orig_linear's dtype - but orig_linear might use non-trainable dtype (e.g., 4bit), in which case the user must - specify the dtype manually. Defaults to None. + but orig_linear might use non-trainable dtype (e.g., 4bit), in which case the user must + specify the dtype manually. Defaults to None. + use_memory_efficient_lora (bool, optional): Use the custom autograd implementation for standard LoRA. + When Triton is enabled this uses Triton kernels; otherwise it uses PyTorch matmuls. Defaults to True. use_triton (bool, optional): By default we use the triton kernel LoRA implementation. Returns: @@ -428,6 +461,7 @@ def patch_linear_module( dropout_position=dropout_position, lora_A_init_method=lora_A_init_method, lora_dtype=lora_dtype, + use_memory_efficient_lora=use_memory_efficient_lora, ) cls = orig_linear.__class__ new_cls = type("PatchedLinearLoRA", (linear_lora_cls, cls), {}) @@ -605,6 +639,7 @@ def apply_lora_to_linear_modules( dropout_position=peft_config.dropout_position, lora_A_init_method=peft_config.lora_A_init, lora_dtype=lora_dtype, + use_memory_efficient_lora=getattr(peft_config, "use_memory_efficient_lora", True), use_triton=peft_config.use_triton, layer_name=name, ) @@ -614,7 +649,11 @@ def apply_lora_to_linear_modules( class LoRATritonFunction(torch.autograd.Function): """ - Autograd function that calls the triton kernel wrappers for the LoRA forward and backward passes. + Autograd function that avoids saving the LoRA A activation. + + The default path calls Triton kernel wrappers for forward and backward. Callers can pass + ``use_triton_kernel=False`` to use PyTorch matmuls while keeping the same memory-efficient + saved tensor behavior. """ @staticmethod @@ -622,39 +661,44 @@ def setup_context(ctx, inputs, output): """ Stores context for LoRA backward pass. """ - x, lora_A, lora_B, scale, _ = inputs + x, lora_A, lora_B, scale, dtype, *rest = inputs ctx.save_for_backward(x, lora_A, lora_B) ctx.scale = scale + ctx.dtype = dtype + ctx.use_triton_kernel = bool(rest[0]) if rest else True + ctx.num_inputs = len(inputs) @staticmethod - def forward(x, lora_A, lora_B, scale, dtype): + def forward(x, lora_A, lora_B, scale, dtype, use_triton_kernel=True): """ - Forward method for LoRATriton. + Forward method for memory-efficient LoRA. - Reshapes 3D tensors into 2D and then calls the triton kernel. + Reshapes 3D tensors into 2D and then calls either Triton kernels or PyTorch matmuls. """ reshape = x.dim() == 3 if reshape: bs, seq_len, d = x.shape x = x.reshape(-1, d) - lora_res = lora_forward_wrapper(x, lora_A.t(), lora_B.t(), res=None, scale=scale, dtype=dtype) + if use_triton_kernel: + lora_res = lora_forward_wrapper(x, lora_A.t(), lora_B.t(), res=None, scale=scale, dtype=dtype) + else: + lora_res = F.linear(F.linear(x, lora_A) * scale, lora_B) if reshape: return lora_res.view(bs, seq_len, -1) - else: - return lora_res + return lora_res @staticmethod def backward(ctx, d_y): """ - Backward method for LoRATriton. + Backward method for memory-efficient LoRA. - Reshapes 3D tensors into 2D and then calls the kernels to update d_lora_a, d_lora_b, and dx. + Reshapes 3D tensors into 2D and then updates d_lora_a, d_lora_b, and dx. The PyTorch matmul + path recomputes ``x @ lora_A.T`` here instead of saving it from forward. """ x, lora_A, lora_B = ctx.saved_tensors scale = ctx.scale - dtype = x.dtype reshape = x.dim() == 3 if reshape: @@ -662,9 +706,29 @@ def backward(ctx, d_y): d_y = d_y.reshape(-1, d_y.shape[-1]) x = x.reshape(-1, d) - d_lora_A, d_x = lora_da_dx_update_wrapper(x.t(), d_y, lora_B, lora_A, scale, dtype=dtype) - d_lora_B = lora_db_update_wrapper(lora_A, x.t(), d_y, scale, dtype) - - if reshape: + if ctx.use_triton_kernel: + d_lora_A, d_x = lora_da_dx_update_wrapper(x.t(), d_y, lora_B, lora_A, scale, dtype=ctx.dtype) + d_lora_B = lora_db_update_wrapper(lora_A, x.t(), d_y, scale, ctx.dtype) + d_lora_A = d_lora_A.t() + else: + d_x = d_lora_A = d_lora_B = None + needs_x, needs_lora_A, needs_lora_B = ctx.needs_input_grad[:3] + if needs_x or needs_lora_A: + d_y_lora_B = torch.matmul(d_y, lora_B) + if needs_x: + d_x = torch.empty_like(x) + d_x.addmm_(d_y_lora_B, lora_A, beta=0, alpha=scale) + if needs_lora_A: + d_lora_A = torch.matmul(d_y_lora_B.t(), x) * scale + + if needs_lora_B: + d_lora_B = torch.empty_like(lora_B) + d_lora_B.addmm_(d_y.t(), F.linear(x, lora_A), beta=0, alpha=scale) + + if reshape and d_x is not None: d_x = d_x.view(bs, seq_len, d) - return d_x, d_lora_A.t(), d_lora_B, None, None + + gradients = (d_x, d_lora_A, d_lora_B, None, None) + if ctx.num_inputs == 6: + return gradients + (None,) + return gradients diff --git a/tests/functional_tests/hf_peft/test_lora_memory_profile.py b/tests/functional_tests/hf_peft/test_lora_memory_profile.py new file mode 100644 index 0000000000..c977037bd4 --- /dev/null +++ b/tests/functional_tests/hf_peft/test_lora_memory_profile.py @@ -0,0 +1,285 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import gc +import json +import os +import subprocess +import sys +from dataclasses import dataclass +from pathlib import Path + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.profiler import ProfilerActivity + +from nemo_automodel.components._peft.lora import patch_linear_module +from nemo_automodel.components._peft.lora_kernel import HAVE_TRITON + +_RESULT_PREFIX = "LORA_MEMORY_PROFILE_RESULT " +_GAIN_PREFIX = "LORA_MEMORY_PROFILE_GAIN " +_MIN_SAVED_BYTES = 1024 * 1024 + + +@dataclass +class ProfileResult: + """Memory profile result for one LoRA mode.""" + + peak_delta_bytes: int + profiler_memory_bytes: int + + +@dataclass +class ProfileConfig: + """Shape configuration for one memory profile run.""" + + seq_len: int + hidden_size: int + rank: int + layers: int + + +_PROFILE_CONFIG = ProfileConfig(seq_len=4096, hidden_size=512, rank=16, layers=16) +_LORA_IMPL_PARAMS = [ + pytest.param(False, id="torch"), + pytest.param(True, marks=pytest.mark.skipif(not HAVE_TRITON, reason="Triton is not available"), id="triton"), +] + + +class LoRAStack(nn.Module): + """Small stack of real LinearLoRA modules used by the profiler tests.""" + + def __init__( + self, *, use_memory_efficient_lora: bool, use_triton: bool, config: ProfileConfig, device: torch.device + ) -> None: + super().__init__() + self.layers = nn.ModuleList() + for _ in range(config.layers): + base = nn.Linear(config.hidden_size, config.hidden_size, bias=False, dtype=torch.bfloat16, device=device) + self.layers.append( + patch_linear_module( + base, + dim=config.rank, + alpha=config.rank, + dropout=0.0, + lora_dtype=torch.bfloat16, + use_memory_efficient_lora=use_memory_efficient_lora, + use_triton=use_triton, + ) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Run every LoRA layer so its forward state stays live until backward.""" + for layer in self.layers: + x = x + layer(x) + return x + + +def _run_step(model: nn.Module, x: torch.Tensor) -> None: + out = model(x) + loss = out.float().square().mean() + loss.backward() + torch.cuda.synchronize(x.device) + + +def _profiler_memory_bytes(prof: torch.profiler.profile) -> int: + return sum(max(0, getattr(event, "self_device_memory_usage", 0)) for event in prof.key_averages()) + + +def _profile_model( + *, use_memory_efficient_lora: bool, use_triton: bool, device: torch.device, use_fsdp: bool +) -> ProfileResult: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(device) + torch.manual_seed(1234) + + config = _PROFILE_CONFIG + model = LoRAStack( + use_memory_efficient_lora=use_memory_efficient_lora, use_triton=use_triton, config=config, device=device + ) + if use_fsdp: + from torch.distributed.device_mesh import init_device_mesh + from torch.distributed.fsdp import fully_shard + + mesh = init_device_mesh("cuda", (dist.get_world_size(),), mesh_dim_names=("dp",)) + for layer in model.layers: + fully_shard(layer, mesh=mesh) + fully_shard(model, mesh=mesh) + + x = torch.randn(1, config.seq_len, config.hidden_size, dtype=torch.bfloat16, device=device, requires_grad=True) + + _run_step(model, x) + model.zero_grad(set_to_none=True) + x.grad = None + + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(device) + torch.cuda.reset_peak_memory_stats(device) + allocated_before = torch.cuda.memory_allocated(device) + + with torch.profiler.profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + profile_memory=True, + record_shapes=True, + ) as prof: + _run_step(model, x) + + peak_delta_bytes = torch.cuda.max_memory_allocated(device) - allocated_before + profiler_memory_bytes = _profiler_memory_bytes(prof) + + del model, x + gc.collect() + torch.cuda.empty_cache() + return ProfileResult(peak_delta_bytes=peak_delta_bytes, profiler_memory_bytes=profiler_memory_bytes) + + +def _profile_pair(*, use_triton: bool, device: torch.device, use_fsdp: bool) -> dict[str, int]: + before = _profile_model(use_memory_efficient_lora=False, use_triton=use_triton, device=device, use_fsdp=use_fsdp) + after = _profile_model(use_memory_efficient_lora=True, use_triton=use_triton, device=device, use_fsdp=use_fsdp) + saved_bytes = before.peak_delta_bytes - after.peak_delta_bytes + return { + "before_peak_delta_bytes": before.peak_delta_bytes, + "after_peak_delta_bytes": after.peak_delta_bytes, + "saved_bytes": saved_bytes, + "before_profiler_memory_bytes": before.profiler_memory_bytes, + "after_profiler_memory_bytes": after.profiler_memory_bytes, + "seq_len": _PROFILE_CONFIG.seq_len, + "hidden_size": _PROFILE_CONFIG.hidden_size, + "lora_rank": _PROFILE_CONFIG.rank, + "layers": _PROFILE_CONFIG.layers, + "use_triton": int(use_triton), + } + + +def _format_mib(num_bytes: int) -> str: + return f"{num_bytes / 1024**2:.2f} MiB" + + +def _dump_memory_gain(result: dict[str, int], *, mode: str, dist_rank: int | None = None) -> None: + before_peak = result["before_peak_delta_bytes"] + after_peak = result["after_peak_delta_bytes"] + saved_bytes = result["saved_bytes"] + saved_pct = 100.0 * saved_bytes / before_peak if before_peak > 0 else 0.0 + impl = "triton" if result["use_triton"] else "torch" + rank_part = f" dist_rank={dist_rank}" if dist_rank is not None else "" + print( + f"{_GAIN_PREFIX}mode={mode} impl={impl}{rank_part} " + f"seq_len={result['seq_len']} hidden_size={result['hidden_size']} " + f"lora_rank={result['lora_rank']} layers={result['layers']} " + f"before_peak={_format_mib(before_peak)} after_peak={_format_mib(after_peak)} " + f"saved={_format_mib(saved_bytes)} saved_pct={saved_pct:.2f}% " + f"before_profiler_memory={_format_mib(result['before_profiler_memory_bytes'])} " + f"after_profiler_memory={_format_mib(result['after_profiler_memory_bytes'])}", + flush=True, + ) + + +def _assert_profile_memory_gain(result: dict[str, int]) -> None: + assert result["before_profiler_memory_bytes"] > 0 + assert result["after_profiler_memory_bytes"] > 0 + assert result["saved_bytes"] >= _MIN_SAVED_BYTES, result + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +@pytest.mark.parametrize("use_triton", _LORA_IMPL_PARAMS) +def test_memory_efficient_lora_torch_profile_single_gpu(use_triton: bool): + """Memory-efficient LoRA should reduce profiled peak memory on one GPU.""" + result = _profile_pair(use_triton=use_triton, device=torch.device("cuda", 0), use_fsdp=False) + _dump_memory_gain(result, mode="single_gpu") + _assert_profile_memory_gain(result) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="requires at least 2 CUDA devices") +@pytest.mark.parametrize("use_triton", _LORA_IMPL_PARAMS) +def test_memory_efficient_lora_torch_profile_two_gpu_fsdp(use_triton: bool): + """Memory-efficient LoRA should reduce profiled peak memory with 2-GPU FSDP.""" + cmd = [ + sys.executable, + "-m", + "torch.distributed.run", + "--standalone", + "--nproc_per_node=2", + str(Path(__file__).resolve()), + "--fsdp-worker", + "--use-triton" if use_triton else "--no-use-triton", + ] + env = os.environ.copy() + repo_root = str(Path(__file__).resolve().parents[3]) + env["PYTHONPATH"] = repo_root + os.pathsep + env.get("PYTHONPATH", "") + completed = subprocess.run( + cmd, + cwd=repo_root, + env=env, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + timeout=180, + check=False, + ) + assert completed.returncode == 0, completed.stdout + + result_line = next( + (line for line in completed.stdout.splitlines() if line.startswith(_RESULT_PREFIX)), + None, + ) + assert result_line is not None, completed.stdout + for rank, rank_result in enumerate(json.loads(result_line.removeprefix(_RESULT_PREFIX))): + _dump_memory_gain(rank_result, mode="two_gpu_fsdp", dist_rank=rank) + _assert_profile_memory_gain(rank_result) + + +def _run_fsdp_worker(use_triton: bool) -> None: + dist.init_process_group("nccl") + try: + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + result = _profile_pair(use_triton=use_triton, device=device, use_fsdp=True) + + gathered = [None for _ in range(dist.get_world_size())] + dist.all_gather_object(gathered, result) + if dist.get_rank() == 0: + print(_RESULT_PREFIX + json.dumps(gathered), flush=True) + + ok = int( + result["before_profiler_memory_bytes"] > 0 + and result["after_profiler_memory_bytes"] > 0 + and result["saved_bytes"] >= _MIN_SAVED_BYTES + ) + ok_tensor = torch.tensor([ok], device=device) + dist.all_reduce(ok_tensor, op=dist.ReduceOp.MIN) + if ok_tensor.item() != 1: + raise AssertionError(result) + finally: + dist.destroy_process_group() + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--fsdp-worker", action="store_true") + parser.add_argument("--use-triton", dest="use_triton", action="store_true") + parser.add_argument("--no-use-triton", dest="use_triton", action="store_false") + parser.set_defaults(use_triton=False) + return parser.parse_args() + + +if __name__ == "__main__": + args = _parse_args() + if args.fsdp_worker: + _run_fsdp_worker(args.use_triton) diff --git a/tests/unit_tests/_peft/test_lora.py b/tests/unit_tests/_peft/test_lora.py index 4e3b1bd4df..3b580b53b7 100644 --- a/tests/unit_tests/_peft/test_lora.py +++ b/tests/unit_tests/_peft/test_lora.py @@ -15,9 +15,12 @@ import pytest import torch import torch.nn as nn +import torch.nn.functional as F +from torch.autograd.graph import saved_tensors_hooks from nemo_automodel.components._peft.lora import ( LinearLoRA, + LoRATritonFunction, PeftConfig, apply_lora_to_linear_modules, patch_linear_module, @@ -112,6 +115,15 @@ def test_lora_patch_applies_to_selected_module_with_str_dtype(model): assert not isinstance(model.linear2, LinearLoRA) +def test_peft_config_memory_efficient_lora_round_trip(): + """PeftConfig should default memory-efficient LoRA on and preserve explicit overrides.""" + assert PeftConfig().use_memory_efficient_lora is True + + cfg = PeftConfig.from_dict({"use_memory_efficient_lora": False}) + assert cfg.use_memory_efficient_lora is False + assert cfg.to_dict()["use_memory_efficient_lora"] is False + + def test_forward_output_consistency(dummy_input): """Verifies that model output shape remains the same after LoRA patching, but values change due to the added LoRA components. @@ -146,6 +158,89 @@ def test_backward_pass(dummy_input): assert all(torch.isfinite(g).all() for g in grads if g is not None), "Gradients should be finite" +@pytest.mark.parametrize("input_shape", [(5, 16), (2, 3, 16)]) +def test_memory_efficient_lora_matches_legacy_forward_and_backward(input_shape): + """Custom autograd LoRA should match the legacy two-linear implementation.""" + torch.manual_seed(1234) + scale = 2.0 + lora_dim = 4 + out_features = 12 + + x = torch.randn(*input_shape, requires_grad=True) + lora_A = torch.randn(lora_dim, input_shape[-1], requires_grad=True) + lora_B = torch.randn(out_features, lora_dim, requires_grad=True) + x_ref = x.detach().clone().requires_grad_(True) + lora_A_ref = lora_A.detach().clone().requires_grad_(True) + lora_B_ref = lora_B.detach().clone().requires_grad_(True) + + efficient = LoRATritonFunction.apply(x, lora_A, lora_B, scale, x.dtype, False) + legacy = F.linear(F.linear(x_ref, lora_A_ref) * scale, lora_B_ref) + + grad = torch.randn_like(legacy) + efficient.backward(grad) + legacy.backward(grad) + + assert torch.allclose(efficient, legacy) + assert torch.allclose(x.grad, x_ref.grad) + assert torch.allclose(lora_A.grad, lora_A_ref.grad) + assert torch.allclose(lora_B.grad, lora_B_ref.grad) + + +def test_memory_efficient_lora_saves_less_forward_state(): + """The custom autograd path should not save the intermediate x @ lora_A.T activation.""" + torch.manual_seed(1234) + x = torch.randn(8, 16, requires_grad=True) + lora_A = torch.randn(4, 16, requires_grad=True) + lora_B = torch.randn(12, 4, requires_grad=True) + scale = 2.0 + + def collect_saved_tensors(fn): + saved = [] + + def pack_hook(tensor): + saved.append(tuple(tensor.shape)) + return tensor + + with saved_tensors_hooks(pack_hook, lambda tensor: tensor): + fn() + return saved + + legacy_saved = collect_saved_tensors(lambda: F.linear(F.linear(x, lora_A) * scale, lora_B)) + efficient_saved = collect_saved_tensors(lambda: LoRATritonFunction.apply(x, lora_A, lora_B, scale, x.dtype, False)) + + assert (8, 4) in legacy_saved + assert (8, 4) not in efficient_saved + assert sum(torch.tensor(shape).prod().item() for shape in efficient_saved) < sum( + torch.tensor(shape).prod().item() for shape in legacy_saved + ) + + +def test_linear_lora_memory_efficient_flag_controls_saved_state(): + """LinearLoRA should use the memory-efficient autograd path when the flag is enabled.""" + torch.manual_seed(1234) + base = nn.Linear(16, 12, bias=False) + x = torch.randn(8, 16, requires_grad=True) + legacy = LinearLoRA(base, dim=4, alpha=8, use_memory_efficient_lora=False) + efficient = LinearLoRA(base, dim=4, alpha=8, use_memory_efficient_lora=True) + + def collect_saved_tensors(fn): + saved = [] + + def pack_hook(tensor): + saved.append(tuple(tensor.shape)) + return tensor + + with saved_tensors_hooks(pack_hook, lambda tensor: tensor): + fn() + return saved + + legacy_saved = collect_saved_tensors(lambda: legacy(x)) + efficient_saved = collect_saved_tensors(lambda: efficient(x)) + + assert (8, 4) in legacy_saved + assert (8, 4) not in efficient_saved + + def test_lora_layers_are_trainable(): """Ensures that LoRA layers are trainable while base weights remain frozen.""" base = nn.Linear(16, 16) @@ -236,9 +331,7 @@ def test_patch_sets_super_fwd(self): """patch_linear_module should set super_fwd for TE Linear.""" from transformer_engine.pytorch.module.linear import Linear as TELinear - te_linear = TELinear( - in_features=16, out_features=32, bias=False, params_dtype=torch.bfloat16 - ).cuda() + te_linear = TELinear(in_features=16, out_features=32, bias=False, params_dtype=torch.bfloat16).cuda() patched = patch_linear_module(te_linear, dim=4, alpha=8, use_triton=False) assert hasattr(patched, "super_fwd"), "super_fwd should be set for TE Linear" assert patched.super_fwd is not None @@ -248,24 +341,16 @@ def test_lora_adapters_are_te_linear(self): """lora_A and lora_B should be TE Linear when base module is TE Linear.""" from transformer_engine.pytorch.module.linear import Linear as TELinear - te_linear = TELinear( - in_features=16, out_features=32, bias=False, params_dtype=torch.bfloat16 - ).cuda() + te_linear = TELinear(in_features=16, out_features=32, bias=False, params_dtype=torch.bfloat16).cuda() patched = patch_linear_module(te_linear, dim=4, alpha=8, use_triton=False) - assert isinstance(patched.lora_A, TELinear), ( - f"lora_A should be TE Linear, got {type(patched.lora_A)}" - ) - assert isinstance(patched.lora_B, TELinear), ( - f"lora_B should be TE Linear, got {type(patched.lora_B)}" - ) + assert isinstance(patched.lora_A, TELinear), f"lora_A should be TE Linear, got {type(patched.lora_A)}" + assert isinstance(patched.lora_B, TELinear), f"lora_B should be TE Linear, got {type(patched.lora_B)}" def test_forward_pass(self): """Patched TE Linear should produce valid output.""" from transformer_engine.pytorch.module.linear import Linear as TELinear - te_linear = TELinear( - in_features=16, out_features=32, bias=False, params_dtype=torch.bfloat16 - ).cuda() + te_linear = TELinear(in_features=16, out_features=32, bias=False, params_dtype=torch.bfloat16).cuda() patched = patch_linear_module(te_linear, dim=4, alpha=8, use_triton=False) x = torch.randn(2, 16, device="cuda", dtype=torch.bfloat16) out = patched(x) @@ -276,9 +361,7 @@ def test_backward_pass(self): """Backward pass through patched TE Linear should produce gradients on LoRA params.""" from transformer_engine.pytorch.module.linear import Linear as TELinear - te_linear = TELinear( - in_features=16, out_features=32, bias=False, params_dtype=torch.bfloat16 - ).cuda() + te_linear = TELinear(in_features=16, out_features=32, bias=False, params_dtype=torch.bfloat16).cuda() patched = patch_linear_module(te_linear, dim=4, alpha=8, use_triton=False) x = torch.randn(2, 16, device="cuda", dtype=torch.bfloat16, requires_grad=True) out = patched(x) From cc75b630628d1640bf0b03cd2c765ed6dc426fbc Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Thu, 14 May 2026 21:12:02 -0700 Subject: [PATCH 2/3] add use_memory_efficient_lora Signed-off-by: Alexandros Koumparoulis --- .../functional_tests/checkpoint/test_peft.py | 1 + .../checkpoint/test_peft_vlm.py | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/functional_tests/checkpoint/test_peft.py b/tests/functional_tests/checkpoint/test_peft.py index f4411b03e4..3146fc8192 100644 --- a/tests/functional_tests/checkpoint/test_peft.py +++ b/tests/functional_tests/checkpoint/test_peft.py @@ -274,6 +274,7 @@ def test_hf_peft_checkpoint(force_hf, use_triton): "moe_rank_scaling": False, "target_modules": [], "use_dora": False, + "use_memory_efficient_lora": True, "use_triton": False, } diff --git a/tests/functional_tests/checkpoint/test_peft_vlm.py b/tests/functional_tests/checkpoint/test_peft_vlm.py index f66082b900..d1c0771295 100644 --- a/tests/functional_tests/checkpoint/test_peft_vlm.py +++ b/tests/functional_tests/checkpoint/test_peft_vlm.py @@ -640,6 +640,7 @@ def test_hf_peft_checkpoint(): "moe_rank_scaling": False, "target_modules": [], "use_dora": False, + "use_memory_efficient_lora": True, "use_triton": True, } @@ -772,6 +773,7 @@ def test_hf_peft_checkpoint(): source_model = trainer.model_parts[0] from nemo_automodel.components.checkpoint.checkpointing import _load_hf_checkpoint_preserving_dtype + hf_model_path = cfg.get("model.pretrained_model_name_or_path") hf_state_dict = _load_hf_checkpoint_preserving_dtype(hf_model_path) or {} print(f"HF checkpoint loaded: {len(hf_state_dict)} keys from {hf_model_path}", flush=True) @@ -782,9 +784,7 @@ def test_hf_peft_checkpoint(): print(f"Model param keys (first 10, no lora): {model_keys_sorted[:10]}", flush=True) param_mismatches = [] buffer_mismatches = [] - for (sn, sp), (rn, rp) in zip( - source_model.named_parameters(), restored_model.named_parameters() - ): + for (sn, sp), (rn, rp) in zip(source_model.named_parameters(), restored_model.named_parameters()): assert sn == rn, f"Parameter name mismatch: {sn} vs {rn}" sp_full = sp.full_tensor() if hasattr(sp, "full_tensor") else sp rp_full = rp.full_tensor() if hasattr(rp, "full_tensor") else rp @@ -818,9 +818,7 @@ def test_hf_peft_checkpoint(): f"src_norm={sp_full.float().norm().item():.4f} rst_norm={rp_full.float().norm().item():.4f} " f"| {src_vs_hf} | {rst_vs_hf}" ) - for (sn, sb), (rn, rb) in zip( - source_model.named_buffers(), restored_model.named_buffers() - ): + for (sn, sb), (rn, rb) in zip(source_model.named_buffers(), restored_model.named_buffers()): assert sn == rn, f"Buffer name mismatch: {sn} vs {rn}" if sb.is_meta or rb.is_meta: buffer_mismatches.append(f" BUFFER {sn}: src_meta={sb.is_meta} rst_meta={rb.is_meta}") @@ -836,13 +834,16 @@ def test_hf_peft_checkpoint(): f"max_diff={diff.max().item():.6e} mean_diff={diff.mean().item():.6e}" ) if param_mismatches or buffer_mismatches: - print(f"\n{'='*80}", flush=True) - print(f"WEIGHT COMPARISON: {len(param_mismatches)} param mismatches, {len(buffer_mismatches)} buffer mismatches", flush=True) + print(f"\n{'=' * 80}", flush=True) + print( + f"WEIGHT COMPARISON: {len(param_mismatches)} param mismatches, {len(buffer_mismatches)} buffer mismatches", + flush=True, + ) for m in param_mismatches: print(m, flush=True) for m in buffer_mismatches: print(m, flush=True) - print(f"{'='*80}\n", flush=True) + print(f"{'=' * 80}\n", flush=True) else: print("WEIGHT COMPARISON: All parameters and buffers match exactly.", flush=True) From db99717610cd2bb88cd84bb9c5b1c38a9e15e1f9 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Wed, 27 May 2026 22:14:47 -0700 Subject: [PATCH 3/3] fix Signed-off-by: Alexandros Koumparoulis --- nemo_automodel/components/_peft/lora.py | 29 ++++++++- .../hf_peft/test_lora_memory_profile.py | 3 + tests/unit_tests/_peft/test_lora.py | 59 +++++++++++++++++++ 3 files changed, 88 insertions(+), 3 deletions(-) diff --git a/nemo_automodel/components/_peft/lora.py b/nemo_automodel/components/_peft/lora.py index 57bf669aef..14696fc728 100644 --- a/nemo_automodel/components/_peft/lora.py +++ b/nemo_automodel/components/_peft/lora.py @@ -297,7 +297,12 @@ def forward(self, x): # Apply scale before lora_B to keep lora_res as a Partial tensor. # This allows both res and lora_res to remain Partial, so only one reduce-scatter is needed after addition. # Multiplying after lora_B would convert Partial to Replicate, causing an extra reduce-scatter operation. - if self._should_use_memory_efficient_lora(x): + use_memory_efficient_lora = self._should_use_memory_efficient_lora(x) + if use_memory_efficient_lora: + if self.dropout_position == "pre" or not self.training or self.dropout_p == 0.0: + return LoRATritonFunction.apply( + x, self.lora_A.weight, self.lora_B.weight, self.scale, x.dtype, False, res + ) lora_res = LoRATritonFunction.apply( x, self.lora_A.weight, self.lora_B.weight, self.scale, x.dtype, False ) @@ -305,6 +310,8 @@ def forward(self, x): lora_res = self.lora_B(self.lora_A(x) * self.scale) if self.dropout_position == "post": lora_res = F.dropout(lora_res, p=self.dropout_p, training=self.training) + if use_memory_efficient_lora: + return lora_res.add_(res) return res + lora_res if getattr(self, "lora_magnitude", None) is None: @@ -385,11 +392,17 @@ def forward(self, x): if self.dropout_position == "pre": x = F.dropout(x, p=self.dropout_p, training=self.training) if self.use_memory_efficient_lora: + if self.dropout_position == "pre" or not self.training or self.dropout_p == 0.0: + return LoRATritonFunction.apply( + x, self.lora_A.weight, self.lora_B.weight, self.scale, x.dtype, True, res + ) lora_res = LoRATritonFunction.apply(x, self.lora_A.weight, self.lora_B.weight, self.scale, x.dtype, True) else: lora_res = self.lora_B(self.lora_A(x) * self.scale) if self.dropout_position == "post": lora_res = F.dropout(lora_res, p=self.dropout_p, training=self.training) + if self.use_memory_efficient_lora: + return lora_res.add_(res) return res + lora_res @@ -666,25 +679,32 @@ def setup_context(ctx, inputs, output): ctx.scale = scale ctx.dtype = dtype ctx.use_triton_kernel = bool(rest[0]) if rest else True + ctx.has_residual = len(rest) > 1 and rest[1] is not None ctx.num_inputs = len(inputs) @staticmethod - def forward(x, lora_A, lora_B, scale, dtype, use_triton_kernel=True): + def forward(x, lora_A, lora_B, scale, dtype, use_triton_kernel=True, res=None): """ Forward method for memory-efficient LoRA. - Reshapes 3D tensors into 2D and then calls either Triton kernels or PyTorch matmuls. + Reshapes 3D tensors into 2D and then calls either Triton kernels or PyTorch matmuls. When ``res`` is + provided, the residual is added in-place into the LoRA output to avoid allocating a separate add result. """ reshape = x.dim() == 3 if reshape: bs, seq_len, d = x.shape x = x.reshape(-1, d) + if res is not None: + res = res.reshape(-1, res.shape[-1]) if use_triton_kernel: lora_res = lora_forward_wrapper(x, lora_A.t(), lora_B.t(), res=None, scale=scale, dtype=dtype) else: lora_res = F.linear(F.linear(x, lora_A) * scale, lora_B) + if res is not None: + lora_res.add_(res) + if reshape: return lora_res.view(bs, seq_len, -1) return lora_res @@ -699,6 +719,7 @@ def backward(ctx, d_y): """ x, lora_A, lora_B = ctx.saved_tensors scale = ctx.scale + d_res = d_y if ctx.has_residual and ctx.needs_input_grad[6] else None reshape = x.dim() == 3 if reshape: @@ -729,6 +750,8 @@ def backward(ctx, d_y): d_x = d_x.view(bs, seq_len, d) gradients = (d_x, d_lora_A, d_lora_B, None, None) + if ctx.num_inputs == 7: + return gradients + (None, d_res) if ctx.num_inputs == 6: return gradients + (None,) return gradients diff --git a/tests/functional_tests/hf_peft/test_lora_memory_profile.py b/tests/functional_tests/hf_peft/test_lora_memory_profile.py index c977037bd4..c7bb13a2c4 100644 --- a/tests/functional_tests/hf_peft/test_lora_memory_profile.py +++ b/tests/functional_tests/hf_peft/test_lora_memory_profile.py @@ -103,6 +103,8 @@ def _profiler_memory_bytes(prof: torch.profiler.profile) -> int: def _profile_model( *, use_memory_efficient_lora: bool, use_triton: bool, device: torch.device, use_fsdp: bool ) -> ProfileResult: + torch.cuda.set_device(device) + torch.cuda.init() gc.collect() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats(device) @@ -193,6 +195,7 @@ def _dump_memory_gain(result: dict[str, int], *, mode: str, dist_rank: int | Non def _assert_profile_memory_gain(result: dict[str, int]) -> None: assert result["before_profiler_memory_bytes"] > 0 assert result["after_profiler_memory_bytes"] > 0 + assert result["before_profiler_memory_bytes"] > result["after_profiler_memory_bytes"], result assert result["saved_bytes"] >= _MIN_SAVED_BYTES, result diff --git a/tests/unit_tests/_peft/test_lora.py b/tests/unit_tests/_peft/test_lora.py index 3b580b53b7..cab65f0a6f 100644 --- a/tests/unit_tests/_peft/test_lora.py +++ b/tests/unit_tests/_peft/test_lora.py @@ -186,6 +186,38 @@ def test_memory_efficient_lora_matches_legacy_forward_and_backward(input_shape): assert torch.allclose(lora_B.grad, lora_B_ref.grad) +@pytest.mark.parametrize("input_shape", [(5, 16), (2, 3, 16)]) +def test_memory_efficient_lora_with_residual_matches_legacy_forward_and_backward(input_shape): + """Custom autograd LoRA should fold residual addition without changing gradients.""" + torch.manual_seed(1234) + scale = 2.0 + lora_dim = 4 + out_features = 12 + output_shape = (*input_shape[:-1], out_features) + + x = torch.randn(*input_shape, requires_grad=True) + lora_A = torch.randn(lora_dim, input_shape[-1], requires_grad=True) + lora_B = torch.randn(out_features, lora_dim, requires_grad=True) + res = torch.randn(*output_shape, requires_grad=True) + x_ref = x.detach().clone().requires_grad_(True) + lora_A_ref = lora_A.detach().clone().requires_grad_(True) + lora_B_ref = lora_B.detach().clone().requires_grad_(True) + res_ref = res.detach().clone().requires_grad_(True) + + efficient = LoRATritonFunction.apply(x, lora_A, lora_B, scale, x.dtype, False, res) + legacy = res_ref + F.linear(F.linear(x_ref, lora_A_ref) * scale, lora_B_ref) + + grad = torch.randn_like(legacy) + efficient.backward(grad) + legacy.backward(grad) + + assert torch.allclose(efficient, legacy) + assert torch.allclose(x.grad, x_ref.grad) + assert torch.allclose(lora_A.grad, lora_A_ref.grad) + assert torch.allclose(lora_B.grad, lora_B_ref.grad) + assert torch.allclose(res.grad, res_ref.grad) + + def test_memory_efficient_lora_saves_less_forward_state(): """The custom autograd path should not save the intermediate x @ lora_A.T activation.""" torch.manual_seed(1234) @@ -241,6 +273,33 @@ def pack_hook(tensor): assert (8, 4) not in efficient_saved +def test_linear_lora_memory_efficient_matches_legacy_module_forward_and_backward(): + """LinearLoRA should preserve legacy module behavior when folding the residual add.""" + torch.manual_seed(1234) + base = nn.Linear(16, 12, bias=False) + x = torch.randn(8, 16, requires_grad=True) + x_ref = x.detach().clone().requires_grad_(True) + legacy = LinearLoRA(base, dim=4, alpha=8, use_memory_efficient_lora=False) + efficient = LinearLoRA(base, dim=4, alpha=8, use_memory_efficient_lora=True) + + with torch.no_grad(): + legacy.lora_A.weight.normal_() + legacy.lora_B.weight.normal_() + efficient.lora_A.weight.copy_(legacy.lora_A.weight) + efficient.lora_B.weight.copy_(legacy.lora_B.weight) + + efficient_out = efficient(x) + legacy_out = legacy(x_ref) + grad = torch.randn_like(legacy_out) + efficient_out.backward(grad) + legacy_out.backward(grad) + + assert torch.allclose(efficient_out, legacy_out) + assert torch.allclose(x.grad, x_ref.grad) + assert torch.allclose(efficient.lora_A.weight.grad, legacy.lora_A.weight.grad) + assert torch.allclose(efficient.lora_B.weight.grad, legacy.lora_B.weight.grad) + + def test_lora_layers_are_trainable(): """Ensures that LoRA layers are trainable while base weights remain frozen.""" base = nn.Linear(16, 16)