Skip to content

BUG: MvNormal with minibatch ADVI #6461

Open
@sina-mansour

Description

@sina-mansour

Describe the issue:

So I am trying to implement a MvNormal where the covariance itself is hierarchically dependent on other priors. I have also sort of described the issue in the discourse

However, trying to run a minibatch advi on the MvNormal likelihood produces an error.

I have provided a code that could reproduce the error bellow:

Reproduceable code example:

import pymc as pm
import aesara.tensor as at

c1 = np.random.uniform(size = 1000)

mus = np.array([1, 2])
sigmas = np.diag([2, 3])

vs = []
for x in range(1000):
    c = c1[x]/2
    rho = np.array([[1, c], [c, 1]])
    cov = sigmas.dot(rho.dot(sigmas))
    vs.append(np.random.multivariate_normal(mus, cov, size=1))

vs = np.concatenate(vs)

v1 = vs[:, 0]
v2 = vs[:, 1]

# first, standardize all variables to center the distributions around zero
v1_standardized = (v1 - v1.mean()) / v1.std()
v2_standardized = (v2 - v2.mean()) / v2.std()
vs_standardized = np.concatenate([v1_standardized[:, np.newaxis], v2_standardized[:, np.newaxis]], axis=1)
c1_standardized = (c1 - c1.mean()) / c1.std()

# Splines to model nonlinear effects of c1
# number of spline nuts (could be tuned)
num_knots = 3
knot_list = np.quantile(c1_standardized, np.linspace(0, 1, num_knots))
# create b spline basis for regression using patsy
B_spline_c1 = patsy.dmatrix(
    "bs(c1_standardized, knots=knots, degree=3, include_intercept=True) - 1",
    {"c1_standardized": c1_standardized, "knots": knot_list[1:-1]},
)

coords = {
    "splines": np.arange(B_spline_age.shape[1]),
    "obs_id": np.arange(len(v1_standardized)),
}

advi_model_cov = pm.Model(coords=coords)

with advi_model_cov:
    # minibatch variables
    c1_standardized_t = pm.Minibatch(c1_standardized, 100,)
    B_spline_c1_t = pm.Minibatch(B_spline_c1, 100)
    vs_standardized_t = pm.Minibatch(vs_standardized, 100)

    # Priors (for covariance)

    # c1 splines
    w_c1_rho = pm.Normal("w_c1_cov", mu=0, sigma=10, size=B_spline_c1.shape[1], dims="splines")

    # Estimated covariance
    rho_est = pm.Deterministic("cov_est", 2 * pm.math.sigmoid(pm.math.dot(B_spline_c1_t, w_c1_rho.T)) - 1 )

    # Priors (constant priors, already given)

    # Estimated means
    est_v1 = pm.MutableData("est_v1", np.repeat(mus[0], 100))
    est_v2 = pm.MutableData("est_v2", np.repeat(mus[1], 100))

    # Model variance

    # Variance estimate
    var_v1 = pm.MutableData("var_v1", np.repeat(sigmas[0, 0], 100))
    var_v2 = pm.MutableData("var_v2", np.repeat(sigmas[1, 1], 100))

    # Construct the mean vector and covariance matrix for MvNormal to fit a bivariate normal
    bivariate_mu = pm.Deterministic("bivariate_mu", at.as_tensor_variable([est_v1, est_v2]).T)
    cholesky_decomposition = pm.Deterministic("cholesky_decomposition", at.as_tensor_variable([est_v1, at.math.mul(est_v2, rho_est), at.math.mul(est_v2, at.math.sqrt(1 - rho_est**2)),]).T)

    # Likelihood estimation from a bivariate normal with known mean and variance, but unknown covariance
    likelihood = pm.MvNormal(
        "likelihood",
        mu=bivariate_mu,
        chol=cholesky_decomposition,
        observed=vs_standardized_t,
        total_size=len(v1_standardized),
    )
    
    # run ADVI with minibatch
    approx_cov = pm.fit(100000, callbacks=[pm.callbacks.CheckParametersConvergence(tolerance=1e-4)])

    # sample from trace
    advi_model_idata_cov = pm.sample_prior_predictive()
    advi_model_idata_cov.extend(approx_cov.sample(2000))
    pm.sample_posterior_predictive(advi_model_idata_cov, extend_inferencedata=True)

Error message:

<details>
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
File /mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:971, in Function.__call__(self, *args, **kwargs)
    969 try:
    970     outputs = (
--> 971         self.vm()
    972         if output_subset is None
    973         else self.vm(output_subset=output_subset)
    974     )
    975 except Exception:

AssertionError: Could not broadcast dimensions

During handling of the above exception, another exception occurred:

AssertionError                            Traceback (most recent call last)
Cell In [236], line 84
     75 likelihood = pm.MvNormal(
     76     "likelihood",
     77     mu=bivariate_mu,
   (...)
     80     total_size=len(v1_standardized),
     81 )
     83 # run ADVI with minibatch
---> 84 approx_cov = pm.fit(100000, callbacks=[pm.callbacks.CheckParametersConvergence(tolerance=1e-4)])
     86 # sample from trace
     87 advi_model_idata_cov = pm.sample_prior_predictive()

File /mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/pymc/variational/inference.py:753, in fit(n, method, model, random_seed, start, start_sigma, inf_kwargs, **kwargs)
    751 else:
    752     raise TypeError(f"method should be one of {set(_select.keys())} or Inference instance")
--> 753 return inference.fit(n, **kwargs)

File /mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/pymc/variational/inference.py:144, in Inference.fit(self, n, score, callbacks, progressbar, **kwargs)
    142     progress = range(n)
    143 if score:
--> 144     state = self._iterate_with_loss(0, n, step_func, progress, callbacks)
    145 else:
    146     state = self._iterate_without_loss(0, n, step_func, progress, callbacks)

File /mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/pymc/variational/inference.py:204, in Inference._iterate_with_loss(self, s, n, step_func, progress, callbacks)
    202 try:
    203     for i in progress:
--> 204         e = step_func()
    205         if np.isnan(e):
    206             scores = scores[:i]

File /mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:984, in Function.__call__(self, *args, **kwargs)
    982     if hasattr(self.vm, "thunks"):
    983         thunk = self.vm.thunks[self.vm.position_of_error]
--> 984     raise_with_op(
    985         self.maker.fgraph,
    986         node=self.vm.nodes[self.vm.position_of_error],
    987         thunk=thunk,
    988         storage_map=getattr(self.vm, "storage_map", None),
    989     )
    990 else:
    991     # old-style linkers raise their own exceptions
    992     raise

File /mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/link/utils.py:534, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    529     warnings.warn(
    530         f"{exc_type} error does not allow us to add an extra error message"
    531     )
    532     # Some exception need extra parameter in inputs. So forget the
    533     # extra long error message in that case.
--> 534 raise exc_value.with_traceback(exc_trace)

File /mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:971, in Function.__call__(self, *args, **kwargs)
    968 t0_fn = time.time()
    969 try:
    970     outputs = (
--> 971         self.vm()
    972         if output_subset is None
    973         else self.vm(output_subset=output_subset)
    974     )
    975 except Exception:
    976     restore_defaults()

AssertionError: Could not broadcast dimensions
Apply node that caused the error: Assert{msg=Could not broadcast dimensions}(Abs.0, AND.0)
Toposort index: 109
Inputs types: [ScalarType(int64), ScalarType(bool)]
Inputs shapes: [(), ()]
Inputs strides: [(), ()]
Inputs values: [100, False]
Outputs clients: [[TensorFromScalar(Assert{msg=Could not broadcast dimensions}.0)]]

Backtrace when the node is created (use Aesara flag traceback__limit=N to make it longer):
  File "/mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/pymc/distributions/distribution.py", line 290, in __new__
    rv_out = cls.dist(*args, **kwargs)
  File "/mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/pymc/distributions/multivariate.py", line 264, in dist
    mu = at.broadcast_arrays(mu, cov[..., -1])[0]
  File "/mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/tensor/extra_ops.py", line 1772, in broadcast_arrays
    return tuple(broadcast_to(a, broadcast_shape(*args)) for a in args)
  File "/mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/tensor/extra_ops.py", line 1772, in <genexpr>
    return tuple(broadcast_to(a, broadcast_shape(*args)) for a in args)
  File "/mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/tensor/extra_ops.py", line 1459, in broadcast_shape
    return broadcast_shape_iter(arrays, **kwargs)
  File "/mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/tensor/extra_ops.py", line 1596, in broadcast_shape_iter
    bcast_dim = assert_dim(dim_max, assert_cond)
  File "/mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/graph/op.py", line 297, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/raise_op.py", line 92, in make_node
    [value.type()],

HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
</details>

PyMC version information:

4.2.2

Context for the issue:

Further description of what I want to do is here

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions