Skip to content

Commit 241f566

Browse files
authored
Merge pull request #3 from lucascolley/atleast_nd
2 parents 7415c53 + 89b9c8c commit 241f566

File tree

12 files changed

+2045
-249
lines changed

12 files changed

+2045
-249
lines changed

.github/workflows/ci.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@ jobs:
3434
with:
3535
pixi-version: v0.30.0
3636
cache: true
37-
- name: Run Pylint
38-
run: pixi run -e lint pylint
37+
- name: Run Pylint & Mypy
38+
run: |
39+
pixi run -e lint pylint
40+
pixi run -e lint mypy
3941
4042
checks:
4143
name: Check ${{ matrix.environment }}

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ instance/
7070

7171
# Sphinx documentation
7272
docs/_build/
73+
docs/generated/
7374

7475
# PyBuilder
7576
.pybuilder/

.pre-commit-config.yaml

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,6 @@ repos:
4848
args: ["--fix", "--show-fixes"]
4949
- id: ruff-format
5050

51-
- repo: https://github.com/pre-commit/mirrors-mypy
52-
rev: "v1.11.1"
53-
hooks:
54-
- id: mypy
55-
files: src|tests
56-
args: []
57-
additional_dependencies:
58-
- pytest
59-
6051
- repo: https://github.com/codespell-project/codespell
6152
rev: "v2.3.0"
6253
hooks:

docs/api-reference.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# API Reference
2+
3+
```{eval-rst}
4+
.. currentmodule:: array_api_extra
5+
.. autosummary::
6+
:nosignatures:
7+
:toctree: generated
8+
9+
atleast_nd
10+
```

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
extensions = [
1010
"myst_parser",
1111
"sphinx.ext.autodoc",
12+
"sphinx.ext.autosummary",
1213
"sphinx.ext.intersphinx",
1314
"sphinx.ext.mathjax",
1415
"sphinx.ext.napoleon",

docs/index.md

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,9 @@
33
```{toctree}
44
:maxdepth: 2
55
:hidden:
6-
6+
api-reference.md
77
```
88

99
```{include} ../README.md
1010
:start-after: <!-- SPHINX-START -->
1111
```
12-
13-
## Indices and tables
14-
15-
- {ref}`genindex`
16-
- {ref}`modindex`
17-
- {ref}`search`

pixi.lock

Lines changed: 1876 additions & 214 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,8 @@ dependencies = []
3232
test = [
3333
"pytest >=6",
3434
"pytest-cov >=3",
35-
]
36-
dev = [
37-
"pytest >=6",
38-
"pytest-cov >=3",
39-
"pylint",
35+
"array-api-strict",
36+
"numpy",
4037
]
4138
docs = [
4239
"sphinx>=7.0",
@@ -68,29 +65,30 @@ platforms = ["linux-64", "osx-arm64", "win-64"]
6865
[tool.pixi.pypi-dependencies]
6966
array-api-extra = { path = ".", editable = true }
7067

71-
[tool.pixi.tasks]
72-
pre-commit = { cmd = "pre-commit install && pre-commit run -v --all-files --show-diff-on-failure" }
73-
7468
[tool.pixi.feature.lint.dependencies]
69+
pre-commit = "*"
70+
mypy = "*"
7571
pylint = "*"
72+
# import dependencies for mypy:
73+
array-api-strict = "*"
74+
numpy = "*"
7675

7776
[tool.pixi.feature.lint.tasks]
77+
pre-commit = { cmd = "pre-commit install && pre-commit run -v --all-files --show-diff-on-failure" }
78+
mypy = { cmd = "mypy", cwd = "." }
7879
pylint = { cmd = ["pylint", "array_api_extra"], cwd = "src" }
79-
lint = { depends-on = ["pre-commit", "pylint"] }
80+
lint = { depends-on = ["pre-commit", "pylint", "mypy"] }
8081

8182
[tool.pixi.feature.test.dependencies]
8283
pytest = ">=6"
8384
pytest-cov = ">=3"
85+
array-api-strict = "*"
86+
numpy = "*"
8487

8588
[tool.pixi.feature.test.tasks]
8689
test = { cmd = "pytest" }
8790
test-ci = { cmd = "pytest -ra --cov --cov-report=xml --cov-report=term --durations=20" }
8891

89-
[tool.pixi.feature.dev.dependencies]
90-
pytest = ">=6"
91-
pytest-cov = ">=3"
92-
pylint = "*"
93-
9492
[tool.pixi.feature.docs.dependencies]
9593
sphinx = ">=7.0"
9694
furo = ">=2023.08.17"
@@ -100,6 +98,15 @@ myst_parser = ">=0.13"
10098
sphinx_copybutton = "*"
10199
sphinx_autodoc_typehints = "*"
102100

101+
[tool.pixi.feature.docs.tasks]
102+
docs = { cmd = ["sphinx-build", ".", "build/"], cwd = "docs" }
103+
104+
[tool.pixi.feature.dev.dependencies]
105+
ipython = "*"
106+
107+
[tool.pixi.feature.dev.tasks]
108+
ipython = { cmd = "ipython" }
109+
103110
[tool.pixi.feature.py309.dependencies]
104111
python = "~=3.9.0"
105112

@@ -109,9 +116,9 @@ python = "~=3.12.0"
109116
[tool.pixi.environments]
110117
default = { solve-group = "default" }
111118
lint = { features = ["lint"], solve-group = "default" }
112-
docs = { features = ["docs"], solve-group = "default" }
113119
test = { features = ["test"], solve-group = "default" }
114-
dev = { features = ["dev", "docs"], solve-group = "default" }
120+
docs = { features = ["docs"], solve-group = "default" }
121+
dev = { features = ["lint", "test", "docs", "dev"], solve-group = "default" }
115122
ci-py309 = ["py309", "test"]
116123
ci-py312 = ["py312", "test"]
117124

src/array_api_extra/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from ._funcs import atleast_nd
4+
35
__version__ = "0.1.dev0"
46

5-
__all__ = ["__version__"]
7+
__all__ = ["__version__", "atleast_nd"]

src/array_api_extra/_funcs.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
from ._typing import Array, ModuleType
7+
8+
__all__ = ["atleast_nd"]
9+
10+
11+
def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array:
12+
"""
13+
Recursively expand the dimension of an array to at least `ndim`.
14+
15+
Parameters
16+
----------
17+
x : array
18+
ndim : int
19+
The minimum number of dimensions for the result.
20+
xp : array_namespace
21+
The standard-compatible namespace for `x`.
22+
23+
Returns
24+
-------
25+
res : array
26+
An array with ``res.ndim`` >= `ndim`.
27+
If ``x.ndim`` >= `ndim`, `x` is returned.
28+
If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes
29+
until ``res.ndim`` equals `ndim`.
30+
31+
Examples
32+
--------
33+
>>> import array_api_strict as xp
34+
>>> import array_api_extra as xpx
35+
>>> x = xp.asarray([1])
36+
>>> xpx.atleast_nd(x, ndim=3, xp=xp)
37+
Array([[[1]]], dtype=array_api_strict.int64)
38+
39+
>>> x = xp.asarray([[[1, 2],
40+
... [3, 4]]])
41+
>>> xpx.atleast_nd(x, ndim=1, xp=xp) is x
42+
True
43+
44+
"""
45+
if x.ndim < ndim:
46+
x = xp.expand_dims(x, axis=0)
47+
x = atleast_nd(x, ndim=ndim, xp=xp)
48+
return x

src/array_api_extra/_typing.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from __future__ import annotations
2+
3+
from types import ModuleType
4+
from typing import TYPE_CHECKING, Any
5+
6+
if TYPE_CHECKING:
7+
Array = Any # To be changed to a Protocol later (see array-api#589)
8+
9+
__all__ = ["Array", "ModuleType"]

tests/test_funcs.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from __future__ import annotations
2+
3+
# array-api-strict#6
4+
import array_api_strict as xp # type: ignore[import-untyped]
5+
from numpy.testing import assert_array_equal
6+
7+
from array_api_extra import atleast_nd
8+
9+
10+
class TestAtLeastND:
11+
def test_0D(self):
12+
x = xp.asarray(1)
13+
14+
y = atleast_nd(x, ndim=0, xp=xp)
15+
assert_array_equal(y, x)
16+
17+
y = atleast_nd(x, ndim=1, xp=xp)
18+
assert_array_equal(y, xp.ones((1,)))
19+
20+
y = atleast_nd(x, ndim=5, xp=xp)
21+
assert_array_equal(y, xp.ones((1, 1, 1, 1, 1)))
22+
23+
def test_1D(self):
24+
x = xp.asarray([0, 1])
25+
26+
y = atleast_nd(x, ndim=0, xp=xp)
27+
assert_array_equal(y, x)
28+
29+
y = atleast_nd(x, ndim=1, xp=xp)
30+
assert_array_equal(y, x)
31+
32+
y = atleast_nd(x, ndim=2, xp=xp)
33+
assert_array_equal(y, xp.asarray([[0, 1]]))
34+
35+
y = atleast_nd(x, ndim=5, xp=xp)
36+
assert_array_equal(y, xp.reshape(xp.arange(2), (1, 1, 1, 1, 2)))
37+
38+
def test_2D(self):
39+
x = xp.asarray([[3]])
40+
41+
y = atleast_nd(x, ndim=0, xp=xp)
42+
assert_array_equal(y, x)
43+
44+
y = atleast_nd(x, ndim=2, xp=xp)
45+
assert_array_equal(y, x)
46+
47+
y = atleast_nd(x, ndim=3, xp=xp)
48+
assert_array_equal(y, 3 * xp.ones((1, 1, 1)))
49+
50+
y = atleast_nd(x, ndim=5, xp=xp)
51+
assert_array_equal(y, 3 * xp.ones((1, 1, 1, 1, 1)))
52+
53+
def test_5D(self):
54+
x = xp.ones((1, 1, 1, 1, 1))
55+
56+
y = atleast_nd(x, ndim=0, xp=xp)
57+
assert_array_equal(y, x)
58+
59+
y = atleast_nd(x, ndim=4, xp=xp)
60+
assert_array_equal(y, x)
61+
62+
y = atleast_nd(x, ndim=5, xp=xp)
63+
assert_array_equal(y, x)
64+
65+
y = atleast_nd(x, ndim=6, xp=xp)
66+
assert_array_equal(y, xp.ones((1, 1, 1, 1, 1, 1)))
67+
68+
y = atleast_nd(x, ndim=9, xp=xp)
69+
assert_array_equal(y, xp.ones((1, 1, 1, 1, 1, 1, 1, 1, 1)))

0 commit comments

Comments
 (0)