Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
259 commits
Select commit Hold shift + click to select a range
c259c20
update loss calc config and rename files
Jubeku Nov 7, 2025
a19ee16
restructure loss modules
Jubeku Nov 11, 2025
bf3e128
add ModelOutput dataclass
Jubeku Nov 11, 2025
711f29b
First draft of diffusion model
MatKbauer Nov 11, 2025
81bd6eb
NOT WORKING: initial draft for index-based masking. Implemented for r…
clessig Nov 12, 2025
f367bb4
Minor modifications
MatKbauer Nov 12, 2025
1cc168c
Linter
MatKbauer Nov 12, 2025
48934c2
Copyright attribution to EDM
MatKbauer Nov 12, 2025
51f437f
NOT WORKING: Finished src, target still to be done.
clessig Nov 13, 2025
6046694
Adapt diffusion model to expected data structure
MatKbauer Nov 13, 2025
f66c9fa
Corrected data retrieval to only access model_samples and not target_…
MatKbauer Nov 13, 2025
7e48c39
Minor correction
MatKbauer Nov 13, 2025
e4a9cc0
Masking target is working in principle but errors when feeding data t…
clessig Nov 13, 2025
a581405
Working version for ERA5, NPP-ATMS. Problems with SYNOP with empty ce…
clessig Nov 13, 2025
9229e48
Minor cleanup
clessig Nov 13, 2025
db6f285
Fixed linting
clessig Nov 13, 2025
7866ff7
Restructuring and correcting forward pass during inference
MatKbauer Nov 14, 2025
ec38123
Fixed remaining problems that occured for NPP-ATMS and SYNOP.
clessig Nov 14, 2025
0634105
Enabled support for forecast. Cleaned up some bits and pieces.
clessig Nov 14, 2025
0fa60db
merge develop
Jubeku Nov 14, 2025
cab9fbe
mv streams_data declaration under if condition
Jubeku Nov 14, 2025
20da555
add weight to loss config, add toy loss class LossPhysicalTwo
Jubeku Nov 14, 2025
ce6c735
Removing centroids options for embedding that was unused and should n…
clessig Nov 14, 2025
8fa544d
Removed unused parameters
clessig Nov 14, 2025
d7b326b
fixed trainer for multiple terms in losses_all, still need to fix log…
Jubeku Nov 14, 2025
5d127bf
Inversion of target output ordering to match input one in forcast mod…
clessig Nov 16, 2025
b07aa3f
First steps to encode targets in latent space
MatKbauer Nov 16, 2025
3ffdc60
fix _log_terminal
Jubeku Nov 17, 2025
debbb8f
Changes to prepare_logging to apply index inversion
clessig Nov 17, 2025
ae5a2e6
added file with ModelBatch and SampleMetadata dataclasses
shmh40 Nov 17, 2025
7f3c718
Updating config to working version
clessig Nov 17, 2025
694d948
Encapsulated encoder and target encoding for latent diffusion model loss
MatKbauer Nov 17, 2025
beb4d6f
fix logging
Jubeku Nov 17, 2025
761e263
update ViewMetadata spec
shmh40 Nov 17, 2025
047b299
draft changes to allow global local view generation in masker and tok…
shmh40 Nov 17, 2025
7d5c300
draft of training_config in default_config
shmh40 Nov 17, 2025
c733280
change view_metadata to dict in ModelInput
shmh40 Nov 17, 2025
a934f97
NOT WORKING: updating class to handle multiple input steps and improv…
clessig Nov 18, 2025
ab9eecc
Merge branch 'shmh40/dev/1270-idx-global-local' of github.com:ecmwf/W…
clessig Nov 18, 2025
086aacb
Linter
MatKbauer Nov 18, 2025
c3b5c3b
Added basic support for multi-step sources.
clessig Nov 18, 2025
668912d
Partially enabled correct handling of multiple input steps.
clessig Nov 18, 2025
33394ff
initialize loss as torch tensor with grad
Jubeku Nov 18, 2025
bda52d8
remove level in hist losses dict
Jubeku Nov 18, 2025
053dddd
rename loss.py to loss_functions.py
Jubeku Nov 18, 2025
d094ad0
rename loss.py to loss_functions.py
Jubeku Nov 18, 2025
8b4cbef
return loss with grads seperately to trainer
Jubeku Nov 18, 2025
dd6f85a
Added mode and refactored get_sample_data into separate function.
clessig Nov 18, 2025
d0ef572
modify log names
Jubeku Nov 18, 2025
c6805c4
add loss_functions.py
Jubeku Nov 18, 2025
0ccce9e
merge develop
Jubeku Nov 18, 2025
7ac9e6b
rm loss_fcts in default config
Jubeku Nov 18, 2025
85fa139
Comments
clessig Nov 18, 2025
c1580c4
Renaming
clessig Nov 18, 2025
3c26ddc
updated default config training_config to allow student-teacher
shmh40 Nov 18, 2025
66cf9cd
added stream id to era5 config
shmh40 Nov 18, 2025
36ea287
slight restructure of ViewMetadata
shmh40 Nov 18, 2025
11ad4e6
basic if statement to yield the student and teacher views
shmh40 Nov 18, 2025
b3dfa2f
merge changes
shmh40 Nov 18, 2025
2536cec
correct imports with new batch.py
shmh40 Nov 18, 2025
31dc658
created function for _get_student_teacher_sample_data which returns t…
shmh40 Nov 19, 2025
a824bfc
Not working draft for restructuring
clessig Nov 19, 2025
dfc03f2
Merge branch 'shmh40/dev/1270-idx-global-local' of github.com:ecmwf/W…
clessig Nov 19, 2025
81cf929
Changes for better student teacher structure
clessig Nov 19, 2025
a9a83fd
Merge branch 'mk/develop/implement_diffusion_model_structure' into mk…
MatKbauer Nov 19, 2025
ec26f06
Merge branch 'develop' into mk/develop/implement_diffusion_model_stru…
MatKbauer Nov 19, 2025
46147d4
More refactoring
clessig Nov 19, 2025
8cf90d8
Update to latest develop
MatKbauer Nov 19, 2025
1e70f5c
More refactoring and cleanup
clessig Nov 19, 2025
1235aab
More refactoring. Code working again.
clessig Nov 19, 2025
4613f7a
Cleaned up parametrization
clessig Nov 19, 2025
9fe94f5
Changes necessary for spoofing flag per IOReaderData
clessig Nov 19, 2025
ed26c02
Changes to have spoofing on a per data reader sample
clessig Nov 19, 2025
6d685c0
Moved _get_student_teacher_masks() so that masks are generated for al…
clessig Nov 19, 2025
848880b
Renaming and minor clean up.
clessig Nov 19, 2025
1b1654c
Added basic support for use of ModelBatch class to define rough struc…
clessig Nov 19, 2025
c1d32fb
linting
clessig Nov 20, 2025
6a96065
Linting
clessig Nov 20, 2025
3bca490
linting
clessig Nov 20, 2025
5d5e999
Linting problems but removed unused ViewMetaData dependence
clessig Nov 20, 2025
e8ccb8d
Added required reflexivity between source and target samples to Batch
clessig Nov 20, 2025
d18cf86
Added todo
clessig Nov 20, 2025
a47b6ee
Step 1/3: Merge branch 'mk/develop/implement_diffusion_model_structur…
MatKbauer Nov 20, 2025
374e9cb
Step 2/3: Merge mk/develop/1249_diffusion_model_target into mk/develo…
MatKbauer Nov 20, 2025
cca76e3
Merge branch 'develop' into jk/develop/loss_calc_base
Jubeku Nov 20, 2025
9d0a817
Merge branch 'develop' into mk/develop/1300_assemble_diffusion_model
MatKbauer Nov 20, 2025
3fcb20f
merge loss_calc_base
Jubeku Nov 20, 2025
b2be982
fix typo in ModelBatch
shmh40 Nov 20, 2025
b34b6da
collect num_source_samples and num_target_samples, add loop over teac…
shmh40 Nov 20, 2025
87ad45f
add teacher num_views parameter to config
shmh40 Nov 20, 2025
9b702c5
Re-enabling inversion of targert ordering.
clessig Nov 20, 2025
1806ae5
tidy up, remove unused build_stream_views in tokenizer_masking
shmh40 Nov 20, 2025
647e4b2
multiple idxs for each teacher, need to confirm for not student case,…
shmh40 Nov 20, 2025
91c3d7a
add max_num_targets to era5
shmh40 Nov 21, 2025
1a418bf
add max_num_samples functionality to tokenizer_masking and pass throu…
shmh40 Nov 21, 2025
4df1788
Latent diffusion loss (#1322)
Jubeku Nov 21, 2025
63b2b63
Build latent diffusion forecast engine
MatKbauer Nov 21, 2025
dbffbea
fix training dataflow with diffusion FE
Jubeku Nov 21, 2025
f8c9369
update validation loop
Jubeku Nov 21, 2025
b6c2f7c
Merge branch 'mk/develop/1300_assemble_diffusion_model' of github.com…
MatKbauer Nov 21, 2025
ece1dd0
move build_views_for_stream into masker
shmh40 Nov 21, 2025
b9a60f3
tidy up, remove unused arguments, types
shmh40 Nov 21, 2025
2905cb0
fix masking for NPP-ATMS by correctly selecting final timestep mask a…
shmh40 Nov 22, 2025
af9a3c1
merge with develop, include trainer idx_inv_rt, merged default_config…
shmh40 Nov 24, 2025
b193a50
updated configs so code runs. Note default config to be overhauled still
shmh40 Nov 24, 2025
fa24fc1
very hacky first pass of full masking_strategy_config for the student…
shmh40 Nov 25, 2025
4f8f62b
instructions for sophie
shmh40 Nov 25, 2025
c0df0bf
Issue1279 noise conditioning (#1337)
moritzhauschulz Nov 26, 2025
c27156c
add SampleMetaData integration and functionality, and update masker t…
shmh40 Nov 26, 2025
e0d7346
remove prints, pdb
shmh40 Nov 26, 2025
35352ed
Merge branch 'develop' into mk/develop/1300_assemble_diffusion_model
Jubeku Nov 26, 2025
a09a737
linting
Jubeku Nov 26, 2025
705cb0a
fix ddp
Jubeku Nov 26, 2025
3e989c4
load encoder weights, fixed for multi-gpu
Jubeku Nov 26, 2025
5cb5f05
fix parameter counting in case of diff FE
Jubeku Nov 26, 2025
6d909d6
add mask to SampleMetaData and add forecast_dt to Sample so it is acc…
shmh40 Nov 27, 2025
26f7b5b
add diffusion forecast option for the data sampling, and with noise_l…
shmh40 Nov 27, 2025
b7cfb21
move diff parameters to config, add noise weight calc in latent loss
Jubeku Nov 27, 2025
7311c60
Merge branch 'shmh40/dev/1270-idx-global-local' into mk/develop/1300_…
Jubeku Nov 27, 2025
b47b0fa
Merge branch 'shmh40/dev/1270-idx-global-local' of github.com:ecmwf/W…
clessig Nov 28, 2025
5f803e5
Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into shmh…
clessig Nov 28, 2025
3e4de7a
Linting
clessig Nov 28, 2025
8ef3a4c
Simplified and clarified handling of default target_aux_calcualtor
clessig Nov 28, 2025
d8998a9
Linting
clessig Nov 28, 2025
652500a
Linting
clessig Nov 28, 2025
03166a2
Linting
clessig Nov 28, 2025
e41a575
Linting
clessig Nov 28, 2025
0db8b62
Linting
clessig Nov 28, 2025
47750a5
Restoring masking as training_mode in default_config
clessig Nov 28, 2025
bc8d23e
More linting
clessig Nov 28, 2025
6289959
Removed duplicate lines due to mergeing
clessig Nov 28, 2025
d526dfc
Restored masking as training mode. Not working due to NaN in prediction
clessig Nov 28, 2025
657094a
Fixed problem in engines introduced in recent commits merging develop…
clessig Nov 28, 2025
1a37dd1
remove unused mask generation in diffusion_forecast
shmh40 Nov 28, 2025
3378a67
Merge branch 'shmh40/dev/1270-idx-global-local' into mk/develop/1300_…
Jubeku Nov 28, 2025
0d44f40
remove duplicate key in config
Jubeku Nov 28, 2025
caadb37
add back masking_rate dog
Jubeku Nov 28, 2025
680f577
update config with new training mode
Jubeku Nov 28, 2025
bb71731
forecast_diffusion running with new data batch
Jubeku Nov 28, 2025
6ea07e7
restore masking_strategy to random
shmh40 Nov 28, 2025
4281aff
restore loader_num_workers to 8
shmh40 Nov 28, 2025
950e5b4
set loader_num_workers to 8
Jubeku Nov 28, 2025
15b46e9
fix indentation of else: assert False in _get_sample msds
shmh40 Nov 28, 2025
76270aa
[1269] Noise generation in diffusion inference (#1374)
moritzhauschulz Nov 28, 2025
6fe8561
Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into shmh…
clessig Nov 28, 2025
b662bf2
Made pre-trained encoder weights optional
MatKbauer Nov 28, 2025
3b55ef5
Update validation to new data structure
MatKbauer Dec 2, 2025
dc736e5
merge with dev
tjhunter Dec 2, 2025
2b2c977
linter warnings
tjhunter Dec 2, 2025
c8a2aad
commenting tests
tjhunter Dec 2, 2025
2599ec2
Restructured code so that mask generation and application is cleanly …
clessig Dec 2, 2025
c8a26d7
Commit
clessig Dec 2, 2025
23e0267
Update
clessig Dec 2, 2025
33d9d8d
Merge branch 'shmh40/dev/1270-idx-global-local' of github.com:ecmwf/W…
clessig Dec 2, 2025
9f5e49c
Fixed uv.lock
clessig Dec 2, 2025
3641e1f
Fix for integration test
clessig Dec 2, 2025
9a1a6a9
Re-enabled multi-source training
clessig Dec 3, 2025
402b8de
1390 - Adapt forward pass of new batch object (#1391)
Jubeku Dec 3, 2025
2cd3971
Completed migration to new batch class by removing reference to old l…
clessig Dec 3, 2025
51754fa
Fixed missing non_blocking=True in to_device()
clessig Dec 3, 2025
69b53a6
Removed old comments
clessig Dec 3, 2025
59510dd
Fixed problem with non_blocking=True
clessig Dec 3, 2025
b69b743
Cleaned up comments and return values a bit
clessig Dec 4, 2025
d36367a
Changed args to embedding
clessig Dec 4, 2025
3f52a8d
Changed core functions to take sample as arg
clessig Dec 4, 2025
9065219
Changed that model takes sample as input
clessig Dec 4, 2025
12bae15
Fixes for diffusion
clessig Dec 4, 2025
7745e47
Switched to lists of model / target stratgies
clessig Dec 4, 2025
bf17bfe
Updated config
clessig Dec 4, 2025
89f770e
Changed to per masking strategy loss terms
clessig Dec 5, 2025
a93fdb3
Removed old masking options. Still needs to be fully cleaned up
clessig Dec 5, 2025
454dffb
More robust handling of empty streams
clessig Dec 5, 2025
5cbbaa3
Fixed incorrect handling of empty target_coords_idx
clessig Dec 5, 2025
9c74741
Fixed problem when number of model and target samples is different
clessig Dec 5, 2025
085b55f
Example for config with non-trivial model and target inputs
clessig Dec 5, 2025
4dac76d
Fixed bug in total sample counting
clessig Dec 5, 2025
fe2f63a
Re-enabled missing healpix level
clessig Dec 5, 2025
b9195bb
Fixed incorrect handling of masking and student_teacher modes. Follow…
clessig Dec 6, 2025
43f9b01
An encoder formed by embedding + local assimilation + global assimila…
kctezcan Dec 6, 2025
4d27a95
Formatting
clessig Dec 6, 2025
9cf040e
Fix source-target matching problem.
clessig Dec 6, 2025
5fca790
Enabled multiple input steps. Fixed various robustness that arose thr…
clessig Dec 7, 2025
47e81fa
Linting
clessig Dec 7, 2025
e0f6cc4
Missing update to validation()
clessig Dec 9, 2025
8f097ec
Improved robustness through sanity checking of arguments
clessig Dec 9, 2025
6b64511
Improved handling of corner cases
clessig Dec 9, 2025
ed886e2
Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into shmh…
clessig Dec 9, 2025
303f48a
- Fixed incorrect call to get_forecast_steps() in validation
clessig Dec 9, 2025
9638de8
[NOT WORKING] Merged current data-branch. TargetAuxCalculator argumen…
MatKbauer Dec 9, 2025
7299106
More fixed to validation
clessig Dec 9, 2025
45189a4
Adding stream_id
clessig Dec 9, 2025
50b0a89
[NOT WORKING] Added modifications from data branch
MatKbauer Dec 10, 2025
5bed792
Cleaned up ModelOutput class to have proper access functions and a be…
clessig Dec 10, 2025
06f2e06
Switched to use dict to internally represent streams_datasets
clessig Dec 10, 2025
ad5a19c
Improving robustness of interface of ModelOutput class
clessig Dec 10, 2025
4f8abbb
Re-enabling model output
clessig Dec 10, 2025
d36716c
Ruff
clessig Dec 11, 2025
b8d95b2
Minor clean-ups and additional comments
clessig Dec 11, 2025
081d90a
Minor cleanups
clessig Dec 11, 2025
6b8fe83
Cleaned up handling of masks and masking metadata
clessig Dec 11, 2025
5a8ad49
Resolved bugs when updating data structure
MatKbauer Dec 11, 2025
eedaa8a
Updated to new data output structure
MatKbauer Dec 11, 2025
f768046
Linter
MatKbauer Dec 11, 2025
ca9e605
Current working version of default_config
clessig Dec 11, 2025
f8b1ca6
Fixed problem with branches with old code and incomplete cleanup
clessig Dec 11, 2025
003b0cf
Updated to test convergence of integration test.
clessig Dec 11, 2025
f38e6d2
Updated settings
clessig Dec 11, 2025
7e7ff8e
Clessig/ypd/dev/1353 add tokens latent state finalization (#1452)
clessig Dec 12, 2025
31a0b96
Ruffed
clessig Dec 12, 2025
4fe90d7
Adding sanity check for register tokens
clessig Dec 12, 2025
46bd7a2
Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into shmh…
clessig Dec 12, 2025
48dee1e
Update to latest data branch: latent_state dataclass
MatKbauer Dec 12, 2025
e2c09f2
Update to latest develop with new data structure
MatKbauer Dec 20, 2025
238e321
Merge branch 'develop' into mk/develop/1300_assemble_diffusion_model
Jubeku Jan 14, 2026
458e652
debug target_aux, loss_module, engines, etc
Jubeku Jan 14, 2026
61dce39
debug, diffusion_rn and batch.sample
Jubeku Jan 14, 2026
ea4d76c
Corrected latent token retrieval in loss calculation
MatKbauer Jan 15, 2026
b875734
working training loop on single sample
Jubeku Jan 15, 2026
c91d5c9
update config to fit forecast checkpoint
Jubeku Jan 15, 2026
3a8fead
Merge branch 'develop' into mk/develop/1300_assemble_diffusion_model
Jubeku Jan 19, 2026
d58032d
Merge branch 'develop' into mk/develop/1300_assemble_diffusion_model
Jubeku Jan 19, 2026
91d633b
reset default config
Jubeku Jan 19, 2026
bbdb3a1
modify default config for diffusion
Jubeku Jan 19, 2026
43b21c4
adding encoder loading to model interface
Jubeku Jan 19, 2026
52b6bb1
setting checkpoint to null temporarily
Jubeku Jan 20, 2026
0f7d4e5
rm activation checkpoint around diff forecast engine
Jubeku Jan 20, 2026
a51f706
[Diff] sbAsma/issue1279 noise conditioning (#1358)
sbAsma Jan 23, 2026
47566be
Correct forecast engine initialization
MatKbauer Jan 23, 2026
82a78f9
Merge branch 'develop' into 1300_assemble_diffusion_model_w_develop
moritzhauschulz Feb 8, 2026
3ce80f0
code runs...
moritzhauschulz Feb 8, 2026
a144867
remove some debugging code
moritzhauschulz Feb 18, 2026
e5cccbe
Merge branch 'develop' into mh/develop/1843_viz_denoised_image
moritzhauschulz Feb 18, 2026
63b3f78
adjusted diffusion config
moritzhauschulz Feb 18, 2026
83bb4c9
fixed inference
moritzhauschulz Feb 18, 2026
bb3bbe5
actually fiex inference (via config)
moritzhauschulz Feb 18, 2026
b5ee071
Plot maps during training at validation time
MatKbauer Feb 19, 2026
55b69c2
Intermediate state. Single sample overfitting works
MatKbauer Feb 20, 2026
a93b978
Intermediate multi-GPU error state
MatKbauer Feb 20, 2026
be6cb24
Successful single-sample overfitting on one GPU
MatKbauer Feb 20, 2026
2c63c7e
Minor config change
MatKbauer Feb 20, 2026
4414fe6
Adding missing reset() function for FSDP
clessig Feb 21, 2026
c917777
Linting
clessig Feb 21, 2026
4ae7c13
Linting
clessig Feb 21, 2026
268d34f
Linting
clessig Feb 21, 2026
6a487d9
Workding on FSDP
clessig Feb 21, 2026
351e8f9
Working on FSDP
clessig Feb 21, 2026
fbc7cd1
Linting
clessig Feb 21, 2026
7149866
Activating diffusion model
MatKbauer Feb 24, 2026
dbefecc
Merge branch 'mk/mh/1843_viz_denoised_image' of github.com:ecmwf/Weat…
MatKbauer Feb 24, 2026
72fb4ac
Mixture of physical and latent loss
MatKbauer Feb 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions NOTICE
Original file line number Diff line number Diff line change
@@ -1,3 +1,43 @@
=======================================================================
NVLABS/EDM (Elucidating the Design of Diffusion Models)

This software incorporates code from the 'edm' repository.

Original Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

The source code is available at:
https://github.com/NVlabs/edm

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0

=======================================================================
google-deepmind/graphcast (several associated papers)

This software incorporates code from the 'google-deepmind/graphcast' repository, with adaptations.

Original Copyright 2024 DeepMind Technologies Limited.

The source code is available at:
https://github.com/google-deepmind/graphcast

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0

=======================================================================
facebookresearch/DiT (Scalable Diffusion Models with Transformers (DiT))

This software incorporates code from the 'facebookresearch/DiT' repository, with adaptations.

The source code is available at:
https://github.com/facebookresearch/DiT

The code and model weights are licensed under CC-BY-NC.
See https://raw.githubusercontent.com/facebookresearch/DiT/refs/heads/main/LICENSE.txt for details.
This project includes code derived from project "DINOv2: Learning Robust Visual Features without Supervision",
originally developed by Meta Platforms, Inc. and affiliates,
licensed under the Apache License, Version 2.0.
Expand Down
282 changes: 282 additions & 0 deletions config/config_diffusion.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
# (C) Copyright 2025 WeatherGenerator contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

embed_orientation: "channels"
embed_unembed_mode: "block"
embed_dropout_rate: 0.1

ae_local_dim_embed: 2048
ae_local_num_blocks: 0
ae_local_num_heads: 16
ae_local_dropout_rate: 0.1
ae_local_with_qk_lnorm: True

ae_local_num_queries: 1
ae_local_queries_per_cell: False
ae_adapter_num_heads: 16
ae_adapter_embed: 128
ae_adapter_with_qk_lnorm: True
ae_adapter_with_residual: True
ae_adapter_dropout_rate: 0.1

ae_global_dim_embed: 2048
ae_global_num_blocks: 4
ae_global_num_heads: 32
ae_global_dropout_rate: 0.1
ae_global_with_qk_lnorm: True
# TODO: switching to < 1 triggers triton-related issues.
# See https://github.com/ecmwf/WeatherGenerator/issues/1050
ae_global_att_dense_rate: 1.0
ae_global_block_factor: 64
ae_global_mlp_hidden_factor: 2
ae_global_trailing_layer_norm: False

ae_aggregation_num_blocks: 0
ae_aggregation_num_heads: 32
ae_aggregation_dropout_rate: 0.1
ae_aggregation_with_qk_lnorm: True
ae_aggregation_att_dense_rate: 1.0
ae_aggregation_block_factor: 64
ae_aggregation_mlp_hidden_factor: 2

decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear
pred_adapter_kv: False
pred_self_attention: True
pred_dyadic_dims: False
pred_mlp_adaln: True
num_class_tokens: 0
num_register_tokens: 0

# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then
# one is training an auto-encoder
fe_num_blocks: 2
fe_num_heads: 16
fe_dropout_rate: 0.1
fe_with_qk_lnorm: True
fe_diffusion_model: True
fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer
fe_impute_latent_noise_std: 0.0 # 1e-4
# currently fixed to 1.0 (due to limitations with flex_attention and triton)
forecast_att_dense_rate: 1.0
with_step_conditioning: True # False
# Diffusion related parameters
frequency_embedding_dim: 256
embedding_dim: 512
sigma_min: 0.002
sigma_max: 50000
sigma_data: 0.5
rho: 7
p_mean: 0.0 # -1.2
p_std: 1.2 # 1.2
# Encoder weights (set to null to not load a pretrained encoder)
# chkpt_encoder_weights: "./models/dhb9q2yo/dhb9q2yo_chkpt00126.chkpt"
chkpt_encoder_weights: "dhb9q2yo"
chkpt_encoder_mini_epoch: 126

healpix_level: 5

with_mixed_precision: True
with_flash_attention: True
compile_model: False
with_fsdp: True
attention_dtype: bf16
mixed_precision_dtype: bf16
mlp_norm_eps: 1e-5
norm_eps: 1e-4

latent_noise_kl_weight: 0.0 # 1e-5
latent_noise_gamma: 2.0
latent_noise_saturate_encodings: 5
latent_noise_use_additive_noise: False
latent_noise_deterministic_latents: True


freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*"
# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*"
load_chkpt: {}

norm_type: "LayerNorm"

#####################################

streams_directory: "./config/streams/era5_1deg/"
streams: ???

# type of zarr_store
zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore

general:

# mutable parameters
istep: 0
rank: ???
world_size: ???

# local_rank,
# with_ddp,
# data_path_*,
# model_path,
# run_path,
# path_shared_

multiprocessing_method: "fork"

desc: ""
run_id: ???
run_history: []

# logging frequency in the training loop (in number of batches)
train_log_freq:
terminal: 10
metrics: 20
checkpoint: 250

# parameters for data loading
data_loading :

num_workers: 12
rng_seed: ???
repeat_data_in_mini_epoch : True

# pin GPU memory for faster transfer; it is possible that enabling memory_pinning with
# FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error.
# If this happens, you can disable the flag, but performance will drop on GH200.
memory_pinning: True


# config for training
training_config:

# training_mode: "masking", "student_teacher", "latent_loss"
training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"]

num_mini_epochs: 150
samples_per_mini_epoch: 66
shuffle: True

start_date: 2012-06-01T00:00
end_date: 2012-06-01T18:00

time_window_step: 06:00:00
time_window_len: 06:00:00

learning_rate_scheduling :
lr_start: 5e-5 # 1e-6?
lr_max: 1e-4 # 5e-5?
lr_final_decay: 1e-6
lr_final: 0.0
num_steps_warmup: 64
num_steps_cooldown: 512
policy_warmup: "cosine"
policy_decay: "constant"
policy_cooldown: "linear"
parallel_scaling_policy: "sqrt"

optimizer:
grad_clip: 1.0
weight_decay: 0.1
log_grad_norms: False
adamw :
# parameters are scaled by number of DDP workers
beta1 : 0.975
beta2 : 0.9875
eps : 2e-08

losses : {
"physical": {
type: LossPhysical,
weight: 0.1,
loss_fcts: {
"mse": {},
},
target_and_aux_calc: "Physical",
},
"latent_diff": {
type: LossLatentDiffusion,
weight: 0.9,
target_and_aux_calc: DiffusionLatentTargetEncoder,
loss_fcts: { "mse": { }, },
}
}

model_input: {
"forecasting" : {
# masking strategy: "random", "healpix", "forecast"
masking_strategy: "forecast",
masking_strategy_config: {diffusion_rn: True},
num_samples: 3
}
}

forecast :
time_step: 06:00:00
num_steps: 1
offset: 0
policy: "fixed"


# validation config; full validation config is merge of training and validation config
validation_config:

samples_per_mini_epoch: 16
shuffle: False

start_date: 2012-06-01T00:00
end_date: 2012-06-01T18:00

# whether to track the exponential moving average of weights for validation
validate_with_ema:
enabled : True
ema_ramp_up_ratio: 0.09
ema_halflife_in_thousands: 1e-3

# parameters for validation samples that are written to disk
output : {
# number of samples that are written
num_samples: 1,
# write samples in normalized model space
normalized_samples: False,
# output streams to write; default all
streams: null,
}

# run validation before training starts (mainly for model development)
validate_before_training: True


# test config; full test config is merge of validation and test config
# test config is used by default when running inference

# Tags for experiment tracking
# These tags will be logged in MLFlow along with completed runs for train, eval, val
# The tags are free-form, with the following rules:
# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries
# - tags should not duplicate existing config entries.
# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags
# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future)
wgtags:
# The name of the organization of the person running the experiment.
# This may be autofilled in the future. Expected values are lowercase strings
# e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience"
org: null
# The Github issue corresponding to this run (number such as 1234)
# Github issues are the central point when running experiment and contain
# links to hedgedocs, code branches, pull requests etc.
# It is recommended to associate a run with a Github issue.
issue: null
# The name of the experiment. This is a distinctive codename for the experiment campaign being run.
# This is expected to be the primary tag for comparing experiments in MLFlow, along with the
# issue number.
# Expected values are lowercase strings with no spaces, just underscores:
# Examples: "rollout_ablation_grid"
exp: null
# *** Experiment-specific tags ***
# All extra tags (including lists, dictionaries, etc.) are treated
# as strings by mlflow, so treat all extra tags as simple string key: value pairs.
grid: null
3 changes: 2 additions & 1 deletion config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ fe_num_blocks: 6
fe_num_heads: 16
fe_dropout_rate: 0.1
fe_with_qk_lnorm: True
fe_diffusion_model: False
fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer
fe_impute_latent_noise_std: 0.0 # 1e-4
# currently fixed to 1.0 (due to limitations with flex_attention and triton)
Expand Down Expand Up @@ -255,4 +256,4 @@ wgtags:
# *** Experiment-specific tags ***
# All extra tags (including lists, dictionaries, etc.) are treated
# as strings by mlflow, so treat all extra tags as simple string key: value pairs.
grid: null
grid: null
Loading