Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d6d51da
Write first solution with Claude
Feb 18, 2026
a6e4c25
Add test configs, works on santis
sophie-xhonneux Feb 19, 2026
e65241a
Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into soph…
clessig Feb 24, 2026
66de4c0
Merge branch 'develop' into sophiex/dev/warm-and-frozen-teachers
clessig Feb 24, 2026
9fbe081
Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into soph…
clessig Feb 27, 2026
af9a02c
Disabling rope; removing model config from finetuning since it needs …
clessig Feb 27, 2026
da23d8a
Merge branch 'develop' into sophiex/dev/warm-and-frozen-teachers
sophie-xhonneux Feb 27, 2026
1ccad2b
Add new JEPA config
sophie-xhonneux Feb 27, 2026
ac59fc3
Address comments on PR
sophie-xhonneux Mar 6, 2026
026e272
Merge branch 'develop' into sophiex/dev/warm-and-frozen-teachers
clessig Mar 10, 2026
c2b40b9
Merge branch 'develop' of https://github.com/ecmwf/WeatherGenerator i…
clessig Mar 11, 2026
8f19264
Linting
clessig Mar 11, 2026
9116f0d
Linting
clessig Mar 11, 2026
d9a83ca
Fixed some corner cases in handling of when batch samples are NaN and
clessig Mar 11, 2026
e163c52
Fixed handling of when batch valid is
clessig Mar 11, 2026
2a54039
Fixed path handling
clessig Mar 11, 2026
16dc9eb
Fixed problems with loading of teacher model
clessig Mar 11, 2026
2079049
Revert incorrect changes to default_config
clessig Mar 23, 2026
9eff5ac
Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into soph…
clessig Mar 23, 2026
d2eccc5
Fix problem with missing run_id as dir in path for loading teacher model
clessig Mar 23, 2026
ccaf97c
Updated logging
clessig Mar 24, 2026
65ee9d3
Updated config
clessig Mar 24, 2026
3dba1eb
Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into soph…
clessig Mar 24, 2026
8899a91
Address PR review
sophie-xhonneux Mar 27, 2026
a93b7e9
Lint
sophie-xhonneux Mar 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ embed_orientation: "channels"
embed_unembed_mode: "block"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert changes to default config

embed_dropout_rate: 0.1

ae_local_dim_embed: 512 #1024
ae_local_dim_embed: 1024
ae_local_num_blocks: 2
ae_local_num_heads: 16
ae_local_dropout_rate: 0.1
Expand All @@ -25,7 +25,7 @@ ae_adapter_with_qk_lnorm: True
ae_adapter_with_residual: True
ae_adapter_dropout_rate: 0.1

ae_global_dim_embed: 512 #1024 #2048
ae_global_dim_embed: 2048
ae_global_num_blocks: 2
ae_global_num_heads: 32
ae_global_dropout_rate: 0.1
Expand All @@ -37,7 +37,7 @@ ae_global_block_factor: 64
ae_global_mlp_hidden_factor: 2
ae_global_trailing_layer_norm: False

ae_aggregation_num_blocks: 2
ae_aggregation_num_blocks: 0
ae_aggregation_num_heads: 32
ae_aggregation_dropout_rate: 0.1
ae_aggregation_with_qk_lnorm: True
Expand All @@ -50,8 +50,6 @@ pred_adapter_kv: False
pred_self_attention: True
pred_dyadic_dims: False
pred_mlp_adaln: True
num_class_tokens: 1
num_register_tokens: 7

# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then
# one is training an auto-encoder
Expand All @@ -64,6 +62,9 @@ fe_impute_latent_noise_std: 0.0 # 1e-4
# currently fixed to 1.0 (due to limitations with flex_attention and triton)
forecast_att_dense_rate: 1.0

num_class_tokens: 0
num_register_tokens: 0

healpix_level: 5

# Use 2D RoPE instead of traditional global positional encoding
Expand Down
4 changes: 3 additions & 1 deletion packages/common/src/weathergen/common/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,7 @@ def get_wg_private_path() -> Path:
path = _REPO_ROOT.parent / "WeatherGenerator-private"

path = path.resolve()
assert path.is_dir(), f"WeatherGenerator private repo path does not exist or is not a directory: {path}"
assert path.is_dir(), (
f"WeatherGenerator private repo path does not exist or is not a directory: {path}"
)
return path
3 changes: 1 addition & 2 deletions packages/evaluate/src/weathergen/evaluate/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,7 @@ def _process_stream(
)
metric_list_to_json(reader, stream, stream_computed_scores, regions_to_compute)
scores_dict = merge(stream_loaded_scores, stream_computed_scores)
return run_id, stream, scores_dict

return run_id, stream, scores_dict


# except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class DataReaderMesh(DataReaderTimestep):
- Robust Multi-Node/Worker support (Fork-safe, Dask-safe).
- Dynamic Patching (local) OR Global Sparse Sampling.
"""

def __init__(
self,
tw_handler: TimeWindowHandler,
Expand Down Expand Up @@ -123,7 +124,10 @@ def __init__(
self.roi_min_lon, self.roi_min_lat, self.roi_max_lon, self.roi_max_lat = self.roi
else:
self.roi_min_lon, self.roi_min_lat, self.roi_max_lon, self.roi_max_lat = (
-180.0, -90.0, 180.0, 90.0
-180.0,
-90.0,
180.0,
90.0,
)

self.available_channels = list(self.col_map.keys())
Expand All @@ -133,7 +137,7 @@ def __init__(
self.source_idx = self._select_channels("source")
self.target_idx = self._select_channels("target")
self.geoinfo_idx = []
self.geoinfo_channels =[]
self.geoinfo_channels = []

self.source_channels = [self.available_channels[i] for i in self.source_idx]
self.target_channels = [self.available_channels[i] for i in self.target_idx]
Expand All @@ -146,7 +150,7 @@ def _probe_file(self, filepath, is_source=True):
with xr.open_dataset(mapper, engine="zarr", chunks={}, consolidated=False) as ds:
if "time" not in ds.coords:
all_vars = list(ds.coords) + list(ds.data_vars)
time_candidates =[v for v in all_vars if "time" in v.lower()]
time_candidates = [v for v in all_vars if "time" in v.lower()]
if time_candidates:
target = time_candidates[0]
if target in ds.data_vars:
Expand Down Expand Up @@ -208,35 +212,25 @@ def _lazy_init(self):
if self._initialized:
return

self.mapper_src = fsspec.get_mapper("reference://",
fo=str(self.filename_source),
remote_protocol="file"
)
self.mapper_src = fsspec.get_mapper(
"reference://", fo=str(self.filename_source), remote_protocol="file"
)
import warnings

with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=".*separate the stored chunks.*")
self.ds_source = xr.open_dataset(
self.mapper_src,
engine="zarr",
chunks={},
decode_times=True,
consolidated=False
self.mapper_src, engine="zarr", chunks={}, decode_times=True, consolidated=False
)

if self.filename_target != self.filename_source:
self.mapper_trg = fsspec.get_mapper(
"reference://",
fo=str(self.filename_target),
remote_protocol="file"
"reference://", fo=str(self.filename_target), remote_protocol="file"
)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=".*separate the stored chunks.*")
self.ds_target = xr.open_dataset(
self.mapper_trg,
engine="zarr",
chunks={},
decode_times=True,
consolidated=False
self.mapper_trg, engine="zarr", chunks={}, decode_times=True, consolidated=False
)
else:
self.ds_target = self.ds_source
Expand Down Expand Up @@ -284,10 +278,10 @@ def _fetch_data(self, idx: TIndex, channels: list[str], is_source: bool) -> Read
if len(t_idxs) == 0 or not channels:
return ReaderData.empty(len(channels), 0)

channel_indices =[self.available_channels.index(c) for c in channels]
channel_indices = [self.available_channels.index(c) for c in channels]
start_t, end_t = t_idxs[0], t_idxs[-1] + 1
n_steps = len(t_idxs)

lats_ref = self.lats_src if is_source else self.lats_trg
spatial_indices_ref = self.spatial_indices_src if is_source else self.spatial_indices_trg
coords_ref = self.coords_src if is_source else self.coords_trg
Expand Down Expand Up @@ -322,12 +316,16 @@ def _fetch_data(self, idx: TIndex, channels: list[str], is_source: bool) -> Read
lon_0 = lon_0_candidates[attempts]

mask_src = (
(self.lats_src >= lat_0) & (self.lats_src < lat_0 + self.patch_size_deg) &
(self.lons_src >= lon_0) & (self.lons_src < lon_0 + self.patch_size_deg)
(self.lats_src >= lat_0)
& (self.lats_src < lat_0 + self.patch_size_deg)
& (self.lons_src >= lon_0)
& (self.lons_src < lon_0 + self.patch_size_deg)
)
mask_trg = (
(self.lats_trg >= lat_0) & (self.lats_trg < lat_0 + self.patch_size_deg) &
(self.lons_trg >= lon_0) & (self.lons_trg < lon_0 + self.patch_size_deg)
(self.lats_trg >= lat_0)
& (self.lats_trg < lat_0 + self.patch_size_deg)
& (self.lons_trg >= lon_0)
& (self.lons_trg < lon_0 + self.patch_size_deg)
)

pts_src = np.count_nonzero(mask_src)
Expand All @@ -344,11 +342,15 @@ def _fetch_data(self, idx: TIndex, channels: list[str], is_source: bool) -> Read
len(lats_ref), size=req_points, replace=False
)

patch_coords_base = self.coords_src[patch_indices_local] if is_source else (
self.coords_trg[patch_indices_local]
patch_coords_base = (
self.coords_src[patch_indices_local]
if is_source
else (self.coords_trg[patch_indices_local])
)
final_disk_indices = self.spatial_indices_src[patch_indices_local] if is_source else (
self.spatial_indices_trg[patch_indices_local]
final_disk_indices = (
self.spatial_indices_src[patch_indices_local]
if is_source
else (self.spatial_indices_trg[patch_indices_local])
)
use_contiguous_read = True

Expand All @@ -361,24 +363,25 @@ def _fetch_data(self, idx: TIndex, channels: list[str], is_source: bool) -> Read
disk_start, disk_stop = np.min(final_disk_indices), np.max(final_disk_indices) + 1
rel_indices = final_disk_indices - disk_start
data_block = self._load_block_from_ds(
ds_ref,
arr_cache,
channel_indices,
start_t,
end_t,
n_steps,
slice(disk_start, disk_stop),
rel_indices
ds_ref,
arr_cache,
channel_indices,
start_t,
end_t,
n_steps,
slice(disk_start, disk_stop),
rel_indices,
)
else:
data_block = self._load_block_from_ds(
ds_ref,
arr_cache,
channel_indices,
start_t, end_t,
n_steps,
final_disk_indices,
None
ds_ref,
arr_cache,
channel_indices,
start_t,
end_t,
n_steps,
final_disk_indices,
None,
)

if data_block.size > 0:
Expand All @@ -399,16 +402,8 @@ def _fetch_data(self, idx: TIndex, channels: list[str], is_source: bool) -> Read
return rdata

def _load_block_from_ds(
self,
ds,
arr_cache,
indices,
start_t,
end_t,
n_steps,
disk_indices,
rel_indices
) -> np.typing.NDArray:
self, ds, arr_cache, indices, start_t, end_t, n_steps, disk_indices, rel_indices
) -> np.typing.NDArray:
if rel_indices is not None:
num_points = len(rel_indices)
else:
Expand Down Expand Up @@ -462,7 +457,7 @@ def _load_block_from_ds(
if "time" in dims:
# Contiguous read: Apply raw disk bounds, then rel_indices
chunk = chunk[:, disk_indices]

# Safety check: if chunk is completely empty, fill with NaNs
if chunk.shape[1] == 0:
assert False, "Empty chunk after disk indexing with time dimension"
Expand Down Expand Up @@ -510,8 +505,8 @@ def _parse_attr(self, attrs, key):

def _select_channels(self, type_key: str) -> list[int]:
select = self._stream_info.get(type_key)
exclude = self._stream_info.get(f"{type_key}_exclude",[])
return[
exclude = self._stream_info.get(f"{type_key}_exclude", [])
return [
i
for i, ch in enumerate(self.available_channels)
if (not select or any(s in ch for s in select)) and not any(e in ch for e in exclude)
Expand Down Expand Up @@ -546,21 +541,17 @@ def normalize_target_channels(self, target: np.typing.NDArray) -> np.typing.NDAr
def denormalize_source_channels(self, source):
if isinstance(source, torch.Tensor):
stdev = torch.tensor(
self.stdev[self.source_idx],
dtype=source.dtype,
device=source.device
self.stdev[self.source_idx], dtype=source.dtype, device=source.device
)
mean = torch.tensor(
self.mean[self.source_idx],
dtype=source.dtype,
device=source.device
self.mean[self.source_idx], dtype=source.dtype, device=source.device
)
land_mask = (source == 0.0)
land_mask = source == 0.0
denorm = (source * stdev) + mean
denorm[land_mask] = torch.nan
return denorm
land_mask = (source == 0.0)

land_mask = source == 0.0
denorm = (source * self.stdev[self.source_idx]) + self.mean[self.source_idx]
denorm[land_mask] = np.nan
return denorm
Expand All @@ -576,4 +567,4 @@ def denormalize_target_channels(self, data):
@override
def normalize_geoinfos(self, geoinfos: np.typing.NDArray) -> np.typing.NDArray:
norm = (geoinfos - self.mean_geoinfo) / self.stdev_geoinfo
return np.nan_to_num(norm, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
return np.nan_to_num(norm, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
21 changes: 20 additions & 1 deletion src/weathergen/model/model_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
from weathergen.model.model import Model, ModelParams
from weathergen.model.utils import apply_fct_to_blocks, freeze_weights
from weathergen.train.target_and_aux_module_base import PhysicalTargetAndAux
from weathergen.train.target_and_aux_ssl_teacher import EMATeacher
from weathergen.train.target_and_aux_ssl_teacher import EMATeacher, FrozenTeacher
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

from weathergen.train.teacher_utils import load_encoder_from_checkpoint, prepare_encoder_teacher
from weathergen.utils.distributed import is_root
from weathergen.utils.utils import get_dtype

Expand Down Expand Up @@ -315,6 +316,13 @@ def get_target_aux_calculator(
with_fsdp=False,
overrides=target_and_aux_calc_params.get("model_param_overrides", {}),
)

# Strip to encoder + create fresh heads
cf_overridden = merge_configs(
cf, target_and_aux_calc_params.get("model_param_overrides", {})
)
prepare_encoder_teacher(meta_ema_model, cf.training_config, cf_overridden)

ema_model = EMAModel(
model,
meta_ema_model,
Expand All @@ -326,6 +334,17 @@ def get_target_aux_calculator(
batch_size = cf.get("world_size_original", cf.get("world_size")) * batch_size_per_gpu
target_aux = EMATeacher(model, ema_model, batch_size, cf.training_config)

# Optional: warm start encoder from checkpoint
teacher_run_id = target_and_aux_calc_params.get("teacher_run_id")
if teacher_run_id is not None:
teacher_mini_epoch = target_and_aux_calc_params.get("teacher_mini_epoch", -1)
load_encoder_from_checkpoint(
ema_model.ema_model, cf, teacher_run_id, teacher_mini_epoch, device
)

elif target_and_aux_calc == "FrozenTeacher":
target_aux = FrozenTeacher.from_pretrained(cf, dataset, device, target_and_aux_calc_params)

else:
raise NotImplementedError(f"{target_and_aux_calc} is not implemented")

Expand Down
Loading
Loading