Write new secret manager using existing RSA logic
This commit is contained in:
@ -1,11 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
from logging.config import fileConfig
|
from logging.config import fileConfig
|
||||||
|
|
||||||
from sqlalchemy import engine_from_config
|
from sqlalchemy import Engine, engine_from_config, pool, create_engine
|
||||||
from sqlalchemy import pool
|
|
||||||
|
|
||||||
from alembic import context
|
from alembic import context
|
||||||
from sshecret_admin.auth.models import Base
|
from sshecret_admin.auth.models import Base
|
||||||
|
from sshecret_admin.core.settings import AdminServerSettings
|
||||||
|
|
||||||
# this is the Alembic Config object, which provides
|
# this is the Alembic Config object, which provides
|
||||||
# access to the values within the .ini file in use.
|
# access to the values within the .ini file in use.
|
||||||
@ -14,11 +14,32 @@ config = context.config
|
|||||||
|
|
||||||
def get_database_url() -> str | None:
|
def get_database_url() -> str | None:
|
||||||
"""Get database URL."""
|
"""Get database URL."""
|
||||||
|
try:
|
||||||
|
settings = AdminServerSettings() # pyright: ignore[reportCallIssue]
|
||||||
|
return str(settings.admin_db)
|
||||||
|
except Exception:
|
||||||
if db_file := os.getenv("SSHECRET_ADMIN_DATABASE"):
|
if db_file := os.getenv("SSHECRET_ADMIN_DATABASE"):
|
||||||
return f"sqlite:///{db_file}"
|
return f"sqlite:///{db_file}"
|
||||||
return config.get_main_option("sqlalchemy.url")
|
return config.get_main_option("sqlalchemy.url")
|
||||||
|
|
||||||
|
|
||||||
|
def get_engine() -> Engine:
|
||||||
|
"""Get engine."""
|
||||||
|
try:
|
||||||
|
settings = AdminServerSettings() # pyright: ignore[reportCallIssue]
|
||||||
|
engine = create_engine(settings.admin_db)
|
||||||
|
return engine
|
||||||
|
except Exception:
|
||||||
|
connectable = engine_from_config(
|
||||||
|
config.get_section(config.config_ini_section, {}),
|
||||||
|
prefix="sqlalchemy.",
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
)
|
||||||
|
return connectable
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Interpret the config file for Python logging.
|
# Interpret the config file for Python logging.
|
||||||
# This line sets up loggers basically.
|
# This line sets up loggers basically.
|
||||||
if config.config_file_name is not None:
|
if config.config_file_name is not None:
|
||||||
@ -68,12 +89,7 @@ def run_migrations_online() -> None:
|
|||||||
and associate a connection with the context.
|
and associate a connection with the context.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
connectable = engine_from_config(
|
connectable = get_engine()
|
||||||
config.get_section(config.config_ini_section, {}),
|
|
||||||
prefix="sqlalchemy.",
|
|
||||||
poolclass=pool.NullPool,
|
|
||||||
)
|
|
||||||
|
|
||||||
with connectable.connect() as connection:
|
with connectable.connect() as connection:
|
||||||
context.configure(
|
context.configure(
|
||||||
connection=connection, target_metadata=target_metadata, render_as_batch=True
|
connection=connection, target_metadata=target_metadata, render_as_batch=True
|
||||||
|
|||||||
@ -0,0 +1,44 @@
|
|||||||
|
"""Implement db structures for internal password manager
|
||||||
|
|
||||||
|
Revision ID: 84356d0ea85f
|
||||||
|
Revises: 6c148590471f
|
||||||
|
Create Date: 2025-06-21 07:21:02.257865
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '84356d0ea85f'
|
||||||
|
down_revision: Union[str, None] = '6c148590471f'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('groups',
|
||||||
|
sa.Column('id', sa.Uuid(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(), nullable=False),
|
||||||
|
sa.Column('parent_id', sa.Uuid(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['parent_id'], ['groups.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
with op.batch_alter_table('password_db', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('client_id', sa.Uuid(), nullable=True))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('password_db', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('client_id')
|
||||||
|
|
||||||
|
op.drop_table('groups')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -0,0 +1,48 @@
|
|||||||
|
"""Implement managed secrets
|
||||||
|
|
||||||
|
Revision ID: c34707a1ea3a
|
||||||
|
Revises: 84356d0ea85f
|
||||||
|
Create Date: 2025-06-21 07:38:12.994535
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'c34707a1ea3a'
|
||||||
|
down_revision: Union[str, None] = '84356d0ea85f'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('managed_secrets',
|
||||||
|
sa.Column('id', sa.Uuid(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(), nullable=False),
|
||||||
|
sa.Column('is_deleted', sa.Boolean(), nullable=False),
|
||||||
|
sa.Column('group_id', sa.Uuid(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
|
||||||
|
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ondelete='SET NULL'),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
with op.batch_alter_table('groups', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('description', sa.String(), nullable=True))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('groups', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('description')
|
||||||
|
|
||||||
|
op.drop_table('managed_secrets')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -5,9 +5,9 @@ import logging
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from sshecret_admin.auth import Token, authenticate_user, create_access_token
|
from sshecret_admin.auth import Token, authenticate_user_async, create_access_token
|
||||||
from sshecret_admin.core.dependencies import AdminDependencies
|
from sshecret_admin.core.dependencies import AdminDependencies
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@ -19,11 +19,12 @@ def create_router(dependencies: AdminDependencies) -> APIRouter:
|
|||||||
|
|
||||||
@app.post("/token")
|
@app.post("/token")
|
||||||
async def login_for_access_token(
|
async def login_for_access_token(
|
||||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
|
||||||
|
session: Annotated[AsyncSession, Depends(dependencies.get_async_session)],
|
||||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||||
) -> Token:
|
) -> Token:
|
||||||
"""Login user and generate token."""
|
"""Login user and generate token."""
|
||||||
user = authenticate_user(session, form_data.username, form_data.password)
|
user = await authenticate_user_async(session, form_data.username, form_data.password)
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|||||||
@ -128,7 +128,7 @@ def create_router(dependencies: AdminDependencies) -> APIRouter:
|
|||||||
group = await admin.get_secret_group(group_name)
|
group = await admin.get_secret_group(group_name)
|
||||||
if not group:
|
if not group:
|
||||||
return
|
return
|
||||||
await admin.delete_secret_group(group_name, keep_entries=True)
|
await admin.delete_secret_group(group_name)
|
||||||
|
|
||||||
@app.post("/secrets/groups/{group_name}/{secret_name}")
|
@app.post("/secrets/groups/{group_name}/{secret_name}")
|
||||||
async def move_secret_to_group(
|
async def move_secret_to_group(
|
||||||
|
|||||||
@ -5,8 +5,9 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from fastapi.security.utils import get_authorization_scheme_param
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@ -57,6 +58,31 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
|||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
def get_client_origin(request: Request) -> str:
|
||||||
|
"""Get client origin."""
|
||||||
|
fallback_origin = "UNKNOWN"
|
||||||
|
if request.client:
|
||||||
|
return request.client.host
|
||||||
|
return fallback_origin
|
||||||
|
|
||||||
|
def get_optional_username(request: Request) -> str | None:
|
||||||
|
"""Get username, if available.
|
||||||
|
|
||||||
|
This is purely used for auditing purposes.
|
||||||
|
"""
|
||||||
|
authorization = request.headers.get("Authorization")
|
||||||
|
scheme, param = get_authorization_scheme_param(authorization)
|
||||||
|
if not authorization or scheme.lower() != "bearer":
|
||||||
|
return None
|
||||||
|
claims = decode_token(dependencies.settings, param)
|
||||||
|
if not claims:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if claims.provider == LOCAL_ISSUER:
|
||||||
|
return claims.sub
|
||||||
|
|
||||||
|
return f"oidc:{claims.email}"
|
||||||
|
|
||||||
async def get_current_active_user(
|
async def get_current_active_user(
|
||||||
current_user: Annotated[User, Depends(get_current_user)],
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
) -> User:
|
) -> User:
|
||||||
@ -66,9 +92,12 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
|||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
async def get_admin_backend(
|
async def get_admin_backend(
|
||||||
|
request: Request,
|
||||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||||
):
|
):
|
||||||
"""Get admin backend API."""
|
"""Get admin backend API."""
|
||||||
|
username = get_optional_username(request)
|
||||||
|
origin = get_client_origin(request)
|
||||||
password_db = session.scalars(
|
password_db = session.scalars(
|
||||||
select(PasswordDB).where(PasswordDB.id == 1)
|
select(PasswordDB).where(PasswordDB.id == 1)
|
||||||
).first()
|
).first()
|
||||||
@ -76,7 +105,11 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
|||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
500, detail="Error: The password manager has not yet been set up."
|
500, detail="Error: The password manager has not yet been set up."
|
||||||
)
|
)
|
||||||
admin = AdminBackend(dependencies.settings, password_db.encrypted_password)
|
admin = AdminBackend(
|
||||||
|
dependencies.settings,
|
||||||
|
username=username,
|
||||||
|
origin=origin,
|
||||||
|
)
|
||||||
yield admin
|
yield admin
|
||||||
|
|
||||||
app = APIRouter(prefix=f"/api/{API_VERSION}")
|
app = APIRouter(prefix=f"/api/{API_VERSION}")
|
||||||
|
|||||||
@ -1,12 +1,13 @@
|
|||||||
"""Models for authentication."""
|
"""Models for authentication and secret management."""
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import override
|
||||||
import uuid
|
import uuid
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
|
||||||
JWT_ALGORITHM = "HS256"
|
JWT_ALGORITHM = "HS256"
|
||||||
@ -75,12 +76,15 @@ class PasswordDB(Base):
|
|||||||
__tablename__: str = "password_db"
|
__tablename__: str = "password_db"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(sa.INT, primary_key=True)
|
id: Mapped[int] = mapped_column(sa.INT, primary_key=True)
|
||||||
encrypted_password: Mapped[str] = mapped_column(sa.String)
|
|
||||||
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
client_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
|
sa.Uuid(as_uuid=True), nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
updated_at: Mapped[datetime | None] = mapped_column(
|
updated_at: Mapped[datetime | None] = mapped_column(
|
||||||
sa.DateTime(timezone=True),
|
sa.DateTime(timezone=True),
|
||||||
server_default=sa.func.now(),
|
server_default=sa.func.now(),
|
||||||
@ -88,6 +92,65 @@ class PasswordDB(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Group(Base):
|
||||||
|
"""A secret group."""
|
||||||
|
|
||||||
|
__tablename__: str = "groups"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||||
|
)
|
||||||
|
name: Mapped[str] = mapped_column(sa.String, nullable=False)
|
||||||
|
description: Mapped[str | None] = mapped_column(sa.String, nullable=True)
|
||||||
|
|
||||||
|
parent_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
|
sa.ForeignKey("groups.id"), nullable=True
|
||||||
|
)
|
||||||
|
parent: Mapped["Group | None"] = relationship(
|
||||||
|
"Group", remote_side=[id], back_populates="children"
|
||||||
|
)
|
||||||
|
children: Mapped[list["Group"]] = relationship(
|
||||||
|
"Group", back_populates="parent", cascade="all, delete"
|
||||||
|
)
|
||||||
|
secrets: Mapped[list["ManagedSecret"]] = relationship(back_populates="group")
|
||||||
|
|
||||||
|
@override
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<Group id={self.id} name={self.name} parent_id={self.parent_id}>"
|
||||||
|
|
||||||
|
|
||||||
|
class ManagedSecret(Base):
|
||||||
|
"""Managed Secret."""
|
||||||
|
|
||||||
|
__tablename__: str = "managed_secrets"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||||
|
)
|
||||||
|
name: Mapped[str] = mapped_column(sa.String, nullable=False)
|
||||||
|
|
||||||
|
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, default=False)
|
||||||
|
|
||||||
|
group_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
|
sa.ForeignKey("groups.id", ondelete="SET NULL"), nullable=True
|
||||||
|
)
|
||||||
|
group: Mapped["Group | None"] = relationship(
|
||||||
|
Group, foreign_keys=[group_id], back_populates="secrets"
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime | None] = mapped_column(
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
onupdate=sa.func.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
deleted_at: Mapped[datetime | None] = mapped_column(
|
||||||
|
sa.DateTime(timezone=True), nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class IdentityClaims(BaseModel):
|
class IdentityClaims(BaseModel):
|
||||||
"""Normalized identity claim model."""
|
"""Normalized identity claim model."""
|
||||||
|
|
||||||
@ -125,6 +188,3 @@ class LocalUserInfo(BaseModel):
|
|||||||
local: bool
|
local: bool
|
||||||
|
|
||||||
|
|
||||||
def init_db(engine: sa.Engine) -> None:
|
|
||||||
"""Create database."""
|
|
||||||
Base.metadata.create_all(engine)
|
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
# pyright: reportUnusedFunction=false
|
# pyright: reportUnusedFunction=false
|
||||||
#
|
#
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
@ -12,15 +13,15 @@ from fastapi.encoders import jsonable_encoder
|
|||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse, RedirectResponse
|
from fastapi.responses import JSONResponse, RedirectResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from sqlalchemy import select
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import Session
|
from sshecret_backend.db import DatabaseSessionManager
|
||||||
from starlette.middleware.sessions import SessionMiddleware
|
from starlette.middleware.sessions import SessionMiddleware
|
||||||
|
|
||||||
from sshecret_admin import api, frontend
|
from sshecret_admin import api, frontend
|
||||||
from sshecret_admin.auth.models import PasswordDB, init_db
|
from sshecret_admin.auth.models import Base
|
||||||
from sshecret_admin.core.db import setup_database
|
from sshecret_admin.core.db import setup_database
|
||||||
from sshecret_admin.frontend.exceptions import RedirectException
|
from sshecret_admin.frontend.exceptions import RedirectException
|
||||||
from sshecret_admin.services.master_password import setup_master_password
|
from sshecret_admin.services.secret_manager import setup_private_key
|
||||||
|
|
||||||
from .dependencies import BaseDependencies
|
from .dependencies import BaseDependencies
|
||||||
from .settings import AdminServerSettings
|
from .settings import AdminServerSettings
|
||||||
@ -40,44 +41,28 @@ def setup_frontend(app: FastAPI, dependencies: BaseDependencies) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def create_admin_app(
|
def create_admin_app(
|
||||||
settings: AdminServerSettings, with_frontend: bool = True
|
settings: AdminServerSettings,
|
||||||
|
with_frontend: bool = True,
|
||||||
|
create_db: bool = False,
|
||||||
) -> FastAPI:
|
) -> FastAPI:
|
||||||
"""Create admin app."""
|
"""Create admin app."""
|
||||||
engine, get_db_session = setup_database(settings.admin_db)
|
engine, get_db_session = setup_database(settings.admin_db)
|
||||||
|
|
||||||
|
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""Get async session."""
|
||||||
|
session_manager = DatabaseSessionManager(settings.async_db_url)
|
||||||
|
async with session_manager.session() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
def setup_password_manager() -> None:
|
def setup_password_manager() -> None:
|
||||||
"""Setup password manager."""
|
"""Setup password manager."""
|
||||||
encr_master_password = setup_master_password(
|
setup_private_key(settings, regenerate=False)
|
||||||
settings=settings, regenerate=False
|
|
||||||
)
|
|
||||||
with Session(engine) as session:
|
|
||||||
existing_password = session.scalars(
|
|
||||||
select(PasswordDB).where(PasswordDB.id == 1)
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not encr_master_password:
|
|
||||||
if existing_password:
|
|
||||||
LOG.info("Master password already defined.")
|
|
||||||
return
|
|
||||||
# Looks like we have to regenerate it
|
|
||||||
LOG.warning(
|
|
||||||
"Master password was set, but not saved to the database. Regenerating it."
|
|
||||||
)
|
|
||||||
encr_master_password = setup_master_password(
|
|
||||||
settings=settings, regenerate=True
|
|
||||||
)
|
|
||||||
|
|
||||||
assert encr_master_password is not None
|
|
||||||
|
|
||||||
with Session(engine) as session:
|
|
||||||
pwdb = PasswordDB(id=1, encrypted_password=encr_master_password)
|
|
||||||
session.add(pwdb)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(_app: FastAPI):
|
async def lifespan(_app: FastAPI):
|
||||||
"""Create database before starting the server."""
|
"""Create database before starting the server."""
|
||||||
init_db(engine)
|
if create_db:
|
||||||
|
Base.metadata.create_all(engine)
|
||||||
setup_password_manager()
|
setup_password_manager()
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@ -109,7 +94,7 @@ def create_admin_app(
|
|||||||
status_code=status.HTTP_200_OK, content=jsonable_encoder({"status": "LIVE"})
|
status_code=status.HTTP_200_OK, content=jsonable_encoder({"status": "LIVE"})
|
||||||
)
|
)
|
||||||
|
|
||||||
dependencies = BaseDependencies(settings, get_db_session)
|
dependencies = BaseDependencies(settings, get_db_session, get_async_session)
|
||||||
|
|
||||||
app.include_router(api.create_api_router(dependencies))
|
app.include_router(api.create_api_router(dependencies))
|
||||||
if with_frontend:
|
if with_frontend:
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from pydantic import ValidationError
|
|||||||
from sqlalchemy import select, create_engine
|
from sqlalchemy import select, create_engine
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sshecret_admin.auth.authentication import hash_password
|
from sshecret_admin.auth.authentication import hash_password
|
||||||
from sshecret_admin.auth.models import AuthProvider, PasswordDB, User, init_db
|
from sshecret_admin.auth.models import AuthProvider, PasswordDB, User
|
||||||
from sshecret_admin.core.settings import AdminServerSettings
|
from sshecret_admin.core.settings import AdminServerSettings
|
||||||
from sshecret_admin.services.admin_backend import AdminBackend
|
from sshecret_admin.services.admin_backend import AdminBackend
|
||||||
|
|
||||||
@ -72,7 +72,6 @@ def cli_create_user(
|
|||||||
"""Create user."""
|
"""Create user."""
|
||||||
settings = cast(AdminServerSettings, ctx.obj)
|
settings = cast(AdminServerSettings, ctx.obj)
|
||||||
engine = create_engine(settings.admin_db)
|
engine = create_engine(settings.admin_db)
|
||||||
init_db(engine)
|
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
create_user(session, username, email, password)
|
create_user(session, username, email, password)
|
||||||
|
|
||||||
@ -87,7 +86,6 @@ def cli_change_user_passwd(ctx: click.Context, username: str, password: str) ->
|
|||||||
"""Change password on user."""
|
"""Change password on user."""
|
||||||
settings = cast(AdminServerSettings, ctx.obj)
|
settings = cast(AdminServerSettings, ctx.obj)
|
||||||
engine = create_engine(settings.admin_db)
|
engine = create_engine(settings.admin_db)
|
||||||
init_db(engine)
|
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
user = session.scalars(select(User).where(User.username == username)).first()
|
user = session.scalars(select(User).where(User.username == username)).first()
|
||||||
if not user:
|
if not user:
|
||||||
@ -107,7 +105,6 @@ def cli_delete_user(ctx: click.Context, username: str) -> None:
|
|||||||
"""Remove a user."""
|
"""Remove a user."""
|
||||||
settings = cast(AdminServerSettings, ctx.obj)
|
settings = cast(AdminServerSettings, ctx.obj)
|
||||||
engine = create_engine(settings.admin_db)
|
engine = create_engine(settings.admin_db)
|
||||||
init_db(engine)
|
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
user = session.scalars(select(User).where(User.username == username)).first()
|
user = session.scalars(select(User).where(User.username == username)).first()
|
||||||
if not user:
|
if not user:
|
||||||
@ -149,7 +146,6 @@ def cli_repl(ctx: click.Context) -> None:
|
|||||||
"""Run an interactive console."""
|
"""Run an interactive console."""
|
||||||
settings = cast(AdminServerSettings, ctx.obj)
|
settings = cast(AdminServerSettings, ctx.obj)
|
||||||
engine = create_engine(settings.admin_db)
|
engine = create_engine(settings.admin_db)
|
||||||
init_db(engine)
|
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
password_db = session.scalars(
|
password_db = session.scalars(
|
||||||
select(PasswordDB).where(PasswordDB.id == 1)
|
select(PasswordDB).where(PasswordDB.id == 1)
|
||||||
@ -165,7 +161,7 @@ def cli_repl(ctx: click.Context) -> None:
|
|||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
return loop.run_until_complete(func)
|
return loop.run_until_complete(func)
|
||||||
|
|
||||||
admin = AdminBackend(settings, password_db.encrypted_password)
|
admin = AdminBackend(settings, )
|
||||||
locals = {
|
locals = {
|
||||||
"run": run,
|
"run": run,
|
||||||
"admin": admin,
|
"admin": admin,
|
||||||
|
|||||||
@ -1,12 +1,13 @@
|
|||||||
"""Database setup."""
|
"""Database setup."""
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from collections.abc import AsyncIterator, Generator, Callable
|
from collections.abc import AsyncIterator, Generator, Callable
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy.engine import URL
|
from sqlalchemy.engine import URL
|
||||||
from sqlalchemy import create_engine, Engine
|
from sqlalchemy import create_engine, Engine, event
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import (
|
from sqlalchemy.ext.asyncio import (
|
||||||
AsyncConnection,
|
AsyncConnection,
|
||||||
@ -18,11 +19,20 @@ from sqlalchemy.ext.asyncio import (
|
|||||||
|
|
||||||
|
|
||||||
def setup_database(
|
def setup_database(
|
||||||
db_url: URL | str,
|
db_url: URL,
|
||||||
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
|
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
|
||||||
"""Setup database."""
|
"""Setup database."""
|
||||||
|
|
||||||
engine = create_engine(db_url, echo=True, future=True)
|
engine = create_engine(db_url, echo=True, future=True)
|
||||||
|
if db_url.drivername.startswith("sqlite"):
|
||||||
|
|
||||||
|
@event.listens_for(engine, "connect")
|
||||||
|
def set_sqlite_pragma(
|
||||||
|
dbapi_connection: sqlite3.Connection, _connection_record: object
|
||||||
|
) -> None:
|
||||||
|
cursor = dbapi_connection.cursor()
|
||||||
|
cursor.execute("PRAGMA foreign_keys=ON")
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
def get_db_session() -> Generator[Session, None, None]:
|
def get_db_session() -> Generator[Session, None, None]:
|
||||||
"""Get DB Session."""
|
"""Get DB Session."""
|
||||||
@ -33,8 +43,18 @@ def setup_database(
|
|||||||
|
|
||||||
|
|
||||||
class DatabaseSessionManager:
|
class DatabaseSessionManager:
|
||||||
def __init__(self, host: URL | str, **engine_kwargs: str):
|
def __init__(self, host: URL, **engine_kwargs: str):
|
||||||
self._engine: AsyncEngine | None = create_async_engine(host, **engine_kwargs)
|
self._engine: AsyncEngine | None = create_async_engine(host, **engine_kwargs)
|
||||||
|
if host.drivername.startswith("sqlite+"):
|
||||||
|
|
||||||
|
@event.listens_for(self._engine.sync_engine, "connect")
|
||||||
|
def set_sqlite_pragma(
|
||||||
|
dbapi_connection: sqlite3.Connection, _connection_record: object
|
||||||
|
) -> None:
|
||||||
|
cursor = dbapi_connection.cursor()
|
||||||
|
cursor.execute("PRAGMA foreign_keys=ON")
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
self._sessionmaker: async_sessionmaker[AsyncSession] | None = (
|
self._sessionmaker: async_sessionmaker[AsyncSession] | None = (
|
||||||
async_sessionmaker(
|
async_sessionmaker(
|
||||||
autocommit=False, bind=self._engine, expire_on_commit=False
|
autocommit=False, bind=self._engine, expire_on_commit=False
|
||||||
|
|||||||
@ -4,6 +4,8 @@ from collections.abc import AsyncGenerator, Awaitable, Callable, Generator
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Self
|
from typing import Self
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sshecret_admin.auth import User
|
from sshecret_admin.auth import User
|
||||||
from sshecret_admin.services import AdminBackend
|
from sshecret_admin.services import AdminBackend
|
||||||
@ -11,8 +13,9 @@ from sshecret_admin.core.settings import AdminServerSettings
|
|||||||
|
|
||||||
|
|
||||||
DBSessionDep = Callable[[], Generator[Session, None, None]]
|
DBSessionDep = Callable[[], Generator[Session, None, None]]
|
||||||
|
AsyncSessionDep = Callable[[], AsyncGenerator[AsyncSession, None]]
|
||||||
|
|
||||||
AdminDep = Callable[[Session], AsyncGenerator[AdminBackend, None]]
|
AdminDep = Callable[[Request, Session], AsyncGenerator[AdminBackend, None]]
|
||||||
|
|
||||||
GetUserDep = Callable[[User], Awaitable[User]]
|
GetUserDep = Callable[[User], Awaitable[User]]
|
||||||
|
|
||||||
@ -23,6 +26,8 @@ class BaseDependencies:
|
|||||||
|
|
||||||
settings: AdminServerSettings
|
settings: AdminServerSettings
|
||||||
get_db_session: DBSessionDep
|
get_db_session: DBSessionDep
|
||||||
|
get_async_session: AsyncSessionDep
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -43,6 +48,7 @@ class AdminDependencies(BaseDependencies):
|
|||||||
return cls(
|
return cls(
|
||||||
settings=deps.settings,
|
settings=deps.settings,
|
||||||
get_db_session=deps.get_db_session,
|
get_db_session=deps.get_db_session,
|
||||||
|
get_async_session=deps.get_async_session,
|
||||||
get_admin_backend=get_admin_backend,
|
get_admin_backend=get_admin_backend,
|
||||||
get_current_active_user=get_current_active_user,
|
get_current_active_user=get_current_active_user,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -30,7 +30,6 @@ class FrontendDependencies(BaseDependencies):
|
|||||||
get_refresh_claims: RefreshTokenDep
|
get_refresh_claims: RefreshTokenDep
|
||||||
get_login_status: LoginStatusDep
|
get_login_status: LoginStatusDep
|
||||||
get_user_info: UserInfoDep
|
get_user_info: UserInfoDep
|
||||||
get_async_session: AsyncSessionDep
|
|
||||||
require_login: LoginGuardDep
|
require_login: LoginGuardDep
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -42,18 +41,17 @@ class FrontendDependencies(BaseDependencies):
|
|||||||
get_refresh_claims: RefreshTokenDep,
|
get_refresh_claims: RefreshTokenDep,
|
||||||
get_login_status: LoginStatusDep,
|
get_login_status: LoginStatusDep,
|
||||||
get_user_info: UserInfoDep,
|
get_user_info: UserInfoDep,
|
||||||
get_async_session: AsyncSessionDep,
|
|
||||||
require_login: LoginGuardDep,
|
require_login: LoginGuardDep,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
"""Create from base dependencies."""
|
"""Create from base dependencies."""
|
||||||
return cls(
|
return cls(
|
||||||
settings=deps.settings,
|
settings=deps.settings,
|
||||||
get_db_session=deps.get_db_session,
|
get_db_session=deps.get_db_session,
|
||||||
|
get_async_session=deps.get_async_session,
|
||||||
get_admin_backend=get_admin_backend,
|
get_admin_backend=get_admin_backend,
|
||||||
templates=templates,
|
templates=templates,
|
||||||
get_refresh_claims=get_refresh_claims,
|
get_refresh_claims=get_refresh_claims,
|
||||||
get_login_status=get_login_status,
|
get_login_status=get_login_status,
|
||||||
get_user_info=get_user_info,
|
get_user_info=get_user_info,
|
||||||
get_async_session=get_async_session,
|
|
||||||
require_login=require_login,
|
require_login=require_login,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -24,7 +24,6 @@ from sshecret_admin.auth.constants import LOCAL_ISSUER
|
|||||||
|
|
||||||
from sshecret_admin.core.dependencies import BaseDependencies
|
from sshecret_admin.core.dependencies import BaseDependencies
|
||||||
from sshecret_admin.services.admin_backend import AdminBackend
|
from sshecret_admin.services.admin_backend import AdminBackend
|
||||||
from sshecret_admin.core.db import DatabaseSessionManager
|
|
||||||
|
|
||||||
from .dependencies import FrontendDependencies
|
from .dependencies import FrontendDependencies
|
||||||
from .exceptions import RedirectException
|
from .exceptions import RedirectException
|
||||||
@ -50,17 +49,24 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
|||||||
templates = Jinja2Blocks(directory=template_path)
|
templates = Jinja2Blocks(directory=template_path)
|
||||||
|
|
||||||
async def get_admin_backend(
|
async def get_admin_backend(
|
||||||
|
request: Request,
|
||||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||||
):
|
):
|
||||||
"""Get admin backend API."""
|
"""Get admin backend API."""
|
||||||
password_db = session.scalars(
|
password_db = session.scalars(
|
||||||
select(PasswordDB).where(PasswordDB.id == 1)
|
select(PasswordDB).where(PasswordDB.id == 1)
|
||||||
).first()
|
).first()
|
||||||
|
username = get_optional_username(request)
|
||||||
|
origin = get_client_origin(request)
|
||||||
if not password_db:
|
if not password_db:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
500, detail="Error: The password manager has not yet been set up."
|
500, detail="Error: The password manager has not yet been set up."
|
||||||
)
|
)
|
||||||
admin = AdminBackend(dependencies.settings, password_db.encrypted_password)
|
admin = AdminBackend(
|
||||||
|
dependencies.settings,
|
||||||
|
username=username,
|
||||||
|
origin=origin,
|
||||||
|
)
|
||||||
yield admin
|
yield admin
|
||||||
|
|
||||||
def get_identity_claims(request: Request) -> IdentityClaims:
|
def get_identity_claims(request: Request) -> IdentityClaims:
|
||||||
@ -108,14 +114,9 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
|||||||
next = URL("/refresh").include_query_params(next=request.url.path)
|
next = URL("/refresh").include_query_params(next=request.url.path)
|
||||||
raise RedirectException(to=next)
|
raise RedirectException(to=next)
|
||||||
|
|
||||||
async def get_async_session():
|
|
||||||
"""Get async session."""
|
|
||||||
sessionmanager = DatabaseSessionManager(dependencies.settings.async_db_url)
|
|
||||||
async with sessionmanager.session() as session:
|
|
||||||
yield session
|
|
||||||
|
|
||||||
async def get_user_info(
|
async def get_user_info(
|
||||||
request: Request, session: Annotated[AsyncSession, Depends(get_async_session)]
|
request: Request,
|
||||||
|
session: Annotated[AsyncSession, Depends(dependencies.get_async_session)],
|
||||||
) -> LocalUserInfo:
|
) -> LocalUserInfo:
|
||||||
"""Get User information."""
|
"""Get User information."""
|
||||||
claims = get_identity_claims(request)
|
claims = get_identity_claims(request)
|
||||||
@ -142,6 +143,30 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
|||||||
next = URL("/refresh").include_query_params(next=request.url.path)
|
next = URL("/refresh").include_query_params(next=request.url.path)
|
||||||
raise RedirectException(to=next)
|
raise RedirectException(to=next)
|
||||||
|
|
||||||
|
def get_optional_username(
|
||||||
|
request: Request,
|
||||||
|
) -> str | None:
|
||||||
|
"""Get username, if available.
|
||||||
|
|
||||||
|
This is purely used for auditing purposes.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
claims = get_identity_claims(request)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if claims.provider == LOCAL_ISSUER:
|
||||||
|
return claims.sub
|
||||||
|
|
||||||
|
return f"oidc:{claims.email}"
|
||||||
|
|
||||||
|
def get_client_origin(request: Request) -> str:
|
||||||
|
"""Get client origin."""
|
||||||
|
fallback_origin = "UNKNOWN"
|
||||||
|
if request.client:
|
||||||
|
return request.client.host
|
||||||
|
return fallback_origin
|
||||||
|
|
||||||
view_dependencies = FrontendDependencies.create(
|
view_dependencies = FrontendDependencies.create(
|
||||||
dependencies,
|
dependencies,
|
||||||
get_admin_backend,
|
get_admin_backend,
|
||||||
@ -149,7 +174,6 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
|||||||
refresh_identity_claims,
|
refresh_identity_claims,
|
||||||
get_login_status,
|
get_login_status,
|
||||||
get_user_info,
|
get_user_info,
|
||||||
get_async_session,
|
|
||||||
require_login,
|
require_login,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -16,7 +16,7 @@
|
|||||||
<!-- Detail Pane -->
|
<!-- Detail Pane -->
|
||||||
|
|
||||||
<section id="detail-pane"
|
<section id="detail-pane"
|
||||||
class="flex-1 flex overflow-y-auto bg-white p-4 {%- if not mobile_show_details|default(false) -%} hidden{%- endif -%} lg:block dark:bg-gray-800">
|
class="flex-1 flex overflow-y-auto bg-white p-4 lg:block {% if not mobile_show_details|default(false) -%} hidden{%- endif -%} lg:block dark:bg-gray-800">
|
||||||
|
|
||||||
|
|
||||||
{% block detail %}
|
{% block detail %}
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import ipaddress
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request, Response
|
from fastapi import APIRouter, Depends, Form, HTTPException, Request, Response
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
from pydantic import BaseModel, IPvAnyAddress, IPvAnyNetwork
|
from pydantic import BaseModel, IPvAnyAddress, IPvAnyNetwork
|
||||||
from sshecret_admin.frontend.views.common import PagingInfo
|
from sshecret_admin.frontend.views.common import PagingInfo
|
||||||
@ -209,7 +209,7 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
|||||||
page: int,
|
page: int,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""Get more events for a client."""
|
"""Get more events for a client."""
|
||||||
if not "HX-Request" in request.headers:
|
if "HX-Request" not in request.headers:
|
||||||
return RedirectResponse(url=f"/clients/client/{id}")
|
return RedirectResponse(url=f"/clients/client/{id}")
|
||||||
|
|
||||||
client = await admin.get_client(("id", id))
|
client = await admin.get_client(("id", id))
|
||||||
|
|||||||
@ -4,8 +4,8 @@ Since we have a frontend and a REST API, it makes sense to have a generic librar
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Iterator
|
from collections.abc import AsyncIterator
|
||||||
from contextlib import contextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from sshecret.backend import (
|
from sshecret.backend import (
|
||||||
AuditLog,
|
AuditLog,
|
||||||
@ -20,7 +20,7 @@ from sshecret.backend.models import ClientQueryResult, DetailedSecrets
|
|||||||
from sshecret.backend.api import AuditAPI, KeySpec
|
from sshecret.backend.api import AuditAPI, KeySpec
|
||||||
from sshecret.crypto import encrypt_string, load_public_key
|
from sshecret.crypto import encrypt_string, load_public_key
|
||||||
|
|
||||||
from .keepass import PasswordContext, load_password_manager
|
from .secret_manager import AsyncSecretContext, password_manager_context
|
||||||
from sshecret_admin.core.settings import AdminServerSettings
|
from sshecret_admin.core.settings import AdminServerSettings
|
||||||
from .models import (
|
from .models import (
|
||||||
ClientSecretGroup,
|
ClientSecretGroup,
|
||||||
@ -86,19 +86,27 @@ def add_clients_to_secret_group(
|
|||||||
class AdminBackend:
|
class AdminBackend:
|
||||||
"""Admin backend API."""
|
"""Admin backend API."""
|
||||||
|
|
||||||
def __init__(self, settings: AdminServerSettings, keepass_password: str) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
settings: AdminServerSettings,
|
||||||
|
username: str | None = None,
|
||||||
|
origin: str = "UNKNOWN",
|
||||||
|
) -> None:
|
||||||
"""Create client management API."""
|
"""Create client management API."""
|
||||||
self.settings: AdminServerSettings = settings
|
self.settings: AdminServerSettings = settings
|
||||||
self.backend: SshecretBackend = SshecretBackend(
|
self.backend: SshecretBackend = SshecretBackend(
|
||||||
str(settings.backend_url), settings.backend_token
|
str(settings.backend_url), settings.backend_token
|
||||||
)
|
)
|
||||||
self.keepass_password: str = keepass_password
|
self.username: str = username or "UKNOWN_USER"
|
||||||
|
self.origin: str = origin
|
||||||
|
|
||||||
@contextmanager
|
@asynccontextmanager
|
||||||
def password_manager(self) -> Iterator[PasswordContext]:
|
async def secrets_manager(self) -> AsyncIterator[AsyncSecretContext]:
|
||||||
"""Open the password manager."""
|
"""Open the secrets manager."""
|
||||||
with load_password_manager(self.settings, self.keepass_password) as kp:
|
async with password_manager_context(
|
||||||
yield kp
|
self.settings, self.username, self.origin
|
||||||
|
) as manager:
|
||||||
|
yield manager
|
||||||
|
|
||||||
async def _get_clients(self, filter: ClientFilter | None = None) -> list[Client]:
|
async def _get_clients(self, filter: ClientFilter | None = None) -> list[Client]:
|
||||||
"""Get clients from backend."""
|
"""Get clients from backend."""
|
||||||
@ -194,7 +202,7 @@ class AdminBackend:
|
|||||||
self,
|
self,
|
||||||
name: KeySpec,
|
name: KeySpec,
|
||||||
new_key: str,
|
new_key: str,
|
||||||
password_manager: PasswordContext,
|
password_manager: AsyncSecretContext,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Update client public key."""
|
"""Update client public key."""
|
||||||
LOG.info(
|
LOG.info(
|
||||||
@ -207,7 +215,7 @@ class AdminBackend:
|
|||||||
updated_secrets: list[str] = []
|
updated_secrets: list[str] = []
|
||||||
for secret in client.secrets:
|
for secret in client.secrets:
|
||||||
LOG.debug("Re-encrypting secret %s for client %s", secret, name)
|
LOG.debug("Re-encrypting secret %s for client %s", secret, name)
|
||||||
secret_value = password_manager.get_secret(secret)
|
secret_value = await password_manager.get_secret(secret)
|
||||||
if not secret_value:
|
if not secret_value:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"Referenced secret %s does not exist! Skipping.", secret_value
|
"Referenced secret %s does not exist! Skipping.", secret_value
|
||||||
@ -224,7 +232,7 @@ class AdminBackend:
|
|||||||
async def update_client_public_key(self, name: KeySpec, new_key: str) -> list[str]:
|
async def update_client_public_key(self, name: KeySpec, new_key: str) -> list[str]:
|
||||||
"""Update client public key."""
|
"""Update client public key."""
|
||||||
try:
|
try:
|
||||||
with self.password_manager() as password_manager:
|
async with self.secrets_manager() as password_manager:
|
||||||
return await self._update_client_public_key(
|
return await self._update_client_public_key(
|
||||||
name, new_key, password_manager
|
name, new_key, password_manager
|
||||||
)
|
)
|
||||||
@ -291,8 +299,8 @@ class AdminBackend:
|
|||||||
This fetches the secret to client mapping from backend, and adds secrets from the password manager.
|
This fetches the secret to client mapping from backend, and adds secrets from the password manager.
|
||||||
"""
|
"""
|
||||||
backend_secrets = await self.backend.get_secrets()
|
backend_secrets = await self.backend.get_secrets()
|
||||||
with self.password_manager() as password_manager:
|
async with self.secrets_manager() as password_manager:
|
||||||
admin_secrets = password_manager.get_available_secrets()
|
admin_secrets = await password_manager.get_available_secrets()
|
||||||
|
|
||||||
secrets: dict[str, SecretListView] = {}
|
secrets: dict[str, SecretListView] = {}
|
||||||
for secret in backend_secrets:
|
for secret in backend_secrets:
|
||||||
@ -324,8 +332,8 @@ class AdminBackend:
|
|||||||
|
|
||||||
This fetches the secret to client mapping from backend, and adds secrets from the password manager.
|
This fetches the secret to client mapping from backend, and adds secrets from the password manager.
|
||||||
"""
|
"""
|
||||||
with self.password_manager() as password_manager:
|
async with self.secrets_manager() as password_manager:
|
||||||
all_secrets = password_manager.get_available_secrets()
|
all_secrets = await password_manager.get_available_secrets()
|
||||||
|
|
||||||
secrets = await self.backend.get_detailed_secrets()
|
secrets = await self.backend.get_detailed_secrets()
|
||||||
backend_secret_names = [secret.name for secret in secrets]
|
backend_secret_names = [secret.name for secret in secrets]
|
||||||
@ -351,13 +359,13 @@ class AdminBackend:
|
|||||||
parent_group: str | None = None,
|
parent_group: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add secret group."""
|
"""Add secret group."""
|
||||||
with self.password_manager() as password_manager:
|
async with self.secrets_manager() as password_manager:
|
||||||
password_manager.add_group(group_name, description, parent_group)
|
await password_manager.add_group(group_name, description, parent_group)
|
||||||
|
|
||||||
async def set_secret_group(self, secret_name: str, group_name: str | None) -> None:
|
async def set_secret_group(self, secret_name: str, group_name: str | None) -> None:
|
||||||
"""Assign a group to a secret."""
|
"""Assign a group to a secret."""
|
||||||
with self.password_manager() as password_manager:
|
async with self.secrets_manager() as password_manager:
|
||||||
password_manager.set_secret_group(secret_name, group_name)
|
await password_manager.set_secret_group(secret_name, group_name)
|
||||||
|
|
||||||
async def move_secret_group(
|
async def move_secret_group(
|
||||||
self, group_name: str, parent_group: str | None
|
self, group_name: str, parent_group: str | None
|
||||||
@ -366,23 +374,21 @@ class AdminBackend:
|
|||||||
|
|
||||||
If parent_group is None, it will be moved to the root.
|
If parent_group is None, it will be moved to the root.
|
||||||
"""
|
"""
|
||||||
with self.password_manager() as password_manager:
|
async with self.secrets_manager() as password_manager:
|
||||||
password_manager.move_group(group_name, parent_group)
|
await password_manager.move_group(group_name, parent_group)
|
||||||
|
|
||||||
async def set_group_description(self, group_name: str, description: str) -> None:
|
async def set_group_description(self, group_name: str, description: str) -> None:
|
||||||
"""Set a group description."""
|
"""Set a group description."""
|
||||||
with self.password_manager() as password_manager:
|
async with self.secrets_manager() as password_manager:
|
||||||
password_manager.set_group_description(group_name, description)
|
await password_manager.set_group_description(group_name, description)
|
||||||
|
|
||||||
async def delete_secret_group(
|
async def delete_secret_group(self, group_name: str) -> None:
|
||||||
self, group_name: str, keep_entries: bool = True
|
|
||||||
) -> None:
|
|
||||||
"""Delete a group.
|
"""Delete a group.
|
||||||
|
|
||||||
If keep_entries is set to False, all entries in the group will be deleted.
|
If keep_entries is set to False, all entries in the group will be deleted.
|
||||||
"""
|
"""
|
||||||
with self.password_manager() as password_manager:
|
async with self.secrets_manager() as password_manager:
|
||||||
password_manager.delete_group(group_name, keep_entries)
|
await password_manager.delete_group(group_name)
|
||||||
|
|
||||||
async def get_secret_groups(
|
async def get_secret_groups(
|
||||||
self,
|
self,
|
||||||
@ -399,18 +405,18 @@ class AdminBackend:
|
|||||||
"""
|
"""
|
||||||
all_secrets = await self.backend.get_detailed_secrets()
|
all_secrets = await self.backend.get_detailed_secrets()
|
||||||
secrets_mapping = {secret.name: secret for secret in all_secrets}
|
secrets_mapping = {secret.name: secret for secret in all_secrets}
|
||||||
with self.password_manager() as password_manager:
|
async with self.secrets_manager() as password_manager:
|
||||||
if flat:
|
if flat:
|
||||||
all_groups = password_manager.get_secret_group_list(
|
all_groups = await password_manager.get_secret_group_list(
|
||||||
group_filter, regex=regex
|
group_filter, regex=regex
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
all_groups = password_manager.get_secret_groups(
|
all_groups = await password_manager.get_secret_groups(
|
||||||
group_filter, regex=regex
|
group_filter, regex=regex
|
||||||
)
|
)
|
||||||
ungrouped = password_manager.get_ungrouped_secrets()
|
ungrouped = await password_manager.get_ungrouped_secrets()
|
||||||
|
|
||||||
all_admin_secrets = password_manager.get_available_secrets()
|
all_admin_secrets = await password_manager.get_available_secrets()
|
||||||
|
|
||||||
group_result: list[ClientSecretGroup] = []
|
group_result: list[ClientSecretGroup] = []
|
||||||
for group in all_groups:
|
for group in all_groups:
|
||||||
@ -452,8 +458,8 @@ class AdminBackend:
|
|||||||
|
|
||||||
async def get_secret_group_by_path(self, path: str) -> ClientSecretGroup | None:
|
async def get_secret_group_by_path(self, path: str) -> ClientSecretGroup | None:
|
||||||
"""Get a group based on its path."""
|
"""Get a group based on its path."""
|
||||||
with self.password_manager() as password_manager:
|
async with self.secrets_manager() as password_manager:
|
||||||
secret_group = password_manager.get_secret_group(path)
|
secret_group = await password_manager.get_secret_group(path)
|
||||||
|
|
||||||
if not secret_group:
|
if not secret_group:
|
||||||
return None
|
return None
|
||||||
@ -476,9 +482,11 @@ class AdminBackend:
|
|||||||
) -> SecretView | None:
|
) -> SecretView | None:
|
||||||
"""Get a secret, including the actual unencrypted value and clients."""
|
"""Get a secret, including the actual unencrypted value and clients."""
|
||||||
secret: str | None = None
|
secret: str | None = None
|
||||||
with self.password_manager() as password_manager:
|
async with self.secrets_manager() as password_manager:
|
||||||
secret = password_manager.get_secret(name)
|
secret = await password_manager.get_secret(name)
|
||||||
secret_group = password_manager.get_entry_group(name)
|
secret_group: str | None = None
|
||||||
|
if secret:
|
||||||
|
secret_group = await password_manager.get_entry_group(name)
|
||||||
|
|
||||||
secret_view = SecretView(name=name, secret=secret, group=secret_group)
|
secret_view = SecretView(name=name, secret=secret, group=secret_group)
|
||||||
|
|
||||||
@ -503,8 +511,8 @@ class AdminBackend:
|
|||||||
|
|
||||||
async def _delete_secret(self, name: str) -> None:
|
async def _delete_secret(self, name: str) -> None:
|
||||||
"""Delete a secret."""
|
"""Delete a secret."""
|
||||||
with self.password_manager() as password_manager:
|
async with self.secrets_manager() as password_manager:
|
||||||
password_manager.delete_entry(name)
|
await password_manager.delete_entry(name)
|
||||||
|
|
||||||
secret_mapping = await self.backend.get_secret(name)
|
secret_mapping = await self.backend.get_secret(name)
|
||||||
if not secret_mapping:
|
if not secret_mapping:
|
||||||
@ -522,8 +530,8 @@ class AdminBackend:
|
|||||||
group: str | None = None,
|
group: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add a secret."""
|
"""Add a secret."""
|
||||||
with self.password_manager() as password_manager:
|
async with self.secrets_manager() as password_manager:
|
||||||
password_manager.add_entry(name, value, update, group_name=group)
|
await password_manager.add_entry(name, value, update, group_path=group)
|
||||||
|
|
||||||
if update:
|
if update:
|
||||||
secret_map = await self.backend.get_secret(name)
|
secret_map = await self.backend.get_secret(name)
|
||||||
@ -576,8 +584,8 @@ class AdminBackend:
|
|||||||
if not client:
|
if not client:
|
||||||
raise ClientNotFoundError(client_idname)
|
raise ClientNotFoundError(client_idname)
|
||||||
|
|
||||||
with self.password_manager() as password_manager:
|
async with self.secrets_manager() as password_manager:
|
||||||
secret = password_manager.get_secret(secret_name)
|
secret = await password_manager.get_secret(secret_name)
|
||||||
if not secret:
|
if not secret:
|
||||||
raise SecretNotFoundError()
|
raise SecretNotFoundError()
|
||||||
|
|
||||||
|
|||||||
@ -1,348 +0,0 @@
|
|||||||
"""Keepass password manager."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections.abc import Iterator
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
import pykeepass
|
|
||||||
import pykeepass.exceptions
|
|
||||||
from sshecret_admin.core.settings import AdminServerSettings
|
|
||||||
|
|
||||||
from .models import SecretGroup
|
|
||||||
from .master_password import decrypt_master_password
|
|
||||||
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
NO_USERNAME = "NO_USERNAME"
|
|
||||||
|
|
||||||
DEFAULT_LOCATION = "keepass.kdbx"
|
|
||||||
|
|
||||||
|
|
||||||
class PasswordCredentialsError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def create_password_db(location: Path, password: str) -> None:
|
|
||||||
"""Create the password database."""
|
|
||||||
LOG.info("Creating password database at %s", location)
|
|
||||||
pykeepass.create_database(str(location.absolute()), password=password)
|
|
||||||
|
|
||||||
|
|
||||||
def _kp_group_to_secret_group(
|
|
||||||
kp_group: pykeepass.group.Group,
|
|
||||||
parent: SecretGroup | None = None,
|
|
||||||
depth: int | None = None,
|
|
||||||
) -> SecretGroup:
|
|
||||||
"""Convert keepass group to secret group dataclass."""
|
|
||||||
group_name = cast(str, kp_group.name)
|
|
||||||
path = "/".join(cast(list[str], kp_group.path))
|
|
||||||
group = SecretGroup(name=group_name, path=path, description=kp_group.notes)
|
|
||||||
for entry in kp_group.entries:
|
|
||||||
group.entries.append(str(entry.title))
|
|
||||||
if parent:
|
|
||||||
group.parent_group = parent
|
|
||||||
|
|
||||||
current_depth = len(kp_group.path)
|
|
||||||
|
|
||||||
if not parent and current_depth > 1:
|
|
||||||
parent = _kp_group_to_secret_group(kp_group.parentgroup, depth=current_depth)
|
|
||||||
parent.children.append(group)
|
|
||||||
group.parent_group = parent
|
|
||||||
|
|
||||||
if depth and depth == current_depth:
|
|
||||||
return group
|
|
||||||
|
|
||||||
for subgroup in kp_group.subgroups:
|
|
||||||
group.children.append(_kp_group_to_secret_group(subgroup, group, depth=depth))
|
|
||||||
|
|
||||||
return group
|
|
||||||
|
|
||||||
|
|
||||||
class PasswordContext:
|
|
||||||
"""Password Context class."""
|
|
||||||
|
|
||||||
def __init__(self, keepass: pykeepass.PyKeePass) -> None:
|
|
||||||
"""Initialize password context."""
|
|
||||||
self.keepass: pykeepass.PyKeePass = keepass
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _root_group(self) -> pykeepass.group.Group:
|
|
||||||
"""Return the root group."""
|
|
||||||
return cast(pykeepass.group.Group, self.keepass.root_group)
|
|
||||||
|
|
||||||
def _get_entry(self, name: str) -> pykeepass.entry.Entry | None:
|
|
||||||
"""Get entry."""
|
|
||||||
entry = cast(
|
|
||||||
"pykeepass.entry.Entry | None",
|
|
||||||
self.keepass.find_entries(title=name, first=True),
|
|
||||||
)
|
|
||||||
return entry
|
|
||||||
|
|
||||||
def _get_group(self, name: str) -> pykeepass.group.Group | None:
|
|
||||||
"""Find a group."""
|
|
||||||
group = cast(
|
|
||||||
pykeepass.group.Group | None,
|
|
||||||
self.keepass.find_groups(name=name, first=True),
|
|
||||||
)
|
|
||||||
return group
|
|
||||||
|
|
||||||
def add_entry(
|
|
||||||
self,
|
|
||||||
entry_name: str,
|
|
||||||
secret: str,
|
|
||||||
overwrite: bool = False,
|
|
||||||
group_name: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Add an entry.
|
|
||||||
|
|
||||||
Specify overwrite=True to overwrite the existing secret value, if it exists.
|
|
||||||
This will not move the entry, if the group_name is different from the original group.
|
|
||||||
|
|
||||||
"""
|
|
||||||
entry = self._get_entry(entry_name)
|
|
||||||
if entry and overwrite:
|
|
||||||
entry.password = secret
|
|
||||||
self.keepass.save()
|
|
||||||
return
|
|
||||||
|
|
||||||
if entry:
|
|
||||||
raise ValueError("Error: A secret with this name already exists.")
|
|
||||||
LOG.debug("Add secret entry to keepass: %s, group: %r", entry_name, group_name)
|
|
||||||
if group_name:
|
|
||||||
destination_group = self._get_group(group_name)
|
|
||||||
else:
|
|
||||||
destination_group = self._root_group
|
|
||||||
|
|
||||||
entry = self.keepass.add_entry(
|
|
||||||
destination_group=destination_group,
|
|
||||||
title=entry_name,
|
|
||||||
username=NO_USERNAME,
|
|
||||||
password=secret,
|
|
||||||
)
|
|
||||||
self.keepass.save()
|
|
||||||
|
|
||||||
def get_secret(self, entry_name: str) -> str | None:
|
|
||||||
"""Get the secret value."""
|
|
||||||
entry = self._get_entry(entry_name)
|
|
||||||
if not entry:
|
|
||||||
return None
|
|
||||||
|
|
||||||
LOG.warning("Secret name %s accessed", entry_name)
|
|
||||||
if password := cast(str, entry.password):
|
|
||||||
return str(password)
|
|
||||||
|
|
||||||
raise RuntimeError(f"Cannot get password for entry {entry_name}")
|
|
||||||
|
|
||||||
def get_entry_group(self, entry_name: str) -> str | None:
|
|
||||||
"""Get the group for an entry."""
|
|
||||||
entry = self._get_entry(entry_name)
|
|
||||||
if not entry:
|
|
||||||
return None
|
|
||||||
if entry.group.is_root_group:
|
|
||||||
return None
|
|
||||||
return str(entry.group.name)
|
|
||||||
|
|
||||||
def get_secret_groups(
|
|
||||||
self, pattern: str | None = None, regex: bool = True
|
|
||||||
) -> list[SecretGroup]:
|
|
||||||
"""Get secret groups.
|
|
||||||
|
|
||||||
A regex pattern may be provided to filter groups.
|
|
||||||
"""
|
|
||||||
if pattern:
|
|
||||||
groups = cast(
|
|
||||||
list[pykeepass.group.Group],
|
|
||||||
self.keepass.find_groups(name=pattern, regex=regex),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
groups = self._root_group.subgroups
|
|
||||||
|
|
||||||
secret_groups = [_kp_group_to_secret_group(group) for group in groups]
|
|
||||||
return secret_groups
|
|
||||||
|
|
||||||
def get_secret_group_list(
|
|
||||||
self, pattern: str | None = None, regex: bool = True
|
|
||||||
) -> list[SecretGroup]:
|
|
||||||
"""Get a flat list of groups."""
|
|
||||||
if pattern:
|
|
||||||
return self.get_secret_groups(pattern, regex)
|
|
||||||
|
|
||||||
groups = [group for group in self.keepass.groups if not group.is_root_group]
|
|
||||||
secret_groups = [_kp_group_to_secret_group(group) for group in groups]
|
|
||||||
return secret_groups
|
|
||||||
|
|
||||||
def get_secret_group(self, path: str) -> SecretGroup | None:
|
|
||||||
"""Get a secret group by path."""
|
|
||||||
elements = path.split("/")
|
|
||||||
final_element = elements[-1]
|
|
||||||
|
|
||||||
current = self._root_group
|
|
||||||
while elements:
|
|
||||||
groupname = elements.pop(0)
|
|
||||||
matches = [
|
|
||||||
subgroup for subgroup in current.subgroups if subgroup.name == groupname
|
|
||||||
]
|
|
||||||
if matches:
|
|
||||||
current = matches[0]
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
if not current.is_root_group and current.name == final_element:
|
|
||||||
return _kp_group_to_secret_group(current)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_ungrouped_secrets(self) -> list[str]:
|
|
||||||
"""Get secrets without groups."""
|
|
||||||
entries: list[str] = []
|
|
||||||
for entry in self._root_group.entries:
|
|
||||||
entries.append(str(entry.title))
|
|
||||||
|
|
||||||
return entries
|
|
||||||
|
|
||||||
def add_group(
|
|
||||||
self, name: str, description: str | None = None, parent_group: str | None = None
|
|
||||||
) -> None:
|
|
||||||
"""Add a group."""
|
|
||||||
kp_parent_group = self._root_group
|
|
||||||
if parent_group:
|
|
||||||
query = cast(
|
|
||||||
pykeepass.group.Group | None,
|
|
||||||
self.keepass.find_groups(name=parent_group, first=True),
|
|
||||||
)
|
|
||||||
if not query:
|
|
||||||
raise ValueError(
|
|
||||||
f"Error: Cannot find a parent group named {parent_group}"
|
|
||||||
)
|
|
||||||
kp_parent_group = query
|
|
||||||
self.keepass.add_group(
|
|
||||||
destination_group=kp_parent_group, group_name=name, notes=description
|
|
||||||
)
|
|
||||||
self.keepass.save()
|
|
||||||
|
|
||||||
def set_group_description(self, name: str, description: str) -> None:
|
|
||||||
"""Set the description of a group."""
|
|
||||||
group = self._get_group(name)
|
|
||||||
if not group:
|
|
||||||
raise ValueError(f"Error: No such group {name}")
|
|
||||||
|
|
||||||
group.notes = description
|
|
||||||
self.keepass.save()
|
|
||||||
|
|
||||||
def set_secret_group(self, entry_name: str, group_name: str | None) -> None:
|
|
||||||
"""Move a secret to a group.
|
|
||||||
|
|
||||||
If group is None, the secret will be placed in the root group.
|
|
||||||
"""
|
|
||||||
entry = self._get_entry(entry_name)
|
|
||||||
if not entry:
|
|
||||||
raise ValueError(
|
|
||||||
f"Cannot find secret entry named {entry_name} in secrets database"
|
|
||||||
)
|
|
||||||
if group_name:
|
|
||||||
group = self._get_group(group_name)
|
|
||||||
if not group:
|
|
||||||
raise ValueError(f"Cannot find a group named {group_name}")
|
|
||||||
else:
|
|
||||||
group = self._root_group
|
|
||||||
|
|
||||||
self.keepass.move_entry(entry, group)
|
|
||||||
self.keepass.save()
|
|
||||||
|
|
||||||
def move_group(self, name: str, parent_group: str | None) -> None:
|
|
||||||
"""Move a group.
|
|
||||||
|
|
||||||
If parent_group is None, it will be moved to the root.
|
|
||||||
"""
|
|
||||||
group = self._get_group(name)
|
|
||||||
if not group:
|
|
||||||
raise ValueError(f"Error: No such group {name}")
|
|
||||||
if parent_group:
|
|
||||||
parent = self._get_group(parent_group)
|
|
||||||
if not parent:
|
|
||||||
raise ValueError(f"Error: No such group {parent_group}")
|
|
||||||
else:
|
|
||||||
parent = self._root_group
|
|
||||||
|
|
||||||
self.keepass.move_group(group, parent)
|
|
||||||
self.keepass.save()
|
|
||||||
|
|
||||||
def get_available_secrets(self, group_name: str | None = None) -> list[str]:
|
|
||||||
"""Get the names of all secrets in the database."""
|
|
||||||
if group_name:
|
|
||||||
group = self._get_group(group_name)
|
|
||||||
if not group:
|
|
||||||
raise ValueError(f"Error: No such group {group_name}")
|
|
||||||
entries = group.entries
|
|
||||||
else:
|
|
||||||
entries = cast(list[pykeepass.entry.Entry], self.keepass.entries)
|
|
||||||
if not entries:
|
|
||||||
return []
|
|
||||||
return [str(entry.title) for entry in entries]
|
|
||||||
|
|
||||||
def delete_entry(self, entry_name: str) -> None:
|
|
||||||
"""Delete entry."""
|
|
||||||
entry = cast(
|
|
||||||
"pykeepass.entry.Entry | None",
|
|
||||||
self.keepass.find_entries(title=entry_name, first=True),
|
|
||||||
)
|
|
||||||
if not entry:
|
|
||||||
return
|
|
||||||
entry.delete()
|
|
||||||
self.keepass.save()
|
|
||||||
|
|
||||||
def delete_group(self, name: str, keep_entries: bool = True) -> None:
|
|
||||||
"""Delete a group.
|
|
||||||
|
|
||||||
If keep_entries is set to False, all entries in the group will be deleted.
|
|
||||||
"""
|
|
||||||
group = self._get_group(name)
|
|
||||||
if not group:
|
|
||||||
return
|
|
||||||
if keep_entries:
|
|
||||||
for entry in cast(
|
|
||||||
list[pykeepass.entry.Entry],
|
|
||||||
self.keepass.find_entries(recursive=True, group=group),
|
|
||||||
):
|
|
||||||
# Move the entry to the root group.
|
|
||||||
LOG.warning(
|
|
||||||
"Moving orphaned secret entry %s to root group", entry.title
|
|
||||||
)
|
|
||||||
self.keepass.move_entry(entry, self._root_group)
|
|
||||||
|
|
||||||
self.keepass.delete_group(group)
|
|
||||||
self.keepass.save()
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def _password_context(location: Path, password: str) -> Iterator[PasswordContext]:
|
|
||||||
"""Open the password context."""
|
|
||||||
try:
|
|
||||||
database = pykeepass.PyKeePass(str(location.absolute()), password=password)
|
|
||||||
except pykeepass.exceptions.CredentialsError as e:
|
|
||||||
raise PasswordCredentialsError(
|
|
||||||
"Could not open password database. Invalid credentials."
|
|
||||||
) from e
|
|
||||||
context = PasswordContext(database)
|
|
||||||
yield context
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def load_password_manager(
|
|
||||||
settings: AdminServerSettings,
|
|
||||||
encrypted_password: str,
|
|
||||||
location: str = DEFAULT_LOCATION,
|
|
||||||
) -> Iterator[PasswordContext]:
|
|
||||||
"""Load password manager.
|
|
||||||
|
|
||||||
This function decrypts the password, and creates the password database if it
|
|
||||||
has not yet been created.
|
|
||||||
"""
|
|
||||||
db_location = Path(location)
|
|
||||||
password = decrypt_master_password(settings=settings, encrypted=encrypted_password)
|
|
||||||
if not db_location.exists():
|
|
||||||
create_password_db(db_location, password)
|
|
||||||
|
|
||||||
with _password_context(db_location, password) as context:
|
|
||||||
yield context
|
|
||||||
@ -1,86 +0,0 @@
|
|||||||
"""Functions related to handling the password database master password."""
|
|
||||||
|
|
||||||
import secrets
|
|
||||||
from pathlib import Path
|
|
||||||
from sshecret.crypto import (
|
|
||||||
create_private_rsa_key,
|
|
||||||
load_private_key,
|
|
||||||
encrypt_string,
|
|
||||||
decode_string,
|
|
||||||
)
|
|
||||||
from sshecret_admin.core.settings import AdminServerSettings
|
|
||||||
|
|
||||||
KEY_FILENAME = "sshecret-admin-key"
|
|
||||||
|
|
||||||
|
|
||||||
def setup_master_password(
|
|
||||||
settings: AdminServerSettings,
|
|
||||||
filename: str = KEY_FILENAME,
|
|
||||||
regenerate: bool = False,
|
|
||||||
) -> str | None:
|
|
||||||
"""Setup master password.
|
|
||||||
|
|
||||||
If regenerate is True, a new key will be generated.
|
|
||||||
|
|
||||||
This method should run just after setting up the database.
|
|
||||||
"""
|
|
||||||
keyfile = Path(filename)
|
|
||||||
if settings.password_manager_directory:
|
|
||||||
keyfile = settings.password_manager_directory / filename
|
|
||||||
created = _initial_key_setup(settings, keyfile, regenerate)
|
|
||||||
if not created:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return _generate_master_password(settings, keyfile)
|
|
||||||
|
|
||||||
|
|
||||||
def decrypt_master_password(
|
|
||||||
settings: AdminServerSettings, encrypted: str, filename: str = KEY_FILENAME
|
|
||||||
) -> str:
|
|
||||||
"""Retrieve master password."""
|
|
||||||
keyfile = Path(filename)
|
|
||||||
if settings.password_manager_directory:
|
|
||||||
keyfile = settings.password_manager_directory / filename
|
|
||||||
if not keyfile.exists():
|
|
||||||
raise RuntimeError("Error: Private key has not been generated yet.")
|
|
||||||
|
|
||||||
private_key = load_private_key(
|
|
||||||
str(keyfile.absolute()), password=settings.secret_key
|
|
||||||
)
|
|
||||||
return decode_string(encrypted, private_key)
|
|
||||||
|
|
||||||
|
|
||||||
def _generate_password() -> str:
|
|
||||||
"""Generate a password."""
|
|
||||||
return secrets.token_urlsafe(32)
|
|
||||||
|
|
||||||
|
|
||||||
def _initial_key_setup(
|
|
||||||
settings: AdminServerSettings,
|
|
||||||
keyfile: Path,
|
|
||||||
regenerate: bool = False,
|
|
||||||
) -> bool:
|
|
||||||
"""Set up initial keys."""
|
|
||||||
if keyfile.exists() and not regenerate:
|
|
||||||
return False
|
|
||||||
|
|
||||||
assert settings.secret_key is not None, (
|
|
||||||
"Error: Could not load a secret key from environment."
|
|
||||||
)
|
|
||||||
create_private_rsa_key(keyfile, password=settings.secret_key)
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def _generate_master_password(settings: AdminServerSettings, keyfile: Path) -> str:
|
|
||||||
"""Generate master password for password database.
|
|
||||||
|
|
||||||
Returns the encrypted string, base64 encoded.
|
|
||||||
"""
|
|
||||||
if not keyfile.exists():
|
|
||||||
raise RuntimeError("Error: Private key has not been generated yet.")
|
|
||||||
private_key = load_private_key(
|
|
||||||
str(keyfile.absolute()), password=settings.secret_key
|
|
||||||
)
|
|
||||||
public_key = private_key.public_key()
|
|
||||||
master_password = _generate_password()
|
|
||||||
return encrypt_string(master_password, public_key)
|
|
||||||
@ -0,0 +1,776 @@
|
|||||||
|
"""Rewritten secret manager using a rsa keys."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from functools import cached_property
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import selectinload, aliased
|
||||||
|
from sshecret.backend import SshecretBackend
|
||||||
|
from sshecret.backend.api import AuditAPI, KeySpec
|
||||||
|
from sshecret.backend.models import Client, ClientSecret, Operation, SubSystem
|
||||||
|
from sshecret.crypto import (
|
||||||
|
create_private_rsa_key,
|
||||||
|
decode_string,
|
||||||
|
encrypt_string,
|
||||||
|
generate_public_key_string,
|
||||||
|
load_private_key,
|
||||||
|
load_public_key,
|
||||||
|
)
|
||||||
|
from sshecret_admin.auth import PasswordDB
|
||||||
|
from sshecret_admin.auth.models import Group, ManagedSecret
|
||||||
|
from sshecret_admin.core.db import DatabaseSessionManager
|
||||||
|
from sshecret_admin.core.settings import AdminServerSettings
|
||||||
|
from sshecret_admin.services.models import SecretGroup
|
||||||
|
|
||||||
|
|
||||||
|
KEY_FILENAME = "sshecret-admin-key"
|
||||||
|
PASSWORD_MANAGER_ID = "SshecretAdminPasswordManager"
|
||||||
|
|
||||||
|
LOG = logging.getLogger(PASSWORD_MANAGER_ID)
|
||||||
|
|
||||||
|
|
||||||
|
class SecretManagerError(Exception):
|
||||||
|
"""Secret manager error."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidGroupNameError(SecretManagerError):
|
||||||
|
"""Invalid group name."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidSecretNameError(SecretManagerError):
|
||||||
|
"""Invalid secret name."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClientAuditData:
|
||||||
|
"""Client audit data."""
|
||||||
|
|
||||||
|
username: str
|
||||||
|
origin: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ParsedPath:
|
||||||
|
"""Parsed path."""
|
||||||
|
|
||||||
|
item: str
|
||||||
|
full_path: str
|
||||||
|
parent: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SecretDataEntryExport(BaseModel):
|
||||||
|
"""Exportable secret entry."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
secret: str
|
||||||
|
group: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SecretDataGroupExport(BaseModel):
|
||||||
|
"""Exportable secret grouping."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
description: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SecretDataExport(BaseModel):
|
||||||
|
"""Exportable object containing secrets and groups."""
|
||||||
|
|
||||||
|
entries: list[SecretDataEntryExport]
|
||||||
|
groups: list[SecretDataGroupExport]
|
||||||
|
|
||||||
|
|
||||||
|
def split_path(path: str) -> list[str]:
|
||||||
|
"""Split a path into a list of groups."""
|
||||||
|
elements = path.split("/")
|
||||||
|
if path.startswith("/"):
|
||||||
|
elements = elements[1:]
|
||||||
|
|
||||||
|
return elements
|
||||||
|
|
||||||
|
|
||||||
|
def parse_path(path: str) -> ParsedPath:
|
||||||
|
"""Parse path."""
|
||||||
|
elements = split_path(path)
|
||||||
|
parsed = ParsedPath(elements[-1], path)
|
||||||
|
if len(elements) > 1:
|
||||||
|
parsed.parent = elements[-2]
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncSecretContext:
|
||||||
|
"""Async secret context."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
private_key: rsa.RSAPrivateKey,
|
||||||
|
manager_client: Client,
|
||||||
|
session: AsyncSession,
|
||||||
|
backend: SshecretBackend,
|
||||||
|
audit_data: ClientAuditData,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize secret manager"""
|
||||||
|
self._private_key: rsa.RSAPrivateKey = private_key
|
||||||
|
self._manager_client: Client = manager_client
|
||||||
|
self._id: KeySpec = ("id", str(manager_client.id))
|
||||||
|
self.backend: SshecretBackend = backend
|
||||||
|
self.session: AsyncSession = session
|
||||||
|
|
||||||
|
self.audit_data: ClientAuditData = audit_data
|
||||||
|
self.audit: AuditAPI = backend.audit(SubSystem.ADMIN)
|
||||||
|
self._import_has_run: bool = False
|
||||||
|
|
||||||
|
async def _create_missing_entries(self) -> None:
|
||||||
|
"""Create any missing entries."""
|
||||||
|
new_secrets: bool = False
|
||||||
|
to_check = set(self._manager_client.secrets)
|
||||||
|
for secret_name in to_check:
|
||||||
|
# entry = await self._get_entry(secret_name, include_deleted=True)
|
||||||
|
statement = select(ManagedSecret).where(ManagedSecret.name == secret_name)
|
||||||
|
result = await self.session.scalars(statement)
|
||||||
|
if not result.first():
|
||||||
|
new_secrets = True
|
||||||
|
managed_secret = ManagedSecret(name=secret_name)
|
||||||
|
self.session.add(managed_secret)
|
||||||
|
|
||||||
|
await self.session.flush()
|
||||||
|
await self.write_audit(
|
||||||
|
Operation.CREATE,
|
||||||
|
message="Imported managed secret from backend.",
|
||||||
|
secret_name=secret_name,
|
||||||
|
managed_secret=managed_secret,
|
||||||
|
)
|
||||||
|
if new_secrets:
|
||||||
|
await self.session.commit()
|
||||||
|
|
||||||
|
async def _get_group_depth(self, group: Group) -> int:
|
||||||
|
"""Get the depth of a group."""
|
||||||
|
depth = 1
|
||||||
|
if not group.parent_id:
|
||||||
|
return depth
|
||||||
|
|
||||||
|
current = group
|
||||||
|
while current.parent is not None:
|
||||||
|
if current.parent:
|
||||||
|
depth += 1
|
||||||
|
current = await self._get_group_by_id(current.parent.id)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
return depth
|
||||||
|
|
||||||
|
async def _get_group_path(self, group: Group) -> str:
|
||||||
|
"""Get the path of a group."""
|
||||||
|
|
||||||
|
if not group.parent_id:
|
||||||
|
return group.name
|
||||||
|
path: list[str] = []
|
||||||
|
current = group
|
||||||
|
while current.parent_id is not None:
|
||||||
|
path.append(current.name)
|
||||||
|
current = await self._get_group_by_id(current.parent_id)
|
||||||
|
|
||||||
|
path.append("")
|
||||||
|
path.reverse()
|
||||||
|
return "/".join(path)
|
||||||
|
|
||||||
|
async def _get_group_secrets(self, group: Group) -> list[ManagedSecret]:
|
||||||
|
"""Get secrets in a group."""
|
||||||
|
statement = (
|
||||||
|
select(ManagedSecret)
|
||||||
|
.where(ManagedSecret.group_id == group.id)
|
||||||
|
.where(ManagedSecret.is_deleted.is_not(True))
|
||||||
|
)
|
||||||
|
results = await self.session.scalars(statement)
|
||||||
|
return list(results.all())
|
||||||
|
|
||||||
|
async def _build_group_tree(
|
||||||
|
self, group: Group, parent: SecretGroup | None = None, depth: int | None = None
|
||||||
|
) -> SecretGroup:
|
||||||
|
"""Build a group tree."""
|
||||||
|
path = "/"
|
||||||
|
if parent:
|
||||||
|
path = os.path.join(parent.path, path)
|
||||||
|
secret_group = SecretGroup(
|
||||||
|
name=group.name, path=path, description=group.description
|
||||||
|
)
|
||||||
|
group_secrets = await self._get_group_secrets(group)
|
||||||
|
for secret in group_secrets:
|
||||||
|
secret_group.entries.append(secret.name)
|
||||||
|
if parent:
|
||||||
|
secret_group.parent_group = parent
|
||||||
|
|
||||||
|
current_depth = await self._get_group_depth(group)
|
||||||
|
|
||||||
|
if not parent and group.parent:
|
||||||
|
parent_group = await self._get_group_by_id(group.parent.id)
|
||||||
|
assert parent_group is not None
|
||||||
|
parent = await self._build_group_tree(parent_group, depth=current_depth)
|
||||||
|
parent.children.append(secret_group)
|
||||||
|
secret_group.parent_group = parent
|
||||||
|
|
||||||
|
if depth and depth == current_depth:
|
||||||
|
return secret_group
|
||||||
|
|
||||||
|
for subgroup in group.children:
|
||||||
|
child_group = await self._get_group_by_id(subgroup.id)
|
||||||
|
assert child_group is not None
|
||||||
|
secret_subgroup = await self._build_group_tree(
|
||||||
|
child_group, secret_group, depth=depth
|
||||||
|
)
|
||||||
|
secret_group.children.append(secret_subgroup)
|
||||||
|
|
||||||
|
return secret_group
|
||||||
|
|
||||||
|
async def write_audit(
|
||||||
|
self,
|
||||||
|
operation: Operation,
|
||||||
|
message: str,
|
||||||
|
group_name: str | None = None,
|
||||||
|
client_secret: ClientSecret | None = None,
|
||||||
|
secret_name: str | None = None,
|
||||||
|
managed_secret: ManagedSecret | None = None,
|
||||||
|
**data: str,
|
||||||
|
) -> None:
|
||||||
|
"""Write Audit message."""
|
||||||
|
if group_name:
|
||||||
|
data["group"] = group_name
|
||||||
|
|
||||||
|
data["username"] = self.audit_data.username
|
||||||
|
if client_secret and not secret_name:
|
||||||
|
secret_name = client_secret.name
|
||||||
|
|
||||||
|
if managed_secret:
|
||||||
|
data["managed_secret"] = str(managed_secret.id)
|
||||||
|
|
||||||
|
await self.audit.write_async(
|
||||||
|
operation=operation,
|
||||||
|
message=message,
|
||||||
|
origin=self.audit_data.origin,
|
||||||
|
client=self._manager_client,
|
||||||
|
secret=client_secret,
|
||||||
|
secret_name=secret_name,
|
||||||
|
**data,
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def public_key(self) -> rsa.RSAPublicKey:
|
||||||
|
"""Get public key."""
|
||||||
|
keystring = self._manager_client.public_key
|
||||||
|
return load_public_key(keystring.encode())
|
||||||
|
|
||||||
|
async def _get_entry(
|
||||||
|
self, name: str, include_deleted: bool = False
|
||||||
|
) -> ManagedSecret | None:
|
||||||
|
"""Get managed secret."""
|
||||||
|
if not self._import_has_run:
|
||||||
|
await self._create_missing_entries()
|
||||||
|
self._import_has_run = True
|
||||||
|
statement = (
|
||||||
|
select(ManagedSecret)
|
||||||
|
.options(selectinload(ManagedSecret.group))
|
||||||
|
.where(ManagedSecret.name == name)
|
||||||
|
)
|
||||||
|
if not include_deleted:
|
||||||
|
statement = statement.where(ManagedSecret.is_deleted.is_not(True))
|
||||||
|
|
||||||
|
result = await self.session.scalars(statement)
|
||||||
|
return result.first()
|
||||||
|
|
||||||
|
async def add_entry(
|
||||||
|
self,
|
||||||
|
entry_name: str,
|
||||||
|
secret: str,
|
||||||
|
overwrite: bool = False,
|
||||||
|
group_path: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Add entry."""
|
||||||
|
existing_entry = await self._get_entry(entry_name)
|
||||||
|
if existing_entry and not overwrite:
|
||||||
|
raise InvalidSecretNameError(
|
||||||
|
"Another secret with this name is already defined."
|
||||||
|
)
|
||||||
|
|
||||||
|
encrypted = encrypt_string(secret, self.public_key)
|
||||||
|
client_secret = await self.backend.create_client_secret(
|
||||||
|
self._id, entry_name, encrypted
|
||||||
|
)
|
||||||
|
group_id: uuid.UUID | None = None
|
||||||
|
if group_path:
|
||||||
|
elements = parse_path(group_path)
|
||||||
|
group = await self._get_group(elements.item, elements.parent, True)
|
||||||
|
if not group:
|
||||||
|
raise InvalidGroupNameError("Invalid group name")
|
||||||
|
group_id = group.id
|
||||||
|
|
||||||
|
if existing_entry:
|
||||||
|
existing_entry.updated_at = datetime.now(timezone.utc)
|
||||||
|
if group_id:
|
||||||
|
existing_entry.group_id = group_id
|
||||||
|
self.session.add(existing_entry)
|
||||||
|
await self.session.commit()
|
||||||
|
await self.write_audit(
|
||||||
|
Operation.UPDATE,
|
||||||
|
"Updated secret value",
|
||||||
|
group_name=group_path,
|
||||||
|
client_secret=client_secret,
|
||||||
|
managed_secret=existing_entry,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
managed_secret = ManagedSecret(
|
||||||
|
name=entry_name,
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
self.session.add(managed_secret)
|
||||||
|
|
||||||
|
await self.session.commit()
|
||||||
|
await self.write_audit(
|
||||||
|
Operation.CREATE,
|
||||||
|
"Created managed client secret",
|
||||||
|
group_path,
|
||||||
|
client_secret=client_secret,
|
||||||
|
managed_secret=managed_secret,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_secret(self, entry_name: str) -> str | None:
|
||||||
|
"""Get secret."""
|
||||||
|
client_secret = await self.backend.get_client_secret(
|
||||||
|
self._id, ("name", entry_name)
|
||||||
|
)
|
||||||
|
if not client_secret:
|
||||||
|
return None
|
||||||
|
decrypted = decode_string(client_secret, self._private_key)
|
||||||
|
await self.write_audit(
|
||||||
|
Operation.READ,
|
||||||
|
"Secret was viewed from secret manager",
|
||||||
|
secret_name=entry_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return decrypted
|
||||||
|
|
||||||
|
async def get_available_secrets(self, group_path: str | None = None) -> list[str]:
|
||||||
|
"""Get the names of all secrets in the db."""
|
||||||
|
if not self._import_has_run:
|
||||||
|
await self._create_missing_entries()
|
||||||
|
if group_path:
|
||||||
|
elements = parse_path(group_path)
|
||||||
|
group = await self._get_group(elements.item, elements.parent)
|
||||||
|
if not group:
|
||||||
|
raise InvalidGroupNameError("Invalid or nonexisting group name.")
|
||||||
|
entries = group.secrets
|
||||||
|
else:
|
||||||
|
result = await self.session.scalars(
|
||||||
|
select(ManagedSecret)
|
||||||
|
.options(selectinload(ManagedSecret.group))
|
||||||
|
.where(ManagedSecret.is_deleted.is_not(True))
|
||||||
|
)
|
||||||
|
|
||||||
|
entries = list(result.all())
|
||||||
|
|
||||||
|
return [entry.name for entry in entries]
|
||||||
|
|
||||||
|
async def delete_entry(self, entry_name: str) -> None:
|
||||||
|
"""Delete a secret."""
|
||||||
|
entry = await self._get_entry(entry_name)
|
||||||
|
if not entry:
|
||||||
|
return
|
||||||
|
entry.is_deleted = True
|
||||||
|
entry.deleted_at = datetime.now(timezone.utc)
|
||||||
|
self.session.add(entry)
|
||||||
|
await self.session.commit()
|
||||||
|
await self.backend.delete_client_secret(
|
||||||
|
("id", str(self._manager_client.id)), ("name", entry_name)
|
||||||
|
)
|
||||||
|
await self.write_audit(
|
||||||
|
Operation.DELETE,
|
||||||
|
"Managed secret entry deleted",
|
||||||
|
secret_name=entry_name,
|
||||||
|
managed_secret=entry,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_entry_group(self, entry_name: str) -> str | None:
|
||||||
|
"""Get group of entry."""
|
||||||
|
entry = await self._get_entry(entry_name)
|
||||||
|
if not entry:
|
||||||
|
raise InvalidSecretNameError("Invalid secret name or secret not found.")
|
||||||
|
if entry.group:
|
||||||
|
return entry.group.name
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _get_groups(
|
||||||
|
self, pattern: str | None = None, regex: bool = True, root_groups: bool = False
|
||||||
|
) -> list[Group]:
|
||||||
|
"""Get groups."""
|
||||||
|
statement = select(Group).options(
|
||||||
|
selectinload(Group.children), selectinload(Group.parent)
|
||||||
|
)
|
||||||
|
if pattern and regex:
|
||||||
|
statement = statement.where(Group.name.regexp_match(pattern))
|
||||||
|
elif pattern:
|
||||||
|
statement = statement.where(Group.name.contains(pattern))
|
||||||
|
if root_groups:
|
||||||
|
statement = statement.where(Group.parent_id == None)
|
||||||
|
results = await self.session.scalars(statement)
|
||||||
|
return list(results.all())
|
||||||
|
|
||||||
|
async def get_secret_groups(
|
||||||
|
self, pattern: str | None = None, regex: bool = True
|
||||||
|
) -> list[SecretGroup]:
|
||||||
|
"""Get secret groups, as a hierarcy."""
|
||||||
|
if pattern:
|
||||||
|
groups = await self._get_groups(pattern, regex)
|
||||||
|
else:
|
||||||
|
groups = await self._get_groups(root_groups=True)
|
||||||
|
|
||||||
|
secret_groups: list[SecretGroup] = []
|
||||||
|
for group in groups:
|
||||||
|
secret_group = await self._build_group_tree(group)
|
||||||
|
secret_groups.append(secret_group)
|
||||||
|
|
||||||
|
return secret_groups
|
||||||
|
|
||||||
|
async def get_secret_group_list(
|
||||||
|
self, pattern: str | None = None, regex: bool = True
|
||||||
|
) -> list[SecretGroup]:
|
||||||
|
"""Get secret group list."""
|
||||||
|
groups = await self._get_groups(pattern, regex)
|
||||||
|
return [(await self._build_group_tree(group)) for group in groups]
|
||||||
|
|
||||||
|
async def _get_group_by_id(self, id: uuid.UUID) -> Group:
|
||||||
|
"""Get group by ID."""
|
||||||
|
statement = (
|
||||||
|
select(Group)
|
||||||
|
.options(
|
||||||
|
selectinload(Group.parent),
|
||||||
|
selectinload(Group.children),
|
||||||
|
selectinload(Group.secrets),
|
||||||
|
)
|
||||||
|
.where(Group.id == id)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await self.session.scalars(statement)
|
||||||
|
return result.one()
|
||||||
|
|
||||||
|
async def _get_group(
|
||||||
|
self, name: str, parent: str | None = None, exact_match: bool = False
|
||||||
|
) -> Group | None:
|
||||||
|
"""Get a group."""
|
||||||
|
statement = (
|
||||||
|
select(Group)
|
||||||
|
.options(
|
||||||
|
selectinload(Group.parent),
|
||||||
|
selectinload(Group.children),
|
||||||
|
selectinload(Group.secrets),
|
||||||
|
)
|
||||||
|
.where(Group.name == name)
|
||||||
|
)
|
||||||
|
if parent:
|
||||||
|
ParentGroup = aliased(Group)
|
||||||
|
statement = statement.join(ParentGroup, Group.parent).where(
|
||||||
|
ParentGroup.name == parent
|
||||||
|
)
|
||||||
|
elif exact_match:
|
||||||
|
statement = statement.where(Group.parent_id == None)
|
||||||
|
result = await self.session.scalars(statement)
|
||||||
|
return result.first()
|
||||||
|
|
||||||
|
async def get_secret_group(self, path: str) -> SecretGroup | None:
|
||||||
|
"""Get a secret group by path."""
|
||||||
|
elements = parse_path(path)
|
||||||
|
|
||||||
|
group_name = elements.item
|
||||||
|
parent_group = elements.parent
|
||||||
|
|
||||||
|
group = await self._get_group(group_name, parent_group)
|
||||||
|
if not group:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return await self._build_group_tree(group)
|
||||||
|
|
||||||
|
async def get_ungrouped_secrets(self) -> list[str]:
|
||||||
|
"""Get ungrouped secrets."""
|
||||||
|
statement = (
|
||||||
|
select(ManagedSecret)
|
||||||
|
.where(ManagedSecret.is_deleted.is_not(True))
|
||||||
|
.where(ManagedSecret.group_id == None)
|
||||||
|
)
|
||||||
|
result = await self.session.scalars(statement)
|
||||||
|
secrets = result.all()
|
||||||
|
return [secret.name for secret in secrets]
|
||||||
|
|
||||||
|
async def add_group(
|
||||||
|
self,
|
||||||
|
name_or_path: str,
|
||||||
|
description: str | None = None,
|
||||||
|
parent_group: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Add a group."""
|
||||||
|
parent_id: uuid.UUID | None = None
|
||||||
|
group_name = name_or_path
|
||||||
|
if parent_group and name_or_path.startswith("/"):
|
||||||
|
raise InvalidGroupNameError(
|
||||||
|
"Path as name cannot be used if parent is also specified."
|
||||||
|
)
|
||||||
|
if name_or_path.startswith("/"):
|
||||||
|
elements = parse_path(name_or_path)
|
||||||
|
group_name = elements.item
|
||||||
|
parent_group = elements.parent
|
||||||
|
|
||||||
|
if parent_group:
|
||||||
|
if parent := (await self._get_group(parent_group)):
|
||||||
|
child_names = [child.name for child in parent.children]
|
||||||
|
if group_name in child_names:
|
||||||
|
raise InvalidGroupNameError(
|
||||||
|
"Parent group already has a group with this name."
|
||||||
|
)
|
||||||
|
parent_id = parent.id
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise InvalidGroupNameError(
|
||||||
|
"Invalid or non-existing parent group name."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
existing_group = await self._get_group(group_name)
|
||||||
|
if existing_group:
|
||||||
|
raise InvalidGroupNameError("A group with this name already exists.")
|
||||||
|
|
||||||
|
group = Group(
|
||||||
|
name=group_name,
|
||||||
|
description=description,
|
||||||
|
parent_id=parent_id,
|
||||||
|
)
|
||||||
|
self.session.add(group)
|
||||||
|
# We don't audit-log this operation.
|
||||||
|
await self.session.commit()
|
||||||
|
|
||||||
|
async def set_group_description(self, path: str, description: str) -> None:
|
||||||
|
"""Set group description."""
|
||||||
|
elements = parse_path(path)
|
||||||
|
group = await self._get_group(elements.item, elements.parent, True)
|
||||||
|
if not group:
|
||||||
|
raise InvalidGroupNameError("Invalid or non-existing group name.")
|
||||||
|
group.description = description
|
||||||
|
self.session.add(group)
|
||||||
|
await self.session.commit()
|
||||||
|
|
||||||
|
async def set_secret_group(self, entry_name: str, group_name: str | None) -> None:
|
||||||
|
"""Move a secret to a group.
|
||||||
|
|
||||||
|
If group_name is None, the secret will be moved out of any group it may exist in.
|
||||||
|
"""
|
||||||
|
entry = await self._get_entry(entry_name)
|
||||||
|
if not entry:
|
||||||
|
raise InvalidSecretNameError("Invalid or non-existing secret.")
|
||||||
|
if group_name:
|
||||||
|
elements = parse_path(group_name)
|
||||||
|
group = await self._get_group(elements.item, elements.parent, True)
|
||||||
|
if not group:
|
||||||
|
raise InvalidGroupNameError("Invalid or non-existing group name.")
|
||||||
|
entry.group_id = group.id
|
||||||
|
else:
|
||||||
|
entry.group_id = None
|
||||||
|
|
||||||
|
self.session.add(entry)
|
||||||
|
await self.session.commit()
|
||||||
|
await self.write_audit(
|
||||||
|
Operation.UPDATE,
|
||||||
|
"Secret group updated",
|
||||||
|
group_name=group_name or "ROOT",
|
||||||
|
secret_name=entry_name,
|
||||||
|
managed_secret=entry,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def move_group(self, path: str, parent_group: str | None) -> None:
|
||||||
|
"""Move group.
|
||||||
|
|
||||||
|
If parent_group is None, it will be moved to the root.
|
||||||
|
"""
|
||||||
|
elements = parse_path(path)
|
||||||
|
group = await self._get_group(elements.item, elements.parent, True)
|
||||||
|
if not group:
|
||||||
|
raise InvalidGroupNameError("Invalid or non-existing group name.")
|
||||||
|
|
||||||
|
parent_group_id: uuid.UUID | None = None
|
||||||
|
if parent_group:
|
||||||
|
db_parent_group = await self._get_group(parent_group)
|
||||||
|
if not db_parent_group:
|
||||||
|
raise InvalidGroupNameError("Invalid or non-existing parent group.")
|
||||||
|
parent_group_id = db_parent_group.id
|
||||||
|
|
||||||
|
group.parent_id = parent_group_id
|
||||||
|
|
||||||
|
self.session.add(group)
|
||||||
|
await self.session.commit()
|
||||||
|
|
||||||
|
async def delete_group(self, path: str) -> None:
|
||||||
|
"""Delete a group."""
|
||||||
|
elements = parse_path(path)
|
||||||
|
group = await self._get_group(elements.item, elements.parent, True)
|
||||||
|
if not group:
|
||||||
|
return
|
||||||
|
await self.session.delete(group)
|
||||||
|
|
||||||
|
await self.session.commit()
|
||||||
|
# We don't audit-log this operation currently, even though it indirectly
|
||||||
|
# may affect secrets.
|
||||||
|
|
||||||
|
async def _export_entries(self) -> list[SecretDataEntryExport]:
|
||||||
|
"""Export entries as a pydantic object."""
|
||||||
|
statement = (
|
||||||
|
select(ManagedSecret)
|
||||||
|
.options(selectinload(ManagedSecret.group))
|
||||||
|
.where(ManagedSecret.is_deleted.is_(False))
|
||||||
|
)
|
||||||
|
results = await self.session.scalars(statement)
|
||||||
|
entries: list[SecretDataEntryExport] = []
|
||||||
|
for entry in results.all():
|
||||||
|
group: str | None = None
|
||||||
|
if entry.group:
|
||||||
|
group = await self._get_group_path(entry.group)
|
||||||
|
secret = await self.get_secret(entry.name)
|
||||||
|
if not secret:
|
||||||
|
continue
|
||||||
|
data = SecretDataEntryExport(name=entry.name, secret=secret, group=group)
|
||||||
|
entries.append(data)
|
||||||
|
return entries
|
||||||
|
|
||||||
|
async def _export_groups(self) -> list[SecretDataGroupExport]:
|
||||||
|
"""Export groups as pydantic objects."""
|
||||||
|
groups = await self.get_secret_group_list()
|
||||||
|
entries = [
|
||||||
|
SecretDataGroupExport(
|
||||||
|
name=group.name,
|
||||||
|
path=group.path,
|
||||||
|
description=group.description,
|
||||||
|
)
|
||||||
|
for group in groups
|
||||||
|
]
|
||||||
|
return entries
|
||||||
|
|
||||||
|
async def export_secrets(self) -> SecretDataExport:
|
||||||
|
"""Export the managed secrets as a pydantic object."""
|
||||||
|
entries = await self._export_entries()
|
||||||
|
groups = await self._export_groups()
|
||||||
|
return SecretDataExport(entries=entries, groups=groups)
|
||||||
|
|
||||||
|
async def export_secrets_json(self) -> str:
|
||||||
|
"""Export secrets as JSON."""
|
||||||
|
export = await self.export_secrets()
|
||||||
|
return export.model_dump_json(indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def get_managed_private_key(
|
||||||
|
settings: AdminServerSettings,
|
||||||
|
filename: str = KEY_FILENAME,
|
||||||
|
regenerate: bool = False,
|
||||||
|
) -> rsa.RSAPrivateKey:
|
||||||
|
"""Load our private key."""
|
||||||
|
keyfile = Path(filename)
|
||||||
|
if settings.password_manager_directory:
|
||||||
|
keyfile = settings.password_manager_directory / filename
|
||||||
|
if not keyfile.exists():
|
||||||
|
_initial_key_setup(settings, keyfile)
|
||||||
|
setup_password_manager(settings, keyfile, regenerate)
|
||||||
|
return load_private_key(str(keyfile.absolute()), password=settings.secret_key)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_password_manager(
|
||||||
|
settings: AdminServerSettings, filename: Path, regenerate: bool = False
|
||||||
|
) -> bool:
|
||||||
|
"""Setup password manager."""
|
||||||
|
if filename.exists() and not regenerate:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not settings.secret_key:
|
||||||
|
raise RuntimeError("Error: Could not load secret key from environment.")
|
||||||
|
create_private_rsa_key(filename, password=settings.secret_key)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def create_manager_client(
|
||||||
|
backend: SshecretBackend, public_key: rsa.RSAPublicKey
|
||||||
|
) -> Client:
|
||||||
|
"""Create the manager client."""
|
||||||
|
public_key_string = generate_public_key_string(public_key)
|
||||||
|
new_client = await backend.create_system_client(
|
||||||
|
"AdminPasswordManager",
|
||||||
|
public_key_string,
|
||||||
|
)
|
||||||
|
return new_client
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def password_manager_context(
|
||||||
|
settings: AdminServerSettings, username: str, origin: str
|
||||||
|
) -> AsyncIterator[AsyncSecretContext]:
|
||||||
|
"""Start a context for the password manager."""
|
||||||
|
audit_context_data = ClientAuditData(username=username, origin=origin)
|
||||||
|
session_manager = DatabaseSessionManager(settings.async_db_url)
|
||||||
|
backend = SshecretBackend(str(settings.backend_url), settings.backend_token)
|
||||||
|
private_key = get_managed_private_key(settings)
|
||||||
|
async with session_manager.session() as session:
|
||||||
|
# Check if there is a client_id stored already.
|
||||||
|
query = select(PasswordDB).where(PasswordDB.id == 1)
|
||||||
|
result = await session.scalars(query)
|
||||||
|
password_db = result.first()
|
||||||
|
if not password_db:
|
||||||
|
password_db = PasswordDB(id=1)
|
||||||
|
session.add(password_db)
|
||||||
|
await session.flush()
|
||||||
|
if not password_db.client_id:
|
||||||
|
manager_client = await create_manager_client(
|
||||||
|
backend, private_key.public_key()
|
||||||
|
)
|
||||||
|
password_db.client_id = manager_client.id
|
||||||
|
session.add(password_db)
|
||||||
|
await session.commit()
|
||||||
|
else:
|
||||||
|
manager_client = await backend.get_client(
|
||||||
|
("id", str(password_db.client_id))
|
||||||
|
)
|
||||||
|
if not manager_client:
|
||||||
|
raise SecretManagerError("Error: Could not fetch system client.")
|
||||||
|
|
||||||
|
context = AsyncSecretContext(
|
||||||
|
private_key, manager_client, session, backend, audit_context_data
|
||||||
|
)
|
||||||
|
yield context
|
||||||
|
|
||||||
|
|
||||||
|
def setup_private_key(
|
||||||
|
settings: AdminServerSettings,
|
||||||
|
filename: str = KEY_FILENAME,
|
||||||
|
regenerate: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Setup secret manager private key."""
|
||||||
|
keyfile = Path(filename)
|
||||||
|
if settings.password_manager_directory:
|
||||||
|
keyfile = settings.password_manager_directory / filename
|
||||||
|
_initial_key_setup(settings, keyfile, regenerate)
|
||||||
|
|
||||||
|
|
||||||
|
def _initial_key_setup(
|
||||||
|
settings: AdminServerSettings,
|
||||||
|
keyfile: Path,
|
||||||
|
regenerate: bool = False,
|
||||||
|
) -> bool:
|
||||||
|
"""Set up initial keys."""
|
||||||
|
if keyfile.exists() and not regenerate:
|
||||||
|
return False
|
||||||
|
|
||||||
|
assert (
|
||||||
|
settings.secret_key is not None
|
||||||
|
), "Error: Could not load a secret key from environment."
|
||||||
|
create_private_rsa_key(keyfile, password=settings.secret_key)
|
||||||
|
return True
|
||||||
@ -0,0 +1,33 @@
|
|||||||
|
"""Implement db structures for internal password manager
|
||||||
|
|
||||||
|
Revision ID: 1657c5d25d2c
|
||||||
|
Revises: b4e135ff347a
|
||||||
|
Create Date: 2025-06-21 07:22:17.792528
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '1657c5d25d2c'
|
||||||
|
down_revision: Union[str, None] = 'b4e135ff347a'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column('client', sa.Column('is_system', sa.Boolean(), nullable=False, default=False, server_default="0"))
|
||||||
|
op.add_column('client_secret', sa.Column('is_system', sa.Boolean(), nullable=False, default=False, server_default="0"))
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_column('client_secret', 'is_system')
|
||||||
|
op.drop_column('client', 'is_system')
|
||||||
@ -0,0 +1,44 @@
|
|||||||
|
"""Remove secret key from password database
|
||||||
|
|
||||||
|
Revision ID: 71f7272a6ee1
|
||||||
|
Revises: 1657c5d25d2c
|
||||||
|
Create Date: 2025-06-22 18:42:53.207334
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '71f7272a6ee1'
|
||||||
|
down_revision: Union[str, None] = '1657c5d25d2c'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table('managed_secret')
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('managed_secret',
|
||||||
|
sa.Column('id', sa.CHAR(length=32), nullable=False),
|
||||||
|
sa.Column('name', sa.VARCHAR(), nullable=False),
|
||||||
|
sa.Column('description', sa.VARCHAR(), nullable=True),
|
||||||
|
sa.Column('secret', sa.VARCHAR(), nullable=False),
|
||||||
|
sa.Column('client_id', sa.CHAR(length=32), nullable=True),
|
||||||
|
sa.Column('deleted', sa.BOOLEAN(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DATETIME(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DATETIME(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
|
||||||
|
sa.Column('deleted_at', sa.DATETIME(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['client_id'], ['client.id'], ondelete='CASCADE'),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -107,8 +107,7 @@ class ClientOperations:
|
|||||||
return ClientView.from_client(db_client)
|
return ClientView.from_client(db_client)
|
||||||
|
|
||||||
async def create_client(
|
async def create_client(
|
||||||
self,
|
self, create_model: ClientCreate, system_client: bool = False
|
||||||
create_model: ClientCreate,
|
|
||||||
) -> ClientView:
|
) -> ClientView:
|
||||||
"""Create a new client."""
|
"""Create a new client."""
|
||||||
existing_id = await self.get_client_id(FlexID.name(create_model.name))
|
existing_id = await self.get_client_id(FlexID.name(create_model.name))
|
||||||
@ -117,6 +116,15 @@ class ClientOperations:
|
|||||||
status_code=400, detail="Error: A client already exists with this name."
|
status_code=400, detail="Error: A client already exists with this name."
|
||||||
)
|
)
|
||||||
client = create_model.to_client()
|
client = create_model.to_client()
|
||||||
|
if system_client:
|
||||||
|
statement = query_active_clients().where(Client.is_system.is_(True))
|
||||||
|
results = await self.session.scalars(statement)
|
||||||
|
other_system_clients = results.all()
|
||||||
|
if other_system_clients:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Only one system client may exist"
|
||||||
|
)
|
||||||
|
client.is_system = True
|
||||||
self.session.add(client)
|
self.session.add(client)
|
||||||
await self.session.flush()
|
await self.session.flush()
|
||||||
await self.session.commit()
|
await self.session.commit()
|
||||||
@ -246,6 +254,15 @@ class ClientOperations:
|
|||||||
|
|
||||||
return ClientPolicyView.from_client(db_client)
|
return ClientPolicyView.from_client(db_client)
|
||||||
|
|
||||||
|
async def get_system_client(self) -> ClientView:
|
||||||
|
"""Get the system client, if it exists."""
|
||||||
|
statement = query_active_clients().where(Client.is_system.is_(True))
|
||||||
|
result = await self.session.scalars(statement)
|
||||||
|
client = result.first()
|
||||||
|
if not client:
|
||||||
|
raise HTTPException(status_code=404, detail="No system client registered")
|
||||||
|
return ClientView.from_client(client)
|
||||||
|
|
||||||
|
|
||||||
def resolve_order(statement: Select[Any], order_by: str, reversed: bool) -> Select[Any]:
|
def resolve_order(statement: Select[Any], order_by: str, reversed: bool) -> Select[Any]:
|
||||||
"""Resolve ordering."""
|
"""Resolve ordering."""
|
||||||
@ -267,6 +284,7 @@ def resolve_order(statement: Select[Any], order_by: str, reversed: bool) -> Sele
|
|||||||
LOG.warning("Unsupported order field: %s", order_by)
|
LOG.warning("Unsupported order field: %s", order_by)
|
||||||
return statement
|
return statement
|
||||||
|
|
||||||
|
|
||||||
def filter_client_statement(
|
def filter_client_statement(
|
||||||
statement: Select[Any], params: ClientListParams, ignore_limits: bool = False
|
statement: Select[Any], params: ClientListParams, ignore_limits: bool = False
|
||||||
) -> Select[Any]:
|
) -> Select[Any]:
|
||||||
@ -299,6 +317,7 @@ async def get_clients(
|
|||||||
.select_from(Client)
|
.select_from(Client)
|
||||||
.where(Client.is_deleted.is_not(True))
|
.where(Client.is_deleted.is_not(True))
|
||||||
.where(Client.is_active.is_not(False))
|
.where(Client.is_active.is_not(False))
|
||||||
|
.where(Client.is_system.is_not(True))
|
||||||
)
|
)
|
||||||
count_statement = cast(
|
count_statement = cast(
|
||||||
Select[tuple[int]],
|
Select[tuple[int]],
|
||||||
@ -307,7 +326,8 @@ async def get_clients(
|
|||||||
|
|
||||||
total_results = (await session.scalars(count_statement)).one()
|
total_results = (await session.scalars(count_statement)).one()
|
||||||
|
|
||||||
statement = filter_client_statement(query_active_clients(), filter_query, False)
|
statement = query_active_clients().where(Client.is_system.is_not(True))
|
||||||
|
statement = filter_client_statement(statement, filter_query, False)
|
||||||
|
|
||||||
results = await session.scalars(statement)
|
results = await session.scalars(statement)
|
||||||
remainder = total_results - filter_query.offset - filter_query.limit
|
remainder = total_results - filter_query.offset - filter_query.limit
|
||||||
|
|||||||
@ -46,6 +46,25 @@ def create_client_router(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
|||||||
client_op = ClientOperations(session, request)
|
client_op = ClientOperations(session, request)
|
||||||
return await client_op.create_client(client)
|
return await client_op.create_client(client)
|
||||||
|
|
||||||
|
@router.get("/internal/system_client/", include_in_schema=False)
|
||||||
|
async def get_system_client(
|
||||||
|
request: Request,
|
||||||
|
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||||
|
) -> ClientView:
|
||||||
|
"""Get the system client."""
|
||||||
|
client_op = ClientOperations(session, request)
|
||||||
|
return await client_op.get_system_client()
|
||||||
|
|
||||||
|
@router.post("/internal/system_client/", include_in_schema=False)
|
||||||
|
async def create_system_client(
|
||||||
|
request: Request,
|
||||||
|
client: ClientCreate,
|
||||||
|
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||||
|
) -> ClientView:
|
||||||
|
"""Create system client."""
|
||||||
|
client_op = ClientOperations(session, request)
|
||||||
|
return await client_op.create_client(client, system_client=True)
|
||||||
|
|
||||||
@router.get("/clients/{client_identifier}")
|
@router.get("/clients/{client_identifier}")
|
||||||
async def fetch_client_by_name(
|
async def fetch_client_by_name(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
|||||||
@ -242,7 +242,7 @@ async def resolve_client_secret_clients(
|
|||||||
# Ensure we don't create the object before we have at least one client.
|
# Ensure we don't create the object before we have at least one client.
|
||||||
clients = ClientSecretDetailList(name=name)
|
clients = ClientSecretDetailList(name=name)
|
||||||
clients.ids.append(str(client_secret.id))
|
clients.ids.append(str(client_secret.id))
|
||||||
if client_secret.client:
|
if client_secret.client and not client_secret.client.is_system:
|
||||||
clients.clients.append(
|
clients.clients.append(
|
||||||
ClientReference(
|
ClientReference(
|
||||||
id=str(client_secret.client.id), name=client_secret.client.name
|
id=str(client_secret.client.id), name=client_secret.client.name
|
||||||
|
|||||||
@ -110,7 +110,6 @@ def get_async_engine(url: URL, echo: bool = False, **engine_kwargs: str) -> Asyn
|
|||||||
"""Get an async engine."""
|
"""Get an async engine."""
|
||||||
engine = create_async_engine(url, echo=echo, **engine_kwargs)
|
engine = create_async_engine(url, echo=echo, **engine_kwargs)
|
||||||
if url.drivername.startswith("sqlite+"):
|
if url.drivername.startswith("sqlite+"):
|
||||||
|
|
||||||
@event.listens_for(engine.sync_engine, "connect")
|
@event.listens_for(engine.sync_engine, "connect")
|
||||||
def set_sqlite_pragma(
|
def set_sqlite_pragma(
|
||||||
dbapi_connection: sqlite3.Connection, _connection_record: object
|
dbapi_connection: sqlite3.Connection, _connection_record: object
|
||||||
|
|||||||
@ -67,6 +67,7 @@ class Client(Base):
|
|||||||
|
|
||||||
is_active: Mapped[bool] = mapped_column(sa.Boolean, default=True)
|
is_active: Mapped[bool] = mapped_column(sa.Boolean, default=True)
|
||||||
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, default=False)
|
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, default=False)
|
||||||
|
is_system: Mapped[bool] = mapped_column(sa.Boolean, default=False)
|
||||||
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||||
@ -141,6 +142,8 @@ class ClientSecret(Base):
|
|||||||
client: Mapped[Client] = relationship(back_populates="secrets")
|
client: Mapped[Client] = relationship(back_populates="secrets")
|
||||||
deleted: Mapped[bool] = mapped_column(default=False)
|
deleted: Mapped[bool] = mapped_column(default=False)
|
||||||
|
|
||||||
|
is_system: Mapped[bool] = mapped_column(sa.Boolean, default=False)
|
||||||
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||||
)
|
)
|
||||||
|
|||||||
@ -140,11 +140,23 @@ class ShellStoreSecret(CommandDispatcher):
|
|||||||
secret=secret_name,
|
secret=secret_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await self.store_managed_secret(secret_name, secret_data)
|
||||||
|
|
||||||
def encrypt_secret(self, value: str) -> str:
|
def encrypt_secret(self, value: str) -> str:
|
||||||
"""Encrypt a secret."""
|
"""Encrypt a secret."""
|
||||||
public_key = load_public_key(self.client.public_key.encode())
|
public_key = load_public_key(self.client.public_key.encode())
|
||||||
return encrypt_string(value, public_key)
|
return encrypt_string(value, public_key)
|
||||||
|
|
||||||
|
async def store_managed_secret(self, secret_name: str, secret_data: str) -> None:
|
||||||
|
"""Store managed secret."""
|
||||||
|
system_client = await self.backend.get_system_client()
|
||||||
|
if not system_client:
|
||||||
|
return
|
||||||
|
public_key = load_public_key(system_client.public_key.encode())
|
||||||
|
encrypted = encrypt_string(secret_data, public_key)
|
||||||
|
await self.backend.create_client_secret(("id", str(system_client.id)), secret_name, encrypted)
|
||||||
|
await self.audit(operation=Operation.CREATE, message="Managed secret entry created.", secret=secret_name)
|
||||||
|
|
||||||
async def get_secret_on_stdin(self) -> str | None:
|
async def get_secret_on_stdin(self) -> str | None:
|
||||||
"""Get secret from stdin."""
|
"""Get secret from stdin."""
|
||||||
secret_data = ""
|
secret_data = ""
|
||||||
|
|||||||
@ -6,6 +6,7 @@ admin and sshd library do not need to implement the same
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Literal, Self, override
|
from typing import Any, Literal, Self, override
|
||||||
|
import uuid
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
@ -325,6 +326,28 @@ class SshecretBackend(BaseBackend):
|
|||||||
path = "/api/v1/clients/"
|
path = "/api/v1/clients/"
|
||||||
response = await self._post(path, json=data)
|
response = await self._post(path, json=data)
|
||||||
|
|
||||||
|
async def create_system_client(self, name: str, public_key: str) -> Client:
|
||||||
|
"""Create system client."""
|
||||||
|
if not validate_public_key(public_key):
|
||||||
|
raise BackendValidationError("Error: Invalid public key format.")
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"name": name,
|
||||||
|
"public_key": public_key,
|
||||||
|
"description": "Internal system client",
|
||||||
|
}
|
||||||
|
path = "/api/v1/internal/system_client/"
|
||||||
|
response = await self._post(path, json=data)
|
||||||
|
return Client.model_validate(response.json())
|
||||||
|
|
||||||
|
async def get_system_client(self) -> Client | None:
|
||||||
|
"""Get the system client."""
|
||||||
|
path = "/api/v1/internal/system_client/"
|
||||||
|
response = await self._get(path)
|
||||||
|
if response.status_code == 404:
|
||||||
|
return None
|
||||||
|
return Client.model_validate(response.json())
|
||||||
|
|
||||||
async def get_clients(self, filter: ClientFilter | None = None) -> list[Client]:
|
async def get_clients(self, filter: ClientFilter | None = None) -> list[Client]:
|
||||||
"""Get all clients."""
|
"""Get all clients."""
|
||||||
clients: list[Client] = []
|
clients: list[Client] = []
|
||||||
@ -375,7 +398,7 @@ class SshecretBackend(BaseBackend):
|
|||||||
|
|
||||||
async def create_client_secret(
|
async def create_client_secret(
|
||||||
self, client_idname: KeySpec, secret_name: str, encrypted_secret: str
|
self, client_idname: KeySpec, secret_name: str, encrypted_secret: str
|
||||||
) -> None:
|
) -> ClientSecret:
|
||||||
"""Create a secret.
|
"""Create a secret.
|
||||||
|
|
||||||
This will overwrite any existing secret with that name.
|
This will overwrite any existing secret with that name.
|
||||||
@ -383,6 +406,8 @@ class SshecretBackend(BaseBackend):
|
|||||||
client_key = _key(client_idname)
|
client_key = _key(client_idname)
|
||||||
path = f"api/v1/clients/{client_key}/secrets/{secret_name}"
|
path = f"api/v1/clients/{client_key}/secrets/{secret_name}"
|
||||||
response = await self._put(path, json={"value": encrypted_secret})
|
response = await self._put(path, json={"value": encrypted_secret})
|
||||||
|
secret = ClientSecret.model_validate(response.json())
|
||||||
|
return secret
|
||||||
|
|
||||||
async def get_client_secret(
|
async def get_client_secret(
|
||||||
self, client_idname: KeySpec, secret_idname: KeySpec
|
self, client_idname: KeySpec, secret_idname: KeySpec
|
||||||
|
|||||||
@ -8,14 +8,14 @@ from pathlib import Path
|
|||||||
from sqlmodel import Session, create_engine
|
from sqlmodel import Session, create_engine
|
||||||
from sshecret.crypto import generate_private_key, write_private_key
|
from sshecret.crypto import generate_private_key, write_private_key
|
||||||
from sshecret_admin.auth.authentication import hash_password
|
from sshecret_admin.auth.authentication import hash_password
|
||||||
from sshecret_admin.auth.models import AuthProvider, User, init_db
|
from sshecret_admin.auth.models import AuthProvider, User, Base
|
||||||
from sshecret_admin.core.settings import AdminServerSettings
|
from sshecret_admin.core.settings import AdminServerSettings
|
||||||
|
|
||||||
def create_test_admin_user(settings: AdminServerSettings, username: str, password: str) -> None:
|
def create_test_admin_user(settings: AdminServerSettings, username: str, password: str) -> None:
|
||||||
"""Create a test admin user."""
|
"""Create a test admin user."""
|
||||||
hashed_password = hash_password(password)
|
hashed_password = hash_password(password)
|
||||||
engine = create_engine(settings.admin_db)
|
engine = create_engine(settings.admin_db)
|
||||||
init_db(engine)
|
Base.metadata.create_all(engine)
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
user = User(username=username, hashed_password=hashed_password, provider=AuthProvider.LOCAL, email="test@test.com")
|
user = User(username=username, hashed_password=hashed_password, provider=AuthProvider.LOCAL, email="test@test.com")
|
||||||
session.add(user)
|
session.add(user)
|
||||||
|
|||||||
1
tests/integration/admin/__init__.py
Normal file
1
tests/integration/admin/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
||||||
67
tests/integration/admin/base.py
Normal file
67
tests/integration/admin/base.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from sshecret.backend import Client
|
||||||
|
|
||||||
|
from sshecret.crypto import generate_private_key, generate_public_key_string
|
||||||
|
|
||||||
|
from ..types import AdminServer
|
||||||
|
|
||||||
|
|
||||||
|
def make_test_key() -> str:
|
||||||
|
"""Generate a test key."""
|
||||||
|
private_key = generate_private_key()
|
||||||
|
return generate_public_key_string(private_key.public_key())
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAdminTests:
|
||||||
|
"""Base admin test class."""
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def http_client(
|
||||||
|
self, admin_server: AdminServer, authenticate: bool = True
|
||||||
|
) -> AsyncIterator[httpx.AsyncClient]:
|
||||||
|
"""Run a client towards the admin rest api."""
|
||||||
|
admin_url, credentials = admin_server
|
||||||
|
username, password = credentials
|
||||||
|
headers: dict[str, str] | None = None
|
||||||
|
if authenticate:
|
||||||
|
async with httpx.AsyncClient(base_url=admin_url) as client:
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"api/v1/token", data={"username": username, "password": password}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "access_token" in data
|
||||||
|
token = data["access_token"]
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(base_url=admin_url, headers=headers) as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
async def create_client(
|
||||||
|
self,
|
||||||
|
admin_server: AdminServer,
|
||||||
|
name: str,
|
||||||
|
public_key: str | None = None,
|
||||||
|
) -> Client:
|
||||||
|
"""Create a client."""
|
||||||
|
if not public_key:
|
||||||
|
public_key = make_test_key()
|
||||||
|
|
||||||
|
new_client = {
|
||||||
|
"name": name,
|
||||||
|
"public_key": public_key,
|
||||||
|
"sources": ["192.0.2.0/24"],
|
||||||
|
}
|
||||||
|
|
||||||
|
async with self.http_client(admin_server, True) as http_client:
|
||||||
|
response = await http_client.post("api/v1/clients/", json=new_client)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
client = Client.model_validate(data)
|
||||||
|
|
||||||
|
return client
|
||||||
@ -1,76 +1,12 @@
|
|||||||
"""Tests of the admin interface."""
|
"""Tests of the admin interface."""
|
||||||
|
|
||||||
from collections.abc import AsyncIterator
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
import allure
|
import allure
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from allure_commons.types import Severity
|
from allure_commons.types import Severity
|
||||||
|
|
||||||
from sshecret.backend import Client
|
from ..types import AdminServer
|
||||||
|
from .base import BaseAdminTests
|
||||||
from sshecret.crypto import generate_private_key, generate_public_key_string
|
|
||||||
|
|
||||||
from .types import AdminServer
|
|
||||||
|
|
||||||
|
|
||||||
def make_test_key() -> str:
|
|
||||||
"""Generate a test key."""
|
|
||||||
private_key = generate_private_key()
|
|
||||||
return generate_public_key_string(private_key.public_key())
|
|
||||||
|
|
||||||
|
|
||||||
class BaseAdminTests:
|
|
||||||
"""Base admin test class."""
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def http_client(
|
|
||||||
self, admin_server: AdminServer, authenticate: bool = True
|
|
||||||
) -> AsyncIterator[httpx.AsyncClient]:
|
|
||||||
"""Run a client towards the admin rest api."""
|
|
||||||
admin_url, credentials = admin_server
|
|
||||||
username, password = credentials
|
|
||||||
headers: dict[str, str] | None = None
|
|
||||||
if authenticate:
|
|
||||||
async with httpx.AsyncClient(base_url=admin_url) as client:
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"api/v1/token", data={"username": username, "password": password}
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert "access_token" in data
|
|
||||||
token = data["access_token"]
|
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(base_url=admin_url, headers=headers) as client:
|
|
||||||
yield client
|
|
||||||
|
|
||||||
async def create_client(
|
|
||||||
self,
|
|
||||||
admin_server: AdminServer,
|
|
||||||
name: str,
|
|
||||||
public_key: str | None = None,
|
|
||||||
) -> Client:
|
|
||||||
"""Create a client."""
|
|
||||||
if not public_key:
|
|
||||||
public_key = make_test_key()
|
|
||||||
|
|
||||||
new_client = {
|
|
||||||
"name": name,
|
|
||||||
"public_key": public_key,
|
|
||||||
"sources": ["192.0.2.0/24"],
|
|
||||||
}
|
|
||||||
|
|
||||||
async with self.http_client(admin_server, True) as http_client:
|
|
||||||
response = await http_client.post("api/v1/clients/", json=new_client)
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
client = Client.model_validate(data)
|
|
||||||
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
@allure.title("Admin API")
|
@allure.title("Admin API")
|
||||||
@ -196,7 +132,9 @@ class TestAdminApiSecrets(BaseAdminTests):
|
|||||||
assert "testclient" in data["clients"]
|
assert "testclient" in data["clients"]
|
||||||
|
|
||||||
@allure.title("Test adding a secret with automatic value")
|
@allure.title("Test adding a secret with automatic value")
|
||||||
@allure.description("Test that we can add a secret where we let the system come up with the value of a given length.")
|
@allure.description(
|
||||||
|
"Test that we can add a secret where we let the system come up with the value of a given length."
|
||||||
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_add_secret_auto(self, admin_server: AdminServer) -> None:
|
async def test_add_secret_auto(self, admin_server: AdminServer) -> None:
|
||||||
"""Test adding a secret with an auto-generated value."""
|
"""Test adding a secret with an auto-generated value."""
|
||||||
634
tests/integration/admin/test_secret_manager.py
Normal file
634
tests/integration/admin/test_secret_manager.py
Normal file
@ -0,0 +1,634 @@
|
|||||||
|
"""Test secret manager.
|
||||||
|
|
||||||
|
This package tests the rewritten secret manager system.
|
||||||
|
|
||||||
|
This is technically an integration test, as it requires the other subsystems to
|
||||||
|
run, but it uses the internal API rather than the exposed routes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import allure
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from sshecret_admin.core.settings import AdminServerSettings
|
||||||
|
from sshecret_admin.services.models import SecretGroup
|
||||||
|
from sshecret_admin.services.secret_manager import (
|
||||||
|
password_manager_context,
|
||||||
|
AsyncSecretContext,
|
||||||
|
InvalidSecretNameError,
|
||||||
|
InvalidGroupNameError,
|
||||||
|
)
|
||||||
|
from sshecret_admin.auth.models import Base, PasswordDB
|
||||||
|
from sshecret_admin.services.master_password import setup_master_password
|
||||||
|
|
||||||
|
# -------- global parameter sets start here -------- #
|
||||||
|
|
||||||
|
|
||||||
|
# -------- Fixtures start here -------- #
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(autouse=True)
|
||||||
|
async def create_admin_db(admin_server_settings: AdminServerSettings) -> None:
|
||||||
|
"""Create the database."""
|
||||||
|
engine = create_engine(admin_server_settings.admin_db)
|
||||||
|
Base.metadata.create_all(engine)
|
||||||
|
encr_master_password = setup_master_password(
|
||||||
|
settings=admin_server_settings, regenerate=True
|
||||||
|
)
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
pwdb = PasswordDB(id=1, encrypted_password=encr_master_password)
|
||||||
|
session.add(pwdb)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture()
|
||||||
|
async def secrets_manager(admin_server_settings: AdminServerSettings):
|
||||||
|
"""Test that the context manager can be created."""
|
||||||
|
async with password_manager_context(
|
||||||
|
admin_server_settings, "TEST", "127.0.0.1"
|
||||||
|
) as manager:
|
||||||
|
yield manager
|
||||||
|
|
||||||
|
|
||||||
|
# -------- Tests start here -------- #
|
||||||
|
|
||||||
|
|
||||||
|
@allure.title("Adding entries")
|
||||||
|
@pytest.mark.parametrize("name,secret", [("testentry", "testsecret")])
|
||||||
|
class TestSecretsAddEntry:
|
||||||
|
"""Tests for the add_entry method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_entry(
|
||||||
|
self,
|
||||||
|
secrets_manager: AsyncSecretContext,
|
||||||
|
name: str,
|
||||||
|
secret: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test add entry.
|
||||||
|
|
||||||
|
This tests add_entry and get_secret
|
||||||
|
"""
|
||||||
|
await secrets_manager.add_entry(name, secret)
|
||||||
|
stored_secret = await secrets_manager.get_secret(name)
|
||||||
|
assert stored_secret == secret
|
||||||
|
|
||||||
|
async def test_add_entry_duplicate(
|
||||||
|
self,
|
||||||
|
secrets_manager: AsyncSecretContext,
|
||||||
|
name: str,
|
||||||
|
secret: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test adding an entry twice."""
|
||||||
|
await secrets_manager.add_entry(name, secret)
|
||||||
|
stored_secret = await secrets_manager.get_secret(name)
|
||||||
|
assert stored_secret == secret
|
||||||
|
|
||||||
|
with pytest.raises(InvalidSecretNameError):
|
||||||
|
await secrets_manager.add_entry(name, secret)
|
||||||
|
|
||||||
|
async def test_add_entry_with_group(
|
||||||
|
self,
|
||||||
|
secrets_manager: AsyncSecretContext,
|
||||||
|
name: str,
|
||||||
|
secret: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test adding a secret with a group."""
|
||||||
|
group = "testgroup"
|
||||||
|
await secrets_manager.add_group(group)
|
||||||
|
await secrets_manager.add_entry(name, secret, group_path=group)
|
||||||
|
result = await secrets_manager.get_entry_group(name)
|
||||||
|
assert result == group
|
||||||
|
|
||||||
|
async def test_add_entry_with_nonexisting_group(
|
||||||
|
self,
|
||||||
|
secrets_manager: AsyncSecretContext,
|
||||||
|
name: str,
|
||||||
|
secret: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test adding a secret where the group does not exist."""
|
||||||
|
group = "testgroup"
|
||||||
|
with pytest.raises(InvalidGroupNameError):
|
||||||
|
await secrets_manager.add_entry(name, secret, group_path=group)
|
||||||
|
|
||||||
|
async def test_add_entry_with_deep_path(
|
||||||
|
self,
|
||||||
|
secrets_manager: AsyncSecretContext,
|
||||||
|
name: str,
|
||||||
|
secret: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test adding a secret to a nested group with a path specification."""
|
||||||
|
await secrets_manager.add_group("root")
|
||||||
|
await secrets_manager.add_group("nested", parent_group="root")
|
||||||
|
await secrets_manager.add_entry(name, secret, group_path="/root/nested")
|
||||||
|
group = await secrets_manager.get_secret_group("/root/nested")
|
||||||
|
assert group is not None
|
||||||
|
assert name in group.entries
|
||||||
|
|
||||||
|
async def test_overwrite_secret(
|
||||||
|
self, secrets_manager: AsyncSecretContext, name: str, secret: str
|
||||||
|
) -> None:
|
||||||
|
"""Test overwriting a secret."""
|
||||||
|
await secrets_manager.add_entry(name, secret)
|
||||||
|
stored_secret = await secrets_manager.get_secret(name)
|
||||||
|
assert stored_secret == secret
|
||||||
|
|
||||||
|
new_secret = "newsecret"
|
||||||
|
await secrets_manager.add_entry(name, new_secret, overwrite=True)
|
||||||
|
|
||||||
|
stored_secret = await secrets_manager.get_secret(name)
|
||||||
|
assert stored_secret == new_secret
|
||||||
|
|
||||||
|
|
||||||
|
@allure.title("Creating groups")
|
||||||
|
class TestSecretGroupCreation:
|
||||||
|
"""Test secret groups."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"group_name", ["testgroup", "long group name with spaces", "blåbærgrød"]
|
||||||
|
)
|
||||||
|
@allure.title("Add a group name {group_name}")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_group(
|
||||||
|
self,
|
||||||
|
secrets_manager: AsyncSecretContext,
|
||||||
|
group_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Get adding a group."""
|
||||||
|
await secrets_manager.add_group(group_name)
|
||||||
|
groups = await secrets_manager.get_secret_groups()
|
||||||
|
assert len(groups) == 1
|
||||||
|
assert groups[0].name == group_name
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_group_with_parent(
|
||||||
|
self, secrets_manager: AsyncSecretContext
|
||||||
|
) -> None:
|
||||||
|
"""Test add a group with a parent group."""
|
||||||
|
parent_name = "parent"
|
||||||
|
child_name = "child"
|
||||||
|
await secrets_manager.add_group(parent_name)
|
||||||
|
await secrets_manager.add_group(child_name, parent_group=parent_name)
|
||||||
|
parent_group = await secrets_manager.get_secret_group(f"/parent")
|
||||||
|
assert parent_group is not None
|
||||||
|
assert len(parent_group.children) == 1
|
||||||
|
assert parent_group.children[0].name == child_name
|
||||||
|
|
||||||
|
child_group = await secrets_manager.get_secret_group("/parent/child")
|
||||||
|
assert child_group is not None
|
||||||
|
assert child_group.name == child_name
|
||||||
|
assert child_group.parent_group is not None
|
||||||
|
assert child_group.parent_group.name == parent_name
|
||||||
|
assert len(child_group.children) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_group_as_path(self, secrets_manager: AsyncSecretContext) -> None:
|
||||||
|
"""Add a nested group with path annotation."""
|
||||||
|
parent_name = "parent"
|
||||||
|
child_path = "/parent/child"
|
||||||
|
await secrets_manager.add_group(parent_name)
|
||||||
|
await secrets_manager.add_group(child_path)
|
||||||
|
parent_group = await secrets_manager.get_secret_group(f"/parent")
|
||||||
|
assert parent_group is not None
|
||||||
|
assert len(parent_group.children) == 1
|
||||||
|
assert parent_group.children[0].name == "child"
|
||||||
|
|
||||||
|
child_group = await secrets_manager.get_secret_group(child_path)
|
||||||
|
assert child_group is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_overlapping_names(self, secrets_manager: AsyncSecretContext) -> None:
|
||||||
|
"""Test having overlapping names in different groups."""
|
||||||
|
await secrets_manager.add_group("root")
|
||||||
|
with pytest.raises(InvalidGroupNameError):
|
||||||
|
await secrets_manager.add_group("/root")
|
||||||
|
|
||||||
|
await secrets_manager.add_group("/root/root")
|
||||||
|
|
||||||
|
group = await secrets_manager.get_secret_group("/root/root")
|
||||||
|
assert group is not None
|
||||||
|
assert group.name == "root"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_group_with_nonexisting_parent(
|
||||||
|
self, secrets_manager: AsyncSecretContext
|
||||||
|
) -> None:
|
||||||
|
"""Test adding a group with a nonexisting parent."""
|
||||||
|
with pytest.raises(InvalidGroupNameError):
|
||||||
|
await secrets_manager.add_group("orphan", parent_group="unknown")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_duplicate_group(
|
||||||
|
self, secrets_manager: AsyncSecretContext
|
||||||
|
) -> None:
|
||||||
|
"""Test adding the same group twice."""
|
||||||
|
await secrets_manager.add_group("snowflake")
|
||||||
|
with pytest.raises(InvalidGroupNameError):
|
||||||
|
await secrets_manager.add_group("snowflake")
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"group_name,description", [("testgroup", "test description")]
|
||||||
|
)
|
||||||
|
@allure.title("Add a group name {group_name} with description {description}")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_group_with_description(
|
||||||
|
self,
|
||||||
|
secrets_manager: AsyncSecretContext,
|
||||||
|
group_name: str,
|
||||||
|
description: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test adding a group with description."""
|
||||||
|
await secrets_manager.add_group(group_name, description)
|
||||||
|
result = await secrets_manager.get_secret_group(group_name)
|
||||||
|
assert result is not None
|
||||||
|
assert result.description == description
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"groups",
|
||||||
|
[
|
||||||
|
[
|
||||||
|
("root", None, "root"),
|
||||||
|
("level1", "root", "/root/level1"),
|
||||||
|
("level2", "level1", "/root/level1/level2"),
|
||||||
|
],
|
||||||
|
[("flat1", None, "flat1"), ("flat2", None, "flat2")],
|
||||||
|
[
|
||||||
|
("stub", None, "stub"),
|
||||||
|
("root", None, "root"),
|
||||||
|
("nested", "root", "/root/nested"),
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@allure.title("Listing groups")
|
||||||
|
class TestSecretGroupListing:
|
||||||
|
"""Tests for listing groups."""
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(autouse=True)
|
||||||
|
async def create_groups(
|
||||||
|
self,
|
||||||
|
secrets_manager: AsyncSecretContext,
|
||||||
|
groups: list[tuple[str, str | None, str]],
|
||||||
|
) -> None:
|
||||||
|
"""Pre-create groups."""
|
||||||
|
for name, parent_group, _path in groups:
|
||||||
|
await secrets_manager.add_group(name, parent_group=parent_group)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_secret_groups_list(
|
||||||
|
self,
|
||||||
|
secrets_manager: AsyncSecretContext,
|
||||||
|
groups: list[tuple[str, str | None, str]],
|
||||||
|
) -> None:
|
||||||
|
"""Test the flat get_secret_groups_list."""
|
||||||
|
# Create three levels of content
|
||||||
|
group_list = await secrets_manager.get_secret_group_list()
|
||||||
|
assert len(group_list) == len(groups)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_secret_groups(
|
||||||
|
self,
|
||||||
|
secrets_manager: AsyncSecretContext,
|
||||||
|
groups: list[tuple[str, str | None, str]],
|
||||||
|
) -> None:
|
||||||
|
"""Test the tree-oriented get_secretsgroups."""
|
||||||
|
group_map = dict([(group[0], group[1]) for group in sorted(groups)])
|
||||||
|
group_tree = await secrets_manager.get_secret_groups()
|
||||||
|
root_groups = [key for key, value in group_map.items() if value is None]
|
||||||
|
assert len(group_tree) == len(root_groups)
|
||||||
|
reconstructed_groups: list[tuple[str, str | None]] = []
|
||||||
|
|
||||||
|
def crawl_tree(item: SecretGroup) -> list[tuple[str, str | None]]:
|
||||||
|
"""Crawl a tree recursively."""
|
||||||
|
parent_group_name = None
|
||||||
|
if item.parent_group:
|
||||||
|
parent_group_name = item.parent_group.name
|
||||||
|
items: list[tuple[str, str | None]] = [(item.name, parent_group_name)]
|
||||||
|
for child in item.children:
|
||||||
|
items.extend(crawl_tree(child))
|
||||||
|
|
||||||
|
return items
|
||||||
|
|
||||||
|
for item in group_tree:
|
||||||
|
reconstructed_groups.extend(crawl_tree(item))
|
||||||
|
|
||||||
|
assert dict(sorted(reconstructed_groups)) == group_map
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_secret_groups_with_secrets(
|
||||||
|
self,
|
||||||
|
secrets_manager: AsyncSecretContext,
|
||||||
|
groups: list[tuple[str, str | None, str]],
|
||||||
|
) -> None:
|
||||||
|
"""Test fetching groups where there are secrets in all groups."""
|
||||||
|
# We will create exactly two secrets in each group.
|
||||||
|
for group_name, _parent, path in groups:
|
||||||
|
await secrets_manager.add_entry(
|
||||||
|
f"{group_name}_1", f"{group_name}_secret_1", group_path=path
|
||||||
|
)
|
||||||
|
await secrets_manager.add_entry(
|
||||||
|
f"{group_name}_2", f"{group_name}_secret_2", group_path=path
|
||||||
|
)
|
||||||
|
|
||||||
|
group_list = await secrets_manager.get_secret_group_list()
|
||||||
|
for group in group_list:
|
||||||
|
assert len(group.entries) == 2
|
||||||
|
assert group.entries[0] == f"{group.name}_1"
|
||||||
|
assert group.entries[1] == f"{group.name}_2"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"query,expected,groups,children",
|
||||||
|
[
|
||||||
|
("MATCH", 1, [("MATCH", None), ("SOMETHINGELSE", None)], 0),
|
||||||
|
("MATCH", 1, [("root", None), ("MATCH", "root"), ("SOMETHINGELSE", "root")], 0),
|
||||||
|
("MATCH", 3, [("MATCH1", None), ("MATCH2", None), ("MATCH3", None)], 0),
|
||||||
|
("MATCH", 1, [("root", None), ("MATCH", "root"), ("CHILD", "MATCH")], 1),
|
||||||
|
(
|
||||||
|
"NOMATCH",
|
||||||
|
0,
|
||||||
|
[("foo", None), ("bar", None), ("foobar", "foo"), ("barfoo", "bar")],
|
||||||
|
0,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@allure.title("Searching in groups using patterns")
|
||||||
|
class TestGroupSearchPattern:
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_group_list_pattern(
|
||||||
|
self,
|
||||||
|
secrets_manager: AsyncSecretContext,
|
||||||
|
query: str,
|
||||||
|
expected: int,
|
||||||
|
groups: list[tuple[str, str | None]],
|
||||||
|
children: int,
|
||||||
|
) -> None:
|
||||||
|
"""Test matching a pattern."""
|
||||||
|
for name, parent_group in groups:
|
||||||
|
await secrets_manager.add_group(name, parent_group=parent_group)
|
||||||
|
|
||||||
|
result = await secrets_manager.get_secret_group_list(pattern=query, regex=False)
|
||||||
|
assert len(result) == expected
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_group_tree_pattern(
|
||||||
|
self,
|
||||||
|
secrets_manager: AsyncSecretContext,
|
||||||
|
query: str,
|
||||||
|
expected: int,
|
||||||
|
groups: list[tuple[str, str | None]],
|
||||||
|
children: int,
|
||||||
|
) -> None:
|
||||||
|
"""Test matching a pattern with a tree result."""
|
||||||
|
for name, parent_group in groups:
|
||||||
|
await secrets_manager.add_group(name, parent_group=parent_group)
|
||||||
|
|
||||||
|
result = await secrets_manager.get_secret_groups(pattern=query, regex=False)
|
||||||
|
assert len(result) == expected
|
||||||
|
if expected == 1 and children > 0:
|
||||||
|
assert len(result[0].children) == children
|
||||||
|
|
||||||
|
|
||||||
|
@allure.title("Modifying groups")
|
||||||
|
class TestGroupModification:
|
||||||
|
"""Test modifying groups."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"group_name,parent,description",
|
||||||
|
[("test", None, "test description"), ("test", "root", "test_description")],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_group_description(
|
||||||
|
self,
|
||||||
|
secrets_manager: AsyncSecretContext,
|
||||||
|
group_name: str,
|
||||||
|
parent: str | None,
|
||||||
|
description: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test setting a description on a group."""
|
||||||
|
if parent:
|
||||||
|
await secrets_manager.add_group(parent)
|
||||||
|
await secrets_manager.add_group(group_name, parent_group=parent)
|
||||||
|
path = group_name
|
||||||
|
if parent:
|
||||||
|
path = f"/{parent}/{group_name}"
|
||||||
|
group = await secrets_manager.get_secret_group(path)
|
||||||
|
assert group is not None
|
||||||
|
assert group.description is None
|
||||||
|
|
||||||
|
await secrets_manager.set_group_description(path, description)
|
||||||
|
group = await secrets_manager.get_secret_group(path)
|
||||||
|
assert group is not None
|
||||||
|
assert group.description == description
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"groups,target_group, expected_path",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
[("root", None), ("test", None)],
|
||||||
|
("test", "root"),
|
||||||
|
"/root/test",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[("root", None), ("test", "root")],
|
||||||
|
("/root/test", None),
|
||||||
|
"test",
|
||||||
|
),
|
||||||
|
([("test", None)], ("test", None), "test"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_move_group(
|
||||||
|
self,
|
||||||
|
secrets_manager: AsyncSecretContext,
|
||||||
|
groups: list[tuple[str, str | None]],
|
||||||
|
target_group: tuple[str, str | None],
|
||||||
|
expected_path: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test moving groups around."""
|
||||||
|
for group_name, parent_name in groups:
|
||||||
|
await secrets_manager.add_group(group_name, parent_group=parent_name)
|
||||||
|
|
||||||
|
group_name, target = target_group
|
||||||
|
await secrets_manager.move_group(group_name, target)
|
||||||
|
group = await secrets_manager.get_secret_group(expected_path)
|
||||||
|
assert group is not None
|
||||||
|
|
||||||
|
|
||||||
|
@allure.title("Deleting items")
|
||||||
|
class TestSecretManagerDeletions:
|
||||||
|
"""Test secret manager deletions."""
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(autouse=True)
|
||||||
|
async def create_test_data(self, secrets_manager: AsyncSecretContext) -> None:
|
||||||
|
"""Create some test data."""
|
||||||
|
groups = [
|
||||||
|
("root", None, "root"),
|
||||||
|
("level1", "root", "/root/level1"),
|
||||||
|
("level2", "level1", "/root/level1/level2"),
|
||||||
|
]
|
||||||
|
for n in range(2):
|
||||||
|
await secrets_manager.add_entry(f"ungrouped_{n}", "secret")
|
||||||
|
for group_name, parent_name, path in groups:
|
||||||
|
await secrets_manager.add_group(group_name, parent_group=parent_name)
|
||||||
|
for n in range(2):
|
||||||
|
await secrets_manager.add_entry(
|
||||||
|
f"{group_name}_{n}", "secret", group_path=path
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"name,group_name",
|
||||||
|
[("root_1", "root"), ("level1_0", "/root/level1"), ("ungrouped_1", None)],
|
||||||
|
)
|
||||||
|
@allure.title("Delete secret {name in group {group_name}}")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_secret_deletion(
|
||||||
|
self, secrets_manager: AsyncSecretContext, name: str, group_name: str | None
|
||||||
|
) -> None:
|
||||||
|
"""Test secret deletion."""
|
||||||
|
if group_name:
|
||||||
|
group = await secrets_manager.get_secret_group(group_name)
|
||||||
|
assert group is not None
|
||||||
|
assert name in group.entries
|
||||||
|
|
||||||
|
secret = await secrets_manager.get_secret(name)
|
||||||
|
assert secret is not None
|
||||||
|
|
||||||
|
await secrets_manager.delete_entry(name)
|
||||||
|
|
||||||
|
secret = await secrets_manager.get_secret(name)
|
||||||
|
assert secret is None
|
||||||
|
|
||||||
|
if group_name:
|
||||||
|
group = await secrets_manager.get_secret_group(group_name)
|
||||||
|
assert group is not None
|
||||||
|
assert name not in group.entries
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("name", ["NONEXISTING"])
|
||||||
|
@allure.title("Deleting non-existing entry {name}")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_nonexisting_entry(
|
||||||
|
self, secrets_manager: AsyncSecretContext, name: str
|
||||||
|
) -> None:
|
||||||
|
"""Test deleting something that doesn't exist."""
|
||||||
|
secret = await secrets_manager.get_secret(name)
|
||||||
|
assert secret is None
|
||||||
|
# Deleting something that is already deleted returns None
|
||||||
|
await secrets_manager.delete_entry(name)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("path", ["/root/level1"])
|
||||||
|
@allure.title("Deleting group {path}")
|
||||||
|
async def test_group_delete(
|
||||||
|
self, secrets_manager: AsyncSecretContext, path: str
|
||||||
|
) -> None:
|
||||||
|
"""Test deleting a group."""
|
||||||
|
group = await secrets_manager.get_secret_group(path)
|
||||||
|
assert group is not None
|
||||||
|
entries = list(group.entries)
|
||||||
|
await secrets_manager.delete_group(path)
|
||||||
|
group = await secrets_manager.get_secret_group(path)
|
||||||
|
assert group is None
|
||||||
|
|
||||||
|
for name in entries:
|
||||||
|
new_grouping = await secrets_manager.get_entry_group(name)
|
||||||
|
assert new_grouping is None
|
||||||
|
|
||||||
|
|
||||||
|
@allure.title("Other tests")
|
||||||
|
class TestSecretManagerOther:
|
||||||
|
"""Uncategorized tests to standardize module."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_secret_nonexisting(
|
||||||
|
self, secrets_manager: AsyncSecretContext
|
||||||
|
) -> None:
|
||||||
|
"""Test get_secret with invalid name."""
|
||||||
|
result = await secrets_manager.get_secret("NOMATCH")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"group_name,num_grouped,num_ungrouped",
|
||||||
|
[("GROUP", 3, 3), ("GROUP", 3, 0), ("GROUP", 0, 0)],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_ungrouped_secrets(
|
||||||
|
self,
|
||||||
|
secrets_manager: AsyncSecretContext,
|
||||||
|
group_name: str,
|
||||||
|
num_grouped: int,
|
||||||
|
num_ungrouped: int,
|
||||||
|
) -> None:
|
||||||
|
"""Test get_ungrouped_secrets."""
|
||||||
|
await secrets_manager.add_group(group_name)
|
||||||
|
for n in range(num_ungrouped):
|
||||||
|
await secrets_manager.add_entry(f"ungrouped_{n}", "secret")
|
||||||
|
|
||||||
|
for n in range(num_grouped):
|
||||||
|
await secrets_manager.add_entry(
|
||||||
|
f"grouped_{n}", "secret", group_path="GROUP"
|
||||||
|
)
|
||||||
|
|
||||||
|
ungrouped = await secrets_manager.get_ungrouped_secrets()
|
||||||
|
assert len(ungrouped) == num_ungrouped
|
||||||
|
matching = [entry for entry in ungrouped if entry.startswith("ungrouped_")]
|
||||||
|
assert len(matching) == num_ungrouped
|
||||||
|
|
||||||
|
grouped_secrets = await secrets_manager.get_available_secrets(
|
||||||
|
group_path=group_name
|
||||||
|
)
|
||||||
|
assert len(grouped_secrets) == num_grouped
|
||||||
|
all_secrets = await secrets_manager.get_available_secrets()
|
||||||
|
assert len(all_secrets) == (num_ungrouped + num_grouped)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"entries", [[("test1", "secret1"), ("test2", "secret2")], []]
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_available_secrets(
|
||||||
|
self, secrets_manager: AsyncSecretContext, entries: list[tuple[str, str]]
|
||||||
|
) -> None:
|
||||||
|
"""Test the get_available_secrets method."""
|
||||||
|
for name, secret in entries:
|
||||||
|
await secrets_manager.add_entry(name, secret)
|
||||||
|
|
||||||
|
entry_names = [entry[0] for entry in entries]
|
||||||
|
response = await secrets_manager.get_available_secrets()
|
||||||
|
assert len(response) == len(entries)
|
||||||
|
|
||||||
|
assert sorted(response) == sorted(entry_names)
|
||||||
|
|
||||||
|
async def test_get_secret_groups_none(
|
||||||
|
self, secrets_manager: AsyncSecretContext
|
||||||
|
) -> None:
|
||||||
|
"""Test get_secret_groups with no groups created."""
|
||||||
|
result = await secrets_manager.get_secret_groups()
|
||||||
|
assert len(result) == 0
|
||||||
|
result_flat = await secrets_manager.get_secret_group_list()
|
||||||
|
assert len(result_flat) == 0
|
||||||
|
|
||||||
|
@allure.title("Search for a group using regular expression")
|
||||||
|
async def test_group_regex_search(
|
||||||
|
self, secrets_manager: AsyncSecretContext
|
||||||
|
) -> None:
|
||||||
|
"""Search for entries with regular expressions."""
|
||||||
|
groups = [
|
||||||
|
"test1",
|
||||||
|
"test2",
|
||||||
|
"other",
|
||||||
|
"somethingelse",
|
||||||
|
]
|
||||||
|
|
||||||
|
for group in groups:
|
||||||
|
await secrets_manager.add_group(group)
|
||||||
|
|
||||||
|
results = await secrets_manager.get_secret_group_list(
|
||||||
|
pattern="^test", regex=True
|
||||||
|
)
|
||||||
|
assert len(results) == 2
|
||||||
|
for group in results:
|
||||||
|
assert group.name.startswith("test")
|
||||||
@ -79,6 +79,32 @@ async def run_backend_server(test_ports: TestPorts):
|
|||||||
await server_task
|
await server_task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(
|
||||||
|
scope=TEST_SCOPE, name="admin_server_settings", loop_scope=LOOP_SCOPE
|
||||||
|
)
|
||||||
|
async def get_admin_server_settings(
|
||||||
|
test_ports: TestPorts, backend_server: tuple[str, str]
|
||||||
|
):
|
||||||
|
"""Get admin server settings."""
|
||||||
|
backend_url, backend_token = backend_server
|
||||||
|
port = test_ports.admin
|
||||||
|
secret_key = secrets.token_urlsafe(32)
|
||||||
|
with in_tempdir() as admin_work_path:
|
||||||
|
admin_db = admin_work_path / "ssh_admin.db"
|
||||||
|
admin_settings = AdminServerSettings.model_validate(
|
||||||
|
{
|
||||||
|
"sshecret_backend_url": backend_url,
|
||||||
|
"backend_token": backend_token,
|
||||||
|
"secret_key": secret_key,
|
||||||
|
"listen_address": "0.0.0.0",
|
||||||
|
"port": port,
|
||||||
|
"database": str(admin_db.absolute()),
|
||||||
|
"password_manager_directory": str(admin_work_path.absolute()),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
yield admin_settings
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope=TEST_SCOPE, name="admin_server", loop_scope=LOOP_SCOPE)
|
@pytest_asyncio.fixture(scope=TEST_SCOPE, name="admin_server", loop_scope=LOOP_SCOPE)
|
||||||
async def run_admin_server(test_ports: TestPorts, backend_server: tuple[str, str]):
|
async def run_admin_server(test_ports: TestPorts, backend_server: tuple[str, str]):
|
||||||
"""Run admin server."""
|
"""Run admin server."""
|
||||||
@ -98,7 +124,7 @@ async def run_admin_server(test_ports: TestPorts, backend_server: tuple[str, str
|
|||||||
"password_manager_directory": str(admin_work_path.absolute()),
|
"password_manager_directory": str(admin_work_path.absolute()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
admin_app = create_admin_app(admin_settings)
|
admin_app = create_admin_app(admin_settings, create_db=True)
|
||||||
config = uvicorn.Config(app=admin_app, port=port, loop="asyncio")
|
config = uvicorn.Config(app=admin_app, port=port, loop="asyncio")
|
||||||
server = uvicorn.Server(config=config)
|
server = uvicorn.Server(config=config)
|
||||||
server_task = asyncio.create_task(server.serve())
|
server_task = asyncio.create_task(server.serve())
|
||||||
|
|||||||
Reference in New Issue
Block a user