Merge pull request 'Refactor backend views, update secret model' (#24) from feature/expanded-secrets into main
All checks were successful
Build and push image / build-containers (push) Successful in 8m3s

Reviewed-on: #24
This commit is contained in:
2025-06-08 15:44:59 +00:00
29 changed files with 1491 additions and 731 deletions

View File

@ -43,7 +43,7 @@ jobs:
git.eising.cloud/${{ env.DOCKER_ORG }}/${{ steps.meta.outputs.REPO_NAME }}-backend:${{ gitea.ref == 'refs/heads/main' && 'latest' || gitea.sha }} git.eising.cloud/${{ env.DOCKER_ORG }}/${{ steps.meta.outputs.REPO_NAME }}-backend:${{ gitea.ref == 'refs/heads/main' && 'latest' || gitea.sha }}
- name: Build backend and push - name: Build sshd and push
uses: docker/build-push-action@v4 uses: docker/build-push-action@v4
with: with:
context: . context: .
@ -54,7 +54,7 @@ jobs:
tags: | tags: |
git.eising.cloud/${{ env.DOCKER_ORG }}/${{ steps.meta.outputs.REPO_NAME }}-sshd:${{ gitea.ref == 'refs/heads/main' && 'latest' || gitea.sha }} git.eising.cloud/${{ env.DOCKER_ORG }}/${{ steps.meta.outputs.REPO_NAME }}-sshd:${{ gitea.ref == 'refs/heads/main' && 'latest' || gitea.sha }}
- name: Build backend and push - name: Build admin and push
uses: docker/build-push-action@v4 uses: docker/build-push-action@v4
with: with:
context: . context: .

View File

@ -0,0 +1,36 @@
"""Remove invalidated, add deleted
Revision ID: b4e135ff347a
Revises: f2dc50533f88
Create Date: 2025-06-06 08:57:47.611854
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'b4e135ff347a'
down_revision: Union[str, None] = 'f2dc50533f88'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('client_secret', sa.Column('deleted', sa.Boolean(), nullable=False))
op.add_column('client_secret', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
op.drop_column('client_secret', 'invalidated')
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('client_secret', sa.Column('invalidated', sa.BOOLEAN(), nullable=False))
op.drop_column('client_secret', 'deleted_at')
op.drop_column('client_secret', 'deleted')
# ### end Alembic commands ###

View File

@ -0,0 +1,51 @@
"""Make client object better
Revision ID: c251311b64c9
Revises: 37329d9b5437
Create Date: 2025-06-04 21:49:22.638698
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "c251311b64c9"
down_revision: Union[str, None] = "37329d9b5437"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("client", schema=None) as batch_op:
batch_op.add_column(sa.Column("version", sa.Integer(), nullable=False))
batch_op.add_column(sa.Column("is_active", sa.Boolean(), nullable=False))
batch_op.add_column(sa.Column("is_deleted", sa.Boolean(), nullable=False))
batch_op.add_column(
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True)
)
batch_op.add_column(sa.Column("parent_id", sa.Uuid(), nullable=True))
batch_op.create_unique_constraint("uq_client_name_version", ["name", "version"])
batch_op.create_foreign_key(
"fk_client_parent", "client", ["parent_id"], ["id"], ondelete="SET NULL"
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("client", schema=None) as batch_op:
batch_op.drop_constraint("fk_client_parent", type_="foreignkey")
batch_op.drop_constraint("uq_client_name_version", type_="unique")
batch_op.drop_column("parent_id")
batch_op.drop_column("deleted_at")
batch_op.drop_column("is_deleted")
batch_op.drop_column("is_active")
batch_op.drop_column("version")
# ### end Alembic commands ###

View File

@ -0,0 +1,27 @@
"""Rename parent to previous_version for clarity
Revision ID: f2dc50533f88
Revises: c251311b64c9
Create Date: 2025-06-05 13:24:32.465927
"""
from typing import Sequence, Union
from alembic import op
# revision identifiers, used by Alembic.
revision: str = 'f2dc50533f88'
down_revision: Union[str, None] = 'c251311b64c9'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
op.alter_column("client", "parent_id", new_column_name="previous_version_id")
def downgrade() -> None:
"""Downgrade schema."""
op.alter_column("client", "previous_version_id", new_column_name="parent_id")

View File

@ -1,8 +1 @@
"""API factory modules.""" """API factory modules."""
from .audit import get_audit_api
from .clients import get_clients_api
from .policies import get_policy_api
from .secrets import get_secrets_api
__all__ = ["get_audit_api", "get_clients_api", "get_policy_api", "get_secrets_api"]

View File

@ -4,7 +4,7 @@
import logging import logging
from typing import Any from typing import Any
from fastapi import APIRouter, Depends, Request from fastapi import APIRouter, Depends
from pydantic import BaseModel, Field, TypeAdapter from pydantic import BaseModel, Field, TypeAdapter
from sqlalchemy import select, func, and_ from sqlalchemy import select, func, and_
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@ -15,7 +15,7 @@ from typing import Annotated
from sshecret_backend.models import AuditLog, Operation, SubSystem from sshecret_backend.models import AuditLog, Operation, SubSystem
from sshecret_backend.types import AsyncDBSessionDep from sshecret_backend.types import AsyncDBSessionDep
from sshecret_backend.view_models import AuditInfo, AuditView, AuditListResult from .schemas import AuditInfo, AuditView, AuditListResult
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -58,7 +58,7 @@ class AuditFilter(BaseModel):
] ]
def get_audit_api(get_db_session: AsyncDBSessionDep) -> APIRouter: def create_audit_router(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Construct audit sub-api.""" """Construct audit sub-api."""
router = APIRouter() router = APIRouter()
@ -70,11 +70,13 @@ def get_audit_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Get audit logs.""" """Get audit logs."""
# audit.audit_access_audit_log(session, request) # audit.audit_access_audit_log(session, request)
total = (await session.scalars( total = (
await session.scalars(
select(func.count("*")) select(func.count("*"))
.select_from(AuditLog) .select_from(AuditLog)
.where(and_(True, *filters.filter_mapping)) .where(and_(True, *filters.filter_mapping))
)).one() )
).one()
remaining = total - filters.offset remaining = total - filters.offset
statement = ( statement = (
@ -107,12 +109,12 @@ def get_audit_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
@router.get("/audit/info") @router.get("/audit/info")
async def get_audit_info( async def get_audit_info(
session: Annotated[AsyncSession, Depends(get_db_session)] session: Annotated[AsyncSession, Depends(get_db_session)],
) -> AuditInfo: ) -> AuditInfo:
"""Get audit info.""" """Get audit info."""
audit_count = (await session.scalars( audit_count = (
select(func.count("*")).select_from(AuditLog) await session.scalars(select(func.count("*")).select_from(AuditLog))
)).one() ).one()
return AuditInfo(entries=audit_count) return AuditInfo(entries=audit_count)
return router return router

View File

@ -0,0 +1,40 @@
"""Models for Audit API views."""
import uuid
from datetime import datetime
from collections.abc import Sequence
from pydantic import BaseModel
from sshecret_backend import models
class AuditView(BaseModel):
"""Audit log view."""
id: uuid.UUID | None = None
subsystem: models.SubSystem
message: str
operation: models.Operation
data: dict[str, str] | None = None
client_id: uuid.UUID | None = None
client_name: str | None = None
secret_id: uuid.UUID | None = None
secret_name: str | None = None
origin: str | None = None
timestamp: datetime | None = None
class AuditInfo(BaseModel):
"""Information about audit information."""
entries: int
class AuditListResult(BaseModel):
"""Class to return when listing audit entries."""
results: Sequence[AuditView]
total: int
remaining: int

View File

@ -1,227 +0,0 @@
"""Client sub-api factory."""
# pyright: reportUnusedFunction=false
import uuid
import logging
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from pydantic import BaseModel, Field, model_validator
from typing import Annotated, Any, Self, TypeVar, cast
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import Select
from sshecret_backend.types import AsyncDBSessionDep
from sshecret_backend.models import Client, ClientSecret
from sshecret_backend.view_models import (
ClientCreate,
ClientQueryResult,
ClientView,
ClientUpdate,
)
from sshecret_backend import audit
from .common import get_client_by_id_or_name, client_with_relationships
class ClientListParams(BaseModel):
"""Client list parameters."""
limit: int = Field(100, gt=0, le=100)
offset: int = Field(0, ge=0)
id: uuid.UUID | None = None
name: str | None = None
name__like: str | None = None
name__contains: str | None = None
@model_validator(mode="after")
def validate_expressions(self) -> Self:
"""Validate mutually exclusive expression."""
name_filter = False
if self.name__like or self.name__contains:
name_filter = True
if self.name__like and self.name__contains:
raise ValueError("You may only specify one name expression")
if self.name and name_filter:
raise ValueError(
"You must either specify name or one of name__like or name__contains"
)
return self
LOG = logging.getLogger(__name__)
T = TypeVar("T")
def filter_client_statement(
statement: Select[Any], params: ClientListParams, ignore_limits: bool = False
) -> Select[Any]:
"""Filter a statement with the provided params."""
if params.id:
statement = statement.where(Client.id == params.id)
if params.name:
statement = statement.where(Client.name == params.name)
elif params.name__like:
statement = statement.where(Client.name.like(params.name__like))
elif params.name__contains:
statement = statement.where(Client.name.contains(params.name__contains))
if ignore_limits:
return statement
return statement.limit(params.limit).offset(params.offset)
def get_clients_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Construct clients sub-api."""
router = APIRouter()
@router.get("/clients/")
async def get_clients(
filter_query: Annotated[ClientListParams, Query()],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientQueryResult:
"""Get clients."""
# Get total results first
count_statement = select(func.count("*")).select_from(Client)
count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True))
total_results = (await session.scalars(count_statement)).one()
statement = filter_client_statement(client_with_relationships(), filter_query, False)
results = await session.scalars(statement)
remainder = total_results - filter_query.offset - filter_query.limit
if remainder < 0:
remainder = 0
clients = list(results.all())
clients_view = ClientView.from_client_list(clients)
return ClientQueryResult(
clients=clients_view,
total_results=total_results,
remaining_results=remainder,
)
@router.get("/clients/{name}")
async def get_client(
name: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Fetch a client."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
return ClientView.from_client(client)
@router.delete("/clients/{name}")
async def delete_client(
request: Request,
name: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> None:
"""Delete a client."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
await session.delete(client)
await session.commit()
await audit.audit_delete_client(session, request, client)
@router.post("/clients/")
async def create_client(
request: Request,
client: ClientCreate,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Create client."""
existing = await get_client_by_id_or_name(session, client.name)
if existing:
raise HTTPException(400, detail="Error: Already a client with that name.")
db_client = client.to_client()
session.add(db_client)
await session.commit()
await session.refresh(db_client)
db_client = await get_client_by_id_or_name(session, client.name)
if not db_client:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Could not create the client.")
await audit.audit_create_client(session, request, db_client)
return ClientView.from_client(db_client)
@router.post("/clients/{name}/public-key")
async def update_client_public_key(
request: Request,
name: str,
client_update: ClientUpdate,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Change the public key of a client.
This invalidates all secrets.
"""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
client.public_key = client_update.public_key
matching_secrets = await session.scalars(select(ClientSecret).where(ClientSecret.client_id == client.id))
for secret in matching_secrets.all():
LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True
secret.client_id = None
session.add(client)
await session.refresh(client)
await session.commit()
await audit.audit_invalidate_secrets(session, request, client)
return ClientView.from_client(client)
@router.put("/clients/{name}")
async def update_client(
request: Request,
name: str,
client_update: ClientCreate,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Change the public key of a client.
This invalidates all secrets.
"""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
client.name = client_update.name
client.description = client_update.description
public_key_updated = False
if client_update.public_key != client.public_key:
public_key_updated = True
client_secrets = await session.scalars(
select(ClientSecret).where(ClientSecret.client_id == client.id)
)
for secret in client_secrets.all():
LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True
secret.client_id = None
session.add(client)
await session.commit()
await session.refresh(client)
await audit.audit_update_client(session, request, client)
if public_key_updated:
await audit.audit_invalidate_secrets(session, request, client)
return ClientView.from_client(client)
return router

View File

@ -0,0 +1,320 @@
"""Client operations."""
import logging
import uuid
from typing import Any, cast
from datetime import datetime, timezone
from fastapi import HTTPException, Request
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import Select
from sshecret_backend import audit
from sshecret_backend.api.common import (
FlexID,
IdType,
RelaxedId,
create_new_client_version,
query_active_clients,
get_client_by_id,
resolve_client_id,
refresh_client,
reload_client_with_relationships,
)
from sshecret_backend.models import Client, ClientAccessPolicy
from .schemas import (
ClientListParams,
ClientCreate,
ClientPolicyView,
ClientView,
ClientQueryResult,
ClientPolicyUpdate,
)
LOG = logging.getLogger(__name__)
def _id(id: RelaxedId) -> uuid.UUID:
"""Ensure that the ID is a uuid."""
if isinstance(id, str):
return uuid.UUID(id)
return id
class ClientOperations:
"""Perform operations on a client."""
def __init__(self, session: AsyncSession, request: Request) -> None:
"""Create operations class."""
self.session: AsyncSession = session
self.request: Request = request
self._client_id: uuid.UUID | None = None
# self.client_name: str | None = None
# if client.type is IdType.ID:
# self._client_id = _id(client.value)
# else:
# self.client_name = str(client.value)
async def get_client_id(
self,
client: FlexID,
version: int | None = None,
) -> uuid.UUID | None:
"""Get client ID."""
if self._client_id:
LOG.debug("Returning previously resolved client ID.")
return self._client_id
if client.type is IdType.ID:
self._client_id = _id(client.value)
return self._client_id
client_name = str(client.value)
client_id = await resolve_client_id(
self.session,
client_name,
version=version,
)
if not client_id:
return None
self._client_id = client_id
LOG.debug("Saving client ID %s", client_id)
return client_id
async def _get_client(
self,
client: FlexID,
version: int | None = None,
) -> Client | None:
"""Get client."""
client_id = await self.get_client_id(client, version=version)
if not client_id:
return None
db_client = await get_client_by_id(self.session, client_id)
return db_client
async def get_client(
self,
client: FlexID,
version: int | None = None,
) -> ClientView:
"""Get public client object."""
db_client = await self._get_client(client, version)
if not db_client:
raise HTTPException(status_code=404, detail="Client not found.")
return ClientView.from_client(db_client)
async def create_client(
self,
create_model: ClientCreate,
) -> ClientView:
"""Create a new client."""
existing_id = await self.get_client_id(FlexID.name(create_model.name))
if existing_id:
raise HTTPException(
status_code=400, detail="Error: A client already exists with this name."
)
client = create_model.to_client()
self.session.add(client)
await self.session.flush()
await self.session.commit()
await refresh_client(self.session, client)
await audit.audit_create_client(self.session, self.request, client)
return ClientView.from_client(client)
async def delete_client(self, client: FlexID) -> None:
"""Delete client."""
db_client = await self._get_client(client)
if not db_client:
return
if db_client.is_deleted:
return
db_client.is_deleted = True
db_client.deleted_at = datetime.now(timezone.utc)
self.session.add(db_client)
await self.session.commit()
await audit.audit_delete_client(self.session, self.request, db_client)
async def update_client(
self,
client: FlexID,
client_update: ClientCreate,
) -> ClientView:
"""Update client details."""
db_client = await self._get_client(client)
if not db_client:
raise HTTPException(status_code=404, detail="Client not found.")
if (
client_update.public_key
and client_update.public_key != db_client.public_key
):
return await new_client_version(
self.session,
db_client.id,
client_update.public_key,
client_update.name,
client_update.description,
)
db_client.name = client_update.name
db_client.description = client_update.description
self.session.add(db_client)
await self.session.commit()
await refresh_client(self.session, db_client)
await audit.audit_update_client(self.session, self.request, db_client)
return ClientView.from_client(db_client)
async def new_client_version(
self,
client: FlexID,
public_key: str,
name: str | None = None,
description: str | None = None,
) -> ClientView:
"""Update a client to a new version."""
current_client = await self._get_client(client)
if not current_client:
raise HTTPException(status_code=404, detail="Client not found.")
new_client = await create_new_client_version(
self.session, current_client, public_key
)
if name:
new_client.name = name
if description:
new_client.description = description
current_client.is_active = False
self.session.add(current_client)
await self.session.commit()
await refresh_client(self.session, new_client)
await audit.audit_new_client_version(
self.session, self.request, current_client, new_client
)
return ClientView.from_client(new_client)
async def get_client_policies(self, client: FlexID) -> ClientPolicyView:
"""Get client policies."""
db_client = await self._get_client(client)
if not db_client:
raise HTTPException(status_code=404, detail="Client not found.")
return ClientPolicyView.from_client(db_client)
async def update_client_policies(
self, client: FlexID, policy_update: ClientPolicyUpdate
) -> ClientPolicyView:
"""Update client policies."""
db_client = await self._get_client(client)
if not db_client:
raise HTTPException(status_code=404, detail="Client not found.")
policies = await self.session.scalars(
select(ClientAccessPolicy).where(
ClientAccessPolicy.client_id == db_client.id
)
)
deleted_policies: list[ClientAccessPolicy] = []
added_policies: list[ClientAccessPolicy] = []
for policy in policies.all():
await self.session.delete(policy)
deleted_policies.append(policy)
LOG.debug("Updating client policies with: %r", policy_update.sources)
for source in policy_update.sources:
LOG.debug("Source %r", source)
policy = ClientAccessPolicy(source=str(source), client_id=db_client.id)
self.session.add(policy)
added_policies.append(policy)
await self.session.flush()
await self.session.commit()
db_client = await reload_client_with_relationships(self.session, db_client)
for policy in deleted_policies:
await audit.audit_remove_policy(
self.session, self.request, db_client, policy
)
for policy in added_policies:
await audit.audit_update_policy(
self.session, self.request, db_client, policy
)
return ClientPolicyView.from_client(db_client)
def filter_client_statement(
statement: Select[Any], params: ClientListParams, ignore_limits: bool = False
) -> Select[Any]:
"""Filter a statement with the provided params."""
if params.id:
statement = statement.where(Client.id == params.id)
if params.name:
statement = statement.where(Client.name == params.name)
elif params.name__like:
statement = statement.where(Client.name.like(params.name__like))
elif params.name__contains:
statement = statement.where(Client.name.contains(params.name__contains))
if ignore_limits:
return statement
return statement.limit(params.limit).offset(params.offset)
async def get_clients(
session: AsyncSession,
filter_query: ClientListParams,
) -> ClientQueryResult:
"""Get Clients."""
count_statement = select(func.count("*")).select_from(Client)
count_statement = cast(
Select[tuple[int]],
filter_client_statement(count_statement, filter_query, True),
)
total_results = (await session.scalars(count_statement)).one()
statement = filter_client_statement(query_active_clients(), filter_query, False)
results = await session.scalars(statement)
remainder = total_results - filter_query.offset - filter_query.limit
if remainder < 0:
remainder = 0
clients = list(results.all())
clients_view = ClientView.from_client_list(clients)
return ClientQueryResult(
clients=clients_view,
total_results=total_results,
remaining_results=remainder,
)
async def new_client_version(
session: AsyncSession,
client_id: uuid.UUID,
public_key: str,
name: str | None = None,
description: str | None = None,
) -> ClientView:
"""Update a client to a new version."""
current_client = await get_client_by_id(session, client_id)
if not current_client:
raise ValueError("Client not found.")
new_client = await create_new_client_version(session, current_client, public_key)
if name:
new_client.name = name
if description:
new_client.description = description
current_client.is_active = False
session.add(current_client)
await session.commit()
return ClientView.from_client(new_client)

View File

@ -0,0 +1,122 @@
"""Client router."""
# pyright: reportUnusedFunction=false
#
import logging
from typing import Annotated
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sshecret_backend.types import AsyncDBSessionDep
from sshecret_backend.api.clients.schemas import (
ClientCreate,
ClientListParams,
ClientQueryResult,
ClientUpdate,
ClientView,
ClientPolicyUpdate,
ClientPolicyView,
)
from sshecret_backend.api.clients import operations
from sshecret_backend.api.clients.operations import ClientOperations
from sshecret_backend.api.common import FlexID
LOG = logging.getLogger(__name__)
def create_client_router(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Create client router."""
router = APIRouter()
@router.get("/clients/")
async def route_get_clients(
filter_query: Annotated[ClientListParams, Query()],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientQueryResult:
return await operations.get_clients(session, filter_query)
@router.post("/clients/")
async def create_client(
request: Request,
client: ClientCreate,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Create client."""
client_op = ClientOperations(session, request)
return await client_op.create_client(client)
@router.get("/clients/{client_identifier}")
async def fetch_client_by_name(
request: Request,
client_identifier: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
version: Annotated[int | None, Query()] = None,
) -> ClientView:
"""Fetch client by name."""
client_id = FlexID.from_string(client_identifier)
client_op = ClientOperations(session, request)
return await client_op.get_client(client_id, version)
@router.delete("/clients/{client_identifier}")
async def delete_client(
request: Request,
client_identifier: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> None:
"""Delete a client."""
client_id = FlexID.from_string(client_identifier)
client_op = ClientOperations(session, request)
await client_op.delete_client(client_id)
@router.post("/clients/{client_identifier}/public-key")
async def update_client_public_key(
request: Request,
client_identifier: str,
client_update: ClientUpdate,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Update client public key."""
client_id = FlexID.from_string(client_identifier)
client_op = ClientOperations(session, request)
return await client_op.new_client_version(client_id, client_update.public_key)
@router.put("/clients/{client_identifier}")
async def update_client_by_name(
request: Request,
client_identifier: str,
client_update: ClientCreate,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Update client by name."""
client_id = FlexID.from_string(client_identifier)
client_op = ClientOperations(session, request)
return await client_op.update_client(client_id, client_update)
@router.get("/clients/{client_identifier}/policies/")
async def get_client_policies(
request: Request,
client_identifier: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientPolicyView:
"""Get client policies."""
client_id = FlexID.from_string(client_identifier)
client_op = ClientOperations(session, request)
return await client_op.get_client_policies(client_id)
@router.put("/clients/{client_identifier}/policies/")
async def update_client_policies(
request: Request,
client_identifier: str,
policy_update: ClientPolicyUpdate,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientPolicyView:
"""Update client policies.
This is also how you delete policies.
"""
client_id = FlexID.from_string(client_identifier)
client_op = ClientOperations(session, request)
return await client_op.update_client_policies(client_id, policy_update)
return router

View File

@ -1,14 +1,21 @@
"""Models for API views.""" """Client related schemas."""
from typing import Annotated, Self
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import Annotated, Self, Sequence, override
from pydantic import AfterValidator, BaseModel, Field, IPvAnyAddress, IPvAnyNetwork from pydantic import (
AfterValidator,
BaseModel,
Field,
IPvAnyAddress,
IPvAnyNetwork,
model_validator,
)
from sshecret.crypto import public_key_validator from sshecret.crypto import public_key_validator
from . import models from sshecret_backend import models
class ClientView(BaseModel): class ClientView(BaseModel):
@ -19,9 +26,12 @@ class ClientView(BaseModel):
description: str | None = None description: str | None = None
public_key: str public_key: str
policies: list[str] = ["0.0.0.0/0", "::/0"] policies: list[str] = ["0.0.0.0/0", "::/0"]
is_active: bool = True
is_deleted: bool = False
secrets: list[str] = Field(default_factory=list) secrets: list[str] = Field(default_factory=list)
created_at: datetime | None created_at: datetime | None
updated_at: datetime | None = None updated_at: datetime | None = None
deleted_at: datetime | None = None
@classmethod @classmethod
def from_client_list(cls, clients: list[models.Client]) -> list[Self]: def from_client_list(cls, clients: list[models.Client]) -> list[Self]:
@ -39,6 +49,9 @@ class ClientView(BaseModel):
public_key=client.public_key, public_key=client.public_key,
created_at=client.created_at, created_at=client.created_at,
updated_at=client.updated_at or None, updated_at=client.updated_at or None,
deleted_at=client.deleted_at or None,
is_active=client.is_active,
is_deleted=client.is_deleted,
) )
if client.secrets: if client.secrets:
view.secrets = [secret.name for secret in client.secrets] view.secrets = [secret.name for secret in client.secrets]
@ -57,6 +70,32 @@ class ClientQueryResult(BaseModel):
remaining_results: int remaining_results: int
class ClientListParams(BaseModel):
"""Client list parameters."""
limit: int = Field(100, gt=0, le=100)
offset: int = Field(0, ge=0)
id: uuid.UUID | None = None
name: str | None = None
name__like: str | None = None
name__contains: str | None = None
@model_validator(mode="after")
def validate_expressions(self) -> Self:
"""Validate mutually exclusive expression."""
name_filter = False
if self.name__like or self.name__contains:
name_filter = True
if self.name__like and self.name__contains:
raise ValueError("You may only specify one name expression")
if self.name and name_filter:
raise ValueError(
"You must either specify name or one of name__like or name__contains"
)
return self
class ClientCreate(BaseModel): class ClientCreate(BaseModel):
"""Model to create a client.""" """Model to create a client."""
@ -79,48 +118,6 @@ class ClientUpdate(BaseModel):
public_key: Annotated[str, AfterValidator(public_key_validator)] public_key: Annotated[str, AfterValidator(public_key_validator)]
class BodyValue(BaseModel):
"""A generic model with just a value parameter."""
value: str
class ClientSecretPublic(BaseModel):
"""Public model to manage client secrets."""
name: str
secret: str
description: str | None = None
@classmethod
def from_client_secret(cls, client_secret: models.ClientSecret) -> Self:
"""Instantiate from ClientSecret."""
return cls(
name=client_secret.name,
secret=client_secret.secret,
description=client_secret.description,
)
class ClientSecretResponse(ClientSecretPublic):
"""A secret view."""
created_at: datetime | None
updated_at: datetime | None = None
@override
@classmethod
def from_client_secret(cls, client_secret: models.ClientSecret) -> Self:
"""Instantiate from ClientSecret."""
return cls(
name=client_secret.name,
secret=client_secret.secret,
created_at=client_secret.created_at,
updated_at=client_secret.updated_at,
)
class ClientPolicyView(BaseModel): class ClientPolicyView(BaseModel):
"""Update object for client policy.""" """Update object for client policy."""
@ -138,55 +135,3 @@ class ClientPolicyUpdate(BaseModel):
"""Model for updating policies.""" """Model for updating policies."""
sources: list[IPvAnyAddress | IPvAnyNetwork] sources: list[IPvAnyAddress | IPvAnyNetwork]
class ClientSecretList(BaseModel):
"""Model for aggregating identically named secrets."""
name: str
clients: list[str]
class ClientReference(BaseModel):
"""Reference to a client."""
id: str
name: str
class ClientSecretDetailList(BaseModel):
"""A more detailed version of the ClientSecretList."""
name: str
ids: list[str] = Field(default_factory=list)
clients: list[ClientReference] = Field(default_factory=list)
class AuditView(BaseModel):
"""Audit log view."""
id: uuid.UUID | None = None
subsystem: models.SubSystem
message: str
operation: models.Operation
data: dict[str, str] | None = None
client_id: uuid.UUID | None = None
client_name: str | None = None
secret_id: uuid.UUID | None = None
secret_name: str | None = None
origin: str | None = None
timestamp: datetime | None = None
class AuditInfo(BaseModel):
"""Information about audit information."""
entries: int
class AuditListResult(BaseModel):
"""Class to return when listing audit entries."""
results: Sequence[AuditView]
total: int
remaining: int

View File

@ -1,17 +1,66 @@
"""Common helpers.""" """Common helpers."""
import re import re
from typing import Self
import uuid import uuid
from dataclasses import dataclass, field
from enum import Enum
import bcrypt import bcrypt
from pydantic import BaseModel
from sqlalchemy import Select from sqlalchemy import Select
from sqlalchemy.orm import selectinload
from sqlalchemy.future import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from sshecret_backend.models import Client, ClientAccessPolicy
from sshecret_backend.models import Client RE_UUID = re.compile(
"^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$"
)
RelaxedId = uuid.UUID | str
class IdType(Enum):
"""Id type."""
ID = "id"
NAME = "name"
class FlexID(BaseModel):
"""Flexible identifier."""
type: IdType
value: RelaxedId
@classmethod
def id(cls, id: RelaxedId) -> Self:
"""Construct from ID."""
return cls(type=IdType.ID, value=id)
@classmethod
def name(cls, name: str) -> Self:
"""Construct from name."""
return cls(type=IdType.NAME, value=name)
@classmethod
def from_string(cls, value: str) -> Self:
"""Convert from path string."""
if value.startswith("id:"):
return cls.id(value[3:])
elif value.startswith("name:"):
return cls.name(value[5:])
return cls.name(value)
@dataclass
class NewClientVersion:
"""New client version dataclass."""
client: Client
policies: list[ClientAccessPolicy] = field(default_factory=list)
RE_UUID = re.compile("^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$")
def verify_token(token: str, stored_hash: str) -> bool: def verify_token(token: str, stored_hash: str) -> bool:
"""Verify token.""" """Verify token."""
@ -19,12 +68,19 @@ def verify_token(token: str, stored_hash: str) -> bool:
stored_bytes = stored_hash.encode("utf-8") stored_bytes = stored_hash.encode("utf-8")
return bcrypt.checkpw(token_bytes, stored_bytes) return bcrypt.checkpw(token_bytes, stored_bytes)
async def reload_client_with_relationships(session: AsyncSession, client: Client) -> Client:
async def reload_client_with_relationships(
session: AsyncSession, client: Client
) -> Client:
"""Reload a client from the database.""" """Reload a client from the database."""
session.expunge(client) session.expunge(client)
stmt = ( stmt = (
select(Client) select(Client)
.options(selectinload(Client.policies), selectinload(Client.secrets)) .options(
selectinload(Client.policies),
selectinload(Client.secrets),
selectinload(Client.previous_version),
)
.where(Client.id == client.id) .where(Client.id == client.id)
) )
result = await session.execute(stmt) result = await session.execute(stmt)
@ -36,13 +92,29 @@ def client_with_relationships() -> Select[tuple[Client]]:
return select(Client).options( return select(Client).options(
selectinload(Client.secrets), selectinload(Client.secrets),
selectinload(Client.policies), selectinload(Client.policies),
selectinload(Client.previous_version),
) )
async def get_client_by_name(session: AsyncSession, name: str) -> Client | None:
"""Get client by name.""" async def resolve_client_id(
session: AsyncSession,
name: str,
version: int | None = None,
include_deleted: bool = False,
) -> uuid.UUID | None:
"""Get the ID of a client name."""
if include_deleted:
client_filter = client_with_relationships().where(Client.name == name) client_filter = client_with_relationships().where(Client.name == name)
client_results = await session.execute(client_filter) else:
return client_results.scalars().first() client_filter = query_active_clients().where(Client.name == name)
if version:
client_filter = client_filter.where(Client.version == version)
client_result = await session.execute(client_filter)
if client := client_result.scalars().first():
return client.id
return None
async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | None: async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | None:
"""Get client by ID.""" """Get client by ID."""
@ -50,10 +122,76 @@ async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | Non
client_results = await session.execute(client_filter) client_results = await session.execute(client_filter)
return client_results.scalars().first() return client_results.scalars().first()
async def get_client_by_id_or_name(session: AsyncSession, id_or_name: str) -> Client | None:
async def get_client_by_id_or_name(
session: AsyncSession, id_or_name: str
) -> Client | None:
"""Get client either by id or name.""" """Get client either by id or name."""
if RE_UUID.match(id_or_name): if RE_UUID.match(id_or_name):
id = uuid.UUID(id_or_name) id = uuid.UUID(id_or_name)
return await get_client_by_id(session, id) return await get_client_by_id(session, id)
return await get_client_by_name(session, id_or_name) return await get_client_by_name(session, id_or_name)
def query_active_clients() -> Select[tuple[Client]]:
"""Get all active clients."""
client_filter = (
client_with_relationships()
.where(Client.is_active.is_(True))
.where(Client.is_deleted.is_(False))
)
return client_filter
async def get_client_by_name(session: AsyncSession, name: str) -> Client | None:
"""Get client by name.
This will get the latest client version, unless it's deleted.
"""
client_filter = (
client_with_relationships()
.where(Client.is_active.is_(True))
.where(Client.is_deleted.is_not(True))
.where(Client.name == name)
.order_by(Client.version.desc())
)
client_result = await session.execute(client_filter)
return client_result.scalars().first()
async def refresh_client(session: AsyncSession, client: Client) -> None:
"""Refresh the client and load in all relationships."""
await session.refresh(
client,
attribute_names=["secrets", "policies", "previous_version", "updated_at"],
)
async def create_new_client_version(
session: AsyncSession, current_client: Client, new_public_key: str
) -> Client:
new_client = Client(
name=current_client.name,
version=current_client.version + 1,
description=current_client.description,
public_key=new_public_key,
previous_version_id=current_client.id,
is_active=True,
)
# Mark current client as inactive
current_client.is_active = False
# Copy policies
for policy in current_client.policies:
copied_policy = ClientAccessPolicy(
client=new_client,
address=policy.source,
)
session.add(copied_policy)
session.add(new_client)
await session.flush()
await refresh_client(session, new_client)
return new_client

View File

@ -1,86 +0,0 @@
"""Policies sub-api router factory."""
# pyright: reportUnusedFunction=false
import logging
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Annotated
from sshecret_backend.models import ClientAccessPolicy
from sshecret_backend.view_models import (
ClientPolicyView,
ClientPolicyUpdate,
)
from sshecret_backend.types import AsyncDBSessionDep
from sshecret_backend import audit
from .common import get_client_by_id_or_name, reload_client_with_relationships
LOG = logging.getLogger(__name__)
def get_policy_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Construct clients sub-api."""
router = APIRouter()
@router.get("/clients/{name}/policies/")
async def get_client_policies(
name: str, session: Annotated[AsyncSession, Depends(get_db_session)]
) -> ClientPolicyView:
"""Get client policies."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
return ClientPolicyView.from_client(client)
@router.put("/clients/{name}/policies/")
async def update_client_policies(
request: Request,
name: str,
policy_update: ClientPolicyUpdate,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientPolicyView:
"""Update client policies.
This is also how you delete policies.
"""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
# Remove old policies.
policies = await session.scalars(
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
)
deleted_policies: list[ClientAccessPolicy] = []
added_policies: list[ClientAccessPolicy] = []
for policy in policies.all():
await session.delete(policy)
deleted_policies.append(policy)
LOG.debug("Updating client policies with: %r", policy_update.sources)
for source in policy_update.sources:
LOG.debug("Source %r", source)
policy = ClientAccessPolicy(source=str(source), client_id=client.id)
session.add(policy)
added_policies.append(policy)
await session.flush()
await session.commit()
client = await reload_client_with_relationships(session, client)
for policy in deleted_policies:
await audit.audit_remove_policy(session, request, client, policy)
for policy in added_policies:
await audit.audit_update_policy(session, request, client, policy)
return ClientPolicyView.from_client(client)
return router

View File

@ -0,0 +1,9 @@
"""Common API schemas."""
from pydantic import BaseModel
class BodyValue(BaseModel):
"""A generic model with just a value parameter."""
value: str

View File

@ -1,252 +0,0 @@
"""Secrets sub-api factory."""
# pyright: reportUnusedFunction=false
import logging
from collections import defaultdict
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy import select
from typing import Annotated
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sshecret_backend.models import Client, ClientSecret
from sshecret_backend.view_models import (
ClientReference,
ClientSecretDetailList,
ClientSecretList,
ClientSecretPublic,
BodyValue,
ClientSecretResponse,
)
from sshecret_backend import audit
from sshecret_backend.types import AsyncDBSessionDep
from .common import get_client_by_id_or_name
LOG = logging.getLogger(__name__)
async def lookup_client_secret(
session: AsyncSession, client: Client, name: str
) -> ClientSecret | None:
"""Look up a secret for a client."""
statement = (
select(ClientSecret)
.where(ClientSecret.client_id == client.id)
.where(ClientSecret.name == name)
)
results = await session.scalars(statement)
return results.first()
def get_secrets_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Construct clients sub-api."""
router = APIRouter()
@router.post("/clients/{name}/secrets/")
async def add_secret_to_client(
request: Request,
name: str,
client_secret: ClientSecretPublic,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> None:
"""Add secret to a client."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
existing_secret = await lookup_client_secret(
session, client, client_secret.name
)
if existing_secret:
raise HTTPException(
status_code=400,
detail="Cannot add a secret. A different secret with the same name already exists.",
)
db_secret = ClientSecret(
name=client_secret.name, client_id=client.id, secret=client_secret.secret
)
session.add(db_secret)
await session.commit()
await session.refresh(db_secret)
await audit.audit_create_secret(session, request, client, db_secret)
@router.put("/clients/{name}/secrets/{secret_name}")
async def update_client_secret(
request: Request,
name: str,
secret_name: str,
secret_data: BodyValue,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretResponse:
"""Update a client secret.
This can also be used for destructive creates.
"""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
existing_secret = await lookup_client_secret(session, client, secret_name)
if existing_secret:
existing_secret.secret = secret_data.value
session.add(existing_secret)
await session.commit()
await session.refresh(existing_secret)
await audit.audit_update_secret(session, request, client, existing_secret)
return ClientSecretResponse.from_client_secret(existing_secret)
db_secret = ClientSecret(
name=secret_name,
client_id=client.id,
secret=secret_data.value,
)
session.add(db_secret)
await session.commit()
await session.refresh(db_secret)
await audit.audit_create_secret(session, request, client, db_secret)
return ClientSecretResponse.from_client_secret(db_secret)
@router.get("/clients/{name}/secrets/{secret_name}")
async def request_client_secret(
request: Request,
name: str,
secret_name: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretResponse:
"""Get a client secret."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
secret = await lookup_client_secret(session, client, secret_name)
if not secret:
raise HTTPException(
status_code=404, detail="Cannot find a secret with the given name."
)
response_model = ClientSecretResponse.from_client_secret(secret)
await audit.audit_access_secret(session, request, client, secret)
return response_model
@router.delete("/clients/{name}/secrets/{secret_name}")
async def delete_client_secret(
request: Request,
name: str,
secret_name: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> None:
"""Delete a secret."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
secret = await lookup_client_secret(session, client, secret_name)
if not secret:
raise HTTPException(
status_code=404, detail="Cannot find a secret with the given name."
)
await session.delete(secret)
await session.commit()
await audit.audit_delete_secret(session, request, client, secret)
@router.get("/secrets/")
async def get_secret_map(
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> list[ClientSecretList]:
"""Get a list of all secrets and which clients have them."""
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
client_secrets = await session.scalars(
select(ClientSecret).options(selectinload(ClientSecret.client))
)
for client_secret in client_secrets.all():
if not client_secret.client:
if client_secret.name not in client_secret_map:
client_secret_map[client_secret.name] = []
continue
client_secret_map[client_secret.name].append(client_secret.client.name)
# audit.audit_client_secret_list(session, request)
return [
ClientSecretList(name=secret_name, clients=clients)
for secret_name, clients in client_secret_map.items()
]
@router.get("/secrets/detailed/")
async def get_detailed_secret_map(
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> list[ClientSecretDetailList]:
"""Get a list of all secrets and which clients have them."""
client_secrets: dict[str, ClientSecretDetailList] = {}
all_client_secrets = await session.execute(
select(ClientSecret).options(selectinload(ClientSecret.client))
)
for client_secret in all_client_secrets.scalars().all():
if client_secret.name not in client_secrets:
client_secrets[client_secret.name] = ClientSecretDetailList(
name=client_secret.name
)
client_secrets[client_secret.name].ids.append(str(client_secret.id))
if not client_secret.client:
continue
client_secrets[client_secret.name].clients.append(
ClientReference(
id=str(client_secret.client.id), name=client_secret.client.name
)
)
# `audit.audit_client_secret_list(session, request)
return list(client_secrets.values())
@router.get("/secrets/{name}")
async def get_secret_clients(
name: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretList:
"""Get a list of which clients has a named secret."""
clients: list[str] = []
client_secrets = await session.scalars(
select(ClientSecret)
.options(selectinload(ClientSecret.client))
.where(ClientSecret.name == name)
)
for client_secret in client_secrets.all():
if not client_secret.client:
continue
clients.append(client_secret.client.name)
return ClientSecretList(name=name, clients=clients)
@router.get("/secrets/{name}/detailed")
async def get_secret_clients_detailed(
name: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretDetailList:
"""Get a list of which clients has a named secret."""
detail_list = ClientSecretDetailList(name=name)
client_secrets = await session.scalars(
select(ClientSecret).where(ClientSecret.name == name)
)
for client_secret in client_secrets.all():
if not client_secret.client:
continue
detail_list.ids.append(str(client_secret.id))
detail_list.clients.append(
ClientReference(
id=str(client_secret.client.id), name=client_secret.client.name
)
)
return detail_list
return router

View File

@ -0,0 +1,328 @@
"""ClientSecret operations."""
import logging
import uuid
from datetime import datetime, timezone
from fastapi import HTTPException, Request
from sqlalchemy import Null, distinct, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sshecret_backend import audit
from sshecret_backend.api.common import (
FlexID,
IdType,
get_client_by_id,
resolve_client_id,
)
from sshecret_backend.api.secrets.schemas import (
ClientReference,
ClientSecretDetailList,
ClientSecretResponse,
ClientSecretStats,
DistinctClientSecretStats,
)
from sshecret_backend.models import Client, ClientSecret
LOG = logging.getLogger(__name__)
RelaxedId = uuid.UUID | str
def _id(id: RelaxedId) -> uuid.UUID:
"""Ensure that the ID is a uuid."""
if isinstance(id, str):
return uuid.UUID(id)
return id
class ClientSecretOperations:
"""Perform operations on a client's ClientSecrets"""
def __init__(
self,
session: AsyncSession,
request: Request,
client: FlexID,
include_deleted: bool = False,
) -> None:
"""Create operations class."""
self._client_id: uuid.UUID | None = None
self.client_name: str | None = None
self.client: Client | None = None
if client.type is IdType.ID:
self._client_id = _id(client.value)
else:
self.client_name = str(client.value)
self.include_deleted: bool = include_deleted
self.session: AsyncSession = session
self.request: Request = request
async def get_client_id(self) -> uuid.UUID | None:
"""Get client ID."""
if self._client_id:
LOG.debug("Returning previously resolved client ID.")
return self._client_id
if not self.client_name:
raise RuntimeError("Error: No client information registered.")
client_id = await resolve_client_id(
self.session, self.client_name, include_deleted=self.include_deleted
)
self._client_id = client_id
LOG.debug("Saving client ID %s", client_id)
return client_id
async def get_client(self) -> Client | None:
"""Get client."""
if self.client:
return self.client
client_id = await self.get_client_id()
if not client_id:
return None
client = await get_client_by_id(self.session, client_id)
self.client = client
return client
async def _get_client_secret(
self, secret_identifier: FlexID
) -> ClientSecret | None:
"""Get client secret private method."""
client = await self.get_client()
if not client:
LOG.debug("Client lookup failed.")
return None
if secret_identifier.type is IdType.ID:
match_id = _id(secret_identifier.value)
LOG.debug("Searching for secrets matching ID %r", match_id)
matches = [secret for secret in client.secrets if secret.id == match_id]
else:
LOG.debug("Searching for secrets matching name")
matches = [
secret
for secret in client.secrets
if secret.name == str(secret_identifier.value)
]
for secret_match in matches:
if secret_match.deleted:
# TODO: Add override for deleted.
LOG.debug("Found deleted secret")
continue
await audit.audit_access_secret(
self.session, self.request, client, secret_match
)
LOG.debug("Found matching secret")
return secret_match
LOG.debug("No secrets matched.")
return None
async def get_client_secret(
self, secret_identifier: FlexID
) -> ClientSecretResponse:
"""Get client secret public model."""
client_secret = await self._get_client_secret(secret_identifier)
if not client_secret:
raise HTTPException(status_code=404, detail="Unknown client or secret")
return ClientSecretResponse.from_client_secret(client_secret)
async def create_client_secret(
self, name: str, value: str, description: str | None = None
) -> ClientSecretResponse:
"""Create client Secret."""
client = await self.get_client()
if not client:
raise HTTPException(status_code=404, detail="Client not found.")
existing_secret = await self._get_client_secret(FlexID.name(name))
if existing_secret:
raise HTTPException(
status_code=400,
detail="Cannot add a secret. A different secret with the same name already exists.",
)
secret = ClientSecret(
name=name,
description=description,
client=client,
secret=value,
)
self.session.add(secret)
await self.session.commit()
await self.session.refresh(secret, attribute_names=["client", "updated_at"])
await audit.audit_create_secret(self.session, self.request, client, secret)
return ClientSecretResponse.from_client_secret(secret)
async def update_client_secret_value(
self, secret_identifier: FlexID, value: str
) -> ClientSecretResponse:
"""Update client secret value."""
client_secret = await self._get_client_secret(secret_identifier)
if not client_secret:
# We can use this method to create a secret too.
if secret_identifier.type is IdType.NAME:
return await self.create_client_secret(
str(secret_identifier.value), value
)
raise HTTPException(status_code=404, detail="Unknown client or secret")
client_secret.secret = value
self.session.add(client_secret)
await self.session.commit()
await self.session.refresh(client_secret, ["updated_at"])
await audit.audit_update_secret(
self.session, self.request, client_secret.client, client_secret
)
return ClientSecretResponse.from_client_secret(client_secret)
async def delete_client_secret(self, secret_identifier: FlexID) -> None:
"""Delete a client secret."""
client_secret = await self._get_client_secret(secret_identifier)
if not client_secret:
return
client_secret.deleted = True
client_secret.deleted_at = datetime.now(timezone.utc)
self.session.add(client_secret)
await self.session.commit()
await self.session.refresh(client_secret, ["deleted", "deleted_at"])
await audit.audit_delete_secret(
self.session, self.request, client_secret.client, client_secret
)
async def resolve_client_secret_mapping(
session: AsyncSession,
) -> list[ClientSecretDetailList]:
"""Resolve mapping of clients to secrets."""
result = await session.execute(
select(ClientSecret)
.join(ClientSecret.client)
.options(selectinload(ClientSecret.client))
.where(Client.is_active.is_not(False))
.where(ClientSecret.deleted.is_not(True))
)
client_secrets: dict[str, ClientSecretDetailList] = {}
for secret in result.scalars().all():
if secret.name not in client_secrets:
client_secrets[secret.name] = ClientSecretDetailList(name=secret.name)
client_secrets[secret.name].ids.append(str(secret.id))
if not secret.client:
continue
client_secrets[secret.name].clients.append(
ClientReference(id=str(secret.client.id), name=secret.client.name)
)
return list(client_secrets.values())
async def resolve_client_secret_clients(
session: AsyncSession, name: str, include_deleted: bool = False
) -> ClientSecretDetailList | None:
"""Resolve client association to a secret."""
statement = (
select(ClientSecret)
.options(selectinload(ClientSecret.client))
.where(ClientSecret.name == name)
)
if not include_deleted:
statement = statement.where(ClientSecret.deleted.is_not(True))
results = await session.execute(statement)
clients: ClientSecretDetailList | None = None
for client_secret in results.scalars().all():
if not clients:
# Ensure we don't create the object before we have at least one client.
clients = ClientSecretDetailList(name=name)
clients.ids.append(str(client_secret.id))
if client_secret.client:
clients.clients.append(
ClientReference(
id=str(client_secret.client.id), name=client_secret.client.name
)
)
return clients
async def _get_secrets_total(session: AsyncSession) -> int:
"""Get total amount of secrets excluding deleted."""
statement = (
select(func.count())
.select_from(ClientSecret)
.where(ClientSecret.deleted.is_not(True))
)
result = await session.execute(statement)
return result.scalar_one()
async def _get_distinct_secret_count(session: AsyncSession) -> int:
"""Get the amount of distinct secrets, excluding deleted."""
statement = (
select(func.count())
.select_from(ClientSecret)
.where(ClientSecret.deleted.is_not(True))
.distinct()
)
result = await session.execute(statement)
return result.scalar_one()
async def _get_secret_client_stats(
session: AsyncSession,
) -> list[DistinctClientSecretStats]:
"""Get stats for each named secret."""
statement = (
(
select(
ClientSecret.name, func.count(distinct(Client.id)).label("client_count")
)
)
.join(Client, ClientSecret.client_id)
.where(ClientSecret.deleted.is_not(True), Client.is_deleted.is_not(True))
.group_by(ClientSecret.name)
)
results = await session.execute(statement)
stats: list[DistinctClientSecretStats] = []
rows = results.all()
for row in rows:
secret_name, client_count = row.tuple()
stats.append(DistinctClientSecretStats(name=secret_name, clients=client_count))
return stats
async def _get_secrets_without_clients(
session: AsyncSession,
) -> int:
"""Get secrets without clients."""
statement = (
select(func.count(distinct(ClientSecret.name)))
.where(ClientSecret.deleted.is_not(True))
.where(ClientSecret.client_id.is_(Null))
)
result = await session.execute(statement)
return result.scalar_one()
async def get_client_secret_stats(session: AsyncSession) -> ClientSecretStats:
"""Get stats for the client secrets.
TODO: Implement a route with this.
"""
distinct_secrets = await _get_distinct_secret_count(session)
total_secrets = await _get_secrets_total(session)
secrets_without_clients = await _get_secrets_without_clients(session)
client_stats = await _get_secret_client_stats(session)
return ClientSecretStats(
distinct_secrets=distinct_secrets,
total_secrets=total_secrets,
secrets_without_clients=secrets_without_clients,
client_stats=client_stats,
)

View File

@ -0,0 +1,107 @@
"""Client Secret Router."""
# pyright: reportUnusedFunction=false
import logging
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sshecret_backend.api.common import FlexID
from sshecret_backend.types import AsyncDBSessionDep
from sshecret_backend.api.secrets.operations import (
ClientSecretOperations,
resolve_client_secret_clients,
resolve_client_secret_mapping,
)
from sshecret_backend.api.schemas import BodyValue
from sshecret_backend.api.secrets.schemas import (
ClientSecretDetailList,
ClientSecretPublic,
ClientSecretResponse,
)
LOG = logging.getLogger(__name__)
def create_client_secrets_router(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Create client secret router."""
router = APIRouter()
@router.post("/clients/{client_identifier}/secrets/")
async def add_secret_to_client(
request: Request,
client_identifier: str,
client_secret: ClientSecretPublic,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> None:
"""Add secret to a client."""
client = FlexID.from_string(client_identifier)
client_op = ClientSecretOperations(session, request, client)
await client_op.create_client_secret(
client_secret.name, client_secret.secret, client_secret.description
)
@router.put("/clients/{client_identifier}/secrets/{secret_identifier}")
async def update_client_secret(
request: Request,
client_identifier: str,
secret_identifier: str,
secret_data: BodyValue,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretResponse:
"""Update client secret."""
client = FlexID.from_string(client_identifier)
secret = FlexID.from_string(secret_identifier)
client_op = ClientSecretOperations(session, request, client)
return await client_op.update_client_secret_value(secret, secret_data.value)
@router.get("/clients/{client_identifier}/secrets/{secret_identifier}")
async def request_client_secret_named(
request: Request,
client_identifier: str,
secret_identifier: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretResponse:
"""Get a named secret from a named client."""
client = FlexID.from_string(client_identifier)
secret = FlexID.from_string(secret_identifier)
LOG.debug("Resolved client FlexID: %r", client)
LOG.debug("Resolved secret FlexID: %r", secret)
client_op = ClientSecretOperations(session, request, client)
return await client_op.get_client_secret(secret)
# TODO: delete_client_secret
@router.delete("/clients/{client_identifier}/secrets/{secret_identifier}")
async def delete_client_secret(
request: Request,
client_identifier: str,
secret_identifier: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> None:
"""Delete client secret."""
client = FlexID.from_string(client_identifier)
secret = FlexID.from_string(secret_identifier)
client_op = ClientSecretOperations(session, request, client)
await client_op.delete_client_secret(secret)
@router.get("/secrets/")
async def get_secret_map(
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> list[ClientSecretDetailList]:
"""Get a list of secrets and which clients have them."""
return await resolve_client_secret_mapping(session)
@router.get("/secrets/{name}")
async def get_secret_clients(
name: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretDetailList:
"""Get a list of which clients has a named secret."""
result = await resolve_client_secret_clients(session, name)
if not result:
raise HTTPException(status_code=404, detail="Secret not found.")
return result
return router

View File

@ -0,0 +1,88 @@
"""Schemas for the secrets router."""
from datetime import datetime
from typing import Self, override
import uuid
from pydantic import BaseModel, Field
from sshecret_backend import models
class ClientSecretPublic(BaseModel):
"""Public model to manage client secrets."""
name: str
secret: str
description: str | None = None
@classmethod
def from_client_secret(cls, client_secret: models.ClientSecret) -> Self:
"""Instantiate from ClientSecret."""
return cls(
name=client_secret.name,
secret=client_secret.secret,
description=client_secret.description,
)
class ClientSecretResponse(ClientSecretPublic):
"""A secret view."""
id: uuid.UUID
created_at: datetime | None
updated_at: datetime | None = None
@override
@classmethod
def from_client_secret(cls, client_secret: models.ClientSecret) -> Self:
"""Instantiate from ClientSecret."""
return cls(
id=client_secret.id,
name=client_secret.name,
secret=client_secret.secret,
created_at=client_secret.created_at,
updated_at=client_secret.updated_at,
)
class ClientSecretList(BaseModel):
"""Model for aggregating identically named secrets."""
name: str
clients: list[str]
class ClientReference(BaseModel):
"""Reference to a client."""
id: str
name: str
class ClientSecretDetailList(BaseModel):
"""A more detailed version of the ClientSecretList."""
name: str
ids: list[str] = Field(default_factory=list)
clients: list[ClientReference] = Field(default_factory=list)
class DistinctClientSecretStats(BaseModel):
"""Stats for distinct client secrets."""
name: str
clients: int = 0
class ClientSecretStats(BaseModel):
"""Stats row for the clientSecret model.
Useful for pagination and statistics.
"""
distinct_secrets: int
total_secrets: int
secrets_without_clients: int
client_stats: list[DistinctClientSecretStats] = Field(default_factory=list)

View File

@ -12,16 +12,13 @@ from fastapi import (
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from sqlalchemy import Engine
from sqlalchemy.ext.asyncio import AsyncEngine
from .models import init_db_async from .models import init_db_async
from .backend_api import get_backend_api from .backend_api import get_backend_api
from .db import setup_database, get_async_engine from .db import get_async_engine
from .settings import BackendSettings from .settings import BackendSettings
from .types import AsyncDBSessionDep
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)

View File

@ -5,7 +5,14 @@ from fastapi import Request
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy, Operation, SubSystem from .models import (
AuditLog,
Client,
ClientSecret,
ClientAccessPolicy,
Operation,
SubSystem,
)
def _get_origin(request: Request) -> str | None: def _get_origin(request: Request) -> str | None:
@ -128,6 +135,27 @@ async def audit_update_client(
await _write_audit_log(session, request, entry, commit) await _write_audit_log(session, request, entry, commit)
async def audit_new_client_version(
session: AsyncSession,
request: Request,
old_client: Client,
new_client: Client,
commit: bool = True,
) -> None:
"""Audit an update secret event."""
entry = AuditLog(
operation=Operation.UPDATE,
client_id=old_client.id,
client_name=old_client.name,
message="Client data updated",
data={
"new_client_id": str(new_client.id),
"new_client_version": new_client.version,
},
)
await _write_audit_log(session, request, entry, commit)
async def audit_update_secret( async def audit_update_secret(
session: AsyncSession, session: AsyncSession,
request: Request, request: Request,
@ -224,6 +252,7 @@ async def audit_access_secret(
) )
await _write_audit_log(session, request, entry, commit) await _write_audit_log(session, request, entry, commit)
async def audit_client_secret_list( async def audit_client_secret_list(
session: AsyncSession, request: Request, commit: bool = True session: AsyncSession, request: Request, commit: bool = True
) -> None: ) -> None:
@ -233,4 +262,3 @@ async def audit_client_secret_list(
message="All secret names and their clients was viewed", message="All secret names and their clients was viewed",
) )
await _write_audit_log(session, request, entry, commit) await _write_audit_log(session, request, entry, commit)

View File

@ -2,6 +2,7 @@
import bcrypt import bcrypt
def hash_token(token: str) -> str: def hash_token(token: str) -> str:
"""Hash a token.""" """Hash a token."""
pwbytes = token.encode("utf-8") pwbytes = token.encode("utf-8")

View File

@ -9,7 +9,9 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sshecret_backend.db import DatabaseSessionManager from sshecret_backend.db import DatabaseSessionManager
from sshecret_backend.settings import BackendSettings from sshecret_backend.settings import BackendSettings
from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api from .api.audit.router import create_audit_router
from .api.secrets.router import create_client_secrets_router
from .api.clients.router import create_client_router
from .auth import verify_token from .auth import verify_token
from .models import ( from .models import (
APIClient, APIClient,
@ -59,9 +61,8 @@ def get_backend_api(
dependencies=[Depends(validate_token)], dependencies=[Depends(validate_token)],
) )
backend_api.include_router(get_audit_api(get_db_session)) backend_api.include_router(create_audit_router(get_db_session))
backend_api.include_router(get_clients_api(get_db_session)) backend_api.include_router(create_client_router(get_db_session))
backend_api.include_router(get_policy_api(get_db_session)) backend_api.include_router(create_client_secrets_router(get_db_session))
backend_api.include_router(get_secrets_api(get_db_session))
return backend_api return backend_api

View File

@ -6,9 +6,15 @@ import sqlite3
from collections.abc import AsyncIterator, Generator, Callable from collections.abc import AsyncIterator, Generator, Callable
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any, Literal from typing import Literal
from sqlalchemy import create_engine, Engine, event, select from sqlalchemy import create_engine, Engine, event, select
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession, async_sessionmaker, create_async_engine, AsyncEngine from sqlalchemy.ext.asyncio import (
AsyncConnection,
AsyncSession,
async_sessionmaker,
create_async_engine,
AsyncEngine,
)
from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.orm import sessionmaker, Session
@ -21,11 +27,14 @@ from .models import APIClient, SubSystem
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
class DatabaseSessionManager: class DatabaseSessionManager:
def __init__(self, host: URL, **engine_kwargs: str): def __init__(self, host: URL, **engine_kwargs: str):
self._engine: AsyncEngine | None = get_async_engine(host) self._engine: AsyncEngine | None = get_async_engine(host)
self._sessionmaker: async_sessionmaker[AsyncSession] | None = async_sessionmaker(autocommit=False, bind=self._engine, expire_on_commit=False) self._sessionmaker: async_sessionmaker[AsyncSession] | None = (
async_sessionmaker(
autocommit=False, bind=self._engine, expire_on_commit=False
)
)
async def close(self): async def close(self):
if self._engine is None: if self._engine is None:
@ -81,7 +90,6 @@ def setup_database(
return engine, get_db_session return engine, get_db_session
def get_engine(url: URL, echo: bool = False) -> Engine: def get_engine(url: URL, echo: bool = False) -> Engine:
"""Initialize the engine.""" """Initialize the engine."""
engine = create_engine(url, echo=echo) engine = create_engine(url, echo=echo)
@ -102,6 +110,7 @@ def get_async_engine(url: URL, echo: bool = False, **engine_kwargs: str) -> Asyn
"""Get an async engine.""" """Get an async engine."""
engine = create_async_engine(url, echo=echo, **engine_kwargs) engine = create_async_engine(url, echo=echo, **engine_kwargs)
if url.drivername.startswith("sqlite+"): if url.drivername.startswith("sqlite+"):
@event.listens_for(engine.sync_engine, "connect") @event.listens_for(engine.sync_engine, "connect")
def set_sqlite_pragma( def set_sqlite_pragma(
dbapi_connection: sqlite3.Connection, _connection_record: object dbapi_connection: sqlite3.Connection, _connection_record: object
@ -113,11 +122,14 @@ def get_async_engine(url: URL, echo: bool = False, **engine_kwargs: str) -> Asyn
return engine return engine
def create_api_token_with_value(
def create_api_token_with_value(session: Session, token: str, subsystem: Literal["admin", "sshd"]) -> None: session: Session, token: str, subsystem: Literal["admin", "sshd"]
) -> None:
"""Create API token with a given value.""" """Create API token with a given value."""
existing = session.scalars(select(APIClient).where(APIClient.subsystem == SubSystem(subsystem))).first() existing = session.scalars(
select(APIClient).where(APIClient.subsystem == SubSystem(subsystem))
).first()
if existing: if existing:
if verify_token(token, existing.token): if verify_token(token, existing.token):
LOG.info("Token is up to date.") LOG.info("Token is up to date.")
@ -135,12 +147,19 @@ def create_api_token_with_value(session: Session, token: str, subsystem: Literal
session.add(api_token) session.add(api_token)
session.commit() session.commit()
def create_api_token(session: Session, subsystem: Literal["admin", "sshd", "test"], recreate: bool = False) -> str:
def create_api_token(
session: Session,
subsystem: Literal["admin", "sshd", "test"],
recreate: bool = False,
) -> str:
"""Create API token.""" """Create API token."""
subsys = SubSystem(subsystem) subsys = SubSystem(subsystem)
token = secrets.token_urlsafe(32) token = secrets.token_urlsafe(32)
hashed = hash_token(token) hashed = hash_token(token)
if existing := session.scalars(select(APIClient).where(APIClient.subsystem == subsys)).first(): if existing := session.scalars(
select(APIClient).where(APIClient.subsystem == subsys)
).first():
if not recreate: if not recreate:
raise RuntimeError("Error: A token already exist for this subsystem.") raise RuntimeError("Error: A token already exist for this subsystem.")
existing.token = hashed existing.token = hashed

View File

@ -51,13 +51,22 @@ class Client(Base):
"""Clients.""" """Clients."""
__tablename__: str = "client" __tablename__: str = "client"
__table_args__: tuple[sa.UniqueConstraint, ...] = (
sa.UniqueConstraint("name", "version", name="uq_client_name_version"),
)
id: Mapped[uuid.UUID] = mapped_column( id: Mapped[uuid.UUID] = mapped_column(
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4 sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
) )
name: Mapped[str] = mapped_column(sa.String, unique=True) version: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=1)
name: Mapped[str] = mapped_column(sa.String, nullable=False)
description: Mapped[str | None] = mapped_column(sa.String, nullable=True) description: Mapped[str | None] = mapped_column(sa.String, nullable=True)
public_key: Mapped[str] = mapped_column(sa.Text) public_key: Mapped[str] = mapped_column(sa.Text, nullable=False)
is_active: Mapped[bool] = mapped_column(sa.Boolean, default=True)
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, default=False)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
@ -69,10 +78,23 @@ class Client(Base):
onupdate=sa.func.now(), onupdate=sa.func.now(),
) )
deleted_at: Mapped[datetime | None] = mapped_column(
sa.DateTime(timezone=True), nullable=True
)
secrets: Mapped[list["ClientSecret"]] = relationship( secrets: Mapped[list["ClientSecret"]] = relationship(
back_populates="client", passive_deletes=True back_populates="client", passive_deletes=True
) )
previous_version_id: Mapped[uuid.UUID | None] = mapped_column(
sa.Uuid(as_uuid=True),
sa.ForeignKey("client.id", ondelete="SET NULL"),
nullable=True,
)
previous_version: Mapped["Client | None"] = relationship(
"Client", remote_side=[id], backref="versions"
)
policies: Mapped[list["ClientAccessPolicy"]] = relationship(back_populates="client") policies: Mapped[list["ClientAccessPolicy"]] = relationship(back_populates="client")
@ -117,7 +139,7 @@ class ClientSecret(Base):
sa.Uuid(as_uuid=True), sa.ForeignKey("client.id", ondelete="CASCADE") sa.Uuid(as_uuid=True), sa.ForeignKey("client.id", ondelete="CASCADE")
) )
client: Mapped[Client] = relationship(back_populates="secrets") client: Mapped[Client] = relationship(back_populates="secrets")
invalidated: Mapped[bool] = mapped_column(default=False) deleted: Mapped[bool] = mapped_column(default=False)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
@ -129,6 +151,10 @@ class ClientSecret(Base):
onupdate=sa.func.now(), onupdate=sa.func.now(),
) )
deleted_at: Mapped[datetime | None] = mapped_column(
sa.DateTime(timezone=True), nullable=True
)
class APIClient(Base): class APIClient(Base):
"""A client on the API. """A client on the API.

View File

@ -2,7 +2,7 @@
from collections.abc import AsyncGenerator, Callable, Generator from collections.abc import AsyncGenerator, Callable, Generator
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session from sqlalchemy.orm import Session

View File

@ -11,7 +11,7 @@ from fastapi.testclient import TestClient
from sshecret.crypto import generate_private_key, generate_public_key_string from sshecret.crypto import generate_private_key, generate_public_key_string
from sshecret_backend.app import create_backend_app from sshecret_backend.app import create_backend_app
from sshecret_backend.testing import create_test_token from sshecret_backend.testing import create_test_token
from sshecret_backend.view_models import AuditView from sshecret_backend.api.audit.schemas import AuditView
from sshecret_backend.settings import BackendSettings from sshecret_backend.settings import BackendSettings
@ -20,7 +20,7 @@ handler = logging.StreamHandler()
formatter = logging.Formatter("'%(asctime)s - %(levelname)s - %(message)s'") formatter = logging.Formatter("'%(asctime)s - %(levelname)s - %(message)s'")
handler.setFormatter(formatter) handler.setFormatter(formatter)
LOG.addHandler(handler) LOG.addHandler(handler)
LOG.setLevel(logging.DEBUG) #LOG.setLevel(logging.DEBUG)
def make_test_key() -> str: def make_test_key() -> str:
@ -167,8 +167,8 @@ def test_put_add_secret(test_client: TestClient) -> None:
response_model = response.json() response_model = response.json()
del response_model["created_at"] del response_model["created_at"]
del response_model["updated_at"] del response_model["updated_at"]
assert response_model == data for key, value in data.items():
assert response_model.get(key) == value
def test_put_update_secret(test_client: TestClient) -> None: def test_put_update_secret(test_client: TestClient) -> None:
"""Test updating a client secret.""" """Test updating a client secret."""
@ -407,7 +407,7 @@ def test_get_secret_list(test_client: TestClient) -> None:
assert len(entry["clients"]) == 4 assert len(entry["clients"]) == 4
else: else:
assert len(entry["clients"]) == 1 assert len(entry["clients"]) == 1
assert entry["clients"][0] == entry["name"] assert entry["clients"][0]["name"] == entry["name"]
def test_get_secret_clients(test_client: TestClient) -> None: def test_get_secret_clients(test_client: TestClient) -> None:
@ -428,8 +428,9 @@ def test_get_secret_clients(test_client: TestClient) -> None:
data = resp.json() data = resp.json()
assert data["name"] == "commonsecret" assert data["name"] == "commonsecret"
assert "client-0" in data["clients"] client_names = [client["name"] for client in data["clients"]]
assert "client-1" not in data["clients"] assert "client-0" in client_names
assert "client-1" not in client_names
assert len(data["clients"]) == 2 assert len(data["clients"]) == 2
@ -473,7 +474,7 @@ def test_operations_with_id(test_client: TestClient) -> None:
data = resp.json() data = resp.json()
client = data["clients"][0] client = data["clients"][0]
client_id = client["id"] client_id = client["id"]
resp = test_client.get(f"/api/v1/clients/{client_id}") resp = test_client.get(f"/api/v1/clients/id:{client_id}")
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
assert data["name"] == "test" assert data["name"] == "test"
@ -559,3 +560,47 @@ def test_filter_audit_log(test_client: TestClient) -> None:
assert data["results"][0]["operation"] == "login" assert data["results"][0]["operation"] == "login"
assert data["results"][0]["message"] == "message1" assert data["results"][0]["message"] == "message1"
def test_secret_flexid(test_client: TestClient) -> None:
"""Test flexible IDs in the secret API."""
client_name = "test"
create_response = create_client(
test_client,
client_name,
)
assert create_response.status_code == 200
assert "id" in create_response.json()
client_id = create_response.json()["id"]
# Create a secret using the client name
secrets: dict[str, str] = {}
resp = test_client.put(
"/api/v1/clients/test/secrets/clientnamesecret",
json={"value": "secret"},
)
assert resp.status_code == 200
secret_data = resp.json()
assert "id" in secret_data
secrets["clientnamesecret"] = secret_data["id"]
# Create one using the client ID
resp = test_client.put(
f"/api/v1/clients/id:{client_id}/secrets/clientidsecret",
json={"value": "secret"},
)
assert resp.status_code == 200
secret_data = resp.json()
assert "id" in secret_data
secrets["clientidsecret"] = secret_data["id"]
# Let's try fetching the various permutations
for client_identifier in ("test", f"id:{client_id}"):
for secret_name, secret_id in secrets.items():
for secret_identifier in (secret_name, f"id:{secret_id}"):
resp = test_client.get(f"/api/v1/clients/{client_identifier}/secrets/{secret_identifier}")
assert resp.status_code == 200
resp_body = resp.json()
assert "id" in resp_body
assert resp_body["id"] == secret_id