Check in backend in working state

This commit is contained in:
2025-04-30 08:23:31 +02:00
parent 76ef97d9c4
commit 20f1ee707a
26 changed files with 1505 additions and 621 deletions

View File

@ -1,5 +1,2 @@
"""Sshecret backend."""
from .app import app as app
#from .router import app as app
__all__ = ["app"]
# from .router import app as app

View File

@ -0,0 +1,8 @@
"""API factory modules."""
from .audit import get_audit_api
from .clients import get_clients_api
from .policies import get_policy_api
from .secrets import get_secrets_api
__all__ = ["get_audit_api", "get_clients_api", "get_policy_api", "get_secrets_api"]

View File

@ -0,0 +1,65 @@
"""Audit sub-api factory."""
# pyright: reportUnusedFunction=false
import logging
from collections.abc import Sequence
from fastapi import APIRouter, Depends, Request, Query
from sqlmodel import Session, col, func, select
from sqlalchemy import desc
from typing import Annotated
from sshecret_backend.models import AuditLog
from sshecret_backend.types import DBSessionDep
from sshecret_backend import audit
from sshecret_backend.view_models import AuditInfo
LOG = logging.getLogger(__name__)
def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
"""Construct audit sub-api."""
router = APIRouter()
@router.get("/audit/", response_model=list[AuditLog])
async def get_audit_logs(
request: Request,
session: Annotated[Session, Depends(get_db_session)],
offset: Annotated[int, Query()] = 0,
limit: Annotated[int, Query(le=100)] = 100,
filter_client: Annotated[str | None, Query()] = None,
filter_subsystem: Annotated[str | None, Query()] = None,
) -> Sequence[AuditLog]:
"""Get audit logs."""
#audit.audit_access_audit_log(session, request)
statement = select(AuditLog).offset(offset).limit(limit).order_by(desc(col(AuditLog.timestamp)))
if filter_client:
statement = statement.where(AuditLog.client_name == filter_client)
if filter_subsystem:
statement = statement.where(AuditLog.subsystem == filter_subsystem)
results = session.exec(statement).all()
return results
@router.post("/audit/")
async def add_audit_log(
request: Request,
session: Annotated[Session, Depends(get_db_session)],
entry: AuditLog,
) -> AuditLog:
"""Add entry to audit log."""
audit_log = AuditLog.model_validate(entry.model_dump(exclude_none=True))
session.add(audit_log)
session.commit()
return audit_log
@router.get("/audit/info")
async def get_audit_info(request: Request, session: Annotated[Session, Depends(get_db_session)]) -> AuditInfo:
"""Get audit info."""
audit_count = session.exec(select(func.count('*')).select_from(AuditLog)).one()
return AuditInfo(entries=audit_count)
return router

View File

@ -0,0 +1,226 @@
"""Client sub-api factory."""
# pyright: reportUnusedFunction=false
import uuid
import logging
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from pydantic import BaseModel, Field, model_validator
from sqlmodel import Session, col, select
from sqlalchemy import func
from typing import Annotated, Self, TypeVar
from sqlmodel.sql.expression import SelectOfScalar
from sshecret_backend.types import DBSessionDep
from sshecret_backend.models import Client, ClientSecret
from sshecret_backend.view_models import (
ClientCreate,
ClientQueryResult,
ClientView,
ClientUpdate,
)
from sshecret_backend import audit
from .common import get_client_by_id_or_name
class ClientListParams(BaseModel):
"""Client list parameters."""
limit: int = Field(100, gt=0, le=100)
offset: int = Field(0, ge=0)
id: uuid.UUID | None = None
name: str | None = None
name__like: str | None = None
name__contains: str | None = None
@model_validator(mode="after")
def validate_expressions(self) -> Self:
"""Validate mutually exclusive expression."""
name_filter = False
if self.name__like or self.name__contains:
name_filter = True
if self.name__like and self.name__contains:
raise ValueError("You may only specify one name expression")
if self.name and name_filter:
raise ValueError(
"You must either specify name or one of name__like or name__contains"
)
return self
LOG = logging.getLogger(__name__)
T = TypeVar("T")
def filter_client_statement(
statement: SelectOfScalar[T], params: ClientListParams, ignore_limits: bool = False
) -> SelectOfScalar[T]:
"""Filter a statement with the provided params."""
if params.id:
statement = statement.where(Client.id == params.id)
if params.name:
statement = statement.where(Client.name == params.name)
elif params.name__like:
statement = statement.where(col(Client.name).like(params.name__like))
elif params.name__contains:
statement = statement.where(col(Client.name).contains(params.name__contains))
if ignore_limits:
return statement
return statement.limit(params.limit).offset(params.offset)
def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
"""Construct clients sub-api."""
router = APIRouter()
@router.get("/clients/")
async def get_clients(
filter_query: Annotated[ClientListParams, Query()],
session: Annotated[Session, Depends(get_db_session)],
) -> ClientQueryResult:
"""Get clients."""
# Get total results first
count_statement = select(func.count("*")).select_from(Client)
count_statement = filter_client_statement(count_statement, filter_query, True)
total_results = session.exec(count_statement).one()
statement = filter_client_statement(select(Client), filter_query, False)
results = session.exec(statement)
remainder = total_results - filter_query.offset - filter_query.limit
if remainder < 0:
remainder = 0
clients = list(results)
clients_view = ClientView.from_client_list(clients)
return ClientQueryResult(
clients=clients_view,
total_results=total_results,
remaining_results=remainder,
)
@router.get("/clients/{name}")
async def get_client(
name: str,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientView:
"""Fetch a client."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
return ClientView.from_client(client)
@router.delete("/clients/{name}")
async def delete_client(
request: Request,
name: str,
session: Annotated[Session, Depends(get_db_session)],
) -> None:
"""Delete a client."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
session.delete(client)
session.commit()
audit.audit_delete_client(session, request, client)
@router.post("/clients/")
async def create_client(
request: Request,
client: ClientCreate,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientView:
"""Create client."""
existing = await get_client_by_id_or_name(session, client.name)
if existing:
raise HTTPException(400, detail="Error: Already a client with that name.")
db_client = client.to_client()
session.add(db_client)
session.commit()
session.refresh(db_client)
audit.audit_create_client(session, request, db_client)
return ClientView.from_client(db_client)
@router.post("/clients/{name}/public-key")
async def update_client_public_key(
request: Request,
name: str,
client_update: ClientUpdate,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientView:
"""Change the public key of a client.
This invalidates all secrets.
"""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
client.public_key = client_update.public_key
for secret in session.exec(
select(ClientSecret).where(ClientSecret.client_id == client.id)
).all():
LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True
secret.client_id = None
secret.client = None
session.add(client)
session.refresh(client)
session.commit()
audit.audit_invalidate_secrets(session, request, client)
return ClientView.from_client(client)
@router.put("/clients/{name}")
async def update_client(
request: Request,
name: str,
client_update: ClientCreate,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientView:
"""Change the public key of a client.
This invalidates all secrets.
"""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
client.name = client_update.name
client.description = client_update.description
public_key_updated = False
if client_update.public_key != client.public_key:
public_key_updated = True
for secret in session.exec(
select(ClientSecret).where(ClientSecret.client_id == client.id)
).all():
LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True
secret.client_id = None
secret.client = None
session.add(client)
session.commit()
session.refresh(client)
audit.audit_update_client(session, request, client)
if public_key_updated:
audit.audit_invalidate_secrets(session, request, client)
return ClientView.from_client(client)
return router

View File

@ -0,0 +1,38 @@
"""Common helpers."""
import re
import uuid
import bcrypt
from sqlmodel import Session, select
from sshecret_backend.models import Client
RE_UUID = re.compile("^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$")
def verify_token(token: str, stored_hash: str) -> bool:
"""Verify token."""
token_bytes = token.encode("utf-8")
stored_bytes = stored_hash.encode("utf-8")
return bcrypt.checkpw(token_bytes, stored_bytes)
async def get_client_by_name(session: Session, name: str) -> Client | None:
"""Get client by name."""
client_filter = select(Client).where(Client.name == name)
client_results = session.exec(client_filter)
return client_results.first()
async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None:
"""Get client by name."""
client_filter = select(Client).where(Client.id == id)
client_results = session.exec(client_filter)
return client_results.first()
async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None:
"""Get client either by id or name."""
if RE_UUID.match(id_or_name):
id = uuid.UUID(id_or_name)
return await get_client_by_id(session, id)
return await get_client_by_name(session, id_or_name)

View File

@ -0,0 +1,82 @@
"""Policies sub-api router factory."""
# pyright: reportUnusedFunction=false
import logging
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlmodel import Session, select
from typing import Annotated
from sshecret_backend.models import Client, ClientAccessPolicy
from sshecret_backend.view_models import (
ClientPolicyView,
ClientPolicyUpdate,
)
from sshecret_backend.types import DBSessionDep
from sshecret_backend import audit
from .common import get_client_by_id_or_name
LOG = logging.getLogger(__name__)
def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
"""Construct clients sub-api."""
router = APIRouter()
@router.get("/clients/{name}/policies/")
async def get_client_policies(
name: str, session: Annotated[Session, Depends(get_db_session)]
) -> ClientPolicyView:
"""Get client policies."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
return ClientPolicyView.from_client(client)
@router.put("/clients/{name}/policies/")
async def update_client_policies(
request: Request,
name: str,
policy_update: ClientPolicyUpdate,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientPolicyView:
"""Update client policies.
This is also how you delete policies.
"""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
# Remove old policies.
policies = session.exec(
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
).all()
deleted_policies: list[ClientAccessPolicy] = []
added_policies: list[ClientAccessPolicy] = []
for policy in policies:
session.delete(policy)
deleted_policies.append(policy)
for source in policy_update.sources:
LOG.debug("Source %r", source)
policy = ClientAccessPolicy(source=str(source), client_id=client.id)
session.add(policy)
added_policies.append(policy)
session.commit()
session.refresh(client)
for policy in deleted_policies:
audit.audit_remove_policy(session, request, client, policy)
for policy in added_policies:
audit.audit_update_policy(session, request, client, policy)
return ClientPolicyView.from_client(client)
return router

View File

@ -0,0 +1,232 @@
"""Secrets sub-api factory."""
# pyright: reportUnusedFunction=false
import logging
from collections import defaultdict
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlmodel import Session, select
from typing import Annotated
from sshecret_backend.models import Client, ClientSecret
from sshecret_backend.view_models import (
ClientReference,
ClientSecretDetailList,
ClientSecretList,
ClientSecretPublic,
BodyValue,
ClientSecretResponse,
)
from sshecret_backend import audit
from sshecret_backend.types import DBSessionDep
from .common import get_client_by_id_or_name
LOG = logging.getLogger(__name__)
async def lookup_client_secret(
session: Session, client: Client, name: str
) -> ClientSecret | None:
"""Look up a secret for a client."""
statement = (
select(ClientSecret)
.where(ClientSecret.client_id == client.id)
.where(ClientSecret.name == name)
)
results = session.exec(statement)
return results.first()
def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
"""Construct clients sub-api."""
router = APIRouter()
@router.post("/clients/{name}/secrets/")
async def add_secret_to_client(
request: Request,
name: str,
client_secret: ClientSecretPublic,
session: Annotated[Session, Depends(get_db_session)],
) -> None:
"""Add secret to a client."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
existing_secret = await lookup_client_secret(
session, client, client_secret.name
)
if existing_secret:
raise HTTPException(
status_code=400,
detail="Cannot add a secret. A different secret with the same name already exists.",
)
db_secret = ClientSecret(
name=client_secret.name, client_id=client.id, secret=client_secret.secret
)
session.add(db_secret)
session.commit()
session.refresh(db_secret)
audit.audit_create_secret(session, request, client, db_secret)
@router.put("/clients/{name}/secrets/{secret_name}")
async def update_client_secret(
request: Request,
name: str,
secret_name: str,
secret_data: BodyValue,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientSecretResponse:
"""Update a client secret.
This can also be used for destructive creates.
"""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
existing_secret = await lookup_client_secret(session, client, secret_name)
if existing_secret:
existing_secret.secret = secret_data.value
session.add(existing_secret)
session.commit()
session.refresh(existing_secret)
audit.audit_update_secret(session, request, client, existing_secret)
return ClientSecretResponse.from_client_secret(existing_secret)
db_secret = ClientSecret(
name=secret_name,
client_id=client.id,
secret=secret_data.value,
)
session.add(db_secret)
session.commit()
session.refresh(db_secret)
audit.audit_create_secret(session, request, client, db_secret)
return ClientSecretResponse.from_client_secret(db_secret)
@router.get("/clients/{name}/secrets/{secret_name}")
async def request_client_secret(
request: Request,
name: str,
secret_name: str,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientSecretResponse:
"""Get a client secret."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
secret = await lookup_client_secret(session, client, secret_name)
if not secret:
raise HTTPException(
status_code=404, detail="Cannot find a secret with the given name."
)
response_model = ClientSecretResponse.from_client_secret(secret)
audit.audit_access_secret(session, request, client, secret)
return response_model
@router.delete("/clients/{name}/secrets/{secret_name}")
async def delete_client_secret(
request: Request,
name: str,
secret_name: str,
session: Annotated[Session, Depends(get_db_session)],
) -> None:
"""Delete a secret."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
secret = await lookup_client_secret(session, client, secret_name)
if not secret:
raise HTTPException(
status_code=404, detail="Cannot find a secret with the given name."
)
session.delete(secret)
session.commit()
audit.audit_delete_secret(session, request, client, secret)
@router.get("/secrets/")
async def get_secret_map(
request: Request, session: Annotated[Session, Depends(get_db_session)]
) -> list[ClientSecretList]:
"""Get a list of all secrets and which clients have them."""
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
for client_secret in session.exec(select(ClientSecret)).all():
if not client_secret.client:
if client_secret.name not in client_secret_map:
client_secret_map[client_secret.name] = []
continue
client_secret_map[client_secret.name].append(client_secret.client.name)
audit.audit_client_secret_list(session, request)
return [
ClientSecretList(name=secret_name, clients=clients)
for secret_name, clients in client_secret_map.items()
]
@router.get("/secrets/detailed/")
async def get_detailed_secret_map(
request: Request, session: Annotated[Session, Depends(get_db_session)]
) -> list[ClientSecretDetailList]:
"""Get a list of all secrets and which clients have them."""
client_secrets: dict[str, ClientSecretDetailList] = {}
for client_secret in session.exec(select(ClientSecret)).all():
if client_secret.name not in client_secrets:
client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name)
client_secrets[client_secret.name].ids.append(str(client_secret.id))
if not client_secret.client:
continue
client_secrets[client_secret.name].clients.append(ClientReference(id=str(client_secret.client.id), name=client_secret.client.name))
audit.audit_client_secret_list(session, request)
return list(client_secrets.values())
@router.get("/secrets/{name}")
async def get_secret_clients(
request: Request,
name: str,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientSecretList:
"""Get a list of which clients has a named secret."""
clients: list[str] = []
for client_secret in session.exec(
select(ClientSecret).where(ClientSecret.name == name)
).all():
if not client_secret.client:
continue
clients.append(client_secret.client.name)
return ClientSecretList(name=name, clients=clients)
@router.get("/secrets/{name}/detailed")
async def get_secret_clients_detailed(
request: Request,
name: str,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientSecretDetailList:
"""Get a list of which clients has a named secret."""
detail_list = ClientSecretDetailList(name=name)
for client_secret in session.exec(
select(ClientSecret).where(ClientSecret.name == name)
).all():
if not client_secret.client:
continue
detail_list.ids.append(str(client_secret.id))
detail_list.clients.append(ClientReference(id=str(client_secret.client.id), name=client_secret.client.name))
return detail_list
return router

View File

@ -1,436 +1,65 @@
"""FastAPI api.
TODO: We may want to allow a consumer to generate audit log entries manually.
"""
"""FastAPI api."""
# pyright: reportUnusedFunction=false
import logging
from collections.abc import Sequence
from contextlib import asynccontextmanager
from typing import Annotated
import bcrypt
from fastapi import (
APIRouter,
Depends,
FastAPI,
Header,
HTTPException,
Query,
Request,
status,
)
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from sqlalchemy import Engine
from sqlmodel import Session, select
from . import audit
from .db import get_engine
from .models import (
APIClient,
AuditLog,
Client,
ClientAccessPolicy,
ClientSecret,
init_db,
)
from .settings import get_settings
from .view_models import (
BodyValue,
ClientCreate,
ClientSecretPublic,
ClientSecretResponse,
ClientUpdate,
ClientView,
ClientPolicyView,
ClientPolicyUpdate,
)
settings = get_settings()
engine = get_engine(settings.db_file)
from .models import init_db
from .backend_api import get_backend_api
from .db import setup_database
from .settings import BackendSettings
from .types import DBSessionDep
LOG = logging.getLogger(__name__)
API_VERSION = "v1"
def init_backend_app(engine: Engine, get_db_session: DBSessionDep) -> FastAPI:
"""Initialize backend app."""
def verify_token(token: str, stored_hash: str) -> bool:
"""Verify token."""
token_bytes = token.encode("utf-8")
stored_bytes = stored_hash.encode("utf-8")
return bcrypt.checkpw(token_bytes, stored_bytes)
@asynccontextmanager
async def lifespan(_app: FastAPI):
"""Create database before starting the server."""
LOG.debug("Running lifespan")
init_db(engine)
yield
app = FastAPI(lifespan=lifespan)
app.include_router(get_backend_api(get_db_session))
@asynccontextmanager
async def lifespan(_app: FastAPI):
"""Create database before starting the server."""
init_db(engine)
yield
async def get_session():
"""Get the session."""
with Session(engine) as session:
yield session
async def validate_token(
x_api_token: Annotated[str, Header()],
session: Annotated[Session, Depends(get_session)],
) -> str:
"""Validate token."""
LOG.debug("Validating token %s", x_api_token)
statement = select(APIClient)
results = session.exec(statement)
valid = False
for result in results:
if verify_token(x_api_token, result.token):
valid = True
LOG.debug("Token is valid")
break
if not valid:
LOG.debug("Token is not valid.")
raise HTTPException(status_code=401, detail="unauthorized. invalid api token.")
return x_api_token
async def get_client_by_name(session: Session, name: str) -> Client | None:
"""Get client by name."""
client_filter = select(Client).where(Client.name == name)
client_results = session.exec(client_filter)
return client_results.first()
async def lookup_client_secret(
session: Session, client: Client, name: str
) -> ClientSecret | None:
"""Look up a secret for a client."""
statement = (
select(ClientSecret)
.where(ClientSecret.client_id == client.id)
.where(ClientSecret.name == name)
)
results = session.exec(statement)
return results.first()
LOG.info("Initializing app.")
backend_api = APIRouter(
prefix=f"/api/{API_VERSION}",
lifespan=lifespan,
dependencies=[Depends(validate_token)],
)
@backend_api.get("/clients/")
async def get_clients(
session: Annotated[Session, Depends(get_session)]
) -> list[ClientView]:
"""Get clients."""
statement = select(Client)
results = session.exec(statement)
clients = list(results)
return ClientView.from_client_list(clients)
@backend_api.get("/clients/{name}")
async def get_client(
request: Request, name: str, session: Annotated[Session, Depends(get_session)]
) -> ClientView:
"""Fetch a client."""
statement = select(Client).where(Client.name == name)
results = session.exec(statement)
client = results.first()
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
audit.audit_access_secrets(session, request, client)
return ClientView.from_client(client)
@backend_api.delete("/clients/{name}")
async def delete_client(
request: Request, name: str, session: Annotated[Session, Depends(get_session)]
) -> None:
"""Delete a client."""
statement = select(Client).where(Client.name == name)
results = session.exec(statement)
client = results.first()
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(
request: Request, exc: RequestValidationError
):
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
)
session.delete(client)
session.commit()
audit.audit_delete_client(session, request, client)
@backend_api.get("/clients/{name}/policies/")
async def get_client_policies(
name: str, session: Annotated[Session, Depends(get_session)]
) -> ClientPolicyView:
"""Get client policies."""
statement = select(Client).where(Client.name == name)
results = session.exec(statement)
client = results.first()
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
@app.get("/health")
async def get_health() -> JSONResponse:
"""Provide simple health check."""
return JSONResponse(
status_code=status.HTTP_200_OK, content=jsonable_encoder({"status": "LIVE"})
)
return ClientPolicyView.from_client(client)
return app
@backend_api.put("/clients/{name}/policies/")
async def update_client_policies(
request: Request,
name: str,
policy_update: ClientPolicyUpdate,
session: Annotated[Session, Depends(get_session)],
) -> ClientPolicyView:
"""Update client policies.
def create_backend_app(settings: BackendSettings) -> FastAPI:
"""Create the backend app."""
This is also how you delete policies.
"""
statement = select(Client).where(Client.name == name)
results = session.exec(statement)
client = results.first()
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
# Remove old policies.
policies = session.exec(
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
).all()
deleted_policies: list[ClientAccessPolicy] = []
added_policies: list[ClientAccessPolicy] = []
for policy in policies:
session.delete(policy)
deleted_policies.append(policy)
engine, get_db_session = setup_database(settings.db_url)
for source in policy_update.sources:
LOG.debug("Source %r", source)
policy = ClientAccessPolicy(source=str(source), client_id=client.id)
session.add(policy)
added_policies.append(policy)
session.commit()
session.refresh(client)
for policy in deleted_policies:
audit.audit_remove_policy(session, request, client, policy)
for policy in added_policies:
audit.audit_update_policy(session, request, client, policy)
return ClientPolicyView.from_client(client)
@backend_api.post("/clients/{name}/public-key")
async def update_client_public_key(
request: Request,
name: str,
client_update: ClientUpdate,
session: Annotated[Session, Depends(get_session)],
) -> ClientView:
"""Change the public key of a client.
This invalidates all secrets.
"""
statement = select(Client).where(Client.name == name)
results = session.exec(statement)
client = results.first()
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
client.public_key = client_update.public_key
for secret in session.exec(
select(ClientSecret).where(ClientSecret.client_id == client.id)
).all():
LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True
secret.client_id = None
secret.client = None
session.add(client)
session.refresh(client)
session.commit()
audit.audit_invalidate_secrets(session, request, client)
return ClientView.from_client(client)
@backend_api.post("/clients/")
async def create_client(
request: Request,
client: ClientCreate,
session: Annotated[Session, Depends(get_session)],
) -> ClientView:
"""Create client."""
existing = await get_client_by_name(session, client.name)
if existing:
raise HTTPException(400, detail="Error: Already a client with that name.")
db_client = client.to_client()
session.add(db_client)
session.commit()
session.refresh(db_client)
audit.audit_create_client(session, request, db_client)
return ClientView.from_client(db_client)
@backend_api.post("/clients/{name}/secrets/")
async def add_secret_to_client(
request: Request,
name: str,
client_secret: ClientSecretPublic,
session: Annotated[Session, Depends(get_session)],
) -> None:
"""Add secret to a client."""
client = await get_client_by_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
existing_secret = await lookup_client_secret(session, client, client_secret.name)
if existing_secret:
raise HTTPException(
status_code=400,
detail="Cannot add a secret. A different secret with the same name already exists.",
)
db_secret = ClientSecret(
name=client_secret.name, client_id=client.id, secret=client_secret.secret
)
session.add(db_secret)
session.commit()
session.refresh(db_secret)
audit.audit_create_secret(session, request, client, db_secret)
@backend_api.put("/clients/{name}/secrets/{secret_name}")
async def update_client_secret(
request: Request,
name: str,
secret_name: str,
secret_data: BodyValue,
session: Annotated[Session, Depends(get_session)],
) -> ClientSecretResponse:
"""Update a client secret.
This can also be used for destructive creates.
"""
client = await get_client_by_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
existing_secret = await lookup_client_secret(session, client, secret_name)
if existing_secret:
existing_secret.secret = secret_data.value
session.add(existing_secret)
session.commit()
session.refresh(existing_secret)
audit.audit_update_secret(session, request, client, existing_secret)
return ClientSecretResponse.from_client_secret(existing_secret)
db_secret = ClientSecret(
name=secret_name,
client_id=client.id,
secret=secret_data.value,
)
session.add(db_secret)
session.commit()
session.refresh(db_secret)
audit.audit_create_secret(session, request, client, db_secret)
return ClientSecretResponse.from_client_secret(db_secret)
@backend_api.get("/clients/{name}/secrets/{secret_name}")
async def request_client_secret(
request: Request,
name: str,
secret_name: str,
session: Annotated[Session, Depends(get_session)],
) -> ClientSecretResponse:
"""Get a client secret."""
client = await get_client_by_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
secret = await lookup_client_secret(session, client, secret_name)
if not secret:
raise HTTPException(
status_code=404, detail="Cannot find a secret with the given name."
)
response_model = ClientSecretResponse.from_client_secret(secret)
audit.audit_access_secret(session, request, client, secret)
return response_model
@backend_api.delete("/clients/{name}/secrets/{secret_name}")
async def delete_client_secret(
request: Request,
name: str,
secret_name: str,
session: Annotated[Session, Depends(get_session)],
) -> None:
"""Delete a secret."""
client = await get_client_by_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
secret = await lookup_client_secret(session, client, secret_name)
if not secret:
raise HTTPException(
status_code=404, detail="Cannot find a secret with the given name."
)
session.delete(secret)
session.commit()
audit.audit_delete_secret(session, request, client, secret)
@backend_api.get("/audit/", response_model=list[AuditLog])
async def get_audit_logs(
request: Request,
session: Annotated[Session, Depends(get_session)],
offset: Annotated[int, Query()] = 0,
limit: Annotated[int, Query(le=100)] = 100,
filter_client: Annotated[str | None, Query()] = None,
) -> Sequence[AuditLog]:
"""Get audit logs."""
audit.audit_access_audit_log(session, request)
statement = select(AuditLog).offset(offset).limit(limit)
if filter_client:
statement = statement.where(AuditLog.client_name == filter_client)
results = session.exec(statement).all()
return results
app = FastAPI()
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
)
app.include_router(backend_api)
return init_backend_app(engine, get_db_session)

View File

@ -3,6 +3,7 @@
from collections.abc import Sequence
from fastapi import Request
from sqlmodel import Session, select
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy
@ -21,6 +22,7 @@ def _write_audit_log(
"""Write the audit log."""
origin = _get_origin(request)
entry.origin = origin
entry.subsystem = "backend"
session.add(entry)
if commit:
session.commit()
@ -109,6 +111,23 @@ def audit_update_policy(
_write_audit_log(session, request, entry, commit)
def audit_update_client(
session: Session,
request: Request,
client: Client,
commit: bool = True,
) -> None:
"""Audit an update secret event."""
entry = AuditLog(
operation="UPDATE",
object="Client",
client_id=client.id,
client_name=client.name,
message="Client updated",
)
_write_audit_log(session, request, entry, commit)
def audit_update_secret(
session: Session,
request: Request,
@ -219,3 +238,15 @@ def audit_access_audit_log(
object="AuditLog",
)
_write_audit_log(session, request, entry, commit)
def audit_client_secret_list(
session: Session, request: Request, commit: bool = True
) -> None:
"""Audit a list of all secrets."""
entry = AuditLog(
operation="ACCESS",
message="All secret names and their clients was viewed",
)
_write_audit_log(session, request, entry, commit)

View File

@ -0,0 +1,67 @@
"""Backend API."""
import logging
from typing import Annotated
import bcrypt
from fastapi import APIRouter, Depends, Header, HTTPException
from sqlmodel import Session, select
from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api
from .models import (
APIClient,
)
from .types import DBSessionDep
LOG = logging.getLogger(__name__)
API_VERSION = "v1"
def verify_token(token: str, stored_hash: str) -> bool:
"""Verify token."""
token_bytes = token.encode("utf-8")
stored_bytes = stored_hash.encode("utf-8")
return bcrypt.checkpw(token_bytes, stored_bytes)
def get_backend_api(
get_db_session: DBSessionDep,
) -> APIRouter:
"""Construct backend API."""
async def validate_token(
x_api_token: Annotated[str, Header()],
session: Annotated[Session, Depends(get_db_session)],
) -> str:
"""Validate token."""
LOG.debug("Validating token %s", x_api_token)
statement = select(APIClient)
results = session.exec(statement)
valid = False
for result in results:
if verify_token(x_api_token, result.token):
valid = True
LOG.debug("Token is valid")
break
if not valid:
LOG.debug("Token is not valid.")
raise HTTPException(
status_code=401, detail="unauthorized. invalid api token."
)
return x_api_token
LOG.info("Initializing app.")
backend_api = APIRouter(
prefix=f"/api/{API_VERSION}",
dependencies=[Depends(validate_token)],
)
backend_api.include_router(get_audit_api(get_db_session))
backend_api.include_router(get_clients_api(get_db_session))
backend_api.include_router(get_policy_api(get_db_session))
backend_api.include_router(get_secrets_api(get_db_session))
return backend_api

View File

@ -1,11 +1,18 @@
"""CLI and main entry point."""
import code
import os
from pathlib import Path
from typing import cast
from dotenv import load_dotenv
import click
from sqlmodel import Session, create_engine, select
import uvicorn
from .db import generate_api_token
from .db import create_api_token
from .models import Client, ClientSecret, ClientAccessPolicy, AuditLog, APIClient
from .settings import BackendSettings
DEFAULT_LISTEN = "127.0.0.1"
DEFAULT_PORT = 8022
@ -14,18 +21,59 @@ WORKDIR = Path(os.getcwd())
load_dotenv()
@click.group()
@click.option("--database", help="Path to the sqlite database file.")
def cli(database: str) -> None:
@click.pass_context
def cli(ctx: click.Context, database: str) -> None:
"""CLI group."""
if database:
# Hopefully it's enough to set the environment variable as so.
os.environ["SSHECRET_DB_FILE"] = str(Path(database).absolute())
settings = BackendSettings(db_url=f"sqlite:///{Path(database).absolute()}")
else:
settings = BackendSettings()
ctx.obj = settings
@cli.command("generate-token")
def cli_generate_token() -> None:
@click.pass_context
def cli_generate_token(ctx: click.Context) -> None:
"""Generate a token."""
token = generate_api_token()
settings = cast(BackendSettings, ctx.obj)
engine = create_engine(settings.db_url)
with Session(engine) as session:
token = create_api_token(session, True)
click.echo("Generated api token:")
click.echo(token)
@cli.command("run")
@click.option("--host", default="127.0.0.1")
@click.option("--port", default=8022, type=click.INT)
@click.option("--dev", is_flag=True)
@click.option("--workers", type=click.INT)
def cli_run(host: str, port: int, dev: bool, workers: int | None) -> None:
"""Run the server."""
uvicorn.run("sshecret_backend.main:app", host=host, port=port, reload=dev, workers=workers)
@cli.command("repl")
@click.pass_context
def cli_repl(ctx: click.Context) -> None:
"""Run an interactive console."""
settings = cast(BackendSettings, ctx.obj)
engine = create_engine(settings.db_url)
with Session(engine) as session:
locals = {
"session": session,
"select": select,
"Client": Client,
"ClientSecret": ClientSecret,
"ClientAccessPolicy": ClientAccessPolicy,
"APIClient": APIClient,
"AuditLog": AuditLog,
}
console = code.InteractiveConsole(locals=locals, local_exit=True)
banner = "Sshecret-backend REPL.\nUse 'session' to interact with the database."
console.interact(banner=banner, exitmsg="Bye!")

View File

@ -2,6 +2,7 @@
import logging
import secrets
from collections.abc import Generator, Callable
from pathlib import Path
from sqlalchemy import Engine
from sqlmodel import Session, create_engine, text
@ -9,14 +10,30 @@ import bcrypt
from sqlalchemy.engine import URL
from .models import APIClient, init_db
from .settings import get_settings
from .models import APIClient
LOG = logging.getLogger(__name__)
def setup_database(
db_url: URL | str,
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
"""Setup database."""
engine = create_engine(db_url, echo=False)
with engine.connect() as connection:
connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only
def get_db_session() -> Generator[Session, None, None]:
"""Get DB Session."""
with Session(engine) as session:
yield session
return engine, get_db_session
def get_engine(filename: Path, echo: bool = False) -> Engine:
"""Initialize the engine."""
url = URL.create(drivername="sqlite", database=str(filename.absolute()))
@ -27,20 +44,6 @@ def get_engine(filename: Path, echo: bool = False) -> Engine:
return engine
def create_db_and_tables(filename: Path, echo: bool = False) -> bool:
"""Create database and tables.
Returns True if the database was created.
"""
created = False
if not filename.exists():
created = True
engine = get_engine(filename, echo)
init_db(engine)
return created
def create_api_token(session: Session, read_write: bool) -> str:
"""Create API token."""
token = secrets.token_urlsafe(32)
@ -54,14 +57,3 @@ def create_api_token(session: Session, read_write: bool) -> str:
session.commit()
return token
def generate_api_token() -> str:
"""Generate API token."""
settings = get_settings()
engine = get_engine(settings.db_file)
init_db(engine)
with Session(engine) as session:
token = create_api_token(session, True)
return token

View File

@ -0,0 +1,7 @@
"""Main script entrypoint."""
from .settings import BackendSettings
from .app import create_backend_app
app = create_backend_app(BackendSettings())

View File

@ -7,17 +7,21 @@ This might require some changes to these schemas.
"""
import logging
import uuid
from datetime import datetime
import sqlalchemy as sa
from sqlmodel import Field, Relationship, SQLModel
LOG = logging.getLogger(__name__)
class Client(SQLModel, table=True):
"""Client model."""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
name: str = Field(unique=True)
description: str | None = None
public_key: str
created_at: datetime | None = Field(
@ -61,11 +65,13 @@ class ClientAccessPolicy(SQLModel, table=True):
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
)
class ClientSecret(SQLModel, table=True):
"""A client secret."""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
name: str
description: str | None = None
client_id: uuid.UUID | None = Field(foreign_key="client.id", ondelete="CASCADE")
client: Client | None = Relationship(back_populates="secrets")
secret: str
@ -92,6 +98,7 @@ class AuditLog(SQLModel, table=True):
"""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
subsystem: str | None = None
object: str | None = None
object_id: str | None = None
operation: str
@ -107,6 +114,7 @@ class AuditLog(SQLModel, table=True):
nullable=False,
)
class APIClient(SQLModel, table=True):
"""Stores API Keys."""
@ -120,6 +128,8 @@ class APIClient(SQLModel, table=True):
nullable=False,
)
def init_db(engine: sa.Engine) -> None:
"""Create database."""
LOG.info("Starting init_db")
SQLModel.metadata.create_all(engine)

View File

@ -1,15 +1,13 @@
"""Settings management."""
from typing import override
from pathlib import Path
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import (
BaseSettings,
SettingsConfigDict,
)
DEFAULT_DATABASE = "sshecret.db"
DEFAULT_DATABASE = "sqlite:///sshecret.db"
class BackendSettings(BaseSettings):
@ -17,7 +15,7 @@ class BackendSettings(BaseSettings):
model_config = SettingsConfigDict(env_file=".backend.env", env_prefix="sshecret_")
db_file: Path = Field(default=Path(DEFAULT_DATABASE).absolute())
db_url: str = Field(default=DEFAULT_DATABASE)
def get_settings() -> BackendSettings:

View File

@ -1,13 +1,17 @@
"""Test helpers."""
import logging
from sqlmodel import Session
from .db import get_engine, create_api_token
from sshecret_backend.settings import BackendSettings
from .models import init_db
from .settings import get_settings
from .db import create_api_token, setup_database
def create_test_token(session: Session) -> str:
LOG = logging.getLogger(__name__)
def create_test_token(settings: BackendSettings) -> str:
"""Create test token."""
settings = get_settings()
engine = get_engine(settings.db_file)
init_db(engine)
return create_api_token(session, True)
engine, _setupdb = setup_database(settings.db_url)
with Session(engine) as session:
init_db(engine)
return create_api_token(session, True)

View File

@ -0,0 +1,8 @@
"""Common type definitions."""
from collections.abc import Callable, Generator
from sqlmodel import Session
DBSessionDep = Callable[[], Generator[Session, None, None]]

View File

@ -1,24 +1,27 @@
"""Models for API views."""
import ipaddress
import uuid
from datetime import datetime
from typing import Annotated, Any, Self, override
from typing import Annotated, Self, override
from sqlmodel import Field, SQLModel
from pydantic import IPvAnyAddress, IPvAnyNetwork
from . import models
from pydantic import AfterValidator, IPvAnyAddress, IPvAnyNetwork
from sshecret.crypto import public_key_validator
from . import models
class ClientView(SQLModel):
"""View for a single client."""
id: uuid.UUID
name: str
description: str | None = None
public_key: str
policies: list[str] = ["0.0.0.0/0", "::/0"]
secrets: list[str] = Field(default_factory=list)
created_at: datetime
created_at: datetime | None
updated_at: datetime | None = None
@classmethod
@ -33,6 +36,7 @@ class ClientView(SQLModel):
view = cls(
id=client.id,
name=client.name,
description=client.description,
public_key=client.public_key,
created_at=client.created_at,
updated_at=client.updated_at or None,
@ -46,24 +50,34 @@ class ClientView(SQLModel):
return view
class ClientQueryResult(SQLModel):
"""Result class for queries towards the client list."""
clients: list[ClientView] = Field(default_factory=list)
total_results: int
remaining_results: int
class ClientCreate(SQLModel):
"""Model to create a client."""
name: str
public_key: str
description: str | None = None
public_key: Annotated[str, AfterValidator(public_key_validator)]
def to_client(self) -> models.Client:
"""Instantiate a client."""
public_key = self.public_key
return models.Client(
name=self.name, public_key=public_key
name=self.name,
public_key=self.public_key,
description=self.description,
)
class ClientUpdate(SQLModel):
"""Model to update the client public key."""
public_key: str
public_key: Annotated[str, AfterValidator(public_key_validator)]
class BodyValue(SQLModel):
@ -77,6 +91,7 @@ class ClientSecretPublic(SQLModel):
name: str
secret: str
description: str | None = None
@classmethod
def from_client_secret(cls, client_secret: models.ClientSecret) -> Self:
@ -84,13 +99,14 @@ class ClientSecretPublic(SQLModel):
return cls(
name=client_secret.name,
secret=client_secret.secret,
description=client_secret.description,
)
class ClientSecretResponse(ClientSecretPublic):
"""A secret view."""
created_at: datetime
created_at: datetime | None
updated_at: datetime | None = None
@override
@ -123,3 +139,31 @@ class ClientPolicyUpdate(SQLModel):
"""Model for updating policies."""
sources: list[IPvAnyAddress | IPvAnyNetwork]
class ClientSecretList(SQLModel):
"""Model for aggregating identically named secrets."""
name: str
clients: list[str]
class ClientReference(SQLModel):
"""Reference to a client."""
id: str
name: str
class ClientSecretDetailList(SQLModel):
"""A more detailed version of the ClientSecretList."""
name: str
ids: list[str] = Field(default_factory=list)
clients: list[ClientReference] = Field(default_factory=list)
class AuditInfo(SQLModel):
"""Information about audit information."""
entries: int