Compare commits
5 Commits
435b9dee83
...
3779e93b8c
| Author | SHA1 | Date | |
|---|---|---|---|
| 3779e93b8c | |||
| 7ad41f43d8 | |||
| aa6b55a911 | |||
| a7a09f7784 | |||
| ee1e7a16ec |
@ -43,7 +43,7 @@ jobs:
|
||||
git.eising.cloud/${{ env.DOCKER_ORG }}/${{ steps.meta.outputs.REPO_NAME }}-backend:${{ gitea.ref == 'refs/heads/main' && 'latest' || gitea.sha }}
|
||||
|
||||
|
||||
- name: Build backend and push
|
||||
- name: Build sshd and push
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: .
|
||||
@ -54,7 +54,7 @@ jobs:
|
||||
tags: |
|
||||
git.eising.cloud/${{ env.DOCKER_ORG }}/${{ steps.meta.outputs.REPO_NAME }}-sshd:${{ gitea.ref == 'refs/heads/main' && 'latest' || gitea.sha }}
|
||||
|
||||
- name: Build backend and push
|
||||
- name: Build admin and push
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: .
|
||||
|
||||
@ -0,0 +1,36 @@
|
||||
"""Remove invalidated, add deleted
|
||||
|
||||
Revision ID: b4e135ff347a
|
||||
Revises: f2dc50533f88
|
||||
Create Date: 2025-06-06 08:57:47.611854
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'b4e135ff347a'
|
||||
down_revision: Union[str, None] = 'f2dc50533f88'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('client_secret', sa.Column('deleted', sa.Boolean(), nullable=False))
|
||||
op.add_column('client_secret', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
|
||||
op.drop_column('client_secret', 'invalidated')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('client_secret', sa.Column('invalidated', sa.BOOLEAN(), nullable=False))
|
||||
op.drop_column('client_secret', 'deleted_at')
|
||||
op.drop_column('client_secret', 'deleted')
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,51 @@
|
||||
"""Make client object better
|
||||
|
||||
Revision ID: c251311b64c9
|
||||
Revises: 37329d9b5437
|
||||
Create Date: 2025-06-04 21:49:22.638698
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "c251311b64c9"
|
||||
down_revision: Union[str, None] = "37329d9b5437"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("client", schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column("version", sa.Integer(), nullable=False))
|
||||
batch_op.add_column(sa.Column("is_active", sa.Boolean(), nullable=False))
|
||||
batch_op.add_column(sa.Column("is_deleted", sa.Boolean(), nullable=False))
|
||||
batch_op.add_column(
|
||||
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
batch_op.add_column(sa.Column("parent_id", sa.Uuid(), nullable=True))
|
||||
batch_op.create_unique_constraint("uq_client_name_version", ["name", "version"])
|
||||
batch_op.create_foreign_key(
|
||||
"fk_client_parent", "client", ["parent_id"], ["id"], ondelete="SET NULL"
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("client", schema=None) as batch_op:
|
||||
batch_op.drop_constraint("fk_client_parent", type_="foreignkey")
|
||||
batch_op.drop_constraint("uq_client_name_version", type_="unique")
|
||||
batch_op.drop_column("parent_id")
|
||||
batch_op.drop_column("deleted_at")
|
||||
batch_op.drop_column("is_deleted")
|
||||
batch_op.drop_column("is_active")
|
||||
batch_op.drop_column("version")
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,27 @@
|
||||
"""Rename parent to previous_version for clarity
|
||||
|
||||
Revision ID: f2dc50533f88
|
||||
Revises: c251311b64c9
|
||||
Create Date: 2025-06-05 13:24:32.465927
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'f2dc50533f88'
|
||||
down_revision: Union[str, None] = 'c251311b64c9'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
op.alter_column("client", "parent_id", new_column_name="previous_version_id")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
op.alter_column("client", "previous_version_id", new_column_name="parent_id")
|
||||
@ -1,8 +1 @@
|
||||
"""API factory modules."""
|
||||
|
||||
from .audit import get_audit_api
|
||||
from .clients import get_clients_api
|
||||
from .policies import get_policy_api
|
||||
from .secrets import get_secrets_api
|
||||
|
||||
__all__ = ["get_audit_api", "get_clients_api", "get_policy_api", "get_secrets_api"]
|
||||
|
||||
@ -0,0 +1 @@
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from sqlalchemy import select, func, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@ -15,7 +15,7 @@ from typing import Annotated
|
||||
|
||||
from sshecret_backend.models import AuditLog, Operation, SubSystem
|
||||
from sshecret_backend.types import AsyncDBSessionDep
|
||||
from sshecret_backend.view_models import AuditInfo, AuditView, AuditListResult
|
||||
from .schemas import AuditInfo, AuditView, AuditListResult
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -58,7 +58,7 @@ class AuditFilter(BaseModel):
|
||||
]
|
||||
|
||||
|
||||
def get_audit_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||
def create_audit_router(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||
"""Construct audit sub-api."""
|
||||
router = APIRouter()
|
||||
|
||||
@ -70,11 +70,13 @@ def get_audit_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||
"""Get audit logs."""
|
||||
# audit.audit_access_audit_log(session, request)
|
||||
|
||||
total = (await session.scalars(
|
||||
total = (
|
||||
await session.scalars(
|
||||
select(func.count("*"))
|
||||
.select_from(AuditLog)
|
||||
.where(and_(True, *filters.filter_mapping))
|
||||
)).one()
|
||||
)
|
||||
).one()
|
||||
|
||||
remaining = total - filters.offset
|
||||
statement = (
|
||||
@ -107,12 +109,12 @@ def get_audit_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||
|
||||
@router.get("/audit/info")
|
||||
async def get_audit_info(
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)]
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> AuditInfo:
|
||||
"""Get audit info."""
|
||||
audit_count = (await session.scalars(
|
||||
select(func.count("*")).select_from(AuditLog)
|
||||
)).one()
|
||||
audit_count = (
|
||||
await session.scalars(select(func.count("*")).select_from(AuditLog))
|
||||
).one()
|
||||
return AuditInfo(entries=audit_count)
|
||||
|
||||
return router
|
||||
@ -0,0 +1,40 @@
|
||||
"""Models for Audit API views."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
from sshecret_backend import models
|
||||
|
||||
|
||||
class AuditView(BaseModel):
|
||||
"""Audit log view."""
|
||||
|
||||
id: uuid.UUID | None = None
|
||||
subsystem: models.SubSystem
|
||||
message: str
|
||||
operation: models.Operation
|
||||
data: dict[str, str] | None = None
|
||||
client_id: uuid.UUID | None = None
|
||||
client_name: str | None = None
|
||||
secret_id: uuid.UUID | None = None
|
||||
secret_name: str | None = None
|
||||
origin: str | None = None
|
||||
timestamp: datetime | None = None
|
||||
|
||||
|
||||
class AuditInfo(BaseModel):
|
||||
"""Information about audit information."""
|
||||
|
||||
entries: int
|
||||
|
||||
|
||||
class AuditListResult(BaseModel):
|
||||
"""Class to return when listing audit entries."""
|
||||
|
||||
results: Sequence[AuditView]
|
||||
total: int
|
||||
remaining: int
|
||||
@ -1,227 +0,0 @@
|
||||
"""Client sub-api factory."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
|
||||
import uuid
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing import Annotated, Any, Self, TypeVar, cast
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.sql import Select
|
||||
from sshecret_backend.types import AsyncDBSessionDep
|
||||
from sshecret_backend.models import Client, ClientSecret
|
||||
from sshecret_backend.view_models import (
|
||||
ClientCreate,
|
||||
ClientQueryResult,
|
||||
ClientView,
|
||||
ClientUpdate,
|
||||
)
|
||||
from sshecret_backend import audit
|
||||
from .common import get_client_by_id_or_name, client_with_relationships
|
||||
|
||||
|
||||
class ClientListParams(BaseModel):
|
||||
"""Client list parameters."""
|
||||
|
||||
limit: int = Field(100, gt=0, le=100)
|
||||
offset: int = Field(0, ge=0)
|
||||
id: uuid.UUID | None = None
|
||||
name: str | None = None
|
||||
name__like: str | None = None
|
||||
name__contains: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_expressions(self) -> Self:
|
||||
"""Validate mutually exclusive expression."""
|
||||
name_filter = False
|
||||
if self.name__like or self.name__contains:
|
||||
name_filter = True
|
||||
if self.name__like and self.name__contains:
|
||||
raise ValueError("You may only specify one name expression")
|
||||
if self.name and name_filter:
|
||||
raise ValueError(
|
||||
"You must either specify name or one of name__like or name__contains"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def filter_client_statement(
|
||||
statement: Select[Any], params: ClientListParams, ignore_limits: bool = False
|
||||
) -> Select[Any]:
|
||||
"""Filter a statement with the provided params."""
|
||||
if params.id:
|
||||
statement = statement.where(Client.id == params.id)
|
||||
|
||||
if params.name:
|
||||
statement = statement.where(Client.name == params.name)
|
||||
elif params.name__like:
|
||||
statement = statement.where(Client.name.like(params.name__like))
|
||||
elif params.name__contains:
|
||||
statement = statement.where(Client.name.contains(params.name__contains))
|
||||
|
||||
if ignore_limits:
|
||||
return statement
|
||||
|
||||
return statement.limit(params.limit).offset(params.offset)
|
||||
|
||||
|
||||
def get_clients_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||
"""Construct clients sub-api."""
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/clients/")
|
||||
async def get_clients(
|
||||
filter_query: Annotated[ClientListParams, Query()],
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientQueryResult:
|
||||
"""Get clients."""
|
||||
# Get total results first
|
||||
count_statement = select(func.count("*")).select_from(Client)
|
||||
count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True))
|
||||
|
||||
total_results = (await session.scalars(count_statement)).one()
|
||||
|
||||
statement = filter_client_statement(client_with_relationships(), filter_query, False)
|
||||
|
||||
results = await session.scalars(statement)
|
||||
remainder = total_results - filter_query.offset - filter_query.limit
|
||||
if remainder < 0:
|
||||
remainder = 0
|
||||
|
||||
clients = list(results.all())
|
||||
clients_view = ClientView.from_client_list(clients)
|
||||
return ClientQueryResult(
|
||||
clients=clients_view,
|
||||
total_results=total_results,
|
||||
remaining_results=remainder,
|
||||
)
|
||||
|
||||
@router.get("/clients/{name}")
|
||||
async def get_client(
|
||||
name: str,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientView:
|
||||
"""Fetch a client."""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
return ClientView.from_client(client)
|
||||
|
||||
@router.delete("/clients/{name}")
|
||||
async def delete_client(
|
||||
request: Request,
|
||||
name: str,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> None:
|
||||
"""Delete a client."""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
await session.delete(client)
|
||||
await session.commit()
|
||||
await audit.audit_delete_client(session, request, client)
|
||||
|
||||
@router.post("/clients/")
|
||||
async def create_client(
|
||||
request: Request,
|
||||
client: ClientCreate,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientView:
|
||||
"""Create client."""
|
||||
existing = await get_client_by_id_or_name(session, client.name)
|
||||
if existing:
|
||||
raise HTTPException(400, detail="Error: Already a client with that name.")
|
||||
|
||||
db_client = client.to_client()
|
||||
session.add(db_client)
|
||||
await session.commit()
|
||||
await session.refresh(db_client)
|
||||
db_client = await get_client_by_id_or_name(session, client.name)
|
||||
if not db_client:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Could not create the client.")
|
||||
await audit.audit_create_client(session, request, db_client)
|
||||
return ClientView.from_client(db_client)
|
||||
|
||||
@router.post("/clients/{name}/public-key")
|
||||
async def update_client_public_key(
|
||||
request: Request,
|
||||
name: str,
|
||||
client_update: ClientUpdate,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientView:
|
||||
"""Change the public key of a client.
|
||||
|
||||
This invalidates all secrets.
|
||||
"""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
client.public_key = client_update.public_key
|
||||
matching_secrets = await session.scalars(select(ClientSecret).where(ClientSecret.client_id == client.id))
|
||||
for secret in matching_secrets.all():
|
||||
LOG.debug("Invalidated secret %s", secret.id)
|
||||
secret.invalidated = True
|
||||
secret.client_id = None
|
||||
|
||||
session.add(client)
|
||||
await session.refresh(client)
|
||||
await session.commit()
|
||||
await audit.audit_invalidate_secrets(session, request, client)
|
||||
|
||||
return ClientView.from_client(client)
|
||||
|
||||
@router.put("/clients/{name}")
|
||||
async def update_client(
|
||||
request: Request,
|
||||
name: str,
|
||||
client_update: ClientCreate,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientView:
|
||||
"""Change the public key of a client.
|
||||
|
||||
This invalidates all secrets.
|
||||
"""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
client.name = client_update.name
|
||||
client.description = client_update.description
|
||||
public_key_updated = False
|
||||
if client_update.public_key != client.public_key:
|
||||
public_key_updated = True
|
||||
client_secrets = await session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||
)
|
||||
for secret in client_secrets.all():
|
||||
LOG.debug("Invalidated secret %s", secret.id)
|
||||
secret.invalidated = True
|
||||
secret.client_id = None
|
||||
|
||||
session.add(client)
|
||||
await session.commit()
|
||||
await session.refresh(client)
|
||||
await audit.audit_update_client(session, request, client)
|
||||
if public_key_updated:
|
||||
await audit.audit_invalidate_secrets(session, request, client)
|
||||
|
||||
return ClientView.from_client(client)
|
||||
|
||||
return router
|
||||
@ -0,0 +1,320 @@
|
||||
"""Client operations."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, cast
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.sql import Select
|
||||
from sshecret_backend import audit
|
||||
from sshecret_backend.api.common import (
|
||||
FlexID,
|
||||
IdType,
|
||||
RelaxedId,
|
||||
create_new_client_version,
|
||||
query_active_clients,
|
||||
get_client_by_id,
|
||||
resolve_client_id,
|
||||
refresh_client,
|
||||
reload_client_with_relationships,
|
||||
)
|
||||
from sshecret_backend.models import Client, ClientAccessPolicy
|
||||
from .schemas import (
|
||||
ClientListParams,
|
||||
ClientCreate,
|
||||
ClientPolicyView,
|
||||
ClientView,
|
||||
ClientQueryResult,
|
||||
ClientPolicyUpdate,
|
||||
)
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _id(id: RelaxedId) -> uuid.UUID:
|
||||
"""Ensure that the ID is a uuid."""
|
||||
if isinstance(id, str):
|
||||
return uuid.UUID(id)
|
||||
return id
|
||||
|
||||
|
||||
class ClientOperations:
|
||||
"""Perform operations on a client."""
|
||||
|
||||
def __init__(self, session: AsyncSession, request: Request) -> None:
|
||||
"""Create operations class."""
|
||||
self.session: AsyncSession = session
|
||||
self.request: Request = request
|
||||
|
||||
self._client_id: uuid.UUID | None = None
|
||||
# self.client_name: str | None = None
|
||||
# if client.type is IdType.ID:
|
||||
# self._client_id = _id(client.value)
|
||||
# else:
|
||||
# self.client_name = str(client.value)
|
||||
|
||||
async def get_client_id(
|
||||
self,
|
||||
client: FlexID,
|
||||
version: int | None = None,
|
||||
) -> uuid.UUID | None:
|
||||
"""Get client ID."""
|
||||
if self._client_id:
|
||||
LOG.debug("Returning previously resolved client ID.")
|
||||
return self._client_id
|
||||
|
||||
if client.type is IdType.ID:
|
||||
self._client_id = _id(client.value)
|
||||
return self._client_id
|
||||
|
||||
client_name = str(client.value)
|
||||
client_id = await resolve_client_id(
|
||||
self.session,
|
||||
client_name,
|
||||
version=version,
|
||||
)
|
||||
if not client_id:
|
||||
return None
|
||||
self._client_id = client_id
|
||||
LOG.debug("Saving client ID %s", client_id)
|
||||
return client_id
|
||||
|
||||
async def _get_client(
|
||||
self,
|
||||
client: FlexID,
|
||||
version: int | None = None,
|
||||
) -> Client | None:
|
||||
"""Get client."""
|
||||
client_id = await self.get_client_id(client, version=version)
|
||||
if not client_id:
|
||||
return None
|
||||
db_client = await get_client_by_id(self.session, client_id)
|
||||
return db_client
|
||||
|
||||
async def get_client(
|
||||
self,
|
||||
client: FlexID,
|
||||
version: int | None = None,
|
||||
) -> ClientView:
|
||||
"""Get public client object."""
|
||||
db_client = await self._get_client(client, version)
|
||||
if not db_client:
|
||||
raise HTTPException(status_code=404, detail="Client not found.")
|
||||
return ClientView.from_client(db_client)
|
||||
|
||||
async def create_client(
|
||||
self,
|
||||
create_model: ClientCreate,
|
||||
) -> ClientView:
|
||||
"""Create a new client."""
|
||||
existing_id = await self.get_client_id(FlexID.name(create_model.name))
|
||||
if existing_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Error: A client already exists with this name."
|
||||
)
|
||||
client = create_model.to_client()
|
||||
self.session.add(client)
|
||||
await self.session.flush()
|
||||
await self.session.commit()
|
||||
await refresh_client(self.session, client)
|
||||
await audit.audit_create_client(self.session, self.request, client)
|
||||
return ClientView.from_client(client)
|
||||
|
||||
async def delete_client(self, client: FlexID) -> None:
|
||||
"""Delete client."""
|
||||
db_client = await self._get_client(client)
|
||||
if not db_client:
|
||||
return
|
||||
if db_client.is_deleted:
|
||||
return
|
||||
db_client.is_deleted = True
|
||||
db_client.deleted_at = datetime.now(timezone.utc)
|
||||
self.session.add(db_client)
|
||||
await self.session.commit()
|
||||
await audit.audit_delete_client(self.session, self.request, db_client)
|
||||
|
||||
async def update_client(
|
||||
self,
|
||||
client: FlexID,
|
||||
client_update: ClientCreate,
|
||||
) -> ClientView:
|
||||
"""Update client details."""
|
||||
db_client = await self._get_client(client)
|
||||
if not db_client:
|
||||
raise HTTPException(status_code=404, detail="Client not found.")
|
||||
|
||||
if (
|
||||
client_update.public_key
|
||||
and client_update.public_key != db_client.public_key
|
||||
):
|
||||
return await new_client_version(
|
||||
self.session,
|
||||
db_client.id,
|
||||
client_update.public_key,
|
||||
client_update.name,
|
||||
client_update.description,
|
||||
)
|
||||
db_client.name = client_update.name
|
||||
db_client.description = client_update.description
|
||||
self.session.add(db_client)
|
||||
await self.session.commit()
|
||||
await refresh_client(self.session, db_client)
|
||||
await audit.audit_update_client(self.session, self.request, db_client)
|
||||
return ClientView.from_client(db_client)
|
||||
|
||||
async def new_client_version(
|
||||
self,
|
||||
client: FlexID,
|
||||
public_key: str,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> ClientView:
|
||||
"""Update a client to a new version."""
|
||||
current_client = await self._get_client(client)
|
||||
if not current_client:
|
||||
raise HTTPException(status_code=404, detail="Client not found.")
|
||||
new_client = await create_new_client_version(
|
||||
self.session, current_client, public_key
|
||||
)
|
||||
if name:
|
||||
new_client.name = name
|
||||
if description:
|
||||
new_client.description = description
|
||||
|
||||
current_client.is_active = False
|
||||
self.session.add(current_client)
|
||||
|
||||
await self.session.commit()
|
||||
|
||||
await refresh_client(self.session, new_client)
|
||||
await audit.audit_new_client_version(
|
||||
self.session, self.request, current_client, new_client
|
||||
)
|
||||
return ClientView.from_client(new_client)
|
||||
|
||||
async def get_client_policies(self, client: FlexID) -> ClientPolicyView:
|
||||
"""Get client policies."""
|
||||
db_client = await self._get_client(client)
|
||||
if not db_client:
|
||||
raise HTTPException(status_code=404, detail="Client not found.")
|
||||
return ClientPolicyView.from_client(db_client)
|
||||
|
||||
async def update_client_policies(
|
||||
self, client: FlexID, policy_update: ClientPolicyUpdate
|
||||
) -> ClientPolicyView:
|
||||
"""Update client policies."""
|
||||
db_client = await self._get_client(client)
|
||||
if not db_client:
|
||||
raise HTTPException(status_code=404, detail="Client not found.")
|
||||
|
||||
policies = await self.session.scalars(
|
||||
select(ClientAccessPolicy).where(
|
||||
ClientAccessPolicy.client_id == db_client.id
|
||||
)
|
||||
)
|
||||
deleted_policies: list[ClientAccessPolicy] = []
|
||||
added_policies: list[ClientAccessPolicy] = []
|
||||
for policy in policies.all():
|
||||
await self.session.delete(policy)
|
||||
deleted_policies.append(policy)
|
||||
|
||||
LOG.debug("Updating client policies with: %r", policy_update.sources)
|
||||
for source in policy_update.sources:
|
||||
LOG.debug("Source %r", source)
|
||||
policy = ClientAccessPolicy(source=str(source), client_id=db_client.id)
|
||||
self.session.add(policy)
|
||||
added_policies.append(policy)
|
||||
|
||||
await self.session.flush()
|
||||
await self.session.commit()
|
||||
|
||||
db_client = await reload_client_with_relationships(self.session, db_client)
|
||||
for policy in deleted_policies:
|
||||
await audit.audit_remove_policy(
|
||||
self.session, self.request, db_client, policy
|
||||
)
|
||||
|
||||
for policy in added_policies:
|
||||
await audit.audit_update_policy(
|
||||
self.session, self.request, db_client, policy
|
||||
)
|
||||
|
||||
return ClientPolicyView.from_client(db_client)
|
||||
|
||||
|
||||
def filter_client_statement(
|
||||
statement: Select[Any], params: ClientListParams, ignore_limits: bool = False
|
||||
) -> Select[Any]:
|
||||
"""Filter a statement with the provided params."""
|
||||
if params.id:
|
||||
statement = statement.where(Client.id == params.id)
|
||||
|
||||
if params.name:
|
||||
statement = statement.where(Client.name == params.name)
|
||||
elif params.name__like:
|
||||
statement = statement.where(Client.name.like(params.name__like))
|
||||
elif params.name__contains:
|
||||
statement = statement.where(Client.name.contains(params.name__contains))
|
||||
|
||||
if ignore_limits:
|
||||
return statement
|
||||
|
||||
return statement.limit(params.limit).offset(params.offset)
|
||||
|
||||
|
||||
async def get_clients(
|
||||
session: AsyncSession,
|
||||
filter_query: ClientListParams,
|
||||
) -> ClientQueryResult:
|
||||
"""Get Clients."""
|
||||
count_statement = select(func.count("*")).select_from(Client)
|
||||
count_statement = cast(
|
||||
Select[tuple[int]],
|
||||
filter_client_statement(count_statement, filter_query, True),
|
||||
)
|
||||
|
||||
total_results = (await session.scalars(count_statement)).one()
|
||||
|
||||
statement = filter_client_statement(query_active_clients(), filter_query, False)
|
||||
|
||||
results = await session.scalars(statement)
|
||||
remainder = total_results - filter_query.offset - filter_query.limit
|
||||
if remainder < 0:
|
||||
remainder = 0
|
||||
|
||||
clients = list(results.all())
|
||||
clients_view = ClientView.from_client_list(clients)
|
||||
return ClientQueryResult(
|
||||
clients=clients_view,
|
||||
total_results=total_results,
|
||||
remaining_results=remainder,
|
||||
)
|
||||
|
||||
|
||||
async def new_client_version(
|
||||
session: AsyncSession,
|
||||
client_id: uuid.UUID,
|
||||
public_key: str,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> ClientView:
|
||||
"""Update a client to a new version."""
|
||||
current_client = await get_client_by_id(session, client_id)
|
||||
if not current_client:
|
||||
raise ValueError("Client not found.")
|
||||
new_client = await create_new_client_version(session, current_client, public_key)
|
||||
if name:
|
||||
new_client.name = name
|
||||
if description:
|
||||
new_client.description = description
|
||||
|
||||
current_client.is_active = False
|
||||
session.add(current_client)
|
||||
|
||||
await session.commit()
|
||||
|
||||
return ClientView.from_client(new_client)
|
||||
@ -0,0 +1,122 @@
|
||||
"""Client router."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
#
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from sshecret_backend.types import AsyncDBSessionDep
|
||||
from sshecret_backend.api.clients.schemas import (
|
||||
ClientCreate,
|
||||
ClientListParams,
|
||||
ClientQueryResult,
|
||||
ClientUpdate,
|
||||
ClientView,
|
||||
ClientPolicyUpdate,
|
||||
ClientPolicyView,
|
||||
)
|
||||
from sshecret_backend.api.clients import operations
|
||||
from sshecret_backend.api.clients.operations import ClientOperations
|
||||
from sshecret_backend.api.common import FlexID
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_client_router(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||
"""Create client router."""
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/clients/")
|
||||
async def route_get_clients(
|
||||
filter_query: Annotated[ClientListParams, Query()],
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientQueryResult:
|
||||
return await operations.get_clients(session, filter_query)
|
||||
|
||||
@router.post("/clients/")
|
||||
async def create_client(
|
||||
request: Request,
|
||||
client: ClientCreate,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientView:
|
||||
"""Create client."""
|
||||
client_op = ClientOperations(session, request)
|
||||
return await client_op.create_client(client)
|
||||
|
||||
@router.get("/clients/{client_identifier}")
|
||||
async def fetch_client_by_name(
|
||||
request: Request,
|
||||
client_identifier: str,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
version: Annotated[int | None, Query()] = None,
|
||||
) -> ClientView:
|
||||
"""Fetch client by name."""
|
||||
client_id = FlexID.from_string(client_identifier)
|
||||
client_op = ClientOperations(session, request)
|
||||
return await client_op.get_client(client_id, version)
|
||||
|
||||
@router.delete("/clients/{client_identifier}")
|
||||
async def delete_client(
|
||||
request: Request,
|
||||
client_identifier: str,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> None:
|
||||
"""Delete a client."""
|
||||
client_id = FlexID.from_string(client_identifier)
|
||||
client_op = ClientOperations(session, request)
|
||||
await client_op.delete_client(client_id)
|
||||
|
||||
@router.post("/clients/{client_identifier}/public-key")
|
||||
async def update_client_public_key(
|
||||
request: Request,
|
||||
client_identifier: str,
|
||||
client_update: ClientUpdate,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientView:
|
||||
"""Update client public key."""
|
||||
client_id = FlexID.from_string(client_identifier)
|
||||
client_op = ClientOperations(session, request)
|
||||
return await client_op.new_client_version(client_id, client_update.public_key)
|
||||
|
||||
@router.put("/clients/{client_identifier}")
|
||||
async def update_client_by_name(
|
||||
request: Request,
|
||||
client_identifier: str,
|
||||
client_update: ClientCreate,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientView:
|
||||
"""Update client by name."""
|
||||
client_id = FlexID.from_string(client_identifier)
|
||||
client_op = ClientOperations(session, request)
|
||||
return await client_op.update_client(client_id, client_update)
|
||||
|
||||
@router.get("/clients/{client_identifier}/policies/")
|
||||
async def get_client_policies(
|
||||
request: Request,
|
||||
client_identifier: str,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientPolicyView:
|
||||
"""Get client policies."""
|
||||
client_id = FlexID.from_string(client_identifier)
|
||||
client_op = ClientOperations(session, request)
|
||||
return await client_op.get_client_policies(client_id)
|
||||
|
||||
@router.put("/clients/{client_identifier}/policies/")
|
||||
async def update_client_policies(
|
||||
request: Request,
|
||||
client_identifier: str,
|
||||
policy_update: ClientPolicyUpdate,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientPolicyView:
|
||||
"""Update client policies.
|
||||
|
||||
This is also how you delete policies.
|
||||
"""
|
||||
client_id = FlexID.from_string(client_identifier)
|
||||
client_op = ClientOperations(session, request)
|
||||
return await client_op.update_client_policies(client_id, policy_update)
|
||||
|
||||
return router
|
||||
@ -1,14 +1,21 @@
|
||||
"""Models for API views."""
|
||||
"""Client related schemas."""
|
||||
|
||||
from typing import Annotated, Self
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Self, Sequence, override
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, Field, IPvAnyAddress, IPvAnyNetwork
|
||||
from pydantic import (
|
||||
AfterValidator,
|
||||
BaseModel,
|
||||
Field,
|
||||
IPvAnyAddress,
|
||||
IPvAnyNetwork,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from sshecret.crypto import public_key_validator
|
||||
|
||||
from . import models
|
||||
from sshecret_backend import models
|
||||
|
||||
|
||||
class ClientView(BaseModel):
|
||||
@ -19,9 +26,12 @@ class ClientView(BaseModel):
|
||||
description: str | None = None
|
||||
public_key: str
|
||||
policies: list[str] = ["0.0.0.0/0", "::/0"]
|
||||
is_active: bool = True
|
||||
is_deleted: bool = False
|
||||
secrets: list[str] = Field(default_factory=list)
|
||||
created_at: datetime | None
|
||||
updated_at: datetime | None = None
|
||||
deleted_at: datetime | None = None
|
||||
|
||||
@classmethod
|
||||
def from_client_list(cls, clients: list[models.Client]) -> list[Self]:
|
||||
@ -39,6 +49,9 @@ class ClientView(BaseModel):
|
||||
public_key=client.public_key,
|
||||
created_at=client.created_at,
|
||||
updated_at=client.updated_at or None,
|
||||
deleted_at=client.deleted_at or None,
|
||||
is_active=client.is_active,
|
||||
is_deleted=client.is_deleted,
|
||||
)
|
||||
if client.secrets:
|
||||
view.secrets = [secret.name for secret in client.secrets]
|
||||
@ -57,6 +70,32 @@ class ClientQueryResult(BaseModel):
|
||||
remaining_results: int
|
||||
|
||||
|
||||
class ClientListParams(BaseModel):
|
||||
"""Client list parameters."""
|
||||
|
||||
limit: int = Field(100, gt=0, le=100)
|
||||
offset: int = Field(0, ge=0)
|
||||
id: uuid.UUID | None = None
|
||||
name: str | None = None
|
||||
name__like: str | None = None
|
||||
name__contains: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_expressions(self) -> Self:
|
||||
"""Validate mutually exclusive expression."""
|
||||
name_filter = False
|
||||
if self.name__like or self.name__contains:
|
||||
name_filter = True
|
||||
if self.name__like and self.name__contains:
|
||||
raise ValueError("You may only specify one name expression")
|
||||
if self.name and name_filter:
|
||||
raise ValueError(
|
||||
"You must either specify name or one of name__like or name__contains"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class ClientCreate(BaseModel):
|
||||
"""Model to create a client."""
|
||||
|
||||
@ -79,48 +118,6 @@ class ClientUpdate(BaseModel):
|
||||
public_key: Annotated[str, AfterValidator(public_key_validator)]
|
||||
|
||||
|
||||
class BodyValue(BaseModel):
|
||||
"""A generic model with just a value parameter."""
|
||||
|
||||
value: str
|
||||
|
||||
|
||||
class ClientSecretPublic(BaseModel):
|
||||
"""Public model to manage client secrets."""
|
||||
|
||||
name: str
|
||||
secret: str
|
||||
description: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_client_secret(cls, client_secret: models.ClientSecret) -> Self:
|
||||
"""Instantiate from ClientSecret."""
|
||||
return cls(
|
||||
name=client_secret.name,
|
||||
secret=client_secret.secret,
|
||||
description=client_secret.description,
|
||||
)
|
||||
|
||||
|
||||
class ClientSecretResponse(ClientSecretPublic):
|
||||
"""A secret view."""
|
||||
|
||||
created_at: datetime | None
|
||||
updated_at: datetime | None = None
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
def from_client_secret(cls, client_secret: models.ClientSecret) -> Self:
|
||||
"""Instantiate from ClientSecret."""
|
||||
|
||||
return cls(
|
||||
name=client_secret.name,
|
||||
secret=client_secret.secret,
|
||||
created_at=client_secret.created_at,
|
||||
updated_at=client_secret.updated_at,
|
||||
)
|
||||
|
||||
|
||||
class ClientPolicyView(BaseModel):
|
||||
"""Update object for client policy."""
|
||||
|
||||
@ -138,55 +135,3 @@ class ClientPolicyUpdate(BaseModel):
|
||||
"""Model for updating policies."""
|
||||
|
||||
sources: list[IPvAnyAddress | IPvAnyNetwork]
|
||||
|
||||
|
||||
class ClientSecretList(BaseModel):
|
||||
"""Model for aggregating identically named secrets."""
|
||||
|
||||
name: str
|
||||
clients: list[str]
|
||||
|
||||
|
||||
class ClientReference(BaseModel):
|
||||
"""Reference to a client."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class ClientSecretDetailList(BaseModel):
|
||||
"""A more detailed version of the ClientSecretList."""
|
||||
|
||||
name: str
|
||||
ids: list[str] = Field(default_factory=list)
|
||||
clients: list[ClientReference] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AuditView(BaseModel):
|
||||
"""Audit log view."""
|
||||
|
||||
|
||||
id: uuid.UUID | None = None
|
||||
subsystem: models.SubSystem
|
||||
message: str
|
||||
operation: models.Operation
|
||||
data: dict[str, str] | None = None
|
||||
client_id: uuid.UUID | None = None
|
||||
client_name: str | None = None
|
||||
secret_id: uuid.UUID | None = None
|
||||
secret_name: str | None = None
|
||||
origin: str | None = None
|
||||
timestamp: datetime | None = None
|
||||
|
||||
|
||||
class AuditInfo(BaseModel):
|
||||
"""Information about audit information."""
|
||||
|
||||
entries: int
|
||||
|
||||
|
||||
class AuditListResult(BaseModel):
|
||||
"""Class to return when listing audit entries."""
|
||||
results: Sequence[AuditView]
|
||||
total: int
|
||||
remaining: int
|
||||
@ -1,17 +1,66 @@
|
||||
"""Common helpers."""
|
||||
|
||||
import re
|
||||
from typing import Self
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
import bcrypt
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sshecret_backend.models import Client, ClientAccessPolicy
|
||||
|
||||
from sshecret_backend.models import Client
|
||||
RE_UUID = re.compile(
|
||||
"^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$"
|
||||
)
|
||||
|
||||
RelaxedId = uuid.UUID | str
|
||||
|
||||
|
||||
class IdType(Enum):
|
||||
"""Id type."""
|
||||
|
||||
ID = "id"
|
||||
NAME = "name"
|
||||
|
||||
|
||||
class FlexID(BaseModel):
|
||||
"""Flexible identifier."""
|
||||
|
||||
type: IdType
|
||||
value: RelaxedId
|
||||
|
||||
@classmethod
|
||||
def id(cls, id: RelaxedId) -> Self:
|
||||
"""Construct from ID."""
|
||||
return cls(type=IdType.ID, value=id)
|
||||
|
||||
@classmethod
|
||||
def name(cls, name: str) -> Self:
|
||||
"""Construct from name."""
|
||||
return cls(type=IdType.NAME, value=name)
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str) -> Self:
|
||||
"""Convert from path string."""
|
||||
if value.startswith("id:"):
|
||||
return cls.id(value[3:])
|
||||
elif value.startswith("name:"):
|
||||
return cls.name(value[5:])
|
||||
return cls.name(value)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewClientVersion:
|
||||
"""New client version dataclass."""
|
||||
|
||||
client: Client
|
||||
policies: list[ClientAccessPolicy] = field(default_factory=list)
|
||||
|
||||
RE_UUID = re.compile("^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$")
|
||||
|
||||
def verify_token(token: str, stored_hash: str) -> bool:
|
||||
"""Verify token."""
|
||||
@ -19,12 +68,19 @@ def verify_token(token: str, stored_hash: str) -> bool:
|
||||
stored_bytes = stored_hash.encode("utf-8")
|
||||
return bcrypt.checkpw(token_bytes, stored_bytes)
|
||||
|
||||
async def reload_client_with_relationships(session: AsyncSession, client: Client) -> Client:
|
||||
|
||||
async def reload_client_with_relationships(
|
||||
session: AsyncSession, client: Client
|
||||
) -> Client:
|
||||
"""Reload a client from the database."""
|
||||
session.expunge(client)
|
||||
stmt = (
|
||||
select(Client)
|
||||
.options(selectinload(Client.policies), selectinload(Client.secrets))
|
||||
.options(
|
||||
selectinload(Client.policies),
|
||||
selectinload(Client.secrets),
|
||||
selectinload(Client.previous_version),
|
||||
)
|
||||
.where(Client.id == client.id)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
@ -36,13 +92,29 @@ def client_with_relationships() -> Select[tuple[Client]]:
|
||||
return select(Client).options(
|
||||
selectinload(Client.secrets),
|
||||
selectinload(Client.policies),
|
||||
selectinload(Client.previous_version),
|
||||
)
|
||||
|
||||
async def get_client_by_name(session: AsyncSession, name: str) -> Client | None:
|
||||
"""Get client by name."""
|
||||
|
||||
async def resolve_client_id(
|
||||
session: AsyncSession,
|
||||
name: str,
|
||||
version: int | None = None,
|
||||
include_deleted: bool = False,
|
||||
) -> uuid.UUID | None:
|
||||
"""Get the ID of a client name."""
|
||||
if include_deleted:
|
||||
client_filter = client_with_relationships().where(Client.name == name)
|
||||
client_results = await session.execute(client_filter)
|
||||
return client_results.scalars().first()
|
||||
else:
|
||||
client_filter = query_active_clients().where(Client.name == name)
|
||||
if version:
|
||||
client_filter = client_filter.where(Client.version == version)
|
||||
|
||||
client_result = await session.execute(client_filter)
|
||||
if client := client_result.scalars().first():
|
||||
return client.id
|
||||
return None
|
||||
|
||||
|
||||
async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | None:
|
||||
"""Get client by ID."""
|
||||
@ -50,10 +122,76 @@ async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | Non
|
||||
client_results = await session.execute(client_filter)
|
||||
return client_results.scalars().first()
|
||||
|
||||
async def get_client_by_id_or_name(session: AsyncSession, id_or_name: str) -> Client | None:
|
||||
|
||||
async def get_client_by_id_or_name(
|
||||
session: AsyncSession, id_or_name: str
|
||||
) -> Client | None:
|
||||
"""Get client either by id or name."""
|
||||
if RE_UUID.match(id_or_name):
|
||||
id = uuid.UUID(id_or_name)
|
||||
return await get_client_by_id(session, id)
|
||||
|
||||
return await get_client_by_name(session, id_or_name)
|
||||
|
||||
|
||||
def query_active_clients() -> Select[tuple[Client]]:
|
||||
"""Get all active clients."""
|
||||
client_filter = (
|
||||
client_with_relationships()
|
||||
.where(Client.is_active.is_(True))
|
||||
.where(Client.is_deleted.is_(False))
|
||||
)
|
||||
return client_filter
|
||||
|
||||
|
||||
async def get_client_by_name(session: AsyncSession, name: str) -> Client | None:
|
||||
"""Get client by name.
|
||||
|
||||
This will get the latest client version, unless it's deleted.
|
||||
"""
|
||||
client_filter = (
|
||||
client_with_relationships()
|
||||
.where(Client.is_active.is_(True))
|
||||
.where(Client.is_deleted.is_not(True))
|
||||
.where(Client.name == name)
|
||||
.order_by(Client.version.desc())
|
||||
)
|
||||
client_result = await session.execute(client_filter)
|
||||
return client_result.scalars().first()
|
||||
|
||||
|
||||
async def refresh_client(session: AsyncSession, client: Client) -> None:
|
||||
"""Refresh the client and load in all relationships."""
|
||||
await session.refresh(
|
||||
client,
|
||||
attribute_names=["secrets", "policies", "previous_version", "updated_at"],
|
||||
)
|
||||
|
||||
|
||||
async def create_new_client_version(
|
||||
session: AsyncSession, current_client: Client, new_public_key: str
|
||||
) -> Client:
|
||||
new_client = Client(
|
||||
name=current_client.name,
|
||||
version=current_client.version + 1,
|
||||
description=current_client.description,
|
||||
public_key=new_public_key,
|
||||
previous_version_id=current_client.id,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
# Mark current client as inactive
|
||||
current_client.is_active = False
|
||||
|
||||
# Copy policies
|
||||
for policy in current_client.policies:
|
||||
copied_policy = ClientAccessPolicy(
|
||||
client=new_client,
|
||||
address=policy.source,
|
||||
)
|
||||
session.add(copied_policy)
|
||||
|
||||
session.add(new_client)
|
||||
await session.flush()
|
||||
await refresh_client(session, new_client)
|
||||
return new_client
|
||||
|
||||
@ -1,86 +0,0 @@
|
||||
"""Policies sub-api router factory."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Annotated
|
||||
|
||||
from sshecret_backend.models import ClientAccessPolicy
|
||||
from sshecret_backend.view_models import (
|
||||
ClientPolicyView,
|
||||
ClientPolicyUpdate,
|
||||
)
|
||||
from sshecret_backend.types import AsyncDBSessionDep
|
||||
from sshecret_backend import audit
|
||||
from .common import get_client_by_id_or_name, reload_client_with_relationships
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_policy_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||
"""Construct clients sub-api."""
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/clients/{name}/policies/")
|
||||
async def get_client_policies(
|
||||
name: str, session: Annotated[AsyncSession, Depends(get_db_session)]
|
||||
) -> ClientPolicyView:
|
||||
"""Get client policies."""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
return ClientPolicyView.from_client(client)
|
||||
|
||||
@router.put("/clients/{name}/policies/")
|
||||
async def update_client_policies(
|
||||
request: Request,
|
||||
name: str,
|
||||
policy_update: ClientPolicyUpdate,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientPolicyView:
|
||||
"""Update client policies.
|
||||
|
||||
This is also how you delete policies.
|
||||
"""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
# Remove old policies.
|
||||
policies = await session.scalars(
|
||||
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
|
||||
)
|
||||
deleted_policies: list[ClientAccessPolicy] = []
|
||||
added_policies: list[ClientAccessPolicy] = []
|
||||
for policy in policies.all():
|
||||
await session.delete(policy)
|
||||
deleted_policies.append(policy)
|
||||
|
||||
LOG.debug("Updating client policies with: %r", policy_update.sources)
|
||||
for source in policy_update.sources:
|
||||
LOG.debug("Source %r", source)
|
||||
policy = ClientAccessPolicy(source=str(source), client_id=client.id)
|
||||
session.add(policy)
|
||||
added_policies.append(policy)
|
||||
|
||||
await session.flush()
|
||||
await session.commit()
|
||||
|
||||
client = await reload_client_with_relationships(session, client)
|
||||
for policy in deleted_policies:
|
||||
await audit.audit_remove_policy(session, request, client, policy)
|
||||
|
||||
for policy in added_policies:
|
||||
await audit.audit_update_policy(session, request, client, policy)
|
||||
|
||||
return ClientPolicyView.from_client(client)
|
||||
|
||||
return router
|
||||
@ -0,0 +1,9 @@
|
||||
"""Common API schemas."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BodyValue(BaseModel):
|
||||
"""A generic model with just a value parameter."""
|
||||
|
||||
value: str
|
||||
@ -1,252 +0,0 @@
|
||||
"""Secrets sub-api factory."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy import select
|
||||
from typing import Annotated
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from sshecret_backend.models import Client, ClientSecret
|
||||
from sshecret_backend.view_models import (
|
||||
ClientReference,
|
||||
ClientSecretDetailList,
|
||||
ClientSecretList,
|
||||
ClientSecretPublic,
|
||||
BodyValue,
|
||||
ClientSecretResponse,
|
||||
)
|
||||
from sshecret_backend import audit
|
||||
from sshecret_backend.types import AsyncDBSessionDep
|
||||
from .common import get_client_by_id_or_name
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def lookup_client_secret(
|
||||
session: AsyncSession, client: Client, name: str
|
||||
) -> ClientSecret | None:
|
||||
"""Look up a secret for a client."""
|
||||
statement = (
|
||||
select(ClientSecret)
|
||||
.where(ClientSecret.client_id == client.id)
|
||||
.where(ClientSecret.name == name)
|
||||
)
|
||||
results = await session.scalars(statement)
|
||||
return results.first()
|
||||
|
||||
|
||||
def get_secrets_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||
"""Construct clients sub-api."""
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/clients/{name}/secrets/")
|
||||
async def add_secret_to_client(
|
||||
request: Request,
|
||||
name: str,
|
||||
client_secret: ClientSecretPublic,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> None:
|
||||
"""Add secret to a client."""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
existing_secret = await lookup_client_secret(
|
||||
session, client, client_secret.name
|
||||
)
|
||||
if existing_secret:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot add a secret. A different secret with the same name already exists.",
|
||||
)
|
||||
db_secret = ClientSecret(
|
||||
name=client_secret.name, client_id=client.id, secret=client_secret.secret
|
||||
)
|
||||
session.add(db_secret)
|
||||
await session.commit()
|
||||
await session.refresh(db_secret)
|
||||
await audit.audit_create_secret(session, request, client, db_secret)
|
||||
|
||||
@router.put("/clients/{name}/secrets/{secret_name}")
|
||||
async def update_client_secret(
|
||||
request: Request,
|
||||
name: str,
|
||||
secret_name: str,
|
||||
secret_data: BodyValue,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientSecretResponse:
|
||||
"""Update a client secret.
|
||||
|
||||
This can also be used for destructive creates.
|
||||
"""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
existing_secret = await lookup_client_secret(session, client, secret_name)
|
||||
if existing_secret:
|
||||
existing_secret.secret = secret_data.value
|
||||
|
||||
session.add(existing_secret)
|
||||
await session.commit()
|
||||
await session.refresh(existing_secret)
|
||||
await audit.audit_update_secret(session, request, client, existing_secret)
|
||||
return ClientSecretResponse.from_client_secret(existing_secret)
|
||||
|
||||
db_secret = ClientSecret(
|
||||
name=secret_name,
|
||||
client_id=client.id,
|
||||
secret=secret_data.value,
|
||||
)
|
||||
session.add(db_secret)
|
||||
await session.commit()
|
||||
await session.refresh(db_secret)
|
||||
await audit.audit_create_secret(session, request, client, db_secret)
|
||||
return ClientSecretResponse.from_client_secret(db_secret)
|
||||
|
||||
@router.get("/clients/{name}/secrets/{secret_name}")
|
||||
async def request_client_secret(
|
||||
request: Request,
|
||||
name: str,
|
||||
secret_name: str,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientSecretResponse:
|
||||
"""Get a client secret."""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
secret = await lookup_client_secret(session, client, secret_name)
|
||||
if not secret:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a secret with the given name."
|
||||
)
|
||||
|
||||
response_model = ClientSecretResponse.from_client_secret(secret)
|
||||
await audit.audit_access_secret(session, request, client, secret)
|
||||
return response_model
|
||||
|
||||
@router.delete("/clients/{name}/secrets/{secret_name}")
|
||||
async def delete_client_secret(
|
||||
request: Request,
|
||||
name: str,
|
||||
secret_name: str,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> None:
|
||||
"""Delete a secret."""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
secret = await lookup_client_secret(session, client, secret_name)
|
||||
if not secret:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a secret with the given name."
|
||||
)
|
||||
|
||||
await session.delete(secret)
|
||||
await session.commit()
|
||||
await audit.audit_delete_secret(session, request, client, secret)
|
||||
|
||||
@router.get("/secrets/")
|
||||
async def get_secret_map(
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> list[ClientSecretList]:
|
||||
"""Get a list of all secrets and which clients have them."""
|
||||
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
|
||||
client_secrets = await session.scalars(
|
||||
select(ClientSecret).options(selectinload(ClientSecret.client))
|
||||
)
|
||||
for client_secret in client_secrets.all():
|
||||
if not client_secret.client:
|
||||
if client_secret.name not in client_secret_map:
|
||||
client_secret_map[client_secret.name] = []
|
||||
continue
|
||||
client_secret_map[client_secret.name].append(client_secret.client.name)
|
||||
# audit.audit_client_secret_list(session, request)
|
||||
return [
|
||||
ClientSecretList(name=secret_name, clients=clients)
|
||||
for secret_name, clients in client_secret_map.items()
|
||||
]
|
||||
|
||||
@router.get("/secrets/detailed/")
|
||||
async def get_detailed_secret_map(
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> list[ClientSecretDetailList]:
|
||||
"""Get a list of all secrets and which clients have them."""
|
||||
client_secrets: dict[str, ClientSecretDetailList] = {}
|
||||
all_client_secrets = await session.execute(
|
||||
select(ClientSecret).options(selectinload(ClientSecret.client))
|
||||
)
|
||||
for client_secret in all_client_secrets.scalars().all():
|
||||
if client_secret.name not in client_secrets:
|
||||
client_secrets[client_secret.name] = ClientSecretDetailList(
|
||||
name=client_secret.name
|
||||
)
|
||||
client_secrets[client_secret.name].ids.append(str(client_secret.id))
|
||||
if not client_secret.client:
|
||||
continue
|
||||
client_secrets[client_secret.name].clients.append(
|
||||
ClientReference(
|
||||
id=str(client_secret.client.id), name=client_secret.client.name
|
||||
)
|
||||
)
|
||||
# `audit.audit_client_secret_list(session, request)
|
||||
return list(client_secrets.values())
|
||||
|
||||
@router.get("/secrets/{name}")
|
||||
async def get_secret_clients(
|
||||
name: str,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientSecretList:
|
||||
"""Get a list of which clients has a named secret."""
|
||||
clients: list[str] = []
|
||||
client_secrets = await session.scalars(
|
||||
select(ClientSecret)
|
||||
.options(selectinload(ClientSecret.client))
|
||||
.where(ClientSecret.name == name)
|
||||
)
|
||||
for client_secret in client_secrets.all():
|
||||
if not client_secret.client:
|
||||
continue
|
||||
clients.append(client_secret.client.name)
|
||||
|
||||
return ClientSecretList(name=name, clients=clients)
|
||||
|
||||
@router.get("/secrets/{name}/detailed")
|
||||
async def get_secret_clients_detailed(
|
||||
name: str,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientSecretDetailList:
|
||||
"""Get a list of which clients has a named secret."""
|
||||
detail_list = ClientSecretDetailList(name=name)
|
||||
client_secrets = await session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.name == name)
|
||||
)
|
||||
for client_secret in client_secrets.all():
|
||||
if not client_secret.client:
|
||||
continue
|
||||
detail_list.ids.append(str(client_secret.id))
|
||||
detail_list.clients.append(
|
||||
ClientReference(
|
||||
id=str(client_secret.client.id), name=client_secret.client.name
|
||||
)
|
||||
)
|
||||
|
||||
return detail_list
|
||||
|
||||
return router
|
||||
@ -0,0 +1 @@
|
||||
|
||||
@ -0,0 +1,328 @@
|
||||
"""ClientSecret operations."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from sqlalchemy import Null, distinct, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sshecret_backend import audit
|
||||
from sshecret_backend.api.common import (
|
||||
FlexID,
|
||||
IdType,
|
||||
get_client_by_id,
|
||||
resolve_client_id,
|
||||
)
|
||||
from sshecret_backend.api.secrets.schemas import (
|
||||
ClientReference,
|
||||
ClientSecretDetailList,
|
||||
ClientSecretResponse,
|
||||
ClientSecretStats,
|
||||
DistinctClientSecretStats,
|
||||
)
|
||||
from sshecret_backend.models import Client, ClientSecret
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
RelaxedId = uuid.UUID | str
|
||||
|
||||
|
||||
def _id(id: RelaxedId) -> uuid.UUID:
|
||||
"""Ensure that the ID is a uuid."""
|
||||
if isinstance(id, str):
|
||||
return uuid.UUID(id)
|
||||
return id
|
||||
|
||||
|
||||
class ClientSecretOperations:
|
||||
"""Perform operations on a client's ClientSecrets"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
request: Request,
|
||||
client: FlexID,
|
||||
include_deleted: bool = False,
|
||||
) -> None:
|
||||
"""Create operations class."""
|
||||
self._client_id: uuid.UUID | None = None
|
||||
self.client_name: str | None = None
|
||||
self.client: Client | None = None
|
||||
if client.type is IdType.ID:
|
||||
self._client_id = _id(client.value)
|
||||
else:
|
||||
self.client_name = str(client.value)
|
||||
|
||||
self.include_deleted: bool = include_deleted
|
||||
|
||||
self.session: AsyncSession = session
|
||||
self.request: Request = request
|
||||
|
||||
async def get_client_id(self) -> uuid.UUID | None:
|
||||
"""Get client ID."""
|
||||
if self._client_id:
|
||||
LOG.debug("Returning previously resolved client ID.")
|
||||
return self._client_id
|
||||
|
||||
if not self.client_name:
|
||||
raise RuntimeError("Error: No client information registered.")
|
||||
|
||||
client_id = await resolve_client_id(
|
||||
self.session, self.client_name, include_deleted=self.include_deleted
|
||||
)
|
||||
self._client_id = client_id
|
||||
LOG.debug("Saving client ID %s", client_id)
|
||||
return client_id
|
||||
|
||||
async def get_client(self) -> Client | None:
|
||||
"""Get client."""
|
||||
if self.client:
|
||||
return self.client
|
||||
client_id = await self.get_client_id()
|
||||
if not client_id:
|
||||
return None
|
||||
|
||||
client = await get_client_by_id(self.session, client_id)
|
||||
self.client = client
|
||||
return client
|
||||
|
||||
async def _get_client_secret(
|
||||
self, secret_identifier: FlexID
|
||||
) -> ClientSecret | None:
|
||||
"""Get client secret private method."""
|
||||
client = await self.get_client()
|
||||
if not client:
|
||||
LOG.debug("Client lookup failed.")
|
||||
return None
|
||||
|
||||
if secret_identifier.type is IdType.ID:
|
||||
match_id = _id(secret_identifier.value)
|
||||
LOG.debug("Searching for secrets matching ID %r", match_id)
|
||||
matches = [secret for secret in client.secrets if secret.id == match_id]
|
||||
else:
|
||||
LOG.debug("Searching for secrets matching name")
|
||||
matches = [
|
||||
secret
|
||||
for secret in client.secrets
|
||||
if secret.name == str(secret_identifier.value)
|
||||
]
|
||||
|
||||
for secret_match in matches:
|
||||
if secret_match.deleted:
|
||||
# TODO: Add override for deleted.
|
||||
LOG.debug("Found deleted secret")
|
||||
continue
|
||||
|
||||
await audit.audit_access_secret(
|
||||
self.session, self.request, client, secret_match
|
||||
)
|
||||
LOG.debug("Found matching secret")
|
||||
return secret_match
|
||||
|
||||
LOG.debug("No secrets matched.")
|
||||
return None
|
||||
|
||||
async def get_client_secret(
|
||||
self, secret_identifier: FlexID
|
||||
) -> ClientSecretResponse:
|
||||
"""Get client secret public model."""
|
||||
client_secret = await self._get_client_secret(secret_identifier)
|
||||
if not client_secret:
|
||||
raise HTTPException(status_code=404, detail="Unknown client or secret")
|
||||
return ClientSecretResponse.from_client_secret(client_secret)
|
||||
|
||||
async def create_client_secret(
|
||||
self, name: str, value: str, description: str | None = None
|
||||
) -> ClientSecretResponse:
|
||||
"""Create client Secret."""
|
||||
client = await self.get_client()
|
||||
if not client:
|
||||
raise HTTPException(status_code=404, detail="Client not found.")
|
||||
|
||||
existing_secret = await self._get_client_secret(FlexID.name(name))
|
||||
if existing_secret:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot add a secret. A different secret with the same name already exists.",
|
||||
)
|
||||
secret = ClientSecret(
|
||||
name=name,
|
||||
description=description,
|
||||
client=client,
|
||||
secret=value,
|
||||
)
|
||||
self.session.add(secret)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(secret, attribute_names=["client", "updated_at"])
|
||||
await audit.audit_create_secret(self.session, self.request, client, secret)
|
||||
return ClientSecretResponse.from_client_secret(secret)
|
||||
|
||||
async def update_client_secret_value(
|
||||
self, secret_identifier: FlexID, value: str
|
||||
) -> ClientSecretResponse:
|
||||
"""Update client secret value."""
|
||||
client_secret = await self._get_client_secret(secret_identifier)
|
||||
if not client_secret:
|
||||
# We can use this method to create a secret too.
|
||||
if secret_identifier.type is IdType.NAME:
|
||||
return await self.create_client_secret(
|
||||
str(secret_identifier.value), value
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="Unknown client or secret")
|
||||
client_secret.secret = value
|
||||
self.session.add(client_secret)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(client_secret, ["updated_at"])
|
||||
await audit.audit_update_secret(
|
||||
self.session, self.request, client_secret.client, client_secret
|
||||
)
|
||||
return ClientSecretResponse.from_client_secret(client_secret)
|
||||
|
||||
async def delete_client_secret(self, secret_identifier: FlexID) -> None:
|
||||
"""Delete a client secret."""
|
||||
client_secret = await self._get_client_secret(secret_identifier)
|
||||
if not client_secret:
|
||||
return
|
||||
|
||||
client_secret.deleted = True
|
||||
client_secret.deleted_at = datetime.now(timezone.utc)
|
||||
self.session.add(client_secret)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(client_secret, ["deleted", "deleted_at"])
|
||||
await audit.audit_delete_secret(
|
||||
self.session, self.request, client_secret.client, client_secret
|
||||
)
|
||||
|
||||
|
||||
async def resolve_client_secret_mapping(
|
||||
session: AsyncSession,
|
||||
) -> list[ClientSecretDetailList]:
|
||||
"""Resolve mapping of clients to secrets."""
|
||||
result = await session.execute(
|
||||
select(ClientSecret)
|
||||
.join(ClientSecret.client)
|
||||
.options(selectinload(ClientSecret.client))
|
||||
.where(Client.is_active.is_not(False))
|
||||
.where(ClientSecret.deleted.is_not(True))
|
||||
)
|
||||
client_secrets: dict[str, ClientSecretDetailList] = {}
|
||||
for secret in result.scalars().all():
|
||||
if secret.name not in client_secrets:
|
||||
client_secrets[secret.name] = ClientSecretDetailList(name=secret.name)
|
||||
client_secrets[secret.name].ids.append(str(secret.id))
|
||||
if not secret.client:
|
||||
continue
|
||||
client_secrets[secret.name].clients.append(
|
||||
ClientReference(id=str(secret.client.id), name=secret.client.name)
|
||||
)
|
||||
|
||||
return list(client_secrets.values())
|
||||
|
||||
|
||||
async def resolve_client_secret_clients(
|
||||
session: AsyncSession, name: str, include_deleted: bool = False
|
||||
) -> ClientSecretDetailList | None:
|
||||
"""Resolve client association to a secret."""
|
||||
statement = (
|
||||
select(ClientSecret)
|
||||
.options(selectinload(ClientSecret.client))
|
||||
.where(ClientSecret.name == name)
|
||||
)
|
||||
if not include_deleted:
|
||||
statement = statement.where(ClientSecret.deleted.is_not(True))
|
||||
|
||||
results = await session.execute(statement)
|
||||
clients: ClientSecretDetailList | None = None
|
||||
for client_secret in results.scalars().all():
|
||||
if not clients:
|
||||
# Ensure we don't create the object before we have at least one client.
|
||||
clients = ClientSecretDetailList(name=name)
|
||||
clients.ids.append(str(client_secret.id))
|
||||
if client_secret.client:
|
||||
clients.clients.append(
|
||||
ClientReference(
|
||||
id=str(client_secret.client.id), name=client_secret.client.name
|
||||
)
|
||||
)
|
||||
|
||||
return clients
|
||||
|
||||
|
||||
async def _get_secrets_total(session: AsyncSession) -> int:
|
||||
"""Get total amount of secrets excluding deleted."""
|
||||
statement = (
|
||||
select(func.count())
|
||||
.select_from(ClientSecret)
|
||||
.where(ClientSecret.deleted.is_not(True))
|
||||
)
|
||||
result = await session.execute(statement)
|
||||
return result.scalar_one()
|
||||
|
||||
|
||||
async def _get_distinct_secret_count(session: AsyncSession) -> int:
|
||||
"""Get the amount of distinct secrets, excluding deleted."""
|
||||
statement = (
|
||||
select(func.count())
|
||||
.select_from(ClientSecret)
|
||||
.where(ClientSecret.deleted.is_not(True))
|
||||
.distinct()
|
||||
)
|
||||
result = await session.execute(statement)
|
||||
return result.scalar_one()
|
||||
|
||||
|
||||
async def _get_secret_client_stats(
|
||||
session: AsyncSession,
|
||||
) -> list[DistinctClientSecretStats]:
|
||||
"""Get stats for each named secret."""
|
||||
statement = (
|
||||
(
|
||||
select(
|
||||
ClientSecret.name, func.count(distinct(Client.id)).label("client_count")
|
||||
)
|
||||
)
|
||||
.join(Client, ClientSecret.client_id)
|
||||
.where(ClientSecret.deleted.is_not(True), Client.is_deleted.is_not(True))
|
||||
.group_by(ClientSecret.name)
|
||||
)
|
||||
results = await session.execute(statement)
|
||||
stats: list[DistinctClientSecretStats] = []
|
||||
rows = results.all()
|
||||
for row in rows:
|
||||
secret_name, client_count = row.tuple()
|
||||
stats.append(DistinctClientSecretStats(name=secret_name, clients=client_count))
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
async def _get_secrets_without_clients(
|
||||
session: AsyncSession,
|
||||
) -> int:
|
||||
"""Get secrets without clients."""
|
||||
statement = (
|
||||
select(func.count(distinct(ClientSecret.name)))
|
||||
.where(ClientSecret.deleted.is_not(True))
|
||||
.where(ClientSecret.client_id.is_(Null))
|
||||
)
|
||||
result = await session.execute(statement)
|
||||
return result.scalar_one()
|
||||
|
||||
|
||||
async def get_client_secret_stats(session: AsyncSession) -> ClientSecretStats:
|
||||
"""Get stats for the client secrets.
|
||||
|
||||
TODO: Implement a route with this.
|
||||
"""
|
||||
distinct_secrets = await _get_distinct_secret_count(session)
|
||||
total_secrets = await _get_secrets_total(session)
|
||||
secrets_without_clients = await _get_secrets_without_clients(session)
|
||||
client_stats = await _get_secret_client_stats(session)
|
||||
|
||||
return ClientSecretStats(
|
||||
distinct_secrets=distinct_secrets,
|
||||
total_secrets=total_secrets,
|
||||
secrets_without_clients=secrets_without_clients,
|
||||
client_stats=client_stats,
|
||||
)
|
||||
@ -0,0 +1,107 @@
|
||||
"""Client Secret Router."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from sshecret_backend.api.common import FlexID
|
||||
from sshecret_backend.types import AsyncDBSessionDep
|
||||
from sshecret_backend.api.secrets.operations import (
|
||||
ClientSecretOperations,
|
||||
resolve_client_secret_clients,
|
||||
resolve_client_secret_mapping,
|
||||
)
|
||||
from sshecret_backend.api.schemas import BodyValue
|
||||
from sshecret_backend.api.secrets.schemas import (
|
||||
ClientSecretDetailList,
|
||||
ClientSecretPublic,
|
||||
ClientSecretResponse,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_client_secrets_router(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||
"""Create client secret router."""
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/clients/{client_identifier}/secrets/")
|
||||
async def add_secret_to_client(
|
||||
request: Request,
|
||||
client_identifier: str,
|
||||
client_secret: ClientSecretPublic,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> None:
|
||||
"""Add secret to a client."""
|
||||
client = FlexID.from_string(client_identifier)
|
||||
client_op = ClientSecretOperations(session, request, client)
|
||||
await client_op.create_client_secret(
|
||||
client_secret.name, client_secret.secret, client_secret.description
|
||||
)
|
||||
|
||||
@router.put("/clients/{client_identifier}/secrets/{secret_identifier}")
|
||||
async def update_client_secret(
|
||||
request: Request,
|
||||
client_identifier: str,
|
||||
secret_identifier: str,
|
||||
secret_data: BodyValue,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientSecretResponse:
|
||||
"""Update client secret."""
|
||||
client = FlexID.from_string(client_identifier)
|
||||
secret = FlexID.from_string(secret_identifier)
|
||||
client_op = ClientSecretOperations(session, request, client)
|
||||
return await client_op.update_client_secret_value(secret, secret_data.value)
|
||||
|
||||
@router.get("/clients/{client_identifier}/secrets/{secret_identifier}")
|
||||
async def request_client_secret_named(
|
||||
request: Request,
|
||||
client_identifier: str,
|
||||
secret_identifier: str,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientSecretResponse:
|
||||
"""Get a named secret from a named client."""
|
||||
client = FlexID.from_string(client_identifier)
|
||||
secret = FlexID.from_string(secret_identifier)
|
||||
LOG.debug("Resolved client FlexID: %r", client)
|
||||
LOG.debug("Resolved secret FlexID: %r", secret)
|
||||
client_op = ClientSecretOperations(session, request, client)
|
||||
return await client_op.get_client_secret(secret)
|
||||
|
||||
# TODO: delete_client_secret
|
||||
@router.delete("/clients/{client_identifier}/secrets/{secret_identifier}")
|
||||
async def delete_client_secret(
|
||||
request: Request,
|
||||
client_identifier: str,
|
||||
secret_identifier: str,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> None:
|
||||
"""Delete client secret."""
|
||||
client = FlexID.from_string(client_identifier)
|
||||
secret = FlexID.from_string(secret_identifier)
|
||||
client_op = ClientSecretOperations(session, request, client)
|
||||
await client_op.delete_client_secret(secret)
|
||||
|
||||
@router.get("/secrets/")
|
||||
async def get_secret_map(
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> list[ClientSecretDetailList]:
|
||||
"""Get a list of secrets and which clients have them."""
|
||||
return await resolve_client_secret_mapping(session)
|
||||
|
||||
@router.get("/secrets/{name}")
|
||||
async def get_secret_clients(
|
||||
name: str,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientSecretDetailList:
|
||||
"""Get a list of which clients has a named secret."""
|
||||
result = await resolve_client_secret_clients(session, name)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Secret not found.")
|
||||
return result
|
||||
|
||||
return router
|
||||
@ -0,0 +1,88 @@
|
||||
"""Schemas for the secrets router."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Self, override
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from sshecret_backend import models
|
||||
|
||||
|
||||
class ClientSecretPublic(BaseModel):
|
||||
"""Public model to manage client secrets."""
|
||||
|
||||
name: str
|
||||
secret: str
|
||||
description: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_client_secret(cls, client_secret: models.ClientSecret) -> Self:
|
||||
"""Instantiate from ClientSecret."""
|
||||
return cls(
|
||||
name=client_secret.name,
|
||||
secret=client_secret.secret,
|
||||
description=client_secret.description,
|
||||
)
|
||||
|
||||
|
||||
class ClientSecretResponse(ClientSecretPublic):
|
||||
"""A secret view."""
|
||||
|
||||
id: uuid.UUID
|
||||
created_at: datetime | None
|
||||
updated_at: datetime | None = None
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
def from_client_secret(cls, client_secret: models.ClientSecret) -> Self:
|
||||
"""Instantiate from ClientSecret."""
|
||||
|
||||
return cls(
|
||||
id=client_secret.id,
|
||||
name=client_secret.name,
|
||||
secret=client_secret.secret,
|
||||
created_at=client_secret.created_at,
|
||||
updated_at=client_secret.updated_at,
|
||||
)
|
||||
|
||||
|
||||
class ClientSecretList(BaseModel):
|
||||
"""Model for aggregating identically named secrets."""
|
||||
|
||||
name: str
|
||||
clients: list[str]
|
||||
|
||||
|
||||
class ClientReference(BaseModel):
|
||||
"""Reference to a client."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class ClientSecretDetailList(BaseModel):
|
||||
"""A more detailed version of the ClientSecretList."""
|
||||
|
||||
name: str
|
||||
ids: list[str] = Field(default_factory=list)
|
||||
clients: list[ClientReference] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DistinctClientSecretStats(BaseModel):
|
||||
"""Stats for distinct client secrets."""
|
||||
|
||||
name: str
|
||||
clients: int = 0
|
||||
|
||||
|
||||
class ClientSecretStats(BaseModel):
|
||||
"""Stats row for the clientSecret model.
|
||||
|
||||
Useful for pagination and statistics.
|
||||
"""
|
||||
|
||||
distinct_secrets: int
|
||||
total_secrets: int
|
||||
secrets_without_clients: int
|
||||
client_stats: list[DistinctClientSecretStats] = Field(default_factory=list)
|
||||
@ -12,16 +12,13 @@ from fastapi import (
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
|
||||
|
||||
from .models import init_db_async
|
||||
from .backend_api import get_backend_api
|
||||
from .db import setup_database, get_async_engine
|
||||
from .db import get_async_engine
|
||||
|
||||
from .settings import BackendSettings
|
||||
from .types import AsyncDBSessionDep
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -5,7 +5,14 @@ from fastapi import Request
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy, Operation, SubSystem
|
||||
from .models import (
|
||||
AuditLog,
|
||||
Client,
|
||||
ClientSecret,
|
||||
ClientAccessPolicy,
|
||||
Operation,
|
||||
SubSystem,
|
||||
)
|
||||
|
||||
|
||||
def _get_origin(request: Request) -> str | None:
|
||||
@ -128,6 +135,27 @@ async def audit_update_client(
|
||||
await _write_audit_log(session, request, entry, commit)
|
||||
|
||||
|
||||
async def audit_new_client_version(
|
||||
session: AsyncSession,
|
||||
request: Request,
|
||||
old_client: Client,
|
||||
new_client: Client,
|
||||
commit: bool = True,
|
||||
) -> None:
|
||||
"""Audit an update secret event."""
|
||||
entry = AuditLog(
|
||||
operation=Operation.UPDATE,
|
||||
client_id=old_client.id,
|
||||
client_name=old_client.name,
|
||||
message="Client data updated",
|
||||
data={
|
||||
"new_client_id": str(new_client.id),
|
||||
"new_client_version": new_client.version,
|
||||
},
|
||||
)
|
||||
await _write_audit_log(session, request, entry, commit)
|
||||
|
||||
|
||||
async def audit_update_secret(
|
||||
session: AsyncSession,
|
||||
request: Request,
|
||||
@ -224,6 +252,7 @@ async def audit_access_secret(
|
||||
)
|
||||
await _write_audit_log(session, request, entry, commit)
|
||||
|
||||
|
||||
async def audit_client_secret_list(
|
||||
session: AsyncSession, request: Request, commit: bool = True
|
||||
) -> None:
|
||||
@ -233,4 +262,3 @@ async def audit_client_secret_list(
|
||||
message="All secret names and their clients was viewed",
|
||||
)
|
||||
await _write_audit_log(session, request, entry, commit)
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
import bcrypt
|
||||
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
"""Hash a token."""
|
||||
pwbytes = token.encode("utf-8")
|
||||
|
||||
@ -9,7 +9,9 @@ from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sshecret_backend.db import DatabaseSessionManager
|
||||
from sshecret_backend.settings import BackendSettings
|
||||
from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api
|
||||
from .api.audit.router import create_audit_router
|
||||
from .api.secrets.router import create_client_secrets_router
|
||||
from .api.clients.router import create_client_router
|
||||
from .auth import verify_token
|
||||
from .models import (
|
||||
APIClient,
|
||||
@ -59,9 +61,8 @@ def get_backend_api(
|
||||
dependencies=[Depends(validate_token)],
|
||||
)
|
||||
|
||||
backend_api.include_router(get_audit_api(get_db_session))
|
||||
backend_api.include_router(get_clients_api(get_db_session))
|
||||
backend_api.include_router(get_policy_api(get_db_session))
|
||||
backend_api.include_router(get_secrets_api(get_db_session))
|
||||
backend_api.include_router(create_audit_router(get_db_session))
|
||||
backend_api.include_router(create_client_router(get_db_session))
|
||||
backend_api.include_router(create_client_secrets_router(get_db_session))
|
||||
|
||||
return backend_api
|
||||
|
||||
@ -6,9 +6,15 @@ import sqlite3
|
||||
|
||||
from collections.abc import AsyncIterator, Generator, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Literal
|
||||
from typing import Literal
|
||||
from sqlalchemy import create_engine, Engine, event, select
|
||||
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession, async_sessionmaker, create_async_engine, AsyncEngine
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncConnection,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
AsyncEngine,
|
||||
)
|
||||
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
|
||||
@ -21,11 +27,14 @@ from .models import APIClient, SubSystem
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
class DatabaseSessionManager:
|
||||
def __init__(self, host: URL, **engine_kwargs: str):
|
||||
self._engine: AsyncEngine | None = get_async_engine(host)
|
||||
self._sessionmaker: async_sessionmaker[AsyncSession] | None = async_sessionmaker(autocommit=False, bind=self._engine, expire_on_commit=False)
|
||||
self._sessionmaker: async_sessionmaker[AsyncSession] | None = (
|
||||
async_sessionmaker(
|
||||
autocommit=False, bind=self._engine, expire_on_commit=False
|
||||
)
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
if self._engine is None:
|
||||
@ -81,7 +90,6 @@ def setup_database(
|
||||
return engine, get_db_session
|
||||
|
||||
|
||||
|
||||
def get_engine(url: URL, echo: bool = False) -> Engine:
|
||||
"""Initialize the engine."""
|
||||
engine = create_engine(url, echo=echo)
|
||||
@ -102,6 +110,7 @@ def get_async_engine(url: URL, echo: bool = False, **engine_kwargs: str) -> Asyn
|
||||
"""Get an async engine."""
|
||||
engine = create_async_engine(url, echo=echo, **engine_kwargs)
|
||||
if url.drivername.startswith("sqlite+"):
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def set_sqlite_pragma(
|
||||
dbapi_connection: sqlite3.Connection, _connection_record: object
|
||||
@ -113,18 +122,21 @@ def get_async_engine(url: URL, echo: bool = False, **engine_kwargs: str) -> Asyn
|
||||
return engine
|
||||
|
||||
|
||||
|
||||
def create_api_token_with_value(session: Session, token: str, subsystem: Literal["admin", "sshd"]) -> None:
|
||||
def create_api_token_with_value(
|
||||
session: Session, token: str, subsystem: Literal["admin", "sshd"]
|
||||
) -> None:
|
||||
"""Create API token with a given value."""
|
||||
|
||||
existing = session.scalars(select(APIClient).where(APIClient.subsystem == SubSystem(subsystem))).first()
|
||||
existing = session.scalars(
|
||||
select(APIClient).where(APIClient.subsystem == SubSystem(subsystem))
|
||||
).first()
|
||||
if existing:
|
||||
if verify_token(token, existing.token):
|
||||
LOG.info("Token is up to date.")
|
||||
return
|
||||
LOG.info("Updating token value for subsystem %s", subsystem)
|
||||
hashed = hash_token(token)
|
||||
existing.token=hashed
|
||||
existing.token = hashed
|
||||
session.commit()
|
||||
return
|
||||
|
||||
@ -135,12 +147,19 @@ def create_api_token_with_value(session: Session, token: str, subsystem: Literal
|
||||
session.add(api_token)
|
||||
session.commit()
|
||||
|
||||
def create_api_token(session: Session, subsystem: Literal["admin", "sshd", "test"], recreate: bool = False) -> str:
|
||||
|
||||
def create_api_token(
|
||||
session: Session,
|
||||
subsystem: Literal["admin", "sshd", "test"],
|
||||
recreate: bool = False,
|
||||
) -> str:
|
||||
"""Create API token."""
|
||||
subsys = SubSystem(subsystem)
|
||||
token = secrets.token_urlsafe(32)
|
||||
hashed = hash_token(token)
|
||||
if existing := session.scalars(select(APIClient).where(APIClient.subsystem == subsys)).first():
|
||||
if existing := session.scalars(
|
||||
select(APIClient).where(APIClient.subsystem == subsys)
|
||||
).first():
|
||||
if not recreate:
|
||||
raise RuntimeError("Error: A token already exist for this subsystem.")
|
||||
existing.token = hashed
|
||||
|
||||
@ -51,13 +51,22 @@ class Client(Base):
|
||||
"""Clients."""
|
||||
|
||||
__tablename__: str = "client"
|
||||
__table_args__: tuple[sa.UniqueConstraint, ...] = (
|
||||
sa.UniqueConstraint("name", "version", name="uq_client_name_version"),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
name: Mapped[str] = mapped_column(sa.String, unique=True)
|
||||
version: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=1)
|
||||
|
||||
name: Mapped[str] = mapped_column(sa.String, nullable=False)
|
||||
|
||||
description: Mapped[str | None] = mapped_column(sa.String, nullable=True)
|
||||
public_key: Mapped[str] = mapped_column(sa.Text)
|
||||
public_key: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
is_active: Mapped[bool] = mapped_column(sa.Boolean, default=True)
|
||||
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, default=False)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||
@ -69,10 +78,23 @@ class Client(Base):
|
||||
onupdate=sa.func.now(),
|
||||
)
|
||||
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
secrets: Mapped[list["ClientSecret"]] = relationship(
|
||||
back_populates="client", passive_deletes=True
|
||||
)
|
||||
|
||||
previous_version_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
sa.Uuid(as_uuid=True),
|
||||
sa.ForeignKey("client.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
previous_version: Mapped["Client | None"] = relationship(
|
||||
"Client", remote_side=[id], backref="versions"
|
||||
)
|
||||
|
||||
policies: Mapped[list["ClientAccessPolicy"]] = relationship(back_populates="client")
|
||||
|
||||
|
||||
@ -117,7 +139,7 @@ class ClientSecret(Base):
|
||||
sa.Uuid(as_uuid=True), sa.ForeignKey("client.id", ondelete="CASCADE")
|
||||
)
|
||||
client: Mapped[Client] = relationship(back_populates="secrets")
|
||||
invalidated: Mapped[bool] = mapped_column(default=False)
|
||||
deleted: Mapped[bool] = mapped_column(default=False)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||
@ -129,6 +151,10 @@ class ClientSecret(Base):
|
||||
onupdate=sa.func.now(),
|
||||
)
|
||||
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
|
||||
class APIClient(Base):
|
||||
"""A client on the API.
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
from collections.abc import AsyncGenerator, Callable, Generator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ from fastapi.testclient import TestClient
|
||||
from sshecret.crypto import generate_private_key, generate_public_key_string
|
||||
from sshecret_backend.app import create_backend_app
|
||||
from sshecret_backend.testing import create_test_token
|
||||
from sshecret_backend.view_models import AuditView
|
||||
from sshecret_backend.api.audit.schemas import AuditView
|
||||
from sshecret_backend.settings import BackendSettings
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter("'%(asctime)s - %(levelname)s - %(message)s'")
|
||||
handler.setFormatter(formatter)
|
||||
LOG.addHandler(handler)
|
||||
LOG.setLevel(logging.DEBUG)
|
||||
#LOG.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def make_test_key() -> str:
|
||||
@ -167,8 +167,8 @@ def test_put_add_secret(test_client: TestClient) -> None:
|
||||
response_model = response.json()
|
||||
del response_model["created_at"]
|
||||
del response_model["updated_at"]
|
||||
assert response_model == data
|
||||
|
||||
for key, value in data.items():
|
||||
assert response_model.get(key) == value
|
||||
|
||||
def test_put_update_secret(test_client: TestClient) -> None:
|
||||
"""Test updating a client secret."""
|
||||
@ -407,7 +407,7 @@ def test_get_secret_list(test_client: TestClient) -> None:
|
||||
assert len(entry["clients"]) == 4
|
||||
else:
|
||||
assert len(entry["clients"]) == 1
|
||||
assert entry["clients"][0] == entry["name"]
|
||||
assert entry["clients"][0]["name"] == entry["name"]
|
||||
|
||||
|
||||
def test_get_secret_clients(test_client: TestClient) -> None:
|
||||
@ -428,8 +428,9 @@ def test_get_secret_clients(test_client: TestClient) -> None:
|
||||
|
||||
data = resp.json()
|
||||
assert data["name"] == "commonsecret"
|
||||
assert "client-0" in data["clients"]
|
||||
assert "client-1" not in data["clients"]
|
||||
client_names = [client["name"] for client in data["clients"]]
|
||||
assert "client-0" in client_names
|
||||
assert "client-1" not in client_names
|
||||
assert len(data["clients"]) == 2
|
||||
|
||||
|
||||
@ -473,7 +474,7 @@ def test_operations_with_id(test_client: TestClient) -> None:
|
||||
data = resp.json()
|
||||
client = data["clients"][0]
|
||||
client_id = client["id"]
|
||||
resp = test_client.get(f"/api/v1/clients/{client_id}")
|
||||
resp = test_client.get(f"/api/v1/clients/id:{client_id}")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["name"] == "test"
|
||||
@ -559,3 +560,47 @@ def test_filter_audit_log(test_client: TestClient) -> None:
|
||||
|
||||
assert data["results"][0]["operation"] == "login"
|
||||
assert data["results"][0]["message"] == "message1"
|
||||
|
||||
|
||||
def test_secret_flexid(test_client: TestClient) -> None:
|
||||
"""Test flexible IDs in the secret API."""
|
||||
client_name = "test"
|
||||
create_response = create_client(
|
||||
test_client,
|
||||
client_name,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
assert "id" in create_response.json()
|
||||
client_id = create_response.json()["id"]
|
||||
|
||||
# Create a secret using the client name
|
||||
secrets: dict[str, str] = {}
|
||||
resp = test_client.put(
|
||||
"/api/v1/clients/test/secrets/clientnamesecret",
|
||||
json={"value": "secret"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
secret_data = resp.json()
|
||||
assert "id" in secret_data
|
||||
secrets["clientnamesecret"] = secret_data["id"]
|
||||
|
||||
# Create one using the client ID
|
||||
resp = test_client.put(
|
||||
f"/api/v1/clients/id:{client_id}/secrets/clientidsecret",
|
||||
json={"value": "secret"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
secret_data = resp.json()
|
||||
assert "id" in secret_data
|
||||
secrets["clientidsecret"] = secret_data["id"]
|
||||
|
||||
# Let's try fetching the various permutations
|
||||
for client_identifier in ("test", f"id:{client_id}"):
|
||||
for secret_name, secret_id in secrets.items():
|
||||
for secret_identifier in (secret_name, f"id:{secret_id}"):
|
||||
resp = test_client.get(f"/api/v1/clients/{client_identifier}/secrets/{secret_identifier}")
|
||||
assert resp.status_code == 200
|
||||
resp_body = resp.json()
|
||||
assert "id" in resp_body
|
||||
assert resp_body["id"] == secret_id
|
||||
|
||||
Reference in New Issue
Block a user