Skip to content

Fix search_after field values. #2679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,15 @@
import static org.springframework.data.elasticsearch.client.elc.TypeUtils.*;
import static org.springframework.util.CollectionUtils.*;

import co.elastic.clients.elasticsearch._types.*;
import co.elastic.clients.elasticsearch._types.Conflicts;
import co.elastic.clients.elasticsearch._types.ExpandWildcard;
import co.elastic.clients.elasticsearch._types.InlineScript;
import co.elastic.clients.elasticsearch._types.NestedSortValue;
import co.elastic.clients.elasticsearch._types.OpType;
import co.elastic.clients.elasticsearch._types.SortOptions;
import co.elastic.clients.elasticsearch._types.SortOrder;
import co.elastic.clients.elasticsearch._types.VersionType;
import co.elastic.clients.elasticsearch._types.WaitForActiveShardOptions;
import co.elastic.clients.elasticsearch._types.mapping.FieldType;
import co.elastic.clients.elasticsearch._types.mapping.RuntimeField;
import co.elastic.clients.elasticsearch._types.mapping.RuntimeFieldType;
Expand Down Expand Up @@ -81,7 +89,6 @@
import org.springframework.data.elasticsearch.core.mapping.ElasticsearchPersistentProperty;
import org.springframework.data.elasticsearch.core.mapping.IndexCoordinates;
import org.springframework.data.elasticsearch.core.query.*;
import org.springframework.data.elasticsearch.core.query.IndicesOptions;
import org.springframework.data.elasticsearch.core.reindex.ReindexRequest;
import org.springframework.data.elasticsearch.core.reindex.Remote;
import org.springframework.data.elasticsearch.core.script.Script;
Expand Down Expand Up @@ -1226,8 +1233,7 @@ public MsearchRequest searchMsearchRequest(
}

if (!isEmpty(query.getSearchAfter())) {
bb.searchAfter(query.getSearchAfter().stream().map(it -> FieldValue.of(it.toString()))
.collect(Collectors.toList()));
bb.searchAfter(query.getSearchAfter().stream().map(TypeUtils::toFieldValue).toList());
}

query.getRescorerQueries().forEach(rescorerQuery -> bb.rescore(getRescore(rescorerQuery)));
Expand Down Expand Up @@ -1391,8 +1397,7 @@ private <T> void prepareSearchRequest(Query query, @Nullable String routing, @Nu
}

if (!isEmpty(query.getSearchAfter())) {
builder.searchAfter(
query.getSearchAfter().stream().map(it -> FieldValue.of(it.toString())).collect(Collectors.toList()));
builder.searchAfter(query.getSearchAfter().stream().map(TypeUtils::toFieldValue).toList());
}

query.getRescorerQueries().forEach(rescorerQuery -> builder.rescore(getRescore(rescorerQuery)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,40 @@ static Object toObject(@Nullable FieldValue fieldValue) {
}
}

@Nullable
static FieldValue toFieldValue(@Nullable Object fieldValue) {

if (fieldValue == null) {
return FieldValue.NULL;
}

if (fieldValue instanceof Boolean b) {
return b ? FieldValue.TRUE : FieldValue.FALSE;
}

if (fieldValue instanceof String s) {
return FieldValue.of(s);
}

if (fieldValue instanceof Long l) {
return FieldValue.of(l);
}

if (fieldValue instanceof Integer i) {
return FieldValue.of((long) i);
}

if (fieldValue instanceof Double d) {
return FieldValue.of(d);
}

if (fieldValue instanceof Float f) {
return FieldValue.of((double) f);
}

return FieldValue.of(JsonData.of(fieldValue));
}

@Nullable
static GeoDistanceType geoDistanceType(GeoDistanceOrder.DistanceType distanceType) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public void before() {
@Test
@Order(java.lang.Integer.MAX_VALUE)
void cleanup() {
operations.indexOps(IndexCoordinates.of(indexNameProvider.getPrefix() + "*")).delete();
operations.indexOps(IndexCoordinates.of(indexNameProvider.getPrefix() + '*')).delete();
}

@Test // #1143
Expand All @@ -85,11 +85,11 @@ void shouldReadPagesWithSearchAfter() {
query.setSearchAfter(searchAfter);
SearchHits<Entity> searchHits = operations.search(query, Entity.class);

if (searchHits.getSearchHits().size() == 0) {
if (searchHits.getSearchHits().isEmpty()) {
break;
}
foundEntities.addAll(searchHits.stream().map(SearchHit::getContent).collect(Collectors.toList()));
searchAfter = searchHits.getSearchHit((int) (searchHits.getSearchHits().size() - 1)).getSortValues();
foundEntities.addAll(searchHits.stream().map(SearchHit::getContent).toList());
searchAfter = searchHits.getSearchHit(searchHits.getSearchHits().size() - 1).getSortValues();

if (++loop > 10) {
fail("loop not terminating");
Expand All @@ -99,16 +99,69 @@ void shouldReadPagesWithSearchAfter() {
assertThat(foundEntities).containsExactlyElementsOf(entities);
}

@Test // #2678
@DisplayName("should be able to handle different search after type values including null")
void shouldBeAbleToHandleDifferentSearchAfterTypeValuesIncludingNull() {

List<Entity> entities = IntStream.rangeClosed(1, 10)
.mapToObj(i -> {
var message = (i % 2 == 0) ? null : "message " + i;
var value = (i % 3 == 0) ? null : (long) i;
return new Entity((long) i, message, value);
})
.collect(Collectors.toList());
operations.save(entities);

Query query = Query.findAll();
query.setPageable(PageRequest.of(0, 3));
query.addSort(Sort.by(Sort.Direction.ASC, "id"));
query.addSort(Sort.by(Sort.Direction.ASC, "keyword"));
query.addSort(Sort.by(Sort.Direction.ASC, "value"));

List<Object> searchAfter = null;
List<Entity> foundEntities = new ArrayList<>();

int loop = 0;
do {
query.setSearchAfter(searchAfter);
SearchHits<Entity> searchHits = operations.search(query, Entity.class);

if (searchHits.getSearchHits().isEmpty()) {
break;
}
foundEntities.addAll(searchHits.stream().map(SearchHit::getContent).toList());
searchAfter = searchHits.getSearchHit(searchHits.getSearchHits().size() - 1).getSortValues();

if (++loop > 10) {
fail("loop not terminating");
}
} while (true);

assertThat(foundEntities).containsExactlyElementsOf(entities);
}

@SuppressWarnings("unused")
@Document(indexName = "#{@indexNameProvider.indexName()}")
private static class Entity {
@Nullable
@Id private Long id;
@Nullable
@Field(type = FieldType.Text) private String message;
@Field(type = FieldType.Keyword) private String keyword;

@Nullable
@Field(type = FieldType.Long) private Long value;

public Entity() {}

public Entity(@Nullable Long id, @Nullable String message) {
public Entity(@Nullable Long id, @Nullable String keyword) {
this.id = id;
this.message = message;
this.keyword = keyword;
}

public Entity(@Nullable Long id, @Nullable String keyword, @Nullable Long value) {
this.id = id;
this.keyword = keyword;
this.value = value;
}

@Nullable
Expand All @@ -121,30 +174,44 @@ public void setId(@Nullable Long id) {
}

@Nullable
public String getMessage() {
return message;
public String getKeyword() {
return keyword;
}

public void setKeyword(@Nullable String keyword) {
this.keyword = keyword;
}

@Nullable
public Long getValue() {
return value;
}

public void setMessage(@Nullable String message) {
this.message = message;
public void setValue(@Nullable Long value) {
this.value = value;
}

@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (!(o instanceof Entity entity))
if (o == null || getClass() != o.getClass())
return false;

Entity entity = (Entity) o;

if (!Objects.equals(id, entity.id))
return false;
return Objects.equals(message, entity.message);
if (!Objects.equals(keyword, entity.keyword))
return false;
return Objects.equals(value, entity.value);
}

@Override
public int hashCode() {
int result = id != null ? id.hashCode() : 0;
result = 31 * result + (message != null ? message.hashCode() : 0);
result = 31 * result + (keyword != null ? keyword.hashCode() : 0);
result = 31 * result + (value != null ? value.hashCode() : 0);
return result;
}
}
Expand Down