Skip to content

Commit f7b422f

Browse files
Minor PEP8 tweaks, fixing return value
1 parent b9402a9 commit f7b422f

File tree

7 files changed

+23
-15
lines changed

7 files changed

+23
-15
lines changed

metric_learn/base_metric.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,26 +48,28 @@ def transform(self, X=None):
4848

4949
def get_params(self, deep=False):
5050
"""Get parameters for this metric learner.
51-
51+
5252
Parameters
5353
----------
5454
deep: boolean, optional
55-
@WARNING doesn't do anything, only exists because scikit-learn has this on BaseEstimator.
56-
55+
@WARNING doesn't do anything, only exists because
56+
scikit-learn has this on BaseEstimator.
57+
5758
Returns
5859
-------
5960
params : mapping of string to any
6061
Parameter names mapped to their values.
6162
"""
6263
return self.params
63-
64+
6465
def set_params(self, **kwarg):
6566
"""Set the parameters of this metric learner.
66-
67+
6768
Overwrites any default parameters or parameters specified in constructor.
68-
69+
6970
Returns
7071
-------
7172
self
7273
"""
7374
self.params.update(kwarg)
75+
return self

metric_learn/itml.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def fit(self, X, constraints, bounds=None, A0=None, verbose=False):
7373
"""
7474
a,b,c,d = self._process_inputs(X, constraints, bounds, A0)
7575
gamma = self.params['gamma']
76+
conv_thresh = self.params['convergence_threshold']
7677
num_pos = len(a)
7778
num_neg = len(c)
7879
_lambda = np.zeros(num_pos + num_neg)
@@ -108,7 +109,7 @@ def fit(self, X, constraints, bounds=None, A0=None, verbose=False):
108109
conv = np.inf
109110
break
110111
conv = np.abs(lambdaold - _lambda).sum() / normsum
111-
if conv < self.params['convergence_threshold']:
112+
if conv < conv_thresh:
112113
break
113114
lambdaold = _lambda.copy()
114115
if verbose:

metric_learn/lfda.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(self, dim=None, k=7, metric='weighted'):
3434
'''
3535
if metric not in ('weighted', 'orthonormalized', 'plain'):
3636
raise ValueError('Invalid metric: %r' % metric)
37-
37+
3838
self.params = {
3939
'dim': dim,
4040
'metric': metric,
@@ -100,7 +100,8 @@ def fit(self, X, Y):
100100
if self.params['dim'] == d:
101101
vals, vecs = scipy.linalg.eigh(tSb, tSw)
102102
else:
103-
vals, vecs = scipy.sparse.linalg.eigsh(tSb, k=self.params['dim'], M=tSw, which='LA')
103+
vals, vecs = scipy.sparse.linalg.eigsh(tSb, k=self.params['dim'], M=tSw,
104+
which='LA')
104105

105106
order = np.argsort(-vals)[:self.params['dim']]
106107
vals = vals[order]

metric_learn/lmnn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ def _process_inputs(self, X, labels):
4747
self.X = X
4848
self.L = np.eye(X.shape[1])
4949
required_k = np.bincount(self.label_inds).min()
50-
k = self.params['k']
51-
assert k <= required_k, ('not enough class labels for specified k' +
52-
' (smallest class has %d)' % required_k)
50+
assert self.params['k'] <= required_k, (
51+
'not enough class labels for specified k'
52+
' (smallest class has %d)' % required_k)
5353

5454
def fit(self, X, labels, verbose=False):
5555
k = self.params['k']

metric_learn/lsml.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,11 @@ def fit(self, X, constraints, weights=None, prior=None, verbose=False):
6868
step_sizes = np.logspace(-10, 0, 10)
6969
if verbose:
7070
print('initial loss', s_best)
71+
tol = self.params['tol']
7172
for it in xrange(1, self.params['max_iter']+1):
7273
grad = self._gradient(self.M, prior_inv)
7374
grad_norm = scipy.linalg.norm(grad)
74-
if grad_norm < self.params['tol']:
75+
if grad_norm < tol:
7576
break
7677
if verbose:
7778
print('gradient norm', grad_norm)

metric_learn/nca.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def fit(self, X, labels):
3434
dX = X[:,None] - X[None] # shape (n, n, d)
3535
tmp = np.einsum('...i,...j->...ij', dX, dX) # shape (n, n, d, d)
3636
masks = labels[:,None] == labels[None]
37+
learning_rate = self.params['learning_rate']
3738
for it in xrange(self.params['max_iter']):
3839
for i, label in enumerate(labels):
3940
mask = masks[i]
@@ -45,7 +46,7 @@ def fit(self, X, labels):
4546

4647
t = softmax[:, None, None] * tmp[i] # shape (n, d, d)
4748
d = softmax[mask].sum() * t.sum(axis=0) - t[mask].sum(axis=0)
48-
A += self.params['learning_rate'] * A.dot(d)
49+
A += learning_rate * A.dot(d)
4950

5051
self.X = X
5152
self.A = A

metric_learn/sdml.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True):
2222
'''
2323
balance_param: trade off between sparsity and M0 prior
2424
sparsity_param: trade off between optimizer and sparseness (see graph_lasso)
25+
use_cov: controls prior matrix, will use the identity if use_cov=False
2526
'''
2627
self.params = {
2728
'balance_param': balance_param,
@@ -52,7 +53,8 @@ def fit(self, X, W, verbose=False):
5253
emp_cov = pinvh(P)
5354
# hack: ensure positive semidefinite
5455
emp_cov = emp_cov.T.dot(emp_cov)
55-
self.M, _ = graph_lasso(emp_cov, self.params['sparsity_param'], verbose=verbose)
56+
self.M, _ = graph_lasso(emp_cov, self.params['sparsity_param'],
57+
verbose=verbose)
5658
return self
5759

5860
@classmethod

0 commit comments

Comments
 (0)