diff --git a/examples/python/requirements.txt b/examples/python/requirements.txt index c22530eaa9..26f645a169 100644 --- a/examples/python/requirements.txt +++ b/examples/python/requirements.txt @@ -2,5 +2,4 @@ pytest~=6.2.2 pytest-asyncio~=0.14.0 psycopg2-binary~=2.8.6 asyncpg~=0.21.0 -pydantic~=1.7.3 -sqlc-python-runtime~=1.0.0 +sqlalchemy==1.4.0 diff --git a/examples/python/sqlc.json b/examples/python/sqlc.json index ba987b5af1..583513a184 100644 --- a/examples/python/sqlc.json +++ b/examples/python/sqlc.json @@ -8,7 +8,9 @@ "gen": { "python": { "out": "src/authors", - "package": "authors" + "package": "authors", + "emit_sync_querier": true, + "emit_async_querier": true } } }, @@ -19,7 +21,8 @@ "gen": { "python": { "out": "src/booktest", - "package": "booktest" + "package": "booktest", + "emit_async_querier": true } } }, @@ -30,7 +33,8 @@ "gen": { "python": { "out": "src/jets", - "package": "jets" + "package": "jets", + "emit_async_querier": true } } }, @@ -41,7 +45,8 @@ "gen": { "python": { "out": "src/ondeck", - "package": "ondeck" + "package": "ondeck", + "emit_async_querier": true } } } diff --git a/examples/python/src/authors/models.py b/examples/python/src/authors/models.py index b282d77dac..42d945abe1 100644 --- a/examples/python/src/authors/models.py +++ b/examples/python/src/authors/models.py @@ -1,13 +1,13 @@ # Code generated by sqlc. DO NOT EDIT. from typing import Optional -import pydantic +import dataclasses -# Enums -# Models -class Author(pydantic.BaseModel): + +@dataclasses.dataclass() +class Author: id: int name: str bio: Optional[str] diff --git a/examples/python/src/authors/query.py b/examples/python/src/authors/query.py index 947e535309..5571b49364 100644 --- a/examples/python/src/authors/query.py +++ b/examples/python/src/authors/query.py @@ -1,92 +1,111 @@ + # Code generated by sqlc. DO NOT EDIT. -from typing import AsyncIterator, Awaitable, Iterator, Optional, overload +from typing import AsyncIterator, Iterator, Optional -import sqlc_runtime as sqlc +import sqlalchemy +import sqlalchemy.ext.asyncio from authors import models -CREATE_AUTHOR = """-- name: create_author :one +CREATE_AUTHOR = """-- name: create_author \\:one INSERT INTO authors ( name, bio ) VALUES ( - $1, $2 + :p1, :p2 ) RETURNING id, name, bio """ -DELETE_AUTHOR = """-- name: delete_author :exec +DELETE_AUTHOR = """-- name: delete_author \\:exec DELETE FROM authors -WHERE id = $1 +WHERE id = :p1 """ -GET_AUTHOR = """-- name: get_author :one +GET_AUTHOR = """-- name: get_author \\:one SELECT id, name, bio FROM authors -WHERE id = $1 LIMIT 1 +WHERE id = :p1 LIMIT 1 """ -LIST_AUTHORS = """-- name: list_authors :many +LIST_AUTHORS = """-- name: list_authors \\:many SELECT id, name, bio FROM authors ORDER BY name """ -@overload -def create_author(conn: sqlc.Connection, name: str, bio: Optional[str]) -> Optional[models.Author]: - pass - - -@overload -def create_author(conn: sqlc.AsyncConnection, name: str, bio: Optional[str]) -> Awaitable[Optional[models.Author]]: - pass - - -def create_author(conn: sqlc.GenericConnection, name: str, bio: Optional[str]) -> sqlc.ReturnType[Optional[models.Author]]: - return conn.execute_one_model(models.Author, CREATE_AUTHOR, name, bio) - - -@overload -def delete_author(conn: sqlc.Connection, id: int) -> None: - pass - - -@overload -def delete_author(conn: sqlc.AsyncConnection, id: int) -> Awaitable[None]: - pass - - -def delete_author(conn: sqlc.GenericConnection, id: int) -> sqlc.ReturnType[None]: - return conn.execute_none(DELETE_AUTHOR, id) - - -@overload -def get_author(conn: sqlc.Connection, id: int) -> Optional[models.Author]: - pass - - -@overload -def get_author(conn: sqlc.AsyncConnection, id: int) -> Awaitable[Optional[models.Author]]: - pass - - -def get_author(conn: sqlc.GenericConnection, id: int) -> sqlc.ReturnType[Optional[models.Author]]: - return conn.execute_one_model(models.Author, GET_AUTHOR, id) - - -@overload -def list_authors(conn: sqlc.Connection) -> Iterator[models.Author]: - pass - - -@overload -def list_authors(conn: sqlc.AsyncConnection) -> AsyncIterator[models.Author]: - pass - - -def list_authors(conn: sqlc.GenericConnection) -> sqlc.IteratorReturn[models.Author]: - return conn.execute_many_model(models.Author, LIST_AUTHORS) - +class Querier: + def __init__(self, conn: sqlalchemy.engine.Connection): + self._conn = conn + + def create_author(self, *, name: str, bio: Optional[str]) -> Optional[models.Author]: + row = self._conn.execute(sqlalchemy.text(CREATE_AUTHOR), {"p1": name, "p2": bio}).first() + if row is None: + return None + return models.Author( + id=row[0], + name=row[1], + bio=row[2], + ) + + def delete_author(self, *, id: int) -> None: + self._conn.execute(sqlalchemy.text(DELETE_AUTHOR), {"p1": id}) + + def get_author(self, *, id: int) -> Optional[models.Author]: + row = self._conn.execute(sqlalchemy.text(GET_AUTHOR), {"p1": id}).first() + if row is None: + return None + return models.Author( + id=row[0], + name=row[1], + bio=row[2], + ) + + def list_authors(self) -> Iterator[models.Author]: + result = self._conn.execute(sqlalchemy.text(LIST_AUTHORS)) + for row in result: + yield models.Author( + id=row[0], + name=row[1], + bio=row[2], + ) + + +class AsyncQuerier: + def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection): + self._conn = conn + + async def create_author(self, *, name: str, bio: Optional[str]) -> Optional[models.Author]: + row = (await self._conn.execute(sqlalchemy.text(CREATE_AUTHOR), {"p1": name, "p2": bio})).first() + if row is None: + return None + return models.Author( + id=row[0], + name=row[1], + bio=row[2], + ) + + async def delete_author(self, *, id: int) -> None: + await self._conn.execute(sqlalchemy.text(DELETE_AUTHOR), {"p1": id}) + + async def get_author(self, *, id: int) -> Optional[models.Author]: + row = (await self._conn.execute(sqlalchemy.text(GET_AUTHOR), {"p1": id})).first() + if row is None: + return None + return models.Author( + id=row[0], + name=row[1], + bio=row[2], + ) + + async def list_authors(self) -> AsyncIterator[models.Author]: + result = await self._conn.stream(sqlalchemy.text(LIST_AUTHORS)) + async for row in result: + yield models.Author( + id=row[0], + name=row[1], + bio=row[2], + ) diff --git a/examples/python/src/booktest/models.py b/examples/python/src/booktest/models.py index 9c7d1a1970..7367c14720 100644 --- a/examples/python/src/booktest/models.py +++ b/examples/python/src/booktest/models.py @@ -3,22 +3,24 @@ import datetime import enum -import pydantic +import dataclasses + -# Enums class BookType(str, enum.Enum): FICTION = "FICTION" NONFICTION = "NONFICTION" -# Models -class Author(pydantic.BaseModel): +@dataclasses.dataclass() +class Author: author_id: int name: str -class Book(pydantic.BaseModel): + +@dataclasses.dataclass() +class Book: book_id: int author_id: int isbn: str diff --git a/examples/python/src/booktest/query.py b/examples/python/src/booktest/query.py index 3081b2a598..6bc73be5fb 100644 --- a/examples/python/src/booktest/query.py +++ b/examples/python/src/booktest/query.py @@ -1,14 +1,16 @@ + # Code generated by sqlc. DO NOT EDIT. -from typing import AsyncIterator, Awaitable, Iterator, List, Optional, overload +from typing import AsyncIterator, List, Optional import datetime -import pydantic -import sqlc_runtime as sqlc +import dataclasses +import sqlalchemy +import sqlalchemy.ext.asyncio from booktest import models -BOOKS_BY_TAGS = """-- name: books_by_tags :many +BOOKS_BY_TAGS = """-- name: books_by_tags \\:many SELECT book_id, title, @@ -17,11 +19,12 @@ tags FROM books LEFT JOIN authors ON books.author_id = authors.author_id -WHERE tags && $1::varchar[] +WHERE tags && :p1\\:\\:varchar[] """ -class BooksByTagsRow(pydantic.BaseModel): +@dataclasses.dataclass() +class BooksByTagsRow: book_id: int title: str name: str @@ -29,30 +32,19 @@ class BooksByTagsRow(pydantic.BaseModel): tags: List[str] -BOOKS_BY_TITLE_YEAR = """-- name: books_by_title_year :many +BOOKS_BY_TITLE_YEAR = """-- name: books_by_title_year \\:many SELECT book_id, author_id, isbn, book_type, title, year, available, tags FROM books -WHERE title = $1 AND year = $2 +WHERE title = :p1 AND year = :p2 """ -class BooksByTitleYearRow(pydantic.BaseModel): - book_id: int - author_id: int - isbn: str - book_type: models.BookType - title: str - year: int - available: datetime.datetime - tags: List[str] - - -CREATE_AUTHOR = """-- name: create_author :one -INSERT INTO authors (name) VALUES ($1) +CREATE_AUTHOR = """-- name: create_author \\:one +INSERT INTO authors (name) VALUES (:p1) RETURNING author_id, name """ -CREATE_BOOK = """-- name: create_book :one +CREATE_BOOK = """-- name: create_book \\:one INSERT INTO books ( author_id, isbn, @@ -62,30 +54,20 @@ class BooksByTitleYearRow(pydantic.BaseModel): available, tags ) VALUES ( - $1, - $2, - $3, - $4, - $5, - $6, - $7 + :p1, + :p2, + :p3, + :p4, + :p5, + :p6, + :p7 ) RETURNING book_id, author_id, isbn, book_type, title, year, available, tags """ -class CreateBookParams(pydantic.BaseModel): - author_id: int - isbn: str - book_type: models.BookType - title: str - year: int - available: datetime.datetime - tags: List[str] - - -class CreateBookRow(pydantic.BaseModel): - book_id: int +@dataclasses.dataclass() +class CreateBookParams: author_id: int isbn: str book_type: models.BookType @@ -95,172 +77,135 @@ class CreateBookRow(pydantic.BaseModel): tags: List[str] -DELETE_BOOK = """-- name: delete_book :exec +DELETE_BOOK = """-- name: delete_book \\:exec DELETE FROM books -WHERE book_id = $1 +WHERE book_id = :p1 """ -GET_AUTHOR = """-- name: get_author :one +GET_AUTHOR = """-- name: get_author \\:one SELECT author_id, name FROM authors -WHERE author_id = $1 +WHERE author_id = :p1 """ -GET_BOOK = """-- name: get_book :one +GET_BOOK = """-- name: get_book \\:one SELECT book_id, author_id, isbn, book_type, title, year, available, tags FROM books -WHERE book_id = $1 +WHERE book_id = :p1 """ -class GetBookRow(pydantic.BaseModel): - book_id: int - author_id: int - isbn: str - book_type: models.BookType - title: str - year: int - available: datetime.datetime - tags: List[str] - - -UPDATE_BOOK = """-- name: update_book :exec +UPDATE_BOOK = """-- name: update_book \\:exec UPDATE books -SET title = $1, tags = $2 -WHERE book_id = $3 +SET title = :p1, tags = :p2 +WHERE book_id = :p3 """ -UPDATE_BOOK_ISBN = """-- name: update_book_isbn :exec +UPDATE_BOOK_ISBN = """-- name: update_book_isbn \\:exec UPDATE books -SET title = $1, tags = $2, isbn = $4 -WHERE book_id = $3 +SET title = :p1, tags = :p2, isbn = :p4 +WHERE book_id = :p3 """ -@overload -def books_by_tags(conn: sqlc.Connection, dollar_1: List[str]) -> Iterator[BooksByTagsRow]: - pass - - -@overload -def books_by_tags(conn: sqlc.AsyncConnection, dollar_1: List[str]) -> AsyncIterator[BooksByTagsRow]: - pass - - -def books_by_tags(conn: sqlc.GenericConnection, dollar_1: List[str]) -> sqlc.IteratorReturn[BooksByTagsRow]: - return conn.execute_many_model(BooksByTagsRow, BOOKS_BY_TAGS, dollar_1) - - -@overload -def books_by_title_year(conn: sqlc.Connection, title: str, year: int) -> Iterator[BooksByTitleYearRow]: - pass - - -@overload -def books_by_title_year(conn: sqlc.AsyncConnection, title: str, year: int) -> AsyncIterator[BooksByTitleYearRow]: - pass - - -def books_by_title_year(conn: sqlc.GenericConnection, title: str, year: int) -> sqlc.IteratorReturn[BooksByTitleYearRow]: - return conn.execute_many_model(BooksByTitleYearRow, BOOKS_BY_TITLE_YEAR, title, year) - - -@overload -def create_author(conn: sqlc.Connection, name: str) -> Optional[models.Author]: - pass - - -@overload -def create_author(conn: sqlc.AsyncConnection, name: str) -> Awaitable[Optional[models.Author]]: - pass - - -def create_author(conn: sqlc.GenericConnection, name: str) -> sqlc.ReturnType[Optional[models.Author]]: - return conn.execute_one_model(models.Author, CREATE_AUTHOR, name) - - -@overload -def create_book(conn: sqlc.Connection, arg: CreateBookParams) -> Optional[CreateBookRow]: - pass - - -@overload -def create_book(conn: sqlc.AsyncConnection, arg: CreateBookParams) -> Awaitable[Optional[CreateBookRow]]: - pass - - -def create_book(conn: sqlc.GenericConnection, arg: CreateBookParams) -> sqlc.ReturnType[Optional[CreateBookRow]]: - return conn.execute_one_model(CreateBookRow, CREATE_BOOK, arg.author_id, arg.isbn, arg.book_type, arg.title, arg.year, arg.available, arg.tags) - - -@overload -def delete_book(conn: sqlc.Connection, book_id: int) -> None: - pass - - -@overload -def delete_book(conn: sqlc.AsyncConnection, book_id: int) -> Awaitable[None]: - pass - - -def delete_book(conn: sqlc.GenericConnection, book_id: int) -> sqlc.ReturnType[None]: - return conn.execute_none(DELETE_BOOK, book_id) - - -@overload -def get_author(conn: sqlc.Connection, author_id: int) -> Optional[models.Author]: - pass - - -@overload -def get_author(conn: sqlc.AsyncConnection, author_id: int) -> Awaitable[Optional[models.Author]]: - pass - - -def get_author(conn: sqlc.GenericConnection, author_id: int) -> sqlc.ReturnType[Optional[models.Author]]: - return conn.execute_one_model(models.Author, GET_AUTHOR, author_id) - - -@overload -def get_book(conn: sqlc.Connection, book_id: int) -> Optional[GetBookRow]: - pass - - -@overload -def get_book(conn: sqlc.AsyncConnection, book_id: int) -> Awaitable[Optional[GetBookRow]]: - pass - - -def get_book(conn: sqlc.GenericConnection, book_id: int) -> sqlc.ReturnType[Optional[GetBookRow]]: - return conn.execute_one_model(GetBookRow, GET_BOOK, book_id) - - -@overload -def update_book(conn: sqlc.Connection, title: str, tags: List[str], book_id: int) -> None: - pass - - -@overload -def update_book(conn: sqlc.AsyncConnection, title: str, tags: List[str], book_id: int) -> Awaitable[None]: - pass - - -def update_book(conn: sqlc.GenericConnection, title: str, tags: List[str], book_id: int) -> sqlc.ReturnType[None]: - return conn.execute_none(UPDATE_BOOK, title, tags, book_id) - - -@overload -def update_book_isbn(conn: sqlc.Connection, title: str, tags: List[str], book_id: int, isbn: str) -> None: - pass - - -@overload -def update_book_isbn(conn: sqlc.AsyncConnection, title: str, tags: List[str], book_id: int, isbn: str) -> Awaitable[None]: - pass - - -def update_book_isbn(conn: sqlc.GenericConnection, title: str, tags: List[str], book_id: int, isbn: str) -> sqlc.ReturnType[None]: - return conn.execute_none(UPDATE_BOOK_ISBN, title, tags, book_id, isbn) +class AsyncQuerier: + def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection): + self._conn = conn + + async def books_by_tags(self, *, dollar_1: List[str]) -> AsyncIterator[BooksByTagsRow]: + result = await self._conn.stream(sqlalchemy.text(BOOKS_BY_TAGS), {"p1": dollar_1}) + async for row in result: + yield BooksByTagsRow( + book_id=row[0], + title=row[1], + name=row[2], + isbn=row[3], + tags=row[4], + ) + + async def books_by_title_year(self, *, title: str, year: int) -> AsyncIterator[models.Book]: + result = await self._conn.stream(sqlalchemy.text(BOOKS_BY_TITLE_YEAR), {"p1": title, "p2": year}) + async for row in result: + yield models.Book( + book_id=row[0], + author_id=row[1], + isbn=row[2], + book_type=row[3], + title=row[4], + year=row[5], + available=row[6], + tags=row[7], + ) + + async def create_author(self, *, name: str) -> Optional[models.Author]: + row = (await self._conn.execute(sqlalchemy.text(CREATE_AUTHOR), {"p1": name})).first() + if row is None: + return None + return models.Author( + author_id=row[0], + name=row[1], + ) + + async def create_book(self, arg: CreateBookParams) -> Optional[models.Book]: + row = (await self._conn.execute(sqlalchemy.text(CREATE_BOOK), { + "p1": arg.author_id, + "p2": arg.isbn, + "p3": arg.book_type, + "p4": arg.title, + "p5": arg.year, + "p6": arg.available, + "p7": arg.tags, + })).first() + if row is None: + return None + return models.Book( + book_id=row[0], + author_id=row[1], + isbn=row[2], + book_type=row[3], + title=row[4], + year=row[5], + available=row[6], + tags=row[7], + ) + + async def delete_book(self, *, book_id: int) -> None: + await self._conn.execute(sqlalchemy.text(DELETE_BOOK), {"p1": book_id}) + + async def get_author(self, *, author_id: int) -> Optional[models.Author]: + row = (await self._conn.execute(sqlalchemy.text(GET_AUTHOR), {"p1": author_id})).first() + if row is None: + return None + return models.Author( + author_id=row[0], + name=row[1], + ) + + async def get_book(self, *, book_id: int) -> Optional[models.Book]: + row = (await self._conn.execute(sqlalchemy.text(GET_BOOK), {"p1": book_id})).first() + if row is None: + return None + return models.Book( + book_id=row[0], + author_id=row[1], + isbn=row[2], + book_type=row[3], + title=row[4], + year=row[5], + available=row[6], + tags=row[7], + ) + + async def update_book(self, *, title: str, tags: List[str], book_id: int) -> None: + await self._conn.execute(sqlalchemy.text(UPDATE_BOOK), {"p1": title, "p2": tags, "p3": book_id}) + + async def update_book_isbn(self, *, title: str, tags: List[str], book_id: int, isbn: str) -> None: + await self._conn.execute(sqlalchemy.text(UPDATE_BOOK_ISBN), { + "p1": title, + "p2": tags, + "p3": book_id, + "p4": isbn, + }) diff --git a/examples/python/src/dbtest/migrations.py b/examples/python/src/dbtest/migrations.py index c3c72b78a6..6ace6bcd63 100644 --- a/examples/python/src/dbtest/migrations.py +++ b/examples/python/src/dbtest/migrations.py @@ -1,29 +1,31 @@ import os from typing import List -import asyncpg -import psycopg2.extensions +import sqlalchemy +import sqlalchemy.ext.asyncio -def apply_migrations(db: psycopg2.extensions.connection, paths: List[str]): +def apply_migrations(conn: sqlalchemy.engine.Connection, paths: List[str]): files = _find_sql_files(paths) for file in files: with open(file, "r") as fd: blob = fd.read() - cur = db.cursor() - cur.execute(blob) - cur.close() - db.commit() + stmts = blob.split(";") + for stmt in stmts: + if stmt.strip(): + conn.execute(sqlalchemy.text(stmt)) -async def apply_migrations_async(db: asyncpg.Connection, paths: List[str]): +async def apply_migrations_async(conn: sqlalchemy.ext.asyncio.AsyncConnection, paths: List[str]): files = _find_sql_files(paths) for file in files: with open(file, "r") as fd: blob = fd.read() - await db.execute(blob) + raw_conn = await conn.get_raw_connection() + # The asyncpg sqlalchemy adapter uses a prepared statement cache which can't handle the migration statements + await raw_conn._connection.execute(blob) def _find_sql_files(paths: List[str]) -> List[str]: diff --git a/examples/python/src/jets/models.py b/examples/python/src/jets/models.py index cb543dcb22..13ff1a122c 100644 --- a/examples/python/src/jets/models.py +++ b/examples/python/src/jets/models.py @@ -1,13 +1,13 @@ # Code generated by sqlc. DO NOT EDIT. -import pydantic +import dataclasses -# Enums -# Models -class Jet(pydantic.BaseModel): + +@dataclasses.dataclass() +class Jet: id: int pilot_id: int age: int @@ -15,17 +15,23 @@ class Jet(pydantic.BaseModel): color: str -class Language(pydantic.BaseModel): + +@dataclasses.dataclass() +class Language: id: int language: str -class Pilot(pydantic.BaseModel): + +@dataclasses.dataclass() +class Pilot: id: int name: str -class PilotLanguage(pydantic.BaseModel): + +@dataclasses.dataclass() +class PilotLanguage: pilot_id: int language_id: int diff --git a/examples/python/src/jets/query-building.py b/examples/python/src/jets/query-building.py index 21056fe10c..0725a80902 100644 --- a/examples/python/src/jets/query-building.py +++ b/examples/python/src/jets/query-building.py @@ -1,65 +1,47 @@ + # Code generated by sqlc. DO NOT EDIT. -from typing import AsyncIterator, Awaitable, Iterator, Optional, overload +from typing import AsyncIterator, Optional -import sqlc_runtime as sqlc +import sqlalchemy +import sqlalchemy.ext.asyncio from jets import models -COUNT_PILOTS = """-- name: count_pilots :one +COUNT_PILOTS = """-- name: count_pilots \\:one SELECT COUNT(*) FROM pilots """ -DELETE_PILOT = """-- name: delete_pilot :exec -DELETE FROM pilots WHERE id = $1 +DELETE_PILOT = """-- name: delete_pilot \\:exec +DELETE FROM pilots WHERE id = :p1 """ -LIST_PILOTS = """-- name: list_pilots :many +LIST_PILOTS = """-- name: list_pilots \\:many SELECT id, name FROM pilots LIMIT 5 """ -@overload -def count_pilots(conn: sqlc.Connection) -> Optional[int]: - pass - - -@overload -def count_pilots(conn: sqlc.AsyncConnection) -> Awaitable[Optional[int]]: - pass - - -def count_pilots(conn: sqlc.GenericConnection) -> sqlc.ReturnType[Optional[int]]: - return conn.execute_one(COUNT_PILOTS) - - -@overload -def delete_pilot(conn: sqlc.Connection, id: int) -> None: - pass - - -@overload -def delete_pilot(conn: sqlc.AsyncConnection, id: int) -> Awaitable[None]: - pass - - -def delete_pilot(conn: sqlc.GenericConnection, id: int) -> sqlc.ReturnType[None]: - return conn.execute_none(DELETE_PILOT, id) - - -@overload -def list_pilots(conn: sqlc.Connection) -> Iterator[models.Pilot]: - pass - -@overload -def list_pilots(conn: sqlc.AsyncConnection) -> AsyncIterator[models.Pilot]: - pass +class AsyncQuerier: + def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection): + self._conn = conn + async def count_pilots(self) -> Optional[int]: + row = (await self._conn.execute(sqlalchemy.text(COUNT_PILOTS))).first() + if row is None: + return None + return row[0] -def list_pilots(conn: sqlc.GenericConnection) -> sqlc.IteratorReturn[models.Pilot]: - return conn.execute_many_model(models.Pilot, LIST_PILOTS) + async def delete_pilot(self, *, id: int) -> None: + await self._conn.execute(sqlalchemy.text(DELETE_PILOT), {"p1": id}) + async def list_pilots(self) -> AsyncIterator[models.Pilot]: + result = await self._conn.stream(sqlalchemy.text(LIST_PILOTS)) + async for row in result: + yield models.Pilot( + id=row[0], + name=row[1], + ) diff --git a/examples/python/src/ondeck/city.py b/examples/python/src/ondeck/city.py index baf0051e9c..d7a2123b97 100644 --- a/examples/python/src/ondeck/city.py +++ b/examples/python/src/ondeck/city.py @@ -1,96 +1,76 @@ + # Code generated by sqlc. DO NOT EDIT. -from typing import AsyncIterator, Awaitable, Iterator, Optional, overload +from typing import AsyncIterator, Optional -import sqlc_runtime as sqlc +import sqlalchemy +import sqlalchemy.ext.asyncio from ondeck import models -CREATE_CITY = """-- name: create_city :one +CREATE_CITY = """-- name: create_city \\:one INSERT INTO city ( name, slug ) VALUES ( - $1, - $2 + :p1, + :p2 ) RETURNING slug, name """ -GET_CITY = """-- name: get_city :one +GET_CITY = """-- name: get_city \\:one SELECT slug, name FROM city -WHERE slug = $1 +WHERE slug = :p1 """ -LIST_CITIES = """-- name: list_cities :many +LIST_CITIES = """-- name: list_cities \\:many SELECT slug, name FROM city ORDER BY name """ -UPDATE_CITY_NAME = """-- name: update_city_name :exec +UPDATE_CITY_NAME = """-- name: update_city_name \\:exec UPDATE city -SET name = $2 -WHERE slug = $1 +SET name = :p2 +WHERE slug = :p1 """ -@overload -def create_city(conn: sqlc.Connection, name: str, slug: str) -> Optional[models.City]: - pass - - -@overload -def create_city(conn: sqlc.AsyncConnection, name: str, slug: str) -> Awaitable[Optional[models.City]]: - pass - - -def create_city(conn: sqlc.GenericConnection, name: str, slug: str) -> sqlc.ReturnType[Optional[models.City]]: - return conn.execute_one_model(models.City, CREATE_CITY, name, slug) - - -@overload -def get_city(conn: sqlc.Connection, slug: str) -> Optional[models.City]: - pass - - -@overload -def get_city(conn: sqlc.AsyncConnection, slug: str) -> Awaitable[Optional[models.City]]: - pass - - -def get_city(conn: sqlc.GenericConnection, slug: str) -> sqlc.ReturnType[Optional[models.City]]: - return conn.execute_one_model(models.City, GET_CITY, slug) - - -@overload -def list_cities(conn: sqlc.Connection) -> Iterator[models.City]: - pass - - -@overload -def list_cities(conn: sqlc.AsyncConnection) -> AsyncIterator[models.City]: - pass - - -def list_cities(conn: sqlc.GenericConnection) -> sqlc.IteratorReturn[models.City]: - return conn.execute_many_model(models.City, LIST_CITIES) - - -@overload -def update_city_name(conn: sqlc.Connection, slug: str, name: str) -> None: - pass - - -@overload -def update_city_name(conn: sqlc.AsyncConnection, slug: str, name: str) -> Awaitable[None]: - pass - - -def update_city_name(conn: sqlc.GenericConnection, slug: str, name: str) -> sqlc.ReturnType[None]: - return conn.execute_none(UPDATE_CITY_NAME, slug, name) +class AsyncQuerier: + def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection): + self._conn = conn + + async def create_city(self, *, name: str, slug: str) -> Optional[models.City]: + row = (await self._conn.execute(sqlalchemy.text(CREATE_CITY), {"p1": name, "p2": slug})).first() + if row is None: + return None + return models.City( + slug=row[0], + name=row[1], + ) + + async def get_city(self, *, slug: str) -> Optional[models.City]: + row = (await self._conn.execute(sqlalchemy.text(GET_CITY), {"p1": slug})).first() + if row is None: + return None + return models.City( + slug=row[0], + name=row[1], + ) + + async def list_cities(self) -> AsyncIterator[models.City]: + result = await self._conn.stream(sqlalchemy.text(LIST_CITIES)) + async for row in result: + yield models.City( + slug=row[0], + name=row[1], + ) + + async def update_city_name(self, *, slug: str, name: str) -> None: + await self._conn.execute(sqlalchemy.text(UPDATE_CITY_NAME), {"p1": slug, "p2": name}) diff --git a/examples/python/src/ondeck/models.py b/examples/python/src/ondeck/models.py index bc8244b2a3..6e2de33b5e 100644 --- a/examples/python/src/ondeck/models.py +++ b/examples/python/src/ondeck/models.py @@ -3,22 +3,24 @@ import datetime import enum -import pydantic +import dataclasses -# Enums# Venues can be either open or closed +# Venues can be either open or closed class Status(str, enum.Enum): OPEN = "op!en" CLOSED = "clo@sed" -# Models -class City(pydantic.BaseModel): +@dataclasses.dataclass() +class City: slug: str name: str + # Venues are places where muisc happens -class Venue(pydantic.BaseModel): +@dataclasses.dataclass() +class Venue: id: int status: Status statuses: Optional[List[Status]] diff --git a/examples/python/src/ondeck/venue.py b/examples/python/src/ondeck/venue.py index 289725e0d4..c0e31be474 100644 --- a/examples/python/src/ondeck/venue.py +++ b/examples/python/src/ondeck/venue.py @@ -1,14 +1,15 @@ + # Code generated by sqlc. DO NOT EDIT. -from typing import AsyncIterator, Awaitable, Iterator, List, Optional, overload -import datetime +from typing import AsyncIterator, List, Optional -import pydantic -import sqlc_runtime as sqlc +import dataclasses +import sqlalchemy +import sqlalchemy.ext.asyncio from ondeck import models -CREATE_VENUE = """-- name: create_venue :one +CREATE_VENUE = """-- name: create_venue \\:one INSERT INTO venue ( slug, name, @@ -19,19 +20,20 @@ statuses, tags ) VALUES ( - $1, - $2, - $3, + :p1, + :p2, + :p3, NOW(), - $4, - $5, - $6, - $7 + :p4, + :p5, + :p6, + :p7 ) RETURNING id """ -class CreateVenueParams(pydantic.BaseModel): +@dataclasses.dataclass() +class CreateVenueParams: slug: str name: str city: str @@ -41,62 +43,36 @@ class CreateVenueParams(pydantic.BaseModel): tags: Optional[List[str]] -DELETE_VENUE = """-- name: delete_venue :exec +DELETE_VENUE = """-- name: delete_venue \\:exec DELETE FROM venue -WHERE slug = $1 AND slug = $1 +WHERE slug = :p1 AND slug = :p1 """ -GET_VENUE = """-- name: get_venue :one +GET_VENUE = """-- name: get_venue \\:one SELECT id, status, statuses, slug, name, city, spotify_playlist, songkick_id, tags, created_at FROM venue -WHERE slug = $1 AND city = $2 +WHERE slug = :p1 AND city = :p2 """ -class GetVenueRow(pydantic.BaseModel): - id: int - status: models.Status - statuses: Optional[List[models.Status]] - slug: str - name: str - city: str - spotify_playlist: str - songkick_id: Optional[str] - tags: Optional[List[str]] - created_at: datetime.datetime - - -LIST_VENUES = """-- name: list_venues :many +LIST_VENUES = """-- name: list_venues \\:many SELECT id, status, statuses, slug, name, city, spotify_playlist, songkick_id, tags, created_at FROM venue -WHERE city = $1 +WHERE city = :p1 ORDER BY name """ -class ListVenuesRow(pydantic.BaseModel): - id: int - status: models.Status - statuses: Optional[List[models.Status]] - slug: str - name: str - city: str - spotify_playlist: str - songkick_id: Optional[str] - tags: Optional[List[str]] - created_at: datetime.datetime - - -UPDATE_VENUE_NAME = """-- name: update_venue_name :one +UPDATE_VENUE_NAME = """-- name: update_venue_name \\:one UPDATE venue -SET name = $2 -WHERE slug = $1 +SET name = :p2 +WHERE slug = :p1 RETURNING id """ -VENUE_COUNT_BY_CITY = """-- name: venue_count_by_city :many +VENUE_COUNT_BY_CITY = """-- name: venue_count_by_city \\:many SELECT city, count(*) @@ -106,92 +82,78 @@ class ListVenuesRow(pydantic.BaseModel): """ -class VenueCountByCityRow(pydantic.BaseModel): +@dataclasses.dataclass() +class VenueCountByCityRow: city: str count: int -@overload -def create_venue(conn: sqlc.Connection, arg: CreateVenueParams) -> Optional[int]: - pass - - -@overload -def create_venue(conn: sqlc.AsyncConnection, arg: CreateVenueParams) -> Awaitable[Optional[int]]: - pass - - -def create_venue(conn: sqlc.GenericConnection, arg: CreateVenueParams) -> sqlc.ReturnType[Optional[int]]: - return conn.execute_one(CREATE_VENUE, arg.slug, arg.name, arg.city, arg.spotify_playlist, arg.status, arg.statuses, arg.tags) - - -@overload -def delete_venue(conn: sqlc.Connection, slug: str) -> None: - pass - - -@overload -def delete_venue(conn: sqlc.AsyncConnection, slug: str) -> Awaitable[None]: - pass - - -def delete_venue(conn: sqlc.GenericConnection, slug: str) -> sqlc.ReturnType[None]: - return conn.execute_none(DELETE_VENUE, slug) - - -@overload -def get_venue(conn: sqlc.Connection, slug: str, city: str) -> Optional[GetVenueRow]: - pass - - -@overload -def get_venue(conn: sqlc.AsyncConnection, slug: str, city: str) -> Awaitable[Optional[GetVenueRow]]: - pass - - -def get_venue(conn: sqlc.GenericConnection, slug: str, city: str) -> sqlc.ReturnType[Optional[GetVenueRow]]: - return conn.execute_one_model(GetVenueRow, GET_VENUE, slug, city) - - -@overload -def list_venues(conn: sqlc.Connection, city: str) -> Iterator[ListVenuesRow]: - pass - - -@overload -def list_venues(conn: sqlc.AsyncConnection, city: str) -> AsyncIterator[ListVenuesRow]: - pass - - -def list_venues(conn: sqlc.GenericConnection, city: str) -> sqlc.IteratorReturn[ListVenuesRow]: - return conn.execute_many_model(ListVenuesRow, LIST_VENUES, city) - - -@overload -def update_venue_name(conn: sqlc.Connection, slug: str, name: str) -> Optional[int]: - pass - - -@overload -def update_venue_name(conn: sqlc.AsyncConnection, slug: str, name: str) -> Awaitable[Optional[int]]: - pass - - -def update_venue_name(conn: sqlc.GenericConnection, slug: str, name: str) -> sqlc.ReturnType[Optional[int]]: - return conn.execute_one(UPDATE_VENUE_NAME, slug, name) - - -@overload -def venue_count_by_city(conn: sqlc.Connection) -> Iterator[VenueCountByCityRow]: - pass - - -@overload -def venue_count_by_city(conn: sqlc.AsyncConnection) -> AsyncIterator[VenueCountByCityRow]: - pass - - -def venue_count_by_city(conn: sqlc.GenericConnection) -> sqlc.IteratorReturn[VenueCountByCityRow]: - return conn.execute_many_model(VenueCountByCityRow, VENUE_COUNT_BY_CITY) +class AsyncQuerier: + def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection): + self._conn = conn + + async def create_venue(self, arg: CreateVenueParams) -> Optional[int]: + row = (await self._conn.execute(sqlalchemy.text(CREATE_VENUE), { + "p1": arg.slug, + "p2": arg.name, + "p3": arg.city, + "p4": arg.spotify_playlist, + "p5": arg.status, + "p6": arg.statuses, + "p7": arg.tags, + })).first() + if row is None: + return None + return row[0] + + async def delete_venue(self, *, slug: str) -> None: + await self._conn.execute(sqlalchemy.text(DELETE_VENUE), {"p1": slug}) + + async def get_venue(self, *, slug: str, city: str) -> Optional[models.Venue]: + row = (await self._conn.execute(sqlalchemy.text(GET_VENUE), {"p1": slug, "p2": city})).first() + if row is None: + return None + return models.Venue( + id=row[0], + status=row[1], + statuses=row[2], + slug=row[3], + name=row[4], + city=row[5], + spotify_playlist=row[6], + songkick_id=row[7], + tags=row[8], + created_at=row[9], + ) + + async def list_venues(self, *, city: str) -> AsyncIterator[models.Venue]: + result = await self._conn.stream(sqlalchemy.text(LIST_VENUES), {"p1": city}) + async for row in result: + yield models.Venue( + id=row[0], + status=row[1], + statuses=row[2], + slug=row[3], + name=row[4], + city=row[5], + spotify_playlist=row[6], + songkick_id=row[7], + tags=row[8], + created_at=row[9], + ) + + async def update_venue_name(self, *, slug: str, name: str) -> Optional[int]: + row = (await self._conn.execute(sqlalchemy.text(UPDATE_VENUE_NAME), {"p1": slug, "p2": name})).first() + if row is None: + return None + return row[0] + + async def venue_count_by_city(self) -> AsyncIterator[VenueCountByCityRow]: + result = await self._conn.stream(sqlalchemy.text(VENUE_COUNT_BY_CITY)) + async for row in result: + yield VenueCountByCityRow( + city=row[0], + count=row[1], + ) diff --git a/examples/python/src/tests/conftest.py b/examples/python/src/tests/conftest.py index e3df5f77dc..f807209229 100644 --- a/examples/python/src/tests/conftest.py +++ b/examples/python/src/tests/conftest.py @@ -2,10 +2,9 @@ import os import random -import asyncpg -import psycopg2 -import psycopg2.extensions import pytest +import sqlalchemy +import sqlalchemy.ext.asyncio @pytest.fixture(scope="session") @@ -16,31 +15,48 @@ def postgres_uri() -> str: pg_password = os.environ.get("PG_PASSWORD", "mysecretpassword") pg_db = os.environ.get("PG_DATABASE", "dinotest") - return f"postgres://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{pg_db}?sslmode=disable" + return f"postgresql://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{pg_db}" @pytest.fixture(scope="session") -def postgres_connection(postgres_uri) -> psycopg2.extensions.connection: - conn = psycopg2.connect(postgres_uri) +def sqlalchemy_connection(postgres_uri) -> sqlalchemy.engine.Connection: + engine = sqlalchemy.create_engine(postgres_uri, future=True) + with engine.connect() as conn: + yield conn + + +@pytest.fixture(scope="function") +def db(sqlalchemy_connection: sqlalchemy.engine.Connection) -> sqlalchemy.engine.Connection: + conn = sqlalchemy_connection + schema_name = f"sqltest_{random.randint(0, 1000)}" + conn.execute(sqlalchemy.text(f"CREATE SCHEMA {schema_name}")) + conn.execute(sqlalchemy.text(f"SET search_path TO {schema_name}")) + conn.commit() yield conn - conn.close() + conn.rollback() + conn.execute(sqlalchemy.text(f"DROP SCHEMA {schema_name} CASCADE")) + conn.execute(sqlalchemy.text("SET search_path TO public")) + + +@pytest.fixture(scope="session") +async def async_sqlalchemy_connection(postgres_uri) -> sqlalchemy.ext.asyncio.AsyncConnection: + postgres_uri = postgres_uri.replace("postgresql", "postgresql+asyncpg") + engine = sqlalchemy.ext.asyncio.create_async_engine(postgres_uri) + async with engine.connect() as conn: + yield conn -@pytest.fixture() -def postgres_db(postgres_connection) -> psycopg2.extensions.connection: +@pytest.fixture(scope="function") +async def async_db(async_sqlalchemy_connection: sqlalchemy.ext.asyncio.AsyncConnection) -> sqlalchemy.ext.asyncio.AsyncConnection: + conn = async_sqlalchemy_connection schema_name = f"sqltest_{random.randint(0, 1000)}" - cur = postgres_connection.cursor() - cur.execute(f"CREATE SCHEMA {schema_name}") - cur.execute(f"SET search_path TO {schema_name}") - cur.close() - postgres_connection.commit() - yield postgres_connection - postgres_connection.rollback() - cur = postgres_connection.cursor() - cur.execute(f"DROP SCHEMA {schema_name} CASCADE") - cur.execute(f"SET search_path TO public") - cur.close() - postgres_connection.commit() + await conn.execute(sqlalchemy.text(f"CREATE SCHEMA {schema_name}")) + await conn.execute(sqlalchemy.text(f"SET search_path TO {schema_name}")) + await conn.commit() + yield conn + await conn.rollback() + await conn.execute(sqlalchemy.text(f"DROP SCHEMA {schema_name} CASCADE")) + await conn.execute(sqlalchemy.text("SET search_path TO public")) @pytest.fixture(scope="session") @@ -49,21 +65,3 @@ def event_loop(): loop = asyncio.get_event_loop_policy().new_event_loop() yield loop loop.close() - - -@pytest.fixture(scope="session") -async def async_postgres_connection(postgres_uri: str) -> asyncpg.Connection: - conn = await asyncpg.connect(postgres_uri) - yield conn - await conn.close() - - -@pytest.fixture() -async def async_postgres_db(async_postgres_connection: asyncpg.Connection) -> asyncpg.Connection: - conn = async_postgres_connection - schema_name = f"sqltest_{random.randint(0, 1000)}" - await conn.execute(f"CREATE SCHEMA {schema_name}") - await conn.execute(f"SET search_path TO {schema_name}") - yield conn - await conn.execute(f"DROP SCHEMA {schema_name} CASCADE") - await conn.execute(f"SET search_path TO public") diff --git a/examples/python/src/tests/test_authors.py b/examples/python/src/tests/test_authors.py index bc01f3133b..7b0a954276 100644 --- a/examples/python/src/tests/test_authors.py +++ b/examples/python/src/tests/test_authors.py @@ -1,59 +1,56 @@ import os -import asyncpg -import psycopg2.extensions import pytest -from sqlc_runtime.psycopg2 import build_psycopg2_connection -from sqlc_runtime.asyncpg import build_asyncpg_connection +import sqlalchemy.ext.asyncio from authors import query from dbtest.migrations import apply_migrations, apply_migrations_async -def test_authors(postgres_db: psycopg2.extensions.connection): - apply_migrations(postgres_db, [os.path.dirname(__file__) + "/../../../authors/postgresql/schema.sql"]) +def test_authors(db: sqlalchemy.engine.Connection): + apply_migrations(db, [os.path.dirname(__file__) + "/../../../authors/postgresql/schema.sql"]) - db = build_psycopg2_connection(postgres_db) + querier = query.Querier(db) - authors = list(query.list_authors(db)) + authors = list(querier.list_authors()) assert authors == [] author_name = "Brian Kernighan" author_bio = "Co-author of The C Programming Language and The Go Programming Language" - new_author = query.create_author(db, name=author_name, bio=author_bio) + new_author = querier.create_author(name=author_name, bio=author_bio) assert new_author.id > 0 assert new_author.name == author_name assert new_author.bio == author_bio - db_author = query.get_author(db, new_author.id) + db_author = querier.get_author(id=new_author.id) assert db_author == new_author - author_list = list(query.list_authors(db)) + author_list = list(querier.list_authors()) assert len(author_list) == 1 assert author_list[0] == new_author @pytest.mark.asyncio -async def test_authors_async(async_postgres_db: asyncpg.Connection): - await apply_migrations_async(async_postgres_db, [os.path.dirname(__file__) + "/../../../authors/postgresql/schema.sql"]) +async def test_authors_async(async_db: sqlalchemy.ext.asyncio.AsyncConnection): + await apply_migrations_async(async_db, [os.path.dirname(__file__) + "/../../../authors/postgresql/schema.sql"]) - db = build_asyncpg_connection(async_postgres_db) + querier = query.AsyncQuerier(async_db) - async for _ in query.list_authors(db): + async for _ in querier.list_authors(): assert False, "No authors should exist" author_name = "Brian Kernighan" author_bio = "Co-author of The C Programming Language and The Go Programming Language" - new_author = await query.create_author(db, name=author_name, bio=author_bio) + new_author = await querier.create_author(name=author_name, bio=author_bio) assert new_author.id > 0 assert new_author.name == author_name assert new_author.bio == author_bio - db_author = await query.get_author(db, new_author.id) + db_author = await querier.get_author(id=new_author.id) assert db_author == new_author author_list = [] - async for author in query.list_authors(db): + async for author in querier.list_authors(): author_list.append(author) assert len(author_list) == 1 assert author_list[0] == new_author diff --git a/examples/python/src/tests/test_booktest.py b/examples/python/src/tests/test_booktest.py index b0ba38891a..6106d9d3fd 100644 --- a/examples/python/src/tests/test_booktest.py +++ b/examples/python/src/tests/test_booktest.py @@ -1,87 +1,85 @@ import datetime import os -import asyncpg import pytest -from sqlc_runtime.asyncpg import build_asyncpg_connection +import sqlalchemy.ext.asyncio from booktest import query, models from dbtest.migrations import apply_migrations_async @pytest.mark.asyncio -async def test_books(async_postgres_db: asyncpg.Connection): - await apply_migrations_async(async_postgres_db, [os.path.dirname(__file__) + "/../../../booktest/postgresql/schema.sql"]) +async def test_books(async_db: sqlalchemy.ext.asyncio.AsyncConnection): + await apply_migrations_async(async_db, [os.path.dirname(__file__) + "/../../../booktest/postgresql/schema.sql"]) - db = build_asyncpg_connection(async_postgres_db) + querier = query.AsyncQuerier(async_db) - author = await query.create_author(db, "Unknown Master") + author = await querier.create_author(name="Unknown Master") assert author is not None - async with async_postgres_db.transaction(): - now = datetime.datetime.now() - await query.create_book(db, query.CreateBookParams( - author_id=author.author_id, - isbn="1", - title="my book title", - book_type=models.BookType.FICTION, - year=2016, - available=now, - tags=[], - )) - - b1 = await query.create_book(db, query.CreateBookParams( - author_id=author.author_id, - isbn="2", - title="the second book", - book_type=models.BookType.FICTION, - year=2016, - available=now, - tags=["cool", "unique"], - )) - - await query.update_book(db, book_id=b1.book_id, title="changed second title", tags=["cool", "disastor"]) - - b3 = await query.create_book(db, query.CreateBookParams( - author_id=author.author_id, - isbn="3", - title="the third book", - book_type=models.BookType.FICTION, - year=2001, - available=now, - tags=["cool"], - )) - - b4 = await query.create_book(db, query.CreateBookParams( - author_id=author.author_id, - isbn="4", - title="4th place finisher", - book_type=models.BookType.NONFICTION, - year=2011, - available=now, - tags=["other"], - )) - - await query.update_book_isbn(db, book_id=b4.book_id, isbn="NEW ISBN", title="never ever gonna finish, a quatrain", tags=["someother"]) - - books0 = query.books_by_title_year(db, title="my book title", year=2016) + now = datetime.datetime.now() + await querier.create_book(query.CreateBookParams( + author_id=author.author_id, + isbn="1", + title="my book title", + book_type=models.BookType.FICTION, + year=2016, + available=now, + tags=[], + )) + + b1 = await querier.create_book(query.CreateBookParams( + author_id=author.author_id, + isbn="2", + title="the second book", + book_type=models.BookType.FICTION, + year=2016, + available=now, + tags=["cool", "unique"], + )) + + await querier.update_book(book_id=b1.book_id, title="changed second title", tags=["cool", "disastor"]) + + b3 = await querier.create_book(query.CreateBookParams( + author_id=author.author_id, + isbn="3", + title="the third book", + book_type=models.BookType.FICTION, + year=2001, + available=now, + tags=["cool"], + )) + + b4 = await querier.create_book(query.CreateBookParams( + author_id=author.author_id, + isbn="4", + title="4th place finisher", + book_type=models.BookType.NONFICTION, + year=2011, + available=now, + tags=["other"], + )) + + await querier.update_book_isbn(book_id=b4.book_id, isbn="NEW ISBN", title="never ever gonna finish, a quatrain", tags=["someother"]) + + books0 = querier.books_by_title_year(title="my book title", year=2016) expected_titles = {"my book title"} async for book in books0: expected_titles.remove(book.title) # raises a key error if the title does not exist assert len(book.tags) == 0 - author = await query.get_author(db, author_id=book.author_id) + author = await querier.get_author(author_id=book.author_id) assert author.name == "Unknown Master" assert len(expected_titles) == 0 - books = query.books_by_tags(db, ["cool", "other", "someother"]) + books = querier.books_by_tags(dollar_1=["cool", "other", "someother"]) expected_titles = {"changed second title", "the third book", "never ever gonna finish, a quatrain"} async for book in books: expected_titles.remove(book.title) assert len(expected_titles) == 0 - b5 = await query.get_book(db, b3.book_id) + b5 = await querier.get_book(book_id=b3.book_id) assert b5 is not None - await query.delete_book(db, book_id=b5.book_id) - b6 = await query.get_book(db, b5.book_id) + await querier.delete_book(book_id=b5.book_id) + b6 = await querier.get_book(book_id=b5.book_id) assert b6 is None diff --git a/examples/python/src/tests/test_ondeck.py b/examples/python/src/tests/test_ondeck.py index f12fbe985c..68cfbc9bcb 100644 --- a/examples/python/src/tests/test_ondeck.py +++ b/examples/python/src/tests/test_ondeck.py @@ -1,8 +1,7 @@ import os -import asyncpg import pytest -from sqlc_runtime.asyncpg import build_asyncpg_connection +import sqlalchemy.ext.asyncio from ondeck import models from ondeck import city as city_queries @@ -11,15 +10,16 @@ @pytest.mark.asyncio -async def test_ondeck(async_postgres_db: asyncpg.Connection): - await apply_migrations_async(async_postgres_db, [os.path.dirname(__file__) + "/../../../ondeck/postgresql/schema"]) +async def test_ondeck(async_db: sqlalchemy.ext.asyncio.AsyncConnection): + await apply_migrations_async(async_db, [os.path.dirname(__file__) + "/../../../ondeck/postgresql/schema"]) - db = build_asyncpg_connection(async_postgres_db) + city_querier = city_queries.AsyncQuerier(async_db) + venue_querier = venue_queries.AsyncQuerier(async_db) - city = await city_queries.create_city(db, slug="san-francisco", name="San Francisco") + city = await city_querier.create_city(slug="san-francisco", name="San Francisco") assert city is not None - venue_id = await venue_queries.create_venue(db, venue_queries.CreateVenueParams( + venue_id = await venue_querier.create_venue(venue_queries.CreateVenueParams( slug="the-fillmore", name="The Fillmore", city=city.slug, @@ -30,20 +30,20 @@ async def test_ondeck(async_postgres_db: asyncpg.Connection): )) assert venue_id is not None - venue = await venue_queries.get_venue(db, slug="the-fillmore", city=city.slug) + venue = await venue_querier.get_venue(slug="the-fillmore", city=city.slug) assert venue is not None assert venue.id == venue_id - assert city == await city_queries.get_city(db, city.slug) - assert [venue_queries.VenueCountByCityRow(city=city.slug, count=1)] == await _to_list(venue_queries.venue_count_by_city(db)) - assert [city] == await _to_list(city_queries.list_cities(db)) - assert [venue] == await _to_list(venue_queries.list_venues(db, city=city.slug)) + assert city == await city_querier.get_city(slug=city.slug) + assert [venue_queries.VenueCountByCityRow(city=city.slug, count=1)] == await _to_list(venue_querier.venue_count_by_city()) + assert [city] == await _to_list(city_querier.list_cities()) + assert [venue] == await _to_list(venue_querier.list_venues(city=city.slug)) - await city_queries.update_city_name(db, slug=city.slug, name="SF") - _id = await venue_queries.update_venue_name(db, slug=venue.slug, name="Fillmore") + await city_querier.update_city_name(slug=city.slug, name="SF") + _id = await venue_querier.update_venue_name(slug=venue.slug, name="Fillmore") assert _id == venue_id - await venue_queries.delete_venue(db, slug=venue.slug) + await venue_querier.delete_venue(slug=venue.slug) async def _to_list(it): diff --git a/internal/codegen/python/gen.go b/internal/codegen/python/gen.go index fda5374524..a30dd3ef50 100644 --- a/internal/codegen/python/gen.go +++ b/internal/codegen/python/gen.go @@ -60,19 +60,6 @@ type Struct struct { Comment string } -func (s Struct) DedupFields() []Field { - seen := map[string]struct{}{} - dedupFields := make([]Field, 0) - for _, f := range s.Fields { - if _, ok := seen[f.Name]; ok { - continue - } - seen[f.Name] = struct{}{} - dedupFields = append(dedupFields, f) - } - return dedupFields -} - type QueryValue struct { Emit bool Name string @@ -113,6 +100,19 @@ func (v QueryValue) Type() string { panic("no type for QueryValue: " + v.Name) } +func (v QueryValue) StructRowParser(rowVar string, indentCount int) string { + if !v.IsStruct() { + panic("StructRowParse called on non-struct QueryValue") + } + indent := strings.Repeat(" ", indentCount+4) + params := make([]string, 0, len(v.Struct.Fields)) + for i, f := range v.Struct.Fields { + params = append(params, fmt.Sprintf("%s%s=%s[%v],", indent, f.Name, rowVar, i)) + } + indent = strings.Repeat(" ", indentCount) + return v.Type() + "(\n" + strings.Join(params, "\n") + "\n" + indent + ")" +} + // A struct used to generate methods and fields on the Queries struct type Query struct { Cmd string @@ -127,6 +127,10 @@ type Query struct { } func (q Query) ArgPairs() string { + // A single struct arg does not need to be passed as a keyword argument + if len(q.Args) == 1 && q.Args[0].IsStruct() { + return ", " + q.Args[0].Pair() + } argPairs := make([]string, 0, len(q.Args)) for _, a := range q.Args { argPairs = append(argPairs, a.Pair()) @@ -134,27 +138,33 @@ func (q Query) ArgPairs() string { if len(argPairs) == 0 { return "" } - return ", " + strings.Join(argPairs, ", ") + return ", *, " + strings.Join(argPairs, ", ") } -func (q Query) ArgParams() string { +func (q Query) ArgDict() string { params := make([]string, 0, len(q.Args)) + i := 1 for _, a := range q.Args { if a.isEmpty() { continue } if a.IsStruct() { for _, f := range a.Struct.Fields { - params = append(params, a.Name+"."+f.Name) + params = append(params, fmt.Sprintf("\"p%v\": %s", i, a.Name+"."+f.Name)) + i++ } } else { - params = append(params, a.Name) + params = append(params, fmt.Sprintf("\"p%v\": %s", i, a.Name)) + i++ } } if len(params) == 0 { return "" } - return ", " + strings.Join(params, ", ") + if len(params) < 4 { + return ", {" + strings.Join(params, ", ") + "}" + } + return ", {\n " + strings.Join(params, ",\n ") + ",\n }" } func makePyType(r *compiler.Result, col *compiler.Column, settings config.CombinedSettings) pyType { @@ -356,6 +366,18 @@ func sameTableName(n *ast.TableName, f core.FQN, defaultSchema string) bool { return n.Catalog == f.Catalog && schema == f.Schema && n.Name == f.Rel } +var postgresPlaceholderRegexp = regexp.MustCompile(`\B\$(\d+)\b`) + +// Sqlalchemy uses ":name" for placeholders, so "$N" is converted to ":pN" +// This also means ":" has special meaning to sqlalchemy, so it must be escaped. +func sqlalchemySQL(s string, engine config.Engine) string { + s = strings.ReplaceAll(s, ":", `\\:`) + if engine == config.EnginePostgreSQL { + return postgresPlaceholderRegexp.ReplaceAllString(s, ":p$1") + } + return s +} + func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs []Struct) []Query { qs := make([]Query, 0, len(r.Queries)) for _, query := range r.Queries { @@ -374,7 +396,7 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs MethodName: methodName, FieldName: codegen.LowerTitle(query.Name) + "Stmt", ConstantName: strings.ToUpper(methodName), - SQL: query.SQL, + SQL: sqlalchemySQL(query.SQL, settings.Package.Engine), SourceName: query.Filename, } @@ -419,8 +441,11 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs same := true for i, f := range s.Fields { c := query.Columns[i] + // HACK: models do not have "models." on their types, so trim that so we can find matches + trimmedPyType := makePyType(r, c, settings) + trimmedPyType.InnerType = strings.TrimPrefix(trimmedPyType.InnerType, "models.") sameName := f.Name == columnName(c, i) - sameType := f.Type == makePyType(r, c, settings) + sameType := f.Type == trimmedPyType sameTable := sameTableName(c.Table, s.Table, r.Catalog.DefaultSchema) if !sameName || !sameType || !sameTable { same = false @@ -462,8 +487,7 @@ var modelsTmpl = `# Code generated by sqlc. DO NOT EDIT. {{- end}} -# Enums -{{- range .Enums}} +{{range .Enums}} {{- if .Comment}}{{comment .Comment}}{{- end}} class {{.Name}}(str, enum.Enum): {{- range .Constants}} @@ -471,10 +495,10 @@ class {{.Name}}(str, enum.Enum): {{- end}} {{end}} -# Models {{- range .Models}} -{{- if .Comment}}{{comment .Comment}}{{- end}} -class {{.Name}}(pydantic.BaseModel): {{- range .DedupFields}} +{{if .Comment}}{{comment .Comment}}{{- end}} +@dataclasses.dataclass() +class {{.Name}}: {{- range .Fields}} {{- if .Comment}} {{comment .Comment}}{{else}} {{- end}} @@ -484,124 +508,142 @@ class {{.Name}}(pydantic.BaseModel): {{- range .DedupFields}} {{end}} ` -var queriesTmpl = `# Code generated by sqlc. DO NOT EDIT. +var queriesTmpl = ` +{{- define "dataclassParse"}} + +{{end}} +# Code generated by sqlc. DO NOT EDIT. {{- range imports .SourceName}} {{.}} {{- end}} {{range .Queries}} {{- if $.OutputQuery .SourceName}} -{{.ConstantName}} = """-- name: {{.MethodName}} {{.Cmd}} +{{.ConstantName}} = """-- name: {{.MethodName}} \\{{.Cmd}} {{.SQL}} """ {{range .Args}} {{- if .EmitStruct}} -class {{.Type}}(pydantic.BaseModel): {{- range .Struct.DedupFields}} +@dataclasses.dataclass() +class {{.Type}}: {{- range .Struct.Fields}} {{.Name}}: {{.Type}} {{- end}} {{end}}{{end}} {{- if .Ret.EmitStruct}} -class {{.Ret.Type}}(pydantic.BaseModel): {{- range .Ret.Struct.DedupFields}} +@dataclasses.dataclass() +class {{.Ret.Type}}: {{- range .Ret.Struct.Fields}} {{.Name}}: {{.Type}} {{- end}} {{end}} {{end}} {{- end}} -{{- range .Queries}} +{{- if .EmitSync}} +class Querier: + def __init__(self, conn: sqlalchemy.engine.Connection): + self._conn = conn +{{range .Queries}} {{- if $.OutputQuery .SourceName}} {{- if eq .Cmd ":one"}} -@overload -def {{.MethodName}}(conn: sqlc.Connection{{.ArgPairs}}) -> Optional[{{.Ret.Type}}]: - pass - - -@overload -def {{.MethodName}}(conn: sqlc.AsyncConnection{{.ArgPairs}}) -> Awaitable[Optional[{{.Ret.Type}}]]: - pass - - -def {{.MethodName}}(conn: sqlc.GenericConnection{{.ArgPairs}}) -> sqlc.ReturnType[Optional[{{.Ret.Type}}]]: - {{- if .Ret.IsStruct}} - return conn.execute_one_model({{.Ret.Type}}, {{.ConstantName}}{{.ArgParams}}) - {{- else}} - return conn.execute_one({{.ConstantName}}{{.ArgParams}}) - {{- end}} + def {{.MethodName}}(self{{.ArgPairs}}) -> Optional[{{.Ret.Type}}]: + row = self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}}).first() + if row is None: + return None + {{- if .Ret.IsStruct}} + return {{.Ret.StructRowParser "row" 8}} + {{- else}} + return row[0] + {{- end}} {{end}} {{- if eq .Cmd ":many"}} -@overload -def {{.MethodName}}(conn: sqlc.Connection{{.ArgPairs}}) -> Iterator[{{.Ret.Type}}]: - pass - - -@overload -def {{.MethodName}}(conn: sqlc.AsyncConnection{{.ArgPairs}}) -> AsyncIterator[{{.Ret.Type}}]: - pass - - -def {{.MethodName}}(conn: sqlc.GenericConnection{{.ArgPairs}}) -> sqlc.IteratorReturn[{{.Ret.Type}}]: - {{- if .Ret.IsStruct}} - return conn.execute_many_model({{.Ret.Type}}, {{.ConstantName}}{{.ArgParams}}) - {{- else}} - return conn.execute_many({{.ConstantName}}{{.ArgParams}}) - {{- end}} + def {{.MethodName}}(self{{.ArgPairs}}) -> Iterator[{{.Ret.Type}}]: + result = self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}}) + for row in result: + {{- if .Ret.IsStruct}} + yield {{.Ret.StructRowParser "row" 12}} + {{- else}} + yield row[0] + {{- end}} {{end}} {{- if eq .Cmd ":exec"}} -@overload -def {{.MethodName}}(conn: sqlc.Connection{{.ArgPairs}}) -> None: - pass - - -@overload -def {{.MethodName}}(conn: sqlc.AsyncConnection{{.ArgPairs}}) -> Awaitable[None]: - pass - - -def {{.MethodName}}(conn: sqlc.GenericConnection{{.ArgPairs}}) -> sqlc.ReturnType[None]: - return conn.execute_none({{.ConstantName}}{{.ArgParams}}) + def {{.MethodName}}(self{{.ArgPairs}}) -> None: + self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}}) {{end}} {{- if eq .Cmd ":execrows"}} -@overload -def {{.MethodName}}(conn: sqlc.Connection{{.ArgPairs}}) -> int: - pass - - -@overload -def {{.MethodName}}(conn: sqlc.AsyncConnection{{.ArgPairs}}) -> Awaitable[int]: - pass - - -def {{.MethodName}}(conn: sqlc.GenericConnection{{.ArgPairs}}) -> sqlc.ReturnType[int]: - return conn.execute_rowcount({{.ConstantName}}{{.ArgParams}}) + def {{.MethodName}}(self{{.ArgPairs}}) -> int: + result = self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}}) + return result.rowcount {{end}} {{- if eq .Cmd ":execresult"}} -@overload -def {{.MethodName}}(conn: sqlc.Connection{{.ArgPairs}}) -> sqlc.Cursor: - pass + def {{.MethodName}}(self{{.ArgPairs}}) -> sqlalchemy.engine.Result: + return self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}}) +{{end}} +{{- end}} +{{- end}} +{{- end}} +{{- if .EmitAsync}} -@overload -def {{.MethodName}}(conn: sqlc.AsyncConnection{{.ArgPairs}}) -> sqlc.AsyncCursor: - pass +class AsyncQuerier: + def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection): + self._conn = conn +{{range .Queries}} +{{- if $.OutputQuery .SourceName}} +{{- if eq .Cmd ":one"}} + async def {{.MethodName}}(self{{.ArgPairs}}) -> Optional[{{.Ret.Type}}]: + row = (await self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}})).first() + if row is None: + return None + {{- if .Ret.IsStruct}} + return {{.Ret.StructRowParser "row" 8}} + {{- else}} + return row[0] + {{- end}} +{{end}} +{{- if eq .Cmd ":many"}} + async def {{.MethodName}}(self{{.ArgPairs}}) -> AsyncIterator[{{.Ret.Type}}]: + result = await self._conn.stream(sqlalchemy.text({{.ConstantName}}){{.ArgDict}}) + async for row in result: + {{- if .Ret.IsStruct}} + yield {{.Ret.StructRowParser "row" 12}} + {{- else}} + yield row[0] + {{- end}} +{{end}} -def {{.MethodName}}(conn: sqlc.GenericConnection{{.ArgPairs}}) -> sqlc.GenericCursor: - return conn.execute({{.ConstantName}}{{.ArgParams}}) +{{- if eq .Cmd ":exec"}} + async def {{.MethodName}}(self{{.ArgPairs}}) -> None: + await self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}}) {{end}} + +{{- if eq .Cmd ":execrows"}} + async def {{.MethodName}}(self{{.ArgPairs}}) -> int: + result = await self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}}) + return result.rowcount +{{end}} + +{{- if eq .Cmd ":execresult"}} + async def {{.MethodName}}(self{{.ArgPairs}}) -> sqlalchemy.engine.Result: + return await self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}}) {{end}} {{- end}} +{{- end}} +{{- end}} ` type pyTmplCtx struct { Models []Struct Queries []Query Enums []Enum + EmitSync bool + EmitAsync bool SourceName string } @@ -635,9 +677,11 @@ func Generate(r *compiler.Result, settings config.CombinedSettings) (map[string] queriesFile := template.Must(template.New("table").Funcs(funcMap).Parse(queriesTmpl)) tctx := pyTmplCtx{ - Models: models, - Queries: queries, - Enums: enums, + Models: models, + Queries: queries, + Enums: enums, + EmitSync: settings.Python.EmitSyncQuerier, + EmitAsync: settings.Python.EmitAsyncQuerier, } output := map[string]string{} diff --git a/internal/codegen/python/imports.go b/internal/codegen/python/imports.go index 493284a564..dfce83a085 100644 --- a/internal/codegen/python/imports.go +++ b/internal/codegen/python/imports.go @@ -92,7 +92,7 @@ func (i *importer) modelImports() []string { } pkg := make(map[string]importSpec) - pkg["pydantic"] = importSpec{Module: "pydantic"} + pkg["dataclasses"] = importSpec{Module: "dataclasses"} for _, o := range i.Settings.Overrides { if o.PythonType.IsSet() && o.PythonType.Module != "" { @@ -129,11 +129,12 @@ func (i *importer) queryImports(fileName string) []string { } std := stdImports(queryUses) - std["typing.overload"] = importSpec{Module: "typing", Name: "overload"} - std["typing.Awaitable"] = importSpec{Module: "typing", Name: "Awaitable"} pkg := make(map[string]importSpec) - pkg["sqlc_runtime"] = importSpec{Module: "sqlc_runtime", Alias: "sqlc"} + pkg["sqlalchemy"] = importSpec{Module: "sqlalchemy"} + if i.Settings.Python.EmitAsyncQuerier { + pkg["sqlalchemy.ext.asyncio"] = importSpec{Module: "sqlalchemy.ext.asyncio"} + } for _, o := range i.Settings.Overrides { if o.PythonType.IsSet() && o.PythonType.Module != "" { @@ -145,7 +146,7 @@ func (i *importer) queryImports(fileName string) []string { queryValueModelImports := func(qv QueryValue) { if qv.IsStruct() && qv.EmitStruct() { - pkg["pydantic"] = importSpec{Module: "pydantic"} + pkg["dataclasses"] = importSpec{Module: "dataclasses"} } } @@ -157,8 +158,12 @@ func (i *importer) queryImports(fileName string) []string { std["typing.Optional"] = importSpec{Module: "typing", Name: "Optional"} } if q.Cmd == ":many" { - std["typing.Iterator"] = importSpec{Module: "typing", Name: "Iterator"} - std["typing.AsyncIterator"] = importSpec{Module: "typing", Name: "AsyncIterator"} + if i.Settings.Python.EmitSyncQuerier { + std["typing.Iterator"] = importSpec{Module: "typing", Name: "Iterator"} + } + if i.Settings.Python.EmitAsyncQuerier { + std["typing.AsyncIterator"] = importSpec{Module: "typing", Name: "AsyncIterator"} + } } queryValueModelImports(q.Ret) for _, qv := range q.Args { diff --git a/internal/config/config.go b/internal/config/config.go index 54001049e8..bdfba2d28d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -130,10 +130,12 @@ type SQLKotlin struct { } type SQLPython struct { - EmitExactTableNames bool `json:"emit_exact_table_names" yaml:"emit_exact_table_names"` - Package string `json:"package" yaml:"package"` - Out string `json:"out" yaml:"out"` - Overrides []Override `json:"overrides,omitempty" yaml:"overrides"` + EmitExactTableNames bool `json:"emit_exact_table_names" yaml:"emit_exact_table_names"` + EmitSyncQuerier bool `json:"emit_sync_querier" yaml:"emit_sync_querier"` + EmitAsyncQuerier bool `json:"emit_async_querier" yaml:"emit_async_querier"` + Package string `json:"package" yaml:"package"` + Out string `json:"out" yaml:"out"` + Overrides []Override `json:"overrides,omitempty" yaml:"overrides"` } type Override struct { @@ -229,7 +231,8 @@ var ErrUnknownEngine = errors.New("invalid engine") var ErrNoPackages = errors.New("no packages") var ErrNoPackageName = errors.New("missing package name") var ErrNoPackagePath = errors.New("missing package path") -var ErrKotlinNoOutPath = errors.New("no output path") +var ErrNoOutPath = errors.New("no output path") +var ErrNoQuerierType = errors.New("no querier emit type enabled") func ParseConfig(rd io.Reader) (Config, error) { var buf bytes.Buffer diff --git a/internal/config/v_two.go b/internal/config/v_two.go index 73a699dc7a..3f0db3cda4 100644 --- a/internal/config/v_two.go +++ b/internal/config/v_two.go @@ -53,13 +53,22 @@ func v2ParseConfig(rd io.Reader) (Config, error) { } if conf.SQL[j].Gen.Kotlin != nil { if conf.SQL[j].Gen.Kotlin.Out == "" { - return conf, ErrKotlinNoOutPath + return conf, ErrNoOutPath } if conf.SQL[j].Gen.Kotlin.Package == "" { return conf, ErrNoPackageName } } if conf.SQL[j].Gen.Python != nil { + if conf.SQL[j].Gen.Python.Out == "" { + return conf, ErrNoOutPath + } + if conf.SQL[j].Gen.Python.Package == "" { + return conf, ErrNoPackageName + } + if !conf.SQL[j].Gen.Python.EmitSyncQuerier && !conf.SQL[j].Gen.Python.EmitAsyncQuerier { + return conf, ErrNoQuerierType + } for i := range conf.SQL[j].Gen.Python.Overrides { if err := conf.SQL[j].Gen.Python.Overrides[i].Parse(); err != nil { return conf, err