Skip to content

Commit a7ce443

Browse files
committed
update stuff
1 parent 7053e00 commit a7ce443

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

ot/da.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,12 +1058,13 @@ class SinkhornTransport(BaseTransport):
10581058
can occur with large metric values.
10591059
distribution_estimation : callable, optional (defaults to the uniform)
10601060
The kind of distribution estimation to employ
1061-
out_of_sample_map : string, optional (default="pooladian")
1061+
out_of_sample_map : string, optional (default="continuous")
10621062
The kind of out of sample mapping to apply to transport samples
10631063
from a domain into another one. Currently the only possible option is
10641064
"ferradans" which uses the nearest neighbor method proposed in :ref:`[6]
1065-
<references-sinkhorntransport>` while "pooladian" use the out of sample
1065+
<references-sinkhorntransport>` while "continuous" use the out of sample
10661066
method from :ref:`[66]
1067+
<references-sinkhorntransport>` and :ref:`[19]
10671068
<references-sinkhorntransport>`.
10681069
limit_max: float, optional (default=np.infty)
10691070
Controls the semi supervised mode. Transport between labeled source
@@ -1093,6 +1094,10 @@ class SinkhornTransport(BaseTransport):
10931094
Regularized discrete optimal transport. SIAM Journal on Imaging
10941095
Sciences, 7(3), 1853-1882.
10951096
1097+
.. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.
1098+
& Blondel, M. Large-scale Optimal Transport and Mapping Estimation.
1099+
International Conference on Learning Representation (2018)
1100+
10961101
.. [66] Pooladian, Aram-Alexandre, and Jonathan Niles-Weed. "Entropic
10971102
estimation of optimal transport maps." arXiv preprint
10981103
arXiv:2109.12004 (2021).
@@ -1103,9 +1108,9 @@ def __init__(self, reg_e=1., method="sinkhorn_log", max_iter=1000,
11031108
tol=10e-9, verbose=False, log=False,
11041109
metric="sqeuclidean", norm=None,
11051110
distribution_estimation=distribution_estimation_uniform,
1106-
out_of_sample_map='pooladian', limit_max=np.infty):
1111+
out_of_sample_map='continuous', limit_max=np.infty):
11071112

1108-
if out_of_sample_map not in ['ferradans', 'pooladian']:
1113+
if out_of_sample_map not in ['ferradans', 'continuous']:
11091114
raise ValueError('Unknown out_of_sample_map method')
11101115

11111116
self.reg_e = reg_e
@@ -1147,11 +1152,11 @@ class label
11471152

11481153
super(SinkhornTransport, self).fit(Xs, ys, Xt, yt)
11491154

1150-
if self.out_of_sample_map == 'pooladian':
1155+
if self.out_of_sample_map == 'continuous':
11511156
self.log = True
11521157
if not self.method == 'sinkhorn_log':
11531158
self.method = 'sinkhorn_log'
1154-
warnings.warn("The method has been set to 'sinkhorn_log' as it is the only method available for out_of_sample_map='pooladian'")
1159+
warnings.warn("The method has been set to 'sinkhorn_log' as it is the only method available for out_of_sample_map='continuous'")
11551160

11561161
# coupling estimation
11571162
returned_ = sinkhorn(
@@ -1198,7 +1203,7 @@ class label
11981203
if self.out_of_sample_map == 'ferradans':
11991204
return super(SinkhornTransport, self).transform(Xs, ys, Xt, yt, batch_size)
12001205

1201-
else: # self.out_of_sample_map == 'pooladian':
1206+
else: # self.out_of_sample_map == 'continuous':
12021207

12031208
# check the necessary inputs parameters are here
12041209
g = self.log_['log_v']
@@ -1243,7 +1248,7 @@ class label
12431248
if self.out_of_sample_map == 'ferradans':
12441249
return super(SinkhornTransport, self).inverse_transform(Xs, ys, Xt, yt, batch_size)
12451250

1246-
else: # self.out_of_sample_map == 'pooladian':
1251+
else: # self.out_of_sample_map == 'continuous':
12471252

12481253
f = self.log_['log_u']
12491254

test/test_da.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -354,11 +354,14 @@ def test_sinkhorn_transport_class(nx):
354354
assert_equal(transp_Xt.shape, Xt.shape)
355355

356356
# test diffeernt transform
357-
otda = ot.da.SinkhornTransport(out_of_sample_map='pooladian')
358-
transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt)
359-
assert_equal(transp_Xs.shape, Xs.shape)
360-
transp_Xt = otda.inverse_transform(Xt=Xt)
361-
assert_equal(transp_Xt.shape, Xt.shape)
357+
otda = ot.da.SinkhornTransport(out_of_sample_map='continuous')
358+
transp_Xs2 = otda.fit_transform(Xs=Xs, Xt=Xt)
359+
assert_equal(transp_Xs2.shape, Xs.shape)
360+
transp_Xt2 = otda.inverse_transform(Xt=Xt)
361+
assert_equal(transp_Xt2.shape, Xt.shape)
362+
363+
np.testing.assert_almost_equal(nx.to_numpy(transp_Xs), nx.to_numpy(transp_Xs2), decimal=5)
364+
np.testing.assert_almost_equal(nx.to_numpy(transp_Xt), nx.to_numpy(transp_Xt2), decimal=5)
362365

363366
with pytest.raises(ValueError):
364367
otda = ot.da.SinkhornTransport(out_of_sample_map='unknown')

0 commit comments

Comments
 (0)