Closed
Description
Describe the bug
Per documentation of ot.emd2()
, uniform weights will be used if empty lists are passed as the arguments. However, doing so with the JAX backend will cause broadcasting issue.
To Reproduce
Simulate some data first:
import jax
from jax import numpy as jnp
key = jax.random.PRNGKey(1)
x = jax.random.normal(key, (100, 2))
y = jax.random.normal(key, (100, 2))
With numpy
backend, the following works without an issue:
from opt_einsum import contract
M = contract('mi,ni->mn', x, y, backend='numpy') ** 2.
emt = np.empty((0))
Wass_dis = ot.emd2(emt, emt, M=M)
Wass_dis
However, errors occur once we switch to jnp
:
M = contract('mi,ni->mn', x, y, backend='jax') ** 2.
emt = jnp.empty((0))
Wass_dis = ot.emd2(emt, emt, M=M)
Wass_dis
Partial error message:
File [c:\ProgramData\anaconda3\Lib\site-packages\ot\lp\__init__.py:567](file:///C:/ProgramData/anaconda3/Lib/site-packages/ot/lp/__init__.py:567), in emd2.<locals>.f(b)
559 warnings.warn(
560 "Input histogram consists of integer. The transport plan will be "
561 "casted accordingly, possibly resulting in a loss of precision. "
(...)
564 stacklevel=2
565 )
566 G = nx.from_numpy(G, type_as=type_as)
--> 567 cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as),
568 (a0, b0, M0), (nx.from_numpy(u - np.mean(u), type_as=type_as),
569 nx.from_numpy(v - np.mean(v), type_as=type_as), G))
571 check_result(result_code)
572 return cost
File [c:\ProgramData\anaconda3\Lib\site-packages\ot\backend.py:1392](file:///C:/ProgramData/anaconda3/Lib/site-packages/ot/backend.py:1392), in JaxBackend.set_gradients(self, val, inputs, grads)
1389 ravelled_inputs, _ = ravel_pytree(inputs)
1390 ravelled_grads, _ = ravel_pytree(grads)
-> 1392 aux = jnp.sum(ravelled_inputs * ravelled_grads) / 2
1393 aux = aux - jax.lax.stop_gradient(aux)
1395 val, = jax.tree_map(lambda z: z + aux, (val,))
File [c:\ProgramData\anaconda3\Lib\site-packages\jax\_src\numpy\array_methods.py:256](file:///C:/ProgramData/anaconda3/Lib/site-packages/jax/_src/numpy/array_methods.py:256), in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
254 args = (other, self) if swap else (self, other)
255 if isinstance(other, _accepted_binop_types):
--> 256 return binary_op(*args)
257 # Note: don't use isinstance here, because we don't want to raise for
258 # subclasses, e.g. NamedTuple objects that may override operators.
259 if type(other) in _rejected_binop_types:
[... skipping hidden 12 frame]
File [c:\ProgramData\anaconda3\Lib\site-packages\jax\_src\numpy\ufuncs.py:97](file:///C:/ProgramData/anaconda3/Lib/site-packages/jax/_src/numpy/ufuncs.py:97), in _maybe_bool_binop.<locals>.fn(x1, x2)
95 def fn(x1, x2, /):
96 x1, x2 = promote_args(numpy_fn.__name__, x1, x2)
---> 97 return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)
[... skipping hidden 7 frame]
File [c:\ProgramData\anaconda3\Lib\site-packages\jax\_src\lax\lax.py:1591](file:///C:/ProgramData/anaconda3/Lib/site-packages/jax/_src/lax/lax.py:1591), in broadcasting_shape_rule(name, *avals)
1589 result_shape.append(non_1s[0])
1590 else:
-> 1591 raise TypeError(f'{name} got incompatible shapes for broadcasting: '
1592 f'{", ".join(map(str, map(tuple, shapes)))}.')
1594 return tuple(result_shape)
TypeError: mul got incompatible shapes for broadcasting: (10000,), (10200,).
Possible solution:
This problem can be avoided if we generate the uniform weight by ourselves:
M = contract('mi,ni->mn', x, y, backend='jax') ** 2.
emt0 = jnp.ones((M.shape[0],)) / M.shape[0]
emt1 = jnp.ones((M.shape[1],)) / M.shape[1]
Wass_dis = ot.emd2(emt0, emt1, M=M)
Wass_dis # correct result
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): Windows
- Python version: 3.11.4
- How was POT installed (source,
pip
,conda
):pip
Output of the following code snippet:
import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Windows-10-10.0.22621-SP0
Python 3.11.4 | packaged by Anaconda, Inc. | (main, Jul 5 2023, 13:38:37) [MSC v.1916 64 bit (AMD64)]
NumPy 1.24.3
SciPy 1.10.1
POT 0.9.1