Skip to content

Commit b336eba

Browse files
wdevazelhesbellet
authored andcommitted
[MRG] MAINT: remove variables not needed to store (#159)
* MAINT: remove variables not needed to store * Address review #159 (review) * DOC: add more precise docstring * DOC: make description clearer
1 parent d3620bb commit b336eba

File tree

7 files changed

+105
-49
lines changed

7 files changed

+105
-49
lines changed

metric_learn/itml.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ def _fit(self, pairs, y, bounds=None):
7373
self.bounds_[self.bounds_==0] = 1e-9
7474
# init metric
7575
if self.A0 is None:
76-
self.A_ = np.identity(pairs.shape[2])
76+
A = np.identity(pairs.shape[2])
7777
else:
78-
self.A_ = check_array(self.A0)
78+
A = check_array(self.A0, copy=True)
7979
gamma = self.gamma
8080
pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1]
8181
num_pos = len(pos_pairs)
@@ -87,7 +87,6 @@ def _fit(self, pairs, y, bounds=None):
8787
neg_bhat = np.zeros(num_neg) + self.bounds_[1]
8888
pos_vv = pos_pairs[:, 0, :] - pos_pairs[:, 1, :]
8989
neg_vv = neg_pairs[:, 0, :] - neg_pairs[:, 1, :]
90-
A = self.A_
9190

9291
for it in xrange(self.max_iter):
9392
# update positives
@@ -125,7 +124,7 @@ def _fit(self, pairs, y, bounds=None):
125124
print('itml converged at iter: %d, conv = %f' % (it, conv))
126125
self.n_iter_ = it
127126

128-
self.transformer_ = transformer_from_metric(self.A_)
127+
self.transformer_ = transformer_from_metric(A)
129128
return self
130129

131130

@@ -134,6 +133,18 @@ class ITML(_BaseITML, _PairsClassifierMixin):
134133
135134
Attributes
136135
----------
136+
bounds_ : array-like, shape=(2,)
137+
Bounds on similarity, aside slack variables, s.t.
138+
``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
139+
and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
140+
dissimilar points ``c`` and ``d``, with ``d`` the learned distance. If
141+
not provided at initialization, bounds_[0] and bounds_[1] are set at
142+
train time to the 5th and 95th percentile of the pairwise distances among
143+
all points present in the input `pairs`.
144+
145+
n_iter_ : `int`
146+
The number of iterations the solver has run.
147+
137148
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
138149
The linear transformation ``L`` deduced from the learned Mahalanobis
139150
metric (See function `transformer_from_metric`.)
@@ -151,8 +162,14 @@ def fit(self, pairs, y, bounds=None):
151162
preprocessor.
152163
y: array-like, of shape (n_constraints,)
153164
Labels of constraints. Should be -1 for dissimilar pair, 1 for similar.
154-
bounds : list (pos,neg) pairs, optional
155-
bounds on similarity, s.t. d(X[a],X[b]) < pos and d(X[c],X[d]) > neg
165+
bounds : `list` of two numbers
166+
Bounds on similarity, aside slack variables, s.t.
167+
``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
168+
and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
169+
dissimilar points ``c`` and ``d``, with ``d`` the learned distance.
170+
If not provided at initialization, bounds_[0] and bounds_[1] will be
171+
set to the 5th and 95th percentile of the pairwise distances among all
172+
points present in the input `pairs`.
156173
157174
Returns
158175
-------
@@ -167,6 +184,18 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
167184
168185
Attributes
169186
----------
187+
bounds_ : array-like, shape=(2,)
188+
Bounds on similarity, aside slack variables, s.t.
189+
``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
190+
and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
191+
dissimilar points ``c`` and ``d``, with ``d`` the learned distance.
192+
If not provided at initialization, bounds_[0] and bounds_[1] are set at
193+
train time to the 5th and 95th percentile of the pairwise distances
194+
among all points in the training data `X`.
195+
196+
n_iter_ : `int`
197+
The number of iterations the solver has run.
198+
170199
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
171200
The linear transformation ``L`` deduced from the learned Mahalanobis
172201
metric (See function `transformer_from_metric`.)
@@ -193,8 +222,14 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
193222
be removed in 0.6.0.
194223
num_constraints: int, optional
195224
number of constraints to generate
196-
bounds : list (pos,neg) pairs, optional
197-
bounds on similarity, s.t. d(X[a],X[b]) < pos and d(X[c],X[d]) > neg
225+
bounds : `list` of two numbers
226+
Bounds on similarity, aside slack variables, s.t.
227+
``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
228+
and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
229+
dissimilar points ``c`` and ``d``, with ``d`` the learned distance.
230+
If not provided at initialization, bounds_[0] and bounds_[1] will be
231+
set to the 5th and 95th percentile of the pairwise distances among all
232+
points in the training data `X`.
198233
A0 : (d x d) matrix, optional
199234
initial regularization matrix, defaults to identity
200235
verbose : bool, optional

metric_learn/lmnn.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,20 @@ def fit(self, X, y):
6060
X, y = self._prepare_inputs(X, y, dtype=float,
6161
ensure_min_samples=2)
6262
num_pts, num_dims = X.shape
63-
unique_labels, self.label_inds_ = np.unique(y, return_inverse=True)
64-
if len(self.label_inds_) != num_pts:
63+
unique_labels, label_inds = np.unique(y, return_inverse=True)
64+
if len(label_inds) != num_pts:
6565
raise ValueError('Must have one label per point.')
6666
self.labels_ = np.arange(len(unique_labels))
6767
if self.use_pca:
6868
warnings.warn('use_pca does nothing for the python_LMNN implementation')
6969
self.transformer_ = np.eye(num_dims)
70-
required_k = np.bincount(self.label_inds_).min()
70+
required_k = np.bincount(label_inds).min()
7171
if self.k > required_k:
7272
raise ValueError('not enough class labels for specified k'
7373
' (smallest class has %d)' % required_k)
7474

75-
target_neighbors = self._select_targets(X)
76-
impostors = self._find_impostors(target_neighbors[:, -1], X)
75+
target_neighbors = self._select_targets(X, label_inds)
76+
impostors = self._find_impostors(target_neighbors[:, -1], X, label_inds)
7777
if len(impostors) == 0:
7878
# L has already been initialized to an identity matrix
7979
return
@@ -196,23 +196,23 @@ def _loss_grad(self, X, L, dfG, impostors, it, k, reg, target_neighbors, df,
196196
objective += G.flatten().dot(L.T.dot(L).flatten())
197197
return G, objective, total_active, df, a1, a2
198198

199-
def _select_targets(self, X):
199+
def _select_targets(self, X, label_inds):
200200
target_neighbors = np.empty((X.shape[0], self.k), dtype=int)
201201
for label in self.labels_:
202-
inds, = np.nonzero(self.label_inds_ == label)
202+
inds, = np.nonzero(label_inds == label)
203203
dd = euclidean_distances(X[inds], squared=True)
204204
np.fill_diagonal(dd, np.inf)
205205
nn = np.argsort(dd)[..., :self.k]
206206
target_neighbors[inds] = inds[nn]
207207
return target_neighbors
208208

209-
def _find_impostors(self, furthest_neighbors, X):
209+
def _find_impostors(self, furthest_neighbors, X, label_inds):
210210
Lx = self.transform(X)
211211
margin_radii = 1 + _inplace_paired_L2(Lx[furthest_neighbors], Lx)
212212
impostors = []
213213
for label in self.labels_[:-1]:
214-
in_inds, = np.nonzero(self.label_inds_ == label)
215-
out_inds, = np.nonzero(self.label_inds_ > label)
214+
in_inds, = np.nonzero(label_inds == label)
215+
out_inds, = np.nonzero(label_inds > label)
216216
dist = euclidean_distances(Lx[out_inds], Lx[in_inds], squared=True)
217217
i1,j1 = np.nonzero(dist < margin_radii[out_inds][:,None])
218218
i2,j2 = np.nonzero(dist < margin_radii[in_inds])
@@ -265,6 +265,9 @@ class LMNN(_base_LMNN):
265265
266266
Attributes
267267
----------
268+
n_iter_ : `int`
269+
The number of iterations the solver has run.
270+
268271
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
269272
The learned linear transformation ``L``.
270273
"""

metric_learn/lsml.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -50,32 +50,32 @@ def _fit(self, quadruplets, y=None, weights=None):
5050
type_of_inputs='tuples')
5151

5252
# check to make sure that no two constrained vectors are identical
53-
self.vab_ = quadruplets[:, 0, :] - quadruplets[:, 1, :]
54-
self.vcd_ = quadruplets[:, 2, :] - quadruplets[:, 3, :]
55-
if self.vab_.shape != self.vcd_.shape:
53+
vab = quadruplets[:, 0, :] - quadruplets[:, 1, :]
54+
vcd = quadruplets[:, 2, :] - quadruplets[:, 3, :]
55+
if vab.shape != vcd.shape:
5656
raise ValueError('Constraints must have same length')
5757
if weights is None:
58-
self.w_ = np.ones(self.vab_.shape[0])
58+
self.w_ = np.ones(vab.shape[0])
5959
else:
6060
self.w_ = weights
6161
self.w_ /= self.w_.sum() # weights must sum to 1
6262
if self.prior is None:
6363
X = np.vstack({tuple(row) for row in
6464
quadruplets.reshape(-1, quadruplets.shape[2])})
65-
self.prior_inv_ = np.atleast_2d(np.cov(X, rowvar=False))
66-
self.M_ = np.linalg.inv(self.prior_inv_)
65+
prior_inv = np.atleast_2d(np.cov(X, rowvar=False))
66+
M = np.linalg.inv(prior_inv)
6767
else:
68-
self.M_ = self.prior
69-
self.prior_inv_ = np.linalg.inv(self.prior)
68+
M = self.prior
69+
prior_inv = np.linalg.inv(self.prior)
7070

7171
step_sizes = np.logspace(-10, 0, 10)
7272
# Keep track of the best step size and the loss at that step.
7373
l_best = 0
74-
s_best = self._total_loss(self.M_)
74+
s_best = self._total_loss(M, vab, vcd, prior_inv)
7575
if self.verbose:
7676
print('initial loss', s_best)
7777
for it in xrange(1, self.max_iter+1):
78-
grad = self._gradient(self.M_)
78+
grad = self._gradient(M, vab, vcd, prior_inv)
7979
grad_norm = scipy.linalg.norm(grad)
8080
if grad_norm < self.tol:
8181
break
@@ -84,10 +84,10 @@ def _fit(self, quadruplets, y=None, weights=None):
8484
M_best = None
8585
for step_size in step_sizes:
8686
step_size /= grad_norm
87-
new_metric = self.M_ - step_size * grad
87+
new_metric = M - step_size * grad
8888
w, v = scipy.linalg.eigh(new_metric)
8989
new_metric = v.dot((np.maximum(w, 1e-8) * v).T)
90-
cur_s = self._total_loss(new_metric)
90+
cur_s = self._total_loss(new_metric, vab, vcd, prior_inv)
9191
if cur_s < s_best:
9292
l_best = step_size
9393
s_best = cur_s
@@ -96,36 +96,36 @@ def _fit(self, quadruplets, y=None, weights=None):
9696
print('iter', it, 'cost', s_best, 'best step', l_best * grad_norm)
9797
if M_best is None:
9898
break
99-
self.M_ = M_best
99+
M = M_best
100100
else:
101101
if self.verbose:
102102
print("Didn't converge after", it, "iterations. Final loss:", s_best)
103103
self.n_iter_ = it
104104

105-
self.transformer_ = transformer_from_metric(self.M_)
105+
self.transformer_ = transformer_from_metric(M)
106106
return self
107107

108-
def _comparison_loss(self, metric):
109-
dab = np.sum(self.vab_.dot(metric) * self.vab_, axis=1)
110-
dcd = np.sum(self.vcd_.dot(metric) * self.vcd_, axis=1)
108+
def _comparison_loss(self, metric, vab, vcd):
109+
dab = np.sum(vab.dot(metric) * vab, axis=1)
110+
dcd = np.sum(vcd.dot(metric) * vcd, axis=1)
111111
violations = dab > dcd
112112
return self.w_[violations].dot((np.sqrt(dab[violations]) -
113113
np.sqrt(dcd[violations]))**2)
114114

115-
def _total_loss(self, metric):
115+
def _total_loss(self, metric, vab, vcd, prior_inv):
116116
# Regularization loss
117117
sign, logdet = np.linalg.slogdet(metric)
118-
reg_loss = np.sum(metric * self.prior_inv_) - sign * logdet
119-
return self._comparison_loss(metric) + reg_loss
118+
reg_loss = np.sum(metric * prior_inv) - sign * logdet
119+
return self._comparison_loss(metric, vab, vcd) + reg_loss
120120

121-
def _gradient(self, metric):
122-
dMetric = self.prior_inv_ - np.linalg.inv(metric)
123-
dabs = np.sum(self.vab_.dot(metric) * self.vab_, axis=1)
124-
dcds = np.sum(self.vcd_.dot(metric) * self.vcd_, axis=1)
121+
def _gradient(self, metric, vab, vcd, prior_inv):
122+
dMetric = prior_inv - np.linalg.inv(metric)
123+
dabs = np.sum(vab.dot(metric) * vab, axis=1)
124+
dcds = np.sum(vcd.dot(metric) * vcd, axis=1)
125125
violations = dabs > dcds
126126
# TODO: vectorize
127-
for vab, dab, vcd, dcd in zip(self.vab_[violations], dabs[violations],
128-
self.vcd_[violations], dcds[violations]):
127+
for vab, dab, vcd, dcd in zip(vab[violations], dabs[violations],
128+
vcd[violations], dcds[violations]):
129129
dMetric += ((1-np.sqrt(dcd/dab))*np.outer(vab, vab) +
130130
(1-np.sqrt(dab/dcd))*np.outer(vcd, vcd))
131131
return dMetric
@@ -136,6 +136,9 @@ class LSML(_BaseLSML, _QuadrupletsClassifierMixin):
136136
137137
Attributes
138138
----------
139+
n_iter_ : `int`
140+
The number of iterations the solver has run.
141+
139142
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
140143
The linear transformation ``L`` deduced from the learned Mahalanobis
141144
metric (See function `transformer_from_metric`.)
@@ -169,6 +172,9 @@ class LSML_Supervised(_BaseLSML, TransformerMixin):
169172
170173
Attributes
171174
----------
175+
n_iter_ : `int`
176+
The number of iterations the solver has run.
177+
172178
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
173179
The linear transformation ``L`` deduced from the learned Mahalanobis
174180
metric (See function `transformer_from_metric`.)

metric_learn/mlkr.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ class MLKR(MahalanobisMixin, TransformerMixin):
3030
3131
Attributes
3232
----------
33+
n_iter_ : `int`
34+
The number of iterations the solver has run.
35+
3336
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
3437
The learned linear transformation ``L``.
3538
"""

metric_learn/mmc.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,9 @@ class MMC(_BaseMMC, _PairsClassifierMixin):
353353
354354
Attributes
355355
----------
356+
n_iter_ : `int`
357+
The number of iterations the solver has run.
358+
356359
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
357360
The linear transformation ``L`` deduced from the learned Mahalanobis
358361
metric (See function `transformer_from_metric`.)
@@ -384,6 +387,9 @@ class MMC_Supervised(_BaseMMC, TransformerMixin):
384387
385388
Attributes
386389
----------
390+
n_iter_ : `int`
391+
The number of iterations the solver has run.
392+
387393
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
388394
The linear transformation ``L`` deduced from the learned Mahalanobis
389395
metric (See function `transformer_from_metric`.)

metric_learn/nca.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ class NCA(MahalanobisMixin, TransformerMixin):
2424
2525
Attributes
2626
----------
27+
n_iter_ : `int`
28+
The number of iterations the solver has run.
29+
2730
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
2831
The learned linear transformation ``L``.
2932
"""

metric_learn/sdml.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,18 @@ def _fit(self, pairs, y):
5858
# set up prior M
5959
if self.use_cov:
6060
X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])})
61-
self.M_ = pinvh(np.atleast_2d(np.cov(X, rowvar = False)))
61+
M = pinvh(np.atleast_2d(np.cov(X, rowvar = False)))
6262
else:
63-
self.M_ = np.identity(pairs.shape[2])
63+
M = np.identity(pairs.shape[2])
6464
diff = pairs[:, 0] - pairs[:, 1]
6565
loss_matrix = (diff.T * y).dot(diff)
66-
P = self.M_ + self.balance_param * loss_matrix
66+
P = M + self.balance_param * loss_matrix
6767
emp_cov = pinvh(P)
6868
# hack: ensure positive semidefinite
6969
emp_cov = emp_cov.T.dot(emp_cov)
70-
_, self.M_ = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose)
70+
_, M = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose)
7171

72-
self.transformer_ = transformer_from_metric(self.M_)
72+
self.transformer_ = transformer_from_metric(M)
7373
return self
7474

7575

0 commit comments

Comments
 (0)