Skip to content

Commit a3db755

Browse files
committed
POC named tensors
1 parent 981688c commit a3db755

File tree

8 files changed

+457
-0
lines changed

8 files changed

+457
-0
lines changed

pytensor/xtensor/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import warnings
2+
3+
import pytensor.xtensor.rewriting
4+
from pytensor.xtensor.type import (
5+
XTensorType,
6+
as_xtensor,
7+
as_xtensor_variable,
8+
xtensor,
9+
xtensor_constant,
10+
)
11+
12+
13+
warnings.warn("xtensor module is experimental and full of bugs")

pytensor/xtensor/basic.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import pytensor.scalar as ps
2+
import pytensor.xtensor as px
3+
from pytensor.graph import Apply, Op
4+
from pytensor.tensor import TensorType
5+
6+
7+
class TensorFromXTensor(Op):
8+
# TODO: May need mapping of named dims to positional dims?
9+
10+
def make_node(self, x) -> Apply:
11+
if not isinstance(x.type, px.XTensorType):
12+
raise TypeError(f"x must be have an XTensorType, got {type(x.type)}")
13+
output = TensorType(x.type.dtype, shape=x.type.shape)()
14+
return Apply(self, [x], [output])
15+
16+
def perform(self, node, inputs, output_storage) -> None:
17+
[x] = inputs
18+
output_storage[0][0] = x.copy()
19+
20+
21+
tensor_from_xtensor = TensorFromXTensor()
22+
23+
24+
class XTensorFromTensor(Op):
25+
__props__ = ("dims",)
26+
27+
def __init__(self, dims):
28+
super().__init__()
29+
self.dims = dims
30+
31+
def make_node(self, x) -> Apply:
32+
if not isinstance(x.type, TensorType):
33+
raise TypeError(f"x must be an TensorType type, got {type(x.type)}")
34+
output = px.XTensorType(x.type.dtype, dims=self.dims, shape=x.type.shape)()
35+
return Apply(self, [x], [output])
36+
37+
def perform(self, node, inputs, output_storage) -> None:
38+
[x] = inputs
39+
output_storage[0][0] = x.copy()
40+
41+
42+
def xtensor_from_tensor(x, dims):
43+
return XTensorFromTensor(dims=dims)(x)
44+
45+
46+
class XElemwise(Op):
47+
__props__ = ("scalar_op",)
48+
49+
def __init__(self, scalar_op):
50+
super().__init__()
51+
self.scalar_op = scalar_op
52+
53+
def make_node(self, *inputs):
54+
inputs = [px.as_xtensor(inp) for inp in inputs]
55+
56+
# TODO: This ordering is different than what xarray does
57+
unique_dims: dict[str, int | None] = {}
58+
for inp in inputs:
59+
for dim, dim_length in zip(inp.type.dims, inp.type.shape):
60+
if dim not in unique_dims:
61+
unique_dims[dim] = dim_length
62+
elif dim_length is not None:
63+
# Check for conflicting shapes
64+
if (unique_dims[dim] is not None) and (
65+
unique_dims[dim] != dim_length
66+
):
67+
raise ValueError(f"Dimension {dim} has conflicting shapes")
68+
# Keep the non-None shape
69+
unique_dims[dim] = dim_length
70+
71+
dims, shape = zip(*sorted(unique_dims.items()))
72+
73+
# TODO: Fix dtype
74+
output_type = px.XTensorType("float64", dims=dims, shape=shape)
75+
outputs = [output_type() for _ in range(self.scalar_op.nout)]
76+
return Apply(self, inputs, outputs)
77+
78+
def perform(self, *args, **kwargs) -> None:
79+
raise NotImplementedError(
80+
"xtensor operations must be rewritten as tensor operations"
81+
)
82+
83+
84+
add = XElemwise(ps.add)
85+
exp = XElemwise(ps.exp)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import pytensor.xtensor.rewriting.basic

pytensor/xtensor/rewriting/basic.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from pytensor.graph import node_rewriter
2+
from pytensor.tensor import expand_dims
3+
from pytensor.tensor.elemwise import Elemwise
4+
from pytensor.xtensor.basic import (
5+
TensorFromXTensor,
6+
XElemwise,
7+
XTensorFromTensor,
8+
tensor_from_xtensor,
9+
xtensor_from_tensor,
10+
)
11+
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
12+
13+
14+
@register_xcanonicalize
15+
@node_rewriter(tracks=[TensorFromXTensor])
16+
def useless_tensor_from_xtensor(fgraph, node):
17+
"""TensorFromXTensor(XTensorFromTensor(x)) -> x"""
18+
[x] = node.inputs
19+
if x.owner and isinstance(x.owner.op, XTensorFromTensor):
20+
return [x.owner.inputs[0]]
21+
22+
23+
@register_xcanonicalize
24+
@node_rewriter(tracks=[XTensorFromTensor])
25+
def useless_xtensor_from_tensor(fgraph, node):
26+
"""XTensorFromTensor(TensorFromXTensor(x)) -> x"""
27+
[x] = node.inputs
28+
if x.owner and isinstance(x.owner.op, TensorFromXTensor):
29+
return [x.owner.inputs[0]]
30+
31+
32+
@register_xcanonicalize
33+
@node_rewriter(tracks=[XElemwise])
34+
def xelemwise_to_elemwise(fgraph, node):
35+
# Convert inputs to TensorVariables and add broadcastable dims
36+
output_dims = node.outputs[0].type.dims
37+
38+
tensor_inputs = []
39+
for inp in node.inputs:
40+
inp_dims = inp.type.dims
41+
axis = [i for i, dim in enumerate(output_dims) if dim not in inp_dims]
42+
tensor_inp = tensor_from_xtensor(inp)
43+
tensor_inp = expand_dims(tensor_inp, axis)
44+
tensor_inputs.append(tensor_inp)
45+
46+
tensor_outs = Elemwise(scalar_op=node.op.scalar_op)(
47+
*tensor_inputs, return_list=True
48+
)
49+
50+
# TODO: copy_stack_trace
51+
new_outs = [
52+
xtensor_from_tensor(tensor_out, dims=output_dims) for tensor_out in tensor_outs
53+
]
54+
return new_outs

pytensor/xtensor/rewriting/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from pytensor.compile import optdb
2+
from pytensor.graph.rewriting.basic import NodeRewriter
3+
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase
4+
5+
6+
optdb.register(
7+
"xcanonicalize",
8+
EquilibriumDB(ignore_newtrees=False),
9+
"fast_run",
10+
"fast_compile",
11+
"xtensor",
12+
position=0,
13+
)
14+
15+
16+
def register_xcanonicalize(
17+
node_rewriter: RewriteDatabase | NodeRewriter | str, *tags: str, **kwargs
18+
):
19+
if isinstance(node_rewriter, str):
20+
21+
def register(inner_rewriter: RewriteDatabase | NodeRewriter):
22+
return register_xcanonicalize(
23+
inner_rewriter, node_rewriter, *tags, **kwargs
24+
)
25+
26+
return register
27+
28+
else:
29+
name = kwargs.pop("name", None) or node_rewriter.__name__
30+
optdb["xtensor"].register(
31+
name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs
32+
)
33+
return node_rewriter

0 commit comments

Comments
 (0)