From b8f87bb6215f19eadd780cc04826c0f0fdf43a3c Mon Sep 17 00:00:00 2001 From: minglu7 <104314597+minglu7@users.noreply.github.com> Date: Thu, 4 Jul 2024 20:06:23 +0800 Subject: [PATCH] feat: extend default settings in AsyncVectorStore (#2602) * feat: extend default settings in AsyncVectorStore This commit introduces a new parameter, custom_settings, to the AsyncVectorStore class. This allows users to provide their own settings that will extend the default settings. This increases the flexibility of the class and allows it to be tailored to specific use cases. The custom settings are applied in the _create_index_if_not_exists method. * Update vectorstore.py * Update vectorstore.py apply changes in vectorstore * Update vectorstore.py format the py file * Update test_vectorstore.py add test_custom_index_settings in test_vectorstore * Update test_vectorstore.py * Update vectorstore.py fix file format * Update test_vectorstore.py fix format * Update vectorstore.py add error tips in vectorstore when confilicting the settings * Update vectorstore.py * Update vectorstore.py modify the comments of the param custom_index_settings * Update vectorstore.py * add settings conflict test * reformat --------- Co-authored-by: Quentin Pradet Co-authored-by: Miguel Grinberg (cherry picked from commit beb03deb2a97b34d9e3c212ec5ec839a73677879) --- .../helpers/vectorstore/_async/vectorstore.py | 17 ++++ .../helpers/vectorstore/_sync/vectorstore.py | 17 ++++ .../test_vectorstore/test_vectorstore.py | 85 +++++++++++++++++++ 3 files changed, 119 insertions(+) diff --git a/elasticsearch/helpers/vectorstore/_async/vectorstore.py b/elasticsearch/helpers/vectorstore/_async/vectorstore.py index b79e2dcaf..81356cf92 100644 --- a/elasticsearch/helpers/vectorstore/_async/vectorstore.py +++ b/elasticsearch/helpers/vectorstore/_async/vectorstore.py @@ -60,6 +60,7 @@ def __init__( vector_field: str = "vector_field", metadata_mappings: Optional[Dict[str, Any]] = None, user_agent: str = f"elasticsearch-py-vs/{lib_version}", + custom_index_settings: Optional[Dict[str, Any]] = None, ) -> None: """ :param user_header: user agent header specific to the 3rd party integration. @@ -72,6 +73,11 @@ def __init__( the embedding vector goes in this field. :param client: Elasticsearch client connection. Alternatively specify the Elasticsearch connection with the other es_* parameters. + :param custom_index_settings: A dictionary of custom settings for the index. + This can include configurations like the number of shards, number of replicas, + analysis settings, and other index-specific settings. If not provided, default + settings will be used. Note that if the same setting is provided by both the user + and the strategy, will raise an error. """ # Add integration-specific usage header for tracking usage in Elastic Cloud. # client.options preserves existing (non-user-agent) headers. @@ -90,6 +96,7 @@ def __init__( self.text_field = text_field self.vector_field = vector_field self.metadata_mappings = metadata_mappings + self.custom_index_settings = custom_index_settings async def close(self) -> None: return await self.client.close() @@ -306,6 +313,16 @@ async def _create_index_if_not_exists(self) -> None: vector_field=self.vector_field, num_dimensions=self.num_dimensions, ) + + if self.custom_index_settings: + conflicting_keys = set(self.custom_index_settings.keys()) & set( + settings.keys() + ) + if conflicting_keys: + raise ValueError(f"Conflicting settings: {conflicting_keys}") + else: + settings.update(self.custom_index_settings) + if self.metadata_mappings: metadata = mappings["properties"].get("metadata", {"properties": {}}) for key in self.metadata_mappings.keys(): diff --git a/elasticsearch/helpers/vectorstore/_sync/vectorstore.py b/elasticsearch/helpers/vectorstore/_sync/vectorstore.py index 2feb96ec4..9aaa966f3 100644 --- a/elasticsearch/helpers/vectorstore/_sync/vectorstore.py +++ b/elasticsearch/helpers/vectorstore/_sync/vectorstore.py @@ -57,6 +57,7 @@ def __init__( vector_field: str = "vector_field", metadata_mappings: Optional[Dict[str, Any]] = None, user_agent: str = f"elasticsearch-py-vs/{lib_version}", + custom_index_settings: Optional[Dict[str, Any]] = None, ) -> None: """ :param user_header: user agent header specific to the 3rd party integration. @@ -69,6 +70,11 @@ def __init__( the embedding vector goes in this field. :param client: Elasticsearch client connection. Alternatively specify the Elasticsearch connection with the other es_* parameters. + :param custom_index_settings: A dictionary of custom settings for the index. + This can include configurations like the number of shards, number of replicas, + analysis settings, and other index-specific settings. If not provided, default + settings will be used. Note that if the same setting is provided by both the user + and the strategy, will raise an error. """ # Add integration-specific usage header for tracking usage in Elastic Cloud. # client.options preserves existing (non-user-agent) headers. @@ -87,6 +93,7 @@ def __init__( self.text_field = text_field self.vector_field = vector_field self.metadata_mappings = metadata_mappings + self.custom_index_settings = custom_index_settings def close(self) -> None: return self.client.close() @@ -303,6 +310,16 @@ def _create_index_if_not_exists(self) -> None: vector_field=self.vector_field, num_dimensions=self.num_dimensions, ) + + if self.custom_index_settings: + conflicting_keys = set(self.custom_index_settings.keys()) & set( + settings.keys() + ) + if conflicting_keys: + raise ValueError(f"Conflicting settings: {conflicting_keys}") + else: + settings.update(self.custom_index_settings) + if self.metadata_mappings: metadata = mappings["properties"].get("metadata", {"properties": {}}) for key in self.metadata_mappings.keys(): diff --git a/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py b/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py index bb15d3dc7..a8cae670f 100644 --- a/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py +++ b/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py @@ -907,3 +907,88 @@ def test_metadata_mapping(self, sync_client: Elasticsearch, index: str) -> None: assert "metadata" in mapping_properties for key, val in test_mappings.items(): assert mapping_properties["metadata"]["properties"][key] == val + + def test_custom_index_settings( + self, sync_client: Elasticsearch, index: str + ) -> None: + """Test that the custom index settings are applied.""" + test_settings = { + "analysis": { + "tokenizer": { + "custom_tokenizer": {"type": "pattern", "pattern": "[,;\\s]+"} + }, + "analyzer": { + "custom_analyzer": { + "type": "custom", + "tokenizer": "custom_tokenizer", + } + }, + } + } + + test_mappings = { + "my_field": {"type": "keyword"}, + "another_field": {"type": "text", "analyzer": "custom_analyzer"}, + } + + store = VectorStore( + index=index, + retrieval_strategy=DenseVectorStrategy(distance=DistanceMetric.COSINE), + embedding_service=FakeEmbeddings(), + num_dimensions=10, + client=sync_client, + metadata_mappings=test_mappings, + custom_index_settings=test_settings, + ) + + sample_texts = [ + "Sample text one, with some keywords.", + "Another; sample, text with; different keywords.", + "Third example text, with more keywords.", + ] + store.add_texts(texts=sample_texts) + + # Fetch the actual index settings from Elasticsearch + actual_settings = sync_client.indices.get_settings(index=index) + + # Assert that the custom settings were applied correctly + custom_settings_applied = actual_settings[index]["settings"]["index"][ + "analysis" + ] + assert ( + custom_settings_applied == test_settings["analysis"] + ), f"Expected custom index settings {test_settings} but got {custom_settings_applied}" + + def test_custom_index_settings_with_collision( + self, sync_client: Elasticsearch, index: str + ) -> None: + """Test that custom index settings that collide cause an error.""" + test_settings = { + "default_pipeline": "my_pipeline", + "analysis": { + "tokenizer": { + "custom_tokenizer": {"type": "pattern", "pattern": "[,;\\s]+"} + }, + "analyzer": { + "custom_analyzer": { + "type": "custom", + "tokenizer": "custom_tokenizer", + } + }, + }, + } + + test_mappings = { + "my_field": {"type": "keyword"}, + "another_field": {"type": "text", "analyzer": "custom_analyzer"}, + } + + store = VectorStore( + index=index, + retrieval_strategy=SparseVectorStrategy(), + client=sync_client, + metadata_mappings=test_mappings, + custom_index_settings=test_settings, + ) + with pytest.raises(ValueError, match="Conflicting settings"): + store.add_texts(texts=["some text"])