Skip to content

MULTIPLY enable broadcasting #655

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 6, 2021
43 changes: 29 additions & 14 deletions dpnp/backend/kernels/dpnp_krnl_elemwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include <dpnp_iface.hpp>
#include "dpnp_fptr.hpp"
#include "dpnp_iterator.hpp"
#include "dpnp_utils.hpp"
#include "queue_sycl.hpp"

Expand Down Expand Up @@ -353,28 +354,41 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
const size_t* where) \
{ \
/* avoid warning unused variable*/ \
(void)input1_shape; \
(void)input1_shape_ndim; \
(void)input2_shape; \
(void)input2_shape_ndim; \
(void)where; \
\
if (!input1_size || !input2_size) \
{ \
return; \
} \
\
const size_t result_size = (input2_size > input1_size) ? input2_size : input1_size; \
\
const _DataType_input1* input1_data = reinterpret_cast<const _DataType_input1*>(input1_in); \
const _DataType_input2* input2_data = reinterpret_cast<const _DataType_input2*>(input2_in); \
_DataType_input1* input1_data = reinterpret_cast<_DataType_input1*>(const_cast<void*>(input1_in)); \
_DataType_input2* input2_data = reinterpret_cast<_DataType_input2*>(const_cast<void*>(input2_in)); \
_DataType_output* result = reinterpret_cast<_DataType_output*>(result_out); \
\
std::vector<size_t> result_shape = get_result_shape(input1_shape, input1_shape_ndim, \
input2_shape, input2_shape_ndim); \
\
DPNPC_id<_DataType_input1>* input1_it; \
const size_t input1_it_size_in_bytes = sizeof(DPNPC_id<_DataType_input1>); \
input1_it = reinterpret_cast<DPNPC_id<_DataType_input1>*>(dpnp_memory_alloc_c(input1_it_size_in_bytes)); \
new (input1_it) DPNPC_id<_DataType_input1>(input1_data, input1_shape, input1_shape_ndim); \
\
input1_it->broadcast_to_shape(result_shape); \
\
DPNPC_id<_DataType_input2>* input2_it; \
const size_t input2_it_size_in_bytes = sizeof(DPNPC_id<_DataType_input2>); \
input2_it = reinterpret_cast<DPNPC_id<_DataType_input2>*>(dpnp_memory_alloc_c(input2_it_size_in_bytes)); \
new (input2_it) DPNPC_id<_DataType_input2>(input2_data, input2_shape, input2_shape_ndim); \
\
input2_it->broadcast_to_shape(result_shape); \
\
const size_t result_size = input1_it->get_output_size(); \
\
cl::sycl::range<1> gws(result_size); \
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) { \
size_t i = global_id[0]; /*for (size_t i = 0; i < result_size; ++i)*/ \
const _DataType_output input1_elem = (input1_size == 1) ? input1_data[0] : input1_data[i]; \
const _DataType_output input2_elem = (input2_size == 1) ? input2_data[0] : input2_data[i]; \
const size_t i = global_id[0]; /*for (size_t i = 0; i < result_size; ++i)*/ \
const _DataType_output input1_elem = (*input1_it)[i]; \
const _DataType_output input2_elem = (*input2_it)[i]; \
result[i] = __operation1__; \
}; \
auto kernel_func = [&](cl::sycl::handler& cgh) { \
Expand All @@ -390,9 +404,7 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
std::is_same<_DataType_input1, float>::value) && \
std::is_same<_DataType_input2, _DataType_input1>::value) \
{ \
_DataType_input1* input1 = const_cast<_DataType_input1*>(input1_data); \
_DataType_input2* input2 = const_cast<_DataType_input2*>(input2_data); \
event = __operation2__(DPNP_QUEUE, result_size, input1, input2, result); \
event = __operation2__(DPNP_QUEUE, result_size, input1_data, input2_data, result); \
} \
else \
{ \
Expand All @@ -405,6 +417,9 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
} \
\
event.wait(); \
\
input1_it->~DPNPC_id(); \
input2_it->~DPNPC_id(); \
Comment on lines +421 to +422

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No needed here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed, we are responsible for destructing placed objects.

Copy link

@samir-nasibli samir-nasibli Apr 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are obvious things that you are saying, but a little bit not about that.

The comment was not about freeing resources in general, it is about that we should avoid such explicitly call the destructor.
I see that current iterator implementation has flaws, updates for to the iterator interface are required for this.
In anyway this is not the current PR problem.

}

#include <dpnp_gen_2arg_3type_tbl.hpp>
Expand Down
99 changes: 99 additions & 0 deletions dpnp/backend/src/dpnp_iterator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,65 @@ class DPNPC_id final
return output_size;
}

/**
* @ingroup BACKEND_UTILS
* @brief Broadcast input data to specified shape.
*
* Set output shape to use in computation of input index by output index.
*
* @note this function is designed for non-SYCL environment execution
*
* @param [in] __shape Output shape.
*/
inline void broadcast_to_shape(const std::vector<size_type>& __shape)
{
if (axis_use)
{
return;
}

if (broadcastable(input_shape, input_shape_size, __shape))
{
free_broadcast_axes_memory();
free_output_memory();

std::vector<size_type> valid_axes;
broadcast_use = true;

output_shape_size = __shape.size();
const size_type output_shape_size_in_bytes = output_shape_size * sizeof(size_type);
output_shape = reinterpret_cast<size_type*>(dpnp_memory_alloc_c(output_shape_size_in_bytes));

for (int irit = input_shape_size - 1, orit = output_shape_size - 1; orit >= 0; --irit, --orit)
{
output_shape[orit] = __shape[orit];

// ex: input_shape = {7, 1, 5}, output_shape = {8, 7, 6, 5} => valid_axes = {0, 2}
if (irit < 0 || input_shape[irit] != output_shape[orit])
{
valid_axes.insert(valid_axes.begin(), orit);
}
}

broadcast_axes_size = valid_axes.size();
const size_type broadcast_axes_size_in_bytes = broadcast_axes_size * sizeof(size_type);
broadcast_axes = reinterpret_cast<size_type*>(dpnp_memory_alloc_c(broadcast_axes_size_in_bytes));
std::copy(valid_axes.begin(), valid_axes.end(), broadcast_axes);

output_size = std::accumulate(
output_shape, output_shape + output_shape_size, size_type(1), std::multiplies<size_type>());

output_shape_strides = reinterpret_cast<size_type*>(dpnp_memory_alloc_c(output_shape_size_in_bytes));
get_shape_offsets_inkernel<size_type>(output_shape, output_shape_size, output_shape_strides);

iteration_size = 1;

// make thread private storage for each shape by multiplying memory
sycl_output_xyz =
reinterpret_cast<size_type*>(dpnp_memory_alloc_c(output_size * output_shape_size_in_bytes));
}
}

/**
* @ingroup BACKEND_UTILS
* @brief Set axis for the data object to use in computation.
Expand Down Expand Up @@ -285,6 +344,11 @@ class DPNPC_id final
*/
inline void set_axes(const std::vector<long>& __axes)
{
if (broadcast_use)
{
return;
}

if (!__axes.empty() && input_shape_size)
{
free_axes_memory();
Expand Down Expand Up @@ -368,6 +432,11 @@ class DPNPC_id final
/// this function is designed for SYCL environment execution
inline reference operator[](size_type __n) const
{
if (broadcast_use)
{
return *begin(__n);
}

const iterator it = begin();
return it[__n];
}
Expand Down Expand Up @@ -430,6 +499,24 @@ class DPNPC_id final
}
}
}
else if (broadcast_use)
{
assert(output_global_id < output_size);

// use thread private storage
size_type* sycl_output_xyz_thread = sycl_output_xyz + (output_global_id * output_shape_size);

get_xyz_by_id_inkernel(output_global_id, output_shape_strides, output_shape_size, sycl_output_xyz_thread);

for (int irit = input_shape_size - 1, orit = output_shape_size - 1; irit >= 0; --irit, --orit)
{
size_type* broadcast_axes_end = broadcast_axes + broadcast_axes_size;
if (std::find(broadcast_axes, broadcast_axes_end, orit) == broadcast_axes_end)
{
input_global_id += (sycl_output_xyz_thread[orit] * input_shape_strides[irit]);
}
}
}

return input_global_id;
}
Expand All @@ -447,6 +534,13 @@ class DPNPC_id final
axes_shape_strides = nullptr;
}

void free_broadcast_axes_memory()
{
broadcast_axes_size = size_type{};
dpnp_memory_free_c(broadcast_axes);
broadcast_axes = nullptr;
}

void free_input_memory()
{
input_size = size_type{};
Expand Down Expand Up @@ -480,6 +574,7 @@ class DPNPC_id final
void free_memory()
{
free_axes_memory();
free_broadcast_axes_memory();
free_input_memory();
free_iteration_memory();
free_output_memory();
Expand All @@ -494,6 +589,10 @@ class DPNPC_id final
std::vector<size_type> axes; /**< input shape reduction axes */
bool axis_use = false;

size_type* broadcast_axes = nullptr; /**< input shape broadcast axes */
size_type broadcast_axes_size = size_type{}; /**< input shape broadcast axes size */
bool broadcast_use = false;

size_type output_size = size_type{}; /**< output array size. Expected is same as GWS */
size_type* output_shape = nullptr; /**< output array shape */
size_type output_shape_size = size_type{}; /**< output array shape size */
Expand Down
84 changes: 84 additions & 0 deletions dpnp/backend/src/dpnp_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,90 @@ size_t get_id_by_xyz_inkernel(const _DataType* xyz, size_t xyz_size, const _Data
return global_id;
}

/**
* @ingroup BACKEND_UTILS
* @brief Check input shape is broadcastable to output one.
*
* @param [in] input_shape Input shape.
* @param [in] output_shape Output shape.
*
* @return Input shape is broadcastable to output one or not.
*/
static inline bool
broadcastable(const std::vector<size_t>& input_shape, const std::vector<size_t>& output_shape)
{
if (input_shape.size() > output_shape.size())
{
return false;
}

std::vector<size_t>::const_reverse_iterator irit = input_shape.rbegin();
std::vector<size_t>::const_reverse_iterator orit = output_shape.rbegin();
for (; irit != input_shape.rend(); ++irit, ++orit)
{
if (*irit != 1 && *irit != *orit)
{
return false;
}
}

return true;
}

static inline bool
broadcastable(const size_t* input_shape, const size_t input_shape_size, const std::vector<size_t>& output_shape)
{
const std::vector<size_t> input_shape_vec(input_shape, input_shape + input_shape_size);
return broadcastable(input_shape_vec, output_shape);
}

/**
* @ingroup BACKEND_UTILS
* @brief Get common shape based on input shapes.
*
* Example:
* Input1 shape A[8, 1, 6, 1]
* Input2 shape B[7, 1, 5]
* Output shape will be C[8, 7, 6, 5]
*
* @param [in] input1_shape Input1 shape.
* @param [in] input1_shape_size Input1 shape size.
* @param [in] input2_shape Input2 shape.
* @param [in] input2_shape_size Input2 shape size.
*
* @exception std::domain_error Input shapes are not broadcastable.
* @return Common shape.
*/
static inline std::vector<size_t>
get_result_shape(const size_t* input1_shape, const size_t input1_shape_size,
const size_t* input2_shape, const size_t input2_shape_size)
{
const size_t result_shape_size = (input2_shape_size > input1_shape_size) ? input2_shape_size : input1_shape_size;
std::vector<size_t> result_shape;
result_shape.reserve(result_shape_size);

for (int irit1 = input1_shape_size - 1, irit2 = input2_shape_size - 1; irit1 >= 0 || irit2 >= 0; --irit1, --irit2)
{
size_t input1_val = (irit1 >= 0) ? input1_shape[irit1] : 1;
size_t input2_val = (irit2 >= 0) ? input2_shape[irit2] : 1;

if (input1_val == input2_val || input1_val == 1)
{
result_shape.insert(result_shape.begin(), input2_val);
}
else if (input2_val == 1)
{
result_shape.insert(result_shape.begin(), input1_val);
}
else
{
throw std::domain_error("DPNP Error: get_common_shape() failed with input shapes check");
}
}

return result_shape;
}

/**
* @ingroup BACKEND_UTILS
* @brief Normalizes an axes into a non-negative integer axes.
Expand Down
1 change: 1 addition & 0 deletions dpnp/backend/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ link_directories(${GTEST_LIB_DIR})

# TODO split
add_executable(dpnpc_tests
test_broadcast_iterator.cpp
test_main.cpp
test_random.cpp
test_utils.cpp
Expand Down
30 changes: 30 additions & 0 deletions dpnp/backend/tests/dpnp_test_utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include <iostream>
#include <vector>

#include "dpnp_iterator.hpp"

using namespace std;
using dpnpc_it_t = DPNPC_id<size_t>::iterator;
using dpnpc_value_t = dpnpc_it_t::value_type;
using dpnpc_index_t = dpnpc_it_t::size_type;

template <typename _DataType>
vector<_DataType> get_input_data(const vector<dpnpc_index_t>& shape)
{
const dpnpc_index_t size = accumulate(shape.begin(), shape.end(), dpnpc_index_t(1), multiplies<dpnpc_index_t>());

vector<_DataType> input_data(size);
iota(input_data.begin(), input_data.end(), 1); // let's start from 1 to avoid cleaned memory comparison

return input_data;
}

template <typename _DataType>
_DataType* get_shared_data(const vector<_DataType>& input_data)
{
const size_t data_size_in_bytes = input_data.size() * sizeof(_DataType);
_DataType* shared_data = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(data_size_in_bytes));
copy(input_data.begin(), input_data.end(), shared_data);

return shared_data;
}
Loading