Skip to content

Commit 96bfbbf

Browse files
committed
ENH: add pad
1 parent 6df1916 commit 96bfbbf

File tree

4 files changed

+81
-2
lines changed

4 files changed

+81
-2
lines changed

src/array_api_extra/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc
3+
from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc, pad
44

55
__version__ = "0.4.1.dev0"
66

@@ -14,4 +14,5 @@
1414
"kron",
1515
"setdiff1d",
1616
"sinc",
17+
"pad",
1718
]

src/array_api_extra/_funcs.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import warnings
44

55
from ._lib import _compat, _utils
6-
from ._lib._compat import array_namespace
6+
from ._lib._compat import (
7+
array_namespace, is_torch_namespace, is_array_api_strict_namespace
8+
)
79
from ._lib._typing import Array, ModuleType
810

911
__all__ = [
@@ -14,6 +16,7 @@
1416
"kron",
1517
"setdiff1d",
1618
"sinc",
19+
"pad",
1720
]
1821

1922

@@ -538,3 +541,54 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
538541
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)),
539542
)
540543
return xp.sin(y) / y
544+
545+
546+
def pad(x: Array, pad_width: int, mode: str = 'constant', *, xp: ModuleType = None, **kwargs):
547+
"""
548+
Pad the input array.
549+
550+
Parameters
551+
----------
552+
x : array
553+
Input array
554+
pad_width: int
555+
Pad the input array with this many elements from each side
556+
mode: str, optional
557+
Only "constant" mode is currently supported.
558+
xp : array_namespace, optional
559+
The standard-compatible namespace for `x`. Default: infer.
560+
constant_values: python scalar, optional
561+
Use this value to pad the input. Default is zero.
562+
563+
Returns
564+
-------
565+
array
566+
The input array, padded with ``pad_width`` elements equal to ``constant_values``
567+
"""
568+
# xp.pad is available on numpy, cupy and jax.numpy; on torch, reuse
569+
# http://github.com/pytorch/pytorch/blob/main/torch/_numpy/_funcs_impl.py#L2045
570+
571+
if mode != 'constant':
572+
raise NotImplementedError()
573+
574+
value = kwargs.get("constant_values", 0)
575+
if kwargs and list(kwargs.keys()) != ['constant_values']:
576+
raise ValueError(f"Unknown kwargs: {kwargs}")
577+
578+
if xp is None:
579+
xp = array_namespace(x)
580+
581+
if is_array_api_strict_namespace(xp):
582+
padded = xp.full(
583+
tuple(x + 2*pad_width for x in x.shape), fill_value=value, dtype=x.dtype
584+
)
585+
padded[(slice(pad_width, -pad_width, None),)*x.ndim] = x
586+
return padded
587+
elif is_torch_namespace(xp):
588+
pad_width = xp.asarray(pad_width)
589+
pad_width = xp.broadcast_to(pad_width, (x.ndim, 2))
590+
pad_width = xp.flip(pad_width, axis=(0,)).flatten()
591+
return xp.nn.functional.pad(x, tuple(pad_width), value=value)
592+
593+
else:
594+
return xp.pad(x, pad_width, mode=mode, **kwargs)

src/array_api_extra/_lib/_compat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
1212
array_namespace, # pyright: ignore[reportUnknownVariableType]
1313
device,
14+
is_torch_namespace,
15+
is_array_api_strict_namespace,
1416
)
1517

1618
__all__ = [

tests/test_funcs.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
kron,
1616
setdiff1d,
1717
sinc,
18+
pad,
1819
)
1920
from array_api_extra._lib._typing import Array
2021

@@ -385,3 +386,24 @@ def test_device(self):
385386

386387
def test_xp(self):
387388
assert_array_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0))
389+
390+
391+
class TestPad:
392+
def test_simple(self):
393+
a = xp.arange(1, 4)
394+
padded = pad(a, 2)
395+
assert xp.all(padded == xp.asarray([0, 0, 1, 2, 3, 0, 0]))
396+
397+
def test_fill_value(self):
398+
a = xp.arange(1, 4)
399+
padded = pad(a, 2, constant_values=42)
400+
assert xp.all(padded == xp.asarray([42, 42, 1, 2, 3, 42, 42]))
401+
402+
def test_ndim(self):
403+
a = xp.reshape(xp.arange(2*3*4), (2, 3, 4))
404+
padded = pad(a, 2)
405+
assert padded.shape == (6, 7, 8)
406+
407+
def test_typo(self):
408+
with pytest.raises(ValueError, match="Unknown"):
409+
pad(xp.arange(2), pad_width=3, oops=3)

0 commit comments

Comments
 (0)