Refactor database layer and auditing

This commit is contained in:
2025-05-10 08:38:57 +02:00
parent d866553ac1
commit 9ccd2f1d4d
20 changed files with 718 additions and 469 deletions

View File

@ -3,22 +3,22 @@ from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from sqlmodel import create_engine
from alembic import context
from sshecret_backend.models import *
def get_database_url() -> str:
"""Get database URL."""
if db_file := os.getenv("SSHECRET_BACKEND_DB"):
return f"sqlite:///{db_file}"
return "sqlite:///sshecret.db"
from sshecret_backend.models import Base
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
def get_database_url() -> str | None:
"""Get database URL."""
if db_file := os.getenv("SSHECRET_BACKEND_DB"):
return f"sqlite:///{db_file}"
return config.get_main_option("sqlalchemy.url")
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
@ -28,8 +28,7 @@ if config.config_file_name is not None:
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
#target_metadata = None
target_metadata = SQLModel.metadata
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
@ -68,7 +67,11 @@ def run_migrations_online() -> None:
and associate a connection with the context.
"""
connectable = create_engine(get_database_url())
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(

View File

@ -9,7 +9,6 @@ from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
${imports if imports else ""}
# revision identifiers, used by Alembic.

View File

@ -5,13 +5,14 @@
import logging
from collections.abc import Sequence
from fastapi import APIRouter, Depends, Request, Query
from sqlmodel import Session, col, func, select
from sqlalchemy import desc
from pydantic import TypeAdapter
from sqlalchemy import select, func
from sqlalchemy.orm import Session
from typing import Annotated
from sshecret_backend.models import AuditLog
from sshecret_backend.types import DBSessionDep
from sshecret_backend.view_models import AuditInfo
from sshecret_backend.view_models import AuditInfo, AuditView
LOG = logging.getLogger(__name__)
@ -21,7 +22,7 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
"""Construct audit sub-api."""
router = APIRouter()
@router.get("/audit/", response_model=list[AuditLog])
@router.get("/audit/", response_model=list[AuditView])
async def get_audit_logs(
request: Request,
session: Annotated[Session, Depends(get_db_session)],
@ -29,35 +30,37 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
limit: Annotated[int, Query(le=100)] = 100,
filter_client: Annotated[str | None, Query()] = None,
filter_subsystem: Annotated[str | None, Query()] = None,
) -> Sequence[AuditLog]:
) -> Sequence[AuditView]:
"""Get audit logs."""
#audit.audit_access_audit_log(session, request)
statement = select(AuditLog).offset(offset).limit(limit).order_by(desc(col(AuditLog.timestamp)))
statement = select(AuditLog).offset(offset).limit(limit).order_by(AuditLog.timestamp.desc())
if filter_client:
statement = statement.where(AuditLog.client_name == filter_client)
if filter_subsystem:
statement = statement.where(AuditLog.subsystem == filter_subsystem)
results = session.exec(statement).all()
return results
LogAdapt = TypeAdapter(list[AuditView])
results = session.scalars(statement).all()
return LogAdapt.validate_python(results, from_attributes=True)
@router.post("/audit/")
async def add_audit_log(
request: Request,
session: Annotated[Session, Depends(get_db_session)],
entry: AuditLog,
) -> AuditLog:
entry: AuditView,
) -> AuditView:
"""Add entry to audit log."""
audit_log = AuditLog.model_validate(entry.model_dump(exclude_none=True))
audit_log = AuditLog(**entry.model_dump(exclude_none=True))
session.add(audit_log)
session.commit()
return audit_log
return AuditView.model_validate(audit_log, from_attributes=True)
@router.get("/audit/info")
async def get_audit_info(request: Request, session: Annotated[Session, Depends(get_db_session)]) -> AuditInfo:
"""Get audit info."""
audit_count = session.exec(select(func.count('*')).select_from(AuditLog)).one()
audit_count = session.scalars(select(func.count('*')).select_from(AuditLog)).one()
return AuditInfo(entries=audit_count)

View File

@ -6,11 +6,11 @@ import uuid
import logging
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from pydantic import BaseModel, Field, model_validator
from sqlmodel import Session, col, select
from sqlalchemy import func
from typing import Annotated, Self, TypeVar
from typing import Annotated, Any, Self, TypeVar, cast
from sqlmodel.sql.expression import SelectOfScalar
from sqlalchemy import select, func
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select
from sshecret_backend.types import DBSessionDep
from sshecret_backend.models import Client, ClientSecret
from sshecret_backend.view_models import (
@ -55,8 +55,8 @@ T = TypeVar("T")
def filter_client_statement(
statement: SelectOfScalar[T], params: ClientListParams, ignore_limits: bool = False
) -> SelectOfScalar[T]:
statement: Select[Any], params: ClientListParams, ignore_limits: bool = False
) -> Select[Any]:
"""Filter a statement with the provided params."""
if params.id:
statement = statement.where(Client.id == params.id)
@ -64,9 +64,9 @@ def filter_client_statement(
if params.name:
statement = statement.where(Client.name == params.name)
elif params.name__like:
statement = statement.where(col(Client.name).like(params.name__like))
statement = statement.where(Client.name.like(params.name__like))
elif params.name__contains:
statement = statement.where(col(Client.name).contains(params.name__contains))
statement = statement.where(Client.name.contains(params.name__contains))
if ignore_limits:
return statement
@ -86,13 +86,13 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
"""Get clients."""
# Get total results first
count_statement = select(func.count("*")).select_from(Client)
count_statement = filter_client_statement(count_statement, filter_query, True)
count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True))
total_results = session.exec(count_statement).one()
total_results = session.scalars(count_statement).one()
statement = filter_client_statement(select(Client), filter_query, False)
results = session.exec(statement)
results = session.scalars(statement)
remainder = total_results - filter_query.offset - filter_query.limit
if remainder < 0:
remainder = 0
@ -170,13 +170,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
status_code=404, detail="Cannot find a client with the given name."
)
client.public_key = client_update.public_key
for secret in session.exec(
for secret in session.scalars(
select(ClientSecret).where(ClientSecret.client_id == client.id)
).all():
LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True
secret.client_id = None
secret.client = None
session.add(client)
session.refresh(client)
@ -206,13 +205,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
public_key_updated = False
if client_update.public_key != client.public_key:
public_key_updated = True
for secret in session.exec(
for secret in session.scalars(
select(ClientSecret).where(ClientSecret.client_id == client.id)
).all():
LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True
secret.client_id = None
secret.client = None
session.add(client)
session.commit()

View File

@ -4,7 +4,8 @@ import re
import uuid
import bcrypt
from sqlmodel import Session, select
from sqlalchemy import select
from sqlalchemy.orm import Session
from sshecret_backend.models import Client
@ -20,13 +21,13 @@ def verify_token(token: str, stored_hash: str) -> bool:
async def get_client_by_name(session: Session, name: str) -> Client | None:
"""Get client by name."""
client_filter = select(Client).where(Client.name == name)
client_results = session.exec(client_filter)
client_results = session.scalars(client_filter)
return client_results.first()
async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None:
"""Get client by name."""
client_filter = select(Client).where(Client.id == id)
client_results = session.exec(client_filter)
client_results = session.scalars(client_filter)
return client_results.first()
async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None:

View File

@ -4,7 +4,8 @@
import logging
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlmodel import Session, select
from sqlalchemy import select
from sqlalchemy.orm import Session
from typing import Annotated
from sshecret_backend.models import ClientAccessPolicy
@ -54,7 +55,7 @@ def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
status_code=404, detail="Cannot find a client with the given name."
)
# Remove old policies.
policies = session.exec(
policies = session.scalars(
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
).all()
deleted_policies: list[ClientAccessPolicy] = []

View File

@ -5,7 +5,8 @@
import logging
from collections import defaultdict
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlmodel import Session, select
from sqlalchemy import select
from sqlalchemy.orm import Session
from typing import Annotated
from sshecret_backend.models import Client, ClientSecret
@ -34,7 +35,7 @@ async def lookup_client_secret(
.where(ClientSecret.client_id == client.id)
.where(ClientSecret.name == name)
)
results = session.exec(statement)
results = session.scalars(statement)
return results.first()
@ -165,7 +166,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
) -> list[ClientSecretList]:
"""Get a list of all secrets and which clients have them."""
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
for client_secret in session.exec(select(ClientSecret)).all():
for client_secret in session.scalars(select(ClientSecret)).all():
if not client_secret.client:
if client_secret.name not in client_secret_map:
client_secret_map[client_secret.name] = []
@ -182,7 +183,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
) -> list[ClientSecretDetailList]:
"""Get a list of all secrets and which clients have them."""
client_secrets: dict[str, ClientSecretDetailList] = {}
for client_secret in session.exec(select(ClientSecret)).all():
for client_secret in session.scalars(select(ClientSecret)).all():
if client_secret.name not in client_secrets:
client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name)
@ -202,7 +203,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
) -> ClientSecretList:
"""Get a list of which clients has a named secret."""
clients: list[str] = []
for client_secret in session.exec(
for client_secret in session.scalars(
select(ClientSecret).where(ClientSecret.name == name)
).all():
if not client_secret.client:
@ -219,7 +220,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
) -> ClientSecretDetailList:
"""Get a list of which clients has a named secret."""
detail_list = ClientSecretDetailList(name=name)
for client_secret in session.exec(
for client_secret in session.scalars(
select(ClientSecret).where(ClientSecret.name == name)
).all():
if not client_secret.client:

View File

@ -2,9 +2,10 @@
from collections.abc import Sequence
from fastapi import Request
from sqlmodel import Session, select
from sqlalchemy import select
from sqlalchemy.orm import Session
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy, Operation, SubSystem
def _get_origin(request: Request) -> str | None:
@ -22,7 +23,7 @@ def _write_audit_log(
"""Write the audit log."""
origin = _get_origin(request)
entry.origin = origin
entry.subsystem = "backend"
entry.subsystem = SubSystem.BACKEND
session.add(entry)
if commit:
session.commit()
@ -33,7 +34,7 @@ def audit_create_client(
) -> None:
"""Log the creation of a client."""
entry = AuditLog(
operation="CREATE",
operation=Operation.CREATE,
client_id=client.id,
client_name=client.name,
message="Client Created",
@ -46,7 +47,7 @@ def audit_delete_client(
) -> None:
"""Log the creation of a client."""
entry = AuditLog(
operation="CREATE",
operation=Operation.CREATE,
client_id=client.id,
client_name=client.name,
message="Client deleted",
@ -63,9 +64,9 @@ def audit_create_secret(
) -> None:
"""Audit a create secret event."""
entry = AuditLog(
operation="CREATE",
object="ClientSecret",
object_id=str(secret.id),
operation=Operation.CREATE,
secret_id=secret.id,
secret_name=secret.name,
client_id=client.id,
client_name=client.name,
message="Added secret to client",
@ -81,13 +82,13 @@ def audit_remove_policy(
commit: bool = True,
) -> None:
"""Audit removal of policy."""
data = {"object": "ClientAccessPolicy", "object_id": str(policy.id)}
entry = AuditLog(
operation="DELETE",
object="ClientAccessPolicy",
object_id=str(policy.id),
operation=Operation.DELETE,
client_id=client.id,
client_name=client.name,
message="Deleted client policy",
data=data,
)
_write_audit_log(session, request, entry, commit)
@ -100,13 +101,13 @@ def audit_update_policy(
commit: bool = True,
) -> None:
"""Audit update of policy."""
data: dict[str, str] = {"object": "ClientAccessPolicy", "object_id": str(policy.id)}
entry = AuditLog(
operation="CREATE",
object="ClientAccessPolicy",
object_id=str(policy.id),
client_id=client.id,
operation=Operation.CREATE,
client_name=client.name,
client_id=client.id,
message="Updated client policy",
data=data,
)
_write_audit_log(session, request, entry, commit)
@ -119,11 +120,10 @@ def audit_update_client(
) -> None:
"""Audit an update secret event."""
entry = AuditLog(
operation="UPDATE",
object="Client",
operation=Operation.UPDATE,
client_id=client.id,
client_name=client.name,
message="Client updated",
message="Client data updated",
)
_write_audit_log(session, request, entry, commit)
@ -137,11 +137,11 @@ def audit_update_secret(
) -> None:
"""Audit an update secret event."""
entry = AuditLog(
operation="UPDATE",
object="ClientSecret",
object_id=str(secret.id),
operation=Operation.UPDATE,
client_id=client.id,
client_name=client.name,
secret_name=secret.name,
secret_id=secret.id,
message="Secret value updated",
)
_write_audit_log(session, request, entry, commit)
@ -155,8 +155,7 @@ def audit_invalidate_secrets(
) -> None:
"""Audit Invalidate client secrets."""
entry = AuditLog(
operation="INVALIDATE",
object="ClientSecret",
operation=Operation.UPDATE,
client_name=client.name,
client_id=client.id,
message="Client public-key changed. All secrets invalidated.",
@ -173,9 +172,9 @@ def audit_delete_secret(
) -> None:
"""Audit Delete client secrets."""
entry = AuditLog(
operation="DELETE",
object="ClientSecret",
object_id=str(secret.id),
operation=Operation.DELETE,
secret_name=secret.name,
secret_id=secret.id,
client_name=client.name,
client_id=client.id,
message="Deleted secret.",
@ -195,7 +194,7 @@ def audit_access_secrets(
With no secrets provided, all secrets of the client will be resolved.
"""
if not secrets:
secrets = session.exec(
secrets = session.scalars(
select(ClientSecret).where(ClientSecret.client_id == client.id)
).all()
@ -215,37 +214,21 @@ def audit_access_secret(
) -> None:
"""Audit that someone accessed one secrets."""
entry = AuditLog(
operation="ACCESS",
operation=Operation.READ,
message="Secret was viewed",
object="ClientSecret",
object_id=str(secret.id),
secret_name=secret.name,
secret_id=secret.id,
client_id=client.id,
client_name=client.name,
)
_write_audit_log(session, request, entry, commit)
def audit_access_audit_log(
session: Session, request: Request, commit: bool = True
) -> None:
"""Audit access to the audit log.
Because why not...
"""
entry = AuditLog(
operation="ACCESS",
message="Audit log was viewed",
object="AuditLog",
)
_write_audit_log(session, request, entry, commit)
def audit_client_secret_list(
session: Session, request: Request, commit: bool = True
) -> None:
"""Audit a list of all secrets."""
entry = AuditLog(
operation="ACCESS",
operation=Operation.READ,
message="All secret names and their clients was viewed",
)
_write_audit_log(session, request, entry, commit)

View File

@ -3,11 +3,11 @@
import logging
from typing import Annotated
import bcrypt
from fastapi import APIRouter, Depends, Header, HTTPException
from sqlmodel import Session, select
from sqlalchemy import select
from sqlalchemy.orm import Session
from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api
from .auth import verify_token
from .models import (
APIClient,
)
@ -18,13 +18,6 @@ LOG = logging.getLogger(__name__)
API_VERSION = "v1"
def verify_token(token: str, stored_hash: str) -> bool:
"""Verify token."""
token_bytes = token.encode("utf-8")
stored_bytes = stored_hash.encode("utf-8")
return bcrypt.checkpw(token_bytes, stored_bytes)
def get_backend_api(
get_db_session: DBSessionDep,
) -> APIRouter:
@ -37,7 +30,7 @@ def get_backend_api(
"""Validate token."""
LOG.debug("Validating token %s", x_api_token)
statement = select(APIClient)
results = session.exec(statement)
results = session.scalars(statement)
valid = False
for result in results:
if verify_token(x_api_token, result.token):

View File

@ -3,15 +3,24 @@
import code
import os
from pathlib import Path
from typing import cast
from dotenv import load_dotenv
from typing import Literal, cast
import click
from sqlmodel import Session, col, func, select
import uvicorn
from dotenv import load_dotenv
from sqlalchemy import select
from sqlalchemy.orm import Session
from .db import get_engine, create_api_token
from .models import Client, ClientSecret, ClientAccessPolicy, AuditLog, APIClient, init_db
from .db import create_api_token, get_engine, hash_token
from .models import (
APIClient,
AuditLog,
Client,
ClientAccessPolicy,
ClientSecret,
SubSystem,
init_db,
)
from .settings import BackendSettings
DEFAULT_LISTEN = "127.0.0.1"
@ -21,22 +30,44 @@ WORKDIR = Path(os.getcwd())
load_dotenv()
def generate_token(settings: BackendSettings) -> str:
def generate_token(
settings: BackendSettings, subsystem: Literal["admin", "sshd"]
) -> str:
"""Generate a token."""
engine = get_engine(settings.db_url)
init_db(engine)
with Session(engine) as session:
token = create_api_token(session, True)
token = create_api_token(session, subsystem)
return token
def count_tokens(settings: BackendSettings) -> int:
"""Count the amount of tokens created."""
def add_system_tokens(settings: BackendSettings) -> None:
"""Add token for subsystems."""
if not settings.admin_token and not settings.sshd_token:
# Tokens should be generated manually.
return
engine = get_engine(settings.db_url)
init_db(engine)
tokens: list[tuple[str, SubSystem]] = []
if admin_token := settings.admin_token:
tokens.append((admin_token, SubSystem.ADMIN))
if sshd_token := settings.sshd_token:
tokens.append((sshd_token, SubSystem.SSHD))
with Session(engine) as session:
count = session.exec(select(func.count("*")).select_from(APIClient)).one()
for token, subsystem in tokens:
hashed_token = hash_token(token)
if existing := session.scalars(
select(APIClient).where(APIClient.subsystem == subsystem)
).first():
existing.token = hashed_token
else:
new_token = APIClient(token=hashed_token, subsystem=subsystem)
session.add(new_token)
return count
session.commit()
click.echo("Generated system tokens.")
@click.group()
@ -49,27 +80,30 @@ def cli(ctx: click.Context, database: str) -> None:
else:
settings = BackendSettings()
add_system_tokens(settings)
if settings.generate_initial_tokens:
if count_tokens(settings) == 0:
click.echo("Creating initial tokens for admin and sshd.")
admin_token = generate_token(settings)
sshd_token = generate_token(settings)
click.echo(f"Admin token: {admin_token}")
click.echo(f"SSHD token: {sshd_token}")
# if settings.generate_initial_tokens:
# if count_tokens(settings) == 0:
# click.echo("Creating initial tokens for admin and sshd.")
# admin_token = generate_token(settings)
# sshd_token = generate_token(settings)
# click.echo(f"Admin token: {admin_token}")
# click.echo(f"SSHD token: {sshd_token}")
ctx.obj = settings
@cli.command("generate-token")
@click.argument("subsystem", type=click.Choice(["sshd", "admin"]))
@click.pass_context
def cli_generate_token(ctx: click.Context) -> None:
"""Generate a token."""
def cli_generate_token(ctx: click.Context, subsystem: Literal["sshd", "admin"]) -> None:
"""Generate a token for a subsystem.."""
settings = cast(BackendSettings, ctx.obj)
token = generate_token(settings)
token = generate_token(settings, subsystem)
click.echo("Generated api token:")
click.echo(token)
@cli.command("run")
@click.option("--host", default="127.0.0.1")
@click.option("--port", default=8022, type=click.INT)
@ -77,7 +111,10 @@ def cli_generate_token(ctx: click.Context) -> None:
@click.option("--workers", type=click.INT)
def cli_run(host: str, port: int, dev: bool, workers: int | None) -> None:
"""Run the server."""
uvicorn.run("sshecret_backend.main:app", host=host, port=port, reload=dev, workers=workers)
uvicorn.run(
"sshecret_backend.main:app", host=host, port=port, reload=dev, workers=workers
)
@cli.command("repl")
@click.pass_context

View File

@ -2,56 +2,108 @@
import logging
import secrets
import sqlite3
from collections.abc import Generator, Callable
from pathlib import Path
from sqlalchemy import Engine
from sqlmodel import Session, create_engine, text
import bcrypt
from typing import Literal
from sqlalchemy import create_engine, Engine, event, select
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.engine import URL
from .models import APIClient
from .auth import hash_token, verify_token
from .models import APIClient, SubSystem
LOG = logging.getLogger(__name__)
def setup_database(
db_url: URL | str,
db_url: URL,
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
"""Setup database."""
engine = create_engine(db_url, echo=False)
with engine.connect() as connection:
connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only
engine = get_engine(db_url)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, future=True)
def get_db_session() -> Generator[Session, None, None]:
"""Get DB Session."""
with Session(engine) as session:
session = SessionLocal(bind=engine)
try:
yield session
finally:
session.close()
return engine, get_db_session
def get_engine(url: URL, echo: bool = False) -> Engine:
"""Initialize the engine."""
engine = create_engine(url, echo=echo)
with engine.connect() as connection:
connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only
engine = create_engine(url, echo=echo, future=True)
if url.drivername.startswith("sqlite"):
@event.listens_for(engine, "connect")
def set_sqlite_pragma(
dbapi_connection: sqlite3.Connection, _connection_record: object
) -> None:
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
return engine
def create_api_token(session: Session, read_write: bool) -> str:
"""Create API token."""
token = secrets.token_urlsafe(32)
pwbytes = token.encode("utf-8")
salt = bcrypt.gensalt()
hashed_bytes = bcrypt.hashpw(password=pwbytes, salt=salt)
hashed = hashed_bytes.decode()
def get_async_engine(url: URL, echo: bool = False) -> AsyncEngine:
"""Get an async engine."""
engine = create_async_engine(url, echo=echo, future=True)
if url.drivername.startswith("sqlite+"):
api_token = APIClient(token=hashed, read_write=read_write)
@event.listens_for(engine, "connect")
def set_sqlite_pragma(
dbapi_connection: sqlite3.Connection, _connection_record: object
) -> None:
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
return engine
def create_api_token_with_value(session: Session, token: str, subsystem: Literal["admin", "sshd"]) -> None:
"""Create API token with a given value."""
existing = session.scalars(select(APIClient).where(APIClient.subsystem == SubSystem(subsystem))).first()
if existing:
if verify_token(token, existing.token):
LOG.info("Token is up to date.")
return
LOG.info("Updating token value for subsystem %s", subsystem)
hashed = hash_token(token)
existing.token=hashed
session.commit()
return
LOG.info("No existing token found. Creating new")
hashed = hash_token(token)
api_token = APIClient(token=hashed, subsystem=SubSystem(subsystem))
session.add(api_token)
session.commit()
def create_api_token(session: Session, subsystem: Literal["admin", "sshd", "test"], recreate: bool = False) -> str:
"""Create API token."""
subsys = SubSystem(subsystem)
token = secrets.token_urlsafe(32)
hashed = hash_token(token)
if existing := session.scalars(select(APIClient).where(APIClient.subsystem == subsys)).first():
if not recreate:
raise RuntimeError("Error: A token already exist for this subsystem.")
existing.token = hashed
else:
api_token = APIClient(token=hashed, subsystem=subsys)
session.add(api_token)
session.commit()

View File

@ -7,128 +7,182 @@ This might require some changes to these schemas.
"""
import enum
import logging
import uuid
from datetime import datetime
import sqlalchemy as sa
from sqlmodel import JSON, Column, DateTime, Field, Relationship, SQLModel
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
LOG = logging.getLogger(__name__)
class Client(SQLModel, table=True):
"""Client model."""
class SubSystem(enum.StrEnum):
"""Available subsystems."""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
name: str = Field(unique=True)
description: str | None = None
public_key: str
ADMIN = enum.auto()
SSHD = enum.auto()
BACKEND = enum.auto()
TEST = enum.auto()
created_at: datetime | None = Field(
default=None,
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"server_default": sa.func.now()},
nullable=False,
class Operation(enum.StrEnum):
"""Various operations for the audit logging module."""
CREATE = enum.auto()
READ = enum.auto()
UPDATE = enum.auto()
DELETE = enum.auto()
DENY = enum.auto()
PERMIT = enum.auto()
LOGIN = enum.auto()
REGISTER = enum.auto()
NONE = enum.auto()
class Base(DeclarativeBase):
pass
class Client(Base):
"""Clients."""
__tablename__: str = "client"
id: Mapped[uuid.UUID] = mapped_column(
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
)
name: Mapped[str] = mapped_column(sa.String, unique=True)
description: Mapped[str | None] = mapped_column(sa.String, nullable=True)
public_key: Mapped[str] = mapped_column(sa.Text)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
)
updated_at: datetime | None = Field(
default=None,
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
updated_at: Mapped[datetime | None] = mapped_column(
sa.DateTime(timezone=True),
server_default=sa.func.now(),
onupdate=sa.func.now(),
)
secrets: list["ClientSecret"] = Relationship(
back_populates="client", passive_deletes="all"
secrets: Mapped[list["ClientSecret"]] = relationship(
back_populates="client", passive_deletes=True
)
policies: list["ClientAccessPolicy"] = Relationship(back_populates="client")
policies: Mapped[list["ClientAccessPolicy"]] = relationship(back_populates="client")
class ClientAccessPolicy(SQLModel, table=True):
class ClientAccessPolicy(Base):
"""Client access policies."""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
source: str
client_id: uuid.UUID | None = Field(foreign_key="client.id", ondelete="CASCADE")
client: Client | None = Relationship(back_populates="policies")
__tablename__: str = "client_access_policy"
created_at: datetime | None = Field(
default=None,
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"server_default": sa.func.now()},
nullable=False,
id: Mapped[uuid.UUID] = mapped_column(
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
)
source: Mapped[str] = mapped_column(sa.String)
client_id: Mapped[uuid.UUID | None] = mapped_column(
sa.Uuid(as_uuid=True), sa.ForeignKey("client.id", ondelete="CASCADE")
)
client: Mapped[Client] = relationship(back_populates="policies")
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
)
updated_at: datetime | None = Field(
default=None,
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
updated_at: Mapped[datetime | None] = mapped_column(
sa.DateTime(timezone=True),
server_default=sa.func.now(),
onupdate=sa.func.now(),
)
class ClientSecret(SQLModel, table=True):
class ClientSecret(Base):
"""A client secret."""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
name: str
description: str | None = None
client_id: uuid.UUID | None = Field(foreign_key="client.id", ondelete="CASCADE")
client: Client | None = Relationship(back_populates="secrets")
secret: str
invalidated: bool = Field(default=False)
created_at: datetime | None = Field(
default=None,
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"server_default": sa.func.now()},
nullable=False,
__tablename__: str = "client_secret"
id: Mapped[uuid.UUID] = mapped_column(
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
)
name: Mapped[str] = mapped_column(sa.String)
description: Mapped[str | None] = mapped_column(sa.String, nullable=True)
secret: Mapped[str] = mapped_column(sa.String)
client_id: Mapped[uuid.UUID | None] = mapped_column(
sa.Uuid(as_uuid=True), sa.ForeignKey("client.id", ondelete="CASCADE")
)
client: Mapped[Client] = relationship(back_populates="secrets")
invalidated: Mapped[bool] = mapped_column(default=False)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
)
updated_at: datetime | None = Field(
default=None,
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
updated_at: Mapped[datetime | None] = mapped_column(
sa.DateTime(timezone=True),
server_default=sa.func.now(),
onupdate=sa.func.now(),
)
class AuditLog(SQLModel, table=True):
class APIClient(Base):
"""A client on the API.
This should eventually get more granular permissions.
"""
__tablename__: str = "api_client"
id: Mapped[uuid.UUID] = mapped_column(
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
)
subsystem: Mapped[SubSystem | None] = mapped_column(sa.String, nullable=True)
token: Mapped[str] = mapped_column(sa.String)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
)
updated_at: Mapped[datetime | None] = mapped_column(
sa.DateTime(timezone=True),
server_default=sa.func.now(),
onupdate=sa.func.now(),
)
class AuditLog(Base):
"""Audit log.
This is implemented without any foreign keys to avoid losing data on
deletions.
"""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
subsystem: str
message: str
operation: str
client_id: uuid.UUID | None = None
client_name: str | None = None
origin: str | None = None
Field(default=None, sa_column=Column(JSON))
timestamp: datetime | None = Field(
default=None,
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"server_default": sa.func.now()},
nullable=False,
__tablename__: str = "audit_log"
id: Mapped[uuid.UUID] = mapped_column(
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
)
subsystem: Mapped[SubSystem] = mapped_column(sa.String)
message: Mapped[str] = mapped_column(sa.String)
operation: Mapped[Operation] = mapped_column(sa.String)
client_id: Mapped[uuid.UUID | None] = mapped_column(
sa.Uuid(as_uuid=True), nullable=True
)
data: Mapped[dict[str, str] | None] = mapped_column(sa.JSON, nullable=True)
client_name: Mapped[str | None] = mapped_column(sa.String, nullable=True)
secret_id: Mapped[uuid.UUID | None] = mapped_column(
sa.Uuid(as_uuid=True), nullable=True
)
secret_name: Mapped[str | None] = mapped_column(sa.String, nullable=True)
class APIClient(SQLModel, table=True):
"""Stores API Keys."""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
token: str
read_write: bool
created_at: datetime | None = Field(
default=None,
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"server_default": sa.func.now()},
nullable=False,
origin: Mapped[str | None] = mapped_column(sa.String, nullable=True)
timestamp: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
)
def init_db(engine: sa.Engine) -> None:
"""Create database."""
LOG.info("Running init_db")
SQLModel.metadata.create_all(engine)
"""Initialize database."""
Base.metadata.create_all(engine)

View File

@ -1,12 +1,10 @@
"""Settings management."""
from pathlib import Path
from typing import Annotated, Any
from pydantic import Field, field_validator
from pydantic import Field
from pydantic_settings import (
BaseSettings,
SettingsConfigDict,
ForceDecode,
)
from sqlalchemy import URL
@ -22,24 +20,19 @@ class BackendSettings(BaseSettings):
)
database: str = Field(default=DEFAULT_DATABASE)
generate_initial_tokens: Annotated[bool, ForceDecode] = Field(default=False)
@field_validator("generate_initial_tokens", mode="before")
@classmethod
def cast_bool(cls, value: Any) -> bool:
"""Ensure we catch the boolean."""
if isinstance(value, str):
if value.lower() in ("1", "true", "on"):
return True
if value.lower() in ("0", "false", "off"):
return False
return bool(value)
admin_token: str | None = Field(default=None, alias="sshecret_admin_backend_token")
sshd_token: str | None = Field(default=None, alias="sshecret_sshd_backend_token")
@property
def db_url(self) -> URL:
"""Construct database url."""
return URL.create(drivername="sqlite", database=self.database)
@property
def async_db_url(self) -> URL:
"""Construct database url with sync handling."""
return URL.create(drivername="sqlite+aiosqlite", database=self.database)
@property
def db_exists(self) -> bool:
"""Check if databatase exists."""

View File

@ -1,8 +1,9 @@
"""Test helpers."""
import logging
from sqlmodel import Session
from sshecret_backend.settings import BackendSettings
from sqlalchemy.orm import Session
from .models import init_db
from .db import create_api_token, setup_database
@ -14,4 +15,4 @@ def create_test_token(settings: BackendSettings) -> str:
engine, _setupdb = setup_database(settings.db_url)
with Session(engine) as session:
init_db(engine)
return create_api_token(session, True)
return create_api_token(session, "test")

View File

@ -2,7 +2,7 @@
from collections.abc import Callable, Generator
from sqlmodel import Session
from sqlalchemy.orm import Session
DBSessionDep = Callable[[], Generator[Session, None, None]]

View File

@ -4,15 +4,14 @@ import uuid
from datetime import datetime
from typing import Annotated, Self, override
from sqlmodel import Field, SQLModel
from pydantic import AfterValidator, IPvAnyAddress, IPvAnyNetwork
from pydantic import AfterValidator, BaseModel, Field, IPvAnyAddress, IPvAnyNetwork
from sshecret.crypto import public_key_validator
from . import models
class ClientView(SQLModel):
class ClientView(BaseModel):
"""View for a single client."""
id: uuid.UUID
@ -50,7 +49,7 @@ class ClientView(SQLModel):
return view
class ClientQueryResult(SQLModel):
class ClientQueryResult(BaseModel):
"""Result class for queries towards the client list."""
clients: list[ClientView] = Field(default_factory=list)
@ -58,7 +57,7 @@ class ClientQueryResult(SQLModel):
remaining_results: int
class ClientCreate(SQLModel):
class ClientCreate(BaseModel):
"""Model to create a client."""
name: str
@ -74,19 +73,19 @@ class ClientCreate(SQLModel):
)
class ClientUpdate(SQLModel):
class ClientUpdate(BaseModel):
"""Model to update the client public key."""
public_key: Annotated[str, AfterValidator(public_key_validator)]
class BodyValue(SQLModel):
class BodyValue(BaseModel):
"""A generic model with just a value parameter."""
value: str
class ClientSecretPublic(SQLModel):
class ClientSecretPublic(BaseModel):
"""Public model to manage client secrets."""
name: str
@ -122,7 +121,7 @@ class ClientSecretResponse(ClientSecretPublic):
)
class ClientPolicyView(SQLModel):
class ClientPolicyView(BaseModel):
"""Update object for client policy."""
sources: list[str] = ["0.0.0.0/0", "::/0"]
@ -135,27 +134,27 @@ class ClientPolicyView(SQLModel):
return cls(sources=[policy.source for policy in client.policies])
class ClientPolicyUpdate(SQLModel):
class ClientPolicyUpdate(BaseModel):
"""Model for updating policies."""
sources: list[IPvAnyAddress | IPvAnyNetwork]
class ClientSecretList(SQLModel):
class ClientSecretList(BaseModel):
"""Model for aggregating identically named secrets."""
name: str
clients: list[str]
class ClientReference(SQLModel):
class ClientReference(BaseModel):
"""Reference to a client."""
id: str
name: str
class ClientSecretDetailList(SQLModel):
class ClientSecretDetailList(BaseModel):
"""A more detailed version of the ClientSecretList."""
name: str
@ -163,7 +162,22 @@ class ClientSecretDetailList(SQLModel):
clients: list[ClientReference] = Field(default_factory=list)
class AuditInfo(SQLModel):
class AuditView(BaseModel):
"""Audit log view."""
id: uuid.UUID | None = None
subsystem: models.SubSystem
message: str
operation: models.Operation
data: dict[str, str] | None = None
client_id: uuid.UUID | None = None
client_name: str | None = None
origin: str | None = None
timestamp: datetime | None = None
class AuditInfo(BaseModel):
"""Information about audit information."""
entries: int

View File

@ -10,7 +10,7 @@ from fastapi.testclient import TestClient
from sshecret.crypto import generate_private_key, generate_public_key_string
from sshecret_backend.app import create_backend_app
from sshecret_backend.testing import create_test_token
from sshecret_backend.models import AuditLog
from sshecret_backend.view_models import AuditView
from sshecret_backend.settings import BackendSettings
@ -53,7 +53,7 @@ def create_client_fixture(tmp_path: Path):
db_file = tmp_path / "backend.db"
print(f"DB File: {db_file.absolute()}")
settings = BackendSettings(db_url=f"sqlite:///{db_file.absolute()}")
settings = BackendSettings(database=str(db_file.absolute()))
app = create_backend_app(settings)
token = create_test_token(settings)
@ -213,7 +213,7 @@ def test_audit_logging(test_client: TestClient) -> None:
assert len(audit_logs) > 0
for entry in audit_logs:
# Let's try to reassemble the objects
audit_log = AuditLog.model_validate(entry)
audit_log = AuditView.model_validate(entry)
assert audit_log is not None
@ -522,9 +522,8 @@ def test_operations_with_id(test_client: TestClient) -> None:
def test_write_audit_log(test_client: TestClient) -> None:
"""Test writing to the audit log."""
params = {
"object": "Test",
"operation": "TEST",
"object_id": "Something",
"subsystem": "backend",
"operation": "read",
"message": "Test Message"
}
resp = test_client.post("/api/v1/audit", json=params)

View File

@ -8,8 +8,10 @@ from .models import (
ClientReference,
ClientSecret,
DetailedSecrets,
Operation,
Policy,
Secret,
SubSystem,
FilterType,
)
@ -22,7 +24,9 @@ __all__ = [
"ClientSecret",
"DetailedSecrets",
"FilterType",
"Operation",
"Policy",
"Secret",
"SubSystem",
"SshecretBackend",
]

View File

@ -1,21 +1,26 @@
"""Backend client.
This is an API calling the HTTP API of the sshecret-backend package so that the
admin and sshd library do not need to implement the same
"""
import logging
from typing import Any, Self
from typing import Any, Self, override
import httpx
from pydantic import TypeAdapter
from .models import (
AuditInfo,
AuditLog,
Client,
ClientSecret,
ClientQueryResult,
ClientFilter,
DetailedSecrets,
Operation,
Secret,
SubSystem,
)
from .exceptions import BackendValidationError, BackendConnectionError
from .utils import validate_public_key
@ -84,12 +89,14 @@ class ClientQueryIterator:
raise StopAsyncIteration
class SshecretBackend:
"""Backend interface."""
class BaseBackend:
"""Base backend class."""
def __init__(self, backend_url: str, api_token: str) -> None:
"""Initialize backend client."""
self._backend_url: str = backend_url
self._api_token: str = api_token
url = httpx.URL(backend_url)
self.http_client: httpx.AsyncClient = httpx.AsyncClient(
@ -101,31 +108,55 @@ class SshecretBackend:
base_url=url,
)
async def _get(self, path: str) -> httpx.Response:
def _handle_response(self, response: httpx.Response) -> httpx.Response:
"""Handle response."""
LOG.debug("Handling response with status_code %s", response.status_code)
if response.status_code == 422:
LOG.error("Validation error from backend:\n%s", response.text)
raise BackendValidationError(response.text)
if response.status_code != 404 and str(response.status_code).startswith("4"):
raise BackendConnectionError("Error from backend: %s", response.text)
return response
async def _get(
self, path: str, params: dict[str, str] | None = None
) -> httpx.Response:
"""Perform a get request."""
try:
return await self.http_client.get(path)
response = await self.http_client.get(path, params=params)
return self._handle_response(response)
except httpx.ConnectError as e:
raise BackendConnectionError() from e
raise BackendConnectionError("Could not connect to backend.") from e
async def _delete(self, path: str) -> httpx.Response:
"""Perform a delete request."""
try:
return await self.http_client.delete(path)
response = await self.http_client.delete(path)
return self._handle_response(response)
except httpx.ConnectError as e:
raise BackendConnectionError() from e
async def _post(self, path: str, json: Any | None = None) -> httpx.Response:
"""Perform a POST request."""
try:
return await self.http_client.post(path, json=json)
response = await self.http_client.post(path, json=json)
return self._handle_response(response)
except httpx.ConnectError as e:
raise BackendConnectionError() from e
def _post_sync(self, path: str, json: Any | None = None) -> httpx.Response:
"""Perform a synchronous post request."""
try:
return self._handle_response(self.sync_client.post(path, json=json))
except httpx.ConnectError as e:
raise BackendConnectionError() from e
async def _put(self, path: str, json: Any | None = None) -> httpx.Response:
"""Perform a PUT request."""
try:
return await self.http_client.put(path, json=json)
response = await self.http_client.put(path, json=json)
return self._handle_response(response)
except httpx.ConnectError as e:
raise BackendConnectionError() from e
@ -134,175 +165,90 @@ class SshecretBackend:
response = await self._get(path)
return response
async def create_client(
self, name: str, public_key: str, description: str | None = None
) -> None:
"""Register a new client."""
if not validate_public_key(public_key):
raise BackendValidationError("Error: Invalid public key format.")
data = {
"name": name,
"public_key": public_key,
}
if description:
data["description"] = description
path = "/api/v1/clients/"
response = await self._post(path, json=data)
response.raise_for_status()
class AuditAPI(BaseBackend):
"""API for the audit logging."""
async def get_clients(self, filter: ClientFilter | None = None) -> list[Client]:
"""Get all clients."""
clients: list[Client] = []
async for client in ClientQueryIterator(self.http_client, filter):
clients.append(client)
@override
def __init__(self, backend_url: str, api_token: str, subsystem: str) -> None:
"""Initialize backend client."""
super().__init__(backend_url, api_token)
self.subsystem: SubSystem = SubSystem(subsystem)
return clients
async def get_client(self, name: str) -> Client | None:
"""Lookup a client on username."""
path = f"/api/v1/clients/{name}"
response = await self.request(path)
if response.status_code == 404:
return None
response.raise_for_status()
client = Client.model_validate(response.json())
return client
async def get_client_by_id(self, id: str) -> Client | None:
"""Lookup a client on username."""
path = f"/api/v1/clients/id/{id}"
response = await self.request(path)
if response.status_code == 404:
return None
response.raise_for_status()
client = Client.model_validate(response.json())
return client
async def delete_client(self, client_name: str) -> None:
"""Delete a client."""
path = f"/api/v1/clients/{client_name}"
response = await self._delete(path)
response.raise_for_status()
async def delete_client_by_id(self, id: str) -> None:
"""Delete a client."""
path = f"/api/v1/clients/id/{id}"
response = await self._delete(path)
response.raise_for_status()
async def create_client_secret(
self, client_name: str, secret_name: str, encrypted_secret: str
) -> None:
"""Create a secret.
This will overwrite any existing secret with that name.
"""
path = f"api/v1/clients/{client_name}/secrets/{secret_name}"
response = await self._put(path, json={"value": encrypted_secret})
response.raise_for_status()
async def get_client_secret(self, name: str, secret_name: str) -> str:
"""Fetch a secret."""
path = f"/api/v1/clients/{name}/secrets/{secret_name}"
response = await self.request(path)
response.raise_for_status()
secret = ClientSecret.model_validate(response.json())
return secret.secret
async def delete_client_secret(self, client_name: str, secret_name: str) -> None:
"""Delete a secret from a client."""
path = f"api/v1/clients/{client_name}/secrets/{secret_name}"
response = await self._delete(path)
response.raise_for_status()
async def update_client(self, client: Client) -> Client:
"""Update the client."""
path = f"/api/v1/clients/{client.name}"
client_update = {
"name": client.name,
"description": client.description,
"public_key": client.public_key,
}
response = await self._put(path, json=client_update)
LOG.info("Response %s", response.text)
response.raise_for_status()
if client.policies:
await self.update_client_sources(
str(client.id), [str(source) for source in client.policies]
def _create_model(
self,
operation: Operation,
message: str,
origin: str,
client: Client | None = None,
secret: ClientSecret | None = None,
secret_name: str | None = None,
data: dict[str, str] | None = None,
) -> AuditLog:
"""Create the audit log object."""
model = AuditLog(
subsystem=self.subsystem,
operation=operation,
message=message,
origin=origin,
)
return client
if client:
model.client_id = str(client.id) or None
model.client_name = client.name
async def update_client_key(self, client_name: str, public_key: str) -> None:
"""Update the client key."""
path = f"/api/v1/clients/{client_name}/public-key"
response = await self._post(path, json={"public_key": public_key})
if secret:
model.secret_name = secret.name
elif secret_name:
model.secret_name = secret_name
if data:
model.data = data
response.raise_for_status()
return model
async def update_client_sources(
self, client_name: str, addresses: list[str] | None
def write_model(self, model: AuditLog) -> None:
"""Write model."""
path = f"/api/v1/audit/"
self.sync_client.post(path, json=model.model_dump())
async def write_model_async(self, model: AuditLog) -> None:
"""Write model async."""
path = f"/api/v1/audit/"
await self._post(path, json=model.model_dump())
def write(
self,
operation: Operation,
message: str,
origin: str,
client: Client | None = None,
secret: ClientSecret | None = None,
secret_name: str | None = None,
**data: str,
) -> None:
"""Update client source addresses.
"""Write an audit entry."""
model = self._create_model(
operation, message, origin, client, secret, secret_name, data
)
Pass None to sources to allow from all.
"""
if not addresses:
addresses = []
self.write_model(model)
path = f"/api/v1/clients/{client_name}/policies/"
response = await self._put(path, json={"sources": addresses})
async def write_async(
self,
operation: Operation,
message: str,
origin: str,
client: Client | None = None,
secret: ClientSecret | None = None,
secret_name: str | None = None,
**data: str,
) -> None:
"""Write an audit entry."""
model = self._create_model(
operation, message, origin, client, secret, secret_name, data
)
await self.write_model_async(model)
response.raise_for_status()
async def get_detailed_secrets(self) -> list[DetailedSecrets]:
"""Get detailed list of secrets."""
path = "/api/v1/secrets/detailed/"
response = await self._get(path)
response.raise_for_status()
secret_list = TypeAdapter(list[DetailedSecrets])
return secret_list.validate_python(response.json())
async def get_secrets(self) -> list[Secret]:
"""Get Secrets.
This provides a list of secret names and which clients have them.
"""
path = "/api/v1/secrets/"
response = await self._get(path)
response.raise_for_status()
secret_list = TypeAdapter(list[Secret])
return secret_list.validate_python(response.json())
async def get_secret(self, name: str) -> Secret | None:
"""Get clients mapped to a single secret."""
path = f"/api/v1/secrets/{name}"
response = await self._get(path)
if response.status_code == 404:
return None
response.raise_for_status()
return Secret.model_validate(response.json())
async def get_detailed_secret(self, name: str) -> DetailedSecrets | None:
"""Get clients mapped to a single secret."""
path = f"/api/v1/secrets/{name}/detailed"
response = await self._get(path)
if response.status_code == 404:
return None
response.raise_for_status()
return DetailedSecrets.model_validate(response.json())
async def get_audit_log(
async def get(
self,
offset: int = 0,
limit: int = 100,
@ -321,30 +267,169 @@ class SshecretBackend:
if subsystem:
params["filter_subsystem"] = subsystem
response = await self.http_client.get(path, params=params)
response.raise_for_status()
response = await self._get(path, params=params)
audit_log_adapter = TypeAdapter(list[AuditLog])
return audit_log_adapter.validate_python(response.json())
async def add_audit_log(self, entry: AuditLog) -> None:
"""Add audit log entry."""
path = f"/api/v1/audit/"
response = await self.http_client.post(path, json=entry.model_dump())
response.raise_for_status()
async def get_audit_log_count(self) -> int:
async def count(self) -> int:
"""Get amount of messages in the audit log."""
path = f"/api/v1/audit/info"
response = await self._get(path)
response.raise_for_status()
data = response.json()
return int(data["entries"])
audit_info = AuditInfo.model_validate(response.json())
return audit_info.entries
def add_audit_log_sync(self, entry: AuditLog) -> None:
"""Add audit log entry."""
path = f"/api/v1/audit/"
LOG.info("AUDIT LOG SYNC %r", entry)
response = self.sync_client.post(path, json=entry.model_dump())
response.raise_for_status()
class SshecretBackend(BaseBackend):
"""Backend interface."""
async def create_client(
self, name: str, public_key: str, description: str | None = None
) -> None:
"""Register a new client."""
if not validate_public_key(public_key):
raise BackendValidationError("Error: Invalid public key format.")
data = {
"name": name,
"public_key": public_key,
}
if description:
data["description"] = description
path = "/api/v1/clients/"
response = await self._post(path, json=data)
async def get_clients(self, filter: ClientFilter | None = None) -> list[Client]:
"""Get all clients."""
clients: list[Client] = []
async for client in ClientQueryIterator(self.http_client, filter):
clients.append(client)
return clients
async def get_client(self, name: str) -> Client | None:
"""Lookup a client on username."""
path = f"/api/v1/clients/{name}"
response = await self._get(path)
if response.status_code == 404:
return None
client = Client.model_validate(response.json())
return client
async def get_client_by_id(self, id: str) -> Client | None:
"""Lookup a client on username."""
path = f"/api/v1/clients/id/{id}"
response = await self._get(path)
if response.status_code == 404:
return None
client = Client.model_validate(response.json())
return client
async def delete_client(self, client_name: str) -> None:
"""Delete a client."""
path = f"/api/v1/clients/{client_name}"
response = await self._delete(path)
async def delete_client_by_id(self, id: str) -> None:
"""Delete a client."""
path = f"/api/v1/clients/id/{id}"
response = await self._delete(path)
async def create_client_secret(
self, client_name: str, secret_name: str, encrypted_secret: str
) -> None:
"""Create a secret.
This will overwrite any existing secret with that name.
"""
path = f"api/v1/clients/{client_name}/secrets/{secret_name}"
response = await self._put(path, json={"value": encrypted_secret})
async def get_client_secret(self, name: str, secret_name: str) -> str | None:
"""Fetch a secret."""
path = f"/api/v1/clients/{name}/secrets/{secret_name}"
response = await self._get(path)
if response.status_code == 404:
return None
secret = ClientSecret.model_validate(response.json())
return secret.secret
async def delete_client_secret(self, client_name: str, secret_name: str) -> None:
"""Delete a secret from a client."""
path = f"api/v1/clients/{client_name}/secrets/{secret_name}"
await self._delete(path)
async def update_client(self, client: Client) -> Client:
"""Update the client."""
path = f"/api/v1/clients/{client.name}"
client_update = {
"name": client.name,
"description": client.description,
"public_key": client.public_key,
}
response = await self._put(path, json=client_update)
LOG.info("Response %s", response.text)
if client.policies:
await self.update_client_sources(
str(client.id), [str(source) for source in client.policies]
)
return client
async def update_client_key(self, client_name: str, public_key: str) -> None:
"""Update the client key."""
path = f"/api/v1/clients/{client_name}/public-key"
await self._post(path, json={"public_key": public_key})
async def update_client_sources(
self, client_name: str, addresses: list[str] | None
) -> None:
"""Update client source addresses.
Pass None to sources to allow from all.
"""
if not addresses:
addresses = []
path = f"/api/v1/clients/{client_name}/policies/"
await self._put(path, json={"sources": addresses})
async def get_detailed_secrets(self) -> list[DetailedSecrets]:
"""Get detailed list of secrets."""
path = "/api/v1/secrets/detailed/"
response = await self._get(path)
secret_list = TypeAdapter(list[DetailedSecrets])
return secret_list.validate_python(response.json())
async def get_secrets(self) -> list[Secret]:
"""Get Secrets.
This provides a list of secret names and which clients have them.
"""
path = "/api/v1/secrets/"
response = await self._get(path)
secret_list = TypeAdapter(list[Secret])
return secret_list.validate_python(response.json())
async def get_secret(self, name: str) -> Secret | None:
"""Get clients mapped to a single secret."""
path = f"/api/v1/secrets/{name}"
response = await self._get(path)
if response.status_code == 404:
return None
return Secret.model_validate(response.json())
async def get_detailed_secret(self, name: str) -> DetailedSecrets | None:
"""Get clients mapped to a single secret."""
path = f"/api/v1/secrets/{name}/detailed"
response = await self._get(path)
if response.status_code == 404:
return None
return DetailedSecrets.model_validate(response.json())
def audit(self, subsystem: SubSystem) -> AuditAPI:
"""Create the audit API."""
audit = AuditAPI(self._backend_url, self._api_token, subsystem)
return audit

View File

@ -17,6 +17,27 @@ class FilterType(enum.StrEnum):
CONTAINS = "contains"
class SubSystem(enum.StrEnum):
"""Available subsystems."""
ADMIN = enum.auto()
SSHD = enum.auto()
BACKEND = enum.auto()
class Operation(enum.StrEnum):
"""Various operations for the audit logging module."""
CREATE = enum.auto()
READ = enum.auto()
UPDATE = enum.auto()
DELETE = enum.auto()
DENY = enum.auto()
PERMIT = enum.auto()
LOGIN = enum.auto()
NONE = enum.auto()
class Client(BaseModel):
"""Implementation of the backend class ClientView."""
@ -101,15 +122,22 @@ class ClientFilter(BaseModel):
class AuditLog(BaseModel):
"""Implementation of the backend class AuditLog."""
"""Implementation of the backend class AuditView."""
id: str | None = None
subsystem: str | None = None
object: str | None = None
object_id: str | None = None
operation: str
subsystem: SubSystem
operation: Operation
client_id: str | None = None
client_name: str | None = None
secret_id: str | None = None
secret_name: str | None = None
data: dict[str, str] | None = None
message: str
origin: str | None = None
timestamp: datetime | None = None
class AuditInfo(BaseModel):
"""Implementation of the backend class AuditInfo."""
entries: int