Refactor command handling

This now supports usage/help texts
This commit is contained in:
2025-05-18 17:56:53 +02:00
parent 26ef9b45d4
commit dcf0b4274c
15 changed files with 337 additions and 431 deletions

View File

@ -1,19 +1,18 @@
"""SSH Server implementation."""
from asyncio import _register_task
import logging
import asyncssh
import ipaddress
from collections.abc import Awaitable
from functools import partial
from pathlib import Path
from typing import Callable, cast, override
from typing import Callable, override
from pydantic import IPvAnyNetwork
from . import constants
from sshecret_sshd import constants
from sshecret_sshd.commands import dispatch_command
from sshecret.backend import SshecretBackend, Client, Operation, SubSystem
from .settings import ServerSettings, ClientRegistrationSettings
@ -29,37 +28,6 @@ PeernameV6 = tuple[str, int, int, int]
Peername = PeernameV4 | PeernameV6
class CommandError(Exception):
"""Error class for errors during command processing."""
async def audit_process(
backend: SshecretBackend,
process: asyncssh.SSHServerProcess[str],
operation: Operation,
message: str,
secret: str | None = None,
**data: str,
) -> None:
"""Add an audit event from process."""
command = get_process_command(process)
client = get_info_client(process)
username = get_info_username(process)
remote_ip = get_info_remote_ip(process) or "UNKNOWN"
if username:
data["username"] = username
if command and not secret:
cmd, cmd_args = command
if cmd:
data["command"] = cmd
data["args"] = " ".join(cmd_args)
await backend.audit(SubSystem.SSHD).write_async(
operation, message, remote_ip, client, secret=None, secret_name=secret, **data
)
async def audit_event(
backend: SshecretBackend,
message: str,
@ -87,250 +55,6 @@ def verify_key_input(public_key: str) -> str | None:
return None
def get_process_command(
process: asyncssh.SSHServerProcess[str],
) -> tuple[str | None, list[str]]:
"""Extract the process command."""
if not process.command:
return (None, [])
argv = process.command.split(" ")
LOG.debug("Args: %r", argv)
return (argv[0], argv[1:])
def get_info_backend(process: asyncssh.SSHServerProcess[str]) -> SshecretBackend | None:
"""Get backend from process."""
backend = cast("SshecretBackend | None", process.get_extra_info("backend", None))
return backend
def get_info_client(process: asyncssh.SSHServerProcess[str]) -> Client | None:
"""Get info from process."""
client = cast("Client | None", process.get_extra_info("client", None))
return client
def get_info_username(process: asyncssh.SSHServerProcess[str]) -> str | None:
"""Get username from process."""
username = cast("str | None", process.get_extra_info("provided_username", None))
return username
def get_info_remote_ip(process: asyncssh.SSHServerProcess[str]) -> str | None:
"""Get remote IP."""
peername = cast("Peername | None", process.get_extra_info("peername", None))
remote_ip: str | None = None
if peername:
remote_ip = peername[0]
return remote_ip
def get_info_allowed_registration(
process: asyncssh.SSHServerProcess[str],
) -> list[IPvAnyNetwork] | None:
"""Get allowed networks to allow registration from."""
allowed_registration = cast(
list[IPvAnyNetwork] | None,
process.get_extra_info("allow_registration_from", None),
)
return allowed_registration
def get_optional_commands(process: asyncssh.SSHServerProcess[str]) -> dict[str, bool]:
"""Get optional command state."""
with_registration = cast(
bool, process.get_extra_info("registration_enabled", False)
)
with_ping = cast(bool, process.get_extra_info("ping_enabled", False))
return {
"registration": with_registration,
"ping": with_ping,
}
async def get_stdin_public_key(process: asyncssh.SSHServerProcess[str]) -> str | None:
"""Get public key from stdin."""
process.stdout.write("Enter public key:\n")
public_key: str | None = None
try:
async for line in process.stdin:
public_key = verify_key_input(line.rstrip("\n"))
if public_key:
break
raise CommandError(constants.ERROR_INVALID_KEY_TYPE)
except asyncssh.BreakReceived:
pass
else:
process.stdout.write("OK\n")
return public_key
async def register_client(
process: asyncssh.SSHServerProcess[str],
backend: SshecretBackend,
username: str,
) -> None:
"""Register a new client."""
public_key = await get_stdin_public_key(process)
if not public_key:
raise CommandError(constants.ERROR_NO_PUBLIC_KEY)
key = asyncssh.import_public_key(public_key)
if key.algorithm.decode() != "ssh-rsa":
raise CommandError(constants.ERROR_INVALID_KEY_TYPE)
await audit_process(backend, process, Operation.CREATE, "Registering new client")
LOG.debug("Registering client %s with public key %s", username, public_key)
await backend.create_client(username, public_key)
async def get_secret(
backend: SshecretBackend,
client: Client,
secret_name: str,
origin: str,
) -> str:
"""Handle get secret requests from client."""
LOG.debug("Recieved command: get_secret %r", secret_name)
if not secret_name:
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
if secret_name not in client.secrets:
raise CommandError(constants.ERROR_NO_SECRET_FOUND)
await audit_event(
backend,
"Client requested secret",
operation=Operation.READ,
client=client,
origin=origin,
secret=secret_name,
)
# Look up secret
try:
secret = await backend.get_client_secret(client.name, secret_name)
if not secret:
raise CommandError(constants.ERROR_NO_SECRET_FOUND)
return secret
except Exception as exc:
LOG.debug(exc, exc_info=True)
raise CommandError(constants.ERROR_BACKEND_ERROR) from exc
async def dispatch_no_cmd(process: asyncssh.SSHServerProcess[str]) -> None:
"""Dispatch for no command."""
raise CommandError(constants.ERROR_NO_COMMAND_RECEIVED)
async def dispatch_cmd_ping(process: asyncssh.SSHServerProcess[str]) -> None:
"""Dispatch the ping command."""
process.stdout.write("PONG\n")
async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None:
"""Dispatch the register command."""
backend = get_info_backend(process)
if not backend:
raise CommandError(constants.ERROR_INFO_BACKEND_GONE)
username = get_info_username(process)
if not username:
raise CommandError(constants.ERROR_INFO_USERNAME_GONE)
allowed_networks = get_info_allowed_registration(process)
if not allowed_networks:
process.stdout.write("Unauthorized.\n")
await audit_process(
backend,
process,
Operation.DENY,
"Received registration command, but no subnets are allowed.",
)
return
remote_ip = get_info_remote_ip(process)
if not remote_ip:
raise CommandError(constants.ERROR_INFO_REMOTE_IP_GONE)
client_address = ipaddress.ip_address(remote_ip)
for network in allowed_networks:
if client_address in network:
break
else:
await audit_process(
backend,
process,
Operation.DENY,
"Received registration command from unauthorized subnet.",
)
process.stdout.write("Unauthorized.\n")
return
await register_client(process, backend, username)
process.stdout.write("Client registered\n.")
async def dispatch_cmd_get_secret(process: asyncssh.SSHServerProcess[str]) -> None:
"""Dispatch the get_secret command."""
backend = get_info_backend(process)
if not backend:
raise CommandError(constants.ERROR_INFO_BACKEND_GONE)
client = get_info_client(process)
if not client:
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
_cmd, args = get_process_command(process)
if not args:
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
secret_name = args[0]
origin = get_info_remote_ip(process) or "Unknown"
secret = await get_secret(backend, client, secret_name, origin)
process.stdout.write(secret)
async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
"""Dispatch command."""
command, _args = get_process_command(process)
if not command:
process.stderr.write(constants.ERROR_NO_COMMAND_RECEIVED)
process.exit(1)
return
cmdmap: dict[str, CommandDispatch] = {
"get_secret": dispatch_cmd_get_secret,
}
extra_commands = get_optional_commands(process)
if "registration" in extra_commands:
cmdmap["register"] = dispatch_cmd_register
if "ping" in extra_commands:
cmdmap["ping"] = dispatch_cmd_ping
if command not in cmdmap:
process.stderr.write(constants.ERROR_UNKNOWN_COMMAND)
process.exit(1)
return
exit_code = 0
try:
dispatcher = cmdmap[command]
await dispatcher(process)
except CommandError as e:
process.stderr.write(str(e))
exit_code = 1
except Exception as e:
LOG.debug(e, exc_info=True)
process.stderr.write("Unexpected exception:\n")
process.stderr.write(str(e))
exit_code = 1
LOG.debug("Command processing finished.")
process.exit(exit_code)
class AsshyncServer(asyncssh.SSHServer):
"""Asynchronous SSH server implementation."""