Skip to content

Commit 7515c73

Browse files
pre-commit run
1 parent 9d279db commit 7515c73

File tree

2 files changed

+41
-28
lines changed

2 files changed

+41
-28
lines changed

pymc/distributions/transforms.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ class Ordered(RVTransform):
7575
name = "ordered"
7676

7777
def __init__(self, ndim_supp=0):
78+
if ndim_supp > 1:
79+
raise ValueError(
80+
f"For Ordered transformation number of core dimensions"
81+
f"(ndim_supp) must not exceed 1 but is {ndim_supp}"
82+
)
7883
self.ndim_supp = ndim_supp
7984

8085
def backward(self, value, *inputs):
@@ -105,6 +110,11 @@ class SumTo1(RVTransform):
105110
name = "sumto1"
106111

107112
def __init__(self, ndim_supp=0):
113+
if ndim_supp > 1:
114+
raise ValueError(
115+
f"For SumTo1 transformation number of core dimensions"
116+
f"(ndim_supp) must not exceed 1 but is {ndim_supp}"
117+
)
108118
self.ndim_supp = ndim_supp
109119

110120
def backward(self, value, *inputs):
@@ -352,11 +362,6 @@ def extend_axis_rev(array, axis):
352362
Instantiation of :class:`pymc.distributions.transforms.Ordered`
353363
for use in the ``transform`` argument of a multivariate random variable."""
354364

355-
ordered = Ordered(ndim_supp=1)
356-
ordered.__doc__ = """
357-
Instantiation of :class:`pymc.distributions.transforms.Ordered`
358-
for use in the ``transform`` argument. """
359-
360365

361366
log = LogTransform()
362367
log.__doc__ = """
@@ -373,11 +378,6 @@ def extend_axis_rev(array, axis):
373378
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
374379
for use in the ``transform`` argument of a multivariate random variable."""
375380

376-
sum_to_1 = SumTo1(ndim_supp=1)
377-
sum_to_1.__doc__ = """
378-
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
379-
for use in the ``transform`` argument of a random variable."""
380-
381381
circular = CircularTransform()
382382
circular.__doc__ = """
383383
Instantiation of :class:`aeppl.transforms.CircularTransform`

pymc/tests/distributions/test_transform.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515

16+
from typing import Union
17+
1618
import aesara
1719
import aesara.tensor as at
1820
import numpy as np
@@ -139,10 +141,12 @@ def test_simplex_accuracy():
139141

140142

141143
def test_sum_to_1():
142-
check_vector_transform(tr.sum_to_1, Simplex(2))
143-
check_vector_transform(tr.sum_to_1, Simplex(4))
144+
check_vector_transform(tr.univariate_sum_to_1, Simplex(2))
145+
check_vector_transform(tr.univariate_sum_to_1, Simplex(4))
144146

145-
check_jacobian_det(tr.sum_to_1, Vector(Unit, 2), at.dvector, np.array([0, 0]), lambda x: x[:-1])
147+
check_jacobian_det(
148+
tr.univariate_sum_to_1, Vector(Unit, 2), at.dvector, np.array([0, 0]), lambda x: x[:-1]
149+
)
146150

147151

148152
def test_log():
@@ -241,28 +245,30 @@ def test_circular():
241245

242246

243247
def test_ordered():
244-
check_vector_transform(tr.ordered, SortedVector(6))
248+
check_vector_transform(tr.univariate_ordered, SortedVector(6))
245249

246-
check_jacobian_det(tr.ordered, Vector(R, 2), at.dvector, np.array([0, 0]), elemwise=False)
250+
check_jacobian_det(
251+
tr.univariate_ordered, Vector(R, 2), at.dvector, np.array([0, 0]), elemwise=False
252+
)
247253

248-
vals = get_values(tr.ordered, Vector(R, 3), at.dvector, np.zeros(3))
254+
vals = get_values(tr.univariate_ordered, Vector(R, 3), at.dvector, np.zeros(3))
249255
close_to_logical(np.diff(vals) >= 0, True, tol)
250256

251257

252258
def test_chain_values():
253-
chain_tranf = tr.Chain([tr.logodds, tr.ordered])
259+
chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered])
254260
vals = get_values(chain_tranf, Vector(R, 5), at.dvector, np.zeros(5))
255261
close_to_logical(np.diff(vals) >= 0, True, tol)
256262

257263

258264
def test_chain_vector_transform():
259-
chain_tranf = tr.Chain([tr.logodds, tr.ordered])
265+
chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered])
260266
check_vector_transform(chain_tranf, UnitSortedVector(3))
261267

262268

263269
@pytest.mark.xfail(reason="Fails due to precision issue. Values just close to expected.")
264270
def test_chain_jacob_det():
265-
chain_tranf = tr.Chain([tr.logodds, tr.ordered])
271+
chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered])
266272
check_jacobian_det(chain_tranf, Vector(R, 4), at.dvector, np.zeros(4), elemwise=False)
267273

268274

@@ -327,7 +333,14 @@ def check_vectortransform_elementwise_logp(self, model):
327333
jacob_det = transform.log_jac_det(test_array_transf, *x.owner.inputs)
328334
# Original distribution is univariate
329335
if x.owner.op.ndim_supp == 0:
330-
assert model.logp(x, sum=False)[0].ndim == x.ndim == (jacob_det.ndim + 1)
336+
tr_steps = getattr(transform, "transform_list", [transform])
337+
transform_keeps_dim = any(
338+
[isinstance(ts, Union[tr.SumTo1, tr.Ordered]) for ts in tr_steps]
339+
)
340+
if transform_keeps_dim:
341+
assert model.logp(x, sum=False)[0].ndim == x.ndim == jacob_det.ndim
342+
else:
343+
assert model.logp(x, sum=False)[0].ndim == x.ndim == (jacob_det.ndim + 1)
331344
# Original distribution is multivariate
332345
else:
333346
assert model.logp(x, sum=False)[0].ndim == (x.ndim - 1) == jacob_det.ndim
@@ -449,7 +462,7 @@ def test_normal_ordered(self):
449462
{"mu": 0.0, "sigma": 1.0},
450463
size=3,
451464
initval=np.asarray([-1.0, 1.0, 4.0]),
452-
transform=tr.ordered,
465+
transform=tr.univariate_ordered,
453466
)
454467
self.check_vectortransform_elementwise_logp(model)
455468

@@ -467,7 +480,7 @@ def test_half_normal_ordered(self, sigma, size):
467480
{"sigma": sigma},
468481
size=size,
469482
initval=initval,
470-
transform=tr.Chain([tr.log, tr.ordered]),
483+
transform=tr.Chain([tr.log, tr.univariate_ordered]),
471484
)
472485
self.check_vectortransform_elementwise_logp(model)
473486

@@ -479,7 +492,7 @@ def test_exponential_ordered(self, lam, size):
479492
{"lam": lam},
480493
size=size,
481494
initval=initval,
482-
transform=tr.Chain([tr.log, tr.ordered]),
495+
transform=tr.Chain([tr.log, tr.univariate_ordered]),
483496
)
484497
self.check_vectortransform_elementwise_logp(model)
485498

@@ -501,7 +514,7 @@ def test_beta_ordered(self, a, b, size):
501514
{"alpha": a, "beta": b},
502515
size=size,
503516
initval=initval,
504-
transform=tr.Chain([tr.logodds, tr.ordered]),
517+
transform=tr.Chain([tr.logodds, tr.univariate_ordered]),
505518
)
506519
self.check_vectortransform_elementwise_logp(model)
507520

@@ -524,7 +537,7 @@ def transform_params(*inputs):
524537
{"lower": lower, "upper": upper},
525538
size=size,
526539
initval=initval,
527-
transform=tr.Chain([interval, tr.ordered]),
540+
transform=tr.Chain([interval, tr.univariate_ordered]),
528541
)
529542
self.check_vectortransform_elementwise_logp(model)
530543

@@ -536,7 +549,7 @@ def test_vonmises_ordered(self, mu, kappa, size):
536549
{"mu": mu, "kappa": kappa},
537550
size=size,
538551
initval=initval,
539-
transform=tr.Chain([tr.circular, tr.ordered]),
552+
transform=tr.Chain([tr.circular, tr.univariate_ordered]),
540553
)
541554
self.check_vectortransform_elementwise_logp(model)
542555

@@ -545,7 +558,7 @@ def test_vonmises_ordered(self, mu, kappa, size):
545558
[
546559
(0.0, 1.0, (2,), tr.simplex),
547560
(0.5, 5.5, (2, 3), tr.simplex),
548-
(np.zeros(3), np.ones(3), (4, 3), tr.Chain([tr.sum_to_1, tr.logodds])),
561+
(np.zeros(3), np.ones(3), (4, 3), tr.Chain([tr.univariate_sum_to_1, tr.logodds])),
549562
],
550563
)
551564
def test_uniform_other(self, lower, upper, size, transform):
@@ -573,7 +586,7 @@ def test_mvnormal_ordered(self, mu, cov, size, shape):
573586
{"mu": mu, "cov": cov},
574587
size=size,
575588
initval=initval,
576-
transform=tr.ordered,
589+
transform=tr.multivariate_ordered,
577590
)
578591
self.check_vectortransform_elementwise_logp(model)
579592

0 commit comments

Comments
 (0)