Skip to content

Commit d920bd9

Browse files
mahf708claudepeterdschwartz
committed
Enable training under spatial model parallelism (ModelTorchDistributed)
Two fixes in model_torch_distributed.py: 1. Set broadcast_buffers=False in DDP wrapping — DDP's default buffer broadcast modifies SHT/iSHT Legendre polynomial buffers in-place between forward calls, breaking autograd's version tracking. 2. Register spatial gradient hooks (_register_spatial_grad_hooks) that all-reduce parameter gradients across spatial ranks after backward, so every rank applies the same weight update. This is the spatial analogue of what DDP does for data parallelism. Add test_single_module_csfno.py with parallel regression tests using NoiseConditionedSFNO (which supports spatial parallelism). Tests use AreaWeightedMSE loss which correctly reduces across spatial ranks via the existing gridded_operations.area_weighted_mean → weighted_mean → spatial_reduce_sum path. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: peterdschwartz <peterdschwartz83@gmail.com>
1 parent fca6446 commit d920bd9

4 files changed

Lines changed: 294 additions & 5 deletions

File tree

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
"""
2+
Parallel regression tests for the SingleModuleStepper with NoiseConditionedSFNO.
3+
4+
These tests verify that the forward pass and loss computation produce identical
5+
results regardless of spatial decomposition (nproc=1 vs model-parallel).
6+
"""
7+
8+
import dataclasses
9+
import datetime
10+
import os
11+
from collections.abc import Mapping
12+
13+
import numpy as np
14+
import pytest
15+
import torch
16+
import xarray as xr
17+
18+
from fme.ace.data_loading.batch_data import BatchData, PrognosticState
19+
from fme.ace.registry.stochastic_sfno import NoiseConditionedSFNOBuilder
20+
from fme.ace.stepper.single_module import (
21+
StepperConfig,
22+
TrainOutput,
23+
TrainStepper,
24+
TrainStepperConfig,
25+
)
26+
from fme.core.coordinates import HybridSigmaPressureCoordinate, LatLonCoordinates
27+
from fme.core.dataset_info import DatasetInfo
28+
from fme.core.device import get_device
29+
from fme.core.distributed.distributed import Distributed
30+
from fme.core.loss import StepLossConfig
31+
from fme.core.normalizer import NetworkAndLossNormalizationConfig, NormalizationConfig
32+
from fme.core.optimization import NullOptimization
33+
from fme.core.registry.module import ModuleSelector
34+
from fme.core.step import SingleModuleStepConfig, StepSelector
35+
from fme.core.testing.regression import validate_tensor_dict
36+
from fme.core.typing_ import EnsembleTensorDict
37+
38+
DIR = os.path.abspath(os.path.dirname(__file__))
39+
TIMESTEP = datetime.timedelta(hours=6)
40+
41+
42+
def get_dataset_info(
43+
img_shape=(5, 5),
44+
) -> DatasetInfo:
45+
horizontal_coordinate = LatLonCoordinates(
46+
lat=torch.zeros(img_shape[-2]),
47+
lon=torch.zeros(img_shape[-1]),
48+
)
49+
vertical_coordinate = HybridSigmaPressureCoordinate(
50+
ak=torch.arange(7), bk=torch.arange(7)
51+
)
52+
return DatasetInfo(
53+
horizontal_coordinates=horizontal_coordinate,
54+
vertical_coordinate=vertical_coordinate,
55+
timestep=TIMESTEP,
56+
)
57+
58+
59+
def _get_train_stepper(
60+
stepper_config: StepperConfig,
61+
dataset_info: DatasetInfo,
62+
**train_config_kwargs,
63+
) -> TrainStepper:
64+
train_config = TrainStepperConfig(**train_config_kwargs)
65+
return train_config.get_train_stepper(stepper_config, dataset_info)
66+
67+
68+
def get_regression_stepper_and_data() -> (
69+
tuple[TrainStepper, BatchData, tuple[int, int]]
70+
):
71+
in_names = ["a", "b"]
72+
out_names = ["b", "c"]
73+
n_forward_steps = 2
74+
n_samples = 3
75+
img_shape = (9, 18)
76+
device = get_device()
77+
78+
all_names = list(set(in_names + out_names))
79+
80+
loss = StepLossConfig(type="AreaWeightedMSE")
81+
82+
config = StepperConfig(
83+
step=StepSelector(
84+
type="single_module",
85+
config=dataclasses.asdict(
86+
SingleModuleStepConfig(
87+
builder=ModuleSelector(
88+
type="NoiseConditionedSFNO",
89+
config=dataclasses.asdict(
90+
NoiseConditionedSFNOBuilder(
91+
embed_dim=16,
92+
num_layers=2,
93+
noise_embed_dim=16,
94+
noise_type="isotropic",
95+
)
96+
),
97+
),
98+
in_names=in_names,
99+
out_names=out_names,
100+
normalization=NetworkAndLossNormalizationConfig(
101+
network=NormalizationConfig(
102+
means={n: 0.1 for n in all_names},
103+
stds={n: 1.1 for n in all_names},
104+
),
105+
),
106+
ocean=None,
107+
)
108+
),
109+
),
110+
)
111+
112+
dataset_info = get_dataset_info(img_shape=img_shape)
113+
train_stepper = _get_train_stepper(config, dataset_info, loss=loss)
114+
data = BatchData.new_on_device(
115+
data={
116+
"a": torch.randn(n_samples, n_forward_steps + 1, *img_shape).to(device),
117+
"b": torch.randn(n_samples, n_forward_steps + 1, *img_shape).to(device),
118+
"c": torch.randn(n_samples, n_forward_steps + 1, *img_shape).to(device),
119+
},
120+
time=xr.DataArray(
121+
np.zeros((n_samples, n_forward_steps + 1)),
122+
dims=["sample", "time"],
123+
),
124+
labels=None,
125+
epoch=0,
126+
horizontal_dims=["lat", "lon"],
127+
)
128+
data = data.scatter_spatial(img_shape)
129+
return train_stepper, data, img_shape
130+
131+
132+
def flatten_dict(
133+
d: Mapping[str, Mapping[str, torch.Tensor]],
134+
) -> dict[str, torch.Tensor]:
135+
return_dict = {}
136+
for k, v in d.items():
137+
for k2, v2 in v.items():
138+
return_dict[f"{k}.{k2}"] = v2
139+
return return_dict
140+
141+
142+
def _get_train_output_tensor_dict(data: TrainOutput) -> dict[str, torch.Tensor]:
143+
return_dict = {}
144+
for k, v in data.metrics.items():
145+
return_dict[f"metrics.{k}"] = v
146+
for k, v in data.gen_data.items():
147+
return_dict[f"gen_data.{k}"] = v
148+
for k, v in data.target_data.items():
149+
assert v.shape[1] == 1
150+
return_dict[f"target_data.{k}"] = v
151+
return return_dict
152+
153+
154+
def get_train_outputs_tensor_dict(
155+
step_1: TrainOutput, step_2: TrainOutput
156+
) -> dict[str, torch.Tensor]:
157+
return flatten_dict(
158+
{
159+
"step_1": _get_train_output_tensor_dict(step_1),
160+
"step_2": _get_train_output_tensor_dict(step_2),
161+
}
162+
)
163+
164+
165+
def get_predict_output_tensor_dict(
166+
output: BatchData, next_state: PrognosticState
167+
) -> dict[str, torch.Tensor]:
168+
return flatten_dict(
169+
{
170+
"output": output.data,
171+
"next_state": next_state.as_batch_data().data,
172+
}
173+
)
174+
175+
176+
@pytest.mark.parallel
177+
def test_stepper_train_on_batch_regression():
178+
torch.manual_seed(0)
179+
train_stepper, data, img_shape = get_regression_stepper_and_data()
180+
optimization = NullOptimization()
181+
result1 = train_stepper.train_on_batch(data, optimization)
182+
result2 = train_stepper.train_on_batch(data, optimization)
183+
dist = Distributed.get_instance()
184+
for result in [result1, result2]:
185+
result.gen_data = EnsembleTensorDict(
186+
dist.gather_spatial(dict(result.gen_data), img_shape)
187+
)
188+
result.target_data = EnsembleTensorDict(
189+
dist.gather_spatial(dict(result.target_data), img_shape)
190+
)
191+
output_dict = get_train_outputs_tensor_dict(result1, result2)
192+
validate_tensor_dict(
193+
output_dict,
194+
os.path.join(
195+
DIR,
196+
"testdata/csfno_stepper_train_on_batch_regression.pt",
197+
),
198+
atol=1e-4,
199+
rtol=1e-4,
200+
)
201+
202+
203+
@pytest.mark.parallel
204+
def test_stepper_predict_regression():
205+
torch.manual_seed(0)
206+
train_stepper, data, img_shape = get_regression_stepper_and_data()
207+
stepper = train_stepper._stepper
208+
initial_condition = data.get_start(
209+
prognostic_names=["b"],
210+
n_ic_timesteps=1,
211+
)
212+
output, next_state = stepper.predict(
213+
initial_condition, data, compute_derived_variables=True
214+
)
215+
dist = Distributed.get_instance()
216+
output_data = dist.gather_spatial(dict(output.data), img_shape)
217+
next_state_data = dist.gather_spatial(
218+
dict(next_state.as_batch_data().data), img_shape
219+
)
220+
output_dict = flatten_dict({"output": output_data, "next_state": next_state_data})
221+
validate_tensor_dict(
222+
output_dict,
223+
os.path.join(DIR, "testdata/csfno_stepper_predict_regression.pt"),
224+
atol=1e-4,
225+
rtol=1e-4,
226+
)
11.9 KB
Binary file not shown.
Binary file not shown.

fme/core/distributed/model_torch_distributed.py

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import torch.distributed
2626
import torch.nn as nn
2727
import torch_harmonics.distributed as thd
28+
from torch.amp import custom_bwd, custom_fwd
2829
from torch.nn import SyncBatchNorm
2930
from torch.nn.parallel import DistributedDataParallel
3031

@@ -42,6 +43,35 @@
4243
T = TypeVar("T")
4344

4445

46+
class _AutogradAllReduce(torch.autograd.Function):
47+
"""Autograd-aware all-reduce (sum) for spatial parallelism.
48+
Forward: all-reduce (sum) the input across the given process group.
49+
Backward: identity — gradients pass through without communication.
50+
This makes ``spatial_reduce_sum`` differentiable so that gradients
51+
flow correctly through the loss computation path::
52+
AreaWeightedMSELoss → area_weighted_mean → weighted_mean
53+
→ spatial_reduce_sum (uses this function)
54+
Without this, the raw ``torch.distributed.all_reduce`` would break
55+
the autograd graph because it is an in-place, non-differentiable op.
56+
"""
57+
58+
@staticmethod
59+
@custom_fwd(device_type="cuda")
60+
def forward(
61+
ctx,
62+
input: torch.Tensor,
63+
group: torch.distributed.ProcessGroup,
64+
) -> torch.Tensor:
65+
output = input.clone()
66+
torch.distributed.all_reduce(output, group=group)
67+
return output
68+
69+
@staticmethod
70+
@custom_bwd(device_type="cuda")
71+
def backward(ctx, grad_output: torch.Tensor):
72+
return grad_output.clone(), None
73+
74+
4575
class ModelTorchDistributed(DistributedBackend):
4676
"""Distributed backend with spatial model parallelism.
4777
@@ -307,31 +337,64 @@ def _device_ids(self) -> list[int] | None:
307337
def wrap_module(self, module: torch.nn.Module) -> torch.nn.Module:
308338
"""Wrap with DDP over the **data** process group.
309339
310-
For now, we assume spatial communication is expected to be handled
311-
inside the model layers themselves. If we need to change course, we
312-
can revisit...
340+
Spatial model parallelism is handled by:
341+
- Forward: communication inside model layers (distributed SHT/iSHT)
342+
- Backward: gradient hooks registered here that all-reduce across
343+
spatial ranks, so every rank sees the global-mean gradient.
344+
345+
``broadcast_buffers=False`` is required because the SHT/iSHT layers
346+
store precomputed Legendre polynomial buffers. DDP's default
347+
buffer broadcast modifies these in-place between forward calls,
348+
which breaks autograd's tensor-version tracking.
313349
"""
314350
if any(p.requires_grad for p in module.parameters()):
315351
if using_gpu():
316352
output_device = [self._device_id]
317353
else:
318354
output_device = None
319-
return DistributedDataParallel(
355+
wrapped = DistributedDataParallel(
320356
SyncBatchNorm.convert_sync_batchnorm(module),
321357
device_ids=self._device_ids,
322358
output_device=output_device,
323359
process_group=self._data_group,
360+
broadcast_buffers=False,
324361
)
362+
self._register_spatial_grad_hooks(wrapped)
363+
return wrapped
325364
return DummyWrapper(module)
326365

366+
def _register_spatial_grad_hooks(self, module: torch.nn.Module) -> None:
367+
"""All-reduce gradients across spatial ranks after each backward.
368+
369+
Each spatial rank only sees its local slice of the input, so its
370+
gradient is a partial sum. This hook sums those partials so
371+
that every rank applies the same weight update.
372+
373+
The hook fires via ``register_post_accumulate_grad_hook`` which
374+
runs when ``.grad`` is written — before DDP's data-parallel
375+
all-reduce. The two reductions commute (orthogonal groups), so
376+
ordering does not matter.
377+
"""
378+
if self._h_size <= 1 and self._w_size <= 1:
379+
return
380+
spatial_group = self._spatial_group
381+
382+
def _hook(param: torch.nn.Parameter) -> None:
383+
if param.grad is not None:
384+
torch.distributed.all_reduce(param.grad, group=spatial_group)
385+
386+
for param in module.parameters():
387+
if param.requires_grad:
388+
param.register_post_accumulate_grad_hook(_hook)
389+
327390
def barrier(self):
328391
"""Global barrier across all ranks."""
329392
logger.debug("Barrier on rank %d", self._rank)
330393
torch.distributed.barrier(device_ids=self._device_ids)
331394

332395
def spatial_reduce_sum(self, tensor: torch.Tensor) -> torch.Tensor:
333396
if self._h_size > 1 or self._w_size > 1:
334-
torch.distributed.all_reduce(tensor, group=self._spatial_group)
397+
return _AutogradAllReduce.apply(tensor, self._spatial_group)
335398
return tensor
336399

337400
def weighted_mean(

0 commit comments

Comments
 (0)