227 lines
7.4 KiB
Python
227 lines
7.4 KiB
Python
"""Client sub-api factory."""
|
|
|
|
# pyright: reportUnusedFunction=false
|
|
|
|
import uuid
|
|
import logging
|
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
|
from pydantic import BaseModel, Field, model_validator
|
|
from sqlmodel import Session, col, select
|
|
from sqlalchemy import func
|
|
from typing import Annotated, Self, TypeVar
|
|
|
|
from sqlmodel.sql.expression import SelectOfScalar
|
|
from sshecret_backend.types import DBSessionDep
|
|
from sshecret_backend.models import Client, ClientSecret
|
|
from sshecret_backend.view_models import (
|
|
ClientCreate,
|
|
ClientQueryResult,
|
|
ClientView,
|
|
ClientUpdate,
|
|
)
|
|
from sshecret_backend import audit
|
|
from .common import get_client_by_id_or_name
|
|
|
|
|
|
class ClientListParams(BaseModel):
|
|
"""Client list parameters."""
|
|
|
|
limit: int = Field(100, gt=0, le=100)
|
|
offset: int = Field(0, ge=0)
|
|
id: uuid.UUID | None = None
|
|
name: str | None = None
|
|
name__like: str | None = None
|
|
name__contains: str | None = None
|
|
|
|
@model_validator(mode="after")
|
|
def validate_expressions(self) -> Self:
|
|
"""Validate mutually exclusive expression."""
|
|
name_filter = False
|
|
if self.name__like or self.name__contains:
|
|
name_filter = True
|
|
if self.name__like and self.name__contains:
|
|
raise ValueError("You may only specify one name expression")
|
|
if self.name and name_filter:
|
|
raise ValueError(
|
|
"You must either specify name or one of name__like or name__contains"
|
|
)
|
|
|
|
return self
|
|
|
|
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
def filter_client_statement(
|
|
statement: SelectOfScalar[T], params: ClientListParams, ignore_limits: bool = False
|
|
) -> SelectOfScalar[T]:
|
|
"""Filter a statement with the provided params."""
|
|
if params.id:
|
|
statement = statement.where(Client.id == params.id)
|
|
|
|
if params.name:
|
|
statement = statement.where(Client.name == params.name)
|
|
elif params.name__like:
|
|
statement = statement.where(col(Client.name).like(params.name__like))
|
|
elif params.name__contains:
|
|
statement = statement.where(col(Client.name).contains(params.name__contains))
|
|
|
|
if ignore_limits:
|
|
return statement
|
|
|
|
return statement.limit(params.limit).offset(params.offset)
|
|
|
|
|
|
def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
|
"""Construct clients sub-api."""
|
|
router = APIRouter()
|
|
|
|
@router.get("/clients/")
|
|
async def get_clients(
|
|
filter_query: Annotated[ClientListParams, Query()],
|
|
session: Annotated[Session, Depends(get_db_session)],
|
|
) -> ClientQueryResult:
|
|
"""Get clients."""
|
|
# Get total results first
|
|
count_statement = select(func.count("*")).select_from(Client)
|
|
count_statement = filter_client_statement(count_statement, filter_query, True)
|
|
|
|
total_results = session.exec(count_statement).one()
|
|
|
|
statement = filter_client_statement(select(Client), filter_query, False)
|
|
|
|
results = session.exec(statement)
|
|
remainder = total_results - filter_query.offset - filter_query.limit
|
|
if remainder < 0:
|
|
remainder = 0
|
|
|
|
clients = list(results)
|
|
clients_view = ClientView.from_client_list(clients)
|
|
return ClientQueryResult(
|
|
clients=clients_view,
|
|
total_results=total_results,
|
|
remaining_results=remainder,
|
|
)
|
|
|
|
@router.get("/clients/{name}")
|
|
async def get_client(
|
|
name: str,
|
|
session: Annotated[Session, Depends(get_db_session)],
|
|
) -> ClientView:
|
|
"""Fetch a client."""
|
|
client = await get_client_by_id_or_name(session, name)
|
|
if not client:
|
|
raise HTTPException(
|
|
status_code=404, detail="Cannot find a client with the given name."
|
|
)
|
|
return ClientView.from_client(client)
|
|
|
|
@router.delete("/clients/{name}")
|
|
async def delete_client(
|
|
request: Request,
|
|
name: str,
|
|
session: Annotated[Session, Depends(get_db_session)],
|
|
) -> None:
|
|
"""Delete a client."""
|
|
client = await get_client_by_id_or_name(session, name)
|
|
if not client:
|
|
raise HTTPException(
|
|
status_code=404, detail="Cannot find a client with the given name."
|
|
)
|
|
|
|
session.delete(client)
|
|
session.commit()
|
|
audit.audit_delete_client(session, request, client)
|
|
|
|
@router.post("/clients/")
|
|
async def create_client(
|
|
request: Request,
|
|
client: ClientCreate,
|
|
session: Annotated[Session, Depends(get_db_session)],
|
|
) -> ClientView:
|
|
"""Create client."""
|
|
existing = await get_client_by_id_or_name(session, client.name)
|
|
if existing:
|
|
raise HTTPException(400, detail="Error: Already a client with that name.")
|
|
|
|
db_client = client.to_client()
|
|
session.add(db_client)
|
|
session.commit()
|
|
session.refresh(db_client)
|
|
audit.audit_create_client(session, request, db_client)
|
|
return ClientView.from_client(db_client)
|
|
|
|
@router.post("/clients/{name}/public-key")
|
|
async def update_client_public_key(
|
|
request: Request,
|
|
name: str,
|
|
client_update: ClientUpdate,
|
|
session: Annotated[Session, Depends(get_db_session)],
|
|
) -> ClientView:
|
|
"""Change the public key of a client.
|
|
|
|
This invalidates all secrets.
|
|
"""
|
|
client = await get_client_by_id_or_name(session, name)
|
|
if not client:
|
|
raise HTTPException(
|
|
status_code=404, detail="Cannot find a client with the given name."
|
|
)
|
|
client.public_key = client_update.public_key
|
|
for secret in session.exec(
|
|
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
|
).all():
|
|
LOG.debug("Invalidated secret %s", secret.id)
|
|
secret.invalidated = True
|
|
secret.client_id = None
|
|
secret.client = None
|
|
|
|
session.add(client)
|
|
session.refresh(client)
|
|
session.commit()
|
|
audit.audit_invalidate_secrets(session, request, client)
|
|
|
|
return ClientView.from_client(client)
|
|
|
|
@router.put("/clients/{name}")
|
|
async def update_client(
|
|
request: Request,
|
|
name: str,
|
|
client_update: ClientCreate,
|
|
session: Annotated[Session, Depends(get_db_session)],
|
|
) -> ClientView:
|
|
"""Change the public key of a client.
|
|
|
|
This invalidates all secrets.
|
|
"""
|
|
client = await get_client_by_id_or_name(session, name)
|
|
if not client:
|
|
raise HTTPException(
|
|
status_code=404, detail="Cannot find a client with the given name."
|
|
)
|
|
client.name = client_update.name
|
|
client.description = client_update.description
|
|
public_key_updated = False
|
|
if client_update.public_key != client.public_key:
|
|
public_key_updated = True
|
|
for secret in session.exec(
|
|
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
|
).all():
|
|
LOG.debug("Invalidated secret %s", secret.id)
|
|
secret.invalidated = True
|
|
secret.client_id = None
|
|
secret.client = None
|
|
|
|
session.add(client)
|
|
session.commit()
|
|
session.refresh(client)
|
|
audit.audit_update_client(session, request, client)
|
|
if public_key_updated:
|
|
audit.audit_invalidate_secrets(session, request, client)
|
|
|
|
return ClientView.from_client(client)
|
|
|
|
return router
|