Refactor to use async database model
This commit is contained in:
@ -3,9 +3,11 @@
|
||||
import re
|
||||
import uuid
|
||||
import bcrypt
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from sshecret_backend.models import Client
|
||||
|
||||
@ -17,20 +19,38 @@ def verify_token(token: str, stored_hash: str) -> bool:
|
||||
stored_bytes = stored_hash.encode("utf-8")
|
||||
return bcrypt.checkpw(token_bytes, stored_bytes)
|
||||
|
||||
async def reload_client_with_relationships(session: AsyncSession, client: Client) -> Client:
|
||||
"""Reload a client from the database."""
|
||||
session.expunge(client)
|
||||
stmt = (
|
||||
select(Client)
|
||||
.options(selectinload(Client.policies), selectinload(Client.secrets))
|
||||
.where(Client.id == client.id)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one()
|
||||
|
||||
async def get_client_by_name(session: Session, name: str) -> Client | None:
|
||||
|
||||
def client_with_relationships() -> Select[tuple[Client]]:
|
||||
"""Base select statement for client with relationships."""
|
||||
return select(Client).options(
|
||||
selectinload(Client.secrets),
|
||||
selectinload(Client.policies),
|
||||
)
|
||||
|
||||
async def get_client_by_name(session: AsyncSession, name: str) -> Client | None:
|
||||
"""Get client by name."""
|
||||
client_filter = select(Client).where(Client.name == name)
|
||||
client_results = session.scalars(client_filter)
|
||||
return client_results.first()
|
||||
client_filter = client_with_relationships().where(Client.name == name)
|
||||
client_results = await session.execute(client_filter)
|
||||
return client_results.scalars().first()
|
||||
|
||||
async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None:
|
||||
"""Get client by name."""
|
||||
client_filter = select(Client).where(Client.id == id)
|
||||
client_results = session.scalars(client_filter)
|
||||
return client_results.first()
|
||||
async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | None:
|
||||
"""Get client by ID."""
|
||||
client_filter = client_with_relationships().where(Client.id == id)
|
||||
client_results = await session.execute(client_filter)
|
||||
return client_results.scalars().first()
|
||||
|
||||
async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None:
|
||||
async def get_client_by_id_or_name(session: AsyncSession, id_or_name: str) -> Client | None:
|
||||
"""Get client either by id or name."""
|
||||
if RE_UUID.match(id_or_name):
|
||||
id = uuid.UUID(id_or_name)
|
||||
|
||||
Reference in New Issue
Block a user