Check in backend in working state

This commit is contained in:
2025-04-30 08:23:31 +02:00
parent 76ef97d9c4
commit 20f1ee707a
26 changed files with 1505 additions and 621 deletions

View File

@ -0,0 +1,119 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
# Use forward slashes (/) also on windows to provide an os agnostic path
script_location = migrations
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to migrations/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
# version_path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
version_path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = sqlite:///sshecret.db
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@ -0,0 +1 @@
Generic single-database configuration.

View File

@ -0,0 +1,80 @@
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from sshecret_backend.models import *
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
#target_metadata = None
target_metadata = SQLModel.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@ -0,0 +1,29 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}

View File

@ -0,0 +1,33 @@
"""Initial model
Revision ID: a0befb5a74a0
Revises:
Create Date: 2025-04-28 21:18:59.069323
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = 'a0befb5a74a0'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###

View File

@ -0,0 +1,33 @@
"""Add subsystem to auditlog
Revision ID: f30e413c5757
Revises: a0befb5a74a0
Create Date: 2025-04-28 21:21:20.103423
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = 'f30e413c5757'
down_revision: Union[str, None] = 'a0befb5a74a0'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('auditlog', sa.Column('subsystem', sqlmodel.sql.sqltypes.AutoString(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('auditlog', 'subsystem')
# ### end Alembic commands ###

View File

@ -8,9 +8,11 @@ authors = [
]
requires-python = ">=3.13"
dependencies = [
"alembic>=1.15.2",
"passlib[bcrypt]>=1.7.4",
"pydantic>=2.10.6",
"pytest>=8.3.5",
"python-multipart>=0.0.20",
"sqlmodel>=0.0.24",
]

View File

@ -1,5 +1,2 @@
"""Sshecret backend."""
from .app import app as app
#from .router import app as app
__all__ = ["app"]
# from .router import app as app

View File

@ -0,0 +1,8 @@
"""API factory modules."""
from .audit import get_audit_api
from .clients import get_clients_api
from .policies import get_policy_api
from .secrets import get_secrets_api
__all__ = ["get_audit_api", "get_clients_api", "get_policy_api", "get_secrets_api"]

View File

@ -0,0 +1,65 @@
"""Audit sub-api factory."""
# pyright: reportUnusedFunction=false
import logging
from collections.abc import Sequence
from fastapi import APIRouter, Depends, Request, Query
from sqlmodel import Session, col, func, select
from sqlalchemy import desc
from typing import Annotated
from sshecret_backend.models import AuditLog
from sshecret_backend.types import DBSessionDep
from sshecret_backend import audit
from sshecret_backend.view_models import AuditInfo
LOG = logging.getLogger(__name__)
def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
"""Construct audit sub-api."""
router = APIRouter()
@router.get("/audit/", response_model=list[AuditLog])
async def get_audit_logs(
request: Request,
session: Annotated[Session, Depends(get_db_session)],
offset: Annotated[int, Query()] = 0,
limit: Annotated[int, Query(le=100)] = 100,
filter_client: Annotated[str | None, Query()] = None,
filter_subsystem: Annotated[str | None, Query()] = None,
) -> Sequence[AuditLog]:
"""Get audit logs."""
#audit.audit_access_audit_log(session, request)
statement = select(AuditLog).offset(offset).limit(limit).order_by(desc(col(AuditLog.timestamp)))
if filter_client:
statement = statement.where(AuditLog.client_name == filter_client)
if filter_subsystem:
statement = statement.where(AuditLog.subsystem == filter_subsystem)
results = session.exec(statement).all()
return results
@router.post("/audit/")
async def add_audit_log(
request: Request,
session: Annotated[Session, Depends(get_db_session)],
entry: AuditLog,
) -> AuditLog:
"""Add entry to audit log."""
audit_log = AuditLog.model_validate(entry.model_dump(exclude_none=True))
session.add(audit_log)
session.commit()
return audit_log
@router.get("/audit/info")
async def get_audit_info(request: Request, session: Annotated[Session, Depends(get_db_session)]) -> AuditInfo:
"""Get audit info."""
audit_count = session.exec(select(func.count('*')).select_from(AuditLog)).one()
return AuditInfo(entries=audit_count)
return router

View File

@ -0,0 +1,226 @@
"""Client sub-api factory."""
# pyright: reportUnusedFunction=false
import uuid
import logging
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from pydantic import BaseModel, Field, model_validator
from sqlmodel import Session, col, select
from sqlalchemy import func
from typing import Annotated, Self, TypeVar
from sqlmodel.sql.expression import SelectOfScalar
from sshecret_backend.types import DBSessionDep
from sshecret_backend.models import Client, ClientSecret
from sshecret_backend.view_models import (
ClientCreate,
ClientQueryResult,
ClientView,
ClientUpdate,
)
from sshecret_backend import audit
from .common import get_client_by_id_or_name
class ClientListParams(BaseModel):
"""Client list parameters."""
limit: int = Field(100, gt=0, le=100)
offset: int = Field(0, ge=0)
id: uuid.UUID | None = None
name: str | None = None
name__like: str | None = None
name__contains: str | None = None
@model_validator(mode="after")
def validate_expressions(self) -> Self:
"""Validate mutually exclusive expression."""
name_filter = False
if self.name__like or self.name__contains:
name_filter = True
if self.name__like and self.name__contains:
raise ValueError("You may only specify one name expression")
if self.name and name_filter:
raise ValueError(
"You must either specify name or one of name__like or name__contains"
)
return self
LOG = logging.getLogger(__name__)
T = TypeVar("T")
def filter_client_statement(
statement: SelectOfScalar[T], params: ClientListParams, ignore_limits: bool = False
) -> SelectOfScalar[T]:
"""Filter a statement with the provided params."""
if params.id:
statement = statement.where(Client.id == params.id)
if params.name:
statement = statement.where(Client.name == params.name)
elif params.name__like:
statement = statement.where(col(Client.name).like(params.name__like))
elif params.name__contains:
statement = statement.where(col(Client.name).contains(params.name__contains))
if ignore_limits:
return statement
return statement.limit(params.limit).offset(params.offset)
def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
"""Construct clients sub-api."""
router = APIRouter()
@router.get("/clients/")
async def get_clients(
filter_query: Annotated[ClientListParams, Query()],
session: Annotated[Session, Depends(get_db_session)],
) -> ClientQueryResult:
"""Get clients."""
# Get total results first
count_statement = select(func.count("*")).select_from(Client)
count_statement = filter_client_statement(count_statement, filter_query, True)
total_results = session.exec(count_statement).one()
statement = filter_client_statement(select(Client), filter_query, False)
results = session.exec(statement)
remainder = total_results - filter_query.offset - filter_query.limit
if remainder < 0:
remainder = 0
clients = list(results)
clients_view = ClientView.from_client_list(clients)
return ClientQueryResult(
clients=clients_view,
total_results=total_results,
remaining_results=remainder,
)
@router.get("/clients/{name}")
async def get_client(
name: str,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientView:
"""Fetch a client."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
return ClientView.from_client(client)
@router.delete("/clients/{name}")
async def delete_client(
request: Request,
name: str,
session: Annotated[Session, Depends(get_db_session)],
) -> None:
"""Delete a client."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
session.delete(client)
session.commit()
audit.audit_delete_client(session, request, client)
@router.post("/clients/")
async def create_client(
request: Request,
client: ClientCreate,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientView:
"""Create client."""
existing = await get_client_by_id_or_name(session, client.name)
if existing:
raise HTTPException(400, detail="Error: Already a client with that name.")
db_client = client.to_client()
session.add(db_client)
session.commit()
session.refresh(db_client)
audit.audit_create_client(session, request, db_client)
return ClientView.from_client(db_client)
@router.post("/clients/{name}/public-key")
async def update_client_public_key(
request: Request,
name: str,
client_update: ClientUpdate,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientView:
"""Change the public key of a client.
This invalidates all secrets.
"""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
client.public_key = client_update.public_key
for secret in session.exec(
select(ClientSecret).where(ClientSecret.client_id == client.id)
).all():
LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True
secret.client_id = None
secret.client = None
session.add(client)
session.refresh(client)
session.commit()
audit.audit_invalidate_secrets(session, request, client)
return ClientView.from_client(client)
@router.put("/clients/{name}")
async def update_client(
request: Request,
name: str,
client_update: ClientCreate,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientView:
"""Change the public key of a client.
This invalidates all secrets.
"""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
client.name = client_update.name
client.description = client_update.description
public_key_updated = False
if client_update.public_key != client.public_key:
public_key_updated = True
for secret in session.exec(
select(ClientSecret).where(ClientSecret.client_id == client.id)
).all():
LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True
secret.client_id = None
secret.client = None
session.add(client)
session.commit()
session.refresh(client)
audit.audit_update_client(session, request, client)
if public_key_updated:
audit.audit_invalidate_secrets(session, request, client)
return ClientView.from_client(client)
return router

View File

@ -0,0 +1,38 @@
"""Common helpers."""
import re
import uuid
import bcrypt
from sqlmodel import Session, select
from sshecret_backend.models import Client
RE_UUID = re.compile("^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$")
def verify_token(token: str, stored_hash: str) -> bool:
"""Verify token."""
token_bytes = token.encode("utf-8")
stored_bytes = stored_hash.encode("utf-8")
return bcrypt.checkpw(token_bytes, stored_bytes)
async def get_client_by_name(session: Session, name: str) -> Client | None:
"""Get client by name."""
client_filter = select(Client).where(Client.name == name)
client_results = session.exec(client_filter)
return client_results.first()
async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None:
"""Get client by name."""
client_filter = select(Client).where(Client.id == id)
client_results = session.exec(client_filter)
return client_results.first()
async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None:
"""Get client either by id or name."""
if RE_UUID.match(id_or_name):
id = uuid.UUID(id_or_name)
return await get_client_by_id(session, id)
return await get_client_by_name(session, id_or_name)

View File

@ -0,0 +1,82 @@
"""Policies sub-api router factory."""
# pyright: reportUnusedFunction=false
import logging
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlmodel import Session, select
from typing import Annotated
from sshecret_backend.models import Client, ClientAccessPolicy
from sshecret_backend.view_models import (
ClientPolicyView,
ClientPolicyUpdate,
)
from sshecret_backend.types import DBSessionDep
from sshecret_backend import audit
from .common import get_client_by_id_or_name
LOG = logging.getLogger(__name__)
def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
"""Construct clients sub-api."""
router = APIRouter()
@router.get("/clients/{name}/policies/")
async def get_client_policies(
name: str, session: Annotated[Session, Depends(get_db_session)]
) -> ClientPolicyView:
"""Get client policies."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
return ClientPolicyView.from_client(client)
@router.put("/clients/{name}/policies/")
async def update_client_policies(
request: Request,
name: str,
policy_update: ClientPolicyUpdate,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientPolicyView:
"""Update client policies.
This is also how you delete policies.
"""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
# Remove old policies.
policies = session.exec(
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
).all()
deleted_policies: list[ClientAccessPolicy] = []
added_policies: list[ClientAccessPolicy] = []
for policy in policies:
session.delete(policy)
deleted_policies.append(policy)
for source in policy_update.sources:
LOG.debug("Source %r", source)
policy = ClientAccessPolicy(source=str(source), client_id=client.id)
session.add(policy)
added_policies.append(policy)
session.commit()
session.refresh(client)
for policy in deleted_policies:
audit.audit_remove_policy(session, request, client, policy)
for policy in added_policies:
audit.audit_update_policy(session, request, client, policy)
return ClientPolicyView.from_client(client)
return router

View File

@ -0,0 +1,232 @@
"""Secrets sub-api factory."""
# pyright: reportUnusedFunction=false
import logging
from collections import defaultdict
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlmodel import Session, select
from typing import Annotated
from sshecret_backend.models import Client, ClientSecret
from sshecret_backend.view_models import (
ClientReference,
ClientSecretDetailList,
ClientSecretList,
ClientSecretPublic,
BodyValue,
ClientSecretResponse,
)
from sshecret_backend import audit
from sshecret_backend.types import DBSessionDep
from .common import get_client_by_id_or_name
LOG = logging.getLogger(__name__)
async def lookup_client_secret(
session: Session, client: Client, name: str
) -> ClientSecret | None:
"""Look up a secret for a client."""
statement = (
select(ClientSecret)
.where(ClientSecret.client_id == client.id)
.where(ClientSecret.name == name)
)
results = session.exec(statement)
return results.first()
def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
"""Construct clients sub-api."""
router = APIRouter()
@router.post("/clients/{name}/secrets/")
async def add_secret_to_client(
request: Request,
name: str,
client_secret: ClientSecretPublic,
session: Annotated[Session, Depends(get_db_session)],
) -> None:
"""Add secret to a client."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
existing_secret = await lookup_client_secret(
session, client, client_secret.name
)
if existing_secret:
raise HTTPException(
status_code=400,
detail="Cannot add a secret. A different secret with the same name already exists.",
)
db_secret = ClientSecret(
name=client_secret.name, client_id=client.id, secret=client_secret.secret
)
session.add(db_secret)
session.commit()
session.refresh(db_secret)
audit.audit_create_secret(session, request, client, db_secret)
@router.put("/clients/{name}/secrets/{secret_name}")
async def update_client_secret(
request: Request,
name: str,
secret_name: str,
secret_data: BodyValue,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientSecretResponse:
"""Update a client secret.
This can also be used for destructive creates.
"""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
existing_secret = await lookup_client_secret(session, client, secret_name)
if existing_secret:
existing_secret.secret = secret_data.value
session.add(existing_secret)
session.commit()
session.refresh(existing_secret)
audit.audit_update_secret(session, request, client, existing_secret)
return ClientSecretResponse.from_client_secret(existing_secret)
db_secret = ClientSecret(
name=secret_name,
client_id=client.id,
secret=secret_data.value,
)
session.add(db_secret)
session.commit()
session.refresh(db_secret)
audit.audit_create_secret(session, request, client, db_secret)
return ClientSecretResponse.from_client_secret(db_secret)
@router.get("/clients/{name}/secrets/{secret_name}")
async def request_client_secret(
request: Request,
name: str,
secret_name: str,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientSecretResponse:
"""Get a client secret."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
secret = await lookup_client_secret(session, client, secret_name)
if not secret:
raise HTTPException(
status_code=404, detail="Cannot find a secret with the given name."
)
response_model = ClientSecretResponse.from_client_secret(secret)
audit.audit_access_secret(session, request, client, secret)
return response_model
@router.delete("/clients/{name}/secrets/{secret_name}")
async def delete_client_secret(
request: Request,
name: str,
secret_name: str,
session: Annotated[Session, Depends(get_db_session)],
) -> None:
"""Delete a secret."""
client = await get_client_by_id_or_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
secret = await lookup_client_secret(session, client, secret_name)
if not secret:
raise HTTPException(
status_code=404, detail="Cannot find a secret with the given name."
)
session.delete(secret)
session.commit()
audit.audit_delete_secret(session, request, client, secret)
@router.get("/secrets/")
async def get_secret_map(
request: Request, session: Annotated[Session, Depends(get_db_session)]
) -> list[ClientSecretList]:
"""Get a list of all secrets and which clients have them."""
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
for client_secret in session.exec(select(ClientSecret)).all():
if not client_secret.client:
if client_secret.name not in client_secret_map:
client_secret_map[client_secret.name] = []
continue
client_secret_map[client_secret.name].append(client_secret.client.name)
audit.audit_client_secret_list(session, request)
return [
ClientSecretList(name=secret_name, clients=clients)
for secret_name, clients in client_secret_map.items()
]
@router.get("/secrets/detailed/")
async def get_detailed_secret_map(
request: Request, session: Annotated[Session, Depends(get_db_session)]
) -> list[ClientSecretDetailList]:
"""Get a list of all secrets and which clients have them."""
client_secrets: dict[str, ClientSecretDetailList] = {}
for client_secret in session.exec(select(ClientSecret)).all():
if client_secret.name not in client_secrets:
client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name)
client_secrets[client_secret.name].ids.append(str(client_secret.id))
if not client_secret.client:
continue
client_secrets[client_secret.name].clients.append(ClientReference(id=str(client_secret.client.id), name=client_secret.client.name))
audit.audit_client_secret_list(session, request)
return list(client_secrets.values())
@router.get("/secrets/{name}")
async def get_secret_clients(
request: Request,
name: str,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientSecretList:
"""Get a list of which clients has a named secret."""
clients: list[str] = []
for client_secret in session.exec(
select(ClientSecret).where(ClientSecret.name == name)
).all():
if not client_secret.client:
continue
clients.append(client_secret.client.name)
return ClientSecretList(name=name, clients=clients)
@router.get("/secrets/{name}/detailed")
async def get_secret_clients_detailed(
request: Request,
name: str,
session: Annotated[Session, Depends(get_db_session)],
) -> ClientSecretDetailList:
"""Get a list of which clients has a named secret."""
detail_list = ClientSecretDetailList(name=name)
for client_secret in session.exec(
select(ClientSecret).where(ClientSecret.name == name)
).all():
if not client_secret.client:
continue
detail_list.ids.append(str(client_secret.id))
detail_list.clients.append(ClientReference(id=str(client_secret.client.id), name=client_secret.client.name))
return detail_list
return router

View File

@ -1,436 +1,65 @@
"""FastAPI api.
TODO: We may want to allow a consumer to generate audit log entries manually.
"""
"""FastAPI api."""
# pyright: reportUnusedFunction=false
import logging
from collections.abc import Sequence
from contextlib import asynccontextmanager
from typing import Annotated
import bcrypt
from fastapi import (
APIRouter,
Depends,
FastAPI,
Header,
HTTPException,
Query,
Request,
status,
)
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from sqlalchemy import Engine
from sqlmodel import Session, select
from . import audit
from .db import get_engine
from .models import (
APIClient,
AuditLog,
Client,
ClientAccessPolicy,
ClientSecret,
init_db,
)
from .settings import get_settings
from .view_models import (
BodyValue,
ClientCreate,
ClientSecretPublic,
ClientSecretResponse,
ClientUpdate,
ClientView,
ClientPolicyView,
ClientPolicyUpdate,
)
settings = get_settings()
engine = get_engine(settings.db_file)
from .models import init_db
from .backend_api import get_backend_api
from .db import setup_database
from .settings import BackendSettings
from .types import DBSessionDep
LOG = logging.getLogger(__name__)
API_VERSION = "v1"
def init_backend_app(engine: Engine, get_db_session: DBSessionDep) -> FastAPI:
"""Initialize backend app."""
def verify_token(token: str, stored_hash: str) -> bool:
"""Verify token."""
token_bytes = token.encode("utf-8")
stored_bytes = stored_hash.encode("utf-8")
return bcrypt.checkpw(token_bytes, stored_bytes)
@asynccontextmanager
async def lifespan(_app: FastAPI):
"""Create database before starting the server."""
LOG.debug("Running lifespan")
init_db(engine)
yield
app = FastAPI(lifespan=lifespan)
app.include_router(get_backend_api(get_db_session))
@asynccontextmanager
async def lifespan(_app: FastAPI):
"""Create database before starting the server."""
init_db(engine)
yield
async def get_session():
"""Get the session."""
with Session(engine) as session:
yield session
async def validate_token(
x_api_token: Annotated[str, Header()],
session: Annotated[Session, Depends(get_session)],
) -> str:
"""Validate token."""
LOG.debug("Validating token %s", x_api_token)
statement = select(APIClient)
results = session.exec(statement)
valid = False
for result in results:
if verify_token(x_api_token, result.token):
valid = True
LOG.debug("Token is valid")
break
if not valid:
LOG.debug("Token is not valid.")
raise HTTPException(status_code=401, detail="unauthorized. invalid api token.")
return x_api_token
async def get_client_by_name(session: Session, name: str) -> Client | None:
"""Get client by name."""
client_filter = select(Client).where(Client.name == name)
client_results = session.exec(client_filter)
return client_results.first()
async def lookup_client_secret(
session: Session, client: Client, name: str
) -> ClientSecret | None:
"""Look up a secret for a client."""
statement = (
select(ClientSecret)
.where(ClientSecret.client_id == client.id)
.where(ClientSecret.name == name)
)
results = session.exec(statement)
return results.first()
LOG.info("Initializing app.")
backend_api = APIRouter(
prefix=f"/api/{API_VERSION}",
lifespan=lifespan,
dependencies=[Depends(validate_token)],
)
@backend_api.get("/clients/")
async def get_clients(
session: Annotated[Session, Depends(get_session)]
) -> list[ClientView]:
"""Get clients."""
statement = select(Client)
results = session.exec(statement)
clients = list(results)
return ClientView.from_client_list(clients)
@backend_api.get("/clients/{name}")
async def get_client(
request: Request, name: str, session: Annotated[Session, Depends(get_session)]
) -> ClientView:
"""Fetch a client."""
statement = select(Client).where(Client.name == name)
results = session.exec(statement)
client = results.first()
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
audit.audit_access_secrets(session, request, client)
return ClientView.from_client(client)
@backend_api.delete("/clients/{name}")
async def delete_client(
request: Request, name: str, session: Annotated[Session, Depends(get_session)]
) -> None:
"""Delete a client."""
statement = select(Client).where(Client.name == name)
results = session.exec(statement)
client = results.first()
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(
request: Request, exc: RequestValidationError
):
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
)
session.delete(client)
session.commit()
audit.audit_delete_client(session, request, client)
@backend_api.get("/clients/{name}/policies/")
async def get_client_policies(
name: str, session: Annotated[Session, Depends(get_session)]
) -> ClientPolicyView:
"""Get client policies."""
statement = select(Client).where(Client.name == name)
results = session.exec(statement)
client = results.first()
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
@app.get("/health")
async def get_health() -> JSONResponse:
"""Provide simple health check."""
return JSONResponse(
status_code=status.HTTP_200_OK, content=jsonable_encoder({"status": "LIVE"})
)
return ClientPolicyView.from_client(client)
return app
@backend_api.put("/clients/{name}/policies/")
async def update_client_policies(
request: Request,
name: str,
policy_update: ClientPolicyUpdate,
session: Annotated[Session, Depends(get_session)],
) -> ClientPolicyView:
"""Update client policies.
def create_backend_app(settings: BackendSettings) -> FastAPI:
"""Create the backend app."""
This is also how you delete policies.
"""
statement = select(Client).where(Client.name == name)
results = session.exec(statement)
client = results.first()
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
# Remove old policies.
policies = session.exec(
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
).all()
deleted_policies: list[ClientAccessPolicy] = []
added_policies: list[ClientAccessPolicy] = []
for policy in policies:
session.delete(policy)
deleted_policies.append(policy)
engine, get_db_session = setup_database(settings.db_url)
for source in policy_update.sources:
LOG.debug("Source %r", source)
policy = ClientAccessPolicy(source=str(source), client_id=client.id)
session.add(policy)
added_policies.append(policy)
session.commit()
session.refresh(client)
for policy in deleted_policies:
audit.audit_remove_policy(session, request, client, policy)
for policy in added_policies:
audit.audit_update_policy(session, request, client, policy)
return ClientPolicyView.from_client(client)
@backend_api.post("/clients/{name}/public-key")
async def update_client_public_key(
request: Request,
name: str,
client_update: ClientUpdate,
session: Annotated[Session, Depends(get_session)],
) -> ClientView:
"""Change the public key of a client.
This invalidates all secrets.
"""
statement = select(Client).where(Client.name == name)
results = session.exec(statement)
client = results.first()
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
client.public_key = client_update.public_key
for secret in session.exec(
select(ClientSecret).where(ClientSecret.client_id == client.id)
).all():
LOG.debug("Invalidated secret %s", secret.id)
secret.invalidated = True
secret.client_id = None
secret.client = None
session.add(client)
session.refresh(client)
session.commit()
audit.audit_invalidate_secrets(session, request, client)
return ClientView.from_client(client)
@backend_api.post("/clients/")
async def create_client(
request: Request,
client: ClientCreate,
session: Annotated[Session, Depends(get_session)],
) -> ClientView:
"""Create client."""
existing = await get_client_by_name(session, client.name)
if existing:
raise HTTPException(400, detail="Error: Already a client with that name.")
db_client = client.to_client()
session.add(db_client)
session.commit()
session.refresh(db_client)
audit.audit_create_client(session, request, db_client)
return ClientView.from_client(db_client)
@backend_api.post("/clients/{name}/secrets/")
async def add_secret_to_client(
request: Request,
name: str,
client_secret: ClientSecretPublic,
session: Annotated[Session, Depends(get_session)],
) -> None:
"""Add secret to a client."""
client = await get_client_by_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
existing_secret = await lookup_client_secret(session, client, client_secret.name)
if existing_secret:
raise HTTPException(
status_code=400,
detail="Cannot add a secret. A different secret with the same name already exists.",
)
db_secret = ClientSecret(
name=client_secret.name, client_id=client.id, secret=client_secret.secret
)
session.add(db_secret)
session.commit()
session.refresh(db_secret)
audit.audit_create_secret(session, request, client, db_secret)
@backend_api.put("/clients/{name}/secrets/{secret_name}")
async def update_client_secret(
request: Request,
name: str,
secret_name: str,
secret_data: BodyValue,
session: Annotated[Session, Depends(get_session)],
) -> ClientSecretResponse:
"""Update a client secret.
This can also be used for destructive creates.
"""
client = await get_client_by_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
existing_secret = await lookup_client_secret(session, client, secret_name)
if existing_secret:
existing_secret.secret = secret_data.value
session.add(existing_secret)
session.commit()
session.refresh(existing_secret)
audit.audit_update_secret(session, request, client, existing_secret)
return ClientSecretResponse.from_client_secret(existing_secret)
db_secret = ClientSecret(
name=secret_name,
client_id=client.id,
secret=secret_data.value,
)
session.add(db_secret)
session.commit()
session.refresh(db_secret)
audit.audit_create_secret(session, request, client, db_secret)
return ClientSecretResponse.from_client_secret(db_secret)
@backend_api.get("/clients/{name}/secrets/{secret_name}")
async def request_client_secret(
request: Request,
name: str,
secret_name: str,
session: Annotated[Session, Depends(get_session)],
) -> ClientSecretResponse:
"""Get a client secret."""
client = await get_client_by_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
secret = await lookup_client_secret(session, client, secret_name)
if not secret:
raise HTTPException(
status_code=404, detail="Cannot find a secret with the given name."
)
response_model = ClientSecretResponse.from_client_secret(secret)
audit.audit_access_secret(session, request, client, secret)
return response_model
@backend_api.delete("/clients/{name}/secrets/{secret_name}")
async def delete_client_secret(
request: Request,
name: str,
secret_name: str,
session: Annotated[Session, Depends(get_session)],
) -> None:
"""Delete a secret."""
client = await get_client_by_name(session, name)
if not client:
raise HTTPException(
status_code=404, detail="Cannot find a client with the given name."
)
secret = await lookup_client_secret(session, client, secret_name)
if not secret:
raise HTTPException(
status_code=404, detail="Cannot find a secret with the given name."
)
session.delete(secret)
session.commit()
audit.audit_delete_secret(session, request, client, secret)
@backend_api.get("/audit/", response_model=list[AuditLog])
async def get_audit_logs(
request: Request,
session: Annotated[Session, Depends(get_session)],
offset: Annotated[int, Query()] = 0,
limit: Annotated[int, Query(le=100)] = 100,
filter_client: Annotated[str | None, Query()] = None,
) -> Sequence[AuditLog]:
"""Get audit logs."""
audit.audit_access_audit_log(session, request)
statement = select(AuditLog).offset(offset).limit(limit)
if filter_client:
statement = statement.where(AuditLog.client_name == filter_client)
results = session.exec(statement).all()
return results
app = FastAPI()
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
)
app.include_router(backend_api)
return init_backend_app(engine, get_db_session)

View File

@ -3,6 +3,7 @@
from collections.abc import Sequence
from fastapi import Request
from sqlmodel import Session, select
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy
@ -21,6 +22,7 @@ def _write_audit_log(
"""Write the audit log."""
origin = _get_origin(request)
entry.origin = origin
entry.subsystem = "backend"
session.add(entry)
if commit:
session.commit()
@ -109,6 +111,23 @@ def audit_update_policy(
_write_audit_log(session, request, entry, commit)
def audit_update_client(
session: Session,
request: Request,
client: Client,
commit: bool = True,
) -> None:
"""Audit an update secret event."""
entry = AuditLog(
operation="UPDATE",
object="Client",
client_id=client.id,
client_name=client.name,
message="Client updated",
)
_write_audit_log(session, request, entry, commit)
def audit_update_secret(
session: Session,
request: Request,
@ -219,3 +238,15 @@ def audit_access_audit_log(
object="AuditLog",
)
_write_audit_log(session, request, entry, commit)
def audit_client_secret_list(
session: Session, request: Request, commit: bool = True
) -> None:
"""Audit a list of all secrets."""
entry = AuditLog(
operation="ACCESS",
message="All secret names and their clients was viewed",
)
_write_audit_log(session, request, entry, commit)

View File

@ -0,0 +1,67 @@
"""Backend API."""
import logging
from typing import Annotated
import bcrypt
from fastapi import APIRouter, Depends, Header, HTTPException
from sqlmodel import Session, select
from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api
from .models import (
APIClient,
)
from .types import DBSessionDep
LOG = logging.getLogger(__name__)
API_VERSION = "v1"
def verify_token(token: str, stored_hash: str) -> bool:
"""Verify token."""
token_bytes = token.encode("utf-8")
stored_bytes = stored_hash.encode("utf-8")
return bcrypt.checkpw(token_bytes, stored_bytes)
def get_backend_api(
get_db_session: DBSessionDep,
) -> APIRouter:
"""Construct backend API."""
async def validate_token(
x_api_token: Annotated[str, Header()],
session: Annotated[Session, Depends(get_db_session)],
) -> str:
"""Validate token."""
LOG.debug("Validating token %s", x_api_token)
statement = select(APIClient)
results = session.exec(statement)
valid = False
for result in results:
if verify_token(x_api_token, result.token):
valid = True
LOG.debug("Token is valid")
break
if not valid:
LOG.debug("Token is not valid.")
raise HTTPException(
status_code=401, detail="unauthorized. invalid api token."
)
return x_api_token
LOG.info("Initializing app.")
backend_api = APIRouter(
prefix=f"/api/{API_VERSION}",
dependencies=[Depends(validate_token)],
)
backend_api.include_router(get_audit_api(get_db_session))
backend_api.include_router(get_clients_api(get_db_session))
backend_api.include_router(get_policy_api(get_db_session))
backend_api.include_router(get_secrets_api(get_db_session))
return backend_api

View File

@ -1,11 +1,18 @@
"""CLI and main entry point."""
import code
import os
from pathlib import Path
from typing import cast
from dotenv import load_dotenv
import click
from sqlmodel import Session, create_engine, select
import uvicorn
from .db import generate_api_token
from .db import create_api_token
from .models import Client, ClientSecret, ClientAccessPolicy, AuditLog, APIClient
from .settings import BackendSettings
DEFAULT_LISTEN = "127.0.0.1"
DEFAULT_PORT = 8022
@ -14,18 +21,59 @@ WORKDIR = Path(os.getcwd())
load_dotenv()
@click.group()
@click.option("--database", help="Path to the sqlite database file.")
def cli(database: str) -> None:
@click.pass_context
def cli(ctx: click.Context, database: str) -> None:
"""CLI group."""
if database:
# Hopefully it's enough to set the environment variable as so.
os.environ["SSHECRET_DB_FILE"] = str(Path(database).absolute())
settings = BackendSettings(db_url=f"sqlite:///{Path(database).absolute()}")
else:
settings = BackendSettings()
ctx.obj = settings
@cli.command("generate-token")
def cli_generate_token() -> None:
@click.pass_context
def cli_generate_token(ctx: click.Context) -> None:
"""Generate a token."""
token = generate_api_token()
settings = cast(BackendSettings, ctx.obj)
engine = create_engine(settings.db_url)
with Session(engine) as session:
token = create_api_token(session, True)
click.echo("Generated api token:")
click.echo(token)
@cli.command("run")
@click.option("--host", default="127.0.0.1")
@click.option("--port", default=8022, type=click.INT)
@click.option("--dev", is_flag=True)
@click.option("--workers", type=click.INT)
def cli_run(host: str, port: int, dev: bool, workers: int | None) -> None:
"""Run the server."""
uvicorn.run("sshecret_backend.main:app", host=host, port=port, reload=dev, workers=workers)
@cli.command("repl")
@click.pass_context
def cli_repl(ctx: click.Context) -> None:
"""Run an interactive console."""
settings = cast(BackendSettings, ctx.obj)
engine = create_engine(settings.db_url)
with Session(engine) as session:
locals = {
"session": session,
"select": select,
"Client": Client,
"ClientSecret": ClientSecret,
"ClientAccessPolicy": ClientAccessPolicy,
"APIClient": APIClient,
"AuditLog": AuditLog,
}
console = code.InteractiveConsole(locals=locals, local_exit=True)
banner = "Sshecret-backend REPL.\nUse 'session' to interact with the database."
console.interact(banner=banner, exitmsg="Bye!")

View File

@ -2,6 +2,7 @@
import logging
import secrets
from collections.abc import Generator, Callable
from pathlib import Path
from sqlalchemy import Engine
from sqlmodel import Session, create_engine, text
@ -9,14 +10,30 @@ import bcrypt
from sqlalchemy.engine import URL
from .models import APIClient, init_db
from .settings import get_settings
from .models import APIClient
LOG = logging.getLogger(__name__)
def setup_database(
db_url: URL | str,
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
"""Setup database."""
engine = create_engine(db_url, echo=False)
with engine.connect() as connection:
connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only
def get_db_session() -> Generator[Session, None, None]:
"""Get DB Session."""
with Session(engine) as session:
yield session
return engine, get_db_session
def get_engine(filename: Path, echo: bool = False) -> Engine:
"""Initialize the engine."""
url = URL.create(drivername="sqlite", database=str(filename.absolute()))
@ -27,20 +44,6 @@ def get_engine(filename: Path, echo: bool = False) -> Engine:
return engine
def create_db_and_tables(filename: Path, echo: bool = False) -> bool:
"""Create database and tables.
Returns True if the database was created.
"""
created = False
if not filename.exists():
created = True
engine = get_engine(filename, echo)
init_db(engine)
return created
def create_api_token(session: Session, read_write: bool) -> str:
"""Create API token."""
token = secrets.token_urlsafe(32)
@ -54,14 +57,3 @@ def create_api_token(session: Session, read_write: bool) -> str:
session.commit()
return token
def generate_api_token() -> str:
"""Generate API token."""
settings = get_settings()
engine = get_engine(settings.db_file)
init_db(engine)
with Session(engine) as session:
token = create_api_token(session, True)
return token

View File

@ -0,0 +1,7 @@
"""Main script entrypoint."""
from .settings import BackendSettings
from .app import create_backend_app
app = create_backend_app(BackendSettings())

View File

@ -7,17 +7,21 @@ This might require some changes to these schemas.
"""
import logging
import uuid
from datetime import datetime
import sqlalchemy as sa
from sqlmodel import Field, Relationship, SQLModel
LOG = logging.getLogger(__name__)
class Client(SQLModel, table=True):
"""Client model."""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
name: str = Field(unique=True)
description: str | None = None
public_key: str
created_at: datetime | None = Field(
@ -61,11 +65,13 @@ class ClientAccessPolicy(SQLModel, table=True):
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
)
class ClientSecret(SQLModel, table=True):
"""A client secret."""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
name: str
description: str | None = None
client_id: uuid.UUID | None = Field(foreign_key="client.id", ondelete="CASCADE")
client: Client | None = Relationship(back_populates="secrets")
secret: str
@ -92,6 +98,7 @@ class AuditLog(SQLModel, table=True):
"""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
subsystem: str | None = None
object: str | None = None
object_id: str | None = None
operation: str
@ -107,6 +114,7 @@ class AuditLog(SQLModel, table=True):
nullable=False,
)
class APIClient(SQLModel, table=True):
"""Stores API Keys."""
@ -120,6 +128,8 @@ class APIClient(SQLModel, table=True):
nullable=False,
)
def init_db(engine: sa.Engine) -> None:
"""Create database."""
LOG.info("Starting init_db")
SQLModel.metadata.create_all(engine)

View File

@ -1,15 +1,13 @@
"""Settings management."""
from typing import override
from pathlib import Path
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import (
BaseSettings,
SettingsConfigDict,
)
DEFAULT_DATABASE = "sshecret.db"
DEFAULT_DATABASE = "sqlite:///sshecret.db"
class BackendSettings(BaseSettings):
@ -17,7 +15,7 @@ class BackendSettings(BaseSettings):
model_config = SettingsConfigDict(env_file=".backend.env", env_prefix="sshecret_")
db_file: Path = Field(default=Path(DEFAULT_DATABASE).absolute())
db_url: str = Field(default=DEFAULT_DATABASE)
def get_settings() -> BackendSettings:

View File

@ -1,13 +1,17 @@
"""Test helpers."""
import logging
from sqlmodel import Session
from .db import get_engine, create_api_token
from sshecret_backend.settings import BackendSettings
from .models import init_db
from .settings import get_settings
from .db import create_api_token, setup_database
def create_test_token(session: Session) -> str:
LOG = logging.getLogger(__name__)
def create_test_token(settings: BackendSettings) -> str:
"""Create test token."""
settings = get_settings()
engine = get_engine(settings.db_file)
init_db(engine)
return create_api_token(session, True)
engine, _setupdb = setup_database(settings.db_url)
with Session(engine) as session:
init_db(engine)
return create_api_token(session, True)

View File

@ -0,0 +1,8 @@
"""Common type definitions."""
from collections.abc import Callable, Generator
from sqlmodel import Session
DBSessionDep = Callable[[], Generator[Session, None, None]]

View File

@ -1,24 +1,27 @@
"""Models for API views."""
import ipaddress
import uuid
from datetime import datetime
from typing import Annotated, Any, Self, override
from typing import Annotated, Self, override
from sqlmodel import Field, SQLModel
from pydantic import IPvAnyAddress, IPvAnyNetwork
from . import models
from pydantic import AfterValidator, IPvAnyAddress, IPvAnyNetwork
from sshecret.crypto import public_key_validator
from . import models
class ClientView(SQLModel):
"""View for a single client."""
id: uuid.UUID
name: str
description: str | None = None
public_key: str
policies: list[str] = ["0.0.0.0/0", "::/0"]
secrets: list[str] = Field(default_factory=list)
created_at: datetime
created_at: datetime | None
updated_at: datetime | None = None
@classmethod
@ -33,6 +36,7 @@ class ClientView(SQLModel):
view = cls(
id=client.id,
name=client.name,
description=client.description,
public_key=client.public_key,
created_at=client.created_at,
updated_at=client.updated_at or None,
@ -46,24 +50,34 @@ class ClientView(SQLModel):
return view
class ClientQueryResult(SQLModel):
"""Result class for queries towards the client list."""
clients: list[ClientView] = Field(default_factory=list)
total_results: int
remaining_results: int
class ClientCreate(SQLModel):
"""Model to create a client."""
name: str
public_key: str
description: str | None = None
public_key: Annotated[str, AfterValidator(public_key_validator)]
def to_client(self) -> models.Client:
"""Instantiate a client."""
public_key = self.public_key
return models.Client(
name=self.name, public_key=public_key
name=self.name,
public_key=self.public_key,
description=self.description,
)
class ClientUpdate(SQLModel):
"""Model to update the client public key."""
public_key: str
public_key: Annotated[str, AfterValidator(public_key_validator)]
class BodyValue(SQLModel):
@ -77,6 +91,7 @@ class ClientSecretPublic(SQLModel):
name: str
secret: str
description: str | None = None
@classmethod
def from_client_secret(cls, client_secret: models.ClientSecret) -> Self:
@ -84,13 +99,14 @@ class ClientSecretPublic(SQLModel):
return cls(
name=client_secret.name,
secret=client_secret.secret,
description=client_secret.description,
)
class ClientSecretResponse(ClientSecretPublic):
"""A secret view."""
created_at: datetime
created_at: datetime | None
updated_at: datetime | None = None
@override
@ -123,3 +139,31 @@ class ClientPolicyUpdate(SQLModel):
"""Model for updating policies."""
sources: list[IPvAnyAddress | IPvAnyNetwork]
class ClientSecretList(SQLModel):
"""Model for aggregating identically named secrets."""
name: str
clients: list[str]
class ClientReference(SQLModel):
"""Reference to a client."""
id: str
name: str
class ClientSecretDetailList(SQLModel):
"""A more detailed version of the ClientSecretList."""
name: str
ids: list[str] = Field(default_factory=list)
clients: list[ClientReference] = Field(default_factory=list)
class AuditInfo(SQLModel):
"""Information about audit information."""
entries: int

View File

@ -1,18 +1,17 @@
"""Tests of the backend api using pytest."""
import logging
import random
import string
from pathlib import Path
from httpx import Response
import pytest
from fastapi.testclient import TestClient
from sqlmodel import Session, SQLModel, create_engine
from sqlmodel.pool import StaticPool
from sshecret_backend.app import app, get_session
from sshecret.crypto import generate_private_key, generate_public_key_string
from sshecret_backend.app import create_backend_app
from sshecret_backend.testing import create_test_token
from sshecret_backend.models import AuditLog
from sshecret_backend.settings import BackendSettings
LOG = logging.getLogger()
@ -25,20 +24,15 @@ LOG.setLevel(logging.DEBUG)
def make_test_key() -> str:
"""Generate a test key."""
randomlength = 540
key = "ssh-rsa "
randompart = "".join(
random.choices(string.ascii_letters + string.digits, k=randomlength)
)
comment = " invalid-test-key"
return key + randompart + comment
private_key = generate_private_key()
return generate_public_key_string(private_key.public_key())
def create_client(
test_client: TestClient,
headers: dict[str, str],
name: str,
public_key: str | None = None,
description: str | None = None,
) -> Response:
"""Create client."""
if not public_key:
@ -47,50 +41,35 @@ def create_client(
"name": name,
"public_key": public_key,
}
create_response = test_client.post("/api/v1/clients", headers=headers, json=data)
if description:
data["description"] = description
create_response = test_client.post("/api/v1/clients", json=data)
return create_response
@pytest.fixture(name="session")
def session_fixture():
engine = create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
)
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
yield session
@pytest.fixture(name="token")
def token_fixture(session: Session):
"""Generate a token."""
token = create_test_token(session)
return token
@pytest.fixture(name="headers")
def headers_fixture(token: str) -> dict[str, str]:
"""Generate headers."""
return {"X-API-Token": token}
@pytest.fixture(name="test_client")
def test_client_fixture(session: Session):
def create_client_fixture(tmp_path: Path):
"""Test client fixture."""
def get_session_override():
return session
db_file = tmp_path / "backend.db"
print(f"DB File: {db_file.absolute()}")
settings = BackendSettings(db_url=f"sqlite:///{db_file.absolute()}")
app = create_backend_app(settings)
app.dependency_overrides[get_session] = get_session_override
test_client = TestClient(app)
token = create_test_token(settings)
test_client = TestClient(app, headers={"X-API-Token": token})
yield test_client
app.dependency_overrides.clear()
def test_missing_token(test_client: TestClient) -> None:
"""Test logging in with missing token."""
response = test_client.get("/api/v1/clients/")
# Save headers
old_headers = test_client.headers
test_client.headers = {}
response = test_client.get("/api/v1/clients/", headers={})
assert response.status_code == 422
test_client.headers = old_headers
def test_incorrect_token(test_client: TestClient) -> None:
@ -99,22 +78,24 @@ def test_incorrect_token(test_client: TestClient) -> None:
assert response.status_code == 401
def test_with_token(test_client: TestClient, token: str) -> None:
def test_with_token(test_client: TestClient) -> None:
"""Test with a valid token."""
response = test_client.get("/api/v1/clients/", headers={"X-API-Token": token})
response = test_client.get("/api/v1/clients/")
assert response.status_code == 200
assert len(response.json()) == 0
data = response.json()
assert data["total_results"] == 0
def test_create_client(test_client: TestClient, headers: dict[str, str]) -> None:
def test_create_client(test_client: TestClient) -> None:
"""Test creating a client."""
client_name = "test"
client_publickey = make_test_key()
create_response = create_client(test_client, headers, client_name, client_publickey)
create_response = create_client(test_client, client_name, client_publickey)
assert create_response.status_code == 200
response = test_client.get("/api/v1/clients/", headers=headers)
response = test_client.get("/api/v1/clients")
assert response.status_code == 200
clients = response.json()
clients_result = response.json()
clients = clients_result["clients"]
assert isinstance(clients, list)
client = clients[0]
assert isinstance(client, dict)
@ -122,72 +103,63 @@ def test_create_client(test_client: TestClient, headers: dict[str, str]) -> None
assert client.get("created_at") is not None
def test_delete_client(test_client: TestClient, headers: dict[str, str]) -> None:
def test_delete_client(test_client: TestClient) -> None:
"""Test creating a client."""
client_name = "test"
create_response = create_client(
test_client,
headers,
client_name,
)
assert create_response.status_code == 200
resp = test_client.delete("/api/v1/clients/test", headers=headers)
resp = test_client.delete("/api/v1/clients/test")
assert resp.status_code == 200
resp = test_client.get("/api/v1/clients/test", headers=headers)
resp = test_client.get("/api/v1/clients/test")
assert resp.status_code == 404
def test_add_secret(test_client: TestClient, headers: dict[str, str]) -> None:
def test_add_secret(test_client: TestClient) -> None:
"""Test adding a secret to a client."""
client_name = "test"
client_publickey = make_test_key()
create_response = create_client(
test_client,
headers,
client_name,
client_publickey,
)
assert create_response.status_code == 200
secret_name = "mysecret"
secret_value = "shhhh"
data = {"name": secret_name, "secret": secret_value}
response = test_client.post(
"/api/v1/clients/test/secrets/", headers=headers, json=data
)
data = {"name": secret_name, "secret": secret_value, "description": "A test secret"}
response = test_client.post("/api/v1/clients/test/secrets/", json=data)
assert response.status_code == 200
# Get it back
get_response = test_client.get(
"/api/v1/clients/test/secrets/mysecret", headers=headers
)
get_response = test_client.get("/api/v1/clients/test/secrets/mysecret")
assert get_response.status_code == 200
secret_body = get_response.json()
assert secret_body["name"] == data["name"]
assert secret_body["secret"] == data["secret"]
def test_delete_secret(test_client: TestClient, headers: dict[str, str]) -> None:
def test_delete_secret(test_client: TestClient) -> None:
"""Test deleting a secret."""
test_add_secret(test_client, headers)
resp = test_client.delete("/api/v1/clients/test/secrets/mysecret", headers=headers)
test_add_secret(test_client)
resp = test_client.delete("/api/v1/clients/test/secrets/mysecret")
assert resp.status_code == 200
get_response = test_client.get(
"/api/v1/clients/test/secrets/mysecret", headers=headers
)
get_response = test_client.get("/api/v1/clients/test/secrets/mysecret")
assert get_response.status_code == 404
def test_put_add_secret(test_client: TestClient, headers: dict[str, str]) -> None:
def test_put_add_secret(test_client: TestClient) -> None:
"""Test adding secret via PUT."""
# Use the test_create_client function to create a client.
test_create_client(test_client, headers)
test_create_client(test_client)
secret_name = "mysecret"
secret_value = "shhhh"
data = {"name": secret_name, "secret": secret_value}
data = {"name": secret_name, "secret": secret_value, "description": None}
response = test_client.put(
"/api/v1/clients/test/secrets/mysecret",
headers=headers,
json={"value": secret_value},
)
assert response.status_code == 200
@ -197,13 +169,12 @@ def test_put_add_secret(test_client: TestClient, headers: dict[str, str]) -> Non
assert response_model == data
def test_put_update_secret(test_client: TestClient, headers: dict[str, str]) -> None:
def test_put_update_secret(test_client: TestClient) -> None:
"""Test updating a client secret."""
test_add_secret(test_client, headers)
test_add_secret(test_client)
new_value = "itsasecret"
update_response = test_client.put(
"/api/v1/clients/test/secrets/mysecret",
headers=headers,
json={"value": new_value},
)
assert update_response.status_code == 200
@ -218,26 +189,25 @@ def test_put_update_secret(test_client: TestClient, headers: dict[str, str]) ->
assert "updated_at" in response_model
def test_audit_logging(test_client: TestClient, headers: dict[str, str]) -> None:
def test_audit_logging(test_client: TestClient) -> None:
"""Test audit logging."""
public_key = make_test_key()
create_client_resp = create_client(test_client, headers, "test", public_key)
create_client_resp = create_client(test_client, "test", public_key)
assert create_client_resp.status_code == 200
secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"}
for name, secret in secrets.items():
add_resp = test_client.post(
"/api/v1/clients/test/secrets/",
headers=headers,
json={"name": name, "secret": secret},
)
assert add_resp.status_code == 200
# Fetch the entire client.
get_client_resp = test_client.get("/api/v1/clients/test", headers=headers)
get_client_resp = test_client.get("/api/v1/clients/test")
assert get_client_resp.status_code == 200
# Fetch the audit log
audit_log_resp = test_client.get("/api/v1/audit/", headers=headers)
audit_log_resp = test_client.get("/api/v1/audit/")
assert audit_log_resp.status_code == 200
audit_logs = audit_log_resp.json()
assert len(audit_logs) > 0
@ -247,61 +217,60 @@ def test_audit_logging(test_client: TestClient, headers: dict[str, str]) -> None
assert audit_log is not None
def test_audit_log_filtering(
session: Session, test_client: TestClient, headers: dict[str, str]
) -> None:
"""Test audit log filtering."""
# Create a lot of test data, but just manually.
audit_log_amount = 150
entries: list[AuditLog] = []
for i in range(audit_log_amount):
client_id = i % 5
entries.append(
AuditLog(
operation="TEST",
object_id=str(i),
client_name=f"client-{client_id}",
message="Test Message",
)
)
# def test_audit_log_filtering(
# session: Session, test_client: TestClient
# ) -> None:
# """Test audit log filtering."""
# # Create a lot of test data, but just manually.
# audit_log_amount = 150
# entries: list[AuditLog] = []
# for i in range(audit_log_amount):
# client_id = i % 5
# entries.append(
# AuditLog(
# operation="TEST",
# object_id=str(i),
# client_name=f"client-{client_id}",
# message="Test Message",
# )
# )
session.add_all(entries)
session.commit()
# session.add_all(entries)
# session.commit()
# This should have generated a lot of audit messages
# # This should have generated a lot of audit messages
audit_path = "/api/v1/audit/"
audit_log_resp = test_client.get(audit_path, headers=headers)
assert audit_log_resp.status_code == 200
entries = audit_log_resp.json()
assert len(entries) == 100 # We get 100 at a time
# audit_path = "/api/v1/audit/"
# audit_log_resp = test_client.get(audit_path)
# assert audit_log_resp.status_code == 200
# entries = audit_log_resp.json()
# assert len(entries) == 100 # We get 100 at a time
audit_log_resp = test_client.get(
audit_path, headers=headers, params={"offset": 100}
)
entries = audit_log_resp.json()
assert len(entries) == 52 # There should be 50 + the two requests we made
# audit_log_resp = test_client.get(
# audit_path, params={"offset": 100}
# )
# entries = audit_log_resp.json()
# assert len(entries) == 52 # There should be 50 + the two requests we made
# Try to get a specific client
# There should be 30 log entries for each client.
audit_log_resp = test_client.get(
audit_path, headers=headers, params={"filter_client": "client-1"}
)
# # Try to get a specific client
# # There should be 30 log entries for each client.
# audit_log_resp = test_client.get(
# audit_path, params={"filter_client": "client-1"}
# )
entries = audit_log_resp.json()
assert len(entries) == 30
# entries = audit_log_resp.json()
# assert len(entries) == 30
def test_secret_invalidation(test_client: TestClient, headers: dict[str, str]) -> None:
def test_secret_invalidation(test_client: TestClient) -> None:
"""Test secret invalidation."""
initial_key = make_test_key()
create_client_resp = create_client(test_client, headers, "test", initial_key)
create_client_resp = create_client(test_client, "test", initial_key)
assert create_client_resp.status_code == 200
secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"}
for name, secret in secrets.items():
add_resp = test_client.post(
"/api/v1/clients/test/secrets/",
headers=headers,
json={"name": name, "secret": secret},
)
assert add_resp.status_code == 200
@ -311,13 +280,12 @@ def test_secret_invalidation(test_client: TestClient, headers: dict[str, str]) -
new_key = make_test_key()
update_resp = test_client.post(
"/api/v1/clients/test/public-key",
headers=headers,
json={"public_key": new_key},
)
assert update_resp.status_code == 200
# Fetch the client. The list of secrets should be empty.
get_resp = test_client.get("/api/v1/clients/test", headers=headers)
get_resp = test_client.get("/api/v1/clients/test")
assert get_resp.status_code == 200
client = get_resp.json()
secrets = client.get("secrets")
@ -325,14 +293,14 @@ def test_secret_invalidation(test_client: TestClient, headers: dict[str, str]) -
def test_client_default_policies(
test_client: TestClient, headers: dict[str, str]
test_client: TestClient,
) -> None:
"""Test client policies."""
public_key = make_test_key()
resp = create_client(test_client, headers, "test", public_key)
resp = create_client(test_client, "test")
assert resp.status_code == 200
# Fetch policies, should return *
resp = test_client.get("/api/v1/clients/test/policies/", headers=headers)
resp = test_client.get("/api/v1/clients/test/policies/")
assert resp.status_code == 200
policies = resp.json()
@ -340,21 +308,17 @@ def test_client_default_policies(
assert policies["sources"] == ["0.0.0.0/0", "::/0"]
def test_client_policy_update_one(
test_client: TestClient, headers: dict[str, str]
) -> None:
def test_client_policy_update_one(test_client: TestClient) -> None:
"""Update client policy with single policy."""
public_key = make_test_key()
resp = create_client(test_client, headers, "test", public_key)
resp = create_client(test_client, "test", public_key)
assert resp.status_code == 200
policy = ["192.0.2.1"]
resp = test_client.put(
"/api/v1/clients/test/policies/", headers=headers, json={"sources": policy}
)
resp = test_client.put("/api/v1/clients/test/policies/", json={"sources": policy})
assert resp.status_code == 200
resp = test_client.get("/api/v1/clients/test/policies/", headers=headers)
resp = test_client.get("/api/v1/clients/test/policies/")
assert resp.status_code == 200
policies = resp.json()
@ -362,22 +326,18 @@ def test_client_policy_update_one(
assert policies["sources"] == policy
def test_client_policy_update_advanced(
test_client: TestClient, headers: dict[str, str]
) -> None:
def test_client_policy_update_advanced(test_client: TestClient) -> None:
"""Test other policy update scenarios."""
public_key = make_test_key()
resp = create_client(test_client, headers, "test", public_key)
resp = create_client(test_client, "test", public_key)
assert resp.status_code == 200
policy = ["192.0.2.1", "198.18.0.0/24"]
resp = test_client.put(
"/api/v1/clients/test/policies/", headers=headers, json={"sources": policy}
)
resp = test_client.put("/api/v1/clients/test/policies/", json={"sources": policy})
assert resp.status_code == 200
resp = test_client.get("/api/v1/clients/test/policies/", headers=headers)
resp = test_client.get("/api/v1/clients/test/policies/")
assert resp.status_code == 200
policies = resp.json()
@ -389,13 +349,11 @@ def test_client_policy_update_advanced(
policy = ["obviosly_wrong"]
resp = test_client.put(
"/api/v1/clients/test/policies/", headers=headers, json={"sources": policy}
)
resp = test_client.put("/api/v1/clients/test/policies/", json={"sources": policy})
assert resp.status_code == 422
# Check that the old value is still there
resp = test_client.get("/api/v1/clients/test/policies/", headers=headers)
resp = test_client.get("/api/v1/clients/test/policies/")
assert resp.status_code == 200
policies = resp.json()
@ -407,18 +365,14 @@ def test_client_policy_update_advanced(
#
def test_client_policy_update_unset(
test_client: TestClient, headers: dict[str, str]
) -> None:
def test_client_policy_update_unset(test_client: TestClient) -> None:
"""Test clearing the client policy."""
public_key = make_test_key()
resp = create_client(test_client, headers, "test", public_key)
resp = create_client(test_client, "test", public_key)
assert resp.status_code == 200
policy = ["192.0.2.1", "198.18.0.0/24"]
resp = test_client.put(
"/api/v1/clients/test/policies/", headers=headers, json={"sources": policy}
)
resp = test_client.put("/api/v1/clients/test/policies/", json={"sources": policy})
assert resp.status_code == 200
policies = resp.json()
@ -428,11 +382,158 @@ def test_client_policy_update_unset(
# Now we clear the policies
resp = test_client.put(
"/api/v1/clients/test/policies/", headers=headers, json={"sources": []}
)
resp = test_client.put("/api/v1/clients/test/policies/", json={"sources": []})
assert resp.status_code == 200
policies = resp.json()
assert policies["sources"] == ["0.0.0.0/0", "::/0"]
def test_client_update(test_client: TestClient) -> None:
"""Test generic update of a client."""
public_key = make_test_key()
resp = create_client(test_client, "test", public_key, "PRE")
assert resp.status_code == 200
resp = test_client.get("/api/v1/clients/test")
assert resp.status_code == 200
client_data = resp.json()
assert client_data["description"] == "PRE"
# Update the description
new_client_data = {
"name": "test",
"description": "POST",
"public_key": client_data["public_key"],
}
resp = test_client.put("/api/v1/clients/test", json=new_client_data)
assert resp.status_code == 200
client_data = resp.json()
assert client_data["description"] == "POST"
resp = test_client.get("/api/v1/clients/test")
assert resp.status_code == 200
client_data = resp.json()
assert client_data["description"] == "POST"
def test_get_secret_list(test_client: TestClient) -> None:
"""Test the secret to client map view."""
# Make 4 clients
for x in range(4):
public_key = make_test_key()
create_client(test_client, f"client-{x}", public_key)
# Create a secret that only this client has.
resp = test_client.put(
f"/api/v1/clients/client-{x}/secrets/client-{x}", json={"value": "SECRET"}
)
assert resp.status_code == 200
# Create a secret that all of them have.
resp = test_client.put(
f"/api/v1/clients/client-{x}/secrets/commonsecret", json={"value": "SECRET"}
)
assert resp.status_code == 200
# Get the secret list
resp = test_client.get("/api/v1/secrets/")
assert resp.status_code == 200
data = resp.json()
assert isinstance(data, list)
assert len(data) == 5
for entry in data:
if entry["name"] == "commonsecret":
assert len(entry["clients"]) == 4
else:
assert len(entry["clients"]) == 1
assert entry["clients"][0] == entry["name"]
def test_get_secret_clients(test_client: TestClient) -> None:
"""Get the clients for a single secret."""
for x in range(4):
public_key = make_test_key()
create_client(test_client, f"client-{x}", public_key)
# Create a secret that every second of them have.
if x % 2 == 1:
continue
resp = test_client.put(
f"/api/v1/clients/client-{x}/secrets/commonsecret", json={"value": "SECRET"}
)
assert resp.status_code == 200
resp = test_client.get("/api/v1/secrets/commonsecret")
assert resp.status_code == 200
data = resp.json()
assert data["name"] == "commonsecret"
assert "client-0" in data["clients"]
assert "client-1" not in data["clients"]
assert len(data["clients"]) == 2
def test_searching(test_client: TestClient) -> None:
"""Test searching."""
for x in range(4):
# Create four clients
create_client(test_client, f"client-{x}")
# Create one with a different name.
create_client(test_client, "othername")
# Search for a specific one.
resp = test_client.get("/api/v1/clients/", params={"name": "othername"})
assert resp.status_code == 200
result = resp.json()
assert result["total_results"] == 1
assert result["clients"][0]["name"] == "othername"
client_id = result["clients"][0]["id"]
# Search by ID
resp = test_client.get("/api/v1/clients/", params={"id": client_id})
assert resp.status_code == 200
result = resp.json()
assert result["total_results"] == 1
assert result["clients"][0]["name"] == "othername"
# Search for the four similarly named ones
resp = test_client.get("/api/v1/clients/", params={"name__like": "client-%"})
assert resp.status_code == 200
result = resp.json()
assert result["total_results"] == 4
assert str(result["clients"][0]["name"]).startswith("client-")
def test_operations_with_id(test_client: TestClient) -> None:
"""Test operations using ID instead of name."""
create_client(test_client, "test")
resp = test_client.get("/api/v1/clients/")
assert resp.status_code == 200
data = resp.json()
client = data["clients"][0]
client_id = client["id"]
resp = test_client.get(f"/api/v1/clients/{client_id}")
assert resp.status_code == 200
data = resp.json()
assert data["name"] == "test"
def test_write_audit_log(test_client: TestClient) -> None:
"""Test writing to the audit log."""
params = {
"object": "Test",
"operation": "TEST",
"object_id": "Something",
"message": "Test Message"
}
resp = test_client.post("/api/v1/audit", json=params)
assert resp.status_code == 200
resp = test_client.get("/api/v1/audit")
assert resp.status_code == 200
data = resp.json()
entry = data[0]
for key, value in params.items():
assert entry[key] == value