1
- import pymc3 as pm
1
+ from typing import List
2
+
3
+ try :
4
+ import aesara .tensor as aet
5
+ except ImportError :
6
+ import theano .tensor as aet
7
+
2
8
import numpy as np
3
- import pandas as pd
4
- from typing import *
5
- import aesara
6
- import aesara .tensor as aet
7
-
8
-
9
- def ZeroSumNormal (
10
- name : str ,
11
- sigma : Optional [float ] = None ,
12
- * ,
13
- dims : Union [str , Tuple [str ]],
14
- model : Optional [pm .Model ] = None ,
15
- ):
16
- """
17
- Multivariate normal, such that sum(x, axis=-1) = 0.
18
-
19
- Parameters
20
- ----------
21
- name: str
22
- String name representation of the PyMC variable.
23
- sigma: Optional[float], defaults to None
24
- Scale for the Normal distribution. If ``None``, a standard Normal is used.
25
- dims: Union[str, Tuple[str]]
26
- Dimension names for the shape of the distribution.
27
- See https://docs.pymc.io/pymc-examples/examples/pymc3_howto/data_container.html for an example.
28
- model: Optional[pm.Model], defaults to None
29
- PyMC model instance. If ``None``, a model instance is created.
30
- """
31
- if isinstance (dims , str ):
32
- dims = (dims ,)
33
-
34
- model = pm .modelcontext (model )
35
- * dims_pre , dim = dims
36
- dim_trunc = f"{ dim } _truncated_"
37
- (shape ,) = model .shape_from_dims ((dim ,))
38
- assert shape >= 1
39
-
40
- model .add_coords ({f"{ dim } _truncated_" : pd .RangeIndex (shape - 1 )})
41
- raw = pm .Normal (f"{ name } _truncated_" , dims = tuple (dims_pre ) + (dim_trunc ,), sigma = sigma )
42
- Q = make_sum_zero_hh (shape )
43
- draws = aet .dot (raw , Q [:, 1 :].T )
44
-
45
- #if sigma is not None:
46
- # draws = sigma * draws
47
-
48
- return pm .Deterministic (name , draws , dims = dims )
49
-
50
-
51
-
52
- def make_sum_zero_hh (N : int ) -> np .ndarray :
53
- """
54
- Build a householder transformation matrix that maps e_1 to a vector of all 1s.
55
- """
56
- e_1 = np .zeros (N )
57
- e_1 [0 ] = 1
58
- a = np .ones (N )
59
- a /= np .sqrt (a @ a )
60
- v = a + e_1
61
- v /= np .sqrt (v @ v )
62
- return np .eye (N ) - 2 * np .outer (v , v )
63
-
64
- def make_sum_zero_hh (N : int ) -> np .ndarray :
65
- """
66
- Build a householder transformation matrix that maps e_1 to a vector of all 1s.
67
- """
68
- e_1 = np .zeros (N )
69
- e_1 [0 ] = 1
70
- a = np .ones (N )
71
- a /= np .sqrt (a @ a )
72
- v = a + e_1
73
- v /= np .sqrt (v @ v )
74
- return np .eye (N ) - 2 * np .outer (v , v )
9
+ import pymc3 as pm
10
+ from scipy import stats
11
+ from pymc3 .distributions .distribution import generate_samples , draw_values
12
+
13
+ def extend_axis_aet (array , axis ):
14
+ n = array .shape [axis ] + 1
15
+ sum_vals = array .sum (axis , keepdims = True )
16
+ norm = sum_vals / (np .sqrt (n ) + n )
17
+ fill_val = norm - sum_vals / np .sqrt (n )
18
+
19
+ out = aet .concatenate ([array , fill_val .astype (str (array .dtype ))], axis = axis )
20
+ return out - norm .astype (str (array .dtype ))
21
+
22
+
23
+ def extend_axis_rev_aet (array : np .ndarray , axis : int ):
24
+ if axis < 0 :
25
+ axis = axis % array .ndim
26
+ assert axis >= 0 and axis < array .ndim
27
+
28
+ n = array .shape [axis ]
29
+ last = aet .take (array , [- 1 ], axis = axis )
30
+
31
+ sum_vals = - last * np .sqrt (n )
32
+ norm = sum_vals / (np .sqrt (n ) + n )
33
+ slice_before = (slice (None , None ),) * axis
34
+ return array [slice_before + (slice (None , - 1 ),)] + norm .astype (str (array .dtype ))
35
+
36
+
37
+ def extend_axis (array , axis ):
38
+ n = array .shape [axis ] + 1
39
+ sum_vals = array .sum (axis , keepdims = True )
40
+ norm = sum_vals / (np .sqrt (n ) + n )
41
+ fill_val = norm - sum_vals / np .sqrt (n )
42
+
43
+ out = np .concatenate ([array , fill_val .astype (str (array .dtype ))], axis = axis )
44
+ return out - norm .astype (str (array .dtype ))
45
+
46
+
47
+ def extend_axis_rev (array , axis ):
48
+ n = array .shape [axis ]
49
+ last = np .take (array , [- 1 ], axis = axis )
50
+
51
+ sum_vals = - last * np .sqrt (n )
52
+ norm = sum_vals / (np .sqrt (n ) + n )
53
+ slice_before = (slice (None , None ),) * len (array .shape [:axis ])
54
+ return array [slice_before + (slice (None , - 1 ),)] + norm .astype (str (array .dtype ))
55
+
56
+
57
+ class ZeroSumTransform (pm .distributions .transforms .Transform ):
58
+ name = "zerosum"
59
+
60
+ _active_dims : List [int ]
61
+
62
+ def __init__ (self , active_dims ):
63
+ self ._active_dims = active_dims
64
+
65
+ def forward (self , x ):
66
+ for axis in self ._active_dims :
67
+ x = extend_axis_rev_aet (x , axis = axis )
68
+ return x
69
+
70
+ def forward_val (self , x , point = None ):
71
+ for axis in self ._active_dims :
72
+ x = extend_axis_rev (x , axis = axis )
73
+ return x
74
+
75
+ def backward (self , z ):
76
+ z = aet .as_tensor_variable (z )
77
+ for axis in self ._active_dims :
78
+ z = extend_axis_aet (z , axis = axis )
79
+ return z
80
+
81
+ def jacobian_det (self , x ):
82
+ return aet .constant (0. )
83
+
84
+
85
+ class ZeroSumNormal (pm .Continuous ):
86
+ def __init__ (self , sigma = 1 , * , active_dims = None , active_axes = None , ** kwargs ):
87
+ shape = kwargs .get ("shape" , ())
88
+ dims = kwargs .get ("dims" , None )
89
+ if isinstance (shape , int ):
90
+ shape = (shape ,)
91
+
92
+ if isinstance (dims , str ):
93
+ dims = (dims ,)
94
+
95
+ self .mu = self .median = self .mode = aet .zeros (shape )
96
+ self .sigma = aet .as_tensor_variable (sigma )
97
+
98
+ if active_dims is None and active_axes is None :
99
+ if shape :
100
+ active_axes = (- 1 ,)
101
+ else :
102
+ active_axes = ()
103
+
104
+ if isinstance (active_axes , int ):
105
+ active_axes = (active_axes ,)
106
+
107
+ if isinstance (active_dims , str ):
108
+ active_dims = (active_dims ,)
109
+
110
+ if active_axes is not None and active_dims is not None :
111
+ raise ValueError ("Only one of active_axes and active_dims can be specified." )
112
+
113
+ if active_dims is not None :
114
+ model = pm .modelcontext (None )
115
+ print (model .RV_dims )
116
+ if dims is None :
117
+ raise ValueError ("active_dims can only be used with the dims kwargs." )
118
+ active_axes = []
119
+ for dim in active_dims :
120
+ active_axes .append (dims .index (dim ))
121
+
122
+ super ().__init__ (** kwargs , transform = ZeroSumTransform (active_axes ))
123
+
124
+ def logp (self , x ):
125
+ return pm .Normal .dist (sigma = self .sigma ).logp (x )
126
+
127
+ @staticmethod
128
+ def _random (scale , size ):
129
+ samples = stats .norm .rvs (loc = 0 , scale = scale , size = size )
130
+ return samples - np .mean (samples , axis = - 1 , keepdims = True )
131
+
132
+ def random (self , point = None , size = None ):
133
+ sigma , = draw_values ([self .sigma ], point = point , size = size )
134
+ return generate_samples (self ._random , scale = sigma , dist_shape = self .shape , size = size )
135
+
136
+ def _distr_parameters_for_repr (self ):
137
+ return ["sigma" ]
138
+
139
+ def logcdf (self , value ):
140
+ raise NotImplementedError ()
0 commit comments