From 4f970a3f713e6c226cc453a853be71d09f6b2993 Mon Sep 17 00:00:00 2001 From: Allan Eising Date: Sat, 10 May 2025 08:27:16 +0200 Subject: [PATCH] Complete sshd package with tests --- packages/sshecret-sshd/pyproject.toml | 6 + packages/sshecret-sshd/pytest.ini | 2 + .../sshecret-sshd/src/sshecret_sshd/cli.py | 45 +++-- .../src/sshecret_sshd/constants.py | 6 + .../src/sshecret_sshd/settings.py | 44 ++++- .../src/sshecret_sshd/ssh_server.py | 178 ++++++++++-------- packages/sshecret-sshd/tests/__init__.py | 1 + packages/sshecret-sshd/tests/conftest.py | 160 ++++++++++++++++ .../sshecret-sshd/tests/test_get_secret.py | 28 +++ packages/sshecret-sshd/tests/test_ping.py | 15 ++ packages/sshecret-sshd/tests/test_register.py | 34 ++++ packages/sshecret-sshd/tests/types.py | 56 ++++++ 12 files changed, 472 insertions(+), 103 deletions(-) create mode 100644 packages/sshecret-sshd/pytest.ini create mode 100644 packages/sshecret-sshd/tests/__init__.py create mode 100644 packages/sshecret-sshd/tests/conftest.py create mode 100644 packages/sshecret-sshd/tests/test_get_secret.py create mode 100644 packages/sshecret-sshd/tests/test_ping.py create mode 100644 packages/sshecret-sshd/tests/test_register.py create mode 100644 packages/sshecret-sshd/tests/types.py diff --git a/packages/sshecret-sshd/pyproject.toml b/packages/sshecret-sshd/pyproject.toml index fbfa7d3..0ad3989 100644 --- a/packages/sshecret-sshd/pyproject.toml +++ b/packages/sshecret-sshd/pyproject.toml @@ -9,10 +9,16 @@ authors = [ requires-python = ">=3.13" dependencies = [ "asyncssh>=2.20.0", + "click>=8.1.8", "httpx>=0.28.1", + "pydantic>=2.10.6", "python-dotenv>=1.0.1", + "sshecret", ] +[tool.uv.sources] +sshecret = { workspace = true } + [project.scripts] sshecret-sshd = "sshecret_sshd.cli:cli" diff --git a/packages/sshecret-sshd/pytest.ini b/packages/sshecret-sshd/pytest.ini new file mode 100644 index 0000000..2f4c80e --- /dev/null +++ b/packages/sshecret-sshd/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +asyncio_mode = auto diff --git a/packages/sshecret-sshd/src/sshecret_sshd/cli.py b/packages/sshecret-sshd/src/sshecret_sshd/cli.py index 0948827..7f967c8 100644 --- a/packages/sshecret-sshd/src/sshecret_sshd/cli.py +++ b/packages/sshecret-sshd/src/sshecret_sshd/cli.py @@ -2,8 +2,10 @@ import logging import asyncio import sys -from pydantic_settings import CliApp +from typing import cast +import click +from pydantic import ValidationError from .settings import ServerSettings from .ssh_server import start_server @@ -17,22 +19,38 @@ LOG.addHandler(handler) LOG.setLevel(logging.INFO) -def cli(args: list[str] | None = None) -> None: - """Run CLI app.""" - try: - settings = ServerSettings() - except Exception: - print("One or more settings could not be resolved.") - CliApp.run(ServerSettings, ["--help"]) - sys.exit(1) - - if settings.debug: +@click.group() +@click.option("--debug", is_flag=True) +@click.pass_context +def cli(ctx: click.Context, debug: bool) -> None: + """Sshecret Admin.""" + if debug: LOG.setLevel(logging.DEBUG) + try: + settings = ServerSettings() # pyright: ignore[reportCallIssue] + except ValidationError as e: + raise click.ClickException( + "Error: One or more required environment options are missing." + ) from e + ctx.obj = settings + + +@cli.command("run") +@click.option("--host") +@click.option("--port", type=click.INT) +@click.pass_context +def cli_run(ctx: click.Context, host: str | None, port: int | None) -> None: + """Run the server.""" + settings = cast(ServerSettings, ctx.obj) + if host: + settings.listen_address = host + if port: + settings.port = port loop = asyncio.new_event_loop() loop.run_until_complete(start_server(settings)) - - print(f"Starting SSH server: {settings.listen_address}:{settings.port}") + title = click.style("Sshecret SSH Daemon", fg="red", bold=True) + click.echo(f"Starting {title}: {settings.listen_address}:{settings.port}") try: loop.run_forever() except KeyboardInterrupt: @@ -40,7 +58,6 @@ def cli(args: list[str] | None = None) -> None: sys.exit() - if __name__ == "__main__": """Run CLI app.""" cli() diff --git a/packages/sshecret-sshd/src/sshecret_sshd/constants.py b/packages/sshecret-sshd/src/sshecret_sshd/constants.py index 1025f1a..dfaea77 100644 --- a/packages/sshecret-sshd/src/sshecret_sshd/constants.py +++ b/packages/sshecret-sshd/src/sshecret_sshd/constants.py @@ -4,5 +4,11 @@ ERROR_NO_COMMAND_RECEIVED = "Error: No command was received from the client." ERROR_SOURCE_IP_NOT_ALLOWED = ( "Error: Client not authorized to connect from the given host." ) +ERROR_NO_PUBLIC_KEY = "Error: No valid public key received." +ERROR_INVALID_KEY_TYPE = "Error: Invalid key type: Only RSA keys are supported." ERROR_UNKNOWN_COMMAND = "Error: The given command was not understood." SERVER_KEY_TYPE = "ed25519" +ERROR_BACKEND_ERROR = "Error: Unexpected response or error from backend" +ERROR_INFO_BACKEND_GONE = "Unexpected error: Backend connection lost." +ERROR_INFO_USERNAME_GONE = "Unexpected error: Username lost in transit." +ERROR_INFO_REMOTE_IP_GONE = "Unexpected error: Client connection details lost in transit." diff --git a/packages/sshecret-sshd/src/sshecret_sshd/settings.py b/packages/sshecret-sshd/src/sshecret_sshd/settings.py index d7362f0..42843b0 100644 --- a/packages/sshecret-sshd/src/sshecret_sshd/settings.py +++ b/packages/sshecret-sshd/src/sshecret_sshd/settings.py @@ -1,18 +1,48 @@ """SSH Server settings.""" -from pydantic import AnyHttpUrl, Field, AliasChoices -from pydantic_settings import BaseSettings, SettingsConfigDict +import ipaddress +from typing import Annotated, Any +from pydantic import AnyHttpUrl, BaseModel, Field, IPvAnyNetwork, field_validator +from pydantic_settings import BaseSettings, ForceDecode, SettingsConfigDict DEFAULT_LISTEN_PORT = 2222 -class ServerSettings(BaseSettings, cli_parse_args=True, cli_exit_on_error=True): + +class ClientRegistrationSettings(BaseModel): + """Client registration settings.""" + + enabled: bool = False + allow_from: Annotated[list[IPvAnyNetwork], ForceDecode] = Field(default_factory=list) + + @field_validator('allow_from', mode="before") + @classmethod + def ensure_allow_from_list(cls, value: Any) -> list[IPvAnyNetwork]: + """Convert allow_from to a list.""" + allow_from: list[IPvAnyNetwork] = [] + if isinstance(value, list): + entries = value + elif isinstance(value, str): + entries = value.split(",") + else: + raise ValueError("Error: Unknown format for allowed_from.") + + for entry in entries: + if isinstance(entry, str): + allow_from.append(ipaddress.ip_network(entry)) + elif isinstance(entry, (ipaddress.IPv4Network, ipaddress.IPv6Network)): + allow_from.append(entry) + return allow_from + +class ServerSettings(BaseSettings): """Server Settings.""" - model_config = SettingsConfigDict(env_file=".sshd.env", env_prefix="sshecret_sshd_") + model_config = SettingsConfigDict(env_file=".sshd.env", env_prefix="sshecret_sshd_", env_nested_delimiter='_') - backend_url: AnyHttpUrl = Field(validation_alias=AliasChoices("backend-url", "sshecret_backend_url")) - backend_token: str = Field(validation_alias=AliasChoices("backend-token", "sshecret_sshd_backend_token")) - listen_address: str = Field(default="", alias="listen") + backend_url: AnyHttpUrl = Field(alias="sshecret_backend_url") + backend_token: str + listen_address: str = Field(default="127.0.0.1") port: int = DEFAULT_LISTEN_PORT + registration: ClientRegistrationSettings = Field(default_factory=ClientRegistrationSettings) debug: bool = False + enable_ping_command: bool = False diff --git a/packages/sshecret-sshd/src/sshecret_sshd/ssh_server.py b/packages/sshecret-sshd/src/sshecret_sshd/ssh_server.py index 0028d61..baccca0 100644 --- a/packages/sshecret-sshd/src/sshecret_sshd/ssh_server.py +++ b/packages/sshecret-sshd/src/sshecret_sshd/ssh_server.py @@ -1,21 +1,21 @@ """SSH Server implementation.""" import logging -import uuid import asyncssh import ipaddress from collections.abc import Awaitable -from datetime import datetime, timezone from functools import partial from pathlib import Path -from typing import Any, Callable, cast, override +from typing import Callable, cast, override + +from pydantic import IPvAnyNetwork from . import constants -from sshecret.backend import AuditLog, SshecretBackend, Client -from .settings import ServerSettings +from sshecret.backend import SshecretBackend, Client, Operation, SubSystem +from .settings import ServerSettings, ClientRegistrationSettings LOG = logging.getLogger(__name__) @@ -35,71 +35,39 @@ class CommandError(Exception): def audit_process( backend: SshecretBackend, process: asyncssh.SSHServerProcess[str], + operation: Operation, message: str, secret: str | None = None, + **data: str, ) -> None: """Add an audit event from process.""" command = get_process_command(process) client = get_info_client(process) username = get_info_username(process) - remote_ip = get_info_remote_ip(process) - operation = "SSH_EVENT" - obj_name: str | None = None - obj_id: str | None = None + remote_ip = get_info_remote_ip(process) or "UNKNOWN" + if username: + data["username"] = username if command and not secret: cmd, cmd_args = command - obj_id = " ".join(cmd_args) - elif secret: - obj_name = "ClientSecret" - obj_id = secret - - entry = AuditLog( - subsystem="ssh", - operation=operation, - object=obj_name, - object_id=obj_id, - message=message, - origin=remote_ip, - ) - if client: - entry.client_id = str(client.id) - entry.client_name = client.name - elif username: - entry.client_name = username - - backend.add_audit_log_sync(entry) + if cmd: + data["command"] = cmd + data["args"] = " ".join(cmd_args) + backend.audit(SubSystem.SSHD).write(operation, message, remote_ip, client, secret=None, secret_name=secret, **data) def audit_event( backend: SshecretBackend, message: str, - operation: str = "SSH_EVENT", + operation: Operation, client: Client | None = None, origin: str | None = None, secret: str | None = None, ) -> None: """Add an audit event.""" - entry = AuditLog( - client_id=None, - client_name=None, - object=None, - object_id=None, - subsystem="ssh", - operation=operation, - message=message, - origin=origin, - ) - if client: - entry.client_id = str(client.id) - entry.client_name = client.name - - if secret: - entry.object = "ClientSecret" - entry.object_id = secret - - backend.add_audit_log_sync(entry) - + if not origin: + origin = "UNKNOWN" + backend.audit(SubSystem.SSHD).write(operation, message, origin, client, secret=None, secret_name=secret) def verify_key_input(public_key: str) -> str | None: """Verify key input.""" @@ -118,6 +86,7 @@ def get_process_command( if not process.command: return (None, []) argv = process.command.split(" ") + LOG.debug("Args: %r", argv) return (argv[0], argv[1:]) @@ -149,7 +118,12 @@ def get_info_remote_ip(process: asyncssh.SSHServerProcess[str]) -> str | None: return remote_ip - # remote_ip = str(self._conn.get_extra_info("peername")[0]) +def get_info_allowed_registration(process: asyncssh.SSHServerProcess[str]) -> list[IPvAnyNetwork] | None: + """Get allowed networks to allow registration from.""" + + allowed_registration = cast(list[IPvAnyNetwork] | None, process.get_extra_info("allow_registration_from", None)) + return allowed_registration + def get_optional_commands(process: asyncssh.SSHServerProcess[str]) -> dict[str, bool]: @@ -197,12 +171,12 @@ async def register_client( """Register a new client.""" public_key = await get_stdin_public_key(process) if not public_key: - raise CommandError("Aborted. No valid public key received.") + raise CommandError(constants.ERROR_NO_PUBLIC_KEY) key = asyncssh.import_public_key(public_key) if key.algorithm.decode() != "ssh-rsa": - raise CommandError("Error: Only RSA keys are supported!") - audit_process(backend, process, "Registering new client") + raise CommandError(constants.ERROR_INVALID_KEY_TYPE) + audit_process(backend, process, Operation.CREATE, "Registering new client") LOG.debug("Registering client %s with public key %s", username, public_key) await backend.create_client(username, public_key) @@ -214,14 +188,16 @@ async def get_secret( origin: str, ) -> str: """Handle get secret requests from client.""" - LOG.debug("Recieved command: %r", secret_name) - if not secret_name or secret_name not in client.secrets: + LOG.debug("Recieved command: get_secret %r", secret_name) + if not secret_name: raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET) + if secret_name not in client.secrets: + raise CommandError(constants.ERROR_NO_SECRET_FOUND) audit_event( backend, "Client requested secret", - operation="get_secret", + operation=Operation.READ, client=client, origin=origin, secret=secret_name, @@ -229,10 +205,13 @@ async def get_secret( # Look up secret try: - return await backend.get_client_secret(client.name, secret_name) + secret = await backend.get_client_secret(client.name, secret_name) + if not secret: + raise CommandError(constants.ERROR_NO_SECRET_FOUND) + return secret except Exception as exc: LOG.debug(exc, exc_info=True) - raise CommandError("Unexpected error from backend") from exc + raise CommandError(constants.ERROR_BACKEND_ERROR) from exc async def dispatch_no_cmd(process: asyncssh.SSHServerProcess[str]) -> None: @@ -249,10 +228,32 @@ async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None """Dispatch the register command.""" backend = get_info_backend(process) if not backend: - raise CommandError("Unexpected error: Backend disappeared.") + raise CommandError(constants.ERROR_INFO_BACKEND_GONE) username = get_info_username(process) if not username: - raise CommandError("Unexpected error: Username was lost.") + raise CommandError(constants.ERROR_INFO_USERNAME_GONE) + + allowed_networks = get_info_allowed_registration(process) + if not allowed_networks: + process.stdout.write("Unauthorized.\n") + audit_process(backend, process, Operation.DENY, "Received registration command, but no subnets are allowed.") + return + + remote_ip = get_info_remote_ip(process) + + if not remote_ip: + raise CommandError(constants.ERROR_INFO_REMOTE_IP_GONE) + + client_address = ipaddress.ip_address(remote_ip) + + for network in allowed_networks: + if client_address in network: + break + else: + audit_process(backend, process, Operation.DENY, "Received registration command from unauthorized subnet.") + process.stdout.write("Unauthorized.\n") + return + await register_client(process, backend, username) process.stdout.write("Client registered\n.") @@ -262,7 +263,7 @@ async def dispatch_cmd_get_secret(process: asyncssh.SSHServerProcess[str]) -> No """Dispatch the get_secret command.""" backend = get_info_backend(process) if not backend: - raise CommandError("Unexpected error: Backend disappeared.") + raise CommandError(constants.ERROR_INFO_BACKEND_GONE) client = get_info_client(process) if not client: @@ -311,6 +312,7 @@ async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None: process.stderr.write(str(e)) exit_code = 1 + LOG.debug("Command processing finished.") process.exit(exit_code) @@ -319,16 +321,18 @@ class AsshyncServer(asyncssh.SSHServer): def __init__( self, - backend_url: str, - backend_token: str, - with_register: bool = True, - with_ping: bool = True, + backend: SshecretBackend, + registration: ClientRegistrationSettings, + enable_ping_command: bool = False, ) -> None: """Initialize server.""" - self.backend: SshecretBackend = SshecretBackend(backend_url, backend_token) + self.backend: SshecretBackend = backend self._conn: asyncssh.SSHServerConnection | None = None - self.registration_enabled: bool = with_register - self.ping_enabled: bool = with_ping + self.registration_enabled: bool = registration.enabled + self.allow_registration_from: list[IPvAnyNetwork] | None = None + if registration.enabled: + self.allow_registration_from = registration.allow_from + self.ping_enabled: bool = enable_ping_command self.client_ip: str | None = None @override @@ -359,9 +363,9 @@ class AsshyncServer(asyncssh.SSHServer): if not self._conn: return True if client := await self.backend.get_client(username): - LOG.debug("Client lookup sucessful.") + LOG.debug("Client lookup sucessful: %r", client) if key := self.resolve_client_key(client): - LOG.debug("Loaded public key for client %s", client.name) + LOG.debug("Loaded public key for client %s\n%s", client.name, key) self._conn.set_extra_info(client=client) self._conn.set_authorized_keys(key) else: @@ -369,15 +373,18 @@ class AsshyncServer(asyncssh.SSHServer): audit_event( self.backend, "Client denied due to policy", - "DENY", + Operation.DENY, client, origin=self.client_ip, ) LOG.warning("Client connection denied due to policy.") - else: + elif self.registration_enabled: self._conn.set_extra_info(provided_username=username) + self._conn.set_extra_info(allow_registration_from=self.allow_registration_from) + LOG.warning("Registration enabled, and client is not recognized. Bypassing authentication.") return False + LOG.debug("Continuing to regular authentication") return True @override @@ -403,6 +410,7 @@ class AsshyncServer(asyncssh.SSHServer): return None remote_ip = str(self._conn.get_extra_info("peername")[0]) LOG.debug("Validating client %s connection from %s", client.name, remote_ip) + LOG.debug("Loading client public key %r", client.public_key) if self.check_connection_allowed(client, remote_ip): return asyncssh.import_authorized_keys(client.public_key) return None @@ -428,7 +436,7 @@ def get_server_key(basedir: Path | None = None) -> str: if filename.exists(): return str(filename.absolute()) # FIXME: There's a weird typing warning here that I need to investigate. - private_key = asyncssh.generate_private_key("ssh-ed25519", "sshecret-sshd") + private_key = asyncssh.generate_private_key("ssh-ed25519", comment="sshecret-sshd") with open(filename, "wb") as f: f.write(private_key.export_private_key()) @@ -436,15 +444,19 @@ def get_server_key(basedir: Path | None = None) -> str: async def run_ssh_server( - backend_url: str, - backend_token: str, + backend: SshecretBackend, listen_address: str, port: int, keys: list[str], + registration: ClientRegistrationSettings, + enable_ping_command: bool = False, ) -> asyncssh.SSHAcceptor: """Run the server.""" server = partial( - AsshyncServer, backend_url=str(backend_url), backend_token=backend_token + AsshyncServer, + backend=backend, + registration=registration, + enable_ping_command=enable_ping_command, ) server = await asyncssh.create_server( server, @@ -463,10 +475,12 @@ async def start_server(settings: ServerSettings | None = None) -> None: if not settings: settings = ServerSettings() # pyright: ignore[reportCallIssue] + backend = SshecretBackend(str(settings.backend_url), settings.backend_token) await run_ssh_server( - str(settings.backend_url), - settings.backend_token, - settings.listen_address, - settings.port, - [server_key], + backend=backend, + listen_address=settings.listen_address, + port=settings.port, + keys=[server_key], + registration=settings.registration, + enable_ping_command=settings.enable_ping_command, ) diff --git a/packages/sshecret-sshd/tests/__init__.py b/packages/sshecret-sshd/tests/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/packages/sshecret-sshd/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/packages/sshecret-sshd/tests/conftest.py b/packages/sshecret-sshd/tests/conftest.py new file mode 100644 index 0000000..b9d809d --- /dev/null +++ b/packages/sshecret-sshd/tests/conftest.py @@ -0,0 +1,160 @@ +import asyncio +import pytest +import uuid +import asyncssh +import tempfile +from contextlib import asynccontextmanager +from unittest.mock import AsyncMock, MagicMock +from ipaddress import IPv4Network, IPv6Network + +from sshecret_sshd.ssh_server import run_ssh_server +from sshecret_sshd.settings import ClientRegistrationSettings + +from .types import Client, ClientKey, ClientRegistry, SshServerFixtureFun, SshServerFixture + +@pytest.fixture(scope="function") +def client_registry() -> ClientRegistry: + clients = {} + secrets = {} + + async def add_client(name: str, secret_names: list[str]|None=None, policies: list[str]|None=None) -> str: + private_key = asyncssh.generate_private_key('ssh-rsa') + public_key = private_key.export_public_key() + clients[name] = ClientKey(name, private_key, public_key.decode().rstrip()) + secrets.update({(name, s): f"mocked-secret-{s}" for s in (secret_names or [])}) + return clients[name] + + return {"clients": clients, "secrets": secrets, "add_client": add_client} + +@pytest.fixture(scope="function") +async def mock_backend(client_registry: ClientRegistry) -> MagicMock: + backend = MagicMock() + clients_data = client_registry['clients'] + secrets_data = client_registry['secrets'] + + async def get_client(name: str) -> Client | None: + client_key = clients_data.get(name) + if client_key: + response_model = Client( + id=uuid.uuid4(), + name=name, + description=f"Mock client {name}", + public_key=client_key.public_key, + secrets=[s for (c, s) in secrets_data if c == name], + policies=[IPv4Network('0.0.0.0/0'), IPv6Network('::/0')], + ) + return response_model + return None + + async def get_client_secret(name: str, secret_name: str) -> str | None: + secret = secrets_data.get((name, secret_name), None) + return secret + + async def create_client(name: str, public_key: str) -> None: + """Create client. + + This only works if you register a client called template first. + Otherwise we can't test this... + """ + if "template" not in clients_data: + raise RuntimeError("Error, must have a client called template for this to work.") + clients_data[name] = clients_data["template"] + for secret_key, secret in secrets_data.items(): + s_client, secret_name = secret_key + if s_client != "template": + continue + secrets_data[(name, secret_name)] = secret + + backend.get_client = AsyncMock(side_effect=get_client) + backend.get_client_secret = AsyncMock(side_effect=get_client_secret) + backend.create_client = AsyncMock(side_effect=create_client) + + # Make sure backend.audit(...) returns the audit mock + audit = MagicMock() + audit.write = MagicMock() + backend.audit = MagicMock(return_value=audit) + + return backend + +@pytest.fixture(scope="function") +async def ssh_server(mock_backend: MagicMock, unused_tcp_port: int) -> SshServerFixtureFun: + port = unused_tcp_port + + private_key = asyncssh.generate_private_key('ssh-ed25519') + key_str = private_key.export_private_key() + with tempfile.NamedTemporaryFile('w+', delete=True) as key_file: + + key_file.write(key_str.decode()) + key_file.flush() + + registration_settings = ClientRegistrationSettings(enabled=True, allow_from=[IPv4Network("0.0.0.0/0")]) + server = await run_ssh_server( + backend=mock_backend, + listen_address="localhost", + port=port, + keys=[key_file.name], + registration=registration_settings, + enable_ping_command=True, + ) + + await asyncio.sleep(0.1) + + yield server, port + + server.close() + await server.wait_closed() + +@pytest.fixture(scope="function") +def ssh_command_runner(ssh_server: SshServerFixture, client_registry: ClientRegistry): + """Run a single command. + + Tricky typing! + """ + _, port = ssh_server + + async def run_command_as(name: str, command: str): + client_key = client_registry['clients'][name] + conn = await asyncssh.connect( + "127.0.0.1", + port=port, + username=name, + client_keys=[client_key.private_key], + known_hosts=None, + ) + try: + result = await conn.run(command) + return result + finally: + conn.close() + await conn.wait_closed() + + return run_command_as + + +@pytest.fixture(scope="function") +def ssh_session(ssh_server: SshServerFixture, client_registry: ClientRegistry): + """Yield an interactive session.""" + + _, port = ssh_server + + @asynccontextmanager + async def run_process_as(name: str, command: str, client: str | None = None): + if not client: + client = name + client_key = client_registry["clients"][client] + conn = await asyncssh.connect( + "127.0.0.1", + port=port, + username=name, + client_keys=[client_key.private_key], + known_hosts=None, + ) + try: + async with conn.create_process(command) as process: + yield process + + finally: + conn.close() + await conn.wait_closed() + + return run_process_as diff --git a/packages/sshecret-sshd/tests/test_get_secret.py b/packages/sshecret-sshd/tests/test_get_secret.py new file mode 100644 index 0000000..dca1645 --- /dev/null +++ b/packages/sshecret-sshd/tests/test_get_secret.py @@ -0,0 +1,28 @@ +"""Test get secret.""" + +import pytest + +from .types import ClientRegistry, CommandRunner + +@pytest.mark.asyncio +async def test_get_secret(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None: + """Test that we can get a secret.""" + + await client_registry['add_client']("test-client", ["mysecret"]) + + result = await ssh_command_runner("test-client", "get_secret mysecret") + + assert result.stdout is not None + assert isinstance(result.stdout, str) + assert result.stdout.rstrip() == "mocked-secret-mysecret" + + +@pytest.mark.asyncio +async def test_invalid_secret_name(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None: + """Test getting an invalid secret name.""" + + await client_registry['add_client']("test-client") + + result = await ssh_command_runner("test-client", "get_secret mysecret") + assert result.exit_status == 1 + assert result.stderr == "Error: No secret available with the given name." diff --git a/packages/sshecret-sshd/tests/test_ping.py b/packages/sshecret-sshd/tests/test_ping.py new file mode 100644 index 0000000..a83ef50 --- /dev/null +++ b/packages/sshecret-sshd/tests/test_ping.py @@ -0,0 +1,15 @@ +import pytest + +from .types import ClientRegistry, CommandRunner + +@pytest.mark.asyncio +async def test_ping_command(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None: + # Register a test client with default policies and no secrets + await client_registry['add_client']('test-pinger') + + result = await ssh_command_runner("test-pinger", "ping") + + assert result.exit_status == 0 + assert result.stdout is not None + assert isinstance(result.stdout, str) + assert result.stdout.rstrip() == "PONG" diff --git a/packages/sshecret-sshd/tests/test_register.py b/packages/sshecret-sshd/tests/test_register.py new file mode 100644 index 0000000..432f57c --- /dev/null +++ b/packages/sshecret-sshd/tests/test_register.py @@ -0,0 +1,34 @@ +"""Test registration.""" +import pytest + +from .types import ClientRegistry, CommandRunner, ProcessRunner + +@pytest.mark.asyncio +async def test_register_client(ssh_session: ProcessRunner, ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None: + """Test client registration.""" + await client_registry["add_client"]("template", ["testsecret"]) + + public_key = client_registry["clients"]["template"].public_key.rstrip() + "\n" + + async with ssh_session("newclient", "register", "template") as session: + maxlines = 10 + l = 0 + found = False + while l < maxlines: + line = await session.stdout.readline() + if "Enter public key" in line: + found = True + break + + assert found is True + session.stdin.write(public_key) + result = await session.stdout.readline() + assert "OK" in result + + # Test that we can connect + + result = await ssh_command_runner("newclient", "get_secret testsecret") + + assert result.stdout is not None + assert isinstance(result.stdout, str) + assert result.stdout.rstrip() == "mocked-secret-testsecret" diff --git a/packages/sshecret-sshd/tests/types.py b/packages/sshecret-sshd/tests/types.py new file mode 100644 index 0000000..9b6187b --- /dev/null +++ b/packages/sshecret-sshd/tests/types.py @@ -0,0 +1,56 @@ +"""Types for the various test properties.""" +import uuid +from datetime import datetime +from dataclasses import dataclass, field +from ipaddress import IPv4Network, IPv6Network +from collections.abc import AsyncGenerator, Awaitable, Callable, AsyncIterator +from typing import Any, Protocol, TypedDict, AsyncContextManager +import asyncssh + +SshServerFixture = tuple[str, int] +SshServerFixtureFun = AsyncGenerator[tuple[asyncssh.SSHAcceptor, int], None] + + + + +@dataclass +class Client: + """Mock client.""" + + id: uuid.UUID + name: str + description: str | None + public_key: str + secrets: list[str] + policies: list[IPv4Network | IPv6Network] + created_at: datetime = field(default_factory=datetime.now) + updated_at: datetime = field(default_factory=datetime.now) + + +@dataclass +class ClientKey: + name: str + private_key: asyncssh.SSHKey + public_key: str + + +class AddClientFun(Protocol): + """Add client function.""" + + def __call__(self, name: str, secret_names: list[str] | None = None, policies: list[str] | None = None) -> Awaitable[str]: ... + +class ProcessRunner(Protocol): + """Process runner typing.""" + + def __call__(self, name: str, command: str, client: str | None = None) -> AsyncContextManager[asyncssh.SSHClientProcess[Any]]: + ... + +class ClientRegistry(TypedDict): + """Client registry typing.""" + + clients: dict[str, ClientKey] + secrets: dict[tuple[str, str], str] + add_client: AddClientFun + + +CommandRunner = Callable[[str, str], Awaitable[asyncssh.SSHCompletedProcess]]