Closed
Description
Description
Trying to update the metric_plotting.ipynb
notebook, it didn't work for SDML, so here is a minimal example to reproduce.
Steps/Code to Reproduce
import numpy as np
from sklearn.datasets import load_iris
from metric_learn import SDML_Supervised
dataset = load_iris()
X, y = dataset.data, dataset.target
sdml = SDML_Supervised(num_constraints=200)
sdml.fit(X, y, random_state = np.random.RandomState(1234))
Expected Results
No error is thrown
Actual Results
/home/will/anaconda3/envs/standard/lib/python3.7/site-packages/sklearn/utils/validation.py:761: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n
_samples, ), for example using ravel().
y = column_or_1d(y, warn=True)
/home/will/anaconda3/envs/standard/lib/python3.7/site-packages/sklearn/utils/deprecation.py:77: DeprecationWarning: Function pinvh is deprecated; sklearn.utils.extmath.pinvh was deprecated in version 0.19 and wil
l be removed in 0.21. Use scipy.linalg.pinvh instead.
warnings.warn(msg, category=DeprecationWarning)
/home/will/anaconda3/envs/standard/lib/python3.7/site-packages/sklearn/utils/deprecation.py:77: DeprecationWarning: Function pinvh is deprecated; sklearn.utils.extmath.pinvh was deprecated in version 0.19 and wil
l be removed in 0.21. Use scipy.linalg.pinvh instead.
warnings.warn(msg, category=DeprecationWarning)
/home/will/anaconda3/envs/standard/lib/python3.7/site-packages/sklearn/utils/deprecation.py:77: DeprecationWarning: Function graph_lasso is deprecated; The 'graph_lasso' was renamed to 'graphical_lasso' in versio
n 0.20 and will be removed in 0.22.
warnings.warn(msg, category=DeprecationWarning)
---------------------------------------------------------------------------
FloatingPointError Traceback (most recent call last)
<ipython-input-2-caf37592524f> in <module>
3 X, y = load_iris(return_X_y=True)
4 sdml = SDML_Supervised(num_constraints=200)
----> 5 sdml.fit(X, y, random_state = np.random.RandomState(1234))
~/Code/metric-learn/metric_learn/sdml.py in fit(self, X, y, random_state)
181 random_state=random_state)
182 pairs, y = wrap_pairs(X, pos_neg)
--> 183 return _BaseSDML._fit(self, pairs, y)
~/Code/metric-learn/metric_learn/sdml.py in _fit(self, pairs, y)
68 # hack: ensure positive semidefinite
69 emp_cov = emp_cov.T.dot(emp_cov)
---> 70 _, self.M_ = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose)
71
72 self.transformer_ = transformer_from_metric(self.M_)
~/anaconda3/envs/standard/lib/python3.7/site-packages/sklearn/utils/deprecation.py in wrapped(*args, **kwargs)
76 def wrapped(*args, **kwargs):
77 warnings.warn(msg, category=DeprecationWarning)
---> 78 return fun(*args, **kwargs)
79
80 wrapped.__doc__ = self._update_doc(wrapped.__doc__)
~/anaconda3/envs/standard/lib/python3.7/site-packages/sklearn/covariance/graph_lasso_.py in graph_lasso(emp_cov, alpha, cov_init, mode, tol, enet_tol, max_iter, verbose, return_costs, eps, return_n_iter)
815 return graphical_lasso(emp_cov, alpha, cov_init, mode, tol,
816 enet_tol, max_iter, verbose, return_costs,
--> 817 eps, return_n_iter)
818
819
~/anaconda3/envs/standard/lib/python3.7/site-packages/sklearn/covariance/graph_lasso_.py in graphical_lasso(emp_cov, alpha, cov_init, mode, tol, enet_tol, max_iter, verbose, return_costs, eps, return_n_iter)
267 e.args = (e.args[0]
268 + '. The system is too ill-conditioned for this solver',)
--> 269 raise e
270
271 if return_costs:
~/anaconda3/envs/standard/lib/python3.7/site-packages/sklearn/covariance/graph_lasso_.py in graphical_lasso(emp_cov, alpha, cov_init, mode, tol, enet_tol, max_iter, verbose, return_costs, eps, return_n_iter)
258 break
259 if not np.isfinite(cost) and i > 0:
--> 260 raise FloatingPointError('Non SPD result: the system is '
261 'too ill-conditioned for this solver')
262 else:
FloatingPointError: Non SPD result: the system is too ill-conditioned for this solver. The system is too ill-conditioned for this solver
Versions
Linux-4.4.0-141-generic-x86_64-with-debian-stretch-sid
Python 3.7.1 (default, Dec 14 2018, 19:28:38)
[GCC 7.3.0]
NumPy 1.15.4
SciPy 1.2.0
Scikit-Learn 0.20.2
Metric-Learn 0.4.0
Note that it works for these versions, though printing this warning:
/home/will/anaconda3/envs/old_scipy_sk/lib/python2.7/site-packages/sklearn/covariance/graph_lasso_.py:252: ConvergenceWarning: graph_lasso: did not converge after 100 iteration: dual gap: 2.377e-04
ConvergenceWarning)
Linux-4.4.0-141-generic-x86_64-with-debian-stretch-sid
('Python', '2.7.15 |Anaconda, Inc.| (default, Dec 14 2018, 19:04:19) \n[GCC 7.3.0]')
('NumPy', '1.11.3')
('SciPy', '0.17.1')
('Scikit-Learn', '0.17.1')
metric-learn: commit: ddfac99