@@ -423,6 +423,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
423
423
int Sz = PyArray_STRIDES(%(z)s)[0] / elemsize;
424
424
int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
425
425
426
+ dtype_%(A)s* A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
426
427
dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
427
428
dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s);
428
429
// gemv expects pointers to the beginning of memory arrays,
@@ -435,17 +436,28 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
435
436
436
437
if (NA0 * NA1)
437
438
{
438
- // If A is neither C- nor F-contiguous, we make a copy.
439
- // TODO:
440
- // - if one stride is equal to "- elemsize", we can still call
441
- // gemv on reversed matrix and vectors
442
- // - if the copy is too long, maybe call vector/vector dot on
443
- // each row instead
444
- if ((PyArray_STRIDES(%(A)s)[0] < 0)
445
- || (PyArray_STRIDES(%(A)s)[1] < 0)
446
- || ((PyArray_STRIDES(%(A)s)[0] != elemsize)
447
- && (PyArray_STRIDES(%(A)s)[1] != elemsize)))
439
+ if (((SA0 < 0) || (SA1 < 0))
440
+ && (abs(SA0) == 1 || (abs(SA1) == 1))
441
+ )
448
442
{
443
+ // We can treat the array A as C-or F-contiguous by changing the order of iteration
444
+
445
+ if (SA0 < 0){
446
+ A_data += (NA0 -1) * SA0; // Jump to first row
447
+ SA0 = -SA0; // Pretend row strides is positive
448
+ Sz = -Sz; // Iterate over y in reverse;
449
+ }
450
+ if (SA1 < 0){
451
+ A_data += (NA1 -1) * SA1; // Jump to first column
452
+ SA1 = -SA1; // Pretend column strides is positive
453
+ Sx = -Sx; // Iterate over x in reverse;
454
+ }
455
+
456
+ } else if ((SA0 < 0) || (SA1 < 0) || ((SA0 != 1) && (SA1 != 1))) {
457
+ // Array isn't contiguous, we have to make a copy
458
+ // - if the copy is too long, maybe call vector/vector dot on
459
+ // each row instead
460
+ // printf("GEMV: Making a copy SA0=%%d, SA1=%%d\\ n", SA0, SA1);
449
461
npy_intp dims[2];
450
462
dims[0] = NA0;
451
463
dims[1] = NA1;
@@ -458,16 +470,17 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
458
470
%(A)s = A_copy;
459
471
SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1);
460
472
SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1);
473
+ A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
461
474
}
462
475
463
- if (PyArray_STRIDES(%(A)s)[0] == elemsize )
476
+ if (SA0 == 1 )
464
477
{
465
478
if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
466
479
{
467
480
float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
468
481
sgemv_(&NOTRANS, &NA0, &NA1,
469
482
&alpha,
470
- (float*)(PyArray_DATA(%(A)s) ), &SA1,
483
+ (float*)(A_data ), &SA1,
471
484
(float*)x_data, &Sx,
472
485
&fbeta,
473
486
(float*)z_data, &Sz);
@@ -477,7 +490,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
477
490
double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
478
491
dgemv_(&NOTRANS, &NA0, &NA1,
479
492
&alpha,
480
- (double*)(PyArray_DATA(%(A)s) ), &SA1,
493
+ (double*)(A_data ), &SA1,
481
494
(double*)x_data, &Sx,
482
495
&dbeta,
483
496
(double*)z_data, &Sz);
@@ -489,7 +502,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
489
502
%(fail)s
490
503
}
491
504
}
492
- else if (PyArray_STRIDES(%(A)s)[1] == elemsize )
505
+ else if (SA1 == 1 )
493
506
{
494
507
if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
495
508
{
@@ -506,14 +519,14 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
506
519
z_data[0] = 0.f;
507
520
}
508
521
z_data[0] += alpha*sdot_(&NA1,
509
- (float*)(PyArray_DATA(%(A)s) ), &SA1,
522
+ (float*)(A_data ), &SA1,
510
523
(float*)x_data, &Sx);
511
524
}
512
525
else
513
526
{
514
527
sgemv_(&TRANS, &NA1, &NA0,
515
528
&alpha,
516
- (float*)(PyArray_DATA(%(A)s) ), &SA0,
529
+ (float*)(A_data ), &SA0,
517
530
(float*)x_data, &Sx,
518
531
&fbeta,
519
532
(float*)z_data, &Sz);
@@ -534,14 +547,14 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
534
547
z_data[0] = 0.;
535
548
}
536
549
z_data[0] += alpha*ddot_(&NA1,
537
- (double*)(PyArray_DATA(%(A)s) ), &SA1,
550
+ (double*)(A_data ), &SA1,
538
551
(double*)x_data, &Sx);
539
552
}
540
553
else
541
554
{
542
555
dgemv_(&TRANS, &NA1, &NA0,
543
556
&alpha,
544
- (double*)(PyArray_DATA(%(A)s) ), &SA0,
557
+ (double*)(A_data ), &SA0,
545
558
(double*)x_data, &Sx,
546
559
&dbeta,
547
560
(double*)z_data, &Sz);
@@ -603,7 +616,7 @@ def c_code(self, node, name, inp, out, sub):
603
616
return code
604
617
605
618
def c_code_cache_version (self ):
606
- return (14 , blas_header_version (), check_force_gemv_init ())
619
+ return (15 , blas_header_version (), check_force_gemv_init ())
607
620
608
621
609
622
cgemv_inplace = CGemv (inplace = True )
0 commit comments