Skip to content

Fix mixed precision call in group norm sharded. #1380

Closed
coreyjadams wants to merge 4 commits intoNVIDIA:mainfrom
coreyjadams:hotfix-sharded-group-norm
Closed

Fix mixed precision call in group norm sharded. #1380
coreyjadams wants to merge 4 commits intoNVIDIA:mainfrom
coreyjadams:hotfix-sharded-group-norm

Conversation

@coreyjadams
Copy link
Copy Markdown
Collaborator

Also fix a math error in how variances are combined across GPUs.

PhysicsNeMo Pull Request

Description

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

Comment on lines +124 to +128
if weight is not None:
weight = weight.to(input.dtype)
if bias is not None:
bias = bias.to(input.dtype)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes a mixed precision crash.

torch tweaking the python dispatch behavior and DTensor.

- adding more layers to handle select
- add more reliable handling of casting DTensor to ShardTensor.  In particular,
  the focus is on making sure we maintain proper autograd graphs.
- switch to a first-principles implemetation of group norm.  It's more stable,
  simpler, and while it might be a little slower the upcoming torch.compile
  work can address that.
- add a dedicated view handler at functional and dispatch level.  It's necessary
  at this point to wrap our own view implementation due to the differences with DTensor.
@coreyjadams coreyjadams marked this pull request as ready for review February 11, 2026 21:28
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Feb 11, 2026

Greptile Overview

Greptile Summary

This PR fixes two critical bugs in the sharded group normalization implementation and adds comprehensive AMP testing.

Key Changes:

  • Fixed variance calculation bug in normalization_patches.py: The old code incorrectly computed variance by inverting local rstd values and averaging them across GPUs (global_var = (1.0 / (rstd**2)) - eps). This is mathematically incorrect because you cannot average variances computed from different data partitions. The new implementation correctly computes global variance using Var(X) = E[X²] - E[X]² by reducing sums and sum-of-squares across GPUs.

  • Fixed mixed precision handling: Added explicit dtype casting for weight and bias parameters to match input dtype (lines 114-118), and added dtype conversion for gradients in backward pass (lines 210-212). This prevents dtype mismatches when running with AMP.

  • Complete rewrite of group normalization: Replaced reliance on aten.native_group_norm / aten.native_group_norm_backward with a from-scratch implementation that properly handles distributed statistics, reducing all-reduce calls from 3 in forward (mean, variance, separate reductions) to 1 (fused sum and sum-of-squares).

  • Added comprehensive AMP testing: Extended test suite with amp parameter to validate mixed precision behavior works correctly with the fixes.

Additional Changes:

  • Added DTensor conversion utilities and autograd-preserving promotion functions in shard_tensor.py
  • New view_ops.py module with proper view/reshape operations for ShardTensor
  • Updated version check to handle torch 2.10.0a (alpha releases)

Important Files Changed

Filename Overview
physicsnemo/domain_parallel/shard_utils/normalization_patches.py Complete rewrite implementing group normalization from first principles; fixes critical variance calculation bug and adds proper mixed precision dtype handling
test/domain_parallel/ops/test_normalization.py Added AMP testing parameters to validate mixed precision behavior in layer norm and group norm tests
test/domain_parallel/ops/utils.py Added AMP support to numerical_shard_tensor_check with autocast context wrapping forward and backward passes
physicsnemo/domain_parallel/shard_tensor.py Added DTensor conversion helpers and autograd-preserving promotion functions to support improved ShardTensor/DTensor interoperability

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

)

if check_version_spec("torch", "2.10.0"):
if check_version_spec("torch", "2.10.0a"):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to get pre-release versions. Fixes #1394

@coreyjadams
Copy link
Copy Markdown
Collaborator Author

This is getting broken up in to smaller PRs for easier review.

@coreyjadams coreyjadams deleted the hotfix-sharded-group-norm branch March 5, 2026 15:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant