Skip to content

Commit 7689c5a

Browse files
SiyaoIsHidingabsurdfarce
authored andcommitted
JAVA-3118: Add support for vector data type in Schema Builder, QueryBuilder
patch by Jane He; reviewed by Mick Semb Wever and Bret McGuire for JAVA-3118 reference: #1931
1 parent 6a8674f commit 7689c5a

File tree

12 files changed

+233
-8
lines changed

12 files changed

+233
-8
lines changed

manual/query_builder/select/README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,29 @@ selectFrom("sensor_data")
387387
// SELECT reading FROM sensor_data WHERE id=? ORDER BY date DESC
388388
```
389389

390+
Vector Search:
391+
392+
```java
393+
394+
import com.datastax.oss.driver.api.core.data.CqlVector;
395+
396+
selectFrom("foo")
397+
.all()
398+
.where(Relation.column("k").isEqualTo(literal(1)))
399+
.orderByAnnOf("c1", CqlVector.newInstance(0.1, 0.2, 0.3));
400+
// SELECT * FROM foo WHERE k=1 ORDER BY c1 ANN OF [0.1, 0.2, 0.3]
401+
402+
selectFrom("cycling", "comments_vs")
403+
.column("comment")
404+
.function(
405+
"similarity_cosine",
406+
Selector.column("comment_vector"),
407+
literal(CqlVector.newInstance(0.2, 0.15, 0.3, 0.2, 0.05)))
408+
.orderByAnnOf("comment_vector", CqlVector.newInstance(0.1, 0.15, 0.3, 0.12, 0.05))
409+
.limit(1);
410+
// SELECT comment,similarity_cosine(comment_vector,[0.2, 0.15, 0.3, 0.2, 0.05]) FROM cycling.comments_vs ORDER BY comment_vector ANN OF [0.1, 0.15, 0.3, 0.12, 0.05] LIMIT 1
411+
```
412+
390413
Limits:
391414

392415
```java

query-builder/revapi.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2772,6 +2772,16 @@
27722772
"code": "java.method.addedToInterface",
27732773
"new": "method com.datastax.oss.driver.api.querybuilder.update.UpdateStart com.datastax.oss.driver.api.querybuilder.update.UpdateStart::usingTtl(int)",
27742774
"justification": "JAVA-2210: Add ability to set TTL for modification queries"
2775+
},
2776+
{
2777+
"code": "java.method.addedToInterface",
2778+
"new": "method com.datastax.oss.driver.api.querybuilder.select.Select com.datastax.oss.driver.api.querybuilder.select.Select::orderByAnnOf(java.lang.String, com.datastax.oss.driver.api.core.data.CqlVector<?>)",
2779+
"justification": "JAVA-3118: Add support for vector data type in Schema Builder, QueryBuilder"
2780+
},
2781+
{
2782+
"code": "java.method.addedToInterface",
2783+
"new": "method com.datastax.oss.driver.api.querybuilder.select.Select com.datastax.oss.driver.api.querybuilder.select.Select::orderByAnnOf(com.datastax.oss.driver.api.core.CqlIdentifier, com.datastax.oss.driver.api.core.data.CqlVector<?>)",
2784+
"justification": "JAVA-3118: Add support for vector data type in Schema Builder, QueryBuilder"
27752785
}
27762786
]
27772787
}

query-builder/src/main/java/com/datastax/oss/driver/api/querybuilder/select/Select.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package com.datastax.oss.driver.api.querybuilder.select;
1919

2020
import com.datastax.oss.driver.api.core.CqlIdentifier;
21+
import com.datastax.oss.driver.api.core.data.CqlVector;
2122
import com.datastax.oss.driver.api.core.metadata.schema.ClusteringOrder;
2223
import com.datastax.oss.driver.api.querybuilder.BindMarker;
2324
import com.datastax.oss.driver.api.querybuilder.BuildableQuery;
@@ -146,6 +147,16 @@ default Select orderBy(@NonNull String columnName, @NonNull ClusteringOrder orde
146147
return orderBy(CqlIdentifier.fromCql(columnName), order);
147148
}
148149

150+
/**
151+
* Shortcut for {@link #orderByAnnOf(CqlIdentifier, CqlVector)}, adding an ORDER BY ... ANN OF ...
152+
* clause
153+
*/
154+
@NonNull
155+
Select orderByAnnOf(@NonNull String columnName, @NonNull CqlVector<?> ann);
156+
157+
/** Adds the ORDER BY ... ANN OF ... clause, usually used for vector search */
158+
@NonNull
159+
Select orderByAnnOf(@NonNull CqlIdentifier columnId, @NonNull CqlVector<?> ann);
149160
/**
150161
* Adds a LIMIT clause to this query with a literal value.
151162
*

query-builder/src/main/java/com/datastax/oss/driver/internal/querybuilder/select/DefaultSelect.java

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
import com.datastax.oss.driver.api.core.CqlIdentifier;
2121
import com.datastax.oss.driver.api.core.cql.SimpleStatement;
2222
import com.datastax.oss.driver.api.core.cql.SimpleStatementBuilder;
23+
import com.datastax.oss.driver.api.core.data.CqlVector;
2324
import com.datastax.oss.driver.api.core.metadata.schema.ClusteringOrder;
2425
import com.datastax.oss.driver.api.querybuilder.BindMarker;
26+
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
2527
import com.datastax.oss.driver.api.querybuilder.relation.Relation;
2628
import com.datastax.oss.driver.api.querybuilder.select.Select;
2729
import com.datastax.oss.driver.api.querybuilder.select.SelectFrom;
@@ -49,6 +51,7 @@ public class DefaultSelect implements SelectFrom, Select {
4951
private final ImmutableList<Relation> relations;
5052
private final ImmutableList<Selector> groupByClauses;
5153
private final ImmutableMap<CqlIdentifier, ClusteringOrder> orderings;
54+
private final Ann ann;
5255
private final Object limit;
5356
private final Object perPartitionLimit;
5457
private final boolean allowsFiltering;
@@ -65,6 +68,7 @@ public DefaultSelect(@Nullable CqlIdentifier keyspace, @NonNull CqlIdentifier ta
6568
ImmutableMap.of(),
6669
null,
6770
null,
71+
null,
6872
false);
6973
}
7074

@@ -74,6 +78,8 @@ public DefaultSelect(@Nullable CqlIdentifier keyspace, @NonNull CqlIdentifier ta
7478
* @param selectors if it contains {@link AllSelector#INSTANCE}, that must be the only element.
7579
* This isn't re-checked because methods that call this constructor internally already do it,
7680
* make sure you do it yourself.
81+
* @param ann Approximate nearest neighbor. ANN ordering does not support secondary ordering or
82+
* ASC order.
7783
*/
7884
public DefaultSelect(
7985
@Nullable CqlIdentifier keyspace,
@@ -84,6 +90,7 @@ public DefaultSelect(
8490
@NonNull ImmutableList<Relation> relations,
8591
@NonNull ImmutableList<Selector> groupByClauses,
8692
@NonNull ImmutableMap<CqlIdentifier, ClusteringOrder> orderings,
93+
@Nullable Ann ann,
8794
@Nullable Object limit,
8895
@Nullable Object perPartitionLimit,
8996
boolean allowsFiltering) {
@@ -94,6 +101,9 @@ public DefaultSelect(
94101
|| (limit instanceof Integer && (Integer) limit > 0)
95102
|| limit instanceof BindMarker,
96103
"limit must be a strictly positive integer or a bind marker");
104+
Preconditions.checkArgument(
105+
orderings.isEmpty() || ann == null, "ANN ordering does not support secondary ordering");
106+
this.ann = ann;
97107
this.keyspace = keyspace;
98108
this.table = table;
99109
this.isJson = isJson;
@@ -117,6 +127,7 @@ public SelectFrom json() {
117127
relations,
118128
groupByClauses,
119129
orderings,
130+
ann,
120131
limit,
121132
perPartitionLimit,
122133
allowsFiltering);
@@ -134,6 +145,7 @@ public SelectFrom distinct() {
134145
relations,
135146
groupByClauses,
136147
orderings,
148+
ann,
137149
limit,
138150
perPartitionLimit,
139151
allowsFiltering);
@@ -193,6 +205,7 @@ public Select withSelectors(@NonNull ImmutableList<Selector> newSelectors) {
193205
relations,
194206
groupByClauses,
195207
orderings,
208+
ann,
196209
limit,
197210
perPartitionLimit,
198211
allowsFiltering);
@@ -221,6 +234,7 @@ public Select withRelations(@NonNull ImmutableList<Relation> newRelations) {
221234
newRelations,
222235
groupByClauses,
223236
orderings,
237+
ann,
224238
limit,
225239
perPartitionLimit,
226240
allowsFiltering);
@@ -249,6 +263,7 @@ public Select withGroupByClauses(@NonNull ImmutableList<Selector> newGroupByClau
249263
relations,
250264
newGroupByClauses,
251265
orderings,
266+
ann,
252267
limit,
253268
perPartitionLimit,
254269
allowsFiltering);
@@ -260,6 +275,18 @@ public Select orderBy(@NonNull CqlIdentifier columnId, @NonNull ClusteringOrder
260275
return withOrderings(ImmutableCollections.append(orderings, columnId, order));
261276
}
262277

278+
@NonNull
279+
@Override
280+
public Select orderByAnnOf(@NonNull String columnName, @NonNull CqlVector<?> ann) {
281+
return withAnn(new Ann(CqlIdentifier.fromCql(columnName), ann));
282+
}
283+
284+
@NonNull
285+
@Override
286+
public Select orderByAnnOf(@NonNull CqlIdentifier columnId, @NonNull CqlVector<?> ann) {
287+
return withAnn(new Ann(columnId, ann));
288+
}
289+
263290
@NonNull
264291
@Override
265292
public Select orderByIds(@NonNull Map<CqlIdentifier, ClusteringOrder> newOrderings) {
@@ -277,6 +304,24 @@ public Select withOrderings(@NonNull ImmutableMap<CqlIdentifier, ClusteringOrder
277304
relations,
278305
groupByClauses,
279306
newOrderings,
307+
ann,
308+
limit,
309+
perPartitionLimit,
310+
allowsFiltering);
311+
}
312+
313+
@NonNull
314+
Select withAnn(@NonNull Ann ann) {
315+
return new DefaultSelect(
316+
keyspace,
317+
table,
318+
isJson,
319+
isDistinct,
320+
selectors,
321+
relations,
322+
groupByClauses,
323+
orderings,
324+
ann,
280325
limit,
281326
perPartitionLimit,
282327
allowsFiltering);
@@ -295,6 +340,7 @@ public Select limit(int limit) {
295340
relations,
296341
groupByClauses,
297342
orderings,
343+
ann,
298344
limit,
299345
perPartitionLimit,
300346
allowsFiltering);
@@ -312,6 +358,7 @@ public Select limit(@Nullable BindMarker bindMarker) {
312358
relations,
313359
groupByClauses,
314360
orderings,
361+
ann,
315362
bindMarker,
316363
perPartitionLimit,
317364
allowsFiltering);
@@ -331,6 +378,7 @@ public Select perPartitionLimit(int perPartitionLimit) {
331378
relations,
332379
groupByClauses,
333380
orderings,
381+
ann,
334382
limit,
335383
perPartitionLimit,
336384
allowsFiltering);
@@ -348,6 +396,7 @@ public Select perPartitionLimit(@Nullable BindMarker bindMarker) {
348396
relations,
349397
groupByClauses,
350398
orderings,
399+
ann,
351400
limit,
352401
bindMarker,
353402
allowsFiltering);
@@ -365,6 +414,7 @@ public Select allowFiltering() {
365414
relations,
366415
groupByClauses,
367416
orderings,
417+
ann,
368418
limit,
369419
perPartitionLimit,
370420
true);
@@ -391,15 +441,20 @@ public String asCql() {
391441
CqlHelper.append(relations, builder, " WHERE ", " AND ", null);
392442
CqlHelper.append(groupByClauses, builder, " GROUP BY ", ",", null);
393443

394-
boolean first = true;
395-
for (Map.Entry<CqlIdentifier, ClusteringOrder> entry : orderings.entrySet()) {
396-
if (first) {
397-
builder.append(" ORDER BY ");
398-
first = false;
399-
} else {
400-
builder.append(",");
444+
if (ann != null) {
445+
builder.append(" ORDER BY ").append(this.ann.columnId.asCql(true)).append(" ANN OF ");
446+
QueryBuilder.literal(ann.vector).appendTo(builder);
447+
} else {
448+
boolean first = true;
449+
for (Map.Entry<CqlIdentifier, ClusteringOrder> entry : orderings.entrySet()) {
450+
if (first) {
451+
builder.append(" ORDER BY ");
452+
first = false;
453+
} else {
454+
builder.append(",");
455+
}
456+
builder.append(entry.getKey().asCql(true)).append(" ").append(entry.getValue().name());
401457
}
402-
builder.append(entry.getKey().asCql(true)).append(" ").append(entry.getValue().name());
403458
}
404459

405460
if (limit != null) {
@@ -499,6 +554,11 @@ public Object getLimit() {
499554
return limit;
500555
}
501556

557+
@Nullable
558+
public Ann getAnn() {
559+
return ann;
560+
}
561+
502562
@Nullable
503563
public Object getPerPartitionLimit() {
504564
return perPartitionLimit;
@@ -512,4 +572,14 @@ public boolean allowsFiltering() {
512572
public String toString() {
513573
return asCql();
514574
}
575+
576+
public static class Ann {
577+
private final CqlVector<?> vector;
578+
private final CqlIdentifier columnId;
579+
580+
private Ann(CqlIdentifier columnId, CqlVector<?> vector) {
581+
this.vector = vector;
582+
this.columnId = columnId;
583+
}
584+
}
515585
}

query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/delete/DeleteSelectorTest.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.deleteFrom;
2323
import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.literal;
2424

25+
import com.datastax.oss.driver.api.core.data.CqlVector;
2526
import org.junit.Test;
2627

2728
public class DeleteSelectorTest {
@@ -34,6 +35,16 @@ public void should_generate_column_deletion() {
3435
.hasCql("DELETE v FROM ks.foo WHERE k=?");
3536
}
3637

38+
@Test
39+
public void should_generate_vector_deletion() {
40+
assertThat(
41+
deleteFrom("foo")
42+
.column("v")
43+
.whereColumn("k")
44+
.isEqualTo(literal(CqlVector.newInstance(0.1, 0.2))))
45+
.hasCql("DELETE v FROM foo WHERE k=[0.1, 0.2]");
46+
}
47+
3748
@Test
3849
public void should_generate_field_deletion() {
3950
assertThat(

query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/insert/RegularInsertTest.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.literal;
2424
import static org.assertj.core.api.Assertions.catchThrowable;
2525

26+
import com.datastax.oss.driver.api.core.data.CqlVector;
2627
import com.datastax.oss.driver.api.querybuilder.term.Term;
2728
import com.datastax.oss.driver.internal.querybuilder.insert.DefaultInsert;
2829
import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableMap;
@@ -41,6 +42,12 @@ public void should_generate_column_assignments() {
4142
.hasCql("INSERT INTO foo (a,b) VALUES (?,?)");
4243
}
4344

45+
@Test
46+
public void should_generate_vector_literals() {
47+
assertThat(insertInto("foo").value("a", literal(CqlVector.newInstance(0.1, 0.2, 0.3))))
48+
.hasCql("INSERT INTO foo (a) VALUES ([0.1, 0.2, 0.3])");
49+
}
50+
4451
@Test
4552
public void should_keep_last_assignment_if_column_listed_twice() {
4653
assertThat(

query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/schema/AlterTableTest.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,10 @@ public void should_generate_alter_table_with_no_compression() {
108108
assertThat(alterTable("bar").withNoCompression())
109109
.hasCql("ALTER TABLE bar WITH compression={'sstable_compression':''}");
110110
}
111+
112+
@Test
113+
public void should_generate_alter_table_with_vector() {
114+
assertThat(alterTable("bar").alterColumn("v", DataTypes.vectorOf(DataTypes.FLOAT, 3)))
115+
.hasCql("ALTER TABLE bar ALTER v TYPE vector<float, 3>");
116+
}
111117
}

query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/schema/AlterTypeTest.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,10 @@ public void should_generate_alter_table_with_rename_three_columns() {
5353
assertThat(alterType("bar").renameField("x", "y").renameField("u", "v").renameField("b", "a"))
5454
.hasCql("ALTER TYPE bar RENAME x TO y AND u TO v AND b TO a");
5555
}
56+
57+
@Test
58+
public void should_generate_alter_type_with_vector() {
59+
assertThat(alterType("foo", "bar").alterField("vec", DataTypes.vectorOf(DataTypes.FLOAT, 3)))
60+
.hasCql("ALTER TYPE foo.bar ALTER vec TYPE vector<float, 3>");
61+
}
5662
}

query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/schema/CreateTableTest.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,4 +314,13 @@ public void should_generate_create_table_time_window_compaction() {
314314
.hasCql(
315315
"CREATE TABLE bar (k int PRIMARY KEY,v text) WITH compaction={'class':'TimeWindowCompactionStrategy','compaction_window_size':10,'compaction_window_unit':'DAYS','timestamp_resolution':'MICROSECONDS','unsafe_aggressive_sstable_expiration':false}");
316316
}
317+
318+
@Test
319+
public void should_generate_vector_column() {
320+
assertThat(
321+
createTable("foo")
322+
.withPartitionKey("k", DataTypes.INT)
323+
.withColumn("v", DataTypes.vectorOf(DataTypes.FLOAT, 3)))
324+
.hasCql("CREATE TABLE foo (k int PRIMARY KEY,v vector<float, 3>)");
325+
}
317326
}

0 commit comments

Comments
 (0)