Skip to content

Commit 5229feb

Browse files
committed
Implement C code for ExtractDiagonal and ARange
Set view flag of ExtractDiagonal to True and respect by default
1 parent d9a8471 commit 5229feb

File tree

1 file changed

+89
-35
lines changed

1 file changed

+89
-35
lines changed

pytensor/tensor/basic.py

Lines changed: 89 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3207,13 +3207,14 @@ def tile(
32073207
return A_replicated.reshape(tiled_shape)
32083208

32093209

3210-
class ARange(Op):
3210+
class ARange(COp):
32113211
"""Create an array containing evenly spaced values within a given interval.
32123212
32133213
Parameters and behaviour are the same as numpy.arange().
32143214
32153215
"""
32163216

3217+
# TODO: Arange should work with scalars as inputs, not arrays
32173218
__props__ = ("dtype",)
32183219

32193220
def __init__(self, dtype):
@@ -3293,13 +3294,30 @@ def upcast(var):
32933294
)
32943295
]
32953296

3296-
def perform(self, node, inp, out_):
3297-
start, stop, step = inp
3298-
(out,) = out_
3299-
start = start.item()
3300-
stop = stop.item()
3301-
step = step.item()
3302-
out[0] = np.arange(start, stop, step, dtype=self.dtype)
3297+
def perform(self, node, inputs, output_storage):
3298+
start, stop, step = inputs
3299+
output_storage[0][0] = np.arange(
3300+
start.item(), stop.item(), step.item(), dtype=self.dtype
3301+
)
3302+
3303+
def c_code(self, node, nodename, input_names, output_names, sub):
3304+
[start_name, stop_name, step_name] = input_names
3305+
[out_name] = output_names
3306+
typenum = np.dtype(self.dtype).num
3307+
return f"""
3308+
double start = ((dtype_{start_name}*)PyArray_DATA({start_name}))[0];
3309+
double stop = ((dtype_{stop_name}*)PyArray_DATA({stop_name}))[0];
3310+
double step = ((dtype_{step_name}*)PyArray_DATA({step_name}))[0];
3311+
//printf("start: %f, stop: %f, step: %f\\n", start, stop, step);
3312+
Py_XDECREF({out_name});
3313+
{out_name} = (PyArrayObject*) PyArray_Arange(start, stop, step, {typenum});
3314+
if (!{out_name}) {{
3315+
{sub["fail"]}
3316+
}}
3317+
"""
3318+
3319+
def c_code_cache_version(self):
3320+
return (0,)
33033321

33043322
def connection_pattern(self, node):
33053323
return [[True], [False], [True]]
@@ -3685,8 +3703,7 @@ def inverse_permutation(perm):
36853703
)
36863704

36873705

3688-
# TODO: optimization to insert ExtractDiag with view=True
3689-
class ExtractDiag(Op):
3706+
class ExtractDiag(COp):
36903707
"""
36913708
Return specified diagonals.
36923709
@@ -3742,7 +3759,7 @@ class ExtractDiag(Op):
37423759

37433760
__props__ = ("offset", "axis1", "axis2", "view")
37443761

3745-
def __init__(self, offset=0, axis1=0, axis2=1, view=False):
3762+
def __init__(self, offset=0, axis1=0, axis2=1, view=True):
37463763
self.view = view
37473764
if self.view:
37483765
self.view_map = {0: [0]}
@@ -3765,24 +3782,74 @@ def make_node(self, x):
37653782
if x.ndim < 2:
37663783
raise ValueError("ExtractDiag needs an input with 2 or more dimensions", x)
37673784

3768-
out_shape = [
3769-
st_dim
3770-
for i, st_dim in enumerate(x.type.shape)
3771-
if i not in (self.axis1, self.axis2)
3772-
] + [None]
3785+
if (dim1 := x.type.shape[self.axis1]) is not None and (
3786+
dim2 := x.type.shape[self.axis2]
3787+
) is not None:
3788+
offset = self.offset
3789+
if offset > 0:
3790+
diag_size = int(np.clip(dim2 - offset, 0, dim1))
3791+
elif offset < 0:
3792+
diag_size = int(np.clip(dim1 + offset, 0, dim2))
3793+
else:
3794+
diag_size = int(np.minimum(dim1, dim2))
3795+
else:
3796+
diag_size = None
3797+
3798+
out_shape = (
3799+
*(
3800+
dim
3801+
for i, dim in enumerate(x.type.shape)
3802+
if i not in (self.axis1, self.axis2)
3803+
),
3804+
diag_size,
3805+
)
37733806

37743807
return Apply(
37753808
self,
37763809
[x],
3777-
[x.type.clone(dtype=x.dtype, shape=tuple(out_shape))()],
3810+
[x.type.clone(dtype=x.dtype, shape=out_shape)()],
37783811
)
37793812

3780-
def perform(self, node, inputs, outputs):
3813+
def perform(self, node, inputs, output_storage):
37813814
(x,) = inputs
3782-
(z,) = outputs
3783-
z[0] = x.diagonal(self.offset, self.axis1, self.axis2)
3784-
if not self.view:
3785-
z[0] = z[0].copy()
3815+
out = x.diagonal(self.offset, self.axis1, self.axis2)
3816+
if self.view:
3817+
try:
3818+
out.flags.writeable = True
3819+
except ValueError:
3820+
# We can't make this array writable
3821+
out = out.copy()
3822+
else:
3823+
out = out.copy()
3824+
output_storage[0][0] = out
3825+
3826+
def c_code(self, node, nodename, input_names, output_names, sub):
3827+
[x_name] = input_names
3828+
[out_name] = output_names
3829+
return f"""
3830+
Py_XDECREF({out_name});
3831+
3832+
{out_name} = (PyArrayObject*) PyArray_Diagonal({x_name}, {self.offset}, {self.axis1}, {self.axis2});
3833+
if (!{out_name}) {{
3834+
{sub["fail"]} // Error already set by Numpy
3835+
}}
3836+
3837+
if ({int(self.view)} && PyArray_ISWRITEABLE({x_name})) {{
3838+
// Make output writeable if input was writeable
3839+
PyArray_ENABLEFLAGS({out_name}, NPY_ARRAY_WRITEABLE);
3840+
}} else {{
3841+
// Make a copy
3842+
PyArrayObject *{out_name}_copy = (PyArrayObject*) PyArray_Copy({out_name});
3843+
Py_DECREF({out_name});
3844+
if (!{out_name}_copy) {{
3845+
{sub['fail']}; // Error already set by Numpy
3846+
}}
3847+
{out_name} = {out_name}_copy;
3848+
}}
3849+
"""
3850+
3851+
def c_code_cache_version(self):
3852+
return (0,)
37863853

37873854
def grad(self, inputs, gout):
37883855
# Avoid circular import
@@ -3829,19 +3896,6 @@ def infer_shape(self, fgraph, node, shapes):
38293896
out_shape.append(diag_size)
38303897
return [tuple(out_shape)]
38313898

3832-
def __setstate__(self, state):
3833-
self.__dict__.update(state)
3834-
3835-
if self.view:
3836-
self.view_map = {0: [0]}
3837-
3838-
if "offset" not in state:
3839-
self.offset = 0
3840-
if "axis1" not in state:
3841-
self.axis1 = 0
3842-
if "axis2" not in state:
3843-
self.axis2 = 1
3844-
38453899

38463900
def extract_diag(x):
38473901
warnings.warn(

0 commit comments

Comments
 (0)