Backend fixed and features
This commit is contained in:
@ -29,6 +29,7 @@ from .schemas import (
|
||||
ClientView,
|
||||
ClientQueryResult,
|
||||
ClientPolicyUpdate,
|
||||
ClientReference,
|
||||
)
|
||||
|
||||
|
||||
@ -61,6 +62,7 @@ class ClientOperations:
|
||||
self,
|
||||
client: FlexID,
|
||||
version: int | None = None,
|
||||
include_deleted: bool = False,
|
||||
) -> uuid.UUID | None:
|
||||
"""Get client ID."""
|
||||
if self._client_id:
|
||||
@ -76,6 +78,7 @@ class ClientOperations:
|
||||
self.session,
|
||||
client_name,
|
||||
version=version,
|
||||
include_deleted=include_deleted,
|
||||
)
|
||||
if not client_id:
|
||||
return None
|
||||
@ -84,17 +87,26 @@ class ClientOperations:
|
||||
return client_id
|
||||
|
||||
async def _get_client(
|
||||
self,
|
||||
client: FlexID,
|
||||
version: int | None = None,
|
||||
self, client: FlexID, version: int | None = None, include_deleted: bool = False
|
||||
) -> Client | None:
|
||||
"""Get client."""
|
||||
client_id = await self.get_client_id(client, version=version)
|
||||
if not client_id:
|
||||
return None
|
||||
db_client = await get_client_by_id(self.session, client_id)
|
||||
if client.type is IdType.ID:
|
||||
client_id = uuid.UUID(client.value)
|
||||
else:
|
||||
client_id = await self.get_client_id(
|
||||
client, version=version, include_deleted=include_deleted
|
||||
)
|
||||
if not client_id:
|
||||
return None
|
||||
db_client = await get_client_by_id(
|
||||
self.session, client_id, include_deleted=include_deleted
|
||||
)
|
||||
return db_client
|
||||
|
||||
async def get_clients_terse(self) -> list[ClientReference]:
|
||||
"""Get a list of client id and names"""
|
||||
return await get_client_references(self.session)
|
||||
|
||||
async def get_client(
|
||||
self,
|
||||
client: FlexID,
|
||||
@ -115,7 +127,25 @@ class ClientOperations:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Error: A client already exists with this name."
|
||||
)
|
||||
deleted_id = await resolve_client_id(
|
||||
self.session, create_model.name, include_deleted=True
|
||||
)
|
||||
|
||||
client = create_model.to_client()
|
||||
if deleted_id:
|
||||
# Some other client had this name before, let's make it a new version
|
||||
LOG.info(
|
||||
"Client %s had this name before, we're creating a new version.",
|
||||
deleted_id,
|
||||
)
|
||||
return await self.new_client_version(
|
||||
FlexID.id(deleted_id),
|
||||
public_key=create_model.public_key,
|
||||
name=create_model.name,
|
||||
description=create_model.description,
|
||||
from_deleted=True,
|
||||
)
|
||||
|
||||
if system_client:
|
||||
statement = query_active_clients().where(Client.is_system.is_(True))
|
||||
results = await self.session.scalars(statement)
|
||||
@ -181,10 +211,12 @@ class ClientOperations:
|
||||
public_key: str,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
from_deleted: bool = False,
|
||||
) -> ClientView:
|
||||
"""Update a client to a new version."""
|
||||
current_client = await self._get_client(client)
|
||||
current_client = await self._get_client(client, include_deleted=from_deleted)
|
||||
if not current_client:
|
||||
LOG.info("Could not find previous version.")
|
||||
raise HTTPException(status_code=404, detail="Client not found.")
|
||||
new_client = await create_new_client_version(
|
||||
self.session, current_client, public_key
|
||||
@ -343,6 +375,23 @@ async def get_clients(
|
||||
)
|
||||
|
||||
|
||||
async def get_client_references(
|
||||
session: AsyncSession,
|
||||
) -> list[ClientReference]:
|
||||
"""Get a list of client names and IDs."""
|
||||
query = (
|
||||
select(Client)
|
||||
.where(Client.is_active.is_(True))
|
||||
.where(Client.is_deleted.is_not(True))
|
||||
.where(Client.is_system.is_not(True))
|
||||
)
|
||||
clients = await session.scalars(query)
|
||||
references: list[ClientReference] = []
|
||||
for client in clients.all():
|
||||
references.append(ClientReference(id=client.id, name=client.name))
|
||||
return references
|
||||
|
||||
|
||||
async def new_client_version(
|
||||
session: AsyncSession,
|
||||
client_id: uuid.UUID,
|
||||
|
||||
@ -17,6 +17,7 @@ from sshecret_backend.api.clients.schemas import (
|
||||
ClientView,
|
||||
ClientPolicyUpdate,
|
||||
ClientPolicyView,
|
||||
ClientReference,
|
||||
)
|
||||
from sshecret_backend.api.clients import operations
|
||||
from sshecret_backend.api.clients.operations import ClientOperations
|
||||
@ -46,6 +47,15 @@ def create_client_router(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||
client_op = ClientOperations(session, request)
|
||||
return await client_op.create_client(client)
|
||||
|
||||
@router.get("/clients/terse/")
|
||||
async def get_clients_terse(
|
||||
request: Request,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> list[ClientReference]:
|
||||
"""Get a list of client ids and names."""
|
||||
client_op = ClientOperations(session, request)
|
||||
return await client_op.get_clients_terse()
|
||||
|
||||
@router.get("/internal/system_client/", include_in_schema=False)
|
||||
async def get_system_client(
|
||||
request: Request,
|
||||
|
||||
@ -70,6 +70,12 @@ class ClientView(BaseModel):
|
||||
|
||||
return view
|
||||
|
||||
class ClientReference(BaseModel):
|
||||
"""A list of client names and IDs."""
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
|
||||
|
||||
class ClientQueryResult(BaseModel):
|
||||
"""Result class for queries towards the client list."""
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
"""Common helpers."""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from typing import Self
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
@ -14,6 +15,7 @@ from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sshecret_backend.models import Client, ClientAccessPolicy
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
RE_UUID = re.compile(
|
||||
"^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$"
|
||||
)
|
||||
@ -165,7 +167,7 @@ async def create_new_client_version(
|
||||
for policy in current_client.policies:
|
||||
copied_policy = ClientAccessPolicy(
|
||||
client=new_client,
|
||||
address=policy.source,
|
||||
source=policy.source,
|
||||
)
|
||||
session.add(copied_policy)
|
||||
|
||||
|
||||
@ -150,7 +150,7 @@ async def audit_new_client_version(
|
||||
message="Client data updated",
|
||||
data={
|
||||
"new_client_id": str(new_client.id),
|
||||
"new_client_version": new_client.version,
|
||||
"new_client_version": str(new_client.version),
|
||||
},
|
||||
)
|
||||
await _write_audit_log(session, request, entry, commit)
|
||||
|
||||
Reference in New Issue
Block a user