Skip to content

Commit ec4407d

Browse files
committed
Recognize alternative form of sigmoid in logprob inference
1 parent 714b4a0 commit ec4407d

File tree

2 files changed

+24
-18
lines changed

2 files changed

+24
-18
lines changed

pymc/logprob/rewriting.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
from pytensor.tensor.elemwise import DimShuffle, Elemwise
7171
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
7272
from pytensor.tensor.rewriting.basic import register_canonicalize
73+
from pytensor.tensor.rewriting.math import local_exp_over_1_plus_exp
7374
from pytensor.tensor.rewriting.shape import ShapeFeature
7475
from pytensor.tensor.rewriting.uncanonicalize import local_max_and_argmax
7576
from pytensor.tensor.subtensor import (
@@ -359,7 +360,12 @@ def incsubtensor_rv_replace(fgraph, node):
359360

360361
logprob_rewrites_db = SequenceDB()
361362
logprob_rewrites_db.name = "logprob_rewrites_db"
363+
# Introduce sigmoid. We do it before canonicalization so that useless mul are removed next
364+
logprob_rewrites_db.register(
365+
"local_exp_over_1_plus_exp", out2in(local_exp_over_1_plus_exp), "basic"
366+
)
362367
logprob_rewrites_db.register("pre-canonicalize", optdb.query("+canonicalize"), "basic")
368+
# Split max_and_argmax
363369
logprob_rewrites_db.register("local_max_and_argmax", out2in(local_max_and_argmax), "basic")
364370

365371
# These rewrites convert un-measurable variables into their measurable forms,

tests/logprob/test_transforms.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,31 +1127,31 @@ def test_cosh_rv_transform():
11271127
)
11281128

11291129

1130-
TRANSFORMATIONS = {
1131-
"log1p": (pt.log1p, lambda x: pt.log(1 + x)),
1132-
"softplus": (pt.softplus, lambda x: pt.log(1 + pt.exp(x))),
1133-
"log1mexp": (pt.log1mexp, lambda x: pt.log(1 - pt.exp(x))),
1134-
"log2": (pt.log2, lambda x: pt.log(x) / pt.log(2)),
1135-
"log10": (pt.log10, lambda x: pt.log(x) / pt.log(10)),
1136-
"exp2": (pt.exp2, lambda x: pt.exp(pt.log(2) * x)),
1137-
"expm1": (pt.expm1, lambda x: pt.exp(x) - 1),
1138-
"sigmoid": (pt.sigmoid, lambda x: 1 / (1 + pt.exp(-x))),
1139-
}
1140-
1141-
1142-
@pytest.mark.parametrize("transform", TRANSFORMATIONS.keys())
1143-
def test_special_log_exp_transforms(transform):
1130+
@pytest.mark.parametrize(
1131+
"canonical_func,raw_func",
1132+
[
1133+
(pt.log1p, lambda x: pt.log(1 + x)),
1134+
(pt.softplus, lambda x: pt.log(1 + pt.exp(x))),
1135+
(pt.log1mexp, lambda x: pt.log(1 - pt.exp(x))),
1136+
(pt.log2, lambda x: pt.log(x) / pt.log(2)),
1137+
(pt.log10, lambda x: pt.log(x) / pt.log(10)),
1138+
(pt.exp2, lambda x: pt.exp(pt.log(2) * x)),
1139+
(pt.expm1, lambda x: pt.exp(x) - 1),
1140+
(pt.sigmoid, lambda x: 1 / (1 + pt.exp(-x))),
1141+
(pt.sigmoid, lambda x: pt.exp(x) / (1 + pt.exp(x))),
1142+
],
1143+
)
1144+
def test_special_log_exp_transforms(canonical_func, raw_func):
11441145
base_rv = pt.random.normal(name="base_rv")
11451146
vv = pt.scalar("vv")
11461147

1147-
transform_func, ref_func = TRANSFORMATIONS[transform]
1148-
transformed_rv = transform_func(base_rv)
1149-
ref_transformed_rv = ref_func(base_rv)
1148+
transformed_rv = raw_func(base_rv)
1149+
ref_transformed_rv = canonical_func(base_rv)
11501150

11511151
logp_test = logp(transformed_rv, vv)
11521152
logp_ref = logp(ref_transformed_rv, vv)
11531153

1154-
if transform in ["log2", "log10"]:
1154+
if canonical_func in (pt.log2, pt.log10):
11551155
# in the cases of log2 and log10 floating point inprecision causes failure
11561156
# from equal_computations so evaluate logp and check all close instead
11571157
vv_test = np.array(0.25)

0 commit comments

Comments
 (0)