Skip to content

Commit 6ef1268

Browse files
committed
Adds warm_start parameter to SCML and SCML_Supervised
1 parent 4e0c444 commit 6ef1268

File tree

1 file changed

+36
-15
lines changed

1 file changed

+36
-15
lines changed

metric_learn/scml.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ class _BaseSCML(MahalanobisMixin):
2323

2424
def __init__(self, beta=1e-5, basis='triplet_diffs', n_basis=None,
2525
gamma=5e-3, max_iter=10000, output_iter=500, batch_size=10,
26-
verbose=False, preprocessor=None, random_state=None):
26+
verbose=False, preprocessor=None, random_state=None,
27+
warm_start=False):
2728
self.beta = beta
2829
self.basis = basis
2930
self.n_basis = n_basis
@@ -34,6 +35,7 @@ def __init__(self, beta=1e-5, basis='triplet_diffs', n_basis=None,
3435
self.verbose = verbose
3536
self.preprocessor = preprocessor
3637
self.random_state = random_state
38+
self.warm_start = warm_start
3739
super(_BaseSCML, self).__init__(preprocessor)
3840

3941
def _fit(self, triplets, basis=None, n_basis=None):
@@ -74,10 +76,13 @@ def _fit(self, triplets, basis=None, n_basis=None):
7476

7577
n_triplets = triplets.shape[0]
7678

77-
# weight vector
78-
w = np.zeros((1, n_basis))
79-
# avarage obj gradient wrt weights
80-
avg_grad_w = np.zeros((1, n_basis))
79+
if self.warm_start is False or not hasattr(self, "w_"):
80+
# weight vector
81+
self.w_ = np.zeros((1, n_basis))
82+
# avarage obj gradient wrt weights
83+
self.avg_grad_w_ = np.zeros((1, n_basis))
84+
# l2 norm in time of all obj gradients wrt weights
85+
self.ada_grad_w_ = np.zeros((1, n_basis))
8186

8287
# l2 norm in time of all obj gradients wrt weights
8388
ada_grad_w = np.zeros((1, n_basis))
@@ -93,27 +98,28 @@ def _fit(self, triplets, basis=None, n_basis=None):
9398

9499
idx = rand_int[iter]
95100

96-
slack_val = 1 + np.matmul(dist_diff[idx, :], w.T)
101+
slack_val = 1 + np.matmul(dist_diff[idx, :], self.w_.T)
97102
slack_mask = np.squeeze(slack_val > 0, axis=1)
98103

99104
grad_w = np.sum(dist_diff[idx[slack_mask], :],
100105
axis=0, keepdims=True)/self.batch_size
101-
avg_grad_w = (iter * avg_grad_w + grad_w) / (iter+1)
102106

103-
ada_grad_w = np.sqrt(np.square(ada_grad_w) + np.square(grad_w))
107+
self.avg_grad_w_ = (iter * self.avg_grad_w_ + grad_w) / (iter + 1)
104108

105-
scale_f = -(iter+1) / (self.gamma * (delta + ada_grad_w))
109+
self.ada_grad_w_ = np.sqrt(np.square(self.ada_grad_w_) + np.square(grad_w))
110+
111+
scale_f = -(iter+1) / (self.gamma * (delta + self.ada_grad_w_))
106112

107113
# proximal operator with negative trimming equivalent
108-
w = scale_f * np.minimum(avg_grad_w + self.beta, 0)
114+
self.w_ = scale_f * np.minimum(self.avg_grad_w_ + self.beta, 0)
109115

110116
if (iter + 1) % self.output_iter == 0:
111117
# regularization part of obj function
112-
obj1 = np.sum(w)*self.beta
118+
obj1 = np.sum(self.w_)*self.beta
113119

114120
# Every triplet distance difference in the space given by L
115121
# plus a slack of one
116-
slack_val = 1 + np.matmul(dist_diff, w.T)
122+
slack_val = 1 + np.matmul(dist_diff, self.w_.T)
117123
# Mask of places with positive slack
118124
slack_mask = slack_val > 0
119125

@@ -129,7 +135,7 @@ def _fit(self, triplets, basis=None, n_basis=None):
129135
# update the best
130136
if obj < best_obj:
131137
best_obj = obj
132-
best_w = w
138+
best_w = self.w_
133139

134140
if self.verbose:
135141
print("max iteration reached.")
@@ -355,6 +361,13 @@ class SCML(_BaseSCML, _TripletsClassifierMixin):
355361
random_state : int or numpy.RandomState or None, optional (default=None)
356362
A pseudo random number generator object or a seed for it if int.
357363
364+
warm_start : bool, default=False
365+
When set to True, reuse the solution of the previous call to fit as
366+
initialization, otherwise, just erase the previous solution.
367+
Repeatedly calling fit when warm_start is True can result in a different
368+
solution than when calling fit a single time because of the way the data
369+
is shuffled.
370+
358371
Attributes
359372
----------
360373
components_ : `numpy.ndarray`, shape=(n_features, n_features)
@@ -465,6 +478,13 @@ class SCML_Supervised(_BaseSCML, TransformerMixin):
465478
random_state : int or numpy.RandomState or None, optional (default=None)
466479
A pseudo random number generator object or a seed for it if int.
467480
481+
warm_start : bool, default=False
482+
When set to True, reuse the solution of the previous call to fit as
483+
initialization, otherwise, just erase the previous solution.
484+
Repeatedly calling fit when warm_start is True can result in a different
485+
solution than when calling fit a single time because of the way the data
486+
is shuffled.
487+
468488
Attributes
469489
----------
470490
components_ : `numpy.ndarray`, shape=(n_features, n_features)
@@ -506,13 +526,14 @@ class SCML_Supervised(_BaseSCML, TransformerMixin):
506526
def __init__(self, k_genuine=3, k_impostor=10, beta=1e-5, basis='lda',
507527
n_basis=None, gamma=5e-3, max_iter=10000, output_iter=500,
508528
batch_size=10, verbose=False, preprocessor=None,
509-
random_state=None):
529+
random_state=None, warm_start=False):
510530
self.k_genuine = k_genuine
511531
self.k_impostor = k_impostor
512532
_BaseSCML.__init__(self, beta=beta, basis=basis, n_basis=n_basis,
513533
max_iter=max_iter, output_iter=output_iter,
514534
batch_size=batch_size, verbose=verbose,
515-
preprocessor=preprocessor, random_state=random_state)
535+
preprocessor=preprocessor, random_state=random_state,
536+
warm_start=warm_start)
516537

517538
def fit(self, X, y):
518539
"""Create constraints from labels and learn the SCML model.

0 commit comments

Comments
 (0)