"""Tests for the Admin HTTP API""" from ipaddress import IPv4Address import unittest from fastapi.testclient import TestClient from sshecret.types import ClientSpecification from sshecret.testing import TestClientSpec, TestContext, api_context from sshecret.webapi.api import get_app_settings from sshecret.webapi.router import app from sshecret.crypto import ( generate_private_key, generate_public_key_string, decode_string, ) class TestLockUnlock(unittest.TestCase): """Test lock and unlock.""" def setUp(self) -> None: """Set up testing.""" def test_unlock_lock(self) -> None: """Test unlocking.""" with api_context([]) as context: app.dependency_overrides[get_app_settings] = context.get_settings testclient: TestClient = TestClient(app) response = testclient.post( "api/v1/auth/unlock", json={"password": context.master_password} ) body = response.json() self.assertEqual(response.status_code, 200) self.assertIn("session_id", body) session_id = body["session_id"] session_header = {"session-id": str(session_id)} status_resp = testclient.get("/api/v1/auth/status", headers=session_header) self.assertEqual(status_resp.status_code, 200) status_body = status_resp.json() self.assertIn("message", status_body) self.assertEqual(str(status_body["message"]), "UNLOCKED") lock_resp = testclient.post("/api/v1/auth/lock", headers=session_header) self.assertEqual(lock_resp.status_code, 200) lock_body = lock_resp.json() lock_status = lock_body.get("message") self.assertEqual(lock_status, "LOCKED") def test_get_clients(self) -> None: """Test get clients.""" test_data = [ TestClientSpec( "webserver", { "API_KEY": "test", "OTHER_API_KEY": "test2", }, ), TestClientSpec( "db_server", { "DB_PASSWORD": "test", }, ), ] with api_context(test_data) as context: app.dependency_overrides[get_app_settings] = context.get_settings testclient: TestClient = TestClient(app) client_resp = testclient.get("/api/v1/clients") clients = client_resp.json() self.assertIsInstance(clients, list) self.assertEqual(len(clients), 2) for client in clients: ClientSpecification.model_validate(client) def test_get_client(self) -> None: """Test get specific client.""" test_data = [ TestClientSpec( "webserver", { "API_KEY": "test", "OTHER_API_KEY": "test2", }, ), TestClientSpec( "db_server", { "DB_PASSWORD": "test", }, ), ] with api_context(test_data) as context: app.dependency_overrides[get_app_settings] = context.get_settings testclient: TestClient = TestClient(app) client_resp = testclient.get("/api/v1/clients/webserver") self.assertEqual(client_resp.status_code, 200) client_dict = client_resp.json() ClientSpecification.model_validate(client_dict) def test_update_client(self) -> None: """Test update client with trivial value.""" test_data = [ TestClientSpec( "webserver", { "API_KEY": "test", "OTHER_API_KEY": "test2", }, ), TestClientSpec( "db_server", { "DB_PASSWORD": "test", }, ), ] with api_context(test_data) as context: app.dependency_overrides[get_app_settings] = context.get_settings testclient: TestClient = TestClient(app) client_resp = testclient.get("/api/v1/clients/webserver") self.assertEqual(client_resp.status_code, 200) client_dict = client_resp.json() client = ClientSpecification.model_validate(client_dict) unlock_response = testclient.post( "/api/v1/auth/unlock", json={"password": context.master_password} ) body = unlock_response.json() self.assertEqual(unlock_response.status_code, 200) self.assertIn("session_id", body) session_id = body["session_id"] session_header = {"session-id": str(session_id)} serialized_client = client.model_dump(exclude_unset=True) serialized_client["allowed_ips"] = ["192.0.2.1"] update_response = testclient.put( "/api/v1/clients/webserver", json=serialized_client, headers=session_header, ) self.assertAlmostEqual(update_response.status_code, 200) update_body = update_response.json() updated_client = ClientSpecification.model_validate(update_body) assert updated_client.allowed_ips == [IPv4Address("192.0.2.1")] def test_update_client_sshkey(self) -> None: """Update client SSH key.""" test_data = [ TestClientSpec( "webserver", { "API_KEY": "test", "OTHER_API_KEY": "test2", }, ), TestClientSpec( "db_server", { "DB_PASSWORD": "test", }, ), ] with api_context(test_data) as context: app.dependency_overrides[get_app_settings] = context.get_settings testclient: TestClient = TestClient(app) new_private_key = generate_private_key() public_key = generate_public_key_string(new_private_key.public_key()) client_resp = testclient.get("/api/v1/clients/webserver") self.assertEqual(client_resp.status_code, 200) client_dict = client_resp.json() client = ClientSpecification.model_validate(client_dict) unlock_response = testclient.post( "/api/v1/auth/unlock", json={"password": context.master_password} ) body = unlock_response.json() self.assertEqual(unlock_response.status_code, 200) self.assertIn("session_id", body) session_id = body["session_id"] session_header = {"session-id": str(session_id)} serialized_client = client.model_dump(exclude_unset=True) serialized_client["public_key"] = public_key update_response = testclient.put( "/api/v1/clients/webserver", json=serialized_client, headers=session_header, ) self.assertAlmostEqual(update_response.status_code, 200) update_body = update_response.json() updated_client = ClientSpecification.model_validate(update_body) for secret, value in updated_client.secrets.items(): old_secret = client.secrets[secret] self.assertNotEqual(old_secret, value) cleartext = decode_string(value, new_private_key) self.assertTrue(cleartext.startswith("test")) # check that the backend is properly updated. new_client_resp = testclient.get("/api/v1/clients/webserver") new_client_dict = new_client_resp.json() self.assertEqual(new_client_resp.status_code, 200) new_client = ClientSpecification.model_validate(new_client_dict) for secret, value in new_client.secrets.items(): matching_value = updated_client.secrets[secret] self.assertEqual(value, matching_value) def test_delete_client(self) -> None: """Test the delete_client API.""" test_data = [ TestClientSpec( "webserver", { "API_KEY": "test", "OTHER_API_KEY": "test2", }, ), TestClientSpec( "db_server", { "DB_PASSWORD": "test", }, ), ] with api_context(test_data) as context: app.dependency_overrides[get_app_settings] = context.get_settings testclient: TestClient = TestClient(app) client_resp = testclient.get("/api/v1/clients/webserver") self.assertEqual(client_resp.status_code, 200) delete_resp = testclient.delete("/api/v1/clients/webserver") self.assertEqual(delete_resp.status_code, 204) get_resp = testclient.get("/api/v1/clients/webserver") self.assertEqual(get_resp.status_code, 404) def test_add_client(self) -> None: """Test the add_client API.""" test_data = [ TestClientSpec( "webserver", { "API_KEY": "test", "OTHER_API_KEY": "test2", }, ), TestClientSpec( "db_server", { "DB_PASSWORD": "test", }, ), ] with api_context(test_data) as context: app.dependency_overrides[get_app_settings] = context.get_settings testclient: TestClient = TestClient(app) private_key = generate_private_key() public_key = generate_public_key_string(private_key.public_key()) new_client = ClientSpecification( name="webserver2", public_key=public_key, ) add_resp = testclient.post( "/api/v1/clients", json=new_client.model_dump(exclude_unset=True, exclude_defaults=True), ) self.assertEqual(add_resp.status_code, 201) body = add_resp.json() client = ClientSpecification.model_validate(body) self.assertEqual(client.public_key, public_key) fetched_client = self.fetch_client(testclient, "webserver2") self.assertEqual(fetched_client, client) def test_list_secrets(self) -> None: """Test the list_secrets API.""" test_data = [ TestClientSpec( "webserver", { "API_KEY": "test", "OTHER_API_KEY": "test2", }, ), TestClientSpec( "db_server", { "DB_PASSWORD": "test", }, ), ] with api_context(test_data) as context: app.dependency_overrides[get_app_settings] = context.get_settings testclient: TestClient = TestClient(app) headers = self.unlock(context, testclient) resp = testclient.get("/api/v1/secrets", headers=headers) self.assertEqual(resp.status_code, 200) expected = [ {"name": "API_KEY", "assigned_clients": ["webserver"]}, {"name": "OTHER_API_KEY", "assigned_clients": ["webserver"]}, {"name": "DB_PASSWORD", "assigned_clients": ["db_server"]}, ] body = resp.json() self.assertListEqual(body, expected) def test_get_secret(self) -> None: """Test the get_secret API.""" test_data = [ TestClientSpec( "db_server", { "DB_PASSWORD": "test", }, ), ] with api_context(test_data) as context: app.dependency_overrides[get_app_settings] = context.get_settings testclient: TestClient = TestClient(app) headers = self.unlock(context, testclient) resp = testclient.get("/api/v1/secrets/DB_PASSWORD", headers=headers) self.assertEqual(resp.status_code, 200) expected = {"name": "DB_PASSWORD", "secret": "test"} body = resp.json() self.assertDictEqual(body, expected) def test_update_secret_provided(self) -> None: """Test the update_secret API. Tests updating a secret with a provided string. """ test_data = [ TestClientSpec( "db_server", { "DB_PASSWORD": "test", }, ), ] with api_context(test_data) as context: app.dependency_overrides[get_app_settings] = context.get_settings testclient: TestClient = TestClient(app) headers = self.unlock(context, testclient) request = {"secret": "not-so-secret"} resp = testclient.put( "/api/v1/secrets/DB_PASSWORD", json=request, headers=headers ) self.assertEqual(resp.status_code, 200) expected = {"name": "DB_PASSWORD", "secret": None} body = resp.json() self.assertDictEqual(body, expected) def test_update_secret_auto(self) -> None: """Test the update_secret API. Tests updating a secret with auto-generated string. """ test_data = [ TestClientSpec( "db_server", { "DB_PASSWORD": "test", }, ), ] with api_context(test_data) as context: app.dependency_overrides[get_app_settings] = context.get_settings testclient: TestClient = TestClient(app) headers = self.unlock(context, testclient) request = {"secret": None, "auto_generate": True} resp = testclient.put( "/api/v1/secrets/DB_PASSWORD", json=request, headers=headers ) self.assertEqual(resp.status_code, 200) body = resp.json() secret = body.get("secret") self.assertIsNotNone(secret) def test_delete_secret(self) -> None: """Test delete_secret API.""" test_data = [ TestClientSpec( "webserver", { "API_KEY": "test", "OTHER_API_KEY": "test2", }, ), ] with api_context(test_data) as context: app.dependency_overrides[get_app_settings] = context.get_settings testclient: TestClient = TestClient(app) headers = self.unlock(context, testclient) resp = testclient.delete("/api/v1/secrets/OTHER_API_KEY", headers=headers) self.assertEqual(resp.status_code, 204) get_resp = testclient.get("/api/v1/secrets/OTHER_API_KEY", headers=headers) self.assertEqual(get_resp.status_code, 404) def fetch_client( self, testclient: TestClient, client_name: str ) -> ClientSpecification: """Fetch a client.""" client_resp = testclient.get(f"/api/v1/clients/{client_name}") self.assertEqual(client_resp.status_code, 200) client_dict = client_resp.json() client = ClientSpecification.model_validate(client_dict) return client def unlock(self, context: TestContext, testclient: TestClient) -> dict[str, str]: """Unlock the session.""" response = testclient.post( "/api/v1/auth/unlock", json={"password": context.master_password} ) body = response.json() session_id = body["session_id"] session_header = {"session-id": str(session_id)} return session_header if __name__ == "__main__": unittest.main()