Skip to content

Commit 21dc351

Browse files
authored
MULTIPLY enable broadcasting (#655)
* MULTIPLY enable broadcasting
1 parent 7ab68ea commit 21dc351

File tree

8 files changed

+397
-32
lines changed

8 files changed

+397
-32
lines changed

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
#include <dpnp_iface.hpp>
3030
#include "dpnp_fptr.hpp"
31+
#include "dpnp_iterator.hpp"
3132
#include "dpnp_utils.hpp"
3233
#include "queue_sycl.hpp"
3334

@@ -353,28 +354,41 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
353354
const size_t* where) \
354355
{ \
355356
/* avoid warning unused variable*/ \
356-
(void)input1_shape; \
357-
(void)input1_shape_ndim; \
358-
(void)input2_shape; \
359-
(void)input2_shape_ndim; \
360357
(void)where; \
361358
\
362359
if (!input1_size || !input2_size) \
363360
{ \
364361
return; \
365362
} \
366363
\
367-
const size_t result_size = (input2_size > input1_size) ? input2_size : input1_size; \
368-
\
369-
const _DataType_input1* input1_data = reinterpret_cast<const _DataType_input1*>(input1_in); \
370-
const _DataType_input2* input2_data = reinterpret_cast<const _DataType_input2*>(input2_in); \
364+
_DataType_input1* input1_data = reinterpret_cast<_DataType_input1*>(const_cast<void*>(input1_in)); \
365+
_DataType_input2* input2_data = reinterpret_cast<_DataType_input2*>(const_cast<void*>(input2_in)); \
371366
_DataType_output* result = reinterpret_cast<_DataType_output*>(result_out); \
372367
\
368+
std::vector<size_t> result_shape = get_result_shape(input1_shape, input1_shape_ndim, \
369+
input2_shape, input2_shape_ndim); \
370+
\
371+
DPNPC_id<_DataType_input1>* input1_it; \
372+
const size_t input1_it_size_in_bytes = sizeof(DPNPC_id<_DataType_input1>); \
373+
input1_it = reinterpret_cast<DPNPC_id<_DataType_input1>*>(dpnp_memory_alloc_c(input1_it_size_in_bytes)); \
374+
new (input1_it) DPNPC_id<_DataType_input1>(input1_data, input1_shape, input1_shape_ndim); \
375+
\
376+
input1_it->broadcast_to_shape(result_shape); \
377+
\
378+
DPNPC_id<_DataType_input2>* input2_it; \
379+
const size_t input2_it_size_in_bytes = sizeof(DPNPC_id<_DataType_input2>); \
380+
input2_it = reinterpret_cast<DPNPC_id<_DataType_input2>*>(dpnp_memory_alloc_c(input2_it_size_in_bytes)); \
381+
new (input2_it) DPNPC_id<_DataType_input2>(input2_data, input2_shape, input2_shape_ndim); \
382+
\
383+
input2_it->broadcast_to_shape(result_shape); \
384+
\
385+
const size_t result_size = input1_it->get_output_size(); \
386+
\
373387
cl::sycl::range<1> gws(result_size); \
374388
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) { \
375-
size_t i = global_id[0]; /*for (size_t i = 0; i < result_size; ++i)*/ \
376-
const _DataType_output input1_elem = (input1_size == 1) ? input1_data[0] : input1_data[i]; \
377-
const _DataType_output input2_elem = (input2_size == 1) ? input2_data[0] : input2_data[i]; \
389+
const size_t i = global_id[0]; /*for (size_t i = 0; i < result_size; ++i)*/ \
390+
const _DataType_output input1_elem = (*input1_it)[i]; \
391+
const _DataType_output input2_elem = (*input2_it)[i]; \
378392
result[i] = __operation1__; \
379393
}; \
380394
auto kernel_func = [&](cl::sycl::handler& cgh) { \
@@ -390,9 +404,7 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
390404
std::is_same<_DataType_input1, float>::value) && \
391405
std::is_same<_DataType_input2, _DataType_input1>::value) \
392406
{ \
393-
_DataType_input1* input1 = const_cast<_DataType_input1*>(input1_data); \
394-
_DataType_input2* input2 = const_cast<_DataType_input2*>(input2_data); \
395-
event = __operation2__(DPNP_QUEUE, result_size, input1, input2, result); \
407+
event = __operation2__(DPNP_QUEUE, result_size, input1_data, input2_data, result); \
396408
} \
397409
else \
398410
{ \
@@ -405,6 +417,9 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
405417
} \
406418
\
407419
event.wait(); \
420+
\
421+
input1_it->~DPNPC_id(); \
422+
input2_it->~DPNPC_id(); \
408423
}
409424

410425
#include <dpnp_gen_2arg_3type_tbl.hpp>

dpnp/backend/src/dpnp_iterator.hpp

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,65 @@ class DPNPC_id final
238238
return output_size;
239239
}
240240

241+
/**
242+
* @ingroup BACKEND_UTILS
243+
* @brief Broadcast input data to specified shape.
244+
*
245+
* Set output shape to use in computation of input index by output index.
246+
*
247+
* @note this function is designed for non-SYCL environment execution
248+
*
249+
* @param [in] __shape Output shape.
250+
*/
251+
inline void broadcast_to_shape(const std::vector<size_type>& __shape)
252+
{
253+
if (axis_use)
254+
{
255+
return;
256+
}
257+
258+
if (broadcastable(input_shape, input_shape_size, __shape))
259+
{
260+
free_broadcast_axes_memory();
261+
free_output_memory();
262+
263+
std::vector<size_type> valid_axes;
264+
broadcast_use = true;
265+
266+
output_shape_size = __shape.size();
267+
const size_type output_shape_size_in_bytes = output_shape_size * sizeof(size_type);
268+
output_shape = reinterpret_cast<size_type*>(dpnp_memory_alloc_c(output_shape_size_in_bytes));
269+
270+
for (int irit = input_shape_size - 1, orit = output_shape_size - 1; orit >= 0; --irit, --orit)
271+
{
272+
output_shape[orit] = __shape[orit];
273+
274+
// ex: input_shape = {7, 1, 5}, output_shape = {8, 7, 6, 5} => valid_axes = {0, 2}
275+
if (irit < 0 || input_shape[irit] != output_shape[orit])
276+
{
277+
valid_axes.insert(valid_axes.begin(), orit);
278+
}
279+
}
280+
281+
broadcast_axes_size = valid_axes.size();
282+
const size_type broadcast_axes_size_in_bytes = broadcast_axes_size * sizeof(size_type);
283+
broadcast_axes = reinterpret_cast<size_type*>(dpnp_memory_alloc_c(broadcast_axes_size_in_bytes));
284+
std::copy(valid_axes.begin(), valid_axes.end(), broadcast_axes);
285+
286+
output_size = std::accumulate(
287+
output_shape, output_shape + output_shape_size, size_type(1), std::multiplies<size_type>());
288+
289+
output_shape_strides = reinterpret_cast<size_type*>(dpnp_memory_alloc_c(output_shape_size_in_bytes));
290+
get_shape_offsets_inkernel<size_type>(output_shape, output_shape_size, output_shape_strides);
291+
292+
iteration_size = 1;
293+
294+
// make thread private storage for each shape by multiplying memory
295+
sycl_output_xyz =
296+
reinterpret_cast<size_type*>(dpnp_memory_alloc_c(output_size * output_shape_size_in_bytes));
297+
}
298+
}
299+
241300
/**
242301
* @ingroup BACKEND_UTILS
243302
* @brief Set axis for the data object to use in computation.
@@ -285,6 +344,11 @@ class DPNPC_id final
285344
*/
286345
inline void set_axes(const std::vector<long>& __axes)
287346
{
347+
if (broadcast_use)
348+
{
349+
return;
350+
}
351+
288352
if (!__axes.empty() && input_shape_size)
289353
{
290354
free_axes_memory();
@@ -368,6 +432,11 @@ class DPNPC_id final
368432
/// this function is designed for SYCL environment execution
369433
inline reference operator[](size_type __n) const
370434
{
435+
if (broadcast_use)
436+
{
437+
return *begin(__n);
438+
}
439+
371440
const iterator it = begin();
372441
return it[__n];
373442
}
@@ -430,6 +499,24 @@ class DPNPC_id final
430499
}
431500
}
432501
}
502+
else if (broadcast_use)
503+
{
504+
assert(output_global_id < output_size);
505+
506+
// use thread private storage
507+
size_type* sycl_output_xyz_thread = sycl_output_xyz + (output_global_id * output_shape_size);
508+
509+
get_xyz_by_id_inkernel(output_global_id, output_shape_strides, output_shape_size, sycl_output_xyz_thread);
510+
511+
for (int irit = input_shape_size - 1, orit = output_shape_size - 1; irit >= 0; --irit, --orit)
512+
{
513+
size_type* broadcast_axes_end = broadcast_axes + broadcast_axes_size;
514+
if (std::find(broadcast_axes, broadcast_axes_end, orit) == broadcast_axes_end)
515+
{
516+
input_global_id += (sycl_output_xyz_thread[orit] * input_shape_strides[irit]);
517+
}
518+
}
519+
}
433520

434521
return input_global_id;
435522
}
@@ -447,6 +534,13 @@ class DPNPC_id final
447534
axes_shape_strides = nullptr;
448535
}
449536

537+
void free_broadcast_axes_memory()
538+
{
539+
broadcast_axes_size = size_type{};
540+
dpnp_memory_free_c(broadcast_axes);
541+
broadcast_axes = nullptr;
542+
}
543+
450544
void free_input_memory()
451545
{
452546
input_size = size_type{};
@@ -480,6 +574,7 @@ class DPNPC_id final
480574
void free_memory()
481575
{
482576
free_axes_memory();
577+
free_broadcast_axes_memory();
483578
free_input_memory();
484579
free_iteration_memory();
485580
free_output_memory();
@@ -494,6 +589,10 @@ class DPNPC_id final
494589
std::vector<size_type> axes; /**< input shape reduction axes */
495590
bool axis_use = false;
496591

592+
size_type* broadcast_axes = nullptr; /**< input shape broadcast axes */
593+
size_type broadcast_axes_size = size_type{}; /**< input shape broadcast axes size */
594+
bool broadcast_use = false;
595+
497596
size_type output_size = size_type{}; /**< output array size. Expected is same as GWS */
498597
size_type* output_shape = nullptr; /**< output array shape */
499598
size_type output_shape_size = size_type{}; /**< output array shape size */

dpnp/backend/src/dpnp_utils.hpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,90 @@ size_t get_id_by_xyz_inkernel(const _DataType* xyz, size_t xyz_size, const _Data
127127
return global_id;
128128
}
129129

130+
/**
131+
* @ingroup BACKEND_UTILS
132+
* @brief Check input shape is broadcastable to output one.
133+
*
134+
* @param [in] input_shape Input shape.
135+
* @param [in] output_shape Output shape.
136+
*
137+
* @return Input shape is broadcastable to output one or not.
138+
*/
139+
static inline bool
140+
broadcastable(const std::vector<size_t>& input_shape, const std::vector<size_t>& output_shape)
141+
{
142+
if (input_shape.size() > output_shape.size())
143+
{
144+
return false;
145+
}
146+
147+
std::vector<size_t>::const_reverse_iterator irit = input_shape.rbegin();
148+
std::vector<size_t>::const_reverse_iterator orit = output_shape.rbegin();
149+
for (; irit != input_shape.rend(); ++irit, ++orit)
150+
{
151+
if (*irit != 1 && *irit != *orit)
152+
{
153+
return false;
154+
}
155+
}
156+
157+
return true;
158+
}
159+
160+
static inline bool
161+
broadcastable(const size_t* input_shape, const size_t input_shape_size, const std::vector<size_t>& output_shape)
162+
{
163+
const std::vector<size_t> input_shape_vec(input_shape, input_shape + input_shape_size);
164+
return broadcastable(input_shape_vec, output_shape);
165+
}
166+
167+
/**
168+
* @ingroup BACKEND_UTILS
169+
* @brief Get common shape based on input shapes.
170+
*
171+
* Example:
172+
* Input1 shape A[8, 1, 6, 1]
173+
* Input2 shape B[7, 1, 5]
174+
* Output shape will be C[8, 7, 6, 5]
175+
*
176+
* @param [in] input1_shape Input1 shape.
177+
* @param [in] input1_shape_size Input1 shape size.
178+
* @param [in] input2_shape Input2 shape.
179+
* @param [in] input2_shape_size Input2 shape size.
180+
*
181+
* @exception std::domain_error Input shapes are not broadcastable.
182+
* @return Common shape.
183+
*/
184+
static inline std::vector<size_t>
185+
get_result_shape(const size_t* input1_shape, const size_t input1_shape_size,
186+
const size_t* input2_shape, const size_t input2_shape_size)
187+
{
188+
const size_t result_shape_size = (input2_shape_size > input1_shape_size) ? input2_shape_size : input1_shape_size;
189+
std::vector<size_t> result_shape;
190+
result_shape.reserve(result_shape_size);
191+
192+
for (int irit1 = input1_shape_size - 1, irit2 = input2_shape_size - 1; irit1 >= 0 || irit2 >= 0; --irit1, --irit2)
193+
{
194+
size_t input1_val = (irit1 >= 0) ? input1_shape[irit1] : 1;
195+
size_t input2_val = (irit2 >= 0) ? input2_shape[irit2] : 1;
196+
197+
if (input1_val == input2_val || input1_val == 1)
198+
{
199+
result_shape.insert(result_shape.begin(), input2_val);
200+
}
201+
else if (input2_val == 1)
202+
{
203+
result_shape.insert(result_shape.begin(), input1_val);
204+
}
205+
else
206+
{
207+
throw std::domain_error("DPNP Error: get_common_shape() failed with input shapes check");
208+
}
209+
}
210+
211+
return result_shape;
212+
}
213+
130214
/**
131215
* @ingroup BACKEND_UTILS
132216
* @brief Normalizes an axes into a non-negative integer axes.

dpnp/backend/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ link_directories(${GTEST_LIB_DIR})
4646

4747
# TODO split
4848
add_executable(dpnpc_tests
49+
test_broadcast_iterator.cpp
4950
test_main.cpp
5051
test_random.cpp
5152
test_utils.cpp
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#include <iostream>
2+
#include <vector>
3+
4+
#include "dpnp_iterator.hpp"
5+
6+
using namespace std;
7+
using dpnpc_it_t = DPNPC_id<size_t>::iterator;
8+
using dpnpc_value_t = dpnpc_it_t::value_type;
9+
using dpnpc_index_t = dpnpc_it_t::size_type;
10+
11+
template <typename _DataType>
12+
vector<_DataType> get_input_data(const vector<dpnpc_index_t>& shape)
13+
{
14+
const dpnpc_index_t size = accumulate(shape.begin(), shape.end(), dpnpc_index_t(1), multiplies<dpnpc_index_t>());
15+
16+
vector<_DataType> input_data(size);
17+
iota(input_data.begin(), input_data.end(), 1); // let's start from 1 to avoid cleaned memory comparison
18+
19+
return input_data;
20+
}
21+
22+
template <typename _DataType>
23+
_DataType* get_shared_data(const vector<_DataType>& input_data)
24+
{
25+
const size_t data_size_in_bytes = input_data.size() * sizeof(_DataType);
26+
_DataType* shared_data = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(data_size_in_bytes));
27+
copy(input_data.begin(), input_data.end(), shared_data);
28+
29+
return shared_data;
30+
}

0 commit comments

Comments
 (0)