diff --git a/CHANGELOG.md b/CHANGELOG.md index 80bea8d07..67e86568c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ ### Changed - Move magic methods (`__radd__`, `__sub__`, `__rsub__`, `__rmul__`, `__richcmp__`, `__neg__`, and `__rtruediv__`) to `ExprLike` base class (#1204) - Speed up `Expr.__add__` and `Expr.__iadd__` via the C-level API +- Replace Python math with C-level math functions and refactor unary expressions. ### Removed ## 6.2.1 - 2026.05.16 diff --git a/src/pyscipopt/expr.pxi b/src/pyscipopt/expr.pxi index 62f0c880d..4e83099b6 100644 --- a/src/pyscipopt/expr.pxi +++ b/src/pyscipopt/expr.pxi @@ -42,7 +42,6 @@ # which should, in princple, modify the expr. However, since we do not implement __isub__, __sub__ # gets called (I guess) and so a copy is returned. # Modifying the expression directly would be a bug, given that the expression might be re-used by the user. -import math from typing import TYPE_CHECKING, Literal, Union import numpy as np @@ -54,6 +53,12 @@ from cpython.number cimport PyNumber_Check from cpython.object cimport Py_LE, Py_EQ, Py_GE, Py_TYPE from cpython.ref cimport PyObject from cpython.tuple cimport PyTuple_GET_ITEM +from libc.math cimport cos as c_cos +from libc.math cimport exp as c_exp +from libc.math cimport fabs as c_fabs +from libc.math cimport log as c_log +from libc.math cimport sqrt as c_sqrt +from libc.math cimport sin as c_sin cimport numpy as cnp from pyscipopt.scip cimport Variable, Solution @@ -278,23 +283,23 @@ cdef class ExprLike: def __neg__(self, /) -> Union[Expr, GenExpr]: return self * -1.0 - def __abs__(self) -> GenExpr: - return UnaryExpr(Operator.fabs, buildGenExprObj(self)) + def __abs__(self, /) -> AbsExpr: + return AbsExpr(Operator.fabs, buildGenExprObj(self)) - def exp(self) -> GenExpr: - return UnaryExpr(Operator.exp, buildGenExprObj(self)) + def exp(self, /) -> ExpExpr: + return ExpExpr(Operator.exp, buildGenExprObj(self)) - def log(self) -> GenExpr: - return UnaryExpr(Operator.log, buildGenExprObj(self)) + def log(self, /) -> LogExpr: + return LogExpr(Operator.log, buildGenExprObj(self)) - def sqrt(self) -> GenExpr: - return UnaryExpr(Operator.sqrt, buildGenExprObj(self)) + def sqrt(self, /) -> SqrtExpr: + return SqrtExpr(Operator.sqrt, buildGenExprObj(self)) - def sin(self) -> GenExpr: - return UnaryExpr(Operator.sin, buildGenExprObj(self)) + def sin(self, /) -> SinExpr: + return SinExpr(Operator.sin, buildGenExprObj(self)) - def cos(self) -> GenExpr: - return UnaryExpr(Operator.cos, buildGenExprObj(self)) + def cos(self, /) -> CosExpr: + return CosExpr(Operator.cos, buildGenExprObj(self)) ##@details Polynomial expressions of variables with operator overloading. \n @@ -799,24 +804,54 @@ cdef class PowExpr(GenExpr): return (self.children[0])._evaluate(sol) ** self.expo -# Exp, Log, Sqrt, Sin, Cos Expressions cdef class UnaryExpr(GenExpr): + def __init__(self, op, expr): self.children = [] self.children.append(expr) self._op = op - def __abs__(self) -> UnaryExpr: - if self._op == "abs": - return self.copy() - return UnaryExpr(Operator.fabs, self) - - def __repr__(self): + def __repr__(self) -> str: return self._op + "(" + self.children[0].__repr__() + ")" + +cdef class AbsExpr(UnaryExpr): + + def __abs__(self) -> AbsExpr: + return self.copy() + + cpdef double _evaluate(self, Solution sol) except *: + return c_fabs((self.children[0])._evaluate(sol)) + + +cdef class ExpExpr(UnaryExpr): + + cpdef double _evaluate(self, Solution sol) except *: + return c_exp((self.children[0])._evaluate(sol)) + + +cdef class LogExpr(UnaryExpr): + + cpdef double _evaluate(self, Solution sol) except *: + return c_log((self.children[0])._evaluate(sol)) + + +cdef class SqrtExpr(UnaryExpr): + + cpdef double _evaluate(self, Solution sol) except *: + return c_sqrt((self.children[0])._evaluate(sol)) + + +cdef class SinExpr(UnaryExpr): + + cpdef double _evaluate(self, Solution sol) except *: + return c_sin((self.children[0])._evaluate(sol)) + + +cdef class CosExpr(UnaryExpr): + cpdef double _evaluate(self, Solution sol) except *: - cdef double res = (self.children[0])._evaluate(sol) - return math.fabs(res) if self._op == "abs" else getattr(math, self._op)(res) + return c_cos((self.children[0])._evaluate(sol)) # class for constant expressions diff --git a/src/pyscipopt/scip.pyi b/src/pyscipopt/scip.pyi index 86196cfc1..16b41bc2a 100644 --- a/src/pyscipopt/scip.pyi +++ b/src/pyscipopt/scip.pyi @@ -338,12 +338,12 @@ class ExprLike: def __rmul__(self, other: object, /) -> Incomplete: ... def __rtruediv__(self, other: object, /) -> GenExpr: ... def __neg__(self, /) -> Union[Expr, GenExpr]: ... - def __abs__(self) -> GenExpr: ... - def exp(self) -> GenExpr: ... - def log(self) -> GenExpr: ... - def sqrt(self) -> GenExpr: ... - def sin(self) -> GenExpr: ... - def cos(self) -> GenExpr: ... + def __abs__(self, /) -> AbsExpr: ... + def exp(self, /) -> ExpExpr: ... + def log(self, /) -> LogExpr: ... + def sqrt(self, /) -> SqrtExpr: ... + def sin(self, /) -> SinExpr: ... + def cos(self, /) -> CosExpr: ... @disjoint_base class Expr(ExprLike): @@ -2262,7 +2262,13 @@ class Term: class UnaryExpr(GenExpr): def __init__(self, *args: Incomplete, **kwargs: Incomplete) -> None: ... - def __abs__(self) -> GenExpr: ... + +class AbsExpr(UnaryExpr): ... +class ExpExpr(UnaryExpr): ... +class LogExpr(UnaryExpr): ... +class SqrtExpr(UnaryExpr): ... +class SinExpr(UnaryExpr): ... +class CosExpr(UnaryExpr): ... @disjoint_base class VarExpr(GenExpr): diff --git a/tests/test_expr.py b/tests/test_expr.py index f35096f73..5a3db9cdf 100644 --- a/tests/test_expr.py +++ b/tests/test_expr.py @@ -212,8 +212,16 @@ def test_getVal_with_GenExpr(): assert m.getVal(y / x) == 2 # test "**(prod(1.0,**(sum(0.0,prod(1.0,x)),-1)),2)" assert m.getVal((1 / x) ** 2) == 1 - # test "sin(sum(0.0,prod(1.0,x)))" + + # test C-level math functions + assert m.getVal(abs(x)) == 1 + assert m.getVal(abs(-x)) == 1 + assert m.getVal(abs(abs(-x))) == 1 + assert round(m.getVal(exp(x)), 6) == round(math.exp(1), 6) + assert round(m.getVal(log(x)), 6) == round(math.log(1), 6) + assert round(m.getVal(sqrt(x)), 6) == round(math.sqrt(1), 6) assert round(m.getVal(sin(x)), 6) == round(math.sin(1), 6) + assert round(m.getVal(cos(x)), 6) == round(math.cos(1), 6) with pytest.raises(TypeError): m.getVal(1)