@@ -100,9 +100,6 @@ class ScheduleParser:
100100
101101 _SPLIT_PATTERN = re .compile (r"^(.*)\[(-\d+|\d*)?:(-\d+|\d*)?\]$" )
102102
103- def __init__ (self , abstract_axis : list [str ]):
104- self .abstract_axis = abstract_axis
105-
106103 def parse (self , spec : dict [str , Any ]) -> ScheduleSpec :
107104 """Parse a schedule specification dict into an AST."""
108105 items : list [ScheduleItem ] = []
@@ -224,31 +221,33 @@ def __init__(self, abstract_axis: list[str]):
224221
225222 def interpret (self , spec : ScheduleSpec , root : str ) -> LoopNest :
226223 """Interpret a schedule specification into a LoopNest."""
227- return self ._interpret_spec (spec , root , head = [])
228-
229- def _interpret_spec (
230- self , spec : ScheduleSpec , root : str , head : list [str ]
231- ) -> LoopNest :
232- """Interpret a schedule spec recursively."""
233224 loop_nest = LoopNest (abstract_dims = self .abstract_axis )
234- slice = loop_nest .build_slice (root )
225+ root_node = loop_nest .build_root_node (root )
226+ self ._interpret_spec_into_node (spec , root_node , root , head = [])
227+ return loop_nest
235228
229+ def _interpret_spec_into_node (
230+ self ,
231+ spec : ScheduleSpec ,
232+ node : LoopNestNode ,
233+ root : str ,
234+ head : list [str ],
235+ ) -> None :
236+ """Interpret a schedule spec into an existing node (mutates node)."""
236237 # Track state during interpretation
237238 sizes : dict [str , int ] = {}
238239 previous_cut : dict [str , int | None ] = {a : 0 for a in self .abstract_axis }
239240 interchange : list [str ] = list (head )
240241
241242 for item in spec .items :
242243 if isinstance (item , SplitDecl ):
243- self ._interpret_split (
244- item , slice , loop_nest , root , interchange , previous_cut
245- )
244+ self ._interpret_split (item , node , root , interchange , previous_cut )
246245 elif isinstance (item , TileDecl ):
247- loop_name = self ._interpret_tile (item , slice , interchange , sizes )
248- self ._apply_annotations (item .annotations , loop_name , sizes , slice )
246+ loop_name = self ._interpret_tile (item , node , interchange , sizes )
247+ self ._apply_annotations (item .annotations , loop_name , sizes , node )
249248 elif isinstance (item , AxisDecl ):
250249 loop_name = self ._interpret_axis (item , interchange )
251- self ._apply_annotations (item .annotations , loop_name , sizes , slice )
250+ self ._apply_annotations (item .annotations , loop_name , sizes , node )
252251
253252 # Check that all splits are complete
254253 for axis , cut in previous_cut .items ():
@@ -257,14 +256,12 @@ def _interpret_spec(
257256 f"Splitting of { axis } unachieved (stops at { cut } )."
258257 )
259258
260- slice .interchange = interchange
261- return loop_nest
259+ node .interchange = interchange
262260
263261 def _interpret_split (
264262 self ,
265263 item : SplitDecl ,
266- slice : LoopNestSlice ,
267- loop_nest : LoopNest ,
264+ node : LoopNestNode ,
268265 root : str ,
269266 interchange : list [str ],
270267 previous_cut : dict [str , int | None ],
@@ -279,7 +276,7 @@ def _interpret_split(
279276 # last one, so it cannot be the previous one.
280277 cut = previous_cut [axis_name ]
281278
282- # When x (the starting point of the slice ) is not specified,
279+ # When x (the starting point of the split ) is not specified,
283280 # it is the previous cut
284281 if x is None :
285282 x = cut
@@ -291,34 +288,44 @@ def _interpret_split(
291288 previous_cut [axis_name ] = y
292289
293290 # Save the cutting points of the new dimensions
294- if axis_name not in slice .splits :
295- slice .splits [axis_name ] = {}
296- new_dim_index = len (slice .splits [axis_name ])
291+ if axis_name not in node .splits :
292+ node .splits [axis_name ] = {}
293+ new_dim_index = len (node .splits [axis_name ])
297294 new_dim_name = f"{ axis_name } [{ new_dim_index } ]"
298295 new_root_name = f"{ root } /{ new_dim_name } "
299- slice .splits [axis_name ][new_dim_name ] = x
296+ node .splits [axis_name ][new_dim_name ] = x
300297 interchange .append (new_dim_name )
301298
302- # Recursively interpret the nested schedule
303- inner_nest = self ._interpret_spec (item .body , new_root_name , head = [axis_name ])
304- loop_nest .slices += inner_nest .slices
299+ # Create a child node for the nested schedule
300+ child_node = LoopNestNode (
301+ root = new_root_name , tiles = {a : {} for a in self .abstract_axis }
302+ )
303+
304+ # Create and attach the split child
305+ split_child = SplitChild (axis = axis_name , start = x , end = y , node = child_node )
306+ node .children .append (split_child )
307+
308+ # Recursively interpret the nested schedule into the child node
309+ self ._interpret_spec_into_node (
310+ item .body , child_node , new_root_name , head = [axis_name ]
311+ )
305312
306313 def _interpret_tile (
307314 self ,
308315 item : TileDecl ,
309- slice : LoopNestSlice ,
316+ node : LoopNestNode ,
310317 interchange : list [str ],
311318 sizes : dict [str , int ],
312319 ) -> str :
313320 """Interpret a tile declaration. Returns the loop name."""
314321 self ._check_axis_existence (item .axis )
315- tile_num = len (slice .tiles [item .axis ])
322+ tile_num = len (node .tiles [item .axis ])
316323 loop_name = f"{ item .axis } { tile_num } "
317324 if item .size <= 0 :
318325 raise ScheduleInterpretError (
319326 f"`{ item } `: tile sizes should be strictly positive."
320327 )
321- slice .tiles [item .axis ][loop_name ] = item .size
328+ node .tiles [item .axis ][loop_name ] = item .size
322329 sizes [loop_name ] = item .size
323330 interchange .append (loop_name )
324331
@@ -355,9 +362,9 @@ def _apply_annotations(
355362 annotations : Annotations ,
356363 loop_name : str ,
357364 sizes : dict [str , int ],
358- slice : LoopNestSlice ,
365+ node : LoopNestNode ,
359366 ) -> None :
360- """Apply annotations to a loop in the slice ."""
367+ """Apply annotations to a loop in the node ."""
361368 if annotations .unroll_specified :
362369 unroll_factor = annotations .unroll_factor
363370 if unroll_factor is None :
@@ -371,13 +378,13 @@ def _apply_annotations(
371378 raise ScheduleInterpretError (
372379 f'`{{"unroll" = { unroll_factor } }}`: unroll parameter should be strictly positive.'
373380 )
374- slice .unroll [loop_name ] = unroll_factor
381+ node .unroll [loop_name ] = unroll_factor
375382
376383 if annotations .vectorize :
377- slice .vectorize .append (loop_name )
384+ node .vectorize .append (loop_name )
378385
379386 if annotations .parallelize :
380- slice .parallelize .append (loop_name )
387+ node .parallelize .append (loop_name )
381388
382389 def _check_splitting_intervals (
383390 self ,
@@ -431,17 +438,17 @@ def loops_to_axis(self) -> dict[str, str]:
431438 return loops_to_axis
432439
433440 @staticmethod
434- def build_from_slices ( slices : list ["LoopNestSlice " ]) -> "LoopsDimsMapper" :
441+ def build_from_nodes ( nodes : list ["LoopNestNode " ]) -> "LoopsDimsMapper" :
435442 tiles_to_axis = {}
436443 splits_to_axis = {}
437444 dims = set ()
438- for slice in slices :
439- tiles_to_axis .update (LoopsDimsMapper ._get_subloops_to_axis (slice .tiles ))
440- splits_to_axis .update (LoopsDimsMapper ._get_subloops_to_axis (slice .splits ))
445+ for node in nodes :
446+ tiles_to_axis .update (LoopsDimsMapper ._get_subloops_to_axis (node .tiles ))
447+ splits_to_axis .update (LoopsDimsMapper ._get_subloops_to_axis (node .splits ))
441448 refined_loops = list (tiles_to_axis ) + list (splits_to_axis )
442- for slice in slices :
449+ for node in nodes :
443450 dims .update (
444- [loop for loop in slice .interchange if loop not in refined_loops ]
451+ [loop for loop in node .interchange if loop not in refined_loops ]
445452 )
446453 dims .update (tiles_to_axis .values ())
447454 dims .update (splits_to_axis .values ())
@@ -457,14 +464,32 @@ def _get_subloops_to_axis(subloops: dict[str, dict[str, Any]]) -> dict[str, str]
457464
458465
459466@dataclass
460- class LoopNestSlice :
461- """Represents a single slice of a loop nest with its transformations .
467+ class SplitChild :
468+ """Represents a child node in the loop nest tree created by a split .
462469
463- A slice describes the loops attached to a single root and
470+ Attributes:
471+ axis: The axis that was split to create this child.
472+ start: The starting position of the split (inclusive), or None if unbounded.
473+ end: The ending position of the split (exclusive), or None if unbounded.
474+ node: The child LoopNestNode containing the nested schedule.
475+ """
476+
477+ axis : str
478+ start : int | None
479+ end : int | None
480+ node : "LoopNestNode"
481+
482+
483+ @dataclass
484+ class LoopNestNode :
485+ """Represents a node in the loop nest tree with its transformations.
486+
487+ A node describes the loops attached to a single root and
464488 contains all the scheduling transformations applied to these loops.
489+ Splits create child nodes, forming an explicit tree structure.
465490
466491 Attributes:
467- root: Identifier of the the slice (either the base operation or
492+ root: Identifier of the node (either the base operation or
468493 the content of a split).
469494 tiles: Tiling configuration per axis. Maps axis names to dicts of
470495 tile loop names and their sizes.
@@ -474,6 +499,7 @@ class LoopNestSlice:
474499 vectorize: List of loops to vectorize.
475500 parallelize: List of loops to parallelize.
476501 unroll: Maps loop names to their unroll factors.
502+ children: List of SplitChild objects representing child nodes from splits.
477503 """
478504
479505 root : str
@@ -483,6 +509,7 @@ class LoopNestSlice:
483509 vectorize : list [str ] = field (default_factory = list )
484510 parallelize : list [str ] = field (default_factory = list )
485511 unroll : dict [str , int ] = field (default_factory = dict )
512+ children : list [SplitChild ] = field (default_factory = list )
486513
487514 @property
488515 def splits_to_sizes (self ) -> dict [str , int ]:
@@ -509,25 +536,44 @@ def tiles_to_sizes(self) -> dict[str, int]:
509536class LoopNest :
510537 """Represents a complete loop nest structure for scheduling.
511538
512- A loop nest contains abstract dimensions and a collection of slices/
513- It provides validation to ensure consistency across all slices .
539+ A loop nest contains abstract dimensions and a tree of nodes representing
540+ the schedule. Splits create child nodes, forming an explicit tree structure .
514541
515542 Attributes:
516543 abstract_dims: List of abstract dimension names for the loop nest.
517- slices: List of LoopNestSlice objects, one per scheduled operation .
544+ root_node: The root node of the loop nest tree, or None if empty .
518545 """
519546
520547 abstract_dims : list [str ]
521- slices : list [ LoopNestSlice ] = field ( default_factory = list )
548+ root_node : LoopNestNode | None = None
522549
523550 @property
524- def empty (self ):
525- return not self .slices
551+ def empty (self ) -> bool :
552+ return self .root_node is None
526553
527- def build_slice (self , root : str ) -> LoopNestSlice :
528- slice = LoopNestSlice (root = root , tiles = {a : {} for a in self .abstract_dims })
529- self .slices = [slice ] + self .slices
530- return slice
554+ @property
555+ def nodes (self ) -> list [LoopNestNode ]:
556+ """Flatten the tree into a list of nodes.
557+
558+ Returns nodes in depth-first order, with the root node first,
559+ followed by children in the order they were created.
560+ """
561+ if self .root_node is None :
562+ return []
563+ return self ._collect_nodes_dfs (self .root_node )
564+
565+ def _collect_nodes_dfs (self , node : LoopNestNode ) -> list [LoopNestNode ]:
566+ """Collect all nodes in the tree in depth-first order."""
567+ result = [node ]
568+ for child in node .children :
569+ result .extend (self ._collect_nodes_dfs (child .node ))
570+ return result
571+
572+ def build_root_node (self , root : str ) -> LoopNestNode :
573+ """Build and set the root node of the loop nest tree."""
574+ node = LoopNestNode (root = root , tiles = {a : {} for a in self .abstract_dims })
575+ self .root_node = node
576+ return node
531577
532578 def check (self ):
533579 self ._check_use_defined_dims ()
@@ -536,13 +582,13 @@ def check(self):
536582 self ._check_sizes ()
537583
538584 def _check_use_defined_dims (self ):
539- mapper = LoopsDimsMapper .build_from_slices (self .slices )
585+ mapper = LoopsDimsMapper .build_from_nodes (self .nodes )
540586 for dim in self .abstract_dims :
541587 if dim not in mapper .dims :
542588 raise ScheduleValidationError (f"{ dim } defined but never used" )
543589
544590 def _check_vectorization_consistency (self ):
545- for sched in self .slices :
591+ for sched in self .nodes :
546592 vect_above = False
547593 for loop_name in sched .interchange :
548594 if loop_name in sched .vectorize :
@@ -553,9 +599,9 @@ def _check_vectorization_consistency(self):
553599 )
554600
555601 def _check_tiling_consistency (self ) -> None :
556- mapper = LoopsDimsMapper .build_from_slices (self .slices )
602+ mapper = LoopsDimsMapper .build_from_nodes (self .nodes )
557603 seen_axes : dict [str , int | None ] = {}
558- for sched in self .slices :
604+ for sched in self .nodes :
559605 for loop_name in sched .interchange :
560606 if loop_name in mapper .dims :
561607 seen_axes [loop_name ] = None
@@ -571,9 +617,9 @@ def _check_tiling_consistency(self) -> None:
571617 seen_axes [axis ] = sched .tiles [axis ][loop_name ]
572618
573619 def _check_sizes (self ):
574- mapper = LoopsDimsMapper .build_from_slices (self .slices )
620+ mapper = LoopsDimsMapper .build_from_nodes (self .nodes )
575621 current_size_of_split : dict [str , int | None ] = {}
576- for sched in self .slices :
622+ for sched in self .nodes :
577623 current_size_of_tile : dict [str , int ] = {}
578624
579625 for loop_name in sched .interchange :
@@ -677,7 +723,7 @@ def apply(self, node_name: str, spec: dict[str, dict[str, Any]]) -> None:
677723 ScheduleValidationError: If the resulting schedule is invalid.
678724 """
679725 # Parse the specification into an AST
680- parser = ScheduleParser (self . abstract_axis )
726+ parser = ScheduleParser ()
681727 ast = parser .parse (spec )
682728
683729 # Interpret the AST into a LoopNest
@@ -694,16 +740,24 @@ def _apply_loop_nest(self, loop_nest: LoopNest) -> None:
694740 """Apply a LoopNest to the scheduler."""
695741 self .scheduler .set_dims (self .abstract_axis )
696742
697- for slice in loop_nest .slices :
698- root = slice .root
743+ if loop_nest .root_node is not None :
744+ self ._apply_node (loop_nest .root_node )
745+
746+ def _apply_node (self , node : LoopNestNode ) -> None :
747+ """Recursively apply a LoopNestNode and its children to the scheduler."""
748+ root = node .root
749+
750+ for d , s in node .splits .items ():
751+ self .scheduler .split (d , s , root = root )
699752
700- for d , s in slice . splits .items ():
701- self .scheduler .split (d , s , root = root )
753+ for d , s in node . tiles .items ():
754+ self .scheduler .tile (d , s , root = root )
702755
703- for d , s in slice .tiles .items ():
704- self .scheduler .tile (d , s , root = root )
756+ self .scheduler .interchange (node .interchange , root = root )
757+ self .scheduler .vectorize (node .vectorize , root = root )
758+ self .scheduler .parallelize (node .parallelize , root = root )
759+ self .scheduler .unroll (node .unroll , root = root )
705760
706- self .scheduler .interchange (slice .interchange , root = root )
707- self .scheduler .vectorize (slice .vectorize , root = root )
708- self .scheduler .parallelize (slice .parallelize , root = root )
709- self .scheduler .unroll (slice .unroll , root = root )
761+ # Recursively apply children
762+ for child in node .children :
763+ self ._apply_node (child .node )
0 commit comments