Skip to content

Commit 3c57c64

Browse files
sveconperimosocordiae
authored andcommitted
Fit constraints (#19)
* ITML fit_constraints * SDML fit_constraints * LSML fit_constraints * RCA fit_constraints * Renamed fit->fit_constrains in test * Created new supervised classes for methods with semi-supervised constraints * Comment * Removed duplicate code * Super compatible with Python 2 * Code standards * Fix overriding of params
1 parent c1372af commit 3c57c64

File tree

8 files changed

+262
-95
lines changed

8 files changed

+262
-95
lines changed

metric_learn/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import absolute_import
22

3-
from .itml import ITML
3+
from .itml import ITML, ITML_Supervised
44
from .lmnn import LMNN
5-
from .lsml import LSML
6-
from .sdml import SDML
5+
from .lsml import LSML, LSML_Supervised
6+
from .sdml import SDML, SDML_Supervised
77
from .nca import NCA
88
from .lfda import LFDA
9-
from .rca import RCA
9+
from .rca import RCA, RCA_Supervised
10+
from .constraints import adjacencyMatrix, positiveNegativePairs, relativeQuadruplets, chunks

metric_learn/constraints.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
""" Helper class that can generate different types of constraints from supervised data labels."""
2+
3+
import numpy as np
4+
import random
5+
from six.moves import xrange
6+
7+
# @TODO: consider creating a stateful class
8+
# https://github.com/all-umass/metric-learn/pull/19#discussion_r67386226
9+
10+
def adjacencyMatrix(labels, num_points, num_constraints):
11+
a, c = np.random.randint(len(labels), size=(2,num_constraints))
12+
b, d = np.empty((2, num_constraints), dtype=int)
13+
for i,(al,cl) in enumerate(zip(labels[a],labels[c])):
14+
b[i] = random.choice(np.nonzero(labels == al)[0])
15+
d[i] = random.choice(np.nonzero(labels != cl)[0])
16+
W = np.zeros((num_points,num_points))
17+
W[a,b] = 1
18+
W[c,d] = -1
19+
# make W symmetric
20+
W[b,a] = 1
21+
W[d,c] = -1
22+
return W
23+
24+
def positiveNegativePairs(labels, num_points, num_constraints):
25+
ac,bd = np.random.randint(num_points, size=(2,num_constraints))
26+
pos = labels[ac] == labels[bd]
27+
a,c = ac[pos], ac[~pos]
28+
b,d = bd[pos], bd[~pos]
29+
return a,b,c,d
30+
31+
def relativeQuadruplets(labels, num_constraints):
32+
C = np.empty((num_constraints,4), dtype=int)
33+
a, c = np.random.randint(len(labels), size=(2,num_constraints))
34+
for i,(al,cl) in enumerate(zip(labels[a],labels[c])):
35+
C[i,1] = random.choice(np.nonzero(labels == al)[0])
36+
C[i,3] = random.choice(np.nonzero(labels != cl)[0])
37+
C[:,0] = a
38+
C[:,2] = c
39+
return C
40+
41+
def chunks(Y, num_chunks=100, chunk_size=2, seed=None):
42+
# @TODO: remove seed from params and use numpy RandomState
43+
# https://github.com/all-umass/metric-learn/pull/19#discussion_r67386666
44+
random.seed(seed)
45+
chunks = -np.ones_like(Y, dtype=int)
46+
uniq, lookup = np.unique(Y, return_inverse=True)
47+
all_inds = [set(np.where(lookup==c)[0]) for c in xrange(len(uniq))]
48+
idx = 0
49+
while idx < num_chunks and all_inds:
50+
c = random.randint(0, len(all_inds)-1)
51+
inds = all_inds[c]
52+
if len(inds) < chunk_size:
53+
del all_inds[c]
54+
continue
55+
ii = random.sample(inds, chunk_size)
56+
inds.difference_update(ii)
57+
chunks[ii] = idx
58+
idx += 1
59+
if idx < num_chunks:
60+
raise ValueError('Unable to make %d chunks of %d examples each' %
61+
(num_chunks, chunk_size))
62+
return chunks

metric_learn/itml.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616
from six.moves import xrange
1717
from sklearn.metrics import pairwise_distances
1818
from .base_metric import BaseMetricLearner
19+
from .constraints import positiveNegativePairs
1920

2021

2122
class ITML(BaseMetricLearner):
2223
"""Information Theoretic Metric Learning (ITML)"""
23-
def __init__(self, gamma=1., max_iters=1000, convergence_threshold=1e-3):
24+
def __init__(self, gamma=1., max_iters=1000, convergence_threshold=1e-3, verbose=False):
2425
"""Initialize the learner.
2526
2627
Parameters
@@ -29,11 +30,14 @@ def __init__(self, gamma=1., max_iters=1000, convergence_threshold=1e-3):
2930
value for slack variables
3031
max_iters : int, optional
3132
convergence_threshold : float, optional
33+
verbose : bool, optional
34+
if True, prints information while learning
3235
"""
3336
self.params = {
3437
'gamma': gamma,
3538
'max_iters': max_iters,
3639
'convergence_threshold': convergence_threshold,
40+
'verbose': verbose,
3741
}
3842

3943
def _process_inputs(self, X, constraints, bounds, A0):
@@ -57,7 +61,7 @@ def _process_inputs(self, X, constraints, bounds, A0):
5761
self.A = A0
5862
return a,b,c,d
5963

60-
def fit(self, X, constraints, bounds=None, A0=None, verbose=False):
64+
def fit(self, X, constraints, bounds=None, A0=None):
6165
"""Learn the ITML model.
6266
6367
Parameters
@@ -71,6 +75,7 @@ def fit(self, X, constraints, bounds=None, A0=None, verbose=False):
7175
A0 : (d x d) matrix, optional
7276
initial regularization matrix, defaults to identity
7377
"""
78+
verbose = self.params['verbose']
7479
a,b,c,d = self._process_inputs(X, constraints, bounds, A0)
7580
gamma = self.params['gamma']
7681
conv_thresh = self.params['convergence_threshold']
@@ -121,14 +126,6 @@ def fit(self, X, constraints, bounds=None, A0=None, verbose=False):
121126
def metric(self):
122127
return self.A
123128

124-
@classmethod
125-
def prepare_constraints(self, labels, num_points, num_constraints):
126-
ac,bd = np.random.randint(num_points, size=(2,num_constraints))
127-
pos = labels[ac] == labels[bd]
128-
a,c = ac[pos], ac[~pos]
129-
b,d = bd[pos], bd[~pos]
130-
return a,b,c,d
131-
132129
# hack around lack of axis kwarg in older numpy versions
133130
try:
134131
np.linalg.norm([[4]], axis=1)
@@ -138,3 +135,46 @@ def _vector_norm(X):
138135
else:
139136
def _vector_norm(X):
140137
return np.linalg.norm(X, axis=1)
138+
139+
140+
class ITML_Supervised(ITML):
141+
"""Information Theoretic Metric Learning (ITML)"""
142+
def __init__(self, gamma=1., max_iters=1000, convergence_threshold=1e-3, num_constraints=None,
143+
bounds=None, A0=None, verbose=False):
144+
"""Initialize the learner.
145+
146+
Parameters
147+
----------
148+
gamma : float, optional
149+
value for slack variables
150+
max_iters : int, optional
151+
convergence_threshold : float, optional
152+
num_constraints: int, needed for .fit()
153+
verbose : bool, optional
154+
if True, prints information while learning
155+
"""
156+
ITML.__init__(self, gamma=gamma, max_iters=max_iters,
157+
convergence_threshold=convergence_threshold, verbose=verbose)
158+
self.params.update({
159+
'num_constraints': num_constraints,
160+
'bounds': bounds,
161+
'A0': A0,
162+
})
163+
164+
def fit(self, X, labels):
165+
"""Create constraints from labels and learn the ITML model.
166+
Needs num_constraints specified in constructor.
167+
168+
Parameters
169+
----------
170+
X : (n x d) data matrix
171+
each row corresponds to a single instance
172+
labels : (n) data labels
173+
"""
174+
num_constraints = self.params['num_constraints']
175+
if num_constraints is None:
176+
num_classes = np.unique(labels)
177+
num_constraints = 20*(len(num_classes))**2
178+
179+
C = positiveNegativePairs(labels, X.shape[0], num_constraints)
180+
return ITML.fit(self, X, C, bounds=self.params['bounds'], A0=self.params['A0'])

metric_learn/lmnn.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@ def transformer(self):
2929
# slower Python version
3030
class python_LMNN(_base_LMNN):
3131
def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7,
32-
regularization=0.5, convergence_tol=0.001):
32+
regularization=0.5, convergence_tol=0.001, verbose=False):
3333
"""Initialize the LMNN object
3434
3535
k: number of neighbors to consider. (does not include self-edges)
3636
regularization: weighting of pull and push terms
3737
"""
3838
_base_LMNN.__init__(self, k=k, min_iter=min_iter, max_iter=max_iter,
3939
learn_rate=learn_rate, regularization=regularization,
40-
convergence_tol=convergence_tol)
40+
convergence_tol=convergence_tol, verbose=verbose)
4141

4242
def _process_inputs(self, X, labels):
4343
num_pts = X.shape[0]
@@ -51,8 +51,9 @@ def _process_inputs(self, X, labels):
5151
'not enough class labels for specified k'
5252
' (smallest class has %d)' % required_k)
5353

54-
def fit(self, X, labels, verbose=False):
54+
def fit(self, X, labels):
5555
k = self.params['k']
56+
verbose = self.params['verbose']
5657
reg = self.params['regularization']
5758
learn_rate = self.params['learn_rate']
5859
convergence_tol = self.params['convergence_tol']
@@ -236,12 +237,12 @@ def _sum_outer_products(data, a_inds, b_inds, weights=None):
236237

237238
class LMNN(_base_LMNN):
238239
def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7,
239-
regularization=0.5, convergence_tol=0.001, use_pca=True):
240+
regularization=0.5, convergence_tol=0.001, use_pca=True, verbose=False):
240241
_base_LMNN.__init__(self, k=k, min_iter=min_iter, max_iter=max_iter,
241242
learn_rate=learn_rate, regularization=regularization,
242-
convergence_tol=convergence_tol, use_pca=use_pca)
243+
convergence_tol=convergence_tol, use_pca=use_pca, verbose=verbose)
243244

244-
def fit(self, X, labels, verbose=False):
245+
def fit(self, X, labels):
245246
self.X = X
246247
self.L = np.eye(X.shape[1])
247248
labels = MulticlassLabels(labels.astype(np.float64))

metric_learn/lsml.py

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,24 @@
1313
from random import choice
1414
from six.moves import xrange
1515
from .base_metric import BaseMetricLearner
16+
from .constraints import relativeQuadruplets
1617

1718

1819
class LSML(BaseMetricLearner):
19-
def __init__(self, tol=1e-3, max_iter=1000):
20+
def __init__(self, tol=1e-3, max_iter=1000, verbose=False):
2021
"""Initialize the learner.
2122
2223
Parameters
2324
----------
2425
tol : float, optional
2526
max_iter : int, optional
27+
verbose : bool, optional
28+
if True, prints information while learning
2629
"""
2730
self.params = {
2831
'tol': tol,
2932
'max_iter': max_iter,
33+
'verbose': verbose,
3034
}
3135

3236
def _prepare_inputs(self, X, constraints, weights, prior):
@@ -46,7 +50,7 @@ def _prepare_inputs(self, X, constraints, weights, prior):
4650
def metric(self):
4751
return self.M
4852

49-
def fit(self, X, constraints, weights=None, prior=None, verbose=False):
53+
def fit(self, X, constraints, weights=None, prior=None):
5054
"""Learn the LSML model.
5155
5256
Parameters
@@ -59,9 +63,8 @@ def fit(self, X, constraints, weights=None, prior=None, verbose=False):
5963
scale factor for each constraint
6064
prior : (d x d) matrix, optional
6165
guess at a metric [default: covariance(X)]
62-
verbose : bool, optional
63-
if True, prints information while learning
6466
"""
67+
verbose = self.params['verbose']
6568
self._prepare_inputs(X, constraints, weights, prior)
6669
prior_inv = scipy.linalg.inv(self.M)
6770
s_best = self._total_loss(self.M, prior_inv)
@@ -93,7 +96,8 @@ def fit(self, X, constraints, weights=None, prior=None, verbose=False):
9396
break
9497
self.M = M_best
9598
else:
96-
print("Didn't converge after", it, "iterations. Final loss:", s_best)
99+
if verbose:
100+
print("Didn't converge after", it, "iterations. Final loss:", s_best)
97101
return self
98102

99103
def _comparison_loss(self, metric):
@@ -119,18 +123,47 @@ def _gradient(self, metric, prior_inv):
119123
(1-np.sqrt(dab/dcd))*np.outer(vcd, vcd))
120124
return dMetric
121125

122-
@classmethod
123-
def prepare_constraints(cls, labels, num_constraints):
124-
C = np.empty((num_constraints,4), dtype=int)
125-
a, c = np.random.randint(len(labels), size=(2,num_constraints))
126-
for i,(al,cl) in enumerate(zip(labels[a],labels[c])):
127-
C[i,1] = choice(np.nonzero(labels == al)[0])
128-
C[i,3] = choice(np.nonzero(labels != cl)[0])
129-
C[:,0] = a
130-
C[:,2] = c
131-
return C
132-
133-
134126
def _regularization_loss(metric, prior_inv):
135127
sign, logdet = np.linalg.slogdet(metric)
136128
return np.sum(metric * prior_inv) - sign * logdet
129+
130+
class LSML_Supervised(LSML):
131+
def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_constraints=None, weights=None, verbose=False):
132+
"""Initialize the learner.
133+
134+
Parameters
135+
----------
136+
tol : float, optional
137+
max_iter : int, optional
138+
prior : (d x d) matrix, optional
139+
guess at a metric [default: covariance(X)]
140+
num_constraints: int, needed for .fit()
141+
weights : (m,) array of floats, optional
142+
scale factor for each constraint
143+
verbose : bool, optional
144+
if True, prints information while learning
145+
"""
146+
LSML.__init__(self, tol=tol, max_iter=max_iter, verbose=verbose)
147+
self.params.update({
148+
'prior': prior,
149+
'num_constraints': num_constraints,
150+
'weights': weights,
151+
})
152+
153+
def fit(self, X, labels):
154+
"""Create constraints from labels and learn the LSML model.
155+
Needs num_constraints specified in constructor.
156+
157+
Parameters
158+
----------
159+
X : (n x d) data matrix
160+
each row corresponds to a single instance
161+
labels : (n) data labels
162+
"""
163+
num_constraints = self.params['num_constraints']
164+
if num_constraints is None:
165+
num_classes = np.unique(labels)
166+
num_constraints = 20*(len(num_classes))**2
167+
168+
C = relativeQuadruplets(labels, num_constraints)
169+
return LSML.fit(self, X, C, weights=self.params['weights'], prior=self.params['prior'])

0 commit comments

Comments
 (0)