Skip to content

Commit 3b5ca77

Browse files
committed
[Mlir] Support buffer_at with SDist
1 parent 0a2c9a2 commit 3b5ca77

21 files changed

Lines changed: 120 additions & 78 deletions

src/xtc/backends/mlir/MlirCompilerPasses.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,12 @@ def _generate_node_scheduling(
312312
schedule=schedule,
313313
sched_state=sched_state,
314314
)
315+
if loop_name in schedule.write_buffers:
316+
self._write_buffer(
317+
loop_name=loop_name,
318+
schedule=schedule,
319+
sched_state=sched_state,
320+
)
315321

316322
# Manage the strip-mining
317323
if loop_name in schedule.vectorization:
@@ -537,6 +543,30 @@ def _pack_buffer(
537543
input_idx=input_idx,
538544
)
539545

546+
def _write_buffer(
547+
self,
548+
loop_name: str,
549+
schedule: MlirNodeSchedule,
550+
sched_state: SchedulingState,
551+
):
552+
from .MlirGraphBackend import MlirGraphBackend
553+
from .MlirNodeBackend import MlirNodeBackend
554+
555+
assert self._mlir_schedule is not None
556+
graph_backend = self._mlir_schedule.scheduler.backend
557+
assert isinstance(graph_backend, MlirGraphBackend)
558+
node_backend = graph_backend.nodes[schedule.node_name]
559+
assert isinstance(node_backend, MlirNodeBackend)
560+
output_idx = len(node_backend.np_inputs_spec())
561+
with InsertionPoint(transform.ApplyPatternsOp(sched_state.handle).patterns):
562+
memref.ApplyFoldMemrefAliasOpsPatternsOp()
563+
if "sdist" in self._mlir_program.mlir_extensions:
564+
assert sdist_transform is not None
565+
sdist_transform.SDistLocalBufferAtOp(
566+
target=sched_state.handle,
567+
input_idx=output_idx,
568+
)
569+
540570

541571
class MlirProgramApplyTransformPass:
542572
def __init__(

src/xtc/backends/mlir/MlirNodeScheduler.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class MlirNodeSchedule:
3030
parallelization: list[str]
3131
unrolling: dict[str, int]
3232
packed_buffers: dict[str, list[int]]
33+
write_buffers: list[str]
3334
memory_mesh: dict[str, int]
3435
processor_mesh: dict[str, int]
3536
distribution: dict[str, str]
@@ -90,6 +91,7 @@ def __init__(
9091
self.parallelization: list[str] = []
9192
self.unrolling: dict[str, int] = {}
9293
self.packed_buffers: dict[str, list[int]] = {}
94+
self.write_buffers: list[str] = []
9395
self.memory_mesh: dict[str, int] = {}
9496
self.processor_mesh: dict[str, int] = {}
9597
self.distribution: dict[str, str] = {}
@@ -112,6 +114,7 @@ def mlir_node_schedule(self) -> MlirNodeSchedule:
112114
unrolling=self.unrolling,
113115
memory_mesh=self.memory_mesh,
114116
packed_buffers=self.packed_buffers,
117+
write_buffers=self.write_buffers,
115118
processor_mesh=self.processor_mesh,
116119
distribution=self.distribution,
117120
distributed_buffers=self.distributed_buffers,
@@ -178,6 +181,10 @@ def pack_at(
178181
else:
179182
self.packed_buffers[axis_key].append(input_idx)
180183

184+
def buffer_at(self, axis: str, mtype: str | None = None, root: str = DEFAULT_ROOT):
185+
axis_key = f"{root}{ROOT_SEP}{axis}"
186+
self.write_buffers.append(axis_key)
187+
181188
def define_memory_mesh(self, axes: dict[str, int]):
182189
assert len(self.memory_mesh) == 0, "Memory mesh has already been defined"
183190
self.memory_mesh = axes

src/xtc/backends/mlir/MlirScheduler.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,14 @@ def interchange(self, permutation: list[str], root: str = DEFAULT_ROOT) -> None:
131131
def buffer_at(
132132
self, axis: str, mtype: str | None = None, root: str = DEFAULT_ROOT
133133
) -> None:
134-
assert mtype is None or mtype == "global"
135-
# TODO: not implemented for now
136-
pass
134+
# The current implementation exclusively rely on SDist, but upstream
135+
# transform dialect may be used for some cases.
136+
assert mtype is None or mtype == "global" or mtype == "local"
137+
if mtype is None or mtype == "global":
138+
self._require_extension("sdist", weak=True)
139+
else:
140+
self._require_extension("sdist")
141+
self._current_scheduler.buffer_at(axis, mtype, root=root)
137142

138143
@override
139144
def pack_at(
@@ -144,7 +149,7 @@ def pack_at(
144149
pad: bool = False,
145150
root: str = DEFAULT_ROOT,
146151
) -> None:
147-
# The current implemntation exclusively rely on SDist, but upstream
152+
# The current implementation exclusively rely on SDist, but upstream
148153
# transform dialect may be used for some cases.
149154
assert mtype is None or mtype == "global" or mtype == "local"
150155
if pad:

tests/filecheck/search/test_conv_oo.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
utils.print_exhaustive_samples(backend, strategy, 100)
1414

1515
# CHECK: schedule O0: [1, 1, 1, 1, 1, 1, 1]
16-
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 1}, 'h': {'./h1': 1}, 'w': {'./w1': 1}, 'f': {'./f1': 1}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 1}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 1, './w1': 1, './h1': 1, './c1': 1, './s1': 1, './r1': 1, './b1': 1}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
16+
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 1}, 'h': {'./h1': 1}, 'w': {'./w1': 1}, 'f': {'./f1': 1}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 1}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 1, './w1': 1, './h1': 1, './c1': 1, './s1': 1, './r1': 1, './b1': 1}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
1717
# CHECK-NEXT: schedule O1: [1, 1, 1, 1, 1, 1, 1]
18-
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 1}, 'h': {'./h1': 1}, 'w': {'./w1': 1}, 'f': {'./f1': 1}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 1}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 1, './w1': 1, './h1': 1, './c1': 1, './s1': 1, './r1': 1, './b1': 1}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
18+
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 1}, 'h': {'./h1': 1}, 'w': {'./w1': 1}, 'f': {'./f1': 1}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 1}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 1, './w1': 1, './h1': 1, './c1': 1, './s1': 1, './r1': 1, './b1': 1}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
1919
# CHECK-NEXT: schedule O2: [1, 1, 2, 16, 1, 1, 1]
20-
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 1}, 'h': {'./h1': 1}, 'w': {'./w1': 2}, 'f': {'./f1': 16}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 1}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 16, './w1': 2, './h1': 1, './c1': 1, './s1': 1, './r1': 1, './b1': 1}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
20+
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 1}, 'h': {'./h1': 1}, 'w': {'./w1': 2}, 'f': {'./f1': 16}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 1}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 16, './w1': 2, './h1': 1, './c1': 1, './s1': 1, './r1': 1, './b1': 1}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
2121
# CHECK-NEXT: schedule O3: [1, 1, 2, 16, 1, 1, 3]
22-
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 1}, 'h': {'./h1': 1}, 'w': {'./w1': 2}, 'f': {'./f1': 16}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 3}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 16, './w1': 2, './h1': 1, './c1': 3, './s1': 1, './r1': 1, './b1': 1}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
22+
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 1}, 'h': {'./h1': 1}, 'w': {'./w1': 2}, 'f': {'./f1': 16}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 3}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 16, './w1': 2, './h1': 1, './c1': 3, './s1': 1, './r1': 1, './b1': 1}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
2323
# CHECK-NEXT: sample 0: [1, 1, 1, 1, 1, 1, 1]
2424
# CHECK-NEXT: sample 1: [1, 1, 1, 1, 1, 1, 3]
2525
# CHECK-NEXT: sample 2: [1, 1, 1, 1, 1, 7, 1]
@@ -99,4 +99,4 @@
9999
# CHECK-NEXT: sample 76: [2, 2, 2, 8, 1, 1, 1]
100100
# CHECK-NEXT: sample 77: [2, 2, 2, 16, 1, 1, 1]
101101
# CHECK-NEXT: stats {'filtered': 78, 'all': 384}
102-
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 2}, 'h': {'./h1': 2}, 'w': {'./w1': 2}, 'f': {'./f1': 16}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 1}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 16, './w1': 2, './h1': 2, './c1': 1, './s1': 1, './r1': 1, './b1': 2}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
102+
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 2}, 'h': {'./h1': 2}, 'w': {'./w1': 2}, 'f': {'./f1': 16}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 1}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 16, './w1': 2, './h1': 2, './c1': 1, './s1': 1, './r1': 1, './b1': 2}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]

0 commit comments

Comments
 (0)