Refactor database layer and auditing
This commit is contained in:
@ -3,22 +3,22 @@ from logging.config import fileConfig
|
|||||||
|
|
||||||
from sqlalchemy import engine_from_config
|
from sqlalchemy import engine_from_config
|
||||||
from sqlalchemy import pool
|
from sqlalchemy import pool
|
||||||
from sqlmodel import create_engine
|
|
||||||
|
|
||||||
from alembic import context
|
from alembic import context
|
||||||
from sshecret_backend.models import *
|
from sshecret_backend.models import Base
|
||||||
|
|
||||||
def get_database_url() -> str:
|
|
||||||
"""Get database URL."""
|
|
||||||
if db_file := os.getenv("SSHECRET_BACKEND_DB"):
|
|
||||||
return f"sqlite:///{db_file}"
|
|
||||||
return "sqlite:///sshecret.db"
|
|
||||||
|
|
||||||
|
|
||||||
# this is the Alembic Config object, which provides
|
# this is the Alembic Config object, which provides
|
||||||
# access to the values within the .ini file in use.
|
# access to the values within the .ini file in use.
|
||||||
config = context.config
|
config = context.config
|
||||||
|
|
||||||
|
|
||||||
|
def get_database_url() -> str | None:
|
||||||
|
"""Get database URL."""
|
||||||
|
if db_file := os.getenv("SSHECRET_BACKEND_DB"):
|
||||||
|
return f"sqlite:///{db_file}"
|
||||||
|
return config.get_main_option("sqlalchemy.url")
|
||||||
|
|
||||||
|
|
||||||
# Interpret the config file for Python logging.
|
# Interpret the config file for Python logging.
|
||||||
# This line sets up loggers basically.
|
# This line sets up loggers basically.
|
||||||
if config.config_file_name is not None:
|
if config.config_file_name is not None:
|
||||||
@ -28,8 +28,7 @@ if config.config_file_name is not None:
|
|||||||
# for 'autogenerate' support
|
# for 'autogenerate' support
|
||||||
# from myapp import mymodel
|
# from myapp import mymodel
|
||||||
# target_metadata = mymodel.Base.metadata
|
# target_metadata = mymodel.Base.metadata
|
||||||
#target_metadata = None
|
target_metadata = Base.metadata
|
||||||
target_metadata = SQLModel.metadata
|
|
||||||
|
|
||||||
# other values from the config, defined by the needs of env.py,
|
# other values from the config, defined by the needs of env.py,
|
||||||
# can be acquired:
|
# can be acquired:
|
||||||
@ -68,7 +67,11 @@ def run_migrations_online() -> None:
|
|||||||
and associate a connection with the context.
|
and associate a connection with the context.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
connectable = create_engine(get_database_url())
|
connectable = engine_from_config(
|
||||||
|
config.get_section(config.config_ini_section, {}),
|
||||||
|
prefix="sqlalchemy.",
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
)
|
||||||
|
|
||||||
with connectable.connect() as connection:
|
with connectable.connect() as connection:
|
||||||
context.configure(
|
context.configure(
|
||||||
|
|||||||
@ -9,7 +9,6 @@ from typing import Sequence, Union
|
|||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
import sqlmodel
|
|
||||||
${imports if imports else ""}
|
${imports if imports else ""}
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
|
|||||||
@ -5,13 +5,14 @@
|
|||||||
import logging
|
import logging
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from fastapi import APIRouter, Depends, Request, Query
|
from fastapi import APIRouter, Depends, Request, Query
|
||||||
from sqlmodel import Session, col, func, select
|
from pydantic import TypeAdapter
|
||||||
from sqlalchemy import desc
|
from sqlalchemy import select, func
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from sshecret_backend.models import AuditLog
|
from sshecret_backend.models import AuditLog
|
||||||
from sshecret_backend.types import DBSessionDep
|
from sshecret_backend.types import DBSessionDep
|
||||||
from sshecret_backend.view_models import AuditInfo
|
from sshecret_backend.view_models import AuditInfo, AuditView
|
||||||
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@ -21,7 +22,7 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
"""Construct audit sub-api."""
|
"""Construct audit sub-api."""
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@router.get("/audit/", response_model=list[AuditLog])
|
@router.get("/audit/", response_model=list[AuditView])
|
||||||
async def get_audit_logs(
|
async def get_audit_logs(
|
||||||
request: Request,
|
request: Request,
|
||||||
session: Annotated[Session, Depends(get_db_session)],
|
session: Annotated[Session, Depends(get_db_session)],
|
||||||
@ -29,35 +30,37 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
limit: Annotated[int, Query(le=100)] = 100,
|
limit: Annotated[int, Query(le=100)] = 100,
|
||||||
filter_client: Annotated[str | None, Query()] = None,
|
filter_client: Annotated[str | None, Query()] = None,
|
||||||
filter_subsystem: Annotated[str | None, Query()] = None,
|
filter_subsystem: Annotated[str | None, Query()] = None,
|
||||||
) -> Sequence[AuditLog]:
|
) -> Sequence[AuditView]:
|
||||||
"""Get audit logs."""
|
"""Get audit logs."""
|
||||||
#audit.audit_access_audit_log(session, request)
|
#audit.audit_access_audit_log(session, request)
|
||||||
statement = select(AuditLog).offset(offset).limit(limit).order_by(desc(col(AuditLog.timestamp)))
|
statement = select(AuditLog).offset(offset).limit(limit).order_by(AuditLog.timestamp.desc())
|
||||||
if filter_client:
|
if filter_client:
|
||||||
statement = statement.where(AuditLog.client_name == filter_client)
|
statement = statement.where(AuditLog.client_name == filter_client)
|
||||||
|
|
||||||
if filter_subsystem:
|
if filter_subsystem:
|
||||||
statement = statement.where(AuditLog.subsystem == filter_subsystem)
|
statement = statement.where(AuditLog.subsystem == filter_subsystem)
|
||||||
|
|
||||||
results = session.exec(statement).all()
|
LogAdapt = TypeAdapter(list[AuditView])
|
||||||
return results
|
results = session.scalars(statement).all()
|
||||||
|
return LogAdapt.validate_python(results, from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/audit/")
|
@router.post("/audit/")
|
||||||
async def add_audit_log(
|
async def add_audit_log(
|
||||||
request: Request,
|
request: Request,
|
||||||
session: Annotated[Session, Depends(get_db_session)],
|
session: Annotated[Session, Depends(get_db_session)],
|
||||||
entry: AuditLog,
|
entry: AuditView,
|
||||||
) -> AuditLog:
|
) -> AuditView:
|
||||||
"""Add entry to audit log."""
|
"""Add entry to audit log."""
|
||||||
audit_log = AuditLog.model_validate(entry.model_dump(exclude_none=True))
|
audit_log = AuditLog(**entry.model_dump(exclude_none=True))
|
||||||
session.add(audit_log)
|
session.add(audit_log)
|
||||||
session.commit()
|
session.commit()
|
||||||
return audit_log
|
return AuditView.model_validate(audit_log, from_attributes=True)
|
||||||
|
|
||||||
@router.get("/audit/info")
|
@router.get("/audit/info")
|
||||||
async def get_audit_info(request: Request, session: Annotated[Session, Depends(get_db_session)]) -> AuditInfo:
|
async def get_audit_info(request: Request, session: Annotated[Session, Depends(get_db_session)]) -> AuditInfo:
|
||||||
"""Get audit info."""
|
"""Get audit info."""
|
||||||
audit_count = session.exec(select(func.count('*')).select_from(AuditLog)).one()
|
audit_count = session.scalars(select(func.count('*')).select_from(AuditLog)).one()
|
||||||
return AuditInfo(entries=audit_count)
|
return AuditInfo(entries=audit_count)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -6,11 +6,11 @@ import uuid
|
|||||||
import logging
|
import logging
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
from sqlmodel import Session, col, select
|
from typing import Annotated, Any, Self, TypeVar, cast
|
||||||
from sqlalchemy import func
|
|
||||||
from typing import Annotated, Self, TypeVar
|
|
||||||
|
|
||||||
from sqlmodel.sql.expression import SelectOfScalar
|
from sqlalchemy import select, func
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy.sql import Select
|
||||||
from sshecret_backend.types import DBSessionDep
|
from sshecret_backend.types import DBSessionDep
|
||||||
from sshecret_backend.models import Client, ClientSecret
|
from sshecret_backend.models import Client, ClientSecret
|
||||||
from sshecret_backend.view_models import (
|
from sshecret_backend.view_models import (
|
||||||
@ -55,8 +55,8 @@ T = TypeVar("T")
|
|||||||
|
|
||||||
|
|
||||||
def filter_client_statement(
|
def filter_client_statement(
|
||||||
statement: SelectOfScalar[T], params: ClientListParams, ignore_limits: bool = False
|
statement: Select[Any], params: ClientListParams, ignore_limits: bool = False
|
||||||
) -> SelectOfScalar[T]:
|
) -> Select[Any]:
|
||||||
"""Filter a statement with the provided params."""
|
"""Filter a statement with the provided params."""
|
||||||
if params.id:
|
if params.id:
|
||||||
statement = statement.where(Client.id == params.id)
|
statement = statement.where(Client.id == params.id)
|
||||||
@ -64,9 +64,9 @@ def filter_client_statement(
|
|||||||
if params.name:
|
if params.name:
|
||||||
statement = statement.where(Client.name == params.name)
|
statement = statement.where(Client.name == params.name)
|
||||||
elif params.name__like:
|
elif params.name__like:
|
||||||
statement = statement.where(col(Client.name).like(params.name__like))
|
statement = statement.where(Client.name.like(params.name__like))
|
||||||
elif params.name__contains:
|
elif params.name__contains:
|
||||||
statement = statement.where(col(Client.name).contains(params.name__contains))
|
statement = statement.where(Client.name.contains(params.name__contains))
|
||||||
|
|
||||||
if ignore_limits:
|
if ignore_limits:
|
||||||
return statement
|
return statement
|
||||||
@ -86,13 +86,13 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
"""Get clients."""
|
"""Get clients."""
|
||||||
# Get total results first
|
# Get total results first
|
||||||
count_statement = select(func.count("*")).select_from(Client)
|
count_statement = select(func.count("*")).select_from(Client)
|
||||||
count_statement = filter_client_statement(count_statement, filter_query, True)
|
count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True))
|
||||||
|
|
||||||
total_results = session.exec(count_statement).one()
|
total_results = session.scalars(count_statement).one()
|
||||||
|
|
||||||
statement = filter_client_statement(select(Client), filter_query, False)
|
statement = filter_client_statement(select(Client), filter_query, False)
|
||||||
|
|
||||||
results = session.exec(statement)
|
results = session.scalars(statement)
|
||||||
remainder = total_results - filter_query.offset - filter_query.limit
|
remainder = total_results - filter_query.offset - filter_query.limit
|
||||||
if remainder < 0:
|
if remainder < 0:
|
||||||
remainder = 0
|
remainder = 0
|
||||||
@ -170,13 +170,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
status_code=404, detail="Cannot find a client with the given name."
|
status_code=404, detail="Cannot find a client with the given name."
|
||||||
)
|
)
|
||||||
client.public_key = client_update.public_key
|
client.public_key = client_update.public_key
|
||||||
for secret in session.exec(
|
for secret in session.scalars(
|
||||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||||
).all():
|
).all():
|
||||||
LOG.debug("Invalidated secret %s", secret.id)
|
LOG.debug("Invalidated secret %s", secret.id)
|
||||||
secret.invalidated = True
|
secret.invalidated = True
|
||||||
secret.client_id = None
|
secret.client_id = None
|
||||||
secret.client = None
|
|
||||||
|
|
||||||
session.add(client)
|
session.add(client)
|
||||||
session.refresh(client)
|
session.refresh(client)
|
||||||
@ -206,13 +205,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
public_key_updated = False
|
public_key_updated = False
|
||||||
if client_update.public_key != client.public_key:
|
if client_update.public_key != client.public_key:
|
||||||
public_key_updated = True
|
public_key_updated = True
|
||||||
for secret in session.exec(
|
for secret in session.scalars(
|
||||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||||
).all():
|
).all():
|
||||||
LOG.debug("Invalidated secret %s", secret.id)
|
LOG.debug("Invalidated secret %s", secret.id)
|
||||||
secret.invalidated = True
|
secret.invalidated = True
|
||||||
secret.client_id = None
|
secret.client_id = None
|
||||||
secret.client = None
|
|
||||||
|
|
||||||
session.add(client)
|
session.add(client)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|||||||
@ -4,7 +4,8 @@ import re
|
|||||||
import uuid
|
import uuid
|
||||||
import bcrypt
|
import bcrypt
|
||||||
|
|
||||||
from sqlmodel import Session, select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from sshecret_backend.models import Client
|
from sshecret_backend.models import Client
|
||||||
|
|
||||||
@ -20,13 +21,13 @@ def verify_token(token: str, stored_hash: str) -> bool:
|
|||||||
async def get_client_by_name(session: Session, name: str) -> Client | None:
|
async def get_client_by_name(session: Session, name: str) -> Client | None:
|
||||||
"""Get client by name."""
|
"""Get client by name."""
|
||||||
client_filter = select(Client).where(Client.name == name)
|
client_filter = select(Client).where(Client.name == name)
|
||||||
client_results = session.exec(client_filter)
|
client_results = session.scalars(client_filter)
|
||||||
return client_results.first()
|
return client_results.first()
|
||||||
|
|
||||||
async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None:
|
async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None:
|
||||||
"""Get client by name."""
|
"""Get client by name."""
|
||||||
client_filter = select(Client).where(Client.id == id)
|
client_filter = select(Client).where(Client.id == id)
|
||||||
client_results = session.exec(client_filter)
|
client_results = session.scalars(client_filter)
|
||||||
return client_results.first()
|
return client_results.first()
|
||||||
|
|
||||||
async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None:
|
async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None:
|
||||||
|
|||||||
@ -4,7 +4,8 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from sqlmodel import Session, select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from sshecret_backend.models import ClientAccessPolicy
|
from sshecret_backend.models import ClientAccessPolicy
|
||||||
@ -54,7 +55,7 @@ def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
status_code=404, detail="Cannot find a client with the given name."
|
status_code=404, detail="Cannot find a client with the given name."
|
||||||
)
|
)
|
||||||
# Remove old policies.
|
# Remove old policies.
|
||||||
policies = session.exec(
|
policies = session.scalars(
|
||||||
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
|
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
|
||||||
).all()
|
).all()
|
||||||
deleted_policies: list[ClientAccessPolicy] = []
|
deleted_policies: list[ClientAccessPolicy] = []
|
||||||
|
|||||||
@ -5,7 +5,8 @@
|
|||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from sqlmodel import Session, select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from sshecret_backend.models import Client, ClientSecret
|
from sshecret_backend.models import Client, ClientSecret
|
||||||
@ -34,7 +35,7 @@ async def lookup_client_secret(
|
|||||||
.where(ClientSecret.client_id == client.id)
|
.where(ClientSecret.client_id == client.id)
|
||||||
.where(ClientSecret.name == name)
|
.where(ClientSecret.name == name)
|
||||||
)
|
)
|
||||||
results = session.exec(statement)
|
results = session.scalars(statement)
|
||||||
return results.first()
|
return results.first()
|
||||||
|
|
||||||
|
|
||||||
@ -165,7 +166,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
) -> list[ClientSecretList]:
|
) -> list[ClientSecretList]:
|
||||||
"""Get a list of all secrets and which clients have them."""
|
"""Get a list of all secrets and which clients have them."""
|
||||||
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
|
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
|
||||||
for client_secret in session.exec(select(ClientSecret)).all():
|
for client_secret in session.scalars(select(ClientSecret)).all():
|
||||||
if not client_secret.client:
|
if not client_secret.client:
|
||||||
if client_secret.name not in client_secret_map:
|
if client_secret.name not in client_secret_map:
|
||||||
client_secret_map[client_secret.name] = []
|
client_secret_map[client_secret.name] = []
|
||||||
@ -182,7 +183,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
) -> list[ClientSecretDetailList]:
|
) -> list[ClientSecretDetailList]:
|
||||||
"""Get a list of all secrets and which clients have them."""
|
"""Get a list of all secrets and which clients have them."""
|
||||||
client_secrets: dict[str, ClientSecretDetailList] = {}
|
client_secrets: dict[str, ClientSecretDetailList] = {}
|
||||||
for client_secret in session.exec(select(ClientSecret)).all():
|
for client_secret in session.scalars(select(ClientSecret)).all():
|
||||||
|
|
||||||
if client_secret.name not in client_secrets:
|
if client_secret.name not in client_secrets:
|
||||||
client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name)
|
client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name)
|
||||||
@ -202,7 +203,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
) -> ClientSecretList:
|
) -> ClientSecretList:
|
||||||
"""Get a list of which clients has a named secret."""
|
"""Get a list of which clients has a named secret."""
|
||||||
clients: list[str] = []
|
clients: list[str] = []
|
||||||
for client_secret in session.exec(
|
for client_secret in session.scalars(
|
||||||
select(ClientSecret).where(ClientSecret.name == name)
|
select(ClientSecret).where(ClientSecret.name == name)
|
||||||
).all():
|
).all():
|
||||||
if not client_secret.client:
|
if not client_secret.client:
|
||||||
@ -219,7 +220,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
) -> ClientSecretDetailList:
|
) -> ClientSecretDetailList:
|
||||||
"""Get a list of which clients has a named secret."""
|
"""Get a list of which clients has a named secret."""
|
||||||
detail_list = ClientSecretDetailList(name=name)
|
detail_list = ClientSecretDetailList(name=name)
|
||||||
for client_secret in session.exec(
|
for client_secret in session.scalars(
|
||||||
select(ClientSecret).where(ClientSecret.name == name)
|
select(ClientSecret).where(ClientSecret.name == name)
|
||||||
).all():
|
).all():
|
||||||
if not client_secret.client:
|
if not client_secret.client:
|
||||||
|
|||||||
@ -2,9 +2,10 @@
|
|||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from sqlmodel import Session, select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy
|
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy, Operation, SubSystem
|
||||||
|
|
||||||
|
|
||||||
def _get_origin(request: Request) -> str | None:
|
def _get_origin(request: Request) -> str | None:
|
||||||
@ -22,7 +23,7 @@ def _write_audit_log(
|
|||||||
"""Write the audit log."""
|
"""Write the audit log."""
|
||||||
origin = _get_origin(request)
|
origin = _get_origin(request)
|
||||||
entry.origin = origin
|
entry.origin = origin
|
||||||
entry.subsystem = "backend"
|
entry.subsystem = SubSystem.BACKEND
|
||||||
session.add(entry)
|
session.add(entry)
|
||||||
if commit:
|
if commit:
|
||||||
session.commit()
|
session.commit()
|
||||||
@ -33,7 +34,7 @@ def audit_create_client(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Log the creation of a client."""
|
"""Log the creation of a client."""
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
operation="CREATE",
|
operation=Operation.CREATE,
|
||||||
client_id=client.id,
|
client_id=client.id,
|
||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
message="Client Created",
|
message="Client Created",
|
||||||
@ -46,7 +47,7 @@ def audit_delete_client(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Log the creation of a client."""
|
"""Log the creation of a client."""
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
operation="CREATE",
|
operation=Operation.CREATE,
|
||||||
client_id=client.id,
|
client_id=client.id,
|
||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
message="Client deleted",
|
message="Client deleted",
|
||||||
@ -63,9 +64,9 @@ def audit_create_secret(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Audit a create secret event."""
|
"""Audit a create secret event."""
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
operation="CREATE",
|
operation=Operation.CREATE,
|
||||||
object="ClientSecret",
|
secret_id=secret.id,
|
||||||
object_id=str(secret.id),
|
secret_name=secret.name,
|
||||||
client_id=client.id,
|
client_id=client.id,
|
||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
message="Added secret to client",
|
message="Added secret to client",
|
||||||
@ -81,13 +82,13 @@ def audit_remove_policy(
|
|||||||
commit: bool = True,
|
commit: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Audit removal of policy."""
|
"""Audit removal of policy."""
|
||||||
|
data = {"object": "ClientAccessPolicy", "object_id": str(policy.id)}
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
operation="DELETE",
|
operation=Operation.DELETE,
|
||||||
object="ClientAccessPolicy",
|
|
||||||
object_id=str(policy.id),
|
|
||||||
client_id=client.id,
|
client_id=client.id,
|
||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
message="Deleted client policy",
|
message="Deleted client policy",
|
||||||
|
data=data,
|
||||||
)
|
)
|
||||||
_write_audit_log(session, request, entry, commit)
|
_write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
@ -100,13 +101,13 @@ def audit_update_policy(
|
|||||||
commit: bool = True,
|
commit: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Audit update of policy."""
|
"""Audit update of policy."""
|
||||||
|
data: dict[str, str] = {"object": "ClientAccessPolicy", "object_id": str(policy.id)}
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
operation="CREATE",
|
operation=Operation.CREATE,
|
||||||
object="ClientAccessPolicy",
|
|
||||||
object_id=str(policy.id),
|
|
||||||
client_id=client.id,
|
|
||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
|
client_id=client.id,
|
||||||
message="Updated client policy",
|
message="Updated client policy",
|
||||||
|
data=data,
|
||||||
)
|
)
|
||||||
_write_audit_log(session, request, entry, commit)
|
_write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
@ -119,11 +120,10 @@ def audit_update_client(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Audit an update secret event."""
|
"""Audit an update secret event."""
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
operation="UPDATE",
|
operation=Operation.UPDATE,
|
||||||
object="Client",
|
|
||||||
client_id=client.id,
|
client_id=client.id,
|
||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
message="Client updated",
|
message="Client data updated",
|
||||||
)
|
)
|
||||||
_write_audit_log(session, request, entry, commit)
|
_write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
@ -137,11 +137,11 @@ def audit_update_secret(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Audit an update secret event."""
|
"""Audit an update secret event."""
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
operation="UPDATE",
|
operation=Operation.UPDATE,
|
||||||
object="ClientSecret",
|
|
||||||
object_id=str(secret.id),
|
|
||||||
client_id=client.id,
|
client_id=client.id,
|
||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
|
secret_name=secret.name,
|
||||||
|
secret_id=secret.id,
|
||||||
message="Secret value updated",
|
message="Secret value updated",
|
||||||
)
|
)
|
||||||
_write_audit_log(session, request, entry, commit)
|
_write_audit_log(session, request, entry, commit)
|
||||||
@ -155,8 +155,7 @@ def audit_invalidate_secrets(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Audit Invalidate client secrets."""
|
"""Audit Invalidate client secrets."""
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
operation="INVALIDATE",
|
operation=Operation.UPDATE,
|
||||||
object="ClientSecret",
|
|
||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
client_id=client.id,
|
client_id=client.id,
|
||||||
message="Client public-key changed. All secrets invalidated.",
|
message="Client public-key changed. All secrets invalidated.",
|
||||||
@ -173,9 +172,9 @@ def audit_delete_secret(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Audit Delete client secrets."""
|
"""Audit Delete client secrets."""
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
operation="DELETE",
|
operation=Operation.DELETE,
|
||||||
object="ClientSecret",
|
secret_name=secret.name,
|
||||||
object_id=str(secret.id),
|
secret_id=secret.id,
|
||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
client_id=client.id,
|
client_id=client.id,
|
||||||
message="Deleted secret.",
|
message="Deleted secret.",
|
||||||
@ -195,7 +194,7 @@ def audit_access_secrets(
|
|||||||
With no secrets provided, all secrets of the client will be resolved.
|
With no secrets provided, all secrets of the client will be resolved.
|
||||||
"""
|
"""
|
||||||
if not secrets:
|
if not secrets:
|
||||||
secrets = session.exec(
|
secrets = session.scalars(
|
||||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||||
).all()
|
).all()
|
||||||
|
|
||||||
@ -215,37 +214,21 @@ def audit_access_secret(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Audit that someone accessed one secrets."""
|
"""Audit that someone accessed one secrets."""
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
operation="ACCESS",
|
operation=Operation.READ,
|
||||||
message="Secret was viewed",
|
message="Secret was viewed",
|
||||||
object="ClientSecret",
|
secret_name=secret.name,
|
||||||
object_id=str(secret.id),
|
secret_id=secret.id,
|
||||||
client_id=client.id,
|
client_id=client.id,
|
||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
)
|
)
|
||||||
_write_audit_log(session, request, entry, commit)
|
_write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
|
|
||||||
def audit_access_audit_log(
|
|
||||||
session: Session, request: Request, commit: bool = True
|
|
||||||
) -> None:
|
|
||||||
"""Audit access to the audit log.
|
|
||||||
|
|
||||||
Because why not...
|
|
||||||
"""
|
|
||||||
entry = AuditLog(
|
|
||||||
operation="ACCESS",
|
|
||||||
message="Audit log was viewed",
|
|
||||||
object="AuditLog",
|
|
||||||
)
|
|
||||||
_write_audit_log(session, request, entry, commit)
|
|
||||||
|
|
||||||
|
|
||||||
def audit_client_secret_list(
|
def audit_client_secret_list(
|
||||||
session: Session, request: Request, commit: bool = True
|
session: Session, request: Request, commit: bool = True
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Audit a list of all secrets."""
|
"""Audit a list of all secrets."""
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
operation="ACCESS",
|
operation=Operation.READ,
|
||||||
message="All secret names and their clients was viewed",
|
message="All secret names and their clients was viewed",
|
||||||
)
|
)
|
||||||
_write_audit_log(session, request, entry, commit)
|
_write_audit_log(session, request, entry, commit)
|
||||||
|
|||||||
@ -3,11 +3,11 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
import bcrypt
|
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException
|
from fastapi import APIRouter, Depends, Header, HTTPException
|
||||||
from sqlmodel import Session, select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api
|
from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api
|
||||||
|
from .auth import verify_token
|
||||||
from .models import (
|
from .models import (
|
||||||
APIClient,
|
APIClient,
|
||||||
)
|
)
|
||||||
@ -18,13 +18,6 @@ LOG = logging.getLogger(__name__)
|
|||||||
API_VERSION = "v1"
|
API_VERSION = "v1"
|
||||||
|
|
||||||
|
|
||||||
def verify_token(token: str, stored_hash: str) -> bool:
|
|
||||||
"""Verify token."""
|
|
||||||
token_bytes = token.encode("utf-8")
|
|
||||||
stored_bytes = stored_hash.encode("utf-8")
|
|
||||||
return bcrypt.checkpw(token_bytes, stored_bytes)
|
|
||||||
|
|
||||||
|
|
||||||
def get_backend_api(
|
def get_backend_api(
|
||||||
get_db_session: DBSessionDep,
|
get_db_session: DBSessionDep,
|
||||||
) -> APIRouter:
|
) -> APIRouter:
|
||||||
@ -37,7 +30,7 @@ def get_backend_api(
|
|||||||
"""Validate token."""
|
"""Validate token."""
|
||||||
LOG.debug("Validating token %s", x_api_token)
|
LOG.debug("Validating token %s", x_api_token)
|
||||||
statement = select(APIClient)
|
statement = select(APIClient)
|
||||||
results = session.exec(statement)
|
results = session.scalars(statement)
|
||||||
valid = False
|
valid = False
|
||||||
for result in results:
|
for result in results:
|
||||||
if verify_token(x_api_token, result.token):
|
if verify_token(x_api_token, result.token):
|
||||||
|
|||||||
@ -3,15 +3,24 @@
|
|||||||
import code
|
import code
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import cast
|
from typing import Literal, cast
|
||||||
from dotenv import load_dotenv
|
|
||||||
import click
|
import click
|
||||||
from sqlmodel import Session, col, func, select
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from .db import get_engine, create_api_token
|
from .db import create_api_token, get_engine, hash_token
|
||||||
|
from .models import (
|
||||||
from .models import Client, ClientSecret, ClientAccessPolicy, AuditLog, APIClient, init_db
|
APIClient,
|
||||||
|
AuditLog,
|
||||||
|
Client,
|
||||||
|
ClientAccessPolicy,
|
||||||
|
ClientSecret,
|
||||||
|
SubSystem,
|
||||||
|
init_db,
|
||||||
|
)
|
||||||
from .settings import BackendSettings
|
from .settings import BackendSettings
|
||||||
|
|
||||||
DEFAULT_LISTEN = "127.0.0.1"
|
DEFAULT_LISTEN = "127.0.0.1"
|
||||||
@ -21,22 +30,44 @@ WORKDIR = Path(os.getcwd())
|
|||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
def generate_token(settings: BackendSettings) -> str:
|
|
||||||
|
def generate_token(
|
||||||
|
settings: BackendSettings, subsystem: Literal["admin", "sshd"]
|
||||||
|
) -> str:
|
||||||
"""Generate a token."""
|
"""Generate a token."""
|
||||||
engine = get_engine(settings.db_url)
|
engine = get_engine(settings.db_url)
|
||||||
init_db(engine)
|
init_db(engine)
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
token = create_api_token(session, True)
|
token = create_api_token(session, subsystem)
|
||||||
return token
|
return token
|
||||||
|
|
||||||
def count_tokens(settings: BackendSettings) -> int:
|
|
||||||
"""Count the amount of tokens created."""
|
def add_system_tokens(settings: BackendSettings) -> None:
|
||||||
|
"""Add token for subsystems."""
|
||||||
|
if not settings.admin_token and not settings.sshd_token:
|
||||||
|
# Tokens should be generated manually.
|
||||||
|
return
|
||||||
|
|
||||||
engine = get_engine(settings.db_url)
|
engine = get_engine(settings.db_url)
|
||||||
init_db(engine)
|
init_db(engine)
|
||||||
|
tokens: list[tuple[str, SubSystem]] = []
|
||||||
|
if admin_token := settings.admin_token:
|
||||||
|
tokens.append((admin_token, SubSystem.ADMIN))
|
||||||
|
if sshd_token := settings.sshd_token:
|
||||||
|
tokens.append((sshd_token, SubSystem.SSHD))
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
count = session.exec(select(func.count("*")).select_from(APIClient)).one()
|
for token, subsystem in tokens:
|
||||||
|
hashed_token = hash_token(token)
|
||||||
|
if existing := session.scalars(
|
||||||
|
select(APIClient).where(APIClient.subsystem == subsystem)
|
||||||
|
).first():
|
||||||
|
existing.token = hashed_token
|
||||||
|
else:
|
||||||
|
new_token = APIClient(token=hashed_token, subsystem=subsystem)
|
||||||
|
session.add(new_token)
|
||||||
|
|
||||||
return count
|
session.commit()
|
||||||
|
click.echo("Generated system tokens.")
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
@ -49,27 +80,30 @@ def cli(ctx: click.Context, database: str) -> None:
|
|||||||
else:
|
else:
|
||||||
settings = BackendSettings()
|
settings = BackendSettings()
|
||||||
|
|
||||||
|
add_system_tokens(settings)
|
||||||
|
|
||||||
if settings.generate_initial_tokens:
|
# if settings.generate_initial_tokens:
|
||||||
if count_tokens(settings) == 0:
|
# if count_tokens(settings) == 0:
|
||||||
click.echo("Creating initial tokens for admin and sshd.")
|
# click.echo("Creating initial tokens for admin and sshd.")
|
||||||
admin_token = generate_token(settings)
|
# admin_token = generate_token(settings)
|
||||||
sshd_token = generate_token(settings)
|
# sshd_token = generate_token(settings)
|
||||||
click.echo(f"Admin token: {admin_token}")
|
# click.echo(f"Admin token: {admin_token}")
|
||||||
click.echo(f"SSHD token: {sshd_token}")
|
# click.echo(f"SSHD token: {sshd_token}")
|
||||||
|
|
||||||
ctx.obj = settings
|
ctx.obj = settings
|
||||||
|
|
||||||
|
|
||||||
@cli.command("generate-token")
|
@cli.command("generate-token")
|
||||||
|
@click.argument("subsystem", type=click.Choice(["sshd", "admin"]))
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def cli_generate_token(ctx: click.Context) -> None:
|
def cli_generate_token(ctx: click.Context, subsystem: Literal["sshd", "admin"]) -> None:
|
||||||
"""Generate a token."""
|
"""Generate a token for a subsystem.."""
|
||||||
settings = cast(BackendSettings, ctx.obj)
|
settings = cast(BackendSettings, ctx.obj)
|
||||||
token = generate_token(settings)
|
token = generate_token(settings, subsystem)
|
||||||
click.echo("Generated api token:")
|
click.echo("Generated api token:")
|
||||||
click.echo(token)
|
click.echo(token)
|
||||||
|
|
||||||
|
|
||||||
@cli.command("run")
|
@cli.command("run")
|
||||||
@click.option("--host", default="127.0.0.1")
|
@click.option("--host", default="127.0.0.1")
|
||||||
@click.option("--port", default=8022, type=click.INT)
|
@click.option("--port", default=8022, type=click.INT)
|
||||||
@ -77,7 +111,10 @@ def cli_generate_token(ctx: click.Context) -> None:
|
|||||||
@click.option("--workers", type=click.INT)
|
@click.option("--workers", type=click.INT)
|
||||||
def cli_run(host: str, port: int, dev: bool, workers: int | None) -> None:
|
def cli_run(host: str, port: int, dev: bool, workers: int | None) -> None:
|
||||||
"""Run the server."""
|
"""Run the server."""
|
||||||
uvicorn.run("sshecret_backend.main:app", host=host, port=port, reload=dev, workers=workers)
|
uvicorn.run(
|
||||||
|
"sshecret_backend.main:app", host=host, port=port, reload=dev, workers=workers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@cli.command("repl")
|
@cli.command("repl")
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
|
|||||||
@ -2,56 +2,108 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
from collections.abc import Generator, Callable
|
from collections.abc import Generator, Callable
|
||||||
from pathlib import Path
|
from typing import Literal
|
||||||
from sqlalchemy import Engine
|
from sqlalchemy import create_engine, Engine, event, select
|
||||||
from sqlmodel import Session, create_engine, text
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
|
||||||
import bcrypt
|
|
||||||
|
from sqlalchemy.orm import sessionmaker, Session
|
||||||
|
|
||||||
from sqlalchemy.engine import URL
|
from sqlalchemy.engine import URL
|
||||||
|
|
||||||
|
from .auth import hash_token, verify_token
|
||||||
from .models import APIClient
|
from .models import APIClient, SubSystem
|
||||||
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def setup_database(
|
def setup_database(
|
||||||
db_url: URL | str,
|
db_url: URL,
|
||||||
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
|
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
|
||||||
"""Setup database."""
|
"""Setup database."""
|
||||||
|
|
||||||
engine = create_engine(db_url, echo=False)
|
engine = get_engine(db_url)
|
||||||
with engine.connect() as connection:
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, future=True)
|
||||||
connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only
|
|
||||||
|
|
||||||
def get_db_session() -> Generator[Session, None, None]:
|
def get_db_session() -> Generator[Session, None, None]:
|
||||||
"""Get DB Session."""
|
"""Get DB Session."""
|
||||||
with Session(engine) as session:
|
session = SessionLocal(bind=engine)
|
||||||
|
try:
|
||||||
yield session
|
yield session
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
return engine, get_db_session
|
return engine, get_db_session
|
||||||
|
|
||||||
|
|
||||||
def get_engine(url: URL, echo: bool = False) -> Engine:
|
def get_engine(url: URL, echo: bool = False) -> Engine:
|
||||||
"""Initialize the engine."""
|
"""Initialize the engine."""
|
||||||
engine = create_engine(url, echo=echo)
|
engine = create_engine(url, echo=echo, future=True)
|
||||||
with engine.connect() as connection:
|
if url.drivername.startswith("sqlite"):
|
||||||
connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only
|
|
||||||
|
@event.listens_for(engine, "connect")
|
||||||
|
def set_sqlite_pragma(
|
||||||
|
dbapi_connection: sqlite3.Connection, _connection_record: object
|
||||||
|
) -> None:
|
||||||
|
cursor = dbapi_connection.cursor()
|
||||||
|
cursor.execute("PRAGMA foreign_keys=ON")
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|
||||||
def create_api_token(session: Session, read_write: bool) -> str:
|
def get_async_engine(url: URL, echo: bool = False) -> AsyncEngine:
|
||||||
"""Create API token."""
|
"""Get an async engine."""
|
||||||
token = secrets.token_urlsafe(32)
|
engine = create_async_engine(url, echo=echo, future=True)
|
||||||
pwbytes = token.encode("utf-8")
|
if url.drivername.startswith("sqlite+"):
|
||||||
salt = bcrypt.gensalt()
|
|
||||||
hashed_bytes = bcrypt.hashpw(password=pwbytes, salt=salt)
|
|
||||||
hashed = hashed_bytes.decode()
|
|
||||||
|
|
||||||
api_token = APIClient(token=hashed, read_write=read_write)
|
@event.listens_for(engine, "connect")
|
||||||
|
def set_sqlite_pragma(
|
||||||
|
dbapi_connection: sqlite3.Connection, _connection_record: object
|
||||||
|
) -> None:
|
||||||
|
cursor = dbapi_connection.cursor()
|
||||||
|
cursor.execute("PRAGMA foreign_keys=ON")
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def create_api_token_with_value(session: Session, token: str, subsystem: Literal["admin", "sshd"]) -> None:
|
||||||
|
"""Create API token with a given value."""
|
||||||
|
|
||||||
|
existing = session.scalars(select(APIClient).where(APIClient.subsystem == SubSystem(subsystem))).first()
|
||||||
|
if existing:
|
||||||
|
if verify_token(token, existing.token):
|
||||||
|
LOG.info("Token is up to date.")
|
||||||
|
return
|
||||||
|
LOG.info("Updating token value for subsystem %s", subsystem)
|
||||||
|
hashed = hash_token(token)
|
||||||
|
existing.token=hashed
|
||||||
|
session.commit()
|
||||||
|
return
|
||||||
|
|
||||||
|
LOG.info("No existing token found. Creating new")
|
||||||
|
hashed = hash_token(token)
|
||||||
|
api_token = APIClient(token=hashed, subsystem=SubSystem(subsystem))
|
||||||
|
|
||||||
|
session.add(api_token)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def create_api_token(session: Session, subsystem: Literal["admin", "sshd", "test"], recreate: bool = False) -> str:
|
||||||
|
"""Create API token."""
|
||||||
|
subsys = SubSystem(subsystem)
|
||||||
|
token = secrets.token_urlsafe(32)
|
||||||
|
hashed = hash_token(token)
|
||||||
|
if existing := session.scalars(select(APIClient).where(APIClient.subsystem == subsys)).first():
|
||||||
|
if not recreate:
|
||||||
|
raise RuntimeError("Error: A token already exist for this subsystem.")
|
||||||
|
existing.token = hashed
|
||||||
|
else:
|
||||||
|
api_token = APIClient(token=hashed, subsystem=subsys)
|
||||||
session.add(api_token)
|
session.add(api_token)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
|||||||
@ -7,128 +7,182 @@ This might require some changes to these schemas.
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import enum
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlmodel import JSON, Column, DateTime, Field, Relationship, SQLModel
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Client(SQLModel, table=True):
|
class SubSystem(enum.StrEnum):
|
||||||
"""Client model."""
|
"""Available subsystems."""
|
||||||
|
|
||||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
ADMIN = enum.auto()
|
||||||
name: str = Field(unique=True)
|
SSHD = enum.auto()
|
||||||
description: str | None = None
|
BACKEND = enum.auto()
|
||||||
public_key: str
|
TEST = enum.auto()
|
||||||
|
|
||||||
created_at: datetime | None = Field(
|
|
||||||
default=None,
|
class Operation(enum.StrEnum):
|
||||||
sa_type=sa.DateTime(timezone=True),
|
"""Various operations for the audit logging module."""
|
||||||
sa_column_kwargs={"server_default": sa.func.now()},
|
|
||||||
nullable=False,
|
CREATE = enum.auto()
|
||||||
|
READ = enum.auto()
|
||||||
|
UPDATE = enum.auto()
|
||||||
|
DELETE = enum.auto()
|
||||||
|
DENY = enum.auto()
|
||||||
|
PERMIT = enum.auto()
|
||||||
|
LOGIN = enum.auto()
|
||||||
|
REGISTER = enum.auto()
|
||||||
|
NONE = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Client(Base):
|
||||||
|
"""Clients."""
|
||||||
|
|
||||||
|
__tablename__: str = "client"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||||
|
)
|
||||||
|
name: Mapped[str] = mapped_column(sa.String, unique=True)
|
||||||
|
description: Mapped[str | None] = mapped_column(sa.String, nullable=True)
|
||||||
|
public_key: Mapped[str] = mapped_column(sa.Text)
|
||||||
|
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||||
)
|
)
|
||||||
|
|
||||||
updated_at: datetime | None = Field(
|
updated_at: Mapped[datetime | None] = mapped_column(
|
||||||
default=None,
|
sa.DateTime(timezone=True),
|
||||||
sa_type=sa.DateTime(timezone=True),
|
server_default=sa.func.now(),
|
||||||
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
|
onupdate=sa.func.now(),
|
||||||
)
|
)
|
||||||
|
|
||||||
secrets: list["ClientSecret"] = Relationship(
|
secrets: Mapped[list["ClientSecret"]] = relationship(
|
||||||
back_populates="client", passive_deletes="all"
|
back_populates="client", passive_deletes=True
|
||||||
)
|
)
|
||||||
|
|
||||||
policies: list["ClientAccessPolicy"] = Relationship(back_populates="client")
|
policies: Mapped[list["ClientAccessPolicy"]] = relationship(back_populates="client")
|
||||||
|
|
||||||
|
|
||||||
class ClientAccessPolicy(SQLModel, table=True):
|
class ClientAccessPolicy(Base):
|
||||||
"""Client access policies."""
|
"""Client access policies."""
|
||||||
|
|
||||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
__tablename__: str = "client_access_policy"
|
||||||
source: str
|
|
||||||
client_id: uuid.UUID | None = Field(foreign_key="client.id", ondelete="CASCADE")
|
|
||||||
client: Client | None = Relationship(back_populates="policies")
|
|
||||||
|
|
||||||
created_at: datetime | None = Field(
|
id: Mapped[uuid.UUID] = mapped_column(
|
||||||
default=None,
|
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||||
sa_type=sa.DateTime(timezone=True),
|
)
|
||||||
sa_column_kwargs={"server_default": sa.func.now()},
|
source: Mapped[str] = mapped_column(sa.String)
|
||||||
nullable=False,
|
client_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
|
sa.Uuid(as_uuid=True), sa.ForeignKey("client.id", ondelete="CASCADE")
|
||||||
|
)
|
||||||
|
client: Mapped[Client] = relationship(back_populates="policies")
|
||||||
|
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||||
)
|
)
|
||||||
|
|
||||||
updated_at: datetime | None = Field(
|
updated_at: Mapped[datetime | None] = mapped_column(
|
||||||
default=None,
|
sa.DateTime(timezone=True),
|
||||||
sa_type=sa.DateTime(timezone=True),
|
server_default=sa.func.now(),
|
||||||
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
|
onupdate=sa.func.now(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ClientSecret(SQLModel, table=True):
|
class ClientSecret(Base):
|
||||||
"""A client secret."""
|
"""A client secret."""
|
||||||
|
|
||||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
__tablename__: str = "client_secret"
|
||||||
name: str
|
|
||||||
description: str | None = None
|
id: Mapped[uuid.UUID] = mapped_column(
|
||||||
client_id: uuid.UUID | None = Field(foreign_key="client.id", ondelete="CASCADE")
|
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||||
client: Client | None = Relationship(back_populates="secrets")
|
)
|
||||||
secret: str
|
name: Mapped[str] = mapped_column(sa.String)
|
||||||
invalidated: bool = Field(default=False)
|
description: Mapped[str | None] = mapped_column(sa.String, nullable=True)
|
||||||
created_at: datetime | None = Field(
|
secret: Mapped[str] = mapped_column(sa.String)
|
||||||
default=None,
|
|
||||||
sa_type=sa.DateTime(timezone=True),
|
client_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
sa_column_kwargs={"server_default": sa.func.now()},
|
sa.Uuid(as_uuid=True), sa.ForeignKey("client.id", ondelete="CASCADE")
|
||||||
nullable=False,
|
)
|
||||||
|
client: Mapped[Client] = relationship(back_populates="secrets")
|
||||||
|
invalidated: Mapped[bool] = mapped_column(default=False)
|
||||||
|
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||||
)
|
)
|
||||||
|
|
||||||
updated_at: datetime | None = Field(
|
updated_at: Mapped[datetime | None] = mapped_column(
|
||||||
default=None,
|
sa.DateTime(timezone=True),
|
||||||
sa_type=sa.DateTime(timezone=True),
|
server_default=sa.func.now(),
|
||||||
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
|
onupdate=sa.func.now(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AuditLog(SQLModel, table=True):
|
class APIClient(Base):
|
||||||
|
"""A client on the API.
|
||||||
|
|
||||||
|
This should eventually get more granular permissions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__: str = "api_client"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||||
|
)
|
||||||
|
subsystem: Mapped[SubSystem | None] = mapped_column(sa.String, nullable=True)
|
||||||
|
token: Mapped[str] = mapped_column(sa.String)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime | None] = mapped_column(
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
onupdate=sa.func.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AuditLog(Base):
|
||||||
"""Audit log.
|
"""Audit log.
|
||||||
|
|
||||||
This is implemented without any foreign keys to avoid losing data on
|
This is implemented without any foreign keys to avoid losing data on
|
||||||
deletions.
|
deletions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
__tablename__: str = "audit_log"
|
||||||
subsystem: str
|
id: Mapped[uuid.UUID] = mapped_column(
|
||||||
message: str
|
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||||
operation: str
|
|
||||||
client_id: uuid.UUID | None = None
|
|
||||||
client_name: str | None = None
|
|
||||||
origin: str | None = None
|
|
||||||
Field(default=None, sa_column=Column(JSON))
|
|
||||||
|
|
||||||
timestamp: datetime | None = Field(
|
|
||||||
default=None,
|
|
||||||
sa_type=sa.DateTime(timezone=True),
|
|
||||||
sa_column_kwargs={"server_default": sa.func.now()},
|
|
||||||
nullable=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
subsystem: Mapped[SubSystem] = mapped_column(sa.String)
|
||||||
|
message: Mapped[str] = mapped_column(sa.String)
|
||||||
|
operation: Mapped[Operation] = mapped_column(sa.String)
|
||||||
|
client_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
|
sa.Uuid(as_uuid=True), nullable=True
|
||||||
|
)
|
||||||
|
data: Mapped[dict[str, str] | None] = mapped_column(sa.JSON, nullable=True)
|
||||||
|
client_name: Mapped[str | None] = mapped_column(sa.String, nullable=True)
|
||||||
|
secret_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
|
sa.Uuid(as_uuid=True), nullable=True
|
||||||
|
)
|
||||||
|
secret_name: Mapped[str | None] = mapped_column(sa.String, nullable=True)
|
||||||
|
|
||||||
class APIClient(SQLModel, table=True):
|
origin: Mapped[str | None] = mapped_column(sa.String, nullable=True)
|
||||||
"""Stores API Keys."""
|
timestamp: Mapped[datetime] = mapped_column(
|
||||||
|
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
|
||||||
token: str
|
|
||||||
read_write: bool
|
|
||||||
created_at: datetime | None = Field(
|
|
||||||
default=None,
|
|
||||||
sa_type=sa.DateTime(timezone=True),
|
|
||||||
sa_column_kwargs={"server_default": sa.func.now()},
|
|
||||||
nullable=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def init_db(engine: sa.Engine) -> None:
|
def init_db(engine: sa.Engine) -> None:
|
||||||
"""Create database."""
|
"""Initialize database."""
|
||||||
LOG.info("Running init_db")
|
Base.metadata.create_all(engine)
|
||||||
SQLModel.metadata.create_all(engine)
|
|
||||||
|
|||||||
@ -1,12 +1,10 @@
|
|||||||
"""Settings management."""
|
"""Settings management."""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Any
|
from pydantic import Field
|
||||||
from pydantic import Field, field_validator
|
|
||||||
from pydantic_settings import (
|
from pydantic_settings import (
|
||||||
BaseSettings,
|
BaseSettings,
|
||||||
SettingsConfigDict,
|
SettingsConfigDict,
|
||||||
ForceDecode,
|
|
||||||
)
|
)
|
||||||
from sqlalchemy import URL
|
from sqlalchemy import URL
|
||||||
|
|
||||||
@ -22,24 +20,19 @@ class BackendSettings(BaseSettings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
database: str = Field(default=DEFAULT_DATABASE)
|
database: str = Field(default=DEFAULT_DATABASE)
|
||||||
generate_initial_tokens: Annotated[bool, ForceDecode] = Field(default=False)
|
admin_token: str | None = Field(default=None, alias="sshecret_admin_backend_token")
|
||||||
|
sshd_token: str | None = Field(default=None, alias="sshecret_sshd_backend_token")
|
||||||
@field_validator("generate_initial_tokens", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def cast_bool(cls, value: Any) -> bool:
|
|
||||||
"""Ensure we catch the boolean."""
|
|
||||||
if isinstance(value, str):
|
|
||||||
if value.lower() in ("1", "true", "on"):
|
|
||||||
return True
|
|
||||||
if value.lower() in ("0", "false", "off"):
|
|
||||||
return False
|
|
||||||
return bool(value)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def db_url(self) -> URL:
|
def db_url(self) -> URL:
|
||||||
"""Construct database url."""
|
"""Construct database url."""
|
||||||
return URL.create(drivername="sqlite", database=self.database)
|
return URL.create(drivername="sqlite", database=self.database)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def async_db_url(self) -> URL:
|
||||||
|
"""Construct database url with sync handling."""
|
||||||
|
return URL.create(drivername="sqlite+aiosqlite", database=self.database)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def db_exists(self) -> bool:
|
def db_exists(self) -> bool:
|
||||||
"""Check if databatase exists."""
|
"""Check if databatase exists."""
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
"""Test helpers."""
|
"""Test helpers."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from sqlmodel import Session
|
|
||||||
from sshecret_backend.settings import BackendSettings
|
from sshecret_backend.settings import BackendSettings
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from .models import init_db
|
from .models import init_db
|
||||||
from .db import create_api_token, setup_database
|
from .db import create_api_token, setup_database
|
||||||
|
|
||||||
@ -14,4 +15,4 @@ def create_test_token(settings: BackendSettings) -> str:
|
|||||||
engine, _setupdb = setup_database(settings.db_url)
|
engine, _setupdb = setup_database(settings.db_url)
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
init_db(engine)
|
init_db(engine)
|
||||||
return create_api_token(session, True)
|
return create_api_token(session, "test")
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from collections.abc import Callable, Generator
|
from collections.abc import Callable, Generator
|
||||||
|
|
||||||
from sqlmodel import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
||||||
DBSessionDep = Callable[[], Generator[Session, None, None]]
|
DBSessionDep = Callable[[], Generator[Session, None, None]]
|
||||||
|
|||||||
@ -4,15 +4,14 @@ import uuid
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Annotated, Self, override
|
from typing import Annotated, Self, override
|
||||||
|
|
||||||
from sqlmodel import Field, SQLModel
|
from pydantic import AfterValidator, BaseModel, Field, IPvAnyAddress, IPvAnyNetwork
|
||||||
from pydantic import AfterValidator, IPvAnyAddress, IPvAnyNetwork
|
|
||||||
|
|
||||||
from sshecret.crypto import public_key_validator
|
from sshecret.crypto import public_key_validator
|
||||||
|
|
||||||
from . import models
|
from . import models
|
||||||
|
|
||||||
|
|
||||||
class ClientView(SQLModel):
|
class ClientView(BaseModel):
|
||||||
"""View for a single client."""
|
"""View for a single client."""
|
||||||
|
|
||||||
id: uuid.UUID
|
id: uuid.UUID
|
||||||
@ -50,7 +49,7 @@ class ClientView(SQLModel):
|
|||||||
return view
|
return view
|
||||||
|
|
||||||
|
|
||||||
class ClientQueryResult(SQLModel):
|
class ClientQueryResult(BaseModel):
|
||||||
"""Result class for queries towards the client list."""
|
"""Result class for queries towards the client list."""
|
||||||
|
|
||||||
clients: list[ClientView] = Field(default_factory=list)
|
clients: list[ClientView] = Field(default_factory=list)
|
||||||
@ -58,7 +57,7 @@ class ClientQueryResult(SQLModel):
|
|||||||
remaining_results: int
|
remaining_results: int
|
||||||
|
|
||||||
|
|
||||||
class ClientCreate(SQLModel):
|
class ClientCreate(BaseModel):
|
||||||
"""Model to create a client."""
|
"""Model to create a client."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
@ -74,19 +73,19 @@ class ClientCreate(SQLModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ClientUpdate(SQLModel):
|
class ClientUpdate(BaseModel):
|
||||||
"""Model to update the client public key."""
|
"""Model to update the client public key."""
|
||||||
|
|
||||||
public_key: Annotated[str, AfterValidator(public_key_validator)]
|
public_key: Annotated[str, AfterValidator(public_key_validator)]
|
||||||
|
|
||||||
|
|
||||||
class BodyValue(SQLModel):
|
class BodyValue(BaseModel):
|
||||||
"""A generic model with just a value parameter."""
|
"""A generic model with just a value parameter."""
|
||||||
|
|
||||||
value: str
|
value: str
|
||||||
|
|
||||||
|
|
||||||
class ClientSecretPublic(SQLModel):
|
class ClientSecretPublic(BaseModel):
|
||||||
"""Public model to manage client secrets."""
|
"""Public model to manage client secrets."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
@ -122,7 +121,7 @@ class ClientSecretResponse(ClientSecretPublic):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ClientPolicyView(SQLModel):
|
class ClientPolicyView(BaseModel):
|
||||||
"""Update object for client policy."""
|
"""Update object for client policy."""
|
||||||
|
|
||||||
sources: list[str] = ["0.0.0.0/0", "::/0"]
|
sources: list[str] = ["0.0.0.0/0", "::/0"]
|
||||||
@ -135,27 +134,27 @@ class ClientPolicyView(SQLModel):
|
|||||||
return cls(sources=[policy.source for policy in client.policies])
|
return cls(sources=[policy.source for policy in client.policies])
|
||||||
|
|
||||||
|
|
||||||
class ClientPolicyUpdate(SQLModel):
|
class ClientPolicyUpdate(BaseModel):
|
||||||
"""Model for updating policies."""
|
"""Model for updating policies."""
|
||||||
|
|
||||||
sources: list[IPvAnyAddress | IPvAnyNetwork]
|
sources: list[IPvAnyAddress | IPvAnyNetwork]
|
||||||
|
|
||||||
|
|
||||||
class ClientSecretList(SQLModel):
|
class ClientSecretList(BaseModel):
|
||||||
"""Model for aggregating identically named secrets."""
|
"""Model for aggregating identically named secrets."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
clients: list[str]
|
clients: list[str]
|
||||||
|
|
||||||
|
|
||||||
class ClientReference(SQLModel):
|
class ClientReference(BaseModel):
|
||||||
"""Reference to a client."""
|
"""Reference to a client."""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
class ClientSecretDetailList(SQLModel):
|
class ClientSecretDetailList(BaseModel):
|
||||||
"""A more detailed version of the ClientSecretList."""
|
"""A more detailed version of the ClientSecretList."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
@ -163,7 +162,22 @@ class ClientSecretDetailList(SQLModel):
|
|||||||
clients: list[ClientReference] = Field(default_factory=list)
|
clients: list[ClientReference] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class AuditInfo(SQLModel):
|
class AuditView(BaseModel):
|
||||||
|
"""Audit log view."""
|
||||||
|
|
||||||
|
|
||||||
|
id: uuid.UUID | None = None
|
||||||
|
subsystem: models.SubSystem
|
||||||
|
message: str
|
||||||
|
operation: models.Operation
|
||||||
|
data: dict[str, str] | None = None
|
||||||
|
client_id: uuid.UUID | None = None
|
||||||
|
client_name: str | None = None
|
||||||
|
origin: str | None = None
|
||||||
|
timestamp: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AuditInfo(BaseModel):
|
||||||
"""Information about audit information."""
|
"""Information about audit information."""
|
||||||
|
|
||||||
entries: int
|
entries: int
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from fastapi.testclient import TestClient
|
|||||||
from sshecret.crypto import generate_private_key, generate_public_key_string
|
from sshecret.crypto import generate_private_key, generate_public_key_string
|
||||||
from sshecret_backend.app import create_backend_app
|
from sshecret_backend.app import create_backend_app
|
||||||
from sshecret_backend.testing import create_test_token
|
from sshecret_backend.testing import create_test_token
|
||||||
from sshecret_backend.models import AuditLog
|
from sshecret_backend.view_models import AuditView
|
||||||
from sshecret_backend.settings import BackendSettings
|
from sshecret_backend.settings import BackendSettings
|
||||||
|
|
||||||
|
|
||||||
@ -53,7 +53,7 @@ def create_client_fixture(tmp_path: Path):
|
|||||||
|
|
||||||
db_file = tmp_path / "backend.db"
|
db_file = tmp_path / "backend.db"
|
||||||
print(f"DB File: {db_file.absolute()}")
|
print(f"DB File: {db_file.absolute()}")
|
||||||
settings = BackendSettings(db_url=f"sqlite:///{db_file.absolute()}")
|
settings = BackendSettings(database=str(db_file.absolute()))
|
||||||
app = create_backend_app(settings)
|
app = create_backend_app(settings)
|
||||||
|
|
||||||
token = create_test_token(settings)
|
token = create_test_token(settings)
|
||||||
@ -213,7 +213,7 @@ def test_audit_logging(test_client: TestClient) -> None:
|
|||||||
assert len(audit_logs) > 0
|
assert len(audit_logs) > 0
|
||||||
for entry in audit_logs:
|
for entry in audit_logs:
|
||||||
# Let's try to reassemble the objects
|
# Let's try to reassemble the objects
|
||||||
audit_log = AuditLog.model_validate(entry)
|
audit_log = AuditView.model_validate(entry)
|
||||||
assert audit_log is not None
|
assert audit_log is not None
|
||||||
|
|
||||||
|
|
||||||
@ -522,9 +522,8 @@ def test_operations_with_id(test_client: TestClient) -> None:
|
|||||||
def test_write_audit_log(test_client: TestClient) -> None:
|
def test_write_audit_log(test_client: TestClient) -> None:
|
||||||
"""Test writing to the audit log."""
|
"""Test writing to the audit log."""
|
||||||
params = {
|
params = {
|
||||||
"object": "Test",
|
"subsystem": "backend",
|
||||||
"operation": "TEST",
|
"operation": "read",
|
||||||
"object_id": "Something",
|
|
||||||
"message": "Test Message"
|
"message": "Test Message"
|
||||||
}
|
}
|
||||||
resp = test_client.post("/api/v1/audit", json=params)
|
resp = test_client.post("/api/v1/audit", json=params)
|
||||||
|
|||||||
@ -8,8 +8,10 @@ from .models import (
|
|||||||
ClientReference,
|
ClientReference,
|
||||||
ClientSecret,
|
ClientSecret,
|
||||||
DetailedSecrets,
|
DetailedSecrets,
|
||||||
|
Operation,
|
||||||
Policy,
|
Policy,
|
||||||
Secret,
|
Secret,
|
||||||
|
SubSystem,
|
||||||
FilterType,
|
FilterType,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -22,7 +24,9 @@ __all__ = [
|
|||||||
"ClientSecret",
|
"ClientSecret",
|
||||||
"DetailedSecrets",
|
"DetailedSecrets",
|
||||||
"FilterType",
|
"FilterType",
|
||||||
|
"Operation",
|
||||||
"Policy",
|
"Policy",
|
||||||
"Secret",
|
"Secret",
|
||||||
|
"SubSystem",
|
||||||
"SshecretBackend",
|
"SshecretBackend",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,21 +1,26 @@
|
|||||||
"""Backend client.
|
"""Backend client.
|
||||||
|
|
||||||
|
This is an API calling the HTTP API of the sshecret-backend package so that the
|
||||||
|
admin and sshd library do not need to implement the same
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Self
|
from typing import Any, Self, override
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
from .models import (
|
from .models import (
|
||||||
|
AuditInfo,
|
||||||
AuditLog,
|
AuditLog,
|
||||||
Client,
|
Client,
|
||||||
ClientSecret,
|
ClientSecret,
|
||||||
ClientQueryResult,
|
ClientQueryResult,
|
||||||
ClientFilter,
|
ClientFilter,
|
||||||
DetailedSecrets,
|
DetailedSecrets,
|
||||||
|
Operation,
|
||||||
Secret,
|
Secret,
|
||||||
|
SubSystem,
|
||||||
)
|
)
|
||||||
from .exceptions import BackendValidationError, BackendConnectionError
|
from .exceptions import BackendValidationError, BackendConnectionError
|
||||||
from .utils import validate_public_key
|
from .utils import validate_public_key
|
||||||
@ -84,12 +89,14 @@ class ClientQueryIterator:
|
|||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
|
||||||
class SshecretBackend:
|
class BaseBackend:
|
||||||
"""Backend interface."""
|
"""Base backend class."""
|
||||||
|
|
||||||
def __init__(self, backend_url: str, api_token: str) -> None:
|
def __init__(self, backend_url: str, api_token: str) -> None:
|
||||||
"""Initialize backend client."""
|
"""Initialize backend client."""
|
||||||
|
|
||||||
|
self._backend_url: str = backend_url
|
||||||
|
self._api_token: str = api_token
|
||||||
url = httpx.URL(backend_url)
|
url = httpx.URL(backend_url)
|
||||||
|
|
||||||
self.http_client: httpx.AsyncClient = httpx.AsyncClient(
|
self.http_client: httpx.AsyncClient = httpx.AsyncClient(
|
||||||
@ -101,31 +108,55 @@ class SshecretBackend:
|
|||||||
base_url=url,
|
base_url=url,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _get(self, path: str) -> httpx.Response:
|
def _handle_response(self, response: httpx.Response) -> httpx.Response:
|
||||||
|
"""Handle response."""
|
||||||
|
LOG.debug("Handling response with status_code %s", response.status_code)
|
||||||
|
if response.status_code == 422:
|
||||||
|
LOG.error("Validation error from backend:\n%s", response.text)
|
||||||
|
raise BackendValidationError(response.text)
|
||||||
|
if response.status_code != 404 and str(response.status_code).startswith("4"):
|
||||||
|
raise BackendConnectionError("Error from backend: %s", response.text)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def _get(
|
||||||
|
self, path: str, params: dict[str, str] | None = None
|
||||||
|
) -> httpx.Response:
|
||||||
"""Perform a get request."""
|
"""Perform a get request."""
|
||||||
try:
|
try:
|
||||||
return await self.http_client.get(path)
|
response = await self.http_client.get(path, params=params)
|
||||||
|
return self._handle_response(response)
|
||||||
except httpx.ConnectError as e:
|
except httpx.ConnectError as e:
|
||||||
raise BackendConnectionError() from e
|
raise BackendConnectionError("Could not connect to backend.") from e
|
||||||
|
|
||||||
async def _delete(self, path: str) -> httpx.Response:
|
async def _delete(self, path: str) -> httpx.Response:
|
||||||
"""Perform a delete request."""
|
"""Perform a delete request."""
|
||||||
try:
|
try:
|
||||||
return await self.http_client.delete(path)
|
response = await self.http_client.delete(path)
|
||||||
|
return self._handle_response(response)
|
||||||
except httpx.ConnectError as e:
|
except httpx.ConnectError as e:
|
||||||
raise BackendConnectionError() from e
|
raise BackendConnectionError() from e
|
||||||
|
|
||||||
async def _post(self, path: str, json: Any | None = None) -> httpx.Response:
|
async def _post(self, path: str, json: Any | None = None) -> httpx.Response:
|
||||||
"""Perform a POST request."""
|
"""Perform a POST request."""
|
||||||
try:
|
try:
|
||||||
return await self.http_client.post(path, json=json)
|
response = await self.http_client.post(path, json=json)
|
||||||
|
return self._handle_response(response)
|
||||||
|
except httpx.ConnectError as e:
|
||||||
|
raise BackendConnectionError() from e
|
||||||
|
|
||||||
|
def _post_sync(self, path: str, json: Any | None = None) -> httpx.Response:
|
||||||
|
"""Perform a synchronous post request."""
|
||||||
|
try:
|
||||||
|
return self._handle_response(self.sync_client.post(path, json=json))
|
||||||
except httpx.ConnectError as e:
|
except httpx.ConnectError as e:
|
||||||
raise BackendConnectionError() from e
|
raise BackendConnectionError() from e
|
||||||
|
|
||||||
async def _put(self, path: str, json: Any | None = None) -> httpx.Response:
|
async def _put(self, path: str, json: Any | None = None) -> httpx.Response:
|
||||||
"""Perform a PUT request."""
|
"""Perform a PUT request."""
|
||||||
try:
|
try:
|
||||||
return await self.http_client.put(path, json=json)
|
response = await self.http_client.put(path, json=json)
|
||||||
|
return self._handle_response(response)
|
||||||
except httpx.ConnectError as e:
|
except httpx.ConnectError as e:
|
||||||
raise BackendConnectionError() from e
|
raise BackendConnectionError() from e
|
||||||
|
|
||||||
@ -134,175 +165,90 @@ class SshecretBackend:
|
|||||||
response = await self._get(path)
|
response = await self._get(path)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def create_client(
|
|
||||||
self, name: str, public_key: str, description: str | None = None
|
|
||||||
) -> None:
|
|
||||||
"""Register a new client."""
|
|
||||||
if not validate_public_key(public_key):
|
|
||||||
raise BackendValidationError("Error: Invalid public key format.")
|
|
||||||
data = {
|
|
||||||
"name": name,
|
|
||||||
"public_key": public_key,
|
|
||||||
}
|
|
||||||
if description:
|
|
||||||
data["description"] = description
|
|
||||||
path = "/api/v1/clients/"
|
|
||||||
response = await self._post(path, json=data)
|
|
||||||
|
|
||||||
response.raise_for_status()
|
class AuditAPI(BaseBackend):
|
||||||
|
"""API for the audit logging."""
|
||||||
|
|
||||||
async def get_clients(self, filter: ClientFilter | None = None) -> list[Client]:
|
@override
|
||||||
"""Get all clients."""
|
def __init__(self, backend_url: str, api_token: str, subsystem: str) -> None:
|
||||||
clients: list[Client] = []
|
"""Initialize backend client."""
|
||||||
async for client in ClientQueryIterator(self.http_client, filter):
|
super().__init__(backend_url, api_token)
|
||||||
clients.append(client)
|
self.subsystem: SubSystem = SubSystem(subsystem)
|
||||||
|
|
||||||
return clients
|
def _create_model(
|
||||||
|
self,
|
||||||
async def get_client(self, name: str) -> Client | None:
|
operation: Operation,
|
||||||
"""Lookup a client on username."""
|
message: str,
|
||||||
path = f"/api/v1/clients/{name}"
|
origin: str,
|
||||||
response = await self.request(path)
|
client: Client | None = None,
|
||||||
if response.status_code == 404:
|
secret: ClientSecret | None = None,
|
||||||
return None
|
secret_name: str | None = None,
|
||||||
response.raise_for_status()
|
data: dict[str, str] | None = None,
|
||||||
client = Client.model_validate(response.json())
|
) -> AuditLog:
|
||||||
return client
|
"""Create the audit log object."""
|
||||||
|
model = AuditLog(
|
||||||
async def get_client_by_id(self, id: str) -> Client | None:
|
subsystem=self.subsystem,
|
||||||
"""Lookup a client on username."""
|
operation=operation,
|
||||||
path = f"/api/v1/clients/id/{id}"
|
message=message,
|
||||||
response = await self.request(path)
|
origin=origin,
|
||||||
if response.status_code == 404:
|
|
||||||
return None
|
|
||||||
response.raise_for_status()
|
|
||||||
client = Client.model_validate(response.json())
|
|
||||||
return client
|
|
||||||
|
|
||||||
async def delete_client(self, client_name: str) -> None:
|
|
||||||
"""Delete a client."""
|
|
||||||
path = f"/api/v1/clients/{client_name}"
|
|
||||||
response = await self._delete(path)
|
|
||||||
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
async def delete_client_by_id(self, id: str) -> None:
|
|
||||||
"""Delete a client."""
|
|
||||||
path = f"/api/v1/clients/id/{id}"
|
|
||||||
response = await self._delete(path)
|
|
||||||
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
async def create_client_secret(
|
|
||||||
self, client_name: str, secret_name: str, encrypted_secret: str
|
|
||||||
) -> None:
|
|
||||||
"""Create a secret.
|
|
||||||
|
|
||||||
This will overwrite any existing secret with that name.
|
|
||||||
"""
|
|
||||||
path = f"api/v1/clients/{client_name}/secrets/{secret_name}"
|
|
||||||
response = await self._put(path, json={"value": encrypted_secret})
|
|
||||||
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
async def get_client_secret(self, name: str, secret_name: str) -> str:
|
|
||||||
"""Fetch a secret."""
|
|
||||||
path = f"/api/v1/clients/{name}/secrets/{secret_name}"
|
|
||||||
response = await self.request(path)
|
|
||||||
response.raise_for_status()
|
|
||||||
secret = ClientSecret.model_validate(response.json())
|
|
||||||
return secret.secret
|
|
||||||
|
|
||||||
async def delete_client_secret(self, client_name: str, secret_name: str) -> None:
|
|
||||||
"""Delete a secret from a client."""
|
|
||||||
path = f"api/v1/clients/{client_name}/secrets/{secret_name}"
|
|
||||||
response = await self._delete(path)
|
|
||||||
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
async def update_client(self, client: Client) -> Client:
|
|
||||||
"""Update the client."""
|
|
||||||
path = f"/api/v1/clients/{client.name}"
|
|
||||||
client_update = {
|
|
||||||
"name": client.name,
|
|
||||||
"description": client.description,
|
|
||||||
"public_key": client.public_key,
|
|
||||||
}
|
|
||||||
response = await self._put(path, json=client_update)
|
|
||||||
LOG.info("Response %s", response.text)
|
|
||||||
|
|
||||||
response.raise_for_status()
|
|
||||||
if client.policies:
|
|
||||||
await self.update_client_sources(
|
|
||||||
str(client.id), [str(source) for source in client.policies]
|
|
||||||
)
|
)
|
||||||
return client
|
if client:
|
||||||
|
model.client_id = str(client.id) or None
|
||||||
|
model.client_name = client.name
|
||||||
|
|
||||||
async def update_client_key(self, client_name: str, public_key: str) -> None:
|
if secret:
|
||||||
"""Update the client key."""
|
model.secret_name = secret.name
|
||||||
path = f"/api/v1/clients/{client_name}/public-key"
|
elif secret_name:
|
||||||
response = await self._post(path, json={"public_key": public_key})
|
model.secret_name = secret_name
|
||||||
|
if data:
|
||||||
|
model.data = data
|
||||||
|
|
||||||
response.raise_for_status()
|
return model
|
||||||
|
|
||||||
async def update_client_sources(
|
def write_model(self, model: AuditLog) -> None:
|
||||||
self, client_name: str, addresses: list[str] | None
|
"""Write model."""
|
||||||
|
path = f"/api/v1/audit/"
|
||||||
|
self.sync_client.post(path, json=model.model_dump())
|
||||||
|
|
||||||
|
async def write_model_async(self, model: AuditLog) -> None:
|
||||||
|
"""Write model async."""
|
||||||
|
path = f"/api/v1/audit/"
|
||||||
|
await self._post(path, json=model.model_dump())
|
||||||
|
|
||||||
|
def write(
|
||||||
|
self,
|
||||||
|
operation: Operation,
|
||||||
|
message: str,
|
||||||
|
origin: str,
|
||||||
|
client: Client | None = None,
|
||||||
|
secret: ClientSecret | None = None,
|
||||||
|
secret_name: str | None = None,
|
||||||
|
**data: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update client source addresses.
|
"""Write an audit entry."""
|
||||||
|
model = self._create_model(
|
||||||
|
operation, message, origin, client, secret, secret_name, data
|
||||||
|
)
|
||||||
|
|
||||||
Pass None to sources to allow from all.
|
self.write_model(model)
|
||||||
"""
|
|
||||||
if not addresses:
|
|
||||||
addresses = []
|
|
||||||
|
|
||||||
path = f"/api/v1/clients/{client_name}/policies/"
|
async def write_async(
|
||||||
response = await self._put(path, json={"sources": addresses})
|
self,
|
||||||
|
operation: Operation,
|
||||||
|
message: str,
|
||||||
|
origin: str,
|
||||||
|
client: Client | None = None,
|
||||||
|
secret: ClientSecret | None = None,
|
||||||
|
secret_name: str | None = None,
|
||||||
|
**data: str,
|
||||||
|
) -> None:
|
||||||
|
"""Write an audit entry."""
|
||||||
|
model = self._create_model(
|
||||||
|
operation, message, origin, client, secret, secret_name, data
|
||||||
|
)
|
||||||
|
await self.write_model_async(model)
|
||||||
|
|
||||||
response.raise_for_status()
|
async def get(
|
||||||
|
|
||||||
async def get_detailed_secrets(self) -> list[DetailedSecrets]:
|
|
||||||
"""Get detailed list of secrets."""
|
|
||||||
path = "/api/v1/secrets/detailed/"
|
|
||||||
response = await self._get(path)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
secret_list = TypeAdapter(list[DetailedSecrets])
|
|
||||||
return secret_list.validate_python(response.json())
|
|
||||||
|
|
||||||
async def get_secrets(self) -> list[Secret]:
|
|
||||||
"""Get Secrets.
|
|
||||||
|
|
||||||
This provides a list of secret names and which clients have them.
|
|
||||||
"""
|
|
||||||
path = "/api/v1/secrets/"
|
|
||||||
response = await self._get(path)
|
|
||||||
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
secret_list = TypeAdapter(list[Secret])
|
|
||||||
return secret_list.validate_python(response.json())
|
|
||||||
|
|
||||||
async def get_secret(self, name: str) -> Secret | None:
|
|
||||||
"""Get clients mapped to a single secret."""
|
|
||||||
path = f"/api/v1/secrets/{name}"
|
|
||||||
response = await self._get(path)
|
|
||||||
if response.status_code == 404:
|
|
||||||
return None
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
return Secret.model_validate(response.json())
|
|
||||||
|
|
||||||
async def get_detailed_secret(self, name: str) -> DetailedSecrets | None:
|
|
||||||
"""Get clients mapped to a single secret."""
|
|
||||||
path = f"/api/v1/secrets/{name}/detailed"
|
|
||||||
response = await self._get(path)
|
|
||||||
if response.status_code == 404:
|
|
||||||
return None
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
return DetailedSecrets.model_validate(response.json())
|
|
||||||
|
|
||||||
async def get_audit_log(
|
|
||||||
self,
|
self,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
@ -321,30 +267,169 @@ class SshecretBackend:
|
|||||||
if subsystem:
|
if subsystem:
|
||||||
params["filter_subsystem"] = subsystem
|
params["filter_subsystem"] = subsystem
|
||||||
|
|
||||||
response = await self.http_client.get(path, params=params)
|
response = await self._get(path, params=params)
|
||||||
response.raise_for_status()
|
|
||||||
audit_log_adapter = TypeAdapter(list[AuditLog])
|
audit_log_adapter = TypeAdapter(list[AuditLog])
|
||||||
return audit_log_adapter.validate_python(response.json())
|
return audit_log_adapter.validate_python(response.json())
|
||||||
|
|
||||||
async def add_audit_log(self, entry: AuditLog) -> None:
|
async def count(self) -> int:
|
||||||
"""Add audit log entry."""
|
|
||||||
path = f"/api/v1/audit/"
|
|
||||||
|
|
||||||
response = await self.http_client.post(path, json=entry.model_dump())
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
async def get_audit_log_count(self) -> int:
|
|
||||||
"""Get amount of messages in the audit log."""
|
"""Get amount of messages in the audit log."""
|
||||||
path = f"/api/v1/audit/info"
|
path = f"/api/v1/audit/info"
|
||||||
response = await self._get(path)
|
response = await self._get(path)
|
||||||
response.raise_for_status()
|
audit_info = AuditInfo.model_validate(response.json())
|
||||||
data = response.json()
|
return audit_info.entries
|
||||||
return int(data["entries"])
|
|
||||||
|
|
||||||
def add_audit_log_sync(self, entry: AuditLog) -> None:
|
|
||||||
"""Add audit log entry."""
|
|
||||||
path = f"/api/v1/audit/"
|
|
||||||
LOG.info("AUDIT LOG SYNC %r", entry)
|
|
||||||
|
|
||||||
response = self.sync_client.post(path, json=entry.model_dump())
|
class SshecretBackend(BaseBackend):
|
||||||
response.raise_for_status()
|
"""Backend interface."""
|
||||||
|
|
||||||
|
async def create_client(
|
||||||
|
self, name: str, public_key: str, description: str | None = None
|
||||||
|
) -> None:
|
||||||
|
"""Register a new client."""
|
||||||
|
if not validate_public_key(public_key):
|
||||||
|
raise BackendValidationError("Error: Invalid public key format.")
|
||||||
|
data = {
|
||||||
|
"name": name,
|
||||||
|
"public_key": public_key,
|
||||||
|
}
|
||||||
|
if description:
|
||||||
|
data["description"] = description
|
||||||
|
path = "/api/v1/clients/"
|
||||||
|
response = await self._post(path, json=data)
|
||||||
|
|
||||||
|
async def get_clients(self, filter: ClientFilter | None = None) -> list[Client]:
|
||||||
|
"""Get all clients."""
|
||||||
|
clients: list[Client] = []
|
||||||
|
async for client in ClientQueryIterator(self.http_client, filter):
|
||||||
|
clients.append(client)
|
||||||
|
|
||||||
|
return clients
|
||||||
|
|
||||||
|
async def get_client(self, name: str) -> Client | None:
|
||||||
|
"""Lookup a client on username."""
|
||||||
|
path = f"/api/v1/clients/{name}"
|
||||||
|
response = await self._get(path)
|
||||||
|
if response.status_code == 404:
|
||||||
|
return None
|
||||||
|
client = Client.model_validate(response.json())
|
||||||
|
return client
|
||||||
|
|
||||||
|
async def get_client_by_id(self, id: str) -> Client | None:
|
||||||
|
"""Lookup a client on username."""
|
||||||
|
path = f"/api/v1/clients/id/{id}"
|
||||||
|
response = await self._get(path)
|
||||||
|
if response.status_code == 404:
|
||||||
|
return None
|
||||||
|
client = Client.model_validate(response.json())
|
||||||
|
return client
|
||||||
|
|
||||||
|
async def delete_client(self, client_name: str) -> None:
|
||||||
|
"""Delete a client."""
|
||||||
|
path = f"/api/v1/clients/{client_name}"
|
||||||
|
response = await self._delete(path)
|
||||||
|
|
||||||
|
async def delete_client_by_id(self, id: str) -> None:
|
||||||
|
"""Delete a client."""
|
||||||
|
path = f"/api/v1/clients/id/{id}"
|
||||||
|
response = await self._delete(path)
|
||||||
|
|
||||||
|
async def create_client_secret(
|
||||||
|
self, client_name: str, secret_name: str, encrypted_secret: str
|
||||||
|
) -> None:
|
||||||
|
"""Create a secret.
|
||||||
|
|
||||||
|
This will overwrite any existing secret with that name.
|
||||||
|
"""
|
||||||
|
path = f"api/v1/clients/{client_name}/secrets/{secret_name}"
|
||||||
|
response = await self._put(path, json={"value": encrypted_secret})
|
||||||
|
|
||||||
|
async def get_client_secret(self, name: str, secret_name: str) -> str | None:
|
||||||
|
"""Fetch a secret."""
|
||||||
|
path = f"/api/v1/clients/{name}/secrets/{secret_name}"
|
||||||
|
response = await self._get(path)
|
||||||
|
if response.status_code == 404:
|
||||||
|
return None
|
||||||
|
secret = ClientSecret.model_validate(response.json())
|
||||||
|
return secret.secret
|
||||||
|
|
||||||
|
async def delete_client_secret(self, client_name: str, secret_name: str) -> None:
|
||||||
|
"""Delete a secret from a client."""
|
||||||
|
path = f"api/v1/clients/{client_name}/secrets/{secret_name}"
|
||||||
|
await self._delete(path)
|
||||||
|
|
||||||
|
async def update_client(self, client: Client) -> Client:
|
||||||
|
"""Update the client."""
|
||||||
|
path = f"/api/v1/clients/{client.name}"
|
||||||
|
client_update = {
|
||||||
|
"name": client.name,
|
||||||
|
"description": client.description,
|
||||||
|
"public_key": client.public_key,
|
||||||
|
}
|
||||||
|
response = await self._put(path, json=client_update)
|
||||||
|
LOG.info("Response %s", response.text)
|
||||||
|
|
||||||
|
if client.policies:
|
||||||
|
await self.update_client_sources(
|
||||||
|
str(client.id), [str(source) for source in client.policies]
|
||||||
|
)
|
||||||
|
return client
|
||||||
|
|
||||||
|
async def update_client_key(self, client_name: str, public_key: str) -> None:
|
||||||
|
"""Update the client key."""
|
||||||
|
path = f"/api/v1/clients/{client_name}/public-key"
|
||||||
|
await self._post(path, json={"public_key": public_key})
|
||||||
|
|
||||||
|
async def update_client_sources(
|
||||||
|
self, client_name: str, addresses: list[str] | None
|
||||||
|
) -> None:
|
||||||
|
"""Update client source addresses.
|
||||||
|
|
||||||
|
Pass None to sources to allow from all.
|
||||||
|
"""
|
||||||
|
if not addresses:
|
||||||
|
addresses = []
|
||||||
|
|
||||||
|
path = f"/api/v1/clients/{client_name}/policies/"
|
||||||
|
await self._put(path, json={"sources": addresses})
|
||||||
|
|
||||||
|
async def get_detailed_secrets(self) -> list[DetailedSecrets]:
|
||||||
|
"""Get detailed list of secrets."""
|
||||||
|
path = "/api/v1/secrets/detailed/"
|
||||||
|
response = await self._get(path)
|
||||||
|
|
||||||
|
secret_list = TypeAdapter(list[DetailedSecrets])
|
||||||
|
return secret_list.validate_python(response.json())
|
||||||
|
|
||||||
|
async def get_secrets(self) -> list[Secret]:
|
||||||
|
"""Get Secrets.
|
||||||
|
|
||||||
|
This provides a list of secret names and which clients have them.
|
||||||
|
"""
|
||||||
|
path = "/api/v1/secrets/"
|
||||||
|
response = await self._get(path)
|
||||||
|
|
||||||
|
secret_list = TypeAdapter(list[Secret])
|
||||||
|
return secret_list.validate_python(response.json())
|
||||||
|
|
||||||
|
async def get_secret(self, name: str) -> Secret | None:
|
||||||
|
"""Get clients mapped to a single secret."""
|
||||||
|
path = f"/api/v1/secrets/{name}"
|
||||||
|
response = await self._get(path)
|
||||||
|
if response.status_code == 404:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return Secret.model_validate(response.json())
|
||||||
|
|
||||||
|
async def get_detailed_secret(self, name: str) -> DetailedSecrets | None:
|
||||||
|
"""Get clients mapped to a single secret."""
|
||||||
|
path = f"/api/v1/secrets/{name}/detailed"
|
||||||
|
response = await self._get(path)
|
||||||
|
if response.status_code == 404:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return DetailedSecrets.model_validate(response.json())
|
||||||
|
|
||||||
|
def audit(self, subsystem: SubSystem) -> AuditAPI:
|
||||||
|
"""Create the audit API."""
|
||||||
|
audit = AuditAPI(self._backend_url, self._api_token, subsystem)
|
||||||
|
return audit
|
||||||
|
|||||||
@ -17,6 +17,27 @@ class FilterType(enum.StrEnum):
|
|||||||
CONTAINS = "contains"
|
CONTAINS = "contains"
|
||||||
|
|
||||||
|
|
||||||
|
class SubSystem(enum.StrEnum):
|
||||||
|
"""Available subsystems."""
|
||||||
|
|
||||||
|
ADMIN = enum.auto()
|
||||||
|
SSHD = enum.auto()
|
||||||
|
BACKEND = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
|
class Operation(enum.StrEnum):
|
||||||
|
"""Various operations for the audit logging module."""
|
||||||
|
|
||||||
|
CREATE = enum.auto()
|
||||||
|
READ = enum.auto()
|
||||||
|
UPDATE = enum.auto()
|
||||||
|
DELETE = enum.auto()
|
||||||
|
DENY = enum.auto()
|
||||||
|
PERMIT = enum.auto()
|
||||||
|
LOGIN = enum.auto()
|
||||||
|
NONE = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
class Client(BaseModel):
|
class Client(BaseModel):
|
||||||
"""Implementation of the backend class ClientView."""
|
"""Implementation of the backend class ClientView."""
|
||||||
|
|
||||||
@ -101,15 +122,22 @@ class ClientFilter(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class AuditLog(BaseModel):
|
class AuditLog(BaseModel):
|
||||||
"""Implementation of the backend class AuditLog."""
|
"""Implementation of the backend class AuditView."""
|
||||||
|
|
||||||
id: str | None = None
|
id: str | None = None
|
||||||
subsystem: str | None = None
|
subsystem: SubSystem
|
||||||
object: str | None = None
|
operation: Operation
|
||||||
object_id: str | None = None
|
|
||||||
operation: str
|
|
||||||
client_id: str | None = None
|
client_id: str | None = None
|
||||||
client_name: str | None = None
|
client_name: str | None = None
|
||||||
|
secret_id: str | None = None
|
||||||
|
secret_name: str | None = None
|
||||||
|
data: dict[str, str] | None = None
|
||||||
message: str
|
message: str
|
||||||
origin: str | None = None
|
origin: str | None = None
|
||||||
timestamp: datetime | None = None
|
timestamp: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AuditInfo(BaseModel):
|
||||||
|
"""Implementation of the backend class AuditInfo."""
|
||||||
|
|
||||||
|
entries: int
|
||||||
|
|||||||
Reference in New Issue
Block a user