Refactor database layer and auditing
This commit is contained in:
@ -5,13 +5,14 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from fastapi import APIRouter, Depends, Request, Query
|
||||
from sqlmodel import Session, col, func, select
|
||||
from sqlalchemy import desc
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Annotated
|
||||
|
||||
from sshecret_backend.models import AuditLog
|
||||
from sshecret_backend.types import DBSessionDep
|
||||
from sshecret_backend.view_models import AuditInfo
|
||||
from sshecret_backend.view_models import AuditInfo, AuditView
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -21,7 +22,7 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
"""Construct audit sub-api."""
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/audit/", response_model=list[AuditLog])
|
||||
@router.get("/audit/", response_model=list[AuditView])
|
||||
async def get_audit_logs(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
@ -29,35 +30,37 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
limit: Annotated[int, Query(le=100)] = 100,
|
||||
filter_client: Annotated[str | None, Query()] = None,
|
||||
filter_subsystem: Annotated[str | None, Query()] = None,
|
||||
) -> Sequence[AuditLog]:
|
||||
) -> Sequence[AuditView]:
|
||||
"""Get audit logs."""
|
||||
#audit.audit_access_audit_log(session, request)
|
||||
statement = select(AuditLog).offset(offset).limit(limit).order_by(desc(col(AuditLog.timestamp)))
|
||||
statement = select(AuditLog).offset(offset).limit(limit).order_by(AuditLog.timestamp.desc())
|
||||
if filter_client:
|
||||
statement = statement.where(AuditLog.client_name == filter_client)
|
||||
|
||||
if filter_subsystem:
|
||||
statement = statement.where(AuditLog.subsystem == filter_subsystem)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
return results
|
||||
LogAdapt = TypeAdapter(list[AuditView])
|
||||
results = session.scalars(statement).all()
|
||||
return LogAdapt.validate_python(results, from_attributes=True)
|
||||
|
||||
|
||||
@router.post("/audit/")
|
||||
async def add_audit_log(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
entry: AuditLog,
|
||||
) -> AuditLog:
|
||||
entry: AuditView,
|
||||
) -> AuditView:
|
||||
"""Add entry to audit log."""
|
||||
audit_log = AuditLog.model_validate(entry.model_dump(exclude_none=True))
|
||||
audit_log = AuditLog(**entry.model_dump(exclude_none=True))
|
||||
session.add(audit_log)
|
||||
session.commit()
|
||||
return audit_log
|
||||
return AuditView.model_validate(audit_log, from_attributes=True)
|
||||
|
||||
@router.get("/audit/info")
|
||||
async def get_audit_info(request: Request, session: Annotated[Session, Depends(get_db_session)]) -> AuditInfo:
|
||||
"""Get audit info."""
|
||||
audit_count = session.exec(select(func.count('*')).select_from(AuditLog)).one()
|
||||
audit_count = session.scalars(select(func.count('*')).select_from(AuditLog)).one()
|
||||
return AuditInfo(entries=audit_count)
|
||||
|
||||
|
||||
|
||||
@ -6,11 +6,11 @@ import uuid
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from sqlmodel import Session, col, select
|
||||
from sqlalchemy import func
|
||||
from typing import Annotated, Self, TypeVar
|
||||
from typing import Annotated, Any, Self, TypeVar, cast
|
||||
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import Select
|
||||
from sshecret_backend.types import DBSessionDep
|
||||
from sshecret_backend.models import Client, ClientSecret
|
||||
from sshecret_backend.view_models import (
|
||||
@ -55,8 +55,8 @@ T = TypeVar("T")
|
||||
|
||||
|
||||
def filter_client_statement(
|
||||
statement: SelectOfScalar[T], params: ClientListParams, ignore_limits: bool = False
|
||||
) -> SelectOfScalar[T]:
|
||||
statement: Select[Any], params: ClientListParams, ignore_limits: bool = False
|
||||
) -> Select[Any]:
|
||||
"""Filter a statement with the provided params."""
|
||||
if params.id:
|
||||
statement = statement.where(Client.id == params.id)
|
||||
@ -64,9 +64,9 @@ def filter_client_statement(
|
||||
if params.name:
|
||||
statement = statement.where(Client.name == params.name)
|
||||
elif params.name__like:
|
||||
statement = statement.where(col(Client.name).like(params.name__like))
|
||||
statement = statement.where(Client.name.like(params.name__like))
|
||||
elif params.name__contains:
|
||||
statement = statement.where(col(Client.name).contains(params.name__contains))
|
||||
statement = statement.where(Client.name.contains(params.name__contains))
|
||||
|
||||
if ignore_limits:
|
||||
return statement
|
||||
@ -86,13 +86,13 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
"""Get clients."""
|
||||
# Get total results first
|
||||
count_statement = select(func.count("*")).select_from(Client)
|
||||
count_statement = filter_client_statement(count_statement, filter_query, True)
|
||||
count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True))
|
||||
|
||||
total_results = session.exec(count_statement).one()
|
||||
total_results = session.scalars(count_statement).one()
|
||||
|
||||
statement = filter_client_statement(select(Client), filter_query, False)
|
||||
|
||||
results = session.exec(statement)
|
||||
results = session.scalars(statement)
|
||||
remainder = total_results - filter_query.offset - filter_query.limit
|
||||
if remainder < 0:
|
||||
remainder = 0
|
||||
@ -170,13 +170,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
client.public_key = client_update.public_key
|
||||
for secret in session.exec(
|
||||
for secret in session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||
).all():
|
||||
LOG.debug("Invalidated secret %s", secret.id)
|
||||
secret.invalidated = True
|
||||
secret.client_id = None
|
||||
secret.client = None
|
||||
|
||||
session.add(client)
|
||||
session.refresh(client)
|
||||
@ -206,13 +205,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
public_key_updated = False
|
||||
if client_update.public_key != client.public_key:
|
||||
public_key_updated = True
|
||||
for secret in session.exec(
|
||||
for secret in session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||
).all():
|
||||
LOG.debug("Invalidated secret %s", secret.id)
|
||||
secret.invalidated = True
|
||||
secret.client_id = None
|
||||
secret.client = None
|
||||
|
||||
session.add(client)
|
||||
session.commit()
|
||||
|
||||
@ -4,7 +4,8 @@ import re
|
||||
import uuid
|
||||
import bcrypt
|
||||
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from sshecret_backend.models import Client
|
||||
|
||||
@ -20,13 +21,13 @@ def verify_token(token: str, stored_hash: str) -> bool:
|
||||
async def get_client_by_name(session: Session, name: str) -> Client | None:
|
||||
"""Get client by name."""
|
||||
client_filter = select(Client).where(Client.name == name)
|
||||
client_results = session.exec(client_filter)
|
||||
client_results = session.scalars(client_filter)
|
||||
return client_results.first()
|
||||
|
||||
async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None:
|
||||
"""Get client by name."""
|
||||
client_filter = select(Client).where(Client.id == id)
|
||||
client_results = session.exec(client_filter)
|
||||
client_results = session.scalars(client_filter)
|
||||
return client_results.first()
|
||||
|
||||
async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None:
|
||||
|
||||
@ -4,7 +4,8 @@
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Annotated
|
||||
|
||||
from sshecret_backend.models import ClientAccessPolicy
|
||||
@ -54,7 +55,7 @@ def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
# Remove old policies.
|
||||
policies = session.exec(
|
||||
policies = session.scalars(
|
||||
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
|
||||
).all()
|
||||
deleted_policies: list[ClientAccessPolicy] = []
|
||||
|
||||
@ -5,7 +5,8 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Annotated
|
||||
|
||||
from sshecret_backend.models import Client, ClientSecret
|
||||
@ -34,7 +35,7 @@ async def lookup_client_secret(
|
||||
.where(ClientSecret.client_id == client.id)
|
||||
.where(ClientSecret.name == name)
|
||||
)
|
||||
results = session.exec(statement)
|
||||
results = session.scalars(statement)
|
||||
return results.first()
|
||||
|
||||
|
||||
@ -165,7 +166,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
) -> list[ClientSecretList]:
|
||||
"""Get a list of all secrets and which clients have them."""
|
||||
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
|
||||
for client_secret in session.exec(select(ClientSecret)).all():
|
||||
for client_secret in session.scalars(select(ClientSecret)).all():
|
||||
if not client_secret.client:
|
||||
if client_secret.name not in client_secret_map:
|
||||
client_secret_map[client_secret.name] = []
|
||||
@ -182,7 +183,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
) -> list[ClientSecretDetailList]:
|
||||
"""Get a list of all secrets and which clients have them."""
|
||||
client_secrets: dict[str, ClientSecretDetailList] = {}
|
||||
for client_secret in session.exec(select(ClientSecret)).all():
|
||||
for client_secret in session.scalars(select(ClientSecret)).all():
|
||||
|
||||
if client_secret.name not in client_secrets:
|
||||
client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name)
|
||||
@ -202,7 +203,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
) -> ClientSecretList:
|
||||
"""Get a list of which clients has a named secret."""
|
||||
clients: list[str] = []
|
||||
for client_secret in session.exec(
|
||||
for client_secret in session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.name == name)
|
||||
).all():
|
||||
if not client_secret.client:
|
||||
@ -219,7 +220,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
) -> ClientSecretDetailList:
|
||||
"""Get a list of which clients has a named secret."""
|
||||
detail_list = ClientSecretDetailList(name=name)
|
||||
for client_secret in session.exec(
|
||||
for client_secret in session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.name == name)
|
||||
).all():
|
||||
if not client_secret.client:
|
||||
|
||||
Reference in New Issue
Block a user