Refactor database layer and auditing

This commit is contained in:
2025-05-10 08:38:57 +02:00
parent d866553ac1
commit 9ccd2f1d4d
20 changed files with 718 additions and 469 deletions

View File

@ -5,13 +5,14 @@
import logging
from collections.abc import Sequence
from fastapi import APIRouter, Depends, Request, Query
from sqlmodel import Session, col, func, select
from sqlalchemy import desc
from pydantic import TypeAdapter
from sqlalchemy import select, func
from sqlalchemy.orm import Session
from typing import Annotated
from sshecret_backend.models import AuditLog
from sshecret_backend.types import DBSessionDep
from sshecret_backend.view_models import AuditInfo
from sshecret_backend.view_models import AuditInfo, AuditView
LOG = logging.getLogger(__name__)
@ -21,7 +22,7 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
"""Construct audit sub-api."""
router = APIRouter()
@router.get("/audit/", response_model=list[AuditLog])
@router.get("/audit/", response_model=list[AuditView])
async def get_audit_logs(
request: Request,
session: Annotated[Session, Depends(get_db_session)],
@ -29,35 +30,37 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
limit: Annotated[int, Query(le=100)] = 100,
filter_client: Annotated[str | None, Query()] = None,
filter_subsystem: Annotated[str | None, Query()] = None,
) -> Sequence[AuditLog]:
) -> Sequence[AuditView]:
"""Get audit logs."""
#audit.audit_access_audit_log(session, request)
statement = select(AuditLog).offset(offset).limit(limit).order_by(desc(col(AuditLog.timestamp)))
statement = select(AuditLog).offset(offset).limit(limit).order_by(AuditLog.timestamp.desc())
if filter_client:
statement = statement.where(AuditLog.client_name == filter_client)
if filter_subsystem:
statement = statement.where(AuditLog.subsystem == filter_subsystem)
results = session.exec(statement).all()
return results
LogAdapt = TypeAdapter(list[AuditView])
results = session.scalars(statement).all()
return LogAdapt.validate_python(results, from_attributes=True)
@router.post("/audit/")
async def add_audit_log(
request: Request,
session: Annotated[Session, Depends(get_db_session)],
entry: AuditLog,
) -> AuditLog:
entry: AuditView,
) -> AuditView:
"""Add entry to audit log."""
audit_log = AuditLog.model_validate(entry.model_dump(exclude_none=True))
audit_log = AuditLog(**entry.model_dump(exclude_none=True))
session.add(audit_log)
session.commit()
return audit_log
return AuditView.model_validate(audit_log, from_attributes=True)
@router.get("/audit/info")
async def get_audit_info(request: Request, session: Annotated[Session, Depends(get_db_session)]) -> AuditInfo:
"""Get audit info."""
audit_count = session.exec(select(func.count('*')).select_from(AuditLog)).one()
audit_count = session.scalars(select(func.count('*')).select_from(AuditLog)).one()
return AuditInfo(entries=audit_count)

View File

@ -6,11 +6,11 @@ import uuid
import logging
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from pydantic import BaseModel, Field, model_validator
from sqlmodel import Session, col, select
from sqlalchemy import func
from typing import Annotated, Self, TypeVar
from typing import Annotated, Any, Self, TypeVar, cast
from sqlmodel.sql.expression import SelectOfScalar
from sqlalchemy import select, func
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select
from sshecret_backend.types import DBSessionDep
from sshecret_backend.models import Client, ClientSecret
from sshecret_backend.view_models import (
@ -55,8 +55,8 @@ T = TypeVar("T")
def filter_client_statement(
statement: SelectOfScalar[T], params: ClientListParams, ignore_limits: bool = False
) -> SelectOfScalar[T]:
statement: Select[Any], params: ClientListParams, ignore_limits: bool = False
) -> Select[Any]:
"""Filter a statement with the provided params."""
if params.id:
statement = statement.where(Client.id == params.id)
@ -64,9 +64,9 @@ def filter_client_statement(
if params.name:
statement = statement.where(Client.name == params.name)
elif params.name__like:
statement = statement.where(col(Client.name).like(params.name__like))
statement = statement.where(Client.name.like(params.name__like))
elif params.name__contains:
statement = statement.where(col(Client.name).contains(params.name__contains))
statement = statement.where(Client.name.contains(params.name__contains))
if ignore_limits:
return statement
@ -86,13 +86,13 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
"""Get clients."""
# Get total results first
count_statement = select(func.count("*")).select_from(Client)
count_statement = filter_client_statement(count_statement, filter_query, True)
count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True))
total_results = session.exec(count_statement).one()
total_results = session.scalars(count_statement).one()
statement = filter_client_statement(select(Client), filter_query, False)
results = session.exec(statement)
results = session.scalars(statement)
remainder = total_results - filter_query.offset - filter_query.limit
if remainder < 0:
remainder = 0
@ -170,13 +170,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
status_code=404, detail="Cannot find a client with the given name."
)
client.public_key = client_update.public_key
for secret in session.exec(
for secret in session.scalars(
select(ClientSecret).where(ClientSecret.client_id == client.id)
).all():
LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True
secret.client_id = None
secret.client = None
session.add(client)
session.refresh(client)
@ -206,13 +205,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
public_key_updated = False
if client_update.public_key != client.public_key:
public_key_updated = True
for secret in session.exec(
for secret in session.scalars(
select(ClientSecret).where(ClientSecret.client_id == client.id)
).all():
LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True
secret.client_id = None
secret.client = None
session.add(client)
session.commit()

View File

@ -4,7 +4,8 @@ import re
import uuid
import bcrypt
from sqlmodel import Session, select
from sqlalchemy import select
from sqlalchemy.orm import Session
from sshecret_backend.models import Client
@ -20,13 +21,13 @@ def verify_token(token: str, stored_hash: str) -> bool:
async def get_client_by_name(session: Session, name: str) -> Client | None:
"""Get client by name."""
client_filter = select(Client).where(Client.name == name)
client_results = session.exec(client_filter)
client_results = session.scalars(client_filter)
return client_results.first()
async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None:
"""Get client by name."""
client_filter = select(Client).where(Client.id == id)
client_results = session.exec(client_filter)
client_results = session.scalars(client_filter)
return client_results.first()
async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None:

View File

@ -4,7 +4,8 @@
import logging
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlmodel import Session, select
from sqlalchemy import select
from sqlalchemy.orm import Session
from typing import Annotated
from sshecret_backend.models import ClientAccessPolicy
@ -54,7 +55,7 @@ def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
status_code=404, detail="Cannot find a client with the given name."
)
# Remove old policies.
policies = session.exec(
policies = session.scalars(
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
).all()
deleted_policies: list[ClientAccessPolicy] = []

View File

@ -5,7 +5,8 @@
import logging
from collections import defaultdict
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlmodel import Session, select
from sqlalchemy import select
from sqlalchemy.orm import Session
from typing import Annotated
from sshecret_backend.models import Client, ClientSecret
@ -34,7 +35,7 @@ async def lookup_client_secret(
.where(ClientSecret.client_id == client.id)
.where(ClientSecret.name == name)
)
results = session.exec(statement)
results = session.scalars(statement)
return results.first()
@ -165,7 +166,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
) -> list[ClientSecretList]:
"""Get a list of all secrets and which clients have them."""
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
for client_secret in session.exec(select(ClientSecret)).all():
for client_secret in session.scalars(select(ClientSecret)).all():
if not client_secret.client:
if client_secret.name not in client_secret_map:
client_secret_map[client_secret.name] = []
@ -182,7 +183,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
) -> list[ClientSecretDetailList]:
"""Get a list of all secrets and which clients have them."""
client_secrets: dict[str, ClientSecretDetailList] = {}
for client_secret in session.exec(select(ClientSecret)).all():
for client_secret in session.scalars(select(ClientSecret)).all():
if client_secret.name not in client_secrets:
client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name)
@ -202,7 +203,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
) -> ClientSecretList:
"""Get a list of which clients has a named secret."""
clients: list[str] = []
for client_secret in session.exec(
for client_secret in session.scalars(
select(ClientSecret).where(ClientSecret.name == name)
).all():
if not client_secret.client:
@ -219,7 +220,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
) -> ClientSecretDetailList:
"""Get a list of which clients has a named secret."""
detail_list = ClientSecretDetailList(name=name)
for client_secret in session.exec(
for client_secret in session.scalars(
select(ClientSecret).where(ClientSecret.name == name)
).all():
if not client_secret.client:

View File

@ -2,9 +2,10 @@
from collections.abc import Sequence
from fastapi import Request
from sqlmodel import Session, select
from sqlalchemy import select
from sqlalchemy.orm import Session
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy, Operation, SubSystem
def _get_origin(request: Request) -> str | None:
@ -22,7 +23,7 @@ def _write_audit_log(
"""Write the audit log."""
origin = _get_origin(request)
entry.origin = origin
entry.subsystem = "backend"
entry.subsystem = SubSystem.BACKEND
session.add(entry)
if commit:
session.commit()
@ -33,7 +34,7 @@ def audit_create_client(
) -> None:
"""Log the creation of a client."""
entry = AuditLog(
operation="CREATE",
operation=Operation.CREATE,
client_id=client.id,
client_name=client.name,
message="Client Created",
@ -46,7 +47,7 @@ def audit_delete_client(
) -> None:
"""Log the creation of a client."""
entry = AuditLog(
operation="CREATE",
operation=Operation.CREATE,
client_id=client.id,
client_name=client.name,
message="Client deleted",
@ -63,9 +64,9 @@ def audit_create_secret(
) -> None:
"""Audit a create secret event."""
entry = AuditLog(
operation="CREATE",
object="ClientSecret",
object_id=str(secret.id),
operation=Operation.CREATE,
secret_id=secret.id,
secret_name=secret.name,
client_id=client.id,
client_name=client.name,
message="Added secret to client",
@ -81,13 +82,13 @@ def audit_remove_policy(
commit: bool = True,
) -> None:
"""Audit removal of policy."""
data = {"object": "ClientAccessPolicy", "object_id": str(policy.id)}
entry = AuditLog(
operation="DELETE",
object="ClientAccessPolicy",
object_id=str(policy.id),
operation=Operation.DELETE,
client_id=client.id,
client_name=client.name,
message="Deleted client policy",
data=data,
)
_write_audit_log(session, request, entry, commit)
@ -100,13 +101,13 @@ def audit_update_policy(
commit: bool = True,
) -> None:
"""Audit update of policy."""
data: dict[str, str] = {"object": "ClientAccessPolicy", "object_id": str(policy.id)}
entry = AuditLog(
operation="CREATE",
object="ClientAccessPolicy",
object_id=str(policy.id),
client_id=client.id,
operation=Operation.CREATE,
client_name=client.name,
client_id=client.id,
message="Updated client policy",
data=data,
)
_write_audit_log(session, request, entry, commit)
@ -119,11 +120,10 @@ def audit_update_client(
) -> None:
"""Audit an update secret event."""
entry = AuditLog(
operation="UPDATE",
object="Client",
operation=Operation.UPDATE,
client_id=client.id,
client_name=client.name,
message="Client updated",
message="Client data updated",
)
_write_audit_log(session, request, entry, commit)
@ -137,11 +137,11 @@ def audit_update_secret(
) -> None:
"""Audit an update secret event."""
entry = AuditLog(
operation="UPDATE",
object="ClientSecret",
object_id=str(secret.id),
operation=Operation.UPDATE,
client_id=client.id,
client_name=client.name,
secret_name=secret.name,
secret_id=secret.id,
message="Secret value updated",
)
_write_audit_log(session, request, entry, commit)
@ -155,8 +155,7 @@ def audit_invalidate_secrets(
) -> None:
"""Audit Invalidate client secrets."""
entry = AuditLog(
operation="INVALIDATE",
object="ClientSecret",
operation=Operation.UPDATE,
client_name=client.name,
client_id=client.id,
message="Client public-key changed. All secrets invalidated.",
@ -173,9 +172,9 @@ def audit_delete_secret(
) -> None:
"""Audit Delete client secrets."""
entry = AuditLog(
operation="DELETE",
object="ClientSecret",
object_id=str(secret.id),
operation=Operation.DELETE,
secret_name=secret.name,
secret_id=secret.id,
client_name=client.name,
client_id=client.id,
message="Deleted secret.",
@ -195,7 +194,7 @@ def audit_access_secrets(
With no secrets provided, all secrets of the client will be resolved.
"""
if not secrets:
secrets = session.exec(
secrets = session.scalars(
select(ClientSecret).where(ClientSecret.client_id == client.id)
).all()
@ -215,37 +214,21 @@ def audit_access_secret(
) -> None:
"""Audit that someone accessed one secrets."""
entry = AuditLog(
operation="ACCESS",
operation=Operation.READ,
message="Secret was viewed",
object="ClientSecret",
object_id=str(secret.id),
secret_name=secret.name,
secret_id=secret.id,
client_id=client.id,
client_name=client.name,
)
_write_audit_log(session, request, entry, commit)
def audit_access_audit_log(
session: Session, request: Request, commit: bool = True
) -> None:
"""Audit access to the audit log.
Because why not...
"""
entry = AuditLog(
operation="ACCESS",
message="Audit log was viewed",
object="AuditLog",
)
_write_audit_log(session, request, entry, commit)
def audit_client_secret_list(
session: Session, request: Request, commit: bool = True
) -> None:
"""Audit a list of all secrets."""
entry = AuditLog(
operation="ACCESS",
operation=Operation.READ,
message="All secret names and their clients was viewed",
)
_write_audit_log(session, request, entry, commit)

View File

@ -3,11 +3,11 @@
import logging
from typing import Annotated
import bcrypt
from fastapi import APIRouter, Depends, Header, HTTPException
from sqlmodel import Session, select
from sqlalchemy import select
from sqlalchemy.orm import Session
from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api
from .auth import verify_token
from .models import (
APIClient,
)
@ -18,13 +18,6 @@ LOG = logging.getLogger(__name__)
API_VERSION = "v1"
def verify_token(token: str, stored_hash: str) -> bool:
"""Verify token."""
token_bytes = token.encode("utf-8")
stored_bytes = stored_hash.encode("utf-8")
return bcrypt.checkpw(token_bytes, stored_bytes)
def get_backend_api(
get_db_session: DBSessionDep,
) -> APIRouter:
@ -37,7 +30,7 @@ def get_backend_api(
"""Validate token."""
LOG.debug("Validating token %s", x_api_token)
statement = select(APIClient)
results = session.exec(statement)
results = session.scalars(statement)
valid = False
for result in results:
if verify_token(x_api_token, result.token):

View File

@ -3,15 +3,24 @@
import code
import os
from pathlib import Path
from typing import cast
from dotenv import load_dotenv
from typing import Literal, cast
import click
from sqlmodel import Session, col, func, select
import uvicorn
from dotenv import load_dotenv
from sqlalchemy import select
from sqlalchemy.orm import Session
from .db import get_engine, create_api_token
from .models import Client, ClientSecret, ClientAccessPolicy, AuditLog, APIClient, init_db
from .db import create_api_token, get_engine, hash_token
from .models import (
APIClient,
AuditLog,
Client,
ClientAccessPolicy,
ClientSecret,
SubSystem,
init_db,
)
from .settings import BackendSettings
DEFAULT_LISTEN = "127.0.0.1"
@ -21,22 +30,44 @@ WORKDIR = Path(os.getcwd())
load_dotenv()
def generate_token(settings: BackendSettings) -> str:
def generate_token(
settings: BackendSettings, subsystem: Literal["admin", "sshd"]
) -> str:
"""Generate a token."""
engine = get_engine(settings.db_url)
init_db(engine)
with Session(engine) as session:
token = create_api_token(session, True)
token = create_api_token(session, subsystem)
return token
def count_tokens(settings: BackendSettings) -> int:
"""Count the amount of tokens created."""
def add_system_tokens(settings: BackendSettings) -> None:
"""Add token for subsystems."""
if not settings.admin_token and not settings.sshd_token:
# Tokens should be generated manually.
return
engine = get_engine(settings.db_url)
init_db(engine)
tokens: list[tuple[str, SubSystem]] = []
if admin_token := settings.admin_token:
tokens.append((admin_token, SubSystem.ADMIN))
if sshd_token := settings.sshd_token:
tokens.append((sshd_token, SubSystem.SSHD))
with Session(engine) as session:
count = session.exec(select(func.count("*")).select_from(APIClient)).one()
for token, subsystem in tokens:
hashed_token = hash_token(token)
if existing := session.scalars(
select(APIClient).where(APIClient.subsystem == subsystem)
).first():
existing.token = hashed_token
else:
new_token = APIClient(token=hashed_token, subsystem=subsystem)
session.add(new_token)
return count
session.commit()
click.echo("Generated system tokens.")
@click.group()
@ -49,27 +80,30 @@ def cli(ctx: click.Context, database: str) -> None:
else:
settings = BackendSettings()
add_system_tokens(settings)
if settings.generate_initial_tokens:
if count_tokens(settings) == 0:
click.echo("Creating initial tokens for admin and sshd.")
admin_token = generate_token(settings)
sshd_token = generate_token(settings)
click.echo(f"Admin token: {admin_token}")
click.echo(f"SSHD token: {sshd_token}")
# if settings.generate_initial_tokens:
# if count_tokens(settings) == 0:
# click.echo("Creating initial tokens for admin and sshd.")
# admin_token = generate_token(settings)
# sshd_token = generate_token(settings)
# click.echo(f"Admin token: {admin_token}")
# click.echo(f"SSHD token: {sshd_token}")
ctx.obj = settings
@cli.command("generate-token")
@click.argument("subsystem", type=click.Choice(["sshd", "admin"]))
@click.pass_context
def cli_generate_token(ctx: click.Context) -> None:
"""Generate a token."""
def cli_generate_token(ctx: click.Context, subsystem: Literal["sshd", "admin"]) -> None:
"""Generate a token for a subsystem.."""
settings = cast(BackendSettings, ctx.obj)
token = generate_token(settings)
token = generate_token(settings, subsystem)
click.echo("Generated api token:")
click.echo(token)
@cli.command("run")
@click.option("--host", default="127.0.0.1")
@click.option("--port", default=8022, type=click.INT)
@ -77,7 +111,10 @@ def cli_generate_token(ctx: click.Context) -> None:
@click.option("--workers", type=click.INT)
def cli_run(host: str, port: int, dev: bool, workers: int | None) -> None:
"""Run the server."""
uvicorn.run("sshecret_backend.main:app", host=host, port=port, reload=dev, workers=workers)
uvicorn.run(
"sshecret_backend.main:app", host=host, port=port, reload=dev, workers=workers
)
@cli.command("repl")
@click.pass_context

View File

@ -2,57 +2,109 @@
import logging
import secrets
import sqlite3
from collections.abc import Generator, Callable
from pathlib import Path
from sqlalchemy import Engine
from sqlmodel import Session, create_engine, text
import bcrypt
from typing import Literal
from sqlalchemy import create_engine, Engine, event, select
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.engine import URL
from .models import APIClient
from .auth import hash_token, verify_token
from .models import APIClient, SubSystem
LOG = logging.getLogger(__name__)
def setup_database(
db_url: URL | str,
db_url: URL,
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
"""Setup database."""
engine = create_engine(db_url, echo=False)
with engine.connect() as connection:
connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only
engine = get_engine(db_url)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, future=True)
def get_db_session() -> Generator[Session, None, None]:
"""Get DB Session."""
with Session(engine) as session:
session = SessionLocal(bind=engine)
try:
yield session
finally:
session.close()
return engine, get_db_session
def get_engine(url: URL, echo: bool = False) -> Engine:
"""Initialize the engine."""
engine = create_engine(url, echo=echo)
with engine.connect() as connection:
connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only
engine = create_engine(url, echo=echo, future=True)
if url.drivername.startswith("sqlite"):
@event.listens_for(engine, "connect")
def set_sqlite_pragma(
dbapi_connection: sqlite3.Connection, _connection_record: object
) -> None:
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
return engine
def create_api_token(session: Session, read_write: bool) -> str:
"""Create API token."""
token = secrets.token_urlsafe(32)
pwbytes = token.encode("utf-8")
salt = bcrypt.gensalt()
hashed_bytes = bcrypt.hashpw(password=pwbytes, salt=salt)
hashed = hashed_bytes.decode()
def get_async_engine(url: URL, echo: bool = False) -> AsyncEngine:
"""Get an async engine."""
engine = create_async_engine(url, echo=echo, future=True)
if url.drivername.startswith("sqlite+"):
@event.listens_for(engine, "connect")
def set_sqlite_pragma(
dbapi_connection: sqlite3.Connection, _connection_record: object
) -> None:
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
return engine
def create_api_token_with_value(session: Session, token: str, subsystem: Literal["admin", "sshd"]) -> None:
"""Create API token with a given value."""
existing = session.scalars(select(APIClient).where(APIClient.subsystem == SubSystem(subsystem))).first()
if existing:
if verify_token(token, existing.token):
LOG.info("Token is up to date.")
return
LOG.info("Updating token value for subsystem %s", subsystem)
hashed = hash_token(token)
existing.token=hashed
session.commit()
return
LOG.info("No existing token found. Creating new")
hashed = hash_token(token)
api_token = APIClient(token=hashed, subsystem=SubSystem(subsystem))
api_token = APIClient(token=hashed, read_write=read_write)
session.add(api_token)
session.commit()
def create_api_token(session: Session, subsystem: Literal["admin", "sshd", "test"], recreate: bool = False) -> str:
"""Create API token."""
subsys = SubSystem(subsystem)
token = secrets.token_urlsafe(32)
hashed = hash_token(token)
if existing := session.scalars(select(APIClient).where(APIClient.subsystem == subsys)).first():
if not recreate:
raise RuntimeError("Error: A token already exist for this subsystem.")
existing.token = hashed
else:
api_token = APIClient(token=hashed, subsystem=subsys)
session.add(api_token)
session.commit()
return token

View File

@ -7,128 +7,182 @@ This might require some changes to these schemas.
"""
import enum
import logging
import uuid
from datetime import datetime
import sqlalchemy as sa
from sqlmodel import JSON, Column, DateTime, Field, Relationship, SQLModel
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
LOG = logging.getLogger(__name__)
class Client(SQLModel, table=True):
"""Client model."""
class SubSystem(enum.StrEnum):
"""Available subsystems."""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
name: str = Field(unique=True)
description: str | None = None
public_key: str
ADMIN = enum.auto()
SSHD = enum.auto()
BACKEND = enum.auto()
TEST = enum.auto()
created_at: datetime | None = Field(
default=None,
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"server_default": sa.func.now()},
nullable=False,
class Operation(enum.StrEnum):
"""Various operations for the audit logging module."""
CREATE = enum.auto()
READ = enum.auto()
UPDATE = enum.auto()
DELETE = enum.auto()
DENY = enum.auto()
PERMIT = enum.auto()
LOGIN = enum.auto()
REGISTER = enum.auto()
NONE = enum.auto()
class Base(DeclarativeBase):
pass
class Client(Base):
"""Clients."""
__tablename__: str = "client"
id: Mapped[uuid.UUID] = mapped_column(
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
)
name: Mapped[str] = mapped_column(sa.String, unique=True)
description: Mapped[str | None] = mapped_column(sa.String, nullable=True)
public_key: Mapped[str] = mapped_column(sa.Text)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
)
updated_at: datetime | None = Field(
default=None,
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
updated_at: Mapped[datetime | None] = mapped_column(
sa.DateTime(timezone=True),
server_default=sa.func.now(),
onupdate=sa.func.now(),
)
secrets: list["ClientSecret"] = Relationship(
back_populates="client", passive_deletes="all"
secrets: Mapped[list["ClientSecret"]] = relationship(
back_populates="client", passive_deletes=True
)
policies: list["ClientAccessPolicy"] = Relationship(back_populates="client")
policies: Mapped[list["ClientAccessPolicy"]] = relationship(back_populates="client")
class ClientAccessPolicy(SQLModel, table=True):
class ClientAccessPolicy(Base):
"""Client access policies."""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
source: str
client_id: uuid.UUID | None = Field(foreign_key="client.id", ondelete="CASCADE")
client: Client | None = Relationship(back_populates="policies")
__tablename__: str = "client_access_policy"
created_at: datetime | None = Field(
default=None,
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"server_default": sa.func.now()},
nullable=False,
id: Mapped[uuid.UUID] = mapped_column(
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
)
source: Mapped[str] = mapped_column(sa.String)
client_id: Mapped[uuid.UUID | None] = mapped_column(
sa.Uuid(as_uuid=True), sa.ForeignKey("client.id", ondelete="CASCADE")
)
client: Mapped[Client] = relationship(back_populates="policies")
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
)
updated_at: datetime | None = Field(
default=None,
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
updated_at: Mapped[datetime | None] = mapped_column(
sa.DateTime(timezone=True),
server_default=sa.func.now(),
onupdate=sa.func.now(),
)
class ClientSecret(SQLModel, table=True):
class ClientSecret(Base):
"""A client secret."""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
name: str
description: str | None = None
client_id: uuid.UUID | None = Field(foreign_key="client.id", ondelete="CASCADE")
client: Client | None = Relationship(back_populates="secrets")
secret: str
invalidated: bool = Field(default=False)
created_at: datetime | None = Field(
default=None,
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"server_default": sa.func.now()},
nullable=False,
__tablename__: str = "client_secret"
id: Mapped[uuid.UUID] = mapped_column(
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
)
name: Mapped[str] = mapped_column(sa.String)
description: Mapped[str | None] = mapped_column(sa.String, nullable=True)
secret: Mapped[str] = mapped_column(sa.String)
client_id: Mapped[uuid.UUID | None] = mapped_column(
sa.Uuid(as_uuid=True), sa.ForeignKey("client.id", ondelete="CASCADE")
)
client: Mapped[Client] = relationship(back_populates="secrets")
invalidated: Mapped[bool] = mapped_column(default=False)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
)
updated_at: datetime | None = Field(
default=None,
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
updated_at: Mapped[datetime | None] = mapped_column(
sa.DateTime(timezone=True),
server_default=sa.func.now(),
onupdate=sa.func.now(),
)
class AuditLog(SQLModel, table=True):
class APIClient(Base):
"""A client on the API.
This should eventually get more granular permissions.
"""
__tablename__: str = "api_client"
id: Mapped[uuid.UUID] = mapped_column(
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
)
subsystem: Mapped[SubSystem | None] = mapped_column(sa.String, nullable=True)
token: Mapped[str] = mapped_column(sa.String)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
)
updated_at: Mapped[datetime | None] = mapped_column(
sa.DateTime(timezone=True),
server_default=sa.func.now(),
onupdate=sa.func.now(),
)
class AuditLog(Base):
"""Audit log.
This is implemented without any foreign keys to avoid losing data on
deletions.
"""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
subsystem: str
message: str
operation: str
client_id: uuid.UUID | None = None
client_name: str | None = None
origin: str | None = None
Field(default=None, sa_column=Column(JSON))
timestamp: datetime | None = Field(
default=None,
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"server_default": sa.func.now()},
nullable=False,
__tablename__: str = "audit_log"
id: Mapped[uuid.UUID] = mapped_column(
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
)
subsystem: Mapped[SubSystem] = mapped_column(sa.String)
message: Mapped[str] = mapped_column(sa.String)
operation: Mapped[Operation] = mapped_column(sa.String)
client_id: Mapped[uuid.UUID | None] = mapped_column(
sa.Uuid(as_uuid=True), nullable=True
)
data: Mapped[dict[str, str] | None] = mapped_column(sa.JSON, nullable=True)
client_name: Mapped[str | None] = mapped_column(sa.String, nullable=True)
secret_id: Mapped[uuid.UUID | None] = mapped_column(
sa.Uuid(as_uuid=True), nullable=True
)
secret_name: Mapped[str | None] = mapped_column(sa.String, nullable=True)
class APIClient(SQLModel, table=True):
"""Stores API Keys."""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
token: str
read_write: bool
created_at: datetime | None = Field(
default=None,
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"server_default": sa.func.now()},
nullable=False,
origin: Mapped[str | None] = mapped_column(sa.String, nullable=True)
timestamp: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
)
def init_db(engine: sa.Engine) -> None:
"""Create database."""
LOG.info("Running init_db")
SQLModel.metadata.create_all(engine)
"""Initialize database."""
Base.metadata.create_all(engine)

View File

@ -1,12 +1,10 @@
"""Settings management."""
from pathlib import Path
from typing import Annotated, Any
from pydantic import Field, field_validator
from pydantic import Field
from pydantic_settings import (
BaseSettings,
SettingsConfigDict,
ForceDecode,
)
from sqlalchemy import URL
@ -22,24 +20,19 @@ class BackendSettings(BaseSettings):
)
database: str = Field(default=DEFAULT_DATABASE)
generate_initial_tokens: Annotated[bool, ForceDecode] = Field(default=False)
@field_validator("generate_initial_tokens", mode="before")
@classmethod
def cast_bool(cls, value: Any) -> bool:
"""Ensure we catch the boolean."""
if isinstance(value, str):
if value.lower() in ("1", "true", "on"):
return True
if value.lower() in ("0", "false", "off"):
return False
return bool(value)
admin_token: str | None = Field(default=None, alias="sshecret_admin_backend_token")
sshd_token: str | None = Field(default=None, alias="sshecret_sshd_backend_token")
@property
def db_url(self) -> URL:
"""Construct database url."""
return URL.create(drivername="sqlite", database=self.database)
@property
def async_db_url(self) -> URL:
"""Construct database url with sync handling."""
return URL.create(drivername="sqlite+aiosqlite", database=self.database)
@property
def db_exists(self) -> bool:
"""Check if databatase exists."""

View File

@ -1,8 +1,9 @@
"""Test helpers."""
import logging
from sqlmodel import Session
from sshecret_backend.settings import BackendSettings
from sqlalchemy.orm import Session
from .models import init_db
from .db import create_api_token, setup_database
@ -14,4 +15,4 @@ def create_test_token(settings: BackendSettings) -> str:
engine, _setupdb = setup_database(settings.db_url)
with Session(engine) as session:
init_db(engine)
return create_api_token(session, True)
return create_api_token(session, "test")

View File

@ -2,7 +2,7 @@
from collections.abc import Callable, Generator
from sqlmodel import Session
from sqlalchemy.orm import Session
DBSessionDep = Callable[[], Generator[Session, None, None]]

View File

@ -4,15 +4,14 @@ import uuid
from datetime import datetime
from typing import Annotated, Self, override
from sqlmodel import Field, SQLModel
from pydantic import AfterValidator, IPvAnyAddress, IPvAnyNetwork
from pydantic import AfterValidator, BaseModel, Field, IPvAnyAddress, IPvAnyNetwork
from sshecret.crypto import public_key_validator
from . import models
class ClientView(SQLModel):
class ClientView(BaseModel):
"""View for a single client."""
id: uuid.UUID
@ -50,7 +49,7 @@ class ClientView(SQLModel):
return view
class ClientQueryResult(SQLModel):
class ClientQueryResult(BaseModel):
"""Result class for queries towards the client list."""
clients: list[ClientView] = Field(default_factory=list)
@ -58,7 +57,7 @@ class ClientQueryResult(SQLModel):
remaining_results: int
class ClientCreate(SQLModel):
class ClientCreate(BaseModel):
"""Model to create a client."""
name: str
@ -74,19 +73,19 @@ class ClientCreate(SQLModel):
)
class ClientUpdate(SQLModel):
class ClientUpdate(BaseModel):
"""Model to update the client public key."""
public_key: Annotated[str, AfterValidator(public_key_validator)]
class BodyValue(SQLModel):
class BodyValue(BaseModel):
"""A generic model with just a value parameter."""
value: str
class ClientSecretPublic(SQLModel):
class ClientSecretPublic(BaseModel):
"""Public model to manage client secrets."""
name: str
@ -122,7 +121,7 @@ class ClientSecretResponse(ClientSecretPublic):
)
class ClientPolicyView(SQLModel):
class ClientPolicyView(BaseModel):
"""Update object for client policy."""
sources: list[str] = ["0.0.0.0/0", "::/0"]
@ -135,27 +134,27 @@ class ClientPolicyView(SQLModel):
return cls(sources=[policy.source for policy in client.policies])
class ClientPolicyUpdate(SQLModel):
class ClientPolicyUpdate(BaseModel):
"""Model for updating policies."""
sources: list[IPvAnyAddress | IPvAnyNetwork]
class ClientSecretList(SQLModel):
class ClientSecretList(BaseModel):
"""Model for aggregating identically named secrets."""
name: str
clients: list[str]
class ClientReference(SQLModel):
class ClientReference(BaseModel):
"""Reference to a client."""
id: str
name: str
class ClientSecretDetailList(SQLModel):
class ClientSecretDetailList(BaseModel):
"""A more detailed version of the ClientSecretList."""
name: str
@ -163,7 +162,22 @@ class ClientSecretDetailList(SQLModel):
clients: list[ClientReference] = Field(default_factory=list)
class AuditInfo(SQLModel):
class AuditView(BaseModel):
"""Audit log view."""
id: uuid.UUID | None = None
subsystem: models.SubSystem
message: str
operation: models.Operation
data: dict[str, str] | None = None
client_id: uuid.UUID | None = None
client_name: str | None = None
origin: str | None = None
timestamp: datetime | None = None
class AuditInfo(BaseModel):
"""Information about audit information."""
entries: int