Refactor command handling

This now supports usage/help texts
This commit is contained in:
2025-05-18 17:56:53 +02:00
parent 26ef9b45d4
commit dcf0b4274c
15 changed files with 337 additions and 431 deletions

View File

@ -1 +1,5 @@
"""Commands module."""
from .dispatcher import dispatch_command
__all__ = ["dispatch_command"]

View File

@ -2,11 +2,14 @@
import abc import abc
import json import json
import textwrap
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
import ipaddress import ipaddress
import logging import logging
from typing import Any, cast from typing import Any, cast, override
from rich.console import Console
import asyncssh import asyncssh
from pydantic import IPvAnyNetwork, IPvAnyAddress from pydantic import IPvAnyNetwork, IPvAnyAddress
@ -15,8 +18,6 @@ from sshecret.backend.api import SshecretBackend
from sshecret.backend.models import Client, Operation, SubSystem from sshecret.backend.models import Client, Operation, SubSystem
from sshecret_sshd import exceptions from sshecret_sshd import exceptions
from .utils import get_console
PeernameV4 = tuple[str, int] PeernameV4 = tuple[str, int]
PeernameV6 = tuple[str, int, int, int] PeernameV6 = tuple[str, int, int, int]
Peername = PeernameV4 | PeernameV6 Peername = PeernameV4 | PeernameV6
@ -32,11 +33,45 @@ class CmdArgs:
arguments: list[str] = field(default_factory=list) arguments: list[str] = field(default_factory=list)
@dataclass
class CommandFlag:
"""Command flag."""
name: str
description: str
supports_short: bool = False
enabled: bool = False
@override
def __str__(self) -> str:
"""Format an output for help texts."""
if self.supports_short:
return f"[-{self.name[0]} --{self.name}]"
return f"--{self.name}"
def get_console(process: asyncssh.SSHServerProcess[str]) -> Console:
"""Initiate console from process."""
width, _height, pixwidth, pixheight = process.term_size
LOG.debug("Terminal is %sx%s", pixwidth, pixheight)
if width > 0:
console = Console(
force_terminal=True,
width=pixwidth,
height=pixheight,
color_system="standard",
)
return console
return Console(markup=False, color_system=None)
class CommandDispatcher(abc.ABC): class CommandDispatcher(abc.ABC):
"""Command dispatcher.""" """Command dispatcher."""
name: str name: str
flags: dict[str, str] | None = None flags: dict[str, str] | None = None
mandatory_argument: str | None = None
def __init__( def __init__(
self, self,
@ -47,31 +82,57 @@ class CommandDispatcher(abc.ABC):
def print( def print(
self, self,
data: str, *data: str,
stderr: bool = False, stderr: bool = False,
formatted: bool = False,
newline: bool = True, newline: bool = True,
) -> None: ) -> None:
"""Write to stdout.""" """Write to stdout."""
if stderr: if stderr:
self.process.stderr.write(data + "\n") for line in data:
self.process.stderr.write(line + "\n")
return return
if formatted: for line in data:
data = self.get_rich(data) self.process.stdout.write(line)
if newline: if newline:
data += "\n" self.process.stdout.write("\n")
self.process.stdout.write(data)
def get_rich(self, data: str) -> str: def rich_print_line(
"""Print with rich formatting.""" self, data: str, tags: list[str] | None = None, rule: bool = False
console = get_console(self.process) ) -> None:
"""Write formatted text to the process.
IF the client terminal does not support this, no formatting will be added.
Otherwise, the tags will be added to the string.
"""
if not tags:
tags = []
if not self.formatting_supported:
return self.print(data, newline=True)
for tag in tags:
data = f"[{tag}]{data}[/{tag}]"
console = self.get_console()
with console.capture() as capture: with console.capture() as capture:
console.print(data) if rule:
return capture.get() console.rule(data)
else:
console.print(data)
self.process.stdout.write(capture.get())
def get_console(self) -> Console:
"""Initiate console from process."""
return get_console(self.process)
@property
def formatting_supported(self) -> bool:
"""Check if the terminal supports formatting."""
term_width = self.process.term_size[0]
return term_width > 0
def print_json(self, obj: dict[str, Any]) -> None: def print_json(self, obj: dict[str, Any]) -> None:
"""Print a json object.""" """Print a json object."""
console = get_console(self.process) console = self.get_console()
data = json.dumps(obj) data = json.dumps(obj)
with console.capture() as capture: with console.capture() as capture:
console.print_json(data) console.print_json(data)
@ -97,6 +158,7 @@ class CommandDispatcher(abc.ABC):
username, username,
client, client,
secret, secret,
data,
) )
await self.backend.audit(SubSystem.SSHD).write_async( await self.backend.audit(SubSystem.SSHD).write_async(
@ -148,6 +210,7 @@ class CommandDispatcher(abc.ABC):
def arguments(self) -> list[str]: def arguments(self) -> list[str]:
"""Get non-flag arguments.""" """Get non-flag arguments."""
parsed = self.parse_command() parsed = self.parse_command()
LOG.debug("Parsed command: %r", parsed)
if not self.flags: if not self.flags:
return parsed.arguments return parsed.arguments
return [ return [
@ -211,37 +274,69 @@ class CommandDispatcher(abc.ABC):
return allowed_registration return allowed_registration
@classmethod @classmethod
async def print_help(cls, process: asyncssh.SSHServerProcess[str]) -> None: def _format_command_name(cls) -> str:
"""Format command name with arguments."""
name = cls.name
if cls.mandatory_argument:
name = f"{name} {cls.mandatory_argument}"
if not cls.flags:
return name
flags = cls.command_flags()
args: list[str] = [str(flag) for flag in flags.values()]
flagstr = " ".join(args)
return f"{name} {flagstr}"
@classmethod
def usage(cls, indent: int = 0) -> str:
"""Print command usage."""
indent_prefix = " " * indent
inner_indent = 2 + indent
inner_prefix = " " * inner_indent
usage_str: list[str] = []
default_usage = "No help available."
if not cls.__doc__:
usage_str.append(default_usage)
else:
usage_str = [
textwrap.indent(line, prefix=inner_prefix)
for line in cls.__doc__.splitlines()
]
flags = cls.command_flags()
if flags:
usage_str.append("")
usage_str.append(textwrap.indent("Arguments:", prefix=(" " * 4)))
for info in flags.values():
usage_str.append(
textwrap.indent(f"{info!s}: {info.description}", prefix=(" " * 6))
)
usage = "\n".join(usage_str)
if indent:
usage = textwrap.indent(usage, prefix=indent_prefix)
return usage
async def print_help(self) -> None:
"""Print help.""" """Print help."""
descr = cls.__doc__ usage = type(self).usage()
command_name = type(self)._format_command_name()
self.rich_print_line(command_name)
self.print(usage)
help_text: list[str] = [ @classmethod
f"[bold]{cls.name}[/bold]", def command_flags(cls) -> dict[str, CommandFlag]:
descr or "", """Parse the command flags."""
]
flags: dict[str, str] = {}
shortflags: defaultdict[str, list[str]] = defaultdict(list) shortflags: defaultdict[str, list[str]] = defaultdict(list)
if cls.flags: if not cls.flags:
flags = cls.flags return {}
for flag in cls.flags.keys():
arg_text: list[str] = []
for flag in flags.keys():
shortflags[flag[0]].append(flag) shortflags[flag[0]].append(flag)
for flag, flag_descr in flags.items():
flagstr = f"--{flag}"
if len(shortflags[flag[0]]) == 1:
flagstr = f"[ -{flag[0]} --{flag} ]"
arg_text.extend([f"{flagstr}:", f"{flag_descr}", ""]) command_flags: dict[str, CommandFlag] = {}
console = get_console(process) for flag, description in cls.flags.items():
with console.capture() as capture: supports_shorts = True
for line in help_text: if len(shortflags[flag[0]]) > 1:
console.print(line) supports_shorts = False
if arg_text: command_flags[flag] = CommandFlag(flag, description, supports_shorts)
console.rule("Argument flags") return command_flags
for line in arg_text:
console.print(line)
process.stdout.write(capture.get())

View File

@ -5,7 +5,7 @@ Register arguments here.
import logging import logging
import textwrap import textwrap
from typing import final, override from typing import cast, final, override
import asyncssh import asyncssh
@ -47,44 +47,105 @@ class HelpCommand(CommandDispatcher):
name = "help" name = "help"
@override
def __init__(
self,
process: asyncssh.SSHServerProcess[str],
disabled_commands: list[str] | None = None,
) -> None:
"""Init help command."""
super().__init__(process)
self.disabled_commands: list[str] = disabled_commands or []
@override @override
async def exec(self) -> None: async def exec(self) -> None:
"""Execute command.""" """Execute command."""
usage_text = "[bold]AVAILABLE COMMANDS[/bold]" output_lines: list[tuple[str, str] | str] = []
usage_text += ( output_lines.append(("Available commands:", "bold"))
"[i]Some commands may be disabled or restricted by the administrator[/i]\n" output_lines.append(
("Some commands may be disabled or restricted by the administrator", "i")
) )
output_lines.append("")
for command in COMMANDS: for command in COMMANDS:
usage_text += f" [bold]{command.name}:[/bold]" if command.name in self.disabled_commands:
usage_text += textwrap.indent(self.get_command_doc(command), " ") continue
command_usage = command.usage(indent=2)
output_lines.append((f" {command._format_command_name()}", "bold"))
output_lines.extend(command_usage.splitlines())
output_lines.append("")
width, _height, _pixwidth, _pixheight = self.process.term_size for line in output_lines:
wrapped = textwrap.wrap(usage_text, width=width) tags: list[str] = []
wrapped_lines = "\n".join(wrapped) if isinstance(line, tuple):
self.print(wrapped_lines, formatted=True) text, tag = line
tags.append(tag)
else:
text = line
self.rich_print_line(text, tags)
def get_command_doc(self, command: type[CommandDispatcher]) -> str:
"""Format usage string for command.""" def get_disabled_commands(process: asyncssh.SSHServerProcess[str]) -> list[str]:
default_usage = f"No help available." """Get optional command state."""
if not command.__doc__: with_registration = cast(
return default_usage bool, process.get_extra_info("registration_enabled", False)
# Skip the first two lines: )
usage_str = command.__doc__[:2:] with_ping = cast(bool, process.get_extra_info("ping_enabled", False))
return usage_str.join("\n") optional_commands = {
"register": with_registration,
"ping": with_ping,
}
disabled = [key for key, value in optional_commands.items() if not value]
return disabled
async def do_dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
"""Dispatch command."""
disabled_commands = get_disabled_commands(process)
command_args = process.command
if not command_args:
raise exceptions.NoCommandReceivedError()
show_command_help = False
if "--help" in command_args:
show_command_help = True
command = command_args.split(" ")[0]
command_map: dict[str, type[CommandDispatcher]] = {
cmd_disp.name: cmd_disp
for cmd_disp in COMMANDS
if cmd_disp.name not in disabled_commands
}
command_map["help"] = HelpCommand
LOG.debug("disabled_commands: %r, command_map: %r", disabled_commands, command_map)
LOG.debug("Looking for command %s", command)
if command not in command_map:
raise exceptions.UnknownCommandError()
dispatcher = command_map[command]
LOG.debug("Received command: %s. Dispatching to %r", command, dispatcher)
if command != "help" and show_command_help:
return await dispatcher(process).print_help()
if command == "help":
return await HelpCommand(process, disabled_commands).exec()
return await dispatcher(process).exec()
async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None: async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
"""Dispatch command.""" """Dispatch command."""
command = process.command status_code = 0
if not command: try:
command = "help" await do_dispatch_command(process)
command_map: dict[str, type[CommandDispatcher]] = { except exceptions.BaseSshecretSshError as e:
cmd_disp.name: cmd_disp for cmd_disp in COMMANDS LOG.error("Command Error: %s", e, exc_info=True)
} process.stderr.write(f"{e}\n")
command_map["help"] = HelpCommand status_code = 1
except Exception as e:
LOG.error("Unexpected error: %e", e, exc_info=True)
process.stderr.write(f"{constants.ERROR_GENERIC_ERROR}: {e}")
status_code = 1
dispatcher = command_map.get(command, HelpCommand) finally:
process.exit(status_code)
LOG.debug("Received command: %s", command)
return await dispatcher(process).exec()

View File

@ -12,9 +12,15 @@ LOG = logging.getLogger(__name__)
@final @final
class GetSecret(CommandDispatcher): class GetSecret(CommandDispatcher):
"""Get Secret.""" """Retrieve an encrypted secret.
Returns the value of the secret provided as a mandatory argument.
The secret will be encrypted using the stored RSA public key, and returned
as a base64 encoded string.
"""
name = "get_secret" name = "get_secret"
mandatory_argument = "SECRET"
@override @override
async def exec(self) -> None: async def exec(self) -> None:
@ -22,6 +28,7 @@ class GetSecret(CommandDispatcher):
if len(self.arguments) != 1: if len(self.arguments) != 1:
raise exceptions.UnknownClientOrSecretError() raise exceptions.UnknownClientOrSecretError()
secret_name = self.arguments[0] secret_name = self.arguments[0]
LOG.debug("get_secret called: Argument: %r", secret_name)
if secret_name not in self.client.secrets: if secret_name not in self.client.secrets:
await self.audit( await self.audit(
Operation.DENY, Operation.DENY,

View File

@ -1,48 +0,0 @@
"""Simple help command."""
from typing import final, override
from .base import CommandDispatcher
HELP_TEXT = """
[bold]Sshecret SSH Server[/bold]
[bold]SYNOPSIS[/bold]
An interface to request encrypted client secrets and perform
simple commands.
Secrets will be returned encrypted with the client public key,
encoded as base64.
[bold]AVAILABLE COMMANDS[/bold]
[i]Some of these commands may be disabled by the administrator.[/i]
[bold]ping[/bold]: Perform a ping towards the server.
[bold]get_secret [i]secret_name[/i][/bold]: Get a named secret
[bold]register[/bold]: Register a new client. Will prompt for public key.
[i]If enabled and permitted[/i]
[bold]ls [ --json -j ][/bold]: List available secrets.
[bold]help[/bold]: Prints this message
"""
@final
class HelpCommand(CommandDispatcher):
"""Help.
Returns usage instructions.
"""
name = "help"
@override
async def exec(self) -> None:
"""Execute command."""
self.print(HELP_TEXT, formatted=True)

View File

@ -32,10 +32,7 @@ class ListSecrets(CommandDispatcher):
async def list_secrets(self) -> None: async def list_secrets(self) -> None:
"""List secrets.""" """List secrets."""
self.print( self.rich_print_line(f"Available secrets for client {self.client.name}", ["bold"], rule=True)
f"[bold]Available secrets for client {self.client.name}[/bold]",
formatted=True,
)
await self.audit(Operation.READ, "Listed available secret names") await self.audit(Operation.READ, "Listed available secret names")
for secret_name in self.client.name: for secret_name in self.client.secrets:
self.print(f" - {secret_name}") self.print(f" - {secret_name}")

View File

@ -24,7 +24,6 @@ class Register(CommandDispatcher):
"""Register a new client. """Register a new client.
After connection, you must input a ssh public key. After connection, you must input a ssh public key.
This must be an RSA key. No other types of keys are supported. This must be an RSA key. No other types of keys are supported.
""" """

View File

@ -1,12 +0,0 @@
"""Various utilities."""
import asyncssh
from rich.console import Console
def get_console(process: asyncssh.SSHServerProcess[str]) -> Console:
"""Initiate console from process."""
_width, _height, pixwidth, pixheight = process.term_size
console = Console(
force_terminal=True, width=pixwidth, height=pixheight, color_system="standard"
)
return console

View File

@ -1,19 +1,18 @@
"""SSH Server implementation.""" """SSH Server implementation."""
from asyncio import _register_task
import logging import logging
import asyncssh import asyncssh
import ipaddress import ipaddress
from collections.abc import Awaitable from collections.abc import Awaitable
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Callable, cast, override from typing import Callable, override
from pydantic import IPvAnyNetwork from pydantic import IPvAnyNetwork
from . import constants from sshecret_sshd import constants
from sshecret_sshd.commands import dispatch_command
from sshecret.backend import SshecretBackend, Client, Operation, SubSystem from sshecret.backend import SshecretBackend, Client, Operation, SubSystem
from .settings import ServerSettings, ClientRegistrationSettings from .settings import ServerSettings, ClientRegistrationSettings
@ -29,37 +28,6 @@ PeernameV6 = tuple[str, int, int, int]
Peername = PeernameV4 | PeernameV6 Peername = PeernameV4 | PeernameV6
class CommandError(Exception):
"""Error class for errors during command processing."""
async def audit_process(
backend: SshecretBackend,
process: asyncssh.SSHServerProcess[str],
operation: Operation,
message: str,
secret: str | None = None,
**data: str,
) -> None:
"""Add an audit event from process."""
command = get_process_command(process)
client = get_info_client(process)
username = get_info_username(process)
remote_ip = get_info_remote_ip(process) or "UNKNOWN"
if username:
data["username"] = username
if command and not secret:
cmd, cmd_args = command
if cmd:
data["command"] = cmd
data["args"] = " ".join(cmd_args)
await backend.audit(SubSystem.SSHD).write_async(
operation, message, remote_ip, client, secret=None, secret_name=secret, **data
)
async def audit_event( async def audit_event(
backend: SshecretBackend, backend: SshecretBackend,
message: str, message: str,
@ -87,250 +55,6 @@ def verify_key_input(public_key: str) -> str | None:
return None return None
def get_process_command(
process: asyncssh.SSHServerProcess[str],
) -> tuple[str | None, list[str]]:
"""Extract the process command."""
if not process.command:
return (None, [])
argv = process.command.split(" ")
LOG.debug("Args: %r", argv)
return (argv[0], argv[1:])
def get_info_backend(process: asyncssh.SSHServerProcess[str]) -> SshecretBackend | None:
"""Get backend from process."""
backend = cast("SshecretBackend | None", process.get_extra_info("backend", None))
return backend
def get_info_client(process: asyncssh.SSHServerProcess[str]) -> Client | None:
"""Get info from process."""
client = cast("Client | None", process.get_extra_info("client", None))
return client
def get_info_username(process: asyncssh.SSHServerProcess[str]) -> str | None:
"""Get username from process."""
username = cast("str | None", process.get_extra_info("provided_username", None))
return username
def get_info_remote_ip(process: asyncssh.SSHServerProcess[str]) -> str | None:
"""Get remote IP."""
peername = cast("Peername | None", process.get_extra_info("peername", None))
remote_ip: str | None = None
if peername:
remote_ip = peername[0]
return remote_ip
def get_info_allowed_registration(
process: asyncssh.SSHServerProcess[str],
) -> list[IPvAnyNetwork] | None:
"""Get allowed networks to allow registration from."""
allowed_registration = cast(
list[IPvAnyNetwork] | None,
process.get_extra_info("allow_registration_from", None),
)
return allowed_registration
def get_optional_commands(process: asyncssh.SSHServerProcess[str]) -> dict[str, bool]:
"""Get optional command state."""
with_registration = cast(
bool, process.get_extra_info("registration_enabled", False)
)
with_ping = cast(bool, process.get_extra_info("ping_enabled", False))
return {
"registration": with_registration,
"ping": with_ping,
}
async def get_stdin_public_key(process: asyncssh.SSHServerProcess[str]) -> str | None:
"""Get public key from stdin."""
process.stdout.write("Enter public key:\n")
public_key: str | None = None
try:
async for line in process.stdin:
public_key = verify_key_input(line.rstrip("\n"))
if public_key:
break
raise CommandError(constants.ERROR_INVALID_KEY_TYPE)
except asyncssh.BreakReceived:
pass
else:
process.stdout.write("OK\n")
return public_key
async def register_client(
process: asyncssh.SSHServerProcess[str],
backend: SshecretBackend,
username: str,
) -> None:
"""Register a new client."""
public_key = await get_stdin_public_key(process)
if not public_key:
raise CommandError(constants.ERROR_NO_PUBLIC_KEY)
key = asyncssh.import_public_key(public_key)
if key.algorithm.decode() != "ssh-rsa":
raise CommandError(constants.ERROR_INVALID_KEY_TYPE)
await audit_process(backend, process, Operation.CREATE, "Registering new client")
LOG.debug("Registering client %s with public key %s", username, public_key)
await backend.create_client(username, public_key)
async def get_secret(
backend: SshecretBackend,
client: Client,
secret_name: str,
origin: str,
) -> str:
"""Handle get secret requests from client."""
LOG.debug("Recieved command: get_secret %r", secret_name)
if not secret_name:
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
if secret_name not in client.secrets:
raise CommandError(constants.ERROR_NO_SECRET_FOUND)
await audit_event(
backend,
"Client requested secret",
operation=Operation.READ,
client=client,
origin=origin,
secret=secret_name,
)
# Look up secret
try:
secret = await backend.get_client_secret(client.name, secret_name)
if not secret:
raise CommandError(constants.ERROR_NO_SECRET_FOUND)
return secret
except Exception as exc:
LOG.debug(exc, exc_info=True)
raise CommandError(constants.ERROR_BACKEND_ERROR) from exc
async def dispatch_no_cmd(process: asyncssh.SSHServerProcess[str]) -> None:
"""Dispatch for no command."""
raise CommandError(constants.ERROR_NO_COMMAND_RECEIVED)
async def dispatch_cmd_ping(process: asyncssh.SSHServerProcess[str]) -> None:
"""Dispatch the ping command."""
process.stdout.write("PONG\n")
async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None:
"""Dispatch the register command."""
backend = get_info_backend(process)
if not backend:
raise CommandError(constants.ERROR_INFO_BACKEND_GONE)
username = get_info_username(process)
if not username:
raise CommandError(constants.ERROR_INFO_USERNAME_GONE)
allowed_networks = get_info_allowed_registration(process)
if not allowed_networks:
process.stdout.write("Unauthorized.\n")
await audit_process(
backend,
process,
Operation.DENY,
"Received registration command, but no subnets are allowed.",
)
return
remote_ip = get_info_remote_ip(process)
if not remote_ip:
raise CommandError(constants.ERROR_INFO_REMOTE_IP_GONE)
client_address = ipaddress.ip_address(remote_ip)
for network in allowed_networks:
if client_address in network:
break
else:
await audit_process(
backend,
process,
Operation.DENY,
"Received registration command from unauthorized subnet.",
)
process.stdout.write("Unauthorized.\n")
return
await register_client(process, backend, username)
process.stdout.write("Client registered\n.")
async def dispatch_cmd_get_secret(process: asyncssh.SSHServerProcess[str]) -> None:
"""Dispatch the get_secret command."""
backend = get_info_backend(process)
if not backend:
raise CommandError(constants.ERROR_INFO_BACKEND_GONE)
client = get_info_client(process)
if not client:
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
_cmd, args = get_process_command(process)
if not args:
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
secret_name = args[0]
origin = get_info_remote_ip(process) or "Unknown"
secret = await get_secret(backend, client, secret_name, origin)
process.stdout.write(secret)
async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
"""Dispatch command."""
command, _args = get_process_command(process)
if not command:
process.stderr.write(constants.ERROR_NO_COMMAND_RECEIVED)
process.exit(1)
return
cmdmap: dict[str, CommandDispatch] = {
"get_secret": dispatch_cmd_get_secret,
}
extra_commands = get_optional_commands(process)
if "registration" in extra_commands:
cmdmap["register"] = dispatch_cmd_register
if "ping" in extra_commands:
cmdmap["ping"] = dispatch_cmd_ping
if command not in cmdmap:
process.stderr.write(constants.ERROR_UNKNOWN_COMMAND)
process.exit(1)
return
exit_code = 0
try:
dispatcher = cmdmap[command]
await dispatcher(process)
except CommandError as e:
process.stderr.write(str(e))
exit_code = 1
except Exception as e:
LOG.debug(e, exc_info=True)
process.stderr.write("Unexpected exception:\n")
process.stderr.write(str(e))
exit_code = 1
LOG.debug("Command processing finished.")
process.exit(exit_code)
class AsshyncServer(asyncssh.SSHServer): class AsshyncServer(asyncssh.SSHServer):
"""Asynchronous SSH server implementation.""" """Asynchronous SSH server implementation."""

View File

@ -69,8 +69,8 @@ class TestSshd:
assert found is True assert found is True
session.stdin.write(test_client.public_key + "\n") session.stdin.write(test_client.public_key + "\n")
result = await session.stdout.readline() result = await session.stdout.read()
assert "OK" in result assert "Key is valid. Registering client." in result
await session.wait() await session.wait()
return test_client return test_client

View File

@ -82,10 +82,13 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
"Error, must have a client called template for this to work." "Error, must have a client called template for this to work."
) )
clients_data[name] = clients_data["template"] clients_data[name] = clients_data["template"]
template_secrets: dict[str, str] = {}
for secret_key, secret in secrets_data.items(): for secret_key, secret in secrets_data.items():
s_client, secret_name = secret_key s_client, secret_name = secret_key
if s_client != "template": if s_client != "template":
continue continue
template_secrets[secret_name] = secret
for secret_name, secret in template_secrets.items():
secrets_data[(name, secret_name)] = secret secrets_data[(name, secret_name)] = secret
async def write_audit(*args, **kwargs): async def write_audit(*args, **kwargs):

View File

@ -83,8 +83,9 @@ class TestRegistrationErrors(BaseSshTests):
output = await process.stdout.readline() output = await process.stdout.readline()
assert "Enter public key" in output assert "Enter public key" in output
stdout, stderr = await process.communicate(public_key) stdout, stderr = await process.communicate(public_key)
assert isinstance(stderr, str)
print(f"{stdout=!r}, {stderr=!r}") print(f"{stdout=!r}, {stderr=!r}")
assert stderr == "Error: Invalid key type: Only RSA keys are supported." assert stderr.rstrip() == "Error: Invalid key type: Only RSA keys are supported."
result = await process.wait() result = await process.wait()
assert result.exit_status == 1 assert result.exit_status == 1
@ -102,8 +103,9 @@ class TestRegistrationErrors(BaseSshTests):
output = await process.stdout.readline() output = await process.stdout.readline()
assert "Enter public key" in output assert "Enter public key" in output
stdout, stderr = await process.communicate(public_key) stdout, stderr = await process.communicate(public_key)
assert isinstance(stderr, str)
print(f"{stdout=!r}, {stderr=!r}") print(f"{stdout=!r}, {stderr=!r}")
assert stderr == "Error: Invalid key type: Only RSA keys are supported." assert stderr.rstrip() == "Error: Invalid key type: Only RSA keys are supported."
result = await process.wait() result = await process.wait()
assert result.exit_status == 1 assert result.exit_status == 1
@ -122,7 +124,8 @@ class TestCommandErrors(BaseSshTests):
assert result.exit_status == 1 assert result.exit_status == 1
stderr = result.stderr or "" stderr = result.stderr or ""
assert stderr == "Error: Unsupported command." assert isinstance(stderr, str)
assert stderr.rstrip() == "Error: Unsupported command."
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_no_command( async def test_no_command(
@ -136,7 +139,8 @@ class TestCommandErrors(BaseSshTests):
async with conn.create_process() as process: async with conn.create_process() as process:
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
print(f"{stdout=!r}, {stderr=!r}") print(f"{stdout=!r}, {stderr=!r}")
assert stderr == "Error: No command was received from the client." assert isinstance(stderr, str)
assert stderr.rstrip() == "Error: No command was received from the client."
result = await process.wait() result = await process.wait()
assert result.exit_status == 1 assert result.exit_status == 1

View File

@ -1,10 +1,12 @@
"""Test get secret.""" """Test get secret."""
import allure
import pytest import pytest
from .types import ClientRegistry, CommandRunner from .types import ClientRegistry, CommandRunner
@allure.title("Test get_secret command")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_secret( async def test_get_secret(
ssh_command_runner: CommandRunner, client_registry: ClientRegistry ssh_command_runner: CommandRunner, client_registry: ClientRegistry
@ -19,7 +21,7 @@ async def test_get_secret(
assert isinstance(result.stdout, str) assert isinstance(result.stdout, str)
assert result.stdout.rstrip() == "mocked-secret-mysecret" assert result.stdout.rstrip() == "mocked-secret-mysecret"
@allure.title("Test with invalid secret name")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invalid_secret_name( async def test_invalid_secret_name(
ssh_command_runner: CommandRunner, client_registry: ClientRegistry ssh_command_runner: CommandRunner, client_registry: ClientRegistry
@ -30,4 +32,25 @@ async def test_invalid_secret_name(
result = await ssh_command_runner("test-client", "get_secret mysecret") result = await ssh_command_runner("test-client", "get_secret mysecret")
assert result.exit_status == 1 assert result.exit_status == 1
assert result.stderr == "Error: No secret available with the given name." stderr = result.stderr
assert isinstance(stderr, str)
assert stderr.rstrip() == "Error: No secret available with the given name."
@allure.title("Test get_secret command help")
@pytest.mark.asyncio
async def test_get_secret_cmd_help(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
"""Test running get_secret --help"""
await client_registry["add_client"]("test-client", ["mysecret"])
result = await ssh_command_runner("test-client", "get_secret --help")
assert result.exit_status == 0
print(result.stdout)
assert isinstance(result.stdout, str)
lines = result.stdout.splitlines()
assert lines[0] == "get_secret SECRET"
assert len(lines) > 4

View File

@ -1,8 +1,11 @@
"""Test for the ping command."""
import allure
import pytest import pytest
from .types import ClientRegistry, CommandRunner from .types import ClientRegistry, CommandRunner
@allure.title("Test running the ping command")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_ping_command( async def test_ping_command(
ssh_command_runner: CommandRunner, client_registry: ClientRegistry ssh_command_runner: CommandRunner, client_registry: ClientRegistry
@ -16,3 +19,21 @@ async def test_ping_command(
assert result.stdout is not None assert result.stdout is not None
assert isinstance(result.stdout, str) assert isinstance(result.stdout, str)
assert result.stdout.rstrip() == "PONG" assert result.stdout.rstrip() == "PONG"
@allure.title("Test ping help")
@pytest.mark.asyncio
async def test_ping_cmd_help(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
"""Test running ping --help."""
await client_registry["add_client"]("test-client", ["mysecret"])
result = await ssh_command_runner("test-client", "ping --help")
assert result.exit_status == 0
print(result.stdout)
assert isinstance(result.stdout, str)
lines = result.stdout.splitlines()
assert lines[0] == "ping"
assert len(lines) > 4

View File

@ -1,10 +1,12 @@
"""Test registration.""" """Test registration."""
import allure
import pytest import pytest
from .types import ClientRegistry, CommandRunner, ProcessRunner from .types import ClientRegistry, CommandRunner, ProcessRunner
@allure.title("Test client registration")
@pytest.mark.enable_registration(True) @pytest.mark.enable_registration(True)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_client( async def test_register_client(
@ -29,8 +31,9 @@ async def test_register_client(
assert found is True assert found is True
session.stdin.write(public_key) session.stdin.write(public_key)
result = await session.stdout.readline() data = await session.stdout.read()
assert "OK" in result assert isinstance(data, str)
assert "Key is valid. Registering client" in data
# Test that we can connect # Test that we can connect
@ -39,3 +42,28 @@ async def test_register_client(
assert result.stdout is not None assert result.stdout is not None
assert isinstance(result.stdout, str) assert isinstance(result.stdout, str)
assert result.stdout.rstrip() == "mocked-secret-testsecret" assert result.stdout.rstrip() == "mocked-secret-testsecret"
@allure.title("Test register command help")
@pytest.mark.enable_registration(True)
@pytest.mark.asyncio
async def test_register_cmd_help(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
"""Test running register --help"""
await client_registry["add_client"]("test-client", ["mysecret"])
result = await ssh_command_runner("test-client", "register --help")
assert result.exit_status == 0
print(result.stdout)
assert isinstance(result.stdout, str)
lines = result.stdout.splitlines()
assert lines[0] == "register"
assert len(lines) > 4
# TODO: Test running register with an existing client.