Skip to content

Commit 9d279db

Browse files
keep transforms.ordered for backward compatibility
1 parent 7ef9b2e commit 9d279db

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

pymc/distributions/transforms.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,14 +352,28 @@ def extend_axis_rev(array, axis):
352352
Instantiation of :class:`pymc.distributions.transforms.Ordered`
353353
for use in the ``transform`` argument of a multivariate random variable."""
354354

355-
ordered = univariate_ordered
355+
ordered = Ordered(ndim_supp=1)
356+
ordered.__doc__ = """
357+
Instantiation of :class:`pymc.distributions.transforms.Ordered`
358+
for use in the ``transform`` argument. """
359+
356360

357361
log = LogTransform()
358362
log.__doc__ = """
359363
Instantiation of :class:`aeppl.transforms.LogTransform`
360364
for use in the ``transform`` argument of a random variable."""
361365

362-
sum_to_1 = SumTo1()
366+
univariate_sum_to_1 = SumTo1(ndim_supp=0)
367+
univariate_sum_to_1.__doc__ = """
368+
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
369+
for use in the ``transform`` argument of a univariate random variable."""
370+
371+
multivariate_sum_to_1 = SumTo1(ndim_supp=1)
372+
multivariate_sum_to_1.__doc__ = """
373+
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
374+
for use in the ``transform`` argument of a multivariate random variable."""
375+
376+
sum_to_1 = SumTo1(ndim_supp=1)
363377
sum_to_1.__doc__ = """
364378
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
365379
for use in the ``transform`` argument of a random variable."""

pymc/tests/distributions/test_transform.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def check_vectortransform_elementwise_logp(self, model):
327327
jacob_det = transform.log_jac_det(test_array_transf, *x.owner.inputs)
328328
# Original distribution is univariate
329329
if x.owner.op.ndim_supp == 0:
330-
assert model.logp(x, sum=False)[0].ndim == x.ndim == jacob_det.ndim
330+
assert model.logp(x, sum=False)[0].ndim == x.ndim == (jacob_det.ndim + 1)
331331
# Original distribution is multivariate
332332
else:
333333
assert model.logp(x, sum=False)[0].ndim == (x.ndim - 1) == jacob_det.ndim
@@ -573,7 +573,7 @@ def test_mvnormal_ordered(self, mu, cov, size, shape):
573573
{"mu": mu, "cov": cov},
574574
size=size,
575575
initval=initval,
576-
transform=tr.multivariate_ordered,
576+
transform=tr.ordered,
577577
)
578578
self.check_vectortransform_elementwise_logp(model)
579579

@@ -607,11 +607,11 @@ def test_discrete_trafo():
607607
def test_transforms_ordered():
608608
with pm.Model() as model:
609609
pm.Normal(
610-
"x",
610+
"x_univariate",
611611
mu=[-3, -1, 1, 2],
612612
sigma=1,
613613
size=(10, 4),
614-
transform=pm.distributions.transforms.ordered,
614+
transform=tr.univariate_ordered,
615615
)
616616

617617
log_prob = model.point_logps()
@@ -625,7 +625,7 @@ def test_transforms_sumto1():
625625
mu=[-3, -1, 1, 2],
626626
sigma=1,
627627
size=(10, 4),
628-
transform=pm.distributions.transforms.sum_to_1,
628+
transform=tr.univariate_sum_to_1,
629629
)
630630

631631
log_prob = model.point_logps()

0 commit comments

Comments
 (0)