Skip to content

Commit f3fe8ba

Browse files
Make sure shared variables have correct broadcasting in ValueGradFunction
1 parent ab03ed0 commit f3fe8ba

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

pymc3/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,9 @@ def __init__(
427427
givens = []
428428
self._extra_vars_shared = {}
429429
for var, value in extra_vars_and_values.items():
430-
shared = aesara.shared(value, var.name + "_shared__")
430+
shared = aesara.shared(
431+
value, var.name + "_shared__", broadcastable=[s == 1 for s in value.shape]
432+
)
431433
self._extra_vars_shared[var.name] = shared
432434
givens.append((var, shared))
433435

0 commit comments

Comments
 (0)