Skip to content

Commit d207389

Browse files
committed
fixed some assertions
1 parent cf13268 commit d207389

3 files changed

Lines changed: 70 additions & 28 deletions

File tree

pytensor/tensor/rewriting/math.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,9 @@ def local_exp_log(fgraph, node):
506506
# Case for exp(softplus(x)) aka exp(log1pexp) -> 1 + exp(x)
507507
if isinstance(prev_op, ps_math.Softplus) and isinstance(node_op, ps.Exp):
508508
x = x.owner.inputs[0]
509-
return [add(1, exp(x))]
509+
old_out = node.outputs[0]
510+
one_cast = np.asarray(1, dtype=old_out.dtype)
511+
return [add(one_cast, exp(x))]
510512

511513
# Case for expm1(softplus(x)) aka expm1(log1pexp) -> exp(x)
512514
if isinstance(prev_op, ps_math.Softplus) and isinstance(node_op, ps.Expm1):
@@ -591,14 +593,16 @@ def local_exp_log_nan_switch(fgraph, node):
591593
if isinstance(prev_op, ps.Log1p) and isinstance(node_op, ps.Exp):
592594
x = x.owner.inputs[0]
593595
old_out = node.outputs[0]
594-
new_out = switch(ge(x, -1), add(1, x), np.asarray(np.nan, old_out.dtype))
596+
one_cast = np.asarray(1, dtype=old_out.dtype)
597+
new_out = switch(ge(x, -1), add(one_cast, x), np.asarray(np.nan, old_out.dtype))
595598
return [new_out]
596599

597600
# Case for expm1(log(x)) -> x - 1
598601
if isinstance(prev_op, ps.Log) and isinstance(node_op, ps.Expm1):
599602
x = x.owner.inputs[0]
600603
old_out = node.outputs[0]
601-
new_out = switch(ge(x, 0), sub(x, 1), np.asarray(np.nan, old_out.dtype))
604+
one_cast = np.asarray(1, dtype=old_out.dtype)
605+
new_out = switch(ge(x, 0), sub(x, one_cast), np.asarray(np.nan, old_out.dtype))
602606
return [new_out]
603607

604608
# Case for expm1(log1p(x)) -> x
@@ -612,7 +616,8 @@ def local_exp_log_nan_switch(fgraph, node):
612616
if isinstance(prev_op, ps_math.Log1mexp) and isinstance(node_op, ps.Exp):
613617
x = x.owner.inputs[0]
614618
old_out = node.outputs[0]
615-
new_out = switch(le(x, 0), sub(1, exp(x)), np.asarray(np.nan, old_out.dtype))
619+
one_cast = np.asarray(1, dtype=old_out.dtype)
620+
new_out = switch(le(x, 0), sub(one_cast, exp(x)), np.asarray(np.nan, old_out.dtype))
616621
return [new_out]
617622

618623
# Case for expm1(log1mexp(x)) -> -exp(x)
@@ -3393,12 +3398,15 @@ def local_exp_over_1_plus_exp(fgraph, node):
33933398
copy_stack_trace(num, new_num)
33943399

33953400
if len(denom_rest) == 0:
3396-
return [new_num]
3401+
out = new_num
33973402
elif len(denom_rest) == 1:
33983403
out = new_num / denom_rest[0]
33993404
else:
34003405
out = new_num / mul(*denom_rest)
34013406

3407+
if out.dtype != node.outputs[0].dtype:
3408+
out = cast(out, node.outputs[0].dtype)
3409+
34023410
copy_stack_trace(node.outputs[0], out)
34033411
return [out]
34043412

tests/tensor/rewriting/test_elemwise.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,11 @@ def test_expansion_order(self):
656656
(fxv, fyv),
657657
1,
658658
fxv - (fyv / 2),
659-
"float32",
659+
{
660+
"custom": "float32",
661+
"numpy+floatX": "float64",
662+
"numpy": "float64",
663+
},
660664
),
661665
(
662666
fx - true_div(fy, fz),
@@ -673,12 +677,23 @@ def test_expansion_order(self):
673677
1,
674678
fxv - ((ixv * 100) // (iyv * 1000)),
675679
{
676-
"custom": "float64",
677-
"numpy + floatX": config.floatX,
680+
"custom": config.floatX,
681+
"numpy+floatX": "float64",
678682
"numpy": "float64",
679683
},
680684
), # 40
681-
(fx - (fy / 2), (fx, fy), (fxv, fyv), 1, fxv - (fyv / 2), "float32"),
685+
(
686+
fx - (fy / 2),
687+
(fx, fy),
688+
(fxv, fyv),
689+
1,
690+
fxv - (fyv / 2),
691+
{
692+
"custom": "float32",
693+
"numpy+floatX": "float64",
694+
"numpy": "float64",
695+
},
696+
),
682697
(
683698
fx - (fy % fz),
684699
(fx, fy, fz),
@@ -823,8 +838,8 @@ def test_expansion_order(self):
823838
1,
824839
fxv - (iyv | izv),
825840
{
826-
"custom": "float64",
827-
"numpy + floatX": config.floatX,
841+
"custom": config.floatX,
842+
"numpy+floatX": "float64",
828843
"numpy": "float64",
829844
},
830845
),
@@ -835,8 +850,8 @@ def test_expansion_order(self):
835850
1,
836851
fxv - (iyv ^ izv),
837852
{
838-
"custom": "float64",
839-
"numpy + floatX": config.floatX,
853+
"custom": config.floatX,
854+
"numpy+floatX": "float64",
840855
"numpy": "float64",
841856
},
842857
), # 60
@@ -847,8 +862,8 @@ def test_expansion_order(self):
847862
1,
848863
fxv - (iyv & izv),
849864
{
850-
"custom": "float64",
851-
"numpy + floatX": config.floatX,
865+
"custom": config.floatX,
866+
"numpy+floatX": "float64",
852867
"numpy": "float64",
853868
},
854869
),
@@ -859,8 +874,8 @@ def test_expansion_order(self):
859874
1,
860875
fxv - (~iyv),
861876
{
862-
"custom": "float64",
863-
"numpy + floatX": config.floatX,
877+
"custom": config.floatX,
878+
"numpy+floatX": "float64",
864879
"numpy": "float64",
865880
},
866881
),
@@ -957,7 +972,11 @@ def test_expansion_order(self):
957972
np.sum(-((fxv - fyv) ** 2) / 2),
958973
-(fxv - fyv),
959974
),
960-
("float32", "float32"),
975+
{
976+
"custom": ("float32", "float32"),
977+
"numpy+floatX": ("float64", "float32"),
978+
"numpy": ("float64", "float32"),
979+
},
961980
),
962981
# Two Composite graphs that share the same input, but are split by
963982
# a non-elemwise operation (Assert)
@@ -995,7 +1014,11 @@ def test_expansion_order(self):
9951014
(fxv,),
9961015
4,
9971016
(np.sum(fxv + 5) * np.exp(fxv) / (fxv + 5),),
998-
("float32",),
1017+
{
1018+
"custom": ("float32",),
1019+
"numpy+floatX": ("float64",),
1020+
"numpy": ("float64",),
1021+
},
9991022
),
10001023
pytest.param(
10011024
(
@@ -1038,7 +1061,8 @@ def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True):
10381061
self._shared(np.zeros((5,) * g_.ndim, dtype=od), "out")
10391062
for g_, od in zip(g, out_dtype, strict=True)
10401063
]
1041-
assert all(o.dtype == g_.dtype for o, g_ in zip(out, g, strict=True))
1064+
for o, g_ in zip(out, g, strict=True):
1065+
assert o.dtype == g_.dtype, f"Mismatch! Expected {o.dtype}, but graph evaluated to {g_.dtype}"
10421066
f = function(
10431067
sym_inputs, [], updates=list(zip(out, g, strict=True)), mode=self.mode
10441068
)

tests/tensor/rewriting/test_math.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,7 +1168,7 @@ def test_log1p():
11681168
# should work for int
11691169
z = imatrix()
11701170
f = function([z], log(1 + (z)), mode=m)
1171-
assert [node.op for node in f.maker.fgraph.toposort()] == [log1p]
1171+
assert ps.log1p in [getattr(node.op, "scalar_op", None) for node in f.maker.fgraph.toposort()]
11721172

11731173

11741174
def test_local_log_add_exp():
@@ -1353,7 +1353,15 @@ def assert_eqs_const(self, f, val, op=deep_copy_op):
13531353
def assert_identity(self, f):
13541354
topo = f.maker.fgraph.toposort()
13551355
assert len(topo) == 1
1356-
assert topo[0].op == deep_copy_op
1356+
1357+
# If the operation mathematically forced an upcast, the identity is a Cast
1358+
in_dtype = f.maker.fgraph.inputs[0].type.dtype
1359+
out_dtype = f.maker.fgraph.outputs[0].type.dtype
1360+
1361+
if in_dtype != out_dtype:
1362+
assert "Cast" in str(topo[0].op), f"Expected Cast, got {topo[0].op}"
1363+
else:
1364+
assert topo[0].op == deep_copy_op
13571365
if f.outputs[0].variable.dtype == "bool":
13581366
x_vals = [0, 1]
13591367
else:
@@ -1982,7 +1990,7 @@ def test_log1pexp_log(self):
19821990
f.maker.fgraph.outputs,
19831991
[
19841992
pt.switch(
1985-
x >= np.array([[0]], dtype=np.int8),
1993+
x >= np.array([[0]], dtype=np.int64),
19861994
pt.log1p(x),
19871995
np.array([[np.nan]], dtype=np.float32),
19881996
)
@@ -2006,7 +2014,7 @@ def test_log1mexp_log(self):
20062014
f.maker.fgraph.outputs,
20072015
[
20082016
pt.switch(
2009-
x >= np.array([[0]], dtype=np.int8),
2017+
x >= np.array([[0]], dtype=np.int64),
20102018
pt.log1p(-x),
20112019
np.array([[np.nan]], dtype=np.float32),
20122020
)
@@ -2029,7 +2037,7 @@ def test_log1mexp_log1mexp(self):
20292037
f.maker.fgraph.outputs,
20302038
[
20312039
pt.switch(
2032-
x <= np.array([[0]], dtype=np.int8),
2040+
x <= np.array([[0]], dtype=np.int64),
20332041
x,
20342042
np.array([[np.nan]], dtype=np.float32),
20352043
)
@@ -2095,7 +2103,7 @@ def test_sqrt_sqr(self):
20952103
out = rewrite_graph(out, include=["canonicalize", "specialize", "stabilize"])
20962104

20972105
expected = switch(
2098-
ge(x, np.zeros((1, 1), dtype="int8")),
2106+
ge(x, np.zeros((1, 1), dtype="int64")),
20992107
x,
21002108
np.full((1, 1), np.nan, dtype=x.type.dtype),
21012109
)
@@ -4176,16 +4184,18 @@ def test_local_1msigmoid(self):
41764184
m = self.get_mode(excluding=["fusion", "inplace"])
41774185
x = fscalar()
41784186
xd = dscalar()
4187+
4188+
one = pt.constant(1.0, dtype="float32")
41794189

41804190
# Test `exp_over_1_plus_exp`
4181-
f = pytensor.function([x], 1 - exp(x) / (1 + exp(x)), mode=m)
4191+
f = pytensor.function([x], one - exp(x) / (one + exp(x)), mode=m)
41824192
# FIXME: PatternNodeRewriter does not copy stack trace
41834193
# (see https://github.com/Theano/Theano/issues/4581)
41844194
# assert check_stack_trace(f, ops_to_check=[neg, sigmoid])
41854195
assert equal_computations(f.maker.fgraph.outputs, [sigmoid(-x)])
41864196

41874197
# Test `inv_1_plus_exp`
4188-
f = pytensor.function([x], 1 - pt.fill(x, 1.0) / (1 + exp(-x)), mode=m)
4198+
f = pytensor.function([x], one - pt.fill(x, one) / (one + exp(-x)), mode=m)
41894199
# assert check_stack_trace(f, ops_to_check=[neg, sigmoid])
41904200
assert equal_computations(f.maker.fgraph.outputs, [sigmoid(-x)])
41914201

0 commit comments

Comments
 (0)