Skip to content

Commit 55e6483

Browse files
committed
Add SimilarityNormalizer.
1 parent 3c4f4b1 commit 55e6483

13 files changed

+618
-84
lines changed

pom.xml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@
5656
<profiles>
5757
<profile>
5858
<id>jmh</id>
59+
<dependencies>
60+
<dependency>
61+
<groupId>com.github.mp911de.microbenchmark-runner</groupId>
62+
<artifactId>microbenchmark-runner-junit5</artifactId>
63+
<version>0.5.0.RELEASE</version>
64+
<scope>test</scope>
65+
</dependency>
66+
</dependencies>
5967
<repositories>
6068
<repository>
6169
<id>jitpack</id>

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/AbstractJpaQuery.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ protected JpaQueryExecution getExecution(JpaParametersParameterAccessor accessor
176176

177177
ReturnedType returnedType = method.getResultProcessor().withDynamicProjection(accessor).getReturnedType();
178178
return new JpaQueryExecution.SearchResultExecution(execution == null ? new SingleEntityExecution() : execution,
179-
returnedType, accessor.getScoringFunction());
179+
returnedType, accessor.getScoringFunction(), accessor.normalizeSimilarity());
180180
}
181181

182182
if (execution != null) {

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaParametersParameterAccessor.java

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,16 @@
1515
*/
1616
package org.springframework.data.jpa.repository.query;
1717

18+
import java.util.function.Function;
19+
import java.util.function.Predicate;
20+
import java.util.function.Supplier;
21+
1822
import org.jspecify.annotations.Nullable;
1923

2024
import org.springframework.data.domain.Range;
2125
import org.springframework.data.domain.Score;
2226
import org.springframework.data.domain.ScoringFunction;
27+
import org.springframework.data.domain.Similarity;
2328
import org.springframework.data.jpa.repository.query.JpaParameters.JpaParameter;
2429
import org.springframework.data.repository.query.Parameter;
2530
import org.springframework.data.repository.query.Parameters;
@@ -71,28 +76,54 @@ protected Object potentiallyUnwrap(Object parameterValue) {
7176
return parameterValue;
7277
}
7378

79+
/**
80+
* Returns the {@link ScoringFunction}.
81+
*
82+
* @return
83+
*/
7484
public ScoringFunction getScoringFunction() {
85+
return doWithScore(Score::getFunction, Score.class::isInstance, () -> ScoringFunction.UNSPECIFIED);
86+
}
87+
88+
/**
89+
* Returns whether to normalize similarities (i.e. translate the database-specific score into {@link Similarity}).
90+
*
91+
* @return
92+
*/
93+
public boolean normalizeSimilarity() {
94+
return doWithScore(it -> true, Similarity.class::isInstance, () -> false);
95+
}
96+
97+
/**
98+
* Returns the {@link ScoringFunction}.
99+
*
100+
* @return
101+
*/
102+
public <T> T doWithScore(Function<Score, T> function, Predicate<Score> scoreFilter, Supplier<T> defaultValue) {
75103

76104
Score score = getScore();
77-
if (score != null) {
78-
return score.getFunction();
105+
if (score != null && scoreFilter.test(score)) {
106+
return function.apply(score);
79107
}
80108

81109
JpaParameters parameters = getParameters();
82110
if (parameters.hasScoreRangeParameter()) {
83111

84112
Range<Score> range = getScoreRange();
85113

86-
if (range.getUpperBound().isBounded()) {
87-
return range.getUpperBound().getValue().get().getFunction();
114+
if (range != null && range.getLowerBound().isBounded()
115+
&& scoreFilter.test(range.getLowerBound().getValue().get())) {
116+
return function.apply(range.getUpperBound().getValue().get());
88117
}
89118

90-
if (range.getLowerBound().isBounded()) {
91-
return range.getLowerBound().getValue().get().getFunction();
119+
if (range != null && range.getUpperBound().isBounded()
120+
&& scoreFilter.test(range.getUpperBound().getValue().get())) {
121+
return function.apply(range.getUpperBound().getValue().get());
92122
}
123+
93124
}
94125

95-
return ScoringFunction.UNSPECIFIED;
126+
return defaultValue.get();
96127
}
97128

98129
}

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryCreator.java

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
import org.springframework.data.domain.Range;
4141
import org.springframework.data.domain.Score;
4242
import org.springframework.data.domain.ScoringFunction;
43-
import org.springframework.data.domain.Similarity;
4443
import org.springframework.data.domain.Sort;
4544
import org.springframework.data.domain.VectorScoringFunctions;
4645

@@ -75,11 +74,14 @@
7574
public class JpaQueryCreator extends AbstractQueryCreator<String, JpqlQueryBuilder.Predicate> implements JpqlQueryCreator {
7675

7776
private static final Map<ScoringFunction, DistanceFunction> DISTANCE_FUNCTIONS = Map.of(VectorScoringFunctions.COSINE,
78-
new DistanceFunction("cosine_distance", Sort.Direction.ASC), VectorScoringFunctions.EUCLIDEAN,
79-
new DistanceFunction("euclidean_distance", Sort.Direction.ASC), VectorScoringFunctions.TAXICAB,
80-
new DistanceFunction("taxicab_distance", Sort.Direction.ASC), VectorScoringFunctions.HAMMING,
81-
new DistanceFunction("hamming_distance", Sort.Direction.ASC), VectorScoringFunctions.INNER_PRODUCT,
82-
new DistanceFunction("negative_inner_product", Sort.Direction.DESC));
77+
new DistanceFunction("cosine_distance", Sort.Direction.ASC), //
78+
VectorScoringFunctions.EUCLIDEAN, new DistanceFunction("euclidean_distance", Sort.Direction.ASC), //
79+
VectorScoringFunctions.TAXICAB, new DistanceFunction("taxicab_distance", Sort.Direction.ASC), //
80+
VectorScoringFunctions.HAMMING, new DistanceFunction("hamming_distance", Sort.Direction.ASC), //
81+
VectorScoringFunctions.INNER_PRODUCT, new DistanceFunction("negative_inner_product", Sort.Direction.ASC), //
82+
83+
// TODO: Do we need both, dot and inner product? Aren't these the same in some sense?
84+
VectorScoringFunctions.DOT, new DistanceFunction("negative_inner_product", Sort.Direction.ASC));
8385

8486
record DistanceFunction(String distanceFunction, Sort.Direction direction) {
8587

@@ -94,6 +96,7 @@ record DistanceFunction(String distanceFunction, Sort.Direction direction) {
9496
private final EntityType<?> entityType;
9597
private final JpqlQueryBuilder.Entity entity;
9698
private final Metamodel metamodel;
99+
private final SimilarityNormalizer similarityNormalizer;
97100
private final boolean useNamedParameters;
98101

99102
/**
@@ -147,6 +150,7 @@ public JpaQueryCreator(PartTree tree, boolean searchQuery, ReturnedType type, Pa
147150
this.entityType = metamodel.entity(type.getDomainType());
148151
this.entity = JpqlQueryBuilder.entity(returnedType.getDomainType());
149152
this.metamodel = metamodel;
153+
this.similarityNormalizer = provider.getSimilarityNormalizer();
150154
}
151155

152156
Bindable<?> getFrom() {
@@ -405,29 +409,31 @@ JpqlQueryBuilder.Expression placeholder(ParameterBinding binding) {
405409
* @return
406410
*/
407411
private JpqlQueryBuilder.Predicate toPredicate(Part part) {
408-
return new PredicateBuilder(part).build();
412+
return new PredicateBuilder(part, similarityNormalizer).build();
409413
}
410414

411415
/**
412416
* Simple builder to contain logic to create JPA {@link Predicate}s from {@link Part}s.
413417
*
414418
* @author Phil Webb
415419
* @author Oliver Gierke
420+
* @author Mark Paluch
416421
*/
417422
private class PredicateBuilder {
418423

419424
private final Part part;
425+
private final SimilarityNormalizer normalizer;
420426

421427
/**
422428
* Creates a new {@link PredicateBuilder} for the given {@link Part}.
423429
*
424430
* @param part must not be {@literal null}.
431+
* @param normalizer must not be {@literal null}.
425432
*/
426-
public PredicateBuilder(Part part) {
427-
428-
Assert.notNull(part, "Part must not be null");
433+
public PredicateBuilder(Part part, SimilarityNormalizer normalizer) {
429434

430435
this.part = part;
436+
this.normalizer = normalizer;
431437
}
432438

433439
/**
@@ -537,24 +543,17 @@ public JpqlQueryBuilder.Predicate build() {
537543

538544
JpqlQueryBuilder.Predicate lowerPredicate = null;
539545
JpqlQueryBuilder.Predicate upperPredicate = null;
540-
if (lower.isBounded()) {
541-
542-
JpqlQueryBuilder.Expression distanceValue = JpqlQueryBuilder
543-
.expression("" + lower.getValue().get().getValue());
544-
545-
where = JpqlQueryBuilder.where(distance);
546546

547-
lowerPredicate = lower.isInclusive() ? where.gte(distanceValue) : where.gt(distanceValue);
547+
// Score is a distance function, you typically want less when you specify a lower boundary,
548+
// therefore lower and upper predicates are inverted.
549+
if (lower.isBounded()) {
550+
JpqlQueryBuilder.Expression distanceValue = placeholder(provider.lower(within, normalizer));
551+
lowerPredicate = getUpperPredicate(lower.isInclusive(), distance, distanceValue);
548552
}
549553

550554
if (upper.isBounded()) {
551-
552-
JpqlQueryBuilder.Expression distanceValue = JpqlQueryBuilder
553-
.expression("" + upper.getValue().get().getValue());
554-
555-
where = JpqlQueryBuilder.where(distance);
556-
557-
upperPredicate = upper.isInclusive() ? where.lte(distanceValue) : where.lt(distanceValue);
555+
JpqlQueryBuilder.Expression distanceValue = placeholder(provider.upper(within, normalizer));
556+
upperPredicate = getLowerPredicate(upper.isInclusive(), distance, distanceValue);
558557
}
559558

560559
if (lowerPredicate != null && upperPredicate != null) {
@@ -570,19 +569,38 @@ public JpqlQueryBuilder.Predicate build() {
570569
if (within.getValue() instanceof Score score) {
571570

572571
String distanceFunction = getDistanceFunction(score.getFunction());
573-
JpqlQueryBuilder.Expression distanceValue = placeholder(within);
572+
JpqlQueryBuilder.Expression distanceValue = placeholder(provider.normalize(within, normalizer));
574573
JpqlQueryBuilder.Expression distance = JpqlQueryBuilder.function(distanceFunction, pas,
575574
placeholder(vector));
576575

577-
return score instanceof Similarity ? JpqlQueryBuilder.where(distance).lte(distanceValue)
578-
: JpqlQueryBuilder.where(distance).gte(distanceValue);
576+
return getUpperPredicate(true, distance, distanceValue);
579577
}
580578

581579
default:
582580
throw new IllegalArgumentException("Unsupported keyword " + type);
583581
}
584582
}
585583

584+
private JpqlQueryBuilder.Predicate getLowerPredicate(boolean inclusive, JpqlQueryBuilder.Expression lhs,
585+
JpqlQueryBuilder.Expression distance) {
586+
return doLower(inclusive, lhs, distance);
587+
}
588+
589+
private JpqlQueryBuilder.Predicate getUpperPredicate(boolean inclusive, JpqlQueryBuilder.Expression lhs,
590+
JpqlQueryBuilder.Expression distance) {
591+
return doUpper(inclusive, lhs, distance);
592+
}
593+
594+
private static JpqlQueryBuilder.Predicate doLower(boolean inclusive, JpqlQueryBuilder.Expression lhs,
595+
JpqlQueryBuilder.Expression distance) {
596+
return inclusive ? JpqlQueryBuilder.where(lhs).gte(distance) : JpqlQueryBuilder.where(lhs).gt(distance);
597+
}
598+
599+
private static JpqlQueryBuilder.Predicate doUpper(boolean inclusive, JpqlQueryBuilder.Expression lhs,
600+
JpqlQueryBuilder.Expression distance) {
601+
return inclusive ? JpqlQueryBuilder.where(lhs).lte(distance) : JpqlQueryBuilder.where(lhs).lt(distance);
602+
}
603+
586604
private static String getDistanceFunction(ScoringFunction scoringFunction) {
587605

588606
DistanceFunction distanceFunction = JpaQueryCreator.DISTANCE_FUNCTIONS.get(scoringFunction);

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryExecution.java

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.springframework.data.domain.ScrollPosition;
4040
import org.springframework.data.domain.SearchResult;
4141
import org.springframework.data.domain.SearchResults;
42+
import org.springframework.data.domain.Similarity;
4243
import org.springframework.data.domain.Slice;
4344
import org.springframework.data.domain.SliceImpl;
4445
import org.springframework.data.domain.Sort;
@@ -135,11 +136,17 @@ static class SearchResultExecution extends JpaQueryExecution {
135136
private final JpaQueryExecution delegate;
136137
private final ReturnedType returnedType;
137138
private final ScoringFunction function;
139+
private final boolean normalizeSimilarity;
140+
private final SimilarityNormalizer normalizer;
141+
142+
SearchResultExecution(JpaQueryExecution delegate, ReturnedType returnedType, ScoringFunction function,
143+
boolean normalizeSimilarity) {
138144

139-
SearchResultExecution(JpaQueryExecution delegate, ReturnedType returnedType, ScoringFunction function) {
140145
this.delegate = delegate;
141146
this.returnedType = returnedType;
142147
this.function = function;
148+
this.normalizeSimilarity = normalizeSimilarity;
149+
this.normalizer = normalizeSimilarity ? SimilarityNormalizer.get(function) : SimilarityNormalizer.IDENTITY;
143150
}
144151

145152
@Override
@@ -171,26 +178,31 @@ static class SearchResultExecution extends JpaQueryExecution {
171178

172179
Object value = returnedType.needsCustomConstruction() ? t : t.get(0);
173180
try {
174-
return new SearchResult<>(value, Score.of(t.get("distance", Number.class).doubleValue(), function));
181+
return new SearchResult<>(value, getScore(t.get("distance", Number.class).doubleValue()));
175182
} catch (RuntimeException e) {
176-
return new SearchResult<>(value, Score.of(0, function));
183+
return new SearchResult<>(value, getScore(0));
177184
}
178185
}
179186

180187
if (result instanceof Object[] objects) {
181188

182189
Object value = returnedType.needsCustomConstruction() ? objects : objects[0];
183-
184190
try {
185191

186-
return new SearchResult<>(value, Score.of(((Number) (objects[objects.length - 1])).doubleValue(), function));
192+
return new SearchResult<>(value, getScore(((Number) (objects[objects.length - 1])).doubleValue()));
187193
} catch (RuntimeException e) {
188-
return new SearchResult<>(value, Score.of(0, function));
194+
return new SearchResult<>(value, getScore(0));
189195
}
190196
}
191197

192198
return null;
193199
}
200+
201+
private Score getScore(double score) {
202+
return normalizeSimilarity ? Similarity.raw(normalizer.getSimilarity(score), function)
203+
: Score.of(score, function);
204+
}
205+
194206
}
195207

196208
/**

0 commit comments

Comments
 (0)