Skip to content

BUG: ZeroSumTransform fails with initvalues #7772

Open
@velochy

Description

@velochy

Describe the issue:

Trying to set initvals on ZeroSumTransform'ed variables leads to a type casting error.

It seems to be caused by input being a numpy array rather than a pytensor one.

Fix seems simple. Posting a PR for it next

Reproduceable code example:

import pymc as pm, numpy as np

with pm.Model() as model:
    pm.ZeroSumNormal('zsn',shape=(10,))
    pm.Normal('n', shape=(10,), transform=pm.distributions.transforms.ZeroSumTransform(zerosum_axes=[0]))
    mp = pm.find_MAP()

    pm.sample(initvals=mp)

Error message:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/home/velochy/salk/sandbox/sandy.ipynb Cell 1 line 8
      5 pm.Normal('n', shape=(10,), transform=pm.distributions.transforms.ZeroSumTransform(zerosum_axes=[0]))
      6 mp = pm.find_MAP()
----> 8 pm.sample(initvals=mp)

File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/mcmc.py:832, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
    830         [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()]
    831     with joined_blas_limiter():
--> 832         initial_points, step = init_nuts(
    833             init=init,
    834             chains=chains,
    835             n_init=n_init,
    836             model=model,
    837             random_seed=random_seed_list,
    838             progressbar=progress_bool,
    839             jitter_max_retries=jitter_max_retries,
    840             tune=tune,
    841             initvals=initvals,
    842             compile_kwargs=compile_kwargs,
    843             **kwargs,
    844         )
    845 else:
    846     # Get initial points
    847     ipfns = make_initial_point_fns_per_chain(
    848         model=model,
    849         overrides=initvals,
    850         jitter_rvs=set(),
    851         chains=chains,
    852     )

File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/mcmc.py:1605, in init_nuts(init, chains, n_init, model, random_seed, progressbar, jitter_max_retries, tune, initvals, compile_kwargs, **kwargs)
   1602     q, _ = DictToArrayBijection.map(ip)
   1603     return logp_dlogp_func([q], extra_vars={})[0]
-> 1605 initial_points = _init_jitter(
   1606     model,
   1607     initvals,
   1608     seeds=random_seed_list,
   1609     jitter="jitter" in init,
   1610     jitter_max_retries=jitter_max_retries,
   1611     logp_fn=model_logp_fn,
   1612 )
   1614 apoints = [DictToArrayBijection.map(point) for point in initial_points]
   1615 apoints_data = [apoint.data for apoint in apoints]

File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/mcmc.py:1462, in _init_jitter(model, initvals, seeds, jitter, jitter_max_retries, logp_fn)
   1432 def _init_jitter(
   1433     model: Model,
   1434     initvals: StartDict | Sequence[StartDict | None] | None,
   (...)
   1438     logp_fn: Callable[[PointType], np.ndarray] | None = None,
   1439 ) -> list[PointType]:
   1440     """Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain.
   1441 
   1442     ``model.check_start_vals`` is used to test whether the jittered starting
   (...)
   1460         List of starting points for the sampler
   1461     """
-> 1462     ipfns = make_initial_point_fns_per_chain(
   1463         model=model,
   1464         overrides=initvals,
   1465         jitter_rvs=set(model.free_RVs) if jitter else set(),
   1466         chains=len(seeds),
   1467     )
   1469     if not jitter:
   1470         return [ipfn(seed) for ipfn, seed in zip(ipfns, seeds)]

File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/initial_point.py:101, in make_initial_point_fns_per_chain(model, overrides, jitter_rvs, chains)
     72 """Create an initial point function for each chain, as defined by initvals.
     73 
     74 If a single initval dictionary is passed, the function is replicated for each
   (...)
     95 
     96 """
     97 if isinstance(overrides, dict) or overrides is None:
     98     # One strategy for all chains
     99     # Only one function compilation is needed.
    100     ipfns = [
--> 101         make_initial_point_fn(
    102             model=model,
    103             overrides=overrides,
    104             jitter_rvs=jitter_rvs,
    105             return_transformed=True,
    106         )
    107     ] * chains
    108 elif len(overrides) == chains:
    109     ipfns = [
    110         make_initial_point_fn(
    111             model=model,
   (...)
    116         for chain_overrides in overrides
    117     ]

File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/initial_point.py:152, in make_initial_point_fn(model, overrides, jitter_rvs, default_strategy, return_transformed)
    126 def make_initial_point_fn(
    127     *,
    128     model,
   (...)
    132     return_transformed: bool = True,
    133 ) -> Callable[[SeedSequenceSeed], PointType]:
    134     """Create seeded function that computes initial values for all free model variables.
    135 
    136     Parameters
   (...)
    150     initial_point_fn : Callable[[SeedSequenceSeed], dict[str, np.ndarray]]
    151     """
--> 152     sdict_overrides = convert_str_to_rv_dict(model, overrides or {})
    153     initval_strats = {
    154         **model.rvs_to_initial_values,
    155         **sdict_overrides,
    156     }
    158     initial_values = make_initial_point_expression(
    159         free_rvs=model.free_RVs,
    160         rvs_to_transforms=model.rvs_to_transforms,
   (...)
    164         return_transformed=return_transformed,
    165     )

File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/initial_point.py:57, in convert_str_to_rv_dict(model, start)
     55 if is_transformed_name(key):
     56     rv = model[get_untransformed_name(key)]
---> 57     initvals[rv] = model.rvs_to_transforms[rv].backward(initval, *rv.owner.inputs)
     58 else:
     59     initvals[model[key]] = initval

File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/distributions/transforms.py:309, in ZeroSumTransform.backward(self, value, *rv_inputs)
    307 def backward(self, value, *rv_inputs):
    308     for axis in self.zerosum_axes:
--> 309         value = self.extend_axis(value, axis=axis)
    310     return value

File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/distributions/transforms.py:281, in ZeroSumTransform.extend_axis(array, axis)
    279 @staticmethod
    280 def extend_axis(array, axis):
--> 281     n = (array.shape[axis] + 1).astype("floatX")
    282     sum_vals = array.sum(axis, keepdims=True)
    283     norm = sum_vals / (pt.sqrt(n) + n)

AttributeError: 'int' object has no attribute 'astype'

PyMC version information:

pymc 5.22.0

Context for the issue:

I wanted to experiment with setting initvals from MAP and pathfinder, and ran into this issue.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions