Skip to content

Add support for semantic search with pandas #293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions llmstack/apps/runner/agent_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from typing import Any, Dict, List

from llmstack.apps.runner.agent_controller import (
AgentController,
AgentControllerConfig,
AgentControllerData,
AgentControllerDataType,
AgentControllerFactory,
AgentMessageContent,
AgentMessageContentType,
AgentToolCallsMessage,
Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(
)

self._agent_output_queue = asyncio.Queue()
self._agent_controller = AgentController(self._agent_output_queue, self._controller_config)
self._agent_controller = AgentControllerFactory.create(self._agent_output_queue, self._controller_config)

def _add_error_from_tool_call(self, output_index, tool_name, tool_call_id, errors):
error_message = "\n".join([error for error in errors])
Expand Down
420 changes: 219 additions & 201 deletions llmstack/apps/runner/agent_controller.py

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions llmstack/data/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def destinations(self, request):
def transformations(self, request):
from llmstack.data.transformations import (
CodeSplitter,
CSVTextSplitter,
SemanticDoubleMergingSplitterNodeParser,
SentenceSplitter,
UnstructuredIOSplitter,
Expand Down Expand Up @@ -171,6 +172,12 @@ def transformations(self, request):
"schema": UnstructuredIOSplitter.get_schema(),
"ui_schema": UnstructuredIOSplitter.get_ui_schema(),
},
{
"slug": CSVTextSplitter.slug(),
"provider_slug": CSVTextSplitter.provider_slug(),
"schema": CSVTextSplitter.get_schema(),
"ui_schema": CSVTextSplitter.get_ui_schema(),
},
]
)

Expand Down
22 changes: 16 additions & 6 deletions llmstack/data/destinations/stores/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def add(self, document):
k: v for k, v in document.extra_info.get("extra_data", {}).items() if k in [m.target for m in self.mapping]
}
for item in self.schema:
if item.name == "id" or item.name == "text":
if item.name == "id" or item.name == "text" or item.name == "embedding":
continue
if item.name not in extra_data:
if item.type == "string":
Expand All @@ -112,7 +112,8 @@ def add(self, document):
extra_data[item.name] = str(extra_data[item.name])

for node in document.nodes:
document_dict = {"text": node.text, **extra_data}
node_metadata = node.metadata
document_dict = {"text": node.text, "embedding": node.embedding, **extra_data, **node_metadata}
entry_dict = {
"id": node.id_,
**{mapping.source: document_dict.get(mapping.target) for mapping in self.mapping},
Expand All @@ -137,10 +138,19 @@ def delete(self, document: DataDocument):
self._asset.update_file(buffer.getvalue(), filename)

def search(self, query: str, **kwargs):
result = self._dataframe.query(query).to_dict(orient="records")
nodes = list(
map(lambda x: TextNode(text=json.dumps(x), metadata={"query": query, "source": self._name}), result)
)
df = self._dataframe
df = df.query(kwargs.get("search_filters") or query)
result = df.to_dict(orient="records")
nodes = []
for entry in result:
entry.pop("embedding")
nodes.append(
TextNode(
text=json.dumps(entry),
metadata={"query": query, "source": self._name, "search_filters": kwargs.get("search_filters")},
)
)

node_ids = list(map(lambda x: x["id"], result))
return VectorStoreQueryResult(nodes=nodes, ids=node_ids, similarities=[])

Expand Down
5 changes: 1 addition & 4 deletions llmstack/data/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,9 @@ def search(self, query: str, use_hybrid_search=True, **kwargs) -> List[dict]:
content_key = self.datasource.destination_text_content_key
query_embedding = None

if kwargs.get("search_filters", None):
raise NotImplementedError("Search filters are not supported for this data source.")

documents = []

if self._embedding_generator:
if query and self._embedding_generator:
query_embedding = self._embedding_generator.get_embedding(query)

if self._destination:
Expand Down
5 changes: 1 addition & 4 deletions llmstack/data/transformations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@

@cache
def get_transformer_cls(slug, provider_slug):
for cls in [
UnstructuredIOSplitter,
EmbeddingsGenerator,
]:
for cls in [UnstructuredIOSplitter, EmbeddingsGenerator, CSVTextSplitter]:
if cls.slug() == slug and cls.provider_slug() == provider_slug:
return cls

Expand Down
109 changes: 78 additions & 31 deletions llmstack/data/transformations/splitters.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,54 @@
import csv
from io import StringIO
import json
import logging
from typing import List, Optional

from llama_index.core.node_parser.interface import MetadataAwareTextSplitter
from pydantic import Field
from llama_index.core.bridge.pydantic import Field
from llama_index.core.node_parser.interface import TextSplitter

from llmstack.assets.utils import get_asset_by_objref
from llmstack.common.blocks.base.schema import get_ui_schema_from_json_schema

class CSVTextSplitter(MetadataAwareTextSplitter):
include_columns: Optional[List[str]] = Field(
default=None,
description="Columns to include in the text",
)
logger = logging.getLogger(__name__)


class PromptlyTransformers:
@classmethod
def get_schema(cls):
json_schema = cls.schema()
json_schema["properties"].pop("callback_manager", None)
json_schema["properties"].pop("class_name", None)
json_schema["properties"].pop("include_metadata", None)
json_schema["properties"].pop("include_prev_next_rel", None)
return json_schema

@classmethod
def get_ui_schema(cls):
return get_ui_schema_from_json_schema(cls.get_schema())

@classmethod
def get_default_data(cls):
data = cls().dict()
data.pop("callback_manager", None)
data.pop("class_name", None)
data.pop("include_metadata", None)
data.pop("include_prev_next_rel", None)
return data


class CSVTextSplitter(TextSplitter, PromptlyTransformers):
exclude_columns: Optional[List[str]] = Field(
default=None,
description="Columns to exclude from the text",
)
text_columns: Optional[List[str]] = Field(
default=None,
description="Columns to include in the text",
)
metadata_prefix: Optional[str] = Field(
default="cts_",
description="Prefix for metadata columns",
)

@classmethod
def slug(cls):
Expand All @@ -24,27 +58,40 @@ def slug(cls):
def provider_slug(cls):
return "promptly"

@classmethod
def class_name(cls) -> str:
return "CSVTextSplitter"

def _split_text(self, text: str) -> List[str]:
chunks = []
file_handle = StringIO(text)
csv_reader = csv.DictReader(file_handle)
for i, row in enumerate(csv_reader):
content = ""
for column_name, value in row.items():
if self.include_columns and column_name not in self.include_columns:
continue
if self.exclude_columns and column_name in self.exclude_columns:
continue
content += f"{column_name}: {value}\n"
chunks.append(content)
return chunks

def split_text_metadata_aware(self, text: str, metadata_str: str) -> List[str]:
return self._split_text(text)

def split_text(self, text: str) -> List[str]:
return self._split_text(text)
raise NotImplementedError

def split_texts(self, texts: List[str]) -> List[str]:
raise NotImplementedError

def _parse_nodes(self, nodes, show_progress: bool = False, **kwargs):
from llama_index.core.node_parser.node_utils import build_nodes_from_splits
from llama_index.core.utils import get_tqdm_iterable

all_nodes = []
nodes_with_progress = get_tqdm_iterable(nodes, show_progress, "Parsing nodes")
for node in nodes_with_progress:
if hasattr(node, "content"):
asset = get_asset_by_objref(node.content, None, None)
with asset.file.open(mode="r") as f:
csv_reader = csv.DictReader(f)
for row in csv_reader:
content = {}
for column_name, value in row.items():
if self.exclude_columns and column_name in self.exclude_columns:
continue
content[column_name] = value
row_text = json.dumps(content)
if self.text_columns:
if len(self.text_columns) == 1:
row_text = content[self.text_columns[0]]
else:
text_parts = {}
for column_name in self.text_columns:
text_parts[column_name] = content.get(column_name, "")
row_text = json.dumps(text_parts)
all_nodes.extend(build_nodes_from_splits([row_text], node, id_func=self.id_func))
for column_name, value in content.items():
all_nodes[-1].metadata[f"{self.metadata_prefix}{column_name}"] = value

return all_nodes
4 changes: 3 additions & 1 deletion llmstack/processors/providers/promptly/datasource_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@


class DataSourceSearchInput(ApiProcessorSchema):
query: str
query: Optional[str] = None
filters: Optional[str] = None


class DocumentMetadata(BaseModel):
Expand Down Expand Up @@ -104,6 +105,7 @@ def process(self) -> DataSourceSearchOutput:
alpha=hybrid_semantic_search_ratio,
limit=self._config.document_limit,
use_hybrid_search=True,
search_filters=input_data.filters or self._config.search_filters,
)
documents.extend(result)
except BaseException:
Expand Down
75 changes: 28 additions & 47 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ django-picklefield = "^3.2"
django-redis = "^5.4.0"
djangorestframework = "^3.15.2"
django-flags = "^5.0.13"
django-jsonform = {version = "^2.17.4"}
django-ratelimit = {version = "^4.1.0"}
croniter = {version ="^2.0.1"}
pykka = "^4.0.2"
Expand Down Expand Up @@ -75,7 +74,7 @@ django-rq = "^3.0.0"
distlib = "^0.3.9"

[tool.poetry.group.faiss.dependencies]
faiss-cpu = "^1.8.0"
faiss-cpu = "^1.9.0"

[tool.poetry.group.processors]

Expand Down