From 7780063ae4cb3b6fa5eb8c91811b1f6695143259 Mon Sep 17 00:00:00 2001 From: CJ Carey Date: Fri, 22 May 2020 14:26:03 -0400 Subject: [PATCH] Use scipy's logsumexp function Fixes gh-289 --- metric_learn/mlkr.py | 8 +++----- metric_learn/nca.py | 6 +++--- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/metric_learn/mlkr.py b/metric_learn/mlkr.py index 5fffee9b..c65341be 100644 --- a/metric_learn/mlkr.py +++ b/metric_learn/mlkr.py @@ -6,16 +6,14 @@ import sys import warnings import numpy as np -from sklearn.exceptions import ConvergenceWarning, ChangedBehaviorWarning -from sklearn.utils.fixes import logsumexp from scipy.optimize import minimize +from scipy.special import logsumexp from sklearn.base import TransformerMixin - +from sklearn.exceptions import ConvergenceWarning, ChangedBehaviorWarning from sklearn.metrics import pairwise_distances -from metric_learn._util import _check_n_components from .base_metric import MahalanobisMixin -from ._util import _initialize_components +from ._util import _initialize_components, _check_n_components EPS = np.finfo(float).eps diff --git a/metric_learn/nca.py b/metric_learn/nca.py index fbce5658..d09e7282 100644 --- a/metric_learn/nca.py +++ b/metric_learn/nca.py @@ -8,10 +8,10 @@ import sys import numpy as np from scipy.optimize import minimize -from sklearn.metrics import pairwise_distances -from sklearn.exceptions import ConvergenceWarning, ChangedBehaviorWarning -from sklearn.utils.fixes import logsumexp +from scipy.special import logsumexp from sklearn.base import TransformerMixin +from sklearn.exceptions import ConvergenceWarning, ChangedBehaviorWarning +from sklearn.metrics import pairwise_distances from ._util import _initialize_components, _check_n_components from .base_metric import MahalanobisMixin