Skip to content

Commit fd1ae2b

Browse files
committed
POC named tensors
1 parent 981688c commit fd1ae2b

File tree

8 files changed

+409
-0
lines changed

8 files changed

+409
-0
lines changed

pytensor/xtensor/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import warnings
2+
3+
import pytensor.xtensor.rewriting
4+
from pytensor.xtensor.type import (
5+
XTensorType,
6+
as_xtensor,
7+
as_xtensor_variable,
8+
xtensor,
9+
xtensor_constant,
10+
)
11+
12+
13+
warnings.warn("xtensor module is experimental and full of bugs")

pytensor/xtensor/basic.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from itertools import chain
2+
3+
import pytensor.scalar as ps
4+
import pytensor.xtensor as px
5+
from pytensor.graph import Apply, Op
6+
from pytensor.tensor import TensorType
7+
8+
9+
class TensorFromXTensor(Op):
10+
def make_node(self, x) -> Apply:
11+
if not isinstance(x.type, px.XTensorType):
12+
raise TypeError(f"x must be have an XTensorType, got {type(x.type)}")
13+
output = TensorType(x.type.dtype, shape=x.type.shape)()
14+
return Apply(self, [x], [output])
15+
16+
def perform(self, node, inputs, output_storage) -> None:
17+
[x] = inputs
18+
output_storage[0][0] = x.copy()
19+
20+
21+
tensor_from_xtensor = TensorFromXTensor()
22+
23+
24+
class XTensorFromTensor(Op):
25+
__props__ = ("dims",)
26+
27+
def __init__(self, dims):
28+
super().__init__()
29+
self.dims = dims
30+
31+
def make_node(self, x) -> Apply:
32+
if not isinstance(x.type, TensorType):
33+
raise TypeError(f"x must be an TensorType type, got {type(x.type)}")
34+
output = px.XTensorType(x.type.dtype, dims=self.dims, shape=x.type.shape)()
35+
return Apply(self, [x], [output])
36+
37+
def perform(self, node, inputs, output_storage) -> None:
38+
[x] = inputs
39+
output_storage[0][0] = x.copy()
40+
41+
42+
def xtensor_from_tensor(x, dims):
43+
return XTensorFromTensor(dims=dims)(x)
44+
45+
46+
class XElemwise(Op):
47+
__props__ = ("scalar_op",)
48+
49+
def __init__(self, scalar_op):
50+
super().__init__()
51+
self.scalar_op = scalar_op
52+
53+
def make_node(self, *inputs):
54+
# TODO: Check dim lengths match
55+
inputs = [px.as_xtensor(inp) for inp in inputs]
56+
# TODO: This ordering is different than what xarray does
57+
unique_dims = sorted(set(chain.from_iterable(inp.type.dims for inp in inputs)))
58+
# TODO: Fix dtype
59+
output_type = px.XTensorType(
60+
"float64", dims=unique_dims, shape=(None,) * len(unique_dims)
61+
)
62+
outputs = [output_type() for _ in range(self.scalar_op.nout)]
63+
return Apply(self, inputs, outputs)
64+
65+
def perform(self, *args, **kwargs) -> None:
66+
raise NotImplementedError(
67+
"xtensor operations must be rewritten as tensor operations"
68+
)
69+
70+
71+
add = XElemwise(ps.add)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import pytensor.xtensor.rewriting.basic

pytensor/xtensor/rewriting/basic.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from pytensor.graph import node_rewriter
2+
from pytensor.tensor import expand_dims
3+
from pytensor.tensor.elemwise import Elemwise
4+
from pytensor.xtensor.basic import tensor_from_xtensor, XElemwise, xtensor_from_tensor
5+
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
6+
7+
8+
@register_xcanonicalize
9+
@node_rewriter(tracks=[XElemwise])
10+
def xelemwise_to_elemwise(fgraph, node):
11+
# Convert inputs to TensorVariables and add broadcastable dims
12+
output_dims = node.outputs[0].type.dims
13+
14+
tensor_inputs = []
15+
for inp in node.inputs:
16+
inp_dims = inp.type.dims
17+
axis = [i for i, dim in enumerate(output_dims) if dim not in inp_dims]
18+
tensor_inp = tensor_from_xtensor(inp)
19+
tensor_inp = expand_dims(tensor_inp, axis)
20+
tensor_inputs.append(tensor_inp)
21+
22+
tensor_outs = Elemwise(scalar_op=node.op.scalar_op)(*tensor_inputs, return_list=True)
23+
24+
# TODO: copy_stack_trace
25+
new_outs = [xtensor_from_tensor(tensor_out, dims=output_dims) for tensor_out in tensor_outs]
26+
return new_outs

pytensor/xtensor/rewriting/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from typing import Union
2+
3+
from pytensor.compile import optdb
4+
from pytensor.graph.rewriting.basic import NodeRewriter
5+
from pytensor.graph.rewriting.db import RewriteDatabase, EquilibriumDB
6+
7+
8+
optdb.register(
9+
"xcanonicalize",
10+
EquilibriumDB(ignore_newtrees=False),
11+
"fast_run",
12+
"fast_compile",
13+
"xtensor",
14+
position=0,
15+
)
16+
17+
18+
def register_xcanonicalize(
19+
node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs
20+
):
21+
if isinstance(node_rewriter, str):
22+
23+
def register(inner_rewriter: Union[RewriteDatabase, NodeRewriter]):
24+
return register_xcanonicalize(inner_rewriter, node_rewriter, *tags, **kwargs)
25+
26+
return register
27+
28+
else:
29+
name = kwargs.pop("name", None) or node_rewriter.__name__
30+
optdb["xtensor"].register(
31+
name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs
32+
)
33+
return node_rewriter

0 commit comments

Comments
 (0)