Create command dispatching classes

This commit is contained in:
2025-05-18 09:40:09 +02:00
parent 64536b40f6
commit 26ef9b45d4
9 changed files with 591 additions and 0 deletions

View File

@ -0,0 +1,247 @@
"""Base command class."""
import abc
import json
from collections import defaultdict
from dataclasses import dataclass, field
import ipaddress
import logging
from typing import Any, cast
import asyncssh
from pydantic import IPvAnyNetwork, IPvAnyAddress
from sshecret.backend.api import SshecretBackend
from sshecret.backend.models import Client, Operation, SubSystem
from sshecret_sshd import exceptions
from .utils import get_console
PeernameV4 = tuple[str, int]
PeernameV6 = tuple[str, int, int, int]
Peername = PeernameV4 | PeernameV6
LOG = logging.getLogger(__name__)
@dataclass
class CmdArgs:
"""Command and arguments."""
command: str
arguments: list[str] = field(default_factory=list)
class CommandDispatcher(abc.ABC):
"""Command dispatcher."""
name: str
flags: dict[str, str] | None = None
def __init__(
self,
process: asyncssh.SSHServerProcess[str],
) -> None:
"""Create command dispatcher class."""
self.process: asyncssh.SSHServerProcess[str] = process
def print(
self,
data: str,
stderr: bool = False,
formatted: bool = False,
newline: bool = True,
) -> None:
"""Write to stdout."""
if stderr:
self.process.stderr.write(data + "\n")
return
if formatted:
data = self.get_rich(data)
if newline:
data += "\n"
self.process.stdout.write(data)
def get_rich(self, data: str) -> str:
"""Print with rich formatting."""
console = get_console(self.process)
with console.capture() as capture:
console.print(data)
return capture.get()
def print_json(self, obj: dict[str, Any]) -> None:
"""Print a json object."""
console = get_console(self.process)
data = json.dumps(obj)
with console.capture() as capture:
console.print_json(data)
self.process.stdout.write(capture.get() + "\n")
async def audit(
self, operation: Operation, message: str, secret: str | None = None, **data: str
) -> None:
"""Log audit message."""
client = self.get_client()
try:
origin = str(self.remote_ip)
except Exception:
origin = "UNKNOWN"
username = self.get_username()
LOG.warning(
"Audit: %s (origin=%s, username=%s, client=%r, secret=%r, data=%r)",
message,
origin,
username,
client,
secret,
)
await self.backend.audit(SubSystem.SSHD).write_async(
operation=operation,
message=message,
origin=origin,
secret=None,
secret_name=secret,
client=client,
username=username or "No username",
**data,
)
@abc.abstractmethod
async def exec(self) -> None:
"""Execute main command."""
def parse_command(self) -> CmdArgs:
"""Get command."""
if not self.process.command:
raise exceptions.NoCommandReceivedError()
argv = self.process.command.split(" ")
return CmdArgs(argv[0], argv[1:])
@property
def options(self) -> dict[str, bool]:
"""Get arguments."""
if not self.flags:
return {}
parsed = self.parse_command()
args: dict[str, bool] = {}
shortflags: defaultdict[str, list[str]] = defaultdict(list)
for flag in self.flags.keys():
shortflags[flag[0]].append(flag)
for flag in self.flags.keys():
allowshort = len(shortflags[flag[0]]) == 1
if f"--{flag}" in parsed.arguments:
args[flag] = True
continue
if allowshort and f"-{flag[0]}" in parsed.arguments:
args[flag] = True
continue
args[flag] = False
return args
@property
def arguments(self) -> list[str]:
"""Get non-flag arguments."""
parsed = self.parse_command()
if not self.flags:
return parsed.arguments
return [
argument for argument in parsed.arguments if not argument.startswith("-")
]
@property
def backend(self) -> SshecretBackend:
"""Get backend from process info."""
backend = cast(
SshecretBackend | None, self.process.get_extra_info("backend", None)
)
if not backend:
raise exceptions.NoBackendError()
return backend
def get_client(self) -> Client | None:
"""Get client."""
return cast(Client | None, self.process.get_extra_info("client", None))
@property
def client(self) -> Client:
"""Get client from process info."""
client = self.get_client()
if not client:
raise exceptions.NoClientError()
return client
def get_username(self) -> str | None:
"""Get username."""
return cast(str | None, self.process.get_extra_info("provided_username", None))
@property
def username(self) -> str:
"""Get username from process info."""
username = self.get_username()
if not username:
raise exceptions.NoUsernameError()
return username
@property
def remote_ip(self) -> IPvAnyAddress:
"""Get remote IP."""
peername = cast(
"Peername | None", self.process.get_extra_info("peername", None)
)
remote_ip: str | None = None
if peername:
remote_ip = peername[0]
return ipaddress.ip_address(remote_ip)
raise exceptions.NoRemoteIpError()
@property
def allowed_registration_networks(self) -> list[IPvAnyNetwork]:
"""Get networks that allow registration."""
allowed_registration = cast(
list[IPvAnyNetwork],
self.process.get_extra_info("allow_registration_from", []),
)
return allowed_registration
@classmethod
async def print_help(cls, process: asyncssh.SSHServerProcess[str]) -> None:
"""Print help."""
descr = cls.__doc__
help_text: list[str] = [
f"[bold]{cls.name}[/bold]",
descr or "",
]
flags: dict[str, str] = {}
shortflags: defaultdict[str, list[str]] = defaultdict(list)
if cls.flags:
flags = cls.flags
arg_text: list[str] = []
for flag in flags.keys():
shortflags[flag[0]].append(flag)
for flag, flag_descr in flags.items():
flagstr = f"--{flag}"
if len(shortflags[flag[0]]) == 1:
flagstr = f"[ -{flag[0]} --{flag} ]"
arg_text.extend([f"{flagstr}:", f"{flag_descr}", ""])
console = get_console(process)
with console.capture() as capture:
for line in help_text:
console.print(line)
if arg_text:
console.rule("Argument flags")
for line in arg_text:
console.print(line)
process.stdout.write(capture.get())

View File

@ -0,0 +1,90 @@
"""Command dispatcher.
Register arguments here.
"""
import logging
import textwrap
from typing import final, override
import asyncssh
from sshecret_sshd import exceptions, constants
from .base import CommandDispatcher
from .get_secret import GetSecret
from .register import Register
from .list_secrets import ListSecrets
from .ping import PingCommand
SYNOPSIS = """[bold]Sshecret SSH Server[/bold]
[bold]SYNOPSIS[/bold]
An interface to request encrypted client secrets and perform
simple commands.
Secrets will be returned encrypted with the client public key,
encoded as base64.
"""
COMMANDS = [
GetSecret,
Register,
ListSecrets,
PingCommand,
]
LOG = logging.getLogger(__name__)
@final
class HelpCommand(CommandDispatcher):
"""Help.
Returns usage instructions.
"""
name = "help"
@override
async def exec(self) -> None:
"""Execute command."""
usage_text = "[bold]AVAILABLE COMMANDS[/bold]"
usage_text += (
"[i]Some commands may be disabled or restricted by the administrator[/i]\n"
)
for command in COMMANDS:
usage_text += f" [bold]{command.name}:[/bold]"
usage_text += textwrap.indent(self.get_command_doc(command), " ")
width, _height, _pixwidth, _pixheight = self.process.term_size
wrapped = textwrap.wrap(usage_text, width=width)
wrapped_lines = "\n".join(wrapped)
self.print(wrapped_lines, formatted=True)
def get_command_doc(self, command: type[CommandDispatcher]) -> str:
"""Format usage string for command."""
default_usage = f"No help available."
if not command.__doc__:
return default_usage
# Skip the first two lines:
usage_str = command.__doc__[:2:]
return usage_str.join("\n")
async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
"""Dispatch command."""
command = process.command
if not command:
command = "help"
command_map: dict[str, type[CommandDispatcher]] = {
cmd_disp.name: cmd_disp for cmd_disp in COMMANDS
}
command_map["help"] = HelpCommand
dispatcher = command_map.get(command, HelpCommand)
LOG.debug("Received command: %s", command)
return await dispatcher(process).exec()

View File

@ -0,0 +1,56 @@
"""Get secret."""
import logging
from typing import final, override
from sshecret.backend.models import Operation
from sshecret_sshd import exceptions
from .base import CommandDispatcher
LOG = logging.getLogger(__name__)
@final
class GetSecret(CommandDispatcher):
"""Get Secret."""
name = "get_secret"
@override
async def exec(self) -> None:
"""Execute command."""
if len(self.arguments) != 1:
raise exceptions.UnknownClientOrSecretError()
secret_name = self.arguments[0]
if secret_name not in self.client.secrets:
await self.audit(
Operation.DENY,
message="Client requested invalid secret",
secret=secret_name,
)
raise exceptions.SecretNotFoundError()
try:
secret = await self.backend.get_client_secret(self.client.name, secret_name)
except Exception as exc:
LOG.error(
"Got exception while getting client %s secret %s: %s",
self.client.name,
secret_name,
exc,
exc_info=True,
)
raise exceptions.BackendError(backend_error=str(exc)) from exc
if not secret:
await self.audit(
Operation.DENY,
message="Client requested invalid secret",
secret=secret_name,
)
raise exceptions.SecretNotFoundError()
await self.audit(
Operation.READ, message="Client requested secret", secret=secret_name
)
self.print(secret, newline=False)

View File

@ -0,0 +1,48 @@
"""Simple help command."""
from typing import final, override
from .base import CommandDispatcher
HELP_TEXT = """
[bold]Sshecret SSH Server[/bold]
[bold]SYNOPSIS[/bold]
An interface to request encrypted client secrets and perform
simple commands.
Secrets will be returned encrypted with the client public key,
encoded as base64.
[bold]AVAILABLE COMMANDS[/bold]
[i]Some of these commands may be disabled by the administrator.[/i]
[bold]ping[/bold]: Perform a ping towards the server.
[bold]get_secret [i]secret_name[/i][/bold]: Get a named secret
[bold]register[/bold]: Register a new client. Will prompt for public key.
[i]If enabled and permitted[/i]
[bold]ls [ --json -j ][/bold]: List available secrets.
[bold]help[/bold]: Prints this message
"""
@final
class HelpCommand(CommandDispatcher):
"""Help.
Returns usage instructions.
"""
name = "help"
@override
async def exec(self) -> None:
"""Execute command."""
self.print(HELP_TEXT, formatted=True)

View File

@ -0,0 +1,41 @@
"""List secrets command."""
from typing import final, override
from sshecret.backend.models import Operation
from .base import CommandDispatcher
@final
class ListSecrets(CommandDispatcher):
"""List secrets.
This command returns a list of secrets available for the connecting client
host.
"""
name = "ls"
flags = {"json": "Output in JSON format"}
@override
async def exec(self) -> None:
"""Execute command."""
json_mode = self.options.get("json")
if json_mode:
return self.list_as_json()
return await self.list_secrets()
def list_as_json(self) -> None:
"""List as json."""
json_obj = {"secrets": self.client.secrets}
self.print_json(json_obj)
async def list_secrets(self) -> None:
"""List secrets."""
self.print(
f"[bold]Available secrets for client {self.client.name}[/bold]",
formatted=True,
)
await self.audit(Operation.READ, "Listed available secret names")
for secret_name in self.client.name:
self.print(f" - {secret_name}")

View File

@ -0,0 +1,22 @@
"""Ping as a healthcheck command."""
from typing import final, override
from .base import CommandDispatcher
@final
class PingCommand(CommandDispatcher):
"""Ping.
This command responds with the string 'PONG'.
It may be used to ensure that the system works.
"""
name = "ping"
@override
async def exec(self) -> None:
"""Execute command."""
self.print("PONG")

View File

@ -0,0 +1,74 @@
"""Registration command."""
from typing import final, override
import asyncssh
from sshecret_sshd import constants, exceptions
from sshecret.backend.models import Operation
from .base import CommandDispatcher
def verify_key_input(public_key: str) -> str | None:
"""Verify key input."""
try:
key = asyncssh.import_public_key(public_key)
if key.algorithm.decode() == "ssh-rsa":
return public_key
except Exception:
return None
@final
class Register(CommandDispatcher):
"""Register a new client.
After connection, you must input a ssh public key.
This must be an RSA key. No other types of keys are supported.
"""
name = "register"
def verify_registration_host(self) -> bool:
"""Check if registration command is allowed."""
if not self.allowed_registration_networks:
return False
for network in self.allowed_registration_networks:
if self.remote_ip in network:
return True
return False
@override
async def exec(self) -> None:
"""Register client."""
if not self.verify_registration_host():
self.print(f"Registration not permitted from {self.remote_ip}", stderr=True)
await self.audit(
Operation.DENY,
constants.ERROR_REGISTRATION_NOT_ALLOWED,
)
return
self.print("Enter public key:")
public_key: str | None = None
try:
async for line in self.process.stdin:
public_key = verify_key_input(line.rstrip("\n"))
if public_key:
break
raise exceptions.InvalidPublicKeyType()
except asyncssh.BreakReceived:
pass
else:
self.print("Key received. Validating.")
if not public_key:
raise exceptions.InvalidPublicKeyType()
key = asyncssh.import_public_key(public_key)
if key.algorithm.decode() != "ssh-rsa":
raise exceptions.InvalidPublicKeyType()
self.print("Key is valid. Registering client.")
await self.audit(Operation.CREATE, "Registering new client.")
await self.backend.create_client(self.username, public_key)

View File

@ -0,0 +1,12 @@
"""Various utilities."""
import asyncssh
from rich.console import Console
def get_console(process: asyncssh.SSHServerProcess[str]) -> Console:
"""Initiate console from process."""
_width, _height, pixwidth, pixheight = process.term_size
console = Console(
force_terminal=True, width=pixwidth, height=pixheight, color_system="standard"
)
return console