10
10
from pymc .distributions .discrete import Bernoulli , Categorical , DiscreteUniform
11
11
from pymc .distributions .transforms import Chain
12
12
from pymc .logprob .abstract import _logprob
13
- from pymc .logprob .basic import conditional_logp
13
+ from pymc .logprob .basic import conditional_logp , logp
14
14
from pymc .logprob .transforms import IntervalTransform
15
15
from pymc .model import Model
16
16
from pymc .pytensorf import compile_pymc , constant_fold , inputvars
17
17
from pymc .util import _get_seeds_per_chain , dataset_to_point_list , treedict
18
- from pytensor import Mode
18
+ from pytensor import Mode , scan
19
19
from pytensor .compile import SharedVariable
20
20
from pytensor .compile .builders import OpFromGraph
21
- from pytensor .graph import (
22
- Constant ,
23
- FunctionGraph ,
24
- ancestors ,
25
- clone_replace ,
26
- vectorize_graph ,
27
- )
21
+ from pytensor .graph import Constant , FunctionGraph , ancestors , clone_replace
22
+ from pytensor .graph .replace import vectorize_graph
28
23
from pytensor .scan import map as scan_map
29
24
from pytensor .tensor import TensorVariable
30
25
from pytensor .tensor .elemwise import Elemwise
33
28
34
29
__all__ = ["MarginalModel" ]
35
30
31
+ from pymc_experimental .distributions import DiscreteMarkovChain
32
+
36
33
37
34
class MarginalModel (Model ):
38
35
"""Subclass of PyMC Model that implements functionality for automatic
@@ -245,16 +242,25 @@ def marginalize(
245
242
self [var ] if isinstance (var , str ) else var for var in rvs_to_marginalize
246
243
]
247
244
248
- supported_dists = (Bernoulli , Categorical , DiscreteUniform )
249
245
for rv_to_marginalize in rvs_to_marginalize :
250
246
if rv_to_marginalize not in self .free_RVs :
251
247
raise ValueError (
252
248
f"Marginalized RV { rv_to_marginalize } is not a free RV in the model"
253
249
)
254
- if not isinstance (rv_to_marginalize .owner .op , supported_dists ):
250
+
251
+ rv_op = rv_to_marginalize .owner .op
252
+ if isinstance (rv_op , DiscreteMarkovChain ):
253
+ if rv_op .n_lags > 1 :
254
+ raise NotImplementedError (
255
+ "Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
256
+ )
257
+ if rv_to_marginalize .owner .inputs [0 ].type .ndim > 2 :
258
+ raise NotImplementedError (
259
+ "Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
260
+ )
261
+ elif not isinstance (rv_op , (Bernoulli , Categorical , DiscreteUniform )):
255
262
raise NotImplementedError (
256
- f"RV with distribution { rv_to_marginalize .owner .op } cannot be marginalized. "
257
- f"Supported distribution include { supported_dists } "
263
+ f"Marginalization of RV with distribution { rv_to_marginalize .owner .op } is not supported"
258
264
)
259
265
260
266
if rv_to_marginalize .name in self .named_vars_to_dims :
@@ -493,6 +499,10 @@ class FiniteDiscreteMarginalRV(MarginalRV):
493
499
"""Base class for Finite Discrete Marginalized RVs"""
494
500
495
501
502
+ class DiscreteMarginalMarkovChainRV (MarginalRV ):
503
+ """Base class for Discrete Marginal Markov Chain RVs"""
504
+
505
+
496
506
def static_shape_ancestors (vars ):
497
507
"""Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
498
508
return [
@@ -621,11 +631,17 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
621
631
replace_inputs .update ({input_rv : input_rv .type () for input_rv in input_rvs })
622
632
cloned_outputs = clone_replace (outputs , replace = replace_inputs )
623
633
624
- marginalization_op = FiniteDiscreteMarginalRV (
634
+ if isinstance (rv_to_marginalize .owner .op , DiscreteMarkovChain ):
635
+ marginalize_constructor = DiscreteMarginalMarkovChainRV
636
+ else :
637
+ marginalize_constructor = FiniteDiscreteMarginalRV
638
+
639
+ marginalization_op = marginalize_constructor (
625
640
inputs = list (replace_inputs .values ()),
626
641
outputs = cloned_outputs ,
627
642
ndim_supp = ndim_supp ,
628
643
)
644
+
629
645
marginalized_rvs = marginalization_op (* replace_inputs .keys ())
630
646
fgraph .replace_all (tuple (zip (rvs_to_marginalize , marginalized_rvs )))
631
647
return rvs_to_marginalize , marginalized_rvs
@@ -641,10 +657,26 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
641
657
elif isinstance (op , DiscreteUniform ):
642
658
lower , upper = constant_fold (rv .owner .inputs [3 :])
643
659
return tuple (range (lower , upper + 1 ))
660
+ elif isinstance (op , DiscreteMarkovChain ):
661
+ p = rv .owner .inputs [0 ]
662
+ return tuple (range (pt .get_vector_length (p [- 1 ])))
644
663
645
664
raise NotImplementedError (f"Cannot compute domain for op { op } " )
646
665
647
666
667
+ def _reduce_add_batch_dims_dependent_logps (marginalized_rv , vv_to_logps_dict ):
668
+ """Reduce batch dimensions beyond a marginalized rv for the dependent variables in a logp dict"""
669
+ mbcast = marginalized_rv .type .broadcastable
670
+ reduced_logps = []
671
+ for value , logp in vv_to_logps_dict .items ():
672
+ vbcast = value .type .broadcastable
673
+ dim_diff = len (vbcast ) - len (mbcast )
674
+ mbcast_aligned = (True ,) * dim_diff + mbcast
675
+ vbcast_axis = [i for i , (m , v ) in enumerate (zip (mbcast_aligned , vbcast )) if m and not v ]
676
+ reduced_logps .append (logp .sum (vbcast_axis ))
677
+ return pt .add (* reduced_logps )
678
+
679
+
648
680
@_logprob .register (FiniteDiscreteMarginalRV )
649
681
def finite_discrete_marginal_rv_logp (op , values , * inputs , ** kwargs ):
650
682
# Clone the inner RV graph of the Marginalized RV
@@ -661,16 +693,13 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
661
693
662
694
# Reduce logp dimensions corresponding to broadcasted variables
663
695
joint_logp = logps_dict [inner_rvs_to_values [marginalized_rv ]]
664
- for inner_rv , inner_value in inner_rvs_to_values .items ():
665
- if inner_rv is marginalized_rv :
666
- continue
667
- vbcast = inner_value .type .broadcastable
668
- mbcast = marginalized_rv .type .broadcastable
669
- mbcast = (True ,) * (len (vbcast ) - len (mbcast )) + mbcast
670
- values_axis_bcast = [i for i , (m , v ) in enumerate (zip (mbcast , vbcast )) if m != v ]
671
- joint_logp += logps_dict [inner_value ].sum (values_axis_bcast , keepdims = True )
672
-
673
- # Wrap the joint_logp graph in an OpFromGrah, so that we can evaluate it at different
696
+ joint_logp += _reduce_add_batch_dims_dependent_logps (
697
+ marginalized_rv ,
698
+ # Exclude the first entry, corresponding to the marginalized RV
699
+ {v : logp for i , (v , logp ) in enumerate (logps_dict .items ()) if i > 0 },
700
+ )
701
+
702
+ # Wrap the joint_logp graph in an OpFromGraph, so that we can evaluate it at different
674
703
# values of the marginalized RV
675
704
# Some inputs are not root inputs (such as transformed projections of value variables)
676
705
# Or cannot be used as inputs to an OpFromGraph (shared variables and constants)
@@ -698,6 +727,7 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
698
727
)
699
728
700
729
# Arbitrary cutoff to switch to Scan implementation to keep graph size under control
730
+ # TODO: Try vectorize here
701
731
if len (marginalized_rv_domain ) <= 10 :
702
732
joint_logps = [
703
733
joint_logp_op (marginalized_rv_domain_tensor [i ], * values , * inputs )
@@ -719,3 +749,67 @@ def logp_fn(marginalized_rv_const, *non_sequences):
719
749
720
750
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
721
751
return joint_logps , * (pt .constant (0 ),) * (len (values ) - 1 )
752
+
753
+
754
+ @_logprob .register (DiscreteMarginalMarkovChainRV )
755
+ def marginal_hmm_logp (op , values , * inputs , ** kwargs ):
756
+
757
+ marginalized_rvs_node = op .make_node (* inputs )
758
+ inner_rvs = clone_replace (
759
+ op .inner_outputs ,
760
+ replace = {u : v for u , v in zip (op .inner_inputs , marginalized_rvs_node .inputs )},
761
+ )
762
+
763
+ chain_rv , * dependent_rvs = inner_rvs
764
+ P , n_steps_ , init_dist_ , rng = chain_rv .owner .inputs
765
+ domain = pt .arange (P .shape [- 1 ], dtype = "int32" )
766
+
767
+ # Construct logp in two steps
768
+ # Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission)
769
+
770
+ # First we need to vectorize the conditional logp graph of the data, in case there are batch dimensions floating
771
+ # around. To do this, we need to break the dependency between chain and the init_dist_ random variable. Otherwise,
772
+ # PyMC will detect a random variable in the logp graph (init_dist_), that isn't relevant at this step.
773
+ chain_value = chain_rv .clone ()
774
+ dependent_rvs = clone_replace (dependent_rvs , {chain_rv : chain_value })
775
+ logp_emissions_dict = conditional_logp (dict (zip (dependent_rvs , values )))
776
+
777
+ # Reduce and add the batch dims beyond the chain dimension
778
+ reduced_logp_emissions = _reduce_add_batch_dims_dependent_logps (chain_rv , logp_emissions_dict )
779
+
780
+ # Add a batch dimension for the domain of the chain
781
+ chain_shape = constant_fold (tuple (chain_rv .shape ))
782
+ batch_chain_value = pt .moveaxis (pt .full ((* chain_shape , domain .size ), domain ), - 1 , 0 )
783
+ batch_logp_emissions = vectorize_graph (reduced_logp_emissions , {chain_value : batch_chain_value })
784
+
785
+ # Step 2: Compute the transition probabilities
786
+ # This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1})
787
+ # We do it entirely in logs, though.
788
+
789
+ # To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states) under
790
+ # the initial distribution. This is robust to everything the user can throw at it.
791
+ batch_logp_init_dist = pt .vectorize (lambda x : logp (init_dist_ , x ), "()->()" )(
792
+ batch_chain_value [..., 0 ]
793
+ )
794
+ log_alpha_init = batch_logp_init_dist + batch_logp_emissions [..., 0 ]
795
+
796
+ def step_alpha (logp_emission , log_alpha , log_P ):
797
+ step_log_prob = pt .logsumexp (log_alpha [:, None ] + log_P , axis = 0 )
798
+ return logp_emission + step_log_prob
799
+
800
+ P_bcast_dims = (len (chain_shape ) - 1 ) - (P .type .ndim - 2 )
801
+ log_P = pt .shape_padright (pt .log (P ), P_bcast_dims )
802
+ log_alpha_seq , _ = scan (
803
+ step_alpha ,
804
+ non_sequences = [log_P ],
805
+ outputs_info = [log_alpha_init ],
806
+ # Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value
807
+ sequences = pt .moveaxis (batch_logp_emissions [..., 1 :], - 1 , 0 ),
808
+ )
809
+ # Final logp is just the sum of the last scan state
810
+ joint_logp = pt .logsumexp (log_alpha_seq [- 1 ], axis = 0 )
811
+
812
+ # If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
813
+ # return is the joint probability of everything together, but PyMC still expects one logp for each one.
814
+ dummy_logps = (pt .constant (np .zeros (shape = ())),) * (len (values ) - 1 )
815
+ return joint_logp , * dummy_logps
0 commit comments