Skip to content

Commit d686807

Browse files
committed
Added tests for aggregates with Django
1 parent 8123b94 commit d686807

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

tests/test_django.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from django.conf import settings
33
from django.core import serializers
44
from django.db import connection, migrations, models
5+
from django.db.models import Avg, Sum
56
from django.db.migrations.loader import MigrationLoader
67
from django.forms import ModelForm
78
from math import sqrt
@@ -132,6 +133,22 @@ def test_filter(self):
132133
items = Item.objects.alias(distance=distance).filter(distance__lt=1)
133134
assert [v.id for v in items] == [1]
134135

136+
def test_avg(self):
137+
avg = Item.objects.aggregate(Avg('embedding'))['embedding__avg']
138+
assert avg is None
139+
Item(embedding=[1, 2, 3]).save()
140+
Item(embedding=[4, 5, 6]).save()
141+
avg = Item.objects.aggregate(Avg('embedding'))['embedding__avg']
142+
assert np.array_equal(avg, np.array([2.5, 3.5, 4.5]))
143+
144+
def test_sum(self):
145+
avg = Item.objects.aggregate(Sum('embedding'))['embedding__sum']
146+
assert avg is None
147+
Item(embedding=[1, 2, 3]).save()
148+
Item(embedding=[4, 5, 6]).save()
149+
avg = Item.objects.aggregate(Sum('embedding'))['embedding__sum']
150+
assert np.array_equal(avg, np.array([5, 7, 9]))
151+
135152
def test_serialization(self):
136153
create_items()
137154
items = Item.objects.all()

0 commit comments

Comments
 (0)