13
13
# limitations under the License.
14
14
15
15
16
+ from typing import Union
17
+
16
18
import aesara
17
19
import aesara .tensor as at
18
20
import numpy as np
@@ -139,10 +141,12 @@ def test_simplex_accuracy():
139
141
140
142
141
143
def test_sum_to_1 ():
142
- check_vector_transform (tr .sum_to_1 , Simplex (2 ))
143
- check_vector_transform (tr .sum_to_1 , Simplex (4 ))
144
+ check_vector_transform (tr .univariate_sum_to_1 , Simplex (2 ))
145
+ check_vector_transform (tr .univariate_sum_to_1 , Simplex (4 ))
144
146
145
- check_jacobian_det (tr .sum_to_1 , Vector (Unit , 2 ), at .dvector , np .array ([0 , 0 ]), lambda x : x [:- 1 ])
147
+ check_jacobian_det (
148
+ tr .univariate_sum_to_1 , Vector (Unit , 2 ), at .dvector , np .array ([0 , 0 ]), lambda x : x [:- 1 ]
149
+ )
146
150
147
151
148
152
def test_log ():
@@ -241,28 +245,30 @@ def test_circular():
241
245
242
246
243
247
def test_ordered ():
244
- check_vector_transform (tr .ordered , SortedVector (6 ))
248
+ check_vector_transform (tr .univariate_ordered , SortedVector (6 ))
245
249
246
- check_jacobian_det (tr .ordered , Vector (R , 2 ), at .dvector , np .array ([0 , 0 ]), elemwise = False )
250
+ check_jacobian_det (
251
+ tr .univariate_ordered , Vector (R , 2 ), at .dvector , np .array ([0 , 0 ]), elemwise = False
252
+ )
247
253
248
- vals = get_values (tr .ordered , Vector (R , 3 ), at .dvector , np .zeros (3 ))
254
+ vals = get_values (tr .univariate_ordered , Vector (R , 3 ), at .dvector , np .zeros (3 ))
249
255
close_to_logical (np .diff (vals ) >= 0 , True , tol )
250
256
251
257
252
258
def test_chain_values ():
253
- chain_tranf = tr .Chain ([tr .logodds , tr .ordered ])
259
+ chain_tranf = tr .Chain ([tr .logodds , tr .univariate_ordered ])
254
260
vals = get_values (chain_tranf , Vector (R , 5 ), at .dvector , np .zeros (5 ))
255
261
close_to_logical (np .diff (vals ) >= 0 , True , tol )
256
262
257
263
258
264
def test_chain_vector_transform ():
259
- chain_tranf = tr .Chain ([tr .logodds , tr .ordered ])
265
+ chain_tranf = tr .Chain ([tr .logodds , tr .univariate_ordered ])
260
266
check_vector_transform (chain_tranf , UnitSortedVector (3 ))
261
267
262
268
263
269
@pytest .mark .xfail (reason = "Fails due to precision issue. Values just close to expected." )
264
270
def test_chain_jacob_det ():
265
- chain_tranf = tr .Chain ([tr .logodds , tr .ordered ])
271
+ chain_tranf = tr .Chain ([tr .logodds , tr .univariate_ordered ])
266
272
check_jacobian_det (chain_tranf , Vector (R , 4 ), at .dvector , np .zeros (4 ), elemwise = False )
267
273
268
274
@@ -327,7 +333,14 @@ def check_vectortransform_elementwise_logp(self, model):
327
333
jacob_det = transform .log_jac_det (test_array_transf , * x .owner .inputs )
328
334
# Original distribution is univariate
329
335
if x .owner .op .ndim_supp == 0 :
330
- assert model .logp (x , sum = False )[0 ].ndim == x .ndim == (jacob_det .ndim + 1 )
336
+ tr_steps = getattr (transform , "transform_list" , [transform ])
337
+ transform_keeps_dim = any (
338
+ [isinstance (ts , Union [tr .SumTo1 , tr .Ordered ]) for ts in tr_steps ]
339
+ )
340
+ if transform_keeps_dim :
341
+ assert model .logp (x , sum = False )[0 ].ndim == x .ndim == jacob_det .ndim
342
+ else :
343
+ assert model .logp (x , sum = False )[0 ].ndim == x .ndim == (jacob_det .ndim + 1 )
331
344
# Original distribution is multivariate
332
345
else :
333
346
assert model .logp (x , sum = False )[0 ].ndim == (x .ndim - 1 ) == jacob_det .ndim
@@ -449,7 +462,7 @@ def test_normal_ordered(self):
449
462
{"mu" : 0.0 , "sigma" : 1.0 },
450
463
size = 3 ,
451
464
initval = np .asarray ([- 1.0 , 1.0 , 4.0 ]),
452
- transform = tr .ordered ,
465
+ transform = tr .univariate_ordered ,
453
466
)
454
467
self .check_vectortransform_elementwise_logp (model )
455
468
@@ -467,7 +480,7 @@ def test_half_normal_ordered(self, sigma, size):
467
480
{"sigma" : sigma },
468
481
size = size ,
469
482
initval = initval ,
470
- transform = tr .Chain ([tr .log , tr .ordered ]),
483
+ transform = tr .Chain ([tr .log , tr .univariate_ordered ]),
471
484
)
472
485
self .check_vectortransform_elementwise_logp (model )
473
486
@@ -479,7 +492,7 @@ def test_exponential_ordered(self, lam, size):
479
492
{"lam" : lam },
480
493
size = size ,
481
494
initval = initval ,
482
- transform = tr .Chain ([tr .log , tr .ordered ]),
495
+ transform = tr .Chain ([tr .log , tr .univariate_ordered ]),
483
496
)
484
497
self .check_vectortransform_elementwise_logp (model )
485
498
@@ -501,7 +514,7 @@ def test_beta_ordered(self, a, b, size):
501
514
{"alpha" : a , "beta" : b },
502
515
size = size ,
503
516
initval = initval ,
504
- transform = tr .Chain ([tr .logodds , tr .ordered ]),
517
+ transform = tr .Chain ([tr .logodds , tr .univariate_ordered ]),
505
518
)
506
519
self .check_vectortransform_elementwise_logp (model )
507
520
@@ -524,7 +537,7 @@ def transform_params(*inputs):
524
537
{"lower" : lower , "upper" : upper },
525
538
size = size ,
526
539
initval = initval ,
527
- transform = tr .Chain ([interval , tr .ordered ]),
540
+ transform = tr .Chain ([interval , tr .univariate_ordered ]),
528
541
)
529
542
self .check_vectortransform_elementwise_logp (model )
530
543
@@ -536,7 +549,7 @@ def test_vonmises_ordered(self, mu, kappa, size):
536
549
{"mu" : mu , "kappa" : kappa },
537
550
size = size ,
538
551
initval = initval ,
539
- transform = tr .Chain ([tr .circular , tr .ordered ]),
552
+ transform = tr .Chain ([tr .circular , tr .univariate_ordered ]),
540
553
)
541
554
self .check_vectortransform_elementwise_logp (model )
542
555
@@ -545,7 +558,7 @@ def test_vonmises_ordered(self, mu, kappa, size):
545
558
[
546
559
(0.0 , 1.0 , (2 ,), tr .simplex ),
547
560
(0.5 , 5.5 , (2 , 3 ), tr .simplex ),
548
- (np .zeros (3 ), np .ones (3 ), (4 , 3 ), tr .Chain ([tr .sum_to_1 , tr .logodds ])),
561
+ (np .zeros (3 ), np .ones (3 ), (4 , 3 ), tr .Chain ([tr .univariate_sum_to_1 , tr .logodds ])),
549
562
],
550
563
)
551
564
def test_uniform_other (self , lower , upper , size , transform ):
@@ -573,7 +586,7 @@ def test_mvnormal_ordered(self, mu, cov, size, shape):
573
586
{"mu" : mu , "cov" : cov },
574
587
size = size ,
575
588
initval = initval ,
576
- transform = tr .ordered ,
589
+ transform = tr .multivariate_ordered ,
577
590
)
578
591
self .check_vectortransform_elementwise_logp (model )
579
592
0 commit comments