Reformat and lint

This commit is contained in:
2025-05-10 08:29:58 +02:00
parent 0a427b6a91
commit d866553ac1
9 changed files with 120 additions and 44 deletions

View File

@ -1,4 +1,5 @@
"""CLI app.""" """CLI app."""
import logging import logging
import asyncio import asyncio
import sys import sys
@ -12,7 +13,9 @@ from .ssh_server import start_server
LOG = logging.getLogger() LOG = logging.getLogger()
handler = logging.StreamHandler() handler = logging.StreamHandler()
formatter = logging.Formatter("%(created)f:%(levelname)s:%(name)s:%(module)s:%(message)s") formatter = logging.Formatter(
"%(created)f:%(levelname)s:%(name)s:%(module)s:%(message)s"
)
handler.setFormatter(formatter) handler.setFormatter(formatter)
LOG.addHandler(handler) LOG.addHandler(handler)

View File

@ -11,4 +11,6 @@ SERVER_KEY_TYPE = "ed25519"
ERROR_BACKEND_ERROR = "Error: Unexpected response or error from backend" ERROR_BACKEND_ERROR = "Error: Unexpected response or error from backend"
ERROR_INFO_BACKEND_GONE = "Unexpected error: Backend connection lost." ERROR_INFO_BACKEND_GONE = "Unexpected error: Backend connection lost."
ERROR_INFO_USERNAME_GONE = "Unexpected error: Username lost in transit." ERROR_INFO_USERNAME_GONE = "Unexpected error: Username lost in transit."
ERROR_INFO_REMOTE_IP_GONE = "Unexpected error: Client connection details lost in transit." ERROR_INFO_REMOTE_IP_GONE = (
"Unexpected error: Client connection details lost in transit."
)

View File

@ -13,9 +13,11 @@ class ClientRegistrationSettings(BaseModel):
"""Client registration settings.""" """Client registration settings."""
enabled: bool = False enabled: bool = False
allow_from: Annotated[list[IPvAnyNetwork], ForceDecode] = Field(default_factory=list) allow_from: Annotated[list[IPvAnyNetwork], ForceDecode] = Field(
default_factory=list
)
@field_validator('allow_from', mode="before") @field_validator("allow_from", mode="before")
@classmethod @classmethod
def ensure_allow_from_list(cls, value: Any) -> list[IPvAnyNetwork]: def ensure_allow_from_list(cls, value: Any) -> list[IPvAnyNetwork]:
"""Convert allow_from to a list.""" """Convert allow_from to a list."""
@ -34,15 +36,20 @@ class ClientRegistrationSettings(BaseModel):
allow_from.append(entry) allow_from.append(entry)
return allow_from return allow_from
class ServerSettings(BaseSettings): class ServerSettings(BaseSettings):
"""Server Settings.""" """Server Settings."""
model_config = SettingsConfigDict(env_file=".sshd.env", env_prefix="sshecret_sshd_", env_nested_delimiter='_') model_config = SettingsConfigDict(
env_file=".sshd.env", env_prefix="sshecret_sshd_", env_nested_delimiter="_"
)
backend_url: AnyHttpUrl = Field(alias="sshecret_backend_url") backend_url: AnyHttpUrl = Field(alias="sshecret_backend_url")
backend_token: str backend_token: str
listen_address: str = Field(default="127.0.0.1") listen_address: str = Field(default="127.0.0.1")
port: int = DEFAULT_LISTEN_PORT port: int = DEFAULT_LISTEN_PORT
registration: ClientRegistrationSettings = Field(default_factory=ClientRegistrationSettings) registration: ClientRegistrationSettings = Field(
default_factory=ClientRegistrationSettings
)
debug: bool = False debug: bool = False
enable_ping_command: bool = False enable_ping_command: bool = False

View File

@ -54,7 +54,10 @@ def audit_process(
data["command"] = cmd data["command"] = cmd
data["args"] = " ".join(cmd_args) data["args"] = " ".join(cmd_args)
backend.audit(SubSystem.SSHD).write(operation, message, remote_ip, client, secret=None, secret_name=secret, **data) backend.audit(SubSystem.SSHD).write(
operation, message, remote_ip, client, secret=None, secret_name=secret, **data
)
def audit_event( def audit_event(
backend: SshecretBackend, backend: SshecretBackend,
@ -67,7 +70,10 @@ def audit_event(
"""Add an audit event.""" """Add an audit event."""
if not origin: if not origin:
origin = "UNKNOWN" origin = "UNKNOWN"
backend.audit(SubSystem.SSHD).write(operation, message, origin, client, secret=None, secret_name=secret) backend.audit(SubSystem.SSHD).write(
operation, message, origin, client, secret=None, secret_name=secret
)
def verify_key_input(public_key: str) -> str | None: def verify_key_input(public_key: str) -> str | None:
"""Verify key input.""" """Verify key input."""
@ -118,14 +124,19 @@ def get_info_remote_ip(process: asyncssh.SSHServerProcess[str]) -> str | None:
return remote_ip return remote_ip
def get_info_allowed_registration(process: asyncssh.SSHServerProcess[str]) -> list[IPvAnyNetwork] | None:
def get_info_allowed_registration(
process: asyncssh.SSHServerProcess[str],
) -> list[IPvAnyNetwork] | None:
"""Get allowed networks to allow registration from.""" """Get allowed networks to allow registration from."""
allowed_registration = cast(list[IPvAnyNetwork] | None, process.get_extra_info("allow_registration_from", None)) allowed_registration = cast(
list[IPvAnyNetwork] | None,
process.get_extra_info("allow_registration_from", None),
)
return allowed_registration return allowed_registration
def get_optional_commands(process: asyncssh.SSHServerProcess[str]) -> dict[str, bool]: def get_optional_commands(process: asyncssh.SSHServerProcess[str]) -> dict[str, bool]:
"""Get optional command state.""" """Get optional command state."""
with_registration = cast( with_registration = cast(
@ -236,7 +247,12 @@ async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None
allowed_networks = get_info_allowed_registration(process) allowed_networks = get_info_allowed_registration(process)
if not allowed_networks: if not allowed_networks:
process.stdout.write("Unauthorized.\n") process.stdout.write("Unauthorized.\n")
audit_process(backend, process, Operation.DENY, "Received registration command, but no subnets are allowed.") audit_process(
backend,
process,
Operation.DENY,
"Received registration command, but no subnets are allowed.",
)
return return
remote_ip = get_info_remote_ip(process) remote_ip = get_info_remote_ip(process)
@ -250,7 +266,12 @@ async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None
if client_address in network: if client_address in network:
break break
else: else:
audit_process(backend, process, Operation.DENY, "Received registration command from unauthorized subnet.") audit_process(
backend,
process,
Operation.DENY,
"Received registration command from unauthorized subnet.",
)
process.stdout.write("Unauthorized.\n") process.stdout.write("Unauthorized.\n")
return return
@ -369,7 +390,6 @@ class AsshyncServer(asyncssh.SSHServer):
self._conn.set_extra_info(client=client) self._conn.set_extra_info(client=client)
self._conn.set_authorized_keys(key) self._conn.set_authorized_keys(key)
else: else:
audit_event( audit_event(
self.backend, self.backend,
"Client denied due to policy", "Client denied due to policy",
@ -380,8 +400,12 @@ class AsshyncServer(asyncssh.SSHServer):
LOG.warning("Client connection denied due to policy.") LOG.warning("Client connection denied due to policy.")
elif self.registration_enabled: elif self.registration_enabled:
self._conn.set_extra_info(provided_username=username) self._conn.set_extra_info(provided_username=username)
self._conn.set_extra_info(allow_registration_from=self.allow_registration_from) self._conn.set_extra_info(
LOG.warning("Registration enabled, and client is not recognized. Bypassing authentication.") allow_registration_from=self.allow_registration_from
)
LOG.warning(
"Registration enabled, and client is not recognized. Bypassing authentication."
)
return False return False
LOG.debug("Continuing to regular authentication") LOG.debug("Continuing to regular authentication")

View File

@ -10,15 +10,26 @@ from ipaddress import IPv4Network, IPv6Network
from sshecret_sshd.ssh_server import run_ssh_server from sshecret_sshd.ssh_server import run_ssh_server
from sshecret_sshd.settings import ClientRegistrationSettings from sshecret_sshd.settings import ClientRegistrationSettings
from .types import Client, ClientKey, ClientRegistry, SshServerFixtureFun, SshServerFixture from .types import (
Client,
ClientKey,
ClientRegistry,
SshServerFixtureFun,
SshServerFixture,
)
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def client_registry() -> ClientRegistry: def client_registry() -> ClientRegistry:
clients = {} clients = {}
secrets = {} secrets = {}
async def add_client(name: str, secret_names: list[str]|None=None, policies: list[str]|None=None) -> str: async def add_client(
private_key = asyncssh.generate_private_key('ssh-rsa') name: str,
secret_names: list[str] | None = None,
policies: list[str] | None = None,
) -> str:
private_key = asyncssh.generate_private_key("ssh-rsa")
public_key = private_key.export_public_key() public_key = private_key.export_public_key()
clients[name] = ClientKey(name, private_key, public_key.decode().rstrip()) clients[name] = ClientKey(name, private_key, public_key.decode().rstrip())
secrets.update({(name, s): f"mocked-secret-{s}" for s in (secret_names or [])}) secrets.update({(name, s): f"mocked-secret-{s}" for s in (secret_names or [])})
@ -26,11 +37,12 @@ def client_registry() -> ClientRegistry:
return {"clients": clients, "secrets": secrets, "add_client": add_client} return {"clients": clients, "secrets": secrets, "add_client": add_client}
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
async def mock_backend(client_registry: ClientRegistry) -> MagicMock: async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
backend = MagicMock() backend = MagicMock()
clients_data = client_registry['clients'] clients_data = client_registry["clients"]
secrets_data = client_registry['secrets'] secrets_data = client_registry["secrets"]
async def get_client(name: str) -> Client | None: async def get_client(name: str) -> Client | None:
client_key = clients_data.get(name) client_key = clients_data.get(name)
@ -41,7 +53,7 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
description=f"Mock client {name}", description=f"Mock client {name}",
public_key=client_key.public_key, public_key=client_key.public_key,
secrets=[s for (c, s) in secrets_data if c == name], secrets=[s for (c, s) in secrets_data if c == name],
policies=[IPv4Network('0.0.0.0/0'), IPv6Network('::/0')], policies=[IPv4Network("0.0.0.0/0"), IPv6Network("::/0")],
) )
return response_model return response_model
return None return None
@ -57,7 +69,9 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
Otherwise we can't test this... Otherwise we can't test this...
""" """
if "template" not in clients_data: if "template" not in clients_data:
raise RuntimeError("Error, must have a client called template for this to work.") raise RuntimeError(
"Error, must have a client called template for this to work."
)
clients_data[name] = clients_data["template"] clients_data[name] = clients_data["template"]
for secret_key, secret in secrets_data.items(): for secret_key, secret in secrets_data.items():
s_client, secret_name = secret_key s_client, secret_name = secret_key
@ -76,18 +90,22 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
return backend return backend
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
async def ssh_server(mock_backend: MagicMock, unused_tcp_port: int) -> SshServerFixtureFun: async def ssh_server(
mock_backend: MagicMock, unused_tcp_port: int
) -> SshServerFixtureFun:
port = unused_tcp_port port = unused_tcp_port
private_key = asyncssh.generate_private_key('ssh-ed25519') private_key = asyncssh.generate_private_key("ssh-ed25519")
key_str = private_key.export_private_key() key_str = private_key.export_private_key()
with tempfile.NamedTemporaryFile('w+', delete=True) as key_file: with tempfile.NamedTemporaryFile("w+", delete=True) as key_file:
key_file.write(key_str.decode()) key_file.write(key_str.decode())
key_file.flush() key_file.flush()
registration_settings = ClientRegistrationSettings(enabled=True, allow_from=[IPv4Network("0.0.0.0/0")]) registration_settings = ClientRegistrationSettings(
enabled=True, allow_from=[IPv4Network("0.0.0.0/0")]
)
server = await run_ssh_server( server = await run_ssh_server(
backend=mock_backend, backend=mock_backend,
listen_address="localhost", listen_address="localhost",
@ -104,6 +122,7 @@ async def ssh_server(mock_backend: MagicMock, unused_tcp_port: int) -> SshServer
server.close() server.close()
await server.wait_closed() await server.wait_closed()
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def ssh_command_runner(ssh_server: SshServerFixture, client_registry: ClientRegistry): def ssh_command_runner(ssh_server: SshServerFixture, client_registry: ClientRegistry):
"""Run a single command. """Run a single command.
@ -113,7 +132,7 @@ def ssh_command_runner(ssh_server: SshServerFixture, client_registry: ClientRegi
_, port = ssh_server _, port = ssh_server
async def run_command_as(name: str, command: str): async def run_command_as(name: str, command: str):
client_key = client_registry['clients'][name] client_key = client_registry["clients"][name]
conn = await asyncssh.connect( conn = await asyncssh.connect(
"127.0.0.1", "127.0.0.1",
port=port, port=port,

View File

@ -4,11 +4,14 @@ import pytest
from .types import ClientRegistry, CommandRunner from .types import ClientRegistry, CommandRunner
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_secret(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None: async def test_get_secret(
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
) -> None:
"""Test that we can get a secret.""" """Test that we can get a secret."""
await client_registry['add_client']("test-client", ["mysecret"]) await client_registry["add_client"]("test-client", ["mysecret"])
result = await ssh_command_runner("test-client", "get_secret mysecret") result = await ssh_command_runner("test-client", "get_secret mysecret")
@ -18,10 +21,12 @@ async def test_get_secret(ssh_command_runner: CommandRunner, client_registry: Cl
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invalid_secret_name(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None: async def test_invalid_secret_name(
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
) -> None:
"""Test getting an invalid secret name.""" """Test getting an invalid secret name."""
await client_registry['add_client']("test-client") await client_registry["add_client"]("test-client")
result = await ssh_command_runner("test-client", "get_secret mysecret") result = await ssh_command_runner("test-client", "get_secret mysecret")
assert result.exit_status == 1 assert result.exit_status == 1

View File

@ -2,10 +2,13 @@ import pytest
from .types import ClientRegistry, CommandRunner from .types import ClientRegistry, CommandRunner
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_ping_command(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None: async def test_ping_command(
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
) -> None:
# Register a test client with default policies and no secrets # Register a test client with default policies and no secrets
await client_registry['add_client']('test-pinger') await client_registry["add_client"]("test-pinger")
result = await ssh_command_runner("test-pinger", "ping") result = await ssh_command_runner("test-pinger", "ping")

View File

@ -1,10 +1,16 @@
"""Test registration.""" """Test registration."""
import pytest import pytest
from .types import ClientRegistry, CommandRunner, ProcessRunner from .types import ClientRegistry, CommandRunner, ProcessRunner
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_client(ssh_session: ProcessRunner, ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None: async def test_register_client(
ssh_session: ProcessRunner,
ssh_command_runner: CommandRunner,
client_registry: ClientRegistry,
) -> None:
"""Test client registration.""" """Test client registration."""
await client_registry["add_client"]("template", ["testsecret"]) await client_registry["add_client"]("template", ["testsecret"])
@ -12,9 +18,9 @@ async def test_register_client(ssh_session: ProcessRunner, ssh_command_runner: C
async with ssh_session("newclient", "register", "template") as session: async with ssh_session("newclient", "register", "template") as session:
maxlines = 10 maxlines = 10
l = 0 linenum = 0
found = False found = False
while l < maxlines: while linenum < maxlines:
line = await session.stdout.readline() line = await session.stdout.readline()
if "Enter public key" in line: if "Enter public key" in line:
found = True found = True

View File

@ -1,9 +1,10 @@
"""Types for the various test properties.""" """Types for the various test properties."""
import uuid import uuid
from datetime import datetime from datetime import datetime
from dataclasses import dataclass, field from dataclasses import dataclass, field
from ipaddress import IPv4Network, IPv6Network from ipaddress import IPv4Network, IPv6Network
from collections.abc import AsyncGenerator, Awaitable, Callable, AsyncIterator from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any, Protocol, TypedDict, AsyncContextManager from typing import Any, Protocol, TypedDict, AsyncContextManager
import asyncssh import asyncssh
@ -11,8 +12,6 @@ SshServerFixture = tuple[str, int]
SshServerFixtureFun = AsyncGenerator[tuple[asyncssh.SSHAcceptor, int], None] SshServerFixtureFun = AsyncGenerator[tuple[asyncssh.SSHAcceptor, int], None]
@dataclass @dataclass
class Client: class Client:
"""Mock client.""" """Mock client."""
@ -37,13 +36,21 @@ class ClientKey:
class AddClientFun(Protocol): class AddClientFun(Protocol):
"""Add client function.""" """Add client function."""
def __call__(self, name: str, secret_names: list[str] | None = None, policies: list[str] | None = None) -> Awaitable[str]: ... def __call__(
self,
name: str,
secret_names: list[str] | None = None,
policies: list[str] | None = None,
) -> Awaitable[str]: ...
class ProcessRunner(Protocol): class ProcessRunner(Protocol):
"""Process runner typing.""" """Process runner typing."""
def __call__(self, name: str, command: str, client: str | None = None) -> AsyncContextManager[asyncssh.SSHClientProcess[Any]]: def __call__(
... self, name: str, command: str, client: str | None = None
) -> AsyncContextManager[asyncssh.SSHClientProcess[Any]]: ...
class ClientRegistry(TypedDict): class ClientRegistry(TypedDict):
"""Client registry typing.""" """Client registry typing."""