Skip to content

Conversation

@tomlifu
Copy link
Contributor

@tomlifu tomlifu commented Feb 6, 2026

Description

This PR is needed to support vision encoder CUDA Graph.

Related MLM PR: NVIDIA/Megatron-LM#3293, NVIDIA/Megatron-LM#3294

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Lifu Zhang and others added 2 commits February 6, 2026 10:33
Signed-off-by: Lifu Zhang <lifuz@login-lyris02.lyris.clusters.nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Greptile Overview

Greptile Summary

This PR adds None-safety checks throughout the CUDA Graph capture code to support vision encoder modules. The changes prevent AttributeError when module outputs contain None values by checking o is not None before accessing the requires_grad attribute.

  • Added None checks in 7 locations where outputs are filtered for gradient computation
  • Modified warmup phase to handle None outputs in torch.autograd.backward calls
  • Updated backward graph capture to safely create gradient tensors only for non-None outputs
  • Fixed forward graph replay to properly detach non-None outputs while preserving None values
  • Changes are defensive and backward-compatible - no impact when outputs don't contain None

The fix is minimal and surgical, adding safety checks without changing the underlying logic or control flow.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The changes are defensive None-safety checks that prevent AttributeError exceptions. The pattern o is not None and o.requires_grad is already used elsewhere in the codebase. All modifications follow the same pattern consistently across 7 locations. The fix is backward-compatible and doesn't change behavior when outputs don't contain None values. This unblocks CUDA Graph support for vision encoders without breaking existing functionality.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/pytorch/graph.py Added None checks before accessing requires_grad attribute on outputs to prevent AttributeError when outputs contain None values, enabling CUDA Graph support for vision encoders

Sequence Diagram

sequenceDiagram
    participant User
    participant make_graphed_callables
    participant _make_graphed_callables
    participant Forward Graph
    participant Backward Graph
    participant Module

    User->>make_graphed_callables: Call with modules & sample_args
    make_graphed_callables->>_make_graphed_callables: Pass callables & args
    
    Note over _make_graphed_callables: Warmup Phase
    _make_graphed_callables->>Module: Run warmup iterations
    Module-->>_make_graphed_callables: Return outputs (may contain None)
    
    Note over _make_graphed_callables: Graph Capture Phase
    _make_graphed_callables->>Forward Graph: Capture forward pass
    Module-->>Forward Graph: Store static outputs
    
    Note over _make_graphed_callables: Filter outputs with None check
    _make_graphed_callables->>_make_graphed_callables: Check "o is not None and o.requires_grad"
    
    _make_graphed_callables->>Backward Graph: Capture backward pass
    _make_graphed_callables->>Backward Graph: Create grad tensors for valid outputs
    
    Note over _make_graphed_callables: Graph Replay Phase
    _make_graphed_callables->>User: Return graphed callables
    User->>Forward Graph: Call graphed module
    Forward Graph->>Forward Graph: Replay captured graph
    Forward Graph-->>User: Return detached outputs (None-safe)
    User->>Backward Graph: Trigger backward
    Backward Graph->>Backward Graph: Replay captured backward
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant