Skip to content

Commit b1332b2

Browse files
committed
Reintroduce test_log1mexp_stabilization accidentally removed in ff1a3a9
1 parent b8e939e commit b1332b2

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

tests/tensor/rewriting/test_math.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4115,3 +4115,24 @@ def test_local_add_neg_to_sub_const():
41154115

41164116
x_test = np.array([3, 4], dtype=config.floatX)
41174117
assert np.allclose(f(x_test), x_test + (-const))
4118+
4119+
4120+
def test_log1mexp_stabilization():
4121+
mode = Mode("py").including("stabilize")
4122+
4123+
x = vector()
4124+
f = function([x], log(1 - exp(x)), mode=mode)
4125+
4126+
nodes = [node.op for node in f.maker.fgraph.toposort()]
4127+
assert nodes == [at.log1mexp]
4128+
4129+
# Check values that would under or overflow without rewriting
4130+
assert f([-(2.0**-55)]) != -np.inf
4131+
overflow_value = -500.0 if config.floatX == "float64" else -100.0
4132+
assert f([overflow_value]) < 0
4133+
4134+
# Check values around the switch point np.log(0.5)
4135+
assert np.allclose(
4136+
f(np.array([-0.8, -0.6], dtype=config.floatX)),
4137+
np.log(1 - np.exp([-0.8, -0.6])),
4138+
)

0 commit comments

Comments
 (0)