Skip to content

Comments

Sophiex/dev/pretrained frozen teacher#1824

Closed
sophie-xhonneux wants to merge 42 commits intodevelopfrom
sophiex/dev/pretrained-frozen-teacher
Closed

Sophiex/dev/pretrained frozen teacher#1824
sophie-xhonneux wants to merge 42 commits intodevelopfrom
sophiex/dev/pretrained-frozen-teacher

Conversation

@sophie-xhonneux
Copy link
Contributor

@sophie-xhonneux sophie-xhonneux commented Feb 6, 2026

Description

The goal is to train against a frozen pre-trained teacher (e.g. by MAE)

Issue Number

Closes #1815

Is this PR a draft? Mark it as draft.

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

Sophie Xhonneux and others added 27 commits February 2, 2026 21:53
Fix
The issue is we're passing cf.training_config (current training config) but the teacher model's latent
  heads are defined by teacher_config. We need to pass the teacher's training config so the postprocessing
  keys match the teacher model's outputs.
The fix now:
  1. FrozenTeacher inspects the teacher model's actual latent_heads attribute to determine what
  postprocessing is needed
  2. Sets up JEPA/DINO/iBOT postprocessing based on what heads exist (using identity transform for all,
  with warnings for DINO/iBOT since full centering isn't supported for frozen teachers)
  3. Tests updated to use models with latent_heads attributes
Summary of Changes

  Key insight from your feedback: The frozen teacher may have been pre-trained with any method
  (forecasting, MAE, etc.) and doesn't need to have SSL latent heads. We should:
  1. Use the student's training config to know which SSL losses are needed
  2. Add identity heads (LatentPredictionHeadIdentity) to the teacher if they don't exist
  3. Use identity postprocessing (JEPATargetProcessing) for all SSL losses

  Changes Made

  src/weathergen/train/target_and_aux_ssl_teacher.py:
  - Added import for LatentPredictionHeadIdentity
  - Rewrote FrozenTeacher.__init__ to:
    - Accept training_cfg (the student's config) to determine required SSL heads
    - Call _get_required_ssl_heads() to extract loss names from config
    - Call _ensure_identity_heads() to add missing heads to the teacher model
    - Set up identity postprocessing for all SSL losses
  - Added _get_required_ssl_heads(): extracts SSL loss names from training config, defaults to {"JEPA"} if
  none found
  - Added _ensure_identity_heads(): adds LatentPredictionHeadIdentity for any missing heads
  - Updated from_pretrained() to pass cf.training_config to constructor

  tests/test_encoder_teacher.py:
  - Added model_without_latent_heads fixture (simulates a forecasting-only teacher)
  - Added 5 new tests:
    - test_frozen_teacher_adds_identity_heads_when_missing
    - test_frozen_teacher_uses_training_cfg_for_heads
    - test_frozen_teacher_defaults_to_jepa_without_config
    - test_frozen_teacher_preserves_existing_heads
    - test_frozen_teacher_all_postprocessing_is_identity
Copy link
Collaborator

@clessig clessig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some high level comments. The PR should be split up into the three independent contributions that are in it currently

from weathergen.model.norms import AdaLayerNorm, RMSNorm


class LayerScale(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put this into a separate PR--it's not related to the frozen teacher

return x * self.gamma


class StochasticDepth(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put this into a separate PR--it's not related to the frozen teacher


return self.lr

def _set_param_group_lrs(self, base_lr: float):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put this into a separate PR--it's not related to the frozen teacher

lr_multiplier = g.get("lr_multiplier", 1.0)
g["lr"] = base_lr * lr_multiplier

def _apply_lr_multipliers(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put this into a separate PR--it's not related to the frozen teacher


# Initialize collapse monitor for SSL training
collapse_config = self.training_cfg.get("collapse_monitoring", {})
self.collapse_monitor = CollapseMonitor(collapse_config, None) # device set later in run()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

devices[9] is available here

self.ema_model.update(self.cf.general.istep * batch_size_total, batch_size_total)

# Compute collapse monitoring metrics
if self.collapse_monitor.should_compute(self.cf.general.istep):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move if statement into the function that you call. It's better encapsulation.

if bidx % self.train_log_freq.metrics == 0:
self._log(TRAIN)
# Log collapse metrics
if self.collapse_monitor.should_log(self.cf.general.istep):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move if statement into the function that you call. It's better encapsulation.

is_rank_zero = is_root()

# Handle CompositeOptimizer (Muon+AdamW) separately
if isinstance(self.optimizer, CompositeOptimizer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have this encapsulation, we should also encapsulate the part for native AdamW into a separate function.

else:
return {}

def _get_full_composite_optimizer_state_dict(self, is_rank_zero: bool):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should go to optimizer.py


self.t_start = time.time()

def _compute_collapse_metrics(self, preds, targets_and_auxs) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be in trainer.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be best to have this as a separate class and then we could also split up the function.

Sophie Xhonneux and others added 11 commits February 16, 2026 15:14
* implemented

* remove eval in interface

* lint

* incoporate requested changes

* fix imports

* Fix corner case in inference where data window is empty

* Fix missing handling of missing load_chkpt argument in config

---------

Co-authored-by: moritzhauschulz <moritz.hauschulz@gmail.com>
…but also as entrypoints. (#1778)

* Clean up docstrings, separate cli parsing from running.

* remove unused argument stream_dir

* separate parser instantiation from adding args

* add unified parser with subparsers

* implement main function in run_train using subparsers.

* update integration tests

* remove redundant methods *_from_args (previously used by integration tests)

* Move entrypoints to the top of run_train.py

* fix typo in small_multi_stream_test.infer_multi_stream

* fix formatting

* Organize strings into enum.

* fix parser
* Implement best effort backward compatibility.

* use new `data_pathes` option to look up training data.

* fix integration tests

* linting

* correct spelling in config.py

* correct spelling in multi_stream_data_sampler.py

* fix typo "data_path_anmoi" -> "data_path_anemoi" in config.py

* Update test_config.py

* Add suggested comment.
* Fixed most parts of plot_train. Currently missing: handling of stage_configs when these are derived from an earlier stage.

* Removed outdated or unsupported options

* Fixed final problems with consolidated training/validation config. Required to move Stage to a more appropriate place

* Removed old, unused code
* nse_metric

* length

---------

Co-authored-by: Jesica Pinyon Rodriguez <jpinyonr@login07.leonardo.local>
@github-actions github-actions bot added data Anything related to the datasets used in the project eval anything related to the model evaluation pipeline infra Issues related to infrastructure model Related to model training or definition (not generic infra) labels Feb 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

data Anything related to the datasets used in the project eval anything related to the model evaluation pipeline infra Issues related to infrastructure model Related to model training or definition (not generic infra)

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

Frozen pre-trained teacher

5 participants