diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 3fd086e257..1c0bc52b88 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -143,6 +143,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationHandler); +// Inspect +XLA_FFI_DECLARE_HANDLER_SYMBOL(InspectHandler); + // Cudnn helpers XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler); diff --git a/transformer_engine/jax/csrc/extensions/amax.cpp b/transformer_engine/jax/csrc/extensions/amax.cpp index 5ffccaffb4..58c89cfd32 100644 --- a/transformer_engine/jax/csrc/extensions/amax.cpp +++ b/transformer_engine/jax/csrc/extensions/amax.cpp @@ -5,8 +5,6 @@ ************************************************************************/ #include -#include - #include "../extensions.h" #include "transformer_engine/cast.h" #include "transformer_engine/hadamard_transform.h" diff --git a/transformer_engine/jax/csrc/extensions/inspect.cpp b/transformer_engine/jax/csrc/extensions/inspect.cpp new file mode 100644 index 0000000000..a7110367b3 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/inspect.cpp @@ -0,0 +1,98 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ +#include + +#include +#include + +#include "../extensions.h" +#include "xla/ffi/api/c_api.h" + +namespace transformer_engine { +namespace jax { + +Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type min_buf, + Buffer_Type max_buf, Buffer_Type mean_buf, Buffer_Type std_buf, + Result_Type output_buf) { + NVTE_CHECK(input_buf.untyped_data() != nullptr, "Input must be provided for inspect operation"); + NVTE_CHECK(output_buf->untyped_data() != nullptr, + "Output must be provided for inspect operation"); + NVTE_CHECK(input_buf.untyped_data() == output_buf->untyped_data(), + "Input and output must point to the same buffer for inspect operation"); + + std::vector input_data(input_buf.size_bytes()); + cudaMemcpyAsync(input_data.data(), input_buf.untyped_data(), input_buf.size_bytes(), + cudaMemcpyDeviceToHost, stream); + + float min_val{}, max_val{}, mean_val{}, std_val{}; + cudaMemcpyAsync(&min_val, min_buf.untyped_data(), sizeof(float), cudaMemcpyDeviceToHost, stream); + cudaMemcpyAsync(&max_val, max_buf.untyped_data(), sizeof(float), cudaMemcpyDeviceToHost, stream); + cudaMemcpyAsync(&mean_val, mean_buf.untyped_data(), sizeof(float), cudaMemcpyDeviceToHost, + stream); + cudaMemcpyAsync(&std_val, std_buf.untyped_data(), sizeof(float), cudaMemcpyDeviceToHost, stream); + + cudaStreamSynchronize(stream); + + int device; + cudaGetDevice(&device); + + // Write the tensor data to a file as a binary blob + std::string filename = "my_tensor_gpu" + std::to_string(device) + ".bin"; + std::ofstream file(filename, std::ios::binary); + if (file.is_open()) { + file.write(reinterpret_cast(input_data.data()), input_data.size()); + file.close(); + } + + // Write out a metadata file + std::string meta_filename = "my_tensor_gpu" + std::to_string(device) + "_meta.json"; + std::ofstream meta_file(meta_filename); + if (meta_file.is_open()) { + meta_file << "{"; + meta_file << "\"shape\": ["; + for (size_t i = 0; i < input_buf.dimensions().size(); ++i) { + meta_file << input_buf.dimensions()[i]; + if (i < input_buf.dimensions().size() - 1) { + meta_file << ", "; + } + } + meta_file << "], "; + meta_file << "\"dtype\": " << static_cast(input_buf.element_type()); + meta_file << ", \"min\": " << min_val; + meta_file << ", \"max\": " << max_val; + meta_file << ", \"mean\": " << mean_val; + meta_file << ", \"std\": " << std_val; + meta_file << "}"; + meta_file.close(); + } + + // Log the tensor metadata to the console + printf("Tensor data written to %s (shape: [", filename.c_str()); + for (size_t i = 0; i < input_buf.dimensions().size(); ++i) { + printf("%zu", static_cast(input_buf.dimensions()[i])); + if (i < input_buf.dimensions().size() - 1) { + printf(", "); + } + } + printf("], dtype: %d", static_cast(input_buf.element_type())); + printf(", min: %f, max: %f, mean: %f, std: %f)\n", min_val, max_val, mean_val, std_val); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(InspectHandler, InspectFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // min + .Arg() // max + .Arg() // mean + .Arg() // std + .Ret() // output +); + +} // namespace jax +} // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index a5986404c9..5a8ee18f09 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -81,6 +81,9 @@ pybind11::dict Registrations() { pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler), pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler)); + dict["te_inspect_ffi"] = + pybind11::dict(pybind11::arg("execute") = EncapsulateFFI(InspectHandler)); + return dict; } diff --git a/transformer_engine/jax/debug/experimental/__init__.py b/transformer_engine/jax/debug/experimental/__init__.py new file mode 100644 index 0000000000..44a4847660 --- /dev/null +++ b/transformer_engine/jax/debug/experimental/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""EXPERIMENTAL debugging utilities for Transformer Engine JAX. + +This API is experimental and may change or be removed without deprecation in future releases. +""" + +from .inspect import inspect_array, load_array_dump + +__all__ = [ + "inspect_array", + "load_array_dump", +] diff --git a/transformer_engine/jax/debug/experimental/inspect.py b/transformer_engine/jax/debug/experimental/inspect.py new file mode 100644 index 0000000000..59ec98fd8c --- /dev/null +++ b/transformer_engine/jax/debug/experimental/inspect.py @@ -0,0 +1,170 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Experimental JAX array inspection utilities.""" + +from functools import partial + +import jax +import jax.numpy as jnp +from jax import ffi + +from transformer_engine.jax.cpp_extensions.base import BasePrimitive, register_primitive + +__all__ = ["inspect_array", "load_array_dump"] + + +class InspectPrimitive(BasePrimitive): + """ + No-op used for inspect array values. + """ + + name = "te_inspect_ffi" + multiple_results = False + impl_static_args = () + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + x_aval, + x_min_aval, + x_max_aval, + x_mean_aval, + x_std_aval, + ): + """ + inspect abstract + """ + assert ( + x_min_aval.shape == () and x_min_aval.dtype == jnp.float32 + ), "x_min must be a scalar with dtype float32" + assert ( + x_max_aval.shape == () and x_max_aval.dtype == jnp.float32 + ), "x_max must be a scalar with dtype float32" + assert ( + x_mean_aval.shape == () and x_mean_aval.dtype == jnp.float32 + ), "x_mean must be a scalar with dtype float32" + assert ( + x_std_aval.shape == () and x_std_aval.dtype == jnp.float32 + ), "x_std must be a scalar with dtype float32" + return x_aval + + @staticmethod + def lowering( + ctx, + x, + x_min, + x_max, + x_mean, + x_std, + ): + """ + inspect lowering rules + """ + + return ffi.ffi_lowering( + InspectPrimitive.name, + operand_output_aliases={0: 0}, # donate input buffer to output buffer + )( + ctx, + x, + x_min, + x_max, + x_mean, + x_std, + ) + + @staticmethod + def impl( + x, + x_min, + x_max, + x_mean, + x_std, + ): + """ + inspect implementation + """ + assert InspectPrimitive.inner_primitive is not None + (x) = InspectPrimitive.inner_primitive.bind( + x, + x_min, + x_max, + x_mean, + x_std, + ) + return x + + +register_primitive(InspectPrimitive) + + +def _inspect_array_inner(x: jnp.ndarray) -> jnp.ndarray: + return InspectPrimitive.outer_primitive.bind( + x, + jnp.min(x).astype(jnp.float32), + jnp.max(x).astype(jnp.float32), + jnp.mean(x.astype(jnp.float32)), + jnp.std(x.astype(jnp.float32)), + ) + + +@partial(jax.custom_vjp, nondiff_argnums=()) +def _inspect( + x, +): + """ """ + output, _ = _inspect_fwd_rule( + x, + ) + return output + + +def _inspect_fwd_rule( + x, +): + """""" + ctx = () + x = _inspect_array_inner(x) + return x, ctx + + +def _inspect_bwd_rule( + ctx, + grad, +): + """""" + del ctx + return (grad,) + + +_inspect.defvjp(_inspect_fwd_rule, _inspect_bwd_rule) + + +def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray: + """Utility function to inspect JAX arrays by printing their name, shape, dtype, and statistics. + + Args: + x (jnp.ndarray): The JAX array to inspect. + name (str): The name of the array for identification in the output. + """ + # TODO: Handle the name of the tensor in the primitive and output files + return _inspect(x) + + +def load_array_dump(filename: str, shape: tuple, dtype: jnp.dtype) -> jnp.ndarray: + """Utility function to load a JAX array from a dumped binary file. + + Args: + filename (str): The path to the binary file containing the array data. + shape (tuple): The shape of the array to be loaded. + dtype (jnp.dtype): The data type of the array to be loaded. + + Returns: + jnp.ndarray: The loaded JAX array. + """ + with open(filename, "rb") as f: + data = f.read() + array = jnp.frombuffer(data, dtype=dtype).reshape(shape) + return array