diff --git a/tests/integration/test_admin_api.py b/tests/integration/test_admin_api.py index 5227cfd..d7bc005 100644 --- a/tests/integration/test_admin_api.py +++ b/tests/integration/test_admin_api.py @@ -6,34 +6,25 @@ import pytest import httpx +from sshecret.backend import Client -class TestAdminAPI: - """Tests of the Admin REST API.""" +from sshecret.crypto import generate_private_key, generate_public_key_string - @pytest.mark.asyncio - async def test_health_check( - self, admin_server: tuple[str, tuple[str, str]] - ) -> None: - """Test admin login.""" - async with self.http_client(admin_server, False) as client: - resp = await client.get("/health") - assert resp.status_code == 200 +from .types import AdminServer - @pytest.mark.asyncio - async def test_admin_login(self, admin_server: tuple[str, tuple[str, str]]) -> None: - """Test admin login.""" - async with self.http_client(admin_server, False) as client: - resp = await client.get("api/v1/clients/") - assert resp.status_code == 401 +def make_test_key() -> str: + """Generate a test key.""" + private_key = generate_private_key() + return generate_public_key_string(private_key.public_key()) - async with self.http_client(admin_server, True) as client: - resp = await client.get("api/v1/clients/") - assert resp.status_code == 200 + +class BaseAdminTests: + """Base admin test class.""" @asynccontextmanager async def http_client( - self, admin_server: tuple[str, tuple[str, str]], authenticate: bool = True + self, admin_server: AdminServer, authenticate: bool = True ) -> AsyncIterator[httpx.AsyncClient]: """Run a client towards the admin rest api.""" admin_url, credentials = admin_server @@ -53,3 +44,179 @@ class TestAdminAPI: async with httpx.AsyncClient(base_url=admin_url, headers=headers) as client: yield client + + async def create_client( + self, + admin_server: AdminServer, + name: str, + public_key: str | None = None, + ) -> Client: + """Create a client.""" + if not public_key: + public_key = make_test_key() + + new_client = { + "name": name, + "public_key": public_key, + "sources": ["192.0.2.0/24"], + } + + async with self.http_client(admin_server, True) as http_client: + response = await http_client.post("api/v1/clients/", json=new_client) + assert response.status_code == 200 + data = response.json() + client = Client.model_validate(data) + + return client + + +class TestAdminAPI(BaseAdminTests): + """Tests of the Admin REST API.""" + + @pytest.mark.asyncio + async def test_health_check( + self, admin_server: tuple[str, tuple[str, str]] + ) -> None: + """Test admin login.""" + async with self.http_client(admin_server, False) as client: + resp = await client.get("/health") + assert resp.status_code == 200 + + @pytest.mark.asyncio + async def test_admin_login(self, admin_server: AdminServer) -> None: + """Test admin login.""" + + async with self.http_client(admin_server, False) as client: + resp = await client.get("api/v1/clients/") + assert resp.status_code == 401 + + async with self.http_client(admin_server, True) as client: + resp = await client.get("api/v1/clients/") + assert resp.status_code == 200 + + +class TestAdminApiClients(BaseAdminTests): + """Test client routes.""" + + @pytest.mark.asyncio + async def test_create_client(self, admin_server: AdminServer) -> None: + """Test create_client.""" + client = await self.create_client(admin_server, "testclient") + + assert client.id is not None + assert client.name == "testclient" + + @pytest.mark.asyncio + async def test_get_clients(self, admin_server: AdminServer) -> None: + """Test get_clients.""" + + client_names = ["test-db", "test-app", "test-www"] + for name in client_names: + await self.create_client(admin_server, name) + async with self.http_client(admin_server) as http_client: + resp = await http_client.get("api/v1/clients/") + assert resp.status_code == 200 + data = resp.json() + assert isinstance(data, list) + assert len(data) == 3 + for entry in data: + assert isinstance(entry, dict) + client_name = entry.get("name") + assert client_name in client_names + + @pytest.mark.asyncio + async def test_delete_client(self, admin_server: AdminServer) -> None: + """Test delete_client.""" + await self.create_client(admin_server, name="testclient") + async with self.http_client(admin_server) as http_client: + resp = await http_client.get("api/v1/clients/") + assert resp.status_code == 200 + data = resp.json() + assert isinstance(data, list) + assert len(data) == 1 + assert data[0]["name"] == "testclient" + + resp = await http_client.delete("/api/v1/clients/testclient") + assert resp.status_code == 200 + + resp = await http_client.get("api/v1/clients/") + assert resp.status_code == 200 + data = resp.json() + assert isinstance(data, list) + assert len(data) == 0 + + +class TestAdminApiSecrets(BaseAdminTests): + """Test secret management.""" + + @pytest.mark.asyncio + async def test_add_secret(self, admin_server: AdminServer) -> None: + """Test add_secret.""" + await self.create_client(admin_server, name="testclient") + async with self.http_client(admin_server) as http_client: + data = { + "name": "testsecret", + "clients": ["testclient"], + "value": "secretstring", + } + resp = await http_client.post("api/v1/secrets/", json=data) + assert resp.status_code == 200 + + @pytest.mark.asyncio + async def test_get_secret(self, admin_server: AdminServer) -> None: + """Test get_secret.""" + await self.test_add_secret(admin_server) + async with self.http_client(admin_server) as http_client: + resp = await http_client.get("api/v1/secrets/testsecret") + assert resp.status_code == 200 + data = resp.json() + assert isinstance(data, dict) + assert data["name"] == "testsecret" + assert data["secret"] == "secretstring" + assert "testclient" in data["clients"] + + @pytest.mark.asyncio + async def test_add_secret_auto(self, admin_server: AdminServer) -> None: + """Test adding a secret with an auto-generated value.""" + await self.create_client(admin_server, name="testclient") + async with self.http_client(admin_server) as http_client: + data = { + "name": "testsecret", + "clients": ["testclient"], + "value": {"auto_generate": True, "length": 17}, + } + resp = await http_client.post("api/v1/secrets/", json=data) + assert resp.status_code == 200 + resp = await http_client.get("api/v1/secrets/testsecret") + assert resp.status_code == 200 + data = resp.json() + assert isinstance(data, dict) + assert data["name"] == "testsecret" + assert len(data["secret"]) == 17 + assert "testclient" in data["clients"] + + @pytest.mark.asyncio + async def test_update_secret(self, admin_server: AdminServer) -> None: + """Test updating secrets.""" + await self.test_add_secret_auto(admin_server) + async with self.http_client(admin_server) as http_client: + resp = await http_client.put( + "api/v1/secrets/testsecret", + json={"value": "secret"}, + ) + assert resp.status_code == 200 + resp = await http_client.get("api/v1/secrets/testsecret") + assert resp.status_code == 200 + data = resp.json() + assert data["secret"] == "secret" + + resp = await http_client.put( + "api/v1/secrets/testsecret", + json={"value": {"auto_generate": True, "length": 16}}, + ) + assert resp.status_code == 200 + + resp = await http_client.get("api/v1/secrets/testsecret") + assert resp.status_code == 200 + data = resp.json() + assert len(data["secret"]) == 16 diff --git a/tests/integration/test_backend.py b/tests/integration/test_backend.py index b942991..9cc667d 100644 --- a/tests/integration/test_backend.py +++ b/tests/integration/test_backend.py @@ -34,6 +34,7 @@ async def test_create_client(backend_api: SshecretBackend) -> None: assert clients[0].public_key == test_client.public_key +@pytest.mark.asyncio async def test_create_secret(backend_api: SshecretBackend) -> None: """Test creating secrets.""" test_client = create_test_client("test") diff --git a/tests/integration/types.py b/tests/integration/types.py index 96998f6..311c0bd 100644 --- a/tests/integration/types.py +++ b/tests/integration/types.py @@ -10,6 +10,7 @@ from .clients import ClientData PortFactory = Callable[[], int] +AdminServer = tuple[str, tuple[str, str]] @dataclass class TestPorts: diff --git a/tests/packages/sshd/conftest.py b/tests/packages/sshd/conftest.py index c27580e..55f4a38 100644 --- a/tests/packages/sshd/conftest.py +++ b/tests/packages/sshd/conftest.py @@ -1,12 +1,15 @@ import asyncio -from typing import Any +from pydantic import IPvAnyNetwork import pytest import uuid import asyncssh import tempfile from contextlib import asynccontextmanager +import pytest_asyncio +from pytest import FixtureRequest + from unittest.mock import AsyncMock, MagicMock -from ipaddress import IPv4Network, IPv6Network +from ipaddress import IPv4Network, IPv6Network, ip_network from sshecret_sshd.ssh_server import run_ssh_server from sshecret_sshd.settings import ClientRegistrationSettings @@ -32,14 +35,16 @@ def client_registry() -> ClientRegistry: ) -> str: private_key = asyncssh.generate_private_key("ssh-rsa") public_key = private_key.export_public_key() - clients[name] = ClientKey(name, private_key, public_key.decode().rstrip()) + clients[name] = ClientKey( + name, private_key, public_key.decode().rstrip(), policies + ) secrets.update({(name, s): f"mocked-secret-{s}" for s in (secret_names or [])}) return clients[name] return {"clients": clients, "secrets": secrets, "add_client": add_client} -@pytest.fixture(scope="function") +@pytest_asyncio.fixture(scope="function") async def mock_backend(client_registry: ClientRegistry) -> MagicMock: backend = MagicMock() clients_data = client_registry["clients"] @@ -48,13 +53,16 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock: async def get_client(name: str) -> Client | None: client_key = clients_data.get(name) if client_key: + policies = [IPv4Network("0.0.0.0/0"), IPv6Network("::/0")] + if client_key.policies: + policies = [ip_network(network) for network in client_key.policies] response_model = Client( id=uuid.uuid4(), name=name, description=f"Mock client {name}", public_key=client_key.public_key, secrets=[s for (c, s) in secrets_data if c == name], - policies=[IPv4Network("0.0.0.0/0"), IPv6Network("::/0")], + policies=policies, ) return response_model return None @@ -95,20 +103,31 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock: return backend -@pytest.fixture(scope="function") +@pytest_asyncio.fixture(scope="function") async def ssh_server( - mock_backend: MagicMock, unused_tcp_port: int + request: FixtureRequest, + mock_backend: MagicMock, + unused_tcp_port: int, ) -> SshServerFixtureFun: port = unused_tcp_port - private_key = asyncssh.generate_private_key("ssh-ed25519") key_str = private_key.export_private_key() + registration_mark = request.node.get_closest_marker("enable_registration") + registration_enabled = registration_mark is not None + registration_source_mark = request.node.get_closest_marker("registration_sources") + allowed_from: list[IPvAnyNetwork] = [] + if registration_source_mark: + for network in registration_source_mark.args: + allowed_from.append(ip_network(network)) + else: + allowed_from = [IPv4Network("0.0.0.0/0")] with tempfile.NamedTemporaryFile("w+", delete=True) as key_file: key_file.write(key_str.decode()) key_file.flush() registration_settings = ClientRegistrationSettings( - enabled=True, allow_from=[IPv4Network("0.0.0.0/0")] + enabled=registration_enabled, + allow_from=allowed_from, ) server = await run_ssh_server( backend=mock_backend, diff --git a/tests/packages/sshd/test_errors.py b/tests/packages/sshd/test_errors.py new file mode 100644 index 0000000..4f7d143 --- /dev/null +++ b/tests/packages/sshd/test_errors.py @@ -0,0 +1,155 @@ +"""Test various exceptions and error conditions.""" + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +import asyncssh +import pytest + +from .types import ClientRegistry, CommandRunner, ProcessRunner, SshServerFixture + + +class BaseSshTests: + """Base test class.""" + + @asynccontextmanager + async def unregistered_client(self, username: str, port: int): + """Generate SSH session as an uregistered client.""" + private_key = asyncssh.generate_private_key("ssh-rsa") + conn = await asyncssh.connect( + "127.0.0.1", + port=port, + username=username, + client_keys=[private_key], + known_hosts=None, + ) + try: + yield conn + finally: + conn.close() + await conn.wait_closed() + + @asynccontextmanager + async def ssh_connection( + self, username: str, port: int, private_key: asyncssh.SSHKey + ): + """Generate SSH session as a client with an ed25519 key.""" + # private_key = asyncssh.generate_private_key("ssh-ed25519") + conn = await asyncssh.connect( + "127.0.0.1", + port=port, + username=username, + client_keys=[private_key], + known_hosts=None, + ) + try: + yield conn + finally: + conn.close() + await conn.wait_closed() + + +class TestRegistrationErrors(BaseSshTests): + """Test class for errors related to registartion.""" + + @pytest.mark.enable_registration(True) + @pytest.mark.registration_sources("192.0.2.0/24") + @pytest.mark.asyncio + async def test_register_client_invalid_source( + self, ssh_server: SshServerFixture + ) -> None: + """Test client registration from a network that's not permitted.""" + _, port = ssh_server + with pytest.raises(asyncssh.misc.PermissionDenied): + async with self.unregistered_client("stranger", port) as conn: + async with conn.create_process("register") as process: + stdout, stderr = process.collect_output() + print(f"{stdout=!r}\n{stderr=!r}") + if isinstance(stdout, str): + assert "Enter public key" not in stdout + result = await process.wait() + assert result.exit_status == 1 + + @pytest.mark.enable_registration(True) + @pytest.mark.registration_sources("127.0.0.1") + @pytest.mark.asyncio + async def test_invalid_key_type(self, ssh_server: SshServerFixture) -> None: + """Test registration with an unsupported key.""" + _, port = ssh_server + private_key = asyncssh.generate_private_key("ssh-ed25519") + public_key = private_key.export_public_key().decode().rstrip() + "\n" + + async with self.ssh_connection("stranger", port, private_key) as conn: + async with conn.create_process("register") as process: + output = await process.stdout.readline() + assert "Enter public key" in output + stdout, stderr = await process.communicate(public_key) + print(f"{stdout=!r}, {stderr=!r}") + assert stderr == "Error: Invalid key type: Only RSA keys are supported." + result = await process.wait() + assert result.exit_status == 1 + + @pytest.mark.enable_registration(True) + @pytest.mark.registration_sources("127.0.0.1") + @pytest.mark.asyncio + async def test_invalid_key(self, ssh_server: SshServerFixture) -> None: + """Test registration with a bogus string as key..""" + _, port = ssh_server + private_key = asyncssh.generate_private_key("ssh-ed25519") + public_key = f"ssh-test {'A' * 544}\n" + + async with self.ssh_connection("stranger", port, private_key) as conn: + async with conn.create_process("register") as process: + output = await process.stdout.readline() + assert "Enter public key" in output + stdout, stderr = await process.communicate(public_key) + print(f"{stdout=!r}, {stderr=!r}") + assert stderr == "Error: Invalid key type: Only RSA keys are supported." + result = await process.wait() + assert result.exit_status == 1 + + +class TestCommandErrors(BaseSshTests): + """Tests various errors around commands.""" + + @pytest.mark.asyncio + async def test_invalid_command( + self, ssh_command_runner: CommandRunner, client_registry: ClientRegistry + ) -> None: + """Test sending an invalid command.""" + await client_registry["add_client"]("test") + + result = await ssh_command_runner("test", "cat /etc/passwd") + + assert result.exit_status == 1 + stderr = result.stderr or "" + assert stderr == "Error: Unsupported command." + + @pytest.mark.asyncio + async def test_no_command( + self, ssh_server: SshServerFixture, client_registry: ClientRegistry + ) -> None: + """Test sending no command.""" + await client_registry["add_client"]("test") + _, port = ssh_server + client_key = client_registry["clients"]["test"] + async with self.ssh_connection("test", port, client_key.private_key) as conn: + async with conn.create_process() as process: + stdout, stderr = await process.communicate() + print(f"{stdout=!r}, {stderr=!r}") + assert stderr == "Error: No command was received from the client." + result = await process.wait() + assert result.exit_status == 1 + + @pytest.mark.asyncio + async def test_deny_client_connection( + self, ssh_command_runner: CommandRunner, client_registry: ClientRegistry + ) -> None: + """Test client that is not permitted to connect.""" + await client_registry["add_client"]( + "test-client", + ["mysecret"], + ["192.0.2.0/24"], + ) + + with pytest.raises(asyncssh.misc.PermissionDenied): + await ssh_command_runner("test-client", "get_secret mysecret") diff --git a/tests/packages/sshd/test_register.py b/tests/packages/sshd/test_register.py index 2e6c3c5..ad82e78 100644 --- a/tests/packages/sshd/test_register.py +++ b/tests/packages/sshd/test_register.py @@ -5,6 +5,7 @@ import pytest from .types import ClientRegistry, CommandRunner, ProcessRunner +@pytest.mark.enable_registration(True) @pytest.mark.asyncio async def test_register_client( ssh_session: ProcessRunner, diff --git a/tests/packages/sshd/types.py b/tests/packages/sshd/types.py index aa009cb..1c76323 100644 --- a/tests/packages/sshd/types.py +++ b/tests/packages/sshd/types.py @@ -31,6 +31,8 @@ class ClientKey: name: str private_key: asyncssh.SSHKey public_key: str + policies: list[str] | None = None + class AddClientFun(Protocol):