Skip to content

Commit 5bea0b8

Browse files
committed
latest ZeroSumNormal code, pymc3 v3, random seed for sampling
1 parent 68e0c53 commit 5bea0b8

File tree

2 files changed

+336
-314
lines changed

2 files changed

+336
-314
lines changed

examples/generalized_linear_models/GLM-ZeroSumNormal.ipynb

Lines changed: 197 additions & 241 deletions
Large diffs are not rendered by default.
Lines changed: 139 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,140 @@
1-
import pymc3 as pm
1+
from typing import List
2+
3+
try:
4+
import aesara.tensor as aet
5+
except ImportError:
6+
import theano.tensor as aet
7+
28
import numpy as np
3-
import pandas as pd
4-
from typing import *
5-
import aesara
6-
import aesara.tensor as aet
7-
8-
9-
def ZeroSumNormal(
10-
name: str,
11-
sigma: Optional[float] = None,
12-
*,
13-
dims: Union[str, Tuple[str]],
14-
model: Optional[pm.Model] = None,
15-
):
16-
"""
17-
Multivariate normal, such that sum(x, axis=-1) = 0.
18-
19-
Parameters
20-
----------
21-
name: str
22-
String name representation of the PyMC variable.
23-
sigma: Optional[float], defaults to None
24-
Scale for the Normal distribution. If ``None``, a standard Normal is used.
25-
dims: Union[str, Tuple[str]]
26-
Dimension names for the shape of the distribution.
27-
See https://docs.pymc.io/pymc-examples/examples/pymc3_howto/data_container.html for an example.
28-
model: Optional[pm.Model], defaults to None
29-
PyMC model instance. If ``None``, a model instance is created.
30-
"""
31-
if isinstance(dims, str):
32-
dims = (dims,)
33-
34-
model = pm.modelcontext(model)
35-
*dims_pre, dim = dims
36-
dim_trunc = f"{dim}_truncated_"
37-
(shape,) = model.shape_from_dims((dim,))
38-
assert shape >= 1
39-
40-
model.add_coords({f"{dim}_truncated_": pd.RangeIndex(shape - 1)})
41-
raw = pm.Normal(f"{name}_truncated_", dims=tuple(dims_pre) + (dim_trunc,), sigma=sigma)
42-
Q = make_sum_zero_hh(shape)
43-
draws = aet.dot(raw, Q[:, 1:].T)
44-
45-
#if sigma is not None:
46-
# draws = sigma * draws
47-
48-
return pm.Deterministic(name, draws, dims=dims)
49-
50-
51-
52-
def make_sum_zero_hh(N: int) -> np.ndarray:
53-
"""
54-
Build a householder transformation matrix that maps e_1 to a vector of all 1s.
55-
"""
56-
e_1 = np.zeros(N)
57-
e_1[0] = 1
58-
a = np.ones(N)
59-
a /= np.sqrt(a @ a)
60-
v = a + e_1
61-
v /= np.sqrt(v @ v)
62-
return np.eye(N) - 2 * np.outer(v, v)
63-
64-
def make_sum_zero_hh(N: int) -> np.ndarray:
65-
"""
66-
Build a householder transformation matrix that maps e_1 to a vector of all 1s.
67-
"""
68-
e_1 = np.zeros(N)
69-
e_1[0] = 1
70-
a = np.ones(N)
71-
a /= np.sqrt(a @ a)
72-
v = a + e_1
73-
v /= np.sqrt(v @ v)
74-
return np.eye(N) - 2 * np.outer(v, v)
9+
import pymc3 as pm
10+
from scipy import stats
11+
from pymc3.distributions.distribution import generate_samples, draw_values
12+
13+
def extend_axis_aet(array, axis):
14+
n = array.shape[axis] + 1
15+
sum_vals = array.sum(axis, keepdims=True)
16+
norm = sum_vals / (np.sqrt(n) + n)
17+
fill_val = norm - sum_vals / np.sqrt(n)
18+
19+
out = aet.concatenate([array, fill_val.astype(str(array.dtype))], axis=axis)
20+
return out - norm.astype(str(array.dtype))
21+
22+
23+
def extend_axis_rev_aet(array: np.ndarray, axis: int):
24+
if axis < 0:
25+
axis = axis % array.ndim
26+
assert axis >= 0 and axis < array.ndim
27+
28+
n = array.shape[axis]
29+
last = aet.take(array, [-1], axis=axis)
30+
31+
sum_vals = -last * np.sqrt(n)
32+
norm = sum_vals / (np.sqrt(n) + n)
33+
slice_before = (slice(None, None),) * axis
34+
return array[slice_before + (slice(None, -1),)] + norm.astype(str(array.dtype))
35+
36+
37+
def extend_axis(array, axis):
38+
n = array.shape[axis] + 1
39+
sum_vals = array.sum(axis, keepdims=True)
40+
norm = sum_vals / (np.sqrt(n) + n)
41+
fill_val = norm - sum_vals / np.sqrt(n)
42+
43+
out = np.concatenate([array, fill_val.astype(str(array.dtype))], axis=axis)
44+
return out - norm.astype(str(array.dtype))
45+
46+
47+
def extend_axis_rev(array, axis):
48+
n = array.shape[axis]
49+
last = np.take(array, [-1], axis=axis)
50+
51+
sum_vals = -last * np.sqrt(n)
52+
norm = sum_vals / (np.sqrt(n) + n)
53+
slice_before = (slice(None, None),) * len(array.shape[:axis])
54+
return array[slice_before + (slice(None, -1),)] + norm.astype(str(array.dtype))
55+
56+
57+
class ZeroSumTransform(pm.distributions.transforms.Transform):
58+
name = "zerosum"
59+
60+
_active_dims: List[int]
61+
62+
def __init__(self, active_dims):
63+
self._active_dims = active_dims
64+
65+
def forward(self, x):
66+
for axis in self._active_dims:
67+
x = extend_axis_rev_aet(x, axis=axis)
68+
return x
69+
70+
def forward_val(self, x, point=None):
71+
for axis in self._active_dims:
72+
x = extend_axis_rev(x, axis=axis)
73+
return x
74+
75+
def backward(self, z):
76+
z = aet.as_tensor_variable(z)
77+
for axis in self._active_dims:
78+
z = extend_axis_aet(z, axis=axis)
79+
return z
80+
81+
def jacobian_det(self, x):
82+
return aet.constant(0.)
83+
84+
85+
class ZeroSumNormal(pm.Continuous):
86+
def __init__(self, sigma=1, *, active_dims=None, active_axes=None, **kwargs):
87+
shape = kwargs.get("shape", ())
88+
dims = kwargs.get("dims", None)
89+
if isinstance(shape, int):
90+
shape = (shape,)
91+
92+
if isinstance(dims, str):
93+
dims = (dims,)
94+
95+
self.mu = self.median = self.mode = aet.zeros(shape)
96+
self.sigma = aet.as_tensor_variable(sigma)
97+
98+
if active_dims is None and active_axes is None:
99+
if shape:
100+
active_axes = (-1,)
101+
else:
102+
active_axes = ()
103+
104+
if isinstance(active_axes, int):
105+
active_axes = (active_axes,)
106+
107+
if isinstance(active_dims, str):
108+
active_dims = (active_dims,)
109+
110+
if active_axes is not None and active_dims is not None:
111+
raise ValueError("Only one of active_axes and active_dims can be specified.")
112+
113+
if active_dims is not None:
114+
model = pm.modelcontext(None)
115+
print(model.RV_dims)
116+
if dims is None:
117+
raise ValueError("active_dims can only be used with the dims kwargs.")
118+
active_axes = []
119+
for dim in active_dims:
120+
active_axes.append(dims.index(dim))
121+
122+
super().__init__(**kwargs, transform=ZeroSumTransform(active_axes))
123+
124+
def logp(self, x):
125+
return pm.Normal.dist(sigma=self.sigma).logp(x)
126+
127+
@staticmethod
128+
def _random(scale, size):
129+
samples = stats.norm.rvs(loc=0, scale=scale, size=size)
130+
return samples - np.mean(samples, axis=-1, keepdims=True)
131+
132+
def random(self, point=None, size=None):
133+
sigma, = draw_values([self.sigma], point=point, size=size)
134+
return generate_samples(self._random, scale=sigma, dist_shape=self.shape, size=size)
135+
136+
def _distr_parameters_for_repr(self):
137+
return ["sigma"]
138+
139+
def logcdf(self, value):
140+
raise NotImplementedError()

0 commit comments

Comments
 (0)