check in current project state
This commit is contained in:
416
tests/test_admin_api.py
Normal file
416
tests/test_admin_api.py
Normal file
@ -0,0 +1,416 @@
|
||||
"""Tests for the Admin HTTP API"""
|
||||
|
||||
from ipaddress import IPv4Address
|
||||
import unittest
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from sshecret.types import ClientSpecification
|
||||
from sshecret.testing import TestClientSpec, TestContext, api_context
|
||||
from sshecret.webapi.api import get_app_settings
|
||||
from sshecret.webapi.router import app
|
||||
from sshecret.crypto import (
|
||||
generate_private_key,
|
||||
generate_public_key_string,
|
||||
decode_string,
|
||||
)
|
||||
|
||||
|
||||
class TestLockUnlock(unittest.TestCase):
|
||||
"""Test lock and unlock."""
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""Set up testing."""
|
||||
|
||||
def test_unlock_lock(self) -> None:
|
||||
"""Test unlocking."""
|
||||
with api_context([]) as context:
|
||||
app.dependency_overrides[get_app_settings] = context.get_settings
|
||||
testclient: TestClient = TestClient(app)
|
||||
response = testclient.post(
|
||||
"api/v1/auth/unlock", json={"password": context.master_password}
|
||||
)
|
||||
body = response.json()
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertIn("session_id", body)
|
||||
session_id = body["session_id"]
|
||||
session_header = {"session-id": str(session_id)}
|
||||
status_resp = testclient.get("/api/v1/auth/status", headers=session_header)
|
||||
self.assertEqual(status_resp.status_code, 200)
|
||||
status_body = status_resp.json()
|
||||
self.assertIn("message", status_body)
|
||||
self.assertEqual(str(status_body["message"]), "UNLOCKED")
|
||||
lock_resp = testclient.post("/api/v1/auth/lock", headers=session_header)
|
||||
self.assertEqual(lock_resp.status_code, 200)
|
||||
lock_body = lock_resp.json()
|
||||
lock_status = lock_body.get("message")
|
||||
self.assertEqual(lock_status, "LOCKED")
|
||||
|
||||
def test_get_clients(self) -> None:
|
||||
"""Test get clients."""
|
||||
test_data = [
|
||||
TestClientSpec(
|
||||
"webserver",
|
||||
{
|
||||
"API_KEY": "test",
|
||||
"OTHER_API_KEY": "test2",
|
||||
},
|
||||
),
|
||||
TestClientSpec(
|
||||
"db_server",
|
||||
{
|
||||
"DB_PASSWORD": "test",
|
||||
},
|
||||
),
|
||||
]
|
||||
with api_context(test_data) as context:
|
||||
app.dependency_overrides[get_app_settings] = context.get_settings
|
||||
testclient: TestClient = TestClient(app)
|
||||
client_resp = testclient.get("/api/v1/clients")
|
||||
clients = client_resp.json()
|
||||
self.assertIsInstance(clients, list)
|
||||
self.assertEqual(len(clients), 2)
|
||||
|
||||
for client in clients:
|
||||
ClientSpecification.model_validate(client)
|
||||
|
||||
def test_get_client(self) -> None:
|
||||
"""Test get specific client."""
|
||||
test_data = [
|
||||
TestClientSpec(
|
||||
"webserver",
|
||||
{
|
||||
"API_KEY": "test",
|
||||
"OTHER_API_KEY": "test2",
|
||||
},
|
||||
),
|
||||
TestClientSpec(
|
||||
"db_server",
|
||||
{
|
||||
"DB_PASSWORD": "test",
|
||||
},
|
||||
),
|
||||
]
|
||||
with api_context(test_data) as context:
|
||||
app.dependency_overrides[get_app_settings] = context.get_settings
|
||||
testclient: TestClient = TestClient(app)
|
||||
client_resp = testclient.get("/api/v1/clients/webserver")
|
||||
self.assertEqual(client_resp.status_code, 200)
|
||||
client_dict = client_resp.json()
|
||||
ClientSpecification.model_validate(client_dict)
|
||||
|
||||
def test_update_client(self) -> None:
|
||||
"""Test update client with trivial value."""
|
||||
test_data = [
|
||||
TestClientSpec(
|
||||
"webserver",
|
||||
{
|
||||
"API_KEY": "test",
|
||||
"OTHER_API_KEY": "test2",
|
||||
},
|
||||
),
|
||||
TestClientSpec(
|
||||
"db_server",
|
||||
{
|
||||
"DB_PASSWORD": "test",
|
||||
},
|
||||
),
|
||||
]
|
||||
with api_context(test_data) as context:
|
||||
app.dependency_overrides[get_app_settings] = context.get_settings
|
||||
testclient: TestClient = TestClient(app)
|
||||
client_resp = testclient.get("/api/v1/clients/webserver")
|
||||
self.assertEqual(client_resp.status_code, 200)
|
||||
client_dict = client_resp.json()
|
||||
client = ClientSpecification.model_validate(client_dict)
|
||||
unlock_response = testclient.post(
|
||||
"/api/v1/auth/unlock", json={"password": context.master_password}
|
||||
)
|
||||
body = unlock_response.json()
|
||||
self.assertEqual(unlock_response.status_code, 200)
|
||||
self.assertIn("session_id", body)
|
||||
session_id = body["session_id"]
|
||||
session_header = {"session-id": str(session_id)}
|
||||
serialized_client = client.model_dump(exclude_unset=True)
|
||||
serialized_client["allowed_ips"] = ["192.0.2.1"]
|
||||
update_response = testclient.put(
|
||||
"/api/v1/clients/webserver",
|
||||
json=serialized_client,
|
||||
headers=session_header,
|
||||
)
|
||||
self.assertAlmostEqual(update_response.status_code, 200)
|
||||
update_body = update_response.json()
|
||||
updated_client = ClientSpecification.model_validate(update_body)
|
||||
assert updated_client.allowed_ips == [IPv4Address("192.0.2.1")]
|
||||
|
||||
def test_update_client_sshkey(self) -> None:
|
||||
"""Update client SSH key."""
|
||||
test_data = [
|
||||
TestClientSpec(
|
||||
"webserver",
|
||||
{
|
||||
"API_KEY": "test",
|
||||
"OTHER_API_KEY": "test2",
|
||||
},
|
||||
),
|
||||
TestClientSpec(
|
||||
"db_server",
|
||||
{
|
||||
"DB_PASSWORD": "test",
|
||||
},
|
||||
),
|
||||
]
|
||||
with api_context(test_data) as context:
|
||||
app.dependency_overrides[get_app_settings] = context.get_settings
|
||||
testclient: TestClient = TestClient(app)
|
||||
new_private_key = generate_private_key()
|
||||
public_key = generate_public_key_string(new_private_key.public_key())
|
||||
client_resp = testclient.get("/api/v1/clients/webserver")
|
||||
self.assertEqual(client_resp.status_code, 200)
|
||||
client_dict = client_resp.json()
|
||||
client = ClientSpecification.model_validate(client_dict)
|
||||
unlock_response = testclient.post(
|
||||
"/api/v1/auth/unlock", json={"password": context.master_password}
|
||||
)
|
||||
body = unlock_response.json()
|
||||
self.assertEqual(unlock_response.status_code, 200)
|
||||
self.assertIn("session_id", body)
|
||||
session_id = body["session_id"]
|
||||
session_header = {"session-id": str(session_id)}
|
||||
|
||||
serialized_client = client.model_dump(exclude_unset=True)
|
||||
serialized_client["public_key"] = public_key
|
||||
update_response = testclient.put(
|
||||
"/api/v1/clients/webserver",
|
||||
json=serialized_client,
|
||||
headers=session_header,
|
||||
)
|
||||
self.assertAlmostEqual(update_response.status_code, 200)
|
||||
update_body = update_response.json()
|
||||
updated_client = ClientSpecification.model_validate(update_body)
|
||||
for secret, value in updated_client.secrets.items():
|
||||
old_secret = client.secrets[secret]
|
||||
self.assertNotEqual(old_secret, value)
|
||||
cleartext = decode_string(value, new_private_key)
|
||||
self.assertTrue(cleartext.startswith("test"))
|
||||
|
||||
# check that the backend is properly updated.
|
||||
new_client_resp = testclient.get("/api/v1/clients/webserver")
|
||||
new_client_dict = new_client_resp.json()
|
||||
self.assertEqual(new_client_resp.status_code, 200)
|
||||
new_client = ClientSpecification.model_validate(new_client_dict)
|
||||
for secret, value in new_client.secrets.items():
|
||||
matching_value = updated_client.secrets[secret]
|
||||
self.assertEqual(value, matching_value)
|
||||
|
||||
def test_delete_client(self) -> None:
|
||||
"""Test the delete_client API."""
|
||||
test_data = [
|
||||
TestClientSpec(
|
||||
"webserver",
|
||||
{
|
||||
"API_KEY": "test",
|
||||
"OTHER_API_KEY": "test2",
|
||||
},
|
||||
),
|
||||
TestClientSpec(
|
||||
"db_server",
|
||||
{
|
||||
"DB_PASSWORD": "test",
|
||||
},
|
||||
),
|
||||
]
|
||||
with api_context(test_data) as context:
|
||||
app.dependency_overrides[get_app_settings] = context.get_settings
|
||||
testclient: TestClient = TestClient(app)
|
||||
client_resp = testclient.get("/api/v1/clients/webserver")
|
||||
self.assertEqual(client_resp.status_code, 200)
|
||||
delete_resp = testclient.delete("/api/v1/clients/webserver")
|
||||
self.assertEqual(delete_resp.status_code, 204)
|
||||
get_resp = testclient.get("/api/v1/clients/webserver")
|
||||
self.assertEqual(get_resp.status_code, 404)
|
||||
|
||||
def test_add_client(self) -> None:
|
||||
"""Test the add_client API."""
|
||||
test_data = [
|
||||
TestClientSpec(
|
||||
"webserver",
|
||||
{
|
||||
"API_KEY": "test",
|
||||
"OTHER_API_KEY": "test2",
|
||||
},
|
||||
),
|
||||
TestClientSpec(
|
||||
"db_server",
|
||||
{
|
||||
"DB_PASSWORD": "test",
|
||||
},
|
||||
),
|
||||
]
|
||||
with api_context(test_data) as context:
|
||||
app.dependency_overrides[get_app_settings] = context.get_settings
|
||||
testclient: TestClient = TestClient(app)
|
||||
private_key = generate_private_key()
|
||||
public_key = generate_public_key_string(private_key.public_key())
|
||||
new_client = ClientSpecification(
|
||||
name="webserver2",
|
||||
public_key=public_key,
|
||||
)
|
||||
add_resp = testclient.post(
|
||||
"/api/v1/clients",
|
||||
json=new_client.model_dump(exclude_unset=True, exclude_defaults=True),
|
||||
)
|
||||
self.assertEqual(add_resp.status_code, 201)
|
||||
body = add_resp.json()
|
||||
client = ClientSpecification.model_validate(body)
|
||||
self.assertEqual(client.public_key, public_key)
|
||||
fetched_client = self.fetch_client(testclient, "webserver2")
|
||||
self.assertEqual(fetched_client, client)
|
||||
|
||||
def test_list_secrets(self) -> None:
|
||||
"""Test the list_secrets API."""
|
||||
test_data = [
|
||||
TestClientSpec(
|
||||
"webserver",
|
||||
{
|
||||
"API_KEY": "test",
|
||||
"OTHER_API_KEY": "test2",
|
||||
},
|
||||
),
|
||||
TestClientSpec(
|
||||
"db_server",
|
||||
{
|
||||
"DB_PASSWORD": "test",
|
||||
},
|
||||
),
|
||||
]
|
||||
with api_context(test_data) as context:
|
||||
app.dependency_overrides[get_app_settings] = context.get_settings
|
||||
testclient: TestClient = TestClient(app)
|
||||
headers = self.unlock(context, testclient)
|
||||
resp = testclient.get("/api/v1/secrets", headers=headers)
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
expected = [
|
||||
{"name": "API_KEY", "assigned_clients": ["webserver"]},
|
||||
{"name": "OTHER_API_KEY", "assigned_clients": ["webserver"]},
|
||||
{"name": "DB_PASSWORD", "assigned_clients": ["db_server"]},
|
||||
]
|
||||
body = resp.json()
|
||||
|
||||
self.assertListEqual(body, expected)
|
||||
|
||||
def test_get_secret(self) -> None:
|
||||
"""Test the get_secret API."""
|
||||
test_data = [
|
||||
TestClientSpec(
|
||||
"db_server",
|
||||
{
|
||||
"DB_PASSWORD": "test",
|
||||
},
|
||||
),
|
||||
]
|
||||
with api_context(test_data) as context:
|
||||
app.dependency_overrides[get_app_settings] = context.get_settings
|
||||
testclient: TestClient = TestClient(app)
|
||||
headers = self.unlock(context, testclient)
|
||||
resp = testclient.get("/api/v1/secrets/DB_PASSWORD", headers=headers)
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
expected = {"name": "DB_PASSWORD", "secret": "test"}
|
||||
body = resp.json()
|
||||
self.assertDictEqual(body, expected)
|
||||
|
||||
def test_update_secret_provided(self) -> None:
|
||||
"""Test the update_secret API.
|
||||
|
||||
Tests updating a secret with a provided string.
|
||||
"""
|
||||
test_data = [
|
||||
TestClientSpec(
|
||||
"db_server",
|
||||
{
|
||||
"DB_PASSWORD": "test",
|
||||
},
|
||||
),
|
||||
]
|
||||
with api_context(test_data) as context:
|
||||
app.dependency_overrides[get_app_settings] = context.get_settings
|
||||
testclient: TestClient = TestClient(app)
|
||||
headers = self.unlock(context, testclient)
|
||||
request = {"secret": "not-so-secret"}
|
||||
resp = testclient.put(
|
||||
"/api/v1/secrets/DB_PASSWORD", json=request, headers=headers
|
||||
)
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
expected = {"name": "DB_PASSWORD", "secret": None}
|
||||
body = resp.json()
|
||||
self.assertDictEqual(body, expected)
|
||||
|
||||
def test_update_secret_auto(self) -> None:
|
||||
"""Test the update_secret API.
|
||||
|
||||
Tests updating a secret with auto-generated string.
|
||||
"""
|
||||
test_data = [
|
||||
TestClientSpec(
|
||||
"db_server",
|
||||
{
|
||||
"DB_PASSWORD": "test",
|
||||
},
|
||||
),
|
||||
]
|
||||
with api_context(test_data) as context:
|
||||
app.dependency_overrides[get_app_settings] = context.get_settings
|
||||
testclient: TestClient = TestClient(app)
|
||||
headers = self.unlock(context, testclient)
|
||||
request = {"secret": None, "auto_generate": True}
|
||||
resp = testclient.put(
|
||||
"/api/v1/secrets/DB_PASSWORD", json=request, headers=headers
|
||||
)
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
body = resp.json()
|
||||
secret = body.get("secret")
|
||||
self.assertIsNotNone(secret)
|
||||
|
||||
def test_delete_secret(self) -> None:
|
||||
"""Test delete_secret API."""
|
||||
test_data = [
|
||||
TestClientSpec(
|
||||
"webserver",
|
||||
{
|
||||
"API_KEY": "test",
|
||||
"OTHER_API_KEY": "test2",
|
||||
},
|
||||
),
|
||||
]
|
||||
with api_context(test_data) as context:
|
||||
app.dependency_overrides[get_app_settings] = context.get_settings
|
||||
testclient: TestClient = TestClient(app)
|
||||
headers = self.unlock(context, testclient)
|
||||
resp = testclient.delete("/api/v1/secrets/OTHER_API_KEY", headers=headers)
|
||||
self.assertEqual(resp.status_code, 204)
|
||||
get_resp = testclient.get("/api/v1/secrets/OTHER_API_KEY", headers=headers)
|
||||
self.assertEqual(get_resp.status_code, 404)
|
||||
|
||||
def fetch_client(
|
||||
self, testclient: TestClient, client_name: str
|
||||
) -> ClientSpecification:
|
||||
"""Fetch a client."""
|
||||
client_resp = testclient.get(f"/api/v1/clients/{client_name}")
|
||||
self.assertEqual(client_resp.status_code, 200)
|
||||
client_dict = client_resp.json()
|
||||
client = ClientSpecification.model_validate(client_dict)
|
||||
return client
|
||||
|
||||
def unlock(self, context: TestContext, testclient: TestClient) -> dict[str, str]:
|
||||
"""Unlock the session."""
|
||||
response = testclient.post(
|
||||
"/api/v1/auth/unlock", json={"password": context.master_password}
|
||||
)
|
||||
body = response.json()
|
||||
session_id = body["session_id"]
|
||||
session_header = {"session-id": str(session_id)}
|
||||
return session_header
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user