Complete backend

This commit is contained in:
2025-04-18 16:39:05 +02:00
parent 83551ffb4a
commit ec90fb7680
11 changed files with 561 additions and 121 deletions

View File

@ -14,6 +14,9 @@ dependencies = [
"sqlmodel>=0.0.24", "sqlmodel>=0.0.24",
] ]
[project.scripts]
sshecret-backend = "sshecret_backend.cli:cli"
[build-system] [build-system]
requires = ["hatchling"] requires = ["hatchling"]
build-backend = "hatchling.build" build-backend = "hatchling.build"

View File

@ -1 +1,5 @@
"""Sshecret backend.""" """Sshecret backend."""
from .app import app as app
#from .router import app as app
__all__ = ["app"]

View File

@ -1,26 +1,52 @@
"""FastAPI api.""" """FastAPI api.
TODO: We may want to allow a consumer to generate audit log entries manually.
"""
import logging import logging
from collections.abc import Sequence
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Annotated from typing import Annotated
from collections.abc import Sequence
import bcrypt import bcrypt
from fastapi import APIRouter, Depends, FastAPI, Header, HTTPException, Query, Request from fastapi import (
APIRouter,
Depends,
FastAPI,
Header,
HTTPException,
Query,
Request,
status,
)
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from sqlmodel import Session, select from sqlmodel import Session, select
from . import audit from . import audit
from .db import get_engine from .db import get_engine
from .models import APIClient, AuditLog, Client, ClientSecret, init_db from .models import (
APIClient,
AuditLog,
Client,
ClientAccessPolicy,
ClientSecret,
init_db,
)
from .settings import get_settings from .settings import get_settings
from .view_models import ( from .view_models import (
BodyValue, BodyValue,
ClientCreate, ClientCreate,
ClientListResponse,
ClientSecretPublic, ClientSecretPublic,
ClientSecretResponse, ClientSecretResponse,
ClientUpdate, ClientUpdate,
ClientView, ClientView,
ClientPolicyView,
ClientPolicyUpdate,
) )
settings = get_settings() settings = get_settings()
@ -104,12 +130,12 @@ backend_api = APIRouter(
@backend_api.get("/clients/") @backend_api.get("/clients/")
async def get_clients( async def get_clients(
session: Annotated[Session, Depends(get_session)] session: Annotated[Session, Depends(get_session)]
) -> list[ClientListResponse]: ) -> list[ClientView]:
"""Get clients.""" """Get clients."""
statement = select(Client) statement = select(Client)
results = session.exec(statement) results = session.exec(statement)
clients = list(results) clients = list(results)
return ClientListResponse.from_clients(clients) return ClientView.from_client_list(clients)
@backend_api.get("/clients/{name}") @backend_api.get("/clients/{name}")
@ -128,14 +154,93 @@ async def get_client(
return ClientView.from_client(client) return ClientView.from_client(client)
@backend_api.post("/clients/{name}/update_fingerprint") @backend_api.delete("/clients/{name}")
async def update_client_fingerprint( async def delete_client(
request: Request, name: str, session: Annotated[Session, Depends(get_session)]
) -> None:
"""Delete a client."""
statement = select(Client).where(Client.name == name)
results = session.exec(statement)
client = results.first()
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
session.delete(client)
session.commit()
audit.audit_delete_client(session, request, client)
@backend_api.get("/clients/{name}/policies/")
async def get_client_policies(
name: str, session: Annotated[Session, Depends(get_session)]
) -> ClientPolicyView:
"""Get client policies."""
statement = select(Client).where(Client.name == name)
results = session.exec(statement)
client = results.first()
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
return ClientPolicyView.from_client(client)
@backend_api.put("/clients/{name}/policies/")
async def update_client_policies(
request: Request,
name: str,
policy_update: ClientPolicyUpdate,
session: Annotated[Session, Depends(get_session)],
) -> ClientPolicyView:
"""Update client policies.
This is also how you delete policies.
"""
statement = select(Client).where(Client.name == name)
results = session.exec(statement)
client = results.first()
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
# Remove old policies.
policies = session.exec(
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
).all()
deleted_policies: list[ClientAccessPolicy] = []
added_policies: list[ClientAccessPolicy] = []
for policy in policies:
session.delete(policy)
deleted_policies.append(policy)
for source in policy_update.sources:
LOG.debug("Source %r", source)
policy = ClientAccessPolicy(source=str(source), client_id=client.id)
session.add(policy)
added_policies.append(policy)
session.commit()
session.refresh(client)
for policy in deleted_policies:
audit.audit_remove_policy(session, request, client, policy)
for policy in added_policies:
audit.audit_update_policy(session, request, client, policy)
return ClientPolicyView.from_client(client)
@backend_api.post("/clients/{name}/public-key")
async def update_client_public_key(
request: Request, request: Request,
name: str, name: str,
client_update: ClientUpdate, client_update: ClientUpdate,
session: Annotated[Session, Depends(get_session)], session: Annotated[Session, Depends(get_session)],
) -> ClientView: ) -> ClientView:
"""Update the client fingerprint. """Change the public key of a client.
This invalidates all secrets. This invalidates all secrets.
""" """
@ -146,7 +251,7 @@ async def update_client_fingerprint(
raise HTTPException( raise HTTPException(
status_code=404, detail="Cannot find a client with the given name." status_code=404, detail="Cannot find a client with the given name."
) )
client.fingerprint = client_update.fingerprint client.public_key = client_update.public_key
for secret in session.exec( for secret in session.exec(
select(ClientSecret).where(ClientSecret.client_id == client.id) select(ClientSecret).where(ClientSecret.client_id == client.id)
).all(): ).all():
@ -170,7 +275,11 @@ async def create_client(
session: Annotated[Session, Depends(get_session)], session: Annotated[Session, Depends(get_session)],
) -> ClientView: ) -> ClientView:
"""Create client.""" """Create client."""
db_client = Client.model_validate(client) existing = await get_client_by_name(session, client.name)
if existing:
raise HTTPException(400, detail="Error: Already a client with that name.")
db_client = client.to_client()
session.add(db_client) session.add(db_client)
session.commit() session.commit()
session.refresh(db_client) session.refresh(db_client)
@ -270,6 +379,30 @@ async def request_client_secret(
audit.audit_access_secret(session, request, client, secret) audit.audit_access_secret(session, request, client, secret)
return response_model return response_model
@backend_api.delete("/clients/{name}/secrets/{secret_name}")
async def delete_client_secret(
request: Request,
name: str,
secret_name: str,
session: Annotated[Session, Depends(get_session)],
) -> None:
"""Delete a secret."""
client = await get_client_by_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
secret = await lookup_client_secret(session, client, secret_name)
if not secret:
raise HTTPException(
status_code=404, detail="Cannot find a secret with the given name."
)
session.delete(secret)
session.commit()
audit.audit_delete_secret(session, request, client, secret)
@backend_api.get("/audit/", response_model=list[AuditLog]) @backend_api.get("/audit/", response_model=list[AuditLog])
async def get_audit_logs( async def get_audit_logs(
@ -287,3 +420,17 @@ async def get_audit_logs(
results = session.exec(statement).all() results = session.exec(statement).all()
return results return results
app = FastAPI()
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
)
app.include_router(backend_api)

View File

@ -3,7 +3,7 @@
from collections.abc import Sequence from collections.abc import Sequence
from fastapi import Request from fastapi import Request
from sqlmodel import Session, select from sqlmodel import Session, select
from .models import AuditLog, Client, ClientSecret from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy
def _get_origin(request: Request) -> str | None: def _get_origin(request: Request) -> str | None:
@ -39,6 +39,19 @@ def audit_create_client(
_write_audit_log(session, request, entry, commit) _write_audit_log(session, request, entry, commit)
def audit_delete_client(
session: Session, request: Request, client: Client, commit: bool = True
) -> None:
"""Log the creation of a client."""
entry = AuditLog(
operation="CREATE",
client_id=client.id,
client_name=client.name,
message="Client deleted",
)
_write_audit_log(session, request, entry, commit)
def audit_create_secret( def audit_create_secret(
session: Session, session: Session,
request: Request, request: Request,
@ -58,6 +71,44 @@ def audit_create_secret(
_write_audit_log(session, request, entry, commit) _write_audit_log(session, request, entry, commit)
def audit_remove_policy(
session: Session,
request: Request,
client: Client,
policy: ClientAccessPolicy,
commit: bool = True,
) -> None:
"""Audit removal of policy."""
entry = AuditLog(
operation="DELETE",
object="ClientAccessPolicy",
object_id=str(policy.id),
client_id=client.id,
client_name=client.name,
message="Deleted client policy",
)
_write_audit_log(session, request, entry, commit)
def audit_update_policy(
session: Session,
request: Request,
client: Client,
policy: ClientAccessPolicy,
commit: bool = True,
) -> None:
"""Audit update of policy."""
entry = AuditLog(
operation="CREATE",
object="ClientAccessPolicy",
object_id=str(policy.id),
client_id=client.id,
client_name=client.name,
message="Updated client policy",
)
_write_audit_log(session, request, entry, commit)
def audit_update_secret( def audit_update_secret(
session: Session, session: Session,
request: Request, request: Request,
@ -89,7 +140,26 @@ def audit_invalidate_secrets(
object="ClientSecret", object="ClientSecret",
client_name=client.name, client_name=client.name,
client_id=client.id, client_id=client.id,
message="Client fingerprint updated. All secrets invalidated.", message="Client public-key changed. All secrets invalidated.",
)
_write_audit_log(session, request, entry, commit)
def audit_delete_secret(
session: Session,
request: Request,
client: Client,
secret: ClientSecret,
commit: bool = True,
) -> None:
"""Audit Delete client secrets."""
entry = AuditLog(
operation="DELETE",
object="ClientSecret",
object_id=str(secret.id),
client_name=client.name,
client_id=client.id,
message="Deleted secret.",
) )
_write_audit_log(session, request, entry, commit) _write_audit_log(session, request, entry, commit)

View File

@ -1,24 +1,31 @@
"""CLI and main entry point.""" """CLI and main entry point."""
import os
from pathlib import Path from pathlib import Path
from dotenv import load_dotenv
import click import click
from pydantic import BaseModel, FilePath
from .db import generate_api_token
DEFAULT_LISTEN = "127.0.0.1"
DEFAULT_PORT = 8022
WORKDIR = Path(os.getcwd())
class BackendSettings(BaseModel): load_dotenv()
"""Backend Settings."""
db_file: FilePath
regenerate_tokens: bool = False
@click.group() @click.group()
@click.option("--db-file", envvar="sshecret_db_file", type=click.Path(path_type=Path)) @click.option("--database", help="Path to the sqlite database file.")
@click.option("--regenerate-tokens", is_flag=True, default=False) def cli(database: str) -> None:
@click.pass_context """CLI group."""
def cli(ctx: click.Context, db_file: Path, regenerate_tokens: bool) -> None: if database:
"""Sshecret database handler.""" # Hopefully it's enough to set the environment variable as so.
if not isinstance(ctx.obj, BackendSettings): os.environ["SSHECRET_DB_FILE"] = str(Path(database).absolute())
ctx.obj = BackendSettings(db_file=db_file, regenerate_tokens=regenerate_tokens)
@cli.command("generate-token")
def cli_generate_token() -> None:
"""Generate a token."""
token = generate_api_token()
click.echo("Generated api token:")
click.echo(token)

View File

@ -1,18 +1,20 @@
#!/usr/bin/env python3 """Database related functions."""
import logging
import secrets import secrets
from pathlib import Path from pathlib import Path
from sqlalchemy import Engine from sqlalchemy import Engine
from sqlmodel import Session, create_engine, text from sqlmodel import Session, create_engine, text
import bcrypt import bcrypt
from dotenv import load_dotenv
from sqlalchemy.engine import URL from sqlalchemy.engine import URL
from .models import APIClient, init_db from .models import APIClient, init_db
from .settings import get_settings
load_dotenv()
LOG = logging.getLogger(__name__)
def get_engine(filename: Path, echo: bool = False) -> Engine: def get_engine(filename: Path, echo: bool = False) -> Engine:
@ -25,7 +27,7 @@ def get_engine(filename: Path, echo: bool = False) -> Engine:
return engine return engine
def create_db_and_tables(filename: Path, echo: bool = True) -> bool: def create_db_and_tables(filename: Path, echo: bool = False) -> bool:
"""Create database and tables. """Create database and tables.
Returns True if the database was created. Returns True if the database was created.
@ -52,3 +54,14 @@ def create_api_token(session: Session, read_write: bool) -> str:
session.commit() session.commit()
return token return token
def generate_api_token() -> str:
"""Generate API token."""
settings = get_settings()
engine = get_engine(settings.db_file)
init_db(engine)
with Session(engine) as session:
token = create_api_token(session, True)
return token

View File

@ -1,8 +1,15 @@
#!/usr/bin/env python3 """Database models.
TODO:
We might want to pass on audit information from the SSH server.
This might require some changes to these schemas.
"""
import uuid import uuid
from datetime import datetime from datetime import datetime
from sqlalchemy import Engine, Column, DateTime, func import sqlalchemy as sa
from sqlmodel import Field, Relationship, SQLModel from sqlmodel import Field, Relationship, SQLModel
@ -11,22 +18,48 @@ class Client(SQLModel, table=True):
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
name: str = Field(unique=True) name: str = Field(unique=True)
fingerprint: str public_key: str
created_at: datetime = Field(
created_at: datetime | None = Field(
default=None, default=None,
sa_column=Column( sa_type=sa.DateTime(timezone=True),
DateTime(timezone=True), server_default=func.now(), nullable=True sa_column_kwargs={"server_default": sa.func.now()},
), nullable=False,
) )
updated_at: datetime | None = Field( updated_at: datetime | None = Field(
default=None, default=None,
sa_column=Column(DateTime(timezone=True), onupdate=func.now(), nullable=True), sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
) )
secrets: list["ClientSecret"] = Relationship( secrets: list["ClientSecret"] = Relationship(
back_populates="client", passive_deletes="all" back_populates="client", passive_deletes="all"
) )
policies: list["ClientAccessPolicy"] = Relationship(back_populates="client")
class ClientAccessPolicy(SQLModel, table=True):
"""Client access policies."""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
source: str
client_id: uuid.UUID | None = Field(foreign_key="client.id", ondelete="CASCADE")
client: Client | None = Relationship(back_populates="policies")
created_at: datetime | None = Field(
default=None,
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"server_default": sa.func.now()},
nullable=False,
)
updated_at: datetime | None = Field(
default=None,
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
)
class ClientSecret(SQLModel, table=True): class ClientSecret(SQLModel, table=True):
"""A client secret.""" """A client secret."""
@ -37,15 +70,17 @@ class ClientSecret(SQLModel, table=True):
client: Client | None = Relationship(back_populates="secrets") client: Client | None = Relationship(back_populates="secrets")
secret: str secret: str
invalidated: bool = Field(default=False) invalidated: bool = Field(default=False)
created_at: datetime = Field( created_at: datetime | None = Field(
default=None, default=None,
sa_column=Column( sa_type=sa.DateTime(timezone=True),
DateTime(timezone=True), server_default=func.now(), nullable=True sa_column_kwargs={"server_default": sa.func.now()},
), nullable=False,
) )
updated_at: datetime | None = Field( updated_at: datetime | None = Field(
default=None, default=None,
sa_column=Column(DateTime(timezone=True), onupdate=func.now(), nullable=True), sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
) )
@ -64,28 +99,27 @@ class AuditLog(SQLModel, table=True):
client_name: str | None = None client_name: str | None = None
message: str message: str
origin: str | None = None origin: str | None = None
timestamp: datetime | None = Field( timestamp: datetime | None = Field(
default=None, default=None,
sa_column=Column( sa_type=sa.DateTime(timezone=True),
DateTime(timezone=True), server_default=func.now(), nullable=True sa_column_kwargs={"server_default": sa.func.now()},
), nullable=False,
) )
class APIClient(SQLModel, table=True): class APIClient(SQLModel, table=True):
"""Stores API Keys.""" """Stores API Keys."""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
token: str token: str
read_write: bool read_write: bool
created_at: datetime = Field( created_at: datetime | None = Field(
default=None, default=None,
sa_column=Column( sa_type=sa.DateTime(timezone=True),
DateTime(timezone=True), server_default=func.now(), nullable=True sa_column_kwargs={"server_default": sa.func.now()},
), nullable=False,
) )
def init_db(engine: sa.Engine) -> None:
def init_db(engine: Engine) -> None:
"""Create database.""" """Create database."""
SQLModel.metadata.create_all(engine) SQLModel.metadata.create_all(engine)

View File

@ -1,9 +0,0 @@
"""API Router."""
from fastapi import FastAPI
from .app import backend_api
app = FastAPI()
app.include_router(backend_api)

View File

@ -1,26 +1,25 @@
"""Settings management.""" """Settings management."""
import os from typing import override
from pathlib import Path from pathlib import Path
from pydantic import BaseModel from pydantic import BaseModel, Field
from dotenv import load_dotenv from pydantic_settings import (
BaseSettings,
SettingsConfigDict,
)
DEFAULT_DATABASE = "sshecret.db" DEFAULT_DATABASE = "sshecret.db"
load_dotenv() class BackendSettings(BaseSettings):
"""Backend settings."""
model_config = SettingsConfigDict(env_file=".backend.env", env_prefix="sshecret_")
class BackendSettings(BaseModel): db_file: Path = Field(default=Path(DEFAULT_DATABASE).absolute())
"""Backend Settings."""
db_file: Path
regenerate_tokens: bool = False
def get_settings() -> BackendSettings: def get_settings() -> BackendSettings:
"""Get settings.""" """Get settings."""
db_filename = os.getenv("SSHECRET_DATABASE") or DEFAULT_DATABASE return BackendSettings()
db_file = Path(db_filename).absolute()
return BackendSettings(db_file=db_file)

View File

@ -1,57 +1,48 @@
"""Models for API views.""" """Models for API views."""
import ipaddress
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import Self, override from typing import Annotated, Any, Self, override
from sqlmodel import Field, SQLModel from sqlmodel import Field, SQLModel
from pydantic import IPvAnyAddress, IPvAnyNetwork
from . import models from . import models
class ClientListResponse(SQLModel):
"""Model list responses."""
class ClientView(SQLModel):
"""View for a single client."""
id: uuid.UUID id: uuid.UUID
name: str name: str
fingerprint: str public_key: str
policies: list[str] = ["0.0.0.0/0", "::/0"]
secrets: list[str] = Field(default_factory=list)
created_at: datetime created_at: datetime
updated_at: datetime | None = None updated_at: datetime | None = None
@classmethod @classmethod
def from_clients(cls, clients: list[models.Client]) -> list[Self]: def from_client_list(cls, clients: list[models.Client]) -> list[Self]:
"""Generate a list of responses from a list of clients.""" """Generate a list of responses from a list of clients."""
responses: list[Self] = [] responses: list[Self] = [cls.from_client(client) for client in clients]
for client in clients:
responses.append(
cls(
id=client.id,
name=client.name,
fingerprint=client.fingerprint,
created_at=client.created_at,
updated_at=client.updated_at or None,
)
)
return responses return responses
class ClientView(ClientListResponse):
"""View for a single client."""
secrets: list[str] = Field(default_factory=list)
@classmethod @classmethod
def from_client(cls, client: models.Client) -> Self: def from_client(cls, client: models.Client) -> Self:
"""Instantiate from a client.""" """Instantiate from a client."""
view = cls( view = cls(
id=client.id, id=client.id,
name=client.name, name=client.name,
fingerprint=client.fingerprint, public_key=client.public_key,
created_at=client.created_at, created_at=client.created_at,
updated_at=client.updated_at or None, updated_at=client.updated_at or None,
) )
if client.secrets: if client.secrets:
view.secrets = [secret.name for secret in client.secrets] view.secrets = [secret.name for secret in client.secrets]
if client.policies:
view.policies = [policy.source for policy in client.policies]
return view return view
@ -59,17 +50,20 @@ class ClientCreate(SQLModel):
"""Model to create a client.""" """Model to create a client."""
name: str name: str
fingerprint: str public_key: str
def to_client(self) -> models.Client: def to_client(self) -> models.Client:
"""Instantiate a client.""" """Instantiate a client."""
return models.Client(name=self.name, fingerprint=self.fingerprint) public_key = self.public_key
return models.Client(
name=self.name, public_key=public_key
)
class ClientUpdate(SQLModel): class ClientUpdate(SQLModel):
"""Model to update the client fingerprint.""" """Model to update the client public key."""
fingerprint: str public_key: str
class BodyValue(SQLModel): class BodyValue(SQLModel):
@ -110,3 +104,22 @@ class ClientSecretResponse(ClientSecretPublic):
created_at=client_secret.created_at, created_at=client_secret.created_at,
updated_at=client_secret.updated_at, updated_at=client_secret.updated_at,
) )
class ClientPolicyView(SQLModel):
"""Update object for client policy."""
sources: list[str] = ["0.0.0.0/0", "::/0"]
@classmethod
def from_client(cls, client: models.Client) -> Self:
"""Create from client."""
if not client.policies:
return cls()
return cls(sources=[policy.source for policy in client.policies])
class ClientPolicyUpdate(SQLModel):
"""Model for updating policies."""
sources: list[IPvAnyAddress | IPvAnyNetwork]

View File

@ -1,6 +1,8 @@
"""Tests of the backend api using pytest.""" """Tests of the backend api using pytest."""
import logging import logging
import random
import string
from httpx import Response from httpx import Response
import pytest import pytest
@ -8,8 +10,7 @@ from fastapi.testclient import TestClient
from sqlmodel import Session, SQLModel, create_engine from sqlmodel import Session, SQLModel, create_engine
from sqlmodel.pool import StaticPool from sqlmodel.pool import StaticPool
from sshecret_backend.router import app from sshecret_backend.app import app, get_session
from sshecret_backend.app import get_session
from sshecret_backend.testing import create_test_token from sshecret_backend.testing import create_test_token
from sshecret_backend.models import AuditLog from sshecret_backend.models import AuditLog
@ -22,17 +23,30 @@ LOG.addHandler(handler)
LOG.setLevel(logging.DEBUG) LOG.setLevel(logging.DEBUG)
TEST_FINGERPRINT = "00:11:22:33:44:55:66:77:88:99:aa:bb:cc:dd:ee:ff" def make_test_key() -> str:
"""Generate a test key."""
randomlength = 540
key = "ssh-rsa "
randompart = "".join(
random.choices(string.ascii_letters + string.digits, k=randomlength)
)
comment = " invalid-test-key"
return key + randompart + comment
def create_client( def create_client(
test_client: TestClient, test_client: TestClient,
headers: dict[str, str], headers: dict[str, str],
name: str, name: str,
fingerprint: str = TEST_FINGERPRINT, public_key: str | None = None,
) -> Response: ) -> Response:
"""Create client.""" """Create client."""
data = {"name": name, "fingerprint": fingerprint} if not public_key:
public_key = make_test_key()
data = {
"name": name,
"public_key": public_key,
}
create_response = test_client.post("/api/v1/clients", headers=headers, json=data) create_response = test_client.post("/api/v1/clients", headers=headers, json=data)
return create_response return create_response
@ -95,10 +109,8 @@ def test_with_token(test_client: TestClient, token: str) -> None:
def test_create_client(test_client: TestClient, headers: dict[str, str]) -> None: def test_create_client(test_client: TestClient, headers: dict[str, str]) -> None:
"""Test creating a client.""" """Test creating a client."""
client_name = "test" client_name = "test"
client_fingerprint = TEST_FINGERPRINT client_publickey = make_test_key()
create_response = create_client( create_response = create_client(test_client, headers, client_name, client_publickey)
test_client, headers, client_name, client_fingerprint
)
assert create_response.status_code == 200 assert create_response.status_code == 200
response = test_client.get("/api/v1/clients/", headers=headers) response = test_client.get("/api/v1/clients/", headers=headers)
assert response.status_code == 200 assert response.status_code == 200
@ -107,16 +119,35 @@ def test_create_client(test_client: TestClient, headers: dict[str, str]) -> None
client = clients[0] client = clients[0]
assert isinstance(client, dict) assert isinstance(client, dict)
assert client.get("name") == client_name assert client.get("name") == client_name
assert client.get("fingerprint") == client_fingerprint
assert client.get("created_at") is not None assert client.get("created_at") is not None
def test_delete_client(test_client: TestClient, headers: dict[str, str]) -> None:
"""Test creating a client."""
client_name = "test"
create_response = create_client(
test_client,
headers,
client_name,
)
assert create_response.status_code == 200
resp = test_client.delete("/api/v1/clients/test", headers=headers)
assert resp.status_code == 200
resp = test_client.get("/api/v1/clients/test", headers=headers)
assert resp.status_code == 404
def test_add_secret(test_client: TestClient, headers: dict[str, str]) -> None: def test_add_secret(test_client: TestClient, headers: dict[str, str]) -> None:
"""Test adding a secret to a client.""" """Test adding a secret to a client."""
client_name = "test" client_name = "test"
client_fingerprint = "00:11:22:33:44:55:66:77:88:99:aa:bb:cc:dd:ee:ff" client_publickey = make_test_key()
create_response = create_client( create_response = create_client(
test_client, headers, client_name, client_fingerprint test_client,
headers,
client_name,
client_publickey,
) )
assert create_response.status_code == 200 assert create_response.status_code == 200
secret_name = "mysecret" secret_name = "mysecret"
@ -136,6 +167,17 @@ def test_add_secret(test_client: TestClient, headers: dict[str, str]) -> None:
assert secret_body["secret"] == data["secret"] assert secret_body["secret"] == data["secret"]
def test_delete_secret(test_client: TestClient, headers: dict[str, str]) -> None:
"""Test deleting a secret."""
test_add_secret(test_client, headers)
resp = test_client.delete("/api/v1/clients/test/secrets/mysecret", headers=headers)
assert resp.status_code == 200
get_response = test_client.get(
"/api/v1/clients/test/secrets/mysecret", headers=headers
)
assert get_response.status_code == 404
def test_put_add_secret(test_client: TestClient, headers: dict[str, str]) -> None: def test_put_add_secret(test_client: TestClient, headers: dict[str, str]) -> None:
"""Test adding secret via PUT.""" """Test adding secret via PUT."""
# Use the test_create_client function to create a client. # Use the test_create_client function to create a client.
@ -178,7 +220,8 @@ def test_put_update_secret(test_client: TestClient, headers: dict[str, str]) ->
def test_audit_logging(test_client: TestClient, headers: dict[str, str]) -> None: def test_audit_logging(test_client: TestClient, headers: dict[str, str]) -> None:
"""Test audit logging.""" """Test audit logging."""
create_client_resp = create_client(test_client, headers, "test") public_key = make_test_key()
create_client_resp = create_client(test_client, headers, "test", public_key)
assert create_client_resp.status_code == 200 assert create_client_resp.status_code == 200
secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"} secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"}
for name, secret in secrets.items(): for name, secret in secrets.items():
@ -251,7 +294,8 @@ def test_audit_log_filtering(
def test_secret_invalidation(test_client: TestClient, headers: dict[str, str]) -> None: def test_secret_invalidation(test_client: TestClient, headers: dict[str, str]) -> None:
"""Test secret invalidation.""" """Test secret invalidation."""
create_client_resp = create_client(test_client, headers, "test") initial_key = make_test_key()
create_client_resp = create_client(test_client, headers, "test", initial_key)
assert create_client_resp.status_code == 200 assert create_client_resp.status_code == 200
secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"} secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"}
for name, secret in secrets.items(): for name, secret in secrets.items():
@ -262,12 +306,13 @@ def test_secret_invalidation(test_client: TestClient, headers: dict[str, str]) -
) )
assert add_resp.status_code == 200 assert add_resp.status_code == 200
# Update the fingerprint. This should cause all secrets to be invalidated # Update the public-key. This should cause all secrets to be invalidated
# and no longer associated with a client. # and no longer associated with a client.
new_key = make_test_key()
update_resp = test_client.post( update_resp = test_client.post(
"/api/v1/clients/test/update_fingerprint", "/api/v1/clients/test/public-key",
headers=headers, headers=headers,
json={"fingerprint": "foobar"}, json={"public_key": new_key},
) )
assert update_resp.status_code == 200 assert update_resp.status_code == 200
@ -277,3 +322,117 @@ def test_secret_invalidation(test_client: TestClient, headers: dict[str, str]) -
client = get_resp.json() client = get_resp.json()
secrets = client.get("secrets") secrets = client.get("secrets")
assert bool(secrets) is False assert bool(secrets) is False
def test_client_default_policies(
test_client: TestClient, headers: dict[str, str]
) -> None:
"""Test client policies."""
public_key = make_test_key()
resp = create_client(test_client, headers, "test", public_key)
assert resp.status_code == 200
# Fetch policies, should return *
resp = test_client.get("/api/v1/clients/test/policies/", headers=headers)
assert resp.status_code == 200
policies = resp.json()
assert policies["sources"] == ["0.0.0.0/0", "::/0"]
def test_client_policy_update_one(
test_client: TestClient, headers: dict[str, str]
) -> None:
"""Update client policy with single policy."""
public_key = make_test_key()
resp = create_client(test_client, headers, "test", public_key)
assert resp.status_code == 200
policy = ["192.0.2.1"]
resp = test_client.put(
"/api/v1/clients/test/policies/", headers=headers, json={"sources": policy}
)
assert resp.status_code == 200
resp = test_client.get("/api/v1/clients/test/policies/", headers=headers)
assert resp.status_code == 200
policies = resp.json()
assert policies["sources"] == policy
def test_client_policy_update_advanced(
test_client: TestClient, headers: dict[str, str]
) -> None:
"""Test other policy update scenarios."""
public_key = make_test_key()
resp = create_client(test_client, headers, "test", public_key)
assert resp.status_code == 200
policy = ["192.0.2.1", "198.18.0.0/24"]
resp = test_client.put(
"/api/v1/clients/test/policies/", headers=headers, json={"sources": policy}
)
assert resp.status_code == 200
resp = test_client.get("/api/v1/clients/test/policies/", headers=headers)
assert resp.status_code == 200
policies = resp.json()
assert "192.0.2.1" in policies["sources"]
assert "198.18.0.0/24" in policies["sources"]
# Try to set it to something incorrect
policy = ["obviosly_wrong"]
resp = test_client.put(
"/api/v1/clients/test/policies/", headers=headers, json={"sources": policy}
)
assert resp.status_code == 422
# Check that the old value is still there
resp = test_client.get("/api/v1/clients/test/policies/", headers=headers)
assert resp.status_code == 200
policies = resp.json()
assert "192.0.2.1" in policies["sources"]
assert "198.18.0.0/24" in policies["sources"]
# Clear the policies
#
def test_client_policy_update_unset(
test_client: TestClient, headers: dict[str, str]
) -> None:
"""Test clearing the client policy."""
public_key = make_test_key()
resp = create_client(test_client, headers, "test", public_key)
assert resp.status_code == 200
policy = ["192.0.2.1", "198.18.0.0/24"]
resp = test_client.put(
"/api/v1/clients/test/policies/", headers=headers, json={"sources": policy}
)
assert resp.status_code == 200
policies = resp.json()
assert "192.0.2.1" in policies["sources"]
assert "198.18.0.0/24" in policies["sources"]
# Now we clear the policies
resp = test_client.put(
"/api/v1/clients/test/policies/", headers=headers, json={"sources": []}
)
assert resp.status_code == 200
policies = resp.json()
assert policies["sources"] == ["0.0.0.0/0", "::/0"]