Skip to content

Commit 0d84a06

Browse files
committed
Add SimilarityNormalizer.
1 parent bc5176e commit 0d84a06

13 files changed

+612
-86
lines changed

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
<dependency>
6161
<groupId>com.github.mp911de.microbenchmark-runner</groupId>
6262
<artifactId>microbenchmark-runner-junit5</artifactId>
63-
<version>0.4.0.RELEASE</version>
63+
<version>0.5.0.RELEASE</version>
6464
<scope>test</scope>
6565
</dependency>
6666
</dependencies>

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/AbstractJpaQuery.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ protected JpaQueryExecution getExecution(JpaParametersParameterAccessor accessor
176176

177177
ReturnedType returnedType = method.getResultProcessor().withDynamicProjection(accessor).getReturnedType();
178178
return new JpaQueryExecution.SearchResultExecution(execution == null ? new SingleEntityExecution() : execution,
179-
returnedType, accessor.getScoringFunction());
179+
returnedType, accessor.getScoringFunction(), accessor.normalizeSimilarity());
180180
}
181181

182182
if (execution != null) {

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaParametersParameterAccessor.java

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,16 @@
1515
*/
1616
package org.springframework.data.jpa.repository.query;
1717

18+
import java.util.function.Function;
19+
import java.util.function.Predicate;
20+
import java.util.function.Supplier;
21+
1822
import org.jspecify.annotations.Nullable;
1923

2024
import org.springframework.data.domain.Range;
2125
import org.springframework.data.domain.Score;
2226
import org.springframework.data.domain.ScoringFunction;
27+
import org.springframework.data.domain.Similarity;
2328
import org.springframework.data.jpa.repository.query.JpaParameters.JpaParameter;
2429
import org.springframework.data.repository.query.Parameter;
2530
import org.springframework.data.repository.query.Parameters;
@@ -71,28 +76,54 @@ protected Object potentiallyUnwrap(Object parameterValue) {
7176
return parameterValue;
7277
}
7378

79+
/**
80+
* Returns the {@link ScoringFunction}.
81+
*
82+
* @return
83+
*/
7484
public ScoringFunction getScoringFunction() {
85+
return doWithScore(Score::getFunction, Score.class::isInstance, () -> ScoringFunction.UNSPECIFIED);
86+
}
87+
88+
/**
89+
* Returns whether to normalize similarities (i.e. translate the database-specific score into {@link Similarity}).
90+
*
91+
* @return
92+
*/
93+
public boolean normalizeSimilarity() {
94+
return doWithScore(it -> true, Similarity.class::isInstance, () -> false);
95+
}
96+
97+
/**
98+
* Returns the {@link ScoringFunction}.
99+
*
100+
* @return
101+
*/
102+
public <T> T doWithScore(Function<Score, T> function, Predicate<Score> scoreFilter, Supplier<T> defaultValue) {
75103

76104
Score score = getScore();
77-
if (score != null) {
78-
return score.getFunction();
105+
if (score != null && scoreFilter.test(score)) {
106+
return function.apply(score);
79107
}
80108

81109
JpaParameters parameters = getParameters();
82110
if (parameters.hasScoreRangeParameter()) {
83111

84112
Range<Score> range = getScoreRange();
85113

86-
if (range.getUpperBound().isBounded()) {
87-
return range.getUpperBound().getValue().get().getFunction();
114+
if (range != null && range.getLowerBound().isBounded()
115+
&& scoreFilter.test(range.getLowerBound().getValue().get())) {
116+
return function.apply(range.getUpperBound().getValue().get());
88117
}
89118

90-
if (range.getLowerBound().isBounded()) {
91-
return range.getLowerBound().getValue().get().getFunction();
119+
if (range != null && range.getUpperBound().isBounded()
120+
&& scoreFilter.test(range.getUpperBound().getValue().get())) {
121+
return function.apply(range.getUpperBound().getValue().get());
92122
}
123+
93124
}
94125

95-
return ScoringFunction.UNSPECIFIED;
126+
return defaultValue.get();
96127
}
97128

98129
}

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryCreator.java

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
import org.springframework.data.domain.Range;
4141
import org.springframework.data.domain.Score;
4242
import org.springframework.data.domain.ScoringFunction;
43-
import org.springframework.data.domain.Similarity;
4443
import org.springframework.data.domain.Sort;
4544
import org.springframework.data.domain.VectorScoringFunctions;
4645
import org.springframework.data.jpa.domain.JpaSort;
@@ -73,11 +72,14 @@
7372
public class JpaQueryCreator extends AbstractQueryCreator<String, JpqlQueryBuilder.Predicate> implements JpqlQueryCreator {
7473

7574
private static final Map<ScoringFunction, DistanceFunction> DISTANCE_FUNCTIONS = Map.of(VectorScoringFunctions.COSINE,
76-
new DistanceFunction("cosine_distance", Sort.Direction.ASC), VectorScoringFunctions.EUCLIDEAN,
77-
new DistanceFunction("euclidean_distance", Sort.Direction.ASC), VectorScoringFunctions.TAXICAB,
78-
new DistanceFunction("taxicab_distance", Sort.Direction.ASC), VectorScoringFunctions.HAMMING,
79-
new DistanceFunction("hamming_distance", Sort.Direction.ASC), VectorScoringFunctions.INNER_PRODUCT,
80-
new DistanceFunction("negative_inner_product", Sort.Direction.DESC));
75+
new DistanceFunction("cosine_distance", Sort.Direction.ASC), //
76+
VectorScoringFunctions.EUCLIDEAN, new DistanceFunction("euclidean_distance", Sort.Direction.ASC), //
77+
VectorScoringFunctions.TAXICAB, new DistanceFunction("taxicab_distance", Sort.Direction.ASC), //
78+
VectorScoringFunctions.HAMMING, new DistanceFunction("hamming_distance", Sort.Direction.ASC), //
79+
VectorScoringFunctions.INNER_PRODUCT, new DistanceFunction("negative_inner_product", Sort.Direction.ASC), //
80+
81+
// TODO: Do we need both, dot and inner product? Aren't these the same in some sense?
82+
VectorScoringFunctions.DOT, new DistanceFunction("negative_inner_product", Sort.Direction.ASC));
8183

8284
record DistanceFunction(String distanceFunction, Sort.Direction direction) {
8385

@@ -92,6 +94,7 @@ record DistanceFunction(String distanceFunction, Sort.Direction direction) {
9294
private final EntityType<?> entityType;
9395
private final JpqlQueryBuilder.Entity entity;
9496
private final Metamodel metamodel;
97+
private final SimilarityNormalizer similarityNormalizer;
9598

9699
/**
97100
* Create a new {@link JpaQueryCreator}.
@@ -127,6 +130,7 @@ public JpaQueryCreator(PartTree tree, boolean searchQuery, ReturnedType type, Pa
127130
this.entityType = metamodel.entity(type.getDomainType());
128131
this.entity = JpqlQueryBuilder.entity(returnedType.getDomainType());
129132
this.metamodel = metamodel;
133+
this.similarityNormalizer = provider.getSimilarityNormalizer();
130134
}
131135

132136
Bindable<?> getFrom() {
@@ -384,29 +388,31 @@ JpqlQueryBuilder.Expression placeholder(int position) {
384388
* @return
385389
*/
386390
private JpqlQueryBuilder.Predicate toPredicate(Part part) {
387-
return new PredicateBuilder(part).build();
391+
return new PredicateBuilder(part, similarityNormalizer).build();
388392
}
389393

390394
/**
391395
* Simple builder to contain logic to create JPA {@link Predicate}s from {@link Part}s.
392396
*
393397
* @author Phil Webb
394398
* @author Oliver Gierke
399+
* @author Mark Paluch
395400
*/
396401
private class PredicateBuilder {
397402

398403
private final Part part;
404+
private final SimilarityNormalizer normalizer;
399405

400406
/**
401407
* Creates a new {@link PredicateBuilder} for the given {@link Part}.
402408
*
403409
* @param part must not be {@literal null}.
410+
* @param normalizer must not be {@literal null}.
404411
*/
405-
public PredicateBuilder(Part part) {
406-
407-
Assert.notNull(part, "Part must not be null");
412+
public PredicateBuilder(Part part, SimilarityNormalizer normalizer) {
408413

409414
this.part = part;
415+
this.normalizer = normalizer;
410416
}
411417

412418
/**
@@ -516,24 +522,17 @@ public JpqlQueryBuilder.Predicate build() {
516522

517523
JpqlQueryBuilder.Predicate lowerPredicate = null;
518524
JpqlQueryBuilder.Predicate upperPredicate = null;
519-
if (lower.isBounded()) {
520-
521-
JpqlQueryBuilder.Expression distanceValue = JpqlQueryBuilder
522-
.expression("" + lower.getValue().get().getValue());
523-
524-
where = JpqlQueryBuilder.where(distance);
525525

526-
lowerPredicate = lower.isInclusive() ? where.gte(distanceValue) : where.gt(distanceValue);
526+
// Score is a distance function, you typically want less when you specify a lower boundary,
527+
// therefore lower and upper predicates are inverted.
528+
if (lower.isBounded()) {
529+
JpqlQueryBuilder.Expression distanceValue = placeholder(provider.lower(within, normalizer));
530+
lowerPredicate = getUpperPredicate(lower.isInclusive(), distance, distanceValue);
527531
}
528532

529533
if (upper.isBounded()) {
530-
531-
JpqlQueryBuilder.Expression distanceValue = JpqlQueryBuilder
532-
.expression("" + upper.getValue().get().getValue());
533-
534-
where = JpqlQueryBuilder.where(distance);
535-
536-
upperPredicate = upper.isInclusive() ? where.lte(distanceValue) : where.lt(distanceValue);
534+
JpqlQueryBuilder.Expression distanceValue = placeholder(provider.upper(within, normalizer));
535+
upperPredicate = getLowerPredicate(upper.isInclusive(), distance, distanceValue);
537536
}
538537

539538
if (lowerPredicate != null && upperPredicate != null) {
@@ -549,19 +548,38 @@ public JpqlQueryBuilder.Predicate build() {
549548
if (within.getValue() instanceof Score score) {
550549

551550
String distanceFunction = getDistanceFunction(score.getFunction());
552-
JpqlQueryBuilder.Expression distanceValue = placeholder(within);
551+
JpqlQueryBuilder.Expression distanceValue = placeholder(provider.normalize(within, normalizer));
553552
JpqlQueryBuilder.Expression distance = JpqlQueryBuilder.function(distanceFunction, pas,
554553
placeholder(vector));
555554

556-
return score instanceof Similarity ? JpqlQueryBuilder.where(distance).lte(distanceValue)
557-
: JpqlQueryBuilder.where(distance).gte(distanceValue);
555+
return getUpperPredicate(true, distance, distanceValue);
558556
}
559557

560558
default:
561559
throw new IllegalArgumentException("Unsupported keyword " + type);
562560
}
563561
}
564562

563+
private JpqlQueryBuilder.Predicate getLowerPredicate(boolean inclusive, JpqlQueryBuilder.Expression lhs,
564+
JpqlQueryBuilder.Expression distance) {
565+
return doLower(inclusive, lhs, distance);
566+
}
567+
568+
private JpqlQueryBuilder.Predicate getUpperPredicate(boolean inclusive, JpqlQueryBuilder.Expression lhs,
569+
JpqlQueryBuilder.Expression distance) {
570+
return doUpper(inclusive, lhs, distance);
571+
}
572+
573+
private static JpqlQueryBuilder.Predicate doLower(boolean inclusive, JpqlQueryBuilder.Expression lhs,
574+
JpqlQueryBuilder.Expression distance) {
575+
return inclusive ? JpqlQueryBuilder.where(lhs).gte(distance) : JpqlQueryBuilder.where(lhs).gt(distance);
576+
}
577+
578+
private static JpqlQueryBuilder.Predicate doUpper(boolean inclusive, JpqlQueryBuilder.Expression lhs,
579+
JpqlQueryBuilder.Expression distance) {
580+
return inclusive ? JpqlQueryBuilder.where(lhs).lte(distance) : JpqlQueryBuilder.where(lhs).lt(distance);
581+
}
582+
565583
private static String getDistanceFunction(ScoringFunction scoringFunction) {
566584

567585
DistanceFunction distanceFunction = JpaQueryCreator.DISTANCE_FUNCTIONS.get(scoringFunction);

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryExecution.java

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.springframework.data.domain.ScrollPosition;
4040
import org.springframework.data.domain.SearchResult;
4141
import org.springframework.data.domain.SearchResults;
42+
import org.springframework.data.domain.Similarity;
4243
import org.springframework.data.domain.Slice;
4344
import org.springframework.data.domain.SliceImpl;
4445
import org.springframework.data.domain.Sort;
@@ -135,11 +136,17 @@ static class SearchResultExecution extends JpaQueryExecution {
135136
private final JpaQueryExecution delegate;
136137
private final ReturnedType returnedType;
137138
private final ScoringFunction function;
139+
private final boolean normalizeSimilarity;
140+
private final SimilarityNormalizer normalizer;
141+
142+
SearchResultExecution(JpaQueryExecution delegate, ReturnedType returnedType, ScoringFunction function,
143+
boolean normalizeSimilarity) {
138144

139-
SearchResultExecution(JpaQueryExecution delegate, ReturnedType returnedType, ScoringFunction function) {
140145
this.delegate = delegate;
141146
this.returnedType = returnedType;
142147
this.function = function;
148+
this.normalizeSimilarity = normalizeSimilarity;
149+
this.normalizer = normalizeSimilarity ? SimilarityNormalizer.get(function) : SimilarityNormalizer.IDENTITY;
143150
}
144151

145152
@Override
@@ -171,26 +178,31 @@ static class SearchResultExecution extends JpaQueryExecution {
171178

172179
Object value = returnedType.needsCustomConstruction() ? t : t.get(0);
173180
try {
174-
return new SearchResult<>(value, Score.of(t.get("distance", Number.class).doubleValue(), function));
181+
return new SearchResult<>(value, getScore(t.get("distance", Number.class).doubleValue()));
175182
} catch (RuntimeException e) {
176-
return new SearchResult<>(value, Score.of(0, function));
183+
return new SearchResult<>(value, getScore(0));
177184
}
178185
}
179186

180187
if (result instanceof Object[] objects) {
181188

182189
Object value = returnedType.needsCustomConstruction() ? objects : objects[0];
183-
184190
try {
185191

186-
return new SearchResult<>(value, Score.of(((Number) (objects[objects.length - 1])).doubleValue(), function));
192+
return new SearchResult<>(value, getScore(((Number) (objects[objects.length - 1])).doubleValue()));
187193
} catch (RuntimeException e) {
188-
return new SearchResult<>(value, Score.of(0, function));
194+
return new SearchResult<>(value, getScore(0));
189195
}
190196
}
191197

192198
return null;
193199
}
200+
201+
private Score getScore(double score) {
202+
return normalizeSimilarity ? Similarity.raw(normalizer.getSimilarity(score), function)
203+
: Score.of(score, function);
204+
}
205+
194206
}
195207

196208
/**

0 commit comments

Comments
 (0)