diff --git a/pymc/aesaraf.py b/pymc/aesaraf.py index 4aa5981f03..62217bb5f2 100644 --- a/pymc/aesaraf.py +++ b/pymc/aesaraf.py @@ -572,7 +572,7 @@ def join_nonshared_inputs( for var in vars: shape = point[var.name].shape arr_len = np.prod(shape, dtype=int) - replace[var] = reshape_t(inarray[last_idx : last_idx + arr_len], shape).astype(var.dtype) + replace[var] = inarray[last_idx : last_idx + arr_len].reshape(shape).astype(var.dtype) last_idx += arr_len replace.update(shared) @@ -581,14 +581,6 @@ def join_nonshared_inputs( return xs_special, inarray -def reshape_t(x, shape): - """Work around fact that x.reshape(()) doesn't work""" - if shape != (): - return x.reshape(shape) - else: - return x[0] - - class PointFunc: """Wraps so a function so it takes a dict of arguments instead of arguments."""