Files
sshecret/packages/sshecret-backend/src/sshecret_backend/api/common.py
2025-06-08 17:43:34 +02:00

198 lines
5.4 KiB
Python

"""Common helpers."""
import re
from typing import Self
import uuid
from dataclasses import dataclass, field
from enum import Enum
import bcrypt
from pydantic import BaseModel
from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from sshecret_backend.models import Client, ClientAccessPolicy
RE_UUID = re.compile(
"^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$"
)
RelaxedId = uuid.UUID | str
class IdType(Enum):
"""Id type."""
ID = "id"
NAME = "name"
class FlexID(BaseModel):
"""Flexible identifier."""
type: IdType
value: RelaxedId
@classmethod
def id(cls, id: RelaxedId) -> Self:
"""Construct from ID."""
return cls(type=IdType.ID, value=id)
@classmethod
def name(cls, name: str) -> Self:
"""Construct from name."""
return cls(type=IdType.NAME, value=name)
@classmethod
def from_string(cls, value: str) -> Self:
"""Convert from path string."""
if value.startswith("id:"):
return cls.id(value[3:])
elif value.startswith("name:"):
return cls.name(value[5:])
return cls.name(value)
@dataclass
class NewClientVersion:
"""New client version dataclass."""
client: Client
policies: list[ClientAccessPolicy] = field(default_factory=list)
def verify_token(token: str, stored_hash: str) -> bool:
"""Verify token."""
token_bytes = token.encode("utf-8")
stored_bytes = stored_hash.encode("utf-8")
return bcrypt.checkpw(token_bytes, stored_bytes)
async def reload_client_with_relationships(
session: AsyncSession, client: Client
) -> Client:
"""Reload a client from the database."""
session.expunge(client)
stmt = (
select(Client)
.options(
selectinload(Client.policies),
selectinload(Client.secrets),
selectinload(Client.previous_version),
)
.where(Client.id == client.id)
)
result = await session.execute(stmt)
return result.scalar_one()
def client_with_relationships() -> Select[tuple[Client]]:
"""Base select statement for client with relationships."""
return select(Client).options(
selectinload(Client.secrets),
selectinload(Client.policies),
selectinload(Client.previous_version),
)
async def resolve_client_id(
session: AsyncSession,
name: str,
version: int | None = None,
include_deleted: bool = False,
) -> uuid.UUID | None:
"""Get the ID of a client name."""
if include_deleted:
client_filter = client_with_relationships().where(Client.name == name)
else:
client_filter = query_active_clients().where(Client.name == name)
if version:
client_filter = client_filter.where(Client.version == version)
client_result = await session.execute(client_filter)
if client := client_result.scalars().first():
return client.id
return None
async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | None:
"""Get client by ID."""
client_filter = client_with_relationships().where(Client.id == id)
client_results = await session.execute(client_filter)
return client_results.scalars().first()
async def get_client_by_id_or_name(
session: AsyncSession, id_or_name: str
) -> Client | None:
"""Get client either by id or name."""
if RE_UUID.match(id_or_name):
id = uuid.UUID(id_or_name)
return await get_client_by_id(session, id)
return await get_client_by_name(session, id_or_name)
def query_active_clients() -> Select[tuple[Client]]:
"""Get all active clients."""
client_filter = (
client_with_relationships()
.where(Client.is_active.is_(True))
.where(Client.is_deleted.is_(False))
)
return client_filter
async def get_client_by_name(session: AsyncSession, name: str) -> Client | None:
"""Get client by name.
This will get the latest client version, unless it's deleted.
"""
client_filter = (
client_with_relationships()
.where(Client.is_active.is_(True))
.where(Client.is_deleted.is_not(True))
.where(Client.name == name)
.order_by(Client.version.desc())
)
client_result = await session.execute(client_filter)
return client_result.scalars().first()
async def refresh_client(session: AsyncSession, client: Client) -> None:
"""Refresh the client and load in all relationships."""
await session.refresh(
client,
attribute_names=["secrets", "policies", "previous_version", "updated_at"],
)
async def create_new_client_version(
session: AsyncSession, current_client: Client, new_public_key: str
) -> Client:
new_client = Client(
name=current_client.name,
version=current_client.version + 1,
description=current_client.description,
public_key=new_public_key,
previous_version_id=current_client.id,
is_active=True,
)
# Mark current client as inactive
current_client.is_active = False
# Copy policies
for policy in current_client.policies:
copied_policy = ClientAccessPolicy(
client=new_client,
address=policy.source,
)
session.add(copied_policy)
session.add(new_client)
await session.flush()
await refresh_client(session, new_client)
return new_client