diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 027a0261..335008e4 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -548,8 +548,12 @@ def count_nonzero( ) -> Array: result = torch.count_nonzero(x, dim=axis) if keepdims: - if axis is not None: + if isinstance(axis, int): return result.unsqueeze(axis) + elif isinstance(axis, tuple): + n_axis = [x.ndim + ax if ax < 0 else ax for ax in axis] + sh = [1 if i in n_axis else x.shape[i] for i in range(x.ndim)] + return torch.reshape(result, sh) return _axis_none_keepdims(result, x.ndim, keepdims) else: return result diff --git a/numpy-1-22-xfails.txt b/numpy-1-22-xfails.txt index c1de77d8..cacb95b7 100644 --- a/numpy-1-22-xfails.txt +++ b/numpy-1-22-xfails.txt @@ -127,6 +127,13 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_sc array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[subtract] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[multiply] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars