Skip to content

Commit 6fcf318

Browse files
committed
Refactors Log operation
1 parent cf94b38 commit 6fcf318

2 files changed

Lines changed: 35 additions & 10 deletions

File tree

sasdata/quantities/quantity.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -713,27 +713,52 @@ def _clean_ab(self, a, b):
713713
return Div(a, b)
714714

715715

716-
class Log(BinaryOperation):
716+
class Log(Operation):
717717

718718
serialisation_name = "log"
719719

720+
def __init__(self, a: Operation, base: float):
721+
self.a = a
722+
self.base = base
723+
720724
def evaluate(self, variables: dict[int, T]) -> Operation:
721-
return log(self.a.evaluate(variables), self.b.evaluate(variables))
725+
return log(self.a.evaluate(variables), self.base)
722726

723727
def _derivative(self, hash_value: int) -> Operation:
724-
return Inv(Mul(self.a, Ln(self.b)))
728+
return Inv(Mul(self.a, Ln(Constant(self.base))))
729+
730+
def _clean_ab(self) -> Operation:
731+
a = self.a._clean()
725732

726-
def _clean_ab(self, a, b):
727733
if isinstance(a, MultiplicativeIdentity):
728734
# Convert log(1) to 0
729735
return AdditiveIdentity()
730736

731-
elif a == b:
732-
# Convert log(b) to 1
737+
elif a == self.base:
738+
# Convert log(base) to 1
733739
return MultiplicativeIdentity()
734740

735741
else:
736-
return Log(a, b)
742+
return Log(a, self.base)
743+
744+
def _serialise_parameters(self) -> dict[str, Any]:
745+
return {"a": Operation._serialise_json(self.a),
746+
"base": self.base}
747+
748+
@staticmethod
749+
def _deserialise(parameters: dict) -> "Operation":
750+
return Log(Operation.deserialise_json(parameters["a"]), parameters["base"])
751+
752+
def summary(self, indent_amount: int=0, indent=" "):
753+
return (f"{indent_amount*indent}Log(\n" +
754+
self.a.summary(indent_amount+1, indent) + "\n" +
755+
f"{(indent_amount+1)*indent}{self.base}\n" +
756+
f"{indent_amount*indent})")
757+
758+
def __eq__(self, other):
759+
if isinstance(other, Log):
760+
return self.a == other.a and self.base == other.base
761+
return False
737762

738763

739764
class Pow(Operation):

test/quantities/utest_operations.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
Add(
3232
Neg(Inv(MultiplicativeIdentity())),
3333
Ln(Transpose(Variable("x")))),
34-
Log(Constant(7), Constant(2))),
34+
Log(Constant(7), 2)),
3535
AdditiveIdentity()),
3636
2),
3737
Variable("y"))
@@ -69,7 +69,7 @@ def test_unary_summary(op):
6969

7070
@pytest.mark.parametrize("op", [Add, Div, Dot, Log, MatMul, Mul, Pow, Sub])
7171
def test_binary_summary(op):
72-
f = op(Constant(1), 1 if op == Pow else Constant(1))
72+
f = op(Constant(1), 1 if op == Log or op == Pow else Constant(1))
7373
assert f.summary() == f"{op.__name__}(\n 1\n 1\n)"
7474

7575

@@ -114,7 +114,7 @@ def test_unary_evaluation(op, a, result):
114114
(Log, 100, 10, 2),
115115
(Log, 256, 2, 8)])
116116
def test_binary_evaluation(op, a, b, result):
117-
f = op(Constant(a), b if op == Pow else Constant(b))
117+
f = op(Constant(a), b if op == Log or op == Pow else Constant(b))
118118
assert f.evaluate({}) == result
119119

120120
@pytest.mark.parametrize("op, a, b, result", [

0 commit comments

Comments
 (0)