Skip to content

Commit be6e00f

Browse files
Implementation of correlate (#2180)
* Implementation of dot product sliding window which is can be used to implement both `correlate` and `convolve` functions. * Implementation of `correlate` function using dot product sliding window
1 parent 0bfe1e3 commit be6e00f

23 files changed

+1741
-183
lines changed

dpnp/backend/extensions/statistics/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,15 @@
2626

2727
set(python_module_name _statistics_impl)
2828
set(_module_src
29-
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
3029
${CMAKE_CURRENT_SOURCE_DIR}/bincount.cpp
30+
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
3131
${CMAKE_CURRENT_SOURCE_DIR}/histogram.cpp
3232
${CMAKE_CURRENT_SOURCE_DIR}/histogramdd.cpp
3333
${CMAKE_CURRENT_SOURCE_DIR}/histogram_common.cpp
34+
${CMAKE_CURRENT_SOURCE_DIR}/sliding_dot_product1d.cpp
35+
${CMAKE_CURRENT_SOURCE_DIR}/sliding_window1d.cpp
3436
${CMAKE_CURRENT_SOURCE_DIR}/statistics_py.cpp
37+
${CMAKE_CURRENT_SOURCE_DIR}/validation_utils.cpp
3538
)
3639

3740
pybind11_add_module(${python_module_name} MODULE ${_module_src})

dpnp/backend/extensions/statistics/bincount.hpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@
3333

3434
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
3535

36-
namespace statistics
37-
{
38-
namespace histogram
36+
namespace statistics::histogram
3937
{
4038
struct Bincount
4139
{
@@ -63,5 +61,4 @@ struct Bincount
6361
};
6462

6563
void populate_bincount(py::module_ m);
66-
} // namespace histogram
67-
} // namespace statistics
64+
} // namespace statistics::histogram

dpnp/backend/extensions/statistics/common.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,8 @@
2929

3030
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
3131

32-
namespace statistics
32+
namespace statistics::common
3333
{
34-
namespace common
35-
{
36-
3734
size_t get_max_local_size(const sycl::device &device)
3835
{
3936
constexpr const int default_max_cpu_local_size = 256;
@@ -120,5 +117,4 @@ pybind11::dtype dtype_from_typenum(int dst_typenum)
120117
}
121118
}
122119

123-
} // namespace common
124-
} // namespace statistics
120+
} // namespace statistics::common

dpnp/backend/extensions/statistics/common.hpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@
3535

3636
namespace type_utils = dpctl::tensor::type_utils;
3737

38-
namespace statistics
39-
{
40-
namespace common
38+
namespace statistics::common
4139
{
4240

4341
template <typename N, typename D>
@@ -187,5 +185,4 @@ sycl::nd_range<1>
187185
// headers of dpctl.
188186
pybind11::dtype dtype_from_typenum(int dst_typenum);
189187

190-
} // namespace common
191-
} // namespace statistics
188+
} // namespace statistics::common

dpnp/backend/extensions/statistics/dispatch_table.hpp

Lines changed: 100 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,8 @@
3939
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
4040
namespace py = pybind11;
4141

42-
namespace statistics
42+
namespace statistics::common
4343
{
44-
namespace common
45-
{
46-
4744
template <typename T, typename Rest>
4845
struct one_of
4946
{
@@ -97,6 +94,32 @@ using DTypePair = std::pair<DType, DType>;
9794
using SupportedDTypeList = std::vector<DType>;
9895
using SupportedDTypeList2 = std::vector<DTypePair>;
9996

97+
template <typename FnT,
98+
typename SupportedTypes,
99+
template <typename>
100+
typename Func>
101+
struct TableBuilder
102+
{
103+
template <typename _FnT, typename T>
104+
struct impl
105+
{
106+
static constexpr bool is_defined = one_of_v<T, SupportedTypes>;
107+
108+
_FnT get()
109+
{
110+
if constexpr (is_defined) {
111+
return Func<T>::impl;
112+
}
113+
else {
114+
return nullptr;
115+
}
116+
}
117+
};
118+
119+
using type =
120+
dpctl_td_ns::DispatchVectorBuilder<FnT, impl, dpctl_td_ns::num_types>;
121+
};
122+
100123
template <typename FnT,
101124
typename SupportedTypes,
102125
template <typename, typename>
@@ -124,6 +147,78 @@ struct TableBuilder2
124147
dpctl_td_ns::DispatchTableBuilder<FnT, impl, dpctl_td_ns::num_types>;
125148
};
126149

150+
template <typename FnT>
151+
class DispatchTable
152+
{
153+
public:
154+
DispatchTable(std::string name) : name(name) {}
155+
156+
template <typename SupportedTypes, template <typename> typename Func>
157+
void populate_dispatch_table()
158+
{
159+
using TBulder = typename TableBuilder<FnT, SupportedTypes, Func>::type;
160+
TBulder builder;
161+
162+
builder.populate_dispatch_vector(table);
163+
populate_supported_types();
164+
}
165+
166+
FnT get_unsafe(int _typenum) const
167+
{
168+
auto array_types = dpctl_td_ns::usm_ndarray_types();
169+
const int type_id = array_types.typenum_to_lookup_id(_typenum);
170+
171+
return table[type_id];
172+
}
173+
174+
FnT get(int _typenum) const
175+
{
176+
auto fn = get_unsafe(_typenum);
177+
178+
if (fn == nullptr) {
179+
auto array_types = dpctl_td_ns::usm_ndarray_types();
180+
const int _type_id = array_types.typenum_to_lookup_id(_typenum);
181+
182+
py::dtype _dtype = dtype_from_typenum(_type_id);
183+
auto _type_pos = std::find(supported_types.begin(),
184+
supported_types.end(), _dtype);
185+
if (_type_pos == supported_types.end()) {
186+
py::str types = py::str(py::cast(supported_types));
187+
py::str dtype = py::str(_dtype);
188+
189+
py::str err_msg =
190+
py::str("'" + name + "' has unsupported type '") + dtype +
191+
py::str("'."
192+
" Supported types are: ") +
193+
types;
194+
195+
throw py::value_error(static_cast<std::string>(err_msg));
196+
}
197+
}
198+
199+
return fn;
200+
}
201+
202+
const SupportedDTypeList &get_all_supported_types() const
203+
{
204+
return supported_types;
205+
}
206+
207+
private:
208+
void populate_supported_types()
209+
{
210+
for (int i = 0; i < dpctl_td_ns::num_types; ++i) {
211+
if (table[i] != nullptr) {
212+
supported_types.emplace_back(dtype_from_typenum(i));
213+
}
214+
}
215+
}
216+
217+
std::string name;
218+
SupportedDTypeList supported_types;
219+
Table<FnT> table;
220+
};
221+
127222
template <typename FnT>
128223
class DispatchTable2
129224
{
@@ -288,5 +383,4 @@ class DispatchTable2
288383
Table2<FnT> table;
289384
};
290385

291-
} // namespace common
292-
} // namespace statistics
386+
} // namespace statistics::common

dpnp/backend/extensions/statistics/histogram.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@
2626
#include <algorithm>
2727
#include <complex>
2828
#include <memory>
29-
#include <string>
30-
#include <type_traits>
31-
#include <unordered_map>
29+
#include <tuple>
3230
#include <vector>
3331

3432
#include <pybind11/pybind11.h>

dpnp/backend/extensions/statistics/histogram.hpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,7 @@
3131
#include "dispatch_table.hpp"
3232
#include "dpctl4pybind11.hpp"
3333

34-
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
35-
36-
namespace statistics
37-
{
38-
namespace histogram
34+
namespace statistics::histogram
3935
{
4036
struct Histogram
4137
{
@@ -61,5 +57,4 @@ struct Histogram
6157
};
6258

6359
void populate_histogram(py::module_ m);
64-
} // namespace histogram
65-
} // namespace statistics
60+
} // namespace statistics::histogram

0 commit comments

Comments
 (0)