Skip to content

Commit 746e330

Browse files
committed
[DSL] Add mulitple operators
Signed-off-by: Herklos <herklos@drakkar.software>
1 parent 8f36a4d commit 746e330

2 files changed

Lines changed: 265 additions & 13 deletions

File tree

packages/commons/octobot_commons/dsl_interpreter/interpreter.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -237,19 +237,28 @@ def _visit_node(self, node: typing.Optional[ast.AST]) -> typing.Union[
237237

238238
if isinstance(node, ast.Compare):
239239
# Comparison: left op right
240-
if len(node.ops) == 1 and len(node.comparators) == 1:
241-
op_name = type(node.ops[0]).__name__
242-
if op_name in self.operators_by_name:
243-
operator_class = self.operators_by_name[op_name]
244-
left = self._visit_node(node.left)
245-
right = self._visit_node(node.comparators[0])
246-
return operator_class(left, right)
240+
# Handles both single comparisons (a < b) and chained comparisons (a < b <= c)
241+
# Chained comparisons are decomposed into: (a < b) And (b <= c)
242+
comparisons = []
243+
left = self._visit_node(node.left)
244+
for op, comparator in zip(node.ops, node.comparators):
245+
op_name = type(op).__name__
246+
if op_name not in self.operators_by_name:
247+
raise octobot_commons.errors.UnsupportedOperatorError(
248+
f"Unknown comparison operator: {op_name}"
249+
)
250+
operator_class = self.operators_by_name[op_name]
251+
right = self._visit_node(comparator)
252+
comparisons.append(operator_class(left, right))
253+
left = right
254+
if len(comparisons) == 1:
255+
return comparisons[0]
256+
and_op_name = ast.And.__name__
257+
if and_op_name not in self.operators_by_name:
247258
raise octobot_commons.errors.UnsupportedOperatorError(
248-
f"Unknown comparison operator: {op_name}"
259+
f"Chained comparisons require the '{and_op_name}' operator"
249260
)
250-
raise octobot_commons.errors.UnsupportedOperatorError(
251-
"Multiple comparisons not supported"
252-
)
261+
return self.operators_by_name[and_op_name](*comparisons)
253262

254263
if isinstance(node, (ast.Constant)):
255264
# Literal values: numbers, strings, booleans, None

packages/commons/tests/dsl_interpreter/test_custom_operators.py

Lines changed: 245 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,118 @@ def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
9292
left, right = self.get_computed_left_and_right_parameters()
9393
return left + right
9494

95+
96+
class SubOperator(dsl_interpreter.BinaryOperator):
97+
@staticmethod
98+
def get_name() -> str:
99+
return ast.Sub.__name__
100+
101+
def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
102+
left, right = self.get_computed_left_and_right_parameters()
103+
return left - right
104+
105+
106+
class LtOperator(dsl_interpreter.CompareOperator):
107+
@staticmethod
108+
def get_name() -> str:
109+
return ast.Lt.__name__
110+
111+
def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
112+
left, right = self.get_computed_left_and_right_parameters()
113+
return left < right
114+
115+
116+
class LtEOperator(dsl_interpreter.CompareOperator):
117+
@staticmethod
118+
def get_name() -> str:
119+
return ast.LtE.__name__
120+
121+
def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
122+
left, right = self.get_computed_left_and_right_parameters()
123+
return left <= right
124+
125+
126+
class GtOperator(dsl_interpreter.CompareOperator):
127+
@staticmethod
128+
def get_name() -> str:
129+
return ast.Gt.__name__
130+
131+
def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
132+
left, right = self.get_computed_left_and_right_parameters()
133+
return left > right
134+
135+
136+
class GtEOperator(dsl_interpreter.CompareOperator):
137+
@staticmethod
138+
def get_name() -> str:
139+
return ast.GtE.__name__
140+
141+
def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
142+
left, right = self.get_computed_left_and_right_parameters()
143+
return left >= right
144+
145+
146+
class EqOperator(dsl_interpreter.CompareOperator):
147+
@staticmethod
148+
def get_name() -> str:
149+
return ast.Eq.__name__
150+
151+
def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
152+
left, right = self.get_computed_left_and_right_parameters()
153+
return left == right
154+
155+
156+
class NotEqOperator(dsl_interpreter.CompareOperator):
157+
@staticmethod
158+
def get_name() -> str:
159+
return ast.NotEq.__name__
160+
161+
def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
162+
left, right = self.get_computed_left_and_right_parameters()
163+
return left != right
164+
165+
166+
class IsOperator(dsl_interpreter.CompareOperator):
167+
@staticmethod
168+
def get_name() -> str:
169+
return ast.Is.__name__
170+
171+
def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
172+
left, right = self.get_computed_left_and_right_parameters()
173+
return left is right
174+
175+
176+
class IsNotOperator(dsl_interpreter.CompareOperator):
177+
@staticmethod
178+
def get_name() -> str:
179+
return ast.IsNot.__name__
180+
181+
def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
182+
left, right = self.get_computed_left_and_right_parameters()
183+
return left is not right
184+
185+
186+
class AndOperator(dsl_interpreter.NaryOperator):
187+
MIN_PARAMS = 1
188+
189+
@staticmethod
190+
def get_name() -> str:
191+
return ast.And.__name__
192+
193+
def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
194+
return all(self.get_computed_parameters())
195+
196+
197+
class OrOperator(dsl_interpreter.NaryOperator):
198+
MIN_PARAMS = 1
199+
200+
@staticmethod
201+
def get_name() -> str:
202+
return ast.Or.__name__
203+
204+
def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
205+
return any(self.get_computed_parameters())
206+
95207
class Add2Operator(dsl_interpreter.CallOperator):
96208
@staticmethod
97209
def get_name() -> str:
@@ -193,7 +305,11 @@ def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
193305
def interpreter():
194306
return dsl_interpreter.Interpreter(
195307
dsl_interpreter.get_all_operators() + [
196-
SumPlusXOperatorWithoutInit, SumPlusXOperatorWithPreCompute, TimeFrameToSecondsOperator, AddOperator, Add2Operator, PreComputeSumOperator, CallWithDefaultParametersOperator, NestedDictSumOperator, ParamMerger
308+
SumPlusXOperatorWithoutInit, SumPlusXOperatorWithPreCompute, TimeFrameToSecondsOperator,
309+
AddOperator, SubOperator, Add2Operator, PreComputeSumOperator, CallWithDefaultParametersOperator,
310+
NestedDictSumOperator, ParamMerger,
311+
LtOperator, LtEOperator, GtOperator, GtEOperator, EqOperator, NotEqOperator,
312+
IsOperator, IsNotOperator, AndOperator, OrOperator
197313
]
198314
)
199315

@@ -518,4 +634,131 @@ def test_get_docs_to_json():
518634
assert json_data["parameters"][1]["name"] == "y"
519635
assert json_data["parameters"][1]["description"] == "second parameter"
520636
assert json_data["parameters"][1]["required"] is False
521-
assert json_data["parameters"][1]["type"] == "int"
637+
assert json_data["parameters"][1]["type"] == "int"
638+
639+
640+
@pytest.mark.asyncio
641+
async def test_single_comparison(interpreter):
642+
assert await interpreter.interprete("1 < 2") is True
643+
assert await interpreter.interprete("2 < 1") is False
644+
assert await interpreter.interprete("5 <= 5") is True
645+
assert await interpreter.interprete("5 >= 5") is True
646+
assert await interpreter.interprete("5 > 3") is True
647+
assert await interpreter.interprete("3 == 3") is True
648+
assert await interpreter.interprete("3 != 4") is True
649+
650+
651+
@pytest.mark.asyncio
652+
async def test_chained_comparison_two_ops(interpreter):
653+
# 0 < 5 <= 10 => (0 < 5) and (5 <= 10) => True
654+
assert await interpreter.interprete("0 < 5 <= 10") is True
655+
# 0 < 10 <= 10 => (0 < 10) and (10 <= 10) => True
656+
assert await interpreter.interprete("0 < 10 <= 10") is True
657+
# 0 < 15 <= 10 => (0 < 15) and (15 <= 10) => False (second fails)
658+
assert await interpreter.interprete("0 < 15 <= 10") is False
659+
# 5 < 3 <= 10 => (5 < 3) and (3 <= 10) => False (first fails)
660+
assert await interpreter.interprete("5 < 3 <= 10") is False
661+
# both fail: 10 < 5 <= 3
662+
assert await interpreter.interprete("10 < 5 <= 3") is False
663+
664+
665+
@pytest.mark.asyncio
666+
async def test_chained_comparison_three_ops(interpreter):
667+
# 1 < 2 < 3 < 4 => all True
668+
assert await interpreter.interprete("1 < 2 < 3 < 4") is True
669+
# 1 < 2 < 3 < 3 => last fails (3 < 3 is False)
670+
assert await interpreter.interprete("1 < 2 < 3 < 3") is False
671+
# 1 <= 1 <= 1 <= 1 => all True
672+
assert await interpreter.interprete("1 <= 1 <= 1 <= 1") is True
673+
674+
675+
@pytest.mark.asyncio
676+
async def test_chained_comparison_mixed_operators(interpreter):
677+
# 0 < 5 >= 3 => (0 < 5) and (5 >= 3) => True
678+
assert await interpreter.interprete("0 < 5 >= 3") is True
679+
# 0 < 5 >= 6 => (0 < 5) and (5 >= 6) => False
680+
assert await interpreter.interprete("0 < 5 >= 6") is False
681+
# 1 <= 2 > 1 => (1 <= 2) and (2 > 1) => True
682+
assert await interpreter.interprete("1 <= 2 > 1") is True
683+
# 1 != 2 < 3 => (1 != 2) and (2 < 3) => True
684+
assert await interpreter.interprete("1 != 2 < 3") is True
685+
# 1 == 1 < 2 => (1 == 1) and (1 < 2) => True
686+
assert await interpreter.interprete("1 == 1 < 2") is True
687+
# 1 == 1 < 0 => (1 == 1) and (1 < 0) => False
688+
assert await interpreter.interprete("1 == 1 < 0") is False
689+
690+
691+
@pytest.mark.asyncio
692+
async def test_chained_comparison_with_expressions(interpreter):
693+
# chained comparison where operands are arithmetic expressions
694+
# 0 < (2 + 3) <= 10 => 0 < 5 <= 10 => True
695+
assert await interpreter.interprete("0 < 2 + 3 <= 10") is True
696+
# 0 < (10 - 3) <= 5 => 0 < 7 <= 5 => False
697+
assert await interpreter.interprete("0 < 10 - 3 <= 5") is False
698+
699+
700+
@pytest.mark.asyncio
701+
async def test_chained_comparison_with_function_calls(interpreter):
702+
# plus_42() returns 42 => 0 < 42 <= 100 => True
703+
assert await interpreter.interprete("0 < plus_42() <= 100") is True
704+
# 0 < 42 <= 41 => False
705+
assert await interpreter.interprete("0 < plus_42() <= 41") is False
706+
# 40 < 42 < 50 => True
707+
assert await interpreter.interprete("40 < plus_42() < 50") is True
708+
# middle operand shared: 0 < plus_42() <= plus_42() => 0 < 42 <= 42 => True
709+
assert await interpreter.interprete("0 < plus_42() <= plus_42()") is True
710+
711+
712+
@pytest.mark.asyncio
713+
async def test_chained_comparison_with_is_not(interpreter):
714+
# x is not None and 0 < x <= 300000
715+
# This mirrors the original failing DSL script pattern
716+
assert await interpreter.interprete("100 is not None") is True
717+
assert await interpreter.interprete("None is None") is True
718+
assert await interpreter.interprete("None is not None") is False
719+
720+
721+
@pytest.mark.asyncio
722+
async def test_chained_comparison_in_if_expression(interpreter):
723+
# "body if test else orelse" with chained comparison as test
724+
# 1 if 0 < 5 <= 10 else 0 => 1
725+
assert await interpreter.interprete("1 if 0 < 5 <= 10 else 0") == 1
726+
# 1 if 0 < 15 <= 10 else 0 => 0
727+
assert await interpreter.interprete("1 if 0 < 15 <= 10 else 0") == 0
728+
729+
730+
@pytest.mark.asyncio
731+
async def test_chained_comparison_combined_with_bool_ops(interpreter):
732+
# chained comparison inside a boolean expression
733+
# (0 < 5 <= 10) and (1 < 2) => True and True => True
734+
assert await interpreter.interprete("0 < 5 <= 10 and 1 < 2") is True
735+
# (0 < 5 <= 10) and (2 < 1) => True and False => False
736+
assert await interpreter.interprete("0 < 5 <= 10 and 2 < 1") is False
737+
# (0 < 15 <= 10) or (1 < 2) => False or True => True
738+
assert await interpreter.interprete("0 < 15 <= 10 or 1 < 2") is True
739+
740+
741+
@pytest.mark.asyncio
742+
async def test_chained_comparison_boundary_values(interpreter):
743+
# exact boundary: 0 < 0 <= 10 => (0 < 0) is False
744+
assert await interpreter.interprete("0 < 0 <= 10") is False
745+
# exact boundary: 0 < 10 <= 10 => True
746+
assert await interpreter.interprete("0 < 10 <= 10") is True
747+
# negative values: -5 < 0 < 5 => True
748+
assert await interpreter.interprete("-5 < 0 < 5") is True # parsed as (-5) < 0 < 5
749+
# float boundaries
750+
assert await interpreter.interprete("0.0 < 0.5 <= 1.0") is True
751+
assert await interpreter.interprete("0.0 < 1.0 <= 0.5") is False
752+
753+
754+
@pytest.mark.asyncio
755+
async def test_chained_comparison_without_and_operator_raises(interpreter):
756+
# create an interpreter without the And operator to verify the error message
757+
interpreter_no_and = dsl_interpreter.Interpreter([
758+
LtOperator, LtEOperator,
759+
])
760+
# single comparison still works
761+
assert await interpreter_no_and.interprete("1 < 2") is True
762+
# chained comparison requires And and should raise
763+
with pytest.raises(commons_errors.UnsupportedOperatorError, match="Chained comparisons require the 'And' operator"):
764+
interpreter_no_and.prepare("0 < 5 <= 10")

0 commit comments

Comments
 (0)