diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index 044bece..b4abe29 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -98,7 +98,7 @@ class DistanceStrategy(str, enum.Enum): ) -def _get_embedding_collection_store(vector_dimension: Optional[int] = None) -> Any: +def _get_embedding_collection_store(vector_dimension: Optional[int] = None, table_schema: Optional[str] = "public") -> Any: global _classes if _classes is not None: return _classes @@ -109,6 +109,7 @@ class CollectionStore(Base): """Collection store.""" __tablename__ = "langchain_pg_collection" + __table_args__ = {"schema":table_schema} uuid = sqlalchemy.Column( UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 @@ -204,7 +205,7 @@ class EmbeddingStore(Base): collection_id = sqlalchemy.Column( UUID(as_uuid=True), sqlalchemy.ForeignKey( - f"{CollectionStore.__tablename__}.uuid", + f"{table_schema}.{CollectionStore.__tablename__}.uuid", ondelete="CASCADE", ), ) @@ -220,7 +221,8 @@ class EmbeddingStore(Base): "cmetadata", postgresql_using="gin", postgresql_ops={"cmetadata": "jsonb_path_ops"}, - ), + ) + , {"schema": table_schema} ) _classes = (EmbeddingStore, CollectionStore) @@ -386,6 +388,7 @@ def __init__( use_jsonb: bool = True, create_extension: bool = True, async_mode: bool = False, + table_schema: str = 'public' ) -> None: """Initialize the PGVector store. For an async version, use `PGVector.acreate()` instead. @@ -422,6 +425,7 @@ def __init__( self.collection_metadata = collection_metadata self._distance_strategy = distance_strategy self.pre_delete_collection = pre_delete_collection + self.table_schema = table_schema self.logger = logger or logging.getLogger(__name__) self.override_relevance_score_fn = relevance_score_fn self._engine: Optional[Engine] = None @@ -469,7 +473,8 @@ def __post_init__( self.create_vector_extension() EmbeddingStore, CollectionStore = _get_embedding_collection_store( - self._embedding_length + self._embedding_length, + table_schema = self.table_schema ) self.CollectionStore = CollectionStore self.EmbeddingStore = EmbeddingStore @@ -485,7 +490,8 @@ async def __apost_init__( self._async_init = True EmbeddingStore, CollectionStore = _get_embedding_collection_store( - self._embedding_length + self._embedding_length, + table_schema = self.table_schema ) self.CollectionStore = CollectionStore self.EmbeddingStore = EmbeddingStore