Skip to content

Commit a54cb12

Browse files
Use ABC abstract methods in step method interface
1 parent d650afd commit a54cb12

File tree

4 files changed

+36
-20
lines changed

4 files changed

+36
-20
lines changed

pymc3/step_methods/arraystep.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,24 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from abc import ABC, abstractmethod
1516
from enum import IntEnum, unique
16-
from typing import Dict, List
17+
from typing import Dict, List, Tuple, TypeVar, Union
1718

1819
import numpy as np
1920

2021
from aesara.graph.basic import Variable
2122
from numpy.random import uniform
2223

23-
from pymc3.blocking import DictToArrayBijection, RaveledVars
24+
from pymc3.blocking import DictToArrayBijection, PointType, RaveledVars
2425
from pymc3.model import modelcontext
2526
from pymc3.step_methods.compound import CompoundStep
2627
from pymc3.util import get_var_name
2728

2829
__all__ = ["ArrayStep", "ArrayStepShared", "metrop_select", "Competence"]
2930

31+
StatsType = TypeVar("StatsType")
32+
3033

3134
@unique
3235
class Competence(IntEnum):
@@ -44,7 +47,7 @@ class Competence(IntEnum):
4447
IDEAL = 3
4548

4649

47-
class BlockedStep:
50+
class BlockedStep(ABC):
4851

4952
generates_stats = False
5053
stats_dtypes: List[Dict[str, np.dtype]] = []
@@ -99,6 +102,10 @@ def __new__(cls, *args, **kwargs):
99102
def __getnewargs_ex__(self):
100103
return self.__newargs
101104

105+
@abstractmethod
106+
def step(point: PointType, *args, **kwargs) -> Union[PointType, Tuple[PointType, StatsType]]:
107+
"""Perform a single step of the sampler."""
108+
102109
@staticmethod
103110
def competence(var, has_grad):
104111
return Competence.INCOMPATIBLE
@@ -139,7 +146,7 @@ def __init__(self, vars, fs, allvars=False, blocked=True):
139146
self.allvars = allvars
140147
self.blocked = blocked
141148

142-
def step(self, point: Dict[str, np.ndarray]):
149+
def step(self, point: PointType):
143150

144151
partial_funcs_and_point = [DictToArrayBijection.mapf(x, start_point=point) for x in self.fs]
145152
if self.allvars:
@@ -164,8 +171,11 @@ def step(self, point: Dict[str, np.ndarray]):
164171

165172
return point_new
166173

167-
def astep(self, apoint: RaveledVars, point: Dict[str, np.ndarray]):
168-
raise NotImplementedError()
174+
@abstractmethod
175+
def astep(
176+
self, apoint: RaveledVars, point: PointType, *args
177+
) -> Union[RaveledVars, Tuple[RaveledVars, StatsType]]:
178+
"""Perform a single sample step in a raveled and concatenated parameter space."""
169179

170180

171181
class ArrayStepShared(BlockedStep):
@@ -213,9 +223,6 @@ def step(self, point):
213223

214224
return new_point
215225

216-
def astep(self, apoint: RaveledVars):
217-
raise NotImplementedError()
218-
219226

220227
class PopulationArrayStepShared(ArrayStepShared):
221228
"""Version of ArrayStepShared that allows samplers to access the states
@@ -278,9 +285,6 @@ def step(self, point):
278285
self._logp_dlogp_func._extra_are_set = True
279286
return super().step(point)
280287

281-
def astep(self, apoint):
282-
raise NotImplementedError()
283-
284288

285289
def metrop_select(mr, q, q0):
286290
"""Perform rejection/acceptance step for Metropolis class samplers.
@@ -300,6 +304,8 @@ def metrop_select(mr, q, q0):
300304
q or q0
301305
"""
302306
# Compare acceptance ratio to uniform random number
307+
# TODO XXX: This `uniform` is not given a model-specific RNG state, which
308+
# means that sampler runs that use it will not be reproducible.
303309
if np.isfinite(mr) and np.log(uniform()) < mr:
304310
return q, True
305311
else:

pymc3/step_methods/hmc/base_hmc.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import logging
1616
import time
1717

18+
from abc import abstractmethod
1819
from collections import namedtuple
1920

2021
import numpy as np
@@ -24,7 +25,8 @@
2425
from pymc3.blocking import DictToArrayBijection, RaveledVars
2526
from pymc3.exceptions import SamplingError
2627
from pymc3.model import Point, modelcontext
27-
from pymc3.step_methods import arraystep, step_sizes
28+
from pymc3.step_methods import step_sizes
29+
from pymc3.step_methods.arraystep import GradientSharedStep
2830
from pymc3.step_methods.hmc import integration
2931
from pymc3.step_methods.hmc.quadpotential import QuadPotentialDiagAdapt, quad_potential
3032
from pymc3.tuning import guess_scaling
@@ -36,7 +38,7 @@
3638
DivergenceInfo = namedtuple("DivergenceInfo", "message, exec_info, state, state_div")
3739

3840

39-
class BaseHMC(arraystep.GradientSharedStep):
41+
class BaseHMC(GradientSharedStep):
4042
"""Superclass to implement Hamiltonian/hybrid monte carlo."""
4143

4244
default_blocked = True
@@ -85,8 +87,6 @@ def __init__(
8587
if vars is None:
8688
vars = self._model.cont_vars
8789

88-
# vars = inputvars(vars)
89-
9090
super().__init__(vars, blocked=blocked, model=self._model, dtype=dtype, **aesara_kwargs)
9191

9292
self.adapt_step_size = adapt_step_size
@@ -132,12 +132,12 @@ def __init__(
132132
self._samples_after_tune = 0
133133
self._num_divs_sample = 0
134134

135+
@abstractmethod
135136
def _hamiltonian_step(self, start, p0, step_size):
136137
"""Compute one hamiltonian trajectory and return the next state.
137138
138139
Subclasses must overwrite this method and return a `HMCStepData`.
139140
"""
140-
raise NotImplementedError("Abstract method")
141141

142142
def astep(self, q0):
143143
"""Perform a single HMC iteration."""

pymc3/step_methods/metropolis.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Dict, List, Tuple
14+
from typing import Any, Callable, Dict, List, Tuple
1515

1616
import aesara
1717
import numpy as np
@@ -425,7 +425,7 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None):
425425

426426
super().__init__(vars, [model.fastlogp])
427427

428-
def astep(self, q0: RaveledVars, logp) -> RaveledVars:
428+
def astep(self, q0: RaveledVars, logp: Callable[[RaveledVars], np.ndarray]) -> RaveledVars:
429429

430430
order = self.order
431431
if self.shuffle_dims:
@@ -475,6 +475,7 @@ def competence(var):
475475

476476
class CategoricalGibbsMetropolis(ArrayStep):
477477
"""A Metropolis-within-Gibbs step method optimized for categorical variables.
478+
478479
This step method works for Bernoulli variables as well, but it is not
479480
optimized for them, like BinaryGibbsMetropolis is. Step method supports
480481
two types of proposals: A uniform proposal and a proportional proposal,
@@ -574,6 +575,9 @@ def astep_prop(self, q0: RaveledVars, logp) -> RaveledVars:
574575

575576
return q
576577

578+
def astep(self, q0, logp):
579+
raise NotImplementedError()
580+
577581
def metropolis_proportional(self, q, logp, logp_curr, dim, k):
578582
given_cat = int(q.data[dim])
579583
log_probs = np.zeros(k)

pymc3/tests/test_hmc.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,13 @@ def test_leapfrog_reversible():
3232
start, model, _ = models.non_normal(n)
3333
size = sum(start[n.name].size for n in model.value_vars)
3434
scaling = floatX(np.random.rand(size))
35-
step = BaseHMC(vars=model.value_vars, model=model, scaling=scaling)
35+
36+
class HMC(BaseHMC):
37+
def _hamiltonian_step(self, *args, **kwargs):
38+
pass
39+
40+
step = HMC(vars=model.value_vars, model=model, scaling=scaling)
41+
3642
step.integrator._logp_dlogp_func.set_extra_values({})
3743
astart = DictToArrayBijection.map(start)
3844
p = RaveledVars(floatX(step.potential.random()), astart.point_map_info)

0 commit comments

Comments
 (0)