|
30 | 30 | #include <iostream>
|
31 | 31 | #include <memory>
|
32 | 32 | #include <stdexcept>
|
| 33 | +#include <type_traits> |
33 | 34 | #include <vector>
|
34 | 35 |
|
35 | 36 | #include "sycl/sycl.hpp"
|
@@ -140,12 +141,50 @@ smart_malloc_host(std::size_t count,
|
140 | 141 | return smart_malloc<T>(count, q, sycl::usm::alloc::host, propList);
|
141 | 142 | }
|
142 | 143 |
|
| 144 | +namespace |
| 145 | +{ |
| 146 | +template <typename T> struct valid_smart_ptr : public std::false_type |
| 147 | +{ |
| 148 | +}; |
| 149 | + |
| 150 | +template <typename ValT, typename DelT> |
| 151 | +struct valid_smart_ptr<std::unique_ptr<ValT, DelT> &> |
| 152 | + : public std::is_same<DelT, USMDeleter> |
| 153 | +{ |
| 154 | +}; |
| 155 | + |
| 156 | +template <typename ValT, typename DelT> |
| 157 | +struct valid_smart_ptr<std::unique_ptr<ValT, DelT>> |
| 158 | + : public std::is_same<DelT, USMDeleter> |
| 159 | +{ |
| 160 | +}; |
| 161 | + |
| 162 | +// base case |
| 163 | +template <typename... Rest> struct all_valid_smart_ptrs |
| 164 | +{ |
| 165 | + static constexpr bool value = true; |
| 166 | +}; |
| 167 | + |
| 168 | +template <typename Arg, typename... RestArgs> |
| 169 | +struct all_valid_smart_ptrs<Arg, RestArgs...> |
| 170 | +{ |
| 171 | + static constexpr bool value = valid_smart_ptr<Arg>::value && |
| 172 | + (all_valid_smart_ptrs<RestArgs...>::value); |
| 173 | +}; |
| 174 | +} // namespace |
| 175 | + |
143 | 176 | template <typename... Args>
|
144 | 177 | sycl::event async_smart_free(sycl::queue &exec_q,
|
145 | 178 | const std::vector<sycl::event> &depends,
|
146 | 179 | Args &&...args)
|
147 | 180 | {
|
148 | 181 | constexpr std::size_t n = sizeof...(Args);
|
| 182 | + static_assert( |
| 183 | + n > 0, "async_smart_free requires at least one smart pointer argument"); |
| 184 | + |
| 185 | + static_assert( |
| 186 | + all_valid_smart_ptrs<Args...>::value, |
| 187 | + "async_smart_free requires unique_ptr created with smart_malloc"); |
149 | 188 |
|
150 | 189 | std::vector<void *> ptrs;
|
151 | 190 | ptrs.reserve(n);
|
|
0 commit comments