You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[Refactor] Unify TrainEngine by moving model-specific logic to model layer
Previously, we had two separate train engines:
- `TrainEngine` for regular models
- `VisionComposeTrainEngine` for vision-language models
This duplication led to:
- Code maintenance overhead (242 lines of duplicated logic)
- Tight coupling between engine and model-specific details
- Difficulty in extending to new model types
- **Remove** `VisionComposeTrainEngine` entirely (242 lines deleted)
- **Add** `pre_micro_batch_forward()` and `post_micro_batch_forward()` hooks to `BaseModel`
- `pre_micro_batch_forward()`: Compute data batch statistics before forward pass
- `post_micro_batch_forward()`: Aggregate micro-batch results and compute metrics
- **Unify** `TrainEngine` to handle all model types through the new hook system
- **BaseModel**:
- Add `DataBatchInfo` and `BatchForwardInfo` TypedDicts for return types
- Implement default `pre_micro_batch_forward()` to compute token statistics
- Implement default `post_micro_batch_forward()` to aggregate losses and extra info
- Add overload type hints for `__call__` to improve type inference
- **MoE Model**:
- Override `post_micro_batch_forward()` to handle MoE-specific logic:
- Compute maxvio for router load balancing
- Update router bias based on expert load
- Add `need_update_bias` property for cleaner code
- Properly scale balancing_loss and z_loss by batch_size
- **ComposeModel**:
- Override `pre_micro_batch_forward()` to compute image token statistics
- Add `ComposeDataBatchInfo` with `step_consumed_img_tokens` field
- **TrainEngine**:
- Simplify `train_step()` to delegate statistics to model hooks
- Replace `LossLog` and `OtherLog` with unified `TrainStepInfo`
- Add `_get_total_loss()` to aggregate all losses (with TODO for future refactor)
- Remove all model-specific branching logic
- **EngineConfig**:
- Remove conditional logic for VisionComposeTrainEngine
- Use single TrainEngine.build() path
- Update to use `TrainStepInfo` instead of separate `LossLog` and `OtherLog`
- Simplify hook signatures (from 2 params to 1)
- Remove conditional engine instantiation logic
- Replace `VisionComposeTrainEngine` imports with `TrainEngine`
- Update test assertions to use new `TrainStepInfo` structure
- Remove TypeAdapter validation for deprecated types
Currently, `TrainEngine._get_total_loss()` aggregates losses by iterating
through ModelOutputs fields. This is pragmatic but not ideal:
- **Pros**: Avoids large-scale changes to model forward() logic
- **Cons**: Engine knows about loss field names (coupling)
- **Future**: Model should return total_loss directly (see TODO comment)
`loss_ctx.batch_size` represents the full gradient accumulation batch size,
not intra_layer_micro_batch. This is correctly set in `CELossContext.build_batches()`
and used for scaling balancing_loss and z_loss.
The pre/post hooks provide clean extension points:
- Subclasses can override to add model-specific logic
- Default implementations in BaseModel handle common cases
- No conditional logic needed in engine layer
1. **Code Reduction**: -242 lines (VisionComposeTrainEngine removed)
2. **Better Separation of Concerns**: Engine focuses on training orchestration, models handle their own statistics
3. **Extensibility**: New model types can override hooks without changing engine
4. **Type Safety**: Unified TrainStepInfo with clear field definitions
5. **Maintainability**: Single engine implementation to maintain
- Loss reduce logic still needs clarification (minor issue, doesn't affect training)
- TODO added for future refactor: move total_loss aggregation to model layer
- All format issues (extra blank lines, class formatting) fixed
ghstack-source-id: 9ce752d
Pull-Request: #1518
0 commit comments