diff --git a/.github/workflows/conda-package.yml b/.github/workflows/conda-package.yml index 8c1c6bfe4352..1a6650798e92 100644 --- a/.github/workflows/conda-package.yml +++ b/.github/workflows/conda-package.yml @@ -38,14 +38,7 @@ env: third_party/cupy/manipulation_tests/test_join.py third_party/cupy/manipulation_tests/test_rearrange.py third_party/cupy/manipulation_tests/test_transpose.py - third_party/cupy/math_tests/test_arithmetic.py - third_party/cupy/math_tests/test_explog.py - third_party/cupy/math_tests/test_floating.py - third_party/cupy/math_tests/test_hyperbolic.py - third_party/cupy/math_tests/test_matmul.py - third_party/cupy/math_tests/test_misc.py - third_party/cupy/math_tests/test_rounding.py - third_party/cupy/math_tests/test_trigonometric.py + third_party/cupy/math_tests third_party/cupy/sorting_tests/test_sort.py third_party/cupy/sorting_tests/test_count.py third_party/cupy/statistics_tests/test_meanvar.py diff --git a/dpnp/dpnp_algo/dpnp_algo_mathematical.pxi b/dpnp/dpnp_algo/dpnp_algo_mathematical.pxi index 431892f10217..ce1b0c5f894f 100644 --- a/dpnp/dpnp_algo/dpnp_algo_mathematical.pxi +++ b/dpnp/dpnp_algo/dpnp_algo_mathematical.pxi @@ -39,7 +39,6 @@ __all__ += [ "dpnp_cross", "dpnp_cumprod", "dpnp_cumsum", - "dpnp_diff", "dpnp_ediff1d", "dpnp_fabs", "dpnp_fmod", @@ -95,35 +94,6 @@ cpdef utils.dpnp_descriptor dpnp_cumsum(utils.dpnp_descriptor x1): return call_fptr_1in_1out(DPNP_FN_CUMSUM_EXT, x1, (x1.size,)) -cpdef utils.dpnp_descriptor dpnp_diff(utils.dpnp_descriptor x1, int n): - cdef utils.dpnp_descriptor res - - x1_obj = x1.get_array() - - if x1.size - n < 1: - res_obj = dpnp_container.empty(0, - dtype=x1.dtype, - device=x1_obj.sycl_device, - usm_type=x1_obj.usm_type, - sycl_queue=x1_obj.sycl_queue) - res = utils.dpnp_descriptor(res_obj) - return res - - res_obj = dpnp_container.empty(x1.size - 1, - dtype=x1.dtype, - device=x1_obj.sycl_device, - usm_type=x1_obj.usm_type, - sycl_queue=x1_obj.sycl_queue) - res = utils.dpnp_descriptor(res_obj) - for i in range(res.size): - res.get_pyobj()[i] = x1.get_pyobj()[i + 1] - x1.get_pyobj()[i] - - if n == 1: - return res - - return dpnp_diff(res, n - 1) - - cpdef utils.dpnp_descriptor dpnp_ediff1d(utils.dpnp_descriptor x1): if x1.size <= 1: diff --git a/dpnp/dpnp_iface_indexing.py b/dpnp/dpnp_iface_indexing.py index 6a61f728e7d2..ef21c3b9b185 100644 --- a/dpnp/dpnp_iface_indexing.py +++ b/dpnp/dpnp_iface_indexing.py @@ -550,17 +550,24 @@ def put_along_axis(a, indices, values, axis): For full documentation refer to :obj:`numpy.put_along_axis`. - Limitations - ----------- - Parameters `a` and `indices` are supported either as :class:`dpnp.ndarray` - or :class:`dpctl.tensor.usm_ndarray`. - Parameter `values` is supported either as scalar, :class:`dpnp.ndarray` - or :class:`dpctl.tensor.usm_ndarray`. - Otherwise ``TypeError`` exception will be raised. + Parameters + ---------- + a : {dpnp.ndarray, usm_ndarray}, (Ni..., M, Nk...) + Destination array. + indices : {dpnp.ndarray, usm_ndarray}, (Ni..., J, Nk...) + Indices to change along each 1d slice of `a`. This must match the + dimension of input array, but dimensions in ``Ni`` and ``Nj`` + may be 1 to broadcast against `a`. + values : {scalar, array_like}, (Ni..., J, Nk...) + Values to insert at those indices. Its shape and dimension are + broadcast to match that of `indices`. + axis : int + The axis to take 1d slices along. If axis is ``None``, the destination + array is treated as if a flattened 1d view had been created of it. See Also -------- - :obj:`dpnp.put` : Put values along an axis, using the same indices for every 1d slice. + :obj:`dpnp.put` : Put values along an axis, using the same indices for every 1d slice. :obj:`dpnp.take_along_axis` : Take values from the input array by matching 1d index and data slices. Examples @@ -736,17 +743,24 @@ def take_along_axis(a, indices, axis): For full documentation refer to :obj:`numpy.take_along_axis`. + Parameters + ---------- + a : {dpnp.ndarray, usm_ndarray}, (Ni..., M, Nk...) + Source array + indices : {dpnp.ndarray, usm_ndarray}, (Ni..., J, Nk...) + Indices to take along each 1d slice of `a`. This must match the + dimension of the input array, but dimensions ``Ni`` and ``Nj`` + only need to broadcast against `a`. + axis : int + The axis to take 1d slices along. If axis is ``None``, the input + array is treated as if it had first been flattened to 1d, + for consistency with `sort` and `argsort`. + Returns ------- out : dpnp.ndarray The indexed result. - Limitations - ----------- - Parameters `a` and `indices` are supported either as :class:`dpnp.ndarray` - or :class:`dpctl.tensor.usm_ndarray`. - Otherwise ``TypeError`` exception will be raised. - See Also -------- :obj:`dpnp.take` : Take along an axis, using the same indices for every 1d slice. diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index d619b5662b1b..89f4e831dcc2 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -43,7 +43,10 @@ import dpctl.tensor as dpt import dpctl.utils as du import numpy -from numpy.core.numeric import normalize_axis_tuple +from numpy.core.numeric import ( + normalize_axis_index, + normalize_axis_tuple, +) import dpnp from dpnp.dpnp_array import dpnp_array @@ -129,6 +132,31 @@ ] +def _append_to_diff_array(a, axis, combined, values): + """ + Append `values` to `combined` list based on data of array `a`. + + Scalar value (including case with 0d array) is expanded to an array + with length=1 in the direction of axis and the shape of the input array `a` + in along all other axes. + Note, if `values` is a scalar. then it is converted to 0d array allocating + on the same SYCL queue as the input array `a` and with the same USM type. + + """ + + dpnp.check_supported_arrays_type(values, scalar_type=True) + if dpnp.isscalar(values): + values = dpnp.asarray( + values, sycl_queue=a.sycl_queue, usm_type=a.usm_type + ) + + if values.ndim == 0: + shape = list(a.shape) + shape[axis] = 1 + values = dpnp.broadcast_to(values, tuple(shape)) + combined.append(values) + + def absolute( x, /, @@ -609,6 +637,10 @@ def cumsum(x1, **kwargs): Otherwise the function will be executed sequentially on CPU. Input array data types are limited by supported DPNP :ref:`Data types`. + See Also + -------- + :obj:`dpnp.diff` : Calculate the n-th discrete difference along the given axis. + Examples -------- >>> import dpnp as np @@ -630,39 +662,95 @@ def cumsum(x1, **kwargs): return call_origin(numpy.cumsum, x1, **kwargs) -def diff(x1, n=1, axis=-1, prepend=numpy._NoValue, append=numpy._NoValue): +def diff(a, n=1, axis=-1, prepend=None, append=None): """ Calculate the n-th discrete difference along the given axis. For full documentation refer to :obj:`numpy.diff`. - Limitations - ----------- - Input array is supported as :obj:`dpnp.ndarray`. - Parameters `axis`, `prepend` and `append` are supported only with default values. - Otherwise the function will be executed sequentially on CPU. + Parameters + ---------- + a : {dpnp.ndarray, usm_ndarray} + Input array + n : int, optional + The number of times values are differenced. If zero, the input + is returned as-is. + axis : int, optional + The axis along which the difference is taken, default is the + last axis. + prepend, append : {scalar, dpnp.ndarray, usm_ndarray}, optional + Values to prepend or append to `a` along axis prior to + performing the difference. Scalar values are expanded to + arrays with length 1 in the direction of axis and the shape + of the input array in along all other axes. Otherwise the + dimension and shape must match `a` except along axis. + + Returns + ------- + out : dpnp.ndarray + The n-th differences. The shape of the output is the same as `a` + except along `axis` where the dimension is smaller by `n`. The + type of the output is the same as the type of the difference + between any two elements of `a`. This is the same as the type of + `a` in most cases. + + See Also + -------- + :obj:`dpnp.gradient` : Return the gradient of an N-dimensional array. + :obj:`dpnp.ediff1d` : Compute the differences between consecutive elements of an array. + :obj:`dpnp.cumsum` : Return the cumulative sum of the elements along a given axis. + + Examples + -------- + >>> import dpnp as np + >>> x = np.array([1, 2, 4, 7, 0]) + >>> np.diff(x) + array([ 1, 2, 3, -7]) + >>> np.diff(x, n=2) + array([ 1, 1, -10]) + + >>> x = np.array([[1, 3, 6, 10], [0, 5, 6, 8]]) + >>> np.diff(x) + array([[2, 3, 4], + [5, 1, 2]]) + >>> np.diff(x, axis=0) + array([[-1, 2, 0, -2]]) + """ - x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) - if x1_desc: - if not isinstance(n, int): - pass - elif n < 1: - pass - elif x1_desc.ndim != 1: - pass - elif axis != -1: - pass - elif prepend is not numpy._NoValue: - pass - elif append is not numpy._NoValue: - pass - else: - return dpnp_diff(x1_desc, n).get_pyobj() + dpnp.check_supported_arrays_type(a) + if n == 0: + return a + if n < 0: + raise ValueError(f"order must be non-negative but got {n}") - return call_origin( - numpy.diff, x1, n=n, axis=axis, prepend=prepend, append=append - ) + nd = a.ndim + if nd == 0: + raise ValueError("diff requires input that is at least one dimensional") + axis = normalize_axis_index(axis, nd) + + combined = [] + if prepend is not None: + _append_to_diff_array(a, axis, combined, prepend) + + combined.append(a) + if append is not None: + _append_to_diff_array(a, axis, combined, append) + + if len(combined) > 1: + a = dpnp.concatenate(combined, axis=axis) + + slice1 = [slice(None)] * nd + slice2 = [slice(None)] * nd + slice1[axis] = slice(1, None) + slice2[axis] = slice(None, -1) + slice1 = tuple(slice1) + slice2 = tuple(slice2) + + op = dpnp.not_equal if a.dtype == numpy.bool_ else dpnp.subtract + for _ in range(n): + a = op(a[slice1], a[slice2]) + return a def divide( @@ -1276,6 +1364,10 @@ def gradient(x1, *varargs, **kwargs): Otherwise the function will be executed sequentially on CPU. Input array data types are limited by supported DPNP :ref:`Data types`. + See Also + -------- + :obj:`dpnp.diff` : Calculate the n-th discrete difference along the given axis. + Examples -------- >>> import dpnp as np diff --git a/tests/helper.py b/tests/helper.py index de4db998a7bf..8fa26116756d 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -76,6 +76,14 @@ def get_integer_dtypes(): return [dpnp.int32, dpnp.int64] +def get_integer_dtypes(): + """ + Build a list of integer types supported by DPNP. + """ + + return [dpnp.int32, dpnp.int64] + + def get_complex_dtypes(device=None): """ Build a list of complex types supported by DPNP based on device capabilities. diff --git a/tests/skipped_tests.tbl b/tests/skipped_tests.tbl index d32f1ee78c0c..87a29e2cb0c5 100644 --- a/tests/skipped_tests.tbl +++ b/tests/skipped_tests.tbl @@ -557,24 +557,22 @@ tests/third_party/cupy/math_tests/test_misc.py::TestConvolve::test_convolve_diff tests/third_party/cupy/math_tests/test_rounding.py::TestRounding::test_fix tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_out tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_out_wrong_shape + tests/third_party/cupy/math_tests/test_sumprod.py::TestCumprod::test_ndarray_cumprod_2dim_with_axis -tests/third_party/cupy/math_tests/test_sumprod.py::TestDiff::test_diff_1dim -tests/third_party/cupy/math_tests/test_sumprod.py::TestDiff::test_diff_1dim_with_n -tests/third_party/cupy/math_tests/test_sumprod.py::TestDiff::test_diff_2dim_without_axis tests/third_party/cupy/math_tests/test_sumprod.py::TestCumprod::test_cumprod_arraylike tests/third_party/cupy/math_tests/test_sumprod.py::TestCumprod::test_cumprod_huge_array tests/third_party/cupy/math_tests/test_sumprod.py::TestCumprod::test_cumprod_numpy_array tests/third_party/cupy/math_tests/test_sumprod.py::TestCumprod::test_cumprod_out_noncontiguous +tests/third_party/cupy/math_tests/test_sumprod.py::TestCumprod::test_cumprod_1dim +tests/third_party/cupy/math_tests/test_sumprod.py::TestCumprod::test_cumprod_2dim_without_axis + tests/third_party/cupy/math_tests/test_sumprod.py::TestCumsum_param_0_{axis=0}::test_cumsum_arraylike tests/third_party/cupy/math_tests/test_sumprod.py::TestCumsum_param_0_{axis=0}::test_cumsum_numpy_array tests/third_party/cupy/math_tests/test_sumprod.py::TestCumsum_param_1_{axis=1}::test_cumsum_arraylike tests/third_party/cupy/math_tests/test_sumprod.py::TestCumsum_param_1_{axis=1}::test_cumsum_numpy_array tests/third_party/cupy/math_tests/test_sumprod.py::TestCumsum_param_2_{axis=2}::test_cumsum_arraylike tests/third_party/cupy/math_tests/test_sumprod.py::TestCumsum_param_2_{axis=2}::test_cumsum_numpy_array -tests/third_party/cupy/math_tests/test_sumprod.py::TestDiff::test_diff_2dim_with_append -tests/third_party/cupy/math_tests/test_sumprod.py::TestDiff::test_diff_2dim_with_axis -tests/third_party/cupy/math_tests/test_sumprod.py::TestDiff::test_diff_2dim_with_n_and_axis -tests/third_party/cupy/math_tests/test_sumprod.py::TestDiff::test_diff_2dim_with_prepend + tests/third_party/cupy/math_tests/test_sumprod.py::TestNansumNanprodAxes_param_0_{axis=(1, 3), shape=(2, 3, 4, 5)}::test_nansum_axes tests/third_party/cupy/math_tests/test_sumprod.py::TestNansumNanprodAxes_param_1_{axis=(1, 3), shape=(20, 30, 40, 50)}::test_nansum_axes tests/third_party/cupy/math_tests/test_sumprod.py::TestNansumNanprodAxes_param_2_{axis=(0, 2, 3), shape=(2, 3, 4, 5)}::test_nansum_axes diff --git a/tests/skipped_tests_gpu.tbl b/tests/skipped_tests_gpu.tbl index c5cf53b2a71c..b6f6ceb45913 100644 --- a/tests/skipped_tests_gpu.tbl +++ b/tests/skipped_tests_gpu.tbl @@ -654,16 +654,15 @@ tests/third_party/cupy/math_tests/test_misc.py::TestConvolve::test_convolve_diff tests/third_party/cupy/math_tests/test_rounding.py::TestRounding::test_fix tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_out tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_out_wrong_shape + tests/third_party/cupy/math_tests/test_sumprod.py::TestCumprod::test_ndarray_cumprod_2dim_with_axis -tests/third_party/cupy/math_tests/test_sumprod.py::TestDiff::test_diff_1dim -tests/third_party/cupy/math_tests/test_sumprod.py::TestDiff::test_diff_1dim_with_n -tests/third_party/cupy/math_tests/test_sumprod.py::TestDiff::test_diff_2dim_without_axis tests/third_party/cupy/math_tests/test_sumprod.py::TestCumprod::test_cumprod_arraylike tests/third_party/cupy/math_tests/test_sumprod.py::TestCumprod::test_cumprod_huge_array tests/third_party/cupy/math_tests/test_sumprod.py::TestCumprod::test_cumprod_numpy_array tests/third_party/cupy/math_tests/test_sumprod.py::TestCumprod::test_cumprod_out_noncontiguous tests/third_party/cupy/math_tests/test_sumprod.py::TestCumprod::test_cumprod_1dim tests/third_party/cupy/math_tests/test_sumprod.py::TestCumprod::test_cumprod_2dim_without_axis + tests/third_party/cupy/math_tests/test_sumprod.py::TestCumsum_param_0_{axis=0}::test_cumsum tests/third_party/cupy/math_tests/test_sumprod.py::TestCumsum_param_0_{axis=0}::test_cumsum_2dim tests/third_party/cupy/math_tests/test_sumprod.py::TestCumsum_param_1_{axis=1}::test_cumsum @@ -676,10 +675,7 @@ tests/third_party/cupy/math_tests/test_sumprod.py::TestCumsum_param_1_{axis=1}:: tests/third_party/cupy/math_tests/test_sumprod.py::TestCumsum_param_1_{axis=1}::test_cumsum_numpy_array tests/third_party/cupy/math_tests/test_sumprod.py::TestCumsum_param_2_{axis=2}::test_cumsum_arraylike tests/third_party/cupy/math_tests/test_sumprod.py::TestCumsum_param_2_{axis=2}::test_cumsum_numpy_array -tests/third_party/cupy/math_tests/test_sumprod.py::TestDiff::test_diff_2dim_with_append -tests/third_party/cupy/math_tests/test_sumprod.py::TestDiff::test_diff_2dim_with_axis -tests/third_party/cupy/math_tests/test_sumprod.py::TestDiff::test_diff_2dim_with_n_and_axis -tests/third_party/cupy/math_tests/test_sumprod.py::TestDiff::test_diff_2dim_with_prepend + tests/third_party/cupy/math_tests/test_sumprod.py::TestNansumNanprodAxes_param_0_{axis=(1, 3), shape=(2, 3, 4, 5)}::test_nansum_axes tests/third_party/cupy/math_tests/test_sumprod.py::TestNansumNanprodAxes_param_1_{axis=(1, 3), shape=(20, 30, 40, 50)}::test_nansum_axes tests/third_party/cupy/math_tests/test_sumprod.py::TestNansumNanprodAxes_param_2_{axis=(0, 2, 3), shape=(2, 3, 4, 5)}::test_nansum_axes diff --git a/tests/test_arraycreation.py b/tests/test_arraycreation.py index 779e62237a08..0a4ce2063379 100644 --- a/tests/test_arraycreation.py +++ b/tests/test_arraycreation.py @@ -14,6 +14,7 @@ import dpnp from .helper import ( + assert_dtype_allclose, get_all_dtypes, has_support_aspect64, ) @@ -876,4 +877,4 @@ def test_logspace_axis(axis): func = lambda xp: xp.logspace( [2, 3], [20, 15], num=2, base=[[1, 3], [5, 7]], axis=axis ) - assert_allclose(func(dpnp), func(numpy)) + assert_dtype_allclose(func(dpnp), func(numpy)) diff --git a/tests/test_mathematical.py b/tests/test_mathematical.py index 4f751b697fef..7484a66bfb53 100644 --- a/tests/test_mathematical.py +++ b/tests/test_mathematical.py @@ -5,6 +5,7 @@ import pytest from numpy.testing import ( assert_allclose, + assert_almost_equal, assert_array_almost_equal, assert_array_equal, assert_equal, @@ -19,12 +20,183 @@ get_complex_dtypes, get_float_complex_dtypes, get_float_dtypes, + get_integer_dtypes, has_support_aspect64, is_cpu_device, is_win_platform, ) +class TestDiff: + @pytest.mark.parametrize("n", list(range(0, 3))) + @pytest.mark.parametrize("dt", get_integer_dtypes()) + def test_basic_integer(self, n, dt): + x = [1, 4, 6, 7, 12] + np_a = numpy.array(x, dtype=dt) + dpnp_a = dpnp.array(x, dtype=dt) + + expected = numpy.diff(np_a, n=n) + result = dpnp.diff(dpnp_a, n=n) + assert_array_equal(expected, result) + + @pytest.mark.parametrize("dt", get_float_dtypes()) + def test_basic_floating(self, dt): + x = [1.1, 2.2, 3.0, -0.2, -0.1] + np_a = numpy.array(x, dtype=dt) + dpnp_a = dpnp.array(x, dtype=dt) + + expected = numpy.diff(np_a) + result = dpnp.diff(dpnp_a) + assert_almost_equal(expected, result) + + @pytest.mark.parametrize("n", [1, 2]) + def test_basic_boolean(self, n): + x = [True, True, False, False] + np_a = numpy.array(x) + dpnp_a = dpnp.array(x) + + expected = numpy.diff(np_a, n=n) + result = dpnp.diff(dpnp_a, n=n) + assert_array_equal(expected, result) + + @pytest.mark.parametrize("dt", get_complex_dtypes()) + def test_basic_complex(self, dt): + x = [1.1 + 1j, 2.2 + 4j, 3.0 + 6j, -0.2 + 7j, -0.1 + 12j] + np_a = numpy.array(x, dtype=dt) + dpnp_a = dpnp.array(x, dtype=dt) + + expected = numpy.diff(np_a) + result = dpnp.diff(dpnp_a) + assert_allclose(expected, result) + + @pytest.mark.parametrize("axis", [None] + list(range(-3, 2))) + def test_axis(self, axis): + np_a = numpy.zeros((10, 20, 30)) + np_a[:, 1::2, :] = 1 + dpnp_a = dpnp.array(np_a) + + kwargs = {} if axis is None else {"axis": axis} + expected = numpy.diff(np_a, **kwargs) + result = dpnp.diff(dpnp_a, **kwargs) + assert_array_equal(expected, result) + + @pytest.mark.parametrize("xp", [numpy, dpnp]) + @pytest.mark.parametrize("axis", [-4, 3]) + def test_axis_error(self, xp, axis): + a = xp.ones((10, 20, 30)) + assert_raises(numpy.AxisError, xp.diff, a, axis=axis) + + @pytest.mark.parametrize("xp", [numpy, dpnp]) + def test_ndim_error(self, xp): + a = xp.array(1.1111111, xp.float32) + assert_raises(ValueError, xp.diff, a) + + @pytest.mark.parametrize("n", [None, 2]) + @pytest.mark.parametrize("axis", [None, 0]) + def test_nd(self, n, axis): + np_a = 20 * numpy.random.rand(10, 20, 30) + dpnp_a = dpnp.array(np_a) + + kwargs = {} if n is None else {"n": n} + if axis is not None: + kwargs.update({"axis": axis}) + + expected = numpy.diff(np_a, **kwargs) + result = dpnp.diff(dpnp_a, **kwargs) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("n", list(range(0, 5))) + def test_n(self, n): + np_a = numpy.array(list(range(3))) + dpnp_a = dpnp.array(np_a) + + expected = numpy.diff(np_a, n=n) + result = dpnp.diff(dpnp_a, n=n) + assert_array_equal(expected, result) + + @pytest.mark.parametrize("xp", [numpy, dpnp]) + def test_n_error(self, xp): + a = xp.array(list(range(3))) + assert_raises(ValueError, xp.diff, a, n=-1) + + @pytest.mark.parametrize("prepend", [0, [0], [-1, 0]]) + def test_prepend(self, prepend): + np_a = numpy.arange(5) + 1 + dpnp_a = dpnp.array(np_a) + + np_p = prepend if numpy.isscalar(prepend) else numpy.array(prepend) + dpnp_p = prepend if dpnp.isscalar(prepend) else dpnp.array(prepend) + + expected = numpy.diff(np_a, prepend=np_p) + result = dpnp.diff(dpnp_a, prepend=dpnp_p) + assert_array_equal(expected, result) + + @pytest.mark.parametrize( + "axis, prepend", + [ + pytest.param(0, 0), + pytest.param(0, [[0, 0]]), + pytest.param(1, 0), + pytest.param(1, [[0], [0]]), + ], + ) + def test_prepend_axis(self, axis, prepend): + np_a = numpy.arange(4).reshape(2, 2) + dpnp_a = dpnp.array(np_a) + + np_p = prepend if numpy.isscalar(prepend) else numpy.array(prepend) + dpnp_p = prepend if dpnp.isscalar(prepend) else dpnp.array(prepend) + + expected = numpy.diff(np_a, axis=axis, prepend=np_p) + result = dpnp.diff(dpnp_a, axis=axis, prepend=dpnp_p) + assert_array_equal(expected, result) + + @pytest.mark.parametrize("append", [0, [0], [0, 2]]) + def test_append(self, append): + np_a = numpy.arange(5) + dpnp_a = dpnp.array(np_a) + + np_ap = append if numpy.isscalar(append) else numpy.array(append) + dpnp_ap = append if dpnp.isscalar(append) else dpnp.array(append) + + expected = numpy.diff(np_a, append=np_ap) + result = dpnp.diff(dpnp_a, append=dpnp_ap) + assert_array_equal(expected, result) + + @pytest.mark.parametrize( + "axis, append", + [ + pytest.param(0, 0), + pytest.param(0, [[0, 0]]), + pytest.param(1, 0), + pytest.param(1, [[0], [0]]), + ], + ) + def test_append_axis(self, axis, append): + np_a = numpy.arange(4).reshape(2, 2) + dpnp_a = dpnp.array(np_a) + + np_ap = append if numpy.isscalar(append) else numpy.array(append) + dpnp_ap = append if dpnp.isscalar(append) else dpnp.array(append) + + expected = numpy.diff(np_a, axis=axis, append=np_ap) + result = dpnp.diff(dpnp_a, axis=axis, append=dpnp_ap) + assert_array_equal(expected, result) + + @pytest.mark.parametrize("xp", [numpy, dpnp]) + def test_prepend_append_error(self, xp): + a = xp.arange(4).reshape(2, 2) + p = xp.zeros((3, 3)) + assert_raises(ValueError, xp.diff, a, prepend=p) + assert_raises(ValueError, xp.diff, a, append=p) + + @pytest.mark.parametrize("xp", [numpy, dpnp]) + def test_prepend_append_axis_error(self, xp): + a = xp.arange(4).reshape(2, 2) + assert_raises(numpy.AxisError, xp.diff, a, axis=3, prepend=0) + assert_raises(numpy.AxisError, xp.diff, a, axis=3, append=0) + + @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestConvolve: def test_object(self): @@ -54,35 +226,6 @@ def test_mode(self): dpnp.convolve(d, k, mode=None) -@pytest.mark.usefixtures("allow_fall_back_on_numpy") -@pytest.mark.parametrize( - "array", - [ - [[0, 0], [0, 0]], - [[1, 2], [1, 2]], - [[1, 2], [3, 4]], - [[[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]], - [ - [[[1, 2], [3, 4]], [[1, 2], [2, 1]]], - [[[1, 3], [3, 1]], [[0, 1], [1, 3]]], - ], - ], - ids=[ - "[[0, 0], [0, 0]]", - "[[1, 2], [1, 2]]", - "[[1, 2], [3, 4]]", - "[[[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]]", - "[[[[1, 2], [3, 4]], [[1, 2], [2, 1]]], [[[1, 3], [3, 1]], [[0, 1], [1, 3]]]]", - ], -) -def test_diff(array): - np_a = numpy.array(array) - dpnp_a = dpnp.array(array) - expected = numpy.diff(np_a) - result = dpnp.diff(dpnp_a) - assert_allclose(expected, result) - - @pytest.mark.parametrize("dtype1", get_all_dtypes()) @pytest.mark.parametrize("dtype2", get_all_dtypes()) @pytest.mark.parametrize( @@ -110,10 +253,6 @@ def test_op_multiple_dtypes(dtype1, func, dtype2, data): "rhs", [[[1, 2, 3], [4, 5, 6]], [2.0, 1.5, 1.0], 3, 0.3] ) @pytest.mark.parametrize("lhs", [[[6, 5, 4], [3, 2, 1]], [1.3, 2.6, 3.9]]) -# TODO: achieve the same level of dtype support for all mathematical operations, like -# @pytest.mark.parametrize("dtype", get_all_dtypes()) -# and to get rid of fallbacks on numpy allowed by below fixture -# @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestMathematical: @staticmethod def array_or_scalar(xp, data, dtype=None): diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index fc4dbf9f0d6e..3c658c14fe52 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -1299,6 +1299,32 @@ def test_asarray(device_x, device_y): assert_sycl_queue_equal(y.sycl_queue, x.to_device(device_y).sycl_queue) +@pytest.mark.parametrize( + "device", + valid_devices, + ids=[device.filter_string for device in valid_devices], +) +@pytest.mark.parametrize( + "kwargs", + [ + pytest.param({"prepend": 7}), + pytest.param({"append": -2}), + pytest.param({"prepend": -4, "append": 5}), + ], +) +def test_diff_scalar_append(device, kwargs): + numpy_data = numpy.arange(7) + dpnp_data = dpnp.array(numpy_data, device=device) + + expected = numpy.diff(numpy_data, **kwargs) + result = dpnp.diff(dpnp_data, **kwargs) + assert_allclose(expected, result) + + expected_queue = dpnp_data.get_array().sycl_queue + result_queue = result.get_array().sycl_queue + assert_sycl_queue_equal(result_queue, expected_queue) + + @pytest.mark.parametrize("func", ["take", "take_along_axis"]) @pytest.mark.parametrize( "device", @@ -1318,5 +1344,4 @@ def test_take(func, device): expected_queue = dpnp_data.get_array().sycl_queue result_queue = result.get_array().sycl_queue - assert_sycl_queue_equal(result_queue, expected_queue) diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index f82e04a2a566..4982ed424140 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -382,6 +382,7 @@ def test_meshgrid(usm_type_x, usm_type_y): ), pytest.param("cosh", [-5.0, -3.5, 0.0, 3.5, 5.0]), pytest.param("count_nonzero", [0, 1, 7, 0]), + pytest.param("diff", [1.0, 2.0, 4.0, 7.0, 0.0]), pytest.param("exp", [1.0, 2.0, 4.0, 7.0]), pytest.param("exp2", [0.0, 1.0, 2.0]), pytest.param("expm1", [1.0e-10, 1.0, 2.0, 4.0, 7.0]), diff --git a/tests/third_party/cupy/math_tests/test_sumprod.py b/tests/third_party/cupy/math_tests/test_sumprod.py index 5834ac94fe2f..0728382a5b43 100644 --- a/tests/third_party/cupy/math_tests/test_sumprod.py +++ b/tests/third_party/cupy/math_tests/test_sumprod.py @@ -360,10 +360,10 @@ def test_nansum_axes(self, xp, dtype): @testing.parameterize(*testing.product({"axis": axes})) @pytest.mark.usefixtures("allow_fall_back_on_numpy") -@testing.gpu +# TODO: remove "type_check=False" once leveraged on dpctl call class TestCumsum(unittest.TestCase): @testing.for_all_dtypes() - @testing.numpy_cupy_allclose() + @testing.numpy_cupy_allclose(type_check=False) def test_cumsum(self, xp, dtype): a = testing.shaped_arange((5,), xp, dtype) return xp.cumsum(a) @@ -385,7 +385,7 @@ def test_cumsum_out_noncontiguous(self, xp, dtype): return out @testing.for_all_dtypes() - @testing.numpy_cupy_allclose() + @testing.numpy_cupy_allclose(type_check=False) def test_cumsum_2dim(self, xp, dtype): a = testing.shaped_arange((4, 5), xp, dtype) return xp.cumsum(a) @@ -569,8 +569,7 @@ def test_cumprod_numpy_array(self, dtype): return cupy.cumprod(a_numpy) -@testing.gpu -class TestDiff(unittest.TestCase): +class TestDiff: @testing.for_all_dtypes() @testing.numpy_cupy_allclose() def test_diff_1dim(self, xp, dtype): @@ -617,7 +616,6 @@ def test_diff_2dim_with_append(self, xp, dtype): b = testing.shaped_arange((1, 5), xp, dtype) return xp.diff(a, axis=0, append=b, n=2) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.with_requires("numpy>=1.16") @testing.for_all_dtypes() @testing.numpy_cupy_allclose(type_check=has_support_aspect64()) @@ -625,7 +623,6 @@ def test_diff_2dim_with_scalar_append(self, xp, dtype): a = testing.shaped_arange((4, 5), xp, dtype) return xp.diff(a, prepend=1, append=0) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.with_requires("numpy>=1.16") def test_diff_invalid_axis(self): for xp in (numpy, cupy):