diff --git a/packages/sshecret-admin/migrations/env.py b/packages/sshecret-admin/migrations/env.py index 95597d0..53e2672 100644 --- a/packages/sshecret-admin/migrations/env.py +++ b/packages/sshecret-admin/migrations/env.py @@ -1,11 +1,11 @@ import os from logging.config import fileConfig -from sqlalchemy import engine_from_config -from sqlalchemy import pool +from sqlalchemy import Engine, engine_from_config, pool, create_engine from alembic import context from sshecret_admin.auth.models import Base +from sshecret_admin.core.settings import AdminServerSettings # this is the Alembic Config object, which provides # access to the values within the .ini file in use. @@ -14,9 +14,30 @@ config = context.config def get_database_url() -> str | None: """Get database URL.""" - if db_file := os.getenv("SSHECRET_ADMIN_DATABASE"): - return f"sqlite:///{db_file}" - return config.get_main_option("sqlalchemy.url") + try: + settings = AdminServerSettings() # pyright: ignore[reportCallIssue] + return str(settings.admin_db) + except Exception: + if db_file := os.getenv("SSHECRET_ADMIN_DATABASE"): + return f"sqlite:///{db_file}" + return config.get_main_option("sqlalchemy.url") + + +def get_engine() -> Engine: + """Get engine.""" + try: + settings = AdminServerSettings() # pyright: ignore[reportCallIssue] + engine = create_engine(settings.admin_db) + return engine + except Exception: + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + return connectable + + # Interpret the config file for Python logging. @@ -68,12 +89,7 @@ def run_migrations_online() -> None: and associate a connection with the context. """ - connectable = engine_from_config( - config.get_section(config.config_ini_section, {}), - prefix="sqlalchemy.", - poolclass=pool.NullPool, - ) - + connectable = get_engine() with connectable.connect() as connection: context.configure( connection=connection, target_metadata=target_metadata, render_as_batch=True diff --git a/packages/sshecret-admin/migrations/versions/84356d0ea85f_implement_db_structures_for_internal_.py b/packages/sshecret-admin/migrations/versions/84356d0ea85f_implement_db_structures_for_internal_.py new file mode 100644 index 0000000..6e4f272 --- /dev/null +++ b/packages/sshecret-admin/migrations/versions/84356d0ea85f_implement_db_structures_for_internal_.py @@ -0,0 +1,44 @@ +"""Implement db structures for internal password manager + +Revision ID: 84356d0ea85f +Revises: 6c148590471f +Create Date: 2025-06-21 07:21:02.257865 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '84356d0ea85f' +down_revision: Union[str, None] = '6c148590471f' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('groups', + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('name', sa.String(), nullable=False), + sa.Column('parent_id', sa.Uuid(), nullable=True), + sa.ForeignKeyConstraint(['parent_id'], ['groups.id'], ), + sa.PrimaryKeyConstraint('id') + ) + with op.batch_alter_table('password_db', schema=None) as batch_op: + batch_op.add_column(sa.Column('client_id', sa.Uuid(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('password_db', schema=None) as batch_op: + batch_op.drop_column('client_id') + + op.drop_table('groups') + # ### end Alembic commands ### diff --git a/packages/sshecret-admin/migrations/versions/c34707a1ea3a_implement_managed_secrets.py b/packages/sshecret-admin/migrations/versions/c34707a1ea3a_implement_managed_secrets.py new file mode 100644 index 0000000..80fb9ad --- /dev/null +++ b/packages/sshecret-admin/migrations/versions/c34707a1ea3a_implement_managed_secrets.py @@ -0,0 +1,48 @@ +"""Implement managed secrets + +Revision ID: c34707a1ea3a +Revises: 84356d0ea85f +Create Date: 2025-06-21 07:38:12.994535 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'c34707a1ea3a' +down_revision: Union[str, None] = '84356d0ea85f' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('managed_secrets', + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('name', sa.String(), nullable=False), + sa.Column('is_deleted', sa.Boolean(), nullable=False), + sa.Column('group_id', sa.Uuid(), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), + sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ondelete='SET NULL'), + sa.PrimaryKeyConstraint('id') + ) + with op.batch_alter_table('groups', schema=None) as batch_op: + batch_op.add_column(sa.Column('description', sa.String(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('groups', schema=None) as batch_op: + batch_op.drop_column('description') + + op.drop_table('managed_secrets') + # ### end Alembic commands ### diff --git a/packages/sshecret-admin/src/sshecret_admin/api/endpoints/auth.py b/packages/sshecret-admin/src/sshecret_admin/api/endpoints/auth.py index e528c1e..c5b8aaa 100644 --- a/packages/sshecret-admin/src/sshecret_admin/api/endpoints/auth.py +++ b/packages/sshecret-admin/src/sshecret_admin/api/endpoints/auth.py @@ -5,9 +5,9 @@ import logging from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession -from sshecret_admin.auth import Token, authenticate_user, create_access_token +from sshecret_admin.auth import Token, authenticate_user_async, create_access_token from sshecret_admin.core.dependencies import AdminDependencies LOG = logging.getLogger(__name__) @@ -19,11 +19,12 @@ def create_router(dependencies: AdminDependencies) -> APIRouter: @app.post("/token") async def login_for_access_token( - session: Annotated[Session, Depends(dependencies.get_db_session)], + + session: Annotated[AsyncSession, Depends(dependencies.get_async_session)], form_data: Annotated[OAuth2PasswordRequestForm, Depends()], ) -> Token: """Login user and generate token.""" - user = authenticate_user(session, form_data.username, form_data.password) + user = await authenticate_user_async(session, form_data.username, form_data.password) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, diff --git a/packages/sshecret-admin/src/sshecret_admin/api/endpoints/secrets.py b/packages/sshecret-admin/src/sshecret_admin/api/endpoints/secrets.py index cce0b62..c8afd42 100644 --- a/packages/sshecret-admin/src/sshecret_admin/api/endpoints/secrets.py +++ b/packages/sshecret-admin/src/sshecret_admin/api/endpoints/secrets.py @@ -128,7 +128,7 @@ def create_router(dependencies: AdminDependencies) -> APIRouter: group = await admin.get_secret_group(group_name) if not group: return - await admin.delete_secret_group(group_name, keep_entries=True) + await admin.delete_secret_group(group_name) @app.post("/secrets/groups/{group_name}/{secret_name}") async def move_secret_to_group( diff --git a/packages/sshecret-admin/src/sshecret_admin/api/router.py b/packages/sshecret-admin/src/sshecret_admin/api/router.py index 13a9dd4..89cf62d 100644 --- a/packages/sshecret-admin/src/sshecret_admin/api/router.py +++ b/packages/sshecret-admin/src/sshecret_admin/api/router.py @@ -5,8 +5,9 @@ import logging from typing import Annotated -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.security import OAuth2PasswordBearer +from fastapi.security.utils import get_authorization_scheme_param from sqlalchemy import select from sqlalchemy.orm import Session @@ -57,6 +58,31 @@ def create_router(dependencies: BaseDependencies) -> APIRouter: raise credentials_exception return user + def get_client_origin(request: Request) -> str: + """Get client origin.""" + fallback_origin = "UNKNOWN" + if request.client: + return request.client.host + return fallback_origin + + def get_optional_username(request: Request) -> str | None: + """Get username, if available. + + This is purely used for auditing purposes. + """ + authorization = request.headers.get("Authorization") + scheme, param = get_authorization_scheme_param(authorization) + if not authorization or scheme.lower() != "bearer": + return None + claims = decode_token(dependencies.settings, param) + if not claims: + return None + + if claims.provider == LOCAL_ISSUER: + return claims.sub + + return f"oidc:{claims.email}" + async def get_current_active_user( current_user: Annotated[User, Depends(get_current_user)], ) -> User: @@ -66,9 +92,12 @@ def create_router(dependencies: BaseDependencies) -> APIRouter: return current_user async def get_admin_backend( + request: Request, session: Annotated[Session, Depends(dependencies.get_db_session)], ): """Get admin backend API.""" + username = get_optional_username(request) + origin = get_client_origin(request) password_db = session.scalars( select(PasswordDB).where(PasswordDB.id == 1) ).first() @@ -76,7 +105,11 @@ def create_router(dependencies: BaseDependencies) -> APIRouter: raise HTTPException( 500, detail="Error: The password manager has not yet been set up." ) - admin = AdminBackend(dependencies.settings, password_db.encrypted_password) + admin = AdminBackend( + dependencies.settings, + username=username, + origin=origin, + ) yield admin app = APIRouter(prefix=f"/api/{API_VERSION}") diff --git a/packages/sshecret-admin/src/sshecret_admin/auth/models.py b/packages/sshecret-admin/src/sshecret_admin/auth/models.py index fb285f1..7c543f1 100644 --- a/packages/sshecret-admin/src/sshecret_admin/auth/models.py +++ b/packages/sshecret-admin/src/sshecret_admin/auth/models.py @@ -1,12 +1,13 @@ -"""Models for authentication.""" +"""Models for authentication and secret management.""" import enum from datetime import datetime +from typing import override import uuid import sqlalchemy as sa from pydantic import BaseModel -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship JWT_ALGORITHM = "HS256" @@ -75,12 +76,15 @@ class PasswordDB(Base): __tablename__: str = "password_db" id: Mapped[int] = mapped_column(sa.INT, primary_key=True) - encrypted_password: Mapped[str] = mapped_column(sa.String) created_at: Mapped[datetime] = mapped_column( sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False ) + client_id: Mapped[uuid.UUID | None] = mapped_column( + sa.Uuid(as_uuid=True), nullable=True + ) + updated_at: Mapped[datetime | None] = mapped_column( sa.DateTime(timezone=True), server_default=sa.func.now(), @@ -88,6 +92,65 @@ class PasswordDB(Base): ) +class Group(Base): + """A secret group.""" + + __tablename__: str = "groups" + + id: Mapped[uuid.UUID] = mapped_column( + sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + name: Mapped[str] = mapped_column(sa.String, nullable=False) + description: Mapped[str | None] = mapped_column(sa.String, nullable=True) + + parent_id: Mapped[uuid.UUID | None] = mapped_column( + sa.ForeignKey("groups.id"), nullable=True + ) + parent: Mapped["Group | None"] = relationship( + "Group", remote_side=[id], back_populates="children" + ) + children: Mapped[list["Group"]] = relationship( + "Group", back_populates="parent", cascade="all, delete" + ) + secrets: Mapped[list["ManagedSecret"]] = relationship(back_populates="group") + + @override + def __repr__(self) -> str: + return f"" + + +class ManagedSecret(Base): + """Managed Secret.""" + + __tablename__: str = "managed_secrets" + + id: Mapped[uuid.UUID] = mapped_column( + sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + name: Mapped[str] = mapped_column(sa.String, nullable=False) + + is_deleted: Mapped[bool] = mapped_column(sa.Boolean, default=False) + + group_id: Mapped[uuid.UUID | None] = mapped_column( + sa.ForeignKey("groups.id", ondelete="SET NULL"), nullable=True + ) + group: Mapped["Group | None"] = relationship( + Group, foreign_keys=[group_id], back_populates="secrets" + ) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False + ) + updated_at: Mapped[datetime | None] = mapped_column( + sa.DateTime(timezone=True), + server_default=sa.func.now(), + onupdate=sa.func.now(), + ) + + deleted_at: Mapped[datetime | None] = mapped_column( + sa.DateTime(timezone=True), nullable=True + ) + + class IdentityClaims(BaseModel): """Normalized identity claim model.""" @@ -125,6 +188,3 @@ class LocalUserInfo(BaseModel): local: bool -def init_db(engine: sa.Engine) -> None: - """Create database.""" - Base.metadata.create_all(engine) diff --git a/packages/sshecret-admin/src/sshecret_admin/core/app.py b/packages/sshecret-admin/src/sshecret_admin/core/app.py index a81c340..06a7119 100644 --- a/packages/sshecret-admin/src/sshecret_admin/core/app.py +++ b/packages/sshecret-admin/src/sshecret_admin/core/app.py @@ -2,6 +2,7 @@ # pyright: reportUnusedFunction=false # +from collections.abc import AsyncGenerator import logging import os from contextlib import asynccontextmanager @@ -12,15 +13,15 @@ from fastapi.encoders import jsonable_encoder from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, RedirectResponse from fastapi.staticfiles import StaticFiles -from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession +from sshecret_backend.db import DatabaseSessionManager from starlette.middleware.sessions import SessionMiddleware from sshecret_admin import api, frontend -from sshecret_admin.auth.models import PasswordDB, init_db +from sshecret_admin.auth.models import Base from sshecret_admin.core.db import setup_database from sshecret_admin.frontend.exceptions import RedirectException -from sshecret_admin.services.master_password import setup_master_password +from sshecret_admin.services.secret_manager import setup_private_key from .dependencies import BaseDependencies from .settings import AdminServerSettings @@ -40,44 +41,28 @@ def setup_frontend(app: FastAPI, dependencies: BaseDependencies) -> None: def create_admin_app( - settings: AdminServerSettings, with_frontend: bool = True + settings: AdminServerSettings, + with_frontend: bool = True, + create_db: bool = False, ) -> FastAPI: """Create admin app.""" engine, get_db_session = setup_database(settings.admin_db) + async def get_async_session() -> AsyncGenerator[AsyncSession, None]: + """Get async session.""" + session_manager = DatabaseSessionManager(settings.async_db_url) + async with session_manager.session() as session: + yield session + def setup_password_manager() -> None: """Setup password manager.""" - encr_master_password = setup_master_password( - settings=settings, regenerate=False - ) - with Session(engine) as session: - existing_password = session.scalars( - select(PasswordDB).where(PasswordDB.id == 1) - ).first() - - if not encr_master_password: - if existing_password: - LOG.info("Master password already defined.") - return - # Looks like we have to regenerate it - LOG.warning( - "Master password was set, but not saved to the database. Regenerating it." - ) - encr_master_password = setup_master_password( - settings=settings, regenerate=True - ) - - assert encr_master_password is not None - - with Session(engine) as session: - pwdb = PasswordDB(id=1, encrypted_password=encr_master_password) - session.add(pwdb) - session.commit() + setup_private_key(settings, regenerate=False) @asynccontextmanager async def lifespan(_app: FastAPI): """Create database before starting the server.""" - init_db(engine) + if create_db: + Base.metadata.create_all(engine) setup_password_manager() yield @@ -109,7 +94,7 @@ def create_admin_app( status_code=status.HTTP_200_OK, content=jsonable_encoder({"status": "LIVE"}) ) - dependencies = BaseDependencies(settings, get_db_session) + dependencies = BaseDependencies(settings, get_db_session, get_async_session) app.include_router(api.create_api_router(dependencies)) if with_frontend: diff --git a/packages/sshecret-admin/src/sshecret_admin/core/cli.py b/packages/sshecret-admin/src/sshecret_admin/core/cli.py index 256260c..69baffa 100644 --- a/packages/sshecret-admin/src/sshecret_admin/core/cli.py +++ b/packages/sshecret-admin/src/sshecret_admin/core/cli.py @@ -12,7 +12,7 @@ from pydantic import ValidationError from sqlalchemy import select, create_engine from sqlalchemy.orm import Session from sshecret_admin.auth.authentication import hash_password -from sshecret_admin.auth.models import AuthProvider, PasswordDB, User, init_db +from sshecret_admin.auth.models import AuthProvider, PasswordDB, User from sshecret_admin.core.settings import AdminServerSettings from sshecret_admin.services.admin_backend import AdminBackend @@ -72,7 +72,6 @@ def cli_create_user( """Create user.""" settings = cast(AdminServerSettings, ctx.obj) engine = create_engine(settings.admin_db) - init_db(engine) with Session(engine) as session: create_user(session, username, email, password) @@ -87,7 +86,6 @@ def cli_change_user_passwd(ctx: click.Context, username: str, password: str) -> """Change password on user.""" settings = cast(AdminServerSettings, ctx.obj) engine = create_engine(settings.admin_db) - init_db(engine) with Session(engine) as session: user = session.scalars(select(User).where(User.username == username)).first() if not user: @@ -107,7 +105,6 @@ def cli_delete_user(ctx: click.Context, username: str) -> None: """Remove a user.""" settings = cast(AdminServerSettings, ctx.obj) engine = create_engine(settings.admin_db) - init_db(engine) with Session(engine) as session: user = session.scalars(select(User).where(User.username == username)).first() if not user: @@ -149,7 +146,6 @@ def cli_repl(ctx: click.Context) -> None: """Run an interactive console.""" settings = cast(AdminServerSettings, ctx.obj) engine = create_engine(settings.admin_db) - init_db(engine) with Session(engine) as session: password_db = session.scalars( select(PasswordDB).where(PasswordDB.id == 1) @@ -165,7 +161,7 @@ def cli_repl(ctx: click.Context) -> None: loop = asyncio.get_event_loop() return loop.run_until_complete(func) - admin = AdminBackend(settings, password_db.encrypted_password) + admin = AdminBackend(settings, ) locals = { "run": run, "admin": admin, diff --git a/packages/sshecret-admin/src/sshecret_admin/core/db.py b/packages/sshecret-admin/src/sshecret_admin/core/db.py index 781d50d..210a995 100644 --- a/packages/sshecret-admin/src/sshecret_admin/core/db.py +++ b/packages/sshecret-admin/src/sshecret_admin/core/db.py @@ -1,12 +1,13 @@ """Database setup.""" +import sqlite3 from contextlib import asynccontextmanager from collections.abc import AsyncIterator, Generator, Callable from sqlalchemy.orm import Session from sqlalchemy.engine import URL -from sqlalchemy import create_engine, Engine +from sqlalchemy import create_engine, Engine, event from sqlalchemy.ext.asyncio import ( AsyncConnection, @@ -18,11 +19,20 @@ from sqlalchemy.ext.asyncio import ( def setup_database( - db_url: URL | str, + db_url: URL, ) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]: """Setup database.""" engine = create_engine(db_url, echo=True, future=True) + if db_url.drivername.startswith("sqlite"): + + @event.listens_for(engine, "connect") + def set_sqlite_pragma( + dbapi_connection: sqlite3.Connection, _connection_record: object + ) -> None: + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() def get_db_session() -> Generator[Session, None, None]: """Get DB Session.""" @@ -33,8 +43,18 @@ def setup_database( class DatabaseSessionManager: - def __init__(self, host: URL | str, **engine_kwargs: str): + def __init__(self, host: URL, **engine_kwargs: str): self._engine: AsyncEngine | None = create_async_engine(host, **engine_kwargs) + if host.drivername.startswith("sqlite+"): + + @event.listens_for(self._engine.sync_engine, "connect") + def set_sqlite_pragma( + dbapi_connection: sqlite3.Connection, _connection_record: object + ) -> None: + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + self._sessionmaker: async_sessionmaker[AsyncSession] | None = ( async_sessionmaker( autocommit=False, bind=self._engine, expire_on_commit=False diff --git a/packages/sshecret-admin/src/sshecret_admin/core/dependencies.py b/packages/sshecret-admin/src/sshecret_admin/core/dependencies.py index 358cc7a..82bc742 100644 --- a/packages/sshecret-admin/src/sshecret_admin/core/dependencies.py +++ b/packages/sshecret-admin/src/sshecret_admin/core/dependencies.py @@ -4,6 +4,8 @@ from collections.abc import AsyncGenerator, Awaitable, Callable, Generator from dataclasses import dataclass from typing import Self +from fastapi import Request +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from sshecret_admin.auth import User from sshecret_admin.services import AdminBackend @@ -11,8 +13,9 @@ from sshecret_admin.core.settings import AdminServerSettings DBSessionDep = Callable[[], Generator[Session, None, None]] +AsyncSessionDep = Callable[[], AsyncGenerator[AsyncSession, None]] -AdminDep = Callable[[Session], AsyncGenerator[AdminBackend, None]] +AdminDep = Callable[[Request, Session], AsyncGenerator[AdminBackend, None]] GetUserDep = Callable[[User], Awaitable[User]] @@ -23,6 +26,8 @@ class BaseDependencies: settings: AdminServerSettings get_db_session: DBSessionDep + get_async_session: AsyncSessionDep + @dataclass @@ -43,6 +48,7 @@ class AdminDependencies(BaseDependencies): return cls( settings=deps.settings, get_db_session=deps.get_db_session, + get_async_session=deps.get_async_session, get_admin_backend=get_admin_backend, get_current_active_user=get_current_active_user, ) diff --git a/packages/sshecret-admin/src/sshecret_admin/frontend/dependencies.py b/packages/sshecret-admin/src/sshecret_admin/frontend/dependencies.py index 43b262f..6ab6a27 100644 --- a/packages/sshecret-admin/src/sshecret_admin/frontend/dependencies.py +++ b/packages/sshecret-admin/src/sshecret_admin/frontend/dependencies.py @@ -30,7 +30,6 @@ class FrontendDependencies(BaseDependencies): get_refresh_claims: RefreshTokenDep get_login_status: LoginStatusDep get_user_info: UserInfoDep - get_async_session: AsyncSessionDep require_login: LoginGuardDep @classmethod @@ -42,18 +41,17 @@ class FrontendDependencies(BaseDependencies): get_refresh_claims: RefreshTokenDep, get_login_status: LoginStatusDep, get_user_info: UserInfoDep, - get_async_session: AsyncSessionDep, require_login: LoginGuardDep, ) -> Self: """Create from base dependencies.""" return cls( settings=deps.settings, get_db_session=deps.get_db_session, + get_async_session=deps.get_async_session, get_admin_backend=get_admin_backend, templates=templates, get_refresh_claims=get_refresh_claims, get_login_status=get_login_status, get_user_info=get_user_info, - get_async_session=get_async_session, require_login=require_login, ) diff --git a/packages/sshecret-admin/src/sshecret_admin/frontend/router.py b/packages/sshecret-admin/src/sshecret_admin/frontend/router.py index 2faf8f4..d47af1d 100644 --- a/packages/sshecret-admin/src/sshecret_admin/frontend/router.py +++ b/packages/sshecret-admin/src/sshecret_admin/frontend/router.py @@ -24,7 +24,6 @@ from sshecret_admin.auth.constants import LOCAL_ISSUER from sshecret_admin.core.dependencies import BaseDependencies from sshecret_admin.services.admin_backend import AdminBackend -from sshecret_admin.core.db import DatabaseSessionManager from .dependencies import FrontendDependencies from .exceptions import RedirectException @@ -50,17 +49,24 @@ def create_router(dependencies: BaseDependencies) -> APIRouter: templates = Jinja2Blocks(directory=template_path) async def get_admin_backend( + request: Request, session: Annotated[Session, Depends(dependencies.get_db_session)], ): """Get admin backend API.""" password_db = session.scalars( select(PasswordDB).where(PasswordDB.id == 1) ).first() + username = get_optional_username(request) + origin = get_client_origin(request) if not password_db: raise HTTPException( 500, detail="Error: The password manager has not yet been set up." ) - admin = AdminBackend(dependencies.settings, password_db.encrypted_password) + admin = AdminBackend( + dependencies.settings, + username=username, + origin=origin, + ) yield admin def get_identity_claims(request: Request) -> IdentityClaims: @@ -108,14 +114,9 @@ def create_router(dependencies: BaseDependencies) -> APIRouter: next = URL("/refresh").include_query_params(next=request.url.path) raise RedirectException(to=next) - async def get_async_session(): - """Get async session.""" - sessionmanager = DatabaseSessionManager(dependencies.settings.async_db_url) - async with sessionmanager.session() as session: - yield session - async def get_user_info( - request: Request, session: Annotated[AsyncSession, Depends(get_async_session)] + request: Request, + session: Annotated[AsyncSession, Depends(dependencies.get_async_session)], ) -> LocalUserInfo: """Get User information.""" claims = get_identity_claims(request) @@ -142,6 +143,30 @@ def create_router(dependencies: BaseDependencies) -> APIRouter: next = URL("/refresh").include_query_params(next=request.url.path) raise RedirectException(to=next) + def get_optional_username( + request: Request, + ) -> str | None: + """Get username, if available. + + This is purely used for auditing purposes. + """ + try: + claims = get_identity_claims(request) + except Exception: + return None + + if claims.provider == LOCAL_ISSUER: + return claims.sub + + return f"oidc:{claims.email}" + + def get_client_origin(request: Request) -> str: + """Get client origin.""" + fallback_origin = "UNKNOWN" + if request.client: + return request.client.host + return fallback_origin + view_dependencies = FrontendDependencies.create( dependencies, get_admin_backend, @@ -149,7 +174,6 @@ def create_router(dependencies: BaseDependencies) -> APIRouter: refresh_identity_claims, get_login_status, get_user_info, - get_async_session, require_login, ) diff --git a/packages/sshecret-admin/src/sshecret_admin/frontend/templates/base/master-detail-email.html.j2 b/packages/sshecret-admin/src/sshecret_admin/frontend/templates/base/master-detail-email.html.j2 index 21706af..ee35e6e 100644 --- a/packages/sshecret-admin/src/sshecret_admin/frontend/templates/base/master-detail-email.html.j2 +++ b/packages/sshecret-admin/src/sshecret_admin/frontend/templates/base/master-detail-email.html.j2 @@ -16,7 +16,7 @@
+ class="flex-1 flex overflow-y-auto bg-white p-4 lg:block {% if not mobile_show_details|default(false) -%} hidden{%- endif -%} lg:block dark:bg-gray-800"> {% block detail %} diff --git a/packages/sshecret-admin/src/sshecret_admin/frontend/views/clients.py b/packages/sshecret-admin/src/sshecret_admin/frontend/views/clients.py index ced2cc0..0265e65 100644 --- a/packages/sshecret-admin/src/sshecret_admin/frontend/views/clients.py +++ b/packages/sshecret-admin/src/sshecret_admin/frontend/views/clients.py @@ -5,7 +5,7 @@ import ipaddress import logging import uuid from typing import Annotated -from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request, Response +from fastapi import APIRouter, Depends, Form, HTTPException, Request, Response from fastapi.responses import RedirectResponse from pydantic import BaseModel, IPvAnyAddress, IPvAnyNetwork from sshecret_admin.frontend.views.common import PagingInfo @@ -209,7 +209,7 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter: page: int, ) -> Response: """Get more events for a client.""" - if not "HX-Request" in request.headers: + if "HX-Request" not in request.headers: return RedirectResponse(url=f"/clients/client/{id}") client = await admin.get_client(("id", id)) diff --git a/packages/sshecret-admin/src/sshecret_admin/services/admin_backend.py b/packages/sshecret-admin/src/sshecret_admin/services/admin_backend.py index 0b0b8fa..29bfab0 100644 --- a/packages/sshecret-admin/src/sshecret_admin/services/admin_backend.py +++ b/packages/sshecret-admin/src/sshecret_admin/services/admin_backend.py @@ -4,8 +4,8 @@ Since we have a frontend and a REST API, it makes sense to have a generic librar """ import logging -from collections.abc import Iterator -from contextlib import contextmanager +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from sshecret.backend import ( AuditLog, @@ -20,7 +20,7 @@ from sshecret.backend.models import ClientQueryResult, DetailedSecrets from sshecret.backend.api import AuditAPI, KeySpec from sshecret.crypto import encrypt_string, load_public_key -from .keepass import PasswordContext, load_password_manager +from .secret_manager import AsyncSecretContext, password_manager_context from sshecret_admin.core.settings import AdminServerSettings from .models import ( ClientSecretGroup, @@ -86,19 +86,27 @@ def add_clients_to_secret_group( class AdminBackend: """Admin backend API.""" - def __init__(self, settings: AdminServerSettings, keepass_password: str) -> None: + def __init__( + self, + settings: AdminServerSettings, + username: str | None = None, + origin: str = "UNKNOWN", + ) -> None: """Create client management API.""" self.settings: AdminServerSettings = settings self.backend: SshecretBackend = SshecretBackend( str(settings.backend_url), settings.backend_token ) - self.keepass_password: str = keepass_password + self.username: str = username or "UKNOWN_USER" + self.origin: str = origin - @contextmanager - def password_manager(self) -> Iterator[PasswordContext]: - """Open the password manager.""" - with load_password_manager(self.settings, self.keepass_password) as kp: - yield kp + @asynccontextmanager + async def secrets_manager(self) -> AsyncIterator[AsyncSecretContext]: + """Open the secrets manager.""" + async with password_manager_context( + self.settings, self.username, self.origin + ) as manager: + yield manager async def _get_clients(self, filter: ClientFilter | None = None) -> list[Client]: """Get clients from backend.""" @@ -194,7 +202,7 @@ class AdminBackend: self, name: KeySpec, new_key: str, - password_manager: PasswordContext, + password_manager: AsyncSecretContext, ) -> list[str]: """Update client public key.""" LOG.info( @@ -207,7 +215,7 @@ class AdminBackend: updated_secrets: list[str] = [] for secret in client.secrets: LOG.debug("Re-encrypting secret %s for client %s", secret, name) - secret_value = password_manager.get_secret(secret) + secret_value = await password_manager.get_secret(secret) if not secret_value: LOG.warning( "Referenced secret %s does not exist! Skipping.", secret_value @@ -224,7 +232,7 @@ class AdminBackend: async def update_client_public_key(self, name: KeySpec, new_key: str) -> list[str]: """Update client public key.""" try: - with self.password_manager() as password_manager: + async with self.secrets_manager() as password_manager: return await self._update_client_public_key( name, new_key, password_manager ) @@ -291,8 +299,8 @@ class AdminBackend: This fetches the secret to client mapping from backend, and adds secrets from the password manager. """ backend_secrets = await self.backend.get_secrets() - with self.password_manager() as password_manager: - admin_secrets = password_manager.get_available_secrets() + async with self.secrets_manager() as password_manager: + admin_secrets = await password_manager.get_available_secrets() secrets: dict[str, SecretListView] = {} for secret in backend_secrets: @@ -324,8 +332,8 @@ class AdminBackend: This fetches the secret to client mapping from backend, and adds secrets from the password manager. """ - with self.password_manager() as password_manager: - all_secrets = password_manager.get_available_secrets() + async with self.secrets_manager() as password_manager: + all_secrets = await password_manager.get_available_secrets() secrets = await self.backend.get_detailed_secrets() backend_secret_names = [secret.name for secret in secrets] @@ -351,13 +359,13 @@ class AdminBackend: parent_group: str | None = None, ) -> None: """Add secret group.""" - with self.password_manager() as password_manager: - password_manager.add_group(group_name, description, parent_group) + async with self.secrets_manager() as password_manager: + await password_manager.add_group(group_name, description, parent_group) async def set_secret_group(self, secret_name: str, group_name: str | None) -> None: """Assign a group to a secret.""" - with self.password_manager() as password_manager: - password_manager.set_secret_group(secret_name, group_name) + async with self.secrets_manager() as password_manager: + await password_manager.set_secret_group(secret_name, group_name) async def move_secret_group( self, group_name: str, parent_group: str | None @@ -366,23 +374,21 @@ class AdminBackend: If parent_group is None, it will be moved to the root. """ - with self.password_manager() as password_manager: - password_manager.move_group(group_name, parent_group) + async with self.secrets_manager() as password_manager: + await password_manager.move_group(group_name, parent_group) async def set_group_description(self, group_name: str, description: str) -> None: """Set a group description.""" - with self.password_manager() as password_manager: - password_manager.set_group_description(group_name, description) + async with self.secrets_manager() as password_manager: + await password_manager.set_group_description(group_name, description) - async def delete_secret_group( - self, group_name: str, keep_entries: bool = True - ) -> None: + async def delete_secret_group(self, group_name: str) -> None: """Delete a group. If keep_entries is set to False, all entries in the group will be deleted. """ - with self.password_manager() as password_manager: - password_manager.delete_group(group_name, keep_entries) + async with self.secrets_manager() as password_manager: + await password_manager.delete_group(group_name) async def get_secret_groups( self, @@ -399,18 +405,18 @@ class AdminBackend: """ all_secrets = await self.backend.get_detailed_secrets() secrets_mapping = {secret.name: secret for secret in all_secrets} - with self.password_manager() as password_manager: + async with self.secrets_manager() as password_manager: if flat: - all_groups = password_manager.get_secret_group_list( + all_groups = await password_manager.get_secret_group_list( group_filter, regex=regex ) else: - all_groups = password_manager.get_secret_groups( + all_groups = await password_manager.get_secret_groups( group_filter, regex=regex ) - ungrouped = password_manager.get_ungrouped_secrets() + ungrouped = await password_manager.get_ungrouped_secrets() - all_admin_secrets = password_manager.get_available_secrets() + all_admin_secrets = await password_manager.get_available_secrets() group_result: list[ClientSecretGroup] = [] for group in all_groups: @@ -452,8 +458,8 @@ class AdminBackend: async def get_secret_group_by_path(self, path: str) -> ClientSecretGroup | None: """Get a group based on its path.""" - with self.password_manager() as password_manager: - secret_group = password_manager.get_secret_group(path) + async with self.secrets_manager() as password_manager: + secret_group = await password_manager.get_secret_group(path) if not secret_group: return None @@ -476,9 +482,11 @@ class AdminBackend: ) -> SecretView | None: """Get a secret, including the actual unencrypted value and clients.""" secret: str | None = None - with self.password_manager() as password_manager: - secret = password_manager.get_secret(name) - secret_group = password_manager.get_entry_group(name) + async with self.secrets_manager() as password_manager: + secret = await password_manager.get_secret(name) + secret_group: str | None = None + if secret: + secret_group = await password_manager.get_entry_group(name) secret_view = SecretView(name=name, secret=secret, group=secret_group) @@ -503,8 +511,8 @@ class AdminBackend: async def _delete_secret(self, name: str) -> None: """Delete a secret.""" - with self.password_manager() as password_manager: - password_manager.delete_entry(name) + async with self.secrets_manager() as password_manager: + await password_manager.delete_entry(name) secret_mapping = await self.backend.get_secret(name) if not secret_mapping: @@ -522,8 +530,8 @@ class AdminBackend: group: str | None = None, ) -> None: """Add a secret.""" - with self.password_manager() as password_manager: - password_manager.add_entry(name, value, update, group_name=group) + async with self.secrets_manager() as password_manager: + await password_manager.add_entry(name, value, update, group_path=group) if update: secret_map = await self.backend.get_secret(name) @@ -576,8 +584,8 @@ class AdminBackend: if not client: raise ClientNotFoundError(client_idname) - with self.password_manager() as password_manager: - secret = password_manager.get_secret(secret_name) + async with self.secrets_manager() as password_manager: + secret = await password_manager.get_secret(secret_name) if not secret: raise SecretNotFoundError() diff --git a/packages/sshecret-admin/src/sshecret_admin/services/keepass.py b/packages/sshecret-admin/src/sshecret_admin/services/keepass.py deleted file mode 100644 index 04af2fd..0000000 --- a/packages/sshecret-admin/src/sshecret_admin/services/keepass.py +++ /dev/null @@ -1,348 +0,0 @@ -"""Keepass password manager.""" - -import logging -from collections.abc import Iterator -from contextlib import contextmanager -from pathlib import Path -from typing import cast - -import pykeepass -import pykeepass.exceptions -from sshecret_admin.core.settings import AdminServerSettings - -from .models import SecretGroup -from .master_password import decrypt_master_password - - -LOG = logging.getLogger(__name__) - -NO_USERNAME = "NO_USERNAME" - -DEFAULT_LOCATION = "keepass.kdbx" - - -class PasswordCredentialsError(Exception): - pass - - -def create_password_db(location: Path, password: str) -> None: - """Create the password database.""" - LOG.info("Creating password database at %s", location) - pykeepass.create_database(str(location.absolute()), password=password) - - -def _kp_group_to_secret_group( - kp_group: pykeepass.group.Group, - parent: SecretGroup | None = None, - depth: int | None = None, -) -> SecretGroup: - """Convert keepass group to secret group dataclass.""" - group_name = cast(str, kp_group.name) - path = "/".join(cast(list[str], kp_group.path)) - group = SecretGroup(name=group_name, path=path, description=kp_group.notes) - for entry in kp_group.entries: - group.entries.append(str(entry.title)) - if parent: - group.parent_group = parent - - current_depth = len(kp_group.path) - - if not parent and current_depth > 1: - parent = _kp_group_to_secret_group(kp_group.parentgroup, depth=current_depth) - parent.children.append(group) - group.parent_group = parent - - if depth and depth == current_depth: - return group - - for subgroup in kp_group.subgroups: - group.children.append(_kp_group_to_secret_group(subgroup, group, depth=depth)) - - return group - - -class PasswordContext: - """Password Context class.""" - - def __init__(self, keepass: pykeepass.PyKeePass) -> None: - """Initialize password context.""" - self.keepass: pykeepass.PyKeePass = keepass - - @property - def _root_group(self) -> pykeepass.group.Group: - """Return the root group.""" - return cast(pykeepass.group.Group, self.keepass.root_group) - - def _get_entry(self, name: str) -> pykeepass.entry.Entry | None: - """Get entry.""" - entry = cast( - "pykeepass.entry.Entry | None", - self.keepass.find_entries(title=name, first=True), - ) - return entry - - def _get_group(self, name: str) -> pykeepass.group.Group | None: - """Find a group.""" - group = cast( - pykeepass.group.Group | None, - self.keepass.find_groups(name=name, first=True), - ) - return group - - def add_entry( - self, - entry_name: str, - secret: str, - overwrite: bool = False, - group_name: str | None = None, - ) -> None: - """Add an entry. - - Specify overwrite=True to overwrite the existing secret value, if it exists. - This will not move the entry, if the group_name is different from the original group. - - """ - entry = self._get_entry(entry_name) - if entry and overwrite: - entry.password = secret - self.keepass.save() - return - - if entry: - raise ValueError("Error: A secret with this name already exists.") - LOG.debug("Add secret entry to keepass: %s, group: %r", entry_name, group_name) - if group_name: - destination_group = self._get_group(group_name) - else: - destination_group = self._root_group - - entry = self.keepass.add_entry( - destination_group=destination_group, - title=entry_name, - username=NO_USERNAME, - password=secret, - ) - self.keepass.save() - - def get_secret(self, entry_name: str) -> str | None: - """Get the secret value.""" - entry = self._get_entry(entry_name) - if not entry: - return None - - LOG.warning("Secret name %s accessed", entry_name) - if password := cast(str, entry.password): - return str(password) - - raise RuntimeError(f"Cannot get password for entry {entry_name}") - - def get_entry_group(self, entry_name: str) -> str | None: - """Get the group for an entry.""" - entry = self._get_entry(entry_name) - if not entry: - return None - if entry.group.is_root_group: - return None - return str(entry.group.name) - - def get_secret_groups( - self, pattern: str | None = None, regex: bool = True - ) -> list[SecretGroup]: - """Get secret groups. - - A regex pattern may be provided to filter groups. - """ - if pattern: - groups = cast( - list[pykeepass.group.Group], - self.keepass.find_groups(name=pattern, regex=regex), - ) - else: - groups = self._root_group.subgroups - - secret_groups = [_kp_group_to_secret_group(group) for group in groups] - return secret_groups - - def get_secret_group_list( - self, pattern: str | None = None, regex: bool = True - ) -> list[SecretGroup]: - """Get a flat list of groups.""" - if pattern: - return self.get_secret_groups(pattern, regex) - - groups = [group for group in self.keepass.groups if not group.is_root_group] - secret_groups = [_kp_group_to_secret_group(group) for group in groups] - return secret_groups - - def get_secret_group(self, path: str) -> SecretGroup | None: - """Get a secret group by path.""" - elements = path.split("/") - final_element = elements[-1] - - current = self._root_group - while elements: - groupname = elements.pop(0) - matches = [ - subgroup for subgroup in current.subgroups if subgroup.name == groupname - ] - if matches: - current = matches[0] - else: - return None - if not current.is_root_group and current.name == final_element: - return _kp_group_to_secret_group(current) - return None - - def get_ungrouped_secrets(self) -> list[str]: - """Get secrets without groups.""" - entries: list[str] = [] - for entry in self._root_group.entries: - entries.append(str(entry.title)) - - return entries - - def add_group( - self, name: str, description: str | None = None, parent_group: str | None = None - ) -> None: - """Add a group.""" - kp_parent_group = self._root_group - if parent_group: - query = cast( - pykeepass.group.Group | None, - self.keepass.find_groups(name=parent_group, first=True), - ) - if not query: - raise ValueError( - f"Error: Cannot find a parent group named {parent_group}" - ) - kp_parent_group = query - self.keepass.add_group( - destination_group=kp_parent_group, group_name=name, notes=description - ) - self.keepass.save() - - def set_group_description(self, name: str, description: str) -> None: - """Set the description of a group.""" - group = self._get_group(name) - if not group: - raise ValueError(f"Error: No such group {name}") - - group.notes = description - self.keepass.save() - - def set_secret_group(self, entry_name: str, group_name: str | None) -> None: - """Move a secret to a group. - - If group is None, the secret will be placed in the root group. - """ - entry = self._get_entry(entry_name) - if not entry: - raise ValueError( - f"Cannot find secret entry named {entry_name} in secrets database" - ) - if group_name: - group = self._get_group(group_name) - if not group: - raise ValueError(f"Cannot find a group named {group_name}") - else: - group = self._root_group - - self.keepass.move_entry(entry, group) - self.keepass.save() - - def move_group(self, name: str, parent_group: str | None) -> None: - """Move a group. - - If parent_group is None, it will be moved to the root. - """ - group = self._get_group(name) - if not group: - raise ValueError(f"Error: No such group {name}") - if parent_group: - parent = self._get_group(parent_group) - if not parent: - raise ValueError(f"Error: No such group {parent_group}") - else: - parent = self._root_group - - self.keepass.move_group(group, parent) - self.keepass.save() - - def get_available_secrets(self, group_name: str | None = None) -> list[str]: - """Get the names of all secrets in the database.""" - if group_name: - group = self._get_group(group_name) - if not group: - raise ValueError(f"Error: No such group {group_name}") - entries = group.entries - else: - entries = cast(list[pykeepass.entry.Entry], self.keepass.entries) - if not entries: - return [] - return [str(entry.title) for entry in entries] - - def delete_entry(self, entry_name: str) -> None: - """Delete entry.""" - entry = cast( - "pykeepass.entry.Entry | None", - self.keepass.find_entries(title=entry_name, first=True), - ) - if not entry: - return - entry.delete() - self.keepass.save() - - def delete_group(self, name: str, keep_entries: bool = True) -> None: - """Delete a group. - - If keep_entries is set to False, all entries in the group will be deleted. - """ - group = self._get_group(name) - if not group: - return - if keep_entries: - for entry in cast( - list[pykeepass.entry.Entry], - self.keepass.find_entries(recursive=True, group=group), - ): - # Move the entry to the root group. - LOG.warning( - "Moving orphaned secret entry %s to root group", entry.title - ) - self.keepass.move_entry(entry, self._root_group) - - self.keepass.delete_group(group) - self.keepass.save() - - -@contextmanager -def _password_context(location: Path, password: str) -> Iterator[PasswordContext]: - """Open the password context.""" - try: - database = pykeepass.PyKeePass(str(location.absolute()), password=password) - except pykeepass.exceptions.CredentialsError as e: - raise PasswordCredentialsError( - "Could not open password database. Invalid credentials." - ) from e - context = PasswordContext(database) - yield context - - -@contextmanager -def load_password_manager( - settings: AdminServerSettings, - encrypted_password: str, - location: str = DEFAULT_LOCATION, -) -> Iterator[PasswordContext]: - """Load password manager. - - This function decrypts the password, and creates the password database if it - has not yet been created. - """ - db_location = Path(location) - password = decrypt_master_password(settings=settings, encrypted=encrypted_password) - if not db_location.exists(): - create_password_db(db_location, password) - - with _password_context(db_location, password) as context: - yield context diff --git a/packages/sshecret-admin/src/sshecret_admin/services/master_password.py b/packages/sshecret-admin/src/sshecret_admin/services/master_password.py deleted file mode 100644 index 8b87918..0000000 --- a/packages/sshecret-admin/src/sshecret_admin/services/master_password.py +++ /dev/null @@ -1,86 +0,0 @@ -"""Functions related to handling the password database master password.""" - -import secrets -from pathlib import Path -from sshecret.crypto import ( - create_private_rsa_key, - load_private_key, - encrypt_string, - decode_string, -) -from sshecret_admin.core.settings import AdminServerSettings - -KEY_FILENAME = "sshecret-admin-key" - - -def setup_master_password( - settings: AdminServerSettings, - filename: str = KEY_FILENAME, - regenerate: bool = False, -) -> str | None: - """Setup master password. - - If regenerate is True, a new key will be generated. - - This method should run just after setting up the database. - """ - keyfile = Path(filename) - if settings.password_manager_directory: - keyfile = settings.password_manager_directory / filename - created = _initial_key_setup(settings, keyfile, regenerate) - if not created: - return None - - return _generate_master_password(settings, keyfile) - - -def decrypt_master_password( - settings: AdminServerSettings, encrypted: str, filename: str = KEY_FILENAME -) -> str: - """Retrieve master password.""" - keyfile = Path(filename) - if settings.password_manager_directory: - keyfile = settings.password_manager_directory / filename - if not keyfile.exists(): - raise RuntimeError("Error: Private key has not been generated yet.") - - private_key = load_private_key( - str(keyfile.absolute()), password=settings.secret_key - ) - return decode_string(encrypted, private_key) - - -def _generate_password() -> str: - """Generate a password.""" - return secrets.token_urlsafe(32) - - -def _initial_key_setup( - settings: AdminServerSettings, - keyfile: Path, - regenerate: bool = False, -) -> bool: - """Set up initial keys.""" - if keyfile.exists() and not regenerate: - return False - - assert settings.secret_key is not None, ( - "Error: Could not load a secret key from environment." - ) - create_private_rsa_key(keyfile, password=settings.secret_key) - return True - - -def _generate_master_password(settings: AdminServerSettings, keyfile: Path) -> str: - """Generate master password for password database. - - Returns the encrypted string, base64 encoded. - """ - if not keyfile.exists(): - raise RuntimeError("Error: Private key has not been generated yet.") - private_key = load_private_key( - str(keyfile.absolute()), password=settings.secret_key - ) - public_key = private_key.public_key() - master_password = _generate_password() - return encrypt_string(master_password, public_key) diff --git a/packages/sshecret-admin/src/sshecret_admin/services/secret_manager.py b/packages/sshecret-admin/src/sshecret_admin/services/secret_manager.py new file mode 100644 index 0000000..28a2557 --- /dev/null +++ b/packages/sshecret-admin/src/sshecret_admin/services/secret_manager.py @@ -0,0 +1,776 @@ +"""Rewritten secret manager using a rsa keys.""" + +import logging +import os +import uuid +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from dataclasses import dataclass +from datetime import datetime, timezone +from functools import cached_property +from pathlib import Path + +from cryptography.hazmat.primitives.asymmetric import rsa +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload, aliased +from sshecret.backend import SshecretBackend +from sshecret.backend.api import AuditAPI, KeySpec +from sshecret.backend.models import Client, ClientSecret, Operation, SubSystem +from sshecret.crypto import ( + create_private_rsa_key, + decode_string, + encrypt_string, + generate_public_key_string, + load_private_key, + load_public_key, +) +from sshecret_admin.auth import PasswordDB +from sshecret_admin.auth.models import Group, ManagedSecret +from sshecret_admin.core.db import DatabaseSessionManager +from sshecret_admin.core.settings import AdminServerSettings +from sshecret_admin.services.models import SecretGroup + + +KEY_FILENAME = "sshecret-admin-key" +PASSWORD_MANAGER_ID = "SshecretAdminPasswordManager" + +LOG = logging.getLogger(PASSWORD_MANAGER_ID) + + +class SecretManagerError(Exception): + """Secret manager error.""" + + +class InvalidGroupNameError(SecretManagerError): + """Invalid group name.""" + + +class InvalidSecretNameError(SecretManagerError): + """Invalid secret name.""" + + +@dataclass +class ClientAuditData: + """Client audit data.""" + + username: str + origin: str + + +@dataclass +class ParsedPath: + """Parsed path.""" + + item: str + full_path: str + parent: str | None = None + + +class SecretDataEntryExport(BaseModel): + """Exportable secret entry.""" + + name: str + secret: str + group: str | None = None + + +class SecretDataGroupExport(BaseModel): + """Exportable secret grouping.""" + + name: str + path: str + description: str | None = None + + +class SecretDataExport(BaseModel): + """Exportable object containing secrets and groups.""" + + entries: list[SecretDataEntryExport] + groups: list[SecretDataGroupExport] + + +def split_path(path: str) -> list[str]: + """Split a path into a list of groups.""" + elements = path.split("/") + if path.startswith("/"): + elements = elements[1:] + + return elements + + +def parse_path(path: str) -> ParsedPath: + """Parse path.""" + elements = split_path(path) + parsed = ParsedPath(elements[-1], path) + if len(elements) > 1: + parsed.parent = elements[-2] + return parsed + + +class AsyncSecretContext: + """Async secret context.""" + + def __init__( + self, + private_key: rsa.RSAPrivateKey, + manager_client: Client, + session: AsyncSession, + backend: SshecretBackend, + audit_data: ClientAuditData, + ) -> None: + """Initialize secret manager""" + self._private_key: rsa.RSAPrivateKey = private_key + self._manager_client: Client = manager_client + self._id: KeySpec = ("id", str(manager_client.id)) + self.backend: SshecretBackend = backend + self.session: AsyncSession = session + + self.audit_data: ClientAuditData = audit_data + self.audit: AuditAPI = backend.audit(SubSystem.ADMIN) + self._import_has_run: bool = False + + async def _create_missing_entries(self) -> None: + """Create any missing entries.""" + new_secrets: bool = False + to_check = set(self._manager_client.secrets) + for secret_name in to_check: + # entry = await self._get_entry(secret_name, include_deleted=True) + statement = select(ManagedSecret).where(ManagedSecret.name == secret_name) + result = await self.session.scalars(statement) + if not result.first(): + new_secrets = True + managed_secret = ManagedSecret(name=secret_name) + self.session.add(managed_secret) + + await self.session.flush() + await self.write_audit( + Operation.CREATE, + message="Imported managed secret from backend.", + secret_name=secret_name, + managed_secret=managed_secret, + ) + if new_secrets: + await self.session.commit() + + async def _get_group_depth(self, group: Group) -> int: + """Get the depth of a group.""" + depth = 1 + if not group.parent_id: + return depth + + current = group + while current.parent is not None: + if current.parent: + depth += 1 + current = await self._get_group_by_id(current.parent.id) + else: + break + + return depth + + async def _get_group_path(self, group: Group) -> str: + """Get the path of a group.""" + + if not group.parent_id: + return group.name + path: list[str] = [] + current = group + while current.parent_id is not None: + path.append(current.name) + current = await self._get_group_by_id(current.parent_id) + + path.append("") + path.reverse() + return "/".join(path) + + async def _get_group_secrets(self, group: Group) -> list[ManagedSecret]: + """Get secrets in a group.""" + statement = ( + select(ManagedSecret) + .where(ManagedSecret.group_id == group.id) + .where(ManagedSecret.is_deleted.is_not(True)) + ) + results = await self.session.scalars(statement) + return list(results.all()) + + async def _build_group_tree( + self, group: Group, parent: SecretGroup | None = None, depth: int | None = None + ) -> SecretGroup: + """Build a group tree.""" + path = "/" + if parent: + path = os.path.join(parent.path, path) + secret_group = SecretGroup( + name=group.name, path=path, description=group.description + ) + group_secrets = await self._get_group_secrets(group) + for secret in group_secrets: + secret_group.entries.append(secret.name) + if parent: + secret_group.parent_group = parent + + current_depth = await self._get_group_depth(group) + + if not parent and group.parent: + parent_group = await self._get_group_by_id(group.parent.id) + assert parent_group is not None + parent = await self._build_group_tree(parent_group, depth=current_depth) + parent.children.append(secret_group) + secret_group.parent_group = parent + + if depth and depth == current_depth: + return secret_group + + for subgroup in group.children: + child_group = await self._get_group_by_id(subgroup.id) + assert child_group is not None + secret_subgroup = await self._build_group_tree( + child_group, secret_group, depth=depth + ) + secret_group.children.append(secret_subgroup) + + return secret_group + + async def write_audit( + self, + operation: Operation, + message: str, + group_name: str | None = None, + client_secret: ClientSecret | None = None, + secret_name: str | None = None, + managed_secret: ManagedSecret | None = None, + **data: str, + ) -> None: + """Write Audit message.""" + if group_name: + data["group"] = group_name + + data["username"] = self.audit_data.username + if client_secret and not secret_name: + secret_name = client_secret.name + + if managed_secret: + data["managed_secret"] = str(managed_secret.id) + + await self.audit.write_async( + operation=operation, + message=message, + origin=self.audit_data.origin, + client=self._manager_client, + secret=client_secret, + secret_name=secret_name, + **data, + ) + + @cached_property + def public_key(self) -> rsa.RSAPublicKey: + """Get public key.""" + keystring = self._manager_client.public_key + return load_public_key(keystring.encode()) + + async def _get_entry( + self, name: str, include_deleted: bool = False + ) -> ManagedSecret | None: + """Get managed secret.""" + if not self._import_has_run: + await self._create_missing_entries() + self._import_has_run = True + statement = ( + select(ManagedSecret) + .options(selectinload(ManagedSecret.group)) + .where(ManagedSecret.name == name) + ) + if not include_deleted: + statement = statement.where(ManagedSecret.is_deleted.is_not(True)) + + result = await self.session.scalars(statement) + return result.first() + + async def add_entry( + self, + entry_name: str, + secret: str, + overwrite: bool = False, + group_path: str | None = None, + ) -> None: + """Add entry.""" + existing_entry = await self._get_entry(entry_name) + if existing_entry and not overwrite: + raise InvalidSecretNameError( + "Another secret with this name is already defined." + ) + + encrypted = encrypt_string(secret, self.public_key) + client_secret = await self.backend.create_client_secret( + self._id, entry_name, encrypted + ) + group_id: uuid.UUID | None = None + if group_path: + elements = parse_path(group_path) + group = await self._get_group(elements.item, elements.parent, True) + if not group: + raise InvalidGroupNameError("Invalid group name") + group_id = group.id + + if existing_entry: + existing_entry.updated_at = datetime.now(timezone.utc) + if group_id: + existing_entry.group_id = group_id + self.session.add(existing_entry) + await self.session.commit() + await self.write_audit( + Operation.UPDATE, + "Updated secret value", + group_name=group_path, + client_secret=client_secret, + managed_secret=existing_entry, + ) + else: + managed_secret = ManagedSecret( + name=entry_name, + group_id=group_id, + ) + self.session.add(managed_secret) + + await self.session.commit() + await self.write_audit( + Operation.CREATE, + "Created managed client secret", + group_path, + client_secret=client_secret, + managed_secret=managed_secret, + ) + + async def get_secret(self, entry_name: str) -> str | None: + """Get secret.""" + client_secret = await self.backend.get_client_secret( + self._id, ("name", entry_name) + ) + if not client_secret: + return None + decrypted = decode_string(client_secret, self._private_key) + await self.write_audit( + Operation.READ, + "Secret was viewed from secret manager", + secret_name=entry_name, + ) + + return decrypted + + async def get_available_secrets(self, group_path: str | None = None) -> list[str]: + """Get the names of all secrets in the db.""" + if not self._import_has_run: + await self._create_missing_entries() + if group_path: + elements = parse_path(group_path) + group = await self._get_group(elements.item, elements.parent) + if not group: + raise InvalidGroupNameError("Invalid or nonexisting group name.") + entries = group.secrets + else: + result = await self.session.scalars( + select(ManagedSecret) + .options(selectinload(ManagedSecret.group)) + .where(ManagedSecret.is_deleted.is_not(True)) + ) + + entries = list(result.all()) + + return [entry.name for entry in entries] + + async def delete_entry(self, entry_name: str) -> None: + """Delete a secret.""" + entry = await self._get_entry(entry_name) + if not entry: + return + entry.is_deleted = True + entry.deleted_at = datetime.now(timezone.utc) + self.session.add(entry) + await self.session.commit() + await self.backend.delete_client_secret( + ("id", str(self._manager_client.id)), ("name", entry_name) + ) + await self.write_audit( + Operation.DELETE, + "Managed secret entry deleted", + secret_name=entry_name, + managed_secret=entry, + ) + + async def get_entry_group(self, entry_name: str) -> str | None: + """Get group of entry.""" + entry = await self._get_entry(entry_name) + if not entry: + raise InvalidSecretNameError("Invalid secret name or secret not found.") + if entry.group: + return entry.group.name + return None + + async def _get_groups( + self, pattern: str | None = None, regex: bool = True, root_groups: bool = False + ) -> list[Group]: + """Get groups.""" + statement = select(Group).options( + selectinload(Group.children), selectinload(Group.parent) + ) + if pattern and regex: + statement = statement.where(Group.name.regexp_match(pattern)) + elif pattern: + statement = statement.where(Group.name.contains(pattern)) + if root_groups: + statement = statement.where(Group.parent_id == None) + results = await self.session.scalars(statement) + return list(results.all()) + + async def get_secret_groups( + self, pattern: str | None = None, regex: bool = True + ) -> list[SecretGroup]: + """Get secret groups, as a hierarcy.""" + if pattern: + groups = await self._get_groups(pattern, regex) + else: + groups = await self._get_groups(root_groups=True) + + secret_groups: list[SecretGroup] = [] + for group in groups: + secret_group = await self._build_group_tree(group) + secret_groups.append(secret_group) + + return secret_groups + + async def get_secret_group_list( + self, pattern: str | None = None, regex: bool = True + ) -> list[SecretGroup]: + """Get secret group list.""" + groups = await self._get_groups(pattern, regex) + return [(await self._build_group_tree(group)) for group in groups] + + async def _get_group_by_id(self, id: uuid.UUID) -> Group: + """Get group by ID.""" + statement = ( + select(Group) + .options( + selectinload(Group.parent), + selectinload(Group.children), + selectinload(Group.secrets), + ) + .where(Group.id == id) + ) + + result = await self.session.scalars(statement) + return result.one() + + async def _get_group( + self, name: str, parent: str | None = None, exact_match: bool = False + ) -> Group | None: + """Get a group.""" + statement = ( + select(Group) + .options( + selectinload(Group.parent), + selectinload(Group.children), + selectinload(Group.secrets), + ) + .where(Group.name == name) + ) + if parent: + ParentGroup = aliased(Group) + statement = statement.join(ParentGroup, Group.parent).where( + ParentGroup.name == parent + ) + elif exact_match: + statement = statement.where(Group.parent_id == None) + result = await self.session.scalars(statement) + return result.first() + + async def get_secret_group(self, path: str) -> SecretGroup | None: + """Get a secret group by path.""" + elements = parse_path(path) + + group_name = elements.item + parent_group = elements.parent + + group = await self._get_group(group_name, parent_group) + if not group: + return None + + return await self._build_group_tree(group) + + async def get_ungrouped_secrets(self) -> list[str]: + """Get ungrouped secrets.""" + statement = ( + select(ManagedSecret) + .where(ManagedSecret.is_deleted.is_not(True)) + .where(ManagedSecret.group_id == None) + ) + result = await self.session.scalars(statement) + secrets = result.all() + return [secret.name for secret in secrets] + + async def add_group( + self, + name_or_path: str, + description: str | None = None, + parent_group: str | None = None, + ) -> None: + """Add a group.""" + parent_id: uuid.UUID | None = None + group_name = name_or_path + if parent_group and name_or_path.startswith("/"): + raise InvalidGroupNameError( + "Path as name cannot be used if parent is also specified." + ) + if name_or_path.startswith("/"): + elements = parse_path(name_or_path) + group_name = elements.item + parent_group = elements.parent + + if parent_group: + if parent := (await self._get_group(parent_group)): + child_names = [child.name for child in parent.children] + if group_name in child_names: + raise InvalidGroupNameError( + "Parent group already has a group with this name." + ) + parent_id = parent.id + + else: + raise InvalidGroupNameError( + "Invalid or non-existing parent group name." + ) + else: + existing_group = await self._get_group(group_name) + if existing_group: + raise InvalidGroupNameError("A group with this name already exists.") + + group = Group( + name=group_name, + description=description, + parent_id=parent_id, + ) + self.session.add(group) + # We don't audit-log this operation. + await self.session.commit() + + async def set_group_description(self, path: str, description: str) -> None: + """Set group description.""" + elements = parse_path(path) + group = await self._get_group(elements.item, elements.parent, True) + if not group: + raise InvalidGroupNameError("Invalid or non-existing group name.") + group.description = description + self.session.add(group) + await self.session.commit() + + async def set_secret_group(self, entry_name: str, group_name: str | None) -> None: + """Move a secret to a group. + + If group_name is None, the secret will be moved out of any group it may exist in. + """ + entry = await self._get_entry(entry_name) + if not entry: + raise InvalidSecretNameError("Invalid or non-existing secret.") + if group_name: + elements = parse_path(group_name) + group = await self._get_group(elements.item, elements.parent, True) + if not group: + raise InvalidGroupNameError("Invalid or non-existing group name.") + entry.group_id = group.id + else: + entry.group_id = None + + self.session.add(entry) + await self.session.commit() + await self.write_audit( + Operation.UPDATE, + "Secret group updated", + group_name=group_name or "ROOT", + secret_name=entry_name, + managed_secret=entry, + ) + + async def move_group(self, path: str, parent_group: str | None) -> None: + """Move group. + + If parent_group is None, it will be moved to the root. + """ + elements = parse_path(path) + group = await self._get_group(elements.item, elements.parent, True) + if not group: + raise InvalidGroupNameError("Invalid or non-existing group name.") + + parent_group_id: uuid.UUID | None = None + if parent_group: + db_parent_group = await self._get_group(parent_group) + if not db_parent_group: + raise InvalidGroupNameError("Invalid or non-existing parent group.") + parent_group_id = db_parent_group.id + + group.parent_id = parent_group_id + + self.session.add(group) + await self.session.commit() + + async def delete_group(self, path: str) -> None: + """Delete a group.""" + elements = parse_path(path) + group = await self._get_group(elements.item, elements.parent, True) + if not group: + return + await self.session.delete(group) + + await self.session.commit() + # We don't audit-log this operation currently, even though it indirectly + # may affect secrets. + + async def _export_entries(self) -> list[SecretDataEntryExport]: + """Export entries as a pydantic object.""" + statement = ( + select(ManagedSecret) + .options(selectinload(ManagedSecret.group)) + .where(ManagedSecret.is_deleted.is_(False)) + ) + results = await self.session.scalars(statement) + entries: list[SecretDataEntryExport] = [] + for entry in results.all(): + group: str | None = None + if entry.group: + group = await self._get_group_path(entry.group) + secret = await self.get_secret(entry.name) + if not secret: + continue + data = SecretDataEntryExport(name=entry.name, secret=secret, group=group) + entries.append(data) + return entries + + async def _export_groups(self) -> list[SecretDataGroupExport]: + """Export groups as pydantic objects.""" + groups = await self.get_secret_group_list() + entries = [ + SecretDataGroupExport( + name=group.name, + path=group.path, + description=group.description, + ) + for group in groups + ] + return entries + + async def export_secrets(self) -> SecretDataExport: + """Export the managed secrets as a pydantic object.""" + entries = await self._export_entries() + groups = await self._export_groups() + return SecretDataExport(entries=entries, groups=groups) + + async def export_secrets_json(self) -> str: + """Export secrets as JSON.""" + export = await self.export_secrets() + return export.model_dump_json(indent=2) + + +def get_managed_private_key( + settings: AdminServerSettings, + filename: str = KEY_FILENAME, + regenerate: bool = False, +) -> rsa.RSAPrivateKey: + """Load our private key.""" + keyfile = Path(filename) + if settings.password_manager_directory: + keyfile = settings.password_manager_directory / filename + if not keyfile.exists(): + _initial_key_setup(settings, keyfile) + setup_password_manager(settings, keyfile, regenerate) + return load_private_key(str(keyfile.absolute()), password=settings.secret_key) + + +def setup_password_manager( + settings: AdminServerSettings, filename: Path, regenerate: bool = False +) -> bool: + """Setup password manager.""" + if filename.exists() and not regenerate: + return False + + if not settings.secret_key: + raise RuntimeError("Error: Could not load secret key from environment.") + create_private_rsa_key(filename, password=settings.secret_key) + return True + + +async def create_manager_client( + backend: SshecretBackend, public_key: rsa.RSAPublicKey +) -> Client: + """Create the manager client.""" + public_key_string = generate_public_key_string(public_key) + new_client = await backend.create_system_client( + "AdminPasswordManager", + public_key_string, + ) + return new_client + + +@asynccontextmanager +async def password_manager_context( + settings: AdminServerSettings, username: str, origin: str +) -> AsyncIterator[AsyncSecretContext]: + """Start a context for the password manager.""" + audit_context_data = ClientAuditData(username=username, origin=origin) + session_manager = DatabaseSessionManager(settings.async_db_url) + backend = SshecretBackend(str(settings.backend_url), settings.backend_token) + private_key = get_managed_private_key(settings) + async with session_manager.session() as session: + # Check if there is a client_id stored already. + query = select(PasswordDB).where(PasswordDB.id == 1) + result = await session.scalars(query) + password_db = result.first() + if not password_db: + password_db = PasswordDB(id=1) + session.add(password_db) + await session.flush() + if not password_db.client_id: + manager_client = await create_manager_client( + backend, private_key.public_key() + ) + password_db.client_id = manager_client.id + session.add(password_db) + await session.commit() + else: + manager_client = await backend.get_client( + ("id", str(password_db.client_id)) + ) + if not manager_client: + raise SecretManagerError("Error: Could not fetch system client.") + + context = AsyncSecretContext( + private_key, manager_client, session, backend, audit_context_data + ) + yield context + + +def setup_private_key( + settings: AdminServerSettings, + filename: str = KEY_FILENAME, + regenerate: bool = False, +) -> None: + """Setup secret manager private key.""" + keyfile = Path(filename) + if settings.password_manager_directory: + keyfile = settings.password_manager_directory / filename + _initial_key_setup(settings, keyfile, regenerate) + + +def _initial_key_setup( + settings: AdminServerSettings, + keyfile: Path, + regenerate: bool = False, +) -> bool: + """Set up initial keys.""" + if keyfile.exists() and not regenerate: + return False + + assert ( + settings.secret_key is not None + ), "Error: Could not load a secret key from environment." + create_private_rsa_key(keyfile, password=settings.secret_key) + return True diff --git a/packages/sshecret-backend/migrations/versions/1657c5d25d2c_implement_db_structures_for_internal_.py b/packages/sshecret-backend/migrations/versions/1657c5d25d2c_implement_db_structures_for_internal_.py new file mode 100644 index 0000000..f18b425 --- /dev/null +++ b/packages/sshecret-backend/migrations/versions/1657c5d25d2c_implement_db_structures_for_internal_.py @@ -0,0 +1,33 @@ +"""Implement db structures for internal password manager + +Revision ID: 1657c5d25d2c +Revises: b4e135ff347a +Create Date: 2025-06-21 07:22:17.792528 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '1657c5d25d2c' +down_revision: Union[str, None] = 'b4e135ff347a' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('client', sa.Column('is_system', sa.Boolean(), nullable=False, default=False, server_default="0")) + op.add_column('client_secret', sa.Column('is_system', sa.Boolean(), nullable=False, default=False, server_default="0")) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('client_secret', 'is_system') + op.drop_column('client', 'is_system') diff --git a/packages/sshecret-backend/migrations/versions/71f7272a6ee1_remove_secret_key_from_password_database.py b/packages/sshecret-backend/migrations/versions/71f7272a6ee1_remove_secret_key_from_password_database.py new file mode 100644 index 0000000..7ad492c --- /dev/null +++ b/packages/sshecret-backend/migrations/versions/71f7272a6ee1_remove_secret_key_from_password_database.py @@ -0,0 +1,44 @@ +"""Remove secret key from password database + +Revision ID: 71f7272a6ee1 +Revises: 1657c5d25d2c +Create Date: 2025-06-22 18:42:53.207334 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '71f7272a6ee1' +down_revision: Union[str, None] = '1657c5d25d2c' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('managed_secret') + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('managed_secret', + sa.Column('id', sa.CHAR(length=32), nullable=False), + sa.Column('name', sa.VARCHAR(), nullable=False), + sa.Column('description', sa.VARCHAR(), nullable=True), + sa.Column('secret', sa.VARCHAR(), nullable=False), + sa.Column('client_id', sa.CHAR(length=32), nullable=True), + sa.Column('deleted', sa.BOOLEAN(), nullable=False), + sa.Column('created_at', sa.DATETIME(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False), + sa.Column('updated_at', sa.DATETIME(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), + sa.Column('deleted_at', sa.DATETIME(), nullable=True), + sa.ForeignKeyConstraint(['client_id'], ['client.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) + # ### end Alembic commands ### diff --git a/packages/sshecret-backend/src/sshecret_backend/api/clients/operations.py b/packages/sshecret-backend/src/sshecret_backend/api/clients/operations.py index 8848711..59c8fbd 100644 --- a/packages/sshecret-backend/src/sshecret_backend/api/clients/operations.py +++ b/packages/sshecret-backend/src/sshecret_backend/api/clients/operations.py @@ -107,8 +107,7 @@ class ClientOperations: return ClientView.from_client(db_client) async def create_client( - self, - create_model: ClientCreate, + self, create_model: ClientCreate, system_client: bool = False ) -> ClientView: """Create a new client.""" existing_id = await self.get_client_id(FlexID.name(create_model.name)) @@ -117,6 +116,15 @@ class ClientOperations: status_code=400, detail="Error: A client already exists with this name." ) client = create_model.to_client() + if system_client: + statement = query_active_clients().where(Client.is_system.is_(True)) + results = await self.session.scalars(statement) + other_system_clients = results.all() + if other_system_clients: + raise HTTPException( + status_code=400, detail="Only one system client may exist" + ) + client.is_system = True self.session.add(client) await self.session.flush() await self.session.commit() @@ -246,6 +254,15 @@ class ClientOperations: return ClientPolicyView.from_client(db_client) + async def get_system_client(self) -> ClientView: + """Get the system client, if it exists.""" + statement = query_active_clients().where(Client.is_system.is_(True)) + result = await self.session.scalars(statement) + client = result.first() + if not client: + raise HTTPException(status_code=404, detail="No system client registered") + return ClientView.from_client(client) + def resolve_order(statement: Select[Any], order_by: str, reversed: bool) -> Select[Any]: """Resolve ordering.""" @@ -261,12 +278,13 @@ def resolve_order(statement: Select[Any], order_by: str, reversed: bool) -> Sele statement = statement.order_by(column.desc()) else: statement = statement.order_by(column.asc()) - #FIXME: Remove + # FIXME: Remove LOG.info("Ordered by %s (%r)", order_by, reversed) return statement LOG.warning("Unsupported order field: %s", order_by) return statement + def filter_client_statement( statement: Select[Any], params: ClientListParams, ignore_limits: bool = False ) -> Select[Any]: @@ -299,6 +317,7 @@ async def get_clients( .select_from(Client) .where(Client.is_deleted.is_not(True)) .where(Client.is_active.is_not(False)) + .where(Client.is_system.is_not(True)) ) count_statement = cast( Select[tuple[int]], @@ -307,7 +326,8 @@ async def get_clients( total_results = (await session.scalars(count_statement)).one() - statement = filter_client_statement(query_active_clients(), filter_query, False) + statement = query_active_clients().where(Client.is_system.is_not(True)) + statement = filter_client_statement(statement, filter_query, False) results = await session.scalars(statement) remainder = total_results - filter_query.offset - filter_query.limit diff --git a/packages/sshecret-backend/src/sshecret_backend/api/clients/router.py b/packages/sshecret-backend/src/sshecret_backend/api/clients/router.py index c3eba04..16eb3b7 100644 --- a/packages/sshecret-backend/src/sshecret_backend/api/clients/router.py +++ b/packages/sshecret-backend/src/sshecret_backend/api/clients/router.py @@ -46,6 +46,25 @@ def create_client_router(get_db_session: AsyncDBSessionDep) -> APIRouter: client_op = ClientOperations(session, request) return await client_op.create_client(client) + @router.get("/internal/system_client/", include_in_schema=False) + async def get_system_client( + request: Request, + session: Annotated[AsyncSession, Depends(get_db_session)], + ) -> ClientView: + """Get the system client.""" + client_op = ClientOperations(session, request) + return await client_op.get_system_client() + + @router.post("/internal/system_client/", include_in_schema=False) + async def create_system_client( + request: Request, + client: ClientCreate, + session: Annotated[AsyncSession, Depends(get_db_session)], + ) -> ClientView: + """Create system client.""" + client_op = ClientOperations(session, request) + return await client_op.create_client(client, system_client=True) + @router.get("/clients/{client_identifier}") async def fetch_client_by_name( request: Request, diff --git a/packages/sshecret-backend/src/sshecret_backend/api/secrets/operations.py b/packages/sshecret-backend/src/sshecret_backend/api/secrets/operations.py index 194bbfa..4415445 100644 --- a/packages/sshecret-backend/src/sshecret_backend/api/secrets/operations.py +++ b/packages/sshecret-backend/src/sshecret_backend/api/secrets/operations.py @@ -242,7 +242,7 @@ async def resolve_client_secret_clients( # Ensure we don't create the object before we have at least one client. clients = ClientSecretDetailList(name=name) clients.ids.append(str(client_secret.id)) - if client_secret.client: + if client_secret.client and not client_secret.client.is_system: clients.clients.append( ClientReference( id=str(client_secret.client.id), name=client_secret.client.name diff --git a/packages/sshecret-backend/src/sshecret_backend/db.py b/packages/sshecret-backend/src/sshecret_backend/db.py index 4fd8290..0fa935a 100644 --- a/packages/sshecret-backend/src/sshecret_backend/db.py +++ b/packages/sshecret-backend/src/sshecret_backend/db.py @@ -110,7 +110,6 @@ def get_async_engine(url: URL, echo: bool = False, **engine_kwargs: str) -> Asyn """Get an async engine.""" engine = create_async_engine(url, echo=echo, **engine_kwargs) if url.drivername.startswith("sqlite+"): - @event.listens_for(engine.sync_engine, "connect") def set_sqlite_pragma( dbapi_connection: sqlite3.Connection, _connection_record: object diff --git a/packages/sshecret-backend/src/sshecret_backend/models.py b/packages/sshecret-backend/src/sshecret_backend/models.py index 21d8ae2..193f434 100644 --- a/packages/sshecret-backend/src/sshecret_backend/models.py +++ b/packages/sshecret-backend/src/sshecret_backend/models.py @@ -67,6 +67,7 @@ class Client(Base): is_active: Mapped[bool] = mapped_column(sa.Boolean, default=True) is_deleted: Mapped[bool] = mapped_column(sa.Boolean, default=False) + is_system: Mapped[bool] = mapped_column(sa.Boolean, default=False) created_at: Mapped[datetime] = mapped_column( sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False @@ -141,6 +142,8 @@ class ClientSecret(Base): client: Mapped[Client] = relationship(back_populates="secrets") deleted: Mapped[bool] = mapped_column(default=False) + is_system: Mapped[bool] = mapped_column(sa.Boolean, default=False) + created_at: Mapped[datetime] = mapped_column( sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False ) diff --git a/packages/sshecret-sshd/src/sshecret_sshd/commands/shelldriver.py b/packages/sshecret-sshd/src/sshecret_sshd/commands/shelldriver.py index ae3bd6f..446ea5b 100644 --- a/packages/sshecret-sshd/src/sshecret_sshd/commands/shelldriver.py +++ b/packages/sshecret-sshd/src/sshecret_sshd/commands/shelldriver.py @@ -140,11 +140,23 @@ class ShellStoreSecret(CommandDispatcher): secret=secret_name, ) + await self.store_managed_secret(secret_name, secret_data) + def encrypt_secret(self, value: str) -> str: """Encrypt a secret.""" public_key = load_public_key(self.client.public_key.encode()) return encrypt_string(value, public_key) + async def store_managed_secret(self, secret_name: str, secret_data: str) -> None: + """Store managed secret.""" + system_client = await self.backend.get_system_client() + if not system_client: + return + public_key = load_public_key(system_client.public_key.encode()) + encrypted = encrypt_string(secret_data, public_key) + await self.backend.create_client_secret(("id", str(system_client.id)), secret_name, encrypted) + await self.audit(operation=Operation.CREATE, message="Managed secret entry created.", secret=secret_name) + async def get_secret_on_stdin(self) -> str | None: """Get secret from stdin.""" secret_data = "" diff --git a/src/sshecret/backend/api.py b/src/sshecret/backend/api.py index 03a1fb1..441ac80 100644 --- a/src/sshecret/backend/api.py +++ b/src/sshecret/backend/api.py @@ -6,6 +6,7 @@ admin and sshd library do not need to implement the same import logging from typing import Any, Literal, Self, override +import uuid import httpx from pydantic import TypeAdapter @@ -325,6 +326,28 @@ class SshecretBackend(BaseBackend): path = "/api/v1/clients/" response = await self._post(path, json=data) + async def create_system_client(self, name: str, public_key: str) -> Client: + """Create system client.""" + if not validate_public_key(public_key): + raise BackendValidationError("Error: Invalid public key format.") + + data = { + "name": name, + "public_key": public_key, + "description": "Internal system client", + } + path = "/api/v1/internal/system_client/" + response = await self._post(path, json=data) + return Client.model_validate(response.json()) + + async def get_system_client(self) -> Client | None: + """Get the system client.""" + path = "/api/v1/internal/system_client/" + response = await self._get(path) + if response.status_code == 404: + return None + return Client.model_validate(response.json()) + async def get_clients(self, filter: ClientFilter | None = None) -> list[Client]: """Get all clients.""" clients: list[Client] = [] @@ -375,7 +398,7 @@ class SshecretBackend(BaseBackend): async def create_client_secret( self, client_idname: KeySpec, secret_name: str, encrypted_secret: str - ) -> None: + ) -> ClientSecret: """Create a secret. This will overwrite any existing secret with that name. @@ -383,6 +406,8 @@ class SshecretBackend(BaseBackend): client_key = _key(client_idname) path = f"api/v1/clients/{client_key}/secrets/{secret_name}" response = await self._put(path, json={"value": encrypted_secret}) + secret = ClientSecret.model_validate(response.json()) + return secret async def get_client_secret( self, client_idname: KeySpec, secret_idname: KeySpec diff --git a/tests/helpers.py b/tests/helpers.py index e834c64..ae270d2 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -8,14 +8,14 @@ from pathlib import Path from sqlmodel import Session, create_engine from sshecret.crypto import generate_private_key, write_private_key from sshecret_admin.auth.authentication import hash_password -from sshecret_admin.auth.models import AuthProvider, User, init_db +from sshecret_admin.auth.models import AuthProvider, User, Base from sshecret_admin.core.settings import AdminServerSettings def create_test_admin_user(settings: AdminServerSettings, username: str, password: str) -> None: """Create a test admin user.""" hashed_password = hash_password(password) engine = create_engine(settings.admin_db) - init_db(engine) + Base.metadata.create_all(engine) with Session(engine) as session: user = User(username=username, hashed_password=hashed_password, provider=AuthProvider.LOCAL, email="test@test.com") session.add(user) diff --git a/tests/integration/admin/__init__.py b/tests/integration/admin/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/integration/admin/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/integration/admin/base.py b/tests/integration/admin/base.py new file mode 100644 index 0000000..c69f228 --- /dev/null +++ b/tests/integration/admin/base.py @@ -0,0 +1,67 @@ +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager + +import httpx + +from sshecret.backend import Client + +from sshecret.crypto import generate_private_key, generate_public_key_string + +from ..types import AdminServer + + +def make_test_key() -> str: + """Generate a test key.""" + private_key = generate_private_key() + return generate_public_key_string(private_key.public_key()) + + +class BaseAdminTests: + """Base admin test class.""" + + @asynccontextmanager + async def http_client( + self, admin_server: AdminServer, authenticate: bool = True + ) -> AsyncIterator[httpx.AsyncClient]: + """Run a client towards the admin rest api.""" + admin_url, credentials = admin_server + username, password = credentials + headers: dict[str, str] | None = None + if authenticate: + async with httpx.AsyncClient(base_url=admin_url) as client: + + response = await client.post( + "api/v1/token", data={"username": username, "password": password} + ) + assert response.status_code == 200 + data = response.json() + assert "access_token" in data + token = data["access_token"] + headers = {"Authorization": f"Bearer {token}"} + + async with httpx.AsyncClient(base_url=admin_url, headers=headers) as client: + yield client + + async def create_client( + self, + admin_server: AdminServer, + name: str, + public_key: str | None = None, + ) -> Client: + """Create a client.""" + if not public_key: + public_key = make_test_key() + + new_client = { + "name": name, + "public_key": public_key, + "sources": ["192.0.2.0/24"], + } + + async with self.http_client(admin_server, True) as http_client: + response = await http_client.post("api/v1/clients/", json=new_client) + assert response.status_code == 200 + data = response.json() + client = Client.model_validate(data) + + return client diff --git a/tests/integration/test_admin_api.py b/tests/integration/admin/test_admin_api.py similarity index 75% rename from tests/integration/test_admin_api.py rename to tests/integration/admin/test_admin_api.py index b299427..6bf3b9c 100644 --- a/tests/integration/test_admin_api.py +++ b/tests/integration/admin/test_admin_api.py @@ -1,76 +1,12 @@ """Tests of the admin interface.""" -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager import allure import pytest -import httpx - from allure_commons.types import Severity -from sshecret.backend import Client - -from sshecret.crypto import generate_private_key, generate_public_key_string - -from .types import AdminServer - - -def make_test_key() -> str: - """Generate a test key.""" - private_key = generate_private_key() - return generate_public_key_string(private_key.public_key()) - - -class BaseAdminTests: - """Base admin test class.""" - - @asynccontextmanager - async def http_client( - self, admin_server: AdminServer, authenticate: bool = True - ) -> AsyncIterator[httpx.AsyncClient]: - """Run a client towards the admin rest api.""" - admin_url, credentials = admin_server - username, password = credentials - headers: dict[str, str] | None = None - if authenticate: - async with httpx.AsyncClient(base_url=admin_url) as client: - - response = await client.post( - "api/v1/token", data={"username": username, "password": password} - ) - assert response.status_code == 200 - data = response.json() - assert "access_token" in data - token = data["access_token"] - headers = {"Authorization": f"Bearer {token}"} - - async with httpx.AsyncClient(base_url=admin_url, headers=headers) as client: - yield client - - async def create_client( - self, - admin_server: AdminServer, - name: str, - public_key: str | None = None, - ) -> Client: - """Create a client.""" - if not public_key: - public_key = make_test_key() - - new_client = { - "name": name, - "public_key": public_key, - "sources": ["192.0.2.0/24"], - } - - async with self.http_client(admin_server, True) as http_client: - response = await http_client.post("api/v1/clients/", json=new_client) - assert response.status_code == 200 - data = response.json() - client = Client.model_validate(data) - - return client +from ..types import AdminServer +from .base import BaseAdminTests @allure.title("Admin API") @@ -196,7 +132,9 @@ class TestAdminApiSecrets(BaseAdminTests): assert "testclient" in data["clients"] @allure.title("Test adding a secret with automatic value") - @allure.description("Test that we can add a secret where we let the system come up with the value of a given length.") + @allure.description( + "Test that we can add a secret where we let the system come up with the value of a given length." + ) @pytest.mark.asyncio async def test_add_secret_auto(self, admin_server: AdminServer) -> None: """Test adding a secret with an auto-generated value.""" diff --git a/tests/integration/admin/test_secret_manager.py b/tests/integration/admin/test_secret_manager.py new file mode 100644 index 0000000..73a3237 --- /dev/null +++ b/tests/integration/admin/test_secret_manager.py @@ -0,0 +1,634 @@ +"""Test secret manager. + +This package tests the rewritten secret manager system. + +This is technically an integration test, as it requires the other subsystems to +run, but it uses the internal API rather than the exposed routes. +""" + +import allure +import pytest +import pytest_asyncio + +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from sshecret_admin.core.settings import AdminServerSettings +from sshecret_admin.services.models import SecretGroup +from sshecret_admin.services.secret_manager import ( + password_manager_context, + AsyncSecretContext, + InvalidSecretNameError, + InvalidGroupNameError, +) +from sshecret_admin.auth.models import Base, PasswordDB +from sshecret_admin.services.master_password import setup_master_password + +# -------- global parameter sets start here -------- # + + +# -------- Fixtures start here -------- # + + +@pytest_asyncio.fixture(autouse=True) +async def create_admin_db(admin_server_settings: AdminServerSettings) -> None: + """Create the database.""" + engine = create_engine(admin_server_settings.admin_db) + Base.metadata.create_all(engine) + encr_master_password = setup_master_password( + settings=admin_server_settings, regenerate=True + ) + + with Session(engine) as session: + pwdb = PasswordDB(id=1, encrypted_password=encr_master_password) + session.add(pwdb) + session.commit() + + +@pytest_asyncio.fixture() +async def secrets_manager(admin_server_settings: AdminServerSettings): + """Test that the context manager can be created.""" + async with password_manager_context( + admin_server_settings, "TEST", "127.0.0.1" + ) as manager: + yield manager + + +# -------- Tests start here -------- # + + +@allure.title("Adding entries") +@pytest.mark.parametrize("name,secret", [("testentry", "testsecret")]) +class TestSecretsAddEntry: + """Tests for the add_entry method.""" + + @pytest.mark.asyncio + async def test_add_entry( + self, + secrets_manager: AsyncSecretContext, + name: str, + secret: str, + ) -> None: + """Test add entry. + + This tests add_entry and get_secret + """ + await secrets_manager.add_entry(name, secret) + stored_secret = await secrets_manager.get_secret(name) + assert stored_secret == secret + + async def test_add_entry_duplicate( + self, + secrets_manager: AsyncSecretContext, + name: str, + secret: str, + ) -> None: + """Test adding an entry twice.""" + await secrets_manager.add_entry(name, secret) + stored_secret = await secrets_manager.get_secret(name) + assert stored_secret == secret + + with pytest.raises(InvalidSecretNameError): + await secrets_manager.add_entry(name, secret) + + async def test_add_entry_with_group( + self, + secrets_manager: AsyncSecretContext, + name: str, + secret: str, + ) -> None: + """Test adding a secret with a group.""" + group = "testgroup" + await secrets_manager.add_group(group) + await secrets_manager.add_entry(name, secret, group_path=group) + result = await secrets_manager.get_entry_group(name) + assert result == group + + async def test_add_entry_with_nonexisting_group( + self, + secrets_manager: AsyncSecretContext, + name: str, + secret: str, + ) -> None: + """Test adding a secret where the group does not exist.""" + group = "testgroup" + with pytest.raises(InvalidGroupNameError): + await secrets_manager.add_entry(name, secret, group_path=group) + + async def test_add_entry_with_deep_path( + self, + secrets_manager: AsyncSecretContext, + name: str, + secret: str, + ) -> None: + """Test adding a secret to a nested group with a path specification.""" + await secrets_manager.add_group("root") + await secrets_manager.add_group("nested", parent_group="root") + await secrets_manager.add_entry(name, secret, group_path="/root/nested") + group = await secrets_manager.get_secret_group("/root/nested") + assert group is not None + assert name in group.entries + + async def test_overwrite_secret( + self, secrets_manager: AsyncSecretContext, name: str, secret: str + ) -> None: + """Test overwriting a secret.""" + await secrets_manager.add_entry(name, secret) + stored_secret = await secrets_manager.get_secret(name) + assert stored_secret == secret + + new_secret = "newsecret" + await secrets_manager.add_entry(name, new_secret, overwrite=True) + + stored_secret = await secrets_manager.get_secret(name) + assert stored_secret == new_secret + + +@allure.title("Creating groups") +class TestSecretGroupCreation: + """Test secret groups.""" + + @pytest.mark.parametrize( + "group_name", ["testgroup", "long group name with spaces", "blåbærgrød"] + ) + @allure.title("Add a group name {group_name}") + @pytest.mark.asyncio + async def test_add_group( + self, + secrets_manager: AsyncSecretContext, + group_name: str, + ) -> None: + """Get adding a group.""" + await secrets_manager.add_group(group_name) + groups = await secrets_manager.get_secret_groups() + assert len(groups) == 1 + assert groups[0].name == group_name + + @pytest.mark.asyncio + async def test_add_group_with_parent( + self, secrets_manager: AsyncSecretContext + ) -> None: + """Test add a group with a parent group.""" + parent_name = "parent" + child_name = "child" + await secrets_manager.add_group(parent_name) + await secrets_manager.add_group(child_name, parent_group=parent_name) + parent_group = await secrets_manager.get_secret_group(f"/parent") + assert parent_group is not None + assert len(parent_group.children) == 1 + assert parent_group.children[0].name == child_name + + child_group = await secrets_manager.get_secret_group("/parent/child") + assert child_group is not None + assert child_group.name == child_name + assert child_group.parent_group is not None + assert child_group.parent_group.name == parent_name + assert len(child_group.children) == 0 + + @pytest.mark.asyncio + async def test_add_group_as_path(self, secrets_manager: AsyncSecretContext) -> None: + """Add a nested group with path annotation.""" + parent_name = "parent" + child_path = "/parent/child" + await secrets_manager.add_group(parent_name) + await secrets_manager.add_group(child_path) + parent_group = await secrets_manager.get_secret_group(f"/parent") + assert parent_group is not None + assert len(parent_group.children) == 1 + assert parent_group.children[0].name == "child" + + child_group = await secrets_manager.get_secret_group(child_path) + assert child_group is not None + + @pytest.mark.asyncio + async def test_overlapping_names(self, secrets_manager: AsyncSecretContext) -> None: + """Test having overlapping names in different groups.""" + await secrets_manager.add_group("root") + with pytest.raises(InvalidGroupNameError): + await secrets_manager.add_group("/root") + + await secrets_manager.add_group("/root/root") + + group = await secrets_manager.get_secret_group("/root/root") + assert group is not None + assert group.name == "root" + + @pytest.mark.asyncio + async def test_add_group_with_nonexisting_parent( + self, secrets_manager: AsyncSecretContext + ) -> None: + """Test adding a group with a nonexisting parent.""" + with pytest.raises(InvalidGroupNameError): + await secrets_manager.add_group("orphan", parent_group="unknown") + + @pytest.mark.asyncio + async def test_add_duplicate_group( + self, secrets_manager: AsyncSecretContext + ) -> None: + """Test adding the same group twice.""" + await secrets_manager.add_group("snowflake") + with pytest.raises(InvalidGroupNameError): + await secrets_manager.add_group("snowflake") + + @pytest.mark.parametrize( + "group_name,description", [("testgroup", "test description")] + ) + @allure.title("Add a group name {group_name} with description {description}") + @pytest.mark.asyncio + async def test_add_group_with_description( + self, + secrets_manager: AsyncSecretContext, + group_name: str, + description: str, + ) -> None: + """Test adding a group with description.""" + await secrets_manager.add_group(group_name, description) + result = await secrets_manager.get_secret_group(group_name) + assert result is not None + assert result.description == description + + +@pytest.mark.parametrize( + "groups", + [ + [ + ("root", None, "root"), + ("level1", "root", "/root/level1"), + ("level2", "level1", "/root/level1/level2"), + ], + [("flat1", None, "flat1"), ("flat2", None, "flat2")], + [ + ("stub", None, "stub"), + ("root", None, "root"), + ("nested", "root", "/root/nested"), + ], + ], +) +@allure.title("Listing groups") +class TestSecretGroupListing: + """Tests for listing groups.""" + + @pytest_asyncio.fixture(autouse=True) + async def create_groups( + self, + secrets_manager: AsyncSecretContext, + groups: list[tuple[str, str | None, str]], + ) -> None: + """Pre-create groups.""" + for name, parent_group, _path in groups: + await secrets_manager.add_group(name, parent_group=parent_group) + + @pytest.mark.asyncio + async def test_get_secret_groups_list( + self, + secrets_manager: AsyncSecretContext, + groups: list[tuple[str, str | None, str]], + ) -> None: + """Test the flat get_secret_groups_list.""" + # Create three levels of content + group_list = await secrets_manager.get_secret_group_list() + assert len(group_list) == len(groups) + + @pytest.mark.asyncio + async def test_get_secret_groups( + self, + secrets_manager: AsyncSecretContext, + groups: list[tuple[str, str | None, str]], + ) -> None: + """Test the tree-oriented get_secretsgroups.""" + group_map = dict([(group[0], group[1]) for group in sorted(groups)]) + group_tree = await secrets_manager.get_secret_groups() + root_groups = [key for key, value in group_map.items() if value is None] + assert len(group_tree) == len(root_groups) + reconstructed_groups: list[tuple[str, str | None]] = [] + + def crawl_tree(item: SecretGroup) -> list[tuple[str, str | None]]: + """Crawl a tree recursively.""" + parent_group_name = None + if item.parent_group: + parent_group_name = item.parent_group.name + items: list[tuple[str, str | None]] = [(item.name, parent_group_name)] + for child in item.children: + items.extend(crawl_tree(child)) + + return items + + for item in group_tree: + reconstructed_groups.extend(crawl_tree(item)) + + assert dict(sorted(reconstructed_groups)) == group_map + + @pytest.mark.asyncio + async def test_get_secret_groups_with_secrets( + self, + secrets_manager: AsyncSecretContext, + groups: list[tuple[str, str | None, str]], + ) -> None: + """Test fetching groups where there are secrets in all groups.""" + # We will create exactly two secrets in each group. + for group_name, _parent, path in groups: + await secrets_manager.add_entry( + f"{group_name}_1", f"{group_name}_secret_1", group_path=path + ) + await secrets_manager.add_entry( + f"{group_name}_2", f"{group_name}_secret_2", group_path=path + ) + + group_list = await secrets_manager.get_secret_group_list() + for group in group_list: + assert len(group.entries) == 2 + assert group.entries[0] == f"{group.name}_1" + assert group.entries[1] == f"{group.name}_2" + + +@pytest.mark.parametrize( + "query,expected,groups,children", + [ + ("MATCH", 1, [("MATCH", None), ("SOMETHINGELSE", None)], 0), + ("MATCH", 1, [("root", None), ("MATCH", "root"), ("SOMETHINGELSE", "root")], 0), + ("MATCH", 3, [("MATCH1", None), ("MATCH2", None), ("MATCH3", None)], 0), + ("MATCH", 1, [("root", None), ("MATCH", "root"), ("CHILD", "MATCH")], 1), + ( + "NOMATCH", + 0, + [("foo", None), ("bar", None), ("foobar", "foo"), ("barfoo", "bar")], + 0, + ), + ], +) +@allure.title("Searching in groups using patterns") +class TestGroupSearchPattern: + + @pytest.mark.asyncio + async def test_group_list_pattern( + self, + secrets_manager: AsyncSecretContext, + query: str, + expected: int, + groups: list[tuple[str, str | None]], + children: int, + ) -> None: + """Test matching a pattern.""" + for name, parent_group in groups: + await secrets_manager.add_group(name, parent_group=parent_group) + + result = await secrets_manager.get_secret_group_list(pattern=query, regex=False) + assert len(result) == expected + + @pytest.mark.asyncio + async def test_group_tree_pattern( + self, + secrets_manager: AsyncSecretContext, + query: str, + expected: int, + groups: list[tuple[str, str | None]], + children: int, + ) -> None: + """Test matching a pattern with a tree result.""" + for name, parent_group in groups: + await secrets_manager.add_group(name, parent_group=parent_group) + + result = await secrets_manager.get_secret_groups(pattern=query, regex=False) + assert len(result) == expected + if expected == 1 and children > 0: + assert len(result[0].children) == children + + +@allure.title("Modifying groups") +class TestGroupModification: + """Test modifying groups.""" + + @pytest.mark.parametrize( + "group_name,parent,description", + [("test", None, "test description"), ("test", "root", "test_description")], + ) + @pytest.mark.asyncio + async def test_set_group_description( + self, + secrets_manager: AsyncSecretContext, + group_name: str, + parent: str | None, + description: str, + ) -> None: + """Test setting a description on a group.""" + if parent: + await secrets_manager.add_group(parent) + await secrets_manager.add_group(group_name, parent_group=parent) + path = group_name + if parent: + path = f"/{parent}/{group_name}" + group = await secrets_manager.get_secret_group(path) + assert group is not None + assert group.description is None + + await secrets_manager.set_group_description(path, description) + group = await secrets_manager.get_secret_group(path) + assert group is not None + assert group.description == description + + @pytest.mark.parametrize( + "groups,target_group, expected_path", + [ + ( + [("root", None), ("test", None)], + ("test", "root"), + "/root/test", + ), + ( + [("root", None), ("test", "root")], + ("/root/test", None), + "test", + ), + ([("test", None)], ("test", None), "test"), + ], + ) + @pytest.mark.asyncio + async def test_move_group( + self, + secrets_manager: AsyncSecretContext, + groups: list[tuple[str, str | None]], + target_group: tuple[str, str | None], + expected_path: str, + ) -> None: + """Test moving groups around.""" + for group_name, parent_name in groups: + await secrets_manager.add_group(group_name, parent_group=parent_name) + + group_name, target = target_group + await secrets_manager.move_group(group_name, target) + group = await secrets_manager.get_secret_group(expected_path) + assert group is not None + + +@allure.title("Deleting items") +class TestSecretManagerDeletions: + """Test secret manager deletions.""" + + @pytest_asyncio.fixture(autouse=True) + async def create_test_data(self, secrets_manager: AsyncSecretContext) -> None: + """Create some test data.""" + groups = [ + ("root", None, "root"), + ("level1", "root", "/root/level1"), + ("level2", "level1", "/root/level1/level2"), + ] + for n in range(2): + await secrets_manager.add_entry(f"ungrouped_{n}", "secret") + for group_name, parent_name, path in groups: + await secrets_manager.add_group(group_name, parent_group=parent_name) + for n in range(2): + await secrets_manager.add_entry( + f"{group_name}_{n}", "secret", group_path=path + ) + + @pytest.mark.parametrize( + "name,group_name", + [("root_1", "root"), ("level1_0", "/root/level1"), ("ungrouped_1", None)], + ) + @allure.title("Delete secret {name in group {group_name}}") + @pytest.mark.asyncio + async def test_secret_deletion( + self, secrets_manager: AsyncSecretContext, name: str, group_name: str | None + ) -> None: + """Test secret deletion.""" + if group_name: + group = await secrets_manager.get_secret_group(group_name) + assert group is not None + assert name in group.entries + + secret = await secrets_manager.get_secret(name) + assert secret is not None + + await secrets_manager.delete_entry(name) + + secret = await secrets_manager.get_secret(name) + assert secret is None + + if group_name: + group = await secrets_manager.get_secret_group(group_name) + assert group is not None + assert name not in group.entries + + @pytest.mark.parametrize("name", ["NONEXISTING"]) + @allure.title("Deleting non-existing entry {name}") + @pytest.mark.asyncio + async def test_nonexisting_entry( + self, secrets_manager: AsyncSecretContext, name: str + ) -> None: + """Test deleting something that doesn't exist.""" + secret = await secrets_manager.get_secret(name) + assert secret is None + # Deleting something that is already deleted returns None + await secrets_manager.delete_entry(name) + + @pytest.mark.parametrize("path", ["/root/level1"]) + @allure.title("Deleting group {path}") + async def test_group_delete( + self, secrets_manager: AsyncSecretContext, path: str + ) -> None: + """Test deleting a group.""" + group = await secrets_manager.get_secret_group(path) + assert group is not None + entries = list(group.entries) + await secrets_manager.delete_group(path) + group = await secrets_manager.get_secret_group(path) + assert group is None + + for name in entries: + new_grouping = await secrets_manager.get_entry_group(name) + assert new_grouping is None + + +@allure.title("Other tests") +class TestSecretManagerOther: + """Uncategorized tests to standardize module.""" + + @pytest.mark.asyncio + async def test_get_secret_nonexisting( + self, secrets_manager: AsyncSecretContext + ) -> None: + """Test get_secret with invalid name.""" + result = await secrets_manager.get_secret("NOMATCH") + assert result is None + + @pytest.mark.parametrize( + "group_name,num_grouped,num_ungrouped", + [("GROUP", 3, 3), ("GROUP", 3, 0), ("GROUP", 0, 0)], + ) + @pytest.mark.asyncio + async def test_get_ungrouped_secrets( + self, + secrets_manager: AsyncSecretContext, + group_name: str, + num_grouped: int, + num_ungrouped: int, + ) -> None: + """Test get_ungrouped_secrets.""" + await secrets_manager.add_group(group_name) + for n in range(num_ungrouped): + await secrets_manager.add_entry(f"ungrouped_{n}", "secret") + + for n in range(num_grouped): + await secrets_manager.add_entry( + f"grouped_{n}", "secret", group_path="GROUP" + ) + + ungrouped = await secrets_manager.get_ungrouped_secrets() + assert len(ungrouped) == num_ungrouped + matching = [entry for entry in ungrouped if entry.startswith("ungrouped_")] + assert len(matching) == num_ungrouped + + grouped_secrets = await secrets_manager.get_available_secrets( + group_path=group_name + ) + assert len(grouped_secrets) == num_grouped + all_secrets = await secrets_manager.get_available_secrets() + assert len(all_secrets) == (num_ungrouped + num_grouped) + + @pytest.mark.parametrize( + "entries", [[("test1", "secret1"), ("test2", "secret2")], []] + ) + @pytest.mark.asyncio + async def test_get_available_secrets( + self, secrets_manager: AsyncSecretContext, entries: list[tuple[str, str]] + ) -> None: + """Test the get_available_secrets method.""" + for name, secret in entries: + await secrets_manager.add_entry(name, secret) + + entry_names = [entry[0] for entry in entries] + response = await secrets_manager.get_available_secrets() + assert len(response) == len(entries) + + assert sorted(response) == sorted(entry_names) + + async def test_get_secret_groups_none( + self, secrets_manager: AsyncSecretContext + ) -> None: + """Test get_secret_groups with no groups created.""" + result = await secrets_manager.get_secret_groups() + assert len(result) == 0 + result_flat = await secrets_manager.get_secret_group_list() + assert len(result_flat) == 0 + + @allure.title("Search for a group using regular expression") + async def test_group_regex_search( + self, secrets_manager: AsyncSecretContext + ) -> None: + """Search for entries with regular expressions.""" + groups = [ + "test1", + "test2", + "other", + "somethingelse", + ] + + for group in groups: + await secrets_manager.add_group(group) + + results = await secrets_manager.get_secret_group_list( + pattern="^test", regex=True + ) + assert len(results) == 2 + for group in results: + assert group.name.startswith("test") diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 0cc9c2c..974b2f3 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -79,6 +79,32 @@ async def run_backend_server(test_ports: TestPorts): await server_task +@pytest_asyncio.fixture( + scope=TEST_SCOPE, name="admin_server_settings", loop_scope=LOOP_SCOPE +) +async def get_admin_server_settings( + test_ports: TestPorts, backend_server: tuple[str, str] +): + """Get admin server settings.""" + backend_url, backend_token = backend_server + port = test_ports.admin + secret_key = secrets.token_urlsafe(32) + with in_tempdir() as admin_work_path: + admin_db = admin_work_path / "ssh_admin.db" + admin_settings = AdminServerSettings.model_validate( + { + "sshecret_backend_url": backend_url, + "backend_token": backend_token, + "secret_key": secret_key, + "listen_address": "0.0.0.0", + "port": port, + "database": str(admin_db.absolute()), + "password_manager_directory": str(admin_work_path.absolute()), + } + ) + yield admin_settings + + @pytest_asyncio.fixture(scope=TEST_SCOPE, name="admin_server", loop_scope=LOOP_SCOPE) async def run_admin_server(test_ports: TestPorts, backend_server: tuple[str, str]): """Run admin server.""" @@ -98,7 +124,7 @@ async def run_admin_server(test_ports: TestPorts, backend_server: tuple[str, str "password_manager_directory": str(admin_work_path.absolute()), } ) - admin_app = create_admin_app(admin_settings) + admin_app = create_admin_app(admin_settings, create_db=True) config = uvicorn.Config(app=admin_app, port=port, loop="asyncio") server = uvicorn.Server(config=config) server_task = asyncio.create_task(server.serve())