Refactor to use async database model

This commit is contained in:
2025-05-19 09:15:48 +02:00
parent f10ae027e5
commit fc0c3fb950
11 changed files with 288 additions and 185 deletions

View File

@ -3,9 +3,11 @@
import re
import uuid
import bcrypt
from sqlalchemy import Select
from sqlalchemy.orm import selectinload
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.future import select
from sqlalchemy.ext.asyncio import AsyncSession
from sshecret_backend.models import Client
@ -17,20 +19,38 @@ def verify_token(token: str, stored_hash: str) -> bool:
stored_bytes = stored_hash.encode("utf-8")
return bcrypt.checkpw(token_bytes, stored_bytes)
async def reload_client_with_relationships(session: AsyncSession, client: Client) -> Client:
"""Reload a client from the database."""
session.expunge(client)
stmt = (
select(Client)
.options(selectinload(Client.policies), selectinload(Client.secrets))
.where(Client.id == client.id)
)
result = await session.execute(stmt)
return result.scalar_one()
async def get_client_by_name(session: Session, name: str) -> Client | None:
def client_with_relationships() -> Select[tuple[Client]]:
"""Base select statement for client with relationships."""
return select(Client).options(
selectinload(Client.secrets),
selectinload(Client.policies),
)
async def get_client_by_name(session: AsyncSession, name: str) -> Client | None:
"""Get client by name."""
client_filter = select(Client).where(Client.name == name)
client_results = session.scalars(client_filter)
return client_results.first()
client_filter = client_with_relationships().where(Client.name == name)
client_results = await session.execute(client_filter)
return client_results.scalars().first()
async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None:
"""Get client by name."""
client_filter = select(Client).where(Client.id == id)
client_results = session.scalars(client_filter)
return client_results.first()
async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | None:
"""Get client by ID."""
client_filter = client_with_relationships().where(Client.id == id)
client_results = await session.execute(client_filter)
return client_results.scalars().first()
async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None:
async def get_client_by_id_or_name(session: AsyncSession, id_or_name: str) -> Client | None:
"""Get client either by id or name."""
if RE_UUID.match(id_or_name):
id = uuid.UUID(id_or_name)