diff --git a/dpnp/backend/src/dpnp_iterator.hpp b/dpnp/backend/src/dpnp_iterator.hpp index 65336649dff3..9ab0132c79cd 100644 --- a/dpnp/backend/src/dpnp_iterator.hpp +++ b/dpnp/backend/src/dpnp_iterator.hpp @@ -351,13 +351,21 @@ class DPNPC_id final free_iteration_memory(); free_output_memory(); - axes = get_validated_axes(__axes, input_shape_size); + std::vector valid_axes = get_validated_axes(__axes, input_shape_size); axis_use = true; - output_shape_size = input_shape_size - axes.size(); + axes_size = valid_axes.size(); + const size_type axes_size_in_bytes = axes_size * sizeof(size_type); + axes = reinterpret_cast(dpnp_memory_alloc_c(axes_size_in_bytes)); + for (size_type i = 0; i < axes_size; ++i) + { + axes[i] = valid_axes[i]; + } + + output_shape_size = input_shape_size - axes_size; const size_type output_shape_size_in_bytes = output_shape_size * sizeof(size_type); - iteration_shape_size = axes.size(); + iteration_shape_size = axes_size; const size_type iteration_shape_size_in_bytes = iteration_shape_size * sizeof(size_type); std::vector iteration_shape; @@ -365,7 +373,7 @@ class DPNPC_id final size_type* output_shape_it = output_shape; for (size_type i = 0; i < input_shape_size; ++i) { - if (std::find(axes.begin(), axes.end(), i) == axes.end()) + if (std::find(valid_axes.begin(), valid_axes.end(), i) == valid_axes.end()) { *output_shape_it = input_shape[i]; ++output_shape_it; @@ -380,7 +388,7 @@ class DPNPC_id final iteration_size = 1; iteration_shape.reserve(iteration_shape_size); - for (const auto& axis : axes) + for (const auto& axis : valid_axes) { const size_type axis_dim = input_shape[axis]; iteration_shape.push_back(axis_dim); @@ -479,7 +487,7 @@ class DPNPC_id final for (size_t iit = 0, oit = 0; iit < input_shape_size; ++iit) { - if (std::find(axes.begin(), axes.end(), iit) == axes.end()) + if (std::find(axes, axes + axes_size, iit) == axes + axes_size) { const size_type output_xyz_id = get_xyz_id_by_id_inkernel(output_global_id, output_shape_strides, output_shape_size, oit); @@ -516,8 +524,10 @@ class DPNPC_id final void free_axes_memory() { - axes.clear(); + axes_size = size_type{}; + dpnp_memory_free_c(axes); dpnp_memory_free_c(axes_shape_strides); + axes = nullptr; axes_shape_strides = nullptr; } @@ -571,7 +581,8 @@ class DPNPC_id final size_type input_shape_size = size_type{}; /**< input array shape size */ size_type* input_shape_strides = nullptr; /**< input array shape strides (same size as input_shape) */ - std::vector axes; /**< input shape reduction axes */ + size_type* axes = nullptr; /**< input shape reduction axes */ + size_type axes_size = size_type{}; /**< input shape reduction axes size */ bool axis_use = false; size_type* broadcast_axes = nullptr; /**< input shape broadcast axes */ diff --git a/dpnp/backend/tests/test_utils_iterator.cpp b/dpnp/backend/tests/test_utils_iterator.cpp index 569555e7b773..5bdd562eeaa0 100644 --- a/dpnp/backend/tests/test_utils_iterator.cpp +++ b/dpnp/backend/tests/test_utils_iterator.cpp @@ -24,6 +24,7 @@ //***************************************************************************** #include +#include #include #include @@ -35,6 +36,19 @@ using namespace std; +template +_DataType* get_shared_data(const vector<_DataType>& input_data) +{ + const size_t data_size_in_bytes = input_data.size() * sizeof(_DataType); + _DataType* shared_data = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(data_size_in_bytes)); + for (size_t i = 0; i < input_data.size(); ++i) + { + shared_data[i] = input_data[i]; + } + + return shared_data; +} + TEST(TestUtilsIterator, begin_prefix_postfix) { using test_it = dpnpc_it_t; @@ -306,6 +320,43 @@ TEST(TestUtilsIterator, iterator_distance) EXPECT_EQ(axis_1_1_diff_distance, 4); } +TEST(TestUtilsIterator, sycl_getitem) +{ + using data_type = double; + + const dpnpc_index_t result_size = 12; + data_type* result = reinterpret_cast(dpnp_memory_alloc_c(result_size * sizeof(data_type))); + + vector input_data = get_input_data({3, 4}); + data_type* shared_data = get_shared_data(input_data); + + DPNPC_id* input_it; + input_it = reinterpret_cast*>(dpnp_memory_alloc_c(sizeof(DPNPC_id))); + new (input_it) DPNPC_id(shared_data, {3, 4}); + + cl::sycl::range<1> gws(result_size); + auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) { + const size_t idx = global_id[0]; + result[idx] = (*input_it)[idx]; + }; + + auto kernel_func = [&](cl::sycl::handler& cgh) { + cgh.parallel_for(gws, kernel_parallel_for_func); + }; + + cl::sycl::event event = DPNP_QUEUE.submit(kernel_func); + event.wait(); + + for (dpnpc_index_t i = 0; i < result_size; ++i) + { + EXPECT_EQ(result[i], shared_data[i]); + } + + input_it->~DPNPC_id(); + dpnp_memory_free_c(shared_data); + dpnp_memory_free_c(result); +} + struct IteratorParameters { vector input_shape; @@ -381,17 +432,20 @@ TEST_P(IteratorReduction, sycl_reduce_axis) const IteratorParameters& param = GetParam(); const dpnpc_index_t result_size = param.result.size(); - vector result(result_size, 42); - data_type* result_ptr = result.data(); + data_type* result = reinterpret_cast(dpnp_memory_alloc_c(result_size * sizeof(data_type))); vector input_data = get_input_data(param.input_shape); - DPNPC_id input(input_data.data(), param.input_shape); - input.set_axes(param.axes); + data_type* shared_data = get_shared_data(input_data); - ASSERT_EQ(input.get_output_size(), result_size); + DPNPC_id* input_it; + input_it = reinterpret_cast*>(dpnp_memory_alloc_c(sizeof(DPNPC_id))); + new (input_it) DPNPC_id(shared_data, param.input_shape); + + input_it->set_axes(param.axes); + + ASSERT_EQ(input_it->get_output_size(), result_size); cl::sycl::range<1> gws(result_size); - const DPNPC_id* input_it = &input; auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) { const size_t idx = global_id[0]; @@ -400,7 +454,7 @@ TEST_P(IteratorReduction, sycl_reduce_axis) { accumulator += *data_it; } - result_ptr[idx] = accumulator; + result[idx] = accumulator; }; auto kernel_func = [&](cl::sycl::handler& cgh) { @@ -412,8 +466,12 @@ TEST_P(IteratorReduction, sycl_reduce_axis) for (dpnpc_index_t i = 0; i < result_size; ++i) { - EXPECT_EQ(result.at(i), param.result.at(i)); + EXPECT_EQ(result[i], param.result.at(i)); } + + input_it->~DPNPC_id(); + dpnp_memory_free_c(shared_data); + dpnp_memory_free_c(result); } /**