Fix optional-output guard in DecoderAttention/MultiHeadAttention shape inference#29243
Fix optional-output guard in DecoderAttention/MultiHeadAttention shape inference#29243titaiwangms wants to merge 7 commits into
Conversation
DecoderAttention and MultiHeadAttention shape-inference functions guarded population of present_key (output 1) and present_value (output 2) with getNumOutputs() > 1, but write output index 2. present_key and present_value are produced as a both-or-neither pair, so require all three outputs (> 2) before populating them, matching BaseGroupQueryAttention (>= 3) and the EmbedLayerNorm guard. Also add a bounds check in InferenceContextImpl::getOutputType so an out-of-range output index fails inference cleanly instead of indexing past the end, mirroring DataPropagationContextImpl and getInputType. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…outputs omitted Cover DecoderAttention, MultiHeadAttention and DecoderMaskedMultiHeadAttention nodes declared with exactly two outputs (present_key kept, present_value omitted). Each test builds the node and asserts Graph::Resolve() shape inference completes cleanly. Tests are execution-provider independent and throw-free, so they run on the default CPU build and in no-exception (ORT_NO_EXCEPTIONS) builds. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Extend the optional-present-output regression suite with cases that declare all three outputs for DecoderAttention, MultiHeadAttention and DecoderMaskedMultiHeadAttention and assert the present_key/present_value branch still runs and infers their element types. Together with the two-output cases this pins the output-count guard to exactly three. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Use each op's actual output names (DecoderAttention: new_key_cache / new_value_cache; MultiHeadAttention: present_key / present_value), align the three guards to a consistent '// has <names> outputs' phrasing, and note that the two optional cache outputs are produced as a pair, so they are present only when the node declares more than two outputs. Comment-only; no logic change. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The MultiHeadAttention and DecoderMaskedMultiHeadAttention two-output cases only passed a query input, so the present-output branch (which references output index 2) was never entered and the tests could not detect a regression there. Supply shaped past_key / past_value (and past_sequence_length for MHA, buffer sharing for DMMHA) so the branch is exercised while only two outputs are declared, matching the DecoderAttention case which already reached that path. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
tianleiwu
left a comment
There was a problem hiding this comment.
Correct, minimal, and well-tested fix. The present/KV-cache branch previously entered under getNumOutputs() > 1 but unconditionally wrote both output index 1 and index 2, so a node declaring exactly two outputs hit an out-of-bounds write to index 2. Raising the three guards to > 2 matches the both-or-neither convention already used by BaseGroupQueryAttention (>= 3) and PagedAttention (which keeps > 1 but immediately enforces != 3, so it is already safe and was correctly left untouched). The getOutputType bounds check is good defense-in-depth.
Tests are thorough: the three "omitted" cases pin the safe lower bound and the three "all-present" cases prove the > 2 branch still performs real KV-cache shape inference.
One housekeeping item: CI lintrunner / CLANGFORMAT already flagged formatting on the new test file — please run lintrunner -a before merge.
Non-blocking nit: the PR text says the new getOutputType check is "consistent with the checked access already used by getInputType", but getInputType uses vector::at() (throws std::out_of_range) while the new code uses fail_type_inference. The fail_type_inference variant is actually the nicer choice (clean type-inference error), so no change needed — just noting the mechanisms differ.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…ng-convention note Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Summary
DecoderAttentionTypeAndShapeInferenceandMultiHeadAttentionTypeAndShapeInference(the latter shared byMultiHeadAttentionandDecoderMaskedMultiHeadAttention) guarded population of the optional present/cache outputs withgetNumOutputs() > 1, but they populate output index 2. These outputs are produced as a pair (both-or-neither), so the guard should require more than 2 outputs before populating outputs 1 and 2.This changes the three guards from
> 1to> 2, matching the established pattern already used inBaseGroupQueryAttentionTypeAndShapeInference(>= 3),PagedAttentionTypeAndShapeInference, and theEmbedLayerNormalizationfix in #28176.Defensive bounds check
Also adds a bounds check in
InferenceContextImpl::getOutputType(core/graph/graph.cc), consistent with the checked access already used bygetInputType, so an out-of-range output index yields a clean type-inference failure instead of undefined behavior.Tests
New
onnxruntime/test/contrib_ops/attention_optional_outputs_shape_inference_test.ccwith 6 cases:> 2branch still performs legitimate KV-cache shape inference.The tests run through graph resolution and are execution-provider-independent. They are throw-free and safe under
ORT_NO_EXCEPTIONSbuilds.Reference
Mirrors #28176.