Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions packages/commons/octobot_commons/dsl_interpreter/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,19 +237,28 @@ def _visit_node(self, node: typing.Optional[ast.AST]) -> typing.Union[

if isinstance(node, ast.Compare):
# Comparison: left op right
if len(node.ops) == 1 and len(node.comparators) == 1:
op_name = type(node.ops[0]).__name__
if op_name in self.operators_by_name:
operator_class = self.operators_by_name[op_name]
left = self._visit_node(node.left)
right = self._visit_node(node.comparators[0])
return operator_class(left, right)
# Handles both single comparisons (a < b) and chained comparisons (a < b <= c)
# Chained comparisons are decomposed into: (a < b) And (b <= c)
comparisons = []
left = self._visit_node(node.left)
for op, comparator in zip(node.ops, node.comparators):
op_name = type(op).__name__
if op_name not in self.operators_by_name:
raise octobot_commons.errors.UnsupportedOperatorError(
f"Unknown comparison operator: {op_name}"
)
operator_class = self.operators_by_name[op_name]
right = self._visit_node(comparator)
comparisons.append(operator_class(left, right))
left = right
if len(comparisons) == 1:
return comparisons[0]
and_op_name = ast.And.__name__
if and_op_name not in self.operators_by_name:
raise octobot_commons.errors.UnsupportedOperatorError(
f"Unknown comparison operator: {op_name}"
f"Chained comparisons require the '{and_op_name}' operator"
)
raise octobot_commons.errors.UnsupportedOperatorError(
"Multiple comparisons not supported"
)
return self.operators_by_name[and_op_name](*comparisons)

if isinstance(node, (ast.Constant)):
# Literal values: numbers, strings, booleans, None
Expand Down
218 changes: 216 additions & 2 deletions packages/commons/tests/dsl_interpreter/test_custom_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,118 @@ def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
left, right = self.get_computed_left_and_right_parameters()
return left + right


class SubOperator(dsl_interpreter.BinaryOperator):
@staticmethod
def get_name() -> str:
return ast.Sub.__name__

def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
left, right = self.get_computed_left_and_right_parameters()
return left - right


class LtOperator(dsl_interpreter.CompareOperator):
@staticmethod
def get_name() -> str:
return ast.Lt.__name__

def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
left, right = self.get_computed_left_and_right_parameters()
return left < right


class LtEOperator(dsl_interpreter.CompareOperator):
@staticmethod
def get_name() -> str:
return ast.LtE.__name__

def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
left, right = self.get_computed_left_and_right_parameters()
return left <= right


class GtOperator(dsl_interpreter.CompareOperator):
@staticmethod
def get_name() -> str:
return ast.Gt.__name__

def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
left, right = self.get_computed_left_and_right_parameters()
return left > right


class GtEOperator(dsl_interpreter.CompareOperator):
@staticmethod
def get_name() -> str:
return ast.GtE.__name__

def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
left, right = self.get_computed_left_and_right_parameters()
return left >= right


class EqOperator(dsl_interpreter.CompareOperator):
@staticmethod
def get_name() -> str:
return ast.Eq.__name__

def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
left, right = self.get_computed_left_and_right_parameters()
return left == right


class NotEqOperator(dsl_interpreter.CompareOperator):
@staticmethod
def get_name() -> str:
return ast.NotEq.__name__

def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
left, right = self.get_computed_left_and_right_parameters()
return left != right


class IsOperator(dsl_interpreter.CompareOperator):
@staticmethod
def get_name() -> str:
return ast.Is.__name__

def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
left, right = self.get_computed_left_and_right_parameters()
return left is right


class IsNotOperator(dsl_interpreter.CompareOperator):
@staticmethod
def get_name() -> str:
return ast.IsNot.__name__

def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
left, right = self.get_computed_left_and_right_parameters()
return left is not right


class AndOperator(dsl_interpreter.NaryOperator):
MIN_PARAMS = 1

@staticmethod
def get_name() -> str:
return ast.And.__name__

def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
return all(self.get_computed_parameters())


class OrOperator(dsl_interpreter.NaryOperator):
MIN_PARAMS = 1

@staticmethod
def get_name() -> str:
return ast.Or.__name__

def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
return any(self.get_computed_parameters())

class Add2Operator(dsl_interpreter.CallOperator):
@staticmethod
def get_name() -> str:
Expand Down Expand Up @@ -193,7 +305,11 @@ def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
def interpreter():
return dsl_interpreter.Interpreter(
dsl_interpreter.get_all_operators() + [
SumPlusXOperatorWithoutInit, SumPlusXOperatorWithPreCompute, TimeFrameToSecondsOperator, AddOperator, Add2Operator, PreComputeSumOperator, CallWithDefaultParametersOperator, NestedDictSumOperator, ParamMerger
SumPlusXOperatorWithoutInit, SumPlusXOperatorWithPreCompute, TimeFrameToSecondsOperator,
AddOperator, SubOperator, Add2Operator, PreComputeSumOperator, CallWithDefaultParametersOperator,
NestedDictSumOperator, ParamMerger,
LtOperator, LtEOperator, GtOperator, GtEOperator, EqOperator, NotEqOperator,
IsOperator, IsNotOperator, AndOperator, OrOperator
]
)

Expand Down Expand Up @@ -518,4 +634,102 @@ def test_get_docs_to_json():
assert json_data["parameters"][1]["name"] == "y"
assert json_data["parameters"][1]["description"] == "second parameter"
assert json_data["parameters"][1]["required"] is False
assert json_data["parameters"][1]["type"] == "int"
assert json_data["parameters"][1]["type"] == "int"


@pytest.mark.asyncio
async def test_chained_comparison_two_ops(interpreter):
# 0 < 5 <= 10 => (0 < 5) and (5 <= 10) => True
assert await interpreter.interprete("0 < 5 <= 10") is True
# 0 < 10 <= 10 => (0 < 10) and (10 <= 10) => True
assert await interpreter.interprete("0 < 10 <= 10") is True
# 0 < 15 <= 10 => (0 < 15) and (15 <= 10) => False (second fails)
assert await interpreter.interprete("0 < 15 <= 10") is False
# 5 < 3 <= 10 => (5 < 3) and (3 <= 10) => False (first fails)
assert await interpreter.interprete("5 < 3 <= 10") is False
# both fail: 10 < 5 <= 3
assert await interpreter.interprete("10 < 5 <= 3") is False


@pytest.mark.asyncio
async def test_chained_comparison_three_ops(interpreter):
# 1 < 2 < 3 < 4 => all True
assert await interpreter.interprete("1 < 2 < 3 < 4") is True
# 1 < 2 < 3 < 3 => last fails (3 < 3 is False)
assert await interpreter.interprete("1 < 2 < 3 < 3") is False
# 1 <= 1 <= 1 <= 1 => all True
assert await interpreter.interprete("1 <= 1 <= 1 <= 1") is True


@pytest.mark.asyncio
async def test_chained_comparison_mixed_operators(interpreter):
# 0 < 5 >= 3 => (0 < 5) and (5 >= 3) => True
assert await interpreter.interprete("0 < 5 >= 3") is True
# 0 < 5 >= 6 => (0 < 5) and (5 >= 6) => False
assert await interpreter.interprete("0 < 5 >= 6") is False
# 1 <= 2 > 1 => (1 <= 2) and (2 > 1) => True
assert await interpreter.interprete("1 <= 2 > 1") is True
# 1 != 2 < 3 => (1 != 2) and (2 < 3) => True
assert await interpreter.interprete("1 != 2 < 3") is True
# 1 == 1 < 2 => (1 == 1) and (1 < 2) => True
assert await interpreter.interprete("1 == 1 < 2") is True
# 1 == 1 < 0 => (1 == 1) and (1 < 0) => False
assert await interpreter.interprete("1 == 1 < 0") is False


@pytest.mark.asyncio
async def test_chained_comparison_with_expressions(interpreter):
# chained comparison where operands are arithmetic expressions
# 0 < (2 + 3) <= 10 => 0 < 5 <= 10 => True
assert await interpreter.interprete("0 < 2 + 3 <= 10") is True
# 0 < (10 - 3) <= 5 => 0 < 7 <= 5 => False
assert await interpreter.interprete("0 < 10 - 3 <= 5") is False


@pytest.mark.asyncio
async def test_chained_comparison_with_function_calls(interpreter):
# plus_42() returns 42 => 0 < 42 <= 100 => True
assert await interpreter.interprete("0 < plus_42() <= 100") is True
# 0 < 42 <= 41 => False
assert await interpreter.interprete("0 < plus_42() <= 41") is False
# 40 < 42 < 50 => True
assert await interpreter.interprete("40 < plus_42() < 50") is True
# middle operand shared: 0 < plus_42() <= plus_42() => 0 < 42 <= 42 => True
assert await interpreter.interprete("0 < plus_42() <= plus_42()") is True


@pytest.mark.asyncio
async def test_chained_comparison_in_bool_context(interpreter):
# chained comparison as part of a larger boolean expression
# (0 < 5 <= 10) and (1 < 2) => True and True => True
assert await interpreter.interprete("0 < 5 <= 10 and 1 < 2") is True
# (0 < 15 <= 10) and (1 < 2) => False and True => False
assert await interpreter.interprete("0 < 15 <= 10 and 1 < 2") is False
# (0 < 15 <= 10) or (1 < 2) => False or True => True
assert await interpreter.interprete("0 < 15 <= 10 or 1 < 2") is True


@pytest.mark.asyncio
async def test_chained_comparison_boundary_values(interpreter):
# exact boundary: 0 < 0 <= 10 => (0 < 0) is False
assert await interpreter.interprete("0 < 0 <= 10") is False
# exact boundary: 0 < 10 <= 10 => True
assert await interpreter.interprete("0 < 10 <= 10") is True
# negative values via expression: (0 - 5) < 0 < 5 => True
assert await interpreter.interprete("0 - 5 < 0 < 5") is True
# float boundaries
assert await interpreter.interprete("0.0 < 0.5 <= 1.0") is True
assert await interpreter.interprete("0.0 < 1.0 <= 0.5") is False


@pytest.mark.asyncio
async def test_chained_comparison_without_and_operator_raises(interpreter):
# create an interpreter without the And operator to verify the error message
interpreter_no_and = dsl_interpreter.Interpreter([
LtOperator, LtEOperator,
])
# single comparison still works
assert await interpreter_no_and.interprete("1 < 2") is True
# chained comparison requires And and should raise
with pytest.raises(commons_errors.UnsupportedOperatorError, match="Chained comparisons require the 'And' operator"):
interpreter_no_and.prepare("0 < 5 <= 10")
Loading