Skip to content

Commit b9402a9

Browse files
Merge pull request #18 from svecon/master
Added get_params, set_params and params moved to dictionary
2 parents b14d96c + 947a66b commit b9402a9

File tree

7 files changed

+79
-39
lines changed

7 files changed

+79
-39
lines changed

metric_learn/base_metric.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,29 @@ def transform(self, X=None):
4545
X = self.X
4646
L = self.transformer()
4747
return X.dot(L.T)
48+
49+
def get_params(self, deep=False):
50+
"""Get parameters for this metric learner.
51+
52+
Parameters
53+
----------
54+
deep: boolean, optional
55+
@WARNING doesn't do anything, only exists because scikit-learn has this on BaseEstimator.
56+
57+
Returns
58+
-------
59+
params : mapping of string to any
60+
Parameter names mapped to their values.
61+
"""
62+
return self.params
63+
64+
def set_params(self, **kwarg):
65+
"""Set the parameters of this metric learner.
66+
67+
Overwrites any default parameters or parameters specified in constructor.
68+
69+
Returns
70+
-------
71+
self
72+
"""
73+
self.params.update(kwarg)

metric_learn/itml.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@ def __init__(self, gamma=1., max_iters=1000, convergence_threshold=1e-3):
3030
max_iters : int, optional
3131
convergence_threshold : float, optional
3232
"""
33-
self.gamma = gamma
34-
self.max_iters = max_iters
35-
self.convergence_threshold = convergence_threshold
33+
self.params = {
34+
'gamma': gamma,
35+
'max_iters': max_iters,
36+
'convergence_threshold': convergence_threshold,
37+
}
3638

3739
def _process_inputs(self, X, constraints, bounds, A0):
3840
self.X = X
@@ -70,7 +72,7 @@ def fit(self, X, constraints, bounds=None, A0=None, verbose=False):
7072
initial regularization matrix, defaults to identity
7173
"""
7274
a,b,c,d = self._process_inputs(X, constraints, bounds, A0)
73-
gamma = self.gamma
75+
gamma = self.params['gamma']
7476
num_pos = len(a)
7577
num_neg = len(c)
7678
_lambda = np.zeros(num_pos + num_neg)
@@ -80,7 +82,7 @@ def fit(self, X, constraints, bounds=None, A0=None, verbose=False):
8082
neg_bhat = np.zeros(num_neg) + self.bounds[1]
8183
A = self.A
8284

83-
for it in xrange(self.max_iters):
85+
for it in xrange(self.params['max_iters']):
8486
# update positives
8587
vv = self.X[a] - self.X[b]
8688
for i,v in enumerate(vv):
@@ -106,7 +108,7 @@ def fit(self, X, constraints, bounds=None, A0=None, verbose=False):
106108
conv = np.inf
107109
break
108110
conv = np.abs(lambdaold - _lambda).sum() / normsum
109-
if conv < self.convergence_threshold:
111+
if conv < self.params['convergence_threshold']:
110112
break
111113
lambdaold = _lambda.copy()
112114
if verbose:

metric_learn/lfda.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,12 @@ def __init__(self, dim=None, k=7, metric='weighted'):
3434
'''
3535
if metric not in ('weighted', 'orthonormalized', 'plain'):
3636
raise ValueError('Invalid metric: %r' % metric)
37-
self.dim = dim
38-
self.metric = metric
39-
self.k = k
37+
38+
self.params = {
39+
'dim': dim,
40+
'metric': metric,
41+
'k': k,
42+
}
4043

4144
def transformer(self):
4245
return self._transformer
@@ -48,12 +51,12 @@ def _process_inputs(self, X, Y):
4851
unique_classes, Y = np.unique(Y, return_inverse=True)
4952
num_classes = len(unique_classes)
5053

51-
if self.dim is None:
52-
self.dim = d
53-
elif not 0 < self.dim <= d:
54+
if self.params['dim'] is None:
55+
self.params['dim'] = d
56+
elif not 0 < self.params['dim'] <= d:
5457
raise ValueError('Invalid embedding dimension, must be in [1,%d]' % d)
5558

56-
if not 0 < self.k < d:
59+
if not 0 < self.params['k'] < d:
5760
raise ValueError('Invalid k, must be in [0,%d]' % (d-1))
5861

5962
return X, Y, num_classes, n, d
@@ -74,7 +77,7 @@ def fit(self, X, Y):
7477
# classwise affinity matrix
7578
dist = pairwise_distances(Xc, metric='l2', squared=True)
7679
# distances to k-th nearest neighbor
77-
k = min(self.k, nc-1)
80+
k = min(self.params['k'], nc-1)
7881
sigma = np.sqrt(np.partition(dist, k, axis=0)[:,k])
7982

8083
local_scale = np.outer(sigma, sigma)
@@ -94,21 +97,22 @@ def fit(self, X, Y):
9497
tSw += tSw.T
9598
tSw /= 2
9699

97-
if self.dim == d:
100+
if self.params['dim'] == d:
98101
vals, vecs = scipy.linalg.eigh(tSb, tSw)
99102
else:
100-
vals, vecs = scipy.sparse.linalg.eigsh(tSb, k=self.dim, M=tSw, which='LA')
103+
vals, vecs = scipy.sparse.linalg.eigsh(tSb, k=self.params['dim'], M=tSw, which='LA')
101104

102-
order = np.argsort(-vals)[:self.dim]
105+
order = np.argsort(-vals)[:self.params['dim']]
103106
vals = vals[order]
104107
vecs = vecs[:,order]
105108

106-
if self.metric == 'weighted':
109+
if self.params['metric'] == 'weighted':
107110
vecs *= np.sqrt(vals)
108-
elif self.metric == 'orthonormalized':
111+
elif self.params['metric'] == 'orthonormalized':
109112
vecs, _ = np.linalg.qr(vecs)
110113

111114
self._transformer = vecs.T
115+
return self
112116

113117

114118
def _sum_outer(x):

metric_learn/lsml.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@ def __init__(self, tol=1e-3, max_iter=1000):
2424
tol : float, optional
2525
max_iter : int, optional
2626
"""
27-
self.tol = tol
28-
self.max_iter = max_iter
27+
self.params = {
28+
'tol': tol,
29+
'max_iter': max_iter,
30+
}
2931

3032
def _prepare_inputs(self, X, constraints, weights, prior):
3133
self.X = X
@@ -66,10 +68,10 @@ def fit(self, X, constraints, weights=None, prior=None, verbose=False):
6668
step_sizes = np.logspace(-10, 0, 10)
6769
if verbose:
6870
print('initial loss', s_best)
69-
for it in xrange(1, self.max_iter+1):
71+
for it in xrange(1, self.params['max_iter']+1):
7072
grad = self._gradient(self.M, prior_inv)
7173
grad_norm = scipy.linalg.norm(grad)
72-
if grad_norm < self.tol:
74+
if grad_norm < self.params['tol']:
7375
break
7476
if verbose:
7577
print('gradient norm', grad_norm)

metric_learn/nca.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111

1212
class NCA(BaseMetricLearner):
1313
def __init__(self, max_iter=100, learning_rate=0.01):
14-
self.max_iter = max_iter
15-
self.learning_rate = learning_rate
14+
self.params = {
15+
'max_iter': max_iter,
16+
'learning_rate': learning_rate,
17+
}
1618
self.A = None
1719

1820
def transformer(self):
@@ -32,7 +34,7 @@ def fit(self, X, labels):
3234
dX = X[:,None] - X[None] # shape (n, n, d)
3335
tmp = np.einsum('...i,...j->...ij', dX, dX) # shape (n, n, d, d)
3436
masks = labels[:,None] == labels[None]
35-
for it in xrange(self.max_iter):
37+
for it in xrange(self.params['max_iter']):
3638
for i, label in enumerate(labels):
3739
mask = masks[i]
3840
Ax = A.dot(X.T).T # shape (n, d)
@@ -43,7 +45,7 @@ def fit(self, X, labels):
4345

4446
t = softmax[:, None, None] * tmp[i] # shape (n, d, d)
4547
d = softmax[mask].sum() * t.sum(axis=0) - t[mask].sum(axis=0)
46-
A += self.learning_rate * A.dot(d)
48+
A += self.params['learning_rate'] * A.dot(d)
4749

4850
self.X = X
4951
self.A = A

metric_learn/rca.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ def __init__(self, dim=None):
2727
dim : int, optional
2828
embedding dimension (default: original dimension of data)
2929
"""
30-
self.dim = dim
30+
self.params = {
31+
'dim': dim,
32+
}
3133

3234
def transformer(self):
3335
return self._transformer
@@ -37,9 +39,9 @@ def _process_inputs(self, X, Y):
3739
self.X = X
3840
n, d = X.shape
3941

40-
if self.dim is None:
41-
self.dim = d
42-
elif not 0 < self.dim <= d:
42+
if self.params['dim'] is None:
43+
self.params['dim'] = d
44+
elif not 0 < self.params['dim'] <= d:
4345
raise ValueError('Invalid embedding dimension, must be in [1,%d]' % d)
4446

4547
Y = np.asanyarray(Y)
@@ -75,11 +77,11 @@ def fit(self, data, chunks):
7577
inner_cov = np.cov(chunk_data, rowvar=0, bias=1)
7678

7779
# Fisher Linear Discriminant projection
78-
if self.dim < d:
80+
if self.params['dim'] < d:
7981
total_cov = np.cov(data[chunk_mask], rowvar=0)
8082
tmp = np.linalg.lstsq(total_cov, inner_cov)[0]
8183
vals, vecs = np.linalg.eig(tmp)
82-
inds = np.argsort(vals)[:self.dim]
84+
inds = np.argsort(vals)[:self.params['dim']]
8385
A = vecs[:,inds]
8486
inner_cov = A.T.dot(inner_cov).dot(A)
8587
self._transformer = _inv_sqrtm(inner_cov).dot(A.T)

metric_learn/sdml.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,16 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True):
2323
balance_param: trade off between sparsity and M0 prior
2424
sparsity_param: trade off between optimizer and sparseness (see graph_lasso)
2525
'''
26-
self.balance_param = balance_param
27-
self.sparsity_param = sparsity_param
28-
self.use_cov = use_cov
26+
self.params = {
27+
'balance_param': balance_param,
28+
'sparsity_param': sparsity_param,
29+
'use_cov': use_cov,
30+
}
2931

3032
def _prepare_inputs(self, X, W):
3133
self.X = X
3234
# set up prior M
33-
if self.use_cov:
35+
if self.params['use_cov']:
3436
self.M = np.cov(X.T)
3537
else:
3638
self.M = np.identity(X.shape[1])
@@ -46,11 +48,11 @@ def fit(self, X, W, verbose=False):
4648
W: connectivity graph, (n x n). +1 for positive pairs, -1 for negative.
4749
"""
4850
self._prepare_inputs(X, W)
49-
P = pinvh(self.M) + self.balance_param * self.loss_matrix
51+
P = pinvh(self.M) + self.params['balance_param'] * self.loss_matrix
5052
emp_cov = pinvh(P)
5153
# hack: ensure positive semidefinite
5254
emp_cov = emp_cov.T.dot(emp_cov)
53-
self.M, _ = graph_lasso(emp_cov, self.sparsity_param, verbose=verbose)
55+
self.M, _ = graph_lasso(emp_cov, self.params['sparsity_param'], verbose=verbose)
5456
return self
5557

5658
@classmethod

0 commit comments

Comments
 (0)