Skip to content

Commit 13c5678

Browse files
committed
Avoid copy of flipped A matrices in GEMV
1 parent b2365e0 commit 13c5678

File tree

2 files changed

+74
-19
lines changed

2 files changed

+74
-19
lines changed

pytensor/tensor/blas_c.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
423423
int Sz = PyArray_STRIDES(%(z)s)[0] / elemsize;
424424
int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
425425
426+
dtype_%(A)s* A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
426427
dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
427428
dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s);
428429
// gemv expects pointers to the beginning of memory arrays,
@@ -435,17 +436,28 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
435436
436437
if (NA0 * NA1)
437438
{
438-
// If A is neither C- nor F-contiguous, we make a copy.
439-
// TODO:
440-
// - if one stride is equal to "- elemsize", we can still call
441-
// gemv on reversed matrix and vectors
442-
// - if the copy is too long, maybe call vector/vector dot on
443-
// each row instead
444-
if ((PyArray_STRIDES(%(A)s)[0] < 0)
445-
|| (PyArray_STRIDES(%(A)s)[1] < 0)
446-
|| ((PyArray_STRIDES(%(A)s)[0] != elemsize)
447-
&& (PyArray_STRIDES(%(A)s)[1] != elemsize)))
439+
if (((SA0 < 0) || (SA1 < 0))
440+
&& (abs(SA0) == 1 || (abs(SA1) == 1))
441+
)
448442
{
443+
// We can treat the array A as C-or F-contiguous by changing the order of iteration
444+
445+
if (SA0 < 0){
446+
A_data += (NA0 -1) * SA0; // Jump to first row
447+
SA0 = -SA0; // Pretend row strides is positive
448+
Sz = -Sz; // Iterate over y in reverse;
449+
}
450+
if (SA1 < 0){
451+
A_data += (NA1 -1) * SA1; // Jump to first column
452+
SA1 = -SA1; // Pretend column strides is positive
453+
Sx = -Sx; // Iterate over x in reverse;
454+
}
455+
456+
} else if ((SA0 < 0) || (SA1 < 0) || ((SA0 != 1) && (SA1 != 1))) {
457+
// Array isn't contiguous, we have to make a copy
458+
// - if the copy is too long, maybe call vector/vector dot on
459+
// each row instead
460+
// printf("GEMV: Making a copy SA0=%%d, SA1=%%d\\n", SA0, SA1);
449461
npy_intp dims[2];
450462
dims[0] = NA0;
451463
dims[1] = NA1;
@@ -458,16 +470,17 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
458470
%(A)s = A_copy;
459471
SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1);
460472
SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1);
473+
A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
461474
}
462475
463-
if (PyArray_STRIDES(%(A)s)[0] == elemsize)
476+
if (SA0 == 1)
464477
{
465478
if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
466479
{
467480
float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
468481
sgemv_(&NOTRANS, &NA0, &NA1,
469482
&alpha,
470-
(float*)(PyArray_DATA(%(A)s)), &SA1,
483+
(float*)(A_data), &SA1,
471484
(float*)x_data, &Sx,
472485
&fbeta,
473486
(float*)z_data, &Sz);
@@ -477,7 +490,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
477490
double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
478491
dgemv_(&NOTRANS, &NA0, &NA1,
479492
&alpha,
480-
(double*)(PyArray_DATA(%(A)s)), &SA1,
493+
(double*)(A_data), &SA1,
481494
(double*)x_data, &Sx,
482495
&dbeta,
483496
(double*)z_data, &Sz);
@@ -489,7 +502,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
489502
%(fail)s
490503
}
491504
}
492-
else if (PyArray_STRIDES(%(A)s)[1] == elemsize)
505+
else if (SA1 == 1)
493506
{
494507
if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
495508
{
@@ -506,14 +519,14 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
506519
z_data[0] = 0.f;
507520
}
508521
z_data[0] += alpha*sdot_(&NA1,
509-
(float*)(PyArray_DATA(%(A)s)), &SA1,
522+
(float*)(A_data), &SA1,
510523
(float*)x_data, &Sx);
511524
}
512525
else
513526
{
514527
sgemv_(&TRANS, &NA1, &NA0,
515528
&alpha,
516-
(float*)(PyArray_DATA(%(A)s)), &SA0,
529+
(float*)(A_data), &SA0,
517530
(float*)x_data, &Sx,
518531
&fbeta,
519532
(float*)z_data, &Sz);
@@ -534,14 +547,14 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
534547
z_data[0] = 0.;
535548
}
536549
z_data[0] += alpha*ddot_(&NA1,
537-
(double*)(PyArray_DATA(%(A)s)), &SA1,
550+
(double*)(A_data), &SA1,
538551
(double*)x_data, &Sx);
539552
}
540553
else
541554
{
542555
dgemv_(&TRANS, &NA1, &NA0,
543556
&alpha,
544-
(double*)(PyArray_DATA(%(A)s)), &SA0,
557+
(double*)(A_data), &SA0,
545558
(double*)x_data, &Sx,
546559
&dbeta,
547560
(double*)z_data, &Sz);
@@ -603,7 +616,7 @@ def c_code(self, node, name, inp, out, sub):
603616
return code
604617

605618
def c_code_cache_version(self):
606-
return (14, blas_header_version(), check_force_gemv_init())
619+
return (15, blas_header_version(), check_force_gemv_init())
607620

608621

609622
cgemv_inplace = CGemv(inplace=True)

tests/tensor/test_blas_c.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,3 +411,45 @@ class TestSdotNoFlags(TestCGemvNoFlags):
411411

412412
class TestBlasStridesC(TestBlasStrides):
413413
mode = mode_blas_opt
414+
415+
416+
@pytest.mark.parametrize(
417+
"neg_stride1", (True, False), ids=["neg_stride1", "pos_stride1"]
418+
)
419+
@pytest.mark.parametrize(
420+
"neg_stride0", (True, False), ids=["neg_stride0", "pos_stride0"]
421+
)
422+
@pytest.mark.parametrize("F_layout", (True, False), ids=["F_layout", "C_layout"])
423+
def test_gemv_negative_strides_perf(neg_stride0, neg_stride1, F_layout, benchmark):
424+
A = pt.matrix("A", shape=(512, 512))
425+
x = pt.vector("y", shape=(A.type.shape[-1],))
426+
y = pt.vector("y", shape=(A.type.shape[0],))
427+
428+
out = CGemv(inplace=False)(
429+
y,
430+
1.0,
431+
A,
432+
x,
433+
1.0,
434+
)
435+
fn = pytensor.function([A, x, y], out, trust_input=True)
436+
437+
rng = np.random.default_rng(430)
438+
test_A = rng.normal(size=A.type.shape)
439+
test_x = rng.normal(size=x.type.shape)
440+
test_y = rng.normal(size=y.type.shape)
441+
442+
if F_layout:
443+
test_A = test_A.T
444+
if neg_stride0:
445+
test_A = test_A[::-1]
446+
if neg_stride1:
447+
test_A = test_A[:, ::-1]
448+
assert (test_A.strides[0] < 0) == neg_stride0
449+
assert (test_A.strides[1] < 0) == neg_stride1
450+
451+
# Check result is correct by using a copy of A with positive strides
452+
res = fn(test_A, test_x, test_y)
453+
np.testing.assert_allclose(res, fn(test_A.copy(), test_x, test_y))
454+
455+
benchmark(fn, test_A, test_x, test_y)

0 commit comments

Comments
 (0)