Skip to content

Commit 20264e8

Browse files
authored
remove temporary workaround for dpnp.prod (#1768)
1 parent d380776 commit 20264e8

File tree

3 files changed

+1
-36
lines changed

3 files changed

+1
-36
lines changed

dpnp/dpnp_iface_mathematical.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939

4040

4141
import dpctl.tensor as dpt
42-
import dpctl.utils as du
4342
import numpy
4443
from numpy.core.numeric import (
4544
normalize_axis_index,
@@ -2266,32 +2265,7 @@ def prod(
22662265
22672266
"""
22682267

2269-
# Product reduction for complex output are known to fail for Gen9 with 2024.0 compiler
2270-
# TODO: get rid of this temporary work around when OneAPI 2024.1 is released
2271-
dpnp.check_supported_arrays_type(a)
2272-
_dtypes = (a.dtype, dtype)
2273-
_any_complex = any(
2274-
dpnp.issubdtype(dt, dpnp.complexfloating) for dt in _dtypes
2275-
)
2276-
device_mask = (
2277-
du.intel_device_info(a.sycl_device).get("device_id", 0) & 0xFF00
2278-
)
2279-
if _any_complex and device_mask in [0x3E00, 0x9B00]:
2280-
res = call_origin(
2281-
numpy.prod,
2282-
a,
2283-
axis=axis,
2284-
dtype=dtype,
2285-
out=out,
2286-
keepdims=keepdims,
2287-
initial=initial,
2288-
where=where,
2289-
)
2290-
if dpnp.isscalar(res):
2291-
# numpy may return a scalar, convert it back to dpnp array
2292-
return dpnp.array(res, sycl_queue=a.sycl_queue, usm_type=a.usm_type)
2293-
return res
2294-
elif initial is not None:
2268+
if initial is not None:
22952269
raise NotImplementedError(
22962270
"initial keyword argument is only supported with its default value."
22972271
)

tests/test_linalg.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,6 @@ def test_cond(arr, p):
229229

230230

231231
class TestDet:
232-
# TODO: Remove the use of fixture for test_det
233-
# when dpnp.prod() will support complex dtypes on Gen9
234-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
235232
@pytest.mark.parametrize(
236233
"array",
237234
[
@@ -1054,9 +1051,6 @@ def test_solve_errors(self):
10541051

10551052

10561053
class TestSlogdet:
1057-
# TODO: Remove the use of fixture for test_slogdet_2d and test_slogdet_3d
1058-
# when dpnp.prod() will support complex dtypes on Gen9
1059-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
10601054
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
10611055
def test_slogdet_2d(self, dtype):
10621056
a_np = numpy.array([[1, 2], [3, 4]], dtype=dtype)
@@ -1068,7 +1062,6 @@ def test_slogdet_2d(self, dtype):
10681062
assert_allclose(sign_expected, sign_result)
10691063
assert_allclose(logdet_expected, logdet_result, rtol=1e-3, atol=1e-4)
10701064

1071-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
10721065
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
10731066
def test_slogdet_3d(self, dtype):
10741067
a_np = numpy.array(

tests/test_mathematical.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,6 @@ def test_positive_boolean():
752752

753753

754754
class TestProd:
755-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
756755
@pytest.mark.parametrize("func", ["prod", "nanprod"])
757756
@pytest.mark.parametrize("axis", [None, 0, 1, -1, 2, -2, (1, 2), (0, -2)])
758757
@pytest.mark.parametrize("keepdims", [False, True])
@@ -790,7 +789,6 @@ def test_prod_nanprod_bool(self, func, axis, keepdims):
790789
dpnp_res = getattr(dpnp, func)(ia, axis=axis, keepdims=keepdims)
791790
assert_dtype_allclose(dpnp_res, np_res)
792791

793-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
794792
@pytest.mark.usefixtures("suppress_complex_warning")
795793
@pytest.mark.usefixtures("suppress_invalid_numpy_warnings")
796794
@pytest.mark.parametrize("func", ["prod", "nanprod"])

0 commit comments

Comments
 (0)