203 lines
6.2 KiB
Python
203 lines
6.2 KiB
Python
import asyncio
|
|
from pydantic import IPvAnyNetwork
|
|
import pytest
|
|
import uuid
|
|
import asyncssh
|
|
import tempfile
|
|
from contextlib import asynccontextmanager
|
|
import pytest_asyncio
|
|
from pytest import FixtureRequest
|
|
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
from ipaddress import IPv4Network, IPv6Network, ip_network
|
|
|
|
from sshecret_sshd.ssh_server import run_ssh_server
|
|
from sshecret_sshd.settings import ClientRegistrationSettings
|
|
|
|
from .types import (
|
|
Client,
|
|
ClientKey,
|
|
ClientRegistry,
|
|
SshServerFixtureFun,
|
|
SshServerFixture,
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def client_registry() -> ClientRegistry:
|
|
clients = {}
|
|
secrets = {}
|
|
|
|
async def add_client(
|
|
name: str,
|
|
secret_names: list[str] | None = None,
|
|
policies: list[str] | None = None,
|
|
) -> str:
|
|
private_key = asyncssh.generate_private_key("ssh-rsa")
|
|
public_key = private_key.export_public_key()
|
|
clients[name] = ClientKey(
|
|
name, private_key, public_key.decode().rstrip(), policies
|
|
)
|
|
secrets.update({(name, s): f"mocked-secret-{s}" for s in (secret_names or [])})
|
|
return clients[name]
|
|
|
|
return {"clients": clients, "secrets": secrets, "add_client": add_client}
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="function")
|
|
async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
|
|
backend = MagicMock()
|
|
clients_data = client_registry["clients"]
|
|
secrets_data = client_registry["secrets"]
|
|
|
|
async def get_client(name: str) -> Client | None:
|
|
client_key = clients_data.get(name)
|
|
if client_key:
|
|
policies = [IPv4Network("0.0.0.0/0"), IPv6Network("::/0")]
|
|
if client_key.policies:
|
|
policies = [ip_network(network) for network in client_key.policies]
|
|
response_model = Client(
|
|
id=uuid.uuid4(),
|
|
name=name,
|
|
description=f"Mock client {name}",
|
|
public_key=client_key.public_key,
|
|
secrets=[s for (c, s) in secrets_data if c == name],
|
|
policies=policies,
|
|
)
|
|
return response_model
|
|
return None
|
|
|
|
async def get_client_secret(name: str, secret_name: str) -> str | None:
|
|
secret = secrets_data.get((name, secret_name), None)
|
|
return secret
|
|
|
|
async def create_client(name: str, public_key: str) -> None:
|
|
"""Create client.
|
|
|
|
This only works if you register a client called template first.
|
|
Otherwise we can't test this...
|
|
"""
|
|
if "template" not in clients_data:
|
|
raise RuntimeError(
|
|
"Error, must have a client called template for this to work."
|
|
)
|
|
clients_data[name] = clients_data["template"]
|
|
for secret_key, secret in secrets_data.items():
|
|
s_client, secret_name = secret_key
|
|
if s_client != "template":
|
|
continue
|
|
secrets_data[(name, secret_name)] = secret
|
|
|
|
async def write_audit(*args, **kwargs):
|
|
"""Write audit mock."""
|
|
return None
|
|
|
|
backend.get_client = AsyncMock(side_effect=get_client)
|
|
backend.get_client_secret = AsyncMock(side_effect=get_client_secret)
|
|
backend.create_client = AsyncMock(side_effect=create_client)
|
|
|
|
audit = MagicMock()
|
|
audit.write_async = AsyncMock(side_effect=write_audit)
|
|
backend.audit = MagicMock(return_value=audit)
|
|
|
|
return backend
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="function")
|
|
async def ssh_server(
|
|
request: FixtureRequest,
|
|
mock_backend: MagicMock,
|
|
unused_tcp_port: int,
|
|
) -> SshServerFixtureFun:
|
|
port = unused_tcp_port
|
|
private_key = asyncssh.generate_private_key("ssh-ed25519")
|
|
key_str = private_key.export_private_key()
|
|
registration_mark = request.node.get_closest_marker("enable_registration")
|
|
registration_enabled = registration_mark is not None
|
|
registration_source_mark = request.node.get_closest_marker("registration_sources")
|
|
allowed_from: list[IPvAnyNetwork] = []
|
|
if registration_source_mark:
|
|
for network in registration_source_mark.args:
|
|
allowed_from.append(ip_network(network))
|
|
else:
|
|
allowed_from = [IPv4Network("0.0.0.0/0")]
|
|
with tempfile.NamedTemporaryFile("w+", delete=True) as key_file:
|
|
key_file.write(key_str.decode())
|
|
key_file.flush()
|
|
|
|
registration_settings = ClientRegistrationSettings(
|
|
enabled=registration_enabled,
|
|
allow_from=allowed_from,
|
|
)
|
|
server = await run_ssh_server(
|
|
backend=mock_backend,
|
|
listen_address="localhost",
|
|
port=port,
|
|
keys=[key_file.name],
|
|
registration=registration_settings,
|
|
enable_ping_command=True,
|
|
)
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
yield server, port
|
|
|
|
server.close()
|
|
await server.wait_closed()
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def ssh_command_runner(ssh_server: SshServerFixture, client_registry: ClientRegistry):
|
|
"""Run a single command.
|
|
|
|
Tricky typing!
|
|
"""
|
|
_, port = ssh_server
|
|
|
|
async def run_command_as(name: str, command: str):
|
|
client_key = client_registry["clients"][name]
|
|
conn = await asyncssh.connect(
|
|
"127.0.0.1",
|
|
port=port,
|
|
username=name,
|
|
client_keys=[client_key.private_key],
|
|
known_hosts=None,
|
|
)
|
|
try:
|
|
result = await conn.run(command)
|
|
return result
|
|
finally:
|
|
conn.close()
|
|
await conn.wait_closed()
|
|
|
|
return run_command_as
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def ssh_session(ssh_server: SshServerFixture, client_registry: ClientRegistry):
|
|
"""Yield an interactive session."""
|
|
|
|
_, port = ssh_server
|
|
|
|
@asynccontextmanager
|
|
async def run_process_as(name: str, command: str, client: str | None = None):
|
|
if not client:
|
|
client = name
|
|
client_key = client_registry["clients"][client]
|
|
conn = await asyncssh.connect(
|
|
"127.0.0.1",
|
|
port=port,
|
|
username=name,
|
|
client_keys=[client_key.private_key],
|
|
known_hosts=None,
|
|
)
|
|
try:
|
|
async with conn.create_process(command) as process:
|
|
yield process
|
|
|
|
finally:
|
|
conn.close()
|
|
await conn.wait_closed()
|
|
|
|
return run_process_as
|