Skip to content

Labeled tensors #1411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Labeled tensors #1411

wants to merge 4 commits into from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented May 22, 2025

import numpy as np

from pytensor import function
from pytensor.xtensor.basic import add, exp
from pytensor.xtensor.type import xtensor

x = xtensor("x", dims=("city",), shape=(None,))
y = xtensor("y", dims=("country",), shape=(4,))
z = add(exp(x), exp(y))
assert z.type.dims == ("city", "country")
assert z.type.shape == (None, 4)

fn = function([x, y], z)
fn.dprint(print_type=True)
# XTensorFromTensor{dims=('city', 'country')} [id A] <XTensorType{dtype='float64', shape=(None, 4), dims=('city', 'country')}> 7
#  └─ Add [id B] <Matrix(float64, shape=(?, 4))> 6
#     ├─ Exp [id C] <Matrix(float64, shape=(?, 1))> 5
#     │  └─ ExpandDims{axis=1} [id D] <Matrix(float64, shape=(?, 1))> 3
#     │     └─ TensorFromXTensor [id E] <Vector(float64, shape=(?,))> 1
#     │        └─ x [id F] <XTensorType{dtype='float64', shape=(None,), dims=('city',)}>
#     └─ Exp [id G] <Matrix(float64, shape=(1, 4))> 4
#        └─ ExpandDims{axis=0} [id H] <Matrix(float64, shape=(1, 4))> 2
#           └─ TensorFromXTensor [id I] <Vector(float64, shape=(4,))> 0
#              └─ y [id J] <XTensorType{dtype='float64', shape=(4,), dims=('country',)}>

np.testing.assert_allclose(
    fn(x=np.zeros(3), y=np.zeros(4)),
    np.full((3, 4), 2.0),
)

Strategy

We implement xarray-like dummy Ops that respect / propagate dims semantics, and lower them to regular PyTensor graphs with rewrites.

Note in the example above the dummy TensorFromXtensor and XTensorFromTensor remain in the final graph. If we had created a function with Tensor inputs and outputs that are only then converted (symbolically) to and from xtensor, respectively, the final graph would have no signs of dimension operations, other than how it was constructed.

I suggest registering those rewrites in an xtensor_lowering database.

Coordinates

For now I'm playing with how far we can get without coordinates. This means the graphs produced by an xarray-like syntax are much more amenable to the numpy-like backend of PyTensor. Otherwise it involves a lot of Pandas-like stuff (e.g., Multiindex) that we don't really have. It may be feasible, specially if nothing is symbolic, but... I fear a rabbit hole of edge cases)

Gradients

These ops are currently not differentiable, but one can lower the graph and then call the gradient. I do want to try the lazy grad approach from #788

Help implementing more Ops so we have MVP to try out with PyMC next. We need some Ops

Open a PR on top of this branch, I'll try to merge quickly! Try to make it clean (one commit per Op, unless it's like a factory of related Ops)

Implementing means:

  1. Create a dummy Op
  2. Create a rewrite that lowers the dummy Op to real tensor operations
    3.1 The rewrites "box" the lower tensor operations between TensorFromXTensor and XTensorFromTensor calls, so that the replacements are valid in terms of types. There are rewrites to remove chains of useless TensorFromXTensor/XTensorFromTensor that should clean up everything in the middle of the graph.
  3. Add a test that compares with xarray/xarray_einstants and proves it's correct
  4. If you really want, test the error checks (I haven't been doing that)

Interplay between XTensorTypes and TensorTypes / weakly typed inputs

  • Symbolic conversion to and from XTensor and Tensor
  • Make sure MetaOps accept non-XTensorType scalar inputs
  • Make MetaOps "cast" regular numpy/TensorVariable inputs to XTensorVariable to behave like xarray does (dims are considered to match positionally, try it out).
  • Operators as methods (__add__ and the like so you can do x + x)

Meta Ops

  • Elemwise (automatically generated, some may fail)
  • Blockwise (each Op needs manual curation though)
  • where (double behavior based on num inputs, quite central in xarray like workflow?)
  • CAReduce (Sum, All, Mean, ...)
  • Einsum (probably low priority)
  • Scan (what a joke)
  • OpFromGraph (Should just work?, gotta test it)

Math stuff

  • Cast (it's a parametrized ScalarOp so the general XElemwise logic won't suffice)
  • Dot
  • Mean/Std/Variance (there's no CAReduce Op corresponding to those)
  • Everything that is a blockwise in vanilla pytensor (like all of linalg)

Shape stuff

  • Rename
  • Transpose (test already inplace)
  • ExpandDims
  • Squeeze
  • Stack
  • Unstack (will need to specify shapes to work without coordinates, similar to einops rearrage, may consider not
    overloading the name to avoid confusion) (@OriolAbril working on it in first pass at unstack #1412 )
  • Concat
  • Broadcast_arrays (chaining Elemwise second should achieve this)

Array creation stuff

  • ZerosLike / OnesLike
  • Is there anything else?

Indexing stuff

  • __getitem__ + isel with non XTensor indices (in progress, missing tests and lowering)
  • __getitem__ + isel with XTensor indices (can be multi-dimensional, not sure the rules of xarray).
  • Indexing update with non XTensor indices (set and inc)
  • Indexing update with XTensor indices
    It probably makes sense to convert the non-XTensor indices to XTensor indices if they can be rendered equivalent, to reduce logic needed.

RandomVariables

This is quite important, as we'll need those for PyMC models! They are a mix of blockwise + size argument (which can or not be redundant)


📚 Documentation preview 📚: https://pytensor--1411.org.readthedocs.build/en/1411/

@ricardoV94 ricardoV94 added the enhancement New feature or request label May 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant