"""Common helpers.""" import re from typing import Self import uuid from dataclasses import dataclass, field from enum import Enum import bcrypt from pydantic import BaseModel from sqlalchemy import Select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload from sshecret_backend.models import Client, ClientAccessPolicy RE_UUID = re.compile( "^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$" ) RelaxedId = uuid.UUID | str class IdType(Enum): """Id type.""" ID = "id" NAME = "name" class FlexID(BaseModel): """Flexible identifier.""" type: IdType value: RelaxedId @classmethod def id(cls, id: RelaxedId) -> Self: """Construct from ID.""" return cls(type=IdType.ID, value=id) @classmethod def name(cls, name: str) -> Self: """Construct from name.""" return cls(type=IdType.NAME, value=name) @classmethod def from_string(cls, value: str) -> Self: """Convert from path string.""" if value.startswith("id:"): return cls.id(value[3:]) elif value.startswith("name:"): return cls.name(value[5:]) return cls.name(value) @dataclass class NewClientVersion: """New client version dataclass.""" client: Client policies: list[ClientAccessPolicy] = field(default_factory=list) def verify_token(token: str, stored_hash: str) -> bool: """Verify token.""" token_bytes = token.encode("utf-8") stored_bytes = stored_hash.encode("utf-8") return bcrypt.checkpw(token_bytes, stored_bytes) async def reload_client_with_relationships( session: AsyncSession, client: Client ) -> Client: """Reload a client from the database.""" session.expunge(client) stmt = ( select(Client) .options( selectinload(Client.policies), selectinload(Client.secrets), selectinload(Client.previous_version), ) .where(Client.id == client.id) ) result = await session.execute(stmt) return result.scalar_one() def client_with_relationships() -> Select[tuple[Client]]: """Base select statement for client with relationships.""" return select(Client).options( selectinload(Client.secrets), selectinload(Client.policies), selectinload(Client.previous_version), ) async def resolve_client_id( session: AsyncSession, name: str, version: int | None = None, include_deleted: bool = False, ) -> uuid.UUID | None: """Get the ID of a client name.""" if include_deleted: client_filter = client_with_relationships().where(Client.name == name) else: client_filter = query_active_clients().where(Client.name == name) if version: client_filter = client_filter.where(Client.version == version) client_result = await session.execute(client_filter) if client := client_result.scalars().first(): return client.id return None async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | None: """Get client by ID.""" client_filter = client_with_relationships().where(Client.id == id) client_results = await session.execute(client_filter) return client_results.scalars().first() async def get_client_by_id_or_name( session: AsyncSession, id_or_name: str ) -> Client | None: """Get client either by id or name.""" if RE_UUID.match(id_or_name): id = uuid.UUID(id_or_name) return await get_client_by_id(session, id) return await get_client_by_name(session, id_or_name) def query_active_clients() -> Select[tuple[Client]]: """Get all active clients.""" client_filter = ( client_with_relationships() .where(Client.is_active.is_(True)) .where(Client.is_deleted.is_(False)) ) return client_filter async def get_client_by_name(session: AsyncSession, name: str) -> Client | None: """Get client by name. This will get the latest client version, unless it's deleted. """ client_filter = ( client_with_relationships() .where(Client.is_active.is_(True)) .where(Client.is_deleted.is_not(True)) .where(Client.name == name) .order_by(Client.version.desc()) ) client_result = await session.execute(client_filter) return client_result.scalars().first() async def refresh_client(session: AsyncSession, client: Client) -> None: """Refresh the client and load in all relationships.""" await session.refresh( client, attribute_names=["secrets", "policies", "previous_version", "updated_at"], ) async def create_new_client_version( session: AsyncSession, current_client: Client, new_public_key: str ) -> Client: new_client = Client( name=current_client.name, version=current_client.version + 1, description=current_client.description, public_key=new_public_key, previous_version_id=current_client.id, is_active=True, ) # Mark current client as inactive current_client.is_active = False # Copy policies for policy in current_client.policies: copied_policy = ClientAccessPolicy( client=new_client, address=policy.source, ) session.add(copied_policy) session.add(new_client) await session.flush() await refresh_client(session, new_client) return new_client