@@ -105,6 +105,70 @@ def stable_init(self, n_components=None, pca_comps=None,
105
105
106
106
# ---------------------- Test scikit-learn compatibility ----------------------
107
107
108
+ def generate_array_like (input_data , labels = None ):
109
+ """Helper function to generate array-like variants of numpy datasets,
110
+ for testing purposes."""
111
+ list_data = input_data .tolist ()
112
+ input_data_changed = [input_data , list_data , tuple (list_data )]
113
+ if input_data .ndim >= 2 :
114
+ input_data_changed .append (tuple (tuple (x ) for x in list_data ))
115
+ if input_data .ndim >= 3 :
116
+ input_data_changed .append (tuple (tuple (tuple (x ) for x in y ) for y in
117
+ list_data ))
118
+ if input_data .ndim == 2 :
119
+ pd = pytest .importorskip ('pandas' )
120
+ input_data_changed .append (pd .DataFrame (input_data ))
121
+ if labels is not None :
122
+ labels_changed = [labels , list (labels ), tuple (labels )]
123
+ else :
124
+ labels_changed = [labels ]
125
+ return input_data_changed , labels_changed
126
+
127
+
128
+ @pytest .mark .integration
129
+ @pytest .mark .parametrize ('with_preprocessor' , [True , False ])
130
+ @pytest .mark .parametrize ('estimator, build_dataset' , metric_learners ,
131
+ ids = ids_metric_learners )
132
+ def test_array_like_inputs (estimator , build_dataset , with_preprocessor ):
133
+ """Test that metric-learners can have as input (of all functions that are
134
+ applied on data) any array-like object."""
135
+ input_data , labels , preprocessor , X = build_dataset (with_preprocessor )
136
+
137
+ # we subsample the data for the test to be more efficient
138
+ input_data , _ , labels , _ = train_test_split (input_data , labels ,
139
+ train_size = 20 )
140
+ X = X [:10 ]
141
+
142
+ estimator = clone (estimator )
143
+ estimator .set_params (preprocessor = preprocessor )
144
+ set_random_state (estimator )
145
+ input_variants , label_variants = generate_array_like (input_data , labels )
146
+ for input_variant in input_variants :
147
+ for label_variant in label_variants :
148
+ estimator .fit (* remove_y_quadruplets (estimator , input_variant ,
149
+ label_variant ))
150
+ if hasattr (estimator , "predict" ):
151
+ estimator .predict (input_variant )
152
+ if hasattr (estimator , "predict_proba" ):
153
+ estimator .predict_proba (input_variant ) # anticipation in case some
154
+ # time we have that, or if ppl want to contribute with new algorithms
155
+ # it will be checked automatically
156
+ if hasattr (estimator , "decision_function" ):
157
+ estimator .decision_function (input_variant )
158
+ if hasattr (estimator , "score" ):
159
+ for label_variant in label_variants :
160
+ estimator .score (* remove_y_quadruplets (estimator , input_variant ,
161
+ label_variant ))
162
+
163
+ X_variants , _ = generate_array_like (X )
164
+ for X_variant in X_variants :
165
+ estimator .transform (X_variant )
166
+
167
+ pairs = np .array ([[X [0 ], X [1 ]], [X [0 ], X [2 ]]])
168
+ pairs_variants , _ = generate_array_like (pairs )
169
+ for pairs_variant in pairs_variants :
170
+ estimator .score_pairs (pairs_variant )
171
+
108
172
109
173
@pytest .mark .parametrize ('with_preprocessor' , [True , False ])
110
174
@pytest .mark .parametrize ('estimator, build_dataset' , pairs_learners ,
0 commit comments