Skip to content

Commit 08caafa

Browse files
remove sample reshaping
1 parent d89f879 commit 08caafa

File tree

1 file changed

+1
-18
lines changed

1 file changed

+1
-18
lines changed

pymc3/distributions/distribution.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,7 @@ def generate_samples(generator, *args, **kwargs):
983983
Any remaining args and kwargs are passed on to the generator function.
984984
"""
985985
dist_shape = kwargs.pop("dist_shape", ())
986+
# TODO: the following variable is no longer used !!
986987
one_d = _is_one_d(dist_shape)
987988
size = kwargs.pop("size", None)
988989
broadcast_shape = kwargs.pop("broadcast_shape", None)
@@ -1059,23 +1060,5 @@ def generate_samples(generator, *args, **kwargs):
10591060
samples = generator(size=dist_bcast_shape, *args, **kwargs)
10601061
else:
10611062
samples = generator(size=size_tup + dist_bcast_shape, *args, **kwargs)
1062-
samples = np.asarray(samples)
10631063

1064-
# reshape samples here
1065-
if samples.ndim > 0 and samples.shape[0] == 1 and size_tup == (1,):
1066-
if (
1067-
len(samples.shape) > len(dist_shape)
1068-
and samples.shape[-len(dist_shape) :] == dist_shape[-len(dist_shape) :]
1069-
):
1070-
raise ValueError(f"This SHOULD be unreachable code. DON'T MERGE UNTIL THIS ENTIRE BLOCK WAS REMOVED. {samples.shape}, {size_tup}")
1071-
samples = samples.reshape(samples.shape[1:])
1072-
1073-
if (
1074-
one_d
1075-
and samples.ndim > 0
1076-
and samples.shape[-1] == 1
1077-
and (samples.shape != size_tup or size_tup == tuple() or size_tup == (1,))
1078-
):
1079-
raise ValueError(f"This SHOULD be unreachable code. DON'T MERGE UNTIL THIS ENTIRE BLOCK WAS REMOVED. {samples.shape}, {size_tup}")
1080-
samples = samples.reshape(samples.shape[:-1])
10811064
return np.asarray(samples)

0 commit comments

Comments
 (0)