Write new secret manager using existing RSA logic
This commit is contained in:
@ -1,11 +1,11 @@
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy import Engine, engine_from_config, pool, create_engine
|
||||
|
||||
from alembic import context
|
||||
from sshecret_admin.auth.models import Base
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
@ -14,11 +14,32 @@ config = context.config
|
||||
|
||||
def get_database_url() -> str | None:
|
||||
"""Get database URL."""
|
||||
try:
|
||||
settings = AdminServerSettings() # pyright: ignore[reportCallIssue]
|
||||
return str(settings.admin_db)
|
||||
except Exception:
|
||||
if db_file := os.getenv("SSHECRET_ADMIN_DATABASE"):
|
||||
return f"sqlite:///{db_file}"
|
||||
return config.get_main_option("sqlalchemy.url")
|
||||
|
||||
|
||||
def get_engine() -> Engine:
|
||||
"""Get engine."""
|
||||
try:
|
||||
settings = AdminServerSettings() # pyright: ignore[reportCallIssue]
|
||||
engine = create_engine(settings.admin_db)
|
||||
return engine
|
||||
except Exception:
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
return connectable
|
||||
|
||||
|
||||
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
@ -68,12 +89,7 @@ def run_migrations_online() -> None:
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
connectable = get_engine()
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata, render_as_batch=True
|
||||
|
||||
@ -0,0 +1,44 @@
|
||||
"""Implement db structures for internal password manager
|
||||
|
||||
Revision ID: 84356d0ea85f
|
||||
Revises: 6c148590471f
|
||||
Create Date: 2025-06-21 07:21:02.257865
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '84356d0ea85f'
|
||||
down_revision: Union[str, None] = '6c148590471f'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('groups',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=False),
|
||||
sa.Column('parent_id', sa.Uuid(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['parent_id'], ['groups.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('password_db', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('client_id', sa.Uuid(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('password_db', schema=None) as batch_op:
|
||||
batch_op.drop_column('client_id')
|
||||
|
||||
op.drop_table('groups')
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,48 @@
|
||||
"""Implement managed secrets
|
||||
|
||||
Revision ID: c34707a1ea3a
|
||||
Revises: 84356d0ea85f
|
||||
Create Date: 2025-06-21 07:38:12.994535
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'c34707a1ea3a'
|
||||
down_revision: Union[str, None] = '84356d0ea85f'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('managed_secrets',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=False),
|
||||
sa.Column('is_deleted', sa.Boolean(), nullable=False),
|
||||
sa.Column('group_id', sa.Uuid(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
|
||||
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ondelete='SET NULL'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('groups', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('description', sa.String(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('groups', schema=None) as batch_op:
|
||||
batch_op.drop_column('description')
|
||||
|
||||
op.drop_table('managed_secrets')
|
||||
# ### end Alembic commands ###
|
||||
@ -5,9 +5,9 @@ import logging
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from sshecret_admin.auth import Token, authenticate_user, create_access_token
|
||||
from sshecret_admin.auth import Token, authenticate_user_async, create_access_token
|
||||
from sshecret_admin.core.dependencies import AdminDependencies
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -19,11 +19,12 @@ def create_router(dependencies: AdminDependencies) -> APIRouter:
|
||||
|
||||
@app.post("/token")
|
||||
async def login_for_access_token(
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||
|
||||
session: Annotated[AsyncSession, Depends(dependencies.get_async_session)],
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
) -> Token:
|
||||
"""Login user and generate token."""
|
||||
user = authenticate_user(session, form_data.username, form_data.password)
|
||||
user = await authenticate_user_async(session, form_data.username, form_data.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
||||
@ -128,7 +128,7 @@ def create_router(dependencies: AdminDependencies) -> APIRouter:
|
||||
group = await admin.get_secret_group(group_name)
|
||||
if not group:
|
||||
return
|
||||
await admin.delete_secret_group(group_name, keep_entries=True)
|
||||
await admin.delete_secret_group(group_name)
|
||||
|
||||
@app.post("/secrets/groups/{group_name}/{secret_name}")
|
||||
async def move_secret_to_group(
|
||||
|
||||
@ -5,8 +5,9 @@
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi.security.utils import get_authorization_scheme_param
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -57,6 +58,31 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
def get_client_origin(request: Request) -> str:
|
||||
"""Get client origin."""
|
||||
fallback_origin = "UNKNOWN"
|
||||
if request.client:
|
||||
return request.client.host
|
||||
return fallback_origin
|
||||
|
||||
def get_optional_username(request: Request) -> str | None:
|
||||
"""Get username, if available.
|
||||
|
||||
This is purely used for auditing purposes.
|
||||
"""
|
||||
authorization = request.headers.get("Authorization")
|
||||
scheme, param = get_authorization_scheme_param(authorization)
|
||||
if not authorization or scheme.lower() != "bearer":
|
||||
return None
|
||||
claims = decode_token(dependencies.settings, param)
|
||||
if not claims:
|
||||
return None
|
||||
|
||||
if claims.provider == LOCAL_ISSUER:
|
||||
return claims.sub
|
||||
|
||||
return f"oidc:{claims.email}"
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
) -> User:
|
||||
@ -66,9 +92,12 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
return current_user
|
||||
|
||||
async def get_admin_backend(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||
):
|
||||
"""Get admin backend API."""
|
||||
username = get_optional_username(request)
|
||||
origin = get_client_origin(request)
|
||||
password_db = session.scalars(
|
||||
select(PasswordDB).where(PasswordDB.id == 1)
|
||||
).first()
|
||||
@ -76,7 +105,11 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
raise HTTPException(
|
||||
500, detail="Error: The password manager has not yet been set up."
|
||||
)
|
||||
admin = AdminBackend(dependencies.settings, password_db.encrypted_password)
|
||||
admin = AdminBackend(
|
||||
dependencies.settings,
|
||||
username=username,
|
||||
origin=origin,
|
||||
)
|
||||
yield admin
|
||||
|
||||
app = APIRouter(prefix=f"/api/{API_VERSION}")
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
"""Models for authentication."""
|
||||
"""Models for authentication and secret management."""
|
||||
|
||||
import enum
|
||||
from datetime import datetime
|
||||
from typing import override
|
||||
import uuid
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
|
||||
|
||||
JWT_ALGORITHM = "HS256"
|
||||
@ -75,12 +76,15 @@ class PasswordDB(Base):
|
||||
__tablename__: str = "password_db"
|
||||
|
||||
id: Mapped[int] = mapped_column(sa.INT, primary_key=True)
|
||||
encrypted_password: Mapped[str] = mapped_column(sa.String)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||
)
|
||||
|
||||
client_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), nullable=True
|
||||
)
|
||||
|
||||
updated_at: Mapped[datetime | None] = mapped_column(
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
@ -88,6 +92,65 @@ class PasswordDB(Base):
|
||||
)
|
||||
|
||||
|
||||
class Group(Base):
|
||||
"""A secret group."""
|
||||
|
||||
__tablename__: str = "groups"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
name: Mapped[str] = mapped_column(sa.String, nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(sa.String, nullable=True)
|
||||
|
||||
parent_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
sa.ForeignKey("groups.id"), nullable=True
|
||||
)
|
||||
parent: Mapped["Group | None"] = relationship(
|
||||
"Group", remote_side=[id], back_populates="children"
|
||||
)
|
||||
children: Mapped[list["Group"]] = relationship(
|
||||
"Group", back_populates="parent", cascade="all, delete"
|
||||
)
|
||||
secrets: Mapped[list["ManagedSecret"]] = relationship(back_populates="group")
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return f"<Group id={self.id} name={self.name} parent_id={self.parent_id}>"
|
||||
|
||||
|
||||
class ManagedSecret(Base):
|
||||
"""Managed Secret."""
|
||||
|
||||
__tablename__: str = "managed_secrets"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
name: Mapped[str] = mapped_column(sa.String, nullable=False)
|
||||
|
||||
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, default=False)
|
||||
|
||||
group_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
sa.ForeignKey("groups.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
group: Mapped["Group | None"] = relationship(
|
||||
Group, foreign_keys=[group_id], back_populates="secrets"
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime | None] = mapped_column(
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
onupdate=sa.func.now(),
|
||||
)
|
||||
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
|
||||
class IdentityClaims(BaseModel):
|
||||
"""Normalized identity claim model."""
|
||||
|
||||
@ -125,6 +188,3 @@ class LocalUserInfo(BaseModel):
|
||||
local: bool
|
||||
|
||||
|
||||
def init_db(engine: sa.Engine) -> None:
|
||||
"""Create database."""
|
||||
Base.metadata.create_all(engine)
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
#
|
||||
from collections.abc import AsyncGenerator
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
@ -12,15 +13,15 @@ from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sshecret_backend.db import DatabaseSessionManager
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
|
||||
from sshecret_admin import api, frontend
|
||||
from sshecret_admin.auth.models import PasswordDB, init_db
|
||||
from sshecret_admin.auth.models import Base
|
||||
from sshecret_admin.core.db import setup_database
|
||||
from sshecret_admin.frontend.exceptions import RedirectException
|
||||
from sshecret_admin.services.master_password import setup_master_password
|
||||
from sshecret_admin.services.secret_manager import setup_private_key
|
||||
|
||||
from .dependencies import BaseDependencies
|
||||
from .settings import AdminServerSettings
|
||||
@ -40,44 +41,28 @@ def setup_frontend(app: FastAPI, dependencies: BaseDependencies) -> None:
|
||||
|
||||
|
||||
def create_admin_app(
|
||||
settings: AdminServerSettings, with_frontend: bool = True
|
||||
settings: AdminServerSettings,
|
||||
with_frontend: bool = True,
|
||||
create_db: bool = False,
|
||||
) -> FastAPI:
|
||||
"""Create admin app."""
|
||||
engine, get_db_session = setup_database(settings.admin_db)
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Get async session."""
|
||||
session_manager = DatabaseSessionManager(settings.async_db_url)
|
||||
async with session_manager.session() as session:
|
||||
yield session
|
||||
|
||||
def setup_password_manager() -> None:
|
||||
"""Setup password manager."""
|
||||
encr_master_password = setup_master_password(
|
||||
settings=settings, regenerate=False
|
||||
)
|
||||
with Session(engine) as session:
|
||||
existing_password = session.scalars(
|
||||
select(PasswordDB).where(PasswordDB.id == 1)
|
||||
).first()
|
||||
|
||||
if not encr_master_password:
|
||||
if existing_password:
|
||||
LOG.info("Master password already defined.")
|
||||
return
|
||||
# Looks like we have to regenerate it
|
||||
LOG.warning(
|
||||
"Master password was set, but not saved to the database. Regenerating it."
|
||||
)
|
||||
encr_master_password = setup_master_password(
|
||||
settings=settings, regenerate=True
|
||||
)
|
||||
|
||||
assert encr_master_password is not None
|
||||
|
||||
with Session(engine) as session:
|
||||
pwdb = PasswordDB(id=1, encrypted_password=encr_master_password)
|
||||
session.add(pwdb)
|
||||
session.commit()
|
||||
setup_private_key(settings, regenerate=False)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI):
|
||||
"""Create database before starting the server."""
|
||||
init_db(engine)
|
||||
if create_db:
|
||||
Base.metadata.create_all(engine)
|
||||
setup_password_manager()
|
||||
yield
|
||||
|
||||
@ -109,7 +94,7 @@ def create_admin_app(
|
||||
status_code=status.HTTP_200_OK, content=jsonable_encoder({"status": "LIVE"})
|
||||
)
|
||||
|
||||
dependencies = BaseDependencies(settings, get_db_session)
|
||||
dependencies = BaseDependencies(settings, get_db_session, get_async_session)
|
||||
|
||||
app.include_router(api.create_api_router(dependencies))
|
||||
if with_frontend:
|
||||
|
||||
@ -12,7 +12,7 @@ from pydantic import ValidationError
|
||||
from sqlalchemy import select, create_engine
|
||||
from sqlalchemy.orm import Session
|
||||
from sshecret_admin.auth.authentication import hash_password
|
||||
from sshecret_admin.auth.models import AuthProvider, PasswordDB, User, init_db
|
||||
from sshecret_admin.auth.models import AuthProvider, PasswordDB, User
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
from sshecret_admin.services.admin_backend import AdminBackend
|
||||
|
||||
@ -72,7 +72,6 @@ def cli_create_user(
|
||||
"""Create user."""
|
||||
settings = cast(AdminServerSettings, ctx.obj)
|
||||
engine = create_engine(settings.admin_db)
|
||||
init_db(engine)
|
||||
with Session(engine) as session:
|
||||
create_user(session, username, email, password)
|
||||
|
||||
@ -87,7 +86,6 @@ def cli_change_user_passwd(ctx: click.Context, username: str, password: str) ->
|
||||
"""Change password on user."""
|
||||
settings = cast(AdminServerSettings, ctx.obj)
|
||||
engine = create_engine(settings.admin_db)
|
||||
init_db(engine)
|
||||
with Session(engine) as session:
|
||||
user = session.scalars(select(User).where(User.username == username)).first()
|
||||
if not user:
|
||||
@ -107,7 +105,6 @@ def cli_delete_user(ctx: click.Context, username: str) -> None:
|
||||
"""Remove a user."""
|
||||
settings = cast(AdminServerSettings, ctx.obj)
|
||||
engine = create_engine(settings.admin_db)
|
||||
init_db(engine)
|
||||
with Session(engine) as session:
|
||||
user = session.scalars(select(User).where(User.username == username)).first()
|
||||
if not user:
|
||||
@ -149,7 +146,6 @@ def cli_repl(ctx: click.Context) -> None:
|
||||
"""Run an interactive console."""
|
||||
settings = cast(AdminServerSettings, ctx.obj)
|
||||
engine = create_engine(settings.admin_db)
|
||||
init_db(engine)
|
||||
with Session(engine) as session:
|
||||
password_db = session.scalars(
|
||||
select(PasswordDB).where(PasswordDB.id == 1)
|
||||
@ -165,7 +161,7 @@ def cli_repl(ctx: click.Context) -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(func)
|
||||
|
||||
admin = AdminBackend(settings, password_db.encrypted_password)
|
||||
admin = AdminBackend(settings, )
|
||||
locals = {
|
||||
"run": run,
|
||||
"admin": admin,
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
"""Database setup."""
|
||||
|
||||
import sqlite3
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from collections.abc import AsyncIterator, Generator, Callable
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.engine import URL
|
||||
from sqlalchemy import create_engine, Engine
|
||||
from sqlalchemy import create_engine, Engine, event
|
||||
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncConnection,
|
||||
@ -18,11 +19,20 @@ from sqlalchemy.ext.asyncio import (
|
||||
|
||||
|
||||
def setup_database(
|
||||
db_url: URL | str,
|
||||
db_url: URL,
|
||||
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
|
||||
"""Setup database."""
|
||||
|
||||
engine = create_engine(db_url, echo=True, future=True)
|
||||
if db_url.drivername.startswith("sqlite"):
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(
|
||||
dbapi_connection: sqlite3.Connection, _connection_record: object
|
||||
) -> None:
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
def get_db_session() -> Generator[Session, None, None]:
|
||||
"""Get DB Session."""
|
||||
@ -33,8 +43,18 @@ def setup_database(
|
||||
|
||||
|
||||
class DatabaseSessionManager:
|
||||
def __init__(self, host: URL | str, **engine_kwargs: str):
|
||||
def __init__(self, host: URL, **engine_kwargs: str):
|
||||
self._engine: AsyncEngine | None = create_async_engine(host, **engine_kwargs)
|
||||
if host.drivername.startswith("sqlite+"):
|
||||
|
||||
@event.listens_for(self._engine.sync_engine, "connect")
|
||||
def set_sqlite_pragma(
|
||||
dbapi_connection: sqlite3.Connection, _connection_record: object
|
||||
) -> None:
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
self._sessionmaker: async_sessionmaker[AsyncSession] | None = (
|
||||
async_sessionmaker(
|
||||
autocommit=False, bind=self._engine, expire_on_commit=False
|
||||
|
||||
@ -4,6 +4,8 @@ from collections.abc import AsyncGenerator, Awaitable, Callable, Generator
|
||||
from dataclasses import dataclass
|
||||
from typing import Self
|
||||
|
||||
from fastapi import Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
from sshecret_admin.auth import User
|
||||
from sshecret_admin.services import AdminBackend
|
||||
@ -11,8 +13,9 @@ from sshecret_admin.core.settings import AdminServerSettings
|
||||
|
||||
|
||||
DBSessionDep = Callable[[], Generator[Session, None, None]]
|
||||
AsyncSessionDep = Callable[[], AsyncGenerator[AsyncSession, None]]
|
||||
|
||||
AdminDep = Callable[[Session], AsyncGenerator[AdminBackend, None]]
|
||||
AdminDep = Callable[[Request, Session], AsyncGenerator[AdminBackend, None]]
|
||||
|
||||
GetUserDep = Callable[[User], Awaitable[User]]
|
||||
|
||||
@ -23,6 +26,8 @@ class BaseDependencies:
|
||||
|
||||
settings: AdminServerSettings
|
||||
get_db_session: DBSessionDep
|
||||
get_async_session: AsyncSessionDep
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -43,6 +48,7 @@ class AdminDependencies(BaseDependencies):
|
||||
return cls(
|
||||
settings=deps.settings,
|
||||
get_db_session=deps.get_db_session,
|
||||
get_async_session=deps.get_async_session,
|
||||
get_admin_backend=get_admin_backend,
|
||||
get_current_active_user=get_current_active_user,
|
||||
)
|
||||
|
||||
@ -30,7 +30,6 @@ class FrontendDependencies(BaseDependencies):
|
||||
get_refresh_claims: RefreshTokenDep
|
||||
get_login_status: LoginStatusDep
|
||||
get_user_info: UserInfoDep
|
||||
get_async_session: AsyncSessionDep
|
||||
require_login: LoginGuardDep
|
||||
|
||||
@classmethod
|
||||
@ -42,18 +41,17 @@ class FrontendDependencies(BaseDependencies):
|
||||
get_refresh_claims: RefreshTokenDep,
|
||||
get_login_status: LoginStatusDep,
|
||||
get_user_info: UserInfoDep,
|
||||
get_async_session: AsyncSessionDep,
|
||||
require_login: LoginGuardDep,
|
||||
) -> Self:
|
||||
"""Create from base dependencies."""
|
||||
return cls(
|
||||
settings=deps.settings,
|
||||
get_db_session=deps.get_db_session,
|
||||
get_async_session=deps.get_async_session,
|
||||
get_admin_backend=get_admin_backend,
|
||||
templates=templates,
|
||||
get_refresh_claims=get_refresh_claims,
|
||||
get_login_status=get_login_status,
|
||||
get_user_info=get_user_info,
|
||||
get_async_session=get_async_session,
|
||||
require_login=require_login,
|
||||
)
|
||||
|
||||
@ -24,7 +24,6 @@ from sshecret_admin.auth.constants import LOCAL_ISSUER
|
||||
|
||||
from sshecret_admin.core.dependencies import BaseDependencies
|
||||
from sshecret_admin.services.admin_backend import AdminBackend
|
||||
from sshecret_admin.core.db import DatabaseSessionManager
|
||||
|
||||
from .dependencies import FrontendDependencies
|
||||
from .exceptions import RedirectException
|
||||
@ -50,17 +49,24 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
templates = Jinja2Blocks(directory=template_path)
|
||||
|
||||
async def get_admin_backend(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||
):
|
||||
"""Get admin backend API."""
|
||||
password_db = session.scalars(
|
||||
select(PasswordDB).where(PasswordDB.id == 1)
|
||||
).first()
|
||||
username = get_optional_username(request)
|
||||
origin = get_client_origin(request)
|
||||
if not password_db:
|
||||
raise HTTPException(
|
||||
500, detail="Error: The password manager has not yet been set up."
|
||||
)
|
||||
admin = AdminBackend(dependencies.settings, password_db.encrypted_password)
|
||||
admin = AdminBackend(
|
||||
dependencies.settings,
|
||||
username=username,
|
||||
origin=origin,
|
||||
)
|
||||
yield admin
|
||||
|
||||
def get_identity_claims(request: Request) -> IdentityClaims:
|
||||
@ -108,14 +114,9 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
next = URL("/refresh").include_query_params(next=request.url.path)
|
||||
raise RedirectException(to=next)
|
||||
|
||||
async def get_async_session():
|
||||
"""Get async session."""
|
||||
sessionmanager = DatabaseSessionManager(dependencies.settings.async_db_url)
|
||||
async with sessionmanager.session() as session:
|
||||
yield session
|
||||
|
||||
async def get_user_info(
|
||||
request: Request, session: Annotated[AsyncSession, Depends(get_async_session)]
|
||||
request: Request,
|
||||
session: Annotated[AsyncSession, Depends(dependencies.get_async_session)],
|
||||
) -> LocalUserInfo:
|
||||
"""Get User information."""
|
||||
claims = get_identity_claims(request)
|
||||
@ -142,6 +143,30 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
next = URL("/refresh").include_query_params(next=request.url.path)
|
||||
raise RedirectException(to=next)
|
||||
|
||||
def get_optional_username(
|
||||
request: Request,
|
||||
) -> str | None:
|
||||
"""Get username, if available.
|
||||
|
||||
This is purely used for auditing purposes.
|
||||
"""
|
||||
try:
|
||||
claims = get_identity_claims(request)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if claims.provider == LOCAL_ISSUER:
|
||||
return claims.sub
|
||||
|
||||
return f"oidc:{claims.email}"
|
||||
|
||||
def get_client_origin(request: Request) -> str:
|
||||
"""Get client origin."""
|
||||
fallback_origin = "UNKNOWN"
|
||||
if request.client:
|
||||
return request.client.host
|
||||
return fallback_origin
|
||||
|
||||
view_dependencies = FrontendDependencies.create(
|
||||
dependencies,
|
||||
get_admin_backend,
|
||||
@ -149,7 +174,6 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
refresh_identity_claims,
|
||||
get_login_status,
|
||||
get_user_info,
|
||||
get_async_session,
|
||||
require_login,
|
||||
)
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
<!-- Detail Pane -->
|
||||
|
||||
<section id="detail-pane"
|
||||
class="flex-1 flex overflow-y-auto bg-white p-4 {%- if not mobile_show_details|default(false) -%} hidden{%- endif -%} lg:block dark:bg-gray-800">
|
||||
class="flex-1 flex overflow-y-auto bg-white p-4 lg:block {% if not mobile_show_details|default(false) -%} hidden{%- endif -%} lg:block dark:bg-gray-800">
|
||||
|
||||
|
||||
{% block detail %}
|
||||
|
||||
@ -5,7 +5,7 @@ import ipaddress
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request, Response
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException, Request, Response
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel, IPvAnyAddress, IPvAnyNetwork
|
||||
from sshecret_admin.frontend.views.common import PagingInfo
|
||||
@ -209,7 +209,7 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
page: int,
|
||||
) -> Response:
|
||||
"""Get more events for a client."""
|
||||
if not "HX-Request" in request.headers:
|
||||
if "HX-Request" not in request.headers:
|
||||
return RedirectResponse(url=f"/clients/client/{id}")
|
||||
|
||||
client = await admin.get_client(("id", id))
|
||||
|
||||
@ -4,8 +4,8 @@ Since we have a frontend and a REST API, it makes sense to have a generic librar
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from sshecret.backend import (
|
||||
AuditLog,
|
||||
@ -20,7 +20,7 @@ from sshecret.backend.models import ClientQueryResult, DetailedSecrets
|
||||
from sshecret.backend.api import AuditAPI, KeySpec
|
||||
from sshecret.crypto import encrypt_string, load_public_key
|
||||
|
||||
from .keepass import PasswordContext, load_password_manager
|
||||
from .secret_manager import AsyncSecretContext, password_manager_context
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
from .models import (
|
||||
ClientSecretGroup,
|
||||
@ -86,19 +86,27 @@ def add_clients_to_secret_group(
|
||||
class AdminBackend:
|
||||
"""Admin backend API."""
|
||||
|
||||
def __init__(self, settings: AdminServerSettings, keepass_password: str) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
settings: AdminServerSettings,
|
||||
username: str | None = None,
|
||||
origin: str = "UNKNOWN",
|
||||
) -> None:
|
||||
"""Create client management API."""
|
||||
self.settings: AdminServerSettings = settings
|
||||
self.backend: SshecretBackend = SshecretBackend(
|
||||
str(settings.backend_url), settings.backend_token
|
||||
)
|
||||
self.keepass_password: str = keepass_password
|
||||
self.username: str = username or "UKNOWN_USER"
|
||||
self.origin: str = origin
|
||||
|
||||
@contextmanager
|
||||
def password_manager(self) -> Iterator[PasswordContext]:
|
||||
"""Open the password manager."""
|
||||
with load_password_manager(self.settings, self.keepass_password) as kp:
|
||||
yield kp
|
||||
@asynccontextmanager
|
||||
async def secrets_manager(self) -> AsyncIterator[AsyncSecretContext]:
|
||||
"""Open the secrets manager."""
|
||||
async with password_manager_context(
|
||||
self.settings, self.username, self.origin
|
||||
) as manager:
|
||||
yield manager
|
||||
|
||||
async def _get_clients(self, filter: ClientFilter | None = None) -> list[Client]:
|
||||
"""Get clients from backend."""
|
||||
@ -194,7 +202,7 @@ class AdminBackend:
|
||||
self,
|
||||
name: KeySpec,
|
||||
new_key: str,
|
||||
password_manager: PasswordContext,
|
||||
password_manager: AsyncSecretContext,
|
||||
) -> list[str]:
|
||||
"""Update client public key."""
|
||||
LOG.info(
|
||||
@ -207,7 +215,7 @@ class AdminBackend:
|
||||
updated_secrets: list[str] = []
|
||||
for secret in client.secrets:
|
||||
LOG.debug("Re-encrypting secret %s for client %s", secret, name)
|
||||
secret_value = password_manager.get_secret(secret)
|
||||
secret_value = await password_manager.get_secret(secret)
|
||||
if not secret_value:
|
||||
LOG.warning(
|
||||
"Referenced secret %s does not exist! Skipping.", secret_value
|
||||
@ -224,7 +232,7 @@ class AdminBackend:
|
||||
async def update_client_public_key(self, name: KeySpec, new_key: str) -> list[str]:
|
||||
"""Update client public key."""
|
||||
try:
|
||||
with self.password_manager() as password_manager:
|
||||
async with self.secrets_manager() as password_manager:
|
||||
return await self._update_client_public_key(
|
||||
name, new_key, password_manager
|
||||
)
|
||||
@ -291,8 +299,8 @@ class AdminBackend:
|
||||
This fetches the secret to client mapping from backend, and adds secrets from the password manager.
|
||||
"""
|
||||
backend_secrets = await self.backend.get_secrets()
|
||||
with self.password_manager() as password_manager:
|
||||
admin_secrets = password_manager.get_available_secrets()
|
||||
async with self.secrets_manager() as password_manager:
|
||||
admin_secrets = await password_manager.get_available_secrets()
|
||||
|
||||
secrets: dict[str, SecretListView] = {}
|
||||
for secret in backend_secrets:
|
||||
@ -324,8 +332,8 @@ class AdminBackend:
|
||||
|
||||
This fetches the secret to client mapping from backend, and adds secrets from the password manager.
|
||||
"""
|
||||
with self.password_manager() as password_manager:
|
||||
all_secrets = password_manager.get_available_secrets()
|
||||
async with self.secrets_manager() as password_manager:
|
||||
all_secrets = await password_manager.get_available_secrets()
|
||||
|
||||
secrets = await self.backend.get_detailed_secrets()
|
||||
backend_secret_names = [secret.name for secret in secrets]
|
||||
@ -351,13 +359,13 @@ class AdminBackend:
|
||||
parent_group: str | None = None,
|
||||
) -> None:
|
||||
"""Add secret group."""
|
||||
with self.password_manager() as password_manager:
|
||||
password_manager.add_group(group_name, description, parent_group)
|
||||
async with self.secrets_manager() as password_manager:
|
||||
await password_manager.add_group(group_name, description, parent_group)
|
||||
|
||||
async def set_secret_group(self, secret_name: str, group_name: str | None) -> None:
|
||||
"""Assign a group to a secret."""
|
||||
with self.password_manager() as password_manager:
|
||||
password_manager.set_secret_group(secret_name, group_name)
|
||||
async with self.secrets_manager() as password_manager:
|
||||
await password_manager.set_secret_group(secret_name, group_name)
|
||||
|
||||
async def move_secret_group(
|
||||
self, group_name: str, parent_group: str | None
|
||||
@ -366,23 +374,21 @@ class AdminBackend:
|
||||
|
||||
If parent_group is None, it will be moved to the root.
|
||||
"""
|
||||
with self.password_manager() as password_manager:
|
||||
password_manager.move_group(group_name, parent_group)
|
||||
async with self.secrets_manager() as password_manager:
|
||||
await password_manager.move_group(group_name, parent_group)
|
||||
|
||||
async def set_group_description(self, group_name: str, description: str) -> None:
|
||||
"""Set a group description."""
|
||||
with self.password_manager() as password_manager:
|
||||
password_manager.set_group_description(group_name, description)
|
||||
async with self.secrets_manager() as password_manager:
|
||||
await password_manager.set_group_description(group_name, description)
|
||||
|
||||
async def delete_secret_group(
|
||||
self, group_name: str, keep_entries: bool = True
|
||||
) -> None:
|
||||
async def delete_secret_group(self, group_name: str) -> None:
|
||||
"""Delete a group.
|
||||
|
||||
If keep_entries is set to False, all entries in the group will be deleted.
|
||||
"""
|
||||
with self.password_manager() as password_manager:
|
||||
password_manager.delete_group(group_name, keep_entries)
|
||||
async with self.secrets_manager() as password_manager:
|
||||
await password_manager.delete_group(group_name)
|
||||
|
||||
async def get_secret_groups(
|
||||
self,
|
||||
@ -399,18 +405,18 @@ class AdminBackend:
|
||||
"""
|
||||
all_secrets = await self.backend.get_detailed_secrets()
|
||||
secrets_mapping = {secret.name: secret for secret in all_secrets}
|
||||
with self.password_manager() as password_manager:
|
||||
async with self.secrets_manager() as password_manager:
|
||||
if flat:
|
||||
all_groups = password_manager.get_secret_group_list(
|
||||
all_groups = await password_manager.get_secret_group_list(
|
||||
group_filter, regex=regex
|
||||
)
|
||||
else:
|
||||
all_groups = password_manager.get_secret_groups(
|
||||
all_groups = await password_manager.get_secret_groups(
|
||||
group_filter, regex=regex
|
||||
)
|
||||
ungrouped = password_manager.get_ungrouped_secrets()
|
||||
ungrouped = await password_manager.get_ungrouped_secrets()
|
||||
|
||||
all_admin_secrets = password_manager.get_available_secrets()
|
||||
all_admin_secrets = await password_manager.get_available_secrets()
|
||||
|
||||
group_result: list[ClientSecretGroup] = []
|
||||
for group in all_groups:
|
||||
@ -452,8 +458,8 @@ class AdminBackend:
|
||||
|
||||
async def get_secret_group_by_path(self, path: str) -> ClientSecretGroup | None:
|
||||
"""Get a group based on its path."""
|
||||
with self.password_manager() as password_manager:
|
||||
secret_group = password_manager.get_secret_group(path)
|
||||
async with self.secrets_manager() as password_manager:
|
||||
secret_group = await password_manager.get_secret_group(path)
|
||||
|
||||
if not secret_group:
|
||||
return None
|
||||
@ -476,9 +482,11 @@ class AdminBackend:
|
||||
) -> SecretView | None:
|
||||
"""Get a secret, including the actual unencrypted value and clients."""
|
||||
secret: str | None = None
|
||||
with self.password_manager() as password_manager:
|
||||
secret = password_manager.get_secret(name)
|
||||
secret_group = password_manager.get_entry_group(name)
|
||||
async with self.secrets_manager() as password_manager:
|
||||
secret = await password_manager.get_secret(name)
|
||||
secret_group: str | None = None
|
||||
if secret:
|
||||
secret_group = await password_manager.get_entry_group(name)
|
||||
|
||||
secret_view = SecretView(name=name, secret=secret, group=secret_group)
|
||||
|
||||
@ -503,8 +511,8 @@ class AdminBackend:
|
||||
|
||||
async def _delete_secret(self, name: str) -> None:
|
||||
"""Delete a secret."""
|
||||
with self.password_manager() as password_manager:
|
||||
password_manager.delete_entry(name)
|
||||
async with self.secrets_manager() as password_manager:
|
||||
await password_manager.delete_entry(name)
|
||||
|
||||
secret_mapping = await self.backend.get_secret(name)
|
||||
if not secret_mapping:
|
||||
@ -522,8 +530,8 @@ class AdminBackend:
|
||||
group: str | None = None,
|
||||
) -> None:
|
||||
"""Add a secret."""
|
||||
with self.password_manager() as password_manager:
|
||||
password_manager.add_entry(name, value, update, group_name=group)
|
||||
async with self.secrets_manager() as password_manager:
|
||||
await password_manager.add_entry(name, value, update, group_path=group)
|
||||
|
||||
if update:
|
||||
secret_map = await self.backend.get_secret(name)
|
||||
@ -576,8 +584,8 @@ class AdminBackend:
|
||||
if not client:
|
||||
raise ClientNotFoundError(client_idname)
|
||||
|
||||
with self.password_manager() as password_manager:
|
||||
secret = password_manager.get_secret(secret_name)
|
||||
async with self.secrets_manager() as password_manager:
|
||||
secret = await password_manager.get_secret(secret_name)
|
||||
if not secret:
|
||||
raise SecretNotFoundError()
|
||||
|
||||
|
||||
@ -1,348 +0,0 @@
|
||||
"""Keepass password manager."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import pykeepass
|
||||
import pykeepass.exceptions
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
|
||||
from .models import SecretGroup
|
||||
from .master_password import decrypt_master_password
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
NO_USERNAME = "NO_USERNAME"
|
||||
|
||||
DEFAULT_LOCATION = "keepass.kdbx"
|
||||
|
||||
|
||||
class PasswordCredentialsError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def create_password_db(location: Path, password: str) -> None:
|
||||
"""Create the password database."""
|
||||
LOG.info("Creating password database at %s", location)
|
||||
pykeepass.create_database(str(location.absolute()), password=password)
|
||||
|
||||
|
||||
def _kp_group_to_secret_group(
|
||||
kp_group: pykeepass.group.Group,
|
||||
parent: SecretGroup | None = None,
|
||||
depth: int | None = None,
|
||||
) -> SecretGroup:
|
||||
"""Convert keepass group to secret group dataclass."""
|
||||
group_name = cast(str, kp_group.name)
|
||||
path = "/".join(cast(list[str], kp_group.path))
|
||||
group = SecretGroup(name=group_name, path=path, description=kp_group.notes)
|
||||
for entry in kp_group.entries:
|
||||
group.entries.append(str(entry.title))
|
||||
if parent:
|
||||
group.parent_group = parent
|
||||
|
||||
current_depth = len(kp_group.path)
|
||||
|
||||
if not parent and current_depth > 1:
|
||||
parent = _kp_group_to_secret_group(kp_group.parentgroup, depth=current_depth)
|
||||
parent.children.append(group)
|
||||
group.parent_group = parent
|
||||
|
||||
if depth and depth == current_depth:
|
||||
return group
|
||||
|
||||
for subgroup in kp_group.subgroups:
|
||||
group.children.append(_kp_group_to_secret_group(subgroup, group, depth=depth))
|
||||
|
||||
return group
|
||||
|
||||
|
||||
class PasswordContext:
|
||||
"""Password Context class."""
|
||||
|
||||
def __init__(self, keepass: pykeepass.PyKeePass) -> None:
|
||||
"""Initialize password context."""
|
||||
self.keepass: pykeepass.PyKeePass = keepass
|
||||
|
||||
@property
|
||||
def _root_group(self) -> pykeepass.group.Group:
|
||||
"""Return the root group."""
|
||||
return cast(pykeepass.group.Group, self.keepass.root_group)
|
||||
|
||||
def _get_entry(self, name: str) -> pykeepass.entry.Entry | None:
|
||||
"""Get entry."""
|
||||
entry = cast(
|
||||
"pykeepass.entry.Entry | None",
|
||||
self.keepass.find_entries(title=name, first=True),
|
||||
)
|
||||
return entry
|
||||
|
||||
def _get_group(self, name: str) -> pykeepass.group.Group | None:
|
||||
"""Find a group."""
|
||||
group = cast(
|
||||
pykeepass.group.Group | None,
|
||||
self.keepass.find_groups(name=name, first=True),
|
||||
)
|
||||
return group
|
||||
|
||||
def add_entry(
|
||||
self,
|
||||
entry_name: str,
|
||||
secret: str,
|
||||
overwrite: bool = False,
|
||||
group_name: str | None = None,
|
||||
) -> None:
|
||||
"""Add an entry.
|
||||
|
||||
Specify overwrite=True to overwrite the existing secret value, if it exists.
|
||||
This will not move the entry, if the group_name is different from the original group.
|
||||
|
||||
"""
|
||||
entry = self._get_entry(entry_name)
|
||||
if entry and overwrite:
|
||||
entry.password = secret
|
||||
self.keepass.save()
|
||||
return
|
||||
|
||||
if entry:
|
||||
raise ValueError("Error: A secret with this name already exists.")
|
||||
LOG.debug("Add secret entry to keepass: %s, group: %r", entry_name, group_name)
|
||||
if group_name:
|
||||
destination_group = self._get_group(group_name)
|
||||
else:
|
||||
destination_group = self._root_group
|
||||
|
||||
entry = self.keepass.add_entry(
|
||||
destination_group=destination_group,
|
||||
title=entry_name,
|
||||
username=NO_USERNAME,
|
||||
password=secret,
|
||||
)
|
||||
self.keepass.save()
|
||||
|
||||
def get_secret(self, entry_name: str) -> str | None:
|
||||
"""Get the secret value."""
|
||||
entry = self._get_entry(entry_name)
|
||||
if not entry:
|
||||
return None
|
||||
|
||||
LOG.warning("Secret name %s accessed", entry_name)
|
||||
if password := cast(str, entry.password):
|
||||
return str(password)
|
||||
|
||||
raise RuntimeError(f"Cannot get password for entry {entry_name}")
|
||||
|
||||
def get_entry_group(self, entry_name: str) -> str | None:
|
||||
"""Get the group for an entry."""
|
||||
entry = self._get_entry(entry_name)
|
||||
if not entry:
|
||||
return None
|
||||
if entry.group.is_root_group:
|
||||
return None
|
||||
return str(entry.group.name)
|
||||
|
||||
def get_secret_groups(
|
||||
self, pattern: str | None = None, regex: bool = True
|
||||
) -> list[SecretGroup]:
|
||||
"""Get secret groups.
|
||||
|
||||
A regex pattern may be provided to filter groups.
|
||||
"""
|
||||
if pattern:
|
||||
groups = cast(
|
||||
list[pykeepass.group.Group],
|
||||
self.keepass.find_groups(name=pattern, regex=regex),
|
||||
)
|
||||
else:
|
||||
groups = self._root_group.subgroups
|
||||
|
||||
secret_groups = [_kp_group_to_secret_group(group) for group in groups]
|
||||
return secret_groups
|
||||
|
||||
def get_secret_group_list(
|
||||
self, pattern: str | None = None, regex: bool = True
|
||||
) -> list[SecretGroup]:
|
||||
"""Get a flat list of groups."""
|
||||
if pattern:
|
||||
return self.get_secret_groups(pattern, regex)
|
||||
|
||||
groups = [group for group in self.keepass.groups if not group.is_root_group]
|
||||
secret_groups = [_kp_group_to_secret_group(group) for group in groups]
|
||||
return secret_groups
|
||||
|
||||
def get_secret_group(self, path: str) -> SecretGroup | None:
|
||||
"""Get a secret group by path."""
|
||||
elements = path.split("/")
|
||||
final_element = elements[-1]
|
||||
|
||||
current = self._root_group
|
||||
while elements:
|
||||
groupname = elements.pop(0)
|
||||
matches = [
|
||||
subgroup for subgroup in current.subgroups if subgroup.name == groupname
|
||||
]
|
||||
if matches:
|
||||
current = matches[0]
|
||||
else:
|
||||
return None
|
||||
if not current.is_root_group and current.name == final_element:
|
||||
return _kp_group_to_secret_group(current)
|
||||
return None
|
||||
|
||||
def get_ungrouped_secrets(self) -> list[str]:
|
||||
"""Get secrets without groups."""
|
||||
entries: list[str] = []
|
||||
for entry in self._root_group.entries:
|
||||
entries.append(str(entry.title))
|
||||
|
||||
return entries
|
||||
|
||||
def add_group(
|
||||
self, name: str, description: str | None = None, parent_group: str | None = None
|
||||
) -> None:
|
||||
"""Add a group."""
|
||||
kp_parent_group = self._root_group
|
||||
if parent_group:
|
||||
query = cast(
|
||||
pykeepass.group.Group | None,
|
||||
self.keepass.find_groups(name=parent_group, first=True),
|
||||
)
|
||||
if not query:
|
||||
raise ValueError(
|
||||
f"Error: Cannot find a parent group named {parent_group}"
|
||||
)
|
||||
kp_parent_group = query
|
||||
self.keepass.add_group(
|
||||
destination_group=kp_parent_group, group_name=name, notes=description
|
||||
)
|
||||
self.keepass.save()
|
||||
|
||||
def set_group_description(self, name: str, description: str) -> None:
|
||||
"""Set the description of a group."""
|
||||
group = self._get_group(name)
|
||||
if not group:
|
||||
raise ValueError(f"Error: No such group {name}")
|
||||
|
||||
group.notes = description
|
||||
self.keepass.save()
|
||||
|
||||
def set_secret_group(self, entry_name: str, group_name: str | None) -> None:
|
||||
"""Move a secret to a group.
|
||||
|
||||
If group is None, the secret will be placed in the root group.
|
||||
"""
|
||||
entry = self._get_entry(entry_name)
|
||||
if not entry:
|
||||
raise ValueError(
|
||||
f"Cannot find secret entry named {entry_name} in secrets database"
|
||||
)
|
||||
if group_name:
|
||||
group = self._get_group(group_name)
|
||||
if not group:
|
||||
raise ValueError(f"Cannot find a group named {group_name}")
|
||||
else:
|
||||
group = self._root_group
|
||||
|
||||
self.keepass.move_entry(entry, group)
|
||||
self.keepass.save()
|
||||
|
||||
def move_group(self, name: str, parent_group: str | None) -> None:
|
||||
"""Move a group.
|
||||
|
||||
If parent_group is None, it will be moved to the root.
|
||||
"""
|
||||
group = self._get_group(name)
|
||||
if not group:
|
||||
raise ValueError(f"Error: No such group {name}")
|
||||
if parent_group:
|
||||
parent = self._get_group(parent_group)
|
||||
if not parent:
|
||||
raise ValueError(f"Error: No such group {parent_group}")
|
||||
else:
|
||||
parent = self._root_group
|
||||
|
||||
self.keepass.move_group(group, parent)
|
||||
self.keepass.save()
|
||||
|
||||
def get_available_secrets(self, group_name: str | None = None) -> list[str]:
|
||||
"""Get the names of all secrets in the database."""
|
||||
if group_name:
|
||||
group = self._get_group(group_name)
|
||||
if not group:
|
||||
raise ValueError(f"Error: No such group {group_name}")
|
||||
entries = group.entries
|
||||
else:
|
||||
entries = cast(list[pykeepass.entry.Entry], self.keepass.entries)
|
||||
if not entries:
|
||||
return []
|
||||
return [str(entry.title) for entry in entries]
|
||||
|
||||
def delete_entry(self, entry_name: str) -> None:
|
||||
"""Delete entry."""
|
||||
entry = cast(
|
||||
"pykeepass.entry.Entry | None",
|
||||
self.keepass.find_entries(title=entry_name, first=True),
|
||||
)
|
||||
if not entry:
|
||||
return
|
||||
entry.delete()
|
||||
self.keepass.save()
|
||||
|
||||
def delete_group(self, name: str, keep_entries: bool = True) -> None:
|
||||
"""Delete a group.
|
||||
|
||||
If keep_entries is set to False, all entries in the group will be deleted.
|
||||
"""
|
||||
group = self._get_group(name)
|
||||
if not group:
|
||||
return
|
||||
if keep_entries:
|
||||
for entry in cast(
|
||||
list[pykeepass.entry.Entry],
|
||||
self.keepass.find_entries(recursive=True, group=group),
|
||||
):
|
||||
# Move the entry to the root group.
|
||||
LOG.warning(
|
||||
"Moving orphaned secret entry %s to root group", entry.title
|
||||
)
|
||||
self.keepass.move_entry(entry, self._root_group)
|
||||
|
||||
self.keepass.delete_group(group)
|
||||
self.keepass.save()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _password_context(location: Path, password: str) -> Iterator[PasswordContext]:
|
||||
"""Open the password context."""
|
||||
try:
|
||||
database = pykeepass.PyKeePass(str(location.absolute()), password=password)
|
||||
except pykeepass.exceptions.CredentialsError as e:
|
||||
raise PasswordCredentialsError(
|
||||
"Could not open password database. Invalid credentials."
|
||||
) from e
|
||||
context = PasswordContext(database)
|
||||
yield context
|
||||
|
||||
|
||||
@contextmanager
|
||||
def load_password_manager(
|
||||
settings: AdminServerSettings,
|
||||
encrypted_password: str,
|
||||
location: str = DEFAULT_LOCATION,
|
||||
) -> Iterator[PasswordContext]:
|
||||
"""Load password manager.
|
||||
|
||||
This function decrypts the password, and creates the password database if it
|
||||
has not yet been created.
|
||||
"""
|
||||
db_location = Path(location)
|
||||
password = decrypt_master_password(settings=settings, encrypted=encrypted_password)
|
||||
if not db_location.exists():
|
||||
create_password_db(db_location, password)
|
||||
|
||||
with _password_context(db_location, password) as context:
|
||||
yield context
|
||||
@ -1,86 +0,0 @@
|
||||
"""Functions related to handling the password database master password."""
|
||||
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
from sshecret.crypto import (
|
||||
create_private_rsa_key,
|
||||
load_private_key,
|
||||
encrypt_string,
|
||||
decode_string,
|
||||
)
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
|
||||
KEY_FILENAME = "sshecret-admin-key"
|
||||
|
||||
|
||||
def setup_master_password(
|
||||
settings: AdminServerSettings,
|
||||
filename: str = KEY_FILENAME,
|
||||
regenerate: bool = False,
|
||||
) -> str | None:
|
||||
"""Setup master password.
|
||||
|
||||
If regenerate is True, a new key will be generated.
|
||||
|
||||
This method should run just after setting up the database.
|
||||
"""
|
||||
keyfile = Path(filename)
|
||||
if settings.password_manager_directory:
|
||||
keyfile = settings.password_manager_directory / filename
|
||||
created = _initial_key_setup(settings, keyfile, regenerate)
|
||||
if not created:
|
||||
return None
|
||||
|
||||
return _generate_master_password(settings, keyfile)
|
||||
|
||||
|
||||
def decrypt_master_password(
|
||||
settings: AdminServerSettings, encrypted: str, filename: str = KEY_FILENAME
|
||||
) -> str:
|
||||
"""Retrieve master password."""
|
||||
keyfile = Path(filename)
|
||||
if settings.password_manager_directory:
|
||||
keyfile = settings.password_manager_directory / filename
|
||||
if not keyfile.exists():
|
||||
raise RuntimeError("Error: Private key has not been generated yet.")
|
||||
|
||||
private_key = load_private_key(
|
||||
str(keyfile.absolute()), password=settings.secret_key
|
||||
)
|
||||
return decode_string(encrypted, private_key)
|
||||
|
||||
|
||||
def _generate_password() -> str:
|
||||
"""Generate a password."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def _initial_key_setup(
|
||||
settings: AdminServerSettings,
|
||||
keyfile: Path,
|
||||
regenerate: bool = False,
|
||||
) -> bool:
|
||||
"""Set up initial keys."""
|
||||
if keyfile.exists() and not regenerate:
|
||||
return False
|
||||
|
||||
assert settings.secret_key is not None, (
|
||||
"Error: Could not load a secret key from environment."
|
||||
)
|
||||
create_private_rsa_key(keyfile, password=settings.secret_key)
|
||||
return True
|
||||
|
||||
|
||||
def _generate_master_password(settings: AdminServerSettings, keyfile: Path) -> str:
|
||||
"""Generate master password for password database.
|
||||
|
||||
Returns the encrypted string, base64 encoded.
|
||||
"""
|
||||
if not keyfile.exists():
|
||||
raise RuntimeError("Error: Private key has not been generated yet.")
|
||||
private_key = load_private_key(
|
||||
str(keyfile.absolute()), password=settings.secret_key
|
||||
)
|
||||
public_key = private_key.public_key()
|
||||
master_password = _generate_password()
|
||||
return encrypt_string(master_password, public_key)
|
||||
@ -0,0 +1,776 @@
|
||||
"""Rewritten secret manager using a rsa keys."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload, aliased
|
||||
from sshecret.backend import SshecretBackend
|
||||
from sshecret.backend.api import AuditAPI, KeySpec
|
||||
from sshecret.backend.models import Client, ClientSecret, Operation, SubSystem
|
||||
from sshecret.crypto import (
|
||||
create_private_rsa_key,
|
||||
decode_string,
|
||||
encrypt_string,
|
||||
generate_public_key_string,
|
||||
load_private_key,
|
||||
load_public_key,
|
||||
)
|
||||
from sshecret_admin.auth import PasswordDB
|
||||
from sshecret_admin.auth.models import Group, ManagedSecret
|
||||
from sshecret_admin.core.db import DatabaseSessionManager
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
from sshecret_admin.services.models import SecretGroup
|
||||
|
||||
|
||||
KEY_FILENAME = "sshecret-admin-key"
|
||||
PASSWORD_MANAGER_ID = "SshecretAdminPasswordManager"
|
||||
|
||||
LOG = logging.getLogger(PASSWORD_MANAGER_ID)
|
||||
|
||||
|
||||
class SecretManagerError(Exception):
|
||||
"""Secret manager error."""
|
||||
|
||||
|
||||
class InvalidGroupNameError(SecretManagerError):
|
||||
"""Invalid group name."""
|
||||
|
||||
|
||||
class InvalidSecretNameError(SecretManagerError):
|
||||
"""Invalid secret name."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientAuditData:
|
||||
"""Client audit data."""
|
||||
|
||||
username: str
|
||||
origin: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedPath:
|
||||
"""Parsed path."""
|
||||
|
||||
item: str
|
||||
full_path: str
|
||||
parent: str | None = None
|
||||
|
||||
|
||||
class SecretDataEntryExport(BaseModel):
|
||||
"""Exportable secret entry."""
|
||||
|
||||
name: str
|
||||
secret: str
|
||||
group: str | None = None
|
||||
|
||||
|
||||
class SecretDataGroupExport(BaseModel):
|
||||
"""Exportable secret grouping."""
|
||||
|
||||
name: str
|
||||
path: str
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class SecretDataExport(BaseModel):
|
||||
"""Exportable object containing secrets and groups."""
|
||||
|
||||
entries: list[SecretDataEntryExport]
|
||||
groups: list[SecretDataGroupExport]
|
||||
|
||||
|
||||
def split_path(path: str) -> list[str]:
|
||||
"""Split a path into a list of groups."""
|
||||
elements = path.split("/")
|
||||
if path.startswith("/"):
|
||||
elements = elements[1:]
|
||||
|
||||
return elements
|
||||
|
||||
|
||||
def parse_path(path: str) -> ParsedPath:
|
||||
"""Parse path."""
|
||||
elements = split_path(path)
|
||||
parsed = ParsedPath(elements[-1], path)
|
||||
if len(elements) > 1:
|
||||
parsed.parent = elements[-2]
|
||||
return parsed
|
||||
|
||||
|
||||
class AsyncSecretContext:
|
||||
"""Async secret context."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
private_key: rsa.RSAPrivateKey,
|
||||
manager_client: Client,
|
||||
session: AsyncSession,
|
||||
backend: SshecretBackend,
|
||||
audit_data: ClientAuditData,
|
||||
) -> None:
|
||||
"""Initialize secret manager"""
|
||||
self._private_key: rsa.RSAPrivateKey = private_key
|
||||
self._manager_client: Client = manager_client
|
||||
self._id: KeySpec = ("id", str(manager_client.id))
|
||||
self.backend: SshecretBackend = backend
|
||||
self.session: AsyncSession = session
|
||||
|
||||
self.audit_data: ClientAuditData = audit_data
|
||||
self.audit: AuditAPI = backend.audit(SubSystem.ADMIN)
|
||||
self._import_has_run: bool = False
|
||||
|
||||
async def _create_missing_entries(self) -> None:
|
||||
"""Create any missing entries."""
|
||||
new_secrets: bool = False
|
||||
to_check = set(self._manager_client.secrets)
|
||||
for secret_name in to_check:
|
||||
# entry = await self._get_entry(secret_name, include_deleted=True)
|
||||
statement = select(ManagedSecret).where(ManagedSecret.name == secret_name)
|
||||
result = await self.session.scalars(statement)
|
||||
if not result.first():
|
||||
new_secrets = True
|
||||
managed_secret = ManagedSecret(name=secret_name)
|
||||
self.session.add(managed_secret)
|
||||
|
||||
await self.session.flush()
|
||||
await self.write_audit(
|
||||
Operation.CREATE,
|
||||
message="Imported managed secret from backend.",
|
||||
secret_name=secret_name,
|
||||
managed_secret=managed_secret,
|
||||
)
|
||||
if new_secrets:
|
||||
await self.session.commit()
|
||||
|
||||
async def _get_group_depth(self, group: Group) -> int:
|
||||
"""Get the depth of a group."""
|
||||
depth = 1
|
||||
if not group.parent_id:
|
||||
return depth
|
||||
|
||||
current = group
|
||||
while current.parent is not None:
|
||||
if current.parent:
|
||||
depth += 1
|
||||
current = await self._get_group_by_id(current.parent.id)
|
||||
else:
|
||||
break
|
||||
|
||||
return depth
|
||||
|
||||
async def _get_group_path(self, group: Group) -> str:
|
||||
"""Get the path of a group."""
|
||||
|
||||
if not group.parent_id:
|
||||
return group.name
|
||||
path: list[str] = []
|
||||
current = group
|
||||
while current.parent_id is not None:
|
||||
path.append(current.name)
|
||||
current = await self._get_group_by_id(current.parent_id)
|
||||
|
||||
path.append("")
|
||||
path.reverse()
|
||||
return "/".join(path)
|
||||
|
||||
async def _get_group_secrets(self, group: Group) -> list[ManagedSecret]:
|
||||
"""Get secrets in a group."""
|
||||
statement = (
|
||||
select(ManagedSecret)
|
||||
.where(ManagedSecret.group_id == group.id)
|
||||
.where(ManagedSecret.is_deleted.is_not(True))
|
||||
)
|
||||
results = await self.session.scalars(statement)
|
||||
return list(results.all())
|
||||
|
||||
async def _build_group_tree(
|
||||
self, group: Group, parent: SecretGroup | None = None, depth: int | None = None
|
||||
) -> SecretGroup:
|
||||
"""Build a group tree."""
|
||||
path = "/"
|
||||
if parent:
|
||||
path = os.path.join(parent.path, path)
|
||||
secret_group = SecretGroup(
|
||||
name=group.name, path=path, description=group.description
|
||||
)
|
||||
group_secrets = await self._get_group_secrets(group)
|
||||
for secret in group_secrets:
|
||||
secret_group.entries.append(secret.name)
|
||||
if parent:
|
||||
secret_group.parent_group = parent
|
||||
|
||||
current_depth = await self._get_group_depth(group)
|
||||
|
||||
if not parent and group.parent:
|
||||
parent_group = await self._get_group_by_id(group.parent.id)
|
||||
assert parent_group is not None
|
||||
parent = await self._build_group_tree(parent_group, depth=current_depth)
|
||||
parent.children.append(secret_group)
|
||||
secret_group.parent_group = parent
|
||||
|
||||
if depth and depth == current_depth:
|
||||
return secret_group
|
||||
|
||||
for subgroup in group.children:
|
||||
child_group = await self._get_group_by_id(subgroup.id)
|
||||
assert child_group is not None
|
||||
secret_subgroup = await self._build_group_tree(
|
||||
child_group, secret_group, depth=depth
|
||||
)
|
||||
secret_group.children.append(secret_subgroup)
|
||||
|
||||
return secret_group
|
||||
|
||||
async def write_audit(
|
||||
self,
|
||||
operation: Operation,
|
||||
message: str,
|
||||
group_name: str | None = None,
|
||||
client_secret: ClientSecret | None = None,
|
||||
secret_name: str | None = None,
|
||||
managed_secret: ManagedSecret | None = None,
|
||||
**data: str,
|
||||
) -> None:
|
||||
"""Write Audit message."""
|
||||
if group_name:
|
||||
data["group"] = group_name
|
||||
|
||||
data["username"] = self.audit_data.username
|
||||
if client_secret and not secret_name:
|
||||
secret_name = client_secret.name
|
||||
|
||||
if managed_secret:
|
||||
data["managed_secret"] = str(managed_secret.id)
|
||||
|
||||
await self.audit.write_async(
|
||||
operation=operation,
|
||||
message=message,
|
||||
origin=self.audit_data.origin,
|
||||
client=self._manager_client,
|
||||
secret=client_secret,
|
||||
secret_name=secret_name,
|
||||
**data,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def public_key(self) -> rsa.RSAPublicKey:
|
||||
"""Get public key."""
|
||||
keystring = self._manager_client.public_key
|
||||
return load_public_key(keystring.encode())
|
||||
|
||||
async def _get_entry(
|
||||
self, name: str, include_deleted: bool = False
|
||||
) -> ManagedSecret | None:
|
||||
"""Get managed secret."""
|
||||
if not self._import_has_run:
|
||||
await self._create_missing_entries()
|
||||
self._import_has_run = True
|
||||
statement = (
|
||||
select(ManagedSecret)
|
||||
.options(selectinload(ManagedSecret.group))
|
||||
.where(ManagedSecret.name == name)
|
||||
)
|
||||
if not include_deleted:
|
||||
statement = statement.where(ManagedSecret.is_deleted.is_not(True))
|
||||
|
||||
result = await self.session.scalars(statement)
|
||||
return result.first()
|
||||
|
||||
async def add_entry(
|
||||
self,
|
||||
entry_name: str,
|
||||
secret: str,
|
||||
overwrite: bool = False,
|
||||
group_path: str | None = None,
|
||||
) -> None:
|
||||
"""Add entry."""
|
||||
existing_entry = await self._get_entry(entry_name)
|
||||
if existing_entry and not overwrite:
|
||||
raise InvalidSecretNameError(
|
||||
"Another secret with this name is already defined."
|
||||
)
|
||||
|
||||
encrypted = encrypt_string(secret, self.public_key)
|
||||
client_secret = await self.backend.create_client_secret(
|
||||
self._id, entry_name, encrypted
|
||||
)
|
||||
group_id: uuid.UUID | None = None
|
||||
if group_path:
|
||||
elements = parse_path(group_path)
|
||||
group = await self._get_group(elements.item, elements.parent, True)
|
||||
if not group:
|
||||
raise InvalidGroupNameError("Invalid group name")
|
||||
group_id = group.id
|
||||
|
||||
if existing_entry:
|
||||
existing_entry.updated_at = datetime.now(timezone.utc)
|
||||
if group_id:
|
||||
existing_entry.group_id = group_id
|
||||
self.session.add(existing_entry)
|
||||
await self.session.commit()
|
||||
await self.write_audit(
|
||||
Operation.UPDATE,
|
||||
"Updated secret value",
|
||||
group_name=group_path,
|
||||
client_secret=client_secret,
|
||||
managed_secret=existing_entry,
|
||||
)
|
||||
else:
|
||||
managed_secret = ManagedSecret(
|
||||
name=entry_name,
|
||||
group_id=group_id,
|
||||
)
|
||||
self.session.add(managed_secret)
|
||||
|
||||
await self.session.commit()
|
||||
await self.write_audit(
|
||||
Operation.CREATE,
|
||||
"Created managed client secret",
|
||||
group_path,
|
||||
client_secret=client_secret,
|
||||
managed_secret=managed_secret,
|
||||
)
|
||||
|
||||
async def get_secret(self, entry_name: str) -> str | None:
|
||||
"""Get secret."""
|
||||
client_secret = await self.backend.get_client_secret(
|
||||
self._id, ("name", entry_name)
|
||||
)
|
||||
if not client_secret:
|
||||
return None
|
||||
decrypted = decode_string(client_secret, self._private_key)
|
||||
await self.write_audit(
|
||||
Operation.READ,
|
||||
"Secret was viewed from secret manager",
|
||||
secret_name=entry_name,
|
||||
)
|
||||
|
||||
return decrypted
|
||||
|
||||
async def get_available_secrets(self, group_path: str | None = None) -> list[str]:
|
||||
"""Get the names of all secrets in the db."""
|
||||
if not self._import_has_run:
|
||||
await self._create_missing_entries()
|
||||
if group_path:
|
||||
elements = parse_path(group_path)
|
||||
group = await self._get_group(elements.item, elements.parent)
|
||||
if not group:
|
||||
raise InvalidGroupNameError("Invalid or nonexisting group name.")
|
||||
entries = group.secrets
|
||||
else:
|
||||
result = await self.session.scalars(
|
||||
select(ManagedSecret)
|
||||
.options(selectinload(ManagedSecret.group))
|
||||
.where(ManagedSecret.is_deleted.is_not(True))
|
||||
)
|
||||
|
||||
entries = list(result.all())
|
||||
|
||||
return [entry.name for entry in entries]
|
||||
|
||||
async def delete_entry(self, entry_name: str) -> None:
|
||||
"""Delete a secret."""
|
||||
entry = await self._get_entry(entry_name)
|
||||
if not entry:
|
||||
return
|
||||
entry.is_deleted = True
|
||||
entry.deleted_at = datetime.now(timezone.utc)
|
||||
self.session.add(entry)
|
||||
await self.session.commit()
|
||||
await self.backend.delete_client_secret(
|
||||
("id", str(self._manager_client.id)), ("name", entry_name)
|
||||
)
|
||||
await self.write_audit(
|
||||
Operation.DELETE,
|
||||
"Managed secret entry deleted",
|
||||
secret_name=entry_name,
|
||||
managed_secret=entry,
|
||||
)
|
||||
|
||||
async def get_entry_group(self, entry_name: str) -> str | None:
|
||||
"""Get group of entry."""
|
||||
entry = await self._get_entry(entry_name)
|
||||
if not entry:
|
||||
raise InvalidSecretNameError("Invalid secret name or secret not found.")
|
||||
if entry.group:
|
||||
return entry.group.name
|
||||
return None
|
||||
|
||||
async def _get_groups(
|
||||
self, pattern: str | None = None, regex: bool = True, root_groups: bool = False
|
||||
) -> list[Group]:
|
||||
"""Get groups."""
|
||||
statement = select(Group).options(
|
||||
selectinload(Group.children), selectinload(Group.parent)
|
||||
)
|
||||
if pattern and regex:
|
||||
statement = statement.where(Group.name.regexp_match(pattern))
|
||||
elif pattern:
|
||||
statement = statement.where(Group.name.contains(pattern))
|
||||
if root_groups:
|
||||
statement = statement.where(Group.parent_id == None)
|
||||
results = await self.session.scalars(statement)
|
||||
return list(results.all())
|
||||
|
||||
async def get_secret_groups(
|
||||
self, pattern: str | None = None, regex: bool = True
|
||||
) -> list[SecretGroup]:
|
||||
"""Get secret groups, as a hierarcy."""
|
||||
if pattern:
|
||||
groups = await self._get_groups(pattern, regex)
|
||||
else:
|
||||
groups = await self._get_groups(root_groups=True)
|
||||
|
||||
secret_groups: list[SecretGroup] = []
|
||||
for group in groups:
|
||||
secret_group = await self._build_group_tree(group)
|
||||
secret_groups.append(secret_group)
|
||||
|
||||
return secret_groups
|
||||
|
||||
async def get_secret_group_list(
|
||||
self, pattern: str | None = None, regex: bool = True
|
||||
) -> list[SecretGroup]:
|
||||
"""Get secret group list."""
|
||||
groups = await self._get_groups(pattern, regex)
|
||||
return [(await self._build_group_tree(group)) for group in groups]
|
||||
|
||||
async def _get_group_by_id(self, id: uuid.UUID) -> Group:
|
||||
"""Get group by ID."""
|
||||
statement = (
|
||||
select(Group)
|
||||
.options(
|
||||
selectinload(Group.parent),
|
||||
selectinload(Group.children),
|
||||
selectinload(Group.secrets),
|
||||
)
|
||||
.where(Group.id == id)
|
||||
)
|
||||
|
||||
result = await self.session.scalars(statement)
|
||||
return result.one()
|
||||
|
||||
async def _get_group(
|
||||
self, name: str, parent: str | None = None, exact_match: bool = False
|
||||
) -> Group | None:
|
||||
"""Get a group."""
|
||||
statement = (
|
||||
select(Group)
|
||||
.options(
|
||||
selectinload(Group.parent),
|
||||
selectinload(Group.children),
|
||||
selectinload(Group.secrets),
|
||||
)
|
||||
.where(Group.name == name)
|
||||
)
|
||||
if parent:
|
||||
ParentGroup = aliased(Group)
|
||||
statement = statement.join(ParentGroup, Group.parent).where(
|
||||
ParentGroup.name == parent
|
||||
)
|
||||
elif exact_match:
|
||||
statement = statement.where(Group.parent_id == None)
|
||||
result = await self.session.scalars(statement)
|
||||
return result.first()
|
||||
|
||||
async def get_secret_group(self, path: str) -> SecretGroup | None:
|
||||
"""Get a secret group by path."""
|
||||
elements = parse_path(path)
|
||||
|
||||
group_name = elements.item
|
||||
parent_group = elements.parent
|
||||
|
||||
group = await self._get_group(group_name, parent_group)
|
||||
if not group:
|
||||
return None
|
||||
|
||||
return await self._build_group_tree(group)
|
||||
|
||||
async def get_ungrouped_secrets(self) -> list[str]:
|
||||
"""Get ungrouped secrets."""
|
||||
statement = (
|
||||
select(ManagedSecret)
|
||||
.where(ManagedSecret.is_deleted.is_not(True))
|
||||
.where(ManagedSecret.group_id == None)
|
||||
)
|
||||
result = await self.session.scalars(statement)
|
||||
secrets = result.all()
|
||||
return [secret.name for secret in secrets]
|
||||
|
||||
async def add_group(
|
||||
self,
|
||||
name_or_path: str,
|
||||
description: str | None = None,
|
||||
parent_group: str | None = None,
|
||||
) -> None:
|
||||
"""Add a group."""
|
||||
parent_id: uuid.UUID | None = None
|
||||
group_name = name_or_path
|
||||
if parent_group and name_or_path.startswith("/"):
|
||||
raise InvalidGroupNameError(
|
||||
"Path as name cannot be used if parent is also specified."
|
||||
)
|
||||
if name_or_path.startswith("/"):
|
||||
elements = parse_path(name_or_path)
|
||||
group_name = elements.item
|
||||
parent_group = elements.parent
|
||||
|
||||
if parent_group:
|
||||
if parent := (await self._get_group(parent_group)):
|
||||
child_names = [child.name for child in parent.children]
|
||||
if group_name in child_names:
|
||||
raise InvalidGroupNameError(
|
||||
"Parent group already has a group with this name."
|
||||
)
|
||||
parent_id = parent.id
|
||||
|
||||
else:
|
||||
raise InvalidGroupNameError(
|
||||
"Invalid or non-existing parent group name."
|
||||
)
|
||||
else:
|
||||
existing_group = await self._get_group(group_name)
|
||||
if existing_group:
|
||||
raise InvalidGroupNameError("A group with this name already exists.")
|
||||
|
||||
group = Group(
|
||||
name=group_name,
|
||||
description=description,
|
||||
parent_id=parent_id,
|
||||
)
|
||||
self.session.add(group)
|
||||
# We don't audit-log this operation.
|
||||
await self.session.commit()
|
||||
|
||||
async def set_group_description(self, path: str, description: str) -> None:
|
||||
"""Set group description."""
|
||||
elements = parse_path(path)
|
||||
group = await self._get_group(elements.item, elements.parent, True)
|
||||
if not group:
|
||||
raise InvalidGroupNameError("Invalid or non-existing group name.")
|
||||
group.description = description
|
||||
self.session.add(group)
|
||||
await self.session.commit()
|
||||
|
||||
async def set_secret_group(self, entry_name: str, group_name: str | None) -> None:
|
||||
"""Move a secret to a group.
|
||||
|
||||
If group_name is None, the secret will be moved out of any group it may exist in.
|
||||
"""
|
||||
entry = await self._get_entry(entry_name)
|
||||
if not entry:
|
||||
raise InvalidSecretNameError("Invalid or non-existing secret.")
|
||||
if group_name:
|
||||
elements = parse_path(group_name)
|
||||
group = await self._get_group(elements.item, elements.parent, True)
|
||||
if not group:
|
||||
raise InvalidGroupNameError("Invalid or non-existing group name.")
|
||||
entry.group_id = group.id
|
||||
else:
|
||||
entry.group_id = None
|
||||
|
||||
self.session.add(entry)
|
||||
await self.session.commit()
|
||||
await self.write_audit(
|
||||
Operation.UPDATE,
|
||||
"Secret group updated",
|
||||
group_name=group_name or "ROOT",
|
||||
secret_name=entry_name,
|
||||
managed_secret=entry,
|
||||
)
|
||||
|
||||
async def move_group(self, path: str, parent_group: str | None) -> None:
|
||||
"""Move group.
|
||||
|
||||
If parent_group is None, it will be moved to the root.
|
||||
"""
|
||||
elements = parse_path(path)
|
||||
group = await self._get_group(elements.item, elements.parent, True)
|
||||
if not group:
|
||||
raise InvalidGroupNameError("Invalid or non-existing group name.")
|
||||
|
||||
parent_group_id: uuid.UUID | None = None
|
||||
if parent_group:
|
||||
db_parent_group = await self._get_group(parent_group)
|
||||
if not db_parent_group:
|
||||
raise InvalidGroupNameError("Invalid or non-existing parent group.")
|
||||
parent_group_id = db_parent_group.id
|
||||
|
||||
group.parent_id = parent_group_id
|
||||
|
||||
self.session.add(group)
|
||||
await self.session.commit()
|
||||
|
||||
async def delete_group(self, path: str) -> None:
|
||||
"""Delete a group."""
|
||||
elements = parse_path(path)
|
||||
group = await self._get_group(elements.item, elements.parent, True)
|
||||
if not group:
|
||||
return
|
||||
await self.session.delete(group)
|
||||
|
||||
await self.session.commit()
|
||||
# We don't audit-log this operation currently, even though it indirectly
|
||||
# may affect secrets.
|
||||
|
||||
async def _export_entries(self) -> list[SecretDataEntryExport]:
|
||||
"""Export entries as a pydantic object."""
|
||||
statement = (
|
||||
select(ManagedSecret)
|
||||
.options(selectinload(ManagedSecret.group))
|
||||
.where(ManagedSecret.is_deleted.is_(False))
|
||||
)
|
||||
results = await self.session.scalars(statement)
|
||||
entries: list[SecretDataEntryExport] = []
|
||||
for entry in results.all():
|
||||
group: str | None = None
|
||||
if entry.group:
|
||||
group = await self._get_group_path(entry.group)
|
||||
secret = await self.get_secret(entry.name)
|
||||
if not secret:
|
||||
continue
|
||||
data = SecretDataEntryExport(name=entry.name, secret=secret, group=group)
|
||||
entries.append(data)
|
||||
return entries
|
||||
|
||||
async def _export_groups(self) -> list[SecretDataGroupExport]:
|
||||
"""Export groups as pydantic objects."""
|
||||
groups = await self.get_secret_group_list()
|
||||
entries = [
|
||||
SecretDataGroupExport(
|
||||
name=group.name,
|
||||
path=group.path,
|
||||
description=group.description,
|
||||
)
|
||||
for group in groups
|
||||
]
|
||||
return entries
|
||||
|
||||
async def export_secrets(self) -> SecretDataExport:
|
||||
"""Export the managed secrets as a pydantic object."""
|
||||
entries = await self._export_entries()
|
||||
groups = await self._export_groups()
|
||||
return SecretDataExport(entries=entries, groups=groups)
|
||||
|
||||
async def export_secrets_json(self) -> str:
|
||||
"""Export secrets as JSON."""
|
||||
export = await self.export_secrets()
|
||||
return export.model_dump_json(indent=2)
|
||||
|
||||
|
||||
def get_managed_private_key(
|
||||
settings: AdminServerSettings,
|
||||
filename: str = KEY_FILENAME,
|
||||
regenerate: bool = False,
|
||||
) -> rsa.RSAPrivateKey:
|
||||
"""Load our private key."""
|
||||
keyfile = Path(filename)
|
||||
if settings.password_manager_directory:
|
||||
keyfile = settings.password_manager_directory / filename
|
||||
if not keyfile.exists():
|
||||
_initial_key_setup(settings, keyfile)
|
||||
setup_password_manager(settings, keyfile, regenerate)
|
||||
return load_private_key(str(keyfile.absolute()), password=settings.secret_key)
|
||||
|
||||
|
||||
def setup_password_manager(
|
||||
settings: AdminServerSettings, filename: Path, regenerate: bool = False
|
||||
) -> bool:
|
||||
"""Setup password manager."""
|
||||
if filename.exists() and not regenerate:
|
||||
return False
|
||||
|
||||
if not settings.secret_key:
|
||||
raise RuntimeError("Error: Could not load secret key from environment.")
|
||||
create_private_rsa_key(filename, password=settings.secret_key)
|
||||
return True
|
||||
|
||||
|
||||
async def create_manager_client(
|
||||
backend: SshecretBackend, public_key: rsa.RSAPublicKey
|
||||
) -> Client:
|
||||
"""Create the manager client."""
|
||||
public_key_string = generate_public_key_string(public_key)
|
||||
new_client = await backend.create_system_client(
|
||||
"AdminPasswordManager",
|
||||
public_key_string,
|
||||
)
|
||||
return new_client
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def password_manager_context(
|
||||
settings: AdminServerSettings, username: str, origin: str
|
||||
) -> AsyncIterator[AsyncSecretContext]:
|
||||
"""Start a context for the password manager."""
|
||||
audit_context_data = ClientAuditData(username=username, origin=origin)
|
||||
session_manager = DatabaseSessionManager(settings.async_db_url)
|
||||
backend = SshecretBackend(str(settings.backend_url), settings.backend_token)
|
||||
private_key = get_managed_private_key(settings)
|
||||
async with session_manager.session() as session:
|
||||
# Check if there is a client_id stored already.
|
||||
query = select(PasswordDB).where(PasswordDB.id == 1)
|
||||
result = await session.scalars(query)
|
||||
password_db = result.first()
|
||||
if not password_db:
|
||||
password_db = PasswordDB(id=1)
|
||||
session.add(password_db)
|
||||
await session.flush()
|
||||
if not password_db.client_id:
|
||||
manager_client = await create_manager_client(
|
||||
backend, private_key.public_key()
|
||||
)
|
||||
password_db.client_id = manager_client.id
|
||||
session.add(password_db)
|
||||
await session.commit()
|
||||
else:
|
||||
manager_client = await backend.get_client(
|
||||
("id", str(password_db.client_id))
|
||||
)
|
||||
if not manager_client:
|
||||
raise SecretManagerError("Error: Could not fetch system client.")
|
||||
|
||||
context = AsyncSecretContext(
|
||||
private_key, manager_client, session, backend, audit_context_data
|
||||
)
|
||||
yield context
|
||||
|
||||
|
||||
def setup_private_key(
|
||||
settings: AdminServerSettings,
|
||||
filename: str = KEY_FILENAME,
|
||||
regenerate: bool = False,
|
||||
) -> None:
|
||||
"""Setup secret manager private key."""
|
||||
keyfile = Path(filename)
|
||||
if settings.password_manager_directory:
|
||||
keyfile = settings.password_manager_directory / filename
|
||||
_initial_key_setup(settings, keyfile, regenerate)
|
||||
|
||||
|
||||
def _initial_key_setup(
|
||||
settings: AdminServerSettings,
|
||||
keyfile: Path,
|
||||
regenerate: bool = False,
|
||||
) -> bool:
|
||||
"""Set up initial keys."""
|
||||
if keyfile.exists() and not regenerate:
|
||||
return False
|
||||
|
||||
assert (
|
||||
settings.secret_key is not None
|
||||
), "Error: Could not load a secret key from environment."
|
||||
create_private_rsa_key(keyfile, password=settings.secret_key)
|
||||
return True
|
||||
@ -0,0 +1,33 @@
|
||||
"""Implement db structures for internal password manager
|
||||
|
||||
Revision ID: 1657c5d25d2c
|
||||
Revises: b4e135ff347a
|
||||
Create Date: 2025-06-21 07:22:17.792528
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '1657c5d25d2c'
|
||||
down_revision: Union[str, None] = 'b4e135ff347a'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('client', sa.Column('is_system', sa.Boolean(), nullable=False, default=False, server_default="0"))
|
||||
op.add_column('client_secret', sa.Column('is_system', sa.Boolean(), nullable=False, default=False, server_default="0"))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('client_secret', 'is_system')
|
||||
op.drop_column('client', 'is_system')
|
||||
@ -0,0 +1,44 @@
|
||||
"""Remove secret key from password database
|
||||
|
||||
Revision ID: 71f7272a6ee1
|
||||
Revises: 1657c5d25d2c
|
||||
Create Date: 2025-06-22 18:42:53.207334
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '71f7272a6ee1'
|
||||
down_revision: Union[str, None] = '1657c5d25d2c'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('managed_secret')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('managed_secret',
|
||||
sa.Column('id', sa.CHAR(length=32), nullable=False),
|
||||
sa.Column('name', sa.VARCHAR(), nullable=False),
|
||||
sa.Column('description', sa.VARCHAR(), nullable=True),
|
||||
sa.Column('secret', sa.VARCHAR(), nullable=False),
|
||||
sa.Column('client_id', sa.CHAR(length=32), nullable=True),
|
||||
sa.Column('deleted', sa.BOOLEAN(), nullable=False),
|
||||
sa.Column('created_at', sa.DATETIME(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DATETIME(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
|
||||
sa.Column('deleted_at', sa.DATETIME(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['client_id'], ['client.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
@ -107,8 +107,7 @@ class ClientOperations:
|
||||
return ClientView.from_client(db_client)
|
||||
|
||||
async def create_client(
|
||||
self,
|
||||
create_model: ClientCreate,
|
||||
self, create_model: ClientCreate, system_client: bool = False
|
||||
) -> ClientView:
|
||||
"""Create a new client."""
|
||||
existing_id = await self.get_client_id(FlexID.name(create_model.name))
|
||||
@ -117,6 +116,15 @@ class ClientOperations:
|
||||
status_code=400, detail="Error: A client already exists with this name."
|
||||
)
|
||||
client = create_model.to_client()
|
||||
if system_client:
|
||||
statement = query_active_clients().where(Client.is_system.is_(True))
|
||||
results = await self.session.scalars(statement)
|
||||
other_system_clients = results.all()
|
||||
if other_system_clients:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Only one system client may exist"
|
||||
)
|
||||
client.is_system = True
|
||||
self.session.add(client)
|
||||
await self.session.flush()
|
||||
await self.session.commit()
|
||||
@ -246,6 +254,15 @@ class ClientOperations:
|
||||
|
||||
return ClientPolicyView.from_client(db_client)
|
||||
|
||||
async def get_system_client(self) -> ClientView:
|
||||
"""Get the system client, if it exists."""
|
||||
statement = query_active_clients().where(Client.is_system.is_(True))
|
||||
result = await self.session.scalars(statement)
|
||||
client = result.first()
|
||||
if not client:
|
||||
raise HTTPException(status_code=404, detail="No system client registered")
|
||||
return ClientView.from_client(client)
|
||||
|
||||
|
||||
def resolve_order(statement: Select[Any], order_by: str, reversed: bool) -> Select[Any]:
|
||||
"""Resolve ordering."""
|
||||
@ -261,12 +278,13 @@ def resolve_order(statement: Select[Any], order_by: str, reversed: bool) -> Sele
|
||||
statement = statement.order_by(column.desc())
|
||||
else:
|
||||
statement = statement.order_by(column.asc())
|
||||
#FIXME: Remove
|
||||
# FIXME: Remove
|
||||
LOG.info("Ordered by %s (%r)", order_by, reversed)
|
||||
return statement
|
||||
LOG.warning("Unsupported order field: %s", order_by)
|
||||
return statement
|
||||
|
||||
|
||||
def filter_client_statement(
|
||||
statement: Select[Any], params: ClientListParams, ignore_limits: bool = False
|
||||
) -> Select[Any]:
|
||||
@ -299,6 +317,7 @@ async def get_clients(
|
||||
.select_from(Client)
|
||||
.where(Client.is_deleted.is_not(True))
|
||||
.where(Client.is_active.is_not(False))
|
||||
.where(Client.is_system.is_not(True))
|
||||
)
|
||||
count_statement = cast(
|
||||
Select[tuple[int]],
|
||||
@ -307,7 +326,8 @@ async def get_clients(
|
||||
|
||||
total_results = (await session.scalars(count_statement)).one()
|
||||
|
||||
statement = filter_client_statement(query_active_clients(), filter_query, False)
|
||||
statement = query_active_clients().where(Client.is_system.is_not(True))
|
||||
statement = filter_client_statement(statement, filter_query, False)
|
||||
|
||||
results = await session.scalars(statement)
|
||||
remainder = total_results - filter_query.offset - filter_query.limit
|
||||
|
||||
@ -46,6 +46,25 @@ def create_client_router(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||
client_op = ClientOperations(session, request)
|
||||
return await client_op.create_client(client)
|
||||
|
||||
@router.get("/internal/system_client/", include_in_schema=False)
|
||||
async def get_system_client(
|
||||
request: Request,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientView:
|
||||
"""Get the system client."""
|
||||
client_op = ClientOperations(session, request)
|
||||
return await client_op.get_system_client()
|
||||
|
||||
@router.post("/internal/system_client/", include_in_schema=False)
|
||||
async def create_system_client(
|
||||
request: Request,
|
||||
client: ClientCreate,
|
||||
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||
) -> ClientView:
|
||||
"""Create system client."""
|
||||
client_op = ClientOperations(session, request)
|
||||
return await client_op.create_client(client, system_client=True)
|
||||
|
||||
@router.get("/clients/{client_identifier}")
|
||||
async def fetch_client_by_name(
|
||||
request: Request,
|
||||
|
||||
@ -242,7 +242,7 @@ async def resolve_client_secret_clients(
|
||||
# Ensure we don't create the object before we have at least one client.
|
||||
clients = ClientSecretDetailList(name=name)
|
||||
clients.ids.append(str(client_secret.id))
|
||||
if client_secret.client:
|
||||
if client_secret.client and not client_secret.client.is_system:
|
||||
clients.clients.append(
|
||||
ClientReference(
|
||||
id=str(client_secret.client.id), name=client_secret.client.name
|
||||
|
||||
@ -110,7 +110,6 @@ def get_async_engine(url: URL, echo: bool = False, **engine_kwargs: str) -> Asyn
|
||||
"""Get an async engine."""
|
||||
engine = create_async_engine(url, echo=echo, **engine_kwargs)
|
||||
if url.drivername.startswith("sqlite+"):
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def set_sqlite_pragma(
|
||||
dbapi_connection: sqlite3.Connection, _connection_record: object
|
||||
|
||||
@ -67,6 +67,7 @@ class Client(Base):
|
||||
|
||||
is_active: Mapped[bool] = mapped_column(sa.Boolean, default=True)
|
||||
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, default=False)
|
||||
is_system: Mapped[bool] = mapped_column(sa.Boolean, default=False)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||
@ -141,6 +142,8 @@ class ClientSecret(Base):
|
||||
client: Mapped[Client] = relationship(back_populates="secrets")
|
||||
deleted: Mapped[bool] = mapped_column(default=False)
|
||||
|
||||
is_system: Mapped[bool] = mapped_column(sa.Boolean, default=False)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||
)
|
||||
|
||||
@ -140,11 +140,23 @@ class ShellStoreSecret(CommandDispatcher):
|
||||
secret=secret_name,
|
||||
)
|
||||
|
||||
await self.store_managed_secret(secret_name, secret_data)
|
||||
|
||||
def encrypt_secret(self, value: str) -> str:
|
||||
"""Encrypt a secret."""
|
||||
public_key = load_public_key(self.client.public_key.encode())
|
||||
return encrypt_string(value, public_key)
|
||||
|
||||
async def store_managed_secret(self, secret_name: str, secret_data: str) -> None:
|
||||
"""Store managed secret."""
|
||||
system_client = await self.backend.get_system_client()
|
||||
if not system_client:
|
||||
return
|
||||
public_key = load_public_key(system_client.public_key.encode())
|
||||
encrypted = encrypt_string(secret_data, public_key)
|
||||
await self.backend.create_client_secret(("id", str(system_client.id)), secret_name, encrypted)
|
||||
await self.audit(operation=Operation.CREATE, message="Managed secret entry created.", secret=secret_name)
|
||||
|
||||
async def get_secret_on_stdin(self) -> str | None:
|
||||
"""Get secret from stdin."""
|
||||
secret_data = ""
|
||||
|
||||
@ -6,6 +6,7 @@ admin and sshd library do not need to implement the same
|
||||
|
||||
import logging
|
||||
from typing import Any, Literal, Self, override
|
||||
import uuid
|
||||
import httpx
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
@ -325,6 +326,28 @@ class SshecretBackend(BaseBackend):
|
||||
path = "/api/v1/clients/"
|
||||
response = await self._post(path, json=data)
|
||||
|
||||
async def create_system_client(self, name: str, public_key: str) -> Client:
|
||||
"""Create system client."""
|
||||
if not validate_public_key(public_key):
|
||||
raise BackendValidationError("Error: Invalid public key format.")
|
||||
|
||||
data = {
|
||||
"name": name,
|
||||
"public_key": public_key,
|
||||
"description": "Internal system client",
|
||||
}
|
||||
path = "/api/v1/internal/system_client/"
|
||||
response = await self._post(path, json=data)
|
||||
return Client.model_validate(response.json())
|
||||
|
||||
async def get_system_client(self) -> Client | None:
|
||||
"""Get the system client."""
|
||||
path = "/api/v1/internal/system_client/"
|
||||
response = await self._get(path)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
return Client.model_validate(response.json())
|
||||
|
||||
async def get_clients(self, filter: ClientFilter | None = None) -> list[Client]:
|
||||
"""Get all clients."""
|
||||
clients: list[Client] = []
|
||||
@ -375,7 +398,7 @@ class SshecretBackend(BaseBackend):
|
||||
|
||||
async def create_client_secret(
|
||||
self, client_idname: KeySpec, secret_name: str, encrypted_secret: str
|
||||
) -> None:
|
||||
) -> ClientSecret:
|
||||
"""Create a secret.
|
||||
|
||||
This will overwrite any existing secret with that name.
|
||||
@ -383,6 +406,8 @@ class SshecretBackend(BaseBackend):
|
||||
client_key = _key(client_idname)
|
||||
path = f"api/v1/clients/{client_key}/secrets/{secret_name}"
|
||||
response = await self._put(path, json={"value": encrypted_secret})
|
||||
secret = ClientSecret.model_validate(response.json())
|
||||
return secret
|
||||
|
||||
async def get_client_secret(
|
||||
self, client_idname: KeySpec, secret_idname: KeySpec
|
||||
|
||||
@ -8,14 +8,14 @@ from pathlib import Path
|
||||
from sqlmodel import Session, create_engine
|
||||
from sshecret.crypto import generate_private_key, write_private_key
|
||||
from sshecret_admin.auth.authentication import hash_password
|
||||
from sshecret_admin.auth.models import AuthProvider, User, init_db
|
||||
from sshecret_admin.auth.models import AuthProvider, User, Base
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
|
||||
def create_test_admin_user(settings: AdminServerSettings, username: str, password: str) -> None:
|
||||
"""Create a test admin user."""
|
||||
hashed_password = hash_password(password)
|
||||
engine = create_engine(settings.admin_db)
|
||||
init_db(engine)
|
||||
Base.metadata.create_all(engine)
|
||||
with Session(engine) as session:
|
||||
user = User(username=username, hashed_password=hashed_password, provider=AuthProvider.LOCAL, email="test@test.com")
|
||||
session.add(user)
|
||||
|
||||
1
tests/integration/admin/__init__.py
Normal file
1
tests/integration/admin/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
|
||||
67
tests/integration/admin/base.py
Normal file
67
tests/integration/admin/base.py
Normal file
@ -0,0 +1,67 @@
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import httpx
|
||||
|
||||
from sshecret.backend import Client
|
||||
|
||||
from sshecret.crypto import generate_private_key, generate_public_key_string
|
||||
|
||||
from ..types import AdminServer
|
||||
|
||||
|
||||
def make_test_key() -> str:
|
||||
"""Generate a test key."""
|
||||
private_key = generate_private_key()
|
||||
return generate_public_key_string(private_key.public_key())
|
||||
|
||||
|
||||
class BaseAdminTests:
|
||||
"""Base admin test class."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def http_client(
|
||||
self, admin_server: AdminServer, authenticate: bool = True
|
||||
) -> AsyncIterator[httpx.AsyncClient]:
|
||||
"""Run a client towards the admin rest api."""
|
||||
admin_url, credentials = admin_server
|
||||
username, password = credentials
|
||||
headers: dict[str, str] | None = None
|
||||
if authenticate:
|
||||
async with httpx.AsyncClient(base_url=admin_url) as client:
|
||||
|
||||
response = await client.post(
|
||||
"api/v1/token", data={"username": username, "password": password}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "access_token" in data
|
||||
token = data["access_token"]
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
async with httpx.AsyncClient(base_url=admin_url, headers=headers) as client:
|
||||
yield client
|
||||
|
||||
async def create_client(
|
||||
self,
|
||||
admin_server: AdminServer,
|
||||
name: str,
|
||||
public_key: str | None = None,
|
||||
) -> Client:
|
||||
"""Create a client."""
|
||||
if not public_key:
|
||||
public_key = make_test_key()
|
||||
|
||||
new_client = {
|
||||
"name": name,
|
||||
"public_key": public_key,
|
||||
"sources": ["192.0.2.0/24"],
|
||||
}
|
||||
|
||||
async with self.http_client(admin_server, True) as http_client:
|
||||
response = await http_client.post("api/v1/clients/", json=new_client)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
client = Client.model_validate(data)
|
||||
|
||||
return client
|
||||
@ -1,76 +1,12 @@
|
||||
"""Tests of the admin interface."""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
import allure
|
||||
import pytest
|
||||
|
||||
import httpx
|
||||
|
||||
from allure_commons.types import Severity
|
||||
|
||||
from sshecret.backend import Client
|
||||
|
||||
from sshecret.crypto import generate_private_key, generate_public_key_string
|
||||
|
||||
from .types import AdminServer
|
||||
|
||||
|
||||
def make_test_key() -> str:
|
||||
"""Generate a test key."""
|
||||
private_key = generate_private_key()
|
||||
return generate_public_key_string(private_key.public_key())
|
||||
|
||||
|
||||
class BaseAdminTests:
|
||||
"""Base admin test class."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def http_client(
|
||||
self, admin_server: AdminServer, authenticate: bool = True
|
||||
) -> AsyncIterator[httpx.AsyncClient]:
|
||||
"""Run a client towards the admin rest api."""
|
||||
admin_url, credentials = admin_server
|
||||
username, password = credentials
|
||||
headers: dict[str, str] | None = None
|
||||
if authenticate:
|
||||
async with httpx.AsyncClient(base_url=admin_url) as client:
|
||||
|
||||
response = await client.post(
|
||||
"api/v1/token", data={"username": username, "password": password}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "access_token" in data
|
||||
token = data["access_token"]
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
async with httpx.AsyncClient(base_url=admin_url, headers=headers) as client:
|
||||
yield client
|
||||
|
||||
async def create_client(
|
||||
self,
|
||||
admin_server: AdminServer,
|
||||
name: str,
|
||||
public_key: str | None = None,
|
||||
) -> Client:
|
||||
"""Create a client."""
|
||||
if not public_key:
|
||||
public_key = make_test_key()
|
||||
|
||||
new_client = {
|
||||
"name": name,
|
||||
"public_key": public_key,
|
||||
"sources": ["192.0.2.0/24"],
|
||||
}
|
||||
|
||||
async with self.http_client(admin_server, True) as http_client:
|
||||
response = await http_client.post("api/v1/clients/", json=new_client)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
client = Client.model_validate(data)
|
||||
|
||||
return client
|
||||
from ..types import AdminServer
|
||||
from .base import BaseAdminTests
|
||||
|
||||
|
||||
@allure.title("Admin API")
|
||||
@ -196,7 +132,9 @@ class TestAdminApiSecrets(BaseAdminTests):
|
||||
assert "testclient" in data["clients"]
|
||||
|
||||
@allure.title("Test adding a secret with automatic value")
|
||||
@allure.description("Test that we can add a secret where we let the system come up with the value of a given length.")
|
||||
@allure.description(
|
||||
"Test that we can add a secret where we let the system come up with the value of a given length."
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_secret_auto(self, admin_server: AdminServer) -> None:
|
||||
"""Test adding a secret with an auto-generated value."""
|
||||
634
tests/integration/admin/test_secret_manager.py
Normal file
634
tests/integration/admin/test_secret_manager.py
Normal file
@ -0,0 +1,634 @@
|
||||
"""Test secret manager.
|
||||
|
||||
This package tests the rewritten secret manager system.
|
||||
|
||||
This is technically an integration test, as it requires the other subsystems to
|
||||
run, but it uses the internal API rather than the exposed routes.
|
||||
"""
|
||||
|
||||
import allure
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
from sshecret_admin.services.models import SecretGroup
|
||||
from sshecret_admin.services.secret_manager import (
|
||||
password_manager_context,
|
||||
AsyncSecretContext,
|
||||
InvalidSecretNameError,
|
||||
InvalidGroupNameError,
|
||||
)
|
||||
from sshecret_admin.auth.models import Base, PasswordDB
|
||||
from sshecret_admin.services.master_password import setup_master_password
|
||||
|
||||
# -------- global parameter sets start here -------- #
|
||||
|
||||
|
||||
# -------- Fixtures start here -------- #
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def create_admin_db(admin_server_settings: AdminServerSettings) -> None:
|
||||
"""Create the database."""
|
||||
engine = create_engine(admin_server_settings.admin_db)
|
||||
Base.metadata.create_all(engine)
|
||||
encr_master_password = setup_master_password(
|
||||
settings=admin_server_settings, regenerate=True
|
||||
)
|
||||
|
||||
with Session(engine) as session:
|
||||
pwdb = PasswordDB(id=1, encrypted_password=encr_master_password)
|
||||
session.add(pwdb)
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def secrets_manager(admin_server_settings: AdminServerSettings):
|
||||
"""Test that the context manager can be created."""
|
||||
async with password_manager_context(
|
||||
admin_server_settings, "TEST", "127.0.0.1"
|
||||
) as manager:
|
||||
yield manager
|
||||
|
||||
|
||||
# -------- Tests start here -------- #
|
||||
|
||||
|
||||
@allure.title("Adding entries")
|
||||
@pytest.mark.parametrize("name,secret", [("testentry", "testsecret")])
|
||||
class TestSecretsAddEntry:
|
||||
"""Tests for the add_entry method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_entry(
|
||||
self,
|
||||
secrets_manager: AsyncSecretContext,
|
||||
name: str,
|
||||
secret: str,
|
||||
) -> None:
|
||||
"""Test add entry.
|
||||
|
||||
This tests add_entry and get_secret
|
||||
"""
|
||||
await secrets_manager.add_entry(name, secret)
|
||||
stored_secret = await secrets_manager.get_secret(name)
|
||||
assert stored_secret == secret
|
||||
|
||||
async def test_add_entry_duplicate(
|
||||
self,
|
||||
secrets_manager: AsyncSecretContext,
|
||||
name: str,
|
||||
secret: str,
|
||||
) -> None:
|
||||
"""Test adding an entry twice."""
|
||||
await secrets_manager.add_entry(name, secret)
|
||||
stored_secret = await secrets_manager.get_secret(name)
|
||||
assert stored_secret == secret
|
||||
|
||||
with pytest.raises(InvalidSecretNameError):
|
||||
await secrets_manager.add_entry(name, secret)
|
||||
|
||||
async def test_add_entry_with_group(
|
||||
self,
|
||||
secrets_manager: AsyncSecretContext,
|
||||
name: str,
|
||||
secret: str,
|
||||
) -> None:
|
||||
"""Test adding a secret with a group."""
|
||||
group = "testgroup"
|
||||
await secrets_manager.add_group(group)
|
||||
await secrets_manager.add_entry(name, secret, group_path=group)
|
||||
result = await secrets_manager.get_entry_group(name)
|
||||
assert result == group
|
||||
|
||||
async def test_add_entry_with_nonexisting_group(
|
||||
self,
|
||||
secrets_manager: AsyncSecretContext,
|
||||
name: str,
|
||||
secret: str,
|
||||
) -> None:
|
||||
"""Test adding a secret where the group does not exist."""
|
||||
group = "testgroup"
|
||||
with pytest.raises(InvalidGroupNameError):
|
||||
await secrets_manager.add_entry(name, secret, group_path=group)
|
||||
|
||||
async def test_add_entry_with_deep_path(
|
||||
self,
|
||||
secrets_manager: AsyncSecretContext,
|
||||
name: str,
|
||||
secret: str,
|
||||
) -> None:
|
||||
"""Test adding a secret to a nested group with a path specification."""
|
||||
await secrets_manager.add_group("root")
|
||||
await secrets_manager.add_group("nested", parent_group="root")
|
||||
await secrets_manager.add_entry(name, secret, group_path="/root/nested")
|
||||
group = await secrets_manager.get_secret_group("/root/nested")
|
||||
assert group is not None
|
||||
assert name in group.entries
|
||||
|
||||
async def test_overwrite_secret(
|
||||
self, secrets_manager: AsyncSecretContext, name: str, secret: str
|
||||
) -> None:
|
||||
"""Test overwriting a secret."""
|
||||
await secrets_manager.add_entry(name, secret)
|
||||
stored_secret = await secrets_manager.get_secret(name)
|
||||
assert stored_secret == secret
|
||||
|
||||
new_secret = "newsecret"
|
||||
await secrets_manager.add_entry(name, new_secret, overwrite=True)
|
||||
|
||||
stored_secret = await secrets_manager.get_secret(name)
|
||||
assert stored_secret == new_secret
|
||||
|
||||
|
||||
@allure.title("Creating groups")
|
||||
class TestSecretGroupCreation:
|
||||
"""Test secret groups."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"group_name", ["testgroup", "long group name with spaces", "blåbærgrød"]
|
||||
)
|
||||
@allure.title("Add a group name {group_name}")
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_group(
|
||||
self,
|
||||
secrets_manager: AsyncSecretContext,
|
||||
group_name: str,
|
||||
) -> None:
|
||||
"""Get adding a group."""
|
||||
await secrets_manager.add_group(group_name)
|
||||
groups = await secrets_manager.get_secret_groups()
|
||||
assert len(groups) == 1
|
||||
assert groups[0].name == group_name
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_group_with_parent(
|
||||
self, secrets_manager: AsyncSecretContext
|
||||
) -> None:
|
||||
"""Test add a group with a parent group."""
|
||||
parent_name = "parent"
|
||||
child_name = "child"
|
||||
await secrets_manager.add_group(parent_name)
|
||||
await secrets_manager.add_group(child_name, parent_group=parent_name)
|
||||
parent_group = await secrets_manager.get_secret_group(f"/parent")
|
||||
assert parent_group is not None
|
||||
assert len(parent_group.children) == 1
|
||||
assert parent_group.children[0].name == child_name
|
||||
|
||||
child_group = await secrets_manager.get_secret_group("/parent/child")
|
||||
assert child_group is not None
|
||||
assert child_group.name == child_name
|
||||
assert child_group.parent_group is not None
|
||||
assert child_group.parent_group.name == parent_name
|
||||
assert len(child_group.children) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_group_as_path(self, secrets_manager: AsyncSecretContext) -> None:
|
||||
"""Add a nested group with path annotation."""
|
||||
parent_name = "parent"
|
||||
child_path = "/parent/child"
|
||||
await secrets_manager.add_group(parent_name)
|
||||
await secrets_manager.add_group(child_path)
|
||||
parent_group = await secrets_manager.get_secret_group(f"/parent")
|
||||
assert parent_group is not None
|
||||
assert len(parent_group.children) == 1
|
||||
assert parent_group.children[0].name == "child"
|
||||
|
||||
child_group = await secrets_manager.get_secret_group(child_path)
|
||||
assert child_group is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_overlapping_names(self, secrets_manager: AsyncSecretContext) -> None:
|
||||
"""Test having overlapping names in different groups."""
|
||||
await secrets_manager.add_group("root")
|
||||
with pytest.raises(InvalidGroupNameError):
|
||||
await secrets_manager.add_group("/root")
|
||||
|
||||
await secrets_manager.add_group("/root/root")
|
||||
|
||||
group = await secrets_manager.get_secret_group("/root/root")
|
||||
assert group is not None
|
||||
assert group.name == "root"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_group_with_nonexisting_parent(
|
||||
self, secrets_manager: AsyncSecretContext
|
||||
) -> None:
|
||||
"""Test adding a group with a nonexisting parent."""
|
||||
with pytest.raises(InvalidGroupNameError):
|
||||
await secrets_manager.add_group("orphan", parent_group="unknown")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_duplicate_group(
|
||||
self, secrets_manager: AsyncSecretContext
|
||||
) -> None:
|
||||
"""Test adding the same group twice."""
|
||||
await secrets_manager.add_group("snowflake")
|
||||
with pytest.raises(InvalidGroupNameError):
|
||||
await secrets_manager.add_group("snowflake")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"group_name,description", [("testgroup", "test description")]
|
||||
)
|
||||
@allure.title("Add a group name {group_name} with description {description}")
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_group_with_description(
|
||||
self,
|
||||
secrets_manager: AsyncSecretContext,
|
||||
group_name: str,
|
||||
description: str,
|
||||
) -> None:
|
||||
"""Test adding a group with description."""
|
||||
await secrets_manager.add_group(group_name, description)
|
||||
result = await secrets_manager.get_secret_group(group_name)
|
||||
assert result is not None
|
||||
assert result.description == description
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"groups",
|
||||
[
|
||||
[
|
||||
("root", None, "root"),
|
||||
("level1", "root", "/root/level1"),
|
||||
("level2", "level1", "/root/level1/level2"),
|
||||
],
|
||||
[("flat1", None, "flat1"), ("flat2", None, "flat2")],
|
||||
[
|
||||
("stub", None, "stub"),
|
||||
("root", None, "root"),
|
||||
("nested", "root", "/root/nested"),
|
||||
],
|
||||
],
|
||||
)
|
||||
@allure.title("Listing groups")
|
||||
class TestSecretGroupListing:
|
||||
"""Tests for listing groups."""
|
||||
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def create_groups(
|
||||
self,
|
||||
secrets_manager: AsyncSecretContext,
|
||||
groups: list[tuple[str, str | None, str]],
|
||||
) -> None:
|
||||
"""Pre-create groups."""
|
||||
for name, parent_group, _path in groups:
|
||||
await secrets_manager.add_group(name, parent_group=parent_group)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_secret_groups_list(
|
||||
self,
|
||||
secrets_manager: AsyncSecretContext,
|
||||
groups: list[tuple[str, str | None, str]],
|
||||
) -> None:
|
||||
"""Test the flat get_secret_groups_list."""
|
||||
# Create three levels of content
|
||||
group_list = await secrets_manager.get_secret_group_list()
|
||||
assert len(group_list) == len(groups)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_secret_groups(
|
||||
self,
|
||||
secrets_manager: AsyncSecretContext,
|
||||
groups: list[tuple[str, str | None, str]],
|
||||
) -> None:
|
||||
"""Test the tree-oriented get_secretsgroups."""
|
||||
group_map = dict([(group[0], group[1]) for group in sorted(groups)])
|
||||
group_tree = await secrets_manager.get_secret_groups()
|
||||
root_groups = [key for key, value in group_map.items() if value is None]
|
||||
assert len(group_tree) == len(root_groups)
|
||||
reconstructed_groups: list[tuple[str, str | None]] = []
|
||||
|
||||
def crawl_tree(item: SecretGroup) -> list[tuple[str, str | None]]:
|
||||
"""Crawl a tree recursively."""
|
||||
parent_group_name = None
|
||||
if item.parent_group:
|
||||
parent_group_name = item.parent_group.name
|
||||
items: list[tuple[str, str | None]] = [(item.name, parent_group_name)]
|
||||
for child in item.children:
|
||||
items.extend(crawl_tree(child))
|
||||
|
||||
return items
|
||||
|
||||
for item in group_tree:
|
||||
reconstructed_groups.extend(crawl_tree(item))
|
||||
|
||||
assert dict(sorted(reconstructed_groups)) == group_map
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_secret_groups_with_secrets(
|
||||
self,
|
||||
secrets_manager: AsyncSecretContext,
|
||||
groups: list[tuple[str, str | None, str]],
|
||||
) -> None:
|
||||
"""Test fetching groups where there are secrets in all groups."""
|
||||
# We will create exactly two secrets in each group.
|
||||
for group_name, _parent, path in groups:
|
||||
await secrets_manager.add_entry(
|
||||
f"{group_name}_1", f"{group_name}_secret_1", group_path=path
|
||||
)
|
||||
await secrets_manager.add_entry(
|
||||
f"{group_name}_2", f"{group_name}_secret_2", group_path=path
|
||||
)
|
||||
|
||||
group_list = await secrets_manager.get_secret_group_list()
|
||||
for group in group_list:
|
||||
assert len(group.entries) == 2
|
||||
assert group.entries[0] == f"{group.name}_1"
|
||||
assert group.entries[1] == f"{group.name}_2"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query,expected,groups,children",
|
||||
[
|
||||
("MATCH", 1, [("MATCH", None), ("SOMETHINGELSE", None)], 0),
|
||||
("MATCH", 1, [("root", None), ("MATCH", "root"), ("SOMETHINGELSE", "root")], 0),
|
||||
("MATCH", 3, [("MATCH1", None), ("MATCH2", None), ("MATCH3", None)], 0),
|
||||
("MATCH", 1, [("root", None), ("MATCH", "root"), ("CHILD", "MATCH")], 1),
|
||||
(
|
||||
"NOMATCH",
|
||||
0,
|
||||
[("foo", None), ("bar", None), ("foobar", "foo"), ("barfoo", "bar")],
|
||||
0,
|
||||
),
|
||||
],
|
||||
)
|
||||
@allure.title("Searching in groups using patterns")
|
||||
class TestGroupSearchPattern:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_list_pattern(
|
||||
self,
|
||||
secrets_manager: AsyncSecretContext,
|
||||
query: str,
|
||||
expected: int,
|
||||
groups: list[tuple[str, str | None]],
|
||||
children: int,
|
||||
) -> None:
|
||||
"""Test matching a pattern."""
|
||||
for name, parent_group in groups:
|
||||
await secrets_manager.add_group(name, parent_group=parent_group)
|
||||
|
||||
result = await secrets_manager.get_secret_group_list(pattern=query, regex=False)
|
||||
assert len(result) == expected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_tree_pattern(
|
||||
self,
|
||||
secrets_manager: AsyncSecretContext,
|
||||
query: str,
|
||||
expected: int,
|
||||
groups: list[tuple[str, str | None]],
|
||||
children: int,
|
||||
) -> None:
|
||||
"""Test matching a pattern with a tree result."""
|
||||
for name, parent_group in groups:
|
||||
await secrets_manager.add_group(name, parent_group=parent_group)
|
||||
|
||||
result = await secrets_manager.get_secret_groups(pattern=query, regex=False)
|
||||
assert len(result) == expected
|
||||
if expected == 1 and children > 0:
|
||||
assert len(result[0].children) == children
|
||||
|
||||
|
||||
@allure.title("Modifying groups")
|
||||
class TestGroupModification:
|
||||
"""Test modifying groups."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"group_name,parent,description",
|
||||
[("test", None, "test description"), ("test", "root", "test_description")],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_group_description(
|
||||
self,
|
||||
secrets_manager: AsyncSecretContext,
|
||||
group_name: str,
|
||||
parent: str | None,
|
||||
description: str,
|
||||
) -> None:
|
||||
"""Test setting a description on a group."""
|
||||
if parent:
|
||||
await secrets_manager.add_group(parent)
|
||||
await secrets_manager.add_group(group_name, parent_group=parent)
|
||||
path = group_name
|
||||
if parent:
|
||||
path = f"/{parent}/{group_name}"
|
||||
group = await secrets_manager.get_secret_group(path)
|
||||
assert group is not None
|
||||
assert group.description is None
|
||||
|
||||
await secrets_manager.set_group_description(path, description)
|
||||
group = await secrets_manager.get_secret_group(path)
|
||||
assert group is not None
|
||||
assert group.description == description
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"groups,target_group, expected_path",
|
||||
[
|
||||
(
|
||||
[("root", None), ("test", None)],
|
||||
("test", "root"),
|
||||
"/root/test",
|
||||
),
|
||||
(
|
||||
[("root", None), ("test", "root")],
|
||||
("/root/test", None),
|
||||
"test",
|
||||
),
|
||||
([("test", None)], ("test", None), "test"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_group(
|
||||
self,
|
||||
secrets_manager: AsyncSecretContext,
|
||||
groups: list[tuple[str, str | None]],
|
||||
target_group: tuple[str, str | None],
|
||||
expected_path: str,
|
||||
) -> None:
|
||||
"""Test moving groups around."""
|
||||
for group_name, parent_name in groups:
|
||||
await secrets_manager.add_group(group_name, parent_group=parent_name)
|
||||
|
||||
group_name, target = target_group
|
||||
await secrets_manager.move_group(group_name, target)
|
||||
group = await secrets_manager.get_secret_group(expected_path)
|
||||
assert group is not None
|
||||
|
||||
|
||||
@allure.title("Deleting items")
|
||||
class TestSecretManagerDeletions:
|
||||
"""Test secret manager deletions."""
|
||||
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def create_test_data(self, secrets_manager: AsyncSecretContext) -> None:
|
||||
"""Create some test data."""
|
||||
groups = [
|
||||
("root", None, "root"),
|
||||
("level1", "root", "/root/level1"),
|
||||
("level2", "level1", "/root/level1/level2"),
|
||||
]
|
||||
for n in range(2):
|
||||
await secrets_manager.add_entry(f"ungrouped_{n}", "secret")
|
||||
for group_name, parent_name, path in groups:
|
||||
await secrets_manager.add_group(group_name, parent_group=parent_name)
|
||||
for n in range(2):
|
||||
await secrets_manager.add_entry(
|
||||
f"{group_name}_{n}", "secret", group_path=path
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"name,group_name",
|
||||
[("root_1", "root"), ("level1_0", "/root/level1"), ("ungrouped_1", None)],
|
||||
)
|
||||
@allure.title("Delete secret {name in group {group_name}}")
|
||||
@pytest.mark.asyncio
|
||||
async def test_secret_deletion(
|
||||
self, secrets_manager: AsyncSecretContext, name: str, group_name: str | None
|
||||
) -> None:
|
||||
"""Test secret deletion."""
|
||||
if group_name:
|
||||
group = await secrets_manager.get_secret_group(group_name)
|
||||
assert group is not None
|
||||
assert name in group.entries
|
||||
|
||||
secret = await secrets_manager.get_secret(name)
|
||||
assert secret is not None
|
||||
|
||||
await secrets_manager.delete_entry(name)
|
||||
|
||||
secret = await secrets_manager.get_secret(name)
|
||||
assert secret is None
|
||||
|
||||
if group_name:
|
||||
group = await secrets_manager.get_secret_group(group_name)
|
||||
assert group is not None
|
||||
assert name not in group.entries
|
||||
|
||||
@pytest.mark.parametrize("name", ["NONEXISTING"])
|
||||
@allure.title("Deleting non-existing entry {name}")
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonexisting_entry(
|
||||
self, secrets_manager: AsyncSecretContext, name: str
|
||||
) -> None:
|
||||
"""Test deleting something that doesn't exist."""
|
||||
secret = await secrets_manager.get_secret(name)
|
||||
assert secret is None
|
||||
# Deleting something that is already deleted returns None
|
||||
await secrets_manager.delete_entry(name)
|
||||
|
||||
@pytest.mark.parametrize("path", ["/root/level1"])
|
||||
@allure.title("Deleting group {path}")
|
||||
async def test_group_delete(
|
||||
self, secrets_manager: AsyncSecretContext, path: str
|
||||
) -> None:
|
||||
"""Test deleting a group."""
|
||||
group = await secrets_manager.get_secret_group(path)
|
||||
assert group is not None
|
||||
entries = list(group.entries)
|
||||
await secrets_manager.delete_group(path)
|
||||
group = await secrets_manager.get_secret_group(path)
|
||||
assert group is None
|
||||
|
||||
for name in entries:
|
||||
new_grouping = await secrets_manager.get_entry_group(name)
|
||||
assert new_grouping is None
|
||||
|
||||
|
||||
@allure.title("Other tests")
|
||||
class TestSecretManagerOther:
|
||||
"""Uncategorized tests to standardize module."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_secret_nonexisting(
|
||||
self, secrets_manager: AsyncSecretContext
|
||||
) -> None:
|
||||
"""Test get_secret with invalid name."""
|
||||
result = await secrets_manager.get_secret("NOMATCH")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"group_name,num_grouped,num_ungrouped",
|
||||
[("GROUP", 3, 3), ("GROUP", 3, 0), ("GROUP", 0, 0)],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_ungrouped_secrets(
|
||||
self,
|
||||
secrets_manager: AsyncSecretContext,
|
||||
group_name: str,
|
||||
num_grouped: int,
|
||||
num_ungrouped: int,
|
||||
) -> None:
|
||||
"""Test get_ungrouped_secrets."""
|
||||
await secrets_manager.add_group(group_name)
|
||||
for n in range(num_ungrouped):
|
||||
await secrets_manager.add_entry(f"ungrouped_{n}", "secret")
|
||||
|
||||
for n in range(num_grouped):
|
||||
await secrets_manager.add_entry(
|
||||
f"grouped_{n}", "secret", group_path="GROUP"
|
||||
)
|
||||
|
||||
ungrouped = await secrets_manager.get_ungrouped_secrets()
|
||||
assert len(ungrouped) == num_ungrouped
|
||||
matching = [entry for entry in ungrouped if entry.startswith("ungrouped_")]
|
||||
assert len(matching) == num_ungrouped
|
||||
|
||||
grouped_secrets = await secrets_manager.get_available_secrets(
|
||||
group_path=group_name
|
||||
)
|
||||
assert len(grouped_secrets) == num_grouped
|
||||
all_secrets = await secrets_manager.get_available_secrets()
|
||||
assert len(all_secrets) == (num_ungrouped + num_grouped)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"entries", [[("test1", "secret1"), ("test2", "secret2")], []]
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_available_secrets(
|
||||
self, secrets_manager: AsyncSecretContext, entries: list[tuple[str, str]]
|
||||
) -> None:
|
||||
"""Test the get_available_secrets method."""
|
||||
for name, secret in entries:
|
||||
await secrets_manager.add_entry(name, secret)
|
||||
|
||||
entry_names = [entry[0] for entry in entries]
|
||||
response = await secrets_manager.get_available_secrets()
|
||||
assert len(response) == len(entries)
|
||||
|
||||
assert sorted(response) == sorted(entry_names)
|
||||
|
||||
async def test_get_secret_groups_none(
|
||||
self, secrets_manager: AsyncSecretContext
|
||||
) -> None:
|
||||
"""Test get_secret_groups with no groups created."""
|
||||
result = await secrets_manager.get_secret_groups()
|
||||
assert len(result) == 0
|
||||
result_flat = await secrets_manager.get_secret_group_list()
|
||||
assert len(result_flat) == 0
|
||||
|
||||
@allure.title("Search for a group using regular expression")
|
||||
async def test_group_regex_search(
|
||||
self, secrets_manager: AsyncSecretContext
|
||||
) -> None:
|
||||
"""Search for entries with regular expressions."""
|
||||
groups = [
|
||||
"test1",
|
||||
"test2",
|
||||
"other",
|
||||
"somethingelse",
|
||||
]
|
||||
|
||||
for group in groups:
|
||||
await secrets_manager.add_group(group)
|
||||
|
||||
results = await secrets_manager.get_secret_group_list(
|
||||
pattern="^test", regex=True
|
||||
)
|
||||
assert len(results) == 2
|
||||
for group in results:
|
||||
assert group.name.startswith("test")
|
||||
@ -79,6 +79,32 @@ async def run_backend_server(test_ports: TestPorts):
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(
|
||||
scope=TEST_SCOPE, name="admin_server_settings", loop_scope=LOOP_SCOPE
|
||||
)
|
||||
async def get_admin_server_settings(
|
||||
test_ports: TestPorts, backend_server: tuple[str, str]
|
||||
):
|
||||
"""Get admin server settings."""
|
||||
backend_url, backend_token = backend_server
|
||||
port = test_ports.admin
|
||||
secret_key = secrets.token_urlsafe(32)
|
||||
with in_tempdir() as admin_work_path:
|
||||
admin_db = admin_work_path / "ssh_admin.db"
|
||||
admin_settings = AdminServerSettings.model_validate(
|
||||
{
|
||||
"sshecret_backend_url": backend_url,
|
||||
"backend_token": backend_token,
|
||||
"secret_key": secret_key,
|
||||
"listen_address": "0.0.0.0",
|
||||
"port": port,
|
||||
"database": str(admin_db.absolute()),
|
||||
"password_manager_directory": str(admin_work_path.absolute()),
|
||||
}
|
||||
)
|
||||
yield admin_settings
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope=TEST_SCOPE, name="admin_server", loop_scope=LOOP_SCOPE)
|
||||
async def run_admin_server(test_ports: TestPorts, backend_server: tuple[str, str]):
|
||||
"""Run admin server."""
|
||||
@ -98,7 +124,7 @@ async def run_admin_server(test_ports: TestPorts, backend_server: tuple[str, str
|
||||
"password_manager_directory": str(admin_work_path.absolute()),
|
||||
}
|
||||
)
|
||||
admin_app = create_admin_app(admin_settings)
|
||||
admin_app = create_admin_app(admin_settings, create_db=True)
|
||||
config = uvicorn.Config(app=admin_app, port=port, loop="asyncio")
|
||||
server = uvicorn.Server(config=config)
|
||||
server_task = asyncio.create_task(server.serve())
|
||||
|
||||
Reference in New Issue
Block a user