83 lines
2.7 KiB
Python
83 lines
2.7 KiB
Python
"""Policies sub-api router factory."""
|
|
|
|
# pyright: reportUnusedFunction=false
|
|
|
|
import logging
|
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
|
from sqlmodel import Session, select
|
|
from typing import Annotated
|
|
|
|
from sshecret_backend.models import Client, ClientAccessPolicy
|
|
from sshecret_backend.view_models import (
|
|
ClientPolicyView,
|
|
ClientPolicyUpdate,
|
|
)
|
|
from sshecret_backend.types import DBSessionDep
|
|
from sshecret_backend import audit
|
|
from .common import get_client_by_id_or_name
|
|
|
|
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
|
|
def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
|
|
"""Construct clients sub-api."""
|
|
router = APIRouter()
|
|
|
|
@router.get("/clients/{name}/policies/")
|
|
async def get_client_policies(
|
|
name: str, session: Annotated[Session, Depends(get_db_session)]
|
|
) -> ClientPolicyView:
|
|
"""Get client policies."""
|
|
client = await get_client_by_id_or_name(session, name)
|
|
if not client:
|
|
raise HTTPException(
|
|
status_code=404, detail="Cannot find a client with the given name."
|
|
)
|
|
|
|
return ClientPolicyView.from_client(client)
|
|
|
|
@router.put("/clients/{name}/policies/")
|
|
async def update_client_policies(
|
|
request: Request,
|
|
name: str,
|
|
policy_update: ClientPolicyUpdate,
|
|
session: Annotated[Session, Depends(get_db_session)],
|
|
) -> ClientPolicyView:
|
|
"""Update client policies.
|
|
|
|
This is also how you delete policies.
|
|
"""
|
|
client = await get_client_by_id_or_name(session, name)
|
|
if not client:
|
|
raise HTTPException(
|
|
status_code=404, detail="Cannot find a client with the given name."
|
|
)
|
|
# Remove old policies.
|
|
policies = session.exec(
|
|
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
|
|
).all()
|
|
deleted_policies: list[ClientAccessPolicy] = []
|
|
added_policies: list[ClientAccessPolicy] = []
|
|
for policy in policies:
|
|
session.delete(policy)
|
|
deleted_policies.append(policy)
|
|
|
|
for source in policy_update.sources:
|
|
LOG.debug("Source %r", source)
|
|
policy = ClientAccessPolicy(source=str(source), client_id=client.id)
|
|
session.add(policy)
|
|
added_policies.append(policy)
|
|
|
|
session.commit()
|
|
session.refresh(client)
|
|
for policy in deleted_policies:
|
|
audit.audit_remove_policy(session, request, client, policy)
|
|
|
|
for policy in added_policies:
|
|
audit.audit_update_policy(session, request, client, policy)
|
|
|
|
return ClientPolicyView.from_client(client)
|
|
|
|
return router
|