@@ -1959,52 +1959,64 @@ def test_scaled_A_plus_scaled_outer(self):
1959
1959
rng .random ((4 )).astype (self .dtype ),
1960
1960
).shape == (5 , 4 )
1961
1961
1962
- def given_dtype (self , dtype , M , N ):
1962
+ def given_dtype (self , dtype , M , N , * , destructive = True ):
1963
1963
# test corner case shape and dtype
1964
1964
rng = np .random .default_rng (unittest_tools .fetch_seed ())
1965
1965
1966
- f = self .function (
1967
- [self .A , self .x , self .y ], self .A + 0.1 * outer (self .x , self .y )
1966
+ A = tensor (dtype = dtype , shape = (False , False ))
1967
+ x = tensor (dtype = dtype , shape = (False ,))
1968
+ y = tensor (dtype = dtype , shape = (False ,))
1969
+
1970
+ f = self .function ([A , x , y ], A + 0.1 * outer (x , y ))
1971
+ self .assertFunctionContains (
1972
+ f , self .ger_destructive if destructive else self .ger
1968
1973
)
1969
- self .assertFunctionContains (f , self .ger )
1970
1974
f (
1971
- rng .random ((M , N )).astype (self . dtype ),
1972
- rng .random ((M )).astype (self . dtype ),
1973
- rng .random ((N )).astype (self . dtype ),
1975
+ rng .random ((M , N )).astype (dtype ),
1976
+ rng .random ((M )).astype (dtype ),
1977
+ rng .random ((N )).astype (dtype ),
1974
1978
).shape == (5 , 4 )
1975
1979
f (
1976
- rng .random ((M , N )).astype (self . dtype )[::- 1 , ::- 1 ],
1977
- rng .random ((M )).astype (self . dtype ),
1978
- rng .random ((N )).astype (self . dtype ),
1980
+ rng .random ((M , N )).astype (dtype )[::- 1 , ::- 1 ],
1981
+ rng .random ((M )).astype (dtype ),
1982
+ rng .random ((N )).astype (dtype ),
1979
1983
).shape == (5 , 4 )
1980
1984
1981
1985
def test_f32_0_0 (self ):
1982
- return self .given_dtype ("float32" , 0 , 0 )
1986
+ return self .given_dtype ("float32" , 0 , 0 , destructive = config . floatX != "float32" )
1983
1987
1984
1988
def test_f32_1_0 (self ):
1985
- return self .given_dtype ("float32" , 1 , 0 )
1989
+ return self .given_dtype ("float32" , 1 , 0 , destructive = config . floatX != "float32" )
1986
1990
1987
1991
def test_f32_0_1 (self ):
1988
- return self .given_dtype ("float32" , 0 , 1 )
1992
+ return self .given_dtype ("float32" , 0 , 1 , destructive = config . floatX != "float32" )
1989
1993
1990
1994
def test_f32_1_1 (self ):
1991
- return self .given_dtype ("float32" , 1 , 1 )
1995
+ return self .given_dtype ("float32" , 1 , 1 , destructive = config . floatX != "float32" )
1992
1996
1993
1997
def test_f32_4_4 (self ):
1994
- return self .given_dtype ("float32" , 4 , 4 )
1998
+ return self .given_dtype ("float32" , 4 , 4 , destructive = config . floatX != "float32" )
1995
1999
1996
2000
def test_f32_7_1 (self ):
1997
- return self .given_dtype ("float32" , 7 , 1 )
2001
+ return self .given_dtype ("float32" , 7 , 1 , destructive = config . floatX != "float32" )
1998
2002
1999
2003
def test_f32_1_2 (self ):
2000
- return self .given_dtype ("float32" , 1 , 2 )
2004
+ return self .given_dtype ("float32" , 1 , 2 , destructive = config . floatX != "float32" )
2001
2005
2002
2006
def test_f64_4_5 (self ):
2003
- return self .given_dtype ("float64" , 4 , 5 )
2007
+ return self .given_dtype ("float64" , 4 , 5 , destructive = False )
2004
2008
2009
+ @pytest .mark .xfail (
2010
+ condition = config .floatX == "float32" ,
2011
+ reason = "GER from complex64 is not introduced in float32 mode" ,
2012
+ )
2005
2013
def test_c64_7_1 (self ):
2006
2014
return self .given_dtype ("complex64" , 7 , 1 )
2007
2015
2016
+ @pytest .mark .xfail (
2017
+ raises = AssertionError ,
2018
+ reason = "Unclear how this test was supposed to work with complex128" ,
2019
+ )
2008
2020
def test_c128_1_9 (self ):
2009
2021
return self .given_dtype ("complex128" , 1 , 9 )
2010
2022
0 commit comments