Skip to content

Throwaway: conc-64 gsm8k eval for DEP8+MTP3 dispatch token bug#1659

Open
Oseltamivir wants to merge 17 commits into
mainfrom
dsr1-dep8-mtp3-conc64-eval
Open

Throwaway: conc-64 gsm8k eval for DEP8+MTP3 dispatch token bug#1659
Oseltamivir wants to merge 17 commits into
mainfrom
dsr1-dep8-mtp3-conc64-eval

Conversation

@Oseltamivir
Copy link
Copy Markdown
Collaborator

@Oseltamivir Oseltamivir commented Jun 3, 2026

Summary

  • Reproduce SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK < 256 silent corruption on DEP8+MTP3 disagg config
  • Narrows dsr1-fp4-mi355x-sglang-disagg-8k1k-mtp search-space to a single conc-64 DEP8+MTP3 entry
  • With max(CONC_LIST)=64, server computes dispatch_tokens = 64/8*4 = 32 (below 256 threshold → broken All2All kernel)
  • Expected result: ~0% gsm8k (silent corruption), confirming the perf pareto's artificially inflated acceptance lengths at low concurrency

Background

The previous eval (#1644) passed with 0.9431 gsm8k because the conc-list was [64,128,256,512,640]max=640dispatch_tokens=320 (≥256, correct kernel). The actual perf benchmarks run each concurrency point solo, so at conc=64 the value drops to 32 → broken kernel → garbage tokens + inflated AL.

Not for merge — throwaway validation only.


Note

Medium Risk
Mutates installed sglang inside the container at server start and changes dispatch sizing for DP+EP decode; mis-patching could leave silent corruption, though the script is idempotent and fails loud on structure miss.

Overview
Adds a root-cause fix for low-concurrency MoRI EP dispatch corruption on disaggregated SGLang (valid tokens / high acceptance length but gsm8k ≈ 0 when per-rank dispatch tokens fall below ~256).

Harness / launcher: server_sglang.sh now runs apply_moriep_dispatch_floor.py after env.sh to in-place floor num_max_dispatch_tokens_per_rank to ≥256 in the image’s vendor moriep.py (avoids a full-file overlay that breaks MoriEPDispatcher / expert_mask_gpu). With DP+EP decode, MORI_MAX_DISPATCH_TOKENS_DECODE is set to BENCH_MAX_CONC_VALUE (per-DP-rank capacity), not max_conc / dp_ranks. Comments note the old harness env clamp was removed in favor of the moriep fix.

CI config (marked throwaway): dsr1-fp4-mi355x-sglang-disagg-8k1k-mtp 8k1k MTP search space is collapsed to a single conc-64, 1×DEP8 + 1×DEP8, MTP3 point so max(CONC_LIST)=64 yields dispatch 32 and reproduces the bug before the patch.

Docs / changelog: patches/README.md documents the bug and in-place approach; perf-changelog.yaml records the change for dsr1-fp4-mi355x-sglang-disagg-8k1k-mtp.

Reviewed by Cursor Bugbot for commit 1fdb89f. Bugbot is set up for automated code reviews on this repo. Configure here.

@Oseltamivir Oseltamivir requested a review from a team June 3, 2026 18:43
@Oseltamivir Oseltamivir added the non-canary-full-sweep-enabled Run the full sweep without the canary gate (full search space, no trim) label Jun 3, 2026
@Oseltamivir Oseltamivir added non-canary-full-sweep-enabled Run the full sweep without the canary gate (full search space, no trim) and removed non-canary-full-sweep-enabled Run the full sweep without the canary gate (full search space, no trim) labels Jun 3, 2026
@Oseltamivir
Copy link
Copy Markdown
Collaborator Author

/sweep

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 3, 2026

@Oseltamivir Kicking off a sweep.

Run: https://github.com/SemiAnalysisAI/InferenceX/actions/runs/26905833527
Command: ``
Pinned ref: 93c050e
Approval: not required (trusted collaborator).

…en corruption

Narrow dsr1-fp4-mi355x-sglang-disagg-8k1k-mtp search-space to a single
DEP8+MTP3 conc-64 entry. With max(CONC_LIST)=64, the server computes
SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK=32, which is below the 256
threshold that selects the correct All2All kernel. Expected: ~0% gsm8k
(silent corruption from the low-latency All2All variant).

Not for merge — throwaway validation of the dispatch token bug.
@Oseltamivir Oseltamivir force-pushed the dsr1-dep8-mtp3-conc64-eval branch from 93c050e to 45f69f5 Compare June 3, 2026 18:48
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 3, 2026

Clamp MORI_MAX_DISPATCH_TOKENS_DECODE to minimum 256 when DP+EP are
both enabled, preventing SGLang's low-latency All2All kernel from being
selected. That kernel silently corrupts outputs at small buffer sizes.

Run A of A/B test: benchmark + eval WITH clamp on conc-64 DEP8+MTP3.
Comment thread benchmarks/multi_node/amd_utils/server_sglang.sh Outdated
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 3, 2026

Run B of A/B test: benchmark + eval WITHOUT dispatch token clamp.
MORI_MAX_DISPATCH_TOKENS_DECODE will be 32 (<256 threshold).
Expected: corrupted output, inflated AL, ~0% gsm8k.
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 3, 2026

… benchmark+eval at conc-64

Validates Option A: instead of clamping the env var, patch the installed
SGLang moriep.py at runtime to enforce a minimum of 64 (AMD CDNA3/4
warpSize) on num_max_dispatch_tokens_per_rank before it reaches the MoRI
kernel config. If gsm8k recovers (like the 256 clamp did), this confirms
warpSize is the minimum viable buffer floor and scopes the upstream fix.
Comment thread benchmarks/multi_node/amd_utils/server_sglang.sh Outdated
…val at conc-64

Root cause: the MoRI All2All dispatch kernel (EpDispatchInterNodeV1Kernel /
IntraNode) writes dispatched tokens into warpSize-aligned receive slots
(destTokId = flagSlotId*warpSize + laneId, laneId 0..63), so each warp-chunk
spans 64 (CDNA3/4 wavefront) token slots. The per-rank receive region is sized
to maxNumInpTokenPerRank, which the harness derives as max(CONC_LIST)/TP*(MTP+1).
At low concurrency this collapses below 64 (conc-64/TP8/MTP3 -> 64/8*4 = 32), so
a single chunk overruns the 32-slot region -> silent out-of-bounds writes ->
semantically corrupt output (decodes fine, gsm8k=0).

Confirmed from the conc-64 Run B decode log: INTER_KERNEL_SWITCH=16 with
DISPATCH_TOKENS=32 selects the NON-LL InterNodeV1 kernel (32 > 16), yet output
was still corrupt -> the bug is the buffer-size floor, not the LL-vs-non-LL
kernel choice.

Fix: clamp MORI_MAX_DISPATCH_TOKENS_DECODE to >= 64 after the MTP multiply.
Only raises the value at low conc; adds a few MB of staging buffer but no
compute, so real throughput is unchanged (the ~3% edge of the corrupt run was
an artifact of dropping work). 64 is the principled minimum vs the proven-but-
larger 256.
@Oseltamivir Oseltamivir added non-canary-full-sweep-enabled Run the full sweep without the canary gate (full search space, no trim) and removed non-canary-full-sweep-enabled Run the full sweep without the canary gate (full search space, no trim) labels Jun 3, 2026
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 3, 2026

…ent)

The conc-64 run with the warpSize floor (64) still scored gsm8k=0.00
(run 26919517564), disproving the one-wavefront hypothesis. The per-rank
dispatch buffer must hold the routing fan-in (a receiving rank takes tokens
from all worldSize peers), not just one warp-chunk. Empirically on MI355X:
dispatch=32 -> 0.00, dispatch=64 -> 0.00, dispatch>=256 -> 0.94. Clamp to the
proven 256. Throughput is unchanged; the corrupt run's ~3% edge was dropped
work, not real speed.
@Oseltamivir Oseltamivir removed the non-canary-full-sweep-enabled Run the full sweep without the canary gate (full search space, no trim) label Jun 3, 2026
…ch (vendor image diverges from upstream)

The full-file moriep.py overlay crashed the scheduler at init:
  AttributeError: 'MoriEPDispatcher' object has no attribute 'expert_mask_gpu'
  RuntimeError: Rank 0 scheduler died during initialization

Root cause of the failure: the lmsysorg/sglang-rocm:v0.5.12.post1 image ships a
DOWNSTREAM-patched moriep.py (class MoriEPDispatcher, extra attrs like
expert_mask_gpu) that diverges from the upstream v0.5.12.post1 tag. The overlay
was byte-identical to the upstream tag (md5 ac626f5459...), so bind-mounting it
reverted the AMD additions -> AttributeError. (The overlay DID mount and the floor
DID fire -- "[MORI floor] num_max_dispatch_tokens_per_rank=32 < 256; clamping" --
so the fix value is right; only the delivery was wrong.)

Fix: replace the overlay with patches/apply_moriep_dispatch_floor.py, a surgical
in-place patch run by server_sglang.sh inside the container. It edits the image's
own moriep.py, injecting `num_max_dispatch_tokens_per_rank = max(..., 256)` after
the dispatch-token env read (line-based, balanced-paren end detection, class-
agnostic, idempotent, fail-loud-but-non-fatal with a diagnostic dump of the
image's actual source). This preserves all vendor downstream code.

The fix value (256) is unchanged and proven (env-clamp run gsm8k 0.94).
Upstream: sgl-project/sglang#27194, ROCm/mori#356.
@Oseltamivir Oseltamivir added non-canary-full-sweep-enabled Run the full sweep without the canary gate (full search space, no trim) and removed non-canary-full-sweep-enabled Run the full sweep without the canary gate (full search space, no trim) labels Jun 4, 2026
The vendor image installs sglang as a namespace package where
__file__ is None.  os.path.dirname(None) throws TypeError, so the
patcher crashed and the floor was never applied — eval ran unpatched.

Fall through __file__ → __path__ → importlib.util.find_spec() to
locate the package directory robustly.
@Oseltamivir Oseltamivir added non-canary-full-sweep-enabled Run the full sweep without the canary gate (full search space, no trim) and removed non-canary-full-sweep-enabled Run the full sweep without the canary gate (full search space, no trim) labels Jun 4, 2026
spec = importlib.util.find_spec("sglang")
if spec and spec.submodule_search_locations:
pkg_dir = list(spec.submodule_search_locations)[0]
except Exception:
@Oseltamivir Oseltamivir removed the non-canary-full-sweep-enabled Run the full sweep without the canary gate (full search space, no trim) label Jun 4, 2026
# See patches/apply_moriep_dispatch_floor.py and patches/README.md.
echo "[server_sglang] applying MoRI dispatch-floor patch to installed sglang moriep.py"
python3 "$SGLANG_WS_PATH/patches/apply_moriep_dispatch_floor.py" \
|| echo "[server_sglang] WARN: moriep dispatch-floor patch returned non-zero"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shell warn never on patch fail

Medium Severity

server_sglang.sh only prints its WARN when the patch script exits non-zero, but apply_moriep_dispatch_floor.py returns 0 on import failure, missing moriep.py, regex miss, and write errors, so failed applies can look successful in logs.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit a51ec94. Configure here.

@Oseltamivir Oseltamivir added the non-canary-full-sweep-enabled Run the full sweep without the canary gate (full search space, no trim) label Jun 4, 2026
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 4, 2026

@Oseltamivir Oseltamivir force-pushed the dsr1-dep8-mtp3-conc64-eval branch from 66d701b to 9b50d69 Compare June 4, 2026 04:10
@SemiAnalysisAI SemiAnalysisAI deleted a comment from github-actions Bot Jun 4, 2026
@SemiAnalysisAI SemiAnalysisAI deleted a comment from github-actions Bot Jun 4, 2026
@SemiAnalysisAI SemiAnalysisAI deleted a comment from github-actions Bot Jun 4, 2026
@SemiAnalysisAI SemiAnalysisAI deleted a comment from github-actions Bot Jun 4, 2026
@SemiAnalysisAI SemiAnalysisAI deleted a comment from github-actions Bot Jun 4, 2026
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 4, 2026

The vendor image installs sglang at /sgl-workspace/sglang/ but the
actual Python package is under python/sglang/ within that tree.
When __path__ returns the repo root, the patcher couldn't find
moriep.py. Add candidate paths and a bounded walk fallback.
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 4, 2026

f"{indent} self.num_max_dispatch_tokens_per_rank, {FLOOR}\n"
f"{indent})\n"
)
lines.insert(end + 1, floor_block)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unbalanced parens corrupt moriep insert

Low Severity

If parenthesis balancing never reaches depth <= 0 before EOF, end stays at start and the floor block is inserted on the line after the opening get_int_env_var(, which can splice Python inside a multi-line call and break moriep.py without reporting failure.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 59dd6c3. Configure here.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 4, 2026

The formula divided BENCH_MAX_CONC_VALUE by decode_dp_ranks, assuming
--max-running-requests is a global limit split across DP ranks. It is
actually per-rank: each of the 8 DP schedulers independently allows up
to BENCH_MAX_CONC_VALUE requests. At conc-64/TP8/MTP3 the old formula
produced dispatch=32 (64/8*4), but each rank can hold 64*4=256 tokens,
causing 8x buffer overflow in MoRI's intra-node dispatch kernel (the
only guard is an assert compiled out under -DNDEBUG) and silent
corruption (gsm8k=0).
Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

There are 3 total unresolved issues (including 2 from previous reviews).

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit 1fdb89f. Configure here.

MORI_MAX_DISPATCH_TOKENS_DECODE=$((BENCH_MAX_CONC_VALUE / decode_dp_ranks))
# --max-running-requests is PER DP RANK (not global); each rank can hold
# up to BENCH_MAX_CONC_VALUE requests, so dispatch tokens = that capacity.
MORI_MAX_DISPATCH_TOKENS_DECODE=$BENCH_MAX_CONC_VALUE
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MoE input overscaled after dispatch

High Severity

After MORI_MAX_DISPATCH_TOKENS_DECODE is set to per-rank BENCH_MAX_CONC_VALUE, MORI_MOE_MAX_INPUT_TOKENS_DECODE still multiplies by decode_dp_ranks. Previously dividing dispatch by decode_dp_ranks canceled that factor; decode MoE max input is now ~8× too large for DEP8, which can mis-size buffers or memory.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 1fdb89f. Configure here.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 4, 2026

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

non-canary-full-sweep-enabled Run the full sweep without the canary gate (full search space, no trim)

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

1 participant