Open
Description
Describe the issue:
Trying to set initvals on ZeroSumTransform'ed variables leads to a type casting error.
It seems to be caused by input being a numpy array rather than a pytensor one.
Fix seems simple. Posting a PR for it next
Reproduceable code example:
import pymc as pm, numpy as np
with pm.Model() as model:
pm.ZeroSumNormal('zsn',shape=(10,))
pm.Normal('n', shape=(10,), transform=pm.distributions.transforms.ZeroSumTransform(zerosum_axes=[0]))
mp = pm.find_MAP()
pm.sample(initvals=mp)
Error message:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
/home/velochy/salk/sandbox/sandy.ipynb Cell 1 line 8
5 pm.Normal('n', shape=(10,), transform=pm.distributions.transforms.ZeroSumTransform(zerosum_axes=[0]))
6 mp = pm.find_MAP()
----> 8 pm.sample(initvals=mp)
File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/mcmc.py:832, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
830 [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()]
831 with joined_blas_limiter():
--> 832 initial_points, step = init_nuts(
833 init=init,
834 chains=chains,
835 n_init=n_init,
836 model=model,
837 random_seed=random_seed_list,
838 progressbar=progress_bool,
839 jitter_max_retries=jitter_max_retries,
840 tune=tune,
841 initvals=initvals,
842 compile_kwargs=compile_kwargs,
843 **kwargs,
844 )
845 else:
846 # Get initial points
847 ipfns = make_initial_point_fns_per_chain(
848 model=model,
849 overrides=initvals,
850 jitter_rvs=set(),
851 chains=chains,
852 )
File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/mcmc.py:1605, in init_nuts(init, chains, n_init, model, random_seed, progressbar, jitter_max_retries, tune, initvals, compile_kwargs, **kwargs)
1602 q, _ = DictToArrayBijection.map(ip)
1603 return logp_dlogp_func([q], extra_vars={})[0]
-> 1605 initial_points = _init_jitter(
1606 model,
1607 initvals,
1608 seeds=random_seed_list,
1609 jitter="jitter" in init,
1610 jitter_max_retries=jitter_max_retries,
1611 logp_fn=model_logp_fn,
1612 )
1614 apoints = [DictToArrayBijection.map(point) for point in initial_points]
1615 apoints_data = [apoint.data for apoint in apoints]
File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/mcmc.py:1462, in _init_jitter(model, initvals, seeds, jitter, jitter_max_retries, logp_fn)
1432 def _init_jitter(
1433 model: Model,
1434 initvals: StartDict | Sequence[StartDict | None] | None,
(...)
1438 logp_fn: Callable[[PointType], np.ndarray] | None = None,
1439 ) -> list[PointType]:
1440 """Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain.
1441
1442 ``model.check_start_vals`` is used to test whether the jittered starting
(...)
1460 List of starting points for the sampler
1461 """
-> 1462 ipfns = make_initial_point_fns_per_chain(
1463 model=model,
1464 overrides=initvals,
1465 jitter_rvs=set(model.free_RVs) if jitter else set(),
1466 chains=len(seeds),
1467 )
1469 if not jitter:
1470 return [ipfn(seed) for ipfn, seed in zip(ipfns, seeds)]
File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/initial_point.py:101, in make_initial_point_fns_per_chain(model, overrides, jitter_rvs, chains)
72 """Create an initial point function for each chain, as defined by initvals.
73
74 If a single initval dictionary is passed, the function is replicated for each
(...)
95
96 """
97 if isinstance(overrides, dict) or overrides is None:
98 # One strategy for all chains
99 # Only one function compilation is needed.
100 ipfns = [
--> 101 make_initial_point_fn(
102 model=model,
103 overrides=overrides,
104 jitter_rvs=jitter_rvs,
105 return_transformed=True,
106 )
107 ] * chains
108 elif len(overrides) == chains:
109 ipfns = [
110 make_initial_point_fn(
111 model=model,
(...)
116 for chain_overrides in overrides
117 ]
File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/initial_point.py:152, in make_initial_point_fn(model, overrides, jitter_rvs, default_strategy, return_transformed)
126 def make_initial_point_fn(
127 *,
128 model,
(...)
132 return_transformed: bool = True,
133 ) -> Callable[[SeedSequenceSeed], PointType]:
134 """Create seeded function that computes initial values for all free model variables.
135
136 Parameters
(...)
150 initial_point_fn : Callable[[SeedSequenceSeed], dict[str, np.ndarray]]
151 """
--> 152 sdict_overrides = convert_str_to_rv_dict(model, overrides or {})
153 initval_strats = {
154 **model.rvs_to_initial_values,
155 **sdict_overrides,
156 }
158 initial_values = make_initial_point_expression(
159 free_rvs=model.free_RVs,
160 rvs_to_transforms=model.rvs_to_transforms,
(...)
164 return_transformed=return_transformed,
165 )
File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/initial_point.py:57, in convert_str_to_rv_dict(model, start)
55 if is_transformed_name(key):
56 rv = model[get_untransformed_name(key)]
---> 57 initvals[rv] = model.rvs_to_transforms[rv].backward(initval, *rv.owner.inputs)
58 else:
59 initvals[model[key]] = initval
File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/distributions/transforms.py:309, in ZeroSumTransform.backward(self, value, *rv_inputs)
307 def backward(self, value, *rv_inputs):
308 for axis in self.zerosum_axes:
--> 309 value = self.extend_axis(value, axis=axis)
310 return value
File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/distributions/transforms.py:281, in ZeroSumTransform.extend_axis(array, axis)
279 @staticmethod
280 def extend_axis(array, axis):
--> 281 n = (array.shape[axis] + 1).astype("floatX")
282 sum_vals = array.sum(axis, keepdims=True)
283 norm = sum_vals / (pt.sqrt(n) + n)
AttributeError: 'int' object has no attribute 'astype'
PyMC version information:
pymc 5.22.0
Context for the issue:
I wanted to experiment with setting initvals from MAP and pathfinder, and ran into this issue.