Skip to content

Commit fbb6fe2

Browse files
only treat () as scalar shapes
closes #4206
1 parent 8092eed commit fbb6fe2

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

pymc3/distributions/distribution.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -943,11 +943,9 @@ def _draw_value(param, point=None, givens=None, size=None):
943943

944944

945945
def _is_one_d(dist_shape):
946-
if hasattr(dist_shape, "dshape") and dist_shape.dshape in ((), (0,), (1,)):
946+
if hasattr(dist_shape, "dshape") and dist_shape.dshape in { (), }:
947947
return True
948-
elif hasattr(dist_shape, "shape") and dist_shape.shape in ((), (0,), (1,)):
949-
return True
950-
elif to_tuple(dist_shape) == ():
948+
elif hasattr(dist_shape, "shape") and dist_shape.shape in { (), }:
951949
return True
952950
return False
953951

@@ -1069,6 +1067,7 @@ def generate_samples(generator, *args, **kwargs):
10691067
len(samples.shape) > len(dist_shape)
10701068
and samples.shape[-len(dist_shape) :] == dist_shape[-len(dist_shape) :]
10711069
):
1070+
raise ValueError(f"This SHOULD be unreachable code. DON'T MERGE UNTIL THIS ENTIRE BLOCK WAS REMOVED. {samples.shape}, {size_tup}")
10721071
samples = samples.reshape(samples.shape[1:])
10731072

10741073
if (
@@ -1077,5 +1076,6 @@ def generate_samples(generator, *args, **kwargs):
10771076
and samples.shape[-1] == 1
10781077
and (samples.shape != size_tup or size_tup == tuple() or size_tup == (1,))
10791078
):
1079+
raise ValueError(f"This SHOULD be unreachable code. DON'T MERGE UNTIL THIS ENTIRE BLOCK WAS REMOVED. {samples.shape}, {size_tup}")
10801080
samples = samples.reshape(samples.shape[:-1])
10811081
return np.asarray(samples)

0 commit comments

Comments
 (0)