Check in backend in working state
This commit is contained in:
@ -1,5 +1,2 @@
|
||||
"""Sshecret backend."""
|
||||
from .app import app as app
|
||||
#from .router import app as app
|
||||
|
||||
__all__ = ["app"]
|
||||
# from .router import app as app
|
||||
|
||||
@ -0,0 +1,8 @@
|
||||
"""API factory modules."""
|
||||
|
||||
from .audit import get_audit_api
|
||||
from .clients import get_clients_api
|
||||
from .policies import get_policy_api
|
||||
from .secrets import get_secrets_api
|
||||
|
||||
__all__ = ["get_audit_api", "get_clients_api", "get_policy_api", "get_secrets_api"]
|
||||
65
packages/sshecret-backend/src/sshecret_backend/api/audit.py
Normal file
65
packages/sshecret-backend/src/sshecret_backend/api/audit.py
Normal file
@ -0,0 +1,65 @@
|
||||
"""Audit sub-api factory."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from fastapi import APIRouter, Depends, Request, Query
|
||||
from sqlmodel import Session, col, func, select
|
||||
from sqlalchemy import desc
|
||||
from typing import Annotated
|
||||
|
||||
from sshecret_backend.models import AuditLog
|
||||
from sshecret_backend.types import DBSessionDep
|
||||
from sshecret_backend import audit
|
||||
from sshecret_backend.view_models import AuditInfo
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
"""Construct audit sub-api."""
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/audit/", response_model=list[AuditLog])
|
||||
async def get_audit_logs(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
offset: Annotated[int, Query()] = 0,
|
||||
limit: Annotated[int, Query(le=100)] = 100,
|
||||
filter_client: Annotated[str | None, Query()] = None,
|
||||
filter_subsystem: Annotated[str | None, Query()] = None,
|
||||
) -> Sequence[AuditLog]:
|
||||
"""Get audit logs."""
|
||||
#audit.audit_access_audit_log(session, request)
|
||||
statement = select(AuditLog).offset(offset).limit(limit).order_by(desc(col(AuditLog.timestamp)))
|
||||
if filter_client:
|
||||
statement = statement.where(AuditLog.client_name == filter_client)
|
||||
|
||||
if filter_subsystem:
|
||||
statement = statement.where(AuditLog.subsystem == filter_subsystem)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
return results
|
||||
|
||||
@router.post("/audit/")
|
||||
async def add_audit_log(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
entry: AuditLog,
|
||||
) -> AuditLog:
|
||||
"""Add entry to audit log."""
|
||||
audit_log = AuditLog.model_validate(entry.model_dump(exclude_none=True))
|
||||
session.add(audit_log)
|
||||
session.commit()
|
||||
return audit_log
|
||||
|
||||
@router.get("/audit/info")
|
||||
async def get_audit_info(request: Request, session: Annotated[Session, Depends(get_db_session)]) -> AuditInfo:
|
||||
"""Get audit info."""
|
||||
audit_count = session.exec(select(func.count('*')).select_from(AuditLog)).one()
|
||||
return AuditInfo(entries=audit_count)
|
||||
|
||||
|
||||
return router
|
||||
226
packages/sshecret-backend/src/sshecret_backend/api/clients.py
Normal file
226
packages/sshecret-backend/src/sshecret_backend/api/clients.py
Normal file
@ -0,0 +1,226 @@
|
||||
"""Client sub-api factory."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
|
||||
import uuid
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from sqlmodel import Session, col, select
|
||||
from sqlalchemy import func
|
||||
from typing import Annotated, Self, TypeVar
|
||||
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
from sshecret_backend.types import DBSessionDep
|
||||
from sshecret_backend.models import Client, ClientSecret
|
||||
from sshecret_backend.view_models import (
|
||||
ClientCreate,
|
||||
ClientQueryResult,
|
||||
ClientView,
|
||||
ClientUpdate,
|
||||
)
|
||||
from sshecret_backend import audit
|
||||
from .common import get_client_by_id_or_name
|
||||
|
||||
|
||||
class ClientListParams(BaseModel):
|
||||
"""Client list parameters."""
|
||||
|
||||
limit: int = Field(100, gt=0, le=100)
|
||||
offset: int = Field(0, ge=0)
|
||||
id: uuid.UUID | None = None
|
||||
name: str | None = None
|
||||
name__like: str | None = None
|
||||
name__contains: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_expressions(self) -> Self:
|
||||
"""Validate mutually exclusive expression."""
|
||||
name_filter = False
|
||||
if self.name__like or self.name__contains:
|
||||
name_filter = True
|
||||
if self.name__like and self.name__contains:
|
||||
raise ValueError("You may only specify one name expression")
|
||||
if self.name and name_filter:
|
||||
raise ValueError(
|
||||
"You must either specify name or one of name__like or name__contains"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def filter_client_statement(
|
||||
statement: SelectOfScalar[T], params: ClientListParams, ignore_limits: bool = False
|
||||
) -> SelectOfScalar[T]:
|
||||
"""Filter a statement with the provided params."""
|
||||
if params.id:
|
||||
statement = statement.where(Client.id == params.id)
|
||||
|
||||
if params.name:
|
||||
statement = statement.where(Client.name == params.name)
|
||||
elif params.name__like:
|
||||
statement = statement.where(col(Client.name).like(params.name__like))
|
||||
elif params.name__contains:
|
||||
statement = statement.where(col(Client.name).contains(params.name__contains))
|
||||
|
||||
if ignore_limits:
|
||||
return statement
|
||||
|
||||
return statement.limit(params.limit).offset(params.offset)
|
||||
|
||||
|
||||
def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
"""Construct clients sub-api."""
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/clients/")
|
||||
async def get_clients(
|
||||
filter_query: Annotated[ClientListParams, Query()],
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
) -> ClientQueryResult:
|
||||
"""Get clients."""
|
||||
# Get total results first
|
||||
count_statement = select(func.count("*")).select_from(Client)
|
||||
count_statement = filter_client_statement(count_statement, filter_query, True)
|
||||
|
||||
total_results = session.exec(count_statement).one()
|
||||
|
||||
statement = filter_client_statement(select(Client), filter_query, False)
|
||||
|
||||
results = session.exec(statement)
|
||||
remainder = total_results - filter_query.offset - filter_query.limit
|
||||
if remainder < 0:
|
||||
remainder = 0
|
||||
|
||||
clients = list(results)
|
||||
clients_view = ClientView.from_client_list(clients)
|
||||
return ClientQueryResult(
|
||||
clients=clients_view,
|
||||
total_results=total_results,
|
||||
remaining_results=remainder,
|
||||
)
|
||||
|
||||
@router.get("/clients/{name}")
|
||||
async def get_client(
|
||||
name: str,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
) -> ClientView:
|
||||
"""Fetch a client."""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
return ClientView.from_client(client)
|
||||
|
||||
@router.delete("/clients/{name}")
|
||||
async def delete_client(
|
||||
request: Request,
|
||||
name: str,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
) -> None:
|
||||
"""Delete a client."""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
session.delete(client)
|
||||
session.commit()
|
||||
audit.audit_delete_client(session, request, client)
|
||||
|
||||
@router.post("/clients/")
|
||||
async def create_client(
|
||||
request: Request,
|
||||
client: ClientCreate,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
) -> ClientView:
|
||||
"""Create client."""
|
||||
existing = await get_client_by_id_or_name(session, client.name)
|
||||
if existing:
|
||||
raise HTTPException(400, detail="Error: Already a client with that name.")
|
||||
|
||||
db_client = client.to_client()
|
||||
session.add(db_client)
|
||||
session.commit()
|
||||
session.refresh(db_client)
|
||||
audit.audit_create_client(session, request, db_client)
|
||||
return ClientView.from_client(db_client)
|
||||
|
||||
@router.post("/clients/{name}/public-key")
|
||||
async def update_client_public_key(
|
||||
request: Request,
|
||||
name: str,
|
||||
client_update: ClientUpdate,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
) -> ClientView:
|
||||
"""Change the public key of a client.
|
||||
|
||||
This invalidates all secrets.
|
||||
"""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
client.public_key = client_update.public_key
|
||||
for secret in session.exec(
|
||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||
).all():
|
||||
LOG.debug("Invalidated secret %s", secret.id)
|
||||
secret.invalidated = True
|
||||
secret.client_id = None
|
||||
secret.client = None
|
||||
|
||||
session.add(client)
|
||||
session.refresh(client)
|
||||
session.commit()
|
||||
audit.audit_invalidate_secrets(session, request, client)
|
||||
|
||||
return ClientView.from_client(client)
|
||||
|
||||
@router.put("/clients/{name}")
|
||||
async def update_client(
|
||||
request: Request,
|
||||
name: str,
|
||||
client_update: ClientCreate,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
) -> ClientView:
|
||||
"""Change the public key of a client.
|
||||
|
||||
This invalidates all secrets.
|
||||
"""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
client.name = client_update.name
|
||||
client.description = client_update.description
|
||||
public_key_updated = False
|
||||
if client_update.public_key != client.public_key:
|
||||
public_key_updated = True
|
||||
for secret in session.exec(
|
||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||
).all():
|
||||
LOG.debug("Invalidated secret %s", secret.id)
|
||||
secret.invalidated = True
|
||||
secret.client_id = None
|
||||
secret.client = None
|
||||
|
||||
session.add(client)
|
||||
session.commit()
|
||||
session.refresh(client)
|
||||
audit.audit_update_client(session, request, client)
|
||||
if public_key_updated:
|
||||
audit.audit_invalidate_secrets(session, request, client)
|
||||
|
||||
return ClientView.from_client(client)
|
||||
|
||||
return router
|
||||
38
packages/sshecret-backend/src/sshecret_backend/api/common.py
Normal file
38
packages/sshecret-backend/src/sshecret_backend/api/common.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""Common helpers."""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
import bcrypt
|
||||
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from sshecret_backend.models import Client
|
||||
|
||||
RE_UUID = re.compile("^[0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}$")
|
||||
|
||||
def verify_token(token: str, stored_hash: str) -> bool:
|
||||
"""Verify token."""
|
||||
token_bytes = token.encode("utf-8")
|
||||
stored_bytes = stored_hash.encode("utf-8")
|
||||
return bcrypt.checkpw(token_bytes, stored_bytes)
|
||||
|
||||
|
||||
async def get_client_by_name(session: Session, name: str) -> Client | None:
|
||||
"""Get client by name."""
|
||||
client_filter = select(Client).where(Client.name == name)
|
||||
client_results = session.exec(client_filter)
|
||||
return client_results.first()
|
||||
|
||||
async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None:
|
||||
"""Get client by name."""
|
||||
client_filter = select(Client).where(Client.id == id)
|
||||
client_results = session.exec(client_filter)
|
||||
return client_results.first()
|
||||
|
||||
async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None:
|
||||
"""Get client either by id or name."""
|
||||
if RE_UUID.match(id_or_name):
|
||||
id = uuid.UUID(id_or_name)
|
||||
return await get_client_by_id(session, id)
|
||||
|
||||
return await get_client_by_name(session, id_or_name)
|
||||
@ -0,0 +1,82 @@
|
||||
"""Policies sub-api router factory."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlmodel import Session, select
|
||||
from typing import Annotated
|
||||
|
||||
from sshecret_backend.models import Client, ClientAccessPolicy
|
||||
from sshecret_backend.view_models import (
|
||||
ClientPolicyView,
|
||||
ClientPolicyUpdate,
|
||||
)
|
||||
from sshecret_backend.types import DBSessionDep
|
||||
from sshecret_backend import audit
|
||||
from .common import get_client_by_id_or_name
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
"""Construct clients sub-api."""
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/clients/{name}/policies/")
|
||||
async def get_client_policies(
|
||||
name: str, session: Annotated[Session, Depends(get_db_session)]
|
||||
) -> ClientPolicyView:
|
||||
"""Get client policies."""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
return ClientPolicyView.from_client(client)
|
||||
|
||||
@router.put("/clients/{name}/policies/")
|
||||
async def update_client_policies(
|
||||
request: Request,
|
||||
name: str,
|
||||
policy_update: ClientPolicyUpdate,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
) -> ClientPolicyView:
|
||||
"""Update client policies.
|
||||
|
||||
This is also how you delete policies.
|
||||
"""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
# Remove old policies.
|
||||
policies = session.exec(
|
||||
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
|
||||
).all()
|
||||
deleted_policies: list[ClientAccessPolicy] = []
|
||||
added_policies: list[ClientAccessPolicy] = []
|
||||
for policy in policies:
|
||||
session.delete(policy)
|
||||
deleted_policies.append(policy)
|
||||
|
||||
for source in policy_update.sources:
|
||||
LOG.debug("Source %r", source)
|
||||
policy = ClientAccessPolicy(source=str(source), client_id=client.id)
|
||||
session.add(policy)
|
||||
added_policies.append(policy)
|
||||
|
||||
session.commit()
|
||||
session.refresh(client)
|
||||
for policy in deleted_policies:
|
||||
audit.audit_remove_policy(session, request, client, policy)
|
||||
|
||||
for policy in added_policies:
|
||||
audit.audit_update_policy(session, request, client, policy)
|
||||
|
||||
return ClientPolicyView.from_client(client)
|
||||
|
||||
return router
|
||||
232
packages/sshecret-backend/src/sshecret_backend/api/secrets.py
Normal file
232
packages/sshecret-backend/src/sshecret_backend/api/secrets.py
Normal file
@ -0,0 +1,232 @@
|
||||
"""Secrets sub-api factory."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlmodel import Session, select
|
||||
from typing import Annotated
|
||||
|
||||
from sshecret_backend.models import Client, ClientSecret
|
||||
from sshecret_backend.view_models import (
|
||||
ClientReference,
|
||||
ClientSecretDetailList,
|
||||
ClientSecretList,
|
||||
ClientSecretPublic,
|
||||
BodyValue,
|
||||
ClientSecretResponse,
|
||||
)
|
||||
from sshecret_backend import audit
|
||||
from sshecret_backend.types import DBSessionDep
|
||||
from .common import get_client_by_id_or_name
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def lookup_client_secret(
|
||||
session: Session, client: Client, name: str
|
||||
) -> ClientSecret | None:
|
||||
"""Look up a secret for a client."""
|
||||
statement = (
|
||||
select(ClientSecret)
|
||||
.where(ClientSecret.client_id == client.id)
|
||||
.where(ClientSecret.name == name)
|
||||
)
|
||||
results = session.exec(statement)
|
||||
return results.first()
|
||||
|
||||
|
||||
def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
"""Construct clients sub-api."""
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/clients/{name}/secrets/")
|
||||
async def add_secret_to_client(
|
||||
request: Request,
|
||||
name: str,
|
||||
client_secret: ClientSecretPublic,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
) -> None:
|
||||
"""Add secret to a client."""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
existing_secret = await lookup_client_secret(
|
||||
session, client, client_secret.name
|
||||
)
|
||||
if existing_secret:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot add a secret. A different secret with the same name already exists.",
|
||||
)
|
||||
db_secret = ClientSecret(
|
||||
name=client_secret.name, client_id=client.id, secret=client_secret.secret
|
||||
)
|
||||
session.add(db_secret)
|
||||
session.commit()
|
||||
session.refresh(db_secret)
|
||||
audit.audit_create_secret(session, request, client, db_secret)
|
||||
|
||||
@router.put("/clients/{name}/secrets/{secret_name}")
|
||||
async def update_client_secret(
|
||||
request: Request,
|
||||
name: str,
|
||||
secret_name: str,
|
||||
secret_data: BodyValue,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
) -> ClientSecretResponse:
|
||||
"""Update a client secret.
|
||||
|
||||
This can also be used for destructive creates.
|
||||
"""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
existing_secret = await lookup_client_secret(session, client, secret_name)
|
||||
if existing_secret:
|
||||
existing_secret.secret = secret_data.value
|
||||
|
||||
session.add(existing_secret)
|
||||
session.commit()
|
||||
session.refresh(existing_secret)
|
||||
audit.audit_update_secret(session, request, client, existing_secret)
|
||||
return ClientSecretResponse.from_client_secret(existing_secret)
|
||||
|
||||
db_secret = ClientSecret(
|
||||
name=secret_name,
|
||||
client_id=client.id,
|
||||
secret=secret_data.value,
|
||||
)
|
||||
session.add(db_secret)
|
||||
session.commit()
|
||||
session.refresh(db_secret)
|
||||
audit.audit_create_secret(session, request, client, db_secret)
|
||||
return ClientSecretResponse.from_client_secret(db_secret)
|
||||
|
||||
@router.get("/clients/{name}/secrets/{secret_name}")
|
||||
async def request_client_secret(
|
||||
request: Request,
|
||||
name: str,
|
||||
secret_name: str,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
) -> ClientSecretResponse:
|
||||
"""Get a client secret."""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
secret = await lookup_client_secret(session, client, secret_name)
|
||||
if not secret:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a secret with the given name."
|
||||
)
|
||||
|
||||
response_model = ClientSecretResponse.from_client_secret(secret)
|
||||
audit.audit_access_secret(session, request, client, secret)
|
||||
return response_model
|
||||
|
||||
@router.delete("/clients/{name}/secrets/{secret_name}")
|
||||
async def delete_client_secret(
|
||||
request: Request,
|
||||
name: str,
|
||||
secret_name: str,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
) -> None:
|
||||
"""Delete a secret."""
|
||||
client = await get_client_by_id_or_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
secret = await lookup_client_secret(session, client, secret_name)
|
||||
if not secret:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a secret with the given name."
|
||||
)
|
||||
|
||||
session.delete(secret)
|
||||
session.commit()
|
||||
audit.audit_delete_secret(session, request, client, secret)
|
||||
|
||||
@router.get("/secrets/")
|
||||
async def get_secret_map(
|
||||
request: Request, session: Annotated[Session, Depends(get_db_session)]
|
||||
) -> list[ClientSecretList]:
|
||||
"""Get a list of all secrets and which clients have them."""
|
||||
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
|
||||
for client_secret in session.exec(select(ClientSecret)).all():
|
||||
if not client_secret.client:
|
||||
if client_secret.name not in client_secret_map:
|
||||
client_secret_map[client_secret.name] = []
|
||||
continue
|
||||
client_secret_map[client_secret.name].append(client_secret.client.name)
|
||||
audit.audit_client_secret_list(session, request)
|
||||
return [
|
||||
ClientSecretList(name=secret_name, clients=clients)
|
||||
for secret_name, clients in client_secret_map.items()
|
||||
]
|
||||
@router.get("/secrets/detailed/")
|
||||
async def get_detailed_secret_map(
|
||||
request: Request, session: Annotated[Session, Depends(get_db_session)]
|
||||
) -> list[ClientSecretDetailList]:
|
||||
"""Get a list of all secrets and which clients have them."""
|
||||
client_secrets: dict[str, ClientSecretDetailList] = {}
|
||||
for client_secret in session.exec(select(ClientSecret)).all():
|
||||
|
||||
if client_secret.name not in client_secrets:
|
||||
client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name)
|
||||
client_secrets[client_secret.name].ids.append(str(client_secret.id))
|
||||
if not client_secret.client:
|
||||
continue
|
||||
client_secrets[client_secret.name].clients.append(ClientReference(id=str(client_secret.client.id), name=client_secret.client.name))
|
||||
audit.audit_client_secret_list(session, request)
|
||||
return list(client_secrets.values())
|
||||
|
||||
|
||||
@router.get("/secrets/{name}")
|
||||
async def get_secret_clients(
|
||||
request: Request,
|
||||
name: str,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
) -> ClientSecretList:
|
||||
"""Get a list of which clients has a named secret."""
|
||||
clients: list[str] = []
|
||||
for client_secret in session.exec(
|
||||
select(ClientSecret).where(ClientSecret.name == name)
|
||||
).all():
|
||||
if not client_secret.client:
|
||||
continue
|
||||
clients.append(client_secret.client.name)
|
||||
|
||||
return ClientSecretList(name=name, clients=clients)
|
||||
|
||||
@router.get("/secrets/{name}/detailed")
|
||||
async def get_secret_clients_detailed(
|
||||
request: Request,
|
||||
name: str,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
) -> ClientSecretDetailList:
|
||||
"""Get a list of which clients has a named secret."""
|
||||
detail_list = ClientSecretDetailList(name=name)
|
||||
for client_secret in session.exec(
|
||||
select(ClientSecret).where(ClientSecret.name == name)
|
||||
).all():
|
||||
if not client_secret.client:
|
||||
continue
|
||||
detail_list.ids.append(str(client_secret.id))
|
||||
detail_list.clients.append(ClientReference(id=str(client_secret.client.id), name=client_secret.client.name))
|
||||
|
||||
return detail_list
|
||||
|
||||
return router
|
||||
@ -1,436 +1,65 @@
|
||||
"""FastAPI api.
|
||||
|
||||
TODO: We may want to allow a consumer to generate audit log entries manually.
|
||||
|
||||
"""
|
||||
"""FastAPI api."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Annotated
|
||||
|
||||
import bcrypt
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
FastAPI,
|
||||
Header,
|
||||
HTTPException,
|
||||
Query,
|
||||
Request,
|
||||
status,
|
||||
)
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy import Engine
|
||||
|
||||
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from . import audit
|
||||
from .db import get_engine
|
||||
from .models import (
|
||||
APIClient,
|
||||
AuditLog,
|
||||
Client,
|
||||
ClientAccessPolicy,
|
||||
ClientSecret,
|
||||
init_db,
|
||||
)
|
||||
from .settings import get_settings
|
||||
from .view_models import (
|
||||
BodyValue,
|
||||
ClientCreate,
|
||||
ClientSecretPublic,
|
||||
ClientSecretResponse,
|
||||
ClientUpdate,
|
||||
ClientView,
|
||||
ClientPolicyView,
|
||||
ClientPolicyUpdate,
|
||||
)
|
||||
|
||||
settings = get_settings()
|
||||
engine = get_engine(settings.db_file)
|
||||
from .models import init_db
|
||||
from .backend_api import get_backend_api
|
||||
from .db import setup_database
|
||||
|
||||
from .settings import BackendSettings
|
||||
from .types import DBSessionDep
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
API_VERSION = "v1"
|
||||
|
||||
def init_backend_app(engine: Engine, get_db_session: DBSessionDep) -> FastAPI:
|
||||
"""Initialize backend app."""
|
||||
|
||||
def verify_token(token: str, stored_hash: str) -> bool:
|
||||
"""Verify token."""
|
||||
token_bytes = token.encode("utf-8")
|
||||
stored_bytes = stored_hash.encode("utf-8")
|
||||
return bcrypt.checkpw(token_bytes, stored_bytes)
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI):
|
||||
"""Create database before starting the server."""
|
||||
LOG.debug("Running lifespan")
|
||||
init_db(engine)
|
||||
yield
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.include_router(get_backend_api(get_db_session))
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI):
|
||||
"""Create database before starting the server."""
|
||||
init_db(engine)
|
||||
yield
|
||||
|
||||
|
||||
async def get_session():
|
||||
"""Get the session."""
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def validate_token(
|
||||
x_api_token: Annotated[str, Header()],
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> str:
|
||||
"""Validate token."""
|
||||
LOG.debug("Validating token %s", x_api_token)
|
||||
statement = select(APIClient)
|
||||
results = session.exec(statement)
|
||||
valid = False
|
||||
for result in results:
|
||||
if verify_token(x_api_token, result.token):
|
||||
valid = True
|
||||
LOG.debug("Token is valid")
|
||||
break
|
||||
|
||||
if not valid:
|
||||
LOG.debug("Token is not valid.")
|
||||
raise HTTPException(status_code=401, detail="unauthorized. invalid api token.")
|
||||
return x_api_token
|
||||
|
||||
|
||||
async def get_client_by_name(session: Session, name: str) -> Client | None:
|
||||
"""Get client by name."""
|
||||
client_filter = select(Client).where(Client.name == name)
|
||||
client_results = session.exec(client_filter)
|
||||
return client_results.first()
|
||||
|
||||
|
||||
async def lookup_client_secret(
|
||||
session: Session, client: Client, name: str
|
||||
) -> ClientSecret | None:
|
||||
"""Look up a secret for a client."""
|
||||
statement = (
|
||||
select(ClientSecret)
|
||||
.where(ClientSecret.client_id == client.id)
|
||||
.where(ClientSecret.name == name)
|
||||
)
|
||||
results = session.exec(statement)
|
||||
return results.first()
|
||||
|
||||
|
||||
LOG.info("Initializing app.")
|
||||
backend_api = APIRouter(
|
||||
prefix=f"/api/{API_VERSION}",
|
||||
lifespan=lifespan,
|
||||
dependencies=[Depends(validate_token)],
|
||||
)
|
||||
|
||||
|
||||
@backend_api.get("/clients/")
|
||||
async def get_clients(
|
||||
session: Annotated[Session, Depends(get_session)]
|
||||
) -> list[ClientView]:
|
||||
"""Get clients."""
|
||||
statement = select(Client)
|
||||
results = session.exec(statement)
|
||||
clients = list(results)
|
||||
return ClientView.from_client_list(clients)
|
||||
|
||||
|
||||
@backend_api.get("/clients/{name}")
|
||||
async def get_client(
|
||||
request: Request, name: str, session: Annotated[Session, Depends(get_session)]
|
||||
) -> ClientView:
|
||||
"""Fetch a client."""
|
||||
statement = select(Client).where(Client.name == name)
|
||||
results = session.exec(statement)
|
||||
client = results.first()
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
audit.audit_access_secrets(session, request, client)
|
||||
return ClientView.from_client(client)
|
||||
|
||||
|
||||
@backend_api.delete("/clients/{name}")
|
||||
async def delete_client(
|
||||
request: Request, name: str, session: Annotated[Session, Depends(get_session)]
|
||||
) -> None:
|
||||
"""Delete a client."""
|
||||
statement = select(Client).where(Client.name == name)
|
||||
results = session.exec(statement)
|
||||
client = results.first()
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(
|
||||
request: Request, exc: RequestValidationError
|
||||
):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
|
||||
)
|
||||
|
||||
session.delete(client)
|
||||
session.commit()
|
||||
audit.audit_delete_client(session, request, client)
|
||||
|
||||
|
||||
@backend_api.get("/clients/{name}/policies/")
|
||||
async def get_client_policies(
|
||||
name: str, session: Annotated[Session, Depends(get_session)]
|
||||
) -> ClientPolicyView:
|
||||
"""Get client policies."""
|
||||
statement = select(Client).where(Client.name == name)
|
||||
results = session.exec(statement)
|
||||
client = results.first()
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
@app.get("/health")
|
||||
async def get_health() -> JSONResponse:
|
||||
"""Provide simple health check."""
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK, content=jsonable_encoder({"status": "LIVE"})
|
||||
)
|
||||
|
||||
return ClientPolicyView.from_client(client)
|
||||
return app
|
||||
|
||||
|
||||
@backend_api.put("/clients/{name}/policies/")
|
||||
async def update_client_policies(
|
||||
request: Request,
|
||||
name: str,
|
||||
policy_update: ClientPolicyUpdate,
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> ClientPolicyView:
|
||||
"""Update client policies.
|
||||
def create_backend_app(settings: BackendSettings) -> FastAPI:
|
||||
"""Create the backend app."""
|
||||
|
||||
This is also how you delete policies.
|
||||
"""
|
||||
statement = select(Client).where(Client.name == name)
|
||||
results = session.exec(statement)
|
||||
client = results.first()
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
# Remove old policies.
|
||||
policies = session.exec(
|
||||
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
|
||||
).all()
|
||||
deleted_policies: list[ClientAccessPolicy] = []
|
||||
added_policies: list[ClientAccessPolicy] = []
|
||||
for policy in policies:
|
||||
session.delete(policy)
|
||||
deleted_policies.append(policy)
|
||||
engine, get_db_session = setup_database(settings.db_url)
|
||||
|
||||
for source in policy_update.sources:
|
||||
LOG.debug("Source %r", source)
|
||||
policy = ClientAccessPolicy(source=str(source), client_id=client.id)
|
||||
session.add(policy)
|
||||
added_policies.append(policy)
|
||||
|
||||
session.commit()
|
||||
session.refresh(client)
|
||||
for policy in deleted_policies:
|
||||
audit.audit_remove_policy(session, request, client, policy)
|
||||
|
||||
for policy in added_policies:
|
||||
audit.audit_update_policy(session, request, client, policy)
|
||||
|
||||
return ClientPolicyView.from_client(client)
|
||||
|
||||
|
||||
@backend_api.post("/clients/{name}/public-key")
|
||||
async def update_client_public_key(
|
||||
request: Request,
|
||||
name: str,
|
||||
client_update: ClientUpdate,
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> ClientView:
|
||||
"""Change the public key of a client.
|
||||
|
||||
This invalidates all secrets.
|
||||
"""
|
||||
statement = select(Client).where(Client.name == name)
|
||||
results = session.exec(statement)
|
||||
client = results.first()
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
client.public_key = client_update.public_key
|
||||
for secret in session.exec(
|
||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||
).all():
|
||||
LOG.debug("Invalidated secret %s", secret.id)
|
||||
secret.invalidated = True
|
||||
secret.client_id = None
|
||||
secret.client = None
|
||||
|
||||
session.add(client)
|
||||
session.refresh(client)
|
||||
session.commit()
|
||||
audit.audit_invalidate_secrets(session, request, client)
|
||||
|
||||
return ClientView.from_client(client)
|
||||
|
||||
|
||||
@backend_api.post("/clients/")
|
||||
async def create_client(
|
||||
request: Request,
|
||||
client: ClientCreate,
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> ClientView:
|
||||
"""Create client."""
|
||||
existing = await get_client_by_name(session, client.name)
|
||||
if existing:
|
||||
raise HTTPException(400, detail="Error: Already a client with that name.")
|
||||
|
||||
db_client = client.to_client()
|
||||
session.add(db_client)
|
||||
session.commit()
|
||||
session.refresh(db_client)
|
||||
audit.audit_create_client(session, request, db_client)
|
||||
return ClientView.from_client(db_client)
|
||||
|
||||
|
||||
@backend_api.post("/clients/{name}/secrets/")
|
||||
async def add_secret_to_client(
|
||||
request: Request,
|
||||
name: str,
|
||||
client_secret: ClientSecretPublic,
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> None:
|
||||
"""Add secret to a client."""
|
||||
client = await get_client_by_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
existing_secret = await lookup_client_secret(session, client, client_secret.name)
|
||||
if existing_secret:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot add a secret. A different secret with the same name already exists.",
|
||||
)
|
||||
db_secret = ClientSecret(
|
||||
name=client_secret.name, client_id=client.id, secret=client_secret.secret
|
||||
)
|
||||
session.add(db_secret)
|
||||
session.commit()
|
||||
session.refresh(db_secret)
|
||||
audit.audit_create_secret(session, request, client, db_secret)
|
||||
|
||||
|
||||
@backend_api.put("/clients/{name}/secrets/{secret_name}")
|
||||
async def update_client_secret(
|
||||
request: Request,
|
||||
name: str,
|
||||
secret_name: str,
|
||||
secret_data: BodyValue,
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> ClientSecretResponse:
|
||||
"""Update a client secret.
|
||||
|
||||
This can also be used for destructive creates.
|
||||
"""
|
||||
client = await get_client_by_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
existing_secret = await lookup_client_secret(session, client, secret_name)
|
||||
if existing_secret:
|
||||
existing_secret.secret = secret_data.value
|
||||
session.add(existing_secret)
|
||||
session.commit()
|
||||
session.refresh(existing_secret)
|
||||
audit.audit_update_secret(session, request, client, existing_secret)
|
||||
return ClientSecretResponse.from_client_secret(existing_secret)
|
||||
|
||||
db_secret = ClientSecret(
|
||||
name=secret_name,
|
||||
client_id=client.id,
|
||||
secret=secret_data.value,
|
||||
)
|
||||
session.add(db_secret)
|
||||
session.commit()
|
||||
session.refresh(db_secret)
|
||||
audit.audit_create_secret(session, request, client, db_secret)
|
||||
return ClientSecretResponse.from_client_secret(db_secret)
|
||||
|
||||
|
||||
@backend_api.get("/clients/{name}/secrets/{secret_name}")
|
||||
async def request_client_secret(
|
||||
request: Request,
|
||||
name: str,
|
||||
secret_name: str,
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> ClientSecretResponse:
|
||||
"""Get a client secret."""
|
||||
client = await get_client_by_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
secret = await lookup_client_secret(session, client, secret_name)
|
||||
if not secret:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a secret with the given name."
|
||||
)
|
||||
|
||||
response_model = ClientSecretResponse.from_client_secret(secret)
|
||||
audit.audit_access_secret(session, request, client, secret)
|
||||
return response_model
|
||||
|
||||
@backend_api.delete("/clients/{name}/secrets/{secret_name}")
|
||||
async def delete_client_secret(
|
||||
request: Request,
|
||||
name: str,
|
||||
secret_name: str,
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> None:
|
||||
"""Delete a secret."""
|
||||
client = await get_client_by_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
secret = await lookup_client_secret(session, client, secret_name)
|
||||
if not secret:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a secret with the given name."
|
||||
)
|
||||
|
||||
session.delete(secret)
|
||||
session.commit()
|
||||
audit.audit_delete_secret(session, request, client, secret)
|
||||
|
||||
|
||||
@backend_api.get("/audit/", response_model=list[AuditLog])
|
||||
async def get_audit_logs(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
offset: Annotated[int, Query()] = 0,
|
||||
limit: Annotated[int, Query(le=100)] = 100,
|
||||
filter_client: Annotated[str | None, Query()] = None,
|
||||
) -> Sequence[AuditLog]:
|
||||
"""Get audit logs."""
|
||||
audit.audit_access_audit_log(session, request)
|
||||
statement = select(AuditLog).offset(offset).limit(limit)
|
||||
if filter_client:
|
||||
statement = statement.where(AuditLog.client_name == filter_client)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
return results
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
|
||||
)
|
||||
|
||||
|
||||
app.include_router(backend_api)
|
||||
return init_backend_app(engine, get_db_session)
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
from collections.abc import Sequence
|
||||
from fastapi import Request
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy
|
||||
|
||||
|
||||
@ -21,6 +22,7 @@ def _write_audit_log(
|
||||
"""Write the audit log."""
|
||||
origin = _get_origin(request)
|
||||
entry.origin = origin
|
||||
entry.subsystem = "backend"
|
||||
session.add(entry)
|
||||
if commit:
|
||||
session.commit()
|
||||
@ -109,6 +111,23 @@ def audit_update_policy(
|
||||
_write_audit_log(session, request, entry, commit)
|
||||
|
||||
|
||||
def audit_update_client(
|
||||
session: Session,
|
||||
request: Request,
|
||||
client: Client,
|
||||
commit: bool = True,
|
||||
) -> None:
|
||||
"""Audit an update secret event."""
|
||||
entry = AuditLog(
|
||||
operation="UPDATE",
|
||||
object="Client",
|
||||
client_id=client.id,
|
||||
client_name=client.name,
|
||||
message="Client updated",
|
||||
)
|
||||
_write_audit_log(session, request, entry, commit)
|
||||
|
||||
|
||||
def audit_update_secret(
|
||||
session: Session,
|
||||
request: Request,
|
||||
@ -219,3 +238,15 @@ def audit_access_audit_log(
|
||||
object="AuditLog",
|
||||
)
|
||||
_write_audit_log(session, request, entry, commit)
|
||||
|
||||
|
||||
def audit_client_secret_list(
|
||||
session: Session, request: Request, commit: bool = True
|
||||
) -> None:
|
||||
"""Audit a list of all secrets."""
|
||||
entry = AuditLog(
|
||||
operation="ACCESS",
|
||||
message="All secret names and their clients was viewed",
|
||||
)
|
||||
_write_audit_log(session, request, entry, commit)
|
||||
|
||||
|
||||
@ -0,0 +1,67 @@
|
||||
"""Backend API."""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
import bcrypt
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api
|
||||
from .models import (
|
||||
APIClient,
|
||||
)
|
||||
from .types import DBSessionDep
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
API_VERSION = "v1"
|
||||
|
||||
|
||||
def verify_token(token: str, stored_hash: str) -> bool:
|
||||
"""Verify token."""
|
||||
token_bytes = token.encode("utf-8")
|
||||
stored_bytes = stored_hash.encode("utf-8")
|
||||
return bcrypt.checkpw(token_bytes, stored_bytes)
|
||||
|
||||
|
||||
def get_backend_api(
|
||||
get_db_session: DBSessionDep,
|
||||
) -> APIRouter:
|
||||
"""Construct backend API."""
|
||||
|
||||
async def validate_token(
|
||||
x_api_token: Annotated[str, Header()],
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
) -> str:
|
||||
"""Validate token."""
|
||||
LOG.debug("Validating token %s", x_api_token)
|
||||
statement = select(APIClient)
|
||||
results = session.exec(statement)
|
||||
valid = False
|
||||
for result in results:
|
||||
if verify_token(x_api_token, result.token):
|
||||
valid = True
|
||||
LOG.debug("Token is valid")
|
||||
break
|
||||
|
||||
if not valid:
|
||||
LOG.debug("Token is not valid.")
|
||||
raise HTTPException(
|
||||
status_code=401, detail="unauthorized. invalid api token."
|
||||
)
|
||||
return x_api_token
|
||||
|
||||
LOG.info("Initializing app.")
|
||||
|
||||
backend_api = APIRouter(
|
||||
prefix=f"/api/{API_VERSION}",
|
||||
dependencies=[Depends(validate_token)],
|
||||
)
|
||||
|
||||
backend_api.include_router(get_audit_api(get_db_session))
|
||||
backend_api.include_router(get_clients_api(get_db_session))
|
||||
backend_api.include_router(get_policy_api(get_db_session))
|
||||
backend_api.include_router(get_secrets_api(get_db_session))
|
||||
|
||||
return backend_api
|
||||
@ -1,11 +1,18 @@
|
||||
"""CLI and main entry point."""
|
||||
|
||||
import code
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
from dotenv import load_dotenv
|
||||
import click
|
||||
from sqlmodel import Session, create_engine, select
|
||||
import uvicorn
|
||||
|
||||
from .db import generate_api_token
|
||||
from .db import create_api_token
|
||||
|
||||
from .models import Client, ClientSecret, ClientAccessPolicy, AuditLog, APIClient
|
||||
from .settings import BackendSettings
|
||||
|
||||
DEFAULT_LISTEN = "127.0.0.1"
|
||||
DEFAULT_PORT = 8022
|
||||
@ -14,18 +21,59 @@ WORKDIR = Path(os.getcwd())
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.option("--database", help="Path to the sqlite database file.")
|
||||
def cli(database: str) -> None:
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context, database: str) -> None:
|
||||
"""CLI group."""
|
||||
if database:
|
||||
# Hopefully it's enough to set the environment variable as so.
|
||||
os.environ["SSHECRET_DB_FILE"] = str(Path(database).absolute())
|
||||
settings = BackendSettings(db_url=f"sqlite:///{Path(database).absolute()}")
|
||||
else:
|
||||
settings = BackendSettings()
|
||||
|
||||
ctx.obj = settings
|
||||
|
||||
|
||||
@cli.command("generate-token")
|
||||
def cli_generate_token() -> None:
|
||||
@click.pass_context
|
||||
def cli_generate_token(ctx: click.Context) -> None:
|
||||
"""Generate a token."""
|
||||
token = generate_api_token()
|
||||
settings = cast(BackendSettings, ctx.obj)
|
||||
engine = create_engine(settings.db_url)
|
||||
with Session(engine) as session:
|
||||
token = create_api_token(session, True)
|
||||
click.echo("Generated api token:")
|
||||
click.echo(token)
|
||||
|
||||
@cli.command("run")
|
||||
@click.option("--host", default="127.0.0.1")
|
||||
@click.option("--port", default=8022, type=click.INT)
|
||||
@click.option("--dev", is_flag=True)
|
||||
@click.option("--workers", type=click.INT)
|
||||
def cli_run(host: str, port: int, dev: bool, workers: int | None) -> None:
|
||||
"""Run the server."""
|
||||
uvicorn.run("sshecret_backend.main:app", host=host, port=port, reload=dev, workers=workers)
|
||||
|
||||
@cli.command("repl")
|
||||
@click.pass_context
|
||||
def cli_repl(ctx: click.Context) -> None:
|
||||
"""Run an interactive console."""
|
||||
settings = cast(BackendSettings, ctx.obj)
|
||||
engine = create_engine(settings.db_url)
|
||||
|
||||
with Session(engine) as session:
|
||||
locals = {
|
||||
"session": session,
|
||||
"select": select,
|
||||
"Client": Client,
|
||||
"ClientSecret": ClientSecret,
|
||||
"ClientAccessPolicy": ClientAccessPolicy,
|
||||
"APIClient": APIClient,
|
||||
"AuditLog": AuditLog,
|
||||
}
|
||||
|
||||
console = code.InteractiveConsole(locals=locals, local_exit=True)
|
||||
banner = "Sshecret-backend REPL.\nUse 'session' to interact with the database."
|
||||
console.interact(banner=banner, exitmsg="Bye!")
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from collections.abc import Generator, Callable
|
||||
from pathlib import Path
|
||||
from sqlalchemy import Engine
|
||||
from sqlmodel import Session, create_engine, text
|
||||
@ -9,14 +10,30 @@ import bcrypt
|
||||
|
||||
from sqlalchemy.engine import URL
|
||||
|
||||
from .models import APIClient, init_db
|
||||
|
||||
from .settings import get_settings
|
||||
from .models import APIClient
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_database(
|
||||
db_url: URL | str,
|
||||
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
|
||||
"""Setup database."""
|
||||
|
||||
engine = create_engine(db_url, echo=False)
|
||||
with engine.connect() as connection:
|
||||
connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only
|
||||
|
||||
def get_db_session() -> Generator[Session, None, None]:
|
||||
"""Get DB Session."""
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
return engine, get_db_session
|
||||
|
||||
|
||||
def get_engine(filename: Path, echo: bool = False) -> Engine:
|
||||
"""Initialize the engine."""
|
||||
url = URL.create(drivername="sqlite", database=str(filename.absolute()))
|
||||
@ -27,20 +44,6 @@ def get_engine(filename: Path, echo: bool = False) -> Engine:
|
||||
return engine
|
||||
|
||||
|
||||
def create_db_and_tables(filename: Path, echo: bool = False) -> bool:
|
||||
"""Create database and tables.
|
||||
|
||||
Returns True if the database was created.
|
||||
"""
|
||||
created = False
|
||||
if not filename.exists():
|
||||
created = True
|
||||
engine = get_engine(filename, echo)
|
||||
|
||||
init_db(engine)
|
||||
return created
|
||||
|
||||
|
||||
def create_api_token(session: Session, read_write: bool) -> str:
|
||||
"""Create API token."""
|
||||
token = secrets.token_urlsafe(32)
|
||||
@ -54,14 +57,3 @@ def create_api_token(session: Session, read_write: bool) -> str:
|
||||
session.commit()
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def generate_api_token() -> str:
|
||||
"""Generate API token."""
|
||||
settings = get_settings()
|
||||
engine = get_engine(settings.db_file)
|
||||
init_db(engine)
|
||||
with Session(engine) as session:
|
||||
token = create_api_token(session, True)
|
||||
|
||||
return token
|
||||
|
||||
7
packages/sshecret-backend/src/sshecret_backend/main.py
Normal file
7
packages/sshecret-backend/src/sshecret_backend/main.py
Normal file
@ -0,0 +1,7 @@
|
||||
"""Main script entrypoint."""
|
||||
|
||||
from .settings import BackendSettings
|
||||
|
||||
from .app import create_backend_app
|
||||
|
||||
app = create_backend_app(BackendSettings())
|
||||
@ -7,17 +7,21 @@ This might require some changes to these schemas.
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
import sqlalchemy as sa
|
||||
from sqlmodel import Field, Relationship, SQLModel
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Client(SQLModel, table=True):
|
||||
"""Client model."""
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
name: str = Field(unique=True)
|
||||
description: str | None = None
|
||||
public_key: str
|
||||
|
||||
created_at: datetime | None = Field(
|
||||
@ -61,11 +65,13 @@ class ClientAccessPolicy(SQLModel, table=True):
|
||||
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
|
||||
)
|
||||
|
||||
|
||||
class ClientSecret(SQLModel, table=True):
|
||||
"""A client secret."""
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
name: str
|
||||
description: str | None = None
|
||||
client_id: uuid.UUID | None = Field(foreign_key="client.id", ondelete="CASCADE")
|
||||
client: Client | None = Relationship(back_populates="secrets")
|
||||
secret: str
|
||||
@ -92,6 +98,7 @@ class AuditLog(SQLModel, table=True):
|
||||
"""
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
subsystem: str | None = None
|
||||
object: str | None = None
|
||||
object_id: str | None = None
|
||||
operation: str
|
||||
@ -107,6 +114,7 @@ class AuditLog(SQLModel, table=True):
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
class APIClient(SQLModel, table=True):
|
||||
"""Stores API Keys."""
|
||||
|
||||
@ -120,6 +128,8 @@ class APIClient(SQLModel, table=True):
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def init_db(engine: sa.Engine) -> None:
|
||||
"""Create database."""
|
||||
LOG.info("Starting init_db")
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
@ -1,15 +1,13 @@
|
||||
"""Settings management."""
|
||||
|
||||
from typing import override
|
||||
from pathlib import Path
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
SettingsConfigDict,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_DATABASE = "sshecret.db"
|
||||
DEFAULT_DATABASE = "sqlite:///sshecret.db"
|
||||
|
||||
|
||||
class BackendSettings(BaseSettings):
|
||||
@ -17,7 +15,7 @@ class BackendSettings(BaseSettings):
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".backend.env", env_prefix="sshecret_")
|
||||
|
||||
db_file: Path = Field(default=Path(DEFAULT_DATABASE).absolute())
|
||||
db_url: str = Field(default=DEFAULT_DATABASE)
|
||||
|
||||
|
||||
def get_settings() -> BackendSettings:
|
||||
|
||||
@ -1,13 +1,17 @@
|
||||
"""Test helpers."""
|
||||
|
||||
import logging
|
||||
from sqlmodel import Session
|
||||
from .db import get_engine, create_api_token
|
||||
from sshecret_backend.settings import BackendSettings
|
||||
from .models import init_db
|
||||
from .settings import get_settings
|
||||
from .db import create_api_token, setup_database
|
||||
|
||||
def create_test_token(session: Session) -> str:
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_test_token(settings: BackendSettings) -> str:
|
||||
"""Create test token."""
|
||||
settings = get_settings()
|
||||
engine = get_engine(settings.db_file)
|
||||
init_db(engine)
|
||||
return create_api_token(session, True)
|
||||
engine, _setupdb = setup_database(settings.db_url)
|
||||
with Session(engine) as session:
|
||||
init_db(engine)
|
||||
return create_api_token(session, True)
|
||||
|
||||
8
packages/sshecret-backend/src/sshecret_backend/types.py
Normal file
8
packages/sshecret-backend/src/sshecret_backend/types.py
Normal file
@ -0,0 +1,8 @@
|
||||
"""Common type definitions."""
|
||||
|
||||
from collections.abc import Callable, Generator
|
||||
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
DBSessionDep = Callable[[], Generator[Session, None, None]]
|
||||
@ -1,24 +1,27 @@
|
||||
"""Models for API views."""
|
||||
|
||||
import ipaddress
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Self, override
|
||||
from typing import Annotated, Self, override
|
||||
|
||||
from sqlmodel import Field, SQLModel
|
||||
from pydantic import IPvAnyAddress, IPvAnyNetwork
|
||||
from . import models
|
||||
from pydantic import AfterValidator, IPvAnyAddress, IPvAnyNetwork
|
||||
|
||||
from sshecret.crypto import public_key_validator
|
||||
|
||||
from . import models
|
||||
|
||||
|
||||
class ClientView(SQLModel):
|
||||
"""View for a single client."""
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
description: str | None = None
|
||||
public_key: str
|
||||
policies: list[str] = ["0.0.0.0/0", "::/0"]
|
||||
secrets: list[str] = Field(default_factory=list)
|
||||
created_at: datetime
|
||||
created_at: datetime | None
|
||||
updated_at: datetime | None = None
|
||||
|
||||
@classmethod
|
||||
@ -33,6 +36,7 @@ class ClientView(SQLModel):
|
||||
view = cls(
|
||||
id=client.id,
|
||||
name=client.name,
|
||||
description=client.description,
|
||||
public_key=client.public_key,
|
||||
created_at=client.created_at,
|
||||
updated_at=client.updated_at or None,
|
||||
@ -46,24 +50,34 @@ class ClientView(SQLModel):
|
||||
return view
|
||||
|
||||
|
||||
class ClientQueryResult(SQLModel):
|
||||
"""Result class for queries towards the client list."""
|
||||
|
||||
clients: list[ClientView] = Field(default_factory=list)
|
||||
total_results: int
|
||||
remaining_results: int
|
||||
|
||||
|
||||
class ClientCreate(SQLModel):
|
||||
"""Model to create a client."""
|
||||
|
||||
name: str
|
||||
public_key: str
|
||||
description: str | None = None
|
||||
public_key: Annotated[str, AfterValidator(public_key_validator)]
|
||||
|
||||
def to_client(self) -> models.Client:
|
||||
"""Instantiate a client."""
|
||||
public_key = self.public_key
|
||||
return models.Client(
|
||||
name=self.name, public_key=public_key
|
||||
name=self.name,
|
||||
public_key=self.public_key,
|
||||
description=self.description,
|
||||
)
|
||||
|
||||
|
||||
class ClientUpdate(SQLModel):
|
||||
"""Model to update the client public key."""
|
||||
|
||||
public_key: str
|
||||
public_key: Annotated[str, AfterValidator(public_key_validator)]
|
||||
|
||||
|
||||
class BodyValue(SQLModel):
|
||||
@ -77,6 +91,7 @@ class ClientSecretPublic(SQLModel):
|
||||
|
||||
name: str
|
||||
secret: str
|
||||
description: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_client_secret(cls, client_secret: models.ClientSecret) -> Self:
|
||||
@ -84,13 +99,14 @@ class ClientSecretPublic(SQLModel):
|
||||
return cls(
|
||||
name=client_secret.name,
|
||||
secret=client_secret.secret,
|
||||
description=client_secret.description,
|
||||
)
|
||||
|
||||
|
||||
class ClientSecretResponse(ClientSecretPublic):
|
||||
"""A secret view."""
|
||||
|
||||
created_at: datetime
|
||||
created_at: datetime | None
|
||||
updated_at: datetime | None = None
|
||||
|
||||
@override
|
||||
@ -123,3 +139,31 @@ class ClientPolicyUpdate(SQLModel):
|
||||
"""Model for updating policies."""
|
||||
|
||||
sources: list[IPvAnyAddress | IPvAnyNetwork]
|
||||
|
||||
|
||||
class ClientSecretList(SQLModel):
|
||||
"""Model for aggregating identically named secrets."""
|
||||
|
||||
name: str
|
||||
clients: list[str]
|
||||
|
||||
|
||||
class ClientReference(SQLModel):
|
||||
"""Reference to a client."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class ClientSecretDetailList(SQLModel):
|
||||
"""A more detailed version of the ClientSecretList."""
|
||||
|
||||
name: str
|
||||
ids: list[str] = Field(default_factory=list)
|
||||
clients: list[ClientReference] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AuditInfo(SQLModel):
|
||||
"""Information about audit information."""
|
||||
|
||||
entries: int
|
||||
|
||||
Reference in New Issue
Block a user