Skip to content

Commit 720832f

Browse files
committed
Replace _array_module.linalg with just linalg in the linalg tests
1 parent 544d8ea commit 720832f

File tree

1 file changed

+37
-36
lines changed

1 file changed

+37
-36
lines changed

array_api_tests/test_linalg.py

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from .test_broadcasting import broadcast_shapes
2929

3030
from . import _array_module
31+
from ._array_module import linalg
3132

3233
# Standin strategy for not yet implemented tests
3334
todo = none()
@@ -75,12 +76,12 @@ def _test_namedtuple(res, fields, func_name):
7576
kw=kwargs(upper=booleans())
7677
)
7778
def test_cholesky(x, kw):
78-
res = _array_module.linalg.cholesky(x, **kw)
79+
res = linalg.cholesky(x, **kw)
7980

8081
assert res.shape == x.shape, "cholesky() did not return the correct shape"
8182
assert res.dtype == x.dtype, "cholesky() did not return the correct dtype"
8283

83-
_test_stacks(_array_module.linalg.cholesky, x, **kw, res=res)
84+
_test_stacks(linalg.cholesky, x, **kw, res=res)
8485

8586
# Test that the result is upper or lower triangular
8687
if kw.get('upper', False):
@@ -129,7 +130,7 @@ def test_cross(x1_x2_kw):
129130
shape = x1.shape
130131
assert x1.shape[axis] == x2.shape[axis] == 3, err
131132

132-
res = _array_module.linalg.cross(x1, x2, **kw)
133+
res = linalg.cross(x1, x2, **kw)
133134

134135
# TODO: Replace result_type() with a helper function
135136
assert res.dtype == _array_module.result_type(x1, x2), "cross() did not return the correct dtype"
@@ -146,7 +147,7 @@ def test_cross(x1_x2_kw):
146147
x1_stack = x1[idx]
147148
x2_stack = x2[idx]
148149
assert x1_stack.shape == x2_stack.shape == (3,), "Invalid cross() stack shapes. This indicates a bug in the test suite."
149-
decomp_res_stack = _array_module.linalg.cross(x1_stack, x2_stack)
150+
decomp_res_stack = linalg.cross(x1_stack, x2_stack)
150151
assert_exactly_equal(res_stack, decomp_res_stack)
151152

152153
exact_cross = asarray([
@@ -160,12 +161,12 @@ def test_cross(x1_x2_kw):
160161
x=xps.arrays(dtype=xps.floating_dtypes(), shape=square_matrix_shapes),
161162
)
162163
def test_det(x):
163-
res = _array_module.linalg.det(x)
164+
res = linalg.det(x)
164165

165166
assert res.dtype == x.dtype, "det() did not return the correct dtype"
166167
assert res.shape == x.shape[:-2], "det() did not return the correct shape"
167168

168-
_test_stacks(_array_module.linalg.det, x, res=res, dims=0)
169+
_test_stacks(linalg.det, x, res=res, dims=0)
169170

170171
# TODO: Test that res actually corresponds to the determinant of x
171172

@@ -176,7 +177,7 @@ def test_det(x):
176177
kw=kwargs(offset=integers(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE))
177178
)
178179
def test_diagonal(x, kw):
179-
res = _array_module.linalg.diagonal(x, **kw)
180+
res = linalg.diagonal(x, **kw)
180181

181182
assert res.dtype == x.dtype, "diagonal() returned the wrong dtype"
182183

@@ -201,11 +202,11 @@ def true_diag(x_stack):
201202
x_stack_diag = [x_stack[i - offset, i] for i in range(diag_size)]
202203
return asarray(x_stack_diag, dtype=x.dtype)
203204

204-
_test_stacks(_array_module.linalg.diagonal, x, **kw, res=res, dims=1, true_val=true_diag)
205+
_test_stacks(linalg.diagonal, x, **kw, res=res, dims=1, true_val=true_diag)
205206

206207
@given(x=symmetric_matrices(finite=True))
207208
def test_eigh(x):
208-
res = _array_module.linalg.eigh(x)
209+
res = linalg.eigh(x)
209210

210211
_test_namedtuple(res, ['eigenvalues', 'eigenvectors'], 'eigh')
211212

@@ -218,35 +219,35 @@ def test_eigh(x):
218219
assert eigenvectors.dtype == x.dtype, "eigh().eigenvectors did not return the correct dtype"
219220
assert eigenvectors.shape == x.shape, "eigh().eigenvectors did not return the correct shape"
220221

221-
_test_stacks(lambda x: _array_module.linalg.eigh(x).eigenvalues, x,
222+
_test_stacks(lambda x: linalg.eigh(x).eigenvalues, x,
222223
res=eigenvalues, dims=1)
223-
_test_stacks(lambda x: _array_module.linalg.eigh(x).eigenvectors, x,
224+
_test_stacks(lambda x: linalg.eigh(x).eigenvectors, x,
224225
res=eigenvectors, dims=2)
225226

226227
# TODO: Test that res actually corresponds to the eigenvalues and
227228
# eigenvectors of x
228229

229230
@given(x=symmetric_matrices(finite=True))
230231
def test_eigvalsh(x):
231-
res = _array_module.linalg.eigvalsh(x)
232+
res = linalg.eigvalsh(x)
232233

233234
assert res.dtype == x.dtype, "eigvalsh() did not return the correct dtype"
234235
assert res.shape == x.shape[:-1], "eigvalsh() did not return the correct shape"
235236

236-
_test_stacks(_array_module.linalg.eigvalsh, x, res=res, dims=1)
237+
_test_stacks(linalg.eigvalsh, x, res=res, dims=1)
237238

238239
# TODO: Should we test that the result is the same as eigh(x).eigenvalues?
239240

240241
# TODO: Test that res actually corresponds to the eigenvalues of x
241242

242243
@given(x=invertible_matrices())
243244
def test_inv(x):
244-
res = _array_module.linalg.inv(x)
245+
res = linalg.inv(x)
245246

246247
assert res.shape == x.shape, "inv() did not return the correct shape"
247248
assert res.dtype == x.dtype, "inv() did not return the correct dtype"
248249

249-
_test_stacks(_array_module.linalg.inv, x, res=res)
250+
_test_stacks(linalg.inv, x, res=res)
250251

251252
# TODO: Test that the result is actually the inverse
252253

@@ -262,11 +263,11 @@ def test_matmul(x1, x2):
262263
or len(x1.shape) >= 2 and len(x2.shape) >= 2 and x1.shape[-1] != x2.shape[-2]):
263264
# The spec doesn't specify what kind of exception is used here. Most
264265
# libraries will use a custom exception class.
265-
raises(Exception, lambda: _array_module.linalg.matmul(x1, x2),
266+
raises(Exception, lambda: linalg.matmul(x1, x2),
266267
"matmul did not raise an exception for invalid shapes")
267268
return
268269
else:
269-
res = _array_module.linalg.matmul(x1, x2)
270+
res = linalg.matmul(x1, x2)
270271

271272
# TODO: Replace result_type() with a helper function
272273
assert res.dtype == _array_module.result_type(x1, x2), "matmul() did not return the correct dtype"
@@ -275,21 +276,21 @@ def test_matmul(x1, x2):
275276
assert res.shape == ()
276277
elif len(x1.shape) == 1:
277278
assert res.shape == x2.shape[:-2] + x2.shape[-1:]
278-
_test_stacks(_array_module.linalg.matmul, x1, x2, res=res, dims=1)
279+
_test_stacks(linalg.matmul, x1, x2, res=res, dims=1)
279280
elif len(x2.shape) == 1:
280281
assert res.shape == x1.shape[:-1]
281-
_test_stacks(_array_module.linalg.matmul, x1, x2, res=res, dims=1)
282+
_test_stacks(linalg.matmul, x1, x2, res=res, dims=1)
282283
else:
283284
stack_shape = broadcast_shapes(x1.shape[:-2], x2.shape[:-2])
284285
assert res.shape == stack_shape + (x1.shape[-2], x2.shape[-1])
285-
_test_stacks(_array_module.linalg.matmul, x1, x2, res=res)
286+
_test_stacks(linalg.matmul, x1, x2, res=res)
286287

287288
@given(
288289
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
289290
kw=kwargs(axis=todo, keepdims=todo, ord=todo)
290291
)
291292
def test_matrix_norm(x, kw):
292-
# res = _array_module.linalg.matrix_norm(x, **kw)
293+
# res = linalg.matrix_norm(x, **kw)
293294
pass
294295

295296
matrix_power_n = shared(integers(-1000, 1000), key='matrix_power n')
@@ -301,82 +302,82 @@ def test_matrix_norm(x, kw):
301302
n=matrix_power_n,
302303
)
303304
def test_matrix_power(x, n):
304-
res = _array_module.linalg.matrix_power(x, n)
305+
res = linalg.matrix_power(x, n)
305306
if n == 0:
306307
true_val = lambda x: _array_module.eye(x.shape[0], dtype=x.dtype)
307308
else:
308309
true_val = None
309310
# _test_stacks only works with array arguments
310-
func = lambda x: _array_module.linalg.matrix_power(x, n)
311+
func = lambda x: linalg.matrix_power(x, n)
311312
_test_stacks(func, x, res=res, true_val=true_val)
312313

313314
@given(
314315
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
315316
kw=kwargs(rtol=todo)
316317
)
317318
def test_matrix_rank(x, kw):
318-
# res = _array_module.linalg.matrix_rank(x, **kw)
319+
# res = linalg.matrix_rank(x, **kw)
319320
pass
320321

321322
@given(
322323
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
323324
)
324325
def test_matrix_transpose(x):
325-
# res = _array_module.linalg.matrix_transpose(x)
326+
# res = linalg.matrix_transpose(x)
326327
pass
327328

328329
@given(
329330
x1=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
330331
x2=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
331332
)
332333
def test_outer(x1, x2):
333-
# res = _array_module.linalg.outer(x1, x2)
334+
# res = linalg.outer(x1, x2)
334335
pass
335336

336337
@given(
337338
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
338339
kw=kwargs(rtol=todo)
339340
)
340341
def test_pinv(x, kw):
341-
# res = _array_module.linalg.pinv(x, **kw)
342+
# res = linalg.pinv(x, **kw)
342343
pass
343344

344345
@given(
345346
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
346347
kw=kwargs(mode=todo)
347348
)
348349
def test_qr(x, kw):
349-
# res = _array_module.linalg.qr(x, **kw)
350+
# res = linalg.qr(x, **kw)
350351
pass
351352

352353
@given(
353354
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
354355
)
355356
def test_slogdet(x):
356-
# res = _array_module.linalg.slogdet(x)
357+
# res = linalg.slogdet(x)
357358
pass
358359

359360
@given(
360361
x1=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
361362
x2=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
362363
)
363364
def test_solve(x1, x2):
364-
# res = _array_module.linalg.solve(x1, x2)
365+
# res = linalg.solve(x1, x2)
365366
pass
366367

367368
@given(
368369
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
369370
kw=kwargs(full_matrices=todo)
370371
)
371372
def test_svd(x, kw):
372-
# res = _array_module.linalg.svd(x, **kw)
373+
# res = linalg.svd(x, **kw)
373374
pass
374375

375376
@given(
376377
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
377378
)
378379
def test_svdvals(x):
379-
# res = _array_module.linalg.svdvals(x)
380+
# res = linalg.svdvals(x)
380381
pass
381382

382383
@given(
@@ -385,15 +386,15 @@ def test_svdvals(x):
385386
kw=kwargs(axes=todo)
386387
)
387388
def test_tensordot(x1, x2, kw):
388-
# res = _array_module.linalg.tensordot(x1, x2, **kw)
389+
# res = linalg.tensordot(x1, x2, **kw)
389390
pass
390391

391392
@given(
392393
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
393394
kw=kwargs(offset=todo)
394395
)
395396
def test_trace(x, kw):
396-
# res = _array_module.linalg.trace(x, **kw)
397+
# res = linalg.trace(x, **kw)
397398
pass
398399

399400
@given(
@@ -402,13 +403,13 @@ def test_trace(x, kw):
402403
kw=kwargs(axis=todo)
403404
)
404405
def test_vecdot(x1, x2, kw):
405-
# res = _array_module.linalg.vecdot(x1, x2, **kw)
406+
# res = linalg.vecdot(x1, x2, **kw)
406407
pass
407408

408409
@given(
409410
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
410411
kw=kwargs(axis=todo, keepdims=todo, ord=todo)
411412
)
412413
def test_vector_norm(x, kw):
413-
# res = _array_module.linalg.vector_norm(x, **kw)
414+
# res = linalg.vector_norm(x, **kw)
414415
pass

0 commit comments

Comments
 (0)