diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index a50350cc..af5510d7 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -681,7 +681,11 @@ def isdtype( else: return dtype == kind -def take(x: array, indices: array, /, *, axis: int, **kwargs) -> array: +def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -> array: + if axis is None: + if x.ndim != 1: + raise ValueError("axis must be specified when ndim > 1") + axis = 0 return torch.index_select(x, axis, indices, **kwargs) __all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'add',