Centralize testing
This commit is contained in:
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
|
||||
0
tests/integration/__init__.py
Normal file
0
tests/integration/__init__.py
Normal file
27
tests/integration/clients.py
Normal file
27
tests/integration/clients.py
Normal file
@ -0,0 +1,27 @@
|
||||
"""Client helpers."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
|
||||
from sshecret.crypto import generate_private_key, generate_public_key_string
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientData:
|
||||
"""Test client."""
|
||||
|
||||
name: str
|
||||
private_key: rsa.RSAPrivateKey
|
||||
|
||||
@property
|
||||
def public_key(self) -> str:
|
||||
"""Return public key as string."""
|
||||
return generate_public_key_string(self.private_key.public_key())
|
||||
|
||||
|
||||
def create_test_client(name: str) -> ClientData:
|
||||
"""Create test client."""
|
||||
return ClientData(
|
||||
name=name,
|
||||
private_key=generate_private_key()
|
||||
)
|
||||
209
tests/integration/conftest.py
Normal file
209
tests/integration/conftest.py
Normal file
@ -0,0 +1,209 @@
|
||||
"""Test library.
|
||||
|
||||
Strategy:
|
||||
|
||||
We start by spawning the backend server, and create two test keys.
|
||||
|
||||
Then we spawn the sshd and the admin api.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import asyncssh
|
||||
import secrets
|
||||
import tempfile
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
import uvicorn
|
||||
from sshecret.backend import SshecretBackend
|
||||
from sshecret.crypto import (
|
||||
generate_private_key,
|
||||
generate_public_key_string,
|
||||
write_private_key,
|
||||
)
|
||||
from sshecret_admin.core.app import create_admin_app
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
from sshecret_backend.app import create_backend_app
|
||||
from sshecret_backend.settings import BackendSettings
|
||||
from sshecret_backend.testing import create_test_token
|
||||
from sshecret_sshd.settings import ServerSettings
|
||||
from sshecret_sshd.ssh_server import start_sshecret_sshd
|
||||
|
||||
from .clients import ClientData
|
||||
from .helpers import create_sshd_server_key, create_test_admin_user, in_tempdir
|
||||
from .types import PortFactory, TestPorts
|
||||
|
||||
TEST_SCOPE = "function"
|
||||
LOOP_SCOPE = "function"
|
||||
|
||||
|
||||
def make_test_key() -> str:
|
||||
"""Generate a test key."""
|
||||
private_key = generate_private_key()
|
||||
return generate_public_key_string(private_key.public_key())
|
||||
|
||||
|
||||
@pytest.fixture(name="test_ports", scope="session")
|
||||
def generate_test_ports(unused_tcp_port_factory: PortFactory) -> TestPorts:
|
||||
"""Generate the test ports."""
|
||||
test_ports = TestPorts(
|
||||
backend=unused_tcp_port_factory(),
|
||||
admin=unused_tcp_port_factory(),
|
||||
sshd=unused_tcp_port_factory(),
|
||||
)
|
||||
print(f"{test_ports=!r}")
|
||||
return test_ports
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope=TEST_SCOPE, name="backend_server", loop_scope=LOOP_SCOPE)
|
||||
async def run_backend_server(test_ports: TestPorts):
|
||||
"""Run the backend server."""
|
||||
port = test_ports.backend
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
backend_work_path = Path(tmp_dir)
|
||||
db_file = backend_work_path / "backend.db"
|
||||
backend_settings = BackendSettings(database=str(db_file.absolute()))
|
||||
backend_app = create_backend_app(backend_settings)
|
||||
token = create_test_token(backend_settings)
|
||||
config = uvicorn.Config(app=backend_app, port=port, loop="asyncio")
|
||||
server = uvicorn.Server(config=config)
|
||||
server_task = asyncio.create_task(server.serve())
|
||||
await asyncio.sleep(0.1)
|
||||
backend_url = f"http://127.0.0.1:{port}"
|
||||
yield (backend_url, token)
|
||||
server.should_exit = True
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope=TEST_SCOPE, name="admin_server", loop_scope=LOOP_SCOPE)
|
||||
async def run_admin_server(test_ports: TestPorts, backend_server: tuple[str, str]):
|
||||
"""Run admin server."""
|
||||
backend_url, backend_token = backend_server
|
||||
secret_key = secrets.token_urlsafe(32)
|
||||
port = test_ports.admin
|
||||
with in_tempdir() as admin_work_path:
|
||||
admin_db = admin_work_path / "ssh_admin.db"
|
||||
admin_settings = AdminServerSettings.model_validate(
|
||||
{
|
||||
"sshecret_backend_url": backend_url,
|
||||
"backend_token": backend_token,
|
||||
"secret_key": secret_key,
|
||||
"listen_address": "127.0.0.1",
|
||||
"port": port,
|
||||
"database": str(admin_db.absolute()),
|
||||
"password_manager_directory": str(admin_work_path.absolute()),
|
||||
}
|
||||
)
|
||||
admin_app = create_admin_app(admin_settings)
|
||||
config = uvicorn.Config(app=admin_app, port=port, loop="asyncio")
|
||||
server = uvicorn.Server(config=config)
|
||||
server_task = asyncio.create_task(server.serve())
|
||||
await asyncio.sleep(0.1)
|
||||
admin_url = f"http://127.0.0.1:{port}"
|
||||
admin_password = secrets.token_urlsafe(10)
|
||||
create_test_admin_user(admin_settings, "test", admin_password)
|
||||
await asyncio.sleep(0.1)
|
||||
yield (admin_url, ("test", admin_password))
|
||||
server.should_exit = True
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope=TEST_SCOPE, name="ssh_server", loop_scope=LOOP_SCOPE)
|
||||
async def start_ssh_server(test_ports: TestPorts, backend_server: tuple[str, str]):
|
||||
"""Run ssh server."""
|
||||
backend_url, backend_token = backend_server
|
||||
port = test_ports.sshd
|
||||
with in_tempdir() as ssh_workdir:
|
||||
create_sshd_server_key(ssh_workdir)
|
||||
sshd_server_settings = ServerSettings.model_validate(
|
||||
{
|
||||
"sshecret_backend_url": backend_url,
|
||||
"backend_token": backend_token,
|
||||
"listen_address": "",
|
||||
"port": port,
|
||||
"registration": {"enabled": True, "allow_from": "0.0.0.0/0"},
|
||||
"enable_ping_command": True,
|
||||
}
|
||||
)
|
||||
|
||||
ssh_server = await start_sshecret_sshd(sshd_server_settings)
|
||||
await asyncio.sleep(0.1)
|
||||
print(f"Started sshd on port {port}")
|
||||
yield port
|
||||
|
||||
ssh_server.close()
|
||||
await ssh_server.wait_closed()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope=TEST_SCOPE, name="backend_client", loop_scope=LOOP_SCOPE)
|
||||
async def create_backend_http_client(backend_server: tuple[str, str]):
|
||||
"""Create a test client."""
|
||||
backend_url, backend_token = backend_server
|
||||
print(f"Creating backend client towards {backend_url}")
|
||||
async with httpx.AsyncClient(
|
||||
base_url=backend_url, headers={"X-API-Token": backend_token}
|
||||
) as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(name="backend_api")
|
||||
async def get_test_backend_api(backend_server: tuple[str, str]) -> SshecretBackend:
|
||||
"""Get the backend API."""
|
||||
backend_url, backend_token = backend_server
|
||||
return SshecretBackend(backend_url, backend_token)
|
||||
|
||||
|
||||
@pytest.fixture(scope=TEST_SCOPE)
|
||||
def ssh_command_runner(ssh_server: int, tmp_path: Path):
|
||||
"""Run a single command on the ssh server."""
|
||||
port = ssh_server
|
||||
|
||||
async def run_command_as(test_client: ClientData, command: str):
|
||||
private_key_file = tmp_path / f"id_{test_client.name}"
|
||||
write_private_key(test_client.private_key, private_key_file)
|
||||
|
||||
conn = await asyncssh.connect(
|
||||
"127.0.0.1",
|
||||
port=port,
|
||||
username=test_client.name,
|
||||
client_keys=[str(private_key_file)],
|
||||
known_hosts=None,
|
||||
)
|
||||
try:
|
||||
result = await conn.run(command)
|
||||
return result
|
||||
finally:
|
||||
conn.close()
|
||||
await conn.wait_closed()
|
||||
|
||||
return run_command_as
|
||||
|
||||
|
||||
@pytest.fixture(name="ssh_session", scope=TEST_SCOPE)
|
||||
def create_ssh_session(ssh_server: int, tmp_path: Path):
|
||||
"""Create a ssh Session."""
|
||||
port = ssh_server
|
||||
|
||||
@asynccontextmanager
|
||||
async def run_process(test_client: ClientData, command: str):
|
||||
private_key_file = tmp_path / f"id_{test_client.name}"
|
||||
write_private_key(test_client.private_key, private_key_file)
|
||||
conn = await asyncssh.connect(
|
||||
"127.0.0.1",
|
||||
port=port,
|
||||
username=test_client.name,
|
||||
client_keys=[str(private_key_file)],
|
||||
known_hosts=None,
|
||||
)
|
||||
try:
|
||||
async with conn.create_process(command) as process:
|
||||
yield process
|
||||
finally:
|
||||
conn.close()
|
||||
await conn.wait_closed()
|
||||
|
||||
return run_process
|
||||
41
tests/integration/helpers.py
Normal file
41
tests/integration/helpers.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""Helper functions."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from sqlmodel import Session, create_engine
|
||||
from sshecret.crypto import generate_private_key, write_private_key
|
||||
from sshecret_admin.auth.authentication import hash_password
|
||||
from sshecret_admin.auth.models import User, init_db
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
|
||||
def create_test_admin_user(settings: AdminServerSettings, username: str, password: str) -> None:
|
||||
"""Create a test admin user."""
|
||||
hashed_password = hash_password(password)
|
||||
engine = create_engine(settings.admin_db)
|
||||
init_db(engine)
|
||||
with Session(engine) as session:
|
||||
user = User(username=username, hashed_password=hashed_password)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
|
||||
def create_sshd_server_key(sshd_path: Path) -> Path:
|
||||
"""Create a ssh key at a general"""
|
||||
server_file = sshd_path / "ssh_host_key"
|
||||
private_key = generate_private_key()
|
||||
write_private_key(private_key, server_file)
|
||||
return server_file
|
||||
|
||||
|
||||
@contextmanager
|
||||
def in_tempdir() -> Iterator[Path]:
|
||||
"""Run in a temporary directory."""
|
||||
curdir = os.getcwd()
|
||||
with tempfile.TemporaryDirectory() as temp_directory:
|
||||
temp_path = Path(temp_directory)
|
||||
os.chdir(temp_directory)
|
||||
yield temp_path
|
||||
os.chdir(curdir)
|
||||
55
tests/integration/test_admin_api.py
Normal file
55
tests/integration/test_admin_api.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""Tests of the admin interface."""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
import pytest
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class TestAdminAPI:
|
||||
"""Tests of the Admin REST API."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check(
|
||||
self, admin_server: tuple[str, tuple[str, str]]
|
||||
) -> None:
|
||||
"""Test admin login."""
|
||||
async with self.http_client(admin_server, False) as client:
|
||||
resp = await client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_login(self, admin_server: tuple[str, tuple[str, str]]) -> None:
|
||||
"""Test admin login."""
|
||||
|
||||
async with self.http_client(admin_server, False) as client:
|
||||
resp = await client.get("api/v1/clients/")
|
||||
assert resp.status_code == 401
|
||||
|
||||
async with self.http_client(admin_server, True) as client:
|
||||
resp = await client.get("api/v1/clients/")
|
||||
assert resp.status_code == 200
|
||||
|
||||
@asynccontextmanager
|
||||
async def http_client(
|
||||
self, admin_server: tuple[str, tuple[str, str]], authenticate: bool = True
|
||||
) -> AsyncIterator[httpx.AsyncClient]:
|
||||
"""Run a client towards the admin rest api."""
|
||||
admin_url, credentials = admin_server
|
||||
username, password = credentials
|
||||
headers: dict[str, str] | None = None
|
||||
if authenticate:
|
||||
async with httpx.AsyncClient(base_url=admin_url) as client:
|
||||
|
||||
response = await client.post(
|
||||
"api/v1/token", data={"username": username, "password": password}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "access_token" in data
|
||||
token = data["access_token"]
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
async with httpx.AsyncClient(base_url=admin_url, headers=headers) as client:
|
||||
yield client
|
||||
58
tests/integration/test_backend.py
Normal file
58
tests/integration/test_backend.py
Normal file
@ -0,0 +1,58 @@
|
||||
"""Test backend.
|
||||
|
||||
These tests just ensure that the backend works well enough for us to run the
|
||||
rest of the tests.
|
||||
|
||||
"""
|
||||
import pytest
|
||||
import httpx
|
||||
from sshecret.backend import SshecretBackend
|
||||
from .clients import create_test_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_healthcheck(backend_client: httpx.AsyncClient) -> None:
|
||||
"""Test healthcheck command."""
|
||||
resp = await backend_client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"status": "LIVE"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_client(backend_api: SshecretBackend) -> None:
|
||||
"""Test creating a client."""
|
||||
test_client = create_test_client("test")
|
||||
await backend_api.create_client("test", test_client.public_key, "A test client")
|
||||
|
||||
# fetch the list of clients.
|
||||
|
||||
clients = await backend_api.get_clients()
|
||||
assert clients is not None
|
||||
|
||||
assert len(clients) == 1
|
||||
|
||||
assert clients[0].name == "test"
|
||||
|
||||
assert clients[0].public_key == test_client.public_key
|
||||
|
||||
async def test_create_secret(backend_api: SshecretBackend) -> None:
|
||||
"""Test creating secrets."""
|
||||
test_client = create_test_client("test")
|
||||
await backend_api.create_client("test", test_client.public_key, "A test client")
|
||||
|
||||
await backend_api.create_client_secret("test", "mysecret", "encrypted_secret")
|
||||
|
||||
secrets = await backend_api.get_secrets()
|
||||
assert len(secrets) == 1
|
||||
assert secrets[0].name == "mysecret"
|
||||
|
||||
|
||||
secret_to_client = await backend_api.get_secret("mysecret")
|
||||
assert secret_to_client is not None
|
||||
|
||||
assert secret_to_client.name == "mysecret"
|
||||
assert "test" in secret_to_client.clients
|
||||
|
||||
secret = await backend_api.get_client_secret("test", "mysecret")
|
||||
|
||||
assert secret is not None
|
||||
assert secret == "encrypted_secret"
|
||||
139
tests/integration/test_sshd.py
Normal file
139
tests/integration/test_sshd.py
Normal file
@ -0,0 +1,139 @@
|
||||
"""Tests where the sshd is the main consumer.
|
||||
|
||||
This essentially also tests parts of the admin API.
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncIterator
|
||||
import os
|
||||
import httpx
|
||||
|
||||
import pytest
|
||||
from sshecret.crypto import decode_string
|
||||
from sshecret.backend.api import SshecretBackend
|
||||
|
||||
from .clients import create_test_client, ClientData
|
||||
|
||||
from .types import CommandRunner, ProcessRunner
|
||||
|
||||
|
||||
class TestSshd:
|
||||
"""Class based tests.
|
||||
|
||||
This allows us to create small helpers.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_secret(
|
||||
self, backend_api: SshecretBackend, ssh_command_runner: CommandRunner
|
||||
) -> None:
|
||||
"""Test get secret flow."""
|
||||
test_client = create_test_client("testclient")
|
||||
await backend_api.create_client(
|
||||
"testclient", test_client.public_key, "A test client"
|
||||
)
|
||||
await backend_api.create_client_secret("testclient", "testsecret", "bogus")
|
||||
response = await ssh_command_runner(test_client, "get_secret testsecret")
|
||||
assert response.exit_status == 0
|
||||
assert response.stdout is not None
|
||||
assert isinstance(response.stdout, str)
|
||||
assert response.stdout.rstrip() == "bogus"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register(
|
||||
self, backend_api: SshecretBackend, ssh_session: ProcessRunner
|
||||
) -> None:
|
||||
"""Test registration."""
|
||||
await self.register_client("new_client", ssh_session)
|
||||
# Check that the client is created.
|
||||
clients = await backend_api.get_clients()
|
||||
assert len(clients) == 1
|
||||
|
||||
client = clients[0]
|
||||
assert client.name == "new_client"
|
||||
|
||||
async def register_client(
|
||||
self, name: str, ssh_session: ProcessRunner
|
||||
) -> ClientData:
|
||||
"""Register client."""
|
||||
test_client = create_test_client(name)
|
||||
async with ssh_session(test_client, "register") as session:
|
||||
maxlines = 10
|
||||
linenum = 0
|
||||
found = False
|
||||
while linenum < maxlines:
|
||||
line = await session.stdout.readline()
|
||||
if "Enter public key" in line:
|
||||
found = True
|
||||
break
|
||||
assert found is True
|
||||
session.stdin.write(test_client.public_key + "\n")
|
||||
|
||||
result = await session.stdout.readline()
|
||||
assert "OK" in result
|
||||
await session.wait()
|
||||
return test_client
|
||||
|
||||
|
||||
class TestSshdIntegration(TestSshd):
|
||||
"""Integration tests."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end(
|
||||
self,
|
||||
backend_api: SshecretBackend,
|
||||
admin_server: tuple[str, tuple[str, str]],
|
||||
ssh_session: ProcessRunner,
|
||||
ssh_command_runner: CommandRunner,
|
||||
) -> None:
|
||||
"""Test end to end."""
|
||||
test_client = await self.register_client("myclient", ssh_session)
|
||||
url, credentials = admin_server
|
||||
username, password = credentials
|
||||
async with self.admin_client(url, username, password) as http_client:
|
||||
resp = await http_client.get("api/v1/clients/")
|
||||
assert resp.status_code == 200
|
||||
clients = resp.json()
|
||||
assert len(clients) == 1
|
||||
assert clients[0]["name"] == "myclient"
|
||||
|
||||
create_model = {
|
||||
"name": "mysecret",
|
||||
"clients": ["myclient"],
|
||||
"value": "mypassword",
|
||||
}
|
||||
resp = await http_client.post("api/v1/secrets/", json=create_model)
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Login via ssh to fetch the decrypted value.
|
||||
ssh_output = await ssh_command_runner(test_client, "get_secret mysecret")
|
||||
assert ssh_output.stdout is not None
|
||||
assert isinstance(ssh_output.stdout, str)
|
||||
encrypted = ssh_output.stdout.rstrip()
|
||||
decrypted = decode_string(encrypted, test_client.private_key)
|
||||
assert decrypted == "mypassword"
|
||||
|
||||
async def login(self, url: str, username: str, password: str) -> str:
|
||||
"""Login and get token."""
|
||||
api_url = os.path.join(url, "api/v1", "token")
|
||||
client = httpx.AsyncClient()
|
||||
|
||||
response = await client.post(
|
||||
api_url, data={"username": username, "password": password}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "access_token" in data
|
||||
assert isinstance(data["access_token"], str)
|
||||
return str(data["access_token"])
|
||||
|
||||
@asynccontextmanager
|
||||
async def admin_client(
|
||||
self, url: str, username: str, password: str
|
||||
) -> AsyncIterator[httpx.AsyncClient]:
|
||||
"""Create an admin client."""
|
||||
token = await self.login(url, username, password)
|
||||
async with httpx.AsyncClient(
|
||||
base_url=url, headers={"Authorization": f"Bearer {token}"}
|
||||
) as client:
|
||||
yield client
|
||||
29
tests/integration/types.py
Normal file
29
tests/integration/types.py
Normal file
@ -0,0 +1,29 @@
|
||||
"""Typings."""
|
||||
import asyncssh
|
||||
|
||||
from typing import Any, AsyncContextManager, Protocol
|
||||
from dataclasses import dataclass
|
||||
from collections.abc import Callable, Awaitable
|
||||
|
||||
from .clients import ClientData
|
||||
|
||||
|
||||
PortFactory = Callable[[], int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestPorts:
|
||||
"""Test port dataclass."""
|
||||
|
||||
backend: int
|
||||
admin: int
|
||||
sshd: int
|
||||
|
||||
|
||||
CommandRunner = Callable[[ClientData, str], Awaitable[asyncssh.SSHCompletedProcess]]
|
||||
|
||||
class ProcessRunner(Protocol):
|
||||
"""Process runner typing."""
|
||||
|
||||
def __call__(self, test_client: ClientData, command: str) -> AsyncContextManager[asyncssh.SSHClientProcess[Any]]:
|
||||
...
|
||||
0
tests/packages/__init__.py
Normal file
0
tests/packages/__init__.py
Normal file
0
tests/packages/backend/__init__.py
Normal file
0
tests/packages/backend/__init__.py
Normal file
538
tests/packages/backend/test_backend.py
Normal file
538
tests/packages/backend/test_backend.py
Normal file
@ -0,0 +1,538 @@
|
||||
"""Tests of the backend api using pytest."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from httpx import Response
|
||||
import pytest
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from sshecret.crypto import generate_private_key, generate_public_key_string
|
||||
from sshecret_backend.app import create_backend_app
|
||||
from sshecret_backend.testing import create_test_token
|
||||
from sshecret_backend.view_models import AuditView
|
||||
from sshecret_backend.settings import BackendSettings
|
||||
|
||||
|
||||
LOG = logging.getLogger()
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter("'%(asctime)s - %(levelname)s - %(message)s'")
|
||||
handler.setFormatter(formatter)
|
||||
LOG.addHandler(handler)
|
||||
LOG.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def make_test_key() -> str:
|
||||
"""Generate a test key."""
|
||||
private_key = generate_private_key()
|
||||
return generate_public_key_string(private_key.public_key())
|
||||
|
||||
|
||||
def create_client(
|
||||
test_client: TestClient,
|
||||
name: str,
|
||||
public_key: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> Response:
|
||||
"""Create client."""
|
||||
if not public_key:
|
||||
public_key = make_test_key()
|
||||
data = {
|
||||
"name": name,
|
||||
"public_key": public_key,
|
||||
}
|
||||
if description:
|
||||
data["description"] = description
|
||||
create_response = test_client.post("/api/v1/clients", json=data)
|
||||
return create_response
|
||||
|
||||
|
||||
@pytest.fixture(name="test_client")
|
||||
def create_client_fixture(tmp_path: Path):
|
||||
"""Test client fixture."""
|
||||
|
||||
db_file = tmp_path / "backend.db"
|
||||
print(f"DB File: {db_file.absolute()}")
|
||||
settings = BackendSettings(database=str(db_file.absolute()))
|
||||
app = create_backend_app(settings)
|
||||
|
||||
token = create_test_token(settings)
|
||||
|
||||
test_client = TestClient(app, headers={"X-API-Token": token})
|
||||
yield test_client
|
||||
|
||||
|
||||
def test_missing_token(test_client: TestClient) -> None:
|
||||
"""Test logging in with missing token."""
|
||||
# Save headers
|
||||
old_headers = test_client.headers
|
||||
test_client.headers = {}
|
||||
response = test_client.get("/api/v1/clients/", headers={})
|
||||
assert response.status_code == 422
|
||||
test_client.headers = old_headers
|
||||
|
||||
|
||||
def test_incorrect_token(test_client: TestClient) -> None:
|
||||
"""Test logging in with missing token."""
|
||||
response = test_client.get("/api/v1/clients/", headers={"X-API-Token": "WRONG"})
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_with_token(test_client: TestClient) -> None:
|
||||
"""Test with a valid token."""
|
||||
response = test_client.get("/api/v1/clients/")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_results"] == 0
|
||||
|
||||
|
||||
def test_create_client(test_client: TestClient) -> None:
|
||||
"""Test creating a client."""
|
||||
client_name = "test"
|
||||
client_publickey = make_test_key()
|
||||
create_response = create_client(test_client, client_name, client_publickey)
|
||||
assert create_response.status_code == 200
|
||||
response = test_client.get("/api/v1/clients")
|
||||
assert response.status_code == 200
|
||||
clients_result = response.json()
|
||||
clients = clients_result["clients"]
|
||||
assert isinstance(clients, list)
|
||||
client = clients[0]
|
||||
assert isinstance(client, dict)
|
||||
assert client.get("name") == client_name
|
||||
assert client.get("created_at") is not None
|
||||
|
||||
|
||||
def test_delete_client(test_client: TestClient) -> None:
|
||||
"""Test creating a client."""
|
||||
client_name = "test"
|
||||
create_response = create_client(
|
||||
test_client,
|
||||
client_name,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
|
||||
resp = test_client.delete("/api/v1/clients/test")
|
||||
assert resp.status_code == 200
|
||||
|
||||
resp = test_client.get("/api/v1/clients/test")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_add_secret(test_client: TestClient) -> None:
|
||||
"""Test adding a secret to a client."""
|
||||
client_name = "test"
|
||||
client_publickey = make_test_key()
|
||||
create_response = create_client(
|
||||
test_client,
|
||||
client_name,
|
||||
client_publickey,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
secret_name = "mysecret"
|
||||
secret_value = "shhhh"
|
||||
data = {"name": secret_name, "secret": secret_value, "description": "A test secret"}
|
||||
response = test_client.post("/api/v1/clients/test/secrets/", json=data)
|
||||
assert response.status_code == 200
|
||||
# Get it back
|
||||
get_response = test_client.get("/api/v1/clients/test/secrets/mysecret")
|
||||
assert get_response.status_code == 200
|
||||
secret_body = get_response.json()
|
||||
assert secret_body["name"] == data["name"]
|
||||
assert secret_body["secret"] == data["secret"]
|
||||
|
||||
|
||||
def test_delete_secret(test_client: TestClient) -> None:
|
||||
"""Test deleting a secret."""
|
||||
test_add_secret(test_client)
|
||||
resp = test_client.delete("/api/v1/clients/test/secrets/mysecret")
|
||||
assert resp.status_code == 200
|
||||
get_response = test_client.get("/api/v1/clients/test/secrets/mysecret")
|
||||
assert get_response.status_code == 404
|
||||
|
||||
|
||||
def test_put_add_secret(test_client: TestClient) -> None:
|
||||
"""Test adding secret via PUT."""
|
||||
# Use the test_create_client function to create a client.
|
||||
test_create_client(test_client)
|
||||
secret_name = "mysecret"
|
||||
secret_value = "shhhh"
|
||||
data = {"name": secret_name, "secret": secret_value, "description": None}
|
||||
response = test_client.put(
|
||||
"/api/v1/clients/test/secrets/mysecret",
|
||||
json={"value": secret_value},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
response_model = response.json()
|
||||
del response_model["created_at"]
|
||||
del response_model["updated_at"]
|
||||
assert response_model == data
|
||||
|
||||
|
||||
def test_put_update_secret(test_client: TestClient) -> None:
|
||||
"""Test updating a client secret."""
|
||||
test_add_secret(test_client)
|
||||
new_value = "itsasecret"
|
||||
update_response = test_client.put(
|
||||
"/api/v1/clients/test/secrets/mysecret",
|
||||
json={"value": new_value},
|
||||
)
|
||||
assert update_response.status_code == 200
|
||||
expected = {"name": "mysecret", "secret": new_value}
|
||||
response_model = update_response.json()
|
||||
|
||||
assert {
|
||||
"name": response_model["name"],
|
||||
"secret": response_model["secret"],
|
||||
} == expected
|
||||
# Ensure that the updated_at has been set.
|
||||
assert "updated_at" in response_model
|
||||
|
||||
|
||||
def test_audit_logging(test_client: TestClient) -> None:
|
||||
"""Test audit logging."""
|
||||
public_key = make_test_key()
|
||||
create_client_resp = create_client(test_client, "test", public_key)
|
||||
assert create_client_resp.status_code == 200
|
||||
secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"}
|
||||
for name, secret in secrets.items():
|
||||
add_resp = test_client.post(
|
||||
"/api/v1/clients/test/secrets/",
|
||||
json={"name": name, "secret": secret},
|
||||
)
|
||||
assert add_resp.status_code == 200
|
||||
|
||||
# Fetch the entire client.
|
||||
get_client_resp = test_client.get("/api/v1/clients/test")
|
||||
assert get_client_resp.status_code == 200
|
||||
|
||||
# Fetch the audit log
|
||||
audit_log_resp = test_client.get("/api/v1/audit/")
|
||||
assert audit_log_resp.status_code == 200
|
||||
audit_logs = audit_log_resp.json()
|
||||
assert len(audit_logs) > 0
|
||||
for entry in audit_logs:
|
||||
# Let's try to reassemble the objects
|
||||
audit_log = AuditView.model_validate(entry)
|
||||
assert audit_log is not None
|
||||
|
||||
|
||||
# def test_audit_log_filtering(
|
||||
# session: Session, test_client: TestClient
|
||||
# ) -> None:
|
||||
# """Test audit log filtering."""
|
||||
# # Create a lot of test data, but just manually.
|
||||
# audit_log_amount = 150
|
||||
# entries: list[AuditLog] = []
|
||||
# for i in range(audit_log_amount):
|
||||
# client_id = i % 5
|
||||
# entries.append(
|
||||
# AuditLog(
|
||||
# operation="TEST",
|
||||
# object_id=str(i),
|
||||
# client_name=f"client-{client_id}",
|
||||
# message="Test Message",
|
||||
# )
|
||||
# )
|
||||
|
||||
# session.add_all(entries)
|
||||
# session.commit()
|
||||
|
||||
# # This should have generated a lot of audit messages
|
||||
|
||||
# audit_path = "/api/v1/audit/"
|
||||
# audit_log_resp = test_client.get(audit_path)
|
||||
# assert audit_log_resp.status_code == 200
|
||||
# entries = audit_log_resp.json()
|
||||
# assert len(entries) == 100 # We get 100 at a time
|
||||
|
||||
# audit_log_resp = test_client.get(
|
||||
# audit_path, params={"offset": 100}
|
||||
# )
|
||||
# entries = audit_log_resp.json()
|
||||
# assert len(entries) == 52 # There should be 50 + the two requests we made
|
||||
|
||||
# # Try to get a specific client
|
||||
# # There should be 30 log entries for each client.
|
||||
# audit_log_resp = test_client.get(
|
||||
# audit_path, params={"filter_client": "client-1"}
|
||||
# )
|
||||
|
||||
# entries = audit_log_resp.json()
|
||||
# assert len(entries) == 30
|
||||
|
||||
|
||||
def test_secret_invalidation(test_client: TestClient) -> None:
|
||||
"""Test secret invalidation."""
|
||||
initial_key = make_test_key()
|
||||
create_client_resp = create_client(test_client, "test", initial_key)
|
||||
assert create_client_resp.status_code == 200
|
||||
secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"}
|
||||
for name, secret in secrets.items():
|
||||
add_resp = test_client.post(
|
||||
"/api/v1/clients/test/secrets/",
|
||||
json={"name": name, "secret": secret},
|
||||
)
|
||||
assert add_resp.status_code == 200
|
||||
|
||||
# Update the public-key. This should cause all secrets to be invalidated
|
||||
# and no longer associated with a client.
|
||||
new_key = make_test_key()
|
||||
update_resp = test_client.post(
|
||||
"/api/v1/clients/test/public-key",
|
||||
json={"public_key": new_key},
|
||||
)
|
||||
assert update_resp.status_code == 200
|
||||
|
||||
# Fetch the client. The list of secrets should be empty.
|
||||
get_resp = test_client.get("/api/v1/clients/test")
|
||||
assert get_resp.status_code == 200
|
||||
client = get_resp.json()
|
||||
secrets = client.get("secrets")
|
||||
assert bool(secrets) is False
|
||||
|
||||
|
||||
def test_client_default_policies(
|
||||
test_client: TestClient,
|
||||
) -> None:
|
||||
"""Test client policies."""
|
||||
public_key = make_test_key()
|
||||
resp = create_client(test_client, "test")
|
||||
assert resp.status_code == 200
|
||||
# Fetch policies, should return *
|
||||
resp = test_client.get("/api/v1/clients/test/policies/")
|
||||
assert resp.status_code == 200
|
||||
|
||||
policies = resp.json()
|
||||
|
||||
assert policies["sources"] == ["0.0.0.0/0", "::/0"]
|
||||
|
||||
|
||||
def test_client_policy_update_one(test_client: TestClient) -> None:
|
||||
"""Update client policy with single policy."""
|
||||
public_key = make_test_key()
|
||||
resp = create_client(test_client, "test", public_key)
|
||||
assert resp.status_code == 200
|
||||
|
||||
policy = ["192.0.2.1"]
|
||||
resp = test_client.put("/api/v1/clients/test/policies/", json={"sources": policy})
|
||||
assert resp.status_code == 200
|
||||
|
||||
resp = test_client.get("/api/v1/clients/test/policies/")
|
||||
assert resp.status_code == 200
|
||||
|
||||
policies = resp.json()
|
||||
|
||||
assert policies["sources"] == policy
|
||||
|
||||
|
||||
def test_client_policy_update_advanced(test_client: TestClient) -> None:
|
||||
"""Test other policy update scenarios."""
|
||||
public_key = make_test_key()
|
||||
resp = create_client(test_client, "test", public_key)
|
||||
assert resp.status_code == 200
|
||||
|
||||
policy = ["192.0.2.1", "198.18.0.0/24"]
|
||||
|
||||
resp = test_client.put("/api/v1/clients/test/policies/", json={"sources": policy})
|
||||
assert resp.status_code == 200
|
||||
|
||||
resp = test_client.get("/api/v1/clients/test/policies/")
|
||||
assert resp.status_code == 200
|
||||
|
||||
policies = resp.json()
|
||||
|
||||
assert "192.0.2.1" in policies["sources"]
|
||||
assert "198.18.0.0/24" in policies["sources"]
|
||||
|
||||
# Try to set it to something incorrect
|
||||
|
||||
policy = ["obviosly_wrong"]
|
||||
|
||||
resp = test_client.put("/api/v1/clients/test/policies/", json={"sources": policy})
|
||||
assert resp.status_code == 422
|
||||
# Check that the old value is still there
|
||||
|
||||
resp = test_client.get("/api/v1/clients/test/policies/")
|
||||
assert resp.status_code == 200
|
||||
|
||||
policies = resp.json()
|
||||
|
||||
assert "192.0.2.1" in policies["sources"]
|
||||
assert "198.18.0.0/24" in policies["sources"]
|
||||
|
||||
# Clear the policies
|
||||
#
|
||||
|
||||
|
||||
def test_client_policy_update_unset(test_client: TestClient) -> None:
|
||||
"""Test clearing the client policy."""
|
||||
public_key = make_test_key()
|
||||
resp = create_client(test_client, "test", public_key)
|
||||
assert resp.status_code == 200
|
||||
policy = ["192.0.2.1", "198.18.0.0/24"]
|
||||
|
||||
resp = test_client.put("/api/v1/clients/test/policies/", json={"sources": policy})
|
||||
assert resp.status_code == 200
|
||||
|
||||
policies = resp.json()
|
||||
|
||||
assert "192.0.2.1" in policies["sources"]
|
||||
assert "198.18.0.0/24" in policies["sources"]
|
||||
|
||||
# Now we clear the policies
|
||||
|
||||
resp = test_client.put("/api/v1/clients/test/policies/", json={"sources": []})
|
||||
assert resp.status_code == 200
|
||||
|
||||
policies = resp.json()
|
||||
|
||||
assert policies["sources"] == ["0.0.0.0/0", "::/0"]
|
||||
|
||||
|
||||
def test_client_update(test_client: TestClient) -> None:
|
||||
"""Test generic update of a client."""
|
||||
|
||||
public_key = make_test_key()
|
||||
resp = create_client(test_client, "test", public_key, "PRE")
|
||||
assert resp.status_code == 200
|
||||
resp = test_client.get("/api/v1/clients/test")
|
||||
assert resp.status_code == 200
|
||||
client_data = resp.json()
|
||||
assert client_data["description"] == "PRE"
|
||||
|
||||
# Update the description
|
||||
new_client_data = {
|
||||
"name": "test",
|
||||
"description": "POST",
|
||||
"public_key": client_data["public_key"],
|
||||
}
|
||||
resp = test_client.put("/api/v1/clients/test", json=new_client_data)
|
||||
assert resp.status_code == 200
|
||||
|
||||
client_data = resp.json()
|
||||
assert client_data["description"] == "POST"
|
||||
|
||||
resp = test_client.get("/api/v1/clients/test")
|
||||
assert resp.status_code == 200
|
||||
client_data = resp.json()
|
||||
assert client_data["description"] == "POST"
|
||||
|
||||
|
||||
def test_get_secret_list(test_client: TestClient) -> None:
|
||||
"""Test the secret to client map view."""
|
||||
# Make 4 clients
|
||||
for x in range(4):
|
||||
public_key = make_test_key()
|
||||
create_client(test_client, f"client-{x}", public_key)
|
||||
# Create a secret that only this client has.
|
||||
resp = test_client.put(
|
||||
f"/api/v1/clients/client-{x}/secrets/client-{x}", json={"value": "SECRET"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
# Create a secret that all of them have.
|
||||
resp = test_client.put(
|
||||
f"/api/v1/clients/client-{x}/secrets/commonsecret", json={"value": "SECRET"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Get the secret list
|
||||
resp = test_client.get("/api/v1/secrets/")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 5
|
||||
for entry in data:
|
||||
if entry["name"] == "commonsecret":
|
||||
assert len(entry["clients"]) == 4
|
||||
else:
|
||||
assert len(entry["clients"]) == 1
|
||||
assert entry["clients"][0] == entry["name"]
|
||||
|
||||
|
||||
def test_get_secret_clients(test_client: TestClient) -> None:
|
||||
"""Get the clients for a single secret."""
|
||||
for x in range(4):
|
||||
public_key = make_test_key()
|
||||
create_client(test_client, f"client-{x}", public_key)
|
||||
# Create a secret that every second of them have.
|
||||
if x % 2 == 1:
|
||||
continue
|
||||
resp = test_client.put(
|
||||
f"/api/v1/clients/client-{x}/secrets/commonsecret", json={"value": "SECRET"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
resp = test_client.get("/api/v1/secrets/commonsecret")
|
||||
assert resp.status_code == 200
|
||||
|
||||
data = resp.json()
|
||||
assert data["name"] == "commonsecret"
|
||||
assert "client-0" in data["clients"]
|
||||
assert "client-1" not in data["clients"]
|
||||
assert len(data["clients"]) == 2
|
||||
|
||||
|
||||
def test_searching(test_client: TestClient) -> None:
|
||||
"""Test searching."""
|
||||
for x in range(4):
|
||||
# Create four clients
|
||||
create_client(test_client, f"client-{x}")
|
||||
|
||||
# Create one with a different name.
|
||||
create_client(test_client, "othername")
|
||||
|
||||
# Search for a specific one.
|
||||
resp = test_client.get("/api/v1/clients/", params={"name": "othername"})
|
||||
assert resp.status_code == 200
|
||||
result = resp.json()
|
||||
assert result["total_results"] == 1
|
||||
assert result["clients"][0]["name"] == "othername"
|
||||
client_id = result["clients"][0]["id"]
|
||||
|
||||
# Search by ID
|
||||
resp = test_client.get("/api/v1/clients/", params={"id": client_id})
|
||||
assert resp.status_code == 200
|
||||
result = resp.json()
|
||||
assert result["total_results"] == 1
|
||||
assert result["clients"][0]["name"] == "othername"
|
||||
|
||||
# Search for the four similarly named ones
|
||||
resp = test_client.get("/api/v1/clients/", params={"name__like": "client-%"})
|
||||
assert resp.status_code == 200
|
||||
result = resp.json()
|
||||
assert result["total_results"] == 4
|
||||
assert str(result["clients"][0]["name"]).startswith("client-")
|
||||
|
||||
|
||||
def test_operations_with_id(test_client: TestClient) -> None:
|
||||
"""Test operations using ID instead of name."""
|
||||
create_client(test_client, "test")
|
||||
resp = test_client.get("/api/v1/clients/")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
client = data["clients"][0]
|
||||
client_id = client["id"]
|
||||
resp = test_client.get(f"/api/v1/clients/{client_id}")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["name"] == "test"
|
||||
|
||||
|
||||
def test_write_audit_log(test_client: TestClient) -> None:
|
||||
"""Test writing to the audit log."""
|
||||
params = {
|
||||
"subsystem": "backend",
|
||||
"operation": "read",
|
||||
"message": "Test Message"
|
||||
}
|
||||
resp = test_client.post("/api/v1/audit", json=params)
|
||||
assert resp.status_code == 200
|
||||
|
||||
resp = test_client.get("/api/v1/audit")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
entry = data[0]
|
||||
for key, value in params.items():
|
||||
assert entry[key] == value
|
||||
1
tests/packages/sshd/__init__.py
Normal file
1
tests/packages/sshd/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
|
||||
183
tests/packages/sshd/conftest.py
Normal file
183
tests/packages/sshd/conftest.py
Normal file
@ -0,0 +1,183 @@
|
||||
import asyncio
|
||||
from typing import Any
|
||||
import pytest
|
||||
import uuid
|
||||
import asyncssh
|
||||
import tempfile
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from ipaddress import IPv4Network, IPv6Network
|
||||
|
||||
from sshecret_sshd.ssh_server import run_ssh_server
|
||||
from sshecret_sshd.settings import ClientRegistrationSettings
|
||||
|
||||
from .types import (
|
||||
Client,
|
||||
ClientKey,
|
||||
ClientRegistry,
|
||||
SshServerFixtureFun,
|
||||
SshServerFixture,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client_registry() -> ClientRegistry:
|
||||
clients = {}
|
||||
secrets = {}
|
||||
|
||||
async def add_client(
|
||||
name: str,
|
||||
secret_names: list[str] | None = None,
|
||||
policies: list[str] | None = None,
|
||||
) -> str:
|
||||
private_key = asyncssh.generate_private_key("ssh-rsa")
|
||||
public_key = private_key.export_public_key()
|
||||
clients[name] = ClientKey(name, private_key, public_key.decode().rstrip())
|
||||
secrets.update({(name, s): f"mocked-secret-{s}" for s in (secret_names or [])})
|
||||
return clients[name]
|
||||
|
||||
return {"clients": clients, "secrets": secrets, "add_client": add_client}
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
|
||||
backend = MagicMock()
|
||||
clients_data = client_registry["clients"]
|
||||
secrets_data = client_registry["secrets"]
|
||||
|
||||
async def get_client(name: str) -> Client | None:
|
||||
client_key = clients_data.get(name)
|
||||
if client_key:
|
||||
response_model = Client(
|
||||
id=uuid.uuid4(),
|
||||
name=name,
|
||||
description=f"Mock client {name}",
|
||||
public_key=client_key.public_key,
|
||||
secrets=[s for (c, s) in secrets_data if c == name],
|
||||
policies=[IPv4Network("0.0.0.0/0"), IPv6Network("::/0")],
|
||||
)
|
||||
return response_model
|
||||
return None
|
||||
|
||||
async def get_client_secret(name: str, secret_name: str) -> str | None:
|
||||
secret = secrets_data.get((name, secret_name), None)
|
||||
return secret
|
||||
|
||||
async def create_client(name: str, public_key: str) -> None:
|
||||
"""Create client.
|
||||
|
||||
This only works if you register a client called template first.
|
||||
Otherwise we can't test this...
|
||||
"""
|
||||
if "template" not in clients_data:
|
||||
raise RuntimeError(
|
||||
"Error, must have a client called template for this to work."
|
||||
)
|
||||
clients_data[name] = clients_data["template"]
|
||||
for secret_key, secret in secrets_data.items():
|
||||
s_client, secret_name = secret_key
|
||||
if s_client != "template":
|
||||
continue
|
||||
secrets_data[(name, secret_name)] = secret
|
||||
|
||||
async def write_audit(*args, **kwargs):
|
||||
"""Write audit mock."""
|
||||
return None
|
||||
|
||||
backend.get_client = AsyncMock(side_effect=get_client)
|
||||
backend.get_client_secret = AsyncMock(side_effect=get_client_secret)
|
||||
backend.create_client = AsyncMock(side_effect=create_client)
|
||||
|
||||
audit = MagicMock()
|
||||
audit.write_async = AsyncMock(side_effect=write_audit)
|
||||
backend.audit = MagicMock(return_value=audit)
|
||||
|
||||
return backend
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def ssh_server(
|
||||
mock_backend: MagicMock, unused_tcp_port: int
|
||||
) -> SshServerFixtureFun:
|
||||
port = unused_tcp_port
|
||||
|
||||
private_key = asyncssh.generate_private_key("ssh-ed25519")
|
||||
key_str = private_key.export_private_key()
|
||||
with tempfile.NamedTemporaryFile("w+", delete=True) as key_file:
|
||||
key_file.write(key_str.decode())
|
||||
key_file.flush()
|
||||
|
||||
registration_settings = ClientRegistrationSettings(
|
||||
enabled=True, allow_from=[IPv4Network("0.0.0.0/0")]
|
||||
)
|
||||
server = await run_ssh_server(
|
||||
backend=mock_backend,
|
||||
listen_address="localhost",
|
||||
port=port,
|
||||
keys=[key_file.name],
|
||||
registration=registration_settings,
|
||||
enable_ping_command=True,
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
yield server, port
|
||||
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def ssh_command_runner(ssh_server: SshServerFixture, client_registry: ClientRegistry):
|
||||
"""Run a single command.
|
||||
|
||||
Tricky typing!
|
||||
"""
|
||||
_, port = ssh_server
|
||||
|
||||
async def run_command_as(name: str, command: str):
|
||||
client_key = client_registry["clients"][name]
|
||||
conn = await asyncssh.connect(
|
||||
"127.0.0.1",
|
||||
port=port,
|
||||
username=name,
|
||||
client_keys=[client_key.private_key],
|
||||
known_hosts=None,
|
||||
)
|
||||
try:
|
||||
result = await conn.run(command)
|
||||
return result
|
||||
finally:
|
||||
conn.close()
|
||||
await conn.wait_closed()
|
||||
|
||||
return run_command_as
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def ssh_session(ssh_server: SshServerFixture, client_registry: ClientRegistry):
|
||||
"""Yield an interactive session."""
|
||||
|
||||
_, port = ssh_server
|
||||
|
||||
@asynccontextmanager
|
||||
async def run_process_as(name: str, command: str, client: str | None = None):
|
||||
if not client:
|
||||
client = name
|
||||
client_key = client_registry["clients"][client]
|
||||
conn = await asyncssh.connect(
|
||||
"127.0.0.1",
|
||||
port=port,
|
||||
username=name,
|
||||
client_keys=[client_key.private_key],
|
||||
known_hosts=None,
|
||||
)
|
||||
try:
|
||||
async with conn.create_process(command) as process:
|
||||
yield process
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
await conn.wait_closed()
|
||||
|
||||
return run_process_as
|
||||
33
tests/packages/sshd/test_get_secret.py
Normal file
33
tests/packages/sshd/test_get_secret.py
Normal file
@ -0,0 +1,33 @@
|
||||
"""Test get secret."""
|
||||
|
||||
import pytest
|
||||
|
||||
from .types import ClientRegistry, CommandRunner
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_secret(
|
||||
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
|
||||
) -> None:
|
||||
"""Test that we can get a secret."""
|
||||
|
||||
await client_registry["add_client"]("test-client", ["mysecret"])
|
||||
|
||||
result = await ssh_command_runner("test-client", "get_secret mysecret")
|
||||
|
||||
assert result.stdout is not None
|
||||
assert isinstance(result.stdout, str)
|
||||
assert result.stdout.rstrip() == "mocked-secret-mysecret"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_secret_name(
|
||||
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
|
||||
) -> None:
|
||||
"""Test getting an invalid secret name."""
|
||||
|
||||
await client_registry["add_client"]("test-client")
|
||||
|
||||
result = await ssh_command_runner("test-client", "get_secret mysecret")
|
||||
assert result.exit_status == 1
|
||||
assert result.stderr == "Error: No secret available with the given name."
|
||||
18
tests/packages/sshd/test_ping.py
Normal file
18
tests/packages/sshd/test_ping.py
Normal file
@ -0,0 +1,18 @@
|
||||
import pytest
|
||||
|
||||
from .types import ClientRegistry, CommandRunner
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ping_command(
|
||||
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
|
||||
) -> None:
|
||||
# Register a test client with default policies and no secrets
|
||||
await client_registry["add_client"]("test-pinger")
|
||||
|
||||
result = await ssh_command_runner("test-pinger", "ping")
|
||||
|
||||
assert result.exit_status == 0
|
||||
assert result.stdout is not None
|
||||
assert isinstance(result.stdout, str)
|
||||
assert result.stdout.rstrip() == "PONG"
|
||||
40
tests/packages/sshd/test_register.py
Normal file
40
tests/packages/sshd/test_register.py
Normal file
@ -0,0 +1,40 @@
|
||||
"""Test registration."""
|
||||
|
||||
import pytest
|
||||
|
||||
from .types import ClientRegistry, CommandRunner, ProcessRunner
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_client(
|
||||
ssh_session: ProcessRunner,
|
||||
ssh_command_runner: CommandRunner,
|
||||
client_registry: ClientRegistry,
|
||||
) -> None:
|
||||
"""Test client registration."""
|
||||
await client_registry["add_client"]("template", ["testsecret"])
|
||||
|
||||
public_key = client_registry["clients"]["template"].public_key.rstrip() + "\n"
|
||||
|
||||
async with ssh_session("newclient", "register", "template") as session:
|
||||
maxlines = 10
|
||||
linenum = 0
|
||||
found = False
|
||||
while linenum < maxlines:
|
||||
line = await session.stdout.readline()
|
||||
if "Enter public key" in line:
|
||||
found = True
|
||||
break
|
||||
|
||||
assert found is True
|
||||
session.stdin.write(public_key)
|
||||
result = await session.stdout.readline()
|
||||
assert "OK" in result
|
||||
|
||||
# Test that we can connect
|
||||
|
||||
result = await ssh_command_runner("newclient", "get_secret testsecret")
|
||||
|
||||
assert result.stdout is not None
|
||||
assert isinstance(result.stdout, str)
|
||||
assert result.stdout.rstrip() == "mocked-secret-testsecret"
|
||||
63
tests/packages/sshd/types.py
Normal file
63
tests/packages/sshd/types.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""Types for the various test properties."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass, field
|
||||
from ipaddress import IPv4Network, IPv6Network
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from typing import Any, Protocol, TypedDict, AsyncContextManager
|
||||
import asyncssh
|
||||
|
||||
SshServerFixture = tuple[str, int]
|
||||
SshServerFixtureFun = AsyncGenerator[tuple[asyncssh.SSHAcceptor, int], None]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Client:
|
||||
"""Mock client."""
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
description: str | None
|
||||
public_key: str
|
||||
secrets: list[str]
|
||||
policies: list[IPv4Network | IPv6Network]
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientKey:
|
||||
name: str
|
||||
private_key: asyncssh.SSHKey
|
||||
public_key: str
|
||||
|
||||
|
||||
class AddClientFun(Protocol):
|
||||
"""Add client function."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
name: str,
|
||||
secret_names: list[str] | None = None,
|
||||
policies: list[str] | None = None,
|
||||
) -> Awaitable[str]: ...
|
||||
|
||||
|
||||
class ProcessRunner(Protocol):
|
||||
"""Process runner typing."""
|
||||
|
||||
def __call__(
|
||||
self, name: str, command: str, client: str | None = None
|
||||
) -> AsyncContextManager[asyncssh.SSHClientProcess[Any]]: ...
|
||||
|
||||
|
||||
class ClientRegistry(TypedDict):
|
||||
"""Client registry typing."""
|
||||
|
||||
clients: dict[str, ClientKey]
|
||||
secrets: dict[tuple[str, str], str]
|
||||
add_client: AddClientFun
|
||||
|
||||
|
||||
CommandRunner = Callable[[str, str], Awaitable[asyncssh.SSHCompletedProcess]]
|
||||
Reference in New Issue
Block a user