@@ -244,6 +244,7 @@ def _get_result_shape(x1, x2, out, func, np_flag):
244
244
x1 , x2 , x1_ndim , x2_ndim
245
245
)
246
246
else : # func == "vecdot"
247
+ assert func == "vecdot"
247
248
x1 , x2 , result_shape = _get_result_shape_vecdot (
248
249
x1 , x2 , x1_ndim , x2_ndim
249
250
)
@@ -466,11 +467,15 @@ def _gemm_matmul(exec_q, x1, x2, res):
466
467
467
468
468
469
def _shape_error (shape1 , shape2 , func , err_msg ):
470
+ """Validate the shapes of input and output arrays."""
469
471
470
472
if func == "matmul" :
471
473
signature = "(n?,k),(k,m?)->(n?,m?)"
472
- else : # func == "vecdot"
474
+ elif func == "vecdot" :
473
475
signature = "(n?,),(n?,)->()"
476
+ else :
477
+ # applicable when err_msg == 3
478
+ assert func is None
474
479
475
480
if err_msg == 0 :
476
481
raise ValueError (
@@ -485,7 +490,8 @@ def _shape_error(shape1, shape2, func, err_msg):
485
490
f"array has shape { shape2 } . "
486
491
f"These cannot be broadcast together for '{ func } ' function."
487
492
)
488
- elif err_msg == 2 :
493
+ else : # err_msg == 2:
494
+ assert err_msg == 2
489
495
raise ValueError (
490
496
f"Expected output array of shape { shape1 } , but got { shape2 } ."
491
497
)
@@ -557,6 +563,7 @@ def _validate_internal(axes, i, ndim):
557
563
x1_ndim = x1 .ndim
558
564
x2_ndim = x2 .ndim
559
565
else : # func == "vecdot"
566
+ assert func == "vecdot"
560
567
x1_ndim = x2_ndim = 1
561
568
562
569
axes [0 ] = _validate_internal (axes [0 ], 0 , x1_ndim )
@@ -573,6 +580,16 @@ def _validate_internal(axes, i, ndim):
573
580
return axes
574
581
575
582
583
+ def _validate_out_array (out , exec_q ):
584
+ """Validate out is supported array and has correct queue."""
585
+ if out is not None :
586
+ dpnp .check_supported_arrays_type (out )
587
+ if dpctl .utils .get_execution_queue ((exec_q , out .sycl_queue )) is None :
588
+ raise ExecutionPlacementError (
589
+ "Input and output allocation queues are not compatible"
590
+ )
591
+
592
+
576
593
def dpnp_cross (a , b , cp ):
577
594
"""Return the cross product of two (arrays of) vectors."""
578
595
@@ -660,13 +677,7 @@ def dpnp_dot(a, b, /, out=None, *, conjugate=False):
660
677
)
661
678
662
679
res_usm_type , exec_q = get_usm_allocations ([a , b ])
663
- if (
664
- out is not None
665
- and dpctl .utils .get_execution_queue ((exec_q , out .sycl_queue )) is None
666
- ):
667
- raise ExecutionPlacementError (
668
- "Input and output allocation queues are not compatible"
669
- )
680
+ _validate_out_array (out , exec_q )
670
681
671
682
# Determine the appropriate data types
672
683
dot_dtype , res_dtype = _compute_res_dtype (a , b , sycl_queue = exec_q )
@@ -755,19 +766,17 @@ def dpnp_matmul(
755
766
756
767
dpnp .check_supported_arrays_type (x1 , x2 )
757
768
res_usm_type , exec_q = get_usm_allocations ([x1 , x2 ])
758
- if out is not None :
759
- dpnp .check_supported_arrays_type (out )
760
- if dpctl .utils .get_execution_queue ((exec_q , out .sycl_queue )) is None :
761
- raise ExecutionPlacementError (
762
- "Input and output allocation queues are not compatible"
763
- )
769
+ _validate_out_array (out , exec_q )
764
770
765
- if order in [ "a" , "A" ] :
771
+ if order in "aA" :
766
772
if x1 .flags .fnc and x2 .flags .fnc :
767
773
order = "F"
768
774
else :
769
775
order = "C"
770
776
777
+ if order in "kK" :
778
+ order = "C"
779
+
771
780
x1_ndim = x1 .ndim
772
781
x2_ndim = x2 .ndim
773
782
if axes is not None :
@@ -938,6 +947,7 @@ def dpnp_matmul(
938
947
result ,
939
948
)
940
949
else : # call_flag == "gemm_batch"
950
+ assert call_flag == "gemm_batch"
941
951
result = _gemm_batch_matmul (
942
952
exec_q ,
943
953
x1 ,
@@ -962,14 +972,7 @@ def dpnp_matmul(
962
972
result = dpnp .moveaxis (result , (- 1 ,), axes_res )
963
973
return dpnp .ascontiguousarray (result )
964
974
965
- # If `order` was not passed as default
966
- # we need to update it to match the passed `order`.
967
- if order not in ["k" , "K" ]:
968
- return dpnp .asarray (result , order = order )
969
- # dpnp.ascontiguousarray changes 0-D array to 1-D array
970
- if result .ndim == 0 :
971
- return result
972
- return dpnp .ascontiguousarray (result )
975
+ return dpnp .asarray (result , order = order )
973
976
974
977
result = dpnp .get_result_array (result , out , casting = casting )
975
978
if axes is not None and out is result :
@@ -994,14 +997,9 @@ def dpnp_vecdot(
994
997
995
998
dpnp .check_supported_arrays_type (x1 , x2 )
996
999
res_usm_type , exec_q = get_usm_allocations ([x1 , x2 ])
997
- if out is not None :
998
- dpnp .check_supported_arrays_type (out )
999
- if dpctl .utils .get_execution_queue ((exec_q , out .sycl_queue )) is None :
1000
- raise ExecutionPlacementError (
1001
- "Input and output allocation queues are not compatible"
1002
- )
1000
+ _validate_out_array (out , exec_q )
1003
1001
1004
- if order in [ "a" , "A" ] :
1002
+ if order in "aAkK" :
1005
1003
if x1 .flags .fnc and x2 .flags .fnc :
1006
1004
order = "F"
1007
1005
else :
@@ -1048,7 +1046,7 @@ def dpnp_vecdot(
1048
1046
_ , x1_is_1D , _ = _define_dim_flags (x1 , axis = - 1 )
1049
1047
_ , x2_is_1D , _ = _define_dim_flags (x2 , axis = - 1 )
1050
1048
1051
- if numpy . prod ( result_shape ) == 0 or x1 .size == 0 or x2 .size == 0 :
1049
+ if x1 .size == 0 or x2 .size == 0 :
1052
1050
order = "C" if order in "kK" else order
1053
1051
result = _create_result_array (
1054
1052
x1 ,
@@ -1060,8 +1058,9 @@ def dpnp_vecdot(
1060
1058
sycl_queue = exec_q ,
1061
1059
order = order ,
1062
1060
)
1063
- if x1 .size == 0 or x2 .size == 0 :
1064
- result .fill (0 )
1061
+ if numpy .prod (result_shape ) == 0 :
1062
+ return result
1063
+ result .fill (0 )
1065
1064
return result
1066
1065
elif x1_is_1D and x2_is_1D :
1067
1066
call_flag = "dot"
@@ -1079,6 +1078,7 @@ def dpnp_vecdot(
1079
1078
else :
1080
1079
result = dpnp_dot (x1 , x2 , out = out , conjugate = True )
1081
1080
else : # call_flag == "vecdot"
1081
+ assert call_flag == "vecdot"
1082
1082
x1_usm = dpnp .get_usm_ndarray (x1 )
1083
1083
x2_usm = dpnp .get_usm_ndarray (x2 )
1084
1084
result = dpnp_array ._create_from_usm_ndarray (
@@ -1091,13 +1091,6 @@ def dpnp_vecdot(
1091
1091
result = dpnp .reshape (result , result_shape )
1092
1092
1093
1093
if out is None :
1094
- # If `order` was not passed as default
1095
- # we need to update it to match the passed `order`.
1096
- if order not in "kK" :
1097
- return dpnp .asarray (result , order = order )
1098
- # dpnp.ascontiguousarray changes 0-D array to 1-D array
1099
- if result .ndim == 0 :
1100
- return result
1101
- return dpnp .ascontiguousarray (result )
1094
+ return dpnp .asarray (result , order = order )
1102
1095
1103
1096
return dpnp .get_result_array (result , out , casting = casting )
0 commit comments