Skip to content

Commit 0ee27f5

Browse files
committed
latest ZeroSumNormal code, pymc3 v3, random seed for sampling
1 parent dc2e3a8 commit 0ee27f5

File tree

2 files changed

+1488
-3006
lines changed

2 files changed

+1488
-3006
lines changed

examples/generalized_linear_models/GLM-ZeroSumNormal.ipynb

Lines changed: 1355 additions & 2953 deletions
Large diffs are not rendered by default.

examples/generalized_linear_models/ZeroSumNormal.py

Lines changed: 133 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
<<<<<<< HEAD
2+
<<<<<<< HEAD
3+
=======
4+
>>>>>>> cb0c201 (latest ZeroSumNormal code, pymc3 v3, random seed for sampling)
25
from typing import List
36

47
try:
58
import aesara.tensor as aet
69
except ImportError:
710
import theano.tensor as aet
811

12+
<<<<<<< HEAD
913
import numpy as np
1014
import pymc3 as pm
1115
from scipy import stats
@@ -141,68 +145,48 @@ def logcdf(self, value):
141145
raise NotImplementedError()
142146
=======
143147
import pymc3 as pm
148+
=======
149+
>>>>>>> cb0c201 (latest ZeroSumNormal code, pymc3 v3, random seed for sampling)
144150
import numpy as np
145-
import pandas as pd
146-
from typing import *
147-
import aesara
148-
import aesara.tensor as aet
149-
150-
151-
def ZeroSumNormal(
152-
name: str,
153-
sigma: Optional[float] = None,
154-
*,
155-
dims: Union[str, Tuple[str]],
156-
model: Optional[pm.Model] = None,
157-
):
158-
"""
159-
Multivariate normal, such that sum(x, axis=-1) = 0.
160-
161-
Parameters
162-
----------
163-
name: str
164-
String name representation of the PyMC variable.
165-
sigma: Optional[float], defaults to None
166-
Scale for the Normal distribution. If ``None``, a standard Normal is used.
167-
dims: Union[str, Tuple[str]]
168-
Dimension names for the shape of the distribution.
169-
See https://docs.pymc.io/pymc-examples/examples/pymc3_howto/data_container.html for an example.
170-
model: Optional[pm.Model], defaults to None
171-
PyMC model instance. If ``None``, a model instance is created.
172-
"""
173-
if isinstance(dims, str):
174-
dims = (dims,)
151+
import pymc3 as pm
152+
from scipy import stats
153+
from pymc3.distributions.distribution import generate_samples, draw_values
175154

176-
model = pm.modelcontext(model)
177-
*dims_pre, dim = dims
178-
dim_trunc = f"{dim}_truncated_"
179-
(shape,) = model.shape_from_dims((dim,))
180-
assert shape >= 1
155+
def extend_axis_aet(array, axis):
156+
n = array.shape[axis] + 1
157+
sum_vals = array.sum(axis, keepdims=True)
158+
norm = sum_vals / (np.sqrt(n) + n)
159+
fill_val = norm - sum_vals / np.sqrt(n)
160+
161+
out = aet.concatenate([array, fill_val.astype(str(array.dtype))], axis=axis)
162+
return out - norm.astype(str(array.dtype))
181163

182-
model.add_coords({f"{dim}_truncated_": pd.RangeIndex(shape - 1)})
183-
raw = pm.Normal(f"{name}_truncated_", dims=tuple(dims_pre) + (dim_trunc,), sigma=sigma)
184-
Q = make_sum_zero_hh(shape)
185-
draws = aet.dot(raw, Q[:, 1:].T)
186164

187-
#if sigma is not None:
188-
# draws = sigma * draws
165+
def extend_axis_rev_aet(array: np.ndarray, axis: int):
166+
if axis < 0:
167+
axis = axis % array.ndim
168+
assert axis >= 0 and axis < array.ndim
189169

190-
return pm.Deterministic(name, draws, dims=dims)
170+
n = array.shape[axis]
171+
last = aet.take(array, [-1], axis=axis)
172+
173+
sum_vals = -last * np.sqrt(n)
174+
norm = sum_vals / (np.sqrt(n) + n)
175+
slice_before = (slice(None, None),) * axis
176+
return array[slice_before + (slice(None, -1),)] + norm.astype(str(array.dtype))
191177

192178

179+
def extend_axis(array, axis):
180+
n = array.shape[axis] + 1
181+
sum_vals = array.sum(axis, keepdims=True)
182+
norm = sum_vals / (np.sqrt(n) + n)
183+
fill_val = norm - sum_vals / np.sqrt(n)
184+
185+
out = np.concatenate([array, fill_val.astype(str(array.dtype))], axis=axis)
186+
return out - norm.astype(str(array.dtype))
193187

194-
def make_sum_zero_hh(N: int) -> np.ndarray:
195-
"""
196-
Build a householder transformation matrix that maps e_1 to a vector of all 1s.
197-
"""
198-
e_1 = np.zeros(N)
199-
e_1[0] = 1
200-
a = np.ones(N)
201-
a /= np.sqrt(a @ a)
202-
v = a + e_1
203-
v /= np.sqrt(v @ v)
204-
return np.eye(N) - 2 * np.outer(v, v)
205188

189+
<<<<<<< HEAD
206190
def make_sum_zero_hh(N: int) -> np.ndarray:
207191
"""
208192
Build a householder transformation matrix that maps e_1 to a vector of all 1s.
@@ -215,3 +199,99 @@ def make_sum_zero_hh(N: int) -> np.ndarray:
215199
v /= np.sqrt(v @ v)
216200
return np.eye(N) - 2 * np.outer(v, v)
217201
>>>>>>> 2da3052 (ZeroSumNormal: initial commit)
202+
=======
203+
def extend_axis_rev(array, axis):
204+
n = array.shape[axis]
205+
last = np.take(array, [-1], axis=axis)
206+
207+
sum_vals = -last * np.sqrt(n)
208+
norm = sum_vals / (np.sqrt(n) + n)
209+
slice_before = (slice(None, None),) * len(array.shape[:axis])
210+
return array[slice_before + (slice(None, -1),)] + norm.astype(str(array.dtype))
211+
212+
213+
class ZeroSumTransform(pm.distributions.transforms.Transform):
214+
name = "zerosum"
215+
216+
_active_dims: List[int]
217+
218+
def __init__(self, active_dims):
219+
self._active_dims = active_dims
220+
221+
def forward(self, x):
222+
for axis in self._active_dims:
223+
x = extend_axis_rev_aet(x, axis=axis)
224+
return x
225+
226+
def forward_val(self, x, point=None):
227+
for axis in self._active_dims:
228+
x = extend_axis_rev(x, axis=axis)
229+
return x
230+
231+
def backward(self, z):
232+
z = aet.as_tensor_variable(z)
233+
for axis in self._active_dims:
234+
z = extend_axis_aet(z, axis=axis)
235+
return z
236+
237+
def jacobian_det(self, x):
238+
return aet.constant(0.)
239+
240+
241+
class ZeroSumNormal(pm.Continuous):
242+
def __init__(self, sigma=1, *, active_dims=None, active_axes=None, **kwargs):
243+
shape = kwargs.get("shape", ())
244+
dims = kwargs.get("dims", None)
245+
if isinstance(shape, int):
246+
shape = (shape,)
247+
248+
if isinstance(dims, str):
249+
dims = (dims,)
250+
251+
self.mu = self.median = self.mode = aet.zeros(shape)
252+
self.sigma = aet.as_tensor_variable(sigma)
253+
254+
if active_dims is None and active_axes is None:
255+
if shape:
256+
active_axes = (-1,)
257+
else:
258+
active_axes = ()
259+
260+
if isinstance(active_axes, int):
261+
active_axes = (active_axes,)
262+
263+
if isinstance(active_dims, str):
264+
active_dims = (active_dims,)
265+
266+
if active_axes is not None and active_dims is not None:
267+
raise ValueError("Only one of active_axes and active_dims can be specified.")
268+
269+
if active_dims is not None:
270+
model = pm.modelcontext(None)
271+
print(model.RV_dims)
272+
if dims is None:
273+
raise ValueError("active_dims can only be used with the dims kwargs.")
274+
active_axes = []
275+
for dim in active_dims:
276+
active_axes.append(dims.index(dim))
277+
278+
super().__init__(**kwargs, transform=ZeroSumTransform(active_axes))
279+
280+
def logp(self, x):
281+
return pm.Normal.dist(sigma=self.sigma).logp(x)
282+
283+
@staticmethod
284+
def _random(scale, size):
285+
samples = stats.norm.rvs(loc=0, scale=scale, size=size)
286+
return samples - np.mean(samples, axis=-1, keepdims=True)
287+
288+
def random(self, point=None, size=None):
289+
sigma, = draw_values([self.sigma], point=point, size=size)
290+
return generate_samples(self._random, scale=sigma, dist_shape=self.shape, size=size)
291+
292+
def _distr_parameters_for_repr(self):
293+
return ["sigma"]
294+
295+
def logcdf(self, value):
296+
raise NotImplementedError()
297+
>>>>>>> cb0c201 (latest ZeroSumNormal code, pymc3 v3, random seed for sampling)

0 commit comments

Comments
 (0)