Write new secret manager using existing RSA logic
This commit is contained in:
@ -5,9 +5,9 @@ import logging
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from sshecret_admin.auth import Token, authenticate_user, create_access_token
|
||||
from sshecret_admin.auth import Token, authenticate_user_async, create_access_token
|
||||
from sshecret_admin.core.dependencies import AdminDependencies
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -19,11 +19,12 @@ def create_router(dependencies: AdminDependencies) -> APIRouter:
|
||||
|
||||
@app.post("/token")
|
||||
async def login_for_access_token(
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||
|
||||
session: Annotated[AsyncSession, Depends(dependencies.get_async_session)],
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
) -> Token:
|
||||
"""Login user and generate token."""
|
||||
user = authenticate_user(session, form_data.username, form_data.password)
|
||||
user = await authenticate_user_async(session, form_data.username, form_data.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
||||
@ -128,7 +128,7 @@ def create_router(dependencies: AdminDependencies) -> APIRouter:
|
||||
group = await admin.get_secret_group(group_name)
|
||||
if not group:
|
||||
return
|
||||
await admin.delete_secret_group(group_name, keep_entries=True)
|
||||
await admin.delete_secret_group(group_name)
|
||||
|
||||
@app.post("/secrets/groups/{group_name}/{secret_name}")
|
||||
async def move_secret_to_group(
|
||||
|
||||
@ -5,8 +5,9 @@
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi.security.utils import get_authorization_scheme_param
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -57,6 +58,31 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
def get_client_origin(request: Request) -> str:
|
||||
"""Get client origin."""
|
||||
fallback_origin = "UNKNOWN"
|
||||
if request.client:
|
||||
return request.client.host
|
||||
return fallback_origin
|
||||
|
||||
def get_optional_username(request: Request) -> str | None:
|
||||
"""Get username, if available.
|
||||
|
||||
This is purely used for auditing purposes.
|
||||
"""
|
||||
authorization = request.headers.get("Authorization")
|
||||
scheme, param = get_authorization_scheme_param(authorization)
|
||||
if not authorization or scheme.lower() != "bearer":
|
||||
return None
|
||||
claims = decode_token(dependencies.settings, param)
|
||||
if not claims:
|
||||
return None
|
||||
|
||||
if claims.provider == LOCAL_ISSUER:
|
||||
return claims.sub
|
||||
|
||||
return f"oidc:{claims.email}"
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
) -> User:
|
||||
@ -66,9 +92,12 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
return current_user
|
||||
|
||||
async def get_admin_backend(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||
):
|
||||
"""Get admin backend API."""
|
||||
username = get_optional_username(request)
|
||||
origin = get_client_origin(request)
|
||||
password_db = session.scalars(
|
||||
select(PasswordDB).where(PasswordDB.id == 1)
|
||||
).first()
|
||||
@ -76,7 +105,11 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
raise HTTPException(
|
||||
500, detail="Error: The password manager has not yet been set up."
|
||||
)
|
||||
admin = AdminBackend(dependencies.settings, password_db.encrypted_password)
|
||||
admin = AdminBackend(
|
||||
dependencies.settings,
|
||||
username=username,
|
||||
origin=origin,
|
||||
)
|
||||
yield admin
|
||||
|
||||
app = APIRouter(prefix=f"/api/{API_VERSION}")
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
"""Models for authentication."""
|
||||
"""Models for authentication and secret management."""
|
||||
|
||||
import enum
|
||||
from datetime import datetime
|
||||
from typing import override
|
||||
import uuid
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
|
||||
|
||||
JWT_ALGORITHM = "HS256"
|
||||
@ -75,12 +76,15 @@ class PasswordDB(Base):
|
||||
__tablename__: str = "password_db"
|
||||
|
||||
id: Mapped[int] = mapped_column(sa.INT, primary_key=True)
|
||||
encrypted_password: Mapped[str] = mapped_column(sa.String)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||
)
|
||||
|
||||
client_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), nullable=True
|
||||
)
|
||||
|
||||
updated_at: Mapped[datetime | None] = mapped_column(
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
@ -88,6 +92,65 @@ class PasswordDB(Base):
|
||||
)
|
||||
|
||||
|
||||
class Group(Base):
|
||||
"""A secret group."""
|
||||
|
||||
__tablename__: str = "groups"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
name: Mapped[str] = mapped_column(sa.String, nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(sa.String, nullable=True)
|
||||
|
||||
parent_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
sa.ForeignKey("groups.id"), nullable=True
|
||||
)
|
||||
parent: Mapped["Group | None"] = relationship(
|
||||
"Group", remote_side=[id], back_populates="children"
|
||||
)
|
||||
children: Mapped[list["Group"]] = relationship(
|
||||
"Group", back_populates="parent", cascade="all, delete"
|
||||
)
|
||||
secrets: Mapped[list["ManagedSecret"]] = relationship(back_populates="group")
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return f"<Group id={self.id} name={self.name} parent_id={self.parent_id}>"
|
||||
|
||||
|
||||
class ManagedSecret(Base):
|
||||
"""Managed Secret."""
|
||||
|
||||
__tablename__: str = "managed_secrets"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
name: Mapped[str] = mapped_column(sa.String, nullable=False)
|
||||
|
||||
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, default=False)
|
||||
|
||||
group_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
sa.ForeignKey("groups.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
group: Mapped["Group | None"] = relationship(
|
||||
Group, foreign_keys=[group_id], back_populates="secrets"
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime | None] = mapped_column(
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
onupdate=sa.func.now(),
|
||||
)
|
||||
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
|
||||
class IdentityClaims(BaseModel):
|
||||
"""Normalized identity claim model."""
|
||||
|
||||
@ -125,6 +188,3 @@ class LocalUserInfo(BaseModel):
|
||||
local: bool
|
||||
|
||||
|
||||
def init_db(engine: sa.Engine) -> None:
|
||||
"""Create database."""
|
||||
Base.metadata.create_all(engine)
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
#
|
||||
from collections.abc import AsyncGenerator
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
@ -12,15 +13,15 @@ from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sshecret_backend.db import DatabaseSessionManager
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
|
||||
from sshecret_admin import api, frontend
|
||||
from sshecret_admin.auth.models import PasswordDB, init_db
|
||||
from sshecret_admin.auth.models import Base
|
||||
from sshecret_admin.core.db import setup_database
|
||||
from sshecret_admin.frontend.exceptions import RedirectException
|
||||
from sshecret_admin.services.master_password import setup_master_password
|
||||
from sshecret_admin.services.secret_manager import setup_private_key
|
||||
|
||||
from .dependencies import BaseDependencies
|
||||
from .settings import AdminServerSettings
|
||||
@ -40,44 +41,28 @@ def setup_frontend(app: FastAPI, dependencies: BaseDependencies) -> None:
|
||||
|
||||
|
||||
def create_admin_app(
|
||||
settings: AdminServerSettings, with_frontend: bool = True
|
||||
settings: AdminServerSettings,
|
||||
with_frontend: bool = True,
|
||||
create_db: bool = False,
|
||||
) -> FastAPI:
|
||||
"""Create admin app."""
|
||||
engine, get_db_session = setup_database(settings.admin_db)
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Get async session."""
|
||||
session_manager = DatabaseSessionManager(settings.async_db_url)
|
||||
async with session_manager.session() as session:
|
||||
yield session
|
||||
|
||||
def setup_password_manager() -> None:
|
||||
"""Setup password manager."""
|
||||
encr_master_password = setup_master_password(
|
||||
settings=settings, regenerate=False
|
||||
)
|
||||
with Session(engine) as session:
|
||||
existing_password = session.scalars(
|
||||
select(PasswordDB).where(PasswordDB.id == 1)
|
||||
).first()
|
||||
|
||||
if not encr_master_password:
|
||||
if existing_password:
|
||||
LOG.info("Master password already defined.")
|
||||
return
|
||||
# Looks like we have to regenerate it
|
||||
LOG.warning(
|
||||
"Master password was set, but not saved to the database. Regenerating it."
|
||||
)
|
||||
encr_master_password = setup_master_password(
|
||||
settings=settings, regenerate=True
|
||||
)
|
||||
|
||||
assert encr_master_password is not None
|
||||
|
||||
with Session(engine) as session:
|
||||
pwdb = PasswordDB(id=1, encrypted_password=encr_master_password)
|
||||
session.add(pwdb)
|
||||
session.commit()
|
||||
setup_private_key(settings, regenerate=False)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI):
|
||||
"""Create database before starting the server."""
|
||||
init_db(engine)
|
||||
if create_db:
|
||||
Base.metadata.create_all(engine)
|
||||
setup_password_manager()
|
||||
yield
|
||||
|
||||
@ -109,7 +94,7 @@ def create_admin_app(
|
||||
status_code=status.HTTP_200_OK, content=jsonable_encoder({"status": "LIVE"})
|
||||
)
|
||||
|
||||
dependencies = BaseDependencies(settings, get_db_session)
|
||||
dependencies = BaseDependencies(settings, get_db_session, get_async_session)
|
||||
|
||||
app.include_router(api.create_api_router(dependencies))
|
||||
if with_frontend:
|
||||
|
||||
@ -12,7 +12,7 @@ from pydantic import ValidationError
|
||||
from sqlalchemy import select, create_engine
|
||||
from sqlalchemy.orm import Session
|
||||
from sshecret_admin.auth.authentication import hash_password
|
||||
from sshecret_admin.auth.models import AuthProvider, PasswordDB, User, init_db
|
||||
from sshecret_admin.auth.models import AuthProvider, PasswordDB, User
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
from sshecret_admin.services.admin_backend import AdminBackend
|
||||
|
||||
@ -72,7 +72,6 @@ def cli_create_user(
|
||||
"""Create user."""
|
||||
settings = cast(AdminServerSettings, ctx.obj)
|
||||
engine = create_engine(settings.admin_db)
|
||||
init_db(engine)
|
||||
with Session(engine) as session:
|
||||
create_user(session, username, email, password)
|
||||
|
||||
@ -87,7 +86,6 @@ def cli_change_user_passwd(ctx: click.Context, username: str, password: str) ->
|
||||
"""Change password on user."""
|
||||
settings = cast(AdminServerSettings, ctx.obj)
|
||||
engine = create_engine(settings.admin_db)
|
||||
init_db(engine)
|
||||
with Session(engine) as session:
|
||||
user = session.scalars(select(User).where(User.username == username)).first()
|
||||
if not user:
|
||||
@ -107,7 +105,6 @@ def cli_delete_user(ctx: click.Context, username: str) -> None:
|
||||
"""Remove a user."""
|
||||
settings = cast(AdminServerSettings, ctx.obj)
|
||||
engine = create_engine(settings.admin_db)
|
||||
init_db(engine)
|
||||
with Session(engine) as session:
|
||||
user = session.scalars(select(User).where(User.username == username)).first()
|
||||
if not user:
|
||||
@ -149,7 +146,6 @@ def cli_repl(ctx: click.Context) -> None:
|
||||
"""Run an interactive console."""
|
||||
settings = cast(AdminServerSettings, ctx.obj)
|
||||
engine = create_engine(settings.admin_db)
|
||||
init_db(engine)
|
||||
with Session(engine) as session:
|
||||
password_db = session.scalars(
|
||||
select(PasswordDB).where(PasswordDB.id == 1)
|
||||
@ -165,7 +161,7 @@ def cli_repl(ctx: click.Context) -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(func)
|
||||
|
||||
admin = AdminBackend(settings, password_db.encrypted_password)
|
||||
admin = AdminBackend(settings, )
|
||||
locals = {
|
||||
"run": run,
|
||||
"admin": admin,
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
"""Database setup."""
|
||||
|
||||
import sqlite3
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from collections.abc import AsyncIterator, Generator, Callable
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.engine import URL
|
||||
from sqlalchemy import create_engine, Engine
|
||||
from sqlalchemy import create_engine, Engine, event
|
||||
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncConnection,
|
||||
@ -18,11 +19,20 @@ from sqlalchemy.ext.asyncio import (
|
||||
|
||||
|
||||
def setup_database(
|
||||
db_url: URL | str,
|
||||
db_url: URL,
|
||||
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
|
||||
"""Setup database."""
|
||||
|
||||
engine = create_engine(db_url, echo=True, future=True)
|
||||
if db_url.drivername.startswith("sqlite"):
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(
|
||||
dbapi_connection: sqlite3.Connection, _connection_record: object
|
||||
) -> None:
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
def get_db_session() -> Generator[Session, None, None]:
|
||||
"""Get DB Session."""
|
||||
@ -33,8 +43,18 @@ def setup_database(
|
||||
|
||||
|
||||
class DatabaseSessionManager:
|
||||
def __init__(self, host: URL | str, **engine_kwargs: str):
|
||||
def __init__(self, host: URL, **engine_kwargs: str):
|
||||
self._engine: AsyncEngine | None = create_async_engine(host, **engine_kwargs)
|
||||
if host.drivername.startswith("sqlite+"):
|
||||
|
||||
@event.listens_for(self._engine.sync_engine, "connect")
|
||||
def set_sqlite_pragma(
|
||||
dbapi_connection: sqlite3.Connection, _connection_record: object
|
||||
) -> None:
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
self._sessionmaker: async_sessionmaker[AsyncSession] | None = (
|
||||
async_sessionmaker(
|
||||
autocommit=False, bind=self._engine, expire_on_commit=False
|
||||
|
||||
@ -4,6 +4,8 @@ from collections.abc import AsyncGenerator, Awaitable, Callable, Generator
|
||||
from dataclasses import dataclass
|
||||
from typing import Self
|
||||
|
||||
from fastapi import Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
from sshecret_admin.auth import User
|
||||
from sshecret_admin.services import AdminBackend
|
||||
@ -11,8 +13,9 @@ from sshecret_admin.core.settings import AdminServerSettings
|
||||
|
||||
|
||||
DBSessionDep = Callable[[], Generator[Session, None, None]]
|
||||
AsyncSessionDep = Callable[[], AsyncGenerator[AsyncSession, None]]
|
||||
|
||||
AdminDep = Callable[[Session], AsyncGenerator[AdminBackend, None]]
|
||||
AdminDep = Callable[[Request, Session], AsyncGenerator[AdminBackend, None]]
|
||||
|
||||
GetUserDep = Callable[[User], Awaitable[User]]
|
||||
|
||||
@ -23,6 +26,8 @@ class BaseDependencies:
|
||||
|
||||
settings: AdminServerSettings
|
||||
get_db_session: DBSessionDep
|
||||
get_async_session: AsyncSessionDep
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -43,6 +48,7 @@ class AdminDependencies(BaseDependencies):
|
||||
return cls(
|
||||
settings=deps.settings,
|
||||
get_db_session=deps.get_db_session,
|
||||
get_async_session=deps.get_async_session,
|
||||
get_admin_backend=get_admin_backend,
|
||||
get_current_active_user=get_current_active_user,
|
||||
)
|
||||
|
||||
@ -30,7 +30,6 @@ class FrontendDependencies(BaseDependencies):
|
||||
get_refresh_claims: RefreshTokenDep
|
||||
get_login_status: LoginStatusDep
|
||||
get_user_info: UserInfoDep
|
||||
get_async_session: AsyncSessionDep
|
||||
require_login: LoginGuardDep
|
||||
|
||||
@classmethod
|
||||
@ -42,18 +41,17 @@ class FrontendDependencies(BaseDependencies):
|
||||
get_refresh_claims: RefreshTokenDep,
|
||||
get_login_status: LoginStatusDep,
|
||||
get_user_info: UserInfoDep,
|
||||
get_async_session: AsyncSessionDep,
|
||||
require_login: LoginGuardDep,
|
||||
) -> Self:
|
||||
"""Create from base dependencies."""
|
||||
return cls(
|
||||
settings=deps.settings,
|
||||
get_db_session=deps.get_db_session,
|
||||
get_async_session=deps.get_async_session,
|
||||
get_admin_backend=get_admin_backend,
|
||||
templates=templates,
|
||||
get_refresh_claims=get_refresh_claims,
|
||||
get_login_status=get_login_status,
|
||||
get_user_info=get_user_info,
|
||||
get_async_session=get_async_session,
|
||||
require_login=require_login,
|
||||
)
|
||||
|
||||
@ -24,7 +24,6 @@ from sshecret_admin.auth.constants import LOCAL_ISSUER
|
||||
|
||||
from sshecret_admin.core.dependencies import BaseDependencies
|
||||
from sshecret_admin.services.admin_backend import AdminBackend
|
||||
from sshecret_admin.core.db import DatabaseSessionManager
|
||||
|
||||
from .dependencies import FrontendDependencies
|
||||
from .exceptions import RedirectException
|
||||
@ -50,17 +49,24 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
templates = Jinja2Blocks(directory=template_path)
|
||||
|
||||
async def get_admin_backend(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||
):
|
||||
"""Get admin backend API."""
|
||||
password_db = session.scalars(
|
||||
select(PasswordDB).where(PasswordDB.id == 1)
|
||||
).first()
|
||||
username = get_optional_username(request)
|
||||
origin = get_client_origin(request)
|
||||
if not password_db:
|
||||
raise HTTPException(
|
||||
500, detail="Error: The password manager has not yet been set up."
|
||||
)
|
||||
admin = AdminBackend(dependencies.settings, password_db.encrypted_password)
|
||||
admin = AdminBackend(
|
||||
dependencies.settings,
|
||||
username=username,
|
||||
origin=origin,
|
||||
)
|
||||
yield admin
|
||||
|
||||
def get_identity_claims(request: Request) -> IdentityClaims:
|
||||
@ -108,14 +114,9 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
next = URL("/refresh").include_query_params(next=request.url.path)
|
||||
raise RedirectException(to=next)
|
||||
|
||||
async def get_async_session():
|
||||
"""Get async session."""
|
||||
sessionmanager = DatabaseSessionManager(dependencies.settings.async_db_url)
|
||||
async with sessionmanager.session() as session:
|
||||
yield session
|
||||
|
||||
async def get_user_info(
|
||||
request: Request, session: Annotated[AsyncSession, Depends(get_async_session)]
|
||||
request: Request,
|
||||
session: Annotated[AsyncSession, Depends(dependencies.get_async_session)],
|
||||
) -> LocalUserInfo:
|
||||
"""Get User information."""
|
||||
claims = get_identity_claims(request)
|
||||
@ -142,6 +143,30 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
next = URL("/refresh").include_query_params(next=request.url.path)
|
||||
raise RedirectException(to=next)
|
||||
|
||||
def get_optional_username(
|
||||
request: Request,
|
||||
) -> str | None:
|
||||
"""Get username, if available.
|
||||
|
||||
This is purely used for auditing purposes.
|
||||
"""
|
||||
try:
|
||||
claims = get_identity_claims(request)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if claims.provider == LOCAL_ISSUER:
|
||||
return claims.sub
|
||||
|
||||
return f"oidc:{claims.email}"
|
||||
|
||||
def get_client_origin(request: Request) -> str:
|
||||
"""Get client origin."""
|
||||
fallback_origin = "UNKNOWN"
|
||||
if request.client:
|
||||
return request.client.host
|
||||
return fallback_origin
|
||||
|
||||
view_dependencies = FrontendDependencies.create(
|
||||
dependencies,
|
||||
get_admin_backend,
|
||||
@ -149,7 +174,6 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
refresh_identity_claims,
|
||||
get_login_status,
|
||||
get_user_info,
|
||||
get_async_session,
|
||||
require_login,
|
||||
)
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
<!-- Detail Pane -->
|
||||
|
||||
<section id="detail-pane"
|
||||
class="flex-1 flex overflow-y-auto bg-white p-4 {%- if not mobile_show_details|default(false) -%} hidden{%- endif -%} lg:block dark:bg-gray-800">
|
||||
class="flex-1 flex overflow-y-auto bg-white p-4 lg:block {% if not mobile_show_details|default(false) -%} hidden{%- endif -%} lg:block dark:bg-gray-800">
|
||||
|
||||
|
||||
{% block detail %}
|
||||
|
||||
@ -5,7 +5,7 @@ import ipaddress
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request, Response
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException, Request, Response
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel, IPvAnyAddress, IPvAnyNetwork
|
||||
from sshecret_admin.frontend.views.common import PagingInfo
|
||||
@ -209,7 +209,7 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
page: int,
|
||||
) -> Response:
|
||||
"""Get more events for a client."""
|
||||
if not "HX-Request" in request.headers:
|
||||
if "HX-Request" not in request.headers:
|
||||
return RedirectResponse(url=f"/clients/client/{id}")
|
||||
|
||||
client = await admin.get_client(("id", id))
|
||||
|
||||
@ -4,8 +4,8 @@ Since we have a frontend and a REST API, it makes sense to have a generic librar
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from sshecret.backend import (
|
||||
AuditLog,
|
||||
@ -20,7 +20,7 @@ from sshecret.backend.models import ClientQueryResult, DetailedSecrets
|
||||
from sshecret.backend.api import AuditAPI, KeySpec
|
||||
from sshecret.crypto import encrypt_string, load_public_key
|
||||
|
||||
from .keepass import PasswordContext, load_password_manager
|
||||
from .secret_manager import AsyncSecretContext, password_manager_context
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
from .models import (
|
||||
ClientSecretGroup,
|
||||
@ -86,19 +86,27 @@ def add_clients_to_secret_group(
|
||||
class AdminBackend:
|
||||
"""Admin backend API."""
|
||||
|
||||
def __init__(self, settings: AdminServerSettings, keepass_password: str) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
settings: AdminServerSettings,
|
||||
username: str | None = None,
|
||||
origin: str = "UNKNOWN",
|
||||
) -> None:
|
||||
"""Create client management API."""
|
||||
self.settings: AdminServerSettings = settings
|
||||
self.backend: SshecretBackend = SshecretBackend(
|
||||
str(settings.backend_url), settings.backend_token
|
||||
)
|
||||
self.keepass_password: str = keepass_password
|
||||
self.username: str = username or "UKNOWN_USER"
|
||||
self.origin: str = origin
|
||||
|
||||
@contextmanager
|
||||
def password_manager(self) -> Iterator[PasswordContext]:
|
||||
"""Open the password manager."""
|
||||
with load_password_manager(self.settings, self.keepass_password) as kp:
|
||||
yield kp
|
||||
@asynccontextmanager
|
||||
async def secrets_manager(self) -> AsyncIterator[AsyncSecretContext]:
|
||||
"""Open the secrets manager."""
|
||||
async with password_manager_context(
|
||||
self.settings, self.username, self.origin
|
||||
) as manager:
|
||||
yield manager
|
||||
|
||||
async def _get_clients(self, filter: ClientFilter | None = None) -> list[Client]:
|
||||
"""Get clients from backend."""
|
||||
@ -194,7 +202,7 @@ class AdminBackend:
|
||||
self,
|
||||
name: KeySpec,
|
||||
new_key: str,
|
||||
password_manager: PasswordContext,
|
||||
password_manager: AsyncSecretContext,
|
||||
) -> list[str]:
|
||||
"""Update client public key."""
|
||||
LOG.info(
|
||||
@ -207,7 +215,7 @@ class AdminBackend:
|
||||
updated_secrets: list[str] = []
|
||||
for secret in client.secrets:
|
||||
LOG.debug("Re-encrypting secret %s for client %s", secret, name)
|
||||
secret_value = password_manager.get_secret(secret)
|
||||
secret_value = await password_manager.get_secret(secret)
|
||||
if not secret_value:
|
||||
LOG.warning(
|
||||
"Referenced secret %s does not exist! Skipping.", secret_value
|
||||
@ -224,7 +232,7 @@ class AdminBackend:
|
||||
async def update_client_public_key(self, name: KeySpec, new_key: str) -> list[str]:
|
||||
"""Update client public key."""
|
||||
try:
|
||||
with self.password_manager() as password_manager:
|
||||
async with self.secrets_manager() as password_manager:
|
||||
return await self._update_client_public_key(
|
||||
name, new_key, password_manager
|
||||
)
|
||||
@ -291,8 +299,8 @@ class AdminBackend:
|
||||
This fetches the secret to client mapping from backend, and adds secrets from the password manager.
|
||||
"""
|
||||
backend_secrets = await self.backend.get_secrets()
|
||||
with self.password_manager() as password_manager:
|
||||
admin_secrets = password_manager.get_available_secrets()
|
||||
async with self.secrets_manager() as password_manager:
|
||||
admin_secrets = await password_manager.get_available_secrets()
|
||||
|
||||
secrets: dict[str, SecretListView] = {}
|
||||
for secret in backend_secrets:
|
||||
@ -324,8 +332,8 @@ class AdminBackend:
|
||||
|
||||
This fetches the secret to client mapping from backend, and adds secrets from the password manager.
|
||||
"""
|
||||
with self.password_manager() as password_manager:
|
||||
all_secrets = password_manager.get_available_secrets()
|
||||
async with self.secrets_manager() as password_manager:
|
||||
all_secrets = await password_manager.get_available_secrets()
|
||||
|
||||
secrets = await self.backend.get_detailed_secrets()
|
||||
backend_secret_names = [secret.name for secret in secrets]
|
||||
@ -351,13 +359,13 @@ class AdminBackend:
|
||||
parent_group: str | None = None,
|
||||
) -> None:
|
||||
"""Add secret group."""
|
||||
with self.password_manager() as password_manager:
|
||||
password_manager.add_group(group_name, description, parent_group)
|
||||
async with self.secrets_manager() as password_manager:
|
||||
await password_manager.add_group(group_name, description, parent_group)
|
||||
|
||||
async def set_secret_group(self, secret_name: str, group_name: str | None) -> None:
|
||||
"""Assign a group to a secret."""
|
||||
with self.password_manager() as password_manager:
|
||||
password_manager.set_secret_group(secret_name, group_name)
|
||||
async with self.secrets_manager() as password_manager:
|
||||
await password_manager.set_secret_group(secret_name, group_name)
|
||||
|
||||
async def move_secret_group(
|
||||
self, group_name: str, parent_group: str | None
|
||||
@ -366,23 +374,21 @@ class AdminBackend:
|
||||
|
||||
If parent_group is None, it will be moved to the root.
|
||||
"""
|
||||
with self.password_manager() as password_manager:
|
||||
password_manager.move_group(group_name, parent_group)
|
||||
async with self.secrets_manager() as password_manager:
|
||||
await password_manager.move_group(group_name, parent_group)
|
||||
|
||||
async def set_group_description(self, group_name: str, description: str) -> None:
|
||||
"""Set a group description."""
|
||||
with self.password_manager() as password_manager:
|
||||
password_manager.set_group_description(group_name, description)
|
||||
async with self.secrets_manager() as password_manager:
|
||||
await password_manager.set_group_description(group_name, description)
|
||||
|
||||
async def delete_secret_group(
|
||||
self, group_name: str, keep_entries: bool = True
|
||||
) -> None:
|
||||
async def delete_secret_group(self, group_name: str) -> None:
|
||||
"""Delete a group.
|
||||
|
||||
If keep_entries is set to False, all entries in the group will be deleted.
|
||||
"""
|
||||
with self.password_manager() as password_manager:
|
||||
password_manager.delete_group(group_name, keep_entries)
|
||||
async with self.secrets_manager() as password_manager:
|
||||
await password_manager.delete_group(group_name)
|
||||
|
||||
async def get_secret_groups(
|
||||
self,
|
||||
@ -399,18 +405,18 @@ class AdminBackend:
|
||||
"""
|
||||
all_secrets = await self.backend.get_detailed_secrets()
|
||||
secrets_mapping = {secret.name: secret for secret in all_secrets}
|
||||
with self.password_manager() as password_manager:
|
||||
async with self.secrets_manager() as password_manager:
|
||||
if flat:
|
||||
all_groups = password_manager.get_secret_group_list(
|
||||
all_groups = await password_manager.get_secret_group_list(
|
||||
group_filter, regex=regex
|
||||
)
|
||||
else:
|
||||
all_groups = password_manager.get_secret_groups(
|
||||
all_groups = await password_manager.get_secret_groups(
|
||||
group_filter, regex=regex
|
||||
)
|
||||
ungrouped = password_manager.get_ungrouped_secrets()
|
||||
ungrouped = await password_manager.get_ungrouped_secrets()
|
||||
|
||||
all_admin_secrets = password_manager.get_available_secrets()
|
||||
all_admin_secrets = await password_manager.get_available_secrets()
|
||||
|
||||
group_result: list[ClientSecretGroup] = []
|
||||
for group in all_groups:
|
||||
@ -452,8 +458,8 @@ class AdminBackend:
|
||||
|
||||
async def get_secret_group_by_path(self, path: str) -> ClientSecretGroup | None:
|
||||
"""Get a group based on its path."""
|
||||
with self.password_manager() as password_manager:
|
||||
secret_group = password_manager.get_secret_group(path)
|
||||
async with self.secrets_manager() as password_manager:
|
||||
secret_group = await password_manager.get_secret_group(path)
|
||||
|
||||
if not secret_group:
|
||||
return None
|
||||
@ -476,9 +482,11 @@ class AdminBackend:
|
||||
) -> SecretView | None:
|
||||
"""Get a secret, including the actual unencrypted value and clients."""
|
||||
secret: str | None = None
|
||||
with self.password_manager() as password_manager:
|
||||
secret = password_manager.get_secret(name)
|
||||
secret_group = password_manager.get_entry_group(name)
|
||||
async with self.secrets_manager() as password_manager:
|
||||
secret = await password_manager.get_secret(name)
|
||||
secret_group: str | None = None
|
||||
if secret:
|
||||
secret_group = await password_manager.get_entry_group(name)
|
||||
|
||||
secret_view = SecretView(name=name, secret=secret, group=secret_group)
|
||||
|
||||
@ -503,8 +511,8 @@ class AdminBackend:
|
||||
|
||||
async def _delete_secret(self, name: str) -> None:
|
||||
"""Delete a secret."""
|
||||
with self.password_manager() as password_manager:
|
||||
password_manager.delete_entry(name)
|
||||
async with self.secrets_manager() as password_manager:
|
||||
await password_manager.delete_entry(name)
|
||||
|
||||
secret_mapping = await self.backend.get_secret(name)
|
||||
if not secret_mapping:
|
||||
@ -522,8 +530,8 @@ class AdminBackend:
|
||||
group: str | None = None,
|
||||
) -> None:
|
||||
"""Add a secret."""
|
||||
with self.password_manager() as password_manager:
|
||||
password_manager.add_entry(name, value, update, group_name=group)
|
||||
async with self.secrets_manager() as password_manager:
|
||||
await password_manager.add_entry(name, value, update, group_path=group)
|
||||
|
||||
if update:
|
||||
secret_map = await self.backend.get_secret(name)
|
||||
@ -576,8 +584,8 @@ class AdminBackend:
|
||||
if not client:
|
||||
raise ClientNotFoundError(client_idname)
|
||||
|
||||
with self.password_manager() as password_manager:
|
||||
secret = password_manager.get_secret(secret_name)
|
||||
async with self.secrets_manager() as password_manager:
|
||||
secret = await password_manager.get_secret(secret_name)
|
||||
if not secret:
|
||||
raise SecretNotFoundError()
|
||||
|
||||
|
||||
@ -1,348 +0,0 @@
|
||||
"""Keepass password manager."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import pykeepass
|
||||
import pykeepass.exceptions
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
|
||||
from .models import SecretGroup
|
||||
from .master_password import decrypt_master_password
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
NO_USERNAME = "NO_USERNAME"
|
||||
|
||||
DEFAULT_LOCATION = "keepass.kdbx"
|
||||
|
||||
|
||||
class PasswordCredentialsError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def create_password_db(location: Path, password: str) -> None:
|
||||
"""Create the password database."""
|
||||
LOG.info("Creating password database at %s", location)
|
||||
pykeepass.create_database(str(location.absolute()), password=password)
|
||||
|
||||
|
||||
def _kp_group_to_secret_group(
|
||||
kp_group: pykeepass.group.Group,
|
||||
parent: SecretGroup | None = None,
|
||||
depth: int | None = None,
|
||||
) -> SecretGroup:
|
||||
"""Convert keepass group to secret group dataclass."""
|
||||
group_name = cast(str, kp_group.name)
|
||||
path = "/".join(cast(list[str], kp_group.path))
|
||||
group = SecretGroup(name=group_name, path=path, description=kp_group.notes)
|
||||
for entry in kp_group.entries:
|
||||
group.entries.append(str(entry.title))
|
||||
if parent:
|
||||
group.parent_group = parent
|
||||
|
||||
current_depth = len(kp_group.path)
|
||||
|
||||
if not parent and current_depth > 1:
|
||||
parent = _kp_group_to_secret_group(kp_group.parentgroup, depth=current_depth)
|
||||
parent.children.append(group)
|
||||
group.parent_group = parent
|
||||
|
||||
if depth and depth == current_depth:
|
||||
return group
|
||||
|
||||
for subgroup in kp_group.subgroups:
|
||||
group.children.append(_kp_group_to_secret_group(subgroup, group, depth=depth))
|
||||
|
||||
return group
|
||||
|
||||
|
||||
class PasswordContext:
|
||||
"""Password Context class."""
|
||||
|
||||
def __init__(self, keepass: pykeepass.PyKeePass) -> None:
|
||||
"""Initialize password context."""
|
||||
self.keepass: pykeepass.PyKeePass = keepass
|
||||
|
||||
@property
|
||||
def _root_group(self) -> pykeepass.group.Group:
|
||||
"""Return the root group."""
|
||||
return cast(pykeepass.group.Group, self.keepass.root_group)
|
||||
|
||||
def _get_entry(self, name: str) -> pykeepass.entry.Entry | None:
|
||||
"""Get entry."""
|
||||
entry = cast(
|
||||
"pykeepass.entry.Entry | None",
|
||||
self.keepass.find_entries(title=name, first=True),
|
||||
)
|
||||
return entry
|
||||
|
||||
def _get_group(self, name: str) -> pykeepass.group.Group | None:
|
||||
"""Find a group."""
|
||||
group = cast(
|
||||
pykeepass.group.Group | None,
|
||||
self.keepass.find_groups(name=name, first=True),
|
||||
)
|
||||
return group
|
||||
|
||||
def add_entry(
|
||||
self,
|
||||
entry_name: str,
|
||||
secret: str,
|
||||
overwrite: bool = False,
|
||||
group_name: str | None = None,
|
||||
) -> None:
|
||||
"""Add an entry.
|
||||
|
||||
Specify overwrite=True to overwrite the existing secret value, if it exists.
|
||||
This will not move the entry, if the group_name is different from the original group.
|
||||
|
||||
"""
|
||||
entry = self._get_entry(entry_name)
|
||||
if entry and overwrite:
|
||||
entry.password = secret
|
||||
self.keepass.save()
|
||||
return
|
||||
|
||||
if entry:
|
||||
raise ValueError("Error: A secret with this name already exists.")
|
||||
LOG.debug("Add secret entry to keepass: %s, group: %r", entry_name, group_name)
|
||||
if group_name:
|
||||
destination_group = self._get_group(group_name)
|
||||
else:
|
||||
destination_group = self._root_group
|
||||
|
||||
entry = self.keepass.add_entry(
|
||||
destination_group=destination_group,
|
||||
title=entry_name,
|
||||
username=NO_USERNAME,
|
||||
password=secret,
|
||||
)
|
||||
self.keepass.save()
|
||||
|
||||
def get_secret(self, entry_name: str) -> str | None:
|
||||
"""Get the secret value."""
|
||||
entry = self._get_entry(entry_name)
|
||||
if not entry:
|
||||
return None
|
||||
|
||||
LOG.warning("Secret name %s accessed", entry_name)
|
||||
if password := cast(str, entry.password):
|
||||
return str(password)
|
||||
|
||||
raise RuntimeError(f"Cannot get password for entry {entry_name}")
|
||||
|
||||
def get_entry_group(self, entry_name: str) -> str | None:
|
||||
"""Get the group for an entry."""
|
||||
entry = self._get_entry(entry_name)
|
||||
if not entry:
|
||||
return None
|
||||
if entry.group.is_root_group:
|
||||
return None
|
||||
return str(entry.group.name)
|
||||
|
||||
def get_secret_groups(
|
||||
self, pattern: str | None = None, regex: bool = True
|
||||
) -> list[SecretGroup]:
|
||||
"""Get secret groups.
|
||||
|
||||
A regex pattern may be provided to filter groups.
|
||||
"""
|
||||
if pattern:
|
||||
groups = cast(
|
||||
list[pykeepass.group.Group],
|
||||
self.keepass.find_groups(name=pattern, regex=regex),
|
||||
)
|
||||
else:
|
||||
groups = self._root_group.subgroups
|
||||
|
||||
secret_groups = [_kp_group_to_secret_group(group) for group in groups]
|
||||
return secret_groups
|
||||
|
||||
def get_secret_group_list(
|
||||
self, pattern: str | None = None, regex: bool = True
|
||||
) -> list[SecretGroup]:
|
||||
"""Get a flat list of groups."""
|
||||
if pattern:
|
||||
return self.get_secret_groups(pattern, regex)
|
||||
|
||||
groups = [group for group in self.keepass.groups if not group.is_root_group]
|
||||
secret_groups = [_kp_group_to_secret_group(group) for group in groups]
|
||||
return secret_groups
|
||||
|
||||
def get_secret_group(self, path: str) -> SecretGroup | None:
|
||||
"""Get a secret group by path."""
|
||||
elements = path.split("/")
|
||||
final_element = elements[-1]
|
||||
|
||||
current = self._root_group
|
||||
while elements:
|
||||
groupname = elements.pop(0)
|
||||
matches = [
|
||||
subgroup for subgroup in current.subgroups if subgroup.name == groupname
|
||||
]
|
||||
if matches:
|
||||
current = matches[0]
|
||||
else:
|
||||
return None
|
||||
if not current.is_root_group and current.name == final_element:
|
||||
return _kp_group_to_secret_group(current)
|
||||
return None
|
||||
|
||||
def get_ungrouped_secrets(self) -> list[str]:
|
||||
"""Get secrets without groups."""
|
||||
entries: list[str] = []
|
||||
for entry in self._root_group.entries:
|
||||
entries.append(str(entry.title))
|
||||
|
||||
return entries
|
||||
|
||||
def add_group(
|
||||
self, name: str, description: str | None = None, parent_group: str | None = None
|
||||
) -> None:
|
||||
"""Add a group."""
|
||||
kp_parent_group = self._root_group
|
||||
if parent_group:
|
||||
query = cast(
|
||||
pykeepass.group.Group | None,
|
||||
self.keepass.find_groups(name=parent_group, first=True),
|
||||
)
|
||||
if not query:
|
||||
raise ValueError(
|
||||
f"Error: Cannot find a parent group named {parent_group}"
|
||||
)
|
||||
kp_parent_group = query
|
||||
self.keepass.add_group(
|
||||
destination_group=kp_parent_group, group_name=name, notes=description
|
||||
)
|
||||
self.keepass.save()
|
||||
|
||||
def set_group_description(self, name: str, description: str) -> None:
|
||||
"""Set the description of a group."""
|
||||
group = self._get_group(name)
|
||||
if not group:
|
||||
raise ValueError(f"Error: No such group {name}")
|
||||
|
||||
group.notes = description
|
||||
self.keepass.save()
|
||||
|
||||
def set_secret_group(self, entry_name: str, group_name: str | None) -> None:
|
||||
"""Move a secret to a group.
|
||||
|
||||
If group is None, the secret will be placed in the root group.
|
||||
"""
|
||||
entry = self._get_entry(entry_name)
|
||||
if not entry:
|
||||
raise ValueError(
|
||||
f"Cannot find secret entry named {entry_name} in secrets database"
|
||||
)
|
||||
if group_name:
|
||||
group = self._get_group(group_name)
|
||||
if not group:
|
||||
raise ValueError(f"Cannot find a group named {group_name}")
|
||||
else:
|
||||
group = self._root_group
|
||||
|
||||
self.keepass.move_entry(entry, group)
|
||||
self.keepass.save()
|
||||
|
||||
def move_group(self, name: str, parent_group: str | None) -> None:
|
||||
"""Move a group.
|
||||
|
||||
If parent_group is None, it will be moved to the root.
|
||||
"""
|
||||
group = self._get_group(name)
|
||||
if not group:
|
||||
raise ValueError(f"Error: No such group {name}")
|
||||
if parent_group:
|
||||
parent = self._get_group(parent_group)
|
||||
if not parent:
|
||||
raise ValueError(f"Error: No such group {parent_group}")
|
||||
else:
|
||||
parent = self._root_group
|
||||
|
||||
self.keepass.move_group(group, parent)
|
||||
self.keepass.save()
|
||||
|
||||
def get_available_secrets(self, group_name: str | None = None) -> list[str]:
|
||||
"""Get the names of all secrets in the database."""
|
||||
if group_name:
|
||||
group = self._get_group(group_name)
|
||||
if not group:
|
||||
raise ValueError(f"Error: No such group {group_name}")
|
||||
entries = group.entries
|
||||
else:
|
||||
entries = cast(list[pykeepass.entry.Entry], self.keepass.entries)
|
||||
if not entries:
|
||||
return []
|
||||
return [str(entry.title) for entry in entries]
|
||||
|
||||
def delete_entry(self, entry_name: str) -> None:
|
||||
"""Delete entry."""
|
||||
entry = cast(
|
||||
"pykeepass.entry.Entry | None",
|
||||
self.keepass.find_entries(title=entry_name, first=True),
|
||||
)
|
||||
if not entry:
|
||||
return
|
||||
entry.delete()
|
||||
self.keepass.save()
|
||||
|
||||
def delete_group(self, name: str, keep_entries: bool = True) -> None:
|
||||
"""Delete a group.
|
||||
|
||||
If keep_entries is set to False, all entries in the group will be deleted.
|
||||
"""
|
||||
group = self._get_group(name)
|
||||
if not group:
|
||||
return
|
||||
if keep_entries:
|
||||
for entry in cast(
|
||||
list[pykeepass.entry.Entry],
|
||||
self.keepass.find_entries(recursive=True, group=group),
|
||||
):
|
||||
# Move the entry to the root group.
|
||||
LOG.warning(
|
||||
"Moving orphaned secret entry %s to root group", entry.title
|
||||
)
|
||||
self.keepass.move_entry(entry, self._root_group)
|
||||
|
||||
self.keepass.delete_group(group)
|
||||
self.keepass.save()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _password_context(location: Path, password: str) -> Iterator[PasswordContext]:
|
||||
"""Open the password context."""
|
||||
try:
|
||||
database = pykeepass.PyKeePass(str(location.absolute()), password=password)
|
||||
except pykeepass.exceptions.CredentialsError as e:
|
||||
raise PasswordCredentialsError(
|
||||
"Could not open password database. Invalid credentials."
|
||||
) from e
|
||||
context = PasswordContext(database)
|
||||
yield context
|
||||
|
||||
|
||||
@contextmanager
|
||||
def load_password_manager(
|
||||
settings: AdminServerSettings,
|
||||
encrypted_password: str,
|
||||
location: str = DEFAULT_LOCATION,
|
||||
) -> Iterator[PasswordContext]:
|
||||
"""Load password manager.
|
||||
|
||||
This function decrypts the password, and creates the password database if it
|
||||
has not yet been created.
|
||||
"""
|
||||
db_location = Path(location)
|
||||
password = decrypt_master_password(settings=settings, encrypted=encrypted_password)
|
||||
if not db_location.exists():
|
||||
create_password_db(db_location, password)
|
||||
|
||||
with _password_context(db_location, password) as context:
|
||||
yield context
|
||||
@ -1,86 +0,0 @@
|
||||
"""Functions related to handling the password database master password."""
|
||||
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
from sshecret.crypto import (
|
||||
create_private_rsa_key,
|
||||
load_private_key,
|
||||
encrypt_string,
|
||||
decode_string,
|
||||
)
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
|
||||
KEY_FILENAME = "sshecret-admin-key"
|
||||
|
||||
|
||||
def setup_master_password(
|
||||
settings: AdminServerSettings,
|
||||
filename: str = KEY_FILENAME,
|
||||
regenerate: bool = False,
|
||||
) -> str | None:
|
||||
"""Setup master password.
|
||||
|
||||
If regenerate is True, a new key will be generated.
|
||||
|
||||
This method should run just after setting up the database.
|
||||
"""
|
||||
keyfile = Path(filename)
|
||||
if settings.password_manager_directory:
|
||||
keyfile = settings.password_manager_directory / filename
|
||||
created = _initial_key_setup(settings, keyfile, regenerate)
|
||||
if not created:
|
||||
return None
|
||||
|
||||
return _generate_master_password(settings, keyfile)
|
||||
|
||||
|
||||
def decrypt_master_password(
|
||||
settings: AdminServerSettings, encrypted: str, filename: str = KEY_FILENAME
|
||||
) -> str:
|
||||
"""Retrieve master password."""
|
||||
keyfile = Path(filename)
|
||||
if settings.password_manager_directory:
|
||||
keyfile = settings.password_manager_directory / filename
|
||||
if not keyfile.exists():
|
||||
raise RuntimeError("Error: Private key has not been generated yet.")
|
||||
|
||||
private_key = load_private_key(
|
||||
str(keyfile.absolute()), password=settings.secret_key
|
||||
)
|
||||
return decode_string(encrypted, private_key)
|
||||
|
||||
|
||||
def _generate_password() -> str:
|
||||
"""Generate a password."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def _initial_key_setup(
|
||||
settings: AdminServerSettings,
|
||||
keyfile: Path,
|
||||
regenerate: bool = False,
|
||||
) -> bool:
|
||||
"""Set up initial keys."""
|
||||
if keyfile.exists() and not regenerate:
|
||||
return False
|
||||
|
||||
assert settings.secret_key is not None, (
|
||||
"Error: Could not load a secret key from environment."
|
||||
)
|
||||
create_private_rsa_key(keyfile, password=settings.secret_key)
|
||||
return True
|
||||
|
||||
|
||||
def _generate_master_password(settings: AdminServerSettings, keyfile: Path) -> str:
|
||||
"""Generate master password for password database.
|
||||
|
||||
Returns the encrypted string, base64 encoded.
|
||||
"""
|
||||
if not keyfile.exists():
|
||||
raise RuntimeError("Error: Private key has not been generated yet.")
|
||||
private_key = load_private_key(
|
||||
str(keyfile.absolute()), password=settings.secret_key
|
||||
)
|
||||
public_key = private_key.public_key()
|
||||
master_password = _generate_password()
|
||||
return encrypt_string(master_password, public_key)
|
||||
@ -0,0 +1,776 @@
|
||||
"""Rewritten secret manager using a rsa keys."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload, aliased
|
||||
from sshecret.backend import SshecretBackend
|
||||
from sshecret.backend.api import AuditAPI, KeySpec
|
||||
from sshecret.backend.models import Client, ClientSecret, Operation, SubSystem
|
||||
from sshecret.crypto import (
|
||||
create_private_rsa_key,
|
||||
decode_string,
|
||||
encrypt_string,
|
||||
generate_public_key_string,
|
||||
load_private_key,
|
||||
load_public_key,
|
||||
)
|
||||
from sshecret_admin.auth import PasswordDB
|
||||
from sshecret_admin.auth.models import Group, ManagedSecret
|
||||
from sshecret_admin.core.db import DatabaseSessionManager
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
from sshecret_admin.services.models import SecretGroup
|
||||
|
||||
|
||||
KEY_FILENAME = "sshecret-admin-key"
|
||||
PASSWORD_MANAGER_ID = "SshecretAdminPasswordManager"
|
||||
|
||||
LOG = logging.getLogger(PASSWORD_MANAGER_ID)
|
||||
|
||||
|
||||
class SecretManagerError(Exception):
|
||||
"""Secret manager error."""
|
||||
|
||||
|
||||
class InvalidGroupNameError(SecretManagerError):
|
||||
"""Invalid group name."""
|
||||
|
||||
|
||||
class InvalidSecretNameError(SecretManagerError):
|
||||
"""Invalid secret name."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientAuditData:
|
||||
"""Client audit data."""
|
||||
|
||||
username: str
|
||||
origin: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedPath:
|
||||
"""Parsed path."""
|
||||
|
||||
item: str
|
||||
full_path: str
|
||||
parent: str | None = None
|
||||
|
||||
|
||||
class SecretDataEntryExport(BaseModel):
|
||||
"""Exportable secret entry."""
|
||||
|
||||
name: str
|
||||
secret: str
|
||||
group: str | None = None
|
||||
|
||||
|
||||
class SecretDataGroupExport(BaseModel):
|
||||
"""Exportable secret grouping."""
|
||||
|
||||
name: str
|
||||
path: str
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class SecretDataExport(BaseModel):
|
||||
"""Exportable object containing secrets and groups."""
|
||||
|
||||
entries: list[SecretDataEntryExport]
|
||||
groups: list[SecretDataGroupExport]
|
||||
|
||||
|
||||
def split_path(path: str) -> list[str]:
|
||||
"""Split a path into a list of groups."""
|
||||
elements = path.split("/")
|
||||
if path.startswith("/"):
|
||||
elements = elements[1:]
|
||||
|
||||
return elements
|
||||
|
||||
|
||||
def parse_path(path: str) -> ParsedPath:
|
||||
"""Parse path."""
|
||||
elements = split_path(path)
|
||||
parsed = ParsedPath(elements[-1], path)
|
||||
if len(elements) > 1:
|
||||
parsed.parent = elements[-2]
|
||||
return parsed
|
||||
|
||||
|
||||
class AsyncSecretContext:
|
||||
"""Async secret context."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
private_key: rsa.RSAPrivateKey,
|
||||
manager_client: Client,
|
||||
session: AsyncSession,
|
||||
backend: SshecretBackend,
|
||||
audit_data: ClientAuditData,
|
||||
) -> None:
|
||||
"""Initialize secret manager"""
|
||||
self._private_key: rsa.RSAPrivateKey = private_key
|
||||
self._manager_client: Client = manager_client
|
||||
self._id: KeySpec = ("id", str(manager_client.id))
|
||||
self.backend: SshecretBackend = backend
|
||||
self.session: AsyncSession = session
|
||||
|
||||
self.audit_data: ClientAuditData = audit_data
|
||||
self.audit: AuditAPI = backend.audit(SubSystem.ADMIN)
|
||||
self._import_has_run: bool = False
|
||||
|
||||
async def _create_missing_entries(self) -> None:
|
||||
"""Create any missing entries."""
|
||||
new_secrets: bool = False
|
||||
to_check = set(self._manager_client.secrets)
|
||||
for secret_name in to_check:
|
||||
# entry = await self._get_entry(secret_name, include_deleted=True)
|
||||
statement = select(ManagedSecret).where(ManagedSecret.name == secret_name)
|
||||
result = await self.session.scalars(statement)
|
||||
if not result.first():
|
||||
new_secrets = True
|
||||
managed_secret = ManagedSecret(name=secret_name)
|
||||
self.session.add(managed_secret)
|
||||
|
||||
await self.session.flush()
|
||||
await self.write_audit(
|
||||
Operation.CREATE,
|
||||
message="Imported managed secret from backend.",
|
||||
secret_name=secret_name,
|
||||
managed_secret=managed_secret,
|
||||
)
|
||||
if new_secrets:
|
||||
await self.session.commit()
|
||||
|
||||
async def _get_group_depth(self, group: Group) -> int:
|
||||
"""Get the depth of a group."""
|
||||
depth = 1
|
||||
if not group.parent_id:
|
||||
return depth
|
||||
|
||||
current = group
|
||||
while current.parent is not None:
|
||||
if current.parent:
|
||||
depth += 1
|
||||
current = await self._get_group_by_id(current.parent.id)
|
||||
else:
|
||||
break
|
||||
|
||||
return depth
|
||||
|
||||
async def _get_group_path(self, group: Group) -> str:
|
||||
"""Get the path of a group."""
|
||||
|
||||
if not group.parent_id:
|
||||
return group.name
|
||||
path: list[str] = []
|
||||
current = group
|
||||
while current.parent_id is not None:
|
||||
path.append(current.name)
|
||||
current = await self._get_group_by_id(current.parent_id)
|
||||
|
||||
path.append("")
|
||||
path.reverse()
|
||||
return "/".join(path)
|
||||
|
||||
async def _get_group_secrets(self, group: Group) -> list[ManagedSecret]:
|
||||
"""Get secrets in a group."""
|
||||
statement = (
|
||||
select(ManagedSecret)
|
||||
.where(ManagedSecret.group_id == group.id)
|
||||
.where(ManagedSecret.is_deleted.is_not(True))
|
||||
)
|
||||
results = await self.session.scalars(statement)
|
||||
return list(results.all())
|
||||
|
||||
async def _build_group_tree(
|
||||
self, group: Group, parent: SecretGroup | None = None, depth: int | None = None
|
||||
) -> SecretGroup:
|
||||
"""Build a group tree."""
|
||||
path = "/"
|
||||
if parent:
|
||||
path = os.path.join(parent.path, path)
|
||||
secret_group = SecretGroup(
|
||||
name=group.name, path=path, description=group.description
|
||||
)
|
||||
group_secrets = await self._get_group_secrets(group)
|
||||
for secret in group_secrets:
|
||||
secret_group.entries.append(secret.name)
|
||||
if parent:
|
||||
secret_group.parent_group = parent
|
||||
|
||||
current_depth = await self._get_group_depth(group)
|
||||
|
||||
if not parent and group.parent:
|
||||
parent_group = await self._get_group_by_id(group.parent.id)
|
||||
assert parent_group is not None
|
||||
parent = await self._build_group_tree(parent_group, depth=current_depth)
|
||||
parent.children.append(secret_group)
|
||||
secret_group.parent_group = parent
|
||||
|
||||
if depth and depth == current_depth:
|
||||
return secret_group
|
||||
|
||||
for subgroup in group.children:
|
||||
child_group = await self._get_group_by_id(subgroup.id)
|
||||
assert child_group is not None
|
||||
secret_subgroup = await self._build_group_tree(
|
||||
child_group, secret_group, depth=depth
|
||||
)
|
||||
secret_group.children.append(secret_subgroup)
|
||||
|
||||
return secret_group
|
||||
|
||||
async def write_audit(
|
||||
self,
|
||||
operation: Operation,
|
||||
message: str,
|
||||
group_name: str | None = None,
|
||||
client_secret: ClientSecret | None = None,
|
||||
secret_name: str | None = None,
|
||||
managed_secret: ManagedSecret | None = None,
|
||||
**data: str,
|
||||
) -> None:
|
||||
"""Write Audit message."""
|
||||
if group_name:
|
||||
data["group"] = group_name
|
||||
|
||||
data["username"] = self.audit_data.username
|
||||
if client_secret and not secret_name:
|
||||
secret_name = client_secret.name
|
||||
|
||||
if managed_secret:
|
||||
data["managed_secret"] = str(managed_secret.id)
|
||||
|
||||
await self.audit.write_async(
|
||||
operation=operation,
|
||||
message=message,
|
||||
origin=self.audit_data.origin,
|
||||
client=self._manager_client,
|
||||
secret=client_secret,
|
||||
secret_name=secret_name,
|
||||
**data,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def public_key(self) -> rsa.RSAPublicKey:
|
||||
"""Get public key."""
|
||||
keystring = self._manager_client.public_key
|
||||
return load_public_key(keystring.encode())
|
||||
|
||||
async def _get_entry(
|
||||
self, name: str, include_deleted: bool = False
|
||||
) -> ManagedSecret | None:
|
||||
"""Get managed secret."""
|
||||
if not self._import_has_run:
|
||||
await self._create_missing_entries()
|
||||
self._import_has_run = True
|
||||
statement = (
|
||||
select(ManagedSecret)
|
||||
.options(selectinload(ManagedSecret.group))
|
||||
.where(ManagedSecret.name == name)
|
||||
)
|
||||
if not include_deleted:
|
||||
statement = statement.where(ManagedSecret.is_deleted.is_not(True))
|
||||
|
||||
result = await self.session.scalars(statement)
|
||||
return result.first()
|
||||
|
||||
async def add_entry(
|
||||
self,
|
||||
entry_name: str,
|
||||
secret: str,
|
||||
overwrite: bool = False,
|
||||
group_path: str | None = None,
|
||||
) -> None:
|
||||
"""Add entry."""
|
||||
existing_entry = await self._get_entry(entry_name)
|
||||
if existing_entry and not overwrite:
|
||||
raise InvalidSecretNameError(
|
||||
"Another secret with this name is already defined."
|
||||
)
|
||||
|
||||
encrypted = encrypt_string(secret, self.public_key)
|
||||
client_secret = await self.backend.create_client_secret(
|
||||
self._id, entry_name, encrypted
|
||||
)
|
||||
group_id: uuid.UUID | None = None
|
||||
if group_path:
|
||||
elements = parse_path(group_path)
|
||||
group = await self._get_group(elements.item, elements.parent, True)
|
||||
if not group:
|
||||
raise InvalidGroupNameError("Invalid group name")
|
||||
group_id = group.id
|
||||
|
||||
if existing_entry:
|
||||
existing_entry.updated_at = datetime.now(timezone.utc)
|
||||
if group_id:
|
||||
existing_entry.group_id = group_id
|
||||
self.session.add(existing_entry)
|
||||
await self.session.commit()
|
||||
await self.write_audit(
|
||||
Operation.UPDATE,
|
||||
"Updated secret value",
|
||||
group_name=group_path,
|
||||
client_secret=client_secret,
|
||||
managed_secret=existing_entry,
|
||||
)
|
||||
else:
|
||||
managed_secret = ManagedSecret(
|
||||
name=entry_name,
|
||||
group_id=group_id,
|
||||
)
|
||||
self.session.add(managed_secret)
|
||||
|
||||
await self.session.commit()
|
||||
await self.write_audit(
|
||||
Operation.CREATE,
|
||||
"Created managed client secret",
|
||||
group_path,
|
||||
client_secret=client_secret,
|
||||
managed_secret=managed_secret,
|
||||
)
|
||||
|
||||
async def get_secret(self, entry_name: str) -> str | None:
|
||||
"""Get secret."""
|
||||
client_secret = await self.backend.get_client_secret(
|
||||
self._id, ("name", entry_name)
|
||||
)
|
||||
if not client_secret:
|
||||
return None
|
||||
decrypted = decode_string(client_secret, self._private_key)
|
||||
await self.write_audit(
|
||||
Operation.READ,
|
||||
"Secret was viewed from secret manager",
|
||||
secret_name=entry_name,
|
||||
)
|
||||
|
||||
return decrypted
|
||||
|
||||
async def get_available_secrets(self, group_path: str | None = None) -> list[str]:
|
||||
"""Get the names of all secrets in the db."""
|
||||
if not self._import_has_run:
|
||||
await self._create_missing_entries()
|
||||
if group_path:
|
||||
elements = parse_path(group_path)
|
||||
group = await self._get_group(elements.item, elements.parent)
|
||||
if not group:
|
||||
raise InvalidGroupNameError("Invalid or nonexisting group name.")
|
||||
entries = group.secrets
|
||||
else:
|
||||
result = await self.session.scalars(
|
||||
select(ManagedSecret)
|
||||
.options(selectinload(ManagedSecret.group))
|
||||
.where(ManagedSecret.is_deleted.is_not(True))
|
||||
)
|
||||
|
||||
entries = list(result.all())
|
||||
|
||||
return [entry.name for entry in entries]
|
||||
|
||||
async def delete_entry(self, entry_name: str) -> None:
|
||||
"""Delete a secret."""
|
||||
entry = await self._get_entry(entry_name)
|
||||
if not entry:
|
||||
return
|
||||
entry.is_deleted = True
|
||||
entry.deleted_at = datetime.now(timezone.utc)
|
||||
self.session.add(entry)
|
||||
await self.session.commit()
|
||||
await self.backend.delete_client_secret(
|
||||
("id", str(self._manager_client.id)), ("name", entry_name)
|
||||
)
|
||||
await self.write_audit(
|
||||
Operation.DELETE,
|
||||
"Managed secret entry deleted",
|
||||
secret_name=entry_name,
|
||||
managed_secret=entry,
|
||||
)
|
||||
|
||||
async def get_entry_group(self, entry_name: str) -> str | None:
|
||||
"""Get group of entry."""
|
||||
entry = await self._get_entry(entry_name)
|
||||
if not entry:
|
||||
raise InvalidSecretNameError("Invalid secret name or secret not found.")
|
||||
if entry.group:
|
||||
return entry.group.name
|
||||
return None
|
||||
|
||||
async def _get_groups(
|
||||
self, pattern: str | None = None, regex: bool = True, root_groups: bool = False
|
||||
) -> list[Group]:
|
||||
"""Get groups."""
|
||||
statement = select(Group).options(
|
||||
selectinload(Group.children), selectinload(Group.parent)
|
||||
)
|
||||
if pattern and regex:
|
||||
statement = statement.where(Group.name.regexp_match(pattern))
|
||||
elif pattern:
|
||||
statement = statement.where(Group.name.contains(pattern))
|
||||
if root_groups:
|
||||
statement = statement.where(Group.parent_id == None)
|
||||
results = await self.session.scalars(statement)
|
||||
return list(results.all())
|
||||
|
||||
async def get_secret_groups(
|
||||
self, pattern: str | None = None, regex: bool = True
|
||||
) -> list[SecretGroup]:
|
||||
"""Get secret groups, as a hierarcy."""
|
||||
if pattern:
|
||||
groups = await self._get_groups(pattern, regex)
|
||||
else:
|
||||
groups = await self._get_groups(root_groups=True)
|
||||
|
||||
secret_groups: list[SecretGroup] = []
|
||||
for group in groups:
|
||||
secret_group = await self._build_group_tree(group)
|
||||
secret_groups.append(secret_group)
|
||||
|
||||
return secret_groups
|
||||
|
||||
async def get_secret_group_list(
|
||||
self, pattern: str | None = None, regex: bool = True
|
||||
) -> list[SecretGroup]:
|
||||
"""Get secret group list."""
|
||||
groups = await self._get_groups(pattern, regex)
|
||||
return [(await self._build_group_tree(group)) for group in groups]
|
||||
|
||||
async def _get_group_by_id(self, id: uuid.UUID) -> Group:
|
||||
"""Get group by ID."""
|
||||
statement = (
|
||||
select(Group)
|
||||
.options(
|
||||
selectinload(Group.parent),
|
||||
selectinload(Group.children),
|
||||
selectinload(Group.secrets),
|
||||
)
|
||||
.where(Group.id == id)
|
||||
)
|
||||
|
||||
result = await self.session.scalars(statement)
|
||||
return result.one()
|
||||
|
||||
async def _get_group(
|
||||
self, name: str, parent: str | None = None, exact_match: bool = False
|
||||
) -> Group | None:
|
||||
"""Get a group."""
|
||||
statement = (
|
||||
select(Group)
|
||||
.options(
|
||||
selectinload(Group.parent),
|
||||
selectinload(Group.children),
|
||||
selectinload(Group.secrets),
|
||||
)
|
||||
.where(Group.name == name)
|
||||
)
|
||||
if parent:
|
||||
ParentGroup = aliased(Group)
|
||||
statement = statement.join(ParentGroup, Group.parent).where(
|
||||
ParentGroup.name == parent
|
||||
)
|
||||
elif exact_match:
|
||||
statement = statement.where(Group.parent_id == None)
|
||||
result = await self.session.scalars(statement)
|
||||
return result.first()
|
||||
|
||||
async def get_secret_group(self, path: str) -> SecretGroup | None:
|
||||
"""Get a secret group by path."""
|
||||
elements = parse_path(path)
|
||||
|
||||
group_name = elements.item
|
||||
parent_group = elements.parent
|
||||
|
||||
group = await self._get_group(group_name, parent_group)
|
||||
if not group:
|
||||
return None
|
||||
|
||||
return await self._build_group_tree(group)
|
||||
|
||||
async def get_ungrouped_secrets(self) -> list[str]:
|
||||
"""Get ungrouped secrets."""
|
||||
statement = (
|
||||
select(ManagedSecret)
|
||||
.where(ManagedSecret.is_deleted.is_not(True))
|
||||
.where(ManagedSecret.group_id == None)
|
||||
)
|
||||
result = await self.session.scalars(statement)
|
||||
secrets = result.all()
|
||||
return [secret.name for secret in secrets]
|
||||
|
||||
async def add_group(
|
||||
self,
|
||||
name_or_path: str,
|
||||
description: str | None = None,
|
||||
parent_group: str | None = None,
|
||||
) -> None:
|
||||
"""Add a group."""
|
||||
parent_id: uuid.UUID | None = None
|
||||
group_name = name_or_path
|
||||
if parent_group and name_or_path.startswith("/"):
|
||||
raise InvalidGroupNameError(
|
||||
"Path as name cannot be used if parent is also specified."
|
||||
)
|
||||
if name_or_path.startswith("/"):
|
||||
elements = parse_path(name_or_path)
|
||||
group_name = elements.item
|
||||
parent_group = elements.parent
|
||||
|
||||
if parent_group:
|
||||
if parent := (await self._get_group(parent_group)):
|
||||
child_names = [child.name for child in parent.children]
|
||||
if group_name in child_names:
|
||||
raise InvalidGroupNameError(
|
||||
"Parent group already has a group with this name."
|
||||
)
|
||||
parent_id = parent.id
|
||||
|
||||
else:
|
||||
raise InvalidGroupNameError(
|
||||
"Invalid or non-existing parent group name."
|
||||
)
|
||||
else:
|
||||
existing_group = await self._get_group(group_name)
|
||||
if existing_group:
|
||||
raise InvalidGroupNameError("A group with this name already exists.")
|
||||
|
||||
group = Group(
|
||||
name=group_name,
|
||||
description=description,
|
||||
parent_id=parent_id,
|
||||
)
|
||||
self.session.add(group)
|
||||
# We don't audit-log this operation.
|
||||
await self.session.commit()
|
||||
|
||||
async def set_group_description(self, path: str, description: str) -> None:
|
||||
"""Set group description."""
|
||||
elements = parse_path(path)
|
||||
group = await self._get_group(elements.item, elements.parent, True)
|
||||
if not group:
|
||||
raise InvalidGroupNameError("Invalid or non-existing group name.")
|
||||
group.description = description
|
||||
self.session.add(group)
|
||||
await self.session.commit()
|
||||
|
||||
async def set_secret_group(self, entry_name: str, group_name: str | None) -> None:
|
||||
"""Move a secret to a group.
|
||||
|
||||
If group_name is None, the secret will be moved out of any group it may exist in.
|
||||
"""
|
||||
entry = await self._get_entry(entry_name)
|
||||
if not entry:
|
||||
raise InvalidSecretNameError("Invalid or non-existing secret.")
|
||||
if group_name:
|
||||
elements = parse_path(group_name)
|
||||
group = await self._get_group(elements.item, elements.parent, True)
|
||||
if not group:
|
||||
raise InvalidGroupNameError("Invalid or non-existing group name.")
|
||||
entry.group_id = group.id
|
||||
else:
|
||||
entry.group_id = None
|
||||
|
||||
self.session.add(entry)
|
||||
await self.session.commit()
|
||||
await self.write_audit(
|
||||
Operation.UPDATE,
|
||||
"Secret group updated",
|
||||
group_name=group_name or "ROOT",
|
||||
secret_name=entry_name,
|
||||
managed_secret=entry,
|
||||
)
|
||||
|
||||
async def move_group(self, path: str, parent_group: str | None) -> None:
|
||||
"""Move group.
|
||||
|
||||
If parent_group is None, it will be moved to the root.
|
||||
"""
|
||||
elements = parse_path(path)
|
||||
group = await self._get_group(elements.item, elements.parent, True)
|
||||
if not group:
|
||||
raise InvalidGroupNameError("Invalid or non-existing group name.")
|
||||
|
||||
parent_group_id: uuid.UUID | None = None
|
||||
if parent_group:
|
||||
db_parent_group = await self._get_group(parent_group)
|
||||
if not db_parent_group:
|
||||
raise InvalidGroupNameError("Invalid or non-existing parent group.")
|
||||
parent_group_id = db_parent_group.id
|
||||
|
||||
group.parent_id = parent_group_id
|
||||
|
||||
self.session.add(group)
|
||||
await self.session.commit()
|
||||
|
||||
async def delete_group(self, path: str) -> None:
|
||||
"""Delete a group."""
|
||||
elements = parse_path(path)
|
||||
group = await self._get_group(elements.item, elements.parent, True)
|
||||
if not group:
|
||||
return
|
||||
await self.session.delete(group)
|
||||
|
||||
await self.session.commit()
|
||||
# We don't audit-log this operation currently, even though it indirectly
|
||||
# may affect secrets.
|
||||
|
||||
async def _export_entries(self) -> list[SecretDataEntryExport]:
|
||||
"""Export entries as a pydantic object."""
|
||||
statement = (
|
||||
select(ManagedSecret)
|
||||
.options(selectinload(ManagedSecret.group))
|
||||
.where(ManagedSecret.is_deleted.is_(False))
|
||||
)
|
||||
results = await self.session.scalars(statement)
|
||||
entries: list[SecretDataEntryExport] = []
|
||||
for entry in results.all():
|
||||
group: str | None = None
|
||||
if entry.group:
|
||||
group = await self._get_group_path(entry.group)
|
||||
secret = await self.get_secret(entry.name)
|
||||
if not secret:
|
||||
continue
|
||||
data = SecretDataEntryExport(name=entry.name, secret=secret, group=group)
|
||||
entries.append(data)
|
||||
return entries
|
||||
|
||||
async def _export_groups(self) -> list[SecretDataGroupExport]:
|
||||
"""Export groups as pydantic objects."""
|
||||
groups = await self.get_secret_group_list()
|
||||
entries = [
|
||||
SecretDataGroupExport(
|
||||
name=group.name,
|
||||
path=group.path,
|
||||
description=group.description,
|
||||
)
|
||||
for group in groups
|
||||
]
|
||||
return entries
|
||||
|
||||
async def export_secrets(self) -> SecretDataExport:
|
||||
"""Export the managed secrets as a pydantic object."""
|
||||
entries = await self._export_entries()
|
||||
groups = await self._export_groups()
|
||||
return SecretDataExport(entries=entries, groups=groups)
|
||||
|
||||
async def export_secrets_json(self) -> str:
|
||||
"""Export secrets as JSON."""
|
||||
export = await self.export_secrets()
|
||||
return export.model_dump_json(indent=2)
|
||||
|
||||
|
||||
def get_managed_private_key(
|
||||
settings: AdminServerSettings,
|
||||
filename: str = KEY_FILENAME,
|
||||
regenerate: bool = False,
|
||||
) -> rsa.RSAPrivateKey:
|
||||
"""Load our private key."""
|
||||
keyfile = Path(filename)
|
||||
if settings.password_manager_directory:
|
||||
keyfile = settings.password_manager_directory / filename
|
||||
if not keyfile.exists():
|
||||
_initial_key_setup(settings, keyfile)
|
||||
setup_password_manager(settings, keyfile, regenerate)
|
||||
return load_private_key(str(keyfile.absolute()), password=settings.secret_key)
|
||||
|
||||
|
||||
def setup_password_manager(
|
||||
settings: AdminServerSettings, filename: Path, regenerate: bool = False
|
||||
) -> bool:
|
||||
"""Setup password manager."""
|
||||
if filename.exists() and not regenerate:
|
||||
return False
|
||||
|
||||
if not settings.secret_key:
|
||||
raise RuntimeError("Error: Could not load secret key from environment.")
|
||||
create_private_rsa_key(filename, password=settings.secret_key)
|
||||
return True
|
||||
|
||||
|
||||
async def create_manager_client(
|
||||
backend: SshecretBackend, public_key: rsa.RSAPublicKey
|
||||
) -> Client:
|
||||
"""Create the manager client."""
|
||||
public_key_string = generate_public_key_string(public_key)
|
||||
new_client = await backend.create_system_client(
|
||||
"AdminPasswordManager",
|
||||
public_key_string,
|
||||
)
|
||||
return new_client
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def password_manager_context(
|
||||
settings: AdminServerSettings, username: str, origin: str
|
||||
) -> AsyncIterator[AsyncSecretContext]:
|
||||
"""Start a context for the password manager."""
|
||||
audit_context_data = ClientAuditData(username=username, origin=origin)
|
||||
session_manager = DatabaseSessionManager(settings.async_db_url)
|
||||
backend = SshecretBackend(str(settings.backend_url), settings.backend_token)
|
||||
private_key = get_managed_private_key(settings)
|
||||
async with session_manager.session() as session:
|
||||
# Check if there is a client_id stored already.
|
||||
query = select(PasswordDB).where(PasswordDB.id == 1)
|
||||
result = await session.scalars(query)
|
||||
password_db = result.first()
|
||||
if not password_db:
|
||||
password_db = PasswordDB(id=1)
|
||||
session.add(password_db)
|
||||
await session.flush()
|
||||
if not password_db.client_id:
|
||||
manager_client = await create_manager_client(
|
||||
backend, private_key.public_key()
|
||||
)
|
||||
password_db.client_id = manager_client.id
|
||||
session.add(password_db)
|
||||
await session.commit()
|
||||
else:
|
||||
manager_client = await backend.get_client(
|
||||
("id", str(password_db.client_id))
|
||||
)
|
||||
if not manager_client:
|
||||
raise SecretManagerError("Error: Could not fetch system client.")
|
||||
|
||||
context = AsyncSecretContext(
|
||||
private_key, manager_client, session, backend, audit_context_data
|
||||
)
|
||||
yield context
|
||||
|
||||
|
||||
def setup_private_key(
|
||||
settings: AdminServerSettings,
|
||||
filename: str = KEY_FILENAME,
|
||||
regenerate: bool = False,
|
||||
) -> None:
|
||||
"""Setup secret manager private key."""
|
||||
keyfile = Path(filename)
|
||||
if settings.password_manager_directory:
|
||||
keyfile = settings.password_manager_directory / filename
|
||||
_initial_key_setup(settings, keyfile, regenerate)
|
||||
|
||||
|
||||
def _initial_key_setup(
|
||||
settings: AdminServerSettings,
|
||||
keyfile: Path,
|
||||
regenerate: bool = False,
|
||||
) -> bool:
|
||||
"""Set up initial keys."""
|
||||
if keyfile.exists() and not regenerate:
|
||||
return False
|
||||
|
||||
assert (
|
||||
settings.secret_key is not None
|
||||
), "Error: Could not load a secret key from environment."
|
||||
create_private_rsa_key(keyfile, password=settings.secret_key)
|
||||
return True
|
||||
Reference in New Issue
Block a user