Check in changes to sshd module

This commit is contained in:
2025-04-30 08:25:15 +02:00
parent 2a668059ef
commit 6d37f7d251
4 changed files with 193 additions and 158 deletions

View File

@ -1,67 +0,0 @@
"""Backend client.
This is intended to be as minimal as possible.
"""
import httpx
import urllib.parse
from .types import Client, Secret
from .settings import ServerSettings
class BackendClient:
"""Backend client."""
def __init__(self, settings: ServerSettings | None = None) -> None:
"""Initialize backend client."""
if not settings:
settings = ServerSettings() # pyright: ignore[reportCallIssue]
self.settings: ServerSettings = settings
@property
def headers(self) -> dict[str, str]:
"""Get the headers."""
return {"X-Api-Token": self.settings.backend_token}
def _format_url(self, path: str) -> str:
"""Format a URL."""
return urllib.parse.urljoin(str(self.settings.backend_url), path)
async def request(self, path: str) -> httpx.Response:
"""Send a simple GET request."""
url = self._format_url(path)
async with httpx.AsyncClient() as http_client:
response = await http_client.get(url, headers=self.headers)
return response
async def lookup_client(self, username: str) -> Client | None:
"""Lookup a client on username."""
path = f"api/v1/clients/{username}"
response = await self.request(path)
if response.status_code == 404:
return None
response.raise_for_status()
client = Client.model_validate(response.json())
return client
async def lookup_secret(self, username: str, secret_name: str) -> str:
"""Fetch a secret."""
path = f"api/v1/clients/{username}/secrets/{secret_name}"
response = await self.request(path)
response.raise_for_status()
secret = Secret.model_validate(response.json())
return secret.secret
async def register_client(self, username: str, public_key: str) -> None:
"""Register a new client."""
data = {
"name": username,
"public_key": public_key,
}
path = "api/v1/clients/"
url = self._format_url(path)
async with httpx.AsyncClient() as http_client:
response = await http_client.post(url, headers=self.headers, json=data)
response.raise_for_status()

View File

@ -1,6 +1,6 @@
"""SSH Server settings.""" """SSH Server settings."""
from pydantic import AnyHttpUrl, Field from pydantic import AnyHttpUrl, Field, AliasChoices
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
@ -11,8 +11,8 @@ class ServerSettings(BaseSettings, cli_parse_args=True, cli_exit_on_error=True):
model_config = SettingsConfigDict(env_file=".sshd.env", env_prefix="sshecret_sshd_") model_config = SettingsConfigDict(env_file=".sshd.env", env_prefix="sshecret_sshd_")
backend_url: AnyHttpUrl = Field(alias="backend-url") backend_url: AnyHttpUrl = Field(validation_alias=AliasChoices("backend-url", "sshecret_backend_url"))
backend_token: str = Field(alias="backend-token") backend_token: str = Field(validation_alias=AliasChoices("backend-token", "sshecret_sshd_backend_token"))
listen_address: str = Field(default="", alias="listen") listen_address: str = Field(default="", alias="listen")
port: int = DEFAULT_LISTEN_PORT port: int = DEFAULT_LISTEN_PORT
debug: bool = False debug: bool = False

View File

@ -1,18 +1,20 @@
"""SSH Server implementation.""" """SSH Server implementation."""
import logging import logging
import uuid
import asyncssh import asyncssh
import ipaddress import ipaddress
from collections.abc import Awaitable
from datetime import datetime, timezone
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Awaitable, Callable, cast, override from typing import Any, Callable, cast, override
from . import constants from . import constants
from .backend_client import BackendClient from sshecret.backend import AuditLog, SshecretBackend, Client
from .types import Client
from .settings import ServerSettings from .settings import ServerSettings
@ -21,11 +23,84 @@ LOG = logging.getLogger(__name__)
CommandDispatch = Callable[[asyncssh.SSHServerProcess[str]], Awaitable[None]] CommandDispatch = Callable[[asyncssh.SSHServerProcess[str]], Awaitable[None]]
PeernameV4 = tuple[str, int]
PeernameV6 = tuple[str, int, int, int]
Peername = PeernameV4 | PeernameV6
class CommandError(Exception): class CommandError(Exception):
"""Error class for errors during command processing.""" """Error class for errors during command processing."""
def audit_process(
backend: SshecretBackend,
process: asyncssh.SSHServerProcess[str],
message: str,
secret: str | None = None,
) -> None:
"""Add an audit event from process."""
command = get_process_command(process)
client = get_info_client(process)
username = get_info_username(process)
remote_ip = get_info_remote_ip(process)
operation = "SSH_EVENT"
obj_name: str | None = None
obj_id: str | None = None
if command and not secret:
cmd, cmd_args = command
obj_id = " ".join(cmd_args)
elif secret:
obj_name = "ClientSecret"
obj_id = secret
entry = AuditLog(
subsystem="ssh",
operation=operation,
object=obj_name,
object_id=obj_id,
message=message,
origin=remote_ip,
)
if client:
entry.client_id = str(client.id)
entry.client_name = client.name
elif username:
entry.client_name = username
backend.add_audit_log_sync(entry)
def audit_event(
backend: SshecretBackend,
message: str,
operation: str = "SSH_EVENT",
client: Client | None = None,
origin: str | None = None,
secret: str | None = None,
) -> None:
"""Add an audit event."""
entry = AuditLog(
client_id=None,
client_name=None,
object=None,
object_id=None,
subsystem="ssh",
operation=operation,
message=message,
origin=origin,
)
if client:
entry.client_id = str(client.id)
entry.client_name = client.name
if secret:
entry.object = "ClientSecret"
entry.object_id = secret
backend.add_audit_log_sync(entry)
def verify_key_input(public_key: str) -> str | None: def verify_key_input(public_key: str) -> str | None:
"""Verify key input.""" """Verify key input."""
try: try:
@ -46,9 +121,9 @@ def get_process_command(
return (argv[0], argv[1:]) return (argv[0], argv[1:])
def get_info_backend(process: asyncssh.SSHServerProcess[str]) -> BackendClient | None: def get_info_backend(process: asyncssh.SSHServerProcess[str]) -> SshecretBackend | None:
"""Get backend from process.""" """Get backend from process."""
backend = cast("BackendClient | None", process.get_extra_info("backend", None)) backend = cast("SshecretBackend | None", process.get_extra_info("backend", None))
return backend return backend
@ -64,6 +139,31 @@ def get_info_username(process: asyncssh.SSHServerProcess[str]) -> str | None:
return username return username
def get_info_remote_ip(process: asyncssh.SSHServerProcess[str]) -> str | None:
"""Get remote IP."""
peername = cast("Peername | None", process.get_extra_info("peername", None))
remote_ip: str | None = None
if peername:
remote_ip = peername[0]
return remote_ip
# remote_ip = str(self._conn.get_extra_info("peername")[0])
def get_optional_commands(process: asyncssh.SSHServerProcess[str]) -> dict[str, bool]:
"""Get optional command state."""
with_registration = cast(
bool, process.get_extra_info("registration_enabled", False)
)
with_ping = cast(bool, process.get_extra_info("ping_enabled", False))
return {
"registration": with_registration,
"ping": with_ping,
}
async def get_stdin_public_key(process: asyncssh.SSHServerProcess[str]) -> str | None: async def get_stdin_public_key(process: asyncssh.SSHServerProcess[str]) -> str | None:
"""Get public key from stdin.""" """Get public key from stdin."""
process.stdout.write("Enter public key:\n") process.stdout.write("Enter public key:\n")
@ -76,6 +176,7 @@ async def get_stdin_public_key(process: asyncssh.SSHServerProcess[str]) -> str |
process.stdout.write("Invalid key. Must be RSA Public Key.\n") process.stdout.write("Invalid key. Must be RSA Public Key.\n")
except asyncssh.BreakReceived: except asyncssh.BreakReceived:
pass pass
process.stdout.write("OK\n")
return public_key return public_key
@ -90,7 +191,7 @@ def get_info_user_and_public_key(
async def register_client( async def register_client(
process: asyncssh.SSHServerProcess[str], process: asyncssh.SSHServerProcess[str],
backend: BackendClient, backend: SshecretBackend,
username: str, username: str,
) -> None: ) -> None:
"""Register a new client.""" """Register a new client."""
@ -101,35 +202,49 @@ async def register_client(
key = asyncssh.import_public_key(public_key) key = asyncssh.import_public_key(public_key)
if key.algorithm.decode() != "ssh-rsa": if key.algorithm.decode() != "ssh-rsa":
raise CommandError("Error: Only RSA keys are supported!") raise CommandError("Error: Only RSA keys are supported!")
audit_process(backend, process, "Registering new client")
LOG.debug("Registering client %s with public key %s", username, public_key) LOG.debug("Registering client %s with public key %s", username, public_key)
await backend.register_client(username, public_key) await backend.create_client(username, public_key)
async def get_secret( async def get_secret(
backend: BackendClient, backend: SshecretBackend,
client: Client, client: Client,
secret_name: str, secret_name: str,
origin: str,
) -> str: ) -> str:
"""Handle get secret requests from client.""" """Handle get secret requests from client."""
LOG.debug("Recieved command: %r", secret_name) LOG.debug("Recieved command: %r", secret_name)
if not secret_name or secret_name not in client.secrets: if not secret_name or secret_name not in client.secrets:
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET) raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
audit_event(
backend,
"Client requested secret",
operation="get_secret",
client=client,
origin=origin,
secret=secret_name,
)
# Look up secret # Look up secret
try: try:
secret = await backend.lookup_secret(client.name, secret_name) return await backend.get_client_secret(client.name, secret_name)
except Exception as exc: except Exception as exc:
LOG.debug(exc, exc_info=True) LOG.debug(exc, exc_info=True)
raise CommandError("Unexpected error from backend") from exc raise CommandError("Unexpected error from backend") from exc
return secret
async def dispatch_no_cmd(process: asyncssh.SSHServerProcess[str]) -> None: async def dispatch_no_cmd(process: asyncssh.SSHServerProcess[str]) -> None:
"""Dispatch for no command.""" """Dispatch for no command."""
raise CommandError(constants.ERROR_NO_COMMAND_RECEIVED) raise CommandError(constants.ERROR_NO_COMMAND_RECEIVED)
async def dispatch_cmd_ping(process: asyncssh.SSHServerProcess[str]) -> None:
"""Dispatch the ping command."""
process.stdout.write("PONG\n")
async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None: async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None:
"""Dispatch the register command.""" """Dispatch the register command."""
backend = get_info_backend(process) backend = get_info_backend(process)
@ -157,7 +272,8 @@ async def dispatch_cmd_get_secret(process: asyncssh.SSHServerProcess[str]) -> No
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET) raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
secret_name = args[0] secret_name = args[0]
secret = await get_secret(backend, client, secret_name) origin = get_info_remote_ip(process) or "Unknown"
secret = await get_secret(backend, client, secret_name, origin)
process.stdout.write(secret) process.stdout.write(secret)
@ -169,9 +285,14 @@ async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
process.exit(1) process.exit(1)
return return
cmdmap: dict[str, CommandDispatch] = { cmdmap: dict[str, CommandDispatch] = {
"register": dispatch_cmd_register,
"get_secret": dispatch_cmd_get_secret, "get_secret": dispatch_cmd_get_secret,
} }
extra_commands = get_optional_commands(process)
if "registration" in extra_commands:
cmdmap["register"] = dispatch_cmd_register
if "ping" in extra_commands:
cmdmap["ping"] = dispatch_cmd_ping
if command not in cmdmap: if command not in cmdmap:
process.stderr.write(constants.ERROR_UNKNOWN_COMMAND) process.stderr.write(constants.ERROR_UNKNOWN_COMMAND)
process.exit(1) process.exit(1)
@ -193,57 +314,33 @@ async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
process.exit(exit_code) process.exit(exit_code)
async def handle_secret(process: asyncssh.SSHServerProcess[str]) -> None:
"""Handle get secret requests from client."""
backend = process.get_extra_info("backend")
if not backend:
process.stderr.write("Unexpected Error: Lost connection with backend object.")
process.exit(1)
return
assert isinstance(backend, BackendClient)
client = process.get_extra_info("client")
if not client:
process.stderr.write(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
process.exit(1)
return
assert isinstance(client, Client), "Error: Unexpected client type received"
secret_name = process.command
LOG.debug("Recieved command: %r", secret_name)
if not secret_name or secret_name not in client.secrets:
process.stderr.write(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
process.exit(1)
return
# Look up secret
try:
secret = await backend.lookup_secret(client.name, secret_name)
except Exception as exc:
process.stderr.write("Unexpected error from backend:\n")
process.stderr.write(str(exc))
LOG.debug(exc, exc_info=True)
process.exit(1)
return
process.stdout.write(secret)
process.exit(0)
class AsshyncServer(asyncssh.SSHServer): class AsshyncServer(asyncssh.SSHServer):
"""Asynchronous SSH server implementation.""" """Asynchronous SSH server implementation."""
def __init__(self, settings: ServerSettings | None = None) -> None: def __init__(
self,
backend_url: str,
backend_token: str,
with_register: bool = True,
with_ping: bool = True,
) -> None:
"""Initialize server.""" """Initialize server."""
self.backend: BackendClient = BackendClient(settings) self.backend: SshecretBackend = SshecretBackend(backend_url, backend_token)
self._conn: asyncssh.SSHServerConnection | None = None self._conn: asyncssh.SSHServerConnection | None = None
self.registration_enabled: bool = with_register
self.ping_enabled: bool = with_ping
self.client_ip: str | None = None
@override @override
def connection_made(self, conn: asyncssh.SSHServerConnection) -> None: def connection_made(self, conn: asyncssh.SSHServerConnection) -> None:
"""Handle incoming connection.""" """Handle incoming connection."""
peername = conn.get_extra_info("peername") peername = conn.get_extra_info("peername")
LOG.debug("Connection established from %r", peername) LOG.debug("Connection established from %r", peername)
self.client_ip = peername[0]
self._conn = conn self._conn = conn
self._conn.set_extra_info(backend=self.backend) self._conn.set_extra_info(backend=self.backend)
self._conn.set_extra_info(registration_enabled=self.registration_enabled)
self._conn.set_extra_info(ping_enabled=self.ping_enabled)
@override @override
def password_auth_supported(self) -> bool: def password_auth_supported(self) -> bool:
@ -261,13 +358,21 @@ class AsshyncServer(asyncssh.SSHServer):
LOG.debug("Started authentication flow for user %s", username) LOG.debug("Started authentication flow for user %s", username)
if not self._conn: if not self._conn:
return True return True
if client := await self.backend.lookup_client(username): if client := await self.backend.get_client(username):
LOG.debug("Client lookup sucessful.") LOG.debug("Client lookup sucessful.")
if key := self.resolve_client_key(client): if key := self.resolve_client_key(client):
LOG.debug("Loaded public key for client %s", client.name) LOG.debug("Loaded public key for client %s", client.name)
self._conn.set_extra_info(client=client) self._conn.set_extra_info(client=client)
self._conn.set_authorized_keys(key) self._conn.set_authorized_keys(key)
else: else:
audit_event(
self.backend,
"Client denied due to policy",
"DENY",
client,
origin=self.client_ip,
)
LOG.warning("Client connection denied due to policy.") LOG.warning("Client connection denied due to policy.")
else: else:
self._conn.set_extra_info(provided_username=username) self._conn.set_extra_info(provided_username=username)
@ -308,37 +413,60 @@ class AsshyncServer(asyncssh.SSHServer):
policies = [ipaddress.ip_network(policy) for policy in client.policies] policies = [ipaddress.ip_network(policy) for policy in client.policies]
valid_source = [source_ip in policy for policy in policies] valid_source = [source_ip in policy for policy in policies]
LOG.debug("Valid sources %r from policies %r", valid_source, policies)
return any(valid_source) return any(valid_source)
def get_server_key() -> str: def get_server_key(basedir: Path | None = None) -> str:
"""Resolve server key. """Resolve server key.
TODO: Is one key enough? Should we generate more keys? TODO: Is one key enough? Should we generate more keys?
""" """
filename = f"ssh_host_{constants.SERVER_KEY_TYPE}_key" filename = Path(f"ssh_host_{constants.SERVER_KEY_TYPE}_key")
if Path(filename).exists(): if basedir:
return filename filename = basedir / filename
if filename.exists():
return str(filename.absolute())
# FIXME: There's a weird typing warning here that I need to investigate. # FIXME: There's a weird typing warning here that I need to investigate.
private_key = asyncssh.generate_private_key("ssh-ed25519", "sshecret-sshd") private_key = asyncssh.generate_private_key("ssh-ed25519", "sshecret-sshd")
with open(filename, "wb") as f: with open(filename, "wb") as f:
f.write(private_key.export_private_key()) f.write(private_key.export_private_key())
return filename return str(filename.absolute())
async def run_ssh_server(
backend_url: str,
backend_token: str,
listen_address: str,
port: int,
keys: list[str],
) -> asyncssh.SSHAcceptor:
"""Run the server."""
server = partial(
AsshyncServer, backend_url=str(backend_url), backend_token=backend_token
)
server = await asyncssh.create_server(
server,
listen_address,
port,
server_host_keys=keys,
process_factory=dispatch_command,
)
return server
async def start_server(settings: ServerSettings | None = None) -> None: async def start_server(settings: ServerSettings | None = None) -> None:
"""Start the server.""" """Start the server."""
server_key = get_server_key() server_key = get_server_key()
server = partial(AsshyncServer, settings=settings)
if not settings: if not settings:
settings = ServerSettings() # pyright: ignore[reportCallIssue] settings = ServerSettings() # pyright: ignore[reportCallIssue]
await asyncssh.create_server( await run_ssh_server(
server, str(settings.backend_url),
settings.backend_token,
settings.listen_address, settings.listen_address,
settings.port, settings.port,
server_host_keys=[server_key], [server_key],
process_factory=dispatch_command,
) )

View File

@ -1,26 +0,0 @@
"""Types."""
import uuid
from datetime import datetime
from pydantic import BaseModel, IPvAnyAddress, IPvAnyNetwork
class Client(BaseModel):
"""Implementation of the backend class ClientView."""
id: uuid.UUID
name: str
public_key: str
secrets: list[str]
policies: list[IPvAnyNetwork | IPvAnyAddress]
created_at: datetime
updated_at: datetime | None
class Secret(BaseModel):
"""Implemen tation of the backend class ClientSecretResponse."""
name: str
secret: str
created_at: datetime
updated_at: datetime | None