Skip to content

Commit 308a5b4

Browse files
committed
Avoid catching all warnings as JAX throws deprecation
1 parent dfc79f0 commit 308a5b4

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

test/test_da.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,9 @@ def test_sinkhorn_l1l2_transport_class(nx):
175175
Xs, ys, Xt, yt, yt_semi = nx.from_numpy(Xs, ys, Xt, yt, yt_semi)
176176

177177
otda = ot.da.SinkhornL1l2Transport(max_inner_iter=500)
178+
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
178179

179180
# test its computed
180-
with warnings.catch_warnings():
181-
warnings.simplefilter("error")
182-
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
183181
assert hasattr(otda, "cost_")
184182
assert hasattr(otda, "coupling_")
185183
assert hasattr(otda, "log_")

0 commit comments

Comments
 (0)