From 2a440dea2c3b7314645532939a5ad6c51e0a3947 Mon Sep 17 00:00:00 2001 From: Tyler Burch Date: Sat, 10 Sep 2022 20:44:25 -0400 Subject: [PATCH 1/2] Remove reshape_t function --- pymc/aesaraf.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/pymc/aesaraf.py b/pymc/aesaraf.py index 4aa5981f03..3c15912ed6 100644 --- a/pymc/aesaraf.py +++ b/pymc/aesaraf.py @@ -581,14 +581,6 @@ def join_nonshared_inputs( return xs_special, inarray -def reshape_t(x, shape): - """Work around fact that x.reshape(()) doesn't work""" - if shape != (): - return x.reshape(shape) - else: - return x[0] - - class PointFunc: """Wraps so a function so it takes a dict of arguments instead of arguments.""" From 52ad96681da0a619941315ca4bdfb1697179cfff Mon Sep 17 00:00:00 2001 From: Tyler Burch Date: Sat, 10 Sep 2022 21:00:38 -0400 Subject: [PATCH 2/2] Remove reshape_t reference --- pymc/aesaraf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/aesaraf.py b/pymc/aesaraf.py index 3c15912ed6..62217bb5f2 100644 --- a/pymc/aesaraf.py +++ b/pymc/aesaraf.py @@ -572,7 +572,7 @@ def join_nonshared_inputs( for var in vars: shape = point[var.name].shape arr_len = np.prod(shape, dtype=int) - replace[var] = reshape_t(inarray[last_idx : last_idx + arr_len], shape).astype(var.dtype) + replace[var] = inarray[last_idx : last_idx + arr_len].reshape(shape).astype(var.dtype) last_idx += arr_len replace.update(shared)