Refactor backend views, update secret model #24

Merged
eising merged 4 commits from feature/expanded-secrets into main 2025-06-08 15:45:00 +00:00
13 changed files with 157 additions and 327 deletions
Showing only changes of commit aa6b55a911 - Show all commits

View File

@ -1,8 +1,7 @@
"""API factory modules.""" """API factory modules."""
from .audit import get_audit_api from .audit import get_audit_api
from .clients import get_clients_api
from .policies import get_policy_api from .policies import get_policy_api
from .secrets import get_secrets_api from .secrets import get_secrets_api
__all__ = ["get_audit_api", "get_clients_api", "get_policy_api", "get_secrets_api"] __all__ = ["get_audit_api", "get_policy_api", "get_secrets_api"]

View File

@ -1,227 +0,0 @@
"""Client sub-api factory."""
# pyright: reportUnusedFunction=false
import uuid
import logging
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from pydantic import BaseModel, Field, model_validator
from typing import Annotated, Any, Self, TypeVar, cast
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import Select
from sshecret_backend.types import AsyncDBSessionDep
from sshecret_backend.models import Client, ClientSecret
from sshecret_backend.view_models import (
ClientCreate,
ClientQueryResult,
ClientView,
ClientUpdate,
)
from sshecret_backend import audit
from .common import get_client_by_id_or_name, client_with_relationships
class ClientListParams(BaseModel):
"""Client list parameters."""
limit: int = Field(100, gt=0, le=100)
offset: int = Field(0, ge=0)
id: uuid.UUID | None = None
name: str | None = None
name__like: str | None = None
name__contains: str | None = None
@model_validator(mode="after")
def validate_expressions(self) -> Self:
"""Validate mutually exclusive expression."""
name_filter = False
if self.name__like or self.name__contains:
name_filter = True
if self.name__like and self.name__contains:
raise ValueError("You may only specify one name expression")
if self.name and name_filter:
raise ValueError(
"You must either specify name or one of name__like or name__contains"
)
return self
LOG = logging.getLogger(__name__)
T = TypeVar("T")
def filter_client_statement(
statement: Select[Any], params: ClientListParams, ignore_limits: bool = False
) -> Select[Any]:
"""Filter a statement with the provided params."""
if params.id:
statement = statement.where(Client.id == params.id)
if params.name:
statement = statement.where(Client.name == params.name)
elif params.name__like:
statement = statement.where(Client.name.like(params.name__like))
elif params.name__contains:
statement = statement.where(Client.name.contains(params.name__contains))
if ignore_limits:
return statement
return statement.limit(params.limit).offset(params.offset)
def get_clients_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Construct clients sub-api."""
router = APIRouter()
@router.get("/clients/")
async def get_clients(
filter_query: Annotated[ClientListParams, Query()],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientQueryResult:
"""Get clients."""
# Get total results first
count_statement = select(func.count("*")).select_from(Client)
count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True))
total_results = (await session.scalars(count_statement)).one()
statement = filter_client_statement(client_with_relationships(), filter_query, False)
results = await session.scalars(statement)
remainder = total_results - filter_query.offset - filter_query.limit
if remainder < 0:
remainder = 0
clients = list(results.all())
clients_view = ClientView.from_client_list(clients)
return ClientQueryResult(
clients=clients_view,
total_results=total_results,
remaining_results=remainder,
)
@router.get("/clients/{name}")
async def get_client(
name: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Fetch a client."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
return ClientView.from_client(client)
@router.delete("/clients/{name}")
async def delete_client(
request: Request,
name: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> None:
"""Delete a client."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
await session.delete(client)
await session.commit()
await audit.audit_delete_client(session, request, client)
@router.post("/clients/")
async def create_client(
request: Request,
client: ClientCreate,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Create client."""
existing = await get_client_by_id_or_name(session, client.name)
if existing:
raise HTTPException(400, detail="Error: Already a client with that name.")
db_client = client.to_client()
session.add(db_client)
await session.commit()
await session.refresh(db_client)
db_client = await get_client_by_id_or_name(session, client.name)
if not db_client:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Could not create the client.")
await audit.audit_create_client(session, request, db_client)
return ClientView.from_client(db_client)
@router.post("/clients/{name}/public-key")
async def update_client_public_key(
request: Request,
name: str,
client_update: ClientUpdate,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Change the public key of a client.
This invalidates all secrets.
"""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
client.public_key = client_update.public_key
matching_secrets = await session.scalars(select(ClientSecret).where(ClientSecret.client_id == client.id))
for secret in matching_secrets.all():
LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True
secret.client_id = None
session.add(client)
await session.refresh(client)
await session.commit()
await audit.audit_invalidate_secrets(session, request, client)
return ClientView.from_client(client)
@router.put("/clients/{name}")
async def update_client(
request: Request,
name: str,
client_update: ClientCreate,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Change the public key of a client.
This invalidates all secrets.
"""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
client.name = client_update.name
client.description = client_update.description
public_key_updated = False
if client_update.public_key != client.public_key:
public_key_updated = True
client_secrets = await session.scalars(
select(ClientSecret).where(ClientSecret.client_id == client.id)
)
for secret in client_secrets.all():
LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True
secret.client_id = None
session.add(client)
await session.commit()
await session.refresh(client)
await audit.audit_update_client(session, request, client)
if public_key_updated:
await audit.audit_invalidate_secrets(session, request, client)
return ClientView.from_client(client)
return router

View File

@ -3,15 +3,27 @@
import re import re
import uuid import uuid
import bcrypt import bcrypt
from dataclasses import dataclass, field
from sqlalchemy import Select from sqlalchemy import Select
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from sqlalchemy.future import select from sqlalchemy.future import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sshecret_backend.models import Client from sshecret_backend.models import Client, ClientAccessPolicy
RE_UUID = re.compile(
"^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$"
)
@dataclass
class NewClientVersion:
"""New client version dataclass."""
client: Client
policies: list[ClientAccessPolicy] = field(default_factory=list)
RE_UUID = re.compile("^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$")
def verify_token(token: str, stored_hash: str) -> bool: def verify_token(token: str, stored_hash: str) -> bool:
"""Verify token.""" """Verify token."""
@ -19,12 +31,19 @@ def verify_token(token: str, stored_hash: str) -> bool:
stored_bytes = stored_hash.encode("utf-8") stored_bytes = stored_hash.encode("utf-8")
return bcrypt.checkpw(token_bytes, stored_bytes) return bcrypt.checkpw(token_bytes, stored_bytes)
async def reload_client_with_relationships(session: AsyncSession, client: Client) -> Client:
async def reload_client_with_relationships(
session: AsyncSession, client: Client
) -> Client:
"""Reload a client from the database.""" """Reload a client from the database."""
session.expunge(client) session.expunge(client)
stmt = ( stmt = (
select(Client) select(Client)
.options(selectinload(Client.policies), selectinload(Client.secrets)) .options(
selectinload(Client.policies),
selectinload(Client.secrets),
selectinload(Client.previous_version),
)
.where(Client.id == client.id) .where(Client.id == client.id)
) )
result = await session.execute(stmt) result = await session.execute(stmt)
@ -36,13 +55,26 @@ def client_with_relationships() -> Select[tuple[Client]]:
return select(Client).options( return select(Client).options(
selectinload(Client.secrets), selectinload(Client.secrets),
selectinload(Client.policies), selectinload(Client.policies),
selectinload(Client.previous_version),
) )
async def get_client_by_name(session: AsyncSession, name: str) -> Client | None:
"""Get client by name.""" async def resolve_client_id(
client_filter = client_with_relationships().where(Client.name == name) session: AsyncSession, name: str, version: int | None = None, include_deleted: bool = False,
client_results = await session.execute(client_filter) ) -> uuid.UUID | None:
return client_results.scalars().first() """Get the ID of a client name."""
if include_deleted:
client_filter = client_with_relationships().where(Client.name == name)
else:
client_filter = query_active_clients().where(Client.name == name)
if version:
client_filter = client_filter.where(Client.version == version)
client_result = await session.execute(client_filter)
if client := client_result.scalars().first():
return client.id
return None
async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | None: async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | None:
"""Get client by ID.""" """Get client by ID."""
@ -50,10 +82,75 @@ async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | Non
client_results = await session.execute(client_filter) client_results = await session.execute(client_filter)
return client_results.scalars().first() return client_results.scalars().first()
async def get_client_by_id_or_name(session: AsyncSession, id_or_name: str) -> Client | None:
async def get_client_by_id_or_name(
session: AsyncSession, id_or_name: str
) -> Client | None:
"""Get client either by id or name.""" """Get client either by id or name."""
if RE_UUID.match(id_or_name): if RE_UUID.match(id_or_name):
id = uuid.UUID(id_or_name) id = uuid.UUID(id_or_name)
return await get_client_by_id(session, id) return await get_client_by_id(session, id)
return await get_client_by_name(session, id_or_name) return await get_client_by_name(session, id_or_name)
def query_active_clients() -> Select[tuple[Client]]:
"""Get all active clients."""
client_filter = (
client_with_relationships()
.where(Client.is_active.is_(True))
.where(Client.is_deleted.is_(False))
)
return client_filter
async def get_client_by_name(session: AsyncSession, name: str) -> Client | None:
"""Get client by name.
This will get the latest client version, unless it's deleted.
"""
client_filter = (
client_with_relationships()
.where(Client.is_active.is_(True))
.where(Client.is_deleted.is_not(True))
.where(Client.name == name)
.order_by(Client.version.desc())
)
client_result = await session.execute(client_filter)
return client_result.scalars().first()
async def refresh_client(session: AsyncSession, client: Client) -> None:
"""Refresh the client and load in all relationships."""
await session.refresh(
client, attribute_names=["secrets", "policies", "previous_version", "updated_at"]
)
async def create_new_client_version(
session: AsyncSession, current_client: Client, new_public_key: str
) -> Client:
new_client = Client(
name=current_client.name,
version=current_client.version + 1,
description=current_client.description,
public_key=new_public_key,
previous_version_id=current_client.id,
is_active=True,
)
# Mark current client as inactive
current_client.is_active = False
# Copy policies
for policy in current_client.policies:
copied_policy = ClientAccessPolicy(
client=new_client,
address=policy.source,
)
session.add(copied_policy)
session.add(new_client)
await session.flush()
await refresh_client(session, new_client)
return new_client

View File

@ -12,16 +12,13 @@ from fastapi import (
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from sqlalchemy import Engine
from sqlalchemy.ext.asyncio import AsyncEngine
from .models import init_db_async from .models import init_db_async
from .backend_api import get_backend_api from .backend_api import get_backend_api
from .db import setup_database, get_async_engine from .db import get_async_engine
from .settings import BackendSettings from .settings import BackendSettings
from .types import AsyncDBSessionDep
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)

View File

@ -9,7 +9,8 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sshecret_backend.db import DatabaseSessionManager from sshecret_backend.db import DatabaseSessionManager
from sshecret_backend.settings import BackendSettings from sshecret_backend.settings import BackendSettings
from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api from .api import get_audit_api, get_policy_api, get_secrets_api
from .api.clients.router import create_client_router
from .auth import verify_token from .auth import verify_token
from .models import ( from .models import (
APIClient, APIClient,
@ -60,7 +61,7 @@ def get_backend_api(
) )
backend_api.include_router(get_audit_api(get_db_session)) backend_api.include_router(get_audit_api(get_db_session))
backend_api.include_router(get_clients_api(get_db_session)) backend_api.include_router(create_client_router(get_db_session))
backend_api.include_router(get_policy_api(get_db_session)) backend_api.include_router(get_policy_api(get_db_session))
backend_api.include_router(get_secrets_api(get_db_session)) backend_api.include_router(get_secrets_api(get_db_session))

View File

@ -51,13 +51,22 @@ class Client(Base):
"""Clients.""" """Clients."""
__tablename__: str = "client" __tablename__: str = "client"
__table_args__: tuple[sa.UniqueConstraint, ...] = (
sa.UniqueConstraint("name", "version", name="uq_client_name_version"),
)
id: Mapped[uuid.UUID] = mapped_column( id: Mapped[uuid.UUID] = mapped_column(
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4 sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
) )
name: Mapped[str] = mapped_column(sa.String, unique=True) version: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=1)
name: Mapped[str] = mapped_column(sa.String, nullable=False)
description: Mapped[str | None] = mapped_column(sa.String, nullable=True) description: Mapped[str | None] = mapped_column(sa.String, nullable=True)
public_key: Mapped[str] = mapped_column(sa.Text) public_key: Mapped[str] = mapped_column(sa.Text, nullable=False)
is_active: Mapped[bool] = mapped_column(sa.Boolean, default=True)
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, default=False)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
@ -69,10 +78,26 @@ class Client(Base):
onupdate=sa.func.now(), onupdate=sa.func.now(),
) )
deleted_at: Mapped[datetime | None] = mapped_column(
sa.DateTime(timezone=True),
nullable=True
)
secrets: Mapped[list["ClientSecret"]] = relationship( secrets: Mapped[list["ClientSecret"]] = relationship(
back_populates="client", passive_deletes=True back_populates="client", passive_deletes=True
) )
previous_version_id: Mapped[uuid.UUID | None] = mapped_column(
sa.Uuid(as_uuid=True),
sa.ForeignKey("client.id", ondelete="SET NULL"),
nullable=True,
)
previous_version: Mapped["Client | None"] = relationship(
"Client",
remote_side=[id],
backref="versions"
)
policies: Mapped[list["ClientAccessPolicy"]] = relationship(back_populates="client") policies: Mapped[list["ClientAccessPolicy"]] = relationship(back_populates="client")

View File

@ -2,82 +2,15 @@
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import Annotated, Self, Sequence, override from typing import Self, override
from collections.abc import Sequence
from pydantic import AfterValidator, BaseModel, Field, IPvAnyAddress, IPvAnyNetwork from pydantic import BaseModel, Field, IPvAnyAddress, IPvAnyNetwork
from sshecret.crypto import public_key_validator
from . import models from . import models
class ClientView(BaseModel):
"""View for a single client."""
id: uuid.UUID
name: str
description: str | None = None
public_key: str
policies: list[str] = ["0.0.0.0/0", "::/0"]
secrets: list[str] = Field(default_factory=list)
created_at: datetime | None
updated_at: datetime | None = None
@classmethod
def from_client_list(cls, clients: list[models.Client]) -> list[Self]:
"""Generate a list of responses from a list of clients."""
responses: list[Self] = [cls.from_client(client) for client in clients]
return responses
@classmethod
def from_client(cls, client: models.Client) -> Self:
"""Instantiate from a client."""
view = cls(
id=client.id,
name=client.name,
description=client.description,
public_key=client.public_key,
created_at=client.created_at,
updated_at=client.updated_at or None,
)
if client.secrets:
view.secrets = [secret.name for secret in client.secrets]
if client.policies:
view.policies = [policy.source for policy in client.policies]
return view
class ClientQueryResult(BaseModel):
"""Result class for queries towards the client list."""
clients: list[ClientView] = Field(default_factory=list)
total_results: int
remaining_results: int
class ClientCreate(BaseModel):
"""Model to create a client."""
name: str
description: str | None = None
public_key: Annotated[str, AfterValidator(public_key_validator)]
def to_client(self) -> models.Client:
"""Instantiate a client."""
return models.Client(
name=self.name,
public_key=self.public_key,
description=self.description,
)
class ClientUpdate(BaseModel):
"""Model to update the client public key."""
public_key: Annotated[str, AfterValidator(public_key_validator)]
class BodyValue(BaseModel): class BodyValue(BaseModel):
"""A generic model with just a value parameter.""" """A generic model with just a value parameter."""

View File

@ -20,7 +20,7 @@ handler = logging.StreamHandler()
formatter = logging.Formatter("'%(asctime)s - %(levelname)s - %(message)s'") formatter = logging.Formatter("'%(asctime)s - %(levelname)s - %(message)s'")
handler.setFormatter(formatter) handler.setFormatter(formatter)
LOG.addHandler(handler) LOG.addHandler(handler)
LOG.setLevel(logging.DEBUG) #LOG.setLevel(logging.DEBUG)
def make_test_key() -> str: def make_test_key() -> str:
@ -473,7 +473,7 @@ def test_operations_with_id(test_client: TestClient) -> None:
data = resp.json() data = resp.json()
client = data["clients"][0] client = data["clients"][0]
client_id = client["id"] client_id = client["id"]
resp = test_client.get(f"/api/v1/clients/{client_id}") resp = test_client.get(f"/api/v1/clients/by-id/{client_id}")
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
assert data["name"] == "test" assert data["name"] == "test"