-
Notifications
You must be signed in to change notification settings - Fork 11
ENH: pad
: add delegation
#72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
d17fd2f
38690bb
d4d05b0
ea09206
9dcb9e5
5db1a93
486ebef
44ec95a
303c7fc
6f72daf
a54357b
71edc05
fd6b9d8
1e59fbd
e0046c5
2bd8205
1531841
64db422
ed7fd25
93f2591
47e9b5b
4c786a2
9daa33a
3ac8e45
82e5258
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
expand_dims | ||
kron | ||
nunique | ||
pad | ||
setdiff1d | ||
sinc | ||
``` |
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
"""Delegation to existing implementations for Public API Functions.""" | ||
|
||
import functools | ||
from enum import Enum | ||
from types import ModuleType | ||
from typing import final | ||
|
||
from ._lib import _funcs | ||
from ._lib._utils._compat import ( | ||
array_namespace, | ||
is_cupy_namespace, | ||
is_jax_namespace, | ||
is_numpy_namespace, | ||
is_torch_namespace, | ||
) | ||
from ._lib._utils._typing import Array | ||
|
||
__all__ = ["pad"] | ||
|
||
|
||
@final | ||
class IsNamespace(Enum): | ||
"""Enum to access is_namespace functions as the backend.""" | ||
|
||
# TODO: when Python 3.10 is dropped, use `enum.member` | ||
# https://stackoverflow.com/a/74302109 | ||
CUPY = functools.partial(is_cupy_namespace) | ||
JAX = functools.partial(is_jax_namespace) | ||
NUMPY = functools.partial(is_numpy_namespace) | ||
TORCH = functools.partial(is_torch_namespace) | ||
|
||
def __call__(self, xp: ModuleType) -> bool: | ||
""" | ||
Call the is_namespace function. | ||
|
||
Parameters | ||
---------- | ||
xp : array_namespace | ||
Array namespace to check. | ||
|
||
Returns | ||
------- | ||
bool | ||
``True`` if xp matches the namespace, ``False`` otherwise. | ||
""" | ||
return self.value(xp) | ||
|
||
|
||
CUPY = IsNamespace.CUPY | ||
JAX = IsNamespace.JAX | ||
NUMPY = IsNamespace.NUMPY | ||
TORCH = IsNamespace.TORCH | ||
|
||
|
||
def _delegate(xp: ModuleType, *backends: IsNamespace) -> bool: | ||
""" | ||
Check whether `xp` is one of the `backends` to delegate to. | ||
|
||
Parameters | ||
---------- | ||
xp : array_namespace | ||
Array namespace to check. | ||
*backends : IsNamespace | ||
Arbitrarily many backends (from the ``IsNamespace`` enum) to check. | ||
|
||
Returns | ||
------- | ||
bool | ||
``True`` if `xp` matches one of the `backends`, ``False`` otherwise. | ||
""" | ||
return any(is_namespace(xp) for is_namespace in backends) | ||
|
||
|
||
def pad( | ||
x: Array, | ||
pad_width: int | tuple[int, int] | list[tuple[int, int]], | ||
mode: str = "constant", | ||
lucascolley marked this conversation as resolved.
Show resolved
Hide resolved
|
||
*, | ||
constant_values: bool | int | float | complex = 0, | ||
xp: ModuleType | None = None, | ||
) -> Array: | ||
""" | ||
Pad the input array. | ||
|
||
Parameters | ||
---------- | ||
x : array | ||
Input array. | ||
pad_width : int or tuple of ints or list of pairs of ints | ||
Pad the input array with this many elements from each side. | ||
If a list of tuples, ``[(before_0, after_0), ... (before_N, after_N)]``, | ||
each pair applies to the corresponding axis of ``x``. | ||
A single tuple, ``(before, after)``, is equivalent to a list of ``x.ndim`` | ||
copies of this tuple. | ||
mode : str, optional | ||
Only "constant" mode is currently supported, which pads with | ||
the value passed to `constant_values`. | ||
constant_values : python scalar, optional | ||
Use this value to pad the input. Default is zero. | ||
xp : array_namespace, optional | ||
The standard-compatible namespace for `x`. Default: infer. | ||
|
||
Returns | ||
------- | ||
array | ||
The input array, | ||
padded with ``pad_width`` elements equal to ``constant_values``. | ||
""" | ||
xp = array_namespace(x) if xp is None else xp | ||
|
||
if mode != "constant": | ||
msg = "Only `'constant'` mode is currently supported" | ||
raise NotImplementedError(msg) | ||
|
||
# https://github.com/pytorch/pytorch/blob/cf76c05b4dc629ac989d1fb8e789d4fac04a095a/torch/_numpy/_funcs_impl.py#L2045-L2056 | ||
if _delegate(xp, TORCH): | ||
pad_width = xp.asarray(pad_width) | ||
pad_width = xp.broadcast_to(pad_width, (x.ndim, 2)) | ||
pad_width = xp.flip(pad_width, axis=(0,)).flatten() | ||
return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] | ||
|
||
if _delegate(xp, NUMPY, JAX, CUPY): | ||
return xp.pad(x, pad_width, mode, constant_values=constant_values) | ||
|
||
return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
"""Modules housing private functions.""" | ||
"""Internals of array-api-extra.""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
"""Public API Functions.""" | ||
"""Array-agnostic implementations for the public API.""" | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Except.... it isn't agnostic, see for example the special paths in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, I would like to split the file structure so that functions which make use of special paths are separate from array-agnostic implementations. I'll save that for a follow-up. |
||
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972 | ||
from __future__ import annotations | ||
|
@@ -11,13 +11,9 @@ | |
from types import ModuleType | ||
from typing import ClassVar, cast | ||
|
||
from ._lib import _compat, _utils | ||
from ._lib._compat import ( | ||
array_namespace, | ||
is_jax_array, | ||
is_writeable_array, | ||
) | ||
from ._lib._typing import Array, Index | ||
from ._utils import _compat, _helpers | ||
from ._utils._compat import array_namespace, is_jax_array, is_writeable_array | ||
from ._utils._typing import Array, Index | ||
|
||
__all__ = [ | ||
"at", | ||
|
@@ -151,7 +147,7 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: | |
m = atleast_nd(m, ndim=2, xp=xp) | ||
m = xp.astype(m, dtype) | ||
|
||
avg = _utils.mean(m, axis=1, xp=xp) | ||
avg = _helpers.mean(m, axis=1, xp=xp) | ||
fact = m.shape[1] - 1 | ||
|
||
if fact <= 0: | ||
|
@@ -467,7 +463,7 @@ def setdiff1d( | |
else: | ||
x1 = xp.unique_values(x1) | ||
x2 = xp.unique_values(x2) | ||
return x1[_utils.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)] | ||
return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)] | ||
|
||
|
||
def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: | ||
|
@@ -562,54 +558,18 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: | |
def pad( | ||
x: Array, | ||
pad_width: int | tuple[int, int] | list[tuple[int, int]], | ||
mode: str = "constant", | ||
*, | ||
xp: ModuleType | None = None, | ||
constant_values: bool | int | float | complex = 0, | ||
) -> Array: | ||
""" | ||
Pad the input array. | ||
|
||
Parameters | ||
---------- | ||
x : array | ||
Input array. | ||
pad_width : int or tuple of ints or list of pairs of ints | ||
Pad the input array with this many elements from each side. | ||
If a list of tuples, ``[(before_0, after_0), ... (before_N, after_N)]``, | ||
each pair applies to the corresponding axis of ``x``. | ||
A single tuple, ``(before, after)``, is equivalent to a list of ``x.ndim`` | ||
copies of this tuple. | ||
mode : str, optional | ||
Only "constant" mode is currently supported, which pads with | ||
the value passed to `constant_values`. | ||
xp : array_namespace, optional | ||
The standard-compatible namespace for `x`. Default: infer. | ||
constant_values : python scalar, optional | ||
Use this value to pad the input. Default is zero. | ||
|
||
Returns | ||
------- | ||
array | ||
The input array, | ||
padded with ``pad_width`` elements equal to ``constant_values``. | ||
""" | ||
if mode != "constant": | ||
msg = "Only `'constant'` mode is currently supported" | ||
raise NotImplementedError(msg) | ||
|
||
value = constant_values | ||
|
||
xp: ModuleType, | ||
) -> Array: # numpydoc ignore=PR01,RT01 | ||
lucascolley marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""See docstring in `array_api_extra._delegation.py`.""" | ||
# make pad_width a list of length-2 tuples of ints | ||
x_ndim = cast(int, x.ndim) | ||
if isinstance(pad_width, int): | ||
pad_width = [(pad_width, pad_width)] * x_ndim | ||
if isinstance(pad_width, tuple): | ||
pad_width = [pad_width] * x_ndim | ||
|
||
if xp is None: | ||
xp = array_namespace(x) | ||
|
||
# https://github.com/python/typeshed/issues/13376 | ||
slices: list[slice] = [] # type: ignore[no-any-explicit] | ||
newshape: list[int] = [] | ||
|
@@ -633,7 +593,7 @@ def pad( | |
|
||
padded = xp.full( | ||
tuple(newshape), | ||
fill_value=value, | ||
fill_value=constant_values, | ||
dtype=x.dtype, | ||
device=_compat.device(x), | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Modules housing private utility functions.""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,10 @@ | ||
"""Static typing helpers.""" | ||
|
||
from types import ModuleType | ||
from typing import Any | ||
|
||
# To be changed to a Protocol later (see data-apis/array-api#589) | ||
Array = Any # type: ignore[no-any-explicit] | ||
Device = Any # type: ignore[no-any-explicit] | ||
Index = Any # type: ignore[no-any-explicit] | ||
|
||
__all__ = ["Array", "Device", "Index", "ModuleType"] | ||
__all__ = ["Array", "Device", "Index"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I liked this before. If anything, we should rename it to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, this should be fixed eventually by array-api-typing. In the meantime, feel free to submit a PR changing use of |
Uh oh!
There was an error while loading. Please reload this page.