From 49fa7f00a0059cddc6c39aad3bc1bdf3588d7d0d Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 4 May 2026 15:40:44 +0000 Subject: [PATCH] cpu optimizations for te autocast Signed-off-by: Varun Thumbe --- transformer_engine/common/recipe/__init__.py | 73 ++++++++++++++-- transformer_engine/pytorch/quantization.py | 90 +++++++++++++------- 2 files changed, 123 insertions(+), 40 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 67b6f87067..4f60c97847 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -60,6 +60,14 @@ class MMParams: use_split_accumulator: bool = True + def __repr__(self) -> str: + cached = self.__dict__.get("_cached_repr") + if cached is not None: + return cached + result = f"MMParams(use_split_accumulator={self.use_split_accumulator})" + object.__setattr__(self, "_cached_repr", result) + return result + @dataclass(frozen=True) class QParams: @@ -77,13 +85,18 @@ class QParams: fp4_2d_quantization: bool = False def __repr__(self) -> str: - return ( + cached = self.__dict__.get("_cached_repr") + if cached is not None: + return cached + result = ( f"Qparams(\npower_2_scale={self.power_2_scale},\n" f"amax_epsilon={self.amax_epsilon},\n" f"random_hadamard_transform={self.random_hadamard_transform},\n" f"stochastic_rounding={self.stochastic_rounding},\n" f"fp4_2d_quantization={self.fp4_2d_quantization}\n)" ) + object.__setattr__(self, "_cached_repr", result) + return result class Recipe: @@ -91,6 +104,22 @@ class Recipe: Base recipe class. """ + # Cached string representation. Lazily populated by ``__repr__`` in + # subclasses and invalidated by ``__setattr__`` whenever any attribute + # changes. This makes repeated ``str(recipe)`` calls (e.g. on the hot + # path in ``FP8GlobalStateManager.get_unique_autocast_key``) essentially + # free after the first call. + _cached_repr: Optional[str] = None + + def __setattr__(self, name: str, value: Any) -> None: + # Invalidate the cached repr on any attribute mutation. We avoid + # recursion by checking the name and always routing the actual + # assignment through ``object.__setattr__`` (which also works for + # pydantic frozen dataclasses that override ``__setattr__``). + if name != "_cached_repr": + object.__setattr__(self, "_cached_repr", None) + object.__setattr__(self, name, value) + @classmethod def nvfp4(cls): """Whether the given recipe is NVFP4 1D block scaling.""" @@ -228,7 +257,10 @@ def __post_init__(self) -> None: ), "Delayed scaling only supports backward_override=None." def __repr__(self) -> str: - return ( + cached = self.__dict__.get("_cached_repr") + if cached is not None: + return cached + result = ( f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " f"format={str(self.fp8_format).split('.')[1]}, " @@ -238,6 +270,8 @@ def __repr__(self) -> str: f"fp8_mha={self.fp8_mha}, " f"backward_override={self.backward_override}" ) + object.__setattr__(self, "_cached_repr", result) + return result @dataclass() @@ -276,7 +310,10 @@ def __post_init__(self) -> None: ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." def __repr__(self) -> str: - return ( + cached = self.__dict__.get("_cached_repr") + if cached is not None: + return cached + result = ( f"recipe_type={self.__class__.__name__}, " f"format={str(self.fp8_format).split('.')[1]}, " f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, " @@ -289,6 +326,8 @@ def __repr__(self) -> str: f"fp8_mha={self.fp8_mha}, " f"backward_override={self.backward_override}" ) + object.__setattr__(self, "_cached_repr", result) + return result @dataclass() @@ -334,12 +373,17 @@ def __post_init__(self) -> None: ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." def __repr__(self) -> str: - return ( + cached = self.__dict__.get("_cached_repr") + if cached is not None: + return cached + result = ( f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " f"format={str(self.fp8_format).split('.')[1]}, " f"backward_override={self.backward_override}" ) + object.__setattr__(self, "_cached_repr", result) + return result @dataclass() @@ -415,7 +459,10 @@ def __post_init__(self) -> None: ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." def __repr__(self) -> str: - return ( + cached = self.__dict__.get("_cached_repr") + if cached is not None: + return cached + result = ( f"recipe_type={self.__class__.__name__}, " f"format={str(self.fp8_format).split('.')[1]}, " f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, " @@ -431,6 +478,8 @@ def __repr__(self) -> str: f"fp8_mha={self.fp8_mha}, " f"backward_override={self.backward_override}" ) + object.__setattr__(self, "_cached_repr", result) + return result @dataclass() @@ -527,7 +576,10 @@ def __post_init__(self) -> None: ) def __repr__(self) -> str: - return ( + cached = self.__dict__.get("_cached_repr") + if cached is not None: + return cached + result = ( f"recipe_type={self.__class__.__name__}, " f"fp4_format={str(self.fp4_format).split('.')[1]}, " f"fp8_format={str(self.fp8_format).split('.')[1]}, " @@ -538,6 +590,8 @@ def __repr__(self) -> str: f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " ) + object.__setattr__(self, "_cached_repr", result) + return result @dataclass() @@ -584,8 +638,13 @@ def __post_init__(self) -> None: ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." def __repr__(self) -> str: - return ( + cached = self.__dict__.get("_cached_repr") + if cached is not None: + return cached + result = ( f"recipe_type={self.__class__.__name__}, " f"qfactory={self.qfactory}, " f"backward_override={self.backward_override}" ) + object.__setattr__(self, "_cached_repr", result) + return result diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 9956fb77ec..40b9a592b7 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -580,9 +580,8 @@ def reduce_and_update_fp8_tensors( amax_history, scale, get_fp8_max(recipe, forward), recipe ) - @classmethod + @staticmethod def get_unique_autocast_key( - cls, recipe: Optional[Recipe] = None, group: Optional[dist_group_type] = None, ): @@ -591,7 +590,13 @@ def get_unique_autocast_key( Object identity is sufficient since autocast contexts never outlive a single training session. """ - return str((str(recipe), id(group) if group is not None else None)) + # directly getting the cached repr is about 40 ns faster than str(recipe) + # on grace systems. + recipe_repr = recipe.__dict__.get("_cached_repr") if recipe is not None else None + if recipe_repr is None: + recipe_repr = str(recipe) + group_id = id(group) if group is not None else 0 + return f"{recipe_repr}|{group_id}" @classmethod def autocast_enter( @@ -805,14 +810,13 @@ def quantized_model_init( qstate.high_precision_init_val = _high_precision_init_val -@contextmanager def fp8_autocast( enabled: bool = True, calibrating: bool = False, fp8_recipe: Optional[Recipe] = None, fp8_group: Optional[dist_group_type] = None, _graph: bool = False, -) -> None: +) -> "autocast": """ .. warning:: @@ -828,25 +832,16 @@ def fp8_autocast( stacklevel=2, ) - # Call new implementation. - with autocast( + return autocast( enabled=enabled, calibrating=calibrating, recipe=fp8_recipe, amax_reduction_group=fp8_group, _graph=_graph, - ): - yield + ) -@contextmanager -def autocast( - enabled: bool = True, - calibrating: bool = False, - recipe: Optional["Recipe"] = None, - amax_reduction_group: Optional["dist_group_type"] = None, - _graph: bool = False, -) -> None: +class autocast: """ Context manager for quantization schemes like FP8 or FP4. @@ -885,24 +880,53 @@ def autocast( are reduced at the end of each training step. """ - if enabled: - check_recipe_support(recipe) + # Class-based context manager (instead of ``@contextmanager`` from contextlib) + # to avoid the ~0.5us / invocation overhead of contextlib's generator-driven + # ``GeneratorContextManager``. ``__slots__`` further avoids per-instance + # dict allocation. + __slots__ = ( + "_enabled", + "_calibrating", + "_recipe", + "_amax_reduction_group", + "_graph", + "_fp8_state", + ) - # Save current state so we always restore it on exit. - fp8_state = FP8GlobalStateManager.get_autocast_state() + def __init__( + self, + enabled: bool = True, + calibrating: bool = False, + recipe: Optional["Recipe"] = None, + amax_reduction_group: Optional["dist_group_type"] = None, + _graph: bool = False, + ) -> None: + self._enabled = enabled + self._calibrating = calibrating + self._recipe = recipe + self._amax_reduction_group = amax_reduction_group + self._graph = _graph + self._fp8_state = None + + def __enter__(self) -> "autocast": + if self._enabled: + check_recipe_support(self._recipe) + # Save current state so we always restore it on exit. + self._fp8_state = FP8GlobalStateManager.get_autocast_state() + FP8GlobalStateManager.autocast_enter( + enabled=self._enabled, + calibrating=self._calibrating, + fp8_recipe=self._recipe, + fp8_group=self._amax_reduction_group, + _graph=self._graph, + ) + return self - FP8GlobalStateManager.autocast_enter( - enabled=enabled, - calibrating=calibrating, - fp8_recipe=recipe, - fp8_group=amax_reduction_group, - _graph=_graph, - ) - try: - yield - finally: - FP8GlobalStateManager.set_autocast_state(fp8_state) - FP8GlobalStateManager.autocast_exit(enabled, _graph=_graph) + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + FP8GlobalStateManager.set_autocast_state(self._fp8_state) + FP8GlobalStateManager.autocast_exit(self._enabled, _graph=self._graph) + # Do not suppress exceptions. + return None def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: