Skip to content

Commit ef4efe3

Browse files
author
William de Vazelhes
committed
DOC: retry to bump the doc for v0.5.0
1 parent 1f1546e commit ef4efe3

File tree

165 files changed

+28037
-10240
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

165 files changed

+28037
-10240
lines changed

.buildinfo

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Sphinx build info version 1
2+
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
3+
config: 89e7bf0528b77a77d89d8167f4184500
4+
tags: 645f666f9bcd5a90fca523b33c5a78b7
Binary file not shown.
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Sandwich demo
4+
=============
5+
6+
Sandwich demo based on code from http://nbviewer.ipython.org/6576096
7+
"""
8+
9+
import numpy as np
10+
from matplotlib import pyplot as plt
11+
from sklearn.metrics import pairwise_distances
12+
from sklearn.neighbors import NearestNeighbors
13+
14+
from metric_learn import LMNN, ITML_Supervised, LSML_Supervised, SDML_Supervised
15+
16+
17+
def sandwich_demo():
18+
x, y = sandwich_data()
19+
knn = nearest_neighbors(x, k=2)
20+
ax = plt.subplot(3, 1, 1) # take the whole top row
21+
plot_sandwich_data(x, y, ax)
22+
plot_neighborhood_graph(x, knn, y, ax)
23+
ax.set_title('input space')
24+
ax.set_aspect('equal')
25+
ax.set_xticks([])
26+
ax.set_yticks([])
27+
28+
mls = [
29+
LMNN(),
30+
ITML_Supervised(num_constraints=200),
31+
SDML_Supervised(num_constraints=200, balance_param=0.001),
32+
LSML_Supervised(num_constraints=200),
33+
]
34+
35+
for ax_num, ml in enumerate(mls, start=3):
36+
ml.fit(x, y)
37+
tx = ml.transform(x)
38+
ml_knn = nearest_neighbors(tx, k=2)
39+
ax = plt.subplot(3, 2, ax_num)
40+
plot_sandwich_data(tx, y, axis=ax)
41+
plot_neighborhood_graph(tx, ml_knn, y, axis=ax)
42+
ax.set_title(ml.__class__.__name__)
43+
ax.set_xticks([])
44+
ax.set_yticks([])
45+
plt.show()
46+
47+
48+
# TODO: use this somewhere
49+
def visualize_class_separation(X, labels):
50+
_, (ax1,ax2) = plt.subplots(ncols=2)
51+
label_order = np.argsort(labels)
52+
ax1.imshow(pairwise_distances(X[label_order]), interpolation='nearest')
53+
ax2.imshow(pairwise_distances(labels[label_order,None]),
54+
interpolation='nearest')
55+
56+
57+
def nearest_neighbors(X, k=5):
58+
knn = NearestNeighbors(n_neighbors=k)
59+
knn.fit(X)
60+
return knn.kneighbors(X, return_distance=False)
61+
62+
63+
def sandwich_data():
64+
# number of distinct classes
65+
num_classes = 6
66+
# number of points per class
67+
num_points = 9
68+
# distance between layers, the points of each class are in a layer
69+
dist = 0.7
70+
71+
data = np.zeros((num_classes, num_points, 2), dtype=float)
72+
labels = np.zeros((num_classes, num_points), dtype=int)
73+
74+
x_centers = np.arange(num_points, dtype=float) - num_points / 2
75+
y_centers = dist * (np.arange(num_classes, dtype=float) - num_classes / 2)
76+
for i, yc in enumerate(y_centers):
77+
for k, xc in enumerate(x_centers):
78+
data[i, k, 0] = np.random.normal(xc, 0.1)
79+
data[i, k, 1] = np.random.normal(yc, 0.1)
80+
labels[i,:] = i
81+
return data.reshape((-1, 2)), labels.ravel()
82+
83+
84+
def plot_sandwich_data(x, y, axis=plt, colors='rbgmky'):
85+
for idx, val in enumerate(np.unique(y)):
86+
xi = x[y==val]
87+
axis.scatter(*xi.T, s=50, facecolors='none', edgecolors=colors[idx])
88+
89+
90+
def plot_neighborhood_graph(x, nn, y, axis=plt, colors='rbgmky'):
91+
for i, a in enumerate(x):
92+
b = x[nn[i,1]]
93+
axis.plot((a[0], b[0]), (a[1], b[1]), colors[y[i]])
94+
95+
96+
if __name__ == '__main__':
97+
sandwich_demo()

0 commit comments

Comments
 (0)