40
40
import org .springframework .data .domain .Range ;
41
41
import org .springframework .data .domain .Score ;
42
42
import org .springframework .data .domain .ScoringFunction ;
43
- import org .springframework .data .domain .Similarity ;
44
43
import org .springframework .data .domain .Sort ;
45
44
import org .springframework .data .domain .VectorScoringFunctions ;
46
45
import org .springframework .data .jpa .domain .JpaSort ;
73
72
public class JpaQueryCreator extends AbstractQueryCreator <String , JpqlQueryBuilder .Predicate > implements JpqlQueryCreator {
74
73
75
74
private static final Map <ScoringFunction , DistanceFunction > DISTANCE_FUNCTIONS = Map .of (VectorScoringFunctions .COSINE ,
76
- new DistanceFunction ("cosine_distance" , Sort .Direction .ASC ), VectorScoringFunctions .EUCLIDEAN ,
77
- new DistanceFunction ("euclidean_distance" , Sort .Direction .ASC ), VectorScoringFunctions .TAXICAB ,
78
- new DistanceFunction ("taxicab_distance" , Sort .Direction .ASC ), VectorScoringFunctions .HAMMING ,
79
- new DistanceFunction ("hamming_distance" , Sort .Direction .ASC ), VectorScoringFunctions .INNER_PRODUCT ,
80
- new DistanceFunction ("negative_inner_product" , Sort .Direction .DESC ));
75
+ new DistanceFunction ("cosine_distance" , Sort .Direction .ASC ), //
76
+ VectorScoringFunctions .EUCLIDEAN , new DistanceFunction ("euclidean_distance" , Sort .Direction .ASC ), //
77
+ VectorScoringFunctions .TAXICAB , new DistanceFunction ("taxicab_distance" , Sort .Direction .ASC ), //
78
+ VectorScoringFunctions .HAMMING , new DistanceFunction ("hamming_distance" , Sort .Direction .ASC ), //
79
+ VectorScoringFunctions .INNER_PRODUCT , new DistanceFunction ("negative_inner_product" , Sort .Direction .ASC ), //
80
+
81
+ // TODO: Do we need both, dot and inner product? Aren't these the same in some sense?
82
+ VectorScoringFunctions .DOT , new DistanceFunction ("negative_inner_product" , Sort .Direction .ASC ));
81
83
82
84
record DistanceFunction (String distanceFunction , Sort .Direction direction ) {
83
85
@@ -92,6 +94,7 @@ record DistanceFunction(String distanceFunction, Sort.Direction direction) {
92
94
private final EntityType <?> entityType ;
93
95
private final JpqlQueryBuilder .Entity entity ;
94
96
private final Metamodel metamodel ;
97
+ private final SimilarityNormalizer similarityNormalizer ;
95
98
96
99
/**
97
100
* Create a new {@link JpaQueryCreator}.
@@ -127,6 +130,7 @@ public JpaQueryCreator(PartTree tree, boolean searchQuery, ReturnedType type, Pa
127
130
this .entityType = metamodel .entity (type .getDomainType ());
128
131
this .entity = JpqlQueryBuilder .entity (returnedType .getDomainType ());
129
132
this .metamodel = metamodel ;
133
+ this .similarityNormalizer = provider .getSimilarityNormalizer ();
130
134
}
131
135
132
136
Bindable <?> getFrom () {
@@ -384,29 +388,31 @@ JpqlQueryBuilder.Expression placeholder(int position) {
384
388
* @return
385
389
*/
386
390
private JpqlQueryBuilder .Predicate toPredicate (Part part ) {
387
- return new PredicateBuilder (part ).build ();
391
+ return new PredicateBuilder (part , similarityNormalizer ).build ();
388
392
}
389
393
390
394
/**
391
395
* Simple builder to contain logic to create JPA {@link Predicate}s from {@link Part}s.
392
396
*
393
397
* @author Phil Webb
394
398
* @author Oliver Gierke
399
+ * @author Mark Paluch
395
400
*/
396
401
private class PredicateBuilder {
397
402
398
403
private final Part part ;
404
+ private final SimilarityNormalizer normalizer ;
399
405
400
406
/**
401
407
* Creates a new {@link PredicateBuilder} for the given {@link Part}.
402
408
*
403
409
* @param part must not be {@literal null}.
410
+ * @param normalizer must not be {@literal null}.
404
411
*/
405
- public PredicateBuilder (Part part ) {
406
-
407
- Assert .notNull (part , "Part must not be null" );
412
+ public PredicateBuilder (Part part , SimilarityNormalizer normalizer ) {
408
413
409
414
this .part = part ;
415
+ this .normalizer = normalizer ;
410
416
}
411
417
412
418
/**
@@ -516,24 +522,17 @@ public JpqlQueryBuilder.Predicate build() {
516
522
517
523
JpqlQueryBuilder .Predicate lowerPredicate = null ;
518
524
JpqlQueryBuilder .Predicate upperPredicate = null ;
519
- if (lower .isBounded ()) {
520
-
521
- JpqlQueryBuilder .Expression distanceValue = JpqlQueryBuilder
522
- .expression ("" + lower .getValue ().get ().getValue ());
523
-
524
- where = JpqlQueryBuilder .where (distance );
525
525
526
- lowerPredicate = lower .isInclusive () ? where .gte (distanceValue ) : where .gt (distanceValue );
526
+ // Score is a distance function, you typically want less when you specify a lower boundary,
527
+ // therefore lower and upper predicates are inverted.
528
+ if (lower .isBounded ()) {
529
+ JpqlQueryBuilder .Expression distanceValue = placeholder (provider .lower (within , normalizer ));
530
+ lowerPredicate = getUpperPredicate (lower .isInclusive (), distance , distanceValue );
527
531
}
528
532
529
533
if (upper .isBounded ()) {
530
-
531
- JpqlQueryBuilder .Expression distanceValue = JpqlQueryBuilder
532
- .expression ("" + upper .getValue ().get ().getValue ());
533
-
534
- where = JpqlQueryBuilder .where (distance );
535
-
536
- upperPredicate = upper .isInclusive () ? where .lte (distanceValue ) : where .lt (distanceValue );
534
+ JpqlQueryBuilder .Expression distanceValue = placeholder (provider .upper (within , normalizer ));
535
+ upperPredicate = getLowerPredicate (upper .isInclusive (), distance , distanceValue );
537
536
}
538
537
539
538
if (lowerPredicate != null && upperPredicate != null ) {
@@ -549,19 +548,38 @@ public JpqlQueryBuilder.Predicate build() {
549
548
if (within .getValue () instanceof Score score ) {
550
549
551
550
String distanceFunction = getDistanceFunction (score .getFunction ());
552
- JpqlQueryBuilder .Expression distanceValue = placeholder (within );
551
+ JpqlQueryBuilder .Expression distanceValue = placeholder (provider . normalize ( within , normalizer ) );
553
552
JpqlQueryBuilder .Expression distance = JpqlQueryBuilder .function (distanceFunction , pas ,
554
553
placeholder (vector ));
555
554
556
- return score instanceof Similarity ? JpqlQueryBuilder .where (distance ).lte (distanceValue )
557
- : JpqlQueryBuilder .where (distance ).gte (distanceValue );
555
+ return getUpperPredicate (true , distance , distanceValue );
558
556
}
559
557
560
558
default :
561
559
throw new IllegalArgumentException ("Unsupported keyword " + type );
562
560
}
563
561
}
564
562
563
+ private JpqlQueryBuilder .Predicate getLowerPredicate (boolean inclusive , JpqlQueryBuilder .Expression lhs ,
564
+ JpqlQueryBuilder .Expression distance ) {
565
+ return doLower (inclusive , lhs , distance );
566
+ }
567
+
568
+ private JpqlQueryBuilder .Predicate getUpperPredicate (boolean inclusive , JpqlQueryBuilder .Expression lhs ,
569
+ JpqlQueryBuilder .Expression distance ) {
570
+ return doUpper (inclusive , lhs , distance );
571
+ }
572
+
573
+ private static JpqlQueryBuilder .Predicate doLower (boolean inclusive , JpqlQueryBuilder .Expression lhs ,
574
+ JpqlQueryBuilder .Expression distance ) {
575
+ return inclusive ? JpqlQueryBuilder .where (lhs ).gte (distance ) : JpqlQueryBuilder .where (lhs ).gt (distance );
576
+ }
577
+
578
+ private static JpqlQueryBuilder .Predicate doUpper (boolean inclusive , JpqlQueryBuilder .Expression lhs ,
579
+ JpqlQueryBuilder .Expression distance ) {
580
+ return inclusive ? JpqlQueryBuilder .where (lhs ).lte (distance ) : JpqlQueryBuilder .where (lhs ).lt (distance );
581
+ }
582
+
565
583
private static String getDistanceFunction (ScoringFunction scoringFunction ) {
566
584
567
585
DistanceFunction distanceFunction = JpaQueryCreator .DISTANCE_FUNCTIONS .get (scoringFunction );
0 commit comments