Refactor database layer and auditing
This commit is contained in:
@ -3,22 +3,22 @@ from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
from sqlmodel import create_engine
|
||||
|
||||
from alembic import context
|
||||
from sshecret_backend.models import *
|
||||
|
||||
def get_database_url() -> str:
|
||||
"""Get database URL."""
|
||||
if db_file := os.getenv("SSHECRET_BACKEND_DB"):
|
||||
return f"sqlite:///{db_file}"
|
||||
return "sqlite:///sshecret.db"
|
||||
|
||||
from sshecret_backend.models import Base
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
|
||||
def get_database_url() -> str | None:
|
||||
"""Get database URL."""
|
||||
if db_file := os.getenv("SSHECRET_BACKEND_DB"):
|
||||
return f"sqlite:///{db_file}"
|
||||
return config.get_main_option("sqlalchemy.url")
|
||||
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
@ -28,8 +28,7 @@ if config.config_file_name is not None:
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
#target_metadata = None
|
||||
target_metadata = SQLModel.metadata
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
@ -68,7 +67,11 @@ def run_migrations_online() -> None:
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = create_engine(get_database_url())
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
|
||||
@ -9,7 +9,6 @@ from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
@ -5,13 +5,14 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from fastapi import APIRouter, Depends, Request, Query
|
||||
from sqlmodel import Session, col, func, select
|
||||
from sqlalchemy import desc
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Annotated
|
||||
|
||||
from sshecret_backend.models import AuditLog
|
||||
from sshecret_backend.types import DBSessionDep
|
||||
from sshecret_backend.view_models import AuditInfo
|
||||
from sshecret_backend.view_models import AuditInfo, AuditView
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -21,7 +22,7 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
"""Construct audit sub-api."""
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/audit/", response_model=list[AuditLog])
|
||||
@router.get("/audit/", response_model=list[AuditView])
|
||||
async def get_audit_logs(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
@ -29,35 +30,37 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
limit: Annotated[int, Query(le=100)] = 100,
|
||||
filter_client: Annotated[str | None, Query()] = None,
|
||||
filter_subsystem: Annotated[str | None, Query()] = None,
|
||||
) -> Sequence[AuditLog]:
|
||||
) -> Sequence[AuditView]:
|
||||
"""Get audit logs."""
|
||||
#audit.audit_access_audit_log(session, request)
|
||||
statement = select(AuditLog).offset(offset).limit(limit).order_by(desc(col(AuditLog.timestamp)))
|
||||
statement = select(AuditLog).offset(offset).limit(limit).order_by(AuditLog.timestamp.desc())
|
||||
if filter_client:
|
||||
statement = statement.where(AuditLog.client_name == filter_client)
|
||||
|
||||
if filter_subsystem:
|
||||
statement = statement.where(AuditLog.subsystem == filter_subsystem)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
return results
|
||||
LogAdapt = TypeAdapter(list[AuditView])
|
||||
results = session.scalars(statement).all()
|
||||
return LogAdapt.validate_python(results, from_attributes=True)
|
||||
|
||||
|
||||
@router.post("/audit/")
|
||||
async def add_audit_log(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
entry: AuditLog,
|
||||
) -> AuditLog:
|
||||
entry: AuditView,
|
||||
) -> AuditView:
|
||||
"""Add entry to audit log."""
|
||||
audit_log = AuditLog.model_validate(entry.model_dump(exclude_none=True))
|
||||
audit_log = AuditLog(**entry.model_dump(exclude_none=True))
|
||||
session.add(audit_log)
|
||||
session.commit()
|
||||
return audit_log
|
||||
return AuditView.model_validate(audit_log, from_attributes=True)
|
||||
|
||||
@router.get("/audit/info")
|
||||
async def get_audit_info(request: Request, session: Annotated[Session, Depends(get_db_session)]) -> AuditInfo:
|
||||
"""Get audit info."""
|
||||
audit_count = session.exec(select(func.count('*')).select_from(AuditLog)).one()
|
||||
audit_count = session.scalars(select(func.count('*')).select_from(AuditLog)).one()
|
||||
return AuditInfo(entries=audit_count)
|
||||
|
||||
|
||||
|
||||
@ -6,11 +6,11 @@ import uuid
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from sqlmodel import Session, col, select
|
||||
from sqlalchemy import func
|
||||
from typing import Annotated, Self, TypeVar
|
||||
from typing import Annotated, Any, Self, TypeVar, cast
|
||||
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import Select
|
||||
from sshecret_backend.types import DBSessionDep
|
||||
from sshecret_backend.models import Client, ClientSecret
|
||||
from sshecret_backend.view_models import (
|
||||
@ -55,8 +55,8 @@ T = TypeVar("T")
|
||||
|
||||
|
||||
def filter_client_statement(
|
||||
statement: SelectOfScalar[T], params: ClientListParams, ignore_limits: bool = False
|
||||
) -> SelectOfScalar[T]:
|
||||
statement: Select[Any], params: ClientListParams, ignore_limits: bool = False
|
||||
) -> Select[Any]:
|
||||
"""Filter a statement with the provided params."""
|
||||
if params.id:
|
||||
statement = statement.where(Client.id == params.id)
|
||||
@ -64,9 +64,9 @@ def filter_client_statement(
|
||||
if params.name:
|
||||
statement = statement.where(Client.name == params.name)
|
||||
elif params.name__like:
|
||||
statement = statement.where(col(Client.name).like(params.name__like))
|
||||
statement = statement.where(Client.name.like(params.name__like))
|
||||
elif params.name__contains:
|
||||
statement = statement.where(col(Client.name).contains(params.name__contains))
|
||||
statement = statement.where(Client.name.contains(params.name__contains))
|
||||
|
||||
if ignore_limits:
|
||||
return statement
|
||||
@ -86,13 +86,13 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
"""Get clients."""
|
||||
# Get total results first
|
||||
count_statement = select(func.count("*")).select_from(Client)
|
||||
count_statement = filter_client_statement(count_statement, filter_query, True)
|
||||
count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True))
|
||||
|
||||
total_results = session.exec(count_statement).one()
|
||||
total_results = session.scalars(count_statement).one()
|
||||
|
||||
statement = filter_client_statement(select(Client), filter_query, False)
|
||||
|
||||
results = session.exec(statement)
|
||||
results = session.scalars(statement)
|
||||
remainder = total_results - filter_query.offset - filter_query.limit
|
||||
if remainder < 0:
|
||||
remainder = 0
|
||||
@ -170,13 +170,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
client.public_key = client_update.public_key
|
||||
for secret in session.exec(
|
||||
for secret in session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||
).all():
|
||||
LOG.debug("Invalidated secret %s", secret.id)
|
||||
secret.invalidated = True
|
||||
secret.client_id = None
|
||||
secret.client = None
|
||||
|
||||
session.add(client)
|
||||
session.refresh(client)
|
||||
@ -206,13 +205,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
public_key_updated = False
|
||||
if client_update.public_key != client.public_key:
|
||||
public_key_updated = True
|
||||
for secret in session.exec(
|
||||
for secret in session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||
).all():
|
||||
LOG.debug("Invalidated secret %s", secret.id)
|
||||
secret.invalidated = True
|
||||
secret.client_id = None
|
||||
secret.client = None
|
||||
|
||||
session.add(client)
|
||||
session.commit()
|
||||
|
||||
@ -4,7 +4,8 @@ import re
|
||||
import uuid
|
||||
import bcrypt
|
||||
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from sshecret_backend.models import Client
|
||||
|
||||
@ -20,13 +21,13 @@ def verify_token(token: str, stored_hash: str) -> bool:
|
||||
async def get_client_by_name(session: Session, name: str) -> Client | None:
|
||||
"""Get client by name."""
|
||||
client_filter = select(Client).where(Client.name == name)
|
||||
client_results = session.exec(client_filter)
|
||||
client_results = session.scalars(client_filter)
|
||||
return client_results.first()
|
||||
|
||||
async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None:
|
||||
"""Get client by name."""
|
||||
client_filter = select(Client).where(Client.id == id)
|
||||
client_results = session.exec(client_filter)
|
||||
client_results = session.scalars(client_filter)
|
||||
return client_results.first()
|
||||
|
||||
async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None:
|
||||
|
||||
@ -4,7 +4,8 @@
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Annotated
|
||||
|
||||
from sshecret_backend.models import ClientAccessPolicy
|
||||
@ -54,7 +55,7 @@ def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
# Remove old policies.
|
||||
policies = session.exec(
|
||||
policies = session.scalars(
|
||||
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
|
||||
).all()
|
||||
deleted_policies: list[ClientAccessPolicy] = []
|
||||
|
||||
@ -5,7 +5,8 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Annotated
|
||||
|
||||
from sshecret_backend.models import Client, ClientSecret
|
||||
@ -34,7 +35,7 @@ async def lookup_client_secret(
|
||||
.where(ClientSecret.client_id == client.id)
|
||||
.where(ClientSecret.name == name)
|
||||
)
|
||||
results = session.exec(statement)
|
||||
results = session.scalars(statement)
|
||||
return results.first()
|
||||
|
||||
|
||||
@ -165,7 +166,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
) -> list[ClientSecretList]:
|
||||
"""Get a list of all secrets and which clients have them."""
|
||||
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
|
||||
for client_secret in session.exec(select(ClientSecret)).all():
|
||||
for client_secret in session.scalars(select(ClientSecret)).all():
|
||||
if not client_secret.client:
|
||||
if client_secret.name not in client_secret_map:
|
||||
client_secret_map[client_secret.name] = []
|
||||
@ -182,7 +183,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
) -> list[ClientSecretDetailList]:
|
||||
"""Get a list of all secrets and which clients have them."""
|
||||
client_secrets: dict[str, ClientSecretDetailList] = {}
|
||||
for client_secret in session.exec(select(ClientSecret)).all():
|
||||
for client_secret in session.scalars(select(ClientSecret)).all():
|
||||
|
||||
if client_secret.name not in client_secrets:
|
||||
client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name)
|
||||
@ -202,7 +203,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
) -> ClientSecretList:
|
||||
"""Get a list of which clients has a named secret."""
|
||||
clients: list[str] = []
|
||||
for client_secret in session.exec(
|
||||
for client_secret in session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.name == name)
|
||||
).all():
|
||||
if not client_secret.client:
|
||||
@ -219,7 +220,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
) -> ClientSecretDetailList:
|
||||
"""Get a list of which clients has a named secret."""
|
||||
detail_list = ClientSecretDetailList(name=name)
|
||||
for client_secret in session.exec(
|
||||
for client_secret in session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.name == name)
|
||||
).all():
|
||||
if not client_secret.client:
|
||||
|
||||
@ -2,9 +2,10 @@
|
||||
|
||||
from collections.abc import Sequence
|
||||
from fastapi import Request
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy
|
||||
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy, Operation, SubSystem
|
||||
|
||||
|
||||
def _get_origin(request: Request) -> str | None:
|
||||
@ -22,7 +23,7 @@ def _write_audit_log(
|
||||
"""Write the audit log."""
|
||||
origin = _get_origin(request)
|
||||
entry.origin = origin
|
||||
entry.subsystem = "backend"
|
||||
entry.subsystem = SubSystem.BACKEND
|
||||
session.add(entry)
|
||||
if commit:
|
||||
session.commit()
|
||||
@ -33,7 +34,7 @@ def audit_create_client(
|
||||
) -> None:
|
||||
"""Log the creation of a client."""
|
||||
entry = AuditLog(
|
||||
operation="CREATE",
|
||||
operation=Operation.CREATE,
|
||||
client_id=client.id,
|
||||
client_name=client.name,
|
||||
message="Client Created",
|
||||
@ -46,7 +47,7 @@ def audit_delete_client(
|
||||
) -> None:
|
||||
"""Log the creation of a client."""
|
||||
entry = AuditLog(
|
||||
operation="CREATE",
|
||||
operation=Operation.CREATE,
|
||||
client_id=client.id,
|
||||
client_name=client.name,
|
||||
message="Client deleted",
|
||||
@ -63,9 +64,9 @@ def audit_create_secret(
|
||||
) -> None:
|
||||
"""Audit a create secret event."""
|
||||
entry = AuditLog(
|
||||
operation="CREATE",
|
||||
object="ClientSecret",
|
||||
object_id=str(secret.id),
|
||||
operation=Operation.CREATE,
|
||||
secret_id=secret.id,
|
||||
secret_name=secret.name,
|
||||
client_id=client.id,
|
||||
client_name=client.name,
|
||||
message="Added secret to client",
|
||||
@ -81,13 +82,13 @@ def audit_remove_policy(
|
||||
commit: bool = True,
|
||||
) -> None:
|
||||
"""Audit removal of policy."""
|
||||
data = {"object": "ClientAccessPolicy", "object_id": str(policy.id)}
|
||||
entry = AuditLog(
|
||||
operation="DELETE",
|
||||
object="ClientAccessPolicy",
|
||||
object_id=str(policy.id),
|
||||
operation=Operation.DELETE,
|
||||
client_id=client.id,
|
||||
client_name=client.name,
|
||||
message="Deleted client policy",
|
||||
data=data,
|
||||
)
|
||||
_write_audit_log(session, request, entry, commit)
|
||||
|
||||
@ -100,13 +101,13 @@ def audit_update_policy(
|
||||
commit: bool = True,
|
||||
) -> None:
|
||||
"""Audit update of policy."""
|
||||
data: dict[str, str] = {"object": "ClientAccessPolicy", "object_id": str(policy.id)}
|
||||
entry = AuditLog(
|
||||
operation="CREATE",
|
||||
object="ClientAccessPolicy",
|
||||
object_id=str(policy.id),
|
||||
client_id=client.id,
|
||||
operation=Operation.CREATE,
|
||||
client_name=client.name,
|
||||
client_id=client.id,
|
||||
message="Updated client policy",
|
||||
data=data,
|
||||
)
|
||||
_write_audit_log(session, request, entry, commit)
|
||||
|
||||
@ -119,11 +120,10 @@ def audit_update_client(
|
||||
) -> None:
|
||||
"""Audit an update secret event."""
|
||||
entry = AuditLog(
|
||||
operation="UPDATE",
|
||||
object="Client",
|
||||
operation=Operation.UPDATE,
|
||||
client_id=client.id,
|
||||
client_name=client.name,
|
||||
message="Client updated",
|
||||
message="Client data updated",
|
||||
)
|
||||
_write_audit_log(session, request, entry, commit)
|
||||
|
||||
@ -137,11 +137,11 @@ def audit_update_secret(
|
||||
) -> None:
|
||||
"""Audit an update secret event."""
|
||||
entry = AuditLog(
|
||||
operation="UPDATE",
|
||||
object="ClientSecret",
|
||||
object_id=str(secret.id),
|
||||
operation=Operation.UPDATE,
|
||||
client_id=client.id,
|
||||
client_name=client.name,
|
||||
secret_name=secret.name,
|
||||
secret_id=secret.id,
|
||||
message="Secret value updated",
|
||||
)
|
||||
_write_audit_log(session, request, entry, commit)
|
||||
@ -155,8 +155,7 @@ def audit_invalidate_secrets(
|
||||
) -> None:
|
||||
"""Audit Invalidate client secrets."""
|
||||
entry = AuditLog(
|
||||
operation="INVALIDATE",
|
||||
object="ClientSecret",
|
||||
operation=Operation.UPDATE,
|
||||
client_name=client.name,
|
||||
client_id=client.id,
|
||||
message="Client public-key changed. All secrets invalidated.",
|
||||
@ -173,9 +172,9 @@ def audit_delete_secret(
|
||||
) -> None:
|
||||
"""Audit Delete client secrets."""
|
||||
entry = AuditLog(
|
||||
operation="DELETE",
|
||||
object="ClientSecret",
|
||||
object_id=str(secret.id),
|
||||
operation=Operation.DELETE,
|
||||
secret_name=secret.name,
|
||||
secret_id=secret.id,
|
||||
client_name=client.name,
|
||||
client_id=client.id,
|
||||
message="Deleted secret.",
|
||||
@ -195,7 +194,7 @@ def audit_access_secrets(
|
||||
With no secrets provided, all secrets of the client will be resolved.
|
||||
"""
|
||||
if not secrets:
|
||||
secrets = session.exec(
|
||||
secrets = session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||
).all()
|
||||
|
||||
@ -215,37 +214,21 @@ def audit_access_secret(
|
||||
) -> None:
|
||||
"""Audit that someone accessed one secrets."""
|
||||
entry = AuditLog(
|
||||
operation="ACCESS",
|
||||
operation=Operation.READ,
|
||||
message="Secret was viewed",
|
||||
object="ClientSecret",
|
||||
object_id=str(secret.id),
|
||||
secret_name=secret.name,
|
||||
secret_id=secret.id,
|
||||
client_id=client.id,
|
||||
client_name=client.name,
|
||||
)
|
||||
_write_audit_log(session, request, entry, commit)
|
||||
|
||||
|
||||
def audit_access_audit_log(
|
||||
session: Session, request: Request, commit: bool = True
|
||||
) -> None:
|
||||
"""Audit access to the audit log.
|
||||
|
||||
Because why not...
|
||||
"""
|
||||
entry = AuditLog(
|
||||
operation="ACCESS",
|
||||
message="Audit log was viewed",
|
||||
object="AuditLog",
|
||||
)
|
||||
_write_audit_log(session, request, entry, commit)
|
||||
|
||||
|
||||
def audit_client_secret_list(
|
||||
session: Session, request: Request, commit: bool = True
|
||||
) -> None:
|
||||
"""Audit a list of all secrets."""
|
||||
entry = AuditLog(
|
||||
operation="ACCESS",
|
||||
operation=Operation.READ,
|
||||
message="All secret names and their clients was viewed",
|
||||
)
|
||||
_write_audit_log(session, request, entry, commit)
|
||||
|
||||
@ -3,11 +3,11 @@
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
import bcrypt
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api
|
||||
from .auth import verify_token
|
||||
from .models import (
|
||||
APIClient,
|
||||
)
|
||||
@ -18,13 +18,6 @@ LOG = logging.getLogger(__name__)
|
||||
API_VERSION = "v1"
|
||||
|
||||
|
||||
def verify_token(token: str, stored_hash: str) -> bool:
|
||||
"""Verify token."""
|
||||
token_bytes = token.encode("utf-8")
|
||||
stored_bytes = stored_hash.encode("utf-8")
|
||||
return bcrypt.checkpw(token_bytes, stored_bytes)
|
||||
|
||||
|
||||
def get_backend_api(
|
||||
get_db_session: DBSessionDep,
|
||||
) -> APIRouter:
|
||||
@ -37,7 +30,7 @@ def get_backend_api(
|
||||
"""Validate token."""
|
||||
LOG.debug("Validating token %s", x_api_token)
|
||||
statement = select(APIClient)
|
||||
results = session.exec(statement)
|
||||
results = session.scalars(statement)
|
||||
valid = False
|
||||
for result in results:
|
||||
if verify_token(x_api_token, result.token):
|
||||
|
||||
@ -3,15 +3,24 @@
|
||||
import code
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
from dotenv import load_dotenv
|
||||
from typing import Literal, cast
|
||||
|
||||
import click
|
||||
from sqlmodel import Session, col, func, select
|
||||
import uvicorn
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .db import get_engine, create_api_token
|
||||
|
||||
from .models import Client, ClientSecret, ClientAccessPolicy, AuditLog, APIClient, init_db
|
||||
from .db import create_api_token, get_engine, hash_token
|
||||
from .models import (
|
||||
APIClient,
|
||||
AuditLog,
|
||||
Client,
|
||||
ClientAccessPolicy,
|
||||
ClientSecret,
|
||||
SubSystem,
|
||||
init_db,
|
||||
)
|
||||
from .settings import BackendSettings
|
||||
|
||||
DEFAULT_LISTEN = "127.0.0.1"
|
||||
@ -21,22 +30,44 @@ WORKDIR = Path(os.getcwd())
|
||||
|
||||
load_dotenv()
|
||||
|
||||
def generate_token(settings: BackendSettings) -> str:
|
||||
|
||||
def generate_token(
|
||||
settings: BackendSettings, subsystem: Literal["admin", "sshd"]
|
||||
) -> str:
|
||||
"""Generate a token."""
|
||||
engine = get_engine(settings.db_url)
|
||||
init_db(engine)
|
||||
with Session(engine) as session:
|
||||
token = create_api_token(session, True)
|
||||
token = create_api_token(session, subsystem)
|
||||
return token
|
||||
|
||||
def count_tokens(settings: BackendSettings) -> int:
|
||||
"""Count the amount of tokens created."""
|
||||
|
||||
def add_system_tokens(settings: BackendSettings) -> None:
|
||||
"""Add token for subsystems."""
|
||||
if not settings.admin_token and not settings.sshd_token:
|
||||
# Tokens should be generated manually.
|
||||
return
|
||||
|
||||
engine = get_engine(settings.db_url)
|
||||
init_db(engine)
|
||||
tokens: list[tuple[str, SubSystem]] = []
|
||||
if admin_token := settings.admin_token:
|
||||
tokens.append((admin_token, SubSystem.ADMIN))
|
||||
if sshd_token := settings.sshd_token:
|
||||
tokens.append((sshd_token, SubSystem.SSHD))
|
||||
with Session(engine) as session:
|
||||
count = session.exec(select(func.count("*")).select_from(APIClient)).one()
|
||||
for token, subsystem in tokens:
|
||||
hashed_token = hash_token(token)
|
||||
if existing := session.scalars(
|
||||
select(APIClient).where(APIClient.subsystem == subsystem)
|
||||
).first():
|
||||
existing.token = hashed_token
|
||||
else:
|
||||
new_token = APIClient(token=hashed_token, subsystem=subsystem)
|
||||
session.add(new_token)
|
||||
|
||||
return count
|
||||
session.commit()
|
||||
click.echo("Generated system tokens.")
|
||||
|
||||
|
||||
@click.group()
|
||||
@ -49,27 +80,30 @@ def cli(ctx: click.Context, database: str) -> None:
|
||||
else:
|
||||
settings = BackendSettings()
|
||||
|
||||
add_system_tokens(settings)
|
||||
|
||||
if settings.generate_initial_tokens:
|
||||
if count_tokens(settings) == 0:
|
||||
click.echo("Creating initial tokens for admin and sshd.")
|
||||
admin_token = generate_token(settings)
|
||||
sshd_token = generate_token(settings)
|
||||
click.echo(f"Admin token: {admin_token}")
|
||||
click.echo(f"SSHD token: {sshd_token}")
|
||||
# if settings.generate_initial_tokens:
|
||||
# if count_tokens(settings) == 0:
|
||||
# click.echo("Creating initial tokens for admin and sshd.")
|
||||
# admin_token = generate_token(settings)
|
||||
# sshd_token = generate_token(settings)
|
||||
# click.echo(f"Admin token: {admin_token}")
|
||||
# click.echo(f"SSHD token: {sshd_token}")
|
||||
|
||||
ctx.obj = settings
|
||||
|
||||
|
||||
@cli.command("generate-token")
|
||||
@click.argument("subsystem", type=click.Choice(["sshd", "admin"]))
|
||||
@click.pass_context
|
||||
def cli_generate_token(ctx: click.Context) -> None:
|
||||
"""Generate a token."""
|
||||
def cli_generate_token(ctx: click.Context, subsystem: Literal["sshd", "admin"]) -> None:
|
||||
"""Generate a token for a subsystem.."""
|
||||
settings = cast(BackendSettings, ctx.obj)
|
||||
token = generate_token(settings)
|
||||
token = generate_token(settings, subsystem)
|
||||
click.echo("Generated api token:")
|
||||
click.echo(token)
|
||||
|
||||
|
||||
@cli.command("run")
|
||||
@click.option("--host", default="127.0.0.1")
|
||||
@click.option("--port", default=8022, type=click.INT)
|
||||
@ -77,7 +111,10 @@ def cli_generate_token(ctx: click.Context) -> None:
|
||||
@click.option("--workers", type=click.INT)
|
||||
def cli_run(host: str, port: int, dev: bool, workers: int | None) -> None:
|
||||
"""Run the server."""
|
||||
uvicorn.run("sshecret_backend.main:app", host=host, port=port, reload=dev, workers=workers)
|
||||
uvicorn.run(
|
||||
"sshecret_backend.main:app", host=host, port=port, reload=dev, workers=workers
|
||||
)
|
||||
|
||||
|
||||
@cli.command("repl")
|
||||
@click.pass_context
|
||||
|
||||
@ -2,56 +2,108 @@
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
import sqlite3
|
||||
|
||||
from collections.abc import Generator, Callable
|
||||
from pathlib import Path
|
||||
from sqlalchemy import Engine
|
||||
from sqlmodel import Session, create_engine, text
|
||||
import bcrypt
|
||||
from typing import Literal
|
||||
from sqlalchemy import create_engine, Engine, event, select
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
|
||||
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
|
||||
from sqlalchemy.engine import URL
|
||||
|
||||
|
||||
from .models import APIClient
|
||||
from .auth import hash_token, verify_token
|
||||
from .models import APIClient, SubSystem
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_database(
|
||||
db_url: URL | str,
|
||||
db_url: URL,
|
||||
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
|
||||
"""Setup database."""
|
||||
|
||||
engine = create_engine(db_url, echo=False)
|
||||
with engine.connect() as connection:
|
||||
connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only
|
||||
engine = get_engine(db_url)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, future=True)
|
||||
|
||||
def get_db_session() -> Generator[Session, None, None]:
|
||||
"""Get DB Session."""
|
||||
with Session(engine) as session:
|
||||
session = SessionLocal(bind=engine)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
return engine, get_db_session
|
||||
|
||||
|
||||
def get_engine(url: URL, echo: bool = False) -> Engine:
|
||||
"""Initialize the engine."""
|
||||
engine = create_engine(url, echo=echo)
|
||||
with engine.connect() as connection:
|
||||
connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only
|
||||
engine = create_engine(url, echo=echo, future=True)
|
||||
if url.drivername.startswith("sqlite"):
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(
|
||||
dbapi_connection: sqlite3.Connection, _connection_record: object
|
||||
) -> None:
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def create_api_token(session: Session, read_write: bool) -> str:
|
||||
"""Create API token."""
|
||||
token = secrets.token_urlsafe(32)
|
||||
pwbytes = token.encode("utf-8")
|
||||
salt = bcrypt.gensalt()
|
||||
hashed_bytes = bcrypt.hashpw(password=pwbytes, salt=salt)
|
||||
hashed = hashed_bytes.decode()
|
||||
def get_async_engine(url: URL, echo: bool = False) -> AsyncEngine:
|
||||
"""Get an async engine."""
|
||||
engine = create_async_engine(url, echo=echo, future=True)
|
||||
if url.drivername.startswith("sqlite+"):
|
||||
|
||||
api_token = APIClient(token=hashed, read_write=read_write)
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(
|
||||
dbapi_connection: sqlite3.Connection, _connection_record: object
|
||||
) -> None:
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
|
||||
def create_api_token_with_value(session: Session, token: str, subsystem: Literal["admin", "sshd"]) -> None:
|
||||
"""Create API token with a given value."""
|
||||
|
||||
existing = session.scalars(select(APIClient).where(APIClient.subsystem == SubSystem(subsystem))).first()
|
||||
if existing:
|
||||
if verify_token(token, existing.token):
|
||||
LOG.info("Token is up to date.")
|
||||
return
|
||||
LOG.info("Updating token value for subsystem %s", subsystem)
|
||||
hashed = hash_token(token)
|
||||
existing.token=hashed
|
||||
session.commit()
|
||||
return
|
||||
|
||||
LOG.info("No existing token found. Creating new")
|
||||
hashed = hash_token(token)
|
||||
api_token = APIClient(token=hashed, subsystem=SubSystem(subsystem))
|
||||
|
||||
session.add(api_token)
|
||||
session.commit()
|
||||
|
||||
def create_api_token(session: Session, subsystem: Literal["admin", "sshd", "test"], recreate: bool = False) -> str:
|
||||
"""Create API token."""
|
||||
subsys = SubSystem(subsystem)
|
||||
token = secrets.token_urlsafe(32)
|
||||
hashed = hash_token(token)
|
||||
if existing := session.scalars(select(APIClient).where(APIClient.subsystem == subsys)).first():
|
||||
if not recreate:
|
||||
raise RuntimeError("Error: A token already exist for this subsystem.")
|
||||
existing.token = hashed
|
||||
else:
|
||||
api_token = APIClient(token=hashed, subsystem=subsys)
|
||||
session.add(api_token)
|
||||
session.commit()
|
||||
|
||||
|
||||
@ -7,128 +7,182 @@ This might require some changes to these schemas.
|
||||
|
||||
"""
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlmodel import JSON, Column, DateTime, Field, Relationship, SQLModel
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Client(SQLModel, table=True):
|
||||
"""Client model."""
|
||||
class SubSystem(enum.StrEnum):
|
||||
"""Available subsystems."""
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
name: str = Field(unique=True)
|
||||
description: str | None = None
|
||||
public_key: str
|
||||
ADMIN = enum.auto()
|
||||
SSHD = enum.auto()
|
||||
BACKEND = enum.auto()
|
||||
TEST = enum.auto()
|
||||
|
||||
created_at: datetime | None = Field(
|
||||
default=None,
|
||||
sa_type=sa.DateTime(timezone=True),
|
||||
sa_column_kwargs={"server_default": sa.func.now()},
|
||||
nullable=False,
|
||||
|
||||
class Operation(enum.StrEnum):
|
||||
"""Various operations for the audit logging module."""
|
||||
|
||||
CREATE = enum.auto()
|
||||
READ = enum.auto()
|
||||
UPDATE = enum.auto()
|
||||
DELETE = enum.auto()
|
||||
DENY = enum.auto()
|
||||
PERMIT = enum.auto()
|
||||
LOGIN = enum.auto()
|
||||
REGISTER = enum.auto()
|
||||
NONE = enum.auto()
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class Client(Base):
|
||||
"""Clients."""
|
||||
|
||||
__tablename__: str = "client"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
name: Mapped[str] = mapped_column(sa.String, unique=True)
|
||||
description: Mapped[str | None] = mapped_column(sa.String, nullable=True)
|
||||
public_key: Mapped[str] = mapped_column(sa.Text)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||
)
|
||||
|
||||
updated_at: datetime | None = Field(
|
||||
default=None,
|
||||
sa_type=sa.DateTime(timezone=True),
|
||||
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
|
||||
updated_at: Mapped[datetime | None] = mapped_column(
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
onupdate=sa.func.now(),
|
||||
)
|
||||
|
||||
secrets: list["ClientSecret"] = Relationship(
|
||||
back_populates="client", passive_deletes="all"
|
||||
secrets: Mapped[list["ClientSecret"]] = relationship(
|
||||
back_populates="client", passive_deletes=True
|
||||
)
|
||||
|
||||
policies: list["ClientAccessPolicy"] = Relationship(back_populates="client")
|
||||
policies: Mapped[list["ClientAccessPolicy"]] = relationship(back_populates="client")
|
||||
|
||||
|
||||
class ClientAccessPolicy(SQLModel, table=True):
|
||||
class ClientAccessPolicy(Base):
|
||||
"""Client access policies."""
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
source: str
|
||||
client_id: uuid.UUID | None = Field(foreign_key="client.id", ondelete="CASCADE")
|
||||
client: Client | None = Relationship(back_populates="policies")
|
||||
__tablename__: str = "client_access_policy"
|
||||
|
||||
created_at: datetime | None = Field(
|
||||
default=None,
|
||||
sa_type=sa.DateTime(timezone=True),
|
||||
sa_column_kwargs={"server_default": sa.func.now()},
|
||||
nullable=False,
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
source: Mapped[str] = mapped_column(sa.String)
|
||||
client_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), sa.ForeignKey("client.id", ondelete="CASCADE")
|
||||
)
|
||||
client: Mapped[Client] = relationship(back_populates="policies")
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||
)
|
||||
|
||||
updated_at: datetime | None = Field(
|
||||
default=None,
|
||||
sa_type=sa.DateTime(timezone=True),
|
||||
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
|
||||
updated_at: Mapped[datetime | None] = mapped_column(
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
onupdate=sa.func.now(),
|
||||
)
|
||||
|
||||
|
||||
class ClientSecret(SQLModel, table=True):
|
||||
class ClientSecret(Base):
|
||||
"""A client secret."""
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
name: str
|
||||
description: str | None = None
|
||||
client_id: uuid.UUID | None = Field(foreign_key="client.id", ondelete="CASCADE")
|
||||
client: Client | None = Relationship(back_populates="secrets")
|
||||
secret: str
|
||||
invalidated: bool = Field(default=False)
|
||||
created_at: datetime | None = Field(
|
||||
default=None,
|
||||
sa_type=sa.DateTime(timezone=True),
|
||||
sa_column_kwargs={"server_default": sa.func.now()},
|
||||
nullable=False,
|
||||
__tablename__: str = "client_secret"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
name: Mapped[str] = mapped_column(sa.String)
|
||||
description: Mapped[str | None] = mapped_column(sa.String, nullable=True)
|
||||
secret: Mapped[str] = mapped_column(sa.String)
|
||||
|
||||
client_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), sa.ForeignKey("client.id", ondelete="CASCADE")
|
||||
)
|
||||
client: Mapped[Client] = relationship(back_populates="secrets")
|
||||
invalidated: Mapped[bool] = mapped_column(default=False)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||
)
|
||||
|
||||
updated_at: datetime | None = Field(
|
||||
default=None,
|
||||
sa_type=sa.DateTime(timezone=True),
|
||||
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
|
||||
updated_at: Mapped[datetime | None] = mapped_column(
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
onupdate=sa.func.now(),
|
||||
)
|
||||
|
||||
|
||||
class AuditLog(SQLModel, table=True):
|
||||
class APIClient(Base):
|
||||
"""A client on the API.
|
||||
|
||||
This should eventually get more granular permissions.
|
||||
"""
|
||||
|
||||
__tablename__: str = "api_client"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
subsystem: Mapped[SubSystem | None] = mapped_column(sa.String, nullable=True)
|
||||
token: Mapped[str] = mapped_column(sa.String)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime | None] = mapped_column(
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
onupdate=sa.func.now(),
|
||||
)
|
||||
|
||||
|
||||
class AuditLog(Base):
|
||||
"""Audit log.
|
||||
|
||||
This is implemented without any foreign keys to avoid losing data on
|
||||
deletions.
|
||||
"""
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
subsystem: str
|
||||
message: str
|
||||
operation: str
|
||||
client_id: uuid.UUID | None = None
|
||||
client_name: str | None = None
|
||||
origin: str | None = None
|
||||
Field(default=None, sa_column=Column(JSON))
|
||||
|
||||
timestamp: datetime | None = Field(
|
||||
default=None,
|
||||
sa_type=sa.DateTime(timezone=True),
|
||||
sa_column_kwargs={"server_default": sa.func.now()},
|
||||
nullable=False,
|
||||
__tablename__: str = "audit_log"
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
|
||||
subsystem: Mapped[SubSystem] = mapped_column(sa.String)
|
||||
message: Mapped[str] = mapped_column(sa.String)
|
||||
operation: Mapped[Operation] = mapped_column(sa.String)
|
||||
client_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), nullable=True
|
||||
)
|
||||
data: Mapped[dict[str, str] | None] = mapped_column(sa.JSON, nullable=True)
|
||||
client_name: Mapped[str | None] = mapped_column(sa.String, nullable=True)
|
||||
secret_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), nullable=True
|
||||
)
|
||||
secret_name: Mapped[str | None] = mapped_column(sa.String, nullable=True)
|
||||
|
||||
class APIClient(SQLModel, table=True):
|
||||
"""Stores API Keys."""
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
token: str
|
||||
read_write: bool
|
||||
created_at: datetime | None = Field(
|
||||
default=None,
|
||||
sa_type=sa.DateTime(timezone=True),
|
||||
sa_column_kwargs={"server_default": sa.func.now()},
|
||||
nullable=False,
|
||||
origin: Mapped[str | None] = mapped_column(sa.String, nullable=True)
|
||||
timestamp: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||
)
|
||||
|
||||
|
||||
def init_db(engine: sa.Engine) -> None:
|
||||
"""Create database."""
|
||||
LOG.info("Running init_db")
|
||||
SQLModel.metadata.create_all(engine)
|
||||
"""Initialize database."""
|
||||
Base.metadata.create_all(engine)
|
||||
|
||||
@ -1,12 +1,10 @@
|
||||
"""Settings management."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic import Field
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
SettingsConfigDict,
|
||||
ForceDecode,
|
||||
)
|
||||
from sqlalchemy import URL
|
||||
|
||||
@ -22,24 +20,19 @@ class BackendSettings(BaseSettings):
|
||||
)
|
||||
|
||||
database: str = Field(default=DEFAULT_DATABASE)
|
||||
generate_initial_tokens: Annotated[bool, ForceDecode] = Field(default=False)
|
||||
|
||||
@field_validator("generate_initial_tokens", mode="before")
|
||||
@classmethod
|
||||
def cast_bool(cls, value: Any) -> bool:
|
||||
"""Ensure we catch the boolean."""
|
||||
if isinstance(value, str):
|
||||
if value.lower() in ("1", "true", "on"):
|
||||
return True
|
||||
if value.lower() in ("0", "false", "off"):
|
||||
return False
|
||||
return bool(value)
|
||||
admin_token: str | None = Field(default=None, alias="sshecret_admin_backend_token")
|
||||
sshd_token: str | None = Field(default=None, alias="sshecret_sshd_backend_token")
|
||||
|
||||
@property
|
||||
def db_url(self) -> URL:
|
||||
"""Construct database url."""
|
||||
return URL.create(drivername="sqlite", database=self.database)
|
||||
|
||||
@property
|
||||
def async_db_url(self) -> URL:
|
||||
"""Construct database url with sync handling."""
|
||||
return URL.create(drivername="sqlite+aiosqlite", database=self.database)
|
||||
|
||||
@property
|
||||
def db_exists(self) -> bool:
|
||||
"""Check if databatase exists."""
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
"""Test helpers."""
|
||||
|
||||
import logging
|
||||
from sqlmodel import Session
|
||||
from sshecret_backend.settings import BackendSettings
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from .models import init_db
|
||||
from .db import create_api_token, setup_database
|
||||
|
||||
@ -14,4 +15,4 @@ def create_test_token(settings: BackendSettings) -> str:
|
||||
engine, _setupdb = setup_database(settings.db_url)
|
||||
with Session(engine) as session:
|
||||
init_db(engine)
|
||||
return create_api_token(session, True)
|
||||
return create_api_token(session, "test")
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
from collections.abc import Callable, Generator
|
||||
|
||||
from sqlmodel import Session
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
DBSessionDep = Callable[[], Generator[Session, None, None]]
|
||||
|
||||
@ -4,15 +4,14 @@ import uuid
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Self, override
|
||||
|
||||
from sqlmodel import Field, SQLModel
|
||||
from pydantic import AfterValidator, IPvAnyAddress, IPvAnyNetwork
|
||||
from pydantic import AfterValidator, BaseModel, Field, IPvAnyAddress, IPvAnyNetwork
|
||||
|
||||
from sshecret.crypto import public_key_validator
|
||||
|
||||
from . import models
|
||||
|
||||
|
||||
class ClientView(SQLModel):
|
||||
class ClientView(BaseModel):
|
||||
"""View for a single client."""
|
||||
|
||||
id: uuid.UUID
|
||||
@ -50,7 +49,7 @@ class ClientView(SQLModel):
|
||||
return view
|
||||
|
||||
|
||||
class ClientQueryResult(SQLModel):
|
||||
class ClientQueryResult(BaseModel):
|
||||
"""Result class for queries towards the client list."""
|
||||
|
||||
clients: list[ClientView] = Field(default_factory=list)
|
||||
@ -58,7 +57,7 @@ class ClientQueryResult(SQLModel):
|
||||
remaining_results: int
|
||||
|
||||
|
||||
class ClientCreate(SQLModel):
|
||||
class ClientCreate(BaseModel):
|
||||
"""Model to create a client."""
|
||||
|
||||
name: str
|
||||
@ -74,19 +73,19 @@ class ClientCreate(SQLModel):
|
||||
)
|
||||
|
||||
|
||||
class ClientUpdate(SQLModel):
|
||||
class ClientUpdate(BaseModel):
|
||||
"""Model to update the client public key."""
|
||||
|
||||
public_key: Annotated[str, AfterValidator(public_key_validator)]
|
||||
|
||||
|
||||
class BodyValue(SQLModel):
|
||||
class BodyValue(BaseModel):
|
||||
"""A generic model with just a value parameter."""
|
||||
|
||||
value: str
|
||||
|
||||
|
||||
class ClientSecretPublic(SQLModel):
|
||||
class ClientSecretPublic(BaseModel):
|
||||
"""Public model to manage client secrets."""
|
||||
|
||||
name: str
|
||||
@ -122,7 +121,7 @@ class ClientSecretResponse(ClientSecretPublic):
|
||||
)
|
||||
|
||||
|
||||
class ClientPolicyView(SQLModel):
|
||||
class ClientPolicyView(BaseModel):
|
||||
"""Update object for client policy."""
|
||||
|
||||
sources: list[str] = ["0.0.0.0/0", "::/0"]
|
||||
@ -135,27 +134,27 @@ class ClientPolicyView(SQLModel):
|
||||
return cls(sources=[policy.source for policy in client.policies])
|
||||
|
||||
|
||||
class ClientPolicyUpdate(SQLModel):
|
||||
class ClientPolicyUpdate(BaseModel):
|
||||
"""Model for updating policies."""
|
||||
|
||||
sources: list[IPvAnyAddress | IPvAnyNetwork]
|
||||
|
||||
|
||||
class ClientSecretList(SQLModel):
|
||||
class ClientSecretList(BaseModel):
|
||||
"""Model for aggregating identically named secrets."""
|
||||
|
||||
name: str
|
||||
clients: list[str]
|
||||
|
||||
|
||||
class ClientReference(SQLModel):
|
||||
class ClientReference(BaseModel):
|
||||
"""Reference to a client."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class ClientSecretDetailList(SQLModel):
|
||||
class ClientSecretDetailList(BaseModel):
|
||||
"""A more detailed version of the ClientSecretList."""
|
||||
|
||||
name: str
|
||||
@ -163,7 +162,22 @@ class ClientSecretDetailList(SQLModel):
|
||||
clients: list[ClientReference] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AuditInfo(SQLModel):
|
||||
class AuditView(BaseModel):
|
||||
"""Audit log view."""
|
||||
|
||||
|
||||
id: uuid.UUID | None = None
|
||||
subsystem: models.SubSystem
|
||||
message: str
|
||||
operation: models.Operation
|
||||
data: dict[str, str] | None = None
|
||||
client_id: uuid.UUID | None = None
|
||||
client_name: str | None = None
|
||||
origin: str | None = None
|
||||
timestamp: datetime | None = None
|
||||
|
||||
|
||||
class AuditInfo(BaseModel):
|
||||
"""Information about audit information."""
|
||||
|
||||
entries: int
|
||||
|
||||
@ -10,7 +10,7 @@ from fastapi.testclient import TestClient
|
||||
from sshecret.crypto import generate_private_key, generate_public_key_string
|
||||
from sshecret_backend.app import create_backend_app
|
||||
from sshecret_backend.testing import create_test_token
|
||||
from sshecret_backend.models import AuditLog
|
||||
from sshecret_backend.view_models import AuditView
|
||||
from sshecret_backend.settings import BackendSettings
|
||||
|
||||
|
||||
@ -53,7 +53,7 @@ def create_client_fixture(tmp_path: Path):
|
||||
|
||||
db_file = tmp_path / "backend.db"
|
||||
print(f"DB File: {db_file.absolute()}")
|
||||
settings = BackendSettings(db_url=f"sqlite:///{db_file.absolute()}")
|
||||
settings = BackendSettings(database=str(db_file.absolute()))
|
||||
app = create_backend_app(settings)
|
||||
|
||||
token = create_test_token(settings)
|
||||
@ -213,7 +213,7 @@ def test_audit_logging(test_client: TestClient) -> None:
|
||||
assert len(audit_logs) > 0
|
||||
for entry in audit_logs:
|
||||
# Let's try to reassemble the objects
|
||||
audit_log = AuditLog.model_validate(entry)
|
||||
audit_log = AuditView.model_validate(entry)
|
||||
assert audit_log is not None
|
||||
|
||||
|
||||
@ -522,9 +522,8 @@ def test_operations_with_id(test_client: TestClient) -> None:
|
||||
def test_write_audit_log(test_client: TestClient) -> None:
|
||||
"""Test writing to the audit log."""
|
||||
params = {
|
||||
"object": "Test",
|
||||
"operation": "TEST",
|
||||
"object_id": "Something",
|
||||
"subsystem": "backend",
|
||||
"operation": "read",
|
||||
"message": "Test Message"
|
||||
}
|
||||
resp = test_client.post("/api/v1/audit", json=params)
|
||||
|
||||
@ -8,8 +8,10 @@ from .models import (
|
||||
ClientReference,
|
||||
ClientSecret,
|
||||
DetailedSecrets,
|
||||
Operation,
|
||||
Policy,
|
||||
Secret,
|
||||
SubSystem,
|
||||
FilterType,
|
||||
)
|
||||
|
||||
@ -22,7 +24,9 @@ __all__ = [
|
||||
"ClientSecret",
|
||||
"DetailedSecrets",
|
||||
"FilterType",
|
||||
"Operation",
|
||||
"Policy",
|
||||
"Secret",
|
||||
"SubSystem",
|
||||
"SshecretBackend",
|
||||
]
|
||||
|
||||
@ -1,21 +1,26 @@
|
||||
"""Backend client.
|
||||
|
||||
This is an API calling the HTTP API of the sshecret-backend package so that the
|
||||
admin and sshd library do not need to implement the same
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Self
|
||||
from typing import Any, Self, override
|
||||
import httpx
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from .models import (
|
||||
AuditInfo,
|
||||
AuditLog,
|
||||
Client,
|
||||
ClientSecret,
|
||||
ClientQueryResult,
|
||||
ClientFilter,
|
||||
DetailedSecrets,
|
||||
Operation,
|
||||
Secret,
|
||||
SubSystem,
|
||||
)
|
||||
from .exceptions import BackendValidationError, BackendConnectionError
|
||||
from .utils import validate_public_key
|
||||
@ -84,12 +89,14 @@ class ClientQueryIterator:
|
||||
raise StopAsyncIteration
|
||||
|
||||
|
||||
class SshecretBackend:
|
||||
"""Backend interface."""
|
||||
class BaseBackend:
|
||||
"""Base backend class."""
|
||||
|
||||
def __init__(self, backend_url: str, api_token: str) -> None:
|
||||
"""Initialize backend client."""
|
||||
|
||||
self._backend_url: str = backend_url
|
||||
self._api_token: str = api_token
|
||||
url = httpx.URL(backend_url)
|
||||
|
||||
self.http_client: httpx.AsyncClient = httpx.AsyncClient(
|
||||
@ -101,31 +108,55 @@ class SshecretBackend:
|
||||
base_url=url,
|
||||
)
|
||||
|
||||
async def _get(self, path: str) -> httpx.Response:
|
||||
def _handle_response(self, response: httpx.Response) -> httpx.Response:
|
||||
"""Handle response."""
|
||||
LOG.debug("Handling response with status_code %s", response.status_code)
|
||||
if response.status_code == 422:
|
||||
LOG.error("Validation error from backend:\n%s", response.text)
|
||||
raise BackendValidationError(response.text)
|
||||
if response.status_code != 404 and str(response.status_code).startswith("4"):
|
||||
raise BackendConnectionError("Error from backend: %s", response.text)
|
||||
|
||||
return response
|
||||
|
||||
async def _get(
|
||||
self, path: str, params: dict[str, str] | None = None
|
||||
) -> httpx.Response:
|
||||
"""Perform a get request."""
|
||||
try:
|
||||
return await self.http_client.get(path)
|
||||
response = await self.http_client.get(path, params=params)
|
||||
return self._handle_response(response)
|
||||
except httpx.ConnectError as e:
|
||||
raise BackendConnectionError() from e
|
||||
raise BackendConnectionError("Could not connect to backend.") from e
|
||||
|
||||
async def _delete(self, path: str) -> httpx.Response:
|
||||
"""Perform a delete request."""
|
||||
try:
|
||||
return await self.http_client.delete(path)
|
||||
response = await self.http_client.delete(path)
|
||||
return self._handle_response(response)
|
||||
except httpx.ConnectError as e:
|
||||
raise BackendConnectionError() from e
|
||||
|
||||
async def _post(self, path: str, json: Any | None = None) -> httpx.Response:
|
||||
"""Perform a POST request."""
|
||||
try:
|
||||
return await self.http_client.post(path, json=json)
|
||||
response = await self.http_client.post(path, json=json)
|
||||
return self._handle_response(response)
|
||||
except httpx.ConnectError as e:
|
||||
raise BackendConnectionError() from e
|
||||
|
||||
def _post_sync(self, path: str, json: Any | None = None) -> httpx.Response:
|
||||
"""Perform a synchronous post request."""
|
||||
try:
|
||||
return self._handle_response(self.sync_client.post(path, json=json))
|
||||
except httpx.ConnectError as e:
|
||||
raise BackendConnectionError() from e
|
||||
|
||||
async def _put(self, path: str, json: Any | None = None) -> httpx.Response:
|
||||
"""Perform a PUT request."""
|
||||
try:
|
||||
return await self.http_client.put(path, json=json)
|
||||
response = await self.http_client.put(path, json=json)
|
||||
return self._handle_response(response)
|
||||
except httpx.ConnectError as e:
|
||||
raise BackendConnectionError() from e
|
||||
|
||||
@ -134,175 +165,90 @@ class SshecretBackend:
|
||||
response = await self._get(path)
|
||||
return response
|
||||
|
||||
async def create_client(
|
||||
self, name: str, public_key: str, description: str | None = None
|
||||
) -> None:
|
||||
"""Register a new client."""
|
||||
if not validate_public_key(public_key):
|
||||
raise BackendValidationError("Error: Invalid public key format.")
|
||||
data = {
|
||||
"name": name,
|
||||
"public_key": public_key,
|
||||
}
|
||||
if description:
|
||||
data["description"] = description
|
||||
path = "/api/v1/clients/"
|
||||
response = await self._post(path, json=data)
|
||||
|
||||
response.raise_for_status()
|
||||
class AuditAPI(BaseBackend):
|
||||
"""API for the audit logging."""
|
||||
|
||||
async def get_clients(self, filter: ClientFilter | None = None) -> list[Client]:
|
||||
"""Get all clients."""
|
||||
clients: list[Client] = []
|
||||
async for client in ClientQueryIterator(self.http_client, filter):
|
||||
clients.append(client)
|
||||
@override
|
||||
def __init__(self, backend_url: str, api_token: str, subsystem: str) -> None:
|
||||
"""Initialize backend client."""
|
||||
super().__init__(backend_url, api_token)
|
||||
self.subsystem: SubSystem = SubSystem(subsystem)
|
||||
|
||||
return clients
|
||||
|
||||
async def get_client(self, name: str) -> Client | None:
|
||||
"""Lookup a client on username."""
|
||||
path = f"/api/v1/clients/{name}"
|
||||
response = await self.request(path)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
response.raise_for_status()
|
||||
client = Client.model_validate(response.json())
|
||||
return client
|
||||
|
||||
async def get_client_by_id(self, id: str) -> Client | None:
|
||||
"""Lookup a client on username."""
|
||||
path = f"/api/v1/clients/id/{id}"
|
||||
response = await self.request(path)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
response.raise_for_status()
|
||||
client = Client.model_validate(response.json())
|
||||
return client
|
||||
|
||||
async def delete_client(self, client_name: str) -> None:
|
||||
"""Delete a client."""
|
||||
path = f"/api/v1/clients/{client_name}"
|
||||
response = await self._delete(path)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
async def delete_client_by_id(self, id: str) -> None:
|
||||
"""Delete a client."""
|
||||
path = f"/api/v1/clients/id/{id}"
|
||||
response = await self._delete(path)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
async def create_client_secret(
|
||||
self, client_name: str, secret_name: str, encrypted_secret: str
|
||||
) -> None:
|
||||
"""Create a secret.
|
||||
|
||||
This will overwrite any existing secret with that name.
|
||||
"""
|
||||
path = f"api/v1/clients/{client_name}/secrets/{secret_name}"
|
||||
response = await self._put(path, json={"value": encrypted_secret})
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
async def get_client_secret(self, name: str, secret_name: str) -> str:
|
||||
"""Fetch a secret."""
|
||||
path = f"/api/v1/clients/{name}/secrets/{secret_name}"
|
||||
response = await self.request(path)
|
||||
response.raise_for_status()
|
||||
secret = ClientSecret.model_validate(response.json())
|
||||
return secret.secret
|
||||
|
||||
async def delete_client_secret(self, client_name: str, secret_name: str) -> None:
|
||||
"""Delete a secret from a client."""
|
||||
path = f"api/v1/clients/{client_name}/secrets/{secret_name}"
|
||||
response = await self._delete(path)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
async def update_client(self, client: Client) -> Client:
|
||||
"""Update the client."""
|
||||
path = f"/api/v1/clients/{client.name}"
|
||||
client_update = {
|
||||
"name": client.name,
|
||||
"description": client.description,
|
||||
"public_key": client.public_key,
|
||||
}
|
||||
response = await self._put(path, json=client_update)
|
||||
LOG.info("Response %s", response.text)
|
||||
|
||||
response.raise_for_status()
|
||||
if client.policies:
|
||||
await self.update_client_sources(
|
||||
str(client.id), [str(source) for source in client.policies]
|
||||
def _create_model(
|
||||
self,
|
||||
operation: Operation,
|
||||
message: str,
|
||||
origin: str,
|
||||
client: Client | None = None,
|
||||
secret: ClientSecret | None = None,
|
||||
secret_name: str | None = None,
|
||||
data: dict[str, str] | None = None,
|
||||
) -> AuditLog:
|
||||
"""Create the audit log object."""
|
||||
model = AuditLog(
|
||||
subsystem=self.subsystem,
|
||||
operation=operation,
|
||||
message=message,
|
||||
origin=origin,
|
||||
)
|
||||
return client
|
||||
if client:
|
||||
model.client_id = str(client.id) or None
|
||||
model.client_name = client.name
|
||||
|
||||
async def update_client_key(self, client_name: str, public_key: str) -> None:
|
||||
"""Update the client key."""
|
||||
path = f"/api/v1/clients/{client_name}/public-key"
|
||||
response = await self._post(path, json={"public_key": public_key})
|
||||
if secret:
|
||||
model.secret_name = secret.name
|
||||
elif secret_name:
|
||||
model.secret_name = secret_name
|
||||
if data:
|
||||
model.data = data
|
||||
|
||||
response.raise_for_status()
|
||||
return model
|
||||
|
||||
async def update_client_sources(
|
||||
self, client_name: str, addresses: list[str] | None
|
||||
def write_model(self, model: AuditLog) -> None:
|
||||
"""Write model."""
|
||||
path = f"/api/v1/audit/"
|
||||
self.sync_client.post(path, json=model.model_dump())
|
||||
|
||||
async def write_model_async(self, model: AuditLog) -> None:
|
||||
"""Write model async."""
|
||||
path = f"/api/v1/audit/"
|
||||
await self._post(path, json=model.model_dump())
|
||||
|
||||
def write(
|
||||
self,
|
||||
operation: Operation,
|
||||
message: str,
|
||||
origin: str,
|
||||
client: Client | None = None,
|
||||
secret: ClientSecret | None = None,
|
||||
secret_name: str | None = None,
|
||||
**data: str,
|
||||
) -> None:
|
||||
"""Update client source addresses.
|
||||
"""Write an audit entry."""
|
||||
model = self._create_model(
|
||||
operation, message, origin, client, secret, secret_name, data
|
||||
)
|
||||
|
||||
Pass None to sources to allow from all.
|
||||
"""
|
||||
if not addresses:
|
||||
addresses = []
|
||||
self.write_model(model)
|
||||
|
||||
path = f"/api/v1/clients/{client_name}/policies/"
|
||||
response = await self._put(path, json={"sources": addresses})
|
||||
async def write_async(
|
||||
self,
|
||||
operation: Operation,
|
||||
message: str,
|
||||
origin: str,
|
||||
client: Client | None = None,
|
||||
secret: ClientSecret | None = None,
|
||||
secret_name: str | None = None,
|
||||
**data: str,
|
||||
) -> None:
|
||||
"""Write an audit entry."""
|
||||
model = self._create_model(
|
||||
operation, message, origin, client, secret, secret_name, data
|
||||
)
|
||||
await self.write_model_async(model)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
async def get_detailed_secrets(self) -> list[DetailedSecrets]:
|
||||
"""Get detailed list of secrets."""
|
||||
path = "/api/v1/secrets/detailed/"
|
||||
response = await self._get(path)
|
||||
response.raise_for_status()
|
||||
|
||||
secret_list = TypeAdapter(list[DetailedSecrets])
|
||||
return secret_list.validate_python(response.json())
|
||||
|
||||
async def get_secrets(self) -> list[Secret]:
|
||||
"""Get Secrets.
|
||||
|
||||
This provides a list of secret names and which clients have them.
|
||||
"""
|
||||
path = "/api/v1/secrets/"
|
||||
response = await self._get(path)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
secret_list = TypeAdapter(list[Secret])
|
||||
return secret_list.validate_python(response.json())
|
||||
|
||||
async def get_secret(self, name: str) -> Secret | None:
|
||||
"""Get clients mapped to a single secret."""
|
||||
path = f"/api/v1/secrets/{name}"
|
||||
response = await self._get(path)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
response.raise_for_status()
|
||||
|
||||
return Secret.model_validate(response.json())
|
||||
|
||||
async def get_detailed_secret(self, name: str) -> DetailedSecrets | None:
|
||||
"""Get clients mapped to a single secret."""
|
||||
path = f"/api/v1/secrets/{name}/detailed"
|
||||
response = await self._get(path)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
response.raise_for_status()
|
||||
|
||||
return DetailedSecrets.model_validate(response.json())
|
||||
|
||||
async def get_audit_log(
|
||||
async def get(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 100,
|
||||
@ -321,30 +267,169 @@ class SshecretBackend:
|
||||
if subsystem:
|
||||
params["filter_subsystem"] = subsystem
|
||||
|
||||
response = await self.http_client.get(path, params=params)
|
||||
response.raise_for_status()
|
||||
response = await self._get(path, params=params)
|
||||
audit_log_adapter = TypeAdapter(list[AuditLog])
|
||||
return audit_log_adapter.validate_python(response.json())
|
||||
|
||||
async def add_audit_log(self, entry: AuditLog) -> None:
|
||||
"""Add audit log entry."""
|
||||
path = f"/api/v1/audit/"
|
||||
|
||||
response = await self.http_client.post(path, json=entry.model_dump())
|
||||
response.raise_for_status()
|
||||
|
||||
async def get_audit_log_count(self) -> int:
|
||||
async def count(self) -> int:
|
||||
"""Get amount of messages in the audit log."""
|
||||
path = f"/api/v1/audit/info"
|
||||
response = await self._get(path)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return int(data["entries"])
|
||||
audit_info = AuditInfo.model_validate(response.json())
|
||||
return audit_info.entries
|
||||
|
||||
def add_audit_log_sync(self, entry: AuditLog) -> None:
|
||||
"""Add audit log entry."""
|
||||
path = f"/api/v1/audit/"
|
||||
LOG.info("AUDIT LOG SYNC %r", entry)
|
||||
|
||||
response = self.sync_client.post(path, json=entry.model_dump())
|
||||
response.raise_for_status()
|
||||
class SshecretBackend(BaseBackend):
|
||||
"""Backend interface."""
|
||||
|
||||
async def create_client(
|
||||
self, name: str, public_key: str, description: str | None = None
|
||||
) -> None:
|
||||
"""Register a new client."""
|
||||
if not validate_public_key(public_key):
|
||||
raise BackendValidationError("Error: Invalid public key format.")
|
||||
data = {
|
||||
"name": name,
|
||||
"public_key": public_key,
|
||||
}
|
||||
if description:
|
||||
data["description"] = description
|
||||
path = "/api/v1/clients/"
|
||||
response = await self._post(path, json=data)
|
||||
|
||||
async def get_clients(self, filter: ClientFilter | None = None) -> list[Client]:
|
||||
"""Get all clients."""
|
||||
clients: list[Client] = []
|
||||
async for client in ClientQueryIterator(self.http_client, filter):
|
||||
clients.append(client)
|
||||
|
||||
return clients
|
||||
|
||||
async def get_client(self, name: str) -> Client | None:
|
||||
"""Lookup a client on username."""
|
||||
path = f"/api/v1/clients/{name}"
|
||||
response = await self._get(path)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
client = Client.model_validate(response.json())
|
||||
return client
|
||||
|
||||
async def get_client_by_id(self, id: str) -> Client | None:
|
||||
"""Lookup a client on username."""
|
||||
path = f"/api/v1/clients/id/{id}"
|
||||
response = await self._get(path)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
client = Client.model_validate(response.json())
|
||||
return client
|
||||
|
||||
async def delete_client(self, client_name: str) -> None:
|
||||
"""Delete a client."""
|
||||
path = f"/api/v1/clients/{client_name}"
|
||||
response = await self._delete(path)
|
||||
|
||||
async def delete_client_by_id(self, id: str) -> None:
|
||||
"""Delete a client."""
|
||||
path = f"/api/v1/clients/id/{id}"
|
||||
response = await self._delete(path)
|
||||
|
||||
async def create_client_secret(
|
||||
self, client_name: str, secret_name: str, encrypted_secret: str
|
||||
) -> None:
|
||||
"""Create a secret.
|
||||
|
||||
This will overwrite any existing secret with that name.
|
||||
"""
|
||||
path = f"api/v1/clients/{client_name}/secrets/{secret_name}"
|
||||
response = await self._put(path, json={"value": encrypted_secret})
|
||||
|
||||
async def get_client_secret(self, name: str, secret_name: str) -> str | None:
|
||||
"""Fetch a secret."""
|
||||
path = f"/api/v1/clients/{name}/secrets/{secret_name}"
|
||||
response = await self._get(path)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
secret = ClientSecret.model_validate(response.json())
|
||||
return secret.secret
|
||||
|
||||
async def delete_client_secret(self, client_name: str, secret_name: str) -> None:
|
||||
"""Delete a secret from a client."""
|
||||
path = f"api/v1/clients/{client_name}/secrets/{secret_name}"
|
||||
await self._delete(path)
|
||||
|
||||
async def update_client(self, client: Client) -> Client:
|
||||
"""Update the client."""
|
||||
path = f"/api/v1/clients/{client.name}"
|
||||
client_update = {
|
||||
"name": client.name,
|
||||
"description": client.description,
|
||||
"public_key": client.public_key,
|
||||
}
|
||||
response = await self._put(path, json=client_update)
|
||||
LOG.info("Response %s", response.text)
|
||||
|
||||
if client.policies:
|
||||
await self.update_client_sources(
|
||||
str(client.id), [str(source) for source in client.policies]
|
||||
)
|
||||
return client
|
||||
|
||||
async def update_client_key(self, client_name: str, public_key: str) -> None:
|
||||
"""Update the client key."""
|
||||
path = f"/api/v1/clients/{client_name}/public-key"
|
||||
await self._post(path, json={"public_key": public_key})
|
||||
|
||||
async def update_client_sources(
|
||||
self, client_name: str, addresses: list[str] | None
|
||||
) -> None:
|
||||
"""Update client source addresses.
|
||||
|
||||
Pass None to sources to allow from all.
|
||||
"""
|
||||
if not addresses:
|
||||
addresses = []
|
||||
|
||||
path = f"/api/v1/clients/{client_name}/policies/"
|
||||
await self._put(path, json={"sources": addresses})
|
||||
|
||||
async def get_detailed_secrets(self) -> list[DetailedSecrets]:
|
||||
"""Get detailed list of secrets."""
|
||||
path = "/api/v1/secrets/detailed/"
|
||||
response = await self._get(path)
|
||||
|
||||
secret_list = TypeAdapter(list[DetailedSecrets])
|
||||
return secret_list.validate_python(response.json())
|
||||
|
||||
async def get_secrets(self) -> list[Secret]:
|
||||
"""Get Secrets.
|
||||
|
||||
This provides a list of secret names and which clients have them.
|
||||
"""
|
||||
path = "/api/v1/secrets/"
|
||||
response = await self._get(path)
|
||||
|
||||
secret_list = TypeAdapter(list[Secret])
|
||||
return secret_list.validate_python(response.json())
|
||||
|
||||
async def get_secret(self, name: str) -> Secret | None:
|
||||
"""Get clients mapped to a single secret."""
|
||||
path = f"/api/v1/secrets/{name}"
|
||||
response = await self._get(path)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
|
||||
return Secret.model_validate(response.json())
|
||||
|
||||
async def get_detailed_secret(self, name: str) -> DetailedSecrets | None:
|
||||
"""Get clients mapped to a single secret."""
|
||||
path = f"/api/v1/secrets/{name}/detailed"
|
||||
response = await self._get(path)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
|
||||
return DetailedSecrets.model_validate(response.json())
|
||||
|
||||
def audit(self, subsystem: SubSystem) -> AuditAPI:
|
||||
"""Create the audit API."""
|
||||
audit = AuditAPI(self._backend_url, self._api_token, subsystem)
|
||||
return audit
|
||||
|
||||
@ -17,6 +17,27 @@ class FilterType(enum.StrEnum):
|
||||
CONTAINS = "contains"
|
||||
|
||||
|
||||
class SubSystem(enum.StrEnum):
|
||||
"""Available subsystems."""
|
||||
|
||||
ADMIN = enum.auto()
|
||||
SSHD = enum.auto()
|
||||
BACKEND = enum.auto()
|
||||
|
||||
|
||||
class Operation(enum.StrEnum):
|
||||
"""Various operations for the audit logging module."""
|
||||
|
||||
CREATE = enum.auto()
|
||||
READ = enum.auto()
|
||||
UPDATE = enum.auto()
|
||||
DELETE = enum.auto()
|
||||
DENY = enum.auto()
|
||||
PERMIT = enum.auto()
|
||||
LOGIN = enum.auto()
|
||||
NONE = enum.auto()
|
||||
|
||||
|
||||
class Client(BaseModel):
|
||||
"""Implementation of the backend class ClientView."""
|
||||
|
||||
@ -101,15 +122,22 @@ class ClientFilter(BaseModel):
|
||||
|
||||
|
||||
class AuditLog(BaseModel):
|
||||
"""Implementation of the backend class AuditLog."""
|
||||
"""Implementation of the backend class AuditView."""
|
||||
|
||||
id: str | None = None
|
||||
subsystem: str | None = None
|
||||
object: str | None = None
|
||||
object_id: str | None = None
|
||||
operation: str
|
||||
subsystem: SubSystem
|
||||
operation: Operation
|
||||
client_id: str | None = None
|
||||
client_name: str | None = None
|
||||
secret_id: str | None = None
|
||||
secret_name: str | None = None
|
||||
data: dict[str, str] | None = None
|
||||
message: str
|
||||
origin: str | None = None
|
||||
timestamp: datetime | None = None
|
||||
|
||||
|
||||
class AuditInfo(BaseModel):
|
||||
"""Implementation of the backend class AuditInfo."""
|
||||
|
||||
entries: int
|
||||
|
||||
Reference in New Issue
Block a user