10
10
from scipy import stats
11
11
from pymc3 .distributions .distribution import generate_samples , draw_values
12
12
13
+
13
14
def extend_axis_aet (array , axis ):
14
15
n = array .shape [axis ] + 1
15
16
sum_vals = array .sum (axis , keepdims = True )
16
17
norm = sum_vals / (np .sqrt (n ) + n )
17
18
fill_val = norm - sum_vals / np .sqrt (n )
18
-
19
+
19
20
out = aet .concatenate ([array , fill_val .astype (str (array .dtype ))], axis = axis )
20
21
return out - norm .astype (str (array .dtype ))
21
22
@@ -27,7 +28,7 @@ def extend_axis_rev_aet(array: np.ndarray, axis: int):
27
28
28
29
n = array .shape [axis ]
29
30
last = aet .take (array , [- 1 ], axis = axis )
30
-
31
+
31
32
sum_vals = - last * np .sqrt (n )
32
33
norm = sum_vals / (np .sqrt (n ) + n )
33
34
slice_before = (slice (None , None ),) * axis
@@ -39,15 +40,15 @@ def extend_axis(array, axis):
39
40
sum_vals = array .sum (axis , keepdims = True )
40
41
norm = sum_vals / (np .sqrt (n ) + n )
41
42
fill_val = norm - sum_vals / np .sqrt (n )
42
-
43
+
43
44
out = np .concatenate ([array , fill_val .astype (str (array .dtype ))], axis = axis )
44
45
return out - norm .astype (str (array .dtype ))
45
46
46
47
47
48
def extend_axis_rev (array , axis ):
48
49
n = array .shape [axis ]
49
50
last = np .take (array , [- 1 ], axis = axis )
50
-
51
+
51
52
sum_vals = - last * np .sqrt (n )
52
53
norm = sum_vals / (np .sqrt (n ) + n )
53
54
slice_before = (slice (None , None ),) * len (array .shape [:axis ])
@@ -56,60 +57,60 @@ def extend_axis_rev(array, axis):
56
57
57
58
class ZeroSumTransform (pm .distributions .transforms .Transform ):
58
59
name = "zerosum"
59
-
60
+
60
61
_active_dims : List [int ]
61
-
62
+
62
63
def __init__ (self , active_dims ):
63
64
self ._active_dims = active_dims
64
-
65
+
65
66
def forward (self , x ):
66
67
for axis in self ._active_dims :
67
68
x = extend_axis_rev_aet (x , axis = axis )
68
69
return x
69
-
70
+
70
71
def forward_val (self , x , point = None ):
71
72
for axis in self ._active_dims :
72
73
x = extend_axis_rev (x , axis = axis )
73
74
return x
74
-
75
+
75
76
def backward (self , z ):
76
77
z = aet .as_tensor_variable (z )
77
78
for axis in self ._active_dims :
78
79
z = extend_axis_aet (z , axis = axis )
79
80
return z
80
-
81
+
81
82
def jacobian_det (self , x ):
82
- return aet .constant (0. )
83
-
84
-
83
+ return aet .constant (0.0 )
84
+
85
+
85
86
class ZeroSumNormal (pm .Continuous ):
86
87
def __init__ (self , sigma = 1 , * , active_dims = None , active_axes = None , ** kwargs ):
87
88
shape = kwargs .get ("shape" , ())
88
89
dims = kwargs .get ("dims" , None )
89
90
if isinstance (shape , int ):
90
91
shape = (shape ,)
91
-
92
+
92
93
if isinstance (dims , str ):
93
94
dims = (dims ,)
94
95
95
96
self .mu = self .median = self .mode = aet .zeros (shape )
96
97
self .sigma = aet .as_tensor_variable (sigma )
97
-
98
+
98
99
if active_dims is None and active_axes is None :
99
100
if shape :
100
101
active_axes = (- 1 ,)
101
102
else :
102
103
active_axes = ()
103
-
104
+
104
105
if isinstance (active_axes , int ):
105
106
active_axes = (active_axes ,)
106
-
107
+
107
108
if isinstance (active_dims , str ):
108
109
active_dims = (active_dims ,)
109
-
110
+
110
111
if active_axes is not None and active_dims is not None :
111
112
raise ValueError ("Only one of active_axes and active_dims can be specified." )
112
-
113
+
113
114
if active_dims is not None :
114
115
model = pm .modelcontext (None )
115
116
print (model .RV_dims )
@@ -118,19 +119,19 @@ def __init__(self, sigma=1, *, active_dims=None, active_axes=None, **kwargs):
118
119
active_axes = []
119
120
for dim in active_dims :
120
121
active_axes .append (dims .index (dim ))
121
-
122
+
122
123
super ().__init__ (** kwargs , transform = ZeroSumTransform (active_axes ))
123
124
124
125
def logp (self , x ):
125
126
return pm .Normal .dist (sigma = self .sigma ).logp (x )
126
-
127
+
127
128
@staticmethod
128
129
def _random (scale , size ):
129
130
samples = stats .norm .rvs (loc = 0 , scale = scale , size = size )
130
131
return samples - np .mean (samples , axis = - 1 , keepdims = True )
131
-
132
+
132
133
def random (self , point = None , size = None ):
133
- sigma , = draw_values ([self .sigma ], point = point , size = size )
134
+ ( sigma ,) = draw_values ([self .sigma ], point = point , size = size )
134
135
return generate_samples (self ._random , scale = sigma , dist_shape = self .shape , size = size )
135
136
136
137
def _distr_parameters_for_repr (self ):
0 commit comments