Refactor database layer and auditing
This commit is contained in:
@ -5,7 +5,8 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Annotated
|
||||
|
||||
from sshecret_backend.models import Client, ClientSecret
|
||||
@ -34,7 +35,7 @@ async def lookup_client_secret(
|
||||
.where(ClientSecret.client_id == client.id)
|
||||
.where(ClientSecret.name == name)
|
||||
)
|
||||
results = session.exec(statement)
|
||||
results = session.scalars(statement)
|
||||
return results.first()
|
||||
|
||||
|
||||
@ -165,7 +166,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
) -> list[ClientSecretList]:
|
||||
"""Get a list of all secrets and which clients have them."""
|
||||
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
|
||||
for client_secret in session.exec(select(ClientSecret)).all():
|
||||
for client_secret in session.scalars(select(ClientSecret)).all():
|
||||
if not client_secret.client:
|
||||
if client_secret.name not in client_secret_map:
|
||||
client_secret_map[client_secret.name] = []
|
||||
@ -182,7 +183,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
) -> list[ClientSecretDetailList]:
|
||||
"""Get a list of all secrets and which clients have them."""
|
||||
client_secrets: dict[str, ClientSecretDetailList] = {}
|
||||
for client_secret in session.exec(select(ClientSecret)).all():
|
||||
for client_secret in session.scalars(select(ClientSecret)).all():
|
||||
|
||||
if client_secret.name not in client_secrets:
|
||||
client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name)
|
||||
@ -202,7 +203,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
) -> ClientSecretList:
|
||||
"""Get a list of which clients has a named secret."""
|
||||
clients: list[str] = []
|
||||
for client_secret in session.exec(
|
||||
for client_secret in session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.name == name)
|
||||
).all():
|
||||
if not client_secret.client:
|
||||
@ -219,7 +220,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
) -> ClientSecretDetailList:
|
||||
"""Get a list of which clients has a named secret."""
|
||||
detail_list = ClientSecretDetailList(name=name)
|
||||
for client_secret in session.exec(
|
||||
for client_secret in session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.name == name)
|
||||
).all():
|
||||
if not client_secret.client:
|
||||
|
||||
Reference in New Issue
Block a user