Complete sshd
This commit is contained in:
67
packages/sshecret-sshd/src/sshecret_sshd/backend_client.py
Normal file
67
packages/sshecret-sshd/src/sshecret_sshd/backend_client.py
Normal file
@ -0,0 +1,67 @@
|
||||
"""Backend client.
|
||||
|
||||
This is intended to be as minimal as possible.
|
||||
"""
|
||||
|
||||
import httpx
|
||||
import urllib.parse
|
||||
|
||||
from .types import Client, Secret
|
||||
from .settings import ServerSettings
|
||||
|
||||
|
||||
class BackendClient:
|
||||
"""Backend client."""
|
||||
|
||||
def __init__(self, settings: ServerSettings | None = None) -> None:
|
||||
"""Initialize backend client."""
|
||||
if not settings:
|
||||
settings = ServerSettings() # pyright: ignore[reportCallIssue]
|
||||
self.settings: ServerSettings = settings
|
||||
|
||||
@property
|
||||
def headers(self) -> dict[str, str]:
|
||||
"""Get the headers."""
|
||||
return {"X-Api-Token": self.settings.backend_token}
|
||||
|
||||
def _format_url(self, path: str) -> str:
|
||||
"""Format a URL."""
|
||||
return urllib.parse.urljoin(str(self.settings.backend_url), path)
|
||||
|
||||
async def request(self, path: str) -> httpx.Response:
|
||||
"""Send a simple GET request."""
|
||||
url = self._format_url(path)
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
response = await http_client.get(url, headers=self.headers)
|
||||
return response
|
||||
|
||||
async def lookup_client(self, username: str) -> Client | None:
|
||||
"""Lookup a client on username."""
|
||||
path = f"api/v1/clients/{username}"
|
||||
response = await self.request(path)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
response.raise_for_status()
|
||||
client = Client.model_validate(response.json())
|
||||
return client
|
||||
|
||||
async def lookup_secret(self, username: str, secret_name: str) -> str:
|
||||
"""Fetch a secret."""
|
||||
path = f"api/v1/clients/{username}/secrets/{secret_name}"
|
||||
response = await self.request(path)
|
||||
response.raise_for_status()
|
||||
secret = Secret.model_validate(response.json())
|
||||
return secret.secret
|
||||
|
||||
async def register_client(self, username: str, public_key: str) -> None:
|
||||
"""Register a new client."""
|
||||
data = {
|
||||
"name": username,
|
||||
"public_key": public_key,
|
||||
}
|
||||
path = "api/v1/clients/"
|
||||
url = self._format_url(path)
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
response = await http_client.post(url, headers=self.headers, json=data)
|
||||
|
||||
response.raise_for_status()
|
||||
46
packages/sshecret-sshd/src/sshecret_sshd/cli.py
Normal file
46
packages/sshecret-sshd/src/sshecret_sshd/cli.py
Normal file
@ -0,0 +1,46 @@
|
||||
"""CLI app."""
|
||||
import logging
|
||||
import asyncio
|
||||
import sys
|
||||
from pydantic_settings import CliApp
|
||||
|
||||
from .settings import ServerSettings
|
||||
from .ssh_server import start_server
|
||||
|
||||
LOG = logging.getLogger()
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter("%(created)f:%(levelname)s:%(name)s:%(module)s:%(message)s")
|
||||
|
||||
handler.setFormatter(formatter)
|
||||
LOG.addHandler(handler)
|
||||
LOG.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def cli(args: list[str] | None = None) -> None:
|
||||
"""Run CLI app."""
|
||||
try:
|
||||
settings = ServerSettings()
|
||||
except Exception:
|
||||
print("One or more settings could not be resolved.")
|
||||
CliApp.run(ServerSettings, ["--help"])
|
||||
sys.exit(1)
|
||||
|
||||
if settings.debug:
|
||||
LOG.setLevel(logging.DEBUG)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.run_until_complete(start_server(settings))
|
||||
|
||||
print(f"Starting SSH server: {settings.listen_address}:{settings.port}")
|
||||
try:
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt:
|
||||
print("\nCtrl-C received. Exiting.")
|
||||
sys.exit()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""Run CLI app."""
|
||||
cli()
|
||||
8
packages/sshecret-sshd/src/sshecret_sshd/constants.py
Normal file
8
packages/sshecret-sshd/src/sshecret_sshd/constants.py
Normal file
@ -0,0 +1,8 @@
|
||||
ERROR_NO_SECRET_FOUND = "Error: No secret available with the given name."
|
||||
ERROR_UNKNOWN_CLIENT_OR_SECRET = "Error: Invalid client or secret name."
|
||||
ERROR_NO_COMMAND_RECEIVED = "Error: No command was received from the client."
|
||||
ERROR_SOURCE_IP_NOT_ALLOWED = (
|
||||
"Error: Client not authorized to connect from the given host."
|
||||
)
|
||||
ERROR_UNKNOWN_COMMAND = "Error: The given command was not understood."
|
||||
SERVER_KEY_TYPE = "ed25519"
|
||||
18
packages/sshecret-sshd/src/sshecret_sshd/settings.py
Normal file
18
packages/sshecret-sshd/src/sshecret_sshd/settings.py
Normal file
@ -0,0 +1,18 @@
|
||||
"""SSH Server settings."""
|
||||
|
||||
from pydantic import AnyHttpUrl, Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
DEFAULT_LISTEN_PORT = 2222
|
||||
|
||||
class ServerSettings(BaseSettings, cli_parse_args=True, cli_exit_on_error=True):
|
||||
"""Server Settings."""
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".sshd.env", env_prefix="sshecret_sshd_")
|
||||
|
||||
backend_url: AnyHttpUrl = Field(alias="backend-url")
|
||||
backend_token: str = Field(alias="backend-token")
|
||||
listen_address: str = Field(default="", alias="listen")
|
||||
port: int = DEFAULT_LISTEN_PORT
|
||||
debug: bool = False
|
||||
344
packages/sshecret-sshd/src/sshecret_sshd/ssh_server.py
Normal file
344
packages/sshecret-sshd/src/sshecret_sshd/ssh_server.py
Normal file
@ -0,0 +1,344 @@
|
||||
"""SSH Server implementation."""
|
||||
|
||||
import logging
|
||||
|
||||
import asyncssh
|
||||
import ipaddress
|
||||
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Awaitable, Callable, cast, override
|
||||
|
||||
from . import constants
|
||||
|
||||
from .backend_client import BackendClient
|
||||
from .types import Client
|
||||
from .settings import ServerSettings
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CommandDispatch = Callable[[asyncssh.SSHServerProcess[str]], Awaitable[None]]
|
||||
|
||||
|
||||
class CommandError(Exception):
|
||||
"""Error class for errors during command processing."""
|
||||
|
||||
|
||||
def verify_key_input(public_key: str) -> str | None:
|
||||
"""Verify key input."""
|
||||
try:
|
||||
key = asyncssh.import_public_key(public_key)
|
||||
if key.algorithm.decode() == "ssh-rsa":
|
||||
return public_key
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def get_process_command(
|
||||
process: asyncssh.SSHServerProcess[str],
|
||||
) -> tuple[str | None, list[str]]:
|
||||
"""Extract the process command."""
|
||||
if not process.command:
|
||||
return (None, [])
|
||||
argv = process.command.split(" ")
|
||||
return (argv[0], argv[1:])
|
||||
|
||||
|
||||
def get_info_backend(process: asyncssh.SSHServerProcess[str]) -> BackendClient | None:
|
||||
"""Get backend from process."""
|
||||
backend = cast("BackendClient | None", process.get_extra_info("backend", None))
|
||||
return backend
|
||||
|
||||
|
||||
def get_info_client(process: asyncssh.SSHServerProcess[str]) -> Client | None:
|
||||
"""Get info from process."""
|
||||
client = cast("Client | None", process.get_extra_info("client", None))
|
||||
return client
|
||||
|
||||
|
||||
def get_info_username(process: asyncssh.SSHServerProcess[str]) -> str | None:
|
||||
"""Get username from process."""
|
||||
username = cast("str | None", process.get_extra_info("provided_username", None))
|
||||
return username
|
||||
|
||||
|
||||
async def get_stdin_public_key(process: asyncssh.SSHServerProcess[str]) -> str | None:
|
||||
"""Get public key from stdin."""
|
||||
process.stdout.write("Enter public key:\n")
|
||||
public_key: str | None = None
|
||||
try:
|
||||
async for line in process.stdin:
|
||||
public_key = verify_key_input(line.rstrip("\n"))
|
||||
if public_key:
|
||||
break
|
||||
process.stdout.write("Invalid key. Must be RSA Public Key.\n")
|
||||
except asyncssh.BreakReceived:
|
||||
pass
|
||||
return public_key
|
||||
|
||||
|
||||
def get_info_user_and_public_key(
|
||||
process: asyncssh.SSHServerProcess[str],
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Get username and public_key from process."""
|
||||
username = cast("str | None", process.get_extra_info("provided_username", None))
|
||||
public_key = cast("str | None", process.get_extra_info("provided_key", None))
|
||||
return (username, public_key)
|
||||
|
||||
|
||||
async def register_client(
|
||||
process: asyncssh.SSHServerProcess[str],
|
||||
backend: BackendClient,
|
||||
username: str,
|
||||
) -> None:
|
||||
"""Register a new client."""
|
||||
public_key = await get_stdin_public_key(process)
|
||||
if not public_key:
|
||||
raise CommandError("Aborted. No valid public key received.")
|
||||
|
||||
key = asyncssh.import_public_key(public_key)
|
||||
if key.algorithm.decode() != "ssh-rsa":
|
||||
raise CommandError("Error: Only RSA keys are supported!")
|
||||
LOG.debug("Registering client %s with public key %s", username, public_key)
|
||||
await backend.register_client(username, public_key)
|
||||
|
||||
|
||||
async def get_secret(
|
||||
backend: BackendClient,
|
||||
client: Client,
|
||||
secret_name: str,
|
||||
) -> str:
|
||||
"""Handle get secret requests from client."""
|
||||
LOG.debug("Recieved command: %r", secret_name)
|
||||
if not secret_name or secret_name not in client.secrets:
|
||||
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
||||
|
||||
# Look up secret
|
||||
try:
|
||||
secret = await backend.lookup_secret(client.name, secret_name)
|
||||
except Exception as exc:
|
||||
LOG.debug(exc, exc_info=True)
|
||||
raise CommandError("Unexpected error from backend") from exc
|
||||
|
||||
return secret
|
||||
|
||||
|
||||
async def dispatch_no_cmd(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
"""Dispatch for no command."""
|
||||
raise CommandError(constants.ERROR_NO_COMMAND_RECEIVED)
|
||||
|
||||
|
||||
async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
"""Dispatch the register command."""
|
||||
backend = get_info_backend(process)
|
||||
if not backend:
|
||||
raise CommandError("Unexpected error: Backend disappeared.")
|
||||
username = get_info_username(process)
|
||||
if not username:
|
||||
raise CommandError("Unexpected error: Username was lost.")
|
||||
await register_client(process, backend, username)
|
||||
|
||||
process.stdout.write("Client registered\n.")
|
||||
|
||||
|
||||
async def dispatch_cmd_get_secret(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
"""Dispatch the get_secret command."""
|
||||
backend = get_info_backend(process)
|
||||
if not backend:
|
||||
raise CommandError("Unexpected error: Backend disappeared.")
|
||||
|
||||
client = get_info_client(process)
|
||||
if not client:
|
||||
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
||||
_cmd, args = get_process_command(process)
|
||||
if not args:
|
||||
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
||||
secret_name = args[0]
|
||||
|
||||
secret = await get_secret(backend, client, secret_name)
|
||||
process.stdout.write(secret)
|
||||
|
||||
|
||||
async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
"""Dispatch command."""
|
||||
command, _args = get_process_command(process)
|
||||
if not command:
|
||||
process.stderr.write(constants.ERROR_NO_COMMAND_RECEIVED)
|
||||
process.exit(1)
|
||||
return
|
||||
cmdmap: dict[str, CommandDispatch] = {
|
||||
"register": dispatch_cmd_register,
|
||||
"get_secret": dispatch_cmd_get_secret,
|
||||
}
|
||||
if command not in cmdmap:
|
||||
process.stderr.write(constants.ERROR_UNKNOWN_COMMAND)
|
||||
process.exit(1)
|
||||
return
|
||||
exit_code = 0
|
||||
try:
|
||||
dispatcher = cmdmap[command]
|
||||
await dispatcher(process)
|
||||
except CommandError as e:
|
||||
process.stderr.write(str(e))
|
||||
exit_code = 1
|
||||
|
||||
except Exception as e:
|
||||
LOG.debug(e, exc_info=True)
|
||||
process.stderr.write("Unexpected exception:\n")
|
||||
process.stderr.write(str(e))
|
||||
exit_code = 1
|
||||
|
||||
process.exit(exit_code)
|
||||
|
||||
|
||||
async def handle_secret(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
"""Handle get secret requests from client."""
|
||||
backend = process.get_extra_info("backend")
|
||||
if not backend:
|
||||
process.stderr.write("Unexpected Error: Lost connection with backend object.")
|
||||
process.exit(1)
|
||||
return
|
||||
|
||||
assert isinstance(backend, BackendClient)
|
||||
|
||||
client = process.get_extra_info("client")
|
||||
if not client:
|
||||
process.stderr.write(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
||||
process.exit(1)
|
||||
return
|
||||
assert isinstance(client, Client), "Error: Unexpected client type received"
|
||||
secret_name = process.command
|
||||
LOG.debug("Recieved command: %r", secret_name)
|
||||
if not secret_name or secret_name not in client.secrets:
|
||||
process.stderr.write(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
||||
process.exit(1)
|
||||
return
|
||||
|
||||
# Look up secret
|
||||
try:
|
||||
secret = await backend.lookup_secret(client.name, secret_name)
|
||||
except Exception as exc:
|
||||
process.stderr.write("Unexpected error from backend:\n")
|
||||
process.stderr.write(str(exc))
|
||||
LOG.debug(exc, exc_info=True)
|
||||
process.exit(1)
|
||||
return
|
||||
process.stdout.write(secret)
|
||||
process.exit(0)
|
||||
|
||||
|
||||
class AsshyncServer(asyncssh.SSHServer):
|
||||
"""Asynchronous SSH server implementation."""
|
||||
|
||||
def __init__(self, settings: ServerSettings | None = None) -> None:
|
||||
"""Initialize server."""
|
||||
self.backend: BackendClient = BackendClient(settings)
|
||||
self._conn: asyncssh.SSHServerConnection | None = None
|
||||
|
||||
@override
|
||||
def connection_made(self, conn: asyncssh.SSHServerConnection) -> None:
|
||||
"""Handle incoming connection."""
|
||||
peername = conn.get_extra_info("peername")
|
||||
LOG.debug("Connection established from %r", peername)
|
||||
self._conn = conn
|
||||
self._conn.set_extra_info(backend=self.backend)
|
||||
|
||||
@override
|
||||
def password_auth_supported(self) -> bool:
|
||||
"""Deny password authentication."""
|
||||
return False
|
||||
|
||||
@override
|
||||
async def begin_auth(self, username: str) -> bool:
|
||||
"""Begin authentication.
|
||||
|
||||
Note we always return True here. False bypasses the whole authentication
|
||||
flow.
|
||||
|
||||
"""
|
||||
LOG.debug("Started authentication flow for user %s", username)
|
||||
if not self._conn:
|
||||
return True
|
||||
if client := await self.backend.lookup_client(username):
|
||||
LOG.debug("Client lookup sucessful.")
|
||||
if key := self.resolve_client_key(client):
|
||||
LOG.debug("Loaded public key for client %s", client.name)
|
||||
self._conn.set_extra_info(client=client)
|
||||
self._conn.set_authorized_keys(key)
|
||||
else:
|
||||
LOG.warning("Client connection denied due to policy.")
|
||||
else:
|
||||
self._conn.set_extra_info(provided_username=username)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@override
|
||||
def validate_public_key(self, username: str, key: asyncssh.SSHKey) -> bool:
|
||||
"""Intercept public key validation."""
|
||||
if not self._conn:
|
||||
return False
|
||||
|
||||
# get an export of the provided public key.
|
||||
keystring = key.export_public_key().decode()
|
||||
self._conn.set_extra_info(provided_username=username)
|
||||
self._conn.set_extra_info(provided_key=keystring)
|
||||
LOG.debug("Intercepting user public key")
|
||||
return False
|
||||
|
||||
def resolve_client_key(self, client: Client) -> asyncssh.SSHAuthorizedKeys | None:
|
||||
"""Resolve the client key.
|
||||
|
||||
Returns the key object only if the client is allowed to connect
|
||||
according to its policy.
|
||||
"""
|
||||
if not self._conn:
|
||||
return None
|
||||
remote_ip = str(self._conn.get_extra_info("peername")[0])
|
||||
LOG.debug("Validating client %s connection from %s", client.name, remote_ip)
|
||||
if self.check_connection_allowed(client, remote_ip):
|
||||
return asyncssh.import_authorized_keys(client.public_key)
|
||||
return None
|
||||
|
||||
def check_connection_allowed(self, client: Client, source: str) -> bool:
|
||||
"""Check if the client is allowed to connect."""
|
||||
source_ip = ipaddress.ip_address(source)
|
||||
policies = [ipaddress.ip_network(policy) for policy in client.policies]
|
||||
|
||||
valid_source = [source_ip in policy for policy in policies]
|
||||
return any(valid_source)
|
||||
|
||||
|
||||
def get_server_key() -> str:
|
||||
"""Resolve server key.
|
||||
|
||||
TODO: Is one key enough? Should we generate more keys?
|
||||
"""
|
||||
filename = f"ssh_host_{constants.SERVER_KEY_TYPE}_key"
|
||||
if Path(filename).exists():
|
||||
return filename
|
||||
# FIXME: There's a weird typing warning here that I need to investigate.
|
||||
private_key = asyncssh.generate_private_key("ssh-ed25519", "sshecret-sshd")
|
||||
with open(filename, "wb") as f:
|
||||
f.write(private_key.export_private_key())
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
async def start_server(settings: ServerSettings | None = None) -> None:
|
||||
"""Start the server."""
|
||||
server_key = get_server_key()
|
||||
server = partial(AsshyncServer, settings=settings)
|
||||
|
||||
if not settings:
|
||||
settings = ServerSettings() # pyright: ignore[reportCallIssue]
|
||||
|
||||
await asyncssh.create_server(
|
||||
server,
|
||||
settings.listen_address,
|
||||
settings.port,
|
||||
server_host_keys=[server_key],
|
||||
process_factory=dispatch_command,
|
||||
)
|
||||
26
packages/sshecret-sshd/src/sshecret_sshd/types.py
Normal file
26
packages/sshecret-sshd/src/sshecret_sshd/types.py
Normal file
@ -0,0 +1,26 @@
|
||||
"""Types."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, IPvAnyAddress, IPvAnyNetwork
|
||||
|
||||
|
||||
class Client(BaseModel):
|
||||
"""Implementation of the backend class ClientView."""
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
public_key: str
|
||||
secrets: list[str]
|
||||
policies: list[IPvAnyNetwork | IPvAnyAddress]
|
||||
created_at: datetime
|
||||
updated_at: datetime | None
|
||||
|
||||
|
||||
class Secret(BaseModel):
|
||||
"""Implemen tation of the backend class ClientSecretResponse."""
|
||||
|
||||
name: str
|
||||
secret: str
|
||||
created_at: datetime
|
||||
updated_at: datetime | None
|
||||
Reference in New Issue
Block a user