diff --git a/packages/sshecret-sshd/pyproject.toml b/packages/sshecret-sshd/pyproject.toml index 42991b9..fbfa7d3 100644 --- a/packages/sshecret-sshd/pyproject.toml +++ b/packages/sshecret-sshd/pyproject.toml @@ -10,8 +10,12 @@ requires-python = ">=3.13" dependencies = [ "asyncssh>=2.20.0", "httpx>=0.28.1", + "python-dotenv>=1.0.1", ] +[project.scripts] +sshecret-sshd = "sshecret_sshd.cli:cli" + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" diff --git a/packages/sshecret-sshd/src/sshecret_sshd/backend_client.py b/packages/sshecret-sshd/src/sshecret_sshd/backend_client.py new file mode 100644 index 0000000..6033fae --- /dev/null +++ b/packages/sshecret-sshd/src/sshecret_sshd/backend_client.py @@ -0,0 +1,67 @@ +"""Backend client. + +This is intended to be as minimal as possible. +""" + +import httpx +import urllib.parse + +from .types import Client, Secret +from .settings import ServerSettings + + +class BackendClient: + """Backend client.""" + + def __init__(self, settings: ServerSettings | None = None) -> None: + """Initialize backend client.""" + if not settings: + settings = ServerSettings() # pyright: ignore[reportCallIssue] + self.settings: ServerSettings = settings + + @property + def headers(self) -> dict[str, str]: + """Get the headers.""" + return {"X-Api-Token": self.settings.backend_token} + + def _format_url(self, path: str) -> str: + """Format a URL.""" + return urllib.parse.urljoin(str(self.settings.backend_url), path) + + async def request(self, path: str) -> httpx.Response: + """Send a simple GET request.""" + url = self._format_url(path) + async with httpx.AsyncClient() as http_client: + response = await http_client.get(url, headers=self.headers) + return response + + async def lookup_client(self, username: str) -> Client | None: + """Lookup a client on username.""" + path = f"api/v1/clients/{username}" + response = await self.request(path) + if response.status_code == 404: + return None + response.raise_for_status() + client = Client.model_validate(response.json()) + return client + + async def lookup_secret(self, username: str, secret_name: str) -> str: + """Fetch a secret.""" + path = f"api/v1/clients/{username}/secrets/{secret_name}" + response = await self.request(path) + response.raise_for_status() + secret = Secret.model_validate(response.json()) + return secret.secret + + async def register_client(self, username: str, public_key: str) -> None: + """Register a new client.""" + data = { + "name": username, + "public_key": public_key, + } + path = "api/v1/clients/" + url = self._format_url(path) + async with httpx.AsyncClient() as http_client: + response = await http_client.post(url, headers=self.headers, json=data) + + response.raise_for_status() diff --git a/packages/sshecret-sshd/src/sshecret_sshd/cli.py b/packages/sshecret-sshd/src/sshecret_sshd/cli.py new file mode 100644 index 0000000..0948827 --- /dev/null +++ b/packages/sshecret-sshd/src/sshecret_sshd/cli.py @@ -0,0 +1,46 @@ +"""CLI app.""" +import logging +import asyncio +import sys +from pydantic_settings import CliApp + +from .settings import ServerSettings +from .ssh_server import start_server + +LOG = logging.getLogger() + +handler = logging.StreamHandler() +formatter = logging.Formatter("%(created)f:%(levelname)s:%(name)s:%(module)s:%(message)s") + +handler.setFormatter(formatter) +LOG.addHandler(handler) +LOG.setLevel(logging.INFO) + + +def cli(args: list[str] | None = None) -> None: + """Run CLI app.""" + try: + settings = ServerSettings() + except Exception: + print("One or more settings could not be resolved.") + CliApp.run(ServerSettings, ["--help"]) + sys.exit(1) + + if settings.debug: + LOG.setLevel(logging.DEBUG) + + loop = asyncio.new_event_loop() + loop.run_until_complete(start_server(settings)) + + print(f"Starting SSH server: {settings.listen_address}:{settings.port}") + try: + loop.run_forever() + except KeyboardInterrupt: + print("\nCtrl-C received. Exiting.") + sys.exit() + + + +if __name__ == "__main__": + """Run CLI app.""" + cli() diff --git a/packages/sshecret-sshd/src/sshecret_sshd/constants.py b/packages/sshecret-sshd/src/sshecret_sshd/constants.py new file mode 100644 index 0000000..1025f1a --- /dev/null +++ b/packages/sshecret-sshd/src/sshecret_sshd/constants.py @@ -0,0 +1,8 @@ +ERROR_NO_SECRET_FOUND = "Error: No secret available with the given name." +ERROR_UNKNOWN_CLIENT_OR_SECRET = "Error: Invalid client or secret name." +ERROR_NO_COMMAND_RECEIVED = "Error: No command was received from the client." +ERROR_SOURCE_IP_NOT_ALLOWED = ( + "Error: Client not authorized to connect from the given host." +) +ERROR_UNKNOWN_COMMAND = "Error: The given command was not understood." +SERVER_KEY_TYPE = "ed25519" diff --git a/packages/sshecret-sshd/src/sshecret_sshd/settings.py b/packages/sshecret-sshd/src/sshecret_sshd/settings.py new file mode 100644 index 0000000..603b24d --- /dev/null +++ b/packages/sshecret-sshd/src/sshecret_sshd/settings.py @@ -0,0 +1,18 @@ +"""SSH Server settings.""" + +from pydantic import AnyHttpUrl, Field +from pydantic_settings import BaseSettings, SettingsConfigDict + + +DEFAULT_LISTEN_PORT = 2222 + +class ServerSettings(BaseSettings, cli_parse_args=True, cli_exit_on_error=True): + """Server Settings.""" + + model_config = SettingsConfigDict(env_file=".sshd.env", env_prefix="sshecret_sshd_") + + backend_url: AnyHttpUrl = Field(alias="backend-url") + backend_token: str = Field(alias="backend-token") + listen_address: str = Field(default="", alias="listen") + port: int = DEFAULT_LISTEN_PORT + debug: bool = False diff --git a/packages/sshecret-sshd/src/sshecret_sshd/ssh_server.py b/packages/sshecret-sshd/src/sshecret_sshd/ssh_server.py new file mode 100644 index 0000000..c77b36d --- /dev/null +++ b/packages/sshecret-sshd/src/sshecret_sshd/ssh_server.py @@ -0,0 +1,344 @@ +"""SSH Server implementation.""" + +import logging + +import asyncssh +import ipaddress + +from functools import partial +from pathlib import Path +from typing import Awaitable, Callable, cast, override + +from . import constants + +from .backend_client import BackendClient +from .types import Client +from .settings import ServerSettings + + +LOG = logging.getLogger(__name__) + + +CommandDispatch = Callable[[asyncssh.SSHServerProcess[str]], Awaitable[None]] + + +class CommandError(Exception): + """Error class for errors during command processing.""" + + +def verify_key_input(public_key: str) -> str | None: + """Verify key input.""" + try: + key = asyncssh.import_public_key(public_key) + if key.algorithm.decode() == "ssh-rsa": + return public_key + except Exception: + return None + + +def get_process_command( + process: asyncssh.SSHServerProcess[str], +) -> tuple[str | None, list[str]]: + """Extract the process command.""" + if not process.command: + return (None, []) + argv = process.command.split(" ") + return (argv[0], argv[1:]) + + +def get_info_backend(process: asyncssh.SSHServerProcess[str]) -> BackendClient | None: + """Get backend from process.""" + backend = cast("BackendClient | None", process.get_extra_info("backend", None)) + return backend + + +def get_info_client(process: asyncssh.SSHServerProcess[str]) -> Client | None: + """Get info from process.""" + client = cast("Client | None", process.get_extra_info("client", None)) + return client + + +def get_info_username(process: asyncssh.SSHServerProcess[str]) -> str | None: + """Get username from process.""" + username = cast("str | None", process.get_extra_info("provided_username", None)) + return username + + +async def get_stdin_public_key(process: asyncssh.SSHServerProcess[str]) -> str | None: + """Get public key from stdin.""" + process.stdout.write("Enter public key:\n") + public_key: str | None = None + try: + async for line in process.stdin: + public_key = verify_key_input(line.rstrip("\n")) + if public_key: + break + process.stdout.write("Invalid key. Must be RSA Public Key.\n") + except asyncssh.BreakReceived: + pass + return public_key + + +def get_info_user_and_public_key( + process: asyncssh.SSHServerProcess[str], +) -> tuple[str | None, str | None]: + """Get username and public_key from process.""" + username = cast("str | None", process.get_extra_info("provided_username", None)) + public_key = cast("str | None", process.get_extra_info("provided_key", None)) + return (username, public_key) + + +async def register_client( + process: asyncssh.SSHServerProcess[str], + backend: BackendClient, + username: str, +) -> None: + """Register a new client.""" + public_key = await get_stdin_public_key(process) + if not public_key: + raise CommandError("Aborted. No valid public key received.") + + key = asyncssh.import_public_key(public_key) + if key.algorithm.decode() != "ssh-rsa": + raise CommandError("Error: Only RSA keys are supported!") + LOG.debug("Registering client %s with public key %s", username, public_key) + await backend.register_client(username, public_key) + + +async def get_secret( + backend: BackendClient, + client: Client, + secret_name: str, +) -> str: + """Handle get secret requests from client.""" + LOG.debug("Recieved command: %r", secret_name) + if not secret_name or secret_name not in client.secrets: + raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET) + + # Look up secret + try: + secret = await backend.lookup_secret(client.name, secret_name) + except Exception as exc: + LOG.debug(exc, exc_info=True) + raise CommandError("Unexpected error from backend") from exc + + return secret + + +async def dispatch_no_cmd(process: asyncssh.SSHServerProcess[str]) -> None: + """Dispatch for no command.""" + raise CommandError(constants.ERROR_NO_COMMAND_RECEIVED) + + +async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None: + """Dispatch the register command.""" + backend = get_info_backend(process) + if not backend: + raise CommandError("Unexpected error: Backend disappeared.") + username = get_info_username(process) + if not username: + raise CommandError("Unexpected error: Username was lost.") + await register_client(process, backend, username) + + process.stdout.write("Client registered\n.") + + +async def dispatch_cmd_get_secret(process: asyncssh.SSHServerProcess[str]) -> None: + """Dispatch the get_secret command.""" + backend = get_info_backend(process) + if not backend: + raise CommandError("Unexpected error: Backend disappeared.") + + client = get_info_client(process) + if not client: + raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET) + _cmd, args = get_process_command(process) + if not args: + raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET) + secret_name = args[0] + + secret = await get_secret(backend, client, secret_name) + process.stdout.write(secret) + + +async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None: + """Dispatch command.""" + command, _args = get_process_command(process) + if not command: + process.stderr.write(constants.ERROR_NO_COMMAND_RECEIVED) + process.exit(1) + return + cmdmap: dict[str, CommandDispatch] = { + "register": dispatch_cmd_register, + "get_secret": dispatch_cmd_get_secret, + } + if command not in cmdmap: + process.stderr.write(constants.ERROR_UNKNOWN_COMMAND) + process.exit(1) + return + exit_code = 0 + try: + dispatcher = cmdmap[command] + await dispatcher(process) + except CommandError as e: + process.stderr.write(str(e)) + exit_code = 1 + + except Exception as e: + LOG.debug(e, exc_info=True) + process.stderr.write("Unexpected exception:\n") + process.stderr.write(str(e)) + exit_code = 1 + + process.exit(exit_code) + + +async def handle_secret(process: asyncssh.SSHServerProcess[str]) -> None: + """Handle get secret requests from client.""" + backend = process.get_extra_info("backend") + if not backend: + process.stderr.write("Unexpected Error: Lost connection with backend object.") + process.exit(1) + return + + assert isinstance(backend, BackendClient) + + client = process.get_extra_info("client") + if not client: + process.stderr.write(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET) + process.exit(1) + return + assert isinstance(client, Client), "Error: Unexpected client type received" + secret_name = process.command + LOG.debug("Recieved command: %r", secret_name) + if not secret_name or secret_name not in client.secrets: + process.stderr.write(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET) + process.exit(1) + return + + # Look up secret + try: + secret = await backend.lookup_secret(client.name, secret_name) + except Exception as exc: + process.stderr.write("Unexpected error from backend:\n") + process.stderr.write(str(exc)) + LOG.debug(exc, exc_info=True) + process.exit(1) + return + process.stdout.write(secret) + process.exit(0) + + +class AsshyncServer(asyncssh.SSHServer): + """Asynchronous SSH server implementation.""" + + def __init__(self, settings: ServerSettings | None = None) -> None: + """Initialize server.""" + self.backend: BackendClient = BackendClient(settings) + self._conn: asyncssh.SSHServerConnection | None = None + + @override + def connection_made(self, conn: asyncssh.SSHServerConnection) -> None: + """Handle incoming connection.""" + peername = conn.get_extra_info("peername") + LOG.debug("Connection established from %r", peername) + self._conn = conn + self._conn.set_extra_info(backend=self.backend) + + @override + def password_auth_supported(self) -> bool: + """Deny password authentication.""" + return False + + @override + async def begin_auth(self, username: str) -> bool: + """Begin authentication. + + Note we always return True here. False bypasses the whole authentication + flow. + + """ + LOG.debug("Started authentication flow for user %s", username) + if not self._conn: + return True + if client := await self.backend.lookup_client(username): + LOG.debug("Client lookup sucessful.") + if key := self.resolve_client_key(client): + LOG.debug("Loaded public key for client %s", client.name) + self._conn.set_extra_info(client=client) + self._conn.set_authorized_keys(key) + else: + LOG.warning("Client connection denied due to policy.") + else: + self._conn.set_extra_info(provided_username=username) + return False + + return True + + @override + def validate_public_key(self, username: str, key: asyncssh.SSHKey) -> bool: + """Intercept public key validation.""" + if not self._conn: + return False + + # get an export of the provided public key. + keystring = key.export_public_key().decode() + self._conn.set_extra_info(provided_username=username) + self._conn.set_extra_info(provided_key=keystring) + LOG.debug("Intercepting user public key") + return False + + def resolve_client_key(self, client: Client) -> asyncssh.SSHAuthorizedKeys | None: + """Resolve the client key. + + Returns the key object only if the client is allowed to connect + according to its policy. + """ + if not self._conn: + return None + remote_ip = str(self._conn.get_extra_info("peername")[0]) + LOG.debug("Validating client %s connection from %s", client.name, remote_ip) + if self.check_connection_allowed(client, remote_ip): + return asyncssh.import_authorized_keys(client.public_key) + return None + + def check_connection_allowed(self, client: Client, source: str) -> bool: + """Check if the client is allowed to connect.""" + source_ip = ipaddress.ip_address(source) + policies = [ipaddress.ip_network(policy) for policy in client.policies] + + valid_source = [source_ip in policy for policy in policies] + return any(valid_source) + + +def get_server_key() -> str: + """Resolve server key. + + TODO: Is one key enough? Should we generate more keys? + """ + filename = f"ssh_host_{constants.SERVER_KEY_TYPE}_key" + if Path(filename).exists(): + return filename + # FIXME: There's a weird typing warning here that I need to investigate. + private_key = asyncssh.generate_private_key("ssh-ed25519", "sshecret-sshd") + with open(filename, "wb") as f: + f.write(private_key.export_private_key()) + + return filename + + +async def start_server(settings: ServerSettings | None = None) -> None: + """Start the server.""" + server_key = get_server_key() + server = partial(AsshyncServer, settings=settings) + + if not settings: + settings = ServerSettings() # pyright: ignore[reportCallIssue] + + await asyncssh.create_server( + server, + settings.listen_address, + settings.port, + server_host_keys=[server_key], + process_factory=dispatch_command, + ) diff --git a/packages/sshecret-sshd/src/sshecret_sshd/types.py b/packages/sshecret-sshd/src/sshecret_sshd/types.py new file mode 100644 index 0000000..09d59fe --- /dev/null +++ b/packages/sshecret-sshd/src/sshecret_sshd/types.py @@ -0,0 +1,26 @@ +"""Types.""" + +import uuid +from datetime import datetime +from pydantic import BaseModel, IPvAnyAddress, IPvAnyNetwork + + +class Client(BaseModel): + """Implementation of the backend class ClientView.""" + + id: uuid.UUID + name: str + public_key: str + secrets: list[str] + policies: list[IPvAnyNetwork | IPvAnyAddress] + created_at: datetime + updated_at: datetime | None + + +class Secret(BaseModel): + """Implemen tation of the backend class ClientSecretResponse.""" + + name: str + secret: str + created_at: datetime + updated_at: datetime | None