Complete sshd package with tests
This commit is contained in:
@ -1,21 +1,21 @@
|
||||
"""SSH Server implementation."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import asyncssh
|
||||
import ipaddress
|
||||
|
||||
from collections.abc import Awaitable
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, cast, override
|
||||
from typing import Callable, cast, override
|
||||
|
||||
from pydantic import IPvAnyNetwork
|
||||
|
||||
from . import constants
|
||||
|
||||
from sshecret.backend import AuditLog, SshecretBackend, Client
|
||||
from .settings import ServerSettings
|
||||
from sshecret.backend import SshecretBackend, Client, Operation, SubSystem
|
||||
from .settings import ServerSettings, ClientRegistrationSettings
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -35,71 +35,39 @@ class CommandError(Exception):
|
||||
def audit_process(
|
||||
backend: SshecretBackend,
|
||||
process: asyncssh.SSHServerProcess[str],
|
||||
operation: Operation,
|
||||
message: str,
|
||||
secret: str | None = None,
|
||||
**data: str,
|
||||
) -> None:
|
||||
"""Add an audit event from process."""
|
||||
command = get_process_command(process)
|
||||
client = get_info_client(process)
|
||||
username = get_info_username(process)
|
||||
remote_ip = get_info_remote_ip(process)
|
||||
operation = "SSH_EVENT"
|
||||
obj_name: str | None = None
|
||||
obj_id: str | None = None
|
||||
remote_ip = get_info_remote_ip(process) or "UNKNOWN"
|
||||
if username:
|
||||
data["username"] = username
|
||||
|
||||
if command and not secret:
|
||||
cmd, cmd_args = command
|
||||
obj_id = " ".join(cmd_args)
|
||||
elif secret:
|
||||
obj_name = "ClientSecret"
|
||||
obj_id = secret
|
||||
|
||||
entry = AuditLog(
|
||||
subsystem="ssh",
|
||||
operation=operation,
|
||||
object=obj_name,
|
||||
object_id=obj_id,
|
||||
message=message,
|
||||
origin=remote_ip,
|
||||
)
|
||||
if client:
|
||||
entry.client_id = str(client.id)
|
||||
entry.client_name = client.name
|
||||
elif username:
|
||||
entry.client_name = username
|
||||
|
||||
backend.add_audit_log_sync(entry)
|
||||
if cmd:
|
||||
data["command"] = cmd
|
||||
data["args"] = " ".join(cmd_args)
|
||||
|
||||
backend.audit(SubSystem.SSHD).write(operation, message, remote_ip, client, secret=None, secret_name=secret, **data)
|
||||
|
||||
def audit_event(
|
||||
backend: SshecretBackend,
|
||||
message: str,
|
||||
operation: str = "SSH_EVENT",
|
||||
operation: Operation,
|
||||
client: Client | None = None,
|
||||
origin: str | None = None,
|
||||
secret: str | None = None,
|
||||
) -> None:
|
||||
"""Add an audit event."""
|
||||
entry = AuditLog(
|
||||
client_id=None,
|
||||
client_name=None,
|
||||
object=None,
|
||||
object_id=None,
|
||||
subsystem="ssh",
|
||||
operation=operation,
|
||||
message=message,
|
||||
origin=origin,
|
||||
)
|
||||
if client:
|
||||
entry.client_id = str(client.id)
|
||||
entry.client_name = client.name
|
||||
|
||||
if secret:
|
||||
entry.object = "ClientSecret"
|
||||
entry.object_id = secret
|
||||
|
||||
backend.add_audit_log_sync(entry)
|
||||
|
||||
if not origin:
|
||||
origin = "UNKNOWN"
|
||||
backend.audit(SubSystem.SSHD).write(operation, message, origin, client, secret=None, secret_name=secret)
|
||||
|
||||
def verify_key_input(public_key: str) -> str | None:
|
||||
"""Verify key input."""
|
||||
@ -118,6 +86,7 @@ def get_process_command(
|
||||
if not process.command:
|
||||
return (None, [])
|
||||
argv = process.command.split(" ")
|
||||
LOG.debug("Args: %r", argv)
|
||||
return (argv[0], argv[1:])
|
||||
|
||||
|
||||
@ -149,7 +118,12 @@ def get_info_remote_ip(process: asyncssh.SSHServerProcess[str]) -> str | None:
|
||||
|
||||
return remote_ip
|
||||
|
||||
# remote_ip = str(self._conn.get_extra_info("peername")[0])
|
||||
def get_info_allowed_registration(process: asyncssh.SSHServerProcess[str]) -> list[IPvAnyNetwork] | None:
|
||||
"""Get allowed networks to allow registration from."""
|
||||
|
||||
allowed_registration = cast(list[IPvAnyNetwork] | None, process.get_extra_info("allow_registration_from", None))
|
||||
return allowed_registration
|
||||
|
||||
|
||||
|
||||
def get_optional_commands(process: asyncssh.SSHServerProcess[str]) -> dict[str, bool]:
|
||||
@ -197,12 +171,12 @@ async def register_client(
|
||||
"""Register a new client."""
|
||||
public_key = await get_stdin_public_key(process)
|
||||
if not public_key:
|
||||
raise CommandError("Aborted. No valid public key received.")
|
||||
raise CommandError(constants.ERROR_NO_PUBLIC_KEY)
|
||||
|
||||
key = asyncssh.import_public_key(public_key)
|
||||
if key.algorithm.decode() != "ssh-rsa":
|
||||
raise CommandError("Error: Only RSA keys are supported!")
|
||||
audit_process(backend, process, "Registering new client")
|
||||
raise CommandError(constants.ERROR_INVALID_KEY_TYPE)
|
||||
audit_process(backend, process, Operation.CREATE, "Registering new client")
|
||||
LOG.debug("Registering client %s with public key %s", username, public_key)
|
||||
await backend.create_client(username, public_key)
|
||||
|
||||
@ -214,14 +188,16 @@ async def get_secret(
|
||||
origin: str,
|
||||
) -> str:
|
||||
"""Handle get secret requests from client."""
|
||||
LOG.debug("Recieved command: %r", secret_name)
|
||||
if not secret_name or secret_name not in client.secrets:
|
||||
LOG.debug("Recieved command: get_secret %r", secret_name)
|
||||
if not secret_name:
|
||||
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
||||
if secret_name not in client.secrets:
|
||||
raise CommandError(constants.ERROR_NO_SECRET_FOUND)
|
||||
|
||||
audit_event(
|
||||
backend,
|
||||
"Client requested secret",
|
||||
operation="get_secret",
|
||||
operation=Operation.READ,
|
||||
client=client,
|
||||
origin=origin,
|
||||
secret=secret_name,
|
||||
@ -229,10 +205,13 @@ async def get_secret(
|
||||
|
||||
# Look up secret
|
||||
try:
|
||||
return await backend.get_client_secret(client.name, secret_name)
|
||||
secret = await backend.get_client_secret(client.name, secret_name)
|
||||
if not secret:
|
||||
raise CommandError(constants.ERROR_NO_SECRET_FOUND)
|
||||
return secret
|
||||
except Exception as exc:
|
||||
LOG.debug(exc, exc_info=True)
|
||||
raise CommandError("Unexpected error from backend") from exc
|
||||
raise CommandError(constants.ERROR_BACKEND_ERROR) from exc
|
||||
|
||||
|
||||
async def dispatch_no_cmd(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
@ -249,10 +228,32 @@ async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None
|
||||
"""Dispatch the register command."""
|
||||
backend = get_info_backend(process)
|
||||
if not backend:
|
||||
raise CommandError("Unexpected error: Backend disappeared.")
|
||||
raise CommandError(constants.ERROR_INFO_BACKEND_GONE)
|
||||
username = get_info_username(process)
|
||||
if not username:
|
||||
raise CommandError("Unexpected error: Username was lost.")
|
||||
raise CommandError(constants.ERROR_INFO_USERNAME_GONE)
|
||||
|
||||
allowed_networks = get_info_allowed_registration(process)
|
||||
if not allowed_networks:
|
||||
process.stdout.write("Unauthorized.\n")
|
||||
audit_process(backend, process, Operation.DENY, "Received registration command, but no subnets are allowed.")
|
||||
return
|
||||
|
||||
remote_ip = get_info_remote_ip(process)
|
||||
|
||||
if not remote_ip:
|
||||
raise CommandError(constants.ERROR_INFO_REMOTE_IP_GONE)
|
||||
|
||||
client_address = ipaddress.ip_address(remote_ip)
|
||||
|
||||
for network in allowed_networks:
|
||||
if client_address in network:
|
||||
break
|
||||
else:
|
||||
audit_process(backend, process, Operation.DENY, "Received registration command from unauthorized subnet.")
|
||||
process.stdout.write("Unauthorized.\n")
|
||||
return
|
||||
|
||||
await register_client(process, backend, username)
|
||||
|
||||
process.stdout.write("Client registered\n.")
|
||||
@ -262,7 +263,7 @@ async def dispatch_cmd_get_secret(process: asyncssh.SSHServerProcess[str]) -> No
|
||||
"""Dispatch the get_secret command."""
|
||||
backend = get_info_backend(process)
|
||||
if not backend:
|
||||
raise CommandError("Unexpected error: Backend disappeared.")
|
||||
raise CommandError(constants.ERROR_INFO_BACKEND_GONE)
|
||||
|
||||
client = get_info_client(process)
|
||||
if not client:
|
||||
@ -311,6 +312,7 @@ async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
process.stderr.write(str(e))
|
||||
exit_code = 1
|
||||
|
||||
LOG.debug("Command processing finished.")
|
||||
process.exit(exit_code)
|
||||
|
||||
|
||||
@ -319,16 +321,18 @@ class AsshyncServer(asyncssh.SSHServer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backend_url: str,
|
||||
backend_token: str,
|
||||
with_register: bool = True,
|
||||
with_ping: bool = True,
|
||||
backend: SshecretBackend,
|
||||
registration: ClientRegistrationSettings,
|
||||
enable_ping_command: bool = False,
|
||||
) -> None:
|
||||
"""Initialize server."""
|
||||
self.backend: SshecretBackend = SshecretBackend(backend_url, backend_token)
|
||||
self.backend: SshecretBackend = backend
|
||||
self._conn: asyncssh.SSHServerConnection | None = None
|
||||
self.registration_enabled: bool = with_register
|
||||
self.ping_enabled: bool = with_ping
|
||||
self.registration_enabled: bool = registration.enabled
|
||||
self.allow_registration_from: list[IPvAnyNetwork] | None = None
|
||||
if registration.enabled:
|
||||
self.allow_registration_from = registration.allow_from
|
||||
self.ping_enabled: bool = enable_ping_command
|
||||
self.client_ip: str | None = None
|
||||
|
||||
@override
|
||||
@ -359,9 +363,9 @@ class AsshyncServer(asyncssh.SSHServer):
|
||||
if not self._conn:
|
||||
return True
|
||||
if client := await self.backend.get_client(username):
|
||||
LOG.debug("Client lookup sucessful.")
|
||||
LOG.debug("Client lookup sucessful: %r", client)
|
||||
if key := self.resolve_client_key(client):
|
||||
LOG.debug("Loaded public key for client %s", client.name)
|
||||
LOG.debug("Loaded public key for client %s\n%s", client.name, key)
|
||||
self._conn.set_extra_info(client=client)
|
||||
self._conn.set_authorized_keys(key)
|
||||
else:
|
||||
@ -369,15 +373,18 @@ class AsshyncServer(asyncssh.SSHServer):
|
||||
audit_event(
|
||||
self.backend,
|
||||
"Client denied due to policy",
|
||||
"DENY",
|
||||
Operation.DENY,
|
||||
client,
|
||||
origin=self.client_ip,
|
||||
)
|
||||
LOG.warning("Client connection denied due to policy.")
|
||||
else:
|
||||
elif self.registration_enabled:
|
||||
self._conn.set_extra_info(provided_username=username)
|
||||
self._conn.set_extra_info(allow_registration_from=self.allow_registration_from)
|
||||
LOG.warning("Registration enabled, and client is not recognized. Bypassing authentication.")
|
||||
return False
|
||||
|
||||
LOG.debug("Continuing to regular authentication")
|
||||
return True
|
||||
|
||||
@override
|
||||
@ -403,6 +410,7 @@ class AsshyncServer(asyncssh.SSHServer):
|
||||
return None
|
||||
remote_ip = str(self._conn.get_extra_info("peername")[0])
|
||||
LOG.debug("Validating client %s connection from %s", client.name, remote_ip)
|
||||
LOG.debug("Loading client public key %r", client.public_key)
|
||||
if self.check_connection_allowed(client, remote_ip):
|
||||
return asyncssh.import_authorized_keys(client.public_key)
|
||||
return None
|
||||
@ -428,7 +436,7 @@ def get_server_key(basedir: Path | None = None) -> str:
|
||||
if filename.exists():
|
||||
return str(filename.absolute())
|
||||
# FIXME: There's a weird typing warning here that I need to investigate.
|
||||
private_key = asyncssh.generate_private_key("ssh-ed25519", "sshecret-sshd")
|
||||
private_key = asyncssh.generate_private_key("ssh-ed25519", comment="sshecret-sshd")
|
||||
with open(filename, "wb") as f:
|
||||
f.write(private_key.export_private_key())
|
||||
|
||||
@ -436,15 +444,19 @@ def get_server_key(basedir: Path | None = None) -> str:
|
||||
|
||||
|
||||
async def run_ssh_server(
|
||||
backend_url: str,
|
||||
backend_token: str,
|
||||
backend: SshecretBackend,
|
||||
listen_address: str,
|
||||
port: int,
|
||||
keys: list[str],
|
||||
registration: ClientRegistrationSettings,
|
||||
enable_ping_command: bool = False,
|
||||
) -> asyncssh.SSHAcceptor:
|
||||
"""Run the server."""
|
||||
server = partial(
|
||||
AsshyncServer, backend_url=str(backend_url), backend_token=backend_token
|
||||
AsshyncServer,
|
||||
backend=backend,
|
||||
registration=registration,
|
||||
enable_ping_command=enable_ping_command,
|
||||
)
|
||||
server = await asyncssh.create_server(
|
||||
server,
|
||||
@ -463,10 +475,12 @@ async def start_server(settings: ServerSettings | None = None) -> None:
|
||||
if not settings:
|
||||
settings = ServerSettings() # pyright: ignore[reportCallIssue]
|
||||
|
||||
backend = SshecretBackend(str(settings.backend_url), settings.backend_token)
|
||||
await run_ssh_server(
|
||||
str(settings.backend_url),
|
||||
settings.backend_token,
|
||||
settings.listen_address,
|
||||
settings.port,
|
||||
[server_key],
|
||||
backend=backend,
|
||||
listen_address=settings.listen_address,
|
||||
port=settings.port,
|
||||
keys=[server_key],
|
||||
registration=settings.registration,
|
||||
enable_ping_command=settings.enable_ping_command,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user