Skip to content

Commit 0fd49ae

Browse files
committed
Use experimental complex namespace in gemm
1 parent 34168bd commit 0fd49ae

File tree

1 file changed

+76
-28
lines changed
  • dpctl/tensor/libtensor/include/kernels/linalg_functions

1 file changed

+76
-28
lines changed

dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp

Lines changed: 76 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
#include "utils/sycl_utils.hpp"
4141
#include "utils/type_utils.hpp"
4242

43+
#define SYCL_EXT_ONEAPI_COMPLEX
44+
#include <sycl/ext/oneapi/experimental/complex/complex.hpp>
45+
4346
namespace dpctl
4447
{
4548
namespace tensor
@@ -48,6 +51,8 @@ namespace kernels
4851
{
4952

5053
using dpctl::tensor::ssize_t;
54+
namespace tu_ns = dpctl::tensor::type_utils;
55+
namespace exprm_ns = sycl::ext::oneapi::experimental;
5156

5257
namespace gemm_detail
5358
{
@@ -1082,8 +1087,21 @@ class GemmBatchFunctorThreadNM_vecm
10821087
#pragma unroll
10831088
for (std::uint32_t pr_j = 0; pr_j < wi_delta_m_vecs; ++pr_j)
10841089
{
1085-
private_C[pr_i * wi_delta_m_vecs + pr_j] +=
1086-
pr_lhs[pr_i] * pr_rhs[pr_j];
1090+
if constexpr (tu_ns::is_complex_v<resT>) {
1091+
using realT = typename resT::value_type;
1092+
using sycl_complex = exprm_ns::complex<realT>;
1093+
1094+
auto tmp = sycl_complex(
1095+
private_C[pr_i * wi_delta_m_vecs + pr_j]);
1096+
tmp += sycl_complex(pr_lhs[pr_i]) *
1097+
sycl_complex(pr_rhs[pr_j]);
1098+
private_C[pr_i * wi_delta_m_vecs + pr_j] =
1099+
resT(tmp);
1100+
}
1101+
else {
1102+
private_C[pr_i * wi_delta_m_vecs + pr_j] +=
1103+
pr_lhs[pr_i] * pr_rhs[pr_j];
1104+
}
10871105
}
10881106
}
10891107
}
@@ -1949,9 +1967,21 @@ class GemmBatchNoAtomicFunctorThreadNM
19491967
slmB_t local_sum(identity_);
19501968
for (std::size_t private_s = 0; private_s < wi_delta_k; ++private_s)
19511969
{
1952-
local_sum = local_sum +
1953-
(local_A_block[a_offset + a_pr_offset + private_s] *
1954-
local_B_block[b_offset + private_s]);
1970+
if constexpr (tu_ns::is_complex_v<resT>) {
1971+
using realT = typename resT::value_type;
1972+
using sycl_complex = exprm_ns::complex<realT>;
1973+
auto tmp = sycl_complex(local_sum);
1974+
tmp += (sycl_complex(local_A_block[a_offset + a_pr_offset +
1975+
private_s]) *
1976+
sycl_complex(local_B_block[b_offset + private_s]));
1977+
local_sum = resT(tmp);
1978+
}
1979+
else {
1980+
local_sum =
1981+
local_sum +
1982+
(local_A_block[a_offset + a_pr_offset + private_s] *
1983+
local_B_block[b_offset + private_s]);
1984+
}
19551985
}
19561986

19571987
const std::size_t gl_i = i + private_i;
@@ -2114,12 +2144,28 @@ class GemmBatchNoAtomicFunctorThreadK
21142144
accV_t private_sum(identity_);
21152145
constexpr accV_t vec_identity_(identity_);
21162146
for (std::size_t t = local_s; t < local_B_block.size(); t += delta_k) {
2117-
private_sum +=
2118-
((i < n) && (t + t_shift < k))
2119-
? (static_cast<resT>(
2120-
lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) *
2121-
local_B_block[t])
2122-
: vec_identity_;
2147+
if constexpr (tu_ns::is_complex_v<resT>) {
2148+
using realT = typename resT::value_type;
2149+
using sycl_complex = exprm_ns::complex<realT>;
2150+
2151+
auto tmp = sycl_complex(private_sum);
2152+
tmp += ((i < n) && (t + t_shift < k))
2153+
? sycl_complex(static_cast<resT>(
2154+
lhs[lhs_offset +
2155+
lhs_indexer(global_s_offset + t)])) *
2156+
sycl_complex(local_B_block[t])
2157+
: sycl_complex(vec_identity_);
2158+
private_sum = resT(tmp);
2159+
}
2160+
else {
2161+
private_sum +=
2162+
((i < n) && (t + t_shift < k))
2163+
? (static_cast<resT>(
2164+
lhs[lhs_offset +
2165+
lhs_indexer(global_s_offset + t)]) *
2166+
local_B_block[t])
2167+
: vec_identity_;
2168+
}
21232169
}
21242170

21252171
std::size_t workspace_i_shift = local_i * delta_k;
@@ -2130,7 +2176,17 @@ class GemmBatchNoAtomicFunctorThreadK
21302176
if (local_s == 0 && i < n) {
21312177
accV_t local_sum(workspace[workspace_i_shift]);
21322178
for (std::size_t t = 1; t < delta_k; ++t) {
2133-
local_sum += workspace[workspace_i_shift + t];
2179+
if constexpr (tu_ns::is_complex_v<resT>) {
2180+
using realT = typename resT::value_type;
2181+
using sycl_complex = exprm_ns::complex<realT>;
2182+
2183+
auto tmp = sycl_complex(local_sum);
2184+
tmp += sycl_complex(workspace[workspace_i_shift + t]);
2185+
local_sum = resT(tmp);
2186+
}
2187+
else {
2188+
local_sum += workspace[workspace_i_shift + t];
2189+
}
21342190
}
21352191

21362192
const std::size_t total_offset =
@@ -2863,8 +2919,7 @@ sycl::event gemm_batch_tree_impl(sycl::queue &exec_q,
28632919
}
28642920

28652921
if (max_nm < 64) {
2866-
using dpctl::tensor::type_utils::is_complex;
2867-
if constexpr (!is_complex<resTy>::value) {
2922+
if constexpr (!tu_ns::is_complex_v<resTy>) {
28682923
if (m < 4) {
28692924
constexpr std::uint32_t m_groups_one = 1;
28702925
return gemm_batch_tree_k_impl<lhsTy, rhsTy, resTy,
@@ -2900,8 +2955,7 @@ sycl::event gemm_batch_tree_impl(sycl::queue &exec_q,
29002955
}
29012956
}
29022957
else { // m > 1, n > k or m > k
2903-
using dpctl::tensor::type_utils::is_complex;
2904-
if constexpr (!is_complex<resTy>::value) {
2958+
if constexpr (!tu_ns::is_complex_v<resTy>) {
29052959
constexpr std::uint32_t m_groups_four = 4;
29062960
return gemm_batch_tree_nm_impl<lhsTy, rhsTy, resTy, m_groups_four>(
29072961
exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd,
@@ -3435,8 +3489,7 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q,
34353489
}
34363490

34373491
if (max_nm < 64) {
3438-
using dpctl::tensor::type_utils::is_complex;
3439-
if constexpr (!is_complex<resTy>::value) {
3492+
if constexpr (!tu_ns::is_complex_v<resTy>) {
34403493
if (m < 4) {
34413494
return gemm_batch_contig_tree_k_impl<lhsTy, rhsTy, resTy, 1>(
34423495
exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m,
@@ -3454,8 +3507,7 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q,
34543507
}
34553508
}
34563509
else { // m > 1, n > k or m > k
3457-
using dpctl::tensor::type_utils::is_complex;
3458-
if constexpr (!is_complex<resTy>::value) {
3510+
if constexpr (!tu_ns::is_complex_v<resTy>) {
34593511
return gemm_batch_contig_tree_nm_impl<lhsTy, rhsTy, resTy, 4>(
34603512
exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends);
34613513
}
@@ -3840,8 +3892,7 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q,
38403892
}
38413893

38423894
if (max_nm < 64) {
3843-
using dpctl::tensor::type_utils::is_complex;
3844-
if constexpr (!is_complex<resTy>::value) {
3895+
if constexpr (!tu_ns::is_complex_v<resTy>) {
38453896
if (m < 4) {
38463897
return gemm_tree_k_impl<lhsTy, rhsTy, resTy, 1>(
38473898
exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd,
@@ -3866,8 +3917,7 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q,
38663917
}
38673918
}
38683919
else { // m > 1, n > k or m > k
3869-
using dpctl::tensor::type_utils::is_complex;
3870-
if constexpr (!is_complex<resTy>::value) {
3920+
if constexpr (!tu_ns::is_complex_v<resTy>) {
38713921
return gemm_tree_nm_impl<lhsTy, rhsTy, resTy, 4>(
38723922
exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd,
38733923
lhs_outer_inner_shapes_strides, rhs_outer_nd,
@@ -4191,8 +4241,7 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q,
41914241
}
41924242

41934243
if (max_nm < 64) {
4194-
using dpctl::tensor::type_utils::is_complex;
4195-
if constexpr (!is_complex<resTy>::value) {
4244+
if constexpr (!tu_ns::is_complex_v<resTy>) {
41964245
if (m < 4) {
41974246
return gemm_contig_tree_k_impl<lhsTy, rhsTy, resTy, 1>(
41984247
exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends);
@@ -4208,8 +4257,7 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q,
42084257
}
42094258
}
42104259
else { // m > 1, n > k or m > k
4211-
using dpctl::tensor::type_utils::is_complex;
4212-
if constexpr (!is_complex<resTy>::value) {
4260+
if constexpr (!tu_ns::is_complex_v<resTy>) {
42134261
return gemm_contig_tree_nm_impl<lhsTy, rhsTy, resTy, 4>(
42144262
exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends);
42154263
}

0 commit comments

Comments
 (0)