Skip to content

Commit a9e76ef

Browse files
authored
Merge branch 'master' into update-tests-part-2
2 parents db0094c + c4997cc commit a9e76ef

File tree

4 files changed

+68
-2
lines changed

4 files changed

+68
-2
lines changed

dpnp/backend/extensions/blas/gemm.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue &,
5555
const std::int64_t,
5656
char *,
5757
const std::int64_t,
58+
#if !defined(USE_ONEMKL_CUBLAS)
5859
const bool,
60+
#endif // !USE_ONEMKL_CUBLAS
5961
const std::vector<sycl::event> &);
6062

6163
static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types]
@@ -74,7 +76,9 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
7476
const std::int64_t ldb,
7577
char *resultC,
7678
const std::int64_t ldc,
79+
#if !defined(USE_ONEMKL_CUBLAS)
7780
const bool is_row_major,
81+
#endif // !USE_ONEMKL_CUBLAS
7882
const std::vector<sycl::event> &depends)
7983
{
8084
type_utils::validate_type_for_device<Tab>(exec_q);
@@ -236,6 +240,7 @@ std::tuple<sycl::event, sycl::event, bool>
236240
std::int64_t lda;
237241
std::int64_t ldb;
238242

243+
// cuBLAS supports only column-major storage
239244
#if defined(USE_ONEMKL_CUBLAS)
240245
const bool is_row_major = false;
241246

@@ -315,9 +320,15 @@ std::tuple<sycl::event, sycl::event, bool>
315320
const char *b_typeless_ptr = matrixB.get_data();
316321
char *r_typeless_ptr = resultC.get_data();
317322

323+
#if defined(USE_ONEMKL_CUBLAS)
324+
sycl::event gemm_ev =
325+
gemm_fn(exec_q, transA, transB, m, n, k, a_typeless_ptr, lda,
326+
b_typeless_ptr, ldb, r_typeless_ptr, ldc, depends);
327+
#else
318328
sycl::event gemm_ev = gemm_fn(exec_q, transA, transB, m, n, k,
319329
a_typeless_ptr, lda, b_typeless_ptr, ldb,
320330
r_typeless_ptr, ldc, is_row_major, depends);
331+
#endif // USE_ONEMKL_CUBLAS
321332

322333
sycl::event args_ev = dpctl::utils::keep_args_alive(
323334
exec_q, {matrixA, matrixB, resultC}, {gemm_ev});

dpnp/backend/extensions/blas/gemm_batch.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ typedef sycl::event (*gemm_batch_impl_fn_ptr_t)(
6060
const char *,
6161
const char *,
6262
char *,
63+
#if !defined(USE_ONEMKL_CUBLAS)
6364
const bool,
65+
#endif // !USE_ONEMKL_CUBLAS
6466
const std::vector<sycl::event> &);
6567

6668
static gemm_batch_impl_fn_ptr_t
@@ -83,7 +85,9 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q,
8385
const char *matrixA,
8486
const char *matrixB,
8587
char *resultC,
88+
#if !defined(USE_ONEMKL_CUBLAS)
8689
const bool is_row_major,
90+
#endif // !USE_ONEMKL_CUBLAS
8791
const std::vector<sycl::event> &depends)
8892
{
8993
type_utils::validate_type_for_device<Tab>(exec_q);
@@ -311,6 +315,7 @@ std::tuple<sycl::event, sycl::event, bool>
311315
std::int64_t lda;
312316
std::int64_t ldb;
313317

318+
// cuBLAS supports only column-major storage
314319
#if defined(USE_ONEMKL_CUBLAS)
315320
const bool is_row_major = false;
316321

@@ -391,10 +396,17 @@ std::tuple<sycl::event, sycl::event, bool>
391396
const char *b_typeless_ptr = matrixB.get_data();
392397
char *r_typeless_ptr = resultC.get_data();
393398

399+
#if defined(USE_ONEMKL_CUBLAS)
400+
sycl::event gemm_batch_ev =
401+
gemm_batch_fn(exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea,
402+
strideb, stridec, transA, transB, a_typeless_ptr,
403+
b_typeless_ptr, r_typeless_ptr, depends);
404+
#else
394405
sycl::event gemm_batch_ev =
395406
gemm_batch_fn(exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea,
396407
strideb, stridec, transA, transB, a_typeless_ptr,
397408
b_typeless_ptr, r_typeless_ptr, is_row_major, depends);
409+
#endif // USE_ONEMKL_CUBLAS
398410

399411
sycl::event args_ev = dpctl::utils::keep_args_alive(
400412
exec_q, {matrixA, matrixB, resultC}, {gemm_batch_ev});

dpnp/backend/extensions/blas/gemv.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ typedef sycl::event (*gemv_impl_fn_ptr_t)(sycl::queue &,
5353
const std::int64_t,
5454
char *,
5555
const std::int64_t,
56+
#if !defined(USE_ONEMKL_CUBLAS)
5657
const bool,
58+
#endif // !USE_ONEMKL_CUBLAS
5759
const std::vector<sycl::event> &);
5860

5961
static gemv_impl_fn_ptr_t gemv_dispatch_vector[dpctl_td_ns::num_types];
@@ -69,7 +71,9 @@ static sycl::event gemv_impl(sycl::queue &exec_q,
6971
const std::int64_t incx,
7072
char *vectorY,
7173
const std::int64_t incy,
74+
#if !defined(USE_ONEMKL_CUBLAS)
7275
const bool is_row_major,
76+
#endif // !USE_ONEMKL_CUBLAS
7377
const std::vector<sycl::event> &depends)
7478
{
7579
type_utils::validate_type_for_device<T>(exec_q);
@@ -190,6 +194,7 @@ std::pair<sycl::event, sycl::event>
190194
oneapi::mkl::transpose transA;
191195
std::size_t src_nelems;
192196

197+
// cuBLAS supports only column-major storage
193198
#if defined(USE_ONEMKL_CUBLAS)
194199
const bool is_row_major = false;
195200
std::int64_t m;
@@ -299,9 +304,15 @@ std::pair<sycl::event, sycl::event>
299304
y_typeless_ptr -= (y_shape[0] - 1) * std::abs(incy) * y_elemsize;
300305
}
301306

307+
#if defined(USE_ONEMKL_CUBLAS)
308+
sycl::event gemv_ev =
309+
gemv_fn(exec_q, transA, m, n, a_typeless_ptr, lda, x_typeless_ptr, incx,
310+
y_typeless_ptr, incy, depends);
311+
#else
302312
sycl::event gemv_ev =
303313
gemv_fn(exec_q, transA, m, n, a_typeless_ptr, lda, x_typeless_ptr, incx,
304314
y_typeless_ptr, incy, is_row_major, depends);
315+
#endif // USE_ONEMKL_CUBLAS
305316

306317
sycl::event args_ev = dpctl::utils::keep_args_alive(
307318
exec_q, {matrixA, vectorX, vectorY}, {gemv_ev});

dpnp/tests/third_party/cupy/sorting_tests/test_search.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,22 @@ def test_nanargmin_zero_size_axis1(self, xp, dtype):
532532
a = testing.shaped_random((0, 1), xp, dtype)
533533
return xp.nanargmin(a, axis=1)
534534

535+
@testing.for_all_dtypes(no_complex=True)
536+
@testing.numpy_cupy_allclose()
537+
def test_nanargmin_out_float_dtype(self, xp, dtype):
538+
a = xp.array([[0.0]])
539+
b = xp.empty((1), dtype="int64")
540+
xp.nanargmin(a, axis=1, out=b)
541+
return b
542+
543+
@testing.for_all_dtypes(no_complex=True)
544+
@testing.numpy_cupy_array_equal()
545+
def test_nanargmin_out_int_dtype(self, xp, dtype):
546+
a = xp.array([1, 0])
547+
b = xp.empty((), dtype="int64")
548+
xp.nanargmin(a, out=b)
549+
return b
550+
535551

536552
class TestNanArgMax:
537553

@@ -623,6 +639,22 @@ def test_nanargmax_zero_size_axis1(self, xp, dtype):
623639
a = testing.shaped_random((0, 1), xp, dtype)
624640
return xp.nanargmax(a, axis=1)
625641

642+
@testing.for_all_dtypes(no_complex=True)
643+
@testing.numpy_cupy_allclose()
644+
def test_nanargmax_out_float_dtype(self, xp, dtype):
645+
a = xp.array([[0.0]])
646+
b = xp.empty((1), dtype="int64")
647+
xp.nanargmax(a, axis=1, out=b)
648+
return b
649+
650+
@testing.for_all_dtypes(no_complex=True)
651+
@testing.numpy_cupy_array_equal()
652+
def test_nanargmax_out_int_dtype(self, xp, dtype):
653+
a = xp.array([0, 1])
654+
b = xp.empty((), dtype="int64")
655+
xp.nanargmax(a, out=b)
656+
return b
657+
626658

627659
@testing.parameterize(
628660
*testing.product(
@@ -771,7 +803,7 @@ def test_invalid_sorter(self):
771803

772804
def test_nonint_sorter(self):
773805
for xp in (numpy, cupy):
774-
x = testing.shaped_arange((12,), xp, xp.float32)
806+
x = testing.shaped_arange((12,), xp, xp.float64)
775807
bins = xp.array([10, 4, 2, 1, 8])
776808
sorter = xp.array([], dtype=xp.float32)
777809
with pytest.raises((TypeError, ValueError)):
@@ -865,7 +897,7 @@ def test_invalid_sorter(self):
865897

866898
def test_nonint_sorter(self):
867899
for xp in (numpy, cupy):
868-
x = testing.shaped_arange((12,), xp, xp.float32)
900+
x = testing.shaped_arange((12,), xp, xp.float64)
869901
bins = xp.array([10, 4, 2, 1, 8])
870902
sorter = xp.array([], dtype=xp.float32)
871903
with pytest.raises((TypeError, ValueError)):

0 commit comments

Comments
 (0)