Skip to content

Support collection parameters in @Query methods #1856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@
*/
package org.springframework.data.elasticsearch.repository.support;

import java.util.Collection;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import org.springframework.core.convert.support.GenericConversionService;
import org.springframework.data.repository.query.ParameterAccessor;
import org.springframework.util.NumberUtils;

/**
* @author Peter-Josef Meisch
* @author Niklas Herder
*/
final public class StringQueryUtil {

Expand Down Expand Up @@ -53,6 +56,28 @@ private static String getParameterWithIndex(ParameterAccessor accessor, int inde
// noinspection ConstantConditions
if (parameter != null) {

parameterValue = convert(parameter);
}

return parameterValue;

}

private static String convert(Object parameter) {
if (Collection.class.isAssignableFrom(parameter.getClass())) {
Collection<?> collectionParam = (Collection<?>) parameter;
StringBuilder sb = new StringBuilder("[");
sb.append(collectionParam.stream().map(o -> {
if (o instanceof String) {
return "\"" + convert(o) + "\"";
} else {
return convert(o);
}
}).collect(Collectors.joining(",")));
sb.append("]");
return sb.toString();
} else {
String parameterValue = "null";
if (conversionService.canConvert(parameter.getClass(), String.class)) {
String converted = conversionService.convert(parameter, String.class);

Expand All @@ -62,11 +87,10 @@ private static String getParameterWithIndex(ParameterAccessor accessor, int inde
} else {
parameterValue = parameter.toString();
}
}

parameterValue = parameterValue.replaceAll("\"", Matcher.quoteReplacement("\\\""));
return parameterValue;

parameterValue = parameterValue.replaceAll("\"", Matcher.quoteReplacement("\\\""));
return parameterValue;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static org.assertj.core.api.Assertions.*;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
Expand Down Expand Up @@ -51,6 +52,7 @@
/**
* @author Christoph Strobl
* @author Peter-Josef Meisch
* @author Niklas Herder
*/
@ExtendWith(MockitoExtension.class)
public class ElasticsearchStringQueryUnitTests {
Expand Down Expand Up @@ -95,14 +97,50 @@ void shouldEscapeStringsInQueryParameters() throws Exception {
.isEqualTo("{\"bool\":{\"must\": [{\"match\": {\"prefix\": {\"name\" : \"hello \\\"Stranger\\\"\"}}]}}");
}

private org.springframework.data.elasticsearch.core.query.Query createQuery(String methodName, String... args)
@Test // #1858
@DisplayName("should only quote String query parameters")
void shouldOnlyEscapeStringQueryParameters() throws Exception {
org.springframework.data.elasticsearch.core.query.Query query = createQuery("findByAge", Integer.valueOf(30));

assertThat(query).isInstanceOf(StringQuery.class);
assertThat(((StringQuery) query).getSource()).isEqualTo("{ 'bool' : { 'must' : { 'term' : { 'age' : 30 } } } }");

}

@Test // #1858
@DisplayName("should only quote String collection query parameters")
void shouldOnlyEscapeStringCollectionQueryParameters() throws Exception {
org.springframework.data.elasticsearch.core.query.Query query = createQuery("findByAgeIn",
new ArrayList<>(Arrays.asList(30, 35, 40)));

assertThat(query).isInstanceOf(StringQuery.class);
assertThat(((StringQuery) query).getSource())
.isEqualTo("{ 'bool' : { 'must' : { 'term' : { 'age' : [30,35,40] } } } }");

}

@Test // #1858
@DisplayName("should escape Strings in collection query parameters")
void shouldEscapeStringsInCollectionsQueryParameters() throws Exception {

final List<String> another_string = Arrays.asList("hello \"Stranger\"", "Another string");
List<String> params = new ArrayList<>(another_string);
org.springframework.data.elasticsearch.core.query.Query query = createQuery("findByNameIn", params);

assertThat(query).isInstanceOf(StringQuery.class);
assertThat(((StringQuery) query).getSource()).isEqualTo(
"{ 'bool' : { 'must' : { 'terms' : { 'name' : [\"hello \\\"Stranger\\\"\",\"Another string\"] } } } }");
}

private org.springframework.data.elasticsearch.core.query.Query createQuery(String methodName, Object... args)
throws NoSuchMethodException {

Class<?>[] argTypes = Arrays.stream(args).map(Object::getClass).toArray(Class[]::new);
ElasticsearchQueryMethod queryMethod = getQueryMethod(methodName, argTypes);
ElasticsearchStringQuery elasticsearchStringQuery = queryForMethod(queryMethod);
return elasticsearchStringQuery.createQuery(new ElasticsearchParametersParameterAccessor(queryMethod, args));
}

private ElasticsearchStringQuery queryForMethod(ElasticsearchQueryMethod queryMethod) {
return new ElasticsearchStringQuery(queryMethod, operations, queryMethod.getAnnotatedQuery());
}
Expand All @@ -116,9 +154,18 @@ private ElasticsearchQueryMethod getQueryMethod(String name, Class<?>... paramet

private interface SampleRepository extends Repository<Person, String> {

@Query("{ 'bool' : { 'must' : { 'term' : { 'age' : ?0 } } } }")
List<Person> findByAge(Integer age);

@Query("{ 'bool' : { 'must' : { 'term' : { 'age' : ?0 } } } }")
List<Person> findByAgeIn(ArrayList<Integer> age);

@Query("{ 'bool' : { 'must' : { 'term' : { 'name' : '?0' } } } }")
Person findByName(String name);

@Query("{ 'bool' : { 'must' : { 'terms' : { 'name' : ?0 } } } }")
Person findByNameIn(ArrayList<String> names);

@Query(value = "name:(?0, ?11, ?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?0, ?1)")
Person findWithRepeatedPlaceholder(String arg0, String arg1, String arg2, String arg3, String arg4, String arg5,
String arg6, String arg7, String arg8, String arg9, String arg10, String arg11);
Expand All @@ -131,16 +178,27 @@ Person findWithRepeatedPlaceholder(String arg0, String arg1, String arg2, String
* @author Rizwan Idrees
* @author Mohsin Husen
* @author Artur Konczak
* @author Niklas Herder
*/

@Document(indexName = "test-index-person-query-unittest")
static class Person {

@Nullable public int age;
@Nullable @Id private String id;
@Nullable private String name;
@Nullable @Field(type = FieldType.Nested) private List<Car> car;
@Nullable @Field(type = FieldType.Nested, includeInParent = true) private List<Book> books;

@Nullable
public int getAge() {
return age;
}

public void setAge(int age) {
this.age = age;
}

@Nullable
public String getId() {
return id;
Expand Down