Skip to content

Commit e4132d6

Browse files
author
William de Vazelhes
committed
TST: remove skipping SDML in test_cross_validation_manual_vs_scikit
1 parent 8c50a0d commit e4132d6

File tree

1 file changed

+29
-31
lines changed

1 file changed

+29
-31
lines changed

test/test_sklearn_compat.py

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -125,39 +125,37 @@ def test_cross_validation_manual_vs_scikit(estimator, build_dataset,
125125
same as scikit-learn's cross-validation (some code for generating the
126126
folds is taken from scikit-learn).
127127
"""
128-
# TODO: remove this check when SDML has become deterministic
129-
if not str(estimator).startswith('SDML'):
130-
if any(hasattr(estimator, method) for method in ["predict", "score"]):
131-
input_data, labels, preprocessor, _ = build_dataset(with_preprocessor)
132-
estimator = clone(estimator)
133-
estimator.set_params(preprocessor=preprocessor)
134-
set_random_state(estimator)
135-
n_splits = 3
136-
kfold = KFold(shuffle=False, n_splits=n_splits)
137-
n_samples = input_data.shape[0]
138-
fold_sizes = (n_samples // n_splits) * np.ones(n_splits, dtype=np.int)
139-
fold_sizes[:n_samples % n_splits] += 1
140-
current = 0
141-
scores, predictions = [], np.zeros(input_data.shape[0])
142-
for fold_size in fold_sizes:
143-
start, stop = current, current + fold_size
144-
current = stop
145-
test_slice = slice(start, stop)
146-
train_mask = np.ones(input_data.shape[0], bool)
147-
train_mask[test_slice] = False
148-
y_train, y_test = labels[train_mask], labels[test_slice]
149-
estimator.fit(input_data[train_mask], y_train)
150-
if hasattr(estimator, "score"):
151-
scores.append(estimator.score(input_data[test_slice], y_test))
152-
if hasattr(estimator, "predict"):
153-
predictions[test_slice] = estimator.predict(input_data[test_slice])
128+
if any(hasattr(estimator, method) for method in ["predict", "score"]):
129+
input_data, labels, preprocessor, _ = build_dataset(with_preprocessor)
130+
estimator = clone(estimator)
131+
estimator.set_params(preprocessor=preprocessor)
132+
set_random_state(estimator)
133+
n_splits = 3
134+
kfold = KFold(shuffle=False, n_splits=n_splits)
135+
n_samples = input_data.shape[0]
136+
fold_sizes = (n_samples // n_splits) * np.ones(n_splits, dtype=np.int)
137+
fold_sizes[:n_samples % n_splits] += 1
138+
current = 0
139+
scores, predictions = [], np.zeros(input_data.shape[0])
140+
for fold_size in fold_sizes:
141+
start, stop = current, current + fold_size
142+
current = stop
143+
test_slice = slice(start, stop)
144+
train_mask = np.ones(input_data.shape[0], bool)
145+
train_mask[test_slice] = False
146+
y_train, y_test = labels[train_mask], labels[test_slice]
147+
estimator.fit(input_data[train_mask], y_train)
154148
if hasattr(estimator, "score"):
155-
assert all(scores == cross_val_score(estimator, input_data, labels,
156-
cv=kfold))
149+
scores.append(estimator.score(input_data[test_slice], y_test))
157150
if hasattr(estimator, "predict"):
158-
assert all(predictions == cross_val_predict(estimator, input_data,
159-
labels,
160-
cv=kfold))
151+
predictions[test_slice] = estimator.predict(input_data[test_slice])
152+
if hasattr(estimator, "score"):
153+
assert all(scores == cross_val_score(estimator, input_data, labels,
154+
cv=kfold))
155+
if hasattr(estimator, "predict"):
156+
assert all(predictions == cross_val_predict(estimator, input_data,
157+
labels,
158+
cv=kfold))
161159

162160

163161
def check_score(estimator, tuples, y):

0 commit comments

Comments
 (0)