Skip to content

Support different routing for each id in multiget. #1956

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package org.springframework.data.elasticsearch.core;

import java.util.Collection;
import java.util.List;

import org.springframework.data.elasticsearch.core.mapping.IndexCoordinates;
Expand Down Expand Up @@ -121,6 +122,8 @@ public interface DocumentOperations {
* @param query the query defining the ids of the objects to get
* @param clazz the type of the object to be returned
* @return list of {@link MultiGetItem}s
* @see Query#multiGetQuery(Collection)
* @see Query#multiGetQueryWithRouting(List)
* @since 4.1
*/
<T> List<MultiGetItem<T>> multiGet(Query query, Class<T> clazz);
Expand All @@ -132,6 +135,8 @@ public interface DocumentOperations {
* @param clazz the type of the object to be returned
* @param index the index(es) from which the objects are read.
* @return list of {@link MultiGetItem}s
* @see Query#multiGetQuery(Collection)
* @see Query#multiGetQueryWithRouting(List)
*/
<T> List<MultiGetItem<T>> multiGet(Query query, Class<T> clazz, IndexCoordinates index);

Expand Down Expand Up @@ -283,7 +288,7 @@ default void bulkUpdate(List<UpdateQuery> queries, IndexCoordinates index) {

/**
* Delete all records matching the query.
*
*
* @param query query defining the objects
* @param clazz The entity class, must be annotated with
* {@link org.springframework.data.elasticsearch.annotations.Document}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ public <T> T get(String id, Class<T> clazz, IndexCoordinates index) {
public <T> List<MultiGetItem<T>> multiGet(Query query, Class<T> clazz, IndexCoordinates index) {

Assert.notNull(index, "index must not be null");
Assert.notEmpty(query.getIds(), "No Id defined for Query");

MultiGetRequest request = requestFactory.multiGetRequest(query, clazz, index);
MultiGetResponse result = execute(client -> client.mget(request, RequestOptions.DEFAULT));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ public <T> T get(String id, Class<T> clazz, IndexCoordinates index) {
public <T> List<MultiGetItem<T>> multiGet(Query query, Class<T> clazz, IndexCoordinates index) {

Assert.notNull(index, "index must not be null");
Assert.notEmpty(query.getIds(), "No Ids defined for Query");

MultiGetRequestBuilder builder = requestFactory.multiGetRequestBuilder(client, query, clazz, index);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ default <T> Flux<T> saveAll(Iterable<T> entities, IndexCoordinates index) {
* @param query the query defining the ids of the objects to get
* @param clazz the type of the object to be returned, used to determine the index
* @return flux with list of {@link MultiGetItem}s that contain the entities
* @see Query#multiGetQuery(Collection)
* @see Query#multiGetQueryWithRouting(List)
* @since 4.1
*/
<T> Flux<MultiGetItem<T>> multiGet(Query query, Class<T> clazz);
Expand All @@ -159,6 +161,8 @@ default <T> Flux<T> saveAll(Iterable<T> entities, IndexCoordinates index) {
* @param clazz the type of the object to be returned
* @param index the index(es) from which the objects are read.
* @return flux with list of {@link MultiGetItem}s that contain the entities
* @see Query#multiGetQuery(Collection)
* @see Query#multiGetQueryWithRouting(List)
* @since 4.0
*/
<T> Flux<MultiGetItem<T>> multiGet(Query query, Class<T> clazz, IndexCoordinates index);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,6 @@ public <T> Flux<MultiGetItem<T>> multiGet(Query query, Class<T> clazz, IndexCoor
Assert.notNull(index, "Index must not be null");
Assert.notNull(clazz, "Class must not be null");
Assert.notNull(query, "Query must not be null");
Assert.notEmpty(query.getIds(), "No Id define for Query");

DocumentCallback<T> callback = new ReadDocumentCallback<>(converter, clazz, index);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -677,13 +677,13 @@ private List<MultiGetRequest.Item> getMultiRequestItems(Query searchQuery, Class

FetchSourceContext fetchSourceContext = getFetchSourceContext(searchQuery);

if (!isEmpty(searchQuery.getIds())) {
if (!isEmpty(searchQuery.getIdsWithRouting())) {
String indexName = index.getIndexName();
for (String id : searchQuery.getIds()) {
MultiGetRequest.Item item = new MultiGetRequest.Item(indexName, id);

if (searchQuery.getRoute() != null) {
item = item.routing(searchQuery.getRoute());
for (Query.IdWithRouting idWithRouting : searchQuery.getIdsWithRouting()) {
MultiGetRequest.Item item = new MultiGetRequest.Item(indexName, idWithRouting.getId());
if (idWithRouting.getRouting() != null) {
item = item.routing(idWithRouting.getRouting());
}

// note: multiGet does not have fields, need to set sourceContext to filter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,23 @@
package org.springframework.data.elasticsearch.core.query;

import static java.util.Collections.*;
import static org.springframework.util.CollectionUtils.*;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
* AbstractQuery
* BaseQuery
*
* @author Rizwan Idrees
* @author Mohsin Husen
Expand All @@ -40,7 +43,7 @@
* @author Peter-Josef Meisch
* @author Peer Mueller
*/
abstract class AbstractQuery implements Query {
public class BaseQuery implements Query {

protected Pageable pageable = DEFAULT_PAGE;
@Nullable protected Sort sort;
Expand All @@ -63,6 +66,7 @@ abstract class AbstractQuery implements Query {
@Nullable private List<Object> searchAfter;
protected List<RescorerQuery> rescorerQueries = new ArrayList<>();
@Nullable protected Boolean requestCache;
private List<IdWithRouting> idsWithRouting = Collections.emptyList();

@Override
@Nullable
Expand All @@ -81,7 +85,7 @@ public final <T extends Query> T setPageable(Pageable pageable) {
Assert.notNull(pageable, "Pageable must not be null!");

this.pageable = pageable;
return (T) this.addSort(pageable.getSort());
return this.addSort(pageable.getSort());
}

@Override
Expand Down Expand Up @@ -116,7 +120,7 @@ public SourceFilter getSourceFilter() {

@Override
@SuppressWarnings("unchecked")
public final <T extends Query> T addSort(Sort sort) {
public final <T extends Query> T addSort(@Nullable Sort sort) {
if (sort == null) {
return (T) this;
}
Expand All @@ -139,14 +143,46 @@ public void setMinScore(float minScore) {
this.minScore = minScore;
}

@Nullable
/**
* Set Ids for a multi-get request with on this query.
*
* @param ids list of id values
*/
public void setIds(@Nullable Collection<String> ids) {
this.ids = ids;
}

@Override
@Nullable
public Collection<String> getIds() {
return ids;
}

public void setIds(Collection<String> ids) {
this.ids = ids;
@Override
public List<IdWithRouting> getIdsWithRouting() {

if (!isEmpty(idsWithRouting)) {
return Collections.unmodifiableList(idsWithRouting);
}

if (!isEmpty(ids)) {
return ids.stream().map(id -> new IdWithRouting(id, route)).collect(Collectors.toList());
}

return Collections.emptyList();
}

/**
* Set Ids with routing values for a multi-get request set on this query.
*
* @param idsWithRouting list of id values, must not be {@literal null}
* @since 4.3
*/
public void setIdsWithRouting(List<IdWithRouting> idsWithRouting) {

Assert.notNull(idsWithRouting, "idsWithRouting must not be null");

this.idsWithRouting = idsWithRouting;
}

@Nullable
Expand Down Expand Up @@ -337,4 +373,5 @@ public void setRequestCache(@Nullable Boolean value) {
public Boolean getRequestCache() {
return this.requestCache;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
* @author Mark Paluch
* @author Peter-Josef Meisch
*/
public class CriteriaQuery extends AbstractQuery {
public class CriteriaQuery extends BaseQuery {

private Criteria criteria;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
package org.springframework.data.elasticsearch.core.query;

import static java.util.Collections.*;
import static org.springframework.data.elasticsearch.core.query.AbstractQuery.*;
import static org.springframework.data.elasticsearch.core.query.BaseQuery.*;

import java.util.ArrayList;
import java.util.List;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
* @author Martin Choraine
* @author Peter-Josef Meisch
*/
public class NativeSearchQuery extends AbstractQuery {
public class NativeSearchQuery extends BaseQuery {

@Nullable private final QueryBuilder query;
@Nullable private QueryBuilder filter;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.springframework.data.domain.Sort;
import org.springframework.data.elasticsearch.core.SearchHit;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
* Query
Expand Down Expand Up @@ -76,7 +77,7 @@ static Query findAll() {
* @param sort
* @return
*/
<T extends Query> T addSort(Sort sort);
<T extends Query> T addSort(@Nullable Sort sort);

/**
* @return null if not set
Expand Down Expand Up @@ -137,13 +138,48 @@ static Query findAll() {
boolean getTrackScores();

/**
* Get Ids
*
* @return
* @return Get ids set on this query.
*/
@Nullable
Collection<String> getIds();

/**
* @return Ids with routing values used in a multi-get request.
* @see #multiGetQueryWithRouting(List)
* @since 4.3
*/
List<IdWithRouting> getIdsWithRouting();

/**
* Utility method to get a query for a multiget request
*
* @param idsWithRouting Ids with routing values used in a multi-get request.
* @return Query instance
*/
static Query multiGetQueryWithRouting(List<IdWithRouting> idsWithRouting) {

Assert.notNull(idsWithRouting, "idsWithRouting must not be null");

BaseQuery query = new BaseQuery();
query.setIdsWithRouting(idsWithRouting);
return query;
}

/**
* Utility method to get a query for a multiget request
*
* @param ids Ids used in a multi-get request.
* @return Query instance
*/
static Query multiGetQuery(Collection<String> ids) {

Assert.notNull(ids, "ids must not be null");

BaseQuery query = new BaseQuery();
query.setIds(ids);
return query;
}

/**
* Get route
*
Expand Down Expand Up @@ -362,4 +398,31 @@ default List<RescorerQuery> getRescorerQueries() {
enum SearchType {
QUERY_THEN_FETCH, DFS_QUERY_THEN_FETCH
}

/**
* Value class combining an id with a routing value. Used in multi-get requests.
*
* @since 4.3
*/
final class IdWithRouting {
private final String id;
@Nullable private final String routing;

public IdWithRouting(String id, @Nullable String routing) {

Assert.notNull(id, "id must not be null");

this.id = id;
this.routing = routing;
}

public String getId() {
return id;
}

@Nullable
public String getRouting() {
return routing;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
* @author Rizwan Idrees
* @author Mohsin Husen
*/
public class StringQuery extends AbstractQuery {
public class StringQuery extends BaseQuery {

private String source;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package org.springframework.data.elasticsearch.repository.support;

import static org.elasticsearch.index.query.QueryBuilders.*;
import static org.springframework.util.CollectionUtils.*;

import java.util.ArrayList;
import java.util.Collections;
Expand Down Expand Up @@ -50,7 +51,6 @@
import org.springframework.data.util.Streamable;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

/**
* Elasticsearch specific repository implementation. Likely to be used as target within
Expand Down Expand Up @@ -149,7 +149,7 @@ public Iterable<T> findAllById(Iterable<ID> ids) {

List<T> result = new ArrayList<>();
Query idQuery = getIdQuery(ids);
if (CollectionUtils.isEmpty(idQuery.getIds())) {
if (isEmpty(idQuery.getIds())) {
return result;
}

Expand Down
Loading