Refactor backend views, update secret model #24
@ -1,8 +1,7 @@
|
|||||||
"""API factory modules."""
|
"""API factory modules."""
|
||||||
|
|
||||||
from .audit import get_audit_api
|
from .audit import get_audit_api
|
||||||
from .clients import get_clients_api
|
|
||||||
from .policies import get_policy_api
|
from .policies import get_policy_api
|
||||||
from .secrets import get_secrets_api
|
from .secrets import get_secrets_api
|
||||||
|
|
||||||
__all__ = ["get_audit_api", "get_clients_api", "get_policy_api", "get_secrets_api"]
|
__all__ = ["get_audit_api", "get_policy_api", "get_secrets_api"]
|
||||||
|
|||||||
@ -1,227 +0,0 @@
|
|||||||
"""Client sub-api factory."""
|
|
||||||
|
|
||||||
# pyright: reportUnusedFunction=false
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
import logging
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
|
||||||
from typing import Annotated, Any, Self, TypeVar, cast
|
|
||||||
|
|
||||||
from sqlalchemy import select, func
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.sql import Select
|
|
||||||
from sshecret_backend.types import AsyncDBSessionDep
|
|
||||||
from sshecret_backend.models import Client, ClientSecret
|
|
||||||
from sshecret_backend.view_models import (
|
|
||||||
ClientCreate,
|
|
||||||
ClientQueryResult,
|
|
||||||
ClientView,
|
|
||||||
ClientUpdate,
|
|
||||||
)
|
|
||||||
from sshecret_backend import audit
|
|
||||||
from .common import get_client_by_id_or_name, client_with_relationships
|
|
||||||
|
|
||||||
|
|
||||||
class ClientListParams(BaseModel):
|
|
||||||
"""Client list parameters."""
|
|
||||||
|
|
||||||
limit: int = Field(100, gt=0, le=100)
|
|
||||||
offset: int = Field(0, ge=0)
|
|
||||||
id: uuid.UUID | None = None
|
|
||||||
name: str | None = None
|
|
||||||
name__like: str | None = None
|
|
||||||
name__contains: str | None = None
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def validate_expressions(self) -> Self:
|
|
||||||
"""Validate mutually exclusive expression."""
|
|
||||||
name_filter = False
|
|
||||||
if self.name__like or self.name__contains:
|
|
||||||
name_filter = True
|
|
||||||
if self.name__like and self.name__contains:
|
|
||||||
raise ValueError("You may only specify one name expression")
|
|
||||||
if self.name and name_filter:
|
|
||||||
raise ValueError(
|
|
||||||
"You must either specify name or one of name__like or name__contains"
|
|
||||||
)
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
def filter_client_statement(
|
|
||||||
statement: Select[Any], params: ClientListParams, ignore_limits: bool = False
|
|
||||||
) -> Select[Any]:
|
|
||||||
"""Filter a statement with the provided params."""
|
|
||||||
if params.id:
|
|
||||||
statement = statement.where(Client.id == params.id)
|
|
||||||
|
|
||||||
if params.name:
|
|
||||||
statement = statement.where(Client.name == params.name)
|
|
||||||
elif params.name__like:
|
|
||||||
statement = statement.where(Client.name.like(params.name__like))
|
|
||||||
elif params.name__contains:
|
|
||||||
statement = statement.where(Client.name.contains(params.name__contains))
|
|
||||||
|
|
||||||
if ignore_limits:
|
|
||||||
return statement
|
|
||||||
|
|
||||||
return statement.limit(params.limit).offset(params.offset)
|
|
||||||
|
|
||||||
|
|
||||||
def get_clients_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
|
||||||
"""Construct clients sub-api."""
|
|
||||||
router = APIRouter()
|
|
||||||
|
|
||||||
@router.get("/clients/")
|
|
||||||
async def get_clients(
|
|
||||||
filter_query: Annotated[ClientListParams, Query()],
|
|
||||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
|
||||||
) -> ClientQueryResult:
|
|
||||||
"""Get clients."""
|
|
||||||
# Get total results first
|
|
||||||
count_statement = select(func.count("*")).select_from(Client)
|
|
||||||
count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True))
|
|
||||||
|
|
||||||
total_results = (await session.scalars(count_statement)).one()
|
|
||||||
|
|
||||||
statement = filter_client_statement(client_with_relationships(), filter_query, False)
|
|
||||||
|
|
||||||
results = await session.scalars(statement)
|
|
||||||
remainder = total_results - filter_query.offset - filter_query.limit
|
|
||||||
if remainder < 0:
|
|
||||||
remainder = 0
|
|
||||||
|
|
||||||
clients = list(results.all())
|
|
||||||
clients_view = ClientView.from_client_list(clients)
|
|
||||||
return ClientQueryResult(
|
|
||||||
clients=clients_view,
|
|
||||||
total_results=total_results,
|
|
||||||
remaining_results=remainder,
|
|
||||||
)
|
|
||||||
|
|
||||||
@router.get("/clients/{name}")
|
|
||||||
async def get_client(
|
|
||||||
name: str,
|
|
||||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
|
||||||
) -> ClientView:
|
|
||||||
"""Fetch a client."""
|
|
||||||
client = await get_client_by_id_or_name(session, name)
|
|
||||||
if not client:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404, detail="Cannot find a client with the given name."
|
|
||||||
)
|
|
||||||
return ClientView.from_client(client)
|
|
||||||
|
|
||||||
@router.delete("/clients/{name}")
|
|
||||||
async def delete_client(
|
|
||||||
request: Request,
|
|
||||||
name: str,
|
|
||||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
|
||||||
) -> None:
|
|
||||||
"""Delete a client."""
|
|
||||||
client = await get_client_by_id_or_name(session, name)
|
|
||||||
if not client:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404, detail="Cannot find a client with the given name."
|
|
||||||
)
|
|
||||||
|
|
||||||
await session.delete(client)
|
|
||||||
await session.commit()
|
|
||||||
await audit.audit_delete_client(session, request, client)
|
|
||||||
|
|
||||||
@router.post("/clients/")
|
|
||||||
async def create_client(
|
|
||||||
request: Request,
|
|
||||||
client: ClientCreate,
|
|
||||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
|
||||||
) -> ClientView:
|
|
||||||
"""Create client."""
|
|
||||||
existing = await get_client_by_id_or_name(session, client.name)
|
|
||||||
if existing:
|
|
||||||
raise HTTPException(400, detail="Error: Already a client with that name.")
|
|
||||||
|
|
||||||
db_client = client.to_client()
|
|
||||||
session.add(db_client)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(db_client)
|
|
||||||
db_client = await get_client_by_id_or_name(session, client.name)
|
|
||||||
if not db_client:
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Could not create the client.")
|
|
||||||
await audit.audit_create_client(session, request, db_client)
|
|
||||||
return ClientView.from_client(db_client)
|
|
||||||
|
|
||||||
@router.post("/clients/{name}/public-key")
|
|
||||||
async def update_client_public_key(
|
|
||||||
request: Request,
|
|
||||||
name: str,
|
|
||||||
client_update: ClientUpdate,
|
|
||||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
|
||||||
) -> ClientView:
|
|
||||||
"""Change the public key of a client.
|
|
||||||
|
|
||||||
This invalidates all secrets.
|
|
||||||
"""
|
|
||||||
client = await get_client_by_id_or_name(session, name)
|
|
||||||
if not client:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404, detail="Cannot find a client with the given name."
|
|
||||||
)
|
|
||||||
client.public_key = client_update.public_key
|
|
||||||
matching_secrets = await session.scalars(select(ClientSecret).where(ClientSecret.client_id == client.id))
|
|
||||||
for secret in matching_secrets.all():
|
|
||||||
LOG.debug("Invalidated secret %s", secret.id)
|
|
||||||
secret.invalidated = True
|
|
||||||
secret.client_id = None
|
|
||||||
|
|
||||||
session.add(client)
|
|
||||||
await session.refresh(client)
|
|
||||||
await session.commit()
|
|
||||||
await audit.audit_invalidate_secrets(session, request, client)
|
|
||||||
|
|
||||||
return ClientView.from_client(client)
|
|
||||||
|
|
||||||
@router.put("/clients/{name}")
|
|
||||||
async def update_client(
|
|
||||||
request: Request,
|
|
||||||
name: str,
|
|
||||||
client_update: ClientCreate,
|
|
||||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
|
||||||
) -> ClientView:
|
|
||||||
"""Change the public key of a client.
|
|
||||||
|
|
||||||
This invalidates all secrets.
|
|
||||||
"""
|
|
||||||
client = await get_client_by_id_or_name(session, name)
|
|
||||||
if not client:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404, detail="Cannot find a client with the given name."
|
|
||||||
)
|
|
||||||
client.name = client_update.name
|
|
||||||
client.description = client_update.description
|
|
||||||
public_key_updated = False
|
|
||||||
if client_update.public_key != client.public_key:
|
|
||||||
public_key_updated = True
|
|
||||||
client_secrets = await session.scalars(
|
|
||||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
|
||||||
)
|
|
||||||
for secret in client_secrets.all():
|
|
||||||
LOG.debug("Invalidated secret %s", secret.id)
|
|
||||||
secret.invalidated = True
|
|
||||||
secret.client_id = None
|
|
||||||
|
|
||||||
session.add(client)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(client)
|
|
||||||
await audit.audit_update_client(session, request, client)
|
|
||||||
if public_key_updated:
|
|
||||||
await audit.audit_invalidate_secrets(session, request, client)
|
|
||||||
|
|
||||||
return ClientView.from_client(client)
|
|
||||||
|
|
||||||
return router
|
|
||||||
@ -3,15 +3,27 @@
|
|||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
import bcrypt
|
import bcrypt
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from sqlalchemy import Select
|
from sqlalchemy import Select
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from sshecret_backend.models import Client
|
from sshecret_backend.models import Client, ClientAccessPolicy
|
||||||
|
|
||||||
|
RE_UUID = re.compile(
|
||||||
|
"^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NewClientVersion:
|
||||||
|
"""New client version dataclass."""
|
||||||
|
|
||||||
|
client: Client
|
||||||
|
policies: list[ClientAccessPolicy] = field(default_factory=list)
|
||||||
|
|
||||||
RE_UUID = re.compile("^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$")
|
|
||||||
|
|
||||||
def verify_token(token: str, stored_hash: str) -> bool:
|
def verify_token(token: str, stored_hash: str) -> bool:
|
||||||
"""Verify token."""
|
"""Verify token."""
|
||||||
@ -19,12 +31,19 @@ def verify_token(token: str, stored_hash: str) -> bool:
|
|||||||
stored_bytes = stored_hash.encode("utf-8")
|
stored_bytes = stored_hash.encode("utf-8")
|
||||||
return bcrypt.checkpw(token_bytes, stored_bytes)
|
return bcrypt.checkpw(token_bytes, stored_bytes)
|
||||||
|
|
||||||
async def reload_client_with_relationships(session: AsyncSession, client: Client) -> Client:
|
|
||||||
|
async def reload_client_with_relationships(
|
||||||
|
session: AsyncSession, client: Client
|
||||||
|
) -> Client:
|
||||||
"""Reload a client from the database."""
|
"""Reload a client from the database."""
|
||||||
session.expunge(client)
|
session.expunge(client)
|
||||||
stmt = (
|
stmt = (
|
||||||
select(Client)
|
select(Client)
|
||||||
.options(selectinload(Client.policies), selectinload(Client.secrets))
|
.options(
|
||||||
|
selectinload(Client.policies),
|
||||||
|
selectinload(Client.secrets),
|
||||||
|
selectinload(Client.previous_version),
|
||||||
|
)
|
||||||
.where(Client.id == client.id)
|
.where(Client.id == client.id)
|
||||||
)
|
)
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
@ -36,13 +55,26 @@ def client_with_relationships() -> Select[tuple[Client]]:
|
|||||||
return select(Client).options(
|
return select(Client).options(
|
||||||
selectinload(Client.secrets),
|
selectinload(Client.secrets),
|
||||||
selectinload(Client.policies),
|
selectinload(Client.policies),
|
||||||
|
selectinload(Client.previous_version),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_client_by_name(session: AsyncSession, name: str) -> Client | None:
|
|
||||||
"""Get client by name."""
|
async def resolve_client_id(
|
||||||
client_filter = client_with_relationships().where(Client.name == name)
|
session: AsyncSession, name: str, version: int | None = None, include_deleted: bool = False,
|
||||||
client_results = await session.execute(client_filter)
|
) -> uuid.UUID | None:
|
||||||
return client_results.scalars().first()
|
"""Get the ID of a client name."""
|
||||||
|
if include_deleted:
|
||||||
|
client_filter = client_with_relationships().where(Client.name == name)
|
||||||
|
else:
|
||||||
|
client_filter = query_active_clients().where(Client.name == name)
|
||||||
|
if version:
|
||||||
|
client_filter = client_filter.where(Client.version == version)
|
||||||
|
|
||||||
|
client_result = await session.execute(client_filter)
|
||||||
|
if client := client_result.scalars().first():
|
||||||
|
return client.id
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | None:
|
async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | None:
|
||||||
"""Get client by ID."""
|
"""Get client by ID."""
|
||||||
@ -50,10 +82,75 @@ async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | Non
|
|||||||
client_results = await session.execute(client_filter)
|
client_results = await session.execute(client_filter)
|
||||||
return client_results.scalars().first()
|
return client_results.scalars().first()
|
||||||
|
|
||||||
async def get_client_by_id_or_name(session: AsyncSession, id_or_name: str) -> Client | None:
|
|
||||||
|
async def get_client_by_id_or_name(
|
||||||
|
session: AsyncSession, id_or_name: str
|
||||||
|
) -> Client | None:
|
||||||
"""Get client either by id or name."""
|
"""Get client either by id or name."""
|
||||||
if RE_UUID.match(id_or_name):
|
if RE_UUID.match(id_or_name):
|
||||||
id = uuid.UUID(id_or_name)
|
id = uuid.UUID(id_or_name)
|
||||||
return await get_client_by_id(session, id)
|
return await get_client_by_id(session, id)
|
||||||
|
|
||||||
return await get_client_by_name(session, id_or_name)
|
return await get_client_by_name(session, id_or_name)
|
||||||
|
|
||||||
|
|
||||||
|
def query_active_clients() -> Select[tuple[Client]]:
|
||||||
|
"""Get all active clients."""
|
||||||
|
client_filter = (
|
||||||
|
client_with_relationships()
|
||||||
|
.where(Client.is_active.is_(True))
|
||||||
|
.where(Client.is_deleted.is_(False))
|
||||||
|
)
|
||||||
|
return client_filter
|
||||||
|
|
||||||
|
|
||||||
|
async def get_client_by_name(session: AsyncSession, name: str) -> Client | None:
|
||||||
|
"""Get client by name.
|
||||||
|
|
||||||
|
This will get the latest client version, unless it's deleted.
|
||||||
|
"""
|
||||||
|
client_filter = (
|
||||||
|
client_with_relationships()
|
||||||
|
.where(Client.is_active.is_(True))
|
||||||
|
.where(Client.is_deleted.is_not(True))
|
||||||
|
.where(Client.name == name)
|
||||||
|
.order_by(Client.version.desc())
|
||||||
|
)
|
||||||
|
client_result = await session.execute(client_filter)
|
||||||
|
return client_result.scalars().first()
|
||||||
|
|
||||||
|
|
||||||
|
async def refresh_client(session: AsyncSession, client: Client) -> None:
|
||||||
|
"""Refresh the client and load in all relationships."""
|
||||||
|
await session.refresh(
|
||||||
|
client, attribute_names=["secrets", "policies", "previous_version", "updated_at"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_new_client_version(
|
||||||
|
session: AsyncSession, current_client: Client, new_public_key: str
|
||||||
|
) -> Client:
|
||||||
|
new_client = Client(
|
||||||
|
name=current_client.name,
|
||||||
|
version=current_client.version + 1,
|
||||||
|
description=current_client.description,
|
||||||
|
public_key=new_public_key,
|
||||||
|
previous_version_id=current_client.id,
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark current client as inactive
|
||||||
|
current_client.is_active = False
|
||||||
|
|
||||||
|
# Copy policies
|
||||||
|
for policy in current_client.policies:
|
||||||
|
copied_policy = ClientAccessPolicy(
|
||||||
|
client=new_client,
|
||||||
|
address=policy.source,
|
||||||
|
)
|
||||||
|
session.add(copied_policy)
|
||||||
|
|
||||||
|
session.add(new_client)
|
||||||
|
await session.flush()
|
||||||
|
await refresh_client(session, new_client)
|
||||||
|
return new_client
|
||||||
|
|||||||
@ -12,16 +12,13 @@ from fastapi import (
|
|||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from sqlalchemy import Engine
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
|
||||||
|
|
||||||
|
|
||||||
from .models import init_db_async
|
from .models import init_db_async
|
||||||
from .backend_api import get_backend_api
|
from .backend_api import get_backend_api
|
||||||
from .db import setup_database, get_async_engine
|
from .db import get_async_engine
|
||||||
|
|
||||||
from .settings import BackendSettings
|
from .settings import BackendSettings
|
||||||
from .types import AsyncDBSessionDep
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -9,7 +9,8 @@ from sqlalchemy import select
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sshecret_backend.db import DatabaseSessionManager
|
from sshecret_backend.db import DatabaseSessionManager
|
||||||
from sshecret_backend.settings import BackendSettings
|
from sshecret_backend.settings import BackendSettings
|
||||||
from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api
|
from .api import get_audit_api, get_policy_api, get_secrets_api
|
||||||
|
from .api.clients.router import create_client_router
|
||||||
from .auth import verify_token
|
from .auth import verify_token
|
||||||
from .models import (
|
from .models import (
|
||||||
APIClient,
|
APIClient,
|
||||||
@ -60,7 +61,7 @@ def get_backend_api(
|
|||||||
)
|
)
|
||||||
|
|
||||||
backend_api.include_router(get_audit_api(get_db_session))
|
backend_api.include_router(get_audit_api(get_db_session))
|
||||||
backend_api.include_router(get_clients_api(get_db_session))
|
backend_api.include_router(create_client_router(get_db_session))
|
||||||
backend_api.include_router(get_policy_api(get_db_session))
|
backend_api.include_router(get_policy_api(get_db_session))
|
||||||
backend_api.include_router(get_secrets_api(get_db_session))
|
backend_api.include_router(get_secrets_api(get_db_session))
|
||||||
|
|
||||||
|
|||||||
@ -51,13 +51,22 @@ class Client(Base):
|
|||||||
"""Clients."""
|
"""Clients."""
|
||||||
|
|
||||||
__tablename__: str = "client"
|
__tablename__: str = "client"
|
||||||
|
__table_args__: tuple[sa.UniqueConstraint, ...] = (
|
||||||
|
sa.UniqueConstraint("name", "version", name="uq_client_name_version"),
|
||||||
|
)
|
||||||
|
|
||||||
id: Mapped[uuid.UUID] = mapped_column(
|
id: Mapped[uuid.UUID] = mapped_column(
|
||||||
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||||
)
|
)
|
||||||
name: Mapped[str] = mapped_column(sa.String, unique=True)
|
version: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=1)
|
||||||
|
|
||||||
|
name: Mapped[str] = mapped_column(sa.String, nullable=False)
|
||||||
|
|
||||||
description: Mapped[str | None] = mapped_column(sa.String, nullable=True)
|
description: Mapped[str | None] = mapped_column(sa.String, nullable=True)
|
||||||
public_key: Mapped[str] = mapped_column(sa.Text)
|
public_key: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||||
|
|
||||||
|
is_active: Mapped[bool] = mapped_column(sa.Boolean, default=True)
|
||||||
|
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, default=False)
|
||||||
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||||
@ -69,10 +78,26 @@ class Client(Base):
|
|||||||
onupdate=sa.func.now(),
|
onupdate=sa.func.now(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
deleted_at: Mapped[datetime | None] = mapped_column(
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
secrets: Mapped[list["ClientSecret"]] = relationship(
|
secrets: Mapped[list["ClientSecret"]] = relationship(
|
||||||
back_populates="client", passive_deletes=True
|
back_populates="client", passive_deletes=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
previous_version_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
|
sa.Uuid(as_uuid=True),
|
||||||
|
sa.ForeignKey("client.id", ondelete="SET NULL"),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
previous_version: Mapped["Client | None"] = relationship(
|
||||||
|
"Client",
|
||||||
|
remote_side=[id],
|
||||||
|
backref="versions"
|
||||||
|
)
|
||||||
|
|
||||||
policies: Mapped[list["ClientAccessPolicy"]] = relationship(back_populates="client")
|
policies: Mapped[list["ClientAccessPolicy"]] = relationship(back_populates="client")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -2,82 +2,15 @@
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Annotated, Self, Sequence, override
|
from typing import Self, override
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from pydantic import AfterValidator, BaseModel, Field, IPvAnyAddress, IPvAnyNetwork
|
from pydantic import BaseModel, Field, IPvAnyAddress, IPvAnyNetwork
|
||||||
|
|
||||||
from sshecret.crypto import public_key_validator
|
|
||||||
|
|
||||||
from . import models
|
from . import models
|
||||||
|
|
||||||
|
|
||||||
class ClientView(BaseModel):
|
|
||||||
"""View for a single client."""
|
|
||||||
|
|
||||||
id: uuid.UUID
|
|
||||||
name: str
|
|
||||||
description: str | None = None
|
|
||||||
public_key: str
|
|
||||||
policies: list[str] = ["0.0.0.0/0", "::/0"]
|
|
||||||
secrets: list[str] = Field(default_factory=list)
|
|
||||||
created_at: datetime | None
|
|
||||||
updated_at: datetime | None = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_client_list(cls, clients: list[models.Client]) -> list[Self]:
|
|
||||||
"""Generate a list of responses from a list of clients."""
|
|
||||||
responses: list[Self] = [cls.from_client(client) for client in clients]
|
|
||||||
return responses
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_client(cls, client: models.Client) -> Self:
|
|
||||||
"""Instantiate from a client."""
|
|
||||||
view = cls(
|
|
||||||
id=client.id,
|
|
||||||
name=client.name,
|
|
||||||
description=client.description,
|
|
||||||
public_key=client.public_key,
|
|
||||||
created_at=client.created_at,
|
|
||||||
updated_at=client.updated_at or None,
|
|
||||||
)
|
|
||||||
if client.secrets:
|
|
||||||
view.secrets = [secret.name for secret in client.secrets]
|
|
||||||
|
|
||||||
if client.policies:
|
|
||||||
view.policies = [policy.source for policy in client.policies]
|
|
||||||
|
|
||||||
return view
|
|
||||||
|
|
||||||
|
|
||||||
class ClientQueryResult(BaseModel):
|
|
||||||
"""Result class for queries towards the client list."""
|
|
||||||
|
|
||||||
clients: list[ClientView] = Field(default_factory=list)
|
|
||||||
total_results: int
|
|
||||||
remaining_results: int
|
|
||||||
|
|
||||||
|
|
||||||
class ClientCreate(BaseModel):
|
|
||||||
"""Model to create a client."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
description: str | None = None
|
|
||||||
public_key: Annotated[str, AfterValidator(public_key_validator)]
|
|
||||||
|
|
||||||
def to_client(self) -> models.Client:
|
|
||||||
"""Instantiate a client."""
|
|
||||||
return models.Client(
|
|
||||||
name=self.name,
|
|
||||||
public_key=self.public_key,
|
|
||||||
description=self.description,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ClientUpdate(BaseModel):
|
|
||||||
"""Model to update the client public key."""
|
|
||||||
|
|
||||||
public_key: Annotated[str, AfterValidator(public_key_validator)]
|
|
||||||
|
|
||||||
|
|
||||||
class BodyValue(BaseModel):
|
class BodyValue(BaseModel):
|
||||||
"""A generic model with just a value parameter."""
|
"""A generic model with just a value parameter."""
|
||||||
|
|||||||
@ -20,7 +20,7 @@ handler = logging.StreamHandler()
|
|||||||
formatter = logging.Formatter("'%(asctime)s - %(levelname)s - %(message)s'")
|
formatter = logging.Formatter("'%(asctime)s - %(levelname)s - %(message)s'")
|
||||||
handler.setFormatter(formatter)
|
handler.setFormatter(formatter)
|
||||||
LOG.addHandler(handler)
|
LOG.addHandler(handler)
|
||||||
LOG.setLevel(logging.DEBUG)
|
#LOG.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
def make_test_key() -> str:
|
def make_test_key() -> str:
|
||||||
@ -473,7 +473,7 @@ def test_operations_with_id(test_client: TestClient) -> None:
|
|||||||
data = resp.json()
|
data = resp.json()
|
||||||
client = data["clients"][0]
|
client = data["clients"][0]
|
||||||
client_id = client["id"]
|
client_id = client["id"]
|
||||||
resp = test_client.get(f"/api/v1/clients/{client_id}")
|
resp = test_client.get(f"/api/v1/clients/by-id/{client_id}")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
assert data["name"] == "test"
|
assert data["name"] == "test"
|
||||||
|
|||||||
Reference in New Issue
Block a user