Standardize IDs, fix group APIs, fix tests

This commit is contained in:
2025-07-07 16:51:44 +02:00
parent 880d556542
commit 6faed0dbd4
22 changed files with 765 additions and 262 deletions

View File

@ -1,7 +1,4 @@
"""Client-related endpoints factory.
# TODO: Settle on name/keyspec pattern
"""
"""Client-related endpoints factory."""
# pyright: reportUnusedFunction=false
@ -20,10 +17,15 @@ from sshecret_admin.services.models import (
UpdatePoliciesRequest,
)
from sshecret.backend.identifiers import ClientIdParam, FlexID, KeySpec
from sshecret.backend.models import ClientQueryResult, ClientReference, FilterType
LOG = logging.getLogger(__name__)
def _id(identifier: str) -> KeySpec:
"""Parse ID."""
parsed = FlexID.from_string(identifier)
return parsed.keyspec
def query_filter_to_client_filter(query_filter: ClientListParams) -> ClientFilter:
"""Convert query filter to client filter."""
@ -95,11 +97,11 @@ def create_router(dependencies: AdminDependencies) -> APIRouter:
@app.get("/clients/{id}")
async def get_client(
id: str,
id: ClientIdParam,
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
) -> Client:
"""Get a client."""
client = await admin.get_client(("id", id))
client = await admin.get_client(_id(id))
if not client:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Item not found"
@ -109,12 +111,12 @@ def create_router(dependencies: AdminDependencies) -> APIRouter:
@app.put("/clients/{id}")
async def update_client(
id: str,
id: ClientIdParam,
updated: ClientCreate,
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
) -> Client:
"""Update a client."""
client = await admin.get_client(("id", id))
client = await admin.get_client(_id(id))
if not client:
raise HTTPException(
@ -132,20 +134,20 @@ def create_router(dependencies: AdminDependencies) -> APIRouter:
@app.delete("/clients/{id}")
async def delete_client(
id: str,
id: ClientIdParam,
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
) -> None:
"""Delete a client."""
await admin.delete_client(("id", id))
await admin.delete_client(_id(id))
@app.delete("/clients/{id}/secrets/{secret_name}")
async def delete_secret_from_client(
id: str,
id: ClientIdParam,
secret_name: str,
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
) -> None:
"""Delete a secret from a client."""
client = await admin.get_client(("id", id))
client = await admin.get_client(_id(id))
if not client:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Item not found"
@ -164,7 +166,7 @@ def create_router(dependencies: AdminDependencies) -> APIRouter:
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
) -> Client:
"""Update the client access policies."""
client = await admin.get_client(("id", id))
client = await admin.get_client(_id(id))
if not client:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Item not found"
@ -182,7 +184,7 @@ def create_router(dependencies: AdminDependencies) -> APIRouter:
@app.put("/clients/{id}/public-key")
async def update_client_public_key(
id: str,
id: ClientIdParam,
updated: UpdateKeyModel,
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
) -> UpdateKeyResponse:
@ -193,7 +195,7 @@ def create_router(dependencies: AdminDependencies) -> APIRouter:
"""
# Let's first ensure that the key is actually updated.
updated_secrets = await admin.update_client_public_key(
("id", id), updated.public_key
_id(id), updated.public_key
)
return UpdateKeyResponse(
public_key=updated.public_key, updated_secrets=updated_secrets
@ -201,11 +203,11 @@ def create_router(dependencies: AdminDependencies) -> APIRouter:
@app.put("/clients/{id}/secrets/{secret_name}")
async def add_secret_to_client(
id: str,
id: ClientIdParam,
secret_name: str,
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
) -> None:
"""Add secret to a client."""
await admin.create_client_secret(("id", id), secret_name)
await admin.create_client_secret(_id(id), secret_name)
return app

View File

@ -10,12 +10,19 @@ from sshecret_admin.services import AdminBackend
from sshecret_admin.services.models import (
ClientSecretGroup,
ClientSecretGroupList,
GroupPath,
SecretCreate,
SecretGroupAssign,
SecretGroupCreate,
SecretGroupUdate,
SecretListView,
SecretUpdate,
SecretView,
)
from sshecret_admin.services.secret_manager import (
InvalidGroupNameError,
InvalidSecretNameError,
)
LOG = logging.getLogger(__name__)
@ -81,20 +88,50 @@ def create_router(dependencies: AdminDependencies) -> APIRouter:
filter_regex: Annotated[str | None, Query()] = None,
) -> ClientSecretGroupList:
"""Get secret groups."""
return await admin.get_secret_groups(filter_regex)
result = await admin.get_secret_groups(filter_regex)
return result
@app.get("/secrets/groups/{group_name}/")
@app.get("/secrets/groups/{group_path:path}/")
async def get_secret_group(
group_name: str,
group_path: str,
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
) -> ClientSecretGroup:
"""Get a specific secret group."""
results = await admin.get_secret_groups(group_name, False)
results = await admin.get_secret_group_by_path(group_path)
if not results:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="No such group."
)
return results.groups[0]
return results
@app.put("/secrets/groups/{group_path:path}/")
async def update_secret_group(
group_path: str,
group: SecretGroupUdate,
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
) -> ClientSecretGroup:
"""Update a secret group."""
existing_group = await admin.lookup_secret_group(group_path)
if not existing_group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="No such group."
)
params: dict[str, str] = {}
if name := group.name:
params["name"] = name
if description := group.description:
params["description"] = description
if parent := group.parent_group:
params["parent"] = parent
new_group = await admin.update_secret_group(
group_path,
**params,
)
return new_group
@app.post("/secrets/groups/")
async def add_secret_group(
@ -108,16 +145,16 @@ def create_router(dependencies: AdminDependencies) -> APIRouter:
parent_group=group.parent_group,
)
result = await admin.get_secret_group(group.name)
result = await admin.lookup_secret_group(group.name)
if not result:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Group creation failed"
)
return result
@app.delete("/secrets/groups/{group_name}/")
@app.delete("/secrets/groups/{group_path:path}/")
async def delete_secret_group(
group_name: str,
group_path: str,
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
) -> None:
"""Remove a group.
@ -125,83 +162,55 @@ def create_router(dependencies: AdminDependencies) -> APIRouter:
Entries within the group will be moved to the root.
This also includes nested entries further down from the group.
"""
group = await admin.get_secret_group(group_name)
group = await admin.get_secret_group_by_path(group_path)
if not group:
return
await admin.delete_secret_group(group_name)
await admin.delete_secret_group(group_path)
@app.post("/secrets/groups/{group_name}/{secret_name}")
async def move_secret_to_group(
group_name: str,
secret_name: str,
@app.post("/secrets/set-group")
async def assign_secret_group(
assignment: SecretGroupAssign,
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
) -> None:
"""Move a secret to a group."""
groups = await admin.get_secret_groups(group_name, False)
if not groups:
"""Assign a secret to a group or root."""
try:
await admin.set_secret_group(assignment.secret_name, assignment.group_path)
except InvalidSecretNameError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="No such group."
status_code=status.HTTP_404_NOT_FOUND, detail="Secret not fount"
)
except InvalidGroupNameError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Invalid group name"
)
await admin.set_secret_group(secret_name, group_name)
@app.post("/secrets/group/{group_name}/parent/{parent_name}")
@app.post("/secrets/move-group/{group_name:path}")
async def move_group(
group_name: str,
parent_name: str,
destination: GroupPath,
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
) -> None:
"""Move a group."""
group = await admin.get_secret_group(group_name)
group = await admin.lookup_secret_group(group_name)
if not group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No such group {group_name}",
)
parent_group = await admin.get_secret_group(parent_name)
if not parent_group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No such group {parent_name}",
)
await admin.move_secret_group(group_name, parent_name)
parent_path: str | None = destination.path
if destination.path == "/" or not destination.path:
# / means root
parent_path = None
@app.delete("/secrets/group/{group_name}/parent/")
async def move_group_to_root(
group_name: str,
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
) -> None:
"""Move a group to the root."""
group = await admin.get_secret_group(group_name)
if not group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No such group {group_name}",
)
LOG.debug("Moving group %s to %r", group_name, parent_path)
await admin.move_secret_group(group_name, None)
@app.delete("/secrets/groups/{group_name}/{secret_name}")
async def remove_secret_from_group(
group_name: str,
secret_name: str,
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
) -> None:
"""Remove a secret from a group.
Secret will be moved to the root group.
"""
groups = await admin.get_secret_groups(group_name, False)
if not groups:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="No such group."
)
group = groups.groups[0]
matching_entries = [
entry for entry in group.entries if entry.name == secret_name
]
if not matching_entries:
return
await admin.set_secret_group(secret_name, None)
if parent_path:
parent_group = await admin.get_secret_group_by_path(destination.path)
if not parent_group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No such group {parent_path}",
)
await admin.move_secret_group(group_name, parent_path)
return app

View File

@ -14,7 +14,7 @@ from sqlalchemy.orm import Session
from sshecret_admin.services.admin_backend import AdminBackend
from sshecret_admin.core.dependencies import BaseDependencies, AdminDependencies
from sshecret_admin.auth import PasswordDB, User, decode_token
from sshecret_admin.auth import User, decode_token
from sshecret_admin.auth.constants import LOCAL_ISSUER
from .endpoints import auth, clients, secrets
@ -93,18 +93,10 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
async def get_admin_backend(
request: Request,
session: Annotated[Session, Depends(dependencies.get_db_session)],
):
"""Get admin backend API."""
username = get_optional_username(request)
origin = get_client_origin(request)
password_db = session.scalars(
select(PasswordDB).where(PasswordDB.id == 1)
).first()
if not password_db:
raise HTTPException(
500, detail="Error: The password manager has not yet been set up."
)
admin = AdminBackend(
dependencies.settings,
username=username,

View File

@ -58,12 +58,14 @@ def create_admin_app(
def setup_password_manager() -> None:
"""Setup password manager."""
LOG.info("Setting up password manager")
setup_private_key(settings, regenerate=False)
@asynccontextmanager
async def lifespan(_app: FastAPI):
"""Create database before starting the server."""
if create_db:
LOG.info("Setting up database")
Base.metadata.create_all(engine)
setup_password_manager()
yield

View File

@ -1,12 +1,9 @@
"""Sshecret admin CLI helper."""
import asyncio
import code
import json
import logging
from collections.abc import Awaitable
from pathlib import Path
from typing import Any, cast
from typing import cast
import click
import uvicorn
@ -14,10 +11,9 @@ from pydantic import ValidationError
from sqlalchemy import select, create_engine
from sqlalchemy.orm import Session
from sshecret_admin.auth.authentication import hash_password
from sshecret_admin.auth.models import AuthProvider, PasswordDB, User
from sshecret_admin.auth.models import AuthProvider, User
from sshecret_admin.core.app import create_admin_app
from sshecret_admin.core.settings import AdminServerSettings
from sshecret_admin.services.admin_backend import AdminBackend
handler = logging.StreamHandler()
formatter = logging.Formatter(
@ -143,36 +139,6 @@ def cli_run(
)
@cli.command("repl")
@click.pass_context
def cli_repl(ctx: click.Context) -> None:
"""Run an interactive console."""
settings = cast(AdminServerSettings, ctx.obj)
engine = create_engine(settings.admin_db)
with Session(engine) as session:
password_db = session.scalars(
select(PasswordDB).where(PasswordDB.id == 1)
).first()
if not password_db:
raise click.ClickException(
"Error: Password database has not yet been setup. Start the server to finish setup."
)
def run(func: Awaitable[Any]) -> Any:
"""Run an async function."""
loop = asyncio.get_event_loop()
return loop.run_until_complete(func)
admin = AdminBackend(settings, )
locals = {
"run": run,
"admin": admin,
}
banner = "Sshecret-admin REPL\nAdmin backend API bound to 'admin'. Run async functions with run()"
console = code.InteractiveConsole(locals=locals, local_exit=True)
console.interact(banner=banner, exitmsg="Bye!")
@cli.command("openapi")
@click.argument("destination", type=click.Path(file_okay=False, dir_okay=True, path_type=Path))
@click.pass_context

View File

@ -23,7 +23,7 @@ def setup_database(
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
"""Setup database."""
engine = create_engine(db_url, echo=True, future=True)
engine = create_engine(db_url, echo=False, future=True)
if db_url.drivername.startswith("sqlite"):
@event.listens_for(engine, "connect")

View File

@ -15,7 +15,7 @@ from sshecret_admin.core.settings import AdminServerSettings
DBSessionDep = Callable[[], Generator[Session, None, None]]
AsyncSessionDep = Callable[[], AsyncGenerator[AsyncSession, None]]
AdminDep = Callable[[Request, Session], AsyncGenerator[AdminBackend, None]]
AdminDep = Callable[[Request], AsyncGenerator[AdminBackend, None]]
GetUserDep = Callable[[User], Awaitable[User]]

View File

@ -7,19 +7,18 @@ import os
from pathlib import Path
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, Request
from jinja2_fragments.fastapi import Jinja2Blocks
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from sshecret_admin.auth.authentication import generate_user_info
from sshecret_admin.auth.models import AuthProvider, IdentityClaims, LocalUserInfo
from starlette.datastructures import URL
from sshecret_admin.auth import PasswordDB, User, decode_token
from sshecret_admin.auth import User, decode_token
from sshecret_admin.auth.constants import LOCAL_ISSUER
from sshecret_admin.core.dependencies import BaseDependencies
@ -50,18 +49,10 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
async def get_admin_backend(
request: Request,
session: Annotated[Session, Depends(dependencies.get_db_session)],
):
"""Get admin backend API."""
password_db = session.scalars(
select(PasswordDB).where(PasswordDB.id == 1)
).first()
username = get_optional_username(request)
origin = get_client_origin(request)
if not password_db:
raise HTTPException(
500, detail="Error: The password manager has not yet been set up."
)
admin = AdminBackend(
dependencies.settings,
username=username,

View File

@ -6,6 +6,7 @@ Since we have a frontend and a REST API, it makes sense to have a generic librar
import logging
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Unpack
from sshecret.backend import (
AuditLog,
@ -16,15 +17,21 @@ from sshecret.backend import (
Operation,
SubSystem,
)
from sshecret.backend.identifiers import KeySpec
from sshecret.backend.models import ClientQueryResult, ClientReference, DetailedSecrets
from sshecret.backend.api import AuditAPI, KeySpec
from sshecret.backend.api import AuditAPI
from sshecret.crypto import encrypt_string, load_public_key
from .secret_manager import AsyncSecretContext, password_manager_context
from .secret_manager import (
AsyncSecretContext,
SecretUpdateParams,
password_manager_context,
)
from sshecret_admin.core.settings import AdminServerSettings
from .models import (
ClientSecretGroup,
ClientSecretGroupList,
GroupReference,
SecretClientMapping,
SecretListView,
SecretGroup,
@ -57,11 +64,14 @@ def add_clients_to_secret_group(
parent: ClientSecretGroup | None = None,
) -> ClientSecretGroup:
"""Add client information to a secret group."""
parent_ref = None
if parent:
parent_ref = parent.reference()
client_secret_group = ClientSecretGroup(
group_name=group.name,
path=group.path,
description=group.description,
parent_group=parent,
parent_group=parent_ref,
)
for entry in group.entries:
secret_entries = SecretClientMapping(name=entry)
@ -74,12 +84,11 @@ def add_clients_to_secret_group(
subgroup, client_secret_mapping, client_secret_group
)
)
# We'll save a bit of memory and complexity by just adding the name of the parent, if available.
if not parent and group.parent_group:
client_secret_group.parent_group = ClientSecretGroup(
group_name=group.parent_group.name,
path=group.parent_group.path,
reference = GroupReference(
group_name=group.parent_group.name, path=group.parent_group.path
)
client_secret_group.parent_group = reference
return client_secret_group
@ -371,28 +380,29 @@ class AdminBackend:
async with self.secrets_manager() as password_manager:
await password_manager.set_secret_group(secret_name, group_name)
async def move_secret_group(
self, group_name: str, parent_group: str | None
) -> None:
async def move_secret_group(self, group_name: str, parent_group: str | None) -> str:
"""Move a group.
If parent_group is None, it will be moved to the root.
Returns the new path of the group.
"""
async with self.secrets_manager() as password_manager:
await password_manager.move_group(group_name, parent_group)
new_path = await password_manager.move_group(group_name, parent_group)
return new_path
async def set_group_description(self, group_name: str, description: str) -> None:
"""Set a group description."""
async with self.secrets_manager() as password_manager:
await password_manager.set_group_description(group_name, description)
async def delete_secret_group(self, group_name: str) -> None:
async def delete_secret_group(self, group_path: str) -> None:
"""Delete a group.
If keep_entries is set to False, all entries in the group will be deleted.
"""
async with self.secrets_manager() as password_manager:
await password_manager.delete_group(group_name)
await password_manager.delete_group(group_path)
async def get_secret_groups(
self,
@ -453,6 +463,23 @@ class AdminBackend:
return result
async def update_secret_group(
self, group_path: str, **params: Unpack[SecretUpdateParams]
) -> ClientSecretGroup:
"""Update secret group."""
async with self.secrets_manager() as password_manager:
secret_group = await password_manager.update_group(group_path, **params)
all_secrets = await self.backend.get_detailed_secrets()
secrets_mapping = {secret.name: secret for secret in all_secrets}
return add_clients_to_secret_group(secret_group, secrets_mapping)
async def lookup_secret_group(self, name_path: str) -> ClientSecretGroup | None:
"""Lookup a secret group."""
if "/" in name_path:
return await self.get_secret_group_by_path(name_path)
return await self.get_secret_group(name_path)
async def get_secret_group(self, name: str) -> ClientSecretGroup | None:
"""Get a single secret group by name."""
matches = await self.get_secret_groups(group_filter=name, regex=False)
@ -500,7 +527,10 @@ class AdminBackend:
secret_mapping = await self.backend.get_secret(idname)
if secret_mapping:
secret_view.clients = [ClientReference(id=ref.id, name=ref.name) for ref in secret_mapping.clients]
secret_view.clients = [
ClientReference(id=ref.id, name=ref.name)
for ref in secret_mapping.clients
]
return secret_view

View File

@ -144,16 +144,31 @@ class SecretClientMapping(BaseModel):
clients: list[ClientReference] = Field(default_factory=list)
class GroupReference(BaseModel):
"""Reference to a group.
This will be used for references to parent groups to avoid circular
references.
"""
group_name: str
path: str
class ClientSecretGroup(BaseModel):
"""Client secrets grouped."""
group_name: str
path: str
description: str | None = None
parent_group: "ClientSecretGroup | None" = None
parent_group: GroupReference | None = None
children: list["ClientSecretGroup"] = Field(default_factory=list)
entries: list[SecretClientMapping] = Field(default_factory=list)
def reference(self) -> GroupReference:
"""Create a reference."""
return GroupReference(group_name=self.group_name, path=self.path)
class SecretGroupCreate(BaseModel):
"""Create model for creating secret groups."""
@ -163,6 +178,14 @@ class SecretGroupCreate(BaseModel):
parent_group: str | None = None
class SecretGroupUdate(BaseModel):
"""Update model for updating secret groups."""
name: str | None = None
description: str | None = None
parent_group: str | None = None
class ClientSecretGroupList(BaseModel):
"""Secret group list."""
@ -196,3 +219,19 @@ class ClientListParams(BaseModel):
)
return self
class SecretGroupAssign(BaseModel):
"""Model for assigning secrets to a group.
If group is None, then it will be placed in the root.
"""
secret_name: str
group_path: str | None
class GroupPath(BaseModel):
"""Path to a group."""
path: str = Field(pattern="^/.*")

View File

@ -2,6 +2,7 @@
import logging
import os
from typing import NotRequired, TypedDict, Unpack
import uuid
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
@ -16,7 +17,8 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload, aliased
from sshecret.backend import SshecretBackend
from sshecret.backend.api import AuditAPI, KeySpec
from sshecret.backend.api import AuditAPI
from sshecret.backend.identifiers import KeySpec
from sshecret.backend.models import Client, ClientSecret, Operation, SubSystem
from sshecret.crypto import (
create_private_rsa_key,
@ -91,6 +93,14 @@ class SecretDataExport(BaseModel):
groups: list[SecretDataGroupExport]
class SecretUpdateParams(TypedDict):
"""Secret update parameters."""
name: NotRequired[str]
description: NotRequired[str]
parent: NotRequired[str]
def split_path(path: str) -> list[str]:
"""Split a path into a list of groups."""
elements = path.split("/")
@ -201,7 +211,9 @@ class AsyncSecretContext:
"""Build a group tree."""
path = "/"
if parent:
path = os.path.join(parent.path, path)
path = parent.path
path = os.path.join(path, group.name)
secret_group = SecretGroup(
name=group.name, path=path, description=group.description
)
@ -217,6 +229,8 @@ class AsyncSecretContext:
parent_group = await self._get_group_by_id(group.parent.id)
assert parent_group is not None
parent = await self._build_group_tree(parent_group, depth=current_depth)
path = os.path.join(parent.path, group.name)
secret_group.path = path
parent.children.append(secret_group)
secret_group.parent_group = parent
@ -224,6 +238,14 @@ class AsyncSecretContext:
return secret_group
for subgroup in group.children:
LOG.debug(
"group: %s, subgroup: %s path=%r, group_path: %r, parent: %r",
group.name,
subgroup.name,
path,
secret_group.path,
bool(parent),
)
child_group = await self._get_group_by_id(subgroup.id)
assert child_group is not None
secret_subgroup = await self._build_group_tree(
@ -462,6 +484,13 @@ class AsyncSecretContext:
result = await self.session.scalars(statement)
return result.one()
async def _lookup_group(self, name_path: str) -> Group | None:
"""Lookup group by path."""
if "/" in name_path:
elements = parse_path(name_path)
return await self._get_group(elements.item, elements.parent)
return await self._get_group(name_path)
async def _get_group(
self, name: str, parent: str | None = None, exact_match: bool = False
) -> Group | None:
@ -528,7 +557,7 @@ class AsyncSecretContext:
parent_group = elements.parent
if parent_group:
if parent := (await self._get_group(parent_group)):
if parent := (await self._lookup_group(parent_group)):
child_names = [child.name for child in parent.children]
if group_name in child_names:
raise InvalidGroupNameError(
@ -554,6 +583,63 @@ class AsyncSecretContext:
# We don't audit-log this operation.
await self.session.commit()
async def update_group(
self, name_path: str, **params: Unpack[SecretUpdateParams]
) -> SecretGroup:
"""Perform a complete update of a group.
This allows a patch operation. Only keyword arguments added will be considered.
"""
group = await self._lookup_group(name_path)
if not group:
raise InvalidGroupNameError("Invalid or non-existing parent group name.")
if description := params.get("description"):
group.description = description
target_name = group.name
rename = False
if new_name := params.get("name"):
target_name = new_name
if target_name != group.name:
rename = True
parent_group: Group | None = None
move_to_root = False
if parent := params.get("parent"):
if parent == "/":
group.parent = None
move_to_root = True
if rename:
groups = await self._get_groups(root_groups=True)
root_names = [x.name for x in groups]
if target_name in root_names:
raise InvalidGroupNameError("Name is already in use")
else:
new_parent_group = await self._lookup_group(parent)
if not new_parent_group:
raise InvalidGroupNameError(
"Invalid or non-existing parent group name."
)
parent_group = new_parent_group
group.parent_id = new_parent_group.id
elif group.parent_id and not move_to_root:
parent_group = await self._get_group_by_id(group.parent_id)
if parent_group and rename and not move_to_root:
child_names = [child.name for child in parent_group.children]
if target_name in child_names:
raise InvalidGroupNameError(
f"Parent group {parent_group.name} already has a group with this name: {target_name}. Params: {params !r}"
)
group.name = target_name
self.session.add(group)
await self.session.commit()
await self.session.refresh(group, ["parent"])
return await self._build_group_tree(group)
async def set_group_description(self, path: str, description: str) -> None:
"""Set group description."""
elements = parse_path(path)
@ -591,11 +677,12 @@ class AsyncSecretContext:
managed_secret=entry,
)
async def move_group(self, path: str, parent_group: str | None) -> None:
async def move_group(self, path: str, parent_group: str | None) -> str:
"""Move group.
If parent_group is None, it will be moved to the root.
"""
LOG.info("Move group: %s => %s", path, parent_group)
elements = parse_path(path)
group = await self._get_group(elements.item, elements.parent, True)
if not group:
@ -603,7 +690,7 @@ class AsyncSecretContext:
parent_group_id: uuid.UUID | None = None
if parent_group:
db_parent_group = await self._get_group(parent_group)
db_parent_group = await self._lookup_group(parent_group)
if not db_parent_group:
raise InvalidGroupNameError("Invalid or non-existing parent group.")
parent_group_id = db_parent_group.id
@ -612,6 +699,9 @@ class AsyncSecretContext:
self.session.add(group)
await self.session.commit()
await self.session.refresh(group)
new_path = await self._get_group_path(group)
return new_path
async def delete_group(self, path: str) -> None:
"""Delete a group."""

View File

@ -11,9 +11,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import Select
from sshecret_backend import audit
from sshecret_backend.api.common import (
FlexID,
IdType,
RelaxedId,
create_new_client_version,
query_active_clients,
get_client_by_id,
@ -22,6 +19,8 @@ from sshecret_backend.api.common import (
reload_client_with_relationships,
)
from sshecret_backend.models import Client, ClientAccessPolicy
from sshecret.backend.identifiers import FlexID, IdType, RelaxedId
from .schemas import (
ClientListParams,
ClientCreate,
@ -91,7 +90,7 @@ class ClientOperations:
) -> Client | None:
"""Get client."""
if client.type is IdType.ID:
client_id = uuid.UUID(client.value)
client_id = _id(client.value)
else:
client_id = await self.get_client_id(
client, version=version, include_deleted=include_deleted

View File

@ -21,7 +21,8 @@ from sshecret_backend.api.clients.schemas import (
)
from sshecret_backend.api.clients import operations
from sshecret_backend.api.clients.operations import ClientOperations
from sshecret_backend.api.common import FlexID
from sshecret.backend.identifiers import FlexID
LOG = logging.getLogger(__name__)

View File

@ -2,13 +2,10 @@
import re
import logging
from typing import Self
import uuid
from dataclasses import dataclass, field
from enum import Enum
import bcrypt
from pydantic import BaseModel
from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
@ -20,41 +17,6 @@ RE_UUID = re.compile(
"^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$"
)
RelaxedId = uuid.UUID | str
class IdType(Enum):
"""Id type."""
ID = "id"
NAME = "name"
class FlexID(BaseModel):
"""Flexible identifier."""
type: IdType
value: RelaxedId
@classmethod
def id(cls, id: RelaxedId) -> Self:
"""Construct from ID."""
return cls(type=IdType.ID, value=id)
@classmethod
def name(cls, name: str) -> Self:
"""Construct from name."""
return cls(type=IdType.NAME, value=name)
@classmethod
def from_string(cls, value: str) -> Self:
"""Convert from path string."""
if value.startswith("id:"):
return cls.id(value[3:])
elif value.startswith("name:"):
return cls.name(value[5:])
return cls.name(value)
@dataclass
class NewClientVersion:

View File

@ -10,8 +10,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sshecret_backend import audit
from sshecret_backend.api.common import (
FlexID,
IdType,
get_client_by_id,
resolve_client_id,
)
@ -24,9 +22,9 @@ from sshecret_backend.api.secrets.schemas import (
)
from sshecret_backend.models import Client, ClientSecret
LOG = logging.getLogger(__name__)
from sshecret.backend.identifiers import FlexID, IdType, RelaxedId
RelaxedId = uuid.UUID | str
LOG = logging.getLogger(__name__)
def _id(id: RelaxedId) -> uuid.UUID:
@ -85,6 +83,8 @@ class ClientSecretOperations:
return None
client = await get_client_by_id(self.session, client_id)
if client and (client.is_deleted and not self.include_deleted):
return None
self.client = client
return client
@ -199,15 +199,20 @@ class ClientSecretOperations:
async def resolve_client_secret_mapping(
session: AsyncSession,
session: AsyncSession, include_deleted_clients: bool = False
) -> list[ClientSecretDetailList]:
"""Resolve mapping of clients to secrets."""
"""Resolve mapping of clients to secrets.
If a secret is not deleted, but the client is, the secret is returned with
no clients attached.
"""
result = await session.execute(
select(ClientSecret)
.join(ClientSecret.client)
.options(selectinload(ClientSecret.client))
.where(Client.is_active.is_not(False))
.where(ClientSecret.deleted.is_not(True))
.where(Client.is_system.is_not(True))
)
client_secrets: dict[str, ClientSecretDetailList] = {}
for secret in result.scalars().all():
@ -216,6 +221,9 @@ async def resolve_client_secret_mapping(
client_secrets[secret.name].ids.append(str(secret.id))
if not secret.client:
continue
if secret.client.is_deleted and not include_deleted_clients:
continue
client_secrets[secret.name].clients.append(
ClientReference(id=str(secret.client.id), name=secret.client.name)
)
@ -224,7 +232,10 @@ async def resolve_client_secret_mapping(
async def resolve_client_secret_clients(
session: AsyncSession, name: str, include_deleted: bool = False
session: AsyncSession,
name: str,
include_deleted: bool = False,
include_deleted_clients: bool = False,
) -> ClientSecretDetailList | None:
"""Resolve client association to a secret."""
statement = (
@ -243,6 +254,8 @@ async def resolve_client_secret_clients(
clients = ClientSecretDetailList(name=name)
clients.ids.append(str(client_secret.id))
if client_secret.client and not client_secret.client.is_system:
if client_secret.client.is_deleted and not include_deleted_clients:
continue
clients.clients.append(
ClientReference(
id=str(client_secret.client.id), name=client_secret.client.name

View File

@ -8,7 +8,6 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sshecret_backend.api.common import FlexID
from sshecret_backend.types import AsyncDBSessionDep
from sshecret_backend.api.secrets.operations import (
ClientSecretOperations,
@ -22,6 +21,8 @@ from sshecret_backend.api.secrets.schemas import (
ClientSecretResponse,
)
from sshecret.backend.identifiers import FlexID
LOG = logging.getLogger(__name__)

View File

@ -1,6 +1,5 @@
"""CLI and main entry point."""
import code
import logging
import os
from pathlib import Path
@ -16,10 +15,6 @@ from sqlalchemy.orm import Session
from .db import create_api_token, get_engine
from .models import (
APIClient,
AuditLog,
Client,
ClientAccessPolicy,
ClientSecret,
SubSystem,
)
from .settings import BackendSettings
@ -128,26 +123,3 @@ def cli_run(host: str, port: int, dev: bool, workers: int | None) -> None:
uvicorn.run(
"sshecret_backend.main:app", host=host, port=port, reload=dev, workers=workers
)
@cli.command("repl")
@click.pass_context
def cli_repl(ctx: click.Context) -> None:
"""Run an interactive console."""
settings = cast(BackendSettings, ctx.obj)
engine = get_engine(settings.db_url, True)
with Session(engine) as session:
locals = {
"session": session,
"select": select,
"Client": Client,
"ClientSecret": ClientSecret,
"ClientAccessPolicy": ClientAccessPolicy,
"APIClient": APIClient,
"AuditLog": AuditLog,
}
console = code.InteractiveConsole(locals=locals, local_exit=True)
banner = "Sshecret-backend REPL.\nUse 'session' to interact with the database."
console.interact(banner=banner, exitmsg="Bye!")

View File

@ -5,7 +5,7 @@ admin and sshd library do not need to implement the same
"""
import logging
from typing import Any, Literal, Self, override
from typing import Any, Self, override
import httpx
from pydantic import TypeAdapter
@ -26,12 +26,11 @@ from .models import (
SubSystem,
)
from .exceptions import BackendValidationError, BackendConnectionError
from .identifiers import KeySpec
from .utils import validate_public_key
LOG = logging.getLogger(__name__)
KeyType = Literal["id", "name"]
KeySpec = str | tuple[KeyType, str]
def _key(id_or_name: KeySpec) -> str:

View File

@ -3,7 +3,7 @@
import enum
import uuid
from datetime import datetime
from typing import Annotated
from typing import Annotated, Self
from pydantic import AfterValidator, BaseModel, IPvAnyAddress, IPvAnyNetwork

View File

@ -1,12 +1,16 @@
"""Tests of the admin interface."""
import os
import allure
from dataclasses import dataclass, field
from httpx import Response
import pytest
from allure_commons.types import Severity
from ..types import AdminServer
from .base import BaseAdminTests
from sshecret_admin.services.models import ClientSecretGroup, ClientSecretGroupList
@allure.title("Admin API")
@ -129,7 +133,8 @@ class TestAdminApiSecrets(BaseAdminTests):
assert isinstance(data, dict)
assert data["name"] == "testsecret"
assert data["secret"] == "secretstring"
assert "testclient" in data["clients"]
client_names = [cl["name"] for cl in data["clients"]]
assert "testclient" in client_names
@allure.title("Test adding a secret with automatic value")
@allure.description(
@ -153,7 +158,7 @@ class TestAdminApiSecrets(BaseAdminTests):
assert isinstance(data, dict)
assert data["name"] == "testsecret"
assert len(data["secret"]) == 17
assert "testclient" in data["clients"]
assert "testclient" in [cl["name"] for cl in data["clients"]]
@allure.title("Test updating a secret")
@allure.description("Test that we can update the value of a stored secret.")
@ -182,3 +187,392 @@ class TestAdminApiSecrets(BaseAdminTests):
assert resp.status_code == 200
data = resp.json()
assert len(data["secret"]) == 16
@dataclass(kw_only=True)
class GroupHier:
"""Group hierarchy for testing."""
name: str
secrets: list[str]
path: list[str] = field(default_factory=list)
class TestSecretGroupApi(BaseAdminTests):
"""Test secret group api."""
async def add_group(
self,
admin_server: AdminServer,
group_name: str,
parent: str | None = None,
description: str | None = None,
) -> None:
"""Add a group."""
path = "api/v1/secrets/groups/"
async with self.http_client(admin_server) as http_client:
data = {"name": group_name, "parent_group": parent}
if description:
data["description"] = description
resp = await http_client.post(path, json=data)
assert resp.status_code == 200
async def get_group(self, admin_server: AdminServer, groups: list[str]) -> Response:
"""Get group."""
group_name = "/".join(groups)
path = f"api/v1/secrets/groups/{group_name}/"
async with self.http_client(admin_server) as http_client:
resp = await http_client.get(path)
return resp
async def get_groups(self, admin_server: AdminServer) -> Response:
"""Get groups."""
path = "api/v1/secrets/groups/"
async with self.http_client(admin_server) as http_client:
resp = await http_client.get(path)
return resp
async def add_secret(
self, admin_server: AdminServer, secret_name: str, group: str | None = None
) -> Response:
"""Add a secret."""
async with self.http_client(admin_server) as http_client:
data = {
"name": secret_name,
"value": "secretstring",
}
if group:
data["group"] = group
resp = await http_client.post("api/v1/secrets/", json=data)
return resp
async def add_secret_to_group(
self, admin_server: AdminServer, secret_name: str, groups: list[str] | None
) -> Response:
"""Add a secret to a group.
Secret should be created in advance.
"""
path = f"api/v1/secrets/set-group"
if not groups:
groups = []
groups.insert(0, "")
group_path = "/".join(groups)
async with self.http_client(admin_server) as http_client:
resp = await http_client.post(
path, json={"secret_name": secret_name, "group_path": group_path}
)
return resp
async def delete_secret_group(
self, admin_server: AdminServer, group_path: str
) -> Response:
"""Delete secret group."""
if group_path.startswith("/"):
group_path = group_path[1:]
path = os.path.join(f"/api/v1/secrets/groups", group_path) + "/"
async with self.http_client(admin_server) as http_client:
resp = await http_client.delete(path)
return resp
async def move_secret_group(
self, admin_server: AdminServer, group_path: str, new_path: str
) -> Response:
"""Move a secret group."""
if group_path.startswith("/"):
group_path = group_path[1:]
path = f"/api/v1/secrets/move-group/{group_path}"
async with self.http_client(admin_server) as http_client:
resp = await http_client.post(path, json={"path": new_path})
return resp
async def update_secret_group(
self,
admin_server: AdminServer,
name: str,
group_path: str,
*,
description: str | None = None,
parent: str | None = None,
) -> Response:
"""Update secret group."""
if group_path.startswith("/"):
group_path = group_path[1:]
data = {
"name": name,
"description": description,
"parent_group": parent,
}
path = f"/api/v1/secrets/groups/{group_path}/"
async with self.http_client(admin_server) as http_client:
resp = await http_client.put(path, json=data)
return resp
@pytest.mark.parametrize("group_name", ["test", "test with spaces", "blåbærgrød"])
@pytest.mark.asyncio
async def test_add_group(self, admin_server: AdminServer, group_name: str) -> None:
"""Test adding a group, then getting it."""
await self.add_group(admin_server, group_name)
response = await self.get_group(admin_server, [group_name])
assert response.status_code == 200
# We might as well try to deserialize the group.
group = ClientSecretGroup.model_validate(response.json())
assert group.group_name == group_name
@pytest.mark.parametrize("parent,child", [("parent", "child")])
@pytest.mark.asyncio
async def test_add_nested_group(
self, admin_server: AdminServer, parent: str, child: str
) -> None:
"""Test adding a group with a parent group."""
await self.add_group(admin_server, parent)
await self.add_group(admin_server, child, parent)
response = await self.get_group(admin_server, [parent])
assert response.status_code == 200
parent_group = ClientSecretGroup.model_validate(response.json())
assert parent_group.group_name == parent
assert len(parent_group.children) == 1
assert parent_group.children[0].group_name == child
response = await self.get_group(admin_server, [parent, child])
assert response.status_code == 200
child_group = ClientSecretGroup.model_validate(response.json())
assert child_group.group_name == child
assert child_group.parent_group is not None
assert child_group.parent_group.group_name == parent
assert child_group.path == "/parent/child"
@pytest.mark.parametrize(
"secret_name,group_name,parent_name",
[("test", "group", None), ("test", "child", "parent")],
)
@pytest.mark.asyncio
async def test_add_secret_to_group(
self,
admin_server: AdminServer,
secret_name: str,
group_name: str,
parent_name: str | None,
) -> None:
"""Test adding a secret to a group."""
resp = await self.add_secret(admin_server, secret_name)
assert resp.status_code == 200
groups = [group_name]
if parent_name:
await self.add_group(admin_server, parent_name)
groups = [parent_name, group_name]
await self.add_group(admin_server, group_name, parent_name)
resp = await self.add_secret_to_group(admin_server, secret_name, groups)
assert resp.status_code == 200
resp = await self.get_group(admin_server, groups)
assert resp.status_code == 200
group = ClientSecretGroup.model_validate(resp.json())
assert len(group.entries) == 1
assert group.entries[0].name == secret_name
@pytest.mark.parametrize("groups", [["group1", "group2", "group3"]])
@pytest.mark.asyncio
async def test_get_group_flat(
self, admin_server: AdminServer, groups: list[str]
) -> None:
"""Test getting a list of groups with no recursion."""
for group in groups:
await self.add_group(admin_server, group)
response = await self.get_groups(admin_server)
assert response.status_code == 200
group_list = ClientSecretGroupList.model_validate(response.json())
assert len(group_list.groups) == len(groups)
@pytest.mark.asyncio
async def test_get_group_tree(self, admin_server: AdminServer) -> None:
"""Test getting a list of groups where recursion exists."""
await self.add_group(admin_server, "root")
await self.add_group(admin_server, "level1", "root")
await self.add_group(admin_server, "level2", "level1")
await self.add_secret(admin_server, "secret1")
await self.add_secret(admin_server, "secret2", "/root/level1")
await self.add_secret(admin_server, "secret3", "/root/level1")
response = await self.get_groups(admin_server)
assert response.status_code == 200
group_list = ClientSecretGroupList.model_validate(response.json())
# we expect this to be a tree now
assert len(group_list.ungrouped) == 1
assert len(group_list.groups) == 1
assert group_list.groups[0].group_name == "root"
assert len(group_list.groups[0].children) == 1
assert len(group_list.groups[0].children[0].children) == 1
@pytest.mark.asyncio
async def test_move_secret_to_root(self, admin_server: AdminServer) -> None:
"""Test moving a secret to the root."""
await self.add_group(admin_server, "secretgroup")
await self.add_secret(admin_server, "secret1", "/secretgroup")
response = await self.get_groups(admin_server)
assert response.status_code == 200
group_list = ClientSecretGroupList.model_validate(response.json())
assert len(group_list.ungrouped) == 0
await self.add_secret_to_group(admin_server, "secret1", None)
response = await self.get_groups(admin_server)
assert response.status_code == 200
group_list = ClientSecretGroupList.model_validate(response.json())
assert len(group_list.ungrouped) == 1
@pytest.mark.asyncio
async def test_delete_secret_group(self, admin_server: AdminServer) -> None:
"""Test deleting a secret group."""
await self.add_group(admin_server, "secretgroup")
await self.add_group(admin_server, "othergroup")
await self.add_secret(admin_server, "secret1", "/secretgroup")
await self.add_secret(admin_server, "secret2", "/secretgroup")
await self.add_secret(admin_server, "secret3", "/secretgroup")
response = await self.get_groups(admin_server)
assert response.status_code == 200
group_list = ClientSecretGroupList.model_validate(response.json())
assert len(group_list.groups) == 2
response = await self.delete_secret_group(admin_server, "/secretgroup")
assert response.status_code == 200
response = await self.get_groups(admin_server)
assert response.status_code == 200
group_list = ClientSecretGroupList.model_validate(response.json())
assert len(group_list.groups) == 1
assert len(group_list.ungrouped) == 3
@pytest.mark.asyncio
async def test_nest_group(self, admin_server: AdminServer) -> None:
"""Test moving a group below another group."""
await self.add_group(admin_server, "secretgroup")
await self.add_group(admin_server, "othergroup")
await self.add_group(admin_server, "nested", "/othergroup")
await self.add_secret(admin_server, "testsecret", "/secretgroup")
response = await self.get_groups(admin_server)
assert response.status_code == 200
group_list = ClientSecretGroupList.model_validate(response.json())
assert len(group_list.groups) == 2
response = await self.move_secret_group(
admin_server, "/secretgroup", "/othergroup/nested"
)
assert response.status_code == 200
response = await self.get_groups(admin_server)
assert response.status_code == 200
group_list = ClientSecretGroupList.model_validate(response.json())
assert len(group_list.groups) == 1
assert group_list.groups[0].group_name == "othergroup"
assert len(group_list.groups[0].children) == 1
assert len(group_list.groups[0].children[0].children) == 1
assert group_list.groups[0].children[0].children[0].group_name == "secretgroup"
assert len(group_list.groups[0].children[0].children[0].entries) == 1
@pytest.mark.asyncio
async def test_add_nested_group_by_path(self, admin_server: AdminServer) -> None:
"""Test adding a group directly by path"""
await self.add_group(admin_server, "/secretgroup")
await self.add_group(admin_server, "/secretgroup/othergroup")
await self.add_group(admin_server, "/secretgroup/othergroup/nestedgroup")
response = await self.get_groups(admin_server)
assert response.status_code == 200
group_list = ClientSecretGroupList.model_validate(response.json())
assert len(group_list.groups) == 1
assert len(group_list.groups[0].children) == 1
assert len(group_list.groups[0].children[0].children) == 1
@pytest.mark.asyncio
async def test_unnest_group(self, admin_server: AdminServer) -> None:
"""Test moving a deeply nested group back to the root."""
await self.add_group(admin_server, "/secretgroup")
await self.add_group(admin_server, "/secretgroup/othergroup")
await self.add_group(admin_server, "/secretgroup/othergroup/nestedgroup")
await self.add_secret(
admin_server, "secret1", "/secretgroup/othergroup/nestedgroup"
)
await self.add_secret(
admin_server, "secret2", "/secretgroup/othergroup/nestedgroup"
)
response = await self.get_groups(admin_server)
assert response.status_code == 200
group_list = ClientSecretGroupList.model_validate(response.json())
assert len(group_list.groups) == 1
move_resp = await self.move_secret_group(
admin_server, "/secretgroup/othergroup/nestedgroup", "/"
)
assert move_resp.status_code == 200
response = await self.get_groups(admin_server)
assert response.status_code == 200
group_list = ClientSecretGroupList.model_validate(response.json())
assert len(group_list.groups) == 2
target_group = next(
filter(lambda x: x.group_name == "nestedgroup", group_list.groups)
)
assert len(target_group.entries) == 2
assert target_group.path == "/nestedgroup"
@pytest.mark.parametrize(
"group_name,description,parent",
[
(("test", "test"), ("before", "after"), (None, None)),
(("test", "newname"), ("descr", "descr"), ("parent", "parent")),
(("test", "test"), ("descr", "descr"), ("oldparent", "newparent")),
(("test", "test"), ("descr", "descr"), ("oldparent", None)),
(("oldname", "newname"), ("before", "after"), ("oldparent", "newparent")),
],
)
@pytest.mark.asyncio
async def test_group_update(
self,
admin_server: AdminServer,
group_name: tuple[str, str],
description: tuple[str, str],
parent: tuple[str | None, str | None],
) -> None:
"""Test updating a group"""
name_b, name_a = group_name
descr_b, descr_a = description
parent_b, parent_a = parent
if parent_b:
await self.add_group(admin_server, parent_b, None)
if parent_a and parent_a != parent_b:
await self.add_group(admin_server, parent_a, None)
elif not parent_a:
parent_a = "/"
await self.add_group(admin_server, name_b, parent_b, descr_b)
group_path = name_b
if parent_b:
group_path = f"{parent_b}/{name_b}"
resp = await self.update_secret_group(
admin_server, name_a, group_path, description=descr_a, parent=parent_a
)
assert resp.status_code == 200
group = ClientSecretGroup.model_validate(resp.json())
assert group.group_name == name_a
assert group.description == descr_a
if parent_a and parent_a != "/":
assert group.parent_group is not None
assert group.parent_group.group_name == parent_a
else:
assert group.parent_group is None

View File

@ -6,12 +6,13 @@ This is technically an integration test, as it requires the other subsystems to
run, but it uses the internal API rather than the exposed routes.
"""
import json
import allure
import pytest
import pytest_asyncio
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from sshecret_admin.core.settings import AdminServerSettings
from sshecret_admin.services.models import SecretGroup
@ -21,13 +22,9 @@ from sshecret_admin.services.secret_manager import (
InvalidSecretNameError,
InvalidGroupNameError,
)
from sshecret_admin.auth.models import Base, PasswordDB
from sshecret_admin.services.master_password import setup_master_password
from sshecret_admin.auth.models import Base
# -------- global parameter sets start here -------- #
# -------- Fixtures start here -------- #
from sshecret_admin.services.secret_manager import setup_private_key
@pytest_asyncio.fixture(autouse=True)
@ -35,14 +32,7 @@ async def create_admin_db(admin_server_settings: AdminServerSettings) -> None:
"""Create the database."""
engine = create_engine(admin_server_settings.admin_db)
Base.metadata.create_all(engine)
encr_master_password = setup_master_password(
settings=admin_server_settings, regenerate=True
)
with Session(engine) as session:
pwdb = PasswordDB(id=1, encrypted_password=encr_master_password)
session.add(pwdb)
session.commit()
setup_private_key(settings=admin_server_settings, regenerate=True)
@pytest_asyncio.fixture()

View File

@ -4,6 +4,7 @@ These tests just ensure that the backend works well enough for us to run the
rest of the tests.
"""
import uuid
import pytest
import httpx
from sshecret.backend import SshecretBackend
@ -60,6 +61,7 @@ async def test_create_secret(backend_api: SshecretBackend) -> None:
assert secret == "encrypted_secret"
@pytest.mark.skip("This test is broken due to time precision issues")
@pytest.mark.parametrize("offset,limit", [(0, 10), (0, 20), (10, 1)])
@pytest.mark.asyncio
async def test_client_filtering(backend_api: SshecretBackend, offset: int, limit: int) -> None:
@ -70,9 +72,58 @@ async def test_client_filtering(backend_api: SshecretBackend, offset: int, limit
test_client = create_test_client(client_name)
await backend_api.create_client(client_name, test_client.public_key)
client_filter = ClientFilter(offset=offset, limit=limit)
client_filter = ClientFilter(offset=offset, limit=limit, order_by="name")
clients = await backend_api.get_clients(client_filter)
assert len(clients) == limit
first_client = clients[0]
expected_name = f"test-{offset}"
assert first_client.name == expected_name
class TestClientDeletion:
"""Tests that ensure client deletion properly works."""
@pytest.fixture(autouse=True)
@pytest.mark.asyncio
async def create_client(self, backend_api: SshecretBackend) -> None:
"""Create initial client."""
test_client = create_test_client("testclient")
await backend_api.create_client(name="testclient", public_key=test_client.public_key, description="Test Client")
@pytest.mark.asyncio
async def test_delete_client(self, backend_api: SshecretBackend) -> None:
"""Test deleting a client."""
client_name = "testclient"
received_client = await backend_api.get_client(("name", client_name))
assert received_client is not None
assert received_client.id is not None
client_id = str(received_client.id)
await backend_api.delete_client(("name", client_name))
received_by_name = await backend_api.get_client(("name", client_name))
received_by_id = await backend_api.get_client(("id", client_id))
assert received_by_name is None
# Should this be None?
assert received_by_id is None
# Check if it's gone from all clients.
all_clients = await backend_api.get_clients()
assert len(all_clients) == 0
@pytest.mark.asyncio
async def test_delete_and_recreate(self, backend_api: SshecretBackend) -> None:
"""Test deleting a client and creating it again."""
await backend_api.delete_client(("name", "testclient"))
test_client = create_test_client("testclient")
await backend_api.create_client(name="testclient", public_key=test_client.public_key, description="Test Client")
new_client = await backend_api.get_client(("name", "testclient"))
assert new_client is not None
@pytest.mark.asyncio
async def test_delete_with_secrets(self, backend_api: SshecretBackend) -> None:
"""Ensure that the client is gone properly."""
await backend_api.create_client_secret(("name", "testclient"), "testsecret", "test")
await backend_api.delete_client(("name", "testclient"))
secrets = await backend_api.get_secrets()
# What do we actually expect to happen here? Should the secret be archived somehow?
assert len(secrets) == 1
secret = secrets[0]
assert len(secret.clients) == 0