@@ -831,6 +831,10 @@ def c_code(self, node, name, inputs, outputs, sub):
831
831
npy_intp Sptr = PyArray_STRIDES(%(x_ptr)s)[0] / PyArray_DESCR(%(x_ptr)s)->elsize;
832
832
npy_intp Sy = PyArray_STRIDES(%(y)s)[1] / PyArray_DESCR(%(y)s)->elsize;
833
833
834
+ // blas expects ints; convert here (rather than just making N etc ints) to avoid potential overflow in the negative-stride correction
835
+ int N32 = N;
836
+ int Sy32 = Sy;
837
+ int Szn32 = Szn;
834
838
835
839
if (!(%(inplace)s))
836
840
{
@@ -860,7 +864,7 @@ def c_code(self, node, name, inputs, outputs, sub):
860
864
if (Szn < 0)
861
865
z_row += (N - 1) * Szn;
862
866
863
- %(axpy)s((int*)&N , (%(conv_type)s*)&Amk, (%(conv_type)s*)y_row, (int*)&Sy , (%(conv_type)s*)z_row, (int*)&Szn );
867
+ %(axpy)s(&N32 , (%(conv_type)s*)&Amk, (%(conv_type)s*)y_row, &Sy32 , (%(conv_type)s*)z_row, &Szn32 );
864
868
}
865
869
}
866
870
}
@@ -869,7 +873,7 @@ def c_code(self, node, name, inputs, outputs, sub):
869
873
return rval
870
874
871
875
def c_code_cache_version (self ):
872
- return (1 , blas .blas_header_version ())
876
+ return (2 , blas .blas_header_version ())
873
877
usmm_csc_dense = UsmmCscDense (inplace = False )
874
878
usmm_csc_dense_inplace = UsmmCscDense (inplace = True )
875
879
@@ -1749,7 +1753,7 @@ def make_node(self, x, y, p_data, p_ind, p_ptr, p_ncols):
1749
1753
])
1750
1754
1751
1755
def c_code_cache_version (self ):
1752
- return (2 , blas .blas_header_version ())
1756
+ return (3 , blas .blas_header_version ())
1753
1757
1754
1758
def c_support_code (self ):
1755
1759
return blas .blas_header_text ()
@@ -1892,15 +1896,27 @@ def c_code(self, node, name, inputs, outputs, sub):
1892
1896
memcpy(Dzi, Dpi, PyArray_DIMS(%(p_ind)s)[0]*sizeof(dtype_%(p_ind)s));
1893
1897
memcpy(Dzp, Dpp, PyArray_DIMS(%(p_ptr)s)[0]*sizeof(dtype_%(p_ptr)s));
1894
1898
1899
+ // blas expects ints; convert here (rather than just making K etc ints) to avoid potential overflow in the negative-stride correction
1900
+ int K32 = K;
1901
+ int Sdx32 = Sdx;
1902
+ int Sdy32 = Sdy;
1903
+
1895
1904
for (npy_int32 m = 0; m < M; ++m) {
1896
1905
for (npy_int32 n_idx = Dpp[m * Sdpp]; n_idx < Dpp[(m+1)*Sdpp]; ++n_idx) {
1897
1906
const npy_int32 n = Dpi[n_idx * Sdpi]; // row index of non-null value for column K
1898
1907
1899
1908
const dtype_%(x)s* x_row = (dtype_%(x)s*)(PyArray_BYTES(%(x)s) + PyArray_STRIDES(%(x)s)[0] * m);
1900
1909
1901
1910
const dtype_%(y)s* y_col = (dtype_%(y)s*)(PyArray_BYTES(%(y)s) + PyArray_STRIDES(%(y)s)[0] * n);
1911
+ // dot expects pointer to the beginning of memory arrays,
1912
+ // so when the stride is negative, we need to get the
1913
+ // last element
1914
+ if (Sdx < 0)
1915
+ x_row += (K - 1) * Sdx;
1916
+ if (Sdy < 0)
1917
+ y_col += (K - 1) * Sdy;
1902
1918
1903
- Dzd[n_idx * Sdzd] = Dpd[n_idx * Sdpd] * %(cdot)s((int*)&K , (const %(conv_type)s*)x_row, (int*)&Sdx , (const %(conv_type)s*)y_col, (int*)&Sdy );
1919
+ Dzd[n_idx * Sdzd] = Dpd[n_idx * Sdpd] * %(cdot)s(&K32 , (const %(conv_type)s*)x_row, &Sdx32 , (const %(conv_type)s*)y_col, &Sdy32 );
1904
1920
}
1905
1921
}
1906
1922
}
0 commit comments