Skip to content

Commit 36ebb57

Browse files
Expand _divide_by_scalar to work with FP denominator.
1 parent 361d03d commit 36ebb57

File tree

1 file changed

+54
-15
lines changed

1 file changed

+54
-15
lines changed

dpctl/tensor/libtensor/source/elementwise_functions/true_divide.cpp

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@
2424
//===----------------------------------------------------------------------===//
2525

2626
#include "dpctl4pybind11.hpp"
27+
#include <complex>
2728
#include <cstdint>
2829
#include <pybind11/numpy.h>
2930
#include <pybind11/pybind11.h>
3031
#include <pybind11/stl.h>
3132
#include <sycl/sycl.hpp>
33+
#include <utility>
3234
#include <vector>
3335

3436
#include "elementwise_functions.hpp"
@@ -179,28 +181,31 @@ typedef sycl::event (*divide_by_scalar_fn_ptr_t)(
179181
const ssize_t *,
180182
const char *,
181183
py::ssize_t,
182-
std::int64_t,
184+
const char *,
183185
char *,
184186
py::ssize_t,
185187
const std::vector<sycl::event> &);
186188

187-
template <typename T>
189+
template <typename T, typename scalarT>
188190
sycl::event divide_by_scalar(sycl::queue &exec_q,
189191
size_t nelems,
190192
int nd,
191193
const ssize_t *shape_and_strides,
192194
const char *arg_p,
193195
py::ssize_t arg_offset,
194-
std::int64_t scalar,
196+
const char *scalar_ptr,
195197
char *res_p,
196198
py::ssize_t res_offset,
197199
const std::vector<sycl::event> &depends = {})
198200
{
201+
const scalarT sc_v = *reinterpret_cast<const scalarT *>(scalar_ptr);
202+
199203
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
200204
cgh.depends_on(depends);
201205

202-
using BinOpT = dpctl::tensor::kernels::true_divide::TrueDivideFunctor<
203-
T, std::int64_t, T>;
206+
using BinOpT =
207+
dpctl::tensor::kernels::true_divide::TrueDivideFunctor<T, scalarT,
208+
T>;
204209

205210
auto op = BinOpT();
206211

@@ -220,15 +225,15 @@ sycl::event divide_by_scalar(sycl::queue &exec_q,
220225

221226
const auto &arg_i = two_offsets_.get_first_offset();
222227
const auto &res_i = two_offsets_.get_second_offset();
223-
res_tp[res_i] = op(arg_tp[arg_i], scalar);
228+
res_tp[res_i] = op(arg_tp[arg_i], sc_v);
224229
});
225230
});
226231
return comp_ev;
227232
}
228233

229234
std::pair<sycl::event, sycl::event>
230235
py_divide_by_scalar(const dpctl::tensor::usm_ndarray &src,
231-
const std::int64_t scalar,
236+
double scalar,
232237
const dpctl::tensor::usm_ndarray &dst,
233238
sycl::queue &exec_q,
234239
const std::vector<sycl::event> &depends = {})
@@ -293,18 +298,41 @@ py_divide_by_scalar(const dpctl::tensor::usm_ndarray &src,
293298
constexpr int float16_typeid = static_cast<int>(td_ns::typenum_t::HALF);
294299
constexpr int float32_typeid = static_cast<int>(td_ns::typenum_t::FLOAT);
295300
constexpr int float64_typeid = static_cast<int>(td_ns::typenum_t::DOUBLE);
301+
constexpr int complex64_typeid = static_cast<int>(td_ns::typenum_t::CFLOAT);
302+
constexpr int complex128_typeid =
303+
static_cast<int>(td_ns::typenum_t::CDOUBLE);
304+
305+
// statically pre-allocated memory for scalar
306+
alignas(double) char scalar_alloc[sizeof(double)] = {0};
296307

297308
divide_by_scalar_fn_ptr_t fn;
298309
switch (src_typeid) {
299310
case float16_typeid:
300-
fn = divide_by_scalar<sycl::half>;
301-
break;
311+
{
312+
fn = divide_by_scalar<sycl::half, sycl::half>;
313+
std::ignore =
314+
new (scalar_alloc) sycl::half(static_cast<sycl::half>(scalar));
315+
} break;
302316
case float32_typeid:
303-
fn = divide_by_scalar<float>;
304-
break;
317+
{
318+
fn = divide_by_scalar<float, float>;
319+
std::ignore = new (scalar_alloc) float(scalar);
320+
} break;
305321
case float64_typeid:
306-
fn = divide_by_scalar<double>;
307-
break;
322+
{
323+
fn = divide_by_scalar<double, double>;
324+
std::ignore = new (scalar_alloc) double(scalar);
325+
} break;
326+
case complex64_typeid:
327+
{
328+
fn = divide_by_scalar<std::complex<float>, float>;
329+
std::ignore = new (scalar_alloc) float(scalar);
330+
} break;
331+
case complex128_typeid:
332+
{
333+
fn = divide_by_scalar<std::complex<double>, double>;
334+
std::ignore = new (scalar_alloc) double(scalar);
335+
} break;
308336
default:
309337
throw std::runtime_error("Implementation is missing for typeid=" +
310338
std::to_string(src_typeid));
@@ -331,6 +359,16 @@ py_divide_by_scalar(const dpctl::tensor::usm_ndarray &src,
331359
simplified_shape, simplified_src_strides, simplified_dst_strides,
332360
src_offset, dst_offset);
333361

362+
if (nd == 0) {
363+
// handle 0d array as 1d array with 1 element
364+
constexpr py::ssize_t one{1};
365+
simplified_shape.push_back(one);
366+
simplified_src_strides.push_back(one);
367+
simplified_dst_strides.push_back(one);
368+
src_offset = 0;
369+
dst_offset = 0;
370+
}
371+
334372
using dpctl::tensor::offset_utils::device_allocate_and_pack;
335373
const auto &ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
336374
exec_q, host_tasks, simplified_shape, simplified_src_strides,
@@ -349,8 +387,9 @@ py_divide_by_scalar(const dpctl::tensor::usm_ndarray &src,
349387
throw std::runtime_error("Unable to allocate device memory");
350388
}
351389

352-
sycl::event div_ev = fn(exec_q, src_nelems, nd, shape_strides, src_data,
353-
src_offset, scalar, dst_data, dst_offset, all_deps);
390+
sycl::event div_ev =
391+
fn(exec_q, src_nelems, nd, shape_strides, src_data, src_offset,
392+
scalar_alloc, dst_data, dst_offset, all_deps);
354393

355394
// async free of shape_strides temporary
356395
auto ctx = exec_q.get_context();

0 commit comments

Comments
 (0)