Complete sshd
This commit is contained in:
344
packages/sshecret-sshd/src/sshecret_sshd/ssh_server.py
Normal file
344
packages/sshecret-sshd/src/sshecret_sshd/ssh_server.py
Normal file
@ -0,0 +1,344 @@
|
||||
"""SSH Server implementation."""
|
||||
|
||||
import logging
|
||||
|
||||
import asyncssh
|
||||
import ipaddress
|
||||
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Awaitable, Callable, cast, override
|
||||
|
||||
from . import constants
|
||||
|
||||
from .backend_client import BackendClient
|
||||
from .types import Client
|
||||
from .settings import ServerSettings
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CommandDispatch = Callable[[asyncssh.SSHServerProcess[str]], Awaitable[None]]
|
||||
|
||||
|
||||
class CommandError(Exception):
|
||||
"""Error class for errors during command processing."""
|
||||
|
||||
|
||||
def verify_key_input(public_key: str) -> str | None:
|
||||
"""Verify key input."""
|
||||
try:
|
||||
key = asyncssh.import_public_key(public_key)
|
||||
if key.algorithm.decode() == "ssh-rsa":
|
||||
return public_key
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def get_process_command(
|
||||
process: asyncssh.SSHServerProcess[str],
|
||||
) -> tuple[str | None, list[str]]:
|
||||
"""Extract the process command."""
|
||||
if not process.command:
|
||||
return (None, [])
|
||||
argv = process.command.split(" ")
|
||||
return (argv[0], argv[1:])
|
||||
|
||||
|
||||
def get_info_backend(process: asyncssh.SSHServerProcess[str]) -> BackendClient | None:
|
||||
"""Get backend from process."""
|
||||
backend = cast("BackendClient | None", process.get_extra_info("backend", None))
|
||||
return backend
|
||||
|
||||
|
||||
def get_info_client(process: asyncssh.SSHServerProcess[str]) -> Client | None:
|
||||
"""Get info from process."""
|
||||
client = cast("Client | None", process.get_extra_info("client", None))
|
||||
return client
|
||||
|
||||
|
||||
def get_info_username(process: asyncssh.SSHServerProcess[str]) -> str | None:
|
||||
"""Get username from process."""
|
||||
username = cast("str | None", process.get_extra_info("provided_username", None))
|
||||
return username
|
||||
|
||||
|
||||
async def get_stdin_public_key(process: asyncssh.SSHServerProcess[str]) -> str | None:
|
||||
"""Get public key from stdin."""
|
||||
process.stdout.write("Enter public key:\n")
|
||||
public_key: str | None = None
|
||||
try:
|
||||
async for line in process.stdin:
|
||||
public_key = verify_key_input(line.rstrip("\n"))
|
||||
if public_key:
|
||||
break
|
||||
process.stdout.write("Invalid key. Must be RSA Public Key.\n")
|
||||
except asyncssh.BreakReceived:
|
||||
pass
|
||||
return public_key
|
||||
|
||||
|
||||
def get_info_user_and_public_key(
|
||||
process: asyncssh.SSHServerProcess[str],
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Get username and public_key from process."""
|
||||
username = cast("str | None", process.get_extra_info("provided_username", None))
|
||||
public_key = cast("str | None", process.get_extra_info("provided_key", None))
|
||||
return (username, public_key)
|
||||
|
||||
|
||||
async def register_client(
|
||||
process: asyncssh.SSHServerProcess[str],
|
||||
backend: BackendClient,
|
||||
username: str,
|
||||
) -> None:
|
||||
"""Register a new client."""
|
||||
public_key = await get_stdin_public_key(process)
|
||||
if not public_key:
|
||||
raise CommandError("Aborted. No valid public key received.")
|
||||
|
||||
key = asyncssh.import_public_key(public_key)
|
||||
if key.algorithm.decode() != "ssh-rsa":
|
||||
raise CommandError("Error: Only RSA keys are supported!")
|
||||
LOG.debug("Registering client %s with public key %s", username, public_key)
|
||||
await backend.register_client(username, public_key)
|
||||
|
||||
|
||||
async def get_secret(
|
||||
backend: BackendClient,
|
||||
client: Client,
|
||||
secret_name: str,
|
||||
) -> str:
|
||||
"""Handle get secret requests from client."""
|
||||
LOG.debug("Recieved command: %r", secret_name)
|
||||
if not secret_name or secret_name not in client.secrets:
|
||||
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
||||
|
||||
# Look up secret
|
||||
try:
|
||||
secret = await backend.lookup_secret(client.name, secret_name)
|
||||
except Exception as exc:
|
||||
LOG.debug(exc, exc_info=True)
|
||||
raise CommandError("Unexpected error from backend") from exc
|
||||
|
||||
return secret
|
||||
|
||||
|
||||
async def dispatch_no_cmd(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
"""Dispatch for no command."""
|
||||
raise CommandError(constants.ERROR_NO_COMMAND_RECEIVED)
|
||||
|
||||
|
||||
async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
"""Dispatch the register command."""
|
||||
backend = get_info_backend(process)
|
||||
if not backend:
|
||||
raise CommandError("Unexpected error: Backend disappeared.")
|
||||
username = get_info_username(process)
|
||||
if not username:
|
||||
raise CommandError("Unexpected error: Username was lost.")
|
||||
await register_client(process, backend, username)
|
||||
|
||||
process.stdout.write("Client registered\n.")
|
||||
|
||||
|
||||
async def dispatch_cmd_get_secret(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
"""Dispatch the get_secret command."""
|
||||
backend = get_info_backend(process)
|
||||
if not backend:
|
||||
raise CommandError("Unexpected error: Backend disappeared.")
|
||||
|
||||
client = get_info_client(process)
|
||||
if not client:
|
||||
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
||||
_cmd, args = get_process_command(process)
|
||||
if not args:
|
||||
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
||||
secret_name = args[0]
|
||||
|
||||
secret = await get_secret(backend, client, secret_name)
|
||||
process.stdout.write(secret)
|
||||
|
||||
|
||||
async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
"""Dispatch command."""
|
||||
command, _args = get_process_command(process)
|
||||
if not command:
|
||||
process.stderr.write(constants.ERROR_NO_COMMAND_RECEIVED)
|
||||
process.exit(1)
|
||||
return
|
||||
cmdmap: dict[str, CommandDispatch] = {
|
||||
"register": dispatch_cmd_register,
|
||||
"get_secret": dispatch_cmd_get_secret,
|
||||
}
|
||||
if command not in cmdmap:
|
||||
process.stderr.write(constants.ERROR_UNKNOWN_COMMAND)
|
||||
process.exit(1)
|
||||
return
|
||||
exit_code = 0
|
||||
try:
|
||||
dispatcher = cmdmap[command]
|
||||
await dispatcher(process)
|
||||
except CommandError as e:
|
||||
process.stderr.write(str(e))
|
||||
exit_code = 1
|
||||
|
||||
except Exception as e:
|
||||
LOG.debug(e, exc_info=True)
|
||||
process.stderr.write("Unexpected exception:\n")
|
||||
process.stderr.write(str(e))
|
||||
exit_code = 1
|
||||
|
||||
process.exit(exit_code)
|
||||
|
||||
|
||||
async def handle_secret(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
"""Handle get secret requests from client."""
|
||||
backend = process.get_extra_info("backend")
|
||||
if not backend:
|
||||
process.stderr.write("Unexpected Error: Lost connection with backend object.")
|
||||
process.exit(1)
|
||||
return
|
||||
|
||||
assert isinstance(backend, BackendClient)
|
||||
|
||||
client = process.get_extra_info("client")
|
||||
if not client:
|
||||
process.stderr.write(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
||||
process.exit(1)
|
||||
return
|
||||
assert isinstance(client, Client), "Error: Unexpected client type received"
|
||||
secret_name = process.command
|
||||
LOG.debug("Recieved command: %r", secret_name)
|
||||
if not secret_name or secret_name not in client.secrets:
|
||||
process.stderr.write(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
||||
process.exit(1)
|
||||
return
|
||||
|
||||
# Look up secret
|
||||
try:
|
||||
secret = await backend.lookup_secret(client.name, secret_name)
|
||||
except Exception as exc:
|
||||
process.stderr.write("Unexpected error from backend:\n")
|
||||
process.stderr.write(str(exc))
|
||||
LOG.debug(exc, exc_info=True)
|
||||
process.exit(1)
|
||||
return
|
||||
process.stdout.write(secret)
|
||||
process.exit(0)
|
||||
|
||||
|
||||
class AsshyncServer(asyncssh.SSHServer):
|
||||
"""Asynchronous SSH server implementation."""
|
||||
|
||||
def __init__(self, settings: ServerSettings | None = None) -> None:
|
||||
"""Initialize server."""
|
||||
self.backend: BackendClient = BackendClient(settings)
|
||||
self._conn: asyncssh.SSHServerConnection | None = None
|
||||
|
||||
@override
|
||||
def connection_made(self, conn: asyncssh.SSHServerConnection) -> None:
|
||||
"""Handle incoming connection."""
|
||||
peername = conn.get_extra_info("peername")
|
||||
LOG.debug("Connection established from %r", peername)
|
||||
self._conn = conn
|
||||
self._conn.set_extra_info(backend=self.backend)
|
||||
|
||||
@override
|
||||
def password_auth_supported(self) -> bool:
|
||||
"""Deny password authentication."""
|
||||
return False
|
||||
|
||||
@override
|
||||
async def begin_auth(self, username: str) -> bool:
|
||||
"""Begin authentication.
|
||||
|
||||
Note we always return True here. False bypasses the whole authentication
|
||||
flow.
|
||||
|
||||
"""
|
||||
LOG.debug("Started authentication flow for user %s", username)
|
||||
if not self._conn:
|
||||
return True
|
||||
if client := await self.backend.lookup_client(username):
|
||||
LOG.debug("Client lookup sucessful.")
|
||||
if key := self.resolve_client_key(client):
|
||||
LOG.debug("Loaded public key for client %s", client.name)
|
||||
self._conn.set_extra_info(client=client)
|
||||
self._conn.set_authorized_keys(key)
|
||||
else:
|
||||
LOG.warning("Client connection denied due to policy.")
|
||||
else:
|
||||
self._conn.set_extra_info(provided_username=username)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@override
|
||||
def validate_public_key(self, username: str, key: asyncssh.SSHKey) -> bool:
|
||||
"""Intercept public key validation."""
|
||||
if not self._conn:
|
||||
return False
|
||||
|
||||
# get an export of the provided public key.
|
||||
keystring = key.export_public_key().decode()
|
||||
self._conn.set_extra_info(provided_username=username)
|
||||
self._conn.set_extra_info(provided_key=keystring)
|
||||
LOG.debug("Intercepting user public key")
|
||||
return False
|
||||
|
||||
def resolve_client_key(self, client: Client) -> asyncssh.SSHAuthorizedKeys | None:
|
||||
"""Resolve the client key.
|
||||
|
||||
Returns the key object only if the client is allowed to connect
|
||||
according to its policy.
|
||||
"""
|
||||
if not self._conn:
|
||||
return None
|
||||
remote_ip = str(self._conn.get_extra_info("peername")[0])
|
||||
LOG.debug("Validating client %s connection from %s", client.name, remote_ip)
|
||||
if self.check_connection_allowed(client, remote_ip):
|
||||
return asyncssh.import_authorized_keys(client.public_key)
|
||||
return None
|
||||
|
||||
def check_connection_allowed(self, client: Client, source: str) -> bool:
|
||||
"""Check if the client is allowed to connect."""
|
||||
source_ip = ipaddress.ip_address(source)
|
||||
policies = [ipaddress.ip_network(policy) for policy in client.policies]
|
||||
|
||||
valid_source = [source_ip in policy for policy in policies]
|
||||
return any(valid_source)
|
||||
|
||||
|
||||
def get_server_key() -> str:
|
||||
"""Resolve server key.
|
||||
|
||||
TODO: Is one key enough? Should we generate more keys?
|
||||
"""
|
||||
filename = f"ssh_host_{constants.SERVER_KEY_TYPE}_key"
|
||||
if Path(filename).exists():
|
||||
return filename
|
||||
# FIXME: There's a weird typing warning here that I need to investigate.
|
||||
private_key = asyncssh.generate_private_key("ssh-ed25519", "sshecret-sshd")
|
||||
with open(filename, "wb") as f:
|
||||
f.write(private_key.export_private_key())
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
async def start_server(settings: ServerSettings | None = None) -> None:
|
||||
"""Start the server."""
|
||||
server_key = get_server_key()
|
||||
server = partial(AsshyncServer, settings=settings)
|
||||
|
||||
if not settings:
|
||||
settings = ServerSettings() # pyright: ignore[reportCallIssue]
|
||||
|
||||
await asyncssh.create_server(
|
||||
server,
|
||||
settings.listen_address,
|
||||
settings.port,
|
||||
server_host_keys=[server_key],
|
||||
process_factory=dispatch_command,
|
||||
)
|
||||
Reference in New Issue
Block a user