Skip to content

Commit 8fc3e5f

Browse files
committed
Don't use "linalg" for functions that are in the main namespace
1 parent 628c5ed commit 8fc3e5f

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

array_api_tests/test_linalg.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -265,26 +265,26 @@ def test_matmul(x1, x2):
265265
or len(x1.shape) >= 2 and len(x2.shape) >= 2 and x1.shape[-1] != x2.shape[-2]):
266266
# The spec doesn't specify what kind of exception is used here. Most
267267
# libraries will use a custom exception class.
268-
raises(Exception, lambda: linalg.matmul(x1, x2),
268+
raises(Exception, lambda: _array_module.matmul(x1, x2),
269269
"matmul did not raise an exception for invalid shapes")
270270
return
271271
else:
272-
res = linalg.matmul(x1, x2)
272+
res = _array_module.matmul(x1, x2)
273273

274274
assert res.dtype == promote_dtypes(x1, x2), "matmul() did not return the correct dtype"
275275

276276
if len(x1.shape) == len(x2.shape) == 1:
277277
assert res.shape == ()
278278
elif len(x1.shape) == 1:
279279
assert res.shape == x2.shape[:-2] + x2.shape[-1:]
280-
_test_stacks(linalg.matmul, x1, x2, res=res, dims=1)
280+
_test_stacks(_array_module.matmul, x1, x2, res=res, dims=1)
281281
elif len(x2.shape) == 1:
282282
assert res.shape == x1.shape[:-1]
283-
_test_stacks(linalg.matmul, x1, x2, res=res, dims=1)
283+
_test_stacks(_array_module.matmul, x1, x2, res=res, dims=1)
284284
else:
285285
stack_shape = broadcast_shapes(x1.shape[:-2], x2.shape[:-2])
286286
assert res.shape == stack_shape + (x1.shape[-2], x2.shape[-1])
287-
_test_stacks(linalg.matmul, x1, x2, res=res)
287+
_test_stacks(_array_module.matmul, x1, x2, res=res)
288288

289289
@given(
290290
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
@@ -328,7 +328,7 @@ def test_matrix_rank(x, kw):
328328
x=xps.arrays(dtype=dtypes, shape=matrix_shapes()),
329329
)
330330
def test_matrix_transpose(x):
331-
res = linalg.matrix_transpose(x)
331+
res = _array_module.matrix_transpose(x)
332332
true_val = lambda a: _array_module.asarray([[a[i, j] for i in
333333
range(a.shape[0])] for j in
334334
range(a.shape[1])],
@@ -339,7 +339,7 @@ def test_matrix_transpose(x):
339339
assert res.shape == shape, "matrix_transpose() did not return the correct shape"
340340
assert res.dtype == x.dtype, "matrix_transpose() did not return the correct dtype"
341341

342-
_test_stacks(linalg.matrix_transpose, x, res=res, true_val=true_val)
342+
_test_stacks(_array_module.matrix_transpose, x, res=res, true_val=true_val)
343343

344344
@given(
345345
*two_mutual_arrays(dtype_objects=numeric_dtype_objects,
@@ -497,7 +497,7 @@ def test_svdvals(x):
497497
kw=kwargs(axes=todo)
498498
)
499499
def test_tensordot(x1, x2, kw):
500-
# res = linalg.tensordot(x1, x2, **kw)
500+
# res = _array_module.tensordot(x1, x2, **kw)
501501
pass
502502

503503
@given(
@@ -514,7 +514,7 @@ def test_trace(x, kw):
514514
kw=kwargs(axis=todo)
515515
)
516516
def test_vecdot(x1, x2, kw):
517-
# res = linalg.vecdot(x1, x2, **kw)
517+
# res = _array_module.vecdot(x1, x2, **kw)
518518
pass
519519

520520
@given(

0 commit comments

Comments
 (0)