Open
Description
Before
import pytensor.tensor as pt
# Need to
a = pt.vector("a", shape=(2, ))
b = pt.vector("b", shape=(3, ))
# a + b fails due to broadcasting
# Transpose required
result = a + b[:, None]
After
a = pt.vector("a", dims="channel")
b = pt.vector("b", dims="geo")
result = a + b
# result.type TensorType(float64, dims=("channel", "geo")) # xarray-like ordering of dims
# + operation handles the transpose based on dims but would work for other element wise operations
Context for the issue:
Use of the Prior
class in PyMC-Marketing and the potential usefulness of it else where and in PyMC directly
dist = Prior(
"Normal",
# Variables are automatically transposed before passing to PyMC distributions
mu=Prior("Normal", dims="geo"),
sigma=Prior("HalfNormal", dims="geo"),
dims=("geo", "channel"),
)
References:
PyMC-Marketing auto-broadcasting handling: https://github.com/pymc-labs/pymc-marketing/blob/main/pymc_marketing/prior.py#L131-L168
PyMC Discussion: pymc-devs/pymc#7416