417 lines
16 KiB
Python
417 lines
16 KiB
Python
"""Tests for the Admin HTTP API"""
|
|
|
|
from ipaddress import IPv4Address
|
|
import unittest
|
|
|
|
from fastapi.testclient import TestClient
|
|
|
|
from sshecret.types import ClientSpecification
|
|
from sshecret.testing import TestClientSpec, TestContext, api_context
|
|
from sshecret.webapi.api import get_app_settings
|
|
from sshecret.webapi.router import app
|
|
from sshecret.crypto import (
|
|
generate_private_key,
|
|
generate_public_key_string,
|
|
decode_string,
|
|
)
|
|
|
|
|
|
class TestLockUnlock(unittest.TestCase):
|
|
"""Test lock and unlock."""
|
|
|
|
def setUp(self) -> None:
|
|
"""Set up testing."""
|
|
|
|
def test_unlock_lock(self) -> None:
|
|
"""Test unlocking."""
|
|
with api_context([]) as context:
|
|
app.dependency_overrides[get_app_settings] = context.get_settings
|
|
testclient: TestClient = TestClient(app)
|
|
response = testclient.post(
|
|
"api/v1/auth/unlock", json={"password": context.master_password}
|
|
)
|
|
body = response.json()
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertIn("session_id", body)
|
|
session_id = body["session_id"]
|
|
session_header = {"session-id": str(session_id)}
|
|
status_resp = testclient.get("/api/v1/auth/status", headers=session_header)
|
|
self.assertEqual(status_resp.status_code, 200)
|
|
status_body = status_resp.json()
|
|
self.assertIn("message", status_body)
|
|
self.assertEqual(str(status_body["message"]), "UNLOCKED")
|
|
lock_resp = testclient.post("/api/v1/auth/lock", headers=session_header)
|
|
self.assertEqual(lock_resp.status_code, 200)
|
|
lock_body = lock_resp.json()
|
|
lock_status = lock_body.get("message")
|
|
self.assertEqual(lock_status, "LOCKED")
|
|
|
|
def test_get_clients(self) -> None:
|
|
"""Test get clients."""
|
|
test_data = [
|
|
TestClientSpec(
|
|
"webserver",
|
|
{
|
|
"API_KEY": "test",
|
|
"OTHER_API_KEY": "test2",
|
|
},
|
|
),
|
|
TestClientSpec(
|
|
"db_server",
|
|
{
|
|
"DB_PASSWORD": "test",
|
|
},
|
|
),
|
|
]
|
|
with api_context(test_data) as context:
|
|
app.dependency_overrides[get_app_settings] = context.get_settings
|
|
testclient: TestClient = TestClient(app)
|
|
client_resp = testclient.get("/api/v1/clients")
|
|
clients = client_resp.json()
|
|
self.assertIsInstance(clients, list)
|
|
self.assertEqual(len(clients), 2)
|
|
|
|
for client in clients:
|
|
ClientSpecification.model_validate(client)
|
|
|
|
def test_get_client(self) -> None:
|
|
"""Test get specific client."""
|
|
test_data = [
|
|
TestClientSpec(
|
|
"webserver",
|
|
{
|
|
"API_KEY": "test",
|
|
"OTHER_API_KEY": "test2",
|
|
},
|
|
),
|
|
TestClientSpec(
|
|
"db_server",
|
|
{
|
|
"DB_PASSWORD": "test",
|
|
},
|
|
),
|
|
]
|
|
with api_context(test_data) as context:
|
|
app.dependency_overrides[get_app_settings] = context.get_settings
|
|
testclient: TestClient = TestClient(app)
|
|
client_resp = testclient.get("/api/v1/clients/webserver")
|
|
self.assertEqual(client_resp.status_code, 200)
|
|
client_dict = client_resp.json()
|
|
ClientSpecification.model_validate(client_dict)
|
|
|
|
def test_update_client(self) -> None:
|
|
"""Test update client with trivial value."""
|
|
test_data = [
|
|
TestClientSpec(
|
|
"webserver",
|
|
{
|
|
"API_KEY": "test",
|
|
"OTHER_API_KEY": "test2",
|
|
},
|
|
),
|
|
TestClientSpec(
|
|
"db_server",
|
|
{
|
|
"DB_PASSWORD": "test",
|
|
},
|
|
),
|
|
]
|
|
with api_context(test_data) as context:
|
|
app.dependency_overrides[get_app_settings] = context.get_settings
|
|
testclient: TestClient = TestClient(app)
|
|
client_resp = testclient.get("/api/v1/clients/webserver")
|
|
self.assertEqual(client_resp.status_code, 200)
|
|
client_dict = client_resp.json()
|
|
client = ClientSpecification.model_validate(client_dict)
|
|
unlock_response = testclient.post(
|
|
"/api/v1/auth/unlock", json={"password": context.master_password}
|
|
)
|
|
body = unlock_response.json()
|
|
self.assertEqual(unlock_response.status_code, 200)
|
|
self.assertIn("session_id", body)
|
|
session_id = body["session_id"]
|
|
session_header = {"session-id": str(session_id)}
|
|
serialized_client = client.model_dump(exclude_unset=True)
|
|
serialized_client["allowed_ips"] = ["192.0.2.1"]
|
|
update_response = testclient.put(
|
|
"/api/v1/clients/webserver",
|
|
json=serialized_client,
|
|
headers=session_header,
|
|
)
|
|
self.assertAlmostEqual(update_response.status_code, 200)
|
|
update_body = update_response.json()
|
|
updated_client = ClientSpecification.model_validate(update_body)
|
|
assert updated_client.allowed_ips == [IPv4Address("192.0.2.1")]
|
|
|
|
def test_update_client_sshkey(self) -> None:
|
|
"""Update client SSH key."""
|
|
test_data = [
|
|
TestClientSpec(
|
|
"webserver",
|
|
{
|
|
"API_KEY": "test",
|
|
"OTHER_API_KEY": "test2",
|
|
},
|
|
),
|
|
TestClientSpec(
|
|
"db_server",
|
|
{
|
|
"DB_PASSWORD": "test",
|
|
},
|
|
),
|
|
]
|
|
with api_context(test_data) as context:
|
|
app.dependency_overrides[get_app_settings] = context.get_settings
|
|
testclient: TestClient = TestClient(app)
|
|
new_private_key = generate_private_key()
|
|
public_key = generate_public_key_string(new_private_key.public_key())
|
|
client_resp = testclient.get("/api/v1/clients/webserver")
|
|
self.assertEqual(client_resp.status_code, 200)
|
|
client_dict = client_resp.json()
|
|
client = ClientSpecification.model_validate(client_dict)
|
|
unlock_response = testclient.post(
|
|
"/api/v1/auth/unlock", json={"password": context.master_password}
|
|
)
|
|
body = unlock_response.json()
|
|
self.assertEqual(unlock_response.status_code, 200)
|
|
self.assertIn("session_id", body)
|
|
session_id = body["session_id"]
|
|
session_header = {"session-id": str(session_id)}
|
|
|
|
serialized_client = client.model_dump(exclude_unset=True)
|
|
serialized_client["public_key"] = public_key
|
|
update_response = testclient.put(
|
|
"/api/v1/clients/webserver",
|
|
json=serialized_client,
|
|
headers=session_header,
|
|
)
|
|
self.assertAlmostEqual(update_response.status_code, 200)
|
|
update_body = update_response.json()
|
|
updated_client = ClientSpecification.model_validate(update_body)
|
|
for secret, value in updated_client.secrets.items():
|
|
old_secret = client.secrets[secret]
|
|
self.assertNotEqual(old_secret, value)
|
|
cleartext = decode_string(value, new_private_key)
|
|
self.assertTrue(cleartext.startswith("test"))
|
|
|
|
# check that the backend is properly updated.
|
|
new_client_resp = testclient.get("/api/v1/clients/webserver")
|
|
new_client_dict = new_client_resp.json()
|
|
self.assertEqual(new_client_resp.status_code, 200)
|
|
new_client = ClientSpecification.model_validate(new_client_dict)
|
|
for secret, value in new_client.secrets.items():
|
|
matching_value = updated_client.secrets[secret]
|
|
self.assertEqual(value, matching_value)
|
|
|
|
def test_delete_client(self) -> None:
|
|
"""Test the delete_client API."""
|
|
test_data = [
|
|
TestClientSpec(
|
|
"webserver",
|
|
{
|
|
"API_KEY": "test",
|
|
"OTHER_API_KEY": "test2",
|
|
},
|
|
),
|
|
TestClientSpec(
|
|
"db_server",
|
|
{
|
|
"DB_PASSWORD": "test",
|
|
},
|
|
),
|
|
]
|
|
with api_context(test_data) as context:
|
|
app.dependency_overrides[get_app_settings] = context.get_settings
|
|
testclient: TestClient = TestClient(app)
|
|
client_resp = testclient.get("/api/v1/clients/webserver")
|
|
self.assertEqual(client_resp.status_code, 200)
|
|
delete_resp = testclient.delete("/api/v1/clients/webserver")
|
|
self.assertEqual(delete_resp.status_code, 204)
|
|
get_resp = testclient.get("/api/v1/clients/webserver")
|
|
self.assertEqual(get_resp.status_code, 404)
|
|
|
|
def test_add_client(self) -> None:
|
|
"""Test the add_client API."""
|
|
test_data = [
|
|
TestClientSpec(
|
|
"webserver",
|
|
{
|
|
"API_KEY": "test",
|
|
"OTHER_API_KEY": "test2",
|
|
},
|
|
),
|
|
TestClientSpec(
|
|
"db_server",
|
|
{
|
|
"DB_PASSWORD": "test",
|
|
},
|
|
),
|
|
]
|
|
with api_context(test_data) as context:
|
|
app.dependency_overrides[get_app_settings] = context.get_settings
|
|
testclient: TestClient = TestClient(app)
|
|
private_key = generate_private_key()
|
|
public_key = generate_public_key_string(private_key.public_key())
|
|
new_client = ClientSpecification(
|
|
name="webserver2",
|
|
public_key=public_key,
|
|
)
|
|
add_resp = testclient.post(
|
|
"/api/v1/clients",
|
|
json=new_client.model_dump(exclude_unset=True, exclude_defaults=True),
|
|
)
|
|
self.assertEqual(add_resp.status_code, 201)
|
|
body = add_resp.json()
|
|
client = ClientSpecification.model_validate(body)
|
|
self.assertEqual(client.public_key, public_key)
|
|
fetched_client = self.fetch_client(testclient, "webserver2")
|
|
self.assertEqual(fetched_client, client)
|
|
|
|
def test_list_secrets(self) -> None:
|
|
"""Test the list_secrets API."""
|
|
test_data = [
|
|
TestClientSpec(
|
|
"webserver",
|
|
{
|
|
"API_KEY": "test",
|
|
"OTHER_API_KEY": "test2",
|
|
},
|
|
),
|
|
TestClientSpec(
|
|
"db_server",
|
|
{
|
|
"DB_PASSWORD": "test",
|
|
},
|
|
),
|
|
]
|
|
with api_context(test_data) as context:
|
|
app.dependency_overrides[get_app_settings] = context.get_settings
|
|
testclient: TestClient = TestClient(app)
|
|
headers = self.unlock(context, testclient)
|
|
resp = testclient.get("/api/v1/secrets", headers=headers)
|
|
self.assertEqual(resp.status_code, 200)
|
|
expected = [
|
|
{"name": "API_KEY", "assigned_clients": ["webserver"]},
|
|
{"name": "OTHER_API_KEY", "assigned_clients": ["webserver"]},
|
|
{"name": "DB_PASSWORD", "assigned_clients": ["db_server"]},
|
|
]
|
|
body = resp.json()
|
|
|
|
self.assertListEqual(body, expected)
|
|
|
|
def test_get_secret(self) -> None:
|
|
"""Test the get_secret API."""
|
|
test_data = [
|
|
TestClientSpec(
|
|
"db_server",
|
|
{
|
|
"DB_PASSWORD": "test",
|
|
},
|
|
),
|
|
]
|
|
with api_context(test_data) as context:
|
|
app.dependency_overrides[get_app_settings] = context.get_settings
|
|
testclient: TestClient = TestClient(app)
|
|
headers = self.unlock(context, testclient)
|
|
resp = testclient.get("/api/v1/secrets/DB_PASSWORD", headers=headers)
|
|
self.assertEqual(resp.status_code, 200)
|
|
expected = {"name": "DB_PASSWORD", "secret": "test"}
|
|
body = resp.json()
|
|
self.assertDictEqual(body, expected)
|
|
|
|
def test_update_secret_provided(self) -> None:
|
|
"""Test the update_secret API.
|
|
|
|
Tests updating a secret with a provided string.
|
|
"""
|
|
test_data = [
|
|
TestClientSpec(
|
|
"db_server",
|
|
{
|
|
"DB_PASSWORD": "test",
|
|
},
|
|
),
|
|
]
|
|
with api_context(test_data) as context:
|
|
app.dependency_overrides[get_app_settings] = context.get_settings
|
|
testclient: TestClient = TestClient(app)
|
|
headers = self.unlock(context, testclient)
|
|
request = {"secret": "not-so-secret"}
|
|
resp = testclient.put(
|
|
"/api/v1/secrets/DB_PASSWORD", json=request, headers=headers
|
|
)
|
|
self.assertEqual(resp.status_code, 200)
|
|
expected = {"name": "DB_PASSWORD", "secret": None}
|
|
body = resp.json()
|
|
self.assertDictEqual(body, expected)
|
|
|
|
def test_update_secret_auto(self) -> None:
|
|
"""Test the update_secret API.
|
|
|
|
Tests updating a secret with auto-generated string.
|
|
"""
|
|
test_data = [
|
|
TestClientSpec(
|
|
"db_server",
|
|
{
|
|
"DB_PASSWORD": "test",
|
|
},
|
|
),
|
|
]
|
|
with api_context(test_data) as context:
|
|
app.dependency_overrides[get_app_settings] = context.get_settings
|
|
testclient: TestClient = TestClient(app)
|
|
headers = self.unlock(context, testclient)
|
|
request = {"secret": None, "auto_generate": True}
|
|
resp = testclient.put(
|
|
"/api/v1/secrets/DB_PASSWORD", json=request, headers=headers
|
|
)
|
|
self.assertEqual(resp.status_code, 200)
|
|
body = resp.json()
|
|
secret = body.get("secret")
|
|
self.assertIsNotNone(secret)
|
|
|
|
def test_delete_secret(self) -> None:
|
|
"""Test delete_secret API."""
|
|
test_data = [
|
|
TestClientSpec(
|
|
"webserver",
|
|
{
|
|
"API_KEY": "test",
|
|
"OTHER_API_KEY": "test2",
|
|
},
|
|
),
|
|
]
|
|
with api_context(test_data) as context:
|
|
app.dependency_overrides[get_app_settings] = context.get_settings
|
|
testclient: TestClient = TestClient(app)
|
|
headers = self.unlock(context, testclient)
|
|
resp = testclient.delete("/api/v1/secrets/OTHER_API_KEY", headers=headers)
|
|
self.assertEqual(resp.status_code, 204)
|
|
get_resp = testclient.get("/api/v1/secrets/OTHER_API_KEY", headers=headers)
|
|
self.assertEqual(get_resp.status_code, 404)
|
|
|
|
def fetch_client(
|
|
self, testclient: TestClient, client_name: str
|
|
) -> ClientSpecification:
|
|
"""Fetch a client."""
|
|
client_resp = testclient.get(f"/api/v1/clients/{client_name}")
|
|
self.assertEqual(client_resp.status_code, 200)
|
|
client_dict = client_resp.json()
|
|
client = ClientSpecification.model_validate(client_dict)
|
|
return client
|
|
|
|
def unlock(self, context: TestContext, testclient: TestClient) -> dict[str, str]:
|
|
"""Unlock the session."""
|
|
response = testclient.post(
|
|
"/api/v1/auth/unlock", json={"password": context.master_password}
|
|
)
|
|
body = response.json()
|
|
session_id = body["session_id"]
|
|
session_header = {"session-id": str(session_id)}
|
|
return session_header
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|