"""Policies sub-api router factory.""" # pyright: reportUnusedFunction=false import logging from fastapi import APIRouter, Depends, HTTPException, Request from sqlmodel import Session, select from typing import Annotated from sshecret_backend.models import Client, ClientAccessPolicy from sshecret_backend.view_models import ( ClientPolicyView, ClientPolicyUpdate, ) from sshecret_backend.types import DBSessionDep from sshecret_backend import audit from .common import get_client_by_id_or_name LOG = logging.getLogger(__name__) def get_policy_api(get_db_session: DBSessionDep) -> APIRouter: """Construct clients sub-api.""" router = APIRouter() @router.get("/clients/{name}/policies/") async def get_client_policies( name: str, session: Annotated[Session, Depends(get_db_session)] ) -> ClientPolicyView: """Get client policies.""" client = await get_client_by_id_or_name(session, name) if not client: raise HTTPException( status_code=404, detail="Cannot find a client with the given name." ) return ClientPolicyView.from_client(client) @router.put("/clients/{name}/policies/") async def update_client_policies( request: Request, name: str, policy_update: ClientPolicyUpdate, session: Annotated[Session, Depends(get_db_session)], ) -> ClientPolicyView: """Update client policies. This is also how you delete policies. """ client = await get_client_by_id_or_name(session, name) if not client: raise HTTPException( status_code=404, detail="Cannot find a client with the given name." ) # Remove old policies. policies = session.exec( select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id) ).all() deleted_policies: list[ClientAccessPolicy] = [] added_policies: list[ClientAccessPolicy] = [] for policy in policies: session.delete(policy) deleted_policies.append(policy) for source in policy_update.sources: LOG.debug("Source %r", source) policy = ClientAccessPolicy(source=str(source), client_id=client.id) session.add(policy) added_policies.append(policy) session.commit() session.refresh(client) for policy in deleted_policies: audit.audit_remove_policy(session, request, client, policy) for policy in added_policies: audit.audit_update_policy(session, request, client, policy) return ClientPolicyView.from_client(client) return router