23
23
24
24
import pymc3 as pm
25
25
26
- from pymc3 .aesaraf import floatX
26
+ from pymc3 .aesaraf import floatX , rvs_to_value_vars
27
27
from pymc3 .blocking import DictToArrayBijection , RaveledVars
28
28
from pymc3 .step_methods .arraystep import (
29
29
ArrayStep ,
@@ -408,8 +408,8 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None):
408
408
# transition probabilities
409
409
self .transit_p = transit_p
410
410
411
- # XXX: This needs to be refactored
412
- self .dim = None # sum(v.dsize for v in vars)
411
+ initial_point = model . initial_point
412
+ self .dim = sum (initial_point [ v . name ]. size for v in vars )
413
413
414
414
if order == "random" :
415
415
self .shuffle_dims = True
@@ -491,29 +491,35 @@ class CategoricalGibbsMetropolis(ArrayStep):
491
491
def __init__ (self , vars , proposal = "uniform" , order = "random" , model = None ):
492
492
493
493
model = pm .modelcontext (model )
494
+
494
495
vars = pm .inputvars (vars )
495
496
497
+ initial_point = model .initial_point
498
+
496
499
dimcats = []
497
500
# The above variable is a list of pairs (aggregate dimension, number
498
501
# of categories). For example, if vars = [x, y] with x being a 2-D
499
502
# variable with M categories and y being a 3-D variable with N
500
503
# categories, we will have dimcats = [(0, M), (1, M), (2, N), (3, N), (4, N)].
501
504
for v in vars :
502
505
503
- distr = getattr (v .owner , "op" , None )
506
+ v_init_val = initial_point [v .name ]
507
+
508
+ rv_var = model .values_to_rvs [v ]
509
+ distr = getattr (rv_var .owner , "op" , None )
504
510
505
511
if isinstance (distr , CategoricalRV ):
506
- # XXX: This needs to be refactored
507
- k = None # draw_values([distr.k])[0]
508
- elif isinstance (distr , pm .Bernoulli ) or (v .dtype in pm .bool_types ):
512
+ k_graph = rv_var .owner .inputs [3 ].shape [- 1 ]
513
+ (k_graph ,), _ = rvs_to_value_vars ((k_graph ,), apply_transforms = True )
514
+ k = model .fn (k_graph )(initial_point )
515
+ elif isinstance (distr , BernoulliRV ):
509
516
k = 2
510
517
else :
511
518
raise ValueError (
512
519
"All variables must be categorical or binary" + "for CategoricalGibbsMetropolis"
513
520
)
514
521
start = len (dimcats )
515
- # XXX: This needs to be refactored
516
- dimcats += None # [(dim, k) for dim in range(start, start + v.dsize)]
522
+ dimcats += [(dim , k ) for dim in range (start , start + v_init_val .size )]
517
523
518
524
if order == "random" :
519
525
self .shuffle_dims = True
@@ -543,18 +549,16 @@ def astep_unif(self, q0: RaveledVars, logp) -> RaveledVars:
543
549
if self .shuffle_dims :
544
550
nr .shuffle (dimcats )
545
551
546
- q = np .copy (q0 )
552
+ q = RaveledVars ( np .copy (q0 ), point_map_info )
547
553
logp_curr = logp (q )
548
554
549
555
for dim , k in dimcats :
550
- curr_val , q [dim ] = q [dim ], sample_except (k , q [dim ])
556
+ curr_val , q . data [dim ] = q . data [dim ], sample_except (k , q . data [dim ])
551
557
logp_prop = logp (q )
552
- q [dim ], accepted = metrop_select (logp_prop - logp_curr , q [dim ], curr_val )
558
+ q . data [dim ], accepted = metrop_select (logp_prop - logp_curr , q . data [dim ], curr_val )
553
559
if accepted :
554
560
logp_curr = logp_prop
555
561
556
- q = RaveledVars (q , point_map_info )
557
-
558
562
return q
559
563
560
564
def astep_prop (self , q0 : RaveledVars , logp ) -> RaveledVars :
@@ -566,34 +570,32 @@ def astep_prop(self, q0: RaveledVars, logp) -> RaveledVars:
566
570
if self .shuffle_dims :
567
571
nr .shuffle (dimcats )
568
572
569
- q = np .copy (q0 )
573
+ q = RaveledVars ( np .copy (q0 ), point_map_info )
570
574
logp_curr = logp (q )
571
575
572
576
for dim , k in dimcats :
573
577
logp_curr = self .metropolis_proportional (q , logp , logp_curr , dim , k )
574
578
575
- q = RaveledVars (q , point_map_info )
576
-
577
579
return q
578
580
579
581
def metropolis_proportional (self , q , logp , logp_curr , dim , k ):
580
- given_cat = int (q [dim ])
582
+ given_cat = int (q . data [dim ])
581
583
log_probs = np .zeros (k )
582
584
log_probs [given_cat ] = logp_curr
583
585
candidates = list (range (k ))
584
586
for candidate_cat in candidates :
585
587
if candidate_cat != given_cat :
586
- q [dim ] = candidate_cat
588
+ q . data [dim ] = candidate_cat
587
589
log_probs [candidate_cat ] = logp (q )
588
590
probs = softmax (log_probs )
589
591
prob_curr , probs [given_cat ] = probs [given_cat ], 0.0
590
592
probs /= 1.0 - prob_curr
591
593
proposed_cat = nr .choice (candidates , p = probs )
592
594
accept_ratio = (1.0 - prob_curr ) / (1.0 - probs [proposed_cat ])
593
595
if not np .isfinite (accept_ratio ) or nr .uniform () >= accept_ratio :
594
- q [dim ] = given_cat
596
+ q . data [dim ] = given_cat
595
597
return logp_curr
596
- q [dim ] = proposed_cat
598
+ q . data [dim ] = proposed_cat
597
599
return log_probs [proposed_cat ]
598
600
599
601
@staticmethod
@@ -744,7 +746,7 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
744
746
r1 = DictToArrayBijection .map (self .population [ir1 ])
745
747
r2 = DictToArrayBijection .map (self .population [ir2 ])
746
748
# propose a jump
747
- q = floatX (q0 + self .lamb * (r1 - r2 ) + epsilon )
749
+ q = floatX (q0 + self .lamb * (r1 . data - r2 . data ) + epsilon )
748
750
749
751
accept = self .delta_logp (q , q0 )
750
752
q_new , accepted = metrop_select (accept , q , q0 )
0 commit comments