Skip to content

Commit 2ed0846

Browse files
committed
[Mppa] Delay vectorization to the lowering pipeline
Rely on the kvxuks-catch pass to catch micro-kernels in replacement of the transform dialect based vectorization.
1 parent 11ec57c commit 2ed0846

2 files changed

Lines changed: 257 additions & 1 deletion

File tree

src/xtc/backends/mlir/MlirTarget/MlirMppaTarget.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from ..MlirProgram import RawMlirProgram
2929

3030
from mlir.passmanager import PassManager
31+
from mlir.ir import OpResult
32+
from mlir.dialects import transform
3133

3234
__all__ = ["MlirMppaTarget"]
3335

@@ -159,6 +161,14 @@ def create_module(
159161
name, payload_name, file_name, file_type, mppa_config, graph, **kwargs
160162
)
161163

164+
@override
165+
def has_custom_vectorize(self) -> bool:
166+
return True
167+
168+
@override
169+
def apply_custom_vectorize(self, handle: OpResult) -> None:
170+
transform.AnnotateOp(handle, "xtc.request_vectorization")
171+
162172
def dump_ir(self, mlir_program: RawMlirProgram, title: str):
163173
print(f"// -----// {title} //----- //", file=sys.stderr)
164174
print(str(mlir_program.mlir_module), file=sys.stderr)
@@ -303,7 +313,9 @@ def _lowering_pipeline(self) -> str:
303313
passes.append("canonicalize")
304314
passes.append("func.func(kvxpe-scf-forall-distribute{num-pes=1})")
305315
passes.append("func.func(kvxpe-launch)")
306-
passes.append("func.func(kvxuks-catch)")
316+
passes.append(
317+
"func.func(kvxuks-catch{request-attribute=xtc.request_vectorization})"
318+
)
307319
passes.append("canonicalize")
308320
passes.append("convert-linalg-to-loops")
309321
passes.append("func.func(lower-affine)")
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
# RUN: python %s 2>&1 | filecheck %s
2+
# REQUIRES: module_mlir_mppa
3+
# REQUIRES: mlir-target=mppa
4+
5+
import xtc.graphs.xtc.op as O
6+
from xtc.backends.mlir.MlirGraphBackend import MlirGraphBackend as Backend
7+
8+
from xtc.runtimes.accelerator.mppa import MppaDevice
9+
10+
I, J, K, dtype = 16, 16, 64, "float32"
11+
a = O.tensor((I, K), dtype, name="A")
12+
b = O.tensor((K, J), dtype, name="B")
13+
14+
with O.graph(name="matmul") as gb:
15+
O.matmul(a, b, name="C")
16+
17+
graph = gb.graph
18+
print(graph)
19+
20+
impl = Backend(graph)
21+
22+
sch = impl.get_scheduler()
23+
sch.define_memory_mesh(axes={"mx": 1, "my": 1})
24+
sch.define_processor_mesh(axes={"px": 1, "py": 1, "psx": 1, "psy": 1})
25+
sch.tile("i", {"i1": 8})
26+
sch.tile("j", {"j1": 8})
27+
sch.interchange(["i", "j", "i1", "j1", "k"])
28+
sch.vectorize(["i1", "j1", "k"])
29+
#sch.pack_at("i1", 1)
30+
sched = sch.schedule()
31+
32+
# Create mppa device
33+
mppa = MppaDevice()
34+
35+
comp = impl.get_compiler(
36+
target=mppa,
37+
shared_lib=True,
38+
dump_file="matmul_mlir_mppa",
39+
print_source_ir=True,
40+
print_transformed_ir=True,
41+
print_lowered_ir=True,
42+
)
43+
module = comp.compile(sched)
44+
executor = module.get_executor(validate=True)
45+
res = executor.execute()
46+
print(f"CODE: {res}")
47+
# CHECK: // -----// IR Dump Before transform //----- //
48+
# CHECK-NEXT: module attributes {transform.with_named_sequence} {
49+
# CHECK-NEXT: func.func @matmul(%arg0: memref<16x64xf32> {llvm.noalias}, %arg1: memref<64x16xf32> {llvm.noalias}, %arg2: memref<16x16xf32> {llvm.noalias}) {
50+
# CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32
51+
# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst : f32) outs(%arg2 : memref<16x16xf32>)
52+
# CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%arg0, %arg1 : memref<16x64xf32>, memref<64x16xf32>) outs(%arg2 : memref<16x16xf32>)
53+
# CHECK-NEXT: return
54+
# CHECK-NEXT: }
55+
# CHECK-NEXT: transform.named_sequence @_vecto(%arg0: !transform.any_op {transform.consumed}) {
56+
# CHECK-NEXT: transform.structured.vectorize %arg0 : !transform.any_op
57+
# CHECK-NEXT: transform.yield
58+
# CHECK-NEXT: }
59+
# CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
60+
# CHECK-NEXT: %0 = transform.sdist.create_memory_mesh %arg0 "memory_mesh" = <["mx"=1, "my"=1]> : !transform.any_op -> !transform.any_op
61+
# CHECK-NEXT: %1 = transform.sdist.create_processor_mesh %arg0 "processor_mesh" = <["px"=1, "py"=1, "psx"=1, "psy"=1]> from "memory_mesh" : !transform.any_op -> !transform.any_op
62+
# CHECK-NEXT: %2 = transform.structured.match attributes {__xtc_id_C_0_} in %arg0 : (!transform.any_op) -> !transform.any_op
63+
# CHECK-NEXT: %tiled_linalg_op, %loops = transform.structured.tile_using_for %2 tile_sizes [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
64+
# CHECK-NEXT: transform.annotate %loops "./i" : !transform.any_op
65+
# CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %tiled_linalg_op tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
66+
# CHECK-NEXT: transform.annotate %loops_1 "./j" : !transform.any_op
67+
# CHECK-NEXT: %3 = transform.structured.match attributes {__xtc_id_C_} in %arg0 : (!transform.any_op) -> !transform.any_op
68+
# CHECK-NEXT: %tiled_linalg_op_2, %loops_3 = transform.structured.tile_using_for %3 tile_sizes [8, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
69+
# CHECK-NEXT: transform.annotate %loops_3 "./i" : !transform.any_op
70+
# CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiled_linalg_op_2 tile_sizes [0, 8, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
71+
# CHECK-NEXT: transform.annotate %loops_5 "./j" : !transform.any_op
72+
# CHECK-NEXT: transform.annotate %tiled_linalg_op_4 "xtc.request_vectorization" : !transform.any_op
73+
# CHECK-NEXT: %4 = transform.get_parent_op %loops_3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
74+
# CHECK-NEXT: transform.apply_patterns to %4 {
75+
# CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract
76+
# CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns
77+
# CHECK-NEXT: } : !transform.any_op
78+
# CHECK-NEXT: transform.apply_patterns to %4 {
79+
# CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct
80+
# CHECK-NEXT: transform.apply_patterns.vector.lower_contraction
81+
# CHECK-NEXT: } : !transform.any_op
82+
# CHECK-NEXT: transform.yield
83+
# CHECK-NEXT: }
84+
# CHECK-NEXT: }
85+
# CHECK-NEXT:
86+
# CHECK-NEXT: // -----// IR Dump After transform //----- //
87+
# CHECK-NEXT: module attributes {transform.with_named_sequence} {
88+
# CHECK-NEXT: sdist.processor_mesh @processor_mesh from @memory_mesh = <["px"=1, "py"=1, "psx"=1, "psy"=1]>
89+
# CHECK-NEXT: sdist.memory_mesh @memory_mesh = <["mx"=1, "my"=1]>
90+
# CHECK-NEXT: func.func @matmul(%arg0: memref<16x64xf32> {llvm.noalias}, %arg1: memref<64x16xf32> {llvm.noalias}, %arg2: memref<16x16xf32> {llvm.noalias}) {
91+
# CHECK-NEXT: %c8 = arith.constant 8 : index
92+
# CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32
93+
# CHECK-NEXT: %c0 = arith.constant 0 : index
94+
# CHECK-NEXT: %c16 = arith.constant 16 : index
95+
# CHECK-NEXT: %c1 = arith.constant 1 : index
96+
# CHECK-NEXT: scf.for %arg3 = %c0 to %c16 step %c1 {
97+
# CHECK-NEXT: %subview = memref.subview %arg2[%arg3, 0] [1, 16] [1, 1] : memref<16x16xf32> to memref<1x16xf32, strided<[16, 1], offset: ?>>
98+
# CHECK-NEXT: scf.for %arg4 = %c0 to %c16 step %c1 {
99+
# CHECK-NEXT: %subview_0 = memref.subview %subview[0, %arg4] [1, 1] [1, 1] : memref<1x16xf32, strided<[16, 1], offset: ?>> to memref<1x1xf32, strided<[16, 1], offset: ?>>
100+
# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst : f32) outs(%subview_0 : memref<1x1xf32, strided<[16, 1], offset: ?>>)
101+
# CHECK-NEXT: } {"./j"}
102+
# CHECK-NEXT: } {"./i"}
103+
# CHECK-NEXT: scf.for %arg3 = %c0 to %c16 step %c8 {
104+
# CHECK-NEXT: %subview = memref.subview %arg0[%arg3, 0] [8, 64] [1, 1] : memref<16x64xf32> to memref<8x64xf32, strided<[64, 1], offset: ?>>
105+
# CHECK-NEXT: %subview_0 = memref.subview %arg1[0, 0] [64, 16] [1, 1] : memref<64x16xf32> to memref<64x16xf32, strided<[16, 1]>>
106+
# CHECK-NEXT: %subview_1 = memref.subview %arg2[%arg3, 0] [8, 16] [1, 1] : memref<16x16xf32> to memref<8x16xf32, strided<[16, 1], offset: ?>>
107+
# CHECK-NEXT: scf.for %arg4 = %c0 to %c16 step %c8 {
108+
# CHECK-NEXT: %subview_2 = memref.subview %subview_0[0, %arg4] [64, 8] [1, 1] : memref<64x16xf32, strided<[16, 1]>> to memref<64x8xf32, strided<[16, 1], offset: ?>>
109+
# CHECK-NEXT: %subview_3 = memref.subview %subview_1[0, %arg4] [8, 8] [1, 1] : memref<8x16xf32, strided<[16, 1], offset: ?>> to memref<8x8xf32, strided<[16, 1], offset: ?>>
110+
# CHECK-NEXT: linalg.matmul {__xtc_id_C_, xtc.request_vectorization} ins(%subview, %subview_2 : memref<8x64xf32, strided<[64, 1], offset: ?>>, memref<64x8xf32, strided<[16, 1], offset: ?>>) outs(%subview_3 : memref<8x8xf32, strided<[16, 1], offset: ?>>)
111+
# CHECK-NEXT: } {"./j"}
112+
# CHECK-NEXT: } {"./i"}
113+
# CHECK-NEXT: return
114+
# CHECK-NEXT: }
115+
# CHECK-NEXT: }
116+
# CHECK-NEXT:
117+
# CHECK-NEXT: // -----// IR Dump After MLIR Opt //----- //
118+
# CHECK-NEXT: #map = affine_map<(d0, d1, d2) -> (d0, d2)>
119+
# CHECK-NEXT: #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
120+
# CHECK-NEXT: #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
121+
# CHECK-NEXT: "builtin.module"() ({
122+
# CHECK-NEXT: "func.func"() <{arg_attrs = [{llvm.noalias}, {llvm.noalias}, {llvm.noalias}], function_type = (memref<16x64xf32>, memref<64x16xf32>, memref<16x16xf32>) -> (), sym_name = "matmul"}> ({
123+
# CHECK-NEXT: ^bb0(%arg0: memref<16x64xf32>, %arg1: memref<64x16xf32>, %arg2: memref<16x16xf32>):
124+
# CHECK-NEXT: "mppa.launch"() ({
125+
# CHECK-NEXT: "kvxcluster.launch"() ({
126+
# CHECK-NEXT: ^bb0(%arg3: index):
127+
# CHECK-NEXT: %0 = "arith.constant"() <{value = 1 : index}> : () -> index
128+
# CHECK-NEXT: %1 = "arith.constant"() <{value = 16 : index}> : () -> index
129+
# CHECK-NEXT: %2 = "arith.constant"() <{value = 0 : index}> : () -> index
130+
# CHECK-NEXT: %3 = "arith.constant"() <{value = 0.000000e+00 : f32}> : () -> f32
131+
# CHECK-NEXT: %4 = "arith.constant"() <{value = 8 : index}> : () -> index
132+
# CHECK-NEXT: "scf.for"(%2, %1, %0) ({
133+
# CHECK-NEXT: ^bb0(%arg9: index):
134+
# CHECK-NEXT: %11 = "memref.subview"(%arg2, %arg9) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: 1, 16>, static_strides = array<i64: 1, 1>}> : (memref<16x16xf32>, index) -> memref<1x16xf32, strided<[16, 1], offset: ?>>
135+
# CHECK-NEXT: "scf.for"(%2, %1, %0) ({
136+
# CHECK-NEXT: ^bb0(%arg10: index):
137+
# CHECK-NEXT: %12 = "memref.subview"(%11, %arg10) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: 0, -9223372036854775808>, static_sizes = array<i64: 1, 1>, static_strides = array<i64: 1, 1>}> : (memref<1x16xf32, strided<[16, 1], offset: ?>>, index) -> memref<1x1xf32, strided<[16, 1], offset: ?>>
138+
# CHECK-NEXT: "linalg.fill"(%3, %12) <{operandSegmentSizes = array<i32: 1, 1>}> ({
139+
# CHECK-NEXT: ^bb0(%arg11: f32, %arg12: f32):
140+
# CHECK-NEXT: "linalg.yield"(%arg11) : (f32) -> ()
141+
# CHECK-NEXT: }) {__xtc_id_C_0_} : (f32, memref<1x1xf32, strided<[16, 1], offset: ?>>) -> ()
142+
# CHECK-NEXT: "scf.yield"() : () -> ()
143+
# CHECK-NEXT: }) {"./j"} : (index, index, index) -> ()
144+
# CHECK-NEXT: "scf.yield"() : () -> ()
145+
# CHECK-NEXT: }) {"./i"} : (index, index, index) -> ()
146+
# CHECK-NEXT: "scf.for"(%2, %1, %4) ({
147+
# CHECK-NEXT: ^bb0(%arg4: index):
148+
# CHECK-NEXT: %5 = "memref.subview"(%arg0, %arg4) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: 8, 64>, static_strides = array<i64: 1, 1>}> : (memref<16x64xf32>, index) -> memref<8x64xf32, strided<[64, 1], offset: ?>>
149+
# CHECK-NEXT: %6 = "memref.subview"(%arg2, %arg4) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: 8, 16>, static_strides = array<i64: 1, 1>}> : (memref<16x16xf32>, index) -> memref<8x16xf32, strided<[16, 1], offset: ?>>
150+
# CHECK-NEXT: "scf.for"(%2, %1, %4) ({
151+
# CHECK-NEXT: ^bb0(%arg5: index):
152+
# CHECK-NEXT: %7 = "memref.subview"(%arg1, %arg5) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: 0, -9223372036854775808>, static_sizes = array<i64: 64, 8>, static_strides = array<i64: 1, 1>}> : (memref<64x16xf32>, index) -> memref<64x8xf32, strided<[16, 1], offset: ?>>
153+
# CHECK-NEXT: %8 = "memref.subview"(%6, %arg5) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: 0, -9223372036854775808>, static_sizes = array<i64: 8, 8>, static_strides = array<i64: 1, 1>}> : (memref<8x16xf32, strided<[16, 1], offset: ?>>, index) -> memref<8x8xf32, strided<[16, 1], offset: ?>>
154+
# CHECK-NEXT: "linalg.matmul"(%5, %7, %8) <{indexing_maps = [#map, #map1, #map2], operandSegmentSizes = array<i32: 2, 1>}> ({
155+
# CHECK-NEXT: ^bb0(%arg6: f32, %arg7: f32, %arg8: f32):
156+
# CHECK-NEXT: %9 = "arith.mulf"(%arg6, %arg7) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
157+
# CHECK-NEXT: %10 = "arith.addf"(%arg8, %9) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
158+
# CHECK-NEXT: "linalg.yield"(%10) : (f32) -> ()
159+
# CHECK-NEXT: }) {__xtc_id_C_, xtc.request_vectorization} : (memref<8x64xf32, strided<[64, 1], offset: ?>>, memref<64x8xf32, strided<[16, 1], offset: ?>>, memref<8x8xf32, strided<[16, 1], offset: ?>>) -> ()
160+
# CHECK-NEXT: "scf.yield"() : () -> ()
161+
# CHECK-NEXT: }) {"./j"} : (index, index, index) -> ()
162+
# CHECK-NEXT: "scf.yield"() : () -> ()
163+
# CHECK-NEXT: }) {"./i"} : (index, index, index) -> ()
164+
# CHECK-NEXT: "kvxcluster.launch_terminator"() : () -> ()
165+
# CHECK-NEXT: }) {mask = 1 : i32, nclusters = 1 : i32} : () -> ()
166+
# CHECK-NEXT: "kvxcluster.await_all"() : () -> ()
167+
# CHECK-NEXT: "mppa.yield"() : () -> ()
168+
# CHECK-NEXT: }) {device = 1 : i32} : () -> ()
169+
# CHECK-NEXT: "func.return"() : () -> ()
170+
# CHECK-NEXT: }) : () -> ()
171+
# CHECK-NEXT: }) {transform.with_named_sequence} : () -> ()
172+
# CHECK-NEXT:
173+
# CHECK-NEXT: // -----// IR Dump After MPPA Opt //----- //
174+
# CHECK-NEXT: module attributes {transform.with_named_sequence} {
175+
# CHECK-NEXT: func.func @kvxcluster_launch_0_kernel_cc_0(%arg0: memref<16x16xf32, 2>, %arg1: memref<16x64xf32, 2>, %arg2: memref<64x16xf32, 2>) attributes {kernel_for_cluster_id = 0 : index} {
176+
# CHECK-NEXT: %c64 = arith.constant 64 : index
177+
# CHECK-NEXT: %c8 = arith.constant 8 : index
178+
# CHECK-NEXT: %c1 = arith.constant 1 : index
179+
# CHECK-NEXT: %c0 = arith.constant 0 : index
180+
# CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32
181+
# CHECK-NEXT: %c16 = arith.constant 16 : index
182+
# CHECK-NEXT: scf.for %arg3 = %c0 to %c16 step %c1 {
183+
# CHECK-NEXT: scf.for %arg4 = %c0 to %c16 step %c1 {
184+
# CHECK-NEXT: %0 = arith.muli %arg3, %c16 overflow<nsw> : index
185+
# CHECK-NEXT: %1 = arith.addi %0, %arg4 : index
186+
# CHECK-NEXT: %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%1], sizes: [1, 1], strides: [16, 1] : memref<16x16xf32, 2> to memref<1x1xf32, strided<[16, 1], offset: ?>, 2>
187+
# CHECK-NEXT: kvxpe.launch %arg5 (npes=1) {
188+
# CHECK-NEXT: memref.store %cst, %reinterpret_cast[%c0, %c0] : memref<1x1xf32, strided<[16, 1], offset: ?>, 2>
189+
# CHECK-NEXT: kvxpe.launch_terminator
190+
# CHECK-NEXT: }
191+
# CHECK-NEXT: kvxpe.await_all
192+
# CHECK-NEXT: } {"./j"}
193+
# CHECK-NEXT: } {"./i"}
194+
# CHECK-NEXT: scf.for %arg3 = %c0 to %c16 step %c8 {
195+
# CHECK-NEXT: %0 = arith.muli %arg3, %c64 overflow<nsw> : index
196+
# CHECK-NEXT: %reinterpret_cast = memref.reinterpret_cast %arg1 to offset: [%0], sizes: [8, 64], strides: [64, 1] : memref<16x64xf32, 2> to memref<8x64xf32, strided<[64, 1], offset: ?>, 2>
197+
# CHECK-NEXT: scf.for %arg4 = %c0 to %c16 step %c8 {
198+
# CHECK-NEXT: %reinterpret_cast_0 = memref.reinterpret_cast %arg2 to offset: [%arg4], sizes: [64, 8], strides: [16, 1] : memref<64x16xf32, 2> to memref<64x8xf32, strided<[16, 1], offset: ?>, 2>
199+
# CHECK-NEXT: %1 = arith.muli %arg3, %c16 overflow<nsw> : index
200+
# CHECK-NEXT: %2 = arith.addi %1, %arg4 : index
201+
# CHECK-NEXT: %reinterpret_cast_1 = memref.reinterpret_cast %arg0 to offset: [%2], sizes: [8, 8], strides: [16, 1] : memref<16x16xf32, 2> to memref<8x8xf32, strided<[16, 1], offset: ?>, 2>
202+
# CHECK-NEXT: kvxpe.launch %arg5 (npes=1) {
203+
# CHECK-NEXT: kvxuks.mma_8x8xf32 %reinterpret_cast, %reinterpret_cast_0 -> %reinterpret_cast_1 : memref<8x64xf32, strided<[64, 1], offset: ?>, 2>, memref<64x8xf32, strided<[16, 1], offset: ?>, 2>, memref<8x8xf32, strided<[16, 1], offset: ?>, 2>
204+
# CHECK-NEXT: kvxpe.launch_terminator
205+
# CHECK-NEXT: }
206+
# CHECK-NEXT: kvxpe.await_all
207+
# CHECK-NEXT: } {"./j"}
208+
# CHECK-NEXT: } {"./i"}
209+
# CHECK-NEXT: return
210+
# CHECK-NEXT: }
211+
# CHECK-NEXT: func.func @matmul(%arg0: memref<16x64xf32> {llvm.noalias}, %arg1: memref<64x16xf32> {llvm.noalias}, %arg2: memref<16x16xf32> {llvm.noalias}) {
212+
# CHECK-NEXT: mppa.launch(k300) {
213+
# CHECK-NEXT: %0 = mppa.alloc : memref<16x16xf32, 2>
214+
# CHECK-NEXT: mppa.copy %arg2, %0 : memref<16x16xf32> to memref<16x16xf32, 2>
215+
# CHECK-NEXT: %1 = mppa.alloc : memref<16x64xf32, 2>
216+
# CHECK-NEXT: mppa.copy %arg0, %1 : memref<16x64xf32> to memref<16x64xf32, 2>
217+
# CHECK-NEXT: %2 = mppa.alloc : memref<64x16xf32, 2>
218+
# CHECK-NEXT: mppa.copy %arg1, %2 : memref<64x16xf32> to memref<64x16xf32, 2>
219+
# CHECK-NEXT: kvxcluster.launch (nclusters=1, mask=1)
220+
# CHECK-NEXT: 0 -> @kvxcluster_launch_0_kernel_cc_0
221+
# CHECK-NEXT: with (%0, %1, %2) : memref<16x16xf32, 2>, memref<16x64xf32, 2>, memref<64x16xf32, 2>
222+
# CHECK-NEXT: kvxcluster.await_all
223+
# CHECK-NEXT: mppa.dealloc %2 : memref<64x16xf32, 2>
224+
# CHECK-NEXT: mppa.dealloc %1 : memref<16x64xf32, 2>
225+
# CHECK-NEXT: mppa.copy %0, %arg2 : memref<16x16xf32, 2> to memref<16x16xf32>
226+
# CHECK-NEXT: mppa.dealloc %0 : memref<16x16xf32, 2>
227+
# CHECK-NEXT: kvxcluster.await_all
228+
# CHECK-NEXT: }
229+
# CHECK-NEXT: return
230+
# CHECK-NEXT: }
231+
# CHECK-NEXT: }
232+
# CHECK-NEXT:
233+
# CHECK-NEXT:
234+
# CHECK-NEXT: graph:
235+
# CHECK-NEXT: name: matmul
236+
# CHECK-NEXT: inputs:
237+
# CHECK-NEXT: - %0 : 16x64xfloat32
238+
# CHECK-NEXT: - %1 : 64x16xfloat32
239+
# CHECK-NEXT: outputs:
240+
# CHECK-NEXT: - %2 : 16x16xfloat32
241+
# CHECK-NEXT: nodes:
242+
# CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [16x64xfloat32, 64x16xfloat32] -> [16x16xfloat32]
243+
# CHECK-NEXT:
244+
# CHECK-NEXT: CODE: 0

0 commit comments

Comments
 (0)