From b554b3a1b2f1b885f61b3183668c5a9a30b4c1fd Mon Sep 17 00:00:00 2001 From: IQ Date: Tue, 6 Aug 2024 11:50:13 -0700 Subject: [PATCH] address issue for mmr calculation --- langchain_postgres/_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/langchain_postgres/_utils.py b/langchain_postgres/_utils.py index 9d8055af..14eae3b7 100644 --- a/langchain_postgres/_utils.py +++ b/langchain_postgres/_utils.py @@ -30,10 +30,8 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: X = np.array(X, dtype=np.float32) Y = np.array(Y, dtype=np.float32) - Z = 1 - simd.cdist(X, Y, metric="cosine") - if isinstance(Z, float): - return np.array([Z]) - return np.array(Z) + Z = 1 - np.array(simd.cdist(X, Y, metric="cosine")) + return Z except ImportError: logger.debug( "Unable to import simsimd, defaulting to NumPy implementation. If you want "