|
| 1 | +# RUN: python %s 2>&1 | filecheck %s |
| 2 | +# REQUIRES: module_mlir_mppa |
| 3 | +# REQUIRES: mlir-target=mppa |
| 4 | + |
| 5 | +import xtc.graphs.xtc.op as O |
| 6 | +from xtc.backends.mlir.MlirGraphBackend import MlirGraphBackend as Backend |
| 7 | + |
| 8 | +from xtc.runtimes.accelerator.mppa import MppaDevice |
| 9 | + |
| 10 | +I, J, K, dtype = 16, 16, 64, "float32" |
| 11 | +a = O.tensor((I, K), dtype, name="A") |
| 12 | +b = O.tensor((K, J), dtype, name="B") |
| 13 | + |
| 14 | +with O.graph(name="matmul") as gb: |
| 15 | + O.matmul(a, b, name="C") |
| 16 | + |
| 17 | +graph = gb.graph |
| 18 | +print(graph) |
| 19 | + |
| 20 | +impl = Backend(graph) |
| 21 | + |
| 22 | +sch = impl.get_scheduler() |
| 23 | +sch.define_memory_mesh(axes={"mx": 1, "my": 1}) |
| 24 | +sch.define_processor_mesh(axes={"px": 1, "py": 1, "psx": 1, "psy": 1}) |
| 25 | +sch.tile("i", {"i1": 8}) |
| 26 | +sch.tile("j", {"j1": 8}) |
| 27 | +sch.interchange(["i", "j", "i1", "j1", "k"]) |
| 28 | +sch.vectorize(["i1", "j1", "k"]) |
| 29 | +#sch.pack_at("i1", 1) |
| 30 | +sched = sch.schedule() |
| 31 | + |
| 32 | +# Create mppa device |
| 33 | +mppa = MppaDevice() |
| 34 | + |
| 35 | +comp = impl.get_compiler( |
| 36 | + target=mppa, |
| 37 | + shared_lib=True, |
| 38 | + dump_file="matmul_mlir_mppa", |
| 39 | + print_source_ir=True, |
| 40 | + print_transformed_ir=True, |
| 41 | + print_lowered_ir=True, |
| 42 | +) |
| 43 | +module = comp.compile(sched) |
| 44 | +executor = module.get_executor(validate=True) |
| 45 | +res = executor.execute() |
| 46 | +print(f"CODE: {res}") |
| 47 | +# CHECK: // -----// IR Dump Before transform //----- // |
| 48 | +# CHECK-NEXT: module attributes {transform.with_named_sequence} { |
| 49 | +# CHECK-NEXT: func.func @matmul(%arg0: memref<16x64xf32> {llvm.noalias}, %arg1: memref<64x16xf32> {llvm.noalias}, %arg2: memref<16x16xf32> {llvm.noalias}) { |
| 50 | +# CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 |
| 51 | +# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst : f32) outs(%arg2 : memref<16x16xf32>) |
| 52 | +# CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%arg0, %arg1 : memref<16x64xf32>, memref<64x16xf32>) outs(%arg2 : memref<16x16xf32>) |
| 53 | +# CHECK-NEXT: return |
| 54 | +# CHECK-NEXT: } |
| 55 | +# CHECK-NEXT: transform.named_sequence @_vecto(%arg0: !transform.any_op {transform.consumed}) { |
| 56 | +# CHECK-NEXT: transform.structured.vectorize %arg0 : !transform.any_op |
| 57 | +# CHECK-NEXT: transform.yield |
| 58 | +# CHECK-NEXT: } |
| 59 | +# CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { |
| 60 | +# CHECK-NEXT: %0 = transform.sdist.create_memory_mesh %arg0 "memory_mesh" = <["mx"=1, "my"=1]> : !transform.any_op -> !transform.any_op |
| 61 | +# CHECK-NEXT: %1 = transform.sdist.create_processor_mesh %arg0 "processor_mesh" = <["px"=1, "py"=1, "psx"=1, "psy"=1]> from "memory_mesh" : !transform.any_op -> !transform.any_op |
| 62 | +# CHECK-NEXT: %2 = transform.structured.match attributes {__xtc_id_C_0_} in %arg0 : (!transform.any_op) -> !transform.any_op |
| 63 | +# CHECK-NEXT: %tiled_linalg_op, %loops = transform.structured.tile_using_for %2 tile_sizes [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) |
| 64 | +# CHECK-NEXT: transform.annotate %loops "./i" : !transform.any_op |
| 65 | +# CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %tiled_linalg_op tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) |
| 66 | +# CHECK-NEXT: transform.annotate %loops_1 "./j" : !transform.any_op |
| 67 | +# CHECK-NEXT: %3 = transform.structured.match attributes {__xtc_id_C_} in %arg0 : (!transform.any_op) -> !transform.any_op |
| 68 | +# CHECK-NEXT: %tiled_linalg_op_2, %loops_3 = transform.structured.tile_using_for %3 tile_sizes [8, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) |
| 69 | +# CHECK-NEXT: transform.annotate %loops_3 "./i" : !transform.any_op |
| 70 | +# CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiled_linalg_op_2 tile_sizes [0, 8, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) |
| 71 | +# CHECK-NEXT: transform.annotate %loops_5 "./j" : !transform.any_op |
| 72 | +# CHECK-NEXT: transform.annotate %tiled_linalg_op_4 "xtc.request_vectorization" : !transform.any_op |
| 73 | +# CHECK-NEXT: %4 = transform.get_parent_op %loops_3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op |
| 74 | +# CHECK-NEXT: transform.apply_patterns to %4 { |
| 75 | +# CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract |
| 76 | +# CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns |
| 77 | +# CHECK-NEXT: } : !transform.any_op |
| 78 | +# CHECK-NEXT: transform.apply_patterns to %4 { |
| 79 | +# CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct |
| 80 | +# CHECK-NEXT: transform.apply_patterns.vector.lower_contraction |
| 81 | +# CHECK-NEXT: } : !transform.any_op |
| 82 | +# CHECK-NEXT: transform.yield |
| 83 | +# CHECK-NEXT: } |
| 84 | +# CHECK-NEXT: } |
| 85 | +# CHECK-NEXT: |
| 86 | +# CHECK-NEXT: // -----// IR Dump After transform //----- // |
| 87 | +# CHECK-NEXT: module attributes {transform.with_named_sequence} { |
| 88 | +# CHECK-NEXT: sdist.processor_mesh @processor_mesh from @memory_mesh = <["px"=1, "py"=1, "psx"=1, "psy"=1]> |
| 89 | +# CHECK-NEXT: sdist.memory_mesh @memory_mesh = <["mx"=1, "my"=1]> |
| 90 | +# CHECK-NEXT: func.func @matmul(%arg0: memref<16x64xf32> {llvm.noalias}, %arg1: memref<64x16xf32> {llvm.noalias}, %arg2: memref<16x16xf32> {llvm.noalias}) { |
| 91 | +# CHECK-NEXT: %c8 = arith.constant 8 : index |
| 92 | +# CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 |
| 93 | +# CHECK-NEXT: %c0 = arith.constant 0 : index |
| 94 | +# CHECK-NEXT: %c16 = arith.constant 16 : index |
| 95 | +# CHECK-NEXT: %c1 = arith.constant 1 : index |
| 96 | +# CHECK-NEXT: scf.for %arg3 = %c0 to %c16 step %c1 { |
| 97 | +# CHECK-NEXT: %subview = memref.subview %arg2[%arg3, 0] [1, 16] [1, 1] : memref<16x16xf32> to memref<1x16xf32, strided<[16, 1], offset: ?>> |
| 98 | +# CHECK-NEXT: scf.for %arg4 = %c0 to %c16 step %c1 { |
| 99 | +# CHECK-NEXT: %subview_0 = memref.subview %subview[0, %arg4] [1, 1] [1, 1] : memref<1x16xf32, strided<[16, 1], offset: ?>> to memref<1x1xf32, strided<[16, 1], offset: ?>> |
| 100 | +# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst : f32) outs(%subview_0 : memref<1x1xf32, strided<[16, 1], offset: ?>>) |
| 101 | +# CHECK-NEXT: } {"./j"} |
| 102 | +# CHECK-NEXT: } {"./i"} |
| 103 | +# CHECK-NEXT: scf.for %arg3 = %c0 to %c16 step %c8 { |
| 104 | +# CHECK-NEXT: %subview = memref.subview %arg0[%arg3, 0] [8, 64] [1, 1] : memref<16x64xf32> to memref<8x64xf32, strided<[64, 1], offset: ?>> |
| 105 | +# CHECK-NEXT: %subview_0 = memref.subview %arg1[0, 0] [64, 16] [1, 1] : memref<64x16xf32> to memref<64x16xf32, strided<[16, 1]>> |
| 106 | +# CHECK-NEXT: %subview_1 = memref.subview %arg2[%arg3, 0] [8, 16] [1, 1] : memref<16x16xf32> to memref<8x16xf32, strided<[16, 1], offset: ?>> |
| 107 | +# CHECK-NEXT: scf.for %arg4 = %c0 to %c16 step %c8 { |
| 108 | +# CHECK-NEXT: %subview_2 = memref.subview %subview_0[0, %arg4] [64, 8] [1, 1] : memref<64x16xf32, strided<[16, 1]>> to memref<64x8xf32, strided<[16, 1], offset: ?>> |
| 109 | +# CHECK-NEXT: %subview_3 = memref.subview %subview_1[0, %arg4] [8, 8] [1, 1] : memref<8x16xf32, strided<[16, 1], offset: ?>> to memref<8x8xf32, strided<[16, 1], offset: ?>> |
| 110 | +# CHECK-NEXT: linalg.matmul {__xtc_id_C_, xtc.request_vectorization} ins(%subview, %subview_2 : memref<8x64xf32, strided<[64, 1], offset: ?>>, memref<64x8xf32, strided<[16, 1], offset: ?>>) outs(%subview_3 : memref<8x8xf32, strided<[16, 1], offset: ?>>) |
| 111 | +# CHECK-NEXT: } {"./j"} |
| 112 | +# CHECK-NEXT: } {"./i"} |
| 113 | +# CHECK-NEXT: return |
| 114 | +# CHECK-NEXT: } |
| 115 | +# CHECK-NEXT: } |
| 116 | +# CHECK-NEXT: |
| 117 | +# CHECK-NEXT: // -----// IR Dump After MLIR Opt //----- // |
| 118 | +# CHECK-NEXT: #map = affine_map<(d0, d1, d2) -> (d0, d2)> |
| 119 | +# CHECK-NEXT: #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> |
| 120 | +# CHECK-NEXT: #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> |
| 121 | +# CHECK-NEXT: "builtin.module"() ({ |
| 122 | +# CHECK-NEXT: "func.func"() <{arg_attrs = [{llvm.noalias}, {llvm.noalias}, {llvm.noalias}], function_type = (memref<16x64xf32>, memref<64x16xf32>, memref<16x16xf32>) -> (), sym_name = "matmul"}> ({ |
| 123 | +# CHECK-NEXT: ^bb0(%arg0: memref<16x64xf32>, %arg1: memref<64x16xf32>, %arg2: memref<16x16xf32>): |
| 124 | +# CHECK-NEXT: "mppa.launch"() ({ |
| 125 | +# CHECK-NEXT: "kvxcluster.launch"() ({ |
| 126 | +# CHECK-NEXT: ^bb0(%arg3: index): |
| 127 | +# CHECK-NEXT: %0 = "arith.constant"() <{value = 1 : index}> : () -> index |
| 128 | +# CHECK-NEXT: %1 = "arith.constant"() <{value = 16 : index}> : () -> index |
| 129 | +# CHECK-NEXT: %2 = "arith.constant"() <{value = 0 : index}> : () -> index |
| 130 | +# CHECK-NEXT: %3 = "arith.constant"() <{value = 0.000000e+00 : f32}> : () -> f32 |
| 131 | +# CHECK-NEXT: %4 = "arith.constant"() <{value = 8 : index}> : () -> index |
| 132 | +# CHECK-NEXT: "scf.for"(%2, %1, %0) ({ |
| 133 | +# CHECK-NEXT: ^bb0(%arg9: index): |
| 134 | +# CHECK-NEXT: %11 = "memref.subview"(%arg2, %arg9) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: 1, 16>, static_strides = array<i64: 1, 1>}> : (memref<16x16xf32>, index) -> memref<1x16xf32, strided<[16, 1], offset: ?>> |
| 135 | +# CHECK-NEXT: "scf.for"(%2, %1, %0) ({ |
| 136 | +# CHECK-NEXT: ^bb0(%arg10: index): |
| 137 | +# CHECK-NEXT: %12 = "memref.subview"(%11, %arg10) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: 0, -9223372036854775808>, static_sizes = array<i64: 1, 1>, static_strides = array<i64: 1, 1>}> : (memref<1x16xf32, strided<[16, 1], offset: ?>>, index) -> memref<1x1xf32, strided<[16, 1], offset: ?>> |
| 138 | +# CHECK-NEXT: "linalg.fill"(%3, %12) <{operandSegmentSizes = array<i32: 1, 1>}> ({ |
| 139 | +# CHECK-NEXT: ^bb0(%arg11: f32, %arg12: f32): |
| 140 | +# CHECK-NEXT: "linalg.yield"(%arg11) : (f32) -> () |
| 141 | +# CHECK-NEXT: }) {__xtc_id_C_0_} : (f32, memref<1x1xf32, strided<[16, 1], offset: ?>>) -> () |
| 142 | +# CHECK-NEXT: "scf.yield"() : () -> () |
| 143 | +# CHECK-NEXT: }) {"./j"} : (index, index, index) -> () |
| 144 | +# CHECK-NEXT: "scf.yield"() : () -> () |
| 145 | +# CHECK-NEXT: }) {"./i"} : (index, index, index) -> () |
| 146 | +# CHECK-NEXT: "scf.for"(%2, %1, %4) ({ |
| 147 | +# CHECK-NEXT: ^bb0(%arg4: index): |
| 148 | +# CHECK-NEXT: %5 = "memref.subview"(%arg0, %arg4) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: 8, 64>, static_strides = array<i64: 1, 1>}> : (memref<16x64xf32>, index) -> memref<8x64xf32, strided<[64, 1], offset: ?>> |
| 149 | +# CHECK-NEXT: %6 = "memref.subview"(%arg2, %arg4) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: 8, 16>, static_strides = array<i64: 1, 1>}> : (memref<16x16xf32>, index) -> memref<8x16xf32, strided<[16, 1], offset: ?>> |
| 150 | +# CHECK-NEXT: "scf.for"(%2, %1, %4) ({ |
| 151 | +# CHECK-NEXT: ^bb0(%arg5: index): |
| 152 | +# CHECK-NEXT: %7 = "memref.subview"(%arg1, %arg5) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: 0, -9223372036854775808>, static_sizes = array<i64: 64, 8>, static_strides = array<i64: 1, 1>}> : (memref<64x16xf32>, index) -> memref<64x8xf32, strided<[16, 1], offset: ?>> |
| 153 | +# CHECK-NEXT: %8 = "memref.subview"(%6, %arg5) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: 0, -9223372036854775808>, static_sizes = array<i64: 8, 8>, static_strides = array<i64: 1, 1>}> : (memref<8x16xf32, strided<[16, 1], offset: ?>>, index) -> memref<8x8xf32, strided<[16, 1], offset: ?>> |
| 154 | +# CHECK-NEXT: "linalg.matmul"(%5, %7, %8) <{indexing_maps = [#map, #map1, #map2], operandSegmentSizes = array<i32: 2, 1>}> ({ |
| 155 | +# CHECK-NEXT: ^bb0(%arg6: f32, %arg7: f32, %arg8: f32): |
| 156 | +# CHECK-NEXT: %9 = "arith.mulf"(%arg6, %arg7) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32 |
| 157 | +# CHECK-NEXT: %10 = "arith.addf"(%arg8, %9) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32 |
| 158 | +# CHECK-NEXT: "linalg.yield"(%10) : (f32) -> () |
| 159 | +# CHECK-NEXT: }) {__xtc_id_C_, xtc.request_vectorization} : (memref<8x64xf32, strided<[64, 1], offset: ?>>, memref<64x8xf32, strided<[16, 1], offset: ?>>, memref<8x8xf32, strided<[16, 1], offset: ?>>) -> () |
| 160 | +# CHECK-NEXT: "scf.yield"() : () -> () |
| 161 | +# CHECK-NEXT: }) {"./j"} : (index, index, index) -> () |
| 162 | +# CHECK-NEXT: "scf.yield"() : () -> () |
| 163 | +# CHECK-NEXT: }) {"./i"} : (index, index, index) -> () |
| 164 | +# CHECK-NEXT: "kvxcluster.launch_terminator"() : () -> () |
| 165 | +# CHECK-NEXT: }) {mask = 1 : i32, nclusters = 1 : i32} : () -> () |
| 166 | +# CHECK-NEXT: "kvxcluster.await_all"() : () -> () |
| 167 | +# CHECK-NEXT: "mppa.yield"() : () -> () |
| 168 | +# CHECK-NEXT: }) {device = 1 : i32} : () -> () |
| 169 | +# CHECK-NEXT: "func.return"() : () -> () |
| 170 | +# CHECK-NEXT: }) : () -> () |
| 171 | +# CHECK-NEXT: }) {transform.with_named_sequence} : () -> () |
| 172 | +# CHECK-NEXT: |
| 173 | +# CHECK-NEXT: // -----// IR Dump After MPPA Opt //----- // |
| 174 | +# CHECK-NEXT: module attributes {transform.with_named_sequence} { |
| 175 | +# CHECK-NEXT: func.func @kvxcluster_launch_0_kernel_cc_0(%arg0: memref<16x16xf32, 2>, %arg1: memref<16x64xf32, 2>, %arg2: memref<64x16xf32, 2>) attributes {kernel_for_cluster_id = 0 : index} { |
| 176 | +# CHECK-NEXT: %c64 = arith.constant 64 : index |
| 177 | +# CHECK-NEXT: %c8 = arith.constant 8 : index |
| 178 | +# CHECK-NEXT: %c1 = arith.constant 1 : index |
| 179 | +# CHECK-NEXT: %c0 = arith.constant 0 : index |
| 180 | +# CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 |
| 181 | +# CHECK-NEXT: %c16 = arith.constant 16 : index |
| 182 | +# CHECK-NEXT: scf.for %arg3 = %c0 to %c16 step %c1 { |
| 183 | +# CHECK-NEXT: scf.for %arg4 = %c0 to %c16 step %c1 { |
| 184 | +# CHECK-NEXT: %0 = arith.muli %arg3, %c16 overflow<nsw> : index |
| 185 | +# CHECK-NEXT: %1 = arith.addi %0, %arg4 : index |
| 186 | +# CHECK-NEXT: %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%1], sizes: [1, 1], strides: [16, 1] : memref<16x16xf32, 2> to memref<1x1xf32, strided<[16, 1], offset: ?>, 2> |
| 187 | +# CHECK-NEXT: kvxpe.launch %arg5 (npes=1) { |
| 188 | +# CHECK-NEXT: memref.store %cst, %reinterpret_cast[%c0, %c0] : memref<1x1xf32, strided<[16, 1], offset: ?>, 2> |
| 189 | +# CHECK-NEXT: kvxpe.launch_terminator |
| 190 | +# CHECK-NEXT: } |
| 191 | +# CHECK-NEXT: kvxpe.await_all |
| 192 | +# CHECK-NEXT: } {"./j"} |
| 193 | +# CHECK-NEXT: } {"./i"} |
| 194 | +# CHECK-NEXT: scf.for %arg3 = %c0 to %c16 step %c8 { |
| 195 | +# CHECK-NEXT: %0 = arith.muli %arg3, %c64 overflow<nsw> : index |
| 196 | +# CHECK-NEXT: %reinterpret_cast = memref.reinterpret_cast %arg1 to offset: [%0], sizes: [8, 64], strides: [64, 1] : memref<16x64xf32, 2> to memref<8x64xf32, strided<[64, 1], offset: ?>, 2> |
| 197 | +# CHECK-NEXT: scf.for %arg4 = %c0 to %c16 step %c8 { |
| 198 | +# CHECK-NEXT: %reinterpret_cast_0 = memref.reinterpret_cast %arg2 to offset: [%arg4], sizes: [64, 8], strides: [16, 1] : memref<64x16xf32, 2> to memref<64x8xf32, strided<[16, 1], offset: ?>, 2> |
| 199 | +# CHECK-NEXT: %1 = arith.muli %arg3, %c16 overflow<nsw> : index |
| 200 | +# CHECK-NEXT: %2 = arith.addi %1, %arg4 : index |
| 201 | +# CHECK-NEXT: %reinterpret_cast_1 = memref.reinterpret_cast %arg0 to offset: [%2], sizes: [8, 8], strides: [16, 1] : memref<16x16xf32, 2> to memref<8x8xf32, strided<[16, 1], offset: ?>, 2> |
| 202 | +# CHECK-NEXT: kvxpe.launch %arg5 (npes=1) { |
| 203 | +# CHECK-NEXT: kvxuks.mma_8x8xf32 %reinterpret_cast, %reinterpret_cast_0 -> %reinterpret_cast_1 : memref<8x64xf32, strided<[64, 1], offset: ?>, 2>, memref<64x8xf32, strided<[16, 1], offset: ?>, 2>, memref<8x8xf32, strided<[16, 1], offset: ?>, 2> |
| 204 | +# CHECK-NEXT: kvxpe.launch_terminator |
| 205 | +# CHECK-NEXT: } |
| 206 | +# CHECK-NEXT: kvxpe.await_all |
| 207 | +# CHECK-NEXT: } {"./j"} |
| 208 | +# CHECK-NEXT: } {"./i"} |
| 209 | +# CHECK-NEXT: return |
| 210 | +# CHECK-NEXT: } |
| 211 | +# CHECK-NEXT: func.func @matmul(%arg0: memref<16x64xf32> {llvm.noalias}, %arg1: memref<64x16xf32> {llvm.noalias}, %arg2: memref<16x16xf32> {llvm.noalias}) { |
| 212 | +# CHECK-NEXT: mppa.launch(k300) { |
| 213 | +# CHECK-NEXT: %0 = mppa.alloc : memref<16x16xf32, 2> |
| 214 | +# CHECK-NEXT: mppa.copy %arg2, %0 : memref<16x16xf32> to memref<16x16xf32, 2> |
| 215 | +# CHECK-NEXT: %1 = mppa.alloc : memref<16x64xf32, 2> |
| 216 | +# CHECK-NEXT: mppa.copy %arg0, %1 : memref<16x64xf32> to memref<16x64xf32, 2> |
| 217 | +# CHECK-NEXT: %2 = mppa.alloc : memref<64x16xf32, 2> |
| 218 | +# CHECK-NEXT: mppa.copy %arg1, %2 : memref<64x16xf32> to memref<64x16xf32, 2> |
| 219 | +# CHECK-NEXT: kvxcluster.launch (nclusters=1, mask=1) |
| 220 | +# CHECK-NEXT: 0 -> @kvxcluster_launch_0_kernel_cc_0 |
| 221 | +# CHECK-NEXT: with (%0, %1, %2) : memref<16x16xf32, 2>, memref<16x64xf32, 2>, memref<64x16xf32, 2> |
| 222 | +# CHECK-NEXT: kvxcluster.await_all |
| 223 | +# CHECK-NEXT: mppa.dealloc %2 : memref<64x16xf32, 2> |
| 224 | +# CHECK-NEXT: mppa.dealloc %1 : memref<16x64xf32, 2> |
| 225 | +# CHECK-NEXT: mppa.copy %0, %arg2 : memref<16x16xf32, 2> to memref<16x16xf32> |
| 226 | +# CHECK-NEXT: mppa.dealloc %0 : memref<16x16xf32, 2> |
| 227 | +# CHECK-NEXT: kvxcluster.await_all |
| 228 | +# CHECK-NEXT: } |
| 229 | +# CHECK-NEXT: return |
| 230 | +# CHECK-NEXT: } |
| 231 | +# CHECK-NEXT: } |
| 232 | +# CHECK-NEXT: |
| 233 | +# CHECK-NEXT: |
| 234 | +# CHECK-NEXT: graph: |
| 235 | +# CHECK-NEXT: name: matmul |
| 236 | +# CHECK-NEXT: inputs: |
| 237 | +# CHECK-NEXT: - %0 : 16x64xfloat32 |
| 238 | +# CHECK-NEXT: - %1 : 64x16xfloat32 |
| 239 | +# CHECK-NEXT: outputs: |
| 240 | +# CHECK-NEXT: - %2 : 16x16xfloat32 |
| 241 | +# CHECK-NEXT: nodes: |
| 242 | +# CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [16x64xfloat32, 64x16xfloat32] -> [16x16xfloat32] |
| 243 | +# CHECK-NEXT: |
| 244 | +# CHECK-NEXT: CODE: 0 |
0 commit comments