Skip to content

Fit constraints #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 20, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions metric_learn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import absolute_import

from .itml import ITML
from .itml import ITML, ITML_Supervised
from .lmnn import LMNN
from .lsml import LSML
from .sdml import SDML
from .lsml import LSML, LSML_Supervised
from .sdml import SDML, SDML_Supervised
from .nca import NCA
from .lfda import LFDA
from .rca import RCA
from .rca import RCA, RCA_Supervised
from .constraints import adjacencyMatrix, positiveNegativePairs, relativeQuadruplets, chunks
62 changes: 62 additions & 0 deletions metric_learn/constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
""" Helper class that can generate different types of constraints from supervised data labels."""

import numpy as np
import random
from six.moves import xrange

# @TODO: consider creating a stateful class
# https://github.com/all-umass/metric-learn/pull/19#discussion_r67386226

def adjacencyMatrix(labels, num_points, num_constraints):
a, c = np.random.randint(len(labels), size=(2,num_constraints))
b, d = np.empty((2, num_constraints), dtype=int)
for i,(al,cl) in enumerate(zip(labels[a],labels[c])):
b[i] = random.choice(np.nonzero(labels == al)[0])
d[i] = random.choice(np.nonzero(labels != cl)[0])
W = np.zeros((num_points,num_points))
W[a,b] = 1
W[c,d] = -1
# make W symmetric
W[b,a] = 1
W[d,c] = -1
return W

def positiveNegativePairs(labels, num_points, num_constraints):
ac,bd = np.random.randint(num_points, size=(2,num_constraints))
pos = labels[ac] == labels[bd]
a,c = ac[pos], ac[~pos]
b,d = bd[pos], bd[~pos]
return a,b,c,d

def relativeQuadruplets(labels, num_constraints):
C = np.empty((num_constraints,4), dtype=int)
a, c = np.random.randint(len(labels), size=(2,num_constraints))
for i,(al,cl) in enumerate(zip(labels[a],labels[c])):
C[i,1] = random.choice(np.nonzero(labels == al)[0])
C[i,3] = random.choice(np.nonzero(labels != cl)[0])
C[:,0] = a
C[:,2] = c
return C

def chunks(Y, num_chunks=100, chunk_size=2, seed=None):
# @TODO: remove seed from params and use numpy RandomState
# https://github.com/all-umass/metric-learn/pull/19#discussion_r67386666
random.seed(seed)
chunks = -np.ones_like(Y, dtype=int)
uniq, lookup = np.unique(Y, return_inverse=True)
all_inds = [set(np.where(lookup==c)[0]) for c in xrange(len(uniq))]
idx = 0
while idx < num_chunks and all_inds:
c = random.randint(0, len(all_inds)-1)
inds = all_inds[c]
if len(inds) < chunk_size:
del all_inds[c]
continue
ii = random.sample(inds, chunk_size)
inds.difference_update(ii)
chunks[ii] = idx
idx += 1
if idx < num_chunks:
raise ValueError('Unable to make %d chunks of %d examples each' %
(num_chunks, chunk_size))
return chunks
60 changes: 50 additions & 10 deletions metric_learn/itml.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
from six.moves import xrange
from sklearn.metrics import pairwise_distances
from .base_metric import BaseMetricLearner
from .constraints import positiveNegativePairs


class ITML(BaseMetricLearner):
"""Information Theoretic Metric Learning (ITML)"""
def __init__(self, gamma=1., max_iters=1000, convergence_threshold=1e-3):
def __init__(self, gamma=1., max_iters=1000, convergence_threshold=1e-3, verbose=False):
"""Initialize the learner.

Parameters
Expand All @@ -29,11 +30,14 @@ def __init__(self, gamma=1., max_iters=1000, convergence_threshold=1e-3):
value for slack variables
max_iters : int, optional
convergence_threshold : float, optional
verbose : bool, optional
if True, prints information while learning
"""
self.params = {
'gamma': gamma,
'max_iters': max_iters,
'convergence_threshold': convergence_threshold,
'verbose': verbose,
}

def _process_inputs(self, X, constraints, bounds, A0):
Expand All @@ -57,7 +61,7 @@ def _process_inputs(self, X, constraints, bounds, A0):
self.A = A0
return a,b,c,d

def fit(self, X, constraints, bounds=None, A0=None, verbose=False):
def fit(self, X, constraints, bounds=None, A0=None):
"""Learn the ITML model.

Parameters
Expand All @@ -71,6 +75,7 @@ def fit(self, X, constraints, bounds=None, A0=None, verbose=False):
A0 : (d x d) matrix, optional
initial regularization matrix, defaults to identity
"""
verbose = self.params['verbose']
a,b,c,d = self._process_inputs(X, constraints, bounds, A0)
gamma = self.params['gamma']
conv_thresh = self.params['convergence_threshold']
Expand Down Expand Up @@ -121,14 +126,6 @@ def fit(self, X, constraints, bounds=None, A0=None, verbose=False):
def metric(self):
return self.A

@classmethod
def prepare_constraints(self, labels, num_points, num_constraints):
ac,bd = np.random.randint(num_points, size=(2,num_constraints))
pos = labels[ac] == labels[bd]
a,c = ac[pos], ac[~pos]
b,d = bd[pos], bd[~pos]
return a,b,c,d

# hack around lack of axis kwarg in older numpy versions
try:
np.linalg.norm([[4]], axis=1)
Expand All @@ -138,3 +135,46 @@ def _vector_norm(X):
else:
def _vector_norm(X):
return np.linalg.norm(X, axis=1)


class ITML_Supervised(ITML):
"""Information Theoretic Metric Learning (ITML)"""
def __init__(self, gamma=1., max_iters=1000, convergence_threshold=1e-3, num_constraints=None,
bounds=None, A0=None, verbose=False):
"""Initialize the learner.

Parameters
----------
gamma : float, optional
value for slack variables
max_iters : int, optional
convergence_threshold : float, optional
num_constraints: int, needed for .fit()
verbose : bool, optional
if True, prints information while learning
"""
ITML.__init__(self, gamma=gamma, max_iters=max_iters,
convergence_threshold=convergence_threshold, verbose=verbose)
self.params.update({
'num_constraints': num_constraints,
'bounds': bounds,
'A0': A0,
})

def fit(self, X, labels):
"""Create constraints from labels and learn the ITML model.
Needs num_constraints specified in constructor.

Parameters
----------
X : (n x d) data matrix
each row corresponds to a single instance
labels : (n) data labels
"""
num_constraints = self.params['num_constraints']
if num_constraints is None:
num_classes = np.unique(labels)
num_constraints = 20*(len(num_classes))**2

C = positiveNegativePairs(labels, X.shape[0], num_constraints)
return ITML.fit(self, X, C, bounds=self.params['bounds'], A0=self.params['A0'])
13 changes: 7 additions & 6 deletions metric_learn/lmnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ def transformer(self):
# slower Python version
class python_LMNN(_base_LMNN):
def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7,
regularization=0.5, convergence_tol=0.001):
regularization=0.5, convergence_tol=0.001, verbose=False):
"""Initialize the LMNN object

k: number of neighbors to consider. (does not include self-edges)
regularization: weighting of pull and push terms
"""
_base_LMNN.__init__(self, k=k, min_iter=min_iter, max_iter=max_iter,
learn_rate=learn_rate, regularization=regularization,
convergence_tol=convergence_tol)
convergence_tol=convergence_tol, verbose=verbose)

def _process_inputs(self, X, labels):
num_pts = X.shape[0]
Expand All @@ -51,8 +51,9 @@ def _process_inputs(self, X, labels):
'not enough class labels for specified k'
' (smallest class has %d)' % required_k)

def fit(self, X, labels, verbose=False):
def fit(self, X, labels):
k = self.params['k']
verbose = self.params['verbose']
reg = self.params['regularization']
learn_rate = self.params['learn_rate']
convergence_tol = self.params['convergence_tol']
Expand Down Expand Up @@ -236,12 +237,12 @@ def _sum_outer_products(data, a_inds, b_inds, weights=None):

class LMNN(_base_LMNN):
def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7,
regularization=0.5, convergence_tol=0.001, use_pca=True):
regularization=0.5, convergence_tol=0.001, use_pca=True, verbose=False):
_base_LMNN.__init__(self, k=k, min_iter=min_iter, max_iter=max_iter,
learn_rate=learn_rate, regularization=regularization,
convergence_tol=convergence_tol, use_pca=use_pca)
convergence_tol=convergence_tol, use_pca=use_pca, verbose=verbose)

def fit(self, X, labels, verbose=False):
def fit(self, X, labels):
self.X = X
self.L = np.eye(X.shape[1])
labels = MulticlassLabels(labels.astype(np.float64))
Expand Down
67 changes: 50 additions & 17 deletions metric_learn/lsml.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,24 @@
from random import choice
from six.moves import xrange
from .base_metric import BaseMetricLearner
from .constraints import relativeQuadruplets


class LSML(BaseMetricLearner):
def __init__(self, tol=1e-3, max_iter=1000):
def __init__(self, tol=1e-3, max_iter=1000, verbose=False):
"""Initialize the learner.

Parameters
----------
tol : float, optional
max_iter : int, optional
verbose : bool, optional
if True, prints information while learning
"""
self.params = {
'tol': tol,
'max_iter': max_iter,
'verbose': verbose,
}

def _prepare_inputs(self, X, constraints, weights, prior):
Expand All @@ -46,7 +50,7 @@ def _prepare_inputs(self, X, constraints, weights, prior):
def metric(self):
return self.M

def fit(self, X, constraints, weights=None, prior=None, verbose=False):
def fit(self, X, constraints, weights=None, prior=None):
"""Learn the LSML model.

Parameters
Expand All @@ -59,9 +63,8 @@ def fit(self, X, constraints, weights=None, prior=None, verbose=False):
scale factor for each constraint
prior : (d x d) matrix, optional
guess at a metric [default: covariance(X)]
verbose : bool, optional
if True, prints information while learning
"""
verbose = self.params['verbose']
self._prepare_inputs(X, constraints, weights, prior)
prior_inv = scipy.linalg.inv(self.M)
s_best = self._total_loss(self.M, prior_inv)
Expand Down Expand Up @@ -93,7 +96,8 @@ def fit(self, X, constraints, weights=None, prior=None, verbose=False):
break
self.M = M_best
else:
print("Didn't converge after", it, "iterations. Final loss:", s_best)
if verbose:
print("Didn't converge after", it, "iterations. Final loss:", s_best)
return self

def _comparison_loss(self, metric):
Expand All @@ -119,18 +123,47 @@ def _gradient(self, metric, prior_inv):
(1-np.sqrt(dab/dcd))*np.outer(vcd, vcd))
return dMetric

@classmethod
def prepare_constraints(cls, labels, num_constraints):
C = np.empty((num_constraints,4), dtype=int)
a, c = np.random.randint(len(labels), size=(2,num_constraints))
for i,(al,cl) in enumerate(zip(labels[a],labels[c])):
C[i,1] = choice(np.nonzero(labels == al)[0])
C[i,3] = choice(np.nonzero(labels != cl)[0])
C[:,0] = a
C[:,2] = c
return C


def _regularization_loss(metric, prior_inv):
sign, logdet = np.linalg.slogdet(metric)
return np.sum(metric * prior_inv) - sign * logdet

class LSML_Supervised(LSML):
def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_constraints=None, weights=None, verbose=False):
"""Initialize the learner.

Parameters
----------
tol : float, optional
max_iter : int, optional
prior : (d x d) matrix, optional
guess at a metric [default: covariance(X)]
num_constraints: int, needed for .fit()
weights : (m,) array of floats, optional
scale factor for each constraint
verbose : bool, optional
if True, prints information while learning
"""
LSML.__init__(self, tol=tol, max_iter=max_iter, verbose=verbose)
self.params.update({
'prior': prior,
'num_constraints': num_constraints,
'weights': weights,
})

def fit(self, X, labels):
"""Create constraints from labels and learn the LSML model.
Needs num_constraints specified in constructor.

Parameters
----------
X : (n x d) data matrix
each row corresponds to a single instance
labels : (n) data labels
"""
num_constraints = self.params['num_constraints']
if num_constraints is None:
num_classes = np.unique(labels)
num_constraints = 20*(len(num_classes))**2

C = relativeQuadruplets(labels, num_constraints)
return LSML.fit(self, X, C, weights=self.params['weights'], prior=self.params['prior'])
Loading