Skip to content

Commit 4856e22

Browse files
committed
Cleanup root namespace
1 parent 419af06 commit 4856e22

File tree

12 files changed

+18
-91
lines changed

12 files changed

+18
-91
lines changed

pymc/__init__.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -46,36 +46,19 @@ def __set_compiler_flags():
4646

4747
__set_compiler_flags()
4848

49-
from pymc import _version, gp, ode, sampling
50-
from pymc.backends import *
51-
from pymc.blocking import *
49+
from pymc import _version, gp, ode, plots, sampling, stats
5250
from pymc.data import *
5351
from pymc.distributions import *
54-
from pymc.exceptions import *
5552
from pymc.func_utils import find_constrained_prior
5653
from pymc.logprob import *
57-
from pymc.math import (
58-
expand_packed_triangular,
59-
invlogit,
60-
invprobit,
61-
logaddexp,
62-
logit,
63-
logsumexp,
64-
probit,
65-
)
6654
from pymc.model.core import *
6755
from pymc.model.transform.conditioning import do, observe
6856
from pymc.model_graph import model_to_graphviz, model_to_networkx
69-
from pymc.plots import *
70-
from pymc.printing import *
7157
from pymc.pytensorf import *
7258
from pymc.sampling import *
7359
from pymc.smc import *
74-
from pymc.stats import *
7560
from pymc.step_methods import *
7661
from pymc.tuning import *
77-
from pymc.util import drop_warning_stat
7862
from pymc.variational import *
79-
from pymc.vartypes import *
8063

8164
__version__ = _version.get_versions()["version"]

pymc/blocking.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,6 @@
3737

3838
from typing_extensions import TypeAlias
3939

40-
__all__ = ["DictToArrayBijection"]
41-
42-
4340
T = TypeVar("T")
4441
PointType: TypeAlias = Dict[str, np.ndarray]
4542
StatsDict: TypeAlias = Dict[str, Any]

pymc/data.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040

4141
__all__ = [
4242
"get_data",
43-
"GeneratorAdapter",
4443
"Minibatch",
4544
"Data",
4645
"ConstantData",

pymc/exceptions.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414

1515
__all__ = [
1616
"SamplingError",
17-
"IncorrectArgumentsError",
18-
"TraceDirectoryError",
1917
"ImputationWarning",
2018
"ShapeWarning",
2119
"ShapeError",
@@ -26,16 +24,6 @@ class SamplingError(RuntimeError):
2624
pass
2725

2826

29-
class IncorrectArgumentsError(ValueError):
30-
pass
31-
32-
33-
class TraceDirectoryError(ValueError):
34-
"""Error from trying to load a trace from an incorrectly-structured directory,"""
35-
36-
pass
37-
38-
3927
class ImputationWarning(UserWarning):
4028
"""Warning that there are missing values that will be imputed."""
4129

pymc/plots/__init__.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
"""
2121
import functools
2222
import sys
23-
import warnings
2423

2524
import arviz as az
2625

@@ -29,40 +28,3 @@
2928
obj = getattr(az.plots, attr)
3029
if not attr.startswith("__"):
3130
setattr(sys.modules[__name__], attr, obj)
32-
33-
34-
def alias_deprecation(func, alias: str):
35-
original = func.__name__
36-
37-
@functools.wraps(func)
38-
def wrapped(*args, **kwargs):
39-
raise FutureWarning(
40-
f"The function `{alias}` from PyMC was an alias for `{original}` from ArviZ. "
41-
"It was removed in PyMC 4.0. "
42-
f"Switch to `pymc.{original}` or `arviz.{original}`."
43-
)
44-
45-
return wrapped
46-
47-
48-
# Aliases of ArviZ functions
49-
autocorrplot = alias_deprecation(az.plot_autocorr, alias="autocorrplot")
50-
forestplot = alias_deprecation(az.plot_forest, alias="forestplot")
51-
kdeplot = alias_deprecation(az.plot_kde, alias="kdeplot")
52-
energyplot = alias_deprecation(az.plot_energy, alias="energyplot")
53-
densityplot = alias_deprecation(az.plot_density, alias="densityplot")
54-
pairplot = alias_deprecation(az.plot_pair, alias="pairplot")
55-
traceplot = alias_deprecation(az.plot_trace, alias="traceplot")
56-
compareplot = alias_deprecation(az.plot_compare, alias="compareplot")
57-
58-
59-
__all__ = tuple(az.plots.__all__) + (
60-
"autocorrplot",
61-
"compareplot",
62-
"forestplot",
63-
"kdeplot",
64-
"traceplot",
65-
"energyplot",
66-
"densityplot",
67-
"pairplot",
68-
)

pymc/pytensorf.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,9 @@
7676
"hessian",
7777
"hessian_diag",
7878
"inputvars",
79-
"cont_inputs",
8079
"floatX",
8180
"intX",
82-
"smartfloatX",
8381
"jacobian",
84-
"CallableTensor",
85-
"join_nonshared_inputs",
86-
"make_shared_replacements",
87-
"generator",
88-
"convert_observed_data",
8982
"compile_pymc",
9083
]
9184

pymc/sampling/forward.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,11 @@
5353
from pytensor.tensor.sharedvar import SharedVariable
5454
from typing_extensions import TypeAlias
5555

56-
import pymc as pm
57-
58-
from pymc.backends.arviz import _DefaultTrace
56+
from pymc.backends.arviz import (
57+
_DefaultTrace,
58+
predictions_to_inference_data,
59+
to_inference_data,
60+
)
5961
from pymc.backends.base import MultiTrace
6062
from pymc.blocking import PointType
6163
from pymc.model import Model, modelcontext
@@ -438,7 +440,7 @@ def sample_prior_predictive(
438440
ikwargs: Dict[str, Any] = dict(model=model)
439441
if idata_kwargs:
440442
ikwargs.update(idata_kwargs)
441-
return pm.to_inference_data(prior=prior, **ikwargs)
443+
return to_inference_data(prior=prior, **ikwargs)
442444

443445

444446
def sample_posterior_predictive(
@@ -669,8 +671,8 @@ def sample_posterior_predictive(
669671
if extend_inferencedata:
670672
ikwargs.setdefault("idata_orig", idata)
671673
ikwargs.setdefault("inplace", True)
672-
return pm.predictions_to_inference_data(ppc_trace, **ikwargs)
673-
idata_pp = pm.to_inference_data(posterior_predictive=ppc_trace, **ikwargs)
674+
return predictions_to_inference_data(ppc_trace, **ikwargs)
675+
idata_pp = to_inference_data(posterior_predictive=ppc_trace, **ikwargs)
674676

675677
if extend_inferencedata and idata is not None:
676678
idata.extend(idata_pp)

pymc/sampling/mcmc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
coords_and_dims_for_inferencedata,
5353
find_constants,
5454
find_observations,
55+
to_inference_data,
5556
)
5657
from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains
5758
from pymc.blocking import DictToArrayBijection
@@ -892,7 +893,7 @@ def _sample_return(
892893
if compute_convergence_checks or return_inferencedata:
893894
ikwargs: Dict[str, Any] = dict(model=model, save_warmup=not discard_tuned_samples)
894895
ikwargs.update(idata_kwargs)
895-
idata = pm.to_inference_data(mtrace, **ikwargs)
896+
idata = to_inference_data(mtrace, **ikwargs)
896897

897898
if compute_convergence_checks:
898899
warns = run_convergence_checks(idata, model)

pymc/stats/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,3 @@
2828
setattr(sys.modules[__name__], attr, obj)
2929

3030
from pymc.stats.log_likelihood import compute_log_likelihood
31-
32-
__all__ = ("compute_log_likelihood",) + tuple(az.stats.__all__)

pymc/step_methods/metropolis.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
compile_pymc,
3232
floatX,
3333
join_nonshared_inputs,
34+
make_shared_replacements,
3435
replace_rng_nodes,
3536
)
3637
from pymc.step_methods.arraystep import (
@@ -804,7 +805,7 @@ def __init__(
804805

805806
self.mode = mode
806807

807-
shared = pm.make_shared_replacements(initial_values, vars, model)
808+
shared = make_shared_replacements(initial_values, vars, model)
808809
self.delta_logp = delta_logp(initial_values, model.logp(), vars, shared)
809810
super().__init__(vars, shared)
810811

@@ -960,7 +961,7 @@ def __init__(
960961

961962
self.mode = mode
962963

963-
shared = pm.make_shared_replacements(initial_values, vars, model)
964+
shared = make_shared_replacements(initial_values, vars, model)
964965
self.delta_logp = delta_logp(initial_values, model.logp(), vars, shared)
965966
super().__init__(vars, shared)
966967

pymc/tuning/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@
1414

1515
from pymc.tuning.scaling import find_hessian, guess_scaling, trace_cov
1616
from pymc.tuning.starting import find_MAP
17+
18+
__all__ = ("find_MAP", "find_hessian")

pymc/variational/opvi.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464

6565
import pymc as pm
6666

67+
from pymc.backends import to_inference_data
6768
from pymc.backends.base import MultiTrace
6869
from pymc.backends.ndarray import NDArray
6970
from pymc.blocking import DictToArrayBijection
@@ -1578,7 +1579,7 @@ def sample(
15781579
if not return_inferencedata:
15791580
return multi_trace
15801581
else:
1581-
return pm.to_inference_data(multi_trace, model=self.model, **kwargs)
1582+
return to_inference_data(multi_trace, model=self.model, **kwargs)
15821583

15831584
@property
15841585
def ndim(self):

0 commit comments

Comments
 (0)