Skip to content

Commit 6fb515d

Browse files
committed
C implementation of Convolve1d
1 parent 6557682 commit 6fb515d

File tree

1 file changed

+84
-2
lines changed

1 file changed

+84
-2
lines changed

pytensor/tensor/signal/conv.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from numpy import convolve as numpy_convolve
44

5-
from pytensor.graph import Apply, Op
5+
from pytensor.graph import Apply
6+
from pytensor.link.c.op import COp
67
from pytensor.scalar.basic import upcast
78
from pytensor.tensor.basic import as_tensor_variable, join, zeros
89
from pytensor.tensor.blockwise import Blockwise
@@ -15,7 +16,7 @@
1516
from pytensor.tensor import TensorLike
1617

1718

18-
class Convolve1d(Op):
19+
class Convolve1d(COp):
1920
__props__ = ("mode",)
2021
gufunc_signature = "(n),(k)->(o)"
2122

@@ -86,6 +87,87 @@ def L_op(self, inputs, outputs, output_grads):
8687

8788
return [in1_bar, in2_bar]
8889

90+
def c_code_cache_version(self):
91+
return (1,)
92+
93+
def c_code(self, node, name, inputs, outputs, sub):
94+
# raise NotImplementedError()
95+
in1, in2 = inputs
96+
[out] = outputs
97+
mode_str = self.mode
98+
99+
if mode_str == "full":
100+
np_mode_val = 2 # NPY_CONVOLVE_FULL
101+
elif mode_str == "valid":
102+
np_mode_val = 0 # NPY_CONVOLVE_VALID
103+
else:
104+
# This case should ideally be prevented by __init__ or make_node
105+
raise ValueError(f"Unsupported mode {mode_str}")
106+
107+
code = f"""
108+
{{
109+
PyArrayObject* in2_flipped_view = NULL;
110+
111+
if (PyArray_NDIM({in1}) != 1 || PyArray_NDIM({in2}) != 1) {{
112+
PyErr_SetString(PyExc_ValueError, "Convolve1d C code expects 1D arrays.");
113+
{sub['fail']};
114+
}}
115+
116+
npy_intp n_in2 = PyArray_DIM({in2}, 0);
117+
118+
// Create a reversed view of in2
119+
if (n_in2 == 0) {{
120+
PyErr_SetString(PyExc_ValueError, "Convolve1d: second input (kernel) cannot be empty.");
121+
{sub['fail']};
122+
}} else {{
123+
npy_intp view_dims[1];
124+
view_dims[0] = n_in2;
125+
126+
npy_intp view_strides[1];
127+
view_strides[0] = -PyArray_STRIDES({in2})[0];
128+
129+
void* view_data = (char*)PyArray_DATA({in2}) + (n_in2 - 1) * PyArray_STRIDES({in2})[0];
130+
131+
Py_INCREF(PyArray_DESCR({in2}));
132+
in2_flipped_view = (PyArrayObject*)PyArray_NewFromDescr(
133+
Py_TYPE({in2}),
134+
PyArray_DESCR({in2}),
135+
1, // ndim
136+
view_dims,
137+
view_strides,
138+
view_data,
139+
(PyArray_FLAGS({in2}) & ~NPY_ARRAY_WRITEABLE),
140+
NULL
141+
);
142+
143+
if (!in2_flipped_view) {{
144+
PyErr_SetString(PyExc_RuntimeError, "Failed to create flipped kernel view for Convolve1d.");
145+
{sub['fail']};
146+
}}
147+
148+
Py_INCREF({in2});
149+
if (PyArray_SetBaseObject(in2_flipped_view, (PyObject*){in2}) < 0) {{
150+
Py_DECREF({in2}); // SetBaseObject failed, release the extra INCREF
151+
Py_DECREF(in2_flipped_view);
152+
in2_flipped_view = NULL;
153+
PyErr_SetString(PyExc_RuntimeError, "Failed to set base object for flipped kernel view in Convolve1d.");
154+
{sub['fail']};
155+
}}
156+
PyArray_UpdateFlags(in2_flipped_view, (NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_F_CONTIGUOUS));
157+
}}
158+
159+
// TODO: Use lower level implementation that allows reusing the output buffer
160+
Py_XDECREF({out});
161+
{out} = (PyArrayObject*) PyArray_Correlate2((PyObject*){in1}, (PyObject*)in2_flipped_view, {np_mode_val});
162+
Py_XDECREF(in2_flipped_view); // Clean up the view if correlate fails
163+
if (!{out}) {{
164+
// PyArray_Correlate already set an error
165+
{sub['fail']};
166+
}}
167+
}}
168+
"""
169+
return code
170+
89171

90172
def convolve1d(
91173
in1: "TensorLike",

0 commit comments

Comments
 (0)