Skip to content

Commit 5e29295

Browse files
author
William de Vazelhes
committed
ENH: add squared option
1 parent d2c0614 commit 5e29295

File tree

2 files changed

+34
-3
lines changed

2 files changed

+34
-3
lines changed

metric_learn/base_metric.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,24 +243,35 @@ def transform(self, X):
243243
def get_metric(self):
244244
transformer_T = self.transformer_.T.copy()
245245

246-
def metric_fun(u, v):
246+
def metric_fun(u, v, squared=False):
247247
"""This function computes the metric between u and v, according to the
248248
previously learned metric.
249249
250250
Parameters
251251
----------
252252
u : array-like, shape=(n_features,)
253253
The first point involved in the distance computation.
254+
254255
v : array-like, shape=(n_features,)
255256
The second point involved in the distance computation.
257+
258+
squared : `bool`
259+
If True, the function will return the squared metric between u and
260+
v, which is faster to compute.
261+
256262
Returns
257263
-------
258264
distance: float
259265
The distance between u and v according to the new metric.
260266
"""
261267
u = validate_vector(u)
262268
v = validate_vector(v)
263-
return euclidean(u.dot(transformer_T), v.dot(transformer_T))
269+
transformed_diff = (u - v).dot(transformer_T)
270+
dist = transformed_diff.dot(transformed_diff.T)
271+
if not squared:
272+
dist = np.sqrt(dist)
273+
return dist
274+
264275
return metric_fun
265276

266277
get_metric.__doc__ = BaseMetricLearner.get_metric.__doc__

test/test_mahalanobis_mixin.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,4 +245,24 @@ def test_get_metric_compatible_with_scikit_learn(estimator, build_dataset):
245245
set_random_state(model)
246246
model.fit(input_data, labels)
247247
clustering = DBSCAN(metric=model.get_metric())
248-
clustering.fit(X)
248+
clustering.fit(X)
249+
250+
251+
@pytest.mark.parametrize('estimator, build_dataset', metric_learners,
252+
ids=ids_metric_learners)
253+
def test_get_squared_metric(estimator, build_dataset):
254+
"""Test that the squared metric returned is indeed the square of the
255+
metric"""
256+
input_data, labels, _, X = build_dataset()
257+
model = clone(estimator)
258+
set_random_state(model)
259+
model.fit(input_data, labels)
260+
metric = model.get_metric()
261+
262+
n_features = X.shape[1]
263+
for seed in range(10):
264+
rng = np.random.RandomState(seed)
265+
a, b = (rng.randn(n_features) for _ in range(2))
266+
assert_allclose(metric(a, b, squared=True),
267+
metric(a, b, squared=False)**2,
268+
rtol=1e-15)

0 commit comments

Comments
 (0)