Skip to content

Commit 9842cdd

Browse files
committed
Implements log2, log10, and logaddexp
1 parent 36a7cd7 commit 9842cdd

File tree

8 files changed

+1172
-9
lines changed

8 files changed

+1172
-9
lines changed

dpctl/tensor/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@
111111
less_equal,
112112
log,
113113
log1p,
114+
log2,
115+
log10,
116+
logaddexp,
114117
multiply,
115118
not_equal,
116119
proj,
@@ -212,6 +215,9 @@
212215
"less_equal",
213216
"log",
214217
"log1p",
218+
"log2",
219+
"log10",
220+
"logaddexp",
215221
"proj",
216222
"real",
217223
"sin",

dpctl/tensor/_elementwise_funcs.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -563,13 +563,72 @@
563563
)
564564

565565
# U22: ==== LOG2 (x)
566-
# FIXME: implement U22
566+
_log2_docstring_ = """
567+
log2(x, out=None, order='K')
568+
Computes the base 2 logarithm element-wise.
569+
Args:
570+
x (usm_ndarray):
571+
Input array, expected to have numeric data type.
572+
out (usm_ndarray):
573+
Output array to populate. Array must have the correct
574+
shape and the expected data type.
575+
order ("C","F","A","K", optional): memory layout of the new
576+
output array, if parameter `out` is `None`.
577+
Default: "K".
578+
Return:
579+
usm_ndarray:
580+
An array containing the element-wise base 2 logarithm values.
581+
"""
582+
583+
log2 = UnaryElementwiseFunc(
584+
"log2", ti._log2_result_type, ti._log2, _log2_docstring_
585+
)
567586

568587
# U23: ==== LOG10 (x)
569-
# FIXME: implement U23
588+
_log10_docstring_ = """
589+
log10(x, out=None, order='K')
590+
Computes the base 10 logarithm element-wise.
591+
Args:
592+
x (usm_ndarray):
593+
Input array, expected to have numeric data type.
594+
out (usm_ndarray):
595+
Output array to populate. Array must have the correct
596+
shape and the expected data type.
597+
order ("C","F","A","K", optional): memory layout of the new
598+
output array, if parameter `out` is `None`.
599+
Default: "K".
600+
Return:
601+
usm_ndarray:
602+
An array containing the element-wise base 10 logarithm values.
603+
"""
604+
605+
log10 = UnaryElementwiseFunc(
606+
"log10", ti._log10_result_type, ti._log10, _log10_docstring_
607+
)
570608

571609
# B15: ==== LOGADDEXP (x1, x2)
572-
# FIXME: implement B15
610+
_logaddexp_docstring_ = """
611+
logaddexp(x1, x2, out=None, order='K')
612+
613+
Calculates the logarithm of the sum of exponentiations for each element
614+
`x1_i` of the input array `x1` with the respective element `x2_i` of the input
615+
array `x2`.
616+
617+
Args:
618+
x1 (usm_ndarray):
619+
First input array, expected to have a real-valued floating point data
620+
type.
621+
x2 (usm_ndarray):
622+
Second input array, also expected to have a real-valued floating point
623+
data type.
624+
Returns:
625+
usm_narray:
626+
an array containing the element-wise results. The data type of
627+
the returned array is determined by the Type Promotion Rules.
628+
"""
629+
logaddexp = BinaryElementwiseFunc(
630+
"logaddexp", ti._logaddexp_result_type, ti._logaddexp, _logaddexp_docstring_
631+
)
573632

574633
# B16: ==== LOGICAL_AND (x1, x2)
575634
# FIXME: implement B16
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
//=== log10.hpp - Unary function LOG10 ------
2+
//*-C++-*--/===//
3+
//
4+
// Data Parallel Control (dpctl)
5+
//
6+
// Copyright 2020-2023 Intel Corporation
7+
//
8+
// Licensed under the Apache License, Version 2.0 (the "License");
9+
// you may not use this file except in compliance with the License.
10+
// You may obtain a copy of the License at
11+
//
12+
// http://www.apache.org/licenses/LICENSE-2.0
13+
//
14+
// Unless required by applicable law or agreed to in writing, software
15+
// distributed under the License is distributed on an "AS IS" BASIS,
16+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
// See the License for the specific language governing permissions and
18+
// limitations under the License.
19+
//
20+
//===---------------------------------------------------------------------===//
21+
///
22+
/// \file
23+
/// This file defines kernels for elementwise evaluation of LOG10(x) function.
24+
//===---------------------------------------------------------------------===//
25+
26+
#pragma once
27+
#include <CL/sycl.hpp>
28+
#include <cmath>
29+
#include <cstddef>
30+
#include <cstdint>
31+
#include <type_traits>
32+
33+
#include "kernels/elementwise_functions/common.hpp"
34+
35+
#include "utils/offset_utils.hpp"
36+
#include "utils/type_dispatch.hpp"
37+
#include "utils/type_utils.hpp"
38+
#include <pybind11/pybind11.h>
39+
40+
namespace dpctl
41+
{
42+
namespace tensor
43+
{
44+
namespace kernels
45+
{
46+
namespace log10
47+
{
48+
49+
namespace py = pybind11;
50+
namespace td_ns = dpctl::tensor::type_dispatch;
51+
52+
using dpctl::tensor::type_utils::is_complex;
53+
using dpctl::tensor::type_utils::vec_cast;
54+
55+
template <typename argT, typename resT> struct Log10Functor
56+
{
57+
58+
// is function constant for given argT
59+
using is_constant = typename std::false_type;
60+
// constant value, if constant
61+
// constexpr resT constant_value = resT{};
62+
// is function defined for sycl::vec
63+
using supports_vec = typename std::negation<
64+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
65+
// do both argTy and resTy support sugroup store/load operation
66+
using supports_sg_loadstore = typename std::negation<
67+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
68+
69+
resT operator()(const argT &in)
70+
{
71+
if constexpr (is_complex<argT>::value) {
72+
using realT = typename argT::value_type;
73+
return (std::log(in) / std::log(realT{10}));
74+
}
75+
else {
76+
return std::log10(in);
77+
}
78+
}
79+
80+
template <int vec_sz>
81+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
82+
{
83+
auto const &res_vec = sycl::log10(in);
84+
using deducedT = typename std::remove_cv_t<
85+
std::remove_reference_t<decltype(res_vec)>>::element_type;
86+
if constexpr (std::is_same_v<resT, deducedT>) {
87+
return res_vec;
88+
}
89+
else {
90+
return vec_cast<resT, deducedT, vec_sz>(res_vec);
91+
}
92+
}
93+
};
94+
95+
template <typename argTy,
96+
typename resTy = argTy,
97+
unsigned int vec_sz = 4,
98+
unsigned int n_vecs = 2>
99+
using Log10ContigFunctor =
100+
elementwise_common::UnaryContigFunctor<argTy,
101+
resTy,
102+
Log10Functor<argTy, resTy>,
103+
vec_sz,
104+
n_vecs>;
105+
106+
template <typename argTy, typename resTy, typename IndexerT>
107+
using Log10StridedFunctor = elementwise_common::
108+
UnaryStridedFunctor<argTy, resTy, IndexerT, Log10Functor<argTy, resTy>>;
109+
110+
template <typename T> struct Log10OutputType
111+
{
112+
using value_type = typename std::disjunction< // disjunction is C++17
113+
// feature, supported by DPC++
114+
td_ns::TypeMapResultEntry<T, sycl::half, sycl::half>,
115+
td_ns::TypeMapResultEntry<T, float, float>,
116+
td_ns::TypeMapResultEntry<T, double, double>,
117+
td_ns::TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
118+
td_ns::
119+
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
120+
td_ns::DefaultResultEntry<void>>::result_type;
121+
};
122+
123+
typedef sycl::event (*log10_contig_impl_fn_ptr_t)(
124+
sycl::queue,
125+
size_t,
126+
const char *,
127+
char *,
128+
const std::vector<sycl::event> &);
129+
130+
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
131+
class log10_contig_kernel;
132+
133+
template <typename argTy>
134+
sycl::event log10_contig_impl(sycl::queue exec_q,
135+
size_t nelems,
136+
const char *arg_p,
137+
char *res_p,
138+
const std::vector<sycl::event> &depends = {})
139+
{
140+
return elementwise_common::unary_contig_impl<
141+
argTy, Log10OutputType, Log10ContigFunctor, log10_contig_kernel>(
142+
exec_q, nelems, arg_p, res_p, depends);
143+
}
144+
145+
template <typename fnT, typename T> struct Log10ContigFactory
146+
{
147+
fnT get()
148+
{
149+
if constexpr (std::is_same_v<typename Log10OutputType<T>::value_type,
150+
void>) {
151+
fnT fn = nullptr;
152+
return fn;
153+
}
154+
else {
155+
fnT fn = log10_contig_impl<T>;
156+
return fn;
157+
}
158+
}
159+
};
160+
161+
template <typename fnT, typename T> struct Log10TypeMapFactory
162+
{
163+
/*! @brief get typeid for output type of std::log10(T x) */
164+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
165+
{
166+
using rT = typename Log10OutputType<T>::value_type;
167+
;
168+
return td_ns::GetTypeid<rT>{}.get();
169+
}
170+
};
171+
172+
template <typename T1, typename T2, typename T3> class log10_strided_kernel;
173+
174+
typedef sycl::event (*log10_strided_impl_fn_ptr_t)(
175+
sycl::queue,
176+
size_t,
177+
int,
178+
const py::ssize_t *,
179+
const char *,
180+
py::ssize_t,
181+
char *,
182+
py::ssize_t,
183+
const std::vector<sycl::event> &,
184+
const std::vector<sycl::event> &);
185+
186+
template <typename argTy>
187+
sycl::event
188+
log10_strided_impl(sycl::queue exec_q,
189+
size_t nelems,
190+
int nd,
191+
const py::ssize_t *shape_and_strides,
192+
const char *arg_p,
193+
py::ssize_t arg_offset,
194+
char *res_p,
195+
py::ssize_t res_offset,
196+
const std::vector<sycl::event> &depends,
197+
const std::vector<sycl::event> &additional_depends)
198+
{
199+
return elementwise_common::unary_strided_impl<
200+
argTy, Log10OutputType, Log10StridedFunctor, log10_strided_kernel>(
201+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
202+
res_offset, depends, additional_depends);
203+
}
204+
205+
template <typename fnT, typename T> struct Log10StridedFactory
206+
{
207+
fnT get()
208+
{
209+
if constexpr (std::is_same_v<typename Log10OutputType<T>::value_type,
210+
void>) {
211+
fnT fn = nullptr;
212+
return fn;
213+
}
214+
else {
215+
fnT fn = log10_strided_impl<T>;
216+
return fn;
217+
}
218+
}
219+
};
220+
221+
} // namespace log10
222+
} // namespace kernels
223+
} // namespace tensor
224+
} // namespace dpctl

0 commit comments

Comments
 (0)