From c6b11aa6f2554d2387158ea7ec7756b4d9d771b6 Mon Sep 17 00:00:00 2001 From: F Date: Mon, 18 Sep 2023 07:38:37 +0000 Subject: [PATCH] feat: add a full example for sqlmodel and fix error in README.md. --- README.md | 7 +++++- examples/simple_sqlmodel_vector.py | 34 ++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 examples/simple_sqlmodel_vector.py diff --git a/README.md b/README.md index 2122299..8ebf2a5 100644 --- a/README.md +++ b/README.md @@ -214,10 +214,13 @@ session.exec(text('CREATE EXTENSION IF NOT EXISTS vector')) Add a vector column ```python +from typing import List, Optional + from pgvector.sqlalchemy import Vector -from sqlalchemy import Column +from sqlmodel import Column, Field, Session, SQLModel, create_engine, select class Item(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) embedding: List[float] = Field(sa_column=Column(Vector(3))) ``` @@ -237,6 +240,8 @@ session.exec(select(Item).order_by(Item.embedding.l2_distance([3, 1, 2])).limit( Also supports `max_inner_product` and `cosine_distance` +See [examples/simple_sqlmodel_vector.py](examples/simple_sqlmodel_vector.py) for full code. + Get the distance ```python diff --git a/examples/simple_sqlmodel_vector.py b/examples/simple_sqlmodel_vector.py new file mode 100644 index 0000000..a790275 --- /dev/null +++ b/examples/simple_sqlmodel_vector.py @@ -0,0 +1,34 @@ +""" +A simple sqlmodel vector demo via pgvector. + +For mac, if depdency missing or error, try `pip install pgvector-binary` +""" + +from typing import List, Optional + +from pgvector.sqlalchemy import Vector +from sqlmodel import Column, Field, Session, SQLModel, create_engine, select + + +class Item(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + embedding: List[float] = Field(sa_column=Column(Vector(3))) + +sqlite_url = f"postgresql://testuser:testuser@localhost:5432/testdb" + +engine = create_engine(sqlite_url, echo=False) + +SQLModel.metadata.create_all(engine) + +with Session(engine) as session: + item = Item(embedding=[1, 2, 3]) + session.add(item) + session.commit() + + res = session.exec( + select(Item).order_by(Item.embedding.l2_distance([3, 1, 2])).limit(5) + ) + + for i in res: + print(i)