Skip to content

Commit 54cba53

Browse files
committed
descript: refacto
1 parent 8104278 commit 54cba53

4 files changed

Lines changed: 15 additions & 19 deletions

File tree

src/xtc/cli/mlir_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def build_node_scheduler(
135135
descript_scheduler(
136136
scheduler=scheduler,
137137
node_name=node_name,
138-
abstract_axis=scheduler.backend.dims,
138+
abstract_dims=scheduler.backend.dims,
139139
spec=normal_schedule,
140140
)
141141
op.attributes.pop("loop.schedule", None)

src/xtc/schedules/descript.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
def descript_scheduler(
2222
scheduler: Scheduler,
2323
node_name: str,
24-
abstract_axis: list[str],
24+
abstract_dims: list[str],
2525
spec: dict[str, dict[str, Any]],
2626
) -> None:
2727
"""Apply a schedule specification to a scheduler.
@@ -31,22 +31,22 @@ def descript_scheduler(
3131
Args:
3232
scheduler: The scheduler to apply the schedule to.
3333
node_name: The name of the root node to schedule.
34-
abstract_axis: The list of abstract axis names (e.g., ["m", "n", "k"]).
34+
abstract_dims: The list of abstract axis names (e.g., ["m", "n", "k"]).
3535
spec: The schedule specification as a nested dict.
3636
"""
37-
descript = Descript(scheduler=scheduler, abstract_axis=abstract_axis)
37+
descript = Descript(scheduler=scheduler, abstract_dims=abstract_dims)
3838
descript.apply(node_name=node_name, spec=spec)
3939

4040

4141
class ScheduleInterpreter:
4242
"""Interprets a parsed ScheduleSpec AST into a LoopNest."""
4343

44-
def __init__(self, abstract_axis: list[str]):
45-
self.abstract_axis = abstract_axis
44+
def __init__(self, abstract_dims: list[str]):
45+
self.abstract_dims = abstract_dims
4646

4747
def interpret(self, spec: ScheduleSpec, root: str) -> LoopNest:
4848
"""Interpret a schedule specification into a LoopNest."""
49-
loop_nest = LoopNest(abstract_dims=self.abstract_axis)
49+
loop_nest = LoopNest(abstract_dims=self.abstract_dims)
5050
root_node = loop_nest.build_root_node(root)
5151
self._interpret_spec_into_node(spec, root_node, root, head=[])
5252
return loop_nest
@@ -61,7 +61,7 @@ def _interpret_spec_into_node(
6161
"""Interpret a schedule spec into an existing node (mutates node)."""
6262
# Track state during interpretation
6363
sizes: dict[str, int] = {}
64-
previous_cut: dict[str, int | None] = {a: 0 for a in self.abstract_axis}
64+
previous_cut: dict[str, int | None] = {a: 0 for a in self.abstract_dims}
6565
interchange: list[str] = list(head)
6666

6767
for item in spec.items:
@@ -124,7 +124,7 @@ def _interpret_split(
124124
# Create a child node for the nested schedule
125125
child_node = LoopNestNode(
126126
root=new_root_name,
127-
tiles={a: {} for a in self.abstract_axis},
127+
tiles={a: {} for a in self.abstract_dims},
128128
split_origin=SplitOrigin(axis=axis_name, start=x, end=y),
129129
)
130130
node.add_child(child_node)
@@ -176,9 +176,9 @@ def _interpret_axis(
176176

177177
def _check_axis_existence(self, axis: str) -> None:
178178
"""Check that an axis is defined."""
179-
if axis not in self.abstract_axis:
179+
if axis not in self.abstract_dims:
180180
raise ScheduleInterpretError(
181-
f"Axis {axis} is not a defined axis (defined axis: {self.abstract_axis})."
181+
f"Axis {axis} is not a defined axis (defined axis: {self.abstract_dims})."
182182
)
183183

184184
def _apply_annotations(
@@ -249,7 +249,7 @@ class Descript:
249249
"""
250250

251251
scheduler: Scheduler
252-
abstract_axis: list[str]
252+
abstract_dims: list[str]
253253

254254
def apply(self, node_name: str, spec: dict[str, dict[str, Any]]) -> None:
255255
"""Parse, interpret, validate, and apply a schedule specification.
@@ -268,7 +268,7 @@ def apply(self, node_name: str, spec: dict[str, dict[str, Any]]) -> None:
268268
ast = parser.parse(spec)
269269

270270
# Interpret the AST into a LoopNest
271-
interpreter = ScheduleInterpreter(self.abstract_axis)
271+
interpreter = ScheduleInterpreter(self.abstract_dims)
272272
loop_nest = interpreter.interpret(ast, root=node_name)
273273

274274
# Validate the loop nest
@@ -279,7 +279,7 @@ def apply(self, node_name: str, spec: dict[str, dict[str, Any]]) -> None:
279279

280280
def _apply_loop_nest(self, loop_nest: LoopNest) -> None:
281281
"""Apply a LoopNest to the scheduler."""
282-
self.scheduler.set_dims(self.abstract_axis)
282+
self.scheduler.set_dims(self.abstract_dims)
283283

284284
if loop_nest.root_node is not None:
285285
self._apply_node(loop_nest.root_node)

src/xtc/schedules/loop_nest.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -305,10 +305,6 @@ class LoopNest:
305305
abstract_dims: list[str]
306306
root_node: LoopNestNode | None = None
307307

308-
@property
309-
def empty(self) -> bool:
310-
return self.root_node is None
311-
312308
@property
313309
def nodes(self) -> list[LoopNestNode]:
314310
"""Flatten the tree into a list of nodes.

src/xtc/schedules/ttile/scheme_to_xtc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,7 @@ def build_schedule_from_ttile(
723723
)
724724

725725
descript_scheduler(
726-
scheduler=sch, node_name=name_op, abstract_axis=ldims, spec=spec_schedule
726+
scheduler=sch, node_name=name_op, abstract_dims=ldims, spec=spec_schedule
727727
)
728728

729729
# And run it!

0 commit comments

Comments
 (0)