198 lines
5.4 KiB
Python
198 lines
5.4 KiB
Python
"""Common helpers."""
|
|
|
|
import re
|
|
from typing import Self
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
|
|
import bcrypt
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import Select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.future import select
|
|
from sqlalchemy.orm import selectinload
|
|
from sshecret_backend.models import Client, ClientAccessPolicy
|
|
|
|
RE_UUID = re.compile(
|
|
"^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$"
|
|
)
|
|
|
|
RelaxedId = uuid.UUID | str
|
|
|
|
|
|
class IdType(Enum):
|
|
"""Id type."""
|
|
|
|
ID = "id"
|
|
NAME = "name"
|
|
|
|
|
|
class FlexID(BaseModel):
|
|
"""Flexible identifier."""
|
|
|
|
type: IdType
|
|
value: RelaxedId
|
|
|
|
@classmethod
|
|
def id(cls, id: RelaxedId) -> Self:
|
|
"""Construct from ID."""
|
|
return cls(type=IdType.ID, value=id)
|
|
|
|
@classmethod
|
|
def name(cls, name: str) -> Self:
|
|
"""Construct from name."""
|
|
return cls(type=IdType.NAME, value=name)
|
|
|
|
@classmethod
|
|
def from_string(cls, value: str) -> Self:
|
|
"""Convert from path string."""
|
|
if value.startswith("id:"):
|
|
return cls.id(value[3:])
|
|
elif value.startswith("name:"):
|
|
return cls.name(value[5:])
|
|
return cls.name(value)
|
|
|
|
|
|
@dataclass
|
|
class NewClientVersion:
|
|
"""New client version dataclass."""
|
|
|
|
client: Client
|
|
policies: list[ClientAccessPolicy] = field(default_factory=list)
|
|
|
|
|
|
def verify_token(token: str, stored_hash: str) -> bool:
|
|
"""Verify token."""
|
|
token_bytes = token.encode("utf-8")
|
|
stored_bytes = stored_hash.encode("utf-8")
|
|
return bcrypt.checkpw(token_bytes, stored_bytes)
|
|
|
|
|
|
async def reload_client_with_relationships(
|
|
session: AsyncSession, client: Client
|
|
) -> Client:
|
|
"""Reload a client from the database."""
|
|
session.expunge(client)
|
|
stmt = (
|
|
select(Client)
|
|
.options(
|
|
selectinload(Client.policies),
|
|
selectinload(Client.secrets),
|
|
selectinload(Client.previous_version),
|
|
)
|
|
.where(Client.id == client.id)
|
|
)
|
|
result = await session.execute(stmt)
|
|
return result.scalar_one()
|
|
|
|
|
|
def client_with_relationships() -> Select[tuple[Client]]:
|
|
"""Base select statement for client with relationships."""
|
|
return select(Client).options(
|
|
selectinload(Client.secrets),
|
|
selectinload(Client.policies),
|
|
selectinload(Client.previous_version),
|
|
)
|
|
|
|
|
|
async def resolve_client_id(
|
|
session: AsyncSession,
|
|
name: str,
|
|
version: int | None = None,
|
|
include_deleted: bool = False,
|
|
) -> uuid.UUID | None:
|
|
"""Get the ID of a client name."""
|
|
if include_deleted:
|
|
client_filter = client_with_relationships().where(Client.name == name)
|
|
else:
|
|
client_filter = query_active_clients().where(Client.name == name)
|
|
if version:
|
|
client_filter = client_filter.where(Client.version == version)
|
|
|
|
client_result = await session.execute(client_filter)
|
|
if client := client_result.scalars().first():
|
|
return client.id
|
|
return None
|
|
|
|
|
|
async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | None:
|
|
"""Get client by ID."""
|
|
client_filter = client_with_relationships().where(Client.id == id)
|
|
client_results = await session.execute(client_filter)
|
|
return client_results.scalars().first()
|
|
|
|
|
|
async def get_client_by_id_or_name(
|
|
session: AsyncSession, id_or_name: str
|
|
) -> Client | None:
|
|
"""Get client either by id or name."""
|
|
if RE_UUID.match(id_or_name):
|
|
id = uuid.UUID(id_or_name)
|
|
return await get_client_by_id(session, id)
|
|
|
|
return await get_client_by_name(session, id_or_name)
|
|
|
|
|
|
def query_active_clients() -> Select[tuple[Client]]:
|
|
"""Get all active clients."""
|
|
client_filter = (
|
|
client_with_relationships()
|
|
.where(Client.is_active.is_(True))
|
|
.where(Client.is_deleted.is_(False))
|
|
)
|
|
return client_filter
|
|
|
|
|
|
async def get_client_by_name(session: AsyncSession, name: str) -> Client | None:
|
|
"""Get client by name.
|
|
|
|
This will get the latest client version, unless it's deleted.
|
|
"""
|
|
client_filter = (
|
|
client_with_relationships()
|
|
.where(Client.is_active.is_(True))
|
|
.where(Client.is_deleted.is_not(True))
|
|
.where(Client.name == name)
|
|
.order_by(Client.version.desc())
|
|
)
|
|
client_result = await session.execute(client_filter)
|
|
return client_result.scalars().first()
|
|
|
|
|
|
async def refresh_client(session: AsyncSession, client: Client) -> None:
|
|
"""Refresh the client and load in all relationships."""
|
|
await session.refresh(
|
|
client,
|
|
attribute_names=["secrets", "policies", "previous_version", "updated_at"],
|
|
)
|
|
|
|
|
|
async def create_new_client_version(
|
|
session: AsyncSession, current_client: Client, new_public_key: str
|
|
) -> Client:
|
|
new_client = Client(
|
|
name=current_client.name,
|
|
version=current_client.version + 1,
|
|
description=current_client.description,
|
|
public_key=new_public_key,
|
|
previous_version_id=current_client.id,
|
|
is_active=True,
|
|
)
|
|
|
|
# Mark current client as inactive
|
|
current_client.is_active = False
|
|
|
|
# Copy policies
|
|
for policy in current_client.policies:
|
|
copied_policy = ClientAccessPolicy(
|
|
client=new_client,
|
|
address=policy.source,
|
|
)
|
|
session.add(copied_policy)
|
|
|
|
session.add(new_client)
|
|
await session.flush()
|
|
await refresh_client(session, new_client)
|
|
return new_client
|