Skip to content

Commit 98edb71

Browse files
committed
update retunrs for angle and reciprocal
update out keyword implementation for angle and reciprocal to make it consistent with implementation of other functions
1 parent 2e59ee9 commit 98edb71

File tree

2 files changed

+35
-49
lines changed

2 files changed

+35
-49
lines changed

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -447,10 +447,7 @@ def dpnp_angle(x, out=None, order="K"):
447447
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
448448

449449
res_usm = angle_func(x1_usm, out=out_usm, order=order)
450-
if out is None:
451-
return dpnp_array._create_from_usm_ndarray(res_usm)
452-
else:
453-
return out
450+
return _get_result(res_usm, out=out)
454451

455452

456453
_asin_docstring = """
@@ -2535,10 +2532,7 @@ def dpnp_reciprocal(x, out=None, order="K"):
25352532
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
25362533

25372534
res_usm = reciprocal_func(x1_usm, out=out_usm, order=order)
2538-
if out is None:
2539-
return dpnp_array._create_from_usm_ndarray(res_usm)
2540-
else:
2541-
return out
2535+
return _get_result(res_usm, out=out)
25422536

25432537

25442538
_remainder_docstring = """

tests/test_umath.py

Lines changed: 33 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -268,47 +268,6 @@ def test_invalid_shape(self, shape):
268268
dpnp.cbrt(dp_array, out=dp_out)
269269

270270

271-
class TestReciprocal:
272-
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
273-
def test_reciprocal(self, dtype):
274-
np_array = numpy.arange(1, 7, dtype=dtype)
275-
expected = numpy.reciprocal(np_array)
276-
277-
dp_out_dtype = (
278-
dtype
279-
if has_support_aspect64()
280-
else dpnp.complex64
281-
if numpy.iscomplexobj(np_array)
282-
else dpnp.float32
283-
)
284-
dp_array = dpnp.array(np_array)
285-
dp_out = dpnp.empty(6, dtype=dp_out_dtype)
286-
result = dpnp.reciprocal(dp_array, out=dp_out)
287-
288-
assert result is dp_out
289-
assert_dtype_allclose(result, expected)
290-
291-
@pytest.mark.parametrize("dtype", get_float_complex_dtypes()[:-1])
292-
def test_invalid_dtype(self, dtype):
293-
dpnp_dtype = get_float_complex_dtypes()[-1]
294-
dp_array = dpnp.arange(1, 10, dtype=dpnp_dtype)
295-
dp_out = dpnp.empty(9, dtype=dtype)
296-
297-
with pytest.raises(TypeError):
298-
dpnp.reciprocal(dp_array, out=dp_out)
299-
300-
@pytest.mark.parametrize("dtype", get_float_dtypes())
301-
@pytest.mark.parametrize(
302-
"shape", [(0,), (15,), (2, 2)], ids=["(0,)", "(15, )", "(2,2)"]
303-
)
304-
def test_invalid_shape(self, shape, dtype):
305-
dp_array = dpnp.arange(10, dtype=dtype)
306-
dp_out = dpnp.empty(shape, dtype=dtype)
307-
308-
with pytest.raises(ValueError):
309-
dpnp.reciprocal(dp_array, out=dp_out)
310-
311-
312271
class TestRsqrt:
313272
@pytest.mark.usefixtures("suppress_divide_numpy_warnings")
314273
@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True))
@@ -397,6 +356,39 @@ def test_invalid_out(self, out):
397356
numpy.testing.assert_raises(TypeError, numpy.square, a.asnumpy(), out)
398357

399358

359+
class TestReciprocal:
360+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
361+
def test_reciprocal(self, dtype):
362+
np_array, expected = _get_numpy_arrays("reciprocal", dtype, [-5, 5, 10])
363+
364+
dp_array = dpnp.array(np_array)
365+
out_dtype = _get_output_data_type(dtype)
366+
dp_out = dpnp.empty(expected.shape, dtype=out_dtype)
367+
result = dpnp.reciprocal(dp_array, out=dp_out)
368+
369+
assert result is dp_out
370+
assert_dtype_allclose(result, expected)
371+
372+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes()[:-1])
373+
def test_invalid_dtype(self, dtype):
374+
dpnp_dtype = get_float_complex_dtypes()[-1]
375+
dp_array = dpnp.arange(10, dtype=dpnp_dtype)
376+
dp_out = dpnp.empty(10, dtype=dtype)
377+
378+
with pytest.raises(TypeError):
379+
dpnp.reciprocal(dp_array, out=dp_out)
380+
381+
@pytest.mark.parametrize(
382+
"shape", [(0,), (15,), (2, 2)], ids=["(0,)", "(15, )", "(2,2)"]
383+
)
384+
def test_invalid_shape(self, shape):
385+
dp_array = dpnp.arange(10)
386+
dp_out = dpnp.empty(shape)
387+
388+
with pytest.raises(ValueError):
389+
dpnp.reciprocal(dp_array, out=dp_out)
390+
391+
400392
class TestArctan2:
401393
@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True))
402394
def test_arctan2(self, dtype):

0 commit comments

Comments
 (0)