Centralize testing

This commit is contained in:
2025-05-11 11:22:00 +02:00
parent b34c49d3e3
commit d3d99775d9
19 changed files with 565 additions and 4 deletions

View File

@ -1,538 +0,0 @@
"""Tests of the backend api using pytest."""
import logging
from pathlib import Path
from httpx import Response
import pytest
from fastapi.testclient import TestClient
from sshecret.crypto import generate_private_key, generate_public_key_string
from sshecret_backend.app import create_backend_app
from sshecret_backend.testing import create_test_token
from sshecret_backend.view_models import AuditView
from sshecret_backend.settings import BackendSettings
LOG = logging.getLogger()
handler = logging.StreamHandler()
formatter = logging.Formatter("'%(asctime)s - %(levelname)s - %(message)s'")
handler.setFormatter(formatter)
LOG.addHandler(handler)
LOG.setLevel(logging.DEBUG)
def make_test_key() -> str:
"""Generate a test key."""
private_key = generate_private_key()
return generate_public_key_string(private_key.public_key())
def create_client(
test_client: TestClient,
name: str,
public_key: str | None = None,
description: str | None = None,
) -> Response:
"""Create client."""
if not public_key:
public_key = make_test_key()
data = {
"name": name,
"public_key": public_key,
}
if description:
data["description"] = description
create_response = test_client.post("/api/v1/clients", json=data)
return create_response
@pytest.fixture(name="test_client")
def create_client_fixture(tmp_path: Path):
"""Test client fixture."""
db_file = tmp_path / "backend.db"
print(f"DB File: {db_file.absolute()}")
settings = BackendSettings(database=str(db_file.absolute()))
app = create_backend_app(settings)
token = create_test_token(settings)
test_client = TestClient(app, headers={"X-API-Token": token})
yield test_client
def test_missing_token(test_client: TestClient) -> None:
"""Test logging in with missing token."""
# Save headers
old_headers = test_client.headers
test_client.headers = {}
response = test_client.get("/api/v1/clients/", headers={})
assert response.status_code == 422
test_client.headers = old_headers
def test_incorrect_token(test_client: TestClient) -> None:
"""Test logging in with missing token."""
response = test_client.get("/api/v1/clients/", headers={"X-API-Token": "WRONG"})
assert response.status_code == 401
def test_with_token(test_client: TestClient) -> None:
"""Test with a valid token."""
response = test_client.get("/api/v1/clients/")
assert response.status_code == 200
data = response.json()
assert data["total_results"] == 0
def test_create_client(test_client: TestClient) -> None:
"""Test creating a client."""
client_name = "test"
client_publickey = make_test_key()
create_response = create_client(test_client, client_name, client_publickey)
assert create_response.status_code == 200
response = test_client.get("/api/v1/clients")
assert response.status_code == 200
clients_result = response.json()
clients = clients_result["clients"]
assert isinstance(clients, list)
client = clients[0]
assert isinstance(client, dict)
assert client.get("name") == client_name
assert client.get("created_at") is not None
def test_delete_client(test_client: TestClient) -> None:
"""Test creating a client."""
client_name = "test"
create_response = create_client(
test_client,
client_name,
)
assert create_response.status_code == 200
resp = test_client.delete("/api/v1/clients/test")
assert resp.status_code == 200
resp = test_client.get("/api/v1/clients/test")
assert resp.status_code == 404
def test_add_secret(test_client: TestClient) -> None:
"""Test adding a secret to a client."""
client_name = "test"
client_publickey = make_test_key()
create_response = create_client(
test_client,
client_name,
client_publickey,
)
assert create_response.status_code == 200
secret_name = "mysecret"
secret_value = "shhhh"
data = {"name": secret_name, "secret": secret_value, "description": "A test secret"}
response = test_client.post("/api/v1/clients/test/secrets/", json=data)
assert response.status_code == 200
# Get it back
get_response = test_client.get("/api/v1/clients/test/secrets/mysecret")
assert get_response.status_code == 200
secret_body = get_response.json()
assert secret_body["name"] == data["name"]
assert secret_body["secret"] == data["secret"]
def test_delete_secret(test_client: TestClient) -> None:
"""Test deleting a secret."""
test_add_secret(test_client)
resp = test_client.delete("/api/v1/clients/test/secrets/mysecret")
assert resp.status_code == 200
get_response = test_client.get("/api/v1/clients/test/secrets/mysecret")
assert get_response.status_code == 404
def test_put_add_secret(test_client: TestClient) -> None:
"""Test adding secret via PUT."""
# Use the test_create_client function to create a client.
test_create_client(test_client)
secret_name = "mysecret"
secret_value = "shhhh"
data = {"name": secret_name, "secret": secret_value, "description": None}
response = test_client.put(
"/api/v1/clients/test/secrets/mysecret",
json={"value": secret_value},
)
assert response.status_code == 200
response_model = response.json()
del response_model["created_at"]
del response_model["updated_at"]
assert response_model == data
def test_put_update_secret(test_client: TestClient) -> None:
"""Test updating a client secret."""
test_add_secret(test_client)
new_value = "itsasecret"
update_response = test_client.put(
"/api/v1/clients/test/secrets/mysecret",
json={"value": new_value},
)
assert update_response.status_code == 200
expected = {"name": "mysecret", "secret": new_value}
response_model = update_response.json()
assert {
"name": response_model["name"],
"secret": response_model["secret"],
} == expected
# Ensure that the updated_at has been set.
assert "updated_at" in response_model
def test_audit_logging(test_client: TestClient) -> None:
"""Test audit logging."""
public_key = make_test_key()
create_client_resp = create_client(test_client, "test", public_key)
assert create_client_resp.status_code == 200
secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"}
for name, secret in secrets.items():
add_resp = test_client.post(
"/api/v1/clients/test/secrets/",
json={"name": name, "secret": secret},
)
assert add_resp.status_code == 200
# Fetch the entire client.
get_client_resp = test_client.get("/api/v1/clients/test")
assert get_client_resp.status_code == 200
# Fetch the audit log
audit_log_resp = test_client.get("/api/v1/audit/")
assert audit_log_resp.status_code == 200
audit_logs = audit_log_resp.json()
assert len(audit_logs) > 0
for entry in audit_logs:
# Let's try to reassemble the objects
audit_log = AuditView.model_validate(entry)
assert audit_log is not None
# def test_audit_log_filtering(
# session: Session, test_client: TestClient
# ) -> None:
# """Test audit log filtering."""
# # Create a lot of test data, but just manually.
# audit_log_amount = 150
# entries: list[AuditLog] = []
# for i in range(audit_log_amount):
# client_id = i % 5
# entries.append(
# AuditLog(
# operation="TEST",
# object_id=str(i),
# client_name=f"client-{client_id}",
# message="Test Message",
# )
# )
# session.add_all(entries)
# session.commit()
# # This should have generated a lot of audit messages
# audit_path = "/api/v1/audit/"
# audit_log_resp = test_client.get(audit_path)
# assert audit_log_resp.status_code == 200
# entries = audit_log_resp.json()
# assert len(entries) == 100 # We get 100 at a time
# audit_log_resp = test_client.get(
# audit_path, params={"offset": 100}
# )
# entries = audit_log_resp.json()
# assert len(entries) == 52 # There should be 50 + the two requests we made
# # Try to get a specific client
# # There should be 30 log entries for each client.
# audit_log_resp = test_client.get(
# audit_path, params={"filter_client": "client-1"}
# )
# entries = audit_log_resp.json()
# assert len(entries) == 30
def test_secret_invalidation(test_client: TestClient) -> None:
"""Test secret invalidation."""
initial_key = make_test_key()
create_client_resp = create_client(test_client, "test", initial_key)
assert create_client_resp.status_code == 200
secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"}
for name, secret in secrets.items():
add_resp = test_client.post(
"/api/v1/clients/test/secrets/",
json={"name": name, "secret": secret},
)
assert add_resp.status_code == 200
# Update the public-key. This should cause all secrets to be invalidated
# and no longer associated with a client.
new_key = make_test_key()
update_resp = test_client.post(
"/api/v1/clients/test/public-key",
json={"public_key": new_key},
)
assert update_resp.status_code == 200
# Fetch the client. The list of secrets should be empty.
get_resp = test_client.get("/api/v1/clients/test")
assert get_resp.status_code == 200
client = get_resp.json()
secrets = client.get("secrets")
assert bool(secrets) is False
def test_client_default_policies(
test_client: TestClient,
) -> None:
"""Test client policies."""
public_key = make_test_key()
resp = create_client(test_client, "test")
assert resp.status_code == 200
# Fetch policies, should return *
resp = test_client.get("/api/v1/clients/test/policies/")
assert resp.status_code == 200
policies = resp.json()
assert policies["sources"] == ["0.0.0.0/0", "::/0"]
def test_client_policy_update_one(test_client: TestClient) -> None:
"""Update client policy with single policy."""
public_key = make_test_key()
resp = create_client(test_client, "test", public_key)
assert resp.status_code == 200
policy = ["192.0.2.1"]
resp = test_client.put("/api/v1/clients/test/policies/", json={"sources": policy})
assert resp.status_code == 200
resp = test_client.get("/api/v1/clients/test/policies/")
assert resp.status_code == 200
policies = resp.json()
assert policies["sources"] == policy
def test_client_policy_update_advanced(test_client: TestClient) -> None:
"""Test other policy update scenarios."""
public_key = make_test_key()
resp = create_client(test_client, "test", public_key)
assert resp.status_code == 200
policy = ["192.0.2.1", "198.18.0.0/24"]
resp = test_client.put("/api/v1/clients/test/policies/", json={"sources": policy})
assert resp.status_code == 200
resp = test_client.get("/api/v1/clients/test/policies/")
assert resp.status_code == 200
policies = resp.json()
assert "192.0.2.1" in policies["sources"]
assert "198.18.0.0/24" in policies["sources"]
# Try to set it to something incorrect
policy = ["obviosly_wrong"]
resp = test_client.put("/api/v1/clients/test/policies/", json={"sources": policy})
assert resp.status_code == 422
# Check that the old value is still there
resp = test_client.get("/api/v1/clients/test/policies/")
assert resp.status_code == 200
policies = resp.json()
assert "192.0.2.1" in policies["sources"]
assert "198.18.0.0/24" in policies["sources"]
# Clear the policies
#
def test_client_policy_update_unset(test_client: TestClient) -> None:
"""Test clearing the client policy."""
public_key = make_test_key()
resp = create_client(test_client, "test", public_key)
assert resp.status_code == 200
policy = ["192.0.2.1", "198.18.0.0/24"]
resp = test_client.put("/api/v1/clients/test/policies/", json={"sources": policy})
assert resp.status_code == 200
policies = resp.json()
assert "192.0.2.1" in policies["sources"]
assert "198.18.0.0/24" in policies["sources"]
# Now we clear the policies
resp = test_client.put("/api/v1/clients/test/policies/", json={"sources": []})
assert resp.status_code == 200
policies = resp.json()
assert policies["sources"] == ["0.0.0.0/0", "::/0"]
def test_client_update(test_client: TestClient) -> None:
"""Test generic update of a client."""
public_key = make_test_key()
resp = create_client(test_client, "test", public_key, "PRE")
assert resp.status_code == 200
resp = test_client.get("/api/v1/clients/test")
assert resp.status_code == 200
client_data = resp.json()
assert client_data["description"] == "PRE"
# Update the description
new_client_data = {
"name": "test",
"description": "POST",
"public_key": client_data["public_key"],
}
resp = test_client.put("/api/v1/clients/test", json=new_client_data)
assert resp.status_code == 200
client_data = resp.json()
assert client_data["description"] == "POST"
resp = test_client.get("/api/v1/clients/test")
assert resp.status_code == 200
client_data = resp.json()
assert client_data["description"] == "POST"
def test_get_secret_list(test_client: TestClient) -> None:
"""Test the secret to client map view."""
# Make 4 clients
for x in range(4):
public_key = make_test_key()
create_client(test_client, f"client-{x}", public_key)
# Create a secret that only this client has.
resp = test_client.put(
f"/api/v1/clients/client-{x}/secrets/client-{x}", json={"value": "SECRET"}
)
assert resp.status_code == 200
# Create a secret that all of them have.
resp = test_client.put(
f"/api/v1/clients/client-{x}/secrets/commonsecret", json={"value": "SECRET"}
)
assert resp.status_code == 200
# Get the secret list
resp = test_client.get("/api/v1/secrets/")
assert resp.status_code == 200
data = resp.json()
assert isinstance(data, list)
assert len(data) == 5
for entry in data:
if entry["name"] == "commonsecret":
assert len(entry["clients"]) == 4
else:
assert len(entry["clients"]) == 1
assert entry["clients"][0] == entry["name"]
def test_get_secret_clients(test_client: TestClient) -> None:
"""Get the clients for a single secret."""
for x in range(4):
public_key = make_test_key()
create_client(test_client, f"client-{x}", public_key)
# Create a secret that every second of them have.
if x % 2 == 1:
continue
resp = test_client.put(
f"/api/v1/clients/client-{x}/secrets/commonsecret", json={"value": "SECRET"}
)
assert resp.status_code == 200
resp = test_client.get("/api/v1/secrets/commonsecret")
assert resp.status_code == 200
data = resp.json()
assert data["name"] == "commonsecret"
assert "client-0" in data["clients"]
assert "client-1" not in data["clients"]
assert len(data["clients"]) == 2
def test_searching(test_client: TestClient) -> None:
"""Test searching."""
for x in range(4):
# Create four clients
create_client(test_client, f"client-{x}")
# Create one with a different name.
create_client(test_client, "othername")
# Search for a specific one.
resp = test_client.get("/api/v1/clients/", params={"name": "othername"})
assert resp.status_code == 200
result = resp.json()
assert result["total_results"] == 1
assert result["clients"][0]["name"] == "othername"
client_id = result["clients"][0]["id"]
# Search by ID
resp = test_client.get("/api/v1/clients/", params={"id": client_id})
assert resp.status_code == 200
result = resp.json()
assert result["total_results"] == 1
assert result["clients"][0]["name"] == "othername"
# Search for the four similarly named ones
resp = test_client.get("/api/v1/clients/", params={"name__like": "client-%"})
assert resp.status_code == 200
result = resp.json()
assert result["total_results"] == 4
assert str(result["clients"][0]["name"]).startswith("client-")
def test_operations_with_id(test_client: TestClient) -> None:
"""Test operations using ID instead of name."""
create_client(test_client, "test")
resp = test_client.get("/api/v1/clients/")
assert resp.status_code == 200
data = resp.json()
client = data["clients"][0]
client_id = client["id"]
resp = test_client.get(f"/api/v1/clients/{client_id}")
assert resp.status_code == 200
data = resp.json()
assert data["name"] == "test"
def test_write_audit_log(test_client: TestClient) -> None:
"""Test writing to the audit log."""
params = {
"subsystem": "backend",
"operation": "read",
"message": "Test Message"
}
resp = test_client.post("/api/v1/audit", json=params)
assert resp.status_code == 200
resp = test_client.get("/api/v1/audit")
assert resp.status_code == 200
data = resp.json()
entry = data[0]
for key, value in params.items():
assert entry[key] == value

View File

@ -1,2 +0,0 @@
[pytest]
asyncio_mode = auto

View File

@ -1 +0,0 @@

View File

@ -1,179 +0,0 @@
import asyncio
import pytest
import uuid
import asyncssh
import tempfile
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock
from ipaddress import IPv4Network, IPv6Network
from sshecret_sshd.ssh_server import run_ssh_server
from sshecret_sshd.settings import ClientRegistrationSettings
from .types import (
Client,
ClientKey,
ClientRegistry,
SshServerFixtureFun,
SshServerFixture,
)
@pytest.fixture(scope="function")
def client_registry() -> ClientRegistry:
clients = {}
secrets = {}
async def add_client(
name: str,
secret_names: list[str] | None = None,
policies: list[str] | None = None,
) -> str:
private_key = asyncssh.generate_private_key("ssh-rsa")
public_key = private_key.export_public_key()
clients[name] = ClientKey(name, private_key, public_key.decode().rstrip())
secrets.update({(name, s): f"mocked-secret-{s}" for s in (secret_names or [])})
return clients[name]
return {"clients": clients, "secrets": secrets, "add_client": add_client}
@pytest.fixture(scope="function")
async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
backend = MagicMock()
clients_data = client_registry["clients"]
secrets_data = client_registry["secrets"]
async def get_client(name: str) -> Client | None:
client_key = clients_data.get(name)
if client_key:
response_model = Client(
id=uuid.uuid4(),
name=name,
description=f"Mock client {name}",
public_key=client_key.public_key,
secrets=[s for (c, s) in secrets_data if c == name],
policies=[IPv4Network("0.0.0.0/0"), IPv6Network("::/0")],
)
return response_model
return None
async def get_client_secret(name: str, secret_name: str) -> str | None:
secret = secrets_data.get((name, secret_name), None)
return secret
async def create_client(name: str, public_key: str) -> None:
"""Create client.
This only works if you register a client called template first.
Otherwise we can't test this...
"""
if "template" not in clients_data:
raise RuntimeError(
"Error, must have a client called template for this to work."
)
clients_data[name] = clients_data["template"]
for secret_key, secret in secrets_data.items():
s_client, secret_name = secret_key
if s_client != "template":
continue
secrets_data[(name, secret_name)] = secret
backend.get_client = AsyncMock(side_effect=get_client)
backend.get_client_secret = AsyncMock(side_effect=get_client_secret)
backend.create_client = AsyncMock(side_effect=create_client)
# Make sure backend.audit(...) returns the audit mock
audit = MagicMock()
audit.write = MagicMock()
backend.audit = MagicMock(return_value=audit)
return backend
@pytest.fixture(scope="function")
async def ssh_server(
mock_backend: MagicMock, unused_tcp_port: int
) -> SshServerFixtureFun:
port = unused_tcp_port
private_key = asyncssh.generate_private_key("ssh-ed25519")
key_str = private_key.export_private_key()
with tempfile.NamedTemporaryFile("w+", delete=True) as key_file:
key_file.write(key_str.decode())
key_file.flush()
registration_settings = ClientRegistrationSettings(
enabled=True, allow_from=[IPv4Network("0.0.0.0/0")]
)
server = await run_ssh_server(
backend=mock_backend,
listen_address="localhost",
port=port,
keys=[key_file.name],
registration=registration_settings,
enable_ping_command=True,
)
await asyncio.sleep(0.1)
yield server, port
server.close()
await server.wait_closed()
@pytest.fixture(scope="function")
def ssh_command_runner(ssh_server: SshServerFixture, client_registry: ClientRegistry):
"""Run a single command.
Tricky typing!
"""
_, port = ssh_server
async def run_command_as(name: str, command: str):
client_key = client_registry["clients"][name]
conn = await asyncssh.connect(
"127.0.0.1",
port=port,
username=name,
client_keys=[client_key.private_key],
known_hosts=None,
)
try:
result = await conn.run(command)
return result
finally:
conn.close()
await conn.wait_closed()
return run_command_as
@pytest.fixture(scope="function")
def ssh_session(ssh_server: SshServerFixture, client_registry: ClientRegistry):
"""Yield an interactive session."""
_, port = ssh_server
@asynccontextmanager
async def run_process_as(name: str, command: str, client: str | None = None):
if not client:
client = name
client_key = client_registry["clients"][client]
conn = await asyncssh.connect(
"127.0.0.1",
port=port,
username=name,
client_keys=[client_key.private_key],
known_hosts=None,
)
try:
async with conn.create_process(command) as process:
yield process
finally:
conn.close()
await conn.wait_closed()
return run_process_as

View File

@ -1,33 +0,0 @@
"""Test get secret."""
import pytest
from .types import ClientRegistry, CommandRunner
@pytest.mark.asyncio
async def test_get_secret(
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
) -> None:
"""Test that we can get a secret."""
await client_registry["add_client"]("test-client", ["mysecret"])
result = await ssh_command_runner("test-client", "get_secret mysecret")
assert result.stdout is not None
assert isinstance(result.stdout, str)
assert result.stdout.rstrip() == "mocked-secret-mysecret"
@pytest.mark.asyncio
async def test_invalid_secret_name(
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
) -> None:
"""Test getting an invalid secret name."""
await client_registry["add_client"]("test-client")
result = await ssh_command_runner("test-client", "get_secret mysecret")
assert result.exit_status == 1
assert result.stderr == "Error: No secret available with the given name."

View File

@ -1,18 +0,0 @@
import pytest
from .types import ClientRegistry, CommandRunner
@pytest.mark.asyncio
async def test_ping_command(
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
) -> None:
# Register a test client with default policies and no secrets
await client_registry["add_client"]("test-pinger")
result = await ssh_command_runner("test-pinger", "ping")
assert result.exit_status == 0
assert result.stdout is not None
assert isinstance(result.stdout, str)
assert result.stdout.rstrip() == "PONG"

View File

@ -1,40 +0,0 @@
"""Test registration."""
import pytest
from .types import ClientRegistry, CommandRunner, ProcessRunner
@pytest.mark.asyncio
async def test_register_client(
ssh_session: ProcessRunner,
ssh_command_runner: CommandRunner,
client_registry: ClientRegistry,
) -> None:
"""Test client registration."""
await client_registry["add_client"]("template", ["testsecret"])
public_key = client_registry["clients"]["template"].public_key.rstrip() + "\n"
async with ssh_session("newclient", "register", "template") as session:
maxlines = 10
linenum = 0
found = False
while linenum < maxlines:
line = await session.stdout.readline()
if "Enter public key" in line:
found = True
break
assert found is True
session.stdin.write(public_key)
result = await session.stdout.readline()
assert "OK" in result
# Test that we can connect
result = await ssh_command_runner("newclient", "get_secret testsecret")
assert result.stdout is not None
assert isinstance(result.stdout, str)
assert result.stdout.rstrip() == "mocked-secret-testsecret"

View File

@ -1,63 +0,0 @@
"""Types for the various test properties."""
import uuid
from datetime import datetime
from dataclasses import dataclass, field
from ipaddress import IPv4Network, IPv6Network
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any, Protocol, TypedDict, AsyncContextManager
import asyncssh
SshServerFixture = tuple[str, int]
SshServerFixtureFun = AsyncGenerator[tuple[asyncssh.SSHAcceptor, int], None]
@dataclass
class Client:
"""Mock client."""
id: uuid.UUID
name: str
description: str | None
public_key: str
secrets: list[str]
policies: list[IPv4Network | IPv6Network]
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)
@dataclass
class ClientKey:
name: str
private_key: asyncssh.SSHKey
public_key: str
class AddClientFun(Protocol):
"""Add client function."""
def __call__(
self,
name: str,
secret_names: list[str] | None = None,
policies: list[str] | None = None,
) -> Awaitable[str]: ...
class ProcessRunner(Protocol):
"""Process runner typing."""
def __call__(
self, name: str, command: str, client: str | None = None
) -> AsyncContextManager[asyncssh.SSHClientProcess[Any]]: ...
class ClientRegistry(TypedDict):
"""Client registry typing."""
clients: dict[str, ClientKey]
secrets: dict[tuple[str, str], str]
add_client: AddClientFun
CommandRunner = Callable[[str, str], Awaitable[asyncssh.SSHCompletedProcess]]