diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index a3743e4..a9873b3 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -6,6 +6,7 @@ import logging import uuid import warnings +import threading from typing import ( Any, AsyncGenerator, @@ -100,11 +101,14 @@ class DistanceStrategy(str, enum.Enum): .union(SPECIAL_CASED_OPERATORS) ) +_embedding_collection_store_lock = threading.Lock() -def _get_embedding_collection_store(vector_dimension: Optional[int] = None) -> Any: +def _get_embedding_collection_store(vector_dimension: Optional[int] = None, extend_existing: bool = False) -> Any: global _classes - if _classes is not None: - return _classes + + with _embedding_collection_store_lock: + if _classes is not None: + return _classes from pgvector.sqlalchemy import Vector # type: ignore @@ -113,6 +117,9 @@ class CollectionStore(Base): __tablename__ = "langchain_pg_collection" + if extend_existing: + __table_args__ = {'extend_existing': True} + uuid = sqlalchemy.Column( UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 ) @@ -216,13 +223,15 @@ class EmbeddingStore(Base): cmetadata = sqlalchemy.Column(JSONB, nullable=True) __table_args__ = ( - sqlalchemy.Index( - "ix_cmetadata_gin", - "cmetadata", - postgresql_using="gin", - postgresql_ops={"cmetadata": "jsonb_path_ops"}, - ), - ) + sqlalchemy.Index( + "ix_cmetadata_gin", + "cmetadata", + postgresql_using="gin", + postgresql_ops={"cmetadata": "jsonb_path_ops"}, + ), + ) + if extend_existing: + __table_args__ = __table_args__ + ({'extend_existing': True}, ) _classes = (EmbeddingStore, CollectionStore) @@ -387,6 +396,7 @@ def __init__( use_jsonb: bool = True, create_extension: bool = True, async_mode: bool = False, + extend_existing: bool = False, ) -> None: """Initialize the PGVector store. For an async version, use `PGVector.acreate()` instead. @@ -415,6 +425,9 @@ def __init__( create_extension: If True, will create the vector extension if it doesn't exist. disabling creation is useful when using ReadOnly Databases. + extend_existing: If True, will set extend_existing=True in table_args for + SQLAlchemy models. This helps prevent race conditions when multiple + threads try to create tables simultaneously. (default: False) """ self.async_mode = async_mode self.embedding_function = embeddings @@ -428,6 +441,7 @@ def __init__( self._engine: Optional[Engine] = None self._async_engine: Optional[AsyncEngine] = None self._async_init = False + self._extend_existing = extend_existing if isinstance(connection, str): if async_mode: @@ -470,7 +484,8 @@ def __post_init__( self.create_vector_extension() EmbeddingStore, CollectionStore = _get_embedding_collection_store( - self._embedding_length + self._embedding_length, + extend_existing=self._extend_existing, ) self.CollectionStore = CollectionStore self.EmbeddingStore = EmbeddingStore @@ -486,7 +501,8 @@ async def __apost_init__( self._async_init = True EmbeddingStore, CollectionStore = _get_embedding_collection_store( - self._embedding_length + self._embedding_length, + extend_existing=self._extend_existing, ) self.CollectionStore = CollectionStore self.EmbeddingStore = EmbeddingStore @@ -514,10 +530,13 @@ async def acreate_vector_extension(self) -> None: async with self._async_engine.begin() as conn: await conn.run_sync(_create_vector_extension) + _create_tables_lock = threading.Lock() def create_tables_if_not_exists(self) -> None: - with self._make_sync_session() as session: - Base.metadata.create_all(session.get_bind()) - session.commit() + """Create tables if they don't exist in a thread-safe manner.""" + with _create_tables_lock: + with self._make_sync_session() as session: + Base.metadata.create_all(session.get_bind()) + session.commit() async def acreate_tables_if_not_exists(self) -> None: assert self._async_engine, "This method must be called with async_mode"