Complete sshd

This commit is contained in:
2025-04-18 16:39:24 +02:00
parent ec90fb7680
commit d5b4ca5440
7 changed files with 513 additions and 0 deletions

View File

@ -10,8 +10,12 @@ requires-python = ">=3.13"
dependencies = [ dependencies = [
"asyncssh>=2.20.0", "asyncssh>=2.20.0",
"httpx>=0.28.1", "httpx>=0.28.1",
"python-dotenv>=1.0.1",
] ]
[project.scripts]
sshecret-sshd = "sshecret_sshd.cli:cli"
[build-system] [build-system]
requires = ["hatchling"] requires = ["hatchling"]
build-backend = "hatchling.build" build-backend = "hatchling.build"

View File

@ -0,0 +1,67 @@
"""Backend client.
This is intended to be as minimal as possible.
"""
import httpx
import urllib.parse
from .types import Client, Secret
from .settings import ServerSettings
class BackendClient:
"""Backend client."""
def __init__(self, settings: ServerSettings | None = None) -> None:
"""Initialize backend client."""
if not settings:
settings = ServerSettings() # pyright: ignore[reportCallIssue]
self.settings: ServerSettings = settings
@property
def headers(self) -> dict[str, str]:
"""Get the headers."""
return {"X-Api-Token": self.settings.backend_token}
def _format_url(self, path: str) -> str:
"""Format a URL."""
return urllib.parse.urljoin(str(self.settings.backend_url), path)
async def request(self, path: str) -> httpx.Response:
"""Send a simple GET request."""
url = self._format_url(path)
async with httpx.AsyncClient() as http_client:
response = await http_client.get(url, headers=self.headers)
return response
async def lookup_client(self, username: str) -> Client | None:
"""Lookup a client on username."""
path = f"api/v1/clients/{username}"
response = await self.request(path)
if response.status_code == 404:
return None
response.raise_for_status()
client = Client.model_validate(response.json())
return client
async def lookup_secret(self, username: str, secret_name: str) -> str:
"""Fetch a secret."""
path = f"api/v1/clients/{username}/secrets/{secret_name}"
response = await self.request(path)
response.raise_for_status()
secret = Secret.model_validate(response.json())
return secret.secret
async def register_client(self, username: str, public_key: str) -> None:
"""Register a new client."""
data = {
"name": username,
"public_key": public_key,
}
path = "api/v1/clients/"
url = self._format_url(path)
async with httpx.AsyncClient() as http_client:
response = await http_client.post(url, headers=self.headers, json=data)
response.raise_for_status()

View File

@ -0,0 +1,46 @@
"""CLI app."""
import logging
import asyncio
import sys
from pydantic_settings import CliApp
from .settings import ServerSettings
from .ssh_server import start_server
LOG = logging.getLogger()
handler = logging.StreamHandler()
formatter = logging.Formatter("%(created)f:%(levelname)s:%(name)s:%(module)s:%(message)s")
handler.setFormatter(formatter)
LOG.addHandler(handler)
LOG.setLevel(logging.INFO)
def cli(args: list[str] | None = None) -> None:
"""Run CLI app."""
try:
settings = ServerSettings()
except Exception:
print("One or more settings could not be resolved.")
CliApp.run(ServerSettings, ["--help"])
sys.exit(1)
if settings.debug:
LOG.setLevel(logging.DEBUG)
loop = asyncio.new_event_loop()
loop.run_until_complete(start_server(settings))
print(f"Starting SSH server: {settings.listen_address}:{settings.port}")
try:
loop.run_forever()
except KeyboardInterrupt:
print("\nCtrl-C received. Exiting.")
sys.exit()
if __name__ == "__main__":
"""Run CLI app."""
cli()

View File

@ -0,0 +1,8 @@
ERROR_NO_SECRET_FOUND = "Error: No secret available with the given name."
ERROR_UNKNOWN_CLIENT_OR_SECRET = "Error: Invalid client or secret name."
ERROR_NO_COMMAND_RECEIVED = "Error: No command was received from the client."
ERROR_SOURCE_IP_NOT_ALLOWED = (
"Error: Client not authorized to connect from the given host."
)
ERROR_UNKNOWN_COMMAND = "Error: The given command was not understood."
SERVER_KEY_TYPE = "ed25519"

View File

@ -0,0 +1,18 @@
"""SSH Server settings."""
from pydantic import AnyHttpUrl, Field
from pydantic_settings import BaseSettings, SettingsConfigDict
DEFAULT_LISTEN_PORT = 2222
class ServerSettings(BaseSettings, cli_parse_args=True, cli_exit_on_error=True):
"""Server Settings."""
model_config = SettingsConfigDict(env_file=".sshd.env", env_prefix="sshecret_sshd_")
backend_url: AnyHttpUrl = Field(alias="backend-url")
backend_token: str = Field(alias="backend-token")
listen_address: str = Field(default="", alias="listen")
port: int = DEFAULT_LISTEN_PORT
debug: bool = False

View File

@ -0,0 +1,344 @@
"""SSH Server implementation."""
import logging
import asyncssh
import ipaddress
from functools import partial
from pathlib import Path
from typing import Awaitable, Callable, cast, override
from . import constants
from .backend_client import BackendClient
from .types import Client
from .settings import ServerSettings
LOG = logging.getLogger(__name__)
CommandDispatch = Callable[[asyncssh.SSHServerProcess[str]], Awaitable[None]]
class CommandError(Exception):
"""Error class for errors during command processing."""
def verify_key_input(public_key: str) -> str | None:
"""Verify key input."""
try:
key = asyncssh.import_public_key(public_key)
if key.algorithm.decode() == "ssh-rsa":
return public_key
except Exception:
return None
def get_process_command(
process: asyncssh.SSHServerProcess[str],
) -> tuple[str | None, list[str]]:
"""Extract the process command."""
if not process.command:
return (None, [])
argv = process.command.split(" ")
return (argv[0], argv[1:])
def get_info_backend(process: asyncssh.SSHServerProcess[str]) -> BackendClient | None:
"""Get backend from process."""
backend = cast("BackendClient | None", process.get_extra_info("backend", None))
return backend
def get_info_client(process: asyncssh.SSHServerProcess[str]) -> Client | None:
"""Get info from process."""
client = cast("Client | None", process.get_extra_info("client", None))
return client
def get_info_username(process: asyncssh.SSHServerProcess[str]) -> str | None:
"""Get username from process."""
username = cast("str | None", process.get_extra_info("provided_username", None))
return username
async def get_stdin_public_key(process: asyncssh.SSHServerProcess[str]) -> str | None:
"""Get public key from stdin."""
process.stdout.write("Enter public key:\n")
public_key: str | None = None
try:
async for line in process.stdin:
public_key = verify_key_input(line.rstrip("\n"))
if public_key:
break
process.stdout.write("Invalid key. Must be RSA Public Key.\n")
except asyncssh.BreakReceived:
pass
return public_key
def get_info_user_and_public_key(
process: asyncssh.SSHServerProcess[str],
) -> tuple[str | None, str | None]:
"""Get username and public_key from process."""
username = cast("str | None", process.get_extra_info("provided_username", None))
public_key = cast("str | None", process.get_extra_info("provided_key", None))
return (username, public_key)
async def register_client(
process: asyncssh.SSHServerProcess[str],
backend: BackendClient,
username: str,
) -> None:
"""Register a new client."""
public_key = await get_stdin_public_key(process)
if not public_key:
raise CommandError("Aborted. No valid public key received.")
key = asyncssh.import_public_key(public_key)
if key.algorithm.decode() != "ssh-rsa":
raise CommandError("Error: Only RSA keys are supported!")
LOG.debug("Registering client %s with public key %s", username, public_key)
await backend.register_client(username, public_key)
async def get_secret(
backend: BackendClient,
client: Client,
secret_name: str,
) -> str:
"""Handle get secret requests from client."""
LOG.debug("Recieved command: %r", secret_name)
if not secret_name or secret_name not in client.secrets:
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
# Look up secret
try:
secret = await backend.lookup_secret(client.name, secret_name)
except Exception as exc:
LOG.debug(exc, exc_info=True)
raise CommandError("Unexpected error from backend") from exc
return secret
async def dispatch_no_cmd(process: asyncssh.SSHServerProcess[str]) -> None:
"""Dispatch for no command."""
raise CommandError(constants.ERROR_NO_COMMAND_RECEIVED)
async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None:
"""Dispatch the register command."""
backend = get_info_backend(process)
if not backend:
raise CommandError("Unexpected error: Backend disappeared.")
username = get_info_username(process)
if not username:
raise CommandError("Unexpected error: Username was lost.")
await register_client(process, backend, username)
process.stdout.write("Client registered\n.")
async def dispatch_cmd_get_secret(process: asyncssh.SSHServerProcess[str]) -> None:
"""Dispatch the get_secret command."""
backend = get_info_backend(process)
if not backend:
raise CommandError("Unexpected error: Backend disappeared.")
client = get_info_client(process)
if not client:
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
_cmd, args = get_process_command(process)
if not args:
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
secret_name = args[0]
secret = await get_secret(backend, client, secret_name)
process.stdout.write(secret)
async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
"""Dispatch command."""
command, _args = get_process_command(process)
if not command:
process.stderr.write(constants.ERROR_NO_COMMAND_RECEIVED)
process.exit(1)
return
cmdmap: dict[str, CommandDispatch] = {
"register": dispatch_cmd_register,
"get_secret": dispatch_cmd_get_secret,
}
if command not in cmdmap:
process.stderr.write(constants.ERROR_UNKNOWN_COMMAND)
process.exit(1)
return
exit_code = 0
try:
dispatcher = cmdmap[command]
await dispatcher(process)
except CommandError as e:
process.stderr.write(str(e))
exit_code = 1
except Exception as e:
LOG.debug(e, exc_info=True)
process.stderr.write("Unexpected exception:\n")
process.stderr.write(str(e))
exit_code = 1
process.exit(exit_code)
async def handle_secret(process: asyncssh.SSHServerProcess[str]) -> None:
"""Handle get secret requests from client."""
backend = process.get_extra_info("backend")
if not backend:
process.stderr.write("Unexpected Error: Lost connection with backend object.")
process.exit(1)
return
assert isinstance(backend, BackendClient)
client = process.get_extra_info("client")
if not client:
process.stderr.write(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
process.exit(1)
return
assert isinstance(client, Client), "Error: Unexpected client type received"
secret_name = process.command
LOG.debug("Recieved command: %r", secret_name)
if not secret_name or secret_name not in client.secrets:
process.stderr.write(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
process.exit(1)
return
# Look up secret
try:
secret = await backend.lookup_secret(client.name, secret_name)
except Exception as exc:
process.stderr.write("Unexpected error from backend:\n")
process.stderr.write(str(exc))
LOG.debug(exc, exc_info=True)
process.exit(1)
return
process.stdout.write(secret)
process.exit(0)
class AsshyncServer(asyncssh.SSHServer):
"""Asynchronous SSH server implementation."""
def __init__(self, settings: ServerSettings | None = None) -> None:
"""Initialize server."""
self.backend: BackendClient = BackendClient(settings)
self._conn: asyncssh.SSHServerConnection | None = None
@override
def connection_made(self, conn: asyncssh.SSHServerConnection) -> None:
"""Handle incoming connection."""
peername = conn.get_extra_info("peername")
LOG.debug("Connection established from %r", peername)
self._conn = conn
self._conn.set_extra_info(backend=self.backend)
@override
def password_auth_supported(self) -> bool:
"""Deny password authentication."""
return False
@override
async def begin_auth(self, username: str) -> bool:
"""Begin authentication.
Note we always return True here. False bypasses the whole authentication
flow.
"""
LOG.debug("Started authentication flow for user %s", username)
if not self._conn:
return True
if client := await self.backend.lookup_client(username):
LOG.debug("Client lookup sucessful.")
if key := self.resolve_client_key(client):
LOG.debug("Loaded public key for client %s", client.name)
self._conn.set_extra_info(client=client)
self._conn.set_authorized_keys(key)
else:
LOG.warning("Client connection denied due to policy.")
else:
self._conn.set_extra_info(provided_username=username)
return False
return True
@override
def validate_public_key(self, username: str, key: asyncssh.SSHKey) -> bool:
"""Intercept public key validation."""
if not self._conn:
return False
# get an export of the provided public key.
keystring = key.export_public_key().decode()
self._conn.set_extra_info(provided_username=username)
self._conn.set_extra_info(provided_key=keystring)
LOG.debug("Intercepting user public key")
return False
def resolve_client_key(self, client: Client) -> asyncssh.SSHAuthorizedKeys | None:
"""Resolve the client key.
Returns the key object only if the client is allowed to connect
according to its policy.
"""
if not self._conn:
return None
remote_ip = str(self._conn.get_extra_info("peername")[0])
LOG.debug("Validating client %s connection from %s", client.name, remote_ip)
if self.check_connection_allowed(client, remote_ip):
return asyncssh.import_authorized_keys(client.public_key)
return None
def check_connection_allowed(self, client: Client, source: str) -> bool:
"""Check if the client is allowed to connect."""
source_ip = ipaddress.ip_address(source)
policies = [ipaddress.ip_network(policy) for policy in client.policies]
valid_source = [source_ip in policy for policy in policies]
return any(valid_source)
def get_server_key() -> str:
"""Resolve server key.
TODO: Is one key enough? Should we generate more keys?
"""
filename = f"ssh_host_{constants.SERVER_KEY_TYPE}_key"
if Path(filename).exists():
return filename
# FIXME: There's a weird typing warning here that I need to investigate.
private_key = asyncssh.generate_private_key("ssh-ed25519", "sshecret-sshd")
with open(filename, "wb") as f:
f.write(private_key.export_private_key())
return filename
async def start_server(settings: ServerSettings | None = None) -> None:
"""Start the server."""
server_key = get_server_key()
server = partial(AsshyncServer, settings=settings)
if not settings:
settings = ServerSettings() # pyright: ignore[reportCallIssue]
await asyncssh.create_server(
server,
settings.listen_address,
settings.port,
server_host_keys=[server_key],
process_factory=dispatch_command,
)

View File

@ -0,0 +1,26 @@
"""Types."""
import uuid
from datetime import datetime
from pydantic import BaseModel, IPvAnyAddress, IPvAnyNetwork
class Client(BaseModel):
"""Implementation of the backend class ClientView."""
id: uuid.UUID
name: str
public_key: str
secrets: list[str]
policies: list[IPvAnyNetwork | IPvAnyAddress]
created_at: datetime
updated_at: datetime | None
class Secret(BaseModel):
"""Implemen tation of the backend class ClientSecretResponse."""
name: str
secret: str
created_at: datetime
updated_at: datetime | None