Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def local_sqrt_sqr(fgraph, node):
node_op = node.op.scalar_op

# Case for sqrt(sqr(x)) -> |x|
if isinstance(prev_op, ps.Sqrt) and isinstance(node_op, ps.Sqr):
if isinstance(prev_op, ps.Sqr) and isinstance(node_op, ps.Sqrt):
new_out = pt_abs(x.owner.inputs[0])
old_out = node.outputs[0]

Expand All @@ -537,7 +537,7 @@ def local_sqrt_sqr(fgraph, node):
return [new_out]

# Case for sqr(sqrt(x)) -> x
if isinstance(prev_op, ps.Sqr) and isinstance(node_op, ps.Sqrt):
if isinstance(prev_op, ps.Sqrt) and isinstance(node_op, ps.Sqr):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(ge(x, 0), x, np.asarray(np.nan, old_out.dtype))
Expand Down
32 changes: 22 additions & 10 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2082,18 +2082,11 @@ def setup_method(self):
self.rng = np.random.default_rng()

def test_sqr_sqrt(self):
# sqrt(x) ** 2 -> x
# sqr(sqrt(x)) -> x for x >= 0, nan for x < 0
x = pt.tensor("x", shape=(None, None))
out = sqr(sqrt(x))
out = rewrite_graph(out, include=["canonicalize", "specialize", "stabilize"])

assert equal_computations([out], [pt_abs(x)])

def test_sqrt_sqr(self):
x = pt.tensor("x", shape=(None, None))
out = sqrt(sqr(x))
out = rewrite_graph(out, include=["canonicalize", "specialize", "stabilize"])

expected = switch(
ge(x, np.zeros((1, 1), dtype="int8")),
x,
Expand All @@ -2102,9 +2095,28 @@ def test_sqrt_sqr(self):

assert equal_computations([out], [expected])

def test_sqr_sqrt_integer_upcast(self):
f = pytensor.function([x], sqr(sqrt(x)), mode=self.mode)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still wouldn't compile and evaluate the functions, the assert_equal_computations shows what the function does already. An independent test would be to compare against the unoptimized function not an expected numerical value

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. I have updated the tests accordingly. The redundant numerical evaluation and function compilation have been removed and the test now relies on assert_equal_computations for verifying the rewrite behavior.

test_val = np.array([[-3.0, -1.0, 0.0, 1.0, 3.0]])
result = f(test_val)
np.testing.assert_array_equal(np.isnan(result[0, :2]), True)
np.testing.assert_allclose(result[0, 2:], [0.0, 1.0, 3.0])

def test_sqrt_sqr(self):
# sqrt(sqr(x)) -> |x|
x = pt.tensor("x", shape=(None, None))
out = sqrt(sqr(x))
out = rewrite_graph(out, include=["canonicalize", "specialize", "stabilize"])

assert equal_computations([out], [pt_abs(x)])

f = pytensor.function([x], sqrt(sqr(x)), mode=self.mode)
test_val = np.array([[-3.0, -1.0, 0.0, 1.0, 3.0]])
result = f(test_val)
np.testing.assert_allclose(result, np.abs(test_val))

def test_sqrt_sqr_integer_upcast(self):
x = ivector("x")
out = sqr(sqrt(x))
out = sqrt(sqr(x))
dtype = out.type.dtype
out = rewrite_graph(out, include=["canonicalize", "specialize", "stabilize"])

Expand Down
Loading