|
|
|
|
@ -1,18 +1,20 @@
|
|
|
|
|
"""SSH Server implementation."""
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
import uuid
|
|
|
|
|
|
|
|
|
|
import asyncssh
|
|
|
|
|
import ipaddress
|
|
|
|
|
|
|
|
|
|
from collections.abc import Awaitable
|
|
|
|
|
from datetime import datetime, timezone
|
|
|
|
|
from functools import partial
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Awaitable, Callable, cast, override
|
|
|
|
|
from typing import Any, Callable, cast, override
|
|
|
|
|
|
|
|
|
|
from . import constants
|
|
|
|
|
|
|
|
|
|
from .backend_client import BackendClient
|
|
|
|
|
from .types import Client
|
|
|
|
|
from sshecret.backend import AuditLog, SshecretBackend, Client
|
|
|
|
|
from .settings import ServerSettings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -21,11 +23,84 @@ LOG = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
CommandDispatch = Callable[[asyncssh.SSHServerProcess[str]], Awaitable[None]]
|
|
|
|
|
|
|
|
|
|
PeernameV4 = tuple[str, int]
|
|
|
|
|
PeernameV6 = tuple[str, int, int, int]
|
|
|
|
|
Peername = PeernameV4 | PeernameV6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CommandError(Exception):
|
|
|
|
|
"""Error class for errors during command processing."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def audit_process(
|
|
|
|
|
backend: SshecretBackend,
|
|
|
|
|
process: asyncssh.SSHServerProcess[str],
|
|
|
|
|
message: str,
|
|
|
|
|
secret: str | None = None,
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Add an audit event from process."""
|
|
|
|
|
command = get_process_command(process)
|
|
|
|
|
client = get_info_client(process)
|
|
|
|
|
username = get_info_username(process)
|
|
|
|
|
remote_ip = get_info_remote_ip(process)
|
|
|
|
|
operation = "SSH_EVENT"
|
|
|
|
|
obj_name: str | None = None
|
|
|
|
|
obj_id: str | None = None
|
|
|
|
|
|
|
|
|
|
if command and not secret:
|
|
|
|
|
cmd, cmd_args = command
|
|
|
|
|
obj_id = " ".join(cmd_args)
|
|
|
|
|
elif secret:
|
|
|
|
|
obj_name = "ClientSecret"
|
|
|
|
|
obj_id = secret
|
|
|
|
|
|
|
|
|
|
entry = AuditLog(
|
|
|
|
|
subsystem="ssh",
|
|
|
|
|
operation=operation,
|
|
|
|
|
object=obj_name,
|
|
|
|
|
object_id=obj_id,
|
|
|
|
|
message=message,
|
|
|
|
|
origin=remote_ip,
|
|
|
|
|
)
|
|
|
|
|
if client:
|
|
|
|
|
entry.client_id = str(client.id)
|
|
|
|
|
entry.client_name = client.name
|
|
|
|
|
elif username:
|
|
|
|
|
entry.client_name = username
|
|
|
|
|
|
|
|
|
|
backend.add_audit_log_sync(entry)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def audit_event(
|
|
|
|
|
backend: SshecretBackend,
|
|
|
|
|
message: str,
|
|
|
|
|
operation: str = "SSH_EVENT",
|
|
|
|
|
client: Client | None = None,
|
|
|
|
|
origin: str | None = None,
|
|
|
|
|
secret: str | None = None,
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Add an audit event."""
|
|
|
|
|
entry = AuditLog(
|
|
|
|
|
client_id=None,
|
|
|
|
|
client_name=None,
|
|
|
|
|
object=None,
|
|
|
|
|
object_id=None,
|
|
|
|
|
subsystem="ssh",
|
|
|
|
|
operation=operation,
|
|
|
|
|
message=message,
|
|
|
|
|
origin=origin,
|
|
|
|
|
)
|
|
|
|
|
if client:
|
|
|
|
|
entry.client_id = str(client.id)
|
|
|
|
|
entry.client_name = client.name
|
|
|
|
|
|
|
|
|
|
if secret:
|
|
|
|
|
entry.object = "ClientSecret"
|
|
|
|
|
entry.object_id = secret
|
|
|
|
|
|
|
|
|
|
backend.add_audit_log_sync(entry)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def verify_key_input(public_key: str) -> str | None:
|
|
|
|
|
"""Verify key input."""
|
|
|
|
|
try:
|
|
|
|
|
@ -46,9 +121,9 @@ def get_process_command(
|
|
|
|
|
return (argv[0], argv[1:])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_info_backend(process: asyncssh.SSHServerProcess[str]) -> BackendClient | None:
|
|
|
|
|
def get_info_backend(process: asyncssh.SSHServerProcess[str]) -> SshecretBackend | None:
|
|
|
|
|
"""Get backend from process."""
|
|
|
|
|
backend = cast("BackendClient | None", process.get_extra_info("backend", None))
|
|
|
|
|
backend = cast("SshecretBackend | None", process.get_extra_info("backend", None))
|
|
|
|
|
return backend
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -64,6 +139,31 @@ def get_info_username(process: asyncssh.SSHServerProcess[str]) -> str | None:
|
|
|
|
|
return username
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_info_remote_ip(process: asyncssh.SSHServerProcess[str]) -> str | None:
|
|
|
|
|
"""Get remote IP."""
|
|
|
|
|
|
|
|
|
|
peername = cast("Peername | None", process.get_extra_info("peername", None))
|
|
|
|
|
remote_ip: str | None = None
|
|
|
|
|
if peername:
|
|
|
|
|
remote_ip = peername[0]
|
|
|
|
|
|
|
|
|
|
return remote_ip
|
|
|
|
|
|
|
|
|
|
# remote_ip = str(self._conn.get_extra_info("peername")[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_optional_commands(process: asyncssh.SSHServerProcess[str]) -> dict[str, bool]:
|
|
|
|
|
"""Get optional command state."""
|
|
|
|
|
with_registration = cast(
|
|
|
|
|
bool, process.get_extra_info("registration_enabled", False)
|
|
|
|
|
)
|
|
|
|
|
with_ping = cast(bool, process.get_extra_info("ping_enabled", False))
|
|
|
|
|
return {
|
|
|
|
|
"registration": with_registration,
|
|
|
|
|
"ping": with_ping,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_stdin_public_key(process: asyncssh.SSHServerProcess[str]) -> str | None:
|
|
|
|
|
"""Get public key from stdin."""
|
|
|
|
|
process.stdout.write("Enter public key:\n")
|
|
|
|
|
@ -76,6 +176,7 @@ async def get_stdin_public_key(process: asyncssh.SSHServerProcess[str]) -> str |
|
|
|
|
|
process.stdout.write("Invalid key. Must be RSA Public Key.\n")
|
|
|
|
|
except asyncssh.BreakReceived:
|
|
|
|
|
pass
|
|
|
|
|
process.stdout.write("OK\n")
|
|
|
|
|
return public_key
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -90,7 +191,7 @@ def get_info_user_and_public_key(
|
|
|
|
|
|
|
|
|
|
async def register_client(
|
|
|
|
|
process: asyncssh.SSHServerProcess[str],
|
|
|
|
|
backend: BackendClient,
|
|
|
|
|
backend: SshecretBackend,
|
|
|
|
|
username: str,
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Register a new client."""
|
|
|
|
|
@ -101,35 +202,49 @@ async def register_client(
|
|
|
|
|
key = asyncssh.import_public_key(public_key)
|
|
|
|
|
if key.algorithm.decode() != "ssh-rsa":
|
|
|
|
|
raise CommandError("Error: Only RSA keys are supported!")
|
|
|
|
|
audit_process(backend, process, "Registering new client")
|
|
|
|
|
LOG.debug("Registering client %s with public key %s", username, public_key)
|
|
|
|
|
await backend.register_client(username, public_key)
|
|
|
|
|
await backend.create_client(username, public_key)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_secret(
|
|
|
|
|
backend: BackendClient,
|
|
|
|
|
backend: SshecretBackend,
|
|
|
|
|
client: Client,
|
|
|
|
|
secret_name: str,
|
|
|
|
|
origin: str,
|
|
|
|
|
) -> str:
|
|
|
|
|
"""Handle get secret requests from client."""
|
|
|
|
|
LOG.debug("Recieved command: %r", secret_name)
|
|
|
|
|
if not secret_name or secret_name not in client.secrets:
|
|
|
|
|
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
|
|
|
|
|
|
|
|
|
audit_event(
|
|
|
|
|
backend,
|
|
|
|
|
"Client requested secret",
|
|
|
|
|
operation="get_secret",
|
|
|
|
|
client=client,
|
|
|
|
|
origin=origin,
|
|
|
|
|
secret=secret_name,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Look up secret
|
|
|
|
|
try:
|
|
|
|
|
secret = await backend.lookup_secret(client.name, secret_name)
|
|
|
|
|
return await backend.get_client_secret(client.name, secret_name)
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
LOG.debug(exc, exc_info=True)
|
|
|
|
|
raise CommandError("Unexpected error from backend") from exc
|
|
|
|
|
|
|
|
|
|
return secret
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def dispatch_no_cmd(process: asyncssh.SSHServerProcess[str]) -> None:
|
|
|
|
|
"""Dispatch for no command."""
|
|
|
|
|
raise CommandError(constants.ERROR_NO_COMMAND_RECEIVED)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def dispatch_cmd_ping(process: asyncssh.SSHServerProcess[str]) -> None:
|
|
|
|
|
"""Dispatch the ping command."""
|
|
|
|
|
process.stdout.write("PONG\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None:
|
|
|
|
|
"""Dispatch the register command."""
|
|
|
|
|
backend = get_info_backend(process)
|
|
|
|
|
@ -157,7 +272,8 @@ async def dispatch_cmd_get_secret(process: asyncssh.SSHServerProcess[str]) -> No
|
|
|
|
|
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
|
|
|
|
secret_name = args[0]
|
|
|
|
|
|
|
|
|
|
secret = await get_secret(backend, client, secret_name)
|
|
|
|
|
origin = get_info_remote_ip(process) or "Unknown"
|
|
|
|
|
secret = await get_secret(backend, client, secret_name, origin)
|
|
|
|
|
process.stdout.write(secret)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -169,9 +285,14 @@ async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
|
|
|
|
|
process.exit(1)
|
|
|
|
|
return
|
|
|
|
|
cmdmap: dict[str, CommandDispatch] = {
|
|
|
|
|
"register": dispatch_cmd_register,
|
|
|
|
|
"get_secret": dispatch_cmd_get_secret,
|
|
|
|
|
}
|
|
|
|
|
extra_commands = get_optional_commands(process)
|
|
|
|
|
if "registration" in extra_commands:
|
|
|
|
|
cmdmap["register"] = dispatch_cmd_register
|
|
|
|
|
if "ping" in extra_commands:
|
|
|
|
|
cmdmap["ping"] = dispatch_cmd_ping
|
|
|
|
|
|
|
|
|
|
if command not in cmdmap:
|
|
|
|
|
process.stderr.write(constants.ERROR_UNKNOWN_COMMAND)
|
|
|
|
|
process.exit(1)
|
|
|
|
|
@ -193,57 +314,33 @@ async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
|
|
|
|
|
process.exit(exit_code)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def handle_secret(process: asyncssh.SSHServerProcess[str]) -> None:
|
|
|
|
|
"""Handle get secret requests from client."""
|
|
|
|
|
backend = process.get_extra_info("backend")
|
|
|
|
|
if not backend:
|
|
|
|
|
process.stderr.write("Unexpected Error: Lost connection with backend object.")
|
|
|
|
|
process.exit(1)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
assert isinstance(backend, BackendClient)
|
|
|
|
|
|
|
|
|
|
client = process.get_extra_info("client")
|
|
|
|
|
if not client:
|
|
|
|
|
process.stderr.write(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
|
|
|
|
process.exit(1)
|
|
|
|
|
return
|
|
|
|
|
assert isinstance(client, Client), "Error: Unexpected client type received"
|
|
|
|
|
secret_name = process.command
|
|
|
|
|
LOG.debug("Recieved command: %r", secret_name)
|
|
|
|
|
if not secret_name or secret_name not in client.secrets:
|
|
|
|
|
process.stderr.write(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
|
|
|
|
process.exit(1)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Look up secret
|
|
|
|
|
try:
|
|
|
|
|
secret = await backend.lookup_secret(client.name, secret_name)
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
process.stderr.write("Unexpected error from backend:\n")
|
|
|
|
|
process.stderr.write(str(exc))
|
|
|
|
|
LOG.debug(exc, exc_info=True)
|
|
|
|
|
process.exit(1)
|
|
|
|
|
return
|
|
|
|
|
process.stdout.write(secret)
|
|
|
|
|
process.exit(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AsshyncServer(asyncssh.SSHServer):
|
|
|
|
|
"""Asynchronous SSH server implementation."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, settings: ServerSettings | None = None) -> None:
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
backend_url: str,
|
|
|
|
|
backend_token: str,
|
|
|
|
|
with_register: bool = True,
|
|
|
|
|
with_ping: bool = True,
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Initialize server."""
|
|
|
|
|
self.backend: BackendClient = BackendClient(settings)
|
|
|
|
|
self.backend: SshecretBackend = SshecretBackend(backend_url, backend_token)
|
|
|
|
|
self._conn: asyncssh.SSHServerConnection | None = None
|
|
|
|
|
self.registration_enabled: bool = with_register
|
|
|
|
|
self.ping_enabled: bool = with_ping
|
|
|
|
|
self.client_ip: str | None = None
|
|
|
|
|
|
|
|
|
|
@override
|
|
|
|
|
def connection_made(self, conn: asyncssh.SSHServerConnection) -> None:
|
|
|
|
|
"""Handle incoming connection."""
|
|
|
|
|
peername = conn.get_extra_info("peername")
|
|
|
|
|
LOG.debug("Connection established from %r", peername)
|
|
|
|
|
self.client_ip = peername[0]
|
|
|
|
|
self._conn = conn
|
|
|
|
|
self._conn.set_extra_info(backend=self.backend)
|
|
|
|
|
self._conn.set_extra_info(registration_enabled=self.registration_enabled)
|
|
|
|
|
self._conn.set_extra_info(ping_enabled=self.ping_enabled)
|
|
|
|
|
|
|
|
|
|
@override
|
|
|
|
|
def password_auth_supported(self) -> bool:
|
|
|
|
|
@ -261,13 +358,21 @@ class AsshyncServer(asyncssh.SSHServer):
|
|
|
|
|
LOG.debug("Started authentication flow for user %s", username)
|
|
|
|
|
if not self._conn:
|
|
|
|
|
return True
|
|
|
|
|
if client := await self.backend.lookup_client(username):
|
|
|
|
|
if client := await self.backend.get_client(username):
|
|
|
|
|
LOG.debug("Client lookup sucessful.")
|
|
|
|
|
if key := self.resolve_client_key(client):
|
|
|
|
|
LOG.debug("Loaded public key for client %s", client.name)
|
|
|
|
|
self._conn.set_extra_info(client=client)
|
|
|
|
|
self._conn.set_authorized_keys(key)
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
audit_event(
|
|
|
|
|
self.backend,
|
|
|
|
|
"Client denied due to policy",
|
|
|
|
|
"DENY",
|
|
|
|
|
client,
|
|
|
|
|
origin=self.client_ip,
|
|
|
|
|
)
|
|
|
|
|
LOG.warning("Client connection denied due to policy.")
|
|
|
|
|
else:
|
|
|
|
|
self._conn.set_extra_info(provided_username=username)
|
|
|
|
|
@ -308,37 +413,60 @@ class AsshyncServer(asyncssh.SSHServer):
|
|
|
|
|
policies = [ipaddress.ip_network(policy) for policy in client.policies]
|
|
|
|
|
|
|
|
|
|
valid_source = [source_ip in policy for policy in policies]
|
|
|
|
|
LOG.debug("Valid sources %r from policies %r", valid_source, policies)
|
|
|
|
|
return any(valid_source)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_server_key() -> str:
|
|
|
|
|
def get_server_key(basedir: Path | None = None) -> str:
|
|
|
|
|
"""Resolve server key.
|
|
|
|
|
|
|
|
|
|
TODO: Is one key enough? Should we generate more keys?
|
|
|
|
|
"""
|
|
|
|
|
filename = f"ssh_host_{constants.SERVER_KEY_TYPE}_key"
|
|
|
|
|
if Path(filename).exists():
|
|
|
|
|
return filename
|
|
|
|
|
filename = Path(f"ssh_host_{constants.SERVER_KEY_TYPE}_key")
|
|
|
|
|
if basedir:
|
|
|
|
|
filename = basedir / filename
|
|
|
|
|
if filename.exists():
|
|
|
|
|
return str(filename.absolute())
|
|
|
|
|
# FIXME: There's a weird typing warning here that I need to investigate.
|
|
|
|
|
private_key = asyncssh.generate_private_key("ssh-ed25519", "sshecret-sshd")
|
|
|
|
|
with open(filename, "wb") as f:
|
|
|
|
|
f.write(private_key.export_private_key())
|
|
|
|
|
|
|
|
|
|
return filename
|
|
|
|
|
return str(filename.absolute())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def run_ssh_server(
|
|
|
|
|
backend_url: str,
|
|
|
|
|
backend_token: str,
|
|
|
|
|
listen_address: str,
|
|
|
|
|
port: int,
|
|
|
|
|
keys: list[str],
|
|
|
|
|
) -> asyncssh.SSHAcceptor:
|
|
|
|
|
"""Run the server."""
|
|
|
|
|
server = partial(
|
|
|
|
|
AsshyncServer, backend_url=str(backend_url), backend_token=backend_token
|
|
|
|
|
)
|
|
|
|
|
server = await asyncssh.create_server(
|
|
|
|
|
server,
|
|
|
|
|
listen_address,
|
|
|
|
|
port,
|
|
|
|
|
server_host_keys=keys,
|
|
|
|
|
process_factory=dispatch_command,
|
|
|
|
|
)
|
|
|
|
|
return server
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def start_server(settings: ServerSettings | None = None) -> None:
|
|
|
|
|
"""Start the server."""
|
|
|
|
|
server_key = get_server_key()
|
|
|
|
|
server = partial(AsshyncServer, settings=settings)
|
|
|
|
|
|
|
|
|
|
if not settings:
|
|
|
|
|
settings = ServerSettings() # pyright: ignore[reportCallIssue]
|
|
|
|
|
|
|
|
|
|
await asyncssh.create_server(
|
|
|
|
|
server,
|
|
|
|
|
await run_ssh_server(
|
|
|
|
|
str(settings.backend_url),
|
|
|
|
|
settings.backend_token,
|
|
|
|
|
settings.listen_address,
|
|
|
|
|
settings.port,
|
|
|
|
|
server_host_keys=[server_key],
|
|
|
|
|
process_factory=dispatch_command,
|
|
|
|
|
[server_key],
|
|
|
|
|
)
|
|
|
|
|
|