Skip to content

Commit 4c063b0

Browse files
committed
Fix the testing of extension module names in test_signatures
Previously names that were both in the extension module and the main namespace were only tested in the extension module both times. Now the two are tested separately. Extension module names are also more explicit in the test parameterizations as a result.
1 parent daafc40 commit 4c063b0

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

array_api_tests/test_signatures.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ def stub_module(name):
2121
def extension_module(name):
2222
return name in submodules and name in function_stubs.__all__
2323

24-
extension_module_names = {}
24+
extension_module_names = []
2525
for n in function_stubs.__all__:
2626
if extension_module(n):
27-
extension_module_names.update({i: n for i in getattr(function_stubs, n).__all__})
27+
extension_module_names.extend([f'{n}.{i}' for i in getattr(function_stubs, n).__all__])
2828

29-
all_names = function_stubs.__all__ + list(extension_module_names)
29+
all_names = function_stubs.__all__ + extension_module_names
3030

3131
def array_method(name):
3232
return stub_module(name) == 'array_object'
@@ -134,8 +134,8 @@ def example_argument(arg, func_name, dtype):
134134
def test_has_names(name):
135135
if extension_module(name):
136136
assert hasattr(mod, name), f'{mod_name} is missing the {name} extension'
137-
elif name in extension_module_names:
138-
extension_mod = extension_module_names[name]
137+
elif '.' in name:
138+
extension_mod, name = name.split('.')
139139
assert hasattr(getattr(mod, extension_mod), name), f"{mod_name} is missing the {function_category(name)} extension function {name}()"
140140
elif array_method(name):
141141
arr = ones((1, 1))
@@ -178,9 +178,10 @@ def test_function_positional_args(name):
178178
else:
179179
_mod = example_argument('self', name, dtype)
180180
stub_func = getattr(function_stubs, name)
181-
elif name in extension_module_names:
182-
_mod = getattr(mod, extension_module_names[name])
183-
stub_func = getattr(getattr(function_stubs, extension_module_names[name]), name)
181+
elif '.' in name:
182+
extension_module_name, name = name.split('.')
183+
_mod = getattr(mod, extension_module_name)
184+
stub_func = getattr(getattr(function_stubs, extension_module_name), name)
184185
else:
185186
_mod = mod
186187
stub_func = getattr(function_stubs, name)
@@ -230,9 +231,10 @@ def test_function_keyword_only_args(name):
230231
if array_method(name):
231232
_mod = ones((1, 1))
232233
stub_func = getattr(function_stubs, name)
233-
elif name in extension_module_names:
234-
_mod = getattr(mod, extension_module_names[name])
235-
stub_func = getattr(getattr(function_stubs, extension_module_names[name]), name)
234+
elif '.' in name:
235+
extension_module_name, name = name.split('.')
236+
_mod = getattr(mod, extension_module_name)
237+
stub_func = getattr(getattr(function_stubs, extension_module_name), name)
236238
else:
237239
_mod = mod
238240
stub_func = getattr(function_stubs, name)

0 commit comments

Comments
 (0)