24
24
import numpy as np
25
25
import scipy
26
26
27
+ from aeppl .logprob import _logprob
27
28
from aesara .graph .basic import Apply , Constant , Variable
28
29
from aesara .graph .op import Op
29
30
from aesara .raise_op import Assert
32
33
from aesara .tensor .nlinalg import det , eigh , matrix_inverse , trace
33
34
from aesara .tensor .random .basic import dirichlet , multinomial , multivariate_normal
34
35
from aesara .tensor .random .op import RandomVariable , default_supp_shape_from_params
35
- from aesara .tensor .random .utils import broadcast_params , normalize_size_param
36
+ from aesara .tensor .random .utils import broadcast_params
36
37
from aesara .tensor .slinalg import Cholesky , SolveTriangular
37
38
from aesara .tensor .type import TensorType
38
39
from scipy import linalg , stats
49
50
logpow ,
50
51
multigammaln ,
51
52
)
52
- from pymc .distributions .distribution import Continuous , Discrete , moment
53
+ from pymc .distributions .distribution import (
54
+ Continuous ,
55
+ Discrete ,
56
+ Distribution ,
57
+ SymbolicRandomVariable ,
58
+ _moment ,
59
+ moment ,
60
+ )
53
61
from pymc .distributions .logprob import ignore_logprob
54
62
from pymc .distributions .shape_utils import (
63
+ _change_dist_size ,
55
64
broadcast_dist_samples_to ,
56
65
change_dist_size ,
57
66
rv_size_is_none ,
@@ -1097,12 +1106,12 @@ def _lkj_normalizing_constant(eta, n):
1097
1106
return result
1098
1107
1099
1108
1100
- class _LKJCholeskyCovRV (RandomVariable ):
1101
- name = "_lkjcholeskycov "
1109
+ class _LKJCholeskyCovBaseRV (RandomVariable ):
1110
+ name = "_lkjcholeskycovbase "
1102
1111
ndim_supp = 1
1103
1112
ndims_params = [0 , 0 , 1 ]
1104
1113
dtype = "floatX"
1105
- _print_name = ("_lkjcholeskycov " , "\\ operatorname{_lkjcholeskycov }" )
1114
+ _print_name = ("_lkjcholeskycovbase " , "\\ operatorname{_lkjcholeskycovbase }" )
1106
1115
1107
1116
def make_node (self , rng , size , dtype , n , eta , D ):
1108
1117
n = at .as_tensor_variable (n )
@@ -1115,35 +1124,19 @@ def make_node(self, rng, size, dtype, n, eta, D):
1115
1124
1116
1125
D = at .as_tensor_variable (D )
1117
1126
1118
- # We resize the sd_dist `D` automatically so that it has (size x n) independent
1119
- # draws which is what the `_LKJCholeskyCovRV.rng_fn` expects. This makes the
1120
- # random and logp methods equivalent, as the latter also assumes a unique value
1121
- # for each diagonal element.
1122
- # Since `eta` and `n` are forced to be scalars we don't need to worry about
1123
- # implied batched dimensions for the time being.
1124
- size = normalize_size_param (size )
1125
- if D .owner .op .ndim_supp == 0 :
1126
- D = change_dist_size (D , at .concatenate ((size , (n ,))))
1127
- else :
1128
- # The support shape must be `n` but we have no way of controlling it
1129
- D = change_dist_size (D , size )
1130
-
1131
1127
return super ().make_node (rng , size , dtype , n , eta , D )
1132
1128
1133
- def _infer_shape (self , size , dist_params , param_shapes = None ):
1129
+ def _supp_shape_from_params (self , dist_params , param_shapes ):
1134
1130
n = dist_params [0 ]
1135
- dist_shape = tuple (size ) + ((n * (n + 1 )) // 2 ,)
1136
- return dist_shape
1131
+ return ((n * (n + 1 )) // 2 ,)
1137
1132
1138
1133
def rng_fn (self , rng , n , eta , D , size ):
1139
1134
# We flatten the size to make operations easier, and then rebuild it
1140
1135
if size is None :
1141
- flat_size = 1
1142
- else :
1143
- flat_size = np .prod (size )
1144
-
1145
- C = LKJCorrRV ._random_corr_matrix (rng , n , eta , flat_size )
1136
+ size = D .shape [:- 1 ]
1137
+ flat_size = np .prod (size ).astype (int )
1146
1138
1139
+ C = LKJCorrRV ._random_corr_matrix (rng = rng , n = n , eta = eta , flat_size = flat_size )
1147
1140
D = D .reshape (flat_size , n )
1148
1141
C *= D [..., :, np .newaxis ] * D [..., np .newaxis , :]
1149
1142
@@ -1159,23 +1152,30 @@ def rng_fn(self, rng, n, eta, D, size):
1159
1152
return samples
1160
1153
1161
1154
1162
- _ljk_cholesky_cov = _LKJCholeskyCovRV ()
1155
+ _ljk_cholesky_cov_base = _LKJCholeskyCovBaseRV ()
1163
1156
1164
1157
1165
- class _LKJCholeskyCov (Continuous ):
1158
+ # _LKJCholeskyCovBaseRV requires a properly shaped `D`, which means the variable can't
1159
+ # be safely resized. Because of this, we add the thin SymbolicRandomVariable wrapper
1160
+ class _LKJCholeskyCovRV (SymbolicRandomVariable ):
1161
+ default_output = 1
1162
+ _print_name = ("_lkjcholeskycov" , "\\ operatorname{_lkjcholeskycov}" )
1163
+
1164
+ def update (self , node ):
1165
+ return {node .inputs [0 ]: node .outputs [0 ]}
1166
+
1167
+
1168
+ class _LKJCholeskyCov (Distribution ):
1166
1169
r"""Underlying class for covariance matrix with LKJ distributed correlations.
1167
1170
See docs for LKJCholeskyCov function for more details on how to use it in models.
1168
1171
"""
1169
- rv_op = _ljk_cholesky_cov
1170
1172
1171
- def __new__ (cls , name , eta , n , sd_dist , ** kwargs ):
1172
- check_dist_not_registered (sd_dist )
1173
- return super ().__new__ (cls , name , eta , n , sd_dist , ** kwargs )
1173
+ rv_type = _LKJCholeskyCovRV
1174
1174
1175
1175
@classmethod
1176
- def dist (cls , eta , n , sd_dist , ** kwargs ):
1177
- eta = at .as_tensor_variable (floatX (eta ))
1176
+ def dist (cls , n , eta , sd_dist , ** kwargs ):
1178
1177
n = at .as_tensor_variable (intX (n ))
1178
+ eta = at .as_tensor_variable (floatX (eta ))
1179
1179
1180
1180
if not (
1181
1181
isinstance (sd_dist , Variable )
@@ -1185,75 +1185,105 @@ def dist(cls, eta, n, sd_dist, **kwargs):
1185
1185
):
1186
1186
raise TypeError ("sd_dist must be a scalar or vector distribution variable" )
1187
1187
1188
+ check_dist_not_registered (sd_dist )
1188
1189
# sd_dist is part of the generative graph, but should be completely ignored
1189
1190
# by the logp graph, since the LKJ logp explicitly includes these terms.
1190
- # TODO: Things could be simplified a bit if we managed to extract the
1191
- # sd_dist prior components from the logp expression.
1192
1191
sd_dist = ignore_logprob (sd_dist )
1193
-
1194
1192
return super ().dist ([n , eta , sd_dist ], ** kwargs )
1195
1193
1196
- def moment (rv , size , n , eta , sd_dists ):
1197
- diag_idxs = (at .cumsum (at .arange (1 , n + 1 )) - 1 ).astype ("int32" )
1198
- moment = at .zeros_like (rv )
1199
- moment = at .set_subtensor (moment [..., diag_idxs ], 1 )
1200
- return moment
1194
+ @classmethod
1195
+ def rv_op (cls , n , eta , sd_dist , size = None ):
1196
+ # We resize the sd_dist automatically so that it has (size x n) independent
1197
+ # draws which is what the `_LKJCholeskyCovBaseRV.rng_fn` expects. This makes the
1198
+ # random and logp methods equivalent, as the latter also assumes a unique value
1199
+ # for each diagonal element.
1200
+ # Since `eta` and `n` are forced to be scalars we don't need to worry about
1201
+ # implied batched dimensions from those for the time being.
1202
+ if size is None :
1203
+ size = sd_dist .shape [:- 1 ]
1204
+ shape = tuple (size ) + (n ,)
1205
+ if sd_dist .owner .op .ndim_supp == 0 :
1206
+ sd_dist = change_dist_size (sd_dist , shape )
1207
+ else :
1208
+ # The support shape must be `n` but we have no way of controlling it
1209
+ sd_dist = change_dist_size (sd_dist , shape [:- 1 ])
1201
1210
1202
- def logp (value , n , eta , sd_dist ):
1203
- """
1204
- Calculate log-probability of Covariance matrix with LKJ
1205
- distributed correlations at specified value.
1211
+ # Create new rng for the _lkjcholeskycov internal RV
1212
+ rng = aesara .shared (np .random .default_rng ())
1206
1213
1207
- Parameters
1208
- ----------
1209
- value: numeric
1210
- Value for which log-probability is calculated.
1214
+ rng_ , n_ , eta_ , sd_dist_ = rng .type (), n .type (), eta .type (), sd_dist .type ()
1215
+ next_rng_ , lkjcov_ = _ljk_cholesky_cov_base (n_ , eta_ , sd_dist_ , rng = rng_ ).owner .outputs
1211
1216
1212
- Returns
1213
- -------
1214
- TensorVariable
1215
- """
1217
+ return _LKJCholeskyCovRV (
1218
+ inputs = [rng_ , n_ , eta_ , sd_dist_ ],
1219
+ outputs = [next_rng_ , lkjcov_ ],
1220
+ ndim_supp = 1 ,
1221
+ )(rng , n , eta , sd_dist )
1216
1222
1217
- if value .ndim > 1 :
1218
- raise ValueError ("LKJCholeskyCov logp is only implemented for vector values (ndim=1)" )
1219
1223
1220
- diag_idxs = at .cumsum (at .arange (1 , n + 1 )) - 1
1221
- cumsum = at .cumsum (value ** 2 )
1222
- variance = at .zeros (at .atleast_1d (n ))
1223
- variance = at .inc_subtensor (variance [0 ], value [0 ] ** 2 )
1224
- variance = at .inc_subtensor (variance [1 :], cumsum [diag_idxs [1 :]] - cumsum [diag_idxs [:- 1 ]])
1225
- sd_vals = at .sqrt (variance )
1224
+ @_change_dist_size .register (_LKJCholeskyCovRV )
1225
+ def change_LKJCholeksyCovRV_size (op , dist , new_size , expand = False ):
1226
+ n , eta , sd_dist = dist .owner .inputs [1 :]
1226
1227
1227
- logp_sd = pm .logp (sd_dist , sd_vals ).sum ()
1228
- corr_diag = value [diag_idxs ] / sd_vals
1228
+ if expand :
1229
+ old_size = sd_dist .shape [:- 1 ]
1230
+ new_size = tuple (new_size ) + tuple (old_size )
1229
1231
1230
- logp_lkj = (2 * eta - 3 + n - at .arange (n )) * at .log (corr_diag )
1231
- logp_lkj = at .sum (logp_lkj )
1232
+ return _LKJCholeskyCov .rv_op (n , eta , sd_dist , size = new_size )
1232
1233
1233
- # Compute the log det jacobian of the second transformation
1234
- # described in the docstring.
1235
- idx = at .arange (n )
1236
- det_invjac = at .log (corr_diag ) - idx * at .log (sd_vals )
1237
- det_invjac = det_invjac .sum ()
1238
1234
1239
- # TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants
1240
- if not isinstance (n , Constant ):
1241
- raise NotImplementedError ("logp only implemented for constant `n`" )
1242
- n = int (n .data )
1235
+ @_moment .register (_LKJCholeskyCovRV )
1236
+ def _LKJCholeksyCovRV_moment (op , rv , rng , n , eta , sd_dist ):
1237
+ diag_idxs = (at .cumsum (at .arange (1 , n + 1 )) - 1 ).astype ("int32" )
1238
+ moment = at .zeros_like (rv )
1239
+ moment = at .set_subtensor (moment [..., diag_idxs ], 1 )
1240
+ return moment
1243
1241
1244
- if not isinstance (eta , Constant ):
1245
- raise NotImplementedError ("logp only implemented for constant `eta`" )
1246
- eta = float (eta .data )
1247
1242
1248
- norm = _lkj_normalizing_constant (eta , n )
1243
+ @_default_transform .register (_LKJCholeskyCovRV )
1244
+ def _LKJCholeksyCovRV_default_transform (op , rv ):
1245
+ _ , n , _ , _ = rv .owner .inputs
1246
+ return transforms .CholeskyCovPacked (n )
1249
1247
1250
- return norm + logp_lkj + logp_sd + det_invjac
1251
1248
1249
+ @_logprob .register (_LKJCholeskyCovRV )
1250
+ def _LKJCholeksyCovRV_logp (op , values , rng , n , eta , sd_dist , ** kwargs ):
1251
+ (value ,) = values
1252
1252
1253
- @_default_transform .register (_LKJCholeskyCov )
1254
- def lkjcholeskycov_default_transform (op , rv ):
1255
- _ , _ , _ , n , _ , _ = rv .owner .inputs
1256
- return transforms .CholeskyCovPacked (n )
1253
+ if value .ndim > 1 :
1254
+ raise ValueError ("_LKJCholeskyCov logp is only implemented for vector values (ndim=1)" )
1255
+
1256
+ diag_idxs = at .cumsum (at .arange (1 , n + 1 )) - 1
1257
+ cumsum = at .cumsum (value ** 2 )
1258
+ variance = at .zeros (at .atleast_1d (n ))
1259
+ variance = at .inc_subtensor (variance [0 ], value [0 ] ** 2 )
1260
+ variance = at .inc_subtensor (variance [1 :], cumsum [diag_idxs [1 :]] - cumsum [diag_idxs [:- 1 ]])
1261
+ sd_vals = at .sqrt (variance )
1262
+
1263
+ logp_sd = pm .logp (sd_dist , sd_vals ).sum ()
1264
+ corr_diag = value [diag_idxs ] / sd_vals
1265
+
1266
+ logp_lkj = (2 * eta - 3 + n - at .arange (n )) * at .log (corr_diag )
1267
+ logp_lkj = at .sum (logp_lkj )
1268
+
1269
+ # Compute the log det jacobian of the second transformation
1270
+ # described in the docstring.
1271
+ idx = at .arange (n )
1272
+ det_invjac = at .log (corr_diag ) - idx * at .log (sd_vals )
1273
+ det_invjac = det_invjac .sum ()
1274
+
1275
+ # TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants
1276
+ if not isinstance (n , Constant ):
1277
+ raise NotImplementedError ("logp only implemented for constant `n`" )
1278
+ n = int (n .data )
1279
+
1280
+ if not isinstance (eta , Constant ):
1281
+ raise NotImplementedError ("logp only implemented for constant `eta`" )
1282
+ eta = float (eta .data )
1283
+
1284
+ norm = _lkj_normalizing_constant (eta , n )
1285
+
1286
+ return norm + logp_lkj + logp_sd + det_invjac
1257
1287
1258
1288
1259
1289
class LKJCholeskyCov :
@@ -1462,7 +1492,7 @@ def rng_fn(cls, rng, n, eta, size):
1462
1492
else :
1463
1493
flat_size = np .prod (size )
1464
1494
1465
- C = cls ._random_corr_matrix (rng , n , eta , flat_size )
1495
+ C = cls ._random_corr_matrix (rng = rng , n = n , eta = eta , flat_size = flat_size )
1466
1496
1467
1497
triu_idx = np .triu_indices (n , k = 1 )
1468
1498
samples = C [..., triu_idx [0 ], triu_idx [1 ]]
0 commit comments