Skip to content

[MRG + 1] Allow already formed tuples as an input. #92

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions metric_learn/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,13 @@ def random_subset(all_labels, num_preserved=np.inf, random_state=np.random):
partial_labels = np.array(all_labels, copy=True)
partial_labels[idx] = -1
return Constraints(partial_labels)

def wrap_pairs(X, constraints):
a = np.array(constraints[0])
b = np.array(constraints[1])
c = np.array(constraints[2])
d = np.array(constraints[3])
constraints = np.vstack((np.column_stack((a, b)), np.column_stack((c, d))))
y = np.vstack([np.ones((len(a), 1)), np.zeros((len(c), 1))])
pairs = X[constraints]
return pairs, y
44 changes: 27 additions & 17 deletions metric_learn/itml.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from sklearn.utils.validation import check_array, check_X_y

from .base_metric import BaseMetricLearner
from .constraints import Constraints
from .constraints import Constraints, wrap_pairs
from ._util import vector_norm


Expand Down Expand Up @@ -51,29 +51,37 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
self.A0 = A0
self.verbose = verbose

def _process_inputs(self, X, constraints, bounds):
self.X_ = X = check_array(X)
def _process_pairs(self, pairs, y, bounds):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would make sense that this function _process_pairs is shared across the class of pair metric learners. For instance, ruling the potential pairs that are identical is useful for all algorithms

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I agree, I added it to the small features TODO list at the end of the main issue: #91 (comment)

pairs, y = check_X_y(pairs, y, accept_sparse=False,
ensure_2d=False, allow_nd=True)
y = y.astype(bool)

# check to make sure that no two constrained vectors are identical
a,b,c,d = constraints
no_ident = vector_norm(X[a] - X[b]) > 1e-9
a, b = a[no_ident], b[no_ident]
no_ident = vector_norm(X[c] - X[d]) > 1e-9
c, d = c[no_ident], d[no_ident]
pos_pairs, neg_pairs = pairs[y], pairs[~y]
pos_no_ident = vector_norm(pos_pairs[:, 0, :] - pos_pairs[:, 1, :]) > 1e-9
pos_pairs = pos_pairs[pos_no_ident]
neg_no_ident = vector_norm(neg_pairs[:, 0, :] - neg_pairs[:, 1, :]) > 1e-9
neg_pairs = neg_pairs[neg_no_ident]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe showing a warning to the user when such pair is found and discarded is useful. in particular if a negative pair is made of two identical points, probably there is a problem with the way the user generated the pairs, or the dataset

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I agree. I added it to the TODO

# init bounds
if bounds is None:
X = np.unique(pairs.reshape(-1, pairs.shape[2]), axis=0)
self.bounds_ = np.percentile(pairwise_distances(X), (5, 95))
else:
assert len(bounds) == 2
self.bounds_ = bounds
self.bounds_[self.bounds_==0] = 1e-9
# init metric
if self.A0 is None:
self.A_ = np.identity(X.shape[1])
self.A_ = np.identity(pairs.shape[2])
else:
self.A_ = check_array(self.A0)
return a,b,c,d
pairs = np.vstack([pos_pairs, neg_pairs])
y = np.hstack([np.ones(len(pos_pairs)), np.zeros(len(neg_pairs))])
y = y.astype(bool)
return pairs, y


def fit(self, X, constraints, bounds=None):
def fit(self, pairs, y, bounds=None):
"""Learn the ITML model.

Parameters
Expand All @@ -86,17 +94,18 @@ def fit(self, X, constraints, bounds=None):
bounds : list (pos,neg) pairs, optional
bounds on similarity, s.t. d(X[a],X[b]) < pos and d(X[c],X[d]) > neg
"""
a,b,c,d = self._process_inputs(X, constraints, bounds)
pairs, y = self._process_pairs(pairs, y, bounds)
gamma = self.gamma
num_pos = len(a)
num_neg = len(c)
pos_pairs, neg_pairs = pairs[y], pairs[~y]
num_pos = len(pos_pairs)
num_neg = len(neg_pairs)
_lambda = np.zeros(num_pos + num_neg)
lambdaold = np.zeros_like(_lambda)
gamma_proj = 1. if gamma is np.inf else gamma/(gamma+1.)
pos_bhat = np.zeros(num_pos) + self.bounds_[0]
neg_bhat = np.zeros(num_neg) + self.bounds_[1]
pos_vv = self.X_[a] - self.X_[b]
neg_vv = self.X_[c] - self.X_[d]
pos_vv = pos_pairs[:, 0, :] - pos_pairs[:, 1, :]
neg_vv = neg_pairs[:, 0, :] - neg_pairs[:, 1, :]
A = self.A_

for it in xrange(self.max_iter):
Expand Down Expand Up @@ -195,4 +204,5 @@ def fit(self, X, y, random_state=np.random):
random_state=random_state)
pos_neg = c.positive_negative_pairs(num_constraints,
random_state=random_state)
return ITML.fit(self, X, pos_neg, bounds=self.bounds)
pairs, y = wrap_pairs(X, pos_neg)
return ITML.fit(self, pairs, y, bounds=self.bounds)
11 changes: 6 additions & 5 deletions metric_learn/lfda.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,11 @@ def _sum_outer(x):
def _eigh(a, b, dim):
try:
return scipy.sparse.linalg.eigsh(a, k=dim, M=b, which='LA')
except (ValueError, scipy.sparse.linalg.ArpackNoConvergence):
pass
try:
return scipy.linalg.eigh(a, b)
except np.linalg.LinAlgError:
pass
pass # scipy already tried eigh for us
except (ValueError, scipy.sparse.linalg.ArpackNoConvergence):
try:
return scipy.linalg.eigh(a, b)
except np.linalg.LinAlgError:
pass
return scipy.linalg.eig(a, b)
23 changes: 13 additions & 10 deletions metric_learn/lsml.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from sklearn.utils.validation import check_array, check_X_y

from .base_metric import BaseMetricLearner
from .constraints import Constraints
from .constraints import Constraints, wrap_pairs


class LSML(BaseMetricLearner):
Expand All @@ -35,11 +35,13 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None, verbose=False):
self.max_iter = max_iter
self.verbose = verbose

def _prepare_inputs(self, X, constraints, weights):
self.X_ = X = check_array(X)
a,b,c,d = constraints
self.vab_ = X[a] - X[b]
self.vcd_ = X[c] - X[d]
def _prepare_quadruplets(self, quadruplets, weights):
pairs = check_array(quadruplets, accept_sparse=False,
ensure_2d=False, allow_nd=True)

# check to make sure that no two constrained vectors are identical
self.vab_ = quadruplets[:, 0, :] - quadruplets[:, 1, :]
self.vcd_ = quadruplets[:, 2, :] - quadruplets[:, 3, :]
if self.vab_.shape != self.vcd_.shape:
raise ValueError('Constraints must have same length')
if weights is None:
Expand All @@ -48,6 +50,7 @@ def _prepare_inputs(self, X, constraints, weights):
self.w_ = weights
self.w_ /= self.w_.sum() # weights must sum to 1
if self.prior is None:
X = np.unique(pairs.reshape(-1, pairs.shape[2]), axis=0)
self.prior_inv_ = np.atleast_2d(np.cov(X, rowvar=False))
self.M_ = np.linalg.inv(self.prior_inv_)
else:
Expand All @@ -57,7 +60,7 @@ def _prepare_inputs(self, X, constraints, weights):
def metric(self):
return self.M_

def fit(self, X, constraints, weights=None):
def fit(self, quadruplets, weights=None):
"""Learn the LSML model.

Parameters
Expand All @@ -69,7 +72,7 @@ def fit(self, X, constraints, weights=None):
weights : (m,) array of floats, optional
scale factor for each constraint
"""
self._prepare_inputs(X, constraints, weights)
self._prepare_quadruplets(quadruplets, weights)
step_sizes = np.logspace(-10, 0, 10)
# Keep track of the best step size and the loss at that step.
l_best = 0
Expand Down Expand Up @@ -179,6 +182,6 @@ def fit(self, X, y, random_state=np.random):

c = Constraints.random_subset(y, self.num_labeled,
random_state=random_state)
pairs = c.positive_negative_pairs(num_constraints, same_length=True,
pos_neg = c.positive_negative_pairs(num_constraints, same_length=True,
random_state=random_state)
return LSML.fit(self, X, pairs, weights=self.weights)
return LSML.fit(self, X[np.column_stack(pos_neg)], weights=self.weights)
Loading