Skip to content

Commit cc6ac97

Browse files
committed
Update generate_stubs.py to generate stubs for the linalg extension
1 parent 7267090 commit cc6ac97

File tree

3 files changed

+120
-5
lines changed

3 files changed

+120
-5
lines changed

array_api_tests/function_stubs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,5 @@
5656
from .utility_functions import all, any
5757

5858
__all__ += ['all', 'any']
59+
60+
from . import linalg
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""
2+
Function stubs for linear algebra functions (Extension).
3+
4+
NOTE: This file is generated automatically by the generate_stubs.py script. Do
5+
not modify it directly.
6+
7+
See
8+
https://github.com/data-apis/array-api/blob/master/spec/API_specification/linear_algebra_functions.md
9+
"""
10+
11+
from __future__ import annotations
12+
13+
14+
def cholesky(x, /, *, upper=False):
15+
pass
16+
17+
def cross(x1, x2, /, *, axis=-1):
18+
pass
19+
20+
def det(x, /):
21+
pass
22+
23+
def diagonal(x, /, *, axis1=0, axis2=1, offset=0):
24+
pass
25+
26+
def eig():
27+
pass
28+
29+
def eigh(x, /, *, upper=False):
30+
pass
31+
32+
def eigvals():
33+
pass
34+
35+
def eigvalsh(x, /, *, upper=False):
36+
pass
37+
38+
def einsum():
39+
pass
40+
41+
def inv(x, /):
42+
pass
43+
44+
def lstsq(x1, x2, /, *, rtol=None):
45+
pass
46+
47+
def matmul(x1, x2, /):
48+
pass
49+
50+
def matrix_power(x, n, /):
51+
pass
52+
53+
def matrix_rank(x, /, *, rtol=None):
54+
pass
55+
56+
def norm(x, /, *, axis=None, keepdims=False, ord=None):
57+
pass
58+
59+
def outer(x1, x2, /):
60+
pass
61+
62+
def pinv(x, /, *, rtol=None):
63+
pass
64+
65+
def qr(x, /, *, mode='reduced'):
66+
pass
67+
68+
def slogdet(x, /):
69+
pass
70+
71+
def solve(x1, x2, /):
72+
pass
73+
74+
def svd(x, /, *, full_matrices=True):
75+
pass
76+
77+
def tensordot(x1, x2, /, *, axes=2):
78+
pass
79+
80+
def svdvals(x, /):
81+
pass
82+
83+
def trace(x, /, *, axis1=0, axis2=1, offset=0):
84+
pass
85+
86+
def transpose(x, /, *, axes=None):
87+
pass
88+
89+
def vecdot(x1, x2, /, *, axis=None):
90+
pass
91+
92+
__all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eig', 'eigh', 'eigvals', 'eigvalsh', 'einsum', 'inv', 'lstsq', 'matmul', 'matrix_power', 'matrix_rank', 'norm', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'tensordot', 'svdvals', 'trace', 'transpose', 'vecdot']

generate_stubs.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,14 @@ def main():
116116
f.write(TYPES_HEADER)
117117

118118
spec_dir = os.path.join(args.array_api_repo, 'spec', 'API_specification')
119+
extensions_dir = os.path.join(args.array_api_repo, 'spec', 'extensions')
120+
files = sorted([os.path.join(spec_dir, f) for f in os.listdir(spec_dir)]
121+
+ [os.path.join(extensions_dir, f) for f in os.listdir(extensions_dir)])
119122
modules = {}
120-
for filename in sorted(os.listdir(spec_dir)):
121-
with open(os.path.join(spec_dir, filename)) as f:
122-
text = f.read()
123+
for file in files:
124+
filename = os.path.basename(file)
125+
with open(file) as f:
126+
text = f.read()
123127
functions = FUNCTION_RE.findall(text)
124128
methods = METHOD_RE.findall(text)
125129
constants = CONSTANT_RE.findall(text)
@@ -130,9 +134,18 @@ def main():
130134
print(f"Found signatures in {filename}")
131135
if not args.write:
132136
continue
133-
py_file = filename.replace('.md', '.py')
134-
py_path = os.path.join('array_api_tests', 'function_stubs', py_file)
137+
135138
title = filename.replace('.md', '').replace('_', ' ')
139+
if 'extensions' in file:
140+
if filename == 'index.md':
141+
continue
142+
elif filename != 'linear_algebra_functions.md':
143+
raise RuntimeError(f"Don't know how to handle extension file {filename}")
144+
py_file = 'linalg.py'
145+
title += " (Extension)"
146+
else:
147+
py_file = filename.replace('.md', '.py')
148+
py_path = os.path.join('array_api_tests', 'function_stubs', py_file)
136149
module_name = py_file.replace('.py', '')
137150
modules[module_name] = []
138151
if not args.quiet:
@@ -153,6 +166,11 @@ def main():
153166
ismethod = sig in methods
154167
sig = sig.replace(r'\_', '_')
155168
func_name = NAME_RE.match(sig).group(1)
169+
if '.' in func_name:
170+
mod, func_name = func_name.split('.', 2)
171+
if mod != 'linalg':
172+
raise RuntimeError(f"Unexpected namespace prefix {mod!r}")
173+
sig = sig.replace(mod + '.', '')
156174
doc = ""
157175
if ismethod:
158176
doc = f'''
@@ -236,6 +254,9 @@ def {annotated_sig}:{doc}
236254
with open(init_path, 'w') as f:
237255
f.write(INIT_HEADER)
238256
for module_name in modules:
257+
if module_name == 'linalg':
258+
f.write(f'\nfrom . import {module_name}\n')
259+
continue
239260
f.write(f"\nfrom .{module_name} import ")
240261
f.write(', '.join(modules[module_name]))
241262
f.write('\n\n')

0 commit comments

Comments
 (0)