@@ -160,6 +160,58 @@ def transform(self, X):
160
160
Input data transformed to the metric space by :math:`XL^{\\ top}`
161
161
"""
162
162
163
+ class BilinearMixin (BaseMetricLearner , metaclass = ABCMeta ):
164
+
165
+ def score_pairs (self , pairs ):
166
+ r"""
167
+ Parameters
168
+ ----------
169
+ pairs : array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2)
170
+ 3D Array of pairs to score, with each row corresponding to two points,
171
+ for 2D array of indices of pairs if the metric learner uses a
172
+ preprocessor.
173
+
174
+ Returns
175
+ -------
176
+ scores : `numpy.ndarray` of shape=(n_pairs,)
177
+ The learned Mahalanobis distance for every pair.
178
+ """
179
+ check_is_fitted (self , ['preprocessor_' , 'components_' ])
180
+ pairs = check_input (pairs , type_of_inputs = 'tuples' ,
181
+ preprocessor = self .preprocessor_ ,
182
+ estimator = self , tuple_size = 2 )
183
+ return np .dot (np .dot (pairs [:, 1 , :], self .components_ ), pairs [:, 0 , :].T )
184
+
185
+ def get_metric (self ):
186
+ check_is_fitted (self , 'components_' )
187
+ components = self .components_ .copy ()
188
+
189
+ def metric_fun (u , v ):
190
+ """This function computes the metric between u and v, according to the
191
+ previously learned metric.
192
+
193
+ Parameters
194
+ ----------
195
+ u : array-like, shape=(n_features,)
196
+ The first point involved in the distance computation.
197
+
198
+ v : array-like, shape=(n_features,)
199
+ The second point involved in the distance computation.
200
+
201
+ Returns
202
+ -------
203
+ distance : float
204
+ The distance between u and v according to the new metric.
205
+ """
206
+ u = validate_vector (u )
207
+ v = validate_vector (v )
208
+ return np .dot (np .dot (u , components ), v .T )
209
+
210
+ return metric_fun
211
+
212
+ def get_bilinear_matrix (self ):
213
+ check_is_fitted (self , 'components_' )
214
+ return self .components_
163
215
164
216
class MahalanobisMixin (BaseMetricLearner , MetricTransformer ,
165
217
metaclass = ABCMeta ):
0 commit comments