diff --git a/transformer_engine/pytorch/tensor/_quantization_helpers.py b/transformer_engine/pytorch/tensor/_quantization_helpers.py index 56cf503630..1b08039dda 100644 --- a/transformer_engine/pytorch/tensor/_quantization_helpers.py +++ b/transformer_engine/pytorch/tensor/_quantization_helpers.py @@ -83,3 +83,42 @@ def _stride_from_shape(shape: list[int]): for d in reversed(shape[1:]): rstride.append(rstride[-1] * d) return list(reversed(rstride)) + + +def safe_quantized_repr(obj, cls_name, extras=None, error=None): + """Metadata-only repr fallback for quantized tensors whose data cannot be + materialized for any reason. + + Each attribute access is guarded so that ``__repr__`` never raises. + + Parameters + ---------- + extras : dict, optional + Additional plain-Python (non-tensor) attributes to include, e.g. + ``{"is_2D_scaled": self._is_2D_scaled}``. Values are inserted after + ``fp8_dtype`` and before ``shape``. + error : BaseException, optional + The exception that triggered the fallback. When given, its type and + message are included in the ``data=`` field so that it is visible *why* + the data could not be materialized. + """ + parts = [] + fp8_dtype = getattr(obj, "_fp8_dtype", None) + if fp8_dtype is not None: + parts.append(f"fp8_dtype={fp8_dtype}") + if extras: + for key, value in extras.items(): + parts.append(f"{key}={value}") + try: + parts.append(f"shape={tuple(obj.shape)}") + except Exception: # pylint: disable=broad-except + pass + try: + parts.append(f"dtype={obj.dtype}") + except Exception: # pylint: disable=broad-except + pass + if error is not None: + parts.append(f"data=") + else: + parts.append("data=") + return f"{cls_name}({', '.join(parts)})" diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index ba46508d74..d2d28aecfb 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -14,7 +14,7 @@ from transformer_engine.common.recipe import Float8BlockScaling, Recipe from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ..quantized_tensor import QuantizedTensor, Quantizer -from ._quantization_helpers import _IdentityFunc +from ._quantization_helpers import _IdentityFunc, safe_quantized_repr from ..constants import DType from ..utils import devices_match, round_up_to_nearest_multiple @@ -267,11 +267,19 @@ def __new__( return instance def __repr__(self, *, tensor_contents=None): - return ( - f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype}," - f" is_2D_scaled={self._is_2D_scaled}," - f" data={self.dequantize()})" - ) + try: + return ( + f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype}," + f" is_2D_scaled={self._is_2D_scaled}," + f" data={self.dequantize()})" + ) + except Exception as exc: # pylint: disable=broad-except + return safe_quantized_repr( + self, + "Float8BlockwiseQTensor", + extras={"is_2D_scaled": self._is_2D_scaled}, + error=exc, + ) def quantize_( self, diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 4de8d82217..17ba87201b 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -18,7 +18,7 @@ from ..utils import canonicalize_process_group, devices_match from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func from ..quantized_tensor import QuantizedTensor, Quantizer -from ._quantization_helpers import _IdentityFunc +from ._quantization_helpers import _IdentityFunc, safe_quantized_repr from ..constants import dist_group_type, DType aten = torch.ops.aten @@ -412,13 +412,16 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor): """ def __repr__(self, *, tensor_contents=None): - return ( - "Float8Tensor(" - f"fp8_dtype={self._fp8_dtype}, " - f"scale_inv={self._scale_inv.item()}, " - f"data={self.dequantize()}" - ")" - ) + try: + return ( + "Float8Tensor(" + f"fp8_dtype={self._fp8_dtype}, " + f"scale_inv={self._scale_inv.item()}, " + f"data={self.dequantize()}" + ")" + ) + except Exception as exc: # pylint: disable=broad-except + return safe_quantized_repr(self, "Float8Tensor", error=exc) def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index d759aaf5c4..33db63d059 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -18,7 +18,7 @@ from ..utils import devices_match, round_up_to_nearest_multiple from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func from ..quantized_tensor import QuantizedTensor, Quantizer -from ._quantization_helpers import _IdentityFunc +from ._quantization_helpers import _IdentityFunc, safe_quantized_repr aten = torch.ops.aten @@ -233,7 +233,10 @@ def __new__( ) def __repr__(self, *, tensor_contents=None): - return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize()})" + try: + return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize()})" + except Exception as exc: # pylint: disable=broad-except + return safe_quantized_repr(self, "MXFP8Tensor", error=exc) def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 5a2765b9f5..f131615e72 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -23,7 +23,7 @@ from .storage.nvfp4_tensor_storage import NVFP4TensorStorage, _FromNVFP4Func from ..quantized_tensor import QuantizedTensor, Quantizer -from ._quantization_helpers import _IdentityFunc +from ._quantization_helpers import _IdentityFunc, safe_quantized_repr aten = torch.ops.aten @@ -398,7 +398,10 @@ def __new__( return instance def __repr__(self, *, tensor_contents=None): - return f"NVFP4Tensor, data={self.dequantize()})" + try: + return f"NVFP4Tensor, data={self.dequantize()})" + except Exception as exc: # pylint: disable=broad-except + return safe_quantized_repr(self, "NVFP4Tensor", error=exc) def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ diff --git a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index f7a3dae70b..993ead42ee 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -12,6 +12,7 @@ import transformer_engine_torch as tex from ...quantized_tensor import QuantizedTensorStorage, Quantizer +from .._quantization_helpers import safe_quantized_repr from ...constants import TE_DType_To_Torch, DType @@ -354,17 +355,25 @@ def _transpose_columnwise_data(self): del _old_data def __repr__(self): - if self._rowwise_data is not None: - data = self.dequantize() - descriptor = "rowwise" - else: - data = self.dequantize() - descriptor = "columnwise" - return ( - "Float8BlockwiseQTensorStorage(" - f"fp8_dtype={self._fp8_dtype}, " - f"{descriptor}_scaled_data={data})" - ) + try: + if self._rowwise_data is not None: + data = self.dequantize() + descriptor = "rowwise" + else: + data = self.dequantize() + descriptor = "columnwise" + return ( + "Float8BlockwiseQTensorStorage(" + f"fp8_dtype={self._fp8_dtype}, " + f"{descriptor}_scaled_data={data})" + ) + except Exception as exc: # pylint: disable=broad-except + return safe_quantized_repr( + self, + "Float8BlockwiseQTensorStorage", + extras={"is_2D_scaled": self._is_2D_scaled}, + error=exc, + ) def update_usage( self, rowwise_usage: Optional[bool] = None, columnwise_usage: Optional[bool] = None diff --git a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index a97162f91c..374d0e1e72 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -12,6 +12,7 @@ import transformer_engine_torch as tex from ...quantized_tensor import QuantizedTensorStorage, Quantizer +from .._quantization_helpers import safe_quantized_repr from ...constants import TE_DType as torch_to_transformer_engine_dtype, TE_DType_To_Torch, DType @@ -209,13 +210,16 @@ def view(self, shape: torch.Size): ) def __repr__(self): - return ( - "Float8TensorStorage(" - f"fp8_dtype={self._fp8_dtype}, " - f"scale_inv={self._scale_inv.item()}, " - f"data={self.dequantize()}" - ")" - ) + try: + return ( + "Float8TensorStorage(" + f"fp8_dtype={self._fp8_dtype}, " + f"scale_inv={self._scale_inv.item()}, " + f"data={self.dequantize()}" + ")" + ) + except Exception as exc: # pylint: disable=broad-except + return safe_quantized_repr(self, "Float8TensorStorage", error=exc) def _create_transpose(self): """Update FP8 transpose cache""" diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index ea592cd989..606ac9e74b 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -13,6 +13,7 @@ import transformer_engine_torch as tex from ...quantized_tensor import QuantizedTensorStorage, Quantizer +from .._quantization_helpers import safe_quantized_repr from ...constants import TE_DType as torch_to_transformer_engine_dtype, DType @@ -257,15 +258,18 @@ def view(self, shape: torch.Size): ) def __repr__(self): - data_rowwise = self.dequantize() - - return ( - "MXFP8TensorStorage(" - f"fp8_dtype={self._fp8_dtype}, " - f"rowwise_scaled_data={data_rowwise}" - f"rowwise_scale_inv={self._rowwise_scale_inv}, " - ")" - ) + try: + data_rowwise = self.dequantize() + + return ( + "MXFP8TensorStorage(" + f"fp8_dtype={self._fp8_dtype}, " + f"rowwise_scaled_data={data_rowwise}" + f"rowwise_scale_inv={self._rowwise_scale_inv}, " + ")" + ) + except Exception as exc: # pylint: disable=broad-except + return safe_quantized_repr(self, "MXFP8TensorStorage", error=exc) def update_usage( self, diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 53bb5e7c11..09f040ba67 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -16,6 +16,7 @@ import transformer_engine_torch as tex from ...quantized_tensor import QuantizedTensorStorage, Quantizer +from .._quantization_helpers import safe_quantized_repr from ...constants import TE_DType as torch_to_transformer_engine_dtype, DType from ...utils import _empty_tensor @@ -340,16 +341,19 @@ def view(self, shape: torch.Size): ) def __repr__(self): - data_rowwise = self.dequantize() - - return ( - "NVFP4TensorStorage(" - f"rowwise_scaled_data={data_rowwise}," - f"rowwise_scale_inv={self._rowwise_scale_inv}," - f"amax_rowwise={self._amax_rowwise}," - f"amax_columnwise={self._amax_columnwise}," - ")" - ) + try: + data_rowwise = self.dequantize() + + return ( + "NVFP4TensorStorage(" + f"rowwise_scaled_data={data_rowwise}," + f"rowwise_scale_inv={self._rowwise_scale_inv}," + f"amax_rowwise={self._amax_rowwise}," + f"amax_columnwise={self._amax_columnwise}," + ")" + ) + except Exception as exc: # pylint: disable=broad-except + return safe_quantized_repr(self, "NVFP4TensorStorage", error=exc) def update_usage( self,