Create command dispatching classes
This commit is contained in:
@ -0,0 +1 @@
|
||||
|
||||
247
packages/sshecret-sshd/src/sshecret_sshd/commands/base.py
Normal file
247
packages/sshecret-sshd/src/sshecret_sshd/commands/base.py
Normal file
@ -0,0 +1,247 @@
|
||||
"""Base command class."""
|
||||
|
||||
import abc
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
import ipaddress
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
import asyncssh
|
||||
from pydantic import IPvAnyNetwork, IPvAnyAddress
|
||||
|
||||
from sshecret.backend.api import SshecretBackend
|
||||
from sshecret.backend.models import Client, Operation, SubSystem
|
||||
from sshecret_sshd import exceptions
|
||||
|
||||
from .utils import get_console
|
||||
|
||||
PeernameV4 = tuple[str, int]
|
||||
PeernameV6 = tuple[str, int, int, int]
|
||||
Peername = PeernameV4 | PeernameV6
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CmdArgs:
|
||||
"""Command and arguments."""
|
||||
|
||||
command: str
|
||||
arguments: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class CommandDispatcher(abc.ABC):
|
||||
"""Command dispatcher."""
|
||||
|
||||
name: str
|
||||
flags: dict[str, str] | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
process: asyncssh.SSHServerProcess[str],
|
||||
) -> None:
|
||||
"""Create command dispatcher class."""
|
||||
self.process: asyncssh.SSHServerProcess[str] = process
|
||||
|
||||
def print(
|
||||
self,
|
||||
data: str,
|
||||
stderr: bool = False,
|
||||
formatted: bool = False,
|
||||
newline: bool = True,
|
||||
) -> None:
|
||||
"""Write to stdout."""
|
||||
if stderr:
|
||||
self.process.stderr.write(data + "\n")
|
||||
return
|
||||
if formatted:
|
||||
data = self.get_rich(data)
|
||||
if newline:
|
||||
data += "\n"
|
||||
self.process.stdout.write(data)
|
||||
|
||||
def get_rich(self, data: str) -> str:
|
||||
"""Print with rich formatting."""
|
||||
console = get_console(self.process)
|
||||
with console.capture() as capture:
|
||||
console.print(data)
|
||||
return capture.get()
|
||||
|
||||
def print_json(self, obj: dict[str, Any]) -> None:
|
||||
"""Print a json object."""
|
||||
console = get_console(self.process)
|
||||
data = json.dumps(obj)
|
||||
with console.capture() as capture:
|
||||
console.print_json(data)
|
||||
|
||||
self.process.stdout.write(capture.get() + "\n")
|
||||
|
||||
async def audit(
|
||||
self, operation: Operation, message: str, secret: str | None = None, **data: str
|
||||
) -> None:
|
||||
"""Log audit message."""
|
||||
client = self.get_client()
|
||||
try:
|
||||
origin = str(self.remote_ip)
|
||||
except Exception:
|
||||
origin = "UNKNOWN"
|
||||
|
||||
username = self.get_username()
|
||||
|
||||
LOG.warning(
|
||||
"Audit: %s (origin=%s, username=%s, client=%r, secret=%r, data=%r)",
|
||||
message,
|
||||
origin,
|
||||
username,
|
||||
client,
|
||||
secret,
|
||||
)
|
||||
|
||||
await self.backend.audit(SubSystem.SSHD).write_async(
|
||||
operation=operation,
|
||||
message=message,
|
||||
origin=origin,
|
||||
secret=None,
|
||||
secret_name=secret,
|
||||
client=client,
|
||||
username=username or "No username",
|
||||
**data,
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
async def exec(self) -> None:
|
||||
"""Execute main command."""
|
||||
|
||||
def parse_command(self) -> CmdArgs:
|
||||
"""Get command."""
|
||||
if not self.process.command:
|
||||
raise exceptions.NoCommandReceivedError()
|
||||
argv = self.process.command.split(" ")
|
||||
return CmdArgs(argv[0], argv[1:])
|
||||
|
||||
@property
|
||||
def options(self) -> dict[str, bool]:
|
||||
"""Get arguments."""
|
||||
if not self.flags:
|
||||
return {}
|
||||
|
||||
parsed = self.parse_command()
|
||||
args: dict[str, bool] = {}
|
||||
shortflags: defaultdict[str, list[str]] = defaultdict(list)
|
||||
for flag in self.flags.keys():
|
||||
shortflags[flag[0]].append(flag)
|
||||
for flag in self.flags.keys():
|
||||
allowshort = len(shortflags[flag[0]]) == 1
|
||||
if f"--{flag}" in parsed.arguments:
|
||||
args[flag] = True
|
||||
continue
|
||||
if allowshort and f"-{flag[0]}" in parsed.arguments:
|
||||
args[flag] = True
|
||||
continue
|
||||
args[flag] = False
|
||||
|
||||
return args
|
||||
|
||||
@property
|
||||
def arguments(self) -> list[str]:
|
||||
"""Get non-flag arguments."""
|
||||
parsed = self.parse_command()
|
||||
if not self.flags:
|
||||
return parsed.arguments
|
||||
return [
|
||||
argument for argument in parsed.arguments if not argument.startswith("-")
|
||||
]
|
||||
|
||||
@property
|
||||
def backend(self) -> SshecretBackend:
|
||||
"""Get backend from process info."""
|
||||
backend = cast(
|
||||
SshecretBackend | None, self.process.get_extra_info("backend", None)
|
||||
)
|
||||
if not backend:
|
||||
raise exceptions.NoBackendError()
|
||||
return backend
|
||||
|
||||
def get_client(self) -> Client | None:
|
||||
"""Get client."""
|
||||
return cast(Client | None, self.process.get_extra_info("client", None))
|
||||
|
||||
@property
|
||||
def client(self) -> Client:
|
||||
"""Get client from process info."""
|
||||
client = self.get_client()
|
||||
if not client:
|
||||
raise exceptions.NoClientError()
|
||||
return client
|
||||
|
||||
def get_username(self) -> str | None:
|
||||
"""Get username."""
|
||||
return cast(str | None, self.process.get_extra_info("provided_username", None))
|
||||
|
||||
@property
|
||||
def username(self) -> str:
|
||||
"""Get username from process info."""
|
||||
username = self.get_username()
|
||||
if not username:
|
||||
raise exceptions.NoUsernameError()
|
||||
return username
|
||||
|
||||
@property
|
||||
def remote_ip(self) -> IPvAnyAddress:
|
||||
"""Get remote IP."""
|
||||
peername = cast(
|
||||
"Peername | None", self.process.get_extra_info("peername", None)
|
||||
)
|
||||
remote_ip: str | None = None
|
||||
if peername:
|
||||
remote_ip = peername[0]
|
||||
return ipaddress.ip_address(remote_ip)
|
||||
|
||||
raise exceptions.NoRemoteIpError()
|
||||
|
||||
@property
|
||||
def allowed_registration_networks(self) -> list[IPvAnyNetwork]:
|
||||
"""Get networks that allow registration."""
|
||||
allowed_registration = cast(
|
||||
list[IPvAnyNetwork],
|
||||
self.process.get_extra_info("allow_registration_from", []),
|
||||
)
|
||||
return allowed_registration
|
||||
|
||||
@classmethod
|
||||
async def print_help(cls, process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
"""Print help."""
|
||||
descr = cls.__doc__
|
||||
|
||||
help_text: list[str] = [
|
||||
f"[bold]{cls.name}[/bold]",
|
||||
descr or "",
|
||||
]
|
||||
flags: dict[str, str] = {}
|
||||
shortflags: defaultdict[str, list[str]] = defaultdict(list)
|
||||
if cls.flags:
|
||||
flags = cls.flags
|
||||
|
||||
arg_text: list[str] = []
|
||||
for flag in flags.keys():
|
||||
shortflags[flag[0]].append(flag)
|
||||
for flag, flag_descr in flags.items():
|
||||
flagstr = f"--{flag}"
|
||||
if len(shortflags[flag[0]]) == 1:
|
||||
flagstr = f"[ -{flag[0]} --{flag} ]"
|
||||
|
||||
arg_text.extend([f"{flagstr}:", f"{flag_descr}", ""])
|
||||
|
||||
console = get_console(process)
|
||||
with console.capture() as capture:
|
||||
for line in help_text:
|
||||
console.print(line)
|
||||
if arg_text:
|
||||
console.rule("Argument flags")
|
||||
|
||||
for line in arg_text:
|
||||
console.print(line)
|
||||
|
||||
process.stdout.write(capture.get())
|
||||
@ -0,0 +1,90 @@
|
||||
"""Command dispatcher.
|
||||
|
||||
Register arguments here.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import textwrap
|
||||
from typing import final, override
|
||||
|
||||
import asyncssh
|
||||
|
||||
from sshecret_sshd import exceptions, constants
|
||||
|
||||
from .base import CommandDispatcher
|
||||
from .get_secret import GetSecret
|
||||
from .register import Register
|
||||
from .list_secrets import ListSecrets
|
||||
from .ping import PingCommand
|
||||
|
||||
|
||||
SYNOPSIS = """[bold]Sshecret SSH Server[/bold]
|
||||
|
||||
[bold]SYNOPSIS[/bold]
|
||||
An interface to request encrypted client secrets and perform
|
||||
simple commands.
|
||||
|
||||
Secrets will be returned encrypted with the client public key,
|
||||
encoded as base64.
|
||||
"""
|
||||
|
||||
COMMANDS = [
|
||||
GetSecret,
|
||||
Register,
|
||||
ListSecrets,
|
||||
PingCommand,
|
||||
]
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class HelpCommand(CommandDispatcher):
|
||||
"""Help.
|
||||
|
||||
Returns usage instructions.
|
||||
"""
|
||||
|
||||
name = "help"
|
||||
|
||||
@override
|
||||
async def exec(self) -> None:
|
||||
"""Execute command."""
|
||||
usage_text = "[bold]AVAILABLE COMMANDS[/bold]"
|
||||
usage_text += (
|
||||
"[i]Some commands may be disabled or restricted by the administrator[/i]\n"
|
||||
)
|
||||
for command in COMMANDS:
|
||||
usage_text += f" [bold]{command.name}:[/bold]"
|
||||
usage_text += textwrap.indent(self.get_command_doc(command), " ")
|
||||
|
||||
width, _height, _pixwidth, _pixheight = self.process.term_size
|
||||
wrapped = textwrap.wrap(usage_text, width=width)
|
||||
wrapped_lines = "\n".join(wrapped)
|
||||
self.print(wrapped_lines, formatted=True)
|
||||
|
||||
def get_command_doc(self, command: type[CommandDispatcher]) -> str:
|
||||
"""Format usage string for command."""
|
||||
default_usage = f"No help available."
|
||||
if not command.__doc__:
|
||||
return default_usage
|
||||
# Skip the first two lines:
|
||||
usage_str = command.__doc__[:2:]
|
||||
return usage_str.join("\n")
|
||||
|
||||
|
||||
async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
"""Dispatch command."""
|
||||
command = process.command
|
||||
if not command:
|
||||
command = "help"
|
||||
command_map: dict[str, type[CommandDispatcher]] = {
|
||||
cmd_disp.name: cmd_disp for cmd_disp in COMMANDS
|
||||
}
|
||||
command_map["help"] = HelpCommand
|
||||
|
||||
dispatcher = command_map.get(command, HelpCommand)
|
||||
|
||||
LOG.debug("Received command: %s", command)
|
||||
|
||||
return await dispatcher(process).exec()
|
||||
@ -0,0 +1,56 @@
|
||||
"""Get secret."""
|
||||
|
||||
import logging
|
||||
from typing import final, override
|
||||
|
||||
from sshecret.backend.models import Operation
|
||||
from sshecret_sshd import exceptions
|
||||
from .base import CommandDispatcher
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class GetSecret(CommandDispatcher):
|
||||
"""Get Secret."""
|
||||
|
||||
name = "get_secret"
|
||||
|
||||
@override
|
||||
async def exec(self) -> None:
|
||||
"""Execute command."""
|
||||
if len(self.arguments) != 1:
|
||||
raise exceptions.UnknownClientOrSecretError()
|
||||
secret_name = self.arguments[0]
|
||||
if secret_name not in self.client.secrets:
|
||||
await self.audit(
|
||||
Operation.DENY,
|
||||
message="Client requested invalid secret",
|
||||
secret=secret_name,
|
||||
)
|
||||
raise exceptions.SecretNotFoundError()
|
||||
try:
|
||||
secret = await self.backend.get_client_secret(self.client.name, secret_name)
|
||||
except Exception as exc:
|
||||
LOG.error(
|
||||
"Got exception while getting client %s secret %s: %s",
|
||||
self.client.name,
|
||||
secret_name,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
raise exceptions.BackendError(backend_error=str(exc)) from exc
|
||||
|
||||
if not secret:
|
||||
await self.audit(
|
||||
Operation.DENY,
|
||||
message="Client requested invalid secret",
|
||||
secret=secret_name,
|
||||
)
|
||||
|
||||
raise exceptions.SecretNotFoundError()
|
||||
|
||||
await self.audit(
|
||||
Operation.READ, message="Client requested secret", secret=secret_name
|
||||
)
|
||||
self.print(secret, newline=False)
|
||||
48
packages/sshecret-sshd/src/sshecret_sshd/commands/help.py
Normal file
48
packages/sshecret-sshd/src/sshecret_sshd/commands/help.py
Normal file
@ -0,0 +1,48 @@
|
||||
"""Simple help command."""
|
||||
|
||||
|
||||
from typing import final, override
|
||||
|
||||
from .base import CommandDispatcher
|
||||
|
||||
HELP_TEXT = """
|
||||
[bold]Sshecret SSH Server[/bold]
|
||||
|
||||
|
||||
[bold]SYNOPSIS[/bold]
|
||||
An interface to request encrypted client secrets and perform
|
||||
simple commands.
|
||||
|
||||
Secrets will be returned encrypted with the client public key,
|
||||
encoded as base64.
|
||||
|
||||
|
||||
[bold]AVAILABLE COMMANDS[/bold]
|
||||
[i]Some of these commands may be disabled by the administrator.[/i]
|
||||
|
||||
[bold]ping[/bold]: Perform a ping towards the server.
|
||||
|
||||
[bold]get_secret [i]secret_name[/i][/bold]: Get a named secret
|
||||
|
||||
[bold]register[/bold]: Register a new client. Will prompt for public key.
|
||||
[i]If enabled and permitted[/i]
|
||||
|
||||
[bold]ls [ --json -j ][/bold]: List available secrets.
|
||||
|
||||
[bold]help[/bold]: Prints this message
|
||||
"""
|
||||
|
||||
|
||||
@final
|
||||
class HelpCommand(CommandDispatcher):
|
||||
"""Help.
|
||||
|
||||
Returns usage instructions.
|
||||
"""
|
||||
|
||||
name = "help"
|
||||
|
||||
@override
|
||||
async def exec(self) -> None:
|
||||
"""Execute command."""
|
||||
self.print(HELP_TEXT, formatted=True)
|
||||
@ -0,0 +1,41 @@
|
||||
"""List secrets command."""
|
||||
|
||||
from typing import final, override
|
||||
|
||||
from sshecret.backend.models import Operation
|
||||
from .base import CommandDispatcher
|
||||
|
||||
|
||||
@final
|
||||
class ListSecrets(CommandDispatcher):
|
||||
"""List secrets.
|
||||
|
||||
This command returns a list of secrets available for the connecting client
|
||||
host.
|
||||
"""
|
||||
|
||||
name = "ls"
|
||||
flags = {"json": "Output in JSON format"}
|
||||
|
||||
@override
|
||||
async def exec(self) -> None:
|
||||
"""Execute command."""
|
||||
json_mode = self.options.get("json")
|
||||
if json_mode:
|
||||
return self.list_as_json()
|
||||
return await self.list_secrets()
|
||||
|
||||
def list_as_json(self) -> None:
|
||||
"""List as json."""
|
||||
json_obj = {"secrets": self.client.secrets}
|
||||
self.print_json(json_obj)
|
||||
|
||||
async def list_secrets(self) -> None:
|
||||
"""List secrets."""
|
||||
self.print(
|
||||
f"[bold]Available secrets for client {self.client.name}[/bold]",
|
||||
formatted=True,
|
||||
)
|
||||
await self.audit(Operation.READ, "Listed available secret names")
|
||||
for secret_name in self.client.name:
|
||||
self.print(f" - {secret_name}")
|
||||
22
packages/sshecret-sshd/src/sshecret_sshd/commands/ping.py
Normal file
22
packages/sshecret-sshd/src/sshecret_sshd/commands/ping.py
Normal file
@ -0,0 +1,22 @@
|
||||
"""Ping as a healthcheck command."""
|
||||
|
||||
from typing import final, override
|
||||
|
||||
from .base import CommandDispatcher
|
||||
|
||||
|
||||
@final
|
||||
class PingCommand(CommandDispatcher):
|
||||
"""Ping.
|
||||
|
||||
This command responds with the string 'PONG'.
|
||||
|
||||
It may be used to ensure that the system works.
|
||||
"""
|
||||
|
||||
name = "ping"
|
||||
|
||||
@override
|
||||
async def exec(self) -> None:
|
||||
"""Execute command."""
|
||||
self.print("PONG")
|
||||
@ -0,0 +1,74 @@
|
||||
"""Registration command."""
|
||||
|
||||
from typing import final, override
|
||||
import asyncssh
|
||||
|
||||
from sshecret_sshd import constants, exceptions
|
||||
|
||||
from sshecret.backend.models import Operation
|
||||
from .base import CommandDispatcher
|
||||
|
||||
|
||||
def verify_key_input(public_key: str) -> str | None:
|
||||
"""Verify key input."""
|
||||
try:
|
||||
key = asyncssh.import_public_key(public_key)
|
||||
if key.algorithm.decode() == "ssh-rsa":
|
||||
return public_key
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@final
|
||||
class Register(CommandDispatcher):
|
||||
"""Register a new client.
|
||||
|
||||
After connection, you must input a ssh public key.
|
||||
|
||||
This must be an RSA key. No other types of keys are supported.
|
||||
"""
|
||||
|
||||
name = "register"
|
||||
|
||||
def verify_registration_host(self) -> bool:
|
||||
"""Check if registration command is allowed."""
|
||||
if not self.allowed_registration_networks:
|
||||
return False
|
||||
for network in self.allowed_registration_networks:
|
||||
if self.remote_ip in network:
|
||||
return True
|
||||
return False
|
||||
|
||||
@override
|
||||
async def exec(self) -> None:
|
||||
"""Register client."""
|
||||
if not self.verify_registration_host():
|
||||
self.print(f"Registration not permitted from {self.remote_ip}", stderr=True)
|
||||
await self.audit(
|
||||
Operation.DENY,
|
||||
constants.ERROR_REGISTRATION_NOT_ALLOWED,
|
||||
)
|
||||
return
|
||||
|
||||
self.print("Enter public key:")
|
||||
public_key: str | None = None
|
||||
try:
|
||||
async for line in self.process.stdin:
|
||||
public_key = verify_key_input(line.rstrip("\n"))
|
||||
if public_key:
|
||||
break
|
||||
raise exceptions.InvalidPublicKeyType()
|
||||
except asyncssh.BreakReceived:
|
||||
pass
|
||||
else:
|
||||
self.print("Key received. Validating.")
|
||||
|
||||
if not public_key:
|
||||
raise exceptions.InvalidPublicKeyType()
|
||||
|
||||
key = asyncssh.import_public_key(public_key)
|
||||
if key.algorithm.decode() != "ssh-rsa":
|
||||
raise exceptions.InvalidPublicKeyType()
|
||||
self.print("Key is valid. Registering client.")
|
||||
await self.audit(Operation.CREATE, "Registering new client.")
|
||||
await self.backend.create_client(self.username, public_key)
|
||||
12
packages/sshecret-sshd/src/sshecret_sshd/commands/utils.py
Normal file
12
packages/sshecret-sshd/src/sshecret_sshd/commands/utils.py
Normal file
@ -0,0 +1,12 @@
|
||||
"""Various utilities."""
|
||||
|
||||
import asyncssh
|
||||
from rich.console import Console
|
||||
|
||||
def get_console(process: asyncssh.SSHServerProcess[str]) -> Console:
|
||||
"""Initiate console from process."""
|
||||
_width, _height, pixwidth, pixheight = process.term_size
|
||||
console = Console(
|
||||
force_terminal=True, width=pixwidth, height=pixheight, color_system="standard"
|
||||
)
|
||||
return console
|
||||
Reference in New Issue
Block a user