Skip to content

Commit ee199ef

Browse files
committed
Rebase again
1 parent c2fe1a3 commit ee199ef

File tree

3 files changed

+24
-1
lines changed

3 files changed

+24
-1
lines changed

pymc3/step_methods/hmc/base_hmc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010

1111
class BaseHMC(arraystep.GradientSharedStep):
12-
"""Superclass to implement Hamiltonian/hybrid monte carlo"""
12+
"""Superclass to implement Hamiltonian/hybrid monte carlo."""
13+
1314
default_blocked = True
1415

1516
def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False,

pymc3/step_methods/hmc/integration.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88

99

1010
class CpuLeapfrogIntegrator(object):
11+
"""Optimized leapfrog integration using numpy."""
12+
1113
def __init__(self, ndim, potential, logp_dlogp_func):
14+
"""Leapfrog integrator using CPU."""
1215
self._ndim = ndim
1316
self._potential = potential
1417
self._logp_dlogp_func = logp_dlogp_func
@@ -19,6 +22,7 @@ def __init__(self, ndim, potential, logp_dlogp_func):
1922
% (self._potential.dtype, self._dtype))
2023

2124
def compute_state(self, q, p):
25+
"""Compute Hamiltonian functions using a position and momentum."""
2226
if q.dtype != self._dtype or p.dtype != self._dtype:
2327
raise ValueError('Invalid dtype. Must be %s' % self._dtype)
2428
logp, dlogp = self._logp_dlogp_func(q)
@@ -28,6 +32,23 @@ def compute_state(self, q, p):
2832
return State(q, p, v, dlogp, energy)
2933

3034
def step(self, epsilon, state, out=None):
35+
"""Leapfrog integrator step.
36+
37+
Half a momentum update, full position update, half momentum update.
38+
39+
Parameters
40+
----------
41+
epsilon: float, > 0
42+
step scale
43+
state: State namedtuple,
44+
current position data
45+
out: (optional) State namedtuple,
46+
preallocated arrays to write to in place
47+
48+
Returns
49+
-------
50+
None if `out` is provided, else a State namedtuple
51+
"""
3152
pot = self._potential
3253
axpy = linalg.blas.get_blas_funcs('axpy', dtype=self._dtype)
3354

pymc3/step_methods/hmc/quadpotential.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def current_mean(self):
262262

263263
class QuadPotentialDiag(QuadPotential):
264264
"""Quad potential using a diagonal covariance matrix."""
265+
265266
def __init__(self, v, dtype=None):
266267
"""Use a vector to represent a diagonal matrix for a covariance matrix.
267268

0 commit comments

Comments
 (0)