Refactor command handling

This now supports usage/help texts
This commit is contained in:
2025-05-18 17:56:53 +02:00
parent 26ef9b45d4
commit dcf0b4274c
15 changed files with 337 additions and 431 deletions

View File

@ -1 +1,5 @@
"""Commands module."""
from .dispatcher import dispatch_command
__all__ = ["dispatch_command"]

View File

@ -2,11 +2,14 @@
import abc
import json
import textwrap
from collections import defaultdict
from dataclasses import dataclass, field
import ipaddress
import logging
from typing import Any, cast
from typing import Any, cast, override
from rich.console import Console
import asyncssh
from pydantic import IPvAnyNetwork, IPvAnyAddress
@ -15,8 +18,6 @@ from sshecret.backend.api import SshecretBackend
from sshecret.backend.models import Client, Operation, SubSystem
from sshecret_sshd import exceptions
from .utils import get_console
PeernameV4 = tuple[str, int]
PeernameV6 = tuple[str, int, int, int]
Peername = PeernameV4 | PeernameV6
@ -32,11 +33,45 @@ class CmdArgs:
arguments: list[str] = field(default_factory=list)
@dataclass
class CommandFlag:
"""Command flag."""
name: str
description: str
supports_short: bool = False
enabled: bool = False
@override
def __str__(self) -> str:
"""Format an output for help texts."""
if self.supports_short:
return f"[-{self.name[0]} --{self.name}]"
return f"--{self.name}"
def get_console(process: asyncssh.SSHServerProcess[str]) -> Console:
"""Initiate console from process."""
width, _height, pixwidth, pixheight = process.term_size
LOG.debug("Terminal is %sx%s", pixwidth, pixheight)
if width > 0:
console = Console(
force_terminal=True,
width=pixwidth,
height=pixheight,
color_system="standard",
)
return console
return Console(markup=False, color_system=None)
class CommandDispatcher(abc.ABC):
"""Command dispatcher."""
name: str
flags: dict[str, str] | None = None
mandatory_argument: str | None = None
def __init__(
self,
@ -47,31 +82,57 @@ class CommandDispatcher(abc.ABC):
def print(
self,
data: str,
*data: str,
stderr: bool = False,
formatted: bool = False,
newline: bool = True,
) -> None:
"""Write to stdout."""
if stderr:
self.process.stderr.write(data + "\n")
for line in data:
self.process.stderr.write(line + "\n")
return
if formatted:
data = self.get_rich(data)
if newline:
data += "\n"
self.process.stdout.write(data)
for line in data:
self.process.stdout.write(line)
if newline:
self.process.stdout.write("\n")
def get_rich(self, data: str) -> str:
"""Print with rich formatting."""
console = get_console(self.process)
def rich_print_line(
self, data: str, tags: list[str] | None = None, rule: bool = False
) -> None:
"""Write formatted text to the process.
IF the client terminal does not support this, no formatting will be added.
Otherwise, the tags will be added to the string.
"""
if not tags:
tags = []
if not self.formatting_supported:
return self.print(data, newline=True)
for tag in tags:
data = f"[{tag}]{data}[/{tag}]"
console = self.get_console()
with console.capture() as capture:
console.print(data)
return capture.get()
if rule:
console.rule(data)
else:
console.print(data)
self.process.stdout.write(capture.get())
def get_console(self) -> Console:
"""Initiate console from process."""
return get_console(self.process)
@property
def formatting_supported(self) -> bool:
"""Check if the terminal supports formatting."""
term_width = self.process.term_size[0]
return term_width > 0
def print_json(self, obj: dict[str, Any]) -> None:
"""Print a json object."""
console = get_console(self.process)
console = self.get_console()
data = json.dumps(obj)
with console.capture() as capture:
console.print_json(data)
@ -97,6 +158,7 @@ class CommandDispatcher(abc.ABC):
username,
client,
secret,
data,
)
await self.backend.audit(SubSystem.SSHD).write_async(
@ -148,6 +210,7 @@ class CommandDispatcher(abc.ABC):
def arguments(self) -> list[str]:
"""Get non-flag arguments."""
parsed = self.parse_command()
LOG.debug("Parsed command: %r", parsed)
if not self.flags:
return parsed.arguments
return [
@ -211,37 +274,69 @@ class CommandDispatcher(abc.ABC):
return allowed_registration
@classmethod
async def print_help(cls, process: asyncssh.SSHServerProcess[str]) -> None:
def _format_command_name(cls) -> str:
"""Format command name with arguments."""
name = cls.name
if cls.mandatory_argument:
name = f"{name} {cls.mandatory_argument}"
if not cls.flags:
return name
flags = cls.command_flags()
args: list[str] = [str(flag) for flag in flags.values()]
flagstr = " ".join(args)
return f"{name} {flagstr}"
@classmethod
def usage(cls, indent: int = 0) -> str:
"""Print command usage."""
indent_prefix = " " * indent
inner_indent = 2 + indent
inner_prefix = " " * inner_indent
usage_str: list[str] = []
default_usage = "No help available."
if not cls.__doc__:
usage_str.append(default_usage)
else:
usage_str = [
textwrap.indent(line, prefix=inner_prefix)
for line in cls.__doc__.splitlines()
]
flags = cls.command_flags()
if flags:
usage_str.append("")
usage_str.append(textwrap.indent("Arguments:", prefix=(" " * 4)))
for info in flags.values():
usage_str.append(
textwrap.indent(f"{info!s}: {info.description}", prefix=(" " * 6))
)
usage = "\n".join(usage_str)
if indent:
usage = textwrap.indent(usage, prefix=indent_prefix)
return usage
async def print_help(self) -> None:
"""Print help."""
descr = cls.__doc__
usage = type(self).usage()
command_name = type(self)._format_command_name()
self.rich_print_line(command_name)
self.print(usage)
help_text: list[str] = [
f"[bold]{cls.name}[/bold]",
descr or "",
]
flags: dict[str, str] = {}
@classmethod
def command_flags(cls) -> dict[str, CommandFlag]:
"""Parse the command flags."""
shortflags: defaultdict[str, list[str]] = defaultdict(list)
if cls.flags:
flags = cls.flags
arg_text: list[str] = []
for flag in flags.keys():
if not cls.flags:
return {}
for flag in cls.flags.keys():
shortflags[flag[0]].append(flag)
for flag, flag_descr in flags.items():
flagstr = f"--{flag}"
if len(shortflags[flag[0]]) == 1:
flagstr = f"[ -{flag[0]} --{flag} ]"
arg_text.extend([f"{flagstr}:", f"{flag_descr}", ""])
command_flags: dict[str, CommandFlag] = {}
console = get_console(process)
with console.capture() as capture:
for line in help_text:
console.print(line)
if arg_text:
console.rule("Argument flags")
for line in arg_text:
console.print(line)
process.stdout.write(capture.get())
for flag, description in cls.flags.items():
supports_shorts = True
if len(shortflags[flag[0]]) > 1:
supports_shorts = False
command_flags[flag] = CommandFlag(flag, description, supports_shorts)
return command_flags

View File

@ -5,7 +5,7 @@ Register arguments here.
import logging
import textwrap
from typing import final, override
from typing import cast, final, override
import asyncssh
@ -47,44 +47,105 @@ class HelpCommand(CommandDispatcher):
name = "help"
@override
def __init__(
self,
process: asyncssh.SSHServerProcess[str],
disabled_commands: list[str] | None = None,
) -> None:
"""Init help command."""
super().__init__(process)
self.disabled_commands: list[str] = disabled_commands or []
@override
async def exec(self) -> None:
"""Execute command."""
usage_text = "[bold]AVAILABLE COMMANDS[/bold]"
usage_text += (
"[i]Some commands may be disabled or restricted by the administrator[/i]\n"
output_lines: list[tuple[str, str] | str] = []
output_lines.append(("Available commands:", "bold"))
output_lines.append(
("Some commands may be disabled or restricted by the administrator", "i")
)
output_lines.append("")
for command in COMMANDS:
usage_text += f" [bold]{command.name}:[/bold]"
usage_text += textwrap.indent(self.get_command_doc(command), " ")
if command.name in self.disabled_commands:
continue
command_usage = command.usage(indent=2)
output_lines.append((f" {command._format_command_name()}", "bold"))
output_lines.extend(command_usage.splitlines())
output_lines.append("")
width, _height, _pixwidth, _pixheight = self.process.term_size
wrapped = textwrap.wrap(usage_text, width=width)
wrapped_lines = "\n".join(wrapped)
self.print(wrapped_lines, formatted=True)
for line in output_lines:
tags: list[str] = []
if isinstance(line, tuple):
text, tag = line
tags.append(tag)
else:
text = line
self.rich_print_line(text, tags)
def get_command_doc(self, command: type[CommandDispatcher]) -> str:
"""Format usage string for command."""
default_usage = f"No help available."
if not command.__doc__:
return default_usage
# Skip the first two lines:
usage_str = command.__doc__[:2:]
return usage_str.join("\n")
def get_disabled_commands(process: asyncssh.SSHServerProcess[str]) -> list[str]:
"""Get optional command state."""
with_registration = cast(
bool, process.get_extra_info("registration_enabled", False)
)
with_ping = cast(bool, process.get_extra_info("ping_enabled", False))
optional_commands = {
"register": with_registration,
"ping": with_ping,
}
disabled = [key for key, value in optional_commands.items() if not value]
return disabled
async def do_dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
"""Dispatch command."""
disabled_commands = get_disabled_commands(process)
command_args = process.command
if not command_args:
raise exceptions.NoCommandReceivedError()
show_command_help = False
if "--help" in command_args:
show_command_help = True
command = command_args.split(" ")[0]
command_map: dict[str, type[CommandDispatcher]] = {
cmd_disp.name: cmd_disp
for cmd_disp in COMMANDS
if cmd_disp.name not in disabled_commands
}
command_map["help"] = HelpCommand
LOG.debug("disabled_commands: %r, command_map: %r", disabled_commands, command_map)
LOG.debug("Looking for command %s", command)
if command not in command_map:
raise exceptions.UnknownCommandError()
dispatcher = command_map[command]
LOG.debug("Received command: %s. Dispatching to %r", command, dispatcher)
if command != "help" and show_command_help:
return await dispatcher(process).print_help()
if command == "help":
return await HelpCommand(process, disabled_commands).exec()
return await dispatcher(process).exec()
async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
"""Dispatch command."""
command = process.command
if not command:
command = "help"
command_map: dict[str, type[CommandDispatcher]] = {
cmd_disp.name: cmd_disp for cmd_disp in COMMANDS
}
command_map["help"] = HelpCommand
status_code = 0
try:
await do_dispatch_command(process)
except exceptions.BaseSshecretSshError as e:
LOG.error("Command Error: %s", e, exc_info=True)
process.stderr.write(f"{e}\n")
status_code = 1
except Exception as e:
LOG.error("Unexpected error: %e", e, exc_info=True)
process.stderr.write(f"{constants.ERROR_GENERIC_ERROR}: {e}")
status_code = 1
dispatcher = command_map.get(command, HelpCommand)
LOG.debug("Received command: %s", command)
return await dispatcher(process).exec()
finally:
process.exit(status_code)

View File

@ -12,9 +12,15 @@ LOG = logging.getLogger(__name__)
@final
class GetSecret(CommandDispatcher):
"""Get Secret."""
"""Retrieve an encrypted secret.
Returns the value of the secret provided as a mandatory argument.
The secret will be encrypted using the stored RSA public key, and returned
as a base64 encoded string.
"""
name = "get_secret"
mandatory_argument = "SECRET"
@override
async def exec(self) -> None:
@ -22,6 +28,7 @@ class GetSecret(CommandDispatcher):
if len(self.arguments) != 1:
raise exceptions.UnknownClientOrSecretError()
secret_name = self.arguments[0]
LOG.debug("get_secret called: Argument: %r", secret_name)
if secret_name not in self.client.secrets:
await self.audit(
Operation.DENY,

View File

@ -1,48 +0,0 @@
"""Simple help command."""
from typing import final, override
from .base import CommandDispatcher
HELP_TEXT = """
[bold]Sshecret SSH Server[/bold]
[bold]SYNOPSIS[/bold]
An interface to request encrypted client secrets and perform
simple commands.
Secrets will be returned encrypted with the client public key,
encoded as base64.
[bold]AVAILABLE COMMANDS[/bold]
[i]Some of these commands may be disabled by the administrator.[/i]
[bold]ping[/bold]: Perform a ping towards the server.
[bold]get_secret [i]secret_name[/i][/bold]: Get a named secret
[bold]register[/bold]: Register a new client. Will prompt for public key.
[i]If enabled and permitted[/i]
[bold]ls [ --json -j ][/bold]: List available secrets.
[bold]help[/bold]: Prints this message
"""
@final
class HelpCommand(CommandDispatcher):
"""Help.
Returns usage instructions.
"""
name = "help"
@override
async def exec(self) -> None:
"""Execute command."""
self.print(HELP_TEXT, formatted=True)

View File

@ -32,10 +32,7 @@ class ListSecrets(CommandDispatcher):
async def list_secrets(self) -> None:
"""List secrets."""
self.print(
f"[bold]Available secrets for client {self.client.name}[/bold]",
formatted=True,
)
self.rich_print_line(f"Available secrets for client {self.client.name}", ["bold"], rule=True)
await self.audit(Operation.READ, "Listed available secret names")
for secret_name in self.client.name:
for secret_name in self.client.secrets:
self.print(f" - {secret_name}")

View File

@ -24,7 +24,6 @@ class Register(CommandDispatcher):
"""Register a new client.
After connection, you must input a ssh public key.
This must be an RSA key. No other types of keys are supported.
"""

View File

@ -1,12 +0,0 @@
"""Various utilities."""
import asyncssh
from rich.console import Console
def get_console(process: asyncssh.SSHServerProcess[str]) -> Console:
"""Initiate console from process."""
_width, _height, pixwidth, pixheight = process.term_size
console = Console(
force_terminal=True, width=pixwidth, height=pixheight, color_system="standard"
)
return console