Refactor database layer and auditing
This commit is contained in:
@ -3,22 +3,22 @@ from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
from sqlmodel import create_engine
|
||||
|
||||
from alembic import context
|
||||
from sshecret_backend.models import *
|
||||
|
||||
def get_database_url() -> str:
|
||||
"""Get database URL."""
|
||||
if db_file := os.getenv("SSHECRET_BACKEND_DB"):
|
||||
return f"sqlite:///{db_file}"
|
||||
return "sqlite:///sshecret.db"
|
||||
|
||||
from sshecret_backend.models import Base
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
|
||||
def get_database_url() -> str | None:
|
||||
"""Get database URL."""
|
||||
if db_file := os.getenv("SSHECRET_BACKEND_DB"):
|
||||
return f"sqlite:///{db_file}"
|
||||
return config.get_main_option("sqlalchemy.url")
|
||||
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
@ -28,8 +28,7 @@ if config.config_file_name is not None:
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
#target_metadata = None
|
||||
target_metadata = SQLModel.metadata
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
@ -68,7 +67,11 @@ def run_migrations_online() -> None:
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = create_engine(get_database_url())
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
|
||||
@ -9,7 +9,6 @@ from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
@ -5,13 +5,14 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from fastapi import APIRouter, Depends, Request, Query
|
||||
from sqlmodel import Session, col, func, select
|
||||
from sqlalchemy import desc
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Annotated
|
||||
|
||||
from sshecret_backend.models import AuditLog
|
||||
from sshecret_backend.types import DBSessionDep
|
||||
from sshecret_backend.view_models import AuditInfo
|
||||
from sshecret_backend.view_models import AuditInfo, AuditView
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -21,7 +22,7 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
"""Construct audit sub-api."""
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/audit/", response_model=list[AuditLog])
|
||||
@router.get("/audit/", response_model=list[AuditView])
|
||||
async def get_audit_logs(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
@ -29,35 +30,37 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
limit: Annotated[int, Query(le=100)] = 100,
|
||||
filter_client: Annotated[str | None, Query()] = None,
|
||||
filter_subsystem: Annotated[str | None, Query()] = None,
|
||||
) -> Sequence[AuditLog]:
|
||||
) -> Sequence[AuditView]:
|
||||
"""Get audit logs."""
|
||||
#audit.audit_access_audit_log(session, request)
|
||||
statement = select(AuditLog).offset(offset).limit(limit).order_by(desc(col(AuditLog.timestamp)))
|
||||
statement = select(AuditLog).offset(offset).limit(limit).order_by(AuditLog.timestamp.desc())
|
||||
if filter_client:
|
||||
statement = statement.where(AuditLog.client_name == filter_client)
|
||||
|
||||
if filter_subsystem:
|
||||
statement = statement.where(AuditLog.subsystem == filter_subsystem)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
return results
|
||||
LogAdapt = TypeAdapter(list[AuditView])
|
||||
results = session.scalars(statement).all()
|
||||
return LogAdapt.validate_python(results, from_attributes=True)
|
||||
|
||||
|
||||
@router.post("/audit/")
|
||||
async def add_audit_log(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
entry: AuditLog,
|
||||
) -> AuditLog:
|
||||
entry: AuditView,
|
||||
) -> AuditView:
|
||||
"""Add entry to audit log."""
|
||||
audit_log = AuditLog.model_validate(entry.model_dump(exclude_none=True))
|
||||
audit_log = AuditLog(**entry.model_dump(exclude_none=True))
|
||||
session.add(audit_log)
|
||||
session.commit()
|
||||
return audit_log
|
||||
return AuditView.model_validate(audit_log, from_attributes=True)
|
||||
|
||||
@router.get("/audit/info")
|
||||
async def get_audit_info(request: Request, session: Annotated[Session, Depends(get_db_session)]) -> AuditInfo:
|
||||
"""Get audit info."""
|
||||
audit_count = session.exec(select(func.count('*')).select_from(AuditLog)).one()
|
||||
audit_count = session.scalars(select(func.count('*')).select_from(AuditLog)).one()
|
||||
return AuditInfo(entries=audit_count)
|
||||
|
||||
|
||||
|
||||
@ -6,11 +6,11 @@ import uuid
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from sqlmodel import Session, col, select
|
||||
from sqlalchemy import func
|
||||
from typing import Annotated, Self, TypeVar
|
||||
from typing import Annotated, Any, Self, TypeVar, cast
|
||||
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import Select
|
||||
from sshecret_backend.types import DBSessionDep
|
||||
from sshecret_backend.models import Client, ClientSecret
|
||||
from sshecret_backend.view_models import (
|
||||
@ -55,8 +55,8 @@ T = TypeVar("T")
|
||||
|
||||
|
||||
def filter_client_statement(
|
||||
statement: SelectOfScalar[T], params: ClientListParams, ignore_limits: bool = False
|
||||
) -> SelectOfScalar[T]:
|
||||
statement: Select[Any], params: ClientListParams, ignore_limits: bool = False
|
||||
) -> Select[Any]:
|
||||
"""Filter a statement with the provided params."""
|
||||
if params.id:
|
||||
statement = statement.where(Client.id == params.id)
|
||||
@ -64,9 +64,9 @@ def filter_client_statement(
|
||||
if params.name:
|
||||
statement = statement.where(Client.name == params.name)
|
||||
elif params.name__like:
|
||||
statement = statement.where(col(Client.name).like(params.name__like))
|
||||
statement = statement.where(Client.name.like(params.name__like))
|
||||
elif params.name__contains:
|
||||
statement = statement.where(col(Client.name).contains(params.name__contains))
|
||||
statement = statement.where(Client.name.contains(params.name__contains))
|
||||
|
||||
if ignore_limits:
|
||||
return statement
|
||||
@ -86,13 +86,13 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
"""Get clients."""
|
||||
# Get total results first
|
||||
count_statement = select(func.count("*")).select_from(Client)
|
||||
count_statement = filter_client_statement(count_statement, filter_query, True)
|
||||
count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True))
|
||||
|
||||
total_results = session.exec(count_statement).one()
|
||||
total_results = session.scalars(count_statement).one()
|
||||
|
||||
statement = filter_client_statement(select(Client), filter_query, False)
|
||||
|
||||
results = session.exec(statement)
|
||||
results = session.scalars(statement)
|
||||
remainder = total_results - filter_query.offset - filter_query.limit
|
||||
if remainder < 0:
|
||||
remainder = 0
|
||||
@ -170,13 +170,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
client.public_key = client_update.public_key
|
||||
for secret in session.exec(
|
||||
for secret in session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||
).all():
|
||||
LOG.debug("Invalidated secret %s", secret.id)
|
||||
secret.invalidated = True
|
||||
secret.client_id = None
|
||||
secret.client = None
|
||||
|
||||
session.add(client)
|
||||
session.refresh(client)
|
||||
@ -206,13 +205,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
public_key_updated = False
|
||||
if client_update.public_key != client.public_key:
|
||||
public_key_updated = True
|
||||
for secret in session.exec(
|
||||
for secret in session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||
).all():
|
||||
LOG.debug("Invalidated secret %s", secret.id)
|
||||
secret.invalidated = True
|
||||
secret.client_id = None
|
||||
secret.client = None
|
||||
|
||||
session.add(client)
|
||||
session.commit()
|
||||
|
||||
@ -4,7 +4,8 @@ import re
|
||||
import uuid
|
||||
import bcrypt
|
||||
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from sshecret_backend.models import Client
|
||||
|
||||
@ -20,13 +21,13 @@ def verify_token(token: str, stored_hash: str) -> bool:
|
||||
async def get_client_by_name(session: Session, name: str) -> Client | None:
|
||||
"""Get client by name."""
|
||||
client_filter = select(Client).where(Client.name == name)
|
||||
client_results = session.exec(client_filter)
|
||||
client_results = session.scalars(client_filter)
|
||||
return client_results.first()
|
||||
|
||||
async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None:
|
||||
"""Get client by name."""
|
||||
client_filter = select(Client).where(Client.id == id)
|
||||
client_results = session.exec(client_filter)
|
||||
client_results = session.scalars(client_filter)
|
||||
return client_results.first()
|
||||
|
||||
async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None:
|
||||
|
||||
@ -4,7 +4,8 @@
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Annotated
|
||||
|
||||
from sshecret_backend.models import ClientAccessPolicy
|
||||
@ -54,7 +55,7 @@ def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
# Remove old policies.
|
||||
policies = session.exec(
|
||||
policies = session.scalars(
|
||||
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
|
||||
).all()
|
||||
deleted_policies: list[ClientAccessPolicy] = []
|
||||
|
||||
@ -5,7 +5,8 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Annotated
|
||||
|
||||
from sshecret_backend.models import Client, ClientSecret
|
||||
@ -34,7 +35,7 @@ async def lookup_client_secret(
|
||||
.where(ClientSecret.client_id == client.id)
|
||||
.where(ClientSecret.name == name)
|
||||
)
|
||||
results = session.exec(statement)
|
||||
results = session.scalars(statement)
|
||||
return results.first()
|
||||
|
||||
|
||||
@ -165,7 +166,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
) -> list[ClientSecretList]:
|
||||
"""Get a list of all secrets and which clients have them."""
|
||||
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
|
||||
for client_secret in session.exec(select(ClientSecret)).all():
|
||||
for client_secret in session.scalars(select(ClientSecret)).all():
|
||||
if not client_secret.client:
|
||||
if client_secret.name not in client_secret_map:
|
||||
client_secret_map[client_secret.name] = []
|
||||
@ -182,7 +183,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
) -> list[ClientSecretDetailList]:
|
||||
"""Get a list of all secrets and which clients have them."""
|
||||
client_secrets: dict[str, ClientSecretDetailList] = {}
|
||||
for client_secret in session.exec(select(ClientSecret)).all():
|
||||
for client_secret in session.scalars(select(ClientSecret)).all():
|
||||
|
||||
if client_secret.name not in client_secrets:
|
||||
client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name)
|
||||
@ -202,7 +203,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
) -> ClientSecretList:
|
||||
"""Get a list of which clients has a named secret."""
|
||||
clients: list[str] = []
|
||||
for client_secret in session.exec(
|
||||
for client_secret in session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.name == name)
|
||||
).all():
|
||||
if not client_secret.client:
|
||||
@ -219,7 +220,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
) -> ClientSecretDetailList:
|
||||
"""Get a list of which clients has a named secret."""
|
||||
detail_list = ClientSecretDetailList(name=name)
|
||||
for client_secret in session.exec(
|
||||
for client_secret in session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.name == name)
|
||||
).all():
|
||||
if not client_secret.client:
|
||||
|
||||
@ -2,9 +2,10 @@
|
||||
|
||||
from collections.abc import Sequence
|
||||
from fastapi import Request
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy
|
||||
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy, Operation, SubSystem
|
||||
|
||||
|
||||
def _get_origin(request: Request) -> str | None:
|
||||
@ -22,7 +23,7 @@ def _write_audit_log(
|
||||
"""Write the audit log."""
|
||||
origin = _get_origin(request)
|
||||
entry.origin = origin
|
||||
entry.subsystem = "backend"
|
||||
entry.subsystem = SubSystem.BACKEND
|
||||
session.add(entry)
|
||||
if commit:
|
||||
session.commit()
|
||||
@ -33,7 +34,7 @@ def audit_create_client(
|
||||
) -> None:
|
||||
"""Log the creation of a client."""
|
||||
entry = AuditLog(
|
||||
operation="CREATE",
|
||||
operation=Operation.CREATE,
|
||||
client_id=client.id,
|
||||
client_name=client.name,
|
||||
message="Client Created",
|
||||
@ -46,7 +47,7 @@ def audit_delete_client(
|
||||
) -> None:
|
||||
"""Log the creation of a client."""
|
||||
entry = AuditLog(
|
||||
operation="CREATE",
|
||||
operation=Operation.CREATE,
|
||||
client_id=client.id,
|
||||
client_name=client.name,
|
||||
message="Client deleted",
|
||||
@ -63,9 +64,9 @@ def audit_create_secret(
|
||||
) -> None:
|
||||
"""Audit a create secret event."""
|
||||
entry = AuditLog(
|
||||
operation="CREATE",
|
||||
object="ClientSecret",
|
||||
object_id=str(secret.id),
|
||||
operation=Operation.CREATE,
|
||||
secret_id=secret.id,
|
||||
secret_name=secret.name,
|
||||
client_id=client.id,
|
||||
client_name=client.name,
|
||||
message="Added secret to client",
|
||||
@ -81,13 +82,13 @@ def audit_remove_policy(
|
||||
commit: bool = True,
|
||||
) -> None:
|
||||
"""Audit removal of policy."""
|
||||
data = {"object": "ClientAccessPolicy", "object_id": str(policy.id)}
|
||||
entry = AuditLog(
|
||||
operation="DELETE",
|
||||
object="ClientAccessPolicy",
|
||||
object_id=str(policy.id),
|
||||
operation=Operation.DELETE,
|
||||
client_id=client.id,
|
||||
client_name=client.name,
|
||||
message="Deleted client policy",
|
||||
data=data,
|
||||
)
|
||||
_write_audit_log(session, request, entry, commit)
|
||||
|
||||
@ -100,13 +101,13 @@ def audit_update_policy(
|
||||
commit: bool = True,
|
||||
) -> None:
|
||||
"""Audit update of policy."""
|
||||
data: dict[str, str] = {"object": "ClientAccessPolicy", "object_id": str(policy.id)}
|
||||
entry = AuditLog(
|
||||
operation="CREATE",
|
||||
object="ClientAccessPolicy",
|
||||
object_id=str(policy.id),
|
||||
client_id=client.id,
|
||||
operation=Operation.CREATE,
|
||||
client_name=client.name,
|
||||
client_id=client.id,
|
||||
message="Updated client policy",
|
||||
data=data,
|
||||
)
|
||||
_write_audit_log(session, request, entry, commit)
|
||||
|
||||
@ -119,11 +120,10 @@ def audit_update_client(
|
||||
) -> None:
|
||||
"""Audit an update secret event."""
|
||||
entry = AuditLog(
|
||||
operation="UPDATE",
|
||||
object="Client",
|
||||
operation=Operation.UPDATE,
|
||||
client_id=client.id,
|
||||
client_name=client.name,
|
||||
message="Client updated",
|
||||
message="Client data updated",
|
||||
)
|
||||
_write_audit_log(session, request, entry, commit)
|
||||
|
||||
@ -137,11 +137,11 @@ def audit_update_secret(
|
||||
) -> None:
|
||||
"""Audit an update secret event."""
|
||||
entry = AuditLog(
|
||||
operation="UPDATE",
|
||||
object="ClientSecret",
|
||||
object_id=str(secret.id),
|
||||
operation=Operation.UPDATE,
|
||||
client_id=client.id,
|
||||
client_name=client.name,
|
||||
secret_name=secret.name,
|
||||
secret_id=secret.id,
|
||||
message="Secret value updated",
|
||||
)
|
||||
_write_audit_log(session, request, entry, commit)
|
||||
@ -155,8 +155,7 @@ def audit_invalidate_secrets(
|
||||
) -> None:
|
||||
"""Audit Invalidate client secrets."""
|
||||
entry = AuditLog(
|
||||
operation="INVALIDATE",
|
||||
object="ClientSecret",
|
||||
operation=Operation.UPDATE,
|
||||
client_name=client.name,
|
||||
client_id=client.id,
|
||||
message="Client public-key changed. All secrets invalidated.",
|
||||
@ -173,9 +172,9 @@ def audit_delete_secret(
|
||||
) -> None:
|
||||
"""Audit Delete client secrets."""
|
||||
entry = AuditLog(
|
||||
operation="DELETE",
|
||||
object="ClientSecret",
|
||||
object_id=str(secret.id),
|
||||
operation=Operation.DELETE,
|
||||
secret_name=secret.name,
|
||||
secret_id=secret.id,
|
||||
client_name=client.name,
|
||||
client_id=client.id,
|
||||
message="Deleted secret.",
|
||||
@ -195,7 +194,7 @@ def audit_access_secrets(
|
||||
With no secrets provided, all secrets of the client will be resolved.
|
||||
"""
|
||||
if not secrets:
|
||||
secrets = session.exec(
|
||||
secrets = session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||
).all()
|
||||
|
||||
@ -215,37 +214,21 @@ def audit_access_secret(
|
||||
) -> None:
|
||||
"""Audit that someone accessed one secrets."""
|
||||
entry = AuditLog(
|
||||
operation="ACCESS",
|
||||
operation=Operation.READ,
|
||||
message="Secret was viewed",
|
||||
object="ClientSecret",
|
||||
object_id=str(secret.id),
|
||||
secret_name=secret.name,
|
||||
secret_id=secret.id,
|
||||
client_id=client.id,
|
||||
client_name=client.name,
|
||||
)
|
||||
_write_audit_log(session, request, entry, commit)
|
||||
|
||||
|
||||
def audit_access_audit_log(
|
||||
session: Session, request: Request, commit: bool = True
|
||||
) -> None:
|
||||
"""Audit access to the audit log.
|
||||
|
||||
Because why not...
|
||||
"""
|
||||
entry = AuditLog(
|
||||
operation="ACCESS",
|
||||
message="Audit log was viewed",
|
||||
object="AuditLog",
|
||||
)
|
||||
_write_audit_log(session, request, entry, commit)
|
||||
|
||||
|
||||
def audit_client_secret_list(
|
||||
session: Session, request: Request, commit: bool = True
|
||||
) -> None:
|
||||
"""Audit a list of all secrets."""
|
||||
entry = AuditLog(
|
||||
operation="ACCESS",
|
||||
operation=Operation.READ,
|
||||
message="All secret names and their clients was viewed",
|
||||
)
|
||||
_write_audit_log(session, request, entry, commit)
|
||||
|
||||
@ -3,11 +3,11 @@
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
import bcrypt
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api
|
||||
from .auth import verify_token
|
||||
from .models import (
|
||||
APIClient,
|
||||
)
|
||||
@ -18,13 +18,6 @@ LOG = logging.getLogger(__name__)
|
||||
API_VERSION = "v1"
|
||||
|
||||
|
||||
def verify_token(token: str, stored_hash: str) -> bool:
|
||||
"""Verify token."""
|
||||
token_bytes = token.encode("utf-8")
|
||||
stored_bytes = stored_hash.encode("utf-8")
|
||||
return bcrypt.checkpw(token_bytes, stored_bytes)
|
||||
|
||||
|
||||
def get_backend_api(
|
||||
get_db_session: DBSessionDep,
|
||||
) -> APIRouter:
|
||||
@ -37,7 +30,7 @@ def get_backend_api(
|
||||
"""Validate token."""
|
||||
LOG.debug("Validating token %s", x_api_token)
|
||||
statement = select(APIClient)
|
||||
results = session.exec(statement)
|
||||
results = session.scalars(statement)
|
||||
valid = False
|
||||
for result in results:
|
||||
if verify_token(x_api_token, result.token):
|
||||
|
||||
@ -3,15 +3,24 @@
|
||||
import code
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
from dotenv import load_dotenv
|
||||
from typing import Literal, cast
|
||||
|
||||
import click
|
||||
from sqlmodel import Session, col, func, select
|
||||
import uvicorn
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .db import get_engine, create_api_token
|
||||
|
||||
from .models import Client, ClientSecret, ClientAccessPolicy, AuditLog, APIClient, init_db
|
||||
from .db import create_api_token, get_engine, hash_token
|
||||
from .models import (
|
||||
APIClient,
|
||||
AuditLog,
|
||||
Client,
|
||||
ClientAccessPolicy,
|
||||
ClientSecret,
|
||||
SubSystem,
|
||||
init_db,
|
||||
)
|
||||
from .settings import BackendSettings
|
||||
|
||||
DEFAULT_LISTEN = "127.0.0.1"
|
||||
@ -21,22 +30,44 @@ WORKDIR = Path(os.getcwd())
|
||||
|
||||
load_dotenv()
|
||||
|
||||
def generate_token(settings: BackendSettings) -> str:
|
||||
|
||||
def generate_token(
|
||||
settings: BackendSettings, subsystem: Literal["admin", "sshd"]
|
||||
) -> str:
|
||||
"""Generate a token."""
|
||||
engine = get_engine(settings.db_url)
|
||||
init_db(engine)
|
||||
with Session(engine) as session:
|
||||
token = create_api_token(session, True)
|
||||
token = create_api_token(session, subsystem)
|
||||
return token
|
||||
|
||||
def count_tokens(settings: BackendSettings) -> int:
|
||||
"""Count the amount of tokens created."""
|
||||
|
||||
def add_system_tokens(settings: BackendSettings) -> None:
|
||||
"""Add token for subsystems."""
|
||||
if not settings.admin_token and not settings.sshd_token:
|
||||
# Tokens should be generated manually.
|
||||
return
|
||||
|
||||
engine = get_engine(settings.db_url)
|
||||
init_db(engine)
|
||||
tokens: list[tuple[str, SubSystem]] = []
|
||||
if admin_token := settings.admin_token:
|
||||
tokens.append((admin_token, SubSystem.ADMIN))
|
||||
if sshd_token := settings.sshd_token:
|
||||
tokens.append((sshd_token, SubSystem.SSHD))
|
||||
with Session(engine) as session:
|
||||
count = session.exec(select(func.count("*")).select_from(APIClient)).one()
|
||||
for token, subsystem in tokens:
|
||||
hashed_token = hash_token(token)
|
||||
if existing := session.scalars(
|
||||
select(APIClient).where(APIClient.subsystem == subsystem)
|
||||
).first():
|
||||
existing.token = hashed_token
|
||||
else:
|
||||
new_token = APIClient(token=hashed_token, subsystem=subsystem)
|
||||
session.add(new_token)
|
||||
|
||||
return count
|
||||
session.commit()
|
||||
click.echo("Generated system tokens.")
|
||||
|
||||
|
||||
@click.group()
|
||||
@ -49,27 +80,30 @@ def cli(ctx: click.Context, database: str) -> None:
|
||||
else:
|
||||
settings = BackendSettings()
|
||||
|
||||
add_system_tokens(settings)
|
||||
|
||||
if settings.generate_initial_tokens:
|
||||
if count_tokens(settings) == 0:
|
||||
click.echo("Creating initial tokens for admin and sshd.")
|
||||
admin_token = generate_token(settings)
|
||||
sshd_token = generate_token(settings)
|
||||
click.echo(f"Admin token: {admin_token}")
|
||||
click.echo(f"SSHD token: {sshd_token}")
|
||||
# if settings.generate_initial_tokens:
|
||||
# if count_tokens(settings) == 0:
|
||||
# click.echo("Creating initial tokens for admin and sshd.")
|
||||
# admin_token = generate_token(settings)
|
||||
# sshd_token = generate_token(settings)
|
||||
# click.echo(f"Admin token: {admin_token}")
|
||||
# click.echo(f"SSHD token: {sshd_token}")
|
||||
|
||||
ctx.obj = settings
|
||||
|
||||
|
||||
@cli.command("generate-token")
|
||||
@click.argument("subsystem", type=click.Choice(["sshd", "admin"]))
|
||||
@click.pass_context
|
||||
def cli_generate_token(ctx: click.Context) -> None:
|
||||
"""Generate a token."""
|
||||
def cli_generate_token(ctx: click.Context, subsystem: Literal["sshd", "admin"]) -> None:
|
||||
"""Generate a token for a subsystem.."""
|
||||
settings = cast(BackendSettings, ctx.obj)
|
||||
token = generate_token(settings)
|
||||
token = generate_token(settings, subsystem)
|
||||
click.echo("Generated api token:")
|
||||
click.echo(token)
|
||||
|
||||
|
||||
@cli.command("run")
|
||||
@click.option("--host", default="127.0.0.1")
|
||||
@click.option("--port", default=8022, type=click.INT)
|
||||
@ -77,7 +111,10 @@ def cli_generate_token(ctx: click.Context) -> None:
|
||||
@click.option("--workers", type=click.INT)
|
||||
def cli_run(host: str, port: int, dev: bool, workers: int | None) -> None:
|
||||
"""Run the server."""
|
||||
uvicorn.run("sshecret_backend.main:app", host=host, port=port, reload=dev, workers=workers)
|
||||
uvicorn.run(
|
||||
"sshecret_backend.main:app", host=host, port=port, reload=dev, workers=workers
|
||||
)
|
||||
|
||||
|
||||
@cli.command("repl")
|
||||
@click.pass_context
|
||||
|
||||
@ -2,57 +2,109 @@
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
import sqlite3
|
||||
|
||||
from collections.abc import Generator, Callable
|
||||
from pathlib import Path
|
||||
from sqlalchemy import Engine
|
||||
from sqlmodel import Session, create_engine, text
|
||||
import bcrypt
|
||||
from typing import Literal
|
||||
from sqlalchemy import create_engine, Engine, event, select
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
|
||||
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
|
||||
from sqlalchemy.engine import URL
|
||||
|
||||
|
||||
from .models import APIClient
|
||||
from .auth import hash_token, verify_token
|
||||
from .models import APIClient, SubSystem
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_database(
|
||||
db_url: URL | str,
|
||||
db_url: URL,
|
||||
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
|
||||
"""Setup database."""
|
||||
|
||||
engine = create_engine(db_url, echo=False)
|
||||
with engine.connect() as connection:
|
||||
connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only
|
||||
engine = get_engine(db_url)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, future=True)
|
||||
|
||||
def get_db_session() -> Generator[Session, None, None]:
|
||||
"""Get DB Session."""
|
||||
with Session(engine) as session:
|
||||
session = SessionLocal(bind=engine)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
return engine, get_db_session
|
||||
|
||||
|
||||
def get_engine(url: URL, echo: bool = False) -> Engine:
|
||||
"""Initialize the engine."""
|
||||
engine = create_engine(url, echo=echo)
|
||||
with engine.connect() as connection:
|
||||
connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only
|
||||
engine = create_engine(url, echo=echo, future=True)
|
||||
if url.drivername.startswith("sqlite"):
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(
|
||||
dbapi_connection: sqlite3.Connection, _connection_record: object
|
||||
) -> None:
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def create_api_token(session: Session, read_write: bool) -> str:
|
||||
"""Create API token."""
|
||||
token = secrets.token_urlsafe(32)
|
||||
pwbytes = token.encode("utf-8")
|
||||
salt = bcrypt.gensalt()
|
||||
hashed_bytes = bcrypt.hashpw(password=pwbytes, salt=salt)
|
||||
hashed = hashed_bytes.decode()
|
||||
def get_async_engine(url: URL, echo: bool = False) -> AsyncEngine:
|
||||
"""Get an async engine."""
|
||||
engine = create_async_engine(url, echo=echo, future=True)
|
||||
if url.drivername.startswith("sqlite+"):
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(
|
||||
dbapi_connection: sqlite3.Connection, _connection_record: object
|
||||
) -> None:
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
|
||||
def create_api_token_with_value(session: Session, token: str, subsystem: Literal["admin", "sshd"]) -> None:
|
||||
"""Create API token with a given value."""
|
||||
|
||||
existing = session.scalars(select(APIClient).where(APIClient.subsystem == SubSystem(subsystem))).first()
|
||||
if existing:
|
||||
if verify_token(token, existing.token):
|
||||
LOG.info("Token is up to date.")
|
||||
return
|
||||
LOG.info("Updating token value for subsystem %s", subsystem)
|
||||
hashed = hash_token(token)
|
||||
existing.token=hashed
|
||||
session.commit()
|
||||
return
|
||||
|
||||
LOG.info("No existing token found. Creating new")
|
||||
hashed = hash_token(token)
|
||||
api_token = APIClient(token=hashed, subsystem=SubSystem(subsystem))
|
||||
|
||||
api_token = APIClient(token=hashed, read_write=read_write)
|
||||
session.add(api_token)
|
||||
session.commit()
|
||||
|
||||
def create_api_token(session: Session, subsystem: Literal["admin", "sshd", "test"], recreate: bool = False) -> str:
|
||||
"""Create API token."""
|
||||
subsys = SubSystem(subsystem)
|
||||
token = secrets.token_urlsafe(32)
|
||||
hashed = hash_token(token)
|
||||
if existing := session.scalars(select(APIClient).where(APIClient.subsystem == subsys)).first():
|
||||
if not recreate:
|
||||
raise RuntimeError("Error: A token already exist for this subsystem.")
|
||||
existing.token = hashed
|
||||
else:
|
||||
api_token = APIClient(token=hashed, subsystem=subsys)
|
||||
session.add(api_token)
|
||||
session.commit()
|
||||
|
||||
return token
|
||||
|
||||
@ -7,128 +7,182 @@ This might require some changes to these schemas.
|
||||
|
||||
"""
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlmodel import JSON, Column, DateTime, Field, Relationship, SQLModel
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Client(SQLModel, table=True):
|
||||
"""Client model."""
|
||||
class SubSystem(enum.StrEnum):
|
||||
"""Available subsystems."""
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
name: str = Field(unique=True)
|
||||
description: str | None = None
|
||||
public_key: str
|
||||
ADMIN = enum.auto()
|
||||
SSHD = enum.auto()
|
||||
BACKEND = enum.auto()
|
||||
TEST = enum.auto()
|
||||
|
||||
created_at: datetime | None = Field(
|
||||
default=None,
|
||||
sa_type=sa.DateTime(timezone=True),
|
||||
sa_column_kwargs={"server_default": sa.func.now()},
|
||||
nullable=False,
|
||||
|
||||
class Operation(enum.StrEnum):
|
||||
"""Various operations for the audit logging module."""
|
||||
|
||||
CREATE = enum.auto()
|
||||
READ = enum.auto()
|
||||
UPDATE = enum.auto()
|
||||
DELETE = enum.auto()
|
||||
DENY = enum.auto()
|
||||
PERMIT = enum.auto()
|
||||
LOGIN = enum.auto()
|
||||
REGISTER = enum.auto()
|
||||
NONE = enum.auto()
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class Client(Base):
|
||||
"""Clients."""
|
||||
|
||||
__tablename__: str = "client"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
name: Mapped[str] = mapped_column(sa.String, unique=True)
|
||||
description: Mapped[str | None] = mapped_column(sa.String, nullable=True)
|
||||
public_key: Mapped[str] = mapped_column(sa.Text)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||
)
|
||||
|
||||
updated_at: datetime | None = Field(
|
||||
default=None,
|
||||
sa_type=sa.DateTime(timezone=True),
|
||||
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
|
||||
updated_at: Mapped[datetime | None] = mapped_column(
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
onupdate=sa.func.now(),
|
||||
)
|
||||
|
||||
secrets: list["ClientSecret"] = Relationship(
|
||||
back_populates="client", passive_deletes="all"
|
||||
secrets: Mapped[list["ClientSecret"]] = relationship(
|
||||
back_populates="client", passive_deletes=True
|
||||
)
|
||||
|
||||
policies: list["ClientAccessPolicy"] = Relationship(back_populates="client")
|
||||
policies: Mapped[list["ClientAccessPolicy"]] = relationship(back_populates="client")
|
||||
|
||||
|
||||
class ClientAccessPolicy(SQLModel, table=True):
|
||||
class ClientAccessPolicy(Base):
|
||||
"""Client access policies."""
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
source: str
|
||||
client_id: uuid.UUID | None = Field(foreign_key="client.id", ondelete="CASCADE")
|
||||
client: Client | None = Relationship(back_populates="policies")
|
||||
__tablename__: str = "client_access_policy"
|
||||
|
||||
created_at: datetime | None = Field(
|
||||
default=None,
|
||||
sa_type=sa.DateTime(timezone=True),
|
||||
sa_column_kwargs={"server_default": sa.func.now()},
|
||||
nullable=False,
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
source: Mapped[str] = mapped_column(sa.String)
|
||||
client_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), sa.ForeignKey("client.id", ondelete="CASCADE")
|
||||
)
|
||||
client: Mapped[Client] = relationship(back_populates="policies")
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||
)
|
||||
|
||||
updated_at: datetime | None = Field(
|
||||
default=None,
|
||||
sa_type=sa.DateTime(timezone=True),
|
||||
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
|
||||
updated_at: Mapped[datetime | None] = mapped_column(
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
onupdate=sa.func.now(),
|
||||
)
|
||||
|
||||
|
||||
class ClientSecret(SQLModel, table=True):
|
||||
class ClientSecret(Base):
|
||||
"""A client secret."""
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
name: str
|
||||
description: str | None = None
|
||||
client_id: uuid.UUID | None = Field(foreign_key="client.id", ondelete="CASCADE")
|
||||
client: Client | None = Relationship(back_populates="secrets")
|
||||
secret: str
|
||||
invalidated: bool = Field(default=False)
|
||||
created_at: datetime | None = Field(
|
||||
default=None,
|
||||
sa_type=sa.DateTime(timezone=True),
|
||||
sa_column_kwargs={"server_default": sa.func.now()},
|
||||
nullable=False,
|
||||
__tablename__: str = "client_secret"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
name: Mapped[str] = mapped_column(sa.String)
|
||||
description: Mapped[str | None] = mapped_column(sa.String, nullable=True)
|
||||
secret: Mapped[str] = mapped_column(sa.String)
|
||||
|
||||
client_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), sa.ForeignKey("client.id", ondelete="CASCADE")
|
||||
)
|
||||
client: Mapped[Client] = relationship(back_populates="secrets")
|
||||
invalidated: Mapped[bool] = mapped_column(default=False)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||
)
|
||||
|
||||
updated_at: datetime | None = Field(
|
||||
default=None,
|
||||
sa_type=sa.DateTime(timezone=True),
|
||||
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
|
||||
updated_at: Mapped[datetime | None] = mapped_column(
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
onupdate=sa.func.now(),
|
||||
)
|
||||
|
||||
|
||||
class AuditLog(SQLModel, table=True):
|
||||
class APIClient(Base):
|
||||
"""A client on the API.
|
||||
|
||||
This should eventually get more granular permissions.
|
||||
"""
|
||||
|
||||
__tablename__: str = "api_client"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
subsystem: Mapped[SubSystem | None] = mapped_column(sa.String, nullable=True)
|
||||
token: Mapped[str] = mapped_column(sa.String)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime | None] = mapped_column(
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
onupdate=sa.func.now(),
|
||||
)
|
||||
|
||||
|
||||
class AuditLog(Base):
|
||||
"""Audit log.
|
||||
|
||||
This is implemented without any foreign keys to avoid losing data on
|
||||
deletions.
|
||||
"""
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
subsystem: str
|
||||
message: str
|
||||
operation: str
|
||||
client_id: uuid.UUID | None = None
|
||||
client_name: str | None = None
|
||||
origin: str | None = None
|
||||
Field(default=None, sa_column=Column(JSON))
|
||||
|
||||
timestamp: datetime | None = Field(
|
||||
default=None,
|
||||
sa_type=sa.DateTime(timezone=True),
|
||||
sa_column_kwargs={"server_default": sa.func.now()},
|
||||
nullable=False,
|
||||
__tablename__: str = "audit_log"
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
|
||||
subsystem: Mapped[SubSystem] = mapped_column(sa.String)
|
||||
message: Mapped[str] = mapped_column(sa.String)
|
||||
operation: Mapped[Operation] = mapped_column(sa.String)
|
||||
client_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), nullable=True
|
||||
)
|
||||
data: Mapped[dict[str, str] | None] = mapped_column(sa.JSON, nullable=True)
|
||||
client_name: Mapped[str | None] = mapped_column(sa.String, nullable=True)
|
||||
secret_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), nullable=True
|
||||
)
|
||||
secret_name: Mapped[str | None] = mapped_column(sa.String, nullable=True)
|
||||
|
||||
class APIClient(SQLModel, table=True):
|
||||
"""Stores API Keys."""
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
token: str
|
||||
read_write: bool
|
||||
created_at: datetime | None = Field(
|
||||
default=None,
|
||||
sa_type=sa.DateTime(timezone=True),
|
||||
sa_column_kwargs={"server_default": sa.func.now()},
|
||||
nullable=False,
|
||||
origin: Mapped[str | None] = mapped_column(sa.String, nullable=True)
|
||||
timestamp: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||
)
|
||||
|
||||
|
||||
def init_db(engine: sa.Engine) -> None:
|
||||
"""Create database."""
|
||||
LOG.info("Running init_db")
|
||||
SQLModel.metadata.create_all(engine)
|
||||
"""Initialize database."""
|
||||
Base.metadata.create_all(engine)
|
||||
|
||||
@ -1,12 +1,10 @@
|
||||
"""Settings management."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic import Field
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
SettingsConfigDict,
|
||||
ForceDecode,
|
||||
)
|
||||
from sqlalchemy import URL
|
||||
|
||||
@ -22,24 +20,19 @@ class BackendSettings(BaseSettings):
|
||||
)
|
||||
|
||||
database: str = Field(default=DEFAULT_DATABASE)
|
||||
generate_initial_tokens: Annotated[bool, ForceDecode] = Field(default=False)
|
||||
|
||||
@field_validator("generate_initial_tokens", mode="before")
|
||||
@classmethod
|
||||
def cast_bool(cls, value: Any) -> bool:
|
||||
"""Ensure we catch the boolean."""
|
||||
if isinstance(value, str):
|
||||
if value.lower() in ("1", "true", "on"):
|
||||
return True
|
||||
if value.lower() in ("0", "false", "off"):
|
||||
return False
|
||||
return bool(value)
|
||||
admin_token: str | None = Field(default=None, alias="sshecret_admin_backend_token")
|
||||
sshd_token: str | None = Field(default=None, alias="sshecret_sshd_backend_token")
|
||||
|
||||
@property
|
||||
def db_url(self) -> URL:
|
||||
"""Construct database url."""
|
||||
return URL.create(drivername="sqlite", database=self.database)
|
||||
|
||||
@property
|
||||
def async_db_url(self) -> URL:
|
||||
"""Construct database url with sync handling."""
|
||||
return URL.create(drivername="sqlite+aiosqlite", database=self.database)
|
||||
|
||||
@property
|
||||
def db_exists(self) -> bool:
|
||||
"""Check if databatase exists."""
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
"""Test helpers."""
|
||||
|
||||
import logging
|
||||
from sqlmodel import Session
|
||||
from sshecret_backend.settings import BackendSettings
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from .models import init_db
|
||||
from .db import create_api_token, setup_database
|
||||
|
||||
@ -14,4 +15,4 @@ def create_test_token(settings: BackendSettings) -> str:
|
||||
engine, _setupdb = setup_database(settings.db_url)
|
||||
with Session(engine) as session:
|
||||
init_db(engine)
|
||||
return create_api_token(session, True)
|
||||
return create_api_token(session, "test")
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
from collections.abc import Callable, Generator
|
||||
|
||||
from sqlmodel import Session
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
DBSessionDep = Callable[[], Generator[Session, None, None]]
|
||||
|
||||
@ -4,15 +4,14 @@ import uuid
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Self, override
|
||||
|
||||
from sqlmodel import Field, SQLModel
|
||||
from pydantic import AfterValidator, IPvAnyAddress, IPvAnyNetwork
|
||||
from pydantic import AfterValidator, BaseModel, Field, IPvAnyAddress, IPvAnyNetwork
|
||||
|
||||
from sshecret.crypto import public_key_validator
|
||||
|
||||
from . import models
|
||||
|
||||
|
||||
class ClientView(SQLModel):
|
||||
class ClientView(BaseModel):
|
||||
"""View for a single client."""
|
||||
|
||||
id: uuid.UUID
|
||||
@ -50,7 +49,7 @@ class ClientView(SQLModel):
|
||||
return view
|
||||
|
||||
|
||||
class ClientQueryResult(SQLModel):
|
||||
class ClientQueryResult(BaseModel):
|
||||
"""Result class for queries towards the client list."""
|
||||
|
||||
clients: list[ClientView] = Field(default_factory=list)
|
||||
@ -58,7 +57,7 @@ class ClientQueryResult(SQLModel):
|
||||
remaining_results: int
|
||||
|
||||
|
||||
class ClientCreate(SQLModel):
|
||||
class ClientCreate(BaseModel):
|
||||
"""Model to create a client."""
|
||||
|
||||
name: str
|
||||
@ -74,19 +73,19 @@ class ClientCreate(SQLModel):
|
||||
)
|
||||
|
||||
|
||||
class ClientUpdate(SQLModel):
|
||||
class ClientUpdate(BaseModel):
|
||||
"""Model to update the client public key."""
|
||||
|
||||
public_key: Annotated[str, AfterValidator(public_key_validator)]
|
||||
|
||||
|
||||
class BodyValue(SQLModel):
|
||||
class BodyValue(BaseModel):
|
||||
"""A generic model with just a value parameter."""
|
||||
|
||||
value: str
|
||||
|
||||
|
||||
class ClientSecretPublic(SQLModel):
|
||||
class ClientSecretPublic(BaseModel):
|
||||
"""Public model to manage client secrets."""
|
||||
|
||||
name: str
|
||||
@ -122,7 +121,7 @@ class ClientSecretResponse(ClientSecretPublic):
|
||||
)
|
||||
|
||||
|
||||
class ClientPolicyView(SQLModel):
|
||||
class ClientPolicyView(BaseModel):
|
||||
"""Update object for client policy."""
|
||||
|
||||
sources: list[str] = ["0.0.0.0/0", "::/0"]
|
||||
@ -135,27 +134,27 @@ class ClientPolicyView(SQLModel):
|
||||
return cls(sources=[policy.source for policy in client.policies])
|
||||
|
||||
|
||||
class ClientPolicyUpdate(SQLModel):
|
||||
class ClientPolicyUpdate(BaseModel):
|
||||
"""Model for updating policies."""
|
||||
|
||||
sources: list[IPvAnyAddress | IPvAnyNetwork]
|
||||
|
||||
|
||||
class ClientSecretList(SQLModel):
|
||||
class ClientSecretList(BaseModel):
|
||||
"""Model for aggregating identically named secrets."""
|
||||
|
||||
name: str
|
||||
clients: list[str]
|
||||
|
||||
|
||||
class ClientReference(SQLModel):
|
||||
class ClientReference(BaseModel):
|
||||
"""Reference to a client."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class ClientSecretDetailList(SQLModel):
|
||||
class ClientSecretDetailList(BaseModel):
|
||||
"""A more detailed version of the ClientSecretList."""
|
||||
|
||||
name: str
|
||||
@ -163,7 +162,22 @@ class ClientSecretDetailList(SQLModel):
|
||||
clients: list[ClientReference] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AuditInfo(SQLModel):
|
||||
class AuditView(BaseModel):
|
||||
"""Audit log view."""
|
||||
|
||||
|
||||
id: uuid.UUID | None = None
|
||||
subsystem: models.SubSystem
|
||||
message: str
|
||||
operation: models.Operation
|
||||
data: dict[str, str] | None = None
|
||||
client_id: uuid.UUID | None = None
|
||||
client_name: str | None = None
|
||||
origin: str | None = None
|
||||
timestamp: datetime | None = None
|
||||
|
||||
|
||||
class AuditInfo(BaseModel):
|
||||
"""Information about audit information."""
|
||||
|
||||
entries: int
|
||||
|
||||
@ -10,7 +10,7 @@ from fastapi.testclient import TestClient
|
||||
from sshecret.crypto import generate_private_key, generate_public_key_string
|
||||
from sshecret_backend.app import create_backend_app
|
||||
from sshecret_backend.testing import create_test_token
|
||||
from sshecret_backend.models import AuditLog
|
||||
from sshecret_backend.view_models import AuditView
|
||||
from sshecret_backend.settings import BackendSettings
|
||||
|
||||
|
||||
@ -53,7 +53,7 @@ def create_client_fixture(tmp_path: Path):
|
||||
|
||||
db_file = tmp_path / "backend.db"
|
||||
print(f"DB File: {db_file.absolute()}")
|
||||
settings = BackendSettings(db_url=f"sqlite:///{db_file.absolute()}")
|
||||
settings = BackendSettings(database=str(db_file.absolute()))
|
||||
app = create_backend_app(settings)
|
||||
|
||||
token = create_test_token(settings)
|
||||
@ -213,7 +213,7 @@ def test_audit_logging(test_client: TestClient) -> None:
|
||||
assert len(audit_logs) > 0
|
||||
for entry in audit_logs:
|
||||
# Let's try to reassemble the objects
|
||||
audit_log = AuditLog.model_validate(entry)
|
||||
audit_log = AuditView.model_validate(entry)
|
||||
assert audit_log is not None
|
||||
|
||||
|
||||
@ -522,9 +522,8 @@ def test_operations_with_id(test_client: TestClient) -> None:
|
||||
def test_write_audit_log(test_client: TestClient) -> None:
|
||||
"""Test writing to the audit log."""
|
||||
params = {
|
||||
"object": "Test",
|
||||
"operation": "TEST",
|
||||
"object_id": "Something",
|
||||
"subsystem": "backend",
|
||||
"operation": "read",
|
||||
"message": "Test Message"
|
||||
}
|
||||
resp = test_client.post("/api/v1/audit", json=params)
|
||||
|
||||
Reference in New Issue
Block a user