Skip to content

Commit b8530b2

Browse files
Support for SLEP010
This requires setting a public `n_features_in_` attribute as part of the fit() logic. For details, see: https://scikit-learn-enhancement-proposals.readthedocs.io/en/latest/slep010/proposal.html
1 parent 72b76c8 commit b8530b2

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ htmlcov/
88
.pytest_cache/
99
doc/auto_examples/*
1010
doc/generated/*
11-
venv/
11+
venv/
12+
.vscode/

metric_learn/base_metric.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,16 @@ def _prepare_inputs(self, X, y=None, type_of_inputs='classic',
166166
self._check_preprocessor()
167167

168168
check_is_fitted(self, ['preprocessor_'])
169-
return check_input(X, y,
169+
outs = check_input(X, y,
170170
type_of_inputs=type_of_inputs,
171171
preprocessor=self.preprocessor_,
172172
estimator=self,
173173
tuple_size=getattr(self, '_tuple_size', None),
174174
**kwargs)
175+
# Conform to SLEP010
176+
if not hasattr(self, 'n_features_in_'):
177+
self.n_features_in_ = (outs if y is None else outs[0]).shape[1]
178+
return outs
175179

176180
@abstractmethod
177181
def get_metric(self):

0 commit comments

Comments
 (0)