Skip to content

Conversation

@jberchtold-nvidia
Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented Feb 4, 2026

Description

Given jax.debug.print/callback is currently broken (issue), this PR introduces an experimental alternative for use for our own internal debugging. This new debugging API allows us to inspect tensors but will be experimental and may have breaking changes without a deprecation process.

Usage:

     x = <some logic to compute x>

      from transformer_engine.jax.debug.experimental import inspect_array as te_inspect_array
      x= te_inspect_array(x, "some_name")

     <something that consumes  x>

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Introduces a new debugging tool for dumping binary blobs of tensor values for multi-GPU inspection.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft February 4, 2026 22:51
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 4, 2026

Greptile Overview

Greptile Summary

This PR introduces an experimental debugging utility for JAX that dumps tensor data to binary files with metadata, working around broken jax.debug.print/callback functionality. The implementation consists of a C++ FFI handler that performs device-to-host copies and file I/O, plus Python primitives that compute tensor statistics and integrate with JAX's custom VJP system.

Key changes:

  • Added InspectFFI handler that writes tensor binary data and JSON metadata to files
  • Implemented InspectPrimitive with proper forward/backward rules for gradient compatibility
  • Exported inspect_array() and load_array_dump() utilities in experimental namespace
  • Removed unused iostream header from amax.cpp

Issues found:

  • Hardcoded filename my_tensor_gpu prevents distinguishing different tensors - the name parameter from Python isn't passed to C++, causing overwrites when inspecting multiple tensors
  • File I/O failures are silently ignored without error reporting
  • Unconditional printf output on every execution could spam logs in production use
  • Missing outer_primitive assertion guard before bind (standard pattern in this codebase)

Confidence Score: 3/5

  • This PR is functional but has implementation issues that limit usability and could cause problems in production
  • The core functionality works but has notable limitations: hardcoded filenames prevent multi-tensor inspection, missing error handling could hide failures, and unconditional I/O/printf will impact performance. These are fixable issues but affect the utility's practical usability.
  • Pay close attention to transformer_engine/jax/csrc/extensions/inspect.cpp (hardcoded filenames, missing error handling) and transformer_engine/jax/debug/experimental/inspect.py (missing primitive guard)

Important Files Changed

Filename Overview
transformer_engine/jax/csrc/extensions/inspect.cpp New FFI implementation for tensor inspection with hardcoded filenames, unconditional I/O, and missing error handling
transformer_engine/jax/debug/experimental/inspect.py Primitive implementation for tensor inspection with unused name parameter and missing assertion guard

Sequence Diagram

sequenceDiagram
    participant User
    participant inspect_array
    participant _inspect
    participant _inspect_array_inner
    participant InspectPrimitive
    participant InspectFFI
    participant CUDA
    participant FileSystem

    User->>inspect_array: x, name
    inspect_array->>_inspect: x
    _inspect->>_inspect_fwd_rule: x
    _inspect_fwd_rule->>_inspect_array_inner: x
    _inspect_array_inner->>_inspect_array_inner: compute min(x), max(x), mean(x), std(x)
    _inspect_array_inner->>InspectPrimitive: bind(x, min, max, mean, std)
    InspectPrimitive->>InspectFFI: FFI call
    InspectFFI->>CUDA: cudaMemcpyAsync (input + stats to host)
    InspectFFI->>CUDA: cudaStreamSynchronize
    InspectFFI->>FileSystem: write my_tensor_gpu{device}.bin
    InspectFFI->>FileSystem: write my_tensor_gpu{device}_meta.json
    InspectFFI->>InspectFFI: printf metadata
    InspectFFI-->>InspectPrimitive: return x (aliased to input)
    InspectPrimitive-->>_inspect_fwd_rule: x
    _inspect_fwd_rule-->>_inspect: x, ctx
    _inspect-->>inspect_array: x
    inspect_array-->>User: x (passthrough)
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 100 to 110
NVTE_CHECK(input_buf.untyped_data() != nullptr, "Input must be provided for inspect operation");
NVTE_CHECK(output_buf->untyped_data() != nullptr,
"Output must be provided for inspect operation");
NVTE_CHECK(input_buf.untyped_data() == output_buf->untyped_data(),
"Input and output must point to the same buffer for inspect operation");

printf("JTEST: Hello\n");

return ffi_with_cuda_error_check();
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug printf in FFI

InspectFFI prints "JTEST: Hello" unconditionally. This will spam stdout/stderr on every execution (including under jax.jit where it may execute many times / across devices) and is not an acceptable side effect for a library primitive. Please remove the printf or gate it behind an explicit debug flag that defaults to off.

Comment on lines 103 to 111
def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray:
"""Utility function to inspect JAX arrays by printing their name, shape, dtype, and statistics.

Args:
x (jnp.ndarray): The JAX array to inspect.
name (str): The name of the array for identification in the output.
"""
# TODO: Handle the name of the tensor in the primitive and output files
return _inspect(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused name argument

inspect_array(x, name) documents printing the array name/shape/dtype/stats, but name is unused and the function currently just returns _inspect(x) (which itself is a no-op). This is misleading API surface and will confuse callers expecting output. Either implement passing/using name (and the actual inspection behavior) or drop the name parameter and update the docstring/export accordingly.

Comment on lines 82 to 91
def _inspect_fwd_rule(
x,
):
""""""
ctx = ()
x = InspectPrimitive.outer_primitive.bind(x)
return x, ctx


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing outer primitive assertion

_inspect_fwd_rule calls InspectPrimitive.outer_primitive.bind(x) without asserting outer_primitive is initialized. If registration fails (or this module is imported before primitives are set up), this will raise an attribute error at runtime. Other primitives in this repo typically guard with assert ... is not None before binding; please add the same guard here.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 4, 2026

Additional Comments (1)

transformer_engine/jax/cpp_extensions/gemm.py
Disabled cuBLAS alignment check

assert_cublas_requirements no longer enforces contracting_size % alignment == 0 for quantized GEMM (the assert is commented out). This will allow invalid shapes through to the cuBLAS custom call and can trigger runtime failures or incorrect behavior when using FP8/NVFP4 inputs. Please restore the check or replace it with an equivalent validation (and only relax it if the backend truly supports unaligned contracting sizes).

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
pre-commit-ci bot and others added 5 commits February 4, 2026 23:28
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia marked this pull request as ready for review February 9, 2026 21:52
jberchtold-nvidia and others added 2 commits February 9, 2026 13:52
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci jax

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

cudaGetDevice(&device);

// Write the tensor data to a file as a binary blob
std::string filename = "my_tensor_gpu" + std::to_string(device) + ".bin";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hardcoded filename my_tensor_gpu prevents distinguishing between different tensors. The name parameter from Python is not passed through to C++, making all dumps use the same base filename. This will overwrite data when inspecting multiple tensors.

Comment on lines +44 to +48
std::ofstream file(filename, std::ios::binary);
if (file.is_open()) {
file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size());
file.close();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

file write failures are silently ignored. If file.is_open() fails (e.g., permission denied, disk full), the function continues without indication. Add error handling or logging.

Comment on lines +72 to +81
// Log the tensor metadata to the console
printf("Tensor data written to %s (shape: [", filename.c_str());
for (size_t i = 0; i < input_buf.dimensions().size(); ++i) {
printf("%zu", static_cast<size_t>(input_buf.dimensions()[i]));
if (i < input_buf.dimensions().size() - 1) {
printf(", ");
}
}
printf("], dtype: %d", static_cast<int>(input_buf.element_type()));
printf(", min: %f, max: %f, mean: %f, std: %f)\n", min_val, max_val, mean_val, std_val);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unconditional printf will spam output every time this primitive executes (including in jitted code on every device). For a debugging utility, consider gating behind an environment variable or adding a way to disable verbose output.



def _inspect_array_inner(x: jnp.ndarray) -> jnp.ndarray:
return InspectPrimitive.outer_primitive.bind(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing assert InspectPrimitive.outer_primitive is not None before bind. Other primitives in this codebase guard this to prevent AttributeError if registration fails (see activation.py:351, amax.py:381, etc.).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant