Skip to content

ENH: Native support for dims in tensors and tensor operations #954

Open
@williambdean

Description

@williambdean

Before

import pytensor.tensor as pt

# Need to 
a = pt.vector("a", shape=(2, ))
b = pt.vector("b", shape=(3, ))

# a + b fails due to broadcasting
# Transpose required
result = a + b[:, None]

After

a = pt.vector("a", dims="channel")
b = pt.vector("b", dims="geo")

result = a + b
# result.type TensorType(float64, dims=("channel", "geo")) # xarray-like ordering of dims
# + operation handles the transpose based on dims but would work for other element wise operations

Context for the issue:

Use of the Prior class in PyMC-Marketing and the potential usefulness of it else where and in PyMC directly

dist = Prior(
    "Normal", 
    # Variables are automatically transposed before passing to PyMC distributions
    mu=Prior("Normal", dims="geo"), 
    sigma=Prior("HalfNormal", dims="geo"), 
    dims=("geo", "channel"), 
)

References:
PyMC-Marketing auto-broadcasting handling: https://github.com/pymc-labs/pymc-marketing/blob/main/pymc_marketing/prior.py#L131-L168
PyMC Discussion: pymc-devs/pymc#7416

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions