-
Notifications
You must be signed in to change notification settings - Fork 229
[MRG + 1] Allow already formed tuples as an input. #92
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
9f5c998
4c887d7
3acf31a
a7e4807
903f174
374a851
b4bdec4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,7 @@ | |
from sklearn.utils.validation import check_array, check_X_y | ||
|
||
from .base_metric import BaseMetricLearner | ||
from .constraints import Constraints | ||
from .constraints import Constraints, wrap_pairs | ||
from ._util import vector_norm | ||
|
||
|
||
|
@@ -51,29 +51,37 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, | |
self.A0 = A0 | ||
self.verbose = verbose | ||
|
||
def _process_inputs(self, X, constraints, bounds): | ||
self.X_ = X = check_array(X) | ||
def _process_pairs(self, pairs, y, bounds): | ||
pairs, y = check_X_y(pairs, y, accept_sparse=False, | ||
ensure_2d=False, allow_nd=True) | ||
y = y.astype(bool) | ||
|
||
# check to make sure that no two constrained vectors are identical | ||
a,b,c,d = constraints | ||
no_ident = vector_norm(X[a] - X[b]) > 1e-9 | ||
a, b = a[no_ident], b[no_ident] | ||
no_ident = vector_norm(X[c] - X[d]) > 1e-9 | ||
c, d = c[no_ident], d[no_ident] | ||
pos_pairs, neg_pairs = pairs[y], pairs[~y] | ||
pos_no_ident = vector_norm(pos_pairs[:, 0, :] - pos_pairs[:, 1, :]) > 1e-9 | ||
pos_pairs = pos_pairs[pos_no_ident] | ||
neg_no_ident = vector_norm(neg_pairs[:, 0, :] - neg_pairs[:, 1, :]) > 1e-9 | ||
neg_pairs = neg_pairs[neg_no_ident] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe showing a warning to the user when such pair is found and discarded is useful. in particular if a negative pair is made of two identical points, probably there is a problem with the way the user generated the pairs, or the dataset There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I agree. I added it to the TODO |
||
# init bounds | ||
if bounds is None: | ||
X = np.unique(pairs.reshape(-1, pairs.shape[2]), axis=0) | ||
self.bounds_ = np.percentile(pairwise_distances(X), (5, 95)) | ||
else: | ||
assert len(bounds) == 2 | ||
self.bounds_ = bounds | ||
self.bounds_[self.bounds_==0] = 1e-9 | ||
# init metric | ||
if self.A0 is None: | ||
self.A_ = np.identity(X.shape[1]) | ||
self.A_ = np.identity(pairs.shape[2]) | ||
else: | ||
self.A_ = check_array(self.A0) | ||
return a,b,c,d | ||
pairs = np.vstack([pos_pairs, neg_pairs]) | ||
y = np.hstack([np.ones(len(pos_pairs)), np.zeros(len(neg_pairs))]) | ||
y = y.astype(bool) | ||
return pairs, y | ||
|
||
|
||
def fit(self, X, constraints, bounds=None): | ||
def fit(self, pairs, y, bounds=None): | ||
"""Learn the ITML model. | ||
|
||
Parameters | ||
|
@@ -86,17 +94,18 @@ def fit(self, X, constraints, bounds=None): | |
bounds : list (pos,neg) pairs, optional | ||
bounds on similarity, s.t. d(X[a],X[b]) < pos and d(X[c],X[d]) > neg | ||
""" | ||
a,b,c,d = self._process_inputs(X, constraints, bounds) | ||
pairs, y = self._process_pairs(pairs, y, bounds) | ||
gamma = self.gamma | ||
num_pos = len(a) | ||
num_neg = len(c) | ||
pos_pairs, neg_pairs = pairs[y], pairs[~y] | ||
num_pos = len(pos_pairs) | ||
num_neg = len(neg_pairs) | ||
_lambda = np.zeros(num_pos + num_neg) | ||
lambdaold = np.zeros_like(_lambda) | ||
gamma_proj = 1. if gamma is np.inf else gamma/(gamma+1.) | ||
pos_bhat = np.zeros(num_pos) + self.bounds_[0] | ||
neg_bhat = np.zeros(num_neg) + self.bounds_[1] | ||
pos_vv = self.X_[a] - self.X_[b] | ||
neg_vv = self.X_[c] - self.X_[d] | ||
pos_vv = pos_pairs[:, 0, :] - pos_pairs[:, 1, :] | ||
neg_vv = neg_pairs[:, 0, :] - neg_pairs[:, 1, :] | ||
A = self.A_ | ||
|
||
for it in xrange(self.max_iter): | ||
|
@@ -195,4 +204,5 @@ def fit(self, X, y, random_state=np.random): | |
random_state=random_state) | ||
pos_neg = c.positive_negative_pairs(num_constraints, | ||
random_state=random_state) | ||
return ITML.fit(self, X, pos_neg, bounds=self.bounds) | ||
pairs, y = wrap_pairs(X, pos_neg) | ||
return ITML.fit(self, pairs, y, bounds=self.bounds) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would make sense that this function
_process_pairs
is shared across the class of pair metric learners. For instance, ruling the potential pairs that are identical is useful for all algorithmsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I agree, I added it to the small features TODO list at the end of the main issue: #91 (comment)