Refactor backend views

This commit is contained in:
2025-06-08 17:40:50 +02:00
parent aa6b55a911
commit 7ad41f43d8
25 changed files with 1382 additions and 452 deletions

View File

@ -1,7 +1 @@
"""API factory modules."""
from .audit import get_audit_api
from .policies import get_policy_api
from .secrets import get_secrets_api
__all__ = ["get_audit_api", "get_policy_api", "get_secrets_api"]

View File

@ -15,7 +15,7 @@ from typing import Annotated
from sshecret_backend.models import AuditLog, Operation, SubSystem
from sshecret_backend.types import AsyncDBSessionDep
from sshecret_backend.view_models import AuditInfo, AuditView, AuditListResult
from .schemas import AuditInfo, AuditView, AuditListResult
LOG = logging.getLogger(__name__)
@ -58,7 +58,7 @@ class AuditFilter(BaseModel):
]
def get_audit_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
def create_audit_router(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Construct audit sub-api."""
router = APIRouter()
@ -70,11 +70,13 @@ def get_audit_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Get audit logs."""
# audit.audit_access_audit_log(session, request)
total = (await session.scalars(
select(func.count("*"))
.select_from(AuditLog)
.where(and_(True, *filters.filter_mapping))
)).one()
total = (
await session.scalars(
select(func.count("*"))
.select_from(AuditLog)
.where(and_(True, *filters.filter_mapping))
)
).one()
remaining = total - filters.offset
statement = (
@ -107,12 +109,12 @@ def get_audit_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
@router.get("/audit/info")
async def get_audit_info(
session: Annotated[AsyncSession, Depends(get_db_session)]
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> AuditInfo:
"""Get audit info."""
audit_count = (await session.scalars(
select(func.count("*")).select_from(AuditLog)
)).one()
audit_count = (
await session.scalars(select(func.count("*")).select_from(AuditLog))
).one()
return AuditInfo(entries=audit_count)
return router

View File

@ -0,0 +1,40 @@
"""Models for Audit API views."""
import uuid
from datetime import datetime
from collections.abc import Sequence
from pydantic import BaseModel
from sshecret_backend import models
class AuditView(BaseModel):
"""Audit log view."""
id: uuid.UUID | None = None
subsystem: models.SubSystem
message: str
operation: models.Operation
data: dict[str, str] | None = None
client_id: uuid.UUID | None = None
client_name: str | None = None
secret_id: uuid.UUID | None = None
secret_name: str | None = None
origin: str | None = None
timestamp: datetime | None = None
class AuditInfo(BaseModel):
"""Information about audit information."""
entries: int
class AuditListResult(BaseModel):
"""Class to return when listing audit entries."""
results: Sequence[AuditView]
total: int
remaining: int

View File

@ -0,0 +1,320 @@
"""Client operations."""
import logging
import uuid
from typing import Any, cast
from datetime import datetime, timezone
from fastapi import HTTPException, Request
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import Select
from sshecret_backend import audit
from sshecret_backend.api.common import (
FlexID,
IdType,
RelaxedId,
create_new_client_version,
query_active_clients,
get_client_by_id,
resolve_client_id,
refresh_client,
reload_client_with_relationships,
)
from sshecret_backend.models import Client, ClientAccessPolicy
from .schemas import (
ClientListParams,
ClientCreate,
ClientPolicyView,
ClientView,
ClientQueryResult,
ClientPolicyUpdate,
)
LOG = logging.getLogger(__name__)
def _id(id: RelaxedId) -> uuid.UUID:
"""Ensure that the ID is a uuid."""
if isinstance(id, str):
return uuid.UUID(id)
return id
class ClientOperations:
"""Perform operations on a client."""
def __init__(self, session: AsyncSession, request: Request) -> None:
"""Create operations class."""
self.session: AsyncSession = session
self.request: Request = request
self._client_id: uuid.UUID | None = None
# self.client_name: str | None = None
# if client.type is IdType.ID:
# self._client_id = _id(client.value)
# else:
# self.client_name = str(client.value)
async def get_client_id(
self,
client: FlexID,
version: int | None = None,
) -> uuid.UUID | None:
"""Get client ID."""
if self._client_id:
LOG.debug("Returning previously resolved client ID.")
return self._client_id
if client.type is IdType.ID:
self._client_id = _id(client.value)
return self._client_id
client_name = str(client.value)
client_id = await resolve_client_id(
self.session,
client_name,
version=version,
)
if not client_id:
return None
self._client_id = client_id
LOG.debug("Saving client ID %s", client_id)
return client_id
async def _get_client(
self,
client: FlexID,
version: int | None = None,
) -> Client | None:
"""Get client."""
client_id = await self.get_client_id(client, version=version)
if not client_id:
return None
db_client = await get_client_by_id(self.session, client_id)
return db_client
async def get_client(
self,
client: FlexID,
version: int | None = None,
) -> ClientView:
"""Get public client object."""
db_client = await self._get_client(client, version)
if not db_client:
raise HTTPException(status_code=404, detail="Client not found.")
return ClientView.from_client(db_client)
async def create_client(
self,
create_model: ClientCreate,
) -> ClientView:
"""Create a new client."""
existing_id = await self.get_client_id(FlexID.name(create_model.name))
if existing_id:
raise HTTPException(
status_code=400, detail="Error: A client already exists with this name."
)
client = create_model.to_client()
self.session.add(client)
await self.session.flush()
await self.session.commit()
await refresh_client(self.session, client)
await audit.audit_create_client(self.session, self.request, client)
return ClientView.from_client(client)
async def delete_client(self, client: FlexID) -> None:
"""Delete client."""
db_client = await self._get_client(client)
if not db_client:
return
if db_client.is_deleted:
return
db_client.is_deleted = True
db_client.deleted_at = datetime.now(timezone.utc)
self.session.add(db_client)
await self.session.commit()
await audit.audit_delete_client(self.session, self.request, db_client)
async def update_client(
self,
client: FlexID,
client_update: ClientCreate,
) -> ClientView:
"""Update client details."""
db_client = await self._get_client(client)
if not db_client:
raise HTTPException(status_code=404, detail="Client not found.")
if (
client_update.public_key
and client_update.public_key != db_client.public_key
):
return await new_client_version(
self.session,
db_client.id,
client_update.public_key,
client_update.name,
client_update.description,
)
db_client.name = client_update.name
db_client.description = client_update.description
self.session.add(db_client)
await self.session.commit()
await refresh_client(self.session, db_client)
await audit.audit_update_client(self.session, self.request, db_client)
return ClientView.from_client(db_client)
async def new_client_version(
self,
client: FlexID,
public_key: str,
name: str | None = None,
description: str | None = None,
) -> ClientView:
"""Update a client to a new version."""
current_client = await self._get_client(client)
if not current_client:
raise HTTPException(status_code=404, detail="Client not found.")
new_client = await create_new_client_version(
self.session, current_client, public_key
)
if name:
new_client.name = name
if description:
new_client.description = description
current_client.is_active = False
self.session.add(current_client)
await self.session.commit()
await refresh_client(self.session, new_client)
await audit.audit_new_client_version(
self.session, self.request, current_client, new_client
)
return ClientView.from_client(new_client)
async def get_client_policies(self, client: FlexID) -> ClientPolicyView:
"""Get client policies."""
db_client = await self._get_client(client)
if not db_client:
raise HTTPException(status_code=404, detail="Client not found.")
return ClientPolicyView.from_client(db_client)
async def update_client_policies(
self, client: FlexID, policy_update: ClientPolicyUpdate
) -> ClientPolicyView:
"""Update client policies."""
db_client = await self._get_client(client)
if not db_client:
raise HTTPException(status_code=404, detail="Client not found.")
policies = await self.session.scalars(
select(ClientAccessPolicy).where(
ClientAccessPolicy.client_id == db_client.id
)
)
deleted_policies: list[ClientAccessPolicy] = []
added_policies: list[ClientAccessPolicy] = []
for policy in policies.all():
await self.session.delete(policy)
deleted_policies.append(policy)
LOG.debug("Updating client policies with: %r", policy_update.sources)
for source in policy_update.sources:
LOG.debug("Source %r", source)
policy = ClientAccessPolicy(source=str(source), client_id=db_client.id)
self.session.add(policy)
added_policies.append(policy)
await self.session.flush()
await self.session.commit()
db_client = await reload_client_with_relationships(self.session, db_client)
for policy in deleted_policies:
await audit.audit_remove_policy(
self.session, self.request, db_client, policy
)
for policy in added_policies:
await audit.audit_update_policy(
self.session, self.request, db_client, policy
)
return ClientPolicyView.from_client(db_client)
def filter_client_statement(
statement: Select[Any], params: ClientListParams, ignore_limits: bool = False
) -> Select[Any]:
"""Filter a statement with the provided params."""
if params.id:
statement = statement.where(Client.id == params.id)
if params.name:
statement = statement.where(Client.name == params.name)
elif params.name__like:
statement = statement.where(Client.name.like(params.name__like))
elif params.name__contains:
statement = statement.where(Client.name.contains(params.name__contains))
if ignore_limits:
return statement
return statement.limit(params.limit).offset(params.offset)
async def get_clients(
session: AsyncSession,
filter_query: ClientListParams,
) -> ClientQueryResult:
"""Get Clients."""
count_statement = select(func.count("*")).select_from(Client)
count_statement = cast(
Select[tuple[int]],
filter_client_statement(count_statement, filter_query, True),
)
total_results = (await session.scalars(count_statement)).one()
statement = filter_client_statement(query_active_clients(), filter_query, False)
results = await session.scalars(statement)
remainder = total_results - filter_query.offset - filter_query.limit
if remainder < 0:
remainder = 0
clients = list(results.all())
clients_view = ClientView.from_client_list(clients)
return ClientQueryResult(
clients=clients_view,
total_results=total_results,
remaining_results=remainder,
)
async def new_client_version(
session: AsyncSession,
client_id: uuid.UUID,
public_key: str,
name: str | None = None,
description: str | None = None,
) -> ClientView:
"""Update a client to a new version."""
current_client = await get_client_by_id(session, client_id)
if not current_client:
raise ValueError("Client not found.")
new_client = await create_new_client_version(session, current_client, public_key)
if name:
new_client.name = name
if description:
new_client.description = description
current_client.is_active = False
session.add(current_client)
await session.commit()
return ClientView.from_client(new_client)

View File

@ -0,0 +1,122 @@
"""Client router."""
# pyright: reportUnusedFunction=false
#
import logging
from typing import Annotated
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sshecret_backend.types import AsyncDBSessionDep
from sshecret_backend.api.clients.schemas import (
ClientCreate,
ClientListParams,
ClientQueryResult,
ClientUpdate,
ClientView,
ClientPolicyUpdate,
ClientPolicyView,
)
from sshecret_backend.api.clients import operations
from sshecret_backend.api.clients.operations import ClientOperations
from sshecret_backend.api.common import FlexID
LOG = logging.getLogger(__name__)
def create_client_router(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Create client router."""
router = APIRouter()
@router.get("/clients/")
async def route_get_clients(
filter_query: Annotated[ClientListParams, Query()],
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientQueryResult:
return await operations.get_clients(session, filter_query)
@router.post("/clients/")
async def create_client(
request: Request,
client: ClientCreate,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Create client."""
client_op = ClientOperations(session, request)
return await client_op.create_client(client)
@router.get("/clients/{client_identifier}")
async def fetch_client_by_name(
request: Request,
client_identifier: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
version: Annotated[int | None, Query()] = None,
) -> ClientView:
"""Fetch client by name."""
client_id = FlexID.from_string(client_identifier)
client_op = ClientOperations(session, request)
return await client_op.get_client(client_id, version)
@router.delete("/clients/{client_identifier}")
async def delete_client(
request: Request,
client_identifier: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> None:
"""Delete a client."""
client_id = FlexID.from_string(client_identifier)
client_op = ClientOperations(session, request)
await client_op.delete_client(client_id)
@router.post("/clients/{client_identifier}/public-key")
async def update_client_public_key(
request: Request,
client_identifier: str,
client_update: ClientUpdate,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Update client public key."""
client_id = FlexID.from_string(client_identifier)
client_op = ClientOperations(session, request)
return await client_op.new_client_version(client_id, client_update.public_key)
@router.put("/clients/{client_identifier}")
async def update_client_by_name(
request: Request,
client_identifier: str,
client_update: ClientCreate,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Update client by name."""
client_id = FlexID.from_string(client_identifier)
client_op = ClientOperations(session, request)
return await client_op.update_client(client_id, client_update)
@router.get("/clients/{client_identifier}/policies/")
async def get_client_policies(
request: Request,
client_identifier: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientPolicyView:
"""Get client policies."""
client_id = FlexID.from_string(client_identifier)
client_op = ClientOperations(session, request)
return await client_op.get_client_policies(client_id)
@router.put("/clients/{client_identifier}/policies/")
async def update_client_policies(
request: Request,
client_identifier: str,
policy_update: ClientPolicyUpdate,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientPolicyView:
"""Update client policies.
This is also how you delete policies.
"""
client_id = FlexID.from_string(client_identifier)
client_op = ClientOperations(session, request)
return await client_op.update_client_policies(client_id, policy_update)
return router

View File

@ -0,0 +1,137 @@
"""Client related schemas."""
from typing import Annotated, Self
import uuid
from datetime import datetime
from pydantic import (
AfterValidator,
BaseModel,
Field,
IPvAnyAddress,
IPvAnyNetwork,
model_validator,
)
from sshecret.crypto import public_key_validator
from sshecret_backend import models
class ClientView(BaseModel):
"""View for a single client."""
id: uuid.UUID
name: str
description: str | None = None
public_key: str
policies: list[str] = ["0.0.0.0/0", "::/0"]
is_active: bool = True
is_deleted: bool = False
secrets: list[str] = Field(default_factory=list)
created_at: datetime | None
updated_at: datetime | None = None
deleted_at: datetime | None = None
@classmethod
def from_client_list(cls, clients: list[models.Client]) -> list[Self]:
"""Generate a list of responses from a list of clients."""
responses: list[Self] = [cls.from_client(client) for client in clients]
return responses
@classmethod
def from_client(cls, client: models.Client) -> Self:
"""Instantiate from a client."""
view = cls(
id=client.id,
name=client.name,
description=client.description,
public_key=client.public_key,
created_at=client.created_at,
updated_at=client.updated_at or None,
deleted_at=client.deleted_at or None,
is_active=client.is_active,
is_deleted=client.is_deleted,
)
if client.secrets:
view.secrets = [secret.name for secret in client.secrets]
if client.policies:
view.policies = [policy.source for policy in client.policies]
return view
class ClientQueryResult(BaseModel):
"""Result class for queries towards the client list."""
clients: list[ClientView] = Field(default_factory=list)
total_results: int
remaining_results: int
class ClientListParams(BaseModel):
"""Client list parameters."""
limit: int = Field(100, gt=0, le=100)
offset: int = Field(0, ge=0)
id: uuid.UUID | None = None
name: str | None = None
name__like: str | None = None
name__contains: str | None = None
@model_validator(mode="after")
def validate_expressions(self) -> Self:
"""Validate mutually exclusive expression."""
name_filter = False
if self.name__like or self.name__contains:
name_filter = True
if self.name__like and self.name__contains:
raise ValueError("You may only specify one name expression")
if self.name and name_filter:
raise ValueError(
"You must either specify name or one of name__like or name__contains"
)
return self
class ClientCreate(BaseModel):
"""Model to create a client."""
name: str
description: str | None = None
public_key: Annotated[str, AfterValidator(public_key_validator)]
def to_client(self) -> models.Client:
"""Instantiate a client."""
return models.Client(
name=self.name,
public_key=self.public_key,
description=self.description,
)
class ClientUpdate(BaseModel):
"""Model to update the client public key."""
public_key: Annotated[str, AfterValidator(public_key_validator)]
class ClientPolicyView(BaseModel):
"""Update object for client policy."""
sources: list[str] = ["0.0.0.0/0", "::/0"]
@classmethod
def from_client(cls, client: models.Client) -> Self:
"""Create from client."""
if not client.policies:
return cls()
return cls(sources=[policy.source for policy in client.policies])
class ClientPolicyUpdate(BaseModel):
"""Model for updating policies."""
sources: list[IPvAnyAddress | IPvAnyNetwork]

View File

@ -1,21 +1,58 @@
"""Common helpers."""
import re
from typing import Self
import uuid
import bcrypt
from dataclasses import dataclass, field
from enum import Enum
import bcrypt
from pydantic import BaseModel
from sqlalchemy import Select
from sqlalchemy.orm import selectinload
from sqlalchemy.future import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from sshecret_backend.models import Client, ClientAccessPolicy
RE_UUID = re.compile(
"^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$"
)
RelaxedId = uuid.UUID | str
class IdType(Enum):
"""Id type."""
ID = "id"
NAME = "name"
class FlexID(BaseModel):
"""Flexible identifier."""
type: IdType
value: RelaxedId
@classmethod
def id(cls, id: RelaxedId) -> Self:
"""Construct from ID."""
return cls(type=IdType.ID, value=id)
@classmethod
def name(cls, name: str) -> Self:
"""Construct from name."""
return cls(type=IdType.NAME, value=name)
@classmethod
def from_string(cls, value: str) -> Self:
"""Convert from path string."""
if value.startswith("id:"):
return cls.id(value[3:])
elif value.startswith("name:"):
return cls.name(value[5:])
return cls.name(value)
@dataclass
class NewClientVersion:
@ -60,7 +97,10 @@ def client_with_relationships() -> Select[tuple[Client]]:
async def resolve_client_id(
session: AsyncSession, name: str, version: int | None = None, include_deleted: bool = False,
session: AsyncSession,
name: str,
version: int | None = None,
include_deleted: bool = False,
) -> uuid.UUID | None:
"""Get the ID of a client name."""
if include_deleted:
@ -123,7 +163,8 @@ async def get_client_by_name(session: AsyncSession, name: str) -> Client | None:
async def refresh_client(session: AsyncSession, client: Client) -> None:
"""Refresh the client and load in all relationships."""
await session.refresh(
client, attribute_names=["secrets", "policies", "previous_version", "updated_at"]
client,
attribute_names=["secrets", "policies", "previous_version", "updated_at"],
)

View File

@ -1,86 +0,0 @@
"""Policies sub-api router factory."""
# pyright: reportUnusedFunction=false
import logging
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Annotated
from sshecret_backend.models import ClientAccessPolicy
from sshecret_backend.view_models import (
ClientPolicyView,
ClientPolicyUpdate,
)
from sshecret_backend.types import AsyncDBSessionDep
from sshecret_backend import audit
from .common import get_client_by_id_or_name, reload_client_with_relationships
LOG = logging.getLogger(__name__)
def get_policy_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Construct clients sub-api."""
router = APIRouter()
@router.get("/clients/{name}/policies/")
async def get_client_policies(
name: str, session: Annotated[AsyncSession, Depends(get_db_session)]
) -> ClientPolicyView:
"""Get client policies."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
return ClientPolicyView.from_client(client)
@router.put("/clients/{name}/policies/")
async def update_client_policies(
request: Request,
name: str,
policy_update: ClientPolicyUpdate,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientPolicyView:
"""Update client policies.
This is also how you delete policies.
"""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
# Remove old policies.
policies = await session.scalars(
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
)
deleted_policies: list[ClientAccessPolicy] = []
added_policies: list[ClientAccessPolicy] = []
for policy in policies.all():
await session.delete(policy)
deleted_policies.append(policy)
LOG.debug("Updating client policies with: %r", policy_update.sources)
for source in policy_update.sources:
LOG.debug("Source %r", source)
policy = ClientAccessPolicy(source=str(source), client_id=client.id)
session.add(policy)
added_policies.append(policy)
await session.flush()
await session.commit()
client = await reload_client_with_relationships(session, client)
for policy in deleted_policies:
await audit.audit_remove_policy(session, request, client, policy)
for policy in added_policies:
await audit.audit_update_policy(session, request, client, policy)
return ClientPolicyView.from_client(client)
return router

View File

@ -0,0 +1,9 @@
"""Common API schemas."""
from pydantic import BaseModel
class BodyValue(BaseModel):
"""A generic model with just a value parameter."""
value: str

View File

@ -1,257 +0,0 @@
"""Secrets sub-api factory."""
# pyright: reportUnusedFunction=false
import logging
from collections import defaultdict
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy import select
from typing import Annotated
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sshecret_backend.models import Client, ClientSecret
from sshecret_backend.view_models import (
ClientReference,
ClientSecretDetailList,
ClientSecretList,
ClientSecretPublic,
BodyValue,
ClientSecretResponse,
)
from sshecret_backend import audit
from sshecret_backend.types import AsyncDBSessionDep
from .common import get_client_by_id_or_name, get_client_by_name
LOG = logging.getLogger(__name__)
async def lookup_client_secret(
session: AsyncSession, client: Client, name: str
) -> ClientSecret | None:
"""Look up a secret for a client."""
statement = (
select(ClientSecret)
.where(ClientSecret.client_id == client.id)
.where(ClientSecret.name == name)
)
results = await session.scalars(statement)
return results.first()
def get_secrets_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Construct clients sub-api."""
router = APIRouter()
@router.post("/clients/{name}/secrets/")
async def add_secret_to_client(
request: Request,
name: str,
client_secret: ClientSecretPublic,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> None:
"""Add secret to a client."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
existing_secret = await lookup_client_secret(
session, client, client_secret.name
)
if existing_secret:
raise HTTPException(
status_code=400,
detail="Cannot add a secret. A different secret with the same name already exists.",
)
db_secret = ClientSecret(
name=client_secret.name, client_id=client.id, secret=client_secret.secret
)
session.add(db_secret)
await session.commit()
await session.refresh(db_secret)
await audit.audit_create_secret(session, request, client, db_secret)
@router.put("/clients/{name}/secrets/{secret_name}")
async def update_client_secret(
request: Request,
name: str,
secret_name: str,
secret_data: BodyValue,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretResponse:
"""Update a client secret.
This can also be used for destructive creates.
"""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
existing_secret = await lookup_client_secret(session, client, secret_name)
if existing_secret:
existing_secret.secret = secret_data.value
session.add(existing_secret)
await session.commit()
await session.refresh(existing_secret)
await audit.audit_update_secret(session, request, client, existing_secret)
return ClientSecretResponse.from_client_secret(existing_secret)
db_secret = ClientSecret(
name=secret_name,
client_id=client.id,
secret=secret_data.value,
)
session.add(db_secret)
await session.commit()
await session.refresh(db_secret)
await audit.audit_create_secret(session, request, client, db_secret)
return ClientSecretResponse.from_client_secret(db_secret)
@router.get("/clients/{name}/secrets/{secret_name}")
async def request_client_secret(
request: Request,
name: str,
secret_name: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretResponse:
"""Get a client secret."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
secret = await lookup_client_secret(session, client, secret_name)
if not secret:
raise HTTPException(
status_code=404, detail="Cannot find a secret with the given name."
)
response_model = ClientSecretResponse.from_client_secret(secret)
await audit.audit_access_secret(session, request, client, secret)
return response_model
@router.delete("/clients/{name}/secrets/{secret_name}")
async def delete_client_secret(
request: Request,
name: str,
secret_name: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> None:
"""Delete a secret."""
client = await get_client_by_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
secret = await lookup_client_secret(session, client, secret_name)
if not secret:
raise HTTPException(
status_code=404, detail="Cannot find a secret with the given name."
)
await session.delete(secret)
await session.commit()
await audit.audit_delete_secret(session, request, client, secret)
@router.get("/secrets/")
async def get_secret_map(
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> list[ClientSecretList]:
"""Get a list of all secrets and which clients have them."""
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
client_secrets = await session.scalars(
select(ClientSecret).options(selectinload(ClientSecret.client))
)
for client_secret in client_secrets.all():
if not client_secret.client:
if client_secret.name not in client_secret_map:
client_secret_map[client_secret.name] = []
continue
client_secret_map[client_secret.name].append(client_secret.client.name)
# audit.audit_client_secret_list(session, request)
return [
ClientSecretList(name=secret_name, clients=clients)
for secret_name, clients in client_secret_map.items()
]
@router.get("/secrets/detailed/")
async def get_detailed_secret_map(
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> list[ClientSecretDetailList]:
"""Get a list of all secrets and which clients have them."""
client_secrets: dict[str, ClientSecretDetailList] = {}
all_client_secrets = await session.execute(
select(ClientSecret).options(selectinload(ClientSecret.client))
)
for client_secret in all_client_secrets.scalars().all():
if client_secret.name not in client_secrets:
client_secrets[client_secret.name] = ClientSecretDetailList(
name=client_secret.name
)
client_secrets[client_secret.name].ids.append(str(client_secret.id))
if not client_secret.client:
continue
client_secrets[client_secret.name].clients.append(
ClientReference(
id=str(client_secret.client.id), name=client_secret.client.name
)
)
# `audit.audit_client_secret_list(session, request)
return list(client_secrets.values())
@router.get("/secrets/{name}")
async def get_secret_clients(
name: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretList:
"""Get a list of which clients has a named secret."""
clients: list[str] = []
client_secrets = await session.scalars(
select(ClientSecret)
.join(ClientSecret.client)
.options(selectinload(ClientSecret.client))
.where(ClientSecret.name == name)
.where(Client.is_active.is_(True))
)
for client_secret in client_secrets.all():
if not client_secret.client:
continue
clients.append(client_secret.client.name)
return ClientSecretList(name=name, clients=clients)
@router.get("/secrets/{name}/detailed")
async def get_secret_clients_detailed(
name: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretDetailList:
"""Get a list of which clients has a named secret."""
detail_list = ClientSecretDetailList(name=name)
client_secrets = await session.scalars(
select(ClientSecret)
.options(selectinload(ClientSecret.client))
.where(ClientSecret.name == name)
.where(ClientSecret.client.is_(Client.is_active))
)
for client_secret in client_secrets.all():
if not client_secret.client:
continue
detail_list.ids.append(str(client_secret.id))
detail_list.clients.append(
ClientReference(
id=str(client_secret.client.id), name=client_secret.client.name
)
)
return detail_list
return router

View File

@ -0,0 +1,328 @@
"""ClientSecret operations."""
import logging
import uuid
from datetime import datetime, timezone
from fastapi import HTTPException, Request
from sqlalchemy import Null, distinct, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sshecret_backend import audit
from sshecret_backend.api.common import (
FlexID,
IdType,
get_client_by_id,
resolve_client_id,
)
from sshecret_backend.api.secrets.schemas import (
ClientReference,
ClientSecretDetailList,
ClientSecretResponse,
ClientSecretStats,
DistinctClientSecretStats,
)
from sshecret_backend.models import Client, ClientSecret
LOG = logging.getLogger(__name__)
RelaxedId = uuid.UUID | str
def _id(id: RelaxedId) -> uuid.UUID:
"""Ensure that the ID is a uuid."""
if isinstance(id, str):
return uuid.UUID(id)
return id
class ClientSecretOperations:
"""Perform operations on a client's ClientSecrets"""
def __init__(
self,
session: AsyncSession,
request: Request,
client: FlexID,
include_deleted: bool = False,
) -> None:
"""Create operations class."""
self._client_id: uuid.UUID | None = None
self.client_name: str | None = None
self.client: Client | None = None
if client.type is IdType.ID:
self._client_id = _id(client.value)
else:
self.client_name = str(client.value)
self.include_deleted: bool = include_deleted
self.session: AsyncSession = session
self.request: Request = request
async def get_client_id(self) -> uuid.UUID | None:
"""Get client ID."""
if self._client_id:
LOG.debug("Returning previously resolved client ID.")
return self._client_id
if not self.client_name:
raise RuntimeError("Error: No client information registered.")
client_id = await resolve_client_id(
self.session, self.client_name, include_deleted=self.include_deleted
)
self._client_id = client_id
LOG.debug("Saving client ID %s", client_id)
return client_id
async def get_client(self) -> Client | None:
"""Get client."""
if self.client:
return self.client
client_id = await self.get_client_id()
if not client_id:
return None
client = await get_client_by_id(self.session, client_id)
self.client = client
return client
async def _get_client_secret(
self, secret_identifier: FlexID
) -> ClientSecret | None:
"""Get client secret private method."""
client = await self.get_client()
if not client:
LOG.debug("Client lookup failed.")
return None
if secret_identifier.type is IdType.ID:
match_id = _id(secret_identifier.value)
LOG.debug("Searching for secrets matching ID %r", match_id)
matches = [secret for secret in client.secrets if secret.id == match_id]
else:
LOG.debug("Searching for secrets matching name")
matches = [
secret
for secret in client.secrets
if secret.name == str(secret_identifier.value)
]
for secret_match in matches:
if secret_match.deleted:
# TODO: Add override for deleted.
LOG.debug("Found deleted secret")
continue
await audit.audit_access_secret(
self.session, self.request, client, secret_match
)
LOG.debug("Found matching secret")
return secret_match
LOG.debug("No secrets matched.")
return None
async def get_client_secret(
self, secret_identifier: FlexID
) -> ClientSecretResponse:
"""Get client secret public model."""
client_secret = await self._get_client_secret(secret_identifier)
if not client_secret:
raise HTTPException(status_code=404, detail="Unknown client or secret")
return ClientSecretResponse.from_client_secret(client_secret)
async def create_client_secret(
self, name: str, value: str, description: str | None = None
) -> ClientSecretResponse:
"""Create client Secret."""
client = await self.get_client()
if not client:
raise HTTPException(status_code=404, detail="Client not found.")
existing_secret = await self._get_client_secret(FlexID.name(name))
if existing_secret:
raise HTTPException(
status_code=400,
detail="Cannot add a secret. A different secret with the same name already exists.",
)
secret = ClientSecret(
name=name,
description=description,
client=client,
secret=value,
)
self.session.add(secret)
await self.session.commit()
await self.session.refresh(secret, attribute_names=["client", "updated_at"])
await audit.audit_create_secret(self.session, self.request, client, secret)
return ClientSecretResponse.from_client_secret(secret)
async def update_client_secret_value(
self, secret_identifier: FlexID, value: str
) -> ClientSecretResponse:
"""Update client secret value."""
client_secret = await self._get_client_secret(secret_identifier)
if not client_secret:
# We can use this method to create a secret too.
if secret_identifier.type is IdType.NAME:
return await self.create_client_secret(
str(secret_identifier.value), value
)
raise HTTPException(status_code=404, detail="Unknown client or secret")
client_secret.secret = value
self.session.add(client_secret)
await self.session.commit()
await self.session.refresh(client_secret, ["updated_at"])
await audit.audit_update_secret(
self.session, self.request, client_secret.client, client_secret
)
return ClientSecretResponse.from_client_secret(client_secret)
async def delete_client_secret(self, secret_identifier: FlexID) -> None:
"""Delete a client secret."""
client_secret = await self._get_client_secret(secret_identifier)
if not client_secret:
return
client_secret.deleted = True
client_secret.deleted_at = datetime.now(timezone.utc)
self.session.add(client_secret)
await self.session.commit()
await self.session.refresh(client_secret, ["deleted", "deleted_at"])
await audit.audit_delete_secret(
self.session, self.request, client_secret.client, client_secret
)
async def resolve_client_secret_mapping(
session: AsyncSession,
) -> list[ClientSecretDetailList]:
"""Resolve mapping of clients to secrets."""
result = await session.execute(
select(ClientSecret)
.join(ClientSecret.client)
.options(selectinload(ClientSecret.client))
.where(Client.is_active.is_not(False))
.where(ClientSecret.deleted.is_not(True))
)
client_secrets: dict[str, ClientSecretDetailList] = {}
for secret in result.scalars().all():
if secret.name not in client_secrets:
client_secrets[secret.name] = ClientSecretDetailList(name=secret.name)
client_secrets[secret.name].ids.append(str(secret.id))
if not secret.client:
continue
client_secrets[secret.name].clients.append(
ClientReference(id=str(secret.client.id), name=secret.client.name)
)
return list(client_secrets.values())
async def resolve_client_secret_clients(
session: AsyncSession, name: str, include_deleted: bool = False
) -> ClientSecretDetailList | None:
"""Resolve client association to a secret."""
statement = (
select(ClientSecret)
.options(selectinload(ClientSecret.client))
.where(ClientSecret.name == name)
)
if not include_deleted:
statement = statement.where(ClientSecret.deleted.is_not(True))
results = await session.execute(statement)
clients: ClientSecretDetailList | None = None
for client_secret in results.scalars().all():
if not clients:
# Ensure we don't create the object before we have at least one client.
clients = ClientSecretDetailList(name=name)
clients.ids.append(str(client_secret.id))
if client_secret.client:
clients.clients.append(
ClientReference(
id=str(client_secret.client.id), name=client_secret.client.name
)
)
return clients
async def _get_secrets_total(session: AsyncSession) -> int:
"""Get total amount of secrets excluding deleted."""
statement = (
select(func.count())
.select_from(ClientSecret)
.where(ClientSecret.deleted.is_not(True))
)
result = await session.execute(statement)
return result.scalar_one()
async def _get_distinct_secret_count(session: AsyncSession) -> int:
"""Get the amount of distinct secrets, excluding deleted."""
statement = (
select(func.count())
.select_from(ClientSecret)
.where(ClientSecret.deleted.is_not(True))
.distinct()
)
result = await session.execute(statement)
return result.scalar_one()
async def _get_secret_client_stats(
session: AsyncSession,
) -> list[DistinctClientSecretStats]:
"""Get stats for each named secret."""
statement = (
(
select(
ClientSecret.name, func.count(distinct(Client.id)).label("client_count")
)
)
.join(Client, ClientSecret.client_id)
.where(ClientSecret.deleted.is_not(True), Client.is_deleted.is_not(True))
.group_by(ClientSecret.name)
)
results = await session.execute(statement)
stats: list[DistinctClientSecretStats] = []
rows = results.all()
for row in rows:
secret_name, client_count = row.tuple()
stats.append(DistinctClientSecretStats(name=secret_name, clients=client_count))
return stats
async def _get_secrets_without_clients(
session: AsyncSession,
) -> int:
"""Get secrets without clients."""
statement = (
select(func.count(distinct(ClientSecret.name)))
.where(ClientSecret.deleted.is_not(True))
.where(ClientSecret.client_id.is_(Null))
)
result = await session.execute(statement)
return result.scalar_one()
async def get_client_secret_stats(session: AsyncSession) -> ClientSecretStats:
"""Get stats for the client secrets.
TODO: Implement a route with this.
"""
distinct_secrets = await _get_distinct_secret_count(session)
total_secrets = await _get_secrets_total(session)
secrets_without_clients = await _get_secrets_without_clients(session)
client_stats = await _get_secret_client_stats(session)
return ClientSecretStats(
distinct_secrets=distinct_secrets,
total_secrets=total_secrets,
secrets_without_clients=secrets_without_clients,
client_stats=client_stats,
)

View File

@ -0,0 +1,107 @@
"""Client Secret Router."""
# pyright: reportUnusedFunction=false
import logging
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sshecret_backend.api.common import FlexID
from sshecret_backend.types import AsyncDBSessionDep
from sshecret_backend.api.secrets.operations import (
ClientSecretOperations,
resolve_client_secret_clients,
resolve_client_secret_mapping,
)
from sshecret_backend.api.schemas import BodyValue
from sshecret_backend.api.secrets.schemas import (
ClientSecretDetailList,
ClientSecretPublic,
ClientSecretResponse,
)
LOG = logging.getLogger(__name__)
def create_client_secrets_router(get_db_session: AsyncDBSessionDep) -> APIRouter:
"""Create client secret router."""
router = APIRouter()
@router.post("/clients/{client_identifier}/secrets/")
async def add_secret_to_client(
request: Request,
client_identifier: str,
client_secret: ClientSecretPublic,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> None:
"""Add secret to a client."""
client = FlexID.from_string(client_identifier)
client_op = ClientSecretOperations(session, request, client)
await client_op.create_client_secret(
client_secret.name, client_secret.secret, client_secret.description
)
@router.put("/clients/{client_identifier}/secrets/{secret_identifier}")
async def update_client_secret(
request: Request,
client_identifier: str,
secret_identifier: str,
secret_data: BodyValue,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretResponse:
"""Update client secret."""
client = FlexID.from_string(client_identifier)
secret = FlexID.from_string(secret_identifier)
client_op = ClientSecretOperations(session, request, client)
return await client_op.update_client_secret_value(secret, secret_data.value)
@router.get("/clients/{client_identifier}/secrets/{secret_identifier}")
async def request_client_secret_named(
request: Request,
client_identifier: str,
secret_identifier: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretResponse:
"""Get a named secret from a named client."""
client = FlexID.from_string(client_identifier)
secret = FlexID.from_string(secret_identifier)
LOG.debug("Resolved client FlexID: %r", client)
LOG.debug("Resolved secret FlexID: %r", secret)
client_op = ClientSecretOperations(session, request, client)
return await client_op.get_client_secret(secret)
# TODO: delete_client_secret
@router.delete("/clients/{client_identifier}/secrets/{secret_identifier}")
async def delete_client_secret(
request: Request,
client_identifier: str,
secret_identifier: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> None:
"""Delete client secret."""
client = FlexID.from_string(client_identifier)
secret = FlexID.from_string(secret_identifier)
client_op = ClientSecretOperations(session, request, client)
await client_op.delete_client_secret(secret)
@router.get("/secrets/")
async def get_secret_map(
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> list[ClientSecretDetailList]:
"""Get a list of secrets and which clients have them."""
return await resolve_client_secret_mapping(session)
@router.get("/secrets/{name}")
async def get_secret_clients(
name: str,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientSecretDetailList:
"""Get a list of which clients has a named secret."""
result = await resolve_client_secret_clients(session, name)
if not result:
raise HTTPException(status_code=404, detail="Secret not found.")
return result
return router

View File

@ -1,21 +1,12 @@
"""Models for API views."""
"""Schemas for the secrets router."""
import uuid
from datetime import datetime
from typing import Self, override
from collections.abc import Sequence
import uuid
from pydantic import BaseModel, Field, IPvAnyAddress, IPvAnyNetwork
from pydantic import BaseModel, Field
from . import models
class BodyValue(BaseModel):
"""A generic model with just a value parameter."""
value: str
from sshecret_backend import models
class ClientSecretPublic(BaseModel):
@ -38,6 +29,7 @@ class ClientSecretPublic(BaseModel):
class ClientSecretResponse(ClientSecretPublic):
"""A secret view."""
id: uuid.UUID
created_at: datetime | None
updated_at: datetime | None = None
@ -47,6 +39,7 @@ class ClientSecretResponse(ClientSecretPublic):
"""Instantiate from ClientSecret."""
return cls(
id=client_secret.id,
name=client_secret.name,
secret=client_secret.secret,
created_at=client_secret.created_at,
@ -54,25 +47,6 @@ class ClientSecretResponse(ClientSecretPublic):
)
class ClientPolicyView(BaseModel):
"""Update object for client policy."""
sources: list[str] = ["0.0.0.0/0", "::/0"]
@classmethod
def from_client(cls, client: models.Client) -> Self:
"""Create from client."""
if not client.policies:
return cls()
return cls(sources=[policy.source for policy in client.policies])
class ClientPolicyUpdate(BaseModel):
"""Model for updating policies."""
sources: list[IPvAnyAddress | IPvAnyNetwork]
class ClientSecretList(BaseModel):
"""Model for aggregating identically named secrets."""
@ -95,31 +69,20 @@ class ClientSecretDetailList(BaseModel):
clients: list[ClientReference] = Field(default_factory=list)
class AuditView(BaseModel):
"""Audit log view."""
class DistinctClientSecretStats(BaseModel):
"""Stats for distinct client secrets."""
name: str
clients: int = 0
id: uuid.UUID | None = None
subsystem: models.SubSystem
message: str
operation: models.Operation
data: dict[str, str] | None = None
client_id: uuid.UUID | None = None
client_name: str | None = None
secret_id: uuid.UUID | None = None
secret_name: str | None = None
origin: str | None = None
timestamp: datetime | None = None
class ClientSecretStats(BaseModel):
"""Stats row for the clientSecret model.
Useful for pagination and statistics.
"""
class AuditInfo(BaseModel):
"""Information about audit information."""
entries: int
class AuditListResult(BaseModel):
"""Class to return when listing audit entries."""
results: Sequence[AuditView]
total: int
remaining: int
distinct_secrets: int
total_secrets: int
secrets_without_clients: int
client_stats: list[DistinctClientSecretStats] = Field(default_factory=list)

View File

@ -5,7 +5,14 @@ from fastapi import Request
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy, Operation, SubSystem
from .models import (
AuditLog,
Client,
ClientSecret,
ClientAccessPolicy,
Operation,
SubSystem,
)
def _get_origin(request: Request) -> str | None:
@ -128,6 +135,27 @@ async def audit_update_client(
await _write_audit_log(session, request, entry, commit)
async def audit_new_client_version(
session: AsyncSession,
request: Request,
old_client: Client,
new_client: Client,
commit: bool = True,
) -> None:
"""Audit an update secret event."""
entry = AuditLog(
operation=Operation.UPDATE,
client_id=old_client.id,
client_name=old_client.name,
message="Client data updated",
data={
"new_client_id": str(new_client.id),
"new_client_version": new_client.version,
},
)
await _write_audit_log(session, request, entry, commit)
async def audit_update_secret(
session: AsyncSession,
request: Request,
@ -224,6 +252,7 @@ async def audit_access_secret(
)
await _write_audit_log(session, request, entry, commit)
async def audit_client_secret_list(
session: AsyncSession, request: Request, commit: bool = True
) -> None:
@ -233,4 +262,3 @@ async def audit_client_secret_list(
message="All secret names and their clients was viewed",
)
await _write_audit_log(session, request, entry, commit)

View File

@ -2,6 +2,7 @@
import bcrypt
def hash_token(token: str) -> str:
"""Hash a token."""
pwbytes = token.encode("utf-8")

View File

@ -9,7 +9,8 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sshecret_backend.db import DatabaseSessionManager
from sshecret_backend.settings import BackendSettings
from .api import get_audit_api, get_policy_api, get_secrets_api
from .api.audit.router import create_audit_router
from .api.secrets.router import create_client_secrets_router
from .api.clients.router import create_client_router
from .auth import verify_token
from .models import (
@ -60,9 +61,8 @@ def get_backend_api(
dependencies=[Depends(validate_token)],
)
backend_api.include_router(get_audit_api(get_db_session))
backend_api.include_router(create_audit_router(get_db_session))
backend_api.include_router(create_client_router(get_db_session))
backend_api.include_router(get_policy_api(get_db_session))
backend_api.include_router(get_secrets_api(get_db_session))
backend_api.include_router(create_client_secrets_router(get_db_session))
return backend_api

View File

@ -8,7 +8,13 @@ from collections.abc import AsyncIterator, Generator, Callable
from contextlib import asynccontextmanager
from typing import Literal
from sqlalchemy import create_engine, Engine, event, select
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession, async_sessionmaker, create_async_engine, AsyncEngine
from sqlalchemy.ext.asyncio import (
AsyncConnection,
AsyncSession,
async_sessionmaker,
create_async_engine,
AsyncEngine,
)
from sqlalchemy.orm import sessionmaker, Session
@ -21,11 +27,14 @@ from .models import APIClient, SubSystem
LOG = logging.getLogger(__name__)
class DatabaseSessionManager:
def __init__(self, host: URL, **engine_kwargs: str):
self._engine: AsyncEngine | None = get_async_engine(host)
self._sessionmaker: async_sessionmaker[AsyncSession] | None = async_sessionmaker(autocommit=False, bind=self._engine, expire_on_commit=False)
self._sessionmaker: async_sessionmaker[AsyncSession] | None = (
async_sessionmaker(
autocommit=False, bind=self._engine, expire_on_commit=False
)
)
async def close(self):
if self._engine is None:
@ -81,7 +90,6 @@ def setup_database(
return engine, get_db_session
def get_engine(url: URL, echo: bool = False) -> Engine:
"""Initialize the engine."""
engine = create_engine(url, echo=echo)
@ -102,6 +110,7 @@ def get_async_engine(url: URL, echo: bool = False, **engine_kwargs: str) -> Asyn
"""Get an async engine."""
engine = create_async_engine(url, echo=echo, **engine_kwargs)
if url.drivername.startswith("sqlite+"):
@event.listens_for(engine.sync_engine, "connect")
def set_sqlite_pragma(
dbapi_connection: sqlite3.Connection, _connection_record: object
@ -113,18 +122,21 @@ def get_async_engine(url: URL, echo: bool = False, **engine_kwargs: str) -> Asyn
return engine
def create_api_token_with_value(session: Session, token: str, subsystem: Literal["admin", "sshd"]) -> None:
def create_api_token_with_value(
session: Session, token: str, subsystem: Literal["admin", "sshd"]
) -> None:
"""Create API token with a given value."""
existing = session.scalars(select(APIClient).where(APIClient.subsystem == SubSystem(subsystem))).first()
existing = session.scalars(
select(APIClient).where(APIClient.subsystem == SubSystem(subsystem))
).first()
if existing:
if verify_token(token, existing.token):
LOG.info("Token is up to date.")
return
LOG.info("Updating token value for subsystem %s", subsystem)
hashed = hash_token(token)
existing.token=hashed
existing.token = hashed
session.commit()
return
@ -135,12 +147,19 @@ def create_api_token_with_value(session: Session, token: str, subsystem: Literal
session.add(api_token)
session.commit()
def create_api_token(session: Session, subsystem: Literal["admin", "sshd", "test"], recreate: bool = False) -> str:
def create_api_token(
session: Session,
subsystem: Literal["admin", "sshd", "test"],
recreate: bool = False,
) -> str:
"""Create API token."""
subsys = SubSystem(subsystem)
token = secrets.token_urlsafe(32)
hashed = hash_token(token)
if existing := session.scalars(select(APIClient).where(APIClient.subsystem == subsys)).first():
if existing := session.scalars(
select(APIClient).where(APIClient.subsystem == subsys)
).first():
if not recreate:
raise RuntimeError("Error: A token already exist for this subsystem.")
existing.token = hashed

View File

@ -79,8 +79,7 @@ class Client(Base):
)
deleted_at: Mapped[datetime | None] = mapped_column(
sa.DateTime(timezone=True),
nullable=True
sa.DateTime(timezone=True), nullable=True
)
secrets: Mapped[list["ClientSecret"]] = relationship(
@ -93,9 +92,7 @@ class Client(Base):
nullable=True,
)
previous_version: Mapped["Client | None"] = relationship(
"Client",
remote_side=[id],
backref="versions"
"Client", remote_side=[id], backref="versions"
)
policies: Mapped[list["ClientAccessPolicy"]] = relationship(back_populates="client")
@ -142,7 +139,7 @@ class ClientSecret(Base):
sa.Uuid(as_uuid=True), sa.ForeignKey("client.id", ondelete="CASCADE")
)
client: Mapped[Client] = relationship(back_populates="secrets")
invalidated: Mapped[bool] = mapped_column(default=False)
deleted: Mapped[bool] = mapped_column(default=False)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
@ -154,6 +151,10 @@ class ClientSecret(Base):
onupdate=sa.func.now(),
)
deleted_at: Mapped[datetime | None] = mapped_column(
sa.DateTime(timezone=True), nullable=True
)
class APIClient(Base):
"""A client on the API.