Files
sshecret/packages/sshecret-sshd/tests/conftest.py
2025-05-10 08:29:58 +02:00

180 lines
5.4 KiB
Python

import asyncio
import pytest
import uuid
import asyncssh
import tempfile
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock
from ipaddress import IPv4Network, IPv6Network
from sshecret_sshd.ssh_server import run_ssh_server
from sshecret_sshd.settings import ClientRegistrationSettings
from .types import (
Client,
ClientKey,
ClientRegistry,
SshServerFixtureFun,
SshServerFixture,
)
@pytest.fixture(scope="function")
def client_registry() -> ClientRegistry:
clients = {}
secrets = {}
async def add_client(
name: str,
secret_names: list[str] | None = None,
policies: list[str] | None = None,
) -> str:
private_key = asyncssh.generate_private_key("ssh-rsa")
public_key = private_key.export_public_key()
clients[name] = ClientKey(name, private_key, public_key.decode().rstrip())
secrets.update({(name, s): f"mocked-secret-{s}" for s in (secret_names or [])})
return clients[name]
return {"clients": clients, "secrets": secrets, "add_client": add_client}
@pytest.fixture(scope="function")
async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
backend = MagicMock()
clients_data = client_registry["clients"]
secrets_data = client_registry["secrets"]
async def get_client(name: str) -> Client | None:
client_key = clients_data.get(name)
if client_key:
response_model = Client(
id=uuid.uuid4(),
name=name,
description=f"Mock client {name}",
public_key=client_key.public_key,
secrets=[s for (c, s) in secrets_data if c == name],
policies=[IPv4Network("0.0.0.0/0"), IPv6Network("::/0")],
)
return response_model
return None
async def get_client_secret(name: str, secret_name: str) -> str | None:
secret = secrets_data.get((name, secret_name), None)
return secret
async def create_client(name: str, public_key: str) -> None:
"""Create client.
This only works if you register a client called template first.
Otherwise we can't test this...
"""
if "template" not in clients_data:
raise RuntimeError(
"Error, must have a client called template for this to work."
)
clients_data[name] = clients_data["template"]
for secret_key, secret in secrets_data.items():
s_client, secret_name = secret_key
if s_client != "template":
continue
secrets_data[(name, secret_name)] = secret
backend.get_client = AsyncMock(side_effect=get_client)
backend.get_client_secret = AsyncMock(side_effect=get_client_secret)
backend.create_client = AsyncMock(side_effect=create_client)
# Make sure backend.audit(...) returns the audit mock
audit = MagicMock()
audit.write = MagicMock()
backend.audit = MagicMock(return_value=audit)
return backend
@pytest.fixture(scope="function")
async def ssh_server(
mock_backend: MagicMock, unused_tcp_port: int
) -> SshServerFixtureFun:
port = unused_tcp_port
private_key = asyncssh.generate_private_key("ssh-ed25519")
key_str = private_key.export_private_key()
with tempfile.NamedTemporaryFile("w+", delete=True) as key_file:
key_file.write(key_str.decode())
key_file.flush()
registration_settings = ClientRegistrationSettings(
enabled=True, allow_from=[IPv4Network("0.0.0.0/0")]
)
server = await run_ssh_server(
backend=mock_backend,
listen_address="localhost",
port=port,
keys=[key_file.name],
registration=registration_settings,
enable_ping_command=True,
)
await asyncio.sleep(0.1)
yield server, port
server.close()
await server.wait_closed()
@pytest.fixture(scope="function")
def ssh_command_runner(ssh_server: SshServerFixture, client_registry: ClientRegistry):
"""Run a single command.
Tricky typing!
"""
_, port = ssh_server
async def run_command_as(name: str, command: str):
client_key = client_registry["clients"][name]
conn = await asyncssh.connect(
"127.0.0.1",
port=port,
username=name,
client_keys=[client_key.private_key],
known_hosts=None,
)
try:
result = await conn.run(command)
return result
finally:
conn.close()
await conn.wait_closed()
return run_command_as
@pytest.fixture(scope="function")
def ssh_session(ssh_server: SshServerFixture, client_registry: ClientRegistry):
"""Yield an interactive session."""
_, port = ssh_server
@asynccontextmanager
async def run_process_as(name: str, command: str, client: str | None = None):
if not client:
client = name
client_key = client_registry["clients"][client]
conn = await asyncssh.connect(
"127.0.0.1",
port=port,
username=name,
client_keys=[client_key.private_key],
known_hosts=None,
)
try:
async with conn.create_process(command) as process:
yield process
finally:
conn.close()
await conn.wait_closed()
return run_process_as