diff --git a/packages/sshecret-sshd/src/sshecret_sshd/commands/__init__.py b/packages/sshecret-sshd/src/sshecret_sshd/commands/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/packages/sshecret-sshd/src/sshecret_sshd/commands/__init__.py @@ -0,0 +1 @@ + diff --git a/packages/sshecret-sshd/src/sshecret_sshd/commands/base.py b/packages/sshecret-sshd/src/sshecret_sshd/commands/base.py new file mode 100644 index 0000000..6e64a0f --- /dev/null +++ b/packages/sshecret-sshd/src/sshecret_sshd/commands/base.py @@ -0,0 +1,247 @@ +"""Base command class.""" + +import abc +import json +from collections import defaultdict +from dataclasses import dataclass, field +import ipaddress +import logging +from typing import Any, cast + +import asyncssh +from pydantic import IPvAnyNetwork, IPvAnyAddress + +from sshecret.backend.api import SshecretBackend +from sshecret.backend.models import Client, Operation, SubSystem +from sshecret_sshd import exceptions + +from .utils import get_console + +PeernameV4 = tuple[str, int] +PeernameV6 = tuple[str, int, int, int] +Peername = PeernameV4 | PeernameV6 + +LOG = logging.getLogger(__name__) + + +@dataclass +class CmdArgs: + """Command and arguments.""" + + command: str + arguments: list[str] = field(default_factory=list) + + +class CommandDispatcher(abc.ABC): + """Command dispatcher.""" + + name: str + flags: dict[str, str] | None = None + + def __init__( + self, + process: asyncssh.SSHServerProcess[str], + ) -> None: + """Create command dispatcher class.""" + self.process: asyncssh.SSHServerProcess[str] = process + + def print( + self, + data: str, + stderr: bool = False, + formatted: bool = False, + newline: bool = True, + ) -> None: + """Write to stdout.""" + if stderr: + self.process.stderr.write(data + "\n") + return + if formatted: + data = self.get_rich(data) + if newline: + data += "\n" + self.process.stdout.write(data) + + def get_rich(self, data: str) -> str: + """Print with rich formatting.""" + console = get_console(self.process) + with console.capture() as capture: + console.print(data) + return capture.get() + + def print_json(self, obj: dict[str, Any]) -> None: + """Print a json object.""" + console = get_console(self.process) + data = json.dumps(obj) + with console.capture() as capture: + console.print_json(data) + + self.process.stdout.write(capture.get() + "\n") + + async def audit( + self, operation: Operation, message: str, secret: str | None = None, **data: str + ) -> None: + """Log audit message.""" + client = self.get_client() + try: + origin = str(self.remote_ip) + except Exception: + origin = "UNKNOWN" + + username = self.get_username() + + LOG.warning( + "Audit: %s (origin=%s, username=%s, client=%r, secret=%r, data=%r)", + message, + origin, + username, + client, + secret, + ) + + await self.backend.audit(SubSystem.SSHD).write_async( + operation=operation, + message=message, + origin=origin, + secret=None, + secret_name=secret, + client=client, + username=username or "No username", + **data, + ) + + @abc.abstractmethod + async def exec(self) -> None: + """Execute main command.""" + + def parse_command(self) -> CmdArgs: + """Get command.""" + if not self.process.command: + raise exceptions.NoCommandReceivedError() + argv = self.process.command.split(" ") + return CmdArgs(argv[0], argv[1:]) + + @property + def options(self) -> dict[str, bool]: + """Get arguments.""" + if not self.flags: + return {} + + parsed = self.parse_command() + args: dict[str, bool] = {} + shortflags: defaultdict[str, list[str]] = defaultdict(list) + for flag in self.flags.keys(): + shortflags[flag[0]].append(flag) + for flag in self.flags.keys(): + allowshort = len(shortflags[flag[0]]) == 1 + if f"--{flag}" in parsed.arguments: + args[flag] = True + continue + if allowshort and f"-{flag[0]}" in parsed.arguments: + args[flag] = True + continue + args[flag] = False + + return args + + @property + def arguments(self) -> list[str]: + """Get non-flag arguments.""" + parsed = self.parse_command() + if not self.flags: + return parsed.arguments + return [ + argument for argument in parsed.arguments if not argument.startswith("-") + ] + + @property + def backend(self) -> SshecretBackend: + """Get backend from process info.""" + backend = cast( + SshecretBackend | None, self.process.get_extra_info("backend", None) + ) + if not backend: + raise exceptions.NoBackendError() + return backend + + def get_client(self) -> Client | None: + """Get client.""" + return cast(Client | None, self.process.get_extra_info("client", None)) + + @property + def client(self) -> Client: + """Get client from process info.""" + client = self.get_client() + if not client: + raise exceptions.NoClientError() + return client + + def get_username(self) -> str | None: + """Get username.""" + return cast(str | None, self.process.get_extra_info("provided_username", None)) + + @property + def username(self) -> str: + """Get username from process info.""" + username = self.get_username() + if not username: + raise exceptions.NoUsernameError() + return username + + @property + def remote_ip(self) -> IPvAnyAddress: + """Get remote IP.""" + peername = cast( + "Peername | None", self.process.get_extra_info("peername", None) + ) + remote_ip: str | None = None + if peername: + remote_ip = peername[0] + return ipaddress.ip_address(remote_ip) + + raise exceptions.NoRemoteIpError() + + @property + def allowed_registration_networks(self) -> list[IPvAnyNetwork]: + """Get networks that allow registration.""" + allowed_registration = cast( + list[IPvAnyNetwork], + self.process.get_extra_info("allow_registration_from", []), + ) + return allowed_registration + + @classmethod + async def print_help(cls, process: asyncssh.SSHServerProcess[str]) -> None: + """Print help.""" + descr = cls.__doc__ + + help_text: list[str] = [ + f"[bold]{cls.name}[/bold]", + descr or "", + ] + flags: dict[str, str] = {} + shortflags: defaultdict[str, list[str]] = defaultdict(list) + if cls.flags: + flags = cls.flags + + arg_text: list[str] = [] + for flag in flags.keys(): + shortflags[flag[0]].append(flag) + for flag, flag_descr in flags.items(): + flagstr = f"--{flag}" + if len(shortflags[flag[0]]) == 1: + flagstr = f"[ -{flag[0]} --{flag} ]" + + arg_text.extend([f"{flagstr}:", f"{flag_descr}", ""]) + + console = get_console(process) + with console.capture() as capture: + for line in help_text: + console.print(line) + if arg_text: + console.rule("Argument flags") + + for line in arg_text: + console.print(line) + + process.stdout.write(capture.get()) diff --git a/packages/sshecret-sshd/src/sshecret_sshd/commands/dispatcher.py b/packages/sshecret-sshd/src/sshecret_sshd/commands/dispatcher.py new file mode 100644 index 0000000..96755fb --- /dev/null +++ b/packages/sshecret-sshd/src/sshecret_sshd/commands/dispatcher.py @@ -0,0 +1,90 @@ +"""Command dispatcher. + +Register arguments here. +""" + +import logging +import textwrap +from typing import final, override + +import asyncssh + +from sshecret_sshd import exceptions, constants + +from .base import CommandDispatcher +from .get_secret import GetSecret +from .register import Register +from .list_secrets import ListSecrets +from .ping import PingCommand + + +SYNOPSIS = """[bold]Sshecret SSH Server[/bold] + +[bold]SYNOPSIS[/bold] +An interface to request encrypted client secrets and perform +simple commands. + +Secrets will be returned encrypted with the client public key, +encoded as base64. +""" + +COMMANDS = [ + GetSecret, + Register, + ListSecrets, + PingCommand, +] + +LOG = logging.getLogger(__name__) + + +@final +class HelpCommand(CommandDispatcher): + """Help. + + Returns usage instructions. + """ + + name = "help" + + @override + async def exec(self) -> None: + """Execute command.""" + usage_text = "[bold]AVAILABLE COMMANDS[/bold]" + usage_text += ( + "[i]Some commands may be disabled or restricted by the administrator[/i]\n" + ) + for command in COMMANDS: + usage_text += f" [bold]{command.name}:[/bold]" + usage_text += textwrap.indent(self.get_command_doc(command), " ") + + width, _height, _pixwidth, _pixheight = self.process.term_size + wrapped = textwrap.wrap(usage_text, width=width) + wrapped_lines = "\n".join(wrapped) + self.print(wrapped_lines, formatted=True) + + def get_command_doc(self, command: type[CommandDispatcher]) -> str: + """Format usage string for command.""" + default_usage = f"No help available." + if not command.__doc__: + return default_usage + # Skip the first two lines: + usage_str = command.__doc__[:2:] + return usage_str.join("\n") + + +async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None: + """Dispatch command.""" + command = process.command + if not command: + command = "help" + command_map: dict[str, type[CommandDispatcher]] = { + cmd_disp.name: cmd_disp for cmd_disp in COMMANDS + } + command_map["help"] = HelpCommand + + dispatcher = command_map.get(command, HelpCommand) + + LOG.debug("Received command: %s", command) + + return await dispatcher(process).exec() diff --git a/packages/sshecret-sshd/src/sshecret_sshd/commands/get_secret.py b/packages/sshecret-sshd/src/sshecret_sshd/commands/get_secret.py new file mode 100644 index 0000000..e6937e9 --- /dev/null +++ b/packages/sshecret-sshd/src/sshecret_sshd/commands/get_secret.py @@ -0,0 +1,56 @@ +"""Get secret.""" + +import logging +from typing import final, override + +from sshecret.backend.models import Operation +from sshecret_sshd import exceptions +from .base import CommandDispatcher + +LOG = logging.getLogger(__name__) + + +@final +class GetSecret(CommandDispatcher): + """Get Secret.""" + + name = "get_secret" + + @override + async def exec(self) -> None: + """Execute command.""" + if len(self.arguments) != 1: + raise exceptions.UnknownClientOrSecretError() + secret_name = self.arguments[0] + if secret_name not in self.client.secrets: + await self.audit( + Operation.DENY, + message="Client requested invalid secret", + secret=secret_name, + ) + raise exceptions.SecretNotFoundError() + try: + secret = await self.backend.get_client_secret(self.client.name, secret_name) + except Exception as exc: + LOG.error( + "Got exception while getting client %s secret %s: %s", + self.client.name, + secret_name, + exc, + exc_info=True, + ) + raise exceptions.BackendError(backend_error=str(exc)) from exc + + if not secret: + await self.audit( + Operation.DENY, + message="Client requested invalid secret", + secret=secret_name, + ) + + raise exceptions.SecretNotFoundError() + + await self.audit( + Operation.READ, message="Client requested secret", secret=secret_name + ) + self.print(secret, newline=False) diff --git a/packages/sshecret-sshd/src/sshecret_sshd/commands/help.py b/packages/sshecret-sshd/src/sshecret_sshd/commands/help.py new file mode 100644 index 0000000..898b9c3 --- /dev/null +++ b/packages/sshecret-sshd/src/sshecret_sshd/commands/help.py @@ -0,0 +1,48 @@ +"""Simple help command.""" + + +from typing import final, override + +from .base import CommandDispatcher + +HELP_TEXT = """ +[bold]Sshecret SSH Server[/bold] + + +[bold]SYNOPSIS[/bold] +An interface to request encrypted client secrets and perform +simple commands. + +Secrets will be returned encrypted with the client public key, +encoded as base64. + + +[bold]AVAILABLE COMMANDS[/bold] +[i]Some of these commands may be disabled by the administrator.[/i] + + [bold]ping[/bold]: Perform a ping towards the server. + + [bold]get_secret [i]secret_name[/i][/bold]: Get a named secret + + [bold]register[/bold]: Register a new client. Will prompt for public key. + [i]If enabled and permitted[/i] + + [bold]ls [ --json -j ][/bold]: List available secrets. + + [bold]help[/bold]: Prints this message +""" + + +@final +class HelpCommand(CommandDispatcher): + """Help. + + Returns usage instructions. + """ + + name = "help" + + @override + async def exec(self) -> None: + """Execute command.""" + self.print(HELP_TEXT, formatted=True) diff --git a/packages/sshecret-sshd/src/sshecret_sshd/commands/list_secrets.py b/packages/sshecret-sshd/src/sshecret_sshd/commands/list_secrets.py new file mode 100644 index 0000000..8b04a10 --- /dev/null +++ b/packages/sshecret-sshd/src/sshecret_sshd/commands/list_secrets.py @@ -0,0 +1,41 @@ +"""List secrets command.""" + +from typing import final, override + +from sshecret.backend.models import Operation +from .base import CommandDispatcher + + +@final +class ListSecrets(CommandDispatcher): + """List secrets. + + This command returns a list of secrets available for the connecting client + host. + """ + + name = "ls" + flags = {"json": "Output in JSON format"} + + @override + async def exec(self) -> None: + """Execute command.""" + json_mode = self.options.get("json") + if json_mode: + return self.list_as_json() + return await self.list_secrets() + + def list_as_json(self) -> None: + """List as json.""" + json_obj = {"secrets": self.client.secrets} + self.print_json(json_obj) + + async def list_secrets(self) -> None: + """List secrets.""" + self.print( + f"[bold]Available secrets for client {self.client.name}[/bold]", + formatted=True, + ) + await self.audit(Operation.READ, "Listed available secret names") + for secret_name in self.client.name: + self.print(f" - {secret_name}") diff --git a/packages/sshecret-sshd/src/sshecret_sshd/commands/ping.py b/packages/sshecret-sshd/src/sshecret_sshd/commands/ping.py new file mode 100644 index 0000000..405bbe0 --- /dev/null +++ b/packages/sshecret-sshd/src/sshecret_sshd/commands/ping.py @@ -0,0 +1,22 @@ +"""Ping as a healthcheck command.""" + +from typing import final, override + +from .base import CommandDispatcher + + +@final +class PingCommand(CommandDispatcher): + """Ping. + + This command responds with the string 'PONG'. + + It may be used to ensure that the system works. + """ + + name = "ping" + + @override + async def exec(self) -> None: + """Execute command.""" + self.print("PONG") diff --git a/packages/sshecret-sshd/src/sshecret_sshd/commands/register.py b/packages/sshecret-sshd/src/sshecret_sshd/commands/register.py new file mode 100644 index 0000000..efa0a0b --- /dev/null +++ b/packages/sshecret-sshd/src/sshecret_sshd/commands/register.py @@ -0,0 +1,74 @@ +"""Registration command.""" + +from typing import final, override +import asyncssh + +from sshecret_sshd import constants, exceptions + +from sshecret.backend.models import Operation +from .base import CommandDispatcher + + +def verify_key_input(public_key: str) -> str | None: + """Verify key input.""" + try: + key = asyncssh.import_public_key(public_key) + if key.algorithm.decode() == "ssh-rsa": + return public_key + except Exception: + return None + + +@final +class Register(CommandDispatcher): + """Register a new client. + + After connection, you must input a ssh public key. + + This must be an RSA key. No other types of keys are supported. + """ + + name = "register" + + def verify_registration_host(self) -> bool: + """Check if registration command is allowed.""" + if not self.allowed_registration_networks: + return False + for network in self.allowed_registration_networks: + if self.remote_ip in network: + return True + return False + + @override + async def exec(self) -> None: + """Register client.""" + if not self.verify_registration_host(): + self.print(f"Registration not permitted from {self.remote_ip}", stderr=True) + await self.audit( + Operation.DENY, + constants.ERROR_REGISTRATION_NOT_ALLOWED, + ) + return + + self.print("Enter public key:") + public_key: str | None = None + try: + async for line in self.process.stdin: + public_key = verify_key_input(line.rstrip("\n")) + if public_key: + break + raise exceptions.InvalidPublicKeyType() + except asyncssh.BreakReceived: + pass + else: + self.print("Key received. Validating.") + + if not public_key: + raise exceptions.InvalidPublicKeyType() + + key = asyncssh.import_public_key(public_key) + if key.algorithm.decode() != "ssh-rsa": + raise exceptions.InvalidPublicKeyType() + self.print("Key is valid. Registering client.") + await self.audit(Operation.CREATE, "Registering new client.") + await self.backend.create_client(self.username, public_key) diff --git a/packages/sshecret-sshd/src/sshecret_sshd/commands/utils.py b/packages/sshecret-sshd/src/sshecret_sshd/commands/utils.py new file mode 100644 index 0000000..096358d --- /dev/null +++ b/packages/sshecret-sshd/src/sshecret_sshd/commands/utils.py @@ -0,0 +1,12 @@ +"""Various utilities.""" + +import asyncssh +from rich.console import Console + +def get_console(process: asyncssh.SSHServerProcess[str]) -> Console: + """Initiate console from process.""" + _width, _height, pixwidth, pixheight = process.term_size + console = Console( + force_terminal=True, width=pixwidth, height=pixheight, color_system="standard" + ) + return console