Skip to content

Fix optional-output guard in DecoderAttention/MultiHeadAttention shape inference#29243

Open
titaiwangms wants to merge 7 commits into
microsoft:mainfrom
titaiwangms:fix/contrib-attn-shapeinf-oob
Open

Fix optional-output guard in DecoderAttention/MultiHeadAttention shape inference#29243
titaiwangms wants to merge 7 commits into
microsoft:mainfrom
titaiwangms:fix/contrib-attn-shapeinf-oob

Conversation

@titaiwangms

@titaiwangms titaiwangms commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

Summary

DecoderAttentionTypeAndShapeInference and MultiHeadAttentionTypeAndShapeInference (the latter shared by MultiHeadAttention and DecoderMaskedMultiHeadAttention) guarded population of the optional present/cache outputs with getNumOutputs() > 1, but they populate output index 2. These outputs are produced as a pair (both-or-neither), so the guard should require more than 2 outputs before populating outputs 1 and 2.

This changes the three guards from > 1 to > 2, matching the established pattern already used in BaseGroupQueryAttentionTypeAndShapeInference (>= 3), PagedAttentionTypeAndShapeInference, and the EmbedLayerNormalization fix in #28176.

Defensive bounds check

Also adds a bounds check in InferenceContextImpl::getOutputType (core/graph/graph.cc), consistent with the checked access already used by getInputType, so an out-of-range output index yields a clean type-inference failure instead of undefined behavior.

Tests

New onnxruntime/test/contrib_ops/attention_optional_outputs_shape_inference_test.cc with 6 cases:

  • 3 cases that declare exactly 2 outputs (omitting the paired optional output) and confirm shape inference completes cleanly.
  • 3 cases that declare all present outputs and assert their element types are still inferred, proving the > 2 branch still performs legitimate KV-cache shape inference.

The tests run through graph resolution and are execution-provider-independent. They are throw-free and safe under ORT_NO_EXCEPTIONS builds.

Reference

Mirrors #28176.

titaiwangms and others added 5 commits June 23, 2026 23:49
DecoderAttention and MultiHeadAttention shape-inference functions guarded
population of present_key (output 1) and present_value (output 2) with
getNumOutputs() > 1, but write output index 2. present_key and present_value
are produced as a both-or-neither pair, so require all three outputs (> 2)
before populating them, matching BaseGroupQueryAttention (>= 3) and the
EmbedLayerNorm guard. Also add a bounds check in
InferenceContextImpl::getOutputType so an out-of-range output index fails
inference cleanly instead of indexing past the end, mirroring
DataPropagationContextImpl and getInputType.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…outputs omitted

Cover DecoderAttention, MultiHeadAttention and DecoderMaskedMultiHeadAttention
nodes declared with exactly two outputs (present_key kept, present_value
omitted). Each test builds the node and asserts Graph::Resolve() shape inference
completes cleanly. Tests are execution-provider independent and throw-free, so
they run on the default CPU build and in no-exception (ORT_NO_EXCEPTIONS) builds.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Extend the optional-present-output regression suite with cases that declare all
three outputs for DecoderAttention, MultiHeadAttention and
DecoderMaskedMultiHeadAttention and assert the present_key/present_value branch
still runs and infers their element types. Together with the two-output cases
this pins the output-count guard to exactly three.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Use each op's actual output names (DecoderAttention: new_key_cache /
new_value_cache; MultiHeadAttention: present_key / present_value), align the
three guards to a consistent '// has <names> outputs' phrasing, and note that
the two optional cache outputs are produced as a pair, so they are present only
when the node declares more than two outputs. Comment-only; no logic change.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The MultiHeadAttention and DecoderMaskedMultiHeadAttention two-output cases only
passed a query input, so the present-output branch (which references output index
2) was never entered and the tests could not detect a regression there. Supply
shaped past_key / past_value (and past_sequence_length for MHA, buffer sharing for
DMMHA) so the branch is exercised while only two outputs are declared, matching the
DecoderAttention case which already reached that path.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

@tianleiwu tianleiwu left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, minimal, and well-tested fix. The present/KV-cache branch previously entered under getNumOutputs() > 1 but unconditionally wrote both output index 1 and index 2, so a node declaring exactly two outputs hit an out-of-bounds write to index 2. Raising the three guards to > 2 matches the both-or-neither convention already used by BaseGroupQueryAttention (>= 3) and PagedAttention (which keeps > 1 but immediately enforces != 3, so it is already safe and was correctly left untouched). The getOutputType bounds check is good defense-in-depth.

Tests are thorough: the three "omitted" cases pin the safe lower bound and the three "all-present" cases prove the > 2 branch still performs real KV-cache shape inference.

One housekeeping item: CI lintrunner / CLANGFORMAT already flagged formatting on the new test file — please run lintrunner -a before merge.

Non-blocking nit: the PR text says the new getOutputType check is "consistent with the checked access already used by getInputType", but getInputType uses vector::at() (throws std::out_of_range) while the new code uses fail_type_inference. The fail_type_inference variant is actually the nicer choice (clean type-inference error), so no change needed — just noting the mechanisms differ.

Comment thread onnxruntime/core/graph/contrib_ops/bert_defs.cc
tianleiwu
tianleiwu previously approved these changes Jun 24, 2026
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…ng-convention note

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms titaiwangms enabled auto-merge (squash) June 25, 2026 00:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants