Update tests

This commit is contained in:
2025-05-12 07:47:38 +02:00
parent 458863de3d
commit 80e2c339e3
7 changed files with 375 additions and 29 deletions

View File

@ -6,34 +6,25 @@ import pytest
import httpx import httpx
from sshecret.backend import Client
class TestAdminAPI: from sshecret.crypto import generate_private_key, generate_public_key_string
"""Tests of the Admin REST API."""
@pytest.mark.asyncio from .types import AdminServer
async def test_health_check(
self, admin_server: tuple[str, tuple[str, str]]
) -> None:
"""Test admin login."""
async with self.http_client(admin_server, False) as client:
resp = await client.get("/health")
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_admin_login(self, admin_server: tuple[str, tuple[str, str]]) -> None:
"""Test admin login."""
async with self.http_client(admin_server, False) as client: def make_test_key() -> str:
resp = await client.get("api/v1/clients/") """Generate a test key."""
assert resp.status_code == 401 private_key = generate_private_key()
return generate_public_key_string(private_key.public_key())
async with self.http_client(admin_server, True) as client:
resp = await client.get("api/v1/clients/") class BaseAdminTests:
assert resp.status_code == 200 """Base admin test class."""
@asynccontextmanager @asynccontextmanager
async def http_client( async def http_client(
self, admin_server: tuple[str, tuple[str, str]], authenticate: bool = True self, admin_server: AdminServer, authenticate: bool = True
) -> AsyncIterator[httpx.AsyncClient]: ) -> AsyncIterator[httpx.AsyncClient]:
"""Run a client towards the admin rest api.""" """Run a client towards the admin rest api."""
admin_url, credentials = admin_server admin_url, credentials = admin_server
@ -53,3 +44,179 @@ class TestAdminAPI:
async with httpx.AsyncClient(base_url=admin_url, headers=headers) as client: async with httpx.AsyncClient(base_url=admin_url, headers=headers) as client:
yield client yield client
async def create_client(
self,
admin_server: AdminServer,
name: str,
public_key: str | None = None,
) -> Client:
"""Create a client."""
if not public_key:
public_key = make_test_key()
new_client = {
"name": name,
"public_key": public_key,
"sources": ["192.0.2.0/24"],
}
async with self.http_client(admin_server, True) as http_client:
response = await http_client.post("api/v1/clients/", json=new_client)
assert response.status_code == 200
data = response.json()
client = Client.model_validate(data)
return client
class TestAdminAPI(BaseAdminTests):
"""Tests of the Admin REST API."""
@pytest.mark.asyncio
async def test_health_check(
self, admin_server: tuple[str, tuple[str, str]]
) -> None:
"""Test admin login."""
async with self.http_client(admin_server, False) as client:
resp = await client.get("/health")
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_admin_login(self, admin_server: AdminServer) -> None:
"""Test admin login."""
async with self.http_client(admin_server, False) as client:
resp = await client.get("api/v1/clients/")
assert resp.status_code == 401
async with self.http_client(admin_server, True) as client:
resp = await client.get("api/v1/clients/")
assert resp.status_code == 200
class TestAdminApiClients(BaseAdminTests):
"""Test client routes."""
@pytest.mark.asyncio
async def test_create_client(self, admin_server: AdminServer) -> None:
"""Test create_client."""
client = await self.create_client(admin_server, "testclient")
assert client.id is not None
assert client.name == "testclient"
@pytest.mark.asyncio
async def test_get_clients(self, admin_server: AdminServer) -> None:
"""Test get_clients."""
client_names = ["test-db", "test-app", "test-www"]
for name in client_names:
await self.create_client(admin_server, name)
async with self.http_client(admin_server) as http_client:
resp = await http_client.get("api/v1/clients/")
assert resp.status_code == 200
data = resp.json()
assert isinstance(data, list)
assert len(data) == 3
for entry in data:
assert isinstance(entry, dict)
client_name = entry.get("name")
assert client_name in client_names
@pytest.mark.asyncio
async def test_delete_client(self, admin_server: AdminServer) -> None:
"""Test delete_client."""
await self.create_client(admin_server, name="testclient")
async with self.http_client(admin_server) as http_client:
resp = await http_client.get("api/v1/clients/")
assert resp.status_code == 200
data = resp.json()
assert isinstance(data, list)
assert len(data) == 1
assert data[0]["name"] == "testclient"
resp = await http_client.delete("/api/v1/clients/testclient")
assert resp.status_code == 200
resp = await http_client.get("api/v1/clients/")
assert resp.status_code == 200
data = resp.json()
assert isinstance(data, list)
assert len(data) == 0
class TestAdminApiSecrets(BaseAdminTests):
"""Test secret management."""
@pytest.mark.asyncio
async def test_add_secret(self, admin_server: AdminServer) -> None:
"""Test add_secret."""
await self.create_client(admin_server, name="testclient")
async with self.http_client(admin_server) as http_client:
data = {
"name": "testsecret",
"clients": ["testclient"],
"value": "secretstring",
}
resp = await http_client.post("api/v1/secrets/", json=data)
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_get_secret(self, admin_server: AdminServer) -> None:
"""Test get_secret."""
await self.test_add_secret(admin_server)
async with self.http_client(admin_server) as http_client:
resp = await http_client.get("api/v1/secrets/testsecret")
assert resp.status_code == 200
data = resp.json()
assert isinstance(data, dict)
assert data["name"] == "testsecret"
assert data["secret"] == "secretstring"
assert "testclient" in data["clients"]
@pytest.mark.asyncio
async def test_add_secret_auto(self, admin_server: AdminServer) -> None:
"""Test adding a secret with an auto-generated value."""
await self.create_client(admin_server, name="testclient")
async with self.http_client(admin_server) as http_client:
data = {
"name": "testsecret",
"clients": ["testclient"],
"value": {"auto_generate": True, "length": 17},
}
resp = await http_client.post("api/v1/secrets/", json=data)
assert resp.status_code == 200
resp = await http_client.get("api/v1/secrets/testsecret")
assert resp.status_code == 200
data = resp.json()
assert isinstance(data, dict)
assert data["name"] == "testsecret"
assert len(data["secret"]) == 17
assert "testclient" in data["clients"]
@pytest.mark.asyncio
async def test_update_secret(self, admin_server: AdminServer) -> None:
"""Test updating secrets."""
await self.test_add_secret_auto(admin_server)
async with self.http_client(admin_server) as http_client:
resp = await http_client.put(
"api/v1/secrets/testsecret",
json={"value": "secret"},
)
assert resp.status_code == 200
resp = await http_client.get("api/v1/secrets/testsecret")
assert resp.status_code == 200
data = resp.json()
assert data["secret"] == "secret"
resp = await http_client.put(
"api/v1/secrets/testsecret",
json={"value": {"auto_generate": True, "length": 16}},
)
assert resp.status_code == 200
resp = await http_client.get("api/v1/secrets/testsecret")
assert resp.status_code == 200
data = resp.json()
assert len(data["secret"]) == 16

View File

@ -34,6 +34,7 @@ async def test_create_client(backend_api: SshecretBackend) -> None:
assert clients[0].public_key == test_client.public_key assert clients[0].public_key == test_client.public_key
@pytest.mark.asyncio
async def test_create_secret(backend_api: SshecretBackend) -> None: async def test_create_secret(backend_api: SshecretBackend) -> None:
"""Test creating secrets.""" """Test creating secrets."""
test_client = create_test_client("test") test_client = create_test_client("test")

View File

@ -10,6 +10,7 @@ from .clients import ClientData
PortFactory = Callable[[], int] PortFactory = Callable[[], int]
AdminServer = tuple[str, tuple[str, str]]
@dataclass @dataclass
class TestPorts: class TestPorts:

View File

@ -1,12 +1,15 @@
import asyncio import asyncio
from typing import Any from pydantic import IPvAnyNetwork
import pytest import pytest
import uuid import uuid
import asyncssh import asyncssh
import tempfile import tempfile
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import pytest_asyncio
from pytest import FixtureRequest
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
from ipaddress import IPv4Network, IPv6Network from ipaddress import IPv4Network, IPv6Network, ip_network
from sshecret_sshd.ssh_server import run_ssh_server from sshecret_sshd.ssh_server import run_ssh_server
from sshecret_sshd.settings import ClientRegistrationSettings from sshecret_sshd.settings import ClientRegistrationSettings
@ -32,14 +35,16 @@ def client_registry() -> ClientRegistry:
) -> str: ) -> str:
private_key = asyncssh.generate_private_key("ssh-rsa") private_key = asyncssh.generate_private_key("ssh-rsa")
public_key = private_key.export_public_key() public_key = private_key.export_public_key()
clients[name] = ClientKey(name, private_key, public_key.decode().rstrip()) clients[name] = ClientKey(
name, private_key, public_key.decode().rstrip(), policies
)
secrets.update({(name, s): f"mocked-secret-{s}" for s in (secret_names or [])}) secrets.update({(name, s): f"mocked-secret-{s}" for s in (secret_names or [])})
return clients[name] return clients[name]
return {"clients": clients, "secrets": secrets, "add_client": add_client} return {"clients": clients, "secrets": secrets, "add_client": add_client}
@pytest.fixture(scope="function") @pytest_asyncio.fixture(scope="function")
async def mock_backend(client_registry: ClientRegistry) -> MagicMock: async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
backend = MagicMock() backend = MagicMock()
clients_data = client_registry["clients"] clients_data = client_registry["clients"]
@ -48,13 +53,16 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
async def get_client(name: str) -> Client | None: async def get_client(name: str) -> Client | None:
client_key = clients_data.get(name) client_key = clients_data.get(name)
if client_key: if client_key:
policies = [IPv4Network("0.0.0.0/0"), IPv6Network("::/0")]
if client_key.policies:
policies = [ip_network(network) for network in client_key.policies]
response_model = Client( response_model = Client(
id=uuid.uuid4(), id=uuid.uuid4(),
name=name, name=name,
description=f"Mock client {name}", description=f"Mock client {name}",
public_key=client_key.public_key, public_key=client_key.public_key,
secrets=[s for (c, s) in secrets_data if c == name], secrets=[s for (c, s) in secrets_data if c == name],
policies=[IPv4Network("0.0.0.0/0"), IPv6Network("::/0")], policies=policies,
) )
return response_model return response_model
return None return None
@ -95,20 +103,31 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
return backend return backend
@pytest.fixture(scope="function") @pytest_asyncio.fixture(scope="function")
async def ssh_server( async def ssh_server(
mock_backend: MagicMock, unused_tcp_port: int request: FixtureRequest,
mock_backend: MagicMock,
unused_tcp_port: int,
) -> SshServerFixtureFun: ) -> SshServerFixtureFun:
port = unused_tcp_port port = unused_tcp_port
private_key = asyncssh.generate_private_key("ssh-ed25519") private_key = asyncssh.generate_private_key("ssh-ed25519")
key_str = private_key.export_private_key() key_str = private_key.export_private_key()
registration_mark = request.node.get_closest_marker("enable_registration")
registration_enabled = registration_mark is not None
registration_source_mark = request.node.get_closest_marker("registration_sources")
allowed_from: list[IPvAnyNetwork] = []
if registration_source_mark:
for network in registration_source_mark.args:
allowed_from.append(ip_network(network))
else:
allowed_from = [IPv4Network("0.0.0.0/0")]
with tempfile.NamedTemporaryFile("w+", delete=True) as key_file: with tempfile.NamedTemporaryFile("w+", delete=True) as key_file:
key_file.write(key_str.decode()) key_file.write(key_str.decode())
key_file.flush() key_file.flush()
registration_settings = ClientRegistrationSettings( registration_settings = ClientRegistrationSettings(
enabled=True, allow_from=[IPv4Network("0.0.0.0/0")] enabled=registration_enabled,
allow_from=allowed_from,
) )
server = await run_ssh_server( server = await run_ssh_server(
backend=mock_backend, backend=mock_backend,

View File

@ -0,0 +1,155 @@
"""Test various exceptions and error conditions."""
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
import asyncssh
import pytest
from .types import ClientRegistry, CommandRunner, ProcessRunner, SshServerFixture
class BaseSshTests:
"""Base test class."""
@asynccontextmanager
async def unregistered_client(self, username: str, port: int):
"""Generate SSH session as an uregistered client."""
private_key = asyncssh.generate_private_key("ssh-rsa")
conn = await asyncssh.connect(
"127.0.0.1",
port=port,
username=username,
client_keys=[private_key],
known_hosts=None,
)
try:
yield conn
finally:
conn.close()
await conn.wait_closed()
@asynccontextmanager
async def ssh_connection(
self, username: str, port: int, private_key: asyncssh.SSHKey
):
"""Generate SSH session as a client with an ed25519 key."""
# private_key = asyncssh.generate_private_key("ssh-ed25519")
conn = await asyncssh.connect(
"127.0.0.1",
port=port,
username=username,
client_keys=[private_key],
known_hosts=None,
)
try:
yield conn
finally:
conn.close()
await conn.wait_closed()
class TestRegistrationErrors(BaseSshTests):
"""Test class for errors related to registartion."""
@pytest.mark.enable_registration(True)
@pytest.mark.registration_sources("192.0.2.0/24")
@pytest.mark.asyncio
async def test_register_client_invalid_source(
self, ssh_server: SshServerFixture
) -> None:
"""Test client registration from a network that's not permitted."""
_, port = ssh_server
with pytest.raises(asyncssh.misc.PermissionDenied):
async with self.unregistered_client("stranger", port) as conn:
async with conn.create_process("register") as process:
stdout, stderr = process.collect_output()
print(f"{stdout=!r}\n{stderr=!r}")
if isinstance(stdout, str):
assert "Enter public key" not in stdout
result = await process.wait()
assert result.exit_status == 1
@pytest.mark.enable_registration(True)
@pytest.mark.registration_sources("127.0.0.1")
@pytest.mark.asyncio
async def test_invalid_key_type(self, ssh_server: SshServerFixture) -> None:
"""Test registration with an unsupported key."""
_, port = ssh_server
private_key = asyncssh.generate_private_key("ssh-ed25519")
public_key = private_key.export_public_key().decode().rstrip() + "\n"
async with self.ssh_connection("stranger", port, private_key) as conn:
async with conn.create_process("register") as process:
output = await process.stdout.readline()
assert "Enter public key" in output
stdout, stderr = await process.communicate(public_key)
print(f"{stdout=!r}, {stderr=!r}")
assert stderr == "Error: Invalid key type: Only RSA keys are supported."
result = await process.wait()
assert result.exit_status == 1
@pytest.mark.enable_registration(True)
@pytest.mark.registration_sources("127.0.0.1")
@pytest.mark.asyncio
async def test_invalid_key(self, ssh_server: SshServerFixture) -> None:
"""Test registration with a bogus string as key.."""
_, port = ssh_server
private_key = asyncssh.generate_private_key("ssh-ed25519")
public_key = f"ssh-test {'A' * 544}\n"
async with self.ssh_connection("stranger", port, private_key) as conn:
async with conn.create_process("register") as process:
output = await process.stdout.readline()
assert "Enter public key" in output
stdout, stderr = await process.communicate(public_key)
print(f"{stdout=!r}, {stderr=!r}")
assert stderr == "Error: Invalid key type: Only RSA keys are supported."
result = await process.wait()
assert result.exit_status == 1
class TestCommandErrors(BaseSshTests):
"""Tests various errors around commands."""
@pytest.mark.asyncio
async def test_invalid_command(
self, ssh_command_runner: CommandRunner, client_registry: ClientRegistry
) -> None:
"""Test sending an invalid command."""
await client_registry["add_client"]("test")
result = await ssh_command_runner("test", "cat /etc/passwd")
assert result.exit_status == 1
stderr = result.stderr or ""
assert stderr == "Error: Unsupported command."
@pytest.mark.asyncio
async def test_no_command(
self, ssh_server: SshServerFixture, client_registry: ClientRegistry
) -> None:
"""Test sending no command."""
await client_registry["add_client"]("test")
_, port = ssh_server
client_key = client_registry["clients"]["test"]
async with self.ssh_connection("test", port, client_key.private_key) as conn:
async with conn.create_process() as process:
stdout, stderr = await process.communicate()
print(f"{stdout=!r}, {stderr=!r}")
assert stderr == "Error: No command was received from the client."
result = await process.wait()
assert result.exit_status == 1
@pytest.mark.asyncio
async def test_deny_client_connection(
self, ssh_command_runner: CommandRunner, client_registry: ClientRegistry
) -> None:
"""Test client that is not permitted to connect."""
await client_registry["add_client"](
"test-client",
["mysecret"],
["192.0.2.0/24"],
)
with pytest.raises(asyncssh.misc.PermissionDenied):
await ssh_command_runner("test-client", "get_secret mysecret")

View File

@ -5,6 +5,7 @@ import pytest
from .types import ClientRegistry, CommandRunner, ProcessRunner from .types import ClientRegistry, CommandRunner, ProcessRunner
@pytest.mark.enable_registration(True)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_client( async def test_register_client(
ssh_session: ProcessRunner, ssh_session: ProcessRunner,

View File

@ -31,6 +31,8 @@ class ClientKey:
name: str name: str
private_key: asyncssh.SSHKey private_key: asyncssh.SSHKey
public_key: str public_key: str
policies: list[str] | None = None
class AddClientFun(Protocol): class AddClientFun(Protocol):