Check in changes to sshd module

This commit is contained in:
2025-04-30 08:25:15 +02:00
parent 2a668059ef
commit 6d37f7d251
4 changed files with 193 additions and 158 deletions

View File

@ -1,18 +1,20 @@
"""SSH Server implementation."""
import logging
import uuid
import asyncssh
import ipaddress
from collections.abc import Awaitable
from datetime import datetime, timezone
from functools import partial
from pathlib import Path
from typing import Awaitable, Callable, cast, override
from typing import Any, Callable, cast, override
from . import constants
from .backend_client import BackendClient
from .types import Client
from sshecret.backend import AuditLog, SshecretBackend, Client
from .settings import ServerSettings
@ -21,11 +23,84 @@ LOG = logging.getLogger(__name__)
CommandDispatch = Callable[[asyncssh.SSHServerProcess[str]], Awaitable[None]]
PeernameV4 = tuple[str, int]
PeernameV6 = tuple[str, int, int, int]
Peername = PeernameV4 | PeernameV6
class CommandError(Exception):
"""Error class for errors during command processing."""
def audit_process(
backend: SshecretBackend,
process: asyncssh.SSHServerProcess[str],
message: str,
secret: str | None = None,
) -> None:
"""Add an audit event from process."""
command = get_process_command(process)
client = get_info_client(process)
username = get_info_username(process)
remote_ip = get_info_remote_ip(process)
operation = "SSH_EVENT"
obj_name: str | None = None
obj_id: str | None = None
if command and not secret:
cmd, cmd_args = command
obj_id = " ".join(cmd_args)
elif secret:
obj_name = "ClientSecret"
obj_id = secret
entry = AuditLog(
subsystem="ssh",
operation=operation,
object=obj_name,
object_id=obj_id,
message=message,
origin=remote_ip,
)
if client:
entry.client_id = str(client.id)
entry.client_name = client.name
elif username:
entry.client_name = username
backend.add_audit_log_sync(entry)
def audit_event(
backend: SshecretBackend,
message: str,
operation: str = "SSH_EVENT",
client: Client | None = None,
origin: str | None = None,
secret: str | None = None,
) -> None:
"""Add an audit event."""
entry = AuditLog(
client_id=None,
client_name=None,
object=None,
object_id=None,
subsystem="ssh",
operation=operation,
message=message,
origin=origin,
)
if client:
entry.client_id = str(client.id)
entry.client_name = client.name
if secret:
entry.object = "ClientSecret"
entry.object_id = secret
backend.add_audit_log_sync(entry)
def verify_key_input(public_key: str) -> str | None:
"""Verify key input."""
try:
@ -46,9 +121,9 @@ def get_process_command(
return (argv[0], argv[1:])
def get_info_backend(process: asyncssh.SSHServerProcess[str]) -> BackendClient | None:
def get_info_backend(process: asyncssh.SSHServerProcess[str]) -> SshecretBackend | None:
"""Get backend from process."""
backend = cast("BackendClient | None", process.get_extra_info("backend", None))
backend = cast("SshecretBackend | None", process.get_extra_info("backend", None))
return backend
@ -64,6 +139,31 @@ def get_info_username(process: asyncssh.SSHServerProcess[str]) -> str | None:
return username
def get_info_remote_ip(process: asyncssh.SSHServerProcess[str]) -> str | None:
"""Get remote IP."""
peername = cast("Peername | None", process.get_extra_info("peername", None))
remote_ip: str | None = None
if peername:
remote_ip = peername[0]
return remote_ip
# remote_ip = str(self._conn.get_extra_info("peername")[0])
def get_optional_commands(process: asyncssh.SSHServerProcess[str]) -> dict[str, bool]:
"""Get optional command state."""
with_registration = cast(
bool, process.get_extra_info("registration_enabled", False)
)
with_ping = cast(bool, process.get_extra_info("ping_enabled", False))
return {
"registration": with_registration,
"ping": with_ping,
}
async def get_stdin_public_key(process: asyncssh.SSHServerProcess[str]) -> str | None:
"""Get public key from stdin."""
process.stdout.write("Enter public key:\n")
@ -76,6 +176,7 @@ async def get_stdin_public_key(process: asyncssh.SSHServerProcess[str]) -> str |
process.stdout.write("Invalid key. Must be RSA Public Key.\n")
except asyncssh.BreakReceived:
pass
process.stdout.write("OK\n")
return public_key
@ -90,7 +191,7 @@ def get_info_user_and_public_key(
async def register_client(
process: asyncssh.SSHServerProcess[str],
backend: BackendClient,
backend: SshecretBackend,
username: str,
) -> None:
"""Register a new client."""
@ -101,35 +202,49 @@ async def register_client(
key = asyncssh.import_public_key(public_key)
if key.algorithm.decode() != "ssh-rsa":
raise CommandError("Error: Only RSA keys are supported!")
audit_process(backend, process, "Registering new client")
LOG.debug("Registering client %s with public key %s", username, public_key)
await backend.register_client(username, public_key)
await backend.create_client(username, public_key)
async def get_secret(
backend: BackendClient,
backend: SshecretBackend,
client: Client,
secret_name: str,
origin: str,
) -> str:
"""Handle get secret requests from client."""
LOG.debug("Recieved command: %r", secret_name)
if not secret_name or secret_name not in client.secrets:
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
audit_event(
backend,
"Client requested secret",
operation="get_secret",
client=client,
origin=origin,
secret=secret_name,
)
# Look up secret
try:
secret = await backend.lookup_secret(client.name, secret_name)
return await backend.get_client_secret(client.name, secret_name)
except Exception as exc:
LOG.debug(exc, exc_info=True)
raise CommandError("Unexpected error from backend") from exc
return secret
async def dispatch_no_cmd(process: asyncssh.SSHServerProcess[str]) -> None:
"""Dispatch for no command."""
raise CommandError(constants.ERROR_NO_COMMAND_RECEIVED)
async def dispatch_cmd_ping(process: asyncssh.SSHServerProcess[str]) -> None:
"""Dispatch the ping command."""
process.stdout.write("PONG\n")
async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None:
"""Dispatch the register command."""
backend = get_info_backend(process)
@ -157,7 +272,8 @@ async def dispatch_cmd_get_secret(process: asyncssh.SSHServerProcess[str]) -> No
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
secret_name = args[0]
secret = await get_secret(backend, client, secret_name)
origin = get_info_remote_ip(process) or "Unknown"
secret = await get_secret(backend, client, secret_name, origin)
process.stdout.write(secret)
@ -169,9 +285,14 @@ async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
process.exit(1)
return
cmdmap: dict[str, CommandDispatch] = {
"register": dispatch_cmd_register,
"get_secret": dispatch_cmd_get_secret,
}
extra_commands = get_optional_commands(process)
if "registration" in extra_commands:
cmdmap["register"] = dispatch_cmd_register
if "ping" in extra_commands:
cmdmap["ping"] = dispatch_cmd_ping
if command not in cmdmap:
process.stderr.write(constants.ERROR_UNKNOWN_COMMAND)
process.exit(1)
@ -193,57 +314,33 @@ async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
process.exit(exit_code)
async def handle_secret(process: asyncssh.SSHServerProcess[str]) -> None:
"""Handle get secret requests from client."""
backend = process.get_extra_info("backend")
if not backend:
process.stderr.write("Unexpected Error: Lost connection with backend object.")
process.exit(1)
return
assert isinstance(backend, BackendClient)
client = process.get_extra_info("client")
if not client:
process.stderr.write(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
process.exit(1)
return
assert isinstance(client, Client), "Error: Unexpected client type received"
secret_name = process.command
LOG.debug("Recieved command: %r", secret_name)
if not secret_name or secret_name not in client.secrets:
process.stderr.write(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
process.exit(1)
return
# Look up secret
try:
secret = await backend.lookup_secret(client.name, secret_name)
except Exception as exc:
process.stderr.write("Unexpected error from backend:\n")
process.stderr.write(str(exc))
LOG.debug(exc, exc_info=True)
process.exit(1)
return
process.stdout.write(secret)
process.exit(0)
class AsshyncServer(asyncssh.SSHServer):
"""Asynchronous SSH server implementation."""
def __init__(self, settings: ServerSettings | None = None) -> None:
def __init__(
self,
backend_url: str,
backend_token: str,
with_register: bool = True,
with_ping: bool = True,
) -> None:
"""Initialize server."""
self.backend: BackendClient = BackendClient(settings)
self.backend: SshecretBackend = SshecretBackend(backend_url, backend_token)
self._conn: asyncssh.SSHServerConnection | None = None
self.registration_enabled: bool = with_register
self.ping_enabled: bool = with_ping
self.client_ip: str | None = None
@override
def connection_made(self, conn: asyncssh.SSHServerConnection) -> None:
"""Handle incoming connection."""
peername = conn.get_extra_info("peername")
LOG.debug("Connection established from %r", peername)
self.client_ip = peername[0]
self._conn = conn
self._conn.set_extra_info(backend=self.backend)
self._conn.set_extra_info(registration_enabled=self.registration_enabled)
self._conn.set_extra_info(ping_enabled=self.ping_enabled)
@override
def password_auth_supported(self) -> bool:
@ -261,13 +358,21 @@ class AsshyncServer(asyncssh.SSHServer):
LOG.debug("Started authentication flow for user %s", username)
if not self._conn:
return True
if client := await self.backend.lookup_client(username):
if client := await self.backend.get_client(username):
LOG.debug("Client lookup sucessful.")
if key := self.resolve_client_key(client):
LOG.debug("Loaded public key for client %s", client.name)
self._conn.set_extra_info(client=client)
self._conn.set_authorized_keys(key)
else:
audit_event(
self.backend,
"Client denied due to policy",
"DENY",
client,
origin=self.client_ip,
)
LOG.warning("Client connection denied due to policy.")
else:
self._conn.set_extra_info(provided_username=username)
@ -308,37 +413,60 @@ class AsshyncServer(asyncssh.SSHServer):
policies = [ipaddress.ip_network(policy) for policy in client.policies]
valid_source = [source_ip in policy for policy in policies]
LOG.debug("Valid sources %r from policies %r", valid_source, policies)
return any(valid_source)
def get_server_key() -> str:
def get_server_key(basedir: Path | None = None) -> str:
"""Resolve server key.
TODO: Is one key enough? Should we generate more keys?
"""
filename = f"ssh_host_{constants.SERVER_KEY_TYPE}_key"
if Path(filename).exists():
return filename
filename = Path(f"ssh_host_{constants.SERVER_KEY_TYPE}_key")
if basedir:
filename = basedir / filename
if filename.exists():
return str(filename.absolute())
# FIXME: There's a weird typing warning here that I need to investigate.
private_key = asyncssh.generate_private_key("ssh-ed25519", "sshecret-sshd")
with open(filename, "wb") as f:
f.write(private_key.export_private_key())
return filename
return str(filename.absolute())
async def run_ssh_server(
backend_url: str,
backend_token: str,
listen_address: str,
port: int,
keys: list[str],
) -> asyncssh.SSHAcceptor:
"""Run the server."""
server = partial(
AsshyncServer, backend_url=str(backend_url), backend_token=backend_token
)
server = await asyncssh.create_server(
server,
listen_address,
port,
server_host_keys=keys,
process_factory=dispatch_command,
)
return server
async def start_server(settings: ServerSettings | None = None) -> None:
"""Start the server."""
server_key = get_server_key()
server = partial(AsshyncServer, settings=settings)
if not settings:
settings = ServerSettings() # pyright: ignore[reportCallIssue]
await asyncssh.create_server(
server,
await run_ssh_server(
str(settings.backend_url),
settings.backend_token,
settings.listen_address,
settings.port,
server_host_keys=[server_key],
process_factory=dispatch_command,
[server_key],
)