Skip to content

Enable training under spatial model parallelism (ModelTorchDistributed)#16

Open
mahf708 wants to merge 5 commits intomainfrom
spatial-parallel-training
Open

Enable training under spatial model parallelism (ModelTorchDistributed)#16
mahf708 wants to merge 5 commits intomainfrom
spatial-parallel-training

Conversation

@mahf708
Copy link
Copy Markdown
Collaborator

@mahf708 mahf708 commented Mar 19, 2026

Two fixes in model_torch_distributed.py:

  1. Set broadcast_buffers=False in DDP wrapping — DDP's default buffer broadcast modifies SHT/iSHT Legendre polynomial buffers in-place between forward calls, breaking autograd's version tracking.

  2. Register spatial gradient hooks (_register_spatial_grad_hooks) that all-reduce parameter gradients across spatial ranks after backward, so every rank applies the same weight update. This is the spatial analogue of what DDP does for data parallelism.

Add test_single_module_csfno.py with parallel regression tests using NoiseConditionedSFNO (which supports spatial parallelism). Tests use AreaWeightedMSE loss which correctly reduces across spatial ranks via the existing gridded_operations.area_weighted_mean → weighted_mean → spatial_reduce_sum path.

Short description of why the PR is needed and how it satisfies those requirements, in sentence form.

Changes:

  • symbol (e.g. fme.core.my_function) or script and concise description of changes or added feature

  • Can group multiple related symbols on a single bullet

  • Tests added

  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

Resolves # (delete if none)

use graph-aware all_reduce inside spatial mean
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Enables correct training behavior when using spatial model parallelism with
ModelTorchDistributed, addressing DDP buffer broadcast issues and ensuring weight
updates are consistent across spatial ranks.

Changes:

  • Update ModelTorchDistributed.wrap_module() to disable DDP buffer broadcasting
    and register spatial gradient all-reduce hooks.
  • Make spatial_reduce_sum() autograd-safe via a custom autograd Function.
  • Add parallel regression tests (and associated .pt artifacts) for
    SingleModuleStepper using NoiseConditionedSFNO.

Reviewed changes

Copilot reviewed 2 out of 4 changed files in this pull request and generated 1 comment.

File Description
fme/core/distributed/model_torch_distributed.py Disables DDP buffer broadcast, adds differentiable spatial all-reduce and spatial grad synchronization hooks.
fme/ace/stepper/test_single_module_csfno.py Adds parallel regression tests for training/predict with spatial decomposition.
fme/ace/stepper/testdata/csfno_stepper_train_on_batch_regression.pt Regression artifact for train_on_batch outputs.
fme/ace/stepper/testdata/csfno_stepper_predict_regression.pt Regression artifact for predict outputs.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@mahf708 mahf708 force-pushed the spatial-parallel-training branch from 8983ef7 to d920bd9 Compare March 19, 2026 14:42
@mahf708 mahf708 requested a review from Copilot March 19, 2026 14:43
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the distributed backend to support training under spatial model parallelism by preventing unsafe DDP buffer broadcasts and by ensuring gradients are synchronized across spatial ranks, and adds parallel regression tests for SingleModuleStepper using NoiseConditionedSFNO.

Changes:

  • fme.core.distributed.ModelTorchDistributed: wrap DDP with broadcast_buffers=False and introduce an autograd-friendly spatial all-reduce for loss reductions.
  • fme.core.distributed.ModelTorchDistributed: register spatial gradient hooks to all-reduce parameter gradients across spatial ranks.
  • Add parallel regression tests + golden .pt artifacts for CSFNO stepper train/predict.

Reviewed changes

Copilot reviewed 2 out of 4 changed files in this pull request and generated 3 comments.

File Description
fme/core/distributed/model_torch_distributed.py Adds autograd-aware spatial reduction, disables DDP buffer broadcasts, and introduces spatial grad all-reduce hooks.
fme/ace/stepper/test_single_module_csfno.py Adds parallel regression tests for CSFNO stepper behavior under spatial decomposition.
fme/ace/stepper/testdata/csfno_stepper_train_on_batch_regression.pt Golden outputs for CSFNO train-on-batch regression.
fme/ace/stepper/testdata/csfno_stepper_predict_regression.pt Golden outputs for CSFNO predict regression.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@mahf708 mahf708 force-pushed the spatial-parallel-training branch from d920bd9 to ebcf3a8 Compare March 19, 2026 15:14
@mahf708 mahf708 requested a review from Copilot March 19, 2026 15:14
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR aims to make training work correctly under spatial model parallelism
(ModelTorchDistributed) by preventing DDP buffer broadcasts from mutating SHT
buffers and by ensuring gradients are synchronized across spatial ranks. It also
adds parallel regression tests for SingleModuleStepper using
NoiseConditionedSFNO to validate consistent results across spatial
decompositions.

Changes:

  • Update ModelTorchDistributed.wrap_module to disable DDP buffer broadcasting
    and introduce an autograd-safe spatial all-reduce used in the loss reduction
    path.
  • Add parallel regression tests for training/prediction with
    NoiseConditionedSFNO, plus new .pt regression artifacts.

Reviewed changes

Copilot reviewed 2 out of 5 changed files in this pull request and generated 1 comment.

File Description
fme/core/distributed/model_torch_distributed.py Adjusts DDP wrapping and introduces spatial reduction / (intended) spatial gradient synchronization logic.
fme/ace/stepper/test_single_module_csfno.py Adds parallel regression tests covering train/predict behavior under spatial decomposition.
fme/ace/stepper/testdata/csfno_stepper_train_on_batch_with_optimization_regression.pt Regression artifact for optimized training path.
fme/ace/stepper/testdata/csfno_stepper_predict_regression.pt Regression artifact for predict path.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

# reduced /= (self._h_size * self._w_size)

return reduced

@mahf708 mahf708 requested a review from Copilot March 19, 2026 16:12
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Enables stable training under the ModelTorchDistributed spatial model-parallel backend by preventing problematic DDP buffer broadcasts, making spatial reductions autograd-safe, and validating behavior with parallel regression tests using NoiseConditionedSFNO.

Changes:

  • Disable DDP buffer broadcasting for model-parallel wrapping to avoid in-place buffer mutations that break autograd version tracking.
  • Add an autograd-aware spatial all-reduce for spatial_reduce_sum and register spatial gradient hooks to all-reduce parameter grads across spatial ranks.
  • Add parallel regression tests for SingleModuleStepper + NoiseConditionedSFNO, along with new .pt regression artifacts.

Reviewed changes

Copilot reviewed 2 out of 5 changed files in this pull request and generated 3 comments.

File Description
fme/core/distributed/model_torch_distributed.py Makes spatial reductions autograd-safe, disables DDP buffer broadcast, and adds spatial gradient all-reduce hooks.
fme/ace/stepper/test_single_module_csfno.py Adds parallel regression tests validating consistent behavior across spatial decompositions.
fme/ace/stepper/testdata/csfno_stepper_train_on_batch_with_optimization_regression.pt New regression baseline for training with optimization enabled.
fme/ace/stepper/testdata/csfno_stepper_predict_regression.pt New regression baseline for prediction output consistency.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +65 to +72
output = input.clone()
torch.distributed.all_reduce(output, group=group)
return output

@staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output: torch.Tensor):
return grad_output.clone(), None
Comment on lines +395 to +398
for p in module.parameters():
if p.requires_grad:
p.register_hook(_hook)

Comment on lines +165 to +175
def get_predict_output_tensor_dict(
output: BatchData, next_state: PrognosticState
) -> dict[str, torch.Tensor]:
return flatten_dict(
{
"output": output.data,
"next_state": next_state.as_batch_data().data,
}
)


@mahf708 mahf708 force-pushed the spatial-parallel-training branch 2 times, most recently from 085a71a to 582bca4 Compare March 19, 2026 20:28
@mahf708 mahf708 force-pushed the spatial-parallel-training branch from 582bca4 to 08647d8 Compare March 19, 2026 20:30
backwards step

Add test that verifies consistency between NonDistribute and
TorchModelDistributed for loss and gradient calculation using simple
SHT/iSHT transforms
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants