diff --git a/packages/sshecret-backend/README.md b/packages/sshecret-backend/README.md new file mode 100644 index 0000000..c411364 --- /dev/null +++ b/packages/sshecret-backend/README.md @@ -0,0 +1,11 @@ +# Backend + +This is the backend part of the SSHecret library. + +The principle here is that it stores encrypted secrets that can be long to clients. + +It does not store much data about the clients, and purely manages access to +encrypted values based on the SSH RSA fingerprint and an optional list of +allowed IP addresses. + +While there is a model for the client, it is purely meant for aggregation and utility. diff --git a/packages/sshecret-backend/pyproject.toml b/packages/sshecret-backend/pyproject.toml new file mode 100644 index 0000000..9cedf2e --- /dev/null +++ b/packages/sshecret-backend/pyproject.toml @@ -0,0 +1,23 @@ +[project] +name = "sshecret-backend" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +authors = [ + { name = "Allan Eising", email = "allan@eising.dk" } +] +requires-python = ">=3.13" +dependencies = [ + "passlib[bcrypt]>=1.7.4", + "pydantic>=2.10.6", + "pytest>=8.3.5", + "sqlmodel>=0.0.24", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.pytest.ini_options] +log_cli = true +log_cli_level = "INFO" diff --git a/packages/sshecret-backend/src/sshecret_backend/__init__.py b/packages/sshecret-backend/src/sshecret_backend/__init__.py new file mode 100644 index 0000000..c2288fe --- /dev/null +++ b/packages/sshecret-backend/src/sshecret_backend/__init__.py @@ -0,0 +1 @@ +"""Sshecret backend.""" diff --git a/packages/sshecret-backend/src/sshecret_backend/app.py b/packages/sshecret-backend/src/sshecret_backend/app.py new file mode 100644 index 0000000..1e995dd --- /dev/null +++ b/packages/sshecret-backend/src/sshecret_backend/app.py @@ -0,0 +1,289 @@ +"""FastAPI api.""" + +import logging +from contextlib import asynccontextmanager +from typing import Annotated +from collections.abc import Sequence + +import bcrypt +from fastapi import APIRouter, Depends, FastAPI, Header, HTTPException, Query, Request +from sqlmodel import Session, select + +from . import audit +from .db import get_engine +from .models import APIClient, AuditLog, Client, ClientSecret, init_db +from .settings import get_settings +from .view_models import ( + BodyValue, + ClientCreate, + ClientListResponse, + ClientSecretPublic, + ClientSecretResponse, + ClientUpdate, + ClientView, +) + +settings = get_settings() +engine = get_engine(settings.db_file) + + +LOG = logging.getLogger(__name__) + +API_VERSION = "v1" + + +def verify_token(token: str, stored_hash: str) -> bool: + """Verify token.""" + token_bytes = token.encode("utf-8") + stored_bytes = stored_hash.encode("utf-8") + return bcrypt.checkpw(token_bytes, stored_bytes) + + +@asynccontextmanager +async def lifespan(_app: FastAPI): + """Create database before starting the server.""" + init_db(engine) + yield + + +async def get_session(): + """Get the session.""" + with Session(engine) as session: + yield session + + +async def validate_token( + x_api_token: Annotated[str, Header()], + session: Annotated[Session, Depends(get_session)], +) -> str: + """Validate token.""" + LOG.debug("Validating token %s", x_api_token) + statement = select(APIClient) + results = session.exec(statement) + valid = False + for result in results: + if verify_token(x_api_token, result.token): + valid = True + LOG.debug("Token is valid") + break + + if not valid: + LOG.debug("Token is not valid.") + raise HTTPException(status_code=401, detail="unauthorized. invalid api token.") + return x_api_token + + +async def get_client_by_name(session: Session, name: str) -> Client | None: + """Get client by name.""" + client_filter = select(Client).where(Client.name == name) + client_results = session.exec(client_filter) + return client_results.first() + + +async def lookup_client_secret( + session: Session, client: Client, name: str +) -> ClientSecret | None: + """Look up a secret for a client.""" + statement = ( + select(ClientSecret) + .where(ClientSecret.client_id == client.id) + .where(ClientSecret.name == name) + ) + results = session.exec(statement) + return results.first() + + +LOG.info("Initializing app.") +backend_api = APIRouter( + prefix=f"/api/{API_VERSION}", + lifespan=lifespan, + dependencies=[Depends(validate_token)], +) + + +@backend_api.get("/clients/") +async def get_clients( + session: Annotated[Session, Depends(get_session)] +) -> list[ClientListResponse]: + """Get clients.""" + statement = select(Client) + results = session.exec(statement) + clients = list(results) + return ClientListResponse.from_clients(clients) + + +@backend_api.get("/clients/{name}") +async def get_client( + request: Request, name: str, session: Annotated[Session, Depends(get_session)] +) -> ClientView: + """Fetch a client.""" + statement = select(Client).where(Client.name == name) + results = session.exec(statement) + client = results.first() + if not client: + raise HTTPException( + status_code=404, detail="Cannot find a client with the given name." + ) + audit.audit_access_secrets(session, request, client) + return ClientView.from_client(client) + + +@backend_api.post("/clients/{name}/update_fingerprint") +async def update_client_fingerprint( + request: Request, + name: str, + client_update: ClientUpdate, + session: Annotated[Session, Depends(get_session)], +) -> ClientView: + """Update the client fingerprint. + + This invalidates all secrets. + """ + statement = select(Client).where(Client.name == name) + results = session.exec(statement) + client = results.first() + if not client: + raise HTTPException( + status_code=404, detail="Cannot find a client with the given name." + ) + client.fingerprint = client_update.fingerprint + for secret in session.exec( + select(ClientSecret).where(ClientSecret.client_id == client.id) + ).all(): + LOG.debug("Invalidated secret %s", secret.id) + secret.invalidated = True + secret.client_id = None + secret.client = None + + session.add(client) + session.refresh(client) + session.commit() + audit.audit_invalidate_secrets(session, request, client) + + return ClientView.from_client(client) + + +@backend_api.post("/clients/") +async def create_client( + request: Request, + client: ClientCreate, + session: Annotated[Session, Depends(get_session)], +) -> ClientView: + """Create client.""" + db_client = Client.model_validate(client) + session.add(db_client) + session.commit() + session.refresh(db_client) + audit.audit_create_client(session, request, db_client) + return ClientView.from_client(db_client) + + +@backend_api.post("/clients/{name}/secrets/") +async def add_secret_to_client( + request: Request, + name: str, + client_secret: ClientSecretPublic, + session: Annotated[Session, Depends(get_session)], +) -> None: + """Add secret to a client.""" + client = await get_client_by_name(session, name) + if not client: + raise HTTPException( + status_code=404, detail="Cannot find a client with the given name." + ) + + existing_secret = await lookup_client_secret(session, client, client_secret.name) + if existing_secret: + raise HTTPException( + status_code=400, + detail="Cannot add a secret. A different secret with the same name already exists.", + ) + db_secret = ClientSecret( + name=client_secret.name, client_id=client.id, secret=client_secret.secret + ) + session.add(db_secret) + session.commit() + session.refresh(db_secret) + audit.audit_create_secret(session, request, client, db_secret) + + +@backend_api.put("/clients/{name}/secrets/{secret_name}") +async def update_client_secret( + request: Request, + name: str, + secret_name: str, + secret_data: BodyValue, + session: Annotated[Session, Depends(get_session)], +) -> ClientSecretResponse: + """Update a client secret. + + This can also be used for destructive creates. + """ + client = await get_client_by_name(session, name) + if not client: + raise HTTPException( + status_code=404, detail="Cannot find a client with the given name." + ) + + existing_secret = await lookup_client_secret(session, client, secret_name) + if existing_secret: + existing_secret.secret = secret_data.value + session.add(existing_secret) + session.commit() + session.refresh(existing_secret) + audit.audit_update_secret(session, request, client, existing_secret) + return ClientSecretResponse.from_client_secret(existing_secret) + + db_secret = ClientSecret( + name=secret_name, + client_id=client.id, + secret=secret_data.value, + ) + session.add(db_secret) + session.commit() + session.refresh(db_secret) + audit.audit_create_secret(session, request, client, db_secret) + return ClientSecretResponse.from_client_secret(db_secret) + + +@backend_api.get("/clients/{name}/secrets/{secret_name}") +async def request_client_secret( + request: Request, + name: str, + secret_name: str, + session: Annotated[Session, Depends(get_session)], +) -> ClientSecretResponse: + """Get a client secret.""" + client = await get_client_by_name(session, name) + if not client: + raise HTTPException( + status_code=404, detail="Cannot find a client with the given name." + ) + + secret = await lookup_client_secret(session, client, secret_name) + if not secret: + raise HTTPException( + status_code=404, detail="Cannot find a secret with the given name." + ) + + response_model = ClientSecretResponse.from_client_secret(secret) + audit.audit_access_secret(session, request, client, secret) + return response_model + + +@backend_api.get("/audit/", response_model=list[AuditLog]) +async def get_audit_logs( + request: Request, + session: Annotated[Session, Depends(get_session)], + offset: Annotated[int, Query()] = 0, + limit: Annotated[int, Query(le=100)] = 100, + filter_client: Annotated[str | None, Query()] = None, +) -> Sequence[AuditLog]: + """Get audit logs.""" + audit.audit_access_audit_log(session, request) + statement = select(AuditLog).offset(offset).limit(limit) + if filter_client: + statement = statement.where(AuditLog.client_name == filter_client) + + results = session.exec(statement).all() + return results diff --git a/packages/sshecret-backend/src/sshecret_backend/audit.py b/packages/sshecret-backend/src/sshecret_backend/audit.py new file mode 100644 index 0000000..6b74cf8 --- /dev/null +++ b/packages/sshecret-backend/src/sshecret_backend/audit.py @@ -0,0 +1,151 @@ +"""Audit methods.""" + +from collections.abc import Sequence +from fastapi import Request +from sqlmodel import Session, select +from .models import AuditLog, Client, ClientSecret + + +def _get_origin(request: Request) -> str | None: + """Resolve the request origin.""" + origin: str | None = None + if request.client: + origin = request.client.host + + return origin + + +def _write_audit_log( + session: Session, request: Request, entry: AuditLog, commit: bool = True +) -> None: + """Write the audit log.""" + origin = _get_origin(request) + entry.origin = origin + session.add(entry) + if commit: + session.commit() + + +def audit_create_client( + session: Session, request: Request, client: Client, commit: bool = True +) -> None: + """Log the creation of a client.""" + entry = AuditLog( + operation="CREATE", + client_id=client.id, + client_name=client.name, + message="Client Created", + ) + _write_audit_log(session, request, entry, commit) + + +def audit_create_secret( + session: Session, + request: Request, + client: Client, + secret: ClientSecret, + commit: bool = True, +) -> None: + """Audit a create secret event.""" + entry = AuditLog( + operation="CREATE", + object="ClientSecret", + object_id=str(secret.id), + client_id=client.id, + client_name=client.name, + message="Added secret to client", + ) + _write_audit_log(session, request, entry, commit) + + +def audit_update_secret( + session: Session, + request: Request, + client: Client, + secret: ClientSecret, + commit: bool = True, +) -> None: + """Audit an update secret event.""" + entry = AuditLog( + operation="UPDATE", + object="ClientSecret", + object_id=str(secret.id), + client_id=client.id, + client_name=client.name, + message="Secret value updated", + ) + _write_audit_log(session, request, entry, commit) + + +def audit_invalidate_secrets( + session: Session, + request: Request, + client: Client, + commit: bool = True, +) -> None: + """Audit Invalidate client secrets.""" + entry = AuditLog( + operation="INVALIDATE", + object="ClientSecret", + client_name=client.name, + client_id=client.id, + message="Client fingerprint updated. All secrets invalidated.", + ) + _write_audit_log(session, request, entry, commit) + + +def audit_access_secrets( + session: Session, + request: Request, + client: Client, + secrets: Sequence[ClientSecret] | None = None, + commit: bool = True, +) -> None: + """Audit that multiple secrets were accessed. + + With no secrets provided, all secrets of the client will be resolved. + """ + if not secrets: + secrets = session.exec( + select(ClientSecret).where(ClientSecret.client_id == client.id) + ).all() + + for secret in secrets: + audit_access_secret(session, request, client, secret, False) + + if commit: + session.commit() + + +def audit_access_secret( + session: Session, + request: Request, + client: Client, + secret: ClientSecret, + commit: bool = True, +) -> None: + """Audit that someone accessed one secrets.""" + entry = AuditLog( + operation="ACCESS", + message="Secret was viewed", + object="ClientSecret", + object_id=str(secret.id), + client_id=client.id, + client_name=client.name, + ) + _write_audit_log(session, request, entry, commit) + + +def audit_access_audit_log( + session: Session, request: Request, commit: bool = True +) -> None: + """Audit access to the audit log. + + Because why not... + """ + entry = AuditLog( + operation="ACCESS", + message="Audit log was viewed", + object="AuditLog", + ) + _write_audit_log(session, request, entry, commit) diff --git a/packages/sshecret-backend/src/sshecret_backend/cli.py b/packages/sshecret-backend/src/sshecret_backend/cli.py new file mode 100644 index 0000000..e847f8e --- /dev/null +++ b/packages/sshecret-backend/src/sshecret_backend/cli.py @@ -0,0 +1,24 @@ +"""CLI and main entry point.""" + +from pathlib import Path +import click +from pydantic import BaseModel, FilePath + + + + +class BackendSettings(BaseModel): + """Backend Settings.""" + + db_file: FilePath + regenerate_tokens: bool = False + + +@click.group() +@click.option("--db-file", envvar="sshecret_db_file", type=click.Path(path_type=Path)) +@click.option("--regenerate-tokens", is_flag=True, default=False) +@click.pass_context +def cli(ctx: click.Context, db_file: Path, regenerate_tokens: bool) -> None: + """Sshecret database handler.""" + if not isinstance(ctx.obj, BackendSettings): + ctx.obj = BackendSettings(db_file=db_file, regenerate_tokens=regenerate_tokens) diff --git a/packages/sshecret-backend/src/sshecret_backend/db.py b/packages/sshecret-backend/src/sshecret_backend/db.py new file mode 100644 index 0000000..9e1a353 --- /dev/null +++ b/packages/sshecret-backend/src/sshecret_backend/db.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 + +import secrets +from pathlib import Path +from sqlalchemy import Engine +from sqlmodel import Session, create_engine, text +import bcrypt + +from dotenv import load_dotenv +from sqlalchemy.engine import URL + +from .models import APIClient, init_db + + +load_dotenv() + + +def get_engine(filename: Path, echo: bool = False) -> Engine: + """Initialize the engine.""" + url = URL.create(drivername="sqlite", database=str(filename.absolute())) + engine = create_engine(url, echo=echo) + with engine.connect() as connection: + connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only + + return engine + + +def create_db_and_tables(filename: Path, echo: bool = True) -> bool: + """Create database and tables. + + Returns True if the database was created. + """ + created = False + if not filename.exists(): + created = True + engine = get_engine(filename, echo) + + init_db(engine) + return created + + +def create_api_token(session: Session, read_write: bool) -> str: + """Create API token.""" + token = secrets.token_urlsafe(32) + pwbytes = token.encode("utf-8") + salt = bcrypt.gensalt() + hashed_bytes = bcrypt.hashpw(password=pwbytes, salt=salt) + hashed = hashed_bytes.decode() + + api_token = APIClient(token=hashed, read_write=read_write) + session.add(api_token) + session.commit() + + return token diff --git a/packages/sshecret-backend/src/sshecret_backend/models.py b/packages/sshecret-backend/src/sshecret_backend/models.py new file mode 100644 index 0000000..00f7af5 --- /dev/null +++ b/packages/sshecret-backend/src/sshecret_backend/models.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 + +import uuid +from datetime import datetime +from sqlalchemy import Engine, Column, DateTime, func +from sqlmodel import Field, Relationship, SQLModel + + +class Client(SQLModel, table=True): + """Client model.""" + + id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + name: str = Field(unique=True) + fingerprint: str + created_at: datetime = Field( + default=None, + sa_column=Column( + DateTime(timezone=True), server_default=func.now(), nullable=True + ), + ) + updated_at: datetime | None = Field( + default=None, + sa_column=Column(DateTime(timezone=True), onupdate=func.now(), nullable=True), + ) + + secrets: list["ClientSecret"] = Relationship( + back_populates="client", passive_deletes="all" + ) + + +class ClientSecret(SQLModel, table=True): + """A client secret.""" + + id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + name: str + client_id: uuid.UUID | None = Field(foreign_key="client.id", ondelete="CASCADE") + client: Client | None = Relationship(back_populates="secrets") + secret: str + invalidated: bool = Field(default=False) + created_at: datetime = Field( + default=None, + sa_column=Column( + DateTime(timezone=True), server_default=func.now(), nullable=True + ), + ) + updated_at: datetime | None = Field( + default=None, + sa_column=Column(DateTime(timezone=True), onupdate=func.now(), nullable=True), + ) + + +class AuditLog(SQLModel, table=True): + """Audit log. + + This is implemented without any foreign keys to avoid losing data on + deletions. + """ + + id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + object: str | None = None + object_id: str | None = None + operation: str + client_id: uuid.UUID | None = None + client_name: str | None = None + message: str + origin: str | None = None + timestamp: datetime | None = Field( + default=None, + sa_column=Column( + DateTime(timezone=True), server_default=func.now(), nullable=True + ), + ) + + +class APIClient(SQLModel, table=True): + """Stores API Keys.""" + + id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + token: str + read_write: bool + created_at: datetime = Field( + default=None, + sa_column=Column( + DateTime(timezone=True), server_default=func.now(), nullable=True + ), + ) + + +def init_db(engine: Engine) -> None: + """Create database.""" + SQLModel.metadata.create_all(engine) diff --git a/packages/sshecret-backend/src/sshecret_backend/py.typed b/packages/sshecret-backend/src/sshecret_backend/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/packages/sshecret-backend/src/sshecret_backend/router.py b/packages/sshecret-backend/src/sshecret_backend/router.py new file mode 100644 index 0000000..b15c5dc --- /dev/null +++ b/packages/sshecret-backend/src/sshecret_backend/router.py @@ -0,0 +1,9 @@ +"""API Router.""" + +from fastapi import FastAPI + +from .app import backend_api + + +app = FastAPI() +app.include_router(backend_api) diff --git a/packages/sshecret-backend/src/sshecret_backend/settings.py b/packages/sshecret-backend/src/sshecret_backend/settings.py new file mode 100644 index 0000000..de75d9a --- /dev/null +++ b/packages/sshecret-backend/src/sshecret_backend/settings.py @@ -0,0 +1,26 @@ +"""Settings management.""" + +import os +from pathlib import Path +from pydantic import BaseModel +from dotenv import load_dotenv + +DEFAULT_DATABASE = "sshecret.db" + + +load_dotenv() + + +class BackendSettings(BaseModel): + """Backend Settings.""" + + db_file: Path + regenerate_tokens: bool = False + + +def get_settings() -> BackendSettings: + """Get settings.""" + db_filename = os.getenv("SSHECRET_DATABASE") or DEFAULT_DATABASE + db_file = Path(db_filename).absolute() + + return BackendSettings(db_file=db_file) diff --git a/packages/sshecret-backend/src/sshecret_backend/testing.py b/packages/sshecret-backend/src/sshecret_backend/testing.py new file mode 100644 index 0000000..a27364a --- /dev/null +++ b/packages/sshecret-backend/src/sshecret_backend/testing.py @@ -0,0 +1,13 @@ +"""Test helpers.""" + +from sqlmodel import Session +from .db import get_engine, create_api_token +from .models import init_db +from .settings import get_settings + +def create_test_token(session: Session) -> str: + """Create test token.""" + settings = get_settings() + engine = get_engine(settings.db_file) + init_db(engine) + return create_api_token(session, True) diff --git a/packages/sshecret-backend/src/sshecret_backend/view_models.py b/packages/sshecret-backend/src/sshecret_backend/view_models.py new file mode 100644 index 0000000..52b3fc1 --- /dev/null +++ b/packages/sshecret-backend/src/sshecret_backend/view_models.py @@ -0,0 +1,112 @@ +"""Models for API views.""" + +import uuid +from datetime import datetime +from typing import Self, override + +from sqlmodel import Field, SQLModel +from . import models + + +class ClientListResponse(SQLModel): + """Model list responses.""" + + id: uuid.UUID + name: str + fingerprint: str + created_at: datetime + updated_at: datetime | None = None + + @classmethod + def from_clients(cls, clients: list[models.Client]) -> list[Self]: + """Generate a list of responses from a list of clients.""" + responses: list[Self] = [] + for client in clients: + responses.append( + cls( + id=client.id, + name=client.name, + fingerprint=client.fingerprint, + created_at=client.created_at, + updated_at=client.updated_at or None, + ) + ) + return responses + + +class ClientView(ClientListResponse): + """View for a single client.""" + + secrets: list[str] = Field(default_factory=list) + + @classmethod + def from_client(cls, client: models.Client) -> Self: + """Instantiate from a client.""" + view = cls( + id=client.id, + name=client.name, + fingerprint=client.fingerprint, + created_at=client.created_at, + updated_at=client.updated_at or None, + ) + if client.secrets: + view.secrets = [secret.name for secret in client.secrets] + + return view + + +class ClientCreate(SQLModel): + """Model to create a client.""" + + name: str + fingerprint: str + + def to_client(self) -> models.Client: + """Instantiate a client.""" + return models.Client(name=self.name, fingerprint=self.fingerprint) + + +class ClientUpdate(SQLModel): + """Model to update the client fingerprint.""" + + fingerprint: str + + +class BodyValue(SQLModel): + """A generic model with just a value parameter.""" + + value: str + + +class ClientSecretPublic(SQLModel): + """Public model to manage client secrets.""" + + name: str + secret: str + + @classmethod + def from_client_secret(cls, client_secret: models.ClientSecret) -> Self: + """Instantiate from ClientSecret.""" + return cls( + name=client_secret.name, + secret=client_secret.secret, + ) + + +class ClientSecretResponse(ClientSecretPublic): + """A secret view.""" + + created_at: datetime + updated_at: datetime | None = None + + @override + @classmethod + def from_client_secret(cls, client_secret: models.ClientSecret) -> Self: + """Instantiate from ClientSecret.""" + + return cls( + name=client_secret.name, + secret=client_secret.secret, + created_at=client_secret.created_at, + updated_at=client_secret.updated_at, + ) diff --git a/packages/sshecret-backend/sshecret.db b/packages/sshecret-backend/sshecret.db new file mode 100644 index 0000000..b6f9b70 Binary files /dev/null and b/packages/sshecret-backend/sshecret.db differ diff --git a/packages/sshecret-backend/tests/test_backend.py b/packages/sshecret-backend/tests/test_backend.py new file mode 100644 index 0000000..eb4659d --- /dev/null +++ b/packages/sshecret-backend/tests/test_backend.py @@ -0,0 +1,279 @@ +"""Tests of the backend api using pytest.""" + +import logging +from httpx import Response +import pytest + +from fastapi.testclient import TestClient +from sqlmodel import Session, SQLModel, create_engine +from sqlmodel.pool import StaticPool + +from sshecret_backend.router import app +from sshecret_backend.app import get_session +from sshecret_backend.testing import create_test_token +from sshecret_backend.models import AuditLog + + +LOG = logging.getLogger() +handler = logging.StreamHandler() +formatter = logging.Formatter("'%(asctime)s - %(levelname)s - %(message)s'") +handler.setFormatter(formatter) +LOG.addHandler(handler) +LOG.setLevel(logging.DEBUG) + + +TEST_FINGERPRINT = "00:11:22:33:44:55:66:77:88:99:aa:bb:cc:dd:ee:ff" + + +def create_client( + test_client: TestClient, + headers: dict[str, str], + name: str, + fingerprint: str = TEST_FINGERPRINT, +) -> Response: + """Create client.""" + data = {"name": name, "fingerprint": fingerprint} + create_response = test_client.post("/api/v1/clients", headers=headers, json=data) + return create_response + + +@pytest.fixture(name="session") +def session_fixture(): + engine = create_engine( + "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool + ) + SQLModel.metadata.create_all(engine) + with Session(engine) as session: + yield session + + +@pytest.fixture(name="token") +def token_fixture(session: Session): + """Generate a token.""" + token = create_test_token(session) + return token + + +@pytest.fixture(name="headers") +def headers_fixture(token: str) -> dict[str, str]: + """Generate headers.""" + return {"X-API-Token": token} + + +@pytest.fixture(name="test_client") +def test_client_fixture(session: Session): + """Test client fixture.""" + + def get_session_override(): + return session + + app.dependency_overrides[get_session] = get_session_override + test_client = TestClient(app) + yield test_client + app.dependency_overrides.clear() + + +def test_missing_token(test_client: TestClient) -> None: + """Test logging in with missing token.""" + response = test_client.get("/api/v1/clients/") + assert response.status_code == 422 + + +def test_incorrect_token(test_client: TestClient) -> None: + """Test logging in with missing token.""" + response = test_client.get("/api/v1/clients/", headers={"X-API-Token": "WRONG"}) + assert response.status_code == 401 + + +def test_with_token(test_client: TestClient, token: str) -> None: + """Test with a valid token.""" + response = test_client.get("/api/v1/clients/", headers={"X-API-Token": token}) + assert response.status_code == 200 + assert len(response.json()) == 0 + + +def test_create_client(test_client: TestClient, headers: dict[str, str]) -> None: + """Test creating a client.""" + client_name = "test" + client_fingerprint = TEST_FINGERPRINT + create_response = create_client( + test_client, headers, client_name, client_fingerprint + ) + assert create_response.status_code == 200 + response = test_client.get("/api/v1/clients/", headers=headers) + assert response.status_code == 200 + clients = response.json() + assert isinstance(clients, list) + client = clients[0] + assert isinstance(client, dict) + assert client.get("name") == client_name + assert client.get("fingerprint") == client_fingerprint + assert client.get("created_at") is not None + + +def test_add_secret(test_client: TestClient, headers: dict[str, str]) -> None: + """Test adding a secret to a client.""" + client_name = "test" + client_fingerprint = "00:11:22:33:44:55:66:77:88:99:aa:bb:cc:dd:ee:ff" + create_response = create_client( + test_client, headers, client_name, client_fingerprint + ) + assert create_response.status_code == 200 + secret_name = "mysecret" + secret_value = "shhhh" + data = {"name": secret_name, "secret": secret_value} + response = test_client.post( + "/api/v1/clients/test/secrets/", headers=headers, json=data + ) + assert response.status_code == 200 + # Get it back + get_response = test_client.get( + "/api/v1/clients/test/secrets/mysecret", headers=headers + ) + assert get_response.status_code == 200 + secret_body = get_response.json() + assert secret_body["name"] == data["name"] + assert secret_body["secret"] == data["secret"] + + +def test_put_add_secret(test_client: TestClient, headers: dict[str, str]) -> None: + """Test adding secret via PUT.""" + # Use the test_create_client function to create a client. + test_create_client(test_client, headers) + secret_name = "mysecret" + secret_value = "shhhh" + data = {"name": secret_name, "secret": secret_value} + response = test_client.put( + "/api/v1/clients/test/secrets/mysecret", + headers=headers, + json={"value": secret_value}, + ) + assert response.status_code == 200 + response_model = response.json() + del response_model["created_at"] + del response_model["updated_at"] + assert response_model == data + + +def test_put_update_secret(test_client: TestClient, headers: dict[str, str]) -> None: + """Test updating a client secret.""" + test_add_secret(test_client, headers) + new_value = "itsasecret" + update_response = test_client.put( + "/api/v1/clients/test/secrets/mysecret", + headers=headers, + json={"value": new_value}, + ) + assert update_response.status_code == 200 + expected = {"name": "mysecret", "secret": new_value} + response_model = update_response.json() + + assert { + "name": response_model["name"], + "secret": response_model["secret"], + } == expected + # Ensure that the updated_at has been set. + assert "updated_at" in response_model + + +def test_audit_logging(test_client: TestClient, headers: dict[str, str]) -> None: + """Test audit logging.""" + create_client_resp = create_client(test_client, headers, "test") + assert create_client_resp.status_code == 200 + secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"} + for name, secret in secrets.items(): + add_resp = test_client.post( + "/api/v1/clients/test/secrets/", + headers=headers, + json={"name": name, "secret": secret}, + ) + assert add_resp.status_code == 200 + + # Fetch the entire client. + get_client_resp = test_client.get("/api/v1/clients/test", headers=headers) + assert get_client_resp.status_code == 200 + + # Fetch the audit log + audit_log_resp = test_client.get("/api/v1/audit/", headers=headers) + assert audit_log_resp.status_code == 200 + audit_logs = audit_log_resp.json() + assert len(audit_logs) > 0 + for entry in audit_logs: + # Let's try to reassemble the objects + audit_log = AuditLog.model_validate(entry) + assert audit_log is not None + + +def test_audit_log_filtering( + session: Session, test_client: TestClient, headers: dict[str, str] +) -> None: + """Test audit log filtering.""" + # Create a lot of test data, but just manually. + audit_log_amount = 150 + entries: list[AuditLog] = [] + for i in range(audit_log_amount): + client_id = i % 5 + entries.append( + AuditLog( + operation="TEST", + object_id=str(i), + client_name=f"client-{client_id}", + message="Test Message", + ) + ) + + session.add_all(entries) + session.commit() + + # This should have generated a lot of audit messages + + audit_path = "/api/v1/audit/" + audit_log_resp = test_client.get(audit_path, headers=headers) + assert audit_log_resp.status_code == 200 + entries = audit_log_resp.json() + assert len(entries) == 100 # We get 100 at a time + + audit_log_resp = test_client.get( + audit_path, headers=headers, params={"offset": 100} + ) + entries = audit_log_resp.json() + assert len(entries) == 52 # There should be 50 + the two requests we made + + # Try to get a specific client + # There should be 30 log entries for each client. + audit_log_resp = test_client.get( + audit_path, headers=headers, params={"filter_client": "client-1"} + ) + + entries = audit_log_resp.json() + assert len(entries) == 30 + + +def test_secret_invalidation(test_client: TestClient, headers: dict[str, str]) -> None: + """Test secret invalidation.""" + create_client_resp = create_client(test_client, headers, "test") + assert create_client_resp.status_code == 200 + secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"} + for name, secret in secrets.items(): + add_resp = test_client.post( + "/api/v1/clients/test/secrets/", + headers=headers, + json={"name": name, "secret": secret}, + ) + assert add_resp.status_code == 200 + + # Update the fingerprint. This should cause all secrets to be invalidated + # and no longer associated with a client. + update_resp = test_client.post( + "/api/v1/clients/test/update_fingerprint", + headers=headers, + json={"fingerprint": "foobar"}, + ) + assert update_resp.status_code == 200 + + # Fetch the client. The list of secrets should be empty. + get_resp = test_client.get("/api/v1/clients/test", headers=headers) + assert get_resp.status_code == 200 + client = get_resp.json() + secrets = client.get("secrets") + assert bool(secrets) is False diff --git a/packages/sshecret-sshd/README.md b/packages/sshecret-sshd/README.md new file mode 100644 index 0000000..e69de29 diff --git a/packages/sshecret-sshd/pyproject.toml b/packages/sshecret-sshd/pyproject.toml new file mode 100644 index 0000000..42991b9 --- /dev/null +++ b/packages/sshecret-sshd/pyproject.toml @@ -0,0 +1,17 @@ +[project] +name = "sshecret-sshd" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +authors = [ + { name = "Allan Eising", email = "allan@eising.dk" } +] +requires-python = ">=3.13" +dependencies = [ + "asyncssh>=2.20.0", + "httpx>=0.28.1", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" diff --git a/packages/sshecret-sshd/src/sshecret_sshd/__init__.py b/packages/sshecret-sshd/src/sshecret_sshd/__init__.py new file mode 100644 index 0000000..66941f4 --- /dev/null +++ b/packages/sshecret-sshd/src/sshecret_sshd/__init__.py @@ -0,0 +1,2 @@ +def hello() -> str: + return "Hello from sshecret-sshd!" diff --git a/packages/sshecret-sshd/src/sshecret_sshd/py.typed b/packages/sshecret-sshd/src/sshecret_sshd/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/packages/sshecret_client/README.md b/packages/sshecret_client/README.md new file mode 100644 index 0000000..e69de29 diff --git a/packages/sshecret_client/client.py b/packages/sshecret_client/client.py new file mode 100644 index 0000000..1561b7b --- /dev/null +++ b/packages/sshecret_client/client.py @@ -0,0 +1,39 @@ +"""Client code""" + +import base64 + +from typing import TextIO +import click +import asyncio +import asyncssh + +from sshecret.crypto import decode_string, load_private_key + + +# async def request_secret(host: str, port: str, username: str, client_key: str, secretname: str) -> str: +# """Request secret.""" +# async with asyncssh.connect(host, port, client_username=username, client_keys=[client_key]) as conn: +# result = await conn.run(secretname, check=True) + +# if encoded := result.stdout: +# if isinstance(encoded, str): +# return encoded +# return encoded.decode() + + +def decrypt_secret(encoded: str, client_key: str) -> str: + """Decrypt secret.""" + private_key = load_private_key(client_key) + return decode_string(encoded, private_key) + + +@click.command() +@click.argument("keyfile", type=click.Path(exists=True, readable=True, dir_okay=False)) +@click.argument("encrypted_input", type=click.File("r")) +def cli_decrypt(keyfile: str, encrypted_input: TextIO) -> None: + """Decrypt on command line.""" + decrypted = decrypt_secret(encrypted_input.read(), keyfile) + click.echo(decrypted) + +if __name__ == "__main__": + cli_decrypt() diff --git a/packages/sshecret_client/pyproject.toml b/packages/sshecret_client/pyproject.toml new file mode 100644 index 0000000..7f01381 --- /dev/null +++ b/packages/sshecret_client/pyproject.toml @@ -0,0 +1,16 @@ +[project] +name = "sshecret-client" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.13" +dependencies = [ + "asyncssh>=2.20.0", + "click>=8.1.8", + "cryptography>=44.0.2", + "paramiko>=3.5.1", + "sshecret", +] + +[tool.uv.sources] +sshecret = { workspace = true }