Skip to content

Commit 6a7efb0

Browse files
Applying review comments
1 parent 494b841 commit 6a7efb0

File tree

8 files changed

+47
-97
lines changed

8 files changed

+47
-97
lines changed

dpnp/backend/extensions/statistics/bincount.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@
2525

2626
#pragma once
2727

28-
#include <dpctl4pybind11.hpp>
28+
#include <pybind11/pybind11.h>
2929
#include <sycl/sycl.hpp>
3030

3131
#include "dispatch_table.hpp"
32+
#include "dpctl4pybind11.hpp"
3233

3334
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
3435

dpnp/backend/extensions/statistics/common.hpp

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -28,33 +28,10 @@
2828
#include <complex>
2929
#include <pybind11/numpy.h>
3030
#include <pybind11/pybind11.h>
31-
32-
// clang-format off
33-
// math_utils.hpp doesn't include sycl header but uses sycl types
34-
// so sycl.hpp must be included before math_utils.hpp
3531
#include <sycl/sycl.hpp>
32+
3633
#include "utils/math_utils.hpp"
3734
#include "utils/type_utils.hpp"
38-
// clang-format on
39-
40-
namespace dpctl
41-
{
42-
namespace tensor
43-
{
44-
namespace type_utils
45-
{
46-
// Upstream to dpctl
47-
template <class T>
48-
struct is_complex<const std::complex<T>> : std::true_type
49-
{
50-
};
51-
52-
template <typename T>
53-
constexpr bool is_complex_v = is_complex<T>::value;
54-
55-
} // namespace type_utils
56-
} // namespace tensor
57-
} // namespace dpctl
5835

5936
namespace type_utils = dpctl::tensor::type_utils;
6037

@@ -115,18 +92,16 @@ struct IsNan
11592
static bool isnan(const T &v)
11693
{
11794
if constexpr (type_utils::is_complex_v<T>) {
118-
const auto real1 = std::real(v);
119-
const auto imag1 = std::imag(v);
120-
12195
using vT = typename T::value_type;
12296

97+
const vT real1 = std::real(v);
98+
const vT imag1 = std::imag(v);
99+
123100
return IsNan<vT>::isnan(real1) || IsNan<vT>::isnan(imag1);
124101
}
125-
else {
126-
if constexpr (std::is_floating_point_v<T> ||
127-
std::is_same_v<T, sycl::half>) {
128-
return sycl::isnan(v);
129-
}
102+
else if constexpr (std::is_floating_point_v<T> ||
103+
std::is_same_v<T, sycl::half>) {
104+
return sycl::isnan(v);
130105
}
131106

132107
return false;

dpnp/backend/extensions/statistics/histogram.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@
2525

2626
#pragma once
2727

28+
#include <pybind11/pybind11.h>
2829
#include <sycl/sycl.hpp>
2930

3031
#include "dispatch_table.hpp"
32+
#include "dpctl4pybind11.hpp"
3133

3234
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
3335

dpnp/backend/extensions/statistics/histogram_common.hpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,13 @@ struct CachedData
7474
template <int _Dims>
7575
void init(const sycl::nd_item<_Dims> &item) const
7676
{
77-
int32_t llid = item.get_local_linear_id();
77+
uint32_t llid = item.get_local_linear_id();
7878
auto local_ptr = &local_data[0];
79-
int32_t size = local_data.size();
79+
uint32_t size = local_data.size();
8080
auto group = item.get_group();
81-
int32_t local_size = group.get_local_linear_range();
81+
uint32_t local_size = group.get_local_linear_range();
8282

83-
for (int32_t i = llid; i < size; i += local_size) {
83+
for (uint32_t i = llid; i < size; i += local_size) {
8484
local_ptr[i] = global_data[i];
8585
}
8686
}
@@ -218,15 +218,15 @@ struct HistWithLocalCopies
218218
template <int _Dims>
219219
void finalize(const sycl::nd_item<_Dims> &item) const
220220
{
221-
int32_t llid = item.get_local_linear_id();
222-
int32_t bins_count = local_hist.get_range().get(1);
223-
int32_t local_hist_count = local_hist.get_range().get(0);
221+
uint32_t llid = item.get_local_linear_id();
222+
uint32_t bins_count = local_hist.get_range().get(1);
223+
uint32_t local_hist_count = local_hist.get_range().get(0);
224224
auto group = item.get_group();
225-
int32_t local_size = group.get_local_linear_range();
225+
uint32_t local_size = group.get_local_linear_range();
226226

227-
for (int32_t i = llid; i < bins_count; i += local_size) {
227+
for (uint32_t i = llid; i < bins_count; i += local_size) {
228228
auto value = local_hist[0][i];
229-
for (int32_t lhc = 1; lhc < local_hist_count; ++lhc) {
229+
for (uint32_t lhc = 1; lhc < local_hist_count; ++lhc) {
230230
value += local_hist[lhc][i];
231231
}
232232
if (value != T(0)) {

dpnp/backend/extensions/statistics/histogramdd.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ struct HistogramddF
245245
};
246246

247247
template <typename T, typename HistType = size_t>
248-
using HistogramddF2 = HistogramddF<T, T, HistType>;
248+
using HistogramddF_ = HistogramddF<T, T, HistType>;
249249

250250
using SupportedTypes =
251251
std::tuple<std::tuple<uint64_t, float>,
@@ -268,7 +268,7 @@ using SupportedTypes =
268268

269269
Histogramdd::Histogramdd() : dispatch_table("sample", "histogram")
270270
{
271-
dispatch_table.populate_dispatch_table<SupportedTypes, HistogramddF2>();
271+
dispatch_table.populate_dispatch_table<SupportedTypes, HistogramddF_>();
272272
}
273273

274274
std::tuple<sycl::event, sycl::event> Histogramdd::call(

dpnp/backend/extensions/statistics/histogramdd.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,11 @@
2525

2626
#pragma once
2727

28+
#include <pybind11/pybind11.h>
2829
#include <sycl/sycl.hpp>
2930

30-
// dpctl tensor headers
31-
#include "dpctl4pybind11.hpp"
32-
3331
#include "dispatch_table.hpp"
32+
#include "dpctl4pybind11.hpp"
3433

3534
namespace statistics
3635
{

dpnp/dpnp_iface_histograms.py

Lines changed: 15 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -235,32 +235,6 @@ def _get_bin_edges(a, bins, range, usm_type):
235235
return bin_edges, None
236236

237237

238-
def _normalize_array(a, dtype, usm_type=None):
239-
if usm_type is None:
240-
usm_type = a.usm_type
241-
242-
try:
243-
return dpnp.asarray(
244-
a,
245-
dtype=dtype,
246-
usm_type=usm_type,
247-
sycl_queue=a.sycl_queue,
248-
order="C",
249-
copy=False,
250-
)
251-
except ValueError:
252-
pass
253-
254-
return dpnp.asarray(
255-
a,
256-
dtype=dtype,
257-
usm_type=usm_type,
258-
sycl_queue=a.sycl_queue,
259-
order="C",
260-
copy=True,
261-
)
262-
263-
264238
def _bincount_validate(x, weights, minlength):
265239
if x.ndim > 1:
266240
raise ValueError("object too deep for desired array")
@@ -426,16 +400,16 @@ def bincount(x, weights=None, minlength=None):
426400
"supported types"
427401
)
428402

429-
x_casted = _normalize_array(x, dtype=x_casted_dtype)
403+
x_casted = dpnp.asarray(x, dtype=x_casted_dtype, order="C")
430404

431405
if weights is not None:
432-
weights_casted = _normalize_array(weights, dtype=ntype_casted)
406+
weights_casted = dpnp.asarray(weights, dtype=ntype_casted, order="C")
433407

434408
n_casted = _bincount_run_native(
435409
x_casted, weights_casted, minlength, ntype_casted, usm_type
436410
)
437411

438-
n = _normalize_array(n_casted, dtype=ntype, usm_type=usm_type)
412+
n = dpnp.asarray(n_casted, dtype=ntype, usm_type=usm_type, order="C")
439413

440414
return n
441415

@@ -643,10 +617,12 @@ def histogram(a, bins=10, range=None, density=None, weights=None):
643617
"supported types"
644618
)
645619

646-
a_casted = _normalize_array(a, a_bin_dtype)
647-
bin_edges_casted = _normalize_array(bin_edges, a_bin_dtype)
620+
a_casted = dpnp.asarray(a, dtype=a_bin_dtype, order="C")
621+
bin_edges_casted = dpnp.asarray(bin_edges, dtype=a_bin_dtype, order="C")
648622
weights_casted = (
649-
_normalize_array(weights, hist_dtype) if weights is not None else None
623+
dpnp.asarray(weights, dtype=hist_dtype, order="C")
624+
if weights is not None
625+
else None
650626
)
651627

652628
# histogram implementation uses atomics, but atomics doesn't work with
@@ -681,7 +657,7 @@ def histogram(a, bins=10, range=None, density=None, weights=None):
681657
)
682658
_manager.add_event_pair(mem_ev, ht_ev)
683659

684-
n = _normalize_array(n_casted, dtype=ntype, usm_type=usm_type)
660+
n = dpnp.asarray(n_casted, dtype=ntype, usm_type=usm_type, order="C")
685661

686662
if density:
687663
db = dpnp.astype(
@@ -794,17 +770,9 @@ def _histdd_validate_bins(bins):
794770
)
795771

796772

797-
def _histdd_check_monotonicity(bins):
798-
for i, b in enumerate(bins):
799-
if dpnp.any(b[:-1] > b[1:]):
800-
raise ValueError(
801-
f"bins[{i}] must increase monotonically, when an array"
802-
)
803-
804-
805773
def _histdd_normalize_bins(bins, ndims):
806774
if not isinstance(bins, Iterable):
807-
if not isinstance(bins, int):
775+
if not dpnp.issubdtype(type(bins), dpnp.integer):
808776
raise ValueError("'bins' must be an integer, when a scalar")
809777

810778
bins = [bins] * ndims
@@ -1053,11 +1021,11 @@ def histogramdd(sample, bins=10, range=None, weights=None, density=False):
10531021
bin_edges_list, edges_count_list, sample_dtype
10541022
)
10551023

1056-
_histdd_check_monotonicity(bin_edges_view_list)
1057-
1058-
sample_ = _normalize_array(sample, sample_dtype)
1024+
sample_ = dpnp.asarray(sample, dtype=sample_dtype, order="C")
10591025
weights_ = (
1060-
_normalize_array(weights, hist_dtype) if weights is not None else None
1026+
dpnp.asarray(weights, dtype=hist_dtype, order="C")
1027+
if weights is not None
1028+
else None
10611029
)
10621030
n = _histdd_run_native(
10631031
sample_,
@@ -1069,7 +1037,7 @@ def histogramdd(sample, bins=10, range=None, weights=None, density=False):
10691037
)
10701038

10711039
expexted_hist_dtype = _histdd_hist_dtype(queue, weights)
1072-
n = _normalize_array(n, expexted_hist_dtype, usm_type)
1040+
n = dpnp.asarray(n, dtype=expexted_hist_dtype, usm_type=usm_type, order="C")
10731041

10741042
if density:
10751043
# calculate the probability density function

dpnp/tests/test_histogram.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -642,10 +642,15 @@ def test_linspace_data(self, dtype):
642642
assert_array_equal(result_hist, expected_hist)
643643

644644
@pytest.mark.parametrize("xp", [numpy, dpnp])
645-
def test_invalid_bin(self, xp):
645+
def test_invalid_bin_float(self, xp):
646646
a = xp.array([[1, 2]])
647647
assert_raises(ValueError, xp.histogramdd, a, bins=0.1)
648648

649+
@pytest.mark.parametrize("xp", [numpy, dpnp])
650+
def test_invalid_bin_2d_array(self, xp):
651+
a = xp.array([[1, 2]])
652+
assert_raises(ValueError, xp.histogramdd, a, bins=[[[10]], 10])
653+
649654
@pytest.mark.parametrize(
650655
"bins",
651656
[
@@ -730,7 +735,7 @@ def test_infinite_edge(self, xp, inf_val):
730735
def test_unsigned_monotonicity_check(self, xp):
731736
# bins must increase monotonically when bins contain unsigned values
732737
arr = xp.array([2])
733-
bins = xp.array([1, 3, 1], dtype="uint64")
738+
bins = [xp.array([1, 3, 1], dtype="uint64")]
734739
with assert_raises(ValueError):
735740
xp.histogramdd(arr, bins=bins)
736741

0 commit comments

Comments
 (0)