Implement async db access in admin
This commit is contained in:
@ -1,11 +1,15 @@
|
||||
"""Database setup."""
|
||||
|
||||
from collections.abc import Generator, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from collections.abc import AsyncIterator, Generator, Callable
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.engine import URL
|
||||
from sqlalchemy import create_engine, Engine
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine, AsyncEngine, AsyncSession, async_sessionmaker
|
||||
|
||||
|
||||
def setup_database(
|
||||
db_url: URL | str,
|
||||
@ -20,3 +24,43 @@ def setup_database(
|
||||
yield session
|
||||
|
||||
return engine, get_db_session
|
||||
|
||||
|
||||
class DatabaseSessionManager:
|
||||
def __init__(self, host: URL | str, **engine_kwargs: str):
|
||||
self._engine: AsyncEngine | None = create_async_engine(host, **engine_kwargs)
|
||||
self._sessionmaker: async_sessionmaker[AsyncSession] | None = async_sessionmaker(autocommit=False, bind=self._engine, expire_on_commit=False)
|
||||
|
||||
async def close(self):
|
||||
if self._engine is None:
|
||||
raise Exception("DatabaseSessionManager is not initialized")
|
||||
await self._engine.dispose()
|
||||
|
||||
self._engine = None
|
||||
self._sessionmaker = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect(self) -> AsyncIterator[AsyncConnection]:
|
||||
if self._engine is None:
|
||||
raise Exception("DatabaseSessionManager is not initialized")
|
||||
|
||||
async with self._engine.begin() as connection:
|
||||
try:
|
||||
yield connection
|
||||
except Exception:
|
||||
await connection.rollback()
|
||||
raise
|
||||
|
||||
@asynccontextmanager
|
||||
async def session(self) -> AsyncIterator[AsyncSession]:
|
||||
if self._sessionmaker is None:
|
||||
raise Exception("DatabaseSessionManager is not initialized")
|
||||
|
||||
session = self._sessionmaker()
|
||||
try:
|
||||
yield session
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
@ -4,6 +4,7 @@ from collections.abc import AsyncGenerator, Awaitable, Callable, Generator
|
||||
from dataclasses import dataclass
|
||||
from typing import Self
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
from sshecret_admin.auth import User
|
||||
from sshecret_admin.services import AdminBackend
|
||||
|
||||
@ -31,3 +31,8 @@ class AdminServerSettings(BaseSettings):
|
||||
def admin_db(self) -> URL:
|
||||
"""Construct database url."""
|
||||
return URL.create(drivername="sqlite", database=self.database)
|
||||
|
||||
@property
|
||||
def async_db_url(self) -> URL:
|
||||
"""Construct database url with sync handling."""
|
||||
return URL.create(drivername="sqlite+aiosqlite", database=self.database)
|
||||
|
||||
Reference in New Issue
Block a user