Complete backend

This commit is contained in:
2025-04-18 16:39:05 +02:00
parent 83551ffb4a
commit ec90fb7680
11 changed files with 561 additions and 121 deletions

View File

@ -1,26 +1,52 @@
"""FastAPI api."""
"""FastAPI api.
TODO: We may want to allow a consumer to generate audit log entries manually.
"""
import logging
from collections.abc import Sequence
from contextlib import asynccontextmanager
from typing import Annotated
from collections.abc import Sequence
import bcrypt
from fastapi import APIRouter, Depends, FastAPI, Header, HTTPException, Query, Request
from fastapi import (
APIRouter,
Depends,
FastAPI,
Header,
HTTPException,
Query,
Request,
status,
)
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from sqlmodel import Session, select
from . import audit
from .db import get_engine
from .models import APIClient, AuditLog, Client, ClientSecret, init_db
from .models import (
APIClient,
AuditLog,
Client,
ClientAccessPolicy,
ClientSecret,
init_db,
)
from .settings import get_settings
from .view_models import (
BodyValue,
ClientCreate,
ClientListResponse,
ClientSecretPublic,
ClientSecretResponse,
ClientUpdate,
ClientView,
ClientPolicyView,
ClientPolicyUpdate,
)
settings = get_settings()
@ -104,12 +130,12 @@ backend_api = APIRouter(
@backend_api.get("/clients/")
async def get_clients(
session: Annotated[Session, Depends(get_session)]
) -> list[ClientListResponse]:
) -> list[ClientView]:
"""Get clients."""
statement = select(Client)
results = session.exec(statement)
clients = list(results)
return ClientListResponse.from_clients(clients)
return ClientView.from_client_list(clients)
@backend_api.get("/clients/{name}")
@ -128,14 +154,93 @@ async def get_client(
return ClientView.from_client(client)
@backend_api.post("/clients/{name}/update_fingerprint")
async def update_client_fingerprint(
@backend_api.delete("/clients/{name}")
async def delete_client(
request: Request, name: str, session: Annotated[Session, Depends(get_session)]
) -> None:
"""Delete a client."""
statement = select(Client).where(Client.name == name)
results = session.exec(statement)
client = results.first()
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
session.delete(client)
session.commit()
audit.audit_delete_client(session, request, client)
@backend_api.get("/clients/{name}/policies/")
async def get_client_policies(
name: str, session: Annotated[Session, Depends(get_session)]
) -> ClientPolicyView:
"""Get client policies."""
statement = select(Client).where(Client.name == name)
results = session.exec(statement)
client = results.first()
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
return ClientPolicyView.from_client(client)
@backend_api.put("/clients/{name}/policies/")
async def update_client_policies(
request: Request,
name: str,
policy_update: ClientPolicyUpdate,
session: Annotated[Session, Depends(get_session)],
) -> ClientPolicyView:
"""Update client policies.
This is also how you delete policies.
"""
statement = select(Client).where(Client.name == name)
results = session.exec(statement)
client = results.first()
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
# Remove old policies.
policies = session.exec(
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
).all()
deleted_policies: list[ClientAccessPolicy] = []
added_policies: list[ClientAccessPolicy] = []
for policy in policies:
session.delete(policy)
deleted_policies.append(policy)
for source in policy_update.sources:
LOG.debug("Source %r", source)
policy = ClientAccessPolicy(source=str(source), client_id=client.id)
session.add(policy)
added_policies.append(policy)
session.commit()
session.refresh(client)
for policy in deleted_policies:
audit.audit_remove_policy(session, request, client, policy)
for policy in added_policies:
audit.audit_update_policy(session, request, client, policy)
return ClientPolicyView.from_client(client)
@backend_api.post("/clients/{name}/public-key")
async def update_client_public_key(
request: Request,
name: str,
client_update: ClientUpdate,
session: Annotated[Session, Depends(get_session)],
) -> ClientView:
"""Update the client fingerprint.
"""Change the public key of a client.
This invalidates all secrets.
"""
@ -146,7 +251,7 @@ async def update_client_fingerprint(
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
client.fingerprint = client_update.fingerprint
client.public_key = client_update.public_key
for secret in session.exec(
select(ClientSecret).where(ClientSecret.client_id == client.id)
).all():
@ -170,7 +275,11 @@ async def create_client(
session: Annotated[Session, Depends(get_session)],
) -> ClientView:
"""Create client."""
db_client = Client.model_validate(client)
existing = await get_client_by_name(session, client.name)
if existing:
raise HTTPException(400, detail="Error: Already a client with that name.")
db_client = client.to_client()
session.add(db_client)
session.commit()
session.refresh(db_client)
@ -270,6 +379,30 @@ async def request_client_secret(
audit.audit_access_secret(session, request, client, secret)
return response_model
@backend_api.delete("/clients/{name}/secrets/{secret_name}")
async def delete_client_secret(
request: Request,
name: str,
secret_name: str,
session: Annotated[Session, Depends(get_session)],
) -> None:
"""Delete a secret."""
client = await get_client_by_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
secret = await lookup_client_secret(session, client, secret_name)
if not secret:
raise HTTPException(
status_code=404, detail="Cannot find a secret with the given name."
)
session.delete(secret)
session.commit()
audit.audit_delete_secret(session, request, client, secret)
@backend_api.get("/audit/", response_model=list[AuditLog])
async def get_audit_logs(
@ -287,3 +420,17 @@ async def get_audit_logs(
results = session.exec(statement).all()
return results
app = FastAPI()
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
)
app.include_router(backend_api)