Skip to content

Commit 50e6dea

Browse files
committed
descript: expose pack_at
1 parent 5e29715 commit 50e6dea

6 files changed

Lines changed: 88 additions & 2 deletions

File tree

src/xtc/backends/tvm/TVMScheduler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,11 @@ def get_loop_nest(self) -> LoopNest:
538538
# Build buffer_at mapping
539539
root_node.buffer_at = {axis: None for axis in self.write_caches}
540540

541+
# Build pack_at mapping
542+
root_node.pack_at = {
543+
axis: (input_idx, None, pad) for axis, input_idx, pad in self.read_buffers
544+
}
545+
541546
return loop_nest
542547

543548

src/xtc/schedules/descript.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,9 @@ def _apply_annotations(
213213
if annotations.buffer_specified:
214214
node.buffer_at[loop_name] = annotations.buffer
215215

216+
if annotations.pack_specified and annotations.pack is not None:
217+
node.pack_at[loop_name] = annotations.pack
218+
216219
def _check_splitting_intervals(
217220
self,
218221
item: SplitDecl,
@@ -305,6 +308,9 @@ def _apply_node(self, node: LoopNestNode) -> None:
305308
for axis, mtype in node.buffer_at.items():
306309
self.scheduler.buffer_at(axis, mtype=mtype, root=root)
307310

311+
for axis, (input_idx, mtype, pad) in node.pack_at.items():
312+
self.scheduler.pack_at(axis, input_idx, mtype=mtype, pad=pad, root=root)
313+
308314
# Recursively apply children
309315
for child in node.children:
310316
self._apply_node(child)

src/xtc/schedules/loop_nest.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ class LoopNestNode(Node["LoopNestNode"]):
9494
unroll: Maps loop names to their unroll factors.
9595
buffer_at: Buffer configuration per axis. Maps axis names to optional
9696
memory types (mtype). None means default memory type.
97+
pack_at: Pack configuration per axis. Maps axis names to tuples of
98+
(input_idx, mtype, pad). input_idx is the input buffer index,
99+
mtype is the memory type (None for default), pad enables padding.
97100
"""
98101

99102
root: str
@@ -104,6 +107,7 @@ class LoopNestNode(Node["LoopNestNode"]):
104107
parallelize: list[str] = field(default_factory=list)
105108
unroll: dict[str, int] = field(default_factory=dict)
106109
buffer_at: dict[str, str | None] = field(default_factory=dict)
110+
pack_at: dict[str, tuple[int, str | None, bool]] = field(default_factory=dict)
107111

108112
def pretty_print(self, indent: int = 0) -> str:
109113
"""Return a human-readable representation of the loop nest.
@@ -204,7 +208,7 @@ def pretty_print(self, indent: int = 0) -> str:
204208
return "\n".join(lines)
205209

206210
def _add_annotations(self, line: str, loop_name: str) -> str:
207-
"""Add annotations (parallelized, vectorized, unroll, buffer) to a loop line."""
211+
"""Add annotations (parallelized, vectorized, unroll, buffer, pack) to a loop line."""
208212
annotations: list[str] = []
209213
if loop_name in self.parallelize:
210214
annotations.append("parallelized")
@@ -218,6 +222,14 @@ def _add_annotations(self, line: str, loop_name: str) -> str:
218222
annotations.append(f"buffer({mtype})")
219223
else:
220224
annotations.append("buffer")
225+
if loop_name in self.pack_at:
226+
input_idx, mtype, pad = self.pack_at[loop_name]
227+
parts = [str(input_idx)]
228+
if mtype is not None:
229+
parts.append(mtype)
230+
if pad:
231+
parts.append("pad")
232+
annotations.append(f"pack({', '.join(parts)})")
221233
if annotations:
222234
line += " // " + ", ".join(annotations)
223235
return line

src/xtc/schedules/parsing.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ class Annotations:
2525
buffer: The memory type for the buffer. None means default memory type.
2626
Only meaningful when buffer_specified is True.
2727
buffer_specified: True if buffer was explicitly requested.
28+
pack: Pack configuration as (input_idx, mtype, pad). mtype is None for default.
29+
Only meaningful when pack_specified is True.
30+
pack_specified: True if pack was explicitly requested.
2831
"""
2932

3033
unroll_factor: int | None = None
@@ -33,6 +36,8 @@ class Annotations:
3336
parallelize: bool = False
3437
buffer: str | None = None
3538
buffer_specified: bool = False
39+
pack: tuple[int, str | None, bool] | None = None
40+
pack_specified: bool = False
3641

3742

3843
@dataclass(frozen=True)
@@ -152,6 +157,8 @@ def _parse_annotations(self, value: dict[str, Any], context: str) -> Annotations
152157
parallelize = False
153158
buffer: str | None = None
154159
buffer_specified = False
160+
pack: tuple[int, str | None, bool] | None = None
161+
pack_specified = False
155162

156163
for key, param in value.items():
157164
if key == "unroll":
@@ -186,6 +193,9 @@ def _parse_annotations(self, value: dict[str, Any], context: str) -> Annotations
186193
)
187194
buffer = None if param == "default" else param
188195
buffer_specified = True
196+
elif key == "pack":
197+
pack = self._parse_pack_param(param, context)
198+
pack_specified = True
189199
else:
190200
raise ScheduleParseError(f"Unknown annotation on {context}: {key}")
191201

@@ -196,8 +206,42 @@ def _parse_annotations(self, value: dict[str, Any], context: str) -> Annotations
196206
parallelize=parallelize,
197207
buffer=buffer,
198208
buffer_specified=buffer_specified,
209+
pack=pack,
210+
pack_specified=pack_specified,
199211
)
200212

213+
def _parse_pack_param(
214+
self, param: Any, context: str
215+
) -> tuple[int, str | None, bool]:
216+
"""Parse pack parameter into (input_idx, mtype, pad) tuple."""
217+
if not isinstance(param, (list, tuple)) or len(param) != 3:
218+
raise ScheduleParseError(
219+
f'`{{"pack" = {param}}}` on {context}: pack parameter should be a tuple (input_idx, mtype, pad).'
220+
)
221+
222+
input_idx, mtype, pad = param
223+
224+
if not isinstance(input_idx, int):
225+
raise ScheduleParseError(
226+
f'`{{"pack" = {param}}}` on {context}: input_idx should be an integer.'
227+
)
228+
229+
if mtype is not None and not isinstance(mtype, str):
230+
raise ScheduleParseError(
231+
f'`{{"pack" = {param}}}` on {context}: mtype should be a string or None.'
232+
)
233+
234+
if not isinstance(pad, bool):
235+
raise ScheduleParseError(
236+
f'`{{"pack" = {param}}}` on {context}: pad should be a boolean.'
237+
)
238+
239+
# Convert "default" to None for mtype
240+
if mtype == "default":
241+
mtype = None
242+
243+
return (input_idx, mtype, pad)
244+
201245
def _parse_split_syntax(
202246
self, declaration: str
203247
) -> tuple[str, int | None, int | None]:

tests/filecheck/schedules/test_descript_pretty_print.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# RUN: python %s --full 2>&1 | filecheck %s --check-prefix=CHECK-FULL
55
# RUN: python %s --split 2>&1 | filecheck %s --check-prefix=CHECK-SPLIT
66
# RUN: python %s --buffer 2>&1 | filecheck %s --check-prefix=CHECK-BUFFER
7+
# RUN: python %s --pack 2>&1 | filecheck %s --check-prefix=CHECK-PACK
78

89
import sys
910
from xtc.schedules.parsing import ScheduleParser
@@ -64,6 +65,17 @@
6465
loop_nest = interpreter.interpret(ast, root="C")
6566
print(loop_nest.root_node.pretty_print())
6667

68+
elif "--pack" in sys.argv:
69+
spec = {
70+
"i": {"parallelize": True},
71+
"k": {"pack": (0, "default", False)},
72+
"j": {"pack": (1, "shared", True)},
73+
"j#16": {"vectorize": True},
74+
}
75+
ast = parser.parse(spec)
76+
loop_nest = interpreter.interpret(ast, root="C")
77+
print(loop_nest.root_node.pretty_print())
78+
6779
# CHECK-SIMPLE: loop i
6880
# CHECK-SIMPLE-NEXT: loop k
6981
# CHECK-SIMPLE-NEXT: loop j
@@ -105,3 +117,9 @@
105117
# CHECK-BUFFER-NEXT: loop j // buffer(shared)
106118
# CHECK-BUFFER-NEXT: tile(j, 16) // vectorized
107119
# CHECK-BUFFER-NEXT: ...
120+
121+
# CHECK-PACK: loop i // parallelized
122+
# CHECK-PACK-NEXT: loop k // pack(0)
123+
# CHECK-PACK-NEXT: loop j // pack(1, shared, pad)
124+
# CHECK-PACK-NEXT: tile(j, 16) // vectorized
125+
# CHECK-PACK-NEXT: ...

tests/filecheck/schedules/test_get_descript.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
sch.vectorize(["J0"])
3434
if "--tvm" in sys.argv:
3535
sch.buffer_at("J")
36+
sch.pack_at("I", 0, pad=True)
3637

3738
loop_nest = sch.get_loop_nest()
3839
print(loop_nest.root_node.pretty_print())
@@ -45,7 +46,7 @@
4546
# CHECK-MLIR-NEXT: ...
4647

4748
# CHECK-TVM: loop K
48-
# CHECK-TVM-NEXT: loop I
49+
# CHECK-TVM-NEXT: loop I // pack(0, pad)
4950
# CHECK-TVM-NEXT: loop J // buffer
5051
# CHECK-TVM-NEXT: tile(I, 2) // unroll(2)
5152
# CHECK-TVM-NEXT: tile(J, 16) // vectorized

0 commit comments

Comments
 (0)