Check in backend in working state
This commit is contained in:
119
packages/sshecret-backend/alembic.ini
Normal file
119
packages/sshecret-backend/alembic.ini
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
# A generic, single database configuration.
|
||||||
|
|
||||||
|
[alembic]
|
||||||
|
# path to migration scripts
|
||||||
|
# Use forward slashes (/) also on windows to provide an os agnostic path
|
||||||
|
script_location = migrations
|
||||||
|
|
||||||
|
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||||
|
# Uncomment the line below if you want the files to be prepended with date and time
|
||||||
|
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||||
|
# for all available tokens
|
||||||
|
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||||
|
|
||||||
|
# sys.path path, will be prepended to sys.path if present.
|
||||||
|
# defaults to the current working directory.
|
||||||
|
prepend_sys_path = .
|
||||||
|
|
||||||
|
# timezone to use when rendering the date within the migration file
|
||||||
|
# as well as the filename.
|
||||||
|
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
|
||||||
|
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
||||||
|
# string value is passed to ZoneInfo()
|
||||||
|
# leave blank for localtime
|
||||||
|
# timezone =
|
||||||
|
|
||||||
|
# max length of characters to apply to the "slug" field
|
||||||
|
# truncate_slug_length = 40
|
||||||
|
|
||||||
|
# set to 'true' to run the environment during
|
||||||
|
# the 'revision' command, regardless of autogenerate
|
||||||
|
# revision_environment = false
|
||||||
|
|
||||||
|
# set to 'true' to allow .pyc and .pyo files without
|
||||||
|
# a source .py file to be detected as revisions in the
|
||||||
|
# versions/ directory
|
||||||
|
# sourceless = false
|
||||||
|
|
||||||
|
# version location specification; This defaults
|
||||||
|
# to migrations/versions. When using multiple version
|
||||||
|
# directories, initial revisions must be specified with --version-path.
|
||||||
|
# The path separator used here should be the separator specified by "version_path_separator" below.
|
||||||
|
# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
|
||||||
|
|
||||||
|
# version path separator; As mentioned above, this is the character used to split
|
||||||
|
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||||
|
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
||||||
|
# Valid values for version_path_separator are:
|
||||||
|
#
|
||||||
|
# version_path_separator = :
|
||||||
|
# version_path_separator = ;
|
||||||
|
# version_path_separator = space
|
||||||
|
# version_path_separator = newline
|
||||||
|
#
|
||||||
|
# Use os.pathsep. Default configuration used for new projects.
|
||||||
|
version_path_separator = os
|
||||||
|
|
||||||
|
# set to 'true' to search source files recursively
|
||||||
|
# in each "version_locations" directory
|
||||||
|
# new in Alembic version 1.10
|
||||||
|
# recursive_version_locations = false
|
||||||
|
|
||||||
|
# the output encoding used when revision files
|
||||||
|
# are written from script.py.mako
|
||||||
|
# output_encoding = utf-8
|
||||||
|
|
||||||
|
sqlalchemy.url = sqlite:///sshecret.db
|
||||||
|
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
# post_write_hooks defines scripts or Python functions that are run
|
||||||
|
# on newly generated revision scripts. See the documentation for further
|
||||||
|
# detail and examples
|
||||||
|
|
||||||
|
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||||
|
# hooks = black
|
||||||
|
# black.type = console_scripts
|
||||||
|
# black.entrypoint = black
|
||||||
|
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
|
||||||
|
# hooks = ruff
|
||||||
|
# ruff.type = exec
|
||||||
|
# ruff.executable = %(here)s/.venv/bin/ruff
|
||||||
|
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# Logging configuration
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARNING
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARNING
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
1
packages/sshecret-backend/migrations/README
Normal file
1
packages/sshecret-backend/migrations/README
Normal file
@ -0,0 +1 @@
|
|||||||
|
Generic single-database configuration.
|
||||||
80
packages/sshecret-backend/migrations/env.py
Normal file
80
packages/sshecret-backend/migrations/env.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from sqlalchemy import engine_from_config
|
||||||
|
from sqlalchemy import pool
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
from sshecret_backend.models import *
|
||||||
|
|
||||||
|
# this is the Alembic Config object, which provides
|
||||||
|
# access to the values within the .ini file in use.
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
# Interpret the config file for Python logging.
|
||||||
|
# This line sets up loggers basically.
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
# add your model's MetaData object here
|
||||||
|
# for 'autogenerate' support
|
||||||
|
# from myapp import mymodel
|
||||||
|
# target_metadata = mymodel.Base.metadata
|
||||||
|
#target_metadata = None
|
||||||
|
target_metadata = SQLModel.metadata
|
||||||
|
|
||||||
|
# other values from the config, defined by the needs of env.py,
|
||||||
|
# can be acquired:
|
||||||
|
# my_important_option = config.get_main_option("my_important_option")
|
||||||
|
# ... etc.
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
"""Run migrations in 'offline' mode.
|
||||||
|
|
||||||
|
This configures the context with just a URL
|
||||||
|
and not an Engine, though an Engine is acceptable
|
||||||
|
here as well. By skipping the Engine creation
|
||||||
|
we don't even need a DBAPI to be available.
|
||||||
|
|
||||||
|
Calls to context.execute() here emit the given string to the
|
||||||
|
script output.
|
||||||
|
|
||||||
|
"""
|
||||||
|
url = config.get_main_option("sqlalchemy.url")
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
"""Run migrations in 'online' mode.
|
||||||
|
|
||||||
|
In this scenario we need to create an Engine
|
||||||
|
and associate a connection with the context.
|
||||||
|
|
||||||
|
"""
|
||||||
|
connectable = engine_from_config(
|
||||||
|
config.get_section(config.config_ini_section, {}),
|
||||||
|
prefix="sqlalchemy.",
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
)
|
||||||
|
|
||||||
|
with connectable.connect() as connection:
|
||||||
|
context.configure(
|
||||||
|
connection=connection, target_metadata=target_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
29
packages/sshecret-backend/migrations/script.py.mako
Normal file
29
packages/sshecret-backend/migrations/script.py.mako
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
import sqlmodel
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||||
|
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
@ -0,0 +1,33 @@
|
|||||||
|
"""Initial model
|
||||||
|
|
||||||
|
Revision ID: a0befb5a74a0
|
||||||
|
Revises:
|
||||||
|
Create Date: 2025-04-28 21:18:59.069323
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
import sqlmodel
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'a0befb5a74a0'
|
||||||
|
down_revision: Union[str, None] = None
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
pass
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
pass
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -0,0 +1,33 @@
|
|||||||
|
"""Add subsystem to auditlog
|
||||||
|
|
||||||
|
Revision ID: f30e413c5757
|
||||||
|
Revises: a0befb5a74a0
|
||||||
|
Create Date: 2025-04-28 21:21:20.103423
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
import sqlmodel
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'f30e413c5757'
|
||||||
|
down_revision: Union[str, None] = 'a0befb5a74a0'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column('auditlog', sa.Column('subsystem', sqlmodel.sql.sqltypes.AutoString(), nullable=True))
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_column('auditlog', 'subsystem')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -8,9 +8,11 @@ authors = [
|
|||||||
]
|
]
|
||||||
requires-python = ">=3.13"
|
requires-python = ">=3.13"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"alembic>=1.15.2",
|
||||||
"passlib[bcrypt]>=1.7.4",
|
"passlib[bcrypt]>=1.7.4",
|
||||||
"pydantic>=2.10.6",
|
"pydantic>=2.10.6",
|
||||||
"pytest>=8.3.5",
|
"pytest>=8.3.5",
|
||||||
|
"python-multipart>=0.0.20",
|
||||||
"sqlmodel>=0.0.24",
|
"sqlmodel>=0.0.24",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,2 @@
|
|||||||
"""Sshecret backend."""
|
"""Sshecret backend."""
|
||||||
from .app import app as app
|
# from .router import app as app
|
||||||
#from .router import app as app
|
|
||||||
|
|
||||||
__all__ = ["app"]
|
|
||||||
|
|||||||
@ -0,0 +1,8 @@
|
|||||||
|
"""API factory modules."""
|
||||||
|
|
||||||
|
from .audit import get_audit_api
|
||||||
|
from .clients import get_clients_api
|
||||||
|
from .policies import get_policy_api
|
||||||
|
from .secrets import get_secrets_api
|
||||||
|
|
||||||
|
__all__ = ["get_audit_api", "get_clients_api", "get_policy_api", "get_secrets_api"]
|
||||||
65
packages/sshecret-backend/src/sshecret_backend/api/audit.py
Normal file
65
packages/sshecret-backend/src/sshecret_backend/api/audit.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
"""Audit sub-api factory."""
|
||||||
|
|
||||||
|
# pyright: reportUnusedFunction=false
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from fastapi import APIRouter, Depends, Request, Query
|
||||||
|
from sqlmodel import Session, col, func, select
|
||||||
|
from sqlalchemy import desc
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from sshecret_backend.models import AuditLog
|
||||||
|
from sshecret_backend.types import DBSessionDep
|
||||||
|
from sshecret_backend import audit
|
||||||
|
from sshecret_backend.view_models import AuditInfo
|
||||||
|
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||||
|
"""Construct audit sub-api."""
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.get("/audit/", response_model=list[AuditLog])
|
||||||
|
async def get_audit_logs(
|
||||||
|
request: Request,
|
||||||
|
session: Annotated[Session, Depends(get_db_session)],
|
||||||
|
offset: Annotated[int, Query()] = 0,
|
||||||
|
limit: Annotated[int, Query(le=100)] = 100,
|
||||||
|
filter_client: Annotated[str | None, Query()] = None,
|
||||||
|
filter_subsystem: Annotated[str | None, Query()] = None,
|
||||||
|
) -> Sequence[AuditLog]:
|
||||||
|
"""Get audit logs."""
|
||||||
|
#audit.audit_access_audit_log(session, request)
|
||||||
|
statement = select(AuditLog).offset(offset).limit(limit).order_by(desc(col(AuditLog.timestamp)))
|
||||||
|
if filter_client:
|
||||||
|
statement = statement.where(AuditLog.client_name == filter_client)
|
||||||
|
|
||||||
|
if filter_subsystem:
|
||||||
|
statement = statement.where(AuditLog.subsystem == filter_subsystem)
|
||||||
|
|
||||||
|
results = session.exec(statement).all()
|
||||||
|
return results
|
||||||
|
|
||||||
|
@router.post("/audit/")
|
||||||
|
async def add_audit_log(
|
||||||
|
request: Request,
|
||||||
|
session: Annotated[Session, Depends(get_db_session)],
|
||||||
|
entry: AuditLog,
|
||||||
|
) -> AuditLog:
|
||||||
|
"""Add entry to audit log."""
|
||||||
|
audit_log = AuditLog.model_validate(entry.model_dump(exclude_none=True))
|
||||||
|
session.add(audit_log)
|
||||||
|
session.commit()
|
||||||
|
return audit_log
|
||||||
|
|
||||||
|
@router.get("/audit/info")
|
||||||
|
async def get_audit_info(request: Request, session: Annotated[Session, Depends(get_db_session)]) -> AuditInfo:
|
||||||
|
"""Get audit info."""
|
||||||
|
audit_count = session.exec(select(func.count('*')).select_from(AuditLog)).one()
|
||||||
|
return AuditInfo(entries=audit_count)
|
||||||
|
|
||||||
|
|
||||||
|
return router
|
||||||
226
packages/sshecret-backend/src/sshecret_backend/api/clients.py
Normal file
226
packages/sshecret-backend/src/sshecret_backend/api/clients.py
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
"""Client sub-api factory."""
|
||||||
|
|
||||||
|
# pyright: reportUnusedFunction=false
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
import logging
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||||
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
from sqlmodel import Session, col, select
|
||||||
|
from sqlalchemy import func
|
||||||
|
from typing import Annotated, Self, TypeVar
|
||||||
|
|
||||||
|
from sqlmodel.sql.expression import SelectOfScalar
|
||||||
|
from sshecret_backend.types import DBSessionDep
|
||||||
|
from sshecret_backend.models import Client, ClientSecret
|
||||||
|
from sshecret_backend.view_models import (
|
||||||
|
ClientCreate,
|
||||||
|
ClientQueryResult,
|
||||||
|
ClientView,
|
||||||
|
ClientUpdate,
|
||||||
|
)
|
||||||
|
from sshecret_backend import audit
|
||||||
|
from .common import get_client_by_id_or_name
|
||||||
|
|
||||||
|
|
||||||
|
class ClientListParams(BaseModel):
|
||||||
|
"""Client list parameters."""
|
||||||
|
|
||||||
|
limit: int = Field(100, gt=0, le=100)
|
||||||
|
offset: int = Field(0, ge=0)
|
||||||
|
id: uuid.UUID | None = None
|
||||||
|
name: str | None = None
|
||||||
|
name__like: str | None = None
|
||||||
|
name__contains: str | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_expressions(self) -> Self:
|
||||||
|
"""Validate mutually exclusive expression."""
|
||||||
|
name_filter = False
|
||||||
|
if self.name__like or self.name__contains:
|
||||||
|
name_filter = True
|
||||||
|
if self.name__like and self.name__contains:
|
||||||
|
raise ValueError("You may only specify one name expression")
|
||||||
|
if self.name and name_filter:
|
||||||
|
raise ValueError(
|
||||||
|
"You must either specify name or one of name__like or name__contains"
|
||||||
|
)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def filter_client_statement(
|
||||||
|
statement: SelectOfScalar[T], params: ClientListParams, ignore_limits: bool = False
|
||||||
|
) -> SelectOfScalar[T]:
|
||||||
|
"""Filter a statement with the provided params."""
|
||||||
|
if params.id:
|
||||||
|
statement = statement.where(Client.id == params.id)
|
||||||
|
|
||||||
|
if params.name:
|
||||||
|
statement = statement.where(Client.name == params.name)
|
||||||
|
elif params.name__like:
|
||||||
|
statement = statement.where(col(Client.name).like(params.name__like))
|
||||||
|
elif params.name__contains:
|
||||||
|
statement = statement.where(col(Client.name).contains(params.name__contains))
|
||||||
|
|
||||||
|
if ignore_limits:
|
||||||
|
return statement
|
||||||
|
|
||||||
|
return statement.limit(params.limit).offset(params.offset)
|
||||||
|
|
||||||
|
|
||||||
|
def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||||
|
"""Construct clients sub-api."""
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.get("/clients/")
|
||||||
|
async def get_clients(
|
||||||
|
filter_query: Annotated[ClientListParams, Query()],
|
||||||
|
session: Annotated[Session, Depends(get_db_session)],
|
||||||
|
) -> ClientQueryResult:
|
||||||
|
"""Get clients."""
|
||||||
|
# Get total results first
|
||||||
|
count_statement = select(func.count("*")).select_from(Client)
|
||||||
|
count_statement = filter_client_statement(count_statement, filter_query, True)
|
||||||
|
|
||||||
|
total_results = session.exec(count_statement).one()
|
||||||
|
|
||||||
|
statement = filter_client_statement(select(Client), filter_query, False)
|
||||||
|
|
||||||
|
results = session.exec(statement)
|
||||||
|
remainder = total_results - filter_query.offset - filter_query.limit
|
||||||
|
if remainder < 0:
|
||||||
|
remainder = 0
|
||||||
|
|
||||||
|
clients = list(results)
|
||||||
|
clients_view = ClientView.from_client_list(clients)
|
||||||
|
return ClientQueryResult(
|
||||||
|
clients=clients_view,
|
||||||
|
total_results=total_results,
|
||||||
|
remaining_results=remainder,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/clients/{name}")
|
||||||
|
async def get_client(
|
||||||
|
name: str,
|
||||||
|
session: Annotated[Session, Depends(get_db_session)],
|
||||||
|
) -> ClientView:
|
||||||
|
"""Fetch a client."""
|
||||||
|
client = await get_client_by_id_or_name(session, name)
|
||||||
|
if not client:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail="Cannot find a client with the given name."
|
||||||
|
)
|
||||||
|
return ClientView.from_client(client)
|
||||||
|
|
||||||
|
@router.delete("/clients/{name}")
|
||||||
|
async def delete_client(
|
||||||
|
request: Request,
|
||||||
|
name: str,
|
||||||
|
session: Annotated[Session, Depends(get_db_session)],
|
||||||
|
) -> None:
|
||||||
|
"""Delete a client."""
|
||||||
|
client = await get_client_by_id_or_name(session, name)
|
||||||
|
if not client:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail="Cannot find a client with the given name."
|
||||||
|
)
|
||||||
|
|
||||||
|
session.delete(client)
|
||||||
|
session.commit()
|
||||||
|
audit.audit_delete_client(session, request, client)
|
||||||
|
|
||||||
|
@router.post("/clients/")
|
||||||
|
async def create_client(
|
||||||
|
request: Request,
|
||||||
|
client: ClientCreate,
|
||||||
|
session: Annotated[Session, Depends(get_db_session)],
|
||||||
|
) -> ClientView:
|
||||||
|
"""Create client."""
|
||||||
|
existing = await get_client_by_id_or_name(session, client.name)
|
||||||
|
if existing:
|
||||||
|
raise HTTPException(400, detail="Error: Already a client with that name.")
|
||||||
|
|
||||||
|
db_client = client.to_client()
|
||||||
|
session.add(db_client)
|
||||||
|
session.commit()
|
||||||
|
session.refresh(db_client)
|
||||||
|
audit.audit_create_client(session, request, db_client)
|
||||||
|
return ClientView.from_client(db_client)
|
||||||
|
|
||||||
|
@router.post("/clients/{name}/public-key")
|
||||||
|
async def update_client_public_key(
|
||||||
|
request: Request,
|
||||||
|
name: str,
|
||||||
|
client_update: ClientUpdate,
|
||||||
|
session: Annotated[Session, Depends(get_db_session)],
|
||||||
|
) -> ClientView:
|
||||||
|
"""Change the public key of a client.
|
||||||
|
|
||||||
|
This invalidates all secrets.
|
||||||
|
"""
|
||||||
|
client = await get_client_by_id_or_name(session, name)
|
||||||
|
if not client:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail="Cannot find a client with the given name."
|
||||||
|
)
|
||||||
|
client.public_key = client_update.public_key
|
||||||
|
for secret in session.exec(
|
||||||
|
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||||
|
).all():
|
||||||
|
LOG.debug("Invalidated secret %s", secret.id)
|
||||||
|
secret.invalidated = True
|
||||||
|
secret.client_id = None
|
||||||
|
secret.client = None
|
||||||
|
|
||||||
|
session.add(client)
|
||||||
|
session.refresh(client)
|
||||||
|
session.commit()
|
||||||
|
audit.audit_invalidate_secrets(session, request, client)
|
||||||
|
|
||||||
|
return ClientView.from_client(client)
|
||||||
|
|
||||||
|
@router.put("/clients/{name}")
|
||||||
|
async def update_client(
|
||||||
|
request: Request,
|
||||||
|
name: str,
|
||||||
|
client_update: ClientCreate,
|
||||||
|
session: Annotated[Session, Depends(get_db_session)],
|
||||||
|
) -> ClientView:
|
||||||
|
"""Change the public key of a client.
|
||||||
|
|
||||||
|
This invalidates all secrets.
|
||||||
|
"""
|
||||||
|
client = await get_client_by_id_or_name(session, name)
|
||||||
|
if not client:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail="Cannot find a client with the given name."
|
||||||
|
)
|
||||||
|
client.name = client_update.name
|
||||||
|
client.description = client_update.description
|
||||||
|
public_key_updated = False
|
||||||
|
if client_update.public_key != client.public_key:
|
||||||
|
public_key_updated = True
|
||||||
|
for secret in session.exec(
|
||||||
|
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||||
|
).all():
|
||||||
|
LOG.debug("Invalidated secret %s", secret.id)
|
||||||
|
secret.invalidated = True
|
||||||
|
secret.client_id = None
|
||||||
|
secret.client = None
|
||||||
|
|
||||||
|
session.add(client)
|
||||||
|
session.commit()
|
||||||
|
session.refresh(client)
|
||||||
|
audit.audit_update_client(session, request, client)
|
||||||
|
if public_key_updated:
|
||||||
|
audit.audit_invalidate_secrets(session, request, client)
|
||||||
|
|
||||||
|
return ClientView.from_client(client)
|
||||||
|
|
||||||
|
return router
|
||||||
38
packages/sshecret-backend/src/sshecret_backend/api/common.py
Normal file
38
packages/sshecret-backend/src/sshecret_backend/api/common.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
"""Common helpers."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
import bcrypt
|
||||||
|
|
||||||
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
|
from sshecret_backend.models import Client
|
||||||
|
|
||||||
|
RE_UUID = re.compile("^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$")
|
||||||
|
|
||||||
|
def verify_token(token: str, stored_hash: str) -> bool:
|
||||||
|
"""Verify token."""
|
||||||
|
token_bytes = token.encode("utf-8")
|
||||||
|
stored_bytes = stored_hash.encode("utf-8")
|
||||||
|
return bcrypt.checkpw(token_bytes, stored_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_client_by_name(session: Session, name: str) -> Client | None:
|
||||||
|
"""Get client by name."""
|
||||||
|
client_filter = select(Client).where(Client.name == name)
|
||||||
|
client_results = session.exec(client_filter)
|
||||||
|
return client_results.first()
|
||||||
|
|
||||||
|
async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None:
|
||||||
|
"""Get client by name."""
|
||||||
|
client_filter = select(Client).where(Client.id == id)
|
||||||
|
client_results = session.exec(client_filter)
|
||||||
|
return client_results.first()
|
||||||
|
|
||||||
|
async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None:
|
||||||
|
"""Get client either by id or name."""
|
||||||
|
if RE_UUID.match(id_or_name):
|
||||||
|
id = uuid.UUID(id_or_name)
|
||||||
|
return await get_client_by_id(session, id)
|
||||||
|
|
||||||
|
return await get_client_by_name(session, id_or_name)
|
||||||
@ -0,0 +1,82 @@
|
|||||||
|
"""Policies sub-api router factory."""
|
||||||
|
|
||||||
|
# pyright: reportUnusedFunction=false
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
|
from sqlmodel import Session, select
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from sshecret_backend.models import Client, ClientAccessPolicy
|
||||||
|
from sshecret_backend.view_models import (
|
||||||
|
ClientPolicyView,
|
||||||
|
ClientPolicyUpdate,
|
||||||
|
)
|
||||||
|
from sshecret_backend.types import DBSessionDep
|
||||||
|
from sshecret_backend import audit
|
||||||
|
from .common import get_client_by_id_or_name
|
||||||
|
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||||
|
"""Construct clients sub-api."""
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.get("/clients/{name}/policies/")
|
||||||
|
async def get_client_policies(
|
||||||
|
name: str, session: Annotated[Session, Depends(get_db_session)]
|
||||||
|
) -> ClientPolicyView:
|
||||||
|
"""Get client policies."""
|
||||||
|
client = await get_client_by_id_or_name(session, name)
|
||||||
|
if not client:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail="Cannot find a client with the given name."
|
||||||
|
)
|
||||||
|
|
||||||
|
return ClientPolicyView.from_client(client)
|
||||||
|
|
||||||
|
@router.put("/clients/{name}/policies/")
|
||||||
|
async def update_client_policies(
|
||||||
|
request: Request,
|
||||||
|
name: str,
|
||||||
|
policy_update: ClientPolicyUpdate,
|
||||||
|
session: Annotated[Session, Depends(get_db_session)],
|
||||||
|
) -> ClientPolicyView:
|
||||||
|
"""Update client policies.
|
||||||
|
|
||||||
|
This is also how you delete policies.
|
||||||
|
"""
|
||||||
|
client = await get_client_by_id_or_name(session, name)
|
||||||
|
if not client:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail="Cannot find a client with the given name."
|
||||||
|
)
|
||||||
|
# Remove old policies.
|
||||||
|
policies = session.exec(
|
||||||
|
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
|
||||||
|
).all()
|
||||||
|
deleted_policies: list[ClientAccessPolicy] = []
|
||||||
|
added_policies: list[ClientAccessPolicy] = []
|
||||||
|
for policy in policies:
|
||||||
|
session.delete(policy)
|
||||||
|
deleted_policies.append(policy)
|
||||||
|
|
||||||
|
for source in policy_update.sources:
|
||||||
|
LOG.debug("Source %r", source)
|
||||||
|
policy = ClientAccessPolicy(source=str(source), client_id=client.id)
|
||||||
|
session.add(policy)
|
||||||
|
added_policies.append(policy)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
session.refresh(client)
|
||||||
|
for policy in deleted_policies:
|
||||||
|
audit.audit_remove_policy(session, request, client, policy)
|
||||||
|
|
||||||
|
for policy in added_policies:
|
||||||
|
audit.audit_update_policy(session, request, client, policy)
|
||||||
|
|
||||||
|
return ClientPolicyView.from_client(client)
|
||||||
|
|
||||||
|
return router
|
||||||
232
packages/sshecret-backend/src/sshecret_backend/api/secrets.py
Normal file
232
packages/sshecret-backend/src/sshecret_backend/api/secrets.py
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
"""Secrets sub-api factory."""
|
||||||
|
|
||||||
|
# pyright: reportUnusedFunction=false
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
|
from sqlmodel import Session, select
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from sshecret_backend.models import Client, ClientSecret
|
||||||
|
from sshecret_backend.view_models import (
|
||||||
|
ClientReference,
|
||||||
|
ClientSecretDetailList,
|
||||||
|
ClientSecretList,
|
||||||
|
ClientSecretPublic,
|
||||||
|
BodyValue,
|
||||||
|
ClientSecretResponse,
|
||||||
|
)
|
||||||
|
from sshecret_backend import audit
|
||||||
|
from sshecret_backend.types import DBSessionDep
|
||||||
|
from .common import get_client_by_id_or_name
|
||||||
|
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def lookup_client_secret(
|
||||||
|
session: Session, client: Client, name: str
|
||||||
|
) -> ClientSecret | None:
|
||||||
|
"""Look up a secret for a client."""
|
||||||
|
statement = (
|
||||||
|
select(ClientSecret)
|
||||||
|
.where(ClientSecret.client_id == client.id)
|
||||||
|
.where(ClientSecret.name == name)
|
||||||
|
)
|
||||||
|
results = session.exec(statement)
|
||||||
|
return results.first()
|
||||||
|
|
||||||
|
|
||||||
|
def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||||
|
"""Construct clients sub-api."""
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.post("/clients/{name}/secrets/")
|
||||||
|
async def add_secret_to_client(
|
||||||
|
request: Request,
|
||||||
|
name: str,
|
||||||
|
client_secret: ClientSecretPublic,
|
||||||
|
session: Annotated[Session, Depends(get_db_session)],
|
||||||
|
) -> None:
|
||||||
|
"""Add secret to a client."""
|
||||||
|
client = await get_client_by_id_or_name(session, name)
|
||||||
|
if not client:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail="Cannot find a client with the given name."
|
||||||
|
)
|
||||||
|
|
||||||
|
existing_secret = await lookup_client_secret(
|
||||||
|
session, client, client_secret.name
|
||||||
|
)
|
||||||
|
if existing_secret:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Cannot add a secret. A different secret with the same name already exists.",
|
||||||
|
)
|
||||||
|
db_secret = ClientSecret(
|
||||||
|
name=client_secret.name, client_id=client.id, secret=client_secret.secret
|
||||||
|
)
|
||||||
|
session.add(db_secret)
|
||||||
|
session.commit()
|
||||||
|
session.refresh(db_secret)
|
||||||
|
audit.audit_create_secret(session, request, client, db_secret)
|
||||||
|
|
||||||
|
@router.put("/clients/{name}/secrets/{secret_name}")
|
||||||
|
async def update_client_secret(
|
||||||
|
request: Request,
|
||||||
|
name: str,
|
||||||
|
secret_name: str,
|
||||||
|
secret_data: BodyValue,
|
||||||
|
session: Annotated[Session, Depends(get_db_session)],
|
||||||
|
) -> ClientSecretResponse:
|
||||||
|
"""Update a client secret.
|
||||||
|
|
||||||
|
This can also be used for destructive creates.
|
||||||
|
"""
|
||||||
|
client = await get_client_by_id_or_name(session, name)
|
||||||
|
if not client:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail="Cannot find a client with the given name."
|
||||||
|
)
|
||||||
|
|
||||||
|
existing_secret = await lookup_client_secret(session, client, secret_name)
|
||||||
|
if existing_secret:
|
||||||
|
existing_secret.secret = secret_data.value
|
||||||
|
|
||||||
|
session.add(existing_secret)
|
||||||
|
session.commit()
|
||||||
|
session.refresh(existing_secret)
|
||||||
|
audit.audit_update_secret(session, request, client, existing_secret)
|
||||||
|
return ClientSecretResponse.from_client_secret(existing_secret)
|
||||||
|
|
||||||
|
db_secret = ClientSecret(
|
||||||
|
name=secret_name,
|
||||||
|
client_id=client.id,
|
||||||
|
secret=secret_data.value,
|
||||||
|
)
|
||||||
|
session.add(db_secret)
|
||||||
|
session.commit()
|
||||||
|
session.refresh(db_secret)
|
||||||
|
audit.audit_create_secret(session, request, client, db_secret)
|
||||||
|
return ClientSecretResponse.from_client_secret(db_secret)
|
||||||
|
|
||||||
|
@router.get("/clients/{name}/secrets/{secret_name}")
|
||||||
|
async def request_client_secret(
|
||||||
|
request: Request,
|
||||||
|
name: str,
|
||||||
|
secret_name: str,
|
||||||
|
session: Annotated[Session, Depends(get_db_session)],
|
||||||
|
) -> ClientSecretResponse:
|
||||||
|
"""Get a client secret."""
|
||||||
|
client = await get_client_by_id_or_name(session, name)
|
||||||
|
if not client:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail="Cannot find a client with the given name."
|
||||||
|
)
|
||||||
|
|
||||||
|
secret = await lookup_client_secret(session, client, secret_name)
|
||||||
|
if not secret:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail="Cannot find a secret with the given name."
|
||||||
|
)
|
||||||
|
|
||||||
|
response_model = ClientSecretResponse.from_client_secret(secret)
|
||||||
|
audit.audit_access_secret(session, request, client, secret)
|
||||||
|
return response_model
|
||||||
|
|
||||||
|
@router.delete("/clients/{name}/secrets/{secret_name}")
|
||||||
|
async def delete_client_secret(
|
||||||
|
request: Request,
|
||||||
|
name: str,
|
||||||
|
secret_name: str,
|
||||||
|
session: Annotated[Session, Depends(get_db_session)],
|
||||||
|
) -> None:
|
||||||
|
"""Delete a secret."""
|
||||||
|
client = await get_client_by_id_or_name(session, name)
|
||||||
|
if not client:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail="Cannot find a client with the given name."
|
||||||
|
)
|
||||||
|
|
||||||
|
secret = await lookup_client_secret(session, client, secret_name)
|
||||||
|
if not secret:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail="Cannot find a secret with the given name."
|
||||||
|
)
|
||||||
|
|
||||||
|
session.delete(secret)
|
||||||
|
session.commit()
|
||||||
|
audit.audit_delete_secret(session, request, client, secret)
|
||||||
|
|
||||||
|
@router.get("/secrets/")
|
||||||
|
async def get_secret_map(
|
||||||
|
request: Request, session: Annotated[Session, Depends(get_db_session)]
|
||||||
|
) -> list[ClientSecretList]:
|
||||||
|
"""Get a list of all secrets and which clients have them."""
|
||||||
|
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
|
||||||
|
for client_secret in session.exec(select(ClientSecret)).all():
|
||||||
|
if not client_secret.client:
|
||||||
|
if client_secret.name not in client_secret_map:
|
||||||
|
client_secret_map[client_secret.name] = []
|
||||||
|
continue
|
||||||
|
client_secret_map[client_secret.name].append(client_secret.client.name)
|
||||||
|
audit.audit_client_secret_list(session, request)
|
||||||
|
return [
|
||||||
|
ClientSecretList(name=secret_name, clients=clients)
|
||||||
|
for secret_name, clients in client_secret_map.items()
|
||||||
|
]
|
||||||
|
@router.get("/secrets/detailed/")
|
||||||
|
async def get_detailed_secret_map(
|
||||||
|
request: Request, session: Annotated[Session, Depends(get_db_session)]
|
||||||
|
) -> list[ClientSecretDetailList]:
|
||||||
|
"""Get a list of all secrets and which clients have them."""
|
||||||
|
client_secrets: dict[str, ClientSecretDetailList] = {}
|
||||||
|
for client_secret in session.exec(select(ClientSecret)).all():
|
||||||
|
|
||||||
|
if client_secret.name not in client_secrets:
|
||||||
|
client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name)
|
||||||
|
client_secrets[client_secret.name].ids.append(str(client_secret.id))
|
||||||
|
if not client_secret.client:
|
||||||
|
continue
|
||||||
|
client_secrets[client_secret.name].clients.append(ClientReference(id=str(client_secret.client.id), name=client_secret.client.name))
|
||||||
|
audit.audit_client_secret_list(session, request)
|
||||||
|
return list(client_secrets.values())
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/secrets/{name}")
|
||||||
|
async def get_secret_clients(
|
||||||
|
request: Request,
|
||||||
|
name: str,
|
||||||
|
session: Annotated[Session, Depends(get_db_session)],
|
||||||
|
) -> ClientSecretList:
|
||||||
|
"""Get a list of which clients has a named secret."""
|
||||||
|
clients: list[str] = []
|
||||||
|
for client_secret in session.exec(
|
||||||
|
select(ClientSecret).where(ClientSecret.name == name)
|
||||||
|
).all():
|
||||||
|
if not client_secret.client:
|
||||||
|
continue
|
||||||
|
clients.append(client_secret.client.name)
|
||||||
|
|
||||||
|
return ClientSecretList(name=name, clients=clients)
|
||||||
|
|
||||||
|
@router.get("/secrets/{name}/detailed")
|
||||||
|
async def get_secret_clients_detailed(
|
||||||
|
request: Request,
|
||||||
|
name: str,
|
||||||
|
session: Annotated[Session, Depends(get_db_session)],
|
||||||
|
) -> ClientSecretDetailList:
|
||||||
|
"""Get a list of which clients has a named secret."""
|
||||||
|
detail_list = ClientSecretDetailList(name=name)
|
||||||
|
for client_secret in session.exec(
|
||||||
|
select(ClientSecret).where(ClientSecret.name == name)
|
||||||
|
).all():
|
||||||
|
if not client_secret.client:
|
||||||
|
continue
|
||||||
|
detail_list.ids.append(str(client_secret.id))
|
||||||
|
detail_list.clients.append(ClientReference(id=str(client_secret.client.id), name=client_secret.client.name))
|
||||||
|
|
||||||
|
return detail_list
|
||||||
|
|
||||||
|
return router
|
||||||
@ -1,436 +1,65 @@
|
|||||||
"""FastAPI api.
|
"""FastAPI api."""
|
||||||
|
|
||||||
TODO: We may want to allow a consumer to generate audit log entries manually.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
# pyright: reportUnusedFunction=false
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Sequence
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Annotated
|
|
||||||
|
|
||||||
import bcrypt
|
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
APIRouter,
|
|
||||||
Depends,
|
|
||||||
FastAPI,
|
FastAPI,
|
||||||
Header,
|
|
||||||
HTTPException,
|
|
||||||
Query,
|
|
||||||
Request,
|
Request,
|
||||||
status,
|
status,
|
||||||
)
|
)
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
from sqlalchemy import Engine
|
||||||
|
|
||||||
|
|
||||||
from sqlmodel import Session, select
|
from .models import init_db
|
||||||
|
from .backend_api import get_backend_api
|
||||||
from . import audit
|
from .db import setup_database
|
||||||
from .db import get_engine
|
|
||||||
from .models import (
|
|
||||||
APIClient,
|
|
||||||
AuditLog,
|
|
||||||
Client,
|
|
||||||
ClientAccessPolicy,
|
|
||||||
ClientSecret,
|
|
||||||
init_db,
|
|
||||||
)
|
|
||||||
from .settings import get_settings
|
|
||||||
from .view_models import (
|
|
||||||
BodyValue,
|
|
||||||
ClientCreate,
|
|
||||||
ClientSecretPublic,
|
|
||||||
ClientSecretResponse,
|
|
||||||
ClientUpdate,
|
|
||||||
ClientView,
|
|
||||||
ClientPolicyView,
|
|
||||||
ClientPolicyUpdate,
|
|
||||||
)
|
|
||||||
|
|
||||||
settings = get_settings()
|
|
||||||
engine = get_engine(settings.db_file)
|
|
||||||
|
|
||||||
|
from .settings import BackendSettings
|
||||||
|
from .types import DBSessionDep
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
API_VERSION = "v1"
|
|
||||||
|
|
||||||
|
def init_backend_app(engine: Engine, get_db_session: DBSessionDep) -> FastAPI:
|
||||||
|
"""Initialize backend app."""
|
||||||
|
|
||||||
def verify_token(token: str, stored_hash: str) -> bool:
|
@asynccontextmanager
|
||||||
"""Verify token."""
|
async def lifespan(_app: FastAPI):
|
||||||
token_bytes = token.encode("utf-8")
|
"""Create database before starting the server."""
|
||||||
stored_bytes = stored_hash.encode("utf-8")
|
LOG.debug("Running lifespan")
|
||||||
return bcrypt.checkpw(token_bytes, stored_bytes)
|
init_db(engine)
|
||||||
|
yield
|
||||||
|
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
app.include_router(get_backend_api(get_db_session))
|
||||||
|
|
||||||
@asynccontextmanager
|
@app.exception_handler(RequestValidationError)
|
||||||
async def lifespan(_app: FastAPI):
|
async def validation_exception_handler(
|
||||||
"""Create database before starting the server."""
|
request: Request, exc: RequestValidationError
|
||||||
init_db(engine)
|
):
|
||||||
yield
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
|
||||||
async def get_session():
|
|
||||||
"""Get the session."""
|
|
||||||
with Session(engine) as session:
|
|
||||||
yield session
|
|
||||||
|
|
||||||
|
|
||||||
async def validate_token(
|
|
||||||
x_api_token: Annotated[str, Header()],
|
|
||||||
session: Annotated[Session, Depends(get_session)],
|
|
||||||
) -> str:
|
|
||||||
"""Validate token."""
|
|
||||||
LOG.debug("Validating token %s", x_api_token)
|
|
||||||
statement = select(APIClient)
|
|
||||||
results = session.exec(statement)
|
|
||||||
valid = False
|
|
||||||
for result in results:
|
|
||||||
if verify_token(x_api_token, result.token):
|
|
||||||
valid = True
|
|
||||||
LOG.debug("Token is valid")
|
|
||||||
break
|
|
||||||
|
|
||||||
if not valid:
|
|
||||||
LOG.debug("Token is not valid.")
|
|
||||||
raise HTTPException(status_code=401, detail="unauthorized. invalid api token.")
|
|
||||||
return x_api_token
|
|
||||||
|
|
||||||
|
|
||||||
async def get_client_by_name(session: Session, name: str) -> Client | None:
|
|
||||||
"""Get client by name."""
|
|
||||||
client_filter = select(Client).where(Client.name == name)
|
|
||||||
client_results = session.exec(client_filter)
|
|
||||||
return client_results.first()
|
|
||||||
|
|
||||||
|
|
||||||
async def lookup_client_secret(
|
|
||||||
session: Session, client: Client, name: str
|
|
||||||
) -> ClientSecret | None:
|
|
||||||
"""Look up a secret for a client."""
|
|
||||||
statement = (
|
|
||||||
select(ClientSecret)
|
|
||||||
.where(ClientSecret.client_id == client.id)
|
|
||||||
.where(ClientSecret.name == name)
|
|
||||||
)
|
|
||||||
results = session.exec(statement)
|
|
||||||
return results.first()
|
|
||||||
|
|
||||||
|
|
||||||
LOG.info("Initializing app.")
|
|
||||||
backend_api = APIRouter(
|
|
||||||
prefix=f"/api/{API_VERSION}",
|
|
||||||
lifespan=lifespan,
|
|
||||||
dependencies=[Depends(validate_token)],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@backend_api.get("/clients/")
|
|
||||||
async def get_clients(
|
|
||||||
session: Annotated[Session, Depends(get_session)]
|
|
||||||
) -> list[ClientView]:
|
|
||||||
"""Get clients."""
|
|
||||||
statement = select(Client)
|
|
||||||
results = session.exec(statement)
|
|
||||||
clients = list(results)
|
|
||||||
return ClientView.from_client_list(clients)
|
|
||||||
|
|
||||||
|
|
||||||
@backend_api.get("/clients/{name}")
|
|
||||||
async def get_client(
|
|
||||||
request: Request, name: str, session: Annotated[Session, Depends(get_session)]
|
|
||||||
) -> ClientView:
|
|
||||||
"""Fetch a client."""
|
|
||||||
statement = select(Client).where(Client.name == name)
|
|
||||||
results = session.exec(statement)
|
|
||||||
client = results.first()
|
|
||||||
if not client:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404, detail="Cannot find a client with the given name."
|
|
||||||
)
|
|
||||||
audit.audit_access_secrets(session, request, client)
|
|
||||||
return ClientView.from_client(client)
|
|
||||||
|
|
||||||
|
|
||||||
@backend_api.delete("/clients/{name}")
|
|
||||||
async def delete_client(
|
|
||||||
request: Request, name: str, session: Annotated[Session, Depends(get_session)]
|
|
||||||
) -> None:
|
|
||||||
"""Delete a client."""
|
|
||||||
statement = select(Client).where(Client.name == name)
|
|
||||||
results = session.exec(statement)
|
|
||||||
client = results.first()
|
|
||||||
if not client:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404, detail="Cannot find a client with the given name."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
session.delete(client)
|
@app.get("/health")
|
||||||
session.commit()
|
async def get_health() -> JSONResponse:
|
||||||
audit.audit_delete_client(session, request, client)
|
"""Provide simple health check."""
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_200_OK, content=jsonable_encoder({"status": "LIVE"})
|
||||||
@backend_api.get("/clients/{name}/policies/")
|
|
||||||
async def get_client_policies(
|
|
||||||
name: str, session: Annotated[Session, Depends(get_session)]
|
|
||||||
) -> ClientPolicyView:
|
|
||||||
"""Get client policies."""
|
|
||||||
statement = select(Client).where(Client.name == name)
|
|
||||||
results = session.exec(statement)
|
|
||||||
client = results.first()
|
|
||||||
if not client:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404, detail="Cannot find a client with the given name."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return ClientPolicyView.from_client(client)
|
return app
|
||||||
|
|
||||||
|
|
||||||
@backend_api.put("/clients/{name}/policies/")
|
def create_backend_app(settings: BackendSettings) -> FastAPI:
|
||||||
async def update_client_policies(
|
"""Create the backend app."""
|
||||||
request: Request,
|
|
||||||
name: str,
|
|
||||||
policy_update: ClientPolicyUpdate,
|
|
||||||
session: Annotated[Session, Depends(get_session)],
|
|
||||||
) -> ClientPolicyView:
|
|
||||||
"""Update client policies.
|
|
||||||
|
|
||||||
This is also how you delete policies.
|
engine, get_db_session = setup_database(settings.db_url)
|
||||||
"""
|
|
||||||
statement = select(Client).where(Client.name == name)
|
|
||||||
results = session.exec(statement)
|
|
||||||
client = results.first()
|
|
||||||
if not client:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404, detail="Cannot find a client with the given name."
|
|
||||||
)
|
|
||||||
# Remove old policies.
|
|
||||||
policies = session.exec(
|
|
||||||
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
|
|
||||||
).all()
|
|
||||||
deleted_policies: list[ClientAccessPolicy] = []
|
|
||||||
added_policies: list[ClientAccessPolicy] = []
|
|
||||||
for policy in policies:
|
|
||||||
session.delete(policy)
|
|
||||||
deleted_policies.append(policy)
|
|
||||||
|
|
||||||
for source in policy_update.sources:
|
return init_backend_app(engine, get_db_session)
|
||||||
LOG.debug("Source %r", source)
|
|
||||||
policy = ClientAccessPolicy(source=str(source), client_id=client.id)
|
|
||||||
session.add(policy)
|
|
||||||
added_policies.append(policy)
|
|
||||||
|
|
||||||
session.commit()
|
|
||||||
session.refresh(client)
|
|
||||||
for policy in deleted_policies:
|
|
||||||
audit.audit_remove_policy(session, request, client, policy)
|
|
||||||
|
|
||||||
for policy in added_policies:
|
|
||||||
audit.audit_update_policy(session, request, client, policy)
|
|
||||||
|
|
||||||
return ClientPolicyView.from_client(client)
|
|
||||||
|
|
||||||
|
|
||||||
@backend_api.post("/clients/{name}/public-key")
|
|
||||||
async def update_client_public_key(
|
|
||||||
request: Request,
|
|
||||||
name: str,
|
|
||||||
client_update: ClientUpdate,
|
|
||||||
session: Annotated[Session, Depends(get_session)],
|
|
||||||
) -> ClientView:
|
|
||||||
"""Change the public key of a client.
|
|
||||||
|
|
||||||
This invalidates all secrets.
|
|
||||||
"""
|
|
||||||
statement = select(Client).where(Client.name == name)
|
|
||||||
results = session.exec(statement)
|
|
||||||
client = results.first()
|
|
||||||
if not client:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404, detail="Cannot find a client with the given name."
|
|
||||||
)
|
|
||||||
client.public_key = client_update.public_key
|
|
||||||
for secret in session.exec(
|
|
||||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
|
||||||
).all():
|
|
||||||
LOG.debug("Invalidated secret %s", secret.id)
|
|
||||||
secret.invalidated = True
|
|
||||||
secret.client_id = None
|
|
||||||
secret.client = None
|
|
||||||
|
|
||||||
session.add(client)
|
|
||||||
session.refresh(client)
|
|
||||||
session.commit()
|
|
||||||
audit.audit_invalidate_secrets(session, request, client)
|
|
||||||
|
|
||||||
return ClientView.from_client(client)
|
|
||||||
|
|
||||||
|
|
||||||
@backend_api.post("/clients/")
|
|
||||||
async def create_client(
|
|
||||||
request: Request,
|
|
||||||
client: ClientCreate,
|
|
||||||
session: Annotated[Session, Depends(get_session)],
|
|
||||||
) -> ClientView:
|
|
||||||
"""Create client."""
|
|
||||||
existing = await get_client_by_name(session, client.name)
|
|
||||||
if existing:
|
|
||||||
raise HTTPException(400, detail="Error: Already a client with that name.")
|
|
||||||
|
|
||||||
db_client = client.to_client()
|
|
||||||
session.add(db_client)
|
|
||||||
session.commit()
|
|
||||||
session.refresh(db_client)
|
|
||||||
audit.audit_create_client(session, request, db_client)
|
|
||||||
return ClientView.from_client(db_client)
|
|
||||||
|
|
||||||
|
|
||||||
@backend_api.post("/clients/{name}/secrets/")
|
|
||||||
async def add_secret_to_client(
|
|
||||||
request: Request,
|
|
||||||
name: str,
|
|
||||||
client_secret: ClientSecretPublic,
|
|
||||||
session: Annotated[Session, Depends(get_session)],
|
|
||||||
) -> None:
|
|
||||||
"""Add secret to a client."""
|
|
||||||
client = await get_client_by_name(session, name)
|
|
||||||
if not client:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404, detail="Cannot find a client with the given name."
|
|
||||||
)
|
|
||||||
|
|
||||||
existing_secret = await lookup_client_secret(session, client, client_secret.name)
|
|
||||||
if existing_secret:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="Cannot add a secret. A different secret with the same name already exists.",
|
|
||||||
)
|
|
||||||
db_secret = ClientSecret(
|
|
||||||
name=client_secret.name, client_id=client.id, secret=client_secret.secret
|
|
||||||
)
|
|
||||||
session.add(db_secret)
|
|
||||||
session.commit()
|
|
||||||
session.refresh(db_secret)
|
|
||||||
audit.audit_create_secret(session, request, client, db_secret)
|
|
||||||
|
|
||||||
|
|
||||||
@backend_api.put("/clients/{name}/secrets/{secret_name}")
|
|
||||||
async def update_client_secret(
|
|
||||||
request: Request,
|
|
||||||
name: str,
|
|
||||||
secret_name: str,
|
|
||||||
secret_data: BodyValue,
|
|
||||||
session: Annotated[Session, Depends(get_session)],
|
|
||||||
) -> ClientSecretResponse:
|
|
||||||
"""Update a client secret.
|
|
||||||
|
|
||||||
This can also be used for destructive creates.
|
|
||||||
"""
|
|
||||||
client = await get_client_by_name(session, name)
|
|
||||||
if not client:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404, detail="Cannot find a client with the given name."
|
|
||||||
)
|
|
||||||
|
|
||||||
existing_secret = await lookup_client_secret(session, client, secret_name)
|
|
||||||
if existing_secret:
|
|
||||||
existing_secret.secret = secret_data.value
|
|
||||||
session.add(existing_secret)
|
|
||||||
session.commit()
|
|
||||||
session.refresh(existing_secret)
|
|
||||||
audit.audit_update_secret(session, request, client, existing_secret)
|
|
||||||
return ClientSecretResponse.from_client_secret(existing_secret)
|
|
||||||
|
|
||||||
db_secret = ClientSecret(
|
|
||||||
name=secret_name,
|
|
||||||
client_id=client.id,
|
|
||||||
secret=secret_data.value,
|
|
||||||
)
|
|
||||||
session.add(db_secret)
|
|
||||||
session.commit()
|
|
||||||
session.refresh(db_secret)
|
|
||||||
audit.audit_create_secret(session, request, client, db_secret)
|
|
||||||
return ClientSecretResponse.from_client_secret(db_secret)
|
|
||||||
|
|
||||||
|
|
||||||
@backend_api.get("/clients/{name}/secrets/{secret_name}")
|
|
||||||
async def request_client_secret(
|
|
||||||
request: Request,
|
|
||||||
name: str,
|
|
||||||
secret_name: str,
|
|
||||||
session: Annotated[Session, Depends(get_session)],
|
|
||||||
) -> ClientSecretResponse:
|
|
||||||
"""Get a client secret."""
|
|
||||||
client = await get_client_by_name(session, name)
|
|
||||||
if not client:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404, detail="Cannot find a client with the given name."
|
|
||||||
)
|
|
||||||
|
|
||||||
secret = await lookup_client_secret(session, client, secret_name)
|
|
||||||
if not secret:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404, detail="Cannot find a secret with the given name."
|
|
||||||
)
|
|
||||||
|
|
||||||
response_model = ClientSecretResponse.from_client_secret(secret)
|
|
||||||
audit.audit_access_secret(session, request, client, secret)
|
|
||||||
return response_model
|
|
||||||
|
|
||||||
@backend_api.delete("/clients/{name}/secrets/{secret_name}")
|
|
||||||
async def delete_client_secret(
|
|
||||||
request: Request,
|
|
||||||
name: str,
|
|
||||||
secret_name: str,
|
|
||||||
session: Annotated[Session, Depends(get_session)],
|
|
||||||
) -> None:
|
|
||||||
"""Delete a secret."""
|
|
||||||
client = await get_client_by_name(session, name)
|
|
||||||
if not client:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404, detail="Cannot find a client with the given name."
|
|
||||||
)
|
|
||||||
|
|
||||||
secret = await lookup_client_secret(session, client, secret_name)
|
|
||||||
if not secret:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404, detail="Cannot find a secret with the given name."
|
|
||||||
)
|
|
||||||
|
|
||||||
session.delete(secret)
|
|
||||||
session.commit()
|
|
||||||
audit.audit_delete_secret(session, request, client, secret)
|
|
||||||
|
|
||||||
|
|
||||||
@backend_api.get("/audit/", response_model=list[AuditLog])
|
|
||||||
async def get_audit_logs(
|
|
||||||
request: Request,
|
|
||||||
session: Annotated[Session, Depends(get_session)],
|
|
||||||
offset: Annotated[int, Query()] = 0,
|
|
||||||
limit: Annotated[int, Query(le=100)] = 100,
|
|
||||||
filter_client: Annotated[str | None, Query()] = None,
|
|
||||||
) -> Sequence[AuditLog]:
|
|
||||||
"""Get audit logs."""
|
|
||||||
audit.audit_access_audit_log(session, request)
|
|
||||||
statement = select(AuditLog).offset(offset).limit(limit)
|
|
||||||
if filter_client:
|
|
||||||
statement = statement.where(AuditLog.client_name == filter_client)
|
|
||||||
|
|
||||||
results = session.exec(statement).all()
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
|
|
||||||
@app.exception_handler(RequestValidationError)
|
|
||||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
||||||
content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
app.include_router(backend_api)
|
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy
|
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy
|
||||||
|
|
||||||
|
|
||||||
@ -21,6 +22,7 @@ def _write_audit_log(
|
|||||||
"""Write the audit log."""
|
"""Write the audit log."""
|
||||||
origin = _get_origin(request)
|
origin = _get_origin(request)
|
||||||
entry.origin = origin
|
entry.origin = origin
|
||||||
|
entry.subsystem = "backend"
|
||||||
session.add(entry)
|
session.add(entry)
|
||||||
if commit:
|
if commit:
|
||||||
session.commit()
|
session.commit()
|
||||||
@ -109,6 +111,23 @@ def audit_update_policy(
|
|||||||
_write_audit_log(session, request, entry, commit)
|
_write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
|
|
||||||
|
def audit_update_client(
|
||||||
|
session: Session,
|
||||||
|
request: Request,
|
||||||
|
client: Client,
|
||||||
|
commit: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Audit an update secret event."""
|
||||||
|
entry = AuditLog(
|
||||||
|
operation="UPDATE",
|
||||||
|
object="Client",
|
||||||
|
client_id=client.id,
|
||||||
|
client_name=client.name,
|
||||||
|
message="Client updated",
|
||||||
|
)
|
||||||
|
_write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
|
|
||||||
def audit_update_secret(
|
def audit_update_secret(
|
||||||
session: Session,
|
session: Session,
|
||||||
request: Request,
|
request: Request,
|
||||||
@ -219,3 +238,15 @@ def audit_access_audit_log(
|
|||||||
object="AuditLog",
|
object="AuditLog",
|
||||||
)
|
)
|
||||||
_write_audit_log(session, request, entry, commit)
|
_write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
|
|
||||||
|
def audit_client_secret_list(
|
||||||
|
session: Session, request: Request, commit: bool = True
|
||||||
|
) -> None:
|
||||||
|
"""Audit a list of all secrets."""
|
||||||
|
entry = AuditLog(
|
||||||
|
operation="ACCESS",
|
||||||
|
message="All secret names and their clients was viewed",
|
||||||
|
)
|
||||||
|
_write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,67 @@
|
|||||||
|
"""Backend API."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import bcrypt
|
||||||
|
from fastapi import APIRouter, Depends, Header, HTTPException
|
||||||
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
|
from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api
|
||||||
|
from .models import (
|
||||||
|
APIClient,
|
||||||
|
)
|
||||||
|
from .types import DBSessionDep
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
API_VERSION = "v1"
|
||||||
|
|
||||||
|
|
||||||
|
def verify_token(token: str, stored_hash: str) -> bool:
|
||||||
|
"""Verify token."""
|
||||||
|
token_bytes = token.encode("utf-8")
|
||||||
|
stored_bytes = stored_hash.encode("utf-8")
|
||||||
|
return bcrypt.checkpw(token_bytes, stored_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
def get_backend_api(
|
||||||
|
get_db_session: DBSessionDep,
|
||||||
|
) -> APIRouter:
|
||||||
|
"""Construct backend API."""
|
||||||
|
|
||||||
|
async def validate_token(
|
||||||
|
x_api_token: Annotated[str, Header()],
|
||||||
|
session: Annotated[Session, Depends(get_db_session)],
|
||||||
|
) -> str:
|
||||||
|
"""Validate token."""
|
||||||
|
LOG.debug("Validating token %s", x_api_token)
|
||||||
|
statement = select(APIClient)
|
||||||
|
results = session.exec(statement)
|
||||||
|
valid = False
|
||||||
|
for result in results:
|
||||||
|
if verify_token(x_api_token, result.token):
|
||||||
|
valid = True
|
||||||
|
LOG.debug("Token is valid")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not valid:
|
||||||
|
LOG.debug("Token is not valid.")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401, detail="unauthorized. invalid api token."
|
||||||
|
)
|
||||||
|
return x_api_token
|
||||||
|
|
||||||
|
LOG.info("Initializing app.")
|
||||||
|
|
||||||
|
backend_api = APIRouter(
|
||||||
|
prefix=f"/api/{API_VERSION}",
|
||||||
|
dependencies=[Depends(validate_token)],
|
||||||
|
)
|
||||||
|
|
||||||
|
backend_api.include_router(get_audit_api(get_db_session))
|
||||||
|
backend_api.include_router(get_clients_api(get_db_session))
|
||||||
|
backend_api.include_router(get_policy_api(get_db_session))
|
||||||
|
backend_api.include_router(get_secrets_api(get_db_session))
|
||||||
|
|
||||||
|
return backend_api
|
||||||
@ -1,11 +1,18 @@
|
|||||||
"""CLI and main entry point."""
|
"""CLI and main entry point."""
|
||||||
|
|
||||||
|
import code
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import cast
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
import click
|
import click
|
||||||
|
from sqlmodel import Session, create_engine, select
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
from .db import generate_api_token
|
from .db import create_api_token
|
||||||
|
|
||||||
|
from .models import Client, ClientSecret, ClientAccessPolicy, AuditLog, APIClient
|
||||||
|
from .settings import BackendSettings
|
||||||
|
|
||||||
DEFAULT_LISTEN = "127.0.0.1"
|
DEFAULT_LISTEN = "127.0.0.1"
|
||||||
DEFAULT_PORT = 8022
|
DEFAULT_PORT = 8022
|
||||||
@ -14,18 +21,59 @@ WORKDIR = Path(os.getcwd())
|
|||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
@click.option("--database", help="Path to the sqlite database file.")
|
@click.option("--database", help="Path to the sqlite database file.")
|
||||||
def cli(database: str) -> None:
|
@click.pass_context
|
||||||
|
def cli(ctx: click.Context, database: str) -> None:
|
||||||
"""CLI group."""
|
"""CLI group."""
|
||||||
if database:
|
if database:
|
||||||
# Hopefully it's enough to set the environment variable as so.
|
# Hopefully it's enough to set the environment variable as so.
|
||||||
os.environ["SSHECRET_DB_FILE"] = str(Path(database).absolute())
|
settings = BackendSettings(db_url=f"sqlite:///{Path(database).absolute()}")
|
||||||
|
else:
|
||||||
|
settings = BackendSettings()
|
||||||
|
|
||||||
|
ctx.obj = settings
|
||||||
|
|
||||||
|
|
||||||
@cli.command("generate-token")
|
@cli.command("generate-token")
|
||||||
def cli_generate_token() -> None:
|
@click.pass_context
|
||||||
|
def cli_generate_token(ctx: click.Context) -> None:
|
||||||
"""Generate a token."""
|
"""Generate a token."""
|
||||||
token = generate_api_token()
|
settings = cast(BackendSettings, ctx.obj)
|
||||||
|
engine = create_engine(settings.db_url)
|
||||||
|
with Session(engine) as session:
|
||||||
|
token = create_api_token(session, True)
|
||||||
click.echo("Generated api token:")
|
click.echo("Generated api token:")
|
||||||
click.echo(token)
|
click.echo(token)
|
||||||
|
|
||||||
|
@cli.command("run")
|
||||||
|
@click.option("--host", default="127.0.0.1")
|
||||||
|
@click.option("--port", default=8022, type=click.INT)
|
||||||
|
@click.option("--dev", is_flag=True)
|
||||||
|
@click.option("--workers", type=click.INT)
|
||||||
|
def cli_run(host: str, port: int, dev: bool, workers: int | None) -> None:
|
||||||
|
"""Run the server."""
|
||||||
|
uvicorn.run("sshecret_backend.main:app", host=host, port=port, reload=dev, workers=workers)
|
||||||
|
|
||||||
|
@cli.command("repl")
|
||||||
|
@click.pass_context
|
||||||
|
def cli_repl(ctx: click.Context) -> None:
|
||||||
|
"""Run an interactive console."""
|
||||||
|
settings = cast(BackendSettings, ctx.obj)
|
||||||
|
engine = create_engine(settings.db_url)
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
locals = {
|
||||||
|
"session": session,
|
||||||
|
"select": select,
|
||||||
|
"Client": Client,
|
||||||
|
"ClientSecret": ClientSecret,
|
||||||
|
"ClientAccessPolicy": ClientAccessPolicy,
|
||||||
|
"APIClient": APIClient,
|
||||||
|
"AuditLog": AuditLog,
|
||||||
|
}
|
||||||
|
|
||||||
|
console = code.InteractiveConsole(locals=locals, local_exit=True)
|
||||||
|
banner = "Sshecret-backend REPL.\nUse 'session' to interact with the database."
|
||||||
|
console.interact(banner=banner, exitmsg="Bye!")
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
|
from collections.abc import Generator, Callable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from sqlalchemy import Engine
|
from sqlalchemy import Engine
|
||||||
from sqlmodel import Session, create_engine, text
|
from sqlmodel import Session, create_engine, text
|
||||||
@ -9,14 +10,30 @@ import bcrypt
|
|||||||
|
|
||||||
from sqlalchemy.engine import URL
|
from sqlalchemy.engine import URL
|
||||||
|
|
||||||
from .models import APIClient, init_db
|
|
||||||
|
|
||||||
from .settings import get_settings
|
from .models import APIClient
|
||||||
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_database(
|
||||||
|
db_url: URL | str,
|
||||||
|
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
|
||||||
|
"""Setup database."""
|
||||||
|
|
||||||
|
engine = create_engine(db_url, echo=False)
|
||||||
|
with engine.connect() as connection:
|
||||||
|
connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only
|
||||||
|
|
||||||
|
def get_db_session() -> Generator[Session, None, None]:
|
||||||
|
"""Get DB Session."""
|
||||||
|
with Session(engine) as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
return engine, get_db_session
|
||||||
|
|
||||||
|
|
||||||
def get_engine(filename: Path, echo: bool = False) -> Engine:
|
def get_engine(filename: Path, echo: bool = False) -> Engine:
|
||||||
"""Initialize the engine."""
|
"""Initialize the engine."""
|
||||||
url = URL.create(drivername="sqlite", database=str(filename.absolute()))
|
url = URL.create(drivername="sqlite", database=str(filename.absolute()))
|
||||||
@ -27,20 +44,6 @@ def get_engine(filename: Path, echo: bool = False) -> Engine:
|
|||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|
||||||
def create_db_and_tables(filename: Path, echo: bool = False) -> bool:
|
|
||||||
"""Create database and tables.
|
|
||||||
|
|
||||||
Returns True if the database was created.
|
|
||||||
"""
|
|
||||||
created = False
|
|
||||||
if not filename.exists():
|
|
||||||
created = True
|
|
||||||
engine = get_engine(filename, echo)
|
|
||||||
|
|
||||||
init_db(engine)
|
|
||||||
return created
|
|
||||||
|
|
||||||
|
|
||||||
def create_api_token(session: Session, read_write: bool) -> str:
|
def create_api_token(session: Session, read_write: bool) -> str:
|
||||||
"""Create API token."""
|
"""Create API token."""
|
||||||
token = secrets.token_urlsafe(32)
|
token = secrets.token_urlsafe(32)
|
||||||
@ -54,14 +57,3 @@ def create_api_token(session: Session, read_write: bool) -> str:
|
|||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
return token
|
return token
|
||||||
|
|
||||||
|
|
||||||
def generate_api_token() -> str:
|
|
||||||
"""Generate API token."""
|
|
||||||
settings = get_settings()
|
|
||||||
engine = get_engine(settings.db_file)
|
|
||||||
init_db(engine)
|
|
||||||
with Session(engine) as session:
|
|
||||||
token = create_api_token(session, True)
|
|
||||||
|
|
||||||
return token
|
|
||||||
|
|||||||
7
packages/sshecret-backend/src/sshecret_backend/main.py
Normal file
7
packages/sshecret-backend/src/sshecret_backend/main.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
"""Main script entrypoint."""
|
||||||
|
|
||||||
|
from .settings import BackendSettings
|
||||||
|
|
||||||
|
from .app import create_backend_app
|
||||||
|
|
||||||
|
app = create_backend_app(BackendSettings())
|
||||||
@ -7,17 +7,21 @@ This might require some changes to these schemas.
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlmodel import Field, Relationship, SQLModel
|
from sqlmodel import Field, Relationship, SQLModel
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Client(SQLModel, table=True):
|
class Client(SQLModel, table=True):
|
||||||
"""Client model."""
|
"""Client model."""
|
||||||
|
|
||||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||||
name: str = Field(unique=True)
|
name: str = Field(unique=True)
|
||||||
|
description: str | None = None
|
||||||
public_key: str
|
public_key: str
|
||||||
|
|
||||||
created_at: datetime | None = Field(
|
created_at: datetime | None = Field(
|
||||||
@ -61,11 +65,13 @@ class ClientAccessPolicy(SQLModel, table=True):
|
|||||||
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
|
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ClientSecret(SQLModel, table=True):
|
class ClientSecret(SQLModel, table=True):
|
||||||
"""A client secret."""
|
"""A client secret."""
|
||||||
|
|
||||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||||
name: str
|
name: str
|
||||||
|
description: str | None = None
|
||||||
client_id: uuid.UUID | None = Field(foreign_key="client.id", ondelete="CASCADE")
|
client_id: uuid.UUID | None = Field(foreign_key="client.id", ondelete="CASCADE")
|
||||||
client: Client | None = Relationship(back_populates="secrets")
|
client: Client | None = Relationship(back_populates="secrets")
|
||||||
secret: str
|
secret: str
|
||||||
@ -92,6 +98,7 @@ class AuditLog(SQLModel, table=True):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||||
|
subsystem: str | None = None
|
||||||
object: str | None = None
|
object: str | None = None
|
||||||
object_id: str | None = None
|
object_id: str | None = None
|
||||||
operation: str
|
operation: str
|
||||||
@ -107,6 +114,7 @@ class AuditLog(SQLModel, table=True):
|
|||||||
nullable=False,
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class APIClient(SQLModel, table=True):
|
class APIClient(SQLModel, table=True):
|
||||||
"""Stores API Keys."""
|
"""Stores API Keys."""
|
||||||
|
|
||||||
@ -120,6 +128,8 @@ class APIClient(SQLModel, table=True):
|
|||||||
nullable=False,
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def init_db(engine: sa.Engine) -> None:
|
def init_db(engine: sa.Engine) -> None:
|
||||||
"""Create database."""
|
"""Create database."""
|
||||||
|
LOG.info("Starting init_db")
|
||||||
SQLModel.metadata.create_all(engine)
|
SQLModel.metadata.create_all(engine)
|
||||||
|
|||||||
@ -1,15 +1,13 @@
|
|||||||
"""Settings management."""
|
"""Settings management."""
|
||||||
|
|
||||||
from typing import override
|
from pydantic import Field
|
||||||
from pathlib import Path
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from pydantic_settings import (
|
from pydantic_settings import (
|
||||||
BaseSettings,
|
BaseSettings,
|
||||||
SettingsConfigDict,
|
SettingsConfigDict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_DATABASE = "sshecret.db"
|
DEFAULT_DATABASE = "sqlite:///sshecret.db"
|
||||||
|
|
||||||
|
|
||||||
class BackendSettings(BaseSettings):
|
class BackendSettings(BaseSettings):
|
||||||
@ -17,7 +15,7 @@ class BackendSettings(BaseSettings):
|
|||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".backend.env", env_prefix="sshecret_")
|
model_config = SettingsConfigDict(env_file=".backend.env", env_prefix="sshecret_")
|
||||||
|
|
||||||
db_file: Path = Field(default=Path(DEFAULT_DATABASE).absolute())
|
db_url: str = Field(default=DEFAULT_DATABASE)
|
||||||
|
|
||||||
|
|
||||||
def get_settings() -> BackendSettings:
|
def get_settings() -> BackendSettings:
|
||||||
|
|||||||
@ -1,13 +1,17 @@
|
|||||||
"""Test helpers."""
|
"""Test helpers."""
|
||||||
|
|
||||||
|
import logging
|
||||||
from sqlmodel import Session
|
from sqlmodel import Session
|
||||||
from .db import get_engine, create_api_token
|
from sshecret_backend.settings import BackendSettings
|
||||||
from .models import init_db
|
from .models import init_db
|
||||||
from .settings import get_settings
|
from .db import create_api_token, setup_database
|
||||||
|
|
||||||
def create_test_token(session: Session) -> str:
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_token(settings: BackendSettings) -> str:
|
||||||
"""Create test token."""
|
"""Create test token."""
|
||||||
settings = get_settings()
|
engine, _setupdb = setup_database(settings.db_url)
|
||||||
engine = get_engine(settings.db_file)
|
with Session(engine) as session:
|
||||||
init_db(engine)
|
init_db(engine)
|
||||||
return create_api_token(session, True)
|
return create_api_token(session, True)
|
||||||
|
|||||||
8
packages/sshecret-backend/src/sshecret_backend/types.py
Normal file
8
packages/sshecret-backend/src/sshecret_backend/types.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
"""Common type definitions."""
|
||||||
|
|
||||||
|
from collections.abc import Callable, Generator
|
||||||
|
|
||||||
|
from sqlmodel import Session
|
||||||
|
|
||||||
|
|
||||||
|
DBSessionDep = Callable[[], Generator[Session, None, None]]
|
||||||
@ -1,24 +1,27 @@
|
|||||||
"""Models for API views."""
|
"""Models for API views."""
|
||||||
|
|
||||||
import ipaddress
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Annotated, Any, Self, override
|
from typing import Annotated, Self, override
|
||||||
|
|
||||||
from sqlmodel import Field, SQLModel
|
from sqlmodel import Field, SQLModel
|
||||||
from pydantic import IPvAnyAddress, IPvAnyNetwork
|
from pydantic import AfterValidator, IPvAnyAddress, IPvAnyNetwork
|
||||||
from . import models
|
|
||||||
|
|
||||||
|
from sshecret.crypto import public_key_validator
|
||||||
|
|
||||||
|
from . import models
|
||||||
|
|
||||||
|
|
||||||
class ClientView(SQLModel):
|
class ClientView(SQLModel):
|
||||||
"""View for a single client."""
|
"""View for a single client."""
|
||||||
|
|
||||||
id: uuid.UUID
|
id: uuid.UUID
|
||||||
name: str
|
name: str
|
||||||
|
description: str | None = None
|
||||||
public_key: str
|
public_key: str
|
||||||
policies: list[str] = ["0.0.0.0/0", "::/0"]
|
policies: list[str] = ["0.0.0.0/0", "::/0"]
|
||||||
secrets: list[str] = Field(default_factory=list)
|
secrets: list[str] = Field(default_factory=list)
|
||||||
created_at: datetime
|
created_at: datetime | None
|
||||||
updated_at: datetime | None = None
|
updated_at: datetime | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -33,6 +36,7 @@ class ClientView(SQLModel):
|
|||||||
view = cls(
|
view = cls(
|
||||||
id=client.id,
|
id=client.id,
|
||||||
name=client.name,
|
name=client.name,
|
||||||
|
description=client.description,
|
||||||
public_key=client.public_key,
|
public_key=client.public_key,
|
||||||
created_at=client.created_at,
|
created_at=client.created_at,
|
||||||
updated_at=client.updated_at or None,
|
updated_at=client.updated_at or None,
|
||||||
@ -46,24 +50,34 @@ class ClientView(SQLModel):
|
|||||||
return view
|
return view
|
||||||
|
|
||||||
|
|
||||||
|
class ClientQueryResult(SQLModel):
|
||||||
|
"""Result class for queries towards the client list."""
|
||||||
|
|
||||||
|
clients: list[ClientView] = Field(default_factory=list)
|
||||||
|
total_results: int
|
||||||
|
remaining_results: int
|
||||||
|
|
||||||
|
|
||||||
class ClientCreate(SQLModel):
|
class ClientCreate(SQLModel):
|
||||||
"""Model to create a client."""
|
"""Model to create a client."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
public_key: str
|
description: str | None = None
|
||||||
|
public_key: Annotated[str, AfterValidator(public_key_validator)]
|
||||||
|
|
||||||
def to_client(self) -> models.Client:
|
def to_client(self) -> models.Client:
|
||||||
"""Instantiate a client."""
|
"""Instantiate a client."""
|
||||||
public_key = self.public_key
|
|
||||||
return models.Client(
|
return models.Client(
|
||||||
name=self.name, public_key=public_key
|
name=self.name,
|
||||||
|
public_key=self.public_key,
|
||||||
|
description=self.description,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ClientUpdate(SQLModel):
|
class ClientUpdate(SQLModel):
|
||||||
"""Model to update the client public key."""
|
"""Model to update the client public key."""
|
||||||
|
|
||||||
public_key: str
|
public_key: Annotated[str, AfterValidator(public_key_validator)]
|
||||||
|
|
||||||
|
|
||||||
class BodyValue(SQLModel):
|
class BodyValue(SQLModel):
|
||||||
@ -77,6 +91,7 @@ class ClientSecretPublic(SQLModel):
|
|||||||
|
|
||||||
name: str
|
name: str
|
||||||
secret: str
|
secret: str
|
||||||
|
description: str | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_client_secret(cls, client_secret: models.ClientSecret) -> Self:
|
def from_client_secret(cls, client_secret: models.ClientSecret) -> Self:
|
||||||
@ -84,13 +99,14 @@ class ClientSecretPublic(SQLModel):
|
|||||||
return cls(
|
return cls(
|
||||||
name=client_secret.name,
|
name=client_secret.name,
|
||||||
secret=client_secret.secret,
|
secret=client_secret.secret,
|
||||||
|
description=client_secret.description,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ClientSecretResponse(ClientSecretPublic):
|
class ClientSecretResponse(ClientSecretPublic):
|
||||||
"""A secret view."""
|
"""A secret view."""
|
||||||
|
|
||||||
created_at: datetime
|
created_at: datetime | None
|
||||||
updated_at: datetime | None = None
|
updated_at: datetime | None = None
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -123,3 +139,31 @@ class ClientPolicyUpdate(SQLModel):
|
|||||||
"""Model for updating policies."""
|
"""Model for updating policies."""
|
||||||
|
|
||||||
sources: list[IPvAnyAddress | IPvAnyNetwork]
|
sources: list[IPvAnyAddress | IPvAnyNetwork]
|
||||||
|
|
||||||
|
|
||||||
|
class ClientSecretList(SQLModel):
|
||||||
|
"""Model for aggregating identically named secrets."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
clients: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ClientReference(SQLModel):
|
||||||
|
"""Reference to a client."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class ClientSecretDetailList(SQLModel):
|
||||||
|
"""A more detailed version of the ClientSecretList."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
ids: list[str] = Field(default_factory=list)
|
||||||
|
clients: list[ClientReference] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class AuditInfo(SQLModel):
|
||||||
|
"""Information about audit information."""
|
||||||
|
|
||||||
|
entries: int
|
||||||
|
|||||||
@ -1,18 +1,17 @@
|
|||||||
"""Tests of the backend api using pytest."""
|
"""Tests of the backend api using pytest."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import random
|
from pathlib import Path
|
||||||
import string
|
|
||||||
from httpx import Response
|
from httpx import Response
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from sqlmodel import Session, SQLModel, create_engine
|
|
||||||
from sqlmodel.pool import StaticPool
|
|
||||||
|
|
||||||
from sshecret_backend.app import app, get_session
|
from sshecret.crypto import generate_private_key, generate_public_key_string
|
||||||
|
from sshecret_backend.app import create_backend_app
|
||||||
from sshecret_backend.testing import create_test_token
|
from sshecret_backend.testing import create_test_token
|
||||||
from sshecret_backend.models import AuditLog
|
from sshecret_backend.models import AuditLog
|
||||||
|
from sshecret_backend.settings import BackendSettings
|
||||||
|
|
||||||
|
|
||||||
LOG = logging.getLogger()
|
LOG = logging.getLogger()
|
||||||
@ -25,20 +24,15 @@ LOG.setLevel(logging.DEBUG)
|
|||||||
|
|
||||||
def make_test_key() -> str:
|
def make_test_key() -> str:
|
||||||
"""Generate a test key."""
|
"""Generate a test key."""
|
||||||
randomlength = 540
|
private_key = generate_private_key()
|
||||||
key = "ssh-rsa "
|
return generate_public_key_string(private_key.public_key())
|
||||||
randompart = "".join(
|
|
||||||
random.choices(string.ascii_letters + string.digits, k=randomlength)
|
|
||||||
)
|
|
||||||
comment = " invalid-test-key"
|
|
||||||
return key + randompart + comment
|
|
||||||
|
|
||||||
|
|
||||||
def create_client(
|
def create_client(
|
||||||
test_client: TestClient,
|
test_client: TestClient,
|
||||||
headers: dict[str, str],
|
|
||||||
name: str,
|
name: str,
|
||||||
public_key: str | None = None,
|
public_key: str | None = None,
|
||||||
|
description: str | None = None,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""Create client."""
|
"""Create client."""
|
||||||
if not public_key:
|
if not public_key:
|
||||||
@ -47,50 +41,35 @@ def create_client(
|
|||||||
"name": name,
|
"name": name,
|
||||||
"public_key": public_key,
|
"public_key": public_key,
|
||||||
}
|
}
|
||||||
create_response = test_client.post("/api/v1/clients", headers=headers, json=data)
|
if description:
|
||||||
|
data["description"] = description
|
||||||
|
create_response = test_client.post("/api/v1/clients", json=data)
|
||||||
return create_response
|
return create_response
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="session")
|
|
||||||
def session_fixture():
|
|
||||||
engine = create_engine(
|
|
||||||
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
|
|
||||||
)
|
|
||||||
SQLModel.metadata.create_all(engine)
|
|
||||||
with Session(engine) as session:
|
|
||||||
yield session
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="token")
|
|
||||||
def token_fixture(session: Session):
|
|
||||||
"""Generate a token."""
|
|
||||||
token = create_test_token(session)
|
|
||||||
return token
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="headers")
|
|
||||||
def headers_fixture(token: str) -> dict[str, str]:
|
|
||||||
"""Generate headers."""
|
|
||||||
return {"X-API-Token": token}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="test_client")
|
@pytest.fixture(name="test_client")
|
||||||
def test_client_fixture(session: Session):
|
def create_client_fixture(tmp_path: Path):
|
||||||
"""Test client fixture."""
|
"""Test client fixture."""
|
||||||
|
|
||||||
def get_session_override():
|
db_file = tmp_path / "backend.db"
|
||||||
return session
|
print(f"DB File: {db_file.absolute()}")
|
||||||
|
settings = BackendSettings(db_url=f"sqlite:///{db_file.absolute()}")
|
||||||
|
app = create_backend_app(settings)
|
||||||
|
|
||||||
app.dependency_overrides[get_session] = get_session_override
|
token = create_test_token(settings)
|
||||||
test_client = TestClient(app)
|
|
||||||
|
test_client = TestClient(app, headers={"X-API-Token": token})
|
||||||
yield test_client
|
yield test_client
|
||||||
app.dependency_overrides.clear()
|
|
||||||
|
|
||||||
|
|
||||||
def test_missing_token(test_client: TestClient) -> None:
|
def test_missing_token(test_client: TestClient) -> None:
|
||||||
"""Test logging in with missing token."""
|
"""Test logging in with missing token."""
|
||||||
response = test_client.get("/api/v1/clients/")
|
# Save headers
|
||||||
|
old_headers = test_client.headers
|
||||||
|
test_client.headers = {}
|
||||||
|
response = test_client.get("/api/v1/clients/", headers={})
|
||||||
assert response.status_code == 422
|
assert response.status_code == 422
|
||||||
|
test_client.headers = old_headers
|
||||||
|
|
||||||
|
|
||||||
def test_incorrect_token(test_client: TestClient) -> None:
|
def test_incorrect_token(test_client: TestClient) -> None:
|
||||||
@ -99,22 +78,24 @@ def test_incorrect_token(test_client: TestClient) -> None:
|
|||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
def test_with_token(test_client: TestClient, token: str) -> None:
|
def test_with_token(test_client: TestClient) -> None:
|
||||||
"""Test with a valid token."""
|
"""Test with a valid token."""
|
||||||
response = test_client.get("/api/v1/clients/", headers={"X-API-Token": token})
|
response = test_client.get("/api/v1/clients/")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert len(response.json()) == 0
|
data = response.json()
|
||||||
|
assert data["total_results"] == 0
|
||||||
|
|
||||||
|
|
||||||
def test_create_client(test_client: TestClient, headers: dict[str, str]) -> None:
|
def test_create_client(test_client: TestClient) -> None:
|
||||||
"""Test creating a client."""
|
"""Test creating a client."""
|
||||||
client_name = "test"
|
client_name = "test"
|
||||||
client_publickey = make_test_key()
|
client_publickey = make_test_key()
|
||||||
create_response = create_client(test_client, headers, client_name, client_publickey)
|
create_response = create_client(test_client, client_name, client_publickey)
|
||||||
assert create_response.status_code == 200
|
assert create_response.status_code == 200
|
||||||
response = test_client.get("/api/v1/clients/", headers=headers)
|
response = test_client.get("/api/v1/clients")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
clients = response.json()
|
clients_result = response.json()
|
||||||
|
clients = clients_result["clients"]
|
||||||
assert isinstance(clients, list)
|
assert isinstance(clients, list)
|
||||||
client = clients[0]
|
client = clients[0]
|
||||||
assert isinstance(client, dict)
|
assert isinstance(client, dict)
|
||||||
@ -122,72 +103,63 @@ def test_create_client(test_client: TestClient, headers: dict[str, str]) -> None
|
|||||||
assert client.get("created_at") is not None
|
assert client.get("created_at") is not None
|
||||||
|
|
||||||
|
|
||||||
def test_delete_client(test_client: TestClient, headers: dict[str, str]) -> None:
|
def test_delete_client(test_client: TestClient) -> None:
|
||||||
"""Test creating a client."""
|
"""Test creating a client."""
|
||||||
client_name = "test"
|
client_name = "test"
|
||||||
create_response = create_client(
|
create_response = create_client(
|
||||||
test_client,
|
test_client,
|
||||||
headers,
|
|
||||||
client_name,
|
client_name,
|
||||||
)
|
)
|
||||||
assert create_response.status_code == 200
|
assert create_response.status_code == 200
|
||||||
|
|
||||||
resp = test_client.delete("/api/v1/clients/test", headers=headers)
|
resp = test_client.delete("/api/v1/clients/test")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
resp = test_client.get("/api/v1/clients/test", headers=headers)
|
resp = test_client.get("/api/v1/clients/test")
|
||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
def test_add_secret(test_client: TestClient, headers: dict[str, str]) -> None:
|
def test_add_secret(test_client: TestClient) -> None:
|
||||||
"""Test adding a secret to a client."""
|
"""Test adding a secret to a client."""
|
||||||
client_name = "test"
|
client_name = "test"
|
||||||
client_publickey = make_test_key()
|
client_publickey = make_test_key()
|
||||||
create_response = create_client(
|
create_response = create_client(
|
||||||
test_client,
|
test_client,
|
||||||
headers,
|
|
||||||
client_name,
|
client_name,
|
||||||
client_publickey,
|
client_publickey,
|
||||||
)
|
)
|
||||||
assert create_response.status_code == 200
|
assert create_response.status_code == 200
|
||||||
secret_name = "mysecret"
|
secret_name = "mysecret"
|
||||||
secret_value = "shhhh"
|
secret_value = "shhhh"
|
||||||
data = {"name": secret_name, "secret": secret_value}
|
data = {"name": secret_name, "secret": secret_value, "description": "A test secret"}
|
||||||
response = test_client.post(
|
response = test_client.post("/api/v1/clients/test/secrets/", json=data)
|
||||||
"/api/v1/clients/test/secrets/", headers=headers, json=data
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
# Get it back
|
# Get it back
|
||||||
get_response = test_client.get(
|
get_response = test_client.get("/api/v1/clients/test/secrets/mysecret")
|
||||||
"/api/v1/clients/test/secrets/mysecret", headers=headers
|
|
||||||
)
|
|
||||||
assert get_response.status_code == 200
|
assert get_response.status_code == 200
|
||||||
secret_body = get_response.json()
|
secret_body = get_response.json()
|
||||||
assert secret_body["name"] == data["name"]
|
assert secret_body["name"] == data["name"]
|
||||||
assert secret_body["secret"] == data["secret"]
|
assert secret_body["secret"] == data["secret"]
|
||||||
|
|
||||||
|
|
||||||
def test_delete_secret(test_client: TestClient, headers: dict[str, str]) -> None:
|
def test_delete_secret(test_client: TestClient) -> None:
|
||||||
"""Test deleting a secret."""
|
"""Test deleting a secret."""
|
||||||
test_add_secret(test_client, headers)
|
test_add_secret(test_client)
|
||||||
resp = test_client.delete("/api/v1/clients/test/secrets/mysecret", headers=headers)
|
resp = test_client.delete("/api/v1/clients/test/secrets/mysecret")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
get_response = test_client.get(
|
get_response = test_client.get("/api/v1/clients/test/secrets/mysecret")
|
||||||
"/api/v1/clients/test/secrets/mysecret", headers=headers
|
|
||||||
)
|
|
||||||
assert get_response.status_code == 404
|
assert get_response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
def test_put_add_secret(test_client: TestClient, headers: dict[str, str]) -> None:
|
def test_put_add_secret(test_client: TestClient) -> None:
|
||||||
"""Test adding secret via PUT."""
|
"""Test adding secret via PUT."""
|
||||||
# Use the test_create_client function to create a client.
|
# Use the test_create_client function to create a client.
|
||||||
test_create_client(test_client, headers)
|
test_create_client(test_client)
|
||||||
secret_name = "mysecret"
|
secret_name = "mysecret"
|
||||||
secret_value = "shhhh"
|
secret_value = "shhhh"
|
||||||
data = {"name": secret_name, "secret": secret_value}
|
data = {"name": secret_name, "secret": secret_value, "description": None}
|
||||||
response = test_client.put(
|
response = test_client.put(
|
||||||
"/api/v1/clients/test/secrets/mysecret",
|
"/api/v1/clients/test/secrets/mysecret",
|
||||||
headers=headers,
|
|
||||||
json={"value": secret_value},
|
json={"value": secret_value},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@ -197,13 +169,12 @@ def test_put_add_secret(test_client: TestClient, headers: dict[str, str]) -> Non
|
|||||||
assert response_model == data
|
assert response_model == data
|
||||||
|
|
||||||
|
|
||||||
def test_put_update_secret(test_client: TestClient, headers: dict[str, str]) -> None:
|
def test_put_update_secret(test_client: TestClient) -> None:
|
||||||
"""Test updating a client secret."""
|
"""Test updating a client secret."""
|
||||||
test_add_secret(test_client, headers)
|
test_add_secret(test_client)
|
||||||
new_value = "itsasecret"
|
new_value = "itsasecret"
|
||||||
update_response = test_client.put(
|
update_response = test_client.put(
|
||||||
"/api/v1/clients/test/secrets/mysecret",
|
"/api/v1/clients/test/secrets/mysecret",
|
||||||
headers=headers,
|
|
||||||
json={"value": new_value},
|
json={"value": new_value},
|
||||||
)
|
)
|
||||||
assert update_response.status_code == 200
|
assert update_response.status_code == 200
|
||||||
@ -218,26 +189,25 @@ def test_put_update_secret(test_client: TestClient, headers: dict[str, str]) ->
|
|||||||
assert "updated_at" in response_model
|
assert "updated_at" in response_model
|
||||||
|
|
||||||
|
|
||||||
def test_audit_logging(test_client: TestClient, headers: dict[str, str]) -> None:
|
def test_audit_logging(test_client: TestClient) -> None:
|
||||||
"""Test audit logging."""
|
"""Test audit logging."""
|
||||||
public_key = make_test_key()
|
public_key = make_test_key()
|
||||||
create_client_resp = create_client(test_client, headers, "test", public_key)
|
create_client_resp = create_client(test_client, "test", public_key)
|
||||||
assert create_client_resp.status_code == 200
|
assert create_client_resp.status_code == 200
|
||||||
secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"}
|
secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"}
|
||||||
for name, secret in secrets.items():
|
for name, secret in secrets.items():
|
||||||
add_resp = test_client.post(
|
add_resp = test_client.post(
|
||||||
"/api/v1/clients/test/secrets/",
|
"/api/v1/clients/test/secrets/",
|
||||||
headers=headers,
|
|
||||||
json={"name": name, "secret": secret},
|
json={"name": name, "secret": secret},
|
||||||
)
|
)
|
||||||
assert add_resp.status_code == 200
|
assert add_resp.status_code == 200
|
||||||
|
|
||||||
# Fetch the entire client.
|
# Fetch the entire client.
|
||||||
get_client_resp = test_client.get("/api/v1/clients/test", headers=headers)
|
get_client_resp = test_client.get("/api/v1/clients/test")
|
||||||
assert get_client_resp.status_code == 200
|
assert get_client_resp.status_code == 200
|
||||||
|
|
||||||
# Fetch the audit log
|
# Fetch the audit log
|
||||||
audit_log_resp = test_client.get("/api/v1/audit/", headers=headers)
|
audit_log_resp = test_client.get("/api/v1/audit/")
|
||||||
assert audit_log_resp.status_code == 200
|
assert audit_log_resp.status_code == 200
|
||||||
audit_logs = audit_log_resp.json()
|
audit_logs = audit_log_resp.json()
|
||||||
assert len(audit_logs) > 0
|
assert len(audit_logs) > 0
|
||||||
@ -247,61 +217,60 @@ def test_audit_logging(test_client: TestClient, headers: dict[str, str]) -> None
|
|||||||
assert audit_log is not None
|
assert audit_log is not None
|
||||||
|
|
||||||
|
|
||||||
def test_audit_log_filtering(
|
# def test_audit_log_filtering(
|
||||||
session: Session, test_client: TestClient, headers: dict[str, str]
|
# session: Session, test_client: TestClient
|
||||||
) -> None:
|
# ) -> None:
|
||||||
"""Test audit log filtering."""
|
# """Test audit log filtering."""
|
||||||
# Create a lot of test data, but just manually.
|
# # Create a lot of test data, but just manually.
|
||||||
audit_log_amount = 150
|
# audit_log_amount = 150
|
||||||
entries: list[AuditLog] = []
|
# entries: list[AuditLog] = []
|
||||||
for i in range(audit_log_amount):
|
# for i in range(audit_log_amount):
|
||||||
client_id = i % 5
|
# client_id = i % 5
|
||||||
entries.append(
|
# entries.append(
|
||||||
AuditLog(
|
# AuditLog(
|
||||||
operation="TEST",
|
# operation="TEST",
|
||||||
object_id=str(i),
|
# object_id=str(i),
|
||||||
client_name=f"client-{client_id}",
|
# client_name=f"client-{client_id}",
|
||||||
message="Test Message",
|
# message="Test Message",
|
||||||
)
|
# )
|
||||||
)
|
# )
|
||||||
|
|
||||||
session.add_all(entries)
|
# session.add_all(entries)
|
||||||
session.commit()
|
# session.commit()
|
||||||
|
|
||||||
# This should have generated a lot of audit messages
|
# # This should have generated a lot of audit messages
|
||||||
|
|
||||||
audit_path = "/api/v1/audit/"
|
# audit_path = "/api/v1/audit/"
|
||||||
audit_log_resp = test_client.get(audit_path, headers=headers)
|
# audit_log_resp = test_client.get(audit_path)
|
||||||
assert audit_log_resp.status_code == 200
|
# assert audit_log_resp.status_code == 200
|
||||||
entries = audit_log_resp.json()
|
# entries = audit_log_resp.json()
|
||||||
assert len(entries) == 100 # We get 100 at a time
|
# assert len(entries) == 100 # We get 100 at a time
|
||||||
|
|
||||||
audit_log_resp = test_client.get(
|
# audit_log_resp = test_client.get(
|
||||||
audit_path, headers=headers, params={"offset": 100}
|
# audit_path, params={"offset": 100}
|
||||||
)
|
# )
|
||||||
entries = audit_log_resp.json()
|
# entries = audit_log_resp.json()
|
||||||
assert len(entries) == 52 # There should be 50 + the two requests we made
|
# assert len(entries) == 52 # There should be 50 + the two requests we made
|
||||||
|
|
||||||
# Try to get a specific client
|
# # Try to get a specific client
|
||||||
# There should be 30 log entries for each client.
|
# # There should be 30 log entries for each client.
|
||||||
audit_log_resp = test_client.get(
|
# audit_log_resp = test_client.get(
|
||||||
audit_path, headers=headers, params={"filter_client": "client-1"}
|
# audit_path, params={"filter_client": "client-1"}
|
||||||
)
|
# )
|
||||||
|
|
||||||
entries = audit_log_resp.json()
|
# entries = audit_log_resp.json()
|
||||||
assert len(entries) == 30
|
# assert len(entries) == 30
|
||||||
|
|
||||||
|
|
||||||
def test_secret_invalidation(test_client: TestClient, headers: dict[str, str]) -> None:
|
def test_secret_invalidation(test_client: TestClient) -> None:
|
||||||
"""Test secret invalidation."""
|
"""Test secret invalidation."""
|
||||||
initial_key = make_test_key()
|
initial_key = make_test_key()
|
||||||
create_client_resp = create_client(test_client, headers, "test", initial_key)
|
create_client_resp = create_client(test_client, "test", initial_key)
|
||||||
assert create_client_resp.status_code == 200
|
assert create_client_resp.status_code == 200
|
||||||
secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"}
|
secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"}
|
||||||
for name, secret in secrets.items():
|
for name, secret in secrets.items():
|
||||||
add_resp = test_client.post(
|
add_resp = test_client.post(
|
||||||
"/api/v1/clients/test/secrets/",
|
"/api/v1/clients/test/secrets/",
|
||||||
headers=headers,
|
|
||||||
json={"name": name, "secret": secret},
|
json={"name": name, "secret": secret},
|
||||||
)
|
)
|
||||||
assert add_resp.status_code == 200
|
assert add_resp.status_code == 200
|
||||||
@ -311,13 +280,12 @@ def test_secret_invalidation(test_client: TestClient, headers: dict[str, str]) -
|
|||||||
new_key = make_test_key()
|
new_key = make_test_key()
|
||||||
update_resp = test_client.post(
|
update_resp = test_client.post(
|
||||||
"/api/v1/clients/test/public-key",
|
"/api/v1/clients/test/public-key",
|
||||||
headers=headers,
|
|
||||||
json={"public_key": new_key},
|
json={"public_key": new_key},
|
||||||
)
|
)
|
||||||
assert update_resp.status_code == 200
|
assert update_resp.status_code == 200
|
||||||
|
|
||||||
# Fetch the client. The list of secrets should be empty.
|
# Fetch the client. The list of secrets should be empty.
|
||||||
get_resp = test_client.get("/api/v1/clients/test", headers=headers)
|
get_resp = test_client.get("/api/v1/clients/test")
|
||||||
assert get_resp.status_code == 200
|
assert get_resp.status_code == 200
|
||||||
client = get_resp.json()
|
client = get_resp.json()
|
||||||
secrets = client.get("secrets")
|
secrets = client.get("secrets")
|
||||||
@ -325,14 +293,14 @@ def test_secret_invalidation(test_client: TestClient, headers: dict[str, str]) -
|
|||||||
|
|
||||||
|
|
||||||
def test_client_default_policies(
|
def test_client_default_policies(
|
||||||
test_client: TestClient, headers: dict[str, str]
|
test_client: TestClient,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test client policies."""
|
"""Test client policies."""
|
||||||
public_key = make_test_key()
|
public_key = make_test_key()
|
||||||
resp = create_client(test_client, headers, "test", public_key)
|
resp = create_client(test_client, "test")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
# Fetch policies, should return *
|
# Fetch policies, should return *
|
||||||
resp = test_client.get("/api/v1/clients/test/policies/", headers=headers)
|
resp = test_client.get("/api/v1/clients/test/policies/")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
policies = resp.json()
|
policies = resp.json()
|
||||||
@ -340,21 +308,17 @@ def test_client_default_policies(
|
|||||||
assert policies["sources"] == ["0.0.0.0/0", "::/0"]
|
assert policies["sources"] == ["0.0.0.0/0", "::/0"]
|
||||||
|
|
||||||
|
|
||||||
def test_client_policy_update_one(
|
def test_client_policy_update_one(test_client: TestClient) -> None:
|
||||||
test_client: TestClient, headers: dict[str, str]
|
|
||||||
) -> None:
|
|
||||||
"""Update client policy with single policy."""
|
"""Update client policy with single policy."""
|
||||||
public_key = make_test_key()
|
public_key = make_test_key()
|
||||||
resp = create_client(test_client, headers, "test", public_key)
|
resp = create_client(test_client, "test", public_key)
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
policy = ["192.0.2.1"]
|
policy = ["192.0.2.1"]
|
||||||
resp = test_client.put(
|
resp = test_client.put("/api/v1/clients/test/policies/", json={"sources": policy})
|
||||||
"/api/v1/clients/test/policies/", headers=headers, json={"sources": policy}
|
|
||||||
)
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
resp = test_client.get("/api/v1/clients/test/policies/", headers=headers)
|
resp = test_client.get("/api/v1/clients/test/policies/")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
policies = resp.json()
|
policies = resp.json()
|
||||||
@ -362,22 +326,18 @@ def test_client_policy_update_one(
|
|||||||
assert policies["sources"] == policy
|
assert policies["sources"] == policy
|
||||||
|
|
||||||
|
|
||||||
def test_client_policy_update_advanced(
|
def test_client_policy_update_advanced(test_client: TestClient) -> None:
|
||||||
test_client: TestClient, headers: dict[str, str]
|
|
||||||
) -> None:
|
|
||||||
"""Test other policy update scenarios."""
|
"""Test other policy update scenarios."""
|
||||||
public_key = make_test_key()
|
public_key = make_test_key()
|
||||||
resp = create_client(test_client, headers, "test", public_key)
|
resp = create_client(test_client, "test", public_key)
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
policy = ["192.0.2.1", "198.18.0.0/24"]
|
policy = ["192.0.2.1", "198.18.0.0/24"]
|
||||||
|
|
||||||
resp = test_client.put(
|
resp = test_client.put("/api/v1/clients/test/policies/", json={"sources": policy})
|
||||||
"/api/v1/clients/test/policies/", headers=headers, json={"sources": policy}
|
|
||||||
)
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
resp = test_client.get("/api/v1/clients/test/policies/", headers=headers)
|
resp = test_client.get("/api/v1/clients/test/policies/")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
policies = resp.json()
|
policies = resp.json()
|
||||||
@ -389,13 +349,11 @@ def test_client_policy_update_advanced(
|
|||||||
|
|
||||||
policy = ["obviosly_wrong"]
|
policy = ["obviosly_wrong"]
|
||||||
|
|
||||||
resp = test_client.put(
|
resp = test_client.put("/api/v1/clients/test/policies/", json={"sources": policy})
|
||||||
"/api/v1/clients/test/policies/", headers=headers, json={"sources": policy}
|
|
||||||
)
|
|
||||||
assert resp.status_code == 422
|
assert resp.status_code == 422
|
||||||
# Check that the old value is still there
|
# Check that the old value is still there
|
||||||
|
|
||||||
resp = test_client.get("/api/v1/clients/test/policies/", headers=headers)
|
resp = test_client.get("/api/v1/clients/test/policies/")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
policies = resp.json()
|
policies = resp.json()
|
||||||
@ -407,18 +365,14 @@ def test_client_policy_update_advanced(
|
|||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
def test_client_policy_update_unset(
|
def test_client_policy_update_unset(test_client: TestClient) -> None:
|
||||||
test_client: TestClient, headers: dict[str, str]
|
|
||||||
) -> None:
|
|
||||||
"""Test clearing the client policy."""
|
"""Test clearing the client policy."""
|
||||||
public_key = make_test_key()
|
public_key = make_test_key()
|
||||||
resp = create_client(test_client, headers, "test", public_key)
|
resp = create_client(test_client, "test", public_key)
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
policy = ["192.0.2.1", "198.18.0.0/24"]
|
policy = ["192.0.2.1", "198.18.0.0/24"]
|
||||||
|
|
||||||
resp = test_client.put(
|
resp = test_client.put("/api/v1/clients/test/policies/", json={"sources": policy})
|
||||||
"/api/v1/clients/test/policies/", headers=headers, json={"sources": policy}
|
|
||||||
)
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
policies = resp.json()
|
policies = resp.json()
|
||||||
@ -428,11 +382,158 @@ def test_client_policy_update_unset(
|
|||||||
|
|
||||||
# Now we clear the policies
|
# Now we clear the policies
|
||||||
|
|
||||||
resp = test_client.put(
|
resp = test_client.put("/api/v1/clients/test/policies/", json={"sources": []})
|
||||||
"/api/v1/clients/test/policies/", headers=headers, json={"sources": []}
|
|
||||||
)
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
policies = resp.json()
|
policies = resp.json()
|
||||||
|
|
||||||
assert policies["sources"] == ["0.0.0.0/0", "::/0"]
|
assert policies["sources"] == ["0.0.0.0/0", "::/0"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_update(test_client: TestClient) -> None:
|
||||||
|
"""Test generic update of a client."""
|
||||||
|
|
||||||
|
public_key = make_test_key()
|
||||||
|
resp = create_client(test_client, "test", public_key, "PRE")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
resp = test_client.get("/api/v1/clients/test")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
client_data = resp.json()
|
||||||
|
assert client_data["description"] == "PRE"
|
||||||
|
|
||||||
|
# Update the description
|
||||||
|
new_client_data = {
|
||||||
|
"name": "test",
|
||||||
|
"description": "POST",
|
||||||
|
"public_key": client_data["public_key"],
|
||||||
|
}
|
||||||
|
resp = test_client.put("/api/v1/clients/test", json=new_client_data)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
client_data = resp.json()
|
||||||
|
assert client_data["description"] == "POST"
|
||||||
|
|
||||||
|
resp = test_client.get("/api/v1/clients/test")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
client_data = resp.json()
|
||||||
|
assert client_data["description"] == "POST"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_secret_list(test_client: TestClient) -> None:
|
||||||
|
"""Test the secret to client map view."""
|
||||||
|
# Make 4 clients
|
||||||
|
for x in range(4):
|
||||||
|
public_key = make_test_key()
|
||||||
|
create_client(test_client, f"client-{x}", public_key)
|
||||||
|
# Create a secret that only this client has.
|
||||||
|
resp = test_client.put(
|
||||||
|
f"/api/v1/clients/client-{x}/secrets/client-{x}", json={"value": "SECRET"}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
# Create a secret that all of them have.
|
||||||
|
resp = test_client.put(
|
||||||
|
f"/api/v1/clients/client-{x}/secrets/commonsecret", json={"value": "SECRET"}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
# Get the secret list
|
||||||
|
resp = test_client.get("/api/v1/secrets/")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert isinstance(data, list)
|
||||||
|
assert len(data) == 5
|
||||||
|
for entry in data:
|
||||||
|
if entry["name"] == "commonsecret":
|
||||||
|
assert len(entry["clients"]) == 4
|
||||||
|
else:
|
||||||
|
assert len(entry["clients"]) == 1
|
||||||
|
assert entry["clients"][0] == entry["name"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_secret_clients(test_client: TestClient) -> None:
|
||||||
|
"""Get the clients for a single secret."""
|
||||||
|
for x in range(4):
|
||||||
|
public_key = make_test_key()
|
||||||
|
create_client(test_client, f"client-{x}", public_key)
|
||||||
|
# Create a secret that every second of them have.
|
||||||
|
if x % 2 == 1:
|
||||||
|
continue
|
||||||
|
resp = test_client.put(
|
||||||
|
f"/api/v1/clients/client-{x}/secrets/commonsecret", json={"value": "SECRET"}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
resp = test_client.get("/api/v1/secrets/commonsecret")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
assert data["name"] == "commonsecret"
|
||||||
|
assert "client-0" in data["clients"]
|
||||||
|
assert "client-1" not in data["clients"]
|
||||||
|
assert len(data["clients"]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_searching(test_client: TestClient) -> None:
|
||||||
|
"""Test searching."""
|
||||||
|
for x in range(4):
|
||||||
|
# Create four clients
|
||||||
|
create_client(test_client, f"client-{x}")
|
||||||
|
|
||||||
|
# Create one with a different name.
|
||||||
|
create_client(test_client, "othername")
|
||||||
|
|
||||||
|
# Search for a specific one.
|
||||||
|
resp = test_client.get("/api/v1/clients/", params={"name": "othername"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
result = resp.json()
|
||||||
|
assert result["total_results"] == 1
|
||||||
|
assert result["clients"][0]["name"] == "othername"
|
||||||
|
client_id = result["clients"][0]["id"]
|
||||||
|
|
||||||
|
# Search by ID
|
||||||
|
resp = test_client.get("/api/v1/clients/", params={"id": client_id})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
result = resp.json()
|
||||||
|
assert result["total_results"] == 1
|
||||||
|
assert result["clients"][0]["name"] == "othername"
|
||||||
|
|
||||||
|
# Search for the four similarly named ones
|
||||||
|
resp = test_client.get("/api/v1/clients/", params={"name__like": "client-%"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
result = resp.json()
|
||||||
|
assert result["total_results"] == 4
|
||||||
|
assert str(result["clients"][0]["name"]).startswith("client-")
|
||||||
|
|
||||||
|
|
||||||
|
def test_operations_with_id(test_client: TestClient) -> None:
|
||||||
|
"""Test operations using ID instead of name."""
|
||||||
|
create_client(test_client, "test")
|
||||||
|
resp = test_client.get("/api/v1/clients/")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
client = data["clients"][0]
|
||||||
|
client_id = client["id"]
|
||||||
|
resp = test_client.get(f"/api/v1/clients/{client_id}")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["name"] == "test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_audit_log(test_client: TestClient) -> None:
|
||||||
|
"""Test writing to the audit log."""
|
||||||
|
params = {
|
||||||
|
"object": "Test",
|
||||||
|
"operation": "TEST",
|
||||||
|
"object_id": "Something",
|
||||||
|
"message": "Test Message"
|
||||||
|
}
|
||||||
|
resp = test_client.post("/api/v1/audit", json=params)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
resp = test_client.get("/api/v1/audit")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
entry = data[0]
|
||||||
|
for key, value in params.items():
|
||||||
|
assert entry[key] == value
|
||||||
|
|||||||
Reference in New Issue
Block a user