Refactor client view

This commit is contained in:
2025-06-06 07:32:51 +02:00
parent a7a09f7784
commit aa6b55a911
8 changed files with 144 additions and 319 deletions

View File

@ -1,8 +1,7 @@
"""API factory modules."""
from .audit import get_audit_api
from .clients import get_clients_api
from .policies import get_policy_api
from .secrets import get_secrets_api
__all__ = ["get_audit_api", "get_clients_api", "get_policy_api", "get_secrets_api"]
__all__ = ["get_audit_api", "get_policy_api", "get_secrets_api"]

View File

@ -1,227 +0,0 @@
"""Client sub-api factory."""
# pyright: reportUnusedFunction=false
import uuid
import logging
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from pydantic import BaseModel, Field, model_validator
from typing import Annotated, Any, Self, TypeVar, cast
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import Select
from sshecret_backend.types import AsyncDBSessionDep
from sshecret_backend.models import Client, ClientSecret
from sshecret_backend.view_models import (
ClientCreate,
ClientQueryResult,
ClientView,
ClientUpdate,
)
from sshecret_backend import audit
from .common import get_client_by_id_or_name, client_with_relationships
class ClientListParams(BaseModel):
"""Client list parameters."""
limit: int = Field(100, gt=0, le=100)
offset: int = Field(0, ge=0)
id: uuid.UUID | None = None
name: str | None = None
name__like: str | None = None
name__contains: str | None = None
@model_validator(mode="after")
def validate_expressions(self) -> Self:
"""Validate mutually exclusive expression."""
name_filter = False
if self.name__like or self.name__contains:
name_filter = True
if self.name__like and self.name__contains:
raise ValueError("You may only specify one name expression")
if self.name and name_filter:
raise ValueError(
"You must either specify name or one of name__like or name__contains"
)
return self
LOG = logging.getLogger(__name__)
T = TypeVar("T")
def filter_client_statement(
statement: Select[Any], params: ClientListParams, ignore_limits: bool = False
) -> Select[Any]:
"""Filter a statement with the provided params."""
if params.id:
statement = statement.where(Client.id == params.id)
if params.name:
statement = statement.where(Client.name == params.name)
elif params.name__like:
statement = statement.where(Client.name.like(params.name__like))
elif params.name__contains:
statement = statement.where(Client.name.contains(params.name__contains))
if ignore_limits:
return statement
return statement.limit(params.limit).offset(params.offset)
def get_clients_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Construct clients sub-api."""
router = APIRouter()
@router.get("/clients/")
async def get_clients(
filter_query: Annotated[ClientListParams, Query()],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientQueryResult:
"""Get clients."""
# Get total results first
count_statement = select(func.count("*")).select_from(Client)
count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True))
total_results = (await session.scalars(count_statement)).one()
statement = filter_client_statement(client_with_relationships(), filter_query, False)
results = await session.scalars(statement)
remainder = total_results - filter_query.offset - filter_query.limit
if remainder < 0:
remainder = 0
clients = list(results.all())
clients_view = ClientView.from_client_list(clients)
return ClientQueryResult(
clients=clients_view,
total_results=total_results,
remaining_results=remainder,
)
@router.get("/clients/{name}")
async def get_client(
name: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Fetch a client."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
return ClientView.from_client(client)
@router.delete("/clients/{name}")
async def delete_client(
request: Request,
name: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> None:
"""Delete a client."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
await session.delete(client)
await session.commit()
await audit.audit_delete_client(session, request, client)
@router.post("/clients/")
async def create_client(
request: Request,
client: ClientCreate,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Create client."""
existing = await get_client_by_id_or_name(session, client.name)
if existing:
raise HTTPException(400, detail="Error: Already a client with that name.")
db_client = client.to_client()
session.add(db_client)
await session.commit()
await session.refresh(db_client)
db_client = await get_client_by_id_or_name(session, client.name)
if not db_client:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Could not create the client.")
await audit.audit_create_client(session, request, db_client)
return ClientView.from_client(db_client)
@router.post("/clients/{name}/public-key")
async def update_client_public_key(
request: Request,
name: str,
client_update: ClientUpdate,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Change the public key of a client.
This invalidates all secrets.
"""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
client.public_key = client_update.public_key
matching_secrets = await session.scalars(select(ClientSecret).where(ClientSecret.client_id == client.id))
for secret in matching_secrets.all():
LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True
secret.client_id = None
session.add(client)
await session.refresh(client)
await session.commit()
await audit.audit_invalidate_secrets(session, request, client)
return ClientView.from_client(client)
@router.put("/clients/{name}")
async def update_client(
request: Request,
name: str,
client_update: ClientCreate,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Change the public key of a client.
This invalidates all secrets.
"""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
client.name = client_update.name
client.description = client_update.description
public_key_updated = False
if client_update.public_key != client.public_key:
public_key_updated = True
client_secrets = await session.scalars(
select(ClientSecret).where(ClientSecret.client_id == client.id)
)
for secret in client_secrets.all():
LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True
secret.client_id = None
session.add(client)
await session.commit()
await session.refresh(client)
await audit.audit_update_client(session, request, client)
if public_key_updated:
await audit.audit_invalidate_secrets(session, request, client)
return ClientView.from_client(client)
return router

View File

@ -3,15 +3,27 @@
import re
import uuid
import bcrypt
from dataclasses import dataclass, field
from sqlalchemy import Select
from sqlalchemy.orm import selectinload
from sqlalchemy.future import select
from sqlalchemy.ext.asyncio import AsyncSession
from sshecret_backend.models import Client
from sshecret_backend.models import Client, ClientAccessPolicy
RE_UUID = re.compile(
"^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$"
)
@dataclass
class NewClientVersion:
"""New client version dataclass."""
client: Client
policies: list[ClientAccessPolicy] = field(default_factory=list)
RE_UUID = re.compile("^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$")
def verify_token(token: str, stored_hash: str) -> bool:
"""Verify token."""
@ -19,12 +31,19 @@ def verify_token(token: str, stored_hash: str) -> bool:
stored_bytes = stored_hash.encode("utf-8")
return bcrypt.checkpw(token_bytes, stored_bytes)
async def reload_client_with_relationships(session: AsyncSession, client: Client) -> Client:
async def reload_client_with_relationships(
session: AsyncSession, client: Client
) -> Client:
"""Reload a client from the database."""
session.expunge(client)
stmt = (
select(Client)
.options(selectinload(Client.policies), selectinload(Client.secrets))
.options(
selectinload(Client.policies),
selectinload(Client.secrets),
selectinload(Client.previous_version),
)
.where(Client.id == client.id)
)
result = await session.execute(stmt)
@ -36,13 +55,26 @@ def client_with_relationships() -> Select[tuple[Client]]:
return select(Client).options(
selectinload(Client.secrets),
selectinload(Client.policies),
selectinload(Client.previous_version),
)
async def get_client_by_name(session: AsyncSession, name: str) -> Client | None:
"""Get client by name."""
client_filter = client_with_relationships().where(Client.name == name)
client_results = await session.execute(client_filter)
return client_results.scalars().first()
async def resolve_client_id(
session: AsyncSession, name: str, version: int | None = None, include_deleted: bool = False,
) -> uuid.UUID | None:
"""Get the ID of a client name."""
if include_deleted:
client_filter = client_with_relationships().where(Client.name == name)
else:
client_filter = query_active_clients().where(Client.name == name)
if version:
client_filter = client_filter.where(Client.version == version)
client_result = await session.execute(client_filter)
if client := client_result.scalars().first():
return client.id
return None
async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | None:
"""Get client by ID."""
@ -50,10 +82,75 @@ async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | Non
client_results = await session.execute(client_filter)
return client_results.scalars().first()
async def get_client_by_id_or_name(session: AsyncSession, id_or_name: str) -> Client | None:
async def get_client_by_id_or_name(
session: AsyncSession, id_or_name: str
) -> Client | None:
"""Get client either by id or name."""
if RE_UUID.match(id_or_name):
id = uuid.UUID(id_or_name)
return await get_client_by_id(session, id)
return await get_client_by_name(session, id_or_name)
def query_active_clients() -> Select[tuple[Client]]:
"""Get all active clients."""
client_filter = (
client_with_relationships()
.where(Client.is_active.is_(True))
.where(Client.is_deleted.is_(False))
)
return client_filter
async def get_client_by_name(session: AsyncSession, name: str) -> Client | None:
"""Get client by name.
This will get the latest client version, unless it's deleted.
"""
client_filter = (
client_with_relationships()
.where(Client.is_active.is_(True))
.where(Client.is_deleted.is_not(True))
.where(Client.name == name)
.order_by(Client.version.desc())
)
client_result = await session.execute(client_filter)
return client_result.scalars().first()
async def refresh_client(session: AsyncSession, client: Client) -> None:
"""Refresh the client and load in all relationships."""
await session.refresh(
client, attribute_names=["secrets", "policies", "previous_version", "updated_at"]
)
async def create_new_client_version(
session: AsyncSession, current_client: Client, new_public_key: str
) -> Client:
new_client = Client(
name=current_client.name,
version=current_client.version + 1,
description=current_client.description,
public_key=new_public_key,
previous_version_id=current_client.id,
is_active=True,
)
# Mark current client as inactive
current_client.is_active = False
# Copy policies
for policy in current_client.policies:
copied_policy = ClientAccessPolicy(
client=new_client,
address=policy.source,
)
session.add(copied_policy)
session.add(new_client)
await session.flush()
await refresh_client(session, new_client)
return new_client