Centralize testing
This commit is contained in:
1
tests/packages/sshd/__init__.py
Normal file
1
tests/packages/sshd/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
|
||||
183
tests/packages/sshd/conftest.py
Normal file
183
tests/packages/sshd/conftest.py
Normal file
@ -0,0 +1,183 @@
|
||||
import asyncio
|
||||
from typing import Any
|
||||
import pytest
|
||||
import uuid
|
||||
import asyncssh
|
||||
import tempfile
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from ipaddress import IPv4Network, IPv6Network
|
||||
|
||||
from sshecret_sshd.ssh_server import run_ssh_server
|
||||
from sshecret_sshd.settings import ClientRegistrationSettings
|
||||
|
||||
from .types import (
|
||||
Client,
|
||||
ClientKey,
|
||||
ClientRegistry,
|
||||
SshServerFixtureFun,
|
||||
SshServerFixture,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client_registry() -> ClientRegistry:
|
||||
clients = {}
|
||||
secrets = {}
|
||||
|
||||
async def add_client(
|
||||
name: str,
|
||||
secret_names: list[str] | None = None,
|
||||
policies: list[str] | None = None,
|
||||
) -> str:
|
||||
private_key = asyncssh.generate_private_key("ssh-rsa")
|
||||
public_key = private_key.export_public_key()
|
||||
clients[name] = ClientKey(name, private_key, public_key.decode().rstrip())
|
||||
secrets.update({(name, s): f"mocked-secret-{s}" for s in (secret_names or [])})
|
||||
return clients[name]
|
||||
|
||||
return {"clients": clients, "secrets": secrets, "add_client": add_client}
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
|
||||
backend = MagicMock()
|
||||
clients_data = client_registry["clients"]
|
||||
secrets_data = client_registry["secrets"]
|
||||
|
||||
async def get_client(name: str) -> Client | None:
|
||||
client_key = clients_data.get(name)
|
||||
if client_key:
|
||||
response_model = Client(
|
||||
id=uuid.uuid4(),
|
||||
name=name,
|
||||
description=f"Mock client {name}",
|
||||
public_key=client_key.public_key,
|
||||
secrets=[s for (c, s) in secrets_data if c == name],
|
||||
policies=[IPv4Network("0.0.0.0/0"), IPv6Network("::/0")],
|
||||
)
|
||||
return response_model
|
||||
return None
|
||||
|
||||
async def get_client_secret(name: str, secret_name: str) -> str | None:
|
||||
secret = secrets_data.get((name, secret_name), None)
|
||||
return secret
|
||||
|
||||
async def create_client(name: str, public_key: str) -> None:
|
||||
"""Create client.
|
||||
|
||||
This only works if you register a client called template first.
|
||||
Otherwise we can't test this...
|
||||
"""
|
||||
if "template" not in clients_data:
|
||||
raise RuntimeError(
|
||||
"Error, must have a client called template for this to work."
|
||||
)
|
||||
clients_data[name] = clients_data["template"]
|
||||
for secret_key, secret in secrets_data.items():
|
||||
s_client, secret_name = secret_key
|
||||
if s_client != "template":
|
||||
continue
|
||||
secrets_data[(name, secret_name)] = secret
|
||||
|
||||
async def write_audit(*args, **kwargs):
|
||||
"""Write audit mock."""
|
||||
return None
|
||||
|
||||
backend.get_client = AsyncMock(side_effect=get_client)
|
||||
backend.get_client_secret = AsyncMock(side_effect=get_client_secret)
|
||||
backend.create_client = AsyncMock(side_effect=create_client)
|
||||
|
||||
audit = MagicMock()
|
||||
audit.write_async = AsyncMock(side_effect=write_audit)
|
||||
backend.audit = MagicMock(return_value=audit)
|
||||
|
||||
return backend
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def ssh_server(
|
||||
mock_backend: MagicMock, unused_tcp_port: int
|
||||
) -> SshServerFixtureFun:
|
||||
port = unused_tcp_port
|
||||
|
||||
private_key = asyncssh.generate_private_key("ssh-ed25519")
|
||||
key_str = private_key.export_private_key()
|
||||
with tempfile.NamedTemporaryFile("w+", delete=True) as key_file:
|
||||
key_file.write(key_str.decode())
|
||||
key_file.flush()
|
||||
|
||||
registration_settings = ClientRegistrationSettings(
|
||||
enabled=True, allow_from=[IPv4Network("0.0.0.0/0")]
|
||||
)
|
||||
server = await run_ssh_server(
|
||||
backend=mock_backend,
|
||||
listen_address="localhost",
|
||||
port=port,
|
||||
keys=[key_file.name],
|
||||
registration=registration_settings,
|
||||
enable_ping_command=True,
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
yield server, port
|
||||
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def ssh_command_runner(ssh_server: SshServerFixture, client_registry: ClientRegistry):
|
||||
"""Run a single command.
|
||||
|
||||
Tricky typing!
|
||||
"""
|
||||
_, port = ssh_server
|
||||
|
||||
async def run_command_as(name: str, command: str):
|
||||
client_key = client_registry["clients"][name]
|
||||
conn = await asyncssh.connect(
|
||||
"127.0.0.1",
|
||||
port=port,
|
||||
username=name,
|
||||
client_keys=[client_key.private_key],
|
||||
known_hosts=None,
|
||||
)
|
||||
try:
|
||||
result = await conn.run(command)
|
||||
return result
|
||||
finally:
|
||||
conn.close()
|
||||
await conn.wait_closed()
|
||||
|
||||
return run_command_as
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def ssh_session(ssh_server: SshServerFixture, client_registry: ClientRegistry):
|
||||
"""Yield an interactive session."""
|
||||
|
||||
_, port = ssh_server
|
||||
|
||||
@asynccontextmanager
|
||||
async def run_process_as(name: str, command: str, client: str | None = None):
|
||||
if not client:
|
||||
client = name
|
||||
client_key = client_registry["clients"][client]
|
||||
conn = await asyncssh.connect(
|
||||
"127.0.0.1",
|
||||
port=port,
|
||||
username=name,
|
||||
client_keys=[client_key.private_key],
|
||||
known_hosts=None,
|
||||
)
|
||||
try:
|
||||
async with conn.create_process(command) as process:
|
||||
yield process
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
await conn.wait_closed()
|
||||
|
||||
return run_process_as
|
||||
33
tests/packages/sshd/test_get_secret.py
Normal file
33
tests/packages/sshd/test_get_secret.py
Normal file
@ -0,0 +1,33 @@
|
||||
"""Test get secret."""
|
||||
|
||||
import pytest
|
||||
|
||||
from .types import ClientRegistry, CommandRunner
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_secret(
|
||||
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
|
||||
) -> None:
|
||||
"""Test that we can get a secret."""
|
||||
|
||||
await client_registry["add_client"]("test-client", ["mysecret"])
|
||||
|
||||
result = await ssh_command_runner("test-client", "get_secret mysecret")
|
||||
|
||||
assert result.stdout is not None
|
||||
assert isinstance(result.stdout, str)
|
||||
assert result.stdout.rstrip() == "mocked-secret-mysecret"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_secret_name(
|
||||
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
|
||||
) -> None:
|
||||
"""Test getting an invalid secret name."""
|
||||
|
||||
await client_registry["add_client"]("test-client")
|
||||
|
||||
result = await ssh_command_runner("test-client", "get_secret mysecret")
|
||||
assert result.exit_status == 1
|
||||
assert result.stderr == "Error: No secret available with the given name."
|
||||
18
tests/packages/sshd/test_ping.py
Normal file
18
tests/packages/sshd/test_ping.py
Normal file
@ -0,0 +1,18 @@
|
||||
import pytest
|
||||
|
||||
from .types import ClientRegistry, CommandRunner
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ping_command(
|
||||
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
|
||||
) -> None:
|
||||
# Register a test client with default policies and no secrets
|
||||
await client_registry["add_client"]("test-pinger")
|
||||
|
||||
result = await ssh_command_runner("test-pinger", "ping")
|
||||
|
||||
assert result.exit_status == 0
|
||||
assert result.stdout is not None
|
||||
assert isinstance(result.stdout, str)
|
||||
assert result.stdout.rstrip() == "PONG"
|
||||
40
tests/packages/sshd/test_register.py
Normal file
40
tests/packages/sshd/test_register.py
Normal file
@ -0,0 +1,40 @@
|
||||
"""Test registration."""
|
||||
|
||||
import pytest
|
||||
|
||||
from .types import ClientRegistry, CommandRunner, ProcessRunner
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_client(
|
||||
ssh_session: ProcessRunner,
|
||||
ssh_command_runner: CommandRunner,
|
||||
client_registry: ClientRegistry,
|
||||
) -> None:
|
||||
"""Test client registration."""
|
||||
await client_registry["add_client"]("template", ["testsecret"])
|
||||
|
||||
public_key = client_registry["clients"]["template"].public_key.rstrip() + "\n"
|
||||
|
||||
async with ssh_session("newclient", "register", "template") as session:
|
||||
maxlines = 10
|
||||
linenum = 0
|
||||
found = False
|
||||
while linenum < maxlines:
|
||||
line = await session.stdout.readline()
|
||||
if "Enter public key" in line:
|
||||
found = True
|
||||
break
|
||||
|
||||
assert found is True
|
||||
session.stdin.write(public_key)
|
||||
result = await session.stdout.readline()
|
||||
assert "OK" in result
|
||||
|
||||
# Test that we can connect
|
||||
|
||||
result = await ssh_command_runner("newclient", "get_secret testsecret")
|
||||
|
||||
assert result.stdout is not None
|
||||
assert isinstance(result.stdout, str)
|
||||
assert result.stdout.rstrip() == "mocked-secret-testsecret"
|
||||
63
tests/packages/sshd/types.py
Normal file
63
tests/packages/sshd/types.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""Types for the various test properties."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass, field
|
||||
from ipaddress import IPv4Network, IPv6Network
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from typing import Any, Protocol, TypedDict, AsyncContextManager
|
||||
import asyncssh
|
||||
|
||||
SshServerFixture = tuple[str, int]
|
||||
SshServerFixtureFun = AsyncGenerator[tuple[asyncssh.SSHAcceptor, int], None]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Client:
|
||||
"""Mock client."""
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
description: str | None
|
||||
public_key: str
|
||||
secrets: list[str]
|
||||
policies: list[IPv4Network | IPv6Network]
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientKey:
|
||||
name: str
|
||||
private_key: asyncssh.SSHKey
|
||||
public_key: str
|
||||
|
||||
|
||||
class AddClientFun(Protocol):
|
||||
"""Add client function."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
name: str,
|
||||
secret_names: list[str] | None = None,
|
||||
policies: list[str] | None = None,
|
||||
) -> Awaitable[str]: ...
|
||||
|
||||
|
||||
class ProcessRunner(Protocol):
|
||||
"""Process runner typing."""
|
||||
|
||||
def __call__(
|
||||
self, name: str, command: str, client: str | None = None
|
||||
) -> AsyncContextManager[asyncssh.SSHClientProcess[Any]]: ...
|
||||
|
||||
|
||||
class ClientRegistry(TypedDict):
|
||||
"""Client registry typing."""
|
||||
|
||||
clients: dict[str, ClientKey]
|
||||
secrets: dict[tuple[str, str], str]
|
||||
add_client: AddClientFun
|
||||
|
||||
|
||||
CommandRunner = Callable[[str, str], Awaitable[asyncssh.SSHCompletedProcess]]
|
||||
Reference in New Issue
Block a user