Files
sshecret/packages/sshecret-backend/tests/test_backend.py
2025-04-18 16:39:05 +02:00

439 lines
14 KiB
Python

"""Tests of the backend api using pytest."""
import logging
import random
import string
from httpx import Response
import pytest
from fastapi.testclient import TestClient
from sqlmodel import Session, SQLModel, create_engine
from sqlmodel.pool import StaticPool
from sshecret_backend.app import app, get_session
from sshecret_backend.testing import create_test_token
from sshecret_backend.models import AuditLog
LOG = logging.getLogger()
handler = logging.StreamHandler()
formatter = logging.Formatter("'%(asctime)s - %(levelname)s - %(message)s'")
handler.setFormatter(formatter)
LOG.addHandler(handler)
LOG.setLevel(logging.DEBUG)
def make_test_key() -> str:
"""Generate a test key."""
randomlength = 540
key = "ssh-rsa "
randompart = "".join(
random.choices(string.ascii_letters + string.digits, k=randomlength)
)
comment = " invalid-test-key"
return key + randompart + comment
def create_client(
test_client: TestClient,
headers: dict[str, str],
name: str,
public_key: str | None = None,
) -> Response:
"""Create client."""
if not public_key:
public_key = make_test_key()
data = {
"name": name,
"public_key": public_key,
}
create_response = test_client.post("/api/v1/clients", headers=headers, json=data)
return create_response
@pytest.fixture(name="session")
def session_fixture():
engine = create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
)
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
yield session
@pytest.fixture(name="token")
def token_fixture(session: Session):
"""Generate a token."""
token = create_test_token(session)
return token
@pytest.fixture(name="headers")
def headers_fixture(token: str) -> dict[str, str]:
"""Generate headers."""
return {"X-API-Token": token}
@pytest.fixture(name="test_client")
def test_client_fixture(session: Session):
"""Test client fixture."""
def get_session_override():
return session
app.dependency_overrides[get_session] = get_session_override
test_client = TestClient(app)
yield test_client
app.dependency_overrides.clear()
def test_missing_token(test_client: TestClient) -> None:
"""Test logging in with missing token."""
response = test_client.get("/api/v1/clients/")
assert response.status_code == 422
def test_incorrect_token(test_client: TestClient) -> None:
"""Test logging in with missing token."""
response = test_client.get("/api/v1/clients/", headers={"X-API-Token": "WRONG"})
assert response.status_code == 401
def test_with_token(test_client: TestClient, token: str) -> None:
"""Test with a valid token."""
response = test_client.get("/api/v1/clients/", headers={"X-API-Token": token})
assert response.status_code == 200
assert len(response.json()) == 0
def test_create_client(test_client: TestClient, headers: dict[str, str]) -> None:
"""Test creating a client."""
client_name = "test"
client_publickey = make_test_key()
create_response = create_client(test_client, headers, client_name, client_publickey)
assert create_response.status_code == 200
response = test_client.get("/api/v1/clients/", headers=headers)
assert response.status_code == 200
clients = response.json()
assert isinstance(clients, list)
client = clients[0]
assert isinstance(client, dict)
assert client.get("name") == client_name
assert client.get("created_at") is not None
def test_delete_client(test_client: TestClient, headers: dict[str, str]) -> None:
"""Test creating a client."""
client_name = "test"
create_response = create_client(
test_client,
headers,
client_name,
)
assert create_response.status_code == 200
resp = test_client.delete("/api/v1/clients/test", headers=headers)
assert resp.status_code == 200
resp = test_client.get("/api/v1/clients/test", headers=headers)
assert resp.status_code == 404
def test_add_secret(test_client: TestClient, headers: dict[str, str]) -> None:
"""Test adding a secret to a client."""
client_name = "test"
client_publickey = make_test_key()
create_response = create_client(
test_client,
headers,
client_name,
client_publickey,
)
assert create_response.status_code == 200
secret_name = "mysecret"
secret_value = "shhhh"
data = {"name": secret_name, "secret": secret_value}
response = test_client.post(
"/api/v1/clients/test/secrets/", headers=headers, json=data
)
assert response.status_code == 200
# Get it back
get_response = test_client.get(
"/api/v1/clients/test/secrets/mysecret", headers=headers
)
assert get_response.status_code == 200
secret_body = get_response.json()
assert secret_body["name"] == data["name"]
assert secret_body["secret"] == data["secret"]
def test_delete_secret(test_client: TestClient, headers: dict[str, str]) -> None:
"""Test deleting a secret."""
test_add_secret(test_client, headers)
resp = test_client.delete("/api/v1/clients/test/secrets/mysecret", headers=headers)
assert resp.status_code == 200
get_response = test_client.get(
"/api/v1/clients/test/secrets/mysecret", headers=headers
)
assert get_response.status_code == 404
def test_put_add_secret(test_client: TestClient, headers: dict[str, str]) -> None:
"""Test adding secret via PUT."""
# Use the test_create_client function to create a client.
test_create_client(test_client, headers)
secret_name = "mysecret"
secret_value = "shhhh"
data = {"name": secret_name, "secret": secret_value}
response = test_client.put(
"/api/v1/clients/test/secrets/mysecret",
headers=headers,
json={"value": secret_value},
)
assert response.status_code == 200
response_model = response.json()
del response_model["created_at"]
del response_model["updated_at"]
assert response_model == data
def test_put_update_secret(test_client: TestClient, headers: dict[str, str]) -> None:
"""Test updating a client secret."""
test_add_secret(test_client, headers)
new_value = "itsasecret"
update_response = test_client.put(
"/api/v1/clients/test/secrets/mysecret",
headers=headers,
json={"value": new_value},
)
assert update_response.status_code == 200
expected = {"name": "mysecret", "secret": new_value}
response_model = update_response.json()
assert {
"name": response_model["name"],
"secret": response_model["secret"],
} == expected
# Ensure that the updated_at has been set.
assert "updated_at" in response_model
def test_audit_logging(test_client: TestClient, headers: dict[str, str]) -> None:
"""Test audit logging."""
public_key = make_test_key()
create_client_resp = create_client(test_client, headers, "test", public_key)
assert create_client_resp.status_code == 200
secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"}
for name, secret in secrets.items():
add_resp = test_client.post(
"/api/v1/clients/test/secrets/",
headers=headers,
json={"name": name, "secret": secret},
)
assert add_resp.status_code == 200
# Fetch the entire client.
get_client_resp = test_client.get("/api/v1/clients/test", headers=headers)
assert get_client_resp.status_code == 200
# Fetch the audit log
audit_log_resp = test_client.get("/api/v1/audit/", headers=headers)
assert audit_log_resp.status_code == 200
audit_logs = audit_log_resp.json()
assert len(audit_logs) > 0
for entry in audit_logs:
# Let's try to reassemble the objects
audit_log = AuditLog.model_validate(entry)
assert audit_log is not None
def test_audit_log_filtering(
session: Session, test_client: TestClient, headers: dict[str, str]
) -> None:
"""Test audit log filtering."""
# Create a lot of test data, but just manually.
audit_log_amount = 150
entries: list[AuditLog] = []
for i in range(audit_log_amount):
client_id = i % 5
entries.append(
AuditLog(
operation="TEST",
object_id=str(i),
client_name=f"client-{client_id}",
message="Test Message",
)
)
session.add_all(entries)
session.commit()
# This should have generated a lot of audit messages
audit_path = "/api/v1/audit/"
audit_log_resp = test_client.get(audit_path, headers=headers)
assert audit_log_resp.status_code == 200
entries = audit_log_resp.json()
assert len(entries) == 100 # We get 100 at a time
audit_log_resp = test_client.get(
audit_path, headers=headers, params={"offset": 100}
)
entries = audit_log_resp.json()
assert len(entries) == 52 # There should be 50 + the two requests we made
# Try to get a specific client
# There should be 30 log entries for each client.
audit_log_resp = test_client.get(
audit_path, headers=headers, params={"filter_client": "client-1"}
)
entries = audit_log_resp.json()
assert len(entries) == 30
def test_secret_invalidation(test_client: TestClient, headers: dict[str, str]) -> None:
"""Test secret invalidation."""
initial_key = make_test_key()
create_client_resp = create_client(test_client, headers, "test", initial_key)
assert create_client_resp.status_code == 200
secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"}
for name, secret in secrets.items():
add_resp = test_client.post(
"/api/v1/clients/test/secrets/",
headers=headers,
json={"name": name, "secret": secret},
)
assert add_resp.status_code == 200
# Update the public-key. This should cause all secrets to be invalidated
# and no longer associated with a client.
new_key = make_test_key()
update_resp = test_client.post(
"/api/v1/clients/test/public-key",
headers=headers,
json={"public_key": new_key},
)
assert update_resp.status_code == 200
# Fetch the client. The list of secrets should be empty.
get_resp = test_client.get("/api/v1/clients/test", headers=headers)
assert get_resp.status_code == 200
client = get_resp.json()
secrets = client.get("secrets")
assert bool(secrets) is False
def test_client_default_policies(
test_client: TestClient, headers: dict[str, str]
) -> None:
"""Test client policies."""
public_key = make_test_key()
resp = create_client(test_client, headers, "test", public_key)
assert resp.status_code == 200
# Fetch policies, should return *
resp = test_client.get("/api/v1/clients/test/policies/", headers=headers)
assert resp.status_code == 200
policies = resp.json()
assert policies["sources"] == ["0.0.0.0/0", "::/0"]
def test_client_policy_update_one(
test_client: TestClient, headers: dict[str, str]
) -> None:
"""Update client policy with single policy."""
public_key = make_test_key()
resp = create_client(test_client, headers, "test", public_key)
assert resp.status_code == 200
policy = ["192.0.2.1"]
resp = test_client.put(
"/api/v1/clients/test/policies/", headers=headers, json={"sources": policy}
)
assert resp.status_code == 200
resp = test_client.get("/api/v1/clients/test/policies/", headers=headers)
assert resp.status_code == 200
policies = resp.json()
assert policies["sources"] == policy
def test_client_policy_update_advanced(
test_client: TestClient, headers: dict[str, str]
) -> None:
"""Test other policy update scenarios."""
public_key = make_test_key()
resp = create_client(test_client, headers, "test", public_key)
assert resp.status_code == 200
policy = ["192.0.2.1", "198.18.0.0/24"]
resp = test_client.put(
"/api/v1/clients/test/policies/", headers=headers, json={"sources": policy}
)
assert resp.status_code == 200
resp = test_client.get("/api/v1/clients/test/policies/", headers=headers)
assert resp.status_code == 200
policies = resp.json()
assert "192.0.2.1" in policies["sources"]
assert "198.18.0.0/24" in policies["sources"]
# Try to set it to something incorrect
policy = ["obviosly_wrong"]
resp = test_client.put(
"/api/v1/clients/test/policies/", headers=headers, json={"sources": policy}
)
assert resp.status_code == 422
# Check that the old value is still there
resp = test_client.get("/api/v1/clients/test/policies/", headers=headers)
assert resp.status_code == 200
policies = resp.json()
assert "192.0.2.1" in policies["sources"]
assert "198.18.0.0/24" in policies["sources"]
# Clear the policies
#
def test_client_policy_update_unset(
test_client: TestClient, headers: dict[str, str]
) -> None:
"""Test clearing the client policy."""
public_key = make_test_key()
resp = create_client(test_client, headers, "test", public_key)
assert resp.status_code == 200
policy = ["192.0.2.1", "198.18.0.0/24"]
resp = test_client.put(
"/api/v1/clients/test/policies/", headers=headers, json={"sources": policy}
)
assert resp.status_code == 200
policies = resp.json()
assert "192.0.2.1" in policies["sources"]
assert "198.18.0.0/24" in policies["sources"]
# Now we clear the policies
resp = test_client.put(
"/api/v1/clients/test/policies/", headers=headers, json={"sources": []}
)
assert resp.status_code == 200
policies = resp.json()
assert policies["sources"] == ["0.0.0.0/0", "::/0"]