Skip to content

Commit c951d41

Browse files
authored
Allow different output type than input type when dispatching (#1590)
1 parent d4e3e79 commit c951d41

35 files changed

+75
-62
lines changed

dpnp/backend/extensions/vm/abs.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event abs_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::AbsOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::abs(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/acos.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event acos_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::AcosOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::acos(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/acosh.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event acosh_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::AcoshOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::acosh(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/add.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ sycl::event add_contig_impl(sycl::queue exec_q,
5050

5151
const T *a = reinterpret_cast<const T *>(in_a);
5252
const T *b = reinterpret_cast<const T *>(in_b);
53-
T *y = reinterpret_cast<T *>(out_y);
53+
using resTy = typename types::AddOutputType<T>::value_type;
54+
resTy *y = reinterpret_cast<resTy *>(out_y);
5455

5556
return mkl_vm::add(exec_q,
5657
n, // number of elements to be calculated

dpnp/backend/extensions/vm/asin.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event asin_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::AsinOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::asin(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/asinh.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event asinh_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::AsinhOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::asinh(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/atan.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event atan_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::AtanOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::atan(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/atan2.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ sycl::event atan2_contig_impl(sycl::queue exec_q,
5050

5151
const T *a = reinterpret_cast<const T *>(in_a);
5252
const T *b = reinterpret_cast<const T *>(in_b);
53-
T *y = reinterpret_cast<T *>(out_y);
53+
using resTy = typename types::Atan2OutputType<T>::value_type;
54+
resTy *y = reinterpret_cast<resTy *>(out_y);
5455

5556
return mkl_vm::atan2(exec_q,
5657
n, // number of elements to be calculated

dpnp/backend/extensions/vm/atanh.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event atanh_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::AtanhOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::atanh(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/ceil.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event ceil_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::CeilOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::ceil(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/common.hpp

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,8 @@ std::pair<sycl::event, sycl::event>
8282
{
8383
// check type_nums
8484
int src_typenum = src.get_typenum();
85-
int dst_typenum = dst.get_typenum();
86-
8785
auto array_types = dpctl_td_ns::usm_ndarray_types();
8886
int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
89-
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
90-
91-
if (src_typeid != dst_typeid) {
92-
throw py::value_error("Input and output arrays have different types.");
93-
}
9487

9588
// check that queues are compatible
9689
if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) {
@@ -155,7 +148,7 @@ std::pair<sycl::event, sycl::event>
155148
throw py::value_error("Input and outpur arrays must be C-contiguous.");
156149
}
157150

158-
auto dispatch_fn = dispatch_vector[dst_typeid];
151+
auto dispatch_fn = dispatch_vector[src_typeid];
159152
if (dispatch_fn == nullptr) {
160153
throw py::value_error("No implementation is defined for ufunc.");
161154
}
@@ -179,16 +172,13 @@ std::pair<sycl::event, sycl::event> binary_ufunc(
179172
// check type_nums
180173
int src1_typenum = src1.get_typenum();
181174
int src2_typenum = src2.get_typenum();
182-
int dst_typenum = dst.get_typenum();
183175

184176
auto array_types = dpctl_td_ns::usm_ndarray_types();
185177
int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum);
186178
int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum);
187-
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
188179

189-
if (src1_typeid != src2_typeid || src2_typeid != dst_typeid) {
190-
throw py::value_error(
191-
"Either any of input arrays or output array have different types.");
180+
if (src1_typeid != src2_typeid) {
181+
throw py::value_error("Input arrays have different types.");
192182
}
193183

194184
// check that queues are compatible
@@ -259,7 +249,7 @@ std::pair<sycl::event, sycl::event> binary_ufunc(
259249
throw py::value_error("Input and outpur arrays must be C-contiguous.");
260250
}
261251

262-
auto dispatch_fn = dispatch_vector[dst_typeid];
252+
auto dispatch_fn = dispatch_vector[src1_typeid];
263253
if (dispatch_fn == nullptr) {
264254
throw py::value_error("No implementation is defined for ufunc.");
265255
}
@@ -279,16 +269,8 @@ bool need_to_call_unary_ufunc(sycl::queue exec_q,
279269
{
280270
// check type_nums
281271
int src_typenum = src.get_typenum();
282-
int dst_typenum = dst.get_typenum();
283-
284272
auto array_types = dpctl_td_ns::usm_ndarray_types();
285273
int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
286-
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
287-
288-
// types must be the same
289-
if (src_typeid != dst_typeid) {
290-
return false;
291-
}
292274

293275
// OneMKL VM functions perform a copy on host if no double type support
294276
if (!exec_q.get_device().has(sycl::aspect::fp64)) {
@@ -356,7 +338,7 @@ bool need_to_call_unary_ufunc(sycl::queue exec_q,
356338
}
357339

358340
// MKL function is not defined for the type
359-
if (dispatch_vector[dst_typeid] == nullptr) {
341+
if (dispatch_vector[src_typeid] == nullptr) {
360342
return false;
361343
}
362344
return true;
@@ -372,15 +354,13 @@ bool need_to_call_binary_ufunc(sycl::queue exec_q,
372354
// check type_nums
373355
int src1_typenum = src1.get_typenum();
374356
int src2_typenum = src2.get_typenum();
375-
int dst_typenum = dst.get_typenum();
376357

377358
auto array_types = dpctl_td_ns::usm_ndarray_types();
378359
int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum);
379360
int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum);
380-
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
381361

382362
// types must be the same
383-
if (src1_typeid != src2_typeid || src2_typeid != dst_typeid) {
363+
if (src1_typeid != src2_typeid) {
384364
return false;
385365
}
386366

@@ -454,7 +434,7 @@ bool need_to_call_binary_ufunc(sycl::queue exec_q,
454434
}
455435

456436
// MKL function is not defined for the type
457-
if (dispatch_vector[dst_typeid] == nullptr) {
437+
if (dispatch_vector[src1_typeid] == nullptr) {
458438
return false;
459439
}
460440
return true;

dpnp/backend/extensions/vm/conj.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event conj_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::ConjOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::conj(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/cos.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event cos_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::CosOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::cos(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/cosh.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event cosh_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::CoshOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::cosh(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/div.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ sycl::event div_contig_impl(sycl::queue exec_q,
5050

5151
const T *a = reinterpret_cast<const T *>(in_a);
5252
const T *b = reinterpret_cast<const T *>(in_b);
53-
T *y = reinterpret_cast<T *>(out_y);
53+
using resTy = typename types::DivOutputType<T>::value_type;
54+
resTy *y = reinterpret_cast<resTy *>(out_y);
5455

5556
return mkl_vm::div(exec_q,
5657
n, // number of elements to be calculated

dpnp/backend/extensions/vm/exp.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event exp_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::ExpOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::exp(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/expm1.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event expm1_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::Expm1OutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::expm1(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/floor.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event floor_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::FloorOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::floor(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/hypot.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ sycl::event hypot_contig_impl(sycl::queue exec_q,
5050

5151
const T *a = reinterpret_cast<const T *>(in_a);
5252
const T *b = reinterpret_cast<const T *>(in_b);
53-
T *y = reinterpret_cast<T *>(out_y);
53+
using resTy = typename types::HypotOutputType<T>::value_type;
54+
resTy *y = reinterpret_cast<resTy *>(out_y);
5455

5556
return mkl_vm::hypot(exec_q,
5657
n, // number of elements to be calculated

dpnp/backend/extensions/vm/ln.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event ln_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::LnOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::ln(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/log10.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event log10_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::Log10OutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::log10(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/log1p.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event log1p_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::Log1pOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::log1p(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/log2.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event log2_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::Log2OutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::log2(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/mul.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ sycl::event mul_contig_impl(sycl::queue exec_q,
5050

5151
const T *a = reinterpret_cast<const T *>(in_a);
5252
const T *b = reinterpret_cast<const T *>(in_b);
53-
T *y = reinterpret_cast<T *>(out_y);
53+
using resTy = typename types::MulOutputType<T>::value_type;
54+
resTy *y = reinterpret_cast<resTy *>(out_y);
5455

5556
return mkl_vm::mul(exec_q,
5657
n, // number of elements to be calculated

dpnp/backend/extensions/vm/pow.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ sycl::event pow_contig_impl(sycl::queue exec_q,
5050

5151
const T *a = reinterpret_cast<const T *>(in_a);
5252
const T *b = reinterpret_cast<const T *>(in_b);
53-
T *y = reinterpret_cast<T *>(out_y);
53+
using resTy = typename types::PowOutputType<T>::value_type;
54+
resTy *y = reinterpret_cast<resTy *>(out_y);
5455

5556
return mkl_vm::pow(exec_q,
5657
n, // number of elements to be calculated

dpnp/backend/extensions/vm/round.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event round_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::RoundOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::rint(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/sin.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event sin_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::SinOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::sin(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/sinh.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event sinh_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::SinhOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::sinh(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/sqr.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event sqr_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::SqrOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::sqr(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/sqrt.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event sqrt_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::SqrtOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::sqrt(exec_q,
5455
n, // number of elements to be calculated

0 commit comments

Comments
 (0)