Skip to content

Commit b340201

Browse files
committed
refactor db functions from main into db.py
Signed-off-by: Grant Ramsay <seapagan@gmail.com>
1 parent 0effb23 commit b340201

File tree

2 files changed

+21
-21
lines changed

2 files changed

+21
-21
lines changed

db.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,23 @@
77
engine = create_async_engine(DATABASE_URL, echo=False)
88
Base = declarative_base()
99
async_session = async_sessionmaker(engine, expire_on_commit=False)
10+
11+
12+
async def get_db():
13+
"""Get a database session.
14+
15+
To be used for dependency injection.
16+
"""
17+
async with async_session() as session:
18+
async with session.begin():
19+
yield session
20+
21+
22+
async def init_models():
23+
"""Create tables if they don't already exist.
24+
25+
In a real-life example we would use Alembic to manage migrations.
26+
"""
27+
async with engine.begin() as conn:
28+
# await conn.run_sync(Base.metadata.drop_all)
29+
await conn.run_sync(Base.metadata.create_all)

main.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,10 @@
55
from fastapi import Depends, FastAPI
66
from sqlalchemy import select
77

8-
from db import Base, async_session, engine
8+
from db import get_db, init_models
99
from models import User
1010

1111

12-
async def init_models():
13-
"""Create tables if they don't already exist.
14-
15-
In a real-life example we would use Alembic to manage migrations.
16-
"""
17-
async with engine.begin() as conn:
18-
# await conn.run_sync(Base.metadata.drop_all)
19-
await conn.run_sync(Base.metadata.create_all)
20-
21-
2212
@asynccontextmanager
2313
async def lifespan(app: FastAPI):
2414
"""Run tasks before and after the server starts."""
@@ -29,16 +19,6 @@ async def lifespan(app: FastAPI):
2919
app = FastAPI(lifespan=lifespan)
3020

3121

32-
async def get_db():
33-
"""Get a database session.
34-
35-
To be used for dependency injection.
36-
"""
37-
async with async_session() as session:
38-
async with session.begin():
39-
yield session
40-
41-
4222
@app.get("/")
4323
async def root():
4424
"""Root endpoint."""

0 commit comments

Comments
 (0)