Skip to content

Commit 71e891c

Browse files
Check in of generic reduction templates and some reductions (#1399)
* Implements necessary sycl utilities for custom reductions * Implements dpctl.tensor.max and dpctl.tensor.min * Adds tests for min and max * Reductions now set max_wg to the minimum of the max work group size and 2048 - This prevents running out of resources when using local memory on CPU * max and min nan propagation fixed for CPU devices - drops use of fetch_max/fetch_min for floats, which do not handle nans correctly * Tweak to test_reduction_kernels * Implements dpctl.tensor.argmax and argmin * Tests for argmin and argmax Also fixes argmin and argmax for scalar inputs * Argmin and argmax now handle identities correctly Adds a test for this behavior Fixed a typo in argmin and argmax causing shared local memory variant to be used for more types than expected * Replaced `std::min` with `idx_reduction_op_` * reductions now well-behaved for size-zero arrays - comparison and search reductions will throw an error in this case - slips in change to align sum signature with array API spec * removed unnecessary copies in reduction templates * Refactors sum to use generic reduction templates * Sum now uses a generic Python API * Docstrings added for argmax, argmin, max, and min * Small reduction clean-ups Removed unnecessary copies in custom_reduce_over_group Sequential reduction now casts before calling operator (makes behavior explicit rather than implicit) * Added test for argmin with keepdims=True * Added a test for raised errors in reductions Also removed unused `_usm_types` in `test_tensor_sum` * Removed `void` overloads from reduction utilities These were unused by dpctl * Added missing include, Identity to use has_known_identity Implementation of Identity trait should call sycl::known_identity if trait sycl::has_known_identity is a true_type. Added IsMultiplies, and identity value for it, since sycl::known_identity for multiplies is only defined for real-valued types. * Adding functor factories for product over axis * Added Python API for _prod_over_axis * Common reduction template takes functions to test if atomics are applicable Passing these function pointers around allows to turn atomic off altogether if desired. Use custom trait to check if reduce_over_groups can be used. This allows to work-around bug, or switch to custom code for reduction over group if desired. Such custom trait type works around issue with incorrect result returned from sycl::reduce_over_group for sycl::multiplies operator for 64-bit integral types. * Defined dpctl.tensor.prod Also tweaked docstring for sum. * Added tests for dpt.prod, removed uses of numpy * Corrected prod docstring Small tweaks to sum, min, and max docstrings --------- Co-authored-by: Oleksandr Pavlyk <oleksandr.pavlyk@intel.com>
1 parent caa0939 commit 71e891c

File tree

11 files changed

+3759
-434
lines changed

11 files changed

+3759
-434
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ pybind11_add_module(${python_module_name} MODULE
4949
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_reductions.cpp
5050
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp
5151
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp
52-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sum_reductions.cpp
5352
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/repeat.cpp
53+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reduction_over_axis.cpp
5454
)
5555
set(_clang_prefix "")
5656
if (WIN32)
@@ -60,6 +60,7 @@ set_source_files_properties(
6060
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
6161
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
6262
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp
63+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reduction_over_axis.cpp
6364
PROPERTIES COMPILE_OPTIONS "${_clang_prefix}-fno-fast-math")
6465
if (UNIX)
6566
set_source_files_properties(

dpctl/tensor/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@
160160
tanh,
161161
trunc,
162162
)
163-
from ._reduction import sum
163+
from ._reduction import argmax, argmin, max, min, prod, sum
164164
from ._testing import allclose
165165

166166
__all__ = [
@@ -309,4 +309,9 @@
309309
"allclose",
310310
"repeat",
311311
"tile",
312+
"max",
313+
"min",
314+
"argmax",
315+
"argmin",
316+
"prod",
312317
]

0 commit comments

Comments
 (0)