Skip to content

Commit ae84df5

Browse files
committed
Fix sort indicies for unique(), switched to tf.gather
1 parent 5ac5553 commit ae84df5

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

ot/backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2793,7 +2793,8 @@ def diag(self, a, k=0):
27932793
def unique(self, a, return_inverse=False):
27942794
y, idx = tf.unique(tf.reshape(a, [-1]))
27952795
sort_idx = tf.argsort(y)
2796-
return y[sort_idx] if not return_inverse else (y[sort_idx], sort_idx[idx])
2796+
y_prime = tf.gather(y, sort_idx)
2797+
return y_prime if not return_inverse else (y_prime, tf.gather(y, idx))
27972798

27982799
def logsumexp(self, a, axis=None):
27992800
return tf.math.reduce_logsumexp(a, axis=axis)

0 commit comments

Comments
 (0)