-
Notifications
You must be signed in to change notification settings - Fork 81
feat: Adding multi modal support for PGVectorStore #207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
43f89a5
97b387c
3dc9cc5
b5bb4ff
f9f5337
ffe8c7a
f45e4de
273a57b
3dfbad6
aaa1514
9b41ade
cc26044
d68c75c
9efdac8
92663b9
5c35c6f
a5399a4
477a038
39dc8f1
7532fdf
0287543
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,15 @@ | ||
# TODO: Remove below import when minimum supported Python version is 3.10 | ||
from __future__ import annotations | ||
|
||
import base64 | ||
import copy | ||
import json | ||
import uuid | ||
from typing import Any, Callable, Iterable, Optional, Sequence | ||
from urllib.parse import urlparse | ||
|
||
import numpy as np | ||
import requests | ||
from langchain_core.documents import Document | ||
from langchain_core.embeddings import Embeddings | ||
from langchain_core.vectorstores import VectorStore, utils | ||
|
@@ -365,6 +368,98 @@ async def aadd_documents( | |
ids = await self.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs) | ||
return ids | ||
|
||
def _encode_image(self, uri: str) -> str: | ||
"""Get base64 string from a image URI.""" | ||
if uri.startswith("gs://"): | ||
from google.cloud import storage # type: ignore | ||
|
||
path_without_prefix = uri[len("gs://") :] | ||
parts = path_without_prefix.split("/", 1) | ||
bucket_name = parts[0] | ||
object_name = "" # Default for bucket root if no object specified | ||
if len(parts) == 2: | ||
object_name = parts[1] | ||
storage_client = storage.Client() | ||
bucket = storage_client.bucket(bucket_name) | ||
blob = bucket.blob(object_name) | ||
return base64.b64encode(blob.download_as_bytes()).decode("utf-8") | ||
|
||
parsed_uri = urlparse(uri) | ||
if parsed_uri.scheme in ["http", "https"]: | ||
response = requests.get(uri, stream=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is an SSRF attack |
||
response.raise_for_status() | ||
return base64.b64encode(response.content).decode("utf-8") | ||
|
||
with open(uri, "rb") as image_file: | ||
return base64.b64encode(image_file.read()).decode("utf-8") | ||
|
||
async def aadd_images( | ||
self, | ||
uris: list[str], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Accepting URIs without safe guards is an SSRF attack There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My understanding is that SSRF attacks are generally dealt with by the application layer. Is that correct, or is it more of a framework responsibility? |
||
metadatas: Optional[list[dict]] = None, | ||
ids: Optional[list[str]] = None, | ||
**kwargs: Any, | ||
) -> list[str]: | ||
"""Embed images and add to the table. | ||
|
||
Args: | ||
uris (list[str]): List of local image URIs to add to the table. | ||
metadatas (Optional[list[dict]]): List of metadatas to add to table records. | ||
ids: (Optional[list[str]]): List of IDs to add to table records. | ||
|
||
Returns: | ||
List of record IDs added. | ||
""" | ||
encoded_images = [] | ||
if metadatas is None: | ||
metadatas = [{"image_uri": uri} for uri in uris] | ||
|
||
for uri in uris: | ||
encoded_image = self._encode_image(uri) | ||
encoded_images.append(encoded_image) | ||
|
||
embeddings = self._images_embedding_helper(uris) | ||
ids = await self.aadd_embeddings( | ||
encoded_images, embeddings, metadatas=metadatas, ids=ids, **kwargs | ||
) | ||
return ids | ||
|
||
def _images_embedding_helper(self, image_uris: list[str]) -> list[list[float]]: | ||
# check if either `embed_images()` or `embed_image()` API is supported by the embedding service used | ||
if hasattr(self.embedding_service, "embed_images"): | ||
try: | ||
embeddings = self.embedding_service.embed_images(image_uris) | ||
except Exception as e: | ||
raise Exception( | ||
f"Make sure your selected embedding model supports list of image URIs as input. {str(e)}" | ||
) | ||
elif hasattr(self.embedding_service, "embed_image"): | ||
try: | ||
embeddings = self.embedding_service.embed_image(image_uris) | ||
except Exception as e: | ||
raise Exception( | ||
f"Make sure your selected embedding model supports list of image URIs as input. {str(e)}" | ||
) | ||
else: | ||
raise ValueError( | ||
"Please use an embedding model that supports image embedding." | ||
) | ||
return embeddings | ||
|
||
async def asimilarity_search_image( | ||
self, | ||
image_uri: str, | ||
k: Optional[int] = None, | ||
filter: Optional[dict] = None, | ||
**kwargs: Any, | ||
) -> list[Document]: | ||
"""Return docs selected by similarity search on query.""" | ||
embedding = self._images_embedding_helper([image_uri])[0] | ||
|
||
return await self.asimilarity_search_by_vector( | ||
embedding=embedding, k=k, filter=filter, **kwargs | ||
) | ||
|
||
async def adelete( | ||
self, | ||
ids: Optional[list] = None, | ||
|
@@ -1268,3 +1363,25 @@ def max_marginal_relevance_search_with_score_by_vector( | |
raise NotImplementedError( | ||
"Sync methods are not implemented for AsyncPGVectorStore. Use PGVectorStore interface instead." | ||
) | ||
|
||
def add_images( | ||
self, | ||
uris: list[str], | ||
metadatas: Optional[list[dict]] = None, | ||
ids: Optional[list[str]] = None, | ||
**kwargs: Any, | ||
) -> list[str]: | ||
raise NotImplementedError( | ||
"Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." | ||
) | ||
|
||
def similarity_search_image( | ||
self, | ||
image_uri: str, | ||
k: Optional[int] = None, | ||
filter: Optional[dict] = None, | ||
**kwargs: Any, | ||
) -> list[Document]: | ||
raise NotImplementedError( | ||
"Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may want to wrap this in a try except block to provide a more clear error or do you think the error is clear if they are not running in a Google Cloud environment or have set up credentials.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The other langchain packages don't have running integrations tests for 3P providers. We could mock this test or just test this functionality in our package downstream.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently this is the error
google.auth.exceptions.DefaultCredentialsError: Your default credentials were not found. To set up Application Default Credentials, see https://cloud.google.com/docs/authentication/external/set-up-adc for more information.
I think this is pretty descriptive, let me know what you think.
The options for the tests are:
What do you suggest?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If GCS storage call can be easily mock, let's go ahead and do that. If it can't let's keep the test but skip it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, the mock solutions may need more debugging, I've removed the gcs uri from being tested, the other images being created locally are still under the test.
I will recreate the GCS path testing in our libraries.