Check in changes to sshd module
This commit is contained in:
@ -1,18 +1,20 @@
|
||||
"""SSH Server implementation."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import asyncssh
|
||||
import ipaddress
|
||||
|
||||
from collections.abc import Awaitable
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Awaitable, Callable, cast, override
|
||||
from typing import Any, Callable, cast, override
|
||||
|
||||
from . import constants
|
||||
|
||||
from .backend_client import BackendClient
|
||||
from .types import Client
|
||||
from sshecret.backend import AuditLog, SshecretBackend, Client
|
||||
from .settings import ServerSettings
|
||||
|
||||
|
||||
@ -21,11 +23,84 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
CommandDispatch = Callable[[asyncssh.SSHServerProcess[str]], Awaitable[None]]
|
||||
|
||||
PeernameV4 = tuple[str, int]
|
||||
PeernameV6 = tuple[str, int, int, int]
|
||||
Peername = PeernameV4 | PeernameV6
|
||||
|
||||
|
||||
class CommandError(Exception):
|
||||
"""Error class for errors during command processing."""
|
||||
|
||||
|
||||
def audit_process(
|
||||
backend: SshecretBackend,
|
||||
process: asyncssh.SSHServerProcess[str],
|
||||
message: str,
|
||||
secret: str | None = None,
|
||||
) -> None:
|
||||
"""Add an audit event from process."""
|
||||
command = get_process_command(process)
|
||||
client = get_info_client(process)
|
||||
username = get_info_username(process)
|
||||
remote_ip = get_info_remote_ip(process)
|
||||
operation = "SSH_EVENT"
|
||||
obj_name: str | None = None
|
||||
obj_id: str | None = None
|
||||
|
||||
if command and not secret:
|
||||
cmd, cmd_args = command
|
||||
obj_id = " ".join(cmd_args)
|
||||
elif secret:
|
||||
obj_name = "ClientSecret"
|
||||
obj_id = secret
|
||||
|
||||
entry = AuditLog(
|
||||
subsystem="ssh",
|
||||
operation=operation,
|
||||
object=obj_name,
|
||||
object_id=obj_id,
|
||||
message=message,
|
||||
origin=remote_ip,
|
||||
)
|
||||
if client:
|
||||
entry.client_id = str(client.id)
|
||||
entry.client_name = client.name
|
||||
elif username:
|
||||
entry.client_name = username
|
||||
|
||||
backend.add_audit_log_sync(entry)
|
||||
|
||||
|
||||
def audit_event(
|
||||
backend: SshecretBackend,
|
||||
message: str,
|
||||
operation: str = "SSH_EVENT",
|
||||
client: Client | None = None,
|
||||
origin: str | None = None,
|
||||
secret: str | None = None,
|
||||
) -> None:
|
||||
"""Add an audit event."""
|
||||
entry = AuditLog(
|
||||
client_id=None,
|
||||
client_name=None,
|
||||
object=None,
|
||||
object_id=None,
|
||||
subsystem="ssh",
|
||||
operation=operation,
|
||||
message=message,
|
||||
origin=origin,
|
||||
)
|
||||
if client:
|
||||
entry.client_id = str(client.id)
|
||||
entry.client_name = client.name
|
||||
|
||||
if secret:
|
||||
entry.object = "ClientSecret"
|
||||
entry.object_id = secret
|
||||
|
||||
backend.add_audit_log_sync(entry)
|
||||
|
||||
|
||||
def verify_key_input(public_key: str) -> str | None:
|
||||
"""Verify key input."""
|
||||
try:
|
||||
@ -46,9 +121,9 @@ def get_process_command(
|
||||
return (argv[0], argv[1:])
|
||||
|
||||
|
||||
def get_info_backend(process: asyncssh.SSHServerProcess[str]) -> BackendClient | None:
|
||||
def get_info_backend(process: asyncssh.SSHServerProcess[str]) -> SshecretBackend | None:
|
||||
"""Get backend from process."""
|
||||
backend = cast("BackendClient | None", process.get_extra_info("backend", None))
|
||||
backend = cast("SshecretBackend | None", process.get_extra_info("backend", None))
|
||||
return backend
|
||||
|
||||
|
||||
@ -64,6 +139,31 @@ def get_info_username(process: asyncssh.SSHServerProcess[str]) -> str | None:
|
||||
return username
|
||||
|
||||
|
||||
def get_info_remote_ip(process: asyncssh.SSHServerProcess[str]) -> str | None:
|
||||
"""Get remote IP."""
|
||||
|
||||
peername = cast("Peername | None", process.get_extra_info("peername", None))
|
||||
remote_ip: str | None = None
|
||||
if peername:
|
||||
remote_ip = peername[0]
|
||||
|
||||
return remote_ip
|
||||
|
||||
# remote_ip = str(self._conn.get_extra_info("peername")[0])
|
||||
|
||||
|
||||
def get_optional_commands(process: asyncssh.SSHServerProcess[str]) -> dict[str, bool]:
|
||||
"""Get optional command state."""
|
||||
with_registration = cast(
|
||||
bool, process.get_extra_info("registration_enabled", False)
|
||||
)
|
||||
with_ping = cast(bool, process.get_extra_info("ping_enabled", False))
|
||||
return {
|
||||
"registration": with_registration,
|
||||
"ping": with_ping,
|
||||
}
|
||||
|
||||
|
||||
async def get_stdin_public_key(process: asyncssh.SSHServerProcess[str]) -> str | None:
|
||||
"""Get public key from stdin."""
|
||||
process.stdout.write("Enter public key:\n")
|
||||
@ -76,6 +176,7 @@ async def get_stdin_public_key(process: asyncssh.SSHServerProcess[str]) -> str |
|
||||
process.stdout.write("Invalid key. Must be RSA Public Key.\n")
|
||||
except asyncssh.BreakReceived:
|
||||
pass
|
||||
process.stdout.write("OK\n")
|
||||
return public_key
|
||||
|
||||
|
||||
@ -90,7 +191,7 @@ def get_info_user_and_public_key(
|
||||
|
||||
async def register_client(
|
||||
process: asyncssh.SSHServerProcess[str],
|
||||
backend: BackendClient,
|
||||
backend: SshecretBackend,
|
||||
username: str,
|
||||
) -> None:
|
||||
"""Register a new client."""
|
||||
@ -101,35 +202,49 @@ async def register_client(
|
||||
key = asyncssh.import_public_key(public_key)
|
||||
if key.algorithm.decode() != "ssh-rsa":
|
||||
raise CommandError("Error: Only RSA keys are supported!")
|
||||
audit_process(backend, process, "Registering new client")
|
||||
LOG.debug("Registering client %s with public key %s", username, public_key)
|
||||
await backend.register_client(username, public_key)
|
||||
await backend.create_client(username, public_key)
|
||||
|
||||
|
||||
async def get_secret(
|
||||
backend: BackendClient,
|
||||
backend: SshecretBackend,
|
||||
client: Client,
|
||||
secret_name: str,
|
||||
origin: str,
|
||||
) -> str:
|
||||
"""Handle get secret requests from client."""
|
||||
LOG.debug("Recieved command: %r", secret_name)
|
||||
if not secret_name or secret_name not in client.secrets:
|
||||
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
||||
|
||||
audit_event(
|
||||
backend,
|
||||
"Client requested secret",
|
||||
operation="get_secret",
|
||||
client=client,
|
||||
origin=origin,
|
||||
secret=secret_name,
|
||||
)
|
||||
|
||||
# Look up secret
|
||||
try:
|
||||
secret = await backend.lookup_secret(client.name, secret_name)
|
||||
return await backend.get_client_secret(client.name, secret_name)
|
||||
except Exception as exc:
|
||||
LOG.debug(exc, exc_info=True)
|
||||
raise CommandError("Unexpected error from backend") from exc
|
||||
|
||||
return secret
|
||||
|
||||
|
||||
async def dispatch_no_cmd(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
"""Dispatch for no command."""
|
||||
raise CommandError(constants.ERROR_NO_COMMAND_RECEIVED)
|
||||
|
||||
|
||||
async def dispatch_cmd_ping(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
"""Dispatch the ping command."""
|
||||
process.stdout.write("PONG\n")
|
||||
|
||||
|
||||
async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
"""Dispatch the register command."""
|
||||
backend = get_info_backend(process)
|
||||
@ -157,7 +272,8 @@ async def dispatch_cmd_get_secret(process: asyncssh.SSHServerProcess[str]) -> No
|
||||
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
||||
secret_name = args[0]
|
||||
|
||||
secret = await get_secret(backend, client, secret_name)
|
||||
origin = get_info_remote_ip(process) or "Unknown"
|
||||
secret = await get_secret(backend, client, secret_name, origin)
|
||||
process.stdout.write(secret)
|
||||
|
||||
|
||||
@ -169,9 +285,14 @@ async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
process.exit(1)
|
||||
return
|
||||
cmdmap: dict[str, CommandDispatch] = {
|
||||
"register": dispatch_cmd_register,
|
||||
"get_secret": dispatch_cmd_get_secret,
|
||||
}
|
||||
extra_commands = get_optional_commands(process)
|
||||
if "registration" in extra_commands:
|
||||
cmdmap["register"] = dispatch_cmd_register
|
||||
if "ping" in extra_commands:
|
||||
cmdmap["ping"] = dispatch_cmd_ping
|
||||
|
||||
if command not in cmdmap:
|
||||
process.stderr.write(constants.ERROR_UNKNOWN_COMMAND)
|
||||
process.exit(1)
|
||||
@ -193,57 +314,33 @@ async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
process.exit(exit_code)
|
||||
|
||||
|
||||
async def handle_secret(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
"""Handle get secret requests from client."""
|
||||
backend = process.get_extra_info("backend")
|
||||
if not backend:
|
||||
process.stderr.write("Unexpected Error: Lost connection with backend object.")
|
||||
process.exit(1)
|
||||
return
|
||||
|
||||
assert isinstance(backend, BackendClient)
|
||||
|
||||
client = process.get_extra_info("client")
|
||||
if not client:
|
||||
process.stderr.write(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
||||
process.exit(1)
|
||||
return
|
||||
assert isinstance(client, Client), "Error: Unexpected client type received"
|
||||
secret_name = process.command
|
||||
LOG.debug("Recieved command: %r", secret_name)
|
||||
if not secret_name or secret_name not in client.secrets:
|
||||
process.stderr.write(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
||||
process.exit(1)
|
||||
return
|
||||
|
||||
# Look up secret
|
||||
try:
|
||||
secret = await backend.lookup_secret(client.name, secret_name)
|
||||
except Exception as exc:
|
||||
process.stderr.write("Unexpected error from backend:\n")
|
||||
process.stderr.write(str(exc))
|
||||
LOG.debug(exc, exc_info=True)
|
||||
process.exit(1)
|
||||
return
|
||||
process.stdout.write(secret)
|
||||
process.exit(0)
|
||||
|
||||
|
||||
class AsshyncServer(asyncssh.SSHServer):
|
||||
"""Asynchronous SSH server implementation."""
|
||||
|
||||
def __init__(self, settings: ServerSettings | None = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
backend_url: str,
|
||||
backend_token: str,
|
||||
with_register: bool = True,
|
||||
with_ping: bool = True,
|
||||
) -> None:
|
||||
"""Initialize server."""
|
||||
self.backend: BackendClient = BackendClient(settings)
|
||||
self.backend: SshecretBackend = SshecretBackend(backend_url, backend_token)
|
||||
self._conn: asyncssh.SSHServerConnection | None = None
|
||||
self.registration_enabled: bool = with_register
|
||||
self.ping_enabled: bool = with_ping
|
||||
self.client_ip: str | None = None
|
||||
|
||||
@override
|
||||
def connection_made(self, conn: asyncssh.SSHServerConnection) -> None:
|
||||
"""Handle incoming connection."""
|
||||
peername = conn.get_extra_info("peername")
|
||||
LOG.debug("Connection established from %r", peername)
|
||||
self.client_ip = peername[0]
|
||||
self._conn = conn
|
||||
self._conn.set_extra_info(backend=self.backend)
|
||||
self._conn.set_extra_info(registration_enabled=self.registration_enabled)
|
||||
self._conn.set_extra_info(ping_enabled=self.ping_enabled)
|
||||
|
||||
@override
|
||||
def password_auth_supported(self) -> bool:
|
||||
@ -261,13 +358,21 @@ class AsshyncServer(asyncssh.SSHServer):
|
||||
LOG.debug("Started authentication flow for user %s", username)
|
||||
if not self._conn:
|
||||
return True
|
||||
if client := await self.backend.lookup_client(username):
|
||||
if client := await self.backend.get_client(username):
|
||||
LOG.debug("Client lookup sucessful.")
|
||||
if key := self.resolve_client_key(client):
|
||||
LOG.debug("Loaded public key for client %s", client.name)
|
||||
self._conn.set_extra_info(client=client)
|
||||
self._conn.set_authorized_keys(key)
|
||||
else:
|
||||
|
||||
audit_event(
|
||||
self.backend,
|
||||
"Client denied due to policy",
|
||||
"DENY",
|
||||
client,
|
||||
origin=self.client_ip,
|
||||
)
|
||||
LOG.warning("Client connection denied due to policy.")
|
||||
else:
|
||||
self._conn.set_extra_info(provided_username=username)
|
||||
@ -308,37 +413,60 @@ class AsshyncServer(asyncssh.SSHServer):
|
||||
policies = [ipaddress.ip_network(policy) for policy in client.policies]
|
||||
|
||||
valid_source = [source_ip in policy for policy in policies]
|
||||
LOG.debug("Valid sources %r from policies %r", valid_source, policies)
|
||||
return any(valid_source)
|
||||
|
||||
|
||||
def get_server_key() -> str:
|
||||
def get_server_key(basedir: Path | None = None) -> str:
|
||||
"""Resolve server key.
|
||||
|
||||
TODO: Is one key enough? Should we generate more keys?
|
||||
"""
|
||||
filename = f"ssh_host_{constants.SERVER_KEY_TYPE}_key"
|
||||
if Path(filename).exists():
|
||||
return filename
|
||||
filename = Path(f"ssh_host_{constants.SERVER_KEY_TYPE}_key")
|
||||
if basedir:
|
||||
filename = basedir / filename
|
||||
if filename.exists():
|
||||
return str(filename.absolute())
|
||||
# FIXME: There's a weird typing warning here that I need to investigate.
|
||||
private_key = asyncssh.generate_private_key("ssh-ed25519", "sshecret-sshd")
|
||||
with open(filename, "wb") as f:
|
||||
f.write(private_key.export_private_key())
|
||||
|
||||
return filename
|
||||
return str(filename.absolute())
|
||||
|
||||
|
||||
async def run_ssh_server(
|
||||
backend_url: str,
|
||||
backend_token: str,
|
||||
listen_address: str,
|
||||
port: int,
|
||||
keys: list[str],
|
||||
) -> asyncssh.SSHAcceptor:
|
||||
"""Run the server."""
|
||||
server = partial(
|
||||
AsshyncServer, backend_url=str(backend_url), backend_token=backend_token
|
||||
)
|
||||
server = await asyncssh.create_server(
|
||||
server,
|
||||
listen_address,
|
||||
port,
|
||||
server_host_keys=keys,
|
||||
process_factory=dispatch_command,
|
||||
)
|
||||
return server
|
||||
|
||||
|
||||
async def start_server(settings: ServerSettings | None = None) -> None:
|
||||
"""Start the server."""
|
||||
server_key = get_server_key()
|
||||
server = partial(AsshyncServer, settings=settings)
|
||||
|
||||
if not settings:
|
||||
settings = ServerSettings() # pyright: ignore[reportCallIssue]
|
||||
|
||||
await asyncssh.create_server(
|
||||
server,
|
||||
await run_ssh_server(
|
||||
str(settings.backend_url),
|
||||
settings.backend_token,
|
||||
settings.listen_address,
|
||||
settings.port,
|
||||
server_host_keys=[server_key],
|
||||
process_factory=dispatch_command,
|
||||
[server_key],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user