diff --git a/.github/workflows/ci-python.yml b/.github/workflows/ci-python.yml new file mode 100644 index 0000000000..639827723d --- /dev/null +++ b/.github/workflows/ci-python.yml @@ -0,0 +1,40 @@ +name: python +on: [push, pull_request] +jobs: + + build: + name: test + runs-on: ubuntu-latest + + services: + postgres: + image: postgres:11 + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: postgres + ports: + - 5432:5432 + # needed because the postgres container does not provide a healthcheck + options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 + + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: 3.9 + - name: Install python dependencies + working-directory: ./examples/python + run: | + python -m pip install --upgrade pip + python -m pip install -r requirements.txt + - name: Test python code + working-directory: ./examples/python + env: + PG_USER: postgres + PG_HOST: localhost + PG_DATABASE: postgres + PG_PASSWORD: postgres + PG_PORT: ${{ job.services.postgres.ports['5432'] }} + run: | + pytest src/tests diff --git a/examples/python/requirements.txt b/examples/python/requirements.txt new file mode 100644 index 0000000000..c22530eaa9 --- /dev/null +++ b/examples/python/requirements.txt @@ -0,0 +1,6 @@ +pytest~=6.2.2 +pytest-asyncio~=0.14.0 +psycopg2-binary~=2.8.6 +asyncpg~=0.21.0 +pydantic~=1.7.3 +sqlc-python-runtime~=1.0.0 diff --git a/examples/python/sqlc.json b/examples/python/sqlc.json new file mode 100644 index 0000000000..ba987b5af1 --- /dev/null +++ b/examples/python/sqlc.json @@ -0,0 +1,49 @@ +{ + "version": "2", + "sql": [ + { + "schema": "../authors/postgresql/schema.sql", + "queries": "../authors/postgresql/query.sql", + "engine": "postgresql", + "gen": { + "python": { + "out": "src/authors", + "package": "authors" + } + } + }, + { + "schema": "../booktest/postgresql/schema.sql", + "queries": "../booktest/postgresql/query.sql", + "engine": "postgresql", + "gen": { + "python": { + "out": "src/booktest", + "package": "booktest" + } + } + }, + { + "schema": "../jets/schema.sql", + "queries": "../jets/query-building.sql", + "engine": "postgresql", + "gen": { + "python": { + "out": "src/jets", + "package": "jets" + } + } + }, + { + "schema": "../ondeck/postgresql/schema", + "queries": "../ondeck/postgresql/query", + "engine": "postgresql", + "gen": { + "python": { + "out": "src/ondeck", + "package": "ondeck" + } + } + } + ] +} diff --git a/examples/python/src/authors/__init__.py b/examples/python/src/authors/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/python/src/authors/models.py b/examples/python/src/authors/models.py new file mode 100644 index 0000000000..b282d77dac --- /dev/null +++ b/examples/python/src/authors/models.py @@ -0,0 +1,15 @@ +# Code generated by sqlc. DO NOT EDIT. +from typing import Optional + +import pydantic + + +# Enums + +# Models +class Author(pydantic.BaseModel): + id: int + name: str + bio: Optional[str] + + diff --git a/examples/python/src/authors/query.py b/examples/python/src/authors/query.py new file mode 100644 index 0000000000..947e535309 --- /dev/null +++ b/examples/python/src/authors/query.py @@ -0,0 +1,92 @@ +# Code generated by sqlc. DO NOT EDIT. +from typing import AsyncIterator, Awaitable, Iterator, Optional, overload + +import sqlc_runtime as sqlc + +from authors import models + + +CREATE_AUTHOR = """-- name: create_author :one +INSERT INTO authors ( + name, bio +) VALUES ( + $1, $2 +) +RETURNING id, name, bio +""" + + +DELETE_AUTHOR = """-- name: delete_author :exec +DELETE FROM authors +WHERE id = $1 +""" + + +GET_AUTHOR = """-- name: get_author :one +SELECT id, name, bio FROM authors +WHERE id = $1 LIMIT 1 +""" + + +LIST_AUTHORS = """-- name: list_authors :many +SELECT id, name, bio FROM authors +ORDER BY name +""" + + +@overload +def create_author(conn: sqlc.Connection, name: str, bio: Optional[str]) -> Optional[models.Author]: + pass + + +@overload +def create_author(conn: sqlc.AsyncConnection, name: str, bio: Optional[str]) -> Awaitable[Optional[models.Author]]: + pass + + +def create_author(conn: sqlc.GenericConnection, name: str, bio: Optional[str]) -> sqlc.ReturnType[Optional[models.Author]]: + return conn.execute_one_model(models.Author, CREATE_AUTHOR, name, bio) + + +@overload +def delete_author(conn: sqlc.Connection, id: int) -> None: + pass + + +@overload +def delete_author(conn: sqlc.AsyncConnection, id: int) -> Awaitable[None]: + pass + + +def delete_author(conn: sqlc.GenericConnection, id: int) -> sqlc.ReturnType[None]: + return conn.execute_none(DELETE_AUTHOR, id) + + +@overload +def get_author(conn: sqlc.Connection, id: int) -> Optional[models.Author]: + pass + + +@overload +def get_author(conn: sqlc.AsyncConnection, id: int) -> Awaitable[Optional[models.Author]]: + pass + + +def get_author(conn: sqlc.GenericConnection, id: int) -> sqlc.ReturnType[Optional[models.Author]]: + return conn.execute_one_model(models.Author, GET_AUTHOR, id) + + +@overload +def list_authors(conn: sqlc.Connection) -> Iterator[models.Author]: + pass + + +@overload +def list_authors(conn: sqlc.AsyncConnection) -> AsyncIterator[models.Author]: + pass + + +def list_authors(conn: sqlc.GenericConnection) -> sqlc.IteratorReturn[models.Author]: + return conn.execute_many_model(models.Author, LIST_AUTHORS) + + diff --git a/examples/python/src/booktest/__init__.py b/examples/python/src/booktest/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/python/src/booktest/models.py b/examples/python/src/booktest/models.py new file mode 100644 index 0000000000..9c7d1a1970 --- /dev/null +++ b/examples/python/src/booktest/models.py @@ -0,0 +1,31 @@ +# Code generated by sqlc. DO NOT EDIT. +from typing import List +import datetime +import enum + +import pydantic + + +# Enums +class BookType(str, enum.Enum): + FICTION = "FICTION" + NONFICTION = "NONFICTION" + + +# Models +class Author(pydantic.BaseModel): + author_id: int + name: str + + +class Book(pydantic.BaseModel): + book_id: int + author_id: int + isbn: str + book_type: BookType + title: str + year: int + available: datetime.datetime + tags: List[str] + + diff --git a/examples/python/src/booktest/query.py b/examples/python/src/booktest/query.py new file mode 100644 index 0000000000..3081b2a598 --- /dev/null +++ b/examples/python/src/booktest/query.py @@ -0,0 +1,266 @@ +# Code generated by sqlc. DO NOT EDIT. +from typing import AsyncIterator, Awaitable, Iterator, List, Optional, overload +import datetime + +import pydantic +import sqlc_runtime as sqlc + +from booktest import models + + +BOOKS_BY_TAGS = """-- name: books_by_tags :many +SELECT + book_id, + title, + name, + isbn, + tags +FROM books +LEFT JOIN authors ON books.author_id = authors.author_id +WHERE tags && $1::varchar[] +""" + + +class BooksByTagsRow(pydantic.BaseModel): + book_id: int + title: str + name: str + isbn: str + tags: List[str] + + +BOOKS_BY_TITLE_YEAR = """-- name: books_by_title_year :many +SELECT book_id, author_id, isbn, book_type, title, year, available, tags FROM books +WHERE title = $1 AND year = $2 +""" + + +class BooksByTitleYearRow(pydantic.BaseModel): + book_id: int + author_id: int + isbn: str + book_type: models.BookType + title: str + year: int + available: datetime.datetime + tags: List[str] + + +CREATE_AUTHOR = """-- name: create_author :one +INSERT INTO authors (name) VALUES ($1) +RETURNING author_id, name +""" + + +CREATE_BOOK = """-- name: create_book :one +INSERT INTO books ( + author_id, + isbn, + book_type, + title, + year, + available, + tags +) VALUES ( + $1, + $2, + $3, + $4, + $5, + $6, + $7 +) +RETURNING book_id, author_id, isbn, book_type, title, year, available, tags +""" + + +class CreateBookParams(pydantic.BaseModel): + author_id: int + isbn: str + book_type: models.BookType + title: str + year: int + available: datetime.datetime + tags: List[str] + + +class CreateBookRow(pydantic.BaseModel): + book_id: int + author_id: int + isbn: str + book_type: models.BookType + title: str + year: int + available: datetime.datetime + tags: List[str] + + +DELETE_BOOK = """-- name: delete_book :exec +DELETE FROM books +WHERE book_id = $1 +""" + + +GET_AUTHOR = """-- name: get_author :one +SELECT author_id, name FROM authors +WHERE author_id = $1 +""" + + +GET_BOOK = """-- name: get_book :one +SELECT book_id, author_id, isbn, book_type, title, year, available, tags FROM books +WHERE book_id = $1 +""" + + +class GetBookRow(pydantic.BaseModel): + book_id: int + author_id: int + isbn: str + book_type: models.BookType + title: str + year: int + available: datetime.datetime + tags: List[str] + + +UPDATE_BOOK = """-- name: update_book :exec +UPDATE books +SET title = $1, tags = $2 +WHERE book_id = $3 +""" + + +UPDATE_BOOK_ISBN = """-- name: update_book_isbn :exec +UPDATE books +SET title = $1, tags = $2, isbn = $4 +WHERE book_id = $3 +""" + + +@overload +def books_by_tags(conn: sqlc.Connection, dollar_1: List[str]) -> Iterator[BooksByTagsRow]: + pass + + +@overload +def books_by_tags(conn: sqlc.AsyncConnection, dollar_1: List[str]) -> AsyncIterator[BooksByTagsRow]: + pass + + +def books_by_tags(conn: sqlc.GenericConnection, dollar_1: List[str]) -> sqlc.IteratorReturn[BooksByTagsRow]: + return conn.execute_many_model(BooksByTagsRow, BOOKS_BY_TAGS, dollar_1) + + +@overload +def books_by_title_year(conn: sqlc.Connection, title: str, year: int) -> Iterator[BooksByTitleYearRow]: + pass + + +@overload +def books_by_title_year(conn: sqlc.AsyncConnection, title: str, year: int) -> AsyncIterator[BooksByTitleYearRow]: + pass + + +def books_by_title_year(conn: sqlc.GenericConnection, title: str, year: int) -> sqlc.IteratorReturn[BooksByTitleYearRow]: + return conn.execute_many_model(BooksByTitleYearRow, BOOKS_BY_TITLE_YEAR, title, year) + + +@overload +def create_author(conn: sqlc.Connection, name: str) -> Optional[models.Author]: + pass + + +@overload +def create_author(conn: sqlc.AsyncConnection, name: str) -> Awaitable[Optional[models.Author]]: + pass + + +def create_author(conn: sqlc.GenericConnection, name: str) -> sqlc.ReturnType[Optional[models.Author]]: + return conn.execute_one_model(models.Author, CREATE_AUTHOR, name) + + +@overload +def create_book(conn: sqlc.Connection, arg: CreateBookParams) -> Optional[CreateBookRow]: + pass + + +@overload +def create_book(conn: sqlc.AsyncConnection, arg: CreateBookParams) -> Awaitable[Optional[CreateBookRow]]: + pass + + +def create_book(conn: sqlc.GenericConnection, arg: CreateBookParams) -> sqlc.ReturnType[Optional[CreateBookRow]]: + return conn.execute_one_model(CreateBookRow, CREATE_BOOK, arg.author_id, arg.isbn, arg.book_type, arg.title, arg.year, arg.available, arg.tags) + + +@overload +def delete_book(conn: sqlc.Connection, book_id: int) -> None: + pass + + +@overload +def delete_book(conn: sqlc.AsyncConnection, book_id: int) -> Awaitable[None]: + pass + + +def delete_book(conn: sqlc.GenericConnection, book_id: int) -> sqlc.ReturnType[None]: + return conn.execute_none(DELETE_BOOK, book_id) + + +@overload +def get_author(conn: sqlc.Connection, author_id: int) -> Optional[models.Author]: + pass + + +@overload +def get_author(conn: sqlc.AsyncConnection, author_id: int) -> Awaitable[Optional[models.Author]]: + pass + + +def get_author(conn: sqlc.GenericConnection, author_id: int) -> sqlc.ReturnType[Optional[models.Author]]: + return conn.execute_one_model(models.Author, GET_AUTHOR, author_id) + + +@overload +def get_book(conn: sqlc.Connection, book_id: int) -> Optional[GetBookRow]: + pass + + +@overload +def get_book(conn: sqlc.AsyncConnection, book_id: int) -> Awaitable[Optional[GetBookRow]]: + pass + + +def get_book(conn: sqlc.GenericConnection, book_id: int) -> sqlc.ReturnType[Optional[GetBookRow]]: + return conn.execute_one_model(GetBookRow, GET_BOOK, book_id) + + +@overload +def update_book(conn: sqlc.Connection, title: str, tags: List[str], book_id: int) -> None: + pass + + +@overload +def update_book(conn: sqlc.AsyncConnection, title: str, tags: List[str], book_id: int) -> Awaitable[None]: + pass + + +def update_book(conn: sqlc.GenericConnection, title: str, tags: List[str], book_id: int) -> sqlc.ReturnType[None]: + return conn.execute_none(UPDATE_BOOK, title, tags, book_id) + + +@overload +def update_book_isbn(conn: sqlc.Connection, title: str, tags: List[str], book_id: int, isbn: str) -> None: + pass + + +@overload +def update_book_isbn(conn: sqlc.AsyncConnection, title: str, tags: List[str], book_id: int, isbn: str) -> Awaitable[None]: + pass + + +def update_book_isbn(conn: sqlc.GenericConnection, title: str, tags: List[str], book_id: int, isbn: str) -> sqlc.ReturnType[None]: + return conn.execute_none(UPDATE_BOOK_ISBN, title, tags, book_id, isbn) + + diff --git a/examples/python/src/dbtest/__init__.py b/examples/python/src/dbtest/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/python/src/dbtest/migrations.py b/examples/python/src/dbtest/migrations.py new file mode 100644 index 0000000000..c3c72b78a6 --- /dev/null +++ b/examples/python/src/dbtest/migrations.py @@ -0,0 +1,41 @@ +import os +from typing import List + +import asyncpg +import psycopg2.extensions + + +def apply_migrations(db: psycopg2.extensions.connection, paths: List[str]): + files = _find_sql_files(paths) + + for file in files: + with open(file, "r") as fd: + blob = fd.read() + cur = db.cursor() + cur.execute(blob) + cur.close() + db.commit() + + +async def apply_migrations_async(db: asyncpg.Connection, paths: List[str]): + files = _find_sql_files(paths) + + for file in files: + with open(file, "r") as fd: + blob = fd.read() + await db.execute(blob) + + +def _find_sql_files(paths: List[str]) -> List[str]: + files = [] + for path in paths: + if not os.path.exists(path): + raise FileNotFoundError(f"{path} does not exist") + if os.path.isdir(path): + for file in os.listdir(path): + if file.endswith(".sql"): + files.append(os.path.join(path, file)) + else: + files.append(path) + files.sort() + return files diff --git a/examples/python/src/jets/__init__.py b/examples/python/src/jets/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/python/src/jets/models.py b/examples/python/src/jets/models.py new file mode 100644 index 0000000000..cb543dcb22 --- /dev/null +++ b/examples/python/src/jets/models.py @@ -0,0 +1,32 @@ +# Code generated by sqlc. DO NOT EDIT. + + +import pydantic + + +# Enums + +# Models +class Jet(pydantic.BaseModel): + id: int + pilot_id: int + age: int + name: str + color: str + + +class Language(pydantic.BaseModel): + id: int + language: str + + +class Pilot(pydantic.BaseModel): + id: int + name: str + + +class PilotLanguage(pydantic.BaseModel): + pilot_id: int + language_id: int + + diff --git a/examples/python/src/jets/query-building.py b/examples/python/src/jets/query-building.py new file mode 100644 index 0000000000..21056fe10c --- /dev/null +++ b/examples/python/src/jets/query-building.py @@ -0,0 +1,65 @@ +# Code generated by sqlc. DO NOT EDIT. +from typing import AsyncIterator, Awaitable, Iterator, Optional, overload + +import sqlc_runtime as sqlc + +from jets import models + + +COUNT_PILOTS = """-- name: count_pilots :one +SELECT COUNT(*) FROM pilots +""" + + +DELETE_PILOT = """-- name: delete_pilot :exec +DELETE FROM pilots WHERE id = $1 +""" + + +LIST_PILOTS = """-- name: list_pilots :many +SELECT id, name FROM pilots LIMIT 5 +""" + + +@overload +def count_pilots(conn: sqlc.Connection) -> Optional[int]: + pass + + +@overload +def count_pilots(conn: sqlc.AsyncConnection) -> Awaitable[Optional[int]]: + pass + + +def count_pilots(conn: sqlc.GenericConnection) -> sqlc.ReturnType[Optional[int]]: + return conn.execute_one(COUNT_PILOTS) + + +@overload +def delete_pilot(conn: sqlc.Connection, id: int) -> None: + pass + + +@overload +def delete_pilot(conn: sqlc.AsyncConnection, id: int) -> Awaitable[None]: + pass + + +def delete_pilot(conn: sqlc.GenericConnection, id: int) -> sqlc.ReturnType[None]: + return conn.execute_none(DELETE_PILOT, id) + + +@overload +def list_pilots(conn: sqlc.Connection) -> Iterator[models.Pilot]: + pass + + +@overload +def list_pilots(conn: sqlc.AsyncConnection) -> AsyncIterator[models.Pilot]: + pass + + +def list_pilots(conn: sqlc.GenericConnection) -> sqlc.IteratorReturn[models.Pilot]: + return conn.execute_many_model(models.Pilot, LIST_PILOTS) + + diff --git a/examples/python/src/ondeck/__init__.py b/examples/python/src/ondeck/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/python/src/ondeck/city.py b/examples/python/src/ondeck/city.py new file mode 100644 index 0000000000..baf0051e9c --- /dev/null +++ b/examples/python/src/ondeck/city.py @@ -0,0 +1,96 @@ +# Code generated by sqlc. DO NOT EDIT. +from typing import AsyncIterator, Awaitable, Iterator, Optional, overload + +import sqlc_runtime as sqlc + +from ondeck import models + + +CREATE_CITY = """-- name: create_city :one +INSERT INTO city ( + name, + slug +) VALUES ( + $1, + $2 +) RETURNING slug, name +""" + + +GET_CITY = """-- name: get_city :one +SELECT slug, name +FROM city +WHERE slug = $1 +""" + + +LIST_CITIES = """-- name: list_cities :many +SELECT slug, name +FROM city +ORDER BY name +""" + + +UPDATE_CITY_NAME = """-- name: update_city_name :exec +UPDATE city +SET name = $2 +WHERE slug = $1 +""" + + +@overload +def create_city(conn: sqlc.Connection, name: str, slug: str) -> Optional[models.City]: + pass + + +@overload +def create_city(conn: sqlc.AsyncConnection, name: str, slug: str) -> Awaitable[Optional[models.City]]: + pass + + +def create_city(conn: sqlc.GenericConnection, name: str, slug: str) -> sqlc.ReturnType[Optional[models.City]]: + return conn.execute_one_model(models.City, CREATE_CITY, name, slug) + + +@overload +def get_city(conn: sqlc.Connection, slug: str) -> Optional[models.City]: + pass + + +@overload +def get_city(conn: sqlc.AsyncConnection, slug: str) -> Awaitable[Optional[models.City]]: + pass + + +def get_city(conn: sqlc.GenericConnection, slug: str) -> sqlc.ReturnType[Optional[models.City]]: + return conn.execute_one_model(models.City, GET_CITY, slug) + + +@overload +def list_cities(conn: sqlc.Connection) -> Iterator[models.City]: + pass + + +@overload +def list_cities(conn: sqlc.AsyncConnection) -> AsyncIterator[models.City]: + pass + + +def list_cities(conn: sqlc.GenericConnection) -> sqlc.IteratorReturn[models.City]: + return conn.execute_many_model(models.City, LIST_CITIES) + + +@overload +def update_city_name(conn: sqlc.Connection, slug: str, name: str) -> None: + pass + + +@overload +def update_city_name(conn: sqlc.AsyncConnection, slug: str, name: str) -> Awaitable[None]: + pass + + +def update_city_name(conn: sqlc.GenericConnection, slug: str, name: str) -> sqlc.ReturnType[None]: + return conn.execute_none(UPDATE_CITY_NAME, slug, name) + + diff --git a/examples/python/src/ondeck/models.py b/examples/python/src/ondeck/models.py new file mode 100644 index 0000000000..bc8244b2a3 --- /dev/null +++ b/examples/python/src/ondeck/models.py @@ -0,0 +1,34 @@ +# Code generated by sqlc. DO NOT EDIT. +from typing import List, Optional +import datetime +import enum + +import pydantic + + +# Enums# Venues can be either open or closed +class Status(str, enum.Enum): + OPEN = "op!en" + CLOSED = "clo@sed" + + +# Models +class City(pydantic.BaseModel): + slug: str + name: str + +# Venues are places where muisc happens +class Venue(pydantic.BaseModel): + id: int + status: Status + statuses: Optional[List[Status]] + # This value appears in public URLs + slug: str + name: str + city: str + spotify_playlist: str + songkick_id: Optional[str] + tags: Optional[List[str]] + created_at: datetime.datetime + + diff --git a/examples/python/src/ondeck/venue.py b/examples/python/src/ondeck/venue.py new file mode 100644 index 0000000000..289725e0d4 --- /dev/null +++ b/examples/python/src/ondeck/venue.py @@ -0,0 +1,197 @@ +# Code generated by sqlc. DO NOT EDIT. +from typing import AsyncIterator, Awaitable, Iterator, List, Optional, overload +import datetime + +import pydantic +import sqlc_runtime as sqlc + +from ondeck import models + + +CREATE_VENUE = """-- name: create_venue :one +INSERT INTO venue ( + slug, + name, + city, + created_at, + spotify_playlist, + status, + statuses, + tags +) VALUES ( + $1, + $2, + $3, + NOW(), + $4, + $5, + $6, + $7 +) RETURNING id +""" + + +class CreateVenueParams(pydantic.BaseModel): + slug: str + name: str + city: str + spotify_playlist: str + status: models.Status + statuses: Optional[List[models.Status]] + tags: Optional[List[str]] + + +DELETE_VENUE = """-- name: delete_venue :exec +DELETE FROM venue +WHERE slug = $1 AND slug = $1 +""" + + +GET_VENUE = """-- name: get_venue :one +SELECT id, status, statuses, slug, name, city, spotify_playlist, songkick_id, tags, created_at +FROM venue +WHERE slug = $1 AND city = $2 +""" + + +class GetVenueRow(pydantic.BaseModel): + id: int + status: models.Status + statuses: Optional[List[models.Status]] + slug: str + name: str + city: str + spotify_playlist: str + songkick_id: Optional[str] + tags: Optional[List[str]] + created_at: datetime.datetime + + +LIST_VENUES = """-- name: list_venues :many +SELECT id, status, statuses, slug, name, city, spotify_playlist, songkick_id, tags, created_at +FROM venue +WHERE city = $1 +ORDER BY name +""" + + +class ListVenuesRow(pydantic.BaseModel): + id: int + status: models.Status + statuses: Optional[List[models.Status]] + slug: str + name: str + city: str + spotify_playlist: str + songkick_id: Optional[str] + tags: Optional[List[str]] + created_at: datetime.datetime + + +UPDATE_VENUE_NAME = """-- name: update_venue_name :one +UPDATE venue +SET name = $2 +WHERE slug = $1 +RETURNING id +""" + + +VENUE_COUNT_BY_CITY = """-- name: venue_count_by_city :many +SELECT + city, + count(*) +FROM venue +GROUP BY 1 +ORDER BY 1 +""" + + +class VenueCountByCityRow(pydantic.BaseModel): + city: str + count: int + + +@overload +def create_venue(conn: sqlc.Connection, arg: CreateVenueParams) -> Optional[int]: + pass + + +@overload +def create_venue(conn: sqlc.AsyncConnection, arg: CreateVenueParams) -> Awaitable[Optional[int]]: + pass + + +def create_venue(conn: sqlc.GenericConnection, arg: CreateVenueParams) -> sqlc.ReturnType[Optional[int]]: + return conn.execute_one(CREATE_VENUE, arg.slug, arg.name, arg.city, arg.spotify_playlist, arg.status, arg.statuses, arg.tags) + + +@overload +def delete_venue(conn: sqlc.Connection, slug: str) -> None: + pass + + +@overload +def delete_venue(conn: sqlc.AsyncConnection, slug: str) -> Awaitable[None]: + pass + + +def delete_venue(conn: sqlc.GenericConnection, slug: str) -> sqlc.ReturnType[None]: + return conn.execute_none(DELETE_VENUE, slug) + + +@overload +def get_venue(conn: sqlc.Connection, slug: str, city: str) -> Optional[GetVenueRow]: + pass + + +@overload +def get_venue(conn: sqlc.AsyncConnection, slug: str, city: str) -> Awaitable[Optional[GetVenueRow]]: + pass + + +def get_venue(conn: sqlc.GenericConnection, slug: str, city: str) -> sqlc.ReturnType[Optional[GetVenueRow]]: + return conn.execute_one_model(GetVenueRow, GET_VENUE, slug, city) + + +@overload +def list_venues(conn: sqlc.Connection, city: str) -> Iterator[ListVenuesRow]: + pass + + +@overload +def list_venues(conn: sqlc.AsyncConnection, city: str) -> AsyncIterator[ListVenuesRow]: + pass + + +def list_venues(conn: sqlc.GenericConnection, city: str) -> sqlc.IteratorReturn[ListVenuesRow]: + return conn.execute_many_model(ListVenuesRow, LIST_VENUES, city) + + +@overload +def update_venue_name(conn: sqlc.Connection, slug: str, name: str) -> Optional[int]: + pass + + +@overload +def update_venue_name(conn: sqlc.AsyncConnection, slug: str, name: str) -> Awaitable[Optional[int]]: + pass + + +def update_venue_name(conn: sqlc.GenericConnection, slug: str, name: str) -> sqlc.ReturnType[Optional[int]]: + return conn.execute_one(UPDATE_VENUE_NAME, slug, name) + + +@overload +def venue_count_by_city(conn: sqlc.Connection) -> Iterator[VenueCountByCityRow]: + pass + + +@overload +def venue_count_by_city(conn: sqlc.AsyncConnection) -> AsyncIterator[VenueCountByCityRow]: + pass + + +def venue_count_by_city(conn: sqlc.GenericConnection) -> sqlc.IteratorReturn[VenueCountByCityRow]: + return conn.execute_many_model(VenueCountByCityRow, VENUE_COUNT_BY_CITY) + + diff --git a/examples/python/src/tests/__init__.py b/examples/python/src/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/python/src/tests/conftest.py b/examples/python/src/tests/conftest.py new file mode 100644 index 0000000000..e3df5f77dc --- /dev/null +++ b/examples/python/src/tests/conftest.py @@ -0,0 +1,69 @@ +import asyncio +import os +import random + +import asyncpg +import psycopg2 +import psycopg2.extensions +import pytest + + +@pytest.fixture(scope="session") +def postgres_uri() -> str: + pg_host = os.environ.get("PG_HOST", "postgres") + pg_port = os.environ.get("PG_PORT", 5432) + pg_user = os.environ.get("PG_USER", "postgres") + pg_password = os.environ.get("PG_PASSWORD", "mysecretpassword") + pg_db = os.environ.get("PG_DATABASE", "dinotest") + + return f"postgres://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{pg_db}?sslmode=disable" + + +@pytest.fixture(scope="session") +def postgres_connection(postgres_uri) -> psycopg2.extensions.connection: + conn = psycopg2.connect(postgres_uri) + yield conn + conn.close() + + +@pytest.fixture() +def postgres_db(postgres_connection) -> psycopg2.extensions.connection: + schema_name = f"sqltest_{random.randint(0, 1000)}" + cur = postgres_connection.cursor() + cur.execute(f"CREATE SCHEMA {schema_name}") + cur.execute(f"SET search_path TO {schema_name}") + cur.close() + postgres_connection.commit() + yield postgres_connection + postgres_connection.rollback() + cur = postgres_connection.cursor() + cur.execute(f"DROP SCHEMA {schema_name} CASCADE") + cur.execute(f"SET search_path TO public") + cur.close() + postgres_connection.commit() + + +@pytest.fixture(scope="session") +def event_loop(): + """Change event_loop fixture to session level.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +@pytest.fixture(scope="session") +async def async_postgres_connection(postgres_uri: str) -> asyncpg.Connection: + conn = await asyncpg.connect(postgres_uri) + yield conn + await conn.close() + + +@pytest.fixture() +async def async_postgres_db(async_postgres_connection: asyncpg.Connection) -> asyncpg.Connection: + conn = async_postgres_connection + schema_name = f"sqltest_{random.randint(0, 1000)}" + await conn.execute(f"CREATE SCHEMA {schema_name}") + await conn.execute(f"SET search_path TO {schema_name}") + yield conn + await conn.execute(f"DROP SCHEMA {schema_name} CASCADE") + await conn.execute(f"SET search_path TO public") diff --git a/examples/python/src/tests/test_authors.py b/examples/python/src/tests/test_authors.py new file mode 100644 index 0000000000..bc01f3133b --- /dev/null +++ b/examples/python/src/tests/test_authors.py @@ -0,0 +1,59 @@ +import os + +import asyncpg +import psycopg2.extensions +import pytest +from sqlc_runtime.psycopg2 import build_psycopg2_connection +from sqlc_runtime.asyncpg import build_asyncpg_connection + +from authors import query +from dbtest.migrations import apply_migrations, apply_migrations_async + + +def test_authors(postgres_db: psycopg2.extensions.connection): + apply_migrations(postgres_db, [os.path.dirname(__file__) + "/../../../authors/postgresql/schema.sql"]) + + db = build_psycopg2_connection(postgres_db) + + authors = list(query.list_authors(db)) + assert authors == [] + + author_name = "Brian Kernighan" + author_bio = "Co-author of The C Programming Language and The Go Programming Language" + new_author = query.create_author(db, name=author_name, bio=author_bio) + assert new_author.id > 0 + assert new_author.name == author_name + assert new_author.bio == author_bio + + db_author = query.get_author(db, new_author.id) + assert db_author == new_author + + author_list = list(query.list_authors(db)) + assert len(author_list) == 1 + assert author_list[0] == new_author + + +@pytest.mark.asyncio +async def test_authors_async(async_postgres_db: asyncpg.Connection): + await apply_migrations_async(async_postgres_db, [os.path.dirname(__file__) + "/../../../authors/postgresql/schema.sql"]) + + db = build_asyncpg_connection(async_postgres_db) + + async for _ in query.list_authors(db): + assert False, "No authors should exist" + + author_name = "Brian Kernighan" + author_bio = "Co-author of The C Programming Language and The Go Programming Language" + new_author = await query.create_author(db, name=author_name, bio=author_bio) + assert new_author.id > 0 + assert new_author.name == author_name + assert new_author.bio == author_bio + + db_author = await query.get_author(db, new_author.id) + assert db_author == new_author + + author_list = [] + async for author in query.list_authors(db): + author_list.append(author) + assert len(author_list) == 1 + assert author_list[0] == new_author diff --git a/examples/python/src/tests/test_booktest.py b/examples/python/src/tests/test_booktest.py new file mode 100644 index 0000000000..b0ba38891a --- /dev/null +++ b/examples/python/src/tests/test_booktest.py @@ -0,0 +1,87 @@ +import datetime +import os + +import asyncpg +import pytest +from sqlc_runtime.asyncpg import build_asyncpg_connection + +from booktest import query, models +from dbtest.migrations import apply_migrations_async + + +@pytest.mark.asyncio +async def test_books(async_postgres_db: asyncpg.Connection): + await apply_migrations_async(async_postgres_db, [os.path.dirname(__file__) + "/../../../booktest/postgresql/schema.sql"]) + + db = build_asyncpg_connection(async_postgres_db) + + author = await query.create_author(db, "Unknown Master") + assert author is not None + + async with async_postgres_db.transaction(): + now = datetime.datetime.now() + await query.create_book(db, query.CreateBookParams( + author_id=author.author_id, + isbn="1", + title="my book title", + book_type=models.BookType.FICTION, + year=2016, + available=now, + tags=[], + )) + + b1 = await query.create_book(db, query.CreateBookParams( + author_id=author.author_id, + isbn="2", + title="the second book", + book_type=models.BookType.FICTION, + year=2016, + available=now, + tags=["cool", "unique"], + )) + + await query.update_book(db, book_id=b1.book_id, title="changed second title", tags=["cool", "disastor"]) + + b3 = await query.create_book(db, query.CreateBookParams( + author_id=author.author_id, + isbn="3", + title="the third book", + book_type=models.BookType.FICTION, + year=2001, + available=now, + tags=["cool"], + )) + + b4 = await query.create_book(db, query.CreateBookParams( + author_id=author.author_id, + isbn="4", + title="4th place finisher", + book_type=models.BookType.NONFICTION, + year=2011, + available=now, + tags=["other"], + )) + + await query.update_book_isbn(db, book_id=b4.book_id, isbn="NEW ISBN", title="never ever gonna finish, a quatrain", tags=["someother"]) + + books0 = query.books_by_title_year(db, title="my book title", year=2016) + expected_titles = {"my book title"} + async for book in books0: + expected_titles.remove(book.title) # raises a key error if the title does not exist + assert len(book.tags) == 0 + + author = await query.get_author(db, author_id=book.author_id) + assert author.name == "Unknown Master" + assert len(expected_titles) == 0 + + books = query.books_by_tags(db, ["cool", "other", "someother"]) + expected_titles = {"changed second title", "the third book", "never ever gonna finish, a quatrain"} + async for book in books: + expected_titles.remove(book.title) + assert len(expected_titles) == 0 + + b5 = await query.get_book(db, b3.book_id) + assert b5 is not None + await query.delete_book(db, book_id=b5.book_id) + b6 = await query.get_book(db, b5.book_id) + assert b6 is None diff --git a/examples/python/src/tests/test_ondeck.py b/examples/python/src/tests/test_ondeck.py new file mode 100644 index 0000000000..f12fbe985c --- /dev/null +++ b/examples/python/src/tests/test_ondeck.py @@ -0,0 +1,53 @@ +import os + +import asyncpg +import pytest +from sqlc_runtime.asyncpg import build_asyncpg_connection + +from ondeck import models +from ondeck import city as city_queries +from ondeck import venue as venue_queries +from dbtest.migrations import apply_migrations_async + + +@pytest.mark.asyncio +async def test_ondeck(async_postgres_db: asyncpg.Connection): + await apply_migrations_async(async_postgres_db, [os.path.dirname(__file__) + "/../../../ondeck/postgresql/schema"]) + + db = build_asyncpg_connection(async_postgres_db) + + city = await city_queries.create_city(db, slug="san-francisco", name="San Francisco") + assert city is not None + + venue_id = await venue_queries.create_venue(db, venue_queries.CreateVenueParams( + slug="the-fillmore", + name="The Fillmore", + city=city.slug, + spotify_playlist="spotify:uri", + status=models.Status.OPEN, + statuses=[models.Status.OPEN, models.Status.CLOSED], + tags=["rock", "punk"], + )) + assert venue_id is not None + + venue = await venue_queries.get_venue(db, slug="the-fillmore", city=city.slug) + assert venue is not None + assert venue.id == venue_id + + assert city == await city_queries.get_city(db, city.slug) + assert [venue_queries.VenueCountByCityRow(city=city.slug, count=1)] == await _to_list(venue_queries.venue_count_by_city(db)) + assert [city] == await _to_list(city_queries.list_cities(db)) + assert [venue] == await _to_list(venue_queries.list_venues(db, city=city.slug)) + + await city_queries.update_city_name(db, slug=city.slug, name="SF") + _id = await venue_queries.update_venue_name(db, slug=venue.slug, name="Fillmore") + assert _id == venue_id + + await venue_queries.delete_venue(db, slug=venue.slug) + + +async def _to_list(it): + out = [] + async for i in it: + out.append(i) + return out diff --git a/internal/cmd/generate.go b/internal/cmd/generate.go index 55f9f4c18d..f5ff1f71ac 100644 --- a/internal/cmd/generate.go +++ b/internal/cmd/generate.go @@ -12,6 +12,7 @@ import ( "github.com/kyleconroy/sqlc/internal/codegen/golang" "github.com/kyleconroy/sqlc/internal/codegen/kotlin" + "github.com/kyleconroy/sqlc/internal/codegen/python" "github.com/kyleconroy/sqlc/internal/compiler" "github.com/kyleconroy/sqlc/internal/config" "github.com/kyleconroy/sqlc/internal/debug" @@ -120,6 +121,12 @@ func Generate(e Env, dir, filename string, stderr io.Writer) (map[string]string, Gen: config.SQLGen{Kotlin: sql.Gen.Kotlin}, }) } + if sql.Gen.Python != nil { + pairs = append(pairs, outPair{ + SQL: sql, + Gen: config.SQLGen{Python: sql.Gen.Python}, + }) + } } for _, sql := range pairs { @@ -149,6 +156,8 @@ func Generate(e Env, dir, filename string, stderr io.Writer) (map[string]string, parseOpts.UsePositionalParameters = true } name = combo.Kotlin.Package + } else if sql.Gen.Python != nil { + name = combo.Python.Package } result, failed := parse(e, name, dir, sql.SQL, combo, parseOpts, stderr) @@ -166,6 +175,9 @@ func Generate(e Env, dir, filename string, stderr io.Writer) (map[string]string, case sql.Gen.Kotlin != nil: out = combo.Kotlin.Out files, err = kotlin.Generate(result, combo) + case sql.Gen.Python != nil: + out = combo.Python.Out + files, err = python.Generate(result, combo) default: panic("missing language backend") } diff --git a/internal/codegen/golang/go_type.go b/internal/codegen/golang/go_type.go index 2736f5e4d0..11c4f96dbd 100644 --- a/internal/codegen/golang/go_type.go +++ b/internal/codegen/golang/go_type.go @@ -8,6 +8,9 @@ import ( func goType(r *compiler.Result, col *compiler.Column, settings config.CombinedSettings) string { // package overrides have a higher precedence for _, oride := range settings.Overrides { + if oride.GoTypeName == "" { + continue + } sameTable := sameTableName(col.Table, oride.Table, r.Catalog.DefaultSchema) if oride.Column != "" && oride.ColumnName == col.Name && sameTable { return oride.GoTypeName @@ -26,6 +29,9 @@ func goInnerType(r *compiler.Result, col *compiler.Column, settings config.Combi // package overrides have a higher precedence for _, oride := range settings.Overrides { + if oride.GoTypeName == "" { + continue + } if oride.DBType != "" && oride.DBType == columnType && oride.Nullable != notNull { return oride.GoTypeName } diff --git a/internal/codegen/golang/imports.go b/internal/codegen/golang/imports.go index e60380abda..d30d20285f 100644 --- a/internal/codegen/golang/imports.go +++ b/internal/codegen/golang/imports.go @@ -155,7 +155,7 @@ func (i *importer) interfaceImports() fileImports { pkg := make(map[ImportSpec]struct{}) overrideTypes := map[string]string{} for _, o := range i.Settings.Overrides { - if o.GoBasicType { + if o.GoBasicType || o.GoTypeName == "" { continue } overrideTypes[o.GoTypeName] = o.GoImportPath @@ -172,7 +172,7 @@ func (i *importer) interfaceImports() fileImports { // Custom imports for _, o := range i.Settings.Overrides { - if o.GoBasicType { + if o.GoBasicType || o.GoTypeName == "" { continue } _, alreadyImported := std[o.GoImportPath] @@ -215,7 +215,7 @@ func (i *importer) modelImports() fileImports { pkg := make(map[ImportSpec]struct{}) overrideTypes := map[string]string{} for _, o := range i.Settings.Overrides { - if o.GoBasicType { + if o.GoBasicType || o.GoTypeName == "" { continue } overrideTypes[o.GoTypeName] = o.GoImportPath @@ -232,7 +232,7 @@ func (i *importer) modelImports() fileImports { } for _, o := range i.Settings.Overrides { - if o.GoBasicType { + if o.GoBasicType || o.GoTypeName == "" { continue } _, alreadyImported := std[o.GoImportPath] @@ -349,7 +349,7 @@ func (i *importer) queryImports(filename string) fileImports { pkg := make(map[ImportSpec]struct{}) overrideTypes := map[string]string{} for _, o := range i.Settings.Overrides { - if o.GoBasicType { + if o.GoBasicType || o.GoTypeName == "" { continue } overrideTypes[o.GoTypeName] = o.GoImportPath @@ -369,7 +369,7 @@ func (i *importer) queryImports(filename string) fileImports { // Custom imports for _, o := range i.Settings.Overrides { - if o.GoBasicType { + if o.GoBasicType || o.GoTypeName == "" { continue } _, alreadyImported := std[o.GoImportPath] diff --git a/internal/codegen/python/gen.go b/internal/codegen/python/gen.go new file mode 100644 index 0000000000..fda5374524 --- /dev/null +++ b/internal/codegen/python/gen.go @@ -0,0 +1,678 @@ +package python + +import ( + "bufio" + "bytes" + "fmt" + "github.com/kyleconroy/sqlc/internal/codegen" + "github.com/kyleconroy/sqlc/internal/compiler" + "github.com/kyleconroy/sqlc/internal/config" + "github.com/kyleconroy/sqlc/internal/core" + "github.com/kyleconroy/sqlc/internal/inflection" + "github.com/kyleconroy/sqlc/internal/sql/ast" + "github.com/kyleconroy/sqlc/internal/sql/catalog" + "log" + "regexp" + "sort" + "strings" + "text/template" +) + +type Constant struct { + Name string + Type string + Value string +} + +type Enum struct { + Name string + Comment string + Constants []Constant +} + +type pyType struct { + InnerType string + IsArray bool + IsNull bool +} + +func (t pyType) String() string { + v := t.InnerType + if t.IsArray { + v = fmt.Sprintf("List[%s]", v) + } + if t.IsNull { + v = fmt.Sprintf("Optional[%s]", v) + } + return v +} + +type Field struct { + Name string + Type pyType + Comment string +} + +type Struct struct { + Table core.FQN + Name string + Fields []Field + Comment string +} + +func (s Struct) DedupFields() []Field { + seen := map[string]struct{}{} + dedupFields := make([]Field, 0) + for _, f := range s.Fields { + if _, ok := seen[f.Name]; ok { + continue + } + seen[f.Name] = struct{}{} + dedupFields = append(dedupFields, f) + } + return dedupFields +} + +type QueryValue struct { + Emit bool + Name string + Struct *Struct + Typ pyType +} + +func (v QueryValue) EmitStruct() bool { + return v.Emit +} + +func (v QueryValue) IsStruct() bool { + return v.Struct != nil +} + +func (v QueryValue) isEmpty() bool { + return v.Typ == (pyType{}) && v.Name == "" && v.Struct == nil +} + +func (v QueryValue) Pair() string { + if v.isEmpty() { + return "" + } + return v.Name + ": " + v.Type() +} + +func (v QueryValue) Type() string { + if v.Typ != (pyType{}) { + return v.Typ.String() + } + if v.Struct != nil { + if v.Emit { + return v.Struct.Name + } else { + return "models." + v.Struct.Name + } + } + panic("no type for QueryValue: " + v.Name) +} + +// A struct used to generate methods and fields on the Queries struct +type Query struct { + Cmd string + Comments []string + MethodName string + FieldName string + ConstantName string + SQL string + SourceName string + Ret QueryValue + Args []QueryValue +} + +func (q Query) ArgPairs() string { + argPairs := make([]string, 0, len(q.Args)) + for _, a := range q.Args { + argPairs = append(argPairs, a.Pair()) + } + if len(argPairs) == 0 { + return "" + } + return ", " + strings.Join(argPairs, ", ") +} + +func (q Query) ArgParams() string { + params := make([]string, 0, len(q.Args)) + for _, a := range q.Args { + if a.isEmpty() { + continue + } + if a.IsStruct() { + for _, f := range a.Struct.Fields { + params = append(params, a.Name+"."+f.Name) + } + } else { + params = append(params, a.Name) + } + } + if len(params) == 0 { + return "" + } + return ", " + strings.Join(params, ", ") +} + +func makePyType(r *compiler.Result, col *compiler.Column, settings config.CombinedSettings) pyType { + typ := pyInnerType(r, col, settings) + return pyType{ + InnerType: typ, + IsArray: col.IsArray, + IsNull: !col.NotNull, + } +} + +func pyInnerType(r *compiler.Result, col *compiler.Column, settings config.CombinedSettings) string { + for _, oride := range settings.Overrides { + if !oride.PythonType.IsSet() { + continue + } + sameTable := sameTableName(col.Table, oride.Table, r.Catalog.DefaultSchema) + if oride.Column != "" && oride.ColumnName == col.Name && sameTable { + return oride.PythonType.TypeString() + } + if oride.DBType != "" && oride.DBType == col.DataType && oride.Nullable != (col.NotNull || col.IsArray) { + return oride.PythonType.TypeString() + } + } + + switch settings.Package.Engine { + case config.EnginePostgreSQL: + return postgresType(r, col, settings) + default: + log.Println("unsupported engine type") + return "Any" + } +} + +func ModelName(name string, settings config.CombinedSettings) string { + if rename := settings.Rename[name]; rename != "" { + return rename + } + out := "" + for _, p := range strings.Split(name, "_") { + out += strings.Title(p) + } + return out +} + +var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)") +var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])") + +func MethodName(name string) string { + snake := matchFirstCap.ReplaceAllString(name, "${1}_${2}") + snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}") + return strings.ToLower(snake) +} + +var pyIdentPattern = regexp.MustCompile("[^a-zA-Z0-9_]+") + +func pyEnumValueName(value string) string { + id := strings.Replace(value, "-", "_", -1) + id = strings.Replace(id, ":", "_", -1) + id = strings.Replace(id, "/", "_", -1) + id = pyIdentPattern.ReplaceAllString(id, "") + return strings.ToUpper(id) +} + +func buildEnums(r *compiler.Result, settings config.CombinedSettings) []Enum { + var enums []Enum + for _, schema := range r.Catalog.Schemas { + if schema.Name == "pg_catalog" { + continue + } + for _, typ := range schema.Types { + enum, ok := typ.(*catalog.Enum) + if !ok { + continue + } + var enumName string + if schema.Name == r.Catalog.DefaultSchema { + enumName = enum.Name + } else { + enumName = schema.Name + "_" + enum.Name + } + e := Enum{ + Name: ModelName(enumName, settings), + Comment: enum.Comment, + } + for _, v := range enum.Vals { + e.Constants = append(e.Constants, Constant{ + Name: pyEnumValueName(v), + Value: v, + Type: e.Name, + }) + } + enums = append(enums, e) + } + } + if len(enums) > 0 { + sort.Slice(enums, func(i, j int) bool { return enums[i].Name < enums[j].Name }) + } + return enums +} + +func buildModels(r *compiler.Result, settings config.CombinedSettings) []Struct { + var structs []Struct + for _, schema := range r.Catalog.Schemas { + if schema.Name == "pg_catalog" { + continue + } + for _, table := range schema.Tables { + var tableName string + if schema.Name == r.Catalog.DefaultSchema { + tableName = table.Rel.Name + } else { + tableName = schema.Name + "_" + table.Rel.Name + } + structName := tableName + if !settings.Python.EmitExactTableNames { + structName = inflection.Singular(structName) + } + s := Struct{ + Table: core.FQN{Schema: schema.Name, Rel: table.Rel.Name}, + Name: ModelName(structName, settings), + Comment: table.Comment, + } + for _, column := range table.Columns { + typ := makePyType(r, compiler.ConvertColumn(table.Rel, column), settings) + typ.InnerType = strings.TrimPrefix(typ.InnerType, "models.") + s.Fields = append(s.Fields, Field{ + Name: column.Name, + Type: typ, + Comment: column.Comment, + }) + } + structs = append(structs, s) + } + } + if len(structs) > 0 { + sort.Slice(structs, func(i, j int) bool { return structs[i].Name < structs[j].Name }) + } + return structs +} + +func columnName(c *compiler.Column, pos int) string { + if c.Name != "" { + return c.Name + } + return fmt.Sprintf("column_%d", pos+1) +} + +func paramName(p compiler.Parameter) string { + if p.Column.Name != "" { + return p.Column.Name + } + return fmt.Sprintf("dollar_%d", p.Number) +} + +type pyColumn struct { + id int + *compiler.Column +} + +func columnsToStruct(r *compiler.Result, name string, columns []pyColumn, settings config.CombinedSettings) *Struct { + gs := Struct{ + Name: name, + } + seen := map[string]int{} + suffixes := map[int]int{} + for i, c := range columns { + colName := columnName(c.Column, i) + fieldName := colName + // Track suffixes by the ID of the column, so that columns referring to the same numbered parameter can be + // reused. + suffix := 0 + if o, ok := suffixes[c.id]; ok { + suffix = o + } else if v := seen[colName]; v > 0 { + suffix = v + 1 + } + suffixes[c.id] = suffix + if suffix > 0 { + fieldName = fmt.Sprintf("%s_%d", fieldName, suffix) + } + gs.Fields = append(gs.Fields, Field{ + Name: fieldName, + Type: makePyType(r, c.Column, settings), + }) + seen[colName]++ + } + return &gs +} + +func sameTableName(n *ast.TableName, f core.FQN, defaultSchema string) bool { + if n == nil { + return false + } + schema := n.Schema + if n.Schema == "" { + schema = defaultSchema + } + return n.Catalog == f.Catalog && schema == f.Schema && n.Name == f.Rel +} + +func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs []Struct) []Query { + qs := make([]Query, 0, len(r.Queries)) + for _, query := range r.Queries { + if query.Name == "" { + continue + } + if query.Cmd == "" { + continue + } + + methodName := MethodName(query.Name) + + gq := Query{ + Cmd: query.Cmd, + Comments: query.Comments, + MethodName: methodName, + FieldName: codegen.LowerTitle(query.Name) + "Stmt", + ConstantName: strings.ToUpper(methodName), + SQL: query.SQL, + SourceName: query.Filename, + } + + if len(query.Params) > 4 { + var cols []pyColumn + for _, p := range query.Params { + cols = append(cols, pyColumn{ + id: p.Number, + Column: p.Column, + }) + } + gq.Args = []QueryValue{{ + Emit: true, + Name: "arg", + Struct: columnsToStruct(r, query.Name+"Params", cols, settings), + }} + } else { + args := make([]QueryValue, 0, len(query.Params)) + for _, p := range query.Params { + args = append(args, QueryValue{ + Name: paramName(p), + Typ: makePyType(r, p.Column, settings), + }) + } + gq.Args = args + } + + if len(query.Columns) == 1 { + c := query.Columns[0] + gq.Ret = QueryValue{ + Name: columnName(c, 0), + Typ: makePyType(r, c, settings), + } + } else if len(query.Columns) > 1 { + var gs *Struct + var emit bool + + for _, s := range structs { + if len(s.Fields) != len(query.Columns) { + continue + } + same := true + for i, f := range s.Fields { + c := query.Columns[i] + sameName := f.Name == columnName(c, i) + sameType := f.Type == makePyType(r, c, settings) + sameTable := sameTableName(c.Table, s.Table, r.Catalog.DefaultSchema) + if !sameName || !sameType || !sameTable { + same = false + } + } + if same { + gs = &s + break + } + } + + if gs == nil { + var columns []pyColumn + for i, c := range query.Columns { + columns = append(columns, pyColumn{ + id: i, + Column: c, + }) + } + gs = columnsToStruct(r, query.Name+"Row", columns, settings) + emit = true + } + gq.Ret = QueryValue{ + Emit: emit, + Name: "i", + Struct: gs, + } + } + + qs = append(qs, gq) + } + sort.Slice(qs, func(i, j int) bool { return qs[i].MethodName < qs[j].MethodName }) + return qs +} + +var modelsTmpl = `# Code generated by sqlc. DO NOT EDIT. +{{- range imports .SourceName}} +{{.}} +{{- end}} + + +# Enums +{{- range .Enums}} +{{- if .Comment}}{{comment .Comment}}{{- end}} +class {{.Name}}(str, enum.Enum): + {{- range .Constants}} + {{.Name}} = "{{.Value}}" + {{- end}} +{{end}} + +# Models +{{- range .Models}} +{{- if .Comment}}{{comment .Comment}}{{- end}} +class {{.Name}}(pydantic.BaseModel): {{- range .DedupFields}} + {{- if .Comment}} + {{comment .Comment}}{{else}} + {{- end}} + {{.Name}}: {{.Type}} + {{- end}} + +{{end}} +` + +var queriesTmpl = `# Code generated by sqlc. DO NOT EDIT. +{{- range imports .SourceName}} +{{.}} +{{- end}} + +{{range .Queries}} +{{- if $.OutputQuery .SourceName}} +{{.ConstantName}} = """-- name: {{.MethodName}} {{.Cmd}} +{{.SQL}} +""" +{{range .Args}} +{{- if .EmitStruct}} + +class {{.Type}}(pydantic.BaseModel): {{- range .Struct.DedupFields}} + {{.Name}}: {{.Type}} + {{- end}} +{{end}}{{end}} +{{- if .Ret.EmitStruct}} + +class {{.Ret.Type}}(pydantic.BaseModel): {{- range .Ret.Struct.DedupFields}} + {{.Name}}: {{.Type}} + {{- end}} +{{end}} +{{end}} +{{- end}} + +{{- range .Queries}} +{{- if $.OutputQuery .SourceName}} +{{- if eq .Cmd ":one"}} +@overload +def {{.MethodName}}(conn: sqlc.Connection{{.ArgPairs}}) -> Optional[{{.Ret.Type}}]: + pass + + +@overload +def {{.MethodName}}(conn: sqlc.AsyncConnection{{.ArgPairs}}) -> Awaitable[Optional[{{.Ret.Type}}]]: + pass + + +def {{.MethodName}}(conn: sqlc.GenericConnection{{.ArgPairs}}) -> sqlc.ReturnType[Optional[{{.Ret.Type}}]]: + {{- if .Ret.IsStruct}} + return conn.execute_one_model({{.Ret.Type}}, {{.ConstantName}}{{.ArgParams}}) + {{- else}} + return conn.execute_one({{.ConstantName}}{{.ArgParams}}) + {{- end}} +{{end}} + +{{- if eq .Cmd ":many"}} +@overload +def {{.MethodName}}(conn: sqlc.Connection{{.ArgPairs}}) -> Iterator[{{.Ret.Type}}]: + pass + + +@overload +def {{.MethodName}}(conn: sqlc.AsyncConnection{{.ArgPairs}}) -> AsyncIterator[{{.Ret.Type}}]: + pass + + +def {{.MethodName}}(conn: sqlc.GenericConnection{{.ArgPairs}}) -> sqlc.IteratorReturn[{{.Ret.Type}}]: + {{- if .Ret.IsStruct}} + return conn.execute_many_model({{.Ret.Type}}, {{.ConstantName}}{{.ArgParams}}) + {{- else}} + return conn.execute_many({{.ConstantName}}{{.ArgParams}}) + {{- end}} +{{end}} + +{{- if eq .Cmd ":exec"}} +@overload +def {{.MethodName}}(conn: sqlc.Connection{{.ArgPairs}}) -> None: + pass + + +@overload +def {{.MethodName}}(conn: sqlc.AsyncConnection{{.ArgPairs}}) -> Awaitable[None]: + pass + + +def {{.MethodName}}(conn: sqlc.GenericConnection{{.ArgPairs}}) -> sqlc.ReturnType[None]: + return conn.execute_none({{.ConstantName}}{{.ArgParams}}) +{{end}} + +{{- if eq .Cmd ":execrows"}} +@overload +def {{.MethodName}}(conn: sqlc.Connection{{.ArgPairs}}) -> int: + pass + + +@overload +def {{.MethodName}}(conn: sqlc.AsyncConnection{{.ArgPairs}}) -> Awaitable[int]: + pass + + +def {{.MethodName}}(conn: sqlc.GenericConnection{{.ArgPairs}}) -> sqlc.ReturnType[int]: + return conn.execute_rowcount({{.ConstantName}}{{.ArgParams}}) +{{end}} + +{{- if eq .Cmd ":execresult"}} +@overload +def {{.MethodName}}(conn: sqlc.Connection{{.ArgPairs}}) -> sqlc.Cursor: + pass + + +@overload +def {{.MethodName}}(conn: sqlc.AsyncConnection{{.ArgPairs}}) -> sqlc.AsyncCursor: + pass + + +def {{.MethodName}}(conn: sqlc.GenericConnection{{.ArgPairs}}) -> sqlc.GenericCursor: + return conn.execute({{.ConstantName}}{{.ArgParams}}) +{{end}} +{{end}} +{{- end}} +` + +type pyTmplCtx struct { + Models []Struct + Queries []Query + Enums []Enum + SourceName string +} + +func (t *pyTmplCtx) OutputQuery(sourceName string) bool { + return t.SourceName == sourceName +} + +func HashComment(s string) string { + return "# " + strings.ReplaceAll(s, "\n", "\n# ") +} + +func Generate(r *compiler.Result, settings config.CombinedSettings) (map[string]string, error) { + enums := buildEnums(r, settings) + models := buildModels(r, settings) + queries := buildQueries(r, settings, models) + + i := &importer{ + Settings: settings, + Models: models, + Queries: queries, + Enums: enums, + } + + funcMap := template.FuncMap{ + "lowerTitle": codegen.LowerTitle, + "comment": HashComment, + "imports": i.Imports, + } + + modelsFile := template.Must(template.New("table").Funcs(funcMap).Parse(modelsTmpl)) + queriesFile := template.Must(template.New("table").Funcs(funcMap).Parse(queriesTmpl)) + + tctx := pyTmplCtx{ + Models: models, + Queries: queries, + Enums: enums, + } + + output := map[string]string{} + + execute := func(name string, t *template.Template) error { + var b bytes.Buffer + w := bufio.NewWriter(&b) + tctx.SourceName = name + err := t.Execute(w, &tctx) + w.Flush() + if err != nil { + return err + } + if !strings.HasSuffix(name, ".py") { + name = strings.TrimSuffix(name, ".sql") + name += ".py" + } + output[name] = b.String() + return nil + } + + if err := execute("models.py", modelsFile); err != nil { + return nil, err + } + + files := map[string]struct{}{} + for _, q := range queries { + files[q.SourceName] = struct{}{} + } + + for source := range files { + if err := execute(source, queriesFile); err != nil { + return nil, err + } + } + + return output, nil +} diff --git a/internal/codegen/python/imports.go b/internal/codegen/python/imports.go new file mode 100644 index 0000000000..493284a564 --- /dev/null +++ b/internal/codegen/python/imports.go @@ -0,0 +1,231 @@ +package python + +import ( + "fmt" + "github.com/kyleconroy/sqlc/internal/config" + "sort" + "strings" +) + +type importSpec struct { + Module string + Name string + Alias string +} + +func (i importSpec) String() string { + if i.Alias != "" { + if i.Name == "" { + return fmt.Sprintf("import %s as %s", i.Module, i.Alias) + } + return fmt.Sprintf("from %s import %s as %s", i.Module, i.Name, i.Alias) + } + if i.Name == "" { + return "import " + i.Module + } + return fmt.Sprintf("from %s import %s", i.Module, i.Name) +} + +type importer struct { + Settings config.CombinedSettings + Models []Struct + Queries []Query + Enums []Enum +} + +func structUses(name string, s Struct) bool { + for _, f := range s.Fields { + if name == "typing.List" && f.Type.IsArray { + return true + } + if name == "typing.Optional" && f.Type.IsNull { + return true + } + if f.Type.InnerType == name { + return true + } + } + return false +} + +func queryValueUses(name string, qv QueryValue) bool { + if !qv.isEmpty() { + if name == "typing.List" && qv.Typ.IsArray { + return true + } + if name == "typing.Optional" && qv.Typ.IsNull { + return true + } + if qv.IsStruct() && qv.EmitStruct() { + if structUses(name, *qv.Struct) { + return true + } + } else { + if qv.Typ.InnerType == name { + return true + } + } + } + return false +} + +func (i *importer) Imports(fileName string) []string { + if fileName == "models.py" { + return i.modelImports() + } + return i.queryImports(fileName) +} + +func (i *importer) modelImports() []string { + modelUses := func(name string) bool { + for _, model := range i.Models { + if structUses(name, model) { + return true + } + } + return false + } + + std := stdImports(modelUses) + if len(i.Enums) > 0 { + std["enum"] = importSpec{Module: "enum"} + } + + pkg := make(map[string]importSpec) + pkg["pydantic"] = importSpec{Module: "pydantic"} + + for _, o := range i.Settings.Overrides { + if o.PythonType.IsSet() && o.PythonType.Module != "" { + if modelUses(o.PythonType.TypeString()) { + pkg[o.PythonType.Module] = importSpec{Module: o.PythonType.Module} + } + } + } + + importLines := []string{ + buildImportBlock(std), + "", + buildImportBlock(pkg), + } + return importLines +} + +func (i *importer) queryImports(fileName string) []string { + queryUses := func(name string) bool { + for _, q := range i.Queries { + if q.SourceName != fileName { + continue + } + if queryValueUses(name, q.Ret) { + return true + } + for _, arg := range q.Args { + if queryValueUses(name, arg) { + return true + } + } + } + return false + } + + std := stdImports(queryUses) + std["typing.overload"] = importSpec{Module: "typing", Name: "overload"} + std["typing.Awaitable"] = importSpec{Module: "typing", Name: "Awaitable"} + + pkg := make(map[string]importSpec) + pkg["sqlc_runtime"] = importSpec{Module: "sqlc_runtime", Alias: "sqlc"} + + for _, o := range i.Settings.Overrides { + if o.PythonType.IsSet() && o.PythonType.Module != "" { + if queryUses(o.PythonType.TypeString()) { + pkg[o.PythonType.Module] = importSpec{Module: o.PythonType.Module} + } + } + } + + queryValueModelImports := func(qv QueryValue) { + if qv.IsStruct() && qv.EmitStruct() { + pkg["pydantic"] = importSpec{Module: "pydantic"} + } + } + + for _, q := range i.Queries { + if q.SourceName != fileName { + continue + } + if q.Cmd == ":one" { + std["typing.Optional"] = importSpec{Module: "typing", Name: "Optional"} + } + if q.Cmd == ":many" { + std["typing.Iterator"] = importSpec{Module: "typing", Name: "Iterator"} + std["typing.AsyncIterator"] = importSpec{Module: "typing", Name: "AsyncIterator"} + } + queryValueModelImports(q.Ret) + for _, qv := range q.Args { + queryValueModelImports(qv) + } + } + + modelImportStr := fmt.Sprintf("from %s import models", i.Settings.Python.Package) + + importLines := []string{ + buildImportBlock(std), + "", + buildImportBlock(pkg), + "", + modelImportStr, + } + return importLines +} + +func buildImportBlock(pkgs map[string]importSpec) string { + pkgImports := make([]importSpec, 0) + fromImports := make(map[string][]string) + for _, is := range pkgs { + if is.Name == "" || is.Alias != "" { + pkgImports = append(pkgImports, is) + } else { + names, ok := fromImports[is.Module] + if !ok { + names = make([]string, 0, 1) + } + names = append(names, is.Name) + fromImports[is.Module] = names + } + } + + importStrings := make([]string, 0, len(pkgImports)+len(fromImports)) + for _, is := range pkgImports { + importStrings = append(importStrings, is.String()) + } + for modName, names := range fromImports { + sort.Strings(names) + nameString := strings.Join(names, ", ") + importStrings = append(importStrings, fmt.Sprintf("from %s import %s", modName, nameString)) + } + sort.Strings(importStrings) + return strings.Join(importStrings, "\n") +} + +func stdImports(uses func(name string) bool) map[string]importSpec { + std := make(map[string]importSpec) + if uses("decimal.Decimal") { + std["decimal"] = importSpec{Module: "decimal"} + } + if uses("datetime.date") || uses("datetime.time") || uses("datetime.datetime") || uses("datetime.timedelta") { + std["datetime"] = importSpec{Module: "datetime"} + } + if uses("uuid.UUID") { + std["uuid"] = importSpec{Module: "uuid"} + } + if uses("typing.List") { + std["typing.List"] = importSpec{Module: "typing", Name: "List"} + } + if uses("typing.Optional") { + std["typing.Optional"] = importSpec{Module: "typing", Name: "Optional"} + } + if uses("Any") { + std["typing.Any"] = importSpec{Module: "typing", Name: "Any"} + } + return std +} diff --git a/internal/codegen/python/postgresql_type.go b/internal/codegen/python/postgresql_type.go new file mode 100644 index 0000000000..5f8663e35d --- /dev/null +++ b/internal/codegen/python/postgresql_type.go @@ -0,0 +1,66 @@ +package python + +import ( + "github.com/kyleconroy/sqlc/internal/compiler" + "github.com/kyleconroy/sqlc/internal/config" + "github.com/kyleconroy/sqlc/internal/sql/catalog" + "log" +) + +func postgresType(r *compiler.Result, col *compiler.Column, settings config.CombinedSettings) string { + columnType := col.DataType + + switch columnType { + case "serial", "serial4", "pg_catalog.serial4", "bigserial", "serial8", "pg_catalog.serial8", "smallserial", "serial2", "pg_catalog.serial2", "integer", "int", "int4", "pg_catalog.int4", "bigint", "int8", "pg_catalog.int8", "smallint", "int2", "pg_catalog.int2": + return "int" + case "float", "double precision", "float8", "pg_catalog.float8", "real", "float4", "pg_catalog.float4": + return "float" + case "numeric", "pg_catalog.numeric", "money": + return "decimal.Decimal" + case "boolean", "bool", "pg_catalog.bool": + return "bool" + case "json", "jsonb": + return "Any" + case "bytea", "blob", "pg_catalog.bytea": + return "memoryview" + case "date": + return "datetime.date" + case "pg_catalog.time", "pg_catalog.timetz": + return "datetime.time" + case "pg_catalog.timestamp", "pg_catalog.timestamptz", "timestamptz": + return "datetime.datetime" + case "interval", "pg_catalog.interval": + return "datetime.timedelta" + case "text", "pg_catalog.varchar", "pg_catalog.bpchar", "string", "citext": + return "str" + case "uuid": + return "uuid.UUID" + case "inet", "cidr", "macaddr", "macaddr8": + // psycopg2 does have support for ipaddress objects, but it is not enabled by default + // + // https://www.psycopg.org/docs/extras.html#adapt-network + return "str" + case "ltree", "lquery", "ltxtquery": + return "str" + default: + for _, schema := range r.Catalog.Schemas { + if schema.Name == "pg_catalog" { + continue + } + for _, typ := range schema.Types { + enum, ok := typ.(*catalog.Enum) + if !ok { + continue + } + if columnType == enum.Name { + if schema.Name == r.Catalog.DefaultSchema { + return "models." + ModelName(enum.Name, settings) + } + return "models." + ModelName(schema.Name+"_"+enum.Name, settings) + } + } + } + log.Printf("unknown PostgreSQL type: %s\n", columnType) + return "Any" + } +} diff --git a/internal/config/config.go b/internal/config/config.go index e4eeff91ca..3092965545 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -106,6 +106,7 @@ type SQL struct { type SQLGen struct { Go *SQLGo `json:"go,omitempty" yaml:"go"` Kotlin *SQLKotlin `json:"kotlin,omitempty" yaml:"kotlin"` + Python *SQLPython `json:"python,omitempty" yaml:"python"` } type SQLGo struct { @@ -127,10 +128,20 @@ type SQLKotlin struct { Out string `json:"out" yaml:"out"` } +type SQLPython struct { + EmitExactTableNames bool `json:"emit_exact_table_names" yaml:"emit_exact_table_names"` + Package string `json:"package" yaml:"package"` + Out string `json:"out" yaml:"out"` + Overrides []Override `json:"overrides,omitempty" yaml:"overrides"` +} + type Override struct { // name of the golang type to use, e.g. `github.com/segmentio/ksuid.KSUID` GoType GoType `json:"go_type" yaml:"go_type"` + // name of the python type to use, e.g. `mymodule.TypeName` + PythonType PythonType `json:"python_type" yaml:"python_type"` + // fully qualified name of the Go type, e.g. `github.com/segmentio/ksuid.KSUID` DBType string `json:"db_type" yaml:"db_type"` Deprecated_PostgresType string `json:"postgres_type" yaml:"postgres_type"` @@ -247,6 +258,7 @@ type CombinedSettings struct { Package SQL Go SQLGo Kotlin SQLKotlin + Python SQLPython Rename map[string]string Overrides []Override } @@ -270,5 +282,9 @@ func Combine(conf Config, pkg SQL) CombinedSettings { if pkg.Gen.Kotlin != nil { cs.Kotlin = *pkg.Gen.Kotlin } + if pkg.Gen.Python != nil { + cs.Python = *pkg.Gen.Python + cs.Overrides = append(cs.Overrides, pkg.Gen.Python.Overrides...) + } return cs } diff --git a/internal/config/python_type.go b/internal/config/python_type.go new file mode 100644 index 0000000000..e908448057 --- /dev/null +++ b/internal/config/python_type.go @@ -0,0 +1,17 @@ +package config + +type PythonType struct { + Module string `json:"module" yaml:"module"` + Name string `json:"name" yaml:"name"` +} + +func (t PythonType) IsSet() bool { + return t.Module != "" || t.Name != "" +} + +func (t PythonType) TypeString() string { + if t.Name != "" && t.Module == "" { + return t.Name + } + return t.Module + "." + t.Name +} diff --git a/internal/config/v_two.go b/internal/config/v_two.go index 46c8ad31c4..73a699dc7a 100644 --- a/internal/config/v_two.go +++ b/internal/config/v_two.go @@ -59,6 +59,13 @@ func v2ParseConfig(rd io.Reader) (Config, error) { return conf, ErrNoPackageName } } + if conf.SQL[j].Gen.Python != nil { + for i := range conf.SQL[j].Gen.Python.Overrides { + if err := conf.SQL[j].Gen.Python.Overrides[i].Parse(); err != nil { + return conf, err + } + } + } } return conf, nil } diff --git a/internal/endtoend/endtoend_test.go b/internal/endtoend/endtoend_test.go index c0fa4f4423..e2deb70c58 100644 --- a/internal/endtoend/endtoend_test.go +++ b/internal/endtoend/endtoend_test.go @@ -112,7 +112,7 @@ func cmpDirectory(t *testing.T, dir string, actual map[string]string) { if file.IsDir() { return nil } - if !strings.HasSuffix(path, ".go") && !strings.HasSuffix(path, ".kt") { + if !strings.HasSuffix(path, ".go") && !strings.HasSuffix(path, ".kt") && !strings.HasSuffix(path, ".py") { return nil } if strings.Contains(path, "/kotlin/build") { @@ -121,6 +121,10 @@ func cmpDirectory(t *testing.T, dir string, actual map[string]string) { if strings.HasSuffix(path, "_test.go") || strings.Contains(path, "src/test/") { return nil } + if strings.Contains(path, "/python/.venv") || strings.Contains(path, "/python/src/tests/") || + strings.HasSuffix(path, "__init__.py") || strings.Contains(path, "/python/src/dbtest/") { + return nil + } blob, err := ioutil.ReadFile(path) if err != nil { return err