[tx] Introduce optimization step metrics dataclass#1191
[tx] Introduce optimization step metrics dataclass#1191pcmoritz merged 3 commits intoNovaSky-AI:mainfrom
Conversation
|
/gemini review |
There was a problem hiding this comment.
Code Review
The pull request introduces a new OptimStepMetrics dataclass to standardize the reporting of optimization metrics in the JAX backend. It also includes a safety fix for gradient averaging and updates tests to reflect these changes. The most significant change is in the optim_step logic, which now proceeds with an optimizer update even when no gradients are accumulated, potentially applying weight decay unexpectedly.
| "skyrl.ai/grad_norm": self.grad_norm.item(), | ||
| "skyrl.ai/learning_rate": self.learning_rate.item(), |
| if self.accumulated_grads.counts[adapter_index] == 0: | ||
| logger.warning(f"No accumulated gradients for model {model_id}, skipping optimizer step") | ||
| return types.OptimStepOutput(metrics={"skyrl.ai/learning_rate": learning_rate}) | ||
| logger.warning(f"No accumulated gradients for model {model_id}; applying step with zero gradients") |
There was a problem hiding this comment.
The removal of the early return when counts[adapter_index] == 0 changes the behavior of optim_step. Previously, the step was skipped entirely. Now, the code proceeds to call _compute_grads_and_update, which applies an optimizer update with zero gradients. For optimizers like AdamW, this will still apply weight decay to the parameters, which might be unintended if the user expects the step to be a no-op when no gradients are present. If the goal is to ensure metrics are always returned, consider restoring the early return but returning a OptimStepOutput with the expected metrics.
| logger.warning(f"No accumulated gradients for model {model_id}; applying step with zero gradients") | |
| logger.warning(f"No accumulated gradients for model {model_id}, skipping optimizer step") | |
| return types.OptimStepOutput(metrics={"skyrl.ai/grad_norm": 0.0, "skyrl.ai/learning_rate": learning_rate}) |
There was a problem hiding this comment.
I don't think the semantics of an empty gradient step are well specified, but applying an empty optimizer step seems like a reasonable thing to do and is also in line with reporting a zero gradient norm. This case shouldn't be common, and handling everything uniformly (zero gradient norm, apply optimizer step) seems like good semantics to me.
This is in preparation for merging #1008 and to make it easier to introduce metrics.