Skip to content

Commit 6e88797

Browse files
committed
POC named tensors
1 parent 5ffe17a commit 6e88797

File tree

15 files changed

+1073
-0
lines changed

15 files changed

+1073
-0
lines changed

pytensor/xtensor/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import warnings
2+
3+
import pytensor.xtensor.rewriting
4+
from pytensor.xtensor.type import (
5+
XTensorType,
6+
as_xtensor,
7+
as_xtensor_variable,
8+
xtensor,
9+
xtensor_constant,
10+
)
11+
12+
13+
warnings.warn("xtensor module is experimental and full of bugs")

pytensor/xtensor/basic.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
from itertools import chain
2+
3+
import pytensor.scalar as ps
4+
from pytensor.graph import Apply, Op
5+
from pytensor.tensor import TensorType, tensor
6+
from pytensor.tensor.utils import _parse_gufunc_signature
7+
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
8+
9+
10+
class XOp(Op):
11+
"""A base class for XOps that shouldn't be materialized"""
12+
13+
def perform(self, node, inputs, outputs):
14+
raise NotImplementedError(
15+
"xtensor operations must be rewritten as tensor operations"
16+
)
17+
18+
19+
class XViewOp(Op):
20+
# Make this a View Op with C-implementation
21+
view_map = {0: [0]}
22+
23+
def perform(self, node, inputs, output_storage):
24+
output_storage[0][0] = inputs[0]
25+
26+
27+
class TensorFromXTensor(XViewOp):
28+
__props__ = ()
29+
30+
def make_node(self, x) -> Apply:
31+
if not isinstance(x.type, XTensorType):
32+
raise TypeError(f"x must be have an XTensorType, got {type(x.type)}")
33+
output = TensorType(x.type.dtype, shape=x.type.shape)()
34+
return Apply(self, [x], [output])
35+
36+
37+
tensor_from_xtensor = TensorFromXTensor()
38+
39+
40+
class XTensorFromTensor(XViewOp):
41+
__props__ = ("dims",)
42+
43+
def __init__(self, dims):
44+
super().__init__()
45+
self.dims = dims
46+
47+
def make_node(self, x) -> Apply:
48+
if not isinstance(x.type, TensorType):
49+
raise TypeError(f"x must be an TensorType type, got {type(x.type)}")
50+
output = xtensor(dtype=x.type.dtype, dims=self.dims, shape=x.type.shape)
51+
return Apply(self, [x], [output])
52+
53+
54+
def xtensor_from_tensor(x, dims):
55+
return XTensorFromTensor(dims=dims)(x)
56+
57+
58+
class Rename(XViewOp):
59+
__props__ = ("new_dims",)
60+
61+
def __init__(self, new_dims: tuple[str, ...]):
62+
super().__init__()
63+
self.new_dims = new_dims
64+
65+
def make_node(self, x):
66+
x = as_xtensor(x)
67+
output = x.type.clone(dims=self.new_dims)()
68+
return Apply(self, [x], [output])
69+
70+
71+
def rename(x, name_dict: dict[str, str] | None = None, **names: str):
72+
if name_dict is not None:
73+
if names:
74+
raise ValueError("Cannot use both positional and keyword names in rename")
75+
names = name_dict
76+
77+
x = as_xtensor(x)
78+
old_names = x.type.dims
79+
new_names = list(old_names)
80+
for old_name, new_name in names.items():
81+
try:
82+
new_names[old_names.index(old_name)] = new_name
83+
except IndexError:
84+
raise ValueError(
85+
f"Cannot rename {old_name} to {new_name}: {old_name} not in {old_names}"
86+
)
87+
88+
return Rename(tuple(new_names))(x)
89+
90+
91+
class XElemwise(XOp):
92+
__props__ = ("scalar_op",)
93+
94+
def __init__(self, scalar_op):
95+
super().__init__()
96+
self.scalar_op = scalar_op
97+
98+
def make_node(self, *inputs):
99+
inputs = [as_xtensor(inp) for inp in inputs]
100+
if (self.scalar_op.nin != -1) and (len(inputs) != self.scalar_op.nin):
101+
raise ValueError(
102+
f"Wrong number of inputs, expected {self.scalar_op.nin}, got {len(inputs)}"
103+
)
104+
105+
dims_and_shape: dict[str, int | None] = {}
106+
for inp in inputs:
107+
for dim, dim_length in zip(inp.type.dims, inp.type.shape):
108+
if dim not in dims_and_shape:
109+
dims_and_shape[dim] = dim_length
110+
elif dim_length is not None:
111+
# Check for conflicting shapes
112+
if (dims_and_shape[dim] is not None) and (
113+
dims_and_shape[dim] != dim_length
114+
):
115+
raise ValueError(f"Dimension {dim} has conflicting shapes")
116+
# Keep the non-None shape
117+
dims_and_shape[dim] = dim_length
118+
119+
output_dims, output_shape = zip(*dims_and_shape.items())
120+
121+
dummy_scalars = [ps.get_scalar_type(inp.type.dtype)() for inp in inputs]
122+
output_dtypes = [
123+
out.type.dtype for out in self.scalar_op.make_node(*dummy_scalars).outputs
124+
]
125+
outputs = [
126+
xtensor(dtype=output_dtype, dims=output_dims, shape=output_shape)
127+
for output_dtype in output_dtypes
128+
]
129+
return Apply(self, inputs, outputs)
130+
131+
132+
class XBlockwise(XOp):
133+
__props__ = ("core_op", "signature", "core_dims")
134+
135+
def __init__(
136+
self,
137+
core_op: Op,
138+
signature: str,
139+
core_dims: tuple[tuple[tuple[str, ...], ...], tuple[tuple[str, ...], ...]],
140+
):
141+
super().__init__()
142+
self.core_op = core_op
143+
self.signature = signature
144+
self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature)
145+
self.core_dims = core_dims
146+
147+
def make_node(self, *inputs):
148+
inputs = [as_xtensor(i) for i in inputs]
149+
if len(inputs) != len(self.inputs_sig):
150+
raise ValueError(
151+
f"Wrong number of inputs, expected {len(self.inputs_sig)}, got {len(inputs)}"
152+
)
153+
154+
dims_and_shape: dict[str, int | None] = {}
155+
for inp in inputs:
156+
for dim, dim_length in zip(inp.type.dims, inp.type.shape):
157+
if dim not in dims_and_shape:
158+
dims_and_shape[dim] = dim_length
159+
elif dim_length is not None:
160+
# Check for conflicting shapes
161+
if (dims_and_shape[dim] is not None) and (
162+
dims_and_shape[dim] != dim_length
163+
):
164+
raise ValueError(f"Dimension {dim} has conflicting shapes")
165+
# Keep the non-None shape
166+
dims_and_shape[dim] = dim_length
167+
168+
core_inputs_dims, core_outputs_dims = self.core_dims
169+
# TODO: Avoid intermediate dict
170+
core_dims = set(chain.from_iterable(core_inputs_dims))
171+
batched_dims_and_shape = {
172+
k: v for k, v in dims_and_shape.items() if k not in core_dims
173+
}
174+
batch_dims, batch_shape = zip(*batched_dims_and_shape.items())
175+
176+
dummy_core_inputs = []
177+
for inp, core_inp_dims in zip(inputs, core_inputs_dims):
178+
try:
179+
core_static_shape = [
180+
inp.type.shape[inp.type.dims.index(d)] for d in core_inp_dims
181+
]
182+
except IndexError:
183+
raise ValueError(
184+
f"At least one core dim={core_inp_dims} missing from input {inp} with dims={inp.type.dims}"
185+
)
186+
dummy_core_inputs.append(
187+
tensor(dtype=inp.type.dtype, shape=core_static_shape)
188+
)
189+
core_node = self.core_op.make_node(*dummy_core_inputs)
190+
191+
outputs = [
192+
xtensor(
193+
dtype=core_out.type.dtype,
194+
shape=batch_shape + core_out.type.shape,
195+
dims=batch_dims + core_out_dims,
196+
)
197+
for core_out, core_out_dims in zip(core_node.outputs, core_outputs_dims)
198+
]
199+
return Apply(self, inputs, outputs)

pytensor/xtensor/linalg.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from collections.abc import Sequence
2+
from typing import Literal
3+
4+
from pytensor.tensor.slinalg import Cholesky, Solve
5+
from pytensor.xtensor import as_xtensor
6+
from pytensor.xtensor.basic import XBlockwise
7+
8+
9+
def cholesky(
10+
x,
11+
lower: bool = True,
12+
*,
13+
check_finite: bool = False,
14+
overwrite_a: bool = False,
15+
on_error: Literal["raise", "nan"] = "raise",
16+
dims: Sequence[str],
17+
):
18+
if len(dims) != 2:
19+
raise ValueError(f"Cholesky needs two dims, got {len(dims)}")
20+
21+
core_op = Cholesky(
22+
lower=lower,
23+
check_finite=check_finite,
24+
overwrite_a=overwrite_a,
25+
on_error=on_error,
26+
)
27+
core_dims = (
28+
((dims[0], dims[1]),),
29+
((dims[0], dims[1]),),
30+
)
31+
x_op = XBlockwise(core_op, signature=core_op.gufunc_signature, core_dims=core_dims)
32+
return x_op(x)
33+
34+
35+
def solve(
36+
a,
37+
b,
38+
dims: Sequence[str],
39+
assume_a="gen",
40+
lower: bool = False,
41+
check_finite: bool = False,
42+
):
43+
a, b = as_xtensor(a), as_xtensor(b)
44+
if len(dims) == 2:
45+
b_ndim = 1
46+
[m1_dim] = [dim for dim in dims if dim not in b.type.dims]
47+
m2_dim = dims[0] if dims[0] != m1_dim else dims[1]
48+
input_core_dims = ((m1_dim, m2_dim), (m2_dim,))
49+
output_core_dims = ((m2_dim,),)
50+
elif len(dims) == 3:
51+
b_ndim = 2
52+
[n_dim] = [dim for dim in dims if dim not in a.type.dims]
53+
[m1_dim, m2_dim] = [dim for dim in dims if dim != n_dim]
54+
input_core_dims = ((m1_dim, m2_dim), (m2_dim, n_dim))
55+
output_core_dims = (
56+
(
57+
m2_dim,
58+
n_dim,
59+
),
60+
)
61+
else:
62+
raise ValueError("Solve dims must have length 2 or 3")
63+
64+
core_op = Solve(
65+
b_ndim=b_ndim, assume_a=assume_a, lower=lower, check_finite=check_finite
66+
)
67+
x_op = XBlockwise(
68+
core_op,
69+
signature=core_op.gufunc_signature,
70+
core_dims=(input_core_dims, output_core_dims),
71+
)
72+
return x_op(a, b)

pytensor/xtensor/math.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import inspect
2+
import sys
3+
4+
import pytensor.scalar as ps
5+
from pytensor.scalar import ScalarOp
6+
from pytensor.xtensor.basic import XElemwise
7+
8+
9+
this_module = sys.modules[__name__]
10+
11+
12+
def get_all_scalar_ops():
13+
"""
14+
Find all scalar operations in the pytensor.scalar module that can be wrapped with XElemwise.
15+
16+
Returns:
17+
dict: A dictionary mapping operation names to XElemwise instances
18+
"""
19+
result = {}
20+
21+
# Get all module members
22+
for name, obj in inspect.getmembers(ps):
23+
# Check if the object is a scalar op (has make_node method and is not an abstract class)
24+
if isinstance(obj, ScalarOp):
25+
result[name] = XElemwise(obj)
26+
27+
return result
28+
29+
30+
for name, op in get_all_scalar_ops().items():
31+
setattr(this_module, name, op)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
import pytensor.xtensor.rewriting.basic
2+
import pytensor.xtensor.rewriting.shape

0 commit comments

Comments
 (0)