Complete sshd package with tests

This commit is contained in:
2025-05-10 08:27:16 +02:00
parent 3719a2611d
commit 4f970a3f71
12 changed files with 472 additions and 103 deletions

View File

@ -1,21 +1,21 @@
"""SSH Server implementation."""
import logging
import uuid
import asyncssh
import ipaddress
from collections.abc import Awaitable
from datetime import datetime, timezone
from functools import partial
from pathlib import Path
from typing import Any, Callable, cast, override
from typing import Callable, cast, override
from pydantic import IPvAnyNetwork
from . import constants
from sshecret.backend import AuditLog, SshecretBackend, Client
from .settings import ServerSettings
from sshecret.backend import SshecretBackend, Client, Operation, SubSystem
from .settings import ServerSettings, ClientRegistrationSettings
LOG = logging.getLogger(__name__)
@ -35,71 +35,39 @@ class CommandError(Exception):
def audit_process(
backend: SshecretBackend,
process: asyncssh.SSHServerProcess[str],
operation: Operation,
message: str,
secret: str | None = None,
**data: str,
) -> None:
"""Add an audit event from process."""
command = get_process_command(process)
client = get_info_client(process)
username = get_info_username(process)
remote_ip = get_info_remote_ip(process)
operation = "SSH_EVENT"
obj_name: str | None = None
obj_id: str | None = None
remote_ip = get_info_remote_ip(process) or "UNKNOWN"
if username:
data["username"] = username
if command and not secret:
cmd, cmd_args = command
obj_id = " ".join(cmd_args)
elif secret:
obj_name = "ClientSecret"
obj_id = secret
entry = AuditLog(
subsystem="ssh",
operation=operation,
object=obj_name,
object_id=obj_id,
message=message,
origin=remote_ip,
)
if client:
entry.client_id = str(client.id)
entry.client_name = client.name
elif username:
entry.client_name = username
backend.add_audit_log_sync(entry)
if cmd:
data["command"] = cmd
data["args"] = " ".join(cmd_args)
backend.audit(SubSystem.SSHD).write(operation, message, remote_ip, client, secret=None, secret_name=secret, **data)
def audit_event(
backend: SshecretBackend,
message: str,
operation: str = "SSH_EVENT",
operation: Operation,
client: Client | None = None,
origin: str | None = None,
secret: str | None = None,
) -> None:
"""Add an audit event."""
entry = AuditLog(
client_id=None,
client_name=None,
object=None,
object_id=None,
subsystem="ssh",
operation=operation,
message=message,
origin=origin,
)
if client:
entry.client_id = str(client.id)
entry.client_name = client.name
if secret:
entry.object = "ClientSecret"
entry.object_id = secret
backend.add_audit_log_sync(entry)
if not origin:
origin = "UNKNOWN"
backend.audit(SubSystem.SSHD).write(operation, message, origin, client, secret=None, secret_name=secret)
def verify_key_input(public_key: str) -> str | None:
"""Verify key input."""
@ -118,6 +86,7 @@ def get_process_command(
if not process.command:
return (None, [])
argv = process.command.split(" ")
LOG.debug("Args: %r", argv)
return (argv[0], argv[1:])
@ -149,7 +118,12 @@ def get_info_remote_ip(process: asyncssh.SSHServerProcess[str]) -> str | None:
return remote_ip
# remote_ip = str(self._conn.get_extra_info("peername")[0])
def get_info_allowed_registration(process: asyncssh.SSHServerProcess[str]) -> list[IPvAnyNetwork] | None:
"""Get allowed networks to allow registration from."""
allowed_registration = cast(list[IPvAnyNetwork] | None, process.get_extra_info("allow_registration_from", None))
return allowed_registration
def get_optional_commands(process: asyncssh.SSHServerProcess[str]) -> dict[str, bool]:
@ -197,12 +171,12 @@ async def register_client(
"""Register a new client."""
public_key = await get_stdin_public_key(process)
if not public_key:
raise CommandError("Aborted. No valid public key received.")
raise CommandError(constants.ERROR_NO_PUBLIC_KEY)
key = asyncssh.import_public_key(public_key)
if key.algorithm.decode() != "ssh-rsa":
raise CommandError("Error: Only RSA keys are supported!")
audit_process(backend, process, "Registering new client")
raise CommandError(constants.ERROR_INVALID_KEY_TYPE)
audit_process(backend, process, Operation.CREATE, "Registering new client")
LOG.debug("Registering client %s with public key %s", username, public_key)
await backend.create_client(username, public_key)
@ -214,14 +188,16 @@ async def get_secret(
origin: str,
) -> str:
"""Handle get secret requests from client."""
LOG.debug("Recieved command: %r", secret_name)
if not secret_name or secret_name not in client.secrets:
LOG.debug("Recieved command: get_secret %r", secret_name)
if not secret_name:
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
if secret_name not in client.secrets:
raise CommandError(constants.ERROR_NO_SECRET_FOUND)
audit_event(
backend,
"Client requested secret",
operation="get_secret",
operation=Operation.READ,
client=client,
origin=origin,
secret=secret_name,
@ -229,10 +205,13 @@ async def get_secret(
# Look up secret
try:
return await backend.get_client_secret(client.name, secret_name)
secret = await backend.get_client_secret(client.name, secret_name)
if not secret:
raise CommandError(constants.ERROR_NO_SECRET_FOUND)
return secret
except Exception as exc:
LOG.debug(exc, exc_info=True)
raise CommandError("Unexpected error from backend") from exc
raise CommandError(constants.ERROR_BACKEND_ERROR) from exc
async def dispatch_no_cmd(process: asyncssh.SSHServerProcess[str]) -> None:
@ -249,10 +228,32 @@ async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None
"""Dispatch the register command."""
backend = get_info_backend(process)
if not backend:
raise CommandError("Unexpected error: Backend disappeared.")
raise CommandError(constants.ERROR_INFO_BACKEND_GONE)
username = get_info_username(process)
if not username:
raise CommandError("Unexpected error: Username was lost.")
raise CommandError(constants.ERROR_INFO_USERNAME_GONE)
allowed_networks = get_info_allowed_registration(process)
if not allowed_networks:
process.stdout.write("Unauthorized.\n")
audit_process(backend, process, Operation.DENY, "Received registration command, but no subnets are allowed.")
return
remote_ip = get_info_remote_ip(process)
if not remote_ip:
raise CommandError(constants.ERROR_INFO_REMOTE_IP_GONE)
client_address = ipaddress.ip_address(remote_ip)
for network in allowed_networks:
if client_address in network:
break
else:
audit_process(backend, process, Operation.DENY, "Received registration command from unauthorized subnet.")
process.stdout.write("Unauthorized.\n")
return
await register_client(process, backend, username)
process.stdout.write("Client registered\n.")
@ -262,7 +263,7 @@ async def dispatch_cmd_get_secret(process: asyncssh.SSHServerProcess[str]) -> No
"""Dispatch the get_secret command."""
backend = get_info_backend(process)
if not backend:
raise CommandError("Unexpected error: Backend disappeared.")
raise CommandError(constants.ERROR_INFO_BACKEND_GONE)
client = get_info_client(process)
if not client:
@ -311,6 +312,7 @@ async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
process.stderr.write(str(e))
exit_code = 1
LOG.debug("Command processing finished.")
process.exit(exit_code)
@ -319,16 +321,18 @@ class AsshyncServer(asyncssh.SSHServer):
def __init__(
self,
backend_url: str,
backend_token: str,
with_register: bool = True,
with_ping: bool = True,
backend: SshecretBackend,
registration: ClientRegistrationSettings,
enable_ping_command: bool = False,
) -> None:
"""Initialize server."""
self.backend: SshecretBackend = SshecretBackend(backend_url, backend_token)
self.backend: SshecretBackend = backend
self._conn: asyncssh.SSHServerConnection | None = None
self.registration_enabled: bool = with_register
self.ping_enabled: bool = with_ping
self.registration_enabled: bool = registration.enabled
self.allow_registration_from: list[IPvAnyNetwork] | None = None
if registration.enabled:
self.allow_registration_from = registration.allow_from
self.ping_enabled: bool = enable_ping_command
self.client_ip: str | None = None
@override
@ -359,9 +363,9 @@ class AsshyncServer(asyncssh.SSHServer):
if not self._conn:
return True
if client := await self.backend.get_client(username):
LOG.debug("Client lookup sucessful.")
LOG.debug("Client lookup sucessful: %r", client)
if key := self.resolve_client_key(client):
LOG.debug("Loaded public key for client %s", client.name)
LOG.debug("Loaded public key for client %s\n%s", client.name, key)
self._conn.set_extra_info(client=client)
self._conn.set_authorized_keys(key)
else:
@ -369,15 +373,18 @@ class AsshyncServer(asyncssh.SSHServer):
audit_event(
self.backend,
"Client denied due to policy",
"DENY",
Operation.DENY,
client,
origin=self.client_ip,
)
LOG.warning("Client connection denied due to policy.")
else:
elif self.registration_enabled:
self._conn.set_extra_info(provided_username=username)
self._conn.set_extra_info(allow_registration_from=self.allow_registration_from)
LOG.warning("Registration enabled, and client is not recognized. Bypassing authentication.")
return False
LOG.debug("Continuing to regular authentication")
return True
@override
@ -403,6 +410,7 @@ class AsshyncServer(asyncssh.SSHServer):
return None
remote_ip = str(self._conn.get_extra_info("peername")[0])
LOG.debug("Validating client %s connection from %s", client.name, remote_ip)
LOG.debug("Loading client public key %r", client.public_key)
if self.check_connection_allowed(client, remote_ip):
return asyncssh.import_authorized_keys(client.public_key)
return None
@ -428,7 +436,7 @@ def get_server_key(basedir: Path | None = None) -> str:
if filename.exists():
return str(filename.absolute())
# FIXME: There's a weird typing warning here that I need to investigate.
private_key = asyncssh.generate_private_key("ssh-ed25519", "sshecret-sshd")
private_key = asyncssh.generate_private_key("ssh-ed25519", comment="sshecret-sshd")
with open(filename, "wb") as f:
f.write(private_key.export_private_key())
@ -436,15 +444,19 @@ def get_server_key(basedir: Path | None = None) -> str:
async def run_ssh_server(
backend_url: str,
backend_token: str,
backend: SshecretBackend,
listen_address: str,
port: int,
keys: list[str],
registration: ClientRegistrationSettings,
enable_ping_command: bool = False,
) -> asyncssh.SSHAcceptor:
"""Run the server."""
server = partial(
AsshyncServer, backend_url=str(backend_url), backend_token=backend_token
AsshyncServer,
backend=backend,
registration=registration,
enable_ping_command=enable_ping_command,
)
server = await asyncssh.create_server(
server,
@ -463,10 +475,12 @@ async def start_server(settings: ServerSettings | None = None) -> None:
if not settings:
settings = ServerSettings() # pyright: ignore[reportCallIssue]
backend = SshecretBackend(str(settings.backend_url), settings.backend_token)
await run_ssh_server(
str(settings.backend_url),
settings.backend_token,
settings.listen_address,
settings.port,
[server_key],
backend=backend,
listen_address=settings.listen_address,
port=settings.port,
keys=[server_key],
registration=settings.registration,
enable_ping_command=settings.enable_ping_command,
)