Skip to content

Commit 9ee5d84

Browse files
add ICARRV and ICAR (#6831)
* add ICARRV and ICAR * reorganized API so W is required, added tests on logp,rng,checks * adjusted tests and api * add ICAR to lists in distributions/__init__.py * fix check_pymc_params_match_rv_op test * deleted unnecessary pytest.mark.parametize * fix check for square matrix + check_pymc_match_rv_op * fix moment test * removed asserts, added tests on sizes * Implementing Bill's comments * add _supp_shape_from_params()
1 parent 470d474 commit 9ee5d84

File tree

3 files changed

+241
-0
lines changed

3 files changed

+241
-0
lines changed

pymc/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
)
8787
from pymc.distributions.multivariate import (
8888
CAR,
89+
ICAR,
8990
Dirichlet,
9091
DirichletMultinomial,
9192
KroneckerNormal,
@@ -198,6 +199,7 @@
198199
"Truncated",
199200
"Censored",
200201
"CAR",
202+
"ICAR",
201203
"PolyaGamma",
202204
"HurdleGamma",
203205
"HurdleLogNormal",

pymc/distributions/multivariate.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
"MatrixNormal",
9090
"KroneckerNormal",
9191
"CAR",
92+
"ICAR",
9293
"StickBreakingWeights",
9394
]
9495

@@ -2256,6 +2257,172 @@ def logp(value, mu, W, alpha, tau):
22562257
)
22572258

22582259

2260+
class ICARRV(RandomVariable):
2261+
name = "icar"
2262+
ndim_supp = 1
2263+
ndims_params = [2, 1, 1, 0, 0, 0]
2264+
dtype = "floatX"
2265+
_print_name = ("ICAR", "\\operatorname{ICAR}")
2266+
2267+
def __call__(self, W, node1, node2, N, sigma, zero_sum_stdev, size=None, **kwargs):
2268+
return super().__call__(W, node1, node2, N, sigma, zero_sum_stdev, size=size, **kwargs)
2269+
2270+
def _supp_shape_from_params(self, dist_params, param_shapes=None):
2271+
return supp_shape_from_ref_param_shape(
2272+
ndim_supp=self.ndim_supp,
2273+
dist_params=dist_params,
2274+
param_shapes=param_shapes,
2275+
ref_param_idx=0,
2276+
)
2277+
2278+
@classmethod
2279+
def rng_fn(cls, rng, size, W, node1, node2, N, sigma, zero_sum_stdev):
2280+
raise NotImplementedError("Cannot sample from ICAR prior")
2281+
2282+
2283+
icar = ICARRV()
2284+
2285+
2286+
class ICAR(Continuous):
2287+
r"""
2288+
The intrinsic conditional autoregressive prior. It is primarily used to model
2289+
covariance between neighboring areas. It is a special case
2290+
of the :class:`~pymc.CAR` distribution where alpha is set to 1.
2291+
2292+
The log probability density function is
2293+
2294+
.. math::
2295+
f(\phi| W,\sigma) =
2296+
-\frac{1}{2\sigma^{2}} \sum_{i\sim j} (\phi_{i} - \phi_{j})^2 -
2297+
\frac{1}{2}*\frac{\sum_{i}{\phi_{i}}}{0.001N}^{2} - \ln{\sqrt{2\\pi}} -
2298+
\ln{0.001N}
2299+
2300+
The first term represents the spatial covariance component. Each $\\phi_{i}$ is penalized
2301+
based on the square distance from each of its neighbors. The notation $i\\sim j$
2302+
indicates a sum over all the neighbors of $\\phi_{i}$. The last three terms are the
2303+
Normal log density function where the mean is zero and the standard deviation is
2304+
$N * 0.001$ (where N is the length of the vector $\\phi$). This component imposes
2305+
a zero-sum constraint by finding the sum of the vector $\\phi$ and penalizing based
2306+
on its distance from zero.
2307+
2308+
Parameters
2309+
----------
2310+
W : ndarray of int
2311+
Symmetric adjacency matrix of 1s and 0s indicating adjacency between elements.
2312+
2313+
sigma : scalar, default 1
2314+
Standard deviation of the vector of phi's. Putting a prior on sigma
2315+
will result in a centered parameterization. In most cases, it is
2316+
preferable to use a non-centered parameterization by using the default
2317+
value and multiplying the resulting phi's by sigma. See the example below.
2318+
2319+
zero_sum_stdev : scalar, default 0.001
2320+
Controls how strongly to enforce the zero-sum constraint. The sum of
2321+
phi is normally distributed with a mean of zero and small standard deviation.
2322+
This parameter sets the standard deviation of a normal density function with
2323+
mean zero.
2324+
2325+
2326+
Examples
2327+
--------
2328+
This example illustrates how to switch between centered and non-centered
2329+
parameterizations.
2330+
2331+
.. code-block:: python
2332+
2333+
import numpy as np
2334+
import pymc as pm
2335+
2336+
# 4x4 adjacency matrix
2337+
# arranged in a square lattice
2338+
2339+
W = np.array([
2340+
[0,1,0,1],
2341+
[1,0,1,0],
2342+
[0,1,0,1],
2343+
[1,0,1,0]
2344+
])
2345+
2346+
# centered parameterization
2347+
with pm.Model():
2348+
sigma = pm.Exponential('sigma', 1)
2349+
phi = pm.ICAR('phi', W=W, sigma=sigma)
2350+
mu = phi
2351+
2352+
# non-centered parameterization
2353+
with pm.Model():
2354+
sigma = pm.Exponential('sigma', 1)
2355+
phi = pm.ICAR('phi', W=W)
2356+
mu = sigma * phi
2357+
2358+
References
2359+
----------
2360+
.. Mitzi, M., Wheeler-Martin, K., Simpson, D., Mooney, J. S.,
2361+
Gelman, A., Dimaggio, C.
2362+
"Bayesian hierarchical spatial models: Implementing the Besag York
2363+
Mollié model in stan"
2364+
Spatial and Spatio-temporal Epidemiology, Vol. 31, (Aug., 2019),
2365+
pp 1-18
2366+
.. Banerjee, S., Carlin, B., Gelfand, A. Hierarchical Modeling
2367+
and Analysis for Spatial Data. Second edition. CRC press. (2015)
2368+
2369+
"""
2370+
2371+
rv_op = icar
2372+
2373+
@classmethod
2374+
def dist(cls, W, sigma=1, zero_sum_stdev=0.001, **kwargs):
2375+
if not W.ndim == 2:
2376+
raise ValueError("W must be matrix with ndim=2")
2377+
2378+
if not W.shape[0] == W.shape[1]:
2379+
raise ValueError("W must be a square matrix")
2380+
2381+
if not np.allclose(W.T, W):
2382+
raise ValueError("W must be a symmetric matrix")
2383+
2384+
if np.any((W != 0) & (W != 1)):
2385+
raise ValueError("W must be composed of only 1s and 0s")
2386+
2387+
# convert adjacency matrix to edgelist representation
2388+
# An edgelist is a pair of lists.
2389+
# If node i and node j are connected then one list
2390+
# will contain i and the other will contain j at the same
2391+
# index value.
2392+
# We only use the lower triangle here because adjacency
2393+
# is a undirected connection.
2394+
2395+
node1, node2 = np.where(np.tril(W) == 1)
2396+
2397+
node1 = pt.as_tensor_variable(node1, dtype=int)
2398+
node2 = pt.as_tensor_variable(node2, dtype=int)
2399+
2400+
W = pt.as_tensor_variable(W, dtype=int)
2401+
2402+
N = pt.shape(W)[0]
2403+
N = pt.as_tensor_variable(N)
2404+
2405+
sigma = pt.as_tensor_variable(floatX(sigma))
2406+
zero_sum_stdev = pt.as_tensor_variable(floatX(zero_sum_stdev))
2407+
2408+
return super().dist([W, node1, node2, N, sigma, zero_sum_stdev], **kwargs)
2409+
2410+
def moment(rv, size, W, node1, node2, N, sigma, zero_sum_stdev):
2411+
return pt.zeros(N)
2412+
2413+
def logp(value, W, node1, node2, N, sigma, zero_sum_stdev):
2414+
pairwise_difference = (-1 / (2 * sigma**2)) * pt.sum(
2415+
pt.square(value[node1] - value[node2])
2416+
)
2417+
zero_sum = (
2418+
-0.5 * pt.pow(pt.sum(value) / (zero_sum_stdev * N), 2)
2419+
- pt.log(pt.sqrt(2.0 * np.pi))
2420+
- pt.log(zero_sum_stdev * N)
2421+
)
2422+
2423+
return check_parameters(pairwise_difference + zero_sum, sigma > 0, msg="sigma > 0")
2424+
2425+
22592426
class StickBreakingWeightsRV(RandomVariable):
22602427
name = "stick_breaking_weights"
22612428
ndim_supp = 1

tests/distributions/test_multivariate.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,18 @@ def test_car_moment(self, mu, size, expected):
10811081
pm.CAR("x", mu=mu, W=W, alpha=alpha, tau=tau, size=size)
10821082
assert_moment_is_expected(model, expected)
10831083

1084+
@pytest.mark.parametrize(
1085+
"W, expected",
1086+
[
1087+
(np.array([[0, 1, 0], [1, 0, 1], [0, 1, 0]]), np.array([0, 0, 0])),
1088+
(np.array([[0, 1], [1, 0]]), np.array([0, 0])),
1089+
],
1090+
)
1091+
def test_icar_moment(self, W, expected):
1092+
with pm.Model() as model:
1093+
RV = pm.ICAR("x", W=W)
1094+
assert_moment_is_expected(model, expected)
1095+
10841096
@pytest.mark.parametrize(
10851097
"nu, mu, cov, size, expected",
10861098
[
@@ -2087,6 +2099,66 @@ def check_draws_match_expected(self):
20872099
assert np.all(np.abs(draw(x, random_seed=rng) - np.array([0.5, 0, 2.0])) < 0.01)
20882100

20892101

2102+
class TestICAR(BaseTestDistributionRandom):
2103+
pymc_dist = pm.ICAR
2104+
pymc_dist_params = {"W": np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]]), "sigma": 2}
2105+
expected_rv_op_params = {
2106+
"W": np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]]),
2107+
"node1": np.array([1, 2, 2]),
2108+
"node2": np.array([0, 0, 1]),
2109+
"N": 3,
2110+
"sigma": 2,
2111+
"zero_sum_strength": 0.001,
2112+
}
2113+
checks_to_run = ["check_pymc_params_match_rv_op", "check_rv_inferred_size"]
2114+
2115+
def check_rv_inferred_size(self):
2116+
sizes_to_check = [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
2117+
sizes_expected = [(3,), (3,), (1, 3), (1, 3), (5, 3), (4, 5, 3), (2, 4, 2, 3)]
2118+
for size, expected in zip(sizes_to_check, sizes_expected):
2119+
pymc_rv = self.pymc_dist.dist(**self.pymc_dist_params, size=size)
2120+
expected_symbolic = tuple(pymc_rv.shape.eval())
2121+
assert expected_symbolic == expected
2122+
2123+
def test_icar_logp(self):
2124+
W = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0]])
2125+
2126+
with pm.Model() as m:
2127+
RV = pm.ICAR("phi", W=W)
2128+
2129+
assert pt.isclose(
2130+
pm.logp(RV, np.array([0.01, -0.03, 0.02, 0.00])).eval(), np.array(4.60022238)
2131+
).eval(), "logp inaccuracy"
2132+
2133+
def test_icar_rng_fn(self):
2134+
W = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0]])
2135+
2136+
RV = pm.ICAR.dist(W=W)
2137+
2138+
with pytest.raises(NotImplementedError, match="Cannot sample from ICAR prior"):
2139+
pm.draw(RV)
2140+
2141+
@pytest.mark.parametrize(
2142+
"W,msg",
2143+
[
2144+
(np.array([0, 1, 0, 0]), "W must be matrix with ndim=2"),
2145+
(np.array([[0, 1, 0, 0], [1, 0, 0, 1], [1, 0, 0, 1]]), "W must be a square matrix"),
2146+
(
2147+
np.array([[0, 1, 0, 0], [1, 0, 0, 1], [1, 0, 0, 1], [0, 1, 1, 0]]),
2148+
"W must be a symmetric matrix",
2149+
),
2150+
(
2151+
np.array([[0, 1, 1, 0], [1, 0, 0, 0.5], [1, 0, 0, 1], [0, 0.5, 1, 0]]),
2152+
"W must be composed of only 1s and 0s",
2153+
),
2154+
],
2155+
)
2156+
def test_icar_matrix_checks(self, W, msg):
2157+
with pytest.raises(ValueError, match=msg):
2158+
with pm.Model():
2159+
pm.ICAR("phi", W=W)
2160+
2161+
20902162
@pytest.mark.parametrize("sparse", [True, False])
20912163
def test_car_rng_fn(sparse):
20922164
delta = 0.05 # limit for KS p-value

0 commit comments

Comments
 (0)