Skip to content

Commit d35bf1b

Browse files
committed
descript: validate low-level schedules
1 parent 0a2c9a2 commit d35bf1b

2 files changed

Lines changed: 103 additions & 11 deletions

File tree

src/xtc/backends/mlir/MlirScheduler.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import xtc.itf as itf
1010
import xtc.backends.mlir as backend
1111

12-
from .MlirNodeScheduler import MlirNodeScheduler, MlirNodeSchedule
12+
from .MlirNodeScheduler import MlirNodeScheduler, MlirNodeSchedule, basename
1313

1414
__all__ = [
1515
"MlirScheduler",
@@ -212,36 +212,41 @@ def get_loop_nest(self) -> LoopNest:
212212
loop_nest = LoopNest(abstract_dims=dims)
213213
root_node = loop_nest.build_root_node(node_sched.node_name)
214214

215-
# Assign splits to root_node first
215+
# Assign splits to root_node first, stripping the root prefix from names
216216
for axis, axis_splits in node_sched.splits.items():
217-
root_node.splits[axis] = dict(axis_splits)
217+
root_node.splits[axis] = {basename(k): v for k, v in axis_splits.items()}
218218

219219
# Build mapper to get splits_info
220220
mapper = LoopInfo.build_from_node(root_node)
221221

222222
def populate_node(node: LoopNestNode, perm: list[str]) -> None:
223223
"""Populate node with data for loops in its permutation."""
224-
node.interchange = list(perm)
225224
perm_set = set(perm)
225+
node.interchange = [basename(n) for n in perm]
226226
for axis, axis_tiles in node_sched.tiles.items():
227227
for tile_name, size in axis_tiles.items():
228228
if tile_name in perm_set:
229229
if axis not in node.tiles:
230230
node.tiles[axis] = {}
231-
node.tiles[axis][tile_name] = size
232-
node.vectorize = [v for v in node_sched.vectorization if v in perm_set]
233-
node.parallelize = [p for p in node_sched.parallelization if p in perm_set]
231+
node.tiles[axis][basename(tile_name)] = size
232+
node.vectorize = [
233+
basename(v) for v in node_sched.vectorization if v in perm_set
234+
]
235+
node.parallelize = [
236+
basename(p) for p in node_sched.parallelization if p in perm_set
237+
]
234238
node.unroll = {
235-
k: v for k, v in node_sched.unrolling.items() if k in perm_set
239+
basename(k): v for k, v in node_sched.unrolling.items() if k in perm_set
236240
}
237241

238242
# Process each root in permutation
239243
for root, perm in node_sched.permutation.items():
240-
if root in mapper.splits_info:
244+
root_name = basename(root)
245+
if root_name in mapper.splits_info:
241246
# This root is a split - create child node
242-
axis, start, end = mapper.splits_info[root]
247+
axis, start, end = mapper.splits_info[root_name]
243248
child = LoopNestNode(
244-
root=root,
249+
root=root_name,
245250
tiles={d: {} for d in dims},
246251
split_origin=SplitOrigin(axis=axis, start=start, end=end),
247252
)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# RUN: python %s 2>&1 | filecheck %s --check-prefix=CHECK-VALID
2+
# RUN: not python %s --unused-dim 2>&1 | filecheck %s --check-prefix=CHECK-UNUSED-DIM
3+
# RUN: not python %s --vect-inconsistency 2>&1 | filecheck %s --check-prefix=CHECK-VECT
4+
# RUN: not python %s --tile-before-axis 2>&1 | filecheck %s --check-prefix=CHECK-ORDER
5+
# RUN: not python %s --tile-too-large 2>&1 | filecheck %s --check-prefix=CHECK-SIZE
6+
7+
import sys
8+
import xtc.graphs.xtc.op as O
9+
from xtc.backends.mlir import Backend
10+
11+
I, J, K, dtype = 4, 32, 512, "float32"
12+
a = O.tensor((I, K), dtype, name="A")
13+
b = O.tensor((K, J), dtype, name="B")
14+
15+
with O.graph(name="matmul") as gb:
16+
O.matmul(a, b, name="C")
17+
18+
graph = gb.graph
19+
20+
21+
def make_scheduler():
22+
impl = Backend(graph)
23+
return impl.get_scheduler()
24+
25+
26+
if len(sys.argv) == 1:
27+
sch = make_scheduler()
28+
sch.set_dims(["I", "J", "K"])
29+
sch.tile("I", {"I0": 2})
30+
sch.tile("J", {"J0": 16})
31+
sch.interchange(["K", "I", "J", "I0", "J0"])
32+
sch.vectorize(["J0"])
33+
34+
loop_nest = sch.get_loop_nest()
35+
loop_nest.check()
36+
print("ok")
37+
38+
# CHECK-VALID: ok
39+
40+
elif "--unused-dim" in sys.argv:
41+
sch = make_scheduler()
42+
sch.set_dims(["I", "J", "K"])
43+
sch.tile("I", {"I0": 2})
44+
sch.interchange(["I", "J", "I0"])
45+
46+
loop_nest = sch.get_loop_nest()
47+
loop_nest.check()
48+
49+
# CHECK-UNUSED-DIM: K defined but never used
50+
51+
elif "--vect-inconsistency" in sys.argv:
52+
sch = make_scheduler()
53+
sch.set_dims(["I", "J", "K"])
54+
sch.tile("I", {"I0": 2})
55+
sch.tile("J", {"J0": 16})
56+
sch.interchange(["K", "I", "J", "J0", "I0"])
57+
sch.vectorize(["J0"])
58+
59+
loop_nest = sch.get_loop_nest()
60+
loop_nest.check()
61+
62+
# CHECK-VECT: Inner loop I0 isn't vectorized but an outer one is.
63+
64+
elif "--tile-before-axis" in sys.argv:
65+
sch = make_scheduler()
66+
sch.set_dims(["I", "J", "K"])
67+
sch.tile("I", {"I0": 2})
68+
sch.tile("J", {"J0": 16})
69+
sch.interchange(["K", "I0", "I", "J", "J0"])
70+
71+
loop_nest = sch.get_loop_nest()
72+
loop_nest.check()
73+
74+
# CHECK-ORDER: `I#2`: I has not been materialized yet.
75+
76+
elif "--tile-too-large" in sys.argv:
77+
sch = make_scheduler()
78+
sch.set_dims(["I", "J", "K"])
79+
sch.tile("I", {"I0": 4})
80+
sch.tile("I", {"I00": 8})
81+
sch.tile("J", {"J0": 16})
82+
sch.interchange(["K", "I", "J", "I0", "I00", "J0"])
83+
84+
loop_nest = sch.get_loop_nest()
85+
loop_nest.check()
86+
87+
# CHECK-SIZE: Inner loop I00 on axis I must be smaller than outer loop.

0 commit comments

Comments
 (0)