diff --git a/pytensor/tensor/blas_c.py b/pytensor/tensor/blas_c.py index 46c8e884fc..6d1d830cf8 100644 --- a/pytensor/tensor/blas_c.py +++ b/pytensor/tensor/blas_c.py @@ -423,6 +423,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non int Sz = PyArray_STRIDES(%(z)s)[0] / elemsize; int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize; + dtype_%(A)s* A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s); dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s); dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s); // gemv expects pointers to the beginning of memory arrays, @@ -435,17 +436,25 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non if (NA0 * NA1) { - // If A is neither C- nor F-contiguous, we make a copy. - // TODO: - // - if one stride is equal to "- elemsize", we can still call - // gemv on reversed matrix and vectors - // - if the copy is too long, maybe call vector/vector dot on - // each row instead - if ((PyArray_STRIDES(%(A)s)[0] < 0) - || (PyArray_STRIDES(%(A)s)[1] < 0) - || ((PyArray_STRIDES(%(A)s)[0] != elemsize) - && (PyArray_STRIDES(%(A)s)[1] != elemsize))) + if ( ((SA0 < 0) || (SA1 < 0)) && (abs(SA0) == 1 || (abs(SA1) == 1)) ) { + // We can treat the array A as C-or F-contiguous by changing the order of iteration + if (SA0 < 0){ + A_data += (NA0 -1) * SA0; // Jump to first row + SA0 = -SA0; // Iterate over rows in reverse + Sz = -Sz; // Iterate over y in reverse + } + if (SA1 < 0){ + A_data += (NA1 -1) * SA1; // Jump to first column + SA1 = -SA1; // Iterate over columns in reverse + Sx = -Sx; // Iterate over x in reverse + } + } else if ((SA0 < 0) || (SA1 < 0) || ((SA0 != 1) && (SA1 != 1))) + { + // Array isn't contiguous, we have to make a copy + // - if the copy is too long, maybe call vector/vector dot on + // each row instead + // printf("GEMV: Making a copy SA0=%%d, SA1=%%d\\n", SA0, SA1); npy_intp dims[2]; dims[0] = NA0; dims[1] = NA1; @@ -458,16 +467,17 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non %(A)s = A_copy; SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1); SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1); + A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s); } - if (PyArray_STRIDES(%(A)s)[0] == elemsize) + if (SA0 == 1) { if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT) { float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; sgemv_(&NOTRANS, &NA0, &NA1, &alpha, - (float*)(PyArray_DATA(%(A)s)), &SA1, + (float*)(A_data), &SA1, (float*)x_data, &Sx, &fbeta, (float*)z_data, &Sz); @@ -477,7 +487,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; dgemv_(&NOTRANS, &NA0, &NA1, &alpha, - (double*)(PyArray_DATA(%(A)s)), &SA1, + (double*)(A_data), &SA1, (double*)x_data, &Sx, &dbeta, (double*)z_data, &Sz); @@ -489,7 +499,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non %(fail)s } } - else if (PyArray_STRIDES(%(A)s)[1] == elemsize) + else if (SA1 == 1) { if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT) { @@ -506,14 +516,14 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non z_data[0] = 0.f; } z_data[0] += alpha*sdot_(&NA1, - (float*)(PyArray_DATA(%(A)s)), &SA1, + (float*)(A_data), &SA1, (float*)x_data, &Sx); } else { sgemv_(&TRANS, &NA1, &NA0, &alpha, - (float*)(PyArray_DATA(%(A)s)), &SA0, + (float*)(A_data), &SA0, (float*)x_data, &Sx, &fbeta, (float*)z_data, &Sz); @@ -534,14 +544,14 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non z_data[0] = 0.; } z_data[0] += alpha*ddot_(&NA1, - (double*)(PyArray_DATA(%(A)s)), &SA1, + (double*)(A_data), &SA1, (double*)x_data, &Sx); } else { dgemv_(&TRANS, &NA1, &NA0, &alpha, - (double*)(PyArray_DATA(%(A)s)), &SA0, + (double*)(A_data), &SA0, (double*)x_data, &Sx, &dbeta, (double*)z_data, &Sz); @@ -603,7 +613,7 @@ def c_code(self, node, name, inp, out, sub): return code def c_code_cache_version(self): - return (14, blas_header_version(), check_force_gemv_init()) + return (15, blas_header_version(), check_force_gemv_init()) cgemv_inplace = CGemv(inplace=True) diff --git a/tests/tensor/test_blas_c.py b/tests/tensor/test_blas_c.py index 8298cae5ba..26747d2199 100644 --- a/tests/tensor/test_blas_c.py +++ b/tests/tensor/test_blas_c.py @@ -411,3 +411,45 @@ class TestSdotNoFlags(TestCGemvNoFlags): class TestBlasStridesC(TestBlasStrides): mode = mode_blas_opt + + +@pytest.mark.parametrize( + "neg_stride1", (True, False), ids=["neg_stride1", "pos_stride1"] +) +@pytest.mark.parametrize( + "neg_stride0", (True, False), ids=["neg_stride0", "pos_stride0"] +) +@pytest.mark.parametrize("F_layout", (True, False), ids=["F_layout", "C_layout"]) +def test_gemv_negative_strides_perf(neg_stride0, neg_stride1, F_layout, benchmark): + A = pt.matrix("A", shape=(512, 512)) + x = pt.vector("x", shape=(A.type.shape[-1],)) + y = pt.vector("y", shape=(A.type.shape[0],)) + + out = CGemv(inplace=False)( + y, + 1.0, + A, + x, + 1.0, + ) + fn = pytensor.function([A, x, y], out, trust_input=True) + + rng = np.random.default_rng(430) + test_A = rng.normal(size=A.type.shape) + test_x = rng.normal(size=x.type.shape) + test_y = rng.normal(size=y.type.shape) + + if F_layout: + test_A = test_A.T + if neg_stride0: + test_A = test_A[::-1] + if neg_stride1: + test_A = test_A[:, ::-1] + assert (test_A.strides[0] < 0) == neg_stride0 + assert (test_A.strides[1] < 0) == neg_stride1 + + # Check result is correct by using a copy of A with positive strides + res = fn(test_A, test_x, test_y) + np.testing.assert_allclose(res, fn(test_A.copy(), test_x, test_y)) + + benchmark(fn, test_A, test_x, test_y)