Skip to content

Commit 40f26e4

Browse files
committed
make tests pass - formatting of ZeroSumNormal.py
1 parent a2370d5 commit 40f26e4

File tree

1 file changed

+24
-23
lines changed

1 file changed

+24
-23
lines changed

examples/generalized_linear_models/ZeroSumNormal.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@
1010
from scipy import stats
1111
from pymc3.distributions.distribution import generate_samples, draw_values
1212

13+
1314
def extend_axis_aet(array, axis):
1415
n = array.shape[axis] + 1
1516
sum_vals = array.sum(axis, keepdims=True)
1617
norm = sum_vals / (np.sqrt(n) + n)
1718
fill_val = norm - sum_vals / np.sqrt(n)
18-
19+
1920
out = aet.concatenate([array, fill_val.astype(str(array.dtype))], axis=axis)
2021
return out - norm.astype(str(array.dtype))
2122

@@ -27,7 +28,7 @@ def extend_axis_rev_aet(array: np.ndarray, axis: int):
2728

2829
n = array.shape[axis]
2930
last = aet.take(array, [-1], axis=axis)
30-
31+
3132
sum_vals = -last * np.sqrt(n)
3233
norm = sum_vals / (np.sqrt(n) + n)
3334
slice_before = (slice(None, None),) * axis
@@ -39,15 +40,15 @@ def extend_axis(array, axis):
3940
sum_vals = array.sum(axis, keepdims=True)
4041
norm = sum_vals / (np.sqrt(n) + n)
4142
fill_val = norm - sum_vals / np.sqrt(n)
42-
43+
4344
out = np.concatenate([array, fill_val.astype(str(array.dtype))], axis=axis)
4445
return out - norm.astype(str(array.dtype))
4546

4647

4748
def extend_axis_rev(array, axis):
4849
n = array.shape[axis]
4950
last = np.take(array, [-1], axis=axis)
50-
51+
5152
sum_vals = -last * np.sqrt(n)
5253
norm = sum_vals / (np.sqrt(n) + n)
5354
slice_before = (slice(None, None),) * len(array.shape[:axis])
@@ -56,60 +57,60 @@ def extend_axis_rev(array, axis):
5657

5758
class ZeroSumTransform(pm.distributions.transforms.Transform):
5859
name = "zerosum"
59-
60+
6061
_active_dims: List[int]
61-
62+
6263
def __init__(self, active_dims):
6364
self._active_dims = active_dims
64-
65+
6566
def forward(self, x):
6667
for axis in self._active_dims:
6768
x = extend_axis_rev_aet(x, axis=axis)
6869
return x
69-
70+
7071
def forward_val(self, x, point=None):
7172
for axis in self._active_dims:
7273
x = extend_axis_rev(x, axis=axis)
7374
return x
74-
75+
7576
def backward(self, z):
7677
z = aet.as_tensor_variable(z)
7778
for axis in self._active_dims:
7879
z = extend_axis_aet(z, axis=axis)
7980
return z
80-
81+
8182
def jacobian_det(self, x):
82-
return aet.constant(0.)
83-
84-
83+
return aet.constant(0.0)
84+
85+
8586
class ZeroSumNormal(pm.Continuous):
8687
def __init__(self, sigma=1, *, active_dims=None, active_axes=None, **kwargs):
8788
shape = kwargs.get("shape", ())
8889
dims = kwargs.get("dims", None)
8990
if isinstance(shape, int):
9091
shape = (shape,)
91-
92+
9293
if isinstance(dims, str):
9394
dims = (dims,)
9495

9596
self.mu = self.median = self.mode = aet.zeros(shape)
9697
self.sigma = aet.as_tensor_variable(sigma)
97-
98+
9899
if active_dims is None and active_axes is None:
99100
if shape:
100101
active_axes = (-1,)
101102
else:
102103
active_axes = ()
103-
104+
104105
if isinstance(active_axes, int):
105106
active_axes = (active_axes,)
106-
107+
107108
if isinstance(active_dims, str):
108109
active_dims = (active_dims,)
109-
110+
110111
if active_axes is not None and active_dims is not None:
111112
raise ValueError("Only one of active_axes and active_dims can be specified.")
112-
113+
113114
if active_dims is not None:
114115
model = pm.modelcontext(None)
115116
print(model.RV_dims)
@@ -118,19 +119,19 @@ def __init__(self, sigma=1, *, active_dims=None, active_axes=None, **kwargs):
118119
active_axes = []
119120
for dim in active_dims:
120121
active_axes.append(dims.index(dim))
121-
122+
122123
super().__init__(**kwargs, transform=ZeroSumTransform(active_axes))
123124

124125
def logp(self, x):
125126
return pm.Normal.dist(sigma=self.sigma).logp(x)
126-
127+
127128
@staticmethod
128129
def _random(scale, size):
129130
samples = stats.norm.rvs(loc=0, scale=scale, size=size)
130131
return samples - np.mean(samples, axis=-1, keepdims=True)
131-
132+
132133
def random(self, point=None, size=None):
133-
sigma, = draw_values([self.sigma], point=point, size=size)
134+
(sigma,) = draw_values([self.sigma], point=point, size=size)
134135
return generate_samples(self._random, scale=sigma, dist_shape=self.shape, size=size)
135136

136137
def _distr_parameters_for_repr(self):

0 commit comments

Comments
 (0)