Skip to content

Commit 8fb9d9a

Browse files
authored
Merge pull request #5553 from rebecca-palmer/samplingdot_usmm_fixes
Fix invalid pointer casts and negative stride handling in sparse
2 parents 596daba + ff969f3 commit 8fb9d9a

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

theano/sparse/opt.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,10 @@ def c_code(self, node, name, inputs, outputs, sub):
831831
npy_intp Sptr = PyArray_STRIDES(%(x_ptr)s)[0] / PyArray_DESCR(%(x_ptr)s)->elsize;
832832
npy_intp Sy = PyArray_STRIDES(%(y)s)[1] / PyArray_DESCR(%(y)s)->elsize;
833833
834+
// blas expects ints; convert here (rather than just making N etc ints) to avoid potential overflow in the negative-stride correction
835+
int N32 = N;
836+
int Sy32 = Sy;
837+
int Szn32 = Szn;
834838
835839
if (!(%(inplace)s))
836840
{
@@ -860,7 +864,7 @@ def c_code(self, node, name, inputs, outputs, sub):
860864
if (Szn < 0)
861865
z_row += (N - 1) * Szn;
862866
863-
%(axpy)s((int*)&N, (%(conv_type)s*)&Amk, (%(conv_type)s*)y_row, (int*)&Sy, (%(conv_type)s*)z_row, (int*)&Szn);
867+
%(axpy)s(&N32, (%(conv_type)s*)&Amk, (%(conv_type)s*)y_row, &Sy32, (%(conv_type)s*)z_row, &Szn32);
864868
}
865869
}
866870
}
@@ -869,7 +873,7 @@ def c_code(self, node, name, inputs, outputs, sub):
869873
return rval
870874

871875
def c_code_cache_version(self):
872-
return (1, blas.blas_header_version())
876+
return (2, blas.blas_header_version())
873877
usmm_csc_dense = UsmmCscDense(inplace=False)
874878
usmm_csc_dense_inplace = UsmmCscDense(inplace=True)
875879

@@ -1749,7 +1753,7 @@ def make_node(self, x, y, p_data, p_ind, p_ptr, p_ncols):
17491753
])
17501754

17511755
def c_code_cache_version(self):
1752-
return (2, blas.blas_header_version())
1756+
return (3, blas.blas_header_version())
17531757

17541758
def c_support_code(self):
17551759
return blas.blas_header_text()
@@ -1892,15 +1896,27 @@ def c_code(self, node, name, inputs, outputs, sub):
18921896
memcpy(Dzi, Dpi, PyArray_DIMS(%(p_ind)s)[0]*sizeof(dtype_%(p_ind)s));
18931897
memcpy(Dzp, Dpp, PyArray_DIMS(%(p_ptr)s)[0]*sizeof(dtype_%(p_ptr)s));
18941898
1899+
// blas expects ints; convert here (rather than just making K etc ints) to avoid potential overflow in the negative-stride correction
1900+
int K32 = K;
1901+
int Sdx32 = Sdx;
1902+
int Sdy32 = Sdy;
1903+
18951904
for (npy_int32 m = 0; m < M; ++m) {
18961905
for (npy_int32 n_idx = Dpp[m * Sdpp]; n_idx < Dpp[(m+1)*Sdpp]; ++n_idx) {
18971906
const npy_int32 n = Dpi[n_idx * Sdpi]; // row index of non-null value for column K
18981907
18991908
const dtype_%(x)s* x_row = (dtype_%(x)s*)(PyArray_BYTES(%(x)s) + PyArray_STRIDES(%(x)s)[0] * m);
19001909
19011910
const dtype_%(y)s* y_col = (dtype_%(y)s*)(PyArray_BYTES(%(y)s) + PyArray_STRIDES(%(y)s)[0] * n);
1911+
// dot expects pointer to the beginning of memory arrays,
1912+
// so when the stride is negative, we need to get the
1913+
// last element
1914+
if (Sdx < 0)
1915+
x_row += (K - 1) * Sdx;
1916+
if (Sdy < 0)
1917+
y_col += (K - 1) * Sdy;
19021918
1903-
Dzd[n_idx * Sdzd] = Dpd[n_idx * Sdpd] * %(cdot)s((int*)&K, (const %(conv_type)s*)x_row, (int*)&Sdx, (const %(conv_type)s*)y_col, (int*)&Sdy);
1919+
Dzd[n_idx * Sdzd] = Dpd[n_idx * Sdpd] * %(cdot)s(&K32, (const %(conv_type)s*)x_row, &Sdx32, (const %(conv_type)s*)y_col, &Sdy32);
19041920
}
19051921
}
19061922
}

theano/sparse/tests/test_basic.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3172,6 +3172,20 @@ def test_op(self):
31723172
assert tested.format == 'csr'
31733173
assert tested.dtype == expected.dtype
31743174

3175+
def test_negative_stride(self):
3176+
f = theano.function(
3177+
self.x,
3178+
sampling_dot(*self.x))
3179+
3180+
a2 = [self.a[0][::-1,:], self.a[1][:,::-1], self.a[2]]
3181+
tested = f(*a2)
3182+
x, y, p = a2
3183+
expected = p.multiply(np.dot(x, y.T))
3184+
3185+
utt.assert_allclose(as_ndarray(expected), tested.toarray())
3186+
assert tested.format == 'csr'
3187+
assert tested.dtype == expected.dtype
3188+
31753189
def test_infer_shape(self):
31763190
self._compile_and_check(self.x,
31773191
[sampling_dot(*self.x)],

0 commit comments

Comments
 (0)