Skip to content

Commit ff51015

Browse files
committed
Fix tests
1 parent b069230 commit ff51015

File tree

4 files changed

+32
-17
lines changed

4 files changed

+32
-17
lines changed

array_api_compat/_internal.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,29 @@ def wrapped_f(*args, **kwargs):
4444
return inner
4545

4646

47-
def _get_all_public_members(module, filter_=None):
48-
"""Get all public members of a module."""
49-
try:
50-
return getattr(module, '__all__')
51-
except AttributeError:
52-
pass
47+
def _get_all_public_members(module, exclude=None, extend_all=False):
48+
"""Get all public members of a module.
5349
54-
if filter_ is None:
55-
filter_ = lambda name: name.startswith('_') # noqa: E731
50+
Parameters
51+
----------
52+
module : module
53+
The module to get members from.
54+
exclude : callable, optional
55+
A callable that takes a name and returns True if the name should be
56+
excluded from the list of members.
57+
extend_all : bool, optional
58+
If True, extend the module's __all__ attribute with the members of the
59+
module derive from dir(module)
60+
"""
61+
members = getattr(module, '__all__', [])
62+
63+
if members and not extend_all:
64+
return members
65+
66+
if exclude is None:
67+
exclude = lambda name: name.startswith('_') # noqa: E731
68+
69+
members += [_ for _ in dir(module) if not exclude(_)]
5670

57-
return map(filter_, dir(module))
71+
# remove duplicates
72+
return list(set(members))

array_api_compat/torch/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
11
# Several names are not included in the above import *
2-
import torch
2+
import torch as _torch
33
from torch import * # noqa: F401, F403
44

55
from .._internal import _get_all_public_members
66

77

8-
def filter_(name):
8+
def exlcude(name):
99
if (
1010
name.startswith("_")
1111
or name.endswith("_")
1212
or "cuda" in name
1313
or "cpu" in name
1414
or "backward" in name
1515
):
16-
return False
17-
return True
16+
return True
17+
return False
1818

1919

20-
_torch_all = _get_all_public_members(torch, filter_=filter_)
20+
_torch_all = _get_all_public_members(_torch, exclude=exlcude, extend_all=True)
2121

2222
for _name in _torch_all:
23-
globals()[_name] = getattr(torch, _name)
23+
globals()[_name] = getattr(_torch, _name)
2424

2525

2626
from ..common._helpers import ( # noqa: E402

tests/test_array_namespace.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@ def test_array_namespace_errors():
3434

3535
def test_array_namespace_errors_torch():
3636
torch = pytest.importorskip("torch")
37+
np = pytest.importorskip("numpy")
3738

3839
y = torch.asarray([1, 2])
40+
x = np.asarray([1, 2])
3941
pytest.raises(TypeError, lambda: array_namespace(x, y))
4042
pytest.raises(ValueError, lambda: array_namespace(x, api_version="2022.12"))
4143

tests/test_isdtype.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55

66
import pytest
77

8-
from ._helpers import import_
9-
108
# Check the known dtypes by their string names
119

1210
def _spec_dtypes(library):

0 commit comments

Comments
 (0)