-
Notifications
You must be signed in to change notification settings - Fork 81
feat: Added Hybrid Search Config and Tests [1/N] #211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
vishwarajanand
wants to merge
1
commit into
langchain-ai:main
Choose a base branch
from
vishwarajanand:hybrid_search_1
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
from abc import ABC | ||
from dataclasses import dataclass, field | ||
from typing import Any, Callable, Optional, Sequence | ||
|
||
from sqlalchemy import RowMapping | ||
|
||
|
||
def weighted_sum_ranking( | ||
primary_search_results: Sequence[RowMapping], | ||
secondary_search_results: Sequence[RowMapping], | ||
primary_results_weight: float = 0.5, | ||
secondary_results_weight: float = 0.5, | ||
fetch_top_k: int = 4, | ||
) -> Sequence[dict[str, Any]]: | ||
""" | ||
Ranks documents using a weighted sum of scores from two sources. | ||
|
||
Args: | ||
primary_search_results: A list of (document, distance) tuples from | ||
the primary search. | ||
secondary_search_results: A list of (document, distance) tuples from | ||
the secondary search. | ||
primary_results_weight: The weight for the primary source's scores. | ||
Defaults to 0.5. | ||
secondary_results_weight: The weight for the secondary source's scores. | ||
Defaults to 0.5. | ||
fetch_top_k: The number of documents to fetch after merging the results. | ||
Defaults to 4. | ||
|
||
Returns: | ||
A list of (document, distance) tuples, sorted by weighted_score in | ||
descending order. | ||
""" | ||
|
||
# stores computed metric with provided distance metric and weights | ||
weighted_scores: dict[str, dict[str, Any]] = {} | ||
|
||
# Process results from primary source | ||
for row in primary_search_results: | ||
values = list(row.values()) | ||
doc_id = str(values[0]) # first value is doc_id | ||
distance = float(values[-1]) # type: ignore # last value is distance | ||
row_values = dict(row) | ||
row_values["distance"] = primary_results_weight * distance | ||
weighted_scores[doc_id] = row_values | ||
|
||
# Process results from secondary source, | ||
# adding to existing scores or creating new ones | ||
for row in secondary_search_results: | ||
values = list(row.values()) | ||
doc_id = str(values[0]) # first value is doc_id | ||
distance = float(values[-1]) # type: ignore # last value is distance | ||
primary_score = ( | ||
weighted_scores[doc_id]["distance"] if doc_id in weighted_scores else 0.0 | ||
) | ||
row_values = dict(row) | ||
row_values["distance"] = distance * secondary_results_weight + primary_score | ||
weighted_scores[doc_id] = row_values | ||
|
||
# Sort the results by weighted score in descending order | ||
ranked_results = sorted( | ||
weighted_scores.values(), key=lambda item: item["distance"], reverse=True | ||
) | ||
return ranked_results[:fetch_top_k] | ||
|
||
|
||
def reciprocal_rank_fusion( | ||
primary_search_results: Sequence[RowMapping], | ||
secondary_search_results: Sequence[RowMapping], | ||
rrf_k: float = 60, | ||
fetch_top_k: int = 4, | ||
) -> Sequence[dict[str, Any]]: | ||
""" | ||
Ranks documents using Reciprocal Rank Fusion (RRF) of scores from two sources. | ||
|
||
Args: | ||
primary_search_results: A list of (document, distance) tuples from | ||
the primary search. | ||
secondary_search_results: A list of (document, distance) tuples from | ||
the secondary search. | ||
rrf_k: The RRF parameter k. | ||
Defaults to 60. | ||
fetch_top_k: The number of documents to fetch after merging the results. | ||
Defaults to 4. | ||
|
||
Returns: | ||
A list of (document_id, rrf_score) tuples, sorted by rrf_score | ||
in descending order. | ||
""" | ||
rrf_scores: dict[str, dict[str, Any]] = {} | ||
|
||
# Process results from primary source | ||
for rank, row in enumerate( | ||
sorted(primary_search_results, key=lambda item: item["distance"], reverse=True) | ||
): | ||
values = list(row.values()) | ||
doc_id = str(values[0]) | ||
row_values = dict(row) | ||
primary_score = rrf_scores[doc_id]["distance"] if doc_id in rrf_scores else 0.0 | ||
primary_score += 1.0 / (rank + rrf_k) | ||
row_values["distance"] = primary_score | ||
rrf_scores[doc_id] = row_values | ||
|
||
# Process results from secondary source | ||
for rank, row in enumerate( | ||
sorted( | ||
secondary_search_results, key=lambda item: item["distance"], reverse=True | ||
) | ||
): | ||
values = list(row.values()) | ||
doc_id = str(values[0]) | ||
row_values = dict(row) | ||
secondary_score = ( | ||
rrf_scores[doc_id]["distance"] if doc_id in rrf_scores else 0.0 | ||
) | ||
secondary_score += 1.0 / (rank + rrf_k) | ||
row_values["distance"] = secondary_score | ||
rrf_scores[doc_id] = row_values | ||
|
||
# Sort the results by rrf score in descending order | ||
# Sort the results by weighted score in descending order | ||
ranked_results = sorted( | ||
rrf_scores.values(), key=lambda item: item["distance"], reverse=True | ||
) | ||
# Extract only the RowMapping for the top results | ||
return ranked_results[:fetch_top_k] | ||
|
||
|
||
@dataclass | ||
class HybridSearchConfig(ABC): | ||
"""Google AlloyDB Vector Store Hybrid Search Config.""" | ||
|
||
tsv_column: Optional[str] = "" | ||
tsv_lang: Optional[str] = "pg_catalog.english" | ||
fts_query: Optional[str] = "" | ||
fusion_function: Callable[ | ||
[Sequence[RowMapping], Sequence[RowMapping], Any], Sequence[Any] | ||
] = weighted_sum_ranking # Updated default | ||
fusion_function_parameters: dict[str, Any] = field(default_factory=dict) | ||
primary_top_k: int = 4 | ||
secondary_top_k: int = 4 | ||
index_name: str = "langchain_tsv_index" | ||
index_type: str = "GIN" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
import pytest | ||
|
||
from langchain_postgres.v2.hybrid_search_config import (reciprocal_rank_fusion, | ||
weighted_sum_ranking) | ||
|
||
|
||
# Helper to create mock input items that mimic RowMapping for the fusion functions | ||
def get_row(doc_id: str, score: float, content: str = "content") -> dict: | ||
""" | ||
Simulates a RowMapping-like dictionary. | ||
The fusion functions expect to extract doc_id as the first value and | ||
the initial score/distance as the last value when casting values from RowMapping. | ||
They then operate on dictionaries, using the 'distance' key for the fused score. | ||
""" | ||
# Python dicts maintain insertion order (Python 3.7+). | ||
# This structure ensures list(row.values())[0] is doc_id and | ||
# list(row.values())[-1] is score. | ||
return {"id_val": doc_id, "content_field": content, "distance": score} | ||
|
||
|
||
class TestWeightedSumRanking: | ||
def test_empty_inputs(self): | ||
results = weighted_sum_ranking([], []) | ||
assert results == [] | ||
|
||
def test_primary_only(self): | ||
primary = [get_row("p1", 0.8), get_row("p2", 0.6)] | ||
# Expected scores: p1 = 0.8 * 0.5 = 0.4, p2 = 0.6 * 0.5 = 0.3 | ||
results = weighted_sum_ranking( | ||
primary, [], primary_results_weight=0.5, secondary_results_weight=0.5 | ||
) | ||
assert len(results) == 2 | ||
assert results[0]["id_val"] == "p1" | ||
assert results[0]["distance"] == pytest.approx(0.4) | ||
assert results[1]["id_val"] == "p2" | ||
assert results[1]["distance"] == pytest.approx(0.3) | ||
|
||
def test_secondary_only(self): | ||
secondary = [get_row("s1", 0.9), get_row("s2", 0.7)] | ||
# Expected scores: s1 = 0.9 * 0.5 = 0.45, s2 = 0.7 * 0.5 = 0.35 | ||
results = weighted_sum_ranking( | ||
[], secondary, primary_results_weight=0.5, secondary_results_weight=0.5 | ||
) | ||
assert len(results) == 2 | ||
assert results[0]["id_val"] == "s1" | ||
assert results[0]["distance"] == pytest.approx(0.45) | ||
assert results[1]["id_val"] == "s2" | ||
assert results[1]["distance"] == pytest.approx(0.35) | ||
|
||
def test_mixed_results_default_weights(self): | ||
primary = [get_row("common", 0.8), get_row("p_only", 0.7)] | ||
secondary = [get_row("common", 0.9), get_row("s_only", 0.6)] | ||
# Weights are 0.5, 0.5 | ||
# common_score = (0.8 * 0.5) + (0.9 * 0.5) = 0.4 + 0.45 = 0.85 | ||
# p_only_score = (0.7 * 0.5) = 0.35 | ||
# s_only_score = (0.6 * 0.5) = 0.30 | ||
# Order: common (0.85), p_only (0.35), s_only (0.30) | ||
|
||
results = weighted_sum_ranking(primary, secondary) | ||
assert len(results) == 3 | ||
assert results[0]["id_val"] == "common" | ||
assert results[0]["distance"] == pytest.approx(0.85) | ||
assert results[1]["id_val"] == "p_only" | ||
assert results[1]["distance"] == pytest.approx(0.35) | ||
assert results[2]["id_val"] == "s_only" | ||
assert results[2]["distance"] == pytest.approx(0.30) | ||
|
||
def test_mixed_results_custom_weights(self): | ||
primary = [get_row("d1", 1.0)] # p_w=0.2 -> 0.2 | ||
secondary = [get_row("d1", 0.5)] # s_w=0.8 -> 0.4 | ||
# Expected: d1_score = (1.0 * 0.2) + (0.5 * 0.8) = 0.2 + 0.4 = 0.6 | ||
|
||
results = weighted_sum_ranking( | ||
primary, secondary, primary_results_weight=0.2, secondary_results_weight=0.8 | ||
) | ||
assert len(results) == 1 | ||
assert results[0]["id_val"] == "d1" | ||
assert results[0]["distance"] == pytest.approx(0.6) | ||
|
||
def test_fetch_top_k(self): | ||
primary = [get_row(f"p{i}", (10 - i) / 10.0) for i in range(5)] | ||
# Scores: 1.0, 0.9, 0.8, 0.7, 0.6 | ||
# Weighted (0.5): 0.5, 0.45, 0.4, 0.35, 0.3 | ||
secondary = [] | ||
results = weighted_sum_ranking(primary, secondary, fetch_top_k=2) | ||
assert len(results) == 2 | ||
assert results[0]["id_val"] == "p0" | ||
assert results[0]["distance"] == pytest.approx(0.5) | ||
assert results[1]["id_val"] == "p1" | ||
assert results[1]["distance"] == pytest.approx(0.45) | ||
|
||
|
||
class TestReciprocalRankFusion: | ||
def test_empty_inputs(self): | ||
results = reciprocal_rank_fusion([], []) | ||
assert results == [] | ||
|
||
def test_primary_only(self): | ||
primary = [ | ||
get_row("p1", 0.8), | ||
get_row("p2", 0.6), | ||
] # p1 rank 0, p2 rank 1 | ||
rrf_k = 60 | ||
# p1_score = 1 / (0 + 60) | ||
# p2_score = 1 / (1 + 60) | ||
results = reciprocal_rank_fusion(primary, [], rrf_k=rrf_k) | ||
assert len(results) == 2 | ||
assert results[0]["id_val"] == "p1" | ||
assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k)) | ||
assert results[1]["id_val"] == "p2" | ||
assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k)) | ||
|
||
def test_secondary_only(self): | ||
secondary = [ | ||
get_row("s1", 0.9), | ||
get_row("s2", 0.7), | ||
] # s1 rank 0, s2 rank 1 | ||
rrf_k = 60 | ||
results = reciprocal_rank_fusion([], secondary, rrf_k=rrf_k) | ||
assert len(results) == 2 | ||
assert results[0]["id_val"] == "s1" | ||
assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k)) | ||
assert results[1]["id_val"] == "s2" | ||
assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k)) | ||
|
||
def test_mixed_results_default_k(self): | ||
primary = [get_row("common", 0.8), get_row("p_only", 0.7)] | ||
secondary = [get_row("common", 0.9), get_row("s_only", 0.6)] | ||
rrf_k = 60 | ||
# common_score = (1/(0+k))_prim + (1/(0+k))_sec = 2/k | ||
# p_only_score = (1/(1+k))_prim = 1/(k+1) | ||
# s_only_score = (1/(1+k))_sec = 1/(k+1) | ||
results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k) | ||
assert len(results) == 3 | ||
assert results[0]["id_val"] == "common" | ||
assert results[0]["distance"] == pytest.approx(2.0 / rrf_k) | ||
# Check the next two elements, their order might vary due to tie in score | ||
next_ids = {results[1]["id_val"], results[2]["id_val"]} | ||
next_scores = {results[1]["distance"], results[2]["distance"]} | ||
assert next_ids == {"p_only", "s_only"} | ||
for score in next_scores: | ||
assert score == pytest.approx(1.0 / (1 + rrf_k)) | ||
|
||
def test_fetch_top_k_rrf(self): | ||
primary = [get_row(f"p{i}", (10 - i) / 10.0) for i in range(5)] | ||
secondary = [] | ||
rrf_k = 1 | ||
results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k, fetch_top_k=2) | ||
assert len(results) == 2 | ||
assert results[0]["id_val"] == "p0" | ||
assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k)) | ||
assert results[1]["id_val"] == "p1" | ||
assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k)) | ||
|
||
def test_rrf_content_preservation(self): | ||
primary = [get_row("doc1", 0.9, content="Primary Content")] | ||
secondary = [get_row("doc1", 0.8, content="Secondary Content")] | ||
# RRF processes primary then secondary. If a doc is in both, | ||
# the content from the secondary list will overwrite primary's. | ||
results = reciprocal_rank_fusion(primary, secondary, rrf_k=60) | ||
assert len(results) == 1 | ||
assert results[0]["id_val"] == "doc1" | ||
assert results[0]["content_field"] == "Secondary Content" | ||
|
||
# If only in primary | ||
results_prim_only = reciprocal_rank_fusion(primary, [], rrf_k=60) | ||
assert results_prim_only[0]["content_field"] == "Primary Content" | ||
|
||
def test_reordering_from_inputs_rrf(self): | ||
""" | ||
Tests that RRF fused ranking can be different from both primary and secondary | ||
input rankings. | ||
Primary Order: A, B, C | ||
Secondary Order: C, B, A | ||
Fused Order: (A, C) tied, then B | ||
""" | ||
primary = [ | ||
get_row("docA", 0.9), | ||
get_row("docB", 0.8), | ||
get_row("docC", 0.1), | ||
] | ||
secondary = [ | ||
get_row("docC", 0.9), | ||
get_row("docB", 0.5), | ||
get_row("docA", 0.2), | ||
] | ||
rrf_k = 1.0 # Using 1.0 for k to simplify rank score calculation | ||
# docA_score = 1/(0+1) [P] + 1/(2+1) [S] = 1 + 1/3 = 4/3 | ||
# docB_score = 1/(1+1) [P] + 1/(1+1) [S] = 1/2 + 1/2 = 1 | ||
# docC_score = 1/(2+1) [P] + 1/(0+1) [S] = 1/3 + 1 = 4/3 | ||
results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k) | ||
assert len(results) == 3 | ||
assert {results[0]["id_val"], results[1]["id_val"]} == {"docA", "docC"} | ||
assert results[0]["distance"] == pytest.approx(4.0 / 3.0) | ||
assert results[1]["distance"] == pytest.approx(4.0 / 3.0) | ||
assert results[2]["id_val"] == "docB" | ||
assert results[2]["distance"] == pytest.approx(1.0) | ||
|
||
def test_reordering_from_inputs_weighted_sum(self): | ||
""" | ||
Tests that the fused ranking can be different from both primary and secondary | ||
input rankings. | ||
Primary Order: A (0.9), B (0.7) | ||
Secondary Order: B (0.8), A (0.2) | ||
Fusion (0.5/0.5 weights): | ||
docA_score = (0.9 * 0.5) + (0.2 * 0.5) = 0.45 + 0.10 = 0.55 | ||
docB_score = (0.7 * 0.5) + (0.8 * 0.5) = 0.35 + 0.40 = 0.75 | ||
Expected Fused Order: docB (0.75), docA (0.55) | ||
This is different from Primary (A,B) and Secondary (B,A) in terms of | ||
original score, but the fusion logic changes the effective contribution). | ||
""" | ||
primary = [get_row("docA", 0.9), get_row("docB", 0.7)] | ||
secondary = [get_row("docB", 0.8), get_row("docA", 0.2)] | ||
|
||
results = weighted_sum_ranking(primary, secondary) | ||
assert len(results) == 2 | ||
assert results[0]["id_val"] == "docB" | ||
assert results[0]["distance"] == pytest.approx(0.75) | ||
assert results[1]["id_val"] == "docA" | ||
assert results[1]["distance"] == pytest.approx(0.55) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update docstring :)