Refactor command handling
This now supports usage/help texts
This commit is contained in:
@ -1 +1,5 @@
|
|||||||
|
"""Commands module."""
|
||||||
|
|
||||||
|
from .dispatcher import dispatch_command
|
||||||
|
|
||||||
|
__all__ = ["dispatch_command"]
|
||||||
|
|||||||
@ -2,11 +2,14 @@
|
|||||||
|
|
||||||
import abc
|
import abc
|
||||||
import json
|
import json
|
||||||
|
import textwrap
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
import ipaddress
|
import ipaddress
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, cast
|
from typing import Any, cast, override
|
||||||
|
|
||||||
|
from rich.console import Console
|
||||||
|
|
||||||
import asyncssh
|
import asyncssh
|
||||||
from pydantic import IPvAnyNetwork, IPvAnyAddress
|
from pydantic import IPvAnyNetwork, IPvAnyAddress
|
||||||
@ -15,8 +18,6 @@ from sshecret.backend.api import SshecretBackend
|
|||||||
from sshecret.backend.models import Client, Operation, SubSystem
|
from sshecret.backend.models import Client, Operation, SubSystem
|
||||||
from sshecret_sshd import exceptions
|
from sshecret_sshd import exceptions
|
||||||
|
|
||||||
from .utils import get_console
|
|
||||||
|
|
||||||
PeernameV4 = tuple[str, int]
|
PeernameV4 = tuple[str, int]
|
||||||
PeernameV6 = tuple[str, int, int, int]
|
PeernameV6 = tuple[str, int, int, int]
|
||||||
Peername = PeernameV4 | PeernameV6
|
Peername = PeernameV4 | PeernameV6
|
||||||
@ -32,11 +33,45 @@ class CmdArgs:
|
|||||||
arguments: list[str] = field(default_factory=list)
|
arguments: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CommandFlag:
|
||||||
|
"""Command flag."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
supports_short: bool = False
|
||||||
|
enabled: bool = False
|
||||||
|
|
||||||
|
@override
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""Format an output for help texts."""
|
||||||
|
if self.supports_short:
|
||||||
|
return f"[-{self.name[0]} --{self.name}]"
|
||||||
|
return f"--{self.name}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_console(process: asyncssh.SSHServerProcess[str]) -> Console:
|
||||||
|
"""Initiate console from process."""
|
||||||
|
width, _height, pixwidth, pixheight = process.term_size
|
||||||
|
LOG.debug("Terminal is %sx%s", pixwidth, pixheight)
|
||||||
|
|
||||||
|
if width > 0:
|
||||||
|
console = Console(
|
||||||
|
force_terminal=True,
|
||||||
|
width=pixwidth,
|
||||||
|
height=pixheight,
|
||||||
|
color_system="standard",
|
||||||
|
)
|
||||||
|
return console
|
||||||
|
return Console(markup=False, color_system=None)
|
||||||
|
|
||||||
|
|
||||||
class CommandDispatcher(abc.ABC):
|
class CommandDispatcher(abc.ABC):
|
||||||
"""Command dispatcher."""
|
"""Command dispatcher."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
flags: dict[str, str] | None = None
|
flags: dict[str, str] | None = None
|
||||||
|
mandatory_argument: str | None = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -47,31 +82,57 @@ class CommandDispatcher(abc.ABC):
|
|||||||
|
|
||||||
def print(
|
def print(
|
||||||
self,
|
self,
|
||||||
data: str,
|
*data: str,
|
||||||
stderr: bool = False,
|
stderr: bool = False,
|
||||||
formatted: bool = False,
|
|
||||||
newline: bool = True,
|
newline: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Write to stdout."""
|
"""Write to stdout."""
|
||||||
if stderr:
|
if stderr:
|
||||||
self.process.stderr.write(data + "\n")
|
for line in data:
|
||||||
|
self.process.stderr.write(line + "\n")
|
||||||
return
|
return
|
||||||
if formatted:
|
for line in data:
|
||||||
data = self.get_rich(data)
|
self.process.stdout.write(line)
|
||||||
if newline:
|
if newline:
|
||||||
data += "\n"
|
self.process.stdout.write("\n")
|
||||||
self.process.stdout.write(data)
|
|
||||||
|
|
||||||
def get_rich(self, data: str) -> str:
|
def rich_print_line(
|
||||||
"""Print with rich formatting."""
|
self, data: str, tags: list[str] | None = None, rule: bool = False
|
||||||
console = get_console(self.process)
|
) -> None:
|
||||||
|
"""Write formatted text to the process.
|
||||||
|
|
||||||
|
IF the client terminal does not support this, no formatting will be added.
|
||||||
|
Otherwise, the tags will be added to the string.
|
||||||
|
"""
|
||||||
|
if not tags:
|
||||||
|
tags = []
|
||||||
|
if not self.formatting_supported:
|
||||||
|
return self.print(data, newline=True)
|
||||||
|
for tag in tags:
|
||||||
|
data = f"[{tag}]{data}[/{tag}]"
|
||||||
|
|
||||||
|
console = self.get_console()
|
||||||
with console.capture() as capture:
|
with console.capture() as capture:
|
||||||
|
if rule:
|
||||||
|
console.rule(data)
|
||||||
|
else:
|
||||||
console.print(data)
|
console.print(data)
|
||||||
return capture.get()
|
|
||||||
|
self.process.stdout.write(capture.get())
|
||||||
|
|
||||||
|
def get_console(self) -> Console:
|
||||||
|
"""Initiate console from process."""
|
||||||
|
return get_console(self.process)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def formatting_supported(self) -> bool:
|
||||||
|
"""Check if the terminal supports formatting."""
|
||||||
|
term_width = self.process.term_size[0]
|
||||||
|
return term_width > 0
|
||||||
|
|
||||||
def print_json(self, obj: dict[str, Any]) -> None:
|
def print_json(self, obj: dict[str, Any]) -> None:
|
||||||
"""Print a json object."""
|
"""Print a json object."""
|
||||||
console = get_console(self.process)
|
console = self.get_console()
|
||||||
data = json.dumps(obj)
|
data = json.dumps(obj)
|
||||||
with console.capture() as capture:
|
with console.capture() as capture:
|
||||||
console.print_json(data)
|
console.print_json(data)
|
||||||
@ -97,6 +158,7 @@ class CommandDispatcher(abc.ABC):
|
|||||||
username,
|
username,
|
||||||
client,
|
client,
|
||||||
secret,
|
secret,
|
||||||
|
data,
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.backend.audit(SubSystem.SSHD).write_async(
|
await self.backend.audit(SubSystem.SSHD).write_async(
|
||||||
@ -148,6 +210,7 @@ class CommandDispatcher(abc.ABC):
|
|||||||
def arguments(self) -> list[str]:
|
def arguments(self) -> list[str]:
|
||||||
"""Get non-flag arguments."""
|
"""Get non-flag arguments."""
|
||||||
parsed = self.parse_command()
|
parsed = self.parse_command()
|
||||||
|
LOG.debug("Parsed command: %r", parsed)
|
||||||
if not self.flags:
|
if not self.flags:
|
||||||
return parsed.arguments
|
return parsed.arguments
|
||||||
return [
|
return [
|
||||||
@ -211,37 +274,69 @@ class CommandDispatcher(abc.ABC):
|
|||||||
return allowed_registration
|
return allowed_registration
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def print_help(cls, process: asyncssh.SSHServerProcess[str]) -> None:
|
def _format_command_name(cls) -> str:
|
||||||
"""Print help."""
|
"""Format command name with arguments."""
|
||||||
descr = cls.__doc__
|
name = cls.name
|
||||||
|
if cls.mandatory_argument:
|
||||||
|
name = f"{name} {cls.mandatory_argument}"
|
||||||
|
|
||||||
help_text: list[str] = [
|
if not cls.flags:
|
||||||
f"[bold]{cls.name}[/bold]",
|
return name
|
||||||
descr or "",
|
flags = cls.command_flags()
|
||||||
|
args: list[str] = [str(flag) for flag in flags.values()]
|
||||||
|
flagstr = " ".join(args)
|
||||||
|
return f"{name} {flagstr}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def usage(cls, indent: int = 0) -> str:
|
||||||
|
"""Print command usage."""
|
||||||
|
indent_prefix = " " * indent
|
||||||
|
inner_indent = 2 + indent
|
||||||
|
inner_prefix = " " * inner_indent
|
||||||
|
usage_str: list[str] = []
|
||||||
|
default_usage = "No help available."
|
||||||
|
if not cls.__doc__:
|
||||||
|
usage_str.append(default_usage)
|
||||||
|
else:
|
||||||
|
usage_str = [
|
||||||
|
textwrap.indent(line, prefix=inner_prefix)
|
||||||
|
for line in cls.__doc__.splitlines()
|
||||||
]
|
]
|
||||||
flags: dict[str, str] = {}
|
flags = cls.command_flags()
|
||||||
|
if flags:
|
||||||
|
usage_str.append("")
|
||||||
|
usage_str.append(textwrap.indent("Arguments:", prefix=(" " * 4)))
|
||||||
|
for info in flags.values():
|
||||||
|
usage_str.append(
|
||||||
|
textwrap.indent(f"{info!s}: {info.description}", prefix=(" " * 6))
|
||||||
|
)
|
||||||
|
usage = "\n".join(usage_str)
|
||||||
|
if indent:
|
||||||
|
usage = textwrap.indent(usage, prefix=indent_prefix)
|
||||||
|
|
||||||
|
return usage
|
||||||
|
|
||||||
|
async def print_help(self) -> None:
|
||||||
|
"""Print help."""
|
||||||
|
usage = type(self).usage()
|
||||||
|
command_name = type(self)._format_command_name()
|
||||||
|
self.rich_print_line(command_name)
|
||||||
|
self.print(usage)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def command_flags(cls) -> dict[str, CommandFlag]:
|
||||||
|
"""Parse the command flags."""
|
||||||
shortflags: defaultdict[str, list[str]] = defaultdict(list)
|
shortflags: defaultdict[str, list[str]] = defaultdict(list)
|
||||||
if cls.flags:
|
if not cls.flags:
|
||||||
flags = cls.flags
|
return {}
|
||||||
|
for flag in cls.flags.keys():
|
||||||
arg_text: list[str] = []
|
|
||||||
for flag in flags.keys():
|
|
||||||
shortflags[flag[0]].append(flag)
|
shortflags[flag[0]].append(flag)
|
||||||
for flag, flag_descr in flags.items():
|
|
||||||
flagstr = f"--{flag}"
|
|
||||||
if len(shortflags[flag[0]]) == 1:
|
|
||||||
flagstr = f"[ -{flag[0]} --{flag} ]"
|
|
||||||
|
|
||||||
arg_text.extend([f"{flagstr}:", f"{flag_descr}", ""])
|
command_flags: dict[str, CommandFlag] = {}
|
||||||
|
|
||||||
console = get_console(process)
|
for flag, description in cls.flags.items():
|
||||||
with console.capture() as capture:
|
supports_shorts = True
|
||||||
for line in help_text:
|
if len(shortflags[flag[0]]) > 1:
|
||||||
console.print(line)
|
supports_shorts = False
|
||||||
if arg_text:
|
command_flags[flag] = CommandFlag(flag, description, supports_shorts)
|
||||||
console.rule("Argument flags")
|
return command_flags
|
||||||
|
|
||||||
for line in arg_text:
|
|
||||||
console.print(line)
|
|
||||||
|
|
||||||
process.stdout.write(capture.get())
|
|
||||||
|
|||||||
@ -5,7 +5,7 @@ Register arguments here.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import final, override
|
from typing import cast, final, override
|
||||||
|
|
||||||
import asyncssh
|
import asyncssh
|
||||||
|
|
||||||
@ -47,44 +47,105 @@ class HelpCommand(CommandDispatcher):
|
|||||||
|
|
||||||
name = "help"
|
name = "help"
|
||||||
|
|
||||||
|
@override
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
process: asyncssh.SSHServerProcess[str],
|
||||||
|
disabled_commands: list[str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Init help command."""
|
||||||
|
super().__init__(process)
|
||||||
|
self.disabled_commands: list[str] = disabled_commands or []
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def exec(self) -> None:
|
async def exec(self) -> None:
|
||||||
"""Execute command."""
|
"""Execute command."""
|
||||||
usage_text = "[bold]AVAILABLE COMMANDS[/bold]"
|
output_lines: list[tuple[str, str] | str] = []
|
||||||
usage_text += (
|
output_lines.append(("Available commands:", "bold"))
|
||||||
"[i]Some commands may be disabled or restricted by the administrator[/i]\n"
|
output_lines.append(
|
||||||
|
("Some commands may be disabled or restricted by the administrator", "i")
|
||||||
)
|
)
|
||||||
|
output_lines.append("")
|
||||||
for command in COMMANDS:
|
for command in COMMANDS:
|
||||||
usage_text += f" [bold]{command.name}:[/bold]"
|
if command.name in self.disabled_commands:
|
||||||
usage_text += textwrap.indent(self.get_command_doc(command), " ")
|
continue
|
||||||
|
command_usage = command.usage(indent=2)
|
||||||
|
output_lines.append((f" {command._format_command_name()}", "bold"))
|
||||||
|
output_lines.extend(command_usage.splitlines())
|
||||||
|
output_lines.append("")
|
||||||
|
|
||||||
width, _height, _pixwidth, _pixheight = self.process.term_size
|
for line in output_lines:
|
||||||
wrapped = textwrap.wrap(usage_text, width=width)
|
tags: list[str] = []
|
||||||
wrapped_lines = "\n".join(wrapped)
|
if isinstance(line, tuple):
|
||||||
self.print(wrapped_lines, formatted=True)
|
text, tag = line
|
||||||
|
tags.append(tag)
|
||||||
|
else:
|
||||||
|
text = line
|
||||||
|
self.rich_print_line(text, tags)
|
||||||
|
|
||||||
def get_command_doc(self, command: type[CommandDispatcher]) -> str:
|
|
||||||
"""Format usage string for command."""
|
def get_disabled_commands(process: asyncssh.SSHServerProcess[str]) -> list[str]:
|
||||||
default_usage = f"No help available."
|
"""Get optional command state."""
|
||||||
if not command.__doc__:
|
with_registration = cast(
|
||||||
return default_usage
|
bool, process.get_extra_info("registration_enabled", False)
|
||||||
# Skip the first two lines:
|
)
|
||||||
usage_str = command.__doc__[:2:]
|
with_ping = cast(bool, process.get_extra_info("ping_enabled", False))
|
||||||
return usage_str.join("\n")
|
optional_commands = {
|
||||||
|
"register": with_registration,
|
||||||
|
"ping": with_ping,
|
||||||
|
}
|
||||||
|
disabled = [key for key, value in optional_commands.items() if not value]
|
||||||
|
return disabled
|
||||||
|
|
||||||
|
|
||||||
|
async def do_dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||||
|
"""Dispatch command."""
|
||||||
|
disabled_commands = get_disabled_commands(process)
|
||||||
|
command_args = process.command
|
||||||
|
|
||||||
|
if not command_args:
|
||||||
|
raise exceptions.NoCommandReceivedError()
|
||||||
|
show_command_help = False
|
||||||
|
if "--help" in command_args:
|
||||||
|
show_command_help = True
|
||||||
|
command = command_args.split(" ")[0]
|
||||||
|
command_map: dict[str, type[CommandDispatcher]] = {
|
||||||
|
cmd_disp.name: cmd_disp
|
||||||
|
for cmd_disp in COMMANDS
|
||||||
|
if cmd_disp.name not in disabled_commands
|
||||||
|
}
|
||||||
|
|
||||||
|
command_map["help"] = HelpCommand
|
||||||
|
|
||||||
|
LOG.debug("disabled_commands: %r, command_map: %r", disabled_commands, command_map)
|
||||||
|
LOG.debug("Looking for command %s", command)
|
||||||
|
if command not in command_map:
|
||||||
|
raise exceptions.UnknownCommandError()
|
||||||
|
|
||||||
|
dispatcher = command_map[command]
|
||||||
|
|
||||||
|
LOG.debug("Received command: %s. Dispatching to %r", command, dispatcher)
|
||||||
|
if command != "help" and show_command_help:
|
||||||
|
return await dispatcher(process).print_help()
|
||||||
|
|
||||||
|
if command == "help":
|
||||||
|
return await HelpCommand(process, disabled_commands).exec()
|
||||||
|
return await dispatcher(process).exec()
|
||||||
|
|
||||||
|
|
||||||
async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
|
async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||||
"""Dispatch command."""
|
"""Dispatch command."""
|
||||||
command = process.command
|
status_code = 0
|
||||||
if not command:
|
try:
|
||||||
command = "help"
|
await do_dispatch_command(process)
|
||||||
command_map: dict[str, type[CommandDispatcher]] = {
|
except exceptions.BaseSshecretSshError as e:
|
||||||
cmd_disp.name: cmd_disp for cmd_disp in COMMANDS
|
LOG.error("Command Error: %s", e, exc_info=True)
|
||||||
}
|
process.stderr.write(f"{e}\n")
|
||||||
command_map["help"] = HelpCommand
|
status_code = 1
|
||||||
|
except Exception as e:
|
||||||
|
LOG.error("Unexpected error: %e", e, exc_info=True)
|
||||||
|
process.stderr.write(f"{constants.ERROR_GENERIC_ERROR}: {e}")
|
||||||
|
status_code = 1
|
||||||
|
|
||||||
dispatcher = command_map.get(command, HelpCommand)
|
finally:
|
||||||
|
process.exit(status_code)
|
||||||
LOG.debug("Received command: %s", command)
|
|
||||||
|
|
||||||
return await dispatcher(process).exec()
|
|
||||||
|
|||||||
@ -12,9 +12,15 @@ LOG = logging.getLogger(__name__)
|
|||||||
|
|
||||||
@final
|
@final
|
||||||
class GetSecret(CommandDispatcher):
|
class GetSecret(CommandDispatcher):
|
||||||
"""Get Secret."""
|
"""Retrieve an encrypted secret.
|
||||||
|
|
||||||
|
Returns the value of the secret provided as a mandatory argument.
|
||||||
|
The secret will be encrypted using the stored RSA public key, and returned
|
||||||
|
as a base64 encoded string.
|
||||||
|
"""
|
||||||
|
|
||||||
name = "get_secret"
|
name = "get_secret"
|
||||||
|
mandatory_argument = "SECRET"
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def exec(self) -> None:
|
async def exec(self) -> None:
|
||||||
@ -22,6 +28,7 @@ class GetSecret(CommandDispatcher):
|
|||||||
if len(self.arguments) != 1:
|
if len(self.arguments) != 1:
|
||||||
raise exceptions.UnknownClientOrSecretError()
|
raise exceptions.UnknownClientOrSecretError()
|
||||||
secret_name = self.arguments[0]
|
secret_name = self.arguments[0]
|
||||||
|
LOG.debug("get_secret called: Argument: %r", secret_name)
|
||||||
if secret_name not in self.client.secrets:
|
if secret_name not in self.client.secrets:
|
||||||
await self.audit(
|
await self.audit(
|
||||||
Operation.DENY,
|
Operation.DENY,
|
||||||
|
|||||||
@ -1,48 +0,0 @@
|
|||||||
"""Simple help command."""
|
|
||||||
|
|
||||||
|
|
||||||
from typing import final, override
|
|
||||||
|
|
||||||
from .base import CommandDispatcher
|
|
||||||
|
|
||||||
HELP_TEXT = """
|
|
||||||
[bold]Sshecret SSH Server[/bold]
|
|
||||||
|
|
||||||
|
|
||||||
[bold]SYNOPSIS[/bold]
|
|
||||||
An interface to request encrypted client secrets and perform
|
|
||||||
simple commands.
|
|
||||||
|
|
||||||
Secrets will be returned encrypted with the client public key,
|
|
||||||
encoded as base64.
|
|
||||||
|
|
||||||
|
|
||||||
[bold]AVAILABLE COMMANDS[/bold]
|
|
||||||
[i]Some of these commands may be disabled by the administrator.[/i]
|
|
||||||
|
|
||||||
[bold]ping[/bold]: Perform a ping towards the server.
|
|
||||||
|
|
||||||
[bold]get_secret [i]secret_name[/i][/bold]: Get a named secret
|
|
||||||
|
|
||||||
[bold]register[/bold]: Register a new client. Will prompt for public key.
|
|
||||||
[i]If enabled and permitted[/i]
|
|
||||||
|
|
||||||
[bold]ls [ --json -j ][/bold]: List available secrets.
|
|
||||||
|
|
||||||
[bold]help[/bold]: Prints this message
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@final
|
|
||||||
class HelpCommand(CommandDispatcher):
|
|
||||||
"""Help.
|
|
||||||
|
|
||||||
Returns usage instructions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name = "help"
|
|
||||||
|
|
||||||
@override
|
|
||||||
async def exec(self) -> None:
|
|
||||||
"""Execute command."""
|
|
||||||
self.print(HELP_TEXT, formatted=True)
|
|
||||||
@ -32,10 +32,7 @@ class ListSecrets(CommandDispatcher):
|
|||||||
|
|
||||||
async def list_secrets(self) -> None:
|
async def list_secrets(self) -> None:
|
||||||
"""List secrets."""
|
"""List secrets."""
|
||||||
self.print(
|
self.rich_print_line(f"Available secrets for client {self.client.name}", ["bold"], rule=True)
|
||||||
f"[bold]Available secrets for client {self.client.name}[/bold]",
|
|
||||||
formatted=True,
|
|
||||||
)
|
|
||||||
await self.audit(Operation.READ, "Listed available secret names")
|
await self.audit(Operation.READ, "Listed available secret names")
|
||||||
for secret_name in self.client.name:
|
for secret_name in self.client.secrets:
|
||||||
self.print(f" - {secret_name}")
|
self.print(f" - {secret_name}")
|
||||||
|
|||||||
@ -24,7 +24,6 @@ class Register(CommandDispatcher):
|
|||||||
"""Register a new client.
|
"""Register a new client.
|
||||||
|
|
||||||
After connection, you must input a ssh public key.
|
After connection, you must input a ssh public key.
|
||||||
|
|
||||||
This must be an RSA key. No other types of keys are supported.
|
This must be an RSA key. No other types of keys are supported.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@ -1,12 +0,0 @@
|
|||||||
"""Various utilities."""
|
|
||||||
|
|
||||||
import asyncssh
|
|
||||||
from rich.console import Console
|
|
||||||
|
|
||||||
def get_console(process: asyncssh.SSHServerProcess[str]) -> Console:
|
|
||||||
"""Initiate console from process."""
|
|
||||||
_width, _height, pixwidth, pixheight = process.term_size
|
|
||||||
console = Console(
|
|
||||||
force_terminal=True, width=pixwidth, height=pixheight, color_system="standard"
|
|
||||||
)
|
|
||||||
return console
|
|
||||||
@ -1,19 +1,18 @@
|
|||||||
"""SSH Server implementation."""
|
"""SSH Server implementation."""
|
||||||
|
|
||||||
from asyncio import _register_task
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import asyncssh
|
import asyncssh
|
||||||
import ipaddress
|
import ipaddress
|
||||||
|
|
||||||
from collections.abc import Awaitable
|
from collections.abc import Awaitable
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, cast, override
|
from typing import Callable, override
|
||||||
|
|
||||||
from pydantic import IPvAnyNetwork
|
from pydantic import IPvAnyNetwork
|
||||||
|
|
||||||
from . import constants
|
from sshecret_sshd import constants
|
||||||
|
from sshecret_sshd.commands import dispatch_command
|
||||||
|
|
||||||
from sshecret.backend import SshecretBackend, Client, Operation, SubSystem
|
from sshecret.backend import SshecretBackend, Client, Operation, SubSystem
|
||||||
from .settings import ServerSettings, ClientRegistrationSettings
|
from .settings import ServerSettings, ClientRegistrationSettings
|
||||||
@ -29,37 +28,6 @@ PeernameV6 = tuple[str, int, int, int]
|
|||||||
Peername = PeernameV4 | PeernameV6
|
Peername = PeernameV4 | PeernameV6
|
||||||
|
|
||||||
|
|
||||||
class CommandError(Exception):
|
|
||||||
"""Error class for errors during command processing."""
|
|
||||||
|
|
||||||
|
|
||||||
async def audit_process(
|
|
||||||
backend: SshecretBackend,
|
|
||||||
process: asyncssh.SSHServerProcess[str],
|
|
||||||
operation: Operation,
|
|
||||||
message: str,
|
|
||||||
secret: str | None = None,
|
|
||||||
**data: str,
|
|
||||||
) -> None:
|
|
||||||
"""Add an audit event from process."""
|
|
||||||
command = get_process_command(process)
|
|
||||||
client = get_info_client(process)
|
|
||||||
username = get_info_username(process)
|
|
||||||
remote_ip = get_info_remote_ip(process) or "UNKNOWN"
|
|
||||||
if username:
|
|
||||||
data["username"] = username
|
|
||||||
|
|
||||||
if command and not secret:
|
|
||||||
cmd, cmd_args = command
|
|
||||||
if cmd:
|
|
||||||
data["command"] = cmd
|
|
||||||
data["args"] = " ".join(cmd_args)
|
|
||||||
|
|
||||||
await backend.audit(SubSystem.SSHD).write_async(
|
|
||||||
operation, message, remote_ip, client, secret=None, secret_name=secret, **data
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def audit_event(
|
async def audit_event(
|
||||||
backend: SshecretBackend,
|
backend: SshecretBackend,
|
||||||
message: str,
|
message: str,
|
||||||
@ -87,250 +55,6 @@ def verify_key_input(public_key: str) -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_process_command(
|
|
||||||
process: asyncssh.SSHServerProcess[str],
|
|
||||||
) -> tuple[str | None, list[str]]:
|
|
||||||
"""Extract the process command."""
|
|
||||||
if not process.command:
|
|
||||||
return (None, [])
|
|
||||||
argv = process.command.split(" ")
|
|
||||||
LOG.debug("Args: %r", argv)
|
|
||||||
return (argv[0], argv[1:])
|
|
||||||
|
|
||||||
|
|
||||||
def get_info_backend(process: asyncssh.SSHServerProcess[str]) -> SshecretBackend | None:
|
|
||||||
"""Get backend from process."""
|
|
||||||
backend = cast("SshecretBackend | None", process.get_extra_info("backend", None))
|
|
||||||
return backend
|
|
||||||
|
|
||||||
|
|
||||||
def get_info_client(process: asyncssh.SSHServerProcess[str]) -> Client | None:
|
|
||||||
"""Get info from process."""
|
|
||||||
client = cast("Client | None", process.get_extra_info("client", None))
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
def get_info_username(process: asyncssh.SSHServerProcess[str]) -> str | None:
|
|
||||||
"""Get username from process."""
|
|
||||||
username = cast("str | None", process.get_extra_info("provided_username", None))
|
|
||||||
return username
|
|
||||||
|
|
||||||
|
|
||||||
def get_info_remote_ip(process: asyncssh.SSHServerProcess[str]) -> str | None:
|
|
||||||
"""Get remote IP."""
|
|
||||||
|
|
||||||
peername = cast("Peername | None", process.get_extra_info("peername", None))
|
|
||||||
remote_ip: str | None = None
|
|
||||||
if peername:
|
|
||||||
remote_ip = peername[0]
|
|
||||||
|
|
||||||
return remote_ip
|
|
||||||
|
|
||||||
|
|
||||||
def get_info_allowed_registration(
|
|
||||||
process: asyncssh.SSHServerProcess[str],
|
|
||||||
) -> list[IPvAnyNetwork] | None:
|
|
||||||
"""Get allowed networks to allow registration from."""
|
|
||||||
|
|
||||||
allowed_registration = cast(
|
|
||||||
list[IPvAnyNetwork] | None,
|
|
||||||
process.get_extra_info("allow_registration_from", None),
|
|
||||||
)
|
|
||||||
return allowed_registration
|
|
||||||
|
|
||||||
|
|
||||||
def get_optional_commands(process: asyncssh.SSHServerProcess[str]) -> dict[str, bool]:
|
|
||||||
"""Get optional command state."""
|
|
||||||
with_registration = cast(
|
|
||||||
bool, process.get_extra_info("registration_enabled", False)
|
|
||||||
)
|
|
||||||
with_ping = cast(bool, process.get_extra_info("ping_enabled", False))
|
|
||||||
return {
|
|
||||||
"registration": with_registration,
|
|
||||||
"ping": with_ping,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def get_stdin_public_key(process: asyncssh.SSHServerProcess[str]) -> str | None:
|
|
||||||
"""Get public key from stdin."""
|
|
||||||
process.stdout.write("Enter public key:\n")
|
|
||||||
public_key: str | None = None
|
|
||||||
try:
|
|
||||||
async for line in process.stdin:
|
|
||||||
public_key = verify_key_input(line.rstrip("\n"))
|
|
||||||
if public_key:
|
|
||||||
break
|
|
||||||
raise CommandError(constants.ERROR_INVALID_KEY_TYPE)
|
|
||||||
except asyncssh.BreakReceived:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
process.stdout.write("OK\n")
|
|
||||||
return public_key
|
|
||||||
|
|
||||||
|
|
||||||
async def register_client(
|
|
||||||
process: asyncssh.SSHServerProcess[str],
|
|
||||||
backend: SshecretBackend,
|
|
||||||
username: str,
|
|
||||||
) -> None:
|
|
||||||
"""Register a new client."""
|
|
||||||
public_key = await get_stdin_public_key(process)
|
|
||||||
if not public_key:
|
|
||||||
raise CommandError(constants.ERROR_NO_PUBLIC_KEY)
|
|
||||||
|
|
||||||
key = asyncssh.import_public_key(public_key)
|
|
||||||
if key.algorithm.decode() != "ssh-rsa":
|
|
||||||
raise CommandError(constants.ERROR_INVALID_KEY_TYPE)
|
|
||||||
await audit_process(backend, process, Operation.CREATE, "Registering new client")
|
|
||||||
LOG.debug("Registering client %s with public key %s", username, public_key)
|
|
||||||
await backend.create_client(username, public_key)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_secret(
|
|
||||||
backend: SshecretBackend,
|
|
||||||
client: Client,
|
|
||||||
secret_name: str,
|
|
||||||
origin: str,
|
|
||||||
) -> str:
|
|
||||||
"""Handle get secret requests from client."""
|
|
||||||
LOG.debug("Recieved command: get_secret %r", secret_name)
|
|
||||||
if not secret_name:
|
|
||||||
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
|
||||||
if secret_name not in client.secrets:
|
|
||||||
raise CommandError(constants.ERROR_NO_SECRET_FOUND)
|
|
||||||
|
|
||||||
await audit_event(
|
|
||||||
backend,
|
|
||||||
"Client requested secret",
|
|
||||||
operation=Operation.READ,
|
|
||||||
client=client,
|
|
||||||
origin=origin,
|
|
||||||
secret=secret_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Look up secret
|
|
||||||
try:
|
|
||||||
secret = await backend.get_client_secret(client.name, secret_name)
|
|
||||||
if not secret:
|
|
||||||
raise CommandError(constants.ERROR_NO_SECRET_FOUND)
|
|
||||||
return secret
|
|
||||||
except Exception as exc:
|
|
||||||
LOG.debug(exc, exc_info=True)
|
|
||||||
raise CommandError(constants.ERROR_BACKEND_ERROR) from exc
|
|
||||||
|
|
||||||
|
|
||||||
async def dispatch_no_cmd(process: asyncssh.SSHServerProcess[str]) -> None:
|
|
||||||
"""Dispatch for no command."""
|
|
||||||
raise CommandError(constants.ERROR_NO_COMMAND_RECEIVED)
|
|
||||||
|
|
||||||
|
|
||||||
async def dispatch_cmd_ping(process: asyncssh.SSHServerProcess[str]) -> None:
|
|
||||||
"""Dispatch the ping command."""
|
|
||||||
process.stdout.write("PONG\n")
|
|
||||||
|
|
||||||
|
|
||||||
async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None:
|
|
||||||
"""Dispatch the register command."""
|
|
||||||
backend = get_info_backend(process)
|
|
||||||
if not backend:
|
|
||||||
raise CommandError(constants.ERROR_INFO_BACKEND_GONE)
|
|
||||||
username = get_info_username(process)
|
|
||||||
if not username:
|
|
||||||
raise CommandError(constants.ERROR_INFO_USERNAME_GONE)
|
|
||||||
|
|
||||||
allowed_networks = get_info_allowed_registration(process)
|
|
||||||
if not allowed_networks:
|
|
||||||
process.stdout.write("Unauthorized.\n")
|
|
||||||
await audit_process(
|
|
||||||
backend,
|
|
||||||
process,
|
|
||||||
Operation.DENY,
|
|
||||||
"Received registration command, but no subnets are allowed.",
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
remote_ip = get_info_remote_ip(process)
|
|
||||||
|
|
||||||
if not remote_ip:
|
|
||||||
raise CommandError(constants.ERROR_INFO_REMOTE_IP_GONE)
|
|
||||||
|
|
||||||
client_address = ipaddress.ip_address(remote_ip)
|
|
||||||
|
|
||||||
for network in allowed_networks:
|
|
||||||
if client_address in network:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
await audit_process(
|
|
||||||
backend,
|
|
||||||
process,
|
|
||||||
Operation.DENY,
|
|
||||||
"Received registration command from unauthorized subnet.",
|
|
||||||
)
|
|
||||||
process.stdout.write("Unauthorized.\n")
|
|
||||||
return
|
|
||||||
|
|
||||||
await register_client(process, backend, username)
|
|
||||||
|
|
||||||
process.stdout.write("Client registered\n.")
|
|
||||||
|
|
||||||
|
|
||||||
async def dispatch_cmd_get_secret(process: asyncssh.SSHServerProcess[str]) -> None:
|
|
||||||
"""Dispatch the get_secret command."""
|
|
||||||
backend = get_info_backend(process)
|
|
||||||
if not backend:
|
|
||||||
raise CommandError(constants.ERROR_INFO_BACKEND_GONE)
|
|
||||||
|
|
||||||
client = get_info_client(process)
|
|
||||||
if not client:
|
|
||||||
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
|
||||||
_cmd, args = get_process_command(process)
|
|
||||||
if not args:
|
|
||||||
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
|
||||||
secret_name = args[0]
|
|
||||||
|
|
||||||
origin = get_info_remote_ip(process) or "Unknown"
|
|
||||||
secret = await get_secret(backend, client, secret_name, origin)
|
|
||||||
process.stdout.write(secret)
|
|
||||||
|
|
||||||
|
|
||||||
async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
|
|
||||||
"""Dispatch command."""
|
|
||||||
command, _args = get_process_command(process)
|
|
||||||
if not command:
|
|
||||||
process.stderr.write(constants.ERROR_NO_COMMAND_RECEIVED)
|
|
||||||
process.exit(1)
|
|
||||||
return
|
|
||||||
cmdmap: dict[str, CommandDispatch] = {
|
|
||||||
"get_secret": dispatch_cmd_get_secret,
|
|
||||||
}
|
|
||||||
extra_commands = get_optional_commands(process)
|
|
||||||
if "registration" in extra_commands:
|
|
||||||
cmdmap["register"] = dispatch_cmd_register
|
|
||||||
if "ping" in extra_commands:
|
|
||||||
cmdmap["ping"] = dispatch_cmd_ping
|
|
||||||
|
|
||||||
if command not in cmdmap:
|
|
||||||
process.stderr.write(constants.ERROR_UNKNOWN_COMMAND)
|
|
||||||
process.exit(1)
|
|
||||||
return
|
|
||||||
exit_code = 0
|
|
||||||
try:
|
|
||||||
dispatcher = cmdmap[command]
|
|
||||||
await dispatcher(process)
|
|
||||||
except CommandError as e:
|
|
||||||
process.stderr.write(str(e))
|
|
||||||
exit_code = 1
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
LOG.debug(e, exc_info=True)
|
|
||||||
process.stderr.write("Unexpected exception:\n")
|
|
||||||
process.stderr.write(str(e))
|
|
||||||
exit_code = 1
|
|
||||||
|
|
||||||
LOG.debug("Command processing finished.")
|
|
||||||
process.exit(exit_code)
|
|
||||||
|
|
||||||
|
|
||||||
class AsshyncServer(asyncssh.SSHServer):
|
class AsshyncServer(asyncssh.SSHServer):
|
||||||
"""Asynchronous SSH server implementation."""
|
"""Asynchronous SSH server implementation."""
|
||||||
|
|
||||||
|
|||||||
@ -69,8 +69,8 @@ class TestSshd:
|
|||||||
assert found is True
|
assert found is True
|
||||||
session.stdin.write(test_client.public_key + "\n")
|
session.stdin.write(test_client.public_key + "\n")
|
||||||
|
|
||||||
result = await session.stdout.readline()
|
result = await session.stdout.read()
|
||||||
assert "OK" in result
|
assert "Key is valid. Registering client." in result
|
||||||
await session.wait()
|
await session.wait()
|
||||||
return test_client
|
return test_client
|
||||||
|
|
||||||
|
|||||||
@ -82,10 +82,13 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
|
|||||||
"Error, must have a client called template for this to work."
|
"Error, must have a client called template for this to work."
|
||||||
)
|
)
|
||||||
clients_data[name] = clients_data["template"]
|
clients_data[name] = clients_data["template"]
|
||||||
|
template_secrets: dict[str, str] = {}
|
||||||
for secret_key, secret in secrets_data.items():
|
for secret_key, secret in secrets_data.items():
|
||||||
s_client, secret_name = secret_key
|
s_client, secret_name = secret_key
|
||||||
if s_client != "template":
|
if s_client != "template":
|
||||||
continue
|
continue
|
||||||
|
template_secrets[secret_name] = secret
|
||||||
|
for secret_name, secret in template_secrets.items():
|
||||||
secrets_data[(name, secret_name)] = secret
|
secrets_data[(name, secret_name)] = secret
|
||||||
|
|
||||||
async def write_audit(*args, **kwargs):
|
async def write_audit(*args, **kwargs):
|
||||||
|
|||||||
@ -83,8 +83,9 @@ class TestRegistrationErrors(BaseSshTests):
|
|||||||
output = await process.stdout.readline()
|
output = await process.stdout.readline()
|
||||||
assert "Enter public key" in output
|
assert "Enter public key" in output
|
||||||
stdout, stderr = await process.communicate(public_key)
|
stdout, stderr = await process.communicate(public_key)
|
||||||
|
assert isinstance(stderr, str)
|
||||||
print(f"{stdout=!r}, {stderr=!r}")
|
print(f"{stdout=!r}, {stderr=!r}")
|
||||||
assert stderr == "Error: Invalid key type: Only RSA keys are supported."
|
assert stderr.rstrip() == "Error: Invalid key type: Only RSA keys are supported."
|
||||||
result = await process.wait()
|
result = await process.wait()
|
||||||
assert result.exit_status == 1
|
assert result.exit_status == 1
|
||||||
|
|
||||||
@ -102,8 +103,9 @@ class TestRegistrationErrors(BaseSshTests):
|
|||||||
output = await process.stdout.readline()
|
output = await process.stdout.readline()
|
||||||
assert "Enter public key" in output
|
assert "Enter public key" in output
|
||||||
stdout, stderr = await process.communicate(public_key)
|
stdout, stderr = await process.communicate(public_key)
|
||||||
|
assert isinstance(stderr, str)
|
||||||
print(f"{stdout=!r}, {stderr=!r}")
|
print(f"{stdout=!r}, {stderr=!r}")
|
||||||
assert stderr == "Error: Invalid key type: Only RSA keys are supported."
|
assert stderr.rstrip() == "Error: Invalid key type: Only RSA keys are supported."
|
||||||
result = await process.wait()
|
result = await process.wait()
|
||||||
assert result.exit_status == 1
|
assert result.exit_status == 1
|
||||||
|
|
||||||
@ -122,7 +124,8 @@ class TestCommandErrors(BaseSshTests):
|
|||||||
|
|
||||||
assert result.exit_status == 1
|
assert result.exit_status == 1
|
||||||
stderr = result.stderr or ""
|
stderr = result.stderr or ""
|
||||||
assert stderr == "Error: Unsupported command."
|
assert isinstance(stderr, str)
|
||||||
|
assert stderr.rstrip() == "Error: Unsupported command."
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_no_command(
|
async def test_no_command(
|
||||||
@ -136,7 +139,8 @@ class TestCommandErrors(BaseSshTests):
|
|||||||
async with conn.create_process() as process:
|
async with conn.create_process() as process:
|
||||||
stdout, stderr = await process.communicate()
|
stdout, stderr = await process.communicate()
|
||||||
print(f"{stdout=!r}, {stderr=!r}")
|
print(f"{stdout=!r}, {stderr=!r}")
|
||||||
assert stderr == "Error: No command was received from the client."
|
assert isinstance(stderr, str)
|
||||||
|
assert stderr.rstrip() == "Error: No command was received from the client."
|
||||||
result = await process.wait()
|
result = await process.wait()
|
||||||
assert result.exit_status == 1
|
assert result.exit_status == 1
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
"""Test get secret."""
|
"""Test get secret."""
|
||||||
|
|
||||||
|
import allure
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from .types import ClientRegistry, CommandRunner
|
from .types import ClientRegistry, CommandRunner
|
||||||
|
|
||||||
|
|
||||||
|
@allure.title("Test get_secret command")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_secret(
|
async def test_get_secret(
|
||||||
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
|
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
|
||||||
@ -19,7 +21,7 @@ async def test_get_secret(
|
|||||||
assert isinstance(result.stdout, str)
|
assert isinstance(result.stdout, str)
|
||||||
assert result.stdout.rstrip() == "mocked-secret-mysecret"
|
assert result.stdout.rstrip() == "mocked-secret-mysecret"
|
||||||
|
|
||||||
|
@allure.title("Test with invalid secret name")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invalid_secret_name(
|
async def test_invalid_secret_name(
|
||||||
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
|
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
|
||||||
@ -30,4 +32,25 @@ async def test_invalid_secret_name(
|
|||||||
|
|
||||||
result = await ssh_command_runner("test-client", "get_secret mysecret")
|
result = await ssh_command_runner("test-client", "get_secret mysecret")
|
||||||
assert result.exit_status == 1
|
assert result.exit_status == 1
|
||||||
assert result.stderr == "Error: No secret available with the given name."
|
stderr = result.stderr
|
||||||
|
assert isinstance(stderr, str)
|
||||||
|
assert stderr.rstrip() == "Error: No secret available with the given name."
|
||||||
|
|
||||||
|
@allure.title("Test get_secret command help")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_secret_cmd_help(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
|
||||||
|
"""Test running get_secret --help"""
|
||||||
|
await client_registry["add_client"]("test-client", ["mysecret"])
|
||||||
|
|
||||||
|
result = await ssh_command_runner("test-client", "get_secret --help")
|
||||||
|
|
||||||
|
assert result.exit_status == 0
|
||||||
|
|
||||||
|
print(result.stdout)
|
||||||
|
assert isinstance(result.stdout, str)
|
||||||
|
|
||||||
|
lines = result.stdout.splitlines()
|
||||||
|
|
||||||
|
assert lines[0] == "get_secret SECRET"
|
||||||
|
|
||||||
|
assert len(lines) > 4
|
||||||
|
|||||||
@ -1,8 +1,11 @@
|
|||||||
|
"""Test for the ping command."""
|
||||||
|
import allure
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from .types import ClientRegistry, CommandRunner
|
from .types import ClientRegistry, CommandRunner
|
||||||
|
|
||||||
|
|
||||||
|
@allure.title("Test running the ping command")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_ping_command(
|
async def test_ping_command(
|
||||||
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
|
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
|
||||||
@ -16,3 +19,21 @@ async def test_ping_command(
|
|||||||
assert result.stdout is not None
|
assert result.stdout is not None
|
||||||
assert isinstance(result.stdout, str)
|
assert isinstance(result.stdout, str)
|
||||||
assert result.stdout.rstrip() == "PONG"
|
assert result.stdout.rstrip() == "PONG"
|
||||||
|
|
||||||
|
@allure.title("Test ping help")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ping_cmd_help(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
|
||||||
|
"""Test running ping --help."""
|
||||||
|
await client_registry["add_client"]("test-client", ["mysecret"])
|
||||||
|
result = await ssh_command_runner("test-client", "ping --help")
|
||||||
|
|
||||||
|
assert result.exit_status == 0
|
||||||
|
|
||||||
|
print(result.stdout)
|
||||||
|
assert isinstance(result.stdout, str)
|
||||||
|
|
||||||
|
lines = result.stdout.splitlines()
|
||||||
|
|
||||||
|
assert lines[0] == "ping"
|
||||||
|
|
||||||
|
assert len(lines) > 4
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
"""Test registration."""
|
"""Test registration."""
|
||||||
|
|
||||||
|
import allure
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from .types import ClientRegistry, CommandRunner, ProcessRunner
|
from .types import ClientRegistry, CommandRunner, ProcessRunner
|
||||||
|
|
||||||
|
|
||||||
|
@allure.title("Test client registration")
|
||||||
@pytest.mark.enable_registration(True)
|
@pytest.mark.enable_registration(True)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_register_client(
|
async def test_register_client(
|
||||||
@ -29,8 +31,9 @@ async def test_register_client(
|
|||||||
|
|
||||||
assert found is True
|
assert found is True
|
||||||
session.stdin.write(public_key)
|
session.stdin.write(public_key)
|
||||||
result = await session.stdout.readline()
|
data = await session.stdout.read()
|
||||||
assert "OK" in result
|
assert isinstance(data, str)
|
||||||
|
assert "Key is valid. Registering client" in data
|
||||||
|
|
||||||
# Test that we can connect
|
# Test that we can connect
|
||||||
|
|
||||||
@ -39,3 +42,28 @@ async def test_register_client(
|
|||||||
assert result.stdout is not None
|
assert result.stdout is not None
|
||||||
assert isinstance(result.stdout, str)
|
assert isinstance(result.stdout, str)
|
||||||
assert result.stdout.rstrip() == "mocked-secret-testsecret"
|
assert result.stdout.rstrip() == "mocked-secret-testsecret"
|
||||||
|
|
||||||
|
@allure.title("Test register command help")
|
||||||
|
@pytest.mark.enable_registration(True)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_cmd_help(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
|
||||||
|
"""Test running register --help"""
|
||||||
|
await client_registry["add_client"]("test-client", ["mysecret"])
|
||||||
|
|
||||||
|
result = await ssh_command_runner("test-client", "register --help")
|
||||||
|
|
||||||
|
assert result.exit_status == 0
|
||||||
|
|
||||||
|
print(result.stdout)
|
||||||
|
assert isinstance(result.stdout, str)
|
||||||
|
|
||||||
|
lines = result.stdout.splitlines()
|
||||||
|
|
||||||
|
assert lines[0] == "register"
|
||||||
|
|
||||||
|
assert len(lines) > 4
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Test running register with an existing client.
|
||||||
|
|||||||
Reference in New Issue
Block a user