Skip to content

Commit e987950

Browse files
ferrinemichaelosthegetwieckiricardoV94
authored
Make VI work on v4 (#4582)
* resolve merge conflicts * start fixing things * make a simple test pass * fix some more tests * fix some more tests * add scaling for VI * add shape check * aet -> at * use rvs_to_values from the model in opi.py * refactor cloning routines (fix pymc references) * Run pre-commit and include VI tests in pytest workflow (rebase) * Run pre-commit and include VI tests in pytest workflow * seems like Grouped inference not working * spot an error in a simple test case * fix the test case with grouping * fix sampling with changed shape * remove not implemented error for local inference * support inferencedata * get rid of shape error for batched mvnormal * do not support AEVB with an error message * fix some meore tests * fix some more tests * fix full rank test * fix tests * test vi * fix conversion function * propagate model * fix * fix elbo * fix elbo full rank * Fixing broken scaling with float32 * ignore a nasty test * xfail one test with float 32 * fix pre commit * fix import * fix import.1 * Update pymc/variational/opvi.py Co-authored-by: Thomas Wiecki <thomas.wiecki@gmail.com> * fix docstrings * fix error with nans * remove TODO comments * print statements to logging * revert bart test * fix pylint issues * fix test bart * fix interence_data in init * ignore pickling problems * fix aevb test * fix name error * xfail test ramdom fn * mark xfail * refactor test * xfail fix * fix xfail syntax * pytest * test fixed * 5090 fixed * do not test local flows * change model.logpt not to return float * add a test for the replacenent in the graph * fix sample node functionality * Fix test with var replacement * add uncommited changes * resolve @ricardoV94's comment about initial point * restore test_bart.py as in main branch * resolve duplicated _get_scaling function * change job order * use commit initial point in the test file * use compute initial point in the opvi.py * remove unnessesary pattern broadcast * mark test as xfail before aesara release * Do not mark anything but just wait for the new release * use compute_initial_point * Update pymc/variational/opvi.py Co-authored-by: Thomas Wiecki <thomas.wiecki@gmail.com> * run upgraded pre-commit * move pipe back * Update pymc/variational/opvi.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> * Update pymc/variational/opvi.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> * Update pymc/variational/opvi.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> * Add removed newline * Use compile_pymc instead of aesara.function * Replace None by empty list in output * Apply suggestions from code review Co-authored-by: Michael Osthege <michael.osthege@outlook.com> Co-authored-by: Michael Osthege <m.osthege@fz-juelich.de> Co-authored-by: Thomas Wiecki <thomas.wiecki@gmail.com> Co-authored-by: Michael Osthege <michael.osthege@outlook.com> Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> Co-authored-by: Ricardo <ricardo.vieira1994@gmail.com>
1 parent ac2b82e commit e987950

File tree

11 files changed

+237
-171
lines changed

11 files changed

+237
-171
lines changed

.github/workflows/pytest.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ jobs:
4848
--ignore=pymc/tests/test_step.py
4949
--ignore=pymc/tests/test_tuning.py
5050
--ignore=pymc/tests/test_transforms.py
51-
--ignore=pymc/tests/test_variational_inference.py
5251
--ignore=pymc/tests/test_sampling_jax.py
5352
--ignore=pymc/tests/test_dist_math.py
5453
--ignore=pymc/tests/test_minibatches.py
@@ -169,6 +168,7 @@ jobs:
169168
pymc/tests/test_distributions_random.py
170169
pymc/tests/test_distributions_moments.py
171170
pymc/tests/test_distributions_timeseries.py
171+
pymc/tests/test_variational_inference.py
172172
- |
173173
pymc/tests/test_parallel_sampling.py
174174
pymc/tests/test_sampling.py

pymc/backends/arviz.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,7 @@ def is_data(name, var) -> bool:
478478
and var not in self.model.observed_RVs
479479
and var not in self.model.free_RVs
480480
and var not in self.model.potentials
481+
and var not in self.model.value_vars
481482
and (self.observations is None or name not in self.observations)
482483
and isinstance(var, (Constant, SharedVariable))
483484
)

pymc/distributions/logprob.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414

1515
from collections.abc import Mapping
1616
from functools import singledispatch
17-
from typing import Dict, List, Optional, Union
17+
from typing import Dict, List, Optional, Sequence, Union
1818

19+
import aesara
1920
import aesara.tensor as at
2021
import numpy as np
2122

@@ -43,15 +44,17 @@ def logp_transform(op: Op):
4344
return None
4445

4546

46-
def _get_scaling(total_size, shape, ndim):
47+
def _get_scaling(total_size: Optional[Union[int, Sequence[int]]], shape, ndim: int):
4748
"""
48-
Gets scaling constant for logp
49+
Gets scaling constant for logp.
4950
5051
Parameters
5152
----------
52-
total_size: int or list[int]
53+
total_size: Optional[int|List[int]]
54+
size of a fully observed data without minibatching,
55+
`None` means data is fully observed
5356
shape: shape
54-
shape to scale
57+
shape of an observed data
5558
ndim: int
5659
ndim hint
5760
@@ -60,7 +63,7 @@ def _get_scaling(total_size, shape, ndim):
6063
scalar
6164
"""
6265
if total_size is None:
63-
coef = floatX(1)
66+
coef = 1.0
6467
elif isinstance(total_size, int):
6568
if ndim >= 1:
6669
denom = shape[0]
@@ -90,21 +93,23 @@ def _get_scaling(total_size, shape, ndim):
9093
"number of scalings is bigger that ndim, got %r" % total_size
9194
)
9295
elif (len(begin) + len(end)) == 0:
93-
return floatX(1)
96+
coef = 1.0
9497
if len(end) > 0:
9598
shp_end = shape[-len(end) :]
9699
else:
97100
shp_end = np.asarray([])
98101
shp_begin = shape[: len(begin)]
99-
begin_coef = [floatX(t) / shp_begin[i] for i, t in enumerate(begin) if t is not None]
100-
end_coef = [floatX(t) / shp_end[i] for i, t in enumerate(end) if t is not None]
102+
begin_coef = [
103+
floatX(t) / floatX(shp_begin[i]) for i, t in enumerate(begin) if t is not None
104+
]
105+
end_coef = [floatX(t) / floatX(shp_end[i]) for i, t in enumerate(end) if t is not None]
101106
coefs = begin_coef + end_coef
102107
coef = at.prod(coefs)
103108
else:
104109
raise TypeError(
105110
"Unrecognized `total_size` type, expected int or list of ints, got %r" % total_size
106111
)
107-
return at.as_tensor(floatX(coef))
112+
return at.as_tensor(coef, dtype=aesara.config.floatX)
108113

109114

110115
subtensor_types = (

pymc/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
from pymc.blocking import DictToArrayBijection, RaveledVars
5959
from pymc.data import GenTensorVariable, Minibatch
6060
from pymc.distributions import joint_logpt, logp_transform
61+
from pymc.distributions.logprob import _get_scaling
6162
from pymc.exceptions import ImputationWarning, SamplingError, ShapeError
6263
from pymc.initial_point import make_initial_point_fn
6364
from pymc.math import flatten_list
@@ -1238,6 +1239,7 @@ def register_rv(
12381239
name = self.name_for(name)
12391240
rv_var.name = name
12401241
rv_var.tag.total_size = total_size
1242+
rv_var.tag.scaling = _get_scaling(total_size, shape=rv_var.shape, ndim=rv_var.ndim)
12411243

12421244
# Associate previously unknown dimension names with
12431245
# the length of the corresponding RV dimension.
@@ -1870,7 +1872,7 @@ def Potential(name, var, model=None):
18701872
"""
18711873
model = modelcontext(model)
18721874
var.name = model.name_for(name)
1873-
var.tag.scaling = None
1875+
var.tag.scaling = 1.0
18741876
model.potentials.append(var)
18751877
model.add_random_variable(var)
18761878

pymc/sampling.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2385,7 +2385,7 @@ def init_nuts(
23852385
progressbar=progressbar,
23862386
obj_optimizer=pm.adagrad_window,
23872387
)
2388-
initial_points = list(approx.sample(draws=chains))
2388+
initial_points = list(approx.sample(draws=chains, return_inferencedata=False))
23892389
std_apoint = approx.std.eval()
23902390
cov = std_apoint**2
23912391
mean = approx.mean.get_value()
@@ -2402,7 +2402,7 @@ def init_nuts(
24022402
progressbar=progressbar,
24032403
obj_optimizer=pm.adagrad_window,
24042404
)
2405-
initial_points = list(approx.sample(draws=chains))
2405+
initial_points = list(approx.sample(draws=chains, return_inferencedata=False))
24062406
cov = approx.std.eval() ** 2
24072407
potential = quadpotential.QuadPotentialDiag(cov)
24082408
elif init == "advi_map":
@@ -2416,7 +2416,7 @@ def init_nuts(
24162416
progressbar=progressbar,
24172417
obj_optimizer=pm.adagrad_window,
24182418
)
2419-
initial_points = list(approx.sample(draws=chains))
2419+
initial_points = list(approx.sample(draws=chains, return_inferencedata=False))
24202420
cov = approx.std.eval() ** 2
24212421
potential = quadpotential.QuadPotentialDiag(cov)
24222422
elif init == "map":

0 commit comments

Comments
 (0)