Write new secret manager using existing RSA logic

This commit is contained in:
2025-06-22 17:17:56 +02:00
parent 5985a726e3
commit 82ec7fabb4
34 changed files with 2042 additions and 640 deletions

View File

@ -1,11 +1,11 @@
import os
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from sqlalchemy import Engine, engine_from_config, pool, create_engine
from alembic import context
from sshecret_admin.auth.models import Base
from sshecret_admin.core.settings import AdminServerSettings
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
@ -14,11 +14,32 @@ config = context.config
def get_database_url() -> str | None:
"""Get database URL."""
try:
settings = AdminServerSettings() # pyright: ignore[reportCallIssue]
return str(settings.admin_db)
except Exception:
if db_file := os.getenv("SSHECRET_ADMIN_DATABASE"):
return f"sqlite:///{db_file}"
return config.get_main_option("sqlalchemy.url")
def get_engine() -> Engine:
"""Get engine."""
try:
settings = AdminServerSettings() # pyright: ignore[reportCallIssue]
engine = create_engine(settings.admin_db)
return engine
except Exception:
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
return connectable
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
@ -68,12 +89,7 @@ def run_migrations_online() -> None:
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
connectable = get_engine()
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata, render_as_batch=True

View File

@ -0,0 +1,44 @@
"""Implement db structures for internal password manager
Revision ID: 84356d0ea85f
Revises: 6c148590471f
Create Date: 2025-06-21 07:21:02.257865
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '84356d0ea85f'
down_revision: Union[str, None] = '6c148590471f'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('groups',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('parent_id', sa.Uuid(), nullable=True),
sa.ForeignKeyConstraint(['parent_id'], ['groups.id'], ),
sa.PrimaryKeyConstraint('id')
)
with op.batch_alter_table('password_db', schema=None) as batch_op:
batch_op.add_column(sa.Column('client_id', sa.Uuid(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('password_db', schema=None) as batch_op:
batch_op.drop_column('client_id')
op.drop_table('groups')
# ### end Alembic commands ###

View File

@ -0,0 +1,48 @@
"""Implement managed secrets
Revision ID: c34707a1ea3a
Revises: 84356d0ea85f
Create Date: 2025-06-21 07:38:12.994535
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'c34707a1ea3a'
down_revision: Union[str, None] = '84356d0ea85f'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('managed_secrets',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('is_deleted', sa.Boolean(), nullable=False),
sa.Column('group_id', sa.Uuid(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ondelete='SET NULL'),
sa.PrimaryKeyConstraint('id')
)
with op.batch_alter_table('groups', schema=None) as batch_op:
batch_op.add_column(sa.Column('description', sa.String(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('groups', schema=None) as batch_op:
batch_op.drop_column('description')
op.drop_table('managed_secrets')
# ### end Alembic commands ###

View File

@ -5,9 +5,9 @@ import logging
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from sshecret_admin.auth import Token, authenticate_user, create_access_token
from sshecret_admin.auth import Token, authenticate_user_async, create_access_token
from sshecret_admin.core.dependencies import AdminDependencies
LOG = logging.getLogger(__name__)
@ -19,11 +19,12 @@ def create_router(dependencies: AdminDependencies) -> APIRouter:
@app.post("/token")
async def login_for_access_token(
session: Annotated[Session, Depends(dependencies.get_db_session)],
session: Annotated[AsyncSession, Depends(dependencies.get_async_session)],
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
) -> Token:
"""Login user and generate token."""
user = authenticate_user(session, form_data.username, form_data.password)
user = await authenticate_user_async(session, form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,

View File

@ -128,7 +128,7 @@ def create_router(dependencies: AdminDependencies) -> APIRouter:
group = await admin.get_secret_group(group_name)
if not group:
return
await admin.delete_secret_group(group_name, keep_entries=True)
await admin.delete_secret_group(group_name)
@app.post("/secrets/groups/{group_name}/{secret_name}")
async def move_secret_to_group(

View File

@ -5,8 +5,9 @@
import logging
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.security import OAuth2PasswordBearer
from fastapi.security.utils import get_authorization_scheme_param
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -57,6 +58,31 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
raise credentials_exception
return user
def get_client_origin(request: Request) -> str:
"""Get client origin."""
fallback_origin = "UNKNOWN"
if request.client:
return request.client.host
return fallback_origin
def get_optional_username(request: Request) -> str | None:
"""Get username, if available.
This is purely used for auditing purposes.
"""
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
return None
claims = decode_token(dependencies.settings, param)
if not claims:
return None
if claims.provider == LOCAL_ISSUER:
return claims.sub
return f"oidc:{claims.email}"
async def get_current_active_user(
current_user: Annotated[User, Depends(get_current_user)],
) -> User:
@ -66,9 +92,12 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
return current_user
async def get_admin_backend(
request: Request,
session: Annotated[Session, Depends(dependencies.get_db_session)],
):
"""Get admin backend API."""
username = get_optional_username(request)
origin = get_client_origin(request)
password_db = session.scalars(
select(PasswordDB).where(PasswordDB.id == 1)
).first()
@ -76,7 +105,11 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
raise HTTPException(
500, detail="Error: The password manager has not yet been set up."
)
admin = AdminBackend(dependencies.settings, password_db.encrypted_password)
admin = AdminBackend(
dependencies.settings,
username=username,
origin=origin,
)
yield admin
app = APIRouter(prefix=f"/api/{API_VERSION}")

View File

@ -1,12 +1,13 @@
"""Models for authentication."""
"""Models for authentication and secret management."""
import enum
from datetime import datetime
from typing import override
import uuid
import sqlalchemy as sa
from pydantic import BaseModel
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
JWT_ALGORITHM = "HS256"
@ -75,12 +76,15 @@ class PasswordDB(Base):
__tablename__: str = "password_db"
id: Mapped[int] = mapped_column(sa.INT, primary_key=True)
encrypted_password: Mapped[str] = mapped_column(sa.String)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
)
client_id: Mapped[uuid.UUID | None] = mapped_column(
sa.Uuid(as_uuid=True), nullable=True
)
updated_at: Mapped[datetime | None] = mapped_column(
sa.DateTime(timezone=True),
server_default=sa.func.now(),
@ -88,6 +92,65 @@ class PasswordDB(Base):
)
class Group(Base):
"""A secret group."""
__tablename__: str = "groups"
id: Mapped[uuid.UUID] = mapped_column(
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
)
name: Mapped[str] = mapped_column(sa.String, nullable=False)
description: Mapped[str | None] = mapped_column(sa.String, nullable=True)
parent_id: Mapped[uuid.UUID | None] = mapped_column(
sa.ForeignKey("groups.id"), nullable=True
)
parent: Mapped["Group | None"] = relationship(
"Group", remote_side=[id], back_populates="children"
)
children: Mapped[list["Group"]] = relationship(
"Group", back_populates="parent", cascade="all, delete"
)
secrets: Mapped[list["ManagedSecret"]] = relationship(back_populates="group")
@override
def __repr__(self) -> str:
return f"<Group id={self.id} name={self.name} parent_id={self.parent_id}>"
class ManagedSecret(Base):
"""Managed Secret."""
__tablename__: str = "managed_secrets"
id: Mapped[uuid.UUID] = mapped_column(
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
)
name: Mapped[str] = mapped_column(sa.String, nullable=False)
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, default=False)
group_id: Mapped[uuid.UUID | None] = mapped_column(
sa.ForeignKey("groups.id", ondelete="SET NULL"), nullable=True
)
group: Mapped["Group | None"] = relationship(
Group, foreign_keys=[group_id], back_populates="secrets"
)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
)
updated_at: Mapped[datetime | None] = mapped_column(
sa.DateTime(timezone=True),
server_default=sa.func.now(),
onupdate=sa.func.now(),
)
deleted_at: Mapped[datetime | None] = mapped_column(
sa.DateTime(timezone=True), nullable=True
)
class IdentityClaims(BaseModel):
"""Normalized identity claim model."""
@ -125,6 +188,3 @@ class LocalUserInfo(BaseModel):
local: bool
def init_db(engine: sa.Engine) -> None:
"""Create database."""
Base.metadata.create_all(engine)

View File

@ -2,6 +2,7 @@
# pyright: reportUnusedFunction=false
#
from collections.abc import AsyncGenerator
import logging
import os
from contextlib import asynccontextmanager
@ -12,15 +13,15 @@ from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from sshecret_backend.db import DatabaseSessionManager
from starlette.middleware.sessions import SessionMiddleware
from sshecret_admin import api, frontend
from sshecret_admin.auth.models import PasswordDB, init_db
from sshecret_admin.auth.models import Base
from sshecret_admin.core.db import setup_database
from sshecret_admin.frontend.exceptions import RedirectException
from sshecret_admin.services.master_password import setup_master_password
from sshecret_admin.services.secret_manager import setup_private_key
from .dependencies import BaseDependencies
from .settings import AdminServerSettings
@ -40,44 +41,28 @@ def setup_frontend(app: FastAPI, dependencies: BaseDependencies) -> None:
def create_admin_app(
settings: AdminServerSettings, with_frontend: bool = True
settings: AdminServerSettings,
with_frontend: bool = True,
create_db: bool = False,
) -> FastAPI:
"""Create admin app."""
engine, get_db_session = setup_database(settings.admin_db)
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
"""Get async session."""
session_manager = DatabaseSessionManager(settings.async_db_url)
async with session_manager.session() as session:
yield session
def setup_password_manager() -> None:
"""Setup password manager."""
encr_master_password = setup_master_password(
settings=settings, regenerate=False
)
with Session(engine) as session:
existing_password = session.scalars(
select(PasswordDB).where(PasswordDB.id == 1)
).first()
if not encr_master_password:
if existing_password:
LOG.info("Master password already defined.")
return
# Looks like we have to regenerate it
LOG.warning(
"Master password was set, but not saved to the database. Regenerating it."
)
encr_master_password = setup_master_password(
settings=settings, regenerate=True
)
assert encr_master_password is not None
with Session(engine) as session:
pwdb = PasswordDB(id=1, encrypted_password=encr_master_password)
session.add(pwdb)
session.commit()
setup_private_key(settings, regenerate=False)
@asynccontextmanager
async def lifespan(_app: FastAPI):
"""Create database before starting the server."""
init_db(engine)
if create_db:
Base.metadata.create_all(engine)
setup_password_manager()
yield
@ -109,7 +94,7 @@ def create_admin_app(
status_code=status.HTTP_200_OK, content=jsonable_encoder({"status": "LIVE"})
)
dependencies = BaseDependencies(settings, get_db_session)
dependencies = BaseDependencies(settings, get_db_session, get_async_session)
app.include_router(api.create_api_router(dependencies))
if with_frontend:

View File

@ -12,7 +12,7 @@ from pydantic import ValidationError
from sqlalchemy import select, create_engine
from sqlalchemy.orm import Session
from sshecret_admin.auth.authentication import hash_password
from sshecret_admin.auth.models import AuthProvider, PasswordDB, User, init_db
from sshecret_admin.auth.models import AuthProvider, PasswordDB, User
from sshecret_admin.core.settings import AdminServerSettings
from sshecret_admin.services.admin_backend import AdminBackend
@ -72,7 +72,6 @@ def cli_create_user(
"""Create user."""
settings = cast(AdminServerSettings, ctx.obj)
engine = create_engine(settings.admin_db)
init_db(engine)
with Session(engine) as session:
create_user(session, username, email, password)
@ -87,7 +86,6 @@ def cli_change_user_passwd(ctx: click.Context, username: str, password: str) ->
"""Change password on user."""
settings = cast(AdminServerSettings, ctx.obj)
engine = create_engine(settings.admin_db)
init_db(engine)
with Session(engine) as session:
user = session.scalars(select(User).where(User.username == username)).first()
if not user:
@ -107,7 +105,6 @@ def cli_delete_user(ctx: click.Context, username: str) -> None:
"""Remove a user."""
settings = cast(AdminServerSettings, ctx.obj)
engine = create_engine(settings.admin_db)
init_db(engine)
with Session(engine) as session:
user = session.scalars(select(User).where(User.username == username)).first()
if not user:
@ -149,7 +146,6 @@ def cli_repl(ctx: click.Context) -> None:
"""Run an interactive console."""
settings = cast(AdminServerSettings, ctx.obj)
engine = create_engine(settings.admin_db)
init_db(engine)
with Session(engine) as session:
password_db = session.scalars(
select(PasswordDB).where(PasswordDB.id == 1)
@ -165,7 +161,7 @@ def cli_repl(ctx: click.Context) -> None:
loop = asyncio.get_event_loop()
return loop.run_until_complete(func)
admin = AdminBackend(settings, password_db.encrypted_password)
admin = AdminBackend(settings, )
locals = {
"run": run,
"admin": admin,

View File

@ -1,12 +1,13 @@
"""Database setup."""
import sqlite3
from contextlib import asynccontextmanager
from collections.abc import AsyncIterator, Generator, Callable
from sqlalchemy.orm import Session
from sqlalchemy.engine import URL
from sqlalchemy import create_engine, Engine
from sqlalchemy import create_engine, Engine, event
from sqlalchemy.ext.asyncio import (
AsyncConnection,
@ -18,11 +19,20 @@ from sqlalchemy.ext.asyncio import (
def setup_database(
db_url: URL | str,
db_url: URL,
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
"""Setup database."""
engine = create_engine(db_url, echo=True, future=True)
if db_url.drivername.startswith("sqlite"):
@event.listens_for(engine, "connect")
def set_sqlite_pragma(
dbapi_connection: sqlite3.Connection, _connection_record: object
) -> None:
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
def get_db_session() -> Generator[Session, None, None]:
"""Get DB Session."""
@ -33,8 +43,18 @@ def setup_database(
class DatabaseSessionManager:
def __init__(self, host: URL | str, **engine_kwargs: str):
def __init__(self, host: URL, **engine_kwargs: str):
self._engine: AsyncEngine | None = create_async_engine(host, **engine_kwargs)
if host.drivername.startswith("sqlite+"):
@event.listens_for(self._engine.sync_engine, "connect")
def set_sqlite_pragma(
dbapi_connection: sqlite3.Connection, _connection_record: object
) -> None:
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
self._sessionmaker: async_sessionmaker[AsyncSession] | None = (
async_sessionmaker(
autocommit=False, bind=self._engine, expire_on_commit=False

View File

@ -4,6 +4,8 @@ from collections.abc import AsyncGenerator, Awaitable, Callable, Generator
from dataclasses import dataclass
from typing import Self
from fastapi import Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from sshecret_admin.auth import User
from sshecret_admin.services import AdminBackend
@ -11,8 +13,9 @@ from sshecret_admin.core.settings import AdminServerSettings
DBSessionDep = Callable[[], Generator[Session, None, None]]
AsyncSessionDep = Callable[[], AsyncGenerator[AsyncSession, None]]
AdminDep = Callable[[Session], AsyncGenerator[AdminBackend, None]]
AdminDep = Callable[[Request, Session], AsyncGenerator[AdminBackend, None]]
GetUserDep = Callable[[User], Awaitable[User]]
@ -23,6 +26,8 @@ class BaseDependencies:
settings: AdminServerSettings
get_db_session: DBSessionDep
get_async_session: AsyncSessionDep
@dataclass
@ -43,6 +48,7 @@ class AdminDependencies(BaseDependencies):
return cls(
settings=deps.settings,
get_db_session=deps.get_db_session,
get_async_session=deps.get_async_session,
get_admin_backend=get_admin_backend,
get_current_active_user=get_current_active_user,
)

View File

@ -30,7 +30,6 @@ class FrontendDependencies(BaseDependencies):
get_refresh_claims: RefreshTokenDep
get_login_status: LoginStatusDep
get_user_info: UserInfoDep
get_async_session: AsyncSessionDep
require_login: LoginGuardDep
@classmethod
@ -42,18 +41,17 @@ class FrontendDependencies(BaseDependencies):
get_refresh_claims: RefreshTokenDep,
get_login_status: LoginStatusDep,
get_user_info: UserInfoDep,
get_async_session: AsyncSessionDep,
require_login: LoginGuardDep,
) -> Self:
"""Create from base dependencies."""
return cls(
settings=deps.settings,
get_db_session=deps.get_db_session,
get_async_session=deps.get_async_session,
get_admin_backend=get_admin_backend,
templates=templates,
get_refresh_claims=get_refresh_claims,
get_login_status=get_login_status,
get_user_info=get_user_info,
get_async_session=get_async_session,
require_login=require_login,
)

View File

@ -24,7 +24,6 @@ from sshecret_admin.auth.constants import LOCAL_ISSUER
from sshecret_admin.core.dependencies import BaseDependencies
from sshecret_admin.services.admin_backend import AdminBackend
from sshecret_admin.core.db import DatabaseSessionManager
from .dependencies import FrontendDependencies
from .exceptions import RedirectException
@ -50,17 +49,24 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
templates = Jinja2Blocks(directory=template_path)
async def get_admin_backend(
request: Request,
session: Annotated[Session, Depends(dependencies.get_db_session)],
):
"""Get admin backend API."""
password_db = session.scalars(
select(PasswordDB).where(PasswordDB.id == 1)
).first()
username = get_optional_username(request)
origin = get_client_origin(request)
if not password_db:
raise HTTPException(
500, detail="Error: The password manager has not yet been set up."
)
admin = AdminBackend(dependencies.settings, password_db.encrypted_password)
admin = AdminBackend(
dependencies.settings,
username=username,
origin=origin,
)
yield admin
def get_identity_claims(request: Request) -> IdentityClaims:
@ -108,14 +114,9 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
next = URL("/refresh").include_query_params(next=request.url.path)
raise RedirectException(to=next)
async def get_async_session():
"""Get async session."""
sessionmanager = DatabaseSessionManager(dependencies.settings.async_db_url)
async with sessionmanager.session() as session:
yield session
async def get_user_info(
request: Request, session: Annotated[AsyncSession, Depends(get_async_session)]
request: Request,
session: Annotated[AsyncSession, Depends(dependencies.get_async_session)],
) -> LocalUserInfo:
"""Get User information."""
claims = get_identity_claims(request)
@ -142,6 +143,30 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
next = URL("/refresh").include_query_params(next=request.url.path)
raise RedirectException(to=next)
def get_optional_username(
request: Request,
) -> str | None:
"""Get username, if available.
This is purely used for auditing purposes.
"""
try:
claims = get_identity_claims(request)
except Exception:
return None
if claims.provider == LOCAL_ISSUER:
return claims.sub
return f"oidc:{claims.email}"
def get_client_origin(request: Request) -> str:
"""Get client origin."""
fallback_origin = "UNKNOWN"
if request.client:
return request.client.host
return fallback_origin
view_dependencies = FrontendDependencies.create(
dependencies,
get_admin_backend,
@ -149,7 +174,6 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
refresh_identity_claims,
get_login_status,
get_user_info,
get_async_session,
require_login,
)

View File

@ -16,7 +16,7 @@
<!-- Detail Pane -->
<section id="detail-pane"
class="flex-1 flex overflow-y-auto bg-white p-4 {%- if not mobile_show_details|default(false) -%} hidden{%- endif -%} lg:block dark:bg-gray-800">
class="flex-1 flex overflow-y-auto bg-white p-4 lg:block {% if not mobile_show_details|default(false) -%} hidden{%- endif -%} lg:block dark:bg-gray-800">
{% block detail %}

View File

@ -5,7 +5,7 @@ import ipaddress
import logging
import uuid
from typing import Annotated
from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request, Response
from fastapi import APIRouter, Depends, Form, HTTPException, Request, Response
from fastapi.responses import RedirectResponse
from pydantic import BaseModel, IPvAnyAddress, IPvAnyNetwork
from sshecret_admin.frontend.views.common import PagingInfo
@ -209,7 +209,7 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
page: int,
) -> Response:
"""Get more events for a client."""
if not "HX-Request" in request.headers:
if "HX-Request" not in request.headers:
return RedirectResponse(url=f"/clients/client/{id}")
client = await admin.get_client(("id", id))

View File

@ -4,8 +4,8 @@ Since we have a frontend and a REST API, it makes sense to have a generic librar
"""
import logging
from collections.abc import Iterator
from contextlib import contextmanager
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from sshecret.backend import (
AuditLog,
@ -20,7 +20,7 @@ from sshecret.backend.models import ClientQueryResult, DetailedSecrets
from sshecret.backend.api import AuditAPI, KeySpec
from sshecret.crypto import encrypt_string, load_public_key
from .keepass import PasswordContext, load_password_manager
from .secret_manager import AsyncSecretContext, password_manager_context
from sshecret_admin.core.settings import AdminServerSettings
from .models import (
ClientSecretGroup,
@ -86,19 +86,27 @@ def add_clients_to_secret_group(
class AdminBackend:
"""Admin backend API."""
def __init__(self, settings: AdminServerSettings, keepass_password: str) -> None:
def __init__(
self,
settings: AdminServerSettings,
username: str | None = None,
origin: str = "UNKNOWN",
) -> None:
"""Create client management API."""
self.settings: AdminServerSettings = settings
self.backend: SshecretBackend = SshecretBackend(
str(settings.backend_url), settings.backend_token
)
self.keepass_password: str = keepass_password
self.username: str = username or "UKNOWN_USER"
self.origin: str = origin
@contextmanager
def password_manager(self) -> Iterator[PasswordContext]:
"""Open the password manager."""
with load_password_manager(self.settings, self.keepass_password) as kp:
yield kp
@asynccontextmanager
async def secrets_manager(self) -> AsyncIterator[AsyncSecretContext]:
"""Open the secrets manager."""
async with password_manager_context(
self.settings, self.username, self.origin
) as manager:
yield manager
async def _get_clients(self, filter: ClientFilter | None = None) -> list[Client]:
"""Get clients from backend."""
@ -194,7 +202,7 @@ class AdminBackend:
self,
name: KeySpec,
new_key: str,
password_manager: PasswordContext,
password_manager: AsyncSecretContext,
) -> list[str]:
"""Update client public key."""
LOG.info(
@ -207,7 +215,7 @@ class AdminBackend:
updated_secrets: list[str] = []
for secret in client.secrets:
LOG.debug("Re-encrypting secret %s for client %s", secret, name)
secret_value = password_manager.get_secret(secret)
secret_value = await password_manager.get_secret(secret)
if not secret_value:
LOG.warning(
"Referenced secret %s does not exist! Skipping.", secret_value
@ -224,7 +232,7 @@ class AdminBackend:
async def update_client_public_key(self, name: KeySpec, new_key: str) -> list[str]:
"""Update client public key."""
try:
with self.password_manager() as password_manager:
async with self.secrets_manager() as password_manager:
return await self._update_client_public_key(
name, new_key, password_manager
)
@ -291,8 +299,8 @@ class AdminBackend:
This fetches the secret to client mapping from backend, and adds secrets from the password manager.
"""
backend_secrets = await self.backend.get_secrets()
with self.password_manager() as password_manager:
admin_secrets = password_manager.get_available_secrets()
async with self.secrets_manager() as password_manager:
admin_secrets = await password_manager.get_available_secrets()
secrets: dict[str, SecretListView] = {}
for secret in backend_secrets:
@ -324,8 +332,8 @@ class AdminBackend:
This fetches the secret to client mapping from backend, and adds secrets from the password manager.
"""
with self.password_manager() as password_manager:
all_secrets = password_manager.get_available_secrets()
async with self.secrets_manager() as password_manager:
all_secrets = await password_manager.get_available_secrets()
secrets = await self.backend.get_detailed_secrets()
backend_secret_names = [secret.name for secret in secrets]
@ -351,13 +359,13 @@ class AdminBackend:
parent_group: str | None = None,
) -> None:
"""Add secret group."""
with self.password_manager() as password_manager:
password_manager.add_group(group_name, description, parent_group)
async with self.secrets_manager() as password_manager:
await password_manager.add_group(group_name, description, parent_group)
async def set_secret_group(self, secret_name: str, group_name: str | None) -> None:
"""Assign a group to a secret."""
with self.password_manager() as password_manager:
password_manager.set_secret_group(secret_name, group_name)
async with self.secrets_manager() as password_manager:
await password_manager.set_secret_group(secret_name, group_name)
async def move_secret_group(
self, group_name: str, parent_group: str | None
@ -366,23 +374,21 @@ class AdminBackend:
If parent_group is None, it will be moved to the root.
"""
with self.password_manager() as password_manager:
password_manager.move_group(group_name, parent_group)
async with self.secrets_manager() as password_manager:
await password_manager.move_group(group_name, parent_group)
async def set_group_description(self, group_name: str, description: str) -> None:
"""Set a group description."""
with self.password_manager() as password_manager:
password_manager.set_group_description(group_name, description)
async with self.secrets_manager() as password_manager:
await password_manager.set_group_description(group_name, description)
async def delete_secret_group(
self, group_name: str, keep_entries: bool = True
) -> None:
async def delete_secret_group(self, group_name: str) -> None:
"""Delete a group.
If keep_entries is set to False, all entries in the group will be deleted.
"""
with self.password_manager() as password_manager:
password_manager.delete_group(group_name, keep_entries)
async with self.secrets_manager() as password_manager:
await password_manager.delete_group(group_name)
async def get_secret_groups(
self,
@ -399,18 +405,18 @@ class AdminBackend:
"""
all_secrets = await self.backend.get_detailed_secrets()
secrets_mapping = {secret.name: secret for secret in all_secrets}
with self.password_manager() as password_manager:
async with self.secrets_manager() as password_manager:
if flat:
all_groups = password_manager.get_secret_group_list(
all_groups = await password_manager.get_secret_group_list(
group_filter, regex=regex
)
else:
all_groups = password_manager.get_secret_groups(
all_groups = await password_manager.get_secret_groups(
group_filter, regex=regex
)
ungrouped = password_manager.get_ungrouped_secrets()
ungrouped = await password_manager.get_ungrouped_secrets()
all_admin_secrets = password_manager.get_available_secrets()
all_admin_secrets = await password_manager.get_available_secrets()
group_result: list[ClientSecretGroup] = []
for group in all_groups:
@ -452,8 +458,8 @@ class AdminBackend:
async def get_secret_group_by_path(self, path: str) -> ClientSecretGroup | None:
"""Get a group based on its path."""
with self.password_manager() as password_manager:
secret_group = password_manager.get_secret_group(path)
async with self.secrets_manager() as password_manager:
secret_group = await password_manager.get_secret_group(path)
if not secret_group:
return None
@ -476,9 +482,11 @@ class AdminBackend:
) -> SecretView | None:
"""Get a secret, including the actual unencrypted value and clients."""
secret: str | None = None
with self.password_manager() as password_manager:
secret = password_manager.get_secret(name)
secret_group = password_manager.get_entry_group(name)
async with self.secrets_manager() as password_manager:
secret = await password_manager.get_secret(name)
secret_group: str | None = None
if secret:
secret_group = await password_manager.get_entry_group(name)
secret_view = SecretView(name=name, secret=secret, group=secret_group)
@ -503,8 +511,8 @@ class AdminBackend:
async def _delete_secret(self, name: str) -> None:
"""Delete a secret."""
with self.password_manager() as password_manager:
password_manager.delete_entry(name)
async with self.secrets_manager() as password_manager:
await password_manager.delete_entry(name)
secret_mapping = await self.backend.get_secret(name)
if not secret_mapping:
@ -522,8 +530,8 @@ class AdminBackend:
group: str | None = None,
) -> None:
"""Add a secret."""
with self.password_manager() as password_manager:
password_manager.add_entry(name, value, update, group_name=group)
async with self.secrets_manager() as password_manager:
await password_manager.add_entry(name, value, update, group_path=group)
if update:
secret_map = await self.backend.get_secret(name)
@ -576,8 +584,8 @@ class AdminBackend:
if not client:
raise ClientNotFoundError(client_idname)
with self.password_manager() as password_manager:
secret = password_manager.get_secret(secret_name)
async with self.secrets_manager() as password_manager:
secret = await password_manager.get_secret(secret_name)
if not secret:
raise SecretNotFoundError()

View File

@ -1,348 +0,0 @@
"""Keepass password manager."""
import logging
from collections.abc import Iterator
from contextlib import contextmanager
from pathlib import Path
from typing import cast
import pykeepass
import pykeepass.exceptions
from sshecret_admin.core.settings import AdminServerSettings
from .models import SecretGroup
from .master_password import decrypt_master_password
LOG = logging.getLogger(__name__)
NO_USERNAME = "NO_USERNAME"
DEFAULT_LOCATION = "keepass.kdbx"
class PasswordCredentialsError(Exception):
pass
def create_password_db(location: Path, password: str) -> None:
"""Create the password database."""
LOG.info("Creating password database at %s", location)
pykeepass.create_database(str(location.absolute()), password=password)
def _kp_group_to_secret_group(
kp_group: pykeepass.group.Group,
parent: SecretGroup | None = None,
depth: int | None = None,
) -> SecretGroup:
"""Convert keepass group to secret group dataclass."""
group_name = cast(str, kp_group.name)
path = "/".join(cast(list[str], kp_group.path))
group = SecretGroup(name=group_name, path=path, description=kp_group.notes)
for entry in kp_group.entries:
group.entries.append(str(entry.title))
if parent:
group.parent_group = parent
current_depth = len(kp_group.path)
if not parent and current_depth > 1:
parent = _kp_group_to_secret_group(kp_group.parentgroup, depth=current_depth)
parent.children.append(group)
group.parent_group = parent
if depth and depth == current_depth:
return group
for subgroup in kp_group.subgroups:
group.children.append(_kp_group_to_secret_group(subgroup, group, depth=depth))
return group
class PasswordContext:
"""Password Context class."""
def __init__(self, keepass: pykeepass.PyKeePass) -> None:
"""Initialize password context."""
self.keepass: pykeepass.PyKeePass = keepass
@property
def _root_group(self) -> pykeepass.group.Group:
"""Return the root group."""
return cast(pykeepass.group.Group, self.keepass.root_group)
def _get_entry(self, name: str) -> pykeepass.entry.Entry | None:
"""Get entry."""
entry = cast(
"pykeepass.entry.Entry | None",
self.keepass.find_entries(title=name, first=True),
)
return entry
def _get_group(self, name: str) -> pykeepass.group.Group | None:
"""Find a group."""
group = cast(
pykeepass.group.Group | None,
self.keepass.find_groups(name=name, first=True),
)
return group
def add_entry(
self,
entry_name: str,
secret: str,
overwrite: bool = False,
group_name: str | None = None,
) -> None:
"""Add an entry.
Specify overwrite=True to overwrite the existing secret value, if it exists.
This will not move the entry, if the group_name is different from the original group.
"""
entry = self._get_entry(entry_name)
if entry and overwrite:
entry.password = secret
self.keepass.save()
return
if entry:
raise ValueError("Error: A secret with this name already exists.")
LOG.debug("Add secret entry to keepass: %s, group: %r", entry_name, group_name)
if group_name:
destination_group = self._get_group(group_name)
else:
destination_group = self._root_group
entry = self.keepass.add_entry(
destination_group=destination_group,
title=entry_name,
username=NO_USERNAME,
password=secret,
)
self.keepass.save()
def get_secret(self, entry_name: str) -> str | None:
"""Get the secret value."""
entry = self._get_entry(entry_name)
if not entry:
return None
LOG.warning("Secret name %s accessed", entry_name)
if password := cast(str, entry.password):
return str(password)
raise RuntimeError(f"Cannot get password for entry {entry_name}")
def get_entry_group(self, entry_name: str) -> str | None:
"""Get the group for an entry."""
entry = self._get_entry(entry_name)
if not entry:
return None
if entry.group.is_root_group:
return None
return str(entry.group.name)
def get_secret_groups(
self, pattern: str | None = None, regex: bool = True
) -> list[SecretGroup]:
"""Get secret groups.
A regex pattern may be provided to filter groups.
"""
if pattern:
groups = cast(
list[pykeepass.group.Group],
self.keepass.find_groups(name=pattern, regex=regex),
)
else:
groups = self._root_group.subgroups
secret_groups = [_kp_group_to_secret_group(group) for group in groups]
return secret_groups
def get_secret_group_list(
self, pattern: str | None = None, regex: bool = True
) -> list[SecretGroup]:
"""Get a flat list of groups."""
if pattern:
return self.get_secret_groups(pattern, regex)
groups = [group for group in self.keepass.groups if not group.is_root_group]
secret_groups = [_kp_group_to_secret_group(group) for group in groups]
return secret_groups
def get_secret_group(self, path: str) -> SecretGroup | None:
"""Get a secret group by path."""
elements = path.split("/")
final_element = elements[-1]
current = self._root_group
while elements:
groupname = elements.pop(0)
matches = [
subgroup for subgroup in current.subgroups if subgroup.name == groupname
]
if matches:
current = matches[0]
else:
return None
if not current.is_root_group and current.name == final_element:
return _kp_group_to_secret_group(current)
return None
def get_ungrouped_secrets(self) -> list[str]:
"""Get secrets without groups."""
entries: list[str] = []
for entry in self._root_group.entries:
entries.append(str(entry.title))
return entries
def add_group(
self, name: str, description: str | None = None, parent_group: str | None = None
) -> None:
"""Add a group."""
kp_parent_group = self._root_group
if parent_group:
query = cast(
pykeepass.group.Group | None,
self.keepass.find_groups(name=parent_group, first=True),
)
if not query:
raise ValueError(
f"Error: Cannot find a parent group named {parent_group}"
)
kp_parent_group = query
self.keepass.add_group(
destination_group=kp_parent_group, group_name=name, notes=description
)
self.keepass.save()
def set_group_description(self, name: str, description: str) -> None:
"""Set the description of a group."""
group = self._get_group(name)
if not group:
raise ValueError(f"Error: No such group {name}")
group.notes = description
self.keepass.save()
def set_secret_group(self, entry_name: str, group_name: str | None) -> None:
"""Move a secret to a group.
If group is None, the secret will be placed in the root group.
"""
entry = self._get_entry(entry_name)
if not entry:
raise ValueError(
f"Cannot find secret entry named {entry_name} in secrets database"
)
if group_name:
group = self._get_group(group_name)
if not group:
raise ValueError(f"Cannot find a group named {group_name}")
else:
group = self._root_group
self.keepass.move_entry(entry, group)
self.keepass.save()
def move_group(self, name: str, parent_group: str | None) -> None:
"""Move a group.
If parent_group is None, it will be moved to the root.
"""
group = self._get_group(name)
if not group:
raise ValueError(f"Error: No such group {name}")
if parent_group:
parent = self._get_group(parent_group)
if not parent:
raise ValueError(f"Error: No such group {parent_group}")
else:
parent = self._root_group
self.keepass.move_group(group, parent)
self.keepass.save()
def get_available_secrets(self, group_name: str | None = None) -> list[str]:
"""Get the names of all secrets in the database."""
if group_name:
group = self._get_group(group_name)
if not group:
raise ValueError(f"Error: No such group {group_name}")
entries = group.entries
else:
entries = cast(list[pykeepass.entry.Entry], self.keepass.entries)
if not entries:
return []
return [str(entry.title) for entry in entries]
def delete_entry(self, entry_name: str) -> None:
"""Delete entry."""
entry = cast(
"pykeepass.entry.Entry | None",
self.keepass.find_entries(title=entry_name, first=True),
)
if not entry:
return
entry.delete()
self.keepass.save()
def delete_group(self, name: str, keep_entries: bool = True) -> None:
"""Delete a group.
If keep_entries is set to False, all entries in the group will be deleted.
"""
group = self._get_group(name)
if not group:
return
if keep_entries:
for entry in cast(
list[pykeepass.entry.Entry],
self.keepass.find_entries(recursive=True, group=group),
):
# Move the entry to the root group.
LOG.warning(
"Moving orphaned secret entry %s to root group", entry.title
)
self.keepass.move_entry(entry, self._root_group)
self.keepass.delete_group(group)
self.keepass.save()
@contextmanager
def _password_context(location: Path, password: str) -> Iterator[PasswordContext]:
"""Open the password context."""
try:
database = pykeepass.PyKeePass(str(location.absolute()), password=password)
except pykeepass.exceptions.CredentialsError as e:
raise PasswordCredentialsError(
"Could not open password database. Invalid credentials."
) from e
context = PasswordContext(database)
yield context
@contextmanager
def load_password_manager(
settings: AdminServerSettings,
encrypted_password: str,
location: str = DEFAULT_LOCATION,
) -> Iterator[PasswordContext]:
"""Load password manager.
This function decrypts the password, and creates the password database if it
has not yet been created.
"""
db_location = Path(location)
password = decrypt_master_password(settings=settings, encrypted=encrypted_password)
if not db_location.exists():
create_password_db(db_location, password)
with _password_context(db_location, password) as context:
yield context

View File

@ -1,86 +0,0 @@
"""Functions related to handling the password database master password."""
import secrets
from pathlib import Path
from sshecret.crypto import (
create_private_rsa_key,
load_private_key,
encrypt_string,
decode_string,
)
from sshecret_admin.core.settings import AdminServerSettings
KEY_FILENAME = "sshecret-admin-key"
def setup_master_password(
settings: AdminServerSettings,
filename: str = KEY_FILENAME,
regenerate: bool = False,
) -> str | None:
"""Setup master password.
If regenerate is True, a new key will be generated.
This method should run just after setting up the database.
"""
keyfile = Path(filename)
if settings.password_manager_directory:
keyfile = settings.password_manager_directory / filename
created = _initial_key_setup(settings, keyfile, regenerate)
if not created:
return None
return _generate_master_password(settings, keyfile)
def decrypt_master_password(
settings: AdminServerSettings, encrypted: str, filename: str = KEY_FILENAME
) -> str:
"""Retrieve master password."""
keyfile = Path(filename)
if settings.password_manager_directory:
keyfile = settings.password_manager_directory / filename
if not keyfile.exists():
raise RuntimeError("Error: Private key has not been generated yet.")
private_key = load_private_key(
str(keyfile.absolute()), password=settings.secret_key
)
return decode_string(encrypted, private_key)
def _generate_password() -> str:
"""Generate a password."""
return secrets.token_urlsafe(32)
def _initial_key_setup(
settings: AdminServerSettings,
keyfile: Path,
regenerate: bool = False,
) -> bool:
"""Set up initial keys."""
if keyfile.exists() and not regenerate:
return False
assert settings.secret_key is not None, (
"Error: Could not load a secret key from environment."
)
create_private_rsa_key(keyfile, password=settings.secret_key)
return True
def _generate_master_password(settings: AdminServerSettings, keyfile: Path) -> str:
"""Generate master password for password database.
Returns the encrypted string, base64 encoded.
"""
if not keyfile.exists():
raise RuntimeError("Error: Private key has not been generated yet.")
private_key = load_private_key(
str(keyfile.absolute()), password=settings.secret_key
)
public_key = private_key.public_key()
master_password = _generate_password()
return encrypt_string(master_password, public_key)

View File

@ -0,0 +1,776 @@
"""Rewritten secret manager using a rsa keys."""
import logging
import os
import uuid
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from dataclasses import dataclass
from datetime import datetime, timezone
from functools import cached_property
from pathlib import Path
from cryptography.hazmat.primitives.asymmetric import rsa
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload, aliased
from sshecret.backend import SshecretBackend
from sshecret.backend.api import AuditAPI, KeySpec
from sshecret.backend.models import Client, ClientSecret, Operation, SubSystem
from sshecret.crypto import (
create_private_rsa_key,
decode_string,
encrypt_string,
generate_public_key_string,
load_private_key,
load_public_key,
)
from sshecret_admin.auth import PasswordDB
from sshecret_admin.auth.models import Group, ManagedSecret
from sshecret_admin.core.db import DatabaseSessionManager
from sshecret_admin.core.settings import AdminServerSettings
from sshecret_admin.services.models import SecretGroup
KEY_FILENAME = "sshecret-admin-key"
PASSWORD_MANAGER_ID = "SshecretAdminPasswordManager"
LOG = logging.getLogger(PASSWORD_MANAGER_ID)
class SecretManagerError(Exception):
"""Secret manager error."""
class InvalidGroupNameError(SecretManagerError):
"""Invalid group name."""
class InvalidSecretNameError(SecretManagerError):
"""Invalid secret name."""
@dataclass
class ClientAuditData:
"""Client audit data."""
username: str
origin: str
@dataclass
class ParsedPath:
"""Parsed path."""
item: str
full_path: str
parent: str | None = None
class SecretDataEntryExport(BaseModel):
"""Exportable secret entry."""
name: str
secret: str
group: str | None = None
class SecretDataGroupExport(BaseModel):
"""Exportable secret grouping."""
name: str
path: str
description: str | None = None
class SecretDataExport(BaseModel):
"""Exportable object containing secrets and groups."""
entries: list[SecretDataEntryExport]
groups: list[SecretDataGroupExport]
def split_path(path: str) -> list[str]:
"""Split a path into a list of groups."""
elements = path.split("/")
if path.startswith("/"):
elements = elements[1:]
return elements
def parse_path(path: str) -> ParsedPath:
"""Parse path."""
elements = split_path(path)
parsed = ParsedPath(elements[-1], path)
if len(elements) > 1:
parsed.parent = elements[-2]
return parsed
class AsyncSecretContext:
"""Async secret context."""
def __init__(
self,
private_key: rsa.RSAPrivateKey,
manager_client: Client,
session: AsyncSession,
backend: SshecretBackend,
audit_data: ClientAuditData,
) -> None:
"""Initialize secret manager"""
self._private_key: rsa.RSAPrivateKey = private_key
self._manager_client: Client = manager_client
self._id: KeySpec = ("id", str(manager_client.id))
self.backend: SshecretBackend = backend
self.session: AsyncSession = session
self.audit_data: ClientAuditData = audit_data
self.audit: AuditAPI = backend.audit(SubSystem.ADMIN)
self._import_has_run: bool = False
async def _create_missing_entries(self) -> None:
"""Create any missing entries."""
new_secrets: bool = False
to_check = set(self._manager_client.secrets)
for secret_name in to_check:
# entry = await self._get_entry(secret_name, include_deleted=True)
statement = select(ManagedSecret).where(ManagedSecret.name == secret_name)
result = await self.session.scalars(statement)
if not result.first():
new_secrets = True
managed_secret = ManagedSecret(name=secret_name)
self.session.add(managed_secret)
await self.session.flush()
await self.write_audit(
Operation.CREATE,
message="Imported managed secret from backend.",
secret_name=secret_name,
managed_secret=managed_secret,
)
if new_secrets:
await self.session.commit()
async def _get_group_depth(self, group: Group) -> int:
"""Get the depth of a group."""
depth = 1
if not group.parent_id:
return depth
current = group
while current.parent is not None:
if current.parent:
depth += 1
current = await self._get_group_by_id(current.parent.id)
else:
break
return depth
async def _get_group_path(self, group: Group) -> str:
"""Get the path of a group."""
if not group.parent_id:
return group.name
path: list[str] = []
current = group
while current.parent_id is not None:
path.append(current.name)
current = await self._get_group_by_id(current.parent_id)
path.append("")
path.reverse()
return "/".join(path)
async def _get_group_secrets(self, group: Group) -> list[ManagedSecret]:
"""Get secrets in a group."""
statement = (
select(ManagedSecret)
.where(ManagedSecret.group_id == group.id)
.where(ManagedSecret.is_deleted.is_not(True))
)
results = await self.session.scalars(statement)
return list(results.all())
async def _build_group_tree(
self, group: Group, parent: SecretGroup | None = None, depth: int | None = None
) -> SecretGroup:
"""Build a group tree."""
path = "/"
if parent:
path = os.path.join(parent.path, path)
secret_group = SecretGroup(
name=group.name, path=path, description=group.description
)
group_secrets = await self._get_group_secrets(group)
for secret in group_secrets:
secret_group.entries.append(secret.name)
if parent:
secret_group.parent_group = parent
current_depth = await self._get_group_depth(group)
if not parent and group.parent:
parent_group = await self._get_group_by_id(group.parent.id)
assert parent_group is not None
parent = await self._build_group_tree(parent_group, depth=current_depth)
parent.children.append(secret_group)
secret_group.parent_group = parent
if depth and depth == current_depth:
return secret_group
for subgroup in group.children:
child_group = await self._get_group_by_id(subgroup.id)
assert child_group is not None
secret_subgroup = await self._build_group_tree(
child_group, secret_group, depth=depth
)
secret_group.children.append(secret_subgroup)
return secret_group
async def write_audit(
self,
operation: Operation,
message: str,
group_name: str | None = None,
client_secret: ClientSecret | None = None,
secret_name: str | None = None,
managed_secret: ManagedSecret | None = None,
**data: str,
) -> None:
"""Write Audit message."""
if group_name:
data["group"] = group_name
data["username"] = self.audit_data.username
if client_secret and not secret_name:
secret_name = client_secret.name
if managed_secret:
data["managed_secret"] = str(managed_secret.id)
await self.audit.write_async(
operation=operation,
message=message,
origin=self.audit_data.origin,
client=self._manager_client,
secret=client_secret,
secret_name=secret_name,
**data,
)
@cached_property
def public_key(self) -> rsa.RSAPublicKey:
"""Get public key."""
keystring = self._manager_client.public_key
return load_public_key(keystring.encode())
async def _get_entry(
self, name: str, include_deleted: bool = False
) -> ManagedSecret | None:
"""Get managed secret."""
if not self._import_has_run:
await self._create_missing_entries()
self._import_has_run = True
statement = (
select(ManagedSecret)
.options(selectinload(ManagedSecret.group))
.where(ManagedSecret.name == name)
)
if not include_deleted:
statement = statement.where(ManagedSecret.is_deleted.is_not(True))
result = await self.session.scalars(statement)
return result.first()
async def add_entry(
self,
entry_name: str,
secret: str,
overwrite: bool = False,
group_path: str | None = None,
) -> None:
"""Add entry."""
existing_entry = await self._get_entry(entry_name)
if existing_entry and not overwrite:
raise InvalidSecretNameError(
"Another secret with this name is already defined."
)
encrypted = encrypt_string(secret, self.public_key)
client_secret = await self.backend.create_client_secret(
self._id, entry_name, encrypted
)
group_id: uuid.UUID | None = None
if group_path:
elements = parse_path(group_path)
group = await self._get_group(elements.item, elements.parent, True)
if not group:
raise InvalidGroupNameError("Invalid group name")
group_id = group.id
if existing_entry:
existing_entry.updated_at = datetime.now(timezone.utc)
if group_id:
existing_entry.group_id = group_id
self.session.add(existing_entry)
await self.session.commit()
await self.write_audit(
Operation.UPDATE,
"Updated secret value",
group_name=group_path,
client_secret=client_secret,
managed_secret=existing_entry,
)
else:
managed_secret = ManagedSecret(
name=entry_name,
group_id=group_id,
)
self.session.add(managed_secret)
await self.session.commit()
await self.write_audit(
Operation.CREATE,
"Created managed client secret",
group_path,
client_secret=client_secret,
managed_secret=managed_secret,
)
async def get_secret(self, entry_name: str) -> str | None:
"""Get secret."""
client_secret = await self.backend.get_client_secret(
self._id, ("name", entry_name)
)
if not client_secret:
return None
decrypted = decode_string(client_secret, self._private_key)
await self.write_audit(
Operation.READ,
"Secret was viewed from secret manager",
secret_name=entry_name,
)
return decrypted
async def get_available_secrets(self, group_path: str | None = None) -> list[str]:
"""Get the names of all secrets in the db."""
if not self._import_has_run:
await self._create_missing_entries()
if group_path:
elements = parse_path(group_path)
group = await self._get_group(elements.item, elements.parent)
if not group:
raise InvalidGroupNameError("Invalid or nonexisting group name.")
entries = group.secrets
else:
result = await self.session.scalars(
select(ManagedSecret)
.options(selectinload(ManagedSecret.group))
.where(ManagedSecret.is_deleted.is_not(True))
)
entries = list(result.all())
return [entry.name for entry in entries]
async def delete_entry(self, entry_name: str) -> None:
"""Delete a secret."""
entry = await self._get_entry(entry_name)
if not entry:
return
entry.is_deleted = True
entry.deleted_at = datetime.now(timezone.utc)
self.session.add(entry)
await self.session.commit()
await self.backend.delete_client_secret(
("id", str(self._manager_client.id)), ("name", entry_name)
)
await self.write_audit(
Operation.DELETE,
"Managed secret entry deleted",
secret_name=entry_name,
managed_secret=entry,
)
async def get_entry_group(self, entry_name: str) -> str | None:
"""Get group of entry."""
entry = await self._get_entry(entry_name)
if not entry:
raise InvalidSecretNameError("Invalid secret name or secret not found.")
if entry.group:
return entry.group.name
return None
async def _get_groups(
self, pattern: str | None = None, regex: bool = True, root_groups: bool = False
) -> list[Group]:
"""Get groups."""
statement = select(Group).options(
selectinload(Group.children), selectinload(Group.parent)
)
if pattern and regex:
statement = statement.where(Group.name.regexp_match(pattern))
elif pattern:
statement = statement.where(Group.name.contains(pattern))
if root_groups:
statement = statement.where(Group.parent_id == None)
results = await self.session.scalars(statement)
return list(results.all())
async def get_secret_groups(
self, pattern: str | None = None, regex: bool = True
) -> list[SecretGroup]:
"""Get secret groups, as a hierarcy."""
if pattern:
groups = await self._get_groups(pattern, regex)
else:
groups = await self._get_groups(root_groups=True)
secret_groups: list[SecretGroup] = []
for group in groups:
secret_group = await self._build_group_tree(group)
secret_groups.append(secret_group)
return secret_groups
async def get_secret_group_list(
self, pattern: str | None = None, regex: bool = True
) -> list[SecretGroup]:
"""Get secret group list."""
groups = await self._get_groups(pattern, regex)
return [(await self._build_group_tree(group)) for group in groups]
async def _get_group_by_id(self, id: uuid.UUID) -> Group:
"""Get group by ID."""
statement = (
select(Group)
.options(
selectinload(Group.parent),
selectinload(Group.children),
selectinload(Group.secrets),
)
.where(Group.id == id)
)
result = await self.session.scalars(statement)
return result.one()
async def _get_group(
self, name: str, parent: str | None = None, exact_match: bool = False
) -> Group | None:
"""Get a group."""
statement = (
select(Group)
.options(
selectinload(Group.parent),
selectinload(Group.children),
selectinload(Group.secrets),
)
.where(Group.name == name)
)
if parent:
ParentGroup = aliased(Group)
statement = statement.join(ParentGroup, Group.parent).where(
ParentGroup.name == parent
)
elif exact_match:
statement = statement.where(Group.parent_id == None)
result = await self.session.scalars(statement)
return result.first()
async def get_secret_group(self, path: str) -> SecretGroup | None:
"""Get a secret group by path."""
elements = parse_path(path)
group_name = elements.item
parent_group = elements.parent
group = await self._get_group(group_name, parent_group)
if not group:
return None
return await self._build_group_tree(group)
async def get_ungrouped_secrets(self) -> list[str]:
"""Get ungrouped secrets."""
statement = (
select(ManagedSecret)
.where(ManagedSecret.is_deleted.is_not(True))
.where(ManagedSecret.group_id == None)
)
result = await self.session.scalars(statement)
secrets = result.all()
return [secret.name for secret in secrets]
async def add_group(
self,
name_or_path: str,
description: str | None = None,
parent_group: str | None = None,
) -> None:
"""Add a group."""
parent_id: uuid.UUID | None = None
group_name = name_or_path
if parent_group and name_or_path.startswith("/"):
raise InvalidGroupNameError(
"Path as name cannot be used if parent is also specified."
)
if name_or_path.startswith("/"):
elements = parse_path(name_or_path)
group_name = elements.item
parent_group = elements.parent
if parent_group:
if parent := (await self._get_group(parent_group)):
child_names = [child.name for child in parent.children]
if group_name in child_names:
raise InvalidGroupNameError(
"Parent group already has a group with this name."
)
parent_id = parent.id
else:
raise InvalidGroupNameError(
"Invalid or non-existing parent group name."
)
else:
existing_group = await self._get_group(group_name)
if existing_group:
raise InvalidGroupNameError("A group with this name already exists.")
group = Group(
name=group_name,
description=description,
parent_id=parent_id,
)
self.session.add(group)
# We don't audit-log this operation.
await self.session.commit()
async def set_group_description(self, path: str, description: str) -> None:
"""Set group description."""
elements = parse_path(path)
group = await self._get_group(elements.item, elements.parent, True)
if not group:
raise InvalidGroupNameError("Invalid or non-existing group name.")
group.description = description
self.session.add(group)
await self.session.commit()
async def set_secret_group(self, entry_name: str, group_name: str | None) -> None:
"""Move a secret to a group.
If group_name is None, the secret will be moved out of any group it may exist in.
"""
entry = await self._get_entry(entry_name)
if not entry:
raise InvalidSecretNameError("Invalid or non-existing secret.")
if group_name:
elements = parse_path(group_name)
group = await self._get_group(elements.item, elements.parent, True)
if not group:
raise InvalidGroupNameError("Invalid or non-existing group name.")
entry.group_id = group.id
else:
entry.group_id = None
self.session.add(entry)
await self.session.commit()
await self.write_audit(
Operation.UPDATE,
"Secret group updated",
group_name=group_name or "ROOT",
secret_name=entry_name,
managed_secret=entry,
)
async def move_group(self, path: str, parent_group: str | None) -> None:
"""Move group.
If parent_group is None, it will be moved to the root.
"""
elements = parse_path(path)
group = await self._get_group(elements.item, elements.parent, True)
if not group:
raise InvalidGroupNameError("Invalid or non-existing group name.")
parent_group_id: uuid.UUID | None = None
if parent_group:
db_parent_group = await self._get_group(parent_group)
if not db_parent_group:
raise InvalidGroupNameError("Invalid or non-existing parent group.")
parent_group_id = db_parent_group.id
group.parent_id = parent_group_id
self.session.add(group)
await self.session.commit()
async def delete_group(self, path: str) -> None:
"""Delete a group."""
elements = parse_path(path)
group = await self._get_group(elements.item, elements.parent, True)
if not group:
return
await self.session.delete(group)
await self.session.commit()
# We don't audit-log this operation currently, even though it indirectly
# may affect secrets.
async def _export_entries(self) -> list[SecretDataEntryExport]:
"""Export entries as a pydantic object."""
statement = (
select(ManagedSecret)
.options(selectinload(ManagedSecret.group))
.where(ManagedSecret.is_deleted.is_(False))
)
results = await self.session.scalars(statement)
entries: list[SecretDataEntryExport] = []
for entry in results.all():
group: str | None = None
if entry.group:
group = await self._get_group_path(entry.group)
secret = await self.get_secret(entry.name)
if not secret:
continue
data = SecretDataEntryExport(name=entry.name, secret=secret, group=group)
entries.append(data)
return entries
async def _export_groups(self) -> list[SecretDataGroupExport]:
"""Export groups as pydantic objects."""
groups = await self.get_secret_group_list()
entries = [
SecretDataGroupExport(
name=group.name,
path=group.path,
description=group.description,
)
for group in groups
]
return entries
async def export_secrets(self) -> SecretDataExport:
"""Export the managed secrets as a pydantic object."""
entries = await self._export_entries()
groups = await self._export_groups()
return SecretDataExport(entries=entries, groups=groups)
async def export_secrets_json(self) -> str:
"""Export secrets as JSON."""
export = await self.export_secrets()
return export.model_dump_json(indent=2)
def get_managed_private_key(
settings: AdminServerSettings,
filename: str = KEY_FILENAME,
regenerate: bool = False,
) -> rsa.RSAPrivateKey:
"""Load our private key."""
keyfile = Path(filename)
if settings.password_manager_directory:
keyfile = settings.password_manager_directory / filename
if not keyfile.exists():
_initial_key_setup(settings, keyfile)
setup_password_manager(settings, keyfile, regenerate)
return load_private_key(str(keyfile.absolute()), password=settings.secret_key)
def setup_password_manager(
settings: AdminServerSettings, filename: Path, regenerate: bool = False
) -> bool:
"""Setup password manager."""
if filename.exists() and not regenerate:
return False
if not settings.secret_key:
raise RuntimeError("Error: Could not load secret key from environment.")
create_private_rsa_key(filename, password=settings.secret_key)
return True
async def create_manager_client(
backend: SshecretBackend, public_key: rsa.RSAPublicKey
) -> Client:
"""Create the manager client."""
public_key_string = generate_public_key_string(public_key)
new_client = await backend.create_system_client(
"AdminPasswordManager",
public_key_string,
)
return new_client
@asynccontextmanager
async def password_manager_context(
settings: AdminServerSettings, username: str, origin: str
) -> AsyncIterator[AsyncSecretContext]:
"""Start a context for the password manager."""
audit_context_data = ClientAuditData(username=username, origin=origin)
session_manager = DatabaseSessionManager(settings.async_db_url)
backend = SshecretBackend(str(settings.backend_url), settings.backend_token)
private_key = get_managed_private_key(settings)
async with session_manager.session() as session:
# Check if there is a client_id stored already.
query = select(PasswordDB).where(PasswordDB.id == 1)
result = await session.scalars(query)
password_db = result.first()
if not password_db:
password_db = PasswordDB(id=1)
session.add(password_db)
await session.flush()
if not password_db.client_id:
manager_client = await create_manager_client(
backend, private_key.public_key()
)
password_db.client_id = manager_client.id
session.add(password_db)
await session.commit()
else:
manager_client = await backend.get_client(
("id", str(password_db.client_id))
)
if not manager_client:
raise SecretManagerError("Error: Could not fetch system client.")
context = AsyncSecretContext(
private_key, manager_client, session, backend, audit_context_data
)
yield context
def setup_private_key(
settings: AdminServerSettings,
filename: str = KEY_FILENAME,
regenerate: bool = False,
) -> None:
"""Setup secret manager private key."""
keyfile = Path(filename)
if settings.password_manager_directory:
keyfile = settings.password_manager_directory / filename
_initial_key_setup(settings, keyfile, regenerate)
def _initial_key_setup(
settings: AdminServerSettings,
keyfile: Path,
regenerate: bool = False,
) -> bool:
"""Set up initial keys."""
if keyfile.exists() and not regenerate:
return False
assert (
settings.secret_key is not None
), "Error: Could not load a secret key from environment."
create_private_rsa_key(keyfile, password=settings.secret_key)
return True

View File

@ -0,0 +1,33 @@
"""Implement db structures for internal password manager
Revision ID: 1657c5d25d2c
Revises: b4e135ff347a
Create Date: 2025-06-21 07:22:17.792528
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '1657c5d25d2c'
down_revision: Union[str, None] = 'b4e135ff347a'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('client', sa.Column('is_system', sa.Boolean(), nullable=False, default=False, server_default="0"))
op.add_column('client_secret', sa.Column('is_system', sa.Boolean(), nullable=False, default=False, server_default="0"))
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('client_secret', 'is_system')
op.drop_column('client', 'is_system')

View File

@ -0,0 +1,44 @@
"""Remove secret key from password database
Revision ID: 71f7272a6ee1
Revises: 1657c5d25d2c
Create Date: 2025-06-22 18:42:53.207334
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '71f7272a6ee1'
down_revision: Union[str, None] = '1657c5d25d2c'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('managed_secret')
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('managed_secret',
sa.Column('id', sa.CHAR(length=32), nullable=False),
sa.Column('name', sa.VARCHAR(), nullable=False),
sa.Column('description', sa.VARCHAR(), nullable=True),
sa.Column('secret', sa.VARCHAR(), nullable=False),
sa.Column('client_id', sa.CHAR(length=32), nullable=True),
sa.Column('deleted', sa.BOOLEAN(), nullable=False),
sa.Column('created_at', sa.DATETIME(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
sa.Column('updated_at', sa.DATETIME(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
sa.Column('deleted_at', sa.DATETIME(), nullable=True),
sa.ForeignKeyConstraint(['client_id'], ['client.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###

View File

@ -107,8 +107,7 @@ class ClientOperations:
return ClientView.from_client(db_client)
async def create_client(
self,
create_model: ClientCreate,
self, create_model: ClientCreate, system_client: bool = False
) -> ClientView:
"""Create a new client."""
existing_id = await self.get_client_id(FlexID.name(create_model.name))
@ -117,6 +116,15 @@ class ClientOperations:
status_code=400, detail="Error: A client already exists with this name."
)
client = create_model.to_client()
if system_client:
statement = query_active_clients().where(Client.is_system.is_(True))
results = await self.session.scalars(statement)
other_system_clients = results.all()
if other_system_clients:
raise HTTPException(
status_code=400, detail="Only one system client may exist"
)
client.is_system = True
self.session.add(client)
await self.session.flush()
await self.session.commit()
@ -246,6 +254,15 @@ class ClientOperations:
return ClientPolicyView.from_client(db_client)
async def get_system_client(self) -> ClientView:
"""Get the system client, if it exists."""
statement = query_active_clients().where(Client.is_system.is_(True))
result = await self.session.scalars(statement)
client = result.first()
if not client:
raise HTTPException(status_code=404, detail="No system client registered")
return ClientView.from_client(client)
def resolve_order(statement: Select[Any], order_by: str, reversed: bool) -> Select[Any]:
"""Resolve ordering."""
@ -267,6 +284,7 @@ def resolve_order(statement: Select[Any], order_by: str, reversed: bool) -> Sele
LOG.warning("Unsupported order field: %s", order_by)
return statement
def filter_client_statement(
statement: Select[Any], params: ClientListParams, ignore_limits: bool = False
) -> Select[Any]:
@ -299,6 +317,7 @@ async def get_clients(
.select_from(Client)
.where(Client.is_deleted.is_not(True))
.where(Client.is_active.is_not(False))
.where(Client.is_system.is_not(True))
)
count_statement = cast(
Select[tuple[int]],
@ -307,7 +326,8 @@ async def get_clients(
total_results = (await session.scalars(count_statement)).one()
statement = filter_client_statement(query_active_clients(), filter_query, False)
statement = query_active_clients().where(Client.is_system.is_not(True))
statement = filter_client_statement(statement, filter_query, False)
results = await session.scalars(statement)
remainder = total_results - filter_query.offset - filter_query.limit

View File

@ -46,6 +46,25 @@ def create_client_router(get_db_session: AsyncDBSessionDep) -> APIRouter:
client_op = ClientOperations(session, request)
return await client_op.create_client(client)
@router.get("/internal/system_client/", include_in_schema=False)
async def get_system_client(
request: Request,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Get the system client."""
client_op = ClientOperations(session, request)
return await client_op.get_system_client()
@router.post("/internal/system_client/", include_in_schema=False)
async def create_system_client(
request: Request,
client: ClientCreate,
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ClientView:
"""Create system client."""
client_op = ClientOperations(session, request)
return await client_op.create_client(client, system_client=True)
@router.get("/clients/{client_identifier}")
async def fetch_client_by_name(
request: Request,

View File

@ -242,7 +242,7 @@ async def resolve_client_secret_clients(
# Ensure we don't create the object before we have at least one client.
clients = ClientSecretDetailList(name=name)
clients.ids.append(str(client_secret.id))
if client_secret.client:
if client_secret.client and not client_secret.client.is_system:
clients.clients.append(
ClientReference(
id=str(client_secret.client.id), name=client_secret.client.name

View File

@ -110,7 +110,6 @@ def get_async_engine(url: URL, echo: bool = False, **engine_kwargs: str) -> Asyn
"""Get an async engine."""
engine = create_async_engine(url, echo=echo, **engine_kwargs)
if url.drivername.startswith("sqlite+"):
@event.listens_for(engine.sync_engine, "connect")
def set_sqlite_pragma(
dbapi_connection: sqlite3.Connection, _connection_record: object

View File

@ -67,6 +67,7 @@ class Client(Base):
is_active: Mapped[bool] = mapped_column(sa.Boolean, default=True)
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, default=False)
is_system: Mapped[bool] = mapped_column(sa.Boolean, default=False)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
@ -141,6 +142,8 @@ class ClientSecret(Base):
client: Mapped[Client] = relationship(back_populates="secrets")
deleted: Mapped[bool] = mapped_column(default=False)
is_system: Mapped[bool] = mapped_column(sa.Boolean, default=False)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
)

View File

@ -140,11 +140,23 @@ class ShellStoreSecret(CommandDispatcher):
secret=secret_name,
)
await self.store_managed_secret(secret_name, secret_data)
def encrypt_secret(self, value: str) -> str:
"""Encrypt a secret."""
public_key = load_public_key(self.client.public_key.encode())
return encrypt_string(value, public_key)
async def store_managed_secret(self, secret_name: str, secret_data: str) -> None:
"""Store managed secret."""
system_client = await self.backend.get_system_client()
if not system_client:
return
public_key = load_public_key(system_client.public_key.encode())
encrypted = encrypt_string(secret_data, public_key)
await self.backend.create_client_secret(("id", str(system_client.id)), secret_name, encrypted)
await self.audit(operation=Operation.CREATE, message="Managed secret entry created.", secret=secret_name)
async def get_secret_on_stdin(self) -> str | None:
"""Get secret from stdin."""
secret_data = ""

View File

@ -6,6 +6,7 @@ admin and sshd library do not need to implement the same
import logging
from typing import Any, Literal, Self, override
import uuid
import httpx
from pydantic import TypeAdapter
@ -325,6 +326,28 @@ class SshecretBackend(BaseBackend):
path = "/api/v1/clients/"
response = await self._post(path, json=data)
async def create_system_client(self, name: str, public_key: str) -> Client:
"""Create system client."""
if not validate_public_key(public_key):
raise BackendValidationError("Error: Invalid public key format.")
data = {
"name": name,
"public_key": public_key,
"description": "Internal system client",
}
path = "/api/v1/internal/system_client/"
response = await self._post(path, json=data)
return Client.model_validate(response.json())
async def get_system_client(self) -> Client | None:
"""Get the system client."""
path = "/api/v1/internal/system_client/"
response = await self._get(path)
if response.status_code == 404:
return None
return Client.model_validate(response.json())
async def get_clients(self, filter: ClientFilter | None = None) -> list[Client]:
"""Get all clients."""
clients: list[Client] = []
@ -375,7 +398,7 @@ class SshecretBackend(BaseBackend):
async def create_client_secret(
self, client_idname: KeySpec, secret_name: str, encrypted_secret: str
) -> None:
) -> ClientSecret:
"""Create a secret.
This will overwrite any existing secret with that name.
@ -383,6 +406,8 @@ class SshecretBackend(BaseBackend):
client_key = _key(client_idname)
path = f"api/v1/clients/{client_key}/secrets/{secret_name}"
response = await self._put(path, json={"value": encrypted_secret})
secret = ClientSecret.model_validate(response.json())
return secret
async def get_client_secret(
self, client_idname: KeySpec, secret_idname: KeySpec

View File

@ -8,14 +8,14 @@ from pathlib import Path
from sqlmodel import Session, create_engine
from sshecret.crypto import generate_private_key, write_private_key
from sshecret_admin.auth.authentication import hash_password
from sshecret_admin.auth.models import AuthProvider, User, init_db
from sshecret_admin.auth.models import AuthProvider, User, Base
from sshecret_admin.core.settings import AdminServerSettings
def create_test_admin_user(settings: AdminServerSettings, username: str, password: str) -> None:
"""Create a test admin user."""
hashed_password = hash_password(password)
engine = create_engine(settings.admin_db)
init_db(engine)
Base.metadata.create_all(engine)
with Session(engine) as session:
user = User(username=username, hashed_password=hashed_password, provider=AuthProvider.LOCAL, email="test@test.com")
session.add(user)

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,67 @@
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
import httpx
from sshecret.backend import Client
from sshecret.crypto import generate_private_key, generate_public_key_string
from ..types import AdminServer
def make_test_key() -> str:
"""Generate a test key."""
private_key = generate_private_key()
return generate_public_key_string(private_key.public_key())
class BaseAdminTests:
"""Base admin test class."""
@asynccontextmanager
async def http_client(
self, admin_server: AdminServer, authenticate: bool = True
) -> AsyncIterator[httpx.AsyncClient]:
"""Run a client towards the admin rest api."""
admin_url, credentials = admin_server
username, password = credentials
headers: dict[str, str] | None = None
if authenticate:
async with httpx.AsyncClient(base_url=admin_url) as client:
response = await client.post(
"api/v1/token", data={"username": username, "password": password}
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
token = data["access_token"]
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient(base_url=admin_url, headers=headers) as client:
yield client
async def create_client(
self,
admin_server: AdminServer,
name: str,
public_key: str | None = None,
) -> Client:
"""Create a client."""
if not public_key:
public_key = make_test_key()
new_client = {
"name": name,
"public_key": public_key,
"sources": ["192.0.2.0/24"],
}
async with self.http_client(admin_server, True) as http_client:
response = await http_client.post("api/v1/clients/", json=new_client)
assert response.status_code == 200
data = response.json()
client = Client.model_validate(data)
return client

View File

@ -1,76 +1,12 @@
"""Tests of the admin interface."""
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
import allure
import pytest
import httpx
from allure_commons.types import Severity
from sshecret.backend import Client
from sshecret.crypto import generate_private_key, generate_public_key_string
from .types import AdminServer
def make_test_key() -> str:
"""Generate a test key."""
private_key = generate_private_key()
return generate_public_key_string(private_key.public_key())
class BaseAdminTests:
"""Base admin test class."""
@asynccontextmanager
async def http_client(
self, admin_server: AdminServer, authenticate: bool = True
) -> AsyncIterator[httpx.AsyncClient]:
"""Run a client towards the admin rest api."""
admin_url, credentials = admin_server
username, password = credentials
headers: dict[str, str] | None = None
if authenticate:
async with httpx.AsyncClient(base_url=admin_url) as client:
response = await client.post(
"api/v1/token", data={"username": username, "password": password}
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
token = data["access_token"]
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient(base_url=admin_url, headers=headers) as client:
yield client
async def create_client(
self,
admin_server: AdminServer,
name: str,
public_key: str | None = None,
) -> Client:
"""Create a client."""
if not public_key:
public_key = make_test_key()
new_client = {
"name": name,
"public_key": public_key,
"sources": ["192.0.2.0/24"],
}
async with self.http_client(admin_server, True) as http_client:
response = await http_client.post("api/v1/clients/", json=new_client)
assert response.status_code == 200
data = response.json()
client = Client.model_validate(data)
return client
from ..types import AdminServer
from .base import BaseAdminTests
@allure.title("Admin API")
@ -196,7 +132,9 @@ class TestAdminApiSecrets(BaseAdminTests):
assert "testclient" in data["clients"]
@allure.title("Test adding a secret with automatic value")
@allure.description("Test that we can add a secret where we let the system come up with the value of a given length.")
@allure.description(
"Test that we can add a secret where we let the system come up with the value of a given length."
)
@pytest.mark.asyncio
async def test_add_secret_auto(self, admin_server: AdminServer) -> None:
"""Test adding a secret with an auto-generated value."""

View File

@ -0,0 +1,634 @@
"""Test secret manager.
This package tests the rewritten secret manager system.
This is technically an integration test, as it requires the other subsystems to
run, but it uses the internal API rather than the exposed routes.
"""
import allure
import pytest
import pytest_asyncio
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from sshecret_admin.core.settings import AdminServerSettings
from sshecret_admin.services.models import SecretGroup
from sshecret_admin.services.secret_manager import (
password_manager_context,
AsyncSecretContext,
InvalidSecretNameError,
InvalidGroupNameError,
)
from sshecret_admin.auth.models import Base, PasswordDB
from sshecret_admin.services.master_password import setup_master_password
# -------- global parameter sets start here -------- #
# -------- Fixtures start here -------- #
@pytest_asyncio.fixture(autouse=True)
async def create_admin_db(admin_server_settings: AdminServerSettings) -> None:
"""Create the database."""
engine = create_engine(admin_server_settings.admin_db)
Base.metadata.create_all(engine)
encr_master_password = setup_master_password(
settings=admin_server_settings, regenerate=True
)
with Session(engine) as session:
pwdb = PasswordDB(id=1, encrypted_password=encr_master_password)
session.add(pwdb)
session.commit()
@pytest_asyncio.fixture()
async def secrets_manager(admin_server_settings: AdminServerSettings):
"""Test that the context manager can be created."""
async with password_manager_context(
admin_server_settings, "TEST", "127.0.0.1"
) as manager:
yield manager
# -------- Tests start here -------- #
@allure.title("Adding entries")
@pytest.mark.parametrize("name,secret", [("testentry", "testsecret")])
class TestSecretsAddEntry:
"""Tests for the add_entry method."""
@pytest.mark.asyncio
async def test_add_entry(
self,
secrets_manager: AsyncSecretContext,
name: str,
secret: str,
) -> None:
"""Test add entry.
This tests add_entry and get_secret
"""
await secrets_manager.add_entry(name, secret)
stored_secret = await secrets_manager.get_secret(name)
assert stored_secret == secret
async def test_add_entry_duplicate(
self,
secrets_manager: AsyncSecretContext,
name: str,
secret: str,
) -> None:
"""Test adding an entry twice."""
await secrets_manager.add_entry(name, secret)
stored_secret = await secrets_manager.get_secret(name)
assert stored_secret == secret
with pytest.raises(InvalidSecretNameError):
await secrets_manager.add_entry(name, secret)
async def test_add_entry_with_group(
self,
secrets_manager: AsyncSecretContext,
name: str,
secret: str,
) -> None:
"""Test adding a secret with a group."""
group = "testgroup"
await secrets_manager.add_group(group)
await secrets_manager.add_entry(name, secret, group_path=group)
result = await secrets_manager.get_entry_group(name)
assert result == group
async def test_add_entry_with_nonexisting_group(
self,
secrets_manager: AsyncSecretContext,
name: str,
secret: str,
) -> None:
"""Test adding a secret where the group does not exist."""
group = "testgroup"
with pytest.raises(InvalidGroupNameError):
await secrets_manager.add_entry(name, secret, group_path=group)
async def test_add_entry_with_deep_path(
self,
secrets_manager: AsyncSecretContext,
name: str,
secret: str,
) -> None:
"""Test adding a secret to a nested group with a path specification."""
await secrets_manager.add_group("root")
await secrets_manager.add_group("nested", parent_group="root")
await secrets_manager.add_entry(name, secret, group_path="/root/nested")
group = await secrets_manager.get_secret_group("/root/nested")
assert group is not None
assert name in group.entries
async def test_overwrite_secret(
self, secrets_manager: AsyncSecretContext, name: str, secret: str
) -> None:
"""Test overwriting a secret."""
await secrets_manager.add_entry(name, secret)
stored_secret = await secrets_manager.get_secret(name)
assert stored_secret == secret
new_secret = "newsecret"
await secrets_manager.add_entry(name, new_secret, overwrite=True)
stored_secret = await secrets_manager.get_secret(name)
assert stored_secret == new_secret
@allure.title("Creating groups")
class TestSecretGroupCreation:
"""Test secret groups."""
@pytest.mark.parametrize(
"group_name", ["testgroup", "long group name with spaces", "blåbærgrød"]
)
@allure.title("Add a group name {group_name}")
@pytest.mark.asyncio
async def test_add_group(
self,
secrets_manager: AsyncSecretContext,
group_name: str,
) -> None:
"""Get adding a group."""
await secrets_manager.add_group(group_name)
groups = await secrets_manager.get_secret_groups()
assert len(groups) == 1
assert groups[0].name == group_name
@pytest.mark.asyncio
async def test_add_group_with_parent(
self, secrets_manager: AsyncSecretContext
) -> None:
"""Test add a group with a parent group."""
parent_name = "parent"
child_name = "child"
await secrets_manager.add_group(parent_name)
await secrets_manager.add_group(child_name, parent_group=parent_name)
parent_group = await secrets_manager.get_secret_group(f"/parent")
assert parent_group is not None
assert len(parent_group.children) == 1
assert parent_group.children[0].name == child_name
child_group = await secrets_manager.get_secret_group("/parent/child")
assert child_group is not None
assert child_group.name == child_name
assert child_group.parent_group is not None
assert child_group.parent_group.name == parent_name
assert len(child_group.children) == 0
@pytest.mark.asyncio
async def test_add_group_as_path(self, secrets_manager: AsyncSecretContext) -> None:
"""Add a nested group with path annotation."""
parent_name = "parent"
child_path = "/parent/child"
await secrets_manager.add_group(parent_name)
await secrets_manager.add_group(child_path)
parent_group = await secrets_manager.get_secret_group(f"/parent")
assert parent_group is not None
assert len(parent_group.children) == 1
assert parent_group.children[0].name == "child"
child_group = await secrets_manager.get_secret_group(child_path)
assert child_group is not None
@pytest.mark.asyncio
async def test_overlapping_names(self, secrets_manager: AsyncSecretContext) -> None:
"""Test having overlapping names in different groups."""
await secrets_manager.add_group("root")
with pytest.raises(InvalidGroupNameError):
await secrets_manager.add_group("/root")
await secrets_manager.add_group("/root/root")
group = await secrets_manager.get_secret_group("/root/root")
assert group is not None
assert group.name == "root"
@pytest.mark.asyncio
async def test_add_group_with_nonexisting_parent(
self, secrets_manager: AsyncSecretContext
) -> None:
"""Test adding a group with a nonexisting parent."""
with pytest.raises(InvalidGroupNameError):
await secrets_manager.add_group("orphan", parent_group="unknown")
@pytest.mark.asyncio
async def test_add_duplicate_group(
self, secrets_manager: AsyncSecretContext
) -> None:
"""Test adding the same group twice."""
await secrets_manager.add_group("snowflake")
with pytest.raises(InvalidGroupNameError):
await secrets_manager.add_group("snowflake")
@pytest.mark.parametrize(
"group_name,description", [("testgroup", "test description")]
)
@allure.title("Add a group name {group_name} with description {description}")
@pytest.mark.asyncio
async def test_add_group_with_description(
self,
secrets_manager: AsyncSecretContext,
group_name: str,
description: str,
) -> None:
"""Test adding a group with description."""
await secrets_manager.add_group(group_name, description)
result = await secrets_manager.get_secret_group(group_name)
assert result is not None
assert result.description == description
@pytest.mark.parametrize(
"groups",
[
[
("root", None, "root"),
("level1", "root", "/root/level1"),
("level2", "level1", "/root/level1/level2"),
],
[("flat1", None, "flat1"), ("flat2", None, "flat2")],
[
("stub", None, "stub"),
("root", None, "root"),
("nested", "root", "/root/nested"),
],
],
)
@allure.title("Listing groups")
class TestSecretGroupListing:
"""Tests for listing groups."""
@pytest_asyncio.fixture(autouse=True)
async def create_groups(
self,
secrets_manager: AsyncSecretContext,
groups: list[tuple[str, str | None, str]],
) -> None:
"""Pre-create groups."""
for name, parent_group, _path in groups:
await secrets_manager.add_group(name, parent_group=parent_group)
@pytest.mark.asyncio
async def test_get_secret_groups_list(
self,
secrets_manager: AsyncSecretContext,
groups: list[tuple[str, str | None, str]],
) -> None:
"""Test the flat get_secret_groups_list."""
# Create three levels of content
group_list = await secrets_manager.get_secret_group_list()
assert len(group_list) == len(groups)
@pytest.mark.asyncio
async def test_get_secret_groups(
self,
secrets_manager: AsyncSecretContext,
groups: list[tuple[str, str | None, str]],
) -> None:
"""Test the tree-oriented get_secretsgroups."""
group_map = dict([(group[0], group[1]) for group in sorted(groups)])
group_tree = await secrets_manager.get_secret_groups()
root_groups = [key for key, value in group_map.items() if value is None]
assert len(group_tree) == len(root_groups)
reconstructed_groups: list[tuple[str, str | None]] = []
def crawl_tree(item: SecretGroup) -> list[tuple[str, str | None]]:
"""Crawl a tree recursively."""
parent_group_name = None
if item.parent_group:
parent_group_name = item.parent_group.name
items: list[tuple[str, str | None]] = [(item.name, parent_group_name)]
for child in item.children:
items.extend(crawl_tree(child))
return items
for item in group_tree:
reconstructed_groups.extend(crawl_tree(item))
assert dict(sorted(reconstructed_groups)) == group_map
@pytest.mark.asyncio
async def test_get_secret_groups_with_secrets(
self,
secrets_manager: AsyncSecretContext,
groups: list[tuple[str, str | None, str]],
) -> None:
"""Test fetching groups where there are secrets in all groups."""
# We will create exactly two secrets in each group.
for group_name, _parent, path in groups:
await secrets_manager.add_entry(
f"{group_name}_1", f"{group_name}_secret_1", group_path=path
)
await secrets_manager.add_entry(
f"{group_name}_2", f"{group_name}_secret_2", group_path=path
)
group_list = await secrets_manager.get_secret_group_list()
for group in group_list:
assert len(group.entries) == 2
assert group.entries[0] == f"{group.name}_1"
assert group.entries[1] == f"{group.name}_2"
@pytest.mark.parametrize(
"query,expected,groups,children",
[
("MATCH", 1, [("MATCH", None), ("SOMETHINGELSE", None)], 0),
("MATCH", 1, [("root", None), ("MATCH", "root"), ("SOMETHINGELSE", "root")], 0),
("MATCH", 3, [("MATCH1", None), ("MATCH2", None), ("MATCH3", None)], 0),
("MATCH", 1, [("root", None), ("MATCH", "root"), ("CHILD", "MATCH")], 1),
(
"NOMATCH",
0,
[("foo", None), ("bar", None), ("foobar", "foo"), ("barfoo", "bar")],
0,
),
],
)
@allure.title("Searching in groups using patterns")
class TestGroupSearchPattern:
@pytest.mark.asyncio
async def test_group_list_pattern(
self,
secrets_manager: AsyncSecretContext,
query: str,
expected: int,
groups: list[tuple[str, str | None]],
children: int,
) -> None:
"""Test matching a pattern."""
for name, parent_group in groups:
await secrets_manager.add_group(name, parent_group=parent_group)
result = await secrets_manager.get_secret_group_list(pattern=query, regex=False)
assert len(result) == expected
@pytest.mark.asyncio
async def test_group_tree_pattern(
self,
secrets_manager: AsyncSecretContext,
query: str,
expected: int,
groups: list[tuple[str, str | None]],
children: int,
) -> None:
"""Test matching a pattern with a tree result."""
for name, parent_group in groups:
await secrets_manager.add_group(name, parent_group=parent_group)
result = await secrets_manager.get_secret_groups(pattern=query, regex=False)
assert len(result) == expected
if expected == 1 and children > 0:
assert len(result[0].children) == children
@allure.title("Modifying groups")
class TestGroupModification:
"""Test modifying groups."""
@pytest.mark.parametrize(
"group_name,parent,description",
[("test", None, "test description"), ("test", "root", "test_description")],
)
@pytest.mark.asyncio
async def test_set_group_description(
self,
secrets_manager: AsyncSecretContext,
group_name: str,
parent: str | None,
description: str,
) -> None:
"""Test setting a description on a group."""
if parent:
await secrets_manager.add_group(parent)
await secrets_manager.add_group(group_name, parent_group=parent)
path = group_name
if parent:
path = f"/{parent}/{group_name}"
group = await secrets_manager.get_secret_group(path)
assert group is not None
assert group.description is None
await secrets_manager.set_group_description(path, description)
group = await secrets_manager.get_secret_group(path)
assert group is not None
assert group.description == description
@pytest.mark.parametrize(
"groups,target_group, expected_path",
[
(
[("root", None), ("test", None)],
("test", "root"),
"/root/test",
),
(
[("root", None), ("test", "root")],
("/root/test", None),
"test",
),
([("test", None)], ("test", None), "test"),
],
)
@pytest.mark.asyncio
async def test_move_group(
self,
secrets_manager: AsyncSecretContext,
groups: list[tuple[str, str | None]],
target_group: tuple[str, str | None],
expected_path: str,
) -> None:
"""Test moving groups around."""
for group_name, parent_name in groups:
await secrets_manager.add_group(group_name, parent_group=parent_name)
group_name, target = target_group
await secrets_manager.move_group(group_name, target)
group = await secrets_manager.get_secret_group(expected_path)
assert group is not None
@allure.title("Deleting items")
class TestSecretManagerDeletions:
"""Test secret manager deletions."""
@pytest_asyncio.fixture(autouse=True)
async def create_test_data(self, secrets_manager: AsyncSecretContext) -> None:
"""Create some test data."""
groups = [
("root", None, "root"),
("level1", "root", "/root/level1"),
("level2", "level1", "/root/level1/level2"),
]
for n in range(2):
await secrets_manager.add_entry(f"ungrouped_{n}", "secret")
for group_name, parent_name, path in groups:
await secrets_manager.add_group(group_name, parent_group=parent_name)
for n in range(2):
await secrets_manager.add_entry(
f"{group_name}_{n}", "secret", group_path=path
)
@pytest.mark.parametrize(
"name,group_name",
[("root_1", "root"), ("level1_0", "/root/level1"), ("ungrouped_1", None)],
)
@allure.title("Delete secret {name in group {group_name}}")
@pytest.mark.asyncio
async def test_secret_deletion(
self, secrets_manager: AsyncSecretContext, name: str, group_name: str | None
) -> None:
"""Test secret deletion."""
if group_name:
group = await secrets_manager.get_secret_group(group_name)
assert group is not None
assert name in group.entries
secret = await secrets_manager.get_secret(name)
assert secret is not None
await secrets_manager.delete_entry(name)
secret = await secrets_manager.get_secret(name)
assert secret is None
if group_name:
group = await secrets_manager.get_secret_group(group_name)
assert group is not None
assert name not in group.entries
@pytest.mark.parametrize("name", ["NONEXISTING"])
@allure.title("Deleting non-existing entry {name}")
@pytest.mark.asyncio
async def test_nonexisting_entry(
self, secrets_manager: AsyncSecretContext, name: str
) -> None:
"""Test deleting something that doesn't exist."""
secret = await secrets_manager.get_secret(name)
assert secret is None
# Deleting something that is already deleted returns None
await secrets_manager.delete_entry(name)
@pytest.mark.parametrize("path", ["/root/level1"])
@allure.title("Deleting group {path}")
async def test_group_delete(
self, secrets_manager: AsyncSecretContext, path: str
) -> None:
"""Test deleting a group."""
group = await secrets_manager.get_secret_group(path)
assert group is not None
entries = list(group.entries)
await secrets_manager.delete_group(path)
group = await secrets_manager.get_secret_group(path)
assert group is None
for name in entries:
new_grouping = await secrets_manager.get_entry_group(name)
assert new_grouping is None
@allure.title("Other tests")
class TestSecretManagerOther:
"""Uncategorized tests to standardize module."""
@pytest.mark.asyncio
async def test_get_secret_nonexisting(
self, secrets_manager: AsyncSecretContext
) -> None:
"""Test get_secret with invalid name."""
result = await secrets_manager.get_secret("NOMATCH")
assert result is None
@pytest.mark.parametrize(
"group_name,num_grouped,num_ungrouped",
[("GROUP", 3, 3), ("GROUP", 3, 0), ("GROUP", 0, 0)],
)
@pytest.mark.asyncio
async def test_get_ungrouped_secrets(
self,
secrets_manager: AsyncSecretContext,
group_name: str,
num_grouped: int,
num_ungrouped: int,
) -> None:
"""Test get_ungrouped_secrets."""
await secrets_manager.add_group(group_name)
for n in range(num_ungrouped):
await secrets_manager.add_entry(f"ungrouped_{n}", "secret")
for n in range(num_grouped):
await secrets_manager.add_entry(
f"grouped_{n}", "secret", group_path="GROUP"
)
ungrouped = await secrets_manager.get_ungrouped_secrets()
assert len(ungrouped) == num_ungrouped
matching = [entry for entry in ungrouped if entry.startswith("ungrouped_")]
assert len(matching) == num_ungrouped
grouped_secrets = await secrets_manager.get_available_secrets(
group_path=group_name
)
assert len(grouped_secrets) == num_grouped
all_secrets = await secrets_manager.get_available_secrets()
assert len(all_secrets) == (num_ungrouped + num_grouped)
@pytest.mark.parametrize(
"entries", [[("test1", "secret1"), ("test2", "secret2")], []]
)
@pytest.mark.asyncio
async def test_get_available_secrets(
self, secrets_manager: AsyncSecretContext, entries: list[tuple[str, str]]
) -> None:
"""Test the get_available_secrets method."""
for name, secret in entries:
await secrets_manager.add_entry(name, secret)
entry_names = [entry[0] for entry in entries]
response = await secrets_manager.get_available_secrets()
assert len(response) == len(entries)
assert sorted(response) == sorted(entry_names)
async def test_get_secret_groups_none(
self, secrets_manager: AsyncSecretContext
) -> None:
"""Test get_secret_groups with no groups created."""
result = await secrets_manager.get_secret_groups()
assert len(result) == 0
result_flat = await secrets_manager.get_secret_group_list()
assert len(result_flat) == 0
@allure.title("Search for a group using regular expression")
async def test_group_regex_search(
self, secrets_manager: AsyncSecretContext
) -> None:
"""Search for entries with regular expressions."""
groups = [
"test1",
"test2",
"other",
"somethingelse",
]
for group in groups:
await secrets_manager.add_group(group)
results = await secrets_manager.get_secret_group_list(
pattern="^test", regex=True
)
assert len(results) == 2
for group in results:
assert group.name.startswith("test")

View File

@ -79,6 +79,32 @@ async def run_backend_server(test_ports: TestPorts):
await server_task
@pytest_asyncio.fixture(
scope=TEST_SCOPE, name="admin_server_settings", loop_scope=LOOP_SCOPE
)
async def get_admin_server_settings(
test_ports: TestPorts, backend_server: tuple[str, str]
):
"""Get admin server settings."""
backend_url, backend_token = backend_server
port = test_ports.admin
secret_key = secrets.token_urlsafe(32)
with in_tempdir() as admin_work_path:
admin_db = admin_work_path / "ssh_admin.db"
admin_settings = AdminServerSettings.model_validate(
{
"sshecret_backend_url": backend_url,
"backend_token": backend_token,
"secret_key": secret_key,
"listen_address": "0.0.0.0",
"port": port,
"database": str(admin_db.absolute()),
"password_manager_directory": str(admin_work_path.absolute()),
}
)
yield admin_settings
@pytest_asyncio.fixture(scope=TEST_SCOPE, name="admin_server", loop_scope=LOOP_SCOPE)
async def run_admin_server(test_ports: TestPorts, backend_server: tuple[str, str]):
"""Run admin server."""
@ -98,7 +124,7 @@ async def run_admin_server(test_ports: TestPorts, backend_server: tuple[str, str
"password_manager_directory": str(admin_work_path.absolute()),
}
)
admin_app = create_admin_app(admin_settings)
admin_app = create_admin_app(admin_settings, create_db=True)
config = uvicorn.Config(app=admin_app, port=port, loop="asyncio")
server = uvicorn.Server(config=config)
server_task = asyncio.create_task(server.serve())