Complete sshd package with tests
This commit is contained in:
@ -9,10 +9,16 @@ authors = [
|
|||||||
requires-python = ">=3.13"
|
requires-python = ">=3.13"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"asyncssh>=2.20.0",
|
"asyncssh>=2.20.0",
|
||||||
|
"click>=8.1.8",
|
||||||
"httpx>=0.28.1",
|
"httpx>=0.28.1",
|
||||||
|
"pydantic>=2.10.6",
|
||||||
"python-dotenv>=1.0.1",
|
"python-dotenv>=1.0.1",
|
||||||
|
"sshecret",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
sshecret = { workspace = true }
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
sshecret-sshd = "sshecret_sshd.cli:cli"
|
sshecret-sshd = "sshecret_sshd.cli:cli"
|
||||||
|
|
||||||
|
|||||||
2
packages/sshecret-sshd/pytest.ini
Normal file
2
packages/sshecret-sshd/pytest.ini
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
[pytest]
|
||||||
|
asyncio_mode = auto
|
||||||
@ -2,8 +2,10 @@
|
|||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
from pydantic_settings import CliApp
|
from typing import cast
|
||||||
|
|
||||||
|
import click
|
||||||
|
from pydantic import ValidationError
|
||||||
from .settings import ServerSettings
|
from .settings import ServerSettings
|
||||||
from .ssh_server import start_server
|
from .ssh_server import start_server
|
||||||
|
|
||||||
@ -17,22 +19,38 @@ LOG.addHandler(handler)
|
|||||||
LOG.setLevel(logging.INFO)
|
LOG.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
def cli(args: list[str] | None = None) -> None:
|
@click.group()
|
||||||
"""Run CLI app."""
|
@click.option("--debug", is_flag=True)
|
||||||
try:
|
@click.pass_context
|
||||||
settings = ServerSettings()
|
def cli(ctx: click.Context, debug: bool) -> None:
|
||||||
except Exception:
|
"""Sshecret Admin."""
|
||||||
print("One or more settings could not be resolved.")
|
if debug:
|
||||||
CliApp.run(ServerSettings, ["--help"])
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if settings.debug:
|
|
||||||
LOG.setLevel(logging.DEBUG)
|
LOG.setLevel(logging.DEBUG)
|
||||||
|
try:
|
||||||
|
settings = ServerSettings() # pyright: ignore[reportCallIssue]
|
||||||
|
except ValidationError as e:
|
||||||
|
raise click.ClickException(
|
||||||
|
"Error: One or more required environment options are missing."
|
||||||
|
) from e
|
||||||
|
ctx.obj = settings
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command("run")
|
||||||
|
@click.option("--host")
|
||||||
|
@click.option("--port", type=click.INT)
|
||||||
|
@click.pass_context
|
||||||
|
def cli_run(ctx: click.Context, host: str | None, port: int | None) -> None:
|
||||||
|
"""Run the server."""
|
||||||
|
settings = cast(ServerSettings, ctx.obj)
|
||||||
|
if host:
|
||||||
|
settings.listen_address = host
|
||||||
|
if port:
|
||||||
|
settings.port = port
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
loop.run_until_complete(start_server(settings))
|
loop.run_until_complete(start_server(settings))
|
||||||
|
title = click.style("Sshecret SSH Daemon", fg="red", bold=True)
|
||||||
print(f"Starting SSH server: {settings.listen_address}:{settings.port}")
|
click.echo(f"Starting {title}: {settings.listen_address}:{settings.port}")
|
||||||
try:
|
try:
|
||||||
loop.run_forever()
|
loop.run_forever()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
@ -40,7 +58,6 @@ def cli(args: list[str] | None = None) -> None:
|
|||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
"""Run CLI app."""
|
"""Run CLI app."""
|
||||||
cli()
|
cli()
|
||||||
|
|||||||
@ -4,5 +4,11 @@ ERROR_NO_COMMAND_RECEIVED = "Error: No command was received from the client."
|
|||||||
ERROR_SOURCE_IP_NOT_ALLOWED = (
|
ERROR_SOURCE_IP_NOT_ALLOWED = (
|
||||||
"Error: Client not authorized to connect from the given host."
|
"Error: Client not authorized to connect from the given host."
|
||||||
)
|
)
|
||||||
|
ERROR_NO_PUBLIC_KEY = "Error: No valid public key received."
|
||||||
|
ERROR_INVALID_KEY_TYPE = "Error: Invalid key type: Only RSA keys are supported."
|
||||||
ERROR_UNKNOWN_COMMAND = "Error: The given command was not understood."
|
ERROR_UNKNOWN_COMMAND = "Error: The given command was not understood."
|
||||||
SERVER_KEY_TYPE = "ed25519"
|
SERVER_KEY_TYPE = "ed25519"
|
||||||
|
ERROR_BACKEND_ERROR = "Error: Unexpected response or error from backend"
|
||||||
|
ERROR_INFO_BACKEND_GONE = "Unexpected error: Backend connection lost."
|
||||||
|
ERROR_INFO_USERNAME_GONE = "Unexpected error: Username lost in transit."
|
||||||
|
ERROR_INFO_REMOTE_IP_GONE = "Unexpected error: Client connection details lost in transit."
|
||||||
|
|||||||
@ -1,18 +1,48 @@
|
|||||||
"""SSH Server settings."""
|
"""SSH Server settings."""
|
||||||
|
|
||||||
from pydantic import AnyHttpUrl, Field, AliasChoices
|
import ipaddress
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from typing import Annotated, Any
|
||||||
|
from pydantic import AnyHttpUrl, BaseModel, Field, IPvAnyNetwork, field_validator
|
||||||
|
from pydantic_settings import BaseSettings, ForceDecode, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_LISTEN_PORT = 2222
|
DEFAULT_LISTEN_PORT = 2222
|
||||||
|
|
||||||
class ServerSettings(BaseSettings, cli_parse_args=True, cli_exit_on_error=True):
|
|
||||||
|
class ClientRegistrationSettings(BaseModel):
|
||||||
|
"""Client registration settings."""
|
||||||
|
|
||||||
|
enabled: bool = False
|
||||||
|
allow_from: Annotated[list[IPvAnyNetwork], ForceDecode] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@field_validator('allow_from', mode="before")
|
||||||
|
@classmethod
|
||||||
|
def ensure_allow_from_list(cls, value: Any) -> list[IPvAnyNetwork]:
|
||||||
|
"""Convert allow_from to a list."""
|
||||||
|
allow_from: list[IPvAnyNetwork] = []
|
||||||
|
if isinstance(value, list):
|
||||||
|
entries = value
|
||||||
|
elif isinstance(value, str):
|
||||||
|
entries = value.split(",")
|
||||||
|
else:
|
||||||
|
raise ValueError("Error: Unknown format for allowed_from.")
|
||||||
|
|
||||||
|
for entry in entries:
|
||||||
|
if isinstance(entry, str):
|
||||||
|
allow_from.append(ipaddress.ip_network(entry))
|
||||||
|
elif isinstance(entry, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
|
||||||
|
allow_from.append(entry)
|
||||||
|
return allow_from
|
||||||
|
|
||||||
|
class ServerSettings(BaseSettings):
|
||||||
"""Server Settings."""
|
"""Server Settings."""
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".sshd.env", env_prefix="sshecret_sshd_")
|
model_config = SettingsConfigDict(env_file=".sshd.env", env_prefix="sshecret_sshd_", env_nested_delimiter='_')
|
||||||
|
|
||||||
backend_url: AnyHttpUrl = Field(validation_alias=AliasChoices("backend-url", "sshecret_backend_url"))
|
backend_url: AnyHttpUrl = Field(alias="sshecret_backend_url")
|
||||||
backend_token: str = Field(validation_alias=AliasChoices("backend-token", "sshecret_sshd_backend_token"))
|
backend_token: str
|
||||||
listen_address: str = Field(default="", alias="listen")
|
listen_address: str = Field(default="127.0.0.1")
|
||||||
port: int = DEFAULT_LISTEN_PORT
|
port: int = DEFAULT_LISTEN_PORT
|
||||||
|
registration: ClientRegistrationSettings = Field(default_factory=ClientRegistrationSettings)
|
||||||
debug: bool = False
|
debug: bool = False
|
||||||
|
enable_ping_command: bool = False
|
||||||
|
|||||||
@ -1,21 +1,21 @@
|
|||||||
"""SSH Server implementation."""
|
"""SSH Server implementation."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
|
||||||
|
|
||||||
import asyncssh
|
import asyncssh
|
||||||
import ipaddress
|
import ipaddress
|
||||||
|
|
||||||
from collections.abc import Awaitable
|
from collections.abc import Awaitable
|
||||||
from datetime import datetime, timezone
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, cast, override
|
from typing import Callable, cast, override
|
||||||
|
|
||||||
|
from pydantic import IPvAnyNetwork
|
||||||
|
|
||||||
from . import constants
|
from . import constants
|
||||||
|
|
||||||
from sshecret.backend import AuditLog, SshecretBackend, Client
|
from sshecret.backend import SshecretBackend, Client, Operation, SubSystem
|
||||||
from .settings import ServerSettings
|
from .settings import ServerSettings, ClientRegistrationSettings
|
||||||
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@ -35,71 +35,39 @@ class CommandError(Exception):
|
|||||||
def audit_process(
|
def audit_process(
|
||||||
backend: SshecretBackend,
|
backend: SshecretBackend,
|
||||||
process: asyncssh.SSHServerProcess[str],
|
process: asyncssh.SSHServerProcess[str],
|
||||||
|
operation: Operation,
|
||||||
message: str,
|
message: str,
|
||||||
secret: str | None = None,
|
secret: str | None = None,
|
||||||
|
**data: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add an audit event from process."""
|
"""Add an audit event from process."""
|
||||||
command = get_process_command(process)
|
command = get_process_command(process)
|
||||||
client = get_info_client(process)
|
client = get_info_client(process)
|
||||||
username = get_info_username(process)
|
username = get_info_username(process)
|
||||||
remote_ip = get_info_remote_ip(process)
|
remote_ip = get_info_remote_ip(process) or "UNKNOWN"
|
||||||
operation = "SSH_EVENT"
|
if username:
|
||||||
obj_name: str | None = None
|
data["username"] = username
|
||||||
obj_id: str | None = None
|
|
||||||
|
|
||||||
if command and not secret:
|
if command and not secret:
|
||||||
cmd, cmd_args = command
|
cmd, cmd_args = command
|
||||||
obj_id = " ".join(cmd_args)
|
if cmd:
|
||||||
elif secret:
|
data["command"] = cmd
|
||||||
obj_name = "ClientSecret"
|
data["args"] = " ".join(cmd_args)
|
||||||
obj_id = secret
|
|
||||||
|
|
||||||
entry = AuditLog(
|
|
||||||
subsystem="ssh",
|
|
||||||
operation=operation,
|
|
||||||
object=obj_name,
|
|
||||||
object_id=obj_id,
|
|
||||||
message=message,
|
|
||||||
origin=remote_ip,
|
|
||||||
)
|
|
||||||
if client:
|
|
||||||
entry.client_id = str(client.id)
|
|
||||||
entry.client_name = client.name
|
|
||||||
elif username:
|
|
||||||
entry.client_name = username
|
|
||||||
|
|
||||||
backend.add_audit_log_sync(entry)
|
|
||||||
|
|
||||||
|
backend.audit(SubSystem.SSHD).write(operation, message, remote_ip, client, secret=None, secret_name=secret, **data)
|
||||||
|
|
||||||
def audit_event(
|
def audit_event(
|
||||||
backend: SshecretBackend,
|
backend: SshecretBackend,
|
||||||
message: str,
|
message: str,
|
||||||
operation: str = "SSH_EVENT",
|
operation: Operation,
|
||||||
client: Client | None = None,
|
client: Client | None = None,
|
||||||
origin: str | None = None,
|
origin: str | None = None,
|
||||||
secret: str | None = None,
|
secret: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add an audit event."""
|
"""Add an audit event."""
|
||||||
entry = AuditLog(
|
if not origin:
|
||||||
client_id=None,
|
origin = "UNKNOWN"
|
||||||
client_name=None,
|
backend.audit(SubSystem.SSHD).write(operation, message, origin, client, secret=None, secret_name=secret)
|
||||||
object=None,
|
|
||||||
object_id=None,
|
|
||||||
subsystem="ssh",
|
|
||||||
operation=operation,
|
|
||||||
message=message,
|
|
||||||
origin=origin,
|
|
||||||
)
|
|
||||||
if client:
|
|
||||||
entry.client_id = str(client.id)
|
|
||||||
entry.client_name = client.name
|
|
||||||
|
|
||||||
if secret:
|
|
||||||
entry.object = "ClientSecret"
|
|
||||||
entry.object_id = secret
|
|
||||||
|
|
||||||
backend.add_audit_log_sync(entry)
|
|
||||||
|
|
||||||
|
|
||||||
def verify_key_input(public_key: str) -> str | None:
|
def verify_key_input(public_key: str) -> str | None:
|
||||||
"""Verify key input."""
|
"""Verify key input."""
|
||||||
@ -118,6 +86,7 @@ def get_process_command(
|
|||||||
if not process.command:
|
if not process.command:
|
||||||
return (None, [])
|
return (None, [])
|
||||||
argv = process.command.split(" ")
|
argv = process.command.split(" ")
|
||||||
|
LOG.debug("Args: %r", argv)
|
||||||
return (argv[0], argv[1:])
|
return (argv[0], argv[1:])
|
||||||
|
|
||||||
|
|
||||||
@ -149,7 +118,12 @@ def get_info_remote_ip(process: asyncssh.SSHServerProcess[str]) -> str | None:
|
|||||||
|
|
||||||
return remote_ip
|
return remote_ip
|
||||||
|
|
||||||
# remote_ip = str(self._conn.get_extra_info("peername")[0])
|
def get_info_allowed_registration(process: asyncssh.SSHServerProcess[str]) -> list[IPvAnyNetwork] | None:
|
||||||
|
"""Get allowed networks to allow registration from."""
|
||||||
|
|
||||||
|
allowed_registration = cast(list[IPvAnyNetwork] | None, process.get_extra_info("allow_registration_from", None))
|
||||||
|
return allowed_registration
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_optional_commands(process: asyncssh.SSHServerProcess[str]) -> dict[str, bool]:
|
def get_optional_commands(process: asyncssh.SSHServerProcess[str]) -> dict[str, bool]:
|
||||||
@ -197,12 +171,12 @@ async def register_client(
|
|||||||
"""Register a new client."""
|
"""Register a new client."""
|
||||||
public_key = await get_stdin_public_key(process)
|
public_key = await get_stdin_public_key(process)
|
||||||
if not public_key:
|
if not public_key:
|
||||||
raise CommandError("Aborted. No valid public key received.")
|
raise CommandError(constants.ERROR_NO_PUBLIC_KEY)
|
||||||
|
|
||||||
key = asyncssh.import_public_key(public_key)
|
key = asyncssh.import_public_key(public_key)
|
||||||
if key.algorithm.decode() != "ssh-rsa":
|
if key.algorithm.decode() != "ssh-rsa":
|
||||||
raise CommandError("Error: Only RSA keys are supported!")
|
raise CommandError(constants.ERROR_INVALID_KEY_TYPE)
|
||||||
audit_process(backend, process, "Registering new client")
|
audit_process(backend, process, Operation.CREATE, "Registering new client")
|
||||||
LOG.debug("Registering client %s with public key %s", username, public_key)
|
LOG.debug("Registering client %s with public key %s", username, public_key)
|
||||||
await backend.create_client(username, public_key)
|
await backend.create_client(username, public_key)
|
||||||
|
|
||||||
@ -214,14 +188,16 @@ async def get_secret(
|
|||||||
origin: str,
|
origin: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Handle get secret requests from client."""
|
"""Handle get secret requests from client."""
|
||||||
LOG.debug("Recieved command: %r", secret_name)
|
LOG.debug("Recieved command: get_secret %r", secret_name)
|
||||||
if not secret_name or secret_name not in client.secrets:
|
if not secret_name:
|
||||||
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
raise CommandError(constants.ERROR_UNKNOWN_CLIENT_OR_SECRET)
|
||||||
|
if secret_name not in client.secrets:
|
||||||
|
raise CommandError(constants.ERROR_NO_SECRET_FOUND)
|
||||||
|
|
||||||
audit_event(
|
audit_event(
|
||||||
backend,
|
backend,
|
||||||
"Client requested secret",
|
"Client requested secret",
|
||||||
operation="get_secret",
|
operation=Operation.READ,
|
||||||
client=client,
|
client=client,
|
||||||
origin=origin,
|
origin=origin,
|
||||||
secret=secret_name,
|
secret=secret_name,
|
||||||
@ -229,10 +205,13 @@ async def get_secret(
|
|||||||
|
|
||||||
# Look up secret
|
# Look up secret
|
||||||
try:
|
try:
|
||||||
return await backend.get_client_secret(client.name, secret_name)
|
secret = await backend.get_client_secret(client.name, secret_name)
|
||||||
|
if not secret:
|
||||||
|
raise CommandError(constants.ERROR_NO_SECRET_FOUND)
|
||||||
|
return secret
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
LOG.debug(exc, exc_info=True)
|
LOG.debug(exc, exc_info=True)
|
||||||
raise CommandError("Unexpected error from backend") from exc
|
raise CommandError(constants.ERROR_BACKEND_ERROR) from exc
|
||||||
|
|
||||||
|
|
||||||
async def dispatch_no_cmd(process: asyncssh.SSHServerProcess[str]) -> None:
|
async def dispatch_no_cmd(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||||
@ -249,10 +228,32 @@ async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None
|
|||||||
"""Dispatch the register command."""
|
"""Dispatch the register command."""
|
||||||
backend = get_info_backend(process)
|
backend = get_info_backend(process)
|
||||||
if not backend:
|
if not backend:
|
||||||
raise CommandError("Unexpected error: Backend disappeared.")
|
raise CommandError(constants.ERROR_INFO_BACKEND_GONE)
|
||||||
username = get_info_username(process)
|
username = get_info_username(process)
|
||||||
if not username:
|
if not username:
|
||||||
raise CommandError("Unexpected error: Username was lost.")
|
raise CommandError(constants.ERROR_INFO_USERNAME_GONE)
|
||||||
|
|
||||||
|
allowed_networks = get_info_allowed_registration(process)
|
||||||
|
if not allowed_networks:
|
||||||
|
process.stdout.write("Unauthorized.\n")
|
||||||
|
audit_process(backend, process, Operation.DENY, "Received registration command, but no subnets are allowed.")
|
||||||
|
return
|
||||||
|
|
||||||
|
remote_ip = get_info_remote_ip(process)
|
||||||
|
|
||||||
|
if not remote_ip:
|
||||||
|
raise CommandError(constants.ERROR_INFO_REMOTE_IP_GONE)
|
||||||
|
|
||||||
|
client_address = ipaddress.ip_address(remote_ip)
|
||||||
|
|
||||||
|
for network in allowed_networks:
|
||||||
|
if client_address in network:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
audit_process(backend, process, Operation.DENY, "Received registration command from unauthorized subnet.")
|
||||||
|
process.stdout.write("Unauthorized.\n")
|
||||||
|
return
|
||||||
|
|
||||||
await register_client(process, backend, username)
|
await register_client(process, backend, username)
|
||||||
|
|
||||||
process.stdout.write("Client registered\n.")
|
process.stdout.write("Client registered\n.")
|
||||||
@ -262,7 +263,7 @@ async def dispatch_cmd_get_secret(process: asyncssh.SSHServerProcess[str]) -> No
|
|||||||
"""Dispatch the get_secret command."""
|
"""Dispatch the get_secret command."""
|
||||||
backend = get_info_backend(process)
|
backend = get_info_backend(process)
|
||||||
if not backend:
|
if not backend:
|
||||||
raise CommandError("Unexpected error: Backend disappeared.")
|
raise CommandError(constants.ERROR_INFO_BACKEND_GONE)
|
||||||
|
|
||||||
client = get_info_client(process)
|
client = get_info_client(process)
|
||||||
if not client:
|
if not client:
|
||||||
@ -311,6 +312,7 @@ async def dispatch_command(process: asyncssh.SSHServerProcess[str]) -> None:
|
|||||||
process.stderr.write(str(e))
|
process.stderr.write(str(e))
|
||||||
exit_code = 1
|
exit_code = 1
|
||||||
|
|
||||||
|
LOG.debug("Command processing finished.")
|
||||||
process.exit(exit_code)
|
process.exit(exit_code)
|
||||||
|
|
||||||
|
|
||||||
@ -319,16 +321,18 @@ class AsshyncServer(asyncssh.SSHServer):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
backend_url: str,
|
backend: SshecretBackend,
|
||||||
backend_token: str,
|
registration: ClientRegistrationSettings,
|
||||||
with_register: bool = True,
|
enable_ping_command: bool = False,
|
||||||
with_ping: bool = True,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize server."""
|
"""Initialize server."""
|
||||||
self.backend: SshecretBackend = SshecretBackend(backend_url, backend_token)
|
self.backend: SshecretBackend = backend
|
||||||
self._conn: asyncssh.SSHServerConnection | None = None
|
self._conn: asyncssh.SSHServerConnection | None = None
|
||||||
self.registration_enabled: bool = with_register
|
self.registration_enabled: bool = registration.enabled
|
||||||
self.ping_enabled: bool = with_ping
|
self.allow_registration_from: list[IPvAnyNetwork] | None = None
|
||||||
|
if registration.enabled:
|
||||||
|
self.allow_registration_from = registration.allow_from
|
||||||
|
self.ping_enabled: bool = enable_ping_command
|
||||||
self.client_ip: str | None = None
|
self.client_ip: str | None = None
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -359,9 +363,9 @@ class AsshyncServer(asyncssh.SSHServer):
|
|||||||
if not self._conn:
|
if not self._conn:
|
||||||
return True
|
return True
|
||||||
if client := await self.backend.get_client(username):
|
if client := await self.backend.get_client(username):
|
||||||
LOG.debug("Client lookup sucessful.")
|
LOG.debug("Client lookup sucessful: %r", client)
|
||||||
if key := self.resolve_client_key(client):
|
if key := self.resolve_client_key(client):
|
||||||
LOG.debug("Loaded public key for client %s", client.name)
|
LOG.debug("Loaded public key for client %s\n%s", client.name, key)
|
||||||
self._conn.set_extra_info(client=client)
|
self._conn.set_extra_info(client=client)
|
||||||
self._conn.set_authorized_keys(key)
|
self._conn.set_authorized_keys(key)
|
||||||
else:
|
else:
|
||||||
@ -369,15 +373,18 @@ class AsshyncServer(asyncssh.SSHServer):
|
|||||||
audit_event(
|
audit_event(
|
||||||
self.backend,
|
self.backend,
|
||||||
"Client denied due to policy",
|
"Client denied due to policy",
|
||||||
"DENY",
|
Operation.DENY,
|
||||||
client,
|
client,
|
||||||
origin=self.client_ip,
|
origin=self.client_ip,
|
||||||
)
|
)
|
||||||
LOG.warning("Client connection denied due to policy.")
|
LOG.warning("Client connection denied due to policy.")
|
||||||
else:
|
elif self.registration_enabled:
|
||||||
self._conn.set_extra_info(provided_username=username)
|
self._conn.set_extra_info(provided_username=username)
|
||||||
|
self._conn.set_extra_info(allow_registration_from=self.allow_registration_from)
|
||||||
|
LOG.warning("Registration enabled, and client is not recognized. Bypassing authentication.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
LOG.debug("Continuing to regular authentication")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -403,6 +410,7 @@ class AsshyncServer(asyncssh.SSHServer):
|
|||||||
return None
|
return None
|
||||||
remote_ip = str(self._conn.get_extra_info("peername")[0])
|
remote_ip = str(self._conn.get_extra_info("peername")[0])
|
||||||
LOG.debug("Validating client %s connection from %s", client.name, remote_ip)
|
LOG.debug("Validating client %s connection from %s", client.name, remote_ip)
|
||||||
|
LOG.debug("Loading client public key %r", client.public_key)
|
||||||
if self.check_connection_allowed(client, remote_ip):
|
if self.check_connection_allowed(client, remote_ip):
|
||||||
return asyncssh.import_authorized_keys(client.public_key)
|
return asyncssh.import_authorized_keys(client.public_key)
|
||||||
return None
|
return None
|
||||||
@ -428,7 +436,7 @@ def get_server_key(basedir: Path | None = None) -> str:
|
|||||||
if filename.exists():
|
if filename.exists():
|
||||||
return str(filename.absolute())
|
return str(filename.absolute())
|
||||||
# FIXME: There's a weird typing warning here that I need to investigate.
|
# FIXME: There's a weird typing warning here that I need to investigate.
|
||||||
private_key = asyncssh.generate_private_key("ssh-ed25519", "sshecret-sshd")
|
private_key = asyncssh.generate_private_key("ssh-ed25519", comment="sshecret-sshd")
|
||||||
with open(filename, "wb") as f:
|
with open(filename, "wb") as f:
|
||||||
f.write(private_key.export_private_key())
|
f.write(private_key.export_private_key())
|
||||||
|
|
||||||
@ -436,15 +444,19 @@ def get_server_key(basedir: Path | None = None) -> str:
|
|||||||
|
|
||||||
|
|
||||||
async def run_ssh_server(
|
async def run_ssh_server(
|
||||||
backend_url: str,
|
backend: SshecretBackend,
|
||||||
backend_token: str,
|
|
||||||
listen_address: str,
|
listen_address: str,
|
||||||
port: int,
|
port: int,
|
||||||
keys: list[str],
|
keys: list[str],
|
||||||
|
registration: ClientRegistrationSettings,
|
||||||
|
enable_ping_command: bool = False,
|
||||||
) -> asyncssh.SSHAcceptor:
|
) -> asyncssh.SSHAcceptor:
|
||||||
"""Run the server."""
|
"""Run the server."""
|
||||||
server = partial(
|
server = partial(
|
||||||
AsshyncServer, backend_url=str(backend_url), backend_token=backend_token
|
AsshyncServer,
|
||||||
|
backend=backend,
|
||||||
|
registration=registration,
|
||||||
|
enable_ping_command=enable_ping_command,
|
||||||
)
|
)
|
||||||
server = await asyncssh.create_server(
|
server = await asyncssh.create_server(
|
||||||
server,
|
server,
|
||||||
@ -463,10 +475,12 @@ async def start_server(settings: ServerSettings | None = None) -> None:
|
|||||||
if not settings:
|
if not settings:
|
||||||
settings = ServerSettings() # pyright: ignore[reportCallIssue]
|
settings = ServerSettings() # pyright: ignore[reportCallIssue]
|
||||||
|
|
||||||
|
backend = SshecretBackend(str(settings.backend_url), settings.backend_token)
|
||||||
await run_ssh_server(
|
await run_ssh_server(
|
||||||
str(settings.backend_url),
|
backend=backend,
|
||||||
settings.backend_token,
|
listen_address=settings.listen_address,
|
||||||
settings.listen_address,
|
port=settings.port,
|
||||||
settings.port,
|
keys=[server_key],
|
||||||
[server_key],
|
registration=settings.registration,
|
||||||
|
enable_ping_command=settings.enable_ping_command,
|
||||||
)
|
)
|
||||||
|
|||||||
1
packages/sshecret-sshd/tests/__init__.py
Normal file
1
packages/sshecret-sshd/tests/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
||||||
160
packages/sshecret-sshd/tests/conftest.py
Normal file
160
packages/sshecret-sshd/tests/conftest.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
import asyncio
|
||||||
|
import pytest
|
||||||
|
import uuid
|
||||||
|
import asyncssh
|
||||||
|
import tempfile
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from ipaddress import IPv4Network, IPv6Network
|
||||||
|
|
||||||
|
from sshecret_sshd.ssh_server import run_ssh_server
|
||||||
|
from sshecret_sshd.settings import ClientRegistrationSettings
|
||||||
|
|
||||||
|
from .types import Client, ClientKey, ClientRegistry, SshServerFixtureFun, SshServerFixture
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def client_registry() -> ClientRegistry:
|
||||||
|
clients = {}
|
||||||
|
secrets = {}
|
||||||
|
|
||||||
|
async def add_client(name: str, secret_names: list[str]|None=None, policies: list[str]|None=None) -> str:
|
||||||
|
private_key = asyncssh.generate_private_key('ssh-rsa')
|
||||||
|
public_key = private_key.export_public_key()
|
||||||
|
clients[name] = ClientKey(name, private_key, public_key.decode().rstrip())
|
||||||
|
secrets.update({(name, s): f"mocked-secret-{s}" for s in (secret_names or [])})
|
||||||
|
return clients[name]
|
||||||
|
|
||||||
|
return {"clients": clients, "secrets": secrets, "add_client": add_client}
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
|
||||||
|
backend = MagicMock()
|
||||||
|
clients_data = client_registry['clients']
|
||||||
|
secrets_data = client_registry['secrets']
|
||||||
|
|
||||||
|
async def get_client(name: str) -> Client | None:
|
||||||
|
client_key = clients_data.get(name)
|
||||||
|
if client_key:
|
||||||
|
response_model = Client(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name=name,
|
||||||
|
description=f"Mock client {name}",
|
||||||
|
public_key=client_key.public_key,
|
||||||
|
secrets=[s for (c, s) in secrets_data if c == name],
|
||||||
|
policies=[IPv4Network('0.0.0.0/0'), IPv6Network('::/0')],
|
||||||
|
)
|
||||||
|
return response_model
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_client_secret(name: str, secret_name: str) -> str | None:
|
||||||
|
secret = secrets_data.get((name, secret_name), None)
|
||||||
|
return secret
|
||||||
|
|
||||||
|
async def create_client(name: str, public_key: str) -> None:
|
||||||
|
"""Create client.
|
||||||
|
|
||||||
|
This only works if you register a client called template first.
|
||||||
|
Otherwise we can't test this...
|
||||||
|
"""
|
||||||
|
if "template" not in clients_data:
|
||||||
|
raise RuntimeError("Error, must have a client called template for this to work.")
|
||||||
|
clients_data[name] = clients_data["template"]
|
||||||
|
for secret_key, secret in secrets_data.items():
|
||||||
|
s_client, secret_name = secret_key
|
||||||
|
if s_client != "template":
|
||||||
|
continue
|
||||||
|
secrets_data[(name, secret_name)] = secret
|
||||||
|
|
||||||
|
backend.get_client = AsyncMock(side_effect=get_client)
|
||||||
|
backend.get_client_secret = AsyncMock(side_effect=get_client_secret)
|
||||||
|
backend.create_client = AsyncMock(side_effect=create_client)
|
||||||
|
|
||||||
|
# Make sure backend.audit(...) returns the audit mock
|
||||||
|
audit = MagicMock()
|
||||||
|
audit.write = MagicMock()
|
||||||
|
backend.audit = MagicMock(return_value=audit)
|
||||||
|
|
||||||
|
return backend
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
async def ssh_server(mock_backend: MagicMock, unused_tcp_port: int) -> SshServerFixtureFun:
|
||||||
|
port = unused_tcp_port
|
||||||
|
|
||||||
|
private_key = asyncssh.generate_private_key('ssh-ed25519')
|
||||||
|
key_str = private_key.export_private_key()
|
||||||
|
with tempfile.NamedTemporaryFile('w+', delete=True) as key_file:
|
||||||
|
|
||||||
|
key_file.write(key_str.decode())
|
||||||
|
key_file.flush()
|
||||||
|
|
||||||
|
registration_settings = ClientRegistrationSettings(enabled=True, allow_from=[IPv4Network("0.0.0.0/0")])
|
||||||
|
server = await run_ssh_server(
|
||||||
|
backend=mock_backend,
|
||||||
|
listen_address="localhost",
|
||||||
|
port=port,
|
||||||
|
keys=[key_file.name],
|
||||||
|
registration=registration_settings,
|
||||||
|
enable_ping_command=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
yield server, port
|
||||||
|
|
||||||
|
server.close()
|
||||||
|
await server.wait_closed()
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def ssh_command_runner(ssh_server: SshServerFixture, client_registry: ClientRegistry):
|
||||||
|
"""Run a single command.
|
||||||
|
|
||||||
|
Tricky typing!
|
||||||
|
"""
|
||||||
|
_, port = ssh_server
|
||||||
|
|
||||||
|
async def run_command_as(name: str, command: str):
|
||||||
|
client_key = client_registry['clients'][name]
|
||||||
|
conn = await asyncssh.connect(
|
||||||
|
"127.0.0.1",
|
||||||
|
port=port,
|
||||||
|
username=name,
|
||||||
|
client_keys=[client_key.private_key],
|
||||||
|
known_hosts=None,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
result = await conn.run(command)
|
||||||
|
return result
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
await conn.wait_closed()
|
||||||
|
|
||||||
|
return run_command_as
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def ssh_session(ssh_server: SshServerFixture, client_registry: ClientRegistry):
|
||||||
|
"""Yield an interactive session."""
|
||||||
|
|
||||||
|
_, port = ssh_server
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def run_process_as(name: str, command: str, client: str | None = None):
|
||||||
|
if not client:
|
||||||
|
client = name
|
||||||
|
client_key = client_registry["clients"][client]
|
||||||
|
conn = await asyncssh.connect(
|
||||||
|
"127.0.0.1",
|
||||||
|
port=port,
|
||||||
|
username=name,
|
||||||
|
client_keys=[client_key.private_key],
|
||||||
|
known_hosts=None,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
async with conn.create_process(command) as process:
|
||||||
|
yield process
|
||||||
|
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
await conn.wait_closed()
|
||||||
|
|
||||||
|
return run_process_as
|
||||||
28
packages/sshecret-sshd/tests/test_get_secret.py
Normal file
28
packages/sshecret-sshd/tests/test_get_secret.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
"""Test get secret."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from .types import ClientRegistry, CommandRunner
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_secret(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
|
||||||
|
"""Test that we can get a secret."""
|
||||||
|
|
||||||
|
await client_registry['add_client']("test-client", ["mysecret"])
|
||||||
|
|
||||||
|
result = await ssh_command_runner("test-client", "get_secret mysecret")
|
||||||
|
|
||||||
|
assert result.stdout is not None
|
||||||
|
assert isinstance(result.stdout, str)
|
||||||
|
assert result.stdout.rstrip() == "mocked-secret-mysecret"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_secret_name(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
|
||||||
|
"""Test getting an invalid secret name."""
|
||||||
|
|
||||||
|
await client_registry['add_client']("test-client")
|
||||||
|
|
||||||
|
result = await ssh_command_runner("test-client", "get_secret mysecret")
|
||||||
|
assert result.exit_status == 1
|
||||||
|
assert result.stderr == "Error: No secret available with the given name."
|
||||||
15
packages/sshecret-sshd/tests/test_ping.py
Normal file
15
packages/sshecret-sshd/tests/test_ping.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from .types import ClientRegistry, CommandRunner
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ping_command(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
|
||||||
|
# Register a test client with default policies and no secrets
|
||||||
|
await client_registry['add_client']('test-pinger')
|
||||||
|
|
||||||
|
result = await ssh_command_runner("test-pinger", "ping")
|
||||||
|
|
||||||
|
assert result.exit_status == 0
|
||||||
|
assert result.stdout is not None
|
||||||
|
assert isinstance(result.stdout, str)
|
||||||
|
assert result.stdout.rstrip() == "PONG"
|
||||||
34
packages/sshecret-sshd/tests/test_register.py
Normal file
34
packages/sshecret-sshd/tests/test_register.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
"""Test registration."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from .types import ClientRegistry, CommandRunner, ProcessRunner
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_client(ssh_session: ProcessRunner, ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
|
||||||
|
"""Test client registration."""
|
||||||
|
await client_registry["add_client"]("template", ["testsecret"])
|
||||||
|
|
||||||
|
public_key = client_registry["clients"]["template"].public_key.rstrip() + "\n"
|
||||||
|
|
||||||
|
async with ssh_session("newclient", "register", "template") as session:
|
||||||
|
maxlines = 10
|
||||||
|
l = 0
|
||||||
|
found = False
|
||||||
|
while l < maxlines:
|
||||||
|
line = await session.stdout.readline()
|
||||||
|
if "Enter public key" in line:
|
||||||
|
found = True
|
||||||
|
break
|
||||||
|
|
||||||
|
assert found is True
|
||||||
|
session.stdin.write(public_key)
|
||||||
|
result = await session.stdout.readline()
|
||||||
|
assert "OK" in result
|
||||||
|
|
||||||
|
# Test that we can connect
|
||||||
|
|
||||||
|
result = await ssh_command_runner("newclient", "get_secret testsecret")
|
||||||
|
|
||||||
|
assert result.stdout is not None
|
||||||
|
assert isinstance(result.stdout, str)
|
||||||
|
assert result.stdout.rstrip() == "mocked-secret-testsecret"
|
||||||
56
packages/sshecret-sshd/tests/types.py
Normal file
56
packages/sshecret-sshd/tests/types.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
"""Types for the various test properties."""
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from ipaddress import IPv4Network, IPv6Network
|
||||||
|
from collections.abc import AsyncGenerator, Awaitable, Callable, AsyncIterator
|
||||||
|
from typing import Any, Protocol, TypedDict, AsyncContextManager
|
||||||
|
import asyncssh
|
||||||
|
|
||||||
|
SshServerFixture = tuple[str, int]
|
||||||
|
SshServerFixtureFun = AsyncGenerator[tuple[asyncssh.SSHAcceptor, int], None]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Client:
|
||||||
|
"""Mock client."""
|
||||||
|
|
||||||
|
id: uuid.UUID
|
||||||
|
name: str
|
||||||
|
description: str | None
|
||||||
|
public_key: str
|
||||||
|
secrets: list[str]
|
||||||
|
policies: list[IPv4Network | IPv6Network]
|
||||||
|
created_at: datetime = field(default_factory=datetime.now)
|
||||||
|
updated_at: datetime = field(default_factory=datetime.now)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClientKey:
|
||||||
|
name: str
|
||||||
|
private_key: asyncssh.SSHKey
|
||||||
|
public_key: str
|
||||||
|
|
||||||
|
|
||||||
|
class AddClientFun(Protocol):
|
||||||
|
"""Add client function."""
|
||||||
|
|
||||||
|
def __call__(self, name: str, secret_names: list[str] | None = None, policies: list[str] | None = None) -> Awaitable[str]: ...
|
||||||
|
|
||||||
|
class ProcessRunner(Protocol):
|
||||||
|
"""Process runner typing."""
|
||||||
|
|
||||||
|
def __call__(self, name: str, command: str, client: str | None = None) -> AsyncContextManager[asyncssh.SSHClientProcess[Any]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
class ClientRegistry(TypedDict):
|
||||||
|
"""Client registry typing."""
|
||||||
|
|
||||||
|
clients: dict[str, ClientKey]
|
||||||
|
secrets: dict[tuple[str, str], str]
|
||||||
|
add_client: AddClientFun
|
||||||
|
|
||||||
|
|
||||||
|
CommandRunner = Callable[[str, str], Awaitable[asyncssh.SSHCompletedProcess]]
|
||||||
Reference in New Issue
Block a user