Refactor to use async database model

This commit is contained in:
2025-05-19 09:15:48 +02:00
parent f10ae027e5
commit fc0c3fb950
11 changed files with 288 additions and 185 deletions

View File

@ -3,18 +3,18 @@
# pyright: reportUnusedFunction=false
import logging
from collections.abc import Sequence
from typing import Any, cast
from fastapi import APIRouter, Depends, Request, Query
from typing import Any
from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel, Field, TypeAdapter
from sqlalchemy import select, func, and_
from sqlalchemy.orm import InstrumentedAttribute, Session
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import InstrumentedAttribute
from sqlalchemy.sql.expression import ColumnExpressionArgument
from typing import Annotated
from sshecret_backend.models import AuditLog, Operation, SubSystem
from sshecret_backend.types import DBSessionDep
from sshecret_backend.types import AsyncDBSessionDep
from sshecret_backend.view_models import AuditInfo, AuditView, AuditListResult
@ -58,24 +58,23 @@ class AuditFilter(BaseModel):
]
def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
def get_audit_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Construct audit sub-api."""
router = APIRouter()
@router.get("/audit/", response_model=AuditListResult)
async def get_audit_logs(
request: Request,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
filters: Annotated[AuditFilter, Depends()],
) -> AuditListResult:
"""Get audit logs."""
# audit.audit_access_audit_log(session, request)
total = session.scalars(
total = (await session.scalars(
select(func.count("*"))
.select_from(AuditLog)
.where(and_(True, *filters.filter_mapping))
).one()
)).one()
remaining = total - filters.offset
statement = (
@ -87,7 +86,7 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
)
LogAdapt = TypeAdapter(list[AuditView])
results = session.scalars(statement).all()
results = (await session.scalars(statement)).all()
entries = LogAdapt.validate_python(results, from_attributes=True)
return AuditListResult(
results=entries,
@ -97,24 +96,23 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
@router.post("/audit/")
async def add_audit_log(
request: Request,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
entry: AuditView,
) -> AuditView:
"""Add entry to audit log."""
audit_log = AuditLog(**entry.model_dump(exclude_none=True))
session.add(audit_log)
session.commit()
await session.commit()
return AuditView.model_validate(audit_log, from_attributes=True)
@router.get("/audit/info")
async def get_audit_info(
request: Request, session: Annotated[Session, Depends(get_db_session)]
session: Annotated[AsyncSession, Depends(get_db_session)]
) -> AuditInfo:
"""Get audit info."""
audit_count = session.scalars(
audit_count = (await session.scalars(
select(func.count("*")).select_from(AuditLog)
).one()
)).one()
return AuditInfo(entries=audit_count)
return router

View File

@ -4,14 +4,14 @@
import uuid
import logging
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from pydantic import BaseModel, Field, model_validator
from typing import Annotated, Any, Self, TypeVar, cast
from sqlalchemy import select, func
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import Select
from sshecret_backend.types import DBSessionDep
from sshecret_backend.types import AsyncDBSessionDep
from sshecret_backend.models import Client, ClientSecret
from sshecret_backend.view_models import (
ClientCreate,
@ -20,7 +20,7 @@ from sshecret_backend.view_models import (
ClientUpdate,
)
from sshecret_backend import audit
from .common import get_client_by_id_or_name
from .common import get_client_by_id_or_name, client_with_relationships
class ClientListParams(BaseModel):
@ -74,30 +74,30 @@ def filter_client_statement(
return statement.limit(params.limit).offset(params.offset)
def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
def get_clients_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Construct clients sub-api."""
router = APIRouter()
@router.get("/clients/")
async def get_clients(
filter_query: Annotated[ClientListParams, Query()],
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientQueryResult:
"""Get clients."""
# Get total results first
count_statement = select(func.count("*")).select_from(Client)
count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True))
total_results = session.scalars(count_statement).one()
total_results = (await session.scalars(count_statement)).one()
statement = filter_client_statement(select(Client), filter_query, False)
statement = filter_client_statement(client_with_relationships(), filter_query, False)
results = session.scalars(statement)
results = await session.scalars(statement)
remainder = total_results - filter_query.offset - filter_query.limit
if remainder < 0:
remainder = 0
clients = list(results)
clients = list(results.all())
clients_view = ClientView.from_client_list(clients)
return ClientQueryResult(
clients=clients_view,
@ -108,7 +108,7 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
@router.get("/clients/{name}")
async def get_client(
name: str,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Fetch a client."""
client = await get_client_by_id_or_name(session, name)
@ -122,7 +122,7 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
async def delete_client(
request: Request,
name: str,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> None:
"""Delete a client."""
client = await get_client_by_id_or_name(session, name)
@ -131,15 +131,15 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
status_code=404, detail="Cannot find a client with the given name."
)
session.delete(client)
session.commit()
audit.audit_delete_client(session, request, client)
await session.delete(client)
await session.commit()
await audit.audit_delete_client(session, request, client)
@router.post("/clients/")
async def create_client(
request: Request,
client: ClientCreate,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Create client."""
existing = await get_client_by_id_or_name(session, client.name)
@ -148,9 +148,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
db_client = client.to_client()
session.add(db_client)
session.commit()
session.refresh(db_client)
audit.audit_create_client(session, request, db_client)
await session.commit()
await session.refresh(db_client)
db_client = await get_client_by_id_or_name(session, client.name)
if not db_client:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Could not create the client.")
await audit.audit_create_client(session, request, db_client)
return ClientView.from_client(db_client)
@router.post("/clients/{name}/public-key")
@ -158,7 +161,7 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
request: Request,
name: str,
client_update: ClientUpdate,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Change the public key of a client.
@ -170,17 +173,16 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
status_code=404, detail="Cannot find a client with the given name."
)
client.public_key = client_update.public_key
for secret in session.scalars(
select(ClientSecret).where(ClientSecret.client_id == client.id)
).all():
matching_secrets = await session.scalars(select(ClientSecret).where(ClientSecret.client_id == client.id))
for secret in matching_secrets.all():
LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True
secret.client_id = None
session.add(client)
session.refresh(client)
session.commit()
audit.audit_invalidate_secrets(session, request, client)
await session.refresh(client)
await session.commit()
await audit.audit_invalidate_secrets(session, request, client)
return ClientView.from_client(client)
@ -189,7 +191,7 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
request: Request,
name: str,
client_update: ClientCreate,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Change the public key of a client.
@ -205,19 +207,20 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
public_key_updated = False
if client_update.public_key != client.public_key:
public_key_updated = True
for secret in session.scalars(
client_secrets = await session.scalars(
select(ClientSecret).where(ClientSecret.client_id == client.id)
).all():
)
for secret in client_secrets.all():
LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True
secret.client_id = None
session.add(client)
session.commit()
session.refresh(client)
audit.audit_update_client(session, request, client)
await session.commit()
await session.refresh(client)
await audit.audit_update_client(session, request, client)
if public_key_updated:
audit.audit_invalidate_secrets(session, request, client)
await audit.audit_invalidate_secrets(session, request, client)
return ClientView.from_client(client)

View File

@ -3,9 +3,11 @@
import re
import uuid
import bcrypt
from sqlalchemy import Select
from sqlalchemy.orm import selectinload
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.future import select
from sqlalchemy.ext.asyncio import AsyncSession
from sshecret_backend.models import Client
@ -17,20 +19,38 @@ def verify_token(token: str, stored_hash: str) -> bool:
stored_bytes = stored_hash.encode("utf-8")
return bcrypt.checkpw(token_bytes, stored_bytes)
async def reload_client_with_relationships(session: AsyncSession, client: Client) -> Client:
"""Reload a client from the database."""
session.expunge(client)
stmt = (
select(Client)
.options(selectinload(Client.policies), selectinload(Client.secrets))
.where(Client.id == client.id)
)
result = await session.execute(stmt)
return result.scalar_one()
async def get_client_by_name(session: Session, name: str) -> Client | None:
def client_with_relationships() -> Select[tuple[Client]]:
"""Base select statement for client with relationships."""
return select(Client).options(
selectinload(Client.secrets),
selectinload(Client.policies),
)
async def get_client_by_name(session: AsyncSession, name: str) -> Client | None:
"""Get client by name."""
client_filter = select(Client).where(Client.name == name)
client_results = session.scalars(client_filter)
return client_results.first()
client_filter = client_with_relationships().where(Client.name == name)
client_results = await session.execute(client_filter)
return client_results.scalars().first()
async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None:
"""Get client by name."""
client_filter = select(Client).where(Client.id == id)
client_results = session.scalars(client_filter)
return client_results.first()
async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | None:
"""Get client by ID."""
client_filter = client_with_relationships().where(Client.id == id)
client_results = await session.execute(client_filter)
return client_results.scalars().first()
async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None:
async def get_client_by_id_or_name(session: AsyncSession, id_or_name: str) -> Client | None:
"""Get client either by id or name."""
if RE_UUID.match(id_or_name):
id = uuid.UUID(id_or_name)

View File

@ -5,7 +5,7 @@
import logging
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Annotated
from sshecret_backend.models import ClientAccessPolicy
@ -13,21 +13,21 @@ from sshecret_backend.view_models import (
ClientPolicyView,
ClientPolicyUpdate,
)
from sshecret_backend.types import DBSessionDep
from sshecret_backend.types import AsyncDBSessionDep
from sshecret_backend import audit
from .common import get_client_by_id_or_name
from .common import get_client_by_id_or_name, reload_client_with_relationships
LOG = logging.getLogger(__name__)
def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
def get_policy_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Construct clients sub-api."""
router = APIRouter()
@router.get("/clients/{name}/policies/")
async def get_client_policies(
name: str, session: Annotated[Session, Depends(get_db_session)]
name: str, session: Annotated[AsyncSession, Depends(get_db_session)]
) -> ClientPolicyView:
"""Get client policies."""
client = await get_client_by_id_or_name(session, name)
@ -43,7 +43,7 @@ def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
request: Request,
name: str,
policy_update: ClientPolicyUpdate,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientPolicyView:
"""Update client policies.
@ -55,28 +55,31 @@ def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
status_code=404, detail="Cannot find a client with the given name."
)
# Remove old policies.
policies = session.scalars(
policies = await session.scalars(
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
).all()
)
deleted_policies: list[ClientAccessPolicy] = []
added_policies: list[ClientAccessPolicy] = []
for policy in policies:
session.delete(policy)
for policy in policies.all():
await session.delete(policy)
deleted_policies.append(policy)
LOG.debug("Updating client policies with: %r", policy_update.sources)
for source in policy_update.sources:
LOG.debug("Source %r", source)
policy = ClientAccessPolicy(source=str(source), client_id=client.id)
session.add(policy)
added_policies.append(policy)
session.commit()
session.refresh(client)
await session.flush()
await session.commit()
client = await reload_client_with_relationships(session, client)
for policy in deleted_policies:
audit.audit_remove_policy(session, request, client, policy)
await audit.audit_remove_policy(session, request, client, policy)
for policy in added_policies:
audit.audit_update_policy(session, request, client, policy)
await audit.audit_update_policy(session, request, client, policy)
return ClientPolicyView.from_client(client)

View File

@ -6,9 +6,11 @@ import logging
from collections import defaultdict
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy import select
from sqlalchemy.orm import Session
from typing import Annotated
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sshecret_backend.models import Client, ClientSecret
from sshecret_backend.view_models import (
ClientReference,
@ -19,7 +21,7 @@ from sshecret_backend.view_models import (
ClientSecretResponse,
)
from sshecret_backend import audit
from sshecret_backend.types import DBSessionDep
from sshecret_backend.types import AsyncDBSessionDep
from .common import get_client_by_id_or_name
@ -27,7 +29,7 @@ LOG = logging.getLogger(__name__)
async def lookup_client_secret(
session: Session, client: Client, name: str
session: AsyncSession, client: Client, name: str
) -> ClientSecret | None:
"""Look up a secret for a client."""
statement = (
@ -35,11 +37,11 @@ async def lookup_client_secret(
.where(ClientSecret.client_id == client.id)
.where(ClientSecret.name == name)
)
results = session.scalars(statement)
results = await session.scalars(statement)
return results.first()
def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
def get_secrets_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Construct clients sub-api."""
router = APIRouter()
@ -48,7 +50,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
request: Request,
name: str,
client_secret: ClientSecretPublic,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> None:
"""Add secret to a client."""
client = await get_client_by_id_or_name(session, name)
@ -69,9 +71,9 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
name=client_secret.name, client_id=client.id, secret=client_secret.secret
)
session.add(db_secret)
session.commit()
session.refresh(db_secret)
audit.audit_create_secret(session, request, client, db_secret)
await session.commit()
await session.refresh(db_secret)
await audit.audit_create_secret(session, request, client, db_secret)
@router.put("/clients/{name}/secrets/{secret_name}")
async def update_client_secret(
@ -79,7 +81,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
name: str,
secret_name: str,
secret_data: BodyValue,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretResponse:
"""Update a client secret.
@ -96,9 +98,9 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
existing_secret.secret = secret_data.value
session.add(existing_secret)
session.commit()
session.refresh(existing_secret)
audit.audit_update_secret(session, request, client, existing_secret)
await session.commit()
await session.refresh(existing_secret)
await audit.audit_update_secret(session, request, client, existing_secret)
return ClientSecretResponse.from_client_secret(existing_secret)
db_secret = ClientSecret(
@ -107,9 +109,9 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
secret=secret_data.value,
)
session.add(db_secret)
session.commit()
session.refresh(db_secret)
audit.audit_create_secret(session, request, client, db_secret)
await session.commit()
await session.refresh(db_secret)
await audit.audit_create_secret(session, request, client, db_secret)
return ClientSecretResponse.from_client_secret(db_secret)
@router.get("/clients/{name}/secrets/{secret_name}")
@ -117,7 +119,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
request: Request,
name: str,
secret_name: str,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretResponse:
"""Get a client secret."""
client = await get_client_by_id_or_name(session, name)
@ -133,7 +135,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
)
response_model = ClientSecretResponse.from_client_secret(secret)
audit.audit_access_secret(session, request, client, secret)
await audit.audit_access_secret(session, request, client, secret)
return response_model
@router.delete("/clients/{name}/secrets/{secret_name}")
@ -141,7 +143,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
request: Request,
name: str,
secret_name: str,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> None:
"""Delete a secret."""
client = await get_client_by_id_or_name(session, name)
@ -156,17 +158,20 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
status_code=404, detail="Cannot find a secret with the given name."
)
session.delete(secret)
session.commit()
audit.audit_delete_secret(session, request, client, secret)
await session.delete(secret)
await session.commit()
await audit.audit_delete_secret(session, request, client, secret)
@router.get("/secrets/")
async def get_secret_map(
request: Request, session: Annotated[Session, Depends(get_db_session)]
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> list[ClientSecretList]:
"""Get a list of all secrets and which clients have them."""
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
for client_secret in session.scalars(select(ClientSecret)).all():
client_secrets = await session.scalars(
select(ClientSecret).options(selectinload(ClientSecret.client))
)
for client_secret in client_secrets.all():
if not client_secret.client:
if client_secret.name not in client_secret_map:
client_secret_map[client_secret.name] = []
@ -177,35 +182,45 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
ClientSecretList(name=secret_name, clients=clients)
for secret_name, clients in client_secret_map.items()
]
@router.get("/secrets/detailed/")
async def get_detailed_secret_map(
request: Request, session: Annotated[Session, Depends(get_db_session)]
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> list[ClientSecretDetailList]:
"""Get a list of all secrets and which clients have them."""
client_secrets: dict[str, ClientSecretDetailList] = {}
for client_secret in session.scalars(select(ClientSecret)).all():
all_client_secrets = await session.execute(
select(ClientSecret).options(selectinload(ClientSecret.client))
)
for client_secret in all_client_secrets.scalars().all():
if client_secret.name not in client_secrets:
client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name)
client_secrets[client_secret.name] = ClientSecretDetailList(
name=client_secret.name
)
client_secrets[client_secret.name].ids.append(str(client_secret.id))
if not client_secret.client:
continue
client_secrets[client_secret.name].clients.append(ClientReference(id=str(client_secret.client.id), name=client_secret.client.name))
client_secrets[client_secret.name].clients.append(
ClientReference(
id=str(client_secret.client.id), name=client_secret.client.name
)
)
# `audit.audit_client_secret_list(session, request)
return list(client_secrets.values())
@router.get("/secrets/{name}")
async def get_secret_clients(
request: Request,
name: str,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretList:
"""Get a list of which clients has a named secret."""
clients: list[str] = []
for client_secret in session.scalars(
select(ClientSecret).where(ClientSecret.name == name)
).all():
client_secrets = await session.scalars(
select(ClientSecret)
.options(selectinload(ClientSecret.client))
.where(ClientSecret.name == name)
)
for client_secret in client_secrets.all():
if not client_secret.client:
continue
clients.append(client_secret.client.name)
@ -214,19 +229,23 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
@router.get("/secrets/{name}/detailed")
async def get_secret_clients_detailed(
request: Request,
name: str,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretDetailList:
"""Get a list of which clients has a named secret."""
detail_list = ClientSecretDetailList(name=name)
for client_secret in session.scalars(
client_secrets = await session.scalars(
select(ClientSecret).where(ClientSecret.name == name)
).all():
)
for client_secret in client_secrets.all():
if not client_secret.client:
continue
detail_list.ids.append(str(client_secret.id))
detail_list.clients.append(ClientReference(id=str(client_secret.client.id), name=client_secret.client.name))
detail_list.clients.append(
ClientReference(
id=str(client_secret.client.id), name=client_secret.client.name
)
)
return detail_list

View File

@ -13,30 +13,32 @@ from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from sqlalchemy import Engine
from sqlalchemy.ext.asyncio import AsyncEngine
from .models import init_db
from .models import init_db_async
from .backend_api import get_backend_api
from .db import setup_database
from .db import setup_database, get_async_engine
from .settings import BackendSettings
from .types import DBSessionDep
from .types import AsyncDBSessionDep
LOG = logging.getLogger(__name__)
def init_backend_app(engine: Engine, get_db_session: DBSessionDep) -> FastAPI:
def init_backend_app(settings: BackendSettings) -> FastAPI:
"""Initialize backend app."""
@asynccontextmanager
async def lifespan(_app: FastAPI):
"""Create database before starting the server."""
LOG.debug("Running lifespan")
init_db(engine)
engine = get_async_engine(settings.async_db_url)
await init_db_async(engine)
yield
app = FastAPI(lifespan=lifespan)
app.include_router(get_backend_api(get_db_session))
app.include_router(get_backend_api(settings))
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(
@ -60,6 +62,4 @@ def init_backend_app(engine: Engine, get_db_session: DBSessionDep) -> FastAPI:
def create_backend_app(settings: BackendSettings) -> FastAPI:
"""Create the backend app."""
engine, get_db_session = setup_database(settings.db_url)
return init_backend_app(engine, get_db_session)
return init_backend_app(settings)

View File

@ -3,7 +3,7 @@
from collections.abc import Sequence
from fastapi import Request
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy, Operation, SubSystem
@ -17,8 +17,8 @@ def _get_origin(request: Request) -> str | None:
return origin
def _write_audit_log(
session: Session, request: Request, entry: AuditLog, commit: bool = True
async def _write_audit_log(
session: AsyncSession, request: Request, entry: AuditLog, commit: bool = True
) -> None:
"""Write the audit log."""
origin = _get_origin(request)
@ -26,11 +26,11 @@ def _write_audit_log(
entry.subsystem = SubSystem.BACKEND
session.add(entry)
if commit:
session.commit()
await session.commit()
def audit_create_client(
session: Session, request: Request, client: Client, commit: bool = True
async def audit_create_client(
session: AsyncSession, request: Request, client: Client, commit: bool = True
) -> None:
"""Log the creation of a client."""
entry = AuditLog(
@ -39,11 +39,11 @@ def audit_create_client(
client_name=client.name,
message="Client Created",
)
_write_audit_log(session, request, entry, commit)
await _write_audit_log(session, request, entry, commit)
def audit_delete_client(
session: Session, request: Request, client: Client, commit: bool = True
async def audit_delete_client(
session: AsyncSession, request: Request, client: Client, commit: bool = True
) -> None:
"""Log the creation of a client."""
entry = AuditLog(
@ -52,11 +52,11 @@ def audit_delete_client(
client_name=client.name,
message="Client deleted",
)
_write_audit_log(session, request, entry, commit)
await _write_audit_log(session, request, entry, commit)
def audit_create_secret(
session: Session,
async def audit_create_secret(
session: AsyncSession,
request: Request,
client: Client,
secret: ClientSecret,
@ -71,11 +71,11 @@ def audit_create_secret(
client_name=client.name,
message="Added secret to client",
)
_write_audit_log(session, request, entry, commit)
await _write_audit_log(session, request, entry, commit)
def audit_remove_policy(
session: Session,
async def audit_remove_policy(
session: AsyncSession,
request: Request,
client: Client,
policy: ClientAccessPolicy,
@ -90,11 +90,11 @@ def audit_remove_policy(
message="Deleted client policy",
data=data,
)
_write_audit_log(session, request, entry, commit)
await _write_audit_log(session, request, entry, commit)
def audit_update_policy(
session: Session,
async def audit_update_policy(
session: AsyncSession,
request: Request,
client: Client,
policy: ClientAccessPolicy,
@ -109,11 +109,11 @@ def audit_update_policy(
message="Updated client policy",
data=data,
)
_write_audit_log(session, request, entry, commit)
await _write_audit_log(session, request, entry, commit)
def audit_update_client(
session: Session,
async def audit_update_client(
session: AsyncSession,
request: Request,
client: Client,
commit: bool = True,
@ -125,11 +125,11 @@ def audit_update_client(
client_name=client.name,
message="Client data updated",
)
_write_audit_log(session, request, entry, commit)
await _write_audit_log(session, request, entry, commit)
def audit_update_secret(
session: Session,
async def audit_update_secret(
session: AsyncSession,
request: Request,
client: Client,
secret: ClientSecret,
@ -144,11 +144,11 @@ def audit_update_secret(
secret_id=secret.id,
message="Secret value updated",
)
_write_audit_log(session, request, entry, commit)
await _write_audit_log(session, request, entry, commit)
def audit_invalidate_secrets(
session: Session,
async def audit_invalidate_secrets(
session: AsyncSession,
request: Request,
client: Client,
commit: bool = True,
@ -160,11 +160,11 @@ def audit_invalidate_secrets(
client_id=client.id,
message="Client public-key changed. All secrets invalidated.",
)
_write_audit_log(session, request, entry, commit)
await _write_audit_log(session, request, entry, commit)
def audit_delete_secret(
session: Session,
async def audit_delete_secret(
session: AsyncSession,
request: Request,
client: Client,
secret: ClientSecret,
@ -179,11 +179,11 @@ def audit_delete_secret(
client_id=client.id,
message="Secret removed from client",
)
_write_audit_log(session, request, entry, commit)
await _write_audit_log(session, request, entry, commit)
def audit_access_secrets(
session: Session,
async def audit_access_secrets(
session: AsyncSession,
request: Request,
client: Client,
secrets: Sequence[ClientSecret] | None = None,
@ -194,19 +194,20 @@ def audit_access_secrets(
With no secrets provided, all secrets of the client will be resolved.
"""
if not secrets:
secrets = session.scalars(
secrets_q = await session.scalars(
select(ClientSecret).where(ClientSecret.client_id == client.id)
).all()
)
secrets = secrets_q.all()
for secret in secrets:
audit_access_secret(session, request, client, secret, False)
await audit_access_secret(session, request, client, secret, False)
if commit:
session.commit()
await session.commit()
def audit_access_secret(
session: Session,
async def audit_access_secret(
session: AsyncSession,
request: Request,
client: Client,
secret: ClientSecret,
@ -221,15 +222,15 @@ def audit_access_secret(
client_id=client.id,
client_name=client.name,
)
_write_audit_log(session, request, entry, commit)
await _write_audit_log(session, request, entry, commit)
def audit_client_secret_list(
session: Session, request: Request, commit: bool = True
async def audit_client_secret_list(
session: AsyncSession, request: Request, commit: bool = True
) -> None:
"""Audit a list of all secrets."""
entry = AuditLog(
operation=Operation.READ,
message="All secret names and their clients was viewed",
)
_write_audit_log(session, request, entry, commit)
await _write_audit_log(session, request, entry, commit)

View File

@ -1,17 +1,19 @@
"""Backend API."""
from collections.abc import AsyncGenerator
import logging
from typing import Annotated
from fastapi import APIRouter, Depends, Header, HTTPException
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from sshecret_backend.db import DatabaseSessionManager
from sshecret_backend.settings import BackendSettings
from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api
from .auth import verify_token
from .models import (
APIClient,
)
from .types import DBSessionDep
LOG = logging.getLogger(__name__)
@ -19,20 +21,25 @@ API_VERSION = "v1"
def get_backend_api(
get_db_session: DBSessionDep,
settings: BackendSettings,
) -> APIRouter:
"""Construct backend API."""
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
sessionmanager = DatabaseSessionManager(settings.async_db_url)
async with sessionmanager.session() as session:
yield session
async def validate_token(
x_api_token: Annotated[str, Header()],
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> str:
"""Validate token."""
LOG.debug("Validating token %s", x_api_token)
statement = select(APIClient)
results = session.scalars(statement)
results = await session.scalars(statement)
valid = False
for result in results:
for result in results.all():
if verify_token(x_api_token, result.token):
valid = True
LOG.debug("Token is valid")

View File

@ -4,10 +4,11 @@ import logging
import secrets
import sqlite3
from collections.abc import Generator, Callable
from typing import Literal
from collections.abc import AsyncIterator, Generator, Callable
from contextlib import asynccontextmanager
from typing import Any, Literal
from sqlalchemy import create_engine, Engine, event, select
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession, async_sessionmaker, create_async_engine, AsyncEngine
from sqlalchemy.orm import sessionmaker, Session
@ -20,6 +21,47 @@ from .models import APIClient, SubSystem
LOG = logging.getLogger(__name__)
class DatabaseSessionManager:
def __init__(self, host: URL, **engine_kwargs: str):
self._engine: AsyncEngine | None = get_async_engine(host)
self._sessionmaker: async_sessionmaker[AsyncSession] | None = async_sessionmaker(autocommit=False, bind=self._engine, expire_on_commit=False)
async def close(self):
if self._engine is None:
raise Exception("DatabaseSessionManager is not initialized")
await self._engine.dispose()
self._engine = None
self._sessionmaker = None
@asynccontextmanager
async def connect(self) -> AsyncIterator[AsyncConnection]:
if self._engine is None:
raise Exception("DatabaseSessionManager is not initialized")
async with self._engine.begin() as connection:
try:
yield connection
except Exception:
await connection.rollback()
raise
@asynccontextmanager
async def session(self) -> AsyncIterator[AsyncSession]:
if self._sessionmaker is None:
raise Exception("DatabaseSessionManager is not initialized")
session = self._sessionmaker()
try:
yield session
except Exception:
await session.rollback()
raise
finally:
await session.close()
def setup_database(
db_url: URL,
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
@ -39,9 +81,10 @@ def setup_database(
return engine, get_db_session
def get_engine(url: URL, echo: bool = False) -> Engine:
"""Initialize the engine."""
engine = create_engine(url, echo=echo, future=True)
engine = create_engine(url, echo=echo)
if url.drivername.startswith("sqlite"):
@event.listens_for(engine, "connect")
@ -55,12 +98,11 @@ def get_engine(url: URL, echo: bool = False) -> Engine:
return engine
def get_async_engine(url: URL, echo: bool = False) -> AsyncEngine:
def get_async_engine(url: URL, echo: bool = False, **engine_kwargs: str) -> AsyncEngine:
"""Get an async engine."""
engine = create_async_engine(url, echo=echo, future=True)
engine = create_async_engine(url, echo=echo, **engine_kwargs)
if url.drivername.startswith("sqlite+"):
@event.listens_for(engine, "connect")
@event.listens_for(engine.sync_engine, "connect")
def set_sqlite_pragma(
dbapi_connection: sqlite3.Connection, _connection_record: object
) -> None:

View File

@ -13,6 +13,7 @@ import uuid
from datetime import datetime
import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
@ -186,3 +187,9 @@ class AuditLog(Base):
def init_db(engine: sa.Engine) -> None:
"""Initialize database."""
Base.metadata.create_all(engine)
async def init_db_async(engine: AsyncEngine) -> None:
"""Initialize database."""
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

View File

@ -1,8 +1,11 @@
"""Common type definitions."""
from collections.abc import Callable, Generator
from collections.abc import AsyncGenerator, Callable, Generator
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from sqlalchemy.orm import Session
DBSessionDep = Callable[[], Generator[Session, None, None]]
AsyncDBSessionDep = Callable[[], AsyncGenerator[AsyncSession, None]]