|
3 | 3 | import pytest
|
4 | 4 | from sqlalchemy import Column
|
5 | 5 | from sqlalchemy.exc import StatementError
|
| 6 | +from sqlalchemy.sql import func |
6 | 7 | from sqlmodel import Field, Session, SQLModel, create_engine, delete, select, text
|
7 | 8 | from typing import List, Optional
|
8 | 9 |
|
@@ -81,6 +82,36 @@ def test_cosine_distance(self):
|
81 | 82 | items = session.exec(select(Item).order_by(Item.embedding.cosine_distance([1, 1, 1])))
|
82 | 83 | assert [v.id for v in items] == [1, 2, 3]
|
83 | 84 |
|
| 85 | + def test_filter(self): |
| 86 | + create_items() |
| 87 | + with Session(engine) as session: |
| 88 | + items = session.exec(select(Item).filter(Item.embedding.l2_distance([1, 1, 1]) < 1)) |
| 89 | + assert [v.id for v in items] == [1] |
| 90 | + |
| 91 | + def test_select(self): |
| 92 | + with Session(engine) as session: |
| 93 | + session.add(Item(embedding=[2, 3, 3])) |
| 94 | + item = session.exec(select(Item.embedding.l2_distance([1, 1, 1]))).all() |
| 95 | + assert item[0] == 3 |
| 96 | + |
| 97 | + def test_avg(self): |
| 98 | + with Session(engine) as session: |
| 99 | + avg = session.exec(select(func.avg(Item.embedding))).first() |
| 100 | + assert avg is None |
| 101 | + session.add(Item(embedding=[1, 2, 3])) |
| 102 | + session.add(Item(embedding=[4, 5, 6])) |
| 103 | + avg = session.exec(select(func.avg(Item.embedding))).first() |
| 104 | + assert np.array_equal(avg, np.array([2.5, 3.5, 4.5])) |
| 105 | + |
| 106 | + def test_sum(self): |
| 107 | + with Session(engine) as session: |
| 108 | + sum = session.exec(select(func.sum(Item.embedding))).first() |
| 109 | + assert sum is None |
| 110 | + session.add(Item(embedding=[1, 2, 3])) |
| 111 | + session.add(Item(embedding=[4, 5, 6])) |
| 112 | + sum = session.exec(select(func.sum(Item.embedding))).first() |
| 113 | + assert np.array_equal(sum, np.array([5, 7, 9])) |
| 114 | + |
84 | 115 | def test_bad_dimensions(self):
|
85 | 116 | item = Item(embedding=[1, 2])
|
86 | 117 | session = Session(engine)
|
|
0 commit comments