Skip to content

Commit d59d050

Browse files
author
Leon Frenot
committed
Fixes after more rebasing
1 parent cd8f59a commit d59d050

20 files changed

Lines changed: 933 additions & 913 deletions

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ scikit-learn
1010
networkx
1111
sympy
1212
strictyaml
13+
types-PyYAML

src/xtc/cli/mlir_loop.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,13 @@ def normalize_extend_schedule(
182182
assert isinstance(instr, str)
183183
if isinstance(param, builtin.UnitAttr):
184184
annotations[instr] = None
185-
elif isinstance(param, builtin.IntegerAttr):
185+
elif isinstance(param, builtin.IntegerAttr) or isinstance(
186+
param, builtin.StringAttr
187+
):
186188
annotations[instr] = param.value.data
187189
else:
188190
raise Exception(
189-
"Annotation parameter should be void or int."
191+
"Annotation parameter should be void, int, or str."
190192
)
191193

192194
elif not isinstance(val, builtin.UnitAttr):

src/xtc/schedules/descript.py

Lines changed: 136 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -41,23 +41,28 @@ class Annotations:
4141
parallelize: True if parallelization was requested.
4242
"""
4343

44-
unroll_factor: int | None = None
44+
unroll_factor: int | str | None = None
4545
unroll_specified: bool = False
46-
vectorize: bool = False
47-
parallelize: bool = False
46+
vectorize: bool | str = False
47+
parallelize: bool | str = False
48+
partial: bool = False
49+
full: bool = False
4850

4951

5052
@dataclass(frozen=True)
5153
class SplitDecl:
5254
"""AST Type: a split declaration like 'axis[start:end]'."""
5355

5456
axis: str
55-
start: int | None
56-
end: int | None
57+
start: int | str | None
58+
end: int | str | None
5759
body: ScheduleSpec
60+
size: int | str | None = None
5861

5962
@override
6063
def __str__(self) -> str:
64+
if self.size is not None:
65+
return f"{self.axis}[:{self.size}:]"
6166
start_str = "" if self.start is None else str(self.start)
6267
end_str = "" if self.end is None else str(self.end)
6368
decl = f"{self.axis}[{start_str}:{end_str}]"
@@ -69,7 +74,7 @@ class TileDecl:
6974
"""AST Type: a tile declaration like 'axis#size'."""
7075

7176
axis: str
72-
size: int
77+
size: int | str
7378
annotations: Annotations
7479

7580
@override
@@ -85,7 +90,36 @@ class AxisDecl:
8590
annotations: Annotations
8691

8792

88-
ScheduleItem = SplitDecl | TileDecl | AxisDecl
93+
@dataclass(frozen=True)
94+
class FusionDecl:
95+
"""AST Type: a fusion declaration"""
96+
97+
98+
@dataclass(frozen=True)
99+
class PackDecl:
100+
"""AST Type: a packing declaration"""
101+
102+
param: str | bool
103+
input: str
104+
pad: str | bool
105+
106+
107+
@dataclass(frozen=True)
108+
class BufferDecl:
109+
"""AST Type: a bufferisation declaration"""
110+
111+
param: str | bool
112+
pad: str
113+
114+
115+
@dataclass(frozen=True)
116+
class ExploreDecl:
117+
level: str
118+
119+
120+
ScheduleItem = (
121+
SplitDecl | TileDecl | AxisDecl | FusionDecl | PackDecl | BufferDecl | ExploreDecl
122+
)
89123

90124

91125
@dataclass(frozen=True)
@@ -144,10 +178,12 @@ def _parse_tile(self, declaration: str, value: dict) -> TileDecl:
144178

145179
axis_name, size_str = parts
146180

147-
try:
148-
size = int(size_str)
149-
except ValueError:
150-
raise ScheduleParseError(f"`{declaration}`: {size_str} is not an integer.")
181+
size = int(size_str) if size_str.isnumeric() else size_str
182+
183+
# try:
184+
# size = int(size_str)
185+
# except ValueError:
186+
# raise ScheduleParseError(f"`{declaration}`: {size_str} is not an integer.")
151187

152188
annotations = self._parse_annotations(value, declaration)
153189
return TileDecl(axis=axis_name, size=size, annotations=annotations)
@@ -234,8 +270,8 @@ def _interpret_spec(
234270
slice = loop_nest.build_slice(root)
235271

236272
# Track state during interpretation
237-
sizes: dict[str, int] = {}
238-
previous_cut: dict[str, int | None] = {a: 0 for a in self.abstract_axis}
273+
sizes: dict[str, int | str] = {}
274+
previous_cut: dict[str, int | str | None] = {a: 0 for a in self.abstract_axis}
239275
interchange: list[str] = list(head)
240276

241277
for item in spec.items:
@@ -267,7 +303,7 @@ def _interpret_split(
267303
loop_nest: LoopNest,
268304
root: str,
269305
interchange: list[str],
270-
previous_cut: dict[str, int | None],
306+
previous_cut: dict[str, int | str | None],
271307
) -> None:
272308
"""Interpret a split declaration."""
273309
axis_name = item.axis
@@ -283,10 +319,8 @@ def _interpret_split(
283319
# it is the previous cut
284320
if x is None:
285321
x = cut
286-
assert x is not None
287-
288322
self._check_splitting_intervals(item, cut, x)
289-
323+
assert x is not None
290324
# Update the previous cut
291325
previous_cut[axis_name] = y
292326

@@ -308,12 +342,15 @@ def _interpret_tile(
308342
item: TileDecl,
309343
slice: LoopNestSlice,
310344
interchange: list[str],
311-
sizes: dict[str, int],
345+
sizes: dict[str, int | str],
312346
) -> str:
313347
"""Interpret a tile declaration. Returns the loop name."""
314348
self._check_axis_existence(item.axis)
315349
tile_num = len(slice.tiles[item.axis])
316350
loop_name = f"{item.axis}{tile_num}"
351+
if not isinstance(item.size, int):
352+
raise ScheduleInterpretError(f"`{item}`: {item.size} is not an integer.")
353+
assert isinstance(item.size, int)
317354
if item.size <= 0:
318355
raise ScheduleInterpretError(
319356
f"`{item}`: tile sizes should be strictly positive."
@@ -354,7 +391,7 @@ def _apply_annotations(
354391
self,
355392
annotations: Annotations,
356393
loop_name: str,
357-
sizes: dict[str, int],
394+
sizes: dict[str, int | str],
358395
slice: LoopNestSlice,
359396
) -> None:
360397
"""Apply annotations to a loop in the slice."""
@@ -367,7 +404,7 @@ def _apply_annotations(
367404
f"{loop_name}'s size being unknown, an unroll factor is needed."
368405
)
369406
unroll_factor = sizes[loop_name]
370-
elif unroll_factor <= 0:
407+
elif isinstance(unroll_factor, int) and unroll_factor <= 0:
371408
raise ScheduleInterpretError(
372409
f'`{{"unroll" = {unroll_factor}}}`: unroll parameter should be strictly positive.'
373410
)
@@ -382,27 +419,46 @@ def _apply_annotations(
382419
def _check_splitting_intervals(
383420
self,
384421
item: SplitDecl,
385-
cut: int | None,
386-
x: int,
387-
) -> None:
422+
cut: int | str | None,
423+
x: int | str | None,
424+
) -> int | str | None:
388425
"""Check that split intervals are valid and contiguous."""
389-
426+
y = item.end
390427
if cut is None:
391428
raise ScheduleInterpretError(f"{item}: {item.axis} already covered.")
392429

393-
if x > cut:
394-
raise ScheduleInterpretError(
395-
f"{item}: splitting doesn't fully cover {item.axis} (jumps from {cut} to {x})."
396-
)
397-
elif x < cut:
430+
if x is None:
398431
raise ScheduleInterpretError(
399-
f"{item}: the segment begins at {x} but the previous one ends at {cut}."
432+
f"x is None, but cut: {cut} is not, this should be unreachable."
400433
)
434+
if isinstance(x, int) and isinstance(cut, int):
435+
if x > cut:
436+
raise ScheduleInterpretError(
437+
f"{item}: splitting doesn't fully cover {item.axis} (jumps from {cut} to {x})."
438+
)
439+
elif x < cut:
440+
raise ScheduleInterpretError(
441+
f"{item}: the segment begins at {x} but the previous one ends at {cut}."
442+
)
443+
else:
444+
if x != cut:
445+
raise ScheduleInterpretError(
446+
f"{item}: Splitting ends at {cut} and begins at {x}. These need to be the same."
447+
)
448+
if y is None:
449+
return None
401450

402-
if item.end is not None and x >= item.end:
403-
raise ScheduleInterpretError(
404-
f"{item}: the ending point should be greater than the starting point."
405-
)
451+
if isinstance(x, int):
452+
if isinstance(y, int):
453+
if x >= y:
454+
raise ScheduleInterpretError(
455+
f"{item}: the ending point should be greater than the starting point."
456+
)
457+
else:
458+
return y - x
459+
if x == 0:
460+
return y
461+
return None
406462

407463

408464
@dataclass
@@ -506,6 +562,34 @@ def tiles_to_sizes(self) -> dict[str, int]:
506562
tiles_to_sizes[loop] = size
507563
return tiles_to_sizes
508564

565+
@property
566+
def int_tiles(self) -> dict[str, dict[str, int]]:
567+
return self._int_dict(self.tiles)
568+
569+
@property
570+
def int_splits(self) -> dict[str, dict[str, int]]:
571+
return self._int_dict(self.splits)
572+
573+
@property
574+
def int_unroll(self) -> dict[str, int]:
575+
out = {}
576+
for x, v in self.unroll.items():
577+
if isinstance(v, str) and v.isnumeric():
578+
v = int(v)
579+
assert isinstance(v, int)
580+
out[x] = v
581+
return out
582+
583+
def _int_dict(self, input: dict[str, dict[str, Any]]) -> dict[str, dict[str, int]]:
584+
out: dict[str, dict[str, int]] = {}
585+
for x, v in input.items():
586+
v_dict: dict[str, int] = {}
587+
for x_v, v_v in v.items():
588+
assert isinstance(v_v, int)
589+
v_dict[x_v] = v_v
590+
out[x] = v_dict
591+
return out
592+
509593

510594
@dataclass
511595
class LoopNest:
@@ -616,6 +700,7 @@ def _check_sizes(self):
616700

617701
if loop_name in sched.unroll:
618702
unroll_factor = sched.unroll[loop_name]
703+
assert isinstance(unroll_factor, int)
619704
if loop_size and loop_size < unroll_factor:
620705
raise ScheduleValidationError(
621706
f'`{{"unroll" = {unroll_factor}}}`: unroll factor should be smaller than {loop_size}.'
@@ -650,19 +735,11 @@ def descript_scheduler(
650735
abstract_axis: The list of abstract axis names (e.g., ["m", "n", "k"]).
651736
spec: The schedule specification as a nested dict.
652737
"""
653-
descript = Descript(scheduler=scheduler, abstract_axis=abstract_axis)
654-
descript.apply(node_name=node_name, spec=spec)
655-
656-
657-
def correct_type(d: dict[str, int | str]) -> dict[str, int]:
658-
out_d: dict[str, int] = {}
659-
for k, v in d.items():
660-
assert isinstance(v, int)
661-
out_d[k] = v
662-
return out_d
738+
descript = Descript(abstract_axis=abstract_axis)
739+
descript.apply(scheduler=scheduler, node_name=node_name, spec=spec)
663740

664741

665-
@dataclass(frozen=True)
742+
@dataclass(frozen=False)
666743
class Descript:
667744
"""Applies a parsed and interpreted schedule to a Scheduler.
668745
@@ -674,10 +751,11 @@ class Descript:
674751
4. Apply: LoopNest -> Scheduler
675752
"""
676753

677-
scheduler: Scheduler
678754
abstract_axis: list[str]
679755

680-
def apply(self, node_name: str, spec: dict[str, dict[str, Any]]) -> None:
756+
def apply(
757+
self, node_name: str, spec: dict[str, dict[str, Any]], scheduler: Scheduler
758+
) -> None:
681759
"""Parse, interpret, validate, and apply a schedule specification.
682760
683761
Args:
@@ -701,22 +779,22 @@ def apply(self, node_name: str, spec: dict[str, dict[str, Any]]) -> None:
701779
loop_nest.check()
702780

703781
# Apply the schedule to the scheduler
704-
self._apply_loop_nest(loop_nest)
782+
self._apply_loop_nest(loop_nest, scheduler)
705783

706-
def _apply_loop_nest(self, loop_nest: LoopNest) -> None:
784+
def _apply_loop_nest(self, loop_nest: LoopNest, scheduler: Scheduler) -> None:
707785
"""Apply a LoopNest to the scheduler."""
708-
self.scheduler.set_dims(self.abstract_axis)
786+
scheduler.set_dims(self.abstract_axis)
709787

710788
for slice in loop_nest.slices:
711789
root = slice.root
712790

713-
for d, s in slice.splits.items():
714-
self.scheduler.split(d, s, root=root)
791+
for d, s in slice.int_splits.items():
792+
scheduler.split(d, s, root=root)
715793

716-
for d, s in slice.tiles.items():
717-
self.scheduler.tile(d, s, root=root)
794+
for d, s in slice.int_tiles.items():
795+
scheduler.tile(d, s, root=root)
718796

719-
self.scheduler.interchange(slice.interchange, root=root)
720-
self.scheduler.vectorize(slice.vectorize, root=root)
721-
self.scheduler.parallelize(slice.parallelize, root=root)
722-
self.scheduler.unroll(slice.unroll, root=root)
797+
scheduler.interchange(slice.interchange, root=root)
798+
scheduler.vectorize(slice.vectorize, root=root)
799+
scheduler.parallelize(slice.parallelize, root=root)
800+
scheduler.unroll(slice.int_unroll, root=root)

0 commit comments

Comments
 (0)