Skip to content

Commit 48fd73c

Browse files
committed
descript: use a tree structure to represent a loop nest
1 parent a0f495f commit 48fd73c

1 file changed

Lines changed: 127 additions & 73 deletions

File tree

src/xtc/schedules/descript.py

Lines changed: 127 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,6 @@ class ScheduleParser:
100100

101101
_SPLIT_PATTERN = re.compile(r"^(.*)\[(-\d+|\d*)?:(-\d+|\d*)?\]$")
102102

103-
def __init__(self, abstract_axis: list[str]):
104-
self.abstract_axis = abstract_axis
105-
106103
def parse(self, spec: dict[str, Any]) -> ScheduleSpec:
107104
"""Parse a schedule specification dict into an AST."""
108105
items: list[ScheduleItem] = []
@@ -224,31 +221,33 @@ def __init__(self, abstract_axis: list[str]):
224221

225222
def interpret(self, spec: ScheduleSpec, root: str) -> LoopNest:
226223
"""Interpret a schedule specification into a LoopNest."""
227-
return self._interpret_spec(spec, root, head=[])
228-
229-
def _interpret_spec(
230-
self, spec: ScheduleSpec, root: str, head: list[str]
231-
) -> LoopNest:
232-
"""Interpret a schedule spec recursively."""
233224
loop_nest = LoopNest(abstract_dims=self.abstract_axis)
234-
slice = loop_nest.build_slice(root)
225+
root_node = loop_nest.build_root_node(root)
226+
self._interpret_spec_into_node(spec, root_node, root, head=[])
227+
return loop_nest
235228

229+
def _interpret_spec_into_node(
230+
self,
231+
spec: ScheduleSpec,
232+
node: LoopNestNode,
233+
root: str,
234+
head: list[str],
235+
) -> None:
236+
"""Interpret a schedule spec into an existing node (mutates node)."""
236237
# Track state during interpretation
237238
sizes: dict[str, int] = {}
238239
previous_cut: dict[str, int | None] = {a: 0 for a in self.abstract_axis}
239240
interchange: list[str] = list(head)
240241

241242
for item in spec.items:
242243
if isinstance(item, SplitDecl):
243-
self._interpret_split(
244-
item, slice, loop_nest, root, interchange, previous_cut
245-
)
244+
self._interpret_split(item, node, root, interchange, previous_cut)
246245
elif isinstance(item, TileDecl):
247-
loop_name = self._interpret_tile(item, slice, interchange, sizes)
248-
self._apply_annotations(item.annotations, loop_name, sizes, slice)
246+
loop_name = self._interpret_tile(item, node, interchange, sizes)
247+
self._apply_annotations(item.annotations, loop_name, sizes, node)
249248
elif isinstance(item, AxisDecl):
250249
loop_name = self._interpret_axis(item, interchange)
251-
self._apply_annotations(item.annotations, loop_name, sizes, slice)
250+
self._apply_annotations(item.annotations, loop_name, sizes, node)
252251

253252
# Check that all splits are complete
254253
for axis, cut in previous_cut.items():
@@ -257,14 +256,12 @@ def _interpret_spec(
257256
f"Splitting of {axis} unachieved (stops at {cut})."
258257
)
259258

260-
slice.interchange = interchange
261-
return loop_nest
259+
node.interchange = interchange
262260

263261
def _interpret_split(
264262
self,
265263
item: SplitDecl,
266-
slice: LoopNestSlice,
267-
loop_nest: LoopNest,
264+
node: LoopNestNode,
268265
root: str,
269266
interchange: list[str],
270267
previous_cut: dict[str, int | None],
@@ -279,7 +276,7 @@ def _interpret_split(
279276
# last one, so it cannot be the previous one.
280277
cut = previous_cut[axis_name]
281278

282-
# When x (the starting point of the slice) is not specified,
279+
# When x (the starting point of the split) is not specified,
283280
# it is the previous cut
284281
if x is None:
285282
x = cut
@@ -291,34 +288,44 @@ def _interpret_split(
291288
previous_cut[axis_name] = y
292289

293290
# Save the cutting points of the new dimensions
294-
if axis_name not in slice.splits:
295-
slice.splits[axis_name] = {}
296-
new_dim_index = len(slice.splits[axis_name])
291+
if axis_name not in node.splits:
292+
node.splits[axis_name] = {}
293+
new_dim_index = len(node.splits[axis_name])
297294
new_dim_name = f"{axis_name}[{new_dim_index}]"
298295
new_root_name = f"{root}/{new_dim_name}"
299-
slice.splits[axis_name][new_dim_name] = x
296+
node.splits[axis_name][new_dim_name] = x
300297
interchange.append(new_dim_name)
301298

302-
# Recursively interpret the nested schedule
303-
inner_nest = self._interpret_spec(item.body, new_root_name, head=[axis_name])
304-
loop_nest.slices += inner_nest.slices
299+
# Create a child node for the nested schedule
300+
child_node = LoopNestNode(
301+
root=new_root_name, tiles={a: {} for a in self.abstract_axis}
302+
)
303+
304+
# Create and attach the split child
305+
split_child = SplitChild(axis=axis_name, start=x, end=y, node=child_node)
306+
node.children.append(split_child)
307+
308+
# Recursively interpret the nested schedule into the child node
309+
self._interpret_spec_into_node(
310+
item.body, child_node, new_root_name, head=[axis_name]
311+
)
305312

306313
def _interpret_tile(
307314
self,
308315
item: TileDecl,
309-
slice: LoopNestSlice,
316+
node: LoopNestNode,
310317
interchange: list[str],
311318
sizes: dict[str, int],
312319
) -> str:
313320
"""Interpret a tile declaration. Returns the loop name."""
314321
self._check_axis_existence(item.axis)
315-
tile_num = len(slice.tiles[item.axis])
322+
tile_num = len(node.tiles[item.axis])
316323
loop_name = f"{item.axis}{tile_num}"
317324
if item.size <= 0:
318325
raise ScheduleInterpretError(
319326
f"`{item}`: tile sizes should be strictly positive."
320327
)
321-
slice.tiles[item.axis][loop_name] = item.size
328+
node.tiles[item.axis][loop_name] = item.size
322329
sizes[loop_name] = item.size
323330
interchange.append(loop_name)
324331

@@ -355,9 +362,9 @@ def _apply_annotations(
355362
annotations: Annotations,
356363
loop_name: str,
357364
sizes: dict[str, int],
358-
slice: LoopNestSlice,
365+
node: LoopNestNode,
359366
) -> None:
360-
"""Apply annotations to a loop in the slice."""
367+
"""Apply annotations to a loop in the node."""
361368
if annotations.unroll_specified:
362369
unroll_factor = annotations.unroll_factor
363370
if unroll_factor is None:
@@ -371,13 +378,13 @@ def _apply_annotations(
371378
raise ScheduleInterpretError(
372379
f'`{{"unroll" = {unroll_factor}}}`: unroll parameter should be strictly positive.'
373380
)
374-
slice.unroll[loop_name] = unroll_factor
381+
node.unroll[loop_name] = unroll_factor
375382

376383
if annotations.vectorize:
377-
slice.vectorize.append(loop_name)
384+
node.vectorize.append(loop_name)
378385

379386
if annotations.parallelize:
380-
slice.parallelize.append(loop_name)
387+
node.parallelize.append(loop_name)
381388

382389
def _check_splitting_intervals(
383390
self,
@@ -431,17 +438,17 @@ def loops_to_axis(self) -> dict[str, str]:
431438
return loops_to_axis
432439

433440
@staticmethod
434-
def build_from_slices(slices: list["LoopNestSlice"]) -> "LoopsDimsMapper":
441+
def build_from_nodes(nodes: list["LoopNestNode"]) -> "LoopsDimsMapper":
435442
tiles_to_axis = {}
436443
splits_to_axis = {}
437444
dims = set()
438-
for slice in slices:
439-
tiles_to_axis.update(LoopsDimsMapper._get_subloops_to_axis(slice.tiles))
440-
splits_to_axis.update(LoopsDimsMapper._get_subloops_to_axis(slice.splits))
445+
for node in nodes:
446+
tiles_to_axis.update(LoopsDimsMapper._get_subloops_to_axis(node.tiles))
447+
splits_to_axis.update(LoopsDimsMapper._get_subloops_to_axis(node.splits))
441448
refined_loops = list(tiles_to_axis) + list(splits_to_axis)
442-
for slice in slices:
449+
for node in nodes:
443450
dims.update(
444-
[loop for loop in slice.interchange if loop not in refined_loops]
451+
[loop for loop in node.interchange if loop not in refined_loops]
445452
)
446453
dims.update(tiles_to_axis.values())
447454
dims.update(splits_to_axis.values())
@@ -457,14 +464,32 @@ def _get_subloops_to_axis(subloops: dict[str, dict[str, Any]]) -> dict[str, str]
457464

458465

459466
@dataclass
460-
class LoopNestSlice:
461-
"""Represents a single slice of a loop nest with its transformations.
467+
class SplitChild:
468+
"""Represents a child node in the loop nest tree created by a split.
462469
463-
A slice describes the loops attached to a single root and
470+
Attributes:
471+
axis: The axis that was split to create this child.
472+
start: The starting position of the split (inclusive), or None if unbounded.
473+
end: The ending position of the split (exclusive), or None if unbounded.
474+
node: The child LoopNestNode containing the nested schedule.
475+
"""
476+
477+
axis: str
478+
start: int | None
479+
end: int | None
480+
node: "LoopNestNode"
481+
482+
483+
@dataclass
484+
class LoopNestNode:
485+
"""Represents a node in the loop nest tree with its transformations.
486+
487+
A node describes the loops attached to a single root and
464488
contains all the scheduling transformations applied to these loops.
489+
Splits create child nodes, forming an explicit tree structure.
465490
466491
Attributes:
467-
root: Identifier of the the slice (either the base operation or
492+
root: Identifier of the node (either the base operation or
468493
the content of a split).
469494
tiles: Tiling configuration per axis. Maps axis names to dicts of
470495
tile loop names and their sizes.
@@ -474,6 +499,7 @@ class LoopNestSlice:
474499
vectorize: List of loops to vectorize.
475500
parallelize: List of loops to parallelize.
476501
unroll: Maps loop names to their unroll factors.
502+
children: List of SplitChild objects representing child nodes from splits.
477503
"""
478504

479505
root: str
@@ -483,6 +509,7 @@ class LoopNestSlice:
483509
vectorize: list[str] = field(default_factory=list)
484510
parallelize: list[str] = field(default_factory=list)
485511
unroll: dict[str, int] = field(default_factory=dict)
512+
children: list[SplitChild] = field(default_factory=list)
486513

487514
@property
488515
def splits_to_sizes(self) -> dict[str, int]:
@@ -509,25 +536,44 @@ def tiles_to_sizes(self) -> dict[str, int]:
509536
class LoopNest:
510537
"""Represents a complete loop nest structure for scheduling.
511538
512-
A loop nest contains abstract dimensions and a collection of slices/
513-
It provides validation to ensure consistency across all slices.
539+
A loop nest contains abstract dimensions and a tree of nodes representing
540+
the schedule. Splits create child nodes, forming an explicit tree structure.
514541
515542
Attributes:
516543
abstract_dims: List of abstract dimension names for the loop nest.
517-
slices: List of LoopNestSlice objects, one per scheduled operation.
544+
root_node: The root node of the loop nest tree, or None if empty.
518545
"""
519546

520547
abstract_dims: list[str]
521-
slices: list[LoopNestSlice] = field(default_factory=list)
548+
root_node: LoopNestNode | None = None
522549

523550
@property
524-
def empty(self):
525-
return not self.slices
551+
def empty(self) -> bool:
552+
return self.root_node is None
526553

527-
def build_slice(self, root: str) -> LoopNestSlice:
528-
slice = LoopNestSlice(root=root, tiles={a: {} for a in self.abstract_dims})
529-
self.slices = [slice] + self.slices
530-
return slice
554+
@property
555+
def nodes(self) -> list[LoopNestNode]:
556+
"""Flatten the tree into a list of nodes.
557+
558+
Returns nodes in depth-first order, with the root node first,
559+
followed by children in the order they were created.
560+
"""
561+
if self.root_node is None:
562+
return []
563+
return self._collect_nodes_dfs(self.root_node)
564+
565+
def _collect_nodes_dfs(self, node: LoopNestNode) -> list[LoopNestNode]:
566+
"""Collect all nodes in the tree in depth-first order."""
567+
result = [node]
568+
for child in node.children:
569+
result.extend(self._collect_nodes_dfs(child.node))
570+
return result
571+
572+
def build_root_node(self, root: str) -> LoopNestNode:
573+
"""Build and set the root node of the loop nest tree."""
574+
node = LoopNestNode(root=root, tiles={a: {} for a in self.abstract_dims})
575+
self.root_node = node
576+
return node
531577

532578
def check(self):
533579
self._check_use_defined_dims()
@@ -536,13 +582,13 @@ def check(self):
536582
self._check_sizes()
537583

538584
def _check_use_defined_dims(self):
539-
mapper = LoopsDimsMapper.build_from_slices(self.slices)
585+
mapper = LoopsDimsMapper.build_from_nodes(self.nodes)
540586
for dim in self.abstract_dims:
541587
if dim not in mapper.dims:
542588
raise ScheduleValidationError(f"{dim} defined but never used")
543589

544590
def _check_vectorization_consistency(self):
545-
for sched in self.slices:
591+
for sched in self.nodes:
546592
vect_above = False
547593
for loop_name in sched.interchange:
548594
if loop_name in sched.vectorize:
@@ -553,9 +599,9 @@ def _check_vectorization_consistency(self):
553599
)
554600

555601
def _check_tiling_consistency(self) -> None:
556-
mapper = LoopsDimsMapper.build_from_slices(self.slices)
602+
mapper = LoopsDimsMapper.build_from_nodes(self.nodes)
557603
seen_axes: dict[str, int | None] = {}
558-
for sched in self.slices:
604+
for sched in self.nodes:
559605
for loop_name in sched.interchange:
560606
if loop_name in mapper.dims:
561607
seen_axes[loop_name] = None
@@ -571,9 +617,9 @@ def _check_tiling_consistency(self) -> None:
571617
seen_axes[axis] = sched.tiles[axis][loop_name]
572618

573619
def _check_sizes(self):
574-
mapper = LoopsDimsMapper.build_from_slices(self.slices)
620+
mapper = LoopsDimsMapper.build_from_nodes(self.nodes)
575621
current_size_of_split: dict[str, int | None] = {}
576-
for sched in self.slices:
622+
for sched in self.nodes:
577623
current_size_of_tile: dict[str, int] = {}
578624

579625
for loop_name in sched.interchange:
@@ -677,7 +723,7 @@ def apply(self, node_name: str, spec: dict[str, dict[str, Any]]) -> None:
677723
ScheduleValidationError: If the resulting schedule is invalid.
678724
"""
679725
# Parse the specification into an AST
680-
parser = ScheduleParser(self.abstract_axis)
726+
parser = ScheduleParser()
681727
ast = parser.parse(spec)
682728

683729
# Interpret the AST into a LoopNest
@@ -694,16 +740,24 @@ def _apply_loop_nest(self, loop_nest: LoopNest) -> None:
694740
"""Apply a LoopNest to the scheduler."""
695741
self.scheduler.set_dims(self.abstract_axis)
696742

697-
for slice in loop_nest.slices:
698-
root = slice.root
743+
if loop_nest.root_node is not None:
744+
self._apply_node(loop_nest.root_node)
745+
746+
def _apply_node(self, node: LoopNestNode) -> None:
747+
"""Recursively apply a LoopNestNode and its children to the scheduler."""
748+
root = node.root
749+
750+
for d, s in node.splits.items():
751+
self.scheduler.split(d, s, root=root)
699752

700-
for d, s in slice.splits.items():
701-
self.scheduler.split(d, s, root=root)
753+
for d, s in node.tiles.items():
754+
self.scheduler.tile(d, s, root=root)
702755

703-
for d, s in slice.tiles.items():
704-
self.scheduler.tile(d, s, root=root)
756+
self.scheduler.interchange(node.interchange, root=root)
757+
self.scheduler.vectorize(node.vectorize, root=root)
758+
self.scheduler.parallelize(node.parallelize, root=root)
759+
self.scheduler.unroll(node.unroll, root=root)
705760

706-
self.scheduler.interchange(slice.interchange, root=root)
707-
self.scheduler.vectorize(slice.vectorize, root=root)
708-
self.scheduler.parallelize(slice.parallelize, root=root)
709-
self.scheduler.unroll(slice.unroll, root=root)
761+
# Recursively apply children
762+
for child in node.children:
763+
self._apply_node(child.node)

0 commit comments

Comments
 (0)