Skip to content

Commit 2836dd3

Browse files
committed
Added more SQLModel examples and tests
1 parent 913fbb0 commit 2836dd3

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

README.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,28 @@ session.exec(select(Item).order_by(Item.embedding.l2_distance([3, 1, 2])).limit(
233233

234234
Also supports `max_inner_product` and `cosine_distance`
235235

236+
Get the distance
237+
238+
```python
239+
session.exec(select(Item.embedding.l2_distance([3, 1, 2])))
240+
```
241+
242+
Get items within a certain distance
243+
244+
```python
245+
session.exec(select(Item).filter(Item.embedding.l2_distance([3, 1, 2]) < 5))
246+
```
247+
248+
Average vectors
249+
250+
```python
251+
from sqlalchemy.sql import func
252+
253+
session.exec(select(func.avg(Item.embedding))).first()
254+
```
255+
256+
Also supports `sum`
257+
236258
## Psycopg 3
237259

238260
Enable the extension

tests/test_sqlmodel.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44
from sqlalchemy import Column
55
from sqlalchemy.exc import StatementError
6+
from sqlalchemy.sql import func
67
from sqlmodel import Field, Session, SQLModel, create_engine, delete, select, text
78
from typing import List, Optional
89

@@ -81,6 +82,36 @@ def test_cosine_distance(self):
8182
items = session.exec(select(Item).order_by(Item.embedding.cosine_distance([1, 1, 1])))
8283
assert [v.id for v in items] == [1, 2, 3]
8384

85+
def test_filter(self):
86+
create_items()
87+
with Session(engine) as session:
88+
items = session.exec(select(Item).filter(Item.embedding.l2_distance([1, 1, 1]) < 1))
89+
assert [v.id for v in items] == [1]
90+
91+
def test_select(self):
92+
with Session(engine) as session:
93+
session.add(Item(embedding=[2, 3, 3]))
94+
item = session.exec(select(Item.embedding.l2_distance([1, 1, 1]))).all()
95+
assert item[0] == 3
96+
97+
def test_avg(self):
98+
with Session(engine) as session:
99+
avg = session.exec(select(func.avg(Item.embedding))).first()
100+
assert avg is None
101+
session.add(Item(embedding=[1, 2, 3]))
102+
session.add(Item(embedding=[4, 5, 6]))
103+
avg = session.exec(select(func.avg(Item.embedding))).first()
104+
assert np.array_equal(avg, np.array([2.5, 3.5, 4.5]))
105+
106+
def test_sum(self):
107+
with Session(engine) as session:
108+
sum = session.exec(select(func.sum(Item.embedding))).first()
109+
assert sum is None
110+
session.add(Item(embedding=[1, 2, 3]))
111+
session.add(Item(embedding=[4, 5, 6]))
112+
sum = session.exec(select(func.sum(Item.embedding))).first()
113+
assert np.array_equal(sum, np.array([5, 7, 9]))
114+
84115
def test_bad_dimensions(self):
85116
item = Item(embedding=[1, 2])
86117
session = Session(engine)

0 commit comments

Comments
 (0)