"""SSH Server implementation.""" import logging import asyncssh import ipaddress from collections.abc import Awaitable from functools import partial from pathlib import Path from typing import Callable, override from pydantic import IPvAnyNetwork from sshecret_sshd import constants from sshecret_sshd.commands import dispatch_command from sshecret.backend import SshecretBackend, Client, Operation, SubSystem from .settings import ServerSettings, ClientRegistrationSettings LOG = logging.getLogger(__name__) CommandDispatch = Callable[[asyncssh.SSHServerProcess[str]], Awaitable[None]] PeernameV4 = tuple[str, int] PeernameV6 = tuple[str, int, int, int] Peername = PeernameV4 | PeernameV6 async def audit_event( backend: SshecretBackend, message: str, operation: Operation, client: Client | None = None, origin: str | None = None, secret: str | None = None, **data: str, ) -> None: """Add an audit event.""" if not origin: origin = "UNKNOWN" await backend.audit(SubSystem.SSHD).write_async( operation, message, origin, client, secret=None, secret_name=secret, **data ) def verify_key_input(public_key: str) -> str | None: """Verify key input.""" try: key = asyncssh.import_public_key(public_key) if key.algorithm.decode() == "ssh-rsa": return public_key except Exception: return None class AsshyncServer(asyncssh.SSHServer): """Asynchronous SSH server implementation.""" def __init__( self, backend: SshecretBackend, registration: ClientRegistrationSettings, enable_ping_command: bool = False, ) -> None: """Initialize server.""" self.backend: SshecretBackend = backend self._conn: asyncssh.SSHServerConnection | None = None self.registration_enabled: bool = registration.enabled self.allow_registration_from: list[IPvAnyNetwork] | None = None if registration.enabled: self.allow_registration_from = registration.allow_from self.ping_enabled: bool = enable_ping_command self.client_ip: str | None = None @override def connection_made(self, conn: asyncssh.SSHServerConnection) -> None: """Handle incoming connection.""" peername = conn.get_extra_info("peername") LOG.debug("Connection established from %r", peername) self.client_ip = peername[0] self._conn = conn self._conn.set_extra_info(backend=self.backend) self._conn.set_extra_info(registration_enabled=self.registration_enabled) self._conn.set_extra_info(ping_enabled=self.ping_enabled) @override def password_auth_supported(self) -> bool: """Deny password authentication.""" return False @override async def begin_auth(self, username: str) -> bool: """Begin authentication. Note we always return True here. False bypasses the whole authentication flow. """ LOG.debug("Started authentication flow for user %s", username) allowed_registration_sources: list[IPvAnyNetwork] = [] if self.registration_enabled and not self.allow_registration_from: allowed_registration_sources.append(ipaddress.IPv4Network("0.0.0.0/0")) allowed_registration_sources.append(ipaddress.IPv6Network("::/0")) elif self.registration_enabled and self.allow_registration_from: allowed_registration_sources = self.allow_registration_from assert self._conn is not None, "Error: No connection found." if client := await self.backend.get_client(username): LOG.debug("Client lookup sucessful: %r", client) if key := self.resolve_client_key(client): LOG.debug("Loaded public key for client %s\n%s", client.name, key) self._conn.set_extra_info(client=client) self._conn.set_authorized_keys(key) else: await audit_event( self.backend, "Client denied due to policy", Operation.DENY, client, origin=self.client_ip, ) LOG.warning( "Client connection denied. Source: %s, policy: %r.", self.client_ip, client.policies, ) elif allowed_registration_sources and self.client_ip: client_ip = ipaddress.ip_address(self.client_ip) for network in allowed_registration_sources: if client_ip.version != network.version: continue if client_ip in network: self._conn.set_extra_info(provided_username=username) self._conn.set_extra_info( allow_registration_from=self.allow_registration_from ) LOG.info( "Registration enabled, and client is not recognized. Bypassing authentication." ) return False else: await audit_event( self.backend, "Received registration command from unauthorized subnet.", Operation.DENY, origin=self.client_ip, username=username, ) LOG.warning( "Registration not permitted for username=%s, origin: %s", username, self.client_ip, ) LOG.debug("Continuing to regular authentication") return True def resolve_client_key(self, client: Client) -> asyncssh.SSHAuthorizedKeys | None: """Resolve the client key. Returns the key object only if the client is allowed to connect according to its policy. """ if not self._conn: return None remote_ip = str(self._conn.get_extra_info("peername")[0]) LOG.debug("Validating client %s connection from %s", client.name, remote_ip) LOG.debug("Loading client public key %r", client.public_key) if self.check_connection_allowed(client, remote_ip): return asyncssh.import_authorized_keys(client.public_key) return None def check_connection_allowed(self, client: Client, source: str) -> bool: """Check if the client is allowed to connect.""" source_ip = ipaddress.ip_address(source) policies = [ipaddress.ip_network(policy) for policy in client.policies] valid_source = [source_ip in policy for policy in policies] LOG.debug("Valid sources %r from policies %r", valid_source, policies) return any(valid_source) def get_server_key(basedir: Path | None = None) -> str: """Resolve server key. TODO: Is one key enough? Should we generate more keys? """ filename = Path(f"ssh_host_{constants.SERVER_KEY_TYPE}_key") if basedir: filename = basedir / filename if filename.exists(): return str(filename.absolute()) # FIXME: There's a weird typing warning here that I need to investigate. private_key = asyncssh.generate_private_key("ssh-ed25519", comment="sshecret-sshd") with open(filename, "wb") as f: f.write(private_key.export_private_key()) return str(filename.absolute()) async def run_ssh_server( backend: SshecretBackend, listen_address: str, port: int, keys: list[str], registration: ClientRegistrationSettings, enable_ping_command: bool = False, ) -> asyncssh.SSHAcceptor: """Run the server.""" server = partial( AsshyncServer, backend=backend, registration=registration, enable_ping_command=enable_ping_command, ) server = await asyncssh.create_server( server, listen_address, port, server_host_keys=keys, process_factory=dispatch_command, ) return server async def start_sshecret_sshd( settings: ServerSettings | None = None, ) -> asyncssh.SSHAcceptor: """Start the server.""" server_key = get_server_key() if not settings: settings = ServerSettings() # pyright: ignore[reportCallIssue] backend = SshecretBackend(str(settings.backend_url), settings.backend_token) return await run_ssh_server( backend=backend, listen_address=settings.listen_address, port=settings.port, keys=[server_key], registration=settings.registration, enable_ping_command=settings.enable_ping_command, )