Standardize IDs, fix group APIs, fix tests

This commit is contained in:
2025-07-07 16:51:44 +02:00
parent 880d556542
commit 6faed0dbd4
22 changed files with 765 additions and 262 deletions

View File

@ -6,6 +6,7 @@ Since we have a frontend and a REST API, it makes sense to have a generic librar
import logging
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Unpack
from sshecret.backend import (
AuditLog,
@ -16,15 +17,21 @@ from sshecret.backend import (
Operation,
SubSystem,
)
from sshecret.backend.identifiers import KeySpec
from sshecret.backend.models import ClientQueryResult, ClientReference, DetailedSecrets
from sshecret.backend.api import AuditAPI, KeySpec
from sshecret.backend.api import AuditAPI
from sshecret.crypto import encrypt_string, load_public_key
from .secret_manager import AsyncSecretContext, password_manager_context
from .secret_manager import (
AsyncSecretContext,
SecretUpdateParams,
password_manager_context,
)
from sshecret_admin.core.settings import AdminServerSettings
from .models import (
ClientSecretGroup,
ClientSecretGroupList,
GroupReference,
SecretClientMapping,
SecretListView,
SecretGroup,
@ -57,11 +64,14 @@ def add_clients_to_secret_group(
parent: ClientSecretGroup | None = None,
) -> ClientSecretGroup:
"""Add client information to a secret group."""
parent_ref = None
if parent:
parent_ref = parent.reference()
client_secret_group = ClientSecretGroup(
group_name=group.name,
path=group.path,
description=group.description,
parent_group=parent,
parent_group=parent_ref,
)
for entry in group.entries:
secret_entries = SecretClientMapping(name=entry)
@ -74,12 +84,11 @@ def add_clients_to_secret_group(
subgroup, client_secret_mapping, client_secret_group
)
)
# We'll save a bit of memory and complexity by just adding the name of the parent, if available.
if not parent and group.parent_group:
client_secret_group.parent_group = ClientSecretGroup(
group_name=group.parent_group.name,
path=group.parent_group.path,
reference = GroupReference(
group_name=group.parent_group.name, path=group.parent_group.path
)
client_secret_group.parent_group = reference
return client_secret_group
@ -371,28 +380,29 @@ class AdminBackend:
async with self.secrets_manager() as password_manager:
await password_manager.set_secret_group(secret_name, group_name)
async def move_secret_group(
self, group_name: str, parent_group: str | None
) -> None:
async def move_secret_group(self, group_name: str, parent_group: str | None) -> str:
"""Move a group.
If parent_group is None, it will be moved to the root.
Returns the new path of the group.
"""
async with self.secrets_manager() as password_manager:
await password_manager.move_group(group_name, parent_group)
new_path = await password_manager.move_group(group_name, parent_group)
return new_path
async def set_group_description(self, group_name: str, description: str) -> None:
"""Set a group description."""
async with self.secrets_manager() as password_manager:
await password_manager.set_group_description(group_name, description)
async def delete_secret_group(self, group_name: str) -> None:
async def delete_secret_group(self, group_path: str) -> None:
"""Delete a group.
If keep_entries is set to False, all entries in the group will be deleted.
"""
async with self.secrets_manager() as password_manager:
await password_manager.delete_group(group_name)
await password_manager.delete_group(group_path)
async def get_secret_groups(
self,
@ -453,6 +463,23 @@ class AdminBackend:
return result
async def update_secret_group(
self, group_path: str, **params: Unpack[SecretUpdateParams]
) -> ClientSecretGroup:
"""Update secret group."""
async with self.secrets_manager() as password_manager:
secret_group = await password_manager.update_group(group_path, **params)
all_secrets = await self.backend.get_detailed_secrets()
secrets_mapping = {secret.name: secret for secret in all_secrets}
return add_clients_to_secret_group(secret_group, secrets_mapping)
async def lookup_secret_group(self, name_path: str) -> ClientSecretGroup | None:
"""Lookup a secret group."""
if "/" in name_path:
return await self.get_secret_group_by_path(name_path)
return await self.get_secret_group(name_path)
async def get_secret_group(self, name: str) -> ClientSecretGroup | None:
"""Get a single secret group by name."""
matches = await self.get_secret_groups(group_filter=name, regex=False)
@ -500,7 +527,10 @@ class AdminBackend:
secret_mapping = await self.backend.get_secret(idname)
if secret_mapping:
secret_view.clients = [ClientReference(id=ref.id, name=ref.name) for ref in secret_mapping.clients]
secret_view.clients = [
ClientReference(id=ref.id, name=ref.name)
for ref in secret_mapping.clients
]
return secret_view

View File

@ -144,16 +144,31 @@ class SecretClientMapping(BaseModel):
clients: list[ClientReference] = Field(default_factory=list)
class GroupReference(BaseModel):
"""Reference to a group.
This will be used for references to parent groups to avoid circular
references.
"""
group_name: str
path: str
class ClientSecretGroup(BaseModel):
"""Client secrets grouped."""
group_name: str
path: str
description: str | None = None
parent_group: "ClientSecretGroup | None" = None
parent_group: GroupReference | None = None
children: list["ClientSecretGroup"] = Field(default_factory=list)
entries: list[SecretClientMapping] = Field(default_factory=list)
def reference(self) -> GroupReference:
"""Create a reference."""
return GroupReference(group_name=self.group_name, path=self.path)
class SecretGroupCreate(BaseModel):
"""Create model for creating secret groups."""
@ -163,6 +178,14 @@ class SecretGroupCreate(BaseModel):
parent_group: str | None = None
class SecretGroupUdate(BaseModel):
"""Update model for updating secret groups."""
name: str | None = None
description: str | None = None
parent_group: str | None = None
class ClientSecretGroupList(BaseModel):
"""Secret group list."""
@ -196,3 +219,19 @@ class ClientListParams(BaseModel):
)
return self
class SecretGroupAssign(BaseModel):
"""Model for assigning secrets to a group.
If group is None, then it will be placed in the root.
"""
secret_name: str
group_path: str | None
class GroupPath(BaseModel):
"""Path to a group."""
path: str = Field(pattern="^/.*")

View File

@ -2,6 +2,7 @@
import logging
import os
from typing import NotRequired, TypedDict, Unpack
import uuid
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
@ -16,7 +17,8 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload, aliased
from sshecret.backend import SshecretBackend
from sshecret.backend.api import AuditAPI, KeySpec
from sshecret.backend.api import AuditAPI
from sshecret.backend.identifiers import KeySpec
from sshecret.backend.models import Client, ClientSecret, Operation, SubSystem
from sshecret.crypto import (
create_private_rsa_key,
@ -91,6 +93,14 @@ class SecretDataExport(BaseModel):
groups: list[SecretDataGroupExport]
class SecretUpdateParams(TypedDict):
"""Secret update parameters."""
name: NotRequired[str]
description: NotRequired[str]
parent: NotRequired[str]
def split_path(path: str) -> list[str]:
"""Split a path into a list of groups."""
elements = path.split("/")
@ -201,7 +211,9 @@ class AsyncSecretContext:
"""Build a group tree."""
path = "/"
if parent:
path = os.path.join(parent.path, path)
path = parent.path
path = os.path.join(path, group.name)
secret_group = SecretGroup(
name=group.name, path=path, description=group.description
)
@ -217,6 +229,8 @@ class AsyncSecretContext:
parent_group = await self._get_group_by_id(group.parent.id)
assert parent_group is not None
parent = await self._build_group_tree(parent_group, depth=current_depth)
path = os.path.join(parent.path, group.name)
secret_group.path = path
parent.children.append(secret_group)
secret_group.parent_group = parent
@ -224,6 +238,14 @@ class AsyncSecretContext:
return secret_group
for subgroup in group.children:
LOG.debug(
"group: %s, subgroup: %s path=%r, group_path: %r, parent: %r",
group.name,
subgroup.name,
path,
secret_group.path,
bool(parent),
)
child_group = await self._get_group_by_id(subgroup.id)
assert child_group is not None
secret_subgroup = await self._build_group_tree(
@ -462,6 +484,13 @@ class AsyncSecretContext:
result = await self.session.scalars(statement)
return result.one()
async def _lookup_group(self, name_path: str) -> Group | None:
"""Lookup group by path."""
if "/" in name_path:
elements = parse_path(name_path)
return await self._get_group(elements.item, elements.parent)
return await self._get_group(name_path)
async def _get_group(
self, name: str, parent: str | None = None, exact_match: bool = False
) -> Group | None:
@ -528,7 +557,7 @@ class AsyncSecretContext:
parent_group = elements.parent
if parent_group:
if parent := (await self._get_group(parent_group)):
if parent := (await self._lookup_group(parent_group)):
child_names = [child.name for child in parent.children]
if group_name in child_names:
raise InvalidGroupNameError(
@ -554,6 +583,63 @@ class AsyncSecretContext:
# We don't audit-log this operation.
await self.session.commit()
async def update_group(
self, name_path: str, **params: Unpack[SecretUpdateParams]
) -> SecretGroup:
"""Perform a complete update of a group.
This allows a patch operation. Only keyword arguments added will be considered.
"""
group = await self._lookup_group(name_path)
if not group:
raise InvalidGroupNameError("Invalid or non-existing parent group name.")
if description := params.get("description"):
group.description = description
target_name = group.name
rename = False
if new_name := params.get("name"):
target_name = new_name
if target_name != group.name:
rename = True
parent_group: Group | None = None
move_to_root = False
if parent := params.get("parent"):
if parent == "/":
group.parent = None
move_to_root = True
if rename:
groups = await self._get_groups(root_groups=True)
root_names = [x.name for x in groups]
if target_name in root_names:
raise InvalidGroupNameError("Name is already in use")
else:
new_parent_group = await self._lookup_group(parent)
if not new_parent_group:
raise InvalidGroupNameError(
"Invalid or non-existing parent group name."
)
parent_group = new_parent_group
group.parent_id = new_parent_group.id
elif group.parent_id and not move_to_root:
parent_group = await self._get_group_by_id(group.parent_id)
if parent_group and rename and not move_to_root:
child_names = [child.name for child in parent_group.children]
if target_name in child_names:
raise InvalidGroupNameError(
f"Parent group {parent_group.name} already has a group with this name: {target_name}. Params: {params !r}"
)
group.name = target_name
self.session.add(group)
await self.session.commit()
await self.session.refresh(group, ["parent"])
return await self._build_group_tree(group)
async def set_group_description(self, path: str, description: str) -> None:
"""Set group description."""
elements = parse_path(path)
@ -591,11 +677,12 @@ class AsyncSecretContext:
managed_secret=entry,
)
async def move_group(self, path: str, parent_group: str | None) -> None:
async def move_group(self, path: str, parent_group: str | None) -> str:
"""Move group.
If parent_group is None, it will be moved to the root.
"""
LOG.info("Move group: %s => %s", path, parent_group)
elements = parse_path(path)
group = await self._get_group(elements.item, elements.parent, True)
if not group:
@ -603,7 +690,7 @@ class AsyncSecretContext:
parent_group_id: uuid.UUID | None = None
if parent_group:
db_parent_group = await self._get_group(parent_group)
db_parent_group = await self._lookup_group(parent_group)
if not db_parent_group:
raise InvalidGroupNameError("Invalid or non-existing parent group.")
parent_group_id = db_parent_group.id
@ -612,6 +699,9 @@ class AsyncSecretContext:
self.session.add(group)
await self.session.commit()
await self.session.refresh(group)
new_path = await self._get_group_path(group)
return new_path
async def delete_group(self, path: str) -> None:
"""Delete a group."""