Check in backend in working state

This commit is contained in:
2025-04-30 08:23:31 +02:00
parent 76ef97d9c4
commit 20f1ee707a
26 changed files with 1505 additions and 621 deletions

View File

@ -0,0 +1,8 @@
"""API factory modules."""
from .audit import get_audit_api
from .clients import get_clients_api
from .policies import get_policy_api
from .secrets import get_secrets_api
__all__ = ["get_audit_api", "get_clients_api", "get_policy_api", "get_secrets_api"]

View File

@ -0,0 +1,65 @@
"""Audit sub-api factory."""
# pyright: reportUnusedFunction=false
import logging
from collections.abc import Sequence
from fastapi import APIRouter, Depends, Request, Query
from sqlmodel import Session, col, func, select
from sqlalchemy import desc
from typing import Annotated
from sshecret_backend.models import AuditLog
from sshecret_backend.types import DBSessionDep
from sshecret_backend import audit
from sshecret_backend.view_models import AuditInfo
LOG = logging.getLogger(__name__)
def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
"""Construct audit sub-api."""
router = APIRouter()
@router.get("/audit/", response_model=list[AuditLog])
async def get_audit_logs(
request: Request,
session: Annotated[Session, Depends(get_db_session)],
offset: Annotated[int, Query()] = 0,
limit: Annotated[int, Query(le=100)] = 100,
filter_client: Annotated[str | None, Query()] = None,
filter_subsystem: Annotated[str | None, Query()] = None,
) -> Sequence[AuditLog]:
"""Get audit logs."""
#audit.audit_access_audit_log(session, request)
statement = select(AuditLog).offset(offset).limit(limit).order_by(desc(col(AuditLog.timestamp)))
if filter_client:
statement = statement.where(AuditLog.client_name == filter_client)
if filter_subsystem:
statement = statement.where(AuditLog.subsystem == filter_subsystem)
results = session.exec(statement).all()
return results
@router.post("/audit/")
async def add_audit_log(
request: Request,
session: Annotated[Session, Depends(get_db_session)],
entry: AuditLog,
) -> AuditLog:
"""Add entry to audit log."""
audit_log = AuditLog.model_validate(entry.model_dump(exclude_none=True))
session.add(audit_log)
session.commit()
return audit_log
@router.get("/audit/info")
async def get_audit_info(request: Request, session: Annotated[Session, Depends(get_db_session)]) -> AuditInfo:
"""Get audit info."""
audit_count = session.exec(select(func.count('*')).select_from(AuditLog)).one()
return AuditInfo(entries=audit_count)
return router

View File

@ -0,0 +1,226 @@
"""Client sub-api factory."""
# pyright: reportUnusedFunction=false
import uuid
import logging
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from pydantic import BaseModel, Field, model_validator
from sqlmodel import Session, col, select
from sqlalchemy import func
from typing import Annotated, Self, TypeVar
from sqlmodel.sql.expression import SelectOfScalar
from sshecret_backend.types import DBSessionDep
from sshecret_backend.models import Client, ClientSecret
from sshecret_backend.view_models import (
ClientCreate,
ClientQueryResult,
ClientView,
ClientUpdate,
)
from sshecret_backend import audit
from .common import get_client_by_id_or_name
class ClientListParams(BaseModel):
"""Client list parameters."""
limit: int = Field(100, gt=0, le=100)
offset: int = Field(0, ge=0)
id: uuid.UUID | None = None
name: str | None = None
name__like: str | None = None
name__contains: str | None = None
@model_validator(mode="after")
def validate_expressions(self) -> Self:
"""Validate mutually exclusive expression."""
name_filter = False
if self.name__like or self.name__contains:
name_filter = True
if self.name__like and self.name__contains:
raise ValueError("You may only specify one name expression")
if self.name and name_filter:
raise ValueError(
"You must either specify name or one of name__like or name__contains"
)
return self
LOG = logging.getLogger(__name__)
T = TypeVar("T")
def filter_client_statement(
statement: SelectOfScalar[T], params: ClientListParams, ignore_limits: bool = False
) -> SelectOfScalar[T]:
"""Filter a statement with the provided params."""
if params.id:
statement = statement.where(Client.id == params.id)
if params.name:
statement = statement.where(Client.name == params.name)
elif params.name__like:
statement = statement.where(col(Client.name).like(params.name__like))
elif params.name__contains:
statement = statement.where(col(Client.name).contains(params.name__contains))
if ignore_limits:
return statement
return statement.limit(params.limit).offset(params.offset)
def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
"""Construct clients sub-api."""
router = APIRouter()
@router.get("/clients/")
async def get_clients(
filter_query: Annotated[ClientListParams, Query()],
session: Annotated[Session, Depends(get_db_session)],
) -> ClientQueryResult:
"""Get clients."""
# Get total results first
count_statement = select(func.count("*")).select_from(Client)
count_statement = filter_client_statement(count_statement, filter_query, True)
total_results = session.exec(count_statement).one()
statement = filter_client_statement(select(Client), filter_query, False)
results = session.exec(statement)
remainder = total_results - filter_query.offset - filter_query.limit
if remainder < 0:
remainder = 0
clients = list(results)
clients_view = ClientView.from_client_list(clients)
return ClientQueryResult(
clients=clients_view,
total_results=total_results,
remaining_results=remainder,
)
@router.get("/clients/{name}")
async def get_client(
name: str,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientView:
"""Fetch a client."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
return ClientView.from_client(client)
@router.delete("/clients/{name}")
async def delete_client(
request: Request,
name: str,
session: Annotated[Session, Depends(get_db_session)],
) -> None:
"""Delete a client."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
session.delete(client)
session.commit()
audit.audit_delete_client(session, request, client)
@router.post("/clients/")
async def create_client(
request: Request,
client: ClientCreate,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientView:
"""Create client."""
existing = await get_client_by_id_or_name(session, client.name)
if existing:
raise HTTPException(400, detail="Error: Already a client with that name.")
db_client = client.to_client()
session.add(db_client)
session.commit()
session.refresh(db_client)
audit.audit_create_client(session, request, db_client)
return ClientView.from_client(db_client)
@router.post("/clients/{name}/public-key")
async def update_client_public_key(
request: Request,
name: str,
client_update: ClientUpdate,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientView:
"""Change the public key of a client.
This invalidates all secrets.
"""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
client.public_key = client_update.public_key
for secret in session.exec(
select(ClientSecret).where(ClientSecret.client_id == client.id)
).all():
LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True
secret.client_id = None
secret.client = None
session.add(client)
session.refresh(client)
session.commit()
audit.audit_invalidate_secrets(session, request, client)
return ClientView.from_client(client)
@router.put("/clients/{name}")
async def update_client(
request: Request,
name: str,
client_update: ClientCreate,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientView:
"""Change the public key of a client.
This invalidates all secrets.
"""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
client.name = client_update.name
client.description = client_update.description
public_key_updated = False
if client_update.public_key != client.public_key:
public_key_updated = True
for secret in session.exec(
select(ClientSecret).where(ClientSecret.client_id == client.id)
).all():
LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True
secret.client_id = None
secret.client = None
session.add(client)
session.commit()
session.refresh(client)
audit.audit_update_client(session, request, client)
if public_key_updated:
audit.audit_invalidate_secrets(session, request, client)
return ClientView.from_client(client)
return router

View File

@ -0,0 +1,38 @@
"""Common helpers."""
import re
import uuid
import bcrypt
from sqlmodel import Session, select
from sshecret_backend.models import Client
RE_UUID = re.compile("^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$")
def verify_token(token: str, stored_hash: str) -> bool:
"""Verify token."""
token_bytes = token.encode("utf-8")
stored_bytes = stored_hash.encode("utf-8")
return bcrypt.checkpw(token_bytes, stored_bytes)
async def get_client_by_name(session: Session, name: str) -> Client | None:
"""Get client by name."""
client_filter = select(Client).where(Client.name == name)
client_results = session.exec(client_filter)
return client_results.first()
async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None:
"""Get client by name."""
client_filter = select(Client).where(Client.id == id)
client_results = session.exec(client_filter)
return client_results.first()
async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None:
"""Get client either by id or name."""
if RE_UUID.match(id_or_name):
id = uuid.UUID(id_or_name)
return await get_client_by_id(session, id)
return await get_client_by_name(session, id_or_name)

View File

@ -0,0 +1,82 @@
"""Policies sub-api router factory."""
# pyright: reportUnusedFunction=false
import logging
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlmodel import Session, select
from typing import Annotated
from sshecret_backend.models import Client, ClientAccessPolicy
from sshecret_backend.view_models import (
ClientPolicyView,
ClientPolicyUpdate,
)
from sshecret_backend.types import DBSessionDep
from sshecret_backend import audit
from .common import get_client_by_id_or_name
LOG = logging.getLogger(__name__)
def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
"""Construct clients sub-api."""
router = APIRouter()
@router.get("/clients/{name}/policies/")
async def get_client_policies(
name: str, session: Annotated[Session, Depends(get_db_session)]
) -> ClientPolicyView:
"""Get client policies."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
return ClientPolicyView.from_client(client)
@router.put("/clients/{name}/policies/")
async def update_client_policies(
request: Request,
name: str,
policy_update: ClientPolicyUpdate,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientPolicyView:
"""Update client policies.
This is also how you delete policies.
"""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
# Remove old policies.
policies = session.exec(
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
).all()
deleted_policies: list[ClientAccessPolicy] = []
added_policies: list[ClientAccessPolicy] = []
for policy in policies:
session.delete(policy)
deleted_policies.append(policy)
for source in policy_update.sources:
LOG.debug("Source %r", source)
policy = ClientAccessPolicy(source=str(source), client_id=client.id)
session.add(policy)
added_policies.append(policy)
session.commit()
session.refresh(client)
for policy in deleted_policies:
audit.audit_remove_policy(session, request, client, policy)
for policy in added_policies:
audit.audit_update_policy(session, request, client, policy)
return ClientPolicyView.from_client(client)
return router

View File

@ -0,0 +1,232 @@
"""Secrets sub-api factory."""
# pyright: reportUnusedFunction=false
import logging
from collections import defaultdict
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlmodel import Session, select
from typing import Annotated
from sshecret_backend.models import Client, ClientSecret
from sshecret_backend.view_models import (
ClientReference,
ClientSecretDetailList,
ClientSecretList,
ClientSecretPublic,
BodyValue,
ClientSecretResponse,
)
from sshecret_backend import audit
from sshecret_backend.types import DBSessionDep
from .common import get_client_by_id_or_name
LOG = logging.getLogger(__name__)
async def lookup_client_secret(
session: Session, client: Client, name: str
) -> ClientSecret | None:
"""Look up a secret for a client."""
statement = (
select(ClientSecret)
.where(ClientSecret.client_id == client.id)
.where(ClientSecret.name == name)
)
results = session.exec(statement)
return results.first()
def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
"""Construct clients sub-api."""
router = APIRouter()
@router.post("/clients/{name}/secrets/")
async def add_secret_to_client(
request: Request,
name: str,
client_secret: ClientSecretPublic,
session: Annotated[Session, Depends(get_db_session)],
) -> None:
"""Add secret to a client."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
existing_secret = await lookup_client_secret(
session, client, client_secret.name
)
if existing_secret:
raise HTTPException(
status_code=400,
detail="Cannot add a secret. A different secret with the same name already exists.",
)
db_secret = ClientSecret(
name=client_secret.name, client_id=client.id, secret=client_secret.secret
)
session.add(db_secret)
session.commit()
session.refresh(db_secret)
audit.audit_create_secret(session, request, client, db_secret)
@router.put("/clients/{name}/secrets/{secret_name}")
async def update_client_secret(
request: Request,
name: str,
secret_name: str,
secret_data: BodyValue,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientSecretResponse:
"""Update a client secret.
This can also be used for destructive creates.
"""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
existing_secret = await lookup_client_secret(session, client, secret_name)
if existing_secret:
existing_secret.secret = secret_data.value
session.add(existing_secret)
session.commit()
session.refresh(existing_secret)
audit.audit_update_secret(session, request, client, existing_secret)
return ClientSecretResponse.from_client_secret(existing_secret)
db_secret = ClientSecret(
name=secret_name,
client_id=client.id,
secret=secret_data.value,
)
session.add(db_secret)
session.commit()
session.refresh(db_secret)
audit.audit_create_secret(session, request, client, db_secret)
return ClientSecretResponse.from_client_secret(db_secret)
@router.get("/clients/{name}/secrets/{secret_name}")
async def request_client_secret(
request: Request,
name: str,
secret_name: str,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientSecretResponse:
"""Get a client secret."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
secret = await lookup_client_secret(session, client, secret_name)
if not secret:
raise HTTPException(
status_code=404, detail="Cannot find a secret with the given name."
)
response_model = ClientSecretResponse.from_client_secret(secret)
audit.audit_access_secret(session, request, client, secret)
return response_model
@router.delete("/clients/{name}/secrets/{secret_name}")
async def delete_client_secret(
request: Request,
name: str,
secret_name: str,
session: Annotated[Session, Depends(get_db_session)],
) -> None:
"""Delete a secret."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
secret = await lookup_client_secret(session, client, secret_name)
if not secret:
raise HTTPException(
status_code=404, detail="Cannot find a secret with the given name."
)
session.delete(secret)
session.commit()
audit.audit_delete_secret(session, request, client, secret)
@router.get("/secrets/")
async def get_secret_map(
request: Request, session: Annotated[Session, Depends(get_db_session)]
) -> list[ClientSecretList]:
"""Get a list of all secrets and which clients have them."""
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
for client_secret in session.exec(select(ClientSecret)).all():
if not client_secret.client:
if client_secret.name not in client_secret_map:
client_secret_map[client_secret.name] = []
continue
client_secret_map[client_secret.name].append(client_secret.client.name)
audit.audit_client_secret_list(session, request)
return [
ClientSecretList(name=secret_name, clients=clients)
for secret_name, clients in client_secret_map.items()
]
@router.get("/secrets/detailed/")
async def get_detailed_secret_map(
request: Request, session: Annotated[Session, Depends(get_db_session)]
) -> list[ClientSecretDetailList]:
"""Get a list of all secrets and which clients have them."""
client_secrets: dict[str, ClientSecretDetailList] = {}
for client_secret in session.exec(select(ClientSecret)).all():
if client_secret.name not in client_secrets:
client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name)
client_secrets[client_secret.name].ids.append(str(client_secret.id))
if not client_secret.client:
continue
client_secrets[client_secret.name].clients.append(ClientReference(id=str(client_secret.client.id), name=client_secret.client.name))
audit.audit_client_secret_list(session, request)
return list(client_secrets.values())
@router.get("/secrets/{name}")
async def get_secret_clients(
request: Request,
name: str,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientSecretList:
"""Get a list of which clients has a named secret."""
clients: list[str] = []
for client_secret in session.exec(
select(ClientSecret).where(ClientSecret.name == name)
).all():
if not client_secret.client:
continue
clients.append(client_secret.client.name)
return ClientSecretList(name=name, clients=clients)
@router.get("/secrets/{name}/detailed")
async def get_secret_clients_detailed(
request: Request,
name: str,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientSecretDetailList:
"""Get a list of which clients has a named secret."""
detail_list = ClientSecretDetailList(name=name)
for client_secret in session.exec(
select(ClientSecret).where(ClientSecret.name == name)
).all():
if not client_secret.client:
continue
detail_list.ids.append(str(client_secret.id))
detail_list.clients.append(ClientReference(id=str(client_secret.client.id), name=client_secret.client.name))
return detail_list
return router