Skip to content

Commit d1aa4c8

Browse files
authored
Implement dpnp.flatnonzero function (#1956)
* Add implementation of dpnp.flatnonzero() * Update third party tests * Add tests for SYCL queue and USM type * Roll back chnage in .rst
1 parent c94f9f8 commit d1aa4c8

File tree

6 files changed

+58
-10
lines changed

6 files changed

+58
-10
lines changed

dpnp/dpnp_iface.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,9 @@ def are_same_logical_tensors(ar1, ar2):
138138
139139
Parameters
140140
----------
141-
ar1 : {dpnp_array, usm_ndarray}
141+
ar1 : {dpnp.ndarray, usm_ndarray}
142142
First input array.
143-
ar2 : {dpnp_array, usm_ndarray}
143+
ar2 : {dpnp.ndarray, usm_ndarray}
144144
Second input array.
145145
146146
Returns
@@ -399,7 +399,7 @@ def check_supported_arrays_type(*arrays, scalar_type=False, all_scalars=False):
399399
400400
Parameters
401401
----------
402-
arrays : {dpnp_array, usm_ndarray}
402+
arrays : {dpnp.ndarray, usm_ndarray}
403403
Input arrays to check for supported types.
404404
scalar_type : {bool}, optional
405405
A scalar type is also considered as supported if flag is ``True``.
@@ -656,7 +656,7 @@ def get_result_array(a, out=None, casting="safe"):
656656
----------
657657
a : {dpnp_array}
658658
Input array.
659-
out : {dpnp_array, usm_ndarray}
659+
out : {dpnp.ndarray, usm_ndarray}
660660
If provided, value of `a` array will be copied into it
661661
according to ``safe`` casting rule.
662662
It should be of the appropriate shape.
@@ -694,7 +694,7 @@ def get_usm_ndarray(a):
694694
695695
Parameters
696696
----------
697-
a : {dpnp_array, usm_ndarray}
697+
a : {dpnp.ndarray, usm_ndarray}
698698
Input array of supported type :class:`dpnp.ndarray`
699699
or :class:`dpctl.tensor.usm_ndarray`.
700700
@@ -774,7 +774,7 @@ def is_supported_array_type(a):
774774
775775
Parameters
776776
----------
777-
a : {dpnp_array, usm_ndarray}
777+
a : {dpnp.ndarray, usm_ndarray}
778778
An input array to check the type.
779779
780780
Returns

dpnp/dpnp_iface_indexing.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
"diagonal",
6565
"extract",
6666
"fill_diagonal",
67+
"flatnonzero",
6768
"indices",
6869
"mask_indices",
6970
"nonzero",
@@ -509,7 +510,7 @@ def extract(condition, a):
509510
condition : {array_like, scalar}
510511
An array whose non-zero or ``True`` entries indicate the element of `a`
511512
to extract.
512-
a : {dpnp_array, usm_ndarray}
513+
a : {dpnp.ndarray, usm_ndarray}
513514
Input array of the same size as `condition`.
514515
515516
Returns
@@ -585,7 +586,7 @@ def fill_diagonal(a, val, wrap=False):
585586
586587
Parameters
587588
----------
588-
a : {dpnp_array, usm_ndarray}
589+
a : {dpnp.ndarray, usm_ndarray}
589590
Array whose diagonal is to be filled in-place. It must be at least 2-D.
590591
val : {dpnp.ndarray, usm_ndarray, scalar}
591592
Value(s) to write on the diagonal. If `val` is scalar, the value is
@@ -716,6 +717,52 @@ def fill_diagonal(a, val, wrap=False):
716717
usm_a[:] = tmp_a
717718

718719

720+
def flatnonzero(a):
721+
"""
722+
Return indices that are non-zero in the flattened version of `a`.
723+
724+
This is equivalent to ``dpnp.nonzero(dpnp.ravel(a))[0]``.
725+
726+
For full documentation refer to :obj:`numpy.flatnonzero`.
727+
728+
Parameters
729+
----------
730+
a : {dpnp.ndarray, usm_ndarray}
731+
Input data.
732+
733+
Returns
734+
-------
735+
out : dpnp.ndarray
736+
Output array, containing the indices of the elements of ``a.ravel()``
737+
that are non-zero.
738+
739+
See Also
740+
--------
741+
:obj:`dpnp.nonzero` : Return the indices of the non-zero elements of
742+
the input array.
743+
:obj:`dpnp.ravel` : Return a 1-D array containing the elements of
744+
the input array.
745+
746+
Examples
747+
--------
748+
>>> import dpnp as np
749+
>>> x = np.arange(-2, 3)
750+
>>> x
751+
array([-2, -1, 0, 1, 2])
752+
>>> np.flatnonzero(x)
753+
array([0, 1, 3, 4])
754+
755+
Use the indices of the non-zero elements as an index array to extract
756+
these elements:
757+
758+
>>> x.ravel()[np.flatnonzero(x)]
759+
array([-2, -1, 1, 2])
760+
761+
"""
762+
763+
return dpnp.nonzero(dpnp.ravel(a))[0]
764+
765+
719766
def indices(
720767
dimensions,
721768
dtype=int,

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ def _calculate_determinant_sign(ipiv, diag, res_type, n):
608608
609609
Returns
610610
-------
611-
sign : {dpnp_array, usm_ndarray}
611+
sign : {dpnp.ndarray, usm_ndarray}
612612
The sign of the determinant.
613613
614614
"""

tests/test_sycl_queue.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,7 @@ def test_meshgrid(device):
422422
pytest.param("exp2", [0.0, 1.0, 2.0]),
423423
pytest.param("expm1", [1.0e-10, 1.0, 2.0, 4.0, 7.0]),
424424
pytest.param("fabs", [-1.2, 1.2]),
425+
pytest.param("flatnonzero", [-2, -1, 0, 1, 2]),
425426
pytest.param("floor", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]),
426427
pytest.param("gradient", [1.0, 2.0, 4.0, 7.0, 11.0, 16.0]),
427428
pytest.param("histogram_bin_edges", [0, 0, 0, 1, 2, 3, 3, 4, 5]),

tests/test_usm_type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,7 @@ def test_norm(usm_type, ord, axis):
553553
pytest.param("exp2", [0.0, 1.0, 2.0]),
554554
pytest.param("expm1", [1.0e-10, 1.0, 2.0, 4.0, 7.0]),
555555
pytest.param("fabs", [-1.2, 1.2]),
556+
pytest.param("flatnonzero", [-2, -1, 0, 1, 2]),
556557
pytest.param("floor", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]),
557558
pytest.param("gradient", [1, 2, 4, 7, 11, 16]),
558559
pytest.param("histogram_bin_edges", [0, 0, 0, 1, 2, 3, 3, 4, 5]),

tests/third_party/cupy/sorting_tests/test_search.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,6 @@ def test_nonzero(self, xp, dtype):
383383
{"array": numpy.empty((0, 2, 0))},
384384
_ids=False, # Do not generate ids from randomly generated params
385385
)
386-
@pytest.mark.skip("flatnonzero isn't implemented yet")
387386
class TestFlatNonzero:
388387
@testing.for_all_dtypes()
389388
@testing.numpy_cupy_array_equal()

0 commit comments

Comments
 (0)