Skip to content

Commit ba688e4

Browse files
committed
tf.math.invert_permutation
1 parent ae84df5 commit ba688e4

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

ot/backend.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2794,7 +2794,11 @@ def unique(self, a, return_inverse=False):
27942794
y, idx = tf.unique(tf.reshape(a, [-1]))
27952795
sort_idx = tf.argsort(y)
27962796
y_prime = tf.gather(y, sort_idx)
2797-
return y_prime if not return_inverse else (y_prime, tf.gather(y, idx))
2797+
if return_inverse:
2798+
inv_sort_idx = tf.math.invert_permutation(sort_idx)
2799+
return y_prime, tf.gather(inv_sort_idx, idx)
2800+
else:
2801+
return y_prime
27982802

27992803
def logsumexp(self, a, axis=None):
28002804
return tf.math.reduce_logsumexp(a, axis=axis)

0 commit comments

Comments
 (0)