From ec90fb768010a43b65ee625fc5ac2688413b785a Mon Sep 17 00:00:00 2001 From: Allan Eising Date: Fri, 18 Apr 2025 16:39:05 +0200 Subject: [PATCH] Complete backend --- packages/sshecret-backend/pyproject.toml | 3 + .../src/sshecret_backend/__init__.py | 4 + .../src/sshecret_backend/app.py | 171 ++++++++++++++-- .../src/sshecret_backend/audit.py | 74 ++++++- .../src/sshecret_backend/cli.py | 35 ++-- .../src/sshecret_backend/db.py | 21 +- .../src/sshecret_backend/models.py | 80 +++++--- .../src/sshecret_backend/router.py | 9 - .../src/sshecret_backend/settings.py | 25 ++- .../src/sshecret_backend/view_models.py | 67 +++--- .../sshecret-backend/tests/test_backend.py | 193 ++++++++++++++++-- 11 files changed, 561 insertions(+), 121 deletions(-) delete mode 100644 packages/sshecret-backend/src/sshecret_backend/router.py diff --git a/packages/sshecret-backend/pyproject.toml b/packages/sshecret-backend/pyproject.toml index 9cedf2e..eaa7268 100644 --- a/packages/sshecret-backend/pyproject.toml +++ b/packages/sshecret-backend/pyproject.toml @@ -14,6 +14,9 @@ dependencies = [ "sqlmodel>=0.0.24", ] +[project.scripts] +sshecret-backend = "sshecret_backend.cli:cli" + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" diff --git a/packages/sshecret-backend/src/sshecret_backend/__init__.py b/packages/sshecret-backend/src/sshecret_backend/__init__.py index c2288fe..4053693 100644 --- a/packages/sshecret-backend/src/sshecret_backend/__init__.py +++ b/packages/sshecret-backend/src/sshecret_backend/__init__.py @@ -1 +1,5 @@ """Sshecret backend.""" +from .app import app as app +#from .router import app as app + +__all__ = ["app"] diff --git a/packages/sshecret-backend/src/sshecret_backend/app.py b/packages/sshecret-backend/src/sshecret_backend/app.py index 1e995dd..f9a6296 100644 --- a/packages/sshecret-backend/src/sshecret_backend/app.py +++ b/packages/sshecret-backend/src/sshecret_backend/app.py @@ -1,26 +1,52 @@ -"""FastAPI api.""" +"""FastAPI api. + +TODO: We may want to allow a consumer to generate audit log entries manually. + +""" import logging +from collections.abc import Sequence from contextlib import asynccontextmanager from typing import Annotated -from collections.abc import Sequence import bcrypt -from fastapi import APIRouter, Depends, FastAPI, Header, HTTPException, Query, Request +from fastapi import ( + APIRouter, + Depends, + FastAPI, + Header, + HTTPException, + Query, + Request, + status, +) +from fastapi.encoders import jsonable_encoder +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse + + from sqlmodel import Session, select from . import audit from .db import get_engine -from .models import APIClient, AuditLog, Client, ClientSecret, init_db +from .models import ( + APIClient, + AuditLog, + Client, + ClientAccessPolicy, + ClientSecret, + init_db, +) from .settings import get_settings from .view_models import ( BodyValue, ClientCreate, - ClientListResponse, ClientSecretPublic, ClientSecretResponse, ClientUpdate, ClientView, + ClientPolicyView, + ClientPolicyUpdate, ) settings = get_settings() @@ -104,12 +130,12 @@ backend_api = APIRouter( @backend_api.get("/clients/") async def get_clients( session: Annotated[Session, Depends(get_session)] -) -> list[ClientListResponse]: +) -> list[ClientView]: """Get clients.""" statement = select(Client) results = session.exec(statement) clients = list(results) - return ClientListResponse.from_clients(clients) + return ClientView.from_client_list(clients) @backend_api.get("/clients/{name}") @@ -128,14 +154,93 @@ async def get_client( return ClientView.from_client(client) -@backend_api.post("/clients/{name}/update_fingerprint") -async def update_client_fingerprint( +@backend_api.delete("/clients/{name}") +async def delete_client( + request: Request, name: str, session: Annotated[Session, Depends(get_session)] +) -> None: + """Delete a client.""" + statement = select(Client).where(Client.name == name) + results = session.exec(statement) + client = results.first() + if not client: + raise HTTPException( + status_code=404, detail="Cannot find a client with the given name." + ) + + session.delete(client) + session.commit() + audit.audit_delete_client(session, request, client) + + +@backend_api.get("/clients/{name}/policies/") +async def get_client_policies( + name: str, session: Annotated[Session, Depends(get_session)] +) -> ClientPolicyView: + """Get client policies.""" + statement = select(Client).where(Client.name == name) + results = session.exec(statement) + client = results.first() + if not client: + raise HTTPException( + status_code=404, detail="Cannot find a client with the given name." + ) + + return ClientPolicyView.from_client(client) + + +@backend_api.put("/clients/{name}/policies/") +async def update_client_policies( + request: Request, + name: str, + policy_update: ClientPolicyUpdate, + session: Annotated[Session, Depends(get_session)], +) -> ClientPolicyView: + """Update client policies. + + This is also how you delete policies. + """ + statement = select(Client).where(Client.name == name) + results = session.exec(statement) + client = results.first() + if not client: + raise HTTPException( + status_code=404, detail="Cannot find a client with the given name." + ) + # Remove old policies. + policies = session.exec( + select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id) + ).all() + deleted_policies: list[ClientAccessPolicy] = [] + added_policies: list[ClientAccessPolicy] = [] + for policy in policies: + session.delete(policy) + deleted_policies.append(policy) + + for source in policy_update.sources: + LOG.debug("Source %r", source) + policy = ClientAccessPolicy(source=str(source), client_id=client.id) + session.add(policy) + added_policies.append(policy) + + session.commit() + session.refresh(client) + for policy in deleted_policies: + audit.audit_remove_policy(session, request, client, policy) + + for policy in added_policies: + audit.audit_update_policy(session, request, client, policy) + + return ClientPolicyView.from_client(client) + + +@backend_api.post("/clients/{name}/public-key") +async def update_client_public_key( request: Request, name: str, client_update: ClientUpdate, session: Annotated[Session, Depends(get_session)], ) -> ClientView: - """Update the client fingerprint. + """Change the public key of a client. This invalidates all secrets. """ @@ -146,7 +251,7 @@ async def update_client_fingerprint( raise HTTPException( status_code=404, detail="Cannot find a client with the given name." ) - client.fingerprint = client_update.fingerprint + client.public_key = client_update.public_key for secret in session.exec( select(ClientSecret).where(ClientSecret.client_id == client.id) ).all(): @@ -170,7 +275,11 @@ async def create_client( session: Annotated[Session, Depends(get_session)], ) -> ClientView: """Create client.""" - db_client = Client.model_validate(client) + existing = await get_client_by_name(session, client.name) + if existing: + raise HTTPException(400, detail="Error: Already a client with that name.") + + db_client = client.to_client() session.add(db_client) session.commit() session.refresh(db_client) @@ -270,6 +379,30 @@ async def request_client_secret( audit.audit_access_secret(session, request, client, secret) return response_model +@backend_api.delete("/clients/{name}/secrets/{secret_name}") +async def delete_client_secret( + request: Request, + name: str, + secret_name: str, + session: Annotated[Session, Depends(get_session)], +) -> None: + """Delete a secret.""" + client = await get_client_by_name(session, name) + if not client: + raise HTTPException( + status_code=404, detail="Cannot find a client with the given name." + ) + + secret = await lookup_client_secret(session, client, secret_name) + if not secret: + raise HTTPException( + status_code=404, detail="Cannot find a secret with the given name." + ) + + session.delete(secret) + session.commit() + audit.audit_delete_secret(session, request, client, secret) + @backend_api.get("/audit/", response_model=list[AuditLog]) async def get_audit_logs( @@ -287,3 +420,17 @@ async def get_audit_logs( results = session.exec(statement).all() return results + + +app = FastAPI() + + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request: Request, exc: RequestValidationError): + return JSONResponse( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}), + ) + + +app.include_router(backend_api) diff --git a/packages/sshecret-backend/src/sshecret_backend/audit.py b/packages/sshecret-backend/src/sshecret_backend/audit.py index 6b74cf8..69f15b0 100644 --- a/packages/sshecret-backend/src/sshecret_backend/audit.py +++ b/packages/sshecret-backend/src/sshecret_backend/audit.py @@ -3,7 +3,7 @@ from collections.abc import Sequence from fastapi import Request from sqlmodel import Session, select -from .models import AuditLog, Client, ClientSecret +from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy def _get_origin(request: Request) -> str | None: @@ -39,6 +39,19 @@ def audit_create_client( _write_audit_log(session, request, entry, commit) +def audit_delete_client( + session: Session, request: Request, client: Client, commit: bool = True +) -> None: + """Log the creation of a client.""" + entry = AuditLog( + operation="CREATE", + client_id=client.id, + client_name=client.name, + message="Client deleted", + ) + _write_audit_log(session, request, entry, commit) + + def audit_create_secret( session: Session, request: Request, @@ -58,6 +71,44 @@ def audit_create_secret( _write_audit_log(session, request, entry, commit) +def audit_remove_policy( + session: Session, + request: Request, + client: Client, + policy: ClientAccessPolicy, + commit: bool = True, +) -> None: + """Audit removal of policy.""" + entry = AuditLog( + operation="DELETE", + object="ClientAccessPolicy", + object_id=str(policy.id), + client_id=client.id, + client_name=client.name, + message="Deleted client policy", + ) + _write_audit_log(session, request, entry, commit) + + +def audit_update_policy( + session: Session, + request: Request, + client: Client, + policy: ClientAccessPolicy, + commit: bool = True, +) -> None: + """Audit update of policy.""" + entry = AuditLog( + operation="CREATE", + object="ClientAccessPolicy", + object_id=str(policy.id), + client_id=client.id, + client_name=client.name, + message="Updated client policy", + ) + _write_audit_log(session, request, entry, commit) + + def audit_update_secret( session: Session, request: Request, @@ -89,7 +140,26 @@ def audit_invalidate_secrets( object="ClientSecret", client_name=client.name, client_id=client.id, - message="Client fingerprint updated. All secrets invalidated.", + message="Client public-key changed. All secrets invalidated.", + ) + _write_audit_log(session, request, entry, commit) + + +def audit_delete_secret( + session: Session, + request: Request, + client: Client, + secret: ClientSecret, + commit: bool = True, +) -> None: + """Audit Delete client secrets.""" + entry = AuditLog( + operation="DELETE", + object="ClientSecret", + object_id=str(secret.id), + client_name=client.name, + client_id=client.id, + message="Deleted secret.", ) _write_audit_log(session, request, entry, commit) diff --git a/packages/sshecret-backend/src/sshecret_backend/cli.py b/packages/sshecret-backend/src/sshecret_backend/cli.py index e847f8e..d314e4f 100644 --- a/packages/sshecret-backend/src/sshecret_backend/cli.py +++ b/packages/sshecret-backend/src/sshecret_backend/cli.py @@ -1,24 +1,31 @@ """CLI and main entry point.""" +import os from pathlib import Path +from dotenv import load_dotenv import click -from pydantic import BaseModel, FilePath +from .db import generate_api_token +DEFAULT_LISTEN = "127.0.0.1" +DEFAULT_PORT = 8022 +WORKDIR = Path(os.getcwd()) -class BackendSettings(BaseModel): - """Backend Settings.""" - - db_file: FilePath - regenerate_tokens: bool = False - +load_dotenv() @click.group() -@click.option("--db-file", envvar="sshecret_db_file", type=click.Path(path_type=Path)) -@click.option("--regenerate-tokens", is_flag=True, default=False) -@click.pass_context -def cli(ctx: click.Context, db_file: Path, regenerate_tokens: bool) -> None: - """Sshecret database handler.""" - if not isinstance(ctx.obj, BackendSettings): - ctx.obj = BackendSettings(db_file=db_file, regenerate_tokens=regenerate_tokens) +@click.option("--database", help="Path to the sqlite database file.") +def cli(database: str) -> None: + """CLI group.""" + if database: + # Hopefully it's enough to set the environment variable as so. + os.environ["SSHECRET_DB_FILE"] = str(Path(database).absolute()) + + +@cli.command("generate-token") +def cli_generate_token() -> None: + """Generate a token.""" + token = generate_api_token() + click.echo("Generated api token:") + click.echo(token) diff --git a/packages/sshecret-backend/src/sshecret_backend/db.py b/packages/sshecret-backend/src/sshecret_backend/db.py index 9e1a353..3cc9626 100644 --- a/packages/sshecret-backend/src/sshecret_backend/db.py +++ b/packages/sshecret-backend/src/sshecret_backend/db.py @@ -1,18 +1,20 @@ -#!/usr/bin/env python3 +"""Database related functions.""" +import logging import secrets from pathlib import Path from sqlalchemy import Engine from sqlmodel import Session, create_engine, text import bcrypt -from dotenv import load_dotenv from sqlalchemy.engine import URL from .models import APIClient, init_db +from .settings import get_settings -load_dotenv() + +LOG = logging.getLogger(__name__) def get_engine(filename: Path, echo: bool = False) -> Engine: @@ -25,7 +27,7 @@ def get_engine(filename: Path, echo: bool = False) -> Engine: return engine -def create_db_and_tables(filename: Path, echo: bool = True) -> bool: +def create_db_and_tables(filename: Path, echo: bool = False) -> bool: """Create database and tables. Returns True if the database was created. @@ -52,3 +54,14 @@ def create_api_token(session: Session, read_write: bool) -> str: session.commit() return token + + +def generate_api_token() -> str: + """Generate API token.""" + settings = get_settings() + engine = get_engine(settings.db_file) + init_db(engine) + with Session(engine) as session: + token = create_api_token(session, True) + + return token diff --git a/packages/sshecret-backend/src/sshecret_backend/models.py b/packages/sshecret-backend/src/sshecret_backend/models.py index 00f7af5..6e370db 100644 --- a/packages/sshecret-backend/src/sshecret_backend/models.py +++ b/packages/sshecret-backend/src/sshecret_backend/models.py @@ -1,8 +1,15 @@ -#!/usr/bin/env python3 +"""Database models. + +TODO: + +We might want to pass on audit information from the SSH server. +This might require some changes to these schemas. + +""" import uuid from datetime import datetime -from sqlalchemy import Engine, Column, DateTime, func +import sqlalchemy as sa from sqlmodel import Field, Relationship, SQLModel @@ -11,22 +18,48 @@ class Client(SQLModel, table=True): id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) name: str = Field(unique=True) - fingerprint: str - created_at: datetime = Field( + public_key: str + + created_at: datetime | None = Field( default=None, - sa_column=Column( - DateTime(timezone=True), server_default=func.now(), nullable=True - ), + sa_type=sa.DateTime(timezone=True), + sa_column_kwargs={"server_default": sa.func.now()}, + nullable=False, ) + updated_at: datetime | None = Field( default=None, - sa_column=Column(DateTime(timezone=True), onupdate=func.now(), nullable=True), + sa_type=sa.DateTime(timezone=True), + sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()}, ) secrets: list["ClientSecret"] = Relationship( back_populates="client", passive_deletes="all" ) + policies: list["ClientAccessPolicy"] = Relationship(back_populates="client") + + +class ClientAccessPolicy(SQLModel, table=True): + """Client access policies.""" + + id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + source: str + client_id: uuid.UUID | None = Field(foreign_key="client.id", ondelete="CASCADE") + client: Client | None = Relationship(back_populates="policies") + + created_at: datetime | None = Field( + default=None, + sa_type=sa.DateTime(timezone=True), + sa_column_kwargs={"server_default": sa.func.now()}, + nullable=False, + ) + + updated_at: datetime | None = Field( + default=None, + sa_type=sa.DateTime(timezone=True), + sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()}, + ) class ClientSecret(SQLModel, table=True): """A client secret.""" @@ -37,15 +70,17 @@ class ClientSecret(SQLModel, table=True): client: Client | None = Relationship(back_populates="secrets") secret: str invalidated: bool = Field(default=False) - created_at: datetime = Field( + created_at: datetime | None = Field( default=None, - sa_column=Column( - DateTime(timezone=True), server_default=func.now(), nullable=True - ), + sa_type=sa.DateTime(timezone=True), + sa_column_kwargs={"server_default": sa.func.now()}, + nullable=False, ) + updated_at: datetime | None = Field( default=None, - sa_column=Column(DateTime(timezone=True), onupdate=func.now(), nullable=True), + sa_type=sa.DateTime(timezone=True), + sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()}, ) @@ -64,28 +99,27 @@ class AuditLog(SQLModel, table=True): client_name: str | None = None message: str origin: str | None = None + timestamp: datetime | None = Field( default=None, - sa_column=Column( - DateTime(timezone=True), server_default=func.now(), nullable=True - ), + sa_type=sa.DateTime(timezone=True), + sa_column_kwargs={"server_default": sa.func.now()}, + nullable=False, ) - class APIClient(SQLModel, table=True): """Stores API Keys.""" id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) token: str read_write: bool - created_at: datetime = Field( + created_at: datetime | None = Field( default=None, - sa_column=Column( - DateTime(timezone=True), server_default=func.now(), nullable=True - ), + sa_type=sa.DateTime(timezone=True), + sa_column_kwargs={"server_default": sa.func.now()}, + nullable=False, ) - -def init_db(engine: Engine) -> None: +def init_db(engine: sa.Engine) -> None: """Create database.""" SQLModel.metadata.create_all(engine) diff --git a/packages/sshecret-backend/src/sshecret_backend/router.py b/packages/sshecret-backend/src/sshecret_backend/router.py deleted file mode 100644 index b15c5dc..0000000 --- a/packages/sshecret-backend/src/sshecret_backend/router.py +++ /dev/null @@ -1,9 +0,0 @@ -"""API Router.""" - -from fastapi import FastAPI - -from .app import backend_api - - -app = FastAPI() -app.include_router(backend_api) diff --git a/packages/sshecret-backend/src/sshecret_backend/settings.py b/packages/sshecret-backend/src/sshecret_backend/settings.py index de75d9a..c06f0f1 100644 --- a/packages/sshecret-backend/src/sshecret_backend/settings.py +++ b/packages/sshecret-backend/src/sshecret_backend/settings.py @@ -1,26 +1,25 @@ """Settings management.""" -import os +from typing import override from pathlib import Path -from pydantic import BaseModel -from dotenv import load_dotenv +from pydantic import BaseModel, Field +from pydantic_settings import ( + BaseSettings, + SettingsConfigDict, +) + DEFAULT_DATABASE = "sshecret.db" -load_dotenv() +class BackendSettings(BaseSettings): + """Backend settings.""" + model_config = SettingsConfigDict(env_file=".backend.env", env_prefix="sshecret_") -class BackendSettings(BaseModel): - """Backend Settings.""" - - db_file: Path - regenerate_tokens: bool = False + db_file: Path = Field(default=Path(DEFAULT_DATABASE).absolute()) def get_settings() -> BackendSettings: """Get settings.""" - db_filename = os.getenv("SSHECRET_DATABASE") or DEFAULT_DATABASE - db_file = Path(db_filename).absolute() - - return BackendSettings(db_file=db_file) + return BackendSettings() diff --git a/packages/sshecret-backend/src/sshecret_backend/view_models.py b/packages/sshecret-backend/src/sshecret_backend/view_models.py index 52b3fc1..3799d72 100644 --- a/packages/sshecret-backend/src/sshecret_backend/view_models.py +++ b/packages/sshecret-backend/src/sshecret_backend/view_models.py @@ -1,57 +1,48 @@ """Models for API views.""" +import ipaddress import uuid from datetime import datetime -from typing import Self, override +from typing import Annotated, Any, Self, override from sqlmodel import Field, SQLModel +from pydantic import IPvAnyAddress, IPvAnyNetwork from . import models -class ClientListResponse(SQLModel): - """Model list responses.""" +class ClientView(SQLModel): + """View for a single client.""" id: uuid.UUID name: str - fingerprint: str + public_key: str + policies: list[str] = ["0.0.0.0/0", "::/0"] + secrets: list[str] = Field(default_factory=list) created_at: datetime updated_at: datetime | None = None @classmethod - def from_clients(cls, clients: list[models.Client]) -> list[Self]: + def from_client_list(cls, clients: list[models.Client]) -> list[Self]: """Generate a list of responses from a list of clients.""" - responses: list[Self] = [] - for client in clients: - responses.append( - cls( - id=client.id, - name=client.name, - fingerprint=client.fingerprint, - created_at=client.created_at, - updated_at=client.updated_at or None, - ) - ) + responses: list[Self] = [cls.from_client(client) for client in clients] return responses - -class ClientView(ClientListResponse): - """View for a single client.""" - - secrets: list[str] = Field(default_factory=list) - @classmethod def from_client(cls, client: models.Client) -> Self: """Instantiate from a client.""" view = cls( id=client.id, name=client.name, - fingerprint=client.fingerprint, + public_key=client.public_key, created_at=client.created_at, updated_at=client.updated_at or None, ) if client.secrets: view.secrets = [secret.name for secret in client.secrets] + if client.policies: + view.policies = [policy.source for policy in client.policies] + return view @@ -59,17 +50,20 @@ class ClientCreate(SQLModel): """Model to create a client.""" name: str - fingerprint: str + public_key: str def to_client(self) -> models.Client: """Instantiate a client.""" - return models.Client(name=self.name, fingerprint=self.fingerprint) + public_key = self.public_key + return models.Client( + name=self.name, public_key=public_key + ) class ClientUpdate(SQLModel): - """Model to update the client fingerprint.""" + """Model to update the client public key.""" - fingerprint: str + public_key: str class BodyValue(SQLModel): @@ -110,3 +104,22 @@ class ClientSecretResponse(ClientSecretPublic): created_at=client_secret.created_at, updated_at=client_secret.updated_at, ) + + +class ClientPolicyView(SQLModel): + """Update object for client policy.""" + + sources: list[str] = ["0.0.0.0/0", "::/0"] + + @classmethod + def from_client(cls, client: models.Client) -> Self: + """Create from client.""" + if not client.policies: + return cls() + return cls(sources=[policy.source for policy in client.policies]) + + +class ClientPolicyUpdate(SQLModel): + """Model for updating policies.""" + + sources: list[IPvAnyAddress | IPvAnyNetwork] diff --git a/packages/sshecret-backend/tests/test_backend.py b/packages/sshecret-backend/tests/test_backend.py index eb4659d..febbd3f 100644 --- a/packages/sshecret-backend/tests/test_backend.py +++ b/packages/sshecret-backend/tests/test_backend.py @@ -1,6 +1,8 @@ """Tests of the backend api using pytest.""" import logging +import random +import string from httpx import Response import pytest @@ -8,8 +10,7 @@ from fastapi.testclient import TestClient from sqlmodel import Session, SQLModel, create_engine from sqlmodel.pool import StaticPool -from sshecret_backend.router import app -from sshecret_backend.app import get_session +from sshecret_backend.app import app, get_session from sshecret_backend.testing import create_test_token from sshecret_backend.models import AuditLog @@ -22,17 +23,30 @@ LOG.addHandler(handler) LOG.setLevel(logging.DEBUG) -TEST_FINGERPRINT = "00:11:22:33:44:55:66:77:88:99:aa:bb:cc:dd:ee:ff" +def make_test_key() -> str: + """Generate a test key.""" + randomlength = 540 + key = "ssh-rsa " + randompart = "".join( + random.choices(string.ascii_letters + string.digits, k=randomlength) + ) + comment = " invalid-test-key" + return key + randompart + comment def create_client( test_client: TestClient, headers: dict[str, str], name: str, - fingerprint: str = TEST_FINGERPRINT, + public_key: str | None = None, ) -> Response: """Create client.""" - data = {"name": name, "fingerprint": fingerprint} + if not public_key: + public_key = make_test_key() + data = { + "name": name, + "public_key": public_key, + } create_response = test_client.post("/api/v1/clients", headers=headers, json=data) return create_response @@ -95,10 +109,8 @@ def test_with_token(test_client: TestClient, token: str) -> None: def test_create_client(test_client: TestClient, headers: dict[str, str]) -> None: """Test creating a client.""" client_name = "test" - client_fingerprint = TEST_FINGERPRINT - create_response = create_client( - test_client, headers, client_name, client_fingerprint - ) + client_publickey = make_test_key() + create_response = create_client(test_client, headers, client_name, client_publickey) assert create_response.status_code == 200 response = test_client.get("/api/v1/clients/", headers=headers) assert response.status_code == 200 @@ -107,16 +119,35 @@ def test_create_client(test_client: TestClient, headers: dict[str, str]) -> None client = clients[0] assert isinstance(client, dict) assert client.get("name") == client_name - assert client.get("fingerprint") == client_fingerprint assert client.get("created_at") is not None +def test_delete_client(test_client: TestClient, headers: dict[str, str]) -> None: + """Test creating a client.""" + client_name = "test" + create_response = create_client( + test_client, + headers, + client_name, + ) + assert create_response.status_code == 200 + + resp = test_client.delete("/api/v1/clients/test", headers=headers) + assert resp.status_code == 200 + + resp = test_client.get("/api/v1/clients/test", headers=headers) + assert resp.status_code == 404 + + def test_add_secret(test_client: TestClient, headers: dict[str, str]) -> None: """Test adding a secret to a client.""" client_name = "test" - client_fingerprint = "00:11:22:33:44:55:66:77:88:99:aa:bb:cc:dd:ee:ff" + client_publickey = make_test_key() create_response = create_client( - test_client, headers, client_name, client_fingerprint + test_client, + headers, + client_name, + client_publickey, ) assert create_response.status_code == 200 secret_name = "mysecret" @@ -136,6 +167,17 @@ def test_add_secret(test_client: TestClient, headers: dict[str, str]) -> None: assert secret_body["secret"] == data["secret"] +def test_delete_secret(test_client: TestClient, headers: dict[str, str]) -> None: + """Test deleting a secret.""" + test_add_secret(test_client, headers) + resp = test_client.delete("/api/v1/clients/test/secrets/mysecret", headers=headers) + assert resp.status_code == 200 + get_response = test_client.get( + "/api/v1/clients/test/secrets/mysecret", headers=headers + ) + assert get_response.status_code == 404 + + def test_put_add_secret(test_client: TestClient, headers: dict[str, str]) -> None: """Test adding secret via PUT.""" # Use the test_create_client function to create a client. @@ -178,7 +220,8 @@ def test_put_update_secret(test_client: TestClient, headers: dict[str, str]) -> def test_audit_logging(test_client: TestClient, headers: dict[str, str]) -> None: """Test audit logging.""" - create_client_resp = create_client(test_client, headers, "test") + public_key = make_test_key() + create_client_resp = create_client(test_client, headers, "test", public_key) assert create_client_resp.status_code == 200 secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"} for name, secret in secrets.items(): @@ -251,7 +294,8 @@ def test_audit_log_filtering( def test_secret_invalidation(test_client: TestClient, headers: dict[str, str]) -> None: """Test secret invalidation.""" - create_client_resp = create_client(test_client, headers, "test") + initial_key = make_test_key() + create_client_resp = create_client(test_client, headers, "test", initial_key) assert create_client_resp.status_code == 200 secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"} for name, secret in secrets.items(): @@ -262,12 +306,13 @@ def test_secret_invalidation(test_client: TestClient, headers: dict[str, str]) - ) assert add_resp.status_code == 200 - # Update the fingerprint. This should cause all secrets to be invalidated + # Update the public-key. This should cause all secrets to be invalidated # and no longer associated with a client. + new_key = make_test_key() update_resp = test_client.post( - "/api/v1/clients/test/update_fingerprint", + "/api/v1/clients/test/public-key", headers=headers, - json={"fingerprint": "foobar"}, + json={"public_key": new_key}, ) assert update_resp.status_code == 200 @@ -277,3 +322,117 @@ def test_secret_invalidation(test_client: TestClient, headers: dict[str, str]) - client = get_resp.json() secrets = client.get("secrets") assert bool(secrets) is False + + +def test_client_default_policies( + test_client: TestClient, headers: dict[str, str] +) -> None: + """Test client policies.""" + public_key = make_test_key() + resp = create_client(test_client, headers, "test", public_key) + assert resp.status_code == 200 + # Fetch policies, should return * + resp = test_client.get("/api/v1/clients/test/policies/", headers=headers) + assert resp.status_code == 200 + + policies = resp.json() + + assert policies["sources"] == ["0.0.0.0/0", "::/0"] + + +def test_client_policy_update_one( + test_client: TestClient, headers: dict[str, str] +) -> None: + """Update client policy with single policy.""" + public_key = make_test_key() + resp = create_client(test_client, headers, "test", public_key) + assert resp.status_code == 200 + + policy = ["192.0.2.1"] + resp = test_client.put( + "/api/v1/clients/test/policies/", headers=headers, json={"sources": policy} + ) + assert resp.status_code == 200 + + resp = test_client.get("/api/v1/clients/test/policies/", headers=headers) + assert resp.status_code == 200 + + policies = resp.json() + + assert policies["sources"] == policy + + +def test_client_policy_update_advanced( + test_client: TestClient, headers: dict[str, str] +) -> None: + """Test other policy update scenarios.""" + public_key = make_test_key() + resp = create_client(test_client, headers, "test", public_key) + assert resp.status_code == 200 + + policy = ["192.0.2.1", "198.18.0.0/24"] + + resp = test_client.put( + "/api/v1/clients/test/policies/", headers=headers, json={"sources": policy} + ) + assert resp.status_code == 200 + + resp = test_client.get("/api/v1/clients/test/policies/", headers=headers) + assert resp.status_code == 200 + + policies = resp.json() + + assert "192.0.2.1" in policies["sources"] + assert "198.18.0.0/24" in policies["sources"] + + # Try to set it to something incorrect + + policy = ["obviosly_wrong"] + + resp = test_client.put( + "/api/v1/clients/test/policies/", headers=headers, json={"sources": policy} + ) + assert resp.status_code == 422 + # Check that the old value is still there + + resp = test_client.get("/api/v1/clients/test/policies/", headers=headers) + assert resp.status_code == 200 + + policies = resp.json() + + assert "192.0.2.1" in policies["sources"] + assert "198.18.0.0/24" in policies["sources"] + + # Clear the policies + # + + +def test_client_policy_update_unset( + test_client: TestClient, headers: dict[str, str] +) -> None: + """Test clearing the client policy.""" + public_key = make_test_key() + resp = create_client(test_client, headers, "test", public_key) + assert resp.status_code == 200 + policy = ["192.0.2.1", "198.18.0.0/24"] + + resp = test_client.put( + "/api/v1/clients/test/policies/", headers=headers, json={"sources": policy} + ) + assert resp.status_code == 200 + + policies = resp.json() + + assert "192.0.2.1" in policies["sources"] + assert "198.18.0.0/24" in policies["sources"] + + # Now we clear the policies + + resp = test_client.put( + "/api/v1/clients/test/policies/", headers=headers, json={"sources": []} + ) + assert resp.status_code == 200 + + policies = resp.json() + + assert policies["sources"] == ["0.0.0.0/0", "::/0"]