Refactor backend views, update secret model #24
@ -1,8 +1,7 @@
|
||||
"""API factory modules."""
|
||||
|
||||
from .audit import get_audit_api
|
||||
from .clients import get_clients_api
|
||||
from .policies import get_policy_api
|
||||
from .secrets import get_secrets_api
|
||||
|
||||
__all__ = ["get_audit_api", "get_clients_api", "get_policy_api", "get_secrets_api"]
|
||||
__all__ = ["get_audit_api", "get_policy_api", "get_secrets_api"]
|
||||
|
||||
@ -1,227 +0,0 @@
|
||||
"""Client sub-api factory."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
|
||||
import uuid
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing import Annotated, Any, Self, TypeVar, cast
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.sql import Select
|
||||
from sshecret_backend.types import AsyncDBSessionDep
|
||||
from sshecret_backend.models import Client, ClientSecret
|
||||
from sshecret_backend.view_models import (
|
||||
ClientCreate,
|
||||
ClientQueryResult,
|
||||
ClientView,
|
||||
ClientUpdate,
|
||||
)
|
||||
from sshecret_backend import audit
|
||||
from .common import get_client_by_id_or_name, client_with_relationships
|
||||
|
||||
|
||||
class ClientListParams(BaseModel):
|
||||
"""Client list parameters."""
|
||||
|
||||
limit: int = Field(100, gt=0, le=100)
|
||||
offset: int = Field(0, ge=0)
|
||||
id: uuid.UUID | None = None
|
||||
name: str | None = None
|
||||
name__like: str | None = None
|
||||
name__contains: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_expressions(self) -> Self:
|
||||
"""Validate mutually exclusive expression."""
|
||||
name_filter = False
|
||||
if self.name__like or self.name__contains:
|
||||
name_filter = True
|
||||
if self.name__like and self.name__contains:
|
||||
raise ValueError("You may only specify one name expression")
|
||||
if self.name and name_filter:
|
||||
raise ValueError(
|
||||
"You must either specify name or one of name__like or name__contains"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def filter_client_statement(
|
||||
statement: Select[Any], params: ClientListParams, ignore_limits: bool = False
|
||||
) -> Select[Any]:
|
||||
"""Filter a statement with the provided params."""
|
||||
if params.id:
|
||||
statement = statement.where(Client.id == params.id)
|
||||
|
||||
if params.name:
|
||||
statement = statement.where(Client.name == params.name)
|
||||
elif params.name__like:
|
||||
statement = statement.where(Client.name.like(params.name__like))
|
||||
elif params.name__contains:
|
||||
statement = statement.where(Client.name.contains(params.name__contains))
|
||||
|
||||
if ignore_limits:
|
||||
return statement
|
||||
|
||||
return statement.limit(params.limit).offset(params.offset)
|
||||
|
||||
|
||||
def get_clients_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||
"""Construct clients sub-api."""
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/clients/")
|
||||
async def get_clients(
|
||||
filter_query: Annotated[ClientListParams, Query()],
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientQueryResult:
|
||||
"""Get clients."""
|
||||
# Get total results first
|
||||
count_statement = select(func.count("*")).select_from(Client)
|
||||
count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True))
|
||||
|
||||
total_results = (await session.scalars(count_statement)).one()
|
||||
|
||||
statement = filter_client_statement(client_with_relationships(), filter_query, False)
|
||||
|
||||
results = await session.scalars(statement)
|
||||
remainder = total_results - filter_query.offset - filter_query.limit
|
||||
if remainder < 0:
|
||||
remainder = 0
|
||||
|
||||
clients = list(results.all())
|
||||
clients_view = ClientView.from_client_list(clients)
|
||||
return ClientQueryResult(
|
||||
clients=clients_view,
|
||||
total_results=total_results,
|
||||
remaining_results=remainder,
|
||||
)
|
||||
|
||||
@router.get("/clients/{name}")
|
||||
async def get_client(
|
||||
name: str,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientView:
|
||||
"""Fetch a client."""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
return ClientView.from_client(client)
|
||||
|
||||
@router.delete("/clients/{name}")
|
||||
async def delete_client(
|
||||
request: Request,
|
||||
name: str,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> None:
|
||||
"""Delete a client."""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
await session.delete(client)
|
||||
await session.commit()
|
||||
await audit.audit_delete_client(session, request, client)
|
||||
|
||||
@router.post("/clients/")
|
||||
async def create_client(
|
||||
request: Request,
|
||||
client: ClientCreate,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientView:
|
||||
"""Create client."""
|
||||
existing = await get_client_by_id_or_name(session, client.name)
|
||||
if existing:
|
||||
raise HTTPException(400, detail="Error: Already a client with that name.")
|
||||
|
||||
db_client = client.to_client()
|
||||
session.add(db_client)
|
||||
await session.commit()
|
||||
await session.refresh(db_client)
|
||||
db_client = await get_client_by_id_or_name(session, client.name)
|
||||
if not db_client:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Could not create the client.")
|
||||
await audit.audit_create_client(session, request, db_client)
|
||||
return ClientView.from_client(db_client)
|
||||
|
||||
@router.post("/clients/{name}/public-key")
|
||||
async def update_client_public_key(
|
||||
request: Request,
|
||||
name: str,
|
||||
client_update: ClientUpdate,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientView:
|
||||
"""Change the public key of a client.
|
||||
|
||||
This invalidates all secrets.
|
||||
"""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
client.public_key = client_update.public_key
|
||||
matching_secrets = await session.scalars(select(ClientSecret).where(ClientSecret.client_id == client.id))
|
||||
for secret in matching_secrets.all():
|
||||
LOG.debug("Invalidated secret %s", secret.id)
|
||||
secret.invalidated = True
|
||||
secret.client_id = None
|
||||
|
||||
session.add(client)
|
||||
await session.refresh(client)
|
||||
await session.commit()
|
||||
await audit.audit_invalidate_secrets(session, request, client)
|
||||
|
||||
return ClientView.from_client(client)
|
||||
|
||||
@router.put("/clients/{name}")
|
||||
async def update_client(
|
||||
request: Request,
|
||||
name: str,
|
||||
client_update: ClientCreate,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientView:
|
||||
"""Change the public key of a client.
|
||||
|
||||
This invalidates all secrets.
|
||||
"""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
client.name = client_update.name
|
||||
client.description = client_update.description
|
||||
public_key_updated = False
|
||||
if client_update.public_key != client.public_key:
|
||||
public_key_updated = True
|
||||
client_secrets = await session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||
)
|
||||
for secret in client_secrets.all():
|
||||
LOG.debug("Invalidated secret %s", secret.id)
|
||||
secret.invalidated = True
|
||||
secret.client_id = None
|
||||
|
||||
session.add(client)
|
||||
await session.commit()
|
||||
await session.refresh(client)
|
||||
await audit.audit_update_client(session, request, client)
|
||||
if public_key_updated:
|
||||
await audit.audit_invalidate_secrets(session, request, client)
|
||||
|
||||
return ClientView.from_client(client)
|
||||
|
||||
return router
|
||||
@ -3,15 +3,27 @@
|
||||
import re
|
||||
import uuid
|
||||
import bcrypt
|
||||
from dataclasses import dataclass, field
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from sshecret_backend.models import Client
|
||||
from sshecret_backend.models import Client, ClientAccessPolicy
|
||||
|
||||
RE_UUID = re.compile(
|
||||
"^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewClientVersion:
|
||||
"""New client version dataclass."""
|
||||
|
||||
client: Client
|
||||
policies: list[ClientAccessPolicy] = field(default_factory=list)
|
||||
|
||||
RE_UUID = re.compile("^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$")
|
||||
|
||||
def verify_token(token: str, stored_hash: str) -> bool:
|
||||
"""Verify token."""
|
||||
@ -19,12 +31,19 @@ def verify_token(token: str, stored_hash: str) -> bool:
|
||||
stored_bytes = stored_hash.encode("utf-8")
|
||||
return bcrypt.checkpw(token_bytes, stored_bytes)
|
||||
|
||||
async def reload_client_with_relationships(session: AsyncSession, client: Client) -> Client:
|
||||
|
||||
async def reload_client_with_relationships(
|
||||
session: AsyncSession, client: Client
|
||||
) -> Client:
|
||||
"""Reload a client from the database."""
|
||||
session.expunge(client)
|
||||
stmt = (
|
||||
select(Client)
|
||||
.options(selectinload(Client.policies), selectinload(Client.secrets))
|
||||
.options(
|
||||
selectinload(Client.policies),
|
||||
selectinload(Client.secrets),
|
||||
selectinload(Client.previous_version),
|
||||
)
|
||||
.where(Client.id == client.id)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
@ -36,13 +55,26 @@ def client_with_relationships() -> Select[tuple[Client]]:
|
||||
return select(Client).options(
|
||||
selectinload(Client.secrets),
|
||||
selectinload(Client.policies),
|
||||
selectinload(Client.previous_version),
|
||||
)
|
||||
|
||||
async def get_client_by_name(session: AsyncSession, name: str) -> Client | None:
|
||||
"""Get client by name."""
|
||||
|
||||
async def resolve_client_id(
|
||||
session: AsyncSession, name: str, version: int | None = None, include_deleted: bool = False,
|
||||
) -> uuid.UUID | None:
|
||||
"""Get the ID of a client name."""
|
||||
if include_deleted:
|
||||
client_filter = client_with_relationships().where(Client.name == name)
|
||||
client_results = await session.execute(client_filter)
|
||||
return client_results.scalars().first()
|
||||
else:
|
||||
client_filter = query_active_clients().where(Client.name == name)
|
||||
if version:
|
||||
client_filter = client_filter.where(Client.version == version)
|
||||
|
||||
client_result = await session.execute(client_filter)
|
||||
if client := client_result.scalars().first():
|
||||
return client.id
|
||||
return None
|
||||
|
||||
|
||||
async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | None:
|
||||
"""Get client by ID."""
|
||||
@ -50,10 +82,75 @@ async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | Non
|
||||
client_results = await session.execute(client_filter)
|
||||
return client_results.scalars().first()
|
||||
|
||||
async def get_client_by_id_or_name(session: AsyncSession, id_or_name: str) -> Client | None:
|
||||
|
||||
async def get_client_by_id_or_name(
|
||||
session: AsyncSession, id_or_name: str
|
||||
) -> Client | None:
|
||||
"""Get client either by id or name."""
|
||||
if RE_UUID.match(id_or_name):
|
||||
id = uuid.UUID(id_or_name)
|
||||
return await get_client_by_id(session, id)
|
||||
|
||||
return await get_client_by_name(session, id_or_name)
|
||||
|
||||
|
||||
def query_active_clients() -> Select[tuple[Client]]:
|
||||
"""Get all active clients."""
|
||||
client_filter = (
|
||||
client_with_relationships()
|
||||
.where(Client.is_active.is_(True))
|
||||
.where(Client.is_deleted.is_(False))
|
||||
)
|
||||
return client_filter
|
||||
|
||||
|
||||
async def get_client_by_name(session: AsyncSession, name: str) -> Client | None:
|
||||
"""Get client by name.
|
||||
|
||||
This will get the latest client version, unless it's deleted.
|
||||
"""
|
||||
client_filter = (
|
||||
client_with_relationships()
|
||||
.where(Client.is_active.is_(True))
|
||||
.where(Client.is_deleted.is_not(True))
|
||||
.where(Client.name == name)
|
||||
.order_by(Client.version.desc())
|
||||
)
|
||||
client_result = await session.execute(client_filter)
|
||||
return client_result.scalars().first()
|
||||
|
||||
|
||||
async def refresh_client(session: AsyncSession, client: Client) -> None:
|
||||
"""Refresh the client and load in all relationships."""
|
||||
await session.refresh(
|
||||
client, attribute_names=["secrets", "policies", "previous_version", "updated_at"]
|
||||
)
|
||||
|
||||
|
||||
async def create_new_client_version(
|
||||
session: AsyncSession, current_client: Client, new_public_key: str
|
||||
) -> Client:
|
||||
new_client = Client(
|
||||
name=current_client.name,
|
||||
version=current_client.version + 1,
|
||||
description=current_client.description,
|
||||
public_key=new_public_key,
|
||||
previous_version_id=current_client.id,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
# Mark current client as inactive
|
||||
current_client.is_active = False
|
||||
|
||||
# Copy policies
|
||||
for policy in current_client.policies:
|
||||
copied_policy = ClientAccessPolicy(
|
||||
client=new_client,
|
||||
address=policy.source,
|
||||
)
|
||||
session.add(copied_policy)
|
||||
|
||||
session.add(new_client)
|
||||
await session.flush()
|
||||
await refresh_client(session, new_client)
|
||||
return new_client
|
||||
|
||||
@ -12,16 +12,13 @@ from fastapi import (
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
|
||||
|
||||
from .models import init_db_async
|
||||
from .backend_api import get_backend_api
|
||||
from .db import setup_database, get_async_engine
|
||||
from .db import get_async_engine
|
||||
|
||||
from .settings import BackendSettings
|
||||
from .types import AsyncDBSessionDep
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -9,7 +9,8 @@ from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sshecret_backend.db import DatabaseSessionManager
|
||||
from sshecret_backend.settings import BackendSettings
|
||||
from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api
|
||||
from .api import get_audit_api, get_policy_api, get_secrets_api
|
||||
from .api.clients.router import create_client_router
|
||||
from .auth import verify_token
|
||||
from .models import (
|
||||
APIClient,
|
||||
@ -60,7 +61,7 @@ def get_backend_api(
|
||||
)
|
||||
|
||||
backend_api.include_router(get_audit_api(get_db_session))
|
||||
backend_api.include_router(get_clients_api(get_db_session))
|
||||
backend_api.include_router(create_client_router(get_db_session))
|
||||
backend_api.include_router(get_policy_api(get_db_session))
|
||||
backend_api.include_router(get_secrets_api(get_db_session))
|
||||
|
||||
|
||||
@ -51,13 +51,22 @@ class Client(Base):
|
||||
"""Clients."""
|
||||
|
||||
__tablename__: str = "client"
|
||||
__table_args__: tuple[sa.UniqueConstraint, ...] = (
|
||||
sa.UniqueConstraint("name", "version", name="uq_client_name_version"),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
name: Mapped[str] = mapped_column(sa.String, unique=True)
|
||||
version: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=1)
|
||||
|
||||
name: Mapped[str] = mapped_column(sa.String, nullable=False)
|
||||
|
||||
description: Mapped[str | None] = mapped_column(sa.String, nullable=True)
|
||||
public_key: Mapped[str] = mapped_column(sa.Text)
|
||||
public_key: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
is_active: Mapped[bool] = mapped_column(sa.Boolean, default=True)
|
||||
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, default=False)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||
@ -69,10 +78,26 @@ class Client(Base):
|
||||
onupdate=sa.func.now(),
|
||||
)
|
||||
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=True
|
||||
)
|
||||
|
||||
secrets: Mapped[list["ClientSecret"]] = relationship(
|
||||
back_populates="client", passive_deletes=True
|
||||
)
|
||||
|
||||
previous_version_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
sa.Uuid(as_uuid=True),
|
||||
sa.ForeignKey("client.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
previous_version: Mapped["Client | None"] = relationship(
|
||||
"Client",
|
||||
remote_side=[id],
|
||||
backref="versions"
|
||||
)
|
||||
|
||||
policies: Mapped[list["ClientAccessPolicy"]] = relationship(back_populates="client")
|
||||
|
||||
|
||||
|
||||
@ -2,82 +2,15 @@
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Self, Sequence, override
|
||||
from typing import Self, override
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, Field, IPvAnyAddress, IPvAnyNetwork
|
||||
from pydantic import BaseModel, Field, IPvAnyAddress, IPvAnyNetwork
|
||||
|
||||
from sshecret.crypto import public_key_validator
|
||||
|
||||
from . import models
|
||||
|
||||
|
||||
class ClientView(BaseModel):
|
||||
"""View for a single client."""
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
description: str | None = None
|
||||
public_key: str
|
||||
policies: list[str] = ["0.0.0.0/0", "::/0"]
|
||||
secrets: list[str] = Field(default_factory=list)
|
||||
created_at: datetime | None
|
||||
updated_at: datetime | None = None
|
||||
|
||||
@classmethod
|
||||
def from_client_list(cls, clients: list[models.Client]) -> list[Self]:
|
||||
"""Generate a list of responses from a list of clients."""
|
||||
responses: list[Self] = [cls.from_client(client) for client in clients]
|
||||
return responses
|
||||
|
||||
@classmethod
|
||||
def from_client(cls, client: models.Client) -> Self:
|
||||
"""Instantiate from a client."""
|
||||
view = cls(
|
||||
id=client.id,
|
||||
name=client.name,
|
||||
description=client.description,
|
||||
public_key=client.public_key,
|
||||
created_at=client.created_at,
|
||||
updated_at=client.updated_at or None,
|
||||
)
|
||||
if client.secrets:
|
||||
view.secrets = [secret.name for secret in client.secrets]
|
||||
|
||||
if client.policies:
|
||||
view.policies = [policy.source for policy in client.policies]
|
||||
|
||||
return view
|
||||
|
||||
|
||||
class ClientQueryResult(BaseModel):
|
||||
"""Result class for queries towards the client list."""
|
||||
|
||||
clients: list[ClientView] = Field(default_factory=list)
|
||||
total_results: int
|
||||
remaining_results: int
|
||||
|
||||
|
||||
class ClientCreate(BaseModel):
|
||||
"""Model to create a client."""
|
||||
|
||||
name: str
|
||||
description: str | None = None
|
||||
public_key: Annotated[str, AfterValidator(public_key_validator)]
|
||||
|
||||
def to_client(self) -> models.Client:
|
||||
"""Instantiate a client."""
|
||||
return models.Client(
|
||||
name=self.name,
|
||||
public_key=self.public_key,
|
||||
description=self.description,
|
||||
)
|
||||
|
||||
|
||||
class ClientUpdate(BaseModel):
|
||||
"""Model to update the client public key."""
|
||||
|
||||
public_key: Annotated[str, AfterValidator(public_key_validator)]
|
||||
|
||||
|
||||
class BodyValue(BaseModel):
|
||||
"""A generic model with just a value parameter."""
|
||||
|
||||
@ -20,7 +20,7 @@ handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter("'%(asctime)s - %(levelname)s - %(message)s'")
|
||||
handler.setFormatter(formatter)
|
||||
LOG.addHandler(handler)
|
||||
LOG.setLevel(logging.DEBUG)
|
||||
#LOG.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def make_test_key() -> str:
|
||||
@ -473,7 +473,7 @@ def test_operations_with_id(test_client: TestClient) -> None:
|
||||
data = resp.json()
|
||||
client = data["clients"][0]
|
||||
client_id = client["id"]
|
||||
resp = test_client.get(f"/api/v1/clients/{client_id}")
|
||||
resp = test_client.get(f"/api/v1/clients/by-id/{client_id}")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["name"] == "test"
|
||||
|
||||
Reference in New Issue
Block a user