Skip to content

Commit aaf8d44

Browse files
authored
score_pairs refactor (#333)
* Remove 3.9 from compatibility * First draft of refactoring BaseMetricLearner and Mahalanobis Learner * Avoid warning related to score_pairs deprecation in tests of pair_calibraiton * Minor fix * Replaced score_pairs with pair_distance in tests * Replace score_pairs with pair_distance inb docs. * Fix weird commit * Update classifiers to use pair_similarity * Updated rst docs * Fix identation * Update docs of score_pairs, get_metric * Add deprecation Test. Fix identation * Fixed changes requested 1 * Fixed changes requested 2 * Add equivalence test, p_dist == p_score * Fix tests and identation. * Fixed changes requested 3 * Fix identation * Last requested changes * Last small detail
1 parent e2c3e92 commit aaf8d44

9 files changed

+360
-106
lines changed

doc/introduction.rst

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -123,26 +123,3 @@ to the following resources:
123123
Survey <http://dx.doi.org/10.1561/2200000019>`_ (2012)
124124
- **Book:** `Metric Learning
125125
<http://dx.doi.org/10.2200/S00626ED1V01Y201501AIM030>`_ (2015)
126-
127-
.. Methods [TO MOVE TO SUPERVISED/WEAK SECTIONS]
128-
.. =============================================
129-
130-
.. Currently, each metric learning algorithm supports the following methods:
131-
132-
.. - ``fit(...)``, which learns the model.
133-
.. - ``get_mahalanobis_matrix()``, which returns a Mahalanobis matrix
134-
.. - ``get_metric()``, which returns a function that takes as input two 1D
135-
arrays and outputs the learned metric score on these two points
136-
.. :math:`M = L^{\top}L` such that distance between vectors ``x`` and
137-
.. ``y`` can be computed as :math:`\sqrt{\left(x-y\right)M\left(x-y\right)}`.
138-
.. - ``components_from_metric(metric)``, which returns a transformation matrix
139-
.. :math:`L \in \mathbb{R}^{D \times d}`, which can be used to convert a
140-
.. data matrix :math:`X \in \mathbb{R}^{n \times d}` to the
141-
.. :math:`D`-dimensional learned metric space :math:`X L^{\top}`,
142-
.. in which standard Euclidean distances may be used.
143-
.. - ``transform(X)``, which applies the aforementioned transformation.
144-
.. - ``score_pairs(pairs)`` which returns the distance between pairs of
145-
.. points. ``pairs`` should be a 3D array-like of pairs of shape ``(n_pairs,
146-
.. 2, n_features)``, or it can be a 2D array-like of pairs indicators of
147-
.. shape ``(n_pairs, 2)`` (see section :ref:`preprocessor_section` for more
148-
.. details).

doc/supervised.rst

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,10 @@ Also, as explained before, our metric learners has learn a distance between
6969
points. You can use this distance in two main ways:
7070

7171
- You can either return the distance between pairs of points using the
72-
`score_pairs` function:
72+
`pair_distance` function:
7373

74-
>>> nca.score_pairs([[[3.5, 3.6], [5.6, 2.4]], [[1.2, 4.2], [2.1, 6.4]]])
75-
array([0.49627072, 3.65287282])
74+
>>> nca.pair_distance([[[3.5, 3.6], [5.6, 2.4]], [[1.2, 4.2], [2.1, 6.4]], [[3.3, 7.8], [10.9, 0.1]]])
75+
array([0.49627072, 3.65287282, 6.06079877])
7676

7777
- Or you can return a function that will return the distance (in the new
7878
space) between two 1D arrays (the coordinates of the points in the original
@@ -82,6 +82,18 @@ array([0.49627072, 3.65287282])
8282
>>> metric_fun([3.5, 3.6], [5.6, 2.4])
8383
0.4962707194621285
8484

85+
- Alternatively, you can use `pair_score` to return the **score** between
86+
pairs of points (the larger the score, the more similar the pair).
87+
For Mahalanobis learners, it is equal to the opposite of the distance.
88+
89+
>>> score = nca.pair_score([[[3.5, 3.6], [5.6, 2.4]], [[1.2, 4.2], [2.1, 6.4]], [[3.3, 7.8], [10.9, 0.1]]])
90+
>>> score
91+
array([-0.49627072, -3.65287282, -6.06079877])
92+
93+
This is useful because `pair_score` matches the **score** semantic of
94+
scikit-learn's `Classification metrics
95+
<https://scikit-learn.org/stable/modules/model_evaluation.html#classification-metrics>`_.
96+
8597
.. note::
8698

8799
If the metric learner that you use learns a :ref:`Mahalanobis distance
@@ -93,7 +105,6 @@ array([0.49627072, 3.65287282])
93105
array([[0.43680409, 0.89169412],
94106
[0.89169412, 1.9542479 ]])
95107

96-
.. TODO: remove the "like it is the case etc..." if it's not the case anymore
97108

98109
Scikit-learn compatibility
99110
--------------------------
@@ -105,6 +116,7 @@ All supervised algorithms are scikit-learn estimators
105116
scikit-learn model selection routines
106117
(`sklearn.model_selection.cross_val_score`,
107118
`sklearn.model_selection.GridSearchCV`, etc).
119+
You can also use some of the scoring functions from `sklearn.metrics`.
108120

109121
Algorithms
110122
==========

doc/weakly_supervised.rst

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,9 @@ Also, as explained before, our metric learner has learned a distance between
160160
points. You can use this distance in two main ways:
161161

162162
- You can either return the distance between pairs of points using the
163-
`score_pairs` function:
163+
`pair_distance` function:
164164

165-
>>> mmc.score_pairs([[[3.5, 3.6, 5.2], [5.6, 2.4, 6.7]],
165+
>>> mmc.pair_distance([[[3.5, 3.6, 5.2], [5.6, 2.4, 6.7]],
166166
... [[1.2, 4.2, 7.7], [2.1, 6.4, 0.9]]])
167167
array([7.27607365, 0.88853014])
168168

@@ -175,6 +175,18 @@ array([7.27607365, 0.88853014])
175175
>>> metric_fun([3.5, 3.6, 5.2], [5.6, 2.4, 6.7])
176176
7.276073646278203
177177

178+
- Alternatively, you can use `pair_score` to return the **score** between
179+
pairs of points (the larger the score, the more similar the pair).
180+
For Mahalanobis learners, it is equal to the opposite of the distance.
181+
182+
>>> score = mmc.pair_score([[[3.5, 3.6], [5.6, 2.4]], [[1.2, 4.2], [2.1, 6.4]], [[3.3, 7.8], [10.9, 0.1]]])
183+
>>> score
184+
array([-0.49627072, -3.65287282, -6.06079877])
185+
186+
This is useful because `pair_score` matches the **score** semantic of
187+
scikit-learn's `Classification metrics
188+
<https://scikit-learn.org/stable/modules/model_evaluation.html#classification-metrics>`_.
189+
178190
.. note::
179191

180192
If the metric learner that you use learns a :ref:`Mahalanobis distance
@@ -187,8 +199,6 @@ array([[ 0.58603894, -5.69883982, -1.66614919],
187199
[-5.69883982, 55.41743549, 16.20219519],
188200
[-1.66614919, 16.20219519, 4.73697721]])
189201

190-
.. TODO: remove the "like it is the case etc..." if it's not the case anymore
191-
192202
.. _sklearn_compat_ws:
193203

194204
Prediction and scoring
@@ -344,8 +354,8 @@ returns the `sklearn.metrics.roc_auc_score` (which is threshold-independent).
344354

345355
.. note::
346356
See :ref:`fit_ws` for more details on metric learners functions that are
347-
not specific to learning on pairs, like `transform`, `score_pairs`,
348-
`get_metric` and `get_mahalanobis_matrix`.
357+
not specific to learning on pairs, like `transform`, `pair_distance`,
358+
`pair_score`, `get_metric` and `get_mahalanobis_matrix`.
349359

350360
Algorithms
351361
----------
@@ -691,8 +701,8 @@ of triplets that have the right predicted ordering.
691701

692702
.. note::
693703
See :ref:`fit_ws` for more details on metric learners functions that are
694-
not specific to learning on pairs, like `transform`, `score_pairs`,
695-
`get_metric` and `get_mahalanobis_matrix`.
704+
not specific to learning on pairs, like `transform`, `pair_distance`,
705+
`pair_score`, `get_metric` and `get_mahalanobis_matrix`.
696706

697707

698708

@@ -859,8 +869,8 @@ of quadruplets have the right predicted ordering.
859869

860870
.. note::
861871
See :ref:`fit_ws` for more details on metric learners functions that are
862-
not specific to learning on pairs, like `transform`, `score_pairs`,
863-
`get_metric` and `get_mahalanobis_matrix`.
872+
not specific to learning on pairs, like `transform`, `pair_distance`,
873+
`pair_score`, `get_metric` and `get_mahalanobis_matrix`.
864874

865875

866876

0 commit comments

Comments
 (0)