Enable training under spatial model parallelism (ModelTorchDistributed)#16
Enable training under spatial model parallelism (ModelTorchDistributed)#16
Conversation
use graph-aware all_reduce inside spatial mean
There was a problem hiding this comment.
Pull request overview
Enables correct training behavior when using spatial model parallelism with
ModelTorchDistributed, addressing DDP buffer broadcast issues and ensuring weight
updates are consistent across spatial ranks.
Changes:
- Update
ModelTorchDistributed.wrap_module()to disable DDP buffer broadcasting
and register spatial gradient all-reduce hooks. - Make
spatial_reduce_sum()autograd-safe via a custom autogradFunction. - Add parallel regression tests (and associated
.ptartifacts) for
SingleModuleStepperusingNoiseConditionedSFNO.
Reviewed changes
Copilot reviewed 2 out of 4 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
fme/core/distributed/model_torch_distributed.py |
Disables DDP buffer broadcast, adds differentiable spatial all-reduce and spatial grad synchronization hooks. |
fme/ace/stepper/test_single_module_csfno.py |
Adds parallel regression tests for training/predict with spatial decomposition. |
fme/ace/stepper/testdata/csfno_stepper_train_on_batch_regression.pt |
Regression artifact for train_on_batch outputs. |
fme/ace/stepper/testdata/csfno_stepper_predict_regression.pt |
Regression artifact for predict outputs. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
8983ef7 to
d920bd9
Compare
There was a problem hiding this comment.
Pull request overview
This PR updates the distributed backend to support training under spatial model parallelism by preventing unsafe DDP buffer broadcasts and by ensuring gradients are synchronized across spatial ranks, and adds parallel regression tests for SingleModuleStepper using NoiseConditionedSFNO.
Changes:
fme.core.distributed.ModelTorchDistributed: wrap DDP withbroadcast_buffers=Falseand introduce an autograd-friendly spatial all-reduce for loss reductions.fme.core.distributed.ModelTorchDistributed: register spatial gradient hooks to all-reduce parameter gradients across spatial ranks.- Add parallel regression tests + golden
.ptartifacts for CSFNO stepper train/predict.
Reviewed changes
Copilot reviewed 2 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
fme/core/distributed/model_torch_distributed.py |
Adds autograd-aware spatial reduction, disables DDP buffer broadcasts, and introduces spatial grad all-reduce hooks. |
fme/ace/stepper/test_single_module_csfno.py |
Adds parallel regression tests for CSFNO stepper behavior under spatial decomposition. |
fme/ace/stepper/testdata/csfno_stepper_train_on_batch_regression.pt |
Golden outputs for CSFNO train-on-batch regression. |
fme/ace/stepper/testdata/csfno_stepper_predict_regression.pt |
Golden outputs for CSFNO predict regression. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
d920bd9 to
ebcf3a8
Compare
There was a problem hiding this comment.
Pull request overview
This PR aims to make training work correctly under spatial model parallelism
(ModelTorchDistributed) by preventing DDP buffer broadcasts from mutating SHT
buffers and by ensuring gradients are synchronized across spatial ranks. It also
adds parallel regression tests for SingleModuleStepper using
NoiseConditionedSFNO to validate consistent results across spatial
decompositions.
Changes:
- Update
ModelTorchDistributed.wrap_moduleto disable DDP buffer broadcasting
and introduce an autograd-safe spatial all-reduce used in the loss reduction
path. - Add parallel regression tests for training/prediction with
NoiseConditionedSFNO, plus new.ptregression artifacts.
Reviewed changes
Copilot reviewed 2 out of 5 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
fme/core/distributed/model_torch_distributed.py |
Adjusts DDP wrapping and introduces spatial reduction / (intended) spatial gradient synchronization logic. |
fme/ace/stepper/test_single_module_csfno.py |
Adds parallel regression tests covering train/predict behavior under spatial decomposition. |
fme/ace/stepper/testdata/csfno_stepper_train_on_batch_with_optimization_regression.pt |
Regression artifact for optimized training path. |
fme/ace/stepper/testdata/csfno_stepper_predict_regression.pt |
Regression artifact for predict path. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # reduced /= (self._h_size * self._w_size) | ||
|
|
||
| return reduced | ||
|
|
There was a problem hiding this comment.
Pull request overview
Enables stable training under the ModelTorchDistributed spatial model-parallel backend by preventing problematic DDP buffer broadcasts, making spatial reductions autograd-safe, and validating behavior with parallel regression tests using NoiseConditionedSFNO.
Changes:
- Disable DDP buffer broadcasting for model-parallel wrapping to avoid in-place buffer mutations that break autograd version tracking.
- Add an autograd-aware spatial all-reduce for
spatial_reduce_sumand register spatial gradient hooks to all-reduce parameter grads across spatial ranks. - Add parallel regression tests for
SingleModuleStepper+NoiseConditionedSFNO, along with new.ptregression artifacts.
Reviewed changes
Copilot reviewed 2 out of 5 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
fme/core/distributed/model_torch_distributed.py |
Makes spatial reductions autograd-safe, disables DDP buffer broadcast, and adds spatial gradient all-reduce hooks. |
fme/ace/stepper/test_single_module_csfno.py |
Adds parallel regression tests validating consistent behavior across spatial decompositions. |
fme/ace/stepper/testdata/csfno_stepper_train_on_batch_with_optimization_regression.pt |
New regression baseline for training with optimization enabled. |
fme/ace/stepper/testdata/csfno_stepper_predict_regression.pt |
New regression baseline for prediction output consistency. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| output = input.clone() | ||
| torch.distributed.all_reduce(output, group=group) | ||
| return output | ||
|
|
||
| @staticmethod | ||
| @custom_bwd(device_type="cuda") | ||
| def backward(ctx, grad_output: torch.Tensor): | ||
| return grad_output.clone(), None |
| for p in module.parameters(): | ||
| if p.requires_grad: | ||
| p.register_hook(_hook) | ||
|
|
| def get_predict_output_tensor_dict( | ||
| output: BatchData, next_state: PrognosticState | ||
| ) -> dict[str, torch.Tensor]: | ||
| return flatten_dict( | ||
| { | ||
| "output": output.data, | ||
| "next_state": next_state.as_batch_data().data, | ||
| } | ||
| ) | ||
|
|
||
|
|
085a71a to
582bca4
Compare
582bca4 to
08647d8
Compare
backwards step Add test that verifies consistency between NonDistribute and TorchModelDistributed for loss and gradient calculation using simple SHT/iSHT transforms
Two fixes in model_torch_distributed.py:
Set broadcast_buffers=False in DDP wrapping — DDP's default buffer broadcast modifies SHT/iSHT Legendre polynomial buffers in-place between forward calls, breaking autograd's version tracking.
Register spatial gradient hooks (_register_spatial_grad_hooks) that all-reduce parameter gradients across spatial ranks after backward, so every rank applies the same weight update. This is the spatial analogue of what DDP does for data parallelism.
Add test_single_module_csfno.py with parallel regression tests using NoiseConditionedSFNO (which supports spatial parallelism). Tests use AreaWeightedMSE loss which correctly reduces across spatial ranks via the existing gridded_operations.area_weighted_mean → weighted_mean → spatial_reduce_sum path.
Short description of why the PR is needed and how it satisfies those requirements, in sentence form.
Changes:
symbol (e.g.
fme.core.my_function) or script and concise description of changes or added featureCan group multiple related symbols on a single bullet
Tests added
If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated
Resolves # (delete if none)