|
| 1 | +""" |
| 2 | +Parallel regression tests for the SingleModuleStepper with NoiseConditionedSFNO. |
| 3 | +
|
| 4 | +These tests verify that the forward pass and loss computation produce identical |
| 5 | +results regardless of spatial decomposition (nproc=1 vs model-parallel). |
| 6 | +""" |
| 7 | + |
| 8 | +import dataclasses |
| 9 | +import datetime |
| 10 | +import os |
| 11 | +from collections.abc import Mapping |
| 12 | + |
| 13 | +import numpy as np |
| 14 | +import pytest |
| 15 | +import torch |
| 16 | +import xarray as xr |
| 17 | + |
| 18 | +from fme.ace.data_loading.batch_data import BatchData, PrognosticState |
| 19 | +from fme.ace.registry.stochastic_sfno import NoiseConditionedSFNOBuilder |
| 20 | +from fme.ace.stepper.single_module import ( |
| 21 | + StepperConfig, |
| 22 | + TrainOutput, |
| 23 | + TrainStepper, |
| 24 | + TrainStepperConfig, |
| 25 | +) |
| 26 | +from fme.core.coordinates import HybridSigmaPressureCoordinate, LatLonCoordinates |
| 27 | +from fme.core.dataset_info import DatasetInfo |
| 28 | +from fme.core.device import get_device |
| 29 | +from fme.core.distributed.distributed import Distributed |
| 30 | +from fme.core.loss import StepLossConfig |
| 31 | +from fme.core.normalizer import NetworkAndLossNormalizationConfig, NormalizationConfig |
| 32 | +from fme.core.optimization import NullOptimization |
| 33 | +from fme.core.registry.module import ModuleSelector |
| 34 | +from fme.core.step import SingleModuleStepConfig, StepSelector |
| 35 | +from fme.core.testing.regression import validate_tensor_dict |
| 36 | +from fme.core.typing_ import EnsembleTensorDict |
| 37 | + |
| 38 | +DIR = os.path.abspath(os.path.dirname(__file__)) |
| 39 | +TIMESTEP = datetime.timedelta(hours=6) |
| 40 | + |
| 41 | + |
| 42 | +def get_dataset_info( |
| 43 | + img_shape=(5, 5), |
| 44 | +) -> DatasetInfo: |
| 45 | + horizontal_coordinate = LatLonCoordinates( |
| 46 | + lat=torch.zeros(img_shape[-2]), |
| 47 | + lon=torch.zeros(img_shape[-1]), |
| 48 | + ) |
| 49 | + vertical_coordinate = HybridSigmaPressureCoordinate( |
| 50 | + ak=torch.arange(7), bk=torch.arange(7) |
| 51 | + ) |
| 52 | + return DatasetInfo( |
| 53 | + horizontal_coordinates=horizontal_coordinate, |
| 54 | + vertical_coordinate=vertical_coordinate, |
| 55 | + timestep=TIMESTEP, |
| 56 | + ) |
| 57 | + |
| 58 | + |
| 59 | +def _get_train_stepper( |
| 60 | + stepper_config: StepperConfig, |
| 61 | + dataset_info: DatasetInfo, |
| 62 | + **train_config_kwargs, |
| 63 | +) -> TrainStepper: |
| 64 | + train_config = TrainStepperConfig(**train_config_kwargs) |
| 65 | + return train_config.get_train_stepper(stepper_config, dataset_info) |
| 66 | + |
| 67 | + |
| 68 | +def get_regression_stepper_and_data() -> ( |
| 69 | + tuple[TrainStepper, BatchData, tuple[int, int]] |
| 70 | +): |
| 71 | + in_names = ["a", "b"] |
| 72 | + out_names = ["b", "c"] |
| 73 | + n_forward_steps = 2 |
| 74 | + n_samples = 3 |
| 75 | + img_shape = (9, 18) |
| 76 | + device = get_device() |
| 77 | + |
| 78 | + all_names = list(set(in_names + out_names)) |
| 79 | + |
| 80 | + loss = StepLossConfig(type="AreaWeightedMSE") |
| 81 | + |
| 82 | + config = StepperConfig( |
| 83 | + step=StepSelector( |
| 84 | + type="single_module", |
| 85 | + config=dataclasses.asdict( |
| 86 | + SingleModuleStepConfig( |
| 87 | + builder=ModuleSelector( |
| 88 | + type="NoiseConditionedSFNO", |
| 89 | + config=dataclasses.asdict( |
| 90 | + NoiseConditionedSFNOBuilder( |
| 91 | + embed_dim=16, |
| 92 | + num_layers=2, |
| 93 | + noise_embed_dim=16, |
| 94 | + noise_type="isotropic", |
| 95 | + ) |
| 96 | + ), |
| 97 | + ), |
| 98 | + in_names=in_names, |
| 99 | + out_names=out_names, |
| 100 | + normalization=NetworkAndLossNormalizationConfig( |
| 101 | + network=NormalizationConfig( |
| 102 | + means={n: 0.1 for n in all_names}, |
| 103 | + stds={n: 1.1 for n in all_names}, |
| 104 | + ), |
| 105 | + ), |
| 106 | + ocean=None, |
| 107 | + ) |
| 108 | + ), |
| 109 | + ), |
| 110 | + ) |
| 111 | + |
| 112 | + dataset_info = get_dataset_info(img_shape=img_shape) |
| 113 | + train_stepper = _get_train_stepper(config, dataset_info, loss=loss) |
| 114 | + data = BatchData.new_on_device( |
| 115 | + data={ |
| 116 | + "a": torch.randn(n_samples, n_forward_steps + 1, *img_shape).to(device), |
| 117 | + "b": torch.randn(n_samples, n_forward_steps + 1, *img_shape).to(device), |
| 118 | + "c": torch.randn(n_samples, n_forward_steps + 1, *img_shape).to(device), |
| 119 | + }, |
| 120 | + time=xr.DataArray( |
| 121 | + np.zeros((n_samples, n_forward_steps + 1)), |
| 122 | + dims=["sample", "time"], |
| 123 | + ), |
| 124 | + labels=None, |
| 125 | + epoch=0, |
| 126 | + horizontal_dims=["lat", "lon"], |
| 127 | + ) |
| 128 | + data = data.scatter_spatial(img_shape) |
| 129 | + return train_stepper, data, img_shape |
| 130 | + |
| 131 | + |
| 132 | +def flatten_dict( |
| 133 | + d: Mapping[str, Mapping[str, torch.Tensor]], |
| 134 | +) -> dict[str, torch.Tensor]: |
| 135 | + return_dict = {} |
| 136 | + for k, v in d.items(): |
| 137 | + for k2, v2 in v.items(): |
| 138 | + return_dict[f"{k}.{k2}"] = v2 |
| 139 | + return return_dict |
| 140 | + |
| 141 | + |
| 142 | +def _get_train_output_tensor_dict(data: TrainOutput) -> dict[str, torch.Tensor]: |
| 143 | + return_dict = {} |
| 144 | + for k, v in data.metrics.items(): |
| 145 | + return_dict[f"metrics.{k}"] = v |
| 146 | + for k, v in data.gen_data.items(): |
| 147 | + return_dict[f"gen_data.{k}"] = v |
| 148 | + for k, v in data.target_data.items(): |
| 149 | + assert v.shape[1] == 1 |
| 150 | + return_dict[f"target_data.{k}"] = v |
| 151 | + return return_dict |
| 152 | + |
| 153 | + |
| 154 | +def get_train_outputs_tensor_dict( |
| 155 | + step_1: TrainOutput, step_2: TrainOutput |
| 156 | +) -> dict[str, torch.Tensor]: |
| 157 | + return flatten_dict( |
| 158 | + { |
| 159 | + "step_1": _get_train_output_tensor_dict(step_1), |
| 160 | + "step_2": _get_train_output_tensor_dict(step_2), |
| 161 | + } |
| 162 | + ) |
| 163 | + |
| 164 | + |
| 165 | +def get_predict_output_tensor_dict( |
| 166 | + output: BatchData, next_state: PrognosticState |
| 167 | +) -> dict[str, torch.Tensor]: |
| 168 | + return flatten_dict( |
| 169 | + { |
| 170 | + "output": output.data, |
| 171 | + "next_state": next_state.as_batch_data().data, |
| 172 | + } |
| 173 | + ) |
| 174 | + |
| 175 | + |
| 176 | +@pytest.mark.parallel |
| 177 | +def test_stepper_train_on_batch_regression(): |
| 178 | + torch.manual_seed(0) |
| 179 | + train_stepper, data, img_shape = get_regression_stepper_and_data() |
| 180 | + optimization = NullOptimization() |
| 181 | + result1 = train_stepper.train_on_batch(data, optimization) |
| 182 | + result2 = train_stepper.train_on_batch(data, optimization) |
| 183 | + dist = Distributed.get_instance() |
| 184 | + for result in [result1, result2]: |
| 185 | + result.gen_data = EnsembleTensorDict( |
| 186 | + dist.gather_spatial(dict(result.gen_data), img_shape) |
| 187 | + ) |
| 188 | + result.target_data = EnsembleTensorDict( |
| 189 | + dist.gather_spatial(dict(result.target_data), img_shape) |
| 190 | + ) |
| 191 | + output_dict = get_train_outputs_tensor_dict(result1, result2) |
| 192 | + validate_tensor_dict( |
| 193 | + output_dict, |
| 194 | + os.path.join( |
| 195 | + DIR, |
| 196 | + "testdata/csfno_stepper_train_on_batch_regression.pt", |
| 197 | + ), |
| 198 | + atol=1e-4, |
| 199 | + rtol=1e-4, |
| 200 | + ) |
| 201 | + |
| 202 | + |
| 203 | +@pytest.mark.parallel |
| 204 | +def test_stepper_predict_regression(): |
| 205 | + torch.manual_seed(0) |
| 206 | + train_stepper, data, img_shape = get_regression_stepper_and_data() |
| 207 | + stepper = train_stepper._stepper |
| 208 | + initial_condition = data.get_start( |
| 209 | + prognostic_names=["b"], |
| 210 | + n_ic_timesteps=1, |
| 211 | + ) |
| 212 | + output, next_state = stepper.predict( |
| 213 | + initial_condition, data, compute_derived_variables=True |
| 214 | + ) |
| 215 | + dist = Distributed.get_instance() |
| 216 | + output_data = dist.gather_spatial(dict(output.data), img_shape) |
| 217 | + next_state_data = dist.gather_spatial( |
| 218 | + dict(next_state.as_batch_data().data), img_shape |
| 219 | + ) |
| 220 | + output_dict = flatten_dict({"output": output_data, "next_state": next_state_data}) |
| 221 | + validate_tensor_dict( |
| 222 | + output_dict, |
| 223 | + os.path.join(DIR, "testdata/csfno_stepper_predict_regression.pt"), |
| 224 | + atol=1e-4, |
| 225 | + rtol=1e-4, |
| 226 | + ) |
0 commit comments