@@ -3552,16 +3552,10 @@ class TestDiag:
3552
3552
"""
3553
3553
Test that linalg.diag has the same behavior as numpy.diag.
3554
3554
numpy.diag has two behaviors:
3555
- (1) when given a vector, it returns a matrix with that vector as the
3556
- diagonal.
3557
- (2) when given a matrix, returns a vector which is the diagonal of the
3558
- matrix.
3555
+ (1) when given a vector, it returns a matrix with that vector as the diagonal.
3556
+ (2) when given a matrix, returns a vector which is the diagonal of the matrix.
3559
3557
3560
- (1) and (2) are tested by test_alloc_diag and test_extract_diag
3561
- respectively.
3562
-
3563
- test_diag test makes sure that linalg.diag instantiates
3564
- the right op based on the dimension of the input.
3558
+ (1) and (2) are further tested by TestAllocDiag and TestExtractDiag, respectively.
3565
3559
"""
3566
3560
3567
3561
def setup_method (self ):
@@ -3571,6 +3565,7 @@ def setup_method(self):
3571
3565
self .type = TensorType
3572
3566
3573
3567
def test_diag (self ):
3568
+ """Makes sure that diag instantiates the right op based on the dimension of the input."""
3574
3569
rng = np .random .default_rng (utt .fetch_seed ())
3575
3570
3576
3571
# test vector input
@@ -3609,38 +3604,55 @@ def test_diag(self):
3609
3604
f = function ([], g )
3610
3605
assert np .array_equal (f (), np .diag (xx ))
3611
3606
3612
- def test_infer_shape (self ):
3607
+
3608
+ class TestExtractDiag :
3609
+ @pytest .mark .parametrize ("axis1, axis2" , [(0 , 1 ), (1 , 0 )])
3610
+ @pytest .mark .parametrize ("offset" , (- 1 , 0 , 2 ))
3611
+ def test_infer_shape (self , offset , axis1 , axis2 ):
3613
3612
rng = np .random .default_rng (utt .fetch_seed ())
3614
3613
3615
- x = vector ()
3616
- g = diag (x )
3617
- f = pytensor .function ([x ], g .shape )
3618
- topo = f .maker .fgraph .toposort ()
3619
- if config .mode != "FAST_COMPILE" :
3620
- assert sum (isinstance (node .op , AllocDiag ) for node in topo ) == 0
3621
- for shp in [5 , 0 , 1 ]:
3622
- m = rng .random (shp ).astype (self .floatX )
3623
- assert (f (m ) == np .diag (m ).shape ).all ()
3624
-
3625
- x = matrix ()
3626
- g = diag (x )
3614
+ x = matrix ("x" )
3615
+ g = ExtractDiag (offset = offset , axis1 = axis1 , axis2 = axis2 )(x )
3627
3616
f = pytensor .function ([x ], g .shape )
3628
3617
topo = f .maker .fgraph .toposort ()
3629
3618
if config .mode != "FAST_COMPILE" :
3630
3619
assert sum (isinstance (node .op , ExtractDiag ) for node in topo ) == 0
3631
3620
for shp in [(5 , 3 ), (3 , 5 ), (5 , 1 ), (1 , 5 ), (5 , 0 ), (0 , 5 ), (1 , 0 ), (0 , 1 )]:
3632
- m = rng .random (shp ).astype (self .floatX )
3633
- assert (f (m ) == np .diag (m ).shape ).all ()
3621
+ m = rng .random (shp ).astype (config .floatX )
3622
+ assert (
3623
+ f (m ) == np .diagonal (m , offset = offset , axis1 = axis1 , axis2 = axis2 ).shape
3624
+ ).all ()
3634
3625
3635
- def test_diag_grad (self ):
3626
+ @pytest .mark .parametrize ("axis1, axis2" , [(0 , 1 ), (1 , 0 )])
3627
+ @pytest .mark .parametrize ("offset" , (0 , 1 , - 1 ))
3628
+ def test_grad_2d (self , offset , axis1 , axis2 ):
3629
+ diag_fn = ExtractDiag (offset = offset , axis1 = axis1 , axis2 = axis2 )
3636
3630
rng = np .random .default_rng (utt .fetch_seed ())
3637
- x = rng .random (5 )
3638
- utt .verify_grad (diag , [x ], rng = rng )
3639
3631
x = rng .random ((5 , 3 ))
3640
- utt .verify_grad (diag , [x ], rng = rng )
3632
+ utt .verify_grad (diag_fn , [x ], rng = rng )
3633
+
3634
+ @pytest .mark .parametrize (
3635
+ "axis1, axis2" ,
3636
+ [
3637
+ (0 , 1 ),
3638
+ (1 , 0 ),
3639
+ (1 , 2 ),
3640
+ (2 , 1 ),
3641
+ (0 , 2 ),
3642
+ (2 , 0 ),
3643
+ ],
3644
+ )
3645
+ @pytest .mark .parametrize ("offset" , (0 , 1 , - 1 ))
3646
+ def test_grad_3d (self , offset , axis1 , axis2 ):
3647
+ diag_fn = ExtractDiag (offset = offset , axis1 = axis1 , axis2 = axis2 )
3648
+ rng = np .random .default_rng (utt .fetch_seed ())
3649
+ x = rng .random ((5 , 4 , 3 ))
3650
+ utt .verify_grad (diag_fn , [x ], rng = rng )
3641
3651
3642
3652
3643
3653
class TestAllocDiag :
3654
+ # TODO: Separate perform, grad and infer_shape tests
3655
+
3644
3656
def setup_method (self ):
3645
3657
self .alloc_diag = AllocDiag
3646
3658
self .mode = pytensor .compile .mode .get_default_mode ()
@@ -3674,7 +3686,7 @@ def test_alloc_diag_values(self):
3674
3686
(- 2 , 0 , 1 ),
3675
3687
(- 1 , 1 , 2 ),
3676
3688
]:
3677
- # Test AllocDiag values
3689
+ # Test perform
3678
3690
if np .maximum (axis1 , axis2 ) > len (test_val .shape ):
3679
3691
continue
3680
3692
adiag_op = self .alloc_diag (offset = offset , axis1 = axis1 , axis2 = axis2 )
@@ -3688,7 +3700,6 @@ def test_alloc_diag_values(self):
3688
3700
# Test infer_shape
3689
3701
f_shape = pytensor .function ([x ], adiag_op (x ).shape , mode = "FAST_RUN" )
3690
3702
3691
- # pytensor.printing.debugprint(f_shape.maker.fgraph.outputs[0])
3692
3703
output_shape = f_shape (test_val )
3693
3704
assert not any (
3694
3705
isinstance (node .op , self .alloc_diag )
@@ -3699,6 +3710,7 @@ def test_alloc_diag_values(self):
3699
3710
).shape
3700
3711
assert np .all (rediag_shape == test_val .shape )
3701
3712
3713
+ # Test grad
3702
3714
diag_x = adiag_op (x )
3703
3715
sum_diag_x = at_sum (diag_x )
3704
3716
grad_x = pytensor .grad (sum_diag_x , x )
@@ -3710,7 +3722,6 @@ def test_alloc_diag_values(self):
3710
3722
true_grad_input = np .diagonal (
3711
3723
grad_diag_input , offset = offset , axis1 = axis1 , axis2 = axis2
3712
3724
)
3713
-
3714
3725
assert np .all (true_grad_input == grad_input )
3715
3726
3716
3727
0 commit comments