Skip to content

Commit afb3153

Browse files
committed
Add implementation for aggregate groupby
1 parent 6e362a0 commit afb3153

File tree

5 files changed

+159
-19
lines changed

5 files changed

+159
-19
lines changed

src/collections/aggregate/index.ts

Lines changed: 88 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -330,13 +330,19 @@ export type AggregateResult<T, M extends PropertiesMetrics<T> | undefined = unde
330330
totalCount: number;
331331
};
332332

333+
export type AggregatedGeoCoordinate = {
334+
latitude: number;
335+
longitude: number;
336+
distance: number;
337+
};
338+
333339
export type AggregateGroupByResult<
334340
T,
335341
M extends PropertiesMetrics<T> | undefined = undefined
336342
> = AggregateResult<T, M> & {
337343
groupedBy: {
338344
prop: string;
339-
value: string;
345+
value: string | number | boolean | AggregatedGeoCoordinate | string[] | number[] | boolean[];
340346
};
341347
};
342348

@@ -365,10 +371,22 @@ class AggregateManager<T> implements Aggregate<T> {
365371
this.grpcChecker = this.dbVersionSupport.supportsAggregateGRPC().then((res) => res.supports);
366372

367373
this.groupBy = {
368-
hybrid: <M extends PropertiesMetrics<T> | undefined = undefined>(
374+
hybrid: async <M extends PropertiesMetrics<T>>(
369375
query: string,
370376
opts: AggregateGroupByHybridOptions<T, M>
371377
): Promise<AggregateGroupByResult<T, M>[]> => {
378+
if (await this.grpcChecker) {
379+
const group = typeof opts.groupBy === 'string' ? { property: opts.groupBy } : opts.groupBy;
380+
return this.grpc()
381+
.then((aggregate) =>
382+
aggregate.withHybrid({
383+
...Serialize.aggregate.hybrid(query, opts),
384+
groupBy: Serialize.aggregate.groupBy(group),
385+
limit: group.limit,
386+
})
387+
)
388+
.then((reply) => Deserialize.aggregateGroupBy(reply));
389+
}
372390
let builder = this.base(opts?.returnMetrics, opts?.filters, opts?.groupBy).withHybrid({
373391
query: query,
374392
alpha: opts?.alpha,
@@ -382,12 +400,25 @@ class AggregateManager<T> implements Aggregate<T> {
382400
}
383401
return this.doGroupBy(builder);
384402
},
385-
nearImage: async <M extends PropertiesMetrics<T> | undefined = undefined>(
403+
nearImage: async <M extends PropertiesMetrics<T>>(
386404
image: string | Buffer,
387405
opts: AggregateGroupByNearOptions<T, M>
388406
): Promise<AggregateGroupByResult<T, M>[]> => {
407+
const [b64, usesGrpc] = await Promise.all([await toBase64FromMedia(image), await this.grpcChecker]);
408+
if (usesGrpc) {
409+
const group = typeof opts.groupBy === 'string' ? { property: opts.groupBy } : opts.groupBy;
410+
return this.grpc()
411+
.then((aggregate) =>
412+
aggregate.withNearImage({
413+
...Serialize.aggregate.nearImage(b64, opts),
414+
groupBy: Serialize.aggregate.groupBy(group),
415+
limit: group.limit,
416+
})
417+
)
418+
.then((reply) => Deserialize.aggregateGroupBy(reply));
419+
}
389420
const builder = this.base(opts?.returnMetrics, opts?.filters, opts?.groupBy).withNearImage({
390-
image: await toBase64FromMedia(image),
421+
image: b64,
391422
certainty: opts?.certainty,
392423
distance: opts?.distance,
393424
targetVectors: opts?.targetVector ? [opts.targetVector] : undefined,
@@ -397,10 +428,22 @@ class AggregateManager<T> implements Aggregate<T> {
397428
}
398429
return this.doGroupBy(builder);
399430
},
400-
nearObject: <M extends PropertiesMetrics<T> | undefined = undefined>(
431+
nearObject: async <M extends PropertiesMetrics<T>>(
401432
id: string,
402433
opts: AggregateGroupByNearOptions<T, M>
403434
): Promise<AggregateGroupByResult<T, M>[]> => {
435+
if (await this.grpcChecker) {
436+
const group = typeof opts.groupBy === 'string' ? { property: opts.groupBy } : opts.groupBy;
437+
return this.grpc()
438+
.then((aggregate) =>
439+
aggregate.withNearObject({
440+
...Serialize.aggregate.nearObject(id, opts),
441+
groupBy: Serialize.aggregate.groupBy(group),
442+
limit: group.limit,
443+
})
444+
)
445+
.then((reply) => Deserialize.aggregateGroupBy(reply));
446+
}
404447
const builder = this.base(opts?.returnMetrics, opts?.filters, opts?.groupBy).withNearObject({
405448
id: id,
406449
certainty: opts?.certainty,
@@ -412,10 +455,22 @@ class AggregateManager<T> implements Aggregate<T> {
412455
}
413456
return this.doGroupBy(builder);
414457
},
415-
nearText: <M extends PropertiesMetrics<T> | undefined = undefined>(
458+
nearText: async <M extends PropertiesMetrics<T>>(
416459
query: string | string[],
417460
opts: AggregateGroupByNearOptions<T, M>
418461
): Promise<AggregateGroupByResult<T, M>[]> => {
462+
if (await this.grpcChecker) {
463+
const group = typeof opts.groupBy === 'string' ? { property: opts.groupBy } : opts.groupBy;
464+
return this.grpc()
465+
.then((aggregate) =>
466+
aggregate.withNearText({
467+
...Serialize.aggregate.nearText(query, opts),
468+
groupBy: Serialize.aggregate.groupBy(group),
469+
limit: group.limit,
470+
})
471+
)
472+
.then((reply) => Deserialize.aggregateGroupBy(reply));
473+
}
419474
const builder = this.base(opts?.returnMetrics, opts?.filters, opts?.groupBy).withNearText({
420475
concepts: Array.isArray(query) ? query : [query],
421476
certainty: opts?.certainty,
@@ -427,10 +482,22 @@ class AggregateManager<T> implements Aggregate<T> {
427482
}
428483
return this.doGroupBy(builder);
429484
},
430-
nearVector: <M extends PropertiesMetrics<T> | undefined = undefined>(
485+
nearVector: async <M extends PropertiesMetrics<T>>(
431486
vector: number[],
432487
opts: AggregateGroupByNearOptions<T, M>
433488
): Promise<AggregateGroupByResult<T, M>[]> => {
489+
if (await this.grpcChecker) {
490+
const group = typeof opts.groupBy === 'string' ? { property: opts.groupBy } : opts.groupBy;
491+
return this.grpc()
492+
.then((aggregate) =>
493+
aggregate.withNearVector({
494+
...Serialize.aggregate.nearVector(vector, opts),
495+
groupBy: Serialize.aggregate.groupBy(group),
496+
limit: group.limit,
497+
})
498+
)
499+
.then((reply) => Deserialize.aggregateGroupBy(reply));
500+
}
434501
const builder = this.base(opts?.returnMetrics, opts?.filters, opts?.groupBy).withNearVector({
435502
vector: vector,
436503
certainty: opts?.certainty,
@@ -442,9 +509,22 @@ class AggregateManager<T> implements Aggregate<T> {
442509
}
443510
return this.doGroupBy(builder);
444511
},
445-
overAll: <M extends PropertiesMetrics<T> | undefined = undefined>(
512+
overAll: async <M extends PropertiesMetrics<T>>(
446513
opts: AggregateGroupByOptions<T, M>
447514
): Promise<AggregateGroupByResult<T, M>[]> => {
515+
if (await this.grpcChecker) {
516+
const group = typeof opts.groupBy === 'string' ? { property: opts.groupBy } : opts.groupBy;
517+
return this.grpc()
518+
.then((aggregate) =>
519+
aggregate.withFetch({
520+
...Serialize.aggregate.overAll(opts),
521+
522+
groupBy: Serialize.aggregate.groupBy(group),
523+
limit: group.limit,
524+
})
525+
)
526+
.then((reply) => Deserialize.aggregateGroupBy(reply));
527+
}
448528
const builder = this.base(opts?.returnMetrics, opts?.filters, opts?.groupBy);
449529
return this.doGroupBy(builder);
450530
},

src/collections/deserialize/index.ts

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@ import { WeaviateDeserializationError } from '../../errors.js';
22
import { Tenant as TenantREST } from '../../openapi/types.js';
33
import {
44
AggregateReply,
5+
AggregateReply_Aggregations,
56
AggregateReply_Aggregations_Aggregation,
67
AggregateReply_Aggregations_Aggregation_Boolean,
78
AggregateReply_Aggregations_Aggregation_DateMessage,
89
AggregateReply_Aggregations_Aggregation_Integer,
910
AggregateReply_Aggregations_Aggregation_Number,
1011
AggregateReply_Aggregations_Aggregation_Text,
12+
AggregateReply_Group_GroupedBy,
1113
} from '../../proto/v1/aggregate.js';
1214
import { BatchObject as BatchObjectGRPC, BatchObjectsReply } from '../../proto/v1/batch.js';
1315
import { BatchDeleteReply } from '../../proto/v1/batch_delete.js';
@@ -18,6 +20,7 @@ import { DbVersionSupport } from '../../utils/dbVersion.js';
1820
import {
1921
AggregateBoolean,
2022
AggregateDate,
23+
AggregateGroupByResult,
2124
AggregateNumber,
2225
AggregateResult,
2326
AggregateText,
@@ -127,21 +130,67 @@ export class Deserialize {
127130
throw new WeaviateDeserializationError(`Unknown aggregation type: ${aggregation}`);
128131
}
129132

133+
private static aggregations(aggregations?: AggregateReply_Aggregations): Record<string, AggregateType> {
134+
return aggregations
135+
? Object.fromEntries(
136+
aggregations.aggregations.map((aggregation) => [
137+
aggregation.property,
138+
Deserialize.mapAggregate(aggregation),
139+
])
140+
)
141+
: {};
142+
}
143+
130144
public static aggregate<T, M extends PropertiesMetrics<T>>(reply: AggregateReply): AggregateResult<T, M> {
131145
if (reply.singleResult === undefined) {
132146
throw new WeaviateDeserializationError('No single result in aggregate response');
133147
}
134148
return {
135149
totalCount: reply.singleResult.objectsCount!,
136-
properties: (reply.singleResult.aggregations
137-
? Object.fromEntries(
138-
reply.singleResult.aggregations.aggregations.map((aggregation) => [
139-
aggregation.property,
140-
Deserialize.mapAggregate(aggregation),
141-
])
142-
)
143-
: {}) as AggregateResult<T, M>['properties'],
150+
properties: Deserialize.aggregations(reply.singleResult.aggregations) as AggregateResult<
151+
T,
152+
M
153+
>['properties'],
154+
};
155+
}
156+
157+
public static aggregateGroupBy<T, M extends PropertiesMetrics<T>>(
158+
reply: AggregateReply
159+
): AggregateGroupByResult<T, M>[] {
160+
if (reply.groupedResults === undefined)
161+
throw new WeaviateDeserializationError('No grouped results in aggregate response');
162+
163+
const parse = (groupedBy?: AggregateReply_Group_GroupedBy): AggregateGroupByResult<T, M>['groupedBy'] => {
164+
if (groupedBy === undefined)
165+
throw new WeaviateDeserializationError('No groupedBy in aggregate response');
166+
167+
let value: AggregateGroupByResult<T, M>['groupedBy']['value'];
168+
if (groupedBy.boolean !== undefined) value = groupedBy.boolean;
169+
else if (groupedBy.booleans !== undefined) value = groupedBy.booleans.values;
170+
else if (groupedBy.geo !== undefined) value = groupedBy.geo;
171+
else if (groupedBy.int !== undefined) value = groupedBy.int;
172+
else if (groupedBy.ints !== undefined) value = groupedBy.ints.values;
173+
else if (groupedBy.number !== undefined) value = groupedBy.number;
174+
else if (groupedBy.numbers !== undefined) value = groupedBy.numbers.values;
175+
else if (groupedBy.text !== undefined) value = groupedBy.text;
176+
else if (groupedBy.texts !== undefined) value = groupedBy.texts.values;
177+
else {
178+
console.warn(`Unknown groupBy type: ${JSON.stringify(groupedBy, null, 2)}`);
179+
value = '';
180+
}
181+
182+
return {
183+
prop: groupedBy.path[0],
184+
value,
185+
};
144186
};
187+
return reply.groupedResults.groups.map((group) => {
188+
return {
189+
totalCount: group.objectsCount!,
190+
groupedBy: parse(group.groupedBy),
191+
properties: Deserialize.aggregations(group.aggregations) as AggregateResult<T, M>['properties'],
192+
};
193+
});
145194
}
146195

147196
public query<T>(reply: SearchReply): WeaviateReturn<T> {
@@ -174,7 +223,7 @@ export class Deserialize {
174223
};
175224
}
176225

177-
public groupBy<T>(reply: SearchReply): GroupByReturn<T> {
226+
public queryGroupBy<T>(reply: SearchReply): GroupByReturn<T> {
178227
const objects: GroupByObject<T>[] = [];
179228
const groups: Record<string, GroupByResult<T>> = {};
180229
reply.groupByResults.forEach((result) => {

src/collections/query/index.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ class QueryManager<T> implements Query<T> {
6161
reply: SearchReply
6262
) {
6363
const deserialize = await Deserialize.use(this.check.dbVersionSupport);
64-
return Serialize.search.isGroupBy(opts) ? deserialize.groupBy<T>(reply) : deserialize.query<T>(reply);
64+
return Serialize.search.isGroupBy(opts)
65+
? deserialize.queryGroupBy<T>(reply)
66+
: deserialize.query<T>(reply);
6567
}
6668

6769
public fetchObjectById(id: string, opts?: FetchObjectByIdOptions<T>): Promise<WeaviateObject<T> | null> {

src/collections/serialize/index.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ import {
7070
AggregateRequest_Aggregation_Integer,
7171
AggregateRequest_Aggregation_Number,
7272
AggregateRequest_Aggregation_Text,
73+
AggregateRequest_GroupBy,
7374
} from '../../proto/v1/aggregate.js';
7475
import {
7576
BooleanArrayProperties,
@@ -96,6 +97,7 @@ import {
9697
AggregateBaseOptions,
9798
AggregateHybridOptions,
9899
AggregateNearOptions,
100+
GroupByAggregate,
99101
MultiTargetVectorJoin,
100102
PrimitiveKeys,
101103
PropertiesMetrics,
@@ -389,6 +391,12 @@ class Aggregate {
389391
};
390392
};
391393

394+
public static groupBy = <T>(groupBy?: GroupByAggregate<T>): AggregateRequest_GroupBy => {
395+
return AggregateRequest_GroupBy.fromPartial({
396+
property: groupBy?.property,
397+
});
398+
};
399+
392400
public static hybrid = (
393401
query: string,
394402
opts?: AggregateHybridOptions<any, PropertiesMetrics<any>>

src/grpc/aggregator.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@ import { retryOptions } from './retry.js';
3434

3535
export type BaseAggregateArgs = {
3636
aggregations?: AggregateRequest_Aggregation[];
37-
objectLimit?: number;
3837
filters?: Filters;
3938
groupBy?: AggregateRequest_GroupBy;
39+
limit?: number;
40+
objectLimit?: number;
4041
};
4142

4243
export type AggregateFetchArgs = BaseAggregateArgs;

0 commit comments

Comments
 (0)