@@ -23,7 +23,8 @@ class _BaseSCML(MahalanobisMixin):
23
23
24
24
def __init__ (self , beta = 1e-5 , basis = 'triplet_diffs' , n_basis = None ,
25
25
gamma = 5e-3 , max_iter = 10000 , output_iter = 500 , batch_size = 10 ,
26
- verbose = False , preprocessor = None , random_state = None ):
26
+ verbose = False , preprocessor = None , random_state = None ,
27
+ warm_start = False ):
27
28
self .beta = beta
28
29
self .basis = basis
29
30
self .n_basis = n_basis
@@ -34,6 +35,7 @@ def __init__(self, beta=1e-5, basis='triplet_diffs', n_basis=None,
34
35
self .verbose = verbose
35
36
self .preprocessor = preprocessor
36
37
self .random_state = random_state
38
+ self .warm_start = warm_start
37
39
super (_BaseSCML , self ).__init__ (preprocessor )
38
40
39
41
def _fit (self , triplets , basis = None , n_basis = None ):
@@ -74,10 +76,13 @@ def _fit(self, triplets, basis=None, n_basis=None):
74
76
75
77
n_triplets = triplets .shape [0 ]
76
78
77
- # weight vector
78
- w = np .zeros ((1 , n_basis ))
79
- # avarage obj gradient wrt weights
80
- avg_grad_w = np .zeros ((1 , n_basis ))
79
+ if self .warm_start is False or not hasattr (self , "w_" ):
80
+ # weight vector
81
+ self .w_ = np .zeros ((1 , n_basis ))
82
+ # avarage obj gradient wrt weights
83
+ self .avg_grad_w_ = np .zeros ((1 , n_basis ))
84
+ # l2 norm in time of all obj gradients wrt weights
85
+ self .ada_grad_w_ = np .zeros ((1 , n_basis ))
81
86
82
87
# l2 norm in time of all obj gradients wrt weights
83
88
ada_grad_w = np .zeros ((1 , n_basis ))
@@ -93,27 +98,28 @@ def _fit(self, triplets, basis=None, n_basis=None):
93
98
94
99
idx = rand_int [iter ]
95
100
96
- slack_val = 1 + np .matmul (dist_diff [idx , :], w .T )
101
+ slack_val = 1 + np .matmul (dist_diff [idx , :], self . w_ .T )
97
102
slack_mask = np .squeeze (slack_val > 0 , axis = 1 )
98
103
99
104
grad_w = np .sum (dist_diff [idx [slack_mask ], :],
100
105
axis = 0 , keepdims = True )/ self .batch_size
101
- avg_grad_w = (iter * avg_grad_w + grad_w ) / (iter + 1 )
102
106
103
- ada_grad_w = np . sqrt ( np . square ( ada_grad_w ) + np . square ( grad_w ))
107
+ self . avg_grad_w_ = ( iter * self . avg_grad_w_ + grad_w ) / ( iter + 1 )
104
108
105
- scale_f = - (iter + 1 ) / (self .gamma * (delta + ada_grad_w ))
109
+ self .ada_grad_w_ = np .sqrt (np .square (self .ada_grad_w_ ) + np .square (grad_w ))
110
+
111
+ scale_f = - (iter + 1 ) / (self .gamma * (delta + self .ada_grad_w_ ))
106
112
107
113
# proximal operator with negative trimming equivalent
108
- w = scale_f * np .minimum (avg_grad_w + self .beta , 0 )
114
+ self . w_ = scale_f * np .minimum (self . avg_grad_w_ + self .beta , 0 )
109
115
110
116
if (iter + 1 ) % self .output_iter == 0 :
111
117
# regularization part of obj function
112
- obj1 = np .sum (w )* self .beta
118
+ obj1 = np .sum (self . w_ )* self .beta
113
119
114
120
# Every triplet distance difference in the space given by L
115
121
# plus a slack of one
116
- slack_val = 1 + np .matmul (dist_diff , w .T )
122
+ slack_val = 1 + np .matmul (dist_diff , self . w_ .T )
117
123
# Mask of places with positive slack
118
124
slack_mask = slack_val > 0
119
125
@@ -129,7 +135,7 @@ def _fit(self, triplets, basis=None, n_basis=None):
129
135
# update the best
130
136
if obj < best_obj :
131
137
best_obj = obj
132
- best_w = w
138
+ best_w = self . w_
133
139
134
140
if self .verbose :
135
141
print ("max iteration reached." )
@@ -355,6 +361,13 @@ class SCML(_BaseSCML, _TripletsClassifierMixin):
355
361
random_state : int or numpy.RandomState or None, optional (default=None)
356
362
A pseudo random number generator object or a seed for it if int.
357
363
364
+ warm_start : bool, default=False
365
+ When set to True, reuse the solution of the previous call to fit as
366
+ initialization, otherwise, just erase the previous solution.
367
+ Repeatedly calling fit when warm_start is True can result in a different
368
+ solution than when calling fit a single time because of the way the data
369
+ is shuffled.
370
+
358
371
Attributes
359
372
----------
360
373
components_ : `numpy.ndarray`, shape=(n_features, n_features)
@@ -465,6 +478,13 @@ class SCML_Supervised(_BaseSCML, TransformerMixin):
465
478
random_state : int or numpy.RandomState or None, optional (default=None)
466
479
A pseudo random number generator object or a seed for it if int.
467
480
481
+ warm_start : bool, default=False
482
+ When set to True, reuse the solution of the previous call to fit as
483
+ initialization, otherwise, just erase the previous solution.
484
+ Repeatedly calling fit when warm_start is True can result in a different
485
+ solution than when calling fit a single time because of the way the data
486
+ is shuffled.
487
+
468
488
Attributes
469
489
----------
470
490
components_ : `numpy.ndarray`, shape=(n_features, n_features)
@@ -506,13 +526,14 @@ class SCML_Supervised(_BaseSCML, TransformerMixin):
506
526
def __init__ (self , k_genuine = 3 , k_impostor = 10 , beta = 1e-5 , basis = 'lda' ,
507
527
n_basis = None , gamma = 5e-3 , max_iter = 10000 , output_iter = 500 ,
508
528
batch_size = 10 , verbose = False , preprocessor = None ,
509
- random_state = None ):
529
+ random_state = None , warm_start = False ):
510
530
self .k_genuine = k_genuine
511
531
self .k_impostor = k_impostor
512
532
_BaseSCML .__init__ (self , beta = beta , basis = basis , n_basis = n_basis ,
513
533
max_iter = max_iter , output_iter = output_iter ,
514
534
batch_size = batch_size , verbose = verbose ,
515
- preprocessor = preprocessor , random_state = random_state )
535
+ preprocessor = preprocessor , random_state = random_state ,
536
+ warm_start = warm_start )
516
537
517
538
def fit (self , X , y ):
518
539
"""Create constraints from labels and learn the SCML model.
0 commit comments