vpto fix trowexpand op for col major tile#828
Conversation
Codex Review该评论由 review 机器人自动更新。
SummaryReview failed at stage Findings未生成结构化 findings,因为 review 过程提前失败。 Log Tail |
There was a problem hiding this comment.
Code Review
This pull request introduces support for column-major layouts in trowexpanddiv, trowexpandmul, and trowexpandsub templates by utilizing an unaligned load pipeline (vldas + vldus) to prevent unaligned memory access errors. Additionally, the lowering logic in tilelang_dsl/lowering.py is updated to materialize tile pointers directly via pto.tile_buf_addr instead of using the broken memref-to-subview path. The code reviewer identified critical issues where pto.pset_b32(pto.PAT.ALL) was incorrectly hardcoded for 16-bit float operations in trowexpanddiv_f16 and for generic types in trowexpandmul and trowexpandsub, which would lead to compilation or masking errors. They suggested using pto.pset_b16 for 16-bit types and dynamically resolving the mask size based on the element type in generic templates.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| for row in range(0, valid_rows, 1): | ||
| align_src1 = pto.vldas(src1[row, :]) | ||
| scalar_vec, _ = pto.vldus(src1[row, :], align_src1) | ||
| broadcasted = pto.vdup(scalar_vec, pto.pset_b32(pto.PAT.ALL)) |
There was a problem hiding this comment.
In template_trowexpanddiv_f16, the element type is f16 (16-bit). Therefore, the mask used for broadcasting should be pto.pset_b16(pto.PAT.ALL) instead of pto.pset_b32(pto.PAT.ALL). Using pset_b32 on f16 data will lead to incorrect masking or compilation errors.
| broadcasted = pto.vdup(scalar_vec, pto.pset_b32(pto.PAT.ALL)) | |
| broadcasted = pto.vdup(scalar_vec, pto.pset_b16(pto.PAT.ALL)) |
| for row in range(0, valid_rows, 1): | ||
| align_src1 = pto.vldas(src1[row, :]) | ||
| scalar_vec, _ = pto.vldus(src1[row, :], align_src1) | ||
| broadcasted = pto.vdup(scalar_vec, pto.pset_b32(pto.PAT.ALL)) |
There was a problem hiding this comment.
In template_trowexpanddiv_f16, the element type is f16 (16-bit). Therefore, the mask used for broadcasting should be pto.pset_b16(pto.PAT.ALL) instead of pto.pset_b32(pto.PAT.ALL). Using pset_b32 on f16 data will lead to incorrect masking or compilation errors.
| broadcasted = pto.vdup(scalar_vec, pto.pset_b32(pto.PAT.ALL)) | |
| broadcasted = pto.vdup(scalar_vec, pto.pset_b16(pto.PAT.ALL)) |
| # vldas+vldus once per row, broadcast across all col iterations | ||
| align_src1 = pto.vldas(src1[row, :]) | ||
| scalar_vec, _ = pto.vldus(src1[row, :], align_src1) | ||
| broadcasted = pto.vdup(scalar_vec, pto.pset_b32(pto.PAT.ALL)) |
There was a problem hiding this comment.
Since template_trowexpandmul is generic and supports multiple data types (including f16, i16, i8), hardcoding pto.pset_b32(pto.PAT.ALL) will cause incorrect masking or compilation failures for non-32-bit types. We should dynamically select the mask based on the element size of dtype using pto.constexpr and pto.get_lanes(dtype).
| broadcasted = pto.vdup(scalar_vec, pto.pset_b32(pto.PAT.ALL)) | |
| if pto.constexpr(pto.get_lanes(dtype) == 8): | |
| mask_all = pto.pset_b32(pto.PAT.ALL) | |
| elif pto.constexpr(pto.get_lanes(dtype) == 16): | |
| mask_all = pto.pset_b16(pto.PAT.ALL) | |
| else: | |
| mask_all = pto.pset_b8(pto.PAT.ALL) | |
| broadcasted = pto.vdup(scalar_vec, mask_all) |
| # vldas+vldus once per row, broadcast across all col iterations | ||
| align_src1 = pto.vldas(src1[row, :]) | ||
| scalar_vec, _ = pto.vldus(src1[row, :], align_src1) | ||
| broadcasted = pto.vdup(scalar_vec, pto.pset_b32(pto.PAT.ALL)) |
There was a problem hiding this comment.
Since template_trowexpandsub is generic and supports multiple data types (including f16, i16, i8), hardcoding pto.pset_b32(pto.PAT.ALL) will cause incorrect masking or compilation failures for non-32-bit types. We should dynamically select the mask based on the element size of dtype using pto.constexpr and pto.get_lanes(dtype).
| broadcasted = pto.vdup(scalar_vec, pto.pset_b32(pto.PAT.ALL)) | |
| if pto.constexpr(pto.get_lanes(dtype) == 8): | |
| mask_all = pto.pset_b32(pto.PAT.ALL) | |
| elif pto.constexpr(pto.get_lanes(dtype) == 16): | |
| mask_all = pto.pset_b16(pto.PAT.ALL) | |
| else: | |
| mask_all = pto.pset_b8(pto.PAT.ALL) | |
| broadcasted = pto.vdup(scalar_vec, mask_all) |
|
Fixes #826 |
|
/run a5 |
|
已接收
页面会自动刷新,可以直接看当前阶段、排队情况和最近结果。 |
A5 板测失败
失败用例
|
A5 板测失败详情:PR #828test_tmov_row_major_1x16_control_a5
test_tmov_col_major_16x1_align_a5
test_dynamic_valid_shape
test_barrier_sync
test_auto_sync_tail_hint
rmsnorm_incore_0
rar_optimization_test
nested_loop_confliect
matmul
decode_projection_incore_0
compensation_test
add_double_dynamic
rowexpandsub
rems
rem
rope_kv_cache
rmsnorm
qwen3_decode_incore_7
qwen3_decode_incore_6
qwen3_decode_incore_5
qwen3_decode_incore_4
qwen3_decode_incore_2
qwen3_decode_incore_1
qwen3_decode_incore_12
qwen3_decode_incore_11
qwen3_decode_incore_10
post_rmsnorm
vector_example_dag_kernel_mul
vector_example_dag_kernel_add_scalar
vector_example_dag_kernel_add
paged_attention_example_kernel_softmax_prepare
paged_attention_example_kernel_qk_matmul
paged_attention_example_kernel_pv_matmul
paged_attention_example_kernel_online_update
paged_attention_example_kernel_init_inplace
orchestration_example_kernel_mul
orchestration_example_kernel_add_scalar
orchestration_example_kernel_add
prelu
plan_memory_reuse_sequential
plan_memory_peak_exact_capacity
plan_memory_peak_8_overlapping
plan_memory_no_reuse_overlap
plan_memory_nested_loops
plan_memory_loop_no_reuse_outer_live
plan_memory_loop_in_if
plan_memory_if_yield
plan_memory_if_in_loop
plan_memory_fragmentation_two_holes
plan_memory_fragmentation_hole_fit
plan_memory_for_iter_args_yield
plan_memory_bind_tile_alias_liveness
partition_view_verify_valid
partition_view_verify_rank_mismatch_valid
partition5d_dynamic_a5
partition5d_a5
tensor_view_layout_dn
fillpad
sparse_attn_test_incore_7
decode_swa_test_incore_40
decode_hca_test_incore_54
decode_csa_test_incore_81
attention_swa_test_incore_40
attention_hca_test_incore_54
attention_csa_test_refresh_incore_81
tgather_root_binding
cmps
cmp
|
| constraints=[_constraint_trowexpanddiv_row_major], | ||
| ) | ||
| def template_trowexpanddiv_f32(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): | ||
| """Template for pto.trowexpanddiv with f32 dtype and optional high-precision mode.""" |
问题总结
trowexpand系列 op 的 TileLang 模板及 DSL lowering 层在处理src1为 col major[M, 1]时存在两类缺陷:模板层:
trowexpandmul/trowexpandsub/trowexpanddiv三个模板仅实现了row_major路径,对src1[row, :]使用vlds对齐加载。当src1为 col major 时,tile 切片地址不满足 512B 对齐要求,触发 A5 error 340。lowering 层:
vldas/vldus指令的 IR 生成路径中,将 tile 先转为 memref 再通过 subview + castptr 获取指针的方式存在结构性问题——vldas只接受!pto.ptr类型,memref 中间表示引入了不必要的类型转换,且对 col major tile 的地址计算逻辑不正确。修改方案
1. 模板层 — 新增 col major 非对齐访问路径
在三个模板中分别添加
col_major分支。通过pto.constexpr(src1.config.b_layout == pto.BLayout.COL_MAJOR)在编译期选择路径:col major 路径:使用
vldas+vldus非对齐加载流水线替代vlds:vldas/vldus每行仅调用一次,broadcast 到全掩码后复用于该行所有列迭代。row major 路径:保持原有
vlds对齐访问逻辑不变。此修改同时覆盖了
trowexpandmul(f16/f32)、trowexpandsub(f16/f32)、trowexpanddiv(f16/f32 及 high_precision 变体)所有精度和模式组合。2. lowering 层 — tile 直接转 ptr 并修正地址计算
在
_AuthoringRenderer中:新增
_materialize_tile_ptr()方法:通过pto.tile_buf_addr将 tile 直接转换为!pto.ptr<...>类型,跳过 memref 中间表示。结果缓存在_tile_ptr_cache中。修正
vldas和vldus的 2D tile 地址计算:不再使用_materialize_rank2_tile_subview+_materialize_copy_buffer_ptr的 memref 路径,改为直接用arith.muli+arith.addi计算元素偏移(offset = row * shape[1] + col),再通过pto.addptr获取目标地址。此公式对 row major[M, N]和 col major[M, 1]均适用。两个指令的 lowering 逻辑保持一致:tile → ptr → 偏移计算 → addptr → vldas/vldus。
3. 涉及文件
lib/TileOps/trowexpanddiv_template.pylib/TileOps/trowexpandmul_template.pylib/TileOps/trowexpandsub_template.pytilelang-dsl/python/tilelang_dsl/lowering.py_materialize_tile_ptr();修正vldas/vldus的 2D tile 地址计算逻辑