Skip to content

Commit 8860dfd

Browse files
committed
Adds test for warm_start parameter
1 parent 6ef1268 commit 8860dfd

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

test/metric_learn_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,22 @@ def test_large_output_iter(self):
323323
scml.fit(triplets)
324324
assert msg == raised_error.value.args[0]
325325

326+
@pytest.mark.parametrize("basis", ("lda", "triplet_diffs"))
327+
def test_warm_start(self, basis):
328+
X, y = load_iris(return_X_y=True)
329+
# Should work with warm_start=True even with first fit
330+
scml = SCML_Supervised(basis=basis, n_basis=85, k_genuine=7, k_impostor=5,
331+
random_state=42, warm_start=True)
332+
scml.fit(X, y)
333+
# Re-fitting should continue from previous fit
334+
before = class_separation(scml.transform(X), y)
335+
scml.fit(X, y)
336+
# We used the whole same dataset, so it can led to overfitting
337+
after = class_separation(scml.transform(X), y)
338+
if basis == "lda":
339+
assert before > after + 0.05 # For lda, it's better by a margin of 0.05
340+
else:
341+
assert before < after # For triplet_diffs, it overfits
326342

327343
class TestLSML(MetricTestCase):
328344
def test_iris(self):

0 commit comments

Comments
 (0)