247 lines
8.5 KiB
Python
247 lines
8.5 KiB
Python
"""SSH Server implementation."""
|
|
|
|
import logging
|
|
import asyncssh
|
|
import ipaddress
|
|
|
|
from collections.abc import Awaitable
|
|
from functools import partial
|
|
from pathlib import Path
|
|
from typing import Callable, override
|
|
|
|
from pydantic import IPvAnyNetwork
|
|
|
|
from sshecret_sshd import constants
|
|
from sshecret_sshd.commands import dispatch_command
|
|
|
|
from sshecret.backend import SshecretBackend, Client, Operation, SubSystem
|
|
from .settings import ServerSettings, ClientRegistrationSettings
|
|
|
|
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
|
|
CommandDispatch = Callable[[asyncssh.SSHServerProcess[str]], Awaitable[None]]
|
|
|
|
PeernameV4 = tuple[str, int]
|
|
PeernameV6 = tuple[str, int, int, int]
|
|
Peername = PeernameV4 | PeernameV6
|
|
|
|
|
|
async def audit_event(
|
|
backend: SshecretBackend,
|
|
message: str,
|
|
operation: Operation,
|
|
client: Client | None = None,
|
|
origin: str | None = None,
|
|
secret: str | None = None,
|
|
**data: str,
|
|
) -> None:
|
|
"""Add an audit event."""
|
|
if not origin:
|
|
origin = "UNKNOWN"
|
|
await backend.audit(SubSystem.SSHD).write_async(
|
|
operation, message, origin, client, secret=None, secret_name=secret, **data
|
|
)
|
|
|
|
|
|
def verify_key_input(public_key: str) -> str | None:
|
|
"""Verify key input."""
|
|
try:
|
|
key = asyncssh.import_public_key(public_key)
|
|
if key.algorithm.decode() == "ssh-rsa":
|
|
return public_key
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
class AsshyncServer(asyncssh.SSHServer):
|
|
"""Asynchronous SSH server implementation."""
|
|
|
|
def __init__(
|
|
self,
|
|
backend: SshecretBackend,
|
|
registration: ClientRegistrationSettings,
|
|
enable_ping_command: bool = False,
|
|
) -> None:
|
|
"""Initialize server."""
|
|
self.backend: SshecretBackend = backend
|
|
self._conn: asyncssh.SSHServerConnection | None = None
|
|
self.registration_enabled: bool = registration.enabled
|
|
self.allow_registration_from: list[IPvAnyNetwork] | None = None
|
|
if registration.enabled:
|
|
self.allow_registration_from = registration.allow_from
|
|
self.ping_enabled: bool = enable_ping_command
|
|
self.client_ip: str | None = None
|
|
|
|
@override
|
|
def connection_made(self, conn: asyncssh.SSHServerConnection) -> None:
|
|
"""Handle incoming connection."""
|
|
peername = conn.get_extra_info("peername")
|
|
LOG.debug("Connection established from %r", peername)
|
|
self.client_ip = peername[0]
|
|
self._conn = conn
|
|
self._conn.set_extra_info(backend=self.backend)
|
|
self._conn.set_extra_info(registration_enabled=self.registration_enabled)
|
|
self._conn.set_extra_info(ping_enabled=self.ping_enabled)
|
|
|
|
@override
|
|
def password_auth_supported(self) -> bool:
|
|
"""Deny password authentication."""
|
|
return False
|
|
|
|
@override
|
|
async def begin_auth(self, username: str) -> bool:
|
|
"""Begin authentication.
|
|
|
|
Note we always return True here. False bypasses the whole authentication
|
|
flow.
|
|
|
|
"""
|
|
LOG.debug("Started authentication flow for user %s", username)
|
|
allowed_registration_sources: list[IPvAnyNetwork] = []
|
|
if self.registration_enabled and not self.allow_registration_from:
|
|
allowed_registration_sources.append(ipaddress.IPv4Network("0.0.0.0/0"))
|
|
allowed_registration_sources.append(ipaddress.IPv6Network("::/0"))
|
|
elif self.registration_enabled and self.allow_registration_from:
|
|
allowed_registration_sources = self.allow_registration_from
|
|
|
|
assert self._conn is not None, "Error: No connection found."
|
|
if client := await self.backend.get_client(username):
|
|
LOG.debug("Client lookup sucessful: %r", client)
|
|
if key := self.resolve_client_key(client):
|
|
LOG.debug("Loaded public key for client %s\n%s", client.name, key)
|
|
self._conn.set_extra_info(client=client)
|
|
self._conn.set_authorized_keys(key)
|
|
else:
|
|
await audit_event(
|
|
self.backend,
|
|
"Client denied due to policy",
|
|
Operation.DENY,
|
|
client,
|
|
origin=self.client_ip,
|
|
)
|
|
LOG.warning(
|
|
"Client connection denied. Source: %s, policy: %r.",
|
|
self.client_ip,
|
|
client.policies,
|
|
)
|
|
elif allowed_registration_sources and self.client_ip:
|
|
client_ip = ipaddress.ip_address(self.client_ip)
|
|
for network in allowed_registration_sources:
|
|
if client_ip.version != network.version:
|
|
continue
|
|
if client_ip in network:
|
|
self._conn.set_extra_info(provided_username=username)
|
|
self._conn.set_extra_info(
|
|
allow_registration_from=self.allow_registration_from
|
|
)
|
|
LOG.info(
|
|
"Registration enabled, and client is not recognized. Bypassing authentication."
|
|
)
|
|
return False
|
|
else:
|
|
await audit_event(
|
|
self.backend,
|
|
"Received registration command from unauthorized subnet.",
|
|
Operation.DENY,
|
|
origin=self.client_ip,
|
|
username=username,
|
|
)
|
|
|
|
LOG.warning(
|
|
"Registration not permitted for username=%s, origin: %s",
|
|
username,
|
|
self.client_ip,
|
|
)
|
|
|
|
LOG.debug("Continuing to regular authentication")
|
|
return True
|
|
|
|
def resolve_client_key(self, client: Client) -> asyncssh.SSHAuthorizedKeys | None:
|
|
"""Resolve the client key.
|
|
|
|
Returns the key object only if the client is allowed to connect
|
|
according to its policy.
|
|
"""
|
|
if not self._conn:
|
|
return None
|
|
remote_ip = str(self._conn.get_extra_info("peername")[0])
|
|
LOG.debug("Validating client %s connection from %s", client.name, remote_ip)
|
|
LOG.debug("Loading client public key %r", client.public_key)
|
|
if self.check_connection_allowed(client, remote_ip):
|
|
return asyncssh.import_authorized_keys(client.public_key)
|
|
return None
|
|
|
|
def check_connection_allowed(self, client: Client, source: str) -> bool:
|
|
"""Check if the client is allowed to connect."""
|
|
source_ip = ipaddress.ip_address(source)
|
|
policies = [ipaddress.ip_network(policy) for policy in client.policies]
|
|
|
|
valid_source = [source_ip in policy for policy in policies]
|
|
LOG.debug("Valid sources %r from policies %r", valid_source, policies)
|
|
return any(valid_source)
|
|
|
|
|
|
def get_server_key(basedir: Path | None = None) -> str:
|
|
"""Resolve server key.
|
|
|
|
TODO: Is one key enough? Should we generate more keys?
|
|
"""
|
|
filename = Path(f"ssh_host_{constants.SERVER_KEY_TYPE}_key")
|
|
if basedir:
|
|
filename = basedir / filename
|
|
if filename.exists():
|
|
return str(filename.absolute())
|
|
# FIXME: There's a weird typing warning here that I need to investigate.
|
|
private_key = asyncssh.generate_private_key("ssh-ed25519", comment="sshecret-sshd")
|
|
with open(filename, "wb") as f:
|
|
f.write(private_key.export_private_key())
|
|
|
|
return str(filename.absolute())
|
|
|
|
|
|
async def run_ssh_server(
|
|
backend: SshecretBackend,
|
|
listen_address: str,
|
|
port: int,
|
|
keys: list[str],
|
|
registration: ClientRegistrationSettings,
|
|
enable_ping_command: bool = False,
|
|
) -> asyncssh.SSHAcceptor:
|
|
"""Run the server."""
|
|
server = partial(
|
|
AsshyncServer,
|
|
backend=backend,
|
|
registration=registration,
|
|
enable_ping_command=enable_ping_command,
|
|
)
|
|
server = await asyncssh.create_server(
|
|
server,
|
|
listen_address,
|
|
port,
|
|
server_host_keys=keys,
|
|
process_factory=dispatch_command,
|
|
)
|
|
return server
|
|
|
|
|
|
async def start_sshecret_sshd(
|
|
settings: ServerSettings | None = None,
|
|
) -> asyncssh.SSHAcceptor:
|
|
"""Start the server."""
|
|
server_key = get_server_key()
|
|
|
|
if not settings:
|
|
settings = ServerSettings() # pyright: ignore[reportCallIssue]
|
|
|
|
backend = SshecretBackend(str(settings.backend_url), settings.backend_token)
|
|
return await run_ssh_server(
|
|
backend=backend,
|
|
listen_address=settings.listen_address,
|
|
port=settings.port,
|
|
keys=[server_key],
|
|
registration=settings.registration,
|
|
enable_ping_command=settings.enable_ping_command,
|
|
)
|