Check in backend in working state

This commit is contained in:
2025-04-30 08:23:31 +02:00
parent 76ef97d9c4
commit 20f1ee707a
26 changed files with 1505 additions and 621 deletions

View File

@ -1,436 +1,65 @@
"""FastAPI api.
TODO: We may want to allow a consumer to generate audit log entries manually.
"""
"""FastAPI api."""
# pyright: reportUnusedFunction=false
import logging
from collections.abc import Sequence
from contextlib import asynccontextmanager
from typing import Annotated
import bcrypt
from fastapi import (
APIRouter,
Depends,
FastAPI,
Header,
HTTPException,
Query,
Request,
status,
)
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from sqlalchemy import Engine
from sqlmodel import Session, select
from . import audit
from .db import get_engine
from .models import (
APIClient,
AuditLog,
Client,
ClientAccessPolicy,
ClientSecret,
init_db,
)
from .settings import get_settings
from .view_models import (
BodyValue,
ClientCreate,
ClientSecretPublic,
ClientSecretResponse,
ClientUpdate,
ClientView,
ClientPolicyView,
ClientPolicyUpdate,
)
settings = get_settings()
engine = get_engine(settings.db_file)
from .models import init_db
from .backend_api import get_backend_api
from .db import setup_database
from .settings import BackendSettings
from .types import DBSessionDep
LOG = logging.getLogger(__name__)
API_VERSION = "v1"
def init_backend_app(engine: Engine, get_db_session: DBSessionDep) -> FastAPI:
"""Initialize backend app."""
def verify_token(token: str, stored_hash: str) -> bool:
"""Verify token."""
token_bytes = token.encode("utf-8")
stored_bytes = stored_hash.encode("utf-8")
return bcrypt.checkpw(token_bytes, stored_bytes)
@asynccontextmanager
async def lifespan(_app: FastAPI):
"""Create database before starting the server."""
LOG.debug("Running lifespan")
init_db(engine)
yield
app = FastAPI(lifespan=lifespan)
app.include_router(get_backend_api(get_db_session))
@asynccontextmanager
async def lifespan(_app: FastAPI):
"""Create database before starting the server."""
init_db(engine)
yield
async def get_session():
"""Get the session."""
with Session(engine) as session:
yield session
async def validate_token(
x_api_token: Annotated[str, Header()],
session: Annotated[Session, Depends(get_session)],
) -> str:
"""Validate token."""
LOG.debug("Validating token %s", x_api_token)
statement = select(APIClient)
results = session.exec(statement)
valid = False
for result in results:
if verify_token(x_api_token, result.token):
valid = True
LOG.debug("Token is valid")
break
if not valid:
LOG.debug("Token is not valid.")
raise HTTPException(status_code=401, detail="unauthorized. invalid api token.")
return x_api_token
async def get_client_by_name(session: Session, name: str) -> Client | None:
"""Get client by name."""
client_filter = select(Client).where(Client.name == name)
client_results = session.exec(client_filter)
return client_results.first()
async def lookup_client_secret(
session: Session, client: Client, name: str
) -> ClientSecret | None:
"""Look up a secret for a client."""
statement = (
select(ClientSecret)
.where(ClientSecret.client_id == client.id)
.where(ClientSecret.name == name)
)
results = session.exec(statement)
return results.first()
LOG.info("Initializing app.")
backend_api = APIRouter(
prefix=f"/api/{API_VERSION}",
lifespan=lifespan,
dependencies=[Depends(validate_token)],
)
@backend_api.get("/clients/")
async def get_clients(
session: Annotated[Session, Depends(get_session)]
) -> list[ClientView]:
"""Get clients."""
statement = select(Client)
results = session.exec(statement)
clients = list(results)
return ClientView.from_client_list(clients)
@backend_api.get("/clients/{name}")
async def get_client(
request: Request, name: str, session: Annotated[Session, Depends(get_session)]
) -> ClientView:
"""Fetch a client."""
statement = select(Client).where(Client.name == name)
results = session.exec(statement)
client = results.first()
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
audit.audit_access_secrets(session, request, client)
return ClientView.from_client(client)
@backend_api.delete("/clients/{name}")
async def delete_client(
request: Request, name: str, session: Annotated[Session, Depends(get_session)]
) -> None:
"""Delete a client."""
statement = select(Client).where(Client.name == name)
results = session.exec(statement)
client = results.first()
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(
request: Request, exc: RequestValidationError
):
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
)
session.delete(client)
session.commit()
audit.audit_delete_client(session, request, client)
@backend_api.get("/clients/{name}/policies/")
async def get_client_policies(
name: str, session: Annotated[Session, Depends(get_session)]
) -> ClientPolicyView:
"""Get client policies."""
statement = select(Client).where(Client.name == name)
results = session.exec(statement)
client = results.first()
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
@app.get("/health")
async def get_health() -> JSONResponse:
"""Provide simple health check."""
return JSONResponse(
status_code=status.HTTP_200_OK, content=jsonable_encoder({"status": "LIVE"})
)
return ClientPolicyView.from_client(client)
return app
@backend_api.put("/clients/{name}/policies/")
async def update_client_policies(
request: Request,
name: str,
policy_update: ClientPolicyUpdate,
session: Annotated[Session, Depends(get_session)],
) -> ClientPolicyView:
"""Update client policies.
def create_backend_app(settings: BackendSettings) -> FastAPI:
"""Create the backend app."""
This is also how you delete policies.
"""
statement = select(Client).where(Client.name == name)
results = session.exec(statement)
client = results.first()
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
# Remove old policies.
policies = session.exec(
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
).all()
deleted_policies: list[ClientAccessPolicy] = []
added_policies: list[ClientAccessPolicy] = []
for policy in policies:
session.delete(policy)
deleted_policies.append(policy)
engine, get_db_session = setup_database(settings.db_url)
for source in policy_update.sources:
LOG.debug("Source %r", source)
policy = ClientAccessPolicy(source=str(source), client_id=client.id)
session.add(policy)
added_policies.append(policy)
session.commit()
session.refresh(client)
for policy in deleted_policies:
audit.audit_remove_policy(session, request, client, policy)
for policy in added_policies:
audit.audit_update_policy(session, request, client, policy)
return ClientPolicyView.from_client(client)
@backend_api.post("/clients/{name}/public-key")
async def update_client_public_key(
request: Request,
name: str,
client_update: ClientUpdate,
session: Annotated[Session, Depends(get_session)],
) -> ClientView:
"""Change the public key of a client.
This invalidates all secrets.
"""
statement = select(Client).where(Client.name == name)
results = session.exec(statement)
client = results.first()
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
client.public_key = client_update.public_key
for secret in session.exec(
select(ClientSecret).where(ClientSecret.client_id == client.id)
).all():
LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True
secret.client_id = None
secret.client = None
session.add(client)
session.refresh(client)
session.commit()
audit.audit_invalidate_secrets(session, request, client)
return ClientView.from_client(client)
@backend_api.post("/clients/")
async def create_client(
request: Request,
client: ClientCreate,
session: Annotated[Session, Depends(get_session)],
) -> ClientView:
"""Create client."""
existing = await get_client_by_name(session, client.name)
if existing:
raise HTTPException(400, detail="Error: Already a client with that name.")
db_client = client.to_client()
session.add(db_client)
session.commit()
session.refresh(db_client)
audit.audit_create_client(session, request, db_client)
return ClientView.from_client(db_client)
@backend_api.post("/clients/{name}/secrets/")
async def add_secret_to_client(
request: Request,
name: str,
client_secret: ClientSecretPublic,
session: Annotated[Session, Depends(get_session)],
) -> None:
"""Add secret to a client."""
client = await get_client_by_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
existing_secret = await lookup_client_secret(session, client, client_secret.name)
if existing_secret:
raise HTTPException(
status_code=400,
detail="Cannot add a secret. A different secret with the same name already exists.",
)
db_secret = ClientSecret(
name=client_secret.name, client_id=client.id, secret=client_secret.secret
)
session.add(db_secret)
session.commit()
session.refresh(db_secret)
audit.audit_create_secret(session, request, client, db_secret)
@backend_api.put("/clients/{name}/secrets/{secret_name}")
async def update_client_secret(
request: Request,
name: str,
secret_name: str,
secret_data: BodyValue,
session: Annotated[Session, Depends(get_session)],
) -> ClientSecretResponse:
"""Update a client secret.
This can also be used for destructive creates.
"""
client = await get_client_by_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
existing_secret = await lookup_client_secret(session, client, secret_name)
if existing_secret:
existing_secret.secret = secret_data.value
session.add(existing_secret)
session.commit()
session.refresh(existing_secret)
audit.audit_update_secret(session, request, client, existing_secret)
return ClientSecretResponse.from_client_secret(existing_secret)
db_secret = ClientSecret(
name=secret_name,
client_id=client.id,
secret=secret_data.value,
)
session.add(db_secret)
session.commit()
session.refresh(db_secret)
audit.audit_create_secret(session, request, client, db_secret)
return ClientSecretResponse.from_client_secret(db_secret)
@backend_api.get("/clients/{name}/secrets/{secret_name}")
async def request_client_secret(
request: Request,
name: str,
secret_name: str,
session: Annotated[Session, Depends(get_session)],
) -> ClientSecretResponse:
"""Get a client secret."""
client = await get_client_by_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
secret = await lookup_client_secret(session, client, secret_name)
if not secret:
raise HTTPException(
status_code=404, detail="Cannot find a secret with the given name."
)
response_model = ClientSecretResponse.from_client_secret(secret)
audit.audit_access_secret(session, request, client, secret)
return response_model
@backend_api.delete("/clients/{name}/secrets/{secret_name}")
async def delete_client_secret(
request: Request,
name: str,
secret_name: str,
session: Annotated[Session, Depends(get_session)],
) -> None:
"""Delete a secret."""
client = await get_client_by_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
secret = await lookup_client_secret(session, client, secret_name)
if not secret:
raise HTTPException(
status_code=404, detail="Cannot find a secret with the given name."
)
session.delete(secret)
session.commit()
audit.audit_delete_secret(session, request, client, secret)
@backend_api.get("/audit/", response_model=list[AuditLog])
async def get_audit_logs(
request: Request,
session: Annotated[Session, Depends(get_session)],
offset: Annotated[int, Query()] = 0,
limit: Annotated[int, Query(le=100)] = 100,
filter_client: Annotated[str | None, Query()] = None,
) -> Sequence[AuditLog]:
"""Get audit logs."""
audit.audit_access_audit_log(session, request)
statement = select(AuditLog).offset(offset).limit(limit)
if filter_client:
statement = statement.where(AuditLog.client_name == filter_client)
results = session.exec(statement).all()
return results
app = FastAPI()
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
)
app.include_router(backend_api)
return init_backend_app(engine, get_db_session)