Skip to content

Commit 0bfe1e3

Browse files
Implementing histogramdd (#2143)
Implementation of histogramdd
1 parent 6dc39f9 commit 0bfe1e3

File tree

15 files changed

+1093
-137
lines changed

15 files changed

+1093
-137
lines changed

dpnp/backend/extensions/statistics/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ set(_module_src
2929
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
3030
${CMAKE_CURRENT_SOURCE_DIR}/bincount.cpp
3131
${CMAKE_CURRENT_SOURCE_DIR}/histogram.cpp
32+
${CMAKE_CURRENT_SOURCE_DIR}/histogramdd.cpp
3233
${CMAKE_CURRENT_SOURCE_DIR}/histogram_common.cpp
3334
${CMAKE_CURRENT_SOURCE_DIR}/statistics_py.cpp
3435
)

dpnp/backend/extensions/statistics/bincount.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@
2525

2626
#pragma once
2727

28-
#include <dpctl4pybind11.hpp>
28+
#include <pybind11/pybind11.h>
2929
#include <sycl/sycl.hpp>
3030

3131
#include "dispatch_table.hpp"
32+
#include "dpctl4pybind11.hpp"
3233

3334
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
3435

dpnp/backend/extensions/statistics/common.hpp

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,12 @@
2828
#include <complex>
2929
#include <pybind11/numpy.h>
3030
#include <pybind11/pybind11.h>
31-
32-
// clang-format off
33-
// math_utils.hpp doesn't include sycl header but uses sycl types
34-
// so sycl.hpp must be included before math_utils.hpp
3531
#include <sycl/sycl.hpp>
32+
3633
#include "utils/math_utils.hpp"
37-
// clang-format on
34+
#include "utils/type_utils.hpp"
35+
36+
namespace type_utils = dpctl::tensor::type_utils;
3837

3938
namespace statistics
4039
{
@@ -56,24 +55,20 @@ constexpr auto Align(N n, D d)
5655
template <typename T, sycl::memory_order Order, sycl::memory_scope Scope>
5756
struct AtomicOp
5857
{
59-
static void add(T &lhs, const T value)
58+
static void add(T &lhs, const T &value)
6059
{
61-
sycl::atomic_ref<T, Order, Scope> lh(lhs);
62-
lh += value;
63-
}
64-
};
60+
if constexpr (type_utils::is_complex_v<T>) {
61+
using vT = typename T::value_type;
62+
vT *_lhs = reinterpret_cast<vT(&)[2]>(lhs);
63+
const vT *_val = reinterpret_cast<const vT(&)[2]>(value);
6564

66-
template <typename T, sycl::memory_order Order, sycl::memory_scope Scope>
67-
struct AtomicOp<std::complex<T>, Order, Scope>
68-
{
69-
static void add(std::complex<T> &lhs, const std::complex<T> value)
70-
{
71-
T *_lhs = reinterpret_cast<T(&)[2]>(lhs);
72-
const T *_val = reinterpret_cast<const T(&)[2]>(value);
73-
sycl::atomic_ref<T, Order, Scope> lh0(_lhs[0]);
74-
lh0 += _val[0];
75-
sycl::atomic_ref<T, Order, Scope> lh1(_lhs[1]);
76-
lh1 += _val[1];
65+
AtomicOp<vT, Order, Scope>::add(_lhs[0], _val[0]);
66+
AtomicOp<vT, Order, Scope>::add(_lhs[1], _val[1]);
67+
}
68+
else {
69+
sycl::atomic_ref<T, Order, Scope> lh(lhs);
70+
lh += value;
71+
}
7772
}
7873
};
7974

@@ -82,17 +77,12 @@ struct Less
8277
{
8378
bool operator()(const T &lhs, const T &rhs) const
8479
{
85-
return std::less{}(lhs, rhs);
86-
}
87-
};
88-
89-
template <typename T>
90-
struct Less<std::complex<T>>
91-
{
92-
bool operator()(const std::complex<T> &lhs,
93-
const std::complex<T> &rhs) const
94-
{
95-
return dpctl::tensor::math_utils::less_complex(lhs, rhs);
80+
if constexpr (type_utils::is_complex_v<T>) {
81+
return dpctl::tensor::math_utils::less_complex(lhs, rhs);
82+
}
83+
else {
84+
return std::less{}(lhs, rhs);
85+
}
9686
}
9787
};
9888

@@ -101,26 +91,23 @@ struct IsNan
10191
{
10292
static bool isnan(const T &v)
10393
{
104-
if constexpr (std::is_floating_point_v<T> ||
105-
std::is_same_v<T, sycl::half>) {
94+
if constexpr (type_utils::is_complex_v<T>) {
95+
using vT = typename T::value_type;
96+
97+
const vT real1 = std::real(v);
98+
const vT imag1 = std::imag(v);
99+
100+
return IsNan<vT>::isnan(real1) || IsNan<vT>::isnan(imag1);
101+
}
102+
else if constexpr (std::is_floating_point_v<T> ||
103+
std::is_same_v<T, sycl::half>) {
106104
return sycl::isnan(v);
107105
}
108106

109107
return false;
110108
}
111109
};
112110

113-
template <typename T>
114-
struct IsNan<std::complex<T>>
115-
{
116-
static bool isnan(const std::complex<T> &v)
117-
{
118-
T real1 = std::real(v);
119-
T imag1 = std::imag(v);
120-
return sycl::isnan(real1) || sycl::isnan(imag1);
121-
}
122-
};
123-
124111
size_t get_max_local_size(const sycl::device &device);
125112
size_t get_max_local_size(const sycl::device &device,
126113
int cpu_local_size_limit,

dpnp/backend/extensions/statistics/histogram.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@
2525

2626
#pragma once
2727

28+
#include <pybind11/pybind11.h>
2829
#include <sycl/sycl.hpp>
2930

3031
#include "dispatch_table.hpp"
32+
#include "dpctl4pybind11.hpp"
3133

3234
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
3335

dpnp/backend/extensions/statistics/histogram_common.cpp

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -137,22 +137,16 @@ void validate(const usm_ndarray &sample,
137137
" parameter must have at least 1 element");
138138
}
139139

140-
if (histogram.get_ndim() != 1) {
141-
throw py::value_error(get_name(&histogram) +
142-
" parameter must be 1d. Actual " +
143-
std::to_string(histogram.get_ndim()) + "d");
144-
}
145-
146140
if (weights_ptr) {
147141
if (weights_ptr->get_ndim() != 1) {
148142
throw py::value_error(
149143
get_name(weights_ptr) + " parameter must be 1d. Actual " +
150144
std::to_string(weights_ptr->get_ndim()) + "d");
151145
}
152146

153-
auto sample_size = sample.get_size();
147+
auto sample_size = sample.get_shape(0);
154148
auto weights_size = weights_ptr->get_size();
155-
if (sample.get_size() != weights_ptr->get_size()) {
149+
if (sample_size != weights_ptr->get_size()) {
156150
throw py::value_error(
157151
get_name(&sample) + " size (" + std::to_string(sample_size) +
158152
") and " + get_name(weights_ptr) + " size (" +
@@ -168,42 +162,37 @@ void validate(const usm_ndarray &sample,
168162
}
169163

170164
if (sample.get_ndim() == 1) {
171-
if (bins_ptr != nullptr && bins_ptr->get_ndim() != 1) {
165+
if (histogram.get_ndim() != 1) {
172166
throw py::value_error(get_name(&sample) + " parameter is 1d, but " +
173-
get_name(bins_ptr) + " is " +
174-
std::to_string(bins_ptr->get_ndim()) + "d");
167+
get_name(&histogram) + " is " +
168+
std::to_string(histogram.get_ndim()) + "d");
169+
}
170+
171+
if (bins_ptr && histogram.get_size() != bins_ptr->get_size() - 1) {
172+
auto hist_size = histogram.get_size();
173+
auto bins_size = bins_ptr->get_size();
174+
throw py::value_error(
175+
get_name(&histogram) + " parameter and " + get_name(bins_ptr) +
176+
" parameters shape mismatch. " + get_name(&histogram) +
177+
" size is " + std::to_string(hist_size) + get_name(bins_ptr) +
178+
" must have size " + std::to_string(hist_size + 1) +
179+
" but have " + std::to_string(bins_size));
175180
}
176181
}
177182
else if (sample.get_ndim() == 2) {
178183
auto sample_count = sample.get_shape(0);
179184
auto expected_dims = sample.get_shape(1);
180185

181-
if (bins_ptr != nullptr && bins_ptr->get_ndim() != expected_dims) {
182-
throw py::value_error(get_name(&sample) + " parameter has shape {" +
183-
std::to_string(sample_count) + "x" +
184-
std::to_string(expected_dims) + "}" +
185-
", so " + get_name(bins_ptr) +
186+
if (histogram.get_ndim() != expected_dims) {
187+
throw py::value_error(get_name(&sample) + " parameter has shape (" +
188+
std::to_string(sample_count) + ", " +
189+
std::to_string(expected_dims) + ")" +
190+
", so " + get_name(&histogram) +
186191
" parameter expected to be " +
187192
std::to_string(expected_dims) +
188193
"d. "
189194
"Actual " +
190-
std::to_string(bins->get_ndim()) + "d");
191-
}
192-
}
193-
194-
if (bins_ptr != nullptr) {
195-
py::ssize_t expected_hist_size = 1;
196-
for (int i = 0; i < bins_ptr->get_ndim(); ++i) {
197-
expected_hist_size *= (bins_ptr->get_shape(i) - 1);
198-
}
199-
200-
if (histogram.get_size() != expected_hist_size) {
201-
throw py::value_error(
202-
get_name(&histogram) + " and " + get_name(bins_ptr) +
203-
" shape mismatch. " + get_name(&histogram) +
204-
" expected to have size = " +
205-
std::to_string(expected_hist_size) + ". Actual " +
206-
std::to_string(histogram.get_size()));
195+
std::to_string(histogram.get_ndim()) + "d");
207196
}
208197
}
209198

dpnp/backend/extensions/statistics/histogram_common.hpp

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,15 @@ template <typename T, int Dims>
5252
struct CachedData
5353
{
5454
static constexpr bool const sync_after_init = true;
55-
using pointer_type = T *;
55+
using Shape = sycl::range<Dims>;
56+
using value_type = T;
57+
using pointer_type = value_type *;
58+
static constexpr auto dims = Dims;
5659

57-
using ncT = typename std::remove_const<T>::type;
60+
using ncT = typename std::remove_const<value_type>::type;
5861
using LocalData = sycl::local_accessor<ncT, Dims>;
5962

60-
CachedData(T *global_data, sycl::range<Dims> shape, sycl::handler &cgh)
63+
CachedData(T *global_data, Shape shape, sycl::handler &cgh)
6164
{
6265
this->global_data = global_data;
6366
local_data = LocalData(shape, cgh);
@@ -71,13 +74,13 @@ struct CachedData
7174
template <int _Dims>
7275
void init(const sycl::nd_item<_Dims> &item) const
7376
{
74-
int32_t llid = item.get_local_linear_id();
77+
uint32_t llid = item.get_local_linear_id();
7578
auto local_ptr = &local_data[0];
76-
int32_t size = local_data.size();
79+
uint32_t size = local_data.size();
7780
auto group = item.get_group();
78-
int32_t local_size = group.get_local_linear_range();
81+
uint32_t local_size = group.get_local_linear_range();
7982

80-
for (int32_t i = llid; i < size; i += local_size) {
83+
for (uint32_t i = llid; i < size; i += local_size) {
8184
local_ptr[i] = global_data[i];
8285
}
8386
}
@@ -87,17 +90,30 @@ struct CachedData
8790
return local_data.size();
8891
}
8992

93+
T &operator[](const sycl::id<Dims> &id) const
94+
{
95+
return local_data[id];
96+
}
97+
98+
template <typename = std::enable_if_t<Dims == 1>>
99+
T &operator[](const size_t id) const
100+
{
101+
return local_data[id];
102+
}
103+
90104
private:
91105
LocalData local_data;
92-
T *global_data = nullptr;
106+
value_type *global_data = nullptr;
93107
};
94108

95109
template <typename T, int Dims>
96110
struct UncachedData
97111
{
98112
static constexpr bool const sync_after_init = false;
99113
using Shape = sycl::range<Dims>;
100-
using pointer_type = T *;
114+
using value_type = T;
115+
using pointer_type = value_type *;
116+
static constexpr auto dims = Dims;
101117

102118
UncachedData(T *global_data, const Shape &shape, sycl::handler &)
103119
{
@@ -120,6 +136,17 @@ struct UncachedData
120136
return _shape.size();
121137
}
122138

139+
T &operator[](const sycl::id<Dims> &id) const
140+
{
141+
return global_data[id];
142+
}
143+
144+
template <typename = std::enable_if_t<Dims == 1>>
145+
T &operator[](const size_t id) const
146+
{
147+
return global_data[id];
148+
}
149+
123150
private:
124151
T *global_data = nullptr;
125152
Shape _shape;
@@ -191,15 +218,15 @@ struct HistWithLocalCopies
191218
template <int _Dims>
192219
void finalize(const sycl::nd_item<_Dims> &item) const
193220
{
194-
int32_t llid = item.get_local_linear_id();
195-
int32_t bins_count = local_hist.get_range().get(1);
196-
int32_t local_hist_count = local_hist.get_range().get(0);
221+
uint32_t llid = item.get_local_linear_id();
222+
uint32_t bins_count = local_hist.get_range().get(1);
223+
uint32_t local_hist_count = local_hist.get_range().get(0);
197224
auto group = item.get_group();
198-
int32_t local_size = group.get_local_linear_range();
225+
uint32_t local_size = group.get_local_linear_range();
199226

200-
for (int32_t i = llid; i < bins_count; i += local_size) {
227+
for (uint32_t i = llid; i < bins_count; i += local_size) {
201228
auto value = local_hist[0][i];
202-
for (int32_t lhc = 1; lhc < local_hist_count; ++lhc) {
229+
for (uint32_t lhc = 1; lhc < local_hist_count; ++lhc) {
203230
value += local_hist[lhc][i];
204231
}
205232
if (value != T(0)) {
@@ -290,9 +317,9 @@ class histogram_kernel;
290317

291318
template <typename T, typename HistImpl, typename Edges, typename Weights>
292319
void submit_histogram(const T *in,
293-
size_t size,
294-
size_t dims,
295-
uint32_t WorkPI,
320+
const size_t size,
321+
const size_t dims,
322+
const uint32_t WorkPI,
296323
const HistImpl &hist,
297324
const Edges &edges,
298325
const Weights &weights,

0 commit comments

Comments
 (0)