Skip to content

Commit 610d88d

Browse files
Address PR review feedback
1 parent f4705c0 commit 610d88d

File tree

1 file changed

+37
-6
lines changed

1 file changed

+37
-6
lines changed

dpctl/tensor/libtensor/source/copy_as_contig.cpp

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "kernels/copy_as_contiguous.hpp"
3636
#include "utils/memory_overlap.hpp"
3737
#include "utils/offset_utils.hpp"
38+
#include "utils/output_validation.hpp"
3839
#include "utils/sycl_alloc_utils.hpp"
3940
#include "utils/type_dispatch.hpp"
4041

@@ -107,10 +108,16 @@ void init_copy_as_contig_dispatch_vectors(void)
107108
namespace
108109
{
109110

110-
template <typename dimT> dimT get_nelems(const std::vector<dimT> &shape)
111+
template <typename dimT> std::size_t get_nelems(const std::vector<dimT> &shape)
111112
{
112-
const dimT nelems = std::accumulate(std::begin(shape), std::end(shape),
113-
dimT(1), std::multiplies<dimT>{});
113+
auto mult_fn = [](std::size_t prod, const dimT &term) -> std::size_t {
114+
return prod * static_cast<std::size_t>(term);
115+
};
116+
117+
constexpr std::size_t unit{1};
118+
119+
const std::size_t nelems =
120+
std::accumulate(std::begin(shape), std::end(shape), unit, mult_fn);
114121
return nelems;
115122
}
116123

@@ -163,6 +170,14 @@ py_as_c_contig(const dpctl::tensor::usm_ndarray &src,
163170
throw py::value_error("Destination array must be C-contiguous");
164171
}
165172

173+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
174+
175+
// check compatibility of execution queue and allocation queue
176+
if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) {
177+
throw py::value_error(
178+
"Execution queue is not compatible with allocation queues");
179+
}
180+
166181
const auto &src_strides_vec = src.get_strides_vector();
167182

168183
if (src_nd >= 2) {
@@ -175,7 +190,7 @@ py_as_c_contig(const dpctl::tensor::usm_ndarray &src,
175190
}
176191
}
177192

178-
const py::ssize_t nelems = get_nelems(src_shape_vec);
193+
const std::size_t nelems = get_nelems(src_shape_vec);
179194

180195
if (nelems == 0) {
181196
// nothing to do
@@ -254,7 +269,7 @@ py_as_f_contig(const dpctl::tensor::usm_ndarray &src,
254269
const std::vector<sycl::event> &depends)
255270
{
256271
/* Same dimensions, same shape, same data-type
257-
* dst is C-contiguous.
272+
* dst is F-contiguous.
258273
*/
259274
int src_nd = src.get_ndim();
260275
int dst_nd = dst.get_ndim();
@@ -288,6 +303,14 @@ py_as_f_contig(const dpctl::tensor::usm_ndarray &src,
288303
throw py::value_error("Destination array must be F-contiguous");
289304
}
290305

306+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
307+
308+
// check compatibility of execution queue and allocation queue
309+
if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) {
310+
throw py::value_error(
311+
"Execution queue is not compatible with allocation queues");
312+
}
313+
291314
const auto &src_strides_vec = src.get_strides_vector();
292315

293316
if (src_nd >= 2) {
@@ -300,7 +323,7 @@ py_as_f_contig(const dpctl::tensor::usm_ndarray &src,
300323
}
301324
}
302325

303-
const py::ssize_t nelems = get_nelems(src_shape_vec);
326+
const std::size_t nelems = get_nelems(src_shape_vec);
304327

305328
if (nelems == 0) {
306329
// nothing to do
@@ -433,6 +456,14 @@ py_as_c_contig_f2c(const dpctl::tensor::usm_ndarray &src,
433456
throw py::value_error("Destination array must be C-contiguous");
434457
}
435458

459+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
460+
461+
// check compatibility of execution queue and allocation queue
462+
if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) {
463+
throw py::value_error(
464+
"Execution queue is not compatible with allocation queues");
465+
}
466+
436467
if (nelems == 0) {
437468
// nothing to do
438469
return std::make_pair(sycl::event(), sycl::event());

0 commit comments

Comments
 (0)