From fc0c3fb9501ae63a2cf94174bf5f4a76e69c940b Mon Sep 17 00:00:00 2001 From: Allan Eising Date: Mon, 19 May 2025 09:15:48 +0200 Subject: [PATCH] Refactor to use async database model --- .../src/sshecret_backend/api/audit.py | 32 +++--- .../src/sshecret_backend/api/clients.py | 69 ++++++------ .../src/sshecret_backend/api/common.py | 44 ++++++-- .../src/sshecret_backend/api/policies.py | 31 +++--- .../src/sshecret_backend/api/secrets.py | 103 +++++++++++------- .../src/sshecret_backend/app.py | 18 +-- .../src/sshecret_backend/audit.py | 87 +++++++-------- .../src/sshecret_backend/backend_api.py | 19 +++- .../src/sshecret_backend/db.py | 58 ++++++++-- .../src/sshecret_backend/models.py | 7 ++ .../src/sshecret_backend/types.py | 5 +- 11 files changed, 288 insertions(+), 185 deletions(-) diff --git a/packages/sshecret-backend/src/sshecret_backend/api/audit.py b/packages/sshecret-backend/src/sshecret_backend/api/audit.py index 6356adb..14706db 100644 --- a/packages/sshecret-backend/src/sshecret_backend/api/audit.py +++ b/packages/sshecret-backend/src/sshecret_backend/api/audit.py @@ -3,18 +3,18 @@ # pyright: reportUnusedFunction=false import logging -from collections.abc import Sequence -from typing import Any, cast -from fastapi import APIRouter, Depends, Request, Query +from typing import Any +from fastapi import APIRouter, Depends, Request from pydantic import BaseModel, Field, TypeAdapter from sqlalchemy import select, func, and_ -from sqlalchemy.orm import InstrumentedAttribute, Session +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import InstrumentedAttribute from sqlalchemy.sql.expression import ColumnExpressionArgument from typing import Annotated from sshecret_backend.models import AuditLog, Operation, SubSystem -from sshecret_backend.types import DBSessionDep +from sshecret_backend.types import AsyncDBSessionDep from sshecret_backend.view_models import AuditInfo, AuditView, AuditListResult @@ -58,24 +58,23 @@ class AuditFilter(BaseModel): ] -def get_audit_api(get_db_session: DBSessionDep) -> APIRouter: +def get_audit_api(get_db_session: AsyncDBSessionDep) -> APIRouter: """Construct audit sub-api.""" router = APIRouter() @router.get("/audit/", response_model=AuditListResult) async def get_audit_logs( - request: Request, - session: Annotated[Session, Depends(get_db_session)], + session: Annotated[AsyncSession, Depends(get_db_session)], filters: Annotated[AuditFilter, Depends()], ) -> AuditListResult: """Get audit logs.""" # audit.audit_access_audit_log(session, request) - total = session.scalars( + total = (await session.scalars( select(func.count("*")) .select_from(AuditLog) .where(and_(True, *filters.filter_mapping)) - ).one() + )).one() remaining = total - filters.offset statement = ( @@ -87,7 +86,7 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter: ) LogAdapt = TypeAdapter(list[AuditView]) - results = session.scalars(statement).all() + results = (await session.scalars(statement)).all() entries = LogAdapt.validate_python(results, from_attributes=True) return AuditListResult( results=entries, @@ -97,24 +96,23 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter: @router.post("/audit/") async def add_audit_log( - request: Request, - session: Annotated[Session, Depends(get_db_session)], + session: Annotated[AsyncSession, Depends(get_db_session)], entry: AuditView, ) -> AuditView: """Add entry to audit log.""" audit_log = AuditLog(**entry.model_dump(exclude_none=True)) session.add(audit_log) - session.commit() + await session.commit() return AuditView.model_validate(audit_log, from_attributes=True) @router.get("/audit/info") async def get_audit_info( - request: Request, session: Annotated[Session, Depends(get_db_session)] + session: Annotated[AsyncSession, Depends(get_db_session)] ) -> AuditInfo: """Get audit info.""" - audit_count = session.scalars( + audit_count = (await session.scalars( select(func.count("*")).select_from(AuditLog) - ).one() + )).one() return AuditInfo(entries=audit_count) return router diff --git a/packages/sshecret-backend/src/sshecret_backend/api/clients.py b/packages/sshecret-backend/src/sshecret_backend/api/clients.py index 536e176..872cfbc 100644 --- a/packages/sshecret-backend/src/sshecret_backend/api/clients.py +++ b/packages/sshecret-backend/src/sshecret_backend/api/clients.py @@ -4,14 +4,14 @@ import uuid import logging -from fastapi import APIRouter, Depends, HTTPException, Query, Request +from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from pydantic import BaseModel, Field, model_validator from typing import Annotated, Any, Self, TypeVar, cast from sqlalchemy import select, func -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.sql import Select -from sshecret_backend.types import DBSessionDep +from sshecret_backend.types import AsyncDBSessionDep from sshecret_backend.models import Client, ClientSecret from sshecret_backend.view_models import ( ClientCreate, @@ -20,7 +20,7 @@ from sshecret_backend.view_models import ( ClientUpdate, ) from sshecret_backend import audit -from .common import get_client_by_id_or_name +from .common import get_client_by_id_or_name, client_with_relationships class ClientListParams(BaseModel): @@ -74,30 +74,30 @@ def filter_client_statement( return statement.limit(params.limit).offset(params.offset) -def get_clients_api(get_db_session: DBSessionDep) -> APIRouter: +def get_clients_api(get_db_session: AsyncDBSessionDep) -> APIRouter: """Construct clients sub-api.""" router = APIRouter() @router.get("/clients/") async def get_clients( filter_query: Annotated[ClientListParams, Query()], - session: Annotated[Session, Depends(get_db_session)], + session: Annotated[AsyncSession, Depends(get_db_session)], ) -> ClientQueryResult: """Get clients.""" # Get total results first count_statement = select(func.count("*")).select_from(Client) count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True)) - total_results = session.scalars(count_statement).one() + total_results = (await session.scalars(count_statement)).one() - statement = filter_client_statement(select(Client), filter_query, False) + statement = filter_client_statement(client_with_relationships(), filter_query, False) - results = session.scalars(statement) + results = await session.scalars(statement) remainder = total_results - filter_query.offset - filter_query.limit if remainder < 0: remainder = 0 - clients = list(results) + clients = list(results.all()) clients_view = ClientView.from_client_list(clients) return ClientQueryResult( clients=clients_view, @@ -108,7 +108,7 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter: @router.get("/clients/{name}") async def get_client( name: str, - session: Annotated[Session, Depends(get_db_session)], + session: Annotated[AsyncSession, Depends(get_db_session)], ) -> ClientView: """Fetch a client.""" client = await get_client_by_id_or_name(session, name) @@ -122,7 +122,7 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter: async def delete_client( request: Request, name: str, - session: Annotated[Session, Depends(get_db_session)], + session: Annotated[AsyncSession, Depends(get_db_session)], ) -> None: """Delete a client.""" client = await get_client_by_id_or_name(session, name) @@ -131,15 +131,15 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter: status_code=404, detail="Cannot find a client with the given name." ) - session.delete(client) - session.commit() - audit.audit_delete_client(session, request, client) + await session.delete(client) + await session.commit() + await audit.audit_delete_client(session, request, client) @router.post("/clients/") async def create_client( request: Request, client: ClientCreate, - session: Annotated[Session, Depends(get_db_session)], + session: Annotated[AsyncSession, Depends(get_db_session)], ) -> ClientView: """Create client.""" existing = await get_client_by_id_or_name(session, client.name) @@ -148,9 +148,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter: db_client = client.to_client() session.add(db_client) - session.commit() - session.refresh(db_client) - audit.audit_create_client(session, request, db_client) + await session.commit() + await session.refresh(db_client) + db_client = await get_client_by_id_or_name(session, client.name) + if not db_client: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Could not create the client.") + await audit.audit_create_client(session, request, db_client) return ClientView.from_client(db_client) @router.post("/clients/{name}/public-key") @@ -158,7 +161,7 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter: request: Request, name: str, client_update: ClientUpdate, - session: Annotated[Session, Depends(get_db_session)], + session: Annotated[AsyncSession, Depends(get_db_session)], ) -> ClientView: """Change the public key of a client. @@ -170,17 +173,16 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter: status_code=404, detail="Cannot find a client with the given name." ) client.public_key = client_update.public_key - for secret in session.scalars( - select(ClientSecret).where(ClientSecret.client_id == client.id) - ).all(): + matching_secrets = await session.scalars(select(ClientSecret).where(ClientSecret.client_id == client.id)) + for secret in matching_secrets.all(): LOG.debug("Invalidated secret %s", secret.id) secret.invalidated = True secret.client_id = None session.add(client) - session.refresh(client) - session.commit() - audit.audit_invalidate_secrets(session, request, client) + await session.refresh(client) + await session.commit() + await audit.audit_invalidate_secrets(session, request, client) return ClientView.from_client(client) @@ -189,7 +191,7 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter: request: Request, name: str, client_update: ClientCreate, - session: Annotated[Session, Depends(get_db_session)], + session: Annotated[AsyncSession, Depends(get_db_session)], ) -> ClientView: """Change the public key of a client. @@ -205,19 +207,20 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter: public_key_updated = False if client_update.public_key != client.public_key: public_key_updated = True - for secret in session.scalars( + client_secrets = await session.scalars( select(ClientSecret).where(ClientSecret.client_id == client.id) - ).all(): + ) + for secret in client_secrets.all(): LOG.debug("Invalidated secret %s", secret.id) secret.invalidated = True secret.client_id = None session.add(client) - session.commit() - session.refresh(client) - audit.audit_update_client(session, request, client) + await session.commit() + await session.refresh(client) + await audit.audit_update_client(session, request, client) if public_key_updated: - audit.audit_invalidate_secrets(session, request, client) + await audit.audit_invalidate_secrets(session, request, client) return ClientView.from_client(client) diff --git a/packages/sshecret-backend/src/sshecret_backend/api/common.py b/packages/sshecret-backend/src/sshecret_backend/api/common.py index fad0b57..0693907 100644 --- a/packages/sshecret-backend/src/sshecret_backend/api/common.py +++ b/packages/sshecret-backend/src/sshecret_backend/api/common.py @@ -3,9 +3,11 @@ import re import uuid import bcrypt +from sqlalchemy import Select +from sqlalchemy.orm import selectinload -from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.future import select +from sqlalchemy.ext.asyncio import AsyncSession from sshecret_backend.models import Client @@ -17,20 +19,38 @@ def verify_token(token: str, stored_hash: str) -> bool: stored_bytes = stored_hash.encode("utf-8") return bcrypt.checkpw(token_bytes, stored_bytes) +async def reload_client_with_relationships(session: AsyncSession, client: Client) -> Client: + """Reload a client from the database.""" + session.expunge(client) + stmt = ( + select(Client) + .options(selectinload(Client.policies), selectinload(Client.secrets)) + .where(Client.id == client.id) + ) + result = await session.execute(stmt) + return result.scalar_one() -async def get_client_by_name(session: Session, name: str) -> Client | None: + +def client_with_relationships() -> Select[tuple[Client]]: + """Base select statement for client with relationships.""" + return select(Client).options( + selectinload(Client.secrets), + selectinload(Client.policies), + ) + +async def get_client_by_name(session: AsyncSession, name: str) -> Client | None: """Get client by name.""" - client_filter = select(Client).where(Client.name == name) - client_results = session.scalars(client_filter) - return client_results.first() + client_filter = client_with_relationships().where(Client.name == name) + client_results = await session.execute(client_filter) + return client_results.scalars().first() -async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None: - """Get client by name.""" - client_filter = select(Client).where(Client.id == id) - client_results = session.scalars(client_filter) - return client_results.first() +async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | None: + """Get client by ID.""" + client_filter = client_with_relationships().where(Client.id == id) + client_results = await session.execute(client_filter) + return client_results.scalars().first() -async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None: +async def get_client_by_id_or_name(session: AsyncSession, id_or_name: str) -> Client | None: """Get client either by id or name.""" if RE_UUID.match(id_or_name): id = uuid.UUID(id_or_name) diff --git a/packages/sshecret-backend/src/sshecret_backend/api/policies.py b/packages/sshecret-backend/src/sshecret_backend/api/policies.py index 624d2ec..48e677e 100644 --- a/packages/sshecret-backend/src/sshecret_backend/api/policies.py +++ b/packages/sshecret-backend/src/sshecret_backend/api/policies.py @@ -5,7 +5,7 @@ import logging from fastapi import APIRouter, Depends, HTTPException, Request from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession from typing import Annotated from sshecret_backend.models import ClientAccessPolicy @@ -13,21 +13,21 @@ from sshecret_backend.view_models import ( ClientPolicyView, ClientPolicyUpdate, ) -from sshecret_backend.types import DBSessionDep +from sshecret_backend.types import AsyncDBSessionDep from sshecret_backend import audit -from .common import get_client_by_id_or_name +from .common import get_client_by_id_or_name, reload_client_with_relationships LOG = logging.getLogger(__name__) -def get_policy_api(get_db_session: DBSessionDep) -> APIRouter: +def get_policy_api(get_db_session: AsyncDBSessionDep) -> APIRouter: """Construct clients sub-api.""" router = APIRouter() @router.get("/clients/{name}/policies/") async def get_client_policies( - name: str, session: Annotated[Session, Depends(get_db_session)] + name: str, session: Annotated[AsyncSession, Depends(get_db_session)] ) -> ClientPolicyView: """Get client policies.""" client = await get_client_by_id_or_name(session, name) @@ -43,7 +43,7 @@ def get_policy_api(get_db_session: DBSessionDep) -> APIRouter: request: Request, name: str, policy_update: ClientPolicyUpdate, - session: Annotated[Session, Depends(get_db_session)], + session: Annotated[AsyncSession, Depends(get_db_session)], ) -> ClientPolicyView: """Update client policies. @@ -55,28 +55,31 @@ def get_policy_api(get_db_session: DBSessionDep) -> APIRouter: status_code=404, detail="Cannot find a client with the given name." ) # Remove old policies. - policies = session.scalars( + policies = await session.scalars( select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id) - ).all() + ) deleted_policies: list[ClientAccessPolicy] = [] added_policies: list[ClientAccessPolicy] = [] - for policy in policies: - session.delete(policy) + for policy in policies.all(): + await session.delete(policy) deleted_policies.append(policy) + LOG.debug("Updating client policies with: %r", policy_update.sources) for source in policy_update.sources: LOG.debug("Source %r", source) policy = ClientAccessPolicy(source=str(source), client_id=client.id) session.add(policy) added_policies.append(policy) - session.commit() - session.refresh(client) + await session.flush() + await session.commit() + + client = await reload_client_with_relationships(session, client) for policy in deleted_policies: - audit.audit_remove_policy(session, request, client, policy) + await audit.audit_remove_policy(session, request, client, policy) for policy in added_policies: - audit.audit_update_policy(session, request, client, policy) + await audit.audit_update_policy(session, request, client, policy) return ClientPolicyView.from_client(client) diff --git a/packages/sshecret-backend/src/sshecret_backend/api/secrets.py b/packages/sshecret-backend/src/sshecret_backend/api/secrets.py index 096d594..5fca3c2 100644 --- a/packages/sshecret-backend/src/sshecret_backend/api/secrets.py +++ b/packages/sshecret-backend/src/sshecret_backend/api/secrets.py @@ -6,9 +6,11 @@ import logging from collections import defaultdict from fastapi import APIRouter, Depends, HTTPException, Request from sqlalchemy import select -from sqlalchemy.orm import Session from typing import Annotated +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + from sshecret_backend.models import Client, ClientSecret from sshecret_backend.view_models import ( ClientReference, @@ -19,7 +21,7 @@ from sshecret_backend.view_models import ( ClientSecretResponse, ) from sshecret_backend import audit -from sshecret_backend.types import DBSessionDep +from sshecret_backend.types import AsyncDBSessionDep from .common import get_client_by_id_or_name @@ -27,7 +29,7 @@ LOG = logging.getLogger(__name__) async def lookup_client_secret( - session: Session, client: Client, name: str + session: AsyncSession, client: Client, name: str ) -> ClientSecret | None: """Look up a secret for a client.""" statement = ( @@ -35,11 +37,11 @@ async def lookup_client_secret( .where(ClientSecret.client_id == client.id) .where(ClientSecret.name == name) ) - results = session.scalars(statement) + results = await session.scalars(statement) return results.first() -def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter: +def get_secrets_api(get_db_session: AsyncDBSessionDep) -> APIRouter: """Construct clients sub-api.""" router = APIRouter() @@ -48,7 +50,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter: request: Request, name: str, client_secret: ClientSecretPublic, - session: Annotated[Session, Depends(get_db_session)], + session: Annotated[AsyncSession, Depends(get_db_session)], ) -> None: """Add secret to a client.""" client = await get_client_by_id_or_name(session, name) @@ -69,9 +71,9 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter: name=client_secret.name, client_id=client.id, secret=client_secret.secret ) session.add(db_secret) - session.commit() - session.refresh(db_secret) - audit.audit_create_secret(session, request, client, db_secret) + await session.commit() + await session.refresh(db_secret) + await audit.audit_create_secret(session, request, client, db_secret) @router.put("/clients/{name}/secrets/{secret_name}") async def update_client_secret( @@ -79,7 +81,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter: name: str, secret_name: str, secret_data: BodyValue, - session: Annotated[Session, Depends(get_db_session)], + session: Annotated[AsyncSession, Depends(get_db_session)], ) -> ClientSecretResponse: """Update a client secret. @@ -96,9 +98,9 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter: existing_secret.secret = secret_data.value session.add(existing_secret) - session.commit() - session.refresh(existing_secret) - audit.audit_update_secret(session, request, client, existing_secret) + await session.commit() + await session.refresh(existing_secret) + await audit.audit_update_secret(session, request, client, existing_secret) return ClientSecretResponse.from_client_secret(existing_secret) db_secret = ClientSecret( @@ -107,9 +109,9 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter: secret=secret_data.value, ) session.add(db_secret) - session.commit() - session.refresh(db_secret) - audit.audit_create_secret(session, request, client, db_secret) + await session.commit() + await session.refresh(db_secret) + await audit.audit_create_secret(session, request, client, db_secret) return ClientSecretResponse.from_client_secret(db_secret) @router.get("/clients/{name}/secrets/{secret_name}") @@ -117,7 +119,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter: request: Request, name: str, secret_name: str, - session: Annotated[Session, Depends(get_db_session)], + session: Annotated[AsyncSession, Depends(get_db_session)], ) -> ClientSecretResponse: """Get a client secret.""" client = await get_client_by_id_or_name(session, name) @@ -133,7 +135,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter: ) response_model = ClientSecretResponse.from_client_secret(secret) - audit.audit_access_secret(session, request, client, secret) + await audit.audit_access_secret(session, request, client, secret) return response_model @router.delete("/clients/{name}/secrets/{secret_name}") @@ -141,7 +143,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter: request: Request, name: str, secret_name: str, - session: Annotated[Session, Depends(get_db_session)], + session: Annotated[AsyncSession, Depends(get_db_session)], ) -> None: """Delete a secret.""" client = await get_client_by_id_or_name(session, name) @@ -156,56 +158,69 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter: status_code=404, detail="Cannot find a secret with the given name." ) - session.delete(secret) - session.commit() - audit.audit_delete_secret(session, request, client, secret) + await session.delete(secret) + await session.commit() + await audit.audit_delete_secret(session, request, client, secret) @router.get("/secrets/") async def get_secret_map( - request: Request, session: Annotated[Session, Depends(get_db_session)] + session: Annotated[AsyncSession, Depends(get_db_session)], ) -> list[ClientSecretList]: """Get a list of all secrets and which clients have them.""" client_secret_map: defaultdict[str, list[str]] = defaultdict(list) - for client_secret in session.scalars(select(ClientSecret)).all(): + client_secrets = await session.scalars( + select(ClientSecret).options(selectinload(ClientSecret.client)) + ) + for client_secret in client_secrets.all(): if not client_secret.client: if client_secret.name not in client_secret_map: client_secret_map[client_secret.name] = [] continue client_secret_map[client_secret.name].append(client_secret.client.name) - #audit.audit_client_secret_list(session, request) + # audit.audit_client_secret_list(session, request) return [ ClientSecretList(name=secret_name, clients=clients) for secret_name, clients in client_secret_map.items() ] + @router.get("/secrets/detailed/") async def get_detailed_secret_map( - request: Request, session: Annotated[Session, Depends(get_db_session)] + session: Annotated[AsyncSession, Depends(get_db_session)], ) -> list[ClientSecretDetailList]: """Get a list of all secrets and which clients have them.""" client_secrets: dict[str, ClientSecretDetailList] = {} - for client_secret in session.scalars(select(ClientSecret)).all(): - + all_client_secrets = await session.execute( + select(ClientSecret).options(selectinload(ClientSecret.client)) + ) + for client_secret in all_client_secrets.scalars().all(): if client_secret.name not in client_secrets: - client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name) + client_secrets[client_secret.name] = ClientSecretDetailList( + name=client_secret.name + ) client_secrets[client_secret.name].ids.append(str(client_secret.id)) if not client_secret.client: continue - client_secrets[client_secret.name].clients.append(ClientReference(id=str(client_secret.client.id), name=client_secret.client.name)) - #`audit.audit_client_secret_list(session, request) + client_secrets[client_secret.name].clients.append( + ClientReference( + id=str(client_secret.client.id), name=client_secret.client.name + ) + ) + # `audit.audit_client_secret_list(session, request) return list(client_secrets.values()) - @router.get("/secrets/{name}") async def get_secret_clients( - request: Request, name: str, - session: Annotated[Session, Depends(get_db_session)], + session: Annotated[AsyncSession, Depends(get_db_session)], ) -> ClientSecretList: """Get a list of which clients has a named secret.""" clients: list[str] = [] - for client_secret in session.scalars( - select(ClientSecret).where(ClientSecret.name == name) - ).all(): + client_secrets = await session.scalars( + select(ClientSecret) + .options(selectinload(ClientSecret.client)) + .where(ClientSecret.name == name) + ) + for client_secret in client_secrets.all(): if not client_secret.client: continue clients.append(client_secret.client.name) @@ -214,19 +229,23 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter: @router.get("/secrets/{name}/detailed") async def get_secret_clients_detailed( - request: Request, name: str, - session: Annotated[Session, Depends(get_db_session)], + session: Annotated[AsyncSession, Depends(get_db_session)], ) -> ClientSecretDetailList: """Get a list of which clients has a named secret.""" detail_list = ClientSecretDetailList(name=name) - for client_secret in session.scalars( + client_secrets = await session.scalars( select(ClientSecret).where(ClientSecret.name == name) - ).all(): + ) + for client_secret in client_secrets.all(): if not client_secret.client: continue detail_list.ids.append(str(client_secret.id)) - detail_list.clients.append(ClientReference(id=str(client_secret.client.id), name=client_secret.client.name)) + detail_list.clients.append( + ClientReference( + id=str(client_secret.client.id), name=client_secret.client.name + ) + ) return detail_list diff --git a/packages/sshecret-backend/src/sshecret_backend/app.py b/packages/sshecret-backend/src/sshecret_backend/app.py index d8627e1..cf0da7a 100644 --- a/packages/sshecret-backend/src/sshecret_backend/app.py +++ b/packages/sshecret-backend/src/sshecret_backend/app.py @@ -13,30 +13,32 @@ from fastapi.encoders import jsonable_encoder from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse from sqlalchemy import Engine +from sqlalchemy.ext.asyncio import AsyncEngine -from .models import init_db +from .models import init_db_async from .backend_api import get_backend_api -from .db import setup_database +from .db import setup_database, get_async_engine from .settings import BackendSettings -from .types import DBSessionDep +from .types import AsyncDBSessionDep LOG = logging.getLogger(__name__) -def init_backend_app(engine: Engine, get_db_session: DBSessionDep) -> FastAPI: +def init_backend_app(settings: BackendSettings) -> FastAPI: """Initialize backend app.""" @asynccontextmanager async def lifespan(_app: FastAPI): """Create database before starting the server.""" LOG.debug("Running lifespan") - init_db(engine) + engine = get_async_engine(settings.async_db_url) + await init_db_async(engine) yield app = FastAPI(lifespan=lifespan) - app.include_router(get_backend_api(get_db_session)) + app.include_router(get_backend_api(settings)) @app.exception_handler(RequestValidationError) async def validation_exception_handler( @@ -60,6 +62,4 @@ def init_backend_app(engine: Engine, get_db_session: DBSessionDep) -> FastAPI: def create_backend_app(settings: BackendSettings) -> FastAPI: """Create the backend app.""" - engine, get_db_session = setup_database(settings.db_url) - - return init_backend_app(engine, get_db_session) + return init_backend_app(settings) diff --git a/packages/sshecret-backend/src/sshecret_backend/audit.py b/packages/sshecret-backend/src/sshecret_backend/audit.py index 6593143..89c2704 100644 --- a/packages/sshecret-backend/src/sshecret_backend/audit.py +++ b/packages/sshecret-backend/src/sshecret_backend/audit.py @@ -3,7 +3,7 @@ from collections.abc import Sequence from fastapi import Request from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy, Operation, SubSystem @@ -17,8 +17,8 @@ def _get_origin(request: Request) -> str | None: return origin -def _write_audit_log( - session: Session, request: Request, entry: AuditLog, commit: bool = True +async def _write_audit_log( + session: AsyncSession, request: Request, entry: AuditLog, commit: bool = True ) -> None: """Write the audit log.""" origin = _get_origin(request) @@ -26,11 +26,11 @@ def _write_audit_log( entry.subsystem = SubSystem.BACKEND session.add(entry) if commit: - session.commit() + await session.commit() -def audit_create_client( - session: Session, request: Request, client: Client, commit: bool = True +async def audit_create_client( + session: AsyncSession, request: Request, client: Client, commit: bool = True ) -> None: """Log the creation of a client.""" entry = AuditLog( @@ -39,11 +39,11 @@ def audit_create_client( client_name=client.name, message="Client Created", ) - _write_audit_log(session, request, entry, commit) + await _write_audit_log(session, request, entry, commit) -def audit_delete_client( - session: Session, request: Request, client: Client, commit: bool = True +async def audit_delete_client( + session: AsyncSession, request: Request, client: Client, commit: bool = True ) -> None: """Log the creation of a client.""" entry = AuditLog( @@ -52,11 +52,11 @@ def audit_delete_client( client_name=client.name, message="Client deleted", ) - _write_audit_log(session, request, entry, commit) + await _write_audit_log(session, request, entry, commit) -def audit_create_secret( - session: Session, +async def audit_create_secret( + session: AsyncSession, request: Request, client: Client, secret: ClientSecret, @@ -71,11 +71,11 @@ def audit_create_secret( client_name=client.name, message="Added secret to client", ) - _write_audit_log(session, request, entry, commit) + await _write_audit_log(session, request, entry, commit) -def audit_remove_policy( - session: Session, +async def audit_remove_policy( + session: AsyncSession, request: Request, client: Client, policy: ClientAccessPolicy, @@ -90,11 +90,11 @@ def audit_remove_policy( message="Deleted client policy", data=data, ) - _write_audit_log(session, request, entry, commit) + await _write_audit_log(session, request, entry, commit) -def audit_update_policy( - session: Session, +async def audit_update_policy( + session: AsyncSession, request: Request, client: Client, policy: ClientAccessPolicy, @@ -109,11 +109,11 @@ def audit_update_policy( message="Updated client policy", data=data, ) - _write_audit_log(session, request, entry, commit) + await _write_audit_log(session, request, entry, commit) -def audit_update_client( - session: Session, +async def audit_update_client( + session: AsyncSession, request: Request, client: Client, commit: bool = True, @@ -125,11 +125,11 @@ def audit_update_client( client_name=client.name, message="Client data updated", ) - _write_audit_log(session, request, entry, commit) + await _write_audit_log(session, request, entry, commit) -def audit_update_secret( - session: Session, +async def audit_update_secret( + session: AsyncSession, request: Request, client: Client, secret: ClientSecret, @@ -144,11 +144,11 @@ def audit_update_secret( secret_id=secret.id, message="Secret value updated", ) - _write_audit_log(session, request, entry, commit) + await _write_audit_log(session, request, entry, commit) -def audit_invalidate_secrets( - session: Session, +async def audit_invalidate_secrets( + session: AsyncSession, request: Request, client: Client, commit: bool = True, @@ -160,11 +160,11 @@ def audit_invalidate_secrets( client_id=client.id, message="Client public-key changed. All secrets invalidated.", ) - _write_audit_log(session, request, entry, commit) + await _write_audit_log(session, request, entry, commit) -def audit_delete_secret( - session: Session, +async def audit_delete_secret( + session: AsyncSession, request: Request, client: Client, secret: ClientSecret, @@ -179,11 +179,11 @@ def audit_delete_secret( client_id=client.id, message="Secret removed from client", ) - _write_audit_log(session, request, entry, commit) + await _write_audit_log(session, request, entry, commit) -def audit_access_secrets( - session: Session, +async def audit_access_secrets( + session: AsyncSession, request: Request, client: Client, secrets: Sequence[ClientSecret] | None = None, @@ -194,19 +194,20 @@ def audit_access_secrets( With no secrets provided, all secrets of the client will be resolved. """ if not secrets: - secrets = session.scalars( + secrets_q = await session.scalars( select(ClientSecret).where(ClientSecret.client_id == client.id) - ).all() + ) + secrets = secrets_q.all() for secret in secrets: - audit_access_secret(session, request, client, secret, False) + await audit_access_secret(session, request, client, secret, False) if commit: - session.commit() + await session.commit() -def audit_access_secret( - session: Session, +async def audit_access_secret( + session: AsyncSession, request: Request, client: Client, secret: ClientSecret, @@ -221,15 +222,15 @@ def audit_access_secret( client_id=client.id, client_name=client.name, ) - _write_audit_log(session, request, entry, commit) + await _write_audit_log(session, request, entry, commit) -def audit_client_secret_list( - session: Session, request: Request, commit: bool = True +async def audit_client_secret_list( + session: AsyncSession, request: Request, commit: bool = True ) -> None: """Audit a list of all secrets.""" entry = AuditLog( operation=Operation.READ, message="All secret names and their clients was viewed", ) - _write_audit_log(session, request, entry, commit) + await _write_audit_log(session, request, entry, commit) diff --git a/packages/sshecret-backend/src/sshecret_backend/backend_api.py b/packages/sshecret-backend/src/sshecret_backend/backend_api.py index aca6a52..06441f0 100644 --- a/packages/sshecret-backend/src/sshecret_backend/backend_api.py +++ b/packages/sshecret-backend/src/sshecret_backend/backend_api.py @@ -1,17 +1,19 @@ """Backend API.""" +from collections.abc import AsyncGenerator import logging from typing import Annotated from fastapi import APIRouter, Depends, Header, HTTPException from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession +from sshecret_backend.db import DatabaseSessionManager +from sshecret_backend.settings import BackendSettings from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api from .auth import verify_token from .models import ( APIClient, ) -from .types import DBSessionDep LOG = logging.getLogger(__name__) @@ -19,20 +21,25 @@ API_VERSION = "v1" def get_backend_api( - get_db_session: DBSessionDep, + settings: BackendSettings, ) -> APIRouter: """Construct backend API.""" + async def get_db_session() -> AsyncGenerator[AsyncSession, None]: + sessionmanager = DatabaseSessionManager(settings.async_db_url) + async with sessionmanager.session() as session: + yield session + async def validate_token( x_api_token: Annotated[str, Header()], - session: Annotated[Session, Depends(get_db_session)], + session: Annotated[AsyncSession, Depends(get_db_session)], ) -> str: """Validate token.""" LOG.debug("Validating token %s", x_api_token) statement = select(APIClient) - results = session.scalars(statement) + results = await session.scalars(statement) valid = False - for result in results: + for result in results.all(): if verify_token(x_api_token, result.token): valid = True LOG.debug("Token is valid") diff --git a/packages/sshecret-backend/src/sshecret_backend/db.py b/packages/sshecret-backend/src/sshecret_backend/db.py index ff5a0e5..0719503 100644 --- a/packages/sshecret-backend/src/sshecret_backend/db.py +++ b/packages/sshecret-backend/src/sshecret_backend/db.py @@ -4,10 +4,11 @@ import logging import secrets import sqlite3 -from collections.abc import Generator, Callable -from typing import Literal +from collections.abc import AsyncIterator, Generator, Callable +from contextlib import asynccontextmanager +from typing import Any, Literal from sqlalchemy import create_engine, Engine, event, select -from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession, async_sessionmaker, create_async_engine, AsyncEngine from sqlalchemy.orm import sessionmaker, Session @@ -20,6 +21,47 @@ from .models import APIClient, SubSystem LOG = logging.getLogger(__name__) + +class DatabaseSessionManager: + def __init__(self, host: URL, **engine_kwargs: str): + self._engine: AsyncEngine | None = get_async_engine(host) + self._sessionmaker: async_sessionmaker[AsyncSession] | None = async_sessionmaker(autocommit=False, bind=self._engine, expire_on_commit=False) + + async def close(self): + if self._engine is None: + raise Exception("DatabaseSessionManager is not initialized") + await self._engine.dispose() + + self._engine = None + self._sessionmaker = None + + @asynccontextmanager + async def connect(self) -> AsyncIterator[AsyncConnection]: + if self._engine is None: + raise Exception("DatabaseSessionManager is not initialized") + + async with self._engine.begin() as connection: + try: + yield connection + except Exception: + await connection.rollback() + raise + + @asynccontextmanager + async def session(self) -> AsyncIterator[AsyncSession]: + if self._sessionmaker is None: + raise Exception("DatabaseSessionManager is not initialized") + + session = self._sessionmaker() + try: + yield session + except Exception: + await session.rollback() + raise + finally: + await session.close() + + def setup_database( db_url: URL, ) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]: @@ -39,9 +81,10 @@ def setup_database( return engine, get_db_session + def get_engine(url: URL, echo: bool = False) -> Engine: """Initialize the engine.""" - engine = create_engine(url, echo=echo, future=True) + engine = create_engine(url, echo=echo) if url.drivername.startswith("sqlite"): @event.listens_for(engine, "connect") @@ -55,12 +98,11 @@ def get_engine(url: URL, echo: bool = False) -> Engine: return engine -def get_async_engine(url: URL, echo: bool = False) -> AsyncEngine: +def get_async_engine(url: URL, echo: bool = False, **engine_kwargs: str) -> AsyncEngine: """Get an async engine.""" - engine = create_async_engine(url, echo=echo, future=True) + engine = create_async_engine(url, echo=echo, **engine_kwargs) if url.drivername.startswith("sqlite+"): - - @event.listens_for(engine, "connect") + @event.listens_for(engine.sync_engine, "connect") def set_sqlite_pragma( dbapi_connection: sqlite3.Connection, _connection_record: object ) -> None: diff --git a/packages/sshecret-backend/src/sshecret_backend/models.py b/packages/sshecret-backend/src/sshecret_backend/models.py index 6eab429..a2123ab 100644 --- a/packages/sshecret-backend/src/sshecret_backend/models.py +++ b/packages/sshecret-backend/src/sshecret_backend/models.py @@ -13,6 +13,7 @@ import uuid from datetime import datetime import sqlalchemy as sa +from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship @@ -186,3 +187,9 @@ class AuditLog(Base): def init_db(engine: sa.Engine) -> None: """Initialize database.""" Base.metadata.create_all(engine) + + +async def init_db_async(engine: AsyncEngine) -> None: + """Initialize database.""" + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) diff --git a/packages/sshecret-backend/src/sshecret_backend/types.py b/packages/sshecret-backend/src/sshecret_backend/types.py index 39ba47b..d9df5e7 100644 --- a/packages/sshecret-backend/src/sshecret_backend/types.py +++ b/packages/sshecret-backend/src/sshecret_backend/types.py @@ -1,8 +1,11 @@ """Common type definitions.""" -from collections.abc import Callable, Generator +from collections.abc import AsyncGenerator, Callable, Generator +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession from sqlalchemy.orm import Session DBSessionDep = Callable[[], Generator[Session, None, None]] + +AsyncDBSessionDep = Callable[[], AsyncGenerator[AsyncSession, None]]