Implement async db access in admin

This commit is contained in:
2025-05-19 09:22:02 +02:00
parent fc0c3fb950
commit 5865cc450f
8 changed files with 85 additions and 7 deletions

View File

@ -2,6 +2,7 @@
from .authentication import ( from .authentication import (
authenticate_user, authenticate_user,
authenticate_user_async,
create_access_token, create_access_token,
create_refresh_token, create_refresh_token,
check_password, check_password,
@ -16,6 +17,7 @@ __all__ = [
"Token", "Token",
"User", "User",
"authenticate_user", "authenticate_user",
"authenticate_user_async",
"check_password", "check_password",
"create_access_token", "create_access_token",
"create_refresh_token", "create_refresh_token",

View File

@ -8,6 +8,7 @@ import bcrypt
import jwt import jwt
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sshecret_admin.core.settings import AdminServerSettings from sshecret_admin.core.settings import AdminServerSettings
@ -72,6 +73,16 @@ def check_password(plain_password: str, hashed_password: str) -> None:
raise AuthenticationFailedError() raise AuthenticationFailedError()
async def authenticate_user_async(session: AsyncSession, username: str, password: str) -> User | None:
"""Authenticate user async."""
user = (await session.scalars(select(User).where(User.username == username))).first()
if not user:
return None
if not verify_password(password, user.hashed_password):
return None
return user
def authenticate_user(session: Session, username: str, password: str) -> User | None: def authenticate_user(session: Session, username: str, password: str) -> User | None:
"""Authenticate user.""" """Authenticate user."""
user = session.scalars(select(User).where(User.username == username)).first() user = session.scalars(select(User).where(User.username == username)).first()

View File

@ -1,11 +1,15 @@
"""Database setup.""" """Database setup."""
from collections.abc import Generator, Callable from contextlib import asynccontextmanager
from collections.abc import AsyncIterator, Generator, Callable
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy.engine import URL from sqlalchemy.engine import URL
from sqlalchemy import create_engine, Engine from sqlalchemy import create_engine, Engine
from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine, AsyncEngine, AsyncSession, async_sessionmaker
def setup_database( def setup_database(
db_url: URL | str, db_url: URL | str,
@ -20,3 +24,43 @@ def setup_database(
yield session yield session
return engine, get_db_session return engine, get_db_session
class DatabaseSessionManager:
def __init__(self, host: URL | str, **engine_kwargs: str):
self._engine: AsyncEngine | None = create_async_engine(host, **engine_kwargs)
self._sessionmaker: async_sessionmaker[AsyncSession] | None = async_sessionmaker(autocommit=False, bind=self._engine, expire_on_commit=False)
async def close(self):
if self._engine is None:
raise Exception("DatabaseSessionManager is not initialized")
await self._engine.dispose()
self._engine = None
self._sessionmaker = None
@asynccontextmanager
async def connect(self) -> AsyncIterator[AsyncConnection]:
if self._engine is None:
raise Exception("DatabaseSessionManager is not initialized")
async with self._engine.begin() as connection:
try:
yield connection
except Exception:
await connection.rollback()
raise
@asynccontextmanager
async def session(self) -> AsyncIterator[AsyncSession]:
if self._sessionmaker is None:
raise Exception("DatabaseSessionManager is not initialized")
session = self._sessionmaker()
try:
yield session
except Exception:
await session.rollback()
raise
finally:
await session.close()

View File

@ -4,6 +4,7 @@ from collections.abc import AsyncGenerator, Awaitable, Callable, Generator
from dataclasses import dataclass from dataclasses import dataclass
from typing import Self from typing import Self
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sshecret_admin.auth import User from sshecret_admin.auth import User
from sshecret_admin.services import AdminBackend from sshecret_admin.services import AdminBackend

View File

@ -31,3 +31,8 @@ class AdminServerSettings(BaseSettings):
def admin_db(self) -> URL: def admin_db(self) -> URL:
"""Construct database url.""" """Construct database url."""
return URL.create(drivername="sqlite", database=self.database) return URL.create(drivername="sqlite", database=self.database)
@property
def async_db_url(self) -> URL:
"""Construct database url with sync handling."""
return URL.create(drivername="sqlite+aiosqlite", database=self.database)

View File

@ -1,10 +1,11 @@
"""Frontend dependencies.""" """Frontend dependencies."""
from dataclasses import dataclass from dataclasses import dataclass
from collections.abc import Callable, Awaitable from collections.abc import AsyncGenerator, Callable, Awaitable
from typing import Self from typing import Self
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from jinja2_fragments.fastapi import Jinja2Blocks from jinja2_fragments.fastapi import Jinja2Blocks
from fastapi import Request from fastapi import Request
@ -14,6 +15,7 @@ from sshecret_admin.auth.models import User
UserTokenDep = Callable[[Request, Session], Awaitable[User]] UserTokenDep = Callable[[Request, Session], Awaitable[User]]
UserLoginDep = Callable[[Request, Session], Awaitable[bool]] UserLoginDep = Callable[[Request, Session], Awaitable[bool]]
AsyncSessionDep = Callable[[], AsyncGenerator[AsyncSession, None]]
@dataclass @dataclass
@ -25,6 +27,7 @@ class FrontendDependencies(BaseDependencies):
get_user_from_access_token: UserTokenDep get_user_from_access_token: UserTokenDep
get_user_from_refresh_token: UserTokenDep get_user_from_refresh_token: UserTokenDep
get_login_status: UserLoginDep get_login_status: UserLoginDep
get_async_session: AsyncSessionDep
@classmethod @classmethod
def create( def create(
@ -35,6 +38,7 @@ class FrontendDependencies(BaseDependencies):
get_user_from_access_token: UserTokenDep, get_user_from_access_token: UserTokenDep,
get_user_from_refresh_token: UserTokenDep, get_user_from_refresh_token: UserTokenDep,
get_login_status: UserLoginDep, get_login_status: UserLoginDep,
get_async_session: AsyncSessionDep
) -> Self: ) -> Self:
"""Create from base dependencies.""" """Create from base dependencies."""
return cls( return cls(
@ -45,4 +49,5 @@ class FrontendDependencies(BaseDependencies):
get_user_from_access_token=get_user_from_access_token, get_user_from_access_token=get_user_from_access_token,
get_user_from_refresh_token=get_user_from_refresh_token, get_user_from_refresh_token=get_user_from_refresh_token,
get_login_status=get_login_status, get_login_status=get_login_status,
get_async_session=get_async_session,
) )

View File

@ -19,6 +19,7 @@ from starlette.datastructures import URL
from sshecret_admin.auth import PasswordDB, User, decode_token from sshecret_admin.auth import PasswordDB, User, decode_token
from sshecret_admin.core.dependencies import BaseDependencies from sshecret_admin.core.dependencies import BaseDependencies
from sshecret_admin.services.admin_backend import AdminBackend from sshecret_admin.services.admin_backend import AdminBackend
from sshecret_admin.core.db import DatabaseSessionManager
from .dependencies import FrontendDependencies from .dependencies import FrontendDependencies
from .exceptions import RedirectException from .exceptions import RedirectException
@ -47,7 +48,9 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
session: Annotated[Session, Depends(dependencies.get_db_session)] session: Annotated[Session, Depends(dependencies.get_db_session)]
): ):
"""Get admin backend API.""" """Get admin backend API."""
password_db = session.scalars(select(PasswordDB).where(PasswordDB.id == 1)).first() password_db = session.scalars(
select(PasswordDB).where(PasswordDB.id == 1)
).first()
if not password_db: if not password_db:
raise HTTPException( raise HTTPException(
500, detail="Error: The password manager has not yet been set up." 500, detail="Error: The password manager has not yet been set up."
@ -116,6 +119,12 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
return False return False
return True return True
async def get_async_session():
"""Get async session."""
sessionmanager = DatabaseSessionManager(dependencies.settings.async_db_url)
async with sessionmanager.session() as session:
yield session
view_dependencies = FrontendDependencies.create( view_dependencies = FrontendDependencies.create(
dependencies, dependencies,
get_admin_backend, get_admin_backend,
@ -123,6 +132,7 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
get_user_from_access_token, get_user_from_access_token,
get_user_from_refresh_token, get_user_from_refresh_token,
get_login_status, get_login_status,
get_async_session,
) )
app.include_router(audit.create_router(view_dependencies)) app.include_router(audit.create_router(view_dependencies))

View File

@ -8,13 +8,13 @@ from fastapi import APIRouter, Depends, Query, Request, Response, status
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.orm import Session from sqlalchemy.ext.asyncio import AsyncSession
from sshecret_admin.services import AdminBackend from sshecret_admin.services import AdminBackend
from starlette.datastructures import URL from starlette.datastructures import URL
from sshecret_admin.auth import ( from sshecret_admin.auth import (
User, User,
authenticate_user, authenticate_user_async,
create_access_token, create_access_token,
create_refresh_token, create_refresh_token,
) )
@ -80,7 +80,7 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
async def login_user( async def login_user(
request: Request, request: Request,
response: Response, response: Response,
session: Annotated[Session, Depends(dependencies.get_db_session)], session: Annotated[AsyncSession, Depends(dependencies.get_async_session)],
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)], admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
form_data: Annotated[OAuth2PasswordRequestForm, Depends()], form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
next: Annotated[str, Query()] = "/dashboard", next: Annotated[str, Query()] = "/dashboard",
@ -100,7 +100,7 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
}, },
) )
user = authenticate_user(session, form_data.username, form_data.password) user = await authenticate_user_async(session, form_data.username, form_data.password)
login_failed = RedirectException( login_failed = RedirectException(
to=URL("/login").include_query_params( to=URL("/login").include_query_params(
error_title="Login Error", error_message="Invalid username or password" error_title="Login Error", error_message="Invalid username or password"