Refactor database layer and auditing

This commit is contained in:
2025-05-10 08:38:57 +02:00
parent d866553ac1
commit 9ccd2f1d4d
20 changed files with 718 additions and 469 deletions

View File

@ -3,22 +3,22 @@ from logging.config import fileConfig
from sqlalchemy import engine_from_config from sqlalchemy import engine_from_config
from sqlalchemy import pool from sqlalchemy import pool
from sqlmodel import create_engine
from alembic import context from alembic import context
from sshecret_backend.models import * from sshecret_backend.models import Base
def get_database_url() -> str:
"""Get database URL."""
if db_file := os.getenv("SSHECRET_BACKEND_DB"):
return f"sqlite:///{db_file}"
return "sqlite:///sshecret.db"
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.
config = context.config config = context.config
def get_database_url() -> str | None:
"""Get database URL."""
if db_file := os.getenv("SSHECRET_BACKEND_DB"):
return f"sqlite:///{db_file}"
return config.get_main_option("sqlalchemy.url")
# Interpret the config file for Python logging. # Interpret the config file for Python logging.
# This line sets up loggers basically. # This line sets up loggers basically.
if config.config_file_name is not None: if config.config_file_name is not None:
@ -28,8 +28,7 @@ if config.config_file_name is not None:
# for 'autogenerate' support # for 'autogenerate' support
# from myapp import mymodel # from myapp import mymodel
# target_metadata = mymodel.Base.metadata # target_metadata = mymodel.Base.metadata
#target_metadata = None target_metadata = Base.metadata
target_metadata = SQLModel.metadata
# other values from the config, defined by the needs of env.py, # other values from the config, defined by the needs of env.py,
# can be acquired: # can be acquired:
@ -68,7 +67,11 @@ def run_migrations_online() -> None:
and associate a connection with the context. and associate a connection with the context.
""" """
connectable = create_engine(get_database_url()) connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection: with connectable.connect() as connection:
context.configure( context.configure(

View File

@ -9,7 +9,6 @@ from typing import Sequence, Union
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
import sqlmodel
${imports if imports else ""} ${imports if imports else ""}
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.

View File

@ -5,13 +5,14 @@
import logging import logging
from collections.abc import Sequence from collections.abc import Sequence
from fastapi import APIRouter, Depends, Request, Query from fastapi import APIRouter, Depends, Request, Query
from sqlmodel import Session, col, func, select from pydantic import TypeAdapter
from sqlalchemy import desc from sqlalchemy import select, func
from sqlalchemy.orm import Session
from typing import Annotated from typing import Annotated
from sshecret_backend.models import AuditLog from sshecret_backend.models import AuditLog
from sshecret_backend.types import DBSessionDep from sshecret_backend.types import DBSessionDep
from sshecret_backend.view_models import AuditInfo from sshecret_backend.view_models import AuditInfo, AuditView
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -21,7 +22,7 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
"""Construct audit sub-api.""" """Construct audit sub-api."""
router = APIRouter() router = APIRouter()
@router.get("/audit/", response_model=list[AuditLog]) @router.get("/audit/", response_model=list[AuditView])
async def get_audit_logs( async def get_audit_logs(
request: Request, request: Request,
session: Annotated[Session, Depends(get_db_session)], session: Annotated[Session, Depends(get_db_session)],
@ -29,35 +30,37 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
limit: Annotated[int, Query(le=100)] = 100, limit: Annotated[int, Query(le=100)] = 100,
filter_client: Annotated[str | None, Query()] = None, filter_client: Annotated[str | None, Query()] = None,
filter_subsystem: Annotated[str | None, Query()] = None, filter_subsystem: Annotated[str | None, Query()] = None,
) -> Sequence[AuditLog]: ) -> Sequence[AuditView]:
"""Get audit logs.""" """Get audit logs."""
#audit.audit_access_audit_log(session, request) #audit.audit_access_audit_log(session, request)
statement = select(AuditLog).offset(offset).limit(limit).order_by(desc(col(AuditLog.timestamp))) statement = select(AuditLog).offset(offset).limit(limit).order_by(AuditLog.timestamp.desc())
if filter_client: if filter_client:
statement = statement.where(AuditLog.client_name == filter_client) statement = statement.where(AuditLog.client_name == filter_client)
if filter_subsystem: if filter_subsystem:
statement = statement.where(AuditLog.subsystem == filter_subsystem) statement = statement.where(AuditLog.subsystem == filter_subsystem)
results = session.exec(statement).all() LogAdapt = TypeAdapter(list[AuditView])
return results results = session.scalars(statement).all()
return LogAdapt.validate_python(results, from_attributes=True)
@router.post("/audit/") @router.post("/audit/")
async def add_audit_log( async def add_audit_log(
request: Request, request: Request,
session: Annotated[Session, Depends(get_db_session)], session: Annotated[Session, Depends(get_db_session)],
entry: AuditLog, entry: AuditView,
) -> AuditLog: ) -> AuditView:
"""Add entry to audit log.""" """Add entry to audit log."""
audit_log = AuditLog.model_validate(entry.model_dump(exclude_none=True)) audit_log = AuditLog(**entry.model_dump(exclude_none=True))
session.add(audit_log) session.add(audit_log)
session.commit() session.commit()
return audit_log return AuditView.model_validate(audit_log, from_attributes=True)
@router.get("/audit/info") @router.get("/audit/info")
async def get_audit_info(request: Request, session: Annotated[Session, Depends(get_db_session)]) -> AuditInfo: async def get_audit_info(request: Request, session: Annotated[Session, Depends(get_db_session)]) -> AuditInfo:
"""Get audit info.""" """Get audit info."""
audit_count = session.exec(select(func.count('*')).select_from(AuditLog)).one() audit_count = session.scalars(select(func.count('*')).select_from(AuditLog)).one()
return AuditInfo(entries=audit_count) return AuditInfo(entries=audit_count)

View File

@ -6,11 +6,11 @@ import uuid
import logging import logging
from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi import APIRouter, Depends, HTTPException, Query, Request
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from sqlmodel import Session, col, select from typing import Annotated, Any, Self, TypeVar, cast
from sqlalchemy import func
from typing import Annotated, Self, TypeVar
from sqlmodel.sql.expression import SelectOfScalar from sqlalchemy import select, func
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select
from sshecret_backend.types import DBSessionDep from sshecret_backend.types import DBSessionDep
from sshecret_backend.models import Client, ClientSecret from sshecret_backend.models import Client, ClientSecret
from sshecret_backend.view_models import ( from sshecret_backend.view_models import (
@ -55,8 +55,8 @@ T = TypeVar("T")
def filter_client_statement( def filter_client_statement(
statement: SelectOfScalar[T], params: ClientListParams, ignore_limits: bool = False statement: Select[Any], params: ClientListParams, ignore_limits: bool = False
) -> SelectOfScalar[T]: ) -> Select[Any]:
"""Filter a statement with the provided params.""" """Filter a statement with the provided params."""
if params.id: if params.id:
statement = statement.where(Client.id == params.id) statement = statement.where(Client.id == params.id)
@ -64,9 +64,9 @@ def filter_client_statement(
if params.name: if params.name:
statement = statement.where(Client.name == params.name) statement = statement.where(Client.name == params.name)
elif params.name__like: elif params.name__like:
statement = statement.where(col(Client.name).like(params.name__like)) statement = statement.where(Client.name.like(params.name__like))
elif params.name__contains: elif params.name__contains:
statement = statement.where(col(Client.name).contains(params.name__contains)) statement = statement.where(Client.name.contains(params.name__contains))
if ignore_limits: if ignore_limits:
return statement return statement
@ -86,13 +86,13 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
"""Get clients.""" """Get clients."""
# Get total results first # Get total results first
count_statement = select(func.count("*")).select_from(Client) count_statement = select(func.count("*")).select_from(Client)
count_statement = filter_client_statement(count_statement, filter_query, True) count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True))
total_results = session.exec(count_statement).one() total_results = session.scalars(count_statement).one()
statement = filter_client_statement(select(Client), filter_query, False) statement = filter_client_statement(select(Client), filter_query, False)
results = session.exec(statement) results = session.scalars(statement)
remainder = total_results - filter_query.offset - filter_query.limit remainder = total_results - filter_query.offset - filter_query.limit
if remainder < 0: if remainder < 0:
remainder = 0 remainder = 0
@ -170,13 +170,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
status_code=404, detail="Cannot find a client with the given name." status_code=404, detail="Cannot find a client with the given name."
) )
client.public_key = client_update.public_key client.public_key = client_update.public_key
for secret in session.exec( for secret in session.scalars(
select(ClientSecret).where(ClientSecret.client_id == client.id) select(ClientSecret).where(ClientSecret.client_id == client.id)
).all(): ).all():
LOG.debug("Invalidated secret %s", secret.id) LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True secret.invalidated = True
secret.client_id = None secret.client_id = None
secret.client = None
session.add(client) session.add(client)
session.refresh(client) session.refresh(client)
@ -206,13 +205,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
public_key_updated = False public_key_updated = False
if client_update.public_key != client.public_key: if client_update.public_key != client.public_key:
public_key_updated = True public_key_updated = True
for secret in session.exec( for secret in session.scalars(
select(ClientSecret).where(ClientSecret.client_id == client.id) select(ClientSecret).where(ClientSecret.client_id == client.id)
).all(): ).all():
LOG.debug("Invalidated secret %s", secret.id) LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True secret.invalidated = True
secret.client_id = None secret.client_id = None
secret.client = None
session.add(client) session.add(client)
session.commit() session.commit()

View File

@ -4,7 +4,8 @@ import re
import uuid import uuid
import bcrypt import bcrypt
from sqlmodel import Session, select from sqlalchemy import select
from sqlalchemy.orm import Session
from sshecret_backend.models import Client from sshecret_backend.models import Client
@ -20,13 +21,13 @@ def verify_token(token: str, stored_hash: str) -> bool:
async def get_client_by_name(session: Session, name: str) -> Client | None: async def get_client_by_name(session: Session, name: str) -> Client | None:
"""Get client by name.""" """Get client by name."""
client_filter = select(Client).where(Client.name == name) client_filter = select(Client).where(Client.name == name)
client_results = session.exec(client_filter) client_results = session.scalars(client_filter)
return client_results.first() return client_results.first()
async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None: async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None:
"""Get client by name.""" """Get client by name."""
client_filter = select(Client).where(Client.id == id) client_filter = select(Client).where(Client.id == id)
client_results = session.exec(client_filter) client_results = session.scalars(client_filter)
return client_results.first() return client_results.first()
async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None: async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None:

View File

@ -4,7 +4,8 @@
import logging import logging
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from sqlmodel import Session, select from sqlalchemy import select
from sqlalchemy.orm import Session
from typing import Annotated from typing import Annotated
from sshecret_backend.models import ClientAccessPolicy from sshecret_backend.models import ClientAccessPolicy
@ -54,7 +55,7 @@ def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
status_code=404, detail="Cannot find a client with the given name." status_code=404, detail="Cannot find a client with the given name."
) )
# Remove old policies. # Remove old policies.
policies = session.exec( policies = session.scalars(
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id) select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
).all() ).all()
deleted_policies: list[ClientAccessPolicy] = [] deleted_policies: list[ClientAccessPolicy] = []

View File

@ -5,7 +5,8 @@
import logging import logging
from collections import defaultdict from collections import defaultdict
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from sqlmodel import Session, select from sqlalchemy import select
from sqlalchemy.orm import Session
from typing import Annotated from typing import Annotated
from sshecret_backend.models import Client, ClientSecret from sshecret_backend.models import Client, ClientSecret
@ -34,7 +35,7 @@ async def lookup_client_secret(
.where(ClientSecret.client_id == client.id) .where(ClientSecret.client_id == client.id)
.where(ClientSecret.name == name) .where(ClientSecret.name == name)
) )
results = session.exec(statement) results = session.scalars(statement)
return results.first() return results.first()
@ -165,7 +166,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
) -> list[ClientSecretList]: ) -> list[ClientSecretList]:
"""Get a list of all secrets and which clients have them.""" """Get a list of all secrets and which clients have them."""
client_secret_map: defaultdict[str, list[str]] = defaultdict(list) client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
for client_secret in session.exec(select(ClientSecret)).all(): for client_secret in session.scalars(select(ClientSecret)).all():
if not client_secret.client: if not client_secret.client:
if client_secret.name not in client_secret_map: if client_secret.name not in client_secret_map:
client_secret_map[client_secret.name] = [] client_secret_map[client_secret.name] = []
@ -182,7 +183,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
) -> list[ClientSecretDetailList]: ) -> list[ClientSecretDetailList]:
"""Get a list of all secrets and which clients have them.""" """Get a list of all secrets and which clients have them."""
client_secrets: dict[str, ClientSecretDetailList] = {} client_secrets: dict[str, ClientSecretDetailList] = {}
for client_secret in session.exec(select(ClientSecret)).all(): for client_secret in session.scalars(select(ClientSecret)).all():
if client_secret.name not in client_secrets: if client_secret.name not in client_secrets:
client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name) client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name)
@ -202,7 +203,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
) -> ClientSecretList: ) -> ClientSecretList:
"""Get a list of which clients has a named secret.""" """Get a list of which clients has a named secret."""
clients: list[str] = [] clients: list[str] = []
for client_secret in session.exec( for client_secret in session.scalars(
select(ClientSecret).where(ClientSecret.name == name) select(ClientSecret).where(ClientSecret.name == name)
).all(): ).all():
if not client_secret.client: if not client_secret.client:
@ -219,7 +220,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
) -> ClientSecretDetailList: ) -> ClientSecretDetailList:
"""Get a list of which clients has a named secret.""" """Get a list of which clients has a named secret."""
detail_list = ClientSecretDetailList(name=name) detail_list = ClientSecretDetailList(name=name)
for client_secret in session.exec( for client_secret in session.scalars(
select(ClientSecret).where(ClientSecret.name == name) select(ClientSecret).where(ClientSecret.name == name)
).all(): ).all():
if not client_secret.client: if not client_secret.client:

View File

@ -2,9 +2,10 @@
from collections.abc import Sequence from collections.abc import Sequence
from fastapi import Request from fastapi import Request
from sqlmodel import Session, select from sqlalchemy import select
from sqlalchemy.orm import Session
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy, Operation, SubSystem
def _get_origin(request: Request) -> str | None: def _get_origin(request: Request) -> str | None:
@ -22,7 +23,7 @@ def _write_audit_log(
"""Write the audit log.""" """Write the audit log."""
origin = _get_origin(request) origin = _get_origin(request)
entry.origin = origin entry.origin = origin
entry.subsystem = "backend" entry.subsystem = SubSystem.BACKEND
session.add(entry) session.add(entry)
if commit: if commit:
session.commit() session.commit()
@ -33,7 +34,7 @@ def audit_create_client(
) -> None: ) -> None:
"""Log the creation of a client.""" """Log the creation of a client."""
entry = AuditLog( entry = AuditLog(
operation="CREATE", operation=Operation.CREATE,
client_id=client.id, client_id=client.id,
client_name=client.name, client_name=client.name,
message="Client Created", message="Client Created",
@ -46,7 +47,7 @@ def audit_delete_client(
) -> None: ) -> None:
"""Log the creation of a client.""" """Log the creation of a client."""
entry = AuditLog( entry = AuditLog(
operation="CREATE", operation=Operation.CREATE,
client_id=client.id, client_id=client.id,
client_name=client.name, client_name=client.name,
message="Client deleted", message="Client deleted",
@ -63,9 +64,9 @@ def audit_create_secret(
) -> None: ) -> None:
"""Audit a create secret event.""" """Audit a create secret event."""
entry = AuditLog( entry = AuditLog(
operation="CREATE", operation=Operation.CREATE,
object="ClientSecret", secret_id=secret.id,
object_id=str(secret.id), secret_name=secret.name,
client_id=client.id, client_id=client.id,
client_name=client.name, client_name=client.name,
message="Added secret to client", message="Added secret to client",
@ -81,13 +82,13 @@ def audit_remove_policy(
commit: bool = True, commit: bool = True,
) -> None: ) -> None:
"""Audit removal of policy.""" """Audit removal of policy."""
data = {"object": "ClientAccessPolicy", "object_id": str(policy.id)}
entry = AuditLog( entry = AuditLog(
operation="DELETE", operation=Operation.DELETE,
object="ClientAccessPolicy",
object_id=str(policy.id),
client_id=client.id, client_id=client.id,
client_name=client.name, client_name=client.name,
message="Deleted client policy", message="Deleted client policy",
data=data,
) )
_write_audit_log(session, request, entry, commit) _write_audit_log(session, request, entry, commit)
@ -100,13 +101,13 @@ def audit_update_policy(
commit: bool = True, commit: bool = True,
) -> None: ) -> None:
"""Audit update of policy.""" """Audit update of policy."""
data: dict[str, str] = {"object": "ClientAccessPolicy", "object_id": str(policy.id)}
entry = AuditLog( entry = AuditLog(
operation="CREATE", operation=Operation.CREATE,
object="ClientAccessPolicy",
object_id=str(policy.id),
client_id=client.id,
client_name=client.name, client_name=client.name,
client_id=client.id,
message="Updated client policy", message="Updated client policy",
data=data,
) )
_write_audit_log(session, request, entry, commit) _write_audit_log(session, request, entry, commit)
@ -119,11 +120,10 @@ def audit_update_client(
) -> None: ) -> None:
"""Audit an update secret event.""" """Audit an update secret event."""
entry = AuditLog( entry = AuditLog(
operation="UPDATE", operation=Operation.UPDATE,
object="Client",
client_id=client.id, client_id=client.id,
client_name=client.name, client_name=client.name,
message="Client updated", message="Client data updated",
) )
_write_audit_log(session, request, entry, commit) _write_audit_log(session, request, entry, commit)
@ -137,11 +137,11 @@ def audit_update_secret(
) -> None: ) -> None:
"""Audit an update secret event.""" """Audit an update secret event."""
entry = AuditLog( entry = AuditLog(
operation="UPDATE", operation=Operation.UPDATE,
object="ClientSecret",
object_id=str(secret.id),
client_id=client.id, client_id=client.id,
client_name=client.name, client_name=client.name,
secret_name=secret.name,
secret_id=secret.id,
message="Secret value updated", message="Secret value updated",
) )
_write_audit_log(session, request, entry, commit) _write_audit_log(session, request, entry, commit)
@ -155,8 +155,7 @@ def audit_invalidate_secrets(
) -> None: ) -> None:
"""Audit Invalidate client secrets.""" """Audit Invalidate client secrets."""
entry = AuditLog( entry = AuditLog(
operation="INVALIDATE", operation=Operation.UPDATE,
object="ClientSecret",
client_name=client.name, client_name=client.name,
client_id=client.id, client_id=client.id,
message="Client public-key changed. All secrets invalidated.", message="Client public-key changed. All secrets invalidated.",
@ -173,9 +172,9 @@ def audit_delete_secret(
) -> None: ) -> None:
"""Audit Delete client secrets.""" """Audit Delete client secrets."""
entry = AuditLog( entry = AuditLog(
operation="DELETE", operation=Operation.DELETE,
object="ClientSecret", secret_name=secret.name,
object_id=str(secret.id), secret_id=secret.id,
client_name=client.name, client_name=client.name,
client_id=client.id, client_id=client.id,
message="Deleted secret.", message="Deleted secret.",
@ -195,7 +194,7 @@ def audit_access_secrets(
With no secrets provided, all secrets of the client will be resolved. With no secrets provided, all secrets of the client will be resolved.
""" """
if not secrets: if not secrets:
secrets = session.exec( secrets = session.scalars(
select(ClientSecret).where(ClientSecret.client_id == client.id) select(ClientSecret).where(ClientSecret.client_id == client.id)
).all() ).all()
@ -215,37 +214,21 @@ def audit_access_secret(
) -> None: ) -> None:
"""Audit that someone accessed one secrets.""" """Audit that someone accessed one secrets."""
entry = AuditLog( entry = AuditLog(
operation="ACCESS", operation=Operation.READ,
message="Secret was viewed", message="Secret was viewed",
object="ClientSecret", secret_name=secret.name,
object_id=str(secret.id), secret_id=secret.id,
client_id=client.id, client_id=client.id,
client_name=client.name, client_name=client.name,
) )
_write_audit_log(session, request, entry, commit) _write_audit_log(session, request, entry, commit)
def audit_access_audit_log(
session: Session, request: Request, commit: bool = True
) -> None:
"""Audit access to the audit log.
Because why not...
"""
entry = AuditLog(
operation="ACCESS",
message="Audit log was viewed",
object="AuditLog",
)
_write_audit_log(session, request, entry, commit)
def audit_client_secret_list( def audit_client_secret_list(
session: Session, request: Request, commit: bool = True session: Session, request: Request, commit: bool = True
) -> None: ) -> None:
"""Audit a list of all secrets.""" """Audit a list of all secrets."""
entry = AuditLog( entry = AuditLog(
operation="ACCESS", operation=Operation.READ,
message="All secret names and their clients was viewed", message="All secret names and their clients was viewed",
) )
_write_audit_log(session, request, entry, commit) _write_audit_log(session, request, entry, commit)

View File

@ -3,11 +3,11 @@
import logging import logging
from typing import Annotated from typing import Annotated
import bcrypt
from fastapi import APIRouter, Depends, Header, HTTPException from fastapi import APIRouter, Depends, Header, HTTPException
from sqlmodel import Session, select from sqlalchemy import select
from sqlalchemy.orm import Session
from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api
from .auth import verify_token
from .models import ( from .models import (
APIClient, APIClient,
) )
@ -18,13 +18,6 @@ LOG = logging.getLogger(__name__)
API_VERSION = "v1" API_VERSION = "v1"
def verify_token(token: str, stored_hash: str) -> bool:
"""Verify token."""
token_bytes = token.encode("utf-8")
stored_bytes = stored_hash.encode("utf-8")
return bcrypt.checkpw(token_bytes, stored_bytes)
def get_backend_api( def get_backend_api(
get_db_session: DBSessionDep, get_db_session: DBSessionDep,
) -> APIRouter: ) -> APIRouter:
@ -37,7 +30,7 @@ def get_backend_api(
"""Validate token.""" """Validate token."""
LOG.debug("Validating token %s", x_api_token) LOG.debug("Validating token %s", x_api_token)
statement = select(APIClient) statement = select(APIClient)
results = session.exec(statement) results = session.scalars(statement)
valid = False valid = False
for result in results: for result in results:
if verify_token(x_api_token, result.token): if verify_token(x_api_token, result.token):

View File

@ -3,15 +3,24 @@
import code import code
import os import os
from pathlib import Path from pathlib import Path
from typing import cast from typing import Literal, cast
from dotenv import load_dotenv
import click import click
from sqlmodel import Session, col, func, select
import uvicorn import uvicorn
from dotenv import load_dotenv
from sqlalchemy import select
from sqlalchemy.orm import Session
from .db import get_engine, create_api_token from .db import create_api_token, get_engine, hash_token
from .models import (
from .models import Client, ClientSecret, ClientAccessPolicy, AuditLog, APIClient, init_db APIClient,
AuditLog,
Client,
ClientAccessPolicy,
ClientSecret,
SubSystem,
init_db,
)
from .settings import BackendSettings from .settings import BackendSettings
DEFAULT_LISTEN = "127.0.0.1" DEFAULT_LISTEN = "127.0.0.1"
@ -21,22 +30,44 @@ WORKDIR = Path(os.getcwd())
load_dotenv() load_dotenv()
def generate_token(settings: BackendSettings) -> str:
def generate_token(
settings: BackendSettings, subsystem: Literal["admin", "sshd"]
) -> str:
"""Generate a token.""" """Generate a token."""
engine = get_engine(settings.db_url) engine = get_engine(settings.db_url)
init_db(engine) init_db(engine)
with Session(engine) as session: with Session(engine) as session:
token = create_api_token(session, True) token = create_api_token(session, subsystem)
return token return token
def count_tokens(settings: BackendSettings) -> int:
"""Count the amount of tokens created.""" def add_system_tokens(settings: BackendSettings) -> None:
"""Add token for subsystems."""
if not settings.admin_token and not settings.sshd_token:
# Tokens should be generated manually.
return
engine = get_engine(settings.db_url) engine = get_engine(settings.db_url)
init_db(engine) init_db(engine)
tokens: list[tuple[str, SubSystem]] = []
if admin_token := settings.admin_token:
tokens.append((admin_token, SubSystem.ADMIN))
if sshd_token := settings.sshd_token:
tokens.append((sshd_token, SubSystem.SSHD))
with Session(engine) as session: with Session(engine) as session:
count = session.exec(select(func.count("*")).select_from(APIClient)).one() for token, subsystem in tokens:
hashed_token = hash_token(token)
if existing := session.scalars(
select(APIClient).where(APIClient.subsystem == subsystem)
).first():
existing.token = hashed_token
else:
new_token = APIClient(token=hashed_token, subsystem=subsystem)
session.add(new_token)
return count session.commit()
click.echo("Generated system tokens.")
@click.group() @click.group()
@ -49,27 +80,30 @@ def cli(ctx: click.Context, database: str) -> None:
else: else:
settings = BackendSettings() settings = BackendSettings()
add_system_tokens(settings)
if settings.generate_initial_tokens: # if settings.generate_initial_tokens:
if count_tokens(settings) == 0: # if count_tokens(settings) == 0:
click.echo("Creating initial tokens for admin and sshd.") # click.echo("Creating initial tokens for admin and sshd.")
admin_token = generate_token(settings) # admin_token = generate_token(settings)
sshd_token = generate_token(settings) # sshd_token = generate_token(settings)
click.echo(f"Admin token: {admin_token}") # click.echo(f"Admin token: {admin_token}")
click.echo(f"SSHD token: {sshd_token}") # click.echo(f"SSHD token: {sshd_token}")
ctx.obj = settings ctx.obj = settings
@cli.command("generate-token") @cli.command("generate-token")
@click.argument("subsystem", type=click.Choice(["sshd", "admin"]))
@click.pass_context @click.pass_context
def cli_generate_token(ctx: click.Context) -> None: def cli_generate_token(ctx: click.Context, subsystem: Literal["sshd", "admin"]) -> None:
"""Generate a token.""" """Generate a token for a subsystem.."""
settings = cast(BackendSettings, ctx.obj) settings = cast(BackendSettings, ctx.obj)
token = generate_token(settings) token = generate_token(settings, subsystem)
click.echo("Generated api token:") click.echo("Generated api token:")
click.echo(token) click.echo(token)
@cli.command("run") @cli.command("run")
@click.option("--host", default="127.0.0.1") @click.option("--host", default="127.0.0.1")
@click.option("--port", default=8022, type=click.INT) @click.option("--port", default=8022, type=click.INT)
@ -77,7 +111,10 @@ def cli_generate_token(ctx: click.Context) -> None:
@click.option("--workers", type=click.INT) @click.option("--workers", type=click.INT)
def cli_run(host: str, port: int, dev: bool, workers: int | None) -> None: def cli_run(host: str, port: int, dev: bool, workers: int | None) -> None:
"""Run the server.""" """Run the server."""
uvicorn.run("sshecret_backend.main:app", host=host, port=port, reload=dev, workers=workers) uvicorn.run(
"sshecret_backend.main:app", host=host, port=port, reload=dev, workers=workers
)
@cli.command("repl") @cli.command("repl")
@click.pass_context @click.pass_context

View File

@ -2,56 +2,108 @@
import logging import logging
import secrets import secrets
import sqlite3
from collections.abc import Generator, Callable from collections.abc import Generator, Callable
from pathlib import Path from typing import Literal
from sqlalchemy import Engine from sqlalchemy import create_engine, Engine, event, select
from sqlmodel import Session, create_engine, text from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
import bcrypt
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.engine import URL from sqlalchemy.engine import URL
from .auth import hash_token, verify_token
from .models import APIClient from .models import APIClient, SubSystem
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def setup_database( def setup_database(
db_url: URL | str, db_url: URL,
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]: ) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
"""Setup database.""" """Setup database."""
engine = create_engine(db_url, echo=False) engine = get_engine(db_url)
with engine.connect() as connection: SessionLocal = sessionmaker(autocommit=False, autoflush=False, future=True)
connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only
def get_db_session() -> Generator[Session, None, None]: def get_db_session() -> Generator[Session, None, None]:
"""Get DB Session.""" """Get DB Session."""
with Session(engine) as session: session = SessionLocal(bind=engine)
try:
yield session yield session
finally:
session.close()
return engine, get_db_session return engine, get_db_session
def get_engine(url: URL, echo: bool = False) -> Engine: def get_engine(url: URL, echo: bool = False) -> Engine:
"""Initialize the engine.""" """Initialize the engine."""
engine = create_engine(url, echo=echo) engine = create_engine(url, echo=echo, future=True)
with engine.connect() as connection: if url.drivername.startswith("sqlite"):
connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only
@event.listens_for(engine, "connect")
def set_sqlite_pragma(
dbapi_connection: sqlite3.Connection, _connection_record: object
) -> None:
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
return engine return engine
def create_api_token(session: Session, read_write: bool) -> str: def get_async_engine(url: URL, echo: bool = False) -> AsyncEngine:
"""Create API token.""" """Get an async engine."""
token = secrets.token_urlsafe(32) engine = create_async_engine(url, echo=echo, future=True)
pwbytes = token.encode("utf-8") if url.drivername.startswith("sqlite+"):
salt = bcrypt.gensalt()
hashed_bytes = bcrypt.hashpw(password=pwbytes, salt=salt)
hashed = hashed_bytes.decode()
api_token = APIClient(token=hashed, read_write=read_write) @event.listens_for(engine, "connect")
def set_sqlite_pragma(
dbapi_connection: sqlite3.Connection, _connection_record: object
) -> None:
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
return engine
def create_api_token_with_value(session: Session, token: str, subsystem: Literal["admin", "sshd"]) -> None:
"""Create API token with a given value."""
existing = session.scalars(select(APIClient).where(APIClient.subsystem == SubSystem(subsystem))).first()
if existing:
if verify_token(token, existing.token):
LOG.info("Token is up to date.")
return
LOG.info("Updating token value for subsystem %s", subsystem)
hashed = hash_token(token)
existing.token=hashed
session.commit()
return
LOG.info("No existing token found. Creating new")
hashed = hash_token(token)
api_token = APIClient(token=hashed, subsystem=SubSystem(subsystem))
session.add(api_token)
session.commit()
def create_api_token(session: Session, subsystem: Literal["admin", "sshd", "test"], recreate: bool = False) -> str:
"""Create API token."""
subsys = SubSystem(subsystem)
token = secrets.token_urlsafe(32)
hashed = hash_token(token)
if existing := session.scalars(select(APIClient).where(APIClient.subsystem == subsys)).first():
if not recreate:
raise RuntimeError("Error: A token already exist for this subsystem.")
existing.token = hashed
else:
api_token = APIClient(token=hashed, subsystem=subsys)
session.add(api_token) session.add(api_token)
session.commit() session.commit()

View File

@ -7,128 +7,182 @@ This might require some changes to these schemas.
""" """
import enum
import logging import logging
import uuid import uuid
from datetime import datetime from datetime import datetime
import sqlalchemy as sa import sqlalchemy as sa
from sqlmodel import JSON, Column, DateTime, Field, Relationship, SQLModel from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
class Client(SQLModel, table=True): class SubSystem(enum.StrEnum):
"""Client model.""" """Available subsystems."""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) ADMIN = enum.auto()
name: str = Field(unique=True) SSHD = enum.auto()
description: str | None = None BACKEND = enum.auto()
public_key: str TEST = enum.auto()
created_at: datetime | None = Field(
default=None, class Operation(enum.StrEnum):
sa_type=sa.DateTime(timezone=True), """Various operations for the audit logging module."""
sa_column_kwargs={"server_default": sa.func.now()},
nullable=False, CREATE = enum.auto()
READ = enum.auto()
UPDATE = enum.auto()
DELETE = enum.auto()
DENY = enum.auto()
PERMIT = enum.auto()
LOGIN = enum.auto()
REGISTER = enum.auto()
NONE = enum.auto()
class Base(DeclarativeBase):
pass
class Client(Base):
"""Clients."""
__tablename__: str = "client"
id: Mapped[uuid.UUID] = mapped_column(
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
)
name: Mapped[str] = mapped_column(sa.String, unique=True)
description: Mapped[str | None] = mapped_column(sa.String, nullable=True)
public_key: Mapped[str] = mapped_column(sa.Text)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
) )
updated_at: datetime | None = Field( updated_at: Mapped[datetime | None] = mapped_column(
default=None, sa.DateTime(timezone=True),
sa_type=sa.DateTime(timezone=True), server_default=sa.func.now(),
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()}, onupdate=sa.func.now(),
) )
secrets: list["ClientSecret"] = Relationship( secrets: Mapped[list["ClientSecret"]] = relationship(
back_populates="client", passive_deletes="all" back_populates="client", passive_deletes=True
) )
policies: list["ClientAccessPolicy"] = Relationship(back_populates="client") policies: Mapped[list["ClientAccessPolicy"]] = relationship(back_populates="client")
class ClientAccessPolicy(SQLModel, table=True): class ClientAccessPolicy(Base):
"""Client access policies.""" """Client access policies."""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) __tablename__: str = "client_access_policy"
source: str
client_id: uuid.UUID | None = Field(foreign_key="client.id", ondelete="CASCADE")
client: Client | None = Relationship(back_populates="policies")
created_at: datetime | None = Field( id: Mapped[uuid.UUID] = mapped_column(
default=None, sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
sa_type=sa.DateTime(timezone=True), )
sa_column_kwargs={"server_default": sa.func.now()}, source: Mapped[str] = mapped_column(sa.String)
nullable=False, client_id: Mapped[uuid.UUID | None] = mapped_column(
sa.Uuid(as_uuid=True), sa.ForeignKey("client.id", ondelete="CASCADE")
)
client: Mapped[Client] = relationship(back_populates="policies")
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
) )
updated_at: datetime | None = Field( updated_at: Mapped[datetime | None] = mapped_column(
default=None, sa.DateTime(timezone=True),
sa_type=sa.DateTime(timezone=True), server_default=sa.func.now(),
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()}, onupdate=sa.func.now(),
) )
class ClientSecret(SQLModel, table=True): class ClientSecret(Base):
"""A client secret.""" """A client secret."""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) __tablename__: str = "client_secret"
name: str
description: str | None = None id: Mapped[uuid.UUID] = mapped_column(
client_id: uuid.UUID | None = Field(foreign_key="client.id", ondelete="CASCADE") sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
client: Client | None = Relationship(back_populates="secrets") )
secret: str name: Mapped[str] = mapped_column(sa.String)
invalidated: bool = Field(default=False) description: Mapped[str | None] = mapped_column(sa.String, nullable=True)
created_at: datetime | None = Field( secret: Mapped[str] = mapped_column(sa.String)
default=None,
sa_type=sa.DateTime(timezone=True), client_id: Mapped[uuid.UUID | None] = mapped_column(
sa_column_kwargs={"server_default": sa.func.now()}, sa.Uuid(as_uuid=True), sa.ForeignKey("client.id", ondelete="CASCADE")
nullable=False, )
client: Mapped[Client] = relationship(back_populates="secrets")
invalidated: Mapped[bool] = mapped_column(default=False)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
) )
updated_at: datetime | None = Field( updated_at: Mapped[datetime | None] = mapped_column(
default=None, sa.DateTime(timezone=True),
sa_type=sa.DateTime(timezone=True), server_default=sa.func.now(),
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()}, onupdate=sa.func.now(),
) )
class AuditLog(SQLModel, table=True): class APIClient(Base):
"""A client on the API.
This should eventually get more granular permissions.
"""
__tablename__: str = "api_client"
id: Mapped[uuid.UUID] = mapped_column(
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
)
subsystem: Mapped[SubSystem | None] = mapped_column(sa.String, nullable=True)
token: Mapped[str] = mapped_column(sa.String)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
)
updated_at: Mapped[datetime | None] = mapped_column(
sa.DateTime(timezone=True),
server_default=sa.func.now(),
onupdate=sa.func.now(),
)
class AuditLog(Base):
"""Audit log. """Audit log.
This is implemented without any foreign keys to avoid losing data on This is implemented without any foreign keys to avoid losing data on
deletions. deletions.
""" """
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) __tablename__: str = "audit_log"
subsystem: str id: Mapped[uuid.UUID] = mapped_column(
message: str sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
operation: str
client_id: uuid.UUID | None = None
client_name: str | None = None
origin: str | None = None
Field(default=None, sa_column=Column(JSON))
timestamp: datetime | None = Field(
default=None,
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"server_default": sa.func.now()},
nullable=False,
) )
subsystem: Mapped[SubSystem] = mapped_column(sa.String)
message: Mapped[str] = mapped_column(sa.String)
operation: Mapped[Operation] = mapped_column(sa.String)
client_id: Mapped[uuid.UUID | None] = mapped_column(
sa.Uuid(as_uuid=True), nullable=True
)
data: Mapped[dict[str, str] | None] = mapped_column(sa.JSON, nullable=True)
client_name: Mapped[str | None] = mapped_column(sa.String, nullable=True)
secret_id: Mapped[uuid.UUID | None] = mapped_column(
sa.Uuid(as_uuid=True), nullable=True
)
secret_name: Mapped[str | None] = mapped_column(sa.String, nullable=True)
class APIClient(SQLModel, table=True): origin: Mapped[str | None] = mapped_column(sa.String, nullable=True)
"""Stores API Keys.""" timestamp: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
token: str
read_write: bool
created_at: datetime | None = Field(
default=None,
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"server_default": sa.func.now()},
nullable=False,
) )
def init_db(engine: sa.Engine) -> None: def init_db(engine: sa.Engine) -> None:
"""Create database.""" """Initialize database."""
LOG.info("Running init_db") Base.metadata.create_all(engine)
SQLModel.metadata.create_all(engine)

View File

@ -1,12 +1,10 @@
"""Settings management.""" """Settings management."""
from pathlib import Path from pathlib import Path
from typing import Annotated, Any from pydantic import Field
from pydantic import Field, field_validator
from pydantic_settings import ( from pydantic_settings import (
BaseSettings, BaseSettings,
SettingsConfigDict, SettingsConfigDict,
ForceDecode,
) )
from sqlalchemy import URL from sqlalchemy import URL
@ -22,24 +20,19 @@ class BackendSettings(BaseSettings):
) )
database: str = Field(default=DEFAULT_DATABASE) database: str = Field(default=DEFAULT_DATABASE)
generate_initial_tokens: Annotated[bool, ForceDecode] = Field(default=False) admin_token: str | None = Field(default=None, alias="sshecret_admin_backend_token")
sshd_token: str | None = Field(default=None, alias="sshecret_sshd_backend_token")
@field_validator("generate_initial_tokens", mode="before")
@classmethod
def cast_bool(cls, value: Any) -> bool:
"""Ensure we catch the boolean."""
if isinstance(value, str):
if value.lower() in ("1", "true", "on"):
return True
if value.lower() in ("0", "false", "off"):
return False
return bool(value)
@property @property
def db_url(self) -> URL: def db_url(self) -> URL:
"""Construct database url.""" """Construct database url."""
return URL.create(drivername="sqlite", database=self.database) return URL.create(drivername="sqlite", database=self.database)
@property
def async_db_url(self) -> URL:
"""Construct database url with sync handling."""
return URL.create(drivername="sqlite+aiosqlite", database=self.database)
@property @property
def db_exists(self) -> bool: def db_exists(self) -> bool:
"""Check if databatase exists.""" """Check if databatase exists."""

View File

@ -1,8 +1,9 @@
"""Test helpers.""" """Test helpers."""
import logging import logging
from sqlmodel import Session
from sshecret_backend.settings import BackendSettings from sshecret_backend.settings import BackendSettings
from sqlalchemy.orm import Session
from .models import init_db from .models import init_db
from .db import create_api_token, setup_database from .db import create_api_token, setup_database
@ -14,4 +15,4 @@ def create_test_token(settings: BackendSettings) -> str:
engine, _setupdb = setup_database(settings.db_url) engine, _setupdb = setup_database(settings.db_url)
with Session(engine) as session: with Session(engine) as session:
init_db(engine) init_db(engine)
return create_api_token(session, True) return create_api_token(session, "test")

View File

@ -2,7 +2,7 @@
from collections.abc import Callable, Generator from collections.abc import Callable, Generator
from sqlmodel import Session from sqlalchemy.orm import Session
DBSessionDep = Callable[[], Generator[Session, None, None]] DBSessionDep = Callable[[], Generator[Session, None, None]]

View File

@ -4,15 +4,14 @@ import uuid
from datetime import datetime from datetime import datetime
from typing import Annotated, Self, override from typing import Annotated, Self, override
from sqlmodel import Field, SQLModel from pydantic import AfterValidator, BaseModel, Field, IPvAnyAddress, IPvAnyNetwork
from pydantic import AfterValidator, IPvAnyAddress, IPvAnyNetwork
from sshecret.crypto import public_key_validator from sshecret.crypto import public_key_validator
from . import models from . import models
class ClientView(SQLModel): class ClientView(BaseModel):
"""View for a single client.""" """View for a single client."""
id: uuid.UUID id: uuid.UUID
@ -50,7 +49,7 @@ class ClientView(SQLModel):
return view return view
class ClientQueryResult(SQLModel): class ClientQueryResult(BaseModel):
"""Result class for queries towards the client list.""" """Result class for queries towards the client list."""
clients: list[ClientView] = Field(default_factory=list) clients: list[ClientView] = Field(default_factory=list)
@ -58,7 +57,7 @@ class ClientQueryResult(SQLModel):
remaining_results: int remaining_results: int
class ClientCreate(SQLModel): class ClientCreate(BaseModel):
"""Model to create a client.""" """Model to create a client."""
name: str name: str
@ -74,19 +73,19 @@ class ClientCreate(SQLModel):
) )
class ClientUpdate(SQLModel): class ClientUpdate(BaseModel):
"""Model to update the client public key.""" """Model to update the client public key."""
public_key: Annotated[str, AfterValidator(public_key_validator)] public_key: Annotated[str, AfterValidator(public_key_validator)]
class BodyValue(SQLModel): class BodyValue(BaseModel):
"""A generic model with just a value parameter.""" """A generic model with just a value parameter."""
value: str value: str
class ClientSecretPublic(SQLModel): class ClientSecretPublic(BaseModel):
"""Public model to manage client secrets.""" """Public model to manage client secrets."""
name: str name: str
@ -122,7 +121,7 @@ class ClientSecretResponse(ClientSecretPublic):
) )
class ClientPolicyView(SQLModel): class ClientPolicyView(BaseModel):
"""Update object for client policy.""" """Update object for client policy."""
sources: list[str] = ["0.0.0.0/0", "::/0"] sources: list[str] = ["0.0.0.0/0", "::/0"]
@ -135,27 +134,27 @@ class ClientPolicyView(SQLModel):
return cls(sources=[policy.source for policy in client.policies]) return cls(sources=[policy.source for policy in client.policies])
class ClientPolicyUpdate(SQLModel): class ClientPolicyUpdate(BaseModel):
"""Model for updating policies.""" """Model for updating policies."""
sources: list[IPvAnyAddress | IPvAnyNetwork] sources: list[IPvAnyAddress | IPvAnyNetwork]
class ClientSecretList(SQLModel): class ClientSecretList(BaseModel):
"""Model for aggregating identically named secrets.""" """Model for aggregating identically named secrets."""
name: str name: str
clients: list[str] clients: list[str]
class ClientReference(SQLModel): class ClientReference(BaseModel):
"""Reference to a client.""" """Reference to a client."""
id: str id: str
name: str name: str
class ClientSecretDetailList(SQLModel): class ClientSecretDetailList(BaseModel):
"""A more detailed version of the ClientSecretList.""" """A more detailed version of the ClientSecretList."""
name: str name: str
@ -163,7 +162,22 @@ class ClientSecretDetailList(SQLModel):
clients: list[ClientReference] = Field(default_factory=list) clients: list[ClientReference] = Field(default_factory=list)
class AuditInfo(SQLModel): class AuditView(BaseModel):
"""Audit log view."""
id: uuid.UUID | None = None
subsystem: models.SubSystem
message: str
operation: models.Operation
data: dict[str, str] | None = None
client_id: uuid.UUID | None = None
client_name: str | None = None
origin: str | None = None
timestamp: datetime | None = None
class AuditInfo(BaseModel):
"""Information about audit information.""" """Information about audit information."""
entries: int entries: int

View File

@ -10,7 +10,7 @@ from fastapi.testclient import TestClient
from sshecret.crypto import generate_private_key, generate_public_key_string from sshecret.crypto import generate_private_key, generate_public_key_string
from sshecret_backend.app import create_backend_app from sshecret_backend.app import create_backend_app
from sshecret_backend.testing import create_test_token from sshecret_backend.testing import create_test_token
from sshecret_backend.models import AuditLog from sshecret_backend.view_models import AuditView
from sshecret_backend.settings import BackendSettings from sshecret_backend.settings import BackendSettings
@ -53,7 +53,7 @@ def create_client_fixture(tmp_path: Path):
db_file = tmp_path / "backend.db" db_file = tmp_path / "backend.db"
print(f"DB File: {db_file.absolute()}") print(f"DB File: {db_file.absolute()}")
settings = BackendSettings(db_url=f"sqlite:///{db_file.absolute()}") settings = BackendSettings(database=str(db_file.absolute()))
app = create_backend_app(settings) app = create_backend_app(settings)
token = create_test_token(settings) token = create_test_token(settings)
@ -213,7 +213,7 @@ def test_audit_logging(test_client: TestClient) -> None:
assert len(audit_logs) > 0 assert len(audit_logs) > 0
for entry in audit_logs: for entry in audit_logs:
# Let's try to reassemble the objects # Let's try to reassemble the objects
audit_log = AuditLog.model_validate(entry) audit_log = AuditView.model_validate(entry)
assert audit_log is not None assert audit_log is not None
@ -522,9 +522,8 @@ def test_operations_with_id(test_client: TestClient) -> None:
def test_write_audit_log(test_client: TestClient) -> None: def test_write_audit_log(test_client: TestClient) -> None:
"""Test writing to the audit log.""" """Test writing to the audit log."""
params = { params = {
"object": "Test", "subsystem": "backend",
"operation": "TEST", "operation": "read",
"object_id": "Something",
"message": "Test Message" "message": "Test Message"
} }
resp = test_client.post("/api/v1/audit", json=params) resp = test_client.post("/api/v1/audit", json=params)

View File

@ -8,8 +8,10 @@ from .models import (
ClientReference, ClientReference,
ClientSecret, ClientSecret,
DetailedSecrets, DetailedSecrets,
Operation,
Policy, Policy,
Secret, Secret,
SubSystem,
FilterType, FilterType,
) )
@ -22,7 +24,9 @@ __all__ = [
"ClientSecret", "ClientSecret",
"DetailedSecrets", "DetailedSecrets",
"FilterType", "FilterType",
"Operation",
"Policy", "Policy",
"Secret", "Secret",
"SubSystem",
"SshecretBackend", "SshecretBackend",
] ]

View File

@ -1,21 +1,26 @@
"""Backend client. """Backend client.
This is an API calling the HTTP API of the sshecret-backend package so that the
admin and sshd library do not need to implement the same
""" """
import logging import logging
from typing import Any, Self from typing import Any, Self, override
import httpx import httpx
from pydantic import TypeAdapter from pydantic import TypeAdapter
from .models import ( from .models import (
AuditInfo,
AuditLog, AuditLog,
Client, Client,
ClientSecret, ClientSecret,
ClientQueryResult, ClientQueryResult,
ClientFilter, ClientFilter,
DetailedSecrets, DetailedSecrets,
Operation,
Secret, Secret,
SubSystem,
) )
from .exceptions import BackendValidationError, BackendConnectionError from .exceptions import BackendValidationError, BackendConnectionError
from .utils import validate_public_key from .utils import validate_public_key
@ -84,12 +89,14 @@ class ClientQueryIterator:
raise StopAsyncIteration raise StopAsyncIteration
class SshecretBackend: class BaseBackend:
"""Backend interface.""" """Base backend class."""
def __init__(self, backend_url: str, api_token: str) -> None: def __init__(self, backend_url: str, api_token: str) -> None:
"""Initialize backend client.""" """Initialize backend client."""
self._backend_url: str = backend_url
self._api_token: str = api_token
url = httpx.URL(backend_url) url = httpx.URL(backend_url)
self.http_client: httpx.AsyncClient = httpx.AsyncClient( self.http_client: httpx.AsyncClient = httpx.AsyncClient(
@ -101,31 +108,55 @@ class SshecretBackend:
base_url=url, base_url=url,
) )
async def _get(self, path: str) -> httpx.Response: def _handle_response(self, response: httpx.Response) -> httpx.Response:
"""Handle response."""
LOG.debug("Handling response with status_code %s", response.status_code)
if response.status_code == 422:
LOG.error("Validation error from backend:\n%s", response.text)
raise BackendValidationError(response.text)
if response.status_code != 404 and str(response.status_code).startswith("4"):
raise BackendConnectionError("Error from backend: %s", response.text)
return response
async def _get(
self, path: str, params: dict[str, str] | None = None
) -> httpx.Response:
"""Perform a get request.""" """Perform a get request."""
try: try:
return await self.http_client.get(path) response = await self.http_client.get(path, params=params)
return self._handle_response(response)
except httpx.ConnectError as e: except httpx.ConnectError as e:
raise BackendConnectionError() from e raise BackendConnectionError("Could not connect to backend.") from e
async def _delete(self, path: str) -> httpx.Response: async def _delete(self, path: str) -> httpx.Response:
"""Perform a delete request.""" """Perform a delete request."""
try: try:
return await self.http_client.delete(path) response = await self.http_client.delete(path)
return self._handle_response(response)
except httpx.ConnectError as e: except httpx.ConnectError as e:
raise BackendConnectionError() from e raise BackendConnectionError() from e
async def _post(self, path: str, json: Any | None = None) -> httpx.Response: async def _post(self, path: str, json: Any | None = None) -> httpx.Response:
"""Perform a POST request.""" """Perform a POST request."""
try: try:
return await self.http_client.post(path, json=json) response = await self.http_client.post(path, json=json)
return self._handle_response(response)
except httpx.ConnectError as e:
raise BackendConnectionError() from e
def _post_sync(self, path: str, json: Any | None = None) -> httpx.Response:
"""Perform a synchronous post request."""
try:
return self._handle_response(self.sync_client.post(path, json=json))
except httpx.ConnectError as e: except httpx.ConnectError as e:
raise BackendConnectionError() from e raise BackendConnectionError() from e
async def _put(self, path: str, json: Any | None = None) -> httpx.Response: async def _put(self, path: str, json: Any | None = None) -> httpx.Response:
"""Perform a PUT request.""" """Perform a PUT request."""
try: try:
return await self.http_client.put(path, json=json) response = await self.http_client.put(path, json=json)
return self._handle_response(response)
except httpx.ConnectError as e: except httpx.ConnectError as e:
raise BackendConnectionError() from e raise BackendConnectionError() from e
@ -134,175 +165,90 @@ class SshecretBackend:
response = await self._get(path) response = await self._get(path)
return response return response
async def create_client(
self, name: str, public_key: str, description: str | None = None
) -> None:
"""Register a new client."""
if not validate_public_key(public_key):
raise BackendValidationError("Error: Invalid public key format.")
data = {
"name": name,
"public_key": public_key,
}
if description:
data["description"] = description
path = "/api/v1/clients/"
response = await self._post(path, json=data)
response.raise_for_status() class AuditAPI(BaseBackend):
"""API for the audit logging."""
async def get_clients(self, filter: ClientFilter | None = None) -> list[Client]: @override
"""Get all clients.""" def __init__(self, backend_url: str, api_token: str, subsystem: str) -> None:
clients: list[Client] = [] """Initialize backend client."""
async for client in ClientQueryIterator(self.http_client, filter): super().__init__(backend_url, api_token)
clients.append(client) self.subsystem: SubSystem = SubSystem(subsystem)
return clients def _create_model(
self,
async def get_client(self, name: str) -> Client | None: operation: Operation,
"""Lookup a client on username.""" message: str,
path = f"/api/v1/clients/{name}" origin: str,
response = await self.request(path) client: Client | None = None,
if response.status_code == 404: secret: ClientSecret | None = None,
return None secret_name: str | None = None,
response.raise_for_status() data: dict[str, str] | None = None,
client = Client.model_validate(response.json()) ) -> AuditLog:
return client """Create the audit log object."""
model = AuditLog(
async def get_client_by_id(self, id: str) -> Client | None: subsystem=self.subsystem,
"""Lookup a client on username.""" operation=operation,
path = f"/api/v1/clients/id/{id}" message=message,
response = await self.request(path) origin=origin,
if response.status_code == 404:
return None
response.raise_for_status()
client = Client.model_validate(response.json())
return client
async def delete_client(self, client_name: str) -> None:
"""Delete a client."""
path = f"/api/v1/clients/{client_name}"
response = await self._delete(path)
response.raise_for_status()
async def delete_client_by_id(self, id: str) -> None:
"""Delete a client."""
path = f"/api/v1/clients/id/{id}"
response = await self._delete(path)
response.raise_for_status()
async def create_client_secret(
self, client_name: str, secret_name: str, encrypted_secret: str
) -> None:
"""Create a secret.
This will overwrite any existing secret with that name.
"""
path = f"api/v1/clients/{client_name}/secrets/{secret_name}"
response = await self._put(path, json={"value": encrypted_secret})
response.raise_for_status()
async def get_client_secret(self, name: str, secret_name: str) -> str:
"""Fetch a secret."""
path = f"/api/v1/clients/{name}/secrets/{secret_name}"
response = await self.request(path)
response.raise_for_status()
secret = ClientSecret.model_validate(response.json())
return secret.secret
async def delete_client_secret(self, client_name: str, secret_name: str) -> None:
"""Delete a secret from a client."""
path = f"api/v1/clients/{client_name}/secrets/{secret_name}"
response = await self._delete(path)
response.raise_for_status()
async def update_client(self, client: Client) -> Client:
"""Update the client."""
path = f"/api/v1/clients/{client.name}"
client_update = {
"name": client.name,
"description": client.description,
"public_key": client.public_key,
}
response = await self._put(path, json=client_update)
LOG.info("Response %s", response.text)
response.raise_for_status()
if client.policies:
await self.update_client_sources(
str(client.id), [str(source) for source in client.policies]
) )
return client if client:
model.client_id = str(client.id) or None
model.client_name = client.name
async def update_client_key(self, client_name: str, public_key: str) -> None: if secret:
"""Update the client key.""" model.secret_name = secret.name
path = f"/api/v1/clients/{client_name}/public-key" elif secret_name:
response = await self._post(path, json={"public_key": public_key}) model.secret_name = secret_name
if data:
model.data = data
response.raise_for_status() return model
async def update_client_sources( def write_model(self, model: AuditLog) -> None:
self, client_name: str, addresses: list[str] | None """Write model."""
path = f"/api/v1/audit/"
self.sync_client.post(path, json=model.model_dump())
async def write_model_async(self, model: AuditLog) -> None:
"""Write model async."""
path = f"/api/v1/audit/"
await self._post(path, json=model.model_dump())
def write(
self,
operation: Operation,
message: str,
origin: str,
client: Client | None = None,
secret: ClientSecret | None = None,
secret_name: str | None = None,
**data: str,
) -> None: ) -> None:
"""Update client source addresses. """Write an audit entry."""
model = self._create_model(
operation, message, origin, client, secret, secret_name, data
)
Pass None to sources to allow from all. self.write_model(model)
"""
if not addresses:
addresses = []
path = f"/api/v1/clients/{client_name}/policies/" async def write_async(
response = await self._put(path, json={"sources": addresses}) self,
operation: Operation,
message: str,
origin: str,
client: Client | None = None,
secret: ClientSecret | None = None,
secret_name: str | None = None,
**data: str,
) -> None:
"""Write an audit entry."""
model = self._create_model(
operation, message, origin, client, secret, secret_name, data
)
await self.write_model_async(model)
response.raise_for_status() async def get(
async def get_detailed_secrets(self) -> list[DetailedSecrets]:
"""Get detailed list of secrets."""
path = "/api/v1/secrets/detailed/"
response = await self._get(path)
response.raise_for_status()
secret_list = TypeAdapter(list[DetailedSecrets])
return secret_list.validate_python(response.json())
async def get_secrets(self) -> list[Secret]:
"""Get Secrets.
This provides a list of secret names and which clients have them.
"""
path = "/api/v1/secrets/"
response = await self._get(path)
response.raise_for_status()
secret_list = TypeAdapter(list[Secret])
return secret_list.validate_python(response.json())
async def get_secret(self, name: str) -> Secret | None:
"""Get clients mapped to a single secret."""
path = f"/api/v1/secrets/{name}"
response = await self._get(path)
if response.status_code == 404:
return None
response.raise_for_status()
return Secret.model_validate(response.json())
async def get_detailed_secret(self, name: str) -> DetailedSecrets | None:
"""Get clients mapped to a single secret."""
path = f"/api/v1/secrets/{name}/detailed"
response = await self._get(path)
if response.status_code == 404:
return None
response.raise_for_status()
return DetailedSecrets.model_validate(response.json())
async def get_audit_log(
self, self,
offset: int = 0, offset: int = 0,
limit: int = 100, limit: int = 100,
@ -321,30 +267,169 @@ class SshecretBackend:
if subsystem: if subsystem:
params["filter_subsystem"] = subsystem params["filter_subsystem"] = subsystem
response = await self.http_client.get(path, params=params) response = await self._get(path, params=params)
response.raise_for_status()
audit_log_adapter = TypeAdapter(list[AuditLog]) audit_log_adapter = TypeAdapter(list[AuditLog])
return audit_log_adapter.validate_python(response.json()) return audit_log_adapter.validate_python(response.json())
async def add_audit_log(self, entry: AuditLog) -> None: async def count(self) -> int:
"""Add audit log entry."""
path = f"/api/v1/audit/"
response = await self.http_client.post(path, json=entry.model_dump())
response.raise_for_status()
async def get_audit_log_count(self) -> int:
"""Get amount of messages in the audit log.""" """Get amount of messages in the audit log."""
path = f"/api/v1/audit/info" path = f"/api/v1/audit/info"
response = await self._get(path) response = await self._get(path)
response.raise_for_status() audit_info = AuditInfo.model_validate(response.json())
data = response.json() return audit_info.entries
return int(data["entries"])
def add_audit_log_sync(self, entry: AuditLog) -> None:
"""Add audit log entry."""
path = f"/api/v1/audit/"
LOG.info("AUDIT LOG SYNC %r", entry)
response = self.sync_client.post(path, json=entry.model_dump()) class SshecretBackend(BaseBackend):
response.raise_for_status() """Backend interface."""
async def create_client(
self, name: str, public_key: str, description: str | None = None
) -> None:
"""Register a new client."""
if not validate_public_key(public_key):
raise BackendValidationError("Error: Invalid public key format.")
data = {
"name": name,
"public_key": public_key,
}
if description:
data["description"] = description
path = "/api/v1/clients/"
response = await self._post(path, json=data)
async def get_clients(self, filter: ClientFilter | None = None) -> list[Client]:
"""Get all clients."""
clients: list[Client] = []
async for client in ClientQueryIterator(self.http_client, filter):
clients.append(client)
return clients
async def get_client(self, name: str) -> Client | None:
"""Lookup a client on username."""
path = f"/api/v1/clients/{name}"
response = await self._get(path)
if response.status_code == 404:
return None
client = Client.model_validate(response.json())
return client
async def get_client_by_id(self, id: str) -> Client | None:
"""Lookup a client on username."""
path = f"/api/v1/clients/id/{id}"
response = await self._get(path)
if response.status_code == 404:
return None
client = Client.model_validate(response.json())
return client
async def delete_client(self, client_name: str) -> None:
"""Delete a client."""
path = f"/api/v1/clients/{client_name}"
response = await self._delete(path)
async def delete_client_by_id(self, id: str) -> None:
"""Delete a client."""
path = f"/api/v1/clients/id/{id}"
response = await self._delete(path)
async def create_client_secret(
self, client_name: str, secret_name: str, encrypted_secret: str
) -> None:
"""Create a secret.
This will overwrite any existing secret with that name.
"""
path = f"api/v1/clients/{client_name}/secrets/{secret_name}"
response = await self._put(path, json={"value": encrypted_secret})
async def get_client_secret(self, name: str, secret_name: str) -> str | None:
"""Fetch a secret."""
path = f"/api/v1/clients/{name}/secrets/{secret_name}"
response = await self._get(path)
if response.status_code == 404:
return None
secret = ClientSecret.model_validate(response.json())
return secret.secret
async def delete_client_secret(self, client_name: str, secret_name: str) -> None:
"""Delete a secret from a client."""
path = f"api/v1/clients/{client_name}/secrets/{secret_name}"
await self._delete(path)
async def update_client(self, client: Client) -> Client:
"""Update the client."""
path = f"/api/v1/clients/{client.name}"
client_update = {
"name": client.name,
"description": client.description,
"public_key": client.public_key,
}
response = await self._put(path, json=client_update)
LOG.info("Response %s", response.text)
if client.policies:
await self.update_client_sources(
str(client.id), [str(source) for source in client.policies]
)
return client
async def update_client_key(self, client_name: str, public_key: str) -> None:
"""Update the client key."""
path = f"/api/v1/clients/{client_name}/public-key"
await self._post(path, json={"public_key": public_key})
async def update_client_sources(
self, client_name: str, addresses: list[str] | None
) -> None:
"""Update client source addresses.
Pass None to sources to allow from all.
"""
if not addresses:
addresses = []
path = f"/api/v1/clients/{client_name}/policies/"
await self._put(path, json={"sources": addresses})
async def get_detailed_secrets(self) -> list[DetailedSecrets]:
"""Get detailed list of secrets."""
path = "/api/v1/secrets/detailed/"
response = await self._get(path)
secret_list = TypeAdapter(list[DetailedSecrets])
return secret_list.validate_python(response.json())
async def get_secrets(self) -> list[Secret]:
"""Get Secrets.
This provides a list of secret names and which clients have them.
"""
path = "/api/v1/secrets/"
response = await self._get(path)
secret_list = TypeAdapter(list[Secret])
return secret_list.validate_python(response.json())
async def get_secret(self, name: str) -> Secret | None:
"""Get clients mapped to a single secret."""
path = f"/api/v1/secrets/{name}"
response = await self._get(path)
if response.status_code == 404:
return None
return Secret.model_validate(response.json())
async def get_detailed_secret(self, name: str) -> DetailedSecrets | None:
"""Get clients mapped to a single secret."""
path = f"/api/v1/secrets/{name}/detailed"
response = await self._get(path)
if response.status_code == 404:
return None
return DetailedSecrets.model_validate(response.json())
def audit(self, subsystem: SubSystem) -> AuditAPI:
"""Create the audit API."""
audit = AuditAPI(self._backend_url, self._api_token, subsystem)
return audit

View File

@ -17,6 +17,27 @@ class FilterType(enum.StrEnum):
CONTAINS = "contains" CONTAINS = "contains"
class SubSystem(enum.StrEnum):
"""Available subsystems."""
ADMIN = enum.auto()
SSHD = enum.auto()
BACKEND = enum.auto()
class Operation(enum.StrEnum):
"""Various operations for the audit logging module."""
CREATE = enum.auto()
READ = enum.auto()
UPDATE = enum.auto()
DELETE = enum.auto()
DENY = enum.auto()
PERMIT = enum.auto()
LOGIN = enum.auto()
NONE = enum.auto()
class Client(BaseModel): class Client(BaseModel):
"""Implementation of the backend class ClientView.""" """Implementation of the backend class ClientView."""
@ -101,15 +122,22 @@ class ClientFilter(BaseModel):
class AuditLog(BaseModel): class AuditLog(BaseModel):
"""Implementation of the backend class AuditLog.""" """Implementation of the backend class AuditView."""
id: str | None = None id: str | None = None
subsystem: str | None = None subsystem: SubSystem
object: str | None = None operation: Operation
object_id: str | None = None
operation: str
client_id: str | None = None client_id: str | None = None
client_name: str | None = None client_name: str | None = None
secret_id: str | None = None
secret_name: str | None = None
data: dict[str, str] | None = None
message: str message: str
origin: str | None = None origin: str | None = None
timestamp: datetime | None = None timestamp: datetime | None = None
class AuditInfo(BaseModel):
"""Implementation of the backend class AuditInfo."""
entries: int