Skip to content

Commit 21de446

Browse files
adaptions for timeseries distributions
1 parent 08caafa commit 21de446

File tree

1 file changed

+32
-18
lines changed

1 file changed

+32
-18
lines changed

pymc3/tests/test_distributions_random.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,12 @@ def test_blocking_context(self):
198198
class BaseTestCases:
199199
class BaseTestCase(SeededTest):
200200
shape = 5
201+
# the following are the default values of the distribution that take effect
202+
# when the parametrized shape/size in the test case is None.
203+
# For every distribution that defaults to non-scalar shapes they must be
204+
# specified by the inheriting Test class. example: TestGaussianRandomWalk
205+
default_shape = ()
206+
default_size = ()
201207

202208
def setup_method(self, *args, **kwargs):
203209
super().setup_method(*args, **kwargs)
@@ -215,30 +221,39 @@ def get_random_variable(self, shape, with_vector_params=False, name=None):
215221
if name is None:
216222
name = self.distribution.__name__
217223
with self.model:
218-
if shape is None:
219-
return self.distribution(name, transform=None, **params)
220-
else:
221-
try:
224+
try:
225+
if shape is None:
226+
# in the test case parametrization "None" means "no specified (default)"
227+
return self.distribution(name, transform=None, **params)
228+
else:
222229
return self.distribution(name, shape=shape, transform=None, **params)
223-
except TypeError:
224-
if np.sum(np.atleast_1d(shape)) == 0:
225-
pytest.skip("Timeseries must have positive shape")
226-
raise
230+
except TypeError:
231+
if np.sum(np.atleast_1d(shape)) == 0:
232+
pytest.skip("Timeseries must have positive shape")
233+
raise
227234

228235
@staticmethod
229236
def sample_random_variable(random_variable, size):
230237
""" Draws samples from a RandomVariable using its .random() method. """
231238
try:
232-
return random_variable.random(size=size)
239+
if size is None:
240+
return random_variable.random()
241+
else:
242+
return random_variable.random(size=size)
233243
except AttributeError:
234-
return random_variable.distribution.random(size=size)
244+
if size is None:
245+
return random_variable.distribution.random()
246+
else:
247+
return random_variable.distribution.random(size=size)
235248

236249
@pytest.mark.parametrize("size", [None, (), 1, (1,), 5, (4, 5)], ids=str)
237250
@pytest.mark.parametrize("shape", [None, ()], ids=str)
238251
def test_scalar_distribution_shape(self, shape, size):
239252
""" Draws samples of different [size] from a scalar [shape] RV. """
240253
rv = self.get_random_variable(shape)
241-
expected = () if size in {None, ()} else tuple(np.atleast_1d(size))
254+
exp_shape = self.default_shape if shape is None else tuple(np.atleast_1d(shape))
255+
exp_size = self.default_size if size is None else tuple(np.atleast_1d(size))
256+
expected = exp_size + exp_shape
242257
actual = np.shape(self.sample_random_variable(rv, size))
243258
assert expected == actual, f"Sample size {size} from {shape}-shaped RV had shape {actual}. Expected: {expected}"
244259

@@ -247,7 +262,9 @@ def test_scalar_distribution_shape(self, shape, size):
247262
def test_scalar_sample_shape(self, shape, size):
248263
""" Draws samples of scalar [size] from a [shape] RV. """
249264
rv = self.get_random_variable(shape)
250-
expected = () if shape in {None, ()} else tuple(np.atleast_1d(shape))
265+
exp_shape = self.default_shape if shape is None else tuple(np.atleast_1d(shape))
266+
exp_size = self.default_size if size is None else tuple(np.atleast_1d(size))
267+
expected = exp_size + exp_shape
251268
actual = np.shape(self.sample_random_variable(rv, size))
252269
assert expected == actual, f"Sample size {size} from {shape}-shaped RV had shape {actual}. Expected: {expected}"
253270

@@ -256,8 +273,8 @@ def test_scalar_sample_shape(self, shape, size):
256273
def test_vector_params(self, shape, size):
257274
shape = self.shape
258275
rv = self.get_random_variable(shape, with_vector_params=True)
259-
exp_shape = () if shape in {None, ()} else tuple(np.atleast_1d(shape))
260-
exp_size = () if size in {None, ()} else tuple(np.atleast_1d(size))
276+
exp_shape = self.default_shape if shape is None else tuple(np.atleast_1d(shape))
277+
exp_size = self.default_size if size is None else tuple(np.atleast_1d(size))
261278
expected = exp_size + exp_shape
262279
actual = np.shape(self.sample_random_variable(rv, size))
263280
assert expected == actual, f"Sample size {size} from {shape}-shaped RV had shape {actual}. Expected: {expected}"
@@ -266,10 +283,7 @@ def test_vector_params(self, shape, size):
266283
class TestGaussianRandomWalk(BaseTestCases.BaseTestCase):
267284
distribution = pm.GaussianRandomWalk
268285
params = {"mu": 1.0, "sigma": 1.0}
269-
270-
@pytest.mark.xfail(reason="Supporting this makes a nasty API")
271-
def test_broadcast_shape(self):
272-
super().test_broadcast_shape()
286+
default_shape = (1,)
273287

274288

275289
class TestNormal(BaseTestCases.BaseTestCase):

0 commit comments

Comments
 (0)