|
20 | 20 | from sklearn.utils.validation import check_array, check_X_y
|
21 | 21 |
|
22 | 22 | from .base_metric import BaseMetricLearner
|
23 |
| -from .constraints import Constraints |
| 23 | +from .constraints import Constraints, wrap_pairs |
24 | 24 | from ._util import vector_norm
|
25 | 25 |
|
26 | 26 |
|
@@ -51,52 +51,63 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
|
51 | 51 | self.A0 = A0
|
52 | 52 | self.verbose = verbose
|
53 | 53 |
|
54 |
| - def _process_inputs(self, X, constraints, bounds): |
55 |
| - self.X_ = X = check_array(X) |
| 54 | + def _process_pairs(self, pairs, y, bounds): |
| 55 | + pairs, y = check_X_y(pairs, y, accept_sparse=False, |
| 56 | + ensure_2d=False, allow_nd=True) |
| 57 | + |
56 | 58 | # check to make sure that no two constrained vectors are identical
|
57 |
| - a,b,c,d = constraints |
58 |
| - no_ident = vector_norm(X[a] - X[b]) > 1e-9 |
59 |
| - a, b = a[no_ident], b[no_ident] |
60 |
| - no_ident = vector_norm(X[c] - X[d]) > 1e-9 |
61 |
| - c, d = c[no_ident], d[no_ident] |
| 59 | + pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1] |
| 60 | + pos_no_ident = vector_norm(pos_pairs[:, 0, :] - pos_pairs[:, 1, :]) > 1e-9 |
| 61 | + pos_pairs = pos_pairs[pos_no_ident] |
| 62 | + neg_no_ident = vector_norm(neg_pairs[:, 0, :] - neg_pairs[:, 1, :]) > 1e-9 |
| 63 | + neg_pairs = neg_pairs[neg_no_ident] |
62 | 64 | # init bounds
|
63 | 65 | if bounds is None:
|
| 66 | + X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])}) |
64 | 67 | self.bounds_ = np.percentile(pairwise_distances(X), (5, 95))
|
65 | 68 | else:
|
66 | 69 | assert len(bounds) == 2
|
67 | 70 | self.bounds_ = bounds
|
68 | 71 | self.bounds_[self.bounds_==0] = 1e-9
|
69 | 72 | # init metric
|
70 | 73 | if self.A0 is None:
|
71 |
| - self.A_ = np.identity(X.shape[1]) |
| 74 | + self.A_ = np.identity(pairs.shape[2]) |
72 | 75 | else:
|
73 | 76 | self.A_ = check_array(self.A0)
|
74 |
| - return a,b,c,d |
| 77 | + pairs = np.vstack([pos_pairs, neg_pairs]) |
| 78 | + y = np.hstack([np.ones(len(pos_pairs)), - np.ones(len(neg_pairs))]) |
| 79 | + return pairs, y |
| 80 | + |
75 | 81 |
|
76 |
| - def fit(self, X, constraints, bounds=None): |
| 82 | + def fit(self, pairs, y, bounds=None): |
77 | 83 | """Learn the ITML model.
|
78 | 84 |
|
79 | 85 | Parameters
|
80 | 86 | ----------
|
81 |
| - X : (n x d) data matrix |
82 |
| - each row corresponds to a single instance |
83 |
| - constraints : 4-tuple of arrays |
84 |
| - (a,b,c,d) indices into X, with (a,b) specifying positive and (c,d) |
85 |
| - negative pairs |
| 87 | + pairs: array-like, shape=(n_constraints, 2, n_features) |
| 88 | + Array of pairs. Each row corresponds to two points. |
| 89 | + y: array-like, of shape (n_constraints,) |
| 90 | + Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. |
86 | 91 | bounds : list (pos,neg) pairs, optional
|
87 | 92 | bounds on similarity, s.t. d(X[a],X[b]) < pos and d(X[c],X[d]) > neg
|
| 93 | +
|
| 94 | + Returns |
| 95 | + ------- |
| 96 | + self : object |
| 97 | + Returns the instance. |
88 | 98 | """
|
89 |
| - a,b,c,d = self._process_inputs(X, constraints, bounds) |
| 99 | + pairs, y = self._process_pairs(pairs, y, bounds) |
90 | 100 | gamma = self.gamma
|
91 |
| - num_pos = len(a) |
92 |
| - num_neg = len(c) |
| 101 | + pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1] |
| 102 | + num_pos = len(pos_pairs) |
| 103 | + num_neg = len(neg_pairs) |
93 | 104 | _lambda = np.zeros(num_pos + num_neg)
|
94 | 105 | lambdaold = np.zeros_like(_lambda)
|
95 | 106 | gamma_proj = 1. if gamma is np.inf else gamma/(gamma+1.)
|
96 | 107 | pos_bhat = np.zeros(num_pos) + self.bounds_[0]
|
97 | 108 | neg_bhat = np.zeros(num_neg) + self.bounds_[1]
|
98 |
| - pos_vv = self.X_[a] - self.X_[b] |
99 |
| - neg_vv = self.X_[c] - self.X_[d] |
| 109 | + pos_vv = pos_pairs[:, 0, :] - pos_pairs[:, 1, :] |
| 110 | + neg_vv = neg_pairs[:, 0, :] - neg_pairs[:, 1, :] |
100 | 111 | A = self.A_
|
101 | 112 |
|
102 | 113 | for it in xrange(self.max_iter):
|
@@ -195,4 +206,5 @@ def fit(self, X, y, random_state=np.random):
|
195 | 206 | random_state=random_state)
|
196 | 207 | pos_neg = c.positive_negative_pairs(num_constraints,
|
197 | 208 | random_state=random_state)
|
198 |
| - return ITML.fit(self, X, pos_neg, bounds=self.bounds) |
| 209 | + pairs, y = wrap_pairs(X, pos_neg) |
| 210 | + return ITML.fit(self, pairs, y, bounds=self.bounds) |
0 commit comments