234 lines
8.5 KiB
Python
234 lines
8.5 KiB
Python
"""Secrets sub-api factory."""
|
|
|
|
# pyright: reportUnusedFunction=false
|
|
|
|
import logging
|
|
from collections import defaultdict
|
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session
|
|
from typing import Annotated
|
|
|
|
from sshecret_backend.models import Client, ClientSecret
|
|
from sshecret_backend.view_models import (
|
|
ClientReference,
|
|
ClientSecretDetailList,
|
|
ClientSecretList,
|
|
ClientSecretPublic,
|
|
BodyValue,
|
|
ClientSecretResponse,
|
|
)
|
|
from sshecret_backend import audit
|
|
from sshecret_backend.types import DBSessionDep
|
|
from .common import get_client_by_id_or_name
|
|
|
|
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
|
|
async def lookup_client_secret(
|
|
session: Session, client: Client, name: str
|
|
) -> ClientSecret | None:
|
|
"""Look up a secret for a client."""
|
|
statement = (
|
|
select(ClientSecret)
|
|
.where(ClientSecret.client_id == client.id)
|
|
.where(ClientSecret.name == name)
|
|
)
|
|
results = session.scalars(statement)
|
|
return results.first()
|
|
|
|
|
|
def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
|
"""Construct clients sub-api."""
|
|
router = APIRouter()
|
|
|
|
@router.post("/clients/{name}/secrets/")
|
|
async def add_secret_to_client(
|
|
request: Request,
|
|
name: str,
|
|
client_secret: ClientSecretPublic,
|
|
session: Annotated[Session, Depends(get_db_session)],
|
|
) -> None:
|
|
"""Add secret to a client."""
|
|
client = await get_client_by_id_or_name(session, name)
|
|
if not client:
|
|
raise HTTPException(
|
|
status_code=404, detail="Cannot find a client with the given name."
|
|
)
|
|
|
|
existing_secret = await lookup_client_secret(
|
|
session, client, client_secret.name
|
|
)
|
|
if existing_secret:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Cannot add a secret. A different secret with the same name already exists.",
|
|
)
|
|
db_secret = ClientSecret(
|
|
name=client_secret.name, client_id=client.id, secret=client_secret.secret
|
|
)
|
|
session.add(db_secret)
|
|
session.commit()
|
|
session.refresh(db_secret)
|
|
audit.audit_create_secret(session, request, client, db_secret)
|
|
|
|
@router.put("/clients/{name}/secrets/{secret_name}")
|
|
async def update_client_secret(
|
|
request: Request,
|
|
name: str,
|
|
secret_name: str,
|
|
secret_data: BodyValue,
|
|
session: Annotated[Session, Depends(get_db_session)],
|
|
) -> ClientSecretResponse:
|
|
"""Update a client secret.
|
|
|
|
This can also be used for destructive creates.
|
|
"""
|
|
client = await get_client_by_id_or_name(session, name)
|
|
if not client:
|
|
raise HTTPException(
|
|
status_code=404, detail="Cannot find a client with the given name."
|
|
)
|
|
|
|
existing_secret = await lookup_client_secret(session, client, secret_name)
|
|
if existing_secret:
|
|
existing_secret.secret = secret_data.value
|
|
|
|
session.add(existing_secret)
|
|
session.commit()
|
|
session.refresh(existing_secret)
|
|
audit.audit_update_secret(session, request, client, existing_secret)
|
|
return ClientSecretResponse.from_client_secret(existing_secret)
|
|
|
|
db_secret = ClientSecret(
|
|
name=secret_name,
|
|
client_id=client.id,
|
|
secret=secret_data.value,
|
|
)
|
|
session.add(db_secret)
|
|
session.commit()
|
|
session.refresh(db_secret)
|
|
audit.audit_create_secret(session, request, client, db_secret)
|
|
return ClientSecretResponse.from_client_secret(db_secret)
|
|
|
|
@router.get("/clients/{name}/secrets/{secret_name}")
|
|
async def request_client_secret(
|
|
request: Request,
|
|
name: str,
|
|
secret_name: str,
|
|
session: Annotated[Session, Depends(get_db_session)],
|
|
) -> ClientSecretResponse:
|
|
"""Get a client secret."""
|
|
client = await get_client_by_id_or_name(session, name)
|
|
if not client:
|
|
raise HTTPException(
|
|
status_code=404, detail="Cannot find a client with the given name."
|
|
)
|
|
|
|
secret = await lookup_client_secret(session, client, secret_name)
|
|
if not secret:
|
|
raise HTTPException(
|
|
status_code=404, detail="Cannot find a secret with the given name."
|
|
)
|
|
|
|
response_model = ClientSecretResponse.from_client_secret(secret)
|
|
audit.audit_access_secret(session, request, client, secret)
|
|
return response_model
|
|
|
|
@router.delete("/clients/{name}/secrets/{secret_name}")
|
|
async def delete_client_secret(
|
|
request: Request,
|
|
name: str,
|
|
secret_name: str,
|
|
session: Annotated[Session, Depends(get_db_session)],
|
|
) -> None:
|
|
"""Delete a secret."""
|
|
client = await get_client_by_id_or_name(session, name)
|
|
if not client:
|
|
raise HTTPException(
|
|
status_code=404, detail="Cannot find a client with the given name."
|
|
)
|
|
|
|
secret = await lookup_client_secret(session, client, secret_name)
|
|
if not secret:
|
|
raise HTTPException(
|
|
status_code=404, detail="Cannot find a secret with the given name."
|
|
)
|
|
|
|
session.delete(secret)
|
|
session.commit()
|
|
audit.audit_delete_secret(session, request, client, secret)
|
|
|
|
@router.get("/secrets/")
|
|
async def get_secret_map(
|
|
request: Request, session: Annotated[Session, Depends(get_db_session)]
|
|
) -> list[ClientSecretList]:
|
|
"""Get a list of all secrets and which clients have them."""
|
|
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
|
|
for client_secret in session.scalars(select(ClientSecret)).all():
|
|
if not client_secret.client:
|
|
if client_secret.name not in client_secret_map:
|
|
client_secret_map[client_secret.name] = []
|
|
continue
|
|
client_secret_map[client_secret.name].append(client_secret.client.name)
|
|
#audit.audit_client_secret_list(session, request)
|
|
return [
|
|
ClientSecretList(name=secret_name, clients=clients)
|
|
for secret_name, clients in client_secret_map.items()
|
|
]
|
|
@router.get("/secrets/detailed/")
|
|
async def get_detailed_secret_map(
|
|
request: Request, session: Annotated[Session, Depends(get_db_session)]
|
|
) -> list[ClientSecretDetailList]:
|
|
"""Get a list of all secrets and which clients have them."""
|
|
client_secrets: dict[str, ClientSecretDetailList] = {}
|
|
for client_secret in session.scalars(select(ClientSecret)).all():
|
|
|
|
if client_secret.name not in client_secrets:
|
|
client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name)
|
|
client_secrets[client_secret.name].ids.append(str(client_secret.id))
|
|
if not client_secret.client:
|
|
continue
|
|
client_secrets[client_secret.name].clients.append(ClientReference(id=str(client_secret.client.id), name=client_secret.client.name))
|
|
#`audit.audit_client_secret_list(session, request)
|
|
return list(client_secrets.values())
|
|
|
|
|
|
@router.get("/secrets/{name}")
|
|
async def get_secret_clients(
|
|
request: Request,
|
|
name: str,
|
|
session: Annotated[Session, Depends(get_db_session)],
|
|
) -> ClientSecretList:
|
|
"""Get a list of which clients has a named secret."""
|
|
clients: list[str] = []
|
|
for client_secret in session.scalars(
|
|
select(ClientSecret).where(ClientSecret.name == name)
|
|
).all():
|
|
if not client_secret.client:
|
|
continue
|
|
clients.append(client_secret.client.name)
|
|
|
|
return ClientSecretList(name=name, clients=clients)
|
|
|
|
@router.get("/secrets/{name}/detailed")
|
|
async def get_secret_clients_detailed(
|
|
request: Request,
|
|
name: str,
|
|
session: Annotated[Session, Depends(get_db_session)],
|
|
) -> ClientSecretDetailList:
|
|
"""Get a list of which clients has a named secret."""
|
|
detail_list = ClientSecretDetailList(name=name)
|
|
for client_secret in session.scalars(
|
|
select(ClientSecret).where(ClientSecret.name == name)
|
|
).all():
|
|
if not client_secret.client:
|
|
continue
|
|
detail_list.ids.append(str(client_secret.id))
|
|
detail_list.clients.append(ClientReference(id=str(client_secret.client.id), name=client_secret.client.name))
|
|
|
|
return detail_list
|
|
|
|
return router
|