Complete sshd package with tests

This commit is contained in:
2025-05-10 08:27:16 +02:00
parent 3719a2611d
commit 4f970a3f71
12 changed files with 472 additions and 103 deletions

View File

@ -2,8 +2,10 @@
import logging
import asyncio
import sys
from pydantic_settings import CliApp
from typing import cast
import click
from pydantic import ValidationError
from .settings import ServerSettings
from .ssh_server import start_server
@ -17,22 +19,38 @@ LOG.addHandler(handler)
LOG.setLevel(logging.INFO)
def cli(args: list[str] | None = None) -> None:
"""Run CLI app."""
try:
settings = ServerSettings()
except Exception:
print("One or more settings could not be resolved.")
CliApp.run(ServerSettings, ["--help"])
sys.exit(1)
if settings.debug:
@click.group()
@click.option("--debug", is_flag=True)
@click.pass_context
def cli(ctx: click.Context, debug: bool) -> None:
"""Sshecret Admin."""
if debug:
LOG.setLevel(logging.DEBUG)
try:
settings = ServerSettings() # pyright: ignore[reportCallIssue]
except ValidationError as e:
raise click.ClickException(
"Error: One or more required environment options are missing."
) from e
ctx.obj = settings
@cli.command("run")
@click.option("--host")
@click.option("--port", type=click.INT)
@click.pass_context
def cli_run(ctx: click.Context, host: str | None, port: int | None) -> None:
"""Run the server."""
settings = cast(ServerSettings, ctx.obj)
if host:
settings.listen_address = host
if port:
settings.port = port
loop = asyncio.new_event_loop()
loop.run_until_complete(start_server(settings))
print(f"Starting SSH server: {settings.listen_address}:{settings.port}")
title = click.style("Sshecret SSH Daemon", fg="red", bold=True)
click.echo(f"Starting {title}: {settings.listen_address}:{settings.port}")
try:
loop.run_forever()
except KeyboardInterrupt:
@ -40,7 +58,6 @@ def cli(args: list[str] | None = None) -> None:
sys.exit()
if __name__ == "__main__":
"""Run CLI app."""
cli()

View File

@ -4,5 +4,11 @@ ERROR_NO_COMMAND_RECEIVED = "Error: No command was received from the client."
ERROR_SOURCE_IP_NOT_ALLOWED = (
"Error: Client not authorized to connect from the given host."
)
ERROR_NO_PUBLIC_KEY = "Error: No valid public key received."
ERROR_INVALID_KEY_TYPE = "Error: Invalid key type: Only RSA keys are supported."
ERROR_UNKNOWN_COMMAND = "Error: The given command was not understood."
SERVER_KEY_TYPE = "ed25519"
ERROR_BACKEND_ERROR = "Error: Unexpected response or error from backend"
ERROR_INFO_BACKEND_GONE = "Unexpected error: Backend connection lost."
ERROR_INFO_USERNAME_GONE = "Unexpected error: Username lost in transit."
ERROR_INFO_REMOTE_IP_GONE = "Unexpected error: Client connection details lost in transit."

View File

@ -1,18 +1,48 @@
"""SSH Server settings."""
from pydantic import AnyHttpUrl, Field, AliasChoices
from pydantic_settings import BaseSettings, SettingsConfigDict
import ipaddress
from typing import Annotated, Any
from pydantic import AnyHttpUrl, BaseModel, Field, IPvAnyNetwork, field_validator
from pydantic_settings import BaseSettings, ForceDecode, SettingsConfigDict
DEFAULT_LISTEN_PORT = 2222
class ServerSettings(BaseSettings, cli_parse_args=True, cli_exit_on_error=True):
class ClientRegistrationSettings(BaseModel):
"""Client registration settings."""
enabled: bool = False
allow_from: Annotated[list[IPvAnyNetwork], ForceDecode] = Field(default_factory=list)
@field_validator('allow_from', mode="before")
@classmethod
def ensure_allow_from_list(cls, value: Any) -> list[IPvAnyNetwork]:
"""Convert allow_from to a list."""
allow_from: list[IPvAnyNetwork] = []
if isinstance(value, list):
entries = value
elif isinstance(value, str):
entries = value.split(",")
else:
raise ValueError("Error: Unknown format for allowed_from.")
for entry in entries:
if isinstance(entry, str):
allow_from.append(ipaddress.ip_network(entry))
elif isinstance(entry, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
allow_from.append(entry)
return allow_from
class ServerSettings(BaseSettings):
"""Server Settings."""
model_config = SettingsConfigDict(env_file=".sshd.env", env_prefix="sshecret_sshd_")
model_config = SettingsConfigDict(env_file=".sshd.env", env_prefix="sshecret_sshd_", env_nested_delimiter='_')
backend_url: AnyHttpUrl = Field(validation_alias=AliasChoices("backend-url", "sshecret_backend_url"))
backend_token: str = Field(validation_alias=AliasChoices("backend-token", "sshecret_sshd_backend_token"))
listen_address: str = Field(default="", alias="listen")
backend_url: AnyHttpUrl = Field(alias="sshecret_backend_url")
backend_token: str
listen_address: str = Field(default="127.0.0.1")
port: int = DEFAULT_LISTEN_PORT
registration: ClientRegistrationSettings = Field(default_factory=ClientRegistrationSettings)
debug: bool = False
enable_ping_command: bool = False

View File

@ -1,21 +1,21 @@
"""SSH Server implementation."""
import logging
import uuid
import asyncssh
import ipaddress
from collections.abc import Awaitable
from datetime import datetime, timezone
from functools import partial
from pathlib import Path
from typing import Any, Callable, cast, override
from typing import Callable, cast, override
from pydantic import IPvAnyNetwork
from . import constants
from sshecret.backend import AuditLog, SshecretBackend, Client
from .settings import ServerSettings
from sshecret.backend import SshecretBackend, Client, Operation, SubSystem
from .settings import ServerSettings, ClientRegistrationSettings
LOG = logging.getLogger(__name__)
@ -35,71 +35,39 @@ class CommandError(Exception):
def audit_process(
backend: SshecretBackend,
process: asyncssh.SSHServerProcess[str],
operation: Operation,
message: str,
secret: str | None = None,
**data: str,
) -> None:
"""Add an audit event from process."""
command = get_process_command(process)
client = get_info_client(process)
username = get_info_username(process)
remote_ip = get_info_remote_ip(process)
operation = "SSH_EVENT"
obj_name: str | None = None
obj_id: str | None = None
remote_ip = get_info_remote_ip(process) or "UNKNOWN"
if username:
data["username"] = username
if command and not secret:
cmd, cmd_args = command
obj_id = " ".join(cmd_args)
elif secret:
obj_name = "ClientSecret"
obj_id = secret
entry = AuditLog(
subsystem="ssh",
operation=operation,
object=obj_name,
object_id=obj_id,
message=message,
origin=remote_ip,
)
if client:
entry.client_id = str(client.id)
entry.client_name = client.name
elif username:
entry.client_name = username
backend.add_audit_log_sync(entry)
if cmd:
data["command"] = cmd
data["args"] = " ".join(cmd_args)
backend.audit(SubSystem.SSHD).write(operation, message, remote_ip, client, secret=None, secret_name=secret, **data)
def audit_event(
backend: SshecretBackend,
message: str,
operation: str = "SSH_EVENT",
operation: Operation,
client: Client | None = None,
origin: str | None = None,
secret: str | None = None,
) -> None:
"""Add an audit event."""
entry = AuditLog(
client_id=None,
client_name=None,
object=None,
object_id=None,
subsystem="ssh",
operation=operation,
message=message,
origin=origin,
)
if client:
entry.client_id = str(client.id)
entry.client_name = client.name
if secret:
entry.object = "ClientSecret"
entry.object_id = secret
backend.add_audit_log_sync(entry)
if not origin:
origin = "UNKNOWN"
backend.audit(SubSystem.SSHD).write(operation, message, origin, client, secret=None, secret_name=secret)
def verify_key_input(public_key: str) -> str | None:
"""Verify key input."""
@ -118,6 +86,7 @@ def get_process_command(
if not process.command:
return (None, [])
argv = process.command.split(" ")
LOG.debug("Args: %r", argv)
return (argv[0], argv[1:])
@ -149,7 +118,12 @@ def get_info_remote_ip(process: asyncssh.SSHServerProcess[str]) -> str | None:
return remote_ip
# remote_ip = str(self._conn.get_extra_info("peername")[0])
def get_info_allowed_registration(process: asyncssh.SSHServerProcess[str]) -> list[IPvAnyNetwork] | None:
"""Get allowed networks to allow registration from."""
allowed_registration = cast(list[IPvAnyNetwork] | None, process.get_extra_info("allow_registration_from", None))
return allowed_registration
def get_optional_commands(process: asyncssh.SSHServerProcess[str]) -> dict[str, bool]:
@ -197,12 +171,12 @@ async def register_client(
"""Register a new client."""
public_key = await get_stdin_public_key(process)
if not public_key:
raise CommandError("Aborted. No valid public key received.")
raise CommandError(constants.ERROR_NO_PUBLIC_KEY)
key = asyncssh.import_public_key(public_key)
if key.algorithm.decode() != "ssh-rsa":
raise CommandError("Error: Only RSA keys are supported!")
audit_process(backend, process, "Registering new client")
raise CommandError(constants.ERROR_INVALID_KEY_TYPE)
audit_process(backend, process, Operation.CREATE, "Registering new client")
LOG.debug("Registering client %s with public key %s", username, public_key)
await backend.create_client(username, public_key)
@ -214,14 +188,16 @@ async def get_secret(
origin: str,
) -> str:
"""Handle get secret requests from client."""
LOG.debug("Recieved command: %r", secret_name)
if not secret_name or secret_name not in client.secrets:
LOG.debug("Recieved command: get_secret %r", secret_name)
if not secret_name:
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
if secret_name not in client.secrets:
raise CommandError(constants.ERROR_NO_SECRET_FOUND)
audit_event(
backend,
"Client requested secret",
operation="get_secret",
operation=Operation.READ,
client=client,
origin=origin,
secret=secret_name,
@ -229,10 +205,13 @@ async def get_secret(
# Look up secret
try:
return await backend.get_client_secret(client.name, secret_name)
secret = await backend.get_client_secret(client.name, secret_name)
if not secret:
raise CommandError(constants.ERROR_NO_SECRET_FOUND)
return secret
except Exception as exc:
LOG.debug(exc, exc_info=True)
raise CommandError("Unexpected error from backend") from exc
raise CommandError(constants.ERROR_BACKEND_ERROR) from exc
async def dispatch_no_cmd(process: asyncssh.SSHServerProcess[str]) -> None:
@ -249,10 +228,32 @@ async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None
"""Dispatch the register command."""
backend = get_info_backend(process)
if not backend:
raise CommandError("Unexpected error: Backend disappeared.")
raise CommandError(constants.ERROR_INFO_BACKEND_GONE)
username = get_info_username(process)
if not username:
raise CommandError("Unexpected error: Username was lost.")
raise CommandError(constants.ERROR_INFO_USERNAME_GONE)
allowed_networks = get_info_allowed_registration(process)
if not allowed_networks:
process.stdout.write("Unauthorized.\n")
audit_process(backend, process, Operation.DENY, "Received registration command, but no subnets are allowed.")
return
remote_ip = get_info_remote_ip(process)
if not remote_ip:
raise CommandError(constants.ERROR_INFO_REMOTE_IP_GONE)
client_address = ipaddress.ip_address(remote_ip)
for network in allowed_networks:
if client_address in network:
break
else:
audit_process(backend, process, Operation.DENY, "Received registration command from unauthorized subnet.")
process.stdout.write("Unauthorized.\n")
return
await register_client(process, backend, username)
process.stdout.write("Client registered\n.")
@ -262,7 +263,7 @@ async def dispatch_cmd_get_secret(process: asyncssh.SSHServerProcess[str]) -> No
"""Dispatch the get_secret command."""
backend = get_info_backend(process)
if not backend:
raise CommandError("Unexpected error: Backend disappeared.")
raise CommandError(constants.ERROR_INFO_BACKEND_GONE)
client = get_info_client(process)
if not client:
@ -311,6 +312,7 @@ async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
process.stderr.write(str(e))
exit_code = 1
LOG.debug("Command processing finished.")
process.exit(exit_code)
@ -319,16 +321,18 @@ class AsshyncServer(asyncssh.SSHServer):
def __init__(
self,
backend_url: str,
backend_token: str,
with_register: bool = True,
with_ping: bool = True,
backend: SshecretBackend,
registration: ClientRegistrationSettings,
enable_ping_command: bool = False,
) -> None:
"""Initialize server."""
self.backend: SshecretBackend = SshecretBackend(backend_url, backend_token)
self.backend: SshecretBackend = backend
self._conn: asyncssh.SSHServerConnection | None = None
self.registration_enabled: bool = with_register
self.ping_enabled: bool = with_ping
self.registration_enabled: bool = registration.enabled
self.allow_registration_from: list[IPvAnyNetwork] | None = None
if registration.enabled:
self.allow_registration_from = registration.allow_from
self.ping_enabled: bool = enable_ping_command
self.client_ip: str | None = None
@override
@ -359,9 +363,9 @@ class AsshyncServer(asyncssh.SSHServer):
if not self._conn:
return True
if client := await self.backend.get_client(username):
LOG.debug("Client lookup sucessful.")
LOG.debug("Client lookup sucessful: %r", client)
if key := self.resolve_client_key(client):
LOG.debug("Loaded public key for client %s", client.name)
LOG.debug("Loaded public key for client %s\n%s", client.name, key)
self._conn.set_extra_info(client=client)
self._conn.set_authorized_keys(key)
else:
@ -369,15 +373,18 @@ class AsshyncServer(asyncssh.SSHServer):
audit_event(
self.backend,
"Client denied due to policy",
"DENY",
Operation.DENY,
client,
origin=self.client_ip,
)
LOG.warning("Client connection denied due to policy.")
else:
elif self.registration_enabled:
self._conn.set_extra_info(provided_username=username)
self._conn.set_extra_info(allow_registration_from=self.allow_registration_from)
LOG.warning("Registration enabled, and client is not recognized. Bypassing authentication.")
return False
LOG.debug("Continuing to regular authentication")
return True
@override
@ -403,6 +410,7 @@ class AsshyncServer(asyncssh.SSHServer):
return None
remote_ip = str(self._conn.get_extra_info("peername")[0])
LOG.debug("Validating client %s connection from %s", client.name, remote_ip)
LOG.debug("Loading client public key %r", client.public_key)
if self.check_connection_allowed(client, remote_ip):
return asyncssh.import_authorized_keys(client.public_key)
return None
@ -428,7 +436,7 @@ def get_server_key(basedir: Path | None = None) -> str:
if filename.exists():
return str(filename.absolute())
# FIXME: There's a weird typing warning here that I need to investigate.
private_key = asyncssh.generate_private_key("ssh-ed25519", "sshecret-sshd")
private_key = asyncssh.generate_private_key("ssh-ed25519", comment="sshecret-sshd")
with open(filename, "wb") as f:
f.write(private_key.export_private_key())
@ -436,15 +444,19 @@ def get_server_key(basedir: Path | None = None) -> str:
async def run_ssh_server(
backend_url: str,
backend_token: str,
backend: SshecretBackend,
listen_address: str,
port: int,
keys: list[str],
registration: ClientRegistrationSettings,
enable_ping_command: bool = False,
) -> asyncssh.SSHAcceptor:
"""Run the server."""
server = partial(
AsshyncServer, backend_url=str(backend_url), backend_token=backend_token
AsshyncServer,
backend=backend,
registration=registration,
enable_ping_command=enable_ping_command,
)
server = await asyncssh.create_server(
server,
@ -463,10 +475,12 @@ async def start_server(settings: ServerSettings | None = None) -> None:
if not settings:
settings = ServerSettings() # pyright: ignore[reportCallIssue]
backend = SshecretBackend(str(settings.backend_url), settings.backend_token)
await run_ssh_server(
str(settings.backend_url),
settings.backend_token,
settings.listen_address,
settings.port,
[server_key],
backend=backend,
listen_address=settings.listen_address,
port=settings.port,
keys=[server_key],
registration=settings.registration,
enable_ping_command=settings.enable_ping_command,
)