Skip to content

Commit 442e46f

Browse files
Implement kernels for in-place pow, remainder, and bitwise operators (#1447)
* Implements dedicated __ipow__ kernel * Implements in-place remainder * Implements in-place bitwise_and and bitwise_or * Implements in-place bitwise_xor * Implements in-place bitwise_left_shift and bitwise_right_shift * Adds tests for in-place bitwise elementwise funcs * Added tests for in-place remainder and pow Fixed in-place remainder for devices that do not support 64-bit floating point data types * Test commit splitting up elementwise functions * Added missing includes of common_inplace * Split elementwise functions into two more files and added them to the build * Fix more missing includes * Splits elementwise functions into separate source files * Corrected numbers of elementwise functions * Added missing vector include to elementwise function source files Removed utility include * Remove variable name in function declaration * No need to import init functions into namespace, since they are defined in it Removed "using dpctl::tensor::py_internal::init_abs`, since this imports `init_abs` into the current namespace from `dpctl::tensor::py_internal`, but this namespace is the current namespace and so the import is a no-op. Also added brief docstring for the common init module. * Changed use of "static inline" for utility functions Instead, moved common functions into anonymous namespace as inline, which is C++ way of expressing that multiple definitions of the same function may exist in different C++ translation units, which linker unifies. * Moved inline functions into separate translation units Instead of using inline keyword to allow multiple definitions of the same function in different translation units, introduced elementwise_functions_type_utils.cpp that defines these functions and a header file to use in other translatioon units. This should reduce the binary size of the produced object files and simplify the linker's job reducing the link-time. * Added license header for 2 new files --------- Co-authored-by: Oleksandr Pavlyk <oleksandr.pavlyk@intel.com>
1 parent 386bd8b commit 442e46f

File tree

163 files changed

+14358
-5201
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

163 files changed

+14358
-5201
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,78 @@ if(WIN32)
3030
endif()
3131
endif()
3232

33+
set(_elementwise_sources
34+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_common.cpp
35+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_functions_type_utils.cpp
36+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/abs.cpp
37+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/acos.cpp
38+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/acosh.cpp
39+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/add.cpp
40+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/asin.cpp
41+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/asinh.cpp
42+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atan.cpp
43+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atan2.cpp
44+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atanh.cpp
45+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_and.cpp
46+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_invert.cpp
47+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_left_shift.cpp
48+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_or.cpp
49+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_right_shift.cpp
50+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_xor.cpp
51+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/cbrt.cpp
52+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/ceil.cpp
53+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/conj.cpp
54+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/copysign.cpp
55+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/cos.cpp
56+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/cosh.cpp
57+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/equal.cpp
58+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/exp.cpp
59+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/exp2.cpp
60+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/expm1.cpp
61+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/floor_divide.cpp
62+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/floor.cpp
63+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/greater_equal.cpp
64+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/greater.cpp
65+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/hypot.cpp
66+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/imag.cpp
67+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/isfinite.cpp
68+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/isinf.cpp
69+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/isnan.cpp
70+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/less_equal.cpp
71+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/less.cpp
72+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log.cpp
73+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log1p.cpp
74+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log2.cpp
75+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log10.cpp
76+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logaddexp.cpp
77+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logical_and.cpp
78+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logical_not.cpp
79+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logical_or.cpp
80+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logical_xor.cpp
81+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/maximum.cpp
82+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/minimum.cpp
83+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/multiply.cpp
84+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/negative.cpp
85+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/not_equal.cpp
86+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/positive.cpp
87+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/pow.cpp
88+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/proj.cpp
89+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/real.cpp
90+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/remainder.cpp
91+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/round.cpp
92+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/rsqrt.cpp
93+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/sign.cpp
94+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/signbit.cpp
95+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/sin.cpp
96+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/sinh.cpp
97+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/sqrt.cpp
98+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/square.cpp
99+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/subtract.cpp
100+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/tan.cpp
101+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/tanh.cpp
102+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/true_divide.cpp
103+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/trunc.cpp
104+
)
33105
set(_tensor_impl_sources
34106
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_py.cpp
35107
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators.cpp
@@ -47,10 +119,12 @@ set(_tensor_impl_sources
47119
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp
48120
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_reductions.cpp
49121
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp
50-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp
51122
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/repeat.cpp
52123
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reduction_over_axis.cpp
53124
)
125+
list(APPEND _tensor_impl_sources
126+
${_elementwise_sources}
127+
)
54128

55129
set(python_module_name _tensor_impl)
56130
pybind11_add_module(${python_module_name} MODULE ${_tensor_impl_sources})
@@ -63,9 +137,11 @@ endif()
63137
set(_no_fast_math_sources
64138
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
65139
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
66-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp
67140
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reduction_over_axis.cpp
68141
)
142+
list(APPEND _no_fast_math_sources
143+
${_elementwise_sources}
144+
)
69145
foreach(_src_fn ${_no_fast_math_sources})
70146
get_source_file_property(_cmpl_options_prop ${_src_fn} COMPILE_OPTIONS)
71147
set(_combined_options_prop ${_cmpl_options_prop} "${_clang_prefix}-fno-fast-math")
@@ -76,7 +152,8 @@ foreach(_src_fn ${_no_fast_math_sources})
76152
endforeach()
77153
if (UNIX)
78154
set_source_files_properties(
79-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp
155+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/abs.cpp
156+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/sqrt.cpp
80157
PROPERTIES COMPILE_DEFINITIONS "USE_STD_ABS_FOR_COMPLEX_TYPES;USE_STD_SQRT_FOR_COMPLEX_TYPES")
81158
endif()
82159
target_compile_options(${python_module_name} PRIVATE -fno-sycl-id-queries-fit-in-int)

dpctl/tensor/_elementwise_funcs.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@
297297
ti._bitwise_and_result_type,
298298
ti._bitwise_and,
299299
_bitwise_and_docstring_,
300+
binary_inplace_fn=ti._bitwise_and_inplace,
300301
)
301302

302303
# B04: ===== BITWISE_LEFT_SHIFT (x1, x2)
@@ -330,6 +331,7 @@
330331
ti._bitwise_left_shift_result_type,
331332
ti._bitwise_left_shift,
332333
_bitwise_left_shift_docstring_,
334+
binary_inplace_fn=ti._bitwise_left_shift_inplace,
333335
)
334336

335337

@@ -393,6 +395,7 @@
393395
ti._bitwise_or_result_type,
394396
ti._bitwise_or,
395397
_bitwise_or_docstring_,
398+
binary_inplace_fn=ti._bitwise_or_inplace,
396399
)
397400

398401
# B06: ===== BITWISE_RIGHT_SHIFT (x1, x2)
@@ -425,6 +428,7 @@
425428
ti._bitwise_right_shift_result_type,
426429
ti._bitwise_right_shift,
427430
_bitwise_right_shift_docstring_,
431+
binary_inplace_fn=ti._bitwise_right_shift_inplace,
428432
)
429433

430434

@@ -459,6 +463,7 @@
459463
ti._bitwise_xor_result_type,
460464
ti._bitwise_xor,
461465
_bitwise_xor_docstring_,
466+
binary_inplace_fn=ti._bitwise_xor_inplace,
462467
)
463468

464469

@@ -1178,7 +1183,7 @@
11781183
_logical_xor_docstring_,
11791184
)
11801185

1181-
# B??: ==== MAXIMUM (x1, x2)
1186+
# B26: ==== MAXIMUM (x1, x2)
11821187
_maximum_docstring_ = """
11831188
maximum(x1, x2, out=None, order='K')
11841189
@@ -1208,7 +1213,7 @@
12081213
_maximum_docstring_,
12091214
)
12101215

1211-
# B??: ==== MINIMUM (x1, x2)
1216+
# B27: ==== MINIMUM (x1, x2)
12121217
_minimum_docstring_ = """
12131218
minimum(x1, x2, out=None, order='K')
12141219
@@ -1266,7 +1271,7 @@
12661271
ti._multiply_result_type,
12671272
ti._multiply,
12681273
_multiply_docstring_,
1269-
ti._multiply_inplace,
1274+
binary_inplace_fn=ti._multiply_inplace,
12701275
)
12711276

12721277
# U25: ==== NEGATIVE (x)
@@ -1361,10 +1366,14 @@
13611366
the returned array is determined by the Type Promotion Rules.
13621367
"""
13631368
pow = BinaryElementwiseFunc(
1364-
"pow", ti._pow_result_type, ti._pow, _pow_docstring_
1369+
"pow",
1370+
ti._pow_result_type,
1371+
ti._pow,
1372+
_pow_docstring_,
1373+
binary_inplace_fn=ti._pow_inplace,
13651374
)
13661375

1367-
# U??: ==== PROJ (x)
1376+
# U40: ==== PROJ (x)
13681377
_proj_docstring = """
13691378
proj(x, out=None, order='K')
13701379
@@ -1443,7 +1452,11 @@
14431452
the returned array is determined by the Type Promotion Rules.
14441453
"""
14451454
remainder = BinaryElementwiseFunc(
1446-
"remainder", ti._remainder_result_type, ti._remainder, _remainder_docstring_
1455+
"remainder",
1456+
ti._remainder_result_type,
1457+
ti._remainder,
1458+
_remainder_docstring_,
1459+
binary_inplace_fn=ti._remainder_inplace,
14471460
)
14481461

14491462
# U28: ==== ROUND (x)
@@ -1501,7 +1514,7 @@
15011514
"sign", ti._sign_result_type, ti._sign, _sign_docstring
15021515
)
15031516

1504-
# ==== SIGNBIT (x)
1517+
# U41: ==== SIGNBIT (x)
15051518
_signbit_docstring = """
15061519
signbit(x, out=None, order='K')
15071520
@@ -1654,7 +1667,7 @@
16541667
ti._subtract_result_type,
16551668
ti._subtract,
16561669
_subtract_docstring_,
1657-
ti._subtract_inplace,
1670+
binary_inplace_fn=ti._subtract_inplace,
16581671
)
16591672

16601673

dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "utils/type_utils.hpp"
3434

3535
#include "kernels/elementwise_functions/common.hpp"
36+
#include "kernels/elementwise_functions/common_inplace.hpp"
3637
#include <pybind11/pybind11.h>
3738

3839
namespace dpctl
@@ -257,6 +258,144 @@ struct BitwiseAndStridedFactory
257258
}
258259
};
259260

261+
template <typename argT, typename resT> struct BitwiseAndInplaceFunctor
262+
{
263+
using supports_sg_loadstore = typename std::true_type;
264+
using supports_vec = typename std::true_type;
265+
266+
void operator()(resT &res, const argT &in) const
267+
{
268+
using tu_ns::convert_impl;
269+
270+
if constexpr (std::is_same_v<resT, bool>) {
271+
res = res && in;
272+
}
273+
else {
274+
res &= in;
275+
}
276+
}
277+
278+
template <int vec_sz>
279+
void operator()(sycl::vec<resT, vec_sz> &res,
280+
const sycl::vec<argT, vec_sz> &in) const
281+
{
282+
283+
if constexpr (std::is_same_v<resT, bool>) {
284+
using dpctl::tensor::type_utils::vec_cast;
285+
286+
auto tmp = (res && in);
287+
res = vec_cast<resT, typename decltype(tmp)::element_type, vec_sz>(
288+
tmp);
289+
}
290+
else {
291+
res &= in;
292+
}
293+
}
294+
};
295+
296+
template <typename argT,
297+
typename resT,
298+
unsigned int vec_sz = 4,
299+
unsigned int n_vecs = 2>
300+
using BitwiseAndInplaceContigFunctor =
301+
elementwise_common::BinaryInplaceContigFunctor<
302+
argT,
303+
resT,
304+
BitwiseAndInplaceFunctor<argT, resT>,
305+
vec_sz,
306+
n_vecs>;
307+
308+
template <typename argT, typename resT, typename IndexerT>
309+
using BitwiseAndInplaceStridedFunctor =
310+
elementwise_common::BinaryInplaceStridedFunctor<
311+
argT,
312+
resT,
313+
IndexerT,
314+
BitwiseAndInplaceFunctor<argT, resT>>;
315+
316+
template <typename argT,
317+
typename resT,
318+
unsigned int vec_sz,
319+
unsigned int n_vecs>
320+
class bitwise_and_inplace_contig_kernel;
321+
322+
template <typename argTy, typename resTy>
323+
sycl::event
324+
bitwise_and_inplace_contig_impl(sycl::queue &exec_q,
325+
size_t nelems,
326+
const char *arg_p,
327+
py::ssize_t arg_offset,
328+
char *res_p,
329+
py::ssize_t res_offset,
330+
const std::vector<sycl::event> &depends = {})
331+
{
332+
return elementwise_common::binary_inplace_contig_impl<
333+
argTy, resTy, BitwiseAndInplaceContigFunctor,
334+
bitwise_and_inplace_contig_kernel>(exec_q, nelems, arg_p, arg_offset,
335+
res_p, res_offset, depends);
336+
}
337+
338+
template <typename fnT, typename T1, typename T2>
339+
struct BitwiseAndInplaceContigFactory
340+
{
341+
fnT get()
342+
{
343+
if constexpr (std::is_same_v<
344+
typename BitwiseAndOutputType<T1, T2>::value_type,
345+
void>)
346+
{
347+
fnT fn = nullptr;
348+
return fn;
349+
}
350+
else {
351+
fnT fn = bitwise_and_inplace_contig_impl<T1, T2>;
352+
return fn;
353+
}
354+
}
355+
};
356+
357+
template <typename resT, typename argT, typename IndexerT>
358+
class bitwise_and_inplace_strided_kernel;
359+
360+
template <typename argTy, typename resTy>
361+
sycl::event bitwise_and_inplace_strided_impl(
362+
sycl::queue &exec_q,
363+
size_t nelems,
364+
int nd,
365+
const py::ssize_t *shape_and_strides,
366+
const char *arg_p,
367+
py::ssize_t arg_offset,
368+
char *res_p,
369+
py::ssize_t res_offset,
370+
const std::vector<sycl::event> &depends,
371+
const std::vector<sycl::event> &additional_depends)
372+
{
373+
return elementwise_common::binary_inplace_strided_impl<
374+
argTy, resTy, BitwiseAndInplaceStridedFunctor,
375+
bitwise_and_inplace_strided_kernel>(
376+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
377+
res_offset, depends, additional_depends);
378+
}
379+
380+
template <typename fnT, typename T1, typename T2>
381+
struct BitwiseAndInplaceStridedFactory
382+
{
383+
fnT get()
384+
{
385+
if constexpr (std::is_same_v<
386+
typename BitwiseAndOutputType<T1, T2>::value_type,
387+
void>)
388+
{
389+
fnT fn = nullptr;
390+
return fn;
391+
}
392+
else {
393+
fnT fn = bitwise_and_inplace_strided_impl<T1, T2>;
394+
return fn;
395+
}
396+
}
397+
};
398+
260399
} // namespace bitwise_and
261400
} // namespace kernels
262401
} // namespace tensor

dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_invert.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
#include "utils/type_utils.hpp"
3636
#include <pybind11/pybind11.h>
3737

38+
#include "kernels/elementwise_functions/common.hpp"
39+
3840
namespace dpctl
3941
{
4042
namespace tensor

0 commit comments

Comments
 (0)