Skip to content

Commit 374a851

Browse files
author
William de Vazelhes
committed
Change labels y to be +1/-1 (cf. comment #92 (comment)).
1 parent 903f174 commit 374a851

File tree

4 files changed

+8
-13
lines changed

4 files changed

+8
-13
lines changed

metric_learn/constraints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,6 @@ def wrap_pairs(X, constraints):
107107
c = np.array(constraints[2])
108108
d = np.array(constraints[3])
109109
constraints = np.vstack((np.column_stack((a, b)), np.column_stack((c, d))))
110-
y = np.vstack([np.ones((len(a), 1)), np.zeros((len(c), 1))])
110+
y = np.vstack([np.ones((len(a), 1)), - np.ones((len(c), 1))])
111111
pairs = X[constraints]
112112
return pairs, y

metric_learn/itml.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,9 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
5454
def _process_pairs(self, pairs, y, bounds):
5555
pairs, y = check_X_y(pairs, y, accept_sparse=False,
5656
ensure_2d=False, allow_nd=True)
57-
y = y.astype(bool)
5857

5958
# check to make sure that no two constrained vectors are identical
60-
pos_pairs, neg_pairs = pairs[y], pairs[~y]
59+
pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1]
6160
pos_no_ident = vector_norm(pos_pairs[:, 0, :] - pos_pairs[:, 1, :]) > 1e-9
6261
pos_pairs = pos_pairs[pos_no_ident]
6362
neg_no_ident = vector_norm(neg_pairs[:, 0, :] - neg_pairs[:, 1, :]) > 1e-9
@@ -76,8 +75,7 @@ def _process_pairs(self, pairs, y, bounds):
7675
else:
7776
self.A_ = check_array(self.A0)
7877
pairs = np.vstack([pos_pairs, neg_pairs])
79-
y = np.hstack([np.ones(len(pos_pairs)), np.zeros(len(neg_pairs))])
80-
y = y.astype(bool)
78+
y = np.hstack([np.ones(len(pos_pairs)), - np.ones(len(neg_pairs))])
8179
return pairs, y
8280

8381

@@ -100,7 +98,7 @@ def fit(self, pairs, y, bounds=None):
10098
"""
10199
pairs, y = self._process_pairs(pairs, y, bounds)
102100
gamma = self.gamma
103-
pos_pairs, neg_pairs = pairs[y], pairs[~y]
101+
pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1]
104102
num_pos = len(pos_pairs)
105103
num_neg = len(neg_pairs)
106104
_lambda = np.zeros(num_pos + num_neg)

metric_learn/mmc.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,9 @@ def fit(self, pairs, y):
8383
def _process_pairs(self, pairs, y):
8484
pairs, y = check_X_y(pairs, y, accept_sparse=False,
8585
ensure_2d=False, allow_nd=True)
86-
y = y.astype(bool)
8786

8887
# check to make sure that no two constrained vectors are identical
89-
pos_pairs, neg_pairs = pairs[y], pairs[~y]
88+
pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1]
9089
pos_no_ident = vector_norm(pos_pairs[:, 0, :] - pos_pairs[:, 1, :]) > 1e-9
9190
pos_pairs = pos_pairs[pos_no_ident]
9291
neg_no_ident = vector_norm(neg_pairs[:, 0, :] - neg_pairs[:, 1, :]) > 1e-9
@@ -107,8 +106,7 @@ def _process_pairs(self, pairs, y):
107106
self.A_ = check_array(self.A0)
108107

109108
pairs = np.vstack([pos_pairs, neg_pairs])
110-
y = np.hstack([np.ones(len(pos_pairs)), np.zeros(len(neg_pairs))])
111-
y = y.astype(bool)
109+
y = np.hstack([np.ones(len(pos_pairs)), - np.ones(len(neg_pairs))])
112110
return pairs, y
113111

114112
def _fit_full(self, pairs, y):
@@ -128,7 +126,7 @@ def _fit_full(self, pairs, y):
128126
eps = 0.01 # error-bound of iterative projection on C1 and C2
129127
A = self.A_
130128

131-
pos_pairs, neg_pairs = pairs[y], pairs[~y]
129+
pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1]
132130

133131
# Create weight vector from similar samples
134132
pos_diff = pos_pairs[:, 0, :] - pos_pairs[:, 1, :]
@@ -244,7 +242,7 @@ def _fit_diag(self, pairs, y):
244242
dissimilar pairs
245243
"""
246244
num_dim = pairs.shape[2]
247-
pos_pairs, neg_pairs = pairs[y], pairs[~y]
245+
pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1]
248246
s_sum = np.sum((pos_pairs[:, 0, :] - pos_pairs[:, 1, :]) ** 2, axis=0)
249247

250248
it = 0

metric_learn/sdml.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,5 +135,4 @@ def fit(self, X, y, random_state=np.random):
135135
pos_neg = c.positive_negative_pairs(num_constraints,
136136
random_state=random_state)
137137
pairs, y = wrap_pairs(X, pos_neg)
138-
y = 2 * y - 1
139138
return SDML.fit(self, pairs, y)

0 commit comments

Comments
 (0)