Skip to content

sinkhorn2 and its functorch.vmap compatibility #482

Open
@hmdolatabadi

Description

@hmdolatabadi

🚀 Feature

Making the ot.sinkhorn2 function compatible with functorch.vmap.

Motivation

I'm using the Python Optimal Transport library. I want to define a loss function that iterates over every sample in my batch and calculates the sinkhorn distance for that sample and its ground-truth value. What I was using before was a for-loop:

for i in range(len(P_batch)):
      if i == 0:
         loss = ot.sinkhorn2(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), C, epsilon)
      loss += ot.sinkhorn2(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), C, epsilon)

but this is way too slow for my application. I was reading through functorch, and apparently I should have been able to use the vmap functionality.

losses = vmap(ot.sinkhorn2)(P, Q, C, epsilon)

But after wrapping my function in vmap, I get this weird error:

File /anaconda3/envs/my_env/lib/python3.8/site-packages/ot/bregman.py:505, in sinkhorn_knopp(a, b, M, reg, numItermax, stopThr, verbose, log, warn, warmstart, **kwargs)
    502 v = b / KtransposeU
    503 u = 1. / nx.dot(Kp, v)
--> 505 if (nx.any(KtransposeU == 0)
    506         or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
    507         or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
    508     # we have reached the machine precision
    509     # come back to previous solution and quit loop
    510     warnings.warn('Warning: numerical errors at iteration %d' % ii)
    511     u = uprev

RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow. We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 .

Pitch

Apparently, the data-dependent if-statement needs to be replaced with other alternatives. Any help is appreciated.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions