Skip to content

Commit 89353bd

Browse files
committed
Actually test dtype in TestGer.given_dtype
1 parent b294656 commit 89353bd

File tree

1 file changed

+30
-18
lines changed

1 file changed

+30
-18
lines changed

tests/tensor/test_blas.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1959,52 +1959,64 @@ def test_scaled_A_plus_scaled_outer(self):
19591959
rng.random((4)).astype(self.dtype),
19601960
).shape == (5, 4)
19611961

1962-
def given_dtype(self, dtype, M, N):
1962+
def given_dtype(self, dtype, M, N, *, destructive=True):
19631963
# test corner case shape and dtype
19641964
rng = np.random.default_rng(unittest_tools.fetch_seed())
19651965

1966-
f = self.function(
1967-
[self.A, self.x, self.y], self.A + 0.1 * outer(self.x, self.y)
1966+
A = tensor(dtype=dtype, shape=(False, False))
1967+
x = tensor(dtype=dtype, shape=(False,))
1968+
y = tensor(dtype=dtype, shape=(False,))
1969+
1970+
f = self.function([A, x, y], A + 0.1 * outer(x, y))
1971+
self.assertFunctionContains(
1972+
f, self.ger_destructive if destructive else self.ger
19681973
)
1969-
self.assertFunctionContains(f, self.ger)
19701974
f(
1971-
rng.random((M, N)).astype(self.dtype),
1972-
rng.random((M)).astype(self.dtype),
1973-
rng.random((N)).astype(self.dtype),
1975+
rng.random((M, N)).astype(dtype),
1976+
rng.random((M)).astype(dtype),
1977+
rng.random((N)).astype(dtype),
19741978
).shape == (5, 4)
19751979
f(
1976-
rng.random((M, N)).astype(self.dtype)[::-1, ::-1],
1977-
rng.random((M)).astype(self.dtype),
1978-
rng.random((N)).astype(self.dtype),
1980+
rng.random((M, N)).astype(dtype)[::-1, ::-1],
1981+
rng.random((M)).astype(dtype),
1982+
rng.random((N)).astype(dtype),
19791983
).shape == (5, 4)
19801984

19811985
def test_f32_0_0(self):
1982-
return self.given_dtype("float32", 0, 0)
1986+
return self.given_dtype("float32", 0, 0, destructive=config.floatX != "float32")
19831987

19841988
def test_f32_1_0(self):
1985-
return self.given_dtype("float32", 1, 0)
1989+
return self.given_dtype("float32", 1, 0, destructive=config.floatX != "float32")
19861990

19871991
def test_f32_0_1(self):
1988-
return self.given_dtype("float32", 0, 1)
1992+
return self.given_dtype("float32", 0, 1, destructive=config.floatX != "float32")
19891993

19901994
def test_f32_1_1(self):
1991-
return self.given_dtype("float32", 1, 1)
1995+
return self.given_dtype("float32", 1, 1, destructive=config.floatX != "float32")
19921996

19931997
def test_f32_4_4(self):
1994-
return self.given_dtype("float32", 4, 4)
1998+
return self.given_dtype("float32", 4, 4, destructive=config.floatX != "float32")
19951999

19962000
def test_f32_7_1(self):
1997-
return self.given_dtype("float32", 7, 1)
2001+
return self.given_dtype("float32", 7, 1, destructive=config.floatX != "float32")
19982002

19992003
def test_f32_1_2(self):
2000-
return self.given_dtype("float32", 1, 2)
2004+
return self.given_dtype("float32", 1, 2, destructive=config.floatX != "float32")
20012005

20022006
def test_f64_4_5(self):
2003-
return self.given_dtype("float64", 4, 5)
2007+
return self.given_dtype("float64", 4, 5, destructive=False)
20042008

2009+
@pytest.mark.xfail(
2010+
condition=config.floatX == "float32",
2011+
reason="GER from complex64 is not introduced in float32 mode",
2012+
)
20052013
def test_c64_7_1(self):
20062014
return self.given_dtype("complex64", 7, 1)
20072015

2016+
@pytest.mark.xfail(
2017+
raises=AssertionError,
2018+
reason="Unclear how this test was supposed to work with complex128",
2019+
)
20082020
def test_c128_1_9(self):
20092021
return self.given_dtype("complex128", 1, 9)
20102022

0 commit comments

Comments
 (0)