Centralize testing

This commit is contained in:
2025-05-11 11:22:00 +02:00
parent b34c49d3e3
commit d3d99775d9
19 changed files with 565 additions and 4 deletions

1
tests/__init__.py Normal file
View File

@ -0,0 +1 @@

View File

View File

@ -0,0 +1,27 @@
"""Client helpers."""
from dataclasses import dataclass
from cryptography.hazmat.primitives.asymmetric import rsa
from sshecret.crypto import generate_private_key, generate_public_key_string
@dataclass
class ClientData:
"""Test client."""
name: str
private_key: rsa.RSAPrivateKey
@property
def public_key(self) -> str:
"""Return public key as string."""
return generate_public_key_string(self.private_key.public_key())
def create_test_client(name: str) -> ClientData:
"""Create test client."""
return ClientData(
name=name,
private_key=generate_private_key()
)

View File

@ -0,0 +1,209 @@
"""Test library.
Strategy:
We start by spawning the backend server, and create two test keys.
Then we spawn the sshd and the admin api.
"""
import asyncio
import asyncssh
import secrets
import tempfile
from contextlib import asynccontextmanager
from pathlib import Path
import httpx
import pytest
import pytest_asyncio
import uvicorn
from sshecret.backend import SshecretBackend
from sshecret.crypto import (
generate_private_key,
generate_public_key_string,
write_private_key,
)
from sshecret_admin.core.app import create_admin_app
from sshecret_admin.core.settings import AdminServerSettings
from sshecret_backend.app import create_backend_app
from sshecret_backend.settings import BackendSettings
from sshecret_backend.testing import create_test_token
from sshecret_sshd.settings import ServerSettings
from sshecret_sshd.ssh_server import start_sshecret_sshd
from .clients import ClientData
from .helpers import create_sshd_server_key, create_test_admin_user, in_tempdir
from .types import PortFactory, TestPorts
TEST_SCOPE = "function"
LOOP_SCOPE = "function"
def make_test_key() -> str:
"""Generate a test key."""
private_key = generate_private_key()
return generate_public_key_string(private_key.public_key())
@pytest.fixture(name="test_ports", scope="session")
def generate_test_ports(unused_tcp_port_factory: PortFactory) -> TestPorts:
"""Generate the test ports."""
test_ports = TestPorts(
backend=unused_tcp_port_factory(),
admin=unused_tcp_port_factory(),
sshd=unused_tcp_port_factory(),
)
print(f"{test_ports=!r}")
return test_ports
@pytest_asyncio.fixture(scope=TEST_SCOPE, name="backend_server", loop_scope=LOOP_SCOPE)
async def run_backend_server(test_ports: TestPorts):
"""Run the backend server."""
port = test_ports.backend
with tempfile.TemporaryDirectory() as tmp_dir:
backend_work_path = Path(tmp_dir)
db_file = backend_work_path / "backend.db"
backend_settings = BackendSettings(database=str(db_file.absolute()))
backend_app = create_backend_app(backend_settings)
token = create_test_token(backend_settings)
config = uvicorn.Config(app=backend_app, port=port, loop="asyncio")
server = uvicorn.Server(config=config)
server_task = asyncio.create_task(server.serve())
await asyncio.sleep(0.1)
backend_url = f"http://127.0.0.1:{port}"
yield (backend_url, token)
server.should_exit = True
await server_task
@pytest_asyncio.fixture(scope=TEST_SCOPE, name="admin_server", loop_scope=LOOP_SCOPE)
async def run_admin_server(test_ports: TestPorts, backend_server: tuple[str, str]):
"""Run admin server."""
backend_url, backend_token = backend_server
secret_key = secrets.token_urlsafe(32)
port = test_ports.admin
with in_tempdir() as admin_work_path:
admin_db = admin_work_path / "ssh_admin.db"
admin_settings = AdminServerSettings.model_validate(
{
"sshecret_backend_url": backend_url,
"backend_token": backend_token,
"secret_key": secret_key,
"listen_address": "127.0.0.1",
"port": port,
"database": str(admin_db.absolute()),
"password_manager_directory": str(admin_work_path.absolute()),
}
)
admin_app = create_admin_app(admin_settings)
config = uvicorn.Config(app=admin_app, port=port, loop="asyncio")
server = uvicorn.Server(config=config)
server_task = asyncio.create_task(server.serve())
await asyncio.sleep(0.1)
admin_url = f"http://127.0.0.1:{port}"
admin_password = secrets.token_urlsafe(10)
create_test_admin_user(admin_settings, "test", admin_password)
await asyncio.sleep(0.1)
yield (admin_url, ("test", admin_password))
server.should_exit = True
await server_task
@pytest_asyncio.fixture(scope=TEST_SCOPE, name="ssh_server", loop_scope=LOOP_SCOPE)
async def start_ssh_server(test_ports: TestPorts, backend_server: tuple[str, str]):
"""Run ssh server."""
backend_url, backend_token = backend_server
port = test_ports.sshd
with in_tempdir() as ssh_workdir:
create_sshd_server_key(ssh_workdir)
sshd_server_settings = ServerSettings.model_validate(
{
"sshecret_backend_url": backend_url,
"backend_token": backend_token,
"listen_address": "",
"port": port,
"registration": {"enabled": True, "allow_from": "0.0.0.0/0"},
"enable_ping_command": True,
}
)
ssh_server = await start_sshecret_sshd(sshd_server_settings)
await asyncio.sleep(0.1)
print(f"Started sshd on port {port}")
yield port
ssh_server.close()
await ssh_server.wait_closed()
@pytest_asyncio.fixture(scope=TEST_SCOPE, name="backend_client", loop_scope=LOOP_SCOPE)
async def create_backend_http_client(backend_server: tuple[str, str]):
"""Create a test client."""
backend_url, backend_token = backend_server
print(f"Creating backend client towards {backend_url}")
async with httpx.AsyncClient(
base_url=backend_url, headers={"X-API-Token": backend_token}
) as client:
yield client
@pytest_asyncio.fixture(name="backend_api")
async def get_test_backend_api(backend_server: tuple[str, str]) -> SshecretBackend:
"""Get the backend API."""
backend_url, backend_token = backend_server
return SshecretBackend(backend_url, backend_token)
@pytest.fixture(scope=TEST_SCOPE)
def ssh_command_runner(ssh_server: int, tmp_path: Path):
"""Run a single command on the ssh server."""
port = ssh_server
async def run_command_as(test_client: ClientData, command: str):
private_key_file = tmp_path / f"id_{test_client.name}"
write_private_key(test_client.private_key, private_key_file)
conn = await asyncssh.connect(
"127.0.0.1",
port=port,
username=test_client.name,
client_keys=[str(private_key_file)],
known_hosts=None,
)
try:
result = await conn.run(command)
return result
finally:
conn.close()
await conn.wait_closed()
return run_command_as
@pytest.fixture(name="ssh_session", scope=TEST_SCOPE)
def create_ssh_session(ssh_server: int, tmp_path: Path):
"""Create a ssh Session."""
port = ssh_server
@asynccontextmanager
async def run_process(test_client: ClientData, command: str):
private_key_file = tmp_path / f"id_{test_client.name}"
write_private_key(test_client.private_key, private_key_file)
conn = await asyncssh.connect(
"127.0.0.1",
port=port,
username=test_client.name,
client_keys=[str(private_key_file)],
known_hosts=None,
)
try:
async with conn.create_process(command) as process:
yield process
finally:
conn.close()
await conn.wait_closed()
return run_process

View File

@ -0,0 +1,41 @@
"""Helper functions."""
import os
import tempfile
from collections.abc import Iterator
from contextlib import contextmanager
from pathlib import Path
from sqlmodel import Session, create_engine
from sshecret.crypto import generate_private_key, write_private_key
from sshecret_admin.auth.authentication import hash_password
from sshecret_admin.auth.models import User, init_db
from sshecret_admin.core.settings import AdminServerSettings
def create_test_admin_user(settings: AdminServerSettings, username: str, password: str) -> None:
"""Create a test admin user."""
hashed_password = hash_password(password)
engine = create_engine(settings.admin_db)
init_db(engine)
with Session(engine) as session:
user = User(username=username, hashed_password=hashed_password)
session.add(user)
session.commit()
def create_sshd_server_key(sshd_path: Path) -> Path:
"""Create a ssh key at a general"""
server_file = sshd_path / "ssh_host_key"
private_key = generate_private_key()
write_private_key(private_key, server_file)
return server_file
@contextmanager
def in_tempdir() -> Iterator[Path]:
"""Run in a temporary directory."""
curdir = os.getcwd()
with tempfile.TemporaryDirectory() as temp_directory:
temp_path = Path(temp_directory)
os.chdir(temp_directory)
yield temp_path
os.chdir(curdir)

View File

@ -0,0 +1,55 @@
"""Tests of the admin interface."""
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
import pytest
import httpx
class TestAdminAPI:
"""Tests of the Admin REST API."""
@pytest.mark.asyncio
async def test_health_check(
self, admin_server: tuple[str, tuple[str, str]]
) -> None:
"""Test admin login."""
async with self.http_client(admin_server, False) as client:
resp = await client.get("/health")
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_admin_login(self, admin_server: tuple[str, tuple[str, str]]) -> None:
"""Test admin login."""
async with self.http_client(admin_server, False) as client:
resp = await client.get("api/v1/clients/")
assert resp.status_code == 401
async with self.http_client(admin_server, True) as client:
resp = await client.get("api/v1/clients/")
assert resp.status_code == 200
@asynccontextmanager
async def http_client(
self, admin_server: tuple[str, tuple[str, str]], authenticate: bool = True
) -> AsyncIterator[httpx.AsyncClient]:
"""Run a client towards the admin rest api."""
admin_url, credentials = admin_server
username, password = credentials
headers: dict[str, str] | None = None
if authenticate:
async with httpx.AsyncClient(base_url=admin_url) as client:
response = await client.post(
"api/v1/token", data={"username": username, "password": password}
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
token = data["access_token"]
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient(base_url=admin_url, headers=headers) as client:
yield client

View File

@ -0,0 +1,58 @@
"""Test backend.
These tests just ensure that the backend works well enough for us to run the
rest of the tests.
"""
import pytest
import httpx
from sshecret.backend import SshecretBackend
from .clients import create_test_client
@pytest.mark.asyncio
async def test_healthcheck(backend_client: httpx.AsyncClient) -> None:
"""Test healthcheck command."""
resp = await backend_client.get("/health")
assert resp.status_code == 200
assert resp.json() == {"status": "LIVE"}
@pytest.mark.asyncio
async def test_create_client(backend_api: SshecretBackend) -> None:
"""Test creating a client."""
test_client = create_test_client("test")
await backend_api.create_client("test", test_client.public_key, "A test client")
# fetch the list of clients.
clients = await backend_api.get_clients()
assert clients is not None
assert len(clients) == 1
assert clients[0].name == "test"
assert clients[0].public_key == test_client.public_key
async def test_create_secret(backend_api: SshecretBackend) -> None:
"""Test creating secrets."""
test_client = create_test_client("test")
await backend_api.create_client("test", test_client.public_key, "A test client")
await backend_api.create_client_secret("test", "mysecret", "encrypted_secret")
secrets = await backend_api.get_secrets()
assert len(secrets) == 1
assert secrets[0].name == "mysecret"
secret_to_client = await backend_api.get_secret("mysecret")
assert secret_to_client is not None
assert secret_to_client.name == "mysecret"
assert "test" in secret_to_client.clients
secret = await backend_api.get_client_secret("test", "mysecret")
assert secret is not None
assert secret == "encrypted_secret"

View File

@ -0,0 +1,139 @@
"""Tests where the sshd is the main consumer.
This essentially also tests parts of the admin API.
"""
from contextlib import asynccontextmanager
from typing import AsyncIterator
import os
import httpx
import pytest
from sshecret.crypto import decode_string
from sshecret.backend.api import SshecretBackend
from .clients import create_test_client, ClientData
from .types import CommandRunner, ProcessRunner
class TestSshd:
"""Class based tests.
This allows us to create small helpers.
"""
@pytest.mark.asyncio
async def test_get_secret(
self, backend_api: SshecretBackend, ssh_command_runner: CommandRunner
) -> None:
"""Test get secret flow."""
test_client = create_test_client("testclient")
await backend_api.create_client(
"testclient", test_client.public_key, "A test client"
)
await backend_api.create_client_secret("testclient", "testsecret", "bogus")
response = await ssh_command_runner(test_client, "get_secret testsecret")
assert response.exit_status == 0
assert response.stdout is not None
assert isinstance(response.stdout, str)
assert response.stdout.rstrip() == "bogus"
@pytest.mark.asyncio
async def test_register(
self, backend_api: SshecretBackend, ssh_session: ProcessRunner
) -> None:
"""Test registration."""
await self.register_client("new_client", ssh_session)
# Check that the client is created.
clients = await backend_api.get_clients()
assert len(clients) == 1
client = clients[0]
assert client.name == "new_client"
async def register_client(
self, name: str, ssh_session: ProcessRunner
) -> ClientData:
"""Register client."""
test_client = create_test_client(name)
async with ssh_session(test_client, "register") as session:
maxlines = 10
linenum = 0
found = False
while linenum < maxlines:
line = await session.stdout.readline()
if "Enter public key" in line:
found = True
break
assert found is True
session.stdin.write(test_client.public_key + "\n")
result = await session.stdout.readline()
assert "OK" in result
await session.wait()
return test_client
class TestSshdIntegration(TestSshd):
"""Integration tests."""
@pytest.mark.asyncio
async def test_end_to_end(
self,
backend_api: SshecretBackend,
admin_server: tuple[str, tuple[str, str]],
ssh_session: ProcessRunner,
ssh_command_runner: CommandRunner,
) -> None:
"""Test end to end."""
test_client = await self.register_client("myclient", ssh_session)
url, credentials = admin_server
username, password = credentials
async with self.admin_client(url, username, password) as http_client:
resp = await http_client.get("api/v1/clients/")
assert resp.status_code == 200
clients = resp.json()
assert len(clients) == 1
assert clients[0]["name"] == "myclient"
create_model = {
"name": "mysecret",
"clients": ["myclient"],
"value": "mypassword",
}
resp = await http_client.post("api/v1/secrets/", json=create_model)
assert resp.status_code == 200
# Login via ssh to fetch the decrypted value.
ssh_output = await ssh_command_runner(test_client, "get_secret mysecret")
assert ssh_output.stdout is not None
assert isinstance(ssh_output.stdout, str)
encrypted = ssh_output.stdout.rstrip()
decrypted = decode_string(encrypted, test_client.private_key)
assert decrypted == "mypassword"
async def login(self, url: str, username: str, password: str) -> str:
"""Login and get token."""
api_url = os.path.join(url, "api/v1", "token")
client = httpx.AsyncClient()
response = await client.post(
api_url, data={"username": username, "password": password}
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert isinstance(data["access_token"], str)
return str(data["access_token"])
@asynccontextmanager
async def admin_client(
self, url: str, username: str, password: str
) -> AsyncIterator[httpx.AsyncClient]:
"""Create an admin client."""
token = await self.login(url, username, password)
async with httpx.AsyncClient(
base_url=url, headers={"Authorization": f"Bearer {token}"}
) as client:
yield client

View File

@ -0,0 +1,29 @@
"""Typings."""
import asyncssh
from typing import Any, AsyncContextManager, Protocol
from dataclasses import dataclass
from collections.abc import Callable, Awaitable
from .clients import ClientData
PortFactory = Callable[[], int]
@dataclass
class TestPorts:
"""Test port dataclass."""
backend: int
admin: int
sshd: int
CommandRunner = Callable[[ClientData, str], Awaitable[asyncssh.SSHCompletedProcess]]
class ProcessRunner(Protocol):
"""Process runner typing."""
def __call__(self, test_client: ClientData, command: str) -> AsyncContextManager[asyncssh.SSHClientProcess[Any]]:
...

View File

View File

View File

@ -0,0 +1,538 @@
"""Tests of the backend api using pytest."""
import logging
from pathlib import Path
from httpx import Response
import pytest
from fastapi.testclient import TestClient
from sshecret.crypto import generate_private_key, generate_public_key_string
from sshecret_backend.app import create_backend_app
from sshecret_backend.testing import create_test_token
from sshecret_backend.view_models import AuditView
from sshecret_backend.settings import BackendSettings
LOG = logging.getLogger()
handler = logging.StreamHandler()
formatter = logging.Formatter("'%(asctime)s - %(levelname)s - %(message)s'")
handler.setFormatter(formatter)
LOG.addHandler(handler)
LOG.setLevel(logging.DEBUG)
def make_test_key() -> str:
"""Generate a test key."""
private_key = generate_private_key()
return generate_public_key_string(private_key.public_key())
def create_client(
test_client: TestClient,
name: str,
public_key: str | None = None,
description: str | None = None,
) -> Response:
"""Create client."""
if not public_key:
public_key = make_test_key()
data = {
"name": name,
"public_key": public_key,
}
if description:
data["description"] = description
create_response = test_client.post("/api/v1/clients", json=data)
return create_response
@pytest.fixture(name="test_client")
def create_client_fixture(tmp_path: Path):
"""Test client fixture."""
db_file = tmp_path / "backend.db"
print(f"DB File: {db_file.absolute()}")
settings = BackendSettings(database=str(db_file.absolute()))
app = create_backend_app(settings)
token = create_test_token(settings)
test_client = TestClient(app, headers={"X-API-Token": token})
yield test_client
def test_missing_token(test_client: TestClient) -> None:
"""Test logging in with missing token."""
# Save headers
old_headers = test_client.headers
test_client.headers = {}
response = test_client.get("/api/v1/clients/", headers={})
assert response.status_code == 422
test_client.headers = old_headers
def test_incorrect_token(test_client: TestClient) -> None:
"""Test logging in with missing token."""
response = test_client.get("/api/v1/clients/", headers={"X-API-Token": "WRONG"})
assert response.status_code == 401
def test_with_token(test_client: TestClient) -> None:
"""Test with a valid token."""
response = test_client.get("/api/v1/clients/")
assert response.status_code == 200
data = response.json()
assert data["total_results"] == 0
def test_create_client(test_client: TestClient) -> None:
"""Test creating a client."""
client_name = "test"
client_publickey = make_test_key()
create_response = create_client(test_client, client_name, client_publickey)
assert create_response.status_code == 200
response = test_client.get("/api/v1/clients")
assert response.status_code == 200
clients_result = response.json()
clients = clients_result["clients"]
assert isinstance(clients, list)
client = clients[0]
assert isinstance(client, dict)
assert client.get("name") == client_name
assert client.get("created_at") is not None
def test_delete_client(test_client: TestClient) -> None:
"""Test creating a client."""
client_name = "test"
create_response = create_client(
test_client,
client_name,
)
assert create_response.status_code == 200
resp = test_client.delete("/api/v1/clients/test")
assert resp.status_code == 200
resp = test_client.get("/api/v1/clients/test")
assert resp.status_code == 404
def test_add_secret(test_client: TestClient) -> None:
"""Test adding a secret to a client."""
client_name = "test"
client_publickey = make_test_key()
create_response = create_client(
test_client,
client_name,
client_publickey,
)
assert create_response.status_code == 200
secret_name = "mysecret"
secret_value = "shhhh"
data = {"name": secret_name, "secret": secret_value, "description": "A test secret"}
response = test_client.post("/api/v1/clients/test/secrets/", json=data)
assert response.status_code == 200
# Get it back
get_response = test_client.get("/api/v1/clients/test/secrets/mysecret")
assert get_response.status_code == 200
secret_body = get_response.json()
assert secret_body["name"] == data["name"]
assert secret_body["secret"] == data["secret"]
def test_delete_secret(test_client: TestClient) -> None:
"""Test deleting a secret."""
test_add_secret(test_client)
resp = test_client.delete("/api/v1/clients/test/secrets/mysecret")
assert resp.status_code == 200
get_response = test_client.get("/api/v1/clients/test/secrets/mysecret")
assert get_response.status_code == 404
def test_put_add_secret(test_client: TestClient) -> None:
"""Test adding secret via PUT."""
# Use the test_create_client function to create a client.
test_create_client(test_client)
secret_name = "mysecret"
secret_value = "shhhh"
data = {"name": secret_name, "secret": secret_value, "description": None}
response = test_client.put(
"/api/v1/clients/test/secrets/mysecret",
json={"value": secret_value},
)
assert response.status_code == 200
response_model = response.json()
del response_model["created_at"]
del response_model["updated_at"]
assert response_model == data
def test_put_update_secret(test_client: TestClient) -> None:
"""Test updating a client secret."""
test_add_secret(test_client)
new_value = "itsasecret"
update_response = test_client.put(
"/api/v1/clients/test/secrets/mysecret",
json={"value": new_value},
)
assert update_response.status_code == 200
expected = {"name": "mysecret", "secret": new_value}
response_model = update_response.json()
assert {
"name": response_model["name"],
"secret": response_model["secret"],
} == expected
# Ensure that the updated_at has been set.
assert "updated_at" in response_model
def test_audit_logging(test_client: TestClient) -> None:
"""Test audit logging."""
public_key = make_test_key()
create_client_resp = create_client(test_client, "test", public_key)
assert create_client_resp.status_code == 200
secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"}
for name, secret in secrets.items():
add_resp = test_client.post(
"/api/v1/clients/test/secrets/",
json={"name": name, "secret": secret},
)
assert add_resp.status_code == 200
# Fetch the entire client.
get_client_resp = test_client.get("/api/v1/clients/test")
assert get_client_resp.status_code == 200
# Fetch the audit log
audit_log_resp = test_client.get("/api/v1/audit/")
assert audit_log_resp.status_code == 200
audit_logs = audit_log_resp.json()
assert len(audit_logs) > 0
for entry in audit_logs:
# Let's try to reassemble the objects
audit_log = AuditView.model_validate(entry)
assert audit_log is not None
# def test_audit_log_filtering(
# session: Session, test_client: TestClient
# ) -> None:
# """Test audit log filtering."""
# # Create a lot of test data, but just manually.
# audit_log_amount = 150
# entries: list[AuditLog] = []
# for i in range(audit_log_amount):
# client_id = i % 5
# entries.append(
# AuditLog(
# operation="TEST",
# object_id=str(i),
# client_name=f"client-{client_id}",
# message="Test Message",
# )
# )
# session.add_all(entries)
# session.commit()
# # This should have generated a lot of audit messages
# audit_path = "/api/v1/audit/"
# audit_log_resp = test_client.get(audit_path)
# assert audit_log_resp.status_code == 200
# entries = audit_log_resp.json()
# assert len(entries) == 100 # We get 100 at a time
# audit_log_resp = test_client.get(
# audit_path, params={"offset": 100}
# )
# entries = audit_log_resp.json()
# assert len(entries) == 52 # There should be 50 + the two requests we made
# # Try to get a specific client
# # There should be 30 log entries for each client.
# audit_log_resp = test_client.get(
# audit_path, params={"filter_client": "client-1"}
# )
# entries = audit_log_resp.json()
# assert len(entries) == 30
def test_secret_invalidation(test_client: TestClient) -> None:
"""Test secret invalidation."""
initial_key = make_test_key()
create_client_resp = create_client(test_client, "test", initial_key)
assert create_client_resp.status_code == 200
secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"}
for name, secret in secrets.items():
add_resp = test_client.post(
"/api/v1/clients/test/secrets/",
json={"name": name, "secret": secret},
)
assert add_resp.status_code == 200
# Update the public-key. This should cause all secrets to be invalidated
# and no longer associated with a client.
new_key = make_test_key()
update_resp = test_client.post(
"/api/v1/clients/test/public-key",
json={"public_key": new_key},
)
assert update_resp.status_code == 200
# Fetch the client. The list of secrets should be empty.
get_resp = test_client.get("/api/v1/clients/test")
assert get_resp.status_code == 200
client = get_resp.json()
secrets = client.get("secrets")
assert bool(secrets) is False
def test_client_default_policies(
test_client: TestClient,
) -> None:
"""Test client policies."""
public_key = make_test_key()
resp = create_client(test_client, "test")
assert resp.status_code == 200
# Fetch policies, should return *
resp = test_client.get("/api/v1/clients/test/policies/")
assert resp.status_code == 200
policies = resp.json()
assert policies["sources"] == ["0.0.0.0/0", "::/0"]
def test_client_policy_update_one(test_client: TestClient) -> None:
"""Update client policy with single policy."""
public_key = make_test_key()
resp = create_client(test_client, "test", public_key)
assert resp.status_code == 200
policy = ["192.0.2.1"]
resp = test_client.put("/api/v1/clients/test/policies/", json={"sources": policy})
assert resp.status_code == 200
resp = test_client.get("/api/v1/clients/test/policies/")
assert resp.status_code == 200
policies = resp.json()
assert policies["sources"] == policy
def test_client_policy_update_advanced(test_client: TestClient) -> None:
"""Test other policy update scenarios."""
public_key = make_test_key()
resp = create_client(test_client, "test", public_key)
assert resp.status_code == 200
policy = ["192.0.2.1", "198.18.0.0/24"]
resp = test_client.put("/api/v1/clients/test/policies/", json={"sources": policy})
assert resp.status_code == 200
resp = test_client.get("/api/v1/clients/test/policies/")
assert resp.status_code == 200
policies = resp.json()
assert "192.0.2.1" in policies["sources"]
assert "198.18.0.0/24" in policies["sources"]
# Try to set it to something incorrect
policy = ["obviosly_wrong"]
resp = test_client.put("/api/v1/clients/test/policies/", json={"sources": policy})
assert resp.status_code == 422
# Check that the old value is still there
resp = test_client.get("/api/v1/clients/test/policies/")
assert resp.status_code == 200
policies = resp.json()
assert "192.0.2.1" in policies["sources"]
assert "198.18.0.0/24" in policies["sources"]
# Clear the policies
#
def test_client_policy_update_unset(test_client: TestClient) -> None:
"""Test clearing the client policy."""
public_key = make_test_key()
resp = create_client(test_client, "test", public_key)
assert resp.status_code == 200
policy = ["192.0.2.1", "198.18.0.0/24"]
resp = test_client.put("/api/v1/clients/test/policies/", json={"sources": policy})
assert resp.status_code == 200
policies = resp.json()
assert "192.0.2.1" in policies["sources"]
assert "198.18.0.0/24" in policies["sources"]
# Now we clear the policies
resp = test_client.put("/api/v1/clients/test/policies/", json={"sources": []})
assert resp.status_code == 200
policies = resp.json()
assert policies["sources"] == ["0.0.0.0/0", "::/0"]
def test_client_update(test_client: TestClient) -> None:
"""Test generic update of a client."""
public_key = make_test_key()
resp = create_client(test_client, "test", public_key, "PRE")
assert resp.status_code == 200
resp = test_client.get("/api/v1/clients/test")
assert resp.status_code == 200
client_data = resp.json()
assert client_data["description"] == "PRE"
# Update the description
new_client_data = {
"name": "test",
"description": "POST",
"public_key": client_data["public_key"],
}
resp = test_client.put("/api/v1/clients/test", json=new_client_data)
assert resp.status_code == 200
client_data = resp.json()
assert client_data["description"] == "POST"
resp = test_client.get("/api/v1/clients/test")
assert resp.status_code == 200
client_data = resp.json()
assert client_data["description"] == "POST"
def test_get_secret_list(test_client: TestClient) -> None:
"""Test the secret to client map view."""
# Make 4 clients
for x in range(4):
public_key = make_test_key()
create_client(test_client, f"client-{x}", public_key)
# Create a secret that only this client has.
resp = test_client.put(
f"/api/v1/clients/client-{x}/secrets/client-{x}", json={"value": "SECRET"}
)
assert resp.status_code == 200
# Create a secret that all of them have.
resp = test_client.put(
f"/api/v1/clients/client-{x}/secrets/commonsecret", json={"value": "SECRET"}
)
assert resp.status_code == 200
# Get the secret list
resp = test_client.get("/api/v1/secrets/")
assert resp.status_code == 200
data = resp.json()
assert isinstance(data, list)
assert len(data) == 5
for entry in data:
if entry["name"] == "commonsecret":
assert len(entry["clients"]) == 4
else:
assert len(entry["clients"]) == 1
assert entry["clients"][0] == entry["name"]
def test_get_secret_clients(test_client: TestClient) -> None:
"""Get the clients for a single secret."""
for x in range(4):
public_key = make_test_key()
create_client(test_client, f"client-{x}", public_key)
# Create a secret that every second of them have.
if x % 2 == 1:
continue
resp = test_client.put(
f"/api/v1/clients/client-{x}/secrets/commonsecret", json={"value": "SECRET"}
)
assert resp.status_code == 200
resp = test_client.get("/api/v1/secrets/commonsecret")
assert resp.status_code == 200
data = resp.json()
assert data["name"] == "commonsecret"
assert "client-0" in data["clients"]
assert "client-1" not in data["clients"]
assert len(data["clients"]) == 2
def test_searching(test_client: TestClient) -> None:
"""Test searching."""
for x in range(4):
# Create four clients
create_client(test_client, f"client-{x}")
# Create one with a different name.
create_client(test_client, "othername")
# Search for a specific one.
resp = test_client.get("/api/v1/clients/", params={"name": "othername"})
assert resp.status_code == 200
result = resp.json()
assert result["total_results"] == 1
assert result["clients"][0]["name"] == "othername"
client_id = result["clients"][0]["id"]
# Search by ID
resp = test_client.get("/api/v1/clients/", params={"id": client_id})
assert resp.status_code == 200
result = resp.json()
assert result["total_results"] == 1
assert result["clients"][0]["name"] == "othername"
# Search for the four similarly named ones
resp = test_client.get("/api/v1/clients/", params={"name__like": "client-%"})
assert resp.status_code == 200
result = resp.json()
assert result["total_results"] == 4
assert str(result["clients"][0]["name"]).startswith("client-")
def test_operations_with_id(test_client: TestClient) -> None:
"""Test operations using ID instead of name."""
create_client(test_client, "test")
resp = test_client.get("/api/v1/clients/")
assert resp.status_code == 200
data = resp.json()
client = data["clients"][0]
client_id = client["id"]
resp = test_client.get(f"/api/v1/clients/{client_id}")
assert resp.status_code == 200
data = resp.json()
assert data["name"] == "test"
def test_write_audit_log(test_client: TestClient) -> None:
"""Test writing to the audit log."""
params = {
"subsystem": "backend",
"operation": "read",
"message": "Test Message"
}
resp = test_client.post("/api/v1/audit", json=params)
assert resp.status_code == 200
resp = test_client.get("/api/v1/audit")
assert resp.status_code == 200
data = resp.json()
entry = data[0]
for key, value in params.items():
assert entry[key] == value

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,183 @@
import asyncio
from typing import Any
import pytest
import uuid
import asyncssh
import tempfile
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock
from ipaddress import IPv4Network, IPv6Network
from sshecret_sshd.ssh_server import run_ssh_server
from sshecret_sshd.settings import ClientRegistrationSettings
from .types import (
Client,
ClientKey,
ClientRegistry,
SshServerFixtureFun,
SshServerFixture,
)
@pytest.fixture(scope="function")
def client_registry() -> ClientRegistry:
clients = {}
secrets = {}
async def add_client(
name: str,
secret_names: list[str] | None = None,
policies: list[str] | None = None,
) -> str:
private_key = asyncssh.generate_private_key("ssh-rsa")
public_key = private_key.export_public_key()
clients[name] = ClientKey(name, private_key, public_key.decode().rstrip())
secrets.update({(name, s): f"mocked-secret-{s}" for s in (secret_names or [])})
return clients[name]
return {"clients": clients, "secrets": secrets, "add_client": add_client}
@pytest.fixture(scope="function")
async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
backend = MagicMock()
clients_data = client_registry["clients"]
secrets_data = client_registry["secrets"]
async def get_client(name: str) -> Client | None:
client_key = clients_data.get(name)
if client_key:
response_model = Client(
id=uuid.uuid4(),
name=name,
description=f"Mock client {name}",
public_key=client_key.public_key,
secrets=[s for (c, s) in secrets_data if c == name],
policies=[IPv4Network("0.0.0.0/0"), IPv6Network("::/0")],
)
return response_model
return None
async def get_client_secret(name: str, secret_name: str) -> str | None:
secret = secrets_data.get((name, secret_name), None)
return secret
async def create_client(name: str, public_key: str) -> None:
"""Create client.
This only works if you register a client called template first.
Otherwise we can't test this...
"""
if "template" not in clients_data:
raise RuntimeError(
"Error, must have a client called template for this to work."
)
clients_data[name] = clients_data["template"]
for secret_key, secret in secrets_data.items():
s_client, secret_name = secret_key
if s_client != "template":
continue
secrets_data[(name, secret_name)] = secret
async def write_audit(*args, **kwargs):
"""Write audit mock."""
return None
backend.get_client = AsyncMock(side_effect=get_client)
backend.get_client_secret = AsyncMock(side_effect=get_client_secret)
backend.create_client = AsyncMock(side_effect=create_client)
audit = MagicMock()
audit.write_async = AsyncMock(side_effect=write_audit)
backend.audit = MagicMock(return_value=audit)
return backend
@pytest.fixture(scope="function")
async def ssh_server(
mock_backend: MagicMock, unused_tcp_port: int
) -> SshServerFixtureFun:
port = unused_tcp_port
private_key = asyncssh.generate_private_key("ssh-ed25519")
key_str = private_key.export_private_key()
with tempfile.NamedTemporaryFile("w+", delete=True) as key_file:
key_file.write(key_str.decode())
key_file.flush()
registration_settings = ClientRegistrationSettings(
enabled=True, allow_from=[IPv4Network("0.0.0.0/0")]
)
server = await run_ssh_server(
backend=mock_backend,
listen_address="localhost",
port=port,
keys=[key_file.name],
registration=registration_settings,
enable_ping_command=True,
)
await asyncio.sleep(0.1)
yield server, port
server.close()
await server.wait_closed()
@pytest.fixture(scope="function")
def ssh_command_runner(ssh_server: SshServerFixture, client_registry: ClientRegistry):
"""Run a single command.
Tricky typing!
"""
_, port = ssh_server
async def run_command_as(name: str, command: str):
client_key = client_registry["clients"][name]
conn = await asyncssh.connect(
"127.0.0.1",
port=port,
username=name,
client_keys=[client_key.private_key],
known_hosts=None,
)
try:
result = await conn.run(command)
return result
finally:
conn.close()
await conn.wait_closed()
return run_command_as
@pytest.fixture(scope="function")
def ssh_session(ssh_server: SshServerFixture, client_registry: ClientRegistry):
"""Yield an interactive session."""
_, port = ssh_server
@asynccontextmanager
async def run_process_as(name: str, command: str, client: str | None = None):
if not client:
client = name
client_key = client_registry["clients"][client]
conn = await asyncssh.connect(
"127.0.0.1",
port=port,
username=name,
client_keys=[client_key.private_key],
known_hosts=None,
)
try:
async with conn.create_process(command) as process:
yield process
finally:
conn.close()
await conn.wait_closed()
return run_process_as

View File

@ -0,0 +1,33 @@
"""Test get secret."""
import pytest
from .types import ClientRegistry, CommandRunner
@pytest.mark.asyncio
async def test_get_secret(
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
) -> None:
"""Test that we can get a secret."""
await client_registry["add_client"]("test-client", ["mysecret"])
result = await ssh_command_runner("test-client", "get_secret mysecret")
assert result.stdout is not None
assert isinstance(result.stdout, str)
assert result.stdout.rstrip() == "mocked-secret-mysecret"
@pytest.mark.asyncio
async def test_invalid_secret_name(
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
) -> None:
"""Test getting an invalid secret name."""
await client_registry["add_client"]("test-client")
result = await ssh_command_runner("test-client", "get_secret mysecret")
assert result.exit_status == 1
assert result.stderr == "Error: No secret available with the given name."

View File

@ -0,0 +1,18 @@
import pytest
from .types import ClientRegistry, CommandRunner
@pytest.mark.asyncio
async def test_ping_command(
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
) -> None:
# Register a test client with default policies and no secrets
await client_registry["add_client"]("test-pinger")
result = await ssh_command_runner("test-pinger", "ping")
assert result.exit_status == 0
assert result.stdout is not None
assert isinstance(result.stdout, str)
assert result.stdout.rstrip() == "PONG"

View File

@ -0,0 +1,40 @@
"""Test registration."""
import pytest
from .types import ClientRegistry, CommandRunner, ProcessRunner
@pytest.mark.asyncio
async def test_register_client(
ssh_session: ProcessRunner,
ssh_command_runner: CommandRunner,
client_registry: ClientRegistry,
) -> None:
"""Test client registration."""
await client_registry["add_client"]("template", ["testsecret"])
public_key = client_registry["clients"]["template"].public_key.rstrip() + "\n"
async with ssh_session("newclient", "register", "template") as session:
maxlines = 10
linenum = 0
found = False
while linenum < maxlines:
line = await session.stdout.readline()
if "Enter public key" in line:
found = True
break
assert found is True
session.stdin.write(public_key)
result = await session.stdout.readline()
assert "OK" in result
# Test that we can connect
result = await ssh_command_runner("newclient", "get_secret testsecret")
assert result.stdout is not None
assert isinstance(result.stdout, str)
assert result.stdout.rstrip() == "mocked-secret-testsecret"

View File

@ -0,0 +1,63 @@
"""Types for the various test properties."""
import uuid
from datetime import datetime
from dataclasses import dataclass, field
from ipaddress import IPv4Network, IPv6Network
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any, Protocol, TypedDict, AsyncContextManager
import asyncssh
SshServerFixture = tuple[str, int]
SshServerFixtureFun = AsyncGenerator[tuple[asyncssh.SSHAcceptor, int], None]
@dataclass
class Client:
"""Mock client."""
id: uuid.UUID
name: str
description: str | None
public_key: str
secrets: list[str]
policies: list[IPv4Network | IPv6Network]
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)
@dataclass
class ClientKey:
name: str
private_key: asyncssh.SSHKey
public_key: str
class AddClientFun(Protocol):
"""Add client function."""
def __call__(
self,
name: str,
secret_names: list[str] | None = None,
policies: list[str] | None = None,
) -> Awaitable[str]: ...
class ProcessRunner(Protocol):
"""Process runner typing."""
def __call__(
self, name: str, command: str, client: str | None = None
) -> AsyncContextManager[asyncssh.SSHClientProcess[Any]]: ...
class ClientRegistry(TypedDict):
"""Client registry typing."""
clients: dict[str, ClientKey]
secrets: dict[tuple[str, str], str]
add_client: AddClientFun
CommandRunner = Callable[[str, str], Awaitable[asyncssh.SSHCompletedProcess]]