Skip to content

Commit 0372528

Browse files
univariate_ordered and multivariate_ordered
1 parent 2d2bab4 commit 0372528

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

pymc/distributions/transforms.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,10 +336,17 @@ def extend_axis_rev(array, axis):
336336
Instantiation of :class:`pymc.distributions.transforms.LogExpM1`
337337
for use in the ``transform`` argument of a random variable."""
338338

339-
ordered = Ordered()
340-
ordered.__doc__ = """
339+
univariate_ordered = Ordered(ndim_supp=0)
340+
univariate_ordered.__doc__ = """
341341
Instantiation of :class:`pymc.distributions.transforms.Ordered`
342-
for use in the ``transform`` argument of a random variable."""
342+
for use in the ``transform`` argument of a univariate random variable."""
343+
344+
multivariate_ordered = Ordered(ndim_supp=1)
345+
multivariate_ordered.__doc__ = """
346+
Instantiation of :class:`pymc.distributions.transforms.Ordered`
347+
for use in the ``transform`` argument of a multivariate random variable."""
348+
349+
ordered = univariate_ordered
343350

344351
log = LogTransform()
345352
log.__doc__ = """

pymc/tests/distributions/test_transform.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def check_vectortransform_elementwise_logp(self, model):
327327
jacob_det = transform.log_jac_det(test_array_transf, *x.owner.inputs)
328328
# Original distribution is univariate
329329
if x.owner.op.ndim_supp == 0:
330-
assert model.logp(x, sum=False)[0].ndim == x.ndim == (jacob_det.ndim + 1)
330+
assert model.logp(x, sum=False)[0].ndim == x.ndim == jacob_det.ndim
331331
# Original distribution is multivariate
332332
else:
333333
assert model.logp(x, sum=False)[0].ndim == (x.ndim - 1) == jacob_det.ndim
@@ -569,7 +569,11 @@ def test_uniform_other(self, lower, upper, size, transform):
569569
def test_mvnormal_ordered(self, mu, cov, size, shape):
570570
initval = np.sort(np.random.randn(*shape))
571571
model = self.build_model(
572-
pm.MvNormal, {"mu": mu, "cov": cov}, size=size, initval=initval, transform=tr.ordered
572+
pm.MvNormal,
573+
{"mu": mu, "cov": cov},
574+
size=size,
575+
initval=initval,
576+
transform=tr.multivariate_ordered,
573577
)
574578
self.check_vectortransform_elementwise_logp(model)
575579

0 commit comments

Comments
 (0)