Skip to content

Comments

schedule free adamw jax update with switching params to x for validation, y for training#16

Open
wyfEmma wants to merge 4 commits intomlcommons:mainfrom
wyfEmma:wyf_schedule_free
Open

schedule free adamw jax update with switching params to x for validation, y for training#16
wyfEmma wants to merge 4 commits intomlcommons:mainfrom
wyfEmma:wyf_schedule_free

Conversation

@wyfEmma
Copy link

@wyfEmma wyfEmma commented Feb 4, 2026

New Submission

Submission Information

Please fill out the following information about your submission within the quotation marks.

submission_name: "Schedule-Free AdamW fix"  # As it will appear on the leaderboard
submission_folder: "schedule_free"  # Name of folder within `submissions/external_tuning/` or `submissions/self_tuning/` (lowercase, no spaces)
authors: "Yifan (Emma)"  # List authors separated by commas
affiliations: "Google"  # List all affiliations of the authors, separated by commas
version: "1.0"  # Optional version number of your submission
ruleset: "self-tuning"  # Either "external" or "self-tuning"
framework: "JAX"  # Either "PyTorch" or "JAX"
description: "schedule free adamw jax update with switching params to x for validation, y for training to match the algorithm, this fix is based off of init-22's pr"  # A short, high-level description of the algorithm

Evidence for the Submission's Performance

If possible provide some evidence of your submission's performance. E.g. a link to a paper/pre-print, training logs, screenshots, etc. The working group will prioritize evaluating submissions with more convincing evidence.

Comments

Feel free to add any comments, descriptions, or questions here.

@wyfEmma wyfEmma requested a review from a team as a code owner February 4, 2026 04:19
@github-actions
Copy link

github-actions bot commented Feb 4, 2026

MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅

(state, _), opt_update_fn = optimizer_state

# Calculate x = (y - (1 - b1) * z) / b1
params_for_eval = schedule_free_eval_params(state, current_param_container) # (current_param_container - (1 - state.b1) * state.z) / state.b1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this implemented?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is implemented in the optax library for schedule free algorithm,
def schedule_free_eval_params(state: ScheduleFreeState, params: base.Params):
"""Params for evaluation of :func:optax.contrib.schedule_free."""
return jax.tree_util.tree_map(
lambda yi, zi: (yi - (1.0 - state.b1) * zi) / state.b1, params, state.z
)

lambda x_leaf, z_leaf: (1 - beta1) * z_leaf + beta1 * x_leaf,
current_param_container, # x
z
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we set this back to False after updating?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, got it now

z
)

# Set up mesh and sharding
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move 190-219 out of this function

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants