Skip to content

Commit 592c699

Browse files
committed
POC named tensors
1 parent 981688c commit 592c699

File tree

8 files changed

+415
-0
lines changed

8 files changed

+415
-0
lines changed

pytensor/xtensor/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import warnings
2+
3+
import pytensor.xtensor.rewriting
4+
from pytensor.xtensor.type import (
5+
XTensorType,
6+
as_xtensor,
7+
as_xtensor_variable,
8+
xtensor,
9+
xtensor_constant,
10+
)
11+
12+
13+
warnings.warn("xtensor module is experimental and full of bugs")

pytensor/xtensor/basic.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from itertools import chain
2+
3+
import pytensor.scalar as ps
4+
import pytensor.xtensor as px
5+
from pytensor.graph import Apply, Op
6+
from pytensor.tensor import TensorType
7+
8+
9+
class TensorFromXTensor(Op):
10+
# TODO: May need mapping of named dims to positional dims?
11+
12+
def make_node(self, x) -> Apply:
13+
if not isinstance(x.type, px.XTensorType):
14+
raise TypeError(f"x must be have an XTensorType, got {type(x.type)}")
15+
output = TensorType(x.type.dtype, shape=x.type.shape)()
16+
return Apply(self, [x], [output])
17+
18+
def perform(self, node, inputs, output_storage) -> None:
19+
[x] = inputs
20+
output_storage[0][0] = x.copy()
21+
22+
23+
tensor_from_xtensor = TensorFromXTensor()
24+
25+
26+
class XTensorFromTensor(Op):
27+
__props__ = ("dims",)
28+
29+
def __init__(self, dims):
30+
super().__init__()
31+
self.dims = dims
32+
33+
def make_node(self, x) -> Apply:
34+
if not isinstance(x.type, TensorType):
35+
raise TypeError(f"x must be an TensorType type, got {type(x.type)}")
36+
output = px.XTensorType(x.type.dtype, dims=self.dims, shape=x.type.shape)()
37+
return Apply(self, [x], [output])
38+
39+
def perform(self, node, inputs, output_storage) -> None:
40+
[x] = inputs
41+
output_storage[0][0] = x.copy()
42+
43+
44+
def xtensor_from_tensor(x, dims):
45+
return XTensorFromTensor(dims=dims)(x)
46+
47+
48+
class XElemwise(Op):
49+
__props__ = ("scalar_op",)
50+
51+
def __init__(self, scalar_op):
52+
super().__init__()
53+
self.scalar_op = scalar_op
54+
55+
def make_node(self, *inputs):
56+
# TODO: Check dim lengths match
57+
inputs = [px.as_xtensor(inp) for inp in inputs]
58+
# TODO: This ordering is different than what xarray does
59+
unique_dims = sorted(set(chain.from_iterable(inp.type.dims for inp in inputs)))
60+
# TODO: Fix dtype
61+
output_type = px.XTensorType(
62+
"float64", dims=unique_dims, shape=(None,) * len(unique_dims)
63+
)
64+
outputs = [output_type() for _ in range(self.scalar_op.nout)]
65+
return Apply(self, inputs, outputs)
66+
67+
def perform(self, *args, **kwargs) -> None:
68+
raise NotImplementedError(
69+
"xtensor operations must be rewritten as tensor operations"
70+
)
71+
72+
73+
add = XElemwise(ps.add)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import pytensor.xtensor.rewriting.basic

pytensor/xtensor/rewriting/basic.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from pytensor.graph import node_rewriter
2+
from pytensor.tensor import expand_dims
3+
from pytensor.tensor.elemwise import Elemwise
4+
from pytensor.xtensor.basic import XElemwise, tensor_from_xtensor, xtensor_from_tensor
5+
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
6+
7+
8+
@register_xcanonicalize
9+
@node_rewriter(tracks=[XElemwise])
10+
def xelemwise_to_elemwise(fgraph, node):
11+
# Convert inputs to TensorVariables and add broadcastable dims
12+
output_dims = node.outputs[0].type.dims
13+
14+
tensor_inputs = []
15+
for inp in node.inputs:
16+
inp_dims = inp.type.dims
17+
axis = [i for i, dim in enumerate(output_dims) if dim not in inp_dims]
18+
tensor_inp = tensor_from_xtensor(inp)
19+
tensor_inp = expand_dims(tensor_inp, axis)
20+
tensor_inputs.append(tensor_inp)
21+
22+
tensor_outs = Elemwise(scalar_op=node.op.scalar_op)(
23+
*tensor_inputs, return_list=True
24+
)
25+
26+
# TODO: copy_stack_trace
27+
new_outs = [
28+
xtensor_from_tensor(tensor_out, dims=output_dims) for tensor_out in tensor_outs
29+
]
30+
return new_outs

pytensor/xtensor/rewriting/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from pytensor.compile import optdb
2+
from pytensor.graph.rewriting.basic import NodeRewriter
3+
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase
4+
5+
6+
optdb.register(
7+
"xcanonicalize",
8+
EquilibriumDB(ignore_newtrees=False),
9+
"fast_run",
10+
"fast_compile",
11+
"xtensor",
12+
position=0,
13+
)
14+
15+
16+
def register_xcanonicalize(
17+
node_rewriter: RewriteDatabase | NodeRewriter | str, *tags: str, **kwargs
18+
):
19+
if isinstance(node_rewriter, str):
20+
21+
def register(inner_rewriter: RewriteDatabase | NodeRewriter):
22+
return register_xcanonicalize(
23+
inner_rewriter, node_rewriter, *tags, **kwargs
24+
)
25+
26+
return register
27+
28+
else:
29+
name = kwargs.pop("name", None) or node_rewriter.__name__
30+
optdb["xtensor"].register(
31+
name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs
32+
)
33+
return node_rewriter

0 commit comments

Comments
 (0)