Refactor database layer and auditing

This commit is contained in:
2025-05-10 08:38:57 +02:00
parent d866553ac1
commit 9ccd2f1d4d
20 changed files with 718 additions and 469 deletions

View File

@ -5,7 +5,8 @@
import logging
from collections import defaultdict
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlmodel import Session, select
from sqlalchemy import select
from sqlalchemy.orm import Session
from typing import Annotated
from sshecret_backend.models import Client, ClientSecret
@ -34,7 +35,7 @@ async def lookup_client_secret(
.where(ClientSecret.client_id == client.id)
.where(ClientSecret.name == name)
)
results = session.exec(statement)
results = session.scalars(statement)
return results.first()
@ -165,7 +166,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
) -> list[ClientSecretList]:
"""Get a list of all secrets and which clients have them."""
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
for client_secret in session.exec(select(ClientSecret)).all():
for client_secret in session.scalars(select(ClientSecret)).all():
if not client_secret.client:
if client_secret.name not in client_secret_map:
client_secret_map[client_secret.name] = []
@ -182,7 +183,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
) -> list[ClientSecretDetailList]:
"""Get a list of all secrets and which clients have them."""
client_secrets: dict[str, ClientSecretDetailList] = {}
for client_secret in session.exec(select(ClientSecret)).all():
for client_secret in session.scalars(select(ClientSecret)).all():
if client_secret.name not in client_secrets:
client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name)
@ -202,7 +203,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
) -> ClientSecretList:
"""Get a list of which clients has a named secret."""
clients: list[str] = []
for client_secret in session.exec(
for client_secret in session.scalars(
select(ClientSecret).where(ClientSecret.name == name)
).all():
if not client_secret.client:
@ -219,7 +220,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
) -> ClientSecretDetailList:
"""Get a list of which clients has a named secret."""
detail_list = ClientSecretDetailList(name=name)
for client_secret in session.exec(
for client_secret in session.scalars(
select(ClientSecret).where(ClientSecret.name == name)
).all():
if not client_secret.client: