Update tests

This commit is contained in:
2025-05-12 07:47:38 +02:00
parent 458863de3d
commit 80e2c339e3
7 changed files with 375 additions and 29 deletions

View File

@ -1,12 +1,15 @@
import asyncio
from typing import Any
from pydantic import IPvAnyNetwork
import pytest
import uuid
import asyncssh
import tempfile
from contextlib import asynccontextmanager
import pytest_asyncio
from pytest import FixtureRequest
from unittest.mock import AsyncMock, MagicMock
from ipaddress import IPv4Network, IPv6Network
from ipaddress import IPv4Network, IPv6Network, ip_network
from sshecret_sshd.ssh_server import run_ssh_server
from sshecret_sshd.settings import ClientRegistrationSettings
@ -32,14 +35,16 @@ def client_registry() -> ClientRegistry:
) -> str:
private_key = asyncssh.generate_private_key("ssh-rsa")
public_key = private_key.export_public_key()
clients[name] = ClientKey(name, private_key, public_key.decode().rstrip())
clients[name] = ClientKey(
name, private_key, public_key.decode().rstrip(), policies
)
secrets.update({(name, s): f"mocked-secret-{s}" for s in (secret_names or [])})
return clients[name]
return {"clients": clients, "secrets": secrets, "add_client": add_client}
@pytest.fixture(scope="function")
@pytest_asyncio.fixture(scope="function")
async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
backend = MagicMock()
clients_data = client_registry["clients"]
@ -48,13 +53,16 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
async def get_client(name: str) -> Client | None:
client_key = clients_data.get(name)
if client_key:
policies = [IPv4Network("0.0.0.0/0"), IPv6Network("::/0")]
if client_key.policies:
policies = [ip_network(network) for network in client_key.policies]
response_model = Client(
id=uuid.uuid4(),
name=name,
description=f"Mock client {name}",
public_key=client_key.public_key,
secrets=[s for (c, s) in secrets_data if c == name],
policies=[IPv4Network("0.0.0.0/0"), IPv6Network("::/0")],
policies=policies,
)
return response_model
return None
@ -95,20 +103,31 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
return backend
@pytest.fixture(scope="function")
@pytest_asyncio.fixture(scope="function")
async def ssh_server(
mock_backend: MagicMock, unused_tcp_port: int
request: FixtureRequest,
mock_backend: MagicMock,
unused_tcp_port: int,
) -> SshServerFixtureFun:
port = unused_tcp_port
private_key = asyncssh.generate_private_key("ssh-ed25519")
key_str = private_key.export_private_key()
registration_mark = request.node.get_closest_marker("enable_registration")
registration_enabled = registration_mark is not None
registration_source_mark = request.node.get_closest_marker("registration_sources")
allowed_from: list[IPvAnyNetwork] = []
if registration_source_mark:
for network in registration_source_mark.args:
allowed_from.append(ip_network(network))
else:
allowed_from = [IPv4Network("0.0.0.0/0")]
with tempfile.NamedTemporaryFile("w+", delete=True) as key_file:
key_file.write(key_str.decode())
key_file.flush()
registration_settings = ClientRegistrationSettings(
enabled=True, allow_from=[IPv4Network("0.0.0.0/0")]
enabled=registration_enabled,
allow_from=allowed_from,
)
server = await run_ssh_server(
backend=mock_backend,

View File

@ -0,0 +1,155 @@
"""Test various exceptions and error conditions."""
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
import asyncssh
import pytest
from .types import ClientRegistry, CommandRunner, ProcessRunner, SshServerFixture
class BaseSshTests:
"""Base test class."""
@asynccontextmanager
async def unregistered_client(self, username: str, port: int):
"""Generate SSH session as an uregistered client."""
private_key = asyncssh.generate_private_key("ssh-rsa")
conn = await asyncssh.connect(
"127.0.0.1",
port=port,
username=username,
client_keys=[private_key],
known_hosts=None,
)
try:
yield conn
finally:
conn.close()
await conn.wait_closed()
@asynccontextmanager
async def ssh_connection(
self, username: str, port: int, private_key: asyncssh.SSHKey
):
"""Generate SSH session as a client with an ed25519 key."""
# private_key = asyncssh.generate_private_key("ssh-ed25519")
conn = await asyncssh.connect(
"127.0.0.1",
port=port,
username=username,
client_keys=[private_key],
known_hosts=None,
)
try:
yield conn
finally:
conn.close()
await conn.wait_closed()
class TestRegistrationErrors(BaseSshTests):
"""Test class for errors related to registartion."""
@pytest.mark.enable_registration(True)
@pytest.mark.registration_sources("192.0.2.0/24")
@pytest.mark.asyncio
async def test_register_client_invalid_source(
self, ssh_server: SshServerFixture
) -> None:
"""Test client registration from a network that's not permitted."""
_, port = ssh_server
with pytest.raises(asyncssh.misc.PermissionDenied):
async with self.unregistered_client("stranger", port) as conn:
async with conn.create_process("register") as process:
stdout, stderr = process.collect_output()
print(f"{stdout=!r}\n{stderr=!r}")
if isinstance(stdout, str):
assert "Enter public key" not in stdout
result = await process.wait()
assert result.exit_status == 1
@pytest.mark.enable_registration(True)
@pytest.mark.registration_sources("127.0.0.1")
@pytest.mark.asyncio
async def test_invalid_key_type(self, ssh_server: SshServerFixture) -> None:
"""Test registration with an unsupported key."""
_, port = ssh_server
private_key = asyncssh.generate_private_key("ssh-ed25519")
public_key = private_key.export_public_key().decode().rstrip() + "\n"
async with self.ssh_connection("stranger", port, private_key) as conn:
async with conn.create_process("register") as process:
output = await process.stdout.readline()
assert "Enter public key" in output
stdout, stderr = await process.communicate(public_key)
print(f"{stdout=!r}, {stderr=!r}")
assert stderr == "Error: Invalid key type: Only RSA keys are supported."
result = await process.wait()
assert result.exit_status == 1
@pytest.mark.enable_registration(True)
@pytest.mark.registration_sources("127.0.0.1")
@pytest.mark.asyncio
async def test_invalid_key(self, ssh_server: SshServerFixture) -> None:
"""Test registration with a bogus string as key.."""
_, port = ssh_server
private_key = asyncssh.generate_private_key("ssh-ed25519")
public_key = f"ssh-test {'A' * 544}\n"
async with self.ssh_connection("stranger", port, private_key) as conn:
async with conn.create_process("register") as process:
output = await process.stdout.readline()
assert "Enter public key" in output
stdout, stderr = await process.communicate(public_key)
print(f"{stdout=!r}, {stderr=!r}")
assert stderr == "Error: Invalid key type: Only RSA keys are supported."
result = await process.wait()
assert result.exit_status == 1
class TestCommandErrors(BaseSshTests):
"""Tests various errors around commands."""
@pytest.mark.asyncio
async def test_invalid_command(
self, ssh_command_runner: CommandRunner, client_registry: ClientRegistry
) -> None:
"""Test sending an invalid command."""
await client_registry["add_client"]("test")
result = await ssh_command_runner("test", "cat /etc/passwd")
assert result.exit_status == 1
stderr = result.stderr or ""
assert stderr == "Error: Unsupported command."
@pytest.mark.asyncio
async def test_no_command(
self, ssh_server: SshServerFixture, client_registry: ClientRegistry
) -> None:
"""Test sending no command."""
await client_registry["add_client"]("test")
_, port = ssh_server
client_key = client_registry["clients"]["test"]
async with self.ssh_connection("test", port, client_key.private_key) as conn:
async with conn.create_process() as process:
stdout, stderr = await process.communicate()
print(f"{stdout=!r}, {stderr=!r}")
assert stderr == "Error: No command was received from the client."
result = await process.wait()
assert result.exit_status == 1
@pytest.mark.asyncio
async def test_deny_client_connection(
self, ssh_command_runner: CommandRunner, client_registry: ClientRegistry
) -> None:
"""Test client that is not permitted to connect."""
await client_registry["add_client"](
"test-client",
["mysecret"],
["192.0.2.0/24"],
)
with pytest.raises(asyncssh.misc.PermissionDenied):
await ssh_command_runner("test-client", "get_secret mysecret")

View File

@ -5,6 +5,7 @@ import pytest
from .types import ClientRegistry, CommandRunner, ProcessRunner
@pytest.mark.enable_registration(True)
@pytest.mark.asyncio
async def test_register_client(
ssh_session: ProcessRunner,

View File

@ -31,6 +31,8 @@ class ClientKey:
name: str
private_key: asyncssh.SSHKey
public_key: str
policies: list[str] | None = None
class AddClientFun(Protocol):