Skip to content

UnbalancedSinkhorn Transport fails to transform due to "nx.array_equal" #650

Open
@smestern

Description

@smestern

Describe the bug

It appears UnbalancedSinkhornTransport fails on .transform() calls. It seems to fail on the nx.array_equal(self.xs_, Xs) call. It seems the transporter is failing to set nx or self.nx.

To Reproduce

  1. Init a UnbalancedSinkhornTransport object
  2. Call fit on the samples
  3. Call transform separately on the samples

Code sample

    import ot
    from ot.datasets import make_2D_samples_gauss
    OT = ot.da.UnbalancedSinkhornTransport()
    Xs = make_2D_samples_gauss(n=1000, m=10, sigma=[[2, 1], [1, 2]], random_state=42)
    Xt = make_2D_samples_gauss(n=1000, m=5, sigma=[[2, 1], [1, 2]], random_state=42)
    Xs = Xs.astype('float32')
    Xt = Xs + 0.5
    Xt = Xt.astype('float32')
    OT.fit(Xs, Xt)
    OT.transform(Xs)

Expected behavior

Transform should return the transported Xs sample.

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): Windows/Linux
  • Python version: 3.11
  • How was POT installed (source, pip, conda): PIP
    Output of the following code snippet:
import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)

Linux-5.15.0-1061-realtime-x86_64-with-glibc2.35
Python 3.11.3 (main, May 15 2023, 15:45:52) [GCC 11.2.0]
NumPy 1.26.4
SciPy 1.10.1
POT 0.9.3

Additional context

Tested on 0.9.3 and 0.9.4

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions