-
Notifications
You must be signed in to change notification settings - Fork 632
[JAX] Debugging inspect utility #2651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[JAX] Debugging inspect utility #2651
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Greptile OverviewGreptile SummaryThis PR introduces an experimental debugging utility for JAX that dumps tensor data to binary files with metadata, working around broken Key changes:
Issues found:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant inspect_array
participant _inspect
participant _inspect_array_inner
participant InspectPrimitive
participant InspectFFI
participant CUDA
participant FileSystem
User->>inspect_array: x, name
inspect_array->>_inspect: x
_inspect->>_inspect_fwd_rule: x
_inspect_fwd_rule->>_inspect_array_inner: x
_inspect_array_inner->>_inspect_array_inner: compute min(x), max(x), mean(x), std(x)
_inspect_array_inner->>InspectPrimitive: bind(x, min, max, mean, std)
InspectPrimitive->>InspectFFI: FFI call
InspectFFI->>CUDA: cudaMemcpyAsync (input + stats to host)
InspectFFI->>CUDA: cudaStreamSynchronize
InspectFFI->>FileSystem: write my_tensor_gpu{device}.bin
InspectFFI->>FileSystem: write my_tensor_gpu{device}_meta.json
InspectFFI->>InspectFFI: printf metadata
InspectFFI-->>InspectPrimitive: return x (aliased to input)
InspectPrimitive-->>_inspect_fwd_rule: x
_inspect_fwd_rule-->>_inspect: x, ctx
_inspect-->>inspect_array: x
inspect_array-->>User: x (passthrough)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 4 comments
| NVTE_CHECK(input_buf.untyped_data() != nullptr, "Input must be provided for inspect operation"); | ||
| NVTE_CHECK(output_buf->untyped_data() != nullptr, | ||
| "Output must be provided for inspect operation"); | ||
| NVTE_CHECK(input_buf.untyped_data() == output_buf->untyped_data(), | ||
| "Input and output must point to the same buffer for inspect operation"); | ||
|
|
||
| printf("JTEST: Hello\n"); | ||
|
|
||
| return ffi_with_cuda_error_check(); | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Debug printf in FFI
InspectFFI prints "JTEST: Hello" unconditionally. This will spam stdout/stderr on every execution (including under jax.jit where it may execute many times / across devices) and is not an acceptable side effect for a library primitive. Please remove the printf or gate it behind an explicit debug flag that defaults to off.
| def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray: | ||
| """Utility function to inspect JAX arrays by printing their name, shape, dtype, and statistics. | ||
|
|
||
| Args: | ||
| x (jnp.ndarray): The JAX array to inspect. | ||
| name (str): The name of the array for identification in the output. | ||
| """ | ||
| # TODO: Handle the name of the tensor in the primitive and output files | ||
| return _inspect(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused name argument
inspect_array(x, name) documents printing the array name/shape/dtype/stats, but name is unused and the function currently just returns _inspect(x) (which itself is a no-op). This is misleading API surface and will confuse callers expecting output. Either implement passing/using name (and the actual inspection behavior) or drop the name parameter and update the docstring/export accordingly.
| def _inspect_fwd_rule( | ||
| x, | ||
| ): | ||
| """""" | ||
| ctx = () | ||
| x = InspectPrimitive.outer_primitive.bind(x) | ||
| return x, ctx | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing outer primitive assertion
_inspect_fwd_rule calls InspectPrimitive.outer_primitive.bind(x) without asserting outer_primitive is initialized. If registration fails (or this module is imported before primitives are set up), this will raise an attribute error at runtime. Other primitives in this repo typically guard with assert ... is not None before binding; please add the same guard here.
Additional Comments (1)
|
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
da437ca to
f2d1629
Compare
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
966b035 to
f56d869
Compare
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 4 comments
| cudaGetDevice(&device); | ||
|
|
||
| // Write the tensor data to a file as a binary blob | ||
| std::string filename = "my_tensor_gpu" + std::to_string(device) + ".bin"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hardcoded filename my_tensor_gpu prevents distinguishing between different tensors. The name parameter from Python is not passed through to C++, making all dumps use the same base filename. This will overwrite data when inspecting multiple tensors.
| std::ofstream file(filename, std::ios::binary); | ||
| if (file.is_open()) { | ||
| file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size()); | ||
| file.close(); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
file write failures are silently ignored. If file.is_open() fails (e.g., permission denied, disk full), the function continues without indication. Add error handling or logging.
| // Log the tensor metadata to the console | ||
| printf("Tensor data written to %s (shape: [", filename.c_str()); | ||
| for (size_t i = 0; i < input_buf.dimensions().size(); ++i) { | ||
| printf("%zu", static_cast<size_t>(input_buf.dimensions()[i])); | ||
| if (i < input_buf.dimensions().size() - 1) { | ||
| printf(", "); | ||
| } | ||
| } | ||
| printf("], dtype: %d", static_cast<int>(input_buf.element_type())); | ||
| printf(", min: %f, max: %f, mean: %f, std: %f)\n", min_val, max_val, mean_val, std_val); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unconditional printf will spam output every time this primitive executes (including in jitted code on every device). For a debugging utility, consider gating behind an environment variable or adding a way to disable verbose output.
|
|
||
|
|
||
| def _inspect_array_inner(x: jnp.ndarray) -> jnp.ndarray: | ||
| return InspectPrimitive.outer_primitive.bind( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing assert InspectPrimitive.outer_primitive is not None before bind. Other primitives in this codebase guard this to prevent AttributeError if registration fails (see activation.py:351, amax.py:381, etc.).
Description
Given
jax.debug.print/callbackis currently broken (issue), this PR introduces an experimental alternative for use for our own internal debugging. This new debugging API allows us to inspect tensors but will be experimental and may have breaking changes without a deprecation process.Usage:
Type of change
Changes
Checklist: