Skip to content

Commit b0ce0b1

Browse files
committed
descript: pretty-printer
1 parent c66e5e5 commit b0ce0b1

2 files changed

Lines changed: 160 additions & 0 deletions

File tree

src/xtc/schedules/loop_nest.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,83 @@ def tiles_to_sizes(self) -> dict[str, int]:
126126
tiles_to_sizes[loop] = size
127127
return tiles_to_sizes
128128

129+
def pretty_print(self, indent: int = 0) -> str:
130+
"""Return a human-readable representation of the loop nest.
131+
132+
The output format resembles pseudocode:
133+
for i in ...:
134+
for j in ...:
135+
for j0 in tile(j, 16): // vectorized
136+
137+
Args:
138+
indent: The initial indentation level (number of spaces).
139+
140+
Returns:
141+
A multi-line string representing the loop nest structure.
142+
"""
143+
lines: list[str] = []
144+
145+
# Build mapping from tile loop name to (axis, size)
146+
tiles_info: dict[str, tuple[str, int]] = {}
147+
for axis, tile_loops in self.tiles.items():
148+
for loop_name, size in tile_loops.items():
149+
tiles_info[loop_name] = (axis, size)
150+
151+
# Map split loop names to their child nodes (which contain split_origin)
152+
split_to_child: dict[str, LoopNestNode] = {}
153+
for child in self.children:
154+
if child.split_origin is not None:
155+
axis = child.split_origin.axis
156+
if axis in self.splits:
157+
for loop_name, start in self.splits[axis].items():
158+
if start == child.split_origin.start:
159+
split_to_child[loop_name] = child
160+
break
161+
162+
current_indent = indent
163+
for loop_name in self.interchange:
164+
# Build the loop line based on loop type
165+
if loop_name in tiles_info:
166+
axis, size = tiles_info[loop_name]
167+
line = f"for {loop_name} in tile({axis}, {size}):"
168+
elif loop_name in split_to_child:
169+
child = split_to_child[loop_name]
170+
origin = child.split_origin
171+
assert origin is not None
172+
start = origin.start if origin.start is not None else "..."
173+
end = origin.end if origin.end is not None else "..."
174+
line = f"for {loop_name} in split({origin.axis}, {start}, {end}):"
175+
else:
176+
line = f"for {loop_name} in ...:"
177+
178+
# Add annotations as comments
179+
annotations: list[str] = []
180+
if loop_name in self.parallelize:
181+
annotations.append("parallelized")
182+
if loop_name in self.vectorize:
183+
annotations.append("vectorized")
184+
if loop_name in self.unroll:
185+
factor = self.unroll[loop_name]
186+
annotations.append(f"unroll({factor})")
187+
188+
if annotations:
189+
line += " // " + ", ".join(annotations)
190+
191+
lines.append(" " * current_indent + line)
192+
193+
# If this is a split loop with a child, recurse into the child
194+
# Split loops don't increase indent for subsequent siblings
195+
if loop_name in split_to_child:
196+
child_output = split_to_child[loop_name].pretty_print(
197+
current_indent + 2
198+
)
199+
lines.append(child_output)
200+
else:
201+
# Regular loops increase nesting depth
202+
current_indent += 2
203+
204+
return "\n".join(lines)
205+
129206

130207
@dataclass
131208
class LoopsDimsMapper:
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# RUN: python %s --simple 2>&1 | filecheck %s --check-prefix=CHECK-SIMPLE
2+
# RUN: python %s --tiled 2>&1 | filecheck %s --check-prefix=CHECK-TILED
3+
# RUN: python %s --vectorized 2>&1 | filecheck %s --check-prefix=CHECK-VECTORIZED
4+
# RUN: python %s --full 2>&1 | filecheck %s --check-prefix=CHECK-FULL
5+
# RUN: python %s --split 2>&1 | filecheck %s --check-prefix=CHECK-SPLIT
6+
7+
import sys
8+
from xtc.schedules.parsing import ScheduleParser
9+
from xtc.schedules.descript import ScheduleInterpreter
10+
11+
parser = ScheduleParser()
12+
abstract_axis = ["i", "j", "k"]
13+
interpreter = ScheduleInterpreter(abstract_axis)
14+
15+
if "--simple" in sys.argv:
16+
spec = {"i": {}, "k": {}, "j": {}}
17+
ast = parser.parse(spec)
18+
loop_nest = interpreter.interpret(ast, root="C")
19+
print(loop_nest.root_node.pretty_print())
20+
21+
elif "--tiled" in sys.argv:
22+
spec = {"i": {}, "k": {}, "j": {}, "j#16": {}}
23+
ast = parser.parse(spec)
24+
loop_nest = interpreter.interpret(ast, root="C")
25+
print(loop_nest.root_node.pretty_print())
26+
27+
elif "--vectorized" in sys.argv:
28+
spec = {"i": {}, "k": {}, "j": {}, "j#16": {"vectorize": True}}
29+
ast = parser.parse(spec)
30+
loop_nest = interpreter.interpret(ast, root="C")
31+
print(loop_nest.root_node.pretty_print())
32+
33+
elif "--full" in sys.argv:
34+
spec = {
35+
"i": {"parallelize": True},
36+
"k": {},
37+
"j": {},
38+
"j#32": {},
39+
"j#16": {"vectorize": True, "unroll": 4},
40+
}
41+
ast = parser.parse(spec)
42+
loop_nest = interpreter.interpret(ast, root="C")
43+
print(loop_nest.root_node.pretty_print())
44+
45+
elif "--split" in sys.argv:
46+
spec = {
47+
"i": {},
48+
"j[:128]": {"k": {}, "k#32": {}},
49+
"j[128:]": {"k": {}, "k#16": {"vectorize": True}},
50+
}
51+
ast = parser.parse(spec)
52+
loop_nest = interpreter.interpret(ast, root="C")
53+
print(loop_nest.root_node.pretty_print())
54+
55+
# CHECK-SIMPLE: for i in ...:
56+
# CHECK-SIMPLE-NEXT: for k in ...:
57+
# CHECK-SIMPLE-NEXT: for j in ...:
58+
59+
# CHECK-TILED: for i in ...:
60+
# CHECK-TILED-NEXT: for k in ...:
61+
# CHECK-TILED-NEXT: for j in ...:
62+
# CHECK-TILED-NEXT: for j0 in tile(j, 16):
63+
64+
# CHECK-VECTORIZED: for i in ...:
65+
# CHECK-VECTORIZED-NEXT: for k in ...:
66+
# CHECK-VECTORIZED-NEXT: for j in ...:
67+
# CHECK-VECTORIZED-NEXT: for j0 in tile(j, 16): // vectorized
68+
69+
# CHECK-FULL: for i in ...: // parallelized
70+
# CHECK-FULL-NEXT: for k in ...:
71+
# CHECK-FULL-NEXT: for j in ...:
72+
# CHECK-FULL-NEXT: for j0 in tile(j, 32):
73+
# CHECK-FULL-NEXT: for j1 in tile(j, 16): // vectorized, unroll(4)
74+
75+
# CHECK-SPLIT: for i in ...:
76+
# CHECK-SPLIT-NEXT: for j[0] in split(j, 0, 128):
77+
# CHECK-SPLIT-NEXT: for j in ...:
78+
# CHECK-SPLIT-NEXT: for k in ...:
79+
# CHECK-SPLIT-NEXT: for k0 in tile(k, 32):
80+
# CHECK-SPLIT-NEXT: for j[1] in split(j, 128, ...):
81+
# CHECK-SPLIT-NEXT: for j in ...:
82+
# CHECK-SPLIT-NEXT: for k in ...:
83+
# CHECK-SPLIT-NEXT: for k0 in tile(k, 16): // vectorized

0 commit comments

Comments
 (0)