Skip to content

Support multi search template API #2807

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
Expand Down Expand Up @@ -72,6 +73,7 @@
* @author Peter-Josef Meisch
* @author Hamid Rahimi
* @author Illia Ulianov
* @author Haibo Liu
* @since 4.4
*/
public class ElasticsearchTemplate extends AbstractElasticsearchTemplate {
Expand Down Expand Up @@ -437,13 +439,10 @@ public <T> List<SearchHits<T>> multiSearch(List<? extends Query> queries, Class<
Assert.notNull(queries, "queries must not be null");
Assert.notNull(clazz, "clazz must not be null");

List<MultiSearchQueryParameter> multiSearchQueryParameters = new ArrayList<>(queries.size());
for (Query query : queries) {
multiSearchQueryParameters.add(new MultiSearchQueryParameter(query, clazz, getIndexCoordinatesFor(clazz)));
}

int size = queries.size();
// noinspection unchecked
return doMultiSearch(multiSearchQueryParameters).stream().map(searchHits -> (SearchHits<T>) searchHits)
return multiSearch(queries, Collections.nCopies(size, clazz), Collections.nCopies(size, index))
.stream().map(searchHits -> (SearchHits<T>) searchHits)
.collect(Collectors.toList());
}

Expand All @@ -454,14 +453,7 @@ public List<SearchHits<?>> multiSearch(List<? extends Query> queries, List<Class
Assert.notNull(classes, "classes must not be null");
Assert.isTrue(queries.size() == classes.size(), "queries and classes must have the same size");

List<MultiSearchQueryParameter> multiSearchQueryParameters = new ArrayList<>(queries.size());
Iterator<Class<?>> it = classes.iterator();
for (Query query : queries) {
Class<?> clazz = it.next();
multiSearchQueryParameters.add(new MultiSearchQueryParameter(query, clazz, getIndexCoordinatesFor(clazz)));
}

return doMultiSearch(multiSearchQueryParameters);
return multiSearch(queries, classes, classes.stream().map(this::getIndexCoordinatesFor).toList());
}

@Override
Expand All @@ -473,14 +465,7 @@ public List<SearchHits<?>> multiSearch(List<? extends Query> queries, List<Class
Assert.notNull(index, "index must not be null");
Assert.isTrue(queries.size() == classes.size(), "queries and classes must have the same size");

List<MultiSearchQueryParameter> multiSearchQueryParameters = new ArrayList<>(queries.size());
Iterator<Class<?>> it = classes.iterator();
for (Query query : queries) {
Class<?> clazz = it.next();
multiSearchQueryParameters.add(new MultiSearchQueryParameter(query, clazz, index));
}

return doMultiSearch(multiSearchQueryParameters);
return multiSearch(queries, classes, Collections.nCopies(queries.size(), index));
}

@Override
Expand All @@ -497,16 +482,49 @@ public List<SearchHits<?>> multiSearch(List<? extends Query> queries, List<Class
Iterator<Class<?>> it = classes.iterator();
Iterator<IndexCoordinates> indexesIt = indexes.iterator();

Assert.isTrue(!queries.isEmpty(), "queries should have at least 1 query");
boolean isSearchTemplateQuery = queries.get(0) instanceof SearchTemplateQuery;

for (Query query : queries) {
Assert.isTrue((query instanceof SearchTemplateQuery) == isSearchTemplateQuery,
"SearchTemplateQuery can't be mixed with other types of query in multiple search");

Class<?> clazz = it.next();
IndexCoordinates index = indexesIt.next();
multiSearchQueryParameters.add(new MultiSearchQueryParameter(query, clazz, index));
}

return doMultiSearch(multiSearchQueryParameters);
return multiSearch(multiSearchQueryParameters, isSearchTemplateQuery);
}

private List<SearchHits<?>> multiSearch(List<MultiSearchQueryParameter> multiSearchQueryParameters,
boolean isSearchTemplateQuery) {
return isSearchTemplateQuery ?
doMultiTemplateSearch(multiSearchQueryParameters.stream()
.map(p -> new MultiSearchTemplateQueryParameter((SearchTemplateQuery) p.query, p.clazz, p.index))
.toList())
: doMultiSearch(multiSearchQueryParameters);
}

private List<SearchHits<?>> doMultiTemplateSearch(List<MultiSearchTemplateQueryParameter> mSearchTemplateQueryParameters) {
MsearchTemplateRequest request = requestConverter.searchMsearchTemplateRequest(mSearchTemplateQueryParameters,
routingResolver.getRouting());

MsearchTemplateResponse<EntityAsMap> response = execute(client -> client.msearchTemplate(request, EntityAsMap.class));
List<MultiSearchResponseItem<EntityAsMap>> responseItems = response.responses();

Assert.isTrue(mSearchTemplateQueryParameters.size() == responseItems.size(),
"number of response items does not match number of requests");

int size = mSearchTemplateQueryParameters.size();
List<Class<?>> classes = mSearchTemplateQueryParameters
.stream().map(MultiSearchTemplateQueryParameter::clazz).collect(Collectors.toList());
List<IndexCoordinates> indices = mSearchTemplateQueryParameters
.stream().map(MultiSearchTemplateQueryParameter::index).collect(Collectors.toList());

return getSearchHitsFromMsearchResponse(size, classes, indices, responseItems);
}

@SuppressWarnings({ "unchecked", "rawtypes" })
private List<SearchHits<?>> doMultiSearch(List<MultiSearchQueryParameter> multiSearchQueryParameters) {

MsearchRequest request = requestConverter.searchMsearchRequest(multiSearchQueryParameters,
Expand All @@ -518,31 +536,46 @@ private List<SearchHits<?>> doMultiSearch(List<MultiSearchQueryParameter> multiS
Assert.isTrue(multiSearchQueryParameters.size() == responseItems.size(),
"number of response items does not match number of requests");

List<SearchHits<?>> searchHitsList = new ArrayList<>(multiSearchQueryParameters.size());
int size = multiSearchQueryParameters.size();
List<Class<?>> classes = multiSearchQueryParameters
.stream().map(MultiSearchQueryParameter::clazz).collect(Collectors.toList());
List<IndexCoordinates> indices = multiSearchQueryParameters
.stream().map(MultiSearchQueryParameter::index).collect(Collectors.toList());

Iterator<MultiSearchQueryParameter> queryIterator = multiSearchQueryParameters.iterator();
return getSearchHitsFromMsearchResponse(size, classes, indices, responseItems);
}

/**
* {@link MsearchResponse} and {@link MsearchTemplateResponse} share the same {@link MultiSearchResponseItem}
*/
@SuppressWarnings({ "unchecked", "rawtypes" })
private List<SearchHits<?>> getSearchHitsFromMsearchResponse(int size, List<Class<?>> classes,
List<IndexCoordinates> indices, List<MultiSearchResponseItem<EntityAsMap>> responseItems) {
List<SearchHits<?>> searchHitsList = new ArrayList<>(size);
Iterator<Class<?>> clazzIter = classes.iterator();
Iterator<IndexCoordinates> indexIter = indices.iterator();
Iterator<MultiSearchResponseItem<EntityAsMap>> responseIterator = responseItems.iterator();

while (queryIterator.hasNext()) {
MultiSearchQueryParameter queryParameter = queryIterator.next();
while (clazzIter.hasNext() && indexIter.hasNext()) {
MultiSearchResponseItem<EntityAsMap> responseItem = responseIterator.next();

if (responseItem.isResult()) {

Class clazz = queryParameter.clazz;
Class clazz = clazzIter.next();
IndexCoordinates index = indexIter.next();
ReadDocumentCallback<?> documentCallback = new ReadDocumentCallback<>(elasticsearchConverter, clazz,
queryParameter.index);
index);
SearchDocumentResponseCallback<SearchHits<?>> callback = new ReadSearchDocumentResponseCallback<>(clazz,
queryParameter.index);
index);

SearchHits<?> searchHits = callback.doWith(
SearchDocumentResponseBuilder.from(responseItem.result(), getEntityCreator(documentCallback), jsonpMapper));

searchHitsList.add(searchHits);
} else {
if (LOGGER.isWarnEnabled()) {
LOGGER
.warn(String.format("multisearch responsecontains failure: {}", responseItem.failure().error().reason()));
LOGGER.warn(String.format("multisearch response contains failure: %s",
responseItem.failure().error().reason()));
}
}
}
Expand All @@ -556,6 +589,12 @@ private List<SearchHits<?>> doMultiSearch(List<MultiSearchQueryParameter> multiS
record MultiSearchQueryParameter(Query query, Class<?> clazz, IndexCoordinates index) {
}

/**
* value class combining the information needed for a single query in a template multisearch request.
*/
record MultiSearchTemplateQueryParameter(SearchTemplateQuery query, Class<?> clazz, IndexCoordinates index) {
}

@Override
public String openPointInTime(IndexCoordinates index, Duration keepAlive, Boolean ignoreUnavailable) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import co.elastic.clients.elasticsearch.core.bulk.UpdateOperation;
import co.elastic.clients.elasticsearch.core.mget.MultiGetOperation;
import co.elastic.clients.elasticsearch.core.msearch.MultisearchBody;
import co.elastic.clients.elasticsearch.core.msearch.MultisearchHeader;
import co.elastic.clients.elasticsearch.core.search.Highlight;
import co.elastic.clients.elasticsearch.core.search.Rescore;
import co.elastic.clients.elasticsearch.core.search.SourceConfig;
Expand All @@ -54,6 +55,7 @@
import co.elastic.clients.json.JsonData;
import co.elastic.clients.json.JsonpDeserializer;
import co.elastic.clients.json.JsonpMapper;
import co.elastic.clients.util.ObjectBuilder;
import jakarta.json.stream.JsonParser;

import java.io.ByteArrayInputStream;
Expand All @@ -66,6 +68,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -397,10 +400,7 @@ public co.elastic.clients.elasticsearch.indices.PutTemplateRequest indicesPutTem
.order(putTemplateRequest.getOrder());

if (putTemplateRequest.getSettings() != null) {
Function<Map.Entry<String, Object>, String> keyMapper = Map.Entry::getKey;
Function<Map.Entry<String, Object>, JsonData> valueMapper = entry -> JsonData.of(entry.getValue(), jsonpMapper);
Map<String, JsonData> settings = putTemplateRequest.getSettings().entrySet().stream()
.collect(Collectors.toMap(keyMapper, valueMapper));
Map<String, JsonData> settings = getTemplateParams(putTemplateRequest.getSettings().entrySet());
builder.settings(settings);
}

Expand Down Expand Up @@ -1146,6 +1146,36 @@ public <T> SearchRequest searchRequest(Query query, @Nullable String routing, @N
return builder.build();
}

public MsearchTemplateRequest searchMsearchTemplateRequest(
List<ElasticsearchTemplate.MultiSearchTemplateQueryParameter> multiSearchTemplateQueryParameters,
@Nullable String routing) {

// basically the same stuff as in template search
return MsearchTemplateRequest.of(mtrb -> {
multiSearchTemplateQueryParameters.forEach(param -> {
var query = param.query();
mtrb.searchTemplates(stb -> stb
.header(msearchHeaderBuilder(query, param.index(), routing))
.body(bb -> {
bb //
.explain(query.getExplain()) //
.id(query.getId()) //
.source(query.getSource()) //
;

if (!CollectionUtils.isEmpty(query.getParams())) {
Map<String, JsonData> params = getTemplateParams(query.getParams().entrySet());
bb.params(params);
}

return bb;
})
);
});
return mtrb;
});
}

public MsearchRequest searchMsearchRequest(
List<ElasticsearchTemplate.MultiSearchQueryParameter> multiSearchQueryParameters, @Nullable String routing) {

Expand All @@ -1157,28 +1187,7 @@ public MsearchRequest searchMsearchRequest(

var query = param.query();
mrb.searches(sb -> sb //
.header(h -> {
var searchType = (query instanceof NativeQuery nativeQuery && nativeQuery.getKnnQuery() != null) ? null
: searchType(query.getSearchType());

h //
.index(Arrays.asList(param.index().getIndexNames())) //
.searchType(searchType) //
.requestCache(query.getRequestCache()) //
;

if (StringUtils.hasText(query.getRoute())) {
h.routing(query.getRoute());
} else if (StringUtils.hasText(routing)) {
h.routing(routing);
}

if (query.getPreference() != null) {
h.preference(query.getPreference());
}

return h;
}) //
.header(msearchHeaderBuilder(query, param.index(), routing)) //
.body(bb -> {
bb //
.query(getQuery(query, param.clazz()))//
Expand Down Expand Up @@ -1284,6 +1293,35 @@ public MsearchRequest searchMsearchRequest(
});
}

/**
* {@link MsearchRequest} and {@link MsearchTemplateRequest} share the same {@link MultisearchHeader}
*/
private Function<MultisearchHeader.Builder, ObjectBuilder<MultisearchHeader>> msearchHeaderBuilder(Query query,
IndexCoordinates index, @Nullable String routing) {
return h -> {
var searchType = (query instanceof NativeQuery nativeQuery && nativeQuery.getKnnQuery() != null) ? null
: searchType(query.getSearchType());

h //
.index(Arrays.asList(index.getIndexNames())) //
.searchType(searchType) //
.requestCache(query.getRequestCache()) //
;

if (StringUtils.hasText(query.getRoute())) {
h.routing(query.getRoute());
} else if (StringUtils.hasText(routing)) {
h.routing(routing);
}

if (query.getPreference() != null) {
h.preference(query.getPreference());
}

return h;
};
}

private <T> void prepareSearchRequest(Query query, @Nullable String routing, @Nullable Class<T> clazz,
IndexCoordinates indexCoordinates, SearchRequest.Builder builder, boolean forCount, boolean forBatchedSearch) {

Expand Down Expand Up @@ -1770,7 +1808,8 @@ public SearchTemplateRequest searchTemplate(SearchTemplateQuery query, @Nullable
.id(query.getId()) //
.index(Arrays.asList(index.getIndexNames())) //
.preference(query.getPreference()) //
.searchType(searchType(query.getSearchType())).source(query.getSource()) //
.searchType(searchType(query.getSearchType())) //
.source(query.getSource()) //
;

if (query.getRoute() != null) {
Expand All @@ -1789,17 +1828,22 @@ public SearchTemplateRequest searchTemplate(SearchTemplateQuery query, @Nullable
}

if (!CollectionUtils.isEmpty(query.getParams())) {
Function<Map.Entry<String, Object>, String> keyMapper = Map.Entry::getKey;
Function<Map.Entry<String, Object>, JsonData> valueMapper = entry -> JsonData.of(entry.getValue(), jsonpMapper);
Map<String, JsonData> params = query.getParams().entrySet().stream()
.collect(Collectors.toMap(keyMapper, valueMapper));
Map<String, JsonData> params = getTemplateParams(query.getParams().entrySet());
builder.params(params);
}

return builder;
});
}

@NotNull
private Map<String, JsonData> getTemplateParams(Set<Map.Entry<String, Object>> query) {
Function<Map.Entry<String, Object>, String> keyMapper = Map.Entry::getKey;
Function<Map.Entry<String, Object>, JsonData> valueMapper = entry -> JsonData.of(entry.getValue(), jsonpMapper);
return query.stream()
.collect(Collectors.toMap(keyMapper, valueMapper));
}

// endregion

public PutScriptRequest scriptPut(Script script) {
Expand Down
Loading