Compare commits

...

2 Commits

Author SHA1 Message Date
5865cc450f Implement async db access in admin 2025-05-19 09:22:02 +02:00
fc0c3fb950 Refactor to use async database model 2025-05-19 09:15:48 +02:00
19 changed files with 373 additions and 192 deletions

View File

@ -2,6 +2,7 @@
from .authentication import ( from .authentication import (
authenticate_user, authenticate_user,
authenticate_user_async,
create_access_token, create_access_token,
create_refresh_token, create_refresh_token,
check_password, check_password,
@ -16,6 +17,7 @@ __all__ = [
"Token", "Token",
"User", "User",
"authenticate_user", "authenticate_user",
"authenticate_user_async",
"check_password", "check_password",
"create_access_token", "create_access_token",
"create_refresh_token", "create_refresh_token",

View File

@ -8,6 +8,7 @@ import bcrypt
import jwt import jwt
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sshecret_admin.core.settings import AdminServerSettings from sshecret_admin.core.settings import AdminServerSettings
@ -72,6 +73,16 @@ def check_password(plain_password: str, hashed_password: str) -> None:
raise AuthenticationFailedError() raise AuthenticationFailedError()
async def authenticate_user_async(session: AsyncSession, username: str, password: str) -> User | None:
"""Authenticate user async."""
user = (await session.scalars(select(User).where(User.username == username))).first()
if not user:
return None
if not verify_password(password, user.hashed_password):
return None
return user
def authenticate_user(session: Session, username: str, password: str) -> User | None: def authenticate_user(session: Session, username: str, password: str) -> User | None:
"""Authenticate user.""" """Authenticate user."""
user = session.scalars(select(User).where(User.username == username)).first() user = session.scalars(select(User).where(User.username == username)).first()

View File

@ -1,11 +1,15 @@
"""Database setup.""" """Database setup."""
from collections.abc import Generator, Callable from contextlib import asynccontextmanager
from collections.abc import AsyncIterator, Generator, Callable
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy.engine import URL from sqlalchemy.engine import URL
from sqlalchemy import create_engine, Engine from sqlalchemy import create_engine, Engine
from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine, AsyncEngine, AsyncSession, async_sessionmaker
def setup_database( def setup_database(
db_url: URL | str, db_url: URL | str,
@ -20,3 +24,43 @@ def setup_database(
yield session yield session
return engine, get_db_session return engine, get_db_session
class DatabaseSessionManager:
def __init__(self, host: URL | str, **engine_kwargs: str):
self._engine: AsyncEngine | None = create_async_engine(host, **engine_kwargs)
self._sessionmaker: async_sessionmaker[AsyncSession] | None = async_sessionmaker(autocommit=False, bind=self._engine, expire_on_commit=False)
async def close(self):
if self._engine is None:
raise Exception("DatabaseSessionManager is not initialized")
await self._engine.dispose()
self._engine = None
self._sessionmaker = None
@asynccontextmanager
async def connect(self) -> AsyncIterator[AsyncConnection]:
if self._engine is None:
raise Exception("DatabaseSessionManager is not initialized")
async with self._engine.begin() as connection:
try:
yield connection
except Exception:
await connection.rollback()
raise
@asynccontextmanager
async def session(self) -> AsyncIterator[AsyncSession]:
if self._sessionmaker is None:
raise Exception("DatabaseSessionManager is not initialized")
session = self._sessionmaker()
try:
yield session
except Exception:
await session.rollback()
raise
finally:
await session.close()

View File

@ -4,6 +4,7 @@ from collections.abc import AsyncGenerator, Awaitable, Callable, Generator
from dataclasses import dataclass from dataclasses import dataclass
from typing import Self from typing import Self
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sshecret_admin.auth import User from sshecret_admin.auth import User
from sshecret_admin.services import AdminBackend from sshecret_admin.services import AdminBackend

View File

@ -31,3 +31,8 @@ class AdminServerSettings(BaseSettings):
def admin_db(self) -> URL: def admin_db(self) -> URL:
"""Construct database url.""" """Construct database url."""
return URL.create(drivername="sqlite", database=self.database) return URL.create(drivername="sqlite", database=self.database)
@property
def async_db_url(self) -> URL:
"""Construct database url with sync handling."""
return URL.create(drivername="sqlite+aiosqlite", database=self.database)

View File

@ -1,10 +1,11 @@
"""Frontend dependencies.""" """Frontend dependencies."""
from dataclasses import dataclass from dataclasses import dataclass
from collections.abc import Callable, Awaitable from collections.abc import AsyncGenerator, Callable, Awaitable
from typing import Self from typing import Self
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from jinja2_fragments.fastapi import Jinja2Blocks from jinja2_fragments.fastapi import Jinja2Blocks
from fastapi import Request from fastapi import Request
@ -14,6 +15,7 @@ from sshecret_admin.auth.models import User
UserTokenDep = Callable[[Request, Session], Awaitable[User]] UserTokenDep = Callable[[Request, Session], Awaitable[User]]
UserLoginDep = Callable[[Request, Session], Awaitable[bool]] UserLoginDep = Callable[[Request, Session], Awaitable[bool]]
AsyncSessionDep = Callable[[], AsyncGenerator[AsyncSession, None]]
@dataclass @dataclass
@ -25,6 +27,7 @@ class FrontendDependencies(BaseDependencies):
get_user_from_access_token: UserTokenDep get_user_from_access_token: UserTokenDep
get_user_from_refresh_token: UserTokenDep get_user_from_refresh_token: UserTokenDep
get_login_status: UserLoginDep get_login_status: UserLoginDep
get_async_session: AsyncSessionDep
@classmethod @classmethod
def create( def create(
@ -35,6 +38,7 @@ class FrontendDependencies(BaseDependencies):
get_user_from_access_token: UserTokenDep, get_user_from_access_token: UserTokenDep,
get_user_from_refresh_token: UserTokenDep, get_user_from_refresh_token: UserTokenDep,
get_login_status: UserLoginDep, get_login_status: UserLoginDep,
get_async_session: AsyncSessionDep
) -> Self: ) -> Self:
"""Create from base dependencies.""" """Create from base dependencies."""
return cls( return cls(
@ -45,4 +49,5 @@ class FrontendDependencies(BaseDependencies):
get_user_from_access_token=get_user_from_access_token, get_user_from_access_token=get_user_from_access_token,
get_user_from_refresh_token=get_user_from_refresh_token, get_user_from_refresh_token=get_user_from_refresh_token,
get_login_status=get_login_status, get_login_status=get_login_status,
get_async_session=get_async_session,
) )

View File

@ -19,6 +19,7 @@ from starlette.datastructures import URL
from sshecret_admin.auth import PasswordDB, User, decode_token from sshecret_admin.auth import PasswordDB, User, decode_token
from sshecret_admin.core.dependencies import BaseDependencies from sshecret_admin.core.dependencies import BaseDependencies
from sshecret_admin.services.admin_backend import AdminBackend from sshecret_admin.services.admin_backend import AdminBackend
from sshecret_admin.core.db import DatabaseSessionManager
from .dependencies import FrontendDependencies from .dependencies import FrontendDependencies
from .exceptions import RedirectException from .exceptions import RedirectException
@ -47,7 +48,9 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
session: Annotated[Session, Depends(dependencies.get_db_session)] session: Annotated[Session, Depends(dependencies.get_db_session)]
): ):
"""Get admin backend API.""" """Get admin backend API."""
password_db = session.scalars(select(PasswordDB).where(PasswordDB.id == 1)).first() password_db = session.scalars(
select(PasswordDB).where(PasswordDB.id == 1)
).first()
if not password_db: if not password_db:
raise HTTPException( raise HTTPException(
500, detail="Error: The password manager has not yet been set up." 500, detail="Error: The password manager has not yet been set up."
@ -116,6 +119,12 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
return False return False
return True return True
async def get_async_session():
"""Get async session."""
sessionmanager = DatabaseSessionManager(dependencies.settings.async_db_url)
async with sessionmanager.session() as session:
yield session
view_dependencies = FrontendDependencies.create( view_dependencies = FrontendDependencies.create(
dependencies, dependencies,
get_admin_backend, get_admin_backend,
@ -123,6 +132,7 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
get_user_from_access_token, get_user_from_access_token,
get_user_from_refresh_token, get_user_from_refresh_token,
get_login_status, get_login_status,
get_async_session,
) )
app.include_router(audit.create_router(view_dependencies)) app.include_router(audit.create_router(view_dependencies))

View File

@ -8,13 +8,13 @@ from fastapi import APIRouter, Depends, Query, Request, Response, status
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.orm import Session from sqlalchemy.ext.asyncio import AsyncSession
from sshecret_admin.services import AdminBackend from sshecret_admin.services import AdminBackend
from starlette.datastructures import URL from starlette.datastructures import URL
from sshecret_admin.auth import ( from sshecret_admin.auth import (
User, User,
authenticate_user, authenticate_user_async,
create_access_token, create_access_token,
create_refresh_token, create_refresh_token,
) )
@ -80,7 +80,7 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
async def login_user( async def login_user(
request: Request, request: Request,
response: Response, response: Response,
session: Annotated[Session, Depends(dependencies.get_db_session)], session: Annotated[AsyncSession, Depends(dependencies.get_async_session)],
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)], admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
form_data: Annotated[OAuth2PasswordRequestForm, Depends()], form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
next: Annotated[str, Query()] = "/dashboard", next: Annotated[str, Query()] = "/dashboard",
@ -100,7 +100,7 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
}, },
) )
user = authenticate_user(session, form_data.username, form_data.password) user = await authenticate_user_async(session, form_data.username, form_data.password)
login_failed = RedirectException( login_failed = RedirectException(
to=URL("/login").include_query_params( to=URL("/login").include_query_params(
error_title="Login Error", error_message="Invalid username or password" error_title="Login Error", error_message="Invalid username or password"

View File

@ -3,18 +3,18 @@
# pyright: reportUnusedFunction=false # pyright: reportUnusedFunction=false
import logging import logging
from collections.abc import Sequence from typing import Any
from typing import Any, cast from fastapi import APIRouter, Depends, Request
from fastapi import APIRouter, Depends, Request, Query
from pydantic import BaseModel, Field, TypeAdapter from pydantic import BaseModel, Field, TypeAdapter
from sqlalchemy import select, func, and_ from sqlalchemy import select, func, and_
from sqlalchemy.orm import InstrumentedAttribute, Session from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import InstrumentedAttribute
from sqlalchemy.sql.expression import ColumnExpressionArgument from sqlalchemy.sql.expression import ColumnExpressionArgument
from typing import Annotated from typing import Annotated
from sshecret_backend.models import AuditLog, Operation, SubSystem from sshecret_backend.models import AuditLog, Operation, SubSystem
from sshecret_backend.types import DBSessionDep from sshecret_backend.types import AsyncDBSessionDep
from sshecret_backend.view_models import AuditInfo, AuditView, AuditListResult from sshecret_backend.view_models import AuditInfo, AuditView, AuditListResult
@ -58,24 +58,23 @@ class AuditFilter(BaseModel):
] ]
def get_audit_api(get_db_session: DBSessionDep) -> APIRouter: def get_audit_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Construct audit sub-api.""" """Construct audit sub-api."""
router = APIRouter() router = APIRouter()
@router.get("/audit/", response_model=AuditListResult) @router.get("/audit/", response_model=AuditListResult)
async def get_audit_logs( async def get_audit_logs(
request: Request, session: Annotated[AsyncSession, Depends(get_db_session)],
session: Annotated[Session, Depends(get_db_session)],
filters: Annotated[AuditFilter, Depends()], filters: Annotated[AuditFilter, Depends()],
) -> AuditListResult: ) -> AuditListResult:
"""Get audit logs.""" """Get audit logs."""
# audit.audit_access_audit_log(session, request) # audit.audit_access_audit_log(session, request)
total = session.scalars( total = (await session.scalars(
select(func.count("*")) select(func.count("*"))
.select_from(AuditLog) .select_from(AuditLog)
.where(and_(True, *filters.filter_mapping)) .where(and_(True, *filters.filter_mapping))
).one() )).one()
remaining = total - filters.offset remaining = total - filters.offset
statement = ( statement = (
@ -87,7 +86,7 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
) )
LogAdapt = TypeAdapter(list[AuditView]) LogAdapt = TypeAdapter(list[AuditView])
results = session.scalars(statement).all() results = (await session.scalars(statement)).all()
entries = LogAdapt.validate_python(results, from_attributes=True) entries = LogAdapt.validate_python(results, from_attributes=True)
return AuditListResult( return AuditListResult(
results=entries, results=entries,
@ -97,24 +96,23 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
@router.post("/audit/") @router.post("/audit/")
async def add_audit_log( async def add_audit_log(
request: Request, session: Annotated[AsyncSession, Depends(get_db_session)],
session: Annotated[Session, Depends(get_db_session)],
entry: AuditView, entry: AuditView,
) -> AuditView: ) -> AuditView:
"""Add entry to audit log.""" """Add entry to audit log."""
audit_log = AuditLog(**entry.model_dump(exclude_none=True)) audit_log = AuditLog(**entry.model_dump(exclude_none=True))
session.add(audit_log) session.add(audit_log)
session.commit() await session.commit()
return AuditView.model_validate(audit_log, from_attributes=True) return AuditView.model_validate(audit_log, from_attributes=True)
@router.get("/audit/info") @router.get("/audit/info")
async def get_audit_info( async def get_audit_info(
request: Request, session: Annotated[Session, Depends(get_db_session)] session: Annotated[AsyncSession, Depends(get_db_session)]
) -> AuditInfo: ) -> AuditInfo:
"""Get audit info.""" """Get audit info."""
audit_count = session.scalars( audit_count = (await session.scalars(
select(func.count("*")).select_from(AuditLog) select(func.count("*")).select_from(AuditLog)
).one() )).one()
return AuditInfo(entries=audit_count) return AuditInfo(entries=audit_count)
return router return router

View File

@ -4,14 +4,14 @@
import uuid import uuid
import logging import logging
from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from typing import Annotated, Any, Self, TypeVar, cast from typing import Annotated, Any, Self, TypeVar, cast
from sqlalchemy import select, func from sqlalchemy import select, func
from sqlalchemy.orm import Session from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import Select from sqlalchemy.sql import Select
from sshecret_backend.types import DBSessionDep from sshecret_backend.types import AsyncDBSessionDep
from sshecret_backend.models import Client, ClientSecret from sshecret_backend.models import Client, ClientSecret
from sshecret_backend.view_models import ( from sshecret_backend.view_models import (
ClientCreate, ClientCreate,
@ -20,7 +20,7 @@ from sshecret_backend.view_models import (
ClientUpdate, ClientUpdate,
) )
from sshecret_backend import audit from sshecret_backend import audit
from .common import get_client_by_id_or_name from .common import get_client_by_id_or_name, client_with_relationships
class ClientListParams(BaseModel): class ClientListParams(BaseModel):
@ -74,30 +74,30 @@ def filter_client_statement(
return statement.limit(params.limit).offset(params.offset) return statement.limit(params.limit).offset(params.offset)
def get_clients_api(get_db_session: DBSessionDep) -> APIRouter: def get_clients_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Construct clients sub-api.""" """Construct clients sub-api."""
router = APIRouter() router = APIRouter()
@router.get("/clients/") @router.get("/clients/")
async def get_clients( async def get_clients(
filter_query: Annotated[ClientListParams, Query()], filter_query: Annotated[ClientListParams, Query()],
session: Annotated[Session, Depends(get_db_session)], session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientQueryResult: ) -> ClientQueryResult:
"""Get clients.""" """Get clients."""
# Get total results first # Get total results first
count_statement = select(func.count("*")).select_from(Client) count_statement = select(func.count("*")).select_from(Client)
count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True)) count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True))
total_results = session.scalars(count_statement).one() total_results = (await session.scalars(count_statement)).one()
statement = filter_client_statement(select(Client), filter_query, False) statement = filter_client_statement(client_with_relationships(), filter_query, False)
results = session.scalars(statement) results = await session.scalars(statement)
remainder = total_results - filter_query.offset - filter_query.limit remainder = total_results - filter_query.offset - filter_query.limit
if remainder < 0: if remainder < 0:
remainder = 0 remainder = 0
clients = list(results) clients = list(results.all())
clients_view = ClientView.from_client_list(clients) clients_view = ClientView.from_client_list(clients)
return ClientQueryResult( return ClientQueryResult(
clients=clients_view, clients=clients_view,
@ -108,7 +108,7 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
@router.get("/clients/{name}") @router.get("/clients/{name}")
async def get_client( async def get_client(
name: str, name: str,
session: Annotated[Session, Depends(get_db_session)], session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView: ) -> ClientView:
"""Fetch a client.""" """Fetch a client."""
client = await get_client_by_id_or_name(session, name) client = await get_client_by_id_or_name(session, name)
@ -122,7 +122,7 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
async def delete_client( async def delete_client(
request: Request, request: Request,
name: str, name: str,
session: Annotated[Session, Depends(get_db_session)], session: Annotated[AsyncSession, Depends(get_db_session)],
) -> None: ) -> None:
"""Delete a client.""" """Delete a client."""
client = await get_client_by_id_or_name(session, name) client = await get_client_by_id_or_name(session, name)
@ -131,15 +131,15 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
status_code=404, detail="Cannot find a client with the given name." status_code=404, detail="Cannot find a client with the given name."
) )
session.delete(client) await session.delete(client)
session.commit() await session.commit()
audit.audit_delete_client(session, request, client) await audit.audit_delete_client(session, request, client)
@router.post("/clients/") @router.post("/clients/")
async def create_client( async def create_client(
request: Request, request: Request,
client: ClientCreate, client: ClientCreate,
session: Annotated[Session, Depends(get_db_session)], session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView: ) -> ClientView:
"""Create client.""" """Create client."""
existing = await get_client_by_id_or_name(session, client.name) existing = await get_client_by_id_or_name(session, client.name)
@ -148,9 +148,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
db_client = client.to_client() db_client = client.to_client()
session.add(db_client) session.add(db_client)
session.commit() await session.commit()
session.refresh(db_client) await session.refresh(db_client)
audit.audit_create_client(session, request, db_client) db_client = await get_client_by_id_or_name(session, client.name)
if not db_client:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Could not create the client.")
await audit.audit_create_client(session, request, db_client)
return ClientView.from_client(db_client) return ClientView.from_client(db_client)
@router.post("/clients/{name}/public-key") @router.post("/clients/{name}/public-key")
@ -158,7 +161,7 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
request: Request, request: Request,
name: str, name: str,
client_update: ClientUpdate, client_update: ClientUpdate,
session: Annotated[Session, Depends(get_db_session)], session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView: ) -> ClientView:
"""Change the public key of a client. """Change the public key of a client.
@ -170,17 +173,16 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
status_code=404, detail="Cannot find a client with the given name." status_code=404, detail="Cannot find a client with the given name."
) )
client.public_key = client_update.public_key client.public_key = client_update.public_key
for secret in session.scalars( matching_secrets = await session.scalars(select(ClientSecret).where(ClientSecret.client_id == client.id))
select(ClientSecret).where(ClientSecret.client_id == client.id) for secret in matching_secrets.all():
).all():
LOG.debug("Invalidated secret %s", secret.id) LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True secret.invalidated = True
secret.client_id = None secret.client_id = None
session.add(client) session.add(client)
session.refresh(client) await session.refresh(client)
session.commit() await session.commit()
audit.audit_invalidate_secrets(session, request, client) await audit.audit_invalidate_secrets(session, request, client)
return ClientView.from_client(client) return ClientView.from_client(client)
@ -189,7 +191,7 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
request: Request, request: Request,
name: str, name: str,
client_update: ClientCreate, client_update: ClientCreate,
session: Annotated[Session, Depends(get_db_session)], session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView: ) -> ClientView:
"""Change the public key of a client. """Change the public key of a client.
@ -205,19 +207,20 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
public_key_updated = False public_key_updated = False
if client_update.public_key != client.public_key: if client_update.public_key != client.public_key:
public_key_updated = True public_key_updated = True
for secret in session.scalars( client_secrets = await session.scalars(
select(ClientSecret).where(ClientSecret.client_id == client.id) select(ClientSecret).where(ClientSecret.client_id == client.id)
).all(): )
for secret in client_secrets.all():
LOG.debug("Invalidated secret %s", secret.id) LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True secret.invalidated = True
secret.client_id = None secret.client_id = None
session.add(client) session.add(client)
session.commit() await session.commit()
session.refresh(client) await session.refresh(client)
audit.audit_update_client(session, request, client) await audit.audit_update_client(session, request, client)
if public_key_updated: if public_key_updated:
audit.audit_invalidate_secrets(session, request, client) await audit.audit_invalidate_secrets(session, request, client)
return ClientView.from_client(client) return ClientView.from_client(client)

View File

@ -3,9 +3,11 @@
import re import re
import uuid import uuid
import bcrypt import bcrypt
from sqlalchemy import Select
from sqlalchemy.orm import selectinload
from sqlalchemy import select from sqlalchemy.future import select
from sqlalchemy.orm import Session from sqlalchemy.ext.asyncio import AsyncSession
from sshecret_backend.models import Client from sshecret_backend.models import Client
@ -17,20 +19,38 @@ def verify_token(token: str, stored_hash: str) -> bool:
stored_bytes = stored_hash.encode("utf-8") stored_bytes = stored_hash.encode("utf-8")
return bcrypt.checkpw(token_bytes, stored_bytes) return bcrypt.checkpw(token_bytes, stored_bytes)
async def reload_client_with_relationships(session: AsyncSession, client: Client) -> Client:
"""Reload a client from the database."""
session.expunge(client)
stmt = (
select(Client)
.options(selectinload(Client.policies), selectinload(Client.secrets))
.where(Client.id == client.id)
)
result = await session.execute(stmt)
return result.scalar_one()
async def get_client_by_name(session: Session, name: str) -> Client | None:
def client_with_relationships() -> Select[tuple[Client]]:
"""Base select statement for client with relationships."""
return select(Client).options(
selectinload(Client.secrets),
selectinload(Client.policies),
)
async def get_client_by_name(session: AsyncSession, name: str) -> Client | None:
"""Get client by name.""" """Get client by name."""
client_filter = select(Client).where(Client.name == name) client_filter = client_with_relationships().where(Client.name == name)
client_results = session.scalars(client_filter) client_results = await session.execute(client_filter)
return client_results.first() return client_results.scalars().first()
async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None: async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | None:
"""Get client by name.""" """Get client by ID."""
client_filter = select(Client).where(Client.id == id) client_filter = client_with_relationships().where(Client.id == id)
client_results = session.scalars(client_filter) client_results = await session.execute(client_filter)
return client_results.first() return client_results.scalars().first()
async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None: async def get_client_by_id_or_name(session: AsyncSession, id_or_name: str) -> Client | None:
"""Get client either by id or name.""" """Get client either by id or name."""
if RE_UUID.match(id_or_name): if RE_UUID.match(id_or_name):
id = uuid.UUID(id_or_name) id = uuid.UUID(id_or_name)

View File

@ -5,7 +5,7 @@
import logging import logging
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.ext.asyncio import AsyncSession
from typing import Annotated from typing import Annotated
from sshecret_backend.models import ClientAccessPolicy from sshecret_backend.models import ClientAccessPolicy
@ -13,21 +13,21 @@ from sshecret_backend.view_models import (
ClientPolicyView, ClientPolicyView,
ClientPolicyUpdate, ClientPolicyUpdate,
) )
from sshecret_backend.types import DBSessionDep from sshecret_backend.types import AsyncDBSessionDep
from sshecret_backend import audit from sshecret_backend import audit
from .common import get_client_by_id_or_name from .common import get_client_by_id_or_name, reload_client_with_relationships
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def get_policy_api(get_db_session: DBSessionDep) -> APIRouter: def get_policy_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Construct clients sub-api.""" """Construct clients sub-api."""
router = APIRouter() router = APIRouter()
@router.get("/clients/{name}/policies/") @router.get("/clients/{name}/policies/")
async def get_client_policies( async def get_client_policies(
name: str, session: Annotated[Session, Depends(get_db_session)] name: str, session: Annotated[AsyncSession, Depends(get_db_session)]
) -> ClientPolicyView: ) -> ClientPolicyView:
"""Get client policies.""" """Get client policies."""
client = await get_client_by_id_or_name(session, name) client = await get_client_by_id_or_name(session, name)
@ -43,7 +43,7 @@ def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
request: Request, request: Request,
name: str, name: str,
policy_update: ClientPolicyUpdate, policy_update: ClientPolicyUpdate,
session: Annotated[Session, Depends(get_db_session)], session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientPolicyView: ) -> ClientPolicyView:
"""Update client policies. """Update client policies.
@ -55,28 +55,31 @@ def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
status_code=404, detail="Cannot find a client with the given name." status_code=404, detail="Cannot find a client with the given name."
) )
# Remove old policies. # Remove old policies.
policies = session.scalars( policies = await session.scalars(
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id) select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
).all() )
deleted_policies: list[ClientAccessPolicy] = [] deleted_policies: list[ClientAccessPolicy] = []
added_policies: list[ClientAccessPolicy] = [] added_policies: list[ClientAccessPolicy] = []
for policy in policies: for policy in policies.all():
session.delete(policy) await session.delete(policy)
deleted_policies.append(policy) deleted_policies.append(policy)
LOG.debug("Updating client policies with: %r", policy_update.sources)
for source in policy_update.sources: for source in policy_update.sources:
LOG.debug("Source %r", source) LOG.debug("Source %r", source)
policy = ClientAccessPolicy(source=str(source), client_id=client.id) policy = ClientAccessPolicy(source=str(source), client_id=client.id)
session.add(policy) session.add(policy)
added_policies.append(policy) added_policies.append(policy)
session.commit() await session.flush()
session.refresh(client) await session.commit()
client = await reload_client_with_relationships(session, client)
for policy in deleted_policies: for policy in deleted_policies:
audit.audit_remove_policy(session, request, client, policy) await audit.audit_remove_policy(session, request, client, policy)
for policy in added_policies: for policy in added_policies:
audit.audit_update_policy(session, request, client, policy) await audit.audit_update_policy(session, request, client, policy)
return ClientPolicyView.from_client(client) return ClientPolicyView.from_client(client)

View File

@ -6,9 +6,11 @@ import logging
from collections import defaultdict from collections import defaultdict
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session
from typing import Annotated from typing import Annotated
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sshecret_backend.models import Client, ClientSecret from sshecret_backend.models import Client, ClientSecret
from sshecret_backend.view_models import ( from sshecret_backend.view_models import (
ClientReference, ClientReference,
@ -19,7 +21,7 @@ from sshecret_backend.view_models import (
ClientSecretResponse, ClientSecretResponse,
) )
from sshecret_backend import audit from sshecret_backend import audit
from sshecret_backend.types import DBSessionDep from sshecret_backend.types import AsyncDBSessionDep
from .common import get_client_by_id_or_name from .common import get_client_by_id_or_name
@ -27,7 +29,7 @@ LOG = logging.getLogger(__name__)
async def lookup_client_secret( async def lookup_client_secret(
session: Session, client: Client, name: str session: AsyncSession, client: Client, name: str
) -> ClientSecret | None: ) -> ClientSecret | None:
"""Look up a secret for a client.""" """Look up a secret for a client."""
statement = ( statement = (
@ -35,11 +37,11 @@ async def lookup_client_secret(
.where(ClientSecret.client_id == client.id) .where(ClientSecret.client_id == client.id)
.where(ClientSecret.name == name) .where(ClientSecret.name == name)
) )
results = session.scalars(statement) results = await session.scalars(statement)
return results.first() return results.first()
def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter: def get_secrets_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Construct clients sub-api.""" """Construct clients sub-api."""
router = APIRouter() router = APIRouter()
@ -48,7 +50,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
request: Request, request: Request,
name: str, name: str,
client_secret: ClientSecretPublic, client_secret: ClientSecretPublic,
session: Annotated[Session, Depends(get_db_session)], session: Annotated[AsyncSession, Depends(get_db_session)],
) -> None: ) -> None:
"""Add secret to a client.""" """Add secret to a client."""
client = await get_client_by_id_or_name(session, name) client = await get_client_by_id_or_name(session, name)
@ -69,9 +71,9 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
name=client_secret.name, client_id=client.id, secret=client_secret.secret name=client_secret.name, client_id=client.id, secret=client_secret.secret
) )
session.add(db_secret) session.add(db_secret)
session.commit() await session.commit()
session.refresh(db_secret) await session.refresh(db_secret)
audit.audit_create_secret(session, request, client, db_secret) await audit.audit_create_secret(session, request, client, db_secret)
@router.put("/clients/{name}/secrets/{secret_name}") @router.put("/clients/{name}/secrets/{secret_name}")
async def update_client_secret( async def update_client_secret(
@ -79,7 +81,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
name: str, name: str,
secret_name: str, secret_name: str,
secret_data: BodyValue, secret_data: BodyValue,
session: Annotated[Session, Depends(get_db_session)], session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretResponse: ) -> ClientSecretResponse:
"""Update a client secret. """Update a client secret.
@ -96,9 +98,9 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
existing_secret.secret = secret_data.value existing_secret.secret = secret_data.value
session.add(existing_secret) session.add(existing_secret)
session.commit() await session.commit()
session.refresh(existing_secret) await session.refresh(existing_secret)
audit.audit_update_secret(session, request, client, existing_secret) await audit.audit_update_secret(session, request, client, existing_secret)
return ClientSecretResponse.from_client_secret(existing_secret) return ClientSecretResponse.from_client_secret(existing_secret)
db_secret = ClientSecret( db_secret = ClientSecret(
@ -107,9 +109,9 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
secret=secret_data.value, secret=secret_data.value,
) )
session.add(db_secret) session.add(db_secret)
session.commit() await session.commit()
session.refresh(db_secret) await session.refresh(db_secret)
audit.audit_create_secret(session, request, client, db_secret) await audit.audit_create_secret(session, request, client, db_secret)
return ClientSecretResponse.from_client_secret(db_secret) return ClientSecretResponse.from_client_secret(db_secret)
@router.get("/clients/{name}/secrets/{secret_name}") @router.get("/clients/{name}/secrets/{secret_name}")
@ -117,7 +119,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
request: Request, request: Request,
name: str, name: str,
secret_name: str, secret_name: str,
session: Annotated[Session, Depends(get_db_session)], session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretResponse: ) -> ClientSecretResponse:
"""Get a client secret.""" """Get a client secret."""
client = await get_client_by_id_or_name(session, name) client = await get_client_by_id_or_name(session, name)
@ -133,7 +135,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
) )
response_model = ClientSecretResponse.from_client_secret(secret) response_model = ClientSecretResponse.from_client_secret(secret)
audit.audit_access_secret(session, request, client, secret) await audit.audit_access_secret(session, request, client, secret)
return response_model return response_model
@router.delete("/clients/{name}/secrets/{secret_name}") @router.delete("/clients/{name}/secrets/{secret_name}")
@ -141,7 +143,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
request: Request, request: Request,
name: str, name: str,
secret_name: str, secret_name: str,
session: Annotated[Session, Depends(get_db_session)], session: Annotated[AsyncSession, Depends(get_db_session)],
) -> None: ) -> None:
"""Delete a secret.""" """Delete a secret."""
client = await get_client_by_id_or_name(session, name) client = await get_client_by_id_or_name(session, name)
@ -156,56 +158,69 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
status_code=404, detail="Cannot find a secret with the given name." status_code=404, detail="Cannot find a secret with the given name."
) )
session.delete(secret) await session.delete(secret)
session.commit() await session.commit()
audit.audit_delete_secret(session, request, client, secret) await audit.audit_delete_secret(session, request, client, secret)
@router.get("/secrets/") @router.get("/secrets/")
async def get_secret_map( async def get_secret_map(
request: Request, session: Annotated[Session, Depends(get_db_session)] session: Annotated[AsyncSession, Depends(get_db_session)],
) -> list[ClientSecretList]: ) -> list[ClientSecretList]:
"""Get a list of all secrets and which clients have them.""" """Get a list of all secrets and which clients have them."""
client_secret_map: defaultdict[str, list[str]] = defaultdict(list) client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
for client_secret in session.scalars(select(ClientSecret)).all(): client_secrets = await session.scalars(
select(ClientSecret).options(selectinload(ClientSecret.client))
)
for client_secret in client_secrets.all():
if not client_secret.client: if not client_secret.client:
if client_secret.name not in client_secret_map: if client_secret.name not in client_secret_map:
client_secret_map[client_secret.name] = [] client_secret_map[client_secret.name] = []
continue continue
client_secret_map[client_secret.name].append(client_secret.client.name) client_secret_map[client_secret.name].append(client_secret.client.name)
#audit.audit_client_secret_list(session, request) # audit.audit_client_secret_list(session, request)
return [ return [
ClientSecretList(name=secret_name, clients=clients) ClientSecretList(name=secret_name, clients=clients)
for secret_name, clients in client_secret_map.items() for secret_name, clients in client_secret_map.items()
] ]
@router.get("/secrets/detailed/") @router.get("/secrets/detailed/")
async def get_detailed_secret_map( async def get_detailed_secret_map(
request: Request, session: Annotated[Session, Depends(get_db_session)] session: Annotated[AsyncSession, Depends(get_db_session)],
) -> list[ClientSecretDetailList]: ) -> list[ClientSecretDetailList]:
"""Get a list of all secrets and which clients have them.""" """Get a list of all secrets and which clients have them."""
client_secrets: dict[str, ClientSecretDetailList] = {} client_secrets: dict[str, ClientSecretDetailList] = {}
for client_secret in session.scalars(select(ClientSecret)).all(): all_client_secrets = await session.execute(
select(ClientSecret).options(selectinload(ClientSecret.client))
)
for client_secret in all_client_secrets.scalars().all():
if client_secret.name not in client_secrets: if client_secret.name not in client_secrets:
client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name) client_secrets[client_secret.name] = ClientSecretDetailList(
name=client_secret.name
)
client_secrets[client_secret.name].ids.append(str(client_secret.id)) client_secrets[client_secret.name].ids.append(str(client_secret.id))
if not client_secret.client: if not client_secret.client:
continue continue
client_secrets[client_secret.name].clients.append(ClientReference(id=str(client_secret.client.id), name=client_secret.client.name)) client_secrets[client_secret.name].clients.append(
#`audit.audit_client_secret_list(session, request) ClientReference(
id=str(client_secret.client.id), name=client_secret.client.name
)
)
# `audit.audit_client_secret_list(session, request)
return list(client_secrets.values()) return list(client_secrets.values())
@router.get("/secrets/{name}") @router.get("/secrets/{name}")
async def get_secret_clients( async def get_secret_clients(
request: Request,
name: str, name: str,
session: Annotated[Session, Depends(get_db_session)], session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretList: ) -> ClientSecretList:
"""Get a list of which clients has a named secret.""" """Get a list of which clients has a named secret."""
clients: list[str] = [] clients: list[str] = []
for client_secret in session.scalars( client_secrets = await session.scalars(
select(ClientSecret).where(ClientSecret.name == name) select(ClientSecret)
).all(): .options(selectinload(ClientSecret.client))
.where(ClientSecret.name == name)
)
for client_secret in client_secrets.all():
if not client_secret.client: if not client_secret.client:
continue continue
clients.append(client_secret.client.name) clients.append(client_secret.client.name)
@ -214,19 +229,23 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
@router.get("/secrets/{name}/detailed") @router.get("/secrets/{name}/detailed")
async def get_secret_clients_detailed( async def get_secret_clients_detailed(
request: Request,
name: str, name: str,
session: Annotated[Session, Depends(get_db_session)], session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretDetailList: ) -> ClientSecretDetailList:
"""Get a list of which clients has a named secret.""" """Get a list of which clients has a named secret."""
detail_list = ClientSecretDetailList(name=name) detail_list = ClientSecretDetailList(name=name)
for client_secret in session.scalars( client_secrets = await session.scalars(
select(ClientSecret).where(ClientSecret.name == name) select(ClientSecret).where(ClientSecret.name == name)
).all(): )
for client_secret in client_secrets.all():
if not client_secret.client: if not client_secret.client:
continue continue
detail_list.ids.append(str(client_secret.id)) detail_list.ids.append(str(client_secret.id))
detail_list.clients.append(ClientReference(id=str(client_secret.client.id), name=client_secret.client.name)) detail_list.clients.append(
ClientReference(
id=str(client_secret.client.id), name=client_secret.client.name
)
)
return detail_list return detail_list

View File

@ -13,30 +13,32 @@ from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from sqlalchemy import Engine from sqlalchemy import Engine
from sqlalchemy.ext.asyncio import AsyncEngine
from .models import init_db from .models import init_db_async
from .backend_api import get_backend_api from .backend_api import get_backend_api
from .db import setup_database from .db import setup_database, get_async_engine
from .settings import BackendSettings from .settings import BackendSettings
from .types import DBSessionDep from .types import AsyncDBSessionDep
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def init_backend_app(engine: Engine, get_db_session: DBSessionDep) -> FastAPI: def init_backend_app(settings: BackendSettings) -> FastAPI:
"""Initialize backend app.""" """Initialize backend app."""
@asynccontextmanager @asynccontextmanager
async def lifespan(_app: FastAPI): async def lifespan(_app: FastAPI):
"""Create database before starting the server.""" """Create database before starting the server."""
LOG.debug("Running lifespan") LOG.debug("Running lifespan")
init_db(engine) engine = get_async_engine(settings.async_db_url)
await init_db_async(engine)
yield yield
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.include_router(get_backend_api(get_db_session)) app.include_router(get_backend_api(settings))
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
async def validation_exception_handler( async def validation_exception_handler(
@ -60,6 +62,4 @@ def init_backend_app(engine: Engine, get_db_session: DBSessionDep) -> FastAPI:
def create_backend_app(settings: BackendSettings) -> FastAPI: def create_backend_app(settings: BackendSettings) -> FastAPI:
"""Create the backend app.""" """Create the backend app."""
engine, get_db_session = setup_database(settings.db_url) return init_backend_app(settings)
return init_backend_app(engine, get_db_session)

View File

@ -3,7 +3,7 @@
from collections.abc import Sequence from collections.abc import Sequence
from fastapi import Request from fastapi import Request
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.ext.asyncio import AsyncSession
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy, Operation, SubSystem from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy, Operation, SubSystem
@ -17,8 +17,8 @@ def _get_origin(request: Request) -> str | None:
return origin return origin
def _write_audit_log( async def _write_audit_log(
session: Session, request: Request, entry: AuditLog, commit: bool = True session: AsyncSession, request: Request, entry: AuditLog, commit: bool = True
) -> None: ) -> None:
"""Write the audit log.""" """Write the audit log."""
origin = _get_origin(request) origin = _get_origin(request)
@ -26,11 +26,11 @@ def _write_audit_log(
entry.subsystem = SubSystem.BACKEND entry.subsystem = SubSystem.BACKEND
session.add(entry) session.add(entry)
if commit: if commit:
session.commit() await session.commit()
def audit_create_client( async def audit_create_client(
session: Session, request: Request, client: Client, commit: bool = True session: AsyncSession, request: Request, client: Client, commit: bool = True
) -> None: ) -> None:
"""Log the creation of a client.""" """Log the creation of a client."""
entry = AuditLog( entry = AuditLog(
@ -39,11 +39,11 @@ def audit_create_client(
client_name=client.name, client_name=client.name,
message="Client Created", message="Client Created",
) )
_write_audit_log(session, request, entry, commit) await _write_audit_log(session, request, entry, commit)
def audit_delete_client( async def audit_delete_client(
session: Session, request: Request, client: Client, commit: bool = True session: AsyncSession, request: Request, client: Client, commit: bool = True
) -> None: ) -> None:
"""Log the creation of a client.""" """Log the creation of a client."""
entry = AuditLog( entry = AuditLog(
@ -52,11 +52,11 @@ def audit_delete_client(
client_name=client.name, client_name=client.name,
message="Client deleted", message="Client deleted",
) )
_write_audit_log(session, request, entry, commit) await _write_audit_log(session, request, entry, commit)
def audit_create_secret( async def audit_create_secret(
session: Session, session: AsyncSession,
request: Request, request: Request,
client: Client, client: Client,
secret: ClientSecret, secret: ClientSecret,
@ -71,11 +71,11 @@ def audit_create_secret(
client_name=client.name, client_name=client.name,
message="Added secret to client", message="Added secret to client",
) )
_write_audit_log(session, request, entry, commit) await _write_audit_log(session, request, entry, commit)
def audit_remove_policy( async def audit_remove_policy(
session: Session, session: AsyncSession,
request: Request, request: Request,
client: Client, client: Client,
policy: ClientAccessPolicy, policy: ClientAccessPolicy,
@ -90,11 +90,11 @@ def audit_remove_policy(
message="Deleted client policy", message="Deleted client policy",
data=data, data=data,
) )
_write_audit_log(session, request, entry, commit) await _write_audit_log(session, request, entry, commit)
def audit_update_policy( async def audit_update_policy(
session: Session, session: AsyncSession,
request: Request, request: Request,
client: Client, client: Client,
policy: ClientAccessPolicy, policy: ClientAccessPolicy,
@ -109,11 +109,11 @@ def audit_update_policy(
message="Updated client policy", message="Updated client policy",
data=data, data=data,
) )
_write_audit_log(session, request, entry, commit) await _write_audit_log(session, request, entry, commit)
def audit_update_client( async def audit_update_client(
session: Session, session: AsyncSession,
request: Request, request: Request,
client: Client, client: Client,
commit: bool = True, commit: bool = True,
@ -125,11 +125,11 @@ def audit_update_client(
client_name=client.name, client_name=client.name,
message="Client data updated", message="Client data updated",
) )
_write_audit_log(session, request, entry, commit) await _write_audit_log(session, request, entry, commit)
def audit_update_secret( async def audit_update_secret(
session: Session, session: AsyncSession,
request: Request, request: Request,
client: Client, client: Client,
secret: ClientSecret, secret: ClientSecret,
@ -144,11 +144,11 @@ def audit_update_secret(
secret_id=secret.id, secret_id=secret.id,
message="Secret value updated", message="Secret value updated",
) )
_write_audit_log(session, request, entry, commit) await _write_audit_log(session, request, entry, commit)
def audit_invalidate_secrets( async def audit_invalidate_secrets(
session: Session, session: AsyncSession,
request: Request, request: Request,
client: Client, client: Client,
commit: bool = True, commit: bool = True,
@ -160,11 +160,11 @@ def audit_invalidate_secrets(
client_id=client.id, client_id=client.id,
message="Client public-key changed. All secrets invalidated.", message="Client public-key changed. All secrets invalidated.",
) )
_write_audit_log(session, request, entry, commit) await _write_audit_log(session, request, entry, commit)
def audit_delete_secret( async def audit_delete_secret(
session: Session, session: AsyncSession,
request: Request, request: Request,
client: Client, client: Client,
secret: ClientSecret, secret: ClientSecret,
@ -179,11 +179,11 @@ def audit_delete_secret(
client_id=client.id, client_id=client.id,
message="Secret removed from client", message="Secret removed from client",
) )
_write_audit_log(session, request, entry, commit) await _write_audit_log(session, request, entry, commit)
def audit_access_secrets( async def audit_access_secrets(
session: Session, session: AsyncSession,
request: Request, request: Request,
client: Client, client: Client,
secrets: Sequence[ClientSecret] | None = None, secrets: Sequence[ClientSecret] | None = None,
@ -194,19 +194,20 @@ def audit_access_secrets(
With no secrets provided, all secrets of the client will be resolved. With no secrets provided, all secrets of the client will be resolved.
""" """
if not secrets: if not secrets:
secrets = session.scalars( secrets_q = await session.scalars(
select(ClientSecret).where(ClientSecret.client_id == client.id) select(ClientSecret).where(ClientSecret.client_id == client.id)
).all() )
secrets = secrets_q.all()
for secret in secrets: for secret in secrets:
audit_access_secret(session, request, client, secret, False) await audit_access_secret(session, request, client, secret, False)
if commit: if commit:
session.commit() await session.commit()
def audit_access_secret( async def audit_access_secret(
session: Session, session: AsyncSession,
request: Request, request: Request,
client: Client, client: Client,
secret: ClientSecret, secret: ClientSecret,
@ -221,15 +222,15 @@ def audit_access_secret(
client_id=client.id, client_id=client.id,
client_name=client.name, client_name=client.name,
) )
_write_audit_log(session, request, entry, commit) await _write_audit_log(session, request, entry, commit)
def audit_client_secret_list( async def audit_client_secret_list(
session: Session, request: Request, commit: bool = True session: AsyncSession, request: Request, commit: bool = True
) -> None: ) -> None:
"""Audit a list of all secrets.""" """Audit a list of all secrets."""
entry = AuditLog( entry = AuditLog(
operation=Operation.READ, operation=Operation.READ,
message="All secret names and their clients was viewed", message="All secret names and their clients was viewed",
) )
_write_audit_log(session, request, entry, commit) await _write_audit_log(session, request, entry, commit)

View File

@ -1,17 +1,19 @@
"""Backend API.""" """Backend API."""
from collections.abc import AsyncGenerator
import logging import logging
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Depends, Header, HTTPException from fastapi import APIRouter, Depends, Header, HTTPException
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.ext.asyncio import AsyncSession
from sshecret_backend.db import DatabaseSessionManager
from sshecret_backend.settings import BackendSettings
from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api
from .auth import verify_token from .auth import verify_token
from .models import ( from .models import (
APIClient, APIClient,
) )
from .types import DBSessionDep
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -19,20 +21,25 @@ API_VERSION = "v1"
def get_backend_api( def get_backend_api(
get_db_session: DBSessionDep, settings: BackendSettings,
) -> APIRouter: ) -> APIRouter:
"""Construct backend API.""" """Construct backend API."""
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
sessionmanager = DatabaseSessionManager(settings.async_db_url)
async with sessionmanager.session() as session:
yield session
async def validate_token( async def validate_token(
x_api_token: Annotated[str, Header()], x_api_token: Annotated[str, Header()],
session: Annotated[Session, Depends(get_db_session)], session: Annotated[AsyncSession, Depends(get_db_session)],
) -> str: ) -> str:
"""Validate token.""" """Validate token."""
LOG.debug("Validating token %s", x_api_token) LOG.debug("Validating token %s", x_api_token)
statement = select(APIClient) statement = select(APIClient)
results = session.scalars(statement) results = await session.scalars(statement)
valid = False valid = False
for result in results: for result in results.all():
if verify_token(x_api_token, result.token): if verify_token(x_api_token, result.token):
valid = True valid = True
LOG.debug("Token is valid") LOG.debug("Token is valid")

View File

@ -4,10 +4,11 @@ import logging
import secrets import secrets
import sqlite3 import sqlite3
from collections.abc import Generator, Callable from collections.abc import AsyncIterator, Generator, Callable
from typing import Literal from contextlib import asynccontextmanager
from typing import Any, Literal
from sqlalchemy import create_engine, Engine, event, select from sqlalchemy import create_engine, Engine, event, select
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession, async_sessionmaker, create_async_engine, AsyncEngine
from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.orm import sessionmaker, Session
@ -20,6 +21,47 @@ from .models import APIClient, SubSystem
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
class DatabaseSessionManager:
def __init__(self, host: URL, **engine_kwargs: str):
self._engine: AsyncEngine | None = get_async_engine(host)
self._sessionmaker: async_sessionmaker[AsyncSession] | None = async_sessionmaker(autocommit=False, bind=self._engine, expire_on_commit=False)
async def close(self):
if self._engine is None:
raise Exception("DatabaseSessionManager is not initialized")
await self._engine.dispose()
self._engine = None
self._sessionmaker = None
@asynccontextmanager
async def connect(self) -> AsyncIterator[AsyncConnection]:
if self._engine is None:
raise Exception("DatabaseSessionManager is not initialized")
async with self._engine.begin() as connection:
try:
yield connection
except Exception:
await connection.rollback()
raise
@asynccontextmanager
async def session(self) -> AsyncIterator[AsyncSession]:
if self._sessionmaker is None:
raise Exception("DatabaseSessionManager is not initialized")
session = self._sessionmaker()
try:
yield session
except Exception:
await session.rollback()
raise
finally:
await session.close()
def setup_database( def setup_database(
db_url: URL, db_url: URL,
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]: ) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
@ -39,9 +81,10 @@ def setup_database(
return engine, get_db_session return engine, get_db_session
def get_engine(url: URL, echo: bool = False) -> Engine: def get_engine(url: URL, echo: bool = False) -> Engine:
"""Initialize the engine.""" """Initialize the engine."""
engine = create_engine(url, echo=echo, future=True) engine = create_engine(url, echo=echo)
if url.drivername.startswith("sqlite"): if url.drivername.startswith("sqlite"):
@event.listens_for(engine, "connect") @event.listens_for(engine, "connect")
@ -55,12 +98,11 @@ def get_engine(url: URL, echo: bool = False) -> Engine:
return engine return engine
def get_async_engine(url: URL, echo: bool = False) -> AsyncEngine: def get_async_engine(url: URL, echo: bool = False, **engine_kwargs: str) -> AsyncEngine:
"""Get an async engine.""" """Get an async engine."""
engine = create_async_engine(url, echo=echo, future=True) engine = create_async_engine(url, echo=echo, **engine_kwargs)
if url.drivername.startswith("sqlite+"): if url.drivername.startswith("sqlite+"):
@event.listens_for(engine.sync_engine, "connect")
@event.listens_for(engine, "connect")
def set_sqlite_pragma( def set_sqlite_pragma(
dbapi_connection: sqlite3.Connection, _connection_record: object dbapi_connection: sqlite3.Connection, _connection_record: object
) -> None: ) -> None:

View File

@ -13,6 +13,7 @@ import uuid
from datetime import datetime from datetime import datetime
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
@ -186,3 +187,9 @@ class AuditLog(Base):
def init_db(engine: sa.Engine) -> None: def init_db(engine: sa.Engine) -> None:
"""Initialize database.""" """Initialize database."""
Base.metadata.create_all(engine) Base.metadata.create_all(engine)
async def init_db_async(engine: AsyncEngine) -> None:
"""Initialize database."""
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

View File

@ -1,8 +1,11 @@
"""Common type definitions.""" """Common type definitions."""
from collections.abc import Callable, Generator from collections.abc import AsyncGenerator, Callable, Generator
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
DBSessionDep = Callable[[], Generator[Session, None, None]] DBSessionDep = Callable[[], Generator[Session, None, None]]
AsyncDBSessionDep = Callable[[], AsyncGenerator[AsyncSession, None]]