Complete sshd package with tests

This commit is contained in:
2025-05-10 08:27:16 +02:00
parent 3719a2611d
commit 4f970a3f71
12 changed files with 472 additions and 103 deletions

View File

@ -9,10 +9,16 @@ authors = [
requires-python = ">=3.13" requires-python = ">=3.13"
dependencies = [ dependencies = [
"asyncssh>=2.20.0", "asyncssh>=2.20.0",
"click>=8.1.8",
"httpx>=0.28.1", "httpx>=0.28.1",
"pydantic>=2.10.6",
"python-dotenv>=1.0.1", "python-dotenv>=1.0.1",
"sshecret",
] ]
[tool.uv.sources]
sshecret = { workspace = true }
[project.scripts] [project.scripts]
sshecret-sshd = "sshecret_sshd.cli:cli" sshecret-sshd = "sshecret_sshd.cli:cli"

View File

@ -0,0 +1,2 @@
[pytest]
asyncio_mode = auto

View File

@ -2,8 +2,10 @@
import logging import logging
import asyncio import asyncio
import sys import sys
from pydantic_settings import CliApp from typing import cast
import click
from pydantic import ValidationError
from .settings import ServerSettings from .settings import ServerSettings
from .ssh_server import start_server from .ssh_server import start_server
@ -17,22 +19,38 @@ LOG.addHandler(handler)
LOG.setLevel(logging.INFO) LOG.setLevel(logging.INFO)
def cli(args: list[str] | None = None) -> None: @click.group()
"""Run CLI app.""" @click.option("--debug", is_flag=True)
try: @click.pass_context
settings = ServerSettings() def cli(ctx: click.Context, debug: bool) -> None:
except Exception: """Sshecret Admin."""
print("One or more settings could not be resolved.") if debug:
CliApp.run(ServerSettings, ["--help"])
sys.exit(1)
if settings.debug:
LOG.setLevel(logging.DEBUG) LOG.setLevel(logging.DEBUG)
try:
settings = ServerSettings() # pyright: ignore[reportCallIssue]
except ValidationError as e:
raise click.ClickException(
"Error: One or more required environment options are missing."
) from e
ctx.obj = settings
@cli.command("run")
@click.option("--host")
@click.option("--port", type=click.INT)
@click.pass_context
def cli_run(ctx: click.Context, host: str | None, port: int | None) -> None:
"""Run the server."""
settings = cast(ServerSettings, ctx.obj)
if host:
settings.listen_address = host
if port:
settings.port = port
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
loop.run_until_complete(start_server(settings)) loop.run_until_complete(start_server(settings))
title = click.style("Sshecret SSH Daemon", fg="red", bold=True)
print(f"Starting SSH server: {settings.listen_address}:{settings.port}") click.echo(f"Starting {title}: {settings.listen_address}:{settings.port}")
try: try:
loop.run_forever() loop.run_forever()
except KeyboardInterrupt: except KeyboardInterrupt:
@ -40,7 +58,6 @@ def cli(args: list[str] | None = None) -> None:
sys.exit() sys.exit()
if __name__ == "__main__": if __name__ == "__main__":
"""Run CLI app.""" """Run CLI app."""
cli() cli()

View File

@ -4,5 +4,11 @@ ERROR_NO_COMMAND_RECEIVED = "Error: No command was received from the client."
ERROR_SOURCE_IP_NOT_ALLOWED = ( ERROR_SOURCE_IP_NOT_ALLOWED = (
"Error: Client not authorized to connect from the given host." "Error: Client not authorized to connect from the given host."
) )
ERROR_NO_PUBLIC_KEY = "Error: No valid public key received."
ERROR_INVALID_KEY_TYPE = "Error: Invalid key type: Only RSA keys are supported."
ERROR_UNKNOWN_COMMAND = "Error: The given command was not understood." ERROR_UNKNOWN_COMMAND = "Error: The given command was not understood."
SERVER_KEY_TYPE = "ed25519" SERVER_KEY_TYPE = "ed25519"
ERROR_BACKEND_ERROR = "Error: Unexpected response or error from backend"
ERROR_INFO_BACKEND_GONE = "Unexpected error: Backend connection lost."
ERROR_INFO_USERNAME_GONE = "Unexpected error: Username lost in transit."
ERROR_INFO_REMOTE_IP_GONE = "Unexpected error: Client connection details lost in transit."

View File

@ -1,18 +1,48 @@
"""SSH Server settings.""" """SSH Server settings."""
from pydantic import AnyHttpUrl, Field, AliasChoices import ipaddress
from pydantic_settings import BaseSettings, SettingsConfigDict from typing import Annotated, Any
from pydantic import AnyHttpUrl, BaseModel, Field, IPvAnyNetwork, field_validator
from pydantic_settings import BaseSettings, ForceDecode, SettingsConfigDict
DEFAULT_LISTEN_PORT = 2222 DEFAULT_LISTEN_PORT = 2222
class ServerSettings(BaseSettings, cli_parse_args=True, cli_exit_on_error=True):
class ClientRegistrationSettings(BaseModel):
"""Client registration settings."""
enabled: bool = False
allow_from: Annotated[list[IPvAnyNetwork], ForceDecode] = Field(default_factory=list)
@field_validator('allow_from', mode="before")
@classmethod
def ensure_allow_from_list(cls, value: Any) -> list[IPvAnyNetwork]:
"""Convert allow_from to a list."""
allow_from: list[IPvAnyNetwork] = []
if isinstance(value, list):
entries = value
elif isinstance(value, str):
entries = value.split(",")
else:
raise ValueError("Error: Unknown format for allowed_from.")
for entry in entries:
if isinstance(entry, str):
allow_from.append(ipaddress.ip_network(entry))
elif isinstance(entry, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
allow_from.append(entry)
return allow_from
class ServerSettings(BaseSettings):
"""Server Settings.""" """Server Settings."""
model_config = SettingsConfigDict(env_file=".sshd.env", env_prefix="sshecret_sshd_") model_config = SettingsConfigDict(env_file=".sshd.env", env_prefix="sshecret_sshd_", env_nested_delimiter='_')
backend_url: AnyHttpUrl = Field(validation_alias=AliasChoices("backend-url", "sshecret_backend_url")) backend_url: AnyHttpUrl = Field(alias="sshecret_backend_url")
backend_token: str = Field(validation_alias=AliasChoices("backend-token", "sshecret_sshd_backend_token")) backend_token: str
listen_address: str = Field(default="", alias="listen") listen_address: str = Field(default="127.0.0.1")
port: int = DEFAULT_LISTEN_PORT port: int = DEFAULT_LISTEN_PORT
registration: ClientRegistrationSettings = Field(default_factory=ClientRegistrationSettings)
debug: bool = False debug: bool = False
enable_ping_command: bool = False

View File

@ -1,21 +1,21 @@
"""SSH Server implementation.""" """SSH Server implementation."""
import logging import logging
import uuid
import asyncssh import asyncssh
import ipaddress import ipaddress
from collections.abc import Awaitable from collections.abc import Awaitable
from datetime import datetime, timezone
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Any, Callable, cast, override from typing import Callable, cast, override
from pydantic import IPvAnyNetwork
from . import constants from . import constants
from sshecret.backend import AuditLog, SshecretBackend, Client from sshecret.backend import SshecretBackend, Client, Operation, SubSystem
from .settings import ServerSettings from .settings import ServerSettings, ClientRegistrationSettings
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -35,71 +35,39 @@ class CommandError(Exception):
def audit_process( def audit_process(
backend: SshecretBackend, backend: SshecretBackend,
process: asyncssh.SSHServerProcess[str], process: asyncssh.SSHServerProcess[str],
operation: Operation,
message: str, message: str,
secret: str | None = None, secret: str | None = None,
**data: str,
) -> None: ) -> None:
"""Add an audit event from process.""" """Add an audit event from process."""
command = get_process_command(process) command = get_process_command(process)
client = get_info_client(process) client = get_info_client(process)
username = get_info_username(process) username = get_info_username(process)
remote_ip = get_info_remote_ip(process) remote_ip = get_info_remote_ip(process) or "UNKNOWN"
operation = "SSH_EVENT" if username:
obj_name: str | None = None data["username"] = username
obj_id: str | None = None
if command and not secret: if command and not secret:
cmd, cmd_args = command cmd, cmd_args = command
obj_id = " ".join(cmd_args) if cmd:
elif secret: data["command"] = cmd
obj_name = "ClientSecret" data["args"] = " ".join(cmd_args)
obj_id = secret
entry = AuditLog(
subsystem="ssh",
operation=operation,
object=obj_name,
object_id=obj_id,
message=message,
origin=remote_ip,
)
if client:
entry.client_id = str(client.id)
entry.client_name = client.name
elif username:
entry.client_name = username
backend.add_audit_log_sync(entry)
backend.audit(SubSystem.SSHD).write(operation, message, remote_ip, client, secret=None, secret_name=secret, **data)
def audit_event( def audit_event(
backend: SshecretBackend, backend: SshecretBackend,
message: str, message: str,
operation: str = "SSH_EVENT", operation: Operation,
client: Client | None = None, client: Client | None = None,
origin: str | None = None, origin: str | None = None,
secret: str | None = None, secret: str | None = None,
) -> None: ) -> None:
"""Add an audit event.""" """Add an audit event."""
entry = AuditLog( if not origin:
client_id=None, origin = "UNKNOWN"
client_name=None, backend.audit(SubSystem.SSHD).write(operation, message, origin, client, secret=None, secret_name=secret)
object=None,
object_id=None,
subsystem="ssh",
operation=operation,
message=message,
origin=origin,
)
if client:
entry.client_id = str(client.id)
entry.client_name = client.name
if secret:
entry.object = "ClientSecret"
entry.object_id = secret
backend.add_audit_log_sync(entry)
def verify_key_input(public_key: str) -> str | None: def verify_key_input(public_key: str) -> str | None:
"""Verify key input.""" """Verify key input."""
@ -118,6 +86,7 @@ def get_process_command(
if not process.command: if not process.command:
return (None, []) return (None, [])
argv = process.command.split(" ") argv = process.command.split(" ")
LOG.debug("Args: %r", argv)
return (argv[0], argv[1:]) return (argv[0], argv[1:])
@ -149,7 +118,12 @@ def get_info_remote_ip(process: asyncssh.SSHServerProcess[str]) -> str | None:
return remote_ip return remote_ip
# remote_ip = str(self._conn.get_extra_info("peername")[0]) def get_info_allowed_registration(process: asyncssh.SSHServerProcess[str]) -> list[IPvAnyNetwork] | None:
"""Get allowed networks to allow registration from."""
allowed_registration = cast(list[IPvAnyNetwork] | None, process.get_extra_info("allow_registration_from", None))
return allowed_registration
def get_optional_commands(process: asyncssh.SSHServerProcess[str]) -> dict[str, bool]: def get_optional_commands(process: asyncssh.SSHServerProcess[str]) -> dict[str, bool]:
@ -197,12 +171,12 @@ async def register_client(
"""Register a new client.""" """Register a new client."""
public_key = await get_stdin_public_key(process) public_key = await get_stdin_public_key(process)
if not public_key: if not public_key:
raise CommandError("Aborted. No valid public key received.") raise CommandError(constants.ERROR_NO_PUBLIC_KEY)
key = asyncssh.import_public_key(public_key) key = asyncssh.import_public_key(public_key)
if key.algorithm.decode() != "ssh-rsa": if key.algorithm.decode() != "ssh-rsa":
raise CommandError("Error: Only RSA keys are supported!") raise CommandError(constants.ERROR_INVALID_KEY_TYPE)
audit_process(backend, process, "Registering new client") audit_process(backend, process, Operation.CREATE, "Registering new client")
LOG.debug("Registering client %s with public key %s", username, public_key) LOG.debug("Registering client %s with public key %s", username, public_key)
await backend.create_client(username, public_key) await backend.create_client(username, public_key)
@ -214,14 +188,16 @@ async def get_secret(
origin: str, origin: str,
) -> str: ) -> str:
"""Handle get secret requests from client.""" """Handle get secret requests from client."""
LOG.debug("Recieved command: %r", secret_name) LOG.debug("Recieved command: get_secret %r", secret_name)
if not secret_name or secret_name not in client.secrets: if not secret_name:
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET) raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
if secret_name not in client.secrets:
raise CommandError(constants.ERROR_NO_SECRET_FOUND)
audit_event( audit_event(
backend, backend,
"Client requested secret", "Client requested secret",
operation="get_secret", operation=Operation.READ,
client=client, client=client,
origin=origin, origin=origin,
secret=secret_name, secret=secret_name,
@ -229,10 +205,13 @@ async def get_secret(
# Look up secret # Look up secret
try: try:
return await backend.get_client_secret(client.name, secret_name) secret = await backend.get_client_secret(client.name, secret_name)
if not secret:
raise CommandError(constants.ERROR_NO_SECRET_FOUND)
return secret
except Exception as exc: except Exception as exc:
LOG.debug(exc, exc_info=True) LOG.debug(exc, exc_info=True)
raise CommandError("Unexpected error from backend") from exc raise CommandError(constants.ERROR_BACKEND_ERROR) from exc
async def dispatch_no_cmd(process: asyncssh.SSHServerProcess[str]) -> None: async def dispatch_no_cmd(process: asyncssh.SSHServerProcess[str]) -> None:
@ -249,10 +228,32 @@ async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None
"""Dispatch the register command.""" """Dispatch the register command."""
backend = get_info_backend(process) backend = get_info_backend(process)
if not backend: if not backend:
raise CommandError("Unexpected error: Backend disappeared.") raise CommandError(constants.ERROR_INFO_BACKEND_GONE)
username = get_info_username(process) username = get_info_username(process)
if not username: if not username:
raise CommandError("Unexpected error: Username was lost.") raise CommandError(constants.ERROR_INFO_USERNAME_GONE)
allowed_networks = get_info_allowed_registration(process)
if not allowed_networks:
process.stdout.write("Unauthorized.\n")
audit_process(backend, process, Operation.DENY, "Received registration command, but no subnets are allowed.")
return
remote_ip = get_info_remote_ip(process)
if not remote_ip:
raise CommandError(constants.ERROR_INFO_REMOTE_IP_GONE)
client_address = ipaddress.ip_address(remote_ip)
for network in allowed_networks:
if client_address in network:
break
else:
audit_process(backend, process, Operation.DENY, "Received registration command from unauthorized subnet.")
process.stdout.write("Unauthorized.\n")
return
await register_client(process, backend, username) await register_client(process, backend, username)
process.stdout.write("Client registered\n.") process.stdout.write("Client registered\n.")
@ -262,7 +263,7 @@ async def dispatch_cmd_get_secret(process: asyncssh.SSHServerProcess[str]) -> No
"""Dispatch the get_secret command.""" """Dispatch the get_secret command."""
backend = get_info_backend(process) backend = get_info_backend(process)
if not backend: if not backend:
raise CommandError("Unexpected error: Backend disappeared.") raise CommandError(constants.ERROR_INFO_BACKEND_GONE)
client = get_info_client(process) client = get_info_client(process)
if not client: if not client:
@ -311,6 +312,7 @@ async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
process.stderr.write(str(e)) process.stderr.write(str(e))
exit_code = 1 exit_code = 1
LOG.debug("Command processing finished.")
process.exit(exit_code) process.exit(exit_code)
@ -319,16 +321,18 @@ class AsshyncServer(asyncssh.SSHServer):
def __init__( def __init__(
self, self,
backend_url: str, backend: SshecretBackend,
backend_token: str, registration: ClientRegistrationSettings,
with_register: bool = True, enable_ping_command: bool = False,
with_ping: bool = True,
) -> None: ) -> None:
"""Initialize server.""" """Initialize server."""
self.backend: SshecretBackend = SshecretBackend(backend_url, backend_token) self.backend: SshecretBackend = backend
self._conn: asyncssh.SSHServerConnection | None = None self._conn: asyncssh.SSHServerConnection | None = None
self.registration_enabled: bool = with_register self.registration_enabled: bool = registration.enabled
self.ping_enabled: bool = with_ping self.allow_registration_from: list[IPvAnyNetwork] | None = None
if registration.enabled:
self.allow_registration_from = registration.allow_from
self.ping_enabled: bool = enable_ping_command
self.client_ip: str | None = None self.client_ip: str | None = None
@override @override
@ -359,9 +363,9 @@ class AsshyncServer(asyncssh.SSHServer):
if not self._conn: if not self._conn:
return True return True
if client := await self.backend.get_client(username): if client := await self.backend.get_client(username):
LOG.debug("Client lookup sucessful.") LOG.debug("Client lookup sucessful: %r", client)
if key := self.resolve_client_key(client): if key := self.resolve_client_key(client):
LOG.debug("Loaded public key for client %s", client.name) LOG.debug("Loaded public key for client %s\n%s", client.name, key)
self._conn.set_extra_info(client=client) self._conn.set_extra_info(client=client)
self._conn.set_authorized_keys(key) self._conn.set_authorized_keys(key)
else: else:
@ -369,15 +373,18 @@ class AsshyncServer(asyncssh.SSHServer):
audit_event( audit_event(
self.backend, self.backend,
"Client denied due to policy", "Client denied due to policy",
"DENY", Operation.DENY,
client, client,
origin=self.client_ip, origin=self.client_ip,
) )
LOG.warning("Client connection denied due to policy.") LOG.warning("Client connection denied due to policy.")
else: elif self.registration_enabled:
self._conn.set_extra_info(provided_username=username) self._conn.set_extra_info(provided_username=username)
self._conn.set_extra_info(allow_registration_from=self.allow_registration_from)
LOG.warning("Registration enabled, and client is not recognized. Bypassing authentication.")
return False return False
LOG.debug("Continuing to regular authentication")
return True return True
@override @override
@ -403,6 +410,7 @@ class AsshyncServer(asyncssh.SSHServer):
return None return None
remote_ip = str(self._conn.get_extra_info("peername")[0]) remote_ip = str(self._conn.get_extra_info("peername")[0])
LOG.debug("Validating client %s connection from %s", client.name, remote_ip) LOG.debug("Validating client %s connection from %s", client.name, remote_ip)
LOG.debug("Loading client public key %r", client.public_key)
if self.check_connection_allowed(client, remote_ip): if self.check_connection_allowed(client, remote_ip):
return asyncssh.import_authorized_keys(client.public_key) return asyncssh.import_authorized_keys(client.public_key)
return None return None
@ -428,7 +436,7 @@ def get_server_key(basedir: Path | None = None) -> str:
if filename.exists(): if filename.exists():
return str(filename.absolute()) return str(filename.absolute())
# FIXME: There's a weird typing warning here that I need to investigate. # FIXME: There's a weird typing warning here that I need to investigate.
private_key = asyncssh.generate_private_key("ssh-ed25519", "sshecret-sshd") private_key = asyncssh.generate_private_key("ssh-ed25519", comment="sshecret-sshd")
with open(filename, "wb") as f: with open(filename, "wb") as f:
f.write(private_key.export_private_key()) f.write(private_key.export_private_key())
@ -436,15 +444,19 @@ def get_server_key(basedir: Path | None = None) -> str:
async def run_ssh_server( async def run_ssh_server(
backend_url: str, backend: SshecretBackend,
backend_token: str,
listen_address: str, listen_address: str,
port: int, port: int,
keys: list[str], keys: list[str],
registration: ClientRegistrationSettings,
enable_ping_command: bool = False,
) -> asyncssh.SSHAcceptor: ) -> asyncssh.SSHAcceptor:
"""Run the server.""" """Run the server."""
server = partial( server = partial(
AsshyncServer, backend_url=str(backend_url), backend_token=backend_token AsshyncServer,
backend=backend,
registration=registration,
enable_ping_command=enable_ping_command,
) )
server = await asyncssh.create_server( server = await asyncssh.create_server(
server, server,
@ -463,10 +475,12 @@ async def start_server(settings: ServerSettings | None = None) -> None:
if not settings: if not settings:
settings = ServerSettings() # pyright: ignore[reportCallIssue] settings = ServerSettings() # pyright: ignore[reportCallIssue]
backend = SshecretBackend(str(settings.backend_url), settings.backend_token)
await run_ssh_server( await run_ssh_server(
str(settings.backend_url), backend=backend,
settings.backend_token, listen_address=settings.listen_address,
settings.listen_address, port=settings.port,
settings.port, keys=[server_key],
[server_key], registration=settings.registration,
enable_ping_command=settings.enable_ping_command,
) )

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,160 @@
import asyncio
import pytest
import uuid
import asyncssh
import tempfile
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock
from ipaddress import IPv4Network, IPv6Network
from sshecret_sshd.ssh_server import run_ssh_server
from sshecret_sshd.settings import ClientRegistrationSettings
from .types import Client, ClientKey, ClientRegistry, SshServerFixtureFun, SshServerFixture
@pytest.fixture(scope="function")
def client_registry() -> ClientRegistry:
clients = {}
secrets = {}
async def add_client(name: str, secret_names: list[str]|None=None, policies: list[str]|None=None) -> str:
private_key = asyncssh.generate_private_key('ssh-rsa')
public_key = private_key.export_public_key()
clients[name] = ClientKey(name, private_key, public_key.decode().rstrip())
secrets.update({(name, s): f"mocked-secret-{s}" for s in (secret_names or [])})
return clients[name]
return {"clients": clients, "secrets": secrets, "add_client": add_client}
@pytest.fixture(scope="function")
async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
backend = MagicMock()
clients_data = client_registry['clients']
secrets_data = client_registry['secrets']
async def get_client(name: str) -> Client | None:
client_key = clients_data.get(name)
if client_key:
response_model = Client(
id=uuid.uuid4(),
name=name,
description=f"Mock client {name}",
public_key=client_key.public_key,
secrets=[s for (c, s) in secrets_data if c == name],
policies=[IPv4Network('0.0.0.0/0'), IPv6Network('::/0')],
)
return response_model
return None
async def get_client_secret(name: str, secret_name: str) -> str | None:
secret = secrets_data.get((name, secret_name), None)
return secret
async def create_client(name: str, public_key: str) -> None:
"""Create client.
This only works if you register a client called template first.
Otherwise we can't test this...
"""
if "template" not in clients_data:
raise RuntimeError("Error, must have a client called template for this to work.")
clients_data[name] = clients_data["template"]
for secret_key, secret in secrets_data.items():
s_client, secret_name = secret_key
if s_client != "template":
continue
secrets_data[(name, secret_name)] = secret
backend.get_client = AsyncMock(side_effect=get_client)
backend.get_client_secret = AsyncMock(side_effect=get_client_secret)
backend.create_client = AsyncMock(side_effect=create_client)
# Make sure backend.audit(...) returns the audit mock
audit = MagicMock()
audit.write = MagicMock()
backend.audit = MagicMock(return_value=audit)
return backend
@pytest.fixture(scope="function")
async def ssh_server(mock_backend: MagicMock, unused_tcp_port: int) -> SshServerFixtureFun:
port = unused_tcp_port
private_key = asyncssh.generate_private_key('ssh-ed25519')
key_str = private_key.export_private_key()
with tempfile.NamedTemporaryFile('w+', delete=True) as key_file:
key_file.write(key_str.decode())
key_file.flush()
registration_settings = ClientRegistrationSettings(enabled=True, allow_from=[IPv4Network("0.0.0.0/0")])
server = await run_ssh_server(
backend=mock_backend,
listen_address="localhost",
port=port,
keys=[key_file.name],
registration=registration_settings,
enable_ping_command=True,
)
await asyncio.sleep(0.1)
yield server, port
server.close()
await server.wait_closed()
@pytest.fixture(scope="function")
def ssh_command_runner(ssh_server: SshServerFixture, client_registry: ClientRegistry):
"""Run a single command.
Tricky typing!
"""
_, port = ssh_server
async def run_command_as(name: str, command: str):
client_key = client_registry['clients'][name]
conn = await asyncssh.connect(
"127.0.0.1",
port=port,
username=name,
client_keys=[client_key.private_key],
known_hosts=None,
)
try:
result = await conn.run(command)
return result
finally:
conn.close()
await conn.wait_closed()
return run_command_as
@pytest.fixture(scope="function")
def ssh_session(ssh_server: SshServerFixture, client_registry: ClientRegistry):
"""Yield an interactive session."""
_, port = ssh_server
@asynccontextmanager
async def run_process_as(name: str, command: str, client: str | None = None):
if not client:
client = name
client_key = client_registry["clients"][client]
conn = await asyncssh.connect(
"127.0.0.1",
port=port,
username=name,
client_keys=[client_key.private_key],
known_hosts=None,
)
try:
async with conn.create_process(command) as process:
yield process
finally:
conn.close()
await conn.wait_closed()
return run_process_as

View File

@ -0,0 +1,28 @@
"""Test get secret."""
import pytest
from .types import ClientRegistry, CommandRunner
@pytest.mark.asyncio
async def test_get_secret(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
"""Test that we can get a secret."""
await client_registry['add_client']("test-client", ["mysecret"])
result = await ssh_command_runner("test-client", "get_secret mysecret")
assert result.stdout is not None
assert isinstance(result.stdout, str)
assert result.stdout.rstrip() == "mocked-secret-mysecret"
@pytest.mark.asyncio
async def test_invalid_secret_name(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
"""Test getting an invalid secret name."""
await client_registry['add_client']("test-client")
result = await ssh_command_runner("test-client", "get_secret mysecret")
assert result.exit_status == 1
assert result.stderr == "Error: No secret available with the given name."

View File

@ -0,0 +1,15 @@
import pytest
from .types import ClientRegistry, CommandRunner
@pytest.mark.asyncio
async def test_ping_command(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
# Register a test client with default policies and no secrets
await client_registry['add_client']('test-pinger')
result = await ssh_command_runner("test-pinger", "ping")
assert result.exit_status == 0
assert result.stdout is not None
assert isinstance(result.stdout, str)
assert result.stdout.rstrip() == "PONG"

View File

@ -0,0 +1,34 @@
"""Test registration."""
import pytest
from .types import ClientRegistry, CommandRunner, ProcessRunner
@pytest.mark.asyncio
async def test_register_client(ssh_session: ProcessRunner, ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
"""Test client registration."""
await client_registry["add_client"]("template", ["testsecret"])
public_key = client_registry["clients"]["template"].public_key.rstrip() + "\n"
async with ssh_session("newclient", "register", "template") as session:
maxlines = 10
l = 0
found = False
while l < maxlines:
line = await session.stdout.readline()
if "Enter public key" in line:
found = True
break
assert found is True
session.stdin.write(public_key)
result = await session.stdout.readline()
assert "OK" in result
# Test that we can connect
result = await ssh_command_runner("newclient", "get_secret testsecret")
assert result.stdout is not None
assert isinstance(result.stdout, str)
assert result.stdout.rstrip() == "mocked-secret-testsecret"

View File

@ -0,0 +1,56 @@
"""Types for the various test properties."""
import uuid
from datetime import datetime
from dataclasses import dataclass, field
from ipaddress import IPv4Network, IPv6Network
from collections.abc import AsyncGenerator, Awaitable, Callable, AsyncIterator
from typing import Any, Protocol, TypedDict, AsyncContextManager
import asyncssh
SshServerFixture = tuple[str, int]
SshServerFixtureFun = AsyncGenerator[tuple[asyncssh.SSHAcceptor, int], None]
@dataclass
class Client:
"""Mock client."""
id: uuid.UUID
name: str
description: str | None
public_key: str
secrets: list[str]
policies: list[IPv4Network | IPv6Network]
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)
@dataclass
class ClientKey:
name: str
private_key: asyncssh.SSHKey
public_key: str
class AddClientFun(Protocol):
"""Add client function."""
def __call__(self, name: str, secret_names: list[str] | None = None, policies: list[str] | None = None) -> Awaitable[str]: ...
class ProcessRunner(Protocol):
"""Process runner typing."""
def __call__(self, name: str, command: str, client: str | None = None) -> AsyncContextManager[asyncssh.SSHClientProcess[Any]]:
...
class ClientRegistry(TypedDict):
"""Client registry typing."""
clients: dict[str, ClientKey]
secrets: dict[tuple[str, str], str]
add_client: AddClientFun
CommandRunner = Callable[[str, str], Awaitable[asyncssh.SSHCompletedProcess]]