Skip to content

Commit dd22aae

Browse files
committed
improve coverage
1 parent efe804b commit dd22aae

File tree

3 files changed

+66
-54
lines changed

3 files changed

+66
-54
lines changed

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 35 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def _get_result_shape(x1, x2, out, func, np_flag):
244244
x1, x2, x1_ndim, x2_ndim
245245
)
246246
else: # func == "vecdot"
247+
assert func == "vecdot"
247248
x1, x2, result_shape = _get_result_shape_vecdot(
248249
x1, x2, x1_ndim, x2_ndim
249250
)
@@ -466,11 +467,15 @@ def _gemm_matmul(exec_q, x1, x2, res):
466467

467468

468469
def _shape_error(shape1, shape2, func, err_msg):
470+
"""Validate the shapes of input and output arrays."""
469471

470472
if func == "matmul":
471473
signature = "(n?,k),(k,m?)->(n?,m?)"
472-
else: # func == "vecdot"
474+
elif func == "vecdot":
473475
signature = "(n?,),(n?,)->()"
476+
else:
477+
# applicable when err_msg == 3
478+
assert func is None
474479

475480
if err_msg == 0:
476481
raise ValueError(
@@ -485,7 +490,8 @@ def _shape_error(shape1, shape2, func, err_msg):
485490
f"array has shape {shape2}. "
486491
f"These cannot be broadcast together for '{func}' function."
487492
)
488-
elif err_msg == 2:
493+
else: # err_msg == 2:
494+
assert err_msg == 2
489495
raise ValueError(
490496
f"Expected output array of shape {shape1}, but got {shape2}."
491497
)
@@ -557,6 +563,7 @@ def _validate_internal(axes, i, ndim):
557563
x1_ndim = x1.ndim
558564
x2_ndim = x2.ndim
559565
else: # func == "vecdot"
566+
assert func == "vecdot"
560567
x1_ndim = x2_ndim = 1
561568

562569
axes[0] = _validate_internal(axes[0], 0, x1_ndim)
@@ -573,6 +580,16 @@ def _validate_internal(axes, i, ndim):
573580
return axes
574581

575582

583+
def _validate_out_array(out, exec_q):
584+
"""Validate out is supported array and has correct queue."""
585+
if out is not None:
586+
dpnp.check_supported_arrays_type(out)
587+
if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None:
588+
raise ExecutionPlacementError(
589+
"Input and output allocation queues are not compatible"
590+
)
591+
592+
576593
def dpnp_cross(a, b, cp):
577594
"""Return the cross product of two (arrays of) vectors."""
578595

@@ -660,13 +677,7 @@ def dpnp_dot(a, b, /, out=None, *, conjugate=False):
660677
)
661678

662679
res_usm_type, exec_q = get_usm_allocations([a, b])
663-
if (
664-
out is not None
665-
and dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None
666-
):
667-
raise ExecutionPlacementError(
668-
"Input and output allocation queues are not compatible"
669-
)
680+
_validate_out_array(out, exec_q)
670681

671682
# Determine the appropriate data types
672683
dot_dtype, res_dtype = _compute_res_dtype(a, b, sycl_queue=exec_q)
@@ -755,19 +766,17 @@ def dpnp_matmul(
755766

756767
dpnp.check_supported_arrays_type(x1, x2)
757768
res_usm_type, exec_q = get_usm_allocations([x1, x2])
758-
if out is not None:
759-
dpnp.check_supported_arrays_type(out)
760-
if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None:
761-
raise ExecutionPlacementError(
762-
"Input and output allocation queues are not compatible"
763-
)
769+
_validate_out_array(out, exec_q)
764770

765-
if order in ["a", "A"]:
771+
if order in "aA":
766772
if x1.flags.fnc and x2.flags.fnc:
767773
order = "F"
768774
else:
769775
order = "C"
770776

777+
if order in "kK":
778+
order = "C"
779+
771780
x1_ndim = x1.ndim
772781
x2_ndim = x2.ndim
773782
if axes is not None:
@@ -938,6 +947,7 @@ def dpnp_matmul(
938947
result,
939948
)
940949
else: # call_flag == "gemm_batch"
950+
assert call_flag == "gemm_batch"
941951
result = _gemm_batch_matmul(
942952
exec_q,
943953
x1,
@@ -962,14 +972,7 @@ def dpnp_matmul(
962972
result = dpnp.moveaxis(result, (-1,), axes_res)
963973
return dpnp.ascontiguousarray(result)
964974

965-
# If `order` was not passed as default
966-
# we need to update it to match the passed `order`.
967-
if order not in ["k", "K"]:
968-
return dpnp.asarray(result, order=order)
969-
# dpnp.ascontiguousarray changes 0-D array to 1-D array
970-
if result.ndim == 0:
971-
return result
972-
return dpnp.ascontiguousarray(result)
975+
return dpnp.asarray(result, order=order)
973976

974977
result = dpnp.get_result_array(result, out, casting=casting)
975978
if axes is not None and out is result:
@@ -994,14 +997,9 @@ def dpnp_vecdot(
994997

995998
dpnp.check_supported_arrays_type(x1, x2)
996999
res_usm_type, exec_q = get_usm_allocations([x1, x2])
997-
if out is not None:
998-
dpnp.check_supported_arrays_type(out)
999-
if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None:
1000-
raise ExecutionPlacementError(
1001-
"Input and output allocation queues are not compatible"
1002-
)
1000+
_validate_out_array(out, exec_q)
10031001

1004-
if order in ["a", "A"]:
1002+
if order in "aAkK":
10051003
if x1.flags.fnc and x2.flags.fnc:
10061004
order = "F"
10071005
else:
@@ -1048,7 +1046,7 @@ def dpnp_vecdot(
10481046
_, x1_is_1D, _ = _define_dim_flags(x1, axis=-1)
10491047
_, x2_is_1D, _ = _define_dim_flags(x2, axis=-1)
10501048

1051-
if numpy.prod(result_shape) == 0 or x1.size == 0 or x2.size == 0:
1049+
if x1.size == 0 or x2.size == 0:
10521050
order = "C" if order in "kK" else order
10531051
result = _create_result_array(
10541052
x1,
@@ -1060,8 +1058,9 @@ def dpnp_vecdot(
10601058
sycl_queue=exec_q,
10611059
order=order,
10621060
)
1063-
if x1.size == 0 or x2.size == 0:
1064-
result.fill(0)
1061+
if numpy.prod(result_shape) == 0:
1062+
return result
1063+
result.fill(0)
10651064
return result
10661065
elif x1_is_1D and x2_is_1D:
10671066
call_flag = "dot"
@@ -1079,6 +1078,7 @@ def dpnp_vecdot(
10791078
else:
10801079
result = dpnp_dot(x1, x2, out=out, conjugate=True)
10811080
else: # call_flag == "vecdot"
1081+
assert call_flag == "vecdot"
10821082
x1_usm = dpnp.get_usm_ndarray(x1)
10831083
x2_usm = dpnp.get_usm_ndarray(x2)
10841084
result = dpnp_array._create_from_usm_ndarray(
@@ -1091,13 +1091,6 @@ def dpnp_vecdot(
10911091
result = dpnp.reshape(result, result_shape)
10921092

10931093
if out is None:
1094-
# If `order` was not passed as default
1095-
# we need to update it to match the passed `order`.
1096-
if order not in "kK":
1097-
return dpnp.asarray(result, order=order)
1098-
# dpnp.ascontiguousarray changes 0-D array to 1-D array
1099-
if result.ndim == 0:
1100-
return result
1101-
return dpnp.ascontiguousarray(result)
1094+
return dpnp.asarray(result, order=order)
11021095

11031096
return dpnp.get_result_array(result, out, casting=casting)

tests/test_mathematical.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3648,6 +3648,8 @@ def test_matmul_dtype_matrix_inputs(self, dtype1, dtype2, shape_pair):
36483648
expected = numpy.matmul(a1, a2)
36493649
assert_dtype_allclose(result, expected)
36503650

3651+
@pytest.mark.parametrize("order1", ["C", "F", "A"])
3652+
@pytest.mark.parametrize("order2", ["C", "F", "A"])
36513653
@pytest.mark.parametrize("order", ["C", "F", "K", "A"])
36523654
@pytest.mark.parametrize(
36533655
"shape_pair",
@@ -3662,17 +3664,26 @@ def test_matmul_dtype_matrix_inputs(self, dtype1, dtype2, shape_pair):
36623664
"((6, 7, 4, 3), (6, 7, 3, 5))",
36633665
],
36643666
)
3665-
def test_matmul_order(self, order, shape_pair):
3667+
def test_matmul_order(self, order1, order2, order, shape_pair):
36663668
shape1, shape2 = shape_pair
3667-
a1 = numpy.arange(numpy.prod(shape1)).reshape(shape1)
3668-
a2 = numpy.arange(numpy.prod(shape2)).reshape(shape2)
3669+
a1 = numpy.arange(numpy.prod(shape1)).reshape(shape1, order=order1)
3670+
a2 = numpy.arange(numpy.prod(shape2)).reshape(shape2, order=order2)
36693671

36703672
b1 = dpnp.asarray(a1)
36713673
b2 = dpnp.asarray(a2)
36723674

36733675
result = dpnp.matmul(b1, b2, order=order)
36743676
expected = numpy.matmul(a1, a2, order=order)
3675-
assert result.flags.c_contiguous == expected.flags.c_contiguous
3677+
# For the special case of shape_pair == ((6, 7, 4, 3), (6, 7, 3, 5))
3678+
# and order1 == "F" and order2 == "F", NumPy result is not c-contiguous
3679+
# nor f-contiguous, while dpnp (and cupy) results are c-contiguous
3680+
if not (
3681+
shape_pair == ((6, 7, 4, 3), (6, 7, 3, 5))
3682+
and order1 == "F"
3683+
and order2 == "F"
3684+
and order == "K"
3685+
):
3686+
assert result.flags.c_contiguous == expected.flags.c_contiguous
36763687
assert result.flags.f_contiguous == expected.flags.f_contiguous
36773688
assert_dtype_allclose(result, expected)
36783689

tests/test_product.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,6 +1343,7 @@ def setup_method(self):
13431343
"shape_pair",
13441344
[
13451345
((4,), (4,)), # call_flag: dot
1346+
((1, 1, 4), (1, 1, 4)), # call_flag: dot
13461347
((3, 1), (3, 1)),
13471348
((2, 0), (2, 0)), # zero-size inputs, 1D output
13481349
((3, 0, 4), (3, 0, 4)), # zero-size output
@@ -1353,6 +1354,7 @@ def setup_method(self):
13531354
((3, 4), (4,)),
13541355
((1, 4, 5), (3, 1, 5)),
13551356
((1, 1, 4, 5), (3, 1, 5)),
1357+
((1, 4, 5), (1, 3, 1, 5)),
13561358
],
13571359
)
13581360
def test_basic(self, dtype, shape_pair):
@@ -1375,6 +1377,7 @@ def test_basic(self, dtype, shape_pair):
13751377
"shape_pair",
13761378
[
13771379
((4,), (4,)), # call_flag: dot
1380+
((1, 1, 4), (1, 1, 4)), # call_flag: dot
13781381
((3, 1), (3, 1)),
13791382
((2, 0), (2, 0)), # zero-size inputs, 1D output
13801383
((3, 0, 4), (3, 0, 4)), # zero-size output
@@ -1385,6 +1388,7 @@ def test_basic(self, dtype, shape_pair):
13851388
((3, 4), (4,)),
13861389
((1, 4, 5), (3, 1, 5)),
13871390
((1, 1, 4, 5), (3, 1, 5)),
1391+
((1, 4, 5), (1, 3, 1, 5)),
13881392
],
13891393
)
13901394
def test_complex(self, dtype, shape_pair):
@@ -1501,18 +1505,22 @@ def test_input_dtype_matrix(self, dtype1, dtype2):
15011505
expected = numpy.vecdot(a, b)
15021506
assert_dtype_allclose(result, expected)
15031507

1508+
@pytest.mark.parametrize("order1", ["C", "F", "A"])
1509+
@pytest.mark.parametrize("order2", ["C", "F", "A"])
15041510
@pytest.mark.parametrize("order", ["C", "F", "K", "A"])
15051511
@pytest.mark.parametrize(
15061512
"shape",
1507-
[((4, 3)), ((4, 3, 5)), ((6, 7, 3, 5))],
1508-
ids=["((4, 3))", "((4, 3, 5))", "((6, 7, 3, 5))"],
1513+
[(4, 3), (4, 3, 5), (6, 7, 3, 5)],
1514+
ids=["(4, 3)", "(4, 3, 5)", "(6, 7, 3, 5)"],
15091515
)
1510-
def test_order(self, order, shape):
1511-
a = numpy.arange(numpy.prod(shape)).reshape(shape)
1512-
b = dpnp.asarray(a)
1513-
1514-
result = dpnp.vecdot(b, b, order=order)
1515-
expected = numpy.vecdot(a, a, order=order)
1516+
def test_order(self, order1, order2, order, shape):
1517+
a = numpy.arange(numpy.prod(shape)).reshape(shape, order=order1)
1518+
b = numpy.arange(numpy.prod(shape)).reshape(shape, order=order2)
1519+
a_dp = dpnp.asarray(a)
1520+
b_dp = dpnp.asarray(b)
1521+
1522+
result = dpnp.vecdot(a_dp, b_dp, order=order)
1523+
expected = numpy.vecdot(a, b, order=order)
15161524
assert result.flags.c_contiguous == expected.flags.c_contiguous
15171525
assert result.flags.f_contiguous == expected.flags.f_contiguous
15181526
assert_dtype_allclose(result, expected)

0 commit comments

Comments
 (0)