Skip to content

Commit e2fbfbe

Browse files
maxi-marufoMaximiliano Marufo da Silva
authored and
Maximiliano Marufo da Silva
committed
Adds test for warm_start parameter
1 parent 676165a commit e2fbfbe

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

test/metric_learn_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,22 @@ def test_large_output_iter(self):
324324
scml.fit(triplets)
325325
assert msg == raised_error.value.args[0]
326326

327+
@pytest.mark.parametrize("basis", ("lda", "triplet_diffs"))
328+
def test_warm_start(self, basis):
329+
X, y = load_iris(return_X_y=True)
330+
# Should work with warm_start=True even with first fit
331+
scml = SCML_Supervised(basis=basis, n_basis=85, k_genuine=7, k_impostor=5,
332+
random_state=42, warm_start=True)
333+
scml.fit(X, y)
334+
# Re-fitting should continue from previous fit
335+
before = class_separation(scml.transform(X), y)
336+
scml.fit(X, y)
337+
# We used the whole same dataset, so it can led to overfitting
338+
after = class_separation(scml.transform(X), y)
339+
if basis == "lda":
340+
assert before > after + 0.05 # For lda, it's better by a margin of 0.05
341+
else:
342+
assert before < after # For triplet_diffs, it overfits
327343

328344
class TestLSML(MetricTestCase):
329345
def test_iris(self):

0 commit comments

Comments
 (0)