Skip to content

Commit 361d03d

Browse files
committed
Adds simple _divide_by_scalar to _tensor_elementwise_impl
1 parent 46dc288 commit 361d03d

File tree

2 files changed

+213
-0
lines changed

2 files changed

+213
-0
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ struct TrueDivideFunctor
8484

8585
return in1 / exprm_ns::complex<realT2>(in2);
8686
}
87+
else if constexpr (std::is_floating_point_v<argT1> &&
88+
std::is_integral_v<argT2>)
89+
{
90+
return in1 / static_cast<argT1>(in2);
91+
}
8792
else {
8893
return in1 / in2;
8994
}

dpctl/tensor/libtensor/source/elementwise_functions/true_divide.cpp

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,19 @@
2424
//===----------------------------------------------------------------------===//
2525

2626
#include "dpctl4pybind11.hpp"
27+
#include <cstdint>
2728
#include <pybind11/numpy.h>
2829
#include <pybind11/pybind11.h>
2930
#include <pybind11/stl.h>
3031
#include <sycl/sycl.hpp>
3132
#include <vector>
3233

3334
#include "elementwise_functions.hpp"
35+
#include "simplify_iteration_space.hpp"
3436
#include "true_divide.hpp"
37+
#include "utils/memory_overlap.hpp"
38+
#include "utils/offset_utils.hpp"
39+
#include "utils/output_validation.hpp"
3540
#include "utils/type_dispatch.hpp"
3641

3742
#include "kernels/elementwise_functions/common.hpp"
@@ -165,6 +170,204 @@ void populate_true_divide_dispatch_tables(void)
165170
dtb9.populate_dispatch_table(true_divide_inplace_row_matrix_dispatch_table);
166171
};
167172

173+
template <typename T> class divide_by_scalar_krn;
174+
175+
typedef sycl::event (*divide_by_scalar_fn_ptr_t)(
176+
sycl::queue &,
177+
size_t,
178+
int,
179+
const ssize_t *,
180+
const char *,
181+
py::ssize_t,
182+
std::int64_t,
183+
char *,
184+
py::ssize_t,
185+
const std::vector<sycl::event> &);
186+
187+
template <typename T>
188+
sycl::event divide_by_scalar(sycl::queue &exec_q,
189+
size_t nelems,
190+
int nd,
191+
const ssize_t *shape_and_strides,
192+
const char *arg_p,
193+
py::ssize_t arg_offset,
194+
std::int64_t scalar,
195+
char *res_p,
196+
py::ssize_t res_offset,
197+
const std::vector<sycl::event> &depends = {})
198+
{
199+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
200+
cgh.depends_on(depends);
201+
202+
using BinOpT = dpctl::tensor::kernels::true_divide::TrueDivideFunctor<
203+
T, std::int64_t, T>;
204+
205+
auto op = BinOpT();
206+
207+
using IndexerT =
208+
typename dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
209+
210+
const IndexerT two_offsets_indexer{nd, arg_offset, res_offset,
211+
shape_and_strides};
212+
213+
const T *arg_tp = reinterpret_cast<const T *>(arg_p);
214+
T *res_tp = reinterpret_cast<T *>(res_p);
215+
216+
cgh.parallel_for<divide_by_scalar_krn<T>>(
217+
{nelems}, [=](sycl::id<1> id) {
218+
const auto &two_offsets_ =
219+
two_offsets_indexer(static_cast<ssize_t>(id.get(0)));
220+
221+
const auto &arg_i = two_offsets_.get_first_offset();
222+
const auto &res_i = two_offsets_.get_second_offset();
223+
res_tp[res_i] = op(arg_tp[arg_i], scalar);
224+
});
225+
});
226+
return comp_ev;
227+
}
228+
229+
std::pair<sycl::event, sycl::event>
230+
py_divide_by_scalar(const dpctl::tensor::usm_ndarray &src,
231+
const std::int64_t scalar,
232+
const dpctl::tensor::usm_ndarray &dst,
233+
sycl::queue &exec_q,
234+
const std::vector<sycl::event> &depends = {})
235+
{
236+
int src_typenum = src.get_typenum();
237+
int dst_typenum = dst.get_typenum();
238+
239+
auto array_types = td_ns::usm_ndarray_types();
240+
int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
241+
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
242+
243+
if (src_typeid != dst_typeid) {
244+
throw py::value_error(
245+
"Destination array has unexpected elemental data type.");
246+
}
247+
248+
// check that queues are compatible
249+
if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) {
250+
throw py::value_error(
251+
"Execution queue is not compatible with allocation queues");
252+
}
253+
254+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
255+
// check shapes, broadcasting is assumed done by caller
256+
// check that dimensions are the same
257+
int dst_nd = dst.get_ndim();
258+
if (dst_nd != src.get_ndim()) {
259+
throw py::value_error("Array dimensions are not the same.");
260+
}
261+
262+
// check that shapes are the same
263+
const py::ssize_t *src_shape = src.get_shape_raw();
264+
const py::ssize_t *dst_shape = dst.get_shape_raw();
265+
bool shapes_equal(true);
266+
size_t src_nelems(1);
267+
268+
for (int i = 0; i < dst_nd; ++i) {
269+
src_nelems *= static_cast<size_t>(src_shape[i]);
270+
shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
271+
}
272+
if (!shapes_equal) {
273+
throw py::value_error("Array shapes are not the same.");
274+
}
275+
276+
// if nelems is zero, return
277+
if (src_nelems == 0) {
278+
return std::make_pair(sycl::event(), sycl::event());
279+
}
280+
281+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems);
282+
283+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
284+
auto const &same_logical_tensors =
285+
dpctl::tensor::overlap::SameLogicalTensors();
286+
if ((overlap(src, dst) && !same_logical_tensors(src, dst))) {
287+
throw py::value_error("Arrays index overlapping segments of memory");
288+
}
289+
290+
const char *src_data = src.get_data();
291+
char *dst_data = dst.get_data();
292+
293+
constexpr int float16_typeid = static_cast<int>(td_ns::typenum_t::HALF);
294+
constexpr int float32_typeid = static_cast<int>(td_ns::typenum_t::FLOAT);
295+
constexpr int float64_typeid = static_cast<int>(td_ns::typenum_t::DOUBLE);
296+
297+
divide_by_scalar_fn_ptr_t fn;
298+
switch (src_typeid) {
299+
case float16_typeid:
300+
fn = divide_by_scalar<sycl::half>;
301+
break;
302+
case float32_typeid:
303+
fn = divide_by_scalar<float>;
304+
break;
305+
case float64_typeid:
306+
fn = divide_by_scalar<double>;
307+
break;
308+
default:
309+
throw std::runtime_error("Implementation is missing for typeid=" +
310+
std::to_string(src_typeid));
311+
}
312+
313+
// simplify strides
314+
auto const &src_strides = src.get_strides_vector();
315+
auto const &dst_strides = dst.get_strides_vector();
316+
317+
using shT = std::vector<py::ssize_t>;
318+
shT simplified_shape;
319+
shT simplified_src_strides;
320+
shT simplified_dst_strides;
321+
py::ssize_t src_offset(0);
322+
py::ssize_t dst_offset(0);
323+
324+
int nd = dst_nd;
325+
const py::ssize_t *shape = src_shape;
326+
327+
std::vector<sycl::event> host_tasks{};
328+
dpctl::tensor::py_internal::simplify_iteration_space(
329+
nd, shape, src_strides, dst_strides,
330+
// outputs
331+
simplified_shape, simplified_src_strides, simplified_dst_strides,
332+
src_offset, dst_offset);
333+
334+
using dpctl::tensor::offset_utils::device_allocate_and_pack;
335+
const auto &ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
336+
exec_q, host_tasks, simplified_shape, simplified_src_strides,
337+
simplified_dst_strides);
338+
339+
py::ssize_t *shape_strides = std::get<0>(ptr_sz_event_triple_);
340+
const sycl::event &copy_metadata_ev = std::get<2>(ptr_sz_event_triple_);
341+
342+
std::vector<sycl::event> all_deps;
343+
all_deps.reserve(depends.size() + 1);
344+
all_deps.resize(depends.size());
345+
std::copy(depends.begin(), depends.end(), all_deps.begin());
346+
all_deps.push_back(copy_metadata_ev);
347+
348+
if (shape_strides == nullptr) {
349+
throw std::runtime_error("Unable to allocate device memory");
350+
}
351+
352+
sycl::event div_ev = fn(exec_q, src_nelems, nd, shape_strides, src_data,
353+
src_offset, scalar, dst_data, dst_offset, all_deps);
354+
355+
// async free of shape_strides temporary
356+
auto ctx = exec_q.get_context();
357+
358+
sycl::event tmp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
359+
cgh.depends_on(div_ev);
360+
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
361+
cgh.host_task(
362+
[ctx, shape_strides]() { sycl_free_noexcept(shape_strides, ctx); });
363+
});
364+
365+
host_tasks.push_back(tmp_cleanup_ev);
366+
367+
return std::make_pair(
368+
dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_tasks), div_ev);
369+
}
370+
168371
} // namespace impl
169372

170373
void init_divide(py::module_ m)
@@ -233,6 +436,11 @@ void init_divide(py::module_ m)
233436
m.def("_divide_inplace", divide_inplace_pyapi, "", py::arg("lhs"),
234437
py::arg("rhs"), py::arg("sycl_queue"),
235438
py::arg("depends") = py::list());
439+
440+
using impl::py_divide_by_scalar;
441+
m.def("_divide_by_scalar", &py_divide_by_scalar, "", py::arg("src"),
442+
py::arg("scalar"), py::arg("dst"), py::arg("sycl_queue"),
443+
py::arg("depends") = py::list());
236444
}
237445
}
238446

0 commit comments

Comments
 (0)