Skip to content

Commit 1554c3c

Browse files
authored
Support multi search template API.
Original Pull Request #2807 Closes #2704
1 parent 260dadd commit 1554c3c

File tree

3 files changed

+358
-89
lines changed

3 files changed

+358
-89
lines changed

src/main/java/org/springframework/data/elasticsearch/client/elc/ElasticsearchTemplate.java

Lines changed: 72 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import java.io.IOException;
3030
import java.time.Duration;
3131
import java.util.ArrayList;
32+
import java.util.Collections;
3233
import java.util.HashMap;
3334
import java.util.Iterator;
3435
import java.util.List;
@@ -72,6 +73,7 @@
7273
* @author Peter-Josef Meisch
7374
* @author Hamid Rahimi
7475
* @author Illia Ulianov
76+
* @author Haibo Liu
7577
* @since 4.4
7678
*/
7779
public class ElasticsearchTemplate extends AbstractElasticsearchTemplate {
@@ -437,13 +439,10 @@ public <T> List<SearchHits<T>> multiSearch(List<? extends Query> queries, Class<
437439
Assert.notNull(queries, "queries must not be null");
438440
Assert.notNull(clazz, "clazz must not be null");
439441

440-
List<MultiSearchQueryParameter> multiSearchQueryParameters = new ArrayList<>(queries.size());
441-
for (Query query : queries) {
442-
multiSearchQueryParameters.add(new MultiSearchQueryParameter(query, clazz, getIndexCoordinatesFor(clazz)));
443-
}
444-
442+
int size = queries.size();
445443
// noinspection unchecked
446-
return doMultiSearch(multiSearchQueryParameters).stream().map(searchHits -> (SearchHits<T>) searchHits)
444+
return multiSearch(queries, Collections.nCopies(size, clazz), Collections.nCopies(size, index))
445+
.stream().map(searchHits -> (SearchHits<T>) searchHits)
447446
.collect(Collectors.toList());
448447
}
449448

@@ -454,14 +453,7 @@ public List<SearchHits<?>> multiSearch(List<? extends Query> queries, List<Class
454453
Assert.notNull(classes, "classes must not be null");
455454
Assert.isTrue(queries.size() == classes.size(), "queries and classes must have the same size");
456455

457-
List<MultiSearchQueryParameter> multiSearchQueryParameters = new ArrayList<>(queries.size());
458-
Iterator<Class<?>> it = classes.iterator();
459-
for (Query query : queries) {
460-
Class<?> clazz = it.next();
461-
multiSearchQueryParameters.add(new MultiSearchQueryParameter(query, clazz, getIndexCoordinatesFor(clazz)));
462-
}
463-
464-
return doMultiSearch(multiSearchQueryParameters);
456+
return multiSearch(queries, classes, classes.stream().map(this::getIndexCoordinatesFor).toList());
465457
}
466458

467459
@Override
@@ -473,14 +465,7 @@ public List<SearchHits<?>> multiSearch(List<? extends Query> queries, List<Class
473465
Assert.notNull(index, "index must not be null");
474466
Assert.isTrue(queries.size() == classes.size(), "queries and classes must have the same size");
475467

476-
List<MultiSearchQueryParameter> multiSearchQueryParameters = new ArrayList<>(queries.size());
477-
Iterator<Class<?>> it = classes.iterator();
478-
for (Query query : queries) {
479-
Class<?> clazz = it.next();
480-
multiSearchQueryParameters.add(new MultiSearchQueryParameter(query, clazz, index));
481-
}
482-
483-
return doMultiSearch(multiSearchQueryParameters);
468+
return multiSearch(queries, classes, Collections.nCopies(queries.size(), index));
484469
}
485470

486471
@Override
@@ -497,16 +482,49 @@ public List<SearchHits<?>> multiSearch(List<? extends Query> queries, List<Class
497482
Iterator<Class<?>> it = classes.iterator();
498483
Iterator<IndexCoordinates> indexesIt = indexes.iterator();
499484

485+
Assert.isTrue(!queries.isEmpty(), "queries should have at least 1 query");
486+
boolean isSearchTemplateQuery = queries.get(0) instanceof SearchTemplateQuery;
487+
500488
for (Query query : queries) {
489+
Assert.isTrue((query instanceof SearchTemplateQuery) == isSearchTemplateQuery,
490+
"SearchTemplateQuery can't be mixed with other types of query in multiple search");
491+
501492
Class<?> clazz = it.next();
502493
IndexCoordinates index = indexesIt.next();
503494
multiSearchQueryParameters.add(new MultiSearchQueryParameter(query, clazz, index));
504495
}
505496

506-
return doMultiSearch(multiSearchQueryParameters);
497+
return multiSearch(multiSearchQueryParameters, isSearchTemplateQuery);
498+
}
499+
500+
private List<SearchHits<?>> multiSearch(List<MultiSearchQueryParameter> multiSearchQueryParameters,
501+
boolean isSearchTemplateQuery) {
502+
return isSearchTemplateQuery ?
503+
doMultiTemplateSearch(multiSearchQueryParameters.stream()
504+
.map(p -> new MultiSearchTemplateQueryParameter((SearchTemplateQuery) p.query, p.clazz, p.index))
505+
.toList())
506+
: doMultiSearch(multiSearchQueryParameters);
507+
}
508+
509+
private List<SearchHits<?>> doMultiTemplateSearch(List<MultiSearchTemplateQueryParameter> mSearchTemplateQueryParameters) {
510+
MsearchTemplateRequest request = requestConverter.searchMsearchTemplateRequest(mSearchTemplateQueryParameters,
511+
routingResolver.getRouting());
512+
513+
MsearchTemplateResponse<EntityAsMap> response = execute(client -> client.msearchTemplate(request, EntityAsMap.class));
514+
List<MultiSearchResponseItem<EntityAsMap>> responseItems = response.responses();
515+
516+
Assert.isTrue(mSearchTemplateQueryParameters.size() == responseItems.size(),
517+
"number of response items does not match number of requests");
518+
519+
int size = mSearchTemplateQueryParameters.size();
520+
List<Class<?>> classes = mSearchTemplateQueryParameters
521+
.stream().map(MultiSearchTemplateQueryParameter::clazz).collect(Collectors.toList());
522+
List<IndexCoordinates> indices = mSearchTemplateQueryParameters
523+
.stream().map(MultiSearchTemplateQueryParameter::index).collect(Collectors.toList());
524+
525+
return getSearchHitsFromMsearchResponse(size, classes, indices, responseItems);
507526
}
508527

509-
@SuppressWarnings({ "unchecked", "rawtypes" })
510528
private List<SearchHits<?>> doMultiSearch(List<MultiSearchQueryParameter> multiSearchQueryParameters) {
511529

512530
MsearchRequest request = requestConverter.searchMsearchRequest(multiSearchQueryParameters,
@@ -518,31 +536,46 @@ private List<SearchHits<?>> doMultiSearch(List<MultiSearchQueryParameter> multiS
518536
Assert.isTrue(multiSearchQueryParameters.size() == responseItems.size(),
519537
"number of response items does not match number of requests");
520538

521-
List<SearchHits<?>> searchHitsList = new ArrayList<>(multiSearchQueryParameters.size());
539+
int size = multiSearchQueryParameters.size();
540+
List<Class<?>> classes = multiSearchQueryParameters
541+
.stream().map(MultiSearchQueryParameter::clazz).collect(Collectors.toList());
542+
List<IndexCoordinates> indices = multiSearchQueryParameters
543+
.stream().map(MultiSearchQueryParameter::index).collect(Collectors.toList());
522544

523-
Iterator<MultiSearchQueryParameter> queryIterator = multiSearchQueryParameters.iterator();
545+
return getSearchHitsFromMsearchResponse(size, classes, indices, responseItems);
546+
}
547+
548+
/**
549+
* {@link MsearchResponse} and {@link MsearchTemplateResponse} share the same {@link MultiSearchResponseItem}
550+
*/
551+
@SuppressWarnings({ "unchecked", "rawtypes" })
552+
private List<SearchHits<?>> getSearchHitsFromMsearchResponse(int size, List<Class<?>> classes,
553+
List<IndexCoordinates> indices, List<MultiSearchResponseItem<EntityAsMap>> responseItems) {
554+
List<SearchHits<?>> searchHitsList = new ArrayList<>(size);
555+
Iterator<Class<?>> clazzIter = classes.iterator();
556+
Iterator<IndexCoordinates> indexIter = indices.iterator();
524557
Iterator<MultiSearchResponseItem<EntityAsMap>> responseIterator = responseItems.iterator();
525558

526-
while (queryIterator.hasNext()) {
527-
MultiSearchQueryParameter queryParameter = queryIterator.next();
559+
while (clazzIter.hasNext() && indexIter.hasNext()) {
528560
MultiSearchResponseItem<EntityAsMap> responseItem = responseIterator.next();
529561

530562
if (responseItem.isResult()) {
531563

532-
Class clazz = queryParameter.clazz;
564+
Class clazz = clazzIter.next();
565+
IndexCoordinates index = indexIter.next();
533566
ReadDocumentCallback<?> documentCallback = new ReadDocumentCallback<>(elasticsearchConverter, clazz,
534-
queryParameter.index);
567+
index);
535568
SearchDocumentResponseCallback<SearchHits<?>> callback = new ReadSearchDocumentResponseCallback<>(clazz,
536-
queryParameter.index);
569+
index);
537570

538571
SearchHits<?> searchHits = callback.doWith(
539572
SearchDocumentResponseBuilder.from(responseItem.result(), getEntityCreator(documentCallback), jsonpMapper));
540573

541574
searchHitsList.add(searchHits);
542575
} else {
543576
if (LOGGER.isWarnEnabled()) {
544-
LOGGER
545-
.warn(String.format("multisearch responsecontains failure: {}", responseItem.failure().error().reason()));
577+
LOGGER.warn(String.format("multisearch response contains failure: %s",
578+
responseItem.failure().error().reason()));
546579
}
547580
}
548581
}
@@ -556,6 +589,12 @@ private List<SearchHits<?>> doMultiSearch(List<MultiSearchQueryParameter> multiS
556589
record MultiSearchQueryParameter(Query query, Class<?> clazz, IndexCoordinates index) {
557590
}
558591

592+
/**
593+
* value class combining the information needed for a single query in a template multisearch request.
594+
*/
595+
record MultiSearchTemplateQueryParameter(SearchTemplateQuery query, Class<?> clazz, IndexCoordinates index) {
596+
}
597+
559598
@Override
560599
public String openPointInTime(IndexCoordinates index, Duration keepAlive, Boolean ignoreUnavailable) {
561600

src/main/java/org/springframework/data/elasticsearch/client/elc/RequestConverter.java

Lines changed: 75 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import co.elastic.clients.elasticsearch.core.bulk.UpdateOperation;
4545
import co.elastic.clients.elasticsearch.core.mget.MultiGetOperation;
4646
import co.elastic.clients.elasticsearch.core.msearch.MultisearchBody;
47+
import co.elastic.clients.elasticsearch.core.msearch.MultisearchHeader;
4748
import co.elastic.clients.elasticsearch.core.search.Highlight;
4849
import co.elastic.clients.elasticsearch.core.search.Rescore;
4950
import co.elastic.clients.elasticsearch.core.search.SourceConfig;
@@ -54,6 +55,7 @@
5455
import co.elastic.clients.json.JsonData;
5556
import co.elastic.clients.json.JsonpDeserializer;
5657
import co.elastic.clients.json.JsonpMapper;
58+
import co.elastic.clients.util.ObjectBuilder;
5759
import jakarta.json.stream.JsonParser;
5860

5961
import java.io.ByteArrayInputStream;
@@ -66,6 +68,7 @@
6668
import java.util.HashMap;
6769
import java.util.List;
6870
import java.util.Map;
71+
import java.util.Set;
6972
import java.util.function.Function;
7073
import java.util.stream.Collectors;
7174

@@ -397,10 +400,7 @@ public co.elastic.clients.elasticsearch.indices.PutTemplateRequest indicesPutTem
397400
.order(putTemplateRequest.getOrder());
398401

399402
if (putTemplateRequest.getSettings() != null) {
400-
Function<Map.Entry<String, Object>, String> keyMapper = Map.Entry::getKey;
401-
Function<Map.Entry<String, Object>, JsonData> valueMapper = entry -> JsonData.of(entry.getValue(), jsonpMapper);
402-
Map<String, JsonData> settings = putTemplateRequest.getSettings().entrySet().stream()
403-
.collect(Collectors.toMap(keyMapper, valueMapper));
403+
Map<String, JsonData> settings = getTemplateParams(putTemplateRequest.getSettings().entrySet());
404404
builder.settings(settings);
405405
}
406406

@@ -1146,6 +1146,36 @@ public <T> SearchRequest searchRequest(Query query, @Nullable String routing, @N
11461146
return builder.build();
11471147
}
11481148

1149+
public MsearchTemplateRequest searchMsearchTemplateRequest(
1150+
List<ElasticsearchTemplate.MultiSearchTemplateQueryParameter> multiSearchTemplateQueryParameters,
1151+
@Nullable String routing) {
1152+
1153+
// basically the same stuff as in template search
1154+
return MsearchTemplateRequest.of(mtrb -> {
1155+
multiSearchTemplateQueryParameters.forEach(param -> {
1156+
var query = param.query();
1157+
mtrb.searchTemplates(stb -> stb
1158+
.header(msearchHeaderBuilder(query, param.index(), routing))
1159+
.body(bb -> {
1160+
bb //
1161+
.explain(query.getExplain()) //
1162+
.id(query.getId()) //
1163+
.source(query.getSource()) //
1164+
;
1165+
1166+
if (!CollectionUtils.isEmpty(query.getParams())) {
1167+
Map<String, JsonData> params = getTemplateParams(query.getParams().entrySet());
1168+
bb.params(params);
1169+
}
1170+
1171+
return bb;
1172+
})
1173+
);
1174+
});
1175+
return mtrb;
1176+
});
1177+
}
1178+
11491179
public MsearchRequest searchMsearchRequest(
11501180
List<ElasticsearchTemplate.MultiSearchQueryParameter> multiSearchQueryParameters, @Nullable String routing) {
11511181

@@ -1157,28 +1187,7 @@ public MsearchRequest searchMsearchRequest(
11571187

11581188
var query = param.query();
11591189
mrb.searches(sb -> sb //
1160-
.header(h -> {
1161-
var searchType = (query instanceof NativeQuery nativeQuery && nativeQuery.getKnnQuery() != null) ? null
1162-
: searchType(query.getSearchType());
1163-
1164-
h //
1165-
.index(Arrays.asList(param.index().getIndexNames())) //
1166-
.searchType(searchType) //
1167-
.requestCache(query.getRequestCache()) //
1168-
;
1169-
1170-
if (StringUtils.hasText(query.getRoute())) {
1171-
h.routing(query.getRoute());
1172-
} else if (StringUtils.hasText(routing)) {
1173-
h.routing(routing);
1174-
}
1175-
1176-
if (query.getPreference() != null) {
1177-
h.preference(query.getPreference());
1178-
}
1179-
1180-
return h;
1181-
}) //
1190+
.header(msearchHeaderBuilder(query, param.index(), routing)) //
11821191
.body(bb -> {
11831192
bb //
11841193
.query(getQuery(query, param.clazz()))//
@@ -1284,6 +1293,35 @@ public MsearchRequest searchMsearchRequest(
12841293
});
12851294
}
12861295

1296+
/**
1297+
* {@link MsearchRequest} and {@link MsearchTemplateRequest} share the same {@link MultisearchHeader}
1298+
*/
1299+
private Function<MultisearchHeader.Builder, ObjectBuilder<MultisearchHeader>> msearchHeaderBuilder(Query query,
1300+
IndexCoordinates index, @Nullable String routing) {
1301+
return h -> {
1302+
var searchType = (query instanceof NativeQuery nativeQuery && nativeQuery.getKnnQuery() != null) ? null
1303+
: searchType(query.getSearchType());
1304+
1305+
h //
1306+
.index(Arrays.asList(index.getIndexNames())) //
1307+
.searchType(searchType) //
1308+
.requestCache(query.getRequestCache()) //
1309+
;
1310+
1311+
if (StringUtils.hasText(query.getRoute())) {
1312+
h.routing(query.getRoute());
1313+
} else if (StringUtils.hasText(routing)) {
1314+
h.routing(routing);
1315+
}
1316+
1317+
if (query.getPreference() != null) {
1318+
h.preference(query.getPreference());
1319+
}
1320+
1321+
return h;
1322+
};
1323+
}
1324+
12871325
private <T> void prepareSearchRequest(Query query, @Nullable String routing, @Nullable Class<T> clazz,
12881326
IndexCoordinates indexCoordinates, SearchRequest.Builder builder, boolean forCount, boolean forBatchedSearch) {
12891327

@@ -1770,7 +1808,8 @@ public SearchTemplateRequest searchTemplate(SearchTemplateQuery query, @Nullable
17701808
.id(query.getId()) //
17711809
.index(Arrays.asList(index.getIndexNames())) //
17721810
.preference(query.getPreference()) //
1773-
.searchType(searchType(query.getSearchType())).source(query.getSource()) //
1811+
.searchType(searchType(query.getSearchType())) //
1812+
.source(query.getSource()) //
17741813
;
17751814

17761815
if (query.getRoute() != null) {
@@ -1789,17 +1828,22 @@ public SearchTemplateRequest searchTemplate(SearchTemplateQuery query, @Nullable
17891828
}
17901829

17911830
if (!CollectionUtils.isEmpty(query.getParams())) {
1792-
Function<Map.Entry<String, Object>, String> keyMapper = Map.Entry::getKey;
1793-
Function<Map.Entry<String, Object>, JsonData> valueMapper = entry -> JsonData.of(entry.getValue(), jsonpMapper);
1794-
Map<String, JsonData> params = query.getParams().entrySet().stream()
1795-
.collect(Collectors.toMap(keyMapper, valueMapper));
1831+
Map<String, JsonData> params = getTemplateParams(query.getParams().entrySet());
17961832
builder.params(params);
17971833
}
17981834

17991835
return builder;
18001836
});
18011837
}
18021838

1839+
@NotNull
1840+
private Map<String, JsonData> getTemplateParams(Set<Map.Entry<String, Object>> query) {
1841+
Function<Map.Entry<String, Object>, String> keyMapper = Map.Entry::getKey;
1842+
Function<Map.Entry<String, Object>, JsonData> valueMapper = entry -> JsonData.of(entry.getValue(), jsonpMapper);
1843+
return query.stream()
1844+
.collect(Collectors.toMap(keyMapper, valueMapper));
1845+
}
1846+
18031847
// endregion
18041848

18051849
public PutScriptRequest scriptPut(Script script) {

0 commit comments

Comments
 (0)