Refactor database layer and auditing

This commit is contained in:
2025-05-10 08:38:57 +02:00
parent d866553ac1
commit 9ccd2f1d4d
20 changed files with 718 additions and 469 deletions

View File

@ -5,13 +5,14 @@
import logging
from collections.abc import Sequence
from fastapi import APIRouter, Depends, Request, Query
from sqlmodel import Session, col, func, select
from sqlalchemy import desc
from pydantic import TypeAdapter
from sqlalchemy import select, func
from sqlalchemy.orm import Session
from typing import Annotated
from sshecret_backend.models import AuditLog
from sshecret_backend.types import DBSessionDep
from sshecret_backend.view_models import AuditInfo
from sshecret_backend.view_models import AuditInfo, AuditView
LOG = logging.getLogger(__name__)
@ -21,7 +22,7 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
"""Construct audit sub-api."""
router = APIRouter()
@router.get("/audit/", response_model=list[AuditLog])
@router.get("/audit/", response_model=list[AuditView])
async def get_audit_logs(
request: Request,
session: Annotated[Session, Depends(get_db_session)],
@ -29,35 +30,37 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
limit: Annotated[int, Query(le=100)] = 100,
filter_client: Annotated[str | None, Query()] = None,
filter_subsystem: Annotated[str | None, Query()] = None,
) -> Sequence[AuditLog]:
) -> Sequence[AuditView]:
"""Get audit logs."""
#audit.audit_access_audit_log(session, request)
statement = select(AuditLog).offset(offset).limit(limit).order_by(desc(col(AuditLog.timestamp)))
statement = select(AuditLog).offset(offset).limit(limit).order_by(AuditLog.timestamp.desc())
if filter_client:
statement = statement.where(AuditLog.client_name == filter_client)
if filter_subsystem:
statement = statement.where(AuditLog.subsystem == filter_subsystem)
results = session.exec(statement).all()
return results
LogAdapt = TypeAdapter(list[AuditView])
results = session.scalars(statement).all()
return LogAdapt.validate_python(results, from_attributes=True)
@router.post("/audit/")
async def add_audit_log(
request: Request,
session: Annotated[Session, Depends(get_db_session)],
entry: AuditLog,
) -> AuditLog:
entry: AuditView,
) -> AuditView:
"""Add entry to audit log."""
audit_log = AuditLog.model_validate(entry.model_dump(exclude_none=True))
audit_log = AuditLog(**entry.model_dump(exclude_none=True))
session.add(audit_log)
session.commit()
return audit_log
return AuditView.model_validate(audit_log, from_attributes=True)
@router.get("/audit/info")
async def get_audit_info(request: Request, session: Annotated[Session, Depends(get_db_session)]) -> AuditInfo:
"""Get audit info."""
audit_count = session.exec(select(func.count('*')).select_from(AuditLog)).one()
audit_count = session.scalars(select(func.count('*')).select_from(AuditLog)).one()
return AuditInfo(entries=audit_count)

View File

@ -6,11 +6,11 @@ import uuid
import logging
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from pydantic import BaseModel, Field, model_validator
from sqlmodel import Session, col, select
from sqlalchemy import func
from typing import Annotated, Self, TypeVar
from typing import Annotated, Any, Self, TypeVar, cast
from sqlmodel.sql.expression import SelectOfScalar
from sqlalchemy import select, func
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select
from sshecret_backend.types import DBSessionDep
from sshecret_backend.models import Client, ClientSecret
from sshecret_backend.view_models import (
@ -55,8 +55,8 @@ T = TypeVar("T")
def filter_client_statement(
statement: SelectOfScalar[T], params: ClientListParams, ignore_limits: bool = False
) -> SelectOfScalar[T]:
statement: Select[Any], params: ClientListParams, ignore_limits: bool = False
) -> Select[Any]:
"""Filter a statement with the provided params."""
if params.id:
statement = statement.where(Client.id == params.id)
@ -64,9 +64,9 @@ def filter_client_statement(
if params.name:
statement = statement.where(Client.name == params.name)
elif params.name__like:
statement = statement.where(col(Client.name).like(params.name__like))
statement = statement.where(Client.name.like(params.name__like))
elif params.name__contains:
statement = statement.where(col(Client.name).contains(params.name__contains))
statement = statement.where(Client.name.contains(params.name__contains))
if ignore_limits:
return statement
@ -86,13 +86,13 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
"""Get clients."""
# Get total results first
count_statement = select(func.count("*")).select_from(Client)
count_statement = filter_client_statement(count_statement, filter_query, True)
count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True))
total_results = session.exec(count_statement).one()
total_results = session.scalars(count_statement).one()
statement = filter_client_statement(select(Client), filter_query, False)
results = session.exec(statement)
results = session.scalars(statement)
remainder = total_results - filter_query.offset - filter_query.limit
if remainder < 0:
remainder = 0
@ -170,13 +170,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
status_code=404, detail="Cannot find a client with the given name."
)
client.public_key = client_update.public_key
for secret in session.exec(
for secret in session.scalars(
select(ClientSecret).where(ClientSecret.client_id == client.id)
).all():
LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True
secret.client_id = None
secret.client = None
session.add(client)
session.refresh(client)
@ -206,13 +205,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
public_key_updated = False
if client_update.public_key != client.public_key:
public_key_updated = True
for secret in session.exec(
for secret in session.scalars(
select(ClientSecret).where(ClientSecret.client_id == client.id)
).all():
LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True
secret.client_id = None
secret.client = None
session.add(client)
session.commit()

View File

@ -4,7 +4,8 @@ import re
import uuid
import bcrypt
from sqlmodel import Session, select
from sqlalchemy import select
from sqlalchemy.orm import Session
from sshecret_backend.models import Client
@ -20,13 +21,13 @@ def verify_token(token: str, stored_hash: str) -> bool:
async def get_client_by_name(session: Session, name: str) -> Client | None:
"""Get client by name."""
client_filter = select(Client).where(Client.name == name)
client_results = session.exec(client_filter)
client_results = session.scalars(client_filter)
return client_results.first()
async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None:
"""Get client by name."""
client_filter = select(Client).where(Client.id == id)
client_results = session.exec(client_filter)
client_results = session.scalars(client_filter)
return client_results.first()
async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None:

View File

@ -4,7 +4,8 @@
import logging
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlmodel import Session, select
from sqlalchemy import select
from sqlalchemy.orm import Session
from typing import Annotated
from sshecret_backend.models import ClientAccessPolicy
@ -54,7 +55,7 @@ def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
status_code=404, detail="Cannot find a client with the given name."
)
# Remove old policies.
policies = session.exec(
policies = session.scalars(
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
).all()
deleted_policies: list[ClientAccessPolicy] = []

View File

@ -5,7 +5,8 @@
import logging
from collections import defaultdict
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlmodel import Session, select
from sqlalchemy import select
from sqlalchemy.orm import Session
from typing import Annotated
from sshecret_backend.models import Client, ClientSecret
@ -34,7 +35,7 @@ async def lookup_client_secret(
.where(ClientSecret.client_id == client.id)
.where(ClientSecret.name == name)
)
results = session.exec(statement)
results = session.scalars(statement)
return results.first()
@ -165,7 +166,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
) -> list[ClientSecretList]:
"""Get a list of all secrets and which clients have them."""
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
for client_secret in session.exec(select(ClientSecret)).all():
for client_secret in session.scalars(select(ClientSecret)).all():
if not client_secret.client:
if client_secret.name not in client_secret_map:
client_secret_map[client_secret.name] = []
@ -182,7 +183,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
) -> list[ClientSecretDetailList]:
"""Get a list of all secrets and which clients have them."""
client_secrets: dict[str, ClientSecretDetailList] = {}
for client_secret in session.exec(select(ClientSecret)).all():
for client_secret in session.scalars(select(ClientSecret)).all():
if client_secret.name not in client_secrets:
client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name)
@ -202,7 +203,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
) -> ClientSecretList:
"""Get a list of which clients has a named secret."""
clients: list[str] = []
for client_secret in session.exec(
for client_secret in session.scalars(
select(ClientSecret).where(ClientSecret.name == name)
).all():
if not client_secret.client:
@ -219,7 +220,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
) -> ClientSecretDetailList:
"""Get a list of which clients has a named secret."""
detail_list = ClientSecretDetailList(name=name)
for client_secret in session.exec(
for client_secret in session.scalars(
select(ClientSecret).where(ClientSecret.name == name)
).all():
if not client_secret.client: