1
+ import pymc3 as pm
2
+ import numpy as np
3
+ import pandas as pd
4
+ from typing import *
5
+ import aesara
6
+ import aesara .tensor as aet
7
+
8
+
9
+ def ZeroSumNormal (
10
+ name : str ,
11
+ sigma : Optional [float ] = None ,
12
+ * ,
13
+ dims : Union [str , Tuple [str ]],
14
+ model : Optional [pm .Model ] = None ,
15
+ ):
16
+ """
17
+ Multivariate normal, such that sum(x, axis=-1) = 0.
18
+
19
+ Parameters
20
+ ----------
21
+ name: str
22
+ String name representation of the PyMC variable.
23
+ sigma: Optional[float], defaults to None
24
+ Scale for the Normal distribution. If ``None``, a standard Normal is used.
25
+ dims: Union[str, Tuple[str]]
26
+ Dimension names for the shape of the distribution.
27
+ See https://docs.pymc.io/pymc-examples/examples/pymc3_howto/data_container.html for an example.
28
+ model: Optional[pm.Model], defaults to None
29
+ PyMC model instance. If ``None``, a model instance is created.
30
+ """
31
+ if isinstance (dims , str ):
32
+ dims = (dims ,)
33
+
34
+ model = pm .modelcontext (model )
35
+ * dims_pre , dim = dims
36
+ dim_trunc = f"{ dim } _truncated_"
37
+ (shape ,) = model .shape_from_dims ((dim ,))
38
+ assert shape >= 1
39
+
40
+ model .add_coords ({f"{ dim } _truncated_" : pd .RangeIndex (shape - 1 )})
41
+ raw = pm .Normal (f"{ name } _truncated_" , dims = tuple (dims_pre ) + (dim_trunc ,), sigma = sigma )
42
+ Q = make_sum_zero_hh (shape )
43
+ draws = aet .dot (raw , Q [:, 1 :].T )
44
+
45
+ #if sigma is not None:
46
+ # draws = sigma * draws
47
+
48
+ return pm .Deterministic (name , draws , dims = dims )
49
+
50
+
51
+
52
+ def make_sum_zero_hh (N : int ) -> np .ndarray :
53
+ """
54
+ Build a householder transformation matrix that maps e_1 to a vector of all 1s.
55
+ """
56
+ e_1 = np .zeros (N )
57
+ e_1 [0 ] = 1
58
+ a = np .ones (N )
59
+ a /= np .sqrt (a @ a )
60
+ v = a + e_1
61
+ v /= np .sqrt (v @ v )
62
+ return np .eye (N ) - 2 * np .outer (v , v )
63
+
64
+ def make_sum_zero_hh (N : int ) -> np .ndarray :
65
+ """
66
+ Build a householder transformation matrix that maps e_1 to a vector of all 1s.
67
+ """
68
+ e_1 = np .zeros (N )
69
+ e_1 [0 ] = 1
70
+ a = np .ones (N )
71
+ a /= np .sqrt (a @ a )
72
+ v = a + e_1
73
+ v /= np .sqrt (v @ v )
74
+ return np .eye (N ) - 2 * np .outer (v , v )
0 commit comments