diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index a05fe99..9ae44d4 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -3,6 +3,7 @@ import contextlib import enum +from functools import lru_cache import logging import uuid from typing import ( @@ -17,6 +18,7 @@ Tuple, Type, Union, + Literal, ) from typing import ( cast as typing_cast, @@ -62,11 +64,11 @@ class DistanceStrategy(str, enum.Enum): Base = declarative_base() # type: Any +_LANGCHAIN_DEFAULT_COLLECTION_TABLE_NAME = "langchain_pg_collection" +_LANGCHAIN_DEFAULT_EMBEDDING_TABLE_NAME = "langchain_pg_embedding" _LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain" -_classes: Any = None - COMPARISONS_TO_NATIVE = { "$eq": "==", "$ne": "!=", @@ -98,17 +100,21 @@ class DistanceStrategy(str, enum.Enum): ) -def _get_embedding_collection_store(vector_dimension: Optional[int] = None) -> Any: - global _classes - if _classes is not None: - return _classes - +@lru_cache(None) +def _get_embedding_collection_store( + vector_dimension: Optional[int] = None, + collection_table_name: str = _LANGCHAIN_DEFAULT_COLLECTION_TABLE_NAME, + embedding_table_name: str = _LANGCHAIN_DEFAULT_EMBEDDING_TABLE_NAME, +) -> Any: from pgvector.sqlalchemy import Vector # type: ignore + collection_class_name = f"CollectionStore_{collection_table_name}_{vector_dimension}" + embedding_class_name = f"EmbeddingStore_{embedding_table_name}_{vector_dimension}" + class CollectionStore(Base): """Collection store.""" - __tablename__ = "langchain_pg_collection" + __abstract__ = True uuid = sqlalchemy.Column( UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 @@ -116,12 +122,6 @@ class CollectionStore(Base): name = sqlalchemy.Column(sqlalchemy.String, nullable=False, unique=True) cmetadata = sqlalchemy.Column(JSON) - embeddings = relationship( - "EmbeddingStore", - back_populates="collection", - passive_deletes=True, - ) - @classmethod def get_by_name( cls, session: Session, name: str @@ -195,22 +195,19 @@ async def aget_or_create( class EmbeddingStore(Base): """Embedding store.""" - __tablename__ = "langchain_pg_embedding" + __abstract__ = True id = sqlalchemy.Column( sqlalchemy.String, nullable=True, primary_key=True, index=True, unique=True ) - collection_id = sqlalchemy.Column( UUID(as_uuid=True), sqlalchemy.ForeignKey( - f"{CollectionStore.__tablename__}.uuid", + f"{collection_table_name}.uuid", ondelete="CASCADE", ), ) - collection = relationship(CollectionStore, back_populates="embeddings") - - embedding: Vector = sqlalchemy.Column(Vector(vector_dimension)) + embedding = sqlalchemy.Column(Vector(vector_dimension)) document = sqlalchemy.Column(sqlalchemy.String, nullable=True) cmetadata = sqlalchemy.Column(JSONB, nullable=True) @@ -223,9 +220,33 @@ class EmbeddingStore(Base): ), ) - _classes = (EmbeddingStore, CollectionStore) + # Model classes with same module name could not be shared between different tables in sqlalchemy + # so a new class with unique name should be used. + c = type( + collection_class_name, + (CollectionStore,), + { + "__tablename__" : collection_table_name, + "embeddings" : relationship( + embedding_class_name, + back_populates="collection", + passive_deletes=True, + ), + } + ) + e = type( + embedding_class_name, + (EmbeddingStore,), + { + "__tablename__" : embedding_table_name, + "collection" : relationship( + collection_class_name, + back_populates="embeddings", + ), + } + ) - return _classes + return (e, c) def _results_to_docs(docs_and_scores: Any) -> List[Document]: @@ -377,6 +398,8 @@ def __init__( *, connection: Union[None, DBConnection, Engine, AsyncEngine, str] = None, embedding_length: Optional[int] = None, + collection_table_name: str = _LANGCHAIN_DEFAULT_COLLECTION_TABLE_NAME, + embedding_table_name: str = _LANGCHAIN_DEFAULT_EMBEDDING_TABLE_NAME, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, collection_metadata: Optional[dict] = None, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, @@ -399,6 +422,12 @@ def __init__( NOTE: This is not mandatory. Defining it will prevent vectors of any other size to be added to the embeddings table but, without it, the embeddings can't be indexed. + collection_table_name: name of collection table that stores collection informations. + (default: langchain_pg_collection) + embedding_table_name: name of collection table that stores Document and embeddings. + (default: langchain_pg_embedding) + NOTE: collection_table_name and embedding_table_name should be used as pairs, + you should not mix table name between different pairs. collection_name: The name of the collection to use. (default: langchain) NOTE: This is not the name of the table, but the name of the collection. The tables will be created when initializing the store (if not exists) @@ -419,6 +448,8 @@ def __init__( self.async_mode = async_mode self.embedding_function = embeddings self._embedding_length = embedding_length + self._collection_table_name = collection_table_name + self._embedding_table_name = embedding_table_name self.collection_name = collection_name self.collection_metadata = collection_metadata self._distance_strategy = distance_strategy @@ -470,7 +501,9 @@ def __post_init__( self.create_vector_extension() EmbeddingStore, CollectionStore = _get_embedding_collection_store( - self._embedding_length + vector_dimension=self._embedding_length, + collection_table_name=self._collection_table_name, + embedding_table_name=self._embedding_table_name, ) self.CollectionStore = CollectionStore self.EmbeddingStore = EmbeddingStore @@ -486,7 +519,9 @@ async def __apost_init__( self._async_init = True EmbeddingStore, CollectionStore = _get_embedding_collection_store( - self._embedding_length + vector_dimension=self._embedding_length, + collection_table_name=self._collection_table_name, + embedding_table_name=self._embedding_table_name, ) self.CollectionStore = CollectionStore self.EmbeddingStore = EmbeddingStore @@ -675,6 +710,8 @@ def __from( embedding: Embeddings, metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, + collection_table_name: str = _LANGCHAIN_DEFAULT_COLLECTION_TABLE_NAME, + embedding_table_name: str = _LANGCHAIN_DEFAULT_EMBEDDING_TABLE_NAME, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, connection: Optional[str] = None, @@ -691,6 +728,8 @@ def __from( store = cls( connection=connection, + collection_table_name=collection_table_name, + embedding_table_name=embedding_table_name, collection_name=collection_name, embeddings=embedding, distance_strategy=distance_strategy, @@ -713,6 +752,8 @@ async def __afrom( embedding: Embeddings, metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, + collection_table_name: str = _LANGCHAIN_DEFAULT_COLLECTION_TABLE_NAME, + embedding_table_name: str = _LANGCHAIN_DEFAULT_EMBEDDING_TABLE_NAME, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, connection: Optional[str] = None, @@ -729,6 +770,8 @@ async def __afrom( store = cls( connection=connection, + collection_table_name=collection_table_name, + embedding_table_name=embedding_table_name, collection_name=collection_name, embeddings=embedding, distance_strategy=distance_strategy, @@ -996,13 +1039,13 @@ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, floa docs = [ ( Document( - id=str(result.EmbeddingStore.id), - page_content=result.EmbeddingStore.document, - metadata=result.EmbeddingStore.cmetadata, + id=result.id, + page_content=result.document, + metadata=result.cmetadata, ), - result.distance if self.embedding_function is not None else None, + score, ) - for result in results + for result, score in results ] return docs @@ -1331,7 +1374,7 @@ def __query_collection( embedding: List[float], k: int = 4, filter: Optional[Dict[str, str]] = None, - ) -> Sequence[Any]: + ) -> Sequence[Tuple[Any, Union[float, int]]]: """Query the collection.""" with self._make_sync_session() as session: # type: ignore[arg-type] collection = self.get_collection(session) @@ -1466,6 +1509,8 @@ def from_texts( embedding: Embeddings, metadatas: Optional[List[dict]] = None, *, + collection_table_name: str = _LANGCHAIN_DEFAULT_COLLECTION_TABLE_NAME, + embedding_table_name: str = _LANGCHAIN_DEFAULT_EMBEDDING_TABLE_NAME, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, ids: Optional[List[str]] = None, @@ -1482,6 +1527,8 @@ def from_texts( embedding, metadatas=metadatas, ids=ids, + collection_table_name=collection_table_name, + embedding_table_name=embedding_table_name, collection_name=collection_name, distance_strategy=distance_strategy, pre_delete_collection=pre_delete_collection, @@ -1494,12 +1541,14 @@ async def afrom_texts( cls: Type[PGVector], texts: List[str], embedding: Embeddings, + *, metadatas: Optional[List[dict]] = None, + collection_table_name: str = _LANGCHAIN_DEFAULT_COLLECTION_TABLE_NAME, + embedding_table_name: str = _LANGCHAIN_DEFAULT_EMBEDDING_TABLE_NAME, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, ids: Optional[List[str]] = None, pre_delete_collection: bool = False, - *, use_jsonb: bool = True, **kwargs: Any, ) -> PGVector: @@ -1511,6 +1560,8 @@ async def afrom_texts( embedding, metadatas=metadatas, ids=ids, + collection_table_name=collection_table_name, + embedding_table_name=embedding_table_name, collection_name=collection_name, distance_strategy=distance_strategy, pre_delete_collection=pre_delete_collection, @@ -1525,6 +1576,8 @@ def from_embeddings( embedding: Embeddings, *, metadatas: Optional[List[dict]] = None, + collection_table_name: str = _LANGCHAIN_DEFAULT_COLLECTION_TABLE_NAME, + embedding_table_name: str = _LANGCHAIN_DEFAULT_EMBEDDING_TABLE_NAME, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, ids: Optional[List[str]] = None, @@ -1569,6 +1622,8 @@ def from_embeddings( embedding, metadatas=metadatas, ids=ids, + collection_table_name=collection_table_name, + embedding_table_name=embedding_table_name, collection_name=collection_name, distance_strategy=distance_strategy, pre_delete_collection=pre_delete_collection, @@ -1580,7 +1635,10 @@ async def afrom_embeddings( cls, text_embeddings: List[Tuple[str, List[float]]], embedding: Embeddings, + *, metadatas: Optional[List[dict]] = None, + collection_table_name: str = _LANGCHAIN_DEFAULT_COLLECTION_TABLE_NAME, + embedding_table_name: str = _LANGCHAIN_DEFAULT_EMBEDDING_TABLE_NAME, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, ids: Optional[List[str]] = None, @@ -1614,6 +1672,8 @@ async def afrom_embeddings( embedding, metadatas=metadatas, ids=ids, + collection_table_name=collection_table_name, + embedding_table_name=embedding_table_name, collection_name=collection_name, distance_strategy=distance_strategy, pre_delete_collection=pre_delete_collection, @@ -1625,6 +1685,8 @@ def from_existing_index( cls: Type[PGVector], embedding: Embeddings, *, + collection_table_name: str = _LANGCHAIN_DEFAULT_COLLECTION_TABLE_NAME, + embedding_table_name: str = _LANGCHAIN_DEFAULT_EMBEDDING_TABLE_NAME, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, pre_delete_collection: bool = False, @@ -1638,6 +1700,8 @@ def from_existing_index( """ store = cls( connection=connection, + collection_table_name=collection_table_name, + embedding_table_name=embedding_table_name, collection_name=collection_name, embeddings=embedding, distance_strategy=distance_strategy, @@ -1652,6 +1716,8 @@ async def afrom_existing_index( cls: Type[PGVector], embedding: Embeddings, *, + collection_table_name: str = _LANGCHAIN_DEFAULT_COLLECTION_TABLE_NAME, + embedding_table_name: str = _LANGCHAIN_DEFAULT_EMBEDDING_TABLE_NAME, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, pre_delete_collection: bool = False, @@ -1665,6 +1731,8 @@ async def afrom_existing_index( """ store = PGVector( connection=connection, + collection_table_name=collection_table_name, + embedding_table_name=embedding_table_name, collection_name=collection_name, embeddings=embedding, distance_strategy=distance_strategy, @@ -1699,6 +1767,8 @@ def from_documents( embedding: Embeddings, *, connection: Optional[DBConnection] = None, + collection_table_name: str = _LANGCHAIN_DEFAULT_COLLECTION_TABLE_NAME, + embedding_table_name: str = _LANGCHAIN_DEFAULT_EMBEDDING_TABLE_NAME, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, ids: Optional[List[str]] = None, @@ -1719,6 +1789,8 @@ def from_documents( metadatas=metadatas, connection=connection, ids=ids, + collection_table_name=collection_table_name, + embedding_table_name=embedding_table_name, collection_name=collection_name, use_jsonb=use_jsonb, **kwargs, @@ -1729,11 +1801,13 @@ async def afrom_documents( cls: Type[PGVector], documents: List[Document], embedding: Embeddings, + *, + collection_table_name: str = _LANGCHAIN_DEFAULT_COLLECTION_TABLE_NAME, + embedding_table_name: str = _LANGCHAIN_DEFAULT_EMBEDDING_TABLE_NAME, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, ids: Optional[List[str]] = None, pre_delete_collection: bool = False, - *, use_jsonb: bool = True, **kwargs: Any, ) -> PGVector: @@ -1757,6 +1831,8 @@ async def afrom_documents( distance_strategy=distance_strategy, metadatas=metadatas, ids=ids, + collection_table_name=collection_table_name, + embedding_table_name=embedding_table_name, collection_name=collection_name, use_jsonb=use_jsonb, **kwargs, @@ -1765,12 +1841,12 @@ async def afrom_documents( @classmethod def connection_string_from_db_params( cls, - driver: str, - host: str, - port: int, database: str, - user: str, - password: str, + host: str = "127.0.0.1", + port: int = 5432, + user: str = "postgres", + password: str = "postgres", + driver: Literal["psycopg"] = "psycopg", ) -> str: """Return connection string from database parameters.""" if driver != "psycopg": @@ -1837,7 +1913,7 @@ def max_marginal_relevance_search_with_score_by_vector( assert not self._async_engine, "This method must be called without async_mode" results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter) - embedding_list = [result.EmbeddingStore.embedding for result in results] + embedding_list = [result.embedding for result, _ in results] mmr_selected = maximal_marginal_relevance( np.array(embedding, dtype=np.float32), @@ -1886,7 +1962,7 @@ async def amax_marginal_relevance_search_with_score_by_vector( session=session, embedding=embedding, k=fetch_k, filter=filter ) - embedding_list = [result.EmbeddingStore.embedding for result in results] + embedding_list = [result.embedding for result, _ in results] mmr_selected = maximal_marginal_relevance( np.array(embedding, dtype=np.float32),