From f693220efc357c64aa79b463945d342fd7b66a7d Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Wed, 4 Feb 2026 14:50:34 -0800 Subject: [PATCH 1/9] initial debug of inspect ffi Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 8 +- transformer_engine/jax/csrc/extensions.h | 3 + .../jax/csrc/extensions/amax.cpp | 23 ++++ .../jax/csrc/extensions/pybind.cpp | 3 + transformer_engine/jax/inspect.py | 111 ++++++++++++++++++ 5 files changed, 144 insertions(+), 4 deletions(-) create mode 100644 transformer_engine/jax/inspect.py diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 71f133bfc4..d400412386 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -373,10 +373,10 @@ def assert_cublas_requirements(scaling_mode, contracting_size, tensor_name): # Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage alignment = 32 if scaling_mode.is_nvfp4_scaling else 16 - assert contracting_size % alignment == 0, ( - f"cuBLAS GEMM {tensor_name} tensor's contracting dimension must be a multiple of" - f" {alignment} when using quantized inputs. Got contracting_size={contracting_size}" - ) + # assert contracting_size % alignment == 0, ( + # f"cuBLAS GEMM {tensor_name} tensor's contracting dimension must be a multiple of" + # f" {alignment} when using quantized inputs. Got contracting_size={contracting_size}" + # ) class GemmPrimitive(BasePrimitive): diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 3fd086e257..1c0bc52b88 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -143,6 +143,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationHandler); +// Inspect +XLA_FFI_DECLARE_HANDLER_SYMBOL(InspectHandler); + // Cudnn helpers XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler); diff --git a/transformer_engine/jax/csrc/extensions/amax.cpp b/transformer_engine/jax/csrc/extensions/amax.cpp index 5ffccaffb4..61cfa206c3 100644 --- a/transformer_engine/jax/csrc/extensions/amax.cpp +++ b/transformer_engine/jax/csrc/extensions/amax.cpp @@ -96,5 +96,28 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Attr("produce_regular_amax") // produce_regular_amax .Attr("flatten_axis")); // flatten_axis + +Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf) { + NVTE_CHECK(input_buf.untyped_data() != nullptr, + "Input must be provided for inspect operation"); + NVTE_CHECK(output_buf->untyped_data() != nullptr, + "Output must be provided for inspect operation"); + NVTE_CHECK(input_buf.untyped_data() == output_buf->untyped_data(), + "Input and output must point to the same buffer for inspect operation"); + + printf("JTEST: Hello\n"); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + InspectHandler, InspectFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Ret() // output + ); + + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index a5986404c9..3f05b57077 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -81,6 +81,9 @@ pybind11::dict Registrations() { pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler), pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler)); + dict["te_inspect_ffi"] = pybind11::dict( + pybind11::arg("execute") = EncapsulateFFI(InspectHandler)); + return dict; } diff --git a/transformer_engine/jax/inspect.py b/transformer_engine/jax/inspect.py new file mode 100644 index 0000000000..849cfb4491 --- /dev/null +++ b/transformer_engine/jax/inspect.py @@ -0,0 +1,111 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""JAX array inspection utilities.""" + +from functools import partial + +import jax +import jax.numpy as jnp +from jax import ffi + +from .cpp_extensions.base import BasePrimitive, register_primitive + +__all__ = ["inspect_array"] + + +class InspectPrimitive(BasePrimitive): + """ + No-op used for inspect array values. + """ + + name = "te_inspect_ffi" + multiple_results = False + impl_static_args = () + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + x_aval, + ): + """ + inspect abstract + """ + return x_aval + + @staticmethod + def lowering( + ctx, + x, + ): + """ + inspect lowering rules + """ + + return ffi.ffi_lowering( + InspectPrimitive.name, + operand_output_aliases={0: 0}, # donate input buffer to output buffer + )( + ctx, + x, + ) + + @staticmethod + def impl( + x, + ): + """ + inspect implementation + """ + assert InspectPrimitive.inner_primitive is not None + ( + x + ) = InspectPrimitive.inner_primitive.bind( + x, + ) + return x + +register_primitive(InspectPrimitive) + +@partial(jax.custom_vjp, nondiff_argnums=()) +def _inspect( + x, +): + """ + """ + output, _ = _inspect_fwd_rule( + x, + ) + return output + + +def _inspect_fwd_rule( + x, +): + """""" + ctx = () + x = InspectPrimitive.outer_primitive.bind(x) + return x, ctx + + +def _inspect_bwd_rule( + ctx, + grad, +): + """""" + del ctx + return grad, + + +_inspect.defvjp(_inspect_fwd_rule, _inspect_bwd_rule) + +def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray: + """Utility function to inspect JAX arrays by printing their name, shape, dtype, and statistics. + + Args: + x (jnp.ndarray): The JAX array to inspect. + name (str): The name of the array for identification in the output. + """ + # TODO: Handle the name of the tensor in the primitive and output files + return _inspect(x) \ No newline at end of file From f2d1629f3a4ce364dad7fba24688ae4a16dbb4c0 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Wed, 4 Feb 2026 15:16:33 -0800 Subject: [PATCH 2/9] writing binary dumps of tensors works Signed-off-by: Jeremy Berchtold --- .../jax/csrc/extensions/amax.cpp | 34 ++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/amax.cpp b/transformer_engine/jax/csrc/extensions/amax.cpp index 61cfa206c3..52ef9c47fb 100644 --- a/transformer_engine/jax/csrc/extensions/amax.cpp +++ b/transformer_engine/jax/csrc/extensions/amax.cpp @@ -6,6 +6,7 @@ #include #include +#include #include "../extensions.h" #include "transformer_engine/cast.h" @@ -98,16 +99,33 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf) { - NVTE_CHECK(input_buf.untyped_data() != nullptr, - "Input must be provided for inspect operation"); - NVTE_CHECK(output_buf->untyped_data() != nullptr, - "Output must be provided for inspect operation"); - NVTE_CHECK(input_buf.untyped_data() == output_buf->untyped_data(), - "Input and output must point to the same buffer for inspect operation"); + NVTE_CHECK(input_buf.untyped_data() != nullptr, + "Input must be provided for inspect operation"); + NVTE_CHECK(output_buf->untyped_data() != nullptr, + "Output must be provided for inspect operation"); + NVTE_CHECK(input_buf.untyped_data() == output_buf->untyped_data(), + "Input and output must point to the same buffer for inspect operation"); - printf("JTEST: Hello\n"); - return ffi_with_cuda_error_check(); + std::vector input_data(input_buf.size_bytes()); + cudaMemcpyAsync(input_data.data(), input_buf.untyped_data(), input_buf.size_bytes(), + cudaMemcpyDeviceToHost, stream); + cudaStreamSynchronize(stream); + + int device; + cudaGetDevice(&device); + + std::string filename = "my_tensor_gpu" + std::to_string(device) + ".bin"; + std::ofstream file(filename, std::ios::binary); + if (file.is_open()) { + file.write(reinterpret_cast(input_data.data()), input_data.size()); + file.close(); + } + printf("Tensor data written to %s\n", filename.c_str()); + + // TODO: make a metadata file with tensor shape and dtype? + + return ffi_with_cuda_error_check(); } XLA_FFI_DEFINE_HANDLER_SYMBOL( From f56d8696bbcdc33649fab13198ec9eb557de9581 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Wed, 4 Feb 2026 15:27:10 -0800 Subject: [PATCH 3/9] loading works Signed-off-by: Jeremy Berchtold --- .../jax/csrc/extensions/amax.cpp | 10 ++++++++- transformer_engine/jax/inspect.py | 21 +++++++++++++++++-- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/amax.cpp b/transformer_engine/jax/csrc/extensions/amax.cpp index 52ef9c47fb..97728303a2 100644 --- a/transformer_engine/jax/csrc/extensions/amax.cpp +++ b/transformer_engine/jax/csrc/extensions/amax.cpp @@ -121,7 +121,15 @@ Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type ou file.write(reinterpret_cast(input_data.data()), input_data.size()); file.close(); } - printf("Tensor data written to %s\n", filename.c_str()); + printf("Tensor data written to %s (shape: [", filename.c_str()); + for (size_t i = 0; i < input_buf.dimensions().size(); ++i) { + printf("%ld", static_cast(input_buf.dimensions()[i])); + if (i < input_buf.dimensions().size() - 1) { + printf(", "); + } + } + printf("], dtype: %d)\n", static_cast(input_buf.element_type())); + // TODO: make a metadata file with tensor shape and dtype? diff --git a/transformer_engine/jax/inspect.py b/transformer_engine/jax/inspect.py index 849cfb4491..d9f8d70bc9 100644 --- a/transformer_engine/jax/inspect.py +++ b/transformer_engine/jax/inspect.py @@ -11,7 +11,7 @@ from .cpp_extensions.base import BasePrimitive, register_primitive -__all__ = ["inspect_array"] +__all__ = ["inspect_array", "load_array_dump"] class InspectPrimitive(BasePrimitive): @@ -108,4 +108,21 @@ def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray: name (str): The name of the array for identification in the output. """ # TODO: Handle the name of the tensor in the primitive and output files - return _inspect(x) \ No newline at end of file + return _inspect(x) + + +def load_array_dump(filename: str, shape: tuple, dtype: jnp.dtype) -> jnp.ndarray: + """Utility function to load a JAX array from a dumped binary file. + + Args: + filename (str): The path to the binary file containing the array data. + shape (tuple): The shape of the array to be loaded. + dtype (jnp.dtype): The data type of the array to be loaded. + + Returns: + jnp.ndarray: The loaded JAX array. + """ + with open(filename, "rb") as f: + data = f.read() + array = jnp.frombuffer(data, dtype=dtype).reshape(shape) + return array \ No newline at end of file From 37a7dd5b059ee4e200cb7139e22fbe1b03f83912 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 23:28:53 +0000 Subject: [PATCH 4/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../jax/csrc/extensions/amax.cpp | 78 +++++++++---------- .../jax/csrc/extensions/pybind.cpp | 4 +- transformer_engine/jax/inspect.py | 16 ++-- 3 files changed, 46 insertions(+), 52 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/amax.cpp b/transformer_engine/jax/csrc/extensions/amax.cpp index 97728303a2..e18ee99d81 100644 --- a/transformer_engine/jax/csrc/extensions/amax.cpp +++ b/transformer_engine/jax/csrc/extensions/amax.cpp @@ -5,8 +5,8 @@ ************************************************************************/ #include -#include #include +#include #include "../extensions.h" #include "transformer_engine/cast.h" @@ -97,53 +97,47 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Attr("produce_regular_amax") // produce_regular_amax .Attr("flatten_axis")); // flatten_axis - Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf) { - NVTE_CHECK(input_buf.untyped_data() != nullptr, - "Input must be provided for inspect operation"); - NVTE_CHECK(output_buf->untyped_data() != nullptr, - "Output must be provided for inspect operation"); - NVTE_CHECK(input_buf.untyped_data() == output_buf->untyped_data(), - "Input and output must point to the same buffer for inspect operation"); - - - std::vector input_data(input_buf.size_bytes()); - cudaMemcpyAsync(input_data.data(), input_buf.untyped_data(), input_buf.size_bytes(), - cudaMemcpyDeviceToHost, stream); - cudaStreamSynchronize(stream); - - int device; - cudaGetDevice(&device); - - std::string filename = "my_tensor_gpu" + std::to_string(device) + ".bin"; - std::ofstream file(filename, std::ios::binary); - if (file.is_open()) { - file.write(reinterpret_cast(input_data.data()), input_data.size()); - file.close(); - } - printf("Tensor data written to %s (shape: [", filename.c_str()); - for (size_t i = 0; i < input_buf.dimensions().size(); ++i) { - printf("%ld", static_cast(input_buf.dimensions()[i])); - if (i < input_buf.dimensions().size() - 1) { - printf(", "); - } + NVTE_CHECK(input_buf.untyped_data() != nullptr, "Input must be provided for inspect operation"); + NVTE_CHECK(output_buf->untyped_data() != nullptr, + "Output must be provided for inspect operation"); + NVTE_CHECK(input_buf.untyped_data() == output_buf->untyped_data(), + "Input and output must point to the same buffer for inspect operation"); + + std::vector input_data(input_buf.size_bytes()); + cudaMemcpyAsync(input_data.data(), input_buf.untyped_data(), input_buf.size_bytes(), + cudaMemcpyDeviceToHost, stream); + cudaStreamSynchronize(stream); + + int device; + cudaGetDevice(&device); + + std::string filename = "my_tensor_gpu" + std::to_string(device) + ".bin"; + std::ofstream file(filename, std::ios::binary); + if (file.is_open()) { + file.write(reinterpret_cast(input_data.data()), input_data.size()); + file.close(); + } + printf("Tensor data written to %s (shape: [", filename.c_str()); + for (size_t i = 0; i < input_buf.dimensions().size(); ++i) { + printf("%ld", static_cast(input_buf.dimensions()[i])); + if (i < input_buf.dimensions().size() - 1) { + printf(", "); } - printf("], dtype: %d)\n", static_cast(input_buf.element_type())); - + } + printf("], dtype: %d)\n", static_cast(input_buf.element_type())); - // TODO: make a metadata file with tensor shape and dtype? + // TODO: make a metadata file with tensor shape and dtype? - return ffi_with_cuda_error_check(); + return ffi_with_cuda_error_check(); } -XLA_FFI_DEFINE_HANDLER_SYMBOL( - InspectHandler, InspectFFI, - FFI::Bind() - .Ctx() // stream - .Arg() // input - .Ret() // output - ); - +XLA_FFI_DEFINE_HANDLER_SYMBOL(InspectHandler, InspectFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Ret() // output +); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 3f05b57077..5a8ee18f09 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -81,8 +81,8 @@ pybind11::dict Registrations() { pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler), pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler)); - dict["te_inspect_ffi"] = pybind11::dict( - pybind11::arg("execute") = EncapsulateFFI(InspectHandler)); + dict["te_inspect_ffi"] = + pybind11::dict(pybind11::arg("execute") = EncapsulateFFI(InspectHandler)); return dict; } diff --git a/transformer_engine/jax/inspect.py b/transformer_engine/jax/inspect.py index d9f8d70bc9..61bbaf8bb0 100644 --- a/transformer_engine/jax/inspect.py +++ b/transformer_engine/jax/inspect.py @@ -45,7 +45,7 @@ def lowering( return ffi.ffi_lowering( InspectPrimitive.name, - operand_output_aliases={0: 0}, # donate input buffer to output buffer + operand_output_aliases={0: 0}, # donate input buffer to output buffer )( ctx, x, @@ -59,21 +59,20 @@ def impl( inspect implementation """ assert InspectPrimitive.inner_primitive is not None - ( - x - ) = InspectPrimitive.inner_primitive.bind( + (x) = InspectPrimitive.inner_primitive.bind( x, ) return x + register_primitive(InspectPrimitive) + @partial(jax.custom_vjp, nondiff_argnums=()) def _inspect( x, ): - """ - """ + """ """ output, _ = _inspect_fwd_rule( x, ) @@ -95,11 +94,12 @@ def _inspect_bwd_rule( ): """""" del ctx - return grad, + return (grad,) _inspect.defvjp(_inspect_fwd_rule, _inspect_bwd_rule) + def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray: """Utility function to inspect JAX arrays by printing their name, shape, dtype, and statistics. @@ -125,4 +125,4 @@ def load_array_dump(filename: str, shape: tuple, dtype: jnp.dtype) -> jnp.ndarra with open(filename, "rb") as f: data = f.read() array = jnp.frombuffer(data, dtype=dtype).reshape(shape) - return array \ No newline at end of file + return array From c3fe902ecf442dc6949cb9a2e2fc78a6660e80e3 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 9 Feb 2026 12:00:47 -0800 Subject: [PATCH 5/9] refactor Signed-off-by: Jeremy Berchtold --- .../common/gemm/cublaslt_gemm.cu | 8 +- transformer_engine/jax/cpp_extensions/gemm.py | 8 +- .../jax/csrc/extensions/amax.cpp | 45 ----------- .../jax/csrc/extensions/inspect.cpp | 79 +++++++++++++++++++ .../jax/debug/experimental/__init__.py | 14 ++++ .../jax/{ => debug/experimental}/inspect.py | 4 +- 6 files changed, 103 insertions(+), 55 deletions(-) create mode 100644 transformer_engine/jax/csrc/extensions/inspect.cpp create mode 100644 transformer_engine/jax/debug/experimental/__init__.py rename transformer_engine/jax/{ => debug/experimental}/inspect.py (95%) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index c58c3cb47a..11543be501 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -154,8 +154,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla if (is_fp8_dtype(ret.Atype)) { // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - NVTE_CHECK(ret.lda % 16 == 0, - "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); + // NVTE_CHECK(ret.lda % 16 == 0, + // "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); } } else if (nvfp4) { // NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe. @@ -245,8 +245,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla if (is_fp8_dtype(ret.Atype)) { // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - NVTE_CHECK(ret.ldb % 16 == 0, - "Leading dimension requirement on B for FP8 GEMM. Caller must pad."); + // NVTE_CHECK(ret.ldb % 16 == 0, + // "Leading dimension requirement on B for FP8 GEMM. Caller must pad."); } } else if (nvfp4) { if (is_B_transposed) { diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index d400412386..71f133bfc4 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -373,10 +373,10 @@ def assert_cublas_requirements(scaling_mode, contracting_size, tensor_name): # Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage alignment = 32 if scaling_mode.is_nvfp4_scaling else 16 - # assert contracting_size % alignment == 0, ( - # f"cuBLAS GEMM {tensor_name} tensor's contracting dimension must be a multiple of" - # f" {alignment} when using quantized inputs. Got contracting_size={contracting_size}" - # ) + assert contracting_size % alignment == 0, ( + f"cuBLAS GEMM {tensor_name} tensor's contracting dimension must be a multiple of" + f" {alignment} when using quantized inputs. Got contracting_size={contracting_size}" + ) class GemmPrimitive(BasePrimitive): diff --git a/transformer_engine/jax/csrc/extensions/amax.cpp b/transformer_engine/jax/csrc/extensions/amax.cpp index e18ee99d81..58c89cfd32 100644 --- a/transformer_engine/jax/csrc/extensions/amax.cpp +++ b/transformer_engine/jax/csrc/extensions/amax.cpp @@ -5,9 +5,6 @@ ************************************************************************/ #include -#include -#include - #include "../extensions.h" #include "transformer_engine/cast.h" #include "transformer_engine/hadamard_transform.h" @@ -97,47 +94,5 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Attr("produce_regular_amax") // produce_regular_amax .Attr("flatten_axis")); // flatten_axis -Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf) { - NVTE_CHECK(input_buf.untyped_data() != nullptr, "Input must be provided for inspect operation"); - NVTE_CHECK(output_buf->untyped_data() != nullptr, - "Output must be provided for inspect operation"); - NVTE_CHECK(input_buf.untyped_data() == output_buf->untyped_data(), - "Input and output must point to the same buffer for inspect operation"); - - std::vector input_data(input_buf.size_bytes()); - cudaMemcpyAsync(input_data.data(), input_buf.untyped_data(), input_buf.size_bytes(), - cudaMemcpyDeviceToHost, stream); - cudaStreamSynchronize(stream); - - int device; - cudaGetDevice(&device); - - std::string filename = "my_tensor_gpu" + std::to_string(device) + ".bin"; - std::ofstream file(filename, std::ios::binary); - if (file.is_open()) { - file.write(reinterpret_cast(input_data.data()), input_data.size()); - file.close(); - } - printf("Tensor data written to %s (shape: [", filename.c_str()); - for (size_t i = 0; i < input_buf.dimensions().size(); ++i) { - printf("%ld", static_cast(input_buf.dimensions()[i])); - if (i < input_buf.dimensions().size() - 1) { - printf(", "); - } - } - printf("], dtype: %d)\n", static_cast(input_buf.element_type())); - - // TODO: make a metadata file with tensor shape and dtype? - - return ffi_with_cuda_error_check(); -} - -XLA_FFI_DEFINE_HANDLER_SYMBOL(InspectHandler, InspectFFI, - FFI::Bind() - .Ctx() // stream - .Arg() // input - .Ret() // output -); - } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/inspect.cpp b/transformer_engine/jax/csrc/extensions/inspect.cpp new file mode 100644 index 0000000000..4d4d96dee0 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/inspect.cpp @@ -0,0 +1,79 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ +#include + +#include +#include + +#include "../extensions.h" +#include "xla/ffi/api/c_api.h" + +namespace transformer_engine { +namespace jax { + +Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf) { + NVTE_CHECK(input_buf.untyped_data() != nullptr, "Input must be provided for inspect operation"); + NVTE_CHECK(output_buf->untyped_data() != nullptr, + "Output must be provided for inspect operation"); + NVTE_CHECK(input_buf.untyped_data() == output_buf->untyped_data(), + "Input and output must point to the same buffer for inspect operation"); + + std::vector input_data(input_buf.size_bytes()); + cudaMemcpyAsync(input_data.data(), input_buf.untyped_data(), input_buf.size_bytes(), + cudaMemcpyDeviceToHost, stream); + cudaStreamSynchronize(stream); + + int device; + cudaGetDevice(&device); + + // Write the tensor data to a file as a binary blob + std::string filename = "my_tensor_gpu" + std::to_string(device) + ".bin"; + std::ofstream file(filename, std::ios::binary); + if (file.is_open()) { + file.write(reinterpret_cast(input_data.data()), input_data.size()); + file.close(); + } + + // Write out a metadata file + std::string meta_filename = "my_tensor_gpu" + std::to_string(device) + "_meta.json"; + std::ofstream meta_file(meta_filename); + if (meta_file.is_open()) { + meta_file << "{"; + meta_file << "\"shape\": ["; + for (size_t i = 0; i < input_buf.dimensions().size(); ++i) { + meta_file << input_buf.dimensions()[i]; + if (i < input_buf.dimensions().size() - 1) { + meta_file << ", "; + } + } + meta_file << "], "; + meta_file << "\"dtype\": " << static_cast(input_buf.element_type()); + meta_file << "}"; + meta_file.close(); + } + + // Log the tensor metadata to the console + printf("Tensor data written to %s (shape: [", filename.c_str()); + for (size_t i = 0; i < input_buf.dimensions().size(); ++i) { + printf("%ld", static_cast(input_buf.dimensions()[i])); + if (i < input_buf.dimensions().size() - 1) { + printf(", "); + } + } + printf("], dtype: %d)\n", static_cast(input_buf.element_type())); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(InspectHandler, InspectFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Ret() // output +); + +} // namespace jax +} // namespace transformer_engine diff --git a/transformer_engine/jax/debug/experimental/__init__.py b/transformer_engine/jax/debug/experimental/__init__.py new file mode 100644 index 0000000000..44a4847660 --- /dev/null +++ b/transformer_engine/jax/debug/experimental/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""EXPERIMENTAL debugging utilities for Transformer Engine JAX. + +This API is experimental and may change or be removed without deprecation in future releases. +""" + +from .inspect import inspect_array, load_array_dump + +__all__ = [ + "inspect_array", + "load_array_dump", +] diff --git a/transformer_engine/jax/inspect.py b/transformer_engine/jax/debug/experimental/inspect.py similarity index 95% rename from transformer_engine/jax/inspect.py rename to transformer_engine/jax/debug/experimental/inspect.py index 61bbaf8bb0..ddb1a5b069 100644 --- a/transformer_engine/jax/inspect.py +++ b/transformer_engine/jax/debug/experimental/inspect.py @@ -1,7 +1,7 @@ # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -"""JAX array inspection utilities.""" +"""Experimental JAX array inspection utilities.""" from functools import partial @@ -9,7 +9,7 @@ import jax.numpy as jnp from jax import ffi -from .cpp_extensions.base import BasePrimitive, register_primitive +from transformer_engine.jax.cpp_extensions.base import BasePrimitive, register_primitive __all__ = ["inspect_array", "load_array_dump"] From cf7be54fec5ddae347bc81b660b58d85744aab9a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Feb 2026 20:01:52 +0000 Subject: [PATCH 6/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/gemm/cublaslt_gemm.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 11543be501..241e30764a 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -155,7 +155,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla if (is_fp8_dtype(ret.Atype)) { // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage // NVTE_CHECK(ret.lda % 16 == 0, - // "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); + // "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); } } else if (nvfp4) { // NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe. From cdf53f516edc1e122c100f292784f49a52dcf7fe Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 9 Feb 2026 13:48:13 -0800 Subject: [PATCH 7/9] Add tensor statistics Signed-off-by: Jeremy Berchtold --- .../common/gemm/cublaslt_gemm.cu | 8 ++--- .../jax/csrc/extensions/inspect.cpp | 26 ++++++++++++-- .../jax/debug/experimental/inspect.py | 34 ++++++++++++++++++- 3 files changed, 61 insertions(+), 7 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 241e30764a..c58c3cb47a 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -154,8 +154,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla if (is_fp8_dtype(ret.Atype)) { // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - // NVTE_CHECK(ret.lda % 16 == 0, - // "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); + NVTE_CHECK(ret.lda % 16 == 0, + "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); } } else if (nvfp4) { // NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe. @@ -245,8 +245,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla if (is_fp8_dtype(ret.Atype)) { // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - // NVTE_CHECK(ret.ldb % 16 == 0, - // "Leading dimension requirement on B for FP8 GEMM. Caller must pad."); + NVTE_CHECK(ret.ldb % 16 == 0, + "Leading dimension requirement on B for FP8 GEMM. Caller must pad."); } } else if (nvfp4) { if (is_B_transposed) { diff --git a/transformer_engine/jax/csrc/extensions/inspect.cpp b/transformer_engine/jax/csrc/extensions/inspect.cpp index 4d4d96dee0..6a03407bca 100644 --- a/transformer_engine/jax/csrc/extensions/inspect.cpp +++ b/transformer_engine/jax/csrc/extensions/inspect.cpp @@ -14,7 +14,13 @@ namespace transformer_engine { namespace jax { -Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf) { +Error_Type InspectFFI(cudaStream_t stream, + Buffer_Type input_buf, + Buffer_Type min_buf, + Buffer_Type max_buf, + Buffer_Type mean_buf, + Buffer_Type std_buf, + Result_Type output_buf) { NVTE_CHECK(input_buf.untyped_data() != nullptr, "Input must be provided for inspect operation"); NVTE_CHECK(output_buf->untyped_data() != nullptr, "Output must be provided for inspect operation"); @@ -24,6 +30,13 @@ Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type ou std::vector input_data(input_buf.size_bytes()); cudaMemcpyAsync(input_data.data(), input_buf.untyped_data(), input_buf.size_bytes(), cudaMemcpyDeviceToHost, stream); + + float min_val{}, max_val{}, mean_val{}, std_val{}; + cudaMemcpyAsync(&min_val, min_buf.untyped_data(), sizeof(float), cudaMemcpyDeviceToHost, stream); + cudaMemcpyAsync(&max_val, max_buf.untyped_data(), sizeof(float), cudaMemcpyDeviceToHost, stream); + cudaMemcpyAsync(&mean_val, mean_buf.untyped_data(), sizeof(float), cudaMemcpyDeviceToHost, stream); + cudaMemcpyAsync(&std_val, std_buf.untyped_data(), sizeof(float), cudaMemcpyDeviceToHost, stream); + cudaStreamSynchronize(stream); int device; @@ -51,6 +64,10 @@ Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type ou } meta_file << "], "; meta_file << "\"dtype\": " << static_cast(input_buf.element_type()); + meta_file << ", \"min\": " << min_val; + meta_file << ", \"max\": " << max_val; + meta_file << ", \"mean\": " << mean_val; + meta_file << ", \"std\": " << std_val; meta_file << "}"; meta_file.close(); } @@ -63,7 +80,8 @@ Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type ou printf(", "); } } - printf("], dtype: %d)\n", static_cast(input_buf.element_type())); + printf("], dtype: %d", static_cast(input_buf.element_type())); + printf(", min: %f, max: %f, mean: %f, std: %f)\n", min_val, max_val, mean_val, std_val); return ffi_with_cuda_error_check(); } @@ -72,6 +90,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(InspectHandler, InspectFFI, FFI::Bind() .Ctx() // stream .Arg() // input + .Arg() // min + .Arg() // max + .Arg() // mean + .Arg() // std .Ret() // output ); diff --git a/transformer_engine/jax/debug/experimental/inspect.py b/transformer_engine/jax/debug/experimental/inspect.py index ddb1a5b069..c87d34285b 100644 --- a/transformer_engine/jax/debug/experimental/inspect.py +++ b/transformer_engine/jax/debug/experimental/inspect.py @@ -28,16 +28,28 @@ class InspectPrimitive(BasePrimitive): @staticmethod def abstract( x_aval, + x_min_aval, + x_max_aval, + x_mean_aval, + x_std_aval, ): """ inspect abstract """ + assert x_min_aval.shape == () and x_min_aval.dtype == jnp.float32, "x_min must be a scalar with dtype float32" + assert x_max_aval.shape == () and x_max_aval.dtype == jnp.float32, "x_max must be a scalar with dtype float32" + assert x_mean_aval.shape == () and x_mean_aval.dtype == jnp.float32, "x_mean must be a scalar with dtype float32" + assert x_std_aval.shape == () and x_std_aval.dtype == jnp.float32, "x_std must be a scalar with dtype float32" return x_aval @staticmethod def lowering( ctx, x, + x_min, + x_max, + x_mean, + x_std, ): """ inspect lowering rules @@ -49,11 +61,19 @@ def lowering( )( ctx, x, + x_min, + x_max, + x_mean, + x_std, ) @staticmethod def impl( x, + x_min, + x_max, + x_mean, + x_std, ): """ inspect implementation @@ -61,12 +81,24 @@ def impl( assert InspectPrimitive.inner_primitive is not None (x) = InspectPrimitive.inner_primitive.bind( x, + x_min, + x_max, + x_mean, + x_std, ) return x register_primitive(InspectPrimitive) +def _inspect_array_inner(x: jnp.ndarray) -> jnp.ndarray: + return InspectPrimitive.outer_primitive.bind( + x, + jnp.min(x).astype(jnp.float32), + jnp.max(x).astype(jnp.float32), + jnp.mean(x.astype(jnp.float32)), + jnp.std(x.astype(jnp.float32)), + ) @partial(jax.custom_vjp, nondiff_argnums=()) def _inspect( @@ -84,7 +116,7 @@ def _inspect_fwd_rule( ): """""" ctx = () - x = InspectPrimitive.outer_primitive.bind(x) + x = _inspect_array_inner(x) return x, ctx From 39a219498dfa3b655df5e289c91f1ed50f92ce15 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Feb 2026 21:49:13 +0000 Subject: [PATCH 8/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../jax/csrc/extensions/inspect.cpp | 13 +++++-------- .../jax/debug/experimental/inspect.py | 18 ++++++++++++++---- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/inspect.cpp b/transformer_engine/jax/csrc/extensions/inspect.cpp index 6a03407bca..56ac585126 100644 --- a/transformer_engine/jax/csrc/extensions/inspect.cpp +++ b/transformer_engine/jax/csrc/extensions/inspect.cpp @@ -14,13 +14,9 @@ namespace transformer_engine { namespace jax { -Error_Type InspectFFI(cudaStream_t stream, - Buffer_Type input_buf, - Buffer_Type min_buf, - Buffer_Type max_buf, - Buffer_Type mean_buf, - Buffer_Type std_buf, - Result_Type output_buf) { +Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type min_buf, + Buffer_Type max_buf, Buffer_Type mean_buf, Buffer_Type std_buf, + Result_Type output_buf) { NVTE_CHECK(input_buf.untyped_data() != nullptr, "Input must be provided for inspect operation"); NVTE_CHECK(output_buf->untyped_data() != nullptr, "Output must be provided for inspect operation"); @@ -34,7 +30,8 @@ Error_Type InspectFFI(cudaStream_t stream, float min_val{}, max_val{}, mean_val{}, std_val{}; cudaMemcpyAsync(&min_val, min_buf.untyped_data(), sizeof(float), cudaMemcpyDeviceToHost, stream); cudaMemcpyAsync(&max_val, max_buf.untyped_data(), sizeof(float), cudaMemcpyDeviceToHost, stream); - cudaMemcpyAsync(&mean_val, mean_buf.untyped_data(), sizeof(float), cudaMemcpyDeviceToHost, stream); + cudaMemcpyAsync(&mean_val, mean_buf.untyped_data(), sizeof(float), cudaMemcpyDeviceToHost, + stream); cudaMemcpyAsync(&std_val, std_buf.untyped_data(), sizeof(float), cudaMemcpyDeviceToHost, stream); cudaStreamSynchronize(stream); diff --git a/transformer_engine/jax/debug/experimental/inspect.py b/transformer_engine/jax/debug/experimental/inspect.py index c87d34285b..59ec98fd8c 100644 --- a/transformer_engine/jax/debug/experimental/inspect.py +++ b/transformer_engine/jax/debug/experimental/inspect.py @@ -36,10 +36,18 @@ def abstract( """ inspect abstract """ - assert x_min_aval.shape == () and x_min_aval.dtype == jnp.float32, "x_min must be a scalar with dtype float32" - assert x_max_aval.shape == () and x_max_aval.dtype == jnp.float32, "x_max must be a scalar with dtype float32" - assert x_mean_aval.shape == () and x_mean_aval.dtype == jnp.float32, "x_mean must be a scalar with dtype float32" - assert x_std_aval.shape == () and x_std_aval.dtype == jnp.float32, "x_std must be a scalar with dtype float32" + assert ( + x_min_aval.shape == () and x_min_aval.dtype == jnp.float32 + ), "x_min must be a scalar with dtype float32" + assert ( + x_max_aval.shape == () and x_max_aval.dtype == jnp.float32 + ), "x_max must be a scalar with dtype float32" + assert ( + x_mean_aval.shape == () and x_mean_aval.dtype == jnp.float32 + ), "x_mean must be a scalar with dtype float32" + assert ( + x_std_aval.shape == () and x_std_aval.dtype == jnp.float32 + ), "x_std must be a scalar with dtype float32" return x_aval @staticmethod @@ -91,6 +99,7 @@ def impl( register_primitive(InspectPrimitive) + def _inspect_array_inner(x: jnp.ndarray) -> jnp.ndarray: return InspectPrimitive.outer_primitive.bind( x, @@ -100,6 +109,7 @@ def _inspect_array_inner(x: jnp.ndarray) -> jnp.ndarray: jnp.std(x.astype(jnp.float32)), ) + @partial(jax.custom_vjp, nondiff_argnums=()) def _inspect( x, From 378b4ecc1bea893797403f9f9a4307b67f9dfea7 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 9 Feb 2026 13:56:13 -0800 Subject: [PATCH 9/9] lint Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/csrc/extensions/inspect.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/csrc/extensions/inspect.cpp b/transformer_engine/jax/csrc/extensions/inspect.cpp index 56ac585126..a7110367b3 100644 --- a/transformer_engine/jax/csrc/extensions/inspect.cpp +++ b/transformer_engine/jax/csrc/extensions/inspect.cpp @@ -72,7 +72,7 @@ Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type mi // Log the tensor metadata to the console printf("Tensor data written to %s (shape: [", filename.c_str()); for (size_t i = 0; i < input_buf.dimensions().size(); ++i) { - printf("%ld", static_cast(input_buf.dimensions()[i])); + printf("%zu", static_cast(input_buf.dimensions()[i])); if (i < input_buf.dimensions().size() - 1) { printf(", "); }