@@ -181,19 +181,24 @@ def test_gromov2_gradients():
181
181
182
182
if torch :
183
183
184
- p1 = torch .tensor (p , requires_grad = True )
185
- q1 = torch .tensor (q , requires_grad = True )
186
- C11 = torch .tensor (C1 , requires_grad = True )
187
- C12 = torch .tensor (C2 , requires_grad = True )
184
+ devices = [torch .device ("cpu" )]
185
+ if torch .cuda .is_available ():
186
+ devices .append (torch .device ("cuda" ))
187
+ for device in devices :
188
+ p1 = torch .tensor (p , requires_grad = True , device = device )
189
+ q1 = torch .tensor (q , requires_grad = True , device = device )
190
+ C11 = torch .tensor (C1 , requires_grad = True , device = device )
191
+ C12 = torch .tensor (C2 , requires_grad = True , device = device )
188
192
189
- val = ot .gromov_wasserstein2 (C11 , C12 , p1 , q1 )
193
+ val = ot .gromov_wasserstein2 (C11 , C12 , p1 , q1 )
190
194
191
- val .backward ()
195
+ val .backward ()
192
196
193
- assert q1 .shape == q1 .grad .shape
194
- assert p1 .shape == p1 .grad .shape
195
- assert C11 .shape == C11 .grad .shape
196
- assert C12 .shape == C12 .grad .shape
197
+ assert val .device == p1 .device
198
+ assert q1 .shape == q1 .grad .shape
199
+ assert p1 .shape == p1 .grad .shape
200
+ assert C11 .shape == C11 .grad .shape
201
+ assert C12 .shape == C12 .grad .shape
197
202
198
203
199
204
@pytest .skip_backend ("jax" , reason = "test very slow with jax backend" )
@@ -636,21 +641,26 @@ def test_fgw2_gradients():
636
641
637
642
if torch :
638
643
639
- p1 = torch .tensor (p , requires_grad = True )
640
- q1 = torch .tensor (q , requires_grad = True )
641
- C11 = torch .tensor (C1 , requires_grad = True )
642
- C12 = torch .tensor (C2 , requires_grad = True )
643
- M1 = torch .tensor (M , requires_grad = True )
644
-
645
- val = ot .fused_gromov_wasserstein2 (M1 , C11 , C12 , p1 , q1 )
646
-
647
- val .backward ()
648
-
649
- assert q1 .shape == q1 .grad .shape
650
- assert p1 .shape == p1 .grad .shape
651
- assert C11 .shape == C11 .grad .shape
652
- assert C12 .shape == C12 .grad .shape
653
- assert M1 .shape == M1 .grad .shape
644
+ devices = [torch .device ("cpu" )]
645
+ if torch .cuda .is_available ():
646
+ devices .append (torch .device ("cuda" ))
647
+ for device in devices :
648
+ p1 = torch .tensor (p , requires_grad = True , device = device )
649
+ q1 = torch .tensor (q , requires_grad = True , device = device )
650
+ C11 = torch .tensor (C1 , requires_grad = True , device = device )
651
+ C12 = torch .tensor (C2 , requires_grad = True , device = device )
652
+ M1 = torch .tensor (M , requires_grad = True , device = device )
653
+
654
+ val = ot .fused_gromov_wasserstein2 (M1 , C11 , C12 , p1 , q1 )
655
+
656
+ val .backward ()
657
+
658
+ assert val .device == p1 .device
659
+ assert q1 .shape == q1 .grad .shape
660
+ assert p1 .shape == p1 .grad .shape
661
+ assert C11 .shape == C11 .grad .shape
662
+ assert C12 .shape == C12 .grad .shape
663
+ assert M1 .shape == M1 .grad .shape
654
664
655
665
656
666
def test_fgw_barycenter (nx ):
0 commit comments