diff --git a/packages/sshecret-backend/migrations/env.py b/packages/sshecret-backend/migrations/env.py index 1e35f40..d2797e8 100644 --- a/packages/sshecret-backend/migrations/env.py +++ b/packages/sshecret-backend/migrations/env.py @@ -3,22 +3,22 @@ from logging.config import fileConfig from sqlalchemy import engine_from_config from sqlalchemy import pool -from sqlmodel import create_engine from alembic import context -from sshecret_backend.models import * - -def get_database_url() -> str: - """Get database URL.""" - if db_file := os.getenv("SSHECRET_BACKEND_DB"): - return f"sqlite:///{db_file}" - return "sqlite:///sshecret.db" - +from sshecret_backend.models import Base # this is the Alembic Config object, which provides # access to the values within the .ini file in use. config = context.config + +def get_database_url() -> str | None: + """Get database URL.""" + if db_file := os.getenv("SSHECRET_BACKEND_DB"): + return f"sqlite:///{db_file}" + return config.get_main_option("sqlalchemy.url") + + # Interpret the config file for Python logging. # This line sets up loggers basically. if config.config_file_name is not None: @@ -28,8 +28,7 @@ if config.config_file_name is not None: # for 'autogenerate' support # from myapp import mymodel # target_metadata = mymodel.Base.metadata -#target_metadata = None -target_metadata = SQLModel.metadata +target_metadata = Base.metadata # other values from the config, defined by the needs of env.py, # can be acquired: @@ -68,7 +67,11 @@ def run_migrations_online() -> None: and associate a connection with the context. """ - connectable = create_engine(get_database_url()) + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) with connectable.connect() as connection: context.configure( diff --git a/packages/sshecret-backend/migrations/script.py.mako b/packages/sshecret-backend/migrations/script.py.mako index 81f5923..480b130 100644 --- a/packages/sshecret-backend/migrations/script.py.mako +++ b/packages/sshecret-backend/migrations/script.py.mako @@ -9,7 +9,6 @@ from typing import Sequence, Union from alembic import op import sqlalchemy as sa -import sqlmodel ${imports if imports else ""} # revision identifiers, used by Alembic. diff --git a/packages/sshecret-backend/src/sshecret_backend/api/audit.py b/packages/sshecret-backend/src/sshecret_backend/api/audit.py index f429ec6..00eabd4 100644 --- a/packages/sshecret-backend/src/sshecret_backend/api/audit.py +++ b/packages/sshecret-backend/src/sshecret_backend/api/audit.py @@ -5,13 +5,14 @@ import logging from collections.abc import Sequence from fastapi import APIRouter, Depends, Request, Query -from sqlmodel import Session, col, func, select -from sqlalchemy import desc +from pydantic import TypeAdapter +from sqlalchemy import select, func +from sqlalchemy.orm import Session from typing import Annotated from sshecret_backend.models import AuditLog from sshecret_backend.types import DBSessionDep -from sshecret_backend.view_models import AuditInfo +from sshecret_backend.view_models import AuditInfo, AuditView LOG = logging.getLogger(__name__) @@ -21,7 +22,7 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter: """Construct audit sub-api.""" router = APIRouter() - @router.get("/audit/", response_model=list[AuditLog]) + @router.get("/audit/", response_model=list[AuditView]) async def get_audit_logs( request: Request, session: Annotated[Session, Depends(get_db_session)], @@ -29,35 +30,37 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter: limit: Annotated[int, Query(le=100)] = 100, filter_client: Annotated[str | None, Query()] = None, filter_subsystem: Annotated[str | None, Query()] = None, - ) -> Sequence[AuditLog]: + ) -> Sequence[AuditView]: """Get audit logs.""" #audit.audit_access_audit_log(session, request) - statement = select(AuditLog).offset(offset).limit(limit).order_by(desc(col(AuditLog.timestamp))) + statement = select(AuditLog).offset(offset).limit(limit).order_by(AuditLog.timestamp.desc()) if filter_client: statement = statement.where(AuditLog.client_name == filter_client) if filter_subsystem: statement = statement.where(AuditLog.subsystem == filter_subsystem) - results = session.exec(statement).all() - return results + LogAdapt = TypeAdapter(list[AuditView]) + results = session.scalars(statement).all() + return LogAdapt.validate_python(results, from_attributes=True) + @router.post("/audit/") async def add_audit_log( request: Request, session: Annotated[Session, Depends(get_db_session)], - entry: AuditLog, - ) -> AuditLog: + entry: AuditView, + ) -> AuditView: """Add entry to audit log.""" - audit_log = AuditLog.model_validate(entry.model_dump(exclude_none=True)) + audit_log = AuditLog(**entry.model_dump(exclude_none=True)) session.add(audit_log) session.commit() - return audit_log + return AuditView.model_validate(audit_log, from_attributes=True) @router.get("/audit/info") async def get_audit_info(request: Request, session: Annotated[Session, Depends(get_db_session)]) -> AuditInfo: """Get audit info.""" - audit_count = session.exec(select(func.count('*')).select_from(AuditLog)).one() + audit_count = session.scalars(select(func.count('*')).select_from(AuditLog)).one() return AuditInfo(entries=audit_count) diff --git a/packages/sshecret-backend/src/sshecret_backend/api/clients.py b/packages/sshecret-backend/src/sshecret_backend/api/clients.py index 56a3b3c..536e176 100644 --- a/packages/sshecret-backend/src/sshecret_backend/api/clients.py +++ b/packages/sshecret-backend/src/sshecret_backend/api/clients.py @@ -6,11 +6,11 @@ import uuid import logging from fastapi import APIRouter, Depends, HTTPException, Query, Request from pydantic import BaseModel, Field, model_validator -from sqlmodel import Session, col, select -from sqlalchemy import func -from typing import Annotated, Self, TypeVar +from typing import Annotated, Any, Self, TypeVar, cast -from sqlmodel.sql.expression import SelectOfScalar +from sqlalchemy import select, func +from sqlalchemy.orm import Session +from sqlalchemy.sql import Select from sshecret_backend.types import DBSessionDep from sshecret_backend.models import Client, ClientSecret from sshecret_backend.view_models import ( @@ -55,8 +55,8 @@ T = TypeVar("T") def filter_client_statement( - statement: SelectOfScalar[T], params: ClientListParams, ignore_limits: bool = False -) -> SelectOfScalar[T]: + statement: Select[Any], params: ClientListParams, ignore_limits: bool = False +) -> Select[Any]: """Filter a statement with the provided params.""" if params.id: statement = statement.where(Client.id == params.id) @@ -64,9 +64,9 @@ def filter_client_statement( if params.name: statement = statement.where(Client.name == params.name) elif params.name__like: - statement = statement.where(col(Client.name).like(params.name__like)) + statement = statement.where(Client.name.like(params.name__like)) elif params.name__contains: - statement = statement.where(col(Client.name).contains(params.name__contains)) + statement = statement.where(Client.name.contains(params.name__contains)) if ignore_limits: return statement @@ -86,13 +86,13 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter: """Get clients.""" # Get total results first count_statement = select(func.count("*")).select_from(Client) - count_statement = filter_client_statement(count_statement, filter_query, True) + count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True)) - total_results = session.exec(count_statement).one() + total_results = session.scalars(count_statement).one() statement = filter_client_statement(select(Client), filter_query, False) - results = session.exec(statement) + results = session.scalars(statement) remainder = total_results - filter_query.offset - filter_query.limit if remainder < 0: remainder = 0 @@ -170,13 +170,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter: status_code=404, detail="Cannot find a client with the given name." ) client.public_key = client_update.public_key - for secret in session.exec( + for secret in session.scalars( select(ClientSecret).where(ClientSecret.client_id == client.id) ).all(): LOG.debug("Invalidated secret %s", secret.id) secret.invalidated = True secret.client_id = None - secret.client = None session.add(client) session.refresh(client) @@ -206,13 +205,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter: public_key_updated = False if client_update.public_key != client.public_key: public_key_updated = True - for secret in session.exec( + for secret in session.scalars( select(ClientSecret).where(ClientSecret.client_id == client.id) ).all(): LOG.debug("Invalidated secret %s", secret.id) secret.invalidated = True secret.client_id = None - secret.client = None session.add(client) session.commit() diff --git a/packages/sshecret-backend/src/sshecret_backend/api/common.py b/packages/sshecret-backend/src/sshecret_backend/api/common.py index 566ddc4..fad0b57 100644 --- a/packages/sshecret-backend/src/sshecret_backend/api/common.py +++ b/packages/sshecret-backend/src/sshecret_backend/api/common.py @@ -4,7 +4,8 @@ import re import uuid import bcrypt -from sqlmodel import Session, select +from sqlalchemy import select +from sqlalchemy.orm import Session from sshecret_backend.models import Client @@ -20,13 +21,13 @@ def verify_token(token: str, stored_hash: str) -> bool: async def get_client_by_name(session: Session, name: str) -> Client | None: """Get client by name.""" client_filter = select(Client).where(Client.name == name) - client_results = session.exec(client_filter) + client_results = session.scalars(client_filter) return client_results.first() async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None: """Get client by name.""" client_filter = select(Client).where(Client.id == id) - client_results = session.exec(client_filter) + client_results = session.scalars(client_filter) return client_results.first() async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None: diff --git a/packages/sshecret-backend/src/sshecret_backend/api/policies.py b/packages/sshecret-backend/src/sshecret_backend/api/policies.py index 8c7a89e..624d2ec 100644 --- a/packages/sshecret-backend/src/sshecret_backend/api/policies.py +++ b/packages/sshecret-backend/src/sshecret_backend/api/policies.py @@ -4,7 +4,8 @@ import logging from fastapi import APIRouter, Depends, HTTPException, Request -from sqlmodel import Session, select +from sqlalchemy import select +from sqlalchemy.orm import Session from typing import Annotated from sshecret_backend.models import ClientAccessPolicy @@ -54,7 +55,7 @@ def get_policy_api(get_db_session: DBSessionDep) -> APIRouter: status_code=404, detail="Cannot find a client with the given name." ) # Remove old policies. - policies = session.exec( + policies = session.scalars( select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id) ).all() deleted_policies: list[ClientAccessPolicy] = [] diff --git a/packages/sshecret-backend/src/sshecret_backend/api/secrets.py b/packages/sshecret-backend/src/sshecret_backend/api/secrets.py index fadb9ce..0808db4 100644 --- a/packages/sshecret-backend/src/sshecret_backend/api/secrets.py +++ b/packages/sshecret-backend/src/sshecret_backend/api/secrets.py @@ -5,7 +5,8 @@ import logging from collections import defaultdict from fastapi import APIRouter, Depends, HTTPException, Request -from sqlmodel import Session, select +from sqlalchemy import select +from sqlalchemy.orm import Session from typing import Annotated from sshecret_backend.models import Client, ClientSecret @@ -34,7 +35,7 @@ async def lookup_client_secret( .where(ClientSecret.client_id == client.id) .where(ClientSecret.name == name) ) - results = session.exec(statement) + results = session.scalars(statement) return results.first() @@ -165,7 +166,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter: ) -> list[ClientSecretList]: """Get a list of all secrets and which clients have them.""" client_secret_map: defaultdict[str, list[str]] = defaultdict(list) - for client_secret in session.exec(select(ClientSecret)).all(): + for client_secret in session.scalars(select(ClientSecret)).all(): if not client_secret.client: if client_secret.name not in client_secret_map: client_secret_map[client_secret.name] = [] @@ -182,7 +183,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter: ) -> list[ClientSecretDetailList]: """Get a list of all secrets and which clients have them.""" client_secrets: dict[str, ClientSecretDetailList] = {} - for client_secret in session.exec(select(ClientSecret)).all(): + for client_secret in session.scalars(select(ClientSecret)).all(): if client_secret.name not in client_secrets: client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name) @@ -202,7 +203,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter: ) -> ClientSecretList: """Get a list of which clients has a named secret.""" clients: list[str] = [] - for client_secret in session.exec( + for client_secret in session.scalars( select(ClientSecret).where(ClientSecret.name == name) ).all(): if not client_secret.client: @@ -219,7 +220,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter: ) -> ClientSecretDetailList: """Get a list of which clients has a named secret.""" detail_list = ClientSecretDetailList(name=name) - for client_secret in session.exec( + for client_secret in session.scalars( select(ClientSecret).where(ClientSecret.name == name) ).all(): if not client_secret.client: diff --git a/packages/sshecret-backend/src/sshecret_backend/audit.py b/packages/sshecret-backend/src/sshecret_backend/audit.py index d369f63..3f0b0be 100644 --- a/packages/sshecret-backend/src/sshecret_backend/audit.py +++ b/packages/sshecret-backend/src/sshecret_backend/audit.py @@ -2,9 +2,10 @@ from collections.abc import Sequence from fastapi import Request -from sqlmodel import Session, select +from sqlalchemy import select +from sqlalchemy.orm import Session -from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy +from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy, Operation, SubSystem def _get_origin(request: Request) -> str | None: @@ -22,7 +23,7 @@ def _write_audit_log( """Write the audit log.""" origin = _get_origin(request) entry.origin = origin - entry.subsystem = "backend" + entry.subsystem = SubSystem.BACKEND session.add(entry) if commit: session.commit() @@ -33,7 +34,7 @@ def audit_create_client( ) -> None: """Log the creation of a client.""" entry = AuditLog( - operation="CREATE", + operation=Operation.CREATE, client_id=client.id, client_name=client.name, message="Client Created", @@ -46,7 +47,7 @@ def audit_delete_client( ) -> None: """Log the creation of a client.""" entry = AuditLog( - operation="CREATE", + operation=Operation.CREATE, client_id=client.id, client_name=client.name, message="Client deleted", @@ -63,9 +64,9 @@ def audit_create_secret( ) -> None: """Audit a create secret event.""" entry = AuditLog( - operation="CREATE", - object="ClientSecret", - object_id=str(secret.id), + operation=Operation.CREATE, + secret_id=secret.id, + secret_name=secret.name, client_id=client.id, client_name=client.name, message="Added secret to client", @@ -81,13 +82,13 @@ def audit_remove_policy( commit: bool = True, ) -> None: """Audit removal of policy.""" + data = {"object": "ClientAccessPolicy", "object_id": str(policy.id)} entry = AuditLog( - operation="DELETE", - object="ClientAccessPolicy", - object_id=str(policy.id), + operation=Operation.DELETE, client_id=client.id, client_name=client.name, message="Deleted client policy", + data=data, ) _write_audit_log(session, request, entry, commit) @@ -100,13 +101,13 @@ def audit_update_policy( commit: bool = True, ) -> None: """Audit update of policy.""" + data: dict[str, str] = {"object": "ClientAccessPolicy", "object_id": str(policy.id)} entry = AuditLog( - operation="CREATE", - object="ClientAccessPolicy", - object_id=str(policy.id), - client_id=client.id, + operation=Operation.CREATE, client_name=client.name, + client_id=client.id, message="Updated client policy", + data=data, ) _write_audit_log(session, request, entry, commit) @@ -119,11 +120,10 @@ def audit_update_client( ) -> None: """Audit an update secret event.""" entry = AuditLog( - operation="UPDATE", - object="Client", + operation=Operation.UPDATE, client_id=client.id, client_name=client.name, - message="Client updated", + message="Client data updated", ) _write_audit_log(session, request, entry, commit) @@ -137,11 +137,11 @@ def audit_update_secret( ) -> None: """Audit an update secret event.""" entry = AuditLog( - operation="UPDATE", - object="ClientSecret", - object_id=str(secret.id), + operation=Operation.UPDATE, client_id=client.id, client_name=client.name, + secret_name=secret.name, + secret_id=secret.id, message="Secret value updated", ) _write_audit_log(session, request, entry, commit) @@ -155,8 +155,7 @@ def audit_invalidate_secrets( ) -> None: """Audit Invalidate client secrets.""" entry = AuditLog( - operation="INVALIDATE", - object="ClientSecret", + operation=Operation.UPDATE, client_name=client.name, client_id=client.id, message="Client public-key changed. All secrets invalidated.", @@ -173,9 +172,9 @@ def audit_delete_secret( ) -> None: """Audit Delete client secrets.""" entry = AuditLog( - operation="DELETE", - object="ClientSecret", - object_id=str(secret.id), + operation=Operation.DELETE, + secret_name=secret.name, + secret_id=secret.id, client_name=client.name, client_id=client.id, message="Deleted secret.", @@ -195,7 +194,7 @@ def audit_access_secrets( With no secrets provided, all secrets of the client will be resolved. """ if not secrets: - secrets = session.exec( + secrets = session.scalars( select(ClientSecret).where(ClientSecret.client_id == client.id) ).all() @@ -215,37 +214,21 @@ def audit_access_secret( ) -> None: """Audit that someone accessed one secrets.""" entry = AuditLog( - operation="ACCESS", + operation=Operation.READ, message="Secret was viewed", - object="ClientSecret", - object_id=str(secret.id), + secret_name=secret.name, + secret_id=secret.id, client_id=client.id, client_name=client.name, ) _write_audit_log(session, request, entry, commit) - -def audit_access_audit_log( - session: Session, request: Request, commit: bool = True -) -> None: - """Audit access to the audit log. - - Because why not... - """ - entry = AuditLog( - operation="ACCESS", - message="Audit log was viewed", - object="AuditLog", - ) - _write_audit_log(session, request, entry, commit) - - def audit_client_secret_list( session: Session, request: Request, commit: bool = True ) -> None: """Audit a list of all secrets.""" entry = AuditLog( - operation="ACCESS", + operation=Operation.READ, message="All secret names and their clients was viewed", ) _write_audit_log(session, request, entry, commit) diff --git a/packages/sshecret-backend/src/sshecret_backend/backend_api.py b/packages/sshecret-backend/src/sshecret_backend/backend_api.py index 43a2563..aca6a52 100644 --- a/packages/sshecret-backend/src/sshecret_backend/backend_api.py +++ b/packages/sshecret-backend/src/sshecret_backend/backend_api.py @@ -3,11 +3,11 @@ import logging from typing import Annotated -import bcrypt from fastapi import APIRouter, Depends, Header, HTTPException -from sqlmodel import Session, select - +from sqlalchemy import select +from sqlalchemy.orm import Session from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api +from .auth import verify_token from .models import ( APIClient, ) @@ -18,13 +18,6 @@ LOG = logging.getLogger(__name__) API_VERSION = "v1" -def verify_token(token: str, stored_hash: str) -> bool: - """Verify token.""" - token_bytes = token.encode("utf-8") - stored_bytes = stored_hash.encode("utf-8") - return bcrypt.checkpw(token_bytes, stored_bytes) - - def get_backend_api( get_db_session: DBSessionDep, ) -> APIRouter: @@ -37,7 +30,7 @@ def get_backend_api( """Validate token.""" LOG.debug("Validating token %s", x_api_token) statement = select(APIClient) - results = session.exec(statement) + results = session.scalars(statement) valid = False for result in results: if verify_token(x_api_token, result.token): diff --git a/packages/sshecret-backend/src/sshecret_backend/cli.py b/packages/sshecret-backend/src/sshecret_backend/cli.py index 956bad4..1eda37f 100644 --- a/packages/sshecret-backend/src/sshecret_backend/cli.py +++ b/packages/sshecret-backend/src/sshecret_backend/cli.py @@ -3,15 +3,24 @@ import code import os from pathlib import Path -from typing import cast -from dotenv import load_dotenv +from typing import Literal, cast + import click -from sqlmodel import Session, col, func, select import uvicorn +from dotenv import load_dotenv +from sqlalchemy import select +from sqlalchemy.orm import Session -from .db import get_engine, create_api_token - -from .models import Client, ClientSecret, ClientAccessPolicy, AuditLog, APIClient, init_db +from .db import create_api_token, get_engine, hash_token +from .models import ( + APIClient, + AuditLog, + Client, + ClientAccessPolicy, + ClientSecret, + SubSystem, + init_db, +) from .settings import BackendSettings DEFAULT_LISTEN = "127.0.0.1" @@ -21,22 +30,44 @@ WORKDIR = Path(os.getcwd()) load_dotenv() -def generate_token(settings: BackendSettings) -> str: + +def generate_token( + settings: BackendSettings, subsystem: Literal["admin", "sshd"] +) -> str: """Generate a token.""" engine = get_engine(settings.db_url) init_db(engine) with Session(engine) as session: - token = create_api_token(session, True) + token = create_api_token(session, subsystem) return token -def count_tokens(settings: BackendSettings) -> int: - """Count the amount of tokens created.""" + +def add_system_tokens(settings: BackendSettings) -> None: + """Add token for subsystems.""" + if not settings.admin_token and not settings.sshd_token: + # Tokens should be generated manually. + return + engine = get_engine(settings.db_url) init_db(engine) + tokens: list[tuple[str, SubSystem]] = [] + if admin_token := settings.admin_token: + tokens.append((admin_token, SubSystem.ADMIN)) + if sshd_token := settings.sshd_token: + tokens.append((sshd_token, SubSystem.SSHD)) with Session(engine) as session: - count = session.exec(select(func.count("*")).select_from(APIClient)).one() + for token, subsystem in tokens: + hashed_token = hash_token(token) + if existing := session.scalars( + select(APIClient).where(APIClient.subsystem == subsystem) + ).first(): + existing.token = hashed_token + else: + new_token = APIClient(token=hashed_token, subsystem=subsystem) + session.add(new_token) - return count + session.commit() + click.echo("Generated system tokens.") @click.group() @@ -49,27 +80,30 @@ def cli(ctx: click.Context, database: str) -> None: else: settings = BackendSettings() + add_system_tokens(settings) - if settings.generate_initial_tokens: - if count_tokens(settings) == 0: - click.echo("Creating initial tokens for admin and sshd.") - admin_token = generate_token(settings) - sshd_token = generate_token(settings) - click.echo(f"Admin token: {admin_token}") - click.echo(f"SSHD token: {sshd_token}") + # if settings.generate_initial_tokens: + # if count_tokens(settings) == 0: + # click.echo("Creating initial tokens for admin and sshd.") + # admin_token = generate_token(settings) + # sshd_token = generate_token(settings) + # click.echo(f"Admin token: {admin_token}") + # click.echo(f"SSHD token: {sshd_token}") ctx.obj = settings @cli.command("generate-token") +@click.argument("subsystem", type=click.Choice(["sshd", "admin"])) @click.pass_context -def cli_generate_token(ctx: click.Context) -> None: - """Generate a token.""" +def cli_generate_token(ctx: click.Context, subsystem: Literal["sshd", "admin"]) -> None: + """Generate a token for a subsystem..""" settings = cast(BackendSettings, ctx.obj) - token = generate_token(settings) + token = generate_token(settings, subsystem) click.echo("Generated api token:") click.echo(token) + @cli.command("run") @click.option("--host", default="127.0.0.1") @click.option("--port", default=8022, type=click.INT) @@ -77,7 +111,10 @@ def cli_generate_token(ctx: click.Context) -> None: @click.option("--workers", type=click.INT) def cli_run(host: str, port: int, dev: bool, workers: int | None) -> None: """Run the server.""" - uvicorn.run("sshecret_backend.main:app", host=host, port=port, reload=dev, workers=workers) + uvicorn.run( + "sshecret_backend.main:app", host=host, port=port, reload=dev, workers=workers + ) + @cli.command("repl") @click.pass_context diff --git a/packages/sshecret-backend/src/sshecret_backend/db.py b/packages/sshecret-backend/src/sshecret_backend/db.py index 5840def..ff5a0e5 100644 --- a/packages/sshecret-backend/src/sshecret_backend/db.py +++ b/packages/sshecret-backend/src/sshecret_backend/db.py @@ -2,57 +2,109 @@ import logging import secrets +import sqlite3 + from collections.abc import Generator, Callable -from pathlib import Path -from sqlalchemy import Engine -from sqlmodel import Session, create_engine, text -import bcrypt +from typing import Literal +from sqlalchemy import create_engine, Engine, event, select +from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine + +from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.engine import URL - -from .models import APIClient +from .auth import hash_token, verify_token +from .models import APIClient, SubSystem LOG = logging.getLogger(__name__) def setup_database( - db_url: URL | str, + db_url: URL, ) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]: """Setup database.""" - engine = create_engine(db_url, echo=False) - with engine.connect() as connection: - connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only + engine = get_engine(db_url) + SessionLocal = sessionmaker(autocommit=False, autoflush=False, future=True) def get_db_session() -> Generator[Session, None, None]: """Get DB Session.""" - with Session(engine) as session: + session = SessionLocal(bind=engine) + try: yield session + finally: + session.close() return engine, get_db_session def get_engine(url: URL, echo: bool = False) -> Engine: """Initialize the engine.""" - engine = create_engine(url, echo=echo) - with engine.connect() as connection: - connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only + engine = create_engine(url, echo=echo, future=True) + if url.drivername.startswith("sqlite"): + + @event.listens_for(engine, "connect") + def set_sqlite_pragma( + dbapi_connection: sqlite3.Connection, _connection_record: object + ) -> None: + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() return engine -def create_api_token(session: Session, read_write: bool) -> str: - """Create API token.""" - token = secrets.token_urlsafe(32) - pwbytes = token.encode("utf-8") - salt = bcrypt.gensalt() - hashed_bytes = bcrypt.hashpw(password=pwbytes, salt=salt) - hashed = hashed_bytes.decode() +def get_async_engine(url: URL, echo: bool = False) -> AsyncEngine: + """Get an async engine.""" + engine = create_async_engine(url, echo=echo, future=True) + if url.drivername.startswith("sqlite+"): + + @event.listens_for(engine, "connect") + def set_sqlite_pragma( + dbapi_connection: sqlite3.Connection, _connection_record: object + ) -> None: + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + return engine + + + +def create_api_token_with_value(session: Session, token: str, subsystem: Literal["admin", "sshd"]) -> None: + """Create API token with a given value.""" + + existing = session.scalars(select(APIClient).where(APIClient.subsystem == SubSystem(subsystem))).first() + if existing: + if verify_token(token, existing.token): + LOG.info("Token is up to date.") + return + LOG.info("Updating token value for subsystem %s", subsystem) + hashed = hash_token(token) + existing.token=hashed + session.commit() + return + + LOG.info("No existing token found. Creating new") + hashed = hash_token(token) + api_token = APIClient(token=hashed, subsystem=SubSystem(subsystem)) - api_token = APIClient(token=hashed, read_write=read_write) session.add(api_token) session.commit() +def create_api_token(session: Session, subsystem: Literal["admin", "sshd", "test"], recreate: bool = False) -> str: + """Create API token.""" + subsys = SubSystem(subsystem) + token = secrets.token_urlsafe(32) + hashed = hash_token(token) + if existing := session.scalars(select(APIClient).where(APIClient.subsystem == subsys)).first(): + if not recreate: + raise RuntimeError("Error: A token already exist for this subsystem.") + existing.token = hashed + else: + api_token = APIClient(token=hashed, subsystem=subsys) + session.add(api_token) + session.commit() + return token diff --git a/packages/sshecret-backend/src/sshecret_backend/models.py b/packages/sshecret-backend/src/sshecret_backend/models.py index 4030e82..6eab429 100644 --- a/packages/sshecret-backend/src/sshecret_backend/models.py +++ b/packages/sshecret-backend/src/sshecret_backend/models.py @@ -7,128 +7,182 @@ This might require some changes to these schemas. """ +import enum import logging import uuid from datetime import datetime + import sqlalchemy as sa -from sqlmodel import JSON, Column, DateTime, Field, Relationship, SQLModel +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + LOG = logging.getLogger(__name__) -class Client(SQLModel, table=True): - """Client model.""" +class SubSystem(enum.StrEnum): + """Available subsystems.""" - id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) - name: str = Field(unique=True) - description: str | None = None - public_key: str + ADMIN = enum.auto() + SSHD = enum.auto() + BACKEND = enum.auto() + TEST = enum.auto() - created_at: datetime | None = Field( - default=None, - sa_type=sa.DateTime(timezone=True), - sa_column_kwargs={"server_default": sa.func.now()}, - nullable=False, + +class Operation(enum.StrEnum): + """Various operations for the audit logging module.""" + + CREATE = enum.auto() + READ = enum.auto() + UPDATE = enum.auto() + DELETE = enum.auto() + DENY = enum.auto() + PERMIT = enum.auto() + LOGIN = enum.auto() + REGISTER = enum.auto() + NONE = enum.auto() + + +class Base(DeclarativeBase): + pass + + +class Client(Base): + """Clients.""" + + __tablename__: str = "client" + + id: Mapped[uuid.UUID] = mapped_column( + sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + name: Mapped[str] = mapped_column(sa.String, unique=True) + description: Mapped[str | None] = mapped_column(sa.String, nullable=True) + public_key: Mapped[str] = mapped_column(sa.Text) + + created_at: Mapped[datetime] = mapped_column( + sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False ) - updated_at: datetime | None = Field( - default=None, - sa_type=sa.DateTime(timezone=True), - sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()}, + updated_at: Mapped[datetime | None] = mapped_column( + sa.DateTime(timezone=True), + server_default=sa.func.now(), + onupdate=sa.func.now(), ) - secrets: list["ClientSecret"] = Relationship( - back_populates="client", passive_deletes="all" + secrets: Mapped[list["ClientSecret"]] = relationship( + back_populates="client", passive_deletes=True ) - policies: list["ClientAccessPolicy"] = Relationship(back_populates="client") + policies: Mapped[list["ClientAccessPolicy"]] = relationship(back_populates="client") -class ClientAccessPolicy(SQLModel, table=True): +class ClientAccessPolicy(Base): """Client access policies.""" - id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) - source: str - client_id: uuid.UUID | None = Field(foreign_key="client.id", ondelete="CASCADE") - client: Client | None = Relationship(back_populates="policies") + __tablename__: str = "client_access_policy" - created_at: datetime | None = Field( - default=None, - sa_type=sa.DateTime(timezone=True), - sa_column_kwargs={"server_default": sa.func.now()}, - nullable=False, + id: Mapped[uuid.UUID] = mapped_column( + sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + source: Mapped[str] = mapped_column(sa.String) + client_id: Mapped[uuid.UUID | None] = mapped_column( + sa.Uuid(as_uuid=True), sa.ForeignKey("client.id", ondelete="CASCADE") + ) + client: Mapped[Client] = relationship(back_populates="policies") + + created_at: Mapped[datetime] = mapped_column( + sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False ) - updated_at: datetime | None = Field( - default=None, - sa_type=sa.DateTime(timezone=True), - sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()}, + updated_at: Mapped[datetime | None] = mapped_column( + sa.DateTime(timezone=True), + server_default=sa.func.now(), + onupdate=sa.func.now(), ) -class ClientSecret(SQLModel, table=True): +class ClientSecret(Base): """A client secret.""" - id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) - name: str - description: str | None = None - client_id: uuid.UUID | None = Field(foreign_key="client.id", ondelete="CASCADE") - client: Client | None = Relationship(back_populates="secrets") - secret: str - invalidated: bool = Field(default=False) - created_at: datetime | None = Field( - default=None, - sa_type=sa.DateTime(timezone=True), - sa_column_kwargs={"server_default": sa.func.now()}, - nullable=False, + __tablename__: str = "client_secret" + + id: Mapped[uuid.UUID] = mapped_column( + sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + name: Mapped[str] = mapped_column(sa.String) + description: Mapped[str | None] = mapped_column(sa.String, nullable=True) + secret: Mapped[str] = mapped_column(sa.String) + + client_id: Mapped[uuid.UUID | None] = mapped_column( + sa.Uuid(as_uuid=True), sa.ForeignKey("client.id", ondelete="CASCADE") + ) + client: Mapped[Client] = relationship(back_populates="secrets") + invalidated: Mapped[bool] = mapped_column(default=False) + + created_at: Mapped[datetime] = mapped_column( + sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False ) - updated_at: datetime | None = Field( - default=None, - sa_type=sa.DateTime(timezone=True), - sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()}, + updated_at: Mapped[datetime | None] = mapped_column( + sa.DateTime(timezone=True), + server_default=sa.func.now(), + onupdate=sa.func.now(), ) -class AuditLog(SQLModel, table=True): +class APIClient(Base): + """A client on the API. + + This should eventually get more granular permissions. + """ + + __tablename__: str = "api_client" + + id: Mapped[uuid.UUID] = mapped_column( + sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + subsystem: Mapped[SubSystem | None] = mapped_column(sa.String, nullable=True) + token: Mapped[str] = mapped_column(sa.String) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False + ) + updated_at: Mapped[datetime | None] = mapped_column( + sa.DateTime(timezone=True), + server_default=sa.func.now(), + onupdate=sa.func.now(), + ) + + +class AuditLog(Base): """Audit log. This is implemented without any foreign keys to avoid losing data on deletions. """ - id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) - subsystem: str - message: str - operation: str - client_id: uuid.UUID | None = None - client_name: str | None = None - origin: str | None = None - Field(default=None, sa_column=Column(JSON)) - - timestamp: datetime | None = Field( - default=None, - sa_type=sa.DateTime(timezone=True), - sa_column_kwargs={"server_default": sa.func.now()}, - nullable=False, + __tablename__: str = "audit_log" + id: Mapped[uuid.UUID] = mapped_column( + sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4 ) + subsystem: Mapped[SubSystem] = mapped_column(sa.String) + message: Mapped[str] = mapped_column(sa.String) + operation: Mapped[Operation] = mapped_column(sa.String) + client_id: Mapped[uuid.UUID | None] = mapped_column( + sa.Uuid(as_uuid=True), nullable=True + ) + data: Mapped[dict[str, str] | None] = mapped_column(sa.JSON, nullable=True) + client_name: Mapped[str | None] = mapped_column(sa.String, nullable=True) + secret_id: Mapped[uuid.UUID | None] = mapped_column( + sa.Uuid(as_uuid=True), nullable=True + ) + secret_name: Mapped[str | None] = mapped_column(sa.String, nullable=True) -class APIClient(SQLModel, table=True): - """Stores API Keys.""" - - id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) - token: str - read_write: bool - created_at: datetime | None = Field( - default=None, - sa_type=sa.DateTime(timezone=True), - sa_column_kwargs={"server_default": sa.func.now()}, - nullable=False, + origin: Mapped[str | None] = mapped_column(sa.String, nullable=True) + timestamp: Mapped[datetime] = mapped_column( + sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False ) def init_db(engine: sa.Engine) -> None: - """Create database.""" - LOG.info("Running init_db") - SQLModel.metadata.create_all(engine) + """Initialize database.""" + Base.metadata.create_all(engine) diff --git a/packages/sshecret-backend/src/sshecret_backend/settings.py b/packages/sshecret-backend/src/sshecret_backend/settings.py index 7159fe5..0ddb63e 100644 --- a/packages/sshecret-backend/src/sshecret_backend/settings.py +++ b/packages/sshecret-backend/src/sshecret_backend/settings.py @@ -1,12 +1,10 @@ """Settings management.""" from pathlib import Path -from typing import Annotated, Any -from pydantic import Field, field_validator +from pydantic import Field from pydantic_settings import ( BaseSettings, SettingsConfigDict, - ForceDecode, ) from sqlalchemy import URL @@ -22,24 +20,19 @@ class BackendSettings(BaseSettings): ) database: str = Field(default=DEFAULT_DATABASE) - generate_initial_tokens: Annotated[bool, ForceDecode] = Field(default=False) - - @field_validator("generate_initial_tokens", mode="before") - @classmethod - def cast_bool(cls, value: Any) -> bool: - """Ensure we catch the boolean.""" - if isinstance(value, str): - if value.lower() in ("1", "true", "on"): - return True - if value.lower() in ("0", "false", "off"): - return False - return bool(value) + admin_token: str | None = Field(default=None, alias="sshecret_admin_backend_token") + sshd_token: str | None = Field(default=None, alias="sshecret_sshd_backend_token") @property def db_url(self) -> URL: """Construct database url.""" return URL.create(drivername="sqlite", database=self.database) + @property + def async_db_url(self) -> URL: + """Construct database url with sync handling.""" + return URL.create(drivername="sqlite+aiosqlite", database=self.database) + @property def db_exists(self) -> bool: """Check if databatase exists.""" diff --git a/packages/sshecret-backend/src/sshecret_backend/testing.py b/packages/sshecret-backend/src/sshecret_backend/testing.py index cae9ba6..f18664f 100644 --- a/packages/sshecret-backend/src/sshecret_backend/testing.py +++ b/packages/sshecret-backend/src/sshecret_backend/testing.py @@ -1,8 +1,9 @@ """Test helpers.""" import logging -from sqlmodel import Session from sshecret_backend.settings import BackendSettings + +from sqlalchemy.orm import Session from .models import init_db from .db import create_api_token, setup_database @@ -14,4 +15,4 @@ def create_test_token(settings: BackendSettings) -> str: engine, _setupdb = setup_database(settings.db_url) with Session(engine) as session: init_db(engine) - return create_api_token(session, True) + return create_api_token(session, "test") diff --git a/packages/sshecret-backend/src/sshecret_backend/types.py b/packages/sshecret-backend/src/sshecret_backend/types.py index c2cec04..39ba47b 100644 --- a/packages/sshecret-backend/src/sshecret_backend/types.py +++ b/packages/sshecret-backend/src/sshecret_backend/types.py @@ -2,7 +2,7 @@ from collections.abc import Callable, Generator -from sqlmodel import Session +from sqlalchemy.orm import Session DBSessionDep = Callable[[], Generator[Session, None, None]] diff --git a/packages/sshecret-backend/src/sshecret_backend/view_models.py b/packages/sshecret-backend/src/sshecret_backend/view_models.py index c9cb847..cc92384 100644 --- a/packages/sshecret-backend/src/sshecret_backend/view_models.py +++ b/packages/sshecret-backend/src/sshecret_backend/view_models.py @@ -4,15 +4,14 @@ import uuid from datetime import datetime from typing import Annotated, Self, override -from sqlmodel import Field, SQLModel -from pydantic import AfterValidator, IPvAnyAddress, IPvAnyNetwork +from pydantic import AfterValidator, BaseModel, Field, IPvAnyAddress, IPvAnyNetwork from sshecret.crypto import public_key_validator from . import models -class ClientView(SQLModel): +class ClientView(BaseModel): """View for a single client.""" id: uuid.UUID @@ -50,7 +49,7 @@ class ClientView(SQLModel): return view -class ClientQueryResult(SQLModel): +class ClientQueryResult(BaseModel): """Result class for queries towards the client list.""" clients: list[ClientView] = Field(default_factory=list) @@ -58,7 +57,7 @@ class ClientQueryResult(SQLModel): remaining_results: int -class ClientCreate(SQLModel): +class ClientCreate(BaseModel): """Model to create a client.""" name: str @@ -74,19 +73,19 @@ class ClientCreate(SQLModel): ) -class ClientUpdate(SQLModel): +class ClientUpdate(BaseModel): """Model to update the client public key.""" public_key: Annotated[str, AfterValidator(public_key_validator)] -class BodyValue(SQLModel): +class BodyValue(BaseModel): """A generic model with just a value parameter.""" value: str -class ClientSecretPublic(SQLModel): +class ClientSecretPublic(BaseModel): """Public model to manage client secrets.""" name: str @@ -122,7 +121,7 @@ class ClientSecretResponse(ClientSecretPublic): ) -class ClientPolicyView(SQLModel): +class ClientPolicyView(BaseModel): """Update object for client policy.""" sources: list[str] = ["0.0.0.0/0", "::/0"] @@ -135,27 +134,27 @@ class ClientPolicyView(SQLModel): return cls(sources=[policy.source for policy in client.policies]) -class ClientPolicyUpdate(SQLModel): +class ClientPolicyUpdate(BaseModel): """Model for updating policies.""" sources: list[IPvAnyAddress | IPvAnyNetwork] -class ClientSecretList(SQLModel): +class ClientSecretList(BaseModel): """Model for aggregating identically named secrets.""" name: str clients: list[str] -class ClientReference(SQLModel): +class ClientReference(BaseModel): """Reference to a client.""" id: str name: str -class ClientSecretDetailList(SQLModel): +class ClientSecretDetailList(BaseModel): """A more detailed version of the ClientSecretList.""" name: str @@ -163,7 +162,22 @@ class ClientSecretDetailList(SQLModel): clients: list[ClientReference] = Field(default_factory=list) -class AuditInfo(SQLModel): +class AuditView(BaseModel): + """Audit log view.""" + + + id: uuid.UUID | None = None + subsystem: models.SubSystem + message: str + operation: models.Operation + data: dict[str, str] | None = None + client_id: uuid.UUID | None = None + client_name: str | None = None + origin: str | None = None + timestamp: datetime | None = None + + +class AuditInfo(BaseModel): """Information about audit information.""" entries: int diff --git a/packages/sshecret-backend/tests/test_backend.py b/packages/sshecret-backend/tests/test_backend.py index 69c7dc0..fc6031c 100644 --- a/packages/sshecret-backend/tests/test_backend.py +++ b/packages/sshecret-backend/tests/test_backend.py @@ -10,7 +10,7 @@ from fastapi.testclient import TestClient from sshecret.crypto import generate_private_key, generate_public_key_string from sshecret_backend.app import create_backend_app from sshecret_backend.testing import create_test_token -from sshecret_backend.models import AuditLog +from sshecret_backend.view_models import AuditView from sshecret_backend.settings import BackendSettings @@ -53,7 +53,7 @@ def create_client_fixture(tmp_path: Path): db_file = tmp_path / "backend.db" print(f"DB File: {db_file.absolute()}") - settings = BackendSettings(db_url=f"sqlite:///{db_file.absolute()}") + settings = BackendSettings(database=str(db_file.absolute())) app = create_backend_app(settings) token = create_test_token(settings) @@ -213,7 +213,7 @@ def test_audit_logging(test_client: TestClient) -> None: assert len(audit_logs) > 0 for entry in audit_logs: # Let's try to reassemble the objects - audit_log = AuditLog.model_validate(entry) + audit_log = AuditView.model_validate(entry) assert audit_log is not None @@ -522,9 +522,8 @@ def test_operations_with_id(test_client: TestClient) -> None: def test_write_audit_log(test_client: TestClient) -> None: """Test writing to the audit log.""" params = { - "object": "Test", - "operation": "TEST", - "object_id": "Something", + "subsystem": "backend", + "operation": "read", "message": "Test Message" } resp = test_client.post("/api/v1/audit", json=params) diff --git a/src/sshecret/backend/__init__.py b/src/sshecret/backend/__init__.py index be1b7d5..619e3d1 100644 --- a/src/sshecret/backend/__init__.py +++ b/src/sshecret/backend/__init__.py @@ -8,8 +8,10 @@ from .models import ( ClientReference, ClientSecret, DetailedSecrets, + Operation, Policy, Secret, + SubSystem, FilterType, ) @@ -22,7 +24,9 @@ __all__ = [ "ClientSecret", "DetailedSecrets", "FilterType", + "Operation", "Policy", "Secret", + "SubSystem", "SshecretBackend", ] diff --git a/src/sshecret/backend/api.py b/src/sshecret/backend/api.py index c35db51..7125c48 100644 --- a/src/sshecret/backend/api.py +++ b/src/sshecret/backend/api.py @@ -1,21 +1,26 @@ """Backend client. +This is an API calling the HTTP API of the sshecret-backend package so that the +admin and sshd library do not need to implement the same """ import logging -from typing import Any, Self +from typing import Any, Self, override import httpx from pydantic import TypeAdapter from .models import ( + AuditInfo, AuditLog, Client, ClientSecret, ClientQueryResult, ClientFilter, DetailedSecrets, + Operation, Secret, + SubSystem, ) from .exceptions import BackendValidationError, BackendConnectionError from .utils import validate_public_key @@ -84,12 +89,14 @@ class ClientQueryIterator: raise StopAsyncIteration -class SshecretBackend: - """Backend interface.""" +class BaseBackend: + """Base backend class.""" def __init__(self, backend_url: str, api_token: str) -> None: """Initialize backend client.""" + self._backend_url: str = backend_url + self._api_token: str = api_token url = httpx.URL(backend_url) self.http_client: httpx.AsyncClient = httpx.AsyncClient( @@ -101,31 +108,55 @@ class SshecretBackend: base_url=url, ) - async def _get(self, path: str) -> httpx.Response: + def _handle_response(self, response: httpx.Response) -> httpx.Response: + """Handle response.""" + LOG.debug("Handling response with status_code %s", response.status_code) + if response.status_code == 422: + LOG.error("Validation error from backend:\n%s", response.text) + raise BackendValidationError(response.text) + if response.status_code != 404 and str(response.status_code).startswith("4"): + raise BackendConnectionError("Error from backend: %s", response.text) + + return response + + async def _get( + self, path: str, params: dict[str, str] | None = None + ) -> httpx.Response: """Perform a get request.""" try: - return await self.http_client.get(path) + response = await self.http_client.get(path, params=params) + return self._handle_response(response) except httpx.ConnectError as e: - raise BackendConnectionError() from e + raise BackendConnectionError("Could not connect to backend.") from e async def _delete(self, path: str) -> httpx.Response: """Perform a delete request.""" try: - return await self.http_client.delete(path) + response = await self.http_client.delete(path) + return self._handle_response(response) except httpx.ConnectError as e: raise BackendConnectionError() from e async def _post(self, path: str, json: Any | None = None) -> httpx.Response: """Perform a POST request.""" try: - return await self.http_client.post(path, json=json) + response = await self.http_client.post(path, json=json) + return self._handle_response(response) + except httpx.ConnectError as e: + raise BackendConnectionError() from e + + def _post_sync(self, path: str, json: Any | None = None) -> httpx.Response: + """Perform a synchronous post request.""" + try: + return self._handle_response(self.sync_client.post(path, json=json)) except httpx.ConnectError as e: raise BackendConnectionError() from e async def _put(self, path: str, json: Any | None = None) -> httpx.Response: """Perform a PUT request.""" try: - return await self.http_client.put(path, json=json) + response = await self.http_client.put(path, json=json) + return self._handle_response(response) except httpx.ConnectError as e: raise BackendConnectionError() from e @@ -134,175 +165,90 @@ class SshecretBackend: response = await self._get(path) return response - async def create_client( - self, name: str, public_key: str, description: str | None = None + +class AuditAPI(BaseBackend): + """API for the audit logging.""" + + @override + def __init__(self, backend_url: str, api_token: str, subsystem: str) -> None: + """Initialize backend client.""" + super().__init__(backend_url, api_token) + self.subsystem: SubSystem = SubSystem(subsystem) + + def _create_model( + self, + operation: Operation, + message: str, + origin: str, + client: Client | None = None, + secret: ClientSecret | None = None, + secret_name: str | None = None, + data: dict[str, str] | None = None, + ) -> AuditLog: + """Create the audit log object.""" + model = AuditLog( + subsystem=self.subsystem, + operation=operation, + message=message, + origin=origin, + ) + if client: + model.client_id = str(client.id) or None + model.client_name = client.name + + if secret: + model.secret_name = secret.name + elif secret_name: + model.secret_name = secret_name + if data: + model.data = data + + return model + + def write_model(self, model: AuditLog) -> None: + """Write model.""" + path = f"/api/v1/audit/" + self.sync_client.post(path, json=model.model_dump()) + + async def write_model_async(self, model: AuditLog) -> None: + """Write model async.""" + path = f"/api/v1/audit/" + await self._post(path, json=model.model_dump()) + + def write( + self, + operation: Operation, + message: str, + origin: str, + client: Client | None = None, + secret: ClientSecret | None = None, + secret_name: str | None = None, + **data: str, ) -> None: - """Register a new client.""" - if not validate_public_key(public_key): - raise BackendValidationError("Error: Invalid public key format.") - data = { - "name": name, - "public_key": public_key, - } - if description: - data["description"] = description - path = "/api/v1/clients/" - response = await self._post(path, json=data) + """Write an audit entry.""" + model = self._create_model( + operation, message, origin, client, secret, secret_name, data + ) - response.raise_for_status() + self.write_model(model) - async def get_clients(self, filter: ClientFilter | None = None) -> list[Client]: - """Get all clients.""" - clients: list[Client] = [] - async for client in ClientQueryIterator(self.http_client, filter): - clients.append(client) - - return clients - - async def get_client(self, name: str) -> Client | None: - """Lookup a client on username.""" - path = f"/api/v1/clients/{name}" - response = await self.request(path) - if response.status_code == 404: - return None - response.raise_for_status() - client = Client.model_validate(response.json()) - return client - - async def get_client_by_id(self, id: str) -> Client | None: - """Lookup a client on username.""" - path = f"/api/v1/clients/id/{id}" - response = await self.request(path) - if response.status_code == 404: - return None - response.raise_for_status() - client = Client.model_validate(response.json()) - return client - - async def delete_client(self, client_name: str) -> None: - """Delete a client.""" - path = f"/api/v1/clients/{client_name}" - response = await self._delete(path) - - response.raise_for_status() - - async def delete_client_by_id(self, id: str) -> None: - """Delete a client.""" - path = f"/api/v1/clients/id/{id}" - response = await self._delete(path) - - response.raise_for_status() - - async def create_client_secret( - self, client_name: str, secret_name: str, encrypted_secret: str + async def write_async( + self, + operation: Operation, + message: str, + origin: str, + client: Client | None = None, + secret: ClientSecret | None = None, + secret_name: str | None = None, + **data: str, ) -> None: - """Create a secret. + """Write an audit entry.""" + model = self._create_model( + operation, message, origin, client, secret, secret_name, data + ) + await self.write_model_async(model) - This will overwrite any existing secret with that name. - """ - path = f"api/v1/clients/{client_name}/secrets/{secret_name}" - response = await self._put(path, json={"value": encrypted_secret}) - - response.raise_for_status() - - async def get_client_secret(self, name: str, secret_name: str) -> str: - """Fetch a secret.""" - path = f"/api/v1/clients/{name}/secrets/{secret_name}" - response = await self.request(path) - response.raise_for_status() - secret = ClientSecret.model_validate(response.json()) - return secret.secret - - async def delete_client_secret(self, client_name: str, secret_name: str) -> None: - """Delete a secret from a client.""" - path = f"api/v1/clients/{client_name}/secrets/{secret_name}" - response = await self._delete(path) - - response.raise_for_status() - - async def update_client(self, client: Client) -> Client: - """Update the client.""" - path = f"/api/v1/clients/{client.name}" - client_update = { - "name": client.name, - "description": client.description, - "public_key": client.public_key, - } - response = await self._put(path, json=client_update) - LOG.info("Response %s", response.text) - - response.raise_for_status() - if client.policies: - await self.update_client_sources( - str(client.id), [str(source) for source in client.policies] - ) - return client - - async def update_client_key(self, client_name: str, public_key: str) -> None: - """Update the client key.""" - path = f"/api/v1/clients/{client_name}/public-key" - response = await self._post(path, json={"public_key": public_key}) - - response.raise_for_status() - - async def update_client_sources( - self, client_name: str, addresses: list[str] | None - ) -> None: - """Update client source addresses. - - Pass None to sources to allow from all. - """ - if not addresses: - addresses = [] - - path = f"/api/v1/clients/{client_name}/policies/" - response = await self._put(path, json={"sources": addresses}) - - response.raise_for_status() - - async def get_detailed_secrets(self) -> list[DetailedSecrets]: - """Get detailed list of secrets.""" - path = "/api/v1/secrets/detailed/" - response = await self._get(path) - response.raise_for_status() - - secret_list = TypeAdapter(list[DetailedSecrets]) - return secret_list.validate_python(response.json()) - - async def get_secrets(self) -> list[Secret]: - """Get Secrets. - - This provides a list of secret names and which clients have them. - """ - path = "/api/v1/secrets/" - response = await self._get(path) - - response.raise_for_status() - - secret_list = TypeAdapter(list[Secret]) - return secret_list.validate_python(response.json()) - - async def get_secret(self, name: str) -> Secret | None: - """Get clients mapped to a single secret.""" - path = f"/api/v1/secrets/{name}" - response = await self._get(path) - if response.status_code == 404: - return None - response.raise_for_status() - - return Secret.model_validate(response.json()) - - async def get_detailed_secret(self, name: str) -> DetailedSecrets | None: - """Get clients mapped to a single secret.""" - path = f"/api/v1/secrets/{name}/detailed" - response = await self._get(path) - if response.status_code == 404: - return None - response.raise_for_status() - - return DetailedSecrets.model_validate(response.json()) - - async def get_audit_log( + async def get( self, offset: int = 0, limit: int = 100, @@ -321,30 +267,169 @@ class SshecretBackend: if subsystem: params["filter_subsystem"] = subsystem - response = await self.http_client.get(path, params=params) - response.raise_for_status() + response = await self._get(path, params=params) audit_log_adapter = TypeAdapter(list[AuditLog]) return audit_log_adapter.validate_python(response.json()) - async def add_audit_log(self, entry: AuditLog) -> None: - """Add audit log entry.""" - path = f"/api/v1/audit/" - - response = await self.http_client.post(path, json=entry.model_dump()) - response.raise_for_status() - - async def get_audit_log_count(self) -> int: + async def count(self) -> int: """Get amount of messages in the audit log.""" path = f"/api/v1/audit/info" response = await self._get(path) - response.raise_for_status() - data = response.json() - return int(data["entries"]) + audit_info = AuditInfo.model_validate(response.json()) + return audit_info.entries - def add_audit_log_sync(self, entry: AuditLog) -> None: - """Add audit log entry.""" - path = f"/api/v1/audit/" - LOG.info("AUDIT LOG SYNC %r", entry) - response = self.sync_client.post(path, json=entry.model_dump()) - response.raise_for_status() +class SshecretBackend(BaseBackend): + """Backend interface.""" + + async def create_client( + self, name: str, public_key: str, description: str | None = None + ) -> None: + """Register a new client.""" + if not validate_public_key(public_key): + raise BackendValidationError("Error: Invalid public key format.") + data = { + "name": name, + "public_key": public_key, + } + if description: + data["description"] = description + path = "/api/v1/clients/" + response = await self._post(path, json=data) + + async def get_clients(self, filter: ClientFilter | None = None) -> list[Client]: + """Get all clients.""" + clients: list[Client] = [] + async for client in ClientQueryIterator(self.http_client, filter): + clients.append(client) + + return clients + + async def get_client(self, name: str) -> Client | None: + """Lookup a client on username.""" + path = f"/api/v1/clients/{name}" + response = await self._get(path) + if response.status_code == 404: + return None + client = Client.model_validate(response.json()) + return client + + async def get_client_by_id(self, id: str) -> Client | None: + """Lookup a client on username.""" + path = f"/api/v1/clients/id/{id}" + response = await self._get(path) + if response.status_code == 404: + return None + client = Client.model_validate(response.json()) + return client + + async def delete_client(self, client_name: str) -> None: + """Delete a client.""" + path = f"/api/v1/clients/{client_name}" + response = await self._delete(path) + + async def delete_client_by_id(self, id: str) -> None: + """Delete a client.""" + path = f"/api/v1/clients/id/{id}" + response = await self._delete(path) + + async def create_client_secret( + self, client_name: str, secret_name: str, encrypted_secret: str + ) -> None: + """Create a secret. + + This will overwrite any existing secret with that name. + """ + path = f"api/v1/clients/{client_name}/secrets/{secret_name}" + response = await self._put(path, json={"value": encrypted_secret}) + + async def get_client_secret(self, name: str, secret_name: str) -> str | None: + """Fetch a secret.""" + path = f"/api/v1/clients/{name}/secrets/{secret_name}" + response = await self._get(path) + if response.status_code == 404: + return None + secret = ClientSecret.model_validate(response.json()) + return secret.secret + + async def delete_client_secret(self, client_name: str, secret_name: str) -> None: + """Delete a secret from a client.""" + path = f"api/v1/clients/{client_name}/secrets/{secret_name}" + await self._delete(path) + + async def update_client(self, client: Client) -> Client: + """Update the client.""" + path = f"/api/v1/clients/{client.name}" + client_update = { + "name": client.name, + "description": client.description, + "public_key": client.public_key, + } + response = await self._put(path, json=client_update) + LOG.info("Response %s", response.text) + + if client.policies: + await self.update_client_sources( + str(client.id), [str(source) for source in client.policies] + ) + return client + + async def update_client_key(self, client_name: str, public_key: str) -> None: + """Update the client key.""" + path = f"/api/v1/clients/{client_name}/public-key" + await self._post(path, json={"public_key": public_key}) + + async def update_client_sources( + self, client_name: str, addresses: list[str] | None + ) -> None: + """Update client source addresses. + + Pass None to sources to allow from all. + """ + if not addresses: + addresses = [] + + path = f"/api/v1/clients/{client_name}/policies/" + await self._put(path, json={"sources": addresses}) + + async def get_detailed_secrets(self) -> list[DetailedSecrets]: + """Get detailed list of secrets.""" + path = "/api/v1/secrets/detailed/" + response = await self._get(path) + + secret_list = TypeAdapter(list[DetailedSecrets]) + return secret_list.validate_python(response.json()) + + async def get_secrets(self) -> list[Secret]: + """Get Secrets. + + This provides a list of secret names and which clients have them. + """ + path = "/api/v1/secrets/" + response = await self._get(path) + + secret_list = TypeAdapter(list[Secret]) + return secret_list.validate_python(response.json()) + + async def get_secret(self, name: str) -> Secret | None: + """Get clients mapped to a single secret.""" + path = f"/api/v1/secrets/{name}" + response = await self._get(path) + if response.status_code == 404: + return None + + return Secret.model_validate(response.json()) + + async def get_detailed_secret(self, name: str) -> DetailedSecrets | None: + """Get clients mapped to a single secret.""" + path = f"/api/v1/secrets/{name}/detailed" + response = await self._get(path) + if response.status_code == 404: + return None + + return DetailedSecrets.model_validate(response.json()) + + def audit(self, subsystem: SubSystem) -> AuditAPI: + """Create the audit API.""" + audit = AuditAPI(self._backend_url, self._api_token, subsystem) + return audit diff --git a/src/sshecret/backend/models.py b/src/sshecret/backend/models.py index 9ca4523..423f8fd 100644 --- a/src/sshecret/backend/models.py +++ b/src/sshecret/backend/models.py @@ -17,6 +17,27 @@ class FilterType(enum.StrEnum): CONTAINS = "contains" +class SubSystem(enum.StrEnum): + """Available subsystems.""" + + ADMIN = enum.auto() + SSHD = enum.auto() + BACKEND = enum.auto() + + +class Operation(enum.StrEnum): + """Various operations for the audit logging module.""" + + CREATE = enum.auto() + READ = enum.auto() + UPDATE = enum.auto() + DELETE = enum.auto() + DENY = enum.auto() + PERMIT = enum.auto() + LOGIN = enum.auto() + NONE = enum.auto() + + class Client(BaseModel): """Implementation of the backend class ClientView.""" @@ -101,15 +122,22 @@ class ClientFilter(BaseModel): class AuditLog(BaseModel): - """Implementation of the backend class AuditLog.""" + """Implementation of the backend class AuditView.""" id: str | None = None - subsystem: str | None = None - object: str | None = None - object_id: str | None = None - operation: str + subsystem: SubSystem + operation: Operation client_id: str | None = None client_name: str | None = None + secret_id: str | None = None + secret_name: str | None = None + data: dict[str, str] | None = None message: str origin: str | None = None timestamp: datetime | None = None + + +class AuditInfo(BaseModel): + """Implementation of the backend class AuditInfo.""" + + entries: int