-
Notifications
You must be signed in to change notification settings - Fork 229
[WIP] New API proposal #85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
3e5fbc3
0cbf1ae
300dada
8615634
a478baa
3744bec
214d991
4f4ce8b
ac00b8b
33561ab
7f40c56
402f397
47a9372
41dc123
5f63f24
fb0d118
df8a340
e3e7e0c
5a9c2e5
cf94740
52f4516
079bb13
da7c8e7
8192d11
2d0f1ca
6c59a1a
b70163a
a12eb9a
b1f6c23
b0ec33b
64f5762
2cf78dd
11a8ff1
a768cbf
335d8f4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,10 +12,10 @@ | |
import numpy as np | ||
from sklearn.utils.validation import check_array | ||
|
||
from .base_metric import BaseMetricLearner | ||
from .base_metric import SupervisedMetricLearner | ||
|
||
|
||
class Covariance(BaseMetricLearner): | ||
class Covariance(SupervisedMetricLearner): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand why this was chosen, but this particular base class made me stop and consider a moment. We may eventually want a base class for unsupervised methods as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, this simple covariance method should not be a SupervisedMetricLearner (as it is completely unsupervised). Whether we will really need an unsupervised class in the long run is unclear, but maybe the best for now is to create an UnsupervisedMetricLearner class which takes only X in fit. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the review ! I agree, I did not notice but indeed Covariance is unsupervised, so I will change this in a following PR |
||
def __init__(self): | ||
pass | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import unittest | ||
import numpy as np | ||
from metric_learn.constraints import ConstrainedDataset | ||
from numpy.testing import assert_array_equal | ||
from sklearn.model_selection import StratifiedKFold, KFold | ||
from sklearn.utils.testing import assert_raise_message | ||
|
||
X = np.random.randn(20, 5) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice to test sparse class _BaseTestConstrainedDataset(unittest.TestCase):
# everything currently under TestConstrainedDataset, but using self.X instead of X
class TestDenseConstrainedDataset(_BaseTestConstrainedDataset):
def setUp(self):
self.X = np.random.randn(20, 5)
self.c = ... # and so on
class TestSparseConstrainedDataset(_BaseTestConstrainedDataset):
# similar, but setUp creates a dense X There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe it would make sense to also test other data types? like lists instead of numpy arrays There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, I will add these tests in a next commit |
||
c = np.random.randint(0, X.shape[0], (15, 2)) | ||
cd = ConstrainedDataset(X, c) | ||
y = np.random.randint(0, 2, c.shape[0]) | ||
group = np.random.randint(0, 3, c.shape[0]) | ||
|
||
c_shape = c.shape[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is only used by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually this is also used in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree I will rename it and put it before the definitions that use it |
||
|
||
|
||
class TestConstrainedDataset(unittest.TestCase): | ||
|
||
@staticmethod | ||
def check_indexing(idx): | ||
# checks that an indexing returns the data we expect | ||
np.testing.assert_array_equal(cd[idx].c, c[idx]) | ||
np.testing.assert_array_equal(cd[idx].toarray(), X[c[idx]]) | ||
np.testing.assert_array_equal(cd[idx].toarray(), X[c][idx]) | ||
|
||
def test_inputs(self): | ||
# test the allowed and forbidden ways to create a ConstrainedDataset | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd split this into two separate tests, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, will do |
||
ConstrainedDataset(X, c) | ||
two_points = [[1, 2], [3, 5]] | ||
out_of_range_constraints = [[1, 2], [0, 1]] | ||
msg = "ConstrainedDataset cannot be created: the length of " \ | ||
"the dataset is 2, so index 2 is out of " \ | ||
"range." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to imports, use parentheses here: msg = ("First part of string. "
"Second part of string.") There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks, will do |
||
assert_raise_message(IndexError, msg, ConstrainedDataset, two_points, | ||
out_of_range_constraints) | ||
|
||
def test_getitem(self): | ||
# test different types of slicing | ||
i = np.random.randint(1, c_shape - 1) | ||
begin = np.random.randint(1, c_shape - 1) | ||
end = np.random.randint(begin + 1, c_shape) | ||
fancy_index = np.random.randint(0, c_shape, 20) | ||
binary_index = np.random.randint(0, 2, c_shape) | ||
boolean_index = binary_index.astype(bool) | ||
items = [0, c_shape - 1, i, slice(i), slice(0, begin), slice(begin, | ||
end), slice(end, c_shape), slice(0, c_shape), fancy_index, | ||
binary_index, boolean_index] | ||
for item in items: | ||
self.check_indexing(item) | ||
|
||
def test_repr(self): | ||
assert repr(cd) == repr(X[c]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks ! |
||
|
||
def test_str(self): | ||
assert str(cd) == str(X[c]) | ||
|
||
def test_shape(self): | ||
assert cd.shape == (c.shape[0], X.shape[1]) | ||
assert cd[0, 0].shape == (0, X.shape[1]) | ||
|
||
def test_toarray(self): | ||
assert_array_equal(cd.toarray(), cd.X[c]) | ||
|
||
def test_folding(self): | ||
# test that ConstrainedDataset is compatible with scikit-learn folding | ||
shuffle_list = [True, False] | ||
groups_list = [group, None] | ||
for alg in [KFold, StratifiedKFold]: | ||
for shuffle_i in shuffle_list: | ||
for group_i in groups_list: | ||
for train_idx, test_idx in alg( | ||
shuffle=shuffle_i).split(cd, y, group_i): | ||
self.check_indexing(train_idx) | ||
self.check_indexing(test_idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe this could also check for potential duplicates? could simply show a warning when this is the case. (one could also remove them but this might create problems later when constraint labels are used)
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, I will implement it in a next commit