Files
sshecret/packages/sshecret-sshd/src/sshecret_sshd/ssh_server.py
Allan Eising dcf0b4274c Refactor command handling
This now supports usage/help texts
2025-05-18 17:56:53 +02:00

247 lines
8.5 KiB
Python

"""SSH Server implementation."""
import logging
import asyncssh
import ipaddress
from collections.abc import Awaitable
from functools import partial
from pathlib import Path
from typing import Callable, override
from pydantic import IPvAnyNetwork
from sshecret_sshd import constants
from sshecret_sshd.commands import dispatch_command
from sshecret.backend import SshecretBackend, Client, Operation, SubSystem
from .settings import ServerSettings, ClientRegistrationSettings
LOG = logging.getLogger(__name__)
CommandDispatch = Callable[[asyncssh.SSHServerProcess[str]], Awaitable[None]]
PeernameV4 = tuple[str, int]
PeernameV6 = tuple[str, int, int, int]
Peername = PeernameV4 | PeernameV6
async def audit_event(
backend: SshecretBackend,
message: str,
operation: Operation,
client: Client | None = None,
origin: str | None = None,
secret: str | None = None,
**data: str,
) -> None:
"""Add an audit event."""
if not origin:
origin = "UNKNOWN"
await backend.audit(SubSystem.SSHD).write_async(
operation, message, origin, client, secret=None, secret_name=secret, **data
)
def verify_key_input(public_key: str) -> str | None:
"""Verify key input."""
try:
key = asyncssh.import_public_key(public_key)
if key.algorithm.decode() == "ssh-rsa":
return public_key
except Exception:
return None
class AsshyncServer(asyncssh.SSHServer):
"""Asynchronous SSH server implementation."""
def __init__(
self,
backend: SshecretBackend,
registration: ClientRegistrationSettings,
enable_ping_command: bool = False,
) -> None:
"""Initialize server."""
self.backend: SshecretBackend = backend
self._conn: asyncssh.SSHServerConnection | None = None
self.registration_enabled: bool = registration.enabled
self.allow_registration_from: list[IPvAnyNetwork] | None = None
if registration.enabled:
self.allow_registration_from = registration.allow_from
self.ping_enabled: bool = enable_ping_command
self.client_ip: str | None = None
@override
def connection_made(self, conn: asyncssh.SSHServerConnection) -> None:
"""Handle incoming connection."""
peername = conn.get_extra_info("peername")
LOG.debug("Connection established from %r", peername)
self.client_ip = peername[0]
self._conn = conn
self._conn.set_extra_info(backend=self.backend)
self._conn.set_extra_info(registration_enabled=self.registration_enabled)
self._conn.set_extra_info(ping_enabled=self.ping_enabled)
@override
def password_auth_supported(self) -> bool:
"""Deny password authentication."""
return False
@override
async def begin_auth(self, username: str) -> bool:
"""Begin authentication.
Note we always return True here. False bypasses the whole authentication
flow.
"""
LOG.debug("Started authentication flow for user %s", username)
allowed_registration_sources: list[IPvAnyNetwork] = []
if self.registration_enabled and not self.allow_registration_from:
allowed_registration_sources.append(ipaddress.IPv4Network("0.0.0.0/0"))
allowed_registration_sources.append(ipaddress.IPv6Network("::/0"))
elif self.registration_enabled and self.allow_registration_from:
allowed_registration_sources = self.allow_registration_from
assert self._conn is not None, "Error: No connection found."
if client := await self.backend.get_client(username):
LOG.debug("Client lookup sucessful: %r", client)
if key := self.resolve_client_key(client):
LOG.debug("Loaded public key for client %s\n%s", client.name, key)
self._conn.set_extra_info(client=client)
self._conn.set_authorized_keys(key)
else:
await audit_event(
self.backend,
"Client denied due to policy",
Operation.DENY,
client,
origin=self.client_ip,
)
LOG.warning(
"Client connection denied. Source: %s, policy: %r.",
self.client_ip,
client.policies,
)
elif allowed_registration_sources and self.client_ip:
client_ip = ipaddress.ip_address(self.client_ip)
for network in allowed_registration_sources:
if client_ip.version != network.version:
continue
if client_ip in network:
self._conn.set_extra_info(provided_username=username)
self._conn.set_extra_info(
allow_registration_from=self.allow_registration_from
)
LOG.info(
"Registration enabled, and client is not recognized. Bypassing authentication."
)
return False
else:
await audit_event(
self.backend,
"Received registration command from unauthorized subnet.",
Operation.DENY,
origin=self.client_ip,
username=username,
)
LOG.warning(
"Registration not permitted for username=%s, origin: %s",
username,
self.client_ip,
)
LOG.debug("Continuing to regular authentication")
return True
def resolve_client_key(self, client: Client) -> asyncssh.SSHAuthorizedKeys | None:
"""Resolve the client key.
Returns the key object only if the client is allowed to connect
according to its policy.
"""
if not self._conn:
return None
remote_ip = str(self._conn.get_extra_info("peername")[0])
LOG.debug("Validating client %s connection from %s", client.name, remote_ip)
LOG.debug("Loading client public key %r", client.public_key)
if self.check_connection_allowed(client, remote_ip):
return asyncssh.import_authorized_keys(client.public_key)
return None
def check_connection_allowed(self, client: Client, source: str) -> bool:
"""Check if the client is allowed to connect."""
source_ip = ipaddress.ip_address(source)
policies = [ipaddress.ip_network(policy) for policy in client.policies]
valid_source = [source_ip in policy for policy in policies]
LOG.debug("Valid sources %r from policies %r", valid_source, policies)
return any(valid_source)
def get_server_key(basedir: Path | None = None) -> str:
"""Resolve server key.
TODO: Is one key enough? Should we generate more keys?
"""
filename = Path(f"ssh_host_{constants.SERVER_KEY_TYPE}_key")
if basedir:
filename = basedir / filename
if filename.exists():
return str(filename.absolute())
# FIXME: There's a weird typing warning here that I need to investigate.
private_key = asyncssh.generate_private_key("ssh-ed25519", comment="sshecret-sshd")
with open(filename, "wb") as f:
f.write(private_key.export_private_key())
return str(filename.absolute())
async def run_ssh_server(
backend: SshecretBackend,
listen_address: str,
port: int,
keys: list[str],
registration: ClientRegistrationSettings,
enable_ping_command: bool = False,
) -> asyncssh.SSHAcceptor:
"""Run the server."""
server = partial(
AsshyncServer,
backend=backend,
registration=registration,
enable_ping_command=enable_ping_command,
)
server = await asyncssh.create_server(
server,
listen_address,
port,
server_host_keys=keys,
process_factory=dispatch_command,
)
return server
async def start_sshecret_sshd(
settings: ServerSettings | None = None,
) -> asyncssh.SSHAcceptor:
"""Start the server."""
server_key = get_server_key()
if not settings:
settings = ServerSettings() # pyright: ignore[reportCallIssue]
backend = SshecretBackend(str(settings.backend_url), settings.backend_token)
return await run_ssh_server(
backend=backend,
listen_address=settings.listen_address,
port=settings.port,
keys=[server_key],
registration=settings.registration,
enable_ping_command=settings.enable_ping_command,
)