Skip to content

Commit dd1fca3

Browse files
Applying review comments
1 parent 494b841 commit dd1fca3

File tree

7 files changed

+35
-75
lines changed

7 files changed

+35
-75
lines changed

dpnp/backend/extensions/statistics/bincount.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@
2525

2626
#pragma once
2727

28-
#include <dpctl4pybind11.hpp>
28+
#include <pybind11/pybind11.h>
2929
#include <sycl/sycl.hpp>
3030

3131
#include "dispatch_table.hpp"
32+
#include "dpctl4pybind11.hpp"
3233

3334
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
3435

dpnp/backend/extensions/statistics/common.hpp

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -28,33 +28,10 @@
2828
#include <complex>
2929
#include <pybind11/numpy.h>
3030
#include <pybind11/pybind11.h>
31-
32-
// clang-format off
33-
// math_utils.hpp doesn't include sycl header but uses sycl types
34-
// so sycl.hpp must be included before math_utils.hpp
3531
#include <sycl/sycl.hpp>
32+
3633
#include "utils/math_utils.hpp"
3734
#include "utils/type_utils.hpp"
38-
// clang-format on
39-
40-
namespace dpctl
41-
{
42-
namespace tensor
43-
{
44-
namespace type_utils
45-
{
46-
// Upstream to dpctl
47-
template <class T>
48-
struct is_complex<const std::complex<T>> : std::true_type
49-
{
50-
};
51-
52-
template <typename T>
53-
constexpr bool is_complex_v = is_complex<T>::value;
54-
55-
} // namespace type_utils
56-
} // namespace tensor
57-
} // namespace dpctl
5835

5936
namespace type_utils = dpctl::tensor::type_utils;
6037

@@ -115,18 +92,16 @@ struct IsNan
11592
static bool isnan(const T &v)
11693
{
11794
if constexpr (type_utils::is_complex_v<T>) {
118-
const auto real1 = std::real(v);
119-
const auto imag1 = std::imag(v);
120-
12195
using vT = typename T::value_type;
12296

97+
const vT real1 = std::real(v);
98+
const vT imag1 = std::imag(v);
99+
123100
return IsNan<vT>::isnan(real1) || IsNan<vT>::isnan(imag1);
124101
}
125-
else {
126-
if constexpr (std::is_floating_point_v<T> ||
127-
std::is_same_v<T, sycl::half>) {
128-
return sycl::isnan(v);
129-
}
102+
else if constexpr (std::is_floating_point_v<T> ||
103+
std::is_same_v<T, sycl::half>) {
104+
return sycl::isnan(v);
130105
}
131106

132107
return false;

dpnp/backend/extensions/statistics/histogram.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@
2525

2626
#pragma once
2727

28+
#include <pybind11/pybind11.h>
2829
#include <sycl/sycl.hpp>
2930

3031
#include "dispatch_table.hpp"
32+
#include "dpctl4pybind11.hpp"
3133

3234
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
3335

dpnp/backend/extensions/statistics/histogramdd.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ struct HistogramddF
245245
};
246246

247247
template <typename T, typename HistType = size_t>
248-
using HistogramddF2 = HistogramddF<T, T, HistType>;
248+
using HistogramddF_ = HistogramddF<T, T, HistType>;
249249

250250
using SupportedTypes =
251251
std::tuple<std::tuple<uint64_t, float>,
@@ -268,7 +268,7 @@ using SupportedTypes =
268268

269269
Histogramdd::Histogramdd() : dispatch_table("sample", "histogram")
270270
{
271-
dispatch_table.populate_dispatch_table<SupportedTypes, HistogramddF2>();
271+
dispatch_table.populate_dispatch_table<SupportedTypes, HistogramddF_>();
272272
}
273273

274274
std::tuple<sycl::event, sycl::event> Histogramdd::call(

dpnp/backend/extensions/statistics/histogramdd.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,11 @@
2525

2626
#pragma once
2727

28+
#include <pybind11/pybind11.h>
2829
#include <sycl/sycl.hpp>
2930

30-
// dpctl tensor headers
31-
#include "dpctl4pybind11.hpp"
32-
3331
#include "dispatch_table.hpp"
32+
#include "dpctl4pybind11.hpp"
3433

3534
namespace statistics
3635
{

dpnp/dpnp_iface_histograms.py

Lines changed: 14 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -235,32 +235,6 @@ def _get_bin_edges(a, bins, range, usm_type):
235235
return bin_edges, None
236236

237237

238-
def _normalize_array(a, dtype, usm_type=None):
239-
if usm_type is None:
240-
usm_type = a.usm_type
241-
242-
try:
243-
return dpnp.asarray(
244-
a,
245-
dtype=dtype,
246-
usm_type=usm_type,
247-
sycl_queue=a.sycl_queue,
248-
order="C",
249-
copy=False,
250-
)
251-
except ValueError:
252-
pass
253-
254-
return dpnp.asarray(
255-
a,
256-
dtype=dtype,
257-
usm_type=usm_type,
258-
sycl_queue=a.sycl_queue,
259-
order="C",
260-
copy=True,
261-
)
262-
263-
264238
def _bincount_validate(x, weights, minlength):
265239
if x.ndim > 1:
266240
raise ValueError("object too deep for desired array")
@@ -426,16 +400,16 @@ def bincount(x, weights=None, minlength=None):
426400
"supported types"
427401
)
428402

429-
x_casted = _normalize_array(x, dtype=x_casted_dtype)
403+
x_casted = dpnp.asarray(x, dtype=x_casted_dtype, order="C")
430404

431405
if weights is not None:
432-
weights_casted = _normalize_array(weights, dtype=ntype_casted)
406+
weights_casted = dpnp.asarray(weights, dtype=ntype_casted, order="C")
433407

434408
n_casted = _bincount_run_native(
435409
x_casted, weights_casted, minlength, ntype_casted, usm_type
436410
)
437411

438-
n = _normalize_array(n_casted, dtype=ntype, usm_type=usm_type)
412+
n = dpnp.asarray(n_casted, dtype=ntype, usm_type=usm_type, order="C")
439413

440414
return n
441415

@@ -643,10 +617,12 @@ def histogram(a, bins=10, range=None, density=None, weights=None):
643617
"supported types"
644618
)
645619

646-
a_casted = _normalize_array(a, a_bin_dtype)
647-
bin_edges_casted = _normalize_array(bin_edges, a_bin_dtype)
620+
a_casted = dpnp.asarray(a, dtype=a_bin_dtype, order="C")
621+
bin_edges_casted = dpnp.asarray(bin_edges, dtype=a_bin_dtype, order="C")
648622
weights_casted = (
649-
_normalize_array(weights, hist_dtype) if weights is not None else None
623+
dpnp.asarray(weights, dtype=hist_dtype, order="C")
624+
if weights is not None
625+
else None
650626
)
651627

652628
# histogram implementation uses atomics, but atomics doesn't work with
@@ -681,7 +657,7 @@ def histogram(a, bins=10, range=None, density=None, weights=None):
681657
)
682658
_manager.add_event_pair(mem_ev, ht_ev)
683659

684-
n = _normalize_array(n_casted, dtype=ntype, usm_type=usm_type)
660+
n = dpnp.asarray(n_casted, dtype=ntype, usm_type=usm_type, order="C")
685661

686662
if density:
687663
db = dpnp.astype(
@@ -1055,9 +1031,11 @@ def histogramdd(sample, bins=10, range=None, weights=None, density=False):
10551031

10561032
_histdd_check_monotonicity(bin_edges_view_list)
10571033

1058-
sample_ = _normalize_array(sample, sample_dtype)
1034+
sample_ = dpnp.asarray(sample, dtype=sample_dtype, order="C")
10591035
weights_ = (
1060-
_normalize_array(weights, hist_dtype) if weights is not None else None
1036+
dpnp.asarray(weights, dtype=hist_dtype, order="C")
1037+
if weights is not None
1038+
else None
10611039
)
10621040
n = _histdd_run_native(
10631041
sample_,
@@ -1069,7 +1047,7 @@ def histogramdd(sample, bins=10, range=None, weights=None, density=False):
10691047
)
10701048

10711049
expexted_hist_dtype = _histdd_hist_dtype(queue, weights)
1072-
n = _normalize_array(n, expexted_hist_dtype, usm_type)
1050+
n = dpnp.asarray(n, dtype=expexted_hist_dtype, usm_type=usm_type, order="C")
10731051

10741052
if density:
10751053
# calculate the probability density function

dpnp/tests/test_histogram.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,10 +642,15 @@ def test_linspace_data(self, dtype):
642642
assert_array_equal(result_hist, expected_hist)
643643

644644
@pytest.mark.parametrize("xp", [numpy, dpnp])
645-
def test_invalid_bin(self, xp):
645+
def test_invalid_bin_float(self, xp):
646646
a = xp.array([[1, 2]])
647647
assert_raises(ValueError, xp.histogramdd, a, bins=0.1)
648648

649+
@pytest.mark.parametrize("xp", [numpy, dpnp])
650+
def test_invalid_bin_2d_array(self, xp):
651+
a = xp.array([[1, 2]])
652+
assert_raises(ValueError, xp.histogramdd, a, bins=[[[10]]])
653+
649654
@pytest.mark.parametrize(
650655
"bins",
651656
[

0 commit comments

Comments
 (0)