Skip to content

Commit a1d92cb

Browse files
committed
Skip alias stubs and fallback on source stubs, e.g. for matmul
1 parent afc3e8c commit a1d92cb

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

array_api_tests/stubs.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,29 @@
4040
if name.endswith("_functions"):
4141
category = name.replace("_functions", "")
4242
objects = [getattr(mod, name) for name in mod.__all__]
43-
assert all(isinstance(o, FunctionType) for o in objects)
43+
assert all(isinstance(o, FunctionType) for o in objects) # sanity check
4444
category_to_funcs[category] = objects
4545

46+
all_funcs = []
47+
for funcs in [array_methods, *category_to_funcs.values()]:
48+
all_funcs.extend(funcs)
49+
name_to_func: Dict[str, FunctionType] = {f.__name__: f for f in all_funcs}
50+
4651
EXTENSIONS: str = ["linalg"]
4752
extension_to_funcs: Dict[str, List[FunctionType]] = {}
4853
for ext in EXTENSIONS:
4954
mod = name_to_mod[ext]
5055
objects = [getattr(mod, name) for name in mod.__all__]
51-
assert all(isinstance(o, FunctionType) for o in objects)
52-
extension_to_funcs[ext] = objects
56+
assert all(isinstance(o, FunctionType) for o in objects) # sanity check
57+
funcs = []
58+
for func in objects:
59+
if "Alias" in func.__doc__:
60+
funcs.append(name_to_func[func.__name__])
61+
else:
62+
funcs.append(func)
63+
extension_to_funcs[ext] = funcs
5364

54-
all_funcs = []
55-
for funcs in [array_methods, *category_to_funcs.values(), *extension_to_funcs.values()]:
56-
all_funcs.extend(funcs)
57-
name_to_func: Dict[str, FunctionType] = {f.__name__: f for f in all_funcs}
65+
for funcs in extension_to_funcs.values():
66+
for func in funcs:
67+
if func.__name__ not in name_to_func.keys():
68+
name_to_func[func.__name__] = func

0 commit comments

Comments
 (0)