28
28
29
29
30
30
def test_class_jax_tf ():
31
+ from ot .backend import tf
32
+
31
33
backends = []
32
- from ot .backend import jax , tf
33
- if jax :
34
- backends .append (ot .backend .JaxBackend ())
35
34
if tf :
36
35
backends .append (ot .backend .TensorflowBackend ())
37
36
@@ -70,7 +69,6 @@ def test_log_da(nx, class_to_test):
70
69
assert hasattr (otda , "log_" )
71
70
72
71
73
- @pytest .skip_backend ("jax" )
74
72
@pytest .skip_backend ("tf" )
75
73
def test_sinkhorn_lpl1_transport_class (nx ):
76
74
"""test_sinkhorn_transport
@@ -79,10 +77,13 @@ def test_sinkhorn_lpl1_transport_class(nx):
79
77
ns = 50
80
78
nt = 50
81
79
82
- Xs , ys = make_data_classif ('3gauss' , ns )
83
- Xt , yt = make_data_classif ('3gauss2' , nt )
80
+ Xs , ys = make_data_classif ('3gauss' , ns , random_state = 42 )
81
+ Xt , yt = make_data_classif ('3gauss2' , nt , random_state = 43 )
82
+ # prepare semi-supervised labels
83
+ yt_semi = np .copy (yt )
84
+ yt_semi [np .arange (0 , nt , 2 )] = - 1
84
85
85
- Xs , ys , Xt , yt = nx .from_numpy (Xs , ys , Xt , yt )
86
+ Xs , ys , Xt , yt , yt_semi = nx .from_numpy (Xs , ys , Xt , yt , yt_semi )
86
87
87
88
otda = ot .da .SinkhornLpl1Transport ()
88
89
@@ -109,7 +110,7 @@ def test_sinkhorn_lpl1_transport_class(nx):
109
110
transp_Xs = otda .transform (Xs = Xs )
110
111
assert_equal (transp_Xs .shape , Xs .shape )
111
112
112
- Xs_new = nx .from_numpy (make_data_classif ('3gauss' , ns + 1 )[0 ])
113
+ Xs_new = nx .from_numpy (make_data_classif ('3gauss' , ns + 1 , random_state = 44 )[0 ])
113
114
transp_Xs_new = otda .transform (Xs_new )
114
115
115
116
# check that the oos method is working
@@ -119,7 +120,7 @@ def test_sinkhorn_lpl1_transport_class(nx):
119
120
transp_Xt = otda .inverse_transform (Xt = Xt )
120
121
assert_equal (transp_Xt .shape , Xt .shape )
121
122
122
- Xt_new = nx .from_numpy (make_data_classif ('3gauss2' , nt + 1 )[0 ])
123
+ Xt_new = nx .from_numpy (make_data_classif ('3gauss2' , nt + 1 , random_state = 45 )[0 ])
123
124
transp_Xt_new = otda .inverse_transform (Xt = Xt_new )
124
125
125
126
# check that the oos method is working
@@ -142,10 +143,12 @@ def test_sinkhorn_lpl1_transport_class(nx):
142
143
# test unsupervised vs semi-supervised mode
143
144
otda_unsup = ot .da .SinkhornLpl1Transport ()
144
145
otda_unsup .fit (Xs = Xs , ys = ys , Xt = Xt )
146
+ assert np .all (np .isfinite (nx .to_numpy (otda_unsup .coupling_ ))), "unsup coupling is finite"
145
147
n_unsup = nx .sum (otda_unsup .cost_ )
146
148
147
149
otda_semi = ot .da .SinkhornLpl1Transport ()
148
- otda_semi .fit (Xs = Xs , ys = ys , Xt = Xt , yt = yt )
150
+ otda_semi .fit (Xs = Xs , ys = ys , Xt = Xt , yt = yt_semi )
151
+ assert np .all (np .isfinite (nx .to_numpy (otda_semi .coupling_ ))), "semi coupling is finite"
149
152
assert_equal (otda_semi .cost_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
150
153
n_semisup = nx .sum (otda_semi .cost_ )
151
154
@@ -944,3 +947,42 @@ def df2(G):
944
947
945
948
assert np .allclose (f (G ), f2 (G ))
946
949
assert np .allclose (df (G ), df2 (G ))
950
+
951
+
952
+ @pytest .skip_backend ("jax" )
953
+ @pytest .skip_backend ("tf" )
954
+ def test_sinkhorn_lpl1_vectorization (nx ):
955
+ n_samples , n_labels = 150 , 3
956
+ rng = np .random .RandomState (42 )
957
+ M = rng .rand (n_samples , n_samples )
958
+ labels_a = rng .randint (n_labels , size = (n_samples ,))
959
+ M , labels_a = nx .from_numpy (M ), nx .from_numpy (labels_a )
960
+
961
+ # hard-coded params from the original code
962
+ p , epsilon = 0.5 , 1e-3
963
+ T = nx .from_numpy (rng .rand (n_samples , n_samples ))
964
+
965
+ def unvectorized (transp ):
966
+ indices_labels = []
967
+ classes = nx .unique (labels_a )
968
+ for c in classes :
969
+ idxc , = nx .where (labels_a == c )
970
+ indices_labels .append (idxc )
971
+ W = nx .ones (M .shape , type_as = M )
972
+ for (i , c ) in enumerate (classes ):
973
+ majs = nx .sum (transp [indices_labels [i ]], axis = 0 )
974
+ majs = p * ((majs + epsilon ) ** (p - 1 ))
975
+ W [indices_labels [i ]] = majs
976
+ return W
977
+
978
+ def vectorized (transp ):
979
+ labels_u , labels_idx = nx .unique (labels_a , return_inverse = True )
980
+ n_labels = labels_u .shape [0 ]
981
+ unroll_labels_idx = nx .eye (n_labels , type_as = transp )[labels_idx ]
982
+ W = nx .repeat (transp .T [:, :, None ], n_labels , axis = 2 ) * unroll_labels_idx [None , :, :]
983
+ W = nx .sum (W , axis = 1 )
984
+ W = p * ((W + epsilon ) ** (p - 1 ))
985
+ W = nx .dot (W , unroll_labels_idx .T )
986
+ return W .T
987
+
988
+ assert np .allclose (unvectorized (T ), vectorized (T ))
0 commit comments