Reformat and lint

This commit is contained in:
2025-05-10 08:29:58 +02:00
parent 0a427b6a91
commit d866553ac1
9 changed files with 120 additions and 44 deletions

View File

@ -1,4 +1,5 @@
"""CLI app."""
import logging
import asyncio
import sys
@ -12,7 +13,9 @@ from .ssh_server import start_server
LOG = logging.getLogger()
handler = logging.StreamHandler()
formatter = logging.Formatter("%(created)f:%(levelname)s:%(name)s:%(module)s:%(message)s")
formatter = logging.Formatter(
"%(created)f:%(levelname)s:%(name)s:%(module)s:%(message)s"
)
handler.setFormatter(formatter)
LOG.addHandler(handler)

View File

@ -11,4 +11,6 @@ SERVER_KEY_TYPE = "ed25519"
ERROR_BACKEND_ERROR = "Error: Unexpected response or error from backend"
ERROR_INFO_BACKEND_GONE = "Unexpected error: Backend connection lost."
ERROR_INFO_USERNAME_GONE = "Unexpected error: Username lost in transit."
ERROR_INFO_REMOTE_IP_GONE = "Unexpected error: Client connection details lost in transit."
ERROR_INFO_REMOTE_IP_GONE = (
"Unexpected error: Client connection details lost in transit."
)

View File

@ -13,9 +13,11 @@ class ClientRegistrationSettings(BaseModel):
"""Client registration settings."""
enabled: bool = False
allow_from: Annotated[list[IPvAnyNetwork], ForceDecode] = Field(default_factory=list)
allow_from: Annotated[list[IPvAnyNetwork], ForceDecode] = Field(
default_factory=list
)
@field_validator('allow_from', mode="before")
@field_validator("allow_from", mode="before")
@classmethod
def ensure_allow_from_list(cls, value: Any) -> list[IPvAnyNetwork]:
"""Convert allow_from to a list."""
@ -34,15 +36,20 @@ class ClientRegistrationSettings(BaseModel):
allow_from.append(entry)
return allow_from
class ServerSettings(BaseSettings):
"""Server Settings."""
model_config = SettingsConfigDict(env_file=".sshd.env", env_prefix="sshecret_sshd_", env_nested_delimiter='_')
model_config = SettingsConfigDict(
env_file=".sshd.env", env_prefix="sshecret_sshd_", env_nested_delimiter="_"
)
backend_url: AnyHttpUrl = Field(alias="sshecret_backend_url")
backend_token: str
listen_address: str = Field(default="127.0.0.1")
port: int = DEFAULT_LISTEN_PORT
registration: ClientRegistrationSettings = Field(default_factory=ClientRegistrationSettings)
registration: ClientRegistrationSettings = Field(
default_factory=ClientRegistrationSettings
)
debug: bool = False
enable_ping_command: bool = False

View File

@ -54,7 +54,10 @@ def audit_process(
data["command"] = cmd
data["args"] = " ".join(cmd_args)
backend.audit(SubSystem.SSHD).write(operation, message, remote_ip, client, secret=None, secret_name=secret, **data)
backend.audit(SubSystem.SSHD).write(
operation, message, remote_ip, client, secret=None, secret_name=secret, **data
)
def audit_event(
backend: SshecretBackend,
@ -67,7 +70,10 @@ def audit_event(
"""Add an audit event."""
if not origin:
origin = "UNKNOWN"
backend.audit(SubSystem.SSHD).write(operation, message, origin, client, secret=None, secret_name=secret)
backend.audit(SubSystem.SSHD).write(
operation, message, origin, client, secret=None, secret_name=secret
)
def verify_key_input(public_key: str) -> str | None:
"""Verify key input."""
@ -118,14 +124,19 @@ def get_info_remote_ip(process: asyncssh.SSHServerProcess[str]) -> str | None:
return remote_ip
def get_info_allowed_registration(process: asyncssh.SSHServerProcess[str]) -> list[IPvAnyNetwork] | None:
def get_info_allowed_registration(
process: asyncssh.SSHServerProcess[str],
) -> list[IPvAnyNetwork] | None:
"""Get allowed networks to allow registration from."""
allowed_registration = cast(list[IPvAnyNetwork] | None, process.get_extra_info("allow_registration_from", None))
allowed_registration = cast(
list[IPvAnyNetwork] | None,
process.get_extra_info("allow_registration_from", None),
)
return allowed_registration
def get_optional_commands(process: asyncssh.SSHServerProcess[str]) -> dict[str, bool]:
"""Get optional command state."""
with_registration = cast(
@ -236,7 +247,12 @@ async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None
allowed_networks = get_info_allowed_registration(process)
if not allowed_networks:
process.stdout.write("Unauthorized.\n")
audit_process(backend, process, Operation.DENY, "Received registration command, but no subnets are allowed.")
audit_process(
backend,
process,
Operation.DENY,
"Received registration command, but no subnets are allowed.",
)
return
remote_ip = get_info_remote_ip(process)
@ -250,7 +266,12 @@ async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None
if client_address in network:
break
else:
audit_process(backend, process, Operation.DENY, "Received registration command from unauthorized subnet.")
audit_process(
backend,
process,
Operation.DENY,
"Received registration command from unauthorized subnet.",
)
process.stdout.write("Unauthorized.\n")
return
@ -369,7 +390,6 @@ class AsshyncServer(asyncssh.SSHServer):
self._conn.set_extra_info(client=client)
self._conn.set_authorized_keys(key)
else:
audit_event(
self.backend,
"Client denied due to policy",
@ -380,8 +400,12 @@ class AsshyncServer(asyncssh.SSHServer):
LOG.warning("Client connection denied due to policy.")
elif self.registration_enabled:
self._conn.set_extra_info(provided_username=username)
self._conn.set_extra_info(allow_registration_from=self.allow_registration_from)
LOG.warning("Registration enabled, and client is not recognized. Bypassing authentication.")
self._conn.set_extra_info(
allow_registration_from=self.allow_registration_from
)
LOG.warning(
"Registration enabled, and client is not recognized. Bypassing authentication."
)
return False
LOG.debug("Continuing to regular authentication")

View File

@ -10,15 +10,26 @@ from ipaddress import IPv4Network, IPv6Network
from sshecret_sshd.ssh_server import run_ssh_server
from sshecret_sshd.settings import ClientRegistrationSettings
from .types import Client, ClientKey, ClientRegistry, SshServerFixtureFun, SshServerFixture
from .types import (
Client,
ClientKey,
ClientRegistry,
SshServerFixtureFun,
SshServerFixture,
)
@pytest.fixture(scope="function")
def client_registry() -> ClientRegistry:
clients = {}
secrets = {}
async def add_client(name: str, secret_names: list[str]|None=None, policies: list[str]|None=None) -> str:
private_key = asyncssh.generate_private_key('ssh-rsa')
async def add_client(
name: str,
secret_names: list[str] | None = None,
policies: list[str] | None = None,
) -> str:
private_key = asyncssh.generate_private_key("ssh-rsa")
public_key = private_key.export_public_key()
clients[name] = ClientKey(name, private_key, public_key.decode().rstrip())
secrets.update({(name, s): f"mocked-secret-{s}" for s in (secret_names or [])})
@ -26,11 +37,12 @@ def client_registry() -> ClientRegistry:
return {"clients": clients, "secrets": secrets, "add_client": add_client}
@pytest.fixture(scope="function")
async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
backend = MagicMock()
clients_data = client_registry['clients']
secrets_data = client_registry['secrets']
clients_data = client_registry["clients"]
secrets_data = client_registry["secrets"]
async def get_client(name: str) -> Client | None:
client_key = clients_data.get(name)
@ -41,7 +53,7 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
description=f"Mock client {name}",
public_key=client_key.public_key,
secrets=[s for (c, s) in secrets_data if c == name],
policies=[IPv4Network('0.0.0.0/0'), IPv6Network('::/0')],
policies=[IPv4Network("0.0.0.0/0"), IPv6Network("::/0")],
)
return response_model
return None
@ -57,7 +69,9 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
Otherwise we can't test this...
"""
if "template" not in clients_data:
raise RuntimeError("Error, must have a client called template for this to work.")
raise RuntimeError(
"Error, must have a client called template for this to work."
)
clients_data[name] = clients_data["template"]
for secret_key, secret in secrets_data.items():
s_client, secret_name = secret_key
@ -76,18 +90,22 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
return backend
@pytest.fixture(scope="function")
async def ssh_server(mock_backend: MagicMock, unused_tcp_port: int) -> SshServerFixtureFun:
async def ssh_server(
mock_backend: MagicMock, unused_tcp_port: int
) -> SshServerFixtureFun:
port = unused_tcp_port
private_key = asyncssh.generate_private_key('ssh-ed25519')
private_key = asyncssh.generate_private_key("ssh-ed25519")
key_str = private_key.export_private_key()
with tempfile.NamedTemporaryFile('w+', delete=True) as key_file:
with tempfile.NamedTemporaryFile("w+", delete=True) as key_file:
key_file.write(key_str.decode())
key_file.flush()
registration_settings = ClientRegistrationSettings(enabled=True, allow_from=[IPv4Network("0.0.0.0/0")])
registration_settings = ClientRegistrationSettings(
enabled=True, allow_from=[IPv4Network("0.0.0.0/0")]
)
server = await run_ssh_server(
backend=mock_backend,
listen_address="localhost",
@ -104,6 +122,7 @@ async def ssh_server(mock_backend: MagicMock, unused_tcp_port: int) -> SshServer
server.close()
await server.wait_closed()
@pytest.fixture(scope="function")
def ssh_command_runner(ssh_server: SshServerFixture, client_registry: ClientRegistry):
"""Run a single command.
@ -113,7 +132,7 @@ def ssh_command_runner(ssh_server: SshServerFixture, client_registry: ClientRegi
_, port = ssh_server
async def run_command_as(name: str, command: str):
client_key = client_registry['clients'][name]
client_key = client_registry["clients"][name]
conn = await asyncssh.connect(
"127.0.0.1",
port=port,

View File

@ -4,11 +4,14 @@ import pytest
from .types import ClientRegistry, CommandRunner
@pytest.mark.asyncio
async def test_get_secret(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
async def test_get_secret(
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
) -> None:
"""Test that we can get a secret."""
await client_registry['add_client']("test-client", ["mysecret"])
await client_registry["add_client"]("test-client", ["mysecret"])
result = await ssh_command_runner("test-client", "get_secret mysecret")
@ -18,10 +21,12 @@ async def test_get_secret(ssh_command_runner: CommandRunner, client_registry: Cl
@pytest.mark.asyncio
async def test_invalid_secret_name(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
async def test_invalid_secret_name(
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
) -> None:
"""Test getting an invalid secret name."""
await client_registry['add_client']("test-client")
await client_registry["add_client"]("test-client")
result = await ssh_command_runner("test-client", "get_secret mysecret")
assert result.exit_status == 1

View File

@ -2,10 +2,13 @@ import pytest
from .types import ClientRegistry, CommandRunner
@pytest.mark.asyncio
async def test_ping_command(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
async def test_ping_command(
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
) -> None:
# Register a test client with default policies and no secrets
await client_registry['add_client']('test-pinger')
await client_registry["add_client"]("test-pinger")
result = await ssh_command_runner("test-pinger", "ping")

View File

@ -1,10 +1,16 @@
"""Test registration."""
import pytest
from .types import ClientRegistry, CommandRunner, ProcessRunner
@pytest.mark.asyncio
async def test_register_client(ssh_session: ProcessRunner, ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
async def test_register_client(
ssh_session: ProcessRunner,
ssh_command_runner: CommandRunner,
client_registry: ClientRegistry,
) -> None:
"""Test client registration."""
await client_registry["add_client"]("template", ["testsecret"])
@ -12,9 +18,9 @@ async def test_register_client(ssh_session: ProcessRunner, ssh_command_runner: C
async with ssh_session("newclient", "register", "template") as session:
maxlines = 10
l = 0
linenum = 0
found = False
while l < maxlines:
while linenum < maxlines:
line = await session.stdout.readline()
if "Enter public key" in line:
found = True

View File

@ -1,9 +1,10 @@
"""Types for the various test properties."""
import uuid
from datetime import datetime
from dataclasses import dataclass, field
from ipaddress import IPv4Network, IPv6Network
from collections.abc import AsyncGenerator, Awaitable, Callable, AsyncIterator
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any, Protocol, TypedDict, AsyncContextManager
import asyncssh
@ -11,8 +12,6 @@ SshServerFixture = tuple[str, int]
SshServerFixtureFun = AsyncGenerator[tuple[asyncssh.SSHAcceptor, int], None]
@dataclass
class Client:
"""Mock client."""
@ -37,13 +36,21 @@ class ClientKey:
class AddClientFun(Protocol):
"""Add client function."""
def __call__(self, name: str, secret_names: list[str] | None = None, policies: list[str] | None = None) -> Awaitable[str]: ...
def __call__(
self,
name: str,
secret_names: list[str] | None = None,
policies: list[str] | None = None,
) -> Awaitable[str]: ...
class ProcessRunner(Protocol):
"""Process runner typing."""
def __call__(self, name: str, command: str, client: str | None = None) -> AsyncContextManager[asyncssh.SSHClientProcess[Any]]:
...
def __call__(
self, name: str, command: str, client: str | None = None
) -> AsyncContextManager[asyncssh.SSHClientProcess[Any]]: ...
class ClientRegistry(TypedDict):
"""Client registry typing."""