Refactor to use async database model

This commit is contained in:
2025-05-19 09:15:48 +02:00
parent f10ae027e5
commit fc0c3fb950
11 changed files with 288 additions and 185 deletions

View File

@ -3,18 +3,18 @@
# pyright: reportUnusedFunction=false
import logging
from collections.abc import Sequence
from typing import Any, cast
from fastapi import APIRouter, Depends, Request, Query
from typing import Any
from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel, Field, TypeAdapter
from sqlalchemy import select, func, and_
from sqlalchemy.orm import InstrumentedAttribute, Session
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import InstrumentedAttribute
from sqlalchemy.sql.expression import ColumnExpressionArgument
from typing import Annotated
from sshecret_backend.models import AuditLog, Operation, SubSystem
from sshecret_backend.types import DBSessionDep
from sshecret_backend.types import AsyncDBSessionDep
from sshecret_backend.view_models import AuditInfo, AuditView, AuditListResult
@ -58,24 +58,23 @@ class AuditFilter(BaseModel):
]
def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
def get_audit_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Construct audit sub-api."""
router = APIRouter()
@router.get("/audit/", response_model=AuditListResult)
async def get_audit_logs(
request: Request,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
filters: Annotated[AuditFilter, Depends()],
) -> AuditListResult:
"""Get audit logs."""
# audit.audit_access_audit_log(session, request)
total = session.scalars(
total = (await session.scalars(
select(func.count("*"))
.select_from(AuditLog)
.where(and_(True, *filters.filter_mapping))
).one()
)).one()
remaining = total - filters.offset
statement = (
@ -87,7 +86,7 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
)
LogAdapt = TypeAdapter(list[AuditView])
results = session.scalars(statement).all()
results = (await session.scalars(statement)).all()
entries = LogAdapt.validate_python(results, from_attributes=True)
return AuditListResult(
results=entries,
@ -97,24 +96,23 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
@router.post("/audit/")
async def add_audit_log(
request: Request,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
entry: AuditView,
) -> AuditView:
"""Add entry to audit log."""
audit_log = AuditLog(**entry.model_dump(exclude_none=True))
session.add(audit_log)
session.commit()
await session.commit()
return AuditView.model_validate(audit_log, from_attributes=True)
@router.get("/audit/info")
async def get_audit_info(
request: Request, session: Annotated[Session, Depends(get_db_session)]
session: Annotated[AsyncSession, Depends(get_db_session)]
) -> AuditInfo:
"""Get audit info."""
audit_count = session.scalars(
audit_count = (await session.scalars(
select(func.count("*")).select_from(AuditLog)
).one()
)).one()
return AuditInfo(entries=audit_count)
return router

View File

@ -4,14 +4,14 @@
import uuid
import logging
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from pydantic import BaseModel, Field, model_validator
from typing import Annotated, Any, Self, TypeVar, cast
from sqlalchemy import select, func
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import Select
from sshecret_backend.types import DBSessionDep
from sshecret_backend.types import AsyncDBSessionDep
from sshecret_backend.models import Client, ClientSecret
from sshecret_backend.view_models import (
ClientCreate,
@ -20,7 +20,7 @@ from sshecret_backend.view_models import (
ClientUpdate,
)
from sshecret_backend import audit
from .common import get_client_by_id_or_name
from .common import get_client_by_id_or_name, client_with_relationships
class ClientListParams(BaseModel):
@ -74,30 +74,30 @@ def filter_client_statement(
return statement.limit(params.limit).offset(params.offset)
def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
def get_clients_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Construct clients sub-api."""
router = APIRouter()
@router.get("/clients/")
async def get_clients(
filter_query: Annotated[ClientListParams, Query()],
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientQueryResult:
"""Get clients."""
# Get total results first
count_statement = select(func.count("*")).select_from(Client)
count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True))
total_results = session.scalars(count_statement).one()
total_results = (await session.scalars(count_statement)).one()
statement = filter_client_statement(select(Client), filter_query, False)
statement = filter_client_statement(client_with_relationships(), filter_query, False)
results = session.scalars(statement)
results = await session.scalars(statement)
remainder = total_results - filter_query.offset - filter_query.limit
if remainder < 0:
remainder = 0
clients = list(results)
clients = list(results.all())
clients_view = ClientView.from_client_list(clients)
return ClientQueryResult(
clients=clients_view,
@ -108,7 +108,7 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
@router.get("/clients/{name}")
async def get_client(
name: str,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Fetch a client."""
client = await get_client_by_id_or_name(session, name)
@ -122,7 +122,7 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
async def delete_client(
request: Request,
name: str,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> None:
"""Delete a client."""
client = await get_client_by_id_or_name(session, name)
@ -131,15 +131,15 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
status_code=404, detail="Cannot find a client with the given name."
)
session.delete(client)
session.commit()
audit.audit_delete_client(session, request, client)
await session.delete(client)
await session.commit()
await audit.audit_delete_client(session, request, client)
@router.post("/clients/")
async def create_client(
request: Request,
client: ClientCreate,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Create client."""
existing = await get_client_by_id_or_name(session, client.name)
@ -148,9 +148,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
db_client = client.to_client()
session.add(db_client)
session.commit()
session.refresh(db_client)
audit.audit_create_client(session, request, db_client)
await session.commit()
await session.refresh(db_client)
db_client = await get_client_by_id_or_name(session, client.name)
if not db_client:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Could not create the client.")
await audit.audit_create_client(session, request, db_client)
return ClientView.from_client(db_client)
@router.post("/clients/{name}/public-key")
@ -158,7 +161,7 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
request: Request,
name: str,
client_update: ClientUpdate,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Change the public key of a client.
@ -170,17 +173,16 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
status_code=404, detail="Cannot find a client with the given name."
)
client.public_key = client_update.public_key
for secret in session.scalars(
select(ClientSecret).where(ClientSecret.client_id == client.id)
).all():
matching_secrets = await session.scalars(select(ClientSecret).where(ClientSecret.client_id == client.id))
for secret in matching_secrets.all():
LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True
secret.client_id = None
session.add(client)
session.refresh(client)
session.commit()
audit.audit_invalidate_secrets(session, request, client)
await session.refresh(client)
await session.commit()
await audit.audit_invalidate_secrets(session, request, client)
return ClientView.from_client(client)
@ -189,7 +191,7 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
request: Request,
name: str,
client_update: ClientCreate,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Change the public key of a client.
@ -205,19 +207,20 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
public_key_updated = False
if client_update.public_key != client.public_key:
public_key_updated = True
for secret in session.scalars(
client_secrets = await session.scalars(
select(ClientSecret).where(ClientSecret.client_id == client.id)
).all():
)
for secret in client_secrets.all():
LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True
secret.client_id = None
session.add(client)
session.commit()
session.refresh(client)
audit.audit_update_client(session, request, client)
await session.commit()
await session.refresh(client)
await audit.audit_update_client(session, request, client)
if public_key_updated:
audit.audit_invalidate_secrets(session, request, client)
await audit.audit_invalidate_secrets(session, request, client)
return ClientView.from_client(client)

View File

@ -3,9 +3,11 @@
import re
import uuid
import bcrypt
from sqlalchemy import Select
from sqlalchemy.orm import selectinload
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.future import select
from sqlalchemy.ext.asyncio import AsyncSession
from sshecret_backend.models import Client
@ -17,20 +19,38 @@ def verify_token(token: str, stored_hash: str) -> bool:
stored_bytes = stored_hash.encode("utf-8")
return bcrypt.checkpw(token_bytes, stored_bytes)
async def reload_client_with_relationships(session: AsyncSession, client: Client) -> Client:
"""Reload a client from the database."""
session.expunge(client)
stmt = (
select(Client)
.options(selectinload(Client.policies), selectinload(Client.secrets))
.where(Client.id == client.id)
)
result = await session.execute(stmt)
return result.scalar_one()
async def get_client_by_name(session: Session, name: str) -> Client | None:
def client_with_relationships() -> Select[tuple[Client]]:
"""Base select statement for client with relationships."""
return select(Client).options(
selectinload(Client.secrets),
selectinload(Client.policies),
)
async def get_client_by_name(session: AsyncSession, name: str) -> Client | None:
"""Get client by name."""
client_filter = select(Client).where(Client.name == name)
client_results = session.scalars(client_filter)
return client_results.first()
client_filter = client_with_relationships().where(Client.name == name)
client_results = await session.execute(client_filter)
return client_results.scalars().first()
async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None:
"""Get client by name."""
client_filter = select(Client).where(Client.id == id)
client_results = session.scalars(client_filter)
return client_results.first()
async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | None:
"""Get client by ID."""
client_filter = client_with_relationships().where(Client.id == id)
client_results = await session.execute(client_filter)
return client_results.scalars().first()
async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None:
async def get_client_by_id_or_name(session: AsyncSession, id_or_name: str) -> Client | None:
"""Get client either by id or name."""
if RE_UUID.match(id_or_name):
id = uuid.UUID(id_or_name)

View File

@ -5,7 +5,7 @@
import logging
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Annotated
from sshecret_backend.models import ClientAccessPolicy
@ -13,21 +13,21 @@ from sshecret_backend.view_models import (
ClientPolicyView,
ClientPolicyUpdate,
)
from sshecret_backend.types import DBSessionDep
from sshecret_backend.types import AsyncDBSessionDep
from sshecret_backend import audit
from .common import get_client_by_id_or_name
from .common import get_client_by_id_or_name, reload_client_with_relationships
LOG = logging.getLogger(__name__)
def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
def get_policy_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Construct clients sub-api."""
router = APIRouter()
@router.get("/clients/{name}/policies/")
async def get_client_policies(
name: str, session: Annotated[Session, Depends(get_db_session)]
name: str, session: Annotated[AsyncSession, Depends(get_db_session)]
) -> ClientPolicyView:
"""Get client policies."""
client = await get_client_by_id_or_name(session, name)
@ -43,7 +43,7 @@ def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
request: Request,
name: str,
policy_update: ClientPolicyUpdate,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientPolicyView:
"""Update client policies.
@ -55,28 +55,31 @@ def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
status_code=404, detail="Cannot find a client with the given name."
)
# Remove old policies.
policies = session.scalars(
policies = await session.scalars(
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
).all()
)
deleted_policies: list[ClientAccessPolicy] = []
added_policies: list[ClientAccessPolicy] = []
for policy in policies:
session.delete(policy)
for policy in policies.all():
await session.delete(policy)
deleted_policies.append(policy)
LOG.debug("Updating client policies with: %r", policy_update.sources)
for source in policy_update.sources:
LOG.debug("Source %r", source)
policy = ClientAccessPolicy(source=str(source), client_id=client.id)
session.add(policy)
added_policies.append(policy)
session.commit()
session.refresh(client)
await session.flush()
await session.commit()
client = await reload_client_with_relationships(session, client)
for policy in deleted_policies:
audit.audit_remove_policy(session, request, client, policy)
await audit.audit_remove_policy(session, request, client, policy)
for policy in added_policies:
audit.audit_update_policy(session, request, client, policy)
await audit.audit_update_policy(session, request, client, policy)
return ClientPolicyView.from_client(client)

View File

@ -6,9 +6,11 @@ import logging
from collections import defaultdict
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy import select
from sqlalchemy.orm import Session
from typing import Annotated
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sshecret_backend.models import Client, ClientSecret
from sshecret_backend.view_models import (
ClientReference,
@ -19,7 +21,7 @@ from sshecret_backend.view_models import (
ClientSecretResponse,
)
from sshecret_backend import audit
from sshecret_backend.types import DBSessionDep
from sshecret_backend.types import AsyncDBSessionDep
from .common import get_client_by_id_or_name
@ -27,7 +29,7 @@ LOG = logging.getLogger(__name__)
async def lookup_client_secret(
session: Session, client: Client, name: str
session: AsyncSession, client: Client, name: str
) -> ClientSecret | None:
"""Look up a secret for a client."""
statement = (
@ -35,11 +37,11 @@ async def lookup_client_secret(
.where(ClientSecret.client_id == client.id)
.where(ClientSecret.name == name)
)
results = session.scalars(statement)
results = await session.scalars(statement)
return results.first()
def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
def get_secrets_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Construct clients sub-api."""
router = APIRouter()
@ -48,7 +50,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
request: Request,
name: str,
client_secret: ClientSecretPublic,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> None:
"""Add secret to a client."""
client = await get_client_by_id_or_name(session, name)
@ -69,9 +71,9 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
name=client_secret.name, client_id=client.id, secret=client_secret.secret
)
session.add(db_secret)
session.commit()
session.refresh(db_secret)
audit.audit_create_secret(session, request, client, db_secret)
await session.commit()
await session.refresh(db_secret)
await audit.audit_create_secret(session, request, client, db_secret)
@router.put("/clients/{name}/secrets/{secret_name}")
async def update_client_secret(
@ -79,7 +81,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
name: str,
secret_name: str,
secret_data: BodyValue,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretResponse:
"""Update a client secret.
@ -96,9 +98,9 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
existing_secret.secret = secret_data.value
session.add(existing_secret)
session.commit()
session.refresh(existing_secret)
audit.audit_update_secret(session, request, client, existing_secret)
await session.commit()
await session.refresh(existing_secret)
await audit.audit_update_secret(session, request, client, existing_secret)
return ClientSecretResponse.from_client_secret(existing_secret)
db_secret = ClientSecret(
@ -107,9 +109,9 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
secret=secret_data.value,
)
session.add(db_secret)
session.commit()
session.refresh(db_secret)
audit.audit_create_secret(session, request, client, db_secret)
await session.commit()
await session.refresh(db_secret)
await audit.audit_create_secret(session, request, client, db_secret)
return ClientSecretResponse.from_client_secret(db_secret)
@router.get("/clients/{name}/secrets/{secret_name}")
@ -117,7 +119,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
request: Request,
name: str,
secret_name: str,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretResponse:
"""Get a client secret."""
client = await get_client_by_id_or_name(session, name)
@ -133,7 +135,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
)
response_model = ClientSecretResponse.from_client_secret(secret)
audit.audit_access_secret(session, request, client, secret)
await audit.audit_access_secret(session, request, client, secret)
return response_model
@router.delete("/clients/{name}/secrets/{secret_name}")
@ -141,7 +143,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
request: Request,
name: str,
secret_name: str,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> None:
"""Delete a secret."""
client = await get_client_by_id_or_name(session, name)
@ -156,56 +158,69 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
status_code=404, detail="Cannot find a secret with the given name."
)
session.delete(secret)
session.commit()
audit.audit_delete_secret(session, request, client, secret)
await session.delete(secret)
await session.commit()
await audit.audit_delete_secret(session, request, client, secret)
@router.get("/secrets/")
async def get_secret_map(
request: Request, session: Annotated[Session, Depends(get_db_session)]
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> list[ClientSecretList]:
"""Get a list of all secrets and which clients have them."""
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
for client_secret in session.scalars(select(ClientSecret)).all():
client_secrets = await session.scalars(
select(ClientSecret).options(selectinload(ClientSecret.client))
)
for client_secret in client_secrets.all():
if not client_secret.client:
if client_secret.name not in client_secret_map:
client_secret_map[client_secret.name] = []
continue
client_secret_map[client_secret.name].append(client_secret.client.name)
#audit.audit_client_secret_list(session, request)
# audit.audit_client_secret_list(session, request)
return [
ClientSecretList(name=secret_name, clients=clients)
for secret_name, clients in client_secret_map.items()
]
@router.get("/secrets/detailed/")
async def get_detailed_secret_map(
request: Request, session: Annotated[Session, Depends(get_db_session)]
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> list[ClientSecretDetailList]:
"""Get a list of all secrets and which clients have them."""
client_secrets: dict[str, ClientSecretDetailList] = {}
for client_secret in session.scalars(select(ClientSecret)).all():
all_client_secrets = await session.execute(
select(ClientSecret).options(selectinload(ClientSecret.client))
)
for client_secret in all_client_secrets.scalars().all():
if client_secret.name not in client_secrets:
client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name)
client_secrets[client_secret.name] = ClientSecretDetailList(
name=client_secret.name
)
client_secrets[client_secret.name].ids.append(str(client_secret.id))
if not client_secret.client:
continue
client_secrets[client_secret.name].clients.append(ClientReference(id=str(client_secret.client.id), name=client_secret.client.name))
#`audit.audit_client_secret_list(session, request)
client_secrets[client_secret.name].clients.append(
ClientReference(
id=str(client_secret.client.id), name=client_secret.client.name
)
)
# `audit.audit_client_secret_list(session, request)
return list(client_secrets.values())
@router.get("/secrets/{name}")
async def get_secret_clients(
request: Request,
name: str,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretList:
"""Get a list of which clients has a named secret."""
clients: list[str] = []
for client_secret in session.scalars(
select(ClientSecret).where(ClientSecret.name == name)
).all():
client_secrets = await session.scalars(
select(ClientSecret)
.options(selectinload(ClientSecret.client))
.where(ClientSecret.name == name)
)
for client_secret in client_secrets.all():
if not client_secret.client:
continue
clients.append(client_secret.client.name)
@ -214,19 +229,23 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
@router.get("/secrets/{name}/detailed")
async def get_secret_clients_detailed(
request: Request,
name: str,
session: Annotated[Session, Depends(get_db_session)],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretDetailList:
"""Get a list of which clients has a named secret."""
detail_list = ClientSecretDetailList(name=name)
for client_secret in session.scalars(
client_secrets = await session.scalars(
select(ClientSecret).where(ClientSecret.name == name)
).all():
)
for client_secret in client_secrets.all():
if not client_secret.client:
continue
detail_list.ids.append(str(client_secret.id))
detail_list.clients.append(ClientReference(id=str(client_secret.client.id), name=client_secret.client.name))
detail_list.clients.append(
ClientReference(
id=str(client_secret.client.id), name=client_secret.client.name
)
)
return detail_list