|
2 | 2 | from django.conf import settings
|
3 | 3 | from django.core import serializers
|
4 | 4 | from django.db import connection, migrations, models
|
| 5 | +from django.db.models import Avg, Sum |
5 | 6 | from django.db.migrations.loader import MigrationLoader
|
6 | 7 | from django.forms import ModelForm
|
7 | 8 | from math import sqrt
|
@@ -132,6 +133,22 @@ def test_filter(self):
|
132 | 133 | items = Item.objects.alias(distance=distance).filter(distance__lt=1)
|
133 | 134 | assert [v.id for v in items] == [1]
|
134 | 135 |
|
| 136 | + def test_avg(self): |
| 137 | + avg = Item.objects.aggregate(Avg('embedding'))['embedding__avg'] |
| 138 | + assert avg is None |
| 139 | + Item(embedding=[1, 2, 3]).save() |
| 140 | + Item(embedding=[4, 5, 6]).save() |
| 141 | + avg = Item.objects.aggregate(Avg('embedding'))['embedding__avg'] |
| 142 | + assert np.array_equal(avg, np.array([2.5, 3.5, 4.5])) |
| 143 | + |
| 144 | + def test_sum(self): |
| 145 | + avg = Item.objects.aggregate(Sum('embedding'))['embedding__sum'] |
| 146 | + assert avg is None |
| 147 | + Item(embedding=[1, 2, 3]).save() |
| 148 | + Item(embedding=[4, 5, 6]).save() |
| 149 | + avg = Item.objects.aggregate(Sum('embedding'))['embedding__sum'] |
| 150 | + assert np.array_equal(avg, np.array([5, 7, 9])) |
| 151 | + |
135 | 152 | def test_serialization(self):
|
136 | 153 | create_items()
|
137 | 154 | items = Item.objects.all()
|
|
0 commit comments