Skip to content

Commit d9b7ef4

Browse files
committed
Implement Ordered distribution factory
1 parent 3f3aeb9 commit d9b7ef4

File tree

4 files changed

+238
-29
lines changed

4 files changed

+238
-29
lines changed

pymc/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103
WishartBartlett,
104104
ZeroSumNormal,
105105
)
106+
from pymc.distributions.ordered import Ordered
106107
from pymc.distributions.simulator import Simulator
107108
from pymc.distributions.timeseries import (
108109
AR,
@@ -178,6 +179,7 @@
178179
"NegativeBinomial",
179180
"Normal",
180181
"NormalMixture",
182+
"Ordered",
181183
"OrderedLogistic",
182184
"OrderedMultinomial",
183185
"OrderedProbit",

pymc/distributions/discrete.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,8 +1239,7 @@ class OrderedLogistic:
12391239
12401240
# Ordered logistic regression
12411241
with pm.Model() as model:
1242-
cutpoints = pm.Normal("cutpoints", mu=[-1,1], sigma=10, shape=2,
1243-
transform=pm.distributions.transforms.ordered)
1242+
cutpoints = pm.Ordered("cutpoints", dist=pm.Normal.dist(mu=0, sigma=10), shape=2)
12441243
y_ = pm.OrderedLogistic("y", cutpoints=cutpoints, eta=x, observed=y)
12451244
idata = pm.sample()
12461245
@@ -1343,8 +1342,7 @@ class OrderedProbit:
13431342
13441343
# Ordered probit regression
13451344
with pm.Model() as model:
1346-
cutpoints = pm.Normal("cutpoints", mu=[-1,1], sigma=10, shape=2,
1347-
transform=pm.distributions.transforms.ordered)
1345+
cutpoints = pm.Ordered("cutpoints", dist=pm.Normal.dist(mu=0, sigma=10), shape=2)
13481346
y_ = pm.OrderedProbit("y", cutpoints=cutpoints, eta=x, observed=y)
13491347
idata = pm.sample()
13501348

pymc/distributions/ordered.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright 2024 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pytensor.tensor as pt
15+
16+
from pytensor.tensor.random.op import RandomVariable
17+
from pytensor.tensor.random.utils import normalize_size_param
18+
from pytensor.tensor.variable import TensorVariable
19+
20+
from pymc.distributions.distribution import (
21+
Distribution,
22+
SymbolicRandomVariable,
23+
_support_point,
24+
)
25+
from pymc.distributions.shape_utils import change_dist_size, get_support_shape_1d, rv_size_is_none
26+
from pymc.distributions.transforms import _default_transform, ordered
27+
28+
29+
class OrderedRV(SymbolicRandomVariable):
30+
inline_logprob = True
31+
extended_signature = "(x)->(x)"
32+
_print_name = ("Ordered", "\\operatorname{Ordered}")
33+
34+
@classmethod
35+
def rv_op(cls, dist, *, size=None):
36+
# We don't allow passing `rng` because we don't fully control the rng of the components!
37+
38+
size = normalize_size_param(size)
39+
40+
if not rv_size_is_none(size):
41+
core_shape = tuple(dist.shape)[-1]
42+
shape = (*tuple(size), core_shape)
43+
dist = change_dist_size(dist, shape)
44+
45+
sorted_rv = pt.sort(dist, axis=-1)
46+
47+
return OrderedRV(
48+
inputs=[dist],
49+
outputs=[sorted_rv],
50+
)(dist)
51+
52+
53+
class Ordered(Distribution):
54+
r"""Univariate IID Ordered distribution.
55+
56+
The pdf of the oredered distribution is
57+
58+
.. math::
59+
f(x_1, ..., x_n) = n!\prod_{i=1}^n f(x_{(i)}),
60+
where x_1 <= x2 <= ... <= x_n
61+
62+
Parameters
63+
----------
64+
dist: unnamed_distribution
65+
Univariate IID distribution which will be sorted.
66+
67+
.. warning:: dist will be cloned, rendering it independent of the one passade as input
68+
69+
Examples
70+
--------
71+
.. code-block:: python
72+
import pymc as pm
73+
74+
with pm.Model():
75+
x = pm.Normal.dist(mu=0, sigma=1) # Must be IID
76+
ordered_x = pm.Ordered("ordered_x", dist=x, shape=(3,))
77+
78+
pm.draw(ordered_x, random_seed=52) # array([0.05172346, 0.43970706, 0.91500416])
79+
"""
80+
81+
rv_type = OrderedRV
82+
rv_op = OrderedRV.rv_op
83+
84+
def __new__(cls, name, dist, *, support_shape=None, **kwargs):
85+
support_shape = get_support_shape_1d(
86+
support_shape=support_shape,
87+
shape=None, # shape will be checked in `cls.dist`
88+
dims=kwargs.get("dims", None),
89+
observed=kwargs.get("observed", None),
90+
)
91+
return super().__new__(cls, name, dist, support_shape=support_shape, **kwargs)
92+
93+
@classmethod
94+
def dist(cls, dist, *, support_shape=None, **kwargs):
95+
if not isinstance(dist, TensorVariable) or not isinstance(
96+
dist.owner.op, RandomVariable | SymbolicRandomVariable
97+
):
98+
raise ValueError(
99+
f"Ordered dist must be a distribution created via the `.dist()` API, got {type(dist)}"
100+
)
101+
if dist.owner.op.ndim_supp > 0:
102+
raise NotImplementedError("Ordering of multivariate distributions not supported")
103+
if not all(
104+
all(param.type.broadcastable) for param in dist.owner.op.dist_params(dist.owner)
105+
):
106+
raise ValueError("Ordered dist must be an IID variable")
107+
108+
support_shape = get_support_shape_1d(
109+
support_shape=support_shape,
110+
shape=kwargs.get("shape", None),
111+
)
112+
if support_shape is not None:
113+
dist = change_dist_size(dist, support_shape)
114+
115+
dist = pt.atleast_1d(dist)
116+
117+
return super().dist([dist], **kwargs)
118+
119+
120+
@_default_transform.register(OrderedRV)
121+
def default_transform_ordered(op, rv):
122+
if rv.type.dtype.startswith("float"):
123+
return ordered
124+
else:
125+
return None
126+
127+
128+
@_support_point.register(OrderedRV)
129+
def support_point_ordered(op, rv, dist):
130+
# FIXME: This does not work with the default ordered transform
131+
# which maps [0, 0, 0] to [0, -inf, -inf].
132+
# return support_point(dist)
133+
return rv # Draw from the prior

pymc/logprob/order.py

Lines changed: 101 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,65 @@
4141
from pytensor.graph.fg import FunctionGraph
4242
from pytensor.graph.rewriting.basic import node_rewriter
4343
from pytensor.tensor.math import Max
44+
from pytensor.tensor.random.op import RandomVariable
45+
from pytensor.tensor.sort import SortOp
4446
from pytensor.tensor.variable import TensorVariable
4547

4648
from pymc.logprob.abstract import (
47-
MeasurableElemwise,
4849
MeasurableOp,
4950
_logcdf_helper,
5051
_logprob,
5152
_logprob_helper,
5253
)
5354
from pymc.logprob.rewriting import measurable_ir_rewrites_db
54-
from pymc.logprob.utils import filter_measurable_variables
55+
from pymc.logprob.utils import (
56+
CheckParameterValue,
57+
check_potential_measurability,
58+
filter_measurable_variables,
59+
)
5560
from pymc.math import logdiffexp
5661
from pymc.pytensorf import constant_fold
5762

5863

64+
def _underlying_iid_rv(variable) -> TensorVariable | None:
65+
# Check whether an IID base RV is connected to the variable through identical elemwise operations
66+
from pymc.distributions.distribution import SymbolicRandomVariable
67+
from pymc.logprob.transforms import MeasurableTransform
68+
69+
def iid_elemwise_root(var: TensorVariable) -> TensorVariable | None:
70+
node = var.owner
71+
if isinstance(node.op, RandomVariable | SymbolicRandomVariable):
72+
return var
73+
elif isinstance(node.op, MeasurableTransform):
74+
if len(node.inputs == 1):
75+
return iid_elemwise_root(node.inputs[0])
76+
else:
77+
# If the non-measurable inputs are broadcasted, it is still an IID operation.
78+
measurable_inp = node.op.measurable_input_idx
79+
other_inputs = [inp for i, inp in node.inputs if i != measurable_inp]
80+
if all(all(other_inp.type.broadcastable) for other_inp in other_inputs):
81+
return iid_elemwise_root(node.inputs[measurable_inp])
82+
return None
83+
84+
# Check that the root is a univariate distribution linked by only elemwise operations
85+
latent_base_var = iid_elemwise_root(variable)
86+
87+
if latent_base_var is None:
88+
return None
89+
90+
latent_op = latent_base_var.owner.op
91+
92+
if not (hasattr(latent_op, "dist_params") and getattr(latent_op, "ndim_supp") == 0):
93+
return None
94+
95+
if not all(
96+
all(params.type.broadcastable) for params in latent_op.dist_params(latent_base_var.owner)
97+
):
98+
return None
99+
100+
return cast(TensorVariable, latent_base_var)
101+
102+
59103
class MeasurableMax(MeasurableOp, Max):
60104
"""A placeholder used to specify a log-likelihood for a max sub-graph."""
61105

@@ -77,31 +121,12 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab
77121
if not filter_measurable_variables(node.inputs):
78122
return None
79123

80-
# We allow Max of RandomVariables or Elemwise of univariate RandomVariables
81-
if isinstance(base_var.owner.op, MeasurableElemwise):
82-
latent_base_vars = [
83-
var
84-
for var in base_var.owner.inputs
85-
if (var.owner and isinstance(var.owner.op, MeasurableOp))
86-
]
87-
if len(latent_base_vars) != 1:
88-
return None
89-
[latent_base_var] = latent_base_vars
90-
else:
91-
latent_base_var = base_var
92-
93-
latent_op = latent_base_var.owner.op
94-
if not (hasattr(latent_op, "dist_params") and getattr(latent_op, "ndim_supp") == 0):
95-
return None
124+
# We allow Max of RandomVariables or IID Elemwise of univariate RandomVariables
125+
latent_base_var = _underlying_iid_rv(base_var)
96126

97-
# univariate i.i.d. test which also rules out other distributions
98-
if not all(
99-
all(params.type.broadcastable) for params in latent_op.dist_params(latent_base_var.owner)
100-
):
127+
if not latent_base_var:
101128
return None
102129

103-
base_var = cast(TensorVariable, base_var)
104-
105130
if node.op.axis is None:
106131
axis = tuple(range(base_var.ndim))
107132
else:
@@ -119,7 +144,7 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab
119144

120145

121146
measurable_ir_rewrites_db.register(
122-
"find_measurable_max",
147+
find_measurable_max.__name__,
123148
find_measurable_max,
124149
"basic",
125150
"max",
@@ -158,3 +183,54 @@ def max_logprob_discrete(op, values, base_rv, **kwargs):
158183

159184
n = pt.prod(base_rv_shape)
160185
return logdiffexp(n * logcdf, n * logcdf_prev)
186+
187+
188+
class MeasurableSort(MeasurableOp, SortOp):
189+
"""A placeholder used to specify a log-likelihood for a sort sub-graph."""
190+
191+
192+
@_logprob.register(MeasurableSort)
193+
def sort_logprob(op, values, base_rv, axis, **kwargs):
194+
r"""Compute the log-likelihood graph for the `Sort` operation."""
195+
(value,) = values
196+
197+
logprob = _logprob_helper(base_rv, value).sum(axis=-1)
198+
199+
base_rv_shape = constant_fold(tuple(base_rv.shape), raise_not_constant=False)
200+
n = pt.prod(base_rv_shape, axis=-1)
201+
sorted_logp = pt.gammaln(n + 1) + logprob
202+
203+
# The sorted value is not really a parameter, but we include the check in
204+
# `CheckParameterValue` to avoid costly sorting if `check_bounds=False` in a PyMC model
205+
return CheckParameterValue("value must be sorted", can_be_replaced_by_ninf=True)(
206+
sorted_logp, pt.eq(value, value.sort(axis=axis, kind=op.kind)).all()
207+
)
208+
209+
210+
@node_rewriter(tracks=[SortOp])
211+
def find_measurable_sort(fgraph, node):
212+
if isinstance(node.op, MeasurableSort):
213+
return None
214+
215+
if not filter_measurable_variables(node.inputs):
216+
return None
217+
218+
[base_var, axis] = node.inputs
219+
220+
# We allow Max of RandomVariables or IID Elemwise of univariate RandomVariables
221+
if _underlying_iid_rv(base_var) is None:
222+
return None
223+
224+
# Check axis is not potentially measurable
225+
if check_potential_measurability([axis]):
226+
return None
227+
228+
return [MeasurableSort(**node.op._props_dict())(base_var, axis)]
229+
230+
231+
measurable_ir_rewrites_db.register(
232+
find_measurable_sort.__name__,
233+
find_measurable_sort,
234+
"basic",
235+
"sort",
236+
)

0 commit comments

Comments
 (0)