Skip to content

Commit 5b0c472

Browse files
committed
. POC named tensors
1 parent 4cc13bc commit 5b0c472

File tree

7 files changed

+641
-0
lines changed

7 files changed

+641
-0
lines changed

pytensor/xtensor/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import warnings
2+
import pytensor.xtensor.rewriting
3+
4+
from pytensor.xtensor.variable import XTensorVariable, XTensorConstant, as_xtensor, as_xtensor_variable
5+
from pytensor.xtensor.type import XTensorType
6+
7+
8+
warnings.warn("xtensor module is experimental and full of bugs")

pytensor/xtensor/basic.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from itertools import chain
2+
3+
import pytensor.scalar as ps
4+
from pytensor.graph import Apply, Op
5+
import pytensor.xtensor as px
6+
from pytensor.tensor import TensorType
7+
8+
9+
class TensorFromXTensor(Op):
10+
11+
def make_node(self, x) -> Apply:
12+
if not isinstance(x.type, px.XTensorType):
13+
raise TypeError(f"x must be have an XTensorType, got {type(x.type)}")
14+
output = TensorType(x.type.dtype, shape=x.type.shape)()
15+
return Apply(self, [x], [output])
16+
17+
def perform(self, node, inputs, output_storage) -> None:
18+
[x] = inputs
19+
output_storage[0][0] = x.copy()
20+
21+
22+
tensor_from_xtensor = TensorFromXTensor()
23+
24+
25+
class XTensorFromTensor(Op):
26+
27+
__props__ = ("dims",)
28+
29+
def __init__(self, dims):
30+
super().__init__()
31+
self.dims = dims
32+
33+
def make_node(self, x) -> Apply:
34+
if not isinstance(x.type, TensorType):
35+
raise TypeError(f"x must be an TensorType type, got {type(x.type)}")
36+
output = px.XTensorType(x.type.dtype, dims=self.dims, shape=x.type.shape)()
37+
return Apply(self, [x], [output])
38+
39+
def perform(self, node, inputs, output_storage) -> None:
40+
[x] = inputs
41+
output_storage[0][0] = x.copy()
42+
43+
44+
def xtensor_from_tensor(x, dims):
45+
return XTensorFromTensor(dims=dims)(x)
46+
47+
48+
class XElemwise(Op):
49+
50+
__props__ = ("scalar_op",)
51+
52+
def __init__(self, scalar_op):
53+
super().__init__()
54+
self.scalar_op = scalar_op
55+
56+
def make_node(self, *inputs):
57+
# TODO: Check dim lengths match
58+
inputs = [px.as_xtensor_variable(inp) for inp in inputs]
59+
# TODO: This ordering is different than what xarray does
60+
unique_dims = sorted(set(chain.from_iterable(inp.type.dims for inp in inputs)))
61+
# TODO: Fix dtype
62+
output_type = px.XTensorType("float64", dims=unique_dims, shape=(None,) * len(unique_dims))
63+
outputs = [output_type() for _ in range(self.scalar_op.nout)]
64+
return Apply(self, inputs, outputs)
65+
66+
def perform(self, *args, **kwargs) -> None:
67+
raise NotImplementedError("xtensor operations must be rewritten as tensor operations")
68+
69+
70+
add = XElemwise(ps.add)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import pytensor.xtensor.rewriting.basic

pytensor/xtensor/rewriting/basic.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from pytensor.graph import node_rewriter
2+
from pytensor.tensor import expand_dims
3+
from pytensor.tensor.elemwise import Elemwise
4+
from pytensor.xtensor.basic import tensor_from_xtensor, XElemwise, xtensor_from_tensor
5+
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
6+
7+
8+
@register_xcanonicalize
9+
@node_rewriter(tracks=[XElemwise])
10+
def xelemwise_to_elemwise(fgraph, node):
11+
# Convert inputs to TensorVariables and add broadcastable dims
12+
output_dims = node.outputs[0].type.dims
13+
14+
tensor_inputs = []
15+
for inp in node.inputs:
16+
inp_dims = inp.type.dims
17+
axis = [i for i, dim in enumerate(output_dims) if dim not in inp_dims]
18+
tensor_inp = tensor_from_xtensor(inp)
19+
tensor_inp = expand_dims(tensor_inp, axis)
20+
tensor_inputs.append(tensor_inp)
21+
22+
tensor_outs = Elemwise(scalar_op=node.op.scalar_op)(*tensor_inputs, return_list=True)
23+
24+
# TODO: copy_stack_trace
25+
new_outs = [xtensor_from_tensor(tensor_out, dims=output_dims) for tensor_out in tensor_outs]
26+
return new_outs

pytensor/xtensor/rewriting/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from typing import Union
2+
3+
from pytensor.compile import optdb
4+
from pytensor.graph.rewriting.basic import NodeRewriter
5+
from pytensor.graph.rewriting.db import RewriteDatabase, EquilibriumDB
6+
7+
8+
optdb.register(
9+
"xcanonicalize",
10+
EquilibriumDB(ignore_newtrees=False),
11+
"fast_run",
12+
"fast_compile",
13+
"xtensor",
14+
position=0,
15+
)
16+
17+
18+
def register_xcanonicalize(
19+
node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs
20+
):
21+
if isinstance(node_rewriter, str):
22+
23+
def register(inner_rewriter: Union[RewriteDatabase, NodeRewriter]):
24+
return register_xcanonicalize(inner_rewriter, node_rewriter, *tags, **kwargs)
25+
26+
return register
27+
28+
else:
29+
name = kwargs.pop("name", None) or node_rewriter.__name__
30+
optdb["xtensor"].register(
31+
name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs
32+
)
33+
return node_rewriter

pytensor/xtensor/type.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
from typing import Iterable, Optional, Union, Sequence, TypeVar
2+
3+
import numpy as np
4+
5+
import pytensor
6+
from pytensor import scalar as aes
7+
from pytensor.graph.basic import Variable
8+
from pytensor.graph.type import HasDataType
9+
from pytensor.tensor.type import TensorType
10+
11+
12+
_XTensorTypeType = TypeVar("_XTensorTypeType", bound=TensorType)
13+
14+
15+
class XTensorType(TensorType, HasDataType):
16+
"""A `Type` for sparse tensors.
17+
18+
Notes
19+
-----
20+
Currently, sparse tensors can only be matrices (i.e. have two dimensions).
21+
22+
"""
23+
24+
__props__ = ("dtype", "shape", "dims")
25+
26+
def __init__(
27+
self,
28+
dtype: Union[str, np.dtype],
29+
*,
30+
dims: Sequence[str],
31+
shape: Optional[Iterable[Optional[Union[bool, int]]]] = None,
32+
name: Optional[str] = None,
33+
):
34+
super().__init__(dtype, shape=shape, name=name)
35+
if not isinstance(dims, (list, tuple)):
36+
raise TypeError("dims must be a list or tuple")
37+
dims = tuple(dims)
38+
self.dims = dims
39+
40+
def clone(
41+
self,
42+
dtype=None,
43+
dims=None,
44+
shape=None,
45+
**kwargs,
46+
):
47+
if dtype is None:
48+
dtype = self.dtype
49+
if dims is None:
50+
dims = self.dims
51+
if shape is None:
52+
shape = self.shape
53+
return type(self)(format, dtype, shape=shape, dims=dims, **kwargs)
54+
55+
def filter(self, value, strict=False, allow_downcast=None):
56+
# TODO: Implement this
57+
return value
58+
59+
if isinstance(value, Variable):
60+
raise TypeError(
61+
"Expected an array-like object, but found a Variable: "
62+
"maybe you are trying to call a function on a (possibly "
63+
"shared) variable instead of a numeric array?"
64+
)
65+
66+
if (
67+
isinstance(value, self.format_cls[self.format])
68+
and value.dtype == self.dtype
69+
):
70+
return value
71+
72+
if strict:
73+
raise TypeError(
74+
f"{value} is not sparse, or not the right dtype (is {value.dtype}, "
75+
f"expected {self.dtype})"
76+
)
77+
78+
# The input format could be converted here
79+
if allow_downcast:
80+
sp = self.format_cls[self.format](value, dtype=self.dtype)
81+
else:
82+
data = self.format_cls[self.format](value)
83+
up_dtype = aes.upcast(self.dtype, data.dtype)
84+
if up_dtype != self.dtype:
85+
raise TypeError(f"Expected {self.dtype} dtype but got {data.dtype}")
86+
sp = data.astype(up_dtype)
87+
88+
assert sp.format == self.format
89+
90+
return sp
91+
92+
def convert_variable(self, var):
93+
# TODO: Implement this
94+
return var
95+
res = super().convert_variable(var)
96+
97+
if res is None:
98+
return res
99+
100+
if not isinstance(res.type, type(self)):
101+
return None
102+
103+
if res.dims != self.dims:
104+
# TODO: Does this make sense?
105+
return None
106+
107+
return res
108+
109+
def __hash__(self):
110+
return super().__hash__() ^ hash(self.dims)
111+
112+
def __repr__(self):
113+
# TODO: Add `?` for unknown shapes like `TensorType` does
114+
return f"XTensorType({self.dtype}, {self.dims}, {self.shape})"
115+
116+
def __eq__(self, other):
117+
res = super().__eq__(other)
118+
119+
if isinstance(res, bool):
120+
return res and other.dims == self.dims
121+
122+
return res
123+
124+
def is_super(self, otype):
125+
# TODO: Implement this
126+
return True
127+
128+
if not super().is_super(otype):
129+
return False
130+
131+
if self.dims == otype.dims:
132+
return True
133+
134+
return False
135+
136+
137+
# TODO: Implement creater helper xtensor
138+
139+
pytensor.compile.register_view_op_c_code(
140+
XTensorType,
141+
"""
142+
Py_XDECREF(%(oname)s);
143+
%(oname)s = %(iname)s;
144+
Py_XINCREF(%(oname)s);
145+
""",
146+
1,
147+
)

0 commit comments

Comments
 (0)