Skip to content

DLA surgeries added as part of ModelOpt#1415

Open
mgohil-png wants to merge 1 commit intomainfrom
dla_surgeon_integration
Open

DLA surgeries added as part of ModelOpt#1415
mgohil-png wants to merge 1 commit intomainfrom
dla_surgeon_integration

Conversation

@mgohil-png
Copy link
Copy Markdown
Contributor

@mgohil-png mgohil-png commented May 8, 2026

What does this PR do?

Type of change: new feature

Added graph surgeries to transform models and make them DLA HW compatible.
There is single surgery make_dla_compatible which calls series of sub

Usage

from modelopt.onnx.graph_surgery import run_graph_surgery

run_graph_surgery(
    "make-dla-compatible",
    model_path="input_model.onnx",
    output_path="output_model.onnx",
)

Testing

Added multiple test cases for each sub-surgery and for make_dla_compatible surgery.
Following is the list of test cases added, each containing multiple cases.
tests/unit/dla_transforms/test_dla_5d_reshape_to_4d.py
tests/unit/dla_transforms/test_dla_cast_to_fp32.py
tests/unit/dla_transforms/test_dla_constants_to_initializers.py
tests/unit/dla_transforms/test_dla_convert_ops_to_4d.py
tests/unit/dla_transforms/test_dla_decompose_lstm.py
tests/unit/dla_transforms/test_dla_fix_instancenorm_channel_mismatch.py
tests/unit/dla_transforms/test_dla_graph_cleanup.py
tests/unit/dla_transforms/test_dla_greater.py
tests/unit/dla_transforms/test_dla_handle_qdq.py
tests/unit/dla_transforms/test_dla_matmul_to_transpose_conv_transpose.py
tests/unit/dla_transforms/test_dla_not.py
tests/unit/dla_transforms/test_dla_remove_deqlin.py
tests/unit/dla_transforms/test_dla_remove_qdq.py
tests/unit/dla_transforms/test_dla_topk.py
tests/unit/dla_transforms/test_dla_where.py

Summary by CodeRabbit

Release Notes

  • New Features

    • Added DLA (Deep Learning Accelerator) compatibility pipeline with 16-step graph transformation orchestration
    • Added unified graph surgery execution interface with support for external tensor data handling
    • Added comprehensive operator-specific transforms (reshape, LSTM decomposition, quantization handling, etc.)
    • Added shape inference and graph validation utilities
  • Documentation

    • Updated module-level documentation with DLA compatibility pipeline examples

Signed-off-by: mgohil-png <mgohil@nvidia.com>
@mgohil-png mgohil-png requested a review from hthadicherla May 8, 2026 09:30
@mgohil-png mgohil-png requested a review from a team as a code owner May 8, 2026 09:30
@mgohil-png mgohil-png requested a review from cjluo-nv May 8, 2026 09:30
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 8, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

Adds shared ONNX utilities, numerous DLA-focused graph transforms, a centralized surgery runner, a multi-stage DLA compatibility pipeline, refactors to in-memory transforms with wrappers, and extensive pytest suites validating structural rewrites and ORT parity.

Changes

DLA Graph Surgery and Orchestration

Layer / File(s) Summary
Shared utilities and dtype helpers
modelopt/onnx/graph_surgery/dla_transforms/_common.py, .../_dla_graph_helpers.py, .../onnx_dtypes.py
GraphCache indices, shape inference, graph edit helpers, dtype maps/parsers.
Core DLA transforms
modelopt/onnx/graph_surgery/dla_transforms/*
Adds many transforms: reshape 5D→4D, cast fixes, constants→initializers, rank-4 conversions, LSTM decomposition, cleanup, QDQ handling/removal, MatMul→Conv, Not/Where/Greater rewrites, etc.
Orchestration, pipelines, and refactors
modelopt/onnx/graph_surgery/__init__.py, .../make_dla_compatible.py, .../dq_transpose.py, .../encoder_cross_kv.py, .../gqa_replacement.py, .../utils/dtype_conversion.py
Adds run_graph_surgery with registry, DLA pipeline orchestrator, and refactors to in-memory transforms with file I/O wrappers.
Tests / Validation
tests/unit/dla_transforms/*
Comprehensive pytest coverage ensuring structural rewrites and ORT parity for all new transforms and cleanup steps.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant CLI as User/CLI
  participant RG as run_graph_surgery
  participant REG as SurgeryRegistry
  participant TR as TransformFn
  participant IO as Loader/Saver
  CLI->>RG: request(surgery_name, model_path, output_path, kwargs)
  RG->>IO: load ONNX model (+external data)
  RG->>REG: resolve surgery_name→transform
  RG->>TR: call(model, verbose, kwargs)
  TR-->>RG: modified model
  RG->>IO: save model (+external data, infer shapes)
  IO-->>CLI: output_path
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch dla_surgeon_integration

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 8, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1415/

Built to branch gh-pages at 2026-05-08 09:34 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note

Due to the large number of review comments, Critical severity comments were prioritized as inline comments.

🟠 Major comments (27)
modelopt/onnx/graph_surgery/dla_transforms/dla_remove_qdq.py-194-199 (1)

194-199: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Potential bug: only the first matching input is updated for multi-input consumers.

If a consumer node uses qdq_node_pair_input in multiple input slots (e.g., Add(x, x) or Concat with repeated inputs), the break statement causes only the first occurrence to be renamed to graph_output.name, leaving subsequent references dangling.

🐛 Proposed fix: remove the break to update all matching inputs
                     for j, input_name in enumerate(consumer.input):
                         if input_name == qdq_node_pair_input:
                             consumer.input[j] = graph_output.name
-                            break
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/onnx/graph_surgery/dla_transforms/dla_remove_qdq.py` around lines
194 - 199, The loop that updates consumer inputs only replaces the first
matching occurrence of qdq_node_pair_input because of the break inside the inner
loop; remove the break so all input slots that equal qdq_node_pair_input are
updated to graph_output.name. Locate the block using
consumer_node_map.get(qdq_node_pair_input, []), iterate over each consumer and
its consumer.input list, and replace every matching input_name with
graph_output.name (no early exit) to handle multi-input consumers like Add(x,x)
or repeated Concat entries.
modelopt/onnx/graph_surgery/dla_transforms/dla_constants_to_initializers.py-50-52 (1)

50-52: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

sparse_value path is using the wrong ONNX proto type and will crash conversion.

At line 51, numpy_helper.to_array() is called with attr.sparse_tensor (SparseTensorProto), but that API is documented to accept only TensorProto. Passing a SparseTensorProto will raise AttributeError or TypeError, which bypass your except ValueError guard (line 77), aborting the transform for any model using sparse_value constants.

🐛 Proposed fix
        if attr.name == "sparse_value":
-           arr = numpy_helper.to_array(attr.sparse_tensor)
-           return numpy_helper.from_array(arr, name=out_name)
+           sparse = attr.sparse_tensor
+           values = numpy_helper.to_array(sparse.values)
+           indices = numpy_helper.to_array(sparse.indices).astype(np.int64)
+           dense = np.zeros(tuple(sparse.dims), dtype=values.dtype)
+           if indices.ndim == 1:
+               dense.reshape(-1)[indices] = values
+           else:
+               dense[tuple(indices.T)] = values
+           return numpy_helper.from_array(dense, name=out_name)
-        except ValueError as exc:
+        except (ValueError, TypeError, AttributeError) as exc:
             if verbose:
                 logger.warning("%s", exc)
             continue

Also applies to: 77-80

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/onnx/graph_surgery/dla_transforms/dla_constants_to_initializers.py`
around lines 50 - 52, The sparse_value branch is incorrectly calling
numpy_helper.to_array on attr.sparse_tensor (a SparseTensorProto); instead
extract the SparseTensorProto fields, convert the sparse values and indices to
numpy arrays using numpy_helper.to_array(attr.sparse_tensor.values) and
numpy_helper.to_array(attr.sparse_tensor.indices), allocate a dense numpy array
of shape attr.sparse_tensor.dims with the proper dtype, scatter the values into
the dense array at the indices (e.g. dense[tuple(indices.T)] = values), and then
return numpy_helper.from_array(dense, name=out_name); apply the same pattern to
the similar code path around the later block (the code around the except/lines
77-80) that also handles sparse_value.
modelopt/onnx/graph_surgery/dla_transforms/dla_unsqueeze.py-28-49 (1)

28-49: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Potential IndexError when Unsqueeze has no consumers.

At line 46, node_input_pair[0][0] will raise an IndexError if get_consumers(node.output[0]) returns an empty list. This can happen if the Unsqueeze output is a graph output or is otherwise unused.

Suggested fix
 def _apply_unsqueeze(model):
     """Transform Unsqueeze node when data is of type int32 or int64."""
     cache = GraphCache(model.graph)
     nodes_to_add_at_pos = []
     for node in cache.nodes_by_op("Unsqueeze"):
         data_dtype = cache.get_dtype(node.input[0])
         if data_dtype in (TensorProto.INT32, TensorProto.INT64):
+            node_input_pair = cache.get_consumers(node.output[0])
+            if not node_input_pair:
+                # No consumers to rewire; skip this node
+                continue
             cast_output_name = f"{node.output[0]}_float32"
             cast_node = helper.make_node(
                 "Cast",
                 inputs=[node.output[0]],
                 outputs=[cast_output_name],
                 to=TensorProto.FLOAT,
                 name=f"{node.output[0]}_cast_to_float32",
             )
-            node_input_pair = cache.get_consumers(node.output[0])
             for consumer, idx in node_input_pair:
                 consumer.input[idx] = cast_output_name
             nodes_to_add_at_pos.append(([cast_node], node_input_pair[0][0]))
     for nodes_to_add, node_pos in nodes_to_add_at_pos:
         insert_nodes_at_position(model.graph, nodes_to_add, node_pos)
     return model
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/onnx/graph_surgery/dla_transforms/dla_unsqueeze.py` around lines 28
- 49, The Unsqueeze transform in _apply_unsqueeze assumes
cache.get_consumers(node.output[0]) returns at least one consumer, causing an
IndexError when the output has no consumers (e.g., graph output); change the
logic to handle an empty node_input_pair: after computing node_input_pair =
cache.get_consumers(...), only iterate and rewrite consumer.input when
node_input_pair is non-empty, and choose node_pos = node_input_pair[0][0] when
present but fall back to the Unsqueeze node itself (the original node variable)
when node_input_pair is empty before appending to nodes_to_add_at_pos and before
calling insert_nodes_at_position; this avoids indexing into an empty list and
ensures insertion still occurs in the correct place.
modelopt/onnx/graph_surgery/dla_transforms/_dla_graph_helpers.py-178-198 (1)

178-198: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Potential crash when x_zero_point is None.

At lines 185-190, x_zero_point.dtype is accessed to determine int_max and int_min before the None check at line 191. If get_initializer_by_name returns None, this will raise an AttributeError.

Suggested fix
 def calculate_clip_range(node, model):
     x_scale = get_initializer_by_name(model, node.input[1])
     if x_scale is None:
         x_scale = get_constant_by_name(model, node.input[1])
     x_zero_point = get_initializer_by_name(model, node.input[2])
     if x_scale is None:
         raise ValueError(f"{node.name} should have x_scale value")
+    if x_zero_point is None:
+        logger.info("x_zero_point is None!")
+        x_zero_point = np.array(0, dtype=np.int32)
+        int_max = np.int32(127)
+        int_min = np.int32(-128)
+    else:
-    int_max = np.int32(
-        65535 if x_zero_point.dtype == np.uint16 else 255 if x_zero_point.dtype == np.uint8 else 127
-    )
-    int_min = np.int32(
-        0 if x_zero_point.dtype == np.uint16 else 0 if x_zero_point.dtype == np.uint8 else -128
-    )
-    if x_zero_point is None:
-        logger.info("x_zero_point is None!")
-        x_zero_point = np.array(0, dtype=np.int32)
-    else:
+        int_max = np.int32(
+            65535 if x_zero_point.dtype == np.uint16 else 255 if x_zero_point.dtype == np.uint8 else 127
+        )
+        int_min = np.int32(
+            0 if x_zero_point.dtype == np.uint16 else 0 if x_zero_point.dtype == np.uint8 else -128
+        )
         x_zero_point = x_zero_point.astype(np.int32)
     clip_min = ((int_min - x_zero_point) * x_scale).astype(np.float32)
     clip_max = ((int_max - x_zero_point) * x_scale).astype(np.float32)
     return clip_min, clip_max
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/onnx/graph_surgery/dla_transforms/_dla_graph_helpers.py` around
lines 178 - 198, calculate_clip_range accesses x_zero_point.dtype before
checking for None which can crash; modify the function so the None check for
x_zero_point happens immediately after retrieving it (and before computing
int_max/int_min), e.g. if x_zero_point is None set it to np.array(0,
dtype=np.int32) (or otherwise ensure a valid dtype), then cast x_zero_point to
np.int32 and only then compute int_max and int_min and the final
clip_min/clip_max; keep references to the same symbols: calculate_clip_range,
x_scale, x_zero_point, int_max, int_min.
modelopt/onnx/graph_surgery/dla_transforms/dla_not.py-76-110 (1)

76-110: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Use a fallback id when node.name is empty.

node.name is optional in ONNX. For anonymous Not nodes, this generates the same *_clip_min, *_clip_max, *_one, and *_clipped tensor names on every rewrite, so multiple matches collide and can invalidate the graph.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/onnx/graph_surgery/dla_transforms/dla_not.py` around lines 76 - 110,
The code currently builds names using node.name which can be empty; change the
naming to use a fallback unique id when node.name is falsy (e.g., use
node.output[0] or a generated UUID/counter) before constructing min_name,
max_name, one_name, clip_out and the node.name used for clip_node/sub_node;
ensure add_unique_initializers, helper.make_node calls and the name used to
create value_info all use that fallback identifier so anonymous Not nodes
produce unique tensor and node names on each rewrite.
modelopt/onnx/graph_surgery/dla_transforms/dla_remove_reshapes.py-39-44 (1)

39-44: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Skip reshape-chain folding when endpoint shapes aren't concrete.

get_tensor_shape_by_name(...) can be unknown or symbolic on dynamic models. Converting that straight into np.array(..., dtype=np.int64) will raise before any rewrite happens. Guard both shape lookups and leave the chain untouched when either endpoint shape is not fully resolved.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/onnx/graph_surgery/dla_transforms/dla_remove_reshapes.py` around
lines 39 - 44, The code converts endpoint shapes from
get_tensor_shape_by_name(...) directly into np.array(dtype=np.int64), which will
raise on unknown/symbolic shapes; update the logic where reshape_chain endpoints
are fetched (use reshape_chain[0].input[0] and reshape_chain[-1].output[0] via
get_tensor_shape_by_name(model.graph, ...)) to first verify each returned shape
is concrete (not None and every dimension is an integer/fully resolved) and only
then build np.array(dtype=np.int64); if either endpoint shape is non‑concrete,
skip folding the reshape_chain and leave it untouched.
modelopt/onnx/graph_surgery/dla_transforms/dla_squeeze_unsqueeze_to_reshape.py-67-89 (1)

67-89: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Validate the axes initializer before reading it, and require an exact one-axis pattern.

numpy_helper.to_array(...) is called before the None check, so non-initializer axes inputs fail with an opaque exception. Also, only axes_array[0] is validated, which lets multi-axis cases like [0, 2] slip through and be rewritten incorrectly.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@modelopt/onnx/graph_surgery/dla_transforms/dla_squeeze_unsqueeze_to_reshape.py`
around lines 67 - 89, In transform_squeeze_unsqueeze_to_reshape, validate
initializer objects before calling numpy_helper.to_array and enforce a
single-axis pattern: first ensure squeeze_axes =
get_initializer_object_by_name(model, node.input[1]) is not None before
converting, then convert to array and require squeeze_axes_array.ndim == 1 and
squeeze_axes_array.size == 1 and squeeze_axes_array[0] in [0,1,2]; do the same
for unsqueeze_axes (ensure consumer.input[1] initializer exists, then to_array,
then require size == 1 and unsqueeze_axes_array[0] == -1). Use the existing
variable names (squeeze_axes, squeeze_axes_array, unsqueeze_axes,
unsqueeze_axes_array) and functions (get_initializer_object_by_name,
numpy_helper.to_array) to locate and apply the checks.
modelopt/onnx/graph_surgery/dla_transforms/dla_greater.py-79-89 (1)

79-89: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Deduplicate inserted Casts for shared tensors.

cast_out = f"{inp}_to_float32" is reused blindly. If the same non-float tensor feeds both inputs of one Greater or multiple Greater nodes, this will create duplicate Cast nodes/output names. Reuse an existing cast or uniquify the generated names.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/onnx/graph_surgery/dla_transforms/dla_greater.py` around lines 79 -
89, The code blindly creates a Cast node named via cast_out/ cast_node for any
non-float input, causing duplicate Casts when the same tensor is reused; before
creating a new Cast for inp in the Greater transform, search the existing graph
nodes for an output matching cast_out (or maintain a local map of
already-created casts keyed by inp) and if found reuse that output name instead
of creating a new node; only call insert_nodes_at_position(graph, [cast_node],
node) and assign node.input[idx] = cast_out when no existing cast was found,
otherwise set node.input[idx] to the existing cast's output; reference cast_out,
cast_node, insert_nodes_at_position, and node.input in the Greater transform.
modelopt/onnx/graph_surgery/dla_transforms/dla_not.py-60-74 (1)

60-74: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Avoid rewriting a shared Cast in place.

This changes the producer Cast to FLOAT for every matching Not. If that Cast output fans out to another bool consumer, the other branch now receives FLOAT instead of BOOL. Either require a single consumer here or clone the Cast for the rewritten path.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/onnx/graph_surgery/dla_transforms/dla_not.py` around lines 60 - 74,
The current loop over cache.nodes_by_op("Not") mutates the producer Cast node in
place (producer_node) by changing its "to" attribute and its output value_info
(cast_out_vi), which breaks other consumers; instead, for each Not node ensure
the Cast has a single consumer or create a cloned Cast node for this Not path:
if cache.get_consumers(producer_node.output[0]) returns exactly one (the Not),
you may safely change the existing producer_node.attribute 'to' and its
cast_out_vi.elem_type to TensorProto.FLOAT; otherwise, create a new Cast node
(clone of producer_node with 'to' set to FLOAT), insert it into the graph,
update the Not node.input to point to the clone.output, and add/update the
cloned output value_info via get_tensor_value_info_by_name(model,
clone.output[0]) to TensorProto.FLOAT so the original Cast and its other
consumers remain unchanged.
modelopt/onnx/graph_surgery/dla_transforms/dla_remove_deqlin.py-81-97 (1)

81-97: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Don't bypass the protected-weight skip rules on fan-out.

When dq_output has more than one consumer, _should_skip() returns False immediately. That means a DequantizeLinear feeding a protected weight use like Conv/MatMul plus any second consumer will still be folded, even though the logic below explicitly says that weight path must stay quantized. Check all consumers for protected uses before deciding to remove the DQ.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/onnx/graph_surgery/dla_transforms/dla_remove_deqlin.py` around lines
81 - 97, The current logic returns False immediately when dq_output has multiple
consumers, which lets a DequantizeLinear feeding a protected weight use slip
through; instead iterate over all entries in consumer_map.get(dq_output, []) and
for each consumer_node/input_idx apply the same protected-weight checks
(is_weight = input_idx == 1, op in {"Conv","ConvTranspose","MatMul","Gemm"} and
zp_dtype membership in _CONV_SKIP_ZP/_CONVT_SKIP_ZP/_MATMUL_SKIP_ZP). For
MatMul/Gemm call check_to_apply_transpose(consumer_node, model, shape_map)
inside a try/except and treat any consumer that matches as a reason to skip
(return True). Only if none of the consumers are protected should the function
continue returning False or proceed with folding.
modelopt/onnx/graph_surgery/dla_transforms/dla_handle_qdq.py-101-107 (1)

101-107: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Adjust the implicit Q/DQ axis too.

When axis is omitted, Q/DQ still uses the default channel axis of 1. After prepending singleton dims, that referenced dimension moves, but this code only shifts explicit attributes. Per-channel quantization that relies on the default axis will therefore point at the wrong dimension after wrapping.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/onnx/graph_surgery/dla_transforms/dla_handle_qdq.py` around lines
101 - 107, The code only shifts explicit "axis" attributes but ignores the
implicit default axis=1 used by Q/DQ; update the logic in the node attribute
handling (the has_axis/orig_axis/set_node_attr_i block) to account for the
omitted-attribute case by treating missing axis as 1 and writing back an
adjusted axis (use get_node_attr_i(node, "axis", 1) to read and
set_node_attr_i(node, "axis", orig_axis + delta) when orig_axis >= 0), i.e., if
has_axis is false still compute orig_axis = 1, add delta, and set the attribute
so per-channel quantization now references the moved dimension.
modelopt/onnx/graph_surgery/dla_transforms/dla_greater.py-63-68 (1)

63-68: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Convert any non-FLOAT initializer operand, not just INT32/INT64.

The docstring says every Greater input should end up FLOAT, but this branch skips FLOAT16, DOUBLE, INT8, etc. Those static operands remain non-float32 while dynamic tensors are cast, so the transform is incomplete for initializer-backed thresholds.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/onnx/graph_surgery/dla_transforms/dla_greater.py` around lines 63 -
68, The transform currently only casts initializers of integer types because of
the check "if dtype not in _INT_TYPES: continue"; change this to treat any
non-float32 initializer as eligible for casting: replace the dtype guard with
"if dtype == np.float32: continue" (or equivalent) so FLOAT16/DOUBLE/INT8/etc.
are not skipped, then call cache.get_init_array(inp), astype(np.float32) and
write the new array back into the initializer storage (use the cache's setter
method such as cache.set_init_array or the repository's initializer-replace
helper) so the initializer backing 'Greater' inputs is actually replaced with a
float32 tensor; keep the existing assert init_arr is not None and update any
variable names (init_arr/arr) accordingly.
modelopt/onnx/graph_surgery/dla_transforms/dla_fix_instancenorm_channel_mismatch.py-112-139 (1)

112-139: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Preserve the original tensor dtype in the injected value_info.

InstanceNormalization outputs the same element type as its input, but instrnorm_out and post_out are always recorded as FLOAT. On FP16 graphs that makes the rewritten metadata lie about the intermediate dtype and can confuse later passes that read value_info.

Suggested fix
-        add_value_info(graph, instrnorm_out, onnx.TensorProto.FLOAT, [n, c, d, 1])
+        add_value_info(graph, instrnorm_out, in_dtype, [n, c, d, 1])
...
-        add_value_info(graph, post_out, onnx.TensorProto.FLOAT, [1, n, c, d])
+        add_value_info(graph, post_out, in_dtype, [1, n, c, d])
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@modelopt/onnx/graph_surgery/dla_transforms/dla_fix_instancenorm_channel_mismatch.py`
around lines 112 - 139, The injected value_info for instrnorm_out and post_out
always uses onnx.TensorProto.FLOAT, which misreports dtype on FP16 graphs;
change both add_value_info calls to use the original input element type (the
in_dtype computed via cache.get_dtype(in_name) or equivalent) so
InstanceNormalization's output and the subsequent reshape preserve the true
dtype metadata (update the add_value_info calls for instrnorm_out and post_out
to pass in_dtype rather than onnx.TensorProto.FLOAT).
modelopt/onnx/graph_surgery/dla_transforms/dla_cast_to_fp32.py-77-81 (1)

77-81: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Update graph.output types when a rewritten Cast is a model output.

This only patches graph.value_info. If node.output[0] is also declared in graph.output, the model metadata still reports the old elem type even though the Cast now returns FLOAT.

Suggested fix
-                for vi in graph.value_info:
+                for vi in list(graph.value_info) + list(graph.output):
                     if vi.name == node.output[0]:
                         vi.type.tensor_type.elem_type = TensorProto.FLOAT
-                        break
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/onnx/graph_surgery/dla_transforms/dla_cast_to_fp32.py` around lines
77 - 81, The code updates graph.value_info but misses updating graph.output when
the Cast output is also a declared model output; modify the loop (or add a
second loop) to check graph.output entries and set their type to
TensorProto.FLOAT when vi.name == node.output[0] (i.e., update
graph.output[i].type.tensor_type.elem_type as well as graph.value_info) so the
model metadata reflects the rewritten Cast output; ensure you reference
node.output[0], graph.value_info, and graph.output and set elem_type =
TensorProto.FLOAT consistent with the existing change.
modelopt/onnx/graph_surgery/dla_transforms/dla_convert_ops_to_4d.py-247-272 (1)

247-272: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Handle Transpose nodes that rely on the default permutation.

ONNX treats a missing perm as reversing the axes. This handler returns None when the attribute is absent, so default-form Transpose nodes stay non-4D and escape the pass entirely.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/onnx/graph_surgery/dla_transforms/dla_convert_ops_to_4d.py` around
lines 247 - 272, The code returns None when a Transpose node lacks a "perm"
attribute, but ONNX semantics treat a missing perm as the reversed axes; update
the handler so if no attr with name "perm" is found you construct perm =
list(range(rank-1, -1, -1)) (i.e., default reverse permutation) and then apply
the existing rank-specific logic (using rank, orig_shape, idx_to_remove, etc.),
set perm_updated = True and write the computed new_perm back into attr.ints (or
create the attribute entry before writing) so default-form Transpose nodes are
handled like explicit perms.
modelopt/onnx/graph_surgery/dla_transforms/dla_decompose_lstm.py-773-915 (1)

773-915: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Fail fast on unsupported LSTM features instead of rewriting them with default semantics.

This pass hardcodes the default ONNX LSTM behavior, but it never validates layout, sequence_lens, P, clip, input_forget, or custom activations / activation params. Models using any of those features are currently decomposed as if they were absent, which silently changes results.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/onnx/graph_surgery/dla_transforms/dla_decompose_lstm.py` around
lines 773 - 915, The transform silently ignores unsupported LSTM features; add
explicit validation early in the per-node handling (before
transpose_inputs/split_direction_wise) to fail-fast when unsupported attributes
are present: inspect node.attribute and check for non-default/unsupported values
of "layout", "sequence_lens", "P" (presence), "clip" (non-zero), "input_forget"
(true), and custom "activations"/"activation_alpha"/"activation_beta" (anything
other than the default activation set); if any unsupported feature is found,
raise a clear ValueError referencing the node.name; implement this as a small
helper (e.g., validate_lstm_supported(node)) and call it right after computing
direction/shape_info_map so the rest of the code (transpose_inputs,
split_direction_wise, add_nodes_for_time_step) can assume only the supported
default semantics.
modelopt/onnx/graph_surgery/dla_transforms/dla_matmul_to_transpose_conv_transpose.py-76-151 (1)

76-151: ⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Preserve shared weights instead of reshaping them in place.

Both the eligible path and the fallback path rewrite the original initializer proto directly. If that weight is shared by another MatMul/Gemm or any unrelated op, the other consumer will now see Conv-shaped data or lose the expected tensor layout entirely. Please clone/repoint single-LSTM consumers instead of mutating shared initializers in place.

Also applies to: 187-197

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@modelopt/onnx/graph_surgery/dla_transforms/dla_matmul_to_transpose_conv_transpose.py`
around lines 76 - 151, The update_initializers function currently mutates the
original initializer proto in-place which breaks other consumers; instead detect
if the initializer (initializer.name / name) has multiple consumers in
graph.node and if so clone it: create a new unique tensor name, create a new
initializer proto from reshaped_arr via numpy_helper.from_array (do not call
initializer.CopyFrom on the original), add the new initializer to
graph.initializer, and repoint only the MatMul/Gemm/DequantizeLinear node inputs
(the node found by dq_node or the node being transformed) to the new name; when
the initializer has a single consumer you may keep the in-place replacement as
before. Apply the same cloning logic to the other similar block mentioned (lines
~187-197) so shared weights are never mutated in-place.
modelopt/onnx/graph_surgery/dla_transforms/dla_convert_ops_to_4d.py-989-997 (1)

989-997: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Pad Tile.repeats for every promoted rank, not just 3D.

This handler unsqueezes any non-4D input to 4D, but it only expands the repeats initializer when its length is 3. A 2D Tile will therefore get a 4D input with a 2-element repeats vector, which is invalid.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/onnx/graph_surgery/dla_transforms/dla_convert_ops_to_4d.py` around
lines 989 - 997, The repeats initializer padding only handles len==3; change the
logic in the Tile-promoting block (the code that checks node.input[1] in
init_names and iterates graph.initializer) to pad any repeats array with leading
ones until its length equals 4 instead of only handling 3D: read init via
numpy_helper.to_array(init) into arr, if arr.ndim==1 and len(arr) < 4 create
new_arr = np.concatenate([np.ones(4-len(arr), dtype=arr.dtype),
arr]).astype(arr.dtype) and then init.CopyFrom(numpy_helper.from_array(new_arr,
init.name)); keep existing break and other checks intact so 1D/2D/3D repeats are
all promoted correctly when promoting inputs to 4D.
modelopt/onnx/graph_surgery/dla_transforms/dla_matmul_to_transpose_conv_transpose.py-185-196 (1)

185-196: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Shift the DQ channel axis in the fallback 4D-promotion path.

This branch pads the quantized source initializer to 4D for a Const -> DequantizeLinear weight, but it never updates the matching DequantizeLinear.axis. Per-channel scales/zero-points will target the wrong dimension after promotion.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@modelopt/onnx/graph_surgery/dla_transforms/dla_matmul_to_transpose_conv_transpose.py`
around lines 185 - 196, The fallback 4D-promotion pads a const initializer but
doesn't adjust any DequantizeLinear.axis, so per-channel scales/zero-points will
point to the wrong dim; after you compute arr and new_shape in the branch that
reshapes the initializer, find all DequantizeLinear nodes in graph whose input
(the quantized input) equals init_name and if they have an 'axis' attribute
increment that axis by (len(new_shape) - arr.ndim) (i.e. the number of leading
dims you added) — update the attribute in-place on those nodes so the axis
refers to the same logical channel after promotion.
modelopt/onnx/graph_surgery/dla_transforms/dla_matmul_to_transpose_conv_transpose.py-248-345 (1)

248-345: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Reject non-default Gemm.alpha/beta or fold them into the rewrite.

The conversion only honors transA, transB, and optional bias. Gemm also applies alpha * A'B' + beta * C; any node with alpha != 1 or beta != 1 will change numerics after being replaced with Conv.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@modelopt/onnx/graph_surgery/dla_transforms/dla_matmul_to_transpose_conv_transpose.py`
around lines 248 - 345, The transform currently ignores Gemm attributes alpha
and beta; inspect node.attribute for "alpha" and "beta" near the top of this
rewrite (alongside transA/transB) and if either is not 1.0 either (a) reject the
rewrite by raising a ValueError referencing node.name, or (b) fold them into the
Conv inputs: multiply the weight initializer by alpha (use update_initializers
or modify the initializer named by init_name when is_const_dq is True; for
dynamic weights insert a Mul before Conv using conv_input1), and apply beta to
the bias/C input (if C is present and constant scale the bias tensor, otherwise
insert a Mul on the bias/C input before adding to conv). Ensure the chosen
solution handles both static and dynamic weight/bias paths and updates
conv_inputs and initializer updates consistently.
modelopt/onnx/graph_surgery/dla_transforms/dla_decompose_lstm.py-81-105 (1)

81-105: ⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Don't delete or rewrite initializers that may still be shared.

These helpers replace/remove the original initializer unconditionally after creating the LSTM-specific 4D/transposed version. If another node still references the original tensor, that consumer is left with a missing name or with a layout that was only valid for this decomposition.

Also applies to: 107-120, 304-349

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/onnx/graph_surgery/dla_transforms/dla_decompose_lstm.py` around
lines 81 - 105, The code creates a new 4D LSTM-specific initializer (via
_expand_array_to_rank4 and numpy_helper.from_array) and then unconditionally
removes the original initializer (graph.initializer.remove(initializer)) which
may still be referenced by other nodes; instead, after creating new_init and
adding it with add_unique_initializers, check whether any other graph.node
(excluding the consumer you just rewired, e.g., deq_node or the LSTM consumer in
the other ranges) still references initializer.name — if no other references
exist, remove it; otherwise leave the original initializer in place and only set
the specific consumer's input to new_init.name. Apply the same pattern for the
other affected code blocks (the ranges around 107-120 and 304-349) and use
functions/variables like cache.get_init, _expand_array_to_rank4,
add_unique_initializers, graph.initializer.remove, and the consumer node inputs
to locate the spots to change.
modelopt/onnx/graph_surgery/dla_transforms/dla_convert_ops_to_4d.py-439-475 (1)

439-475: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Shift implicit Slice axes after promoting the input to 4D.

When axes is omitted, ONNX interprets it as [0..len(starts)-1]. After the input is unsqueezed, those implied axes need the same delta as explicit axes; otherwise a 2D/3D Slice will slice the prepended singleton dims instead of the original data dims.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/onnx/graph_surgery/dla_transforms/dla_convert_ops_to_4d.py` around
lines 439 - 475, The Slice transform fails to shift implicit axes (when
node.input[3] is missing/empty) after unsqueezing; update the block handling
axes so if node.input[3] is absent or not in init_names you read the starts
initializer (node.input[1]) to get its length, build an axes array as
np.arange(len(starts)), apply delta = _axis_delta(rank) to each axis, create a
new initializer with numpy_helper.from_array (reusing cnt_ref and inits_to_add
like the explicit-axes path), set node.input[3] to the new init name, and
otherwise keep the existing explicit-axes code path (use same symbols: node,
node.input, graph.initializer, starts initializer, _axis_delta, cnt_ref,
inits_to_add, numpy_helper.from_array).
modelopt/onnx/graph_surgery/dla_transforms/dla_convert_ops_to_4d.py-552-615 (1)

552-615: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Treat axis-less Reduce as “reduce all dims” in the post-squeeze.

If a Reduce* node omits axes, adjusted_axes stays empty here. For keepdims=0, the fallback squeeze then removes only the promoted leading dims, so reduce-all cases come back as [1, ...] shapes instead of the original scalar / lower-rank output.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/onnx/graph_surgery/dla_transforms/dla_convert_ops_to_4d.py` around
lines 552 - 615, The post-squeeze currently misses the case where a Reduce node
has no axes (meaning "reduce all dims"), so when orig_keepdims==0 and
adjusted_axes is empty the code only squeezes promoted leading dims; to fix,
when axes are omitted set adjusted_axes = [i + delta for i in range(rank)] (use
the existing delta computed by delta = _axis_delta(rank)) before the Squeeze
logic so the squeeze will remove all original reduced dimensions; update the
block that computes adjusted_axes (the branch after checking node.input[1] and
the attribute branch) to populate adjusted_axes accordingly when no axes were
found so the later Squeeze/Sq_axes computation (variables adjusted_axes,
orig_keepdims, leading, sq_axes, temp_out) behaves correctly.
modelopt/onnx/graph_surgery/make_dla_compatible.py-178-220 (1)

178-220: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

dla_make_dla_compatible no longer matches the exported surgery API.

This helper is re-exported as make_dla_compatible, but unlike the other public surgeries it neither accepts output_path nor saves the result. The example in modelopt/onnx/graph_surgery/__init__.py calls it with output_path=..., which will currently raise TypeError. Please either make this a normal file-to-file wrapper or keep it private and steer callers to run_graph_surgery.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/onnx/graph_surgery/make_dla_compatible.py` around lines 178 - 220,
The public helper dla_make_dla_compatible currently only takes model_path and
returns a ModelProto, which mismatches the exported surgery API that expects an
output_path/file-to-file wrapper; update dla_make_dla_compatible to accept an
optional output_path: Optional[str] = None (and keep verbose) and after calling
_transform_make_dla_compatible(model, ...) save the transformed model to
output_path when provided (use onnx.save / appropriate external data flags) and
still return the ModelProto; reference the existing symbols model,
_transform_make_dla_compatible, and logger to locate where to load, transform,
optionally save, and return the model.
modelopt/onnx/graph_surgery/dla_transforms/dla_5d_reshape_to_4d.py-57-66 (1)

57-66: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Reject fan-out before rewriting the follower chain.

_parse_middleware_strict() currently takes cands[0] when cur has multiple first-input consumers. In a branched graph that means this pass rewrites/removes only one branch and leaves the others pointing at deleted tensors or nodes. Please require a single consumer here (or skip the candidate) instead of choosing one arbitrarily.

💡 Proposed fix
         cands = _consumers_first_input(graph, cur)
         if not cands:
             msg = (
                 f"{path_context}: no consumer for tensor {cur!r}; "
                 f"expected middleware (Clip / Q→DQ) or terminal in {sorted(stop_op_types)}"
             )
             raise ValueError(msg)
+        if len(cands) != 1:
+            msg = (
+                f"{path_context}: expected a single first-input consumer for tensor {cur!r}, "
+                f"got {[n.name for n in cands]}"
+            )
+            raise ValueError(msg)
         n = cands[0]
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/onnx/graph_surgery/dla_transforms/dla_5d_reshape_to_4d.py` around
lines 57 - 66, The _parse_middleware_strict() loop currently picks cands[0] from
_consumers_first_input(graph, cur) which allows fan-out to be silently ignored
and later leaves other branches pointing at removed tensors; change the logic to
require exactly one first-input consumer (len(cands) == 1) before treating it as
the next follower, otherwise skip this candidate (or treat it as non-middleware)
and continue/raise as appropriate so you do not arbitrarily pick cands[0];
update references to cands, cur and the stop_op_types check to only act when a
single consumer exists.
modelopt/onnx/graph_surgery/make_dla_compatible.py-162-171 (1)

162-171: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Don't continue after onnx.checker.ValidationError.

If the transformed graph fails ONNX validation, returning it as “done” makes the failure surface later during save/load/runtime instead of at the transform boundary. Please re-raise ValidationError here and only treat non-validation checker failures as skippable.

💡 Proposed fix
     try:
         onnx.checker.check_model(model)
     except onnx.checker.ValidationError as exc:
-        logger.warning(
-            "[DLA pipeline] ONNX check reported issues (model may still be usable): %s", exc
-        )
+        logger.error("[DLA pipeline] ONNX validation failed: %s", exc)
+        raise
     except Exception as exc:
         logger.debug(
             "[DLA pipeline] ONNX in-memory check skipped (model may be too large): %s", exc
         )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/onnx/graph_surgery/make_dla_compatible.py` around lines 162 - 171,
The current onnx.checker.check_model call swallows ValidationError and logs a
warning but continues; instead, when onnx.checker.ValidationError is caught
(from the check_model call), re-raise it so the transform fails fast (do not
return/continue). Keep the broader Exception catch for other non-validation
issues as skippable and logged with logger.debug; specifically adjust the except
onnx.checker.ValidationError branch in the block around onnx.checker.check_model
to re-raise the caught ValidationError (while preserving any logging if desired)
and only treat general Exception as the skippable case.
tests/unit/dla_transforms/test_dla_convert_ops_to_4d.py-161-181 (1)

161-181: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Compare non-floating outputs as well.

These helpers skip value checks for every non-float dtype, so the ArgMax cases currently pass as long as the shape matches. That leaves the rewritten indices completely unverified.

💡 Proposed fix
 def _assert_match(ref: dict, got: dict) -> None:
     assert ref.keys() == got.keys()
     for name in ref:
         a, b = ref[name], got[name]
         assert a.shape == b.shape, f"{name}: shape {a.shape} vs {b.shape}"
         if a.dtype.kind == "f":
             cos = _cosine_similarity(a, b)
             assert cos >= 1.0 - 1e-5, f"{name}: cosine {cos:.8f}"
             np.testing.assert_allclose(a, b, rtol=1e-4, atol=1e-4, err_msg=name)
+        else:
+            np.testing.assert_array_equal(a, b, err_msg=name)
 
 
 def _assert_values_match(ref: dict, got: dict) -> None:
     """Element-wise comparison ignoring shape (for ops whose output rank changes)."""
     assert ref.keys() == got.keys()
     for name in ref:
         a, b = ref[name].ravel(), got[name].ravel()
         assert a.shape == b.shape, f"{name}: element count {a.shape} vs {b.shape}"
         if a.dtype.kind == "f":
             cos = _cosine_similarity(a, b)
             assert cos >= 1.0 - 1e-5, f"{name}: cosine {cos:.8f}"
             np.testing.assert_allclose(a, b, rtol=1e-4, atol=1e-4, err_msg=name)
+        else:
+            np.testing.assert_array_equal(a, b, err_msg=name)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unit/dla_transforms/test_dla_convert_ops_to_4d.py` around lines 161 -
181, The helpers _assert_match and _assert_values_match currently only validate
floating dtypes' values and skip non-float outputs; update both functions to
also compare non-floating arrays element-wise (e.g., using exact
equality/assert_array_equal) after the dtype check so integer/bool/index outputs
like ArgMax are verified; keep existing cosine/approx checks for float types and
apply the element-wise comparison on the flattened arrays in
_assert_values_match and on full-shape arrays in _assert_match.
🟡 Minor comments (2)
tests/unit/dla_transforms/test_dla_remove_qdq.py-16-20 (1)

16-20: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Remove duplicate license header.

Lines 16-20 duplicate the SPDX license header already present at lines 1-14. This appears to be a copy-paste error.

🧹 Proposed fix
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-#
 # Unit tests for :mod:`modelopt.onnx.graph_surgery.dla_transforms.dla_remove_qdq`.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unit/dla_transforms/test_dla_remove_qdq.py` around lines 16 - 20,
Remove the duplicated SPDX license header block (the second occurrence
containing the SPDX-FileCopyrightText and SPDX-License-Identifier) so only the
original header at the top of the file remains; locate the repeated block (the
header shown in the diff) in test_dla_remove_qdq.py and delete that duplicate
lines 16-20, keeping the rest of the test file unchanged.
modelopt/onnx/graph_surgery/dla_transforms/dla_topk.py-43-50 (1)

43-50: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Docstring/code mismatch: axis condition.

The module docstring (line 47) states the condition is GatherElements axis == 2, but the code at line 132 checks gather_axis != 1. These are different conditions. Please reconcile the documentation with the implementation.

Also applies to: 131-133

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/onnx/graph_surgery/dla_transforms/dla_topk.py` around lines 43 - 50,
The docstring claims the GatherElements axis must be 2 but the implementation
checks gather_axis != 1; reconcile them by making the condition and
documentation match: either change the runtime check around gather_axis (the
variable used when inspecting the GatherElements node) to require axis == 2, or
update the module docstring to state axis == 1. Update the text in the
module-level docstring and the conditional that references gather_axis (and any
related comments) so they are identical, and run/adjust any unit tests or
assertions that rely on the GatherElements axis check.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: abe4dc0a-2f75-4eda-839e-a5b8801873b1

📥 Commits

Reviewing files that changed from the base of the PR and between e2d29c8 and 0b7c916.

📒 Files selected for processing (44)
  • modelopt/onnx/graph_surgery/__init__.py
  • modelopt/onnx/graph_surgery/dla_transforms/_common.py
  • modelopt/onnx/graph_surgery/dla_transforms/_dla_graph_helpers.py
  • modelopt/onnx/graph_surgery/dla_transforms/dla_5d_reshape_to_4d.py
  • modelopt/onnx/graph_surgery/dla_transforms/dla_cast_to_fp32.py
  • modelopt/onnx/graph_surgery/dla_transforms/dla_constants_to_initializers.py
  • modelopt/onnx/graph_surgery/dla_transforms/dla_convert_ops_to_4d.py
  • modelopt/onnx/graph_surgery/dla_transforms/dla_decompose_lstm.py
  • modelopt/onnx/graph_surgery/dla_transforms/dla_fix_instancenorm_channel_mismatch.py
  • modelopt/onnx/graph_surgery/dla_transforms/dla_graph_cleanup.py
  • modelopt/onnx/graph_surgery/dla_transforms/dla_greater.py
  • modelopt/onnx/graph_surgery/dla_transforms/dla_handle_qdq.py
  • modelopt/onnx/graph_surgery/dla_transforms/dla_matmul_to_transpose_conv_transpose.py
  • modelopt/onnx/graph_surgery/dla_transforms/dla_not.py
  • modelopt/onnx/graph_surgery/dla_transforms/dla_remove_deqlin.py
  • modelopt/onnx/graph_surgery/dla_transforms/dla_remove_intermediary_squeeze_and_unsqueeze.py
  • modelopt/onnx/graph_surgery/dla_transforms/dla_remove_qdq.py
  • modelopt/onnx/graph_surgery/dla_transforms/dla_remove_reshapes.py
  • modelopt/onnx/graph_surgery/dla_transforms/dla_remove_unused_initializers.py
  • modelopt/onnx/graph_surgery/dla_transforms/dla_squeeze_unsqueeze_to_reshape.py
  • modelopt/onnx/graph_surgery/dla_transforms/dla_topk.py
  • modelopt/onnx/graph_surgery/dla_transforms/dla_unsqueeze.py
  • modelopt/onnx/graph_surgery/dla_transforms/dla_where.py
  • modelopt/onnx/graph_surgery/dla_transforms/onnx_dtypes.py
  • modelopt/onnx/graph_surgery/dq_transpose.py
  • modelopt/onnx/graph_surgery/encoder_cross_kv.py
  • modelopt/onnx/graph_surgery/gqa_replacement.py
  • modelopt/onnx/graph_surgery/make_dla_compatible.py
  • modelopt/onnx/graph_surgery/utils/dtype_conversion.py
  • tests/unit/dla_transforms/test_dla_5d_reshape_to_4d.py
  • tests/unit/dla_transforms/test_dla_cast_to_fp32.py
  • tests/unit/dla_transforms/test_dla_constants_to_initializers.py
  • tests/unit/dla_transforms/test_dla_convert_ops_to_4d.py
  • tests/unit/dla_transforms/test_dla_decompose_lstm.py
  • tests/unit/dla_transforms/test_dla_fix_instancenorm_channel_mismatch.py
  • tests/unit/dla_transforms/test_dla_graph_cleanup.py
  • tests/unit/dla_transforms/test_dla_greater.py
  • tests/unit/dla_transforms/test_dla_handle_qdq.py
  • tests/unit/dla_transforms/test_dla_matmul_to_transpose_conv_transpose.py
  • tests/unit/dla_transforms/test_dla_not.py
  • tests/unit/dla_transforms/test_dla_remove_deqlin.py
  • tests/unit/dla_transforms/test_dla_remove_qdq.py
  • tests/unit/dla_transforms/test_dla_topk.py
  • tests/unit/dla_transforms/test_dla_where.py

@codecov
Copy link
Copy Markdown

codecov Bot commented May 8, 2026

Codecov Report

❌ Patch coverage is 73.79467% with 924 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.62%. Comparing base (e2d29c8) to head (0b7c916).

Files with missing lines Patch % Lines
...graph_surgery/dla_transforms/_dla_graph_helpers.py 26.40% 184 Missing ⚠️
...ph_surgery/dla_transforms/dla_convert_ops_to_4d.py 79.75% 182 Missing ⚠️
...elopt/onnx/graph_surgery/dla_transforms/_common.py 56.80% 108 Missing ⚠️
...graph_surgery/dla_transforms/dla_decompose_lstm.py 78.86% 86 Missing ⚠️
.../graph_surgery/dla_transforms/dla_graph_cleanup.py 78.90% 81 Missing ⚠️
...aph_surgery/dla_transforms/dla_5d_reshape_to_4d.py 80.28% 55 Missing ⚠️
...ansforms/dla_matmul_to_transpose_conv_transpose.py 78.94% 36 Missing ⚠️
modelopt/onnx/graph_surgery/make_dla_compatible.py 39.65% 35 Missing ⚠️
...nnx/graph_surgery/dla_transforms/dla_remove_qdq.py 75.59% 31 Missing ⚠️
...t/onnx/graph_surgery/dla_transforms/onnx_dtypes.py 62.50% 21 Missing ⚠️
... and 13 more
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1415      +/-   ##
==========================================
- Coverage   77.40%   76.62%   -0.78%     
==========================================
  Files         476      496      +20     
  Lines       51319    54823    +3504     
==========================================
+ Hits        39721    42008    +2287     
- Misses      11598    12815    +1217     
Flag Coverage Δ
examples 41.72% <0.00%> (-0.24%) ⬇️
gpu 59.72% <0.00%> (-0.72%) ⬇️
unit 53.93% <73.79%> (+1.37%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant