Update tests
This commit is contained in:
@ -1,12 +1,15 @@
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from pydantic import IPvAnyNetwork
|
||||
import pytest
|
||||
import uuid
|
||||
import asyncssh
|
||||
import tempfile
|
||||
from contextlib import asynccontextmanager
|
||||
import pytest_asyncio
|
||||
from pytest import FixtureRequest
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from ipaddress import IPv4Network, IPv6Network
|
||||
from ipaddress import IPv4Network, IPv6Network, ip_network
|
||||
|
||||
from sshecret_sshd.ssh_server import run_ssh_server
|
||||
from sshecret_sshd.settings import ClientRegistrationSettings
|
||||
@ -32,14 +35,16 @@ def client_registry() -> ClientRegistry:
|
||||
) -> str:
|
||||
private_key = asyncssh.generate_private_key("ssh-rsa")
|
||||
public_key = private_key.export_public_key()
|
||||
clients[name] = ClientKey(name, private_key, public_key.decode().rstrip())
|
||||
clients[name] = ClientKey(
|
||||
name, private_key, public_key.decode().rstrip(), policies
|
||||
)
|
||||
secrets.update({(name, s): f"mocked-secret-{s}" for s in (secret_names or [])})
|
||||
return clients[name]
|
||||
|
||||
return {"clients": clients, "secrets": secrets, "add_client": add_client}
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
|
||||
backend = MagicMock()
|
||||
clients_data = client_registry["clients"]
|
||||
@ -48,13 +53,16 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
|
||||
async def get_client(name: str) -> Client | None:
|
||||
client_key = clients_data.get(name)
|
||||
if client_key:
|
||||
policies = [IPv4Network("0.0.0.0/0"), IPv6Network("::/0")]
|
||||
if client_key.policies:
|
||||
policies = [ip_network(network) for network in client_key.policies]
|
||||
response_model = Client(
|
||||
id=uuid.uuid4(),
|
||||
name=name,
|
||||
description=f"Mock client {name}",
|
||||
public_key=client_key.public_key,
|
||||
secrets=[s for (c, s) in secrets_data if c == name],
|
||||
policies=[IPv4Network("0.0.0.0/0"), IPv6Network("::/0")],
|
||||
policies=policies,
|
||||
)
|
||||
return response_model
|
||||
return None
|
||||
@ -95,20 +103,31 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
|
||||
return backend
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def ssh_server(
|
||||
mock_backend: MagicMock, unused_tcp_port: int
|
||||
request: FixtureRequest,
|
||||
mock_backend: MagicMock,
|
||||
unused_tcp_port: int,
|
||||
) -> SshServerFixtureFun:
|
||||
port = unused_tcp_port
|
||||
|
||||
private_key = asyncssh.generate_private_key("ssh-ed25519")
|
||||
key_str = private_key.export_private_key()
|
||||
registration_mark = request.node.get_closest_marker("enable_registration")
|
||||
registration_enabled = registration_mark is not None
|
||||
registration_source_mark = request.node.get_closest_marker("registration_sources")
|
||||
allowed_from: list[IPvAnyNetwork] = []
|
||||
if registration_source_mark:
|
||||
for network in registration_source_mark.args:
|
||||
allowed_from.append(ip_network(network))
|
||||
else:
|
||||
allowed_from = [IPv4Network("0.0.0.0/0")]
|
||||
with tempfile.NamedTemporaryFile("w+", delete=True) as key_file:
|
||||
key_file.write(key_str.decode())
|
||||
key_file.flush()
|
||||
|
||||
registration_settings = ClientRegistrationSettings(
|
||||
enabled=True, allow_from=[IPv4Network("0.0.0.0/0")]
|
||||
enabled=registration_enabled,
|
||||
allow_from=allowed_from,
|
||||
)
|
||||
server = await run_ssh_server(
|
||||
backend=mock_backend,
|
||||
|
||||
Reference in New Issue
Block a user