Refactor to use async database model
This commit is contained in:
@ -3,18 +3,18 @@
|
||||
# pyright: reportUnusedFunction=false
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
from fastapi import APIRouter, Depends, Request, Query
|
||||
from typing import Any
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from sqlalchemy import select, func, and_
|
||||
from sqlalchemy.orm import InstrumentedAttribute, Session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import InstrumentedAttribute
|
||||
from sqlalchemy.sql.expression import ColumnExpressionArgument
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from sshecret_backend.models import AuditLog, Operation, SubSystem
|
||||
from sshecret_backend.types import DBSessionDep
|
||||
from sshecret_backend.types import AsyncDBSessionDep
|
||||
from sshecret_backend.view_models import AuditInfo, AuditView, AuditListResult
|
||||
|
||||
|
||||
@ -58,24 +58,23 @@ class AuditFilter(BaseModel):
|
||||
]
|
||||
|
||||
|
||||
def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
def get_audit_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||
"""Construct audit sub-api."""
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/audit/", response_model=AuditListResult)
|
||||
async def get_audit_logs(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
filters: Annotated[AuditFilter, Depends()],
|
||||
) -> AuditListResult:
|
||||
"""Get audit logs."""
|
||||
# audit.audit_access_audit_log(session, request)
|
||||
|
||||
total = session.scalars(
|
||||
total = (await session.scalars(
|
||||
select(func.count("*"))
|
||||
.select_from(AuditLog)
|
||||
.where(and_(True, *filters.filter_mapping))
|
||||
).one()
|
||||
)).one()
|
||||
|
||||
remaining = total - filters.offset
|
||||
statement = (
|
||||
@ -87,7 +86,7 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
)
|
||||
|
||||
LogAdapt = TypeAdapter(list[AuditView])
|
||||
results = session.scalars(statement).all()
|
||||
results = (await session.scalars(statement)).all()
|
||||
entries = LogAdapt.validate_python(results, from_attributes=True)
|
||||
return AuditListResult(
|
||||
results=entries,
|
||||
@ -97,24 +96,23 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
|
||||
@router.post("/audit/")
|
||||
async def add_audit_log(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
entry: AuditView,
|
||||
) -> AuditView:
|
||||
"""Add entry to audit log."""
|
||||
audit_log = AuditLog(**entry.model_dump(exclude_none=True))
|
||||
session.add(audit_log)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
return AuditView.model_validate(audit_log, from_attributes=True)
|
||||
|
||||
@router.get("/audit/info")
|
||||
async def get_audit_info(
|
||||
request: Request, session: Annotated[Session, Depends(get_db_session)]
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)]
|
||||
) -> AuditInfo:
|
||||
"""Get audit info."""
|
||||
audit_count = session.scalars(
|
||||
audit_count = (await session.scalars(
|
||||
select(func.count("*")).select_from(AuditLog)
|
||||
).one()
|
||||
)).one()
|
||||
return AuditInfo(entries=audit_count)
|
||||
|
||||
return router
|
||||
|
||||
@ -4,14 +4,14 @@
|
||||
|
||||
import uuid
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing import Annotated, Any, Self, TypeVar, cast
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.sql import Select
|
||||
from sshecret_backend.types import DBSessionDep
|
||||
from sshecret_backend.types import AsyncDBSessionDep
|
||||
from sshecret_backend.models import Client, ClientSecret
|
||||
from sshecret_backend.view_models import (
|
||||
ClientCreate,
|
||||
@ -20,7 +20,7 @@ from sshecret_backend.view_models import (
|
||||
ClientUpdate,
|
||||
)
|
||||
from sshecret_backend import audit
|
||||
from .common import get_client_by_id_or_name
|
||||
from .common import get_client_by_id_or_name, client_with_relationships
|
||||
|
||||
|
||||
class ClientListParams(BaseModel):
|
||||
@ -74,30 +74,30 @@ def filter_client_statement(
|
||||
return statement.limit(params.limit).offset(params.offset)
|
||||
|
||||
|
||||
def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
def get_clients_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||
"""Construct clients sub-api."""
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/clients/")
|
||||
async def get_clients(
|
||||
filter_query: Annotated[ClientListParams, Query()],
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientQueryResult:
|
||||
"""Get clients."""
|
||||
# Get total results first
|
||||
count_statement = select(func.count("*")).select_from(Client)
|
||||
count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True))
|
||||
|
||||
total_results = session.scalars(count_statement).one()
|
||||
total_results = (await session.scalars(count_statement)).one()
|
||||
|
||||
statement = filter_client_statement(select(Client), filter_query, False)
|
||||
statement = filter_client_statement(client_with_relationships(), filter_query, False)
|
||||
|
||||
results = session.scalars(statement)
|
||||
results = await session.scalars(statement)
|
||||
remainder = total_results - filter_query.offset - filter_query.limit
|
||||
if remainder < 0:
|
||||
remainder = 0
|
||||
|
||||
clients = list(results)
|
||||
clients = list(results.all())
|
||||
clients_view = ClientView.from_client_list(clients)
|
||||
return ClientQueryResult(
|
||||
clients=clients_view,
|
||||
@ -108,7 +108,7 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
@router.get("/clients/{name}")
|
||||
async def get_client(
|
||||
name: str,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientView:
|
||||
"""Fetch a client."""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
@ -122,7 +122,7 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
async def delete_client(
|
||||
request: Request,
|
||||
name: str,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> None:
|
||||
"""Delete a client."""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
@ -131,15 +131,15 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
session.delete(client)
|
||||
session.commit()
|
||||
audit.audit_delete_client(session, request, client)
|
||||
await session.delete(client)
|
||||
await session.commit()
|
||||
await audit.audit_delete_client(session, request, client)
|
||||
|
||||
@router.post("/clients/")
|
||||
async def create_client(
|
||||
request: Request,
|
||||
client: ClientCreate,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientView:
|
||||
"""Create client."""
|
||||
existing = await get_client_by_id_or_name(session, client.name)
|
||||
@ -148,9 +148,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
|
||||
db_client = client.to_client()
|
||||
session.add(db_client)
|
||||
session.commit()
|
||||
session.refresh(db_client)
|
||||
audit.audit_create_client(session, request, db_client)
|
||||
await session.commit()
|
||||
await session.refresh(db_client)
|
||||
db_client = await get_client_by_id_or_name(session, client.name)
|
||||
if not db_client:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Could not create the client.")
|
||||
await audit.audit_create_client(session, request, db_client)
|
||||
return ClientView.from_client(db_client)
|
||||
|
||||
@router.post("/clients/{name}/public-key")
|
||||
@ -158,7 +161,7 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
request: Request,
|
||||
name: str,
|
||||
client_update: ClientUpdate,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientView:
|
||||
"""Change the public key of a client.
|
||||
|
||||
@ -170,17 +173,16 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
client.public_key = client_update.public_key
|
||||
for secret in session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||
).all():
|
||||
matching_secrets = await session.scalars(select(ClientSecret).where(ClientSecret.client_id == client.id))
|
||||
for secret in matching_secrets.all():
|
||||
LOG.debug("Invalidated secret %s", secret.id)
|
||||
secret.invalidated = True
|
||||
secret.client_id = None
|
||||
|
||||
session.add(client)
|
||||
session.refresh(client)
|
||||
session.commit()
|
||||
audit.audit_invalidate_secrets(session, request, client)
|
||||
await session.refresh(client)
|
||||
await session.commit()
|
||||
await audit.audit_invalidate_secrets(session, request, client)
|
||||
|
||||
return ClientView.from_client(client)
|
||||
|
||||
@ -189,7 +191,7 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
request: Request,
|
||||
name: str,
|
||||
client_update: ClientCreate,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientView:
|
||||
"""Change the public key of a client.
|
||||
|
||||
@ -205,19 +207,20 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
public_key_updated = False
|
||||
if client_update.public_key != client.public_key:
|
||||
public_key_updated = True
|
||||
for secret in session.scalars(
|
||||
client_secrets = await session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||
).all():
|
||||
)
|
||||
for secret in client_secrets.all():
|
||||
LOG.debug("Invalidated secret %s", secret.id)
|
||||
secret.invalidated = True
|
||||
secret.client_id = None
|
||||
|
||||
session.add(client)
|
||||
session.commit()
|
||||
session.refresh(client)
|
||||
audit.audit_update_client(session, request, client)
|
||||
await session.commit()
|
||||
await session.refresh(client)
|
||||
await audit.audit_update_client(session, request, client)
|
||||
if public_key_updated:
|
||||
audit.audit_invalidate_secrets(session, request, client)
|
||||
await audit.audit_invalidate_secrets(session, request, client)
|
||||
|
||||
return ClientView.from_client(client)
|
||||
|
||||
|
||||
@ -3,9 +3,11 @@
|
||||
import re
|
||||
import uuid
|
||||
import bcrypt
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from sshecret_backend.models import Client
|
||||
|
||||
@ -17,20 +19,38 @@ def verify_token(token: str, stored_hash: str) -> bool:
|
||||
stored_bytes = stored_hash.encode("utf-8")
|
||||
return bcrypt.checkpw(token_bytes, stored_bytes)
|
||||
|
||||
async def reload_client_with_relationships(session: AsyncSession, client: Client) -> Client:
|
||||
"""Reload a client from the database."""
|
||||
session.expunge(client)
|
||||
stmt = (
|
||||
select(Client)
|
||||
.options(selectinload(Client.policies), selectinload(Client.secrets))
|
||||
.where(Client.id == client.id)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one()
|
||||
|
||||
async def get_client_by_name(session: Session, name: str) -> Client | None:
|
||||
|
||||
def client_with_relationships() -> Select[tuple[Client]]:
|
||||
"""Base select statement for client with relationships."""
|
||||
return select(Client).options(
|
||||
selectinload(Client.secrets),
|
||||
selectinload(Client.policies),
|
||||
)
|
||||
|
||||
async def get_client_by_name(session: AsyncSession, name: str) -> Client | None:
|
||||
"""Get client by name."""
|
||||
client_filter = select(Client).where(Client.name == name)
|
||||
client_results = session.scalars(client_filter)
|
||||
return client_results.first()
|
||||
client_filter = client_with_relationships().where(Client.name == name)
|
||||
client_results = await session.execute(client_filter)
|
||||
return client_results.scalars().first()
|
||||
|
||||
async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None:
|
||||
"""Get client by name."""
|
||||
client_filter = select(Client).where(Client.id == id)
|
||||
client_results = session.scalars(client_filter)
|
||||
return client_results.first()
|
||||
async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | None:
|
||||
"""Get client by ID."""
|
||||
client_filter = client_with_relationships().where(Client.id == id)
|
||||
client_results = await session.execute(client_filter)
|
||||
return client_results.scalars().first()
|
||||
|
||||
async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None:
|
||||
async def get_client_by_id_or_name(session: AsyncSession, id_or_name: str) -> Client | None:
|
||||
"""Get client either by id or name."""
|
||||
if RE_UUID.match(id_or_name):
|
||||
id = uuid.UUID(id_or_name)
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Annotated
|
||||
|
||||
from sshecret_backend.models import ClientAccessPolicy
|
||||
@ -13,21 +13,21 @@ from sshecret_backend.view_models import (
|
||||
ClientPolicyView,
|
||||
ClientPolicyUpdate,
|
||||
)
|
||||
from sshecret_backend.types import DBSessionDep
|
||||
from sshecret_backend.types import AsyncDBSessionDep
|
||||
from sshecret_backend import audit
|
||||
from .common import get_client_by_id_or_name
|
||||
from .common import get_client_by_id_or_name, reload_client_with_relationships
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
def get_policy_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||
"""Construct clients sub-api."""
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/clients/{name}/policies/")
|
||||
async def get_client_policies(
|
||||
name: str, session: Annotated[Session, Depends(get_db_session)]
|
||||
name: str, session: Annotated[AsyncSession, Depends(get_db_session)]
|
||||
) -> ClientPolicyView:
|
||||
"""Get client policies."""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
@ -43,7 +43,7 @@ def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
request: Request,
|
||||
name: str,
|
||||
policy_update: ClientPolicyUpdate,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientPolicyView:
|
||||
"""Update client policies.
|
||||
|
||||
@ -55,28 +55,31 @@ def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
# Remove old policies.
|
||||
policies = session.scalars(
|
||||
policies = await session.scalars(
|
||||
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
|
||||
).all()
|
||||
)
|
||||
deleted_policies: list[ClientAccessPolicy] = []
|
||||
added_policies: list[ClientAccessPolicy] = []
|
||||
for policy in policies:
|
||||
session.delete(policy)
|
||||
for policy in policies.all():
|
||||
await session.delete(policy)
|
||||
deleted_policies.append(policy)
|
||||
|
||||
LOG.debug("Updating client policies with: %r", policy_update.sources)
|
||||
for source in policy_update.sources:
|
||||
LOG.debug("Source %r", source)
|
||||
policy = ClientAccessPolicy(source=str(source), client_id=client.id)
|
||||
session.add(policy)
|
||||
added_policies.append(policy)
|
||||
|
||||
session.commit()
|
||||
session.refresh(client)
|
||||
await session.flush()
|
||||
await session.commit()
|
||||
|
||||
client = await reload_client_with_relationships(session, client)
|
||||
for policy in deleted_policies:
|
||||
audit.audit_remove_policy(session, request, client, policy)
|
||||
await audit.audit_remove_policy(session, request, client, policy)
|
||||
|
||||
for policy in added_policies:
|
||||
audit.audit_update_policy(session, request, client, policy)
|
||||
await audit.audit_update_policy(session, request, client, policy)
|
||||
|
||||
return ClientPolicyView.from_client(client)
|
||||
|
||||
|
||||
@ -6,9 +6,11 @@ import logging
|
||||
from collections import defaultdict
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Annotated
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from sshecret_backend.models import Client, ClientSecret
|
||||
from sshecret_backend.view_models import (
|
||||
ClientReference,
|
||||
@ -19,7 +21,7 @@ from sshecret_backend.view_models import (
|
||||
ClientSecretResponse,
|
||||
)
|
||||
from sshecret_backend import audit
|
||||
from sshecret_backend.types import DBSessionDep
|
||||
from sshecret_backend.types import AsyncDBSessionDep
|
||||
from .common import get_client_by_id_or_name
|
||||
|
||||
|
||||
@ -27,7 +29,7 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def lookup_client_secret(
|
||||
session: Session, client: Client, name: str
|
||||
session: AsyncSession, client: Client, name: str
|
||||
) -> ClientSecret | None:
|
||||
"""Look up a secret for a client."""
|
||||
statement = (
|
||||
@ -35,11 +37,11 @@ async def lookup_client_secret(
|
||||
.where(ClientSecret.client_id == client.id)
|
||||
.where(ClientSecret.name == name)
|
||||
)
|
||||
results = session.scalars(statement)
|
||||
results = await session.scalars(statement)
|
||||
return results.first()
|
||||
|
||||
|
||||
def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
def get_secrets_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||
"""Construct clients sub-api."""
|
||||
router = APIRouter()
|
||||
|
||||
@ -48,7 +50,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
request: Request,
|
||||
name: str,
|
||||
client_secret: ClientSecretPublic,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> None:
|
||||
"""Add secret to a client."""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
@ -69,9 +71,9 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
name=client_secret.name, client_id=client.id, secret=client_secret.secret
|
||||
)
|
||||
session.add(db_secret)
|
||||
session.commit()
|
||||
session.refresh(db_secret)
|
||||
audit.audit_create_secret(session, request, client, db_secret)
|
||||
await session.commit()
|
||||
await session.refresh(db_secret)
|
||||
await audit.audit_create_secret(session, request, client, db_secret)
|
||||
|
||||
@router.put("/clients/{name}/secrets/{secret_name}")
|
||||
async def update_client_secret(
|
||||
@ -79,7 +81,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
name: str,
|
||||
secret_name: str,
|
||||
secret_data: BodyValue,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientSecretResponse:
|
||||
"""Update a client secret.
|
||||
|
||||
@ -96,9 +98,9 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
existing_secret.secret = secret_data.value
|
||||
|
||||
session.add(existing_secret)
|
||||
session.commit()
|
||||
session.refresh(existing_secret)
|
||||
audit.audit_update_secret(session, request, client, existing_secret)
|
||||
await session.commit()
|
||||
await session.refresh(existing_secret)
|
||||
await audit.audit_update_secret(session, request, client, existing_secret)
|
||||
return ClientSecretResponse.from_client_secret(existing_secret)
|
||||
|
||||
db_secret = ClientSecret(
|
||||
@ -107,9 +109,9 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
secret=secret_data.value,
|
||||
)
|
||||
session.add(db_secret)
|
||||
session.commit()
|
||||
session.refresh(db_secret)
|
||||
audit.audit_create_secret(session, request, client, db_secret)
|
||||
await session.commit()
|
||||
await session.refresh(db_secret)
|
||||
await audit.audit_create_secret(session, request, client, db_secret)
|
||||
return ClientSecretResponse.from_client_secret(db_secret)
|
||||
|
||||
@router.get("/clients/{name}/secrets/{secret_name}")
|
||||
@ -117,7 +119,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
request: Request,
|
||||
name: str,
|
||||
secret_name: str,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientSecretResponse:
|
||||
"""Get a client secret."""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
@ -133,7 +135,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
)
|
||||
|
||||
response_model = ClientSecretResponse.from_client_secret(secret)
|
||||
audit.audit_access_secret(session, request, client, secret)
|
||||
await audit.audit_access_secret(session, request, client, secret)
|
||||
return response_model
|
||||
|
||||
@router.delete("/clients/{name}/secrets/{secret_name}")
|
||||
@ -141,7 +143,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
request: Request,
|
||||
name: str,
|
||||
secret_name: str,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> None:
|
||||
"""Delete a secret."""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
@ -156,56 +158,69 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
status_code=404, detail="Cannot find a secret with the given name."
|
||||
)
|
||||
|
||||
session.delete(secret)
|
||||
session.commit()
|
||||
audit.audit_delete_secret(session, request, client, secret)
|
||||
await session.delete(secret)
|
||||
await session.commit()
|
||||
await audit.audit_delete_secret(session, request, client, secret)
|
||||
|
||||
@router.get("/secrets/")
|
||||
async def get_secret_map(
|
||||
request: Request, session: Annotated[Session, Depends(get_db_session)]
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> list[ClientSecretList]:
|
||||
"""Get a list of all secrets and which clients have them."""
|
||||
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
|
||||
for client_secret in session.scalars(select(ClientSecret)).all():
|
||||
client_secrets = await session.scalars(
|
||||
select(ClientSecret).options(selectinload(ClientSecret.client))
|
||||
)
|
||||
for client_secret in client_secrets.all():
|
||||
if not client_secret.client:
|
||||
if client_secret.name not in client_secret_map:
|
||||
client_secret_map[client_secret.name] = []
|
||||
continue
|
||||
client_secret_map[client_secret.name].append(client_secret.client.name)
|
||||
#audit.audit_client_secret_list(session, request)
|
||||
# audit.audit_client_secret_list(session, request)
|
||||
return [
|
||||
ClientSecretList(name=secret_name, clients=clients)
|
||||
for secret_name, clients in client_secret_map.items()
|
||||
]
|
||||
|
||||
@router.get("/secrets/detailed/")
|
||||
async def get_detailed_secret_map(
|
||||
request: Request, session: Annotated[Session, Depends(get_db_session)]
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> list[ClientSecretDetailList]:
|
||||
"""Get a list of all secrets and which clients have them."""
|
||||
client_secrets: dict[str, ClientSecretDetailList] = {}
|
||||
for client_secret in session.scalars(select(ClientSecret)).all():
|
||||
|
||||
all_client_secrets = await session.execute(
|
||||
select(ClientSecret).options(selectinload(ClientSecret.client))
|
||||
)
|
||||
for client_secret in all_client_secrets.scalars().all():
|
||||
if client_secret.name not in client_secrets:
|
||||
client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name)
|
||||
client_secrets[client_secret.name] = ClientSecretDetailList(
|
||||
name=client_secret.name
|
||||
)
|
||||
client_secrets[client_secret.name].ids.append(str(client_secret.id))
|
||||
if not client_secret.client:
|
||||
continue
|
||||
client_secrets[client_secret.name].clients.append(ClientReference(id=str(client_secret.client.id), name=client_secret.client.name))
|
||||
#`audit.audit_client_secret_list(session, request)
|
||||
client_secrets[client_secret.name].clients.append(
|
||||
ClientReference(
|
||||
id=str(client_secret.client.id), name=client_secret.client.name
|
||||
)
|
||||
)
|
||||
# `audit.audit_client_secret_list(session, request)
|
||||
return list(client_secrets.values())
|
||||
|
||||
|
||||
@router.get("/secrets/{name}")
|
||||
async def get_secret_clients(
|
||||
request: Request,
|
||||
name: str,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientSecretList:
|
||||
"""Get a list of which clients has a named secret."""
|
||||
clients: list[str] = []
|
||||
for client_secret in session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.name == name)
|
||||
).all():
|
||||
client_secrets = await session.scalars(
|
||||
select(ClientSecret)
|
||||
.options(selectinload(ClientSecret.client))
|
||||
.where(ClientSecret.name == name)
|
||||
)
|
||||
for client_secret in client_secrets.all():
|
||||
if not client_secret.client:
|
||||
continue
|
||||
clients.append(client_secret.client.name)
|
||||
@ -214,19 +229,23 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
|
||||
@router.get("/secrets/{name}/detailed")
|
||||
async def get_secret_clients_detailed(
|
||||
request: Request,
|
||||
name: str,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientSecretDetailList:
|
||||
"""Get a list of which clients has a named secret."""
|
||||
detail_list = ClientSecretDetailList(name=name)
|
||||
for client_secret in session.scalars(
|
||||
client_secrets = await session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.name == name)
|
||||
).all():
|
||||
)
|
||||
for client_secret in client_secrets.all():
|
||||
if not client_secret.client:
|
||||
continue
|
||||
detail_list.ids.append(str(client_secret.id))
|
||||
detail_list.clients.append(ClientReference(id=str(client_secret.client.id), name=client_secret.client.name))
|
||||
detail_list.clients.append(
|
||||
ClientReference(
|
||||
id=str(client_secret.client.id), name=client_secret.client.name
|
||||
)
|
||||
)
|
||||
|
||||
return detail_list
|
||||
|
||||
|
||||
Reference in New Issue
Block a user