Skip to content

Commit 8983ef7

Browse files
mahf708claudepeterdschwartz
committed
Enable training under spatial model parallelism (ModelTorchDistributed)
Two fixes in model_torch_distributed.py: 1. Set broadcast_buffers=False in DDP wrapping — DDP's default buffer broadcast modifies SHT/iSHT Legendre polynomial buffers in-place between forward calls, breaking autograd's version tracking. 2. Register spatial gradient hooks (_register_spatial_grad_hooks) that all-reduce parameter gradients across spatial ranks after backward, so every rank applies the same weight update. This is the spatial analogue of what DDP does for data parallelism. Add test_single_module_csfno.py with parallel regression tests using NoiseConditionedSFNO (which supports spatial parallelism). Tests use AreaWeightedMSE loss which correctly reduces across spatial ranks via the existing gridded_operations.area_weighted_mean → weighted_mean → spatial_reduce_sum path. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: peterdschwartz <peterdschwartz83@gmail.com>
1 parent fca6446 commit 8983ef7

4 files changed

Lines changed: 295 additions & 5 deletions

File tree

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
"""
2+
Parallel regression tests for the SingleModuleStepper with NoiseConditionedSFNO.
3+
4+
These tests verify that the forward pass and loss computation produce identical
5+
results regardless of spatial decomposition (nproc=1 vs model-parallel).
6+
"""
7+
8+
import dataclasses
9+
import datetime
10+
import os
11+
from collections.abc import Mapping
12+
13+
import numpy as np
14+
import pytest
15+
import torch
16+
import xarray as xr
17+
18+
from fme.ace.data_loading.batch_data import BatchData, PrognosticState
19+
from fme.ace.registry.stochastic_sfno import NoiseConditionedSFNOBuilder
20+
from fme.ace.stepper.single_module import (
21+
StepperConfig,
22+
TrainOutput,
23+
TrainStepper,
24+
TrainStepperConfig,
25+
)
26+
from fme.core.coordinates import HybridSigmaPressureCoordinate, LatLonCoordinates
27+
from fme.core.dataset_info import DatasetInfo
28+
from fme.core.device import get_device
29+
from fme.core.distributed.distributed import Distributed
30+
from fme.core.loss import StepLossConfig
31+
from fme.core.normalizer import NetworkAndLossNormalizationConfig, NormalizationConfig
32+
from fme.core.optimization import NullOptimization
33+
from fme.core.registry.module import ModuleSelector
34+
from fme.core.step import SingleModuleStepConfig, StepSelector
35+
from fme.core.testing.regression import validate_tensor_dict
36+
37+
DIR = os.path.abspath(os.path.dirname(__file__))
38+
TIMESTEP = datetime.timedelta(hours=6)
39+
40+
41+
def get_dataset_info(
42+
img_shape=(5, 5),
43+
) -> DatasetInfo:
44+
horizontal_coordinate = LatLonCoordinates(
45+
lat=torch.zeros(img_shape[-2]),
46+
lon=torch.zeros(img_shape[-1]),
47+
)
48+
vertical_coordinate = HybridSigmaPressureCoordinate(
49+
ak=torch.arange(7), bk=torch.arange(7)
50+
)
51+
return DatasetInfo(
52+
horizontal_coordinates=horizontal_coordinate,
53+
vertical_coordinate=vertical_coordinate,
54+
timestep=TIMESTEP,
55+
)
56+
57+
58+
def _get_train_stepper(
59+
stepper_config: StepperConfig,
60+
dataset_info: DatasetInfo,
61+
**train_config_kwargs,
62+
) -> TrainStepper:
63+
train_config = TrainStepperConfig(**train_config_kwargs)
64+
return train_config.get_train_stepper(stepper_config, dataset_info)
65+
66+
67+
def get_regression_stepper_and_data() -> tuple[
68+
TrainStepper, BatchData, tuple[int, int]
69+
]:
70+
in_names = ["a", "b"]
71+
out_names = ["b", "c"]
72+
n_forward_steps = 2
73+
n_samples = 3
74+
img_shape = (9, 18)
75+
device = get_device()
76+
77+
all_names = list(set(in_names + out_names))
78+
79+
loss = StepLossConfig(type="AreaWeightedMSE")
80+
81+
config = StepperConfig(
82+
step=StepSelector(
83+
type="single_module",
84+
config=dataclasses.asdict(
85+
SingleModuleStepConfig(
86+
builder=ModuleSelector(
87+
type="NoiseConditionedSFNO",
88+
config=dataclasses.asdict(
89+
NoiseConditionedSFNOBuilder(
90+
embed_dim=16,
91+
num_layers=2,
92+
noise_embed_dim=16,
93+
noise_type="isotropic",
94+
)
95+
),
96+
),
97+
in_names=in_names,
98+
out_names=out_names,
99+
normalization=NetworkAndLossNormalizationConfig(
100+
network=NormalizationConfig(
101+
means={n: 0.1 for n in all_names},
102+
stds={n: 1.1 for n in all_names},
103+
),
104+
),
105+
ocean=None,
106+
)
107+
),
108+
),
109+
)
110+
111+
dataset_info = get_dataset_info(img_shape=img_shape)
112+
train_stepper = _get_train_stepper(config, dataset_info, loss=loss)
113+
data = BatchData.new_on_device(
114+
data={
115+
"a": torch.randn(n_samples, n_forward_steps + 1, *img_shape).to(device),
116+
"b": torch.randn(n_samples, n_forward_steps + 1, *img_shape).to(device),
117+
"c": torch.randn(n_samples, n_forward_steps + 1, *img_shape).to(device),
118+
},
119+
time=xr.DataArray(
120+
np.zeros((n_samples, n_forward_steps + 1)),
121+
dims=["sample", "time"],
122+
),
123+
labels=None,
124+
epoch=0,
125+
horizontal_dims=["lat", "lon"],
126+
)
127+
data = data.scatter_spatial(img_shape)
128+
return train_stepper, data, img_shape
129+
130+
131+
def flatten_dict(
132+
d: Mapping[str, Mapping[str, torch.Tensor]],
133+
) -> dict[str, torch.Tensor]:
134+
return_dict = {}
135+
for k, v in d.items():
136+
for k2, v2 in v.items():
137+
return_dict[f"{k}.{k2}"] = v2
138+
return return_dict
139+
140+
141+
def _get_train_output_tensor_dict(data: TrainOutput) -> dict[str, torch.Tensor]:
142+
return_dict = {}
143+
for k, v in data.metrics.items():
144+
return_dict[f"metrics.{k}"] = v
145+
for k, v in data.gen_data.items():
146+
return_dict[f"gen_data.{k}"] = v
147+
for k, v in data.target_data.items():
148+
assert v.shape[1] == 1
149+
return_dict[f"target_data.{k}"] = v
150+
return return_dict
151+
152+
153+
def get_train_outputs_tensor_dict(
154+
step_1: TrainOutput, step_2: TrainOutput
155+
) -> dict[str, torch.Tensor]:
156+
return flatten_dict(
157+
{
158+
"step_1": _get_train_output_tensor_dict(step_1),
159+
"step_2": _get_train_output_tensor_dict(step_2),
160+
}
161+
)
162+
163+
164+
def get_predict_output_tensor_dict(
165+
output: BatchData, next_state: PrognosticState
166+
) -> dict[str, torch.Tensor]:
167+
return flatten_dict(
168+
{
169+
"output": output.data,
170+
"next_state": next_state.as_batch_data().data,
171+
}
172+
)
173+
174+
175+
@pytest.mark.parallel
176+
def test_stepper_train_on_batch_regression():
177+
torch.manual_seed(0)
178+
train_stepper, data, img_shape = get_regression_stepper_and_data()
179+
optimization = NullOptimization()
180+
result1 = train_stepper.train_on_batch(data, optimization)
181+
result2 = train_stepper.train_on_batch(data, optimization)
182+
dist = Distributed.get_instance()
183+
for result in [result1, result2]:
184+
result.gen_data = dist.gather_spatial(dict(result.gen_data), img_shape)
185+
result.target_data = dist.gather_spatial(dict(result.target_data), img_shape)
186+
output_dict = get_train_outputs_tensor_dict(result1, result2)
187+
validate_tensor_dict(
188+
output_dict,
189+
os.path.join(
190+
DIR,
191+
"testdata/csfno_stepper_train_on_batch_regression.pt",
192+
),
193+
atol=1e-4,
194+
rtol=1e-4,
195+
)
196+
197+
198+
@pytest.mark.parallel
199+
def test_stepper_predict_regression():
200+
torch.manual_seed(0)
201+
train_stepper, data, img_shape = get_regression_stepper_and_data()
202+
stepper = train_stepper._stepper
203+
initial_condition = data.get_start(
204+
prognostic_names=["b"],
205+
n_ic_timesteps=1,
206+
)
207+
output, next_state = stepper.predict(
208+
initial_condition, data, compute_derived_variables=True
209+
)
210+
dist = Distributed.get_instance()
211+
output_data = dist.gather_spatial(dict(output.data), img_shape)
212+
next_state_data = dist.gather_spatial(
213+
dict(next_state.as_batch_data().data), img_shape
214+
)
215+
output_dict = flatten_dict(
216+
{"output": output_data, "next_state": next_state_data}
217+
)
218+
validate_tensor_dict(
219+
output_dict,
220+
os.path.join(DIR, "testdata/csfno_stepper_predict_regression.pt"),
221+
atol=1e-4,
222+
rtol=1e-4,
223+
)
11.9 KB
Binary file not shown.
Binary file not shown.

fme/core/distributed/model_torch_distributed.py

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import torch
2525
import torch.distributed
2626
import torch.nn as nn
27+
from torch.amp import custom_bwd, custom_fwd
2728
import torch_harmonics.distributed as thd
2829
from torch.nn import SyncBatchNorm
2930
from torch.nn.parallel import DistributedDataParallel
@@ -42,6 +43,35 @@
4243
T = TypeVar("T")
4344

4445

46+
class _AutogradAllReduce(torch.autograd.Function):
47+
"""Autograd-aware all-reduce (sum) for spatial parallelism.
48+
Forward: all-reduce (sum) the input across the given process group.
49+
Backward: identity — gradients pass through without communication.
50+
This makes ``spatial_reduce_sum`` differentiable so that gradients
51+
flow correctly through the loss computation path::
52+
AreaWeightedMSELoss → area_weighted_mean → weighted_mean
53+
→ spatial_reduce_sum (uses this function)
54+
Without this, the raw ``torch.distributed.all_reduce`` would break
55+
the autograd graph because it is an in-place, non-differentiable op.
56+
"""
57+
58+
@staticmethod
59+
@custom_fwd(device_type="cuda")
60+
def forward(
61+
ctx,
62+
input: torch.Tensor,
63+
group: torch.distributed.ProcessGroup,
64+
) -> torch.Tensor:
65+
output = input.clone()
66+
torch.distributed.all_reduce(output, group=group)
67+
return output
68+
69+
@staticmethod
70+
@custom_bwd(device_type="cuda")
71+
def backward(ctx, grad_output: torch.Tensor):
72+
return grad_output.clone(), None
73+
74+
4575
class ModelTorchDistributed(DistributedBackend):
4676
"""Distributed backend with spatial model parallelism.
4777
@@ -307,31 +337,68 @@ def _device_ids(self) -> list[int] | None:
307337
def wrap_module(self, module: torch.nn.Module) -> torch.nn.Module:
308338
"""Wrap with DDP over the **data** process group.
309339
310-
For now, we assume spatial communication is expected to be handled
311-
inside the model layers themselves. If we need to change course, we
312-
can revisit...
340+
Spatial model parallelism is handled by:
341+
- Forward: communication inside model layers (distributed SHT/iSHT)
342+
- Backward: gradient hooks registered here that all-reduce across
343+
spatial ranks, so every rank sees the global-mean gradient.
344+
345+
``broadcast_buffers=False`` is required because the SHT/iSHT layers
346+
store precomputed Legendre polynomial buffers. DDP's default
347+
buffer broadcast modifies these in-place between forward calls,
348+
which breaks autograd's tensor-version tracking.
313349
"""
314350
if any(p.requires_grad for p in module.parameters()):
315351
if using_gpu():
316352
output_device = [self._device_id]
317353
else:
318354
output_device = None
319-
return DistributedDataParallel(
355+
wrapped = DistributedDataParallel(
320356
SyncBatchNorm.convert_sync_batchnorm(module),
321357
device_ids=self._device_ids,
322358
output_device=output_device,
323359
process_group=self._data_group,
360+
broadcast_buffers=False,
324361
)
362+
self._register_spatial_grad_hooks(wrapped)
363+
return wrapped
325364
return DummyWrapper(module)
326365

366+
def _register_spatial_grad_hooks(self, module: torch.nn.Module) -> None:
367+
"""All-reduce gradients across spatial ranks after each backward.
368+
369+
Each spatial rank only sees its local slice of the input, so its
370+
gradient is a partial sum. This hook averages those partials so
371+
that every rank applies the same weight update.
372+
373+
The hook fires via ``register_post_accumulate_grad_hook`` which
374+
runs when ``.grad`` is written — before DDP's data-parallel
375+
all-reduce. The two reductions commute (orthogonal groups), so
376+
ordering does not matter.
377+
"""
378+
if self._h_size <= 1 and self._w_size <= 1:
379+
return
380+
spatial_group = self._spatial_group
381+
n_spatial = self._h_size * self._w_size
382+
383+
def _hook(param: torch.nn.Parameter) -> None:
384+
if param.grad is not None:
385+
torch.distributed.all_reduce(
386+
param.grad, group=spatial_group
387+
)
388+
param.grad /= n_spatial
389+
390+
for param in module.parameters():
391+
if param.requires_grad:
392+
param.register_post_accumulate_grad_hook(_hook)
393+
327394
def barrier(self):
328395
"""Global barrier across all ranks."""
329396
logger.debug("Barrier on rank %d", self._rank)
330397
torch.distributed.barrier(device_ids=self._device_ids)
331398

332399
def spatial_reduce_sum(self, tensor: torch.Tensor) -> torch.Tensor:
333400
if self._h_size > 1 or self._w_size > 1:
334-
torch.distributed.all_reduce(tensor, group=self._spatial_group)
401+
return _AutogradAllReduce.apply(tensor, self._spatial_group)
335402
return tensor
336403

337404
def weighted_mean(

0 commit comments

Comments
 (0)