"""Client sub-api factory.""" # pyright: reportUnusedFunction=false import uuid import logging from fastapi import APIRouter, Depends, HTTPException, Query, Request from pydantic import BaseModel, Field, model_validator from sqlmodel import Session, col, select from sqlalchemy import func from typing import Annotated, Self, TypeVar from sqlmodel.sql.expression import SelectOfScalar from sshecret_backend.types import DBSessionDep from sshecret_backend.models import Client, ClientSecret from sshecret_backend.view_models import ( ClientCreate, ClientQueryResult, ClientView, ClientUpdate, ) from sshecret_backend import audit from .common import get_client_by_id_or_name class ClientListParams(BaseModel): """Client list parameters.""" limit: int = Field(100, gt=0, le=100) offset: int = Field(0, ge=0) id: uuid.UUID | None = None name: str | None = None name__like: str | None = None name__contains: str | None = None @model_validator(mode="after") def validate_expressions(self) -> Self: """Validate mutually exclusive expression.""" name_filter = False if self.name__like or self.name__contains: name_filter = True if self.name__like and self.name__contains: raise ValueError("You may only specify one name expression") if self.name and name_filter: raise ValueError( "You must either specify name or one of name__like or name__contains" ) return self LOG = logging.getLogger(__name__) T = TypeVar("T") def filter_client_statement( statement: SelectOfScalar[T], params: ClientListParams, ignore_limits: bool = False ) -> SelectOfScalar[T]: """Filter a statement with the provided params.""" if params.id: statement = statement.where(Client.id == params.id) if params.name: statement = statement.where(Client.name == params.name) elif params.name__like: statement = statement.where(col(Client.name).like(params.name__like)) elif params.name__contains: statement = statement.where(col(Client.name).contains(params.name__contains)) if ignore_limits: return statement return statement.limit(params.limit).offset(params.offset) def get_clients_api(get_db_session: DBSessionDep) -> APIRouter: """Construct clients sub-api.""" router = APIRouter() @router.get("/clients/") async def get_clients( filter_query: Annotated[ClientListParams, Query()], session: Annotated[Session, Depends(get_db_session)], ) -> ClientQueryResult: """Get clients.""" # Get total results first count_statement = select(func.count("*")).select_from(Client) count_statement = filter_client_statement(count_statement, filter_query, True) total_results = session.exec(count_statement).one() statement = filter_client_statement(select(Client), filter_query, False) results = session.exec(statement) remainder = total_results - filter_query.offset - filter_query.limit if remainder < 0: remainder = 0 clients = list(results) clients_view = ClientView.from_client_list(clients) return ClientQueryResult( clients=clients_view, total_results=total_results, remaining_results=remainder, ) @router.get("/clients/{name}") async def get_client( name: str, session: Annotated[Session, Depends(get_db_session)], ) -> ClientView: """Fetch a client.""" client = await get_client_by_id_or_name(session, name) if not client: raise HTTPException( status_code=404, detail="Cannot find a client with the given name." ) return ClientView.from_client(client) @router.delete("/clients/{name}") async def delete_client( request: Request, name: str, session: Annotated[Session, Depends(get_db_session)], ) -> None: """Delete a client.""" client = await get_client_by_id_or_name(session, name) if not client: raise HTTPException( status_code=404, detail="Cannot find a client with the given name." ) session.delete(client) session.commit() audit.audit_delete_client(session, request, client) @router.post("/clients/") async def create_client( request: Request, client: ClientCreate, session: Annotated[Session, Depends(get_db_session)], ) -> ClientView: """Create client.""" existing = await get_client_by_id_or_name(session, client.name) if existing: raise HTTPException(400, detail="Error: Already a client with that name.") db_client = client.to_client() session.add(db_client) session.commit() session.refresh(db_client) audit.audit_create_client(session, request, db_client) return ClientView.from_client(db_client) @router.post("/clients/{name}/public-key") async def update_client_public_key( request: Request, name: str, client_update: ClientUpdate, session: Annotated[Session, Depends(get_db_session)], ) -> ClientView: """Change the public key of a client. This invalidates all secrets. """ client = await get_client_by_id_or_name(session, name) if not client: raise HTTPException( status_code=404, detail="Cannot find a client with the given name." ) client.public_key = client_update.public_key for secret in session.exec( select(ClientSecret).where(ClientSecret.client_id == client.id) ).all(): LOG.debug("Invalidated secret %s", secret.id) secret.invalidated = True secret.client_id = None secret.client = None session.add(client) session.refresh(client) session.commit() audit.audit_invalidate_secrets(session, request, client) return ClientView.from_client(client) @router.put("/clients/{name}") async def update_client( request: Request, name: str, client_update: ClientCreate, session: Annotated[Session, Depends(get_db_session)], ) -> ClientView: """Change the public key of a client. This invalidates all secrets. """ client = await get_client_by_id_or_name(session, name) if not client: raise HTTPException( status_code=404, detail="Cannot find a client with the given name." ) client.name = client_update.name client.description = client_update.description public_key_updated = False if client_update.public_key != client.public_key: public_key_updated = True for secret in session.exec( select(ClientSecret).where(ClientSecret.client_id == client.id) ).all(): LOG.debug("Invalidated secret %s", secret.id) secret.invalidated = True secret.client_id = None secret.client = None session.add(client) session.commit() session.refresh(client) audit.audit_update_client(session, request, client) if public_key_updated: audit.audit_invalidate_secrets(session, request, client) return ClientView.from_client(client) return router