Files
sshecret/tests/test_admin_api.py

417 lines
16 KiB
Python

"""Tests for the Admin HTTP API"""
from ipaddress import IPv4Address
import unittest
from fastapi.testclient import TestClient
from sshecret.types import ClientSpecification
from sshecret.testing import TestClientSpec, TestContext, api_context
from sshecret.webapi.api import get_app_settings
from sshecret.webapi.router import app
from sshecret.crypto import (
generate_private_key,
generate_public_key_string,
decode_string,
)
class TestLockUnlock(unittest.TestCase):
"""Test lock and unlock."""
def setUp(self) -> None:
"""Set up testing."""
def test_unlock_lock(self) -> None:
"""Test unlocking."""
with api_context([]) as context:
app.dependency_overrides[get_app_settings] = context.get_settings
testclient: TestClient = TestClient(app)
response = testclient.post(
"api/v1/auth/unlock", json={"password": context.master_password}
)
body = response.json()
self.assertEqual(response.status_code, 200)
self.assertIn("session_id", body)
session_id = body["session_id"]
session_header = {"session-id": str(session_id)}
status_resp = testclient.get("/api/v1/auth/status", headers=session_header)
self.assertEqual(status_resp.status_code, 200)
status_body = status_resp.json()
self.assertIn("message", status_body)
self.assertEqual(str(status_body["message"]), "UNLOCKED")
lock_resp = testclient.post("/api/v1/auth/lock", headers=session_header)
self.assertEqual(lock_resp.status_code, 200)
lock_body = lock_resp.json()
lock_status = lock_body.get("message")
self.assertEqual(lock_status, "LOCKED")
def test_get_clients(self) -> None:
"""Test get clients."""
test_data = [
TestClientSpec(
"webserver",
{
"API_KEY": "test",
"OTHER_API_KEY": "test2",
},
),
TestClientSpec(
"db_server",
{
"DB_PASSWORD": "test",
},
),
]
with api_context(test_data) as context:
app.dependency_overrides[get_app_settings] = context.get_settings
testclient: TestClient = TestClient(app)
client_resp = testclient.get("/api/v1/clients")
clients = client_resp.json()
self.assertIsInstance(clients, list)
self.assertEqual(len(clients), 2)
for client in clients:
ClientSpecification.model_validate(client)
def test_get_client(self) -> None:
"""Test get specific client."""
test_data = [
TestClientSpec(
"webserver",
{
"API_KEY": "test",
"OTHER_API_KEY": "test2",
},
),
TestClientSpec(
"db_server",
{
"DB_PASSWORD": "test",
},
),
]
with api_context(test_data) as context:
app.dependency_overrides[get_app_settings] = context.get_settings
testclient: TestClient = TestClient(app)
client_resp = testclient.get("/api/v1/clients/webserver")
self.assertEqual(client_resp.status_code, 200)
client_dict = client_resp.json()
ClientSpecification.model_validate(client_dict)
def test_update_client(self) -> None:
"""Test update client with trivial value."""
test_data = [
TestClientSpec(
"webserver",
{
"API_KEY": "test",
"OTHER_API_KEY": "test2",
},
),
TestClientSpec(
"db_server",
{
"DB_PASSWORD": "test",
},
),
]
with api_context(test_data) as context:
app.dependency_overrides[get_app_settings] = context.get_settings
testclient: TestClient = TestClient(app)
client_resp = testclient.get("/api/v1/clients/webserver")
self.assertEqual(client_resp.status_code, 200)
client_dict = client_resp.json()
client = ClientSpecification.model_validate(client_dict)
unlock_response = testclient.post(
"/api/v1/auth/unlock", json={"password": context.master_password}
)
body = unlock_response.json()
self.assertEqual(unlock_response.status_code, 200)
self.assertIn("session_id", body)
session_id = body["session_id"]
session_header = {"session-id": str(session_id)}
serialized_client = client.model_dump(exclude_unset=True)
serialized_client["allowed_ips"] = ["192.0.2.1"]
update_response = testclient.put(
"/api/v1/clients/webserver",
json=serialized_client,
headers=session_header,
)
self.assertAlmostEqual(update_response.status_code, 200)
update_body = update_response.json()
updated_client = ClientSpecification.model_validate(update_body)
assert updated_client.allowed_ips == [IPv4Address("192.0.2.1")]
def test_update_client_sshkey(self) -> None:
"""Update client SSH key."""
test_data = [
TestClientSpec(
"webserver",
{
"API_KEY": "test",
"OTHER_API_KEY": "test2",
},
),
TestClientSpec(
"db_server",
{
"DB_PASSWORD": "test",
},
),
]
with api_context(test_data) as context:
app.dependency_overrides[get_app_settings] = context.get_settings
testclient: TestClient = TestClient(app)
new_private_key = generate_private_key()
public_key = generate_public_key_string(new_private_key.public_key())
client_resp = testclient.get("/api/v1/clients/webserver")
self.assertEqual(client_resp.status_code, 200)
client_dict = client_resp.json()
client = ClientSpecification.model_validate(client_dict)
unlock_response = testclient.post(
"/api/v1/auth/unlock", json={"password": context.master_password}
)
body = unlock_response.json()
self.assertEqual(unlock_response.status_code, 200)
self.assertIn("session_id", body)
session_id = body["session_id"]
session_header = {"session-id": str(session_id)}
serialized_client = client.model_dump(exclude_unset=True)
serialized_client["public_key"] = public_key
update_response = testclient.put(
"/api/v1/clients/webserver",
json=serialized_client,
headers=session_header,
)
self.assertAlmostEqual(update_response.status_code, 200)
update_body = update_response.json()
updated_client = ClientSpecification.model_validate(update_body)
for secret, value in updated_client.secrets.items():
old_secret = client.secrets[secret]
self.assertNotEqual(old_secret, value)
cleartext = decode_string(value, new_private_key)
self.assertTrue(cleartext.startswith("test"))
# check that the backend is properly updated.
new_client_resp = testclient.get("/api/v1/clients/webserver")
new_client_dict = new_client_resp.json()
self.assertEqual(new_client_resp.status_code, 200)
new_client = ClientSpecification.model_validate(new_client_dict)
for secret, value in new_client.secrets.items():
matching_value = updated_client.secrets[secret]
self.assertEqual(value, matching_value)
def test_delete_client(self) -> None:
"""Test the delete_client API."""
test_data = [
TestClientSpec(
"webserver",
{
"API_KEY": "test",
"OTHER_API_KEY": "test2",
},
),
TestClientSpec(
"db_server",
{
"DB_PASSWORD": "test",
},
),
]
with api_context(test_data) as context:
app.dependency_overrides[get_app_settings] = context.get_settings
testclient: TestClient = TestClient(app)
client_resp = testclient.get("/api/v1/clients/webserver")
self.assertEqual(client_resp.status_code, 200)
delete_resp = testclient.delete("/api/v1/clients/webserver")
self.assertEqual(delete_resp.status_code, 204)
get_resp = testclient.get("/api/v1/clients/webserver")
self.assertEqual(get_resp.status_code, 404)
def test_add_client(self) -> None:
"""Test the add_client API."""
test_data = [
TestClientSpec(
"webserver",
{
"API_KEY": "test",
"OTHER_API_KEY": "test2",
},
),
TestClientSpec(
"db_server",
{
"DB_PASSWORD": "test",
},
),
]
with api_context(test_data) as context:
app.dependency_overrides[get_app_settings] = context.get_settings
testclient: TestClient = TestClient(app)
private_key = generate_private_key()
public_key = generate_public_key_string(private_key.public_key())
new_client = ClientSpecification(
name="webserver2",
public_key=public_key,
)
add_resp = testclient.post(
"/api/v1/clients",
json=new_client.model_dump(exclude_unset=True, exclude_defaults=True),
)
self.assertEqual(add_resp.status_code, 201)
body = add_resp.json()
client = ClientSpecification.model_validate(body)
self.assertEqual(client.public_key, public_key)
fetched_client = self.fetch_client(testclient, "webserver2")
self.assertEqual(fetched_client, client)
def test_list_secrets(self) -> None:
"""Test the list_secrets API."""
test_data = [
TestClientSpec(
"webserver",
{
"API_KEY": "test",
"OTHER_API_KEY": "test2",
},
),
TestClientSpec(
"db_server",
{
"DB_PASSWORD": "test",
},
),
]
with api_context(test_data) as context:
app.dependency_overrides[get_app_settings] = context.get_settings
testclient: TestClient = TestClient(app)
headers = self.unlock(context, testclient)
resp = testclient.get("/api/v1/secrets", headers=headers)
self.assertEqual(resp.status_code, 200)
expected = [
{"name": "API_KEY", "assigned_clients": ["webserver"]},
{"name": "OTHER_API_KEY", "assigned_clients": ["webserver"]},
{"name": "DB_PASSWORD", "assigned_clients": ["db_server"]},
]
body = resp.json()
self.assertListEqual(body, expected)
def test_get_secret(self) -> None:
"""Test the get_secret API."""
test_data = [
TestClientSpec(
"db_server",
{
"DB_PASSWORD": "test",
},
),
]
with api_context(test_data) as context:
app.dependency_overrides[get_app_settings] = context.get_settings
testclient: TestClient = TestClient(app)
headers = self.unlock(context, testclient)
resp = testclient.get("/api/v1/secrets/DB_PASSWORD", headers=headers)
self.assertEqual(resp.status_code, 200)
expected = {"name": "DB_PASSWORD", "secret": "test"}
body = resp.json()
self.assertDictEqual(body, expected)
def test_update_secret_provided(self) -> None:
"""Test the update_secret API.
Tests updating a secret with a provided string.
"""
test_data = [
TestClientSpec(
"db_server",
{
"DB_PASSWORD": "test",
},
),
]
with api_context(test_data) as context:
app.dependency_overrides[get_app_settings] = context.get_settings
testclient: TestClient = TestClient(app)
headers = self.unlock(context, testclient)
request = {"secret": "not-so-secret"}
resp = testclient.put(
"/api/v1/secrets/DB_PASSWORD", json=request, headers=headers
)
self.assertEqual(resp.status_code, 200)
expected = {"name": "DB_PASSWORD", "secret": None}
body = resp.json()
self.assertDictEqual(body, expected)
def test_update_secret_auto(self) -> None:
"""Test the update_secret API.
Tests updating a secret with auto-generated string.
"""
test_data = [
TestClientSpec(
"db_server",
{
"DB_PASSWORD": "test",
},
),
]
with api_context(test_data) as context:
app.dependency_overrides[get_app_settings] = context.get_settings
testclient: TestClient = TestClient(app)
headers = self.unlock(context, testclient)
request = {"secret": None, "auto_generate": True}
resp = testclient.put(
"/api/v1/secrets/DB_PASSWORD", json=request, headers=headers
)
self.assertEqual(resp.status_code, 200)
body = resp.json()
secret = body.get("secret")
self.assertIsNotNone(secret)
def test_delete_secret(self) -> None:
"""Test delete_secret API."""
test_data = [
TestClientSpec(
"webserver",
{
"API_KEY": "test",
"OTHER_API_KEY": "test2",
},
),
]
with api_context(test_data) as context:
app.dependency_overrides[get_app_settings] = context.get_settings
testclient: TestClient = TestClient(app)
headers = self.unlock(context, testclient)
resp = testclient.delete("/api/v1/secrets/OTHER_API_KEY", headers=headers)
self.assertEqual(resp.status_code, 204)
get_resp = testclient.get("/api/v1/secrets/OTHER_API_KEY", headers=headers)
self.assertEqual(get_resp.status_code, 404)
def fetch_client(
self, testclient: TestClient, client_name: str
) -> ClientSpecification:
"""Fetch a client."""
client_resp = testclient.get(f"/api/v1/clients/{client_name}")
self.assertEqual(client_resp.status_code, 200)
client_dict = client_resp.json()
client = ClientSpecification.model_validate(client_dict)
return client
def unlock(self, context: TestContext, testclient: TestClient) -> dict[str, str]:
"""Unlock the session."""
response = testclient.post(
"/api/v1/auth/unlock", json={"password": context.master_password}
)
body = response.json()
session_id = body["session_id"]
session_header = {"session-id": str(session_id)}
return session_header
if __name__ == "__main__":
unittest.main()