Skip to content

Commit 2f9f1cd

Browse files
author
pymc-devs
committed
Make PyMC depend on logprob
1 parent 10a020d commit 2f9f1cd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+1696
-4083
lines changed

.github/workflows/tests.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,19 @@ jobs:
8585
pymc/tests/ode/test_utils.py
8686
pymc/tests/step_methods/hmc/test_quadpotential.py
8787
88+
- |
89+
pymc/tests/logprob/test_abstract.py
90+
pymc/tests/logprob/test_censoring.py
91+
pymc/tests/logprob/test_composite_logprob.py
92+
pymc/tests/logprob/test_cumsum.py
93+
pymc/tests/logprob/test_joint_logprob.py
94+
pymc/tests/logprob/test_mixture.py
95+
pymc/tests/logprob/test_rewriting.py
96+
pymc/tests/logprob/test_scan.py
97+
pymc/tests/logprob/test_tensor.py
98+
pymc/tests/logprob/test_transforms.py
99+
pymc/tests/logprob/test_utils.py
100+
88101
fail-fast: false
89102
runs-on: ${{ matrix.os }}
90103
env:

.pre-commit-config.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ repos:
8585
(?x)(arviz-devs.github.io|
8686
python.arviz.org|
8787
aesara.readthedocs.io|
88-
aeppl.readthedocs.io|
8988
pymc-experimental.readthedocs.io|
9089
docs.pymc.io|
9190
www.pymc.io|

conda-envs/environment-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ channels:
55
- defaults
66
dependencies:
77
# Base dependencies
8-
- aeppl=0.0.38
98
- aesara=2.8.8
109
- arviz>=0.13.0
1110
- blas
@@ -41,3 +40,4 @@ dependencies:
4140
- types-cachetools
4241
- pip:
4342
- git+https://github.com/pymc-devs/pymc-sphinx-theme
43+
- numdifftools>=0.9.40

conda-envs/environment-test.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ channels:
55
- defaults
66
dependencies:
77
# Base dependencies
8-
- aeppl=0.0.38
98
- aesara=2.8.8
109
- arviz>=0.13.0
1110
- blas
@@ -29,3 +28,5 @@ dependencies:
2928
- pytest>=3.0
3029
- mypy=0.990
3130
- types-cachetools
31+
- pip:
32+
- numdifftools>=0.9.40

conda-envs/windows-environment-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ channels:
55
- defaults
66
dependencies:
77
# Base dependencies (see install guide for Windows)
8-
- aeppl=0.0.38
98
- aesara=2.8.8
109
- arviz>=0.13.0
1110
- blas
@@ -38,3 +37,4 @@ dependencies:
3837
- types-cachetools
3938
- pip:
4039
- git+https://github.com/pymc-devs/pymc-sphinx-theme
40+
- numdifftools>=0.9.40

conda-envs/windows-environment-test.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ channels:
55
- defaults
66
dependencies:
77
# Base dependencies (see install guide for Windows)
8-
- aeppl=0.0.38
98
- aesara=2.8.8
109
- arviz>=0.13.0
1110
- blas
@@ -30,3 +29,5 @@ dependencies:
3029
- pytest>=3.0
3130
- mypy=0.990
3231
- types-cachetools
32+
- pip:
33+
- numdifftools>=0.9.40

docs/source/conf.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,6 @@
188188
intersphinx_mapping = {
189189
"arviz": ("https://python.arviz.org/en/latest/", None),
190190
"aesara": ("https://aesara.readthedocs.io/en/latest/", None),
191-
"aeppl": ("https://aeppl.readthedocs.io/en/latest/", None),
192191
"home": ("https://www.pymc.io", None),
193192
"pmx": ("https://www.pymc.io/projects/experimental/en/latest", None),
194193
"numpy": ("https://numpy.org/doc/stable/", None),

docs/source/learn/core_notebooks/GLM_linear.ipynb

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -522,14 +522,15 @@
522522
"source": [
523523
"%load_ext watermark\n",
524524
"\n",
525-
"%watermark -n -u -v -iv -w -p aesara,aeppl"
525+
"%watermark -n -u -v -iv -w -p aesara"
526526
]
527527
}
528528
],
529529
"metadata": {
530530
"anaconda-cloud": {},
531+
"hide_input": false,
531532
"kernelspec": {
532-
"display_name": "Python 3 (ipykernel)",
533+
"display_name": "Python 3",
533534
"language": "python",
534535
"name": "python3"
535536
},
@@ -543,14 +544,27 @@
543544
"name": "python",
544545
"nbconvert_exporter": "python",
545546
"pygments_lexer": "ipython3",
546-
"version": "3.9.7"
547+
"version": "3.8.10"
547548
},
548549
"latex_envs": {
549550
"bibliofile": "biblio.bib",
550551
"cite_by": "apalike",
551552
"current_citInitial": 1,
552553
"eqLabelWithNumbers": true,
553554
"eqNumInitial": 0
555+
},
556+
"toc": {
557+
"base_numbering": 1,
558+
"nav_menu": {},
559+
"number_sections": true,
560+
"sideBar": true,
561+
"skip_h1_title": false,
562+
"title_cell": "Table of Contents",
563+
"title_sidebar": "Contents",
564+
"toc_cell": false,
565+
"toc_position": {},
566+
"toc_section_display": true,
567+
"toc_window_display": false
554568
}
555569
},
556570
"nbformat": 4,

docs/source/learn/core_notebooks/model_comparison.ipynb

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -536,16 +536,17 @@
536536
}
537537
],
538538
"source": [
539-
"%watermark -n -u -v -iv -w -p xarray,aesara,aeppl"
539+
"%watermark -n -u -v -iv -w -p xarray,aesara"
540540
]
541541
}
542542
],
543543
"metadata": {
544+
"hide_input": false,
544545
"interpreter": {
545546
"hash": "baf205d70af30bf8b721a304f5a44beb31bf8af014f6b7340f1a7ae004926653"
546547
},
547548
"kernelspec": {
548-
"display_name": "Python 3 (ipykernel)",
549+
"display_name": "Python 3",
549550
"language": "python",
550551
"name": "python3"
551552
},
@@ -559,7 +560,20 @@
559560
"name": "python",
560561
"nbconvert_exporter": "python",
561562
"pygments_lexer": "ipython3",
562-
"version": "3.9.7"
563+
"version": "3.8.10"
564+
},
565+
"toc": {
566+
"base_numbering": 1,
567+
"nav_menu": {},
568+
"number_sections": true,
569+
"sideBar": true,
570+
"skip_h1_title": false,
571+
"title_cell": "Table of Contents",
572+
"title_sidebar": "Contents",
573+
"toc_cell": false,
574+
"toc_position": {},
575+
"toc_section_display": true,
576+
"toc_window_display": false
563577
}
564578
},
565579
"nbformat": 4,

docs/source/learn/core_notebooks/posterior_predictive.ipynb

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4649,14 +4649,15 @@
46494649
],
46504650
"source": [
46514651
"%load_ext watermark\n",
4652-
"%watermark -n -u -v -iv -w -p aesara,aeppl"
4652+
"%watermark -n -u -v -iv -w -p aesara"
46534653
]
46544654
}
46554655
],
46564656
"metadata": {
46574657
"anaconda-cloud": {},
4658+
"hide_input": false,
46584659
"kernelspec": {
4659-
"display_name": "Python 3.9.13 ('pymc-dev-py39')",
4660+
"display_name": "Python 3",
46604661
"language": "python",
46614662
"name": "python3"
46624663
},
@@ -4670,7 +4671,20 @@
46704671
"name": "python",
46714672
"nbconvert_exporter": "python",
46724673
"pygments_lexer": "ipython3",
4673-
"version": "3.9.13"
4674+
"version": "3.8.10"
4675+
},
4676+
"toc": {
4677+
"base_numbering": 1,
4678+
"nav_menu": {},
4679+
"number_sections": true,
4680+
"sideBar": true,
4681+
"skip_h1_title": false,
4682+
"title_cell": "Table of Contents",
4683+
"title_sidebar": "Contents",
4684+
"toc_cell": false,
4685+
"toc_position": {},
4686+
"toc_section_display": true,
4687+
"toc_window_display": false
46744688
},
46754689
"vscode": {
46764690
"interpreter": {

docs/source/learn/core_notebooks/pymc_aesara.ipynb

Lines changed: 140 additions & 45 deletions
Large diffs are not rendered by default.

docs/source/learn/core_notebooks/pymc_overview.ipynb

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4340,14 +4340,15 @@
43404340
],
43414341
"source": [
43424342
"%load_ext watermark\n",
4343-
"%watermark -n -u -v -iv -w -p xarray,aeppl"
4343+
"%watermark -n -u -v -iv -w -p xarray"
43444344
]
43454345
}
43464346
],
43474347
"metadata": {
43484348
"anaconda-cloud": {},
4349+
"hide_input": false,
43494350
"kernelspec": {
4350-
"display_name": "Python 3 (ipykernel)",
4351+
"display_name": "Python 3",
43514352
"language": "python",
43524353
"name": "python3"
43534354
},
@@ -4361,7 +4362,20 @@
43614362
"name": "python",
43624363
"nbconvert_exporter": "python",
43634364
"pygments_lexer": "ipython3",
4364-
"version": "3.8.5"
4365+
"version": "3.8.10"
4366+
},
4367+
"toc": {
4368+
"base_numbering": 1,
4369+
"nav_menu": {},
4370+
"number_sections": true,
4371+
"sideBar": true,
4372+
"skip_h1_title": false,
4373+
"title_cell": "Table of Contents",
4374+
"title_sidebar": "Contents",
4375+
"toc_cell": false,
4376+
"toc_position": {},
4377+
"toc_section_display": true,
4378+
"toc_window_display": false
43654379
}
43664380
},
43674381
"nbformat": 4,

pymc/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __set_compiler_flags():
5454
from pymc.distributions import *
5555
from pymc.exceptions import *
5656
from pymc.func_utils import find_constrained_prior
57+
from pymc.logprob import *
5758
from pymc.math import (
5859
expand_packed_triangular,
5960
invlogit,

pymc/aesaraf.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@
3232
import pandas as pd
3333
import scipy.sparse as sps
3434

35-
from aeppl.logprob import CheckParameterValue
36-
from aeppl.transforms import RVTransform
3735
from aesara import scalar
3836
from aesara.compile import Function, Mode, get_mode
3937
from aesara.gradient import grad
@@ -65,6 +63,8 @@
6563
from aesara.tensor.var import TensorConstant, TensorVariable
6664

6765
from pymc.exceptions import NotConstantValueError
66+
from pymc.logprob.transforms import RVTransform
67+
from pymc.logprob.utils import CheckParameterValue
6868
from pymc.vartypes import continuous_types, isgenerator, typefilter
6969

7070
PotentialShapeType = Union[int, np.ndarray, Sequence[Union[int, Variable]], TensorVariable]
@@ -947,7 +947,7 @@ def largest_common_dtype(tensors):
947947

948948
@node_rewriter(tracks=[CheckParameterValue])
949949
def local_remove_check_parameter(fgraph, node):
950-
"""Rewrite that removes Aeppl's CheckParameterValue
950+
"""Rewrite that removes CheckParameterValue
951951
952952
This is used when compile_rv_inplace
953953
"""
@@ -1071,13 +1071,13 @@ def compile_pymc(
10711071
Ensures that compiled functions containing random variables will produce new
10721072
samples on each call.
10731073
local_check_parameter_to_ninf_switch
1074-
Replaces Aeppl's CheckParameterValue assertions is logp expressions with Switches
1074+
Replaces CheckParameterValue assertions is logp expressions with Switches
10751075
that return -inf in place of the assert.
10761076
10771077
Optional rewrites
10781078
-----------------
10791079
local_remove_check_parameter
1080-
Replaces Aeppl's CheckParameterValue assertions is logp expressions. This is used
1080+
Replaces CheckParameterValue assertions is logp expressions. This is used
10811081
as an alteranative to the default local_check_parameter_to_ninf_switch whenenver
10821082
this function is called within a model context and the model `check_bounds` flag
10831083
is set to False.

pymc/distributions/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from pymc.distributions.logprob import ( # isort:skip
1616
logcdf,
1717
logp,
18-
joint_logp,
1918
)
2019

2120
from pymc.distributions.bound import Bound
@@ -198,7 +197,6 @@
198197
"Censored",
199198
"CAR",
200199
"PolyaGamma",
201-
"joint_logp",
202200
"logp",
203201
"logcdf",
204202
]

pymc/distributions/censored.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
class CensoredRV(SymbolicRandomVariable):
3030
"""Censored random variable"""
3131

32-
inline_aeppl = True
32+
inline_logprob = True
3333
_print_name = ("Censored", "\\operatorname{Censored}")
3434

3535

pymc/distributions/continuous.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import aesara.tensor as at
2929
import numpy as np
3030

31-
from aeppl.logprob import _logprob, logcdf, logprob
3231
from aesara.graph.basic import Apply, Variable
3332
from aesara.graph.op import Op
3433
from aesara.raise_op import Assert
@@ -57,6 +56,8 @@
5756
from aesara.tensor.random.op import RandomVariable
5857
from aesara.tensor.var import TensorConstant
5958

59+
from pymc.logprob.abstract import _logprob, logcdf, logprob
60+
6061
try:
6162
from polyagamma import polyagamma_cdf, polyagamma_pdf, random_polyagamma
6263
except ImportError: # pragma: no cover
@@ -531,6 +532,9 @@ def logcdf(value, mu, sigma):
531532
msg="sigma > 0",
532533
)
533534

535+
def icdf(value, mu, sigma):
536+
return mu + sigma * -np.sqrt(2.0) * at.erfcinv(2 * value)
537+
534538

535539
class TruncatedNormalRV(RandomVariable):
536540
name = "truncated_normal"
@@ -1290,10 +1294,6 @@ class Exponential(PositiveContinuous):
12901294
Variance :math:`\dfrac{1}{\lambda^2}`
12911295
======== ============================
12921296
1293-
Notes
1294-
-----
1295-
Logp calculation is defined in `aeppl.logprob <https://github.com/aesara-devs/aeppl/blob/main/aeppl/logprob.py/>`_.
1296-
12971297
Parameters
12981298
----------
12991299
lam : tensor_like of float

pymc/distributions/discrete.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,9 @@ def logcdf(value, p):
819819
msg="0 <= p <= 1",
820820
)
821821

822+
def icdf(value, p):
823+
return at.ceil(at.log1p(-value) / at.log1p(-p)).astype("int64")
824+
822825

823826
class HyperGeometric(Discrete):
824827
R"""

pymc/distributions/dist_math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import scipy.linalg
2828
import scipy.stats
2929

30-
from aeppl.logprob import CheckParameterValue
3130
from aesara.compile.builders import OpFromGraph
3231
from aesara.graph.basic import Apply, Variable
3332
from aesara.graph.op import Op
@@ -38,6 +37,7 @@
3837

3938
from pymc.aesaraf import floatX
4039
from pymc.distributions.shape_utils import to_tuple
40+
from pymc.logprob.utils import CheckParameterValue
4141

4242
solve_lower = SolveTriangular(lower=True)
4343
solve_upper = SolveTriangular(lower=False)

0 commit comments

Comments
 (0)