Refactor backend views
This commit is contained in:
@ -0,0 +1,36 @@
|
||||
"""Remove invalidated, add deleted
|
||||
|
||||
Revision ID: b4e135ff347a
|
||||
Revises: f2dc50533f88
|
||||
Create Date: 2025-06-06 08:57:47.611854
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'b4e135ff347a'
|
||||
down_revision: Union[str, None] = 'f2dc50533f88'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('client_secret', sa.Column('deleted', sa.Boolean(), nullable=False))
|
||||
op.add_column('client_secret', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
|
||||
op.drop_column('client_secret', 'invalidated')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('client_secret', sa.Column('invalidated', sa.BOOLEAN(), nullable=False))
|
||||
op.drop_column('client_secret', 'deleted_at')
|
||||
op.drop_column('client_secret', 'deleted')
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,51 @@
|
||||
"""Make client object better
|
||||
|
||||
Revision ID: c251311b64c9
|
||||
Revises: 37329d9b5437
|
||||
Create Date: 2025-06-04 21:49:22.638698
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "c251311b64c9"
|
||||
down_revision: Union[str, None] = "37329d9b5437"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("client", schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column("version", sa.Integer(), nullable=False))
|
||||
batch_op.add_column(sa.Column("is_active", sa.Boolean(), nullable=False))
|
||||
batch_op.add_column(sa.Column("is_deleted", sa.Boolean(), nullable=False))
|
||||
batch_op.add_column(
|
||||
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
batch_op.add_column(sa.Column("parent_id", sa.Uuid(), nullable=True))
|
||||
batch_op.create_unique_constraint("uq_client_name_version", ["name", "version"])
|
||||
batch_op.create_foreign_key(
|
||||
"fk_client_parent", "client", ["parent_id"], ["id"], ondelete="SET NULL"
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("client", schema=None) as batch_op:
|
||||
batch_op.drop_constraint("fk_client_parent", type_="foreignkey")
|
||||
batch_op.drop_constraint("uq_client_name_version", type_="unique")
|
||||
batch_op.drop_column("parent_id")
|
||||
batch_op.drop_column("deleted_at")
|
||||
batch_op.drop_column("is_deleted")
|
||||
batch_op.drop_column("is_active")
|
||||
batch_op.drop_column("version")
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,27 @@
|
||||
"""Rename parent to previous_version for clarity
|
||||
|
||||
Revision ID: f2dc50533f88
|
||||
Revises: c251311b64c9
|
||||
Create Date: 2025-06-05 13:24:32.465927
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'f2dc50533f88'
|
||||
down_revision: Union[str, None] = 'c251311b64c9'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
op.alter_column("client", "parent_id", new_column_name="previous_version_id")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
op.alter_column("client", "previous_version_id", new_column_name="parent_id")
|
||||
@ -1,7 +1 @@
|
||||
"""API factory modules."""
|
||||
|
||||
from .audit import get_audit_api
|
||||
from .policies import get_policy_api
|
||||
from .secrets import get_secrets_api
|
||||
|
||||
__all__ = ["get_audit_api", "get_policy_api", "get_secrets_api"]
|
||||
|
||||
@ -0,0 +1 @@
|
||||
|
||||
@ -15,7 +15,7 @@ from typing import Annotated
|
||||
|
||||
from sshecret_backend.models import AuditLog, Operation, SubSystem
|
||||
from sshecret_backend.types import AsyncDBSessionDep
|
||||
from sshecret_backend.view_models import AuditInfo, AuditView, AuditListResult
|
||||
from .schemas import AuditInfo, AuditView, AuditListResult
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -58,7 +58,7 @@ class AuditFilter(BaseModel):
|
||||
]
|
||||
|
||||
|
||||
def get_audit_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||
def create_audit_router(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||
"""Construct audit sub-api."""
|
||||
router = APIRouter()
|
||||
|
||||
@ -70,11 +70,13 @@ def get_audit_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||
"""Get audit logs."""
|
||||
# audit.audit_access_audit_log(session, request)
|
||||
|
||||
total = (await session.scalars(
|
||||
select(func.count("*"))
|
||||
.select_from(AuditLog)
|
||||
.where(and_(True, *filters.filter_mapping))
|
||||
)).one()
|
||||
total = (
|
||||
await session.scalars(
|
||||
select(func.count("*"))
|
||||
.select_from(AuditLog)
|
||||
.where(and_(True, *filters.filter_mapping))
|
||||
)
|
||||
).one()
|
||||
|
||||
remaining = total - filters.offset
|
||||
statement = (
|
||||
@ -107,12 +109,12 @@ def get_audit_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||
|
||||
@router.get("/audit/info")
|
||||
async def get_audit_info(
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)]
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> AuditInfo:
|
||||
"""Get audit info."""
|
||||
audit_count = (await session.scalars(
|
||||
select(func.count("*")).select_from(AuditLog)
|
||||
)).one()
|
||||
audit_count = (
|
||||
await session.scalars(select(func.count("*")).select_from(AuditLog))
|
||||
).one()
|
||||
return AuditInfo(entries=audit_count)
|
||||
|
||||
return router
|
||||
@ -0,0 +1,40 @@
|
||||
"""Models for Audit API views."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
from sshecret_backend import models
|
||||
|
||||
|
||||
class AuditView(BaseModel):
|
||||
"""Audit log view."""
|
||||
|
||||
id: uuid.UUID | None = None
|
||||
subsystem: models.SubSystem
|
||||
message: str
|
||||
operation: models.Operation
|
||||
data: dict[str, str] | None = None
|
||||
client_id: uuid.UUID | None = None
|
||||
client_name: str | None = None
|
||||
secret_id: uuid.UUID | None = None
|
||||
secret_name: str | None = None
|
||||
origin: str | None = None
|
||||
timestamp: datetime | None = None
|
||||
|
||||
|
||||
class AuditInfo(BaseModel):
|
||||
"""Information about audit information."""
|
||||
|
||||
entries: int
|
||||
|
||||
|
||||
class AuditListResult(BaseModel):
|
||||
"""Class to return when listing audit entries."""
|
||||
|
||||
results: Sequence[AuditView]
|
||||
total: int
|
||||
remaining: int
|
||||
@ -0,0 +1,320 @@
|
||||
"""Client operations."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, cast
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.sql import Select
|
||||
from sshecret_backend import audit
|
||||
from sshecret_backend.api.common import (
|
||||
FlexID,
|
||||
IdType,
|
||||
RelaxedId,
|
||||
create_new_client_version,
|
||||
query_active_clients,
|
||||
get_client_by_id,
|
||||
resolve_client_id,
|
||||
refresh_client,
|
||||
reload_client_with_relationships,
|
||||
)
|
||||
from sshecret_backend.models import Client, ClientAccessPolicy
|
||||
from .schemas import (
|
||||
ClientListParams,
|
||||
ClientCreate,
|
||||
ClientPolicyView,
|
||||
ClientView,
|
||||
ClientQueryResult,
|
||||
ClientPolicyUpdate,
|
||||
)
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _id(id: RelaxedId) -> uuid.UUID:
|
||||
"""Ensure that the ID is a uuid."""
|
||||
if isinstance(id, str):
|
||||
return uuid.UUID(id)
|
||||
return id
|
||||
|
||||
|
||||
class ClientOperations:
|
||||
"""Perform operations on a client."""
|
||||
|
||||
def __init__(self, session: AsyncSession, request: Request) -> None:
|
||||
"""Create operations class."""
|
||||
self.session: AsyncSession = session
|
||||
self.request: Request = request
|
||||
|
||||
self._client_id: uuid.UUID | None = None
|
||||
# self.client_name: str | None = None
|
||||
# if client.type is IdType.ID:
|
||||
# self._client_id = _id(client.value)
|
||||
# else:
|
||||
# self.client_name = str(client.value)
|
||||
|
||||
async def get_client_id(
|
||||
self,
|
||||
client: FlexID,
|
||||
version: int | None = None,
|
||||
) -> uuid.UUID | None:
|
||||
"""Get client ID."""
|
||||
if self._client_id:
|
||||
LOG.debug("Returning previously resolved client ID.")
|
||||
return self._client_id
|
||||
|
||||
if client.type is IdType.ID:
|
||||
self._client_id = _id(client.value)
|
||||
return self._client_id
|
||||
|
||||
client_name = str(client.value)
|
||||
client_id = await resolve_client_id(
|
||||
self.session,
|
||||
client_name,
|
||||
version=version,
|
||||
)
|
||||
if not client_id:
|
||||
return None
|
||||
self._client_id = client_id
|
||||
LOG.debug("Saving client ID %s", client_id)
|
||||
return client_id
|
||||
|
||||
async def _get_client(
|
||||
self,
|
||||
client: FlexID,
|
||||
version: int | None = None,
|
||||
) -> Client | None:
|
||||
"""Get client."""
|
||||
client_id = await self.get_client_id(client, version=version)
|
||||
if not client_id:
|
||||
return None
|
||||
db_client = await get_client_by_id(self.session, client_id)
|
||||
return db_client
|
||||
|
||||
async def get_client(
|
||||
self,
|
||||
client: FlexID,
|
||||
version: int | None = None,
|
||||
) -> ClientView:
|
||||
"""Get public client object."""
|
||||
db_client = await self._get_client(client, version)
|
||||
if not db_client:
|
||||
raise HTTPException(status_code=404, detail="Client not found.")
|
||||
return ClientView.from_client(db_client)
|
||||
|
||||
async def create_client(
|
||||
self,
|
||||
create_model: ClientCreate,
|
||||
) -> ClientView:
|
||||
"""Create a new client."""
|
||||
existing_id = await self.get_client_id(FlexID.name(create_model.name))
|
||||
if existing_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Error: A client already exists with this name."
|
||||
)
|
||||
client = create_model.to_client()
|
||||
self.session.add(client)
|
||||
await self.session.flush()
|
||||
await self.session.commit()
|
||||
await refresh_client(self.session, client)
|
||||
await audit.audit_create_client(self.session, self.request, client)
|
||||
return ClientView.from_client(client)
|
||||
|
||||
async def delete_client(self, client: FlexID) -> None:
|
||||
"""Delete client."""
|
||||
db_client = await self._get_client(client)
|
||||
if not db_client:
|
||||
return
|
||||
if db_client.is_deleted:
|
||||
return
|
||||
db_client.is_deleted = True
|
||||
db_client.deleted_at = datetime.now(timezone.utc)
|
||||
self.session.add(db_client)
|
||||
await self.session.commit()
|
||||
await audit.audit_delete_client(self.session, self.request, db_client)
|
||||
|
||||
async def update_client(
|
||||
self,
|
||||
client: FlexID,
|
||||
client_update: ClientCreate,
|
||||
) -> ClientView:
|
||||
"""Update client details."""
|
||||
db_client = await self._get_client(client)
|
||||
if not db_client:
|
||||
raise HTTPException(status_code=404, detail="Client not found.")
|
||||
|
||||
if (
|
||||
client_update.public_key
|
||||
and client_update.public_key != db_client.public_key
|
||||
):
|
||||
return await new_client_version(
|
||||
self.session,
|
||||
db_client.id,
|
||||
client_update.public_key,
|
||||
client_update.name,
|
||||
client_update.description,
|
||||
)
|
||||
db_client.name = client_update.name
|
||||
db_client.description = client_update.description
|
||||
self.session.add(db_client)
|
||||
await self.session.commit()
|
||||
await refresh_client(self.session, db_client)
|
||||
await audit.audit_update_client(self.session, self.request, db_client)
|
||||
return ClientView.from_client(db_client)
|
||||
|
||||
async def new_client_version(
|
||||
self,
|
||||
client: FlexID,
|
||||
public_key: str,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> ClientView:
|
||||
"""Update a client to a new version."""
|
||||
current_client = await self._get_client(client)
|
||||
if not current_client:
|
||||
raise HTTPException(status_code=404, detail="Client not found.")
|
||||
new_client = await create_new_client_version(
|
||||
self.session, current_client, public_key
|
||||
)
|
||||
if name:
|
||||
new_client.name = name
|
||||
if description:
|
||||
new_client.description = description
|
||||
|
||||
current_client.is_active = False
|
||||
self.session.add(current_client)
|
||||
|
||||
await self.session.commit()
|
||||
|
||||
await refresh_client(self.session, new_client)
|
||||
await audit.audit_new_client_version(
|
||||
self.session, self.request, current_client, new_client
|
||||
)
|
||||
return ClientView.from_client(new_client)
|
||||
|
||||
async def get_client_policies(self, client: FlexID) -> ClientPolicyView:
|
||||
"""Get client policies."""
|
||||
db_client = await self._get_client(client)
|
||||
if not db_client:
|
||||
raise HTTPException(status_code=404, detail="Client not found.")
|
||||
return ClientPolicyView.from_client(db_client)
|
||||
|
||||
async def update_client_policies(
|
||||
self, client: FlexID, policy_update: ClientPolicyUpdate
|
||||
) -> ClientPolicyView:
|
||||
"""Update client policies."""
|
||||
db_client = await self._get_client(client)
|
||||
if not db_client:
|
||||
raise HTTPException(status_code=404, detail="Client not found.")
|
||||
|
||||
policies = await self.session.scalars(
|
||||
select(ClientAccessPolicy).where(
|
||||
ClientAccessPolicy.client_id == db_client.id
|
||||
)
|
||||
)
|
||||
deleted_policies: list[ClientAccessPolicy] = []
|
||||
added_policies: list[ClientAccessPolicy] = []
|
||||
for policy in policies.all():
|
||||
await self.session.delete(policy)
|
||||
deleted_policies.append(policy)
|
||||
|
||||
LOG.debug("Updating client policies with: %r", policy_update.sources)
|
||||
for source in policy_update.sources:
|
||||
LOG.debug("Source %r", source)
|
||||
policy = ClientAccessPolicy(source=str(source), client_id=db_client.id)
|
||||
self.session.add(policy)
|
||||
added_policies.append(policy)
|
||||
|
||||
await self.session.flush()
|
||||
await self.session.commit()
|
||||
|
||||
db_client = await reload_client_with_relationships(self.session, db_client)
|
||||
for policy in deleted_policies:
|
||||
await audit.audit_remove_policy(
|
||||
self.session, self.request, db_client, policy
|
||||
)
|
||||
|
||||
for policy in added_policies:
|
||||
await audit.audit_update_policy(
|
||||
self.session, self.request, db_client, policy
|
||||
)
|
||||
|
||||
return ClientPolicyView.from_client(db_client)
|
||||
|
||||
|
||||
def filter_client_statement(
|
||||
statement: Select[Any], params: ClientListParams, ignore_limits: bool = False
|
||||
) -> Select[Any]:
|
||||
"""Filter a statement with the provided params."""
|
||||
if params.id:
|
||||
statement = statement.where(Client.id == params.id)
|
||||
|
||||
if params.name:
|
||||
statement = statement.where(Client.name == params.name)
|
||||
elif params.name__like:
|
||||
statement = statement.where(Client.name.like(params.name__like))
|
||||
elif params.name__contains:
|
||||
statement = statement.where(Client.name.contains(params.name__contains))
|
||||
|
||||
if ignore_limits:
|
||||
return statement
|
||||
|
||||
return statement.limit(params.limit).offset(params.offset)
|
||||
|
||||
|
||||
async def get_clients(
|
||||
session: AsyncSession,
|
||||
filter_query: ClientListParams,
|
||||
) -> ClientQueryResult:
|
||||
"""Get Clients."""
|
||||
count_statement = select(func.count("*")).select_from(Client)
|
||||
count_statement = cast(
|
||||
Select[tuple[int]],
|
||||
filter_client_statement(count_statement, filter_query, True),
|
||||
)
|
||||
|
||||
total_results = (await session.scalars(count_statement)).one()
|
||||
|
||||
statement = filter_client_statement(query_active_clients(), filter_query, False)
|
||||
|
||||
results = await session.scalars(statement)
|
||||
remainder = total_results - filter_query.offset - filter_query.limit
|
||||
if remainder < 0:
|
||||
remainder = 0
|
||||
|
||||
clients = list(results.all())
|
||||
clients_view = ClientView.from_client_list(clients)
|
||||
return ClientQueryResult(
|
||||
clients=clients_view,
|
||||
total_results=total_results,
|
||||
remaining_results=remainder,
|
||||
)
|
||||
|
||||
|
||||
async def new_client_version(
|
||||
session: AsyncSession,
|
||||
client_id: uuid.UUID,
|
||||
public_key: str,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> ClientView:
|
||||
"""Update a client to a new version."""
|
||||
current_client = await get_client_by_id(session, client_id)
|
||||
if not current_client:
|
||||
raise ValueError("Client not found.")
|
||||
new_client = await create_new_client_version(session, current_client, public_key)
|
||||
if name:
|
||||
new_client.name = name
|
||||
if description:
|
||||
new_client.description = description
|
||||
|
||||
current_client.is_active = False
|
||||
session.add(current_client)
|
||||
|
||||
await session.commit()
|
||||
|
||||
return ClientView.from_client(new_client)
|
||||
@ -0,0 +1,122 @@
|
||||
"""Client router."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
#
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from sshecret_backend.types import AsyncDBSessionDep
|
||||
from sshecret_backend.api.clients.schemas import (
|
||||
ClientCreate,
|
||||
ClientListParams,
|
||||
ClientQueryResult,
|
||||
ClientUpdate,
|
||||
ClientView,
|
||||
ClientPolicyUpdate,
|
||||
ClientPolicyView,
|
||||
)
|
||||
from sshecret_backend.api.clients import operations
|
||||
from sshecret_backend.api.clients.operations import ClientOperations
|
||||
from sshecret_backend.api.common import FlexID
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_client_router(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||
"""Create client router."""
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/clients/")
|
||||
async def route_get_clients(
|
||||
filter_query: Annotated[ClientListParams, Query()],
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientQueryResult:
|
||||
return await operations.get_clients(session, filter_query)
|
||||
|
||||
@router.post("/clients/")
|
||||
async def create_client(
|
||||
request: Request,
|
||||
client: ClientCreate,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientView:
|
||||
"""Create client."""
|
||||
client_op = ClientOperations(session, request)
|
||||
return await client_op.create_client(client)
|
||||
|
||||
@router.get("/clients/{client_identifier}")
|
||||
async def fetch_client_by_name(
|
||||
request: Request,
|
||||
client_identifier: str,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
version: Annotated[int | None, Query()] = None,
|
||||
) -> ClientView:
|
||||
"""Fetch client by name."""
|
||||
client_id = FlexID.from_string(client_identifier)
|
||||
client_op = ClientOperations(session, request)
|
||||
return await client_op.get_client(client_id, version)
|
||||
|
||||
@router.delete("/clients/{client_identifier}")
|
||||
async def delete_client(
|
||||
request: Request,
|
||||
client_identifier: str,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> None:
|
||||
"""Delete a client."""
|
||||
client_id = FlexID.from_string(client_identifier)
|
||||
client_op = ClientOperations(session, request)
|
||||
await client_op.delete_client(client_id)
|
||||
|
||||
@router.post("/clients/{client_identifier}/public-key")
|
||||
async def update_client_public_key(
|
||||
request: Request,
|
||||
client_identifier: str,
|
||||
client_update: ClientUpdate,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientView:
|
||||
"""Update client public key."""
|
||||
client_id = FlexID.from_string(client_identifier)
|
||||
client_op = ClientOperations(session, request)
|
||||
return await client_op.new_client_version(client_id, client_update.public_key)
|
||||
|
||||
@router.put("/clients/{client_identifier}")
|
||||
async def update_client_by_name(
|
||||
request: Request,
|
||||
client_identifier: str,
|
||||
client_update: ClientCreate,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientView:
|
||||
"""Update client by name."""
|
||||
client_id = FlexID.from_string(client_identifier)
|
||||
client_op = ClientOperations(session, request)
|
||||
return await client_op.update_client(client_id, client_update)
|
||||
|
||||
@router.get("/clients/{client_identifier}/policies/")
|
||||
async def get_client_policies(
|
||||
request: Request,
|
||||
client_identifier: str,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientPolicyView:
|
||||
"""Get client policies."""
|
||||
client_id = FlexID.from_string(client_identifier)
|
||||
client_op = ClientOperations(session, request)
|
||||
return await client_op.get_client_policies(client_id)
|
||||
|
||||
@router.put("/clients/{client_identifier}/policies/")
|
||||
async def update_client_policies(
|
||||
request: Request,
|
||||
client_identifier: str,
|
||||
policy_update: ClientPolicyUpdate,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientPolicyView:
|
||||
"""Update client policies.
|
||||
|
||||
This is also how you delete policies.
|
||||
"""
|
||||
client_id = FlexID.from_string(client_identifier)
|
||||
client_op = ClientOperations(session, request)
|
||||
return await client_op.update_client_policies(client_id, policy_update)
|
||||
|
||||
return router
|
||||
@ -0,0 +1,137 @@
|
||||
"""Client related schemas."""
|
||||
|
||||
from typing import Annotated, Self
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import (
|
||||
AfterValidator,
|
||||
BaseModel,
|
||||
Field,
|
||||
IPvAnyAddress,
|
||||
IPvAnyNetwork,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from sshecret.crypto import public_key_validator
|
||||
|
||||
from sshecret_backend import models
|
||||
|
||||
|
||||
class ClientView(BaseModel):
|
||||
"""View for a single client."""
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
description: str | None = None
|
||||
public_key: str
|
||||
policies: list[str] = ["0.0.0.0/0", "::/0"]
|
||||
is_active: bool = True
|
||||
is_deleted: bool = False
|
||||
secrets: list[str] = Field(default_factory=list)
|
||||
created_at: datetime | None
|
||||
updated_at: datetime | None = None
|
||||
deleted_at: datetime | None = None
|
||||
|
||||
@classmethod
|
||||
def from_client_list(cls, clients: list[models.Client]) -> list[Self]:
|
||||
"""Generate a list of responses from a list of clients."""
|
||||
responses: list[Self] = [cls.from_client(client) for client in clients]
|
||||
return responses
|
||||
|
||||
@classmethod
|
||||
def from_client(cls, client: models.Client) -> Self:
|
||||
"""Instantiate from a client."""
|
||||
view = cls(
|
||||
id=client.id,
|
||||
name=client.name,
|
||||
description=client.description,
|
||||
public_key=client.public_key,
|
||||
created_at=client.created_at,
|
||||
updated_at=client.updated_at or None,
|
||||
deleted_at=client.deleted_at or None,
|
||||
is_active=client.is_active,
|
||||
is_deleted=client.is_deleted,
|
||||
)
|
||||
if client.secrets:
|
||||
view.secrets = [secret.name for secret in client.secrets]
|
||||
|
||||
if client.policies:
|
||||
view.policies = [policy.source for policy in client.policies]
|
||||
|
||||
return view
|
||||
|
||||
|
||||
class ClientQueryResult(BaseModel):
|
||||
"""Result class for queries towards the client list."""
|
||||
|
||||
clients: list[ClientView] = Field(default_factory=list)
|
||||
total_results: int
|
||||
remaining_results: int
|
||||
|
||||
|
||||
class ClientListParams(BaseModel):
|
||||
"""Client list parameters."""
|
||||
|
||||
limit: int = Field(100, gt=0, le=100)
|
||||
offset: int = Field(0, ge=0)
|
||||
id: uuid.UUID | None = None
|
||||
name: str | None = None
|
||||
name__like: str | None = None
|
||||
name__contains: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_expressions(self) -> Self:
|
||||
"""Validate mutually exclusive expression."""
|
||||
name_filter = False
|
||||
if self.name__like or self.name__contains:
|
||||
name_filter = True
|
||||
if self.name__like and self.name__contains:
|
||||
raise ValueError("You may only specify one name expression")
|
||||
if self.name and name_filter:
|
||||
raise ValueError(
|
||||
"You must either specify name or one of name__like or name__contains"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class ClientCreate(BaseModel):
|
||||
"""Model to create a client."""
|
||||
|
||||
name: str
|
||||
description: str | None = None
|
||||
public_key: Annotated[str, AfterValidator(public_key_validator)]
|
||||
|
||||
def to_client(self) -> models.Client:
|
||||
"""Instantiate a client."""
|
||||
return models.Client(
|
||||
name=self.name,
|
||||
public_key=self.public_key,
|
||||
description=self.description,
|
||||
)
|
||||
|
||||
|
||||
class ClientUpdate(BaseModel):
|
||||
"""Model to update the client public key."""
|
||||
|
||||
public_key: Annotated[str, AfterValidator(public_key_validator)]
|
||||
|
||||
|
||||
class ClientPolicyView(BaseModel):
|
||||
"""Update object for client policy."""
|
||||
|
||||
sources: list[str] = ["0.0.0.0/0", "::/0"]
|
||||
|
||||
@classmethod
|
||||
def from_client(cls, client: models.Client) -> Self:
|
||||
"""Create from client."""
|
||||
if not client.policies:
|
||||
return cls()
|
||||
return cls(sources=[policy.source for policy in client.policies])
|
||||
|
||||
|
||||
class ClientPolicyUpdate(BaseModel):
|
||||
"""Model for updating policies."""
|
||||
|
||||
sources: list[IPvAnyAddress | IPvAnyNetwork]
|
||||
@ -1,21 +1,58 @@
|
||||
"""Common helpers."""
|
||||
|
||||
import re
|
||||
from typing import Self
|
||||
import uuid
|
||||
import bcrypt
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
import bcrypt
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sshecret_backend.models import Client, ClientAccessPolicy
|
||||
|
||||
RE_UUID = re.compile(
|
||||
"^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$"
|
||||
)
|
||||
|
||||
RelaxedId = uuid.UUID | str
|
||||
|
||||
|
||||
class IdType(Enum):
|
||||
"""Id type."""
|
||||
|
||||
ID = "id"
|
||||
NAME = "name"
|
||||
|
||||
|
||||
class FlexID(BaseModel):
|
||||
"""Flexible identifier."""
|
||||
|
||||
type: IdType
|
||||
value: RelaxedId
|
||||
|
||||
@classmethod
|
||||
def id(cls, id: RelaxedId) -> Self:
|
||||
"""Construct from ID."""
|
||||
return cls(type=IdType.ID, value=id)
|
||||
|
||||
@classmethod
|
||||
def name(cls, name: str) -> Self:
|
||||
"""Construct from name."""
|
||||
return cls(type=IdType.NAME, value=name)
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str) -> Self:
|
||||
"""Convert from path string."""
|
||||
if value.startswith("id:"):
|
||||
return cls.id(value[3:])
|
||||
elif value.startswith("name:"):
|
||||
return cls.name(value[5:])
|
||||
return cls.name(value)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewClientVersion:
|
||||
@ -60,7 +97,10 @@ def client_with_relationships() -> Select[tuple[Client]]:
|
||||
|
||||
|
||||
async def resolve_client_id(
|
||||
session: AsyncSession, name: str, version: int | None = None, include_deleted: bool = False,
|
||||
session: AsyncSession,
|
||||
name: str,
|
||||
version: int | None = None,
|
||||
include_deleted: bool = False,
|
||||
) -> uuid.UUID | None:
|
||||
"""Get the ID of a client name."""
|
||||
if include_deleted:
|
||||
@ -123,7 +163,8 @@ async def get_client_by_name(session: AsyncSession, name: str) -> Client | None:
|
||||
async def refresh_client(session: AsyncSession, client: Client) -> None:
|
||||
"""Refresh the client and load in all relationships."""
|
||||
await session.refresh(
|
||||
client, attribute_names=["secrets", "policies", "previous_version", "updated_at"]
|
||||
client,
|
||||
attribute_names=["secrets", "policies", "previous_version", "updated_at"],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,86 +0,0 @@
|
||||
"""Policies sub-api router factory."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Annotated
|
||||
|
||||
from sshecret_backend.models import ClientAccessPolicy
|
||||
from sshecret_backend.view_models import (
|
||||
ClientPolicyView,
|
||||
ClientPolicyUpdate,
|
||||
)
|
||||
from sshecret_backend.types import AsyncDBSessionDep
|
||||
from sshecret_backend import audit
|
||||
from .common import get_client_by_id_or_name, reload_client_with_relationships
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_policy_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||
"""Construct clients sub-api."""
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/clients/{name}/policies/")
|
||||
async def get_client_policies(
|
||||
name: str, session: Annotated[AsyncSession, Depends(get_db_session)]
|
||||
) -> ClientPolicyView:
|
||||
"""Get client policies."""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
return ClientPolicyView.from_client(client)
|
||||
|
||||
@router.put("/clients/{name}/policies/")
|
||||
async def update_client_policies(
|
||||
request: Request,
|
||||
name: str,
|
||||
policy_update: ClientPolicyUpdate,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientPolicyView:
|
||||
"""Update client policies.
|
||||
|
||||
This is also how you delete policies.
|
||||
"""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
# Remove old policies.
|
||||
policies = await session.scalars(
|
||||
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
|
||||
)
|
||||
deleted_policies: list[ClientAccessPolicy] = []
|
||||
added_policies: list[ClientAccessPolicy] = []
|
||||
for policy in policies.all():
|
||||
await session.delete(policy)
|
||||
deleted_policies.append(policy)
|
||||
|
||||
LOG.debug("Updating client policies with: %r", policy_update.sources)
|
||||
for source in policy_update.sources:
|
||||
LOG.debug("Source %r", source)
|
||||
policy = ClientAccessPolicy(source=str(source), client_id=client.id)
|
||||
session.add(policy)
|
||||
added_policies.append(policy)
|
||||
|
||||
await session.flush()
|
||||
await session.commit()
|
||||
|
||||
client = await reload_client_with_relationships(session, client)
|
||||
for policy in deleted_policies:
|
||||
await audit.audit_remove_policy(session, request, client, policy)
|
||||
|
||||
for policy in added_policies:
|
||||
await audit.audit_update_policy(session, request, client, policy)
|
||||
|
||||
return ClientPolicyView.from_client(client)
|
||||
|
||||
return router
|
||||
@ -0,0 +1,9 @@
|
||||
"""Common API schemas."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BodyValue(BaseModel):
|
||||
"""A generic model with just a value parameter."""
|
||||
|
||||
value: str
|
||||
@ -1,257 +0,0 @@
|
||||
"""Secrets sub-api factory."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy import select
|
||||
from typing import Annotated
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from sshecret_backend.models import Client, ClientSecret
|
||||
from sshecret_backend.view_models import (
|
||||
ClientReference,
|
||||
ClientSecretDetailList,
|
||||
ClientSecretList,
|
||||
ClientSecretPublic,
|
||||
BodyValue,
|
||||
ClientSecretResponse,
|
||||
)
|
||||
from sshecret_backend import audit
|
||||
from sshecret_backend.types import AsyncDBSessionDep
|
||||
from .common import get_client_by_id_or_name, get_client_by_name
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def lookup_client_secret(
|
||||
session: AsyncSession, client: Client, name: str
|
||||
) -> ClientSecret | None:
|
||||
"""Look up a secret for a client."""
|
||||
statement = (
|
||||
select(ClientSecret)
|
||||
.where(ClientSecret.client_id == client.id)
|
||||
.where(ClientSecret.name == name)
|
||||
)
|
||||
results = await session.scalars(statement)
|
||||
return results.first()
|
||||
|
||||
|
||||
def get_secrets_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||
"""Construct clients sub-api."""
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/clients/{name}/secrets/")
|
||||
async def add_secret_to_client(
|
||||
request: Request,
|
||||
name: str,
|
||||
client_secret: ClientSecretPublic,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> None:
|
||||
"""Add secret to a client."""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
existing_secret = await lookup_client_secret(
|
||||
session, client, client_secret.name
|
||||
)
|
||||
if existing_secret:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot add a secret. A different secret with the same name already exists.",
|
||||
)
|
||||
db_secret = ClientSecret(
|
||||
name=client_secret.name, client_id=client.id, secret=client_secret.secret
|
||||
)
|
||||
session.add(db_secret)
|
||||
await session.commit()
|
||||
await session.refresh(db_secret)
|
||||
await audit.audit_create_secret(session, request, client, db_secret)
|
||||
|
||||
@router.put("/clients/{name}/secrets/{secret_name}")
|
||||
async def update_client_secret(
|
||||
request: Request,
|
||||
name: str,
|
||||
secret_name: str,
|
||||
secret_data: BodyValue,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientSecretResponse:
|
||||
"""Update a client secret.
|
||||
|
||||
This can also be used for destructive creates.
|
||||
"""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
existing_secret = await lookup_client_secret(session, client, secret_name)
|
||||
if existing_secret:
|
||||
existing_secret.secret = secret_data.value
|
||||
|
||||
session.add(existing_secret)
|
||||
await session.commit()
|
||||
await session.refresh(existing_secret)
|
||||
await audit.audit_update_secret(session, request, client, existing_secret)
|
||||
return ClientSecretResponse.from_client_secret(existing_secret)
|
||||
|
||||
db_secret = ClientSecret(
|
||||
name=secret_name,
|
||||
client_id=client.id,
|
||||
secret=secret_data.value,
|
||||
)
|
||||
session.add(db_secret)
|
||||
await session.commit()
|
||||
await session.refresh(db_secret)
|
||||
await audit.audit_create_secret(session, request, client, db_secret)
|
||||
return ClientSecretResponse.from_client_secret(db_secret)
|
||||
|
||||
@router.get("/clients/{name}/secrets/{secret_name}")
|
||||
async def request_client_secret(
|
||||
request: Request,
|
||||
name: str,
|
||||
secret_name: str,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientSecretResponse:
|
||||
"""Get a client secret."""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
secret = await lookup_client_secret(session, client, secret_name)
|
||||
if not secret:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a secret with the given name."
|
||||
)
|
||||
|
||||
response_model = ClientSecretResponse.from_client_secret(secret)
|
||||
await audit.audit_access_secret(session, request, client, secret)
|
||||
return response_model
|
||||
|
||||
@router.delete("/clients/{name}/secrets/{secret_name}")
|
||||
async def delete_client_secret(
|
||||
request: Request,
|
||||
name: str,
|
||||
secret_name: str,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> None:
|
||||
"""Delete a secret."""
|
||||
client = await get_client_by_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
secret = await lookup_client_secret(session, client, secret_name)
|
||||
if not secret:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a secret with the given name."
|
||||
)
|
||||
|
||||
await session.delete(secret)
|
||||
await session.commit()
|
||||
await audit.audit_delete_secret(session, request, client, secret)
|
||||
|
||||
@router.get("/secrets/")
|
||||
async def get_secret_map(
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> list[ClientSecretList]:
|
||||
"""Get a list of all secrets and which clients have them."""
|
||||
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
|
||||
client_secrets = await session.scalars(
|
||||
select(ClientSecret).options(selectinload(ClientSecret.client))
|
||||
)
|
||||
for client_secret in client_secrets.all():
|
||||
if not client_secret.client:
|
||||
if client_secret.name not in client_secret_map:
|
||||
client_secret_map[client_secret.name] = []
|
||||
continue
|
||||
client_secret_map[client_secret.name].append(client_secret.client.name)
|
||||
# audit.audit_client_secret_list(session, request)
|
||||
return [
|
||||
ClientSecretList(name=secret_name, clients=clients)
|
||||
for secret_name, clients in client_secret_map.items()
|
||||
]
|
||||
|
||||
@router.get("/secrets/detailed/")
|
||||
async def get_detailed_secret_map(
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> list[ClientSecretDetailList]:
|
||||
"""Get a list of all secrets and which clients have them."""
|
||||
client_secrets: dict[str, ClientSecretDetailList] = {}
|
||||
all_client_secrets = await session.execute(
|
||||
select(ClientSecret).options(selectinload(ClientSecret.client))
|
||||
)
|
||||
for client_secret in all_client_secrets.scalars().all():
|
||||
if client_secret.name not in client_secrets:
|
||||
client_secrets[client_secret.name] = ClientSecretDetailList(
|
||||
name=client_secret.name
|
||||
)
|
||||
client_secrets[client_secret.name].ids.append(str(client_secret.id))
|
||||
if not client_secret.client:
|
||||
continue
|
||||
client_secrets[client_secret.name].clients.append(
|
||||
ClientReference(
|
||||
id=str(client_secret.client.id), name=client_secret.client.name
|
||||
)
|
||||
)
|
||||
# `audit.audit_client_secret_list(session, request)
|
||||
return list(client_secrets.values())
|
||||
|
||||
@router.get("/secrets/{name}")
|
||||
async def get_secret_clients(
|
||||
name: str,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientSecretList:
|
||||
"""Get a list of which clients has a named secret."""
|
||||
clients: list[str] = []
|
||||
client_secrets = await session.scalars(
|
||||
select(ClientSecret)
|
||||
.join(ClientSecret.client)
|
||||
.options(selectinload(ClientSecret.client))
|
||||
.where(ClientSecret.name == name)
|
||||
.where(Client.is_active.is_(True))
|
||||
)
|
||||
for client_secret in client_secrets.all():
|
||||
if not client_secret.client:
|
||||
continue
|
||||
clients.append(client_secret.client.name)
|
||||
|
||||
return ClientSecretList(name=name, clients=clients)
|
||||
|
||||
@router.get("/secrets/{name}/detailed")
|
||||
async def get_secret_clients_detailed(
|
||||
name: str,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientSecretDetailList:
|
||||
"""Get a list of which clients has a named secret."""
|
||||
detail_list = ClientSecretDetailList(name=name)
|
||||
client_secrets = await session.scalars(
|
||||
select(ClientSecret)
|
||||
.options(selectinload(ClientSecret.client))
|
||||
.where(ClientSecret.name == name)
|
||||
.where(ClientSecret.client.is_(Client.is_active))
|
||||
)
|
||||
for client_secret in client_secrets.all():
|
||||
if not client_secret.client:
|
||||
continue
|
||||
detail_list.ids.append(str(client_secret.id))
|
||||
detail_list.clients.append(
|
||||
ClientReference(
|
||||
id=str(client_secret.client.id), name=client_secret.client.name
|
||||
)
|
||||
)
|
||||
|
||||
return detail_list
|
||||
|
||||
return router
|
||||
@ -0,0 +1 @@
|
||||
|
||||
@ -0,0 +1,328 @@
|
||||
"""ClientSecret operations."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from sqlalchemy import Null, distinct, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sshecret_backend import audit
|
||||
from sshecret_backend.api.common import (
|
||||
FlexID,
|
||||
IdType,
|
||||
get_client_by_id,
|
||||
resolve_client_id,
|
||||
)
|
||||
from sshecret_backend.api.secrets.schemas import (
|
||||
ClientReference,
|
||||
ClientSecretDetailList,
|
||||
ClientSecretResponse,
|
||||
ClientSecretStats,
|
||||
DistinctClientSecretStats,
|
||||
)
|
||||
from sshecret_backend.models import Client, ClientSecret
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
RelaxedId = uuid.UUID | str
|
||||
|
||||
|
||||
def _id(id: RelaxedId) -> uuid.UUID:
|
||||
"""Ensure that the ID is a uuid."""
|
||||
if isinstance(id, str):
|
||||
return uuid.UUID(id)
|
||||
return id
|
||||
|
||||
|
||||
class ClientSecretOperations:
|
||||
"""Perform operations on a client's ClientSecrets"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
request: Request,
|
||||
client: FlexID,
|
||||
include_deleted: bool = False,
|
||||
) -> None:
|
||||
"""Create operations class."""
|
||||
self._client_id: uuid.UUID | None = None
|
||||
self.client_name: str | None = None
|
||||
self.client: Client | None = None
|
||||
if client.type is IdType.ID:
|
||||
self._client_id = _id(client.value)
|
||||
else:
|
||||
self.client_name = str(client.value)
|
||||
|
||||
self.include_deleted: bool = include_deleted
|
||||
|
||||
self.session: AsyncSession = session
|
||||
self.request: Request = request
|
||||
|
||||
async def get_client_id(self) -> uuid.UUID | None:
|
||||
"""Get client ID."""
|
||||
if self._client_id:
|
||||
LOG.debug("Returning previously resolved client ID.")
|
||||
return self._client_id
|
||||
|
||||
if not self.client_name:
|
||||
raise RuntimeError("Error: No client information registered.")
|
||||
|
||||
client_id = await resolve_client_id(
|
||||
self.session, self.client_name, include_deleted=self.include_deleted
|
||||
)
|
||||
self._client_id = client_id
|
||||
LOG.debug("Saving client ID %s", client_id)
|
||||
return client_id
|
||||
|
||||
async def get_client(self) -> Client | None:
|
||||
"""Get client."""
|
||||
if self.client:
|
||||
return self.client
|
||||
client_id = await self.get_client_id()
|
||||
if not client_id:
|
||||
return None
|
||||
|
||||
client = await get_client_by_id(self.session, client_id)
|
||||
self.client = client
|
||||
return client
|
||||
|
||||
async def _get_client_secret(
|
||||
self, secret_identifier: FlexID
|
||||
) -> ClientSecret | None:
|
||||
"""Get client secret private method."""
|
||||
client = await self.get_client()
|
||||
if not client:
|
||||
LOG.debug("Client lookup failed.")
|
||||
return None
|
||||
|
||||
if secret_identifier.type is IdType.ID:
|
||||
match_id = _id(secret_identifier.value)
|
||||
LOG.debug("Searching for secrets matching ID %r", match_id)
|
||||
matches = [secret for secret in client.secrets if secret.id == match_id]
|
||||
else:
|
||||
LOG.debug("Searching for secrets matching name")
|
||||
matches = [
|
||||
secret
|
||||
for secret in client.secrets
|
||||
if secret.name == str(secret_identifier.value)
|
||||
]
|
||||
|
||||
for secret_match in matches:
|
||||
if secret_match.deleted:
|
||||
# TODO: Add override for deleted.
|
||||
LOG.debug("Found deleted secret")
|
||||
continue
|
||||
|
||||
await audit.audit_access_secret(
|
||||
self.session, self.request, client, secret_match
|
||||
)
|
||||
LOG.debug("Found matching secret")
|
||||
return secret_match
|
||||
|
||||
LOG.debug("No secrets matched.")
|
||||
return None
|
||||
|
||||
async def get_client_secret(
|
||||
self, secret_identifier: FlexID
|
||||
) -> ClientSecretResponse:
|
||||
"""Get client secret public model."""
|
||||
client_secret = await self._get_client_secret(secret_identifier)
|
||||
if not client_secret:
|
||||
raise HTTPException(status_code=404, detail="Unknown client or secret")
|
||||
return ClientSecretResponse.from_client_secret(client_secret)
|
||||
|
||||
async def create_client_secret(
|
||||
self, name: str, value: str, description: str | None = None
|
||||
) -> ClientSecretResponse:
|
||||
"""Create client Secret."""
|
||||
client = await self.get_client()
|
||||
if not client:
|
||||
raise HTTPException(status_code=404, detail="Client not found.")
|
||||
|
||||
existing_secret = await self._get_client_secret(FlexID.name(name))
|
||||
if existing_secret:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot add a secret. A different secret with the same name already exists.",
|
||||
)
|
||||
secret = ClientSecret(
|
||||
name=name,
|
||||
description=description,
|
||||
client=client,
|
||||
secret=value,
|
||||
)
|
||||
self.session.add(secret)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(secret, attribute_names=["client", "updated_at"])
|
||||
await audit.audit_create_secret(self.session, self.request, client, secret)
|
||||
return ClientSecretResponse.from_client_secret(secret)
|
||||
|
||||
async def update_client_secret_value(
|
||||
self, secret_identifier: FlexID, value: str
|
||||
) -> ClientSecretResponse:
|
||||
"""Update client secret value."""
|
||||
client_secret = await self._get_client_secret(secret_identifier)
|
||||
if not client_secret:
|
||||
# We can use this method to create a secret too.
|
||||
if secret_identifier.type is IdType.NAME:
|
||||
return await self.create_client_secret(
|
||||
str(secret_identifier.value), value
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="Unknown client or secret")
|
||||
client_secret.secret = value
|
||||
self.session.add(client_secret)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(client_secret, ["updated_at"])
|
||||
await audit.audit_update_secret(
|
||||
self.session, self.request, client_secret.client, client_secret
|
||||
)
|
||||
return ClientSecretResponse.from_client_secret(client_secret)
|
||||
|
||||
async def delete_client_secret(self, secret_identifier: FlexID) -> None:
|
||||
"""Delete a client secret."""
|
||||
client_secret = await self._get_client_secret(secret_identifier)
|
||||
if not client_secret:
|
||||
return
|
||||
|
||||
client_secret.deleted = True
|
||||
client_secret.deleted_at = datetime.now(timezone.utc)
|
||||
self.session.add(client_secret)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(client_secret, ["deleted", "deleted_at"])
|
||||
await audit.audit_delete_secret(
|
||||
self.session, self.request, client_secret.client, client_secret
|
||||
)
|
||||
|
||||
|
||||
async def resolve_client_secret_mapping(
|
||||
session: AsyncSession,
|
||||
) -> list[ClientSecretDetailList]:
|
||||
"""Resolve mapping of clients to secrets."""
|
||||
result = await session.execute(
|
||||
select(ClientSecret)
|
||||
.join(ClientSecret.client)
|
||||
.options(selectinload(ClientSecret.client))
|
||||
.where(Client.is_active.is_not(False))
|
||||
.where(ClientSecret.deleted.is_not(True))
|
||||
)
|
||||
client_secrets: dict[str, ClientSecretDetailList] = {}
|
||||
for secret in result.scalars().all():
|
||||
if secret.name not in client_secrets:
|
||||
client_secrets[secret.name] = ClientSecretDetailList(name=secret.name)
|
||||
client_secrets[secret.name].ids.append(str(secret.id))
|
||||
if not secret.client:
|
||||
continue
|
||||
client_secrets[secret.name].clients.append(
|
||||
ClientReference(id=str(secret.client.id), name=secret.client.name)
|
||||
)
|
||||
|
||||
return list(client_secrets.values())
|
||||
|
||||
|
||||
async def resolve_client_secret_clients(
|
||||
session: AsyncSession, name: str, include_deleted: bool = False
|
||||
) -> ClientSecretDetailList | None:
|
||||
"""Resolve client association to a secret."""
|
||||
statement = (
|
||||
select(ClientSecret)
|
||||
.options(selectinload(ClientSecret.client))
|
||||
.where(ClientSecret.name == name)
|
||||
)
|
||||
if not include_deleted:
|
||||
statement = statement.where(ClientSecret.deleted.is_not(True))
|
||||
|
||||
results = await session.execute(statement)
|
||||
clients: ClientSecretDetailList | None = None
|
||||
for client_secret in results.scalars().all():
|
||||
if not clients:
|
||||
# Ensure we don't create the object before we have at least one client.
|
||||
clients = ClientSecretDetailList(name=name)
|
||||
clients.ids.append(str(client_secret.id))
|
||||
if client_secret.client:
|
||||
clients.clients.append(
|
||||
ClientReference(
|
||||
id=str(client_secret.client.id), name=client_secret.client.name
|
||||
)
|
||||
)
|
||||
|
||||
return clients
|
||||
|
||||
|
||||
async def _get_secrets_total(session: AsyncSession) -> int:
|
||||
"""Get total amount of secrets excluding deleted."""
|
||||
statement = (
|
||||
select(func.count())
|
||||
.select_from(ClientSecret)
|
||||
.where(ClientSecret.deleted.is_not(True))
|
||||
)
|
||||
result = await session.execute(statement)
|
||||
return result.scalar_one()
|
||||
|
||||
|
||||
async def _get_distinct_secret_count(session: AsyncSession) -> int:
|
||||
"""Get the amount of distinct secrets, excluding deleted."""
|
||||
statement = (
|
||||
select(func.count())
|
||||
.select_from(ClientSecret)
|
||||
.where(ClientSecret.deleted.is_not(True))
|
||||
.distinct()
|
||||
)
|
||||
result = await session.execute(statement)
|
||||
return result.scalar_one()
|
||||
|
||||
|
||||
async def _get_secret_client_stats(
|
||||
session: AsyncSession,
|
||||
) -> list[DistinctClientSecretStats]:
|
||||
"""Get stats for each named secret."""
|
||||
statement = (
|
||||
(
|
||||
select(
|
||||
ClientSecret.name, func.count(distinct(Client.id)).label("client_count")
|
||||
)
|
||||
)
|
||||
.join(Client, ClientSecret.client_id)
|
||||
.where(ClientSecret.deleted.is_not(True), Client.is_deleted.is_not(True))
|
||||
.group_by(ClientSecret.name)
|
||||
)
|
||||
results = await session.execute(statement)
|
||||
stats: list[DistinctClientSecretStats] = []
|
||||
rows = results.all()
|
||||
for row in rows:
|
||||
secret_name, client_count = row.tuple()
|
||||
stats.append(DistinctClientSecretStats(name=secret_name, clients=client_count))
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
async def _get_secrets_without_clients(
|
||||
session: AsyncSession,
|
||||
) -> int:
|
||||
"""Get secrets without clients."""
|
||||
statement = (
|
||||
select(func.count(distinct(ClientSecret.name)))
|
||||
.where(ClientSecret.deleted.is_not(True))
|
||||
.where(ClientSecret.client_id.is_(Null))
|
||||
)
|
||||
result = await session.execute(statement)
|
||||
return result.scalar_one()
|
||||
|
||||
|
||||
async def get_client_secret_stats(session: AsyncSession) -> ClientSecretStats:
|
||||
"""Get stats for the client secrets.
|
||||
|
||||
TODO: Implement a route with this.
|
||||
"""
|
||||
distinct_secrets = await _get_distinct_secret_count(session)
|
||||
total_secrets = await _get_secrets_total(session)
|
||||
secrets_without_clients = await _get_secrets_without_clients(session)
|
||||
client_stats = await _get_secret_client_stats(session)
|
||||
|
||||
return ClientSecretStats(
|
||||
distinct_secrets=distinct_secrets,
|
||||
total_secrets=total_secrets,
|
||||
secrets_without_clients=secrets_without_clients,
|
||||
client_stats=client_stats,
|
||||
)
|
||||
@ -0,0 +1,107 @@
|
||||
"""Client Secret Router."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from sshecret_backend.api.common import FlexID
|
||||
from sshecret_backend.types import AsyncDBSessionDep
|
||||
from sshecret_backend.api.secrets.operations import (
|
||||
ClientSecretOperations,
|
||||
resolve_client_secret_clients,
|
||||
resolve_client_secret_mapping,
|
||||
)
|
||||
from sshecret_backend.api.schemas import BodyValue
|
||||
from sshecret_backend.api.secrets.schemas import (
|
||||
ClientSecretDetailList,
|
||||
ClientSecretPublic,
|
||||
ClientSecretResponse,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_client_secrets_router(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||
"""Create client secret router."""
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/clients/{client_identifier}/secrets/")
|
||||
async def add_secret_to_client(
|
||||
request: Request,
|
||||
client_identifier: str,
|
||||
client_secret: ClientSecretPublic,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> None:
|
||||
"""Add secret to a client."""
|
||||
client = FlexID.from_string(client_identifier)
|
||||
client_op = ClientSecretOperations(session, request, client)
|
||||
await client_op.create_client_secret(
|
||||
client_secret.name, client_secret.secret, client_secret.description
|
||||
)
|
||||
|
||||
@router.put("/clients/{client_identifier}/secrets/{secret_identifier}")
|
||||
async def update_client_secret(
|
||||
request: Request,
|
||||
client_identifier: str,
|
||||
secret_identifier: str,
|
||||
secret_data: BodyValue,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientSecretResponse:
|
||||
"""Update client secret."""
|
||||
client = FlexID.from_string(client_identifier)
|
||||
secret = FlexID.from_string(secret_identifier)
|
||||
client_op = ClientSecretOperations(session, request, client)
|
||||
return await client_op.update_client_secret_value(secret, secret_data.value)
|
||||
|
||||
@router.get("/clients/{client_identifier}/secrets/{secret_identifier}")
|
||||
async def request_client_secret_named(
|
||||
request: Request,
|
||||
client_identifier: str,
|
||||
secret_identifier: str,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientSecretResponse:
|
||||
"""Get a named secret from a named client."""
|
||||
client = FlexID.from_string(client_identifier)
|
||||
secret = FlexID.from_string(secret_identifier)
|
||||
LOG.debug("Resolved client FlexID: %r", client)
|
||||
LOG.debug("Resolved secret FlexID: %r", secret)
|
||||
client_op = ClientSecretOperations(session, request, client)
|
||||
return await client_op.get_client_secret(secret)
|
||||
|
||||
# TODO: delete_client_secret
|
||||
@router.delete("/clients/{client_identifier}/secrets/{secret_identifier}")
|
||||
async def delete_client_secret(
|
||||
request: Request,
|
||||
client_identifier: str,
|
||||
secret_identifier: str,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> None:
|
||||
"""Delete client secret."""
|
||||
client = FlexID.from_string(client_identifier)
|
||||
secret = FlexID.from_string(secret_identifier)
|
||||
client_op = ClientSecretOperations(session, request, client)
|
||||
await client_op.delete_client_secret(secret)
|
||||
|
||||
@router.get("/secrets/")
|
||||
async def get_secret_map(
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> list[ClientSecretDetailList]:
|
||||
"""Get a list of secrets and which clients have them."""
|
||||
return await resolve_client_secret_mapping(session)
|
||||
|
||||
@router.get("/secrets/{name}")
|
||||
async def get_secret_clients(
|
||||
name: str,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientSecretDetailList:
|
||||
"""Get a list of which clients has a named secret."""
|
||||
result = await resolve_client_secret_clients(session, name)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Secret not found.")
|
||||
return result
|
||||
|
||||
return router
|
||||
@ -1,21 +1,12 @@
|
||||
"""Models for API views."""
|
||||
"""Schemas for the secrets router."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Self, override
|
||||
from collections.abc import Sequence
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel, Field, IPvAnyAddress, IPvAnyNetwork
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
from . import models
|
||||
|
||||
|
||||
|
||||
class BodyValue(BaseModel):
|
||||
"""A generic model with just a value parameter."""
|
||||
|
||||
value: str
|
||||
from sshecret_backend import models
|
||||
|
||||
|
||||
class ClientSecretPublic(BaseModel):
|
||||
@ -38,6 +29,7 @@ class ClientSecretPublic(BaseModel):
|
||||
class ClientSecretResponse(ClientSecretPublic):
|
||||
"""A secret view."""
|
||||
|
||||
id: uuid.UUID
|
||||
created_at: datetime | None
|
||||
updated_at: datetime | None = None
|
||||
|
||||
@ -47,6 +39,7 @@ class ClientSecretResponse(ClientSecretPublic):
|
||||
"""Instantiate from ClientSecret."""
|
||||
|
||||
return cls(
|
||||
id=client_secret.id,
|
||||
name=client_secret.name,
|
||||
secret=client_secret.secret,
|
||||
created_at=client_secret.created_at,
|
||||
@ -54,25 +47,6 @@ class ClientSecretResponse(ClientSecretPublic):
|
||||
)
|
||||
|
||||
|
||||
class ClientPolicyView(BaseModel):
|
||||
"""Update object for client policy."""
|
||||
|
||||
sources: list[str] = ["0.0.0.0/0", "::/0"]
|
||||
|
||||
@classmethod
|
||||
def from_client(cls, client: models.Client) -> Self:
|
||||
"""Create from client."""
|
||||
if not client.policies:
|
||||
return cls()
|
||||
return cls(sources=[policy.source for policy in client.policies])
|
||||
|
||||
|
||||
class ClientPolicyUpdate(BaseModel):
|
||||
"""Model for updating policies."""
|
||||
|
||||
sources: list[IPvAnyAddress | IPvAnyNetwork]
|
||||
|
||||
|
||||
class ClientSecretList(BaseModel):
|
||||
"""Model for aggregating identically named secrets."""
|
||||
|
||||
@ -95,31 +69,20 @@ class ClientSecretDetailList(BaseModel):
|
||||
clients: list[ClientReference] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AuditView(BaseModel):
|
||||
"""Audit log view."""
|
||||
class DistinctClientSecretStats(BaseModel):
|
||||
"""Stats for distinct client secrets."""
|
||||
|
||||
name: str
|
||||
clients: int = 0
|
||||
|
||||
|
||||
id: uuid.UUID | None = None
|
||||
subsystem: models.SubSystem
|
||||
message: str
|
||||
operation: models.Operation
|
||||
data: dict[str, str] | None = None
|
||||
client_id: uuid.UUID | None = None
|
||||
client_name: str | None = None
|
||||
secret_id: uuid.UUID | None = None
|
||||
secret_name: str | None = None
|
||||
origin: str | None = None
|
||||
timestamp: datetime | None = None
|
||||
class ClientSecretStats(BaseModel):
|
||||
"""Stats row for the clientSecret model.
|
||||
|
||||
Useful for pagination and statistics.
|
||||
"""
|
||||
|
||||
class AuditInfo(BaseModel):
|
||||
"""Information about audit information."""
|
||||
|
||||
entries: int
|
||||
|
||||
|
||||
class AuditListResult(BaseModel):
|
||||
"""Class to return when listing audit entries."""
|
||||
results: Sequence[AuditView]
|
||||
total: int
|
||||
remaining: int
|
||||
distinct_secrets: int
|
||||
total_secrets: int
|
||||
secrets_without_clients: int
|
||||
client_stats: list[DistinctClientSecretStats] = Field(default_factory=list)
|
||||
@ -5,7 +5,14 @@ from fastapi import Request
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy, Operation, SubSystem
|
||||
from .models import (
|
||||
AuditLog,
|
||||
Client,
|
||||
ClientSecret,
|
||||
ClientAccessPolicy,
|
||||
Operation,
|
||||
SubSystem,
|
||||
)
|
||||
|
||||
|
||||
def _get_origin(request: Request) -> str | None:
|
||||
@ -128,6 +135,27 @@ async def audit_update_client(
|
||||
await _write_audit_log(session, request, entry, commit)
|
||||
|
||||
|
||||
async def audit_new_client_version(
|
||||
session: AsyncSession,
|
||||
request: Request,
|
||||
old_client: Client,
|
||||
new_client: Client,
|
||||
commit: bool = True,
|
||||
) -> None:
|
||||
"""Audit an update secret event."""
|
||||
entry = AuditLog(
|
||||
operation=Operation.UPDATE,
|
||||
client_id=old_client.id,
|
||||
client_name=old_client.name,
|
||||
message="Client data updated",
|
||||
data={
|
||||
"new_client_id": str(new_client.id),
|
||||
"new_client_version": new_client.version,
|
||||
},
|
||||
)
|
||||
await _write_audit_log(session, request, entry, commit)
|
||||
|
||||
|
||||
async def audit_update_secret(
|
||||
session: AsyncSession,
|
||||
request: Request,
|
||||
@ -224,6 +252,7 @@ async def audit_access_secret(
|
||||
)
|
||||
await _write_audit_log(session, request, entry, commit)
|
||||
|
||||
|
||||
async def audit_client_secret_list(
|
||||
session: AsyncSession, request: Request, commit: bool = True
|
||||
) -> None:
|
||||
@ -233,4 +262,3 @@ async def audit_client_secret_list(
|
||||
message="All secret names and their clients was viewed",
|
||||
)
|
||||
await _write_audit_log(session, request, entry, commit)
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
import bcrypt
|
||||
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
"""Hash a token."""
|
||||
pwbytes = token.encode("utf-8")
|
||||
|
||||
@ -9,7 +9,8 @@ from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sshecret_backend.db import DatabaseSessionManager
|
||||
from sshecret_backend.settings import BackendSettings
|
||||
from .api import get_audit_api, get_policy_api, get_secrets_api
|
||||
from .api.audit.router import create_audit_router
|
||||
from .api.secrets.router import create_client_secrets_router
|
||||
from .api.clients.router import create_client_router
|
||||
from .auth import verify_token
|
||||
from .models import (
|
||||
@ -60,9 +61,8 @@ def get_backend_api(
|
||||
dependencies=[Depends(validate_token)],
|
||||
)
|
||||
|
||||
backend_api.include_router(get_audit_api(get_db_session))
|
||||
backend_api.include_router(create_audit_router(get_db_session))
|
||||
backend_api.include_router(create_client_router(get_db_session))
|
||||
backend_api.include_router(get_policy_api(get_db_session))
|
||||
backend_api.include_router(get_secrets_api(get_db_session))
|
||||
backend_api.include_router(create_client_secrets_router(get_db_session))
|
||||
|
||||
return backend_api
|
||||
|
||||
@ -8,7 +8,13 @@ from collections.abc import AsyncIterator, Generator, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Literal
|
||||
from sqlalchemy import create_engine, Engine, event, select
|
||||
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession, async_sessionmaker, create_async_engine, AsyncEngine
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncConnection,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
AsyncEngine,
|
||||
)
|
||||
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
|
||||
@ -21,11 +27,14 @@ from .models import APIClient, SubSystem
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
class DatabaseSessionManager:
|
||||
def __init__(self, host: URL, **engine_kwargs: str):
|
||||
self._engine: AsyncEngine | None = get_async_engine(host)
|
||||
self._sessionmaker: async_sessionmaker[AsyncSession] | None = async_sessionmaker(autocommit=False, bind=self._engine, expire_on_commit=False)
|
||||
self._sessionmaker: async_sessionmaker[AsyncSession] | None = (
|
||||
async_sessionmaker(
|
||||
autocommit=False, bind=self._engine, expire_on_commit=False
|
||||
)
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
if self._engine is None:
|
||||
@ -81,7 +90,6 @@ def setup_database(
|
||||
return engine, get_db_session
|
||||
|
||||
|
||||
|
||||
def get_engine(url: URL, echo: bool = False) -> Engine:
|
||||
"""Initialize the engine."""
|
||||
engine = create_engine(url, echo=echo)
|
||||
@ -102,6 +110,7 @@ def get_async_engine(url: URL, echo: bool = False, **engine_kwargs: str) -> Asyn
|
||||
"""Get an async engine."""
|
||||
engine = create_async_engine(url, echo=echo, **engine_kwargs)
|
||||
if url.drivername.startswith("sqlite+"):
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def set_sqlite_pragma(
|
||||
dbapi_connection: sqlite3.Connection, _connection_record: object
|
||||
@ -113,18 +122,21 @@ def get_async_engine(url: URL, echo: bool = False, **engine_kwargs: str) -> Asyn
|
||||
return engine
|
||||
|
||||
|
||||
|
||||
def create_api_token_with_value(session: Session, token: str, subsystem: Literal["admin", "sshd"]) -> None:
|
||||
def create_api_token_with_value(
|
||||
session: Session, token: str, subsystem: Literal["admin", "sshd"]
|
||||
) -> None:
|
||||
"""Create API token with a given value."""
|
||||
|
||||
existing = session.scalars(select(APIClient).where(APIClient.subsystem == SubSystem(subsystem))).first()
|
||||
existing = session.scalars(
|
||||
select(APIClient).where(APIClient.subsystem == SubSystem(subsystem))
|
||||
).first()
|
||||
if existing:
|
||||
if verify_token(token, existing.token):
|
||||
LOG.info("Token is up to date.")
|
||||
return
|
||||
LOG.info("Updating token value for subsystem %s", subsystem)
|
||||
hashed = hash_token(token)
|
||||
existing.token=hashed
|
||||
existing.token = hashed
|
||||
session.commit()
|
||||
return
|
||||
|
||||
@ -135,12 +147,19 @@ def create_api_token_with_value(session: Session, token: str, subsystem: Literal
|
||||
session.add(api_token)
|
||||
session.commit()
|
||||
|
||||
def create_api_token(session: Session, subsystem: Literal["admin", "sshd", "test"], recreate: bool = False) -> str:
|
||||
|
||||
def create_api_token(
|
||||
session: Session,
|
||||
subsystem: Literal["admin", "sshd", "test"],
|
||||
recreate: bool = False,
|
||||
) -> str:
|
||||
"""Create API token."""
|
||||
subsys = SubSystem(subsystem)
|
||||
token = secrets.token_urlsafe(32)
|
||||
hashed = hash_token(token)
|
||||
if existing := session.scalars(select(APIClient).where(APIClient.subsystem == subsys)).first():
|
||||
if existing := session.scalars(
|
||||
select(APIClient).where(APIClient.subsystem == subsys)
|
||||
).first():
|
||||
if not recreate:
|
||||
raise RuntimeError("Error: A token already exist for this subsystem.")
|
||||
existing.token = hashed
|
||||
|
||||
@ -79,8 +79,7 @@ class Client(Base):
|
||||
)
|
||||
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=True
|
||||
sa.DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
secrets: Mapped[list["ClientSecret"]] = relationship(
|
||||
@ -93,9 +92,7 @@ class Client(Base):
|
||||
nullable=True,
|
||||
)
|
||||
previous_version: Mapped["Client | None"] = relationship(
|
||||
"Client",
|
||||
remote_side=[id],
|
||||
backref="versions"
|
||||
"Client", remote_side=[id], backref="versions"
|
||||
)
|
||||
|
||||
policies: Mapped[list["ClientAccessPolicy"]] = relationship(back_populates="client")
|
||||
@ -142,7 +139,7 @@ class ClientSecret(Base):
|
||||
sa.Uuid(as_uuid=True), sa.ForeignKey("client.id", ondelete="CASCADE")
|
||||
)
|
||||
client: Mapped[Client] = relationship(back_populates="secrets")
|
||||
invalidated: Mapped[bool] = mapped_column(default=False)
|
||||
deleted: Mapped[bool] = mapped_column(default=False)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||
@ -154,6 +151,10 @@ class ClientSecret(Base):
|
||||
onupdate=sa.func.now(),
|
||||
)
|
||||
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
|
||||
class APIClient(Base):
|
||||
"""A client on the API.
|
||||
|
||||
Reference in New Issue
Block a user