Skip to content

feat: Adding multi modal support for PGVectorStore #207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ on:
description: "From which folder this pipeline executes"

env:
POETRY_VERSION: "1.7.1"
POETRY_VERSION: "2.1.1"

jobs:
build:
Expand Down Expand Up @@ -37,7 +37,7 @@ jobs:

- name: Install dependencies
shell: bash
run: poetry install
run: poetry install --extras gcs

- name: Run core tests
shell: bash
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ concurrency:
cancel-in-progress: true

env:
POETRY_VERSION: "1.7.1"
POETRY_VERSION: "2.1.1"
WORKDIR: "."

jobs:
Expand Down Expand Up @@ -89,7 +89,7 @@ jobs:
shell: bash
run: |
echo "Running tests, installing dependencies with poetry..."
poetry install --with test,lint,typing,docs
poetry install --with test,lint,typing,docs --extras gcs
- name: Run tests
run: make test
env:
Expand Down
117 changes: 117 additions & 0 deletions langchain_postgres/v2/async_vectorstore.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# TODO: Remove below import when minimum supported Python version is 3.10
from __future__ import annotations

import base64
import copy
import json
import uuid
from typing import Any, Callable, Iterable, Optional, Sequence
from urllib.parse import urlparse

import numpy as np
import requests
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore, utils
Expand Down Expand Up @@ -365,6 +368,98 @@ async def aadd_documents(
ids = await self.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs)
return ids

def _encode_image(self, uri: str) -> str:
"""Get base64 string from a image URI."""
if uri.startswith("gs://"):
from google.cloud import storage # type: ignore

path_without_prefix = uri[len("gs://") :]
parts = path_without_prefix.split("/", 1)
bucket_name = parts[0]
object_name = "" # Default for bucket root if no object specified
if len(parts) == 2:
object_name = parts[1]
storage_client = storage.Client()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to wrap this in a try except block to provide a more clear error or do you think the error is clear if they are not running in a Google Cloud environment or have set up credentials.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other langchain packages don't have running integrations tests for 3P providers. We could mock this test or just test this functionality in our package downstream.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently this is the error
google.auth.exceptions.DefaultCredentialsError: Your default credentials were not found. To set up Application Default Credentials, see https://cloud.google.com/docs/authentication/external/set-up-adc for more information.
I think this is pretty descriptive, let me know what you think.

The options for the tests are:

  1. If we want we can try making the images publicly accessible on a GCP project (which claims that we would not need credentials to fetch it).
  2. We could also store the image directly and skip testing the pathway of GCP.
  3. Not test the add_images at all.

What do you suggest?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If GCS storage call can be easily mock, let's go ahead and do that. If it can't let's keep the test but skip it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the mock solutions may need more debugging, I've removed the gcs uri from being tested, the other images being created locally are still under the test.
I will recreate the GCS path testing in our libraries.

bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(object_name)
return base64.b64encode(blob.download_as_bytes()).decode("utf-8")

parsed_uri = urlparse(uri)
if parsed_uri.scheme in ["http", "https"]:
response = requests.get(uri, stream=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an SSRF attack

response.raise_for_status()
return base64.b64encode(response.content).decode("utf-8")

with open(uri, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")

async def aadd_images(
self,
uris: list[str],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accepting URIs without safe guards is an SSRF attack

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that SSRF attacks are generally dealt with by the application layer. Is that correct, or is it more of a framework responsibility?

metadatas: Optional[list[dict]] = None,
ids: Optional[list[str]] = None,
**kwargs: Any,
) -> list[str]:
"""Embed images and add to the table.

Args:
uris (list[str]): List of local image URIs to add to the table.
metadatas (Optional[list[dict]]): List of metadatas to add to table records.
ids: (Optional[list[str]]): List of IDs to add to table records.

Returns:
List of record IDs added.
"""
encoded_images = []
if metadatas is None:
metadatas = [{"image_uri": uri} for uri in uris]

for uri in uris:
encoded_image = self._encode_image(uri)
encoded_images.append(encoded_image)

embeddings = self._images_embedding_helper(uris)
ids = await self.aadd_embeddings(
encoded_images, embeddings, metadatas=metadatas, ids=ids, **kwargs
)
return ids

def _images_embedding_helper(self, image_uris: list[str]) -> list[list[float]]:
# check if either `embed_images()` or `embed_image()` API is supported by the embedding service used
if hasattr(self.embedding_service, "embed_images"):
try:
embeddings = self.embedding_service.embed_images(image_uris)
except Exception as e:
raise Exception(
f"Make sure your selected embedding model supports list of image URIs as input. {str(e)}"
)
elif hasattr(self.embedding_service, "embed_image"):
try:
embeddings = self.embedding_service.embed_image(image_uris)
except Exception as e:
raise Exception(
f"Make sure your selected embedding model supports list of image URIs as input. {str(e)}"
)
else:
raise ValueError(
"Please use an embedding model that supports image embedding."
)
return embeddings

async def asimilarity_search_image(
self,
image_uri: str,
k: Optional[int] = None,
filter: Optional[dict] = None,
**kwargs: Any,
) -> list[Document]:
"""Return docs selected by similarity search on query."""
embedding = self._images_embedding_helper([image_uri])[0]

return await self.asimilarity_search_by_vector(
embedding=embedding, k=k, filter=filter, **kwargs
)

async def adelete(
self,
ids: Optional[list] = None,
Expand Down Expand Up @@ -1268,3 +1363,25 @@ def max_marginal_relevance_search_with_score_by_vector(
raise NotImplementedError(
"Sync methods are not implemented for AsyncPGVectorStore. Use PGVectorStore interface instead."
)

def add_images(
self,
uris: list[str],
metadatas: Optional[list[dict]] = None,
ids: Optional[list[str]] = None,
**kwargs: Any,
) -> list[str]:
raise NotImplementedError(
"Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead."
)

def similarity_search_image(
self,
image_uri: str,
k: Optional[int] = None,
filter: Optional[dict] = None,
**kwargs: Any,
) -> list[Document]:
raise NotImplementedError(
"Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead."
)
48 changes: 48 additions & 0 deletions langchain_postgres/v2/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,3 +840,51 @@ def get_by_ids(self, ids: Sequence[str]) -> list[Document]:

def get_table_name(self) -> str:
return self.__vs.table_name

async def aadd_images(
self,
uris: list[str],
metadatas: Optional[list[dict]] = None,
ids: Optional[list[str]] = None,
**kwargs: Any,
) -> list[str]:
"""Embed images and add to the table."""
return await self._engine._run_as_async(
self.__vs.aadd_images(uris, metadatas, ids, **kwargs) # type: ignore
)

def add_images(
self,
uris: list[str],
metadatas: Optional[list[dict]] = None,
ids: Optional[list[str]] = None,
**kwargs: Any,
) -> list[str]:
"""Embed images and add to the table."""
return self._engine._run_as_sync(
self.__vs.aadd_images(uris, metadatas, ids, **kwargs) # type: ignore
)

def similarity_search_image(
self,
image_uri: str,
k: Optional[int] = None,
filter: Optional[dict] = None,
**kwargs: Any,
) -> list[Document]:
"""Return docs selected by similarity search on image."""
return self._engine._run_as_sync(
self.__vs.asimilarity_search_image(image_uri, k, filter, **kwargs) # type: ignore
)

async def asimilarity_search_image(
self,
image_uri: str,
k: Optional[int] = None,
filter: Optional[dict] = None,
**kwargs: Any,
) -> list[Document]:
"""Return docs selected by similarity search on image_uri."""
return await self._engine._run_as_async(
self.__vs.asimilarity_search_image(image_uri, k, filter, **kwargs) # type: ignore
)
Loading
Loading