[#15565][fix] AutoDeploy: Fix Super MTP IMA introduced by checkpointing replay#15622
[#15565][fix] AutoDeploy: Fix Super MTP IMA introduced by checkpointing replay#15622galagam wants to merge 3 commits into
Conversation
📝 WalkthroughWalkthroughReplay work-items are zero-initialized, populated through a host-prepare hook before attention forward, and the related AutoDeploy tests, test list, and waiver entries are updated. A new Mamba replay regression test constructs replay buffers and runs the cached SSM path. ChangesReplay metadata and AutoDeploy coverage
Sequence Diagram(s)sequenceDiagram
participant ADEngine as ADEngine._prepare_inputs
participant CachedSequenceInterface
participant RegisterHook as register_host_prepare_for_attention_forward
participant PrepareReplay as prepare_replay_metadata()
ADEngine->>CachedSequenceInterface: info.nest_sequences(...)
CachedSequenceInterface->>RegisterHook: register_host_prepare_for_attention_forward(...)
RegisterHook->>PrepareReplay: prepare_replay_metadata()
PrepareReplay-->>CachedSequenceInterface: _replay_work_items populated
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/unittest/auto_deploy/singlegpu/custom_ops/mamba/test_flashinfer_mamba_cached_op.py (1)
346-468: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low valueGood deterministic regression guard.
The poison-then-prep-then-synchronize structure reliably reproduces the capture-time IMA, and confining the poison to
_replay_work_itemskeeps the failure deterministic. Coverage for the replay path is sufficient.Optionally, you could strengthen the assertion to prove prep actually overwrote the buffer (not merely that the kernel didn't fault), e.g. assert the poison sentinel is gone from the consumed rows:
🧪 Optional stronger assertion
torch.cuda.synchronize() + # Prep must have overwritten the poisoned rows the kernel consumes. + assert (replay_work_items[:num_extend] != 0x7FFFFFFF).all() assert out.shape == hidden_states.shape assert torch.isfinite(out).all()🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/auto_deploy/singlegpu/custom_ops/mamba/test_flashinfer_mamba_cached_op.py` around lines 346 - 468, The replay regression guard is good, but it only checks that the kernel does not fault; add an assertion in test_extend_replay_no_ima that the production prep path actually overwrites the poisoned _replay_work_items buffer before flashinfer_cached_ssm consumes it. Use the existing setup around interface.info.set_capture_batch(), replay_work_items, and replay_n_writes to verify the sentinel value is gone from the relevant rows after metadata prep and before synchronization.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In
`@tests/unittest/auto_deploy/singlegpu/custom_ops/mamba/test_flashinfer_mamba_cached_op.py`:
- Around line 346-468: The replay regression guard is good, but it only checks
that the kernel does not fault; add an assertion in test_extend_replay_no_ima
that the production prep path actually overwrites the poisoned
_replay_work_items buffer before flashinfer_cached_ssm consumes it. Use the
existing setup around interface.info.set_capture_batch(), replay_work_items, and
replay_n_writes to verify the sentinel value is gone from the relevant rows
after metadata prep and before synchronization.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: ae99a81d-7b62-467a-ab74-94a77a15c7b7
📒 Files selected for processing (6)
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.pytensorrt_llm/_torch/auto_deploy/shim/interface.pytests/integration/defs/accuracy/test_llm_api_autodeploy.pytests/integration/test_lists/test-db/l0_dgx_b200.ymltests/integration/test_lists/waives.txttests/unittest/auto_deploy/singlegpu/custom_ops/mamba/test_flashinfer_mamba_cached_op.py
💤 Files with no reviewable changes (2)
- tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
- tests/integration/test_lists/waives.txt
7728d94 to
a93dede
Compare
|
/bot run |
|
PR_Github #55767 [ run ] triggered by Bot. Commit: |
|
PR_Github #55767 [ run ] completed with state
|
…tems buffer The SuperV3 MTP replay path allocated the shared `replay_work_items` metadata buffer with `torch.empty` in `CachedSequenceInterface._create_mamba_hybrid_cache_manager`. During CUDA graph capture/warmup the replay SSM kernel runs before `prepare_replay_metadata()` has populated this buffer, so in `persistent_main` mode the kernel reads uninitialized cache-slot indices and performs an out-of-bounds SSM state-cache access -- an illegal memory access that crashed executor init (e.g. test_mtp[nvfp4_ws4_80gb-trtllm]). Allocate the buffer with `torch.zeros` so every work item maps to slot 0 / pnat 0 (in-bounds) until real metadata is written per step. This mirrors the PyTorch backend, which already zero-inits `replay_work_items` (mamba2_metadata.py). Add a regression test that drives the real replay op (flashinfer_cached_ssm -> replay_selective_state_update) at the production SuperV3 shape (persistent_main mode) with the buffer allocated through the production path and left unprepared (capture scenario), asserting no IMA. The test deterministically faults without the fix and passes with it. Unwaive the SuperV3 test_mtp ws4 cases (bf16 / fp8 / nvfp4) under nvbugs 6316981 and 6336682, now that the capture-time IMA is fixed. Register the previously-uncovered nvfp4_ws8 trtllm MTP case on DGX-B200 (existing 8-GPU post-merge section). Raise the SuperV3 test_mtp acceptance-rate gate to a single 0.50 for all dtypes (observed >=53% on bf16/fp8/nvfp4 ws4). Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
… during cudagraph capture Make replay-metadata preparation symmetric across runtime and cudagraph capture by registering prepare_replay_metadata as a host-prepare hook, so it runs inside nest_sequences() on every path -- exactly like all other AD attention metadata. Previously prepare_replay_metadata was a side call only in ADEngine._prepare_inputs (runtime). The cudagraph capture path (SequenceInfo.set_capture_batch) synthesizes every other metadata buffer (slot_idx, KV page metadata, in-graph prepare ops) but never invoked it, so replay_work_items was the lone metadata buffer left unprepared at capture -- the root cause of the SuperV3 MTP capture-time IMA. The hook is registered only when the replay buffers were actually allocated (ssm_replay enabled), so non-replay models pay no per-forward cost and never touch the unallocated buffers. The explicit runtime call is removed (the hook now covers it). The zero-init of replay_work_items is kept as defense-in-depth (reframed in comment), mirroring the PyTorch backend which also keeps the zeros guard alongside its per-forward prepare. Verified on 4xGB200: with the buffer deliberately poisoned at allocation, the hook overwrites it at capture and every runtime step -- test_mtp[nvfp4_ws4_80gb-trtllm] passes (acceptance 54.5%, no IMA), confirming zero-init is now optional. Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
a93dede to
27f1dd7
Compare
|
/bot run --post-merge |
|
PR_Github #55806 [ run ] triggered by Bot. Commit: |
Summary by CodeRabbit
Bug Fixes
Tests
Description
Make replay-metadata preparation symmetric across runtime and cudagraph
capture by registering prepare_replay_metadata as a host-prepare hook, so it
runs inside nest_sequences() on every path -- exactly like all other AD
attention metadata.
Previously prepare_replay_metadata was a side call only in
ADEngine._prepare_inputs (runtime). The cudagraph capture path
(SequenceInfo.set_capture_batch) synthesizes every other metadata buffer
(slot_idx, KV page metadata, in-graph prepare ops) but never invoked it, so
replay_work_items was the lone metadata buffer left unprepared at capture --
the root cause of the SuperV3 MTP capture-time IMA.
The hook is registered only when the replay buffers were actually allocated
(ssm_replay enabled), so non-replay models pay no per-forward cost and never
touch the unallocated buffers. The explicit runtime call is removed (the hook
now covers it).
In addition, zero-init replay_work_items as an extra defense mechanism.
Increase required acceptance rate by the integration test. Empirically, we're seeing acceptance rates of 53%-54% in quantized variants as well. Tighten the threshold from 40% to 50% to avoid future regressions.
Test Coverage
tests/unittest/auto_deploy/singlegpu/custom_ops/mamba/test_flashinfer_mamba_cached_op.py::test_extend_replay_init_buffersPR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
If PR introduces API changes, an appropriate PR label is added - either
api-compatibleorapi-breaking. Forapi-breaking, includeBREAKINGin the PR title.Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.