40
40
#include " utils/sycl_utils.hpp"
41
41
#include " utils/type_utils.hpp"
42
42
43
+ #define SYCL_EXT_ONEAPI_COMPLEX
44
+ #include < sycl/ext/oneapi/experimental/complex/complex.hpp>
45
+
43
46
namespace dpctl
44
47
{
45
48
namespace tensor
@@ -48,6 +51,8 @@ namespace kernels
48
51
{
49
52
50
53
using dpctl::tensor::ssize_t ;
54
+ namespace tu_ns = dpctl::tensor::type_utils;
55
+ namespace exprm_ns = sycl::ext::oneapi::experimental;
51
56
52
57
namespace gemm_detail
53
58
{
@@ -1082,8 +1087,21 @@ class GemmBatchFunctorThreadNM_vecm
1082
1087
#pragma unroll
1083
1088
for (std::uint32_t pr_j = 0 ; pr_j < wi_delta_m_vecs; ++pr_j)
1084
1089
{
1085
- private_C[pr_i * wi_delta_m_vecs + pr_j] +=
1086
- pr_lhs[pr_i] * pr_rhs[pr_j];
1090
+ if constexpr (tu_ns::is_complex_v<resT>) {
1091
+ using realT = typename resT::value_type;
1092
+ using sycl_complex = exprm_ns::complex<realT>;
1093
+
1094
+ auto tmp = sycl_complex (
1095
+ private_C[pr_i * wi_delta_m_vecs + pr_j]);
1096
+ tmp += sycl_complex (pr_lhs[pr_i]) *
1097
+ sycl_complex (pr_rhs[pr_j]);
1098
+ private_C[pr_i * wi_delta_m_vecs + pr_j] =
1099
+ resT (tmp);
1100
+ }
1101
+ else {
1102
+ private_C[pr_i * wi_delta_m_vecs + pr_j] +=
1103
+ pr_lhs[pr_i] * pr_rhs[pr_j];
1104
+ }
1087
1105
}
1088
1106
}
1089
1107
}
@@ -1949,9 +1967,21 @@ class GemmBatchNoAtomicFunctorThreadNM
1949
1967
slmB_t local_sum (identity_);
1950
1968
for (std::size_t private_s = 0 ; private_s < wi_delta_k; ++private_s)
1951
1969
{
1952
- local_sum = local_sum +
1953
- (local_A_block[a_offset + a_pr_offset + private_s] *
1954
- local_B_block[b_offset + private_s]);
1970
+ if constexpr (tu_ns::is_complex_v<resT>) {
1971
+ using realT = typename resT::value_type;
1972
+ using sycl_complex = exprm_ns::complex<realT>;
1973
+ auto tmp = sycl_complex (local_sum);
1974
+ tmp += (sycl_complex (local_A_block[a_offset + a_pr_offset +
1975
+ private_s]) *
1976
+ sycl_complex (local_B_block[b_offset + private_s]));
1977
+ local_sum = resT (tmp);
1978
+ }
1979
+ else {
1980
+ local_sum =
1981
+ local_sum +
1982
+ (local_A_block[a_offset + a_pr_offset + private_s] *
1983
+ local_B_block[b_offset + private_s]);
1984
+ }
1955
1985
}
1956
1986
1957
1987
const std::size_t gl_i = i + private_i;
@@ -2114,12 +2144,28 @@ class GemmBatchNoAtomicFunctorThreadK
2114
2144
accV_t private_sum (identity_);
2115
2145
constexpr accV_t vec_identity_ (identity_);
2116
2146
for (std::size_t t = local_s; t < local_B_block.size (); t += delta_k) {
2117
- private_sum +=
2118
- ((i < n) && (t + t_shift < k))
2119
- ? (static_cast <resT>(
2120
- lhs[lhs_offset + lhs_indexer (global_s_offset + t)]) *
2121
- local_B_block[t])
2122
- : vec_identity_;
2147
+ if constexpr (tu_ns::is_complex_v<resT>) {
2148
+ using realT = typename resT::value_type;
2149
+ using sycl_complex = exprm_ns::complex<realT>;
2150
+
2151
+ auto tmp = sycl_complex (private_sum);
2152
+ tmp += ((i < n) && (t + t_shift < k))
2153
+ ? sycl_complex (static_cast <resT>(
2154
+ lhs[lhs_offset +
2155
+ lhs_indexer (global_s_offset + t)])) *
2156
+ sycl_complex (local_B_block[t])
2157
+ : sycl_complex (vec_identity_);
2158
+ private_sum = resT (tmp);
2159
+ }
2160
+ else {
2161
+ private_sum +=
2162
+ ((i < n) && (t + t_shift < k))
2163
+ ? (static_cast <resT>(
2164
+ lhs[lhs_offset +
2165
+ lhs_indexer (global_s_offset + t)]) *
2166
+ local_B_block[t])
2167
+ : vec_identity_;
2168
+ }
2123
2169
}
2124
2170
2125
2171
std::size_t workspace_i_shift = local_i * delta_k;
@@ -2130,7 +2176,17 @@ class GemmBatchNoAtomicFunctorThreadK
2130
2176
if (local_s == 0 && i < n) {
2131
2177
accV_t local_sum (workspace[workspace_i_shift]);
2132
2178
for (std::size_t t = 1 ; t < delta_k; ++t) {
2133
- local_sum += workspace[workspace_i_shift + t];
2179
+ if constexpr (tu_ns::is_complex_v<resT>) {
2180
+ using realT = typename resT::value_type;
2181
+ using sycl_complex = exprm_ns::complex<realT>;
2182
+
2183
+ auto tmp = sycl_complex (local_sum);
2184
+ tmp += sycl_complex (workspace[workspace_i_shift + t]);
2185
+ local_sum = resT (tmp);
2186
+ }
2187
+ else {
2188
+ local_sum += workspace[workspace_i_shift + t];
2189
+ }
2134
2190
}
2135
2191
2136
2192
const std::size_t total_offset =
@@ -2863,8 +2919,7 @@ sycl::event gemm_batch_tree_impl(sycl::queue &exec_q,
2863
2919
}
2864
2920
2865
2921
if (max_nm < 64 ) {
2866
- using dpctl::tensor::type_utils::is_complex;
2867
- if constexpr (!is_complex<resTy>::value) {
2922
+ if constexpr (!tu_ns::is_complex_v<resTy>) {
2868
2923
if (m < 4 ) {
2869
2924
constexpr std::uint32_t m_groups_one = 1 ;
2870
2925
return gemm_batch_tree_k_impl<lhsTy, rhsTy, resTy,
@@ -2900,8 +2955,7 @@ sycl::event gemm_batch_tree_impl(sycl::queue &exec_q,
2900
2955
}
2901
2956
}
2902
2957
else { // m > 1, n > k or m > k
2903
- using dpctl::tensor::type_utils::is_complex;
2904
- if constexpr (!is_complex<resTy>::value) {
2958
+ if constexpr (!tu_ns::is_complex_v<resTy>) {
2905
2959
constexpr std::uint32_t m_groups_four = 4 ;
2906
2960
return gemm_batch_tree_nm_impl<lhsTy, rhsTy, resTy, m_groups_four>(
2907
2961
exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd,
@@ -3435,8 +3489,7 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q,
3435
3489
}
3436
3490
3437
3491
if (max_nm < 64 ) {
3438
- using dpctl::tensor::type_utils::is_complex;
3439
- if constexpr (!is_complex<resTy>::value) {
3492
+ if constexpr (!tu_ns::is_complex_v<resTy>) {
3440
3493
if (m < 4 ) {
3441
3494
return gemm_batch_contig_tree_k_impl<lhsTy, rhsTy, resTy, 1 >(
3442
3495
exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m,
@@ -3454,8 +3507,7 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q,
3454
3507
}
3455
3508
}
3456
3509
else { // m > 1, n > k or m > k
3457
- using dpctl::tensor::type_utils::is_complex;
3458
- if constexpr (!is_complex<resTy>::value) {
3510
+ if constexpr (!tu_ns::is_complex_v<resTy>) {
3459
3511
return gemm_batch_contig_tree_nm_impl<lhsTy, rhsTy, resTy, 4 >(
3460
3512
exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends);
3461
3513
}
@@ -3840,8 +3892,7 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q,
3840
3892
}
3841
3893
3842
3894
if (max_nm < 64 ) {
3843
- using dpctl::tensor::type_utils::is_complex;
3844
- if constexpr (!is_complex<resTy>::value) {
3895
+ if constexpr (!tu_ns::is_complex_v<resTy>) {
3845
3896
if (m < 4 ) {
3846
3897
return gemm_tree_k_impl<lhsTy, rhsTy, resTy, 1 >(
3847
3898
exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd,
@@ -3866,8 +3917,7 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q,
3866
3917
}
3867
3918
}
3868
3919
else { // m > 1, n > k or m > k
3869
- using dpctl::tensor::type_utils::is_complex;
3870
- if constexpr (!is_complex<resTy>::value) {
3920
+ if constexpr (!tu_ns::is_complex_v<resTy>) {
3871
3921
return gemm_tree_nm_impl<lhsTy, rhsTy, resTy, 4 >(
3872
3922
exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd,
3873
3923
lhs_outer_inner_shapes_strides, rhs_outer_nd,
@@ -4191,8 +4241,7 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q,
4191
4241
}
4192
4242
4193
4243
if (max_nm < 64 ) {
4194
- using dpctl::tensor::type_utils::is_complex;
4195
- if constexpr (!is_complex<resTy>::value) {
4244
+ if constexpr (!tu_ns::is_complex_v<resTy>) {
4196
4245
if (m < 4 ) {
4197
4246
return gemm_contig_tree_k_impl<lhsTy, rhsTy, resTy, 1 >(
4198
4247
exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends);
@@ -4208,8 +4257,7 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q,
4208
4257
}
4209
4258
}
4210
4259
else { // m > 1, n > k or m > k
4211
- using dpctl::tensor::type_utils::is_complex;
4212
- if constexpr (!is_complex<resTy>::value) {
4260
+ if constexpr (!tu_ns::is_complex_v<resTy>) {
4213
4261
return gemm_contig_tree_nm_impl<lhsTy, rhsTy, resTy, 4 >(
4214
4262
exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends);
4215
4263
}
0 commit comments