From dcf0b4274c68a934956155e70af8ab9a06743761 Mon Sep 17 00:00:00 2001 From: Allan Eising Date: Sun, 18 May 2025 17:56:53 +0200 Subject: [PATCH] Refactor command handling This now supports usage/help texts --- .../src/sshecret_sshd/commands/__init__.py | 4 + .../src/sshecret_sshd/commands/base.py | 185 +++++++++--- .../src/sshecret_sshd/commands/dispatcher.py | 121 ++++++-- .../src/sshecret_sshd/commands/get_secret.py | 9 +- .../src/sshecret_sshd/commands/help.py | 48 --- .../sshecret_sshd/commands/list_secrets.py | 7 +- .../src/sshecret_sshd/commands/register.py | 1 - .../src/sshecret_sshd/commands/utils.py | 12 - .../src/sshecret_sshd/ssh_server.py | 282 +----------------- tests/integration/test_sshd.py | 4 +- tests/packages/sshd/conftest.py | 3 + tests/packages/sshd/test_errors.py | 12 +- tests/packages/sshd/test_get_secret.py | 27 +- tests/packages/sshd/test_ping.py | 21 ++ tests/packages/sshd/test_register.py | 32 +- 15 files changed, 337 insertions(+), 431 deletions(-) delete mode 100644 packages/sshecret-sshd/src/sshecret_sshd/commands/help.py delete mode 100644 packages/sshecret-sshd/src/sshecret_sshd/commands/utils.py diff --git a/packages/sshecret-sshd/src/sshecret_sshd/commands/__init__.py b/packages/sshecret-sshd/src/sshecret_sshd/commands/__init__.py index 8b13789..16514b1 100644 --- a/packages/sshecret-sshd/src/sshecret_sshd/commands/__init__.py +++ b/packages/sshecret-sshd/src/sshecret_sshd/commands/__init__.py @@ -1 +1,5 @@ +"""Commands module.""" +from .dispatcher import dispatch_command + +__all__ = ["dispatch_command"] diff --git a/packages/sshecret-sshd/src/sshecret_sshd/commands/base.py b/packages/sshecret-sshd/src/sshecret_sshd/commands/base.py index 6e64a0f..718a418 100644 --- a/packages/sshecret-sshd/src/sshecret_sshd/commands/base.py +++ b/packages/sshecret-sshd/src/sshecret_sshd/commands/base.py @@ -2,11 +2,14 @@ import abc import json +import textwrap from collections import defaultdict from dataclasses import dataclass, field import ipaddress import logging -from typing import Any, cast +from typing import Any, cast, override + +from rich.console import Console import asyncssh from pydantic import IPvAnyNetwork, IPvAnyAddress @@ -15,8 +18,6 @@ from sshecret.backend.api import SshecretBackend from sshecret.backend.models import Client, Operation, SubSystem from sshecret_sshd import exceptions -from .utils import get_console - PeernameV4 = tuple[str, int] PeernameV6 = tuple[str, int, int, int] Peername = PeernameV4 | PeernameV6 @@ -32,11 +33,45 @@ class CmdArgs: arguments: list[str] = field(default_factory=list) +@dataclass +class CommandFlag: + """Command flag.""" + + name: str + description: str + supports_short: bool = False + enabled: bool = False + + @override + def __str__(self) -> str: + """Format an output for help texts.""" + if self.supports_short: + return f"[-{self.name[0]} --{self.name}]" + return f"--{self.name}" + + +def get_console(process: asyncssh.SSHServerProcess[str]) -> Console: + """Initiate console from process.""" + width, _height, pixwidth, pixheight = process.term_size + LOG.debug("Terminal is %sx%s", pixwidth, pixheight) + + if width > 0: + console = Console( + force_terminal=True, + width=pixwidth, + height=pixheight, + color_system="standard", + ) + return console + return Console(markup=False, color_system=None) + + class CommandDispatcher(abc.ABC): """Command dispatcher.""" name: str flags: dict[str, str] | None = None + mandatory_argument: str | None = None def __init__( self, @@ -47,31 +82,57 @@ class CommandDispatcher(abc.ABC): def print( self, - data: str, + *data: str, stderr: bool = False, - formatted: bool = False, newline: bool = True, ) -> None: """Write to stdout.""" if stderr: - self.process.stderr.write(data + "\n") + for line in data: + self.process.stderr.write(line + "\n") return - if formatted: - data = self.get_rich(data) - if newline: - data += "\n" - self.process.stdout.write(data) + for line in data: + self.process.stdout.write(line) + if newline: + self.process.stdout.write("\n") - def get_rich(self, data: str) -> str: - """Print with rich formatting.""" - console = get_console(self.process) + def rich_print_line( + self, data: str, tags: list[str] | None = None, rule: bool = False + ) -> None: + """Write formatted text to the process. + + IF the client terminal does not support this, no formatting will be added. + Otherwise, the tags will be added to the string. + """ + if not tags: + tags = [] + if not self.formatting_supported: + return self.print(data, newline=True) + for tag in tags: + data = f"[{tag}]{data}[/{tag}]" + + console = self.get_console() with console.capture() as capture: - console.print(data) - return capture.get() + if rule: + console.rule(data) + else: + console.print(data) + + self.process.stdout.write(capture.get()) + + def get_console(self) -> Console: + """Initiate console from process.""" + return get_console(self.process) + + @property + def formatting_supported(self) -> bool: + """Check if the terminal supports formatting.""" + term_width = self.process.term_size[0] + return term_width > 0 def print_json(self, obj: dict[str, Any]) -> None: """Print a json object.""" - console = get_console(self.process) + console = self.get_console() data = json.dumps(obj) with console.capture() as capture: console.print_json(data) @@ -97,6 +158,7 @@ class CommandDispatcher(abc.ABC): username, client, secret, + data, ) await self.backend.audit(SubSystem.SSHD).write_async( @@ -148,6 +210,7 @@ class CommandDispatcher(abc.ABC): def arguments(self) -> list[str]: """Get non-flag arguments.""" parsed = self.parse_command() + LOG.debug("Parsed command: %r", parsed) if not self.flags: return parsed.arguments return [ @@ -211,37 +274,69 @@ class CommandDispatcher(abc.ABC): return allowed_registration @classmethod - async def print_help(cls, process: asyncssh.SSHServerProcess[str]) -> None: + def _format_command_name(cls) -> str: + """Format command name with arguments.""" + name = cls.name + if cls.mandatory_argument: + name = f"{name} {cls.mandatory_argument}" + + if not cls.flags: + return name + flags = cls.command_flags() + args: list[str] = [str(flag) for flag in flags.values()] + flagstr = " ".join(args) + return f"{name} {flagstr}" + + @classmethod + def usage(cls, indent: int = 0) -> str: + """Print command usage.""" + indent_prefix = " " * indent + inner_indent = 2 + indent + inner_prefix = " " * inner_indent + usage_str: list[str] = [] + default_usage = "No help available." + if not cls.__doc__: + usage_str.append(default_usage) + else: + usage_str = [ + textwrap.indent(line, prefix=inner_prefix) + for line in cls.__doc__.splitlines() + ] + flags = cls.command_flags() + if flags: + usage_str.append("") + usage_str.append(textwrap.indent("Arguments:", prefix=(" " * 4))) + for info in flags.values(): + usage_str.append( + textwrap.indent(f"{info!s}: {info.description}", prefix=(" " * 6)) + ) + usage = "\n".join(usage_str) + if indent: + usage = textwrap.indent(usage, prefix=indent_prefix) + + return usage + + async def print_help(self) -> None: """Print help.""" - descr = cls.__doc__ + usage = type(self).usage() + command_name = type(self)._format_command_name() + self.rich_print_line(command_name) + self.print(usage) - help_text: list[str] = [ - f"[bold]{cls.name}[/bold]", - descr or "", - ] - flags: dict[str, str] = {} + @classmethod + def command_flags(cls) -> dict[str, CommandFlag]: + """Parse the command flags.""" shortflags: defaultdict[str, list[str]] = defaultdict(list) - if cls.flags: - flags = cls.flags - - arg_text: list[str] = [] - for flag in flags.keys(): + if not cls.flags: + return {} + for flag in cls.flags.keys(): shortflags[flag[0]].append(flag) - for flag, flag_descr in flags.items(): - flagstr = f"--{flag}" - if len(shortflags[flag[0]]) == 1: - flagstr = f"[ -{flag[0]} --{flag} ]" - arg_text.extend([f"{flagstr}:", f"{flag_descr}", ""]) + command_flags: dict[str, CommandFlag] = {} - console = get_console(process) - with console.capture() as capture: - for line in help_text: - console.print(line) - if arg_text: - console.rule("Argument flags") - - for line in arg_text: - console.print(line) - - process.stdout.write(capture.get()) + for flag, description in cls.flags.items(): + supports_shorts = True + if len(shortflags[flag[0]]) > 1: + supports_shorts = False + command_flags[flag] = CommandFlag(flag, description, supports_shorts) + return command_flags diff --git a/packages/sshecret-sshd/src/sshecret_sshd/commands/dispatcher.py b/packages/sshecret-sshd/src/sshecret_sshd/commands/dispatcher.py index 96755fb..7634fc3 100644 --- a/packages/sshecret-sshd/src/sshecret_sshd/commands/dispatcher.py +++ b/packages/sshecret-sshd/src/sshecret_sshd/commands/dispatcher.py @@ -5,7 +5,7 @@ Register arguments here. import logging import textwrap -from typing import final, override +from typing import cast, final, override import asyncssh @@ -47,44 +47,105 @@ class HelpCommand(CommandDispatcher): name = "help" + @override + def __init__( + self, + process: asyncssh.SSHServerProcess[str], + disabled_commands: list[str] | None = None, + ) -> None: + """Init help command.""" + super().__init__(process) + self.disabled_commands: list[str] = disabled_commands or [] + @override async def exec(self) -> None: """Execute command.""" - usage_text = "[bold]AVAILABLE COMMANDS[/bold]" - usage_text += ( - "[i]Some commands may be disabled or restricted by the administrator[/i]\n" + output_lines: list[tuple[str, str] | str] = [] + output_lines.append(("Available commands:", "bold")) + output_lines.append( + ("Some commands may be disabled or restricted by the administrator", "i") ) + output_lines.append("") for command in COMMANDS: - usage_text += f" [bold]{command.name}:[/bold]" - usage_text += textwrap.indent(self.get_command_doc(command), " ") + if command.name in self.disabled_commands: + continue + command_usage = command.usage(indent=2) + output_lines.append((f" {command._format_command_name()}", "bold")) + output_lines.extend(command_usage.splitlines()) + output_lines.append("") - width, _height, _pixwidth, _pixheight = self.process.term_size - wrapped = textwrap.wrap(usage_text, width=width) - wrapped_lines = "\n".join(wrapped) - self.print(wrapped_lines, formatted=True) + for line in output_lines: + tags: list[str] = [] + if isinstance(line, tuple): + text, tag = line + tags.append(tag) + else: + text = line + self.rich_print_line(text, tags) - def get_command_doc(self, command: type[CommandDispatcher]) -> str: - """Format usage string for command.""" - default_usage = f"No help available." - if not command.__doc__: - return default_usage - # Skip the first two lines: - usage_str = command.__doc__[:2:] - return usage_str.join("\n") + +def get_disabled_commands(process: asyncssh.SSHServerProcess[str]) -> list[str]: + """Get optional command state.""" + with_registration = cast( + bool, process.get_extra_info("registration_enabled", False) + ) + with_ping = cast(bool, process.get_extra_info("ping_enabled", False)) + optional_commands = { + "register": with_registration, + "ping": with_ping, + } + disabled = [key for key, value in optional_commands.items() if not value] + return disabled + + +async def do_dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None: + """Dispatch command.""" + disabled_commands = get_disabled_commands(process) + command_args = process.command + + if not command_args: + raise exceptions.NoCommandReceivedError() + show_command_help = False + if "--help" in command_args: + show_command_help = True + command = command_args.split(" ")[0] + command_map: dict[str, type[CommandDispatcher]] = { + cmd_disp.name: cmd_disp + for cmd_disp in COMMANDS + if cmd_disp.name not in disabled_commands + } + + command_map["help"] = HelpCommand + + LOG.debug("disabled_commands: %r, command_map: %r", disabled_commands, command_map) + LOG.debug("Looking for command %s", command) + if command not in command_map: + raise exceptions.UnknownCommandError() + + dispatcher = command_map[command] + + LOG.debug("Received command: %s. Dispatching to %r", command, dispatcher) + if command != "help" and show_command_help: + return await dispatcher(process).print_help() + + if command == "help": + return await HelpCommand(process, disabled_commands).exec() + return await dispatcher(process).exec() async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None: """Dispatch command.""" - command = process.command - if not command: - command = "help" - command_map: dict[str, type[CommandDispatcher]] = { - cmd_disp.name: cmd_disp for cmd_disp in COMMANDS - } - command_map["help"] = HelpCommand + status_code = 0 + try: + await do_dispatch_command(process) + except exceptions.BaseSshecretSshError as e: + LOG.error("Command Error: %s", e, exc_info=True) + process.stderr.write(f"{e}\n") + status_code = 1 + except Exception as e: + LOG.error("Unexpected error: %e", e, exc_info=True) + process.stderr.write(f"{constants.ERROR_GENERIC_ERROR}: {e}") + status_code = 1 - dispatcher = command_map.get(command, HelpCommand) - - LOG.debug("Received command: %s", command) - - return await dispatcher(process).exec() + finally: + process.exit(status_code) diff --git a/packages/sshecret-sshd/src/sshecret_sshd/commands/get_secret.py b/packages/sshecret-sshd/src/sshecret_sshd/commands/get_secret.py index e6937e9..3adc096 100644 --- a/packages/sshecret-sshd/src/sshecret_sshd/commands/get_secret.py +++ b/packages/sshecret-sshd/src/sshecret_sshd/commands/get_secret.py @@ -12,9 +12,15 @@ LOG = logging.getLogger(__name__) @final class GetSecret(CommandDispatcher): - """Get Secret.""" + """Retrieve an encrypted secret. + + Returns the value of the secret provided as a mandatory argument. + The secret will be encrypted using the stored RSA public key, and returned + as a base64 encoded string. + """ name = "get_secret" + mandatory_argument = "SECRET" @override async def exec(self) -> None: @@ -22,6 +28,7 @@ class GetSecret(CommandDispatcher): if len(self.arguments) != 1: raise exceptions.UnknownClientOrSecretError() secret_name = self.arguments[0] + LOG.debug("get_secret called: Argument: %r", secret_name) if secret_name not in self.client.secrets: await self.audit( Operation.DENY, diff --git a/packages/sshecret-sshd/src/sshecret_sshd/commands/help.py b/packages/sshecret-sshd/src/sshecret_sshd/commands/help.py deleted file mode 100644 index 898b9c3..0000000 --- a/packages/sshecret-sshd/src/sshecret_sshd/commands/help.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Simple help command.""" - - -from typing import final, override - -from .base import CommandDispatcher - -HELP_TEXT = """ -[bold]Sshecret SSH Server[/bold] - - -[bold]SYNOPSIS[/bold] -An interface to request encrypted client secrets and perform -simple commands. - -Secrets will be returned encrypted with the client public key, -encoded as base64. - - -[bold]AVAILABLE COMMANDS[/bold] -[i]Some of these commands may be disabled by the administrator.[/i] - - [bold]ping[/bold]: Perform a ping towards the server. - - [bold]get_secret [i]secret_name[/i][/bold]: Get a named secret - - [bold]register[/bold]: Register a new client. Will prompt for public key. - [i]If enabled and permitted[/i] - - [bold]ls [ --json -j ][/bold]: List available secrets. - - [bold]help[/bold]: Prints this message -""" - - -@final -class HelpCommand(CommandDispatcher): - """Help. - - Returns usage instructions. - """ - - name = "help" - - @override - async def exec(self) -> None: - """Execute command.""" - self.print(HELP_TEXT, formatted=True) diff --git a/packages/sshecret-sshd/src/sshecret_sshd/commands/list_secrets.py b/packages/sshecret-sshd/src/sshecret_sshd/commands/list_secrets.py index 8b04a10..985dc51 100644 --- a/packages/sshecret-sshd/src/sshecret_sshd/commands/list_secrets.py +++ b/packages/sshecret-sshd/src/sshecret_sshd/commands/list_secrets.py @@ -32,10 +32,7 @@ class ListSecrets(CommandDispatcher): async def list_secrets(self) -> None: """List secrets.""" - self.print( - f"[bold]Available secrets for client {self.client.name}[/bold]", - formatted=True, - ) + self.rich_print_line(f"Available secrets for client {self.client.name}", ["bold"], rule=True) await self.audit(Operation.READ, "Listed available secret names") - for secret_name in self.client.name: + for secret_name in self.client.secrets: self.print(f" - {secret_name}") diff --git a/packages/sshecret-sshd/src/sshecret_sshd/commands/register.py b/packages/sshecret-sshd/src/sshecret_sshd/commands/register.py index efa0a0b..762c53f 100644 --- a/packages/sshecret-sshd/src/sshecret_sshd/commands/register.py +++ b/packages/sshecret-sshd/src/sshecret_sshd/commands/register.py @@ -24,7 +24,6 @@ class Register(CommandDispatcher): """Register a new client. After connection, you must input a ssh public key. - This must be an RSA key. No other types of keys are supported. """ diff --git a/packages/sshecret-sshd/src/sshecret_sshd/commands/utils.py b/packages/sshecret-sshd/src/sshecret_sshd/commands/utils.py deleted file mode 100644 index 096358d..0000000 --- a/packages/sshecret-sshd/src/sshecret_sshd/commands/utils.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Various utilities.""" - -import asyncssh -from rich.console import Console - -def get_console(process: asyncssh.SSHServerProcess[str]) -> Console: - """Initiate console from process.""" - _width, _height, pixwidth, pixheight = process.term_size - console = Console( - force_terminal=True, width=pixwidth, height=pixheight, color_system="standard" - ) - return console diff --git a/packages/sshecret-sshd/src/sshecret_sshd/ssh_server.py b/packages/sshecret-sshd/src/sshecret_sshd/ssh_server.py index 517fb73..4365f47 100644 --- a/packages/sshecret-sshd/src/sshecret_sshd/ssh_server.py +++ b/packages/sshecret-sshd/src/sshecret_sshd/ssh_server.py @@ -1,19 +1,18 @@ """SSH Server implementation.""" -from asyncio import _register_task import logging - import asyncssh import ipaddress from collections.abc import Awaitable from functools import partial from pathlib import Path -from typing import Callable, cast, override +from typing import Callable, override from pydantic import IPvAnyNetwork -from . import constants +from sshecret_sshd import constants +from sshecret_sshd.commands import dispatch_command from sshecret.backend import SshecretBackend, Client, Operation, SubSystem from .settings import ServerSettings, ClientRegistrationSettings @@ -29,37 +28,6 @@ PeernameV6 = tuple[str, int, int, int] Peername = PeernameV4 | PeernameV6 -class CommandError(Exception): - """Error class for errors during command processing.""" - - -async def audit_process( - backend: SshecretBackend, - process: asyncssh.SSHServerProcess[str], - operation: Operation, - message: str, - secret: str | None = None, - **data: str, -) -> None: - """Add an audit event from process.""" - command = get_process_command(process) - client = get_info_client(process) - username = get_info_username(process) - remote_ip = get_info_remote_ip(process) or "UNKNOWN" - if username: - data["username"] = username - - if command and not secret: - cmd, cmd_args = command - if cmd: - data["command"] = cmd - data["args"] = " ".join(cmd_args) - - await backend.audit(SubSystem.SSHD).write_async( - operation, message, remote_ip, client, secret=None, secret_name=secret, **data - ) - - async def audit_event( backend: SshecretBackend, message: str, @@ -87,250 +55,6 @@ def verify_key_input(public_key: str) -> str | None: return None -def get_process_command( - process: asyncssh.SSHServerProcess[str], -) -> tuple[str | None, list[str]]: - """Extract the process command.""" - if not process.command: - return (None, []) - argv = process.command.split(" ") - LOG.debug("Args: %r", argv) - return (argv[0], argv[1:]) - - -def get_info_backend(process: asyncssh.SSHServerProcess[str]) -> SshecretBackend | None: - """Get backend from process.""" - backend = cast("SshecretBackend | None", process.get_extra_info("backend", None)) - return backend - - -def get_info_client(process: asyncssh.SSHServerProcess[str]) -> Client | None: - """Get info from process.""" - client = cast("Client | None", process.get_extra_info("client", None)) - return client - - -def get_info_username(process: asyncssh.SSHServerProcess[str]) -> str | None: - """Get username from process.""" - username = cast("str | None", process.get_extra_info("provided_username", None)) - return username - - -def get_info_remote_ip(process: asyncssh.SSHServerProcess[str]) -> str | None: - """Get remote IP.""" - - peername = cast("Peername | None", process.get_extra_info("peername", None)) - remote_ip: str | None = None - if peername: - remote_ip = peername[0] - - return remote_ip - - -def get_info_allowed_registration( - process: asyncssh.SSHServerProcess[str], -) -> list[IPvAnyNetwork] | None: - """Get allowed networks to allow registration from.""" - - allowed_registration = cast( - list[IPvAnyNetwork] | None, - process.get_extra_info("allow_registration_from", None), - ) - return allowed_registration - - -def get_optional_commands(process: asyncssh.SSHServerProcess[str]) -> dict[str, bool]: - """Get optional command state.""" - with_registration = cast( - bool, process.get_extra_info("registration_enabled", False) - ) - with_ping = cast(bool, process.get_extra_info("ping_enabled", False)) - return { - "registration": with_registration, - "ping": with_ping, - } - - -async def get_stdin_public_key(process: asyncssh.SSHServerProcess[str]) -> str | None: - """Get public key from stdin.""" - process.stdout.write("Enter public key:\n") - public_key: str | None = None - try: - async for line in process.stdin: - public_key = verify_key_input(line.rstrip("\n")) - if public_key: - break - raise CommandError(constants.ERROR_INVALID_KEY_TYPE) - except asyncssh.BreakReceived: - pass - else: - process.stdout.write("OK\n") - return public_key - - -async def register_client( - process: asyncssh.SSHServerProcess[str], - backend: SshecretBackend, - username: str, -) -> None: - """Register a new client.""" - public_key = await get_stdin_public_key(process) - if not public_key: - raise CommandError(constants.ERROR_NO_PUBLIC_KEY) - - key = asyncssh.import_public_key(public_key) - if key.algorithm.decode() != "ssh-rsa": - raise CommandError(constants.ERROR_INVALID_KEY_TYPE) - await audit_process(backend, process, Operation.CREATE, "Registering new client") - LOG.debug("Registering client %s with public key %s", username, public_key) - await backend.create_client(username, public_key) - - -async def get_secret( - backend: SshecretBackend, - client: Client, - secret_name: str, - origin: str, -) -> str: - """Handle get secret requests from client.""" - LOG.debug("Recieved command: get_secret %r", secret_name) - if not secret_name: - raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET) - if secret_name not in client.secrets: - raise CommandError(constants.ERROR_NO_SECRET_FOUND) - - await audit_event( - backend, - "Client requested secret", - operation=Operation.READ, - client=client, - origin=origin, - secret=secret_name, - ) - - # Look up secret - try: - secret = await backend.get_client_secret(client.name, secret_name) - if not secret: - raise CommandError(constants.ERROR_NO_SECRET_FOUND) - return secret - except Exception as exc: - LOG.debug(exc, exc_info=True) - raise CommandError(constants.ERROR_BACKEND_ERROR) from exc - - -async def dispatch_no_cmd(process: asyncssh.SSHServerProcess[str]) -> None: - """Dispatch for no command.""" - raise CommandError(constants.ERROR_NO_COMMAND_RECEIVED) - - -async def dispatch_cmd_ping(process: asyncssh.SSHServerProcess[str]) -> None: - """Dispatch the ping command.""" - process.stdout.write("PONG\n") - - -async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None: - """Dispatch the register command.""" - backend = get_info_backend(process) - if not backend: - raise CommandError(constants.ERROR_INFO_BACKEND_GONE) - username = get_info_username(process) - if not username: - raise CommandError(constants.ERROR_INFO_USERNAME_GONE) - - allowed_networks = get_info_allowed_registration(process) - if not allowed_networks: - process.stdout.write("Unauthorized.\n") - await audit_process( - backend, - process, - Operation.DENY, - "Received registration command, but no subnets are allowed.", - ) - return - - remote_ip = get_info_remote_ip(process) - - if not remote_ip: - raise CommandError(constants.ERROR_INFO_REMOTE_IP_GONE) - - client_address = ipaddress.ip_address(remote_ip) - - for network in allowed_networks: - if client_address in network: - break - else: - await audit_process( - backend, - process, - Operation.DENY, - "Received registration command from unauthorized subnet.", - ) - process.stdout.write("Unauthorized.\n") - return - - await register_client(process, backend, username) - - process.stdout.write("Client registered\n.") - - -async def dispatch_cmd_get_secret(process: asyncssh.SSHServerProcess[str]) -> None: - """Dispatch the get_secret command.""" - backend = get_info_backend(process) - if not backend: - raise CommandError(constants.ERROR_INFO_BACKEND_GONE) - - client = get_info_client(process) - if not client: - raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET) - _cmd, args = get_process_command(process) - if not args: - raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET) - secret_name = args[0] - - origin = get_info_remote_ip(process) or "Unknown" - secret = await get_secret(backend, client, secret_name, origin) - process.stdout.write(secret) - - -async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None: - """Dispatch command.""" - command, _args = get_process_command(process) - if not command: - process.stderr.write(constants.ERROR_NO_COMMAND_RECEIVED) - process.exit(1) - return - cmdmap: dict[str, CommandDispatch] = { - "get_secret": dispatch_cmd_get_secret, - } - extra_commands = get_optional_commands(process) - if "registration" in extra_commands: - cmdmap["register"] = dispatch_cmd_register - if "ping" in extra_commands: - cmdmap["ping"] = dispatch_cmd_ping - - if command not in cmdmap: - process.stderr.write(constants.ERROR_UNKNOWN_COMMAND) - process.exit(1) - return - exit_code = 0 - try: - dispatcher = cmdmap[command] - await dispatcher(process) - except CommandError as e: - process.stderr.write(str(e)) - exit_code = 1 - - except Exception as e: - LOG.debug(e, exc_info=True) - process.stderr.write("Unexpected exception:\n") - process.stderr.write(str(e)) - exit_code = 1 - - LOG.debug("Command processing finished.") - process.exit(exit_code) - - class AsshyncServer(asyncssh.SSHServer): """Asynchronous SSH server implementation.""" diff --git a/tests/integration/test_sshd.py b/tests/integration/test_sshd.py index c4628c4..548278b 100644 --- a/tests/integration/test_sshd.py +++ b/tests/integration/test_sshd.py @@ -69,8 +69,8 @@ class TestSshd: assert found is True session.stdin.write(test_client.public_key + "\n") - result = await session.stdout.readline() - assert "OK" in result + result = await session.stdout.read() + assert "Key is valid. Registering client." in result await session.wait() return test_client diff --git a/tests/packages/sshd/conftest.py b/tests/packages/sshd/conftest.py index 55f4a38..3c26c1e 100644 --- a/tests/packages/sshd/conftest.py +++ b/tests/packages/sshd/conftest.py @@ -82,10 +82,13 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock: "Error, must have a client called template for this to work." ) clients_data[name] = clients_data["template"] + template_secrets: dict[str, str] = {} for secret_key, secret in secrets_data.items(): s_client, secret_name = secret_key if s_client != "template": continue + template_secrets[secret_name] = secret + for secret_name, secret in template_secrets.items(): secrets_data[(name, secret_name)] = secret async def write_audit(*args, **kwargs): diff --git a/tests/packages/sshd/test_errors.py b/tests/packages/sshd/test_errors.py index 4f7d143..5127d22 100644 --- a/tests/packages/sshd/test_errors.py +++ b/tests/packages/sshd/test_errors.py @@ -83,8 +83,9 @@ class TestRegistrationErrors(BaseSshTests): output = await process.stdout.readline() assert "Enter public key" in output stdout, stderr = await process.communicate(public_key) + assert isinstance(stderr, str) print(f"{stdout=!r}, {stderr=!r}") - assert stderr == "Error: Invalid key type: Only RSA keys are supported." + assert stderr.rstrip() == "Error: Invalid key type: Only RSA keys are supported." result = await process.wait() assert result.exit_status == 1 @@ -102,8 +103,9 @@ class TestRegistrationErrors(BaseSshTests): output = await process.stdout.readline() assert "Enter public key" in output stdout, stderr = await process.communicate(public_key) + assert isinstance(stderr, str) print(f"{stdout=!r}, {stderr=!r}") - assert stderr == "Error: Invalid key type: Only RSA keys are supported." + assert stderr.rstrip() == "Error: Invalid key type: Only RSA keys are supported." result = await process.wait() assert result.exit_status == 1 @@ -122,7 +124,8 @@ class TestCommandErrors(BaseSshTests): assert result.exit_status == 1 stderr = result.stderr or "" - assert stderr == "Error: Unsupported command." + assert isinstance(stderr, str) + assert stderr.rstrip() == "Error: Unsupported command." @pytest.mark.asyncio async def test_no_command( @@ -136,7 +139,8 @@ class TestCommandErrors(BaseSshTests): async with conn.create_process() as process: stdout, stderr = await process.communicate() print(f"{stdout=!r}, {stderr=!r}") - assert stderr == "Error: No command was received from the client." + assert isinstance(stderr, str) + assert stderr.rstrip() == "Error: No command was received from the client." result = await process.wait() assert result.exit_status == 1 diff --git a/tests/packages/sshd/test_get_secret.py b/tests/packages/sshd/test_get_secret.py index b129a7c..d4782de 100644 --- a/tests/packages/sshd/test_get_secret.py +++ b/tests/packages/sshd/test_get_secret.py @@ -1,10 +1,12 @@ """Test get secret.""" +import allure import pytest from .types import ClientRegistry, CommandRunner +@allure.title("Test get_secret command") @pytest.mark.asyncio async def test_get_secret( ssh_command_runner: CommandRunner, client_registry: ClientRegistry @@ -19,7 +21,7 @@ async def test_get_secret( assert isinstance(result.stdout, str) assert result.stdout.rstrip() == "mocked-secret-mysecret" - +@allure.title("Test with invalid secret name") @pytest.mark.asyncio async def test_invalid_secret_name( ssh_command_runner: CommandRunner, client_registry: ClientRegistry @@ -30,4 +32,25 @@ async def test_invalid_secret_name( result = await ssh_command_runner("test-client", "get_secret mysecret") assert result.exit_status == 1 - assert result.stderr == "Error: No secret available with the given name." + stderr = result.stderr + assert isinstance(stderr, str) + assert stderr.rstrip() == "Error: No secret available with the given name." + +@allure.title("Test get_secret command help") +@pytest.mark.asyncio +async def test_get_secret_cmd_help(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None: + """Test running get_secret --help""" + await client_registry["add_client"]("test-client", ["mysecret"]) + + result = await ssh_command_runner("test-client", "get_secret --help") + + assert result.exit_status == 0 + + print(result.stdout) + assert isinstance(result.stdout, str) + + lines = result.stdout.splitlines() + + assert lines[0] == "get_secret SECRET" + + assert len(lines) > 4 diff --git a/tests/packages/sshd/test_ping.py b/tests/packages/sshd/test_ping.py index 791c01a..f78f6ab 100644 --- a/tests/packages/sshd/test_ping.py +++ b/tests/packages/sshd/test_ping.py @@ -1,8 +1,11 @@ +"""Test for the ping command.""" +import allure import pytest from .types import ClientRegistry, CommandRunner +@allure.title("Test running the ping command") @pytest.mark.asyncio async def test_ping_command( ssh_command_runner: CommandRunner, client_registry: ClientRegistry @@ -16,3 +19,21 @@ async def test_ping_command( assert result.stdout is not None assert isinstance(result.stdout, str) assert result.stdout.rstrip() == "PONG" + +@allure.title("Test ping help") +@pytest.mark.asyncio +async def test_ping_cmd_help(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None: + """Test running ping --help.""" + await client_registry["add_client"]("test-client", ["mysecret"]) + result = await ssh_command_runner("test-client", "ping --help") + + assert result.exit_status == 0 + + print(result.stdout) + assert isinstance(result.stdout, str) + + lines = result.stdout.splitlines() + + assert lines[0] == "ping" + + assert len(lines) > 4 diff --git a/tests/packages/sshd/test_register.py b/tests/packages/sshd/test_register.py index ad82e78..ed7a6b6 100644 --- a/tests/packages/sshd/test_register.py +++ b/tests/packages/sshd/test_register.py @@ -1,10 +1,12 @@ """Test registration.""" +import allure import pytest from .types import ClientRegistry, CommandRunner, ProcessRunner +@allure.title("Test client registration") @pytest.mark.enable_registration(True) @pytest.mark.asyncio async def test_register_client( @@ -29,8 +31,9 @@ async def test_register_client( assert found is True session.stdin.write(public_key) - result = await session.stdout.readline() - assert "OK" in result + data = await session.stdout.read() + assert isinstance(data, str) + assert "Key is valid. Registering client" in data # Test that we can connect @@ -39,3 +42,28 @@ async def test_register_client( assert result.stdout is not None assert isinstance(result.stdout, str) assert result.stdout.rstrip() == "mocked-secret-testsecret" + +@allure.title("Test register command help") +@pytest.mark.enable_registration(True) +@pytest.mark.asyncio +async def test_register_cmd_help(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None: + """Test running register --help""" + await client_registry["add_client"]("test-client", ["mysecret"]) + + result = await ssh_command_runner("test-client", "register --help") + + assert result.exit_status == 0 + + print(result.stdout) + assert isinstance(result.stdout, str) + + lines = result.stdout.splitlines() + + assert lines[0] == "register" + + assert len(lines) > 4 + + + + +# TODO: Test running register with an existing client.