Skip to content

Commit 676165a

Browse files
maxi-marufoMaximiliano Marufo da Silva
authored and
Maximiliano Marufo da Silva
committed
Adds warm_start parameter to SCML and SCML_Supervised
1 parent dc7e449 commit 676165a

File tree

1 file changed

+36
-15
lines changed

1 file changed

+36
-15
lines changed

metric_learn/scml.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ class _BaseSCML(MahalanobisMixin):
2323

2424
def __init__(self, beta=1e-5, basis='triplet_diffs', n_basis=None,
2525
gamma=5e-3, max_iter=10000, output_iter=500, batch_size=10,
26-
verbose=False, preprocessor=None, random_state=None):
26+
verbose=False, preprocessor=None, random_state=None,
27+
warm_start=False):
2728
self.beta = beta
2829
self.basis = basis
2930
self.n_basis = n_basis
@@ -34,6 +35,7 @@ def __init__(self, beta=1e-5, basis='triplet_diffs', n_basis=None,
3435
self.verbose = verbose
3536
self.preprocessor = preprocessor
3637
self.random_state = random_state
38+
self.warm_start = warm_start
3739
super(_BaseSCML, self).__init__(preprocessor)
3840

3941
def _fit(self, triplets, basis=None, n_basis=None):
@@ -74,10 +76,13 @@ def _fit(self, triplets, basis=None, n_basis=None):
7476

7577
n_triplets = triplets.shape[0]
7678

77-
# weight vector
78-
w = np.zeros((1, n_basis))
79-
# avarage obj gradient wrt weights
80-
avg_grad_w = np.zeros((1, n_basis))
79+
if self.warm_start is False or not hasattr(self, "w_"):
80+
# weight vector
81+
self.w_ = np.zeros((1, n_basis))
82+
# avarage obj gradient wrt weights
83+
self.avg_grad_w_ = np.zeros((1, n_basis))
84+
# l2 norm in time of all obj gradients wrt weights
85+
self.ada_grad_w_ = np.zeros((1, n_basis))
8186

8287
# l2 norm in time of all obj gradients wrt weights
8388
ada_grad_w = np.zeros((1, n_basis))
@@ -93,27 +98,28 @@ def _fit(self, triplets, basis=None, n_basis=None):
9398

9499
idx = rand_int[iter]
95100

96-
slack_val = 1 + np.matmul(dist_diff[idx, :], w.T)
101+
slack_val = 1 + np.matmul(dist_diff[idx, :], self.w_.T)
97102
slack_mask = np.squeeze(slack_val > 0, axis=1)
98103

99104
grad_w = np.sum(dist_diff[idx[slack_mask], :],
100105
axis=0, keepdims=True)/self.batch_size
101-
avg_grad_w = (iter * avg_grad_w + grad_w) / (iter+1)
102106

103-
ada_grad_w = np.sqrt(np.square(ada_grad_w) + np.square(grad_w))
107+
self.avg_grad_w_ = (iter * self.avg_grad_w_ + grad_w) / (iter + 1)
104108

105-
scale_f = -(iter+1) / (self.gamma * (delta + ada_grad_w))
109+
self.ada_grad_w_ = np.sqrt(np.square(self.ada_grad_w_) + np.square(grad_w))
110+
111+
scale_f = -(iter+1) / (self.gamma * (delta + self.ada_grad_w_))
106112

107113
# proximal operator with negative trimming equivalent
108-
w = scale_f * np.minimum(avg_grad_w + self.beta, 0)
114+
self.w_ = scale_f * np.minimum(self.avg_grad_w_ + self.beta, 0)
109115

110116
if (iter + 1) % self.output_iter == 0:
111117
# regularization part of obj function
112-
obj1 = np.sum(w)*self.beta
118+
obj1 = np.sum(self.w_)*self.beta
113119

114120
# Every triplet distance difference in the space given by L
115121
# plus a slack of one
116-
slack_val = 1 + np.matmul(dist_diff, w.T)
122+
slack_val = 1 + np.matmul(dist_diff, self.w_.T)
117123
# Mask of places with positive slack
118124
slack_mask = slack_val > 0
119125

@@ -129,7 +135,7 @@ def _fit(self, triplets, basis=None, n_basis=None):
129135
# update the best
130136
if obj < best_obj:
131137
best_obj = obj
132-
best_w = w
138+
best_w = self.w_
133139

134140
if self.verbose:
135141
print("max iteration reached.")
@@ -359,6 +365,13 @@ class SCML(_BaseSCML, _TripletsClassifierMixin):
359365
random_state : int or numpy.RandomState or None, optional (default=None)
360366
A pseudo random number generator object or a seed for it if int.
361367
368+
warm_start : bool, default=False
369+
When set to True, reuse the solution of the previous call to fit as
370+
initialization, otherwise, just erase the previous solution.
371+
Repeatedly calling fit when warm_start is True can result in a different
372+
solution than when calling fit a single time because of the way the data
373+
is shuffled.
374+
362375
Attributes
363376
----------
364377
components_ : `numpy.ndarray`, shape=(n_features, n_features)
@@ -469,6 +482,13 @@ class SCML_Supervised(_BaseSCML, TransformerMixin):
469482
random_state : int or numpy.RandomState or None, optional (default=None)
470483
A pseudo random number generator object or a seed for it if int.
471484
485+
warm_start : bool, default=False
486+
When set to True, reuse the solution of the previous call to fit as
487+
initialization, otherwise, just erase the previous solution.
488+
Repeatedly calling fit when warm_start is True can result in a different
489+
solution than when calling fit a single time because of the way the data
490+
is shuffled.
491+
472492
Attributes
473493
----------
474494
components_ : `numpy.ndarray`, shape=(n_features, n_features)
@@ -510,13 +530,14 @@ class SCML_Supervised(_BaseSCML, TransformerMixin):
510530
def __init__(self, k_genuine=3, k_impostor=10, beta=1e-5, basis='lda',
511531
n_basis=None, gamma=5e-3, max_iter=10000, output_iter=500,
512532
batch_size=10, verbose=False, preprocessor=None,
513-
random_state=None):
533+
random_state=None, warm_start=False):
514534
self.k_genuine = k_genuine
515535
self.k_impostor = k_impostor
516536
_BaseSCML.__init__(self, beta=beta, basis=basis, n_basis=n_basis,
517537
max_iter=max_iter, output_iter=output_iter,
518538
batch_size=batch_size, verbose=verbose,
519-
preprocessor=preprocessor, random_state=random_state)
539+
preprocessor=preprocessor, random_state=random_state,
540+
warm_start=warm_start)
520541

521542
def fit(self, X, y):
522543
"""Create constraints from labels and learn the SCML model.

0 commit comments

Comments
 (0)