diff --git a/packages/sshecret-sshd/src/sshecret_sshd/cli.py b/packages/sshecret-sshd/src/sshecret_sshd/cli.py index 7f967c8..1416ead 100644 --- a/packages/sshecret-sshd/src/sshecret_sshd/cli.py +++ b/packages/sshecret-sshd/src/sshecret_sshd/cli.py @@ -1,4 +1,5 @@ """CLI app.""" + import logging import asyncio import sys @@ -12,7 +13,9 @@ from .ssh_server import start_server LOG = logging.getLogger() handler = logging.StreamHandler() -formatter = logging.Formatter("%(created)f:%(levelname)s:%(name)s:%(module)s:%(message)s") +formatter = logging.Formatter( + "%(created)f:%(levelname)s:%(name)s:%(module)s:%(message)s" +) handler.setFormatter(formatter) LOG.addHandler(handler) diff --git a/packages/sshecret-sshd/src/sshecret_sshd/constants.py b/packages/sshecret-sshd/src/sshecret_sshd/constants.py index dfaea77..df1776f 100644 --- a/packages/sshecret-sshd/src/sshecret_sshd/constants.py +++ b/packages/sshecret-sshd/src/sshecret_sshd/constants.py @@ -11,4 +11,6 @@ SERVER_KEY_TYPE = "ed25519" ERROR_BACKEND_ERROR = "Error: Unexpected response or error from backend" ERROR_INFO_BACKEND_GONE = "Unexpected error: Backend connection lost." ERROR_INFO_USERNAME_GONE = "Unexpected error: Username lost in transit." -ERROR_INFO_REMOTE_IP_GONE = "Unexpected error: Client connection details lost in transit." +ERROR_INFO_REMOTE_IP_GONE = ( + "Unexpected error: Client connection details lost in transit." +) diff --git a/packages/sshecret-sshd/src/sshecret_sshd/settings.py b/packages/sshecret-sshd/src/sshecret_sshd/settings.py index 42843b0..9ea0b7e 100644 --- a/packages/sshecret-sshd/src/sshecret_sshd/settings.py +++ b/packages/sshecret-sshd/src/sshecret_sshd/settings.py @@ -13,9 +13,11 @@ class ClientRegistrationSettings(BaseModel): """Client registration settings.""" enabled: bool = False - allow_from: Annotated[list[IPvAnyNetwork], ForceDecode] = Field(default_factory=list) + allow_from: Annotated[list[IPvAnyNetwork], ForceDecode] = Field( + default_factory=list + ) - @field_validator('allow_from', mode="before") + @field_validator("allow_from", mode="before") @classmethod def ensure_allow_from_list(cls, value: Any) -> list[IPvAnyNetwork]: """Convert allow_from to a list.""" @@ -34,15 +36,20 @@ class ClientRegistrationSettings(BaseModel): allow_from.append(entry) return allow_from + class ServerSettings(BaseSettings): """Server Settings.""" - model_config = SettingsConfigDict(env_file=".sshd.env", env_prefix="sshecret_sshd_", env_nested_delimiter='_') + model_config = SettingsConfigDict( + env_file=".sshd.env", env_prefix="sshecret_sshd_", env_nested_delimiter="_" + ) backend_url: AnyHttpUrl = Field(alias="sshecret_backend_url") backend_token: str listen_address: str = Field(default="127.0.0.1") port: int = DEFAULT_LISTEN_PORT - registration: ClientRegistrationSettings = Field(default_factory=ClientRegistrationSettings) + registration: ClientRegistrationSettings = Field( + default_factory=ClientRegistrationSettings + ) debug: bool = False enable_ping_command: bool = False diff --git a/packages/sshecret-sshd/src/sshecret_sshd/ssh_server.py b/packages/sshecret-sshd/src/sshecret_sshd/ssh_server.py index baccca0..792945c 100644 --- a/packages/sshecret-sshd/src/sshecret_sshd/ssh_server.py +++ b/packages/sshecret-sshd/src/sshecret_sshd/ssh_server.py @@ -54,7 +54,10 @@ def audit_process( data["command"] = cmd data["args"] = " ".join(cmd_args) - backend.audit(SubSystem.SSHD).write(operation, message, remote_ip, client, secret=None, secret_name=secret, **data) + backend.audit(SubSystem.SSHD).write( + operation, message, remote_ip, client, secret=None, secret_name=secret, **data + ) + def audit_event( backend: SshecretBackend, @@ -67,7 +70,10 @@ def audit_event( """Add an audit event.""" if not origin: origin = "UNKNOWN" - backend.audit(SubSystem.SSHD).write(operation, message, origin, client, secret=None, secret_name=secret) + backend.audit(SubSystem.SSHD).write( + operation, message, origin, client, secret=None, secret_name=secret + ) + def verify_key_input(public_key: str) -> str | None: """Verify key input.""" @@ -118,14 +124,19 @@ def get_info_remote_ip(process: asyncssh.SSHServerProcess[str]) -> str | None: return remote_ip -def get_info_allowed_registration(process: asyncssh.SSHServerProcess[str]) -> list[IPvAnyNetwork] | None: + +def get_info_allowed_registration( + process: asyncssh.SSHServerProcess[str], +) -> list[IPvAnyNetwork] | None: """Get allowed networks to allow registration from.""" - allowed_registration = cast(list[IPvAnyNetwork] | None, process.get_extra_info("allow_registration_from", None)) + allowed_registration = cast( + list[IPvAnyNetwork] | None, + process.get_extra_info("allow_registration_from", None), + ) return allowed_registration - def get_optional_commands(process: asyncssh.SSHServerProcess[str]) -> dict[str, bool]: """Get optional command state.""" with_registration = cast( @@ -236,7 +247,12 @@ async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None allowed_networks = get_info_allowed_registration(process) if not allowed_networks: process.stdout.write("Unauthorized.\n") - audit_process(backend, process, Operation.DENY, "Received registration command, but no subnets are allowed.") + audit_process( + backend, + process, + Operation.DENY, + "Received registration command, but no subnets are allowed.", + ) return remote_ip = get_info_remote_ip(process) @@ -250,7 +266,12 @@ async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None if client_address in network: break else: - audit_process(backend, process, Operation.DENY, "Received registration command from unauthorized subnet.") + audit_process( + backend, + process, + Operation.DENY, + "Received registration command from unauthorized subnet.", + ) process.stdout.write("Unauthorized.\n") return @@ -369,7 +390,6 @@ class AsshyncServer(asyncssh.SSHServer): self._conn.set_extra_info(client=client) self._conn.set_authorized_keys(key) else: - audit_event( self.backend, "Client denied due to policy", @@ -380,8 +400,12 @@ class AsshyncServer(asyncssh.SSHServer): LOG.warning("Client connection denied due to policy.") elif self.registration_enabled: self._conn.set_extra_info(provided_username=username) - self._conn.set_extra_info(allow_registration_from=self.allow_registration_from) - LOG.warning("Registration enabled, and client is not recognized. Bypassing authentication.") + self._conn.set_extra_info( + allow_registration_from=self.allow_registration_from + ) + LOG.warning( + "Registration enabled, and client is not recognized. Bypassing authentication." + ) return False LOG.debug("Continuing to regular authentication") diff --git a/packages/sshecret-sshd/tests/conftest.py b/packages/sshecret-sshd/tests/conftest.py index b9d809d..daeec14 100644 --- a/packages/sshecret-sshd/tests/conftest.py +++ b/packages/sshecret-sshd/tests/conftest.py @@ -10,15 +10,26 @@ from ipaddress import IPv4Network, IPv6Network from sshecret_sshd.ssh_server import run_ssh_server from sshecret_sshd.settings import ClientRegistrationSettings -from .types import Client, ClientKey, ClientRegistry, SshServerFixtureFun, SshServerFixture +from .types import ( + Client, + ClientKey, + ClientRegistry, + SshServerFixtureFun, + SshServerFixture, +) + @pytest.fixture(scope="function") def client_registry() -> ClientRegistry: clients = {} secrets = {} - async def add_client(name: str, secret_names: list[str]|None=None, policies: list[str]|None=None) -> str: - private_key = asyncssh.generate_private_key('ssh-rsa') + async def add_client( + name: str, + secret_names: list[str] | None = None, + policies: list[str] | None = None, + ) -> str: + private_key = asyncssh.generate_private_key("ssh-rsa") public_key = private_key.export_public_key() clients[name] = ClientKey(name, private_key, public_key.decode().rstrip()) secrets.update({(name, s): f"mocked-secret-{s}" for s in (secret_names or [])}) @@ -26,11 +37,12 @@ def client_registry() -> ClientRegistry: return {"clients": clients, "secrets": secrets, "add_client": add_client} + @pytest.fixture(scope="function") async def mock_backend(client_registry: ClientRegistry) -> MagicMock: backend = MagicMock() - clients_data = client_registry['clients'] - secrets_data = client_registry['secrets'] + clients_data = client_registry["clients"] + secrets_data = client_registry["secrets"] async def get_client(name: str) -> Client | None: client_key = clients_data.get(name) @@ -41,7 +53,7 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock: description=f"Mock client {name}", public_key=client_key.public_key, secrets=[s for (c, s) in secrets_data if c == name], - policies=[IPv4Network('0.0.0.0/0'), IPv6Network('::/0')], + policies=[IPv4Network("0.0.0.0/0"), IPv6Network("::/0")], ) return response_model return None @@ -57,7 +69,9 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock: Otherwise we can't test this... """ if "template" not in clients_data: - raise RuntimeError("Error, must have a client called template for this to work.") + raise RuntimeError( + "Error, must have a client called template for this to work." + ) clients_data[name] = clients_data["template"] for secret_key, secret in secrets_data.items(): s_client, secret_name = secret_key @@ -76,18 +90,22 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock: return backend + @pytest.fixture(scope="function") -async def ssh_server(mock_backend: MagicMock, unused_tcp_port: int) -> SshServerFixtureFun: +async def ssh_server( + mock_backend: MagicMock, unused_tcp_port: int +) -> SshServerFixtureFun: port = unused_tcp_port - private_key = asyncssh.generate_private_key('ssh-ed25519') + private_key = asyncssh.generate_private_key("ssh-ed25519") key_str = private_key.export_private_key() - with tempfile.NamedTemporaryFile('w+', delete=True) as key_file: - + with tempfile.NamedTemporaryFile("w+", delete=True) as key_file: key_file.write(key_str.decode()) key_file.flush() - registration_settings = ClientRegistrationSettings(enabled=True, allow_from=[IPv4Network("0.0.0.0/0")]) + registration_settings = ClientRegistrationSettings( + enabled=True, allow_from=[IPv4Network("0.0.0.0/0")] + ) server = await run_ssh_server( backend=mock_backend, listen_address="localhost", @@ -104,6 +122,7 @@ async def ssh_server(mock_backend: MagicMock, unused_tcp_port: int) -> SshServer server.close() await server.wait_closed() + @pytest.fixture(scope="function") def ssh_command_runner(ssh_server: SshServerFixture, client_registry: ClientRegistry): """Run a single command. @@ -113,7 +132,7 @@ def ssh_command_runner(ssh_server: SshServerFixture, client_registry: ClientRegi _, port = ssh_server async def run_command_as(name: str, command: str): - client_key = client_registry['clients'][name] + client_key = client_registry["clients"][name] conn = await asyncssh.connect( "127.0.0.1", port=port, diff --git a/packages/sshecret-sshd/tests/test_get_secret.py b/packages/sshecret-sshd/tests/test_get_secret.py index dca1645..b129a7c 100644 --- a/packages/sshecret-sshd/tests/test_get_secret.py +++ b/packages/sshecret-sshd/tests/test_get_secret.py @@ -4,11 +4,14 @@ import pytest from .types import ClientRegistry, CommandRunner + @pytest.mark.asyncio -async def test_get_secret(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None: +async def test_get_secret( + ssh_command_runner: CommandRunner, client_registry: ClientRegistry +) -> None: """Test that we can get a secret.""" - await client_registry['add_client']("test-client", ["mysecret"]) + await client_registry["add_client"]("test-client", ["mysecret"]) result = await ssh_command_runner("test-client", "get_secret mysecret") @@ -18,10 +21,12 @@ async def test_get_secret(ssh_command_runner: CommandRunner, client_registry: Cl @pytest.mark.asyncio -async def test_invalid_secret_name(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None: +async def test_invalid_secret_name( + ssh_command_runner: CommandRunner, client_registry: ClientRegistry +) -> None: """Test getting an invalid secret name.""" - await client_registry['add_client']("test-client") + await client_registry["add_client"]("test-client") result = await ssh_command_runner("test-client", "get_secret mysecret") assert result.exit_status == 1 diff --git a/packages/sshecret-sshd/tests/test_ping.py b/packages/sshecret-sshd/tests/test_ping.py index a83ef50..791c01a 100644 --- a/packages/sshecret-sshd/tests/test_ping.py +++ b/packages/sshecret-sshd/tests/test_ping.py @@ -2,10 +2,13 @@ import pytest from .types import ClientRegistry, CommandRunner + @pytest.mark.asyncio -async def test_ping_command(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None: +async def test_ping_command( + ssh_command_runner: CommandRunner, client_registry: ClientRegistry +) -> None: # Register a test client with default policies and no secrets - await client_registry['add_client']('test-pinger') + await client_registry["add_client"]("test-pinger") result = await ssh_command_runner("test-pinger", "ping") diff --git a/packages/sshecret-sshd/tests/test_register.py b/packages/sshecret-sshd/tests/test_register.py index 432f57c..2e6c3c5 100644 --- a/packages/sshecret-sshd/tests/test_register.py +++ b/packages/sshecret-sshd/tests/test_register.py @@ -1,10 +1,16 @@ """Test registration.""" + import pytest from .types import ClientRegistry, CommandRunner, ProcessRunner + @pytest.mark.asyncio -async def test_register_client(ssh_session: ProcessRunner, ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None: +async def test_register_client( + ssh_session: ProcessRunner, + ssh_command_runner: CommandRunner, + client_registry: ClientRegistry, +) -> None: """Test client registration.""" await client_registry["add_client"]("template", ["testsecret"]) @@ -12,9 +18,9 @@ async def test_register_client(ssh_session: ProcessRunner, ssh_command_runner: C async with ssh_session("newclient", "register", "template") as session: maxlines = 10 - l = 0 + linenum = 0 found = False - while l < maxlines: + while linenum < maxlines: line = await session.stdout.readline() if "Enter public key" in line: found = True diff --git a/packages/sshecret-sshd/tests/types.py b/packages/sshecret-sshd/tests/types.py index 9b6187b..aa009cb 100644 --- a/packages/sshecret-sshd/tests/types.py +++ b/packages/sshecret-sshd/tests/types.py @@ -1,9 +1,10 @@ """Types for the various test properties.""" + import uuid from datetime import datetime from dataclasses import dataclass, field from ipaddress import IPv4Network, IPv6Network -from collections.abc import AsyncGenerator, Awaitable, Callable, AsyncIterator +from collections.abc import AsyncGenerator, Awaitable, Callable from typing import Any, Protocol, TypedDict, AsyncContextManager import asyncssh @@ -11,8 +12,6 @@ SshServerFixture = tuple[str, int] SshServerFixtureFun = AsyncGenerator[tuple[asyncssh.SSHAcceptor, int], None] - - @dataclass class Client: """Mock client.""" @@ -37,13 +36,21 @@ class ClientKey: class AddClientFun(Protocol): """Add client function.""" - def __call__(self, name: str, secret_names: list[str] | None = None, policies: list[str] | None = None) -> Awaitable[str]: ... + def __call__( + self, + name: str, + secret_names: list[str] | None = None, + policies: list[str] | None = None, + ) -> Awaitable[str]: ... + class ProcessRunner(Protocol): """Process runner typing.""" - def __call__(self, name: str, command: str, client: str | None = None) -> AsyncContextManager[asyncssh.SSHClientProcess[Any]]: - ... + def __call__( + self, name: str, command: str, client: str | None = None + ) -> AsyncContextManager[asyncssh.SSHClientProcess[Any]]: ... + class ClientRegistry(TypedDict): """Client registry typing."""