1
+ // === log10.hpp - Unary function LOG10 ------
2
+ // *-C++-*--/===//
3
+ //
4
+ // Data Parallel Control (dpctl)
5
+ //
6
+ // Copyright 2020-2023 Intel Corporation
7
+ //
8
+ // Licensed under the Apache License, Version 2.0 (the "License");
9
+ // you may not use this file except in compliance with the License.
10
+ // You may obtain a copy of the License at
11
+ //
12
+ // http://www.apache.org/licenses/LICENSE-2.0
13
+ //
14
+ // Unless required by applicable law or agreed to in writing, software
15
+ // distributed under the License is distributed on an "AS IS" BASIS,
16
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ // See the License for the specific language governing permissions and
18
+ // limitations under the License.
19
+ //
20
+ // ===---------------------------------------------------------------------===//
21
+ // /
22
+ // / \file
23
+ // / This file defines kernels for elementwise evaluation of LOG10(x) function.
24
+ // ===---------------------------------------------------------------------===//
25
+
26
+ #pragma once
27
+ #include < CL/sycl.hpp>
28
+ #include < cmath>
29
+ #include < cstddef>
30
+ #include < cstdint>
31
+ #include < type_traits>
32
+
33
+ #include " kernels/elementwise_functions/common.hpp"
34
+
35
+ #include " utils/offset_utils.hpp"
36
+ #include " utils/type_dispatch.hpp"
37
+ #include " utils/type_utils.hpp"
38
+ #include < pybind11/pybind11.h>
39
+
40
+ namespace dpctl
41
+ {
42
+ namespace tensor
43
+ {
44
+ namespace kernels
45
+ {
46
+ namespace log10
47
+ {
48
+
49
+ namespace py = pybind11;
50
+ namespace td_ns = dpctl::tensor::type_dispatch;
51
+
52
+ using dpctl::tensor::type_utils::is_complex;
53
+ using dpctl::tensor::type_utils::vec_cast;
54
+
55
+ template <typename argT, typename resT> struct Log10Functor
56
+ {
57
+
58
+ // is function constant for given argT
59
+ using is_constant = typename std::false_type;
60
+ // constant value, if constant
61
+ // constexpr resT constant_value = resT{};
62
+ // is function defined for sycl::vec
63
+ using supports_vec = typename std::negation<
64
+ std::disjunction<is_complex<resT>, is_complex<argT>>>;
65
+ // do both argTy and resTy support sugroup store/load operation
66
+ using supports_sg_loadstore = typename std::negation<
67
+ std::disjunction<is_complex<resT>, is_complex<argT>>>;
68
+
69
+ resT operator ()(const argT &in)
70
+ {
71
+ if constexpr (is_complex<argT>::value) {
72
+ using realT = typename argT::value_type;
73
+ return (std::log (in) / std::log (realT{10 }));
74
+ }
75
+ else {
76
+ return std::log10 (in);
77
+ }
78
+ }
79
+
80
+ template <int vec_sz>
81
+ sycl::vec<resT, vec_sz> operator ()(const sycl::vec<argT, vec_sz> &in)
82
+ {
83
+ auto const &res_vec = sycl::log10 (in);
84
+ using deducedT = typename std::remove_cv_t <
85
+ std::remove_reference_t <decltype (res_vec)>>::element_type;
86
+ if constexpr (std::is_same_v<resT, deducedT>) {
87
+ return res_vec;
88
+ }
89
+ else {
90
+ return vec_cast<resT, deducedT, vec_sz>(res_vec);
91
+ }
92
+ }
93
+ };
94
+
95
+ template <typename argTy,
96
+ typename resTy = argTy,
97
+ unsigned int vec_sz = 4 ,
98
+ unsigned int n_vecs = 2 >
99
+ using Log10ContigFunctor =
100
+ elementwise_common::UnaryContigFunctor<argTy,
101
+ resTy,
102
+ Log10Functor<argTy, resTy>,
103
+ vec_sz,
104
+ n_vecs>;
105
+
106
+ template <typename argTy, typename resTy, typename IndexerT>
107
+ using Log10StridedFunctor = elementwise_common::
108
+ UnaryStridedFunctor<argTy, resTy, IndexerT, Log10Functor<argTy, resTy>>;
109
+
110
+ template <typename T> struct Log10OutputType
111
+ {
112
+ using value_type = typename std::disjunction< // disjunction is C++17
113
+ // feature, supported by DPC++
114
+ td_ns::TypeMapResultEntry<T, sycl::half, sycl::half>,
115
+ td_ns::TypeMapResultEntry<T, float , float >,
116
+ td_ns::TypeMapResultEntry<T, double , double >,
117
+ td_ns::TypeMapResultEntry<T, std::complex<float >, std::complex<float >>,
118
+ td_ns::
119
+ TypeMapResultEntry<T, std::complex<double >, std::complex<double >>,
120
+ td_ns::DefaultResultEntry<void >>::result_type;
121
+ };
122
+
123
+ typedef sycl::event (*log10_contig_impl_fn_ptr_t )(
124
+ sycl::queue,
125
+ size_t ,
126
+ const char *,
127
+ char *,
128
+ const std::vector<sycl::event> &);
129
+
130
+ template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
131
+ class log10_contig_kernel ;
132
+
133
+ template <typename argTy>
134
+ sycl::event log10_contig_impl (sycl::queue exec_q,
135
+ size_t nelems,
136
+ const char *arg_p,
137
+ char *res_p,
138
+ const std::vector<sycl::event> &depends = {})
139
+ {
140
+ return elementwise_common::unary_contig_impl<
141
+ argTy, Log10OutputType, Log10ContigFunctor, log10_contig_kernel>(
142
+ exec_q, nelems, arg_p, res_p, depends);
143
+ }
144
+
145
+ template <typename fnT, typename T> struct Log10ContigFactory
146
+ {
147
+ fnT get ()
148
+ {
149
+ if constexpr (std::is_same_v<typename Log10OutputType<T>::value_type,
150
+ void >) {
151
+ fnT fn = nullptr ;
152
+ return fn;
153
+ }
154
+ else {
155
+ fnT fn = log10_contig_impl<T>;
156
+ return fn;
157
+ }
158
+ }
159
+ };
160
+
161
+ template <typename fnT, typename T> struct Log10TypeMapFactory
162
+ {
163
+ /* ! @brief get typeid for output type of std::log10(T x) */
164
+ std::enable_if_t <std::is_same<fnT, int >::value, int > get ()
165
+ {
166
+ using rT = typename Log10OutputType<T>::value_type;
167
+ ;
168
+ return td_ns::GetTypeid<rT>{}.get ();
169
+ }
170
+ };
171
+
172
+ template <typename T1, typename T2, typename T3> class log10_strided_kernel ;
173
+
174
+ typedef sycl::event (*log10_strided_impl_fn_ptr_t )(
175
+ sycl::queue,
176
+ size_t ,
177
+ int ,
178
+ const py::ssize_t *,
179
+ const char *,
180
+ py::ssize_t ,
181
+ char *,
182
+ py::ssize_t ,
183
+ const std::vector<sycl::event> &,
184
+ const std::vector<sycl::event> &);
185
+
186
+ template <typename argTy>
187
+ sycl::event
188
+ log10_strided_impl (sycl::queue exec_q,
189
+ size_t nelems,
190
+ int nd,
191
+ const py::ssize_t *shape_and_strides,
192
+ const char *arg_p,
193
+ py::ssize_t arg_offset,
194
+ char *res_p,
195
+ py::ssize_t res_offset,
196
+ const std::vector<sycl::event> &depends,
197
+ const std::vector<sycl::event> &additional_depends)
198
+ {
199
+ return elementwise_common::unary_strided_impl<
200
+ argTy, Log10OutputType, Log10StridedFunctor, log10_strided_kernel>(
201
+ exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
202
+ res_offset, depends, additional_depends);
203
+ }
204
+
205
+ template <typename fnT, typename T> struct Log10StridedFactory
206
+ {
207
+ fnT get ()
208
+ {
209
+ if constexpr (std::is_same_v<typename Log10OutputType<T>::value_type,
210
+ void >) {
211
+ fnT fn = nullptr ;
212
+ return fn;
213
+ }
214
+ else {
215
+ fnT fn = log10_strided_impl<T>;
216
+ return fn;
217
+ }
218
+ }
219
+ };
220
+
221
+ } // namespace log10
222
+ } // namespace kernels
223
+ } // namespace tensor
224
+ } // namespace dpctl
0 commit comments