280 lines
9.1 KiB
Python
280 lines
9.1 KiB
Python
"""Tests of the backend api using pytest."""
|
|
|
|
import logging
|
|
from httpx import Response
|
|
import pytest
|
|
|
|
from fastapi.testclient import TestClient
|
|
from sqlmodel import Session, SQLModel, create_engine
|
|
from sqlmodel.pool import StaticPool
|
|
|
|
from sshecret_backend.router import app
|
|
from sshecret_backend.app import get_session
|
|
from sshecret_backend.testing import create_test_token
|
|
from sshecret_backend.models import AuditLog
|
|
|
|
|
|
LOG = logging.getLogger()
|
|
handler = logging.StreamHandler()
|
|
formatter = logging.Formatter("'%(asctime)s - %(levelname)s - %(message)s'")
|
|
handler.setFormatter(formatter)
|
|
LOG.addHandler(handler)
|
|
LOG.setLevel(logging.DEBUG)
|
|
|
|
|
|
TEST_FINGERPRINT = "00:11:22:33:44:55:66:77:88:99:aa:bb:cc:dd:ee:ff"
|
|
|
|
|
|
def create_client(
|
|
test_client: TestClient,
|
|
headers: dict[str, str],
|
|
name: str,
|
|
fingerprint: str = TEST_FINGERPRINT,
|
|
) -> Response:
|
|
"""Create client."""
|
|
data = {"name": name, "fingerprint": fingerprint}
|
|
create_response = test_client.post("/api/v1/clients", headers=headers, json=data)
|
|
return create_response
|
|
|
|
|
|
@pytest.fixture(name="session")
|
|
def session_fixture():
|
|
engine = create_engine(
|
|
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
|
|
)
|
|
SQLModel.metadata.create_all(engine)
|
|
with Session(engine) as session:
|
|
yield session
|
|
|
|
|
|
@pytest.fixture(name="token")
|
|
def token_fixture(session: Session):
|
|
"""Generate a token."""
|
|
token = create_test_token(session)
|
|
return token
|
|
|
|
|
|
@pytest.fixture(name="headers")
|
|
def headers_fixture(token: str) -> dict[str, str]:
|
|
"""Generate headers."""
|
|
return {"X-API-Token": token}
|
|
|
|
|
|
@pytest.fixture(name="test_client")
|
|
def test_client_fixture(session: Session):
|
|
"""Test client fixture."""
|
|
|
|
def get_session_override():
|
|
return session
|
|
|
|
app.dependency_overrides[get_session] = get_session_override
|
|
test_client = TestClient(app)
|
|
yield test_client
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
def test_missing_token(test_client: TestClient) -> None:
|
|
"""Test logging in with missing token."""
|
|
response = test_client.get("/api/v1/clients/")
|
|
assert response.status_code == 422
|
|
|
|
|
|
def test_incorrect_token(test_client: TestClient) -> None:
|
|
"""Test logging in with missing token."""
|
|
response = test_client.get("/api/v1/clients/", headers={"X-API-Token": "WRONG"})
|
|
assert response.status_code == 401
|
|
|
|
|
|
def test_with_token(test_client: TestClient, token: str) -> None:
|
|
"""Test with a valid token."""
|
|
response = test_client.get("/api/v1/clients/", headers={"X-API-Token": token})
|
|
assert response.status_code == 200
|
|
assert len(response.json()) == 0
|
|
|
|
|
|
def test_create_client(test_client: TestClient, headers: dict[str, str]) -> None:
|
|
"""Test creating a client."""
|
|
client_name = "test"
|
|
client_fingerprint = TEST_FINGERPRINT
|
|
create_response = create_client(
|
|
test_client, headers, client_name, client_fingerprint
|
|
)
|
|
assert create_response.status_code == 200
|
|
response = test_client.get("/api/v1/clients/", headers=headers)
|
|
assert response.status_code == 200
|
|
clients = response.json()
|
|
assert isinstance(clients, list)
|
|
client = clients[0]
|
|
assert isinstance(client, dict)
|
|
assert client.get("name") == client_name
|
|
assert client.get("fingerprint") == client_fingerprint
|
|
assert client.get("created_at") is not None
|
|
|
|
|
|
def test_add_secret(test_client: TestClient, headers: dict[str, str]) -> None:
|
|
"""Test adding a secret to a client."""
|
|
client_name = "test"
|
|
client_fingerprint = "00:11:22:33:44:55:66:77:88:99:aa:bb:cc:dd:ee:ff"
|
|
create_response = create_client(
|
|
test_client, headers, client_name, client_fingerprint
|
|
)
|
|
assert create_response.status_code == 200
|
|
secret_name = "mysecret"
|
|
secret_value = "shhhh"
|
|
data = {"name": secret_name, "secret": secret_value}
|
|
response = test_client.post(
|
|
"/api/v1/clients/test/secrets/", headers=headers, json=data
|
|
)
|
|
assert response.status_code == 200
|
|
# Get it back
|
|
get_response = test_client.get(
|
|
"/api/v1/clients/test/secrets/mysecret", headers=headers
|
|
)
|
|
assert get_response.status_code == 200
|
|
secret_body = get_response.json()
|
|
assert secret_body["name"] == data["name"]
|
|
assert secret_body["secret"] == data["secret"]
|
|
|
|
|
|
def test_put_add_secret(test_client: TestClient, headers: dict[str, str]) -> None:
|
|
"""Test adding secret via PUT."""
|
|
# Use the test_create_client function to create a client.
|
|
test_create_client(test_client, headers)
|
|
secret_name = "mysecret"
|
|
secret_value = "shhhh"
|
|
data = {"name": secret_name, "secret": secret_value}
|
|
response = test_client.put(
|
|
"/api/v1/clients/test/secrets/mysecret",
|
|
headers=headers,
|
|
json={"value": secret_value},
|
|
)
|
|
assert response.status_code == 200
|
|
response_model = response.json()
|
|
del response_model["created_at"]
|
|
del response_model["updated_at"]
|
|
assert response_model == data
|
|
|
|
|
|
def test_put_update_secret(test_client: TestClient, headers: dict[str, str]) -> None:
|
|
"""Test updating a client secret."""
|
|
test_add_secret(test_client, headers)
|
|
new_value = "itsasecret"
|
|
update_response = test_client.put(
|
|
"/api/v1/clients/test/secrets/mysecret",
|
|
headers=headers,
|
|
json={"value": new_value},
|
|
)
|
|
assert update_response.status_code == 200
|
|
expected = {"name": "mysecret", "secret": new_value}
|
|
response_model = update_response.json()
|
|
|
|
assert {
|
|
"name": response_model["name"],
|
|
"secret": response_model["secret"],
|
|
} == expected
|
|
# Ensure that the updated_at has been set.
|
|
assert "updated_at" in response_model
|
|
|
|
|
|
def test_audit_logging(test_client: TestClient, headers: dict[str, str]) -> None:
|
|
"""Test audit logging."""
|
|
create_client_resp = create_client(test_client, headers, "test")
|
|
assert create_client_resp.status_code == 200
|
|
secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"}
|
|
for name, secret in secrets.items():
|
|
add_resp = test_client.post(
|
|
"/api/v1/clients/test/secrets/",
|
|
headers=headers,
|
|
json={"name": name, "secret": secret},
|
|
)
|
|
assert add_resp.status_code == 200
|
|
|
|
# Fetch the entire client.
|
|
get_client_resp = test_client.get("/api/v1/clients/test", headers=headers)
|
|
assert get_client_resp.status_code == 200
|
|
|
|
# Fetch the audit log
|
|
audit_log_resp = test_client.get("/api/v1/audit/", headers=headers)
|
|
assert audit_log_resp.status_code == 200
|
|
audit_logs = audit_log_resp.json()
|
|
assert len(audit_logs) > 0
|
|
for entry in audit_logs:
|
|
# Let's try to reassemble the objects
|
|
audit_log = AuditLog.model_validate(entry)
|
|
assert audit_log is not None
|
|
|
|
|
|
def test_audit_log_filtering(
|
|
session: Session, test_client: TestClient, headers: dict[str, str]
|
|
) -> None:
|
|
"""Test audit log filtering."""
|
|
# Create a lot of test data, but just manually.
|
|
audit_log_amount = 150
|
|
entries: list[AuditLog] = []
|
|
for i in range(audit_log_amount):
|
|
client_id = i % 5
|
|
entries.append(
|
|
AuditLog(
|
|
operation="TEST",
|
|
object_id=str(i),
|
|
client_name=f"client-{client_id}",
|
|
message="Test Message",
|
|
)
|
|
)
|
|
|
|
session.add_all(entries)
|
|
session.commit()
|
|
|
|
# This should have generated a lot of audit messages
|
|
|
|
audit_path = "/api/v1/audit/"
|
|
audit_log_resp = test_client.get(audit_path, headers=headers)
|
|
assert audit_log_resp.status_code == 200
|
|
entries = audit_log_resp.json()
|
|
assert len(entries) == 100 # We get 100 at a time
|
|
|
|
audit_log_resp = test_client.get(
|
|
audit_path, headers=headers, params={"offset": 100}
|
|
)
|
|
entries = audit_log_resp.json()
|
|
assert len(entries) == 52 # There should be 50 + the two requests we made
|
|
|
|
# Try to get a specific client
|
|
# There should be 30 log entries for each client.
|
|
audit_log_resp = test_client.get(
|
|
audit_path, headers=headers, params={"filter_client": "client-1"}
|
|
)
|
|
|
|
entries = audit_log_resp.json()
|
|
assert len(entries) == 30
|
|
|
|
|
|
def test_secret_invalidation(test_client: TestClient, headers: dict[str, str]) -> None:
|
|
"""Test secret invalidation."""
|
|
create_client_resp = create_client(test_client, headers, "test")
|
|
assert create_client_resp.status_code == 200
|
|
secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"}
|
|
for name, secret in secrets.items():
|
|
add_resp = test_client.post(
|
|
"/api/v1/clients/test/secrets/",
|
|
headers=headers,
|
|
json={"name": name, "secret": secret},
|
|
)
|
|
assert add_resp.status_code == 200
|
|
|
|
# Update the fingerprint. This should cause all secrets to be invalidated
|
|
# and no longer associated with a client.
|
|
update_resp = test_client.post(
|
|
"/api/v1/clients/test/update_fingerprint",
|
|
headers=headers,
|
|
json={"fingerprint": "foobar"},
|
|
)
|
|
assert update_resp.status_code == 200
|
|
|
|
# Fetch the client. The list of secrets should be empty.
|
|
get_resp = test_client.get("/api/v1/clients/test", headers=headers)
|
|
assert get_resp.status_code == 200
|
|
client = get_resp.json()
|
|
secrets = client.get("secrets")
|
|
assert bool(secrets) is False
|