40
40
import org .springframework .data .domain .Range ;
41
41
import org .springframework .data .domain .Score ;
42
42
import org .springframework .data .domain .ScoringFunction ;
43
- import org .springframework .data .domain .Similarity ;
44
43
import org .springframework .data .domain .Sort ;
45
44
import org .springframework .data .domain .VectorScoringFunctions ;
46
45
75
74
public class JpaQueryCreator extends AbstractQueryCreator <String , JpqlQueryBuilder .Predicate > implements JpqlQueryCreator {
76
75
77
76
private static final Map <ScoringFunction , DistanceFunction > DISTANCE_FUNCTIONS = Map .of (VectorScoringFunctions .COSINE ,
78
- new DistanceFunction ("cosine_distance" , Sort .Direction .ASC ), VectorScoringFunctions .EUCLIDEAN ,
79
- new DistanceFunction ("euclidean_distance" , Sort .Direction .ASC ), VectorScoringFunctions .TAXICAB ,
80
- new DistanceFunction ("taxicab_distance" , Sort .Direction .ASC ), VectorScoringFunctions .HAMMING ,
81
- new DistanceFunction ("hamming_distance" , Sort .Direction .ASC ), VectorScoringFunctions .INNER_PRODUCT ,
82
- new DistanceFunction ("negative_inner_product" , Sort .Direction .DESC ));
77
+ new DistanceFunction ("cosine_distance" , Sort .Direction .ASC ), //
78
+ VectorScoringFunctions .EUCLIDEAN , new DistanceFunction ("euclidean_distance" , Sort .Direction .ASC ), //
79
+ VectorScoringFunctions .TAXICAB , new DistanceFunction ("taxicab_distance" , Sort .Direction .ASC ), //
80
+ VectorScoringFunctions .HAMMING , new DistanceFunction ("hamming_distance" , Sort .Direction .ASC ), //
81
+ VectorScoringFunctions .INNER_PRODUCT , new DistanceFunction ("negative_inner_product" , Sort .Direction .ASC ), //
82
+
83
+ // TODO: Do we need both, dot and inner product? Aren't these the same in some sense?
84
+ VectorScoringFunctions .DOT , new DistanceFunction ("negative_inner_product" , Sort .Direction .ASC ));
83
85
84
86
record DistanceFunction (String distanceFunction , Sort .Direction direction ) {
85
87
@@ -94,6 +96,7 @@ record DistanceFunction(String distanceFunction, Sort.Direction direction) {
94
96
private final EntityType <?> entityType ;
95
97
private final JpqlQueryBuilder .Entity entity ;
96
98
private final Metamodel metamodel ;
99
+ private final SimilarityNormalizer similarityNormalizer ;
97
100
private final boolean useNamedParameters ;
98
101
99
102
/**
@@ -147,6 +150,7 @@ public JpaQueryCreator(PartTree tree, boolean searchQuery, ReturnedType type, Pa
147
150
this .entityType = metamodel .entity (type .getDomainType ());
148
151
this .entity = JpqlQueryBuilder .entity (returnedType .getDomainType ());
149
152
this .metamodel = metamodel ;
153
+ this .similarityNormalizer = provider .getSimilarityNormalizer ();
150
154
}
151
155
152
156
Bindable <?> getFrom () {
@@ -405,29 +409,31 @@ JpqlQueryBuilder.Expression placeholder(ParameterBinding binding) {
405
409
* @return
406
410
*/
407
411
private JpqlQueryBuilder .Predicate toPredicate (Part part ) {
408
- return new PredicateBuilder (part ).build ();
412
+ return new PredicateBuilder (part , similarityNormalizer ).build ();
409
413
}
410
414
411
415
/**
412
416
* Simple builder to contain logic to create JPA {@link Predicate}s from {@link Part}s.
413
417
*
414
418
* @author Phil Webb
415
419
* @author Oliver Gierke
420
+ * @author Mark Paluch
416
421
*/
417
422
private class PredicateBuilder {
418
423
419
424
private final Part part ;
425
+ private final SimilarityNormalizer normalizer ;
420
426
421
427
/**
422
428
* Creates a new {@link PredicateBuilder} for the given {@link Part}.
423
429
*
424
430
* @param part must not be {@literal null}.
431
+ * @param normalizer must not be {@literal null}.
425
432
*/
426
- public PredicateBuilder (Part part ) {
427
-
428
- Assert .notNull (part , "Part must not be null" );
433
+ public PredicateBuilder (Part part , SimilarityNormalizer normalizer ) {
429
434
430
435
this .part = part ;
436
+ this .normalizer = normalizer ;
431
437
}
432
438
433
439
/**
@@ -537,24 +543,17 @@ public JpqlQueryBuilder.Predicate build() {
537
543
538
544
JpqlQueryBuilder .Predicate lowerPredicate = null ;
539
545
JpqlQueryBuilder .Predicate upperPredicate = null ;
540
- if (lower .isBounded ()) {
541
-
542
- JpqlQueryBuilder .Expression distanceValue = JpqlQueryBuilder
543
- .expression ("" + lower .getValue ().get ().getValue ());
544
-
545
- where = JpqlQueryBuilder .where (distance );
546
546
547
- lowerPredicate = lower .isInclusive () ? where .gte (distanceValue ) : where .gt (distanceValue );
547
+ // Score is a distance function, you typically want less when you specify a lower boundary,
548
+ // therefore lower and upper predicates are inverted.
549
+ if (lower .isBounded ()) {
550
+ JpqlQueryBuilder .Expression distanceValue = placeholder (provider .lower (within , normalizer ));
551
+ lowerPredicate = getUpperPredicate (lower .isInclusive (), distance , distanceValue );
548
552
}
549
553
550
554
if (upper .isBounded ()) {
551
-
552
- JpqlQueryBuilder .Expression distanceValue = JpqlQueryBuilder
553
- .expression ("" + upper .getValue ().get ().getValue ());
554
-
555
- where = JpqlQueryBuilder .where (distance );
556
-
557
- upperPredicate = upper .isInclusive () ? where .lte (distanceValue ) : where .lt (distanceValue );
555
+ JpqlQueryBuilder .Expression distanceValue = placeholder (provider .upper (within , normalizer ));
556
+ upperPredicate = getLowerPredicate (upper .isInclusive (), distance , distanceValue );
558
557
}
559
558
560
559
if (lowerPredicate != null && upperPredicate != null ) {
@@ -570,19 +569,38 @@ public JpqlQueryBuilder.Predicate build() {
570
569
if (within .getValue () instanceof Score score ) {
571
570
572
571
String distanceFunction = getDistanceFunction (score .getFunction ());
573
- JpqlQueryBuilder .Expression distanceValue = placeholder (within );
572
+ JpqlQueryBuilder .Expression distanceValue = placeholder (provider . normalize ( within , normalizer ) );
574
573
JpqlQueryBuilder .Expression distance = JpqlQueryBuilder .function (distanceFunction , pas ,
575
574
placeholder (vector ));
576
575
577
- return score instanceof Similarity ? JpqlQueryBuilder .where (distance ).lte (distanceValue )
578
- : JpqlQueryBuilder .where (distance ).gte (distanceValue );
576
+ return getUpperPredicate (true , distance , distanceValue );
579
577
}
580
578
581
579
default :
582
580
throw new IllegalArgumentException ("Unsupported keyword " + type );
583
581
}
584
582
}
585
583
584
+ private JpqlQueryBuilder .Predicate getLowerPredicate (boolean inclusive , JpqlQueryBuilder .Expression lhs ,
585
+ JpqlQueryBuilder .Expression distance ) {
586
+ return doLower (inclusive , lhs , distance );
587
+ }
588
+
589
+ private JpqlQueryBuilder .Predicate getUpperPredicate (boolean inclusive , JpqlQueryBuilder .Expression lhs ,
590
+ JpqlQueryBuilder .Expression distance ) {
591
+ return doUpper (inclusive , lhs , distance );
592
+ }
593
+
594
+ private static JpqlQueryBuilder .Predicate doLower (boolean inclusive , JpqlQueryBuilder .Expression lhs ,
595
+ JpqlQueryBuilder .Expression distance ) {
596
+ return inclusive ? JpqlQueryBuilder .where (lhs ).gte (distance ) : JpqlQueryBuilder .where (lhs ).gt (distance );
597
+ }
598
+
599
+ private static JpqlQueryBuilder .Predicate doUpper (boolean inclusive , JpqlQueryBuilder .Expression lhs ,
600
+ JpqlQueryBuilder .Expression distance ) {
601
+ return inclusive ? JpqlQueryBuilder .where (lhs ).lte (distance ) : JpqlQueryBuilder .where (lhs ).lt (distance );
602
+ }
603
+
586
604
private static String getDistanceFunction (ScoringFunction scoringFunction ) {
587
605
588
606
DistanceFunction distanceFunction = JpaQueryCreator .DISTANCE_FUNCTIONS .get (scoringFunction );
0 commit comments