Skip to content

Commit 81d31c8

Browse files
List more math function in API docs (#7211)
Also removes deprecated functions and deprecates numpy helpers Co-authored-by: Ricardo Vieira <ricardo.vieira1994@gmail.com>
1 parent 89f6fcf commit 81d31c8

File tree

4 files changed

+111
-67
lines changed

4 files changed

+111
-67
lines changed

docs/source/api/math.rst

Lines changed: 74 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,56 +19,98 @@ Functions exposed in pymc namespace
1919
invlogit
2020
probit
2121
invprobit
22+
logaddexp
2223
logsumexp
2324

25+
2426
Functions exposed in pymc.math
2527
------------------------------
2628

2729
.. automodule:: pymc.math
2830
.. autosummary::
2931
:toctree: generated/
3032

31-
dot
32-
constant
33-
flatten
34-
zeros_like
35-
ones_like
36-
stack
37-
concatenate
38-
sum
33+
abs
3934
prod
40-
lt
41-
gt
42-
le
43-
ge
35+
dot
4436
eq
4537
neq
46-
switch
47-
clip
48-
where
49-
and_
50-
or_
51-
abs
38+
ge
39+
gt
40+
le
41+
lt
5242
exp
5343
log
54-
cos
44+
sgn
45+
sqr
46+
sqrt
47+
sum
48+
ceil
49+
floor
5550
sin
56-
tan
57-
cosh
5851
sinh
52+
arcsin
53+
arcsinh
54+
cos
55+
cosh
56+
arccos
57+
arccosh
58+
tan
5959
tanh
60-
sqr
61-
sqrt
62-
erf
63-
erfinv
64-
dot
60+
arctan
61+
arctanh
62+
cumprod
63+
cumsum
64+
matmul
65+
and_
66+
broadcast_to
67+
clip
68+
concatenate
69+
flatten
70+
or_
71+
stack
72+
switch
73+
where
74+
flatten_list
75+
constant
76+
max
6577
maximum
78+
mean
79+
min
6680
minimum
67-
sgn
68-
ceil
69-
floor
70-
matrix_inverse
71-
sigmoid
81+
round
82+
erf
83+
erfc
84+
erfcinv
85+
erfinv
86+
log1pexp
87+
log1mexp
88+
logaddexp
7289
logsumexp
73-
invlogit
90+
logdiffexp
7491
logit
92+
invlogit
93+
probit
94+
invprobit
95+
sigmoid
96+
softmax
97+
log_softmax
98+
logbern
99+
full
100+
full_like
101+
ones
102+
ones_like
103+
zeros
104+
zeros_like
105+
kronecker
106+
cartesian
107+
kron_dot
108+
kron_solve_lower
109+
kron_solve_upper
110+
kron_diag
111+
flat_outer
112+
expand_packed_triangular
113+
batched_diag
114+
block_diagonal
115+
matrix_inverse
116+
logdet

pymc/math.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
ones_like,
7474
or_,
7575
prod,
76+
round,
7677
sgn,
7778
sigmoid,
7879
sin,
@@ -178,6 +179,7 @@
178179
"expand_packed_triangular",
179180
"batched_diag",
180181
"block_diagonal",
182+
"round",
181183
]
182184

183185

@@ -272,27 +274,18 @@ def kron_diag(*diags):
272274
return reduce(flat_outer, diags)
273275

274276

275-
def round(*args, **kwargs):
276-
"""
277-
Temporary function to silence round warning in PyTensor. Please remove
278-
when the warning disappears.
279-
"""
280-
kwargs["mode"] = "half_to_even"
281-
return pt.round(*args, **kwargs)
282-
283-
284-
def tround(*args, **kwargs):
285-
warnings.warn("tround is deprecated. Use round instead.")
286-
return round(*args, **kwargs)
287-
288-
289277
def logdiffexp(a, b):
290278
"""log(exp(a) - exp(b))"""
291279
return a + pt.log1mexp(b - a)
292280

293281

294282
def logdiffexp_numpy(a, b):
295283
"""log(exp(a) - exp(b))"""
284+
warnings.warn(
285+
"pymc.math.logdiffexp_numpy is being deprecated.",
286+
FutureWarning,
287+
stacklevel=2,
288+
)
296289
return a + log1mexp_numpy(b - a, negative_input=True)
297290

298291

@@ -341,6 +334,11 @@ def log1mexp_numpy(x, *, negative_input=False):
341334
For details, see
342335
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
343336
"""
337+
warnings.warn(
338+
"pymc.math.log1mexp_numpy is being deprecated.",
339+
FutureWarning,
340+
stacklevel=2,
341+
)
344342
x = np.asarray(x, dtype="float")
345343

346344
if not negative_input:

tests/distributions/test_continuous.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,8 +423,11 @@ def test_kumaraswamy(self):
423423
def scipy_log_pdf(value, a, b):
424424
return np.log(a) + np.log(b) + (a - 1) * np.log(value) + (b - 1) * np.log(1 - value**a)
425425

426+
def log1mexp(x):
427+
return np.log1p(-np.exp(x)) if x < np.log(0.5) else np.log(-np.expm1(x))
428+
426429
def scipy_log_cdf(value, a, b):
427-
return pm.math.log1mexp_numpy(b * np.log1p(-(value**a)), negative_input=True)
430+
return log1mexp(b * np.log1p(-(value**a)))
428431

429432
check_logp(
430433
pm.Kumaraswamy,

tests/test_math.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -145,45 +145,46 @@ def test_log1mexp():
145145
)
146146
actual = pt.log1mexp(-vals).eval()
147147
npt.assert_allclose(actual, expected)
148+
148149
with warnings.catch_warnings():
149150
warnings.filterwarnings("ignore", "divide by zero encountered in log", RuntimeWarning)
150151
warnings.filterwarnings("ignore", "invalid value encountered in log", RuntimeWarning)
151-
actual_ = log1mexp_numpy(-vals, negative_input=True)
152+
with pytest.warns(FutureWarning, match="deprecated"):
153+
actual_ = log1mexp_numpy(-vals, negative_input=True)
152154
npt.assert_allclose(actual_, expected)
153155
# Check that input was not changed in place
154156
npt.assert_allclose(vals, vals_)
155157

156158

159+
@pytest.mark.filterwarnings("error")
157160
def test_log1mexp_numpy_no_warning():
158161
"""Assert RuntimeWarning is not raised for very small numbers"""
159-
with warnings.catch_warnings():
160-
warnings.simplefilter("error")
162+
with pytest.warns(FutureWarning, match="deprecated"):
161163
log1mexp_numpy(-1e-25, negative_input=True)
162164

163165

164166
def test_log1mexp_numpy_integer_input():
165-
assert np.isclose(log1mexp_numpy(-2, negative_input=True), pt.log1mexp(-2).eval())
167+
with pytest.warns(FutureWarning, match="deprecated"):
168+
assert np.isclose(log1mexp_numpy(-2, negative_input=True), pt.log1mexp(-2).eval())
166169

167170

171+
@pytest.mark.filterwarnings("error")
168172
def test_log1mexp_deprecation_warnings():
169-
with pytest.warns(
170-
FutureWarning,
171-
match="pymc.math.log1mexp_numpy will expect a negative input",
172-
):
173-
res_pos = log1mexp_numpy(2)
173+
with pytest.warns(FutureWarning, match="deprecated"):
174+
with pytest.warns(
175+
FutureWarning,
176+
match="pymc.math.log1mexp_numpy will expect a negative input",
177+
):
178+
res_pos = log1mexp_numpy(2)
174179

175-
with warnings.catch_warnings():
176-
warnings.simplefilter("error")
177180
res_neg = log1mexp_numpy(-2, negative_input=True)
178181

179-
with pytest.warns(
180-
FutureWarning,
181-
match="pymc.math.log1mexp will expect a negative input",
182-
):
183-
res_pos_at = log1mexp(2).eval()
182+
with pytest.warns(
183+
FutureWarning,
184+
match="pymc.math.log1mexp will expect a negative input",
185+
):
186+
res_pos_at = log1mexp(2).eval()
184187

185-
with warnings.catch_warnings():
186-
warnings.simplefilter("error")
187188
res_neg_at = log1mexp(-2, negative_input=True).eval()
188189

189190
assert np.isclose(res_pos, res_neg)
@@ -196,8 +197,8 @@ def test_logdiffexp():
196197
with warnings.catch_warnings():
197198
warnings.filterwarnings("ignore", "divide by zero encountered in log", RuntimeWarning)
198199
b = np.log([0, 1, 2, 3])
199-
200-
assert np.allclose(logdiffexp_numpy(a, b), 0)
200+
with pytest.warns(FutureWarning, match="deprecated"):
201+
assert np.allclose(logdiffexp_numpy(a, b), 0)
201202
assert np.allclose(logdiffexp(a, b).eval(), 0)
202203

203204

0 commit comments

Comments
 (0)