Skip to content

Commit 1b40c3b

Browse files
grudloffbellet
authored andcommitted
Changes in documentation. Rephrasing, fixed examples, standarized notation, etc. (#274)
* Multiple changes to the documentation. Rephrasing, fixed examples and standarized notation, and others. * Forgot to change one A to L * Replaced broken modindex link for module list * fixed compliance with flake8 * Fixed typos, misplaced example, etc * No new bullet and rectification * remove modules index link * add "respectively" * fix rca examples * fix rca examples again
1 parent f48a55d commit 1b40c3b

10 files changed

+114
-66
lines changed

README.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ metric-learn contains efficient Python implementations of several popular superv
2626

2727
- For SDML, using skggm will allow the algorithm to solve problematic cases
2828
(install from commit `a0ed406 <https://github.com/skggm/skggm/commit/a0ed406586c4364ea3297a658f415e13b5cbdaf8>`_).
29+
``pip install 'git+https://github.com/skggm/skggm.git@a0ed406586c4364ea3297a658f415e13b5cbdaf8'`` to install the required version of skggm from GitHub.
2930
- For running the examples only: matplotlib
3031

3132
**Installation/Setup**

doc/getting_started.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Run ``pip install metric-learn`` to download and install from PyPI.
1010
Alternately, download the source repository and run:
1111

1212
- ``python setup.py install`` for default installation.
13-
- ``python setup.py test`` to run all tests.
13+
- ``pytest test`` to run all tests.
1414

1515
**Dependencies**
1616

@@ -21,6 +21,7 @@ Alternately, download the source repository and run:
2121

2222
- For SDML, using skggm will allow the algorithm to solve problematic cases
2323
(install from commit `a0ed406 <https://github.com/skggm/skggm/commit/a0ed406586c4364ea3297a658f415e13b5cbdaf8>`_).
24+
``pip install 'git+https://github.com/skggm/skggm.git@a0ed406586c4364ea3297a658f415e13b5cbdaf8'`` to install the required version of skggm from GitHub.
2425
- For running the examples only: matplotlib
2526

2627
Quick start

doc/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Documentation outline
5252

5353
auto_examples/index
5454

55-
:ref:`genindex` | :ref:`modindex` | :ref:`search`
55+
:ref:`genindex` | :ref:`search`
5656

5757
.. |Travis-CI Build Status| image:: https://api.travis-ci.org/scikit-learn-contrib/metric-learn.svg?branch=master
5858
:target: https://travis-ci.org/scikit-learn-contrib/metric-learn

doc/supervised.rst

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,13 @@ The distance is learned by solving the following optimization problem:
131131
c\sum_{i, j, l}\eta_{ij}(1-y_{ij})[1+||\mathbf{L(x_i-x_j)}||^2-||
132132
\mathbf{L(x_i-x_l)}||^2]_+)
133133
134-
where :math:`\mathbf{x}_i` is an data point, :math:`\mathbf{x}_j` is one
135-
of its k nearest neighbors sharing the same label, and :math:`\mathbf{x}_l`
134+
where :math:`\mathbf{x}_i` is a data point, :math:`\mathbf{x}_j` is one
135+
of its k-nearest neighbors sharing the same label, and :math:`\mathbf{x}_l`
136136
are all the other instances within that region with different labels,
137137
:math:`\eta_{ij}, y_{ij} \in \{0, 1\}` are both the indicators,
138-
:math:`\eta_{ij}` represents :math:`\mathbf{x}_{j}` is the k nearest
139-
neighbors(with same labels) of :math:`\mathbf{x}_{i}`, :math:`y_{ij}=0`
140-
indicates :math:`\mathbf{x}_{i}, \mathbf{x}_{j}` belong to different class,
138+
:math:`\eta_{ij}` represents :math:`\mathbf{x}_{j}` is the k-nearest
139+
neighbors (with same labels) of :math:`\mathbf{x}_{i}`, :math:`y_{ij}=0`
140+
indicates :math:`\mathbf{x}_{i}, \mathbf{x}_{j}` belong to different classes,
141141
:math:`[\cdot]_+=\max(0, \cdot)` is the Hinge loss.
142142

143143
.. topic:: Example Code:
@@ -235,7 +235,7 @@ the sum of probability of being correctly classified:
235235

236236
Local Fisher Discriminant Analysis (:py:class:`LFDA <metric_learn.LFDA>`)
237237

238-
`LFDA` is a linear supervised dimensionality reduction method. It is
238+
`LFDA` is a linear supervised dimensionality reduction method which effectively combines the ideas of `Linear Discriminant Analysis <https://en.wikipedia.org/wiki/Linear_discriminant_analysis>` and Locality-Preserving Projection . It is
239239
particularly useful when dealing with multi-modality, where one ore more classes
240240
consist of separate clusters in input space. The core optimization problem of
241241
LFDA is solved as a generalized eigenvalue problem.
@@ -261,18 +261,18 @@ where
261261
\,\,\mathbf{A}_{i,j}(1/n-1/n_l) \qquad y_i = y_j\end{aligned}\right.\\
262262
263263
here :math:`\mathbf{A}_{i,j}` is the :math:`(i,j)`-th entry of the affinity
264-
matrix :math:`\mathbf{A}`:, which can be calculated with local scaling methods.
264+
matrix :math:`\mathbf{A}`:, which can be calculated with local scaling methods, `n` and `n_l` are the total number of points and the number of points per cluster `l` respectively.
265265

266266
Then the learning problem becomes derive the LFDA transformation matrix
267-
:math:`\mathbf{T}_{LFDA}`:
267+
:math:`\mathbf{L}_{LFDA}`:
268268

269269
.. math::
270270
271-
\mathbf{T}_{LFDA} = \arg\max_\mathbf{T}
272-
[\text{tr}((\mathbf{T}^T\mathbf{S}^{(w)}
273-
\mathbf{T})^{-1}\mathbf{T}^T\mathbf{S}^{(b)}\mathbf{T})]
271+
\mathbf{L}_{LFDA} = \arg\max_\mathbf{L}
272+
[\text{tr}((\mathbf{L}^T\mathbf{S}^{(w)}
273+
\mathbf{L})^{-1}\mathbf{L}^T\mathbf{S}^{(b)}\mathbf{L})]
274274
275-
That is, it is looking for a transformation matrix :math:`\mathbf{T}` such that
275+
That is, it is looking for a transformation matrix :math:`\mathbf{L}` such that
276276
nearby data pairs in the same class are made close and the data pairs in
277277
different classes are separated from each other; far apart data pairs in the
278278
same class are not imposed to be close.
@@ -326,9 +326,9 @@ empirical development. The Gaussian kernel is denoted as:
326326
327327
where :math:`d(\cdot, \cdot)` is the squared distance under some metrics,
328328
here in the fashion of Mahalanobis, it should be :math:`d(\mathbf{x}_i,
329-
\mathbf{x}_j) = ||\mathbf{A}(\mathbf{x}_i - \mathbf{x}_j)||`, the transition
330-
matrix :math:`\mathbf{A}` is derived from the decomposition of Mahalanobis
331-
matrix :math:`\mathbf{M=A^TA}`.
329+
\mathbf{x}_j) = ||\mathbf{L}(\mathbf{x}_i - \mathbf{x}_j)||`, the transition
330+
matrix :math:`\mathbf{L}` is derived from the decomposition of Mahalanobis
331+
matrix :math:`\mathbf{M=L^TL}`.
332332

333333
Since :math:`\sigma^2` can be integrated into :math:`d(\cdot)`, we can set
334334
:math:`\sigma^2=1` for the sake of simplicity. Here we use the cumulative

doc/weakly_supervised.rst

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -367,36 +367,36 @@ other methods, `ITML` does not rely on an eigenvalue computation or
367367
semi-definite programming.
368368

369369

370-
Given a Mahalanobis distance parameterized by :math:`A`, its corresponding
370+
Given a Mahalanobis distance parameterized by :math:`M`, its corresponding
371371
multivariate Gaussian is denoted as:
372372

373373
.. math::
374-
p(\mathbf{x}; \mathbf{A}) = \frac{1}{Z}\exp(-\frac{1}{2}d_\mathbf{A}
374+
p(\mathbf{x}; \mathbf{M}) = \frac{1}{Z}\exp(-\frac{1}{2}d_\mathbf{M}
375375
(\mathbf{x}, \mu))
376-
= \frac{1}{Z}\exp(-\frac{1}{2}((\mathbf{x} - \mu)^T\mathbf{A}
376+
= \frac{1}{Z}\exp(-\frac{1}{2}((\mathbf{x} - \mu)^T\mathbf{M}
377377
(\mathbf{x} - \mu))
378378
379379
where :math:`Z` is the normalization constant, the inverse of Mahalanobis
380-
matrix :math:`\mathbf{A}^{-1}` is the covariance of the Gaussian.
380+
matrix :math:`\mathbf{M}^{-1}` is the covariance of the Gaussian.
381381

382382
Given pairs of similar points :math:`S` and pairs of dissimilar points
383383
:math:`D`, the distance metric learning problem is to minimize the LogDet
384384
divergence, which is equivalent as minimizing :math:`\textbf{KL}(p(\mathbf{x};
385-
\mathbf{A}_0) || p(\mathbf{x}; \mathbf{A}))`:
385+
\mathbf{M}_0) || p(\mathbf{x}; \mathbf{M}))`:
386386

387387
.. math::
388388
389-
\min_\mathbf{A} D_{\ell \mathrm{d}}\left(A, A_{0}\right) =
390-
\operatorname{tr}\left(A A_{0}^{-1}\right)-\log \operatorname{det}
391-
\left(A A_{0}^{-1}\right)-n\\
392-
\text{subject to } \quad d_\mathbf{A}(\mathbf{x}_i, \mathbf{x}_j)
389+
\min_\mathbf{A} D_{\ell \mathrm{d}}\left(M, M_{0}\right) =
390+
\operatorname{tr}\left(M M_{0}^{-1}\right)-\log \operatorname{det}
391+
\left(M M_{0}^{-1}\right)-n\\
392+
\text{subject to } \quad d_\mathbf{M}(\mathbf{x}_i, \mathbf{x}_j)
393393
\leq u \qquad (\mathbf{x}_i, \mathbf{x}_j)\in S \\
394-
d_\mathbf{A}(\mathbf{x}_i, \mathbf{x}_j) \geq l \qquad (\mathbf{x}_i,
394+
d_\mathbf{M}(\mathbf{x}_i, \mathbf{x}_j) \geq l \qquad (\mathbf{x}_i,
395395
\mathbf{x}_j)\in D
396396
397397
398398
where :math:`u` and :math:`l` is the upper and the lower bound of distance
399-
for similar and dissimilar pairs respectively, and :math:`\mathbf{A}_0`
399+
for similar and dissimilar pairs respectively, and :math:`\mathbf{M}_0`
400400
is the prior distance metric, set to identity matrix by default,
401401
:math:`D_{\ell \mathrm{d}}(\cdot)` is the log determinant.
402402

@@ -518,17 +518,14 @@ as the Mahalanobis matrix.
518518

519519
from metric_learn import RCA
520520

521-
pairs = [[[1.2, 7.5], [1.3, 1.5]],
522-
[[6.4, 2.6], [6.2, 9.7]],
523-
[[1.3, 4.5], [3.2, 4.6]],
524-
[[6.2, 5.5], [5.4, 5.4]]]
525-
y = [1, 1, -1, -1]
526-
527-
# in this task we want points where the first feature is close to be closer
528-
# to each other, no matter how close the second feature is
521+
X = [[-0.05, 3.0],[0.05, -3.0],
522+
[0.1, -3.55],[-0.1, 3.55],
523+
[-0.95, -0.05],[0.95, 0.05],
524+
[0.4, 0.05],[-0.4, -0.05]]
525+
chunks = [0, 0, 1, 1, 2, 2, 3, 3]
529526

530527
rca = RCA()
531-
rca.fit(pairs, y)
528+
rca.fit(X, chunks)
532529

533530
.. topic:: References:
534531

examples/plot_metric_learning_examples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def plot_tsne(X, y, colormap=plt.cm.Paired):
175175
#
176176
# ITML uses a regularizer that automatically enforces a Semi-Definite
177177
# Positive Matrix condition - the LogDet divergence. It uses soft
178-
# must-link or cannot like constraints, and a simple algorithm based on
178+
# must-link or cannot-link constraints, and a simple algorithm based on
179179
# Bregman projections. Unlike LMNN, ITML will implicitly enforce points from
180180
# the same class to belong to the same cluster, as you can see below.
181181
#

metric_learn/itml.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -198,13 +198,16 @@ class ITML(_BaseITML, _PairsClassifierMixin):
198198
199199
Examples
200200
--------
201-
>>> from metric_learn import ITML_Supervised
202-
>>> from sklearn.datasets import load_iris
203-
>>> iris_data = load_iris()
204-
>>> X = iris_data['data']
205-
>>> Y = iris_data['target']
206-
>>> itml = ITML_Supervised(num_constraints=200)
207-
>>> itml.fit(X, Y)
201+
>>> from metric_learn import ITML
202+
>>> pairs = [[[1.2, 7.5], [1.3, 1.5]],
203+
>>> [[6.4, 2.6], [6.2, 9.7]],
204+
>>> [[1.3, 4.5], [3.2, 4.6]],
205+
>>> [[6.2, 5.5], [5.4, 5.4]]]
206+
>>> y = [1, 1, -1, -1]
207+
>>> # in this task we want points where the first feature is close to be
208+
>>> # closer to each other, no matter how close the second feature is
209+
>>> itml = ITML()
210+
>>> itml.fit(pairs, y)
208211
209212
References
210213
----------
@@ -335,6 +338,16 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
335338
The linear transformation ``L`` deduced from the learned Mahalanobis
336339
metric (See function `components_from_metric`.)
337340
341+
Examples
342+
--------
343+
>>> from metric_learn import ITML_Supervised
344+
>>> from sklearn.datasets import load_iris
345+
>>> iris_data = load_iris()
346+
>>> X = iris_data['data']
347+
>>> Y = iris_data['target']
348+
>>> itml = ITML_Supervised(num_constraints=200)
349+
>>> itml.fit(X, Y)
350+
338351
See Also
339352
--------
340353
metric_learn.ITML : The original weakly-supervised algorithm

metric_learn/lsml.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -186,13 +186,15 @@ class LSML(_BaseLSML, _QuadrupletsClassifierMixin):
186186
187187
Examples
188188
--------
189-
>>> from metric_learn import LSML_Supervised
190-
>>> from sklearn.datasets import load_iris
191-
>>> iris_data = load_iris()
192-
>>> X = iris_data['data']
193-
>>> Y = iris_data['target']
194-
>>> lsml = LSML_Supervised(num_constraints=200)
195-
>>> lsml.fit(X, Y)
189+
>>> from metric_learn import LSML
190+
>>> quadruplets = [[[1.2, 7.5], [1.3, 1.5], [6.4, 2.6], [6.2, 9.7]],
191+
>>> [[1.3, 4.5], [3.2, 4.6], [6.2, 5.5], [5.4, 5.4]],
192+
>>> [[3.2, 7.5], [3.3, 1.5], [8.4, 2.6], [8.2, 9.7]],
193+
>>> [[3.3, 4.5], [5.2, 4.6], [8.2, 5.5], [7.4, 5.4]]]
194+
>>> # we want to make closer points where the first feature is close, and
195+
>>> # further if the second feature is close
196+
>>> lsml = LSML()
197+
>>> lsml.fit(quadruplets)
196198
197199
References
198200
----------
@@ -290,6 +292,16 @@ class LSML_Supervised(_BaseLSML, TransformerMixin):
290292
prior. In any case, `random_state` is also used to randomly sample
291293
constraints from labels.
292294
295+
Examples
296+
--------
297+
>>> from metric_learn import LSML_Supervised
298+
>>> from sklearn.datasets import load_iris
299+
>>> iris_data = load_iris()
300+
>>> X = iris_data['data']
301+
>>> Y = iris_data['target']
302+
>>> lsml = LSML_Supervised(num_constraints=200)
303+
>>> lsml.fit(X, Y)
304+
293305
Attributes
294306
----------
295307
n_iter_ : `int`

metric_learn/mmc.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -426,13 +426,16 @@ class MMC(_BaseMMC, _PairsClassifierMixin):
426426
427427
Examples
428428
--------
429-
>>> from metric_learn import MMC_Supervised
430-
>>> from sklearn.datasets import load_iris
431-
>>> iris_data = load_iris()
432-
>>> X = iris_data['data']
433-
>>> Y = iris_data['target']
434-
>>> mmc = MMC_Supervised(num_constraints=200)
435-
>>> mmc.fit(X, Y)
429+
>>> from metric_learn import MMC
430+
>>> pairs = [[[1.2, 7.5], [1.3, 1.5]],
431+
>>> [[6.4, 2.6], [6.2, 9.7]],
432+
>>> [[1.3, 4.5], [3.2, 4.6]],
433+
>>> [[6.2, 5.5], [5.4, 5.4]]]
434+
>>> y = [1, 1, -1, -1]
435+
>>> # in this task we want points where the first feature is close to be
436+
>>> # closer to each other, no matter how close the second feature is
437+
>>> mmc = MMC()
438+
>>> mmc.fit(pairs, y)
436439
437440
References
438441
----------
@@ -552,6 +555,16 @@ class MMC_Supervised(_BaseMMC, TransformerMixin):
552555
samples, and pairs of dissimilar samples by taking different class
553556
samples. It then passes these pairs to `MMC` for training.
554557
558+
Examples
559+
--------
560+
>>> from metric_learn import MMC_Supervised
561+
>>> from sklearn.datasets import load_iris
562+
>>> iris_data = load_iris()
563+
>>> X = iris_data['data']
564+
>>> Y = iris_data['target']
565+
>>> mmc = MMC_Supervised(num_constraints=200)
566+
>>> mmc.fit(X, Y)
567+
555568
Attributes
556569
----------
557570
n_iter_ : `int`

metric_learn/rca.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,14 @@ class RCA(MahalanobisMixin, TransformerMixin):
6262
6363
Examples
6464
--------
65-
>>> from metric_learn import RCA_Supervised
66-
>>> from sklearn.datasets import load_iris
67-
>>> iris_data = load_iris()
68-
>>> X = iris_data['data']
69-
>>> Y = iris_data['target']
70-
>>> rca = RCA_Supervised(num_chunks=30, chunk_size=2)
71-
>>> rca.fit(X, Y)
65+
>>> from metric_learn import RCA
66+
>>> X = [[-0.05, 3.0],[0.05, -3.0],
67+
>>> [0.1, -3.55],[-0.1, 3.55],
68+
>>> [-0.95, -0.05],[0.95, 0.05],
69+
>>> [0.4, 0.05],[-0.4, -0.05]]
70+
>>> chunks = [0, 0, 1, 1, 2, 2, 3, 3]
71+
>>> rca = RCA()
72+
>>> rca.fit(X, chunks)
7273
7374
References
7475
------------------
@@ -196,6 +197,16 @@ class RCA_Supervised(RCA):
196197
A pseudo random number generator object or a seed for it if int.
197198
It is used to randomly sample constraints from labels.
198199
200+
Examples
201+
--------
202+
>>> from metric_learn import RCA_Supervised
203+
>>> from sklearn.datasets import load_iris
204+
>>> iris_data = load_iris()
205+
>>> X = iris_data['data']
206+
>>> Y = iris_data['target']
207+
>>> rca = RCA_Supervised(num_chunks=30, chunk_size=2)
208+
>>> rca.fit(X, Y)
209+
199210
Attributes
200211
----------
201212
components_ : `numpy.ndarray`, shape=(n_components, n_features)

0 commit comments

Comments
 (0)