Fix swapped conditions in local_sqrt_sqr rewrite#1922
Fix swapped conditions in local_sqrt_sqr rewrite#1922WHOIM1205 wants to merge 3 commits intopymc-devs:mainfrom
Conversation
Signed-off-by: WHOIM1205 <rathourprateek8@gmail.com>
|
hey @ricardoV94 |
|
Can you point to the original PR here in the comments? Want to see what might have failed in the review process |
tests/tensor/rewriting/test_math.py
Outdated
| assert equal_computations([out], [expected]) | ||
|
|
||
| def test_sqr_sqrt_integer_upcast(self): | ||
| f = pytensor.function([x], sqr(sqrt(x)), mode=self.mode) |
There was a problem hiding this comment.
I still wouldn't compile and evaluate the functions, the assert_equal_computations shows what the function does already. An independent test would be to compare against the unoptimized function not an expected numerical value
There was a problem hiding this comment.
Thanks for the suggestion. I have updated the tests accordingly. The redundant numerical evaluation and function compilation have been removed and the test now relies on assert_equal_computations for verifying the rewrite behavior.
Signed-off-by: WHOIM1205 <rathourprateek8@gmail.com>
Signed-off-by: WHOIM1205 <rathourprateek8@gmail.com>
|
hey @ricardoV94 is there anything i can change |
Fix: Correct swapped logic in
local_sqrt_sqrrewriteSummary
Fixes a numerical correctness bug in
pytensor/tensor/rewriting/math.pywhere the rewrite rule
local_sqrt_sqrhad its conditions swapped.The previous implementation incorrectly transformed:
sqrt(sqr(x))→switch(x >= 0, x, nan)(should beabs(x))sqr(sqrt(x))→abs(x)(should beswitch(x >= 0, x, nan))This caused silent wrong numerical results, especially for negative inputs.
Root Cause
prev_op(inner op) andnode_op(outer op) checks were reversed:Sqr(Sqrt(x))returnedabs(x)Sqrt(Sqr(x))returnedswitch(...)The return values were correct — but attached to the wrong condition.
Fix
Swapped the two
isinstanceconditions so that:sqrt(sqr(x))→abs(x)sqr(sqrt(x))→switch(x >= 0, x, nan)This is a minimal two-line logical correction.
Tests Updated
Impact
nanpollution in common patterns likesqrt(x**2).