Complete backend
This commit is contained in:
@ -14,6 +14,9 @@ dependencies = [
|
|||||||
"sqlmodel>=0.0.24",
|
"sqlmodel>=0.0.24",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
sshecret-backend = "sshecret_backend.cli:cli"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["hatchling"]
|
requires = ["hatchling"]
|
||||||
build-backend = "hatchling.build"
|
build-backend = "hatchling.build"
|
||||||
|
|||||||
@ -1 +1,5 @@
|
|||||||
"""Sshecret backend."""
|
"""Sshecret backend."""
|
||||||
|
from .app import app as app
|
||||||
|
#from .router import app as app
|
||||||
|
|
||||||
|
__all__ = ["app"]
|
||||||
|
|||||||
@ -1,26 +1,52 @@
|
|||||||
"""FastAPI api."""
|
"""FastAPI api.
|
||||||
|
|
||||||
|
TODO: We may want to allow a consumer to generate audit log entries manually.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from collections.abc import Sequence
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from collections.abc import Sequence
|
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
from fastapi import APIRouter, Depends, FastAPI, Header, HTTPException, Query, Request
|
from fastapi import (
|
||||||
|
APIRouter,
|
||||||
|
Depends,
|
||||||
|
FastAPI,
|
||||||
|
Header,
|
||||||
|
HTTPException,
|
||||||
|
Query,
|
||||||
|
Request,
|
||||||
|
status,
|
||||||
|
)
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
from . import audit
|
from . import audit
|
||||||
from .db import get_engine
|
from .db import get_engine
|
||||||
from .models import APIClient, AuditLog, Client, ClientSecret, init_db
|
from .models import (
|
||||||
|
APIClient,
|
||||||
|
AuditLog,
|
||||||
|
Client,
|
||||||
|
ClientAccessPolicy,
|
||||||
|
ClientSecret,
|
||||||
|
init_db,
|
||||||
|
)
|
||||||
from .settings import get_settings
|
from .settings import get_settings
|
||||||
from .view_models import (
|
from .view_models import (
|
||||||
BodyValue,
|
BodyValue,
|
||||||
ClientCreate,
|
ClientCreate,
|
||||||
ClientListResponse,
|
|
||||||
ClientSecretPublic,
|
ClientSecretPublic,
|
||||||
ClientSecretResponse,
|
ClientSecretResponse,
|
||||||
ClientUpdate,
|
ClientUpdate,
|
||||||
ClientView,
|
ClientView,
|
||||||
|
ClientPolicyView,
|
||||||
|
ClientPolicyUpdate,
|
||||||
)
|
)
|
||||||
|
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
@ -104,12 +130,12 @@ backend_api = APIRouter(
|
|||||||
@backend_api.get("/clients/")
|
@backend_api.get("/clients/")
|
||||||
async def get_clients(
|
async def get_clients(
|
||||||
session: Annotated[Session, Depends(get_session)]
|
session: Annotated[Session, Depends(get_session)]
|
||||||
) -> list[ClientListResponse]:
|
) -> list[ClientView]:
|
||||||
"""Get clients."""
|
"""Get clients."""
|
||||||
statement = select(Client)
|
statement = select(Client)
|
||||||
results = session.exec(statement)
|
results = session.exec(statement)
|
||||||
clients = list(results)
|
clients = list(results)
|
||||||
return ClientListResponse.from_clients(clients)
|
return ClientView.from_client_list(clients)
|
||||||
|
|
||||||
|
|
||||||
@backend_api.get("/clients/{name}")
|
@backend_api.get("/clients/{name}")
|
||||||
@ -128,14 +154,93 @@ async def get_client(
|
|||||||
return ClientView.from_client(client)
|
return ClientView.from_client(client)
|
||||||
|
|
||||||
|
|
||||||
@backend_api.post("/clients/{name}/update_fingerprint")
|
@backend_api.delete("/clients/{name}")
|
||||||
async def update_client_fingerprint(
|
async def delete_client(
|
||||||
|
request: Request, name: str, session: Annotated[Session, Depends(get_session)]
|
||||||
|
) -> None:
|
||||||
|
"""Delete a client."""
|
||||||
|
statement = select(Client).where(Client.name == name)
|
||||||
|
results = session.exec(statement)
|
||||||
|
client = results.first()
|
||||||
|
if not client:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail="Cannot find a client with the given name."
|
||||||
|
)
|
||||||
|
|
||||||
|
session.delete(client)
|
||||||
|
session.commit()
|
||||||
|
audit.audit_delete_client(session, request, client)
|
||||||
|
|
||||||
|
|
||||||
|
@backend_api.get("/clients/{name}/policies/")
|
||||||
|
async def get_client_policies(
|
||||||
|
name: str, session: Annotated[Session, Depends(get_session)]
|
||||||
|
) -> ClientPolicyView:
|
||||||
|
"""Get client policies."""
|
||||||
|
statement = select(Client).where(Client.name == name)
|
||||||
|
results = session.exec(statement)
|
||||||
|
client = results.first()
|
||||||
|
if not client:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail="Cannot find a client with the given name."
|
||||||
|
)
|
||||||
|
|
||||||
|
return ClientPolicyView.from_client(client)
|
||||||
|
|
||||||
|
|
||||||
|
@backend_api.put("/clients/{name}/policies/")
|
||||||
|
async def update_client_policies(
|
||||||
|
request: Request,
|
||||||
|
name: str,
|
||||||
|
policy_update: ClientPolicyUpdate,
|
||||||
|
session: Annotated[Session, Depends(get_session)],
|
||||||
|
) -> ClientPolicyView:
|
||||||
|
"""Update client policies.
|
||||||
|
|
||||||
|
This is also how you delete policies.
|
||||||
|
"""
|
||||||
|
statement = select(Client).where(Client.name == name)
|
||||||
|
results = session.exec(statement)
|
||||||
|
client = results.first()
|
||||||
|
if not client:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail="Cannot find a client with the given name."
|
||||||
|
)
|
||||||
|
# Remove old policies.
|
||||||
|
policies = session.exec(
|
||||||
|
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
|
||||||
|
).all()
|
||||||
|
deleted_policies: list[ClientAccessPolicy] = []
|
||||||
|
added_policies: list[ClientAccessPolicy] = []
|
||||||
|
for policy in policies:
|
||||||
|
session.delete(policy)
|
||||||
|
deleted_policies.append(policy)
|
||||||
|
|
||||||
|
for source in policy_update.sources:
|
||||||
|
LOG.debug("Source %r", source)
|
||||||
|
policy = ClientAccessPolicy(source=str(source), client_id=client.id)
|
||||||
|
session.add(policy)
|
||||||
|
added_policies.append(policy)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
session.refresh(client)
|
||||||
|
for policy in deleted_policies:
|
||||||
|
audit.audit_remove_policy(session, request, client, policy)
|
||||||
|
|
||||||
|
for policy in added_policies:
|
||||||
|
audit.audit_update_policy(session, request, client, policy)
|
||||||
|
|
||||||
|
return ClientPolicyView.from_client(client)
|
||||||
|
|
||||||
|
|
||||||
|
@backend_api.post("/clients/{name}/public-key")
|
||||||
|
async def update_client_public_key(
|
||||||
request: Request,
|
request: Request,
|
||||||
name: str,
|
name: str,
|
||||||
client_update: ClientUpdate,
|
client_update: ClientUpdate,
|
||||||
session: Annotated[Session, Depends(get_session)],
|
session: Annotated[Session, Depends(get_session)],
|
||||||
) -> ClientView:
|
) -> ClientView:
|
||||||
"""Update the client fingerprint.
|
"""Change the public key of a client.
|
||||||
|
|
||||||
This invalidates all secrets.
|
This invalidates all secrets.
|
||||||
"""
|
"""
|
||||||
@ -146,7 +251,7 @@ async def update_client_fingerprint(
|
|||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404, detail="Cannot find a client with the given name."
|
status_code=404, detail="Cannot find a client with the given name."
|
||||||
)
|
)
|
||||||
client.fingerprint = client_update.fingerprint
|
client.public_key = client_update.public_key
|
||||||
for secret in session.exec(
|
for secret in session.exec(
|
||||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||||
).all():
|
).all():
|
||||||
@ -170,7 +275,11 @@ async def create_client(
|
|||||||
session: Annotated[Session, Depends(get_session)],
|
session: Annotated[Session, Depends(get_session)],
|
||||||
) -> ClientView:
|
) -> ClientView:
|
||||||
"""Create client."""
|
"""Create client."""
|
||||||
db_client = Client.model_validate(client)
|
existing = await get_client_by_name(session, client.name)
|
||||||
|
if existing:
|
||||||
|
raise HTTPException(400, detail="Error: Already a client with that name.")
|
||||||
|
|
||||||
|
db_client = client.to_client()
|
||||||
session.add(db_client)
|
session.add(db_client)
|
||||||
session.commit()
|
session.commit()
|
||||||
session.refresh(db_client)
|
session.refresh(db_client)
|
||||||
@ -270,6 +379,30 @@ async def request_client_secret(
|
|||||||
audit.audit_access_secret(session, request, client, secret)
|
audit.audit_access_secret(session, request, client, secret)
|
||||||
return response_model
|
return response_model
|
||||||
|
|
||||||
|
@backend_api.delete("/clients/{name}/secrets/{secret_name}")
|
||||||
|
async def delete_client_secret(
|
||||||
|
request: Request,
|
||||||
|
name: str,
|
||||||
|
secret_name: str,
|
||||||
|
session: Annotated[Session, Depends(get_session)],
|
||||||
|
) -> None:
|
||||||
|
"""Delete a secret."""
|
||||||
|
client = await get_client_by_name(session, name)
|
||||||
|
if not client:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail="Cannot find a client with the given name."
|
||||||
|
)
|
||||||
|
|
||||||
|
secret = await lookup_client_secret(session, client, secret_name)
|
||||||
|
if not secret:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail="Cannot find a secret with the given name."
|
||||||
|
)
|
||||||
|
|
||||||
|
session.delete(secret)
|
||||||
|
session.commit()
|
||||||
|
audit.audit_delete_secret(session, request, client, secret)
|
||||||
|
|
||||||
|
|
||||||
@backend_api.get("/audit/", response_model=list[AuditLog])
|
@backend_api.get("/audit/", response_model=list[AuditLog])
|
||||||
async def get_audit_logs(
|
async def get_audit_logs(
|
||||||
@ -287,3 +420,17 @@ async def get_audit_logs(
|
|||||||
|
|
||||||
results = session.exec(statement).all()
|
results = session.exec(statement).all()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@app.exception_handler(RequestValidationError)
|
||||||
|
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
app.include_router(backend_api)
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
from .models import AuditLog, Client, ClientSecret
|
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy
|
||||||
|
|
||||||
|
|
||||||
def _get_origin(request: Request) -> str | None:
|
def _get_origin(request: Request) -> str | None:
|
||||||
@ -39,6 +39,19 @@ def audit_create_client(
|
|||||||
_write_audit_log(session, request, entry, commit)
|
_write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
|
|
||||||
|
def audit_delete_client(
|
||||||
|
session: Session, request: Request, client: Client, commit: bool = True
|
||||||
|
) -> None:
|
||||||
|
"""Log the creation of a client."""
|
||||||
|
entry = AuditLog(
|
||||||
|
operation="CREATE",
|
||||||
|
client_id=client.id,
|
||||||
|
client_name=client.name,
|
||||||
|
message="Client deleted",
|
||||||
|
)
|
||||||
|
_write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
|
|
||||||
def audit_create_secret(
|
def audit_create_secret(
|
||||||
session: Session,
|
session: Session,
|
||||||
request: Request,
|
request: Request,
|
||||||
@ -58,6 +71,44 @@ def audit_create_secret(
|
|||||||
_write_audit_log(session, request, entry, commit)
|
_write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
|
|
||||||
|
def audit_remove_policy(
|
||||||
|
session: Session,
|
||||||
|
request: Request,
|
||||||
|
client: Client,
|
||||||
|
policy: ClientAccessPolicy,
|
||||||
|
commit: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Audit removal of policy."""
|
||||||
|
entry = AuditLog(
|
||||||
|
operation="DELETE",
|
||||||
|
object="ClientAccessPolicy",
|
||||||
|
object_id=str(policy.id),
|
||||||
|
client_id=client.id,
|
||||||
|
client_name=client.name,
|
||||||
|
message="Deleted client policy",
|
||||||
|
)
|
||||||
|
_write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
|
|
||||||
|
def audit_update_policy(
|
||||||
|
session: Session,
|
||||||
|
request: Request,
|
||||||
|
client: Client,
|
||||||
|
policy: ClientAccessPolicy,
|
||||||
|
commit: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Audit update of policy."""
|
||||||
|
entry = AuditLog(
|
||||||
|
operation="CREATE",
|
||||||
|
object="ClientAccessPolicy",
|
||||||
|
object_id=str(policy.id),
|
||||||
|
client_id=client.id,
|
||||||
|
client_name=client.name,
|
||||||
|
message="Updated client policy",
|
||||||
|
)
|
||||||
|
_write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
|
|
||||||
def audit_update_secret(
|
def audit_update_secret(
|
||||||
session: Session,
|
session: Session,
|
||||||
request: Request,
|
request: Request,
|
||||||
@ -89,7 +140,26 @@ def audit_invalidate_secrets(
|
|||||||
object="ClientSecret",
|
object="ClientSecret",
|
||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
client_id=client.id,
|
client_id=client.id,
|
||||||
message="Client fingerprint updated. All secrets invalidated.",
|
message="Client public-key changed. All secrets invalidated.",
|
||||||
|
)
|
||||||
|
_write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
|
|
||||||
|
def audit_delete_secret(
|
||||||
|
session: Session,
|
||||||
|
request: Request,
|
||||||
|
client: Client,
|
||||||
|
secret: ClientSecret,
|
||||||
|
commit: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Audit Delete client secrets."""
|
||||||
|
entry = AuditLog(
|
||||||
|
operation="DELETE",
|
||||||
|
object="ClientSecret",
|
||||||
|
object_id=str(secret.id),
|
||||||
|
client_name=client.name,
|
||||||
|
client_id=client.id,
|
||||||
|
message="Deleted secret.",
|
||||||
)
|
)
|
||||||
_write_audit_log(session, request, entry, commit)
|
_write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
|
|||||||
@ -1,24 +1,31 @@
|
|||||||
"""CLI and main entry point."""
|
"""CLI and main entry point."""
|
||||||
|
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from dotenv import load_dotenv
|
||||||
import click
|
import click
|
||||||
from pydantic import BaseModel, FilePath
|
|
||||||
|
|
||||||
|
from .db import generate_api_token
|
||||||
|
|
||||||
|
DEFAULT_LISTEN = "127.0.0.1"
|
||||||
|
DEFAULT_PORT = 8022
|
||||||
|
|
||||||
|
WORKDIR = Path(os.getcwd())
|
||||||
|
|
||||||
class BackendSettings(BaseModel):
|
load_dotenv()
|
||||||
"""Backend Settings."""
|
|
||||||
|
|
||||||
db_file: FilePath
|
|
||||||
regenerate_tokens: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
@click.option("--db-file", envvar="sshecret_db_file", type=click.Path(path_type=Path))
|
@click.option("--database", help="Path to the sqlite database file.")
|
||||||
@click.option("--regenerate-tokens", is_flag=True, default=False)
|
def cli(database: str) -> None:
|
||||||
@click.pass_context
|
"""CLI group."""
|
||||||
def cli(ctx: click.Context, db_file: Path, regenerate_tokens: bool) -> None:
|
if database:
|
||||||
"""Sshecret database handler."""
|
# Hopefully it's enough to set the environment variable as so.
|
||||||
if not isinstance(ctx.obj, BackendSettings):
|
os.environ["SSHECRET_DB_FILE"] = str(Path(database).absolute())
|
||||||
ctx.obj = BackendSettings(db_file=db_file, regenerate_tokens=regenerate_tokens)
|
|
||||||
|
|
||||||
|
@cli.command("generate-token")
|
||||||
|
def cli_generate_token() -> None:
|
||||||
|
"""Generate a token."""
|
||||||
|
token = generate_api_token()
|
||||||
|
click.echo("Generated api token:")
|
||||||
|
click.echo(token)
|
||||||
|
|||||||
@ -1,18 +1,20 @@
|
|||||||
#!/usr/bin/env python3
|
"""Database related functions."""
|
||||||
|
|
||||||
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from sqlalchemy import Engine
|
from sqlalchemy import Engine
|
||||||
from sqlmodel import Session, create_engine, text
|
from sqlmodel import Session, create_engine, text
|
||||||
import bcrypt
|
import bcrypt
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from sqlalchemy.engine import URL
|
from sqlalchemy.engine import URL
|
||||||
|
|
||||||
from .models import APIClient, init_db
|
from .models import APIClient, init_db
|
||||||
|
|
||||||
|
from .settings import get_settings
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_engine(filename: Path, echo: bool = False) -> Engine:
|
def get_engine(filename: Path, echo: bool = False) -> Engine:
|
||||||
@ -25,7 +27,7 @@ def get_engine(filename: Path, echo: bool = False) -> Engine:
|
|||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|
||||||
def create_db_and_tables(filename: Path, echo: bool = True) -> bool:
|
def create_db_and_tables(filename: Path, echo: bool = False) -> bool:
|
||||||
"""Create database and tables.
|
"""Create database and tables.
|
||||||
|
|
||||||
Returns True if the database was created.
|
Returns True if the database was created.
|
||||||
@ -52,3 +54,14 @@ def create_api_token(session: Session, read_write: bool) -> str:
|
|||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
return token
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
def generate_api_token() -> str:
|
||||||
|
"""Generate API token."""
|
||||||
|
settings = get_settings()
|
||||||
|
engine = get_engine(settings.db_file)
|
||||||
|
init_db(engine)
|
||||||
|
with Session(engine) as session:
|
||||||
|
token = create_api_token(session, True)
|
||||||
|
|
||||||
|
return token
|
||||||
|
|||||||
@ -1,8 +1,15 @@
|
|||||||
#!/usr/bin/env python3
|
"""Database models.
|
||||||
|
|
||||||
|
TODO:
|
||||||
|
|
||||||
|
We might want to pass on audit information from the SSH server.
|
||||||
|
This might require some changes to these schemas.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from sqlalchemy import Engine, Column, DateTime, func
|
import sqlalchemy as sa
|
||||||
from sqlmodel import Field, Relationship, SQLModel
|
from sqlmodel import Field, Relationship, SQLModel
|
||||||
|
|
||||||
|
|
||||||
@ -11,22 +18,48 @@ class Client(SQLModel, table=True):
|
|||||||
|
|
||||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||||
name: str = Field(unique=True)
|
name: str = Field(unique=True)
|
||||||
fingerprint: str
|
public_key: str
|
||||||
created_at: datetime = Field(
|
|
||||||
|
created_at: datetime | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
sa_column=Column(
|
sa_type=sa.DateTime(timezone=True),
|
||||||
DateTime(timezone=True), server_default=func.now(), nullable=True
|
sa_column_kwargs={"server_default": sa.func.now()},
|
||||||
),
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
updated_at: datetime | None = Field(
|
updated_at: datetime | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
sa_column=Column(DateTime(timezone=True), onupdate=func.now(), nullable=True),
|
sa_type=sa.DateTime(timezone=True),
|
||||||
|
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
|
||||||
)
|
)
|
||||||
|
|
||||||
secrets: list["ClientSecret"] = Relationship(
|
secrets: list["ClientSecret"] = Relationship(
|
||||||
back_populates="client", passive_deletes="all"
|
back_populates="client", passive_deletes="all"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
policies: list["ClientAccessPolicy"] = Relationship(back_populates="client")
|
||||||
|
|
||||||
|
|
||||||
|
class ClientAccessPolicy(SQLModel, table=True):
|
||||||
|
"""Client access policies."""
|
||||||
|
|
||||||
|
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||||
|
source: str
|
||||||
|
client_id: uuid.UUID | None = Field(foreign_key="client.id", ondelete="CASCADE")
|
||||||
|
client: Client | None = Relationship(back_populates="policies")
|
||||||
|
|
||||||
|
created_at: datetime | None = Field(
|
||||||
|
default=None,
|
||||||
|
sa_type=sa.DateTime(timezone=True),
|
||||||
|
sa_column_kwargs={"server_default": sa.func.now()},
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_at: datetime | None = Field(
|
||||||
|
default=None,
|
||||||
|
sa_type=sa.DateTime(timezone=True),
|
||||||
|
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
|
||||||
|
)
|
||||||
|
|
||||||
class ClientSecret(SQLModel, table=True):
|
class ClientSecret(SQLModel, table=True):
|
||||||
"""A client secret."""
|
"""A client secret."""
|
||||||
@ -37,15 +70,17 @@ class ClientSecret(SQLModel, table=True):
|
|||||||
client: Client | None = Relationship(back_populates="secrets")
|
client: Client | None = Relationship(back_populates="secrets")
|
||||||
secret: str
|
secret: str
|
||||||
invalidated: bool = Field(default=False)
|
invalidated: bool = Field(default=False)
|
||||||
created_at: datetime = Field(
|
created_at: datetime | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
sa_column=Column(
|
sa_type=sa.DateTime(timezone=True),
|
||||||
DateTime(timezone=True), server_default=func.now(), nullable=True
|
sa_column_kwargs={"server_default": sa.func.now()},
|
||||||
),
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
updated_at: datetime | None = Field(
|
updated_at: datetime | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
sa_column=Column(DateTime(timezone=True), onupdate=func.now(), nullable=True),
|
sa_type=sa.DateTime(timezone=True),
|
||||||
|
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -64,28 +99,27 @@ class AuditLog(SQLModel, table=True):
|
|||||||
client_name: str | None = None
|
client_name: str | None = None
|
||||||
message: str
|
message: str
|
||||||
origin: str | None = None
|
origin: str | None = None
|
||||||
|
|
||||||
timestamp: datetime | None = Field(
|
timestamp: datetime | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
sa_column=Column(
|
sa_type=sa.DateTime(timezone=True),
|
||||||
DateTime(timezone=True), server_default=func.now(), nullable=True
|
sa_column_kwargs={"server_default": sa.func.now()},
|
||||||
),
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class APIClient(SQLModel, table=True):
|
class APIClient(SQLModel, table=True):
|
||||||
"""Stores API Keys."""
|
"""Stores API Keys."""
|
||||||
|
|
||||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||||
token: str
|
token: str
|
||||||
read_write: bool
|
read_write: bool
|
||||||
created_at: datetime = Field(
|
created_at: datetime | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
sa_column=Column(
|
sa_type=sa.DateTime(timezone=True),
|
||||||
DateTime(timezone=True), server_default=func.now(), nullable=True
|
sa_column_kwargs={"server_default": sa.func.now()},
|
||||||
),
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def init_db(engine: sa.Engine) -> None:
|
||||||
def init_db(engine: Engine) -> None:
|
|
||||||
"""Create database."""
|
"""Create database."""
|
||||||
SQLModel.metadata.create_all(engine)
|
SQLModel.metadata.create_all(engine)
|
||||||
|
|||||||
@ -1,9 +0,0 @@
|
|||||||
"""API Router."""
|
|
||||||
|
|
||||||
from fastapi import FastAPI
|
|
||||||
|
|
||||||
from .app import backend_api
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
app.include_router(backend_api)
|
|
||||||
@ -1,26 +1,25 @@
|
|||||||
"""Settings management."""
|
"""Settings management."""
|
||||||
|
|
||||||
import os
|
from typing import override
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
from dotenv import load_dotenv
|
from pydantic_settings import (
|
||||||
|
BaseSettings,
|
||||||
|
SettingsConfigDict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_DATABASE = "sshecret.db"
|
DEFAULT_DATABASE = "sshecret.db"
|
||||||
|
|
||||||
|
|
||||||
load_dotenv()
|
class BackendSettings(BaseSettings):
|
||||||
|
"""Backend settings."""
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(env_file=".backend.env", env_prefix="sshecret_")
|
||||||
|
|
||||||
class BackendSettings(BaseModel):
|
db_file: Path = Field(default=Path(DEFAULT_DATABASE).absolute())
|
||||||
"""Backend Settings."""
|
|
||||||
|
|
||||||
db_file: Path
|
|
||||||
regenerate_tokens: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
def get_settings() -> BackendSettings:
|
def get_settings() -> BackendSettings:
|
||||||
"""Get settings."""
|
"""Get settings."""
|
||||||
db_filename = os.getenv("SSHECRET_DATABASE") or DEFAULT_DATABASE
|
return BackendSettings()
|
||||||
db_file = Path(db_filename).absolute()
|
|
||||||
|
|
||||||
return BackendSettings(db_file=db_file)
|
|
||||||
|
|||||||
@ -1,57 +1,48 @@
|
|||||||
"""Models for API views."""
|
"""Models for API views."""
|
||||||
|
|
||||||
|
import ipaddress
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Self, override
|
from typing import Annotated, Any, Self, override
|
||||||
|
|
||||||
from sqlmodel import Field, SQLModel
|
from sqlmodel import Field, SQLModel
|
||||||
|
from pydantic import IPvAnyAddress, IPvAnyNetwork
|
||||||
from . import models
|
from . import models
|
||||||
|
|
||||||
|
|
||||||
class ClientListResponse(SQLModel):
|
|
||||||
"""Model list responses."""
|
|
||||||
|
|
||||||
|
class ClientView(SQLModel):
|
||||||
|
"""View for a single client."""
|
||||||
id: uuid.UUID
|
id: uuid.UUID
|
||||||
name: str
|
name: str
|
||||||
fingerprint: str
|
public_key: str
|
||||||
|
policies: list[str] = ["0.0.0.0/0", "::/0"]
|
||||||
|
secrets: list[str] = Field(default_factory=list)
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime | None = None
|
updated_at: datetime | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_clients(cls, clients: list[models.Client]) -> list[Self]:
|
def from_client_list(cls, clients: list[models.Client]) -> list[Self]:
|
||||||
"""Generate a list of responses from a list of clients."""
|
"""Generate a list of responses from a list of clients."""
|
||||||
responses: list[Self] = []
|
responses: list[Self] = [cls.from_client(client) for client in clients]
|
||||||
for client in clients:
|
|
||||||
responses.append(
|
|
||||||
cls(
|
|
||||||
id=client.id,
|
|
||||||
name=client.name,
|
|
||||||
fingerprint=client.fingerprint,
|
|
||||||
created_at=client.created_at,
|
|
||||||
updated_at=client.updated_at or None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
|
|
||||||
class ClientView(ClientListResponse):
|
|
||||||
"""View for a single client."""
|
|
||||||
|
|
||||||
secrets: list[str] = Field(default_factory=list)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_client(cls, client: models.Client) -> Self:
|
def from_client(cls, client: models.Client) -> Self:
|
||||||
"""Instantiate from a client."""
|
"""Instantiate from a client."""
|
||||||
view = cls(
|
view = cls(
|
||||||
id=client.id,
|
id=client.id,
|
||||||
name=client.name,
|
name=client.name,
|
||||||
fingerprint=client.fingerprint,
|
public_key=client.public_key,
|
||||||
created_at=client.created_at,
|
created_at=client.created_at,
|
||||||
updated_at=client.updated_at or None,
|
updated_at=client.updated_at or None,
|
||||||
)
|
)
|
||||||
if client.secrets:
|
if client.secrets:
|
||||||
view.secrets = [secret.name for secret in client.secrets]
|
view.secrets = [secret.name for secret in client.secrets]
|
||||||
|
|
||||||
|
if client.policies:
|
||||||
|
view.policies = [policy.source for policy in client.policies]
|
||||||
|
|
||||||
return view
|
return view
|
||||||
|
|
||||||
|
|
||||||
@ -59,17 +50,20 @@ class ClientCreate(SQLModel):
|
|||||||
"""Model to create a client."""
|
"""Model to create a client."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
fingerprint: str
|
public_key: str
|
||||||
|
|
||||||
def to_client(self) -> models.Client:
|
def to_client(self) -> models.Client:
|
||||||
"""Instantiate a client."""
|
"""Instantiate a client."""
|
||||||
return models.Client(name=self.name, fingerprint=self.fingerprint)
|
public_key = self.public_key
|
||||||
|
return models.Client(
|
||||||
|
name=self.name, public_key=public_key
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ClientUpdate(SQLModel):
|
class ClientUpdate(SQLModel):
|
||||||
"""Model to update the client fingerprint."""
|
"""Model to update the client public key."""
|
||||||
|
|
||||||
fingerprint: str
|
public_key: str
|
||||||
|
|
||||||
|
|
||||||
class BodyValue(SQLModel):
|
class BodyValue(SQLModel):
|
||||||
@ -110,3 +104,22 @@ class ClientSecretResponse(ClientSecretPublic):
|
|||||||
created_at=client_secret.created_at,
|
created_at=client_secret.created_at,
|
||||||
updated_at=client_secret.updated_at,
|
updated_at=client_secret.updated_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClientPolicyView(SQLModel):
|
||||||
|
"""Update object for client policy."""
|
||||||
|
|
||||||
|
sources: list[str] = ["0.0.0.0/0", "::/0"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_client(cls, client: models.Client) -> Self:
|
||||||
|
"""Create from client."""
|
||||||
|
if not client.policies:
|
||||||
|
return cls()
|
||||||
|
return cls(sources=[policy.source for policy in client.policies])
|
||||||
|
|
||||||
|
|
||||||
|
class ClientPolicyUpdate(SQLModel):
|
||||||
|
"""Model for updating policies."""
|
||||||
|
|
||||||
|
sources: list[IPvAnyAddress | IPvAnyNetwork]
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
"""Tests of the backend api using pytest."""
|
"""Tests of the backend api using pytest."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import random
|
||||||
|
import string
|
||||||
from httpx import Response
|
from httpx import Response
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -8,8 +10,7 @@ from fastapi.testclient import TestClient
|
|||||||
from sqlmodel import Session, SQLModel, create_engine
|
from sqlmodel import Session, SQLModel, create_engine
|
||||||
from sqlmodel.pool import StaticPool
|
from sqlmodel.pool import StaticPool
|
||||||
|
|
||||||
from sshecret_backend.router import app
|
from sshecret_backend.app import app, get_session
|
||||||
from sshecret_backend.app import get_session
|
|
||||||
from sshecret_backend.testing import create_test_token
|
from sshecret_backend.testing import create_test_token
|
||||||
from sshecret_backend.models import AuditLog
|
from sshecret_backend.models import AuditLog
|
||||||
|
|
||||||
@ -22,17 +23,30 @@ LOG.addHandler(handler)
|
|||||||
LOG.setLevel(logging.DEBUG)
|
LOG.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
TEST_FINGERPRINT = "00:11:22:33:44:55:66:77:88:99:aa:bb:cc:dd:ee:ff"
|
def make_test_key() -> str:
|
||||||
|
"""Generate a test key."""
|
||||||
|
randomlength = 540
|
||||||
|
key = "ssh-rsa "
|
||||||
|
randompart = "".join(
|
||||||
|
random.choices(string.ascii_letters + string.digits, k=randomlength)
|
||||||
|
)
|
||||||
|
comment = " invalid-test-key"
|
||||||
|
return key + randompart + comment
|
||||||
|
|
||||||
|
|
||||||
def create_client(
|
def create_client(
|
||||||
test_client: TestClient,
|
test_client: TestClient,
|
||||||
headers: dict[str, str],
|
headers: dict[str, str],
|
||||||
name: str,
|
name: str,
|
||||||
fingerprint: str = TEST_FINGERPRINT,
|
public_key: str | None = None,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""Create client."""
|
"""Create client."""
|
||||||
data = {"name": name, "fingerprint": fingerprint}
|
if not public_key:
|
||||||
|
public_key = make_test_key()
|
||||||
|
data = {
|
||||||
|
"name": name,
|
||||||
|
"public_key": public_key,
|
||||||
|
}
|
||||||
create_response = test_client.post("/api/v1/clients", headers=headers, json=data)
|
create_response = test_client.post("/api/v1/clients", headers=headers, json=data)
|
||||||
return create_response
|
return create_response
|
||||||
|
|
||||||
@ -95,10 +109,8 @@ def test_with_token(test_client: TestClient, token: str) -> None:
|
|||||||
def test_create_client(test_client: TestClient, headers: dict[str, str]) -> None:
|
def test_create_client(test_client: TestClient, headers: dict[str, str]) -> None:
|
||||||
"""Test creating a client."""
|
"""Test creating a client."""
|
||||||
client_name = "test"
|
client_name = "test"
|
||||||
client_fingerprint = TEST_FINGERPRINT
|
client_publickey = make_test_key()
|
||||||
create_response = create_client(
|
create_response = create_client(test_client, headers, client_name, client_publickey)
|
||||||
test_client, headers, client_name, client_fingerprint
|
|
||||||
)
|
|
||||||
assert create_response.status_code == 200
|
assert create_response.status_code == 200
|
||||||
response = test_client.get("/api/v1/clients/", headers=headers)
|
response = test_client.get("/api/v1/clients/", headers=headers)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@ -107,16 +119,35 @@ def test_create_client(test_client: TestClient, headers: dict[str, str]) -> None
|
|||||||
client = clients[0]
|
client = clients[0]
|
||||||
assert isinstance(client, dict)
|
assert isinstance(client, dict)
|
||||||
assert client.get("name") == client_name
|
assert client.get("name") == client_name
|
||||||
assert client.get("fingerprint") == client_fingerprint
|
|
||||||
assert client.get("created_at") is not None
|
assert client.get("created_at") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_client(test_client: TestClient, headers: dict[str, str]) -> None:
|
||||||
|
"""Test creating a client."""
|
||||||
|
client_name = "test"
|
||||||
|
create_response = create_client(
|
||||||
|
test_client,
|
||||||
|
headers,
|
||||||
|
client_name,
|
||||||
|
)
|
||||||
|
assert create_response.status_code == 200
|
||||||
|
|
||||||
|
resp = test_client.delete("/api/v1/clients/test", headers=headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
resp = test_client.get("/api/v1/clients/test", headers=headers)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
def test_add_secret(test_client: TestClient, headers: dict[str, str]) -> None:
|
def test_add_secret(test_client: TestClient, headers: dict[str, str]) -> None:
|
||||||
"""Test adding a secret to a client."""
|
"""Test adding a secret to a client."""
|
||||||
client_name = "test"
|
client_name = "test"
|
||||||
client_fingerprint = "00:11:22:33:44:55:66:77:88:99:aa:bb:cc:dd:ee:ff"
|
client_publickey = make_test_key()
|
||||||
create_response = create_client(
|
create_response = create_client(
|
||||||
test_client, headers, client_name, client_fingerprint
|
test_client,
|
||||||
|
headers,
|
||||||
|
client_name,
|
||||||
|
client_publickey,
|
||||||
)
|
)
|
||||||
assert create_response.status_code == 200
|
assert create_response.status_code == 200
|
||||||
secret_name = "mysecret"
|
secret_name = "mysecret"
|
||||||
@ -136,6 +167,17 @@ def test_add_secret(test_client: TestClient, headers: dict[str, str]) -> None:
|
|||||||
assert secret_body["secret"] == data["secret"]
|
assert secret_body["secret"] == data["secret"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_secret(test_client: TestClient, headers: dict[str, str]) -> None:
|
||||||
|
"""Test deleting a secret."""
|
||||||
|
test_add_secret(test_client, headers)
|
||||||
|
resp = test_client.delete("/api/v1/clients/test/secrets/mysecret", headers=headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
get_response = test_client.get(
|
||||||
|
"/api/v1/clients/test/secrets/mysecret", headers=headers
|
||||||
|
)
|
||||||
|
assert get_response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
def test_put_add_secret(test_client: TestClient, headers: dict[str, str]) -> None:
|
def test_put_add_secret(test_client: TestClient, headers: dict[str, str]) -> None:
|
||||||
"""Test adding secret via PUT."""
|
"""Test adding secret via PUT."""
|
||||||
# Use the test_create_client function to create a client.
|
# Use the test_create_client function to create a client.
|
||||||
@ -178,7 +220,8 @@ def test_put_update_secret(test_client: TestClient, headers: dict[str, str]) ->
|
|||||||
|
|
||||||
def test_audit_logging(test_client: TestClient, headers: dict[str, str]) -> None:
|
def test_audit_logging(test_client: TestClient, headers: dict[str, str]) -> None:
|
||||||
"""Test audit logging."""
|
"""Test audit logging."""
|
||||||
create_client_resp = create_client(test_client, headers, "test")
|
public_key = make_test_key()
|
||||||
|
create_client_resp = create_client(test_client, headers, "test", public_key)
|
||||||
assert create_client_resp.status_code == 200
|
assert create_client_resp.status_code == 200
|
||||||
secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"}
|
secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"}
|
||||||
for name, secret in secrets.items():
|
for name, secret in secrets.items():
|
||||||
@ -251,7 +294,8 @@ def test_audit_log_filtering(
|
|||||||
|
|
||||||
def test_secret_invalidation(test_client: TestClient, headers: dict[str, str]) -> None:
|
def test_secret_invalidation(test_client: TestClient, headers: dict[str, str]) -> None:
|
||||||
"""Test secret invalidation."""
|
"""Test secret invalidation."""
|
||||||
create_client_resp = create_client(test_client, headers, "test")
|
initial_key = make_test_key()
|
||||||
|
create_client_resp = create_client(test_client, headers, "test", initial_key)
|
||||||
assert create_client_resp.status_code == 200
|
assert create_client_resp.status_code == 200
|
||||||
secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"}
|
secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"}
|
||||||
for name, secret in secrets.items():
|
for name, secret in secrets.items():
|
||||||
@ -262,12 +306,13 @@ def test_secret_invalidation(test_client: TestClient, headers: dict[str, str]) -
|
|||||||
)
|
)
|
||||||
assert add_resp.status_code == 200
|
assert add_resp.status_code == 200
|
||||||
|
|
||||||
# Update the fingerprint. This should cause all secrets to be invalidated
|
# Update the public-key. This should cause all secrets to be invalidated
|
||||||
# and no longer associated with a client.
|
# and no longer associated with a client.
|
||||||
|
new_key = make_test_key()
|
||||||
update_resp = test_client.post(
|
update_resp = test_client.post(
|
||||||
"/api/v1/clients/test/update_fingerprint",
|
"/api/v1/clients/test/public-key",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
json={"fingerprint": "foobar"},
|
json={"public_key": new_key},
|
||||||
)
|
)
|
||||||
assert update_resp.status_code == 200
|
assert update_resp.status_code == 200
|
||||||
|
|
||||||
@ -277,3 +322,117 @@ def test_secret_invalidation(test_client: TestClient, headers: dict[str, str]) -
|
|||||||
client = get_resp.json()
|
client = get_resp.json()
|
||||||
secrets = client.get("secrets")
|
secrets = client.get("secrets")
|
||||||
assert bool(secrets) is False
|
assert bool(secrets) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_default_policies(
|
||||||
|
test_client: TestClient, headers: dict[str, str]
|
||||||
|
) -> None:
|
||||||
|
"""Test client policies."""
|
||||||
|
public_key = make_test_key()
|
||||||
|
resp = create_client(test_client, headers, "test", public_key)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
# Fetch policies, should return *
|
||||||
|
resp = test_client.get("/api/v1/clients/test/policies/", headers=headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
policies = resp.json()
|
||||||
|
|
||||||
|
assert policies["sources"] == ["0.0.0.0/0", "::/0"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_policy_update_one(
|
||||||
|
test_client: TestClient, headers: dict[str, str]
|
||||||
|
) -> None:
|
||||||
|
"""Update client policy with single policy."""
|
||||||
|
public_key = make_test_key()
|
||||||
|
resp = create_client(test_client, headers, "test", public_key)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
policy = ["192.0.2.1"]
|
||||||
|
resp = test_client.put(
|
||||||
|
"/api/v1/clients/test/policies/", headers=headers, json={"sources": policy}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
resp = test_client.get("/api/v1/clients/test/policies/", headers=headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
policies = resp.json()
|
||||||
|
|
||||||
|
assert policies["sources"] == policy
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_policy_update_advanced(
|
||||||
|
test_client: TestClient, headers: dict[str, str]
|
||||||
|
) -> None:
|
||||||
|
"""Test other policy update scenarios."""
|
||||||
|
public_key = make_test_key()
|
||||||
|
resp = create_client(test_client, headers, "test", public_key)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
policy = ["192.0.2.1", "198.18.0.0/24"]
|
||||||
|
|
||||||
|
resp = test_client.put(
|
||||||
|
"/api/v1/clients/test/policies/", headers=headers, json={"sources": policy}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
resp = test_client.get("/api/v1/clients/test/policies/", headers=headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
policies = resp.json()
|
||||||
|
|
||||||
|
assert "192.0.2.1" in policies["sources"]
|
||||||
|
assert "198.18.0.0/24" in policies["sources"]
|
||||||
|
|
||||||
|
# Try to set it to something incorrect
|
||||||
|
|
||||||
|
policy = ["obviosly_wrong"]
|
||||||
|
|
||||||
|
resp = test_client.put(
|
||||||
|
"/api/v1/clients/test/policies/", headers=headers, json={"sources": policy}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
# Check that the old value is still there
|
||||||
|
|
||||||
|
resp = test_client.get("/api/v1/clients/test/policies/", headers=headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
policies = resp.json()
|
||||||
|
|
||||||
|
assert "192.0.2.1" in policies["sources"]
|
||||||
|
assert "198.18.0.0/24" in policies["sources"]
|
||||||
|
|
||||||
|
# Clear the policies
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_policy_update_unset(
|
||||||
|
test_client: TestClient, headers: dict[str, str]
|
||||||
|
) -> None:
|
||||||
|
"""Test clearing the client policy."""
|
||||||
|
public_key = make_test_key()
|
||||||
|
resp = create_client(test_client, headers, "test", public_key)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
policy = ["192.0.2.1", "198.18.0.0/24"]
|
||||||
|
|
||||||
|
resp = test_client.put(
|
||||||
|
"/api/v1/clients/test/policies/", headers=headers, json={"sources": policy}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
policies = resp.json()
|
||||||
|
|
||||||
|
assert "192.0.2.1" in policies["sources"]
|
||||||
|
assert "198.18.0.0/24" in policies["sources"]
|
||||||
|
|
||||||
|
# Now we clear the policies
|
||||||
|
|
||||||
|
resp = test_client.put(
|
||||||
|
"/api/v1/clients/test/policies/", headers=headers, json={"sources": []}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
policies = resp.json()
|
||||||
|
|
||||||
|
assert policies["sources"] == ["0.0.0.0/0", "::/0"]
|
||||||
|
|||||||
Reference in New Issue
Block a user