41
41
from pytensor .graph .fg import FunctionGraph
42
42
from pytensor .graph .rewriting .basic import node_rewriter
43
43
from pytensor .tensor .math import Max
44
+ from pytensor .tensor .random .op import RandomVariable
45
+ from pytensor .tensor .sort import SortOp
44
46
from pytensor .tensor .variable import TensorVariable
45
47
46
48
from pymc .logprob .abstract import (
47
- MeasurableElemwise ,
48
49
MeasurableOp ,
49
50
_logcdf_helper ,
50
51
_logprob ,
51
52
_logprob_helper ,
52
53
)
53
54
from pymc .logprob .rewriting import measurable_ir_rewrites_db
54
- from pymc .logprob .utils import filter_measurable_variables
55
+ from pymc .logprob .utils import (
56
+ CheckParameterValue ,
57
+ check_potential_measurability ,
58
+ filter_measurable_variables ,
59
+ )
55
60
from pymc .math import logdiffexp
56
61
from pymc .pytensorf import constant_fold
57
62
58
63
64
+ def _underlying_iid_rv (variable ) -> TensorVariable | None :
65
+ # Check whether an IID base RV is connected to the variable through identical elemwise operations
66
+ from pymc .distributions .distribution import SymbolicRandomVariable
67
+ from pymc .logprob .transforms import MeasurableTransform
68
+
69
+ def iid_elemwise_root (var : TensorVariable ) -> TensorVariable | None :
70
+ node = var .owner
71
+ if isinstance (node .op , RandomVariable | SymbolicRandomVariable ):
72
+ return var
73
+ elif isinstance (node .op , MeasurableTransform ):
74
+ if len (node .inputs == 1 ):
75
+ return iid_elemwise_root (node .inputs [0 ])
76
+ else :
77
+ # If the non-measurable inputs are broadcasted, it is still an IID operation.
78
+ measurable_inp = node .op .measurable_input_idx
79
+ other_inputs = [inp for i , inp in node .inputs if i != measurable_inp ]
80
+ if all (all (other_inp .type .broadcastable ) for other_inp in other_inputs ):
81
+ return iid_elemwise_root (node .inputs [measurable_inp ])
82
+ return None
83
+
84
+ # Check that the root is a univariate distribution linked by only elemwise operations
85
+ latent_base_var = iid_elemwise_root (variable )
86
+
87
+ if latent_base_var is None :
88
+ return None
89
+
90
+ latent_op = latent_base_var .owner .op
91
+
92
+ if not (hasattr (latent_op , "dist_params" ) and getattr (latent_op , "ndim_supp" ) == 0 ):
93
+ return None
94
+
95
+ if not all (
96
+ all (params .type .broadcastable ) for params in latent_op .dist_params (latent_base_var .owner )
97
+ ):
98
+ return None
99
+
100
+ return cast (TensorVariable , latent_base_var )
101
+
102
+
59
103
class MeasurableMax (MeasurableOp , Max ):
60
104
"""A placeholder used to specify a log-likelihood for a max sub-graph."""
61
105
@@ -77,31 +121,12 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab
77
121
if not filter_measurable_variables (node .inputs ):
78
122
return None
79
123
80
- # We allow Max of RandomVariables or Elemwise of univariate RandomVariables
81
- if isinstance (base_var .owner .op , MeasurableElemwise ):
82
- latent_base_vars = [
83
- var
84
- for var in base_var .owner .inputs
85
- if (var .owner and isinstance (var .owner .op , MeasurableOp ))
86
- ]
87
- if len (latent_base_vars ) != 1 :
88
- return None
89
- [latent_base_var ] = latent_base_vars
90
- else :
91
- latent_base_var = base_var
92
-
93
- latent_op = latent_base_var .owner .op
94
- if not (hasattr (latent_op , "dist_params" ) and getattr (latent_op , "ndim_supp" ) == 0 ):
95
- return None
124
+ # We allow Max of RandomVariables or IID Elemwise of univariate RandomVariables
125
+ latent_base_var = _underlying_iid_rv (base_var )
96
126
97
- # univariate i.i.d. test which also rules out other distributions
98
- if not all (
99
- all (params .type .broadcastable ) for params in latent_op .dist_params (latent_base_var .owner )
100
- ):
127
+ if not latent_base_var :
101
128
return None
102
129
103
- base_var = cast (TensorVariable , base_var )
104
-
105
130
if node .op .axis is None :
106
131
axis = tuple (range (base_var .ndim ))
107
132
else :
@@ -119,7 +144,7 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab
119
144
120
145
121
146
measurable_ir_rewrites_db .register (
122
- " find_measurable_max" ,
147
+ find_measurable_max . __name__ ,
123
148
find_measurable_max ,
124
149
"basic" ,
125
150
"max" ,
@@ -158,3 +183,54 @@ def max_logprob_discrete(op, values, base_rv, **kwargs):
158
183
159
184
n = pt .prod (base_rv_shape )
160
185
return logdiffexp (n * logcdf , n * logcdf_prev )
186
+
187
+
188
+ class MeasurableSort (MeasurableOp , SortOp ):
189
+ """A placeholder used to specify a log-likelihood for a sort sub-graph."""
190
+
191
+
192
+ @_logprob .register (MeasurableSort )
193
+ def sort_logprob (op , values , base_rv , axis , ** kwargs ):
194
+ r"""Compute the log-likelihood graph for the `Sort` operation."""
195
+ (value ,) = values
196
+
197
+ logprob = _logprob_helper (base_rv , value ).sum (axis = - 1 )
198
+
199
+ base_rv_shape = constant_fold (tuple (base_rv .shape ), raise_not_constant = False )
200
+ n = pt .prod (base_rv_shape , axis = - 1 )
201
+ sorted_logp = pt .gammaln (n + 1 ) + logprob
202
+
203
+ # The sorted value is not really a parameter, but we include the check in
204
+ # `CheckParameterValue` to avoid costly sorting if `check_bounds=False` in a PyMC model
205
+ return CheckParameterValue ("value must be sorted" , can_be_replaced_by_ninf = True )(
206
+ sorted_logp , pt .eq (value , value .sort (axis = axis , kind = op .kind )).all ()
207
+ )
208
+
209
+
210
+ @node_rewriter (tracks = [SortOp ])
211
+ def find_measurable_sort (fgraph , node ):
212
+ if isinstance (node .op , MeasurableSort ):
213
+ return None
214
+
215
+ if not filter_measurable_variables (node .inputs ):
216
+ return None
217
+
218
+ [base_var , axis ] = node .inputs
219
+
220
+ # We allow Max of RandomVariables or IID Elemwise of univariate RandomVariables
221
+ if _underlying_iid_rv (base_var ) is None :
222
+ return None
223
+
224
+ # Check axis is not potentially measurable
225
+ if check_potential_measurability ([axis ]):
226
+ return None
227
+
228
+ return [MeasurableSort (** node .op ._props_dict ())(base_var , axis )]
229
+
230
+
231
+ measurable_ir_rewrites_db .register (
232
+ find_measurable_sort .__name__ ,
233
+ find_measurable_sort ,
234
+ "basic" ,
235
+ "sort" ,
236
+ )
0 commit comments