|
| 1 | +# RUN: python %s 2>&1 | filecheck %s |
| 2 | +# REQUIRES: mlir-target=nvgpu |
| 3 | + |
| 4 | +import xtc.graphs.xtc.op as O |
| 5 | +from xtc.backends.mlir.MlirGraphBackend import MlirGraphBackend as Backend |
| 6 | + |
| 7 | +from xtc.runtimes.accelerator.gpu import GPUDevice |
| 8 | + |
| 9 | +# Create device |
| 10 | +gpu = GPUDevice() |
| 11 | + |
| 12 | +I, J, K, dtype = 4, 32, 512, "float32" |
| 13 | +a = O.tensor((I, K), dtype, name="A") # A lives on the host |
| 14 | +b = O.tensor((K, J), dtype, name="B", device=gpu) # B lives on the accelerator |
| 15 | + |
| 16 | +with O.graph(name="matmul") as gb: |
| 17 | + O.matmul(a, b, name="C", device=gpu) # C must live on the accelerator |
| 18 | + |
| 19 | +graph = gb.graph |
| 20 | +print(graph) |
| 21 | + |
| 22 | +impl = Backend(graph) |
| 23 | + |
| 24 | +sch = impl.get_scheduler() |
| 25 | +sch.tile("i", {"i1": 2}) |
| 26 | +sch.tile("j", {"j1": 16}) |
| 27 | +sch.unroll({"i1": 2}) |
| 28 | +sch.parallelize(["i"]) |
| 29 | +sched = sch.schedule() |
| 30 | + |
| 31 | +comp = impl.get_compiler( |
| 32 | + target=gpu, |
| 33 | + shared_lib=True, |
| 34 | + dump_file="gpu_matmul_mlir_offload_tensor", |
| 35 | + print_source_ir=True, |
| 36 | + print_transformed_ir=True, |
| 37 | +) |
| 38 | +module = comp.compile(sched) |
| 39 | +executor = module.get_executor(validate=True) |
| 40 | +res = executor.execute() |
| 41 | +print(f"CODE: {res}") |
| 42 | +# CHECK: // -----// IR Dump Before transform //----- // |
| 43 | +# CHECK-NEXT: module attributes {transform.with_named_sequence} { |
| 44 | +# CHECK-NEXT: func.func @matmul(%arg0: memref<4x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias, memref.on_device}, %arg2: memref<4x32xf32> {llvm.noalias, memref.on_device}) { |
| 45 | +# CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 |
| 46 | +# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst : f32) outs(%arg2 : memref<4x32xf32>) |
| 47 | +# CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%arg0, %arg1 : memref<4x512xf32>, memref<512x32xf32>) outs(%arg2 : memref<4x32xf32>) |
| 48 | +# CHECK-NEXT: return |
| 49 | +# CHECK-NEXT: } |
| 50 | +# CHECK-NEXT: transform.named_sequence @_vecto(%arg0: !transform.any_op {transform.consumed}) { |
| 51 | +# CHECK-NEXT: transform.structured.vectorize %arg0 : !transform.any_op |
| 52 | +# CHECK-NEXT: transform.yield |
| 53 | +# CHECK-NEXT: } |
| 54 | +# CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { |
| 55 | +# CHECK-NEXT: %0 = transform.structured.match attributes {__xtc_id_C_0_} in %arg0 : (!transform.any_op) -> !transform.any_op |
| 56 | +# CHECK-NEXT: %tiled_linalg_op, %loops = transform.structured.tile_using_for %0 tile_sizes [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) |
| 57 | +# CHECK-NEXT: transform.annotate %loops "./i" : !transform.any_op |
| 58 | +# CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %tiled_linalg_op tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) |
| 59 | +# CHECK-NEXT: transform.annotate %loops_1 "./j" : !transform.any_op |
| 60 | +# CHECK-NEXT: %1 = transform.structured.match attributes {__xtc_id_C_} in %arg0 : (!transform.any_op) -> !transform.any_op |
| 61 | +# CHECK-NEXT: %tiled_op, %forall_op = transform.structured.tile_using_forall %1 tile_sizes [2, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) |
| 62 | +# CHECK-NEXT: transform.annotate %forall_op "./i" : !transform.any_op |
| 63 | +# CHECK-NEXT: %tiled_linalg_op_2, %loops_3 = transform.structured.tile_using_for %tiled_op tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) |
| 64 | +# CHECK-NEXT: transform.annotate %loops_3 "./j" : !transform.any_op |
| 65 | +# CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiled_linalg_op_2 tile_sizes [0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) |
| 66 | +# CHECK-NEXT: transform.annotate %loops_5 "./k" : !transform.any_op |
| 67 | +# CHECK-NEXT: %tiled_linalg_op_6, %loops_7 = transform.structured.tile_using_for %tiled_linalg_op_4 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) |
| 68 | +# CHECK-NEXT: transform.annotate %loops_7 "./i1" : !transform.any_op |
| 69 | +# CHECK-NEXT: %tiled_linalg_op_8, %loops_9 = transform.structured.tile_using_for %tiled_linalg_op_6 tile_sizes [0, 1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) |
| 70 | +# CHECK-NEXT: transform.annotate %loops_9 "./j1" : !transform.any_op |
| 71 | +# CHECK-NEXT: transform.loop.unroll %loops_7 {factor = 2 : i64} : !transform.any_op |
| 72 | +# CHECK-NEXT: transform.yield |
| 73 | +# CHECK-NEXT: } |
| 74 | +# CHECK-NEXT: } |
| 75 | +# CHECK-NEXT: |
| 76 | +# CHECK-NEXT: // -----// IR Dump After transform //----- // |
| 77 | +# CHECK-NEXT: #map = affine_map<(d0) -> (d0 * 2)> |
| 78 | +# CHECK-NEXT: module attributes {transform.with_named_sequence} { |
| 79 | +# CHECK-NEXT: func.func @matmul(%arg0: memref<4x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias, memref.on_device}, %arg2: memref<4x32xf32> {llvm.noalias, memref.on_device}) { |
| 80 | +# CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 |
| 81 | +# CHECK-NEXT: %c0 = arith.constant 0 : index |
| 82 | +# CHECK-NEXT: %c4 = arith.constant 4 : index |
| 83 | +# CHECK-NEXT: %c1 = arith.constant 1 : index |
| 84 | +# CHECK-NEXT: scf.for %arg3 = %c0 to %c4 step %c1 { |
| 85 | +# CHECK-NEXT: %subview = memref.subview %arg2[%arg3, 0] [1, 32] [1, 1] : memref<4x32xf32> to memref<1x32xf32, strided<[32, 1], offset: ?>> |
| 86 | +# CHECK-NEXT: %c0_0 = arith.constant 0 : index |
| 87 | +# CHECK-NEXT: %c32 = arith.constant 32 : index |
| 88 | +# CHECK-NEXT: %c1_1 = arith.constant 1 : index |
| 89 | +# CHECK-NEXT: scf.for %arg4 = %c0_0 to %c32 step %c1_1 { |
| 90 | +# CHECK-NEXT: %subview_2 = memref.subview %subview[0, %arg4] [1, 1] [1, 1] : memref<1x32xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>> |
| 91 | +# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst : f32) outs(%subview_2 : memref<1x1xf32, strided<[32, 1], offset: ?>>) |
| 92 | +# CHECK-NEXT: } {"./j"} |
| 93 | +# CHECK-NEXT: } {"./i"} |
| 94 | +# CHECK-NEXT: scf.forall (%arg3) in (2) { |
| 95 | +# CHECK-NEXT: %0 = affine.apply #map(%arg3) |
| 96 | +# CHECK-NEXT: %subview = memref.subview %arg0[%0, 0] [2, 512] [1, 1] : memref<4x512xf32> to memref<2x512xf32, strided<[512, 1], offset: ?>> |
| 97 | +# CHECK-NEXT: %subview_0 = memref.subview %arg1[0, 0] [512, 32] [1, 1] : memref<512x32xf32> to memref<512x32xf32, strided<[32, 1]>> |
| 98 | +# CHECK-NEXT: %subview_1 = memref.subview %arg2[%0, 0] [2, 32] [1, 1] : memref<4x32xf32> to memref<2x32xf32, strided<[32, 1], offset: ?>> |
| 99 | +# CHECK-NEXT: %c0_2 = arith.constant 0 : index |
| 100 | +# CHECK-NEXT: %c32 = arith.constant 32 : index |
| 101 | +# CHECK-NEXT: %c16 = arith.constant 16 : index |
| 102 | +# CHECK-NEXT: scf.for %arg4 = %c0_2 to %c32 step %c16 { |
| 103 | +# CHECK-NEXT: %subview_3 = memref.subview %subview[0, 0] [2, 512] [1, 1] : memref<2x512xf32, strided<[512, 1], offset: ?>> to memref<2x512xf32, strided<[512, 1], offset: ?>> |
| 104 | +# CHECK-NEXT: %subview_4 = memref.subview %subview_0[0, %arg4] [512, 16] [1, 1] : memref<512x32xf32, strided<[32, 1]>> to memref<512x16xf32, strided<[32, 1], offset: ?>> |
| 105 | +# CHECK-NEXT: %subview_5 = memref.subview %subview_1[0, %arg4] [2, 16] [1, 1] : memref<2x32xf32, strided<[32, 1], offset: ?>> to memref<2x16xf32, strided<[32, 1], offset: ?>> |
| 106 | +# CHECK-NEXT: %c0_6 = arith.constant 0 : index |
| 107 | +# CHECK-NEXT: %c512 = arith.constant 512 : index |
| 108 | +# CHECK-NEXT: %c1_7 = arith.constant 1 : index |
| 109 | +# CHECK-NEXT: scf.for %arg5 = %c0_6 to %c512 step %c1_7 { |
| 110 | +# CHECK-NEXT: %subview_8 = memref.subview %subview_3[0, %arg5] [2, 1] [1, 1] : memref<2x512xf32, strided<[512, 1], offset: ?>> to memref<2x1xf32, strided<[512, 1], offset: ?>> |
| 111 | +# CHECK-NEXT: %subview_9 = memref.subview %subview_4[%arg5, 0] [1, 16] [1, 1] : memref<512x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> |
| 112 | +# CHECK-NEXT: %subview_10 = memref.subview %subview_5[0, 0] [2, 16] [1, 1] : memref<2x16xf32, strided<[32, 1], offset: ?>> to memref<2x16xf32, strided<[32, 1], offset: ?>> |
| 113 | +# CHECK-NEXT: %c0_11 = arith.constant 0 : index |
| 114 | +# CHECK-NEXT: %c2 = arith.constant 2 : index |
| 115 | +# CHECK-NEXT: %c1_12 = arith.constant 1 : index |
| 116 | +# CHECK-NEXT: %c2_13 = arith.constant 2 : index |
| 117 | +# CHECK-NEXT: %subview_14 = memref.subview %subview_8[%c0_11, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> |
| 118 | +# CHECK-NEXT: %subview_15 = memref.subview %subview_9[0, 0] [1, 16] [1, 1] : memref<1x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> |
| 119 | +# CHECK-NEXT: %subview_16 = memref.subview %subview_10[%c0_11, 0] [1, 16] [1, 1] : memref<2x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> |
| 120 | +# CHECK-NEXT: %c0_17 = arith.constant 0 : index |
| 121 | +# CHECK-NEXT: %c16_18 = arith.constant 16 : index |
| 122 | +# CHECK-NEXT: %c1_19 = arith.constant 1 : index |
| 123 | +# CHECK-NEXT: scf.for %arg6 = %c0_17 to %c16_18 step %c1_19 { |
| 124 | +# CHECK-NEXT: %subview_27 = memref.subview %subview_14[0, 0] [1, 1] [1, 1] : memref<1x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> |
| 125 | +# CHECK-NEXT: %subview_28 = memref.subview %subview_15[0, %arg6] [1, 1] [1, 1] : memref<1x16xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>> |
| 126 | +# CHECK-NEXT: %subview_29 = memref.subview %subview_16[0, %arg6] [1, 1] [1, 1] : memref<1x16xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>> |
| 127 | +# CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%subview_27, %subview_28 : memref<1x1xf32, strided<[512, 1], offset: ?>>, memref<1x1xf32, strided<[32, 1], offset: ?>>) outs(%subview_29 : memref<1x1xf32, strided<[32, 1], offset: ?>>) |
| 128 | +# CHECK-NEXT: } {"./j1"} |
| 129 | +# CHECK-NEXT: %c1_20 = arith.constant 1 : index |
| 130 | +# CHECK-NEXT: %1 = arith.muli %c1_12, %c1_20 : index |
| 131 | +# CHECK-NEXT: %2 = arith.addi %c0_11, %1 : index |
| 132 | +# CHECK-NEXT: %subview_21 = memref.subview %subview_8[%2, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> |
| 133 | +# CHECK-NEXT: %subview_22 = memref.subview %subview_9[0, 0] [1, 16] [1, 1] : memref<1x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> |
| 134 | +# CHECK-NEXT: %subview_23 = memref.subview %subview_10[%2, 0] [1, 16] [1, 1] : memref<2x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> |
| 135 | +# CHECK-NEXT: %c0_24 = arith.constant 0 : index |
| 136 | +# CHECK-NEXT: %c16_25 = arith.constant 16 : index |
| 137 | +# CHECK-NEXT: %c1_26 = arith.constant 1 : index |
| 138 | +# CHECK-NEXT: scf.for %arg6 = %c0_24 to %c16_25 step %c1_26 { |
| 139 | +# CHECK-NEXT: %subview_27 = memref.subview %subview_21[0, 0] [1, 1] [1, 1] : memref<1x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> |
| 140 | +# CHECK-NEXT: %subview_28 = memref.subview %subview_22[0, %arg6] [1, 1] [1, 1] : memref<1x16xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>> |
| 141 | +# CHECK-NEXT: %subview_29 = memref.subview %subview_23[0, %arg6] [1, 1] [1, 1] : memref<1x16xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>> |
| 142 | +# CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%subview_27, %subview_28 : memref<1x1xf32, strided<[512, 1], offset: ?>>, memref<1x1xf32, strided<[32, 1], offset: ?>>) outs(%subview_29 : memref<1x1xf32, strided<[32, 1], offset: ?>>) |
| 143 | +# CHECK-NEXT: } {"./j1"} |
| 144 | +# CHECK-NEXT: } {"./k"} |
| 145 | +# CHECK-NEXT: } {"./j"} |
| 146 | +# CHECK-NEXT: } {"./i"} |
| 147 | +# CHECK-NEXT: return |
| 148 | +# CHECK-NEXT: } |
| 149 | +# CHECK-NEXT: } |
| 150 | +# CHECK-NEXT: |
| 151 | +# CHECK-NEXT: graph: |
| 152 | +# CHECK-NEXT: name: matmul |
| 153 | +# CHECK-NEXT: inputs: |
| 154 | +# CHECK-NEXT: - %0 : 4x512xfloat32 |
| 155 | +# CHECK-NEXT: - %1 : 512x32xfloat32 |
| 156 | +# CHECK-NEXT: outputs: |
| 157 | +# CHECK-NEXT: - %2 : 4x32xfloat32 |
| 158 | +# CHECK-NEXT: nodes: |
| 159 | +# CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [4x512xfloat32, 512x32xfloat32] -> [4x32xfloat32] |
| 160 | +# CHECK-NEXT: |
| 161 | +# CHECK-NEXT: CODE: 0 |
0 commit comments