Skip to content

Commit 809cb70

Browse files
oleksandr-pavlykndgrigorian
authored andcommitted
Add static assers to async_smart_free
One asserts that at least one unique pointer is specified. Another that specified arguments are unique pointers with USMDeleter.
1 parent b71b128 commit 809cb70

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

dpctl/tensor/libtensor/include/utils/sycl_alloc_utils.hpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <iostream>
3131
#include <memory>
3232
#include <stdexcept>
33+
#include <type_traits>
3334
#include <vector>
3435

3536
#include "sycl/sycl.hpp"
@@ -140,12 +141,50 @@ smart_malloc_host(std::size_t count,
140141
return smart_malloc<T>(count, q, sycl::usm::alloc::host, propList);
141142
}
142143

144+
namespace
145+
{
146+
template <typename T> struct valid_smart_ptr : public std::false_type
147+
{
148+
};
149+
150+
template <typename ValT, typename DelT>
151+
struct valid_smart_ptr<std::unique_ptr<ValT, DelT> &>
152+
: public std::is_same<DelT, USMDeleter>
153+
{
154+
};
155+
156+
template <typename ValT, typename DelT>
157+
struct valid_smart_ptr<std::unique_ptr<ValT, DelT>>
158+
: public std::is_same<DelT, USMDeleter>
159+
{
160+
};
161+
162+
// base case
163+
template <typename... Rest> struct all_valid_smart_ptrs
164+
{
165+
static constexpr bool value = true;
166+
};
167+
168+
template <typename Arg, typename... RestArgs>
169+
struct all_valid_smart_ptrs<Arg, RestArgs...>
170+
{
171+
static constexpr bool value = valid_smart_ptr<Arg>::value &&
172+
(all_valid_smart_ptrs<RestArgs...>::value);
173+
};
174+
} // namespace
175+
143176
template <typename... Args>
144177
sycl::event async_smart_free(sycl::queue &exec_q,
145178
const std::vector<sycl::event> &depends,
146179
Args &&...args)
147180
{
148181
constexpr std::size_t n = sizeof...(Args);
182+
static_assert(
183+
n > 0, "async_smart_free requires at least one smart pointer argument");
184+
185+
static_assert(
186+
all_valid_smart_ptrs<Args...>::value,
187+
"async_smart_free requires unique_ptr created with smart_malloc");
149188

150189
std::vector<void *> ptrs;
151190
ptrs.reserve(n);

0 commit comments

Comments
 (0)