"""Tests where the sshd is the main consumer. This essentially also tests parts of the admin API. """ import asyncio from contextlib import asynccontextmanager from collections.abc import AsyncIterator import os import httpx import pytest from sshecret.crypto import decode_string from sshecret.backend.api import SshecretBackend from .clients import create_test_client, ClientData from .types import CommandRunner, ProcessRunner class BaseSSHTests: """Base test class.""" async def register_client( self, name: str, ssh_session: ProcessRunner ) -> ClientData: """Register client.""" test_client = create_test_client(name) async with ssh_session(test_client, "register") as session: maxlines = 10 linenum = 0 found = False while linenum < maxlines: line = await session.stdout.readline() if "Enter public key" in line: found = True break assert found is True session.stdin.write(test_client.public_key + "\n") result = await session.stdout.read() assert "Key is valid. Registering client." in result await session.wait() return test_client class TestSshd(BaseSSHTests): """Class based tests. This allows us to create small helpers. """ @pytest.mark.asyncio async def test_get_secret( self, backend_api: SshecretBackend, ssh_command_runner: CommandRunner ) -> None: """Test get secret flow.""" test_client = create_test_client("testclient") await backend_api.create_client( "testclient", test_client.public_key, "A test client" ) await backend_api.create_client_secret("testclient", "testsecret", "bogus") response = await ssh_command_runner(test_client, "get_secret testsecret") assert response.exit_status == 0 assert response.stdout is not None assert isinstance(response.stdout, str) assert response.stdout.rstrip() == "bogus" @pytest.mark.asyncio async def test_register( self, backend_api: SshecretBackend, ssh_session: ProcessRunner ) -> None: """Test registration.""" await self.register_client("new_client", ssh_session) # Check that the client is created. clients = await backend_api.get_clients() assert len(clients) == 1 client = clients[0] assert client.name == "new_client" class TestSshdIntegration(BaseSSHTests): """Integration tests.""" @pytest.mark.asyncio async def test_end_to_end( self, backend_api: SshecretBackend, admin_server: tuple[str, tuple[str, str]], ssh_session: ProcessRunner, ssh_command_runner: CommandRunner, ) -> None: """Test end to end.""" test_client = await self.register_client("myclient", ssh_session) url, credentials = admin_server username, password = credentials async with self.admin_client(url, username, password) as http_client: resp = await http_client.get("api/v1/clients/") assert resp.status_code == 200 clients = resp.json() assert len(clients) == 1 assert clients[0]["name"] == "myclient" create_model = { "name": "mysecret", "clients": ["myclient"], "value": "mypassword", } resp = await http_client.post("api/v1/secrets/", json=create_model) assert resp.status_code == 200 # Login via ssh to fetch the decrypted value. ssh_output = await ssh_command_runner(test_client, "get_secret mysecret") assert ssh_output.stdout is not None assert isinstance(ssh_output.stdout, str) encrypted = ssh_output.stdout.rstrip() decrypted = decode_string(encrypted, test_client.private_key) assert decrypted == "mypassword" async def login(self, url: str, username: str, password: str) -> str: """Login and get token.""" api_url = os.path.join(url, "api/v1", "token") client = httpx.AsyncClient() response = await client.post( api_url, data={"username": username, "password": password} ) assert response.status_code == 200 data = response.json() assert "access_token" in data assert isinstance(data["access_token"], str) return str(data["access_token"]) @asynccontextmanager async def admin_client( self, url: str, username: str, password: str ) -> AsyncIterator[httpx.AsyncClient]: """Create an admin client.""" token = await self.login(url, username, password) async with httpx.AsyncClient( base_url=url, headers={"Authorization": f"Bearer {token}"} ) as client: yield client class TestShelldriverCommands(BaseSSHTests): """Shelldriver command tests.""" @pytest.mark.asyncio async def test_store_and_lookup( self, backend_server: tuple[str, tuple[str, str]], ssh_session: ProcessRunner, ssh_command_runner: CommandRunner, ) -> None: """Log in.""" test_client = await self.register_client("myclient", ssh_session) ssh_output = await ssh_command_runner(test_client, "store mysecret secretvalue") assert bool(ssh_output.stderr) is False # wait half a second await asyncio.sleep(0.5) ssh_output = await ssh_command_runner(test_client, "lookup mysecret") assert bool(ssh_output.stderr) is False assert ssh_output.stdout is not None assert isinstance(ssh_output.stdout, str) encrypted = ssh_output.stdout.rstrip() decrypted = decode_string(encrypted, test_client.private_key) assert decrypted == "secretvalue" @pytest.mark.asyncio async def test_store_and_lookup_stdin( self, backend_server: tuple[str, tuple[str, str]], ssh_session: ProcessRunner, ssh_command_runner: CommandRunner, ) -> None: """Test store and lookup, with password specification in stdin.""" test_client = await self.register_client("myclient", ssh_session) async with ssh_session(test_client, "store insecret") as session: session.stdin.write("testinput\n") session.stdin.write_eof() await session.stdin.wait_closed() await asyncio.sleep(0.5) ssh_output = await ssh_command_runner(test_client, "lookup insecret") assert bool(ssh_output.stderr) is False assert ssh_output.stdout is not None assert isinstance(ssh_output.stdout, str) encrypted = ssh_output.stdout.rstrip() decrypted = decode_string(encrypted, test_client.private_key) assert decrypted == "testinput" @pytest.mark.parametrize("secret_name",["nonexistant", "blåbærgrød", "../../../etc/shadow"]) @pytest.mark.asyncio async def test_invalid_lookup( self, secret_name: str, ssh_command_runner: CommandRunner, ssh_session: ProcessRunner, ) -> None: """Test lookup with invalid secret.""" test_client = await self.register_client("myclient", ssh_session) ssh_output = await ssh_command_runner(test_client, f"lookup {secret_name}") assert isinstance(ssh_output.stderr, str) assert ssh_output.stderr.rstrip() == "no such secret" class TestShelldriverListCommand(BaseSSHTests): """Tests for the list command.""" @pytest.fixture(name="secret_names") def get_secret_names(self) -> list[str]: """Get secret names. Sort of like a parametrize function. """ return ["foo", "bar", "abc123", "blåbærgrød"] @pytest.fixture(name="test_client") @pytest.mark.asyncio async def create_test_client( self, ssh_session: ProcessRunner, ) -> ClientData: """Register a test client.""" return await self.register_client("listclient", ssh_session) @pytest.fixture(autouse=True) @pytest.mark.asyncio async def create_data( self, secret_names: list[str], test_client: ClientData, ssh_command_runner: CommandRunner, ) -> None: """Create data for the test.""" for name in secret_names: ssh_output = await ssh_command_runner(test_client, f"store {name} secretvalue") assert bool(ssh_output.stderr) is False assert ssh_output.stdout is not None assert isinstance(ssh_output.stdout, str) @pytest.mark.asyncio async def test_list_secrets( self, secret_names: list[str], test_client: ClientData, ssh_command_runner: CommandRunner, ) -> None: """Test list command.""" ssh_output = await ssh_command_runner(test_client, "list") assert bool(ssh_output.stderr) is False assert ssh_output.stdout is not None assert isinstance(ssh_output.stdout, str) list_output = ssh_output.stdout print(list_output) input_secret_names = list_output.splitlines() assert len(input_secret_names) == len(secret_names) assert sorted(secret_names) == sorted(input_secret_names) @pytest.mark.parametrize("secret_name", ["simple", "blåbærgrød"]) class TestShelldriverDeleteCommand(BaseSSHTests): """Tests for the list command.""" @pytest.fixture(name="test_client") @pytest.mark.asyncio async def create_test_client( self, ssh_session: ProcessRunner, ) -> ClientData: """Register a test client.""" return await self.register_client("listclient", ssh_session) @pytest.fixture(autouse=True) @pytest.mark.asyncio async def create_data( self, secret_name: str, test_client: ClientData, ssh_command_runner: CommandRunner, ) -> None: """Create data for the test.""" ssh_output = await ssh_command_runner(test_client, f"store {secret_name} secretvalue") assert bool(ssh_output.stderr) is False assert ssh_output.stdout is not None assert isinstance(ssh_output.stdout, str) async def get_stored_secrets( self, test_client: ClientData, ssh_command_runner: CommandRunner, ) -> list[str]: """Get stored secrets.""" ssh_output = await ssh_command_runner(test_client, "list") assert bool(ssh_output.stderr) is False assert ssh_output.stdout is not None assert isinstance(ssh_output.stdout, str) list_output = ssh_output.stdout print(list_output) input_secret_names = list_output.splitlines() return sorted(input_secret_names) @pytest.mark.asyncio async def test_delete_secret( self, secret_name: str, test_client: ClientData, ssh_command_runner: CommandRunner, ) -> None: """Delete secret.""" current_secrets = await self.get_stored_secrets(test_client, ssh_command_runner) assert secret_name in current_secrets ssh_output = await ssh_command_runner(test_client, f"delete {secret_name}") assert bool(ssh_output.stderr) is False current_secrets = await self.get_stored_secrets(test_client, ssh_command_runner) assert secret_name not in current_secrets