Skip to content

Commit dcb2575

Browse files
Better parsing for toll direction
1 parent 5847dc2 commit dcb2575

1 file changed

Lines changed: 14 additions & 2 deletions

File tree

accelforge/frontend/arch/components.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -758,16 +758,28 @@ def _eval_tensor2bits(
758758
return {k2: v for k, v in result.items() for k2 in k}
759759

760760

761+
_VALID_DIRECTIONS = {"up", "down", "up_and_down"}
762+
763+
761764
def _eval_direction(toeval, symbol_table: dict[str, Any]) -> dict[str, str]:
762765
"""Evaluate a direction field. If a string, expand to all tensors. If a dict,
763766
resolve tensor expression keys."""
764767
if isinstance(toeval, str):
765-
# Single direction for all tensors
768+
if toeval not in _VALID_DIRECTIONS:
769+
raise EvaluationError(
770+
f'Invalid direction: "{toeval}". '
771+
f"Must be one of {sorted(_VALID_DIRECTIONS)}."
772+
)
766773
all_tensors = symbol_table["All"].instance
767774
return {t: toeval for t in all_tensors}
768775

769776
result = {}
770777
for key, value in toeval.items():
778+
if value not in _VALID_DIRECTIONS:
779+
raise EvaluationError(
780+
f'Invalid direction for {key}: "{value}". '
781+
f"Must be one of {sorted(_VALID_DIRECTIONS)}."
782+
)
771783
key_evaluated = eval_set_expression(
772784
expression=key,
773785
symbol_table=symbol_table,
@@ -1020,7 +1032,7 @@ class Toll(TensorHolder):
10201032
zero.
10211033
"""
10221034

1023-
direction: EvalsTo[dict[str, Literal["up", "down", "up_and_down"]]]
1035+
direction: TryEvalTo[dict]
10241036
"""
10251037
The direction in which data flows through this `Toll`. Can be:
10261038

0 commit comments

Comments
 (0)