Complete backend

This commit is contained in:
2025-04-18 16:39:05 +02:00
parent 83551ffb4a
commit ec90fb7680
11 changed files with 561 additions and 121 deletions

View File

@ -1 +1,5 @@
"""Sshecret backend."""
from .app import app as app
#from .router import app as app
__all__ = ["app"]

View File

@ -1,26 +1,52 @@
"""FastAPI api."""
"""FastAPI api.
TODO: We may want to allow a consumer to generate audit log entries manually.
"""
import logging
from collections.abc import Sequence
from contextlib import asynccontextmanager
from typing import Annotated
from collections.abc import Sequence
import bcrypt
from fastapi import APIRouter, Depends, FastAPI, Header, HTTPException, Query, Request
from fastapi import (
APIRouter,
Depends,
FastAPI,
Header,
HTTPException,
Query,
Request,
status,
)
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from sqlmodel import Session, select
from . import audit
from .db import get_engine
from .models import APIClient, AuditLog, Client, ClientSecret, init_db
from .models import (
APIClient,
AuditLog,
Client,
ClientAccessPolicy,
ClientSecret,
init_db,
)
from .settings import get_settings
from .view_models import (
BodyValue,
ClientCreate,
ClientListResponse,
ClientSecretPublic,
ClientSecretResponse,
ClientUpdate,
ClientView,
ClientPolicyView,
ClientPolicyUpdate,
)
settings = get_settings()
@ -104,12 +130,12 @@ backend_api = APIRouter(
@backend_api.get("/clients/")
async def get_clients(
session: Annotated[Session, Depends(get_session)]
) -> list[ClientListResponse]:
) -> list[ClientView]:
"""Get clients."""
statement = select(Client)
results = session.exec(statement)
clients = list(results)
return ClientListResponse.from_clients(clients)
return ClientView.from_client_list(clients)
@backend_api.get("/clients/{name}")
@ -128,14 +154,93 @@ async def get_client(
return ClientView.from_client(client)
@backend_api.post("/clients/{name}/update_fingerprint")
async def update_client_fingerprint(
@backend_api.delete("/clients/{name}")
async def delete_client(
request: Request, name: str, session: Annotated[Session, Depends(get_session)]
) -> None:
"""Delete a client."""
statement = select(Client).where(Client.name == name)
results = session.exec(statement)
client = results.first()
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
session.delete(client)
session.commit()
audit.audit_delete_client(session, request, client)
@backend_api.get("/clients/{name}/policies/")
async def get_client_policies(
name: str, session: Annotated[Session, Depends(get_session)]
) -> ClientPolicyView:
"""Get client policies."""
statement = select(Client).where(Client.name == name)
results = session.exec(statement)
client = results.first()
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
return ClientPolicyView.from_client(client)
@backend_api.put("/clients/{name}/policies/")
async def update_client_policies(
request: Request,
name: str,
policy_update: ClientPolicyUpdate,
session: Annotated[Session, Depends(get_session)],
) -> ClientPolicyView:
"""Update client policies.
This is also how you delete policies.
"""
statement = select(Client).where(Client.name == name)
results = session.exec(statement)
client = results.first()
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
# Remove old policies.
policies = session.exec(
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
).all()
deleted_policies: list[ClientAccessPolicy] = []
added_policies: list[ClientAccessPolicy] = []
for policy in policies:
session.delete(policy)
deleted_policies.append(policy)
for source in policy_update.sources:
LOG.debug("Source %r", source)
policy = ClientAccessPolicy(source=str(source), client_id=client.id)
session.add(policy)
added_policies.append(policy)
session.commit()
session.refresh(client)
for policy in deleted_policies:
audit.audit_remove_policy(session, request, client, policy)
for policy in added_policies:
audit.audit_update_policy(session, request, client, policy)
return ClientPolicyView.from_client(client)
@backend_api.post("/clients/{name}/public-key")
async def update_client_public_key(
request: Request,
name: str,
client_update: ClientUpdate,
session: Annotated[Session, Depends(get_session)],
) -> ClientView:
"""Update the client fingerprint.
"""Change the public key of a client.
This invalidates all secrets.
"""
@ -146,7 +251,7 @@ async def update_client_fingerprint(
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
client.fingerprint = client_update.fingerprint
client.public_key = client_update.public_key
for secret in session.exec(
select(ClientSecret).where(ClientSecret.client_id == client.id)
).all():
@ -170,7 +275,11 @@ async def create_client(
session: Annotated[Session, Depends(get_session)],
) -> ClientView:
"""Create client."""
db_client = Client.model_validate(client)
existing = await get_client_by_name(session, client.name)
if existing:
raise HTTPException(400, detail="Error: Already a client with that name.")
db_client = client.to_client()
session.add(db_client)
session.commit()
session.refresh(db_client)
@ -270,6 +379,30 @@ async def request_client_secret(
audit.audit_access_secret(session, request, client, secret)
return response_model
@backend_api.delete("/clients/{name}/secrets/{secret_name}")
async def delete_client_secret(
request: Request,
name: str,
secret_name: str,
session: Annotated[Session, Depends(get_session)],
) -> None:
"""Delete a secret."""
client = await get_client_by_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
secret = await lookup_client_secret(session, client, secret_name)
if not secret:
raise HTTPException(
status_code=404, detail="Cannot find a secret with the given name."
)
session.delete(secret)
session.commit()
audit.audit_delete_secret(session, request, client, secret)
@backend_api.get("/audit/", response_model=list[AuditLog])
async def get_audit_logs(
@ -287,3 +420,17 @@ async def get_audit_logs(
results = session.exec(statement).all()
return results
app = FastAPI()
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
)
app.include_router(backend_api)

View File

@ -3,7 +3,7 @@
from collections.abc import Sequence
from fastapi import Request
from sqlmodel import Session, select
from .models import AuditLog, Client, ClientSecret
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy
def _get_origin(request: Request) -> str | None:
@ -39,6 +39,19 @@ def audit_create_client(
_write_audit_log(session, request, entry, commit)
def audit_delete_client(
session: Session, request: Request, client: Client, commit: bool = True
) -> None:
"""Log the creation of a client."""
entry = AuditLog(
operation="CREATE",
client_id=client.id,
client_name=client.name,
message="Client deleted",
)
_write_audit_log(session, request, entry, commit)
def audit_create_secret(
session: Session,
request: Request,
@ -58,6 +71,44 @@ def audit_create_secret(
_write_audit_log(session, request, entry, commit)
def audit_remove_policy(
session: Session,
request: Request,
client: Client,
policy: ClientAccessPolicy,
commit: bool = True,
) -> None:
"""Audit removal of policy."""
entry = AuditLog(
operation="DELETE",
object="ClientAccessPolicy",
object_id=str(policy.id),
client_id=client.id,
client_name=client.name,
message="Deleted client policy",
)
_write_audit_log(session, request, entry, commit)
def audit_update_policy(
session: Session,
request: Request,
client: Client,
policy: ClientAccessPolicy,
commit: bool = True,
) -> None:
"""Audit update of policy."""
entry = AuditLog(
operation="CREATE",
object="ClientAccessPolicy",
object_id=str(policy.id),
client_id=client.id,
client_name=client.name,
message="Updated client policy",
)
_write_audit_log(session, request, entry, commit)
def audit_update_secret(
session: Session,
request: Request,
@ -89,7 +140,26 @@ def audit_invalidate_secrets(
object="ClientSecret",
client_name=client.name,
client_id=client.id,
message="Client fingerprint updated. All secrets invalidated.",
message="Client public-key changed. All secrets invalidated.",
)
_write_audit_log(session, request, entry, commit)
def audit_delete_secret(
session: Session,
request: Request,
client: Client,
secret: ClientSecret,
commit: bool = True,
) -> None:
"""Audit Delete client secrets."""
entry = AuditLog(
operation="DELETE",
object="ClientSecret",
object_id=str(secret.id),
client_name=client.name,
client_id=client.id,
message="Deleted secret.",
)
_write_audit_log(session, request, entry, commit)

View File

@ -1,24 +1,31 @@
"""CLI and main entry point."""
import os
from pathlib import Path
from dotenv import load_dotenv
import click
from pydantic import BaseModel, FilePath
from .db import generate_api_token
DEFAULT_LISTEN = "127.0.0.1"
DEFAULT_PORT = 8022
WORKDIR = Path(os.getcwd())
class BackendSettings(BaseModel):
"""Backend Settings."""
db_file: FilePath
regenerate_tokens: bool = False
load_dotenv()
@click.group()
@click.option("--db-file", envvar="sshecret_db_file", type=click.Path(path_type=Path))
@click.option("--regenerate-tokens", is_flag=True, default=False)
@click.pass_context
def cli(ctx: click.Context, db_file: Path, regenerate_tokens: bool) -> None:
"""Sshecret database handler."""
if not isinstance(ctx.obj, BackendSettings):
ctx.obj = BackendSettings(db_file=db_file, regenerate_tokens=regenerate_tokens)
@click.option("--database", help="Path to the sqlite database file.")
def cli(database: str) -> None:
"""CLI group."""
if database:
# Hopefully it's enough to set the environment variable as so.
os.environ["SSHECRET_DB_FILE"] = str(Path(database).absolute())
@cli.command("generate-token")
def cli_generate_token() -> None:
"""Generate a token."""
token = generate_api_token()
click.echo("Generated api token:")
click.echo(token)

View File

@ -1,18 +1,20 @@
#!/usr/bin/env python3
"""Database related functions."""
import logging
import secrets
from pathlib import Path
from sqlalchemy import Engine
from sqlmodel import Session, create_engine, text
import bcrypt
from dotenv import load_dotenv
from sqlalchemy.engine import URL
from .models import APIClient, init_db
from .settings import get_settings
load_dotenv()
LOG = logging.getLogger(__name__)
def get_engine(filename: Path, echo: bool = False) -> Engine:
@ -25,7 +27,7 @@ def get_engine(filename: Path, echo: bool = False) -> Engine:
return engine
def create_db_and_tables(filename: Path, echo: bool = True) -> bool:
def create_db_and_tables(filename: Path, echo: bool = False) -> bool:
"""Create database and tables.
Returns True if the database was created.
@ -52,3 +54,14 @@ def create_api_token(session: Session, read_write: bool) -> str:
session.commit()
return token
def generate_api_token() -> str:
"""Generate API token."""
settings = get_settings()
engine = get_engine(settings.db_file)
init_db(engine)
with Session(engine) as session:
token = create_api_token(session, True)
return token

View File

@ -1,8 +1,15 @@
#!/usr/bin/env python3
"""Database models.
TODO:
We might want to pass on audit information from the SSH server.
This might require some changes to these schemas.
"""
import uuid
from datetime import datetime
from sqlalchemy import Engine, Column, DateTime, func
import sqlalchemy as sa
from sqlmodel import Field, Relationship, SQLModel
@ -11,22 +18,48 @@ class Client(SQLModel, table=True):
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
name: str = Field(unique=True)
fingerprint: str
created_at: datetime = Field(
public_key: str
created_at: datetime | None = Field(
default=None,
sa_column=Column(
DateTime(timezone=True), server_default=func.now(), nullable=True
),
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"server_default": sa.func.now()},
nullable=False,
)
updated_at: datetime | None = Field(
default=None,
sa_column=Column(DateTime(timezone=True), onupdate=func.now(), nullable=True),
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
)
secrets: list["ClientSecret"] = Relationship(
back_populates="client", passive_deletes="all"
)
policies: list["ClientAccessPolicy"] = Relationship(back_populates="client")
class ClientAccessPolicy(SQLModel, table=True):
"""Client access policies."""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
source: str
client_id: uuid.UUID | None = Field(foreign_key="client.id", ondelete="CASCADE")
client: Client | None = Relationship(back_populates="policies")
created_at: datetime | None = Field(
default=None,
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"server_default": sa.func.now()},
nullable=False,
)
updated_at: datetime | None = Field(
default=None,
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
)
class ClientSecret(SQLModel, table=True):
"""A client secret."""
@ -37,15 +70,17 @@ class ClientSecret(SQLModel, table=True):
client: Client | None = Relationship(back_populates="secrets")
secret: str
invalidated: bool = Field(default=False)
created_at: datetime = Field(
created_at: datetime | None = Field(
default=None,
sa_column=Column(
DateTime(timezone=True), server_default=func.now(), nullable=True
),
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"server_default": sa.func.now()},
nullable=False,
)
updated_at: datetime | None = Field(
default=None,
sa_column=Column(DateTime(timezone=True), onupdate=func.now(), nullable=True),
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
)
@ -64,28 +99,27 @@ class AuditLog(SQLModel, table=True):
client_name: str | None = None
message: str
origin: str | None = None
timestamp: datetime | None = Field(
default=None,
sa_column=Column(
DateTime(timezone=True), server_default=func.now(), nullable=True
),
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"server_default": sa.func.now()},
nullable=False,
)
class APIClient(SQLModel, table=True):
"""Stores API Keys."""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
token: str
read_write: bool
created_at: datetime = Field(
created_at: datetime | None = Field(
default=None,
sa_column=Column(
DateTime(timezone=True), server_default=func.now(), nullable=True
),
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"server_default": sa.func.now()},
nullable=False,
)
def init_db(engine: Engine) -> None:
def init_db(engine: sa.Engine) -> None:
"""Create database."""
SQLModel.metadata.create_all(engine)

View File

@ -1,9 +0,0 @@
"""API Router."""
from fastapi import FastAPI
from .app import backend_api
app = FastAPI()
app.include_router(backend_api)

View File

@ -1,26 +1,25 @@
"""Settings management."""
import os
from typing import override
from pathlib import Path
from pydantic import BaseModel
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from pydantic_settings import (
BaseSettings,
SettingsConfigDict,
)
DEFAULT_DATABASE = "sshecret.db"
load_dotenv()
class BackendSettings(BaseSettings):
"""Backend settings."""
model_config = SettingsConfigDict(env_file=".backend.env", env_prefix="sshecret_")
class BackendSettings(BaseModel):
"""Backend Settings."""
db_file: Path
regenerate_tokens: bool = False
db_file: Path = Field(default=Path(DEFAULT_DATABASE).absolute())
def get_settings() -> BackendSettings:
"""Get settings."""
db_filename = os.getenv("SSHECRET_DATABASE") or DEFAULT_DATABASE
db_file = Path(db_filename).absolute()
return BackendSettings(db_file=db_file)
return BackendSettings()

View File

@ -1,57 +1,48 @@
"""Models for API views."""
import ipaddress
import uuid
from datetime import datetime
from typing import Self, override
from typing import Annotated, Any, Self, override
from sqlmodel import Field, SQLModel
from pydantic import IPvAnyAddress, IPvAnyNetwork
from . import models
class ClientListResponse(SQLModel):
"""Model list responses."""
class ClientView(SQLModel):
"""View for a single client."""
id: uuid.UUID
name: str
fingerprint: str
public_key: str
policies: list[str] = ["0.0.0.0/0", "::/0"]
secrets: list[str] = Field(default_factory=list)
created_at: datetime
updated_at: datetime | None = None
@classmethod
def from_clients(cls, clients: list[models.Client]) -> list[Self]:
def from_client_list(cls, clients: list[models.Client]) -> list[Self]:
"""Generate a list of responses from a list of clients."""
responses: list[Self] = []
for client in clients:
responses.append(
cls(
id=client.id,
name=client.name,
fingerprint=client.fingerprint,
created_at=client.created_at,
updated_at=client.updated_at or None,
)
)
responses: list[Self] = [cls.from_client(client) for client in clients]
return responses
class ClientView(ClientListResponse):
"""View for a single client."""
secrets: list[str] = Field(default_factory=list)
@classmethod
def from_client(cls, client: models.Client) -> Self:
"""Instantiate from a client."""
view = cls(
id=client.id,
name=client.name,
fingerprint=client.fingerprint,
public_key=client.public_key,
created_at=client.created_at,
updated_at=client.updated_at or None,
)
if client.secrets:
view.secrets = [secret.name for secret in client.secrets]
if client.policies:
view.policies = [policy.source for policy in client.policies]
return view
@ -59,17 +50,20 @@ class ClientCreate(SQLModel):
"""Model to create a client."""
name: str
fingerprint: str
public_key: str
def to_client(self) -> models.Client:
"""Instantiate a client."""
return models.Client(name=self.name, fingerprint=self.fingerprint)
public_key = self.public_key
return models.Client(
name=self.name, public_key=public_key
)
class ClientUpdate(SQLModel):
"""Model to update the client fingerprint."""
"""Model to update the client public key."""
fingerprint: str
public_key: str
class BodyValue(SQLModel):
@ -110,3 +104,22 @@ class ClientSecretResponse(ClientSecretPublic):
created_at=client_secret.created_at,
updated_at=client_secret.updated_at,
)
class ClientPolicyView(SQLModel):
"""Update object for client policy."""
sources: list[str] = ["0.0.0.0/0", "::/0"]
@classmethod
def from_client(cls, client: models.Client) -> Self:
"""Create from client."""
if not client.policies:
return cls()
return cls(sources=[policy.source for policy in client.policies])
class ClientPolicyUpdate(SQLModel):
"""Model for updating policies."""
sources: list[IPvAnyAddress | IPvAnyNetwork]