Skip to content

Commit 9f73250

Browse files
authored
[MRG] Fix RCA_Supervised sklearn compat test (#198)
* FIX fix RCA_Supervised sklearn compat test * Address #198 (review) * Refactor comment
1 parent 05a8d41 commit 9f73250

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

metric_learn/rca.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
def _chunk_mean_centering(data, chunks):
2727
num_chunks = chunks.max() + 1
2828
chunk_mask = chunks != -1
29-
chunk_data = data[chunk_mask]
29+
# We need to ensure the data is float so that we can substract the
30+
# mean on it
31+
chunk_data = data[chunk_mask].astype(float, copy=False)
3032
chunk_labels = chunks[chunk_mask]
3133
for c in xrange(num_chunks):
3234
mask = chunk_labels == c
@@ -98,7 +100,7 @@ def fit(self, X, chunks):
98100
When ``chunks[i] == -1``, point i doesn't belong to any chunklet.
99101
When ``chunks[i] == j``, point i belongs to chunklet j.
100102
"""
101-
X = self._prepare_inputs(X, ensure_min_samples=2)
103+
X, chunks = self._prepare_inputs(X, chunks, ensure_min_samples=2)
102104

103105
# PCA projection to remove noise and redundant information.
104106
if self.pca_comps is not None:
@@ -109,7 +111,6 @@ def fit(self, X, chunks):
109111
X_t = X - X.mean(axis=0)
110112
M_pca = None
111113

112-
chunks = np.asanyarray(chunks, dtype=int)
113114
chunk_mask, chunked_data = _chunk_mean_centering(X_t, chunks)
114115

115116
inner_cov = np.atleast_2d(np.cov(chunked_data, rowvar=0, bias=1))

test/test_sklearn_compat.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,15 @@ def stable_init(self, sparsity_param=0.01, num_labeled='deprecated',
8989
dSDML.__init__ = stable_init
9090
check_estimator(dSDML)
9191

92-
# This fails because the default num_chunks isn't data-dependent.
93-
# def test_rca(self):
94-
# check_estimator(RCA_Supervised)
92+
def test_rca(self):
93+
def stable_init(self, num_dims=None, pca_comps=None,
94+
chunk_size=2, preprocessor=None):
95+
# this init makes RCA stable for scikit-learn examples.
96+
RCA_Supervised.__init__(self, num_chunks=2, num_dims=num_dims,
97+
pca_comps=pca_comps, chunk_size=chunk_size,
98+
preprocessor=preprocessor)
99+
dRCA.__init__ = stable_init
100+
check_estimator(dRCA)
95101

96102

97103
RNG = check_random_state(0)

0 commit comments

Comments
 (0)