Skip to content

Commit d82c026

Browse files
Update tests due to new generate_random_numpy_array()
1 parent 5d0f39a commit d82c026

File tree

2 files changed

+9
-15
lines changed

2 files changed

+9
-15
lines changed

dpnp/tests/test_fft.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,7 @@ def setup_method(self):
380380
@pytest.mark.parametrize("norm", [None, "forward", "backward", "ortho"])
381381
@pytest.mark.parametrize("order", ["C", "F"])
382382
def test_fft2(self, dtype, axes, norm, order):
383-
x = generate_random_numpy_array((2, 3, 4), dtype)
384-
a_np = numpy.array(x, order=order)
383+
a_np = generate_random_numpy_array((2, 3, 4), dtype, order)
385384
a = dpnp.array(a_np)
386385

387386
result = dpnp.fft.fft2(a, axes=axes, norm=norm)
@@ -443,8 +442,7 @@ def setup_method(self):
443442
@pytest.mark.parametrize("norm", [None, "backward", "forward", "ortho"])
444443
@pytest.mark.parametrize("order", ["C", "F"])
445444
def test_fftn(self, dtype, axes, norm, order):
446-
x = generate_random_numpy_array((2, 3, 4, 5), dtype)
447-
a_np = numpy.array(x, order=order)
445+
a_np = generate_random_numpy_array((2, 3, 4, 5), dtype, order)
448446
a = dpnp.array(a_np)
449447

450448
result = dpnp.fft.fftn(a, axes=axes, norm=norm)
@@ -698,8 +696,7 @@ def test_irfft_1D_on_2D_array(self, dtype, n, axis, norm, order):
698696
@pytest.mark.parametrize("norm", [None, "backward", "forward", "ortho"])
699697
@pytest.mark.parametrize("order", ["C", "F"])
700698
def test_irfft_1D_on_3D_array(self, dtype, n, axis, norm, order):
701-
x = generate_random_numpy_array((4, 5, 6), dtype)
702-
a_np = numpy.array(x, order=order)
699+
a_np = generate_random_numpy_array((4, 5, 6), dtype, order)
703700
# each 1-D array of input should be Hermitian
704701
if axis == 0:
705702
a_np[0].imag = 0
@@ -936,8 +933,7 @@ def setup_method(self):
936933
@pytest.mark.parametrize("norm", [None, "backward", "forward", "ortho"])
937934
@pytest.mark.parametrize("order", ["C", "F"])
938935
def test_rfft2(self, dtype, axes, norm, order):
939-
x = generate_random_numpy_array((2, 3, 4), dtype)
940-
a_np = numpy.array(x, order=order)
936+
a_np = generate_random_numpy_array((2, 3, 4), dtype, order)
941937
a = dpnp.asarray(a_np)
942938

943939
result = dpnp.fft.rfft2(a, axes=axes, norm=norm)
@@ -1001,8 +997,7 @@ def setup_method(self):
1001997
@pytest.mark.parametrize("norm", [None, "backward", "forward", "ortho"])
1002998
@pytest.mark.parametrize("order", ["C", "F"])
1003999
def test_rfftn(self, dtype, axes, norm, order):
1004-
x = generate_random_numpy_array((2, 3, 4, 5), dtype)
1005-
a_np = numpy.array(x, order=order)
1000+
a_np = generate_random_numpy_array((2, 3, 4, 5), dtype, order)
10061001
a = dpnp.asarray(a_np)
10071002

10081003
result = dpnp.fft.rfftn(a, axes=axes, norm=norm)

dpnp/tests/test_linalg.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -503,24 +503,23 @@ def test_eigenvalues(self, func, shape, dtype, order):
503503
# non-symmetric for eig() and eigvals()
504504
is_hermitian = func in ("eigh, eigvalsh")
505505
a = generate_random_numpy_array(
506-
shape, dtype, hermitian=is_hermitian, low=-4, high=4
506+
shape, dtype, order, hermitian=is_hermitian, low=-4, high=4
507507
)
508-
a_order = numpy.array(a, order=order)
509-
a_dp = dpnp.array(a, order=order)
508+
a_dp = dpnp.array(a)
510509

511510
# NumPy with OneMKL and with rocSOLVER sorts in ascending order,
512511
# so w's should be directly comparable.
513512
# However, both OneMKL and rocSOLVER pick a different convention for
514513
# constructing eigenvectors, so v's are not directly comparable and
515514
# we verify them through the eigen equation A*v=w*v.
516515
if func in ("eig", "eigh"):
517-
w, _ = getattr(numpy.linalg, func)(a_order)
516+
w, _ = getattr(numpy.linalg, func)(a)
518517
w_dp, v_dp = getattr(dpnp.linalg, func)(a_dp)
519518

520519
self.assert_eigen_decomposition(a_dp, w_dp, v_dp)
521520

522521
else: # eighvals or eigvalsh
523-
w = getattr(numpy.linalg, func)(a_order)
522+
w = getattr(numpy.linalg, func)(a)
524523
w_dp = getattr(dpnp.linalg, func)(a_dp)
525524

526525
assert_dtype_allclose(w_dp, w, factor=24)

0 commit comments

Comments
 (0)