diff --git a/dpctl/tensor/CMakeLists.txt b/dpctl/tensor/CMakeLists.txt index d2947aa772..e749b40c9d 100644 --- a/dpctl/tensor/CMakeLists.txt +++ b/dpctl/tensor/CMakeLists.txt @@ -156,6 +156,15 @@ set(_tensor_sorting_impl_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting.cpp ${_sorting_sources} ) +set(_linalg_sources + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_functions_type_utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linalg_functions/dot.cpp +) +set(_tensor_linalg_impl_sources + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_linalg.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/simplify_iteration_space.cpp + ${_linalg_sources} +) set(_py_trgts) @@ -179,6 +188,11 @@ pybind11_add_module(${python_module_name} MODULE ${_tensor_sorting_impl_sources} add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_impl_sources}) list(APPEND _py_trgts ${python_module_name}) +set(python_module_name _tensor_linalg_impl) +pybind11_add_module(${python_module_name} MODULE ${_tensor_linalg_impl_sources}) +add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_linalg_impl_sources}) +list(APPEND _py_trgts ${python_module_name}) + set(_clang_prefix "") if (WIN32) set(_clang_prefix "/clang:") @@ -193,6 +207,7 @@ list(APPEND _no_fast_math_sources ${_elementwise_sources} ${_reduction_sources} ${_sorting_sources} + ${_linalg_sources} ) foreach(_src_fn ${_no_fast_math_sources}) diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index 81fc152e7a..ef8e604952 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -60,7 +60,12 @@ from dpctl.tensor._device import Device from dpctl.tensor._dlpack import from_dlpack from dpctl.tensor._indexing_functions import extract, nonzero, place, put, take -from dpctl.tensor._linear_algebra_functions import matrix_transpose +from dpctl.tensor._linear_algebra_functions import ( + matmul, + matrix_transpose, + tensordot, + vecdot, +) from dpctl.tensor._manipulation_functions import ( broadcast_arrays, broadcast_to, @@ -356,4 +361,7 @@ "unique_counts", "unique_inverse", "unique_values", + "matmul", + "tensordot", + "vecdot", ] diff --git a/dpctl/tensor/_linear_algebra_functions.py b/dpctl/tensor/_linear_algebra_functions.py index fd2c58b08a..0894ac2077 100644 --- a/dpctl/tensor/_linear_algebra_functions.py +++ b/dpctl/tensor/_linear_algebra_functions.py @@ -14,7 +14,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import operator + +from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple + +import dpctl import dpctl.tensor as dpt +import dpctl.tensor._tensor_elementwise_impl as tei +import dpctl.tensor._tensor_impl as ti +import dpctl.tensor._tensor_linalg_impl as tli +from dpctl.tensor._copy_utils import _empty_like_orderK, _empty_like_pair_orderK +from dpctl.tensor._manipulation_functions import _broadcast_shape_impl +from dpctl.tensor._type_utils import ( + _acceptance_fn_default_binary, + _find_buf_dtype2, + _to_device_supported_dtype, +) +from dpctl.utils import ExecutionPlacementError def matrix_transpose(x): @@ -46,3 +62,921 @@ def matrix_transpose(x): ) return x.mT + + +def tensordot(x1, x2, axes=2): + """tensordot(x1, x2, axes=2) + + Returns a tensor contraction of `x1` and `x2` over specific axes. + + Args: + x1 (usm_ndarray): + first input array, expected to have numeric data type. + x2 (usm_ndarray): + second input array, expected to have numeric data type. + Corresponding contracted axes of `x1` and `x2` must be equal. + axes (Union[int, Tuple[Sequence[int], Sequence[int]]): + number of axes to contract or explicit sequences of axes for + `x1` and `x2`, respectively. If `axes` is an integer equal to `N`, + then the contraction is performed over last `N` axes of `x1` and + the first `N` axis of `x2` in order. The size of each corresponding + axis must match and must be non-negative. + * if `N` equals `0`, the result is the tensor outer product + * if `N` equals `1`, the result is the tensor dot product + * if `N` equals `2`, the result is the tensor double + contraction (default). + If `axes` is a tuple of two sequences `(x1_axes, x2_axes)`, the + first sequence applies to `x1` and the second sequence applies + to `x2`. Both sequences must have equal length, and each axis + `x1_axes[i]` for `x1` must have the same size as the respective + axis `x2_axes[i]` for `x2`. Each sequence must consist of unique + non-negative integers that specify valid axes for each respective + array. + Returns: + usm_ndarray: + an array containing the tensor contraction whose shape consists of + the non-contracted axes of the first array `x1`, followed by the + non-contracted axes of the second array `x2`. The returned array + must have a data type determined by Type Promotion Rules. + """ + if not isinstance(x1, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x1)}") + if not isinstance(x2, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x2)}") + q1, x1_usm_type = x1.sycl_queue, x1.usm_type + q2, x2_usm_type = x2.sycl_queue, x2.usm_type + exec_q = dpctl.utils.get_execution_queue((q1, q2)) + if exec_q is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + res_usm_type = dpctl.utils.get_coerced_usm_type( + ( + x1_usm_type, + x2_usm_type, + ) + ) + dpctl.utils.validate_usm_type(res_usm_type, allow_none=False) + # handle axes and shapes validation + x1_nd = x1.ndim + x2_nd = x2.ndim + x1_shape = x1.shape + x2_shape = x2.shape + if isinstance(axes, int): + if axes < 0: + raise ValueError("`axes` integer is expected to be non-negative") + n_axes1 = axes + n_axes2 = axes + axes1 = normalize_axis_tuple(tuple(range(-axes, 0)), x1_nd) + axes2 = tuple(range(0, axes)) + elif isinstance(axes, tuple): + if len(axes) != 2: + raise ValueError( + "`axes` tuple is expected to contain two sequences" + ) + axes1 = tuple(axes[0]) + axes2 = tuple(axes[1]) + n_axes1 = len(axes1) + n_axes2 = len(axes2) + else: + raise TypeError("`axes` must be an integer or a tuple of sequences") + if n_axes1 != n_axes2: + raise ValueError( + "number of axes contracted must be the same for each array" + ) + if n_axes1 == 0: + arr1 = x1[..., dpt.newaxis] + arr2 = x2[dpt.newaxis, ...] + n_axes1 = 1 + n_axes2 = 1 + else: + same_shapes = True + for i in range(n_axes1): + axis1 = axes1[i] + if axis1 < 0: + raise ValueError("`axes` must be non-negative") + axis2 = axes2[i] + if axis2 < 0: + raise ValueError("`axes` must be non-negative") + same_shapes = same_shapes and (x1_shape[axis1] == x2_shape[axis2]) + if not same_shapes: + raise ValueError("shape mismatch in contracted `tensordot` axes") + axes1 = normalize_axis_tuple(axes1, x1_nd) + axes2 = normalize_axis_tuple(axes2, x2_nd) + perm1 = [i for i in range(x1_nd) if i not in axes1] + list(axes1) + perm2 = list(axes2) + [i for i in range(x2_nd) if i not in axes2] + arr1 = dpt.permute_dims(x1, perm1) + arr2 = dpt.permute_dims(x2, perm2) + arr1_outer_nd = arr1.ndim - n_axes1 + arr2_outer_nd = arr2.ndim - n_axes2 + res_shape = arr1.shape[:arr1_outer_nd] + arr2.shape[n_axes2:] + # type validation + sycl_dev = exec_q.sycl_device + x1_dtype = x1.dtype + x2_dtype = x2.dtype + buf1_dt, buf2_dt, res_dt = _find_buf_dtype2( + x1_dtype, + x2_dtype, + tli._dot_result_type, + sycl_dev, + acceptance_fn=_acceptance_fn_default_binary, + ) + if res_dt is None: + raise TypeError( + "function 'tensordot' does not support input types " + f"({x1_dtype}, {x2_dtype}), " + "and the inputs could not be safely coerced to any " + "supported types according to the casting rule ''safe''." + ) + + if buf1_dt is None and buf2_dt is None: + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_dot_ev, _ = tli._dot( + x1=arr1, + x2=arr2, + batch_dims=0, + x1_outer_dims=arr1_outer_nd, + x2_outer_dims=arr2_outer_nd, + inner_dims=n_axes1, + dst=out, + sycl_queue=exec_q, + ) + ht_dot_ev.wait() + + return out + + elif buf1_dt is None: + buf2 = _empty_like_orderK(arr2, buf2_dt) + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr2, dst=buf2, sycl_queue=exec_q + ) + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_dot_ev, _ = tli._dot( + x1=arr1, + x2=buf2, + batch_dims=0, + x1_outer_dims=arr1_outer_nd, + x2_outer_dims=arr2_outer_nd, + inner_dims=n_axes1, + dst=out, + sycl_queue=exec_q, + depends=[copy_ev], + ) + ht_copy_ev.wait() + ht_dot_ev.wait() + + return out + + elif buf2_dt is None: + buf1 = _empty_like_orderK(arr1, buf1_dt) + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr1, dst=buf1, sycl_queue=exec_q + ) + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_dot_ev, _ = tli._dot( + x1=buf1, + x2=arr2, + batch_dims=0, + x1_outer_dims=arr1_outer_nd, + x2_outer_dims=arr2_outer_nd, + inner_dims=n_axes1, + dst=out, + sycl_queue=exec_q, + depends=[copy_ev], + ) + ht_copy_ev.wait() + ht_dot_ev.wait() + + return out + + buf1 = _empty_like_orderK(arr1, buf1_dt) + ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr1, dst=buf1, sycl_queue=exec_q + ) + buf2 = _empty_like_orderK(arr2, buf2_dt) + ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr2, dst=buf2, sycl_queue=exec_q + ) + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_, _ = tli._dot( + x1=buf1, + x2=buf2, + batch_dims=0, + x1_outer_dims=arr1_outer_nd, + x2_outer_dims=arr2_outer_nd, + inner_dims=n_axes1, + dst=out, + sycl_queue=exec_q, + depends=[copy1_ev, copy2_ev], + ) + dpctl.SyclEvent.wait_for([ht_copy1_ev, ht_copy2_ev, ht_]) + + return out + + +def vecdot(x1, x2, axis=-1): + """vecdot(x1, x2, axis=-1) + + Computes the (vector) dot product of two arrays. + + Args: + x1 (usm_ndarray): + first input array. + x2 (usm_ndarray): + second input array. Input arrays must have compatible + shapes along non-contract axes according to broadcasting + rules, and must have the same size along the contracted + axis. Input arrays should be of numeric type. + axis (Optional[int]): + axis over which to compute the dot product. The axis must + be an integer on the interval `[-N, N)`, where `N` is the + array rank of input arrays after broadcasting rules are + applied. If specified as a negative integer, the axis along + which dot product is performed is counted backward from + the last axes (that is `-1` refers to the last axis). By + default, dot product is computed over the last axis. + Default: `-1`. + + Returns: + usm_ndarray: + if `x1` and `x2` are both one-dimensional arrays, a + zero-dimensional array containing the dot product value + is returned; otherwise, a non-zero-dimensional array containing + the dot products and having rank `N-1`, where `N` is the rank + of the shape of input arrays after broadcasting rules are applied + to non-contracted axes. + """ + if not isinstance(x1, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x1)}") + if not isinstance(x2, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x2)}") + q1, x1_usm_type = x1.sycl_queue, x1.usm_type + q2, x2_usm_type = x2.sycl_queue, x2.usm_type + exec_q = dpctl.utils.get_execution_queue((q1, q2)) + if exec_q is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + res_usm_type = dpctl.utils.get_coerced_usm_type( + ( + x1_usm_type, + x2_usm_type, + ) + ) + dpctl.utils.validate_usm_type(res_usm_type, allow_none=False) + # axis and shape validation + x1_nd = x1.ndim + x2_nd = x2.ndim + x1_shape = x1.shape + x2_shape = x2.shape + if x1_nd > x2_nd: + x2_shape = (1,) * (x1_nd - x2_nd) + x2_shape + x2_nd = len(x2_shape) + elif x2_nd > x1_nd: + x1_shape = (1,) * (x2_nd - x1_nd) + x1_shape + x1_nd = len(x1_shape) + axis = normalize_axis_index(operator.index(axis), x1_nd) + if x1_shape[axis] != x2_shape[axis]: + raise ValueError( + "given axis must have the same shape for `x1` and `x2`" + ) + try: + broadcast_sh = _broadcast_shape_impl( + [ + x1_shape, + x2_shape, + ] + ) + except ValueError: + raise ValueError("mismatch in `vecdot` dimensions") + res_sh = tuple( + [broadcast_sh[i] for i in range(len(broadcast_sh)) if i != axis] + ) + # type validation + sycl_dev = exec_q.sycl_device + x1_dtype = x1.dtype + x2_dtype = x2.dtype + buf1_dt, buf2_dt, res_dt = _find_buf_dtype2( + x1_dtype, + x2_dtype, + tli._dot_result_type, + sycl_dev, + acceptance_fn=_acceptance_fn_default_binary, + ) + if res_dt is None: + raise TypeError( + "function 'vecdot' does not support input types " + f"({x1_dtype}, {x2_dtype}), " + "and the inputs could not be safely coerced to any " + "supported types according to the casting rule ''safe''." + ) + + ht_list = [] + deps = [] + if buf1_dt is None and buf2_dt is None: + if x1.dtype.kind == "c": + x1_tmp = _empty_like_orderK(x1, x1.dtype) + ht_conj_ev, conj_ev = tei._conj( + src=x1, + dst=x1_tmp, + sycl_queue=exec_q, + ) + ht_list.append(ht_conj_ev) + deps.append(conj_ev) + x1 = x1_tmp + if x1.shape != broadcast_sh: + x1 = dpt.broadcast_to(x1, broadcast_sh) + if x2.shape != broadcast_sh: + x2 = dpt.broadcast_to(x2, broadcast_sh) + x1 = dpt.moveaxis(x1, axis, -1) + x2 = dpt.moveaxis(x2, axis, -1) + + out = dpt.empty( + res_sh, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_dot_ev, _ = tli._dot( + x1=x1, + x2=x2, + batch_dims=len(x1.shape[:-1]), + x1_outer_dims=0, + x2_outer_dims=0, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=deps, + ) + ht_list.append(ht_dot_ev) + dpctl.SyclEvent.wait_for(ht_list) + + return dpt.reshape(out, res_sh) + + elif buf1_dt is None: + if x1.dtype.kind == "c": + x1_tmp = _empty_like_orderK(x1, x1.dtype) + ht_conj_ev, conj_e = tei._conj( + src=x1, dst=x1_tmp, sycl_queue=exec_q + ) + ht_list.append(ht_conj_ev) + deps.append(conj_e) + x1 = x1_tmp + buf2 = _empty_like_orderK(x2, buf2_dt) + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x2, dst=buf2, sycl_queue=exec_q + ) + ht_list.append(ht_copy_ev) + deps.append(copy_ev) + if x1.shape != broadcast_sh: + x1 = dpt.broadcast_to(x1, broadcast_sh) + if buf2.shape != broadcast_sh: + buf2 = dpt.broadcast_to(buf2, broadcast_sh) + x1 = dpt.moveaxis(x1, axis, -1) + buf2 = dpt.moveaxis(buf2, axis, -1) + out = dpt.empty( + res_sh, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_dot_ev, _ = tli._dot( + x1=x1, + x2=buf2, + batch_dims=len(x1.shape[:-1]), + x1_outer_dims=0, + x2_outer_dims=0, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=deps, + ) + ht_list.append(ht_dot_ev) + dpctl.SyclEvent.wait_for(ht_list) + + return dpt.reshape(out, res_sh) + + elif buf2_dt is None: + buf1 = _empty_like_orderK(x1, buf1_dt) + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x1, dst=buf1, sycl_queue=exec_q + ) + ht_list.append(ht_copy_ev) + deps.append(copy_ev) + if buf1.dtype.kind == "c": + ht_conj_ev, conj_ev = tei._conj( + src=buf1, dst=buf1, sycl_queue=exec_q, depends=[copy_ev] + ) + ht_list.append(ht_conj_ev) + deps.append(conj_ev) + if buf1.shape != broadcast_sh: + buf1 = dpt.broadcast_to(buf1, broadcast_sh) + if x2.shape != broadcast_sh: + x2 = dpt.broadcast_to(x2, broadcast_sh) + buf1 = dpt.moveaxis(buf1, axis, -1) + x2 = dpt.moveaxis(x2, axis, -1) + out = dpt.empty( + res_sh, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_dot_ev, _ = tli._dot( + x1=buf1, + x2=x2, + batch_dims=len(x1.shape[:-1]), + x1_outer_dims=0, + x2_outer_dims=0, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=deps, + ) + ht_list.append(ht_dot_ev) + dpctl.SyclEvent.wait_for(ht_list) + + return dpt.reshape(out, res_sh) + + buf1 = _empty_like_orderK(x1, buf1_dt) + ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x1, dst=buf1, sycl_queue=exec_q + ) + ht_list.append(ht_copy1_ev) + deps.append(copy1_ev) + if buf1.dtype.kind == "c": + ht_conj_ev, conj_ev = tei._conj( + src=buf1, dst=buf1, sycl_queue=exec_q, depends=[copy1_ev] + ) + ht_list.append(ht_conj_ev) + deps.append(conj_ev) + buf2 = _empty_like_orderK(x2, buf2_dt) + ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x2, dst=buf2, sycl_queue=exec_q + ) + ht_list.append(ht_copy2_ev) + deps.append(copy2_ev) + if buf1.shape != broadcast_sh: + buf1 = dpt.broadcast_to(buf1, broadcast_sh) + if buf2.shape != broadcast_sh: + buf2 = dpt.broadcast_to(buf2, broadcast_sh) + buf1 = dpt.moveaxis(buf1, axis, -1) + buf2 = dpt.moveaxis(buf2, axis, -1) + out = dpt.empty( + res_sh, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_dot_ev, _ = tli._dot( + x1=buf1, + x2=buf2, + batch_dims=len(x1.shape[:-1]), + x1_outer_dims=0, + x2_outer_dims=0, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=deps, + ) + ht_list.append(ht_dot_ev) + dpctl.SyclEvent.wait_for(ht_list) + + return out + + +def matmul(x1, x2, out=None, dtype=None, order="K"): + """matmul(x1, x2, out=None, order="K") + + Computes the matrix product. Implements the same semantics + as the built-in operator `@`. + + Args: + x1 (usm_ndarray): + first input array. Expected to have numeric data type, and + at least one dimension. If `x1` is one-dimensional having + shape `(M,)`, and `x2` has more than one dimension, `x1` is + effectively treated as a two-dimensional array with shape `(1, M)`, + although the prepended dimension is removed from the output array. + If `x1` has shape `(..., M, K)`, the innermost two dimensions form + matrices on which to perform matrix multiplication. + x2 (usm_ndarray): + second input array. Expected to have numeric data type, and + at least one dimension. If `x2` is one-dimensional having + shape `(N,)`, and `x1` has more than one dimension, `x2` is + effectively treated as a two-dimensional array with shape `(N, 1)`, + although the appended dimension is removed from the output array. + If `x2` has shape `(..., K, N)`, the innermost two dimensions form + matrices on which to perform matrix multiplication. + out (Optional[usm_ndarray]): + the array into which the result of the matrix product is written. + If `None` then a new array is returned. + order (["K", "C", "F", "A"]): + memory layout of the output array, if `out` is `None`, otherwise + the `order` parameter value is not used. + + Returns: + usm_ndarray: + * if both `x1` and `x2` are one-dimensional arrays with shape + `(N,)`, returned array is a zero-dimensional array containing + inner product as its only element. + * if `x1` is two-dimensional array with shape `(M, K)` and `x2` is + a two-dimensional array with shape `(K, N)`, returned array is a + two-dimensional array with shape `(M, N)` and contains the + conventional matrix product. + * if `x1` is a one-dimensinal array with shape `(K,)` and `x2` is an + array with shape `(..., K, N)`, returned array contains the + conventional matrix product and has shape `(..., N)`. + * if `x1` is an array with shape `(..., M, K)` and `x2` is a + one-dimensional array with shape `(K,)`, returned array has shape + `(..., M)` and contains the conventional matrix product. + * if `x1` is a two-dimensional array with shape `(M, K)` and `x2` + is an array with shape `(..., K, N)`, returned array contains + conventional matrix product for each stacked matrix and has shape + `(..., M, N)`. + * if `x1` has shape `(..., M, K)` and `x2` is a two-dimensional + array with shape `(K, N)`, returned array contains conventional + matrix product for each stacked matrix and has shape + `(..., M, N)`. + * if both `x1` and `x2` have more than two dimensions, returned + array contains conventional matrix product for each stacked + matrix and has shape determined by broadcasting rules for + `x1.shape[:-2]` and `x2.shape[:-2]`. + + The data type of the returned array is determined by the Type + Promotion Rules. If either `x1` or `x2` has a complex floating + point type, neither argument is complex conjugated or transposed. + """ + if not isinstance(x1, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x1)}") + if not isinstance(x2, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x2)}") + if order not in ["K", "C", "F", "A"]: + order = "K" + q1, x1_usm_type = x1.sycl_queue, x1.usm_type + q2, x2_usm_type = x2.sycl_queue, x2.usm_type + exec_q = dpctl.utils.get_execution_queue((q1, q2)) + if exec_q is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + res_usm_type = dpctl.utils.get_coerced_usm_type( + ( + x1_usm_type, + x2_usm_type, + ) + ) + dpctl.utils.validate_usm_type(res_usm_type, allow_none=False) + + x1_nd = x1.ndim + x2_nd = x2.ndim + if x1_nd == 0 or x2_nd == 0: + raise ValueError("one or more operands to `matmul` is 0 dimensional") + x1_shape = x1.shape + x2_shape = x2.shape + appended_axes = [] + if x1_nd == 1: + x1 = x1[dpt.newaxis, :] + x1_shape = x1.shape + appended_axes.append(-2) + if x2_nd == 1: + x2 = x2[:, dpt.newaxis] + x2_shape = x2.shape + appended_axes.append(-1) + if x1_shape[-1] != x2_shape[-2]: + raise ValueError("mismatch in `matmul` inner dimension") + x1_outer_sh = x1_shape[:-2] + x2_outer_sh = x2_shape[:-2] + try: + res_outer_sh = _broadcast_shape_impl( + [ + x1_outer_sh, + x2_outer_sh, + ] + ) + except ValueError: + raise ValueError("mismatch in `matmul` batching dimensions") + x1_broadcast_shape = res_outer_sh + x1_shape[-2:] + x2_broadcast_shape = res_outer_sh + x2_shape[-2:] + res_shape = res_outer_sh + x1_shape[-2:-1] + x2_shape[-1:] + + sycl_dev = exec_q.sycl_device + x1_dtype = x1.dtype + x2_dtype = x2.dtype + if dtype is None: + buf1_dt, buf2_dt, res_dt = _find_buf_dtype2( + x1_dtype, + x2_dtype, + tli._dot_result_type, + sycl_dev, + acceptance_fn=_acceptance_fn_default_binary, + ) + if res_dt is None: + raise ValueError( + "function 'matmul' does not support input types " + f"({x1_dtype}, {x2_dtype}), " + "and the inputs could not be safely coerced to any " + "supported types according to the casting rule ''safe''." + ) + else: + res_dt = dpt.dtype(dtype) + res_dt = _to_device_supported_dtype(res_dt, sycl_dev) + buf1_dt, buf2_dt = None, None + if x1_dtype != res_dt: + if dpt.can_cast(x1_dtype, res_dt, casting="same_kind"): + buf1_dt = res_dt + else: + raise ValueError( + f"`matmul` input `x1` cannot be cast from {x1_dtype} to " + f"requested type {res_dt} according to the casting rule " + "''same_kind''." + ) + if x2_dtype != res_dt: + if dpt.can_cast(x2_dtype, res_dt, casting="same_kind"): + buf2_dt = res_dt + else: + raise ValueError( + f"`matmul` input `x2` cannot be cast from {x2_dtype} to " + f"requested type {res_dt} according to the casting rule " + "''same_kind''." + ) + + orig_out = out + if out is not None: + if not isinstance(out, dpt.usm_ndarray): + raise TypeError( + f"output array must be of usm_ndarray type, got {type(out)}" + ) + + if out.shape != res_shape: + raise ValueError( + "The shape of input and output arrays are inconsistent. " + f"Expected output shape is {res_shape}, got {out.shape}" + ) + + if res_dt != out.dtype: + raise ValueError( + f"Output array of type {res_dt} is needed," f"got {out.dtype}" + ) + + if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None: + raise ExecutionPlacementError( + "Input and output allocation queues are not compatible" + ) + + if ti._array_overlap(x1, out) and buf1_dt is None: + out = dpt.empty_like(out) + + if ti._array_overlap(x2, out) and buf2_dt is None: + # should not reach if out is reallocated + # after being checked against x1 + out = dpt.empty_like(out) + + if buf1_dt is None and buf2_dt is None: + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + x1, x2, res_dt, res_shape, res_usm_type, exec_q + ) + else: + if order == "A": + order = ( + "F" + if all( + arr.flags.f_contiguous + for arr in ( + x1, + x2, + ) + ) + else "C" + ) + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + if x1.shape != x1_broadcast_shape: + x1 = dpt.broadcast_to(x1, x1_broadcast_shape) + if x2.shape != x2_broadcast_shape: + x2 = dpt.broadcast_to(x2, x2_broadcast_shape) + ht_dot_ev, dot_ev = tli._dot( + x1=x1, + x2=x2, + batch_dims=len(res_shape[:-2]), + x1_outer_dims=1, + x2_outer_dims=1, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + ) + if not (orig_out is None or orig_out is out): + # Copy the out data from temporary buffer to original memory + ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, + dst=orig_out, + sycl_queue=exec_q, + depends=[dot_ev], + ) + ht_copy_out_ev.wait() + out = orig_out + ht_dot_ev.wait() + if appended_axes: + out = dpt.squeeze(out, tuple(appended_axes)) + return out + elif buf1_dt is None: + if order == "K": + buf2 = _empty_like_orderK(x2, buf2_dt) + else: + if order == "A": + order = "F" if x1.flags.f_contiguous else "C" + buf2 = dpt.empty_like(x2, dtype=buf2_dt, order=order) + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x2, dst=buf2, sycl_queue=exec_q + ) + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + x1, buf2, res_dt, res_shape, res_usm_type, exec_q + ) + else: + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + + if x1.shape != x1_broadcast_shape: + x1 = dpt.broadcast_to(x1, x1_broadcast_shape) + if buf2.shape != x2_broadcast_shape: + buf2 = dpt.broadcast_to(buf2, x2_broadcast_shape) + ht_dot_ev, dot_ev = tli._dot( + x1=x1, + x2=buf2, + batch_dims=len(res_shape[:-2]), + x1_outer_dims=1, + x2_outer_dims=1, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=[copy_ev], + ) + if not (orig_out is None or orig_out is out): + # Copy the out data from temporary buffer to original memory + ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, + dst=orig_out, + sycl_queue=exec_q, + depends=[dot_ev], + ) + ht_copy_out_ev.wait() + out = orig_out + ht_copy_ev.wait() + ht_dot_ev.wait() + if appended_axes: + out = dpt.squeeze(out, tuple(appended_axes)) + return out + + elif buf2_dt is None: + if order == "K": + buf1 = _empty_like_orderK(x1, buf1_dt) + else: + if order == "A": + order = "F" if x1.flags.f_contiguous else "C" + buf1 = dpt.empty_like(x1, dtype=buf1_dt, order=order) + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x1, dst=buf1, sycl_queue=exec_q + ) + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + buf1, x2, res_dt, res_shape, res_usm_type, exec_q + ) + else: + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + + if buf1.shape != x1_broadcast_shape: + buf1 = dpt.broadcast_to(buf1, x1_broadcast_shape) + if x2.shape != x2_broadcast_shape: + x2 = dpt.broadcast_to(x2, x2_broadcast_shape) + ht_dot_ev, dot_ev = tli._dot( + x1=buf1, + x2=x2, + batch_dims=len(res_shape[:-2]), + x1_outer_dims=1, + x2_outer_dims=1, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=[copy_ev], + ) + if not (orig_out is None or orig_out is out): + # Copy the out data from temporary buffer to original memory + ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, + dst=orig_out, + sycl_queue=exec_q, + depends=[dot_ev], + ) + ht_copy_out_ev.wait() + out = orig_out + ht_copy_ev.wait() + ht_dot_ev.wait() + if appended_axes: + out = dpt.squeeze(out, tuple(appended_axes)) + return out + + if order in ["K", "A"]: + if x1.flags.f_contiguous and x2.flags.f_contiguous: + order = "F" + elif x1.flags.c_contiguous and x2.flags.c_contiguous: + order = "C" + else: + order = "C" if order == "A" else "K" + if order == "K": + buf1 = _empty_like_orderK(x1, buf1_dt) + else: + buf1 = dpt.empty_like(x1, dtype=buf1_dt, order=order) + ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x1, dst=buf1, sycl_queue=exec_q + ) + if order == "K": + buf2 = _empty_like_orderK(x2, buf2_dt) + else: + buf2 = dpt.empty_like(x2, dtype=buf2_dt, order=order) + ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x2, dst=buf2, sycl_queue=exec_q + ) + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + buf1, buf2, res_dt, res_shape, res_usm_type, exec_q + ) + else: + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + + if buf1.shape != x1_broadcast_shape: + buf1 = dpt.broadcast_to(buf1, x1_broadcast_shape) + if buf2.shape != x2_broadcast_shape: + buf2 = dpt.broadcast_to(buf2, x2_broadcast_shape) + ht_, _ = tli._dot( + x1=buf1, + x2=buf2, + batch_dims=len(res_shape[:-2]), + x1_outer_dims=1, + x2_outer_dims=1, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=[copy1_ev, copy2_ev], + ) + dpctl.SyclEvent.wait_for([ht_copy1_ev, ht_copy2_ev, ht_]) + if appended_axes: + out = dpt.squeeze(out, tuple(appended_axes)) + return out diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index 284de1cbe1..67e144f798 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -907,15 +907,15 @@ cdef class usm_ndarray: def __abs__(self): return dpctl.tensor.abs(self) - def __add__(first, other): + def __add__(self, other): """ Implementation for operator.add """ - return dpctl.tensor.add(first, other) + return dpctl.tensor.add(self, other) - def __and__(first, other): + def __and__(self, other): "Implementation for operator.and" - return dpctl.tensor.bitwise_and(first, other) + return dpctl.tensor.bitwise_and(self, other) def __dlpack__(self, stream=None): """ @@ -963,8 +963,8 @@ cdef class usm_ndarray: def __eq__(self, other): return dpctl.tensor.equal(self, other) - def __floordiv__(first, other): - return dpctl.tensor.floor_divide(first, other) + def __floordiv__(self, other): + return dpctl.tensor.floor_divide(self, other) def __ge__(self, other): return dpctl.tensor.greater_equal(self, other) @@ -984,21 +984,20 @@ cdef class usm_ndarray: else: raise TypeError("len() of unsized object") - def __lshift__(first, other): - "See comment in __add__" - return dpctl.tensor.bitwise_left_shift(first, other) + def __lshift__(self, other): + return dpctl.tensor.bitwise_left_shift(self, other) def __lt__(self, other): return dpctl.tensor.less(self, other) - def __matmul__(first, other): - return NotImplemented + def __matmul__(self, other): + return dpctl.tensor.matmul(self, other) - def __mod__(first, other): - return dpctl.tensor.remainder(first, other) + def __mod__(self, other): + return dpctl.tensor.remainder(self, other) - def __mul__(first, other): - return dpctl.tensor.multiply(first, other) + def __mul__(self, other): + return dpctl.tensor.multiply(self, other) def __ne__(self, other): return dpctl.tensor.not_equal(self, other) @@ -1006,20 +1005,17 @@ cdef class usm_ndarray: def __neg__(self): return dpctl.tensor.negative(self) - def __or__(first, other): - return dpctl.tensor.bitwise_or(first, other) + def __or__(self, other): + return dpctl.tensor.bitwise_or(self, other) def __pos__(self): return dpctl.tensor.positive(self) - def __pow__(first, other, mod): - if mod is None: - return dpctl.tensor.pow(first, other) - else: - return NotImplemented + def __pow__(self, other): + return dpctl.tensor.pow(self, other) - def __rshift__(first, other): - return dpctl.tensor.bitwise_right_shift(first, other) + def __rshift__(self, other): + return dpctl.tensor.bitwise_right_shift(self, other) def __setitem__(self, key, rhs): cdef tuple _meta @@ -1109,14 +1105,14 @@ cdef class usm_ndarray: return - def __sub__(first, other): - return dpctl.tensor.subtract(first, other) + def __sub__(self, other): + return dpctl.tensor.subtract(self, other) - def __truediv__(first, other): - return dpctl.tensor.divide(first, other) + def __truediv__(self, other): + return dpctl.tensor.divide(self, other) - def __xor__(first, other): - return dpctl.tensor.bitwise_xor(first, other) + def __xor__(self, other): + return dpctl.tensor.bitwise_xor(self, other) def __radd__(self, other): return dpctl.tensor.add(other, self) @@ -1131,7 +1127,7 @@ cdef class usm_ndarray: return dpctl.tensor.bitwise_left_shift(other, self) def __rmatmul__(self, other): - return NotImplemented + return dpctl.tensor.matmul(other, self) def __rmod__(self, other): return dpctl.tensor.remainder(other, self) @@ -1170,11 +1166,7 @@ cdef class usm_ndarray: return dpctl.tensor.bitwise_left_shift(self, other, out=self) def __imatmul__(self, other): - res = self.__matmul__(other) - if res is NotImplemented: - return res - self.__setitem__(Ellipsis, res) - return self + return dpctl.tensor.matmul(self, other, out=self) def __imod__(self, other): return dpctl.tensor.remainder(self, other, out=self) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp new file mode 100644 index 0000000000..15e5e35d67 --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp @@ -0,0 +1,1137 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "kernels/reductions.hpp" +#include "pybind11/pybind11.h" +#include "utils/offset_utils.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_utils.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ + +template +struct SequentialDotProduct +{ +private: + const lhsT *lhs_ = nullptr; + const rhsT *rhs_ = nullptr; + outT *out_ = nullptr; + BatchIndexerT batch_indexer_; + RedIndexerT reduced_dims_indexer_; + size_t reduction_max_gid_ = 0; + +public: + SequentialDotProduct(const lhsT *lhs, + const rhsT *rhs, + outT *out, + BatchIndexerT batch_indexer, + RedIndexerT reduced_dims_indexer, + size_t reduction_size) + : lhs_(lhs), rhs_(rhs), out_(out), batch_indexer_(batch_indexer), + reduced_dims_indexer_(reduced_dims_indexer), + reduction_max_gid_(reduction_size) + { + } + + void operator()(sycl::id<1> id) const + { + + auto const &batch_offsets = batch_indexer_(id[0]); + const py::ssize_t &lhs_batch_offset = batch_offsets.get_first_offset(); + const py::ssize_t &rhs_batch_offset = batch_offsets.get_second_offset(); + const py::ssize_t &out_batch_offset = batch_offsets.get_third_offset(); + + outT red_val(0); + for (size_t m = 0; m < reduction_max_gid_; ++m) { + auto reduction_offsets = reduced_dims_indexer_(m); + auto lhs_reduction_offset = reduction_offsets.get_first_offset(); + auto rhs_reduction_offset = reduction_offsets.get_second_offset(); + + using dpctl::tensor::type_utils::convert_impl; + red_val += convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset]) * + convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset]); + } + + out_[out_batch_offset] = red_val; + } +}; + +template +struct DotProductFunctor +{ +private: + const lhsT *lhs_ = nullptr; + const rhsT *rhs_ = nullptr; + outT *out_ = nullptr; + BatchIndexerT batch_indexer_; + RedIndexerT reduced_dims_indexer_; + size_t reduction_max_gid_ = 0; + size_t batches_ = 1; + size_t reductions_per_wi = 16; + +public: + DotProductFunctor(const lhsT *lhs, + const rhsT *rhs, + outT *res, + BatchIndexerT batch_indexer, + RedIndexerT arg_reduced_dims_indexer, + size_t reduction_size, + size_t iteration_size, + size_t reduction_size_per_wi) + : lhs_(lhs), rhs_(rhs), out_(res), batch_indexer_(batch_indexer), + reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size), batches_(iteration_size), + reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const size_t batch_id = it.get_group(0) % batches_; + const size_t reduction_batch_id = it.get_group(0) / batches_; + + const size_t reduction_lid = it.get_local_id(0); + const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg + + // work-items operate over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + // for each input + + auto batch_offsets_ = batch_indexer_(batch_id); + const auto &lhs_batch_offset = batch_offsets_.get_first_offset(); + const auto &rhs_batch_offset = batch_offsets_.get_second_offset(); + const auto &out_batch_offset = batch_offsets_.get_third_offset(); + + outT local_red_val(0); + size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + size_t arg_reduce_gid_max = std::min( + reduction_max_gid_, arg_reduce_gid0 + reductions_per_wi * wg); + + for (size_t arg_reduce_gid = arg_reduce_gid0; + arg_reduce_gid < arg_reduce_gid_max; arg_reduce_gid += wg) + { + auto reduction_offsets_ = reduced_dims_indexer_(arg_reduce_gid); + const auto &lhs_reduction_offset = + reduction_offsets_.get_first_offset(); + const auto &rhs_reduction_offset = + reduction_offsets_.get_second_offset(); + + using dpctl::tensor::type_utils::convert_impl; + outT val = convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset]) * + convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset]); + + local_red_val += val; + } + + auto work_group = it.get_group(); + outT red_val_over_wg = sycl::reduce_over_group( + work_group, local_red_val, outT(0), sycl::plus()); + + if (work_group.leader()) { + sycl::atomic_ref + res_ref(out_[out_batch_offset]); + res_ref += red_val_over_wg; + } + } +}; + +template +class dot_product_seq_krn; + +template class dot_product_init_krn; + +template +class dot_product_krn; + +typedef sycl::event (*dot_product_impl_fn_ptr_t)( + sycl::queue &, + size_t, + size_t, + const char *, + const char *, + char *, + int, + const py::ssize_t *, + py::ssize_t, + py::ssize_t, + py::ssize_t, + int, + const py::ssize_t *, + py::ssize_t, + py::ssize_t, + const std::vector &); + +template +sycl::event dot_product_impl(sycl::queue &exec_q, + size_t batches, + size_t reduction_nelems, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + int batch_nd, + const py::ssize_t *batch_shape_and_strides, + py::ssize_t batch_lhs_offset, + py::ssize_t batch_rhs_offset, + py::ssize_t batch_res_offset, + int red_nd, + const py::ssize_t *reduction_shape_stride, + py::ssize_t reduction_lhs_offset, + py::ssize_t reduction_rhs_offset, + const std::vector &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + InputOutputBatchIndexerT in_out_batch_indexer{ + batch_nd, batch_lhs_offset, batch_rhs_offset, batch_res_offset, + batch_shape_and_strides}; + ReductionIndexerT reduction_indexer{red_nd, reduction_lhs_offset, + reduction_rhs_offset, + reduction_shape_stride}; + + cgh.parallel_for>( + sycl::range<1>(batches), + SequentialDotProduct( + lhs_tp, rhs_tp, res_tp, in_out_batch_indexer, + reduction_indexer, reduction_nelems)); + }); + + return dot_ev; + } + else { + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + using IndexerT = + dpctl::tensor::offset_utils::UnpackedStridedIndexer; + + const py::ssize_t *const &res_shape = batch_shape_and_strides; + const py::ssize_t *const &res_strides = + batch_shape_and_strides + 3 * batch_nd; + IndexerT res_indexer(batch_nd, batch_res_offset, res_shape, + res_strides); + using InitKernelName = + class dot_product_init_krn; + cgh.depends_on(depends); + + cgh.parallel_for( + sycl::range<1>(batches), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = 0; + }); + }); + + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); + + using BatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + BatchIndexerT batch_indexer{batch_nd, batch_lhs_offset, + batch_rhs_offset, batch_res_offset, + batch_shape_and_strides}; + ReductionIndexerT reduction_indexer{red_nd, reduction_lhs_offset, + reduction_rhs_offset, + reduction_shape_stride}; + + constexpr size_t preferred_reductions_per_wi = + 4; // determined experimentally + size_t reductions_per_wi = + (reduction_nelems < preferred_reductions_per_wi * wg) + ? std::max(1, (reduction_nelems + wg - 1) / wg) + : preferred_reductions_per_wi; + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + + auto globalRange = sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = + class dot_product_krn; + + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + DotProductFunctor( + lhs_tp, rhs_tp, res_tp, batch_indexer, reduction_indexer, + reduction_nelems, batches, reductions_per_wi)); + }); + return dot_ev; + } +} + +typedef sycl::event (*dot_product_contig_impl_fn_ptr_t)( + sycl::queue &, + size_t, + size_t, + const char *, + const char *, + char *, + py::ssize_t, + py::ssize_t, + py::ssize_t, + py::ssize_t, + py::ssize_t, + const std::vector &); + +template +sycl::event +dot_product_contig_impl(sycl::queue &exec_q, + size_t batches, + size_t reduction_nelems, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + py::ssize_t batch_lhs_offset, + py::ssize_t batch_rhs_offset, + py::ssize_t batch_res_offset, + py::ssize_t reduction_lhs_offset, + py::ssize_t reduction_rhs_offset, + const std::vector &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp) + + batch_lhs_offset + reduction_lhs_offset; + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp) + + batch_rhs_offset + reduction_rhs_offset; + resTy *res_tp = reinterpret_cast(res_cp) + batch_res_offset; + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputBatchIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + InputBatchIndexerT, InputBatchIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + + InputBatchIndexerT inp_batch_indexer{ + 0, static_cast(reduction_nelems), + static_cast(batches)}; + InputOutputBatchIndexerT inp_out_batch_indexer{ + inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{NoOpIndexerT{}, NoOpIndexerT{}}; + + cgh.parallel_for>( + sycl::range<1>(batches), + SequentialDotProduct( + lhs_tp, rhs_tp, res_tp, inp_out_batch_indexer, + reduction_indexer, reduction_nelems)); + }); + + return dot_ev; + } + else { + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.fill(res_tp, resTy(0), batches); + }); + + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); + + using InputBatchIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + InputBatchIndexerT, InputBatchIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + + InputBatchIndexerT inp_batch_indexer{ + 0, static_cast(reduction_nelems), + static_cast(batches)}; + InputOutputBatchIndexerT inp_out_batch_indexer{ + inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{NoOpIndexerT{}, NoOpIndexerT{}}; + + constexpr size_t preferred_reductions_per_wi = + 4; // determined experimentally + size_t reductions_per_wi = + (reduction_nelems < preferred_reductions_per_wi * wg) + ? std::max(1, (reduction_nelems + wg - 1) / wg) + : preferred_reductions_per_wi; + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + + auto globalRange = sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class dot_product_krn; + + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + DotProductFunctor( + lhs_tp, rhs_tp, res_tp, inp_out_batch_indexer, + reduction_indexer, reduction_nelems, batches, + reductions_per_wi)); + }); + return dot_ev; + } +} + +template +struct DotProductNoAtomicFunctor +{ +private: + const lhsT *lhs_ = nullptr; + const rhsT *rhs_ = nullptr; + outT *out_ = nullptr; + BatchIndexerT batch_indexer_; + RedIndexerT reduced_dims_indexer_; + size_t reduction_max_gid_ = 0; + size_t batches_ = 1; + size_t reductions_per_wi = 16; + +public: + DotProductNoAtomicFunctor(const lhsT *lhs, + const rhsT *rhs, + outT *res, + BatchIndexerT batch_indexer, + RedIndexerT arg_reduced_dims_indexer, + size_t reduction_size, + size_t iteration_size, + size_t reduction_size_per_wi) + : lhs_(lhs), rhs_(rhs), out_(res), batch_indexer_(batch_indexer), + reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size), batches_(iteration_size), + reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const size_t reduction_lid = it.get_local_id(0); + const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg + + const size_t batch_id = it.get_group(0) % batches_; + const size_t reduction_batch_id = it.get_group(0) / batches_; + const size_t n_reduction_groups = it.get_group_range(0) / batches_; + + // work-items operate over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + // for each input + + auto batch_offsets_ = batch_indexer_(batch_id); + const auto &lhs_batch_offset = batch_offsets_.get_first_offset(); + const auto &rhs_batch_offset = batch_offsets_.get_second_offset(); + const auto &out_batch_offset = batch_offsets_.get_third_offset(); + + outT local_red_val(0); + size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + size_t arg_reduce_gid_max = std::min( + reduction_max_gid_, arg_reduce_gid0 + reductions_per_wi * wg); + + for (size_t arg_reduce_gid = arg_reduce_gid0; + arg_reduce_gid < arg_reduce_gid_max; arg_reduce_gid += wg) + { + auto reduction_offsets_ = reduced_dims_indexer_(arg_reduce_gid); + const auto &lhs_reduction_offset = + reduction_offsets_.get_first_offset(); + const auto &rhs_reduction_offset = + reduction_offsets_.get_second_offset(); + + using dpctl::tensor::type_utils::convert_impl; + outT val = convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset]) * + convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset]); + + local_red_val += val; + } + + auto work_group = it.get_group(); + outT red_val_over_wg = sycl::reduce_over_group( + work_group, local_red_val, outT(0), sycl::plus()); + + if (work_group.leader()) { + // each group writes to a different memory location + out_[out_batch_offset * n_reduction_groups + reduction_batch_id] = + red_val_over_wg; + } + } +}; + +template +class dot_product_tree_krn; + +template +class dot_product_reduction_over_group_temps_krn; + +template +sycl::event dot_product_tree_impl(sycl::queue &exec_q, + size_t batches, + size_t reduction_nelems, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + int batch_nd, + const py::ssize_t *batch_shape_and_strides, + py::ssize_t batch_lhs_offset, + py::ssize_t batch_rhs_offset, + py::ssize_t batch_res_offset, + int red_nd, + const py::ssize_t *reduction_shape_stride, + py::ssize_t reduction_lhs_offset, + py::ssize_t reduction_rhs_offset, + const std::vector &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + InputOutputBatchIndexerT in_out_batch_indexer{ + batch_nd, batch_lhs_offset, batch_rhs_offset, batch_res_offset, + batch_shape_and_strides}; + ReductionIndexerT reduction_indexer{red_nd, reduction_lhs_offset, + reduction_rhs_offset, + reduction_shape_stride}; + + cgh.parallel_for>( + sycl::range<1>(batches), + SequentialDotProduct( + lhs_tp, rhs_tp, res_tp, in_out_batch_indexer, + reduction_indexer, reduction_nelems)); + }); + + return dot_ev; + } + + constexpr size_t preferred_reductions_per_wi = 8; + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, d.get_info() / 2); + + size_t reductions_per_wi(preferred_reductions_per_wi); + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using BatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + BatchIndexerT batch_indexer{batch_nd, batch_lhs_offset, + batch_rhs_offset, batch_res_offset, + batch_shape_and_strides}; + ReductionIndexerT reduction_indexer{red_nd, reduction_lhs_offset, + reduction_rhs_offset, + reduction_shape_stride}; + + if (batches == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + using KernelName = + dot_product_tree_krn; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + DotProductNoAtomicFunctor( + lhs_tp, rhs_tp, res_tp, batch_indexer, reduction_indexer, + reduction_nelems, batches, reductions_per_wi)); + }); + + return dot_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + // more than one work-groups is needed, requires a temporary + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + size_t second_iter_reduction_groups_ = + (reduction_groups + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + resTy *partially_reduced_tmp = sycl::malloc_device( + batches * (reduction_groups + second_iter_reduction_groups_), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * batches; + } + + const sycl::event &first_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using LhsIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using RhsIndexerT = + dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + LhsIndexerT, RhsIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + LhsIndexerT lhs_indexer(batch_nd, batch_lhs_offset, + batch_shape_and_strides); + RhsIndexerT rhs_indexer(batch_nd, batch_rhs_offset, + batch_shape_and_strides, + batch_shape_and_strides + 2 * batch_nd); + ResIndexerT noop_tmp_indexer{}; + + InputOutputBatchIndexerT in_out_iter_indexer{ + lhs_indexer, rhs_indexer, noop_tmp_indexer}; + ReductionIndexerT reduction_indexer{ + red_nd, reduction_lhs_offset, reduction_rhs_offset, + reduction_shape_stride}; + + auto globalRange = + sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = + class dot_product_tree_krn; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + DotProductNoAtomicFunctor( + lhs_tp, rhs_tp, partially_reduced_tmp, + in_out_iter_indexer, reduction_indexer, + reduction_nelems, batches, + preferred_reductions_per_wi)); + }); + + size_t remaining_reduction_nelems = reduction_groups; + + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > + preferred_reductions_per_wi * max_wg) { + size_t reduction_groups_ = (remaining_reduction_nelems + + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + sycl::event partial_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(batches), + static_cast(reduction_groups_)}; + ResIndexerT res_iter_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{ + inp_indexer, res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + auto globalRange = + sycl::range<1>{batches * reduction_groups_ * wg}; + auto localRange = sycl::range<1>{wg}; + using KernelName = + class dot_product_reduction_over_group_temps_krn< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>( + temp_arg, temp2_arg, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, batches, + preferred_reductions_per_wi)); + }); + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = std::move(partial_reduction_ev); + } + // final reduction to res + sycl::event final_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = + dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(batches), + static_cast(remaining_reduction_nelems)}; + ResIndexerT res_iter_indexer{batch_nd, batch_res_offset, + /* shape */ batch_shape_and_strides, + /* strides */ batch_shape_and_strides + + 2 * batch_nd}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = + std::max(1, (remaining_reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class dot_product_reduction_over_group_temps_krn< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor( + temp_arg, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, batches, reductions_per_wi)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(final_reduction_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } +} + +template +sycl::event +dot_product_contig_tree_impl(sycl::queue &exec_q, + size_t batches, + size_t reduction_nelems, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + py::ssize_t batch_lhs_offset, + py::ssize_t batch_rhs_offset, + py::ssize_t batch_res_offset, + py::ssize_t reduction_lhs_offset, + py::ssize_t reduction_rhs_offset, + const std::vector &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp) + + batch_lhs_offset + reduction_lhs_offset; + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp) + + batch_rhs_offset + reduction_rhs_offset; + resTy *res_tp = reinterpret_cast(res_cp) + batch_res_offset; + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputBatchIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + InputBatchIndexerT, InputBatchIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + + InputBatchIndexerT inp_batch_indexer{ + 0, static_cast(reduction_nelems), + static_cast(batches)}; + InputOutputBatchIndexerT inp_out_batch_indexer{ + inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{NoOpIndexerT{}, NoOpIndexerT{}}; + + cgh.parallel_for>( + sycl::range<1>(batches), + SequentialDotProduct( + lhs_tp, rhs_tp, res_tp, inp_out_batch_indexer, + reduction_indexer, reduction_nelems)); + }); + + return dot_ev; + } + + constexpr size_t preferred_reductions_per_wi = 8; + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, d.get_info() / 2); + + size_t reductions_per_wi(preferred_reductions_per_wi); + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputBatchIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + InputBatchIndexerT, InputBatchIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + + InputBatchIndexerT inp_batch_indexer{ + 0, static_cast(reduction_nelems), + static_cast(batches)}; + InputOutputBatchIndexerT inp_out_batch_indexer{ + inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{NoOpIndexerT{}, NoOpIndexerT{}}; + + if (batches == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + using KernelName = dot_product_tree_krn; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + DotProductNoAtomicFunctor( + lhs_tp, rhs_tp, res_tp, inp_out_batch_indexer, + reduction_indexer, reduction_nelems, batches, + reductions_per_wi)); + }); + + return dot_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + // more than one work-groups is needed, requires a temporary + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + size_t second_iter_reduction_groups_ = + (reduction_groups + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + resTy *partially_reduced_tmp = sycl::malloc_device( + batches * (reduction_groups + second_iter_reduction_groups_), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * batches; + } + + const sycl::event &first_reduction_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using InputBatchIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + InputBatchIndexerT, InputBatchIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + + InputBatchIndexerT inp_batch_indexer{ + 0, static_cast(reduction_nelems), + static_cast(batches)}; + InputOutputBatchIndexerT inp_out_batch_indexer{ + inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{NoOpIndexerT{}, NoOpIndexerT{}}; + + auto globalRange = sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = + class dot_product_tree_krn; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + DotProductNoAtomicFunctor( + lhs_tp, rhs_tp, partially_reduced_tmp, + inp_out_batch_indexer, reduction_indexer, reduction_nelems, + batches, preferred_reductions_per_wi)); + }); + + size_t remaining_reduction_nelems = reduction_groups; + + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > + preferred_reductions_per_wi * max_wg) { + size_t reduction_groups_ = (remaining_reduction_nelems + + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + sycl::event partial_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(batches), + static_cast(reduction_groups_)}; + ResIndexerT res_iter_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{ + inp_indexer, res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + auto globalRange = + sycl::range<1>{batches * reduction_groups_ * wg}; + auto localRange = sycl::range<1>{wg}; + using KernelName = + class dot_product_reduction_over_group_temps_krn< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>( + temp_arg, temp2_arg, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, batches, + preferred_reductions_per_wi)); + }); + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = std::move(partial_reduction_ev); + } + // final reduction to res + sycl::event final_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(batches), + static_cast(remaining_reduction_nelems)}; + ResIndexerT res_iter_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = + std::max(1, (remaining_reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class dot_product_reduction_over_group_temps_krn< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor( + temp_arg, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, batches, reductions_per_wi)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(final_reduction_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } +} + +} // namespace kernels +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp new file mode 100644 index 0000000000..a4a5d3b929 --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -0,0 +1,6968 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "kernels/reductions.hpp" +#include "pybind11/pybind11.h" +#include "utils/offset_utils.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_utils.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ + +namespace gemm_detail +{ + +template +void scale_gemm_k_parameters(const size_t &local_mem_size, + const size_t &reserved_slm_size, + const size_t delta_k, + size_t &n_wi, + size_t &delta_n) +{ + constexpr size_t slm_elem_size = sizeof(T) * m_groups; + + while (slm_elem_size * (n_wi + delta_n) * delta_k + reserved_slm_size >= + local_mem_size) + { + n_wi = n_wi / 2; + delta_n = delta_n / 2; + if (delta_n == 0) + throw std::runtime_error("Insufficient resources"); + } +} + +template +void scale_gemm_nm_parameters(const size_t &local_mem_size, + const size_t &reserved_slm_size, + const size_t &wi_delta_n, + size_t &wi_delta_k, + size_t &wg_delta_n, + size_t &wg_delta_m) +{ + constexpr size_t slm_A_elem_size = sizeof(T); + constexpr size_t slm_B_elem_size = sizeof(T) * wi_delta_m; + + while ((wi_delta_n * wg_delta_n * wi_delta_k * slm_A_elem_size) + + (wi_delta_k * wg_delta_m * slm_B_elem_size) + + reserved_slm_size >= + local_mem_size) + { + wg_delta_n /= 2; + wg_delta_m /= 2; + wi_delta_k /= 2; + if (wg_delta_n == 0) + throw std::runtime_error("Insufficient resources"); + } +} +} // namespace gemm_detail + +using dpctl::tensor::sycl_utils::choose_workgroup_size; + +template +class gemm_seq_reduction_krn; + +template +class gemm_tree_reduction_krn; + +template +sycl::event single_reduction_for_gemm(sycl::queue &exec_q, + T *tmp_tp, + T *res_tp, + T identity_val, + size_t iter_nelems, + size_t reduction_nelems, + size_t reduction_groups, + size_t wg, + size_t max_wg, + size_t preferred_reductions_per_wi, + size_t reductions_per_wi, + int res_nd, + py::ssize_t res_offset, + const py::ssize_t *res_shapes_strides, + const std::vector &depends) +{ + sycl::event red_ev; + if (reduction_nelems < wg) { + red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + ResIndexerT res_iter_indexer{res_nd, 0, res_shapes_strides}; + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for>( + iter_range, + SequentialReduction( + tmp_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, reduction_nelems)); + }); + } + else { + red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + ResIndexerT res_iter_indexer{res_nd, 0, res_shapes_strides}; + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + if (iter_nelems == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class gemm_tree_reduction_krn< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor( + tmp_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, reduction_nelems, + iter_nelems, reductions_per_wi)); + }); + } + return red_ev; +} + +template +sycl::event +single_reduction_for_gemm_contig(sycl::queue &exec_q, + T *tmp_tp, + T *res_tp, + T identity_val, + size_t iter_nelems, + size_t reduction_nelems, + size_t reduction_groups, + size_t wg, + size_t max_wg, + size_t preferred_reductions_per_wi, + size_t reductions_per_wi, + const std::vector &depends) +{ + sycl::event red_ev; + if (reduction_nelems < wg) { + red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for>( + iter_range, + SequentialReduction( + tmp_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, reduction_nelems)); + }); + } + else { + red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + if (iter_nelems == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class gemm_tree_reduction_krn< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor( + tmp_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, reduction_nelems, + iter_nelems, reductions_per_wi)); + }); + } + return red_ev; +} + +template +sycl::event tree_reduction_for_gemm(sycl::queue &exec_q, + T *partially_reduced_tmp, + T *partially_reduced_tmp2, + T *res_tp, + T identity_val, + size_t iter_nelems, + size_t reduction_nelems, + size_t reduction_groups, + size_t wg, + size_t max_wg, + size_t preferred_reductions_per_wi, + size_t reductions_per_wi, + int res_nd, + py::ssize_t res_offset, + const py::ssize_t *res_shape_strides, + const std::vector &depends) +{ + + const sycl::event &first_reduction_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + + // Only 2*iter_nd entries describing shape and strides of + // iterated dimensions of input array from + // iter_shape_and_strides are going to be accessed by + // inp_indexer + + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, /* size */ static_cast(reduction_nelems), + /* step */ static_cast(iter_nelems)}; + + auto globalRange = sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class gemm_tree_reduction_krn< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>( + partially_reduced_tmp, partially_reduced_tmp2, ReductionOpT(), + identity_val, in_out_iter_indexer, reduction_indexer, + reduction_nelems, iter_nelems, reductions_per_wi)); + }); + + size_t remaining_reduction_nelems = reduction_groups; + + T *temp_arg = partially_reduced_tmp2; + T *temp2_arg = partially_reduced_tmp; + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > preferred_reductions_per_wi * max_wg) { + size_t reduction_groups_ = (remaining_reduction_nelems + + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + // keep reducing + sycl::event partial_reduction_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(reduction_groups_)}; + ResIndexerT res_iter_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + + ReductionIndexerT reduction_indexer{}; + + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups_ * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class gemm_tree_reduction_krn< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor( + temp_arg, temp2_arg, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, iter_nelems, + reductions_per_wi)); + }); + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = std::move(partial_reduction_ev); + } + + // final reduction to res + sycl::event final_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(remaining_reduction_nelems)}; + ResIndexerT res_iter_indexer{ + res_nd, static_cast(res_offset), res_shape_strides}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = + std::max(1, (remaining_reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class gemm_tree_reduction_krn< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>( + temp_arg, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, iter_nelems, reductions_per_wi)); + }); + + return final_reduction_ev; +} + +template +class gemm_reduction_over_group_temps_contig_krn; + +template +sycl::event +tree_reduction_for_gemm_contig(sycl::queue &exec_q, + T *partially_reduced_tmp, + T *partially_reduced_tmp2, + T *res_tp, + T identity_val, + size_t iter_nelems, + size_t reduction_nelems, + size_t reduction_groups, + size_t wg, + size_t max_wg, + size_t preferred_reductions_per_wi, + size_t reductions_per_wi, + const std::vector &depends) +{ + + const sycl::event &first_reduction_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + + // Only 2*iter_nd entries describing shape and strides of + // iterated dimensions of input array from + // iter_shape_and_strides are going to be accessed by + // inp_indexer + + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, /* size */ static_cast(reduction_nelems), + /* step */ static_cast(iter_nelems)}; + + auto globalRange = sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class gemm_reduction_over_group_temps_contig_krn< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>( + partially_reduced_tmp, partially_reduced_tmp2, ReductionOpT(), + identity_val, in_out_iter_indexer, reduction_indexer, + reduction_nelems, iter_nelems, reductions_per_wi)); + }); + + size_t remaining_reduction_nelems = reduction_groups; + + T *temp_arg = partially_reduced_tmp2; + T *temp2_arg = partially_reduced_tmp; + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > preferred_reductions_per_wi * max_wg) { + size_t reduction_groups_ = (remaining_reduction_nelems + + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + // keep reducing + sycl::event partial_reduction_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + // n * m = iter_nelems because essentially, this process + // creates a stack of reduction_nelems 2D matrices and we reduce + // along the stack axis + InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(reduction_groups_)}; + ResIndexerT res_iter_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + + ReductionIndexerT reduction_indexer{}; + + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups_ * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class gemm_reduction_over_group_temps_contig_krn< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor( + temp_arg, temp2_arg, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, iter_nelems, + reductions_per_wi)); + }); + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = std::move(partial_reduction_ev); + } + + // final reduction to res + sycl::event final_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(remaining_reduction_nelems)}; + ResIndexerT res_iter_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = + std::max(1, (remaining_reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class gemm_reduction_over_group_temps_contig_krn< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>( + temp_arg, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, iter_nelems, reductions_per_wi)); + }); + + return final_reduction_ev; +} + +template +class GemmFunctorThreadNM +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_A_block; + LocAccT2 local_B_block; + size_t n = 0; + size_t wg_delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t wi_delta_k = 0; + size_t m = 0; + size_t m_blocks = 0; + size_t wg_delta_m = 0; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + OuterInnerDimsIndexerT res_indexer; + +public: + GemmFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + OuterInnerDimsIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), + local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), + k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), + m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), + res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t gr_id = it.get_group_linear_id(); + // lift group_id to (block_i, block_j, block_s), + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s < + // k_blocks + size_t block_i = gr_id / (m_blocks * k_blocks); + size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + size_t block_j = block_r / k_blocks; + size_t block_s = block_r - block_j * k_blocks; + + size_t lid = it.get_local_linear_id(); + size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + + // load A block and B blocks into SLM + + size_t i = block_i * wi_delta_n * wg_delta_n; + size_t j = block_j * wi_delta_m * wg_delta_m; + size_t s = block_s * wi_delta_k; + + const std::int64_t a_st0 = k; + const std::int64_t a_st1 = 1; + + const std::int64_t b_st0 = m; + const std::int64_t b_st1 = 1; + + const std::int64_t c_st0 = m; + const std::int64_t c_st1 = 1; + + size_t lws = it.get_local_range(0); + + for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { + size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_i = i + v_i; + size_t g_s = s + v_s; + + local_A_block[vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : resT(0); + } + + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { + size_t v_j = vid / wi_delta_k; // 0 <= v_i < wg_delta_m + size_t v_s = vid - v_j * wi_delta_k; // 0 <= v_s < wi_delta_k + + size_t g_j0 = j + v_j * wi_delta_m; + size_t g_s = s + v_s; + + sycl::vec vec{}; +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { + size_t g_j = g_j0 + lane_id; + vec[lane_id] = + (g_j < m && g_s < k) + ? static_cast( + rhs[rhs_indexer(g_s * b_st0 + g_j * b_st1)]) + : resT(0); + } + + local_B_block[vid] = vec; + } + + it.barrier(sycl::access::fence_space::local_space); + + i += local_i * wi_delta_n; + j += local_j * wi_delta_m; + + size_t a_offset = local_i * wi_delta_k * wi_delta_n; + size_t b_offset = local_j * wi_delta_k; + + constexpr resT identity_(0); + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { + size_t a_pr_offset = private_i * wi_delta_k; + + sycl::vec local_sum(identity_); + for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { + local_sum = local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } + + size_t gl_i = i + private_i; + +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { + size_t gl_j = j + lane_id; + + if (gl_i < n && gl_j < m) { + sycl::atomic_ref + aout(res[res_indexer(gl_i * c_st0 + gl_j * c_st1)]); + + aout += local_sum[lane_id]; + } + } + } + } +}; + +// specialization for wi_delta_m == 1 +template +class GemmFunctorThreadNM +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_A_block; + LocAccT2 local_B_block; + size_t n = 0; + size_t wg_delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t wi_delta_k = 0; + size_t m = 0; + size_t m_blocks = 0; + size_t wg_delta_m = 0; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + OuterInnerDimsIndexerT res_indexer; + +public: + GemmFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + OuterInnerDimsIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), + local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), + k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), + m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), + res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t gr_id = it.get_group_linear_id(); + // lift group_id to (block_i, block_j, block_s), + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s < + // k_blocks + size_t block_i = gr_id / (m_blocks * k_blocks); + size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + size_t block_j = block_r / k_blocks; + size_t block_s = block_r - block_j * k_blocks; + + size_t lid = it.get_local_linear_id(); + size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + + // load A block and B blocks into SLM + + size_t i = block_i * wi_delta_n * wg_delta_n; + size_t j = block_j * wg_delta_m; + size_t s = block_s * wi_delta_k; + + const std::int64_t a_st0 = k; + const std::int64_t a_st1 = 1; + + const std::int64_t b_st0 = m; + const std::int64_t b_st1 = 1; + + const std::int64_t c_st0 = m; + const std::int64_t c_st1 = 1; + + size_t lws = it.get_local_range(0); + + for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { + size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_i = i + v_i; + size_t g_s = s + v_s; + + local_A_block[vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : resT(0); + } + + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { + size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m + size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_j0 = j + v_j; + size_t g_s = s + v_s; + + resT val = (g_j0 < m && g_s < k) + ? static_cast( + rhs[rhs_indexer(g_s * b_st0 + g_j0 * b_st1)]) + : resT(0); + + local_B_block[vid] = val; + } + + it.barrier(sycl::access::fence_space::local_space); + + i += local_i * wi_delta_n; + j += local_j; + + size_t a_offset = local_i * wi_delta_k * wi_delta_n; + size_t b_offset = local_j * wi_delta_k; + + constexpr resT identity_(0); + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { + size_t a_pr_offset = private_i * wi_delta_k; + + resT local_sum(identity_); + for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { + local_sum = local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } + + size_t gl_i = i + private_i; + + if (gl_i < n && j < m) { + sycl::atomic_ref + aout(res[res_indexer(gl_i * c_st0 + j * c_st1)]); + + aout += local_sum; + } + } + } +}; + +template +class GemmFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + size_t n = 0; + size_t n_blocks = 0; + size_t delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t delta_k = 0; + size_t n_wi = 0; + size_t m = 0; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + OuterInnerDimsIndexerT res_indexer; + +public: + GemmFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + OuterInnerDimsIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t gr_id = it.get_group_linear_id(); + size_t lid = it.get_local_linear_id(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + size_t block_j = + gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks + size_t block_r = + gr_id - block_j * (n_blocks * + k_blocks); // 0 <= block_r < n_blocks * k_blocks + size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks + size_t block_i = + block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + + size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + + size_t i = block_i * delta_n + local_i; + size_t j = m_groups * block_j; + size_t s = block_s * delta_k * n_wi + local_s; + + constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { + size_t sq = s + q; + size_t sqmj = sq * m + j; + sycl::vec local_B_vec; +#pragma unroll + for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) { + local_B_vec[vec_idx] = + (sq < k && j + vec_idx < m) + ? static_cast( + rhs[rhs_indexer(sqmj + vec_idx)]) + : identity_; + } + local_B_block[local_s + q] = local_B_vec; + } + } + + it.barrier(sycl::access::fence_space::local_space); + + size_t t_shift = block_s * delta_k * n_wi; + size_t global_s_offset = i * k + t_shift; + + sycl::vec private_sum(identity_); + constexpr sycl::vec vec_identity_(identity_); + for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { + private_sum += ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : vec_identity_; + } + + size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + sycl::vec local_sum(workspace[workspace_i_shift]); + for (size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + sycl::atomic_ref + aout0(res[res_indexer(i * m + j)]); + + aout0 += local_sum[0]; + +#pragma unroll + for (size_t vec_id = 1; vec_id < m_groups; ++vec_id) { + if (j + vec_id < m) { + sycl::atomic_ref + aout1(res[res_indexer(i * m + j + vec_id)]); + + aout1 += local_sum[vec_id]; + } + } + } + } +}; + +// specialization for m_groups == 1 +template +class GemmFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + size_t n = 0; + size_t n_blocks = 0; + size_t delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t delta_k = 0; + size_t n_wi = 0; + size_t m = 0; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + OuterInnerDimsIndexerT res_indexer; + +public: + GemmFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + OuterInnerDimsIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t gr_id = it.get_group_linear_id(); + size_t lid = it.get_local_linear_id(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + size_t block_j = + gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks + size_t block_r = + gr_id - block_j * (n_blocks * + k_blocks); // 0 <= block_r < n_blocks * k_blocks + size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks + size_t block_i = + block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + + size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + + size_t i = block_i * delta_n + local_i; + size_t j = block_j; + size_t s = block_s * delta_k * n_wi + local_s; + + constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { + size_t sq = s + q; + size_t sqmj = sq * m + j; + local_B_block[local_s + q] = + (sq < k && j < m) + ? static_cast(rhs[rhs_indexer(sqmj)]) + : identity_; + } + } + + it.barrier(sycl::access::fence_space::local_space); + + size_t t_shift = block_s * delta_k * n_wi; + size_t global_s_offset = i * k + t_shift; + + resT private_sum(identity_); + for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { + private_sum += ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : identity_; + } + + size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + resT local_sum(workspace[workspace_i_shift]); + for (size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + sycl::atomic_ref + aout(res[res_indexer(i * m + j)]); + + aout += local_sum; + } + } +}; + +template class gemm_init_krn; + +template +class gemm_k_krn; + +template +class gemm_nm_krn; + +typedef sycl::event (*gemm_impl_fn_ptr_t)( + sycl::queue &, + const char *, // lhs + const char *, // rhs + char *, // res + size_t, // lhs_outer_nelems (n) + size_t, // inner_nelems (k) + size_t, // rhs_outer_nelems (m) + int, // inner nd + int, // lhs outer nd + const py::ssize_t *, // lhs shape and strides + int, // rhs outer nd + const py::ssize_t *, // rhs shape and strides + int, // res outer nd + const py::ssize_t *, // res shape and strides + std::vector const &); + +template +sycl::event gemm_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t n, + size_t k, + size_t m, + int inner_nd, + int lhs_outer_nd, + const py::ssize_t *lhs_shape_strides, + int rhs_outer_nd, + const py::ssize_t *rhs_shape_strides, + int res_outer_nd, + const py::ssize_t *res_shape_strides, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using IndexerT = dpctl::tensor::offset_utils::StridedIndexer; + IndexerT res_indexer(res_outer_nd, 0, res_shape_strides); + using InitKernelName = class gemm_init_krn; + cgh.parallel_for( + sycl::range<1>(n * m), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = resTy(0); + }); + }); + + if (k == 0) { + return res_init_ev; + } + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); + + using OuterInnerIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_shape_strides); + OuterInnerIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_shape_strides); + OuterInnerIndexerT res_indexer(res_outer_nd, 0, res_shape_strides); + + if (m < 4) { + constexpr size_t m_groups = 1; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_k_krn; + cgh.parallel_for( + ndRange, GemmFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, lhs_indexer, rhs_indexer, res_indexer)); + } + else if (k > n && k > m) { + constexpr size_t m_groups = 4; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_k_krn; + cgh.parallel_for( + ndRange, GemmFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, lhs_indexer, rhs_indexer, res_indexer)); + } + else { + constexpr int wi_delta_n = 2; + constexpr int wi_delta_m = 4; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = + class gemm_nm_krn; + cgh.parallel_for( + ndRange, + GemmFunctorThreadNM( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + }); + + return gemm_ev; +} + +typedef sycl::event (*gemm_contig_impl_fn_ptr_t)( + sycl::queue &, + const char *, // lhs + const char *, // rhs + char *, // res + size_t, // n + size_t, // k + size_t, // m + std::vector const &); + +template +sycl::event gemm_contig_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t n, + size_t k, + size_t m, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.fill(res_tp, resTy(0), n * m); + }); + + if (k == 0) { + return res_init_ev; + } + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); + + using OuterInnerIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerIndexerT lhs_indexer{}; + OuterInnerIndexerT rhs_indexer{}; + OuterInnerIndexerT res_indexer{}; + + if (m < 4) { + constexpr size_t m_groups = 1; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_k_krn; + cgh.parallel_for( + ndRange, GemmFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, lhs_indexer, rhs_indexer, res_indexer)); + } + else if (k > n && k > m) { + constexpr size_t m_groups = 4; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_k_krn; + cgh.parallel_for( + ndRange, GemmFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, lhs_indexer, rhs_indexer, res_indexer)); + } + else { + constexpr int wi_delta_n = 2; + constexpr int wi_delta_m = 4; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = + class gemm_nm_krn; + cgh.parallel_for( + ndRange, + GemmFunctorThreadNM( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + }); + + return gemm_ev; +} + +template +class GemmNoAtomicFunctorThreadNM +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_A_block; + LocAccT2 local_B_block; + size_t n = 0; + size_t wg_delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t wi_delta_k = 0; + size_t m = 0; + size_t m_blocks = 0; + size_t wg_delta_m = 0; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmNoAtomicFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + ResIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), + local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), + k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), + m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), + res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t gr_id = it.get_group_linear_id(); + // lift group_id to (block_i, block_j, block_s), + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s < + // k_blocks + size_t block_i = gr_id / (m_blocks * k_blocks); + size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + size_t block_j = block_r / k_blocks; + size_t block_s = block_r - block_j * k_blocks; + + size_t lid = it.get_local_linear_id(); + size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + + // load A block and B blocks into SLM + + size_t i = block_i * wi_delta_n * wg_delta_n; + size_t j = block_j * wi_delta_m * wg_delta_m; + size_t s = block_s * wi_delta_k; + + const std::int64_t a_st0 = k; + const std::int64_t a_st1 = 1; + + const std::int64_t b_st0 = m; + const std::int64_t b_st1 = 1; + + const std::int64_t c_st0 = m; + const std::int64_t c_st1 = 1; + + size_t lws = it.get_local_range(0); + + for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { + size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_i = i + v_i; + size_t g_s = s + v_s; + + local_A_block[vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : resT(0); + } + + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { + size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m + size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_j0 = j + v_j * wi_delta_m; + size_t g_s = s + v_s; + + sycl::vec vec{}; +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { + size_t g_j = g_j0 + lane_id; + vec[lane_id] = + (g_j < m && g_s < k) + ? static_cast( + rhs[rhs_indexer(g_s * b_st0 + g_j * b_st1)]) + : resT(0); + } + + local_B_block[vid] = vec; + } + + it.barrier(sycl::access::fence_space::local_space); + + i += local_i * wi_delta_n; + j += local_j * wi_delta_m; + + size_t a_offset = local_i * wi_delta_k * wi_delta_n; + size_t b_offset = local_j * wi_delta_k; + + constexpr resT identity_(0); + + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { + size_t a_pr_offset = private_i * wi_delta_k; + + sycl::vec local_sum(identity_); + for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { + local_sum = local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } + + size_t gl_i = i + private_i; + +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { + size_t gl_j = j + lane_id; + + if (gl_i < n && gl_j < m) { + res[res_indexer(gl_i * c_st0 + gl_j * c_st1 + + block_s * n * m)] = local_sum[lane_id]; + } + } + } + } +}; + +template +class GemmNoAtomicFunctorThreadNM +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_A_block; + LocAccT2 local_B_block; + size_t n = 0; + size_t wg_delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t wi_delta_k = 0; + size_t m = 0; + size_t m_blocks = 0; + size_t wg_delta_m = 0; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmNoAtomicFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + ResIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), + local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), + k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), + m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), + res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t gr_id = it.get_group_linear_id(); + // lift group_id to (block_i, block_j, block_s), + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s < + // k_blocks + size_t block_i = gr_id / (m_blocks * k_blocks); + size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + size_t block_j = block_r / k_blocks; + size_t block_s = block_r - block_j * k_blocks; + + size_t lid = it.get_local_linear_id(); + size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + + // load A block and B blocks into SLM + + size_t i = block_i * wi_delta_n * wg_delta_n; + size_t j = block_j * wg_delta_m; + size_t s = block_s * wi_delta_k; + + const std::int64_t a_st0 = k; + const std::int64_t a_st1 = 1; + + const std::int64_t b_st0 = m; + const std::int64_t b_st1 = 1; + + const std::int64_t c_st0 = m; + const std::int64_t c_st1 = 1; + + size_t lws = it.get_local_range(0); + + for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { + size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_i = i + v_i; + size_t g_s = s + v_s; + + local_A_block[vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : resT(0); + } + + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { + size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m + size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_j0 = j + v_j; + size_t g_s = s + v_s; + + resT val = (g_j0 < m && g_s < k) + ? static_cast( + rhs[rhs_indexer(g_s * b_st0 + g_j0 * b_st1)]) + : resT(0); + + local_B_block[vid] = val; + } + + it.barrier(sycl::access::fence_space::local_space); + + i += local_i * wi_delta_n; + j += local_j; + + size_t a_offset = local_i * wi_delta_k * wi_delta_n; + size_t b_offset = local_j * wi_delta_k; + + constexpr resT identity_(0); + + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { + size_t a_pr_offset = private_i * wi_delta_k; + + resT local_sum(identity_); + for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { + local_sum = local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } + + size_t gl_i = i + private_i; + + if (gl_i < n && j < m) { + res[res_indexer(gl_i * c_st0 + j * c_st1 + block_s * n * m)] = + local_sum; + } + } + } +}; + +template +class GemmNoAtomicFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + size_t n = 0; + size_t n_blocks = 0; + size_t delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t delta_k = 0; + size_t n_wi = 0; + size_t m = 0; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmNoAtomicFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + ResIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t gr_id = it.get_group_linear_id(); + size_t lid = it.get_local_linear_id(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + size_t block_j = + gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks + size_t block_r = + gr_id - block_j * (n_blocks * + k_blocks); // 0 <= block_r < n_blocks * k_blocks + size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks + size_t block_i = + block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + + size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + + size_t i = block_i * delta_n + local_i; + size_t j = m_groups * block_j; + size_t s = block_s * delta_k * n_wi + local_s; + + constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { + size_t sq = s + q; + size_t sqmj = sq * m + j; + sycl::vec local_B_vec; +#pragma unroll + for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) { + local_B_vec[vec_idx] = + (sq < k && j + vec_idx < m) + ? static_cast( + rhs[rhs_indexer(sqmj + vec_idx)]) + : identity_; + } + local_B_block[local_s + q] = local_B_vec; + } + } + + it.barrier(sycl::access::fence_space::local_space); + + size_t t_shift = block_s * delta_k * n_wi; + size_t global_s_offset = i * k + t_shift; + + sycl::vec private_sum(identity_); + constexpr sycl::vec vec_identity_(identity_); + for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { + private_sum += ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : vec_identity_; + } + + size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + sycl::vec local_sum(workspace[workspace_i_shift]); + for (size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + const size_t res_offset = (block_s * n * m); + res[res_indexer(i * m + j) + res_offset] = local_sum[0]; + +#pragma unroll + for (size_t vec_id = 1; vec_id < m_groups; ++vec_id) { + if (j + vec_id < m) { + res[res_indexer(i * m + j + vec_id) + res_offset] = + local_sum[vec_id]; + } + } + } + } +}; + +template +class GemmNoAtomicFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + size_t n = 0; + size_t n_blocks = 0; + size_t delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t delta_k = 0; + size_t n_wi = 0; + size_t m = 0; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmNoAtomicFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + ResIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t gr_id = it.get_group_linear_id(); + size_t lid = it.get_local_linear_id(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + size_t block_j = + gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks + size_t block_r = + gr_id - block_j * (n_blocks * + k_blocks); // 0 <= block_r < n_blocks * k_blocks + size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks + size_t block_i = + block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + + size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + + size_t i = block_i * delta_n + local_i; + size_t j = block_j; + size_t s = block_s * delta_k * n_wi + local_s; + + constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { + size_t sq = s + q; + size_t sqmj = sq * m + j; + local_B_block[local_s + q] = + (sq < k && j < m) + ? static_cast(rhs[rhs_indexer(sqmj)]) + : identity_; + } + } + + it.barrier(sycl::access::fence_space::local_space); + + size_t t_shift = block_s * delta_k * n_wi; + size_t global_s_offset = i * k + t_shift; + + resT private_sum(identity_); + for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { + private_sum += ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : identity_; + } + + size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + resT local_sum(workspace[workspace_i_shift]); + for (size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + res[res_indexer(i * m + j) + (block_s * n * m)] = local_sum; + } + } +}; + +template +class gemm_tree_nm_krn; + +template +class gemm_tree_k_krn; + +template +sycl::event gemm_tree_k_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t n, + size_t k, + size_t m, + int inner_nd, + int lhs_outer_nd, + const py::ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const py::ssize_t *rhs_outer_inner_shapes_strides, + int res_nd, + const py::ssize_t *res_shapes_strides, + const std::vector &depends) +{ + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + sycl::event gemm_ev; + if (k <= (delta_k * n_wi)) { + gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer(res_nd, 0, res_shapes_strides); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + } + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = n * m; + size_t reduction_nelems = (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-groups is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + ResIndexerT res_indexer{}; + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + } + }); + + sycl::event red_ev = single_reduction_for_gemm( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, res_nd, 0, + res_shapes_strides, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + ResIndexerT res_indexer{}; + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, lhs_indexer, rhs_indexer, + res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, lhs_indexer, rhs_indexer, + res_indexer)); + } + }); + // tree_reduction_for_gemm returns sycl::event for reduction + sycl::event red_ev = tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, + wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, + res_nd, 0, res_shapes_strides, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } +} + +template +sycl::event gemm_tree_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t n, + size_t k, + size_t m, + int inner_nd, + int lhs_outer_nd, + const py::ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const py::ssize_t *rhs_outer_inner_shapes_strides, + int res_nd, + const py::ssize_t *res_shapes_strides, + const std::vector &depends) +{ + constexpr int wi_delta_n = 2; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + // each group processes delta_k items in a column, + // so no need to allocate temp memory if one group needed + if (k <= wi_delta_k) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer(res_nd, 0, res_shapes_strides); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); + + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = n * m; + size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; + + // more than one work-groups is needed, requires a temporary + // wi_delta_k elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + ResIndexerT res_indexer{}; + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); + + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, + wi_delta_m>( + lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, + n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, + wi_delta_m>( + lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, + n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + }); + + sycl::event red_ev = single_reduction_for_gemm( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, res_nd, 0, + res_shapes_strides, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + ResIndexerT res_indexer{}; + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); + + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, + wi_delta_m>(lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, + m_blocks, wg_delta_m, lhs_indexer, + rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, + wi_delta_m>(lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, + m_blocks, wg_delta_m, lhs_indexer, + rhs_indexer, res_indexer)); + } + }); + + sycl::event red_ev = tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, + wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, + res_nd, 0, res_shapes_strides, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } +} + +template class gemm_tree_empty_krn; + +template +sycl::event gemm_tree_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t n, + size_t k, + size_t m, + int inner_nd, + int lhs_outer_nd, + const py::ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const py::ssize_t *rhs_outer_inner_shapes_strides, + int res_nd, + const py::ssize_t *res_shapes_strides, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + if (k == 0) { + sycl::event gemm_no_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using IndexerT = dpctl::tensor::offset_utils::StridedIndexer; + IndexerT res_indexer(res_nd, 0, res_shapes_strides); + using InitKernelName = + class gemm_tree_empty_krn; + cgh.parallel_for( + sycl::range<1>(n * m), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = resTy(0); + }); + }); + return gemm_no_reduction_ev; + } + + if ((k > n && k > m) || m < 4) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m < 4) { + return gemm_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, + lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); + } + else { + return gemm_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, + lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); + } + } + else { + return gemm_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); + } + } + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + return gemm_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); + } + else { + return gemm_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); + } + } +} + +template +sycl::event gemm_contig_tree_k_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t n, + size_t k, + size_t m, + std::vector const &depends) +{ + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + sycl::event gemm_ev; + if (k <= (delta_k * n_wi)) { + gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + } + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = n * m; + size_t reduction_nelems = (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-groups is needed, requires a + // temporary delta_k * n_wi elements processed along k, + // so if more to process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + } + }); + + sycl::event red_ev = + single_reduction_for_gemm_contig( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, lhs_indexer, rhs_indexer, + res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, lhs_indexer, rhs_indexer, + res_indexer)); + } + }); + // tree_reduction_for_gemm_contig returns sycl::event + // for reduction + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, preferred_reductions_per_wi, + reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } +} + +template +sycl::event gemm_contig_tree_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t n, + size_t k, + size_t m, + std::vector const &depends) +{ + constexpr int wi_delta_n = 2; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + // each group processes delta_k items in a column, + // so no need to allocate temp memory if one group needed + if (k <= wi_delta_k) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); + + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = n * m; + size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; + + // more than one work-groups is needed, requires a temporary + // wi_delta_k elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); + + using KernelName = class gemm_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, + n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, + n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + }); + + sycl::event red_ev = + single_reduction_for_gemm_contig( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); + + using KernelName = class gemm_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + lhs_indexer, rhs_indexer, res_indexer)); + } + }); + + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, preferred_reductions_per_wi, + reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } +} + +template +sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t n, + size_t k, + size_t m, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + if (k == 0) { + sycl::event gemm_no_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.fill(res_tp, resTy(0), n * m); + }); + return gemm_no_reduction_ev; + } + + if ((k > n && k > m) || m < 4) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m < 4) { + return gemm_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + } + else { + return gemm_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + } + } + else { + return gemm_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + } + } + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + return gemm_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + } + else { + return gemm_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + } + } +} + +template +class GemmBatchFunctorThreadNM +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_A_block; + LocAccT2 local_B_block; + size_t n = 0; + size_t wg_delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t wi_delta_k = 0; + size_t m = 0; + size_t m_blocks = 0; + size_t wg_delta_m = 0; + size_t batch_nelems; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + OuterInnerDimsIndexerT res_indexer; + +public: + GemmBatchFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + size_t batch_nelems_, + BatchDimsIndexerT batch_indexer_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + OuterInnerDimsIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), + local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), + k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), + m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + batch_nelems(batch_nelems_), batch_indexer(batch_indexer_), + lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), + res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; + const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; + + const auto &three_offsets_ = + batch_indexer(static_cast(m_id)); + + // lift group_id to (block_i, block_j, block_s), + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s + // < k_blocks + + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + size_t block_i = gr_id / (m_blocks * k_blocks); + size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + size_t block_j = block_r / k_blocks; + size_t block_s = block_r - block_j * k_blocks; + + size_t lid = it.get_local_linear_id(); + size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + + // load A block and B blocks into SLM + + size_t i = block_i * wi_delta_n * wg_delta_n; + size_t j = block_j * wi_delta_m * wg_delta_m; + size_t s = block_s * wi_delta_k; + + const std::int64_t a_st0 = k; + const std::int64_t a_st1 = 1; + + const std::int64_t b_st0 = m; + const std::int64_t b_st1 = 1; + + const std::int64_t c_st0 = m; + const std::int64_t c_st1 = 1; + + size_t lws = it.get_local_range(0); + + for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { + size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_i = i + v_i; + size_t g_s = s + v_s; + + local_A_block[vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_offset + + lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : resT(0); + } + + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { + size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m + size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_j0 = j + v_j * wi_delta_m; + size_t g_s = s + v_s; + + sycl::vec vec{}; +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { + size_t g_j = g_j0 + lane_id; + vec[lane_id] = + (g_j < m && g_s < k) + ? static_cast( + rhs[rhs_offset + + rhs_indexer(g_s * b_st0 + g_j * b_st1)]) + : resT(0); + } + + local_B_block[vid] = vec; + } + + it.barrier(sycl::access::fence_space::local_space); + + i += local_i * wi_delta_n; + j += local_j * wi_delta_m; + + size_t a_offset = local_i * wi_delta_k * wi_delta_n; + size_t b_offset = local_j * wi_delta_k; + + constexpr resT identity_(0); + + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { + size_t a_pr_offset = private_i * wi_delta_k; + + sycl::vec local_sum(identity_); + for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { + local_sum = local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } + + size_t gl_i = i + private_i; + +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { + size_t gl_j = j + lane_id; + + if (gl_i < n && gl_j < m) { + sycl::atomic_ref + aout(res[res_offset + + res_indexer(gl_i * c_st0 + gl_j * c_st1)]); + + aout += local_sum[lane_id]; + } + } + } + } +}; + +template +class GemmBatchFunctorThreadNM +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_A_block; + LocAccT2 local_B_block; + size_t n = 0; + size_t wg_delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t wi_delta_k = 0; + size_t m = 0; + size_t m_blocks = 0; + size_t wg_delta_m = 0; + size_t batch_nelems; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + OuterInnerDimsIndexerT res_indexer; + +public: + GemmBatchFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + size_t batch_nelems_, + BatchDimsIndexerT batch_indexer_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + OuterInnerDimsIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), + local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), + k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), + m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + batch_nelems(batch_nelems_), batch_indexer(batch_indexer_), + lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), + res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; + const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; + + const auto &three_offsets_ = + batch_indexer(static_cast(m_id)); + + // lift group_id to (block_i, block_j, block_s), + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s + // < k_blocks + + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + size_t block_i = gr_id / (m_blocks * k_blocks); + size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + size_t block_j = block_r / k_blocks; + size_t block_s = block_r - block_j * k_blocks; + + size_t lid = it.get_local_linear_id(); + size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + + // load A block and B blocks into SLM + + size_t i = block_i * wi_delta_n * wg_delta_n; + size_t j = block_j * wg_delta_m; + size_t s = block_s * wi_delta_k; + + const std::int64_t a_st0 = k; + const std::int64_t a_st1 = 1; + + const std::int64_t b_st0 = m; + const std::int64_t b_st1 = 1; + + const std::int64_t c_st0 = m; + const std::int64_t c_st1 = 1; + + size_t lws = it.get_local_range(0); + + for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { + size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_i = i + v_i; + size_t g_s = s + v_s; + + local_A_block[vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_offset + + lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : resT(0); + } + + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { + size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m + size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_j0 = j + v_j; + size_t g_s = s + v_s; + + resT val = (g_j0 < m && g_s < k) + ? static_cast( + rhs[rhs_offset + + rhs_indexer(g_s * b_st0 + g_j0 * b_st1)]) + : resT(0); + + local_B_block[vid] = val; + } + + it.barrier(sycl::access::fence_space::local_space); + + i += local_i * wi_delta_n; + j += local_j; + + size_t a_offset = local_i * wi_delta_k * wi_delta_n; + size_t b_offset = local_j * wi_delta_k; + + constexpr resT identity_(0); + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { + size_t a_pr_offset = private_i * wi_delta_k; + + resT local_sum(identity_); + for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { + local_sum = local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } + + size_t gl_i = i + private_i; + + if (gl_i < n && j < m) { + sycl::atomic_ref + aout(res[res_offset + + res_indexer(gl_i * c_st0 + j * c_st1)]); + + aout += local_sum; + } + } + } +}; + +template +class GemmBatchFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + size_t n = 0; + size_t n_blocks = 0; + size_t delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t delta_k = 0; + size_t n_wi = 0; + size_t m = 0; + size_t batch_nelems = 0; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + OuterInnerDimsIndexerT res_indexer; + +public: + GemmBatchFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + size_t batch_nelems_, + BatchDimsIndexerT batch_indexer_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + OuterInnerDimsIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), batch_nelems(batch_nelems_), + batch_indexer(batch_indexer_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + // for batching: + // (current matrix in batch) m_id = global_id / (global_range / + // batch_nelems) for lhs, offset = m_id * (n * k) for rhs, offset = + // m_id + // * (k * m) for res, offset = m_id * (n * m) + const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; + const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; + size_t lid = it.get_local_linear_id(); + + const auto &three_offsets_ = + batch_indexer(static_cast(m_id)); + + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + size_t block_j = + gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks + size_t block_r = + gr_id - block_j * (n_blocks * + k_blocks); // 0 <= block_r < n_blocks * k_blocks + size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks + size_t block_i = + block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + + size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + + size_t i = block_i * delta_n + local_i; + size_t j = m_groups * block_j; + size_t s = block_s * delta_k * n_wi + local_s; + + constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { + size_t sq = s + q; + size_t sqmj = sq * m + j; + sycl::vec local_B_vec; +#pragma unroll + for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) { + local_B_vec[vec_idx] = + (sq < k && j + vec_idx < m) + ? static_cast( + rhs[rhs_offset + rhs_indexer(sqmj + vec_idx)]) + : identity_; + } + local_B_block[local_s + q] = local_B_vec; + } + } + + it.barrier(sycl::access::fence_space::local_space); + + size_t t_shift = block_s * delta_k * n_wi; + size_t global_s_offset = i * k + t_shift; + + sycl::vec private_sum(identity_); + constexpr sycl::vec vec_identity_(identity_); + for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { + private_sum += + ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : vec_identity_; + } + + size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + sycl::vec local_sum(workspace[workspace_i_shift]); + for (size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + sycl::atomic_ref + aout0(res[res_offset + res_indexer(i * m + j)]); + + aout0 += local_sum[0]; + +#pragma unroll + for (size_t vec_id = 1; vec_id < m_groups; ++vec_id) { + if (j + vec_id < m) { + sycl::atomic_ref + aout1( + res[res_offset + res_indexer(i * m + j + vec_id)]); + + aout1 += local_sum[vec_id]; + } + } + } + } +}; + +template +class GemmBatchFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + size_t n = 0; + size_t n_blocks = 0; + size_t delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t delta_k = 0; + size_t n_wi = 0; + size_t m = 0; + size_t batch_nelems = 0; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + OuterInnerDimsIndexerT res_indexer; + +public: + GemmBatchFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + size_t batch_nelems_, + BatchDimsIndexerT batch_indexer_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + OuterInnerDimsIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), batch_nelems(batch_nelems_), + batch_indexer(batch_indexer_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + // for batching: + // (current matrix in batch) m_id = global_id / (global_range / + // batch_nelems) for lhs, offset = m_id * (n * k) for rhs, offset = + // m_id + // * (k * m) for res, offset = m_id * (n * m) + const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; + const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; + size_t lid = it.get_local_linear_id(); + + const auto &three_offsets_ = + batch_indexer(static_cast(m_id)); + + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + size_t block_j = + gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks + size_t block_r = + gr_id - block_j * (n_blocks * + k_blocks); // 0 <= block_r < n_blocks * k_blocks + size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks + size_t block_i = + block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + + size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + + size_t i = block_i * delta_n + local_i; + size_t j = block_j; + size_t s = block_s * delta_k * n_wi + local_s; + + constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { + size_t sq = s + q; + size_t sqmj = sq * m + j; + local_B_block[local_s + q] = + (sq < k && j < m) + ? static_cast(rhs[rhs_offset + rhs_indexer(sqmj)]) + : identity_; + ; + } + } + + it.barrier(sycl::access::fence_space::local_space); + + size_t t_shift = block_s * delta_k * n_wi; + size_t global_s_offset = i * k + t_shift; + + resT private_sum(identity_); + for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { + private_sum += + ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : identity_; + } + + size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + resT local_sum(workspace[workspace_i_shift]); + for (size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + sycl::atomic_ref + aout(res[res_offset + res_indexer(i * m + j)]); + + aout += local_sum; + } + } +}; + +template class gemm_batch_init_krn; + +template +class gemm_batch_k_krn; + +template +class gemm_batch_nm_krn; + +typedef sycl::event (*gemm_batch_impl_fn_ptr_t)( + sycl::queue &, + const char *, // lhs + const char *, // rhs + char *, // res + size_t, // batch nelems + size_t, // lhs outer nelems (n) + size_t, // inner nelems (k) + size_t, // rhs outer nelems (m) + int, // batching nd + const py::ssize_t *, // batch shape strides + py::ssize_t, // lhs batch offset + py::ssize_t, // rhs batch offset + py::ssize_t, // res batch offset + int, // inner dims + int, // lhs outer dims + const py::ssize_t *, // lhs outer and inner shape and strides + int, // rhs outer dims + const py::ssize_t *, // rhs outer and inner shape and strides + int, // res outer dims + const py::ssize_t *, // res outer and inner shape and strides + const py::ssize_t *, // res full shape and strides + std::vector const &); + +template +sycl::event gemm_batch_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + int batch_nd, + const py::ssize_t *batch_shape_strides, + py::ssize_t lhs_batch_offset, + py::ssize_t rhs_batch_offset, + py::ssize_t res_batch_offset, + int inner_nd, + int lhs_outer_nd, + const py::ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const py::ssize_t *rhs_outer_inner_shapes_strides, + int res_outer_nd, + const py::ssize_t *res_outer_shapes_strides, + const py::ssize_t *res_shape_strides, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using IndexerT = dpctl::tensor::offset_utils::StridedIndexer; + IndexerT res_indexer(batch_nd + res_outer_nd, res_batch_offset, + res_shape_strides); + using InitKernelName = class gemm_batch_init_krn; + cgh.parallel_for( + sycl::range<1>(n * m * batch_nelems), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = resTy(0); + }); + }); + + if (k == 0) { + return res_init_ev; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer(res_outer_nd, 0, + res_outer_shapes_strides); + using BatchDimsIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + BatchDimsIndexerT batch_indexer(batch_nd, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, + batch_shape_strides); + if (m < 4) { + constexpr size_t m_groups = 1; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = + class gemm_batch_k_krn; + cgh.parallel_for( + ndRange, GemmBatchFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + else if (k > n && k > m) { + constexpr size_t m_groups = 4; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = + class gemm_batch_k_krn; + cgh.parallel_for( + ndRange, GemmBatchFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + else { + constexpr int wi_delta_n = 2; + constexpr int wi_delta_m = 4; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = + class gemm_batch_nm_krn; + cgh.parallel_for( + ndRange, + GemmBatchFunctorThreadNM( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + }); + + return gemm_ev; +} + +typedef sycl::event (*gemm_batch_contig_impl_fn_ptr_t)( + sycl::queue &, + const char *, // lhs + const char *, // rhs + char *, // res + size_t, // batch nelems + size_t, // n + size_t, // k + size_t, // m + py::ssize_t, // lhs batch offset + py::ssize_t, // rhs batch offset + py::ssize_t, // res batch offset + std::vector const &); + +template +sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + py::ssize_t lhs_batch_offset, + py::ssize_t rhs_batch_offset, + py::ssize_t res_batch_offset, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = + reinterpret_cast(lhs_cp) + lhs_batch_offset; + const rhsTy *rhs_tp = + reinterpret_cast(rhs_cp) + rhs_batch_offset; + resTy *res_tp = reinterpret_cast(res_cp) + res_batch_offset; + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.fill(res_tp, resTy(0), n * m * batch_nelems); + }); + + if (k == 0) { + return res_init_ev; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); + + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); + if (m < 4) { + constexpr size_t m_groups = 1; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = + class gemm_batch_k_krn; + cgh.parallel_for( + ndRange, GemmBatchFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + else if (k > n && k > m) { + constexpr size_t m_groups = 4; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = + class gemm_batch_k_krn; + cgh.parallel_for( + ndRange, GemmBatchFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + else { + constexpr int wi_delta_n = 2; + constexpr int wi_delta_m = 4; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = + class gemm_batch_nm_krn; + cgh.parallel_for( + ndRange, + GemmBatchFunctorThreadNM( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + }); + + return gemm_ev; +} + +template +class GemmBatchNoAtomicFunctorThreadNM +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_A_block; + LocAccT2 local_B_block; + size_t n = 0; + size_t wg_delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t wi_delta_k = 0; + size_t m = 0; + size_t m_blocks = 0; + size_t wg_delta_m = 0; + size_t batch_nelems; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmBatchNoAtomicFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + size_t batch_nelems_, + BatchDimsIndexerT batch_indexer_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + ResIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), + local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), + k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), + m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + batch_nelems(batch_nelems_), batch_indexer(batch_indexer_), + lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), + res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; + const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; + + const auto &three_offsets_ = + batch_indexer(static_cast(m_id)); + + // lift group_id to (block_i, block_j, block_s), + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s + // < k_blocks + + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + size_t block_i = gr_id / (m_blocks * k_blocks); + size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + size_t block_j = block_r / k_blocks; + size_t block_s = block_r - block_j * k_blocks; + + size_t lid = it.get_local_linear_id(); + size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + + // load A block and B blocks into SLM + + size_t i = block_i * wi_delta_n * wg_delta_n; + size_t j = block_j * wi_delta_m * wg_delta_m; + size_t s = block_s * wi_delta_k; + + const std::int64_t a_st0 = k; + const std::int64_t a_st1 = 1; + + const std::int64_t b_st0 = m; + const std::int64_t b_st1 = 1; + + const std::int64_t c_st0 = m; + const std::int64_t c_st1 = 1; + + size_t lws = it.get_local_range(0); + + for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { + size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_i = i + v_i; + size_t g_s = s + v_s; + + local_A_block[vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_offset + + lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : resT(0); + } + + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { + size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m + size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_j0 = j + v_j * wi_delta_m; + size_t g_s = s + v_s; + + sycl::vec vec{}; +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { + size_t g_j = g_j0 + lane_id; + vec[lane_id] = + (g_j < m && g_s < k) + ? static_cast( + rhs[rhs_offset + + rhs_indexer(g_s * b_st0 + g_j * b_st1)]) + : resT(0); + } + + local_B_block[vid] = vec; + } + + it.barrier(sycl::access::fence_space::local_space); + + i += local_i * wi_delta_n; + j += local_j * wi_delta_m; + + size_t a_offset = local_i * wi_delta_k * wi_delta_n; + size_t b_offset = local_j * wi_delta_k; + + constexpr resT identity_(0); + + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { + size_t a_pr_offset = private_i * wi_delta_k; + + sycl::vec local_sum(identity_); + for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { + local_sum = local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } + + size_t gl_i = i + private_i; + +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { + size_t gl_j = j + lane_id; + + if (gl_i < n && gl_j < m) { + res[res_offset + res_indexer(gl_i * c_st0 + gl_j * c_st1) + + (block_s * n * m * batch_nelems)] = local_sum[lane_id]; + } + } + } + } +}; + +template +class GemmBatchNoAtomicFunctorThreadNM +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_A_block; + LocAccT2 local_B_block; + size_t n = 0; + size_t wg_delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t wi_delta_k = 0; + size_t m = 0; + size_t m_blocks = 0; + size_t wg_delta_m = 0; + size_t batch_nelems; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmBatchNoAtomicFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + size_t batch_nelems_, + BatchDimsIndexerT batch_indexer_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + ResIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), + local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), + k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), + m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + batch_nelems(batch_nelems_), batch_indexer(batch_indexer_), + lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), + res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; + const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; + + const auto &three_offsets_ = + batch_indexer(static_cast(m_id)); + + // lift group_id to (block_i, block_j, block_s), + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s + // < k_blocks + + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + size_t block_i = gr_id / (m_blocks * k_blocks); + size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + size_t block_j = block_r / k_blocks; + size_t block_s = block_r - block_j * k_blocks; + + size_t lid = it.get_local_linear_id(); + size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + + // load A block and B blocks into SLM + + size_t i = block_i * wi_delta_n * wg_delta_n; + size_t j = block_j * wg_delta_m; + size_t s = block_s * wi_delta_k; + + const std::int64_t a_st0 = k; + const std::int64_t a_st1 = 1; + + const std::int64_t b_st0 = m; + const std::int64_t b_st1 = 1; + + const std::int64_t c_st0 = m; + const std::int64_t c_st1 = 1; + + size_t lws = it.get_local_range(0); + + for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { + size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_i = i + v_i; + size_t g_s = s + v_s; + + local_A_block[vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_offset + + lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : resT(0); + } + + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { + size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m + size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_j0 = j + v_j; + size_t g_s = s + v_s; + + resT val = (g_j0 < m && g_s < k) + ? static_cast( + rhs[rhs_offset + + rhs_indexer(g_s * b_st0 + g_j0 * b_st1)]) + : resT(0); + + local_B_block[vid] = val; + } + + it.barrier(sycl::access::fence_space::local_space); + + i += local_i * wi_delta_n; + j += local_j; + + size_t a_offset = local_i * wi_delta_k * wi_delta_n; + size_t b_offset = local_j * wi_delta_k; + + constexpr resT identity_(0); + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { + size_t a_pr_offset = private_i * wi_delta_k; + + resT local_sum(identity_); + for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { + local_sum = local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } + + size_t gl_i = i + private_i; + + if (gl_i < n && j < m) { + res[res_offset + res_indexer(gl_i * c_st0 + j * c_st1) + + (block_s * n * m * batch_nelems)] = local_sum; + } + } + } +}; + +template +class GemmBatchNoAtomicFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + size_t n = 0; + size_t n_blocks = 0; + size_t delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t delta_k = 0; + size_t n_wi = 0; + size_t m = 0; + size_t batch_nelems = 0; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmBatchNoAtomicFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + size_t batch_nelems_, + BatchDimsIndexerT batch_indexer_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + ResIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), batch_nelems(batch_nelems_), + batch_indexer(batch_indexer_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; + const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; + size_t lid = it.get_local_linear_id(); + + const auto &three_offsets_ = + batch_indexer(static_cast(m_id)); + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + size_t block_j = + gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks + size_t block_r = + gr_id - block_j * (n_blocks * + k_blocks); // 0 <= block_r < n_blocks * k_blocks + size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks + size_t block_i = + block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + + size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + + size_t i = block_i * delta_n + local_i; + size_t j = m_groups * block_j; + size_t s = block_s * delta_k * n_wi + local_s; + + constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { + size_t sq = s + q; + size_t sqmj = sq * m + j; + sycl::vec local_B_vec; +#pragma unroll + for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) { + local_B_vec[vec_idx] = + (sq < k && j + vec_idx < m) + ? static_cast( + rhs[rhs_offset + rhs_indexer(sqmj + vec_idx)]) + : identity_; + } + local_B_block[local_s + q] = local_B_vec; + } + } + + it.barrier(sycl::access::fence_space::local_space); + + size_t t_shift = block_s * delta_k * n_wi; + size_t global_s_offset = i * k + t_shift; + + sycl::vec private_sum(identity_); + constexpr sycl::vec vec_identity_(identity_); + for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { + private_sum += + ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : vec_identity_; + } + + size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + sycl::vec local_sum(workspace[workspace_i_shift]); + for (size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + const size_t total_offset = + res_offset + (block_s * n * m * batch_nelems); + res[total_offset + res_indexer(i * m + j)] = local_sum[0]; + +#pragma unroll + for (size_t vec_id = 1; vec_id < m_groups; ++vec_id) { + if (j + vec_id < m) { + res[total_offset + res_indexer(i * m + j + vec_id)] = + local_sum[1]; + } + } + } + } +}; + +template +class GemmBatchNoAtomicFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + size_t n = 0; + size_t n_blocks = 0; + size_t delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t delta_k = 0; + size_t n_wi = 0; + size_t m = 0; + size_t batch_nelems = 0; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmBatchNoAtomicFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + size_t batch_nelems_, + BatchDimsIndexerT batch_indexer_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + ResIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), batch_nelems(batch_nelems_), + batch_indexer(batch_indexer_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; + const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; + size_t lid = it.get_local_linear_id(); + + const auto &three_offsets_ = + batch_indexer(static_cast(m_id)); + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + size_t block_j = + gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks + size_t block_r = + gr_id - block_j * (n_blocks * + k_blocks); // 0 <= block_r < n_blocks * k_blocks + size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks + size_t block_i = + block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + + size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + + size_t i = block_i * delta_n + local_i; + size_t j = block_j; + size_t s = block_s * delta_k * n_wi + local_s; + + constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { + size_t sq = s + q; + size_t sqmj = sq * m + j; + local_B_block[local_s + q] = + (sq < k && j < m) + ? static_cast(rhs[rhs_offset + rhs_indexer(sqmj)]) + : identity_; + } + } + + it.barrier(sycl::access::fence_space::local_space); + + size_t t_shift = block_s * delta_k * n_wi; + size_t global_s_offset = i * k + t_shift; + + resT private_sum(identity_); + for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { + private_sum += + ((i < n) && ((t + t_shift < k))) + ? (static_cast( + lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : identity_; + } + + size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + resT local_sum(workspace[workspace_i_shift]); + for (size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + res[res_offset + res_indexer(i * m + j) + + (block_s * n * m * batch_nelems)] = local_sum; + } + } +}; + +template +class gemm_batch_tree_k_krn; + +template +class gemm_batch_tree_nm_krn; + +template +sycl::event +gemm_batch_tree_k_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + int batch_nd, + const py::ssize_t *batch_shape_strides, + py::ssize_t lhs_batch_offset, + py::ssize_t rhs_batch_offset, + py::ssize_t res_batch_offset, + int inner_nd, + int lhs_outer_nd, + const py::ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const py::ssize_t *rhs_outer_inner_shapes_strides, + int res_outer_nd, + const py::ssize_t *res_outer_shapes_strides, + const py::ssize_t *res_shape_strides, + std::vector const &depends) +{ + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + if (k <= (delta_k * n_wi)) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer(res_outer_nd, 0, + res_outer_shapes_strides); + using BatchDimsIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + BatchDimsIndexerT batch_indexer(batch_nd, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, + batch_shape_strides); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer)); + } + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-group is needed, requires a + // temporary delta_k * n_wi elements processed along k, + // so if more to process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + TmpIndexerT res_indexer{}; + + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + StridedIndexer, UnpackedStridedIndexer, Strided1DIndexer>; + StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, + batch_shape_strides); + UnpackedStridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides, + batch_shape_strides + 2 * batch_nd); + Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + BatchDimsIndexerT batch_indexer( + lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + }); + + sycl::event red_ev = single_reduction_for_gemm( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, + batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, + batch_shape_strides); + StridedIndexer rhs_batch_indexer(batch_nd, rhs_batch_offset, + batch_shape_strides + + 2 * batch_nd); + Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + BatchDimsIndexerT batch_indexer( + lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, m_groups>; + + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer)); + } + }); + + sycl::event red_ev = tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, + wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, + batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } +} + +template +sycl::event +gemm_batch_tree_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + int batch_nd, + const py::ssize_t *batch_shape_strides, + py::ssize_t lhs_batch_offset, + py::ssize_t rhs_batch_offset, + py::ssize_t res_batch_offset, + int inner_nd, + int lhs_outer_nd, + const py::ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const py::ssize_t *rhs_outer_inner_shapes_strides, + int res_outer_nd, + const py::ssize_t *res_outer_shapes_strides, + const py::ssize_t *res_shape_strides, + std::vector const &depends) +{ + constexpr int wi_delta_n = 2; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + // each group processes delta_k * n_wi + // items in a column, so no need for allocating + // temp memory if only one group is needed + if (k <= wi_delta_k) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer(res_outer_nd, 0, + res_outer_shapes_strides); + using BatchDimsIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + BatchDimsIndexerT batch_indexer(batch_nd, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, + batch_shape_strides); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; + + // more than one work-group is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + TmpIndexerT res_indexer{}; + + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + StridedIndexer, UnpackedStridedIndexer, Strided1DIndexer>; + StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, + batch_shape_strides); + UnpackedStridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides, + batch_shape_strides + 2 * batch_nd); + Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + BatchDimsIndexerT batch_indexer( + lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, + n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, + n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer)); + } + }); + + sycl::event red_ev = single_reduction_for_gemm( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, + batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + StridedIndexer, UnpackedStridedIndexer, Strided1DIndexer>; + StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, + batch_shape_strides); + UnpackedStridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides, + batch_shape_strides + 2 * batch_nd); + Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + BatchDimsIndexerT batch_indexer( + lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + }); + + sycl::event red_ev = tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, + wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, + batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } +} + +template +class gemm_batch_tree_empty_krn; + +template +sycl::event +gemm_batch_tree_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + int batch_nd, + const py::ssize_t *batch_shape_strides, + py::ssize_t lhs_batch_offset, + py::ssize_t rhs_batch_offset, + py::ssize_t res_batch_offset, + int inner_nd, + int lhs_outer_nd, + const py::ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const py::ssize_t *rhs_outer_inner_shapes_strides, + int res_outer_nd, + const py::ssize_t *res_outer_shapes_strides, + const py::ssize_t *res_shape_strides, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + if (k == 0) { + sycl::event gemm_batch_no_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using IndexerT = dpctl::tensor::offset_utils::StridedIndexer; + IndexerT res_indexer(batch_nd + res_outer_nd, res_batch_offset, + res_shape_strides); + using InitKernelName = + class gemm_batch_tree_empty_krn; + cgh.parallel_for( + sycl::range<1>(n * m * batch_nelems), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = resTy(0); + }); + }); + return gemm_batch_no_reduction_ev; + } + + if ((k > n && k > m) || m < 4) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m < 4) { + return gemm_batch_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_nd, batch_shape_strides, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } + else { + return gemm_batch_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_nd, batch_shape_strides, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } + } + else { + return gemm_batch_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, + batch_shape_strides, lhs_batch_offset, rhs_batch_offset, + res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } + } + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + return gemm_batch_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, + batch_shape_strides, lhs_batch_offset, rhs_batch_offset, + res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } + else { // m > 1, n > k or m > k, resTy complex + return gemm_batch_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, + batch_shape_strides, lhs_batch_offset, rhs_batch_offset, + res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } + } +} + +template +sycl::event +gemm_batch_contig_tree_k_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + std::vector const &depends) +{ + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + if (k <= (delta_k * n_wi)) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + using dpctl::tensor::offset_utils::Strided1DIndexer; + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer)); + } + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-group is needed, requires a + // temporary delta_k * n_wi elements processed along k, + // so if more to process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, tmp_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, tmp_indexer)); + } + }); + + sycl::event red_ev = + single_reduction_for_gemm_contig( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, tmp_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, tmp_indexer)); + } + }); + + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, preferred_reductions_per_wi, + reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } +} + +template +sycl::event +gemm_batch_contig_tree_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + std::vector const &depends) +{ + constexpr int wi_delta_n = 2; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + // each group processes delta_k * n_wi + // items in a column, so no need for allocating + // temp memory if only one group is needed + if (k <= wi_delta_k) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; + + // more than one work-group is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, + n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, tmp_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, + n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, tmp_indexer)); + } + }); + + sycl::event red_ev = + single_reduction_for_gemm_contig( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, tmp_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, tmp_indexer)); + } + }); + + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, preferred_reductions_per_wi, + reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } +} + +template +sycl::event +gemm_batch_contig_tree_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + py::ssize_t lhs_batch_offset, + py::ssize_t rhs_batch_offset, + py::ssize_t res_batch_offset, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = + reinterpret_cast(lhs_cp) + lhs_batch_offset; + const rhsTy *rhs_tp = + reinterpret_cast(rhs_cp) + rhs_batch_offset; + resTy *res_tp = reinterpret_cast(res_cp) + res_batch_offset; + + if (k == 0) { + sycl::event gemm_batch_no_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.fill(res_tp, resTy(0), n * m * batch_nelems); + }); + return gemm_batch_no_reduction_ev; + } + + if ((k > n && k > m) || m < 4) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m < 4) { + return gemm_batch_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + depends); + } + else { + return gemm_batch_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + depends); + } + } + else { + return gemm_batch_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); + } + } + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + return gemm_batch_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); + } + else { // m > 1, n > k or m > k, resTy complex + return gemm_batch_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); + } + } +} + +} // namespace kernels +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/include/utils/offset_utils.hpp b/dpctl/tensor/libtensor/include/utils/offset_utils.hpp index 523620737b..440d0d9d0b 100644 --- a/dpctl/tensor/libtensor/include/utils/offset_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/offset_utils.hpp @@ -450,6 +450,32 @@ struct ThreeZeroOffsets_Indexer } }; +template +struct ThreeOffsets_CombinedIndexer +{ +private: + FirstIndexerT first_indexer_; + SecondIndexerT second_indexer_; + ThirdIndexerT third_indexer_; + +public: + ThreeOffsets_CombinedIndexer(const FirstIndexerT &first_indexer, + const SecondIndexerT &second_indexer, + const ThirdIndexerT &third_indexer) + : first_indexer_(first_indexer), second_indexer_(second_indexer), + third_indexer_(third_indexer) + { + } + + ThreeOffsets operator()(py::ssize_t gid) const + { + return ThreeOffsets( + first_indexer_(gid), second_indexer_(gid), third_indexer_(gid)); + } +}; + template struct FourOffsets { FourOffsets() diff --git a/dpctl/tensor/libtensor/source/elementwise_functions.cpp b/dpctl/tensor/libtensor/source/elementwise_functions.cpp deleted file mode 100644 index 9ab7c0807c..0000000000 --- a/dpctl/tensor/libtensor/source/elementwise_functions.cpp +++ /dev/null @@ -1,5155 +0,0 @@ -//===----------- Implementation of _tensor_impl module ---------*-C++-*-/===// -// -// Data Parallel Control (dpctl) -// -// Copyright 2020-2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file defines functions of dpctl.tensor._tensor_impl extensions, -/// specifically functions for elementwise operations. -//===----------------------------------------------------------------------===// - -#include "dpctl4pybind11.hpp" -#include -#include -#include -#include -#include - -#include "elementwise_functions.hpp" -#include "utils/type_dispatch.hpp" - -#include "kernels/elementwise_functions/abs.hpp" -#include "kernels/elementwise_functions/acos.hpp" -#include "kernels/elementwise_functions/acosh.hpp" -#include "kernels/elementwise_functions/add.hpp" -#include "kernels/elementwise_functions/asin.hpp" -#include "kernels/elementwise_functions/asinh.hpp" -#include "kernels/elementwise_functions/atan.hpp" -#include "kernels/elementwise_functions/atan2.hpp" -#include "kernels/elementwise_functions/atanh.hpp" -#include "kernels/elementwise_functions/bitwise_and.hpp" -#include "kernels/elementwise_functions/bitwise_invert.hpp" -#include "kernels/elementwise_functions/bitwise_left_shift.hpp" -#include "kernels/elementwise_functions/bitwise_or.hpp" -#include "kernels/elementwise_functions/bitwise_right_shift.hpp" -#include "kernels/elementwise_functions/bitwise_xor.hpp" -#include "kernels/elementwise_functions/cbrt.hpp" -#include "kernels/elementwise_functions/ceil.hpp" -#include "kernels/elementwise_functions/conj.hpp" -#include "kernels/elementwise_functions/copysign.hpp" -#include "kernels/elementwise_functions/cos.hpp" -#include "kernels/elementwise_functions/cosh.hpp" -#include "kernels/elementwise_functions/equal.hpp" -#include "kernels/elementwise_functions/exp.hpp" -#include "kernels/elementwise_functions/exp2.hpp" -#include "kernels/elementwise_functions/expm1.hpp" -#include "kernels/elementwise_functions/floor.hpp" -#include "kernels/elementwise_functions/floor_divide.hpp" -#include "kernels/elementwise_functions/greater.hpp" -#include "kernels/elementwise_functions/greater_equal.hpp" -#include "kernels/elementwise_functions/hypot.hpp" -#include "kernels/elementwise_functions/imag.hpp" -#include "kernels/elementwise_functions/isfinite.hpp" -#include "kernels/elementwise_functions/isinf.hpp" -#include "kernels/elementwise_functions/isnan.hpp" -#include "kernels/elementwise_functions/less.hpp" -#include "kernels/elementwise_functions/less_equal.hpp" -#include "kernels/elementwise_functions/log.hpp" -#include "kernels/elementwise_functions/log10.hpp" -#include "kernels/elementwise_functions/log1p.hpp" -#include "kernels/elementwise_functions/log2.hpp" -#include "kernels/elementwise_functions/logaddexp.hpp" -#include "kernels/elementwise_functions/logical_and.hpp" -#include "kernels/elementwise_functions/logical_not.hpp" -#include "kernels/elementwise_functions/logical_or.hpp" -#include "kernels/elementwise_functions/logical_xor.hpp" -#include "kernels/elementwise_functions/maximum.hpp" -#include "kernels/elementwise_functions/minimum.hpp" -#include "kernels/elementwise_functions/multiply.hpp" -#include "kernels/elementwise_functions/negative.hpp" -#include "kernels/elementwise_functions/not_equal.hpp" -#include "kernels/elementwise_functions/positive.hpp" -#include "kernels/elementwise_functions/pow.hpp" -#include "kernels/elementwise_functions/proj.hpp" -#include "kernels/elementwise_functions/real.hpp" -#include "kernels/elementwise_functions/remainder.hpp" -#include "kernels/elementwise_functions/round.hpp" -#include "kernels/elementwise_functions/rsqrt.hpp" -#include "kernels/elementwise_functions/sign.hpp" -#include "kernels/elementwise_functions/signbit.hpp" -#include "kernels/elementwise_functions/sin.hpp" -#include "kernels/elementwise_functions/sinh.hpp" -#include "kernels/elementwise_functions/sqrt.hpp" -#include "kernels/elementwise_functions/square.hpp" -#include "kernels/elementwise_functions/subtract.hpp" -#include "kernels/elementwise_functions/tan.hpp" -#include "kernels/elementwise_functions/tanh.hpp" -#include "kernels/elementwise_functions/true_divide.hpp" -#include "kernels/elementwise_functions/trunc.hpp" - -namespace dpctl -{ -namespace tensor -{ -namespace py_internal -{ - -namespace td_ns = dpctl::tensor::type_dispatch; - -py::dtype _dtype_from_typenum(td_ns::typenum_t dst_typenum_t) -{ - switch (dst_typenum_t) { - case td_ns::typenum_t::BOOL: - return py::dtype("?"); - case td_ns::typenum_t::INT8: - return py::dtype("i1"); - case td_ns::typenum_t::UINT8: - return py::dtype("u1"); - case td_ns::typenum_t::INT16: - return py::dtype("i2"); - case td_ns::typenum_t::UINT16: - return py::dtype("u2"); - case td_ns::typenum_t::INT32: - return py::dtype("i4"); - case td_ns::typenum_t::UINT32: - return py::dtype("u4"); - case td_ns::typenum_t::INT64: - return py::dtype("i8"); - case td_ns::typenum_t::UINT64: - return py::dtype("u8"); - case td_ns::typenum_t::HALF: - return py::dtype("f2"); - case td_ns::typenum_t::FLOAT: - return py::dtype("f4"); - case td_ns::typenum_t::DOUBLE: - return py::dtype("f8"); - case td_ns::typenum_t::CFLOAT: - return py::dtype("c8"); - case td_ns::typenum_t::CDOUBLE: - return py::dtype("c16"); - default: - throw py::value_error("Unrecognized dst_typeid"); - } -} - -int _result_typeid(int arg_typeid, const int *fn_output_id) -{ - if (arg_typeid < 0 || arg_typeid >= td_ns::num_types) { - throw py::value_error("Input typeid " + std::to_string(arg_typeid) + - " is outside of expected bounds."); - } - - return fn_output_id[arg_typeid]; -} - -namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common; -using ew_cmn_ns::binary_contig_impl_fn_ptr_t; -using ew_cmn_ns::binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; -using ew_cmn_ns::binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t; -using ew_cmn_ns::binary_strided_impl_fn_ptr_t; -using ew_cmn_ns::unary_contig_impl_fn_ptr_t; -using ew_cmn_ns::unary_strided_impl_fn_ptr_t; - -using ew_cmn_ns::binary_inplace_contig_impl_fn_ptr_t; -using ew_cmn_ns::binary_inplace_row_matrix_broadcast_impl_fn_ptr_t; -using ew_cmn_ns::binary_inplace_strided_impl_fn_ptr_t; - -// U01: ==== ABS (x) -namespace impl -{ - -namespace abs_fn_ns = dpctl::tensor::kernels::abs; - -static unary_contig_impl_fn_ptr_t abs_contig_dispatch_vector[td_ns::num_types]; -static int abs_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - abs_strided_dispatch_vector[td_ns::num_types]; - -void populate_abs_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = abs_fn_ns; - - using fn_ns::AbsContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(abs_contig_dispatch_vector); - - using fn_ns::AbsStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(abs_strided_dispatch_vector); - - using fn_ns::AbsTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(abs_output_typeid_vector); -}; - -} // namespace impl - -// U02: ==== ACOS (x) -namespace impl -{ - -namespace acos_fn_ns = dpctl::tensor::kernels::acos; - -static unary_contig_impl_fn_ptr_t acos_contig_dispatch_vector[td_ns::num_types]; -static int acos_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - acos_strided_dispatch_vector[td_ns::num_types]; - -void populate_acos_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = acos_fn_ns; - - using fn_ns::AcosContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(acos_contig_dispatch_vector); - - using fn_ns::AcosStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(acos_strided_dispatch_vector); - - using fn_ns::AcosTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(acos_output_typeid_vector); -} - -} // namespace impl - -// U03: ===== ACOSH (x) -namespace impl -{ - -namespace acosh_fn_ns = dpctl::tensor::kernels::acosh; - -static unary_contig_impl_fn_ptr_t - acosh_contig_dispatch_vector[td_ns::num_types]; -static int acosh_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - acosh_strided_dispatch_vector[td_ns::num_types]; - -void populate_acosh_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = acosh_fn_ns; - - using fn_ns::AcoshContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(acosh_contig_dispatch_vector); - - using fn_ns::AcoshStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(acosh_strided_dispatch_vector); - - using fn_ns::AcoshTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(acosh_output_typeid_vector); -} - -} // namespace impl - -// B01: ===== ADD (x1, x2) -namespace impl -{ -namespace add_fn_ns = dpctl::tensor::kernels::add; - -static binary_contig_impl_fn_ptr_t add_contig_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static int add_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - add_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -// add(matrix, row) -static binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t - add_contig_matrix_contig_row_broadcast_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -// add(row, matrix) -static binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t - add_contig_row_contig_matrix_broadcast_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -static binary_inplace_contig_impl_fn_ptr_t - add_inplace_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static binary_inplace_strided_impl_fn_ptr_t - add_inplace_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; -static binary_inplace_row_matrix_broadcast_impl_fn_ptr_t - add_inplace_row_matrix_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_add_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = add_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::AddTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(add_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::AddStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(add_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::AddContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(add_contig_dispatch_table); - - // function pointers for operation on contiguous matrix, contiguous row - // with contiguous matrix output - using fn_ns::AddContigMatrixContigRowBroadcastFactory; - DispatchTableBuilder< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t, - AddContigMatrixContigRowBroadcastFactory, num_types> - dtb4; - dtb4.populate_dispatch_table( - add_contig_matrix_contig_row_broadcast_dispatch_table); - - // function pointers for operation on contiguous row, contiguous matrix - // with contiguous matrix output - using fn_ns::AddContigRowContigMatrixBroadcastFactory; - DispatchTableBuilder< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t, - AddContigRowContigMatrixBroadcastFactory, num_types> - dtb5; - dtb5.populate_dispatch_table( - add_contig_row_contig_matrix_broadcast_dispatch_table); - - // function pointers for inplace operation on general strided arrays - using fn_ns::AddInplaceStridedFactory; - DispatchTableBuilder - dtb6; - dtb6.populate_dispatch_table(add_inplace_strided_dispatch_table); - - // function pointers for inplace operation on contiguous inputs and output - using fn_ns::AddInplaceContigFactory; - DispatchTableBuilder - dtb7; - dtb7.populate_dispatch_table(add_inplace_contig_dispatch_table); - - // function pointers for inplace operation on contiguous matrix - // and contiguous row - using fn_ns::AddInplaceRowMatrixBroadcastFactory; - DispatchTableBuilder - dtb8; - dtb8.populate_dispatch_table(add_inplace_row_matrix_dispatch_table); -}; - -} // namespace impl - -// U04: ===== ASIN (x) -namespace impl -{ - -namespace asin_fn_ns = dpctl::tensor::kernels::asin; - -static unary_contig_impl_fn_ptr_t asin_contig_dispatch_vector[td_ns::num_types]; -static int asin_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - asin_strided_dispatch_vector[td_ns::num_types]; - -void populate_asin_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = asin_fn_ns; - - using fn_ns::AsinContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(asin_contig_dispatch_vector); - - using fn_ns::AsinStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(asin_strided_dispatch_vector); - - using fn_ns::AsinTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(asin_output_typeid_vector); -} - -} // namespace impl - -// U05: ===== ASINH (x) -namespace impl -{ - -namespace asinh_fn_ns = dpctl::tensor::kernels::asinh; - -static unary_contig_impl_fn_ptr_t - asinh_contig_dispatch_vector[td_ns::num_types]; -static int asinh_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - asinh_strided_dispatch_vector[td_ns::num_types]; - -void populate_asinh_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = asinh_fn_ns; - - using fn_ns::AsinhContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(asinh_contig_dispatch_vector); - - using fn_ns::AsinhStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(asinh_strided_dispatch_vector); - - using fn_ns::AsinhTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(asinh_output_typeid_vector); -} - -} // namespace impl - -// U06: ===== ATAN (x) -namespace impl -{ - -namespace atan_fn_ns = dpctl::tensor::kernels::atan; - -static unary_contig_impl_fn_ptr_t atan_contig_dispatch_vector[td_ns::num_types]; -static int atan_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - atan_strided_dispatch_vector[td_ns::num_types]; - -void populate_atan_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = atan_fn_ns; - - using fn_ns::AtanContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(atan_contig_dispatch_vector); - - using fn_ns::AtanStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(atan_strided_dispatch_vector); - - using fn_ns::AtanTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(atan_output_typeid_vector); -} - -} // namespace impl - -// B02: ===== ATAN2 (x1, x2) -namespace impl -{ -namespace atan2_fn_ns = dpctl::tensor::kernels::atan2; - -static binary_contig_impl_fn_ptr_t - atan2_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int atan2_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - atan2_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_atan2_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = atan2_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::Atan2TypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(atan2_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::Atan2StridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(atan2_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::Atan2ContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(atan2_contig_dispatch_table); -}; - -} // namespace impl - -// U07: ===== ATANH (x) -namespace impl -{ - -namespace atanh_fn_ns = dpctl::tensor::kernels::atanh; - -static unary_contig_impl_fn_ptr_t - atanh_contig_dispatch_vector[td_ns::num_types]; -static int atanh_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - atanh_strided_dispatch_vector[td_ns::num_types]; - -void populate_atanh_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = atanh_fn_ns; - - using fn_ns::AtanhContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(atanh_contig_dispatch_vector); - - using fn_ns::AtanhStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(atanh_strided_dispatch_vector); - - using fn_ns::AtanhTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(atanh_output_typeid_vector); -} - -} // namespace impl - -// B03: ===== BITWISE_AND (x1, x2) -namespace impl -{ -namespace bitwise_and_fn_ns = dpctl::tensor::kernels::bitwise_and; - -static binary_contig_impl_fn_ptr_t - bitwise_and_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int bitwise_and_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - bitwise_and_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_bitwise_and_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = bitwise_and_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::BitwiseAndTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(bitwise_and_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::BitwiseAndStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(bitwise_and_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::BitwiseAndContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(bitwise_and_contig_dispatch_table); -}; - -} // namespace impl - -// B04: ===== BITWISE_LEFT_SHIFT (x1, x2) -namespace impl -{ -namespace bitwise_left_shift_fn_ns = dpctl::tensor::kernels::bitwise_left_shift; - -static binary_contig_impl_fn_ptr_t - bitwise_left_shift_contig_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static int bitwise_left_shift_output_id_table[td_ns::num_types] - [td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - bitwise_left_shift_strided_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_bitwise_left_shift_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = bitwise_left_shift_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::BitwiseLeftShiftTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(bitwise_left_shift_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::BitwiseLeftShiftStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(bitwise_left_shift_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::BitwiseLeftShiftContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(bitwise_left_shift_contig_dispatch_table); -}; - -} // namespace impl - -// U08: ===== BITWISE_INVERT (x) -namespace impl -{ - -namespace bitwise_invert_fn_ns = dpctl::tensor::kernels::bitwise_invert; - -static unary_contig_impl_fn_ptr_t - bitwise_invert_contig_dispatch_vector[td_ns::num_types]; -static int bitwise_invert_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - bitwise_invert_strided_dispatch_vector[td_ns::num_types]; - -void populate_bitwise_invert_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = bitwise_invert_fn_ns; - - using fn_ns::BitwiseInvertContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(bitwise_invert_contig_dispatch_vector); - - using fn_ns::BitwiseInvertStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(bitwise_invert_strided_dispatch_vector); - - using fn_ns::BitwiseInvertTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(bitwise_invert_output_typeid_vector); -}; - -} // namespace impl - -// B05: ===== BITWISE_OR (x1, x2) -namespace impl -{ -namespace bitwise_or_fn_ns = dpctl::tensor::kernels::bitwise_or; - -static binary_contig_impl_fn_ptr_t - bitwise_or_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int bitwise_or_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - bitwise_or_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_bitwise_or_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = bitwise_or_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::BitwiseOrTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(bitwise_or_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::BitwiseOrStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(bitwise_or_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::BitwiseOrContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(bitwise_or_contig_dispatch_table); -}; -} // namespace impl - -// B06: ===== BITWISE_RIGHT_SHIFT (x1, x2) -namespace impl -{ -namespace bitwise_right_shift_fn_ns = - dpctl::tensor::kernels::bitwise_right_shift; - -static binary_contig_impl_fn_ptr_t - bitwise_right_shift_contig_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static int bitwise_right_shift_output_id_table[td_ns::num_types] - [td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - bitwise_right_shift_strided_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_bitwise_right_shift_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = bitwise_right_shift_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::BitwiseRightShiftTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(bitwise_right_shift_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::BitwiseRightShiftStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(bitwise_right_shift_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::BitwiseRightShiftContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(bitwise_right_shift_contig_dispatch_table); -}; - -} // namespace impl - -// B07: ===== BITWISE_XOR (x1, x2) -namespace impl -{ -namespace bitwise_xor_fn_ns = dpctl::tensor::kernels::bitwise_xor; - -static binary_contig_impl_fn_ptr_t - bitwise_xor_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int bitwise_xor_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - bitwise_xor_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_bitwise_xor_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = bitwise_xor_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::BitwiseXorTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(bitwise_xor_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::BitwiseXorStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(bitwise_xor_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::BitwiseXorContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(bitwise_xor_contig_dispatch_table); -}; -} // namespace impl - -// U09: ==== CEIL (x) -namespace impl -{ - -namespace ceil_fn_ns = dpctl::tensor::kernels::ceil; - -static unary_contig_impl_fn_ptr_t ceil_contig_dispatch_vector[td_ns::num_types]; -static int ceil_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - ceil_strided_dispatch_vector[td_ns::num_types]; - -void populate_ceil_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = ceil_fn_ns; - - using fn_ns::CeilContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(ceil_contig_dispatch_vector); - - using fn_ns::CeilStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(ceil_strided_dispatch_vector); - - using fn_ns::CeilTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(ceil_output_typeid_vector); -} - -} // namespace impl - -// U10: ==== CONJ (x) -namespace impl -{ - -namespace conj_fn_ns = dpctl::tensor::kernels::conj; - -static unary_contig_impl_fn_ptr_t conj_contig_dispatch_vector[td_ns::num_types]; -static int conj_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - conj_strided_dispatch_vector[td_ns::num_types]; - -void populate_conj_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = conj_fn_ns; - - using fn_ns::ConjContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(conj_contig_dispatch_vector); - - using fn_ns::ConjStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(conj_strided_dispatch_vector); - - using fn_ns::ConjTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(conj_output_typeid_vector); -} -} // namespace impl - -// U11: ==== COS (x) -namespace impl -{ - -namespace cos_fn_ns = dpctl::tensor::kernels::cos; - -static unary_contig_impl_fn_ptr_t cos_contig_dispatch_vector[td_ns::num_types]; -static int cos_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - cos_strided_dispatch_vector[td_ns::num_types]; - -void populate_cos_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = cos_fn_ns; - - using fn_ns::CosContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(cos_contig_dispatch_vector); - - using fn_ns::CosStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(cos_strided_dispatch_vector); - - using fn_ns::CosTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(cos_output_typeid_vector); -} - -} // namespace impl - -// U12: ==== COSH (x) -namespace impl -{ - -namespace cosh_fn_ns = dpctl::tensor::kernels::cosh; - -static unary_contig_impl_fn_ptr_t cosh_contig_dispatch_vector[td_ns::num_types]; -static int cosh_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - cosh_strided_dispatch_vector[td_ns::num_types]; - -void populate_cosh_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = cosh_fn_ns; - - using fn_ns::CoshContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(cosh_contig_dispatch_vector); - - using fn_ns::CoshStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(cosh_strided_dispatch_vector); - - using fn_ns::CoshTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(cosh_output_typeid_vector); -} - -} // namespace impl - -// B08: ==== DIVIDE (x1, x2) -namespace impl -{ -namespace true_divide_fn_ns = dpctl::tensor::kernels::true_divide; - -static binary_contig_impl_fn_ptr_t - true_divide_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int true_divide_output_id_table[td_ns::num_types][td_ns::num_types]; -static int true_divide_inplace_output_id_table[td_ns::num_types] - [td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - true_divide_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -// divide(matrix, row) -static binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t - true_divide_contig_matrix_contig_row_broadcast_dispatch_table - [td_ns::num_types][td_ns::num_types]; - -// divide(row, matrix) -static binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t - true_divide_contig_row_contig_matrix_broadcast_dispatch_table - [td_ns::num_types][td_ns::num_types]; - -static binary_inplace_contig_impl_fn_ptr_t - true_divide_inplace_contig_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static binary_inplace_strided_impl_fn_ptr_t - true_divide_inplace_strided_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static binary_inplace_row_matrix_broadcast_impl_fn_ptr_t - true_divide_inplace_row_matrix_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_true_divide_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = true_divide_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::TrueDivideTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(true_divide_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::TrueDivideStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(true_divide_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::TrueDivideContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(true_divide_contig_dispatch_table); - - // function pointers for operation on contiguous matrix, contiguous row - // with contiguous matrix output - using fn_ns::TrueDivideContigMatrixContigRowBroadcastFactory; - DispatchTableBuilder< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t, - TrueDivideContigMatrixContigRowBroadcastFactory, num_types> - dtb4; - dtb4.populate_dispatch_table( - true_divide_contig_matrix_contig_row_broadcast_dispatch_table); - - // function pointers for operation on contiguous row, contiguous matrix - // with contiguous matrix output - using fn_ns::TrueDivideContigRowContigMatrixBroadcastFactory; - DispatchTableBuilder< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t, - TrueDivideContigRowContigMatrixBroadcastFactory, num_types> - dtb5; - dtb5.populate_dispatch_table( - true_divide_contig_row_contig_matrix_broadcast_dispatch_table); - - // which input types are supported, and what is the type of the result - using fn_ns::TrueDivideInplaceTypeMapFactory; - DispatchTableBuilder dtb6; - dtb6.populate_dispatch_table(true_divide_inplace_output_id_table); - - // function pointers for inplace operation on general strided arrays - using fn_ns::TrueDivideInplaceStridedFactory; - DispatchTableBuilder - dtb7; - dtb7.populate_dispatch_table(true_divide_inplace_strided_dispatch_table); - - // function pointers for inplace operation on contiguous inputs and output - using fn_ns::TrueDivideInplaceContigFactory; - DispatchTableBuilder - dtb8; - dtb8.populate_dispatch_table(true_divide_inplace_contig_dispatch_table); - - // function pointers for inplace operation on contiguous matrix - // and contiguous row - using fn_ns::TrueDivideInplaceRowMatrixBroadcastFactory; - DispatchTableBuilder - dtb9; - dtb9.populate_dispatch_table(true_divide_inplace_row_matrix_dispatch_table); -}; - -} // namespace impl - -// B09: ==== EQUAL (x1, x2) -namespace impl -{ -namespace equal_fn_ns = dpctl::tensor::kernels::equal; - -static binary_contig_impl_fn_ptr_t - equal_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int equal_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - equal_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_equal_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = equal_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::EqualTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(equal_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::EqualStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(equal_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::EqualContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(equal_contig_dispatch_table); -}; -} // namespace impl - -// U13: ==== EXP (x) -namespace impl -{ - -namespace exp_fn_ns = dpctl::tensor::kernels::exp; - -static unary_contig_impl_fn_ptr_t exp_contig_dispatch_vector[td_ns::num_types]; -static int exp_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - exp_strided_dispatch_vector[td_ns::num_types]; - -void populate_exp_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = exp_fn_ns; - - using fn_ns::ExpContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(exp_contig_dispatch_vector); - - using fn_ns::ExpStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(exp_strided_dispatch_vector); - - using fn_ns::ExpTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(exp_output_typeid_vector); -} - -} // namespace impl - -// U14: ==== EXPM1 (x) -namespace impl -{ - -namespace expm1_fn_ns = dpctl::tensor::kernels::expm1; - -static unary_contig_impl_fn_ptr_t - expm1_contig_dispatch_vector[td_ns::num_types]; -static int expm1_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - expm1_strided_dispatch_vector[td_ns::num_types]; - -void populate_expm1_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = expm1_fn_ns; - - using fn_ns::Expm1ContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(expm1_contig_dispatch_vector); - - using fn_ns::Expm1StridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(expm1_strided_dispatch_vector); - - using fn_ns::Expm1TypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(expm1_output_typeid_vector); -} - -} // namespace impl - -// U15: ==== FLOOR (x) -namespace impl -{ - -namespace floor_fn_ns = dpctl::tensor::kernels::floor; - -static unary_contig_impl_fn_ptr_t - floor_contig_dispatch_vector[td_ns::num_types]; -static int floor_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - floor_strided_dispatch_vector[td_ns::num_types]; - -void populate_floor_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = floor_fn_ns; - - using fn_ns::FloorContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(floor_contig_dispatch_vector); - - using fn_ns::FloorStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(floor_strided_dispatch_vector); - - using fn_ns::FloorTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(floor_output_typeid_vector); -} - -} // namespace impl - -// B10: ==== FLOOR_DIVIDE (x1, x2) -namespace impl -{ -namespace floor_divide_fn_ns = dpctl::tensor::kernels::floor_divide; - -static binary_contig_impl_fn_ptr_t - floor_divide_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int floor_divide_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - floor_divide_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -static binary_inplace_contig_impl_fn_ptr_t - floor_divide_inplace_contig_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static binary_inplace_strided_impl_fn_ptr_t - floor_divide_inplace_strided_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_floor_divide_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = floor_divide_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::FloorDivideTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(floor_divide_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::FloorDivideStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(floor_divide_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::FloorDivideContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(floor_divide_contig_dispatch_table); - - // function pointers for inplace operation on general strided arrays - using fn_ns::FloorDivideInplaceStridedFactory; - DispatchTableBuilder - dtb4; - dtb4.populate_dispatch_table(floor_divide_inplace_strided_dispatch_table); - - // function pointers for inplace operation on contiguous inputs and output - using fn_ns::FloorDivideInplaceContigFactory; - DispatchTableBuilder - dtb5; - dtb5.populate_dispatch_table(floor_divide_inplace_contig_dispatch_table); -}; - -} // namespace impl - -// B11: ==== GREATER (x1, x2) -namespace impl -{ -namespace greater_fn_ns = dpctl::tensor::kernels::greater; - -static binary_contig_impl_fn_ptr_t - greater_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int greater_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - greater_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_greater_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = greater_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::GreaterTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(greater_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::GreaterStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(greater_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::GreaterContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(greater_contig_dispatch_table); -}; -} // namespace impl - -// B12: ==== GREATER_EQUAL (x1, x2) -namespace impl -{ -namespace greater_equal_fn_ns = dpctl::tensor::kernels::greater_equal; - -static binary_contig_impl_fn_ptr_t - greater_equal_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int greater_equal_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - greater_equal_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_greater_equal_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = greater_equal_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::GreaterEqualTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(greater_equal_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::GreaterEqualStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(greater_equal_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::GreaterEqualContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(greater_equal_contig_dispatch_table); -}; -} // namespace impl - -// U16: ==== IMAG (x) -namespace impl -{ - -namespace imag_fn_ns = dpctl::tensor::kernels::imag; - -static unary_contig_impl_fn_ptr_t imag_contig_dispatch_vector[td_ns::num_types]; -static int imag_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - imag_strided_dispatch_vector[td_ns::num_types]; - -void populate_imag_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = imag_fn_ns; - - using fn_ns::ImagContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(imag_contig_dispatch_vector); - - using fn_ns::ImagStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(imag_strided_dispatch_vector); - - using fn_ns::ImagTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(imag_output_typeid_vector); -} -} // namespace impl - -// U17: ==== ISFINITE (x) -namespace impl -{ -namespace isfinite_fn_ns = dpctl::tensor::kernels::isfinite; - -static unary_contig_impl_fn_ptr_t - isfinite_contig_dispatch_vector[td_ns::num_types]; -static int isfinite_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - isfinite_strided_dispatch_vector[td_ns::num_types]; - -void populate_isfinite_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = isfinite_fn_ns; - - using fn_ns::IsFiniteContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(isfinite_contig_dispatch_vector); - - using fn_ns::IsFiniteStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(isfinite_strided_dispatch_vector); - - using fn_ns::IsFiniteTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(isfinite_output_typeid_vector); -} - -} // namespace impl - -// U18: ==== ISINF (x) -namespace impl -{ -namespace isinf_fn_ns = dpctl::tensor::kernels::isinf; - -static unary_contig_impl_fn_ptr_t - isinf_contig_dispatch_vector[td_ns::num_types]; -static int isinf_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - isinf_strided_dispatch_vector[td_ns::num_types]; - -void populate_isinf_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = isinf_fn_ns; - - using fn_ns::IsInfContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(isinf_contig_dispatch_vector); - - using fn_ns::IsInfStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(isinf_strided_dispatch_vector); - - using fn_ns::IsInfTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(isinf_output_typeid_vector); -} - -} // namespace impl - -// U19: ==== ISNAN (x) -namespace impl -{ -namespace isnan_fn_ns = dpctl::tensor::kernels::isnan; - -static unary_contig_impl_fn_ptr_t - isnan_contig_dispatch_vector[td_ns::num_types]; -static int isnan_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - isnan_strided_dispatch_vector[td_ns::num_types]; - -void populate_isnan_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = isnan_fn_ns; - - using fn_ns::IsNanContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(isnan_contig_dispatch_vector); - - using fn_ns::IsNanStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(isnan_strided_dispatch_vector); - - using fn_ns::IsNanTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(isnan_output_typeid_vector); -} - -} // namespace impl - -// B13: ==== LESS (x1, x2) -namespace impl -{ -namespace less_fn_ns = dpctl::tensor::kernels::less; - -static binary_contig_impl_fn_ptr_t less_contig_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static int less_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - less_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_less_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = less_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::LessTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(less_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::LessStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(less_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::LessContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(less_contig_dispatch_table); -}; -} // namespace impl - -// B14: ==== LESS_EQUAL (x1, x2) -namespace impl -{ -namespace less_equal_fn_ns = dpctl::tensor::kernels::less_equal; - -static binary_contig_impl_fn_ptr_t - less_equal_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int less_equal_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - less_equal_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_less_equal_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = less_equal_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::LessEqualTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(less_equal_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::LessEqualStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(less_equal_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::LessEqualContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(less_equal_contig_dispatch_table); -}; -} // namespace impl - -// U20: ==== LOG (x) -namespace impl -{ - -namespace log_fn_ns = dpctl::tensor::kernels::log; - -static unary_contig_impl_fn_ptr_t log_contig_dispatch_vector[td_ns::num_types]; -static int log_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - log_strided_dispatch_vector[td_ns::num_types]; - -void populate_log_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = log_fn_ns; - - using fn_ns::LogContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(log_contig_dispatch_vector); - - using fn_ns::LogStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(log_strided_dispatch_vector); - - using fn_ns::LogTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(log_output_typeid_vector); -} - -} // namespace impl - -// U21: ==== LOG1P (x) -namespace impl -{ - -namespace log1p_fn_ns = dpctl::tensor::kernels::log1p; - -static unary_contig_impl_fn_ptr_t - log1p_contig_dispatch_vector[td_ns::num_types]; -static int log1p_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - log1p_strided_dispatch_vector[td_ns::num_types]; - -void populate_log1p_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = log1p_fn_ns; - - using fn_ns::Log1pContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(log1p_contig_dispatch_vector); - - using fn_ns::Log1pStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(log1p_strided_dispatch_vector); - - using fn_ns::Log1pTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(log1p_output_typeid_vector); -} - -} // namespace impl - -// U22: ==== LOG2 (x) -namespace impl -{ - -namespace log2_fn_ns = dpctl::tensor::kernels::log2; - -static unary_contig_impl_fn_ptr_t log2_contig_dispatch_vector[td_ns::num_types]; -static int log2_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - log2_strided_dispatch_vector[td_ns::num_types]; - -void populate_log2_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = log2_fn_ns; - - using fn_ns::Log2ContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(log2_contig_dispatch_vector); - - using fn_ns::Log2StridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(log2_strided_dispatch_vector); - - using fn_ns::Log2TypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(log2_output_typeid_vector); -}; - -} // namespace impl - -// U23: ==== LOG10 (x) -namespace impl -{ - -namespace log10_fn_ns = dpctl::tensor::kernels::log10; - -static unary_contig_impl_fn_ptr_t - log10_contig_dispatch_vector[td_ns::num_types]; -static int log10_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - log10_strided_dispatch_vector[td_ns::num_types]; - -void populate_log10_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = log10_fn_ns; - - using fn_ns::Log10ContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(log10_contig_dispatch_vector); - - using fn_ns::Log10StridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(log10_strided_dispatch_vector); - - using fn_ns::Log10TypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(log10_output_typeid_vector); -}; - -} // namespace impl - -// B15: ==== LOGADDEXP (x1, x2) -namespace impl -{ -namespace logaddexp_fn_ns = dpctl::tensor::kernels::logaddexp; - -static binary_contig_impl_fn_ptr_t - logaddexp_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int logaddexp_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - logaddexp_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_logaddexp_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = logaddexp_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::LogAddExpTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(logaddexp_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::LogAddExpStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(logaddexp_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::LogAddExpContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(logaddexp_contig_dispatch_table); -}; -} // namespace impl - -// B16: ==== LOGICAL_AND (x1, x2) -namespace impl -{ -namespace logical_and_fn_ns = dpctl::tensor::kernels::logical_and; - -static binary_contig_impl_fn_ptr_t - logical_and_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int logical_and_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - logical_and_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_logical_and_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = logical_and_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::LogicalAndTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(logical_and_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::LogicalAndStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(logical_and_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::LogicalAndContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(logical_and_contig_dispatch_table); -}; -} // namespace impl - -// U24: ==== LOGICAL_NOT (x) -namespace impl -{ -namespace logical_not_fn_ns = dpctl::tensor::kernels::logical_not; - -static unary_contig_impl_fn_ptr_t - logical_not_contig_dispatch_vector[td_ns::num_types]; -static int logical_not_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - logical_not_strided_dispatch_vector[td_ns::num_types]; - -void populate_logical_not_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = logical_not_fn_ns; - - using fn_ns::LogicalNotContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(logical_not_contig_dispatch_vector); - - using fn_ns::LogicalNotStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(logical_not_strided_dispatch_vector); - - using fn_ns::LogicalNotTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(logical_not_output_typeid_vector); -}; -} // namespace impl - -// B17: ==== LOGICAL_OR (x1, x2) -namespace impl -{ -namespace logical_or_fn_ns = dpctl::tensor::kernels::logical_or; - -static binary_contig_impl_fn_ptr_t - logical_or_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int logical_or_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - logical_or_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_logical_or_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = logical_or_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::LogicalOrTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(logical_or_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::LogicalOrStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(logical_or_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::LogicalOrContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(logical_or_contig_dispatch_table); -}; -} // namespace impl - -// B18: ==== LOGICAL_XOR (x1, x2) -namespace impl -{ -namespace logical_xor_fn_ns = dpctl::tensor::kernels::logical_xor; - -static binary_contig_impl_fn_ptr_t - logical_xor_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int logical_xor_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - logical_xor_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_logical_xor_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = logical_xor_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::LogicalXorTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(logical_xor_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::LogicalXorStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(logical_xor_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::LogicalXorContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(logical_xor_contig_dispatch_table); -}; -} // namespace impl - -// B??: ==== MAXIMUM (x1, x2) -namespace impl -{ - -namespace maximum_fn_ns = dpctl::tensor::kernels::maximum; - -static binary_contig_impl_fn_ptr_t - maximum_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int maximum_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - maximum_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_maximum_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = maximum_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::MaximumTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(maximum_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::MaximumStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(maximum_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::MaximumContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(maximum_contig_dispatch_table); -}; - -} // namespace impl - -// B??: ==== MINIMUM (x1, x2) -namespace impl -{ - -namespace minimum_fn_ns = dpctl::tensor::kernels::minimum; - -static binary_contig_impl_fn_ptr_t - minimum_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int minimum_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - minimum_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_minimum_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = minimum_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::MinimumTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(minimum_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::MinimumStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(minimum_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::MinimumContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(minimum_contig_dispatch_table); -}; - -} // namespace impl - -// B19: ==== MULTIPLY (x1, x2) -namespace impl -{ - -namespace multiply_fn_ns = dpctl::tensor::kernels::multiply; - -static binary_contig_impl_fn_ptr_t - multiply_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int multiply_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - multiply_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -// mul(matrix, row) -static binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t - multiply_contig_matrix_contig_row_broadcast_dispatch_table - [td_ns::num_types][td_ns::num_types]; - -// mul(row, matrix) -static binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t - multiply_contig_row_contig_matrix_broadcast_dispatch_table - [td_ns::num_types][td_ns::num_types]; - -static binary_inplace_contig_impl_fn_ptr_t - multiply_inplace_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static binary_inplace_strided_impl_fn_ptr_t - multiply_inplace_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; -static binary_inplace_row_matrix_broadcast_impl_fn_ptr_t - multiply_inplace_row_matrix_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_multiply_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = multiply_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::MultiplyTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(multiply_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::MultiplyStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(multiply_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::MultiplyContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(multiply_contig_dispatch_table); - - // function pointers for operation on contiguous matrix, contiguous row - // with contiguous matrix output - using fn_ns::MultiplyContigMatrixContigRowBroadcastFactory; - DispatchTableBuilder< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t, - MultiplyContigMatrixContigRowBroadcastFactory, num_types> - dtb4; - dtb4.populate_dispatch_table( - multiply_contig_matrix_contig_row_broadcast_dispatch_table); - - // function pointers for operation on contiguous row, contiguous matrix - // with contiguous matrix output - using fn_ns::MultiplyContigRowContigMatrixBroadcastFactory; - DispatchTableBuilder< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t, - MultiplyContigRowContigMatrixBroadcastFactory, num_types> - dtb5; - dtb5.populate_dispatch_table( - multiply_contig_row_contig_matrix_broadcast_dispatch_table); - - // function pointers for inplace operation on general strided arrays - using fn_ns::MultiplyInplaceStridedFactory; - DispatchTableBuilder - dtb6; - dtb6.populate_dispatch_table(multiply_inplace_strided_dispatch_table); - - // function pointers for inplace operation on contiguous inputs and output - using fn_ns::MultiplyInplaceContigFactory; - DispatchTableBuilder - dtb7; - dtb7.populate_dispatch_table(multiply_inplace_contig_dispatch_table); - - // function pointers for inplace operation on contiguous matrix - // and contiguous row - using fn_ns::MultiplyInplaceRowMatrixBroadcastFactory; - DispatchTableBuilder - dtb8; - dtb8.populate_dispatch_table(multiply_inplace_row_matrix_dispatch_table); -}; - -} // namespace impl - -// U25: ==== NEGATIVE (x) -namespace impl -{ - -namespace negative_fn_ns = dpctl::tensor::kernels::negative; - -static unary_contig_impl_fn_ptr_t - negative_contig_dispatch_vector[td_ns::num_types]; -static int negative_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - negative_strided_dispatch_vector[td_ns::num_types]; - -void populate_negative_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = negative_fn_ns; - - using fn_ns::NegativeContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(negative_contig_dispatch_vector); - - using fn_ns::NegativeStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(negative_strided_dispatch_vector); - - using fn_ns::NegativeTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(negative_output_typeid_vector); -} - -} // namespace impl - -// B20: ==== NOT_EQUAL (x1, x2) -namespace impl -{ -namespace not_equal_fn_ns = dpctl::tensor::kernels::not_equal; - -static binary_contig_impl_fn_ptr_t - not_equal_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int not_equal_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - not_equal_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_not_equal_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = not_equal_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::NotEqualTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(not_equal_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::NotEqualStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(not_equal_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::NotEqualContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(not_equal_contig_dispatch_table); -}; -} // namespace impl - -// U26: ==== POSITIVE (x) -namespace impl -{ - -namespace positive_fn_ns = dpctl::tensor::kernels::positive; - -static unary_contig_impl_fn_ptr_t - positive_contig_dispatch_vector[td_ns::num_types]; -static int positive_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - positive_strided_dispatch_vector[td_ns::num_types]; - -void populate_positive_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = positive_fn_ns; - - using fn_ns::PositiveContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(positive_contig_dispatch_vector); - - using fn_ns::PositiveStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(positive_strided_dispatch_vector); - - using fn_ns::PositiveTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(positive_output_typeid_vector); -} - -} // namespace impl - -// B21: ==== POW (x1, x2) -namespace impl -{ - -namespace pow_fn_ns = dpctl::tensor::kernels::pow; - -static binary_contig_impl_fn_ptr_t pow_contig_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static int pow_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - pow_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_pow_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = pow_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::PowTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(pow_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::PowStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(pow_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::PowContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(pow_contig_dispatch_table); -}; - -} // namespace impl - -// U??: ==== PROJ (x) -namespace impl -{ - -namespace proj_fn_ns = dpctl::tensor::kernels::proj; - -static unary_contig_impl_fn_ptr_t proj_contig_dispatch_vector[td_ns::num_types]; -static int proj_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - proj_strided_dispatch_vector[td_ns::num_types]; - -void populate_proj_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = proj_fn_ns; - - using fn_ns::ProjContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(proj_contig_dispatch_vector); - - using fn_ns::ProjStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(proj_strided_dispatch_vector); - - using fn_ns::ProjTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(proj_output_typeid_vector); -} -} // namespace impl - -// U27: ==== REAL (x) -namespace impl -{ - -namespace real_fn_ns = dpctl::tensor::kernels::real; - -static unary_contig_impl_fn_ptr_t real_contig_dispatch_vector[td_ns::num_types]; -static int real_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - real_strided_dispatch_vector[td_ns::num_types]; - -void populate_real_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = real_fn_ns; - - using fn_ns::RealContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(real_contig_dispatch_vector); - - using fn_ns::RealStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(real_strided_dispatch_vector); - - using fn_ns::RealTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(real_output_typeid_vector); -} -} // namespace impl - -// B22: ==== REMAINDER (x1, x2) -namespace impl -{ - -namespace remainder_fn_ns = dpctl::tensor::kernels::remainder; - -static binary_contig_impl_fn_ptr_t - remainder_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int remainder_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - remainder_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_remainder_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = remainder_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::RemainderTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(remainder_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::RemainderStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(remainder_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::RemainderContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(remainder_contig_dispatch_table); -} - -} // namespace impl - -// U28: ==== ROUND (x) -namespace impl -{ - -namespace round_fn_ns = dpctl::tensor::kernels::round; - -static unary_contig_impl_fn_ptr_t - round_contig_dispatch_vector[td_ns::num_types]; -static int round_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - round_strided_dispatch_vector[td_ns::num_types]; - -void populate_round_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = round_fn_ns; - - using fn_ns::RoundContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(round_contig_dispatch_vector); - - using fn_ns::RoundStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(round_strided_dispatch_vector); - - using fn_ns::RoundTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(round_output_typeid_vector); -} - -} // namespace impl - -// U29: ==== SIGN (x) -namespace impl -{ - -namespace sign_fn_ns = dpctl::tensor::kernels::sign; - -static unary_contig_impl_fn_ptr_t sign_contig_dispatch_vector[td_ns::num_types]; -static int sign_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - sign_strided_dispatch_vector[td_ns::num_types]; - -void populate_sign_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = sign_fn_ns; - - using fn_ns::SignContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(sign_contig_dispatch_vector); - - using fn_ns::SignStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(sign_strided_dispatch_vector); - - using fn_ns::SignTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(sign_output_typeid_vector); -} - -} // namespace impl - -// ==== SIGNBIT (x) -namespace impl -{ - -namespace signbit_fn_ns = dpctl::tensor::kernels::signbit; - -static unary_contig_impl_fn_ptr_t - signbit_contig_dispatch_vector[td_ns::num_types]; -static int signbit_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - signbit_strided_dispatch_vector[td_ns::num_types]; - -void populate_signbit_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = signbit_fn_ns; - - using fn_ns::SignbitContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(signbit_contig_dispatch_vector); - - using fn_ns::SignbitStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(signbit_strided_dispatch_vector); - - using fn_ns::SignbitTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(signbit_output_typeid_vector); -} - -} // namespace impl - -// U30: ==== SIN (x) -namespace impl -{ - -namespace sin_fn_ns = dpctl::tensor::kernels::sin; - -static unary_contig_impl_fn_ptr_t sin_contig_dispatch_vector[td_ns::num_types]; -static int sin_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - sin_strided_dispatch_vector[td_ns::num_types]; - -void populate_sin_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = sin_fn_ns; - - using fn_ns::SinContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(sin_contig_dispatch_vector); - - using fn_ns::SinStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(sin_strided_dispatch_vector); - - using fn_ns::SinTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(sin_output_typeid_vector); -} - -} // namespace impl - -// U31: ==== SINH (x) -namespace impl -{ - -namespace sinh_fn_ns = dpctl::tensor::kernels::sinh; - -static unary_contig_impl_fn_ptr_t sinh_contig_dispatch_vector[td_ns::num_types]; -static int sinh_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - sinh_strided_dispatch_vector[td_ns::num_types]; - -void populate_sinh_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = sinh_fn_ns; - - using fn_ns::SinhContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(sinh_contig_dispatch_vector); - - using fn_ns::SinhStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(sinh_strided_dispatch_vector); - - using fn_ns::SinhTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(sinh_output_typeid_vector); -} - -} // namespace impl - -// U32: ==== SQUARE (x) -namespace impl -{ - -namespace square_fn_ns = dpctl::tensor::kernels::square; - -static unary_contig_impl_fn_ptr_t - square_contig_dispatch_vector[td_ns::num_types]; -static int square_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - square_strided_dispatch_vector[td_ns::num_types]; - -void populate_square_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = square_fn_ns; - - using fn_ns::SquareContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(square_contig_dispatch_vector); - - using fn_ns::SquareStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(square_strided_dispatch_vector); - - using fn_ns::SquareTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(square_output_typeid_vector); -} - -} // namespace impl - -// U33: ==== SQRT (x) -namespace impl -{ - -namespace sqrt_fn_ns = dpctl::tensor::kernels::sqrt; - -static unary_contig_impl_fn_ptr_t sqrt_contig_dispatch_vector[td_ns::num_types]; -static int sqrt_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - sqrt_strided_dispatch_vector[td_ns::num_types]; - -void populate_sqrt_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = sqrt_fn_ns; - - using fn_ns::SqrtContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(sqrt_contig_dispatch_vector); - - using fn_ns::SqrtStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(sqrt_strided_dispatch_vector); - - using fn_ns::SqrtTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(sqrt_output_typeid_vector); -} - -} // namespace impl - -// B23: ==== SUBTRACT (x1, x2) -namespace impl -{ -namespace subtract_fn_ns = dpctl::tensor::kernels::subtract; - -static binary_contig_impl_fn_ptr_t - subtract_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int subtract_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - subtract_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -// sub(matrix, row) -static binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t - subtract_contig_matrix_contig_row_broadcast_dispatch_table - [td_ns::num_types][td_ns::num_types]; - -// sub(row, matrix) -static binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t - subtract_contig_row_contig_matrix_broadcast_dispatch_table - [td_ns::num_types][td_ns::num_types]; - -static binary_inplace_contig_impl_fn_ptr_t - subtract_inplace_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static binary_inplace_strided_impl_fn_ptr_t - subtract_inplace_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; -static binary_inplace_row_matrix_broadcast_impl_fn_ptr_t - subtract_inplace_row_matrix_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_subtract_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = subtract_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::SubtractTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(subtract_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::SubtractStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(subtract_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::SubtractContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(subtract_contig_dispatch_table); - - // function pointers for operation on contiguous matrix, contiguous row - // with contiguous matrix output - using fn_ns::SubtractContigMatrixContigRowBroadcastFactory; - DispatchTableBuilder< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t, - SubtractContigMatrixContigRowBroadcastFactory, num_types> - dtb4; - dtb4.populate_dispatch_table( - subtract_contig_matrix_contig_row_broadcast_dispatch_table); - - // function pointers for operation on contiguous row, contiguous matrix - // with contiguous matrix output - using fn_ns::SubtractContigRowContigMatrixBroadcastFactory; - DispatchTableBuilder< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t, - SubtractContigRowContigMatrixBroadcastFactory, num_types> - dtb5; - dtb5.populate_dispatch_table( - subtract_contig_row_contig_matrix_broadcast_dispatch_table); - - // function pointers for inplace operation on general strided arrays - using fn_ns::SubtractInplaceStridedFactory; - DispatchTableBuilder - dtb6; - dtb6.populate_dispatch_table(subtract_inplace_strided_dispatch_table); - - // function pointers for inplace operation on contiguous inputs and output - using fn_ns::SubtractInplaceContigFactory; - DispatchTableBuilder - dtb7; - dtb7.populate_dispatch_table(subtract_inplace_contig_dispatch_table); - - // function pointers for inplace operation on contiguous matrix - // and contiguous row - using fn_ns::SubtractInplaceRowMatrixBroadcastFactory; - DispatchTableBuilder - dtb8; - dtb8.populate_dispatch_table(subtract_inplace_row_matrix_dispatch_table); -}; - -} // namespace impl - -// U34: ==== TAN (x) -namespace impl -{ - -namespace tan_fn_ns = dpctl::tensor::kernels::tan; - -static unary_contig_impl_fn_ptr_t tan_contig_dispatch_vector[td_ns::num_types]; -static int tan_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - tan_strided_dispatch_vector[td_ns::num_types]; - -void populate_tan_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = tan_fn_ns; - - using fn_ns::TanContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(tan_contig_dispatch_vector); - - using fn_ns::TanStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(tan_strided_dispatch_vector); - - using fn_ns::TanTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(tan_output_typeid_vector); -} - -} // namespace impl - -// U35: ==== TANH (x) -namespace impl -{ - -namespace tanh_fn_ns = dpctl::tensor::kernels::tanh; - -static unary_contig_impl_fn_ptr_t tanh_contig_dispatch_vector[td_ns::num_types]; -static int tanh_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - tanh_strided_dispatch_vector[td_ns::num_types]; - -void populate_tanh_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = tanh_fn_ns; - - using fn_ns::TanhContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(tanh_contig_dispatch_vector); - - using fn_ns::TanhStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(tanh_strided_dispatch_vector); - - using fn_ns::TanhTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(tanh_output_typeid_vector); -} - -} // namespace impl - -// U36: ==== TRUNC (x) -namespace impl -{ - -namespace trunc_fn_ns = dpctl::tensor::kernels::trunc; - -static unary_contig_impl_fn_ptr_t - trunc_contig_dispatch_vector[td_ns::num_types]; -static int trunc_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - trunc_strided_dispatch_vector[td_ns::num_types]; - -void populate_trunc_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = trunc_fn_ns; - - using fn_ns::TruncContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(trunc_contig_dispatch_vector); - - using fn_ns::TruncStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(trunc_strided_dispatch_vector); - - using fn_ns::TruncTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(trunc_output_typeid_vector); -} - -} // namespace impl - -// B24: ==== HYPOT (x1, x2) -namespace impl -{ -namespace hypot_fn_ns = dpctl::tensor::kernels::hypot; - -static binary_contig_impl_fn_ptr_t - hypot_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int hypot_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - hypot_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_hypot_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = hypot_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::HypotTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(hypot_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::HypotStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(hypot_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::HypotContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(hypot_contig_dispatch_table); -}; - -} // namespace impl - -// U37: ==== CBRT (x) -namespace impl -{ - -namespace cbrt_fn_ns = dpctl::tensor::kernels::cbrt; - -static unary_contig_impl_fn_ptr_t cbrt_contig_dispatch_vector[td_ns::num_types]; -static int cbrt_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - cbrt_strided_dispatch_vector[td_ns::num_types]; - -void populate_cbrt_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = cbrt_fn_ns; - - using fn_ns::CbrtContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(cbrt_contig_dispatch_vector); - - using fn_ns::CbrtStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(cbrt_strided_dispatch_vector); - - using fn_ns::CbrtTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(cbrt_output_typeid_vector); -} - -} // namespace impl - -// B24: ==== COPYSIGN (x1, x2) -namespace impl -{ -namespace copysign_fn_ns = dpctl::tensor::kernels::copysign; - -static binary_contig_impl_fn_ptr_t - copysign_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int copysign_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - copysign_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_copysign_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = copysign_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::CopysignTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(copysign_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::CopysignStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(copysign_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::CopysignContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(copysign_contig_dispatch_table); -}; - -} // namespace impl - -// U38: ==== EXP2 (x) -namespace impl -{ - -namespace exp2_fn_ns = dpctl::tensor::kernels::exp2; - -static unary_contig_impl_fn_ptr_t exp2_contig_dispatch_vector[td_ns::num_types]; -static int exp2_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - exp2_strided_dispatch_vector[td_ns::num_types]; - -void populate_exp2_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = exp2_fn_ns; - - using fn_ns::Exp2ContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(exp2_contig_dispatch_vector); - - using fn_ns::Exp2StridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(exp2_strided_dispatch_vector); - - using fn_ns::Exp2TypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(exp2_output_typeid_vector); -} - -} // namespace impl - -// U39: ==== RSQRT (x) -namespace impl -{ - -namespace rsqrt_fn_ns = dpctl::tensor::kernels::rsqrt; - -static unary_contig_impl_fn_ptr_t - rsqrt_contig_dispatch_vector[td_ns::num_types]; -static int rsqrt_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - rsqrt_strided_dispatch_vector[td_ns::num_types]; - -void populate_rsqrt_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = rsqrt_fn_ns; - - using fn_ns::RsqrtContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(rsqrt_contig_dispatch_vector); - - using fn_ns::RsqrtStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(rsqrt_strided_dispatch_vector); - - using fn_ns::RsqrtTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(rsqrt_output_typeid_vector); -} - -} // namespace impl - -// ========================================================================================== -// // - -namespace py = pybind11; - -void init_elementwise_functions(py::module_ m) -{ - using arrayT = dpctl::tensor::usm_ndarray; - using event_vecT = std::vector; - - // U01: ==== ABS (x) - { - impl::populate_abs_dispatch_vectors(); - using impl::abs_contig_dispatch_vector; - using impl::abs_output_typeid_vector; - using impl::abs_strided_dispatch_vector; - - auto abs_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, abs_output_typeid_vector, - abs_contig_dispatch_vector, abs_strided_dispatch_vector); - }; - m.def("_abs", abs_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto abs_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, abs_output_typeid_vector); - }; - m.def("_abs_result_type", abs_result_type_pyapi); - } - - // U02: ==== ACOS (x) - { - impl::populate_acos_dispatch_vectors(); - using impl::acos_contig_dispatch_vector; - using impl::acos_output_typeid_vector; - using impl::acos_strided_dispatch_vector; - - auto acos_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, acos_output_typeid_vector, - acos_contig_dispatch_vector, acos_strided_dispatch_vector); - }; - m.def("_acos", acos_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto acos_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, acos_output_typeid_vector); - }; - m.def("_acos_result_type", acos_result_type_pyapi); - } - - // U03: ===== ACOSH (x) - { - impl::populate_acosh_dispatch_vectors(); - using impl::acosh_contig_dispatch_vector; - using impl::acosh_output_typeid_vector; - using impl::acosh_strided_dispatch_vector; - - auto acosh_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, acosh_output_typeid_vector, - acosh_contig_dispatch_vector, acosh_strided_dispatch_vector); - }; - m.def("_acosh", acosh_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto acosh_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - acosh_output_typeid_vector); - }; - m.def("_acosh_result_type", acosh_result_type_pyapi); - } - - // B01: ===== ADD (x1, x2) - { - impl::populate_add_dispatch_tables(); - using impl::add_contig_dispatch_table; - using impl::add_contig_matrix_contig_row_broadcast_dispatch_table; - using impl::add_contig_row_contig_matrix_broadcast_dispatch_table; - using impl::add_output_id_table; - using impl::add_strided_dispatch_table; - - auto add_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, add_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - add_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - add_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - add_contig_matrix_contig_row_broadcast_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - add_contig_row_contig_matrix_broadcast_dispatch_table); - }; - auto add_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - add_output_id_table); - }; - m.def("_add", add_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_add_result_type", add_result_type_pyapi, ""); - - using impl::add_inplace_contig_dispatch_table; - using impl::add_inplace_row_matrix_dispatch_table; - using impl::add_inplace_strided_dispatch_table; - - auto add_inplace_pyapi = - [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_inplace_ufunc( - src, dst, exec_q, depends, add_output_id_table, - // function pointers to handle inplace operation on - // contiguous arrays (pointers may be nullptr) - add_inplace_contig_dispatch_table, - // function pointers to handle inplace operation on strided - // arrays (most general case) - add_inplace_strided_dispatch_table, - // function pointers to handle inplace operation on - // c-contig matrix with c-contig row with broadcasting - // (may be nullptr) - add_inplace_row_matrix_dispatch_table); - }; - m.def("_add_inplace", add_inplace_pyapi, "", py::arg("lhs"), - py::arg("rhs"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - } - - // U04: ===== ASIN (x) - { - impl::populate_asin_dispatch_vectors(); - using impl::asin_contig_dispatch_vector; - using impl::asin_output_typeid_vector; - using impl::asin_strided_dispatch_vector; - - auto asin_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, asin_output_typeid_vector, - asin_contig_dispatch_vector, asin_strided_dispatch_vector); - }; - m.def("_asin", asin_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto asin_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, asin_output_typeid_vector); - }; - m.def("_asin_result_type", asin_result_type_pyapi); - } - - // U05: ===== ASINH (x) - { - impl::populate_asinh_dispatch_vectors(); - using impl::asinh_contig_dispatch_vector; - using impl::asinh_output_typeid_vector; - using impl::asinh_strided_dispatch_vector; - - auto asinh_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, asinh_output_typeid_vector, - asinh_contig_dispatch_vector, asinh_strided_dispatch_vector); - }; - m.def("_asinh", asinh_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto asinh_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - asinh_output_typeid_vector); - }; - m.def("_asinh_result_type", asinh_result_type_pyapi); - } - - // U06: ===== ATAN (x) - { - impl::populate_atan_dispatch_vectors(); - using impl::atan_contig_dispatch_vector; - using impl::atan_output_typeid_vector; - using impl::atan_strided_dispatch_vector; - - auto atan_pyapi = [&](arrayT src, arrayT dst, sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, atan_output_typeid_vector, - atan_contig_dispatch_vector, atan_strided_dispatch_vector); - }; - m.def("_atan", atan_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto atan_result_type_pyapi = [&](py::dtype dtype) { - return py_unary_ufunc_result_type(dtype, atan_output_typeid_vector); - }; - m.def("_atan_result_type", atan_result_type_pyapi); - } - - // B02: ===== ATAN2 (x1, x2) - { - impl::populate_atan2_dispatch_tables(); - using impl::atan2_contig_dispatch_table; - using impl::atan2_output_id_table; - using impl::atan2_strided_dispatch_table; - - auto atan2_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, atan2_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - atan2_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - atan2_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto atan2_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - atan2_output_id_table); - }; - m.def("_atan2", atan2_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_atan2_result_type", atan2_result_type_pyapi, ""); - } - - // U07: ===== ATANH (x) - { - impl::populate_atanh_dispatch_vectors(); - using impl::atanh_contig_dispatch_vector; - using impl::atanh_output_typeid_vector; - using impl::atanh_strided_dispatch_vector; - - auto atanh_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, atanh_output_typeid_vector, - atanh_contig_dispatch_vector, atanh_strided_dispatch_vector); - }; - m.def("_atanh", atanh_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto atanh_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - atanh_output_typeid_vector); - }; - m.def("_atanh_result_type", atanh_result_type_pyapi); - } - - // B03: ===== BITWISE_AND (x1, x2) - { - impl::populate_bitwise_and_dispatch_tables(); - using impl::bitwise_and_contig_dispatch_table; - using impl::bitwise_and_output_id_table; - using impl::bitwise_and_strided_dispatch_table; - - auto bitwise_and_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, bitwise_and_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - bitwise_and_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - bitwise_and_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto bitwise_and_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - bitwise_and_output_id_table); - }; - m.def("_bitwise_and", bitwise_and_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_bitwise_and_result_type", bitwise_and_result_type_pyapi, ""); - } - - // B04: ===== BITWISE_LEFT_SHIFT (x1, x2) - { - impl::populate_bitwise_left_shift_dispatch_tables(); - using impl::bitwise_left_shift_contig_dispatch_table; - using impl::bitwise_left_shift_output_id_table; - using impl::bitwise_left_shift_strided_dispatch_table; - - auto bitwise_left_shift_pyapi = [&](const dpctl::tensor::usm_ndarray - &src1, - const dpctl::tensor::usm_ndarray - &src2, - const dpctl::tensor::usm_ndarray - &dst, - sycl::queue &exec_q, - const std::vector - &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, - bitwise_left_shift_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - bitwise_left_shift_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - bitwise_left_shift_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto bitwise_left_shift_result_type_pyapi = - [&](const py::dtype &dtype1, const py::dtype &dtype2) { - return py_binary_ufunc_result_type( - dtype1, dtype2, bitwise_left_shift_output_id_table); - }; - m.def("_bitwise_left_shift", bitwise_left_shift_pyapi, "", - py::arg("src1"), py::arg("src2"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_bitwise_left_shift_result_type", - bitwise_left_shift_result_type_pyapi, ""); - } - - // U08: ===== BITWISE_INVERT (x) - { - impl::populate_bitwise_invert_dispatch_vectors(); - using impl::bitwise_invert_contig_dispatch_vector; - using impl::bitwise_invert_output_typeid_vector; - using impl::bitwise_invert_strided_dispatch_vector; - - auto bitwise_invert_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc(src, dst, exec_q, depends, - bitwise_invert_output_typeid_vector, - bitwise_invert_contig_dispatch_vector, - bitwise_invert_strided_dispatch_vector); - }; - m.def("_bitwise_invert", bitwise_invert_pyapi, "", py::arg("src"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - - auto bitwise_invert_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type( - dtype, bitwise_invert_output_typeid_vector); - }; - m.def("_bitwise_invert_result_type", bitwise_invert_result_type_pyapi); - } - - // B05: ===== BITWISE_OR (x1, x2) - { - impl::populate_bitwise_or_dispatch_tables(); - using impl::bitwise_or_contig_dispatch_table; - using impl::bitwise_or_output_id_table; - using impl::bitwise_or_strided_dispatch_table; - - auto bitwise_or_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, bitwise_or_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - bitwise_or_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - bitwise_or_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto bitwise_or_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - bitwise_or_output_id_table); - }; - m.def("_bitwise_or", bitwise_or_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_bitwise_or_result_type", bitwise_or_result_type_pyapi, ""); - } - - // B06: ===== BITWISE_RIGHT_SHIFT (x1, x2) - { - impl::populate_bitwise_right_shift_dispatch_tables(); - using impl::bitwise_right_shift_contig_dispatch_table; - using impl::bitwise_right_shift_output_id_table; - using impl::bitwise_right_shift_strided_dispatch_table; - - auto bitwise_right_shift_pyapi = [&](const dpctl::tensor::usm_ndarray - &src1, - const dpctl::tensor::usm_ndarray - &src2, - const dpctl::tensor::usm_ndarray - &dst, - sycl::queue &exec_q, - const std::vector - &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, - bitwise_right_shift_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - bitwise_right_shift_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - bitwise_right_shift_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto bitwise_right_shift_result_type_pyapi = - [&](const py::dtype &dtype1, const py::dtype &dtype2) { - return py_binary_ufunc_result_type( - dtype1, dtype2, bitwise_right_shift_output_id_table); - }; - m.def("_bitwise_right_shift", bitwise_right_shift_pyapi, "", - py::arg("src1"), py::arg("src2"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_bitwise_right_shift_result_type", - bitwise_right_shift_result_type_pyapi, ""); - } - - // B07: ===== BITWISE_XOR (x1, x2) - { - impl::populate_bitwise_xor_dispatch_tables(); - using impl::bitwise_xor_contig_dispatch_table; - using impl::bitwise_xor_output_id_table; - using impl::bitwise_xor_strided_dispatch_table; - - auto bitwise_xor_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, bitwise_xor_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - bitwise_xor_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - bitwise_xor_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto bitwise_xor_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - bitwise_xor_output_id_table); - }; - m.def("_bitwise_xor", bitwise_xor_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_bitwise_xor_result_type", bitwise_xor_result_type_pyapi, ""); - } - - // U09: ==== CEIL (x) - { - impl::populate_ceil_dispatch_vectors(); - using impl::ceil_contig_dispatch_vector; - using impl::ceil_output_typeid_vector; - using impl::ceil_strided_dispatch_vector; - - auto ceil_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, ceil_output_typeid_vector, - ceil_contig_dispatch_vector, ceil_strided_dispatch_vector); - }; - m.def("_ceil", ceil_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto ceil_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, ceil_output_typeid_vector); - }; - m.def("_ceil_result_type", ceil_result_type_pyapi); - } - - // U10: ==== CONJ (x) - { - impl::populate_conj_dispatch_vectors(); - using impl::conj_contig_dispatch_vector; - using impl::conj_output_typeid_vector; - using impl::conj_strided_dispatch_vector; - - auto conj_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, conj_output_typeid_vector, - conj_contig_dispatch_vector, conj_strided_dispatch_vector); - }; - m.def("_conj", conj_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto conj_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, conj_output_typeid_vector); - }; - m.def("_conj_result_type", conj_result_type_pyapi); - } - - // U11: ==== COS (x) - { - impl::populate_cos_dispatch_vectors(); - using impl::cos_contig_dispatch_vector; - using impl::cos_output_typeid_vector; - using impl::cos_strided_dispatch_vector; - - auto cos_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, cos_output_typeid_vector, - cos_contig_dispatch_vector, cos_strided_dispatch_vector); - }; - m.def("_cos", cos_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto cos_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, cos_output_typeid_vector); - }; - m.def("_cos_result_type", cos_result_type_pyapi); - } - - // U12: ==== COSH (x) - { - impl::populate_cosh_dispatch_vectors(); - using impl::cosh_contig_dispatch_vector; - using impl::cosh_output_typeid_vector; - using impl::cosh_strided_dispatch_vector; - - auto cosh_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, cosh_output_typeid_vector, - cosh_contig_dispatch_vector, cosh_strided_dispatch_vector); - }; - m.def("_cosh", cosh_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto cosh_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, cosh_output_typeid_vector); - }; - m.def("_cosh_result_type", cosh_result_type_pyapi); - } - - // B08: ==== DIVIDE (x1, x2) - { - impl::populate_true_divide_dispatch_tables(); - using impl::true_divide_contig_dispatch_table; - using impl:: - true_divide_contig_matrix_contig_row_broadcast_dispatch_table; - using impl:: - true_divide_contig_row_contig_matrix_broadcast_dispatch_table; - using impl::true_divide_output_id_table; - using impl::true_divide_strided_dispatch_table; - - auto divide_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, true_divide_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - true_divide_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - true_divide_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - true_divide_contig_matrix_contig_row_broadcast_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - true_divide_contig_row_contig_matrix_broadcast_dispatch_table); - }; - auto divide_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - true_divide_output_id_table); - }; - m.def("_divide", divide_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_divide_result_type", divide_result_type_pyapi, ""); - - using impl::true_divide_inplace_contig_dispatch_table; - using impl::true_divide_inplace_output_id_table; - using impl::true_divide_inplace_row_matrix_dispatch_table; - using impl::true_divide_inplace_strided_dispatch_table; - - auto divide_inplace_pyapi = - [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_inplace_ufunc( - src, dst, exec_q, depends, - true_divide_inplace_output_id_table, - // function pointers to handle inplace operation on - // contiguous arrays (pointers may be nullptr) - true_divide_inplace_contig_dispatch_table, - // function pointers to handle inplace operation on strided - // arrays (most general case) - true_divide_inplace_strided_dispatch_table, - // function pointers to handle inplace operation on - // c-contig matrix with c-contig row with broadcasting - // (may be nullptr) - true_divide_inplace_row_matrix_dispatch_table); - }; - m.def("_divide_inplace", divide_inplace_pyapi, "", py::arg("lhs"), - py::arg("rhs"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - } - - // B09: ==== EQUAL (x1, x2) - { - impl::populate_equal_dispatch_tables(); - using impl::equal_contig_dispatch_table; - using impl::equal_output_id_table; - using impl::equal_strided_dispatch_table; - - auto equal_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, equal_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - equal_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - equal_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto equal_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - equal_output_id_table); - }; - m.def("_equal", equal_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_equal_result_type", equal_result_type_pyapi, ""); - } - - // U13: ==== EXP (x) - { - impl::populate_exp_dispatch_vectors(); - using impl::exp_contig_dispatch_vector; - using impl::exp_output_typeid_vector; - using impl::exp_strided_dispatch_vector; - - auto exp_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, exp_output_typeid_vector, - exp_contig_dispatch_vector, exp_strided_dispatch_vector); - }; - m.def("_exp", exp_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto exp_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, exp_output_typeid_vector); - }; - m.def("_exp_result_type", exp_result_type_pyapi); - } - - // U14: ==== EXPM1 (x) - { - impl::populate_expm1_dispatch_vectors(); - using impl::expm1_contig_dispatch_vector; - using impl::expm1_output_typeid_vector; - using impl::expm1_strided_dispatch_vector; - - auto expm1_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, expm1_output_typeid_vector, - expm1_contig_dispatch_vector, expm1_strided_dispatch_vector); - }; - m.def("_expm1", expm1_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto expm1_result_type_pyapi = [&](const py::dtype dtype) { - return py_unary_ufunc_result_type(dtype, - expm1_output_typeid_vector); - }; - m.def("_expm1_result_type", expm1_result_type_pyapi); - } - - // U15: ==== FLOOR (x) - { - impl::populate_floor_dispatch_vectors(); - using impl::floor_contig_dispatch_vector; - using impl::floor_output_typeid_vector; - using impl::floor_strided_dispatch_vector; - - auto floor_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, floor_output_typeid_vector, - floor_contig_dispatch_vector, floor_strided_dispatch_vector); - }; - m.def("_floor", floor_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto floor_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - floor_output_typeid_vector); - }; - m.def("_floor_result_type", floor_result_type_pyapi); - } - - // B10: ==== FLOOR_DIVIDE (x1, x2) - { - impl::populate_floor_divide_dispatch_tables(); - using impl::floor_divide_contig_dispatch_table; - using impl::floor_divide_output_id_table; - using impl::floor_divide_strided_dispatch_table; - - auto floor_divide_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, floor_divide_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - floor_divide_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - floor_divide_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto floor_divide_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - floor_divide_output_id_table); - }; - m.def("_floor_divide", floor_divide_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_floor_divide_result_type", floor_divide_result_type_pyapi, ""); - - using impl::floor_divide_inplace_contig_dispatch_table; - using impl::floor_divide_inplace_strided_dispatch_table; - - auto floor_divide_inplace_pyapi = - [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_inplace_ufunc( - src, dst, exec_q, depends, floor_divide_output_id_table, - // function pointers to handle inplace operation on - // contiguous arrays (pointers may be nullptr) - floor_divide_inplace_contig_dispatch_table, - // function pointers to handle inplace operation on strided - // arrays (most general case) - floor_divide_inplace_strided_dispatch_table, - // function pointers to handle inplace operation on - // c-contig matrix with c-contig row with broadcasting - // (may be nullptr) - td_ns::NullPtrTable< - binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); - }; - m.def("_floor_divide_inplace", floor_divide_inplace_pyapi, "", - py::arg("lhs"), py::arg("rhs"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - } - - // B11: ==== GREATER (x1, x2) - { - impl::populate_greater_dispatch_tables(); - using impl::greater_contig_dispatch_table; - using impl::greater_output_id_table; - using impl::greater_strided_dispatch_table; - - auto greater_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, greater_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - greater_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - greater_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto greater_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - greater_output_id_table); - }; - m.def("_greater", greater_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_greater_result_type", greater_result_type_pyapi, ""); - } - - // B12: ==== GREATER_EQUAL (x1, x2) - { - impl::populate_greater_equal_dispatch_tables(); - using impl::greater_equal_contig_dispatch_table; - using impl::greater_equal_output_id_table; - using impl::greater_equal_strided_dispatch_table; - - auto greater_equal_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, greater_equal_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - greater_equal_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - greater_equal_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto greater_equal_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - greater_equal_output_id_table); - }; - m.def("_greater_equal", greater_equal_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_greater_equal_result_type", greater_equal_result_type_pyapi, - ""); - } - - // U16: ==== IMAG (x) - { - impl::populate_imag_dispatch_vectors(); - using impl::imag_contig_dispatch_vector; - using impl::imag_output_typeid_vector; - using impl::imag_strided_dispatch_vector; - - auto imag_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, imag_output_typeid_vector, - imag_contig_dispatch_vector, imag_strided_dispatch_vector); - }; - m.def("_imag", imag_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto imag_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, imag_output_typeid_vector); - }; - m.def("_imag_result_type", imag_result_type_pyapi); - } - - // U17: ==== ISFINITE (x) - { - impl::populate_isfinite_dispatch_vectors(); - - using impl::isfinite_contig_dispatch_vector; - using impl::isfinite_output_typeid_vector; - using impl::isfinite_strided_dispatch_vector; - auto isfinite_pyapi = - [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_unary_ufunc(src, dst, exec_q, depends, - isfinite_output_typeid_vector, - isfinite_contig_dispatch_vector, - isfinite_strided_dispatch_vector); - }; - auto isfinite_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - isfinite_output_typeid_vector); - }; - m.def("_isfinite", isfinite_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_isfinite_result_type", isfinite_result_type_pyapi, ""); - } - - // U18: ==== ISINF (x) - { - impl::populate_isinf_dispatch_vectors(); - - using impl::isinf_contig_dispatch_vector; - using impl::isinf_output_typeid_vector; - using impl::isinf_strided_dispatch_vector; - auto isinf_pyapi = [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, isinf_output_typeid_vector, - isinf_contig_dispatch_vector, isinf_strided_dispatch_vector); - }; - auto isinf_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - isinf_output_typeid_vector); - }; - m.def("_isinf", isinf_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_isinf_result_type", isinf_result_type_pyapi, ""); - } - - // U19: ==== ISNAN (x) - { - impl::populate_isnan_dispatch_vectors(); - - using impl::isnan_contig_dispatch_vector; - using impl::isnan_output_typeid_vector; - using impl::isnan_strided_dispatch_vector; - auto isnan_pyapi = [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, isnan_output_typeid_vector, - isnan_contig_dispatch_vector, isnan_strided_dispatch_vector); - }; - auto isnan_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - isnan_output_typeid_vector); - }; - m.def("_isnan", isnan_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_isnan_result_type", isnan_result_type_pyapi, ""); - } - - // B13: ==== LESS (x1, x2) - { - impl::populate_less_dispatch_tables(); - using impl::less_contig_dispatch_table; - using impl::less_output_id_table; - using impl::less_strided_dispatch_table; - - auto less_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, less_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - less_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - less_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto less_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - less_output_id_table); - }; - m.def("_less", less_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_less_result_type", less_result_type_pyapi, ""); - } - - // B14: ==== LESS_EQUAL (x1, x2) - { - impl::populate_less_equal_dispatch_tables(); - using impl::less_equal_contig_dispatch_table; - using impl::less_equal_output_id_table; - using impl::less_equal_strided_dispatch_table; - - auto less_equal_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, less_equal_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - less_equal_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - less_equal_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto less_equal_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - less_equal_output_id_table); - }; - m.def("_less_equal", less_equal_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_less_equal_result_type", less_equal_result_type_pyapi, ""); - } - - // U20: ==== LOG (x) - { - impl::populate_log_dispatch_vectors(); - using impl::log_contig_dispatch_vector; - using impl::log_output_typeid_vector; - using impl::log_strided_dispatch_vector; - - auto log_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, log_output_typeid_vector, - log_contig_dispatch_vector, log_strided_dispatch_vector); - }; - m.def("_log", log_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto log_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, log_output_typeid_vector); - }; - m.def("_log_result_type", log_result_type_pyapi); - } - - // U21: ==== LOG1P (x) - { - impl::populate_log1p_dispatch_vectors(); - using impl::log1p_contig_dispatch_vector; - using impl::log1p_output_typeid_vector; - using impl::log1p_strided_dispatch_vector; - - auto log1p_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, log1p_output_typeid_vector, - log1p_contig_dispatch_vector, log1p_strided_dispatch_vector); - }; - m.def("_log1p", log1p_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto log1p_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - log1p_output_typeid_vector); - }; - m.def("_log1p_result_type", log1p_result_type_pyapi); - } - - // U22: ==== LOG2 (x) - { - impl::populate_log2_dispatch_vectors(); - - using impl::log2_contig_dispatch_vector; - using impl::log2_output_typeid_vector; - using impl::log2_strided_dispatch_vector; - auto log2_pyapi = [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, log2_output_typeid_vector, - log2_contig_dispatch_vector, log2_strided_dispatch_vector); - }; - auto log2_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, log2_output_typeid_vector); - }; - m.def("_log2", log2_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_log2_result_type", log2_result_type_pyapi, ""); - } - - // U23: ==== LOG10 (x) - { - impl::populate_log10_dispatch_vectors(); - - using impl::log10_contig_dispatch_vector; - using impl::log10_output_typeid_vector; - using impl::log10_strided_dispatch_vector; - auto log10_pyapi = [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, log10_output_typeid_vector, - log10_contig_dispatch_vector, log10_strided_dispatch_vector); - }; - auto log10_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - log10_output_typeid_vector); - }; - m.def("_log10", log10_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_log10_result_type", log10_result_type_pyapi, ""); - } - - // B15: ==== LOGADDEXP (x1, x2) - { - impl::populate_logaddexp_dispatch_tables(); - using impl::logaddexp_contig_dispatch_table; - using impl::logaddexp_output_id_table; - using impl::logaddexp_strided_dispatch_table; - - auto logaddexp_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, logaddexp_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - logaddexp_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - logaddexp_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto logaddexp_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - logaddexp_output_id_table); - }; - m.def("_logaddexp", logaddexp_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_logaddexp_result_type", logaddexp_result_type_pyapi, ""); - } - - // B16: ==== LOGICAL_AND (x1, x2) - { - impl::populate_logical_and_dispatch_tables(); - using impl::logical_and_contig_dispatch_table; - using impl::logical_and_output_id_table; - using impl::logical_and_strided_dispatch_table; - - auto logical_and_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, logical_and_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - logical_and_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - logical_and_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto logical_and_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - logical_and_output_id_table); - }; - m.def("_logical_and", logical_and_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_logical_and_result_type", logical_and_result_type_pyapi, ""); - } - - // U24: ==== LOGICAL_NOT (x) - { - impl::populate_logical_not_dispatch_vectors(); - using impl::logical_not_contig_dispatch_vector; - using impl::logical_not_output_typeid_vector; - using impl::logical_not_strided_dispatch_vector; - - auto logical_not_pyapi = [&](const arrayT &src, arrayT dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc(src, dst, exec_q, depends, - logical_not_output_typeid_vector, - logical_not_contig_dispatch_vector, - logical_not_strided_dispatch_vector); - }; - m.def("_logical_not", logical_not_pyapi, "", py::arg("src"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - - auto logical_not_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - logical_not_output_typeid_vector); - }; - m.def("_logical_not_result_type", logical_not_result_type_pyapi); - } - - // B17: ==== LOGICAL_OR (x1, x2) - { - impl::populate_logical_or_dispatch_tables(); - using impl::logical_or_contig_dispatch_table; - using impl::logical_or_output_id_table; - using impl::logical_or_strided_dispatch_table; - - auto logical_or_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, logical_or_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - logical_or_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - logical_or_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto logical_or_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - logical_or_output_id_table); - }; - m.def("_logical_or", logical_or_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_logical_or_result_type", logical_or_result_type_pyapi, ""); - } - - // B18: ==== LOGICAL_XOR (x1, x2) - { - impl::populate_logical_xor_dispatch_tables(); - using impl::logical_xor_contig_dispatch_table; - using impl::logical_xor_output_id_table; - using impl::logical_xor_strided_dispatch_table; - - auto logical_xor_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, logical_xor_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - logical_xor_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - logical_xor_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto logical_xor_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - logical_xor_output_id_table); - }; - m.def("_logical_xor", logical_xor_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_logical_xor_result_type", logical_xor_result_type_pyapi, ""); - } - - // B??: ==== MAXIMUM (x1, x2) - { - impl::populate_maximum_dispatch_tables(); - using impl::maximum_contig_dispatch_table; - using impl::maximum_output_id_table; - using impl::maximum_strided_dispatch_table; - - auto maximum_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, maximum_output_id_table, - // function pointers to handle operation on contiguous - // arrays (pointers may be nullptr) - maximum_contig_dispatch_table, - // function pointers to handle operation on strided arrays - // (most general case) - maximum_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix - // and c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix - // and c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto maximum_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - maximum_output_id_table); - }; - m.def("_maximum", maximum_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_maximum_result_type", maximum_result_type_pyapi, ""); - } - - // B??: ==== MINIMUM (x1, x2) - { - impl::populate_minimum_dispatch_tables(); - using impl::minimum_contig_dispatch_table; - using impl::minimum_output_id_table; - using impl::minimum_strided_dispatch_table; - - auto minimum_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, minimum_output_id_table, - // function pointers to handle operation on contiguous - // arrays (pointers may be nullptr) - minimum_contig_dispatch_table, - // function pointers to handle operation on strided arrays - // (most general case) - minimum_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix - // and c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto minimum_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - minimum_output_id_table); - }; - m.def("_minimum", minimum_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_minimum_result_type", minimum_result_type_pyapi, ""); - } - - // B19: ==== MULTIPLY (x1, x2) - { - impl::populate_multiply_dispatch_tables(); - using impl::multiply_contig_dispatch_table; - using impl::multiply_contig_matrix_contig_row_broadcast_dispatch_table; - using impl::multiply_contig_row_contig_matrix_broadcast_dispatch_table; - using impl::multiply_output_id_table; - using impl::multiply_strided_dispatch_table; - - auto multiply_pyapi = - [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, multiply_output_id_table, - // function pointers to handle operation on contiguous - // arrays (pointers may be nullptr) - multiply_contig_dispatch_table, - // function pointers to handle operation on strided arrays - // (most general case) - multiply_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix - // and c-contig row with broadcasting (may be nullptr) - multiply_contig_matrix_contig_row_broadcast_dispatch_table, - // function pointers to handle operation of c-contig matrix - // and c-contig row with broadcasting (may be nullptr) - multiply_contig_row_contig_matrix_broadcast_dispatch_table); - }; - auto multiply_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - multiply_output_id_table); - }; - m.def("_multiply", multiply_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_multiply_result_type", multiply_result_type_pyapi, ""); - - using impl::multiply_inplace_contig_dispatch_table; - using impl::multiply_inplace_row_matrix_dispatch_table; - using impl::multiply_inplace_strided_dispatch_table; - - auto multiply_inplace_pyapi = - [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_inplace_ufunc( - src, dst, exec_q, depends, multiply_output_id_table, - // function pointers to handle inplace operation on - // contiguous arrays (pointers may be nullptr) - multiply_inplace_contig_dispatch_table, - // function pointers to handle inplace operation on strided - // arrays (most general case) - multiply_inplace_strided_dispatch_table, - // function pointers to handle inplace operation on - // c-contig matrix with c-contig row with broadcasting - // (may be nullptr) - multiply_inplace_row_matrix_dispatch_table); - }; - m.def("_multiply_inplace", multiply_inplace_pyapi, "", py::arg("lhs"), - py::arg("rhs"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - } - - // U25: ==== NEGATIVE (x) - { - impl::populate_negative_dispatch_vectors(); - using impl::negative_contig_dispatch_vector; - using impl::negative_output_typeid_vector; - using impl::negative_strided_dispatch_vector; - - auto negative_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc(src, dst, exec_q, depends, - negative_output_typeid_vector, - negative_contig_dispatch_vector, - negative_strided_dispatch_vector); - }; - m.def("_negative", negative_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto negative_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - negative_output_typeid_vector); - }; - m.def("_negative_result_type", negative_result_type_pyapi); - } - - // B20: ==== NOT_EQUAL (x1, x2) - { - impl::populate_not_equal_dispatch_tables(); - using impl::not_equal_contig_dispatch_table; - using impl::not_equal_output_id_table; - using impl::not_equal_strided_dispatch_table; - - auto not_equal_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, not_equal_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - not_equal_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - not_equal_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto not_equal_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - not_equal_output_id_table); - }; - m.def("_not_equal", not_equal_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_not_equal_result_type", not_equal_result_type_pyapi, ""); - } - - // U26: ==== POSITIVE (x) - { - impl::populate_positive_dispatch_vectors(); - using impl::positive_contig_dispatch_vector; - using impl::positive_output_typeid_vector; - using impl::positive_strided_dispatch_vector; - - auto positive_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc(src, dst, exec_q, depends, - positive_output_typeid_vector, - positive_contig_dispatch_vector, - positive_strided_dispatch_vector); - }; - m.def("_positive", positive_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto positive_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - positive_output_typeid_vector); - }; - m.def("_positive_result_type", positive_result_type_pyapi); - } - - // B21: ==== POW (x1, x2) - { - impl::populate_pow_dispatch_tables(); - using impl::pow_contig_dispatch_table; - using impl::pow_output_id_table; - using impl::pow_strided_dispatch_table; - - auto pow_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, pow_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - pow_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - pow_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto pow_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - pow_output_id_table); - }; - m.def("_pow", pow_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_pow_result_type", pow_result_type_pyapi, ""); - } - - // U??: ==== PROJ (x) - { - impl::populate_proj_dispatch_vectors(); - using impl::proj_contig_dispatch_vector; - using impl::proj_output_typeid_vector; - using impl::proj_strided_dispatch_vector; - - auto proj_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, proj_output_typeid_vector, - proj_contig_dispatch_vector, proj_strided_dispatch_vector); - }; - m.def("_proj", proj_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto proj_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, proj_output_typeid_vector); - }; - m.def("_proj_result_type", proj_result_type_pyapi); - } - - // U27: ==== REAL (x) - { - impl::populate_real_dispatch_vectors(); - using impl::real_contig_dispatch_vector; - using impl::real_output_typeid_vector; - using impl::real_strided_dispatch_vector; - - auto real_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, real_output_typeid_vector, - real_contig_dispatch_vector, real_strided_dispatch_vector); - }; - m.def("_real", real_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto real_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, real_output_typeid_vector); - }; - m.def("_real_result_type", real_result_type_pyapi); - } - - // B22: ==== REMAINDER (x1, x2) - { - impl::populate_remainder_dispatch_tables(); - using impl::remainder_contig_dispatch_table; - using impl::remainder_output_id_table; - using impl::remainder_strided_dispatch_table; - - auto remainder_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, remainder_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - remainder_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - remainder_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto remainder_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - remainder_output_id_table); - }; - m.def("_remainder", remainder_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_remainder_result_type", remainder_result_type_pyapi, ""); - } - - // U28: ==== ROUND (x) - { - impl::populate_round_dispatch_vectors(); - using impl::round_contig_dispatch_vector; - using impl::round_output_typeid_vector; - using impl::round_strided_dispatch_vector; - - auto round_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, round_output_typeid_vector, - round_contig_dispatch_vector, round_strided_dispatch_vector); - }; - m.def("_round", round_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto round_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - round_output_typeid_vector); - }; - m.def("_round_result_type", round_result_type_pyapi); - } - - // U29: ==== SIGN (x) - { - impl::populate_sign_dispatch_vectors(); - using impl::sign_contig_dispatch_vector; - using impl::sign_output_typeid_vector; - using impl::sign_strided_dispatch_vector; - - auto sign_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, sign_output_typeid_vector, - sign_contig_dispatch_vector, sign_strided_dispatch_vector); - }; - m.def("_sign", sign_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto sign_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, sign_output_typeid_vector); - }; - m.def("_sign_result_type", sign_result_type_pyapi); - } - - // ==== SIGNBIT (x) - { - impl::populate_signbit_dispatch_vectors(); - using impl::signbit_contig_dispatch_vector; - using impl::signbit_output_typeid_vector; - using impl::signbit_strided_dispatch_vector; - - auto signbit_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc(src, dst, exec_q, depends, - signbit_output_typeid_vector, - signbit_contig_dispatch_vector, - signbit_strided_dispatch_vector); - }; - m.def("_signbit", signbit_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto signbit_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - signbit_output_typeid_vector); - }; - m.def("_signbit_result_type", signbit_result_type_pyapi); - } - - // U30: ==== SIN (x) - { - impl::populate_sin_dispatch_vectors(); - using impl::sin_contig_dispatch_vector; - using impl::sin_output_typeid_vector; - using impl::sin_strided_dispatch_vector; - - auto sin_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, sin_output_typeid_vector, - sin_contig_dispatch_vector, sin_strided_dispatch_vector); - }; - m.def("_sin", sin_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto sin_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, sin_output_typeid_vector); - }; - m.def("_sin_result_type", sin_result_type_pyapi); - } - // U31: ==== SINH (x) - { - impl::populate_sinh_dispatch_vectors(); - using impl::sinh_contig_dispatch_vector; - using impl::sinh_output_typeid_vector; - using impl::sinh_strided_dispatch_vector; - - auto sinh_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, sinh_output_typeid_vector, - sinh_contig_dispatch_vector, sinh_strided_dispatch_vector); - }; - m.def("_sinh", sinh_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto sinh_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, sinh_output_typeid_vector); - }; - m.def("_sinh_result_type", sinh_result_type_pyapi); - } - - // U32: ==== SQUARE (x) - { - impl::populate_square_dispatch_vectors(); - using impl::square_contig_dispatch_vector; - using impl::square_output_typeid_vector; - using impl::square_strided_dispatch_vector; - - auto square_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, square_output_typeid_vector, - square_contig_dispatch_vector, square_strided_dispatch_vector); - }; - m.def("_square", square_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto square_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - square_output_typeid_vector); - }; - m.def("_square_result_type", square_result_type_pyapi); - } - - // U33: ==== SQRT (x) - { - impl::populate_sqrt_dispatch_vectors(); - using impl::sqrt_contig_dispatch_vector; - using impl::sqrt_output_typeid_vector; - using impl::sqrt_strided_dispatch_vector; - - auto sqrt_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, sqrt_output_typeid_vector, - sqrt_contig_dispatch_vector, sqrt_strided_dispatch_vector); - }; - m.def("_sqrt", sqrt_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto sqrt_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, sqrt_output_typeid_vector); - }; - m.def("_sqrt_result_type", sqrt_result_type_pyapi); - } - - // B23: ==== SUBTRACT (x1, x2) - { - impl::populate_subtract_dispatch_tables(); - using impl::subtract_contig_dispatch_table; - using impl::subtract_contig_matrix_contig_row_broadcast_dispatch_table; - using impl::subtract_contig_row_contig_matrix_broadcast_dispatch_table; - using impl::subtract_output_id_table; - using impl::subtract_strided_dispatch_table; - - auto subtract_pyapi = - [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, subtract_output_id_table, - // function pointers to handle operation on contiguous - // arrays (pointers may be nullptr) - subtract_contig_dispatch_table, - // function pointers to handle operation on strided arrays - // (most general case) - subtract_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix - // and c-contig row with broadcasting (may be nullptr) - subtract_contig_matrix_contig_row_broadcast_dispatch_table, - // function pointers to handle operation of c-contig matrix - // and c-contig row with broadcasting (may be nullptr) - subtract_contig_row_contig_matrix_broadcast_dispatch_table); - }; - auto subtract_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - subtract_output_id_table); - }; - m.def("_subtract", subtract_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_subtract_result_type", subtract_result_type_pyapi, ""); - - using impl::subtract_inplace_contig_dispatch_table; - using impl::subtract_inplace_row_matrix_dispatch_table; - using impl::subtract_inplace_strided_dispatch_table; - - auto subtract_inplace_pyapi = - [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_inplace_ufunc( - src, dst, exec_q, depends, subtract_output_id_table, - // function pointers to handle inplace operation on - // contiguous arrays (pointers may be nullptr) - subtract_inplace_contig_dispatch_table, - // function pointers to handle inplace operation on strided - // arrays (most general case) - subtract_inplace_strided_dispatch_table, - // function pointers to handle inplace operation on - // c-contig matrix with c-contig row with broadcasting - // (may be nullptr) - subtract_inplace_row_matrix_dispatch_table); - }; - m.def("_subtract_inplace", subtract_inplace_pyapi, "", py::arg("lhs"), - py::arg("rhs"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - } - - // U34: ==== TAN (x) - { - impl::populate_tan_dispatch_vectors(); - using impl::tan_contig_dispatch_vector; - using impl::tan_output_typeid_vector; - using impl::tan_strided_dispatch_vector; - - auto tan_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, tan_output_typeid_vector, - tan_contig_dispatch_vector, tan_strided_dispatch_vector); - }; - m.def("_tan", tan_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto tan_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, tan_output_typeid_vector); - }; - m.def("_tan_result_type", tan_result_type_pyapi); - } - - // U35: ==== TANH (x) - { - impl::populate_tanh_dispatch_vectors(); - using impl::tanh_contig_dispatch_vector; - using impl::tanh_output_typeid_vector; - using impl::tanh_strided_dispatch_vector; - - auto tanh_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, tanh_output_typeid_vector, - tanh_contig_dispatch_vector, tanh_strided_dispatch_vector); - }; - m.def("_tanh", tanh_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto tanh_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, tanh_output_typeid_vector); - }; - m.def("_tanh_result_type", tanh_result_type_pyapi); - } - - // U36: ==== TRUNC (x) - { - impl::populate_trunc_dispatch_vectors(); - using impl::trunc_contig_dispatch_vector; - using impl::trunc_output_typeid_vector; - using impl::trunc_strided_dispatch_vector; - - auto trunc_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, trunc_output_typeid_vector, - trunc_contig_dispatch_vector, trunc_strided_dispatch_vector); - }; - m.def("_trunc", trunc_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto trunc_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - trunc_output_typeid_vector); - }; - m.def("_trunc_result_type", trunc_result_type_pyapi); - } - - // B24: ==== HYPOT (x1, x2) - { - impl::populate_hypot_dispatch_tables(); - using impl::hypot_contig_dispatch_table; - using impl::hypot_output_id_table; - using impl::hypot_strided_dispatch_table; - - auto hypot_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, hypot_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - hypot_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - hypot_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto hypot_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - hypot_output_id_table); - }; - m.def("_hypot", hypot_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_hypot_result_type", hypot_result_type_pyapi, ""); - } - - // U37: ==== CBRT (x) - { - impl::populate_cbrt_dispatch_vectors(); - using impl::cbrt_contig_dispatch_vector; - using impl::cbrt_output_typeid_vector; - using impl::cbrt_strided_dispatch_vector; - - auto cbrt_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, cbrt_output_typeid_vector, - cbrt_contig_dispatch_vector, cbrt_strided_dispatch_vector); - }; - m.def("_cbrt", cbrt_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto cbrt_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, cbrt_output_typeid_vector); - }; - m.def("_cbrt_result_type", cbrt_result_type_pyapi); - } - - // B25: ==== COPYSIGN (x1, x2) - { - impl::populate_copysign_dispatch_tables(); - using impl::copysign_contig_dispatch_table; - using impl::copysign_output_id_table; - using impl::copysign_strided_dispatch_table; - - auto copysign_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, copysign_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - copysign_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - copysign_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto copysign_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - copysign_output_id_table); - }; - m.def("_copysign", copysign_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_copysign_result_type", copysign_result_type_pyapi, ""); - } - - // U38: ==== EXP2 (x) - { - impl::populate_exp2_dispatch_vectors(); - using impl::exp2_contig_dispatch_vector; - using impl::exp2_output_typeid_vector; - using impl::exp2_strided_dispatch_vector; - - auto exp2_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, exp2_output_typeid_vector, - exp2_contig_dispatch_vector, exp2_strided_dispatch_vector); - }; - m.def("_exp2", exp2_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto exp2_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, exp2_output_typeid_vector); - }; - m.def("_exp2_result_type", exp2_result_type_pyapi); - } - - // U39: ==== RSQRT (x) - { - impl::populate_rsqrt_dispatch_vectors(); - using impl::rsqrt_contig_dispatch_vector; - using impl::rsqrt_output_typeid_vector; - using impl::rsqrt_strided_dispatch_vector; - - auto rsqrt_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, rsqrt_output_typeid_vector, - rsqrt_contig_dispatch_vector, rsqrt_strided_dispatch_vector); - }; - m.def("_rsqrt", rsqrt_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto rsqrt_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - rsqrt_output_typeid_vector); - }; - m.def("_rsqrt_result_type", rsqrt_result_type_pyapi); - } -} - -} // namespace py_internal -} // namespace tensor -} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/linalg_functions/dot.cpp b/dpctl/tensor/libtensor/source/linalg_functions/dot.cpp new file mode 100644 index 0000000000..926f5ffad6 --- /dev/null +++ b/dpctl/tensor/libtensor/source/linalg_functions/dot.cpp @@ -0,0 +1,857 @@ +#include "dpctl4pybind11.hpp" +#include +#include +#include +#include +#include +#include +#include + +#include "dot.hpp" +#include "dot_atomic_support.hpp" +#include "dot_dispatch.hpp" +#include "elementwise_functions/elementwise_functions_type_utils.hpp" +#include "kernels/linalg_functions/dot_product.hpp" +#include "kernels/linalg_functions/gemm.hpp" +#include "reductions/reduction_atomic_support.hpp" +#include "simplify_iteration_space.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/offset_utils.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; + +static int dot_output_id_table[td_ns::num_types][td_ns::num_types]; + +using dpctl::tensor::kernels::dot_product_impl_fn_ptr_t; +static dot_product_impl_fn_ptr_t dot_product_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +static dot_product_impl_fn_ptr_t + dot_product_temps_dispatch_table[td_ns::num_types][td_ns::num_types]; + +using dpctl::tensor::kernels::dot_product_contig_impl_fn_ptr_t; +static dot_product_contig_impl_fn_ptr_t + dot_product_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static dot_product_contig_impl_fn_ptr_t + dot_product_contig_temps_dispatch_table[td_ns::num_types][td_ns::num_types]; + +using dpctl::tensor::kernels::gemm_impl_fn_ptr_t; +static gemm_impl_fn_ptr_t gemm_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +static gemm_impl_fn_ptr_t gemm_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::gemm_contig_impl_fn_ptr_t; +static gemm_contig_impl_fn_ptr_t + gemm_contig_atomic_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static gemm_contig_impl_fn_ptr_t + gemm_contig_temps_dispatch_table[td_ns::num_types][td_ns::num_types]; + +using dpctl::tensor::kernels::gemm_batch_impl_fn_ptr_t; +static gemm_batch_impl_fn_ptr_t + gemm_batch_atomic_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static gemm_batch_impl_fn_ptr_t + gemm_batch_temps_dispatch_table[td_ns::num_types][td_ns::num_types]; + +using dpctl::tensor::kernels::gemm_batch_contig_impl_fn_ptr_t; +static gemm_batch_contig_impl_fn_ptr_t + gemm_batch_contig_atomic_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static gemm_batch_contig_impl_fn_ptr_t + gemm_batch_contig_temps_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void init_dot_dispatch_tables(void) +{ + using dpctl::tensor::py_internal::DotTypeMapFactory; + td_ns::DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(dot_output_id_table); + + using dpctl::tensor::py_internal::GemmBatchAtomicFactory; + td_ns::DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(gemm_batch_atomic_dispatch_table); + + using dpctl::tensor::py_internal::GemmBatchContigAtomicFactory; + td_ns::DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(gemm_batch_contig_atomic_dispatch_table); + + using dpctl::tensor::py_internal::GemmAtomicFactory; + td_ns::DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(gemm_atomic_dispatch_table); + + using dpctl::tensor::py_internal::GemmContigAtomicFactory; + td_ns::DispatchTableBuilder + dtb5; + dtb5.populate_dispatch_table(gemm_contig_atomic_dispatch_table); + + using dpctl::tensor::py_internal::GemmBatchTempsFactory; + td_ns::DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(gemm_batch_temps_dispatch_table); + + using dpctl::tensor::py_internal::GemmBatchContigTempsFactory; + td_ns::DispatchTableBuilder + dtb7; + dtb7.populate_dispatch_table(gemm_batch_contig_temps_dispatch_table); + + using dpctl::tensor::py_internal::GemmTempsFactory; + td_ns::DispatchTableBuilder + dtb8; + dtb8.populate_dispatch_table(gemm_temps_dispatch_table); + + using dpctl::tensor::py_internal::GemmContigTempsFactory; + td_ns::DispatchTableBuilder + dtb9; + dtb9.populate_dispatch_table(gemm_contig_temps_dispatch_table); + + using dpctl::tensor::py_internal::DotProductAtomicFactory; + td_ns::DispatchTableBuilder + dtb10; + dtb10.populate_dispatch_table(dot_product_dispatch_table); + + using dpctl::tensor::py_internal::DotProductNoAtomicFactory; + td_ns::DispatchTableBuilder + dtb11; + dtb11.populate_dispatch_table(dot_product_temps_dispatch_table); + + using dpctl::tensor::py_internal::DotProductContigAtomicFactory; + td_ns::DispatchTableBuilder + dtb12; + dtb12.populate_dispatch_table(dot_product_contig_dispatch_table); + + using dpctl::tensor::py_internal::DotProductContigNoAtomicFactory; + td_ns::DispatchTableBuilder + dtb13; + dtb13.populate_dispatch_table(dot_product_contig_temps_dispatch_table); +} + +using atomic_support::atomic_support_fn_ptr_t; +static atomic_support_fn_ptr_t dot_atomic_support_vector[td_ns::num_types]; + +void init_dot_atomic_support_vector(void) +{ + + using atomic_support::DotAtomicSupportFactory; + td_ns::DispatchVectorBuilder + dvb; + dvb.populate_dispatch_vector(dot_atomic_support_vector); +} + +std::pair +py_dot(const dpctl::tensor::usm_ndarray &x1, + const dpctl::tensor::usm_ndarray &x2, + int batch_dims, + int x1_outer_dims, + int x2_outer_dims, + int inner_dims, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends) +{ + + if (!dst.is_writable()) { + throw py::value_error("Output array is read-only."); + } + + if (inner_dims == 0) { + throw py::value_error("No inner dimension for dot"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {x1, x2, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + int x1_nd = x1.get_ndim(); + int x2_nd = x2.get_ndim(); + if (x1_nd != (batch_dims + x1_outer_dims + inner_dims) || + x2_nd != (batch_dims + x2_outer_dims + inner_dims)) + { + throw py::value_error("Input arrays do not have dimensions consistent " + "with input dimensions"); + } + + int dst_nd = dst.get_ndim(); + if (dst_nd != (batch_dims + x1_outer_dims + x2_outer_dims)) { + throw py::value_error("Destination array rank does not match input " + "array rank and number of input dimensions"); + } + + const py::ssize_t *x1_shape_ptr = x1.get_shape_raw(); + const py::ssize_t *x2_shape_ptr = x2.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + bool same_shapes = true; + size_t batches(1); + for (int i = 0; same_shapes && (i < batch_dims); ++i) { + same_shapes = same_shapes && (x1_shape_ptr[i] == dst_shape_ptr[i]) && + (x2_shape_ptr[i] == dst_shape_ptr[i]); + batches *= x1_shape_ptr[i]; + } + size_t x1_outer_nelems(1); + for (int i = batch_dims; same_shapes && (i < (batch_dims + x1_outer_dims)); + ++i) { + same_shapes = same_shapes && (x1_shape_ptr[i] == dst_shape_ptr[i]); + x1_outer_nelems *= x1_shape_ptr[i]; + } + size_t inner_nelems(1); + for (int i = batch_dims; i < (batch_dims + inner_dims); ++i) { + auto x1_shape_idx = x1_outer_dims + i; + same_shapes = + same_shapes && (x1_shape_ptr[x1_shape_idx] == x2_shape_ptr[i]); + inner_nelems *= x1_shape_ptr[x1_shape_idx]; + } + size_t x2_outer_nelems(1); + for (int i = 0; same_shapes && (i < x2_outer_dims); ++i) { + auto x2_shape_idx = batch_dims + inner_dims + i; + same_shapes = + same_shapes && (x2_shape_ptr[x2_shape_idx] == + dst_shape_ptr[batch_dims + x1_outer_dims + i]); + x2_outer_nelems *= x2_shape_ptr[x2_shape_idx]; + } + if (!same_shapes) { + throw py::value_error("Input arrays to tensor dot product do not have " + "appropriate shapes"); + } + + size_t dst_nelems = batches * x1_outer_nelems * x2_outer_nelems; + if (dst_nelems == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + if (static_cast(dst.get_size()) != dst_nelems) { + throw py::value_error("dst shape and size mismatch"); + } + + // ensure that dst is sufficiently ample + auto dst_offsets = dst.get_minmax_offsets(); + // destination must be ample enough to accommodate all elements + { + size_t range = + static_cast(dst_offsets.second - dst_offsets.first); + if (range + 1 < dst_nelems) { + throw py::value_error( + "Memory addressed by the destination array can not " + "accommodate all the " + "array elements."); + } + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + // check that dst does not intersect with x1 or x2 + if (overlap(dst, x1) || overlap(dst, x2)) { + throw py::value_error("Result array overlaps with inputs"); + } + + int x1_typenum = x1.get_typenum(); + int x2_typenum = x2.get_typenum(); + int dst_typenum = dst.get_typenum(); + + auto const &array_types = td_ns::usm_ndarray_types(); + int x1_typeid = array_types.typenum_to_lookup_id(x1_typenum); + int x2_typeid = array_types.typenum_to_lookup_id(x2_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + int output_typeid = dot_output_id_table[x1_typeid][x2_typeid]; + + if (output_typeid != dst_typeid) { + throw py::value_error( + "Result array has unexpected elemental data type."); + } + + void *data_ptr = dst.get_data(); + const auto &ctx = exec_q.get_context(); + auto usm_type = sycl::get_pointer_type(data_ptr, ctx); + bool supports_atomics = + dot_atomic_support_vector[output_typeid](exec_q, usm_type); + + const char *x1_data = x1.get_data(); + const char *x2_data = x2.get_data(); + char *dst_data = dst.get_data(); + + auto x1_shape_vec = x1.get_shape_vector(); + auto x1_strides_vec = x1.get_strides_vector(); + + auto x2_shape_vec = x2.get_shape_vector(); + auto x2_strides_vec = x2.get_strides_vector(); + + auto dst_shape_vec = dst.get_shape_vector(); + auto dst_strides_vec = dst.get_strides_vector(); + + bool is_x1_c_contig = x1.is_c_contiguous(); + bool is_x1_f_contig = x1.is_f_contiguous(); + bool is_x2_c_contig = x2.is_c_contiguous(); + bool is_x2_f_contig = x2.is_f_contiguous(); + bool is_dst_c_contig = dst.is_c_contiguous(); + + bool call_vecdot = ((x1_outer_dims == 0 && x1_outer_nelems == 1) && + (x2_outer_dims == 0 && x2_outer_nelems == 1)); + + bool call_batched = (batch_dims != 0 || batches > 1); + std::vector host_task_events{}; + sycl::event dot_ev; + if (call_vecdot) { + if ((is_x1_c_contig && is_x2_c_contig && is_dst_c_contig) || + ((is_x1_f_contig && is_x2_f_contig) && !call_batched)) + { + dot_product_contig_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = dot_product_contig_dispatch_table[x1_typeid][x2_typeid]; + } + else { + fn = dot_product_contig_temps_dispatch_table[x1_typeid] + [x2_typeid]; + } + if (fn != nullptr) { + constexpr py::ssize_t zero_offset = 0; + dot_ev = fn(exec_q, dst_nelems, inner_nelems, x1.get_data(), + x2.get_data(), dst.get_data(), + zero_offset, // lhs batch offset + zero_offset, // rhs batch offset + zero_offset, // res batch offset + zero_offset, // lhs reduction offset + zero_offset, // rhs reduction offset + depends); + return std::make_pair(dpctl::utils::keep_args_alive( + exec_q, {x1, x2, dst}, {dot_ev}), + dot_ev); + } + } + using dpctl::tensor::py_internal::simplify_iteration_space; + using dpctl::tensor::py_internal::simplify_iteration_space_3; + + int inner_nd = inner_dims; + const py::ssize_t *inner_shape_ptr = x1_shape_ptr + batch_dims; + using shT = std::vector; + shT inner_x1_strides(std::begin(x1_strides_vec) + batch_dims, + std::end(x1_strides_vec)); + shT inner_x2_strides(std::begin(x2_strides_vec) + batch_dims, + std::end(x2_strides_vec)); + + shT simplified_inner_shape; + shT simplified_inner_x1_strides; + shT simplified_inner_x2_strides; + py::ssize_t inner_x1_offset(0); + py::ssize_t inner_x2_offset(0); + + simplify_iteration_space( + inner_nd, inner_shape_ptr, inner_x1_strides, inner_x2_strides, + // output + simplified_inner_shape, simplified_inner_x1_strides, + simplified_inner_x2_strides, inner_x1_offset, inner_x2_offset); + + const py::ssize_t *batch_shape_ptr = x1_shape_ptr; + + shT batch_x1_strides(std::begin(x1_strides_vec), + std::begin(x1_strides_vec) + batch_dims); + shT batch_x2_strides(std::begin(x2_strides_vec), + std::begin(x2_strides_vec) + batch_dims); + shT const &batch_dst_strides = dst_strides_vec; + + shT simplified_batch_shape; + shT simplified_batch_x1_strides; + shT simplified_batch_x2_strides; + shT simplified_batch_dst_strides; + py::ssize_t batch_x1_offset(0); + py::ssize_t batch_x2_offset(0); + py::ssize_t batch_dst_offset(0); + + if (batch_dims == 0) { + if (dst_nelems != 1) { + throw std::runtime_error( + "batch_dims == 0, but dst_nelems != 1"); + } + batch_dims = 1; + simplified_batch_shape.push_back(1); + simplified_batch_x1_strides.push_back(0); + simplified_batch_x2_strides.push_back(0); + simplified_batch_dst_strides.push_back(0); + } + else { + simplify_iteration_space_3( + batch_dims, batch_shape_ptr, batch_x1_strides, batch_x2_strides, + batch_dst_strides, + // output + simplified_batch_shape, simplified_batch_x1_strides, + simplified_batch_x2_strides, simplified_batch_dst_strides, + batch_x1_offset, batch_x2_offset, batch_dst_offset); + } + + if (inner_nd == 1 && batch_dims == 1) { + bool dot_product_c_contig = false; + bool reduce_all_elems = false; + + if (simplified_inner_x1_strides[0] == 1 && + simplified_inner_x2_strides[0] == 1) { + reduce_all_elems = (simplified_batch_shape[0] == 1); + dot_product_c_contig = + (simplified_batch_dst_strides[0] == 1) && + (static_cast(simplified_batch_x1_strides[0]) == + inner_nelems) && + (static_cast(simplified_batch_x2_strides[0]) == + inner_nelems); + } + + if (dot_product_c_contig || reduce_all_elems) { + dot_product_contig_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = + dot_product_contig_dispatch_table[x1_typeid][x2_typeid]; + } + else { + fn = dot_product_contig_temps_dispatch_table[x1_typeid] + [x2_typeid]; + } + if (fn != nullptr) { + dot_ev = fn(exec_q, dst_nelems, inner_nelems, x1.get_data(), + x2.get_data(), dst.get_data(), + batch_x1_offset, // lhs batch offset + batch_x2_offset, // rhs batch offset + batch_dst_offset, // res batch offset + inner_x1_offset, // lhs reduction offset + inner_x2_offset, // rhs reduction offset + depends); + return std::make_pair(dpctl::utils::keep_args_alive( + exec_q, {x1, x2, dst}, {dot_ev}), + dot_ev); + } + } + } + + dot_product_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = dot_product_dispatch_table[x1_typeid][x2_typeid]; + } + if (fn == nullptr) { + fn = dot_product_temps_dispatch_table[x1_typeid][x2_typeid]; + if (fn == nullptr) { + throw std::runtime_error( + "Implementation is missing for x1_typeid=" + + std::to_string(x1_typeid) + + " and x2_typeid=" + std::to_string(x2_typeid)); + } + } + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + const auto &arrays_metainfo_packing_triple_ = + device_allocate_and_pack( + exec_q, host_task_events, + // iteration metadata + simplified_batch_shape, simplified_batch_x1_strides, + simplified_batch_x2_strides, simplified_batch_dst_strides, + // reduction metadata + simplified_inner_shape, simplified_inner_x1_strides, + simplified_inner_x2_strides); + py::ssize_t *temp_allocation_ptr = + std::get<0>(arrays_metainfo_packing_triple_); + if (temp_allocation_ptr == nullptr) { + throw std::runtime_error("Unable to allocate memory on device"); + } + const auto ©_metadata_ev = + std::get<2>(arrays_metainfo_packing_triple_); + + py::ssize_t *iter_shape_and_strides = temp_allocation_ptr; + py::ssize_t *inner_shape_stride = + temp_allocation_ptr + 4 * simplified_batch_shape.size(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.resize(depends.size()); + std::copy(depends.begin(), depends.end(), all_deps.begin()); + all_deps.push_back(copy_metadata_ev); + + dot_ev = + fn(exec_q, dst_nelems, inner_nelems, x1.get_data(), x2.get_data(), + dst.get_data(), batch_dims, iter_shape_and_strides, + batch_x1_offset, batch_x2_offset, batch_dst_offset, + inner_nd, // number dimensions being reduced + inner_shape_stride, inner_x1_offset, inner_x2_offset, all_deps); + + sycl::event temp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dot_ev); + const auto &ctx = exec_q.get_context(); + cgh.host_task([ctx, temp_allocation_ptr] { + sycl::free(temp_allocation_ptr, ctx); + }); + }); + host_task_events.push_back(temp_cleanup_ev); + } + else { // if (!call_vecdot) + if (!call_batched) { + if ((is_x1_c_contig && is_x2_c_contig && is_dst_c_contig)) { + gemm_contig_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = + gemm_contig_atomic_dispatch_table[x1_typeid][x2_typeid]; + } + else { + fn = gemm_contig_temps_dispatch_table[x1_typeid][x2_typeid]; + } + if (fn != nullptr) { + dot_ev = fn(exec_q, x1_data, x2_data, dst_data, + x1_outer_nelems, // n + inner_nelems, // k + x2_outer_nelems, // m + depends); + return std::make_pair(dpctl::utils::keep_args_alive( + exec_q, {x1, x2, dst}, {dot_ev}), + dot_ev); + } + } + gemm_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = gemm_atomic_dispatch_table[x1_typeid][x2_typeid]; + } + if (fn == nullptr) { + fn = gemm_temps_dispatch_table[x1_typeid][x2_typeid]; + if (fn == nullptr) { + throw std::runtime_error( + "Implementation is missing for x1_typeid=" + + std::to_string(x1_typeid) + + " and x2_typeid=" + std::to_string(x2_typeid)); + } + } + using dpctl::tensor::offset_utils::device_allocate_and_pack; + const auto &ptr_size_event_tuple1 = + device_allocate_and_pack( + exec_q, host_task_events, x1_shape_vec, x1_strides_vec, + x2_shape_vec, x2_strides_vec, dst_shape_vec, + dst_strides_vec); + py::ssize_t *packed_shapes_strides = + std::get<0>(ptr_size_event_tuple1); + if (packed_shapes_strides == nullptr) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event copy_shapes_strides_ev = + std::get<2>(ptr_size_event_tuple1); + py::ssize_t *x1_shape_strides = packed_shapes_strides; + py::ssize_t *x2_shape_strides = packed_shapes_strides + 2 * (x1_nd); + py::ssize_t *dst_shape_strides = + packed_shapes_strides + 2 * (x1_nd + x2_nd); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shapes_strides_ev); + + // change gemm calls to pass inner dims and outer dims separately + dot_ev = + fn(exec_q, x1_data, x2_data, dst_data, x1_outer_nelems, + inner_nelems, x2_outer_nelems, inner_dims, x1_outer_dims, + x1_shape_strides, x2_outer_dims, x2_shape_strides, + x1_outer_dims + x2_outer_dims, dst_shape_strides, all_deps); + + sycl::event cleanup_tmp_allocations_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dot_ev); + const auto &ctx = exec_q.get_context(); + cgh.host_task([ctx, packed_shapes_strides] { + sycl::free(packed_shapes_strides, ctx); + }); + }); + host_task_events.push_back(cleanup_tmp_allocations_ev); + host_task_events.push_back(dot_ev); + } + else { // if (call_batched) + using shT = std::vector; + // temporary asserts for matmul + assert(x1_outer_dims == 1); + assert(x2_outer_dims == 1); + assert(inner_dims == 1); + + if ((is_x1_c_contig && is_x2_c_contig && is_dst_c_contig)) { + gemm_batch_contig_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = gemm_batch_contig_atomic_dispatch_table[x1_typeid] + [x2_typeid]; + } + else { + fn = gemm_batch_contig_temps_dispatch_table[x1_typeid] + [x2_typeid]; + } + if (fn != nullptr) { + constexpr py::ssize_t zero_offset = 0; + dot_ev = fn(exec_q, x1_data, x2_data, dst_data, batches, + x1_outer_nelems, // n + inner_nelems, // k + x2_outer_nelems, // m + zero_offset, zero_offset, zero_offset, depends); + return std::make_pair(dpctl::utils::keep_args_alive( + exec_q, {x1, x2, dst}, {dot_ev}), + dot_ev); + } + } + + auto x1_outer_inner_dims = x1_nd - batch_dims; + auto x2_outer_inner_dims = x2_nd - batch_dims; + auto dst_outer_inner_dims = dst_nd - batch_dims; + + shT batch_x1_shape; + shT outer_inner_x1_shape; + shT batch_x1_strides; + shT outer_inner_x1_strides; + dpctl::tensor::py_internal::split_iteration_space( + x1_shape_vec, x1_strides_vec, batch_dims, + batch_dims + x1_outer_inner_dims, batch_x1_shape, + outer_inner_x1_shape, // 4 vectors modified + batch_x1_strides, outer_inner_x1_strides); + + shT batch_x2_shape; + shT outer_inner_x2_shape; + shT batch_x2_strides; + shT outer_inner_x2_strides; + dpctl::tensor::py_internal::split_iteration_space( + x2_shape_vec, x2_strides_vec, batch_dims, + batch_dims + x2_outer_inner_dims, batch_x2_shape, + outer_inner_x2_shape, // 4 vectors modified + batch_x2_strides, outer_inner_x2_strides); + + shT batch_dst_shape; + shT outer_inner_dst_shape; + shT batch_dst_strides; + shT outer_inner_dst_strides; + dpctl::tensor::py_internal::split_iteration_space( + dst_shape_vec, dst_strides_vec, batch_dims, + batch_dims + dst_outer_inner_dims, batch_dst_shape, + outer_inner_dst_shape, // 4 vectors modified + batch_dst_strides, outer_inner_dst_strides); + + using shT = std::vector; + shT simplified_batch_shape; + shT simplified_batch_x1_strides; + shT simplified_batch_x2_strides; + shT simplified_batch_dst_strides; + py::ssize_t x1_batch_offset(0); + py::ssize_t x2_batch_offset(0); + py::ssize_t dst_batch_offset(0); + + const py::ssize_t *shape = x1_shape_ptr; + + using dpctl::tensor::py_internal::simplify_iteration_space_3; + simplify_iteration_space_3( + batch_dims, shape, batch_x1_strides, batch_x2_strides, + batch_dst_strides, + // outputs + simplified_batch_shape, simplified_batch_x1_strides, + simplified_batch_x2_strides, simplified_batch_dst_strides, + x1_batch_offset, x2_batch_offset, dst_batch_offset); + + if (batch_dims == 1 && x1_outer_dims == 1 && x2_outer_dims == 1 && + inner_dims == 1) + { + bool gemm_batch_c_contig = false; + + if ((static_cast(outer_inner_x1_strides[0]) == + inner_nelems && + outer_inner_x1_strides[1] == 1) && + (static_cast(outer_inner_x2_strides[0]) == + inner_nelems && + outer_inner_x2_strides[1] == 1) && + (static_cast(outer_inner_dst_strides[0]) == + x2_outer_nelems && + outer_inner_dst_strides[1] == 1)) + { + gemm_batch_c_contig = + (static_cast(simplified_batch_x1_strides[0]) == + x1_outer_nelems * inner_nelems) && + (static_cast(simplified_batch_x2_strides[0]) == + x2_outer_nelems * inner_nelems) && + (static_cast(simplified_batch_dst_strides[0]) == + x1_outer_nelems * x2_outer_nelems); + } + + if (gemm_batch_c_contig) { + gemm_batch_contig_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = gemm_batch_contig_atomic_dispatch_table[x1_typeid] + [x2_typeid]; + } + else { + fn = gemm_batch_contig_temps_dispatch_table[x1_typeid] + [x2_typeid]; + } + if (fn != nullptr) { + dot_ev = fn(exec_q, x1_data, x2_data, dst_data, batches, + x1_outer_nelems, // n + inner_nelems, // k + x2_outer_nelems, // m + x1_batch_offset, x2_batch_offset, + dst_batch_offset, depends); + return std::make_pair( + dpctl::utils::keep_args_alive(exec_q, {x1, x2, dst}, + {dot_ev}), + dot_ev); + } + } + } + + gemm_batch_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = gemm_batch_atomic_dispatch_table[x1_typeid][x2_typeid]; + } + if (fn == nullptr) { + fn = gemm_batch_temps_dispatch_table[x1_typeid][x2_typeid]; + if (fn == nullptr) { + throw std::runtime_error( + "Implementation is missing for x1_typeid=" + + std::to_string(x1_typeid) + + " and x2_typeid=" + std::to_string(x2_typeid)); + } + } + using dpctl::tensor::offset_utils::device_allocate_and_pack; + const auto &ptr_size_event_tuple1 = + device_allocate_and_pack( + exec_q, host_task_events, simplified_batch_shape, + simplified_batch_x1_strides, simplified_batch_x2_strides, + simplified_batch_dst_strides, outer_inner_x1_shape, + outer_inner_x1_strides, outer_inner_x2_shape, + outer_inner_x2_strides, outer_inner_dst_shape, + outer_inner_dst_strides, + // full shape and strides of the result array + // necessary for reduction and initialization + simplified_batch_shape, outer_inner_dst_shape, + simplified_batch_dst_strides, outer_inner_dst_strides); + py::ssize_t *packed_shapes_strides = + std::get<0>(ptr_size_event_tuple1); + if (packed_shapes_strides == nullptr) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event copy_shapes_strides_ev = + std::get<2>(ptr_size_event_tuple1); + + auto batch_shape_strides = packed_shapes_strides; + auto x1_outer_inner_shapes_strides = + packed_shapes_strides + 4 * batch_dims; + auto x2_outer_inner_shapes_strides = packed_shapes_strides + + 4 * batch_dims + + 2 * (x1_outer_inner_dims); + auto dst_outer_shapes_strides = + packed_shapes_strides + 4 * batch_dims + + 2 * (x1_outer_inner_dims) + 2 * (x2_outer_inner_dims); + auto dst_full_shape_strides = + packed_shapes_strides + 4 * batch_dims + + 2 * (x1_outer_inner_dims) + 2 * (x2_outer_inner_dims) + + 2 * (dst_outer_inner_dims); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shapes_strides_ev); + + dot_ev = fn( + exec_q, x1_data, x2_data, dst_data, batches, x1_outer_nelems, + inner_nelems, x2_outer_nelems, batch_dims, batch_shape_strides, + x1_batch_offset, x2_batch_offset, dst_batch_offset, inner_dims, + x1_outer_dims, x1_outer_inner_shapes_strides, x2_outer_dims, + x2_outer_inner_shapes_strides, x1_outer_dims + x2_outer_dims, + dst_outer_shapes_strides, dst_full_shape_strides, all_deps); + + sycl::event cleanup_tmp_allocations_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dot_ev); + const auto &ctx = exec_q.get_context(); + cgh.host_task([ctx, packed_shapes_strides] { + sycl::free(packed_shapes_strides, ctx); + }); + }); + host_task_events.push_back(cleanup_tmp_allocations_ev); + host_task_events.push_back(dot_ev); + } + } + return std::make_pair( + dpctl::utils::keep_args_alive(exec_q, {x1, x2, dst}, host_task_events), + dot_ev); +} + +template +py::object py_dot_result_type(const py::dtype &input1_dtype, + const py::dtype &input2_dtype, + const output_typesT &output_types_table) +{ + int tn1 = input1_dtype.num(); // NumPy type numbers are the same as in dpctl + int tn2 = input2_dtype.num(); // NumPy type numbers are the same as in dpctl + int src1_typeid = -1; + int src2_typeid = -1; + + auto array_types = td_ns::usm_ndarray_types(); + + try { + src1_typeid = array_types.typenum_to_lookup_id(tn1); + src2_typeid = array_types.typenum_to_lookup_id(tn2); + } catch (const std::exception &e) { + throw py::value_error(e.what()); + } + + if (src1_typeid < 0 || src1_typeid >= td_ns::num_types || src2_typeid < 0 || + src2_typeid >= td_ns::num_types) + { + throw std::runtime_error("binary output type lookup failed"); + } + int dst_typeid = output_types_table[src1_typeid][src2_typeid]; + + if (dst_typeid < 0) { + auto res = py::none(); + return py::cast(res); + } + else { + using dpctl::tensor::py_internal::type_utils::_dtype_from_typenum; + + auto dst_typenum_t = static_cast(dst_typeid); + auto dt = _dtype_from_typenum(dst_typenum_t); + + return py::cast(dt); + } +} + +void init_dot(py::module_ m) +{ + using dpctl::tensor::py_internal::init_dot_atomic_support_vector; + init_dot_atomic_support_vector(); + using dpctl::tensor::py_internal::init_dot_dispatch_tables; + init_dot_dispatch_tables(); + + using dpctl::tensor::py_internal::py_dot; + m.def("_dot", &py_dot, "", py::arg("x1"), py::arg("x2"), + py::arg("batch_dims"), py::arg("x1_outer_dims"), + py::arg("x2_outer_dims"), py::arg("inner_dims"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + using dpctl::tensor::py_internal::dot_output_id_table; + auto dot_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + using dpctl::tensor::py_internal::py_dot_result_type; + return py_dot_result_type(dtype1, dtype2, dot_output_id_table); + }; + m.def("_dot_result_type", dot_result_type_pyapi, ""); +} + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/linalg_functions/dot.hpp b/dpctl/tensor/libtensor/source/linalg_functions/dot.hpp new file mode 100644 index 0000000000..5f8f6cf494 --- /dev/null +++ b/dpctl/tensor/libtensor/source/linalg_functions/dot.hpp @@ -0,0 +1,17 @@ +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +extern void init_dot(py::module_ m); + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/linalg_functions/dot_atomic_support.hpp b/dpctl/tensor/libtensor/source/linalg_functions/dot_atomic_support.hpp new file mode 100644 index 0000000000..29022342a1 --- /dev/null +++ b/dpctl/tensor/libtensor/source/linalg_functions/dot_atomic_support.hpp @@ -0,0 +1,34 @@ +#pragma once + +#include + +#include "reductions/reduction_atomic_support.hpp" +#include "utils/type_utils.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ +namespace atomic_support +{ + +template struct DotAtomicSupportFactory +{ + fnT get() + { + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + return atomic_support::fixed_decision; + } + else { + return atomic_support::check_atomic_support; + } + } +}; + +} // namespace atomic_support +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp b/dpctl/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp new file mode 100644 index 0000000000..de59450174 --- /dev/null +++ b/dpctl/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp @@ -0,0 +1,336 @@ +#pragma once + +#include +#include +#include + +#include "kernels/linalg_functions/dot_product.hpp" +#include "kernels/linalg_functions/gemm.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +template struct DotAtomicOutputType +{ + using value_type = typename std::disjunction< // disjunction is C++17 + // feature, supported by DPC++ + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::DefaultResultEntry>::result_type; +}; + +// add separate type support lists for atomic vs. temps +// gemm, gevm, and dot product share output type struct +template struct DotNoAtomicOutputType +{ + using value_type = typename std::disjunction< // disjunction is C++17 + // feature, supported by DPC++ + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + T2, + std::complex, + std::complex>, + td_ns::BinaryTypeMapResultEntry, + T2, + std::complex, + std::complex>, + td_ns::DefaultResultEntry>::result_type; +}; + +template struct DotTypeMapFactory +{ + /*! @brief get typeid for output type of kernels called by py_dot */ + std::enable_if_t::value, int> get() + { + using rT1 = typename DotNoAtomicOutputType::value_type; + using rT2 = typename DotAtomicOutputType::value_type; + static_assert(std::is_same_v || std::is_same_v); + return td_ns::GetTypeid{}.get(); + } +}; + +template struct GemmBatchAtomicFactory +{ + fnT get() + { + using T3 = typename DotAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_batch_impl; + fnT fn = gemm_batch_impl; + return fn; + } + } +}; + +template +struct GemmBatchContigAtomicFactory +{ + fnT get() + { + using T3 = typename DotAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_batch_contig_impl; + fnT fn = gemm_batch_contig_impl; + return fn; + } + } +}; + +template struct GemmAtomicFactory +{ + fnT get() + { + using T3 = typename DotAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_impl; + fnT fn = gemm_impl; + return fn; + } + } +}; + +template struct GemmContigAtomicFactory +{ + fnT get() + { + using T3 = typename DotAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_contig_impl; + fnT fn = gemm_contig_impl; + return fn; + } + } +}; + +template struct GemmTempsFactory +{ + fnT get() + { + using T3 = typename DotNoAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_tree_impl; + fnT fn = gemm_tree_impl; + return fn; + } + } +}; + +template struct GemmContigTempsFactory +{ + fnT get() + { + using T3 = typename DotNoAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_contig_tree_impl; + fnT fn = gemm_contig_tree_impl; + return fn; + } + } +}; + +template struct GemmBatchTempsFactory +{ + fnT get() + { + using T3 = typename DotNoAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_batch_tree_impl; + fnT fn = gemm_batch_tree_impl; + return fn; + } + } +}; + +template +struct GemmBatchContigTempsFactory +{ + fnT get() + { + using T3 = typename DotNoAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_batch_contig_tree_impl; + fnT fn = gemm_batch_contig_tree_impl; + return fn; + } + } +}; + +template struct DotProductAtomicFactory +{ + fnT get() + { + using T3 = typename DotAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::dot_product_impl; + fnT fn = dot_product_impl; + return fn; + } + } +}; + +template +struct DotProductNoAtomicFactory +{ + fnT get() + { + using T3 = typename DotNoAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::dot_product_tree_impl; + fnT fn = dot_product_tree_impl; + return fn; + } + } +}; + +template +struct DotProductContigAtomicFactory +{ + fnT get() + { + using T3 = typename DotAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::dot_product_contig_impl; + fnT fn = dot_product_contig_impl; + return fn; + } + } +}; + +template +struct DotProductContigNoAtomicFactory +{ + fnT get() + { + using T3 = typename DotNoAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::dot_product_contig_tree_impl; + fnT fn = dot_product_contig_tree_impl; + return fn; + } + } +}; + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/tensor_linalg.cpp b/dpctl/tensor/libtensor/source/tensor_linalg.cpp new file mode 100644 index 0000000000..82c9893c08 --- /dev/null +++ b/dpctl/tensor/libtensor/source/tensor_linalg.cpp @@ -0,0 +1,34 @@ +//===-- tensor_linalg.cpp ---*-C++-*-/===// +// Implementation of _tensor_linalg_impl module +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===----------------------------------------------------------------------===// + +#include "linalg_functions/dot.hpp" +#include + +namespace py = pybind11; + +PYBIND11_MODULE(_tensor_linalg_impl, m) +{ + dpctl::tensor::py_internal::init_dot(m); +} diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py index 4023eb8ad7..881729136d 100644 --- a/dpctl/tests/test_usm_ndarray_linalg.py +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -14,10 +14,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools + +import numpy as np import pytest +import dpctl import dpctl.tensor as dpt -from dpctl.tests.helper import get_queue_or_skip +from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported +from dpctl.utils import ExecutionPlacementError + +_numeric_types = [ + "i1", + "u1", + "i2", + "u2", + "i4", + "u4", + "i8", + "u8", + "f2", + "f4", + "f8", + "c8", + "c16", +] def test_matrix_transpose(): @@ -46,3 +67,780 @@ def test_matrix_transpose_arg_validation(): X = dpt.empty((5, 5), dtype="i4") assert isinstance(dpt.matrix_transpose(X), dpt.usm_ndarray) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_matmul_simple(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + n, m = 235, 17 + m1 = dpt.ones((m, n), dtype=dtype) + m2 = dpt.ones((n, m), dtype=dtype) + + for k in [1, 2, 3, 4, 7, 8, 9, 15, 16, 17]: + r = dpt.matmul(m1[:k, :], m2[:, :k]) + assert dpt.all(r == dpt.full((k, k), n, dtype=dtype)) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_matmul_nilpotent1(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + n = 77 + N_mat = dpt.eye(n, k=1, dtype=dtype) + I_mat = dpt.eye(n, dtype=dtype) + R_mat = dpt.eye(n, dtype=dtype) + for _ in range(n + 1): + R_mat = I_mat + dpt.matmul(N_mat, R_mat) + + assert dpt.allclose(dpt.matmul(I_mat - N_mat, R_mat), I_mat) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_matmul_nilpotent2(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + n = 128 + u = dpt.ones((n, 1), dtype=dtype) + v = dpt.ones((1, n), dtype=dtype) + + uv = dpt.matmul(u, v) + uv_ref = u * v + + assert dpt.allclose(uv, uv_ref) + + +def test_matmul_null_axis(): + n = 3 + + A_mat = dpt.ones((n, 0), dtype="f4") + B_mat = dpt.ones((0, 1), dtype="f4") + + R_mat = dpt.matmul(A_mat, B_mat) + assert R_mat.shape == (n, 1) + + R_mat = dpt.matmul(A_mat, B_mat[:, :0]) + assert R_mat.shape == (n, 0) + + +@pytest.mark.parametrize("dtype", ["i4", "f4"]) +def test_matmul_dims(dtype): + get_queue_or_skip() + + n, m, k, b = 4, 5, 7, 3 + v = dpt.ones(k, dtype=dtype) + m1 = dpt.ones((n, k), dtype=dtype) + m2 = dpt.ones((k, m), dtype=dtype) + st1 = dpt.ones((b, n, k), dtype=dtype) + st2 = dpt.ones((b, k, m), dtype=dtype) + + r = dpt.matmul(v, v) + assert r.shape == tuple() + assert dpt.round(r) == k + + r = dpt.matmul(m1, v) + assert r.shape == (n,) + assert dpt.all(dpt.round(r) == k) + + r = dpt.matmul(v, m2) + assert r.shape == (m,) + assert dpt.all(dpt.round(r) == k) + + r = dpt.matmul(m1, m2) + assert r.shape == ( + n, + m, + ) + assert dpt.all(dpt.round(r) == k) + + r = dpt.matmul(v, st2) + assert r.shape == ( + b, + m, + ) + assert dpt.all(dpt.round(r) == k) + + r = dpt.matmul(st1, v) + assert r.shape == ( + b, + n, + ) + assert dpt.all(dpt.round(r) == k) + + r = dpt.matmul(st1, m2) + assert r.shape == ( + b, + n, + m, + ) + assert dpt.all(dpt.round(r) == k) + + r = dpt.matmul(m1, st2) + assert r.shape == ( + b, + n, + m, + ) + assert dpt.all(dpt.round(r) == k) + + r = dpt.matmul(st1, st2) + assert r.shape == ( + b, + n, + m, + ) + assert dpt.all(dpt.round(r) == k) + + +def test_matmul_arg_validation(): + get_queue_or_skip() + + s1, s2 = dpt.ones(tuple()), dpt.zeros(tuple()) + v1, v2 = dpt.ones(16), dpt.zeros(16) + + with pytest.raises(ValueError): + dpt.matmul(s1, v2) + + with pytest.raises(ValueError): + dpt.matmul(v1, s2) + + with pytest.raises(TypeError): + dpt.matmul(dict(), v2) + + with pytest.raises(TypeError): + dpt.matmul(v2, None) + + +def test_matmul_dims_validation(): + get_queue_or_skip() + + m1 = dpt.ones((16, 16)) + m2 = dpt.ones((16, 16)) + + # contraction dimensions mismatch + with pytest.raises(ValueError): + dpt.matmul(m1[:, :7], m2[:3, :]) + + m1 = dpt.ones((3, 4, 5)) + m2 = dpt.ones((2, 5, 3)) + # broadcasting dimensions mismatch + with pytest.raises(ValueError): + dpt.matmul(m1, m2) + + +def test_matmul_broadcasting(): + get_queue_or_skip() + + for dt1, dt2 in [ + (dpt.int16, dpt.int32), + (dpt.float32, dpt.int16), + (dpt.int32, dpt.uint32), + ]: + m1 = dpt.ones((7, 11, 16), dtype=dt1) + m2 = dpt.ones((16, 13), dtype=dt2) + + r = dpt.matmul(m1, m2[dpt.newaxis, ...]) + + assert r.shape == (7, 11, 13) + + +@pytest.mark.parametrize("dtype", ["i4", "i8", "f4", "c8"]) +def test_matmul_strided(dtype): + get_queue_or_skip() + + m1_shape = (14, 22, 32) + m1_size = 1 + for el in m1_shape: + m1_size = m1_size * el + + m1 = dpt.remainder(dpt.arange(1, m1_size + 1, dtype="i8"), 13) + m1_orig = dpt.reshape(dpt.astype(m1, dtype), m1_shape) + m2_orig = dpt.ones((14, 16, 13), dtype=dtype) + + m1 = m1_orig[::2, ::-2, ::2] + m2 = m2_orig[::2, :, :] + r = dpt.matmul(m1, m2) + + assert r.shape == m1.shape[:2] + m2.shape[-1:] + ref = np.matmul(dpt.asnumpy(m1), dpt.asnumpy(m2)) + assert np.allclose(dpt.asnumpy(r), ref) + + m1 = m1_orig[::2, ::2, ::-2] + m2 = m2_orig[::2, :, :] + r = dpt.matmul(m1, m2) + + assert r.shape == m1.shape[:2] + m2.shape[-1:] + ref = np.matmul(dpt.asnumpy(m1), dpt.asnumpy(m2)) + assert np.allclose(dpt.asnumpy(r), ref) + + m1 = m1_orig[::-2, ::2, ::2] + m2 = m2_orig[::-2, :, :] + r = dpt.matmul(m1, m2) + + assert r.shape == m1.shape[:2] + m2.shape[-1:] + ref = np.matmul(dpt.asnumpy(m1), dpt.asnumpy(m2)) + assert np.allclose(dpt.asnumpy(r), ref) + + +def test_matmul_out(): + get_queue_or_skip() + + m1 = ( + dpt.arange(14, dtype="f4")[:, dpt.newaxis, dpt.newaxis] + + dpt.arange(17, dtype="f4")[dpt.newaxis, :, dpt.newaxis] + + dpt.arange(128, dtype="f4")[dpt.newaxis, dpt.newaxis, :] + ) + assert m1.shape == (14, 17, 128) + m2 = dpt.tile( + dpt.reshape(dpt.asarray([1, 2], dtype="f4"), (2, 1, 1)), (7, 128, 13) + ) + assert m2.shape == (14, 128, 13) + + buf = dpt.zeros((2 * 14, 3 * 17, 13), dtype="f4") + res = dpt.matmul(m1, m2, out=buf[::-2, 1::3, :]) + + assert dpt.allclose(res, buf[::-2, 1::3, :]) + assert dpt.allclose(dpt.zeros_like(res), buf[::-2, 0::3, :]) + assert dpt.allclose(dpt.zeros_like(res), buf[::-2, 2::3, :]) + + m1_np = dpt.asnumpy(m1) + ref = np.matmul(m1_np, dpt.asnumpy(m2)) + assert np.allclose(ref, dpt.asnumpy(res)) + + res = dpt.matmul(m1[:, :10, :10], m1[:, :10, :10].mT, out=m1[:, :10, :10]) + ref = np.matmul( + m1_np[:, :10, :10], np.transpose(m1_np[:, :10, :10], (0, 2, 1)) + ) + assert np.allclose(ref, dpt.asnumpy(res)) + + +def test_matmul_dtype(): + get_queue_or_skip() + + for dt1, dt2 in [ + (dpt.int32, dpt.int16), + (dpt.int16, dpt.int32), + (dpt.float32, dpt.int16), + (dpt.int32, dpt.float32), + ]: + m1 = dpt.ones((10, 10), dtype=dt1) + m2 = dpt.ones((10, 10), dtype=dt2) + + for ord in ["C", "A", "F", "K"]: + r = dpt.matmul(m1, m2, dtype=dpt.float32, order=ord) + assert r.dtype == dpt.float32 + + +@pytest.mark.parametrize("dt1", _numeric_types) +@pytest.mark.parametrize("dt2", _numeric_types) +@pytest.mark.parametrize("order", ["C", "K"]) +def test_matmul_type_promotion(dt1, dt2, order): + get_queue_or_skip() + + q = get_queue_or_skip() + skip_if_dtype_not_supported(dt1, q) + skip_if_dtype_not_supported(dt2, q) + + b, n, k, m = 8, 10, 17, 10 + m1 = dpt.ones((1, n, k), dtype=dt1) + m2 = dpt.ones((b, k, m), dtype=dt2) + expected_dt = dpt.result_type(m1, m2) + + r = dpt.matmul(m1, m2, order=order) + assert r.shape == (b, n, m) + assert r.dtype == expected_dt + + m1 = dpt.ones((b, n, k), dtype=dt1) + m2 = dpt.ones((1, k, m), dtype=dt2) + + r = dpt.matmul(m1, m2, order=order) + assert r.shape == (b, n, m) + assert r.dtype == expected_dt + + m1 = dpt.ones((n, k), dtype=dt1) + m2 = dpt.ones((k, m), dtype=dt2) + + r = dpt.matmul(m1, m2, order=order) + assert r.shape == (n, m) + assert r.dtype == expected_dt + + +def test_matmul_invalid_dtype(): + get_queue_or_skip() + + m1 = dpt.zeros((10, 10), dtype="f4") + m2 = dpt.zeros((10, 10), dtype="f4") + m3 = dpt.zeros((10, 10), dtype="i4") + + with pytest.raises(ValueError): + dpt.matmul(m1, m2, dtype="i4") + + with pytest.raises(ValueError): + dpt.matmul(m1, m3, dtype="i4") + + with pytest.raises(ValueError): + dpt.matmul(m3, m1, dtype="i4") + + +def test_matmul_out_errors(): + q1 = get_queue_or_skip() + q2 = dpctl.SyclQueue() + + sh = (10, 10) + dt = "i4" + m1 = dpt.zeros(sh, dtype=dt, sycl_queue=q1) + m2 = dpt.zeros(sh, dtype=dt, sycl_queue=q1) + + with pytest.raises(TypeError): + dpt.matmul(m1, m2, out=dict()) + + with pytest.raises(ValueError): + dpt.matmul(m1, m2, out=dpt.empty((10,), dtype=dt, sycl_queue=q1)) + + with pytest.raises(ValueError): + dpt.matmul(m1, m2, out=dpt.empty(sh, dtype="f4", sycl_queue=q1)) + + with pytest.raises(ExecutionPlacementError): + dpt.matmul(m1, m2, out=dpt.empty(sh, dtype=dt, sycl_queue=q2)) + + +def test_matmul_order(): + get_queue_or_skip() + + sh = ( + 10, + 10, + ) + sh2 = tuple(2 * dim for dim in sh) + n = sh[-1] + + for dt1, dt2 in zip(["i4", "i4", "f4"], ["i4", "f4", "i4"]): + ar1 = dpt.ones(sh, dtype=dt1, order="C") + ar2 = dpt.ones(sh, dtype=dt2, order="C") + r1 = dpt.matmul(ar1, ar2, order="C") + assert r1.flags.c_contiguous + r2 = dpt.matmul(ar1, ar2, order="F") + assert r2.flags.f_contiguous + r3 = dpt.matmul(ar1, ar2, order="A") + assert r3.flags.c_contiguous + r4 = dpt.matmul(ar1, ar2, order="K") + assert r4.flags.c_contiguous + + ar1 = dpt.ones(sh, dtype=dt1, order="F") + ar2 = dpt.ones(sh, dtype=dt2, order="F") + r1 = dpt.matmul(ar1, ar2, order="C") + assert r1.flags.c_contiguous + r2 = dpt.matmul(ar1, ar2, order="F") + assert r2.flags.f_contiguous + r3 = dpt.matmul(ar1, ar2, order="A") + assert r3.flags.f_contiguous + r4 = dpt.matmul(ar1, ar2, order="K") + assert r4.flags.f_contiguous + + ar1 = dpt.ones(sh2, dtype=dt1, order="C")[:10, ::-2] + ar2 = dpt.ones(sh2, dtype=dt2, order="C")[:10, ::-2] + r4 = dpt.matmul(ar1, ar2, order="K") + assert r4.strides == (n, -1) + r5 = dpt.matmul(ar1, ar2, order="C") + assert r5.strides == (n, 1) + + ar1 = dpt.ones(sh2, dtype=dt1, order="C")[:10, ::-2].mT + ar2 = dpt.ones(sh2, dtype=dt2, order="C")[:10, ::-2].mT + r4 = dpt.matmul(ar1, ar2, order="K") + assert r4.strides == (-1, n) + r5 = dpt.matmul(ar1, ar2, order="C") + assert r5.strides == (n, 1) + + +def test_matmul_invalid_order(): + get_queue_or_skip() + + sh = ( + 10, + 10, + ) + dt = "i4" + + ar1 = dpt.ones(sh, dtype=dt, order="C") + ar2 = dpt.ones(sh, dtype=dt, order="C") + r = dpt.matmul(ar1, ar2, order="invalid") + assert r.flags.c_contiguous + + ar1 = dpt.ones(sh, dtype=dt, order="F") + ar2 = dpt.ones(sh, dtype=dt, order="F") + r = dpt.matmul(ar1, ar2, order="invalid") + assert r.flags.f_contiguous + + +def test_matmul_compute_follows_data(): + q1 = get_queue_or_skip() + q2 = dpctl.SyclQueue() + + sh = ( + 10, + 10, + ) + dt = "i4" + m1 = dpt.zeros(sh, dtype=dt, sycl_queue=q1) + m2 = dpt.zeros(sh, dtype=dt, sycl_queue=q2) + + with pytest.raises(ExecutionPlacementError): + dpt.matmul(m1, m2) + + +def test_matmul_inplace_broadcasting(): + get_queue_or_skip() + + sh = (3, 5, 5) + dt = "i4" + + m1 = dpt.ones((3, 5, 5), dtype=dt) + m2 = dpt.ones((1, 5, 5), dtype=dt) + m1 @= m2 + assert dpt.all(m1 == dpt.full(sh, 5, dtype=dt)) + + +def test_matmul_prepend_dims(): + get_queue_or_skip() + + n = 5 + for dt1, dt2 in [ + (dpt.int32, dpt.int32), + (dpt.int32, dpt.int64), + (dpt.int64, dpt.int32), + (dpt.int32, dpt.uint32), + ]: + m = dpt.ones((n, 4), dtype=dt1) + v = dpt.ones((4,), dtype=dt2) + r = dpt.matmul(m, v) + assert r.shape == (n,) + + r = dpt.matmul(v, m.mT) + assert r.shape == (n,) + + +def test_matmul_inplace_same_tensors(): + get_queue_or_skip() + + n = 5 + sh = ( + n, + n, + ) + + ar1 = dpt.ones(sh, dtype="i4") + ar1 @= ar1 + assert dpt.all(ar1 == dpt.full(sh, n, dtype="i4")) + + ar1 = dpt.ones(sh, dtype="i8") + ar2 = dpt.ones(sh, dtype="i4") + dpt.matmul(ar1, ar2, out=ar1) + assert dpt.all(ar1 == dpt.full(sh, n, dtype=ar1.dtype)) + + ar1 = dpt.ones(sh, dtype="i4") + ar2 = dpt.ones(sh, dtype="i8") + dpt.matmul(ar1, ar2, out=ar2) + assert dpt.all(ar2 == dpt.full(sh, n, dtype=ar2.dtype)) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_tensordot_outer(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + t1 = dpt.ones((3, 8), dtype=dtype) + t2 = dpt.ones((4, 12), dtype=dtype) + + r = dpt.tensordot(t1, t2, axes=0) + assert r.shape == t1.shape + t2.shape + assert dpt.allclose(r, dpt.ones_like(r)) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_tensordot_inner(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + t1 = dpt.ones((3, 8), dtype=dtype) + t2 = dpt.ones((4, 8), dtype=dtype) + + r = dpt.tensordot(t1, t2.mT, axes=1) + assert r.shape == t1.shape[:1] + t2.shape[:1] + assert dpt.allclose(r, dpt.full_like(r, fill_value=t1.shape[1])) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_tensordot_double(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + t1 = dpt.ones((2, 4, 8), dtype=dtype) + t2 = dpt.ones((3, 4, 8), dtype=dtype) + + r = dpt.tensordot(t1, dpt.permute_dims(t2, (1, 2, 0)), axes=2) + assert r.shape == t1.shape[:1] + t2.shape[:1] + expected = dpt.prod(dpt.asarray(t1.shape[1:])) + assert dpt.allclose(r, dpt.full_like(r, fill_value=expected)) + + +@pytest.mark.parametrize("dtype", ["i4", "f4"]) +def test_tensordot_axes_sequence(dtype): + get_queue_or_skip() + + r = 4 + t1 = dpt.ones((2, 2, 4, 3), dtype=dtype) + t2 = dpt.ones((3, 2, 4, 3), dtype=dtype) + + assert len(t1.shape) == r + assert len(t2.shape) == r + + expected = dpt.prod(dpt.asarray(t1.shape[1:])) + ps1 = itertools.permutations(range(r)) + ps2 = itertools.permutations(range(r)) + + for p1 in ps1: + assert len(p1) == r + inv_p1 = sorted(range(r), key=p1.__getitem__) + u1 = dpt.permute_dims(t1, p1) + x1_axes = inv_p1[1:] + for p2 in ps2: + inv_p2 = sorted(range(r), key=p2.__getitem__) + u2 = dpt.permute_dims(t2, p2) + x2_axes = inv_p2[1:] + + tdr = dpt.tensordot(u1, u2, axes=(x1_axes, x2_axes)) + assert tdr.shape == t1.shape[:1] + t2.shape[:1] + assert dpt.allclose(tdr, dpt.full_like(tdr, fill_value=expected)) + + +def test_tensordot_validation(): + get_queue_or_skip() + + with pytest.raises(TypeError): + dpt.tensordot(dict(), dict()) + + t1 = dpt.empty((10, 10, 10)) + with pytest.raises(TypeError): + dpt.tensordot(t1, dict()) + + t2 = dpt.empty((10, 10, 10)) + q = dpctl.SyclQueue(t2.sycl_context, t2.sycl_device, property="in_order") + with pytest.raises(dpctl.utils.ExecutionPlacementError): + dpt.tensordot(t1, t2.to_device(q)) + + invalid_axes = ( + 1, + 2, + 3, + ) + with pytest.raises(ValueError): + dpt.tensordot(t1, t2, axes=invalid_axes) + + invalid_axes = 5.2 + with pytest.raises(TypeError): + dpt.tensordot(t1, t2, axes=invalid_axes) + + invalid_axes = ( + (1,), + ( + 0, + 2, + ), + ) + with pytest.raises(ValueError): + dpt.tensordot(t1, t2, axes=invalid_axes) + + with pytest.raises(ValueError): + dpt.tensordot(t1[..., :5], t2) + + +def test_tensordot_promotion(): + get_queue_or_skip() + + t1 = dpt.zeros((10, 10), dtype="i4") + t2 = dpt.zeros((10, 10), dtype="i8") + + r1 = dpt.tensordot(t1, t2) + assert r1.dtype == t2.dtype + + r2 = dpt.tensordot(t2, t1) + assert r2.dtype == t2.dtype + + t3 = dpt.zeros((10, 10), dtype="u4") + r3 = dpt.tensordot(t1, t3) + assert r3.dtype == dpt.result_type(t1, t3) + + +def test_tensordot_axes_errors(): + get_queue_or_skip() + + m1 = dpt.zeros((10, 10), dtype="i4") + m2 = dpt.zeros((10, 10), dtype="i4") + + with pytest.raises(ValueError): + dpt.tensordot(m1, m2, axes=-1) + + with pytest.raises(ValueError): + dpt.tensordot(m1, m2, axes=((-1,), (1,))) + + with pytest.raises(ValueError): + dpt.tensordot(m1, m2, axes=((1,), (-1,))) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_vecdot_1d(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + n = 511 + v1 = dpt.ones(n, dtype=dtype) + + v2 = dpt.ones(n, dtype=dtype) + + r = dpt.vecdot(v1, v2) + + assert r == n + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_vecdot_3d(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + m1, m2, n = 7, 3, 511 + v1 = dpt.ones((m1, m2, n), dtype=dtype) + + v2 = dpt.ones((m1, m2, n), dtype=dtype) + + r = dpt.vecdot(v1, v2) + + assert r.shape == ( + m1, + m2, + ) + assert dpt.all(r == n) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_vecdot_axis(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + m1, m2, n = 7, 3, 511 + v1 = dpt.ones((m1, n, m2), dtype=dtype) + + v2 = dpt.ones((m1, n, m2), dtype=dtype) + + r = dpt.vecdot(v1, v2, axis=1) + + assert r.shape == ( + m1, + m2, + ) + assert dpt.all(r == n) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_vecdot_strided(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + m1, m2, n = 7, 3, 511 + list1 = [1, 0, 2, 0] + pattern1 = dpt.asarray(list1, dtype=dtype) + n_padded1 = pattern1.size * (1 + ((n - 1) // pattern1.size)) + v1 = dpt.tile(dpt.reshape(pattern1, (1, -1, 1)), (m1, n_padded1, m2))[ + ::-1, :n, : + ] + + list2 = [1, 2, 1, 2] + pattern2 = dpt.asarray(list2, dtype=dtype) + n_padded2 = pattern2.size * (1 + ((n - 1) // pattern2.size)) + v2 = dpt.tile(dpt.reshape(pattern2, (1, -1, 1)), (m1, n_padded2, m2))[ + :, :n, ::-1 + ] + + r = dpt.vecdot(v1, v2, axis=1) + + ref = sum( + el1 * el2 + for el1, el2 in zip((list1 * n_padded1)[:n], (list2 * n_padded1)[:n]) + ) + + assert r.shape == ( + m1, + m2, + ) + assert dpt.all(r == ref) + + +def test_vector_arg_validation(): + get_queue_or_skip() + + s1, s2 = dpt.ones(tuple()), dpt.zeros(tuple()) + v1, v2 = dpt.ones(16), dpt.zeros(16) + + with pytest.raises(ValueError): + dpt.vecdot(s1, v2) + + with pytest.raises(ValueError): + dpt.vecdot(v1, s2) + + with pytest.raises(TypeError): + dpt.vecdot(dict(), v2) + + with pytest.raises(TypeError): + dpt.vecdot(v2, None) + + with pytest.raises(ValueError): + dpt.vecdot(v1[:5], v2[:4]) + + with pytest.raises(ValueError): + dpt.vecdot(v1, v2, axis=2) + + q = dpctl.SyclQueue( + v2.sycl_context, v2.sycl_device, property="enable_profiling" + ) + with pytest.raises(dpctl.utils.ExecutionPlacementError): + dpt.vecdot(v1, v2.to_device(q)) + + m1 = dpt.empty((10, 5)) + m2 = dpt.empty((5, 5)) + with pytest.raises(ValueError): + dpt.vecdot(m1, m2, axis=-1) + + +def test_vecdot_broadcast(): + get_queue_or_skip() + + for dt1, dt2 in [ + (dpt.int32, dpt.int32), + (dpt.int32, dpt.int64), + (dpt.int64, dpt.int32), + (dpt.int32, dpt.uint32), + ]: + m1 = dpt.zeros((1, 5), dtype=dt1) + m2 = dpt.zeros((5, 5), dtype=dt2) + r1 = dpt.vecdot(m1, m2, axis=-1) + r2 = dpt.vecdot(m2, m1, axis=-1) + assert r1.shape == r2.shape + + +@pytest.mark.parametrize("dt1", _numeric_types) +@pytest.mark.parametrize("dt2", _numeric_types) +def test_vecdot_type_promotion(dt1, dt2): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dt1, q) + skip_if_dtype_not_supported(dt2, q) + + v1 = dpt.ones(128, dtype=dt1) + v2 = dpt.ones(128, dtype=dt2) + + r = dpt.vecdot(v1, v2) + mul = v1 * v2 + assert r.shape == tuple() + assert r.dtype == mul.dtype + assert dpt.allclose(r, dpt.sum(mul, dtype=mul.dtype))