1
1
< << << << HEAD
2
+ < << << << HEAD
3
+ == == == =
4
+ >> >> >> > cb0c201 (latest ZeroSumNormal code , pymc3 v3 , random seed for sampling )
2
5
from typing import List
3
6
4
7
try :
5
8
import aesara .tensor as aet
6
9
except ImportError :
7
10
import theano .tensor as aet
8
11
12
+ << << << < HEAD
9
13
import numpy as np
10
14
import pymc3 as pm
11
15
from scipy import stats
@@ -141,68 +145,48 @@ def logcdf(self, value):
141
145
raise NotImplementedError ()
142
146
== == == =
143
147
import pymc3 as pm
148
+ == == == =
149
+ >> >> >> > cb0c201 (latest ZeroSumNormal code , pymc3 v3 , random seed for sampling )
144
150
import numpy as np
145
- import pandas as pd
146
- from typing import *
147
- import aesara
148
- import aesara .tensor as aet
149
-
150
-
151
- def ZeroSumNormal (
152
- name : str ,
153
- sigma : Optional [float ] = None ,
154
- * ,
155
- dims : Union [str , Tuple [str ]],
156
- model : Optional [pm .Model ] = None ,
157
- ):
158
- """
159
- Multivariate normal, such that sum(x, axis=-1) = 0.
160
-
161
- Parameters
162
- ----------
163
- name: str
164
- String name representation of the PyMC variable.
165
- sigma: Optional[float], defaults to None
166
- Scale for the Normal distribution. If ``None``, a standard Normal is used.
167
- dims: Union[str, Tuple[str]]
168
- Dimension names for the shape of the distribution.
169
- See https://docs.pymc.io/pymc-examples/examples/pymc3_howto/data_container.html for an example.
170
- model: Optional[pm.Model], defaults to None
171
- PyMC model instance. If ``None``, a model instance is created.
172
- """
173
- if isinstance (dims , str ):
174
- dims = (dims ,)
151
+ import pymc3 as pm
152
+ from scipy import stats
153
+ from pymc3 .distributions .distribution import generate_samples , draw_values
175
154
176
- model = pm .modelcontext (model )
177
- * dims_pre , dim = dims
178
- dim_trunc = f"{ dim } _truncated_"
179
- (shape ,) = model .shape_from_dims ((dim ,))
180
- assert shape >= 1
155
+ def extend_axis_aet (array , axis ):
156
+ n = array .shape [axis ] + 1
157
+ sum_vals = array .sum (axis , keepdims = True )
158
+ norm = sum_vals / (np .sqrt (n ) + n )
159
+ fill_val = norm - sum_vals / np .sqrt (n )
160
+
161
+ out = aet .concatenate ([array , fill_val .astype (str (array .dtype ))], axis = axis )
162
+ return out - norm .astype (str (array .dtype ))
181
163
182
- model .add_coords ({f"{ dim } _truncated_" : pd .RangeIndex (shape - 1 )})
183
- raw = pm .Normal (f"{ name } _truncated_" , dims = tuple (dims_pre ) + (dim_trunc ,), sigma = sigma )
184
- Q = make_sum_zero_hh (shape )
185
- draws = aet .dot (raw , Q [:, 1 :].T )
186
164
187
- #if sigma is not None:
188
- # draws = sigma * draws
165
+ def extend_axis_rev_aet (array : np .ndarray , axis : int ):
166
+ if axis < 0 :
167
+ axis = axis % array .ndim
168
+ assert axis >= 0 and axis < array .ndim
189
169
190
- return pm .Deterministic (name , draws , dims = dims )
170
+ n = array .shape [axis ]
171
+ last = aet .take (array , [- 1 ], axis = axis )
172
+
173
+ sum_vals = - last * np .sqrt (n )
174
+ norm = sum_vals / (np .sqrt (n ) + n )
175
+ slice_before = (slice (None , None ),) * axis
176
+ return array [slice_before + (slice (None , - 1 ),)] + norm .astype (str (array .dtype ))
191
177
192
178
179
+ def extend_axis (array , axis ):
180
+ n = array .shape [axis ] + 1
181
+ sum_vals = array .sum (axis , keepdims = True )
182
+ norm = sum_vals / (np .sqrt (n ) + n )
183
+ fill_val = norm - sum_vals / np .sqrt (n )
184
+
185
+ out = np .concatenate ([array , fill_val .astype (str (array .dtype ))], axis = axis )
186
+ return out - norm .astype (str (array .dtype ))
193
187
194
- def make_sum_zero_hh (N : int ) -> np .ndarray :
195
- """
196
- Build a householder transformation matrix that maps e_1 to a vector of all 1s.
197
- """
198
- e_1 = np .zeros (N )
199
- e_1 [0 ] = 1
200
- a = np .ones (N )
201
- a /= np .sqrt (a @ a )
202
- v = a + e_1
203
- v /= np .sqrt (v @ v )
204
- return np .eye (N ) - 2 * np .outer (v , v )
205
188
189
+ < << << << HEAD
206
190
def make_sum_zero_hh (N : int ) -> np .ndarray :
207
191
"""
208
192
Build a householder transformation matrix that maps e_1 to a vector of all 1s.
@@ -215,3 +199,99 @@ def make_sum_zero_hh(N: int) -> np.ndarray:
215
199
v /= np .sqrt (v @ v )
216
200
return np .eye (N ) - 2 * np .outer (v , v )
217
201
> >> >> >> 2 da3052 (ZeroSumNormal : initial commit )
202
+ == == == =
203
+ def extend_axis_rev (array , axis ):
204
+ n = array .shape [axis ]
205
+ last = np .take (array , [- 1 ], axis = axis )
206
+
207
+ sum_vals = - last * np .sqrt (n )
208
+ norm = sum_vals / (np .sqrt (n ) + n )
209
+ slice_before = (slice (None , None ),) * len (array .shape [:axis ])
210
+ return array [slice_before + (slice (None , - 1 ),)] + norm .astype (str (array .dtype ))
211
+
212
+
213
+ class ZeroSumTransform (pm .distributions .transforms .Transform ):
214
+ name = "zerosum"
215
+
216
+ _active_dims : List [int ]
217
+
218
+ def __init__ (self , active_dims ):
219
+ self ._active_dims = active_dims
220
+
221
+ def forward (self , x ):
222
+ for axis in self ._active_dims :
223
+ x = extend_axis_rev_aet (x , axis = axis )
224
+ return x
225
+
226
+ def forward_val (self , x , point = None ):
227
+ for axis in self ._active_dims :
228
+ x = extend_axis_rev (x , axis = axis )
229
+ return x
230
+
231
+ def backward (self , z ):
232
+ z = aet .as_tensor_variable (z )
233
+ for axis in self ._active_dims :
234
+ z = extend_axis_aet (z , axis = axis )
235
+ return z
236
+
237
+ def jacobian_det (self , x ):
238
+ return aet .constant (0. )
239
+
240
+
241
+ class ZeroSumNormal (pm .Continuous ):
242
+ def __init__ (self , sigma = 1 , * , active_dims = None , active_axes = None , ** kwargs ):
243
+ shape = kwargs .get ("shape" , ())
244
+ dims = kwargs .get ("dims" , None )
245
+ if isinstance (shape , int ):
246
+ shape = (shape ,)
247
+
248
+ if isinstance (dims , str ):
249
+ dims = (dims ,)
250
+
251
+ self .mu = self .median = self .mode = aet .zeros (shape )
252
+ self .sigma = aet .as_tensor_variable (sigma )
253
+
254
+ if active_dims is None and active_axes is None :
255
+ if shape :
256
+ active_axes = (- 1 ,)
257
+ else :
258
+ active_axes = ()
259
+
260
+ if isinstance (active_axes , int ):
261
+ active_axes = (active_axes ,)
262
+
263
+ if isinstance (active_dims , str ):
264
+ active_dims = (active_dims ,)
265
+
266
+ if active_axes is not None and active_dims is not None :
267
+ raise ValueError ("Only one of active_axes and active_dims can be specified." )
268
+
269
+ if active_dims is not None :
270
+ model = pm .modelcontext (None )
271
+ print (model .RV_dims )
272
+ if dims is None :
273
+ raise ValueError ("active_dims can only be used with the dims kwargs." )
274
+ active_axes = []
275
+ for dim in active_dims :
276
+ active_axes .append (dims .index (dim ))
277
+
278
+ super ().__init__ (** kwargs , transform = ZeroSumTransform (active_axes ))
279
+
280
+ def logp (self , x ):
281
+ return pm .Normal .dist (sigma = self .sigma ).logp (x )
282
+
283
+ @staticmethod
284
+ def _random (scale , size ):
285
+ samples = stats .norm .rvs (loc = 0 , scale = scale , size = size )
286
+ return samples - np .mean (samples , axis = - 1 , keepdims = True )
287
+
288
+ def random (self , point = None , size = None ):
289
+ sigma , = draw_values ([self .sigma ], point = point , size = size )
290
+ return generate_samples (self ._random , scale = sigma , dist_shape = self .shape , size = size )
291
+
292
+ def _distr_parameters_for_repr (self ):
293
+ return ["sigma" ]
294
+
295
+ def logcdf (self , value ):
296
+ raise NotImplementedError ()
297
+ > >> >> >> cb0c201 (latest ZeroSumNormal code , pymc3 v3 , random seed for sampling )
0 commit comments