28
28
from .test_broadcasting import broadcast_shapes
29
29
30
30
from . import _array_module
31
+ from ._array_module import linalg
31
32
32
33
# Standin strategy for not yet implemented tests
33
34
todo = none ()
@@ -75,12 +76,12 @@ def _test_namedtuple(res, fields, func_name):
75
76
kw = kwargs (upper = booleans ())
76
77
)
77
78
def test_cholesky (x , kw ):
78
- res = _array_module . linalg .cholesky (x , ** kw )
79
+ res = linalg .cholesky (x , ** kw )
79
80
80
81
assert res .shape == x .shape , "cholesky() did not return the correct shape"
81
82
assert res .dtype == x .dtype , "cholesky() did not return the correct dtype"
82
83
83
- _test_stacks (_array_module . linalg .cholesky , x , ** kw , res = res )
84
+ _test_stacks (linalg .cholesky , x , ** kw , res = res )
84
85
85
86
# Test that the result is upper or lower triangular
86
87
if kw .get ('upper' , False ):
@@ -129,7 +130,7 @@ def test_cross(x1_x2_kw):
129
130
shape = x1 .shape
130
131
assert x1 .shape [axis ] == x2 .shape [axis ] == 3 , err
131
132
132
- res = _array_module . linalg .cross (x1 , x2 , ** kw )
133
+ res = linalg .cross (x1 , x2 , ** kw )
133
134
134
135
# TODO: Replace result_type() with a helper function
135
136
assert res .dtype == _array_module .result_type (x1 , x2 ), "cross() did not return the correct dtype"
@@ -146,7 +147,7 @@ def test_cross(x1_x2_kw):
146
147
x1_stack = x1 [idx ]
147
148
x2_stack = x2 [idx ]
148
149
assert x1_stack .shape == x2_stack .shape == (3 ,), "Invalid cross() stack shapes. This indicates a bug in the test suite."
149
- decomp_res_stack = _array_module . linalg .cross (x1_stack , x2_stack )
150
+ decomp_res_stack = linalg .cross (x1_stack , x2_stack )
150
151
assert_exactly_equal (res_stack , decomp_res_stack )
151
152
152
153
exact_cross = asarray ([
@@ -160,12 +161,12 @@ def test_cross(x1_x2_kw):
160
161
x = xps .arrays (dtype = xps .floating_dtypes (), shape = square_matrix_shapes ),
161
162
)
162
163
def test_det (x ):
163
- res = _array_module . linalg .det (x )
164
+ res = linalg .det (x )
164
165
165
166
assert res .dtype == x .dtype , "det() did not return the correct dtype"
166
167
assert res .shape == x .shape [:- 2 ], "det() did not return the correct shape"
167
168
168
- _test_stacks (_array_module . linalg .det , x , res = res , dims = 0 )
169
+ _test_stacks (linalg .det , x , res = res , dims = 0 )
169
170
170
171
# TODO: Test that res actually corresponds to the determinant of x
171
172
@@ -176,7 +177,7 @@ def test_det(x):
176
177
kw = kwargs (offset = integers (- MAX_ARRAY_SIZE , MAX_ARRAY_SIZE ))
177
178
)
178
179
def test_diagonal (x , kw ):
179
- res = _array_module . linalg .diagonal (x , ** kw )
180
+ res = linalg .diagonal (x , ** kw )
180
181
181
182
assert res .dtype == x .dtype , "diagonal() returned the wrong dtype"
182
183
@@ -201,11 +202,11 @@ def true_diag(x_stack):
201
202
x_stack_diag = [x_stack [i - offset , i ] for i in range (diag_size )]
202
203
return asarray (x_stack_diag , dtype = x .dtype )
203
204
204
- _test_stacks (_array_module . linalg .diagonal , x , ** kw , res = res , dims = 1 , true_val = true_diag )
205
+ _test_stacks (linalg .diagonal , x , ** kw , res = res , dims = 1 , true_val = true_diag )
205
206
206
207
@given (x = symmetric_matrices (finite = True ))
207
208
def test_eigh (x ):
208
- res = _array_module . linalg .eigh (x )
209
+ res = linalg .eigh (x )
209
210
210
211
_test_namedtuple (res , ['eigenvalues' , 'eigenvectors' ], 'eigh' )
211
212
@@ -218,35 +219,35 @@ def test_eigh(x):
218
219
assert eigenvectors .dtype == x .dtype , "eigh().eigenvectors did not return the correct dtype"
219
220
assert eigenvectors .shape == x .shape , "eigh().eigenvectors did not return the correct shape"
220
221
221
- _test_stacks (lambda x : _array_module . linalg .eigh (x ).eigenvalues , x ,
222
+ _test_stacks (lambda x : linalg .eigh (x ).eigenvalues , x ,
222
223
res = eigenvalues , dims = 1 )
223
- _test_stacks (lambda x : _array_module . linalg .eigh (x ).eigenvectors , x ,
224
+ _test_stacks (lambda x : linalg .eigh (x ).eigenvectors , x ,
224
225
res = eigenvectors , dims = 2 )
225
226
226
227
# TODO: Test that res actually corresponds to the eigenvalues and
227
228
# eigenvectors of x
228
229
229
230
@given (x = symmetric_matrices (finite = True ))
230
231
def test_eigvalsh (x ):
231
- res = _array_module . linalg .eigvalsh (x )
232
+ res = linalg .eigvalsh (x )
232
233
233
234
assert res .dtype == x .dtype , "eigvalsh() did not return the correct dtype"
234
235
assert res .shape == x .shape [:- 1 ], "eigvalsh() did not return the correct shape"
235
236
236
- _test_stacks (_array_module . linalg .eigvalsh , x , res = res , dims = 1 )
237
+ _test_stacks (linalg .eigvalsh , x , res = res , dims = 1 )
237
238
238
239
# TODO: Should we test that the result is the same as eigh(x).eigenvalues?
239
240
240
241
# TODO: Test that res actually corresponds to the eigenvalues of x
241
242
242
243
@given (x = invertible_matrices ())
243
244
def test_inv (x ):
244
- res = _array_module . linalg .inv (x )
245
+ res = linalg .inv (x )
245
246
246
247
assert res .shape == x .shape , "inv() did not return the correct shape"
247
248
assert res .dtype == x .dtype , "inv() did not return the correct dtype"
248
249
249
- _test_stacks (_array_module . linalg .inv , x , res = res )
250
+ _test_stacks (linalg .inv , x , res = res )
250
251
251
252
# TODO: Test that the result is actually the inverse
252
253
@@ -262,11 +263,11 @@ def test_matmul(x1, x2):
262
263
or len (x1 .shape ) >= 2 and len (x2 .shape ) >= 2 and x1 .shape [- 1 ] != x2 .shape [- 2 ]):
263
264
# The spec doesn't specify what kind of exception is used here. Most
264
265
# libraries will use a custom exception class.
265
- raises (Exception , lambda : _array_module . linalg .matmul (x1 , x2 ),
266
+ raises (Exception , lambda : linalg .matmul (x1 , x2 ),
266
267
"matmul did not raise an exception for invalid shapes" )
267
268
return
268
269
else :
269
- res = _array_module . linalg .matmul (x1 , x2 )
270
+ res = linalg .matmul (x1 , x2 )
270
271
271
272
# TODO: Replace result_type() with a helper function
272
273
assert res .dtype == _array_module .result_type (x1 , x2 ), "matmul() did not return the correct dtype"
@@ -275,21 +276,21 @@ def test_matmul(x1, x2):
275
276
assert res .shape == ()
276
277
elif len (x1 .shape ) == 1 :
277
278
assert res .shape == x2 .shape [:- 2 ] + x2 .shape [- 1 :]
278
- _test_stacks (_array_module . linalg .matmul , x1 , x2 , res = res , dims = 1 )
279
+ _test_stacks (linalg .matmul , x1 , x2 , res = res , dims = 1 )
279
280
elif len (x2 .shape ) == 1 :
280
281
assert res .shape == x1 .shape [:- 1 ]
281
- _test_stacks (_array_module . linalg .matmul , x1 , x2 , res = res , dims = 1 )
282
+ _test_stacks (linalg .matmul , x1 , x2 , res = res , dims = 1 )
282
283
else :
283
284
stack_shape = broadcast_shapes (x1 .shape [:- 2 ], x2 .shape [:- 2 ])
284
285
assert res .shape == stack_shape + (x1 .shape [- 2 ], x2 .shape [- 1 ])
285
- _test_stacks (_array_module . linalg .matmul , x1 , x2 , res = res )
286
+ _test_stacks (linalg .matmul , x1 , x2 , res = res )
286
287
287
288
@given (
288
289
x = xps .arrays (dtype = xps .floating_dtypes (), shape = shapes ),
289
290
kw = kwargs (axis = todo , keepdims = todo , ord = todo )
290
291
)
291
292
def test_matrix_norm (x , kw ):
292
- # res = _array_module. linalg.matrix_norm(x, **kw)
293
+ # res = linalg.matrix_norm(x, **kw)
293
294
pass
294
295
295
296
matrix_power_n = shared (integers (- 1000 , 1000 ), key = 'matrix_power n' )
@@ -301,82 +302,82 @@ def test_matrix_norm(x, kw):
301
302
n = matrix_power_n ,
302
303
)
303
304
def test_matrix_power (x , n ):
304
- res = _array_module . linalg .matrix_power (x , n )
305
+ res = linalg .matrix_power (x , n )
305
306
if n == 0 :
306
307
true_val = lambda x : _array_module .eye (x .shape [0 ], dtype = x .dtype )
307
308
else :
308
309
true_val = None
309
310
# _test_stacks only works with array arguments
310
- func = lambda x : _array_module . linalg .matrix_power (x , n )
311
+ func = lambda x : linalg .matrix_power (x , n )
311
312
_test_stacks (func , x , res = res , true_val = true_val )
312
313
313
314
@given (
314
315
x = xps .arrays (dtype = xps .floating_dtypes (), shape = shapes ),
315
316
kw = kwargs (rtol = todo )
316
317
)
317
318
def test_matrix_rank (x , kw ):
318
- # res = _array_module. linalg.matrix_rank(x, **kw)
319
+ # res = linalg.matrix_rank(x, **kw)
319
320
pass
320
321
321
322
@given (
322
323
x = xps .arrays (dtype = xps .floating_dtypes (), shape = shapes ),
323
324
)
324
325
def test_matrix_transpose (x ):
325
- # res = _array_module. linalg.matrix_transpose(x)
326
+ # res = linalg.matrix_transpose(x)
326
327
pass
327
328
328
329
@given (
329
330
x1 = xps .arrays (dtype = xps .floating_dtypes (), shape = shapes ),
330
331
x2 = xps .arrays (dtype = xps .floating_dtypes (), shape = shapes ),
331
332
)
332
333
def test_outer (x1 , x2 ):
333
- # res = _array_module. linalg.outer(x1, x2)
334
+ # res = linalg.outer(x1, x2)
334
335
pass
335
336
336
337
@given (
337
338
x = xps .arrays (dtype = xps .floating_dtypes (), shape = shapes ),
338
339
kw = kwargs (rtol = todo )
339
340
)
340
341
def test_pinv (x , kw ):
341
- # res = _array_module. linalg.pinv(x, **kw)
342
+ # res = linalg.pinv(x, **kw)
342
343
pass
343
344
344
345
@given (
345
346
x = xps .arrays (dtype = xps .floating_dtypes (), shape = shapes ),
346
347
kw = kwargs (mode = todo )
347
348
)
348
349
def test_qr (x , kw ):
349
- # res = _array_module. linalg.qr(x, **kw)
350
+ # res = linalg.qr(x, **kw)
350
351
pass
351
352
352
353
@given (
353
354
x = xps .arrays (dtype = xps .floating_dtypes (), shape = shapes ),
354
355
)
355
356
def test_slogdet (x ):
356
- # res = _array_module. linalg.slogdet(x)
357
+ # res = linalg.slogdet(x)
357
358
pass
358
359
359
360
@given (
360
361
x1 = xps .arrays (dtype = xps .floating_dtypes (), shape = shapes ),
361
362
x2 = xps .arrays (dtype = xps .floating_dtypes (), shape = shapes ),
362
363
)
363
364
def test_solve (x1 , x2 ):
364
- # res = _array_module. linalg.solve(x1, x2)
365
+ # res = linalg.solve(x1, x2)
365
366
pass
366
367
367
368
@given (
368
369
x = xps .arrays (dtype = xps .floating_dtypes (), shape = shapes ),
369
370
kw = kwargs (full_matrices = todo )
370
371
)
371
372
def test_svd (x , kw ):
372
- # res = _array_module. linalg.svd(x, **kw)
373
+ # res = linalg.svd(x, **kw)
373
374
pass
374
375
375
376
@given (
376
377
x = xps .arrays (dtype = xps .floating_dtypes (), shape = shapes ),
377
378
)
378
379
def test_svdvals (x ):
379
- # res = _array_module. linalg.svdvals(x)
380
+ # res = linalg.svdvals(x)
380
381
pass
381
382
382
383
@given (
@@ -385,15 +386,15 @@ def test_svdvals(x):
385
386
kw = kwargs (axes = todo )
386
387
)
387
388
def test_tensordot (x1 , x2 , kw ):
388
- # res = _array_module. linalg.tensordot(x1, x2, **kw)
389
+ # res = linalg.tensordot(x1, x2, **kw)
389
390
pass
390
391
391
392
@given (
392
393
x = xps .arrays (dtype = xps .floating_dtypes (), shape = shapes ),
393
394
kw = kwargs (offset = todo )
394
395
)
395
396
def test_trace (x , kw ):
396
- # res = _array_module. linalg.trace(x, **kw)
397
+ # res = linalg.trace(x, **kw)
397
398
pass
398
399
399
400
@given (
@@ -402,13 +403,13 @@ def test_trace(x, kw):
402
403
kw = kwargs (axis = todo )
403
404
)
404
405
def test_vecdot (x1 , x2 , kw ):
405
- # res = _array_module. linalg.vecdot(x1, x2, **kw)
406
+ # res = linalg.vecdot(x1, x2, **kw)
406
407
pass
407
408
408
409
@given (
409
410
x = xps .arrays (dtype = xps .floating_dtypes (), shape = shapes ),
410
411
kw = kwargs (axis = todo , keepdims = todo , ord = todo )
411
412
)
412
413
def test_vector_norm (x , kw ):
413
- # res = _array_module. linalg.vector_norm(x, **kw)
414
+ # res = linalg.vector_norm(x, **kw)
414
415
pass
0 commit comments