Skip to content

Commit 30d7dec

Browse files
committed
ENH: add atleast_nd
1 parent e765ddf commit 30d7dec

File tree

6 files changed

+848
-3
lines changed

6 files changed

+848
-3
lines changed

pixi.lock

Lines changed: 783 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,20 @@ classifiers = [
2727
"Typing :: Typed",
2828
]
2929
dynamic = ["version"]
30-
dependencies = []
30+
dependencies = [
31+
"array-api-compat",
32+
]
3133

3234
[project.optional-dependencies]
3335
test = [
3436
"pytest >=6",
3537
"pytest-cov >=3",
38+
"array-api-strict",
3639
]
3740
dev = [
3841
"pytest >=6",
3942
"pytest-cov >=3",
43+
"array-api-strict",
4044
"pylint",
4145
]
4246
docs = [
@@ -83,6 +87,7 @@ lint = { depends-on = ["pre-commit", "pylint"] }
8387
[tool.pixi.feature.test.dependencies]
8488
pytest = ">=6"
8589
pytest-cov = ">=3"
90+
array-api-strict = "*"
8691

8792
[tool.pixi.feature.test.tasks]
8893
test = { cmd = "pytest" }

src/array_api_extra/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from ._funcs import atleast_nd
4+
35
__version__ = "0.1.dev0"
46

5-
__all__ = ["__version__"]
7+
__all__ = ["__version__", "atleast_nd"]

src/array_api_extra/_funcs.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from array_api_compat import array_namespace # type: ignore[import-not-found]
6+
7+
if TYPE_CHECKING:
8+
from ._typing import Array, ModuleType
9+
10+
__all__ = ["atleast_nd"]
11+
12+
13+
def atleast_nd(x: Array, *, ndim: int, xp: ModuleType | None = None) -> Array:
14+
"""
15+
Recursively expand the dimension of an array to have at least `ndim`.
16+
17+
Parameters
18+
----------
19+
x: array
20+
An array.
21+
22+
Returns
23+
-------
24+
res: array
25+
An array with ``res.ndim`` >= `ndim`.
26+
If ``x.ndim`` >= `ndim`, `x` is returned.
27+
If ``x.ndim`` < `ndim`, ``res.ndim`` will equal `ndim`.
28+
"""
29+
xp = array_namespace(x) if xp is None else xp
30+
31+
x = xp.asarray(x)
32+
if x.ndim < ndim:
33+
x = xp.expand_dims(x, axis=0)
34+
x = atleast_nd(x, ndim=ndim, xp=xp)
35+
return x

src/array_api_extra/_typing.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from __future__ import annotations
2+
3+
from types import ModuleType
4+
from typing import TYPE_CHECKING, Any
5+
6+
if TYPE_CHECKING:
7+
Array = Any # To be changed to a Protocol later (see array-api#589)
8+
9+
__all__ = ["Array", "ModuleType"]

tests/test_funcs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from __future__ import annotations
2+
3+
import array_api_strict as xp # type: ignore[import-not-found]
4+
5+
from array_api_extra import atleast_nd
6+
7+
8+
class TestAtLeastND:
9+
def test_1d_to_2d(self):
10+
x = xp.asarray([0, 1])
11+
y = atleast_nd(x, ndim=2, xp=xp)
12+
assert y.ndim == 2

0 commit comments

Comments
 (0)