From 7ad41f43d8673de971450f6026e52d91892a38cc Mon Sep 17 00:00:00 2001 From: Allan Eising Date: Sun, 8 Jun 2025 17:40:50 +0200 Subject: [PATCH] Refactor backend views --- ...35ff347a_remove_invalidated_add_deleted.py | 36 ++ .../c251311b64c9_make_client_object_better.py | 51 +++ ..._rename_parent_to_previous_version_for_.py | 27 ++ .../src/sshecret_backend/api/__init__.py | 6 - .../sshecret_backend/api/audit/__init__.py | 1 + .../api/{audit.py => audit/router.py} | 24 +- .../src/sshecret_backend/api/audit/schemas.py | 40 +++ .../sshecret_backend/api/clients/__init__.py | 0 .../api/clients/operations.py | 320 +++++++++++++++++ .../sshecret_backend/api/clients/router.py | 122 +++++++ .../sshecret_backend/api/clients/schemas.py | 137 ++++++++ .../src/sshecret_backend/api/common.py | 55 ++- .../src/sshecret_backend/api/policies.py | 86 ----- .../src/sshecret_backend/api/schemas.py | 9 + .../src/sshecret_backend/api/secrets.py | 257 -------------- .../sshecret_backend/api/secrets/__init__.py | 1 + .../api/secrets/operations.py | 328 ++++++++++++++++++ .../sshecret_backend/api/secrets/router.py | 107 ++++++ .../secrets/schemas.py} | 75 +--- .../src/sshecret_backend/audit.py | 32 +- .../src/sshecret_backend/auth.py | 1 + .../src/sshecret_backend/backend_api.py | 8 +- .../src/sshecret_backend/db.py | 39 ++- .../src/sshecret_backend/models.py | 13 +- tests/packages/backend/test_backend.py | 59 +++- 25 files changed, 1382 insertions(+), 452 deletions(-) create mode 100644 packages/sshecret-backend/migrations/versions/b4e135ff347a_remove_invalidated_add_deleted.py create mode 100644 packages/sshecret-backend/migrations/versions/c251311b64c9_make_client_object_better.py create mode 100644 packages/sshecret-backend/migrations/versions/f2dc50533f88_rename_parent_to_previous_version_for_.py create mode 100644 packages/sshecret-backend/src/sshecret_backend/api/audit/__init__.py rename packages/sshecret-backend/src/sshecret_backend/api/{audit.py => audit/router.py} (87%) create mode 100644 packages/sshecret-backend/src/sshecret_backend/api/audit/schemas.py create mode 100644 packages/sshecret-backend/src/sshecret_backend/api/clients/__init__.py create mode 100644 packages/sshecret-backend/src/sshecret_backend/api/clients/operations.py create mode 100644 packages/sshecret-backend/src/sshecret_backend/api/clients/router.py create mode 100644 packages/sshecret-backend/src/sshecret_backend/api/clients/schemas.py delete mode 100644 packages/sshecret-backend/src/sshecret_backend/api/policies.py create mode 100644 packages/sshecret-backend/src/sshecret_backend/api/schemas.py delete mode 100644 packages/sshecret-backend/src/sshecret_backend/api/secrets.py create mode 100644 packages/sshecret-backend/src/sshecret_backend/api/secrets/__init__.py create mode 100644 packages/sshecret-backend/src/sshecret_backend/api/secrets/operations.py create mode 100644 packages/sshecret-backend/src/sshecret_backend/api/secrets/router.py rename packages/sshecret-backend/src/sshecret_backend/{view_models.py => api/secrets/schemas.py} (51%) diff --git a/packages/sshecret-backend/migrations/versions/b4e135ff347a_remove_invalidated_add_deleted.py b/packages/sshecret-backend/migrations/versions/b4e135ff347a_remove_invalidated_add_deleted.py new file mode 100644 index 0000000..22d803c --- /dev/null +++ b/packages/sshecret-backend/migrations/versions/b4e135ff347a_remove_invalidated_add_deleted.py @@ -0,0 +1,36 @@ +"""Remove invalidated, add deleted + +Revision ID: b4e135ff347a +Revises: f2dc50533f88 +Create Date: 2025-06-06 08:57:47.611854 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'b4e135ff347a' +down_revision: Union[str, None] = 'f2dc50533f88' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('client_secret', sa.Column('deleted', sa.Boolean(), nullable=False)) + op.add_column('client_secret', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True)) + op.drop_column('client_secret', 'invalidated') + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('client_secret', sa.Column('invalidated', sa.BOOLEAN(), nullable=False)) + op.drop_column('client_secret', 'deleted_at') + op.drop_column('client_secret', 'deleted') + # ### end Alembic commands ### diff --git a/packages/sshecret-backend/migrations/versions/c251311b64c9_make_client_object_better.py b/packages/sshecret-backend/migrations/versions/c251311b64c9_make_client_object_better.py new file mode 100644 index 0000000..7a1c71b --- /dev/null +++ b/packages/sshecret-backend/migrations/versions/c251311b64c9_make_client_object_better.py @@ -0,0 +1,51 @@ +"""Make client object better + +Revision ID: c251311b64c9 +Revises: 37329d9b5437 +Create Date: 2025-06-04 21:49:22.638698 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "c251311b64c9" +down_revision: Union[str, None] = "37329d9b5437" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("client", schema=None) as batch_op: + batch_op.add_column(sa.Column("version", sa.Integer(), nullable=False)) + batch_op.add_column(sa.Column("is_active", sa.Boolean(), nullable=False)) + batch_op.add_column(sa.Column("is_deleted", sa.Boolean(), nullable=False)) + batch_op.add_column( + sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True) + ) + batch_op.add_column(sa.Column("parent_id", sa.Uuid(), nullable=True)) + batch_op.create_unique_constraint("uq_client_name_version", ["name", "version"]) + batch_op.create_foreign_key( + "fk_client_parent", "client", ["parent_id"], ["id"], ondelete="SET NULL" + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("client", schema=None) as batch_op: + batch_op.drop_constraint("fk_client_parent", type_="foreignkey") + batch_op.drop_constraint("uq_client_name_version", type_="unique") + batch_op.drop_column("parent_id") + batch_op.drop_column("deleted_at") + batch_op.drop_column("is_deleted") + batch_op.drop_column("is_active") + batch_op.drop_column("version") + # ### end Alembic commands ### diff --git a/packages/sshecret-backend/migrations/versions/f2dc50533f88_rename_parent_to_previous_version_for_.py b/packages/sshecret-backend/migrations/versions/f2dc50533f88_rename_parent_to_previous_version_for_.py new file mode 100644 index 0000000..72af7a3 --- /dev/null +++ b/packages/sshecret-backend/migrations/versions/f2dc50533f88_rename_parent_to_previous_version_for_.py @@ -0,0 +1,27 @@ +"""Rename parent to previous_version for clarity + +Revision ID: f2dc50533f88 +Revises: c251311b64c9 +Create Date: 2025-06-05 13:24:32.465927 + +""" +from typing import Sequence, Union + +from alembic import op + + +# revision identifiers, used by Alembic. +revision: str = 'f2dc50533f88' +down_revision: Union[str, None] = 'c251311b64c9' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + op.alter_column("client", "parent_id", new_column_name="previous_version_id") + + +def downgrade() -> None: + """Downgrade schema.""" + op.alter_column("client", "previous_version_id", new_column_name="parent_id") diff --git a/packages/sshecret-backend/src/sshecret_backend/api/__init__.py b/packages/sshecret-backend/src/sshecret_backend/api/__init__.py index 611c3e8..31044ad 100644 --- a/packages/sshecret-backend/src/sshecret_backend/api/__init__.py +++ b/packages/sshecret-backend/src/sshecret_backend/api/__init__.py @@ -1,7 +1 @@ """API factory modules.""" - -from .audit import get_audit_api -from .policies import get_policy_api -from .secrets import get_secrets_api - -__all__ = ["get_audit_api", "get_policy_api", "get_secrets_api"] diff --git a/packages/sshecret-backend/src/sshecret_backend/api/audit/__init__.py b/packages/sshecret-backend/src/sshecret_backend/api/audit/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/packages/sshecret-backend/src/sshecret_backend/api/audit/__init__.py @@ -0,0 +1 @@ + diff --git a/packages/sshecret-backend/src/sshecret_backend/api/audit.py b/packages/sshecret-backend/src/sshecret_backend/api/audit/router.py similarity index 87% rename from packages/sshecret-backend/src/sshecret_backend/api/audit.py rename to packages/sshecret-backend/src/sshecret_backend/api/audit/router.py index 517bf57..9b298e5 100644 --- a/packages/sshecret-backend/src/sshecret_backend/api/audit.py +++ b/packages/sshecret-backend/src/sshecret_backend/api/audit/router.py @@ -15,7 +15,7 @@ from typing import Annotated from sshecret_backend.models import AuditLog, Operation, SubSystem from sshecret_backend.types import AsyncDBSessionDep -from sshecret_backend.view_models import AuditInfo, AuditView, AuditListResult +from .schemas import AuditInfo, AuditView, AuditListResult LOG = logging.getLogger(__name__) @@ -58,7 +58,7 @@ class AuditFilter(BaseModel): ] -def get_audit_api(get_db_session: AsyncDBSessionDep) -> APIRouter: +def create_audit_router(get_db_session: AsyncDBSessionDep) -> APIRouter: """Construct audit sub-api.""" router = APIRouter() @@ -70,11 +70,13 @@ def get_audit_api(get_db_session: AsyncDBSessionDep) -> APIRouter: """Get audit logs.""" # audit.audit_access_audit_log(session, request) - total = (await session.scalars( - select(func.count("*")) - .select_from(AuditLog) - .where(and_(True, *filters.filter_mapping)) - )).one() + total = ( + await session.scalars( + select(func.count("*")) + .select_from(AuditLog) + .where(and_(True, *filters.filter_mapping)) + ) + ).one() remaining = total - filters.offset statement = ( @@ -107,12 +109,12 @@ def get_audit_api(get_db_session: AsyncDBSessionDep) -> APIRouter: @router.get("/audit/info") async def get_audit_info( - session: Annotated[AsyncSession, Depends(get_db_session)] + session: Annotated[AsyncSession, Depends(get_db_session)], ) -> AuditInfo: """Get audit info.""" - audit_count = (await session.scalars( - select(func.count("*")).select_from(AuditLog) - )).one() + audit_count = ( + await session.scalars(select(func.count("*")).select_from(AuditLog)) + ).one() return AuditInfo(entries=audit_count) return router diff --git a/packages/sshecret-backend/src/sshecret_backend/api/audit/schemas.py b/packages/sshecret-backend/src/sshecret_backend/api/audit/schemas.py new file mode 100644 index 0000000..fd8abf2 --- /dev/null +++ b/packages/sshecret-backend/src/sshecret_backend/api/audit/schemas.py @@ -0,0 +1,40 @@ +"""Models for Audit API views.""" + +import uuid +from datetime import datetime +from collections.abc import Sequence + +from pydantic import BaseModel + + +from sshecret_backend import models + + +class AuditView(BaseModel): + """Audit log view.""" + + id: uuid.UUID | None = None + subsystem: models.SubSystem + message: str + operation: models.Operation + data: dict[str, str] | None = None + client_id: uuid.UUID | None = None + client_name: str | None = None + secret_id: uuid.UUID | None = None + secret_name: str | None = None + origin: str | None = None + timestamp: datetime | None = None + + +class AuditInfo(BaseModel): + """Information about audit information.""" + + entries: int + + +class AuditListResult(BaseModel): + """Class to return when listing audit entries.""" + + results: Sequence[AuditView] + total: int + remaining: int diff --git a/packages/sshecret-backend/src/sshecret_backend/api/clients/__init__.py b/packages/sshecret-backend/src/sshecret_backend/api/clients/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/sshecret-backend/src/sshecret_backend/api/clients/operations.py b/packages/sshecret-backend/src/sshecret_backend/api/clients/operations.py new file mode 100644 index 0000000..58d6ef3 --- /dev/null +++ b/packages/sshecret-backend/src/sshecret_backend/api/clients/operations.py @@ -0,0 +1,320 @@ +"""Client operations.""" + +import logging +import uuid +from typing import Any, cast +from datetime import datetime, timezone + +from fastapi import HTTPException, Request +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.sql import Select +from sshecret_backend import audit +from sshecret_backend.api.common import ( + FlexID, + IdType, + RelaxedId, + create_new_client_version, + query_active_clients, + get_client_by_id, + resolve_client_id, + refresh_client, + reload_client_with_relationships, +) +from sshecret_backend.models import Client, ClientAccessPolicy +from .schemas import ( + ClientListParams, + ClientCreate, + ClientPolicyView, + ClientView, + ClientQueryResult, + ClientPolicyUpdate, +) + + +LOG = logging.getLogger(__name__) + + +def _id(id: RelaxedId) -> uuid.UUID: + """Ensure that the ID is a uuid.""" + if isinstance(id, str): + return uuid.UUID(id) + return id + + +class ClientOperations: + """Perform operations on a client.""" + + def __init__(self, session: AsyncSession, request: Request) -> None: + """Create operations class.""" + self.session: AsyncSession = session + self.request: Request = request + + self._client_id: uuid.UUID | None = None + # self.client_name: str | None = None + # if client.type is IdType.ID: + # self._client_id = _id(client.value) + # else: + # self.client_name = str(client.value) + + async def get_client_id( + self, + client: FlexID, + version: int | None = None, + ) -> uuid.UUID | None: + """Get client ID.""" + if self._client_id: + LOG.debug("Returning previously resolved client ID.") + return self._client_id + + if client.type is IdType.ID: + self._client_id = _id(client.value) + return self._client_id + + client_name = str(client.value) + client_id = await resolve_client_id( + self.session, + client_name, + version=version, + ) + if not client_id: + return None + self._client_id = client_id + LOG.debug("Saving client ID %s", client_id) + return client_id + + async def _get_client( + self, + client: FlexID, + version: int | None = None, + ) -> Client | None: + """Get client.""" + client_id = await self.get_client_id(client, version=version) + if not client_id: + return None + db_client = await get_client_by_id(self.session, client_id) + return db_client + + async def get_client( + self, + client: FlexID, + version: int | None = None, + ) -> ClientView: + """Get public client object.""" + db_client = await self._get_client(client, version) + if not db_client: + raise HTTPException(status_code=404, detail="Client not found.") + return ClientView.from_client(db_client) + + async def create_client( + self, + create_model: ClientCreate, + ) -> ClientView: + """Create a new client.""" + existing_id = await self.get_client_id(FlexID.name(create_model.name)) + if existing_id: + raise HTTPException( + status_code=400, detail="Error: A client already exists with this name." + ) + client = create_model.to_client() + self.session.add(client) + await self.session.flush() + await self.session.commit() + await refresh_client(self.session, client) + await audit.audit_create_client(self.session, self.request, client) + return ClientView.from_client(client) + + async def delete_client(self, client: FlexID) -> None: + """Delete client.""" + db_client = await self._get_client(client) + if not db_client: + return + if db_client.is_deleted: + return + db_client.is_deleted = True + db_client.deleted_at = datetime.now(timezone.utc) + self.session.add(db_client) + await self.session.commit() + await audit.audit_delete_client(self.session, self.request, db_client) + + async def update_client( + self, + client: FlexID, + client_update: ClientCreate, + ) -> ClientView: + """Update client details.""" + db_client = await self._get_client(client) + if not db_client: + raise HTTPException(status_code=404, detail="Client not found.") + + if ( + client_update.public_key + and client_update.public_key != db_client.public_key + ): + return await new_client_version( + self.session, + db_client.id, + client_update.public_key, + client_update.name, + client_update.description, + ) + db_client.name = client_update.name + db_client.description = client_update.description + self.session.add(db_client) + await self.session.commit() + await refresh_client(self.session, db_client) + await audit.audit_update_client(self.session, self.request, db_client) + return ClientView.from_client(db_client) + + async def new_client_version( + self, + client: FlexID, + public_key: str, + name: str | None = None, + description: str | None = None, + ) -> ClientView: + """Update a client to a new version.""" + current_client = await self._get_client(client) + if not current_client: + raise HTTPException(status_code=404, detail="Client not found.") + new_client = await create_new_client_version( + self.session, current_client, public_key + ) + if name: + new_client.name = name + if description: + new_client.description = description + + current_client.is_active = False + self.session.add(current_client) + + await self.session.commit() + + await refresh_client(self.session, new_client) + await audit.audit_new_client_version( + self.session, self.request, current_client, new_client + ) + return ClientView.from_client(new_client) + + async def get_client_policies(self, client: FlexID) -> ClientPolicyView: + """Get client policies.""" + db_client = await self._get_client(client) + if not db_client: + raise HTTPException(status_code=404, detail="Client not found.") + return ClientPolicyView.from_client(db_client) + + async def update_client_policies( + self, client: FlexID, policy_update: ClientPolicyUpdate + ) -> ClientPolicyView: + """Update client policies.""" + db_client = await self._get_client(client) + if not db_client: + raise HTTPException(status_code=404, detail="Client not found.") + + policies = await self.session.scalars( + select(ClientAccessPolicy).where( + ClientAccessPolicy.client_id == db_client.id + ) + ) + deleted_policies: list[ClientAccessPolicy] = [] + added_policies: list[ClientAccessPolicy] = [] + for policy in policies.all(): + await self.session.delete(policy) + deleted_policies.append(policy) + + LOG.debug("Updating client policies with: %r", policy_update.sources) + for source in policy_update.sources: + LOG.debug("Source %r", source) + policy = ClientAccessPolicy(source=str(source), client_id=db_client.id) + self.session.add(policy) + added_policies.append(policy) + + await self.session.flush() + await self.session.commit() + + db_client = await reload_client_with_relationships(self.session, db_client) + for policy in deleted_policies: + await audit.audit_remove_policy( + self.session, self.request, db_client, policy + ) + + for policy in added_policies: + await audit.audit_update_policy( + self.session, self.request, db_client, policy + ) + + return ClientPolicyView.from_client(db_client) + + +def filter_client_statement( + statement: Select[Any], params: ClientListParams, ignore_limits: bool = False +) -> Select[Any]: + """Filter a statement with the provided params.""" + if params.id: + statement = statement.where(Client.id == params.id) + + if params.name: + statement = statement.where(Client.name == params.name) + elif params.name__like: + statement = statement.where(Client.name.like(params.name__like)) + elif params.name__contains: + statement = statement.where(Client.name.contains(params.name__contains)) + + if ignore_limits: + return statement + + return statement.limit(params.limit).offset(params.offset) + + +async def get_clients( + session: AsyncSession, + filter_query: ClientListParams, +) -> ClientQueryResult: + """Get Clients.""" + count_statement = select(func.count("*")).select_from(Client) + count_statement = cast( + Select[tuple[int]], + filter_client_statement(count_statement, filter_query, True), + ) + + total_results = (await session.scalars(count_statement)).one() + + statement = filter_client_statement(query_active_clients(), filter_query, False) + + results = await session.scalars(statement) + remainder = total_results - filter_query.offset - filter_query.limit + if remainder < 0: + remainder = 0 + + clients = list(results.all()) + clients_view = ClientView.from_client_list(clients) + return ClientQueryResult( + clients=clients_view, + total_results=total_results, + remaining_results=remainder, + ) + + +async def new_client_version( + session: AsyncSession, + client_id: uuid.UUID, + public_key: str, + name: str | None = None, + description: str | None = None, +) -> ClientView: + """Update a client to a new version.""" + current_client = await get_client_by_id(session, client_id) + if not current_client: + raise ValueError("Client not found.") + new_client = await create_new_client_version(session, current_client, public_key) + if name: + new_client.name = name + if description: + new_client.description = description + + current_client.is_active = False + session.add(current_client) + + await session.commit() + + return ClientView.from_client(new_client) diff --git a/packages/sshecret-backend/src/sshecret_backend/api/clients/router.py b/packages/sshecret-backend/src/sshecret_backend/api/clients/router.py new file mode 100644 index 0000000..c3eba04 --- /dev/null +++ b/packages/sshecret-backend/src/sshecret_backend/api/clients/router.py @@ -0,0 +1,122 @@ +"""Client router.""" + +# pyright: reportUnusedFunction=false +# +import logging +from typing import Annotated + +from fastapi import APIRouter, Depends, Query, Request +from sqlalchemy.ext.asyncio import AsyncSession + +from sshecret_backend.types import AsyncDBSessionDep +from sshecret_backend.api.clients.schemas import ( + ClientCreate, + ClientListParams, + ClientQueryResult, + ClientUpdate, + ClientView, + ClientPolicyUpdate, + ClientPolicyView, +) +from sshecret_backend.api.clients import operations +from sshecret_backend.api.clients.operations import ClientOperations +from sshecret_backend.api.common import FlexID + +LOG = logging.getLogger(__name__) + + +def create_client_router(get_db_session: AsyncDBSessionDep) -> APIRouter: + """Create client router.""" + router = APIRouter() + + @router.get("/clients/") + async def route_get_clients( + filter_query: Annotated[ClientListParams, Query()], + session: Annotated[AsyncSession, Depends(get_db_session)], + ) -> ClientQueryResult: + return await operations.get_clients(session, filter_query) + + @router.post("/clients/") + async def create_client( + request: Request, + client: ClientCreate, + session: Annotated[AsyncSession, Depends(get_db_session)], + ) -> ClientView: + """Create client.""" + client_op = ClientOperations(session, request) + return await client_op.create_client(client) + + @router.get("/clients/{client_identifier}") + async def fetch_client_by_name( + request: Request, + client_identifier: str, + session: Annotated[AsyncSession, Depends(get_db_session)], + version: Annotated[int | None, Query()] = None, + ) -> ClientView: + """Fetch client by name.""" + client_id = FlexID.from_string(client_identifier) + client_op = ClientOperations(session, request) + return await client_op.get_client(client_id, version) + + @router.delete("/clients/{client_identifier}") + async def delete_client( + request: Request, + client_identifier: str, + session: Annotated[AsyncSession, Depends(get_db_session)], + ) -> None: + """Delete a client.""" + client_id = FlexID.from_string(client_identifier) + client_op = ClientOperations(session, request) + await client_op.delete_client(client_id) + + @router.post("/clients/{client_identifier}/public-key") + async def update_client_public_key( + request: Request, + client_identifier: str, + client_update: ClientUpdate, + session: Annotated[AsyncSession, Depends(get_db_session)], + ) -> ClientView: + """Update client public key.""" + client_id = FlexID.from_string(client_identifier) + client_op = ClientOperations(session, request) + return await client_op.new_client_version(client_id, client_update.public_key) + + @router.put("/clients/{client_identifier}") + async def update_client_by_name( + request: Request, + client_identifier: str, + client_update: ClientCreate, + session: Annotated[AsyncSession, Depends(get_db_session)], + ) -> ClientView: + """Update client by name.""" + client_id = FlexID.from_string(client_identifier) + client_op = ClientOperations(session, request) + return await client_op.update_client(client_id, client_update) + + @router.get("/clients/{client_identifier}/policies/") + async def get_client_policies( + request: Request, + client_identifier: str, + session: Annotated[AsyncSession, Depends(get_db_session)], + ) -> ClientPolicyView: + """Get client policies.""" + client_id = FlexID.from_string(client_identifier) + client_op = ClientOperations(session, request) + return await client_op.get_client_policies(client_id) + + @router.put("/clients/{client_identifier}/policies/") + async def update_client_policies( + request: Request, + client_identifier: str, + policy_update: ClientPolicyUpdate, + session: Annotated[AsyncSession, Depends(get_db_session)], + ) -> ClientPolicyView: + """Update client policies. + + This is also how you delete policies. + """ + client_id = FlexID.from_string(client_identifier) + client_op = ClientOperations(session, request) + return await client_op.update_client_policies(client_id, policy_update) + + return router diff --git a/packages/sshecret-backend/src/sshecret_backend/api/clients/schemas.py b/packages/sshecret-backend/src/sshecret_backend/api/clients/schemas.py new file mode 100644 index 0000000..d5fa680 --- /dev/null +++ b/packages/sshecret-backend/src/sshecret_backend/api/clients/schemas.py @@ -0,0 +1,137 @@ +"""Client related schemas.""" + +from typing import Annotated, Self +import uuid +from datetime import datetime + +from pydantic import ( + AfterValidator, + BaseModel, + Field, + IPvAnyAddress, + IPvAnyNetwork, + model_validator, +) + +from sshecret.crypto import public_key_validator + +from sshecret_backend import models + + +class ClientView(BaseModel): + """View for a single client.""" + + id: uuid.UUID + name: str + description: str | None = None + public_key: str + policies: list[str] = ["0.0.0.0/0", "::/0"] + is_active: bool = True + is_deleted: bool = False + secrets: list[str] = Field(default_factory=list) + created_at: datetime | None + updated_at: datetime | None = None + deleted_at: datetime | None = None + + @classmethod + def from_client_list(cls, clients: list[models.Client]) -> list[Self]: + """Generate a list of responses from a list of clients.""" + responses: list[Self] = [cls.from_client(client) for client in clients] + return responses + + @classmethod + def from_client(cls, client: models.Client) -> Self: + """Instantiate from a client.""" + view = cls( + id=client.id, + name=client.name, + description=client.description, + public_key=client.public_key, + created_at=client.created_at, + updated_at=client.updated_at or None, + deleted_at=client.deleted_at or None, + is_active=client.is_active, + is_deleted=client.is_deleted, + ) + if client.secrets: + view.secrets = [secret.name for secret in client.secrets] + + if client.policies: + view.policies = [policy.source for policy in client.policies] + + return view + + +class ClientQueryResult(BaseModel): + """Result class for queries towards the client list.""" + + clients: list[ClientView] = Field(default_factory=list) + total_results: int + remaining_results: int + + +class ClientListParams(BaseModel): + """Client list parameters.""" + + limit: int = Field(100, gt=0, le=100) + offset: int = Field(0, ge=0) + id: uuid.UUID | None = None + name: str | None = None + name__like: str | None = None + name__contains: str | None = None + + @model_validator(mode="after") + def validate_expressions(self) -> Self: + """Validate mutually exclusive expression.""" + name_filter = False + if self.name__like or self.name__contains: + name_filter = True + if self.name__like and self.name__contains: + raise ValueError("You may only specify one name expression") + if self.name and name_filter: + raise ValueError( + "You must either specify name or one of name__like or name__contains" + ) + + return self + + +class ClientCreate(BaseModel): + """Model to create a client.""" + + name: str + description: str | None = None + public_key: Annotated[str, AfterValidator(public_key_validator)] + + def to_client(self) -> models.Client: + """Instantiate a client.""" + return models.Client( + name=self.name, + public_key=self.public_key, + description=self.description, + ) + + +class ClientUpdate(BaseModel): + """Model to update the client public key.""" + + public_key: Annotated[str, AfterValidator(public_key_validator)] + + +class ClientPolicyView(BaseModel): + """Update object for client policy.""" + + sources: list[str] = ["0.0.0.0/0", "::/0"] + + @classmethod + def from_client(cls, client: models.Client) -> Self: + """Create from client.""" + if not client.policies: + return cls() + return cls(sources=[policy.source for policy in client.policies]) + + +class ClientPolicyUpdate(BaseModel): + """Model for updating policies.""" + + sources: list[IPvAnyAddress | IPvAnyNetwork] diff --git a/packages/sshecret-backend/src/sshecret_backend/api/common.py b/packages/sshecret-backend/src/sshecret_backend/api/common.py index 41ad016..bab5d5e 100644 --- a/packages/sshecret-backend/src/sshecret_backend/api/common.py +++ b/packages/sshecret-backend/src/sshecret_backend/api/common.py @@ -1,21 +1,58 @@ """Common helpers.""" import re +from typing import Self import uuid -import bcrypt from dataclasses import dataclass, field +from enum import Enum + +import bcrypt +from pydantic import BaseModel from sqlalchemy import Select -from sqlalchemy.orm import selectinload - -from sqlalchemy.future import select from sqlalchemy.ext.asyncio import AsyncSession - +from sqlalchemy.future import select +from sqlalchemy.orm import selectinload from sshecret_backend.models import Client, ClientAccessPolicy RE_UUID = re.compile( "^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$" ) +RelaxedId = uuid.UUID | str + + +class IdType(Enum): + """Id type.""" + + ID = "id" + NAME = "name" + + +class FlexID(BaseModel): + """Flexible identifier.""" + + type: IdType + value: RelaxedId + + @classmethod + def id(cls, id: RelaxedId) -> Self: + """Construct from ID.""" + return cls(type=IdType.ID, value=id) + + @classmethod + def name(cls, name: str) -> Self: + """Construct from name.""" + return cls(type=IdType.NAME, value=name) + + @classmethod + def from_string(cls, value: str) -> Self: + """Convert from path string.""" + if value.startswith("id:"): + return cls.id(value[3:]) + elif value.startswith("name:"): + return cls.name(value[5:]) + return cls.name(value) + @dataclass class NewClientVersion: @@ -60,7 +97,10 @@ def client_with_relationships() -> Select[tuple[Client]]: async def resolve_client_id( - session: AsyncSession, name: str, version: int | None = None, include_deleted: bool = False, + session: AsyncSession, + name: str, + version: int | None = None, + include_deleted: bool = False, ) -> uuid.UUID | None: """Get the ID of a client name.""" if include_deleted: @@ -123,7 +163,8 @@ async def get_client_by_name(session: AsyncSession, name: str) -> Client | None: async def refresh_client(session: AsyncSession, client: Client) -> None: """Refresh the client and load in all relationships.""" await session.refresh( - client, attribute_names=["secrets", "policies", "previous_version", "updated_at"] + client, + attribute_names=["secrets", "policies", "previous_version", "updated_at"], ) diff --git a/packages/sshecret-backend/src/sshecret_backend/api/policies.py b/packages/sshecret-backend/src/sshecret_backend/api/policies.py deleted file mode 100644 index 48e677e..0000000 --- a/packages/sshecret-backend/src/sshecret_backend/api/policies.py +++ /dev/null @@ -1,86 +0,0 @@ -"""Policies sub-api router factory.""" - -# pyright: reportUnusedFunction=false - -import logging -from fastapi import APIRouter, Depends, HTTPException, Request -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession -from typing import Annotated - -from sshecret_backend.models import ClientAccessPolicy -from sshecret_backend.view_models import ( - ClientPolicyView, - ClientPolicyUpdate, -) -from sshecret_backend.types import AsyncDBSessionDep -from sshecret_backend import audit -from .common import get_client_by_id_or_name, reload_client_with_relationships - - -LOG = logging.getLogger(__name__) - - -def get_policy_api(get_db_session: AsyncDBSessionDep) -> APIRouter: - """Construct clients sub-api.""" - router = APIRouter() - - @router.get("/clients/{name}/policies/") - async def get_client_policies( - name: str, session: Annotated[AsyncSession, Depends(get_db_session)] - ) -> ClientPolicyView: - """Get client policies.""" - client = await get_client_by_id_or_name(session, name) - if not client: - raise HTTPException( - status_code=404, detail="Cannot find a client with the given name." - ) - - return ClientPolicyView.from_client(client) - - @router.put("/clients/{name}/policies/") - async def update_client_policies( - request: Request, - name: str, - policy_update: ClientPolicyUpdate, - session: Annotated[AsyncSession, Depends(get_db_session)], - ) -> ClientPolicyView: - """Update client policies. - - This is also how you delete policies. - """ - client = await get_client_by_id_or_name(session, name) - if not client: - raise HTTPException( - status_code=404, detail="Cannot find a client with the given name." - ) - # Remove old policies. - policies = await session.scalars( - select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id) - ) - deleted_policies: list[ClientAccessPolicy] = [] - added_policies: list[ClientAccessPolicy] = [] - for policy in policies.all(): - await session.delete(policy) - deleted_policies.append(policy) - - LOG.debug("Updating client policies with: %r", policy_update.sources) - for source in policy_update.sources: - LOG.debug("Source %r", source) - policy = ClientAccessPolicy(source=str(source), client_id=client.id) - session.add(policy) - added_policies.append(policy) - - await session.flush() - await session.commit() - - client = await reload_client_with_relationships(session, client) - for policy in deleted_policies: - await audit.audit_remove_policy(session, request, client, policy) - - for policy in added_policies: - await audit.audit_update_policy(session, request, client, policy) - - return ClientPolicyView.from_client(client) - - return router diff --git a/packages/sshecret-backend/src/sshecret_backend/api/schemas.py b/packages/sshecret-backend/src/sshecret_backend/api/schemas.py new file mode 100644 index 0000000..201ac47 --- /dev/null +++ b/packages/sshecret-backend/src/sshecret_backend/api/schemas.py @@ -0,0 +1,9 @@ +"""Common API schemas.""" + +from pydantic import BaseModel + + +class BodyValue(BaseModel): + """A generic model with just a value parameter.""" + + value: str diff --git a/packages/sshecret-backend/src/sshecret_backend/api/secrets.py b/packages/sshecret-backend/src/sshecret_backend/api/secrets.py deleted file mode 100644 index d77ef5c..0000000 --- a/packages/sshecret-backend/src/sshecret_backend/api/secrets.py +++ /dev/null @@ -1,257 +0,0 @@ -"""Secrets sub-api factory.""" - -# pyright: reportUnusedFunction=false - -import logging -from collections import defaultdict -from fastapi import APIRouter, Depends, HTTPException, Request -from sqlalchemy import select -from typing import Annotated - -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload - -from sshecret_backend.models import Client, ClientSecret -from sshecret_backend.view_models import ( - ClientReference, - ClientSecretDetailList, - ClientSecretList, - ClientSecretPublic, - BodyValue, - ClientSecretResponse, -) -from sshecret_backend import audit -from sshecret_backend.types import AsyncDBSessionDep -from .common import get_client_by_id_or_name, get_client_by_name - - -LOG = logging.getLogger(__name__) - - -async def lookup_client_secret( - session: AsyncSession, client: Client, name: str -) -> ClientSecret | None: - """Look up a secret for a client.""" - statement = ( - select(ClientSecret) - .where(ClientSecret.client_id == client.id) - .where(ClientSecret.name == name) - ) - results = await session.scalars(statement) - return results.first() - - -def get_secrets_api(get_db_session: AsyncDBSessionDep) -> APIRouter: - """Construct clients sub-api.""" - router = APIRouter() - - @router.post("/clients/{name}/secrets/") - async def add_secret_to_client( - request: Request, - name: str, - client_secret: ClientSecretPublic, - session: Annotated[AsyncSession, Depends(get_db_session)], - ) -> None: - """Add secret to a client.""" - client = await get_client_by_id_or_name(session, name) - if not client: - raise HTTPException( - status_code=404, detail="Cannot find a client with the given name." - ) - - existing_secret = await lookup_client_secret( - session, client, client_secret.name - ) - if existing_secret: - raise HTTPException( - status_code=400, - detail="Cannot add a secret. A different secret with the same name already exists.", - ) - db_secret = ClientSecret( - name=client_secret.name, client_id=client.id, secret=client_secret.secret - ) - session.add(db_secret) - await session.commit() - await session.refresh(db_secret) - await audit.audit_create_secret(session, request, client, db_secret) - - @router.put("/clients/{name}/secrets/{secret_name}") - async def update_client_secret( - request: Request, - name: str, - secret_name: str, - secret_data: BodyValue, - session: Annotated[AsyncSession, Depends(get_db_session)], - ) -> ClientSecretResponse: - """Update a client secret. - - This can also be used for destructive creates. - """ - client = await get_client_by_id_or_name(session, name) - if not client: - raise HTTPException( - status_code=404, detail="Cannot find a client with the given name." - ) - - existing_secret = await lookup_client_secret(session, client, secret_name) - if existing_secret: - existing_secret.secret = secret_data.value - - session.add(existing_secret) - await session.commit() - await session.refresh(existing_secret) - await audit.audit_update_secret(session, request, client, existing_secret) - return ClientSecretResponse.from_client_secret(existing_secret) - - db_secret = ClientSecret( - name=secret_name, - client_id=client.id, - secret=secret_data.value, - ) - session.add(db_secret) - await session.commit() - await session.refresh(db_secret) - await audit.audit_create_secret(session, request, client, db_secret) - return ClientSecretResponse.from_client_secret(db_secret) - - @router.get("/clients/{name}/secrets/{secret_name}") - async def request_client_secret( - request: Request, - name: str, - secret_name: str, - session: Annotated[AsyncSession, Depends(get_db_session)], - ) -> ClientSecretResponse: - """Get a client secret.""" - client = await get_client_by_id_or_name(session, name) - if not client: - raise HTTPException( - status_code=404, detail="Cannot find a client with the given name." - ) - - secret = await lookup_client_secret(session, client, secret_name) - if not secret: - raise HTTPException( - status_code=404, detail="Cannot find a secret with the given name." - ) - - response_model = ClientSecretResponse.from_client_secret(secret) - await audit.audit_access_secret(session, request, client, secret) - return response_model - - @router.delete("/clients/{name}/secrets/{secret_name}") - async def delete_client_secret( - request: Request, - name: str, - secret_name: str, - session: Annotated[AsyncSession, Depends(get_db_session)], - ) -> None: - """Delete a secret.""" - client = await get_client_by_name(session, name) - if not client: - raise HTTPException( - status_code=404, detail="Cannot find a client with the given name." - ) - - secret = await lookup_client_secret(session, client, secret_name) - if not secret: - raise HTTPException( - status_code=404, detail="Cannot find a secret with the given name." - ) - - await session.delete(secret) - await session.commit() - await audit.audit_delete_secret(session, request, client, secret) - - @router.get("/secrets/") - async def get_secret_map( - session: Annotated[AsyncSession, Depends(get_db_session)], - ) -> list[ClientSecretList]: - """Get a list of all secrets and which clients have them.""" - client_secret_map: defaultdict[str, list[str]] = defaultdict(list) - client_secrets = await session.scalars( - select(ClientSecret).options(selectinload(ClientSecret.client)) - ) - for client_secret in client_secrets.all(): - if not client_secret.client: - if client_secret.name not in client_secret_map: - client_secret_map[client_secret.name] = [] - continue - client_secret_map[client_secret.name].append(client_secret.client.name) - # audit.audit_client_secret_list(session, request) - return [ - ClientSecretList(name=secret_name, clients=clients) - for secret_name, clients in client_secret_map.items() - ] - - @router.get("/secrets/detailed/") - async def get_detailed_secret_map( - session: Annotated[AsyncSession, Depends(get_db_session)], - ) -> list[ClientSecretDetailList]: - """Get a list of all secrets and which clients have them.""" - client_secrets: dict[str, ClientSecretDetailList] = {} - all_client_secrets = await session.execute( - select(ClientSecret).options(selectinload(ClientSecret.client)) - ) - for client_secret in all_client_secrets.scalars().all(): - if client_secret.name not in client_secrets: - client_secrets[client_secret.name] = ClientSecretDetailList( - name=client_secret.name - ) - client_secrets[client_secret.name].ids.append(str(client_secret.id)) - if not client_secret.client: - continue - client_secrets[client_secret.name].clients.append( - ClientReference( - id=str(client_secret.client.id), name=client_secret.client.name - ) - ) - # `audit.audit_client_secret_list(session, request) - return list(client_secrets.values()) - - @router.get("/secrets/{name}") - async def get_secret_clients( - name: str, - session: Annotated[AsyncSession, Depends(get_db_session)], - ) -> ClientSecretList: - """Get a list of which clients has a named secret.""" - clients: list[str] = [] - client_secrets = await session.scalars( - select(ClientSecret) - .join(ClientSecret.client) - .options(selectinload(ClientSecret.client)) - .where(ClientSecret.name == name) - .where(Client.is_active.is_(True)) - ) - for client_secret in client_secrets.all(): - if not client_secret.client: - continue - clients.append(client_secret.client.name) - - return ClientSecretList(name=name, clients=clients) - - @router.get("/secrets/{name}/detailed") - async def get_secret_clients_detailed( - name: str, - session: Annotated[AsyncSession, Depends(get_db_session)], - ) -> ClientSecretDetailList: - """Get a list of which clients has a named secret.""" - detail_list = ClientSecretDetailList(name=name) - client_secrets = await session.scalars( - select(ClientSecret) - .options(selectinload(ClientSecret.client)) - .where(ClientSecret.name == name) - .where(ClientSecret.client.is_(Client.is_active)) - ) - for client_secret in client_secrets.all(): - if not client_secret.client: - continue - detail_list.ids.append(str(client_secret.id)) - detail_list.clients.append( - ClientReference( - id=str(client_secret.client.id), name=client_secret.client.name - ) - ) - - return detail_list - - return router diff --git a/packages/sshecret-backend/src/sshecret_backend/api/secrets/__init__.py b/packages/sshecret-backend/src/sshecret_backend/api/secrets/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/packages/sshecret-backend/src/sshecret_backend/api/secrets/__init__.py @@ -0,0 +1 @@ + diff --git a/packages/sshecret-backend/src/sshecret_backend/api/secrets/operations.py b/packages/sshecret-backend/src/sshecret_backend/api/secrets/operations.py new file mode 100644 index 0000000..752b8e0 --- /dev/null +++ b/packages/sshecret-backend/src/sshecret_backend/api/secrets/operations.py @@ -0,0 +1,328 @@ +"""ClientSecret operations.""" + +import logging +import uuid +from datetime import datetime, timezone + +from fastapi import HTTPException, Request +from sqlalchemy import Null, distinct, func, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload +from sshecret_backend import audit +from sshecret_backend.api.common import ( + FlexID, + IdType, + get_client_by_id, + resolve_client_id, +) +from sshecret_backend.api.secrets.schemas import ( + ClientReference, + ClientSecretDetailList, + ClientSecretResponse, + ClientSecretStats, + DistinctClientSecretStats, +) +from sshecret_backend.models import Client, ClientSecret + +LOG = logging.getLogger(__name__) + +RelaxedId = uuid.UUID | str + + +def _id(id: RelaxedId) -> uuid.UUID: + """Ensure that the ID is a uuid.""" + if isinstance(id, str): + return uuid.UUID(id) + return id + + +class ClientSecretOperations: + """Perform operations on a client's ClientSecrets""" + + def __init__( + self, + session: AsyncSession, + request: Request, + client: FlexID, + include_deleted: bool = False, + ) -> None: + """Create operations class.""" + self._client_id: uuid.UUID | None = None + self.client_name: str | None = None + self.client: Client | None = None + if client.type is IdType.ID: + self._client_id = _id(client.value) + else: + self.client_name = str(client.value) + + self.include_deleted: bool = include_deleted + + self.session: AsyncSession = session + self.request: Request = request + + async def get_client_id(self) -> uuid.UUID | None: + """Get client ID.""" + if self._client_id: + LOG.debug("Returning previously resolved client ID.") + return self._client_id + + if not self.client_name: + raise RuntimeError("Error: No client information registered.") + + client_id = await resolve_client_id( + self.session, self.client_name, include_deleted=self.include_deleted + ) + self._client_id = client_id + LOG.debug("Saving client ID %s", client_id) + return client_id + + async def get_client(self) -> Client | None: + """Get client.""" + if self.client: + return self.client + client_id = await self.get_client_id() + if not client_id: + return None + + client = await get_client_by_id(self.session, client_id) + self.client = client + return client + + async def _get_client_secret( + self, secret_identifier: FlexID + ) -> ClientSecret | None: + """Get client secret private method.""" + client = await self.get_client() + if not client: + LOG.debug("Client lookup failed.") + return None + + if secret_identifier.type is IdType.ID: + match_id = _id(secret_identifier.value) + LOG.debug("Searching for secrets matching ID %r", match_id) + matches = [secret for secret in client.secrets if secret.id == match_id] + else: + LOG.debug("Searching for secrets matching name") + matches = [ + secret + for secret in client.secrets + if secret.name == str(secret_identifier.value) + ] + + for secret_match in matches: + if secret_match.deleted: + # TODO: Add override for deleted. + LOG.debug("Found deleted secret") + continue + + await audit.audit_access_secret( + self.session, self.request, client, secret_match + ) + LOG.debug("Found matching secret") + return secret_match + + LOG.debug("No secrets matched.") + return None + + async def get_client_secret( + self, secret_identifier: FlexID + ) -> ClientSecretResponse: + """Get client secret public model.""" + client_secret = await self._get_client_secret(secret_identifier) + if not client_secret: + raise HTTPException(status_code=404, detail="Unknown client or secret") + return ClientSecretResponse.from_client_secret(client_secret) + + async def create_client_secret( + self, name: str, value: str, description: str | None = None + ) -> ClientSecretResponse: + """Create client Secret.""" + client = await self.get_client() + if not client: + raise HTTPException(status_code=404, detail="Client not found.") + + existing_secret = await self._get_client_secret(FlexID.name(name)) + if existing_secret: + raise HTTPException( + status_code=400, + detail="Cannot add a secret. A different secret with the same name already exists.", + ) + secret = ClientSecret( + name=name, + description=description, + client=client, + secret=value, + ) + self.session.add(secret) + await self.session.commit() + await self.session.refresh(secret, attribute_names=["client", "updated_at"]) + await audit.audit_create_secret(self.session, self.request, client, secret) + return ClientSecretResponse.from_client_secret(secret) + + async def update_client_secret_value( + self, secret_identifier: FlexID, value: str + ) -> ClientSecretResponse: + """Update client secret value.""" + client_secret = await self._get_client_secret(secret_identifier) + if not client_secret: + # We can use this method to create a secret too. + if secret_identifier.type is IdType.NAME: + return await self.create_client_secret( + str(secret_identifier.value), value + ) + raise HTTPException(status_code=404, detail="Unknown client or secret") + client_secret.secret = value + self.session.add(client_secret) + await self.session.commit() + await self.session.refresh(client_secret, ["updated_at"]) + await audit.audit_update_secret( + self.session, self.request, client_secret.client, client_secret + ) + return ClientSecretResponse.from_client_secret(client_secret) + + async def delete_client_secret(self, secret_identifier: FlexID) -> None: + """Delete a client secret.""" + client_secret = await self._get_client_secret(secret_identifier) + if not client_secret: + return + + client_secret.deleted = True + client_secret.deleted_at = datetime.now(timezone.utc) + self.session.add(client_secret) + await self.session.commit() + await self.session.refresh(client_secret, ["deleted", "deleted_at"]) + await audit.audit_delete_secret( + self.session, self.request, client_secret.client, client_secret + ) + + +async def resolve_client_secret_mapping( + session: AsyncSession, +) -> list[ClientSecretDetailList]: + """Resolve mapping of clients to secrets.""" + result = await session.execute( + select(ClientSecret) + .join(ClientSecret.client) + .options(selectinload(ClientSecret.client)) + .where(Client.is_active.is_not(False)) + .where(ClientSecret.deleted.is_not(True)) + ) + client_secrets: dict[str, ClientSecretDetailList] = {} + for secret in result.scalars().all(): + if secret.name not in client_secrets: + client_secrets[secret.name] = ClientSecretDetailList(name=secret.name) + client_secrets[secret.name].ids.append(str(secret.id)) + if not secret.client: + continue + client_secrets[secret.name].clients.append( + ClientReference(id=str(secret.client.id), name=secret.client.name) + ) + + return list(client_secrets.values()) + + +async def resolve_client_secret_clients( + session: AsyncSession, name: str, include_deleted: bool = False +) -> ClientSecretDetailList | None: + """Resolve client association to a secret.""" + statement = ( + select(ClientSecret) + .options(selectinload(ClientSecret.client)) + .where(ClientSecret.name == name) + ) + if not include_deleted: + statement = statement.where(ClientSecret.deleted.is_not(True)) + + results = await session.execute(statement) + clients: ClientSecretDetailList | None = None + for client_secret in results.scalars().all(): + if not clients: + # Ensure we don't create the object before we have at least one client. + clients = ClientSecretDetailList(name=name) + clients.ids.append(str(client_secret.id)) + if client_secret.client: + clients.clients.append( + ClientReference( + id=str(client_secret.client.id), name=client_secret.client.name + ) + ) + + return clients + + +async def _get_secrets_total(session: AsyncSession) -> int: + """Get total amount of secrets excluding deleted.""" + statement = ( + select(func.count()) + .select_from(ClientSecret) + .where(ClientSecret.deleted.is_not(True)) + ) + result = await session.execute(statement) + return result.scalar_one() + + +async def _get_distinct_secret_count(session: AsyncSession) -> int: + """Get the amount of distinct secrets, excluding deleted.""" + statement = ( + select(func.count()) + .select_from(ClientSecret) + .where(ClientSecret.deleted.is_not(True)) + .distinct() + ) + result = await session.execute(statement) + return result.scalar_one() + + +async def _get_secret_client_stats( + session: AsyncSession, +) -> list[DistinctClientSecretStats]: + """Get stats for each named secret.""" + statement = ( + ( + select( + ClientSecret.name, func.count(distinct(Client.id)).label("client_count") + ) + ) + .join(Client, ClientSecret.client_id) + .where(ClientSecret.deleted.is_not(True), Client.is_deleted.is_not(True)) + .group_by(ClientSecret.name) + ) + results = await session.execute(statement) + stats: list[DistinctClientSecretStats] = [] + rows = results.all() + for row in rows: + secret_name, client_count = row.tuple() + stats.append(DistinctClientSecretStats(name=secret_name, clients=client_count)) + + return stats + + +async def _get_secrets_without_clients( + session: AsyncSession, +) -> int: + """Get secrets without clients.""" + statement = ( + select(func.count(distinct(ClientSecret.name))) + .where(ClientSecret.deleted.is_not(True)) + .where(ClientSecret.client_id.is_(Null)) + ) + result = await session.execute(statement) + return result.scalar_one() + + +async def get_client_secret_stats(session: AsyncSession) -> ClientSecretStats: + """Get stats for the client secrets. + + TODO: Implement a route with this. + """ + distinct_secrets = await _get_distinct_secret_count(session) + total_secrets = await _get_secrets_total(session) + secrets_without_clients = await _get_secrets_without_clients(session) + client_stats = await _get_secret_client_stats(session) + + return ClientSecretStats( + distinct_secrets=distinct_secrets, + total_secrets=total_secrets, + secrets_without_clients=secrets_without_clients, + client_stats=client_stats, + ) diff --git a/packages/sshecret-backend/src/sshecret_backend/api/secrets/router.py b/packages/sshecret-backend/src/sshecret_backend/api/secrets/router.py new file mode 100644 index 0000000..7116f86 --- /dev/null +++ b/packages/sshecret-backend/src/sshecret_backend/api/secrets/router.py @@ -0,0 +1,107 @@ +"""Client Secret Router.""" + +# pyright: reportUnusedFunction=false + +import logging +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException, Request +from sqlalchemy.ext.asyncio import AsyncSession + +from sshecret_backend.api.common import FlexID +from sshecret_backend.types import AsyncDBSessionDep +from sshecret_backend.api.secrets.operations import ( + ClientSecretOperations, + resolve_client_secret_clients, + resolve_client_secret_mapping, +) +from sshecret_backend.api.schemas import BodyValue +from sshecret_backend.api.secrets.schemas import ( + ClientSecretDetailList, + ClientSecretPublic, + ClientSecretResponse, +) + +LOG = logging.getLogger(__name__) + + +def create_client_secrets_router(get_db_session: AsyncDBSessionDep) -> APIRouter: + """Create client secret router.""" + router = APIRouter() + + @router.post("/clients/{client_identifier}/secrets/") + async def add_secret_to_client( + request: Request, + client_identifier: str, + client_secret: ClientSecretPublic, + session: Annotated[AsyncSession, Depends(get_db_session)], + ) -> None: + """Add secret to a client.""" + client = FlexID.from_string(client_identifier) + client_op = ClientSecretOperations(session, request, client) + await client_op.create_client_secret( + client_secret.name, client_secret.secret, client_secret.description + ) + + @router.put("/clients/{client_identifier}/secrets/{secret_identifier}") + async def update_client_secret( + request: Request, + client_identifier: str, + secret_identifier: str, + secret_data: BodyValue, + session: Annotated[AsyncSession, Depends(get_db_session)], + ) -> ClientSecretResponse: + """Update client secret.""" + client = FlexID.from_string(client_identifier) + secret = FlexID.from_string(secret_identifier) + client_op = ClientSecretOperations(session, request, client) + return await client_op.update_client_secret_value(secret, secret_data.value) + + @router.get("/clients/{client_identifier}/secrets/{secret_identifier}") + async def request_client_secret_named( + request: Request, + client_identifier: str, + secret_identifier: str, + session: Annotated[AsyncSession, Depends(get_db_session)], + ) -> ClientSecretResponse: + """Get a named secret from a named client.""" + client = FlexID.from_string(client_identifier) + secret = FlexID.from_string(secret_identifier) + LOG.debug("Resolved client FlexID: %r", client) + LOG.debug("Resolved secret FlexID: %r", secret) + client_op = ClientSecretOperations(session, request, client) + return await client_op.get_client_secret(secret) + + # TODO: delete_client_secret + @router.delete("/clients/{client_identifier}/secrets/{secret_identifier}") + async def delete_client_secret( + request: Request, + client_identifier: str, + secret_identifier: str, + session: Annotated[AsyncSession, Depends(get_db_session)], + ) -> None: + """Delete client secret.""" + client = FlexID.from_string(client_identifier) + secret = FlexID.from_string(secret_identifier) + client_op = ClientSecretOperations(session, request, client) + await client_op.delete_client_secret(secret) + + @router.get("/secrets/") + async def get_secret_map( + session: Annotated[AsyncSession, Depends(get_db_session)], + ) -> list[ClientSecretDetailList]: + """Get a list of secrets and which clients have them.""" + return await resolve_client_secret_mapping(session) + + @router.get("/secrets/{name}") + async def get_secret_clients( + name: str, + session: Annotated[AsyncSession, Depends(get_db_session)], + ) -> ClientSecretDetailList: + """Get a list of which clients has a named secret.""" + result = await resolve_client_secret_clients(session, name) + if not result: + raise HTTPException(status_code=404, detail="Secret not found.") + return result + + return router diff --git a/packages/sshecret-backend/src/sshecret_backend/view_models.py b/packages/sshecret-backend/src/sshecret_backend/api/secrets/schemas.py similarity index 51% rename from packages/sshecret-backend/src/sshecret_backend/view_models.py rename to packages/sshecret-backend/src/sshecret_backend/api/secrets/schemas.py index dcbd93e..44017ba 100644 --- a/packages/sshecret-backend/src/sshecret_backend/view_models.py +++ b/packages/sshecret-backend/src/sshecret_backend/api/secrets/schemas.py @@ -1,21 +1,12 @@ -"""Models for API views.""" +"""Schemas for the secrets router.""" -import uuid from datetime import datetime from typing import Self, override -from collections.abc import Sequence +import uuid -from pydantic import BaseModel, Field, IPvAnyAddress, IPvAnyNetwork +from pydantic import BaseModel, Field - -from . import models - - - -class BodyValue(BaseModel): - """A generic model with just a value parameter.""" - - value: str +from sshecret_backend import models class ClientSecretPublic(BaseModel): @@ -38,6 +29,7 @@ class ClientSecretPublic(BaseModel): class ClientSecretResponse(ClientSecretPublic): """A secret view.""" + id: uuid.UUID created_at: datetime | None updated_at: datetime | None = None @@ -47,6 +39,7 @@ class ClientSecretResponse(ClientSecretPublic): """Instantiate from ClientSecret.""" return cls( + id=client_secret.id, name=client_secret.name, secret=client_secret.secret, created_at=client_secret.created_at, @@ -54,25 +47,6 @@ class ClientSecretResponse(ClientSecretPublic): ) -class ClientPolicyView(BaseModel): - """Update object for client policy.""" - - sources: list[str] = ["0.0.0.0/0", "::/0"] - - @classmethod - def from_client(cls, client: models.Client) -> Self: - """Create from client.""" - if not client.policies: - return cls() - return cls(sources=[policy.source for policy in client.policies]) - - -class ClientPolicyUpdate(BaseModel): - """Model for updating policies.""" - - sources: list[IPvAnyAddress | IPvAnyNetwork] - - class ClientSecretList(BaseModel): """Model for aggregating identically named secrets.""" @@ -95,31 +69,20 @@ class ClientSecretDetailList(BaseModel): clients: list[ClientReference] = Field(default_factory=list) -class AuditView(BaseModel): - """Audit log view.""" +class DistinctClientSecretStats(BaseModel): + """Stats for distinct client secrets.""" + + name: str + clients: int = 0 - id: uuid.UUID | None = None - subsystem: models.SubSystem - message: str - operation: models.Operation - data: dict[str, str] | None = None - client_id: uuid.UUID | None = None - client_name: str | None = None - secret_id: uuid.UUID | None = None - secret_name: str | None = None - origin: str | None = None - timestamp: datetime | None = None +class ClientSecretStats(BaseModel): + """Stats row for the clientSecret model. + Useful for pagination and statistics. + """ -class AuditInfo(BaseModel): - """Information about audit information.""" - - entries: int - - -class AuditListResult(BaseModel): - """Class to return when listing audit entries.""" - results: Sequence[AuditView] - total: int - remaining: int + distinct_secrets: int + total_secrets: int + secrets_without_clients: int + client_stats: list[DistinctClientSecretStats] = Field(default_factory=list) diff --git a/packages/sshecret-backend/src/sshecret_backend/audit.py b/packages/sshecret-backend/src/sshecret_backend/audit.py index 89c2704..ca88c24 100644 --- a/packages/sshecret-backend/src/sshecret_backend/audit.py +++ b/packages/sshecret-backend/src/sshecret_backend/audit.py @@ -5,7 +5,14 @@ from fastapi import Request from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy, Operation, SubSystem +from .models import ( + AuditLog, + Client, + ClientSecret, + ClientAccessPolicy, + Operation, + SubSystem, +) def _get_origin(request: Request) -> str | None: @@ -128,6 +135,27 @@ async def audit_update_client( await _write_audit_log(session, request, entry, commit) +async def audit_new_client_version( + session: AsyncSession, + request: Request, + old_client: Client, + new_client: Client, + commit: bool = True, +) -> None: + """Audit an update secret event.""" + entry = AuditLog( + operation=Operation.UPDATE, + client_id=old_client.id, + client_name=old_client.name, + message="Client data updated", + data={ + "new_client_id": str(new_client.id), + "new_client_version": new_client.version, + }, + ) + await _write_audit_log(session, request, entry, commit) + + async def audit_update_secret( session: AsyncSession, request: Request, @@ -224,6 +252,7 @@ async def audit_access_secret( ) await _write_audit_log(session, request, entry, commit) + async def audit_client_secret_list( session: AsyncSession, request: Request, commit: bool = True ) -> None: @@ -233,4 +262,3 @@ async def audit_client_secret_list( message="All secret names and their clients was viewed", ) await _write_audit_log(session, request, entry, commit) - diff --git a/packages/sshecret-backend/src/sshecret_backend/auth.py b/packages/sshecret-backend/src/sshecret_backend/auth.py index c6a94cd..1aa2041 100644 --- a/packages/sshecret-backend/src/sshecret_backend/auth.py +++ b/packages/sshecret-backend/src/sshecret_backend/auth.py @@ -2,6 +2,7 @@ import bcrypt + def hash_token(token: str) -> str: """Hash a token.""" pwbytes = token.encode("utf-8") diff --git a/packages/sshecret-backend/src/sshecret_backend/backend_api.py b/packages/sshecret-backend/src/sshecret_backend/backend_api.py index f81615e..390411b 100644 --- a/packages/sshecret-backend/src/sshecret_backend/backend_api.py +++ b/packages/sshecret-backend/src/sshecret_backend/backend_api.py @@ -9,7 +9,8 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sshecret_backend.db import DatabaseSessionManager from sshecret_backend.settings import BackendSettings -from .api import get_audit_api, get_policy_api, get_secrets_api +from .api.audit.router import create_audit_router +from .api.secrets.router import create_client_secrets_router from .api.clients.router import create_client_router from .auth import verify_token from .models import ( @@ -60,9 +61,8 @@ def get_backend_api( dependencies=[Depends(validate_token)], ) - backend_api.include_router(get_audit_api(get_db_session)) + backend_api.include_router(create_audit_router(get_db_session)) backend_api.include_router(create_client_router(get_db_session)) - backend_api.include_router(get_policy_api(get_db_session)) - backend_api.include_router(get_secrets_api(get_db_session)) + backend_api.include_router(create_client_secrets_router(get_db_session)) return backend_api diff --git a/packages/sshecret-backend/src/sshecret_backend/db.py b/packages/sshecret-backend/src/sshecret_backend/db.py index f7c2f13..4fd8290 100644 --- a/packages/sshecret-backend/src/sshecret_backend/db.py +++ b/packages/sshecret-backend/src/sshecret_backend/db.py @@ -8,7 +8,13 @@ from collections.abc import AsyncIterator, Generator, Callable from contextlib import asynccontextmanager from typing import Literal from sqlalchemy import create_engine, Engine, event, select -from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession, async_sessionmaker, create_async_engine, AsyncEngine +from sqlalchemy.ext.asyncio import ( + AsyncConnection, + AsyncSession, + async_sessionmaker, + create_async_engine, + AsyncEngine, +) from sqlalchemy.orm import sessionmaker, Session @@ -21,11 +27,14 @@ from .models import APIClient, SubSystem LOG = logging.getLogger(__name__) - class DatabaseSessionManager: def __init__(self, host: URL, **engine_kwargs: str): self._engine: AsyncEngine | None = get_async_engine(host) - self._sessionmaker: async_sessionmaker[AsyncSession] | None = async_sessionmaker(autocommit=False, bind=self._engine, expire_on_commit=False) + self._sessionmaker: async_sessionmaker[AsyncSession] | None = ( + async_sessionmaker( + autocommit=False, bind=self._engine, expire_on_commit=False + ) + ) async def close(self): if self._engine is None: @@ -81,7 +90,6 @@ def setup_database( return engine, get_db_session - def get_engine(url: URL, echo: bool = False) -> Engine: """Initialize the engine.""" engine = create_engine(url, echo=echo) @@ -102,6 +110,7 @@ def get_async_engine(url: URL, echo: bool = False, **engine_kwargs: str) -> Asyn """Get an async engine.""" engine = create_async_engine(url, echo=echo, **engine_kwargs) if url.drivername.startswith("sqlite+"): + @event.listens_for(engine.sync_engine, "connect") def set_sqlite_pragma( dbapi_connection: sqlite3.Connection, _connection_record: object @@ -113,18 +122,21 @@ def get_async_engine(url: URL, echo: bool = False, **engine_kwargs: str) -> Asyn return engine - -def create_api_token_with_value(session: Session, token: str, subsystem: Literal["admin", "sshd"]) -> None: +def create_api_token_with_value( + session: Session, token: str, subsystem: Literal["admin", "sshd"] +) -> None: """Create API token with a given value.""" - existing = session.scalars(select(APIClient).where(APIClient.subsystem == SubSystem(subsystem))).first() + existing = session.scalars( + select(APIClient).where(APIClient.subsystem == SubSystem(subsystem)) + ).first() if existing: if verify_token(token, existing.token): LOG.info("Token is up to date.") return LOG.info("Updating token value for subsystem %s", subsystem) hashed = hash_token(token) - existing.token=hashed + existing.token = hashed session.commit() return @@ -135,12 +147,19 @@ def create_api_token_with_value(session: Session, token: str, subsystem: Literal session.add(api_token) session.commit() -def create_api_token(session: Session, subsystem: Literal["admin", "sshd", "test"], recreate: bool = False) -> str: + +def create_api_token( + session: Session, + subsystem: Literal["admin", "sshd", "test"], + recreate: bool = False, +) -> str: """Create API token.""" subsys = SubSystem(subsystem) token = secrets.token_urlsafe(32) hashed = hash_token(token) - if existing := session.scalars(select(APIClient).where(APIClient.subsystem == subsys)).first(): + if existing := session.scalars( + select(APIClient).where(APIClient.subsystem == subsys) + ).first(): if not recreate: raise RuntimeError("Error: A token already exist for this subsystem.") existing.token = hashed diff --git a/packages/sshecret-backend/src/sshecret_backend/models.py b/packages/sshecret-backend/src/sshecret_backend/models.py index cddef14..21d8ae2 100644 --- a/packages/sshecret-backend/src/sshecret_backend/models.py +++ b/packages/sshecret-backend/src/sshecret_backend/models.py @@ -79,8 +79,7 @@ class Client(Base): ) deleted_at: Mapped[datetime | None] = mapped_column( - sa.DateTime(timezone=True), - nullable=True + sa.DateTime(timezone=True), nullable=True ) secrets: Mapped[list["ClientSecret"]] = relationship( @@ -93,9 +92,7 @@ class Client(Base): nullable=True, ) previous_version: Mapped["Client | None"] = relationship( - "Client", - remote_side=[id], - backref="versions" + "Client", remote_side=[id], backref="versions" ) policies: Mapped[list["ClientAccessPolicy"]] = relationship(back_populates="client") @@ -142,7 +139,7 @@ class ClientSecret(Base): sa.Uuid(as_uuid=True), sa.ForeignKey("client.id", ondelete="CASCADE") ) client: Mapped[Client] = relationship(back_populates="secrets") - invalidated: Mapped[bool] = mapped_column(default=False) + deleted: Mapped[bool] = mapped_column(default=False) created_at: Mapped[datetime] = mapped_column( sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False @@ -154,6 +151,10 @@ class ClientSecret(Base): onupdate=sa.func.now(), ) + deleted_at: Mapped[datetime | None] = mapped_column( + sa.DateTime(timezone=True), nullable=True + ) + class APIClient(Base): """A client on the API. diff --git a/tests/packages/backend/test_backend.py b/tests/packages/backend/test_backend.py index 89c46f9..0313771 100644 --- a/tests/packages/backend/test_backend.py +++ b/tests/packages/backend/test_backend.py @@ -11,7 +11,7 @@ from fastapi.testclient import TestClient from sshecret.crypto import generate_private_key, generate_public_key_string from sshecret_backend.app import create_backend_app from sshecret_backend.testing import create_test_token -from sshecret_backend.view_models import AuditView +from sshecret_backend.api.audit.schemas import AuditView from sshecret_backend.settings import BackendSettings @@ -167,8 +167,8 @@ def test_put_add_secret(test_client: TestClient) -> None: response_model = response.json() del response_model["created_at"] del response_model["updated_at"] - assert response_model == data - + for key, value in data.items(): + assert response_model.get(key) == value def test_put_update_secret(test_client: TestClient) -> None: """Test updating a client secret.""" @@ -407,7 +407,7 @@ def test_get_secret_list(test_client: TestClient) -> None: assert len(entry["clients"]) == 4 else: assert len(entry["clients"]) == 1 - assert entry["clients"][0] == entry["name"] + assert entry["clients"][0]["name"] == entry["name"] def test_get_secret_clients(test_client: TestClient) -> None: @@ -428,8 +428,9 @@ def test_get_secret_clients(test_client: TestClient) -> None: data = resp.json() assert data["name"] == "commonsecret" - assert "client-0" in data["clients"] - assert "client-1" not in data["clients"] + client_names = [client["name"] for client in data["clients"]] + assert "client-0" in client_names + assert "client-1" not in client_names assert len(data["clients"]) == 2 @@ -473,7 +474,7 @@ def test_operations_with_id(test_client: TestClient) -> None: data = resp.json() client = data["clients"][0] client_id = client["id"] - resp = test_client.get(f"/api/v1/clients/by-id/{client_id}") + resp = test_client.get(f"/api/v1/clients/id:{client_id}") assert resp.status_code == 200 data = resp.json() assert data["name"] == "test" @@ -559,3 +560,47 @@ def test_filter_audit_log(test_client: TestClient) -> None: assert data["results"][0]["operation"] == "login" assert data["results"][0]["message"] == "message1" + + +def test_secret_flexid(test_client: TestClient) -> None: + """Test flexible IDs in the secret API.""" + client_name = "test" + create_response = create_client( + test_client, + client_name, + ) + assert create_response.status_code == 200 + assert "id" in create_response.json() + client_id = create_response.json()["id"] + + # Create a secret using the client name + secrets: dict[str, str] = {} + resp = test_client.put( + "/api/v1/clients/test/secrets/clientnamesecret", + json={"value": "secret"}, + ) + assert resp.status_code == 200 + secret_data = resp.json() + assert "id" in secret_data + secrets["clientnamesecret"] = secret_data["id"] + + # Create one using the client ID + resp = test_client.put( + f"/api/v1/clients/id:{client_id}/secrets/clientidsecret", + json={"value": "secret"}, + ) + assert resp.status_code == 200 + + secret_data = resp.json() + assert "id" in secret_data + secrets["clientidsecret"] = secret_data["id"] + + # Let's try fetching the various permutations + for client_identifier in ("test", f"id:{client_id}"): + for secret_name, secret_id in secrets.items(): + for secret_identifier in (secret_name, f"id:{secret_id}"): + resp = test_client.get(f"/api/v1/clients/{client_identifier}/secrets/{secret_identifier}") + assert resp.status_code == 200 + resp_body = resp.json() + assert "id" in resp_body + assert resp_body["id"] == secret_id