Add sub-projects
This commit is contained in:
289
packages/sshecret-backend/src/sshecret_backend/app.py
Normal file
289
packages/sshecret-backend/src/sshecret_backend/app.py
Normal file
@ -0,0 +1,289 @@
|
||||
"""FastAPI api."""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Annotated
|
||||
from collections.abc import Sequence
|
||||
|
||||
import bcrypt
|
||||
from fastapi import APIRouter, Depends, FastAPI, Header, HTTPException, Query, Request
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from . import audit
|
||||
from .db import get_engine
|
||||
from .models import APIClient, AuditLog, Client, ClientSecret, init_db
|
||||
from .settings import get_settings
|
||||
from .view_models import (
|
||||
BodyValue,
|
||||
ClientCreate,
|
||||
ClientListResponse,
|
||||
ClientSecretPublic,
|
||||
ClientSecretResponse,
|
||||
ClientUpdate,
|
||||
ClientView,
|
||||
)
|
||||
|
||||
settings = get_settings()
|
||||
engine = get_engine(settings.db_file)
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
API_VERSION = "v1"
|
||||
|
||||
|
||||
def verify_token(token: str, stored_hash: str) -> bool:
|
||||
"""Verify token."""
|
||||
token_bytes = token.encode("utf-8")
|
||||
stored_bytes = stored_hash.encode("utf-8")
|
||||
return bcrypt.checkpw(token_bytes, stored_bytes)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI):
|
||||
"""Create database before starting the server."""
|
||||
init_db(engine)
|
||||
yield
|
||||
|
||||
|
||||
async def get_session():
|
||||
"""Get the session."""
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def validate_token(
|
||||
x_api_token: Annotated[str, Header()],
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> str:
|
||||
"""Validate token."""
|
||||
LOG.debug("Validating token %s", x_api_token)
|
||||
statement = select(APIClient)
|
||||
results = session.exec(statement)
|
||||
valid = False
|
||||
for result in results:
|
||||
if verify_token(x_api_token, result.token):
|
||||
valid = True
|
||||
LOG.debug("Token is valid")
|
||||
break
|
||||
|
||||
if not valid:
|
||||
LOG.debug("Token is not valid.")
|
||||
raise HTTPException(status_code=401, detail="unauthorized. invalid api token.")
|
||||
return x_api_token
|
||||
|
||||
|
||||
async def get_client_by_name(session: Session, name: str) -> Client | None:
|
||||
"""Get client by name."""
|
||||
client_filter = select(Client).where(Client.name == name)
|
||||
client_results = session.exec(client_filter)
|
||||
return client_results.first()
|
||||
|
||||
|
||||
async def lookup_client_secret(
|
||||
session: Session, client: Client, name: str
|
||||
) -> ClientSecret | None:
|
||||
"""Look up a secret for a client."""
|
||||
statement = (
|
||||
select(ClientSecret)
|
||||
.where(ClientSecret.client_id == client.id)
|
||||
.where(ClientSecret.name == name)
|
||||
)
|
||||
results = session.exec(statement)
|
||||
return results.first()
|
||||
|
||||
|
||||
LOG.info("Initializing app.")
|
||||
backend_api = APIRouter(
|
||||
prefix=f"/api/{API_VERSION}",
|
||||
lifespan=lifespan,
|
||||
dependencies=[Depends(validate_token)],
|
||||
)
|
||||
|
||||
|
||||
@backend_api.get("/clients/")
|
||||
async def get_clients(
|
||||
session: Annotated[Session, Depends(get_session)]
|
||||
) -> list[ClientListResponse]:
|
||||
"""Get clients."""
|
||||
statement = select(Client)
|
||||
results = session.exec(statement)
|
||||
clients = list(results)
|
||||
return ClientListResponse.from_clients(clients)
|
||||
|
||||
|
||||
@backend_api.get("/clients/{name}")
|
||||
async def get_client(
|
||||
request: Request, name: str, session: Annotated[Session, Depends(get_session)]
|
||||
) -> ClientView:
|
||||
"""Fetch a client."""
|
||||
statement = select(Client).where(Client.name == name)
|
||||
results = session.exec(statement)
|
||||
client = results.first()
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
audit.audit_access_secrets(session, request, client)
|
||||
return ClientView.from_client(client)
|
||||
|
||||
|
||||
@backend_api.post("/clients/{name}/update_fingerprint")
|
||||
async def update_client_fingerprint(
|
||||
request: Request,
|
||||
name: str,
|
||||
client_update: ClientUpdate,
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> ClientView:
|
||||
"""Update the client fingerprint.
|
||||
|
||||
This invalidates all secrets.
|
||||
"""
|
||||
statement = select(Client).where(Client.name == name)
|
||||
results = session.exec(statement)
|
||||
client = results.first()
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
client.fingerprint = client_update.fingerprint
|
||||
for secret in session.exec(
|
||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||
).all():
|
||||
LOG.debug("Invalidated secret %s", secret.id)
|
||||
secret.invalidated = True
|
||||
secret.client_id = None
|
||||
secret.client = None
|
||||
|
||||
session.add(client)
|
||||
session.refresh(client)
|
||||
session.commit()
|
||||
audit.audit_invalidate_secrets(session, request, client)
|
||||
|
||||
return ClientView.from_client(client)
|
||||
|
||||
|
||||
@backend_api.post("/clients/")
|
||||
async def create_client(
|
||||
request: Request,
|
||||
client: ClientCreate,
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> ClientView:
|
||||
"""Create client."""
|
||||
db_client = Client.model_validate(client)
|
||||
session.add(db_client)
|
||||
session.commit()
|
||||
session.refresh(db_client)
|
||||
audit.audit_create_client(session, request, db_client)
|
||||
return ClientView.from_client(db_client)
|
||||
|
||||
|
||||
@backend_api.post("/clients/{name}/secrets/")
|
||||
async def add_secret_to_client(
|
||||
request: Request,
|
||||
name: str,
|
||||
client_secret: ClientSecretPublic,
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> None:
|
||||
"""Add secret to a client."""
|
||||
client = await get_client_by_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
existing_secret = await lookup_client_secret(session, client, client_secret.name)
|
||||
if existing_secret:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot add a secret. A different secret with the same name already exists.",
|
||||
)
|
||||
db_secret = ClientSecret(
|
||||
name=client_secret.name, client_id=client.id, secret=client_secret.secret
|
||||
)
|
||||
session.add(db_secret)
|
||||
session.commit()
|
||||
session.refresh(db_secret)
|
||||
audit.audit_create_secret(session, request, client, db_secret)
|
||||
|
||||
|
||||
@backend_api.put("/clients/{name}/secrets/{secret_name}")
|
||||
async def update_client_secret(
|
||||
request: Request,
|
||||
name: str,
|
||||
secret_name: str,
|
||||
secret_data: BodyValue,
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> ClientSecretResponse:
|
||||
"""Update a client secret.
|
||||
|
||||
This can also be used for destructive creates.
|
||||
"""
|
||||
client = await get_client_by_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
existing_secret = await lookup_client_secret(session, client, secret_name)
|
||||
if existing_secret:
|
||||
existing_secret.secret = secret_data.value
|
||||
session.add(existing_secret)
|
||||
session.commit()
|
||||
session.refresh(existing_secret)
|
||||
audit.audit_update_secret(session, request, client, existing_secret)
|
||||
return ClientSecretResponse.from_client_secret(existing_secret)
|
||||
|
||||
db_secret = ClientSecret(
|
||||
name=secret_name,
|
||||
client_id=client.id,
|
||||
secret=secret_data.value,
|
||||
)
|
||||
session.add(db_secret)
|
||||
session.commit()
|
||||
session.refresh(db_secret)
|
||||
audit.audit_create_secret(session, request, client, db_secret)
|
||||
return ClientSecretResponse.from_client_secret(db_secret)
|
||||
|
||||
|
||||
@backend_api.get("/clients/{name}/secrets/{secret_name}")
|
||||
async def request_client_secret(
|
||||
request: Request,
|
||||
name: str,
|
||||
secret_name: str,
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> ClientSecretResponse:
|
||||
"""Get a client secret."""
|
||||
client = await get_client_by_name(session, name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
|
||||
secret = await lookup_client_secret(session, client, secret_name)
|
||||
if not secret:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find a secret with the given name."
|
||||
)
|
||||
|
||||
response_model = ClientSecretResponse.from_client_secret(secret)
|
||||
audit.audit_access_secret(session, request, client, secret)
|
||||
return response_model
|
||||
|
||||
|
||||
@backend_api.get("/audit/", response_model=list[AuditLog])
|
||||
async def get_audit_logs(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
offset: Annotated[int, Query()] = 0,
|
||||
limit: Annotated[int, Query(le=100)] = 100,
|
||||
filter_client: Annotated[str | None, Query()] = None,
|
||||
) -> Sequence[AuditLog]:
|
||||
"""Get audit logs."""
|
||||
audit.audit_access_audit_log(session, request)
|
||||
statement = select(AuditLog).offset(offset).limit(limit)
|
||||
if filter_client:
|
||||
statement = statement.where(AuditLog.client_name == filter_client)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
return results
|
||||
Reference in New Issue
Block a user