@@ -41,23 +41,28 @@ class Annotations:
4141 parallelize: True if parallelization was requested.
4242 """
4343
44- unroll_factor : int | None = None
44+ unroll_factor : int | str | None = None
4545 unroll_specified : bool = False
46- vectorize : bool = False
47- parallelize : bool = False
46+ vectorize : bool | str = False
47+ parallelize : bool | str = False
48+ partial : bool = False
49+ full : bool = False
4850
4951
5052@dataclass (frozen = True )
5153class SplitDecl :
5254 """AST Type: a split declaration like 'axis[start:end]'."""
5355
5456 axis : str
55- start : int | None
56- end : int | None
57+ start : int | str | None
58+ end : int | str | None
5759 body : ScheduleSpec
60+ size : int | str | None = None
5861
5962 @override
6063 def __str__ (self ) -> str :
64+ if self .size is not None :
65+ return f"{ self .axis } [:{ self .size } :]"
6166 start_str = "" if self .start is None else str (self .start )
6267 end_str = "" if self .end is None else str (self .end )
6368 decl = f"{ self .axis } [{ start_str } :{ end_str } ]"
@@ -69,7 +74,7 @@ class TileDecl:
6974 """AST Type: a tile declaration like 'axis#size'."""
7075
7176 axis : str
72- size : int
77+ size : int | str
7378 annotations : Annotations
7479
7580 @override
@@ -85,7 +90,36 @@ class AxisDecl:
8590 annotations : Annotations
8691
8792
88- ScheduleItem = SplitDecl | TileDecl | AxisDecl
93+ @dataclass (frozen = True )
94+ class FusionDecl :
95+ """AST Type: a fusion declaration"""
96+
97+
98+ @dataclass (frozen = True )
99+ class PackDecl :
100+ """AST Type: a packing declaration"""
101+
102+ param : str | bool
103+ input : str
104+ pad : str | bool
105+
106+
107+ @dataclass (frozen = True )
108+ class BufferDecl :
109+ """AST Type: a bufferisation declaration"""
110+
111+ param : str | bool
112+ pad : str
113+
114+
115+ @dataclass (frozen = True )
116+ class ExploreDecl :
117+ level : str
118+
119+
120+ ScheduleItem = (
121+ SplitDecl | TileDecl | AxisDecl | FusionDecl | PackDecl | BufferDecl | ExploreDecl
122+ )
89123
90124
91125@dataclass (frozen = True )
@@ -144,10 +178,12 @@ def _parse_tile(self, declaration: str, value: dict) -> TileDecl:
144178
145179 axis_name , size_str = parts
146180
147- try :
148- size = int (size_str )
149- except ValueError :
150- raise ScheduleParseError (f"`{ declaration } `: { size_str } is not an integer." )
181+ size = int (size_str ) if size_str .isnumeric () else size_str
182+
183+ # try:
184+ # size = int(size_str)
185+ # except ValueError:
186+ # raise ScheduleParseError(f"`{declaration}`: {size_str} is not an integer.")
151187
152188 annotations = self ._parse_annotations (value , declaration )
153189 return TileDecl (axis = axis_name , size = size , annotations = annotations )
@@ -234,8 +270,8 @@ def _interpret_spec(
234270 slice = loop_nest .build_slice (root )
235271
236272 # Track state during interpretation
237- sizes : dict [str , int ] = {}
238- previous_cut : dict [str , int | None ] = {a : 0 for a in self .abstract_axis }
273+ sizes : dict [str , int | str ] = {}
274+ previous_cut : dict [str , int | str | None ] = {a : 0 for a in self .abstract_axis }
239275 interchange : list [str ] = list (head )
240276
241277 for item in spec .items :
@@ -267,7 +303,7 @@ def _interpret_split(
267303 loop_nest : LoopNest ,
268304 root : str ,
269305 interchange : list [str ],
270- previous_cut : dict [str , int | None ],
306+ previous_cut : dict [str , int | str | None ],
271307 ) -> None :
272308 """Interpret a split declaration."""
273309 axis_name = item .axis
@@ -283,10 +319,8 @@ def _interpret_split(
283319 # it is the previous cut
284320 if x is None :
285321 x = cut
286- assert x is not None
287-
288322 self ._check_splitting_intervals (item , cut , x )
289-
323+ assert x is not None
290324 # Update the previous cut
291325 previous_cut [axis_name ] = y
292326
@@ -308,12 +342,15 @@ def _interpret_tile(
308342 item : TileDecl ,
309343 slice : LoopNestSlice ,
310344 interchange : list [str ],
311- sizes : dict [str , int ],
345+ sizes : dict [str , int | str ],
312346 ) -> str :
313347 """Interpret a tile declaration. Returns the loop name."""
314348 self ._check_axis_existence (item .axis )
315349 tile_num = len (slice .tiles [item .axis ])
316350 loop_name = f"{ item .axis } { tile_num } "
351+ if not isinstance (item .size , int ):
352+ raise ScheduleInterpretError (f"`{ item } `: { item .size } is not an integer." )
353+ assert isinstance (item .size , int )
317354 if item .size <= 0 :
318355 raise ScheduleInterpretError (
319356 f"`{ item } `: tile sizes should be strictly positive."
@@ -354,7 +391,7 @@ def _apply_annotations(
354391 self ,
355392 annotations : Annotations ,
356393 loop_name : str ,
357- sizes : dict [str , int ],
394+ sizes : dict [str , int | str ],
358395 slice : LoopNestSlice ,
359396 ) -> None :
360397 """Apply annotations to a loop in the slice."""
@@ -367,7 +404,7 @@ def _apply_annotations(
367404 f"{ loop_name } 's size being unknown, an unroll factor is needed."
368405 )
369406 unroll_factor = sizes [loop_name ]
370- elif unroll_factor <= 0 :
407+ elif isinstance ( unroll_factor , int ) and unroll_factor <= 0 :
371408 raise ScheduleInterpretError (
372409 f'`{{"unroll" = { unroll_factor } }}`: unroll parameter should be strictly positive.'
373410 )
@@ -382,27 +419,46 @@ def _apply_annotations(
382419 def _check_splitting_intervals (
383420 self ,
384421 item : SplitDecl ,
385- cut : int | None ,
386- x : int ,
387- ) -> None :
422+ cut : int | str | None ,
423+ x : int | str | None ,
424+ ) -> int | str | None :
388425 """Check that split intervals are valid and contiguous."""
389-
426+ y = item . end
390427 if cut is None :
391428 raise ScheduleInterpretError (f"{ item } : { item .axis } already covered." )
392429
393- if x > cut :
394- raise ScheduleInterpretError (
395- f"{ item } : splitting doesn't fully cover { item .axis } (jumps from { cut } to { x } )."
396- )
397- elif x < cut :
430+ if x is None :
398431 raise ScheduleInterpretError (
399- f"{ item } : the segment begins at { x } but the previous one ends at { cut } ."
432+ f"x is None, but cut: { cut } is not, this should be unreachable ."
400433 )
434+ if isinstance (x , int ) and isinstance (cut , int ):
435+ if x > cut :
436+ raise ScheduleInterpretError (
437+ f"{ item } : splitting doesn't fully cover { item .axis } (jumps from { cut } to { x } )."
438+ )
439+ elif x < cut :
440+ raise ScheduleInterpretError (
441+ f"{ item } : the segment begins at { x } but the previous one ends at { cut } ."
442+ )
443+ else :
444+ if x != cut :
445+ raise ScheduleInterpretError (
446+ f"{ item } : Splitting ends at { cut } and begins at { x } . These need to be the same."
447+ )
448+ if y is None :
449+ return None
401450
402- if item .end is not None and x >= item .end :
403- raise ScheduleInterpretError (
404- f"{ item } : the ending point should be greater than the starting point."
405- )
451+ if isinstance (x , int ):
452+ if isinstance (y , int ):
453+ if x >= y :
454+ raise ScheduleInterpretError (
455+ f"{ item } : the ending point should be greater than the starting point."
456+ )
457+ else :
458+ return y - x
459+ if x == 0 :
460+ return y
461+ return None
406462
407463
408464@dataclass
@@ -506,6 +562,34 @@ def tiles_to_sizes(self) -> dict[str, int]:
506562 tiles_to_sizes [loop ] = size
507563 return tiles_to_sizes
508564
565+ @property
566+ def int_tiles (self ) -> dict [str , dict [str , int ]]:
567+ return self ._int_dict (self .tiles )
568+
569+ @property
570+ def int_splits (self ) -> dict [str , dict [str , int ]]:
571+ return self ._int_dict (self .splits )
572+
573+ @property
574+ def int_unroll (self ) -> dict [str , int ]:
575+ out = {}
576+ for x , v in self .unroll .items ():
577+ if isinstance (v , str ) and v .isnumeric ():
578+ v = int (v )
579+ assert isinstance (v , int )
580+ out [x ] = v
581+ return out
582+
583+ def _int_dict (self , input : dict [str , dict [str , Any ]]) -> dict [str , dict [str , int ]]:
584+ out : dict [str , dict [str , int ]] = {}
585+ for x , v in input .items ():
586+ v_dict : dict [str , int ] = {}
587+ for x_v , v_v in v .items ():
588+ assert isinstance (v_v , int )
589+ v_dict [x_v ] = v_v
590+ out [x ] = v_dict
591+ return out
592+
509593
510594@dataclass
511595class LoopNest :
@@ -616,6 +700,7 @@ def _check_sizes(self):
616700
617701 if loop_name in sched .unroll :
618702 unroll_factor = sched .unroll [loop_name ]
703+ assert isinstance (unroll_factor , int )
619704 if loop_size and loop_size < unroll_factor :
620705 raise ScheduleValidationError (
621706 f'`{{"unroll" = { unroll_factor } }}`: unroll factor should be smaller than { loop_size } .'
@@ -650,19 +735,11 @@ def descript_scheduler(
650735 abstract_axis: The list of abstract axis names (e.g., ["m", "n", "k"]).
651736 spec: The schedule specification as a nested dict.
652737 """
653- descript = Descript (scheduler = scheduler , abstract_axis = abstract_axis )
654- descript .apply (node_name = node_name , spec = spec )
655-
656-
657- def correct_type (d : dict [str , int | str ]) -> dict [str , int ]:
658- out_d : dict [str , int ] = {}
659- for k , v in d .items ():
660- assert isinstance (v , int )
661- out_d [k ] = v
662- return out_d
738+ descript = Descript (abstract_axis = abstract_axis )
739+ descript .apply (scheduler = scheduler , node_name = node_name , spec = spec )
663740
664741
665- @dataclass (frozen = True )
742+ @dataclass (frozen = False )
666743class Descript :
667744 """Applies a parsed and interpreted schedule to a Scheduler.
668745
@@ -674,10 +751,11 @@ class Descript:
674751 4. Apply: LoopNest -> Scheduler
675752 """
676753
677- scheduler : Scheduler
678754 abstract_axis : list [str ]
679755
680- def apply (self , node_name : str , spec : dict [str , dict [str , Any ]]) -> None :
756+ def apply (
757+ self , node_name : str , spec : dict [str , dict [str , Any ]], scheduler : Scheduler
758+ ) -> None :
681759 """Parse, interpret, validate, and apply a schedule specification.
682760
683761 Args:
@@ -701,22 +779,22 @@ def apply(self, node_name: str, spec: dict[str, dict[str, Any]]) -> None:
701779 loop_nest .check ()
702780
703781 # Apply the schedule to the scheduler
704- self ._apply_loop_nest (loop_nest )
782+ self ._apply_loop_nest (loop_nest , scheduler )
705783
706- def _apply_loop_nest (self , loop_nest : LoopNest ) -> None :
784+ def _apply_loop_nest (self , loop_nest : LoopNest , scheduler : Scheduler ) -> None :
707785 """Apply a LoopNest to the scheduler."""
708- self . scheduler .set_dims (self .abstract_axis )
786+ scheduler .set_dims (self .abstract_axis )
709787
710788 for slice in loop_nest .slices :
711789 root = slice .root
712790
713- for d , s in slice .splits .items ():
714- self . scheduler .split (d , s , root = root )
791+ for d , s in slice .int_splits .items ():
792+ scheduler .split (d , s , root = root )
715793
716- for d , s in slice .tiles .items ():
717- self . scheduler .tile (d , s , root = root )
794+ for d , s in slice .int_tiles .items ():
795+ scheduler .tile (d , s , root = root )
718796
719- self . scheduler .interchange (slice .interchange , root = root )
720- self . scheduler .vectorize (slice .vectorize , root = root )
721- self . scheduler .parallelize (slice .parallelize , root = root )
722- self . scheduler .unroll (slice .unroll , root = root )
797+ scheduler .interchange (slice .interchange , root = root )
798+ scheduler .vectorize (slice .vectorize , root = root )
799+ scheduler .parallelize (slice .parallelize , root = root )
800+ scheduler .unroll (slice .int_unroll , root = root )
0 commit comments