|
2 | 2 |
|
3 | 3 | from numpy import convolve as numpy_convolve
|
4 | 4 |
|
5 |
| -from pytensor.graph import Apply, Op |
| 5 | +from pytensor.graph import Apply |
| 6 | +from pytensor.link.c.op import COp |
6 | 7 | from pytensor.scalar.basic import upcast
|
7 | 8 | from pytensor.tensor.basic import as_tensor_variable, join, zeros
|
8 | 9 | from pytensor.tensor.blockwise import Blockwise
|
|
15 | 16 | from pytensor.tensor import TensorLike
|
16 | 17 |
|
17 | 18 |
|
18 |
| -class Convolve1d(Op): |
| 19 | +class Convolve1d(COp): |
19 | 20 | __props__ = ("mode",)
|
20 | 21 | gufunc_signature = "(n),(k)->(o)"
|
21 | 22 |
|
@@ -86,6 +87,87 @@ def L_op(self, inputs, outputs, output_grads):
|
86 | 87 |
|
87 | 88 | return [in1_bar, in2_bar]
|
88 | 89 |
|
| 90 | + def c_code_cache_version(self): |
| 91 | + return (1,) |
| 92 | + |
| 93 | + def c_code(self, node, name, inputs, outputs, sub): |
| 94 | + # raise NotImplementedError() |
| 95 | + in1, in2 = inputs |
| 96 | + [out] = outputs |
| 97 | + mode_str = self.mode |
| 98 | + |
| 99 | + if mode_str == "full": |
| 100 | + np_mode_val = 2 # NPY_CONVOLVE_FULL |
| 101 | + elif mode_str == "valid": |
| 102 | + np_mode_val = 0 # NPY_CONVOLVE_VALID |
| 103 | + else: |
| 104 | + # This case should ideally be prevented by __init__ or make_node |
| 105 | + raise ValueError(f"Unsupported mode {mode_str}") |
| 106 | + |
| 107 | + code = f""" |
| 108 | + {{ |
| 109 | + PyArrayObject* in2_flipped_view = NULL; |
| 110 | +
|
| 111 | + if (PyArray_NDIM({in1}) != 1 || PyArray_NDIM({in2}) != 1) {{ |
| 112 | + PyErr_SetString(PyExc_ValueError, "Convolve1d C code expects 1D arrays."); |
| 113 | + {sub['fail']}; |
| 114 | + }} |
| 115 | +
|
| 116 | + npy_intp n_in2 = PyArray_DIM({in2}, 0); |
| 117 | +
|
| 118 | + // Create a reversed view of in2 |
| 119 | + if (n_in2 == 0) {{ |
| 120 | + PyErr_SetString(PyExc_ValueError, "Convolve1d: second input (kernel) cannot be empty."); |
| 121 | + {sub['fail']}; |
| 122 | + }} else {{ |
| 123 | + npy_intp view_dims[1]; |
| 124 | + view_dims[0] = n_in2; |
| 125 | +
|
| 126 | + npy_intp view_strides[1]; |
| 127 | + view_strides[0] = -PyArray_STRIDES({in2})[0]; |
| 128 | +
|
| 129 | + void* view_data = (char*)PyArray_DATA({in2}) + (n_in2 - 1) * PyArray_STRIDES({in2})[0]; |
| 130 | +
|
| 131 | + Py_INCREF(PyArray_DESCR({in2})); |
| 132 | + in2_flipped_view = (PyArrayObject*)PyArray_NewFromDescr( |
| 133 | + Py_TYPE({in2}), |
| 134 | + PyArray_DESCR({in2}), |
| 135 | + 1, // ndim |
| 136 | + view_dims, |
| 137 | + view_strides, |
| 138 | + view_data, |
| 139 | + (PyArray_FLAGS({in2}) & ~NPY_ARRAY_WRITEABLE), |
| 140 | + NULL |
| 141 | + ); |
| 142 | +
|
| 143 | + if (!in2_flipped_view) {{ |
| 144 | + PyErr_SetString(PyExc_RuntimeError, "Failed to create flipped kernel view for Convolve1d."); |
| 145 | + {sub['fail']}; |
| 146 | + }} |
| 147 | +
|
| 148 | + Py_INCREF({in2}); |
| 149 | + if (PyArray_SetBaseObject(in2_flipped_view, (PyObject*){in2}) < 0) {{ |
| 150 | + Py_DECREF({in2}); // SetBaseObject failed, release the extra INCREF |
| 151 | + Py_DECREF(in2_flipped_view); |
| 152 | + in2_flipped_view = NULL; |
| 153 | + PyErr_SetString(PyExc_RuntimeError, "Failed to set base object for flipped kernel view in Convolve1d."); |
| 154 | + {sub['fail']}; |
| 155 | + }} |
| 156 | + PyArray_UpdateFlags(in2_flipped_view, (NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_F_CONTIGUOUS)); |
| 157 | + }} |
| 158 | +
|
| 159 | + // TODO: Use lower level implementation that allows reusing the output buffer |
| 160 | + Py_XDECREF({out}); |
| 161 | + {out} = (PyArrayObject*) PyArray_Correlate2((PyObject*){in1}, (PyObject*)in2_flipped_view, {np_mode_val}); |
| 162 | + Py_XDECREF(in2_flipped_view); // Clean up the view if correlate fails |
| 163 | + if (!{out}) {{ |
| 164 | + // PyArray_Correlate already set an error |
| 165 | + {sub['fail']}; |
| 166 | + }} |
| 167 | + }} |
| 168 | + """ |
| 169 | + return code |
| 170 | + |
89 | 171 |
|
90 | 172 | def convolve1d(
|
91 | 173 | in1: "TensorLike",
|
|
0 commit comments