diff --git a/metric_learn/itml.py b/metric_learn/itml.py index b40145b6..4c154ad4 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -80,7 +80,8 @@ def fit(self, X, constraints, bounds=None): X : (n x d) data matrix each row corresponds to a single instance constraints : 4-tuple of arrays - (a,b,c,d) indices into X, such that d(X[a],X[b]) < d(X[c],X[d]) + (a,b,c,d) indices into X, with (a,b) specifying positive and (c,d) + negative pairs bounds : list (pos,neg) pairs, optional bounds on similarity, s.t. d(X[a],X[b]) < pos and d(X[c],X[d]) > neg """