@@ -758,16 +758,28 @@ def _eval_tensor2bits(
758758 return {k2 : v for k , v in result .items () for k2 in k }
759759
760760
761+ _VALID_DIRECTIONS = {"up" , "down" , "up_and_down" }
762+
763+
761764def _eval_direction (toeval , symbol_table : dict [str , Any ]) -> dict [str , str ]:
762765 """Evaluate a direction field. If a string, expand to all tensors. If a dict,
763766 resolve tensor expression keys."""
764767 if isinstance (toeval , str ):
765- # Single direction for all tensors
768+ if toeval not in _VALID_DIRECTIONS :
769+ raise EvaluationError (
770+ f'Invalid direction: "{ toeval } ". '
771+ f"Must be one of { sorted (_VALID_DIRECTIONS )} ."
772+ )
766773 all_tensors = symbol_table ["All" ].instance
767774 return {t : toeval for t in all_tensors }
768775
769776 result = {}
770777 for key , value in toeval .items ():
778+ if value not in _VALID_DIRECTIONS :
779+ raise EvaluationError (
780+ f'Invalid direction for { key } : "{ value } ". '
781+ f"Must be one of { sorted (_VALID_DIRECTIONS )} ."
782+ )
771783 key_evaluated = eval_set_expression (
772784 expression = key ,
773785 symbol_table = symbol_table ,
@@ -1020,7 +1032,7 @@ class Toll(TensorHolder):
10201032 zero.
10211033 """
10221034
1023- direction : EvalsTo [dict [ str , Literal [ "up" , "down" , "up_and_down" ]] ]
1035+ direction : TryEvalTo [dict ]
10241036 """
10251037 The direction in which data flows through this `Toll`. Can be:
10261038
0 commit comments