import asyncio from typing import Any import pytest import uuid import asyncssh import tempfile from contextlib import asynccontextmanager from unittest.mock import AsyncMock, MagicMock from ipaddress import IPv4Network, IPv6Network from sshecret_sshd.ssh_server import run_ssh_server from sshecret_sshd.settings import ClientRegistrationSettings from .types import ( Client, ClientKey, ClientRegistry, SshServerFixtureFun, SshServerFixture, ) @pytest.fixture(scope="function") def client_registry() -> ClientRegistry: clients = {} secrets = {} async def add_client( name: str, secret_names: list[str] | None = None, policies: list[str] | None = None, ) -> str: private_key = asyncssh.generate_private_key("ssh-rsa") public_key = private_key.export_public_key() clients[name] = ClientKey(name, private_key, public_key.decode().rstrip()) secrets.update({(name, s): f"mocked-secret-{s}" for s in (secret_names or [])}) return clients[name] return {"clients": clients, "secrets": secrets, "add_client": add_client} @pytest.fixture(scope="function") async def mock_backend(client_registry: ClientRegistry) -> MagicMock: backend = MagicMock() clients_data = client_registry["clients"] secrets_data = client_registry["secrets"] async def get_client(name: str) -> Client | None: client_key = clients_data.get(name) if client_key: response_model = Client( id=uuid.uuid4(), name=name, description=f"Mock client {name}", public_key=client_key.public_key, secrets=[s for (c, s) in secrets_data if c == name], policies=[IPv4Network("0.0.0.0/0"), IPv6Network("::/0")], ) return response_model return None async def get_client_secret(name: str, secret_name: str) -> str | None: secret = secrets_data.get((name, secret_name), None) return secret async def create_client(name: str, public_key: str) -> None: """Create client. This only works if you register a client called template first. Otherwise we can't test this... """ if "template" not in clients_data: raise RuntimeError( "Error, must have a client called template for this to work." ) clients_data[name] = clients_data["template"] for secret_key, secret in secrets_data.items(): s_client, secret_name = secret_key if s_client != "template": continue secrets_data[(name, secret_name)] = secret async def write_audit(*args, **kwargs): """Write audit mock.""" return None backend.get_client = AsyncMock(side_effect=get_client) backend.get_client_secret = AsyncMock(side_effect=get_client_secret) backend.create_client = AsyncMock(side_effect=create_client) audit = MagicMock() audit.write_async = AsyncMock(side_effect=write_audit) backend.audit = MagicMock(return_value=audit) return backend @pytest.fixture(scope="function") async def ssh_server( mock_backend: MagicMock, unused_tcp_port: int ) -> SshServerFixtureFun: port = unused_tcp_port private_key = asyncssh.generate_private_key("ssh-ed25519") key_str = private_key.export_private_key() with tempfile.NamedTemporaryFile("w+", delete=True) as key_file: key_file.write(key_str.decode()) key_file.flush() registration_settings = ClientRegistrationSettings( enabled=True, allow_from=[IPv4Network("0.0.0.0/0")] ) server = await run_ssh_server( backend=mock_backend, listen_address="localhost", port=port, keys=[key_file.name], registration=registration_settings, enable_ping_command=True, ) await asyncio.sleep(0.1) yield server, port server.close() await server.wait_closed() @pytest.fixture(scope="function") def ssh_command_runner(ssh_server: SshServerFixture, client_registry: ClientRegistry): """Run a single command. Tricky typing! """ _, port = ssh_server async def run_command_as(name: str, command: str): client_key = client_registry["clients"][name] conn = await asyncssh.connect( "127.0.0.1", port=port, username=name, client_keys=[client_key.private_key], known_hosts=None, ) try: result = await conn.run(command) return result finally: conn.close() await conn.wait_closed() return run_command_as @pytest.fixture(scope="function") def ssh_session(ssh_server: SshServerFixture, client_registry: ClientRegistry): """Yield an interactive session.""" _, port = ssh_server @asynccontextmanager async def run_process_as(name: str, command: str, client: str | None = None): if not client: client = name client_key = client_registry["clients"][client] conn = await asyncssh.connect( "127.0.0.1", port=port, username=name, client_keys=[client_key.private_key], known_hosts=None, ) try: async with conn.create_process(command) as process: yield process finally: conn.close() await conn.wait_closed() return run_process_as