diff --git a/packages/sshecret-backend/src/sshecret_backend/api/__init__.py b/packages/sshecret-backend/src/sshecret_backend/api/__init__.py index 95ef5da..611c3e8 100644 --- a/packages/sshecret-backend/src/sshecret_backend/api/__init__.py +++ b/packages/sshecret-backend/src/sshecret_backend/api/__init__.py @@ -1,8 +1,7 @@ """API factory modules.""" from .audit import get_audit_api -from .clients import get_clients_api from .policies import get_policy_api from .secrets import get_secrets_api -__all__ = ["get_audit_api", "get_clients_api", "get_policy_api", "get_secrets_api"] +__all__ = ["get_audit_api", "get_policy_api", "get_secrets_api"] diff --git a/packages/sshecret-backend/src/sshecret_backend/api/clients.py b/packages/sshecret-backend/src/sshecret_backend/api/clients.py deleted file mode 100644 index 872cfbc..0000000 --- a/packages/sshecret-backend/src/sshecret_backend/api/clients.py +++ /dev/null @@ -1,227 +0,0 @@ -"""Client sub-api factory.""" - -# pyright: reportUnusedFunction=false - -import uuid -import logging -from fastapi import APIRouter, Depends, HTTPException, Query, Request, status -from pydantic import BaseModel, Field, model_validator -from typing import Annotated, Any, Self, TypeVar, cast - -from sqlalchemy import select, func -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.sql import Select -from sshecret_backend.types import AsyncDBSessionDep -from sshecret_backend.models import Client, ClientSecret -from sshecret_backend.view_models import ( - ClientCreate, - ClientQueryResult, - ClientView, - ClientUpdate, -) -from sshecret_backend import audit -from .common import get_client_by_id_or_name, client_with_relationships - - -class ClientListParams(BaseModel): - """Client list parameters.""" - - limit: int = Field(100, gt=0, le=100) - offset: int = Field(0, ge=0) - id: uuid.UUID | None = None - name: str | None = None - name__like: str | None = None - name__contains: str | None = None - - @model_validator(mode="after") - def validate_expressions(self) -> Self: - """Validate mutually exclusive expression.""" - name_filter = False - if self.name__like or self.name__contains: - name_filter = True - if self.name__like and self.name__contains: - raise ValueError("You may only specify one name expression") - if self.name and name_filter: - raise ValueError( - "You must either specify name or one of name__like or name__contains" - ) - - return self - - -LOG = logging.getLogger(__name__) - -T = TypeVar("T") - - -def filter_client_statement( - statement: Select[Any], params: ClientListParams, ignore_limits: bool = False -) -> Select[Any]: - """Filter a statement with the provided params.""" - if params.id: - statement = statement.where(Client.id == params.id) - - if params.name: - statement = statement.where(Client.name == params.name) - elif params.name__like: - statement = statement.where(Client.name.like(params.name__like)) - elif params.name__contains: - statement = statement.where(Client.name.contains(params.name__contains)) - - if ignore_limits: - return statement - - return statement.limit(params.limit).offset(params.offset) - - -def get_clients_api(get_db_session: AsyncDBSessionDep) -> APIRouter: - """Construct clients sub-api.""" - router = APIRouter() - - @router.get("/clients/") - async def get_clients( - filter_query: Annotated[ClientListParams, Query()], - session: Annotated[AsyncSession, Depends(get_db_session)], - ) -> ClientQueryResult: - """Get clients.""" - # Get total results first - count_statement = select(func.count("*")).select_from(Client) - count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True)) - - total_results = (await session.scalars(count_statement)).one() - - statement = filter_client_statement(client_with_relationships(), filter_query, False) - - results = await session.scalars(statement) - remainder = total_results - filter_query.offset - filter_query.limit - if remainder < 0: - remainder = 0 - - clients = list(results.all()) - clients_view = ClientView.from_client_list(clients) - return ClientQueryResult( - clients=clients_view, - total_results=total_results, - remaining_results=remainder, - ) - - @router.get("/clients/{name}") - async def get_client( - name: str, - session: Annotated[AsyncSession, Depends(get_db_session)], - ) -> ClientView: - """Fetch a client.""" - client = await get_client_by_id_or_name(session, name) - if not client: - raise HTTPException( - status_code=404, detail="Cannot find a client with the given name." - ) - return ClientView.from_client(client) - - @router.delete("/clients/{name}") - async def delete_client( - request: Request, - name: str, - session: Annotated[AsyncSession, Depends(get_db_session)], - ) -> None: - """Delete a client.""" - client = await get_client_by_id_or_name(session, name) - if not client: - raise HTTPException( - status_code=404, detail="Cannot find a client with the given name." - ) - - await session.delete(client) - await session.commit() - await audit.audit_delete_client(session, request, client) - - @router.post("/clients/") - async def create_client( - request: Request, - client: ClientCreate, - session: Annotated[AsyncSession, Depends(get_db_session)], - ) -> ClientView: - """Create client.""" - existing = await get_client_by_id_or_name(session, client.name) - if existing: - raise HTTPException(400, detail="Error: Already a client with that name.") - - db_client = client.to_client() - session.add(db_client) - await session.commit() - await session.refresh(db_client) - db_client = await get_client_by_id_or_name(session, client.name) - if not db_client: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Could not create the client.") - await audit.audit_create_client(session, request, db_client) - return ClientView.from_client(db_client) - - @router.post("/clients/{name}/public-key") - async def update_client_public_key( - request: Request, - name: str, - client_update: ClientUpdate, - session: Annotated[AsyncSession, Depends(get_db_session)], - ) -> ClientView: - """Change the public key of a client. - - This invalidates all secrets. - """ - client = await get_client_by_id_or_name(session, name) - if not client: - raise HTTPException( - status_code=404, detail="Cannot find a client with the given name." - ) - client.public_key = client_update.public_key - matching_secrets = await session.scalars(select(ClientSecret).where(ClientSecret.client_id == client.id)) - for secret in matching_secrets.all(): - LOG.debug("Invalidated secret %s", secret.id) - secret.invalidated = True - secret.client_id = None - - session.add(client) - await session.refresh(client) - await session.commit() - await audit.audit_invalidate_secrets(session, request, client) - - return ClientView.from_client(client) - - @router.put("/clients/{name}") - async def update_client( - request: Request, - name: str, - client_update: ClientCreate, - session: Annotated[AsyncSession, Depends(get_db_session)], - ) -> ClientView: - """Change the public key of a client. - - This invalidates all secrets. - """ - client = await get_client_by_id_or_name(session, name) - if not client: - raise HTTPException( - status_code=404, detail="Cannot find a client with the given name." - ) - client.name = client_update.name - client.description = client_update.description - public_key_updated = False - if client_update.public_key != client.public_key: - public_key_updated = True - client_secrets = await session.scalars( - select(ClientSecret).where(ClientSecret.client_id == client.id) - ) - for secret in client_secrets.all(): - LOG.debug("Invalidated secret %s", secret.id) - secret.invalidated = True - secret.client_id = None - - session.add(client) - await session.commit() - await session.refresh(client) - await audit.audit_update_client(session, request, client) - if public_key_updated: - await audit.audit_invalidate_secrets(session, request, client) - - return ClientView.from_client(client) - - return router diff --git a/packages/sshecret-backend/src/sshecret_backend/api/common.py b/packages/sshecret-backend/src/sshecret_backend/api/common.py index 0693907..41ad016 100644 --- a/packages/sshecret-backend/src/sshecret_backend/api/common.py +++ b/packages/sshecret-backend/src/sshecret_backend/api/common.py @@ -3,15 +3,27 @@ import re import uuid import bcrypt +from dataclasses import dataclass, field from sqlalchemy import Select from sqlalchemy.orm import selectinload from sqlalchemy.future import select from sqlalchemy.ext.asyncio import AsyncSession -from sshecret_backend.models import Client +from sshecret_backend.models import Client, ClientAccessPolicy + +RE_UUID = re.compile( + "^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$" +) + + +@dataclass +class NewClientVersion: + """New client version dataclass.""" + + client: Client + policies: list[ClientAccessPolicy] = field(default_factory=list) -RE_UUID = re.compile("^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$") def verify_token(token: str, stored_hash: str) -> bool: """Verify token.""" @@ -19,12 +31,19 @@ def verify_token(token: str, stored_hash: str) -> bool: stored_bytes = stored_hash.encode("utf-8") return bcrypt.checkpw(token_bytes, stored_bytes) -async def reload_client_with_relationships(session: AsyncSession, client: Client) -> Client: + +async def reload_client_with_relationships( + session: AsyncSession, client: Client +) -> Client: """Reload a client from the database.""" session.expunge(client) stmt = ( select(Client) - .options(selectinload(Client.policies), selectinload(Client.secrets)) + .options( + selectinload(Client.policies), + selectinload(Client.secrets), + selectinload(Client.previous_version), + ) .where(Client.id == client.id) ) result = await session.execute(stmt) @@ -36,13 +55,26 @@ def client_with_relationships() -> Select[tuple[Client]]: return select(Client).options( selectinload(Client.secrets), selectinload(Client.policies), + selectinload(Client.previous_version), ) -async def get_client_by_name(session: AsyncSession, name: str) -> Client | None: - """Get client by name.""" - client_filter = client_with_relationships().where(Client.name == name) - client_results = await session.execute(client_filter) - return client_results.scalars().first() + +async def resolve_client_id( + session: AsyncSession, name: str, version: int | None = None, include_deleted: bool = False, +) -> uuid.UUID | None: + """Get the ID of a client name.""" + if include_deleted: + client_filter = client_with_relationships().where(Client.name == name) + else: + client_filter = query_active_clients().where(Client.name == name) + if version: + client_filter = client_filter.where(Client.version == version) + + client_result = await session.execute(client_filter) + if client := client_result.scalars().first(): + return client.id + return None + async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | None: """Get client by ID.""" @@ -50,10 +82,75 @@ async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | Non client_results = await session.execute(client_filter) return client_results.scalars().first() -async def get_client_by_id_or_name(session: AsyncSession, id_or_name: str) -> Client | None: + +async def get_client_by_id_or_name( + session: AsyncSession, id_or_name: str +) -> Client | None: """Get client either by id or name.""" if RE_UUID.match(id_or_name): id = uuid.UUID(id_or_name) return await get_client_by_id(session, id) return await get_client_by_name(session, id_or_name) + + +def query_active_clients() -> Select[tuple[Client]]: + """Get all active clients.""" + client_filter = ( + client_with_relationships() + .where(Client.is_active.is_(True)) + .where(Client.is_deleted.is_(False)) + ) + return client_filter + + +async def get_client_by_name(session: AsyncSession, name: str) -> Client | None: + """Get client by name. + + This will get the latest client version, unless it's deleted. + """ + client_filter = ( + client_with_relationships() + .where(Client.is_active.is_(True)) + .where(Client.is_deleted.is_not(True)) + .where(Client.name == name) + .order_by(Client.version.desc()) + ) + client_result = await session.execute(client_filter) + return client_result.scalars().first() + + +async def refresh_client(session: AsyncSession, client: Client) -> None: + """Refresh the client and load in all relationships.""" + await session.refresh( + client, attribute_names=["secrets", "policies", "previous_version", "updated_at"] + ) + + +async def create_new_client_version( + session: AsyncSession, current_client: Client, new_public_key: str +) -> Client: + new_client = Client( + name=current_client.name, + version=current_client.version + 1, + description=current_client.description, + public_key=new_public_key, + previous_version_id=current_client.id, + is_active=True, + ) + + # Mark current client as inactive + current_client.is_active = False + + # Copy policies + for policy in current_client.policies: + copied_policy = ClientAccessPolicy( + client=new_client, + address=policy.source, + ) + session.add(copied_policy) + + session.add(new_client) + await session.flush() + await refresh_client(session, new_client) + return new_client diff --git a/packages/sshecret-backend/src/sshecret_backend/app.py b/packages/sshecret-backend/src/sshecret_backend/app.py index cf0da7a..b312e6a 100644 --- a/packages/sshecret-backend/src/sshecret_backend/app.py +++ b/packages/sshecret-backend/src/sshecret_backend/app.py @@ -12,16 +12,13 @@ from fastapi import ( from fastapi.encoders import jsonable_encoder from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse -from sqlalchemy import Engine -from sqlalchemy.ext.asyncio import AsyncEngine from .models import init_db_async from .backend_api import get_backend_api -from .db import setup_database, get_async_engine +from .db import get_async_engine from .settings import BackendSettings -from .types import AsyncDBSessionDep LOG = logging.getLogger(__name__) diff --git a/packages/sshecret-backend/src/sshecret_backend/backend_api.py b/packages/sshecret-backend/src/sshecret_backend/backend_api.py index 06441f0..f81615e 100644 --- a/packages/sshecret-backend/src/sshecret_backend/backend_api.py +++ b/packages/sshecret-backend/src/sshecret_backend/backend_api.py @@ -9,7 +9,8 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sshecret_backend.db import DatabaseSessionManager from sshecret_backend.settings import BackendSettings -from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api +from .api import get_audit_api, get_policy_api, get_secrets_api +from .api.clients.router import create_client_router from .auth import verify_token from .models import ( APIClient, @@ -60,7 +61,7 @@ def get_backend_api( ) backend_api.include_router(get_audit_api(get_db_session)) - backend_api.include_router(get_clients_api(get_db_session)) + backend_api.include_router(create_client_router(get_db_session)) backend_api.include_router(get_policy_api(get_db_session)) backend_api.include_router(get_secrets_api(get_db_session)) diff --git a/packages/sshecret-backend/src/sshecret_backend/models.py b/packages/sshecret-backend/src/sshecret_backend/models.py index a2123ab..cddef14 100644 --- a/packages/sshecret-backend/src/sshecret_backend/models.py +++ b/packages/sshecret-backend/src/sshecret_backend/models.py @@ -51,13 +51,22 @@ class Client(Base): """Clients.""" __tablename__: str = "client" + __table_args__: tuple[sa.UniqueConstraint, ...] = ( + sa.UniqueConstraint("name", "version", name="uq_client_name_version"), + ) id: Mapped[uuid.UUID] = mapped_column( sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4 ) - name: Mapped[str] = mapped_column(sa.String, unique=True) + version: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=1) + + name: Mapped[str] = mapped_column(sa.String, nullable=False) + description: Mapped[str | None] = mapped_column(sa.String, nullable=True) - public_key: Mapped[str] = mapped_column(sa.Text) + public_key: Mapped[str] = mapped_column(sa.Text, nullable=False) + + is_active: Mapped[bool] = mapped_column(sa.Boolean, default=True) + is_deleted: Mapped[bool] = mapped_column(sa.Boolean, default=False) created_at: Mapped[datetime] = mapped_column( sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False @@ -69,10 +78,26 @@ class Client(Base): onupdate=sa.func.now(), ) + deleted_at: Mapped[datetime | None] = mapped_column( + sa.DateTime(timezone=True), + nullable=True + ) + secrets: Mapped[list["ClientSecret"]] = relationship( back_populates="client", passive_deletes=True ) + previous_version_id: Mapped[uuid.UUID | None] = mapped_column( + sa.Uuid(as_uuid=True), + sa.ForeignKey("client.id", ondelete="SET NULL"), + nullable=True, + ) + previous_version: Mapped["Client | None"] = relationship( + "Client", + remote_side=[id], + backref="versions" + ) + policies: Mapped[list["ClientAccessPolicy"]] = relationship(back_populates="client") diff --git a/packages/sshecret-backend/src/sshecret_backend/view_models.py b/packages/sshecret-backend/src/sshecret_backend/view_models.py index 395bb94..dcbd93e 100644 --- a/packages/sshecret-backend/src/sshecret_backend/view_models.py +++ b/packages/sshecret-backend/src/sshecret_backend/view_models.py @@ -2,82 +2,15 @@ import uuid from datetime import datetime -from typing import Annotated, Self, Sequence, override +from typing import Self, override +from collections.abc import Sequence -from pydantic import AfterValidator, BaseModel, Field, IPvAnyAddress, IPvAnyNetwork +from pydantic import BaseModel, Field, IPvAnyAddress, IPvAnyNetwork -from sshecret.crypto import public_key_validator from . import models -class ClientView(BaseModel): - """View for a single client.""" - - id: uuid.UUID - name: str - description: str | None = None - public_key: str - policies: list[str] = ["0.0.0.0/0", "::/0"] - secrets: list[str] = Field(default_factory=list) - created_at: datetime | None - updated_at: datetime | None = None - - @classmethod - def from_client_list(cls, clients: list[models.Client]) -> list[Self]: - """Generate a list of responses from a list of clients.""" - responses: list[Self] = [cls.from_client(client) for client in clients] - return responses - - @classmethod - def from_client(cls, client: models.Client) -> Self: - """Instantiate from a client.""" - view = cls( - id=client.id, - name=client.name, - description=client.description, - public_key=client.public_key, - created_at=client.created_at, - updated_at=client.updated_at or None, - ) - if client.secrets: - view.secrets = [secret.name for secret in client.secrets] - - if client.policies: - view.policies = [policy.source for policy in client.policies] - - return view - - -class ClientQueryResult(BaseModel): - """Result class for queries towards the client list.""" - - clients: list[ClientView] = Field(default_factory=list) - total_results: int - remaining_results: int - - -class ClientCreate(BaseModel): - """Model to create a client.""" - - name: str - description: str | None = None - public_key: Annotated[str, AfterValidator(public_key_validator)] - - def to_client(self) -> models.Client: - """Instantiate a client.""" - return models.Client( - name=self.name, - public_key=self.public_key, - description=self.description, - ) - - -class ClientUpdate(BaseModel): - """Model to update the client public key.""" - - public_key: Annotated[str, AfterValidator(public_key_validator)] - class BodyValue(BaseModel): """A generic model with just a value parameter.""" diff --git a/tests/packages/backend/test_backend.py b/tests/packages/backend/test_backend.py index a6576c1..89c46f9 100644 --- a/tests/packages/backend/test_backend.py +++ b/tests/packages/backend/test_backend.py @@ -20,7 +20,7 @@ handler = logging.StreamHandler() formatter = logging.Formatter("'%(asctime)s - %(levelname)s - %(message)s'") handler.setFormatter(formatter) LOG.addHandler(handler) -LOG.setLevel(logging.DEBUG) +#LOG.setLevel(logging.DEBUG) def make_test_key() -> str: @@ -473,7 +473,7 @@ def test_operations_with_id(test_client: TestClient) -> None: data = resp.json() client = data["clients"][0] client_id = client["id"] - resp = test_client.get(f"/api/v1/clients/{client_id}") + resp = test_client.get(f"/api/v1/clients/by-id/{client_id}") assert resp.status_code == 200 data = resp.json() assert data["name"] == "test"