@@ -23,7 +23,8 @@ class _BaseSCML(MahalanobisMixin):
23
23
24
24
def __init__ (self , beta = 1e-5 , basis = 'triplet_diffs' , n_basis = None ,
25
25
gamma = 5e-3 , max_iter = 10000 , output_iter = 500 , batch_size = 10 ,
26
- verbose = False , preprocessor = None , random_state = None ):
26
+ verbose = False , preprocessor = None , random_state = None ,
27
+ warm_start = False ):
27
28
self .beta = beta
28
29
self .basis = basis
29
30
self .n_basis = n_basis
@@ -34,6 +35,7 @@ def __init__(self, beta=1e-5, basis='triplet_diffs', n_basis=None,
34
35
self .verbose = verbose
35
36
self .preprocessor = preprocessor
36
37
self .random_state = random_state
38
+ self .warm_start = warm_start
37
39
super (_BaseSCML , self ).__init__ (preprocessor )
38
40
39
41
def _fit (self , triplets , basis = None , n_basis = None ):
@@ -74,10 +76,13 @@ def _fit(self, triplets, basis=None, n_basis=None):
74
76
75
77
n_triplets = triplets .shape [0 ]
76
78
77
- # weight vector
78
- w = np .zeros ((1 , n_basis ))
79
- # avarage obj gradient wrt weights
80
- avg_grad_w = np .zeros ((1 , n_basis ))
79
+ if self .warm_start is False or not hasattr (self , "w_" ):
80
+ # weight vector
81
+ self .w_ = np .zeros ((1 , n_basis ))
82
+ # avarage obj gradient wrt weights
83
+ self .avg_grad_w_ = np .zeros ((1 , n_basis ))
84
+ # l2 norm in time of all obj gradients wrt weights
85
+ self .ada_grad_w_ = np .zeros ((1 , n_basis ))
81
86
82
87
# l2 norm in time of all obj gradients wrt weights
83
88
ada_grad_w = np .zeros ((1 , n_basis ))
@@ -93,27 +98,28 @@ def _fit(self, triplets, basis=None, n_basis=None):
93
98
94
99
idx = rand_int [iter ]
95
100
96
- slack_val = 1 + np .matmul (dist_diff [idx , :], w .T )
101
+ slack_val = 1 + np .matmul (dist_diff [idx , :], self . w_ .T )
97
102
slack_mask = np .squeeze (slack_val > 0 , axis = 1 )
98
103
99
104
grad_w = np .sum (dist_diff [idx [slack_mask ], :],
100
105
axis = 0 , keepdims = True )/ self .batch_size
101
- avg_grad_w = (iter * avg_grad_w + grad_w ) / (iter + 1 )
102
106
103
- ada_grad_w = np . sqrt ( np . square ( ada_grad_w ) + np . square ( grad_w ))
107
+ self . avg_grad_w_ = ( iter * self . avg_grad_w_ + grad_w ) / ( iter + 1 )
104
108
105
- scale_f = - (iter + 1 ) / (self .gamma * (delta + ada_grad_w ))
109
+ self .ada_grad_w_ = np .sqrt (np .square (self .ada_grad_w_ ) + np .square (grad_w ))
110
+
111
+ scale_f = - (iter + 1 ) / (self .gamma * (delta + self .ada_grad_w_ ))
106
112
107
113
# proximal operator with negative trimming equivalent
108
- w = scale_f * np .minimum (avg_grad_w + self .beta , 0 )
114
+ self . w_ = scale_f * np .minimum (self . avg_grad_w_ + self .beta , 0 )
109
115
110
116
if (iter + 1 ) % self .output_iter == 0 :
111
117
# regularization part of obj function
112
- obj1 = np .sum (w )* self .beta
118
+ obj1 = np .sum (self . w_ )* self .beta
113
119
114
120
# Every triplet distance difference in the space given by L
115
121
# plus a slack of one
116
- slack_val = 1 + np .matmul (dist_diff , w .T )
122
+ slack_val = 1 + np .matmul (dist_diff , self . w_ .T )
117
123
# Mask of places with positive slack
118
124
slack_mask = slack_val > 0
119
125
@@ -129,7 +135,7 @@ def _fit(self, triplets, basis=None, n_basis=None):
129
135
# update the best
130
136
if obj < best_obj :
131
137
best_obj = obj
132
- best_w = w
138
+ best_w = self . w_
133
139
134
140
if self .verbose :
135
141
print ("max iteration reached." )
@@ -359,6 +365,13 @@ class SCML(_BaseSCML, _TripletsClassifierMixin):
359
365
random_state : int or numpy.RandomState or None, optional (default=None)
360
366
A pseudo random number generator object or a seed for it if int.
361
367
368
+ warm_start : bool, default=False
369
+ When set to True, reuse the solution of the previous call to fit as
370
+ initialization, otherwise, just erase the previous solution.
371
+ Repeatedly calling fit when warm_start is True can result in a different
372
+ solution than when calling fit a single time because of the way the data
373
+ is shuffled.
374
+
362
375
Attributes
363
376
----------
364
377
components_ : `numpy.ndarray`, shape=(n_features, n_features)
@@ -469,6 +482,13 @@ class SCML_Supervised(_BaseSCML, TransformerMixin):
469
482
random_state : int or numpy.RandomState or None, optional (default=None)
470
483
A pseudo random number generator object or a seed for it if int.
471
484
485
+ warm_start : bool, default=False
486
+ When set to True, reuse the solution of the previous call to fit as
487
+ initialization, otherwise, just erase the previous solution.
488
+ Repeatedly calling fit when warm_start is True can result in a different
489
+ solution than when calling fit a single time because of the way the data
490
+ is shuffled.
491
+
472
492
Attributes
473
493
----------
474
494
components_ : `numpy.ndarray`, shape=(n_features, n_features)
@@ -510,13 +530,14 @@ class SCML_Supervised(_BaseSCML, TransformerMixin):
510
530
def __init__ (self , k_genuine = 3 , k_impostor = 10 , beta = 1e-5 , basis = 'lda' ,
511
531
n_basis = None , gamma = 5e-3 , max_iter = 10000 , output_iter = 500 ,
512
532
batch_size = 10 , verbose = False , preprocessor = None ,
513
- random_state = None ):
533
+ random_state = None , warm_start = False ):
514
534
self .k_genuine = k_genuine
515
535
self .k_impostor = k_impostor
516
536
_BaseSCML .__init__ (self , beta = beta , basis = basis , n_basis = n_basis ,
517
537
max_iter = max_iter , output_iter = output_iter ,
518
538
batch_size = batch_size , verbose = verbose ,
519
- preprocessor = preprocessor , random_state = random_state )
539
+ preprocessor = preprocessor , random_state = random_state ,
540
+ warm_start = warm_start )
520
541
521
542
def fit (self , X , y ):
522
543
"""Create constraints from labels and learn the SCML model.
0 commit comments