Refactor command handling
This now supports usage/help texts
This commit is contained in:
@ -1,19 +1,18 @@
|
||||
"""SSH Server implementation."""
|
||||
|
||||
from asyncio import _register_task
|
||||
import logging
|
||||
|
||||
import asyncssh
|
||||
import ipaddress
|
||||
|
||||
from collections.abc import Awaitable
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Callable, cast, override
|
||||
from typing import Callable, override
|
||||
|
||||
from pydantic import IPvAnyNetwork
|
||||
|
||||
from . import constants
|
||||
from sshecret_sshd import constants
|
||||
from sshecret_sshd.commands import dispatch_command
|
||||
|
||||
from sshecret.backend import SshecretBackend, Client, Operation, SubSystem
|
||||
from .settings import ServerSettings, ClientRegistrationSettings
|
||||
@ -29,37 +28,6 @@ PeernameV6 = tuple[str, int, int, int]
|
||||
Peername = PeernameV4 | PeernameV6
|
||||
|
||||
|
||||
class CommandError(Exception):
|
||||
"""Error class for errors during command processing."""
|
||||
|
||||
|
||||
async def audit_process(
|
||||
backend: SshecretBackend,
|
||||
process: asyncssh.SSHServerProcess[str],
|
||||
operation: Operation,
|
||||
message: str,
|
||||
secret: str | None = None,
|
||||
**data: str,
|
||||
) -> None:
|
||||
"""Add an audit event from process."""
|
||||
command = get_process_command(process)
|
||||
client = get_info_client(process)
|
||||
username = get_info_username(process)
|
||||
remote_ip = get_info_remote_ip(process) or "UNKNOWN"
|
||||
if username:
|
||||
data["username"] = username
|
||||
|
||||
if command and not secret:
|
||||
cmd, cmd_args = command
|
||||
if cmd:
|
||||
data["command"] = cmd
|
||||
data["args"] = " ".join(cmd_args)
|
||||
|
||||
await backend.audit(SubSystem.SSHD).write_async(
|
||||
operation, message, remote_ip, client, secret=None, secret_name=secret, **data
|
||||
)
|
||||
|
||||
|
||||
async def audit_event(
|
||||
backend: SshecretBackend,
|
||||
message: str,
|
||||
@ -87,250 +55,6 @@ def verify_key_input(public_key: str) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def get_process_command(
|
||||
process: asyncssh.SSHServerProcess[str],
|
||||
) -> tuple[str | None, list[str]]:
|
||||
"""Extract the process command."""
|
||||
if not process.command:
|
||||
return (None, [])
|
||||
argv = process.command.split(" ")
|
||||
LOG.debug("Args: %r", argv)
|
||||
return (argv[0], argv[1:])
|
||||
|
||||
|
||||
def get_info_backend(process: asyncssh.SSHServerProcess[str]) -> SshecretBackend | None:
|
||||
"""Get backend from process."""
|
||||
backend = cast("SshecretBackend | None", process.get_extra_info("backend", None))
|
||||
return backend
|
||||
|
||||
|
||||
def get_info_client(process: asyncssh.SSHServerProcess[str]) -> Client | None:
|
||||
"""Get info from process."""
|
||||
client = cast("Client | None", process.get_extra_info("client", None))
|
||||
return client
|
||||
|
||||
|
||||
def get_info_username(process: asyncssh.SSHServerProcess[str]) -> str | None:
|
||||
"""Get username from process."""
|
||||
username = cast("str | None", process.get_extra_info("provided_username", None))
|
||||
return username
|
||||
|
||||
|
||||
def get_info_remote_ip(process: asyncssh.SSHServerProcess[str]) -> str | None:
|
||||
"""Get remote IP."""
|
||||
|
||||
peername = cast("Peername | None", process.get_extra_info("peername", None))
|
||||
remote_ip: str | None = None
|
||||
if peername:
|
||||
remote_ip = peername[0]
|
||||
|
||||
return remote_ip
|
||||
|
||||
|
||||
def get_info_allowed_registration(
|
||||
process: asyncssh.SSHServerProcess[str],
|
||||
) -> list[IPvAnyNetwork] | None:
|
||||
"""Get allowed networks to allow registration from."""
|
||||
|
||||
allowed_registration = cast(
|
||||
list[IPvAnyNetwork] | None,
|
||||
process.get_extra_info("allow_registration_from", None),
|
||||
)
|
||||
return allowed_registration
|
||||
|
||||
|
||||
def get_optional_commands(process: asyncssh.SSHServerProcess[str]) -> dict[str, bool]:
|
||||
"""Get optional command state."""
|
||||
with_registration = cast(
|
||||
bool, process.get_extra_info("registration_enabled", False)
|
||||
)
|
||||
with_ping = cast(bool, process.get_extra_info("ping_enabled", False))
|
||||
return {
|
||||
"registration": with_registration,
|
||||
"ping": with_ping,
|
||||
}
|
||||
|
||||
|
||||
async def get_stdin_public_key(process: asyncssh.SSHServerProcess[str]) -> str | None:
|
||||
"""Get public key from stdin."""
|
||||
process.stdout.write("Enter public key:\n")
|
||||
public_key: str | None = None
|
||||
try:
|
||||
async for line in process.stdin:
|
||||
public_key = verify_key_input(line.rstrip("\n"))
|
||||
if public_key:
|
||||
break
|
||||
raise CommandError(constants.ERROR_INVALID_KEY_TYPE)
|
||||
except asyncssh.BreakReceived:
|
||||
pass
|
||||
else:
|
||||
process.stdout.write("OK\n")
|
||||
return public_key
|
||||
|
||||
|
||||
async def register_client(
|
||||
process: asyncssh.SSHServerProcess[str],
|
||||
backend: SshecretBackend,
|
||||
username: str,
|
||||
) -> None:
|
||||
"""Register a new client."""
|
||||
public_key = await get_stdin_public_key(process)
|
||||
if not public_key:
|
||||
raise CommandError(constants.ERROR_NO_PUBLIC_KEY)
|
||||
|
||||
key = asyncssh.import_public_key(public_key)
|
||||
if key.algorithm.decode() != "ssh-rsa":
|
||||
raise CommandError(constants.ERROR_INVALID_KEY_TYPE)
|
||||
await audit_process(backend, process, Operation.CREATE, "Registering new client")
|
||||
LOG.debug("Registering client %s with public key %s", username, public_key)
|
||||
await backend.create_client(username, public_key)
|
||||
|
||||
|
||||
async def get_secret(
|
||||
backend: SshecretBackend,
|
||||
client: Client,
|
||||
secret_name: str,
|
||||
origin: str,
|
||||
) -> str:
|
||||
"""Handle get secret requests from client."""
|
||||
LOG.debug("Recieved command: get_secret %r", secret_name)
|
||||
if not secret_name:
|
||||
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
||||
if secret_name not in client.secrets:
|
||||
raise CommandError(constants.ERROR_NO_SECRET_FOUND)
|
||||
|
||||
await audit_event(
|
||||
backend,
|
||||
"Client requested secret",
|
||||
operation=Operation.READ,
|
||||
client=client,
|
||||
origin=origin,
|
||||
secret=secret_name,
|
||||
)
|
||||
|
||||
# Look up secret
|
||||
try:
|
||||
secret = await backend.get_client_secret(client.name, secret_name)
|
||||
if not secret:
|
||||
raise CommandError(constants.ERROR_NO_SECRET_FOUND)
|
||||
return secret
|
||||
except Exception as exc:
|
||||
LOG.debug(exc, exc_info=True)
|
||||
raise CommandError(constants.ERROR_BACKEND_ERROR) from exc
|
||||
|
||||
|
||||
async def dispatch_no_cmd(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
"""Dispatch for no command."""
|
||||
raise CommandError(constants.ERROR_NO_COMMAND_RECEIVED)
|
||||
|
||||
|
||||
async def dispatch_cmd_ping(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
"""Dispatch the ping command."""
|
||||
process.stdout.write("PONG\n")
|
||||
|
||||
|
||||
async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
"""Dispatch the register command."""
|
||||
backend = get_info_backend(process)
|
||||
if not backend:
|
||||
raise CommandError(constants.ERROR_INFO_BACKEND_GONE)
|
||||
username = get_info_username(process)
|
||||
if not username:
|
||||
raise CommandError(constants.ERROR_INFO_USERNAME_GONE)
|
||||
|
||||
allowed_networks = get_info_allowed_registration(process)
|
||||
if not allowed_networks:
|
||||
process.stdout.write("Unauthorized.\n")
|
||||
await audit_process(
|
||||
backend,
|
||||
process,
|
||||
Operation.DENY,
|
||||
"Received registration command, but no subnets are allowed.",
|
||||
)
|
||||
return
|
||||
|
||||
remote_ip = get_info_remote_ip(process)
|
||||
|
||||
if not remote_ip:
|
||||
raise CommandError(constants.ERROR_INFO_REMOTE_IP_GONE)
|
||||
|
||||
client_address = ipaddress.ip_address(remote_ip)
|
||||
|
||||
for network in allowed_networks:
|
||||
if client_address in network:
|
||||
break
|
||||
else:
|
||||
await audit_process(
|
||||
backend,
|
||||
process,
|
||||
Operation.DENY,
|
||||
"Received registration command from unauthorized subnet.",
|
||||
)
|
||||
process.stdout.write("Unauthorized.\n")
|
||||
return
|
||||
|
||||
await register_client(process, backend, username)
|
||||
|
||||
process.stdout.write("Client registered\n.")
|
||||
|
||||
|
||||
async def dispatch_cmd_get_secret(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
"""Dispatch the get_secret command."""
|
||||
backend = get_info_backend(process)
|
||||
if not backend:
|
||||
raise CommandError(constants.ERROR_INFO_BACKEND_GONE)
|
||||
|
||||
client = get_info_client(process)
|
||||
if not client:
|
||||
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
||||
_cmd, args = get_process_command(process)
|
||||
if not args:
|
||||
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
||||
secret_name = args[0]
|
||||
|
||||
origin = get_info_remote_ip(process) or "Unknown"
|
||||
secret = await get_secret(backend, client, secret_name, origin)
|
||||
process.stdout.write(secret)
|
||||
|
||||
|
||||
async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
"""Dispatch command."""
|
||||
command, _args = get_process_command(process)
|
||||
if not command:
|
||||
process.stderr.write(constants.ERROR_NO_COMMAND_RECEIVED)
|
||||
process.exit(1)
|
||||
return
|
||||
cmdmap: dict[str, CommandDispatch] = {
|
||||
"get_secret": dispatch_cmd_get_secret,
|
||||
}
|
||||
extra_commands = get_optional_commands(process)
|
||||
if "registration" in extra_commands:
|
||||
cmdmap["register"] = dispatch_cmd_register
|
||||
if "ping" in extra_commands:
|
||||
cmdmap["ping"] = dispatch_cmd_ping
|
||||
|
||||
if command not in cmdmap:
|
||||
process.stderr.write(constants.ERROR_UNKNOWN_COMMAND)
|
||||
process.exit(1)
|
||||
return
|
||||
exit_code = 0
|
||||
try:
|
||||
dispatcher = cmdmap[command]
|
||||
await dispatcher(process)
|
||||
except CommandError as e:
|
||||
process.stderr.write(str(e))
|
||||
exit_code = 1
|
||||
|
||||
except Exception as e:
|
||||
LOG.debug(e, exc_info=True)
|
||||
process.stderr.write("Unexpected exception:\n")
|
||||
process.stderr.write(str(e))
|
||||
exit_code = 1
|
||||
|
||||
LOG.debug("Command processing finished.")
|
||||
process.exit(exit_code)
|
||||
|
||||
|
||||
class AsshyncServer(asyncssh.SSHServer):
|
||||
"""Asynchronous SSH server implementation."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user