Reformat and lint

This commit is contained in:
2025-05-10 08:29:58 +02:00
parent 0a427b6a91
commit d866553ac1
9 changed files with 120 additions and 44 deletions

View File

@ -10,15 +10,26 @@ from ipaddress import IPv4Network, IPv6Network
from sshecret_sshd.ssh_server import run_ssh_server
from sshecret_sshd.settings import ClientRegistrationSettings
from .types import Client, ClientKey, ClientRegistry, SshServerFixtureFun, SshServerFixture
from .types import (
Client,
ClientKey,
ClientRegistry,
SshServerFixtureFun,
SshServerFixture,
)
@pytest.fixture(scope="function")
def client_registry() -> ClientRegistry:
clients = {}
secrets = {}
async def add_client(name: str, secret_names: list[str]|None=None, policies: list[str]|None=None) -> str:
private_key = asyncssh.generate_private_key('ssh-rsa')
async def add_client(
name: str,
secret_names: list[str] | None = None,
policies: list[str] | None = None,
) -> str:
private_key = asyncssh.generate_private_key("ssh-rsa")
public_key = private_key.export_public_key()
clients[name] = ClientKey(name, private_key, public_key.decode().rstrip())
secrets.update({(name, s): f"mocked-secret-{s}" for s in (secret_names or [])})
@ -26,11 +37,12 @@ def client_registry() -> ClientRegistry:
return {"clients": clients, "secrets": secrets, "add_client": add_client}
@pytest.fixture(scope="function")
async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
backend = MagicMock()
clients_data = client_registry['clients']
secrets_data = client_registry['secrets']
clients_data = client_registry["clients"]
secrets_data = client_registry["secrets"]
async def get_client(name: str) -> Client | None:
client_key = clients_data.get(name)
@ -41,7 +53,7 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
description=f"Mock client {name}",
public_key=client_key.public_key,
secrets=[s for (c, s) in secrets_data if c == name],
policies=[IPv4Network('0.0.0.0/0'), IPv6Network('::/0')],
policies=[IPv4Network("0.0.0.0/0"), IPv6Network("::/0")],
)
return response_model
return None
@ -57,7 +69,9 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
Otherwise we can't test this...
"""
if "template" not in clients_data:
raise RuntimeError("Error, must have a client called template for this to work.")
raise RuntimeError(
"Error, must have a client called template for this to work."
)
clients_data[name] = clients_data["template"]
for secret_key, secret in secrets_data.items():
s_client, secret_name = secret_key
@ -76,18 +90,22 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
return backend
@pytest.fixture(scope="function")
async def ssh_server(mock_backend: MagicMock, unused_tcp_port: int) -> SshServerFixtureFun:
async def ssh_server(
mock_backend: MagicMock, unused_tcp_port: int
) -> SshServerFixtureFun:
port = unused_tcp_port
private_key = asyncssh.generate_private_key('ssh-ed25519')
private_key = asyncssh.generate_private_key("ssh-ed25519")
key_str = private_key.export_private_key()
with tempfile.NamedTemporaryFile('w+', delete=True) as key_file:
with tempfile.NamedTemporaryFile("w+", delete=True) as key_file:
key_file.write(key_str.decode())
key_file.flush()
registration_settings = ClientRegistrationSettings(enabled=True, allow_from=[IPv4Network("0.0.0.0/0")])
registration_settings = ClientRegistrationSettings(
enabled=True, allow_from=[IPv4Network("0.0.0.0/0")]
)
server = await run_ssh_server(
backend=mock_backend,
listen_address="localhost",
@ -104,6 +122,7 @@ async def ssh_server(mock_backend: MagicMock, unused_tcp_port: int) -> SshServer
server.close()
await server.wait_closed()
@pytest.fixture(scope="function")
def ssh_command_runner(ssh_server: SshServerFixture, client_registry: ClientRegistry):
"""Run a single command.
@ -113,7 +132,7 @@ def ssh_command_runner(ssh_server: SshServerFixture, client_registry: ClientRegi
_, port = ssh_server
async def run_command_as(name: str, command: str):
client_key = client_registry['clients'][name]
client_key = client_registry["clients"][name]
conn = await asyncssh.connect(
"127.0.0.1",
port=port,

View File

@ -4,11 +4,14 @@ import pytest
from .types import ClientRegistry, CommandRunner
@pytest.mark.asyncio
async def test_get_secret(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
async def test_get_secret(
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
) -> None:
"""Test that we can get a secret."""
await client_registry['add_client']("test-client", ["mysecret"])
await client_registry["add_client"]("test-client", ["mysecret"])
result = await ssh_command_runner("test-client", "get_secret mysecret")
@ -18,10 +21,12 @@ async def test_get_secret(ssh_command_runner: CommandRunner, client_registry: Cl
@pytest.mark.asyncio
async def test_invalid_secret_name(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
async def test_invalid_secret_name(
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
) -> None:
"""Test getting an invalid secret name."""
await client_registry['add_client']("test-client")
await client_registry["add_client"]("test-client")
result = await ssh_command_runner("test-client", "get_secret mysecret")
assert result.exit_status == 1

View File

@ -2,10 +2,13 @@ import pytest
from .types import ClientRegistry, CommandRunner
@pytest.mark.asyncio
async def test_ping_command(ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
async def test_ping_command(
ssh_command_runner: CommandRunner, client_registry: ClientRegistry
) -> None:
# Register a test client with default policies and no secrets
await client_registry['add_client']('test-pinger')
await client_registry["add_client"]("test-pinger")
result = await ssh_command_runner("test-pinger", "ping")

View File

@ -1,10 +1,16 @@
"""Test registration."""
import pytest
from .types import ClientRegistry, CommandRunner, ProcessRunner
@pytest.mark.asyncio
async def test_register_client(ssh_session: ProcessRunner, ssh_command_runner: CommandRunner, client_registry: ClientRegistry) -> None:
async def test_register_client(
ssh_session: ProcessRunner,
ssh_command_runner: CommandRunner,
client_registry: ClientRegistry,
) -> None:
"""Test client registration."""
await client_registry["add_client"]("template", ["testsecret"])
@ -12,9 +18,9 @@ async def test_register_client(ssh_session: ProcessRunner, ssh_command_runner: C
async with ssh_session("newclient", "register", "template") as session:
maxlines = 10
l = 0
linenum = 0
found = False
while l < maxlines:
while linenum < maxlines:
line = await session.stdout.readline()
if "Enter public key" in line:
found = True

View File

@ -1,9 +1,10 @@
"""Types for the various test properties."""
import uuid
from datetime import datetime
from dataclasses import dataclass, field
from ipaddress import IPv4Network, IPv6Network
from collections.abc import AsyncGenerator, Awaitable, Callable, AsyncIterator
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any, Protocol, TypedDict, AsyncContextManager
import asyncssh
@ -11,8 +12,6 @@ SshServerFixture = tuple[str, int]
SshServerFixtureFun = AsyncGenerator[tuple[asyncssh.SSHAcceptor, int], None]
@dataclass
class Client:
"""Mock client."""
@ -37,13 +36,21 @@ class ClientKey:
class AddClientFun(Protocol):
"""Add client function."""
def __call__(self, name: str, secret_names: list[str] | None = None, policies: list[str] | None = None) -> Awaitable[str]: ...
def __call__(
self,
name: str,
secret_names: list[str] | None = None,
policies: list[str] | None = None,
) -> Awaitable[str]: ...
class ProcessRunner(Protocol):
"""Process runner typing."""
def __call__(self, name: str, command: str, client: str | None = None) -> AsyncContextManager[asyncssh.SSHClientProcess[Any]]:
...
def __call__(
self, name: str, command: str, client: str | None = None
) -> AsyncContextManager[asyncssh.SSHClientProcess[Any]]: ...
class ClientRegistry(TypedDict):
"""Client registry typing."""