From 4fd4a0dbe2338fe366a4f9fb7e57e536b2ce1465 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Sat, 15 Jul 2023 11:16:10 -0500 Subject: [PATCH] Fix the torch.take() wrapper to make axis optional for ndim = 1 Closes #34 --- array_api_compat/torch/_aliases.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index a50350cc..af5510d7 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -681,7 +681,11 @@ def isdtype( else: return dtype == kind -def take(x: array, indices: array, /, *, axis: int, **kwargs) -> array: +def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -> array: + if axis is None: + if x.ndim != 1: + raise ValueError("axis must be specified when ndim > 1") + axis = 0 return torch.index_select(x, axis, indices, **kwargs) __all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'add',