From a4d57272dd5a2285f06a6c5c7f3b866c30501380 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 2 Aug 2023 14:59:22 +0200 Subject: [PATCH 1/4] POC named tensors Co-authored-by: Oriol Abril-Pla --- pytensor/xtensor/__init__.py | 13 ++ pytensor/xtensor/basic.py | 199 +++++++++++++++++++++ pytensor/xtensor/linalg.py | 72 ++++++++ pytensor/xtensor/math.py | 31 ++++ pytensor/xtensor/rewriting/__init__.py | 2 + pytensor/xtensor/rewriting/basic.py | 113 ++++++++++++ pytensor/xtensor/rewriting/shape.py | 29 +++ pytensor/xtensor/rewriting/utils.py | 33 ++++ pytensor/xtensor/shape.py | 71 ++++++++ pytensor/xtensor/type.py | 234 +++++++++++++++++++++++++ tests/xtensor/__init__.py | 0 tests/xtensor/test_linalg.py | 84 +++++++++ tests/xtensor/test_math.py | 85 +++++++++ tests/xtensor/test_shape.py | 104 +++++++++++ tests/xtensor/util.py | 37 ++++ 15 files changed, 1107 insertions(+) create mode 100644 pytensor/xtensor/__init__.py create mode 100644 pytensor/xtensor/basic.py create mode 100644 pytensor/xtensor/linalg.py create mode 100644 pytensor/xtensor/math.py create mode 100644 pytensor/xtensor/rewriting/__init__.py create mode 100644 pytensor/xtensor/rewriting/basic.py create mode 100644 pytensor/xtensor/rewriting/shape.py create mode 100644 pytensor/xtensor/rewriting/utils.py create mode 100644 pytensor/xtensor/shape.py create mode 100644 pytensor/xtensor/type.py create mode 100644 tests/xtensor/__init__.py create mode 100644 tests/xtensor/test_linalg.py create mode 100644 tests/xtensor/test_math.py create mode 100644 tests/xtensor/test_shape.py create mode 100644 tests/xtensor/util.py diff --git a/pytensor/xtensor/__init__.py b/pytensor/xtensor/__init__.py new file mode 100644 index 0000000000..add3e7599b --- /dev/null +++ b/pytensor/xtensor/__init__.py @@ -0,0 +1,13 @@ +import warnings + +import pytensor.xtensor.rewriting +from pytensor.xtensor.type import ( + XTensorType, + as_xtensor, + as_xtensor_variable, + xtensor, + xtensor_constant, +) + + +warnings.warn("xtensor module is experimental and full of bugs") diff --git a/pytensor/xtensor/basic.py b/pytensor/xtensor/basic.py new file mode 100644 index 0000000000..f3b0c8cf11 --- /dev/null +++ b/pytensor/xtensor/basic.py @@ -0,0 +1,199 @@ +from itertools import chain + +import pytensor.scalar as ps +from pytensor.graph import Apply, Op +from pytensor.tensor import TensorType, tensor +from pytensor.tensor.utils import _parse_gufunc_signature +from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor + + +class XOp(Op): + """A base class for XOps that shouldn't be materialized""" + + def perform(self, node, inputs, outputs): + raise NotImplementedError( + "xtensor operations must be rewritten as tensor operations" + ) + + +class XViewOp(Op): + # Make this a View Op with C-implementation + view_map = {0: [0]} + + def perform(self, node, inputs, output_storage): + output_storage[0][0] = inputs[0] + + +class TensorFromXTensor(XViewOp): + __props__ = () + + def make_node(self, x) -> Apply: + if not isinstance(x.type, XTensorType): + raise TypeError(f"x must be have an XTensorType, got {type(x.type)}") + output = TensorType(x.type.dtype, shape=x.type.shape)() + return Apply(self, [x], [output]) + + +tensor_from_xtensor = TensorFromXTensor() + + +class XTensorFromTensor(XViewOp): + __props__ = ("dims",) + + def __init__(self, dims): + super().__init__() + self.dims = dims + + def make_node(self, x) -> Apply: + if not isinstance(x.type, TensorType): + raise TypeError(f"x must be an TensorType type, got {type(x.type)}") + output = xtensor(dtype=x.type.dtype, dims=self.dims, shape=x.type.shape) + return Apply(self, [x], [output]) + + +def xtensor_from_tensor(x, dims): + return XTensorFromTensor(dims=dims)(x) + + +class Rename(XViewOp): + __props__ = ("new_dims",) + + def __init__(self, new_dims: tuple[str, ...]): + super().__init__() + self.new_dims = new_dims + + def make_node(self, x): + x = as_xtensor(x) + output = x.type.clone(dims=self.new_dims)() + return Apply(self, [x], [output]) + + +def rename(x, name_dict: dict[str, str] | None = None, **names: str): + if name_dict is not None: + if names: + raise ValueError("Cannot use both positional and keyword names in rename") + names = name_dict + + x = as_xtensor(x) + old_names = x.type.dims + new_names = list(old_names) + for old_name, new_name in names.items(): + try: + new_names[old_names.index(old_name)] = new_name + except IndexError: + raise ValueError( + f"Cannot rename {old_name} to {new_name}: {old_name} not in {old_names}" + ) + + return Rename(tuple(new_names))(x) + + +class XElemwise(XOp): + __props__ = ("scalar_op",) + + def __init__(self, scalar_op): + super().__init__() + self.scalar_op = scalar_op + + def make_node(self, *inputs): + inputs = [as_xtensor(inp) for inp in inputs] + if (self.scalar_op.nin != -1) and (len(inputs) != self.scalar_op.nin): + raise ValueError( + f"Wrong number of inputs, expected {self.scalar_op.nin}, got {len(inputs)}" + ) + + dims_and_shape: dict[str, int | None] = {} + for inp in inputs: + for dim, dim_length in zip(inp.type.dims, inp.type.shape): + if dim not in dims_and_shape: + dims_and_shape[dim] = dim_length + elif dim_length is not None: + # Check for conflicting shapes + if (dims_and_shape[dim] is not None) and ( + dims_and_shape[dim] != dim_length + ): + raise ValueError(f"Dimension {dim} has conflicting shapes") + # Keep the non-None shape + dims_and_shape[dim] = dim_length + + output_dims, output_shape = zip(*dims_and_shape.items()) + + dummy_scalars = [ps.get_scalar_type(inp.type.dtype)() for inp in inputs] + output_dtypes = [ + out.type.dtype for out in self.scalar_op.make_node(*dummy_scalars).outputs + ] + outputs = [ + xtensor(dtype=output_dtype, dims=output_dims, shape=output_shape) + for output_dtype in output_dtypes + ] + return Apply(self, inputs, outputs) + + +class XBlockwise(XOp): + __props__ = ("core_op", "signature", "core_dims") + + def __init__( + self, + core_op: Op, + signature: str, + core_dims: tuple[tuple[tuple[str, ...], ...], tuple[tuple[str, ...], ...]], + ): + super().__init__() + self.core_op = core_op + self.signature = signature + self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature) + self.core_dims = core_dims + + def make_node(self, *inputs): + inputs = [as_xtensor(i) for i in inputs] + if len(inputs) != len(self.inputs_sig): + raise ValueError( + f"Wrong number of inputs, expected {len(self.inputs_sig)}, got {len(inputs)}" + ) + + dims_and_shape: dict[str, int | None] = {} + for inp in inputs: + for dim, dim_length in zip(inp.type.dims, inp.type.shape): + if dim not in dims_and_shape: + dims_and_shape[dim] = dim_length + elif dim_length is not None: + # Check for conflicting shapes + if (dims_and_shape[dim] is not None) and ( + dims_and_shape[dim] != dim_length + ): + raise ValueError(f"Dimension {dim} has conflicting shapes") + # Keep the non-None shape + dims_and_shape[dim] = dim_length + + core_inputs_dims, core_outputs_dims = self.core_dims + # TODO: Avoid intermediate dict + core_dims = set(chain.from_iterable(core_inputs_dims)) + batched_dims_and_shape = { + k: v for k, v in dims_and_shape.items() if k not in core_dims + } + batch_dims, batch_shape = zip(*batched_dims_and_shape.items()) + + dummy_core_inputs = [] + for inp, core_inp_dims in zip(inputs, core_inputs_dims): + try: + core_static_shape = [ + inp.type.shape[inp.type.dims.index(d)] for d in core_inp_dims + ] + except IndexError: + raise ValueError( + f"At least one core dim={core_inp_dims} missing from input {inp} with dims={inp.type.dims}" + ) + dummy_core_inputs.append( + tensor(dtype=inp.type.dtype, shape=core_static_shape) + ) + core_node = self.core_op.make_node(*dummy_core_inputs) + + outputs = [ + xtensor( + dtype=core_out.type.dtype, + shape=batch_shape + core_out.type.shape, + dims=batch_dims + core_out_dims, + ) + for core_out, core_out_dims in zip(core_node.outputs, core_outputs_dims) + ] + return Apply(self, inputs, outputs) diff --git a/pytensor/xtensor/linalg.py b/pytensor/xtensor/linalg.py new file mode 100644 index 0000000000..7b75729cb7 --- /dev/null +++ b/pytensor/xtensor/linalg.py @@ -0,0 +1,72 @@ +from collections.abc import Sequence +from typing import Literal + +from pytensor.tensor.slinalg import Cholesky, Solve +from pytensor.xtensor import as_xtensor +from pytensor.xtensor.basic import XBlockwise + + +def cholesky( + x, + lower: bool = True, + *, + check_finite: bool = False, + overwrite_a: bool = False, + on_error: Literal["raise", "nan"] = "raise", + dims: Sequence[str], +): + if len(dims) != 2: + raise ValueError(f"Cholesky needs two dims, got {len(dims)}") + + core_op = Cholesky( + lower=lower, + check_finite=check_finite, + overwrite_a=overwrite_a, + on_error=on_error, + ) + core_dims = ( + ((dims[0], dims[1]),), + ((dims[0], dims[1]),), + ) + x_op = XBlockwise(core_op, signature=core_op.gufunc_signature, core_dims=core_dims) + return x_op(x) + + +def solve( + a, + b, + dims: Sequence[str], + assume_a="gen", + lower: bool = False, + check_finite: bool = False, +): + a, b = as_xtensor(a), as_xtensor(b) + if len(dims) == 2: + b_ndim = 1 + [m1_dim] = [dim for dim in dims if dim not in b.type.dims] + m2_dim = dims[0] if dims[0] != m1_dim else dims[1] + input_core_dims = ((m1_dim, m2_dim), (m2_dim,)) + output_core_dims = ((m2_dim,),) + elif len(dims) == 3: + b_ndim = 2 + [n_dim] = [dim for dim in dims if dim not in a.type.dims] + [m1_dim, m2_dim] = [dim for dim in dims if dim != n_dim] + input_core_dims = ((m1_dim, m2_dim), (m2_dim, n_dim)) + output_core_dims = ( + ( + m2_dim, + n_dim, + ), + ) + else: + raise ValueError("Solve dims must have length 2 or 3") + + core_op = Solve( + b_ndim=b_ndim, assume_a=assume_a, lower=lower, check_finite=check_finite + ) + x_op = XBlockwise( + core_op, + signature=core_op.gufunc_signature, + core_dims=(input_core_dims, output_core_dims), + ) + return x_op(a, b) diff --git a/pytensor/xtensor/math.py b/pytensor/xtensor/math.py new file mode 100644 index 0000000000..262032f297 --- /dev/null +++ b/pytensor/xtensor/math.py @@ -0,0 +1,31 @@ +import inspect +import sys + +import pytensor.scalar as ps +from pytensor.scalar import ScalarOp +from pytensor.xtensor.basic import XElemwise + + +this_module = sys.modules[__name__] + + +def get_all_scalar_ops(): + """ + Find all scalar operations in the pytensor.scalar module that can be wrapped with XElemwise. + + Returns: + dict: A dictionary mapping operation names to XElemwise instances + """ + result = {} + + # Get all module members + for name, obj in inspect.getmembers(ps): + # Check if the object is a scalar op (has make_node method and is not an abstract class) + if isinstance(obj, ScalarOp): + result[name] = XElemwise(obj) + + return result + + +for name, op in get_all_scalar_ops().items(): + setattr(this_module, name, op) diff --git a/pytensor/xtensor/rewriting/__init__.py b/pytensor/xtensor/rewriting/__init__.py new file mode 100644 index 0000000000..d4bb32ad66 --- /dev/null +++ b/pytensor/xtensor/rewriting/__init__.py @@ -0,0 +1,2 @@ +import pytensor.xtensor.rewriting.basic +import pytensor.xtensor.rewriting.shape diff --git a/pytensor/xtensor/rewriting/basic.py b/pytensor/xtensor/rewriting/basic.py new file mode 100644 index 0000000000..777780e91e --- /dev/null +++ b/pytensor/xtensor/rewriting/basic.py @@ -0,0 +1,113 @@ +from pytensor.graph import node_rewriter +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.elemwise import Elemwise +from pytensor.xtensor.basic import ( + Rename, + TensorFromXTensor, + XBlockwise, + XElemwise, + XTensorFromTensor, + tensor_from_xtensor, + xtensor_from_tensor, +) +from pytensor.xtensor.rewriting.utils import register_xcanonicalize + + +@register_xcanonicalize +@node_rewriter(tracks=[TensorFromXTensor]) +def useless_tensor_from_xtensor(fgraph, node): + """TensorFromXTensor(XTensorFromTensor(x)) -> x""" + [x] = node.inputs + if x.owner and isinstance(x.owner.op, XTensorFromTensor): + return [x.owner.inputs[0]] + + +@register_xcanonicalize +@node_rewriter(tracks=[XTensorFromTensor]) +def useless_xtensor_from_tensor(fgraph, node): + """XTensorFromTensor(TensorFromXTensor(x)) -> x""" + [x] = node.inputs + if x.owner and isinstance(x.owner.op, TensorFromXTensor): + return [x.owner.inputs[0]] + + +@register_xcanonicalize +@node_rewriter(tracks=[TensorFromXTensor]) +def useless_tensor_from_xtensor_of_rename(fgraph, node): + """TensorFromXTensor(Rename(x)) -> TensorFromXTensor(x)""" + [renamed_x] = node.inputs + if renamed_x.owner and isinstance(renamed_x.owner.op, Rename): + [x] = renamed_x.owner.inputs + return node.op(x, return_list=True) + + +@register_xcanonicalize +@node_rewriter(tracks=[Rename]) +def useless_rename(fgraph, node): + """ + + Rename(Rename(x, inner_dims), outer_dims) -> Rename(x, outer_dims) + Rename(X, XTensorFromTensor(x, inner_dims), outer_dims) -> XTensorFrom_tensor(x, outer_dims) + """ + [renamed_x] = node.inputs + if renamed_x.owner and isinstance(renamed_x.owner.op, Rename | XTensorFromTensor): + [x] = renamed_x.owner.inputs + return node.op(x, return_list=True) + + +@register_xcanonicalize +@node_rewriter(tracks=[XElemwise]) +def lower_elemwise(fgraph, node): + out_dims = node.outputs[0].type.dims + + # Convert input XTensors to Tensors and align batch dimensions + tensor_inputs = [] + for inp in node.inputs: + inp_dims = inp.type.dims + order = [ + inp_dims.index(out_dim) if out_dim in inp_dims else "x" + for out_dim in out_dims + ] + tensor_inp = tensor_from_xtensor(inp).dimshuffle(order) + tensor_inputs.append(tensor_inp) + + tensor_outs = Elemwise(scalar_op=node.op.scalar_op)( + *tensor_inputs, return_list=True + ) + + # Convert output Tensors to XTensors + new_outs = [ + xtensor_from_tensor(tensor_out, dims=out_dims) for tensor_out in tensor_outs + ] + return new_outs + + +@register_xcanonicalize +@node_rewriter(tracks=[XBlockwise]) +def lower_blockwise(fgraph, node): + op: XBlockwise = node.op + batch_ndim = node.outputs[0].type.ndim - len(op.outputs_sig[0]) + batch_dims = node.outputs[0].type.dims[:batch_ndim] + + # Convert input Tensors to XTensors, align batch dimensions and place core dimension at the end + tensor_inputs = [] + for inp, core_dims in zip(node.inputs, op.core_dims[0]): + inp_dims = inp.type.dims + # Align the batch dims of the input, and place the core dims on the right + batch_order = [ + inp_dims.index(batch_dim) if batch_dim in inp_dims else "x" + for batch_dim in batch_dims + ] + core_order = [inp_dims.index(core_dim) for core_dim in core_dims] + tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order) + tensor_inputs.append(tensor_inp) + + tensor_op = Blockwise(core_op=node.op.core_op, signature=op.signature) + tensor_outs = tensor_op(*tensor_inputs, return_list=True) + + # Convert output Tensors to XTensors + new_outs = [ + xtensor_from_tensor(tensor_out, dims=old_out.type.dims) + for (tensor_out, old_out) in zip(tensor_outs, node.outputs, strict=True) + ] + return new_outs diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py new file mode 100644 index 0000000000..b2eabb5c8e --- /dev/null +++ b/pytensor/xtensor/rewriting/shape.py @@ -0,0 +1,29 @@ +from pytensor.graph import node_rewriter +from pytensor.tensor import moveaxis +from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor +from pytensor.xtensor.rewriting.basic import register_xcanonicalize +from pytensor.xtensor.shape import Stack + + +@register_xcanonicalize +@node_rewriter(tracks=[Stack]) +def lower_stack(fgraph, node): + [x] = node.inputs + batch_ndim = x.type.ndim - len(node.op.stacked_dims) + stacked_axes = [ + i for i, dim in enumerate(x.type.dims) if dim in node.op.stacked_dims + ] + end = tuple(range(-len(stacked_axes), 0)) + + x_tensor = tensor_from_xtensor(x) + x_tensor_transposed = moveaxis(x_tensor, source=stacked_axes, destination=end) + if batch_ndim == (x.type.ndim - 1): + # This happens when we stack a "single" dimension, in this case all we need is the transpose + # Note: If we have meaningful rewrites before lowering, consider canonicalizing this as a Transpose + Rename + final_tensor = x_tensor_transposed + else: + final_shape = (*tuple(x_tensor_transposed.shape)[:batch_ndim], -1) + final_tensor = x_tensor_transposed.reshape(final_shape) + + new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims) + return [new_out] diff --git a/pytensor/xtensor/rewriting/utils.py b/pytensor/xtensor/rewriting/utils.py new file mode 100644 index 0000000000..03de2c67a9 --- /dev/null +++ b/pytensor/xtensor/rewriting/utils.py @@ -0,0 +1,33 @@ +from pytensor.compile import optdb +from pytensor.graph.rewriting.basic import NodeRewriter +from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase + + +optdb.register( + "xcanonicalize", + EquilibriumDB(ignore_newtrees=False), + "fast_run", + "fast_compile", + "xtensor", + position=0, +) + + +def register_xcanonicalize( + node_rewriter: RewriteDatabase | NodeRewriter | str, *tags: str, **kwargs +): + if isinstance(node_rewriter, str): + + def register(inner_rewriter: RewriteDatabase | NodeRewriter): + return register_xcanonicalize( + inner_rewriter, node_rewriter, *tags, **kwargs + ) + + return register + + else: + name = kwargs.pop("name", None) or node_rewriter.__name__ + optdb["xtensor"].register( + name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs + ) + return node_rewriter diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py new file mode 100644 index 0000000000..8fa0f42630 --- /dev/null +++ b/pytensor/xtensor/shape.py @@ -0,0 +1,71 @@ +from collections.abc import Sequence + +from pytensor.graph import Apply +from pytensor.xtensor.basic import XOp +from pytensor.xtensor.type import as_xtensor, xtensor + + +class Stack(XOp): + __props__ = ("new_dim_name", "stacked_dims") + + def __init__(self, new_dim_name: str, stacked_dims: tuple[str, ...]): + super().__init__() + if new_dim_name in stacked_dims: + raise ValueError( + f"Stacking dim {new_dim_name} must not be in {stacked_dims}" + ) + if not stacked_dims: + raise ValueError(f"Stacking dims must not be empty: got {stacked_dims}") + self.new_dim_name = new_dim_name + self.stacked_dims = stacked_dims + + def make_node(self, x): + x = as_xtensor(x) + if not (set(self.stacked_dims) <= set(x.type.dims)): + raise ValueError( + f"Stacking dims {self.stacked_dims} must be a subset of {x.type.dims}" + ) + if self.new_dim_name in x.type.dims: + raise ValueError( + f"Stacking dim {self.new_dim_name} must not be in {x.type.dims}" + ) + if len(self.stacked_dims) == x.type.ndim: + batch_dims, batch_shape = (), () + else: + batch_dims, batch_shape = zip( + *( + (dim, shape) + for dim, shape in zip(x.type.dims, x.type.shape) + if dim not in self.stacked_dims + ) + ) + stack_shape = 1 + for dim, shape in zip(x.type.dims, x.type.shape): + if dim in self.stacked_dims: + if shape is None: + stack_shape = None + break + else: + stack_shape *= shape + output = xtensor( + dtype=x.type.dtype, + shape=(*batch_shape, stack_shape), + dims=(*batch_dims, self.new_dim_name), + ) + return Apply(self, [x], [output]) + + +def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]): + if dim is not None: + if dims: + raise ValueError("Cannot use both positional dim and keyword dims in stack") + dims = dim + + y = x + for new_dim_name, stacked_dims in dims.items(): + if isinstance(stacked_dims, str): + raise TypeError( + f"Stacking dims must be a sequence of strings, got a single string: {stacked_dims}" + ) + y = Stack(new_dim_name, tuple(stacked_dims))(y) + return y diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py new file mode 100644 index 0000000000..c206c911ec --- /dev/null +++ b/pytensor/xtensor/type.py @@ -0,0 +1,234 @@ +try: + import xarray as xr + + XARRAY_AVAILABLE = True +except ModuleNotFoundError: + XARRAY_AVAILABLE = False + +from collections.abc import Sequence +from typing import TypeVar + +import numpy as np + +from pytensor import _as_symbolic, config +from pytensor.graph import Apply, Constant +from pytensor.graph.basic import Variable, OptionalApplyType +from pytensor.graph.type import HasDataType, HasShape, Type +from pytensor.tensor.utils import hash_from_ndarray +from pytensor.utils import hash_from_code + + +class XTensorType(Type, HasDataType, HasShape): + """A `Type` for Xtensors (Xarray-like tensors with dims).""" + + __props__ = ("dtype", "shape", "dims") + + def __init__( + self, + dtype: str | np.dtype, + *, + dims: Sequence[str], + shape: Sequence[int | None] | None = None, + name: str | None = None, + ): + if dtype == "floatX": + self.dtype = config.floatX + else: + self.dtype = np.dtype(dtype).name + + self.dims = tuple(dims) + if shape is None: + self.shape = (None,) * len(self.dims) + else: + self.shape = tuple(shape) + self.ndim = len(self.dims) + self.name = name + + def clone( + self, + dtype=None, + dims=None, + shape=None, + **kwargs, + ): + if dtype is None: + dtype = self.dtype + if dims is None: + dims = self.dims + if shape is None: + shape = self.shape + return type(self)(dtype=dtype, shape=shape, dims=dims, **kwargs) + + def filter(self, value, strict=False, allow_downcast=None): + # TODO implement this + return value + + def convert_variable(self, var): + # TODO: Implement this + return var + + def __repr__(self): + return f"XTensorType({self.dtype}, {self.dims}, {self.shape})" + + def __hash__(self): + return hash((type(self), self.dtype, self.shape, self.dims)) + + def __eq__(self, other): + return ( + type(self) is type(other) + and self.dims == other.dims + and self.shape == other.shape + ) + + def is_super(self, otype): + # TODO: Implement this + return True + + +def xtensor( + name: str | None = None, + *, + dims: Sequence[str], + shape: Sequence[int | None] | None = None, + dtype: str | np.dtype = "floatX", +): + return XTensorType(dtype=dtype, dims=dims, shape=shape)(name=name) + + +_XTensorTypeType = TypeVar("_XTensorTypeType", bound=XTensorType) + + +class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): + # These can't work because Python requires native output types + def __bool__(self): + raise TypeError( + "XTensorVariable cannot be converted to Python boolean. " + "Call `.astype(bool)` for the symbolic equivalent." + ) + + def __index__(self): + raise TypeError( + "XTensorVariable cannot be converted to Python integer. " + "Call `.astype(int)` for the symbolic equivalent." + ) + + def __int__(self): + raise TypeError( + "XTensorVariable cannot be converted to Python integer. " + "Call `.astype(int)` for the symbolic equivalent." + ) + + def __float__(self): + raise TypeError( + "XTensorVariables cannot be converted to Python float. " + "Call `.astype(float)` for the symbolic equivalent." + ) + + def __complex__(self): + raise TypeError( + "XTensorVariables cannot be converted to Python complex number. " + "Call `.astype(complex)` for the symbolic equivalent." + ) + + def __setitem__(self, key, value): + raise TypeError( + "XTensorVariable does not support item assignment. Use the output of `x[idx].set` or `x[idx].inc` instead." + ) + + def __getitem__(self, idx): + from pytensor.xtensor.indexing import index + + if isinstance(idx, dict): + return self.isel(idx) + + return index(self, *idx) + + +class XTensorVariable(Variable): + pass + + # def __str__(self): + # return f"{self.__class__.__name__}{{{self.format},{self.dtype}}}" + + # def __repr__(self): + # return str(self) + + +class XTensorConstantSignature(tuple): + def __eq__(self, other): + if type(self) is not type(other): + return False + + (t0, d0), (t1, d1) = self, other + if t0 != t1 or d0.shape != d1.shape: + return False + + return True + + def __ne__(self, other): + return not self == other + + def __hash__(self): + (a, b) = self + return hash(type(self)) ^ hash(a) ^ hash(type(b)) + + def pytensor_hash(self): + t, d = self + return "".join([hash_from_ndarray(d)] + [hash_from_code(dim) for dim in t.dims]) + + +class XTensorConstant(XTensorVariable, Constant[_XTensorTypeType]): + def __init__(self, type: _XTensorTypeType, data, name=None): + # TODO: Add checks that type and data are compatible + Constant.__init__(self, type, data, name) + + def signature(self): + assert self.data is not None + return XTensorConstantSignature((self.type, self.data)) + + +XTensorType.variable_type = XTensorVariable +XTensorType.constant_type = XTensorConstant + + +def xtensor_constant(x, name=None): + if not isinstance(x, xr.DataArray): + raise TypeError("xtensor.constant must be called on a Xarray DataArray") + try: + return XTensorConstant( + XTensorType(dtype=x.dtype, dims=x.dims, shape=x.shape), + x.values.copy(), + name=name, + ) + except TypeError: + raise TypeError(f"Could not convert {x} to XTensorType") + + +if XARRAY_AVAILABLE: + + @_as_symbolic.register(xr.DataArray) + def as_symbolic_xarray(x, **kwargs): + return xtensor_constant(x, **kwargs) + + +def as_xtensor_variable(x, name=None): + if isinstance(x, Apply): + if len(x.outputs) != 1: + raise ValueError( + "It is ambiguous which output of a " + "multi-output Op has to be fetched.", + x, + ) + else: + x = x.outputs[0] + if isinstance(x, Variable): + if not isinstance(x.type, XTensorType): + raise TypeError(f"Variable type field must be a XTensorType, got {x.type}") + return x + try: + return xtensor_constant(x, name=name) + except TypeError as err: + raise TypeError(f"Cannot convert {x} to XTensorType {type(x)}") from err + + +as_xtensor = as_xtensor_variable diff --git a/tests/xtensor/__init__.py b/tests/xtensor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/xtensor/test_linalg.py b/tests/xtensor/test_linalg.py new file mode 100644 index 0000000000..a673683f57 --- /dev/null +++ b/tests/xtensor/test_linalg.py @@ -0,0 +1,84 @@ +# ruff: noqa: E402 + +import pytest + + +pytest.importorskip("xarray") +pytest.importorskip("xarray_einstats") + +import numpy as np +from xarray import DataArray +from xarray_einstats.linalg import ( + cholesky as xr_cholesky, +) +from xarray_einstats.linalg import ( + solve as xr_solve, +) + +from pytensor import function +from pytensor.xtensor.linalg import cholesky, solve +from pytensor.xtensor.type import xtensor + + +def test_cholesky(): + x = xtensor("x", dims=("a", "batch", "b"), shape=(4, 3, 4)) + y = cholesky(x, dims=["b", "a"]) + assert y.type.dims == ("batch", "b", "a") + assert y.type.shape == (3, 4, 4) + + fn = function([x], y) + rng = np.random.default_rng(25) + x_ = rng.random(size=(4, 3, 3)) + x_ = x_ @ x_.mT + x_test = DataArray(x_.transpose(1, 0, 2), dims=x.type.dims) + np.testing.assert_allclose( + fn(x_test.values), + xr_cholesky(x_test, dims=["b", "a"]).values, + ) + + +def test_solve_vector_b(): + a = xtensor("a", dims=("city", "country", "galaxy"), shape=(None, 4, 1)) + b = xtensor("b", dims=("city", "planet"), shape=(None, 2)) + x = solve(a, b, dims=["country", "city"]) + assert x.type.dims == ("galaxy", "planet", "city") + assert x.type.shape == ( + 1, + 2, + None, + ) # Core Solve doesn't make use of the fact A must be square in the static shape + + fn = function([a, b], x) + + rng = np.random.default_rng(25) + a_test = DataArray(rng.random(size=(4, 4, 1)), dims=a.type.dims) + b_test = DataArray(rng.random(size=(4, 2)), dims=b.type.dims) + + np.testing.assert_allclose( + fn(a_test.values, b_test.values), + xr_solve(a_test, b_test, dims=["country", "city"]).values, + ) + + +def test_solve_matrix_b(): + a = xtensor("a", dims=("city", "country", "galaxy"), shape=(None, 4, 1)) + b = xtensor("b", dims=("district", "city", "planet"), shape=(5, None, 2)) + x = solve(a, b, dims=["country", "city", "district"]) + assert x.type.dims == ("galaxy", "planet", "city", "district") + assert x.type.shape == ( + 1, + 2, + None, + 5, + ) # Core Solve doesn't make use of the fact A must be square in the static shape + + fn = function([a, b], x) + + rng = np.random.default_rng(25) + a_test = DataArray(rng.random(size=(4, 4, 1)), dims=a.type.dims) + b_test = DataArray(rng.random(size=(5, 4, 2)), dims=b.type.dims) + + np.testing.assert_allclose( + fn(a_test.values, b_test.values), + xr_solve(a_test, b_test, dims=["country", "city", "district"]).values, + ) diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py new file mode 100644 index 0000000000..4d463cfabb --- /dev/null +++ b/tests/xtensor/test_math.py @@ -0,0 +1,85 @@ +import pytest + + +# ruff: noqa: E402 +pytest.importorskip("xarray") # + +import numpy as np +from xarray import DataArray + +from pytensor import function +from pytensor.xtensor.basic import rename +from pytensor.xtensor.math import add, exp +from pytensor.xtensor.type import xtensor +from tests.xtensor.util import xr_assert_allclose, xr_function + + +def test_dimension_alignment(): + x = xtensor("x", dims=("city", "country", "planet"), shape=(2, 3, 4)) + y = xtensor( + "y", + dims=("galaxy", "country", "city"), + shape=(5, 3, 2), + ) + z = xtensor("z", dims=("universe",), shape=(1,)) + out = add(x, y, z) + assert out.type.dims == ("city", "country", "planet", "galaxy", "universe") + + fn = function([x, y, z], out) + + rng = np.random.default_rng(41) + test_x, test_y, test_z = ( + DataArray(rng.normal(size=inp.type.shape), dims=inp.type.dims) + for inp in [x, y, z] + ) + np.testing.assert_allclose( + fn(test_x.values, test_y.values, test_z.values), + (test_x + test_y + test_z).values, + ) + + +def test_renamed_dimension_alignment(): + x = xtensor("x", dims=("a", "b1", "b2"), shape=(2, 3, 3)) + y = rename(x, b1="b2", b2="b1") + z = rename(x, b2="b3") + assert y.type.dims == ("a", "b2", "b1") + assert z.type.dims == ("a", "b1", "b3") + + out1 = add(x, x) # self addition + assert out1.type.dims == ("a", "b1", "b2") + out2 = add(x, y) # transposed addition + assert out2.type.dims == ("a", "b1", "b2") + out3 = add(x, z) # outer addition + assert out3.type.dims == ("a", "b1", "b2", "b3") + + fn = xr_function([x], [out1, out2, out3]) + x_test = DataArray( + np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), + dims=x.type.dims, + ) + results = fn(x_test) + expected_results = [ + x_test + x_test, + x_test + x_test.rename(b1="b2", b2="b1"), + x_test + x_test.rename(b2="b3"), + ] + for result, expected_result in zip(results, expected_results): + xr_assert_allclose(result, expected_result) + + +def test_chained_operations(): + x = xtensor("x", dims=("city",), shape=(None,)) + y = xtensor("y", dims=("country",), shape=(4,)) + z = add(exp(x), exp(y)) + assert z.type.dims == ("city", "country") + assert z.type.shape == (None, 4) + + fn = function([x, y], z) + + x_test = DataArray(np.zeros(3), dims="city") + y_test = DataArray(np.ones(4), dims="country") + + np.testing.assert_allclose( + fn(x_test.values, y_test.values), + (np.exp(x_test) + np.exp(y_test)).values, + ) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py new file mode 100644 index 0000000000..25bdf68ee6 --- /dev/null +++ b/tests/xtensor/test_shape.py @@ -0,0 +1,104 @@ +# ruff: noqa: E402 +import pytest + + +pytest.importorskip("xarray") + +from itertools import chain, combinations + +import numpy as np +from xarray import DataArray + +from pytensor.xtensor.shape import stack +from pytensor.xtensor.type import xtensor +from tests.xtensor.util import xr_assert_allclose, xr_function + + +def powerset(iterable, min_group_size=0): + "Subsequences of the iterable from shortest to longest." + # powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3) + s = list(iterable) + return chain.from_iterable( + combinations(s, r) for r in range(min_group_size, len(s) + 1) + ) + + +@pytest.mark.xfail(reason="Not yet implemented") +def test_transpose(): + transpose = None + a, b, c, d, e = "abcde" + + x = xtensor("x", dims=(a, b, c, d, e), shape=(2, 3, 5, 7, 11)) + permutations = [ + (a, b, c, d, e), # identity + (e, d, c, b, a), # full tranpose + (), # eqivalent to full transpose + (a, b, c, e, d), # swap last two dims + (..., d, c), # equivalent to (a, b, e, d, c) + (b, a, ..., e, d), # equivalent to (b, a, c, d, e) + (c, a, ...), # equivalent to (c, a, b, d, e) + ] + outs = [transpose(x, *perm) for perm in permutations] + + fn = xr_function([x], outs) + x_test = DataArray( + np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), + dims=x.type.dims, + ) + res = fn(x_test) + expected_res = [x_test.transpose(*perm) for perm in permutations] + for outs_i, res_i, expected_res_i in zip(outs, res, expected_res): + xr_assert_allclose(res_i, expected_res_i) + + +def test_stack(): + dims = ("a", "b", "c", "d") + x = xtensor("x", dims=dims, shape=(2, 3, 5, 7)) + outs = [ + stack(x, new_dim=dims_to_stack) + for dims_to_stack in powerset(dims, min_group_size=2) + ] + + fn = xr_function([x], outs) + x_test = DataArray( + np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), + dims=x.type.dims, + ) + res = fn(x_test) + + expected_res = [ + x_test.stack(new_dim=dims_to_stack) + for dims_to_stack in powerset(dims, min_group_size=2) + ] + for outs_i, res_i, expected_res_i in zip(outs, res, expected_res): + xr_assert_allclose(res_i, expected_res_i) + + +def test_stack_single_dim(): + x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 5)) + out = stack(x, {"d": ["a"]}) + assert out.type.dims == ("b", "c", "d") + + fn = xr_function([x], out) + x_test = DataArray( + np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), + dims=x.type.dims, + ) + fn.fn.dprint(print_type=True) + res = fn(x_test) + expected_res = x_test.stack(d=["a"]) + xr_assert_allclose(res, expected_res) + + +def test_multiple_stacks(): + x = xtensor("x", dims=("a", "b", "c", "d"), shape=(2, 3, 5, 7)) + out = stack(x, new_dim1=("a", "b"), new_dim2=("c", "d")) + + fn = xr_function([x], [out]) + x_test = DataArray( + np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), + dims=x.type.dims, + ) + res = fn(x_test) + expected_res = x_test.stack(new_dim1=("a", "b"), new_dim2=("c", "d")) + xr_assert_allclose(res[0], expected_res) diff --git a/tests/xtensor/util.py b/tests/xtensor/util.py new file mode 100644 index 0000000000..b429adb794 --- /dev/null +++ b/tests/xtensor/util.py @@ -0,0 +1,37 @@ +from xarray import DataArray +from xarray.testing import assert_allclose + +from pytensor import function +from pytensor.xtensor.type import XTensorType + + +def xr_function(*args, **kwargs): + """Compile and wrap a PyTensor function to return xarray DataArrays.""" + fn = function(*args, **kwargs) + symbolic_outputs = fn.maker.fgraph.outputs + assert all( + isinstance(out.type, XTensorType) for out in symbolic_outputs + ), "All outputs must be xtensor" + + def xfn(*xr_inputs): + np_inputs = [ + inp.values if isinstance(inp, DataArray) else inp for inp in xr_inputs + ] + np_outputs = fn(*np_inputs) + if not isinstance(np_outputs, tuple | list): + return DataArray(np_outputs, dims=symbolic_outputs[0].type.dims) + else: + return tuple( + DataArray(res, dims=out.type.dims) + for res, out in zip(np_outputs, symbolic_outputs) + ) + + xfn.fn = fn + return xfn + + +def xr_assert_allclose(x, y, *args, **kwargs): + # Assert that two xarray DataArrays are close, ignoring coordinates + x = x.drop_vars(x.coords) + y = y.drop_vars(y.coords) + assert_allclose(x, y, *args, **kwargs) From a6d59fe98e5fddde6891c22bc83cd6ec4e9706a1 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 22 May 2025 15:51:43 +0200 Subject: [PATCH 2/4] Use DimShuffle instead of Reshape in `ix_` --- pytensor/tensor/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index d1bc65172c..30fd85e9c1 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -4539,7 +4539,7 @@ def ix_(*args): new = as_tensor(new) if new.ndim != 1: raise ValueError("Cross index must be 1 dimensional") - new = new.reshape((1,) * k + (new.size,) + (1,) * (nd - k - 1)) + new = new.dimshuffle(*(("x",) * k), 0, *(("x",) * (nd - k - 1))) out.append(new) return tuple(out) From 84d30103767eb14515ad1a16269649403068b9b5 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Wed, 21 May 2025 19:11:02 +0200 Subject: [PATCH 3/4] WIP Implement index operations on XTensorTypes --- pytensor/xtensor/__init__.py | 1 - pytensor/xtensor/basic.py | 2 +- pytensor/xtensor/indexing.py | 142 +++++++++++++++++++++++++ pytensor/xtensor/rewriting/__init__.py | 1 + pytensor/xtensor/rewriting/indexing.py | 27 +++++ pytensor/xtensor/type.py | 71 +++++++++++-- tests/xtensor/test_indexing.py | 42 ++++++++ 7 files changed, 276 insertions(+), 10 deletions(-) create mode 100644 pytensor/xtensor/indexing.py create mode 100644 pytensor/xtensor/rewriting/indexing.py create mode 100644 tests/xtensor/test_indexing.py diff --git a/pytensor/xtensor/__init__.py b/pytensor/xtensor/__init__.py index add3e7599b..4a7b839aaf 100644 --- a/pytensor/xtensor/__init__.py +++ b/pytensor/xtensor/__init__.py @@ -2,7 +2,6 @@ import pytensor.xtensor.rewriting from pytensor.xtensor.type import ( - XTensorType, as_xtensor, as_xtensor_variable, xtensor, diff --git a/pytensor/xtensor/basic.py b/pytensor/xtensor/basic.py index f3b0c8cf11..ff4567a59c 100644 --- a/pytensor/xtensor/basic.py +++ b/pytensor/xtensor/basic.py @@ -12,7 +12,7 @@ class XOp(Op): def perform(self, node, inputs, outputs): raise NotImplementedError( - "xtensor operations must be rewritten as tensor operations" + f"xtensor operation {self} must be lowered to equivalent tensor operations" ) diff --git a/pytensor/xtensor/indexing.py b/pytensor/xtensor/indexing.py new file mode 100644 index 0000000000..c6b45e0817 --- /dev/null +++ b/pytensor/xtensor/indexing.py @@ -0,0 +1,142 @@ +# HERE LIE DRAGONS +# Uselful links to make sense of all the numpy/xarray complexity +# https://numpy.org/devdocs//user/basics.indexing.html +# https://numpy.org/neps/nep-0021-advanced-indexing.html +# https://docs.xarray.dev/en/latest/user-guide/indexing.html +# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html + +from pytensor.graph.basic import Apply, Constant, Variable +from pytensor.scalar.basic import discrete_dtypes +from pytensor.tensor.basic import as_tensor +from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice +from pytensor.xtensor.basic import XOp +from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor + + +def as_idx_variable(idx): + if idx is None or (isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT)): + raise TypeError( + "XTensors do not support indexing with None (np.newaxis), use expand_dims instead" + ) + if isinstance(idx, slice): + idx = make_slice(idx) + elif isinstance(idx, Variable) and isinstance(idx.type, SliceType): + pass + else: + # Must be integer indices, we already counted for None and slices + try: + idx = as_xtensor(idx) + except TypeError: + idx = as_tensor(idx) + if idx.type.dtype == "bool": + raise NotImplementedError("Boolean indexing not yet supported") + if idx.type.dtype not in discrete_dtypes: + raise TypeError("Numerical indices must be integers or boolean") + if idx.type.dtype == "bool" and idx.type.ndim == 0: + # This can't be triggered right now, but will once we lift the boolean restriction + raise NotImplementedError("Scalar boolean indices not supported") + return idx + + +def get_static_slice_length(slc: Variable, dim_length: None | int) -> int | None: + if dim_length is None: + return None + if isinstance(slc, Constant): + d = slc.data + start, stop, step = d.start, d.stop, d.step + elif slc.owner is None: + # It's a root variable no way of knowing what we're getting + return None + else: + # It's a MakeSliceOp + start, stop, step = slc.owner.inputs + if isinstance(start, Constant): + start = start.data + else: + return None + if isinstance(stop, Constant): + stop = stop.data + else: + return None + if isinstance(step, Constant): + step = step.data + else: + return None + return len(range(*slice(start, stop, step).indices(dim_length))) + + +class Index(XOp): + __props__ = () + + def make_node(self, x, *idxs): + x = as_xtensor(x) + idxs = [as_idx_variable(idx) for idx in idxs] + + x_ndim = x.type.ndim + x_dims = x.type.dims + x_shape = x.type.shape + out_dims = [] + out_shape = [] + has_unlabeled_vector_idx = False + has_labeled_vector_idx = False + for i, idx in enumerate(idxs): + if i == x_ndim: + raise IndexError("Too many indices") + if isinstance(idx.type, SliceType): + out_dims.append(x_dims[i]) + out_shape.append(get_static_slice_length(idx, x_shape[i])) + elif isinstance(idx.type, XTensorType): + if has_unlabeled_vector_idx: + raise NotImplementedError( + "Mixing of labeled and unlabeled vector indexing not implemented" + ) + has_labeled_vector_idx = True + idx_dims = idx.type.dims + for dim in idx_dims: + idx_dim_shape = idx.type.shape[idx_dims.index(dim)] + if dim in out_dims: + # Dim already introduced in output by a previous index + # Update static shape or raise if incompatible + out_dim_pos = out_dims.index(dim) + out_dim_shape = out_shape[out_dim_pos] + if out_dim_shape is None: + # We don't know the size of the dimension yet + out_shape[out_dim_pos] = idx_dim_shape + elif ( + idx_dim_shape is not None and idx_dim_shape != out_dim_shape + ): + raise IndexError( + f"Dimension of indexers mismatch for dim {dim}" + ) + else: + # New dimension + out_dims.append(dim) + out_shape.append(idx_dim_shape) + + else: # TensorType + if idx.type.ndim == 0: + # Scalar, dimension is dropped + pass + elif idx.type.ndim == 1: + if has_labeled_vector_idx: + raise NotImplementedError( + "Mixing of labeled and unlabeled vector indexing not implemented" + ) + has_unlabeled_vector_idx = True + out_dims.append(x_dims[i]) + out_shape.append(idx.type.shape[0]) + else: + # Same error that xarray raises + raise IndexError( + "Unlabeled multi-dimensional array cannot be used for indexing" + ) + for j in range(i + 1, x_ndim): + # Add any unindexed dimensions + out_dims.append(x_dims[j]) + out_shape.append(x_shape[j]) + + output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims) + return Apply(self, [x, *idxs], [output]) + + +index = Index() diff --git a/pytensor/xtensor/rewriting/__init__.py b/pytensor/xtensor/rewriting/__init__.py index d4bb32ad66..5c56f7dfd2 100644 --- a/pytensor/xtensor/rewriting/__init__.py +++ b/pytensor/xtensor/rewriting/__init__.py @@ -1,2 +1,3 @@ import pytensor.xtensor.rewriting.basic +import pytensor.xtensor.rewriting.indexing import pytensor.xtensor.rewriting.shape diff --git a/pytensor/xtensor/rewriting/indexing.py b/pytensor/xtensor/rewriting/indexing.py new file mode 100644 index 0000000000..556a6fd96f --- /dev/null +++ b/pytensor/xtensor/rewriting/indexing.py @@ -0,0 +1,27 @@ +from pytensor.graph import node_rewriter +from pytensor.tensor import TensorType +from pytensor.tensor.type_other import SliceType +from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor +from pytensor.xtensor.indexing import Index +from pytensor.xtensor.rewriting.utils import register_xcanonicalize + + +def is_basic_idx(idx): + return ( + isinstance(idx.type, SliceType) + or isinstance(idx.type, TensorType) + and idx.type.ndim == 0 + and idx.type.dtype != bool + ) + + +@register_xcanonicalize +@node_rewriter(tracks=[Index]) +def lower_index(fgraph, node): + x, *idxs = node.inputs + x_tensor = tensor_from_xtensor(x) + if all(is_basic_idx(idx) for idx in idxs): + # Simple case + x_tensor_indexed = x_tensor[tuple(idxs)] + new_out = xtensor_from_tensor(x_tensor_indexed, dims=node.outputs[0].type.dims) + return [new_out] diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index c206c911ec..7da673b906 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -1,3 +1,6 @@ +import warnings + + try: import xarray as xr @@ -6,13 +9,13 @@ XARRAY_AVAILABLE = False from collections.abc import Sequence -from typing import TypeVar +from typing import Any, Literal, TypeVar import numpy as np from pytensor import _as_symbolic, config from pytensor.graph import Apply, Constant -from pytensor.graph.basic import Variable, OptionalApplyType +from pytensor.graph.basic import OptionalApplyType, Variable from pytensor.graph.type import HasDataType, HasShape, Type from pytensor.tensor.utils import hash_from_ndarray from pytensor.utils import hash_from_code @@ -141,17 +144,69 @@ def __getitem__(self, idx): if isinstance(idx, dict): return self.isel(idx) + # Check for ellipsis not in the last position (last one is useless anyway) + if any(idx_item is Ellipsis for idx_item in idx): + if idx.count(Ellipsis) > 1: + raise IndexError("an index can only have a single ellipsis ('...')") + # Convert intermediate Ellipsis to slice(None) + ellipsis_loc = idx.index(Ellipsis) + n_implied_none_slices = self.type.ndim - (len(idx) - 1) + idx = ( + *idx[:ellipsis_loc], + *((slice(None),) * n_implied_none_slices), + *idx[ellipsis_loc + 1 :], + ) + return index(self, *idx) + def sel(self, *args, **kwargs): + raise NotImplementedError( + "sel not implemented for XTensorVariable, use isel instead" + ) -class XTensorVariable(Variable): - pass + def isel( + self, + indexers: dict[str, Any] | None = None, + drop: bool = False, # Unused by PyTensor + missing_dims: Literal["raise", "warn", "ignore"] = "raise", + **indexers_kwargs, + ): + from pytensor.xtensor.indexing import index + + if indexers_kwargs: + if indexers is not None: + raise ValueError( + "Cannot pass both indexers and indexers_kwargs to isel" + ) + indexers = indexers_kwargs - # def __str__(self): - # return f"{self.__class__.__name__}{{{self.format},{self.dtype}}}" + if missing_dims not in {"raise", "warn", "ignore"}: + raise ValueError( + f"Unrecognized options {missing_dims} for missing_dims argument" + ) - # def __repr__(self): - # return str(self) + # Sort indices and pass them to index + dims = self.type.dims + indices = [slice(None)] * self.type.ndim + for key, idx in indexers.items(): + if idx is Ellipsis: + # Xarray raises a less informative error, suggesting indices must be integer + # But slices are also fine + raise TypeError("Ellipsis (...) is an invalid labeled index") + try: + indices[dims.index(key)] = idx + except IndexError: + if missing_dims == "raise": + raise ValueError( + f"Dimension {key} does not exist. Expected one of {dims}" + ) + elif missing_dims == "warn": + warnings.warn( + UserWarning, + f"Dimension {key} does not exist. Expected one of {dims}", + ) + + return index(self, *indices) class XTensorConstantSignature(tuple): diff --git a/tests/xtensor/test_indexing.py b/tests/xtensor/test_indexing.py new file mode 100644 index 0000000000..5127e7601d --- /dev/null +++ b/tests/xtensor/test_indexing.py @@ -0,0 +1,42 @@ +import numpy as np +import pytest +from xarray import DataArray +from xtensor.util import xr_assert_allclose, xr_function + +from pytensor.xtensor import xtensor + + +@pytest.mark.parametrize( + "indices", + [ + (0,), + (slice(1, None),), + (slice(None, -1),), + (slice(None, None, -1),), + (0, slice(None), -1, slice(1, None)), + (..., 0, -1), + (0, ..., -1), + (0, -1, ...), + ], +) +@pytest.mark.parametrize("labeled", (False, True), ids=["unlabeled", "labeled"]) +def test_basic_indexing(labeled, indices): + if ... in indices and labeled: + pytest.skip("Ellipsis not supported with labeled indexing") + + dims = ("a", "b", "c", "d") + x = xtensor(dims=dims, shape=(2, 3, 5, 7)) + + if labeled: + shufled_dims = tuple(np.random.permutation(dims)) + indices = dict(zip(shufled_dims, indices, strict=False)) + out = x[indices] + + fn = xr_function([x], out) + x_test_values = np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape( + x.type.shape + ) + x_test = DataArray(x_test_values, dims=x.type.dims) + res = fn(x_test) + expected_res = x_test[indices] + xr_assert_allclose(res, expected_res) From f5d426ff8e397cb9c5ef744f737eb617138c3d70 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 22 May 2025 15:45:22 +0200 Subject: [PATCH 4/4] WIP add XTensorVariable properties and methods --- pytensor/xtensor/type.py | 177 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 169 insertions(+), 8 deletions(-) diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 7da673b906..f206a89010 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -1,5 +1,7 @@ import warnings +from pytensor.tensor import TensorVariable, mul + try: import xarray as xr @@ -133,10 +135,99 @@ def __complex__(self): "Call `.astype(complex)` for the symbolic equivalent." ) + # DataArray-like attributes + # https://docs.xarray.dev/en/latest/api.html#id1 + @property + def values(self) -> TensorVariable: + from pytensor.xtensor.basic import tensor_from_xtensor + + return tensor_from_xtensor(self) + + data = values + + @property + def coords(self): + raise NotImplementedError("coords not implemented for XTensorVariable") + + @property + def dims(self) -> tuple[str]: + return self.type.dims + + @property + def sizes(self) -> dict[str, TensorVariable]: + return dict(zip(self.dims, self.shape)) + + @property + def as_numpy(self): + # No-op, since the underlying data is always a numpy array + return self + + # ndarray attributes + # https://docs.xarray.dev/en/latest/api.html#ndarray-attributes + @property + def ndim(self) -> int: + return self.type.ndim + + @property + def shape(self) -> tuple[TensorVariable]: + from pytensor.xtensor.basic import tensor_from_xtensor + + return tuple(tensor_from_xtensor(self).shape) + + @property + def size(self): + return mul(*self.shape) + + @property + def dtype(self): + return self.type.dtype + + # DataArray contents + # https://docs.xarray.dev/en/latest/api.html#dataarray-contents + def rename(self, new_name_or_name_dict, **names): + from pytensor.xtensor.basic import rename + + if isinstance(new_name_or_name_dict, str): + # TODO: Should we make a symbolic copy? + self.name = new_name_or_name_dict + name_dict = None + else: + name_dict = new_name_or_name_dict + return rename(name_dict, **names) + + # def swap_dims(self, *args, **kwargs): + # ... + # + # def expand_dims(self, *args, **kwargs): + # ... + # + # def squeeze(self): + # ... + + def copy(self): + from pytensor.xtensor.math import identity + + return identity(self) + + def astype(self, dtype): + from pytensor.xtensor.math import cast + + return cast(self, dtype) + + def item(self): + raise NotImplementedError("item not implemented for XTensorVariable") + + # Indexing + # https://docs.xarray.dev/en/latest/api.html#id2 def __setitem__(self, key, value): - raise TypeError( - "XTensorVariable does not support item assignment. Use the output of `x[idx].set` or `x[idx].inc` instead." - ) + raise TypeError("XTensorVariable does not support item assignment.") + + @property + def loc(self): + raise NotImplementedError("loc not implemented for XTensorVariable") + + def sel(self, *args, **kwargs): + raise NotImplementedError("sel not implemented for XTensorVariable") def __getitem__(self, idx): from pytensor.xtensor.indexing import index @@ -159,11 +250,6 @@ def __getitem__(self, idx): return index(self, *idx) - def sel(self, *args, **kwargs): - raise NotImplementedError( - "sel not implemented for XTensorVariable, use isel instead" - ) - def isel( self, indexers: dict[str, Any] | None = None, @@ -208,6 +294,81 @@ def isel( return index(self, *indices) + def _head_tail_or_thin( + self, + indexers: dict[str, Any] | int | None, + indexers_kwargs: dict[str, Any], + *, + kind: Literal["head", "tail", "thin"], + ): + if indexers_kwargs: + if indexers is not None: + raise ValueError( + "Cannot pass both indexers and indexers_kwargs to head" + ) + indexers = indexers_kwargs + + if indexers is None: + if kind == "thin": + raise TypeError( + "thin() indexers must be either dict-like or a single integer" + ) + else: + # Default to 5 for head and tail + indexers = {dim: 5 for dim in self.type.dims} + + elif not isinstance(indexers, dict): + indexers = {dim: indexers for dim in self.type.dims} + + if kind == "head": + indices = {dim: slice(None, value) for dim, value in indexers.items()} + elif kind == "tail": + sizes = self.sizes + # Can't use slice(-value, None), in case value is zero + indices = { + dim: slice(sizes[dim] - value, None) for dim, value in indexers.items() + } + elif kind == "thin": + indices = {dim: slice(None, None, value) for dim, value in indexers.items()} + return self.isel(indices) + + def head(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs): + return self._head_tail_or_thin(indexers, indexers_kwargs, kind="head") + + def tail(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs): + return self._head_tail_or_thin(indexers, indexers_kwargs, kind="tail") + + def thin(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs): + return self._head_tail_or_thin(indexers, indexers_kwargs, kind="thin") + + # ndarray methods + # https://docs.xarray.dev/en/latest/api.html#id7 + def clip(self, min, max): + from pytensor.xtensor.math import clip + + return clip(self, min, max) + + def conj(self): + from pytensor.xtensor.math import conj + + return conj(self) + + @property + def imag(self): + from pytensor.xtensor.math import imag + + return imag(self) + + @property + def real(self): + from pytensor.xtensor.math import real + + return real(self) + + # @property + # def T(self): + # ... + class XTensorConstantSignature(tuple): def __eq__(self, other):