[TRTLLM-12982][chore] relocate torch_multi_arange#15416
Conversation
|
/bot run --disable-fail-fast |
|
PR_Github #54587 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthrough
Changestorch_multi_arange relocation and multi_item_part_lens prepare() refactor
Sequence Diagram(s)sequenceDiagram
participant encode as llm.encode()
participant model_engine as _prepare_encoder_inputs
participant cuda_runner as EncoderCUDAGraphRunner
participant metadata as FlashInferAttentionMetadata
participant plan as FlashInferAttentionMetadata.plan()
encode->>encode: compute position_ids via torch_multi_arange
encode->>model_engine: inputs (multi_item_part_lens, position_ids)
model_engine->>cuda_runner: maybe_get_cuda_graph(inputs)
cuda_runner-->>model_engine: (None, None) — fallback to eager
model_engine->>metadata: prepare(multi_item_part_lens=...)
metadata->>metadata: _process_multi_item_part_lens() → _multi_item_params
model_engine->>plan: plan(...)
plan->>plan: PlanParams(multi_item_params=_multi_item_params)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tensorrt_llm/_torch/attention_backend/flashinfer.py`:
- Around line 744-755: The code accesses req_part_lens[0] and req_part_lens[1:]
without validating that each req_part_lens in multi_item_part_lens has the
required structure, which can cause IndexError or ValueError when constructing
tensors for malformed entries like empty lists or lists with only a prefix_len.
Before constructing the prefix_len_ptr and max_item_len_ptr tensors, add
validation to ensure each req_part_lens has at least two elements (one for
prefix_len and at least one for scored items), and raise an API-level ValueError
with a descriptive message if any request part list fails this validation.
- Around line 762-770: The zip() call combining multi_item_part_lens and
token_pos_in_items_raw_lens needs to add strict=True parameter to document that
these iterables have the same length, which resolves the B905 lint finding.
Additionally, replace the list concatenation in the innermost for loop
(req_part_lens[1:] + [token_pos_in_items_len - token_pos_in_items_raw_len]) with
iterable unpacking syntax instead to resolve the RUF005 lint finding.
In `@tensorrt_llm/_torch/utils.py`:
- Around line 574-580: The variable repeats is initialized as an alias to the
ends tensor, and when starts is None, this alias is never broken before the
in-place multiplication operation repeats *= steps.sign() on line 579. This
mutates the caller's ends tensor. Fix this by using out-of-place arithmetic for
the repeat count calculation: instead of the in-place multiplication repeats *=
steps.sign(), use repeats = repeats * steps.sign() to create a new tensor and
avoid mutating the input.
- Around line 584-602: The prev_range_ends calculation using range_ends.roll(1)
doesn't account for empty ranges where repeats == 0. When a range is empty, its
nominal end value should not be used as the previous range end for the next
range; instead, the end of the last non-empty range should be carried forward.
Modify the logic that computes prev_range_ends to propagate the previous
non-empty range's end value through empty ranges, ensuring that jumps
calculations correctly reflect transitions only between actual non-empty ranges.
- Around line 541-557: Replace the assert statements in the function that
validates dtype, shape, and device compatibility between ends, steps, and starts
parameters with explicit ValueError exceptions that include descriptive error
messages. Additionally, add validation at the function entry to ensure that all
input tensors (starts, ends, and steps) are 1-D tensors, raising ValueError if
they are not, since the implementation later uses unsqueeze and torch.cat
operations that expect 1-D inputs.
In `@tensorrt_llm/llmapi/llm.py`:
- Around line 904-932: The code does not sufficiently validate the structure of
multi_item_part_lens before constructing starts_cuda and ends_cuda, allowing
malformed inputs like [prefix_len] with no item lengths to pass through and fail
later in FlashInfer. Add validation before the torch.tensor calls that construct
starts_cuda and ends_cuda to ensure that each multi_item_part_lens in
batch_multi_item_part_lens has length greater than 1 (meaning at least one item
length in addition to the prefix length) and that all length values are
non-negative. Reject the inputs early with a clear error message if these
conditions are not met.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 89ea82b0-3e34-43c9-bc70-8761dbd903f9
📒 Files selected for processing (18)
.pre-commit-config.yamllegacy-files.txtpyproject.tomlruff-legacy.tomltensorrt_llm/_torch/attention_backend/flashinfer.pytensorrt_llm/_torch/attention_backend/interface.pytensorrt_llm/_torch/attention_backend/star_flashinfer.pytensorrt_llm/_torch/attention_backend/trtllm.pytensorrt_llm/_torch/attention_backend/vanilla.pytensorrt_llm/_torch/modules/attention.pytensorrt_llm/_torch/pyexecutor/cuda_graph_runner.pytensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/_torch/pyexecutor/sampler.pytensorrt_llm/_torch/pyexecutor/sampling_utils.pytensorrt_llm/_torch/utils.pytensorrt_llm/llmapi/llm.pytests/integration/test_lists/test-db/l0_a10.ymltests/unittest/_torch/test_torch_multi_arange.py
💤 Files with no reviewable changes (2)
- tensorrt_llm/_torch/pyexecutor/sampling_utils.py
- tensorrt_llm/_torch/modules/attention.py
|
PR_Github #54587 [ run ] completed with state
|
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
faad7dc to
6cf0c06
Compare
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
|
/bot run |
|
PR_Github #55558 [ run ] triggered by Bot. Commit: |
|
PR_Github #55558 [ run ] completed with state |
Description
Follow-up on #14693 (comment).
Commit 800c7ee is from #15413, which is to be merged before this PR.
Test Coverage
Covered by existing tests
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
If PR introduces API changes, an appropriate PR label is added - either
api-compatibleorapi-breaking. Forapi-breaking, includeBREAKINGin the PR title.Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.Summary by CodeRabbit
Improvements
Refactoring
Chores