Complete sshd package with tests

This commit is contained in:
2025-05-10 08:27:16 +02:00
parent 3719a2611d
commit 4f970a3f71
12 changed files with 472 additions and 103 deletions

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,160 @@
import asyncio
import pytest
import uuid
import asyncssh
import tempfile
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock
from ipaddress import IPv4Network, IPv6Network
from sshecret_sshd.ssh_server import run_ssh_server
from sshecret_sshd.settings import ClientRegistrationSettings
from .types import Client, ClientKey, ClientRegistry, SshServerFixtureFun, SshServerFixture
@pytest.fixture(scope="function")
def client_registry() -> ClientRegistry:
clients = {}
secrets = {}
async def add_client(name: str, secret_names: list[str]|None=None, policies: list[str]|None=None) -> str:
private_key = asyncssh.generate_private_key('ssh-rsa')
public_key = private_key.export_public_key()
clients[name] = ClientKey(name, private_key, public_key.decode().rstrip())
secrets.update({(name, s): f"mocked-secret-{s}" for s in (secret_names or [])})
return clients[name]
return {"clients": clients, "secrets": secrets, "add_client": add_client}
@pytest.fixture(scope="function")
async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
backend = MagicMock()
clients_data = client_registry['clients']
secrets_data = client_registry['secrets']
async def get_client(name: str) -> Client | None:
client_key = clients_data.get(name)
if client_key:
response_model = Client(
id=uuid.uuid4(),
name=name,
description=f"Mock client {name}",
public_key=client_key.public_key,
secrets=[s for (c, s) in secrets_data if c == name],
policies=[IPv4Network('0.0.0.0/0'), IPv6Network('::/0')],
)
return response_model
return None
async def get_client_secret(name: str, secret_name: str) -> str | None:
secret = secrets_data.get((name, secret_name), None)
return secret
async def create_client(name: str, public_key: str) -> None:
"""Create client.
This only works if you register a client called template first.
Otherwise we can't test this...
"""
if "template" not in clients_data:
raise RuntimeError("Error, must have a client called template for this to work.")
clients_data[name] = clients_data["template"]
for secret_key, secret in secrets_data.items():
s_client, secret_name = secret_key
if s_client != "template":
continue
secrets_data[(name, secret_name)] = secret
backend.get_client = AsyncMock(side_effect=get_client)
backend.get_client_secret = AsyncMock(side_effect=get_client_secret)
backend.create_client = AsyncMock(side_effect=create_client)
# Make sure backend.audit(...) returns the audit mock
audit = MagicMock()
audit.write = MagicMock()
backend.audit = MagicMock(return_value=audit)
return backend
@pytest.fixture(scope="function")
async def ssh_server(mock_backend: MagicMock, unused_tcp_port: int) -> SshServerFixtureFun:
port = unused_tcp_port
private_key = asyncssh.generate_private_key('ssh-ed25519')
key_str = private_key.export_private_key()
with tempfile.NamedTemporaryFile('w+', delete=True) as key_file:
key_file.write(key_str.decode())
key_file.flush()
registration_settings = ClientRegistrationSettings(enabled=True, allow_from=[IPv4Network("0.0.0.0/0")])
server = await run_ssh_server(
backend=mock_backend,
listen_address="localhost",
port=port,
keys=[key_file.name],
registration=registration_settings,
enable_ping_command=True,
)
await asyncio.sleep(0.1)
yield server, port
server.close()
await server.wait_closed()
@pytest.fixture(scope="function")
def ssh_command_runner(ssh_server: SshServerFixture, client_registry: ClientRegistry):
"""Run a single command.
Tricky typing!
"""
_, port = ssh_server
async def run_command_as(name: str, command: str):
client_key = client_registry['clients'][name]
conn = await asyncssh.connect(
"127.0.0.1",
port=port,
username=name,
client_keys=[client_key.private_key],
known_hosts=None,
)
try:
result = await conn.run(command)
return result
finally:
conn.close()
await conn.wait_closed()
return run_command_as
@pytest.fixture(scope="function")
def ssh_session(ssh_server: SshServerFixture, client_registry: ClientRegistry):
"""Yield an interactive session."""
_, port = ssh_server
@asynccontextmanager
async def run_process_as(name: str, command: str, client: str | None = None):
if not client:
client = name
client_key = client_registry["clients"][client]
conn = await asyncssh.connect(
"127.0.0.1",
port=port,
username=name,
client_keys=[client_key.private_key],
known_hosts=None,
)
try:
async with conn.create_process(command) as process:
yield process
finally:
conn.close()
await conn.wait_closed()
return run_process_as

View File

@ -0,0 +1,28 @@
"""Test get secret."""
import pytest
from .types import ClientRegistry, CommandRunner
@pytest.mark.asyncio
async def test_get_secret(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
"""Test that we can get a secret."""
await client_registry['add_client']("test-client", ["mysecret"])
result = await ssh_command_runner("test-client", "get_secret mysecret")
assert result.stdout is not None
assert isinstance(result.stdout, str)
assert result.stdout.rstrip() == "mocked-secret-mysecret"
@pytest.mark.asyncio
async def test_invalid_secret_name(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
"""Test getting an invalid secret name."""
await client_registry['add_client']("test-client")
result = await ssh_command_runner("test-client", "get_secret mysecret")
assert result.exit_status == 1
assert result.stderr == "Error: No secret available with the given name."

View File

@ -0,0 +1,15 @@
import pytest
from .types import ClientRegistry, CommandRunner
@pytest.mark.asyncio
async def test_ping_command(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
# Register a test client with default policies and no secrets
await client_registry['add_client']('test-pinger')
result = await ssh_command_runner("test-pinger", "ping")
assert result.exit_status == 0
assert result.stdout is not None
assert isinstance(result.stdout, str)
assert result.stdout.rstrip() == "PONG"

View File

@ -0,0 +1,34 @@
"""Test registration."""
import pytest
from .types import ClientRegistry, CommandRunner, ProcessRunner
@pytest.mark.asyncio
async def test_register_client(ssh_session: ProcessRunner, ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
"""Test client registration."""
await client_registry["add_client"]("template", ["testsecret"])
public_key = client_registry["clients"]["template"].public_key.rstrip() + "\n"
async with ssh_session("newclient", "register", "template") as session:
maxlines = 10
l = 0
found = False
while l < maxlines:
line = await session.stdout.readline()
if "Enter public key" in line:
found = True
break
assert found is True
session.stdin.write(public_key)
result = await session.stdout.readline()
assert "OK" in result
# Test that we can connect
result = await ssh_command_runner("newclient", "get_secret testsecret")
assert result.stdout is not None
assert isinstance(result.stdout, str)
assert result.stdout.rstrip() == "mocked-secret-testsecret"

View File

@ -0,0 +1,56 @@
"""Types for the various test properties."""
import uuid
from datetime import datetime
from dataclasses import dataclass, field
from ipaddress import IPv4Network, IPv6Network
from collections.abc import AsyncGenerator, Awaitable, Callable, AsyncIterator
from typing import Any, Protocol, TypedDict, AsyncContextManager
import asyncssh
SshServerFixture = tuple[str, int]
SshServerFixtureFun = AsyncGenerator[tuple[asyncssh.SSHAcceptor, int], None]
@dataclass
class Client:
"""Mock client."""
id: uuid.UUID
name: str
description: str | None
public_key: str
secrets: list[str]
policies: list[IPv4Network | IPv6Network]
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)
@dataclass
class ClientKey:
name: str
private_key: asyncssh.SSHKey
public_key: str
class AddClientFun(Protocol):
"""Add client function."""
def __call__(self, name: str, secret_names: list[str] | None = None, policies: list[str] | None = None) -> Awaitable[str]: ...
class ProcessRunner(Protocol):
"""Process runner typing."""
def __call__(self, name: str, command: str, client: str | None = None) -> AsyncContextManager[asyncssh.SSHClientProcess[Any]]:
...
class ClientRegistry(TypedDict):
"""Client registry typing."""
clients: dict[str, ClientKey]
secrets: dict[tuple[str, str], str]
add_client: AddClientFun
CommandRunner = Callable[[str, str], Awaitable[asyncssh.SSHCompletedProcess]]