Skip to content

Commit fe9813b

Browse files
jgmelberclaude
andcommitted
Skip redundant decode FFN operators when swiglu_fused is active
When use_aie_ffn_swiglu_fused is True, the fused operator handles all of gate+up+silu+mul+down in one design. Skip creating the 3 separate decode GEMVs (aie_fc1_gemv, aie_fc2_gemv, aie_fc3_gemv) which would waste compilation time and device memory. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 563812f commit fe9813b

2 files changed

Lines changed: 9 additions & 2 deletions

File tree

iron/applications/llama_3.2_1b/src/block/feed_forward.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,12 @@ def __init__(
123123
cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False
124124
)
125125

126-
if self.cfg["use_kv_cache"] and self.cfg["use_aie_ffn_gemv"]:
126+
# Skip creating separate decode GEMVs when fused SwiGLU handles everything
127+
if (
128+
self.cfg["use_kv_cache"]
129+
and self.cfg["use_aie_ffn_gemv"]
130+
and not self.cfg.get("use_aie_ffn_swiglu_fused", False)
131+
):
127132
aie_gemv_config = {"num_aie_columns": 8, "is_mv": False}
128133
# FC1 and FC2: emb_dim -> hidden_dim
129134
self.aie_fc1_gemv = AIEGEMV(

iron/applications/llama_3.2_1b/src/model_with_json.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@ def format_option(name, value):
9494
"use_aie_ffn_mul",
9595
"use_aie_ffn_silu",
9696
}
97-
if not cfg.get("use_aie_ffn_swiglu_fused", False):
97+
if cfg.get("use_aie_ffn_swiglu_fused", False):
98+
dont_print |= {"use_aie_ffn_gemv"}
99+
else:
98100
dont_print |= {"use_aie_ffn_swiglu_fused"}
99101
else:
100102
dont_print |= {"use_aie_ffn_swiglu", "use_aie_ffn_swiglu_fused"}

0 commit comments

Comments
 (0)