Implement async db access in admin
This commit is contained in:
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from .authentication import (
|
from .authentication import (
|
||||||
authenticate_user,
|
authenticate_user,
|
||||||
|
authenticate_user_async,
|
||||||
create_access_token,
|
create_access_token,
|
||||||
create_refresh_token,
|
create_refresh_token,
|
||||||
check_password,
|
check_password,
|
||||||
@ -16,6 +17,7 @@ __all__ = [
|
|||||||
"Token",
|
"Token",
|
||||||
"User",
|
"User",
|
||||||
"authenticate_user",
|
"authenticate_user",
|
||||||
|
"authenticate_user_async",
|
||||||
"check_password",
|
"check_password",
|
||||||
"create_access_token",
|
"create_access_token",
|
||||||
"create_refresh_token",
|
"create_refresh_token",
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import bcrypt
|
|||||||
import jwt
|
import jwt
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sshecret_admin.core.settings import AdminServerSettings
|
from sshecret_admin.core.settings import AdminServerSettings
|
||||||
|
|
||||||
@ -72,6 +73,16 @@ def check_password(plain_password: str, hashed_password: str) -> None:
|
|||||||
raise AuthenticationFailedError()
|
raise AuthenticationFailedError()
|
||||||
|
|
||||||
|
|
||||||
|
async def authenticate_user_async(session: AsyncSession, username: str, password: str) -> User | None:
|
||||||
|
"""Authenticate user async."""
|
||||||
|
user = (await session.scalars(select(User).where(User.username == username))).first()
|
||||||
|
if not user:
|
||||||
|
return None
|
||||||
|
if not verify_password(password, user.hashed_password):
|
||||||
|
return None
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
def authenticate_user(session: Session, username: str, password: str) -> User | None:
|
def authenticate_user(session: Session, username: str, password: str) -> User | None:
|
||||||
"""Authenticate user."""
|
"""Authenticate user."""
|
||||||
user = session.scalars(select(User).where(User.username == username)).first()
|
user = session.scalars(select(User).where(User.username == username)).first()
|
||||||
|
|||||||
@ -1,11 +1,15 @@
|
|||||||
"""Database setup."""
|
"""Database setup."""
|
||||||
|
|
||||||
from collections.abc import Generator, Callable
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator, Generator, Callable
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy.engine import URL
|
from sqlalchemy.engine import URL
|
||||||
from sqlalchemy import create_engine, Engine
|
from sqlalchemy import create_engine, Engine
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine, AsyncEngine, AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
|
|
||||||
def setup_database(
|
def setup_database(
|
||||||
db_url: URL | str,
|
db_url: URL | str,
|
||||||
@ -20,3 +24,43 @@ def setup_database(
|
|||||||
yield session
|
yield session
|
||||||
|
|
||||||
return engine, get_db_session
|
return engine, get_db_session
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseSessionManager:
|
||||||
|
def __init__(self, host: URL | str, **engine_kwargs: str):
|
||||||
|
self._engine: AsyncEngine | None = create_async_engine(host, **engine_kwargs)
|
||||||
|
self._sessionmaker: async_sessionmaker[AsyncSession] | None = async_sessionmaker(autocommit=False, bind=self._engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
if self._engine is None:
|
||||||
|
raise Exception("DatabaseSessionManager is not initialized")
|
||||||
|
await self._engine.dispose()
|
||||||
|
|
||||||
|
self._engine = None
|
||||||
|
self._sessionmaker = None
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def connect(self) -> AsyncIterator[AsyncConnection]:
|
||||||
|
if self._engine is None:
|
||||||
|
raise Exception("DatabaseSessionManager is not initialized")
|
||||||
|
|
||||||
|
async with self._engine.begin() as connection:
|
||||||
|
try:
|
||||||
|
yield connection
|
||||||
|
except Exception:
|
||||||
|
await connection.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def session(self) -> AsyncIterator[AsyncSession]:
|
||||||
|
if self._sessionmaker is None:
|
||||||
|
raise Exception("DatabaseSessionManager is not initialized")
|
||||||
|
|
||||||
|
session = self._sessionmaker()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from collections.abc import AsyncGenerator, Awaitable, Callable, Generator
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Self
|
from typing import Self
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sshecret_admin.auth import User
|
from sshecret_admin.auth import User
|
||||||
from sshecret_admin.services import AdminBackend
|
from sshecret_admin.services import AdminBackend
|
||||||
|
|||||||
@ -31,3 +31,8 @@ class AdminServerSettings(BaseSettings):
|
|||||||
def admin_db(self) -> URL:
|
def admin_db(self) -> URL:
|
||||||
"""Construct database url."""
|
"""Construct database url."""
|
||||||
return URL.create(drivername="sqlite", database=self.database)
|
return URL.create(drivername="sqlite", database=self.database)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def async_db_url(self) -> URL:
|
||||||
|
"""Construct database url with sync handling."""
|
||||||
|
return URL.create(drivername="sqlite+aiosqlite", database=self.database)
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
"""Frontend dependencies."""
|
"""Frontend dependencies."""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from collections.abc import Callable, Awaitable
|
from collections.abc import AsyncGenerator, Callable, Awaitable
|
||||||
from typing import Self
|
from typing import Self
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from jinja2_fragments.fastapi import Jinja2Blocks
|
from jinja2_fragments.fastapi import Jinja2Blocks
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
@ -14,6 +15,7 @@ from sshecret_admin.auth.models import User
|
|||||||
|
|
||||||
UserTokenDep = Callable[[Request, Session], Awaitable[User]]
|
UserTokenDep = Callable[[Request, Session], Awaitable[User]]
|
||||||
UserLoginDep = Callable[[Request, Session], Awaitable[bool]]
|
UserLoginDep = Callable[[Request, Session], Awaitable[bool]]
|
||||||
|
AsyncSessionDep = Callable[[], AsyncGenerator[AsyncSession, None]]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -25,6 +27,7 @@ class FrontendDependencies(BaseDependencies):
|
|||||||
get_user_from_access_token: UserTokenDep
|
get_user_from_access_token: UserTokenDep
|
||||||
get_user_from_refresh_token: UserTokenDep
|
get_user_from_refresh_token: UserTokenDep
|
||||||
get_login_status: UserLoginDep
|
get_login_status: UserLoginDep
|
||||||
|
get_async_session: AsyncSessionDep
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
@ -35,6 +38,7 @@ class FrontendDependencies(BaseDependencies):
|
|||||||
get_user_from_access_token: UserTokenDep,
|
get_user_from_access_token: UserTokenDep,
|
||||||
get_user_from_refresh_token: UserTokenDep,
|
get_user_from_refresh_token: UserTokenDep,
|
||||||
get_login_status: UserLoginDep,
|
get_login_status: UserLoginDep,
|
||||||
|
get_async_session: AsyncSessionDep
|
||||||
) -> Self:
|
) -> Self:
|
||||||
"""Create from base dependencies."""
|
"""Create from base dependencies."""
|
||||||
return cls(
|
return cls(
|
||||||
@ -45,4 +49,5 @@ class FrontendDependencies(BaseDependencies):
|
|||||||
get_user_from_access_token=get_user_from_access_token,
|
get_user_from_access_token=get_user_from_access_token,
|
||||||
get_user_from_refresh_token=get_user_from_refresh_token,
|
get_user_from_refresh_token=get_user_from_refresh_token,
|
||||||
get_login_status=get_login_status,
|
get_login_status=get_login_status,
|
||||||
|
get_async_session=get_async_session,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from starlette.datastructures import URL
|
|||||||
from sshecret_admin.auth import PasswordDB, User, decode_token
|
from sshecret_admin.auth import PasswordDB, User, decode_token
|
||||||
from sshecret_admin.core.dependencies import BaseDependencies
|
from sshecret_admin.core.dependencies import BaseDependencies
|
||||||
from sshecret_admin.services.admin_backend import AdminBackend
|
from sshecret_admin.services.admin_backend import AdminBackend
|
||||||
|
from sshecret_admin.core.db import DatabaseSessionManager
|
||||||
|
|
||||||
from .dependencies import FrontendDependencies
|
from .dependencies import FrontendDependencies
|
||||||
from .exceptions import RedirectException
|
from .exceptions import RedirectException
|
||||||
@ -47,7 +48,9 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
|||||||
session: Annotated[Session, Depends(dependencies.get_db_session)]
|
session: Annotated[Session, Depends(dependencies.get_db_session)]
|
||||||
):
|
):
|
||||||
"""Get admin backend API."""
|
"""Get admin backend API."""
|
||||||
password_db = session.scalars(select(PasswordDB).where(PasswordDB.id == 1)).first()
|
password_db = session.scalars(
|
||||||
|
select(PasswordDB).where(PasswordDB.id == 1)
|
||||||
|
).first()
|
||||||
if not password_db:
|
if not password_db:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
500, detail="Error: The password manager has not yet been set up."
|
500, detail="Error: The password manager has not yet been set up."
|
||||||
@ -116,6 +119,12 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
async def get_async_session():
|
||||||
|
"""Get async session."""
|
||||||
|
sessionmanager = DatabaseSessionManager(dependencies.settings.async_db_url)
|
||||||
|
async with sessionmanager.session() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
view_dependencies = FrontendDependencies.create(
|
view_dependencies = FrontendDependencies.create(
|
||||||
dependencies,
|
dependencies,
|
||||||
get_admin_backend,
|
get_admin_backend,
|
||||||
@ -123,6 +132,7 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
|||||||
get_user_from_access_token,
|
get_user_from_access_token,
|
||||||
get_user_from_refresh_token,
|
get_user_from_refresh_token,
|
||||||
get_login_status,
|
get_login_status,
|
||||||
|
get_async_session,
|
||||||
)
|
)
|
||||||
|
|
||||||
app.include_router(audit.create_router(view_dependencies))
|
app.include_router(audit.create_router(view_dependencies))
|
||||||
|
|||||||
@ -8,13 +8,13 @@ from fastapi import APIRouter, Depends, Query, Request, Response, status
|
|||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sshecret_admin.services import AdminBackend
|
from sshecret_admin.services import AdminBackend
|
||||||
from starlette.datastructures import URL
|
from starlette.datastructures import URL
|
||||||
|
|
||||||
from sshecret_admin.auth import (
|
from sshecret_admin.auth import (
|
||||||
User,
|
User,
|
||||||
authenticate_user,
|
authenticate_user_async,
|
||||||
create_access_token,
|
create_access_token,
|
||||||
create_refresh_token,
|
create_refresh_token,
|
||||||
)
|
)
|
||||||
@ -80,7 +80,7 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
|||||||
async def login_user(
|
async def login_user(
|
||||||
request: Request,
|
request: Request,
|
||||||
response: Response,
|
response: Response,
|
||||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
session: Annotated[AsyncSession, Depends(dependencies.get_async_session)],
|
||||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||||
next: Annotated[str, Query()] = "/dashboard",
|
next: Annotated[str, Query()] = "/dashboard",
|
||||||
@ -100,7 +100,7 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
user = authenticate_user(session, form_data.username, form_data.password)
|
user = await authenticate_user_async(session, form_data.username, form_data.password)
|
||||||
login_failed = RedirectException(
|
login_failed = RedirectException(
|
||||||
to=URL("/login").include_query_params(
|
to=URL("/login").include_query_params(
|
||||||
error_title="Login Error", error_message="Invalid username or password"
|
error_title="Login Error", error_message="Invalid username or password"
|
||||||
|
|||||||
Reference in New Issue
Block a user