"""Tests of the backend api using pytest.""" import logging import random import string from httpx import Response import pytest from fastapi.testclient import TestClient from sqlmodel import Session, SQLModel, create_engine from sqlmodel.pool import StaticPool from sshecret_backend.app import app, get_session from sshecret_backend.testing import create_test_token from sshecret_backend.models import AuditLog LOG = logging.getLogger() handler = logging.StreamHandler() formatter = logging.Formatter("'%(asctime)s - %(levelname)s - %(message)s'") handler.setFormatter(formatter) LOG.addHandler(handler) LOG.setLevel(logging.DEBUG) def make_test_key() -> str: """Generate a test key.""" randomlength = 540 key = "ssh-rsa " randompart = "".join( random.choices(string.ascii_letters + string.digits, k=randomlength) ) comment = " invalid-test-key" return key + randompart + comment def create_client( test_client: TestClient, headers: dict[str, str], name: str, public_key: str | None = None, ) -> Response: """Create client.""" if not public_key: public_key = make_test_key() data = { "name": name, "public_key": public_key, } create_response = test_client.post("/api/v1/clients", headers=headers, json=data) return create_response @pytest.fixture(name="session") def session_fixture(): engine = create_engine( "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool ) SQLModel.metadata.create_all(engine) with Session(engine) as session: yield session @pytest.fixture(name="token") def token_fixture(session: Session): """Generate a token.""" token = create_test_token(session) return token @pytest.fixture(name="headers") def headers_fixture(token: str) -> dict[str, str]: """Generate headers.""" return {"X-API-Token": token} @pytest.fixture(name="test_client") def test_client_fixture(session: Session): """Test client fixture.""" def get_session_override(): return session app.dependency_overrides[get_session] = get_session_override test_client = TestClient(app) yield test_client app.dependency_overrides.clear() def test_missing_token(test_client: TestClient) -> None: """Test logging in with missing token.""" response = test_client.get("/api/v1/clients/") assert response.status_code == 422 def test_incorrect_token(test_client: TestClient) -> None: """Test logging in with missing token.""" response = test_client.get("/api/v1/clients/", headers={"X-API-Token": "WRONG"}) assert response.status_code == 401 def test_with_token(test_client: TestClient, token: str) -> None: """Test with a valid token.""" response = test_client.get("/api/v1/clients/", headers={"X-API-Token": token}) assert response.status_code == 200 assert len(response.json()) == 0 def test_create_client(test_client: TestClient, headers: dict[str, str]) -> None: """Test creating a client.""" client_name = "test" client_publickey = make_test_key() create_response = create_client(test_client, headers, client_name, client_publickey) assert create_response.status_code == 200 response = test_client.get("/api/v1/clients/", headers=headers) assert response.status_code == 200 clients = response.json() assert isinstance(clients, list) client = clients[0] assert isinstance(client, dict) assert client.get("name") == client_name assert client.get("created_at") is not None def test_delete_client(test_client: TestClient, headers: dict[str, str]) -> None: """Test creating a client.""" client_name = "test" create_response = create_client( test_client, headers, client_name, ) assert create_response.status_code == 200 resp = test_client.delete("/api/v1/clients/test", headers=headers) assert resp.status_code == 200 resp = test_client.get("/api/v1/clients/test", headers=headers) assert resp.status_code == 404 def test_add_secret(test_client: TestClient, headers: dict[str, str]) -> None: """Test adding a secret to a client.""" client_name = "test" client_publickey = make_test_key() create_response = create_client( test_client, headers, client_name, client_publickey, ) assert create_response.status_code == 200 secret_name = "mysecret" secret_value = "shhhh" data = {"name": secret_name, "secret": secret_value} response = test_client.post( "/api/v1/clients/test/secrets/", headers=headers, json=data ) assert response.status_code == 200 # Get it back get_response = test_client.get( "/api/v1/clients/test/secrets/mysecret", headers=headers ) assert get_response.status_code == 200 secret_body = get_response.json() assert secret_body["name"] == data["name"] assert secret_body["secret"] == data["secret"] def test_delete_secret(test_client: TestClient, headers: dict[str, str]) -> None: """Test deleting a secret.""" test_add_secret(test_client, headers) resp = test_client.delete("/api/v1/clients/test/secrets/mysecret", headers=headers) assert resp.status_code == 200 get_response = test_client.get( "/api/v1/clients/test/secrets/mysecret", headers=headers ) assert get_response.status_code == 404 def test_put_add_secret(test_client: TestClient, headers: dict[str, str]) -> None: """Test adding secret via PUT.""" # Use the test_create_client function to create a client. test_create_client(test_client, headers) secret_name = "mysecret" secret_value = "shhhh" data = {"name": secret_name, "secret": secret_value} response = test_client.put( "/api/v1/clients/test/secrets/mysecret", headers=headers, json={"value": secret_value}, ) assert response.status_code == 200 response_model = response.json() del response_model["created_at"] del response_model["updated_at"] assert response_model == data def test_put_update_secret(test_client: TestClient, headers: dict[str, str]) -> None: """Test updating a client secret.""" test_add_secret(test_client, headers) new_value = "itsasecret" update_response = test_client.put( "/api/v1/clients/test/secrets/mysecret", headers=headers, json={"value": new_value}, ) assert update_response.status_code == 200 expected = {"name": "mysecret", "secret": new_value} response_model = update_response.json() assert { "name": response_model["name"], "secret": response_model["secret"], } == expected # Ensure that the updated_at has been set. assert "updated_at" in response_model def test_audit_logging(test_client: TestClient, headers: dict[str, str]) -> None: """Test audit logging.""" public_key = make_test_key() create_client_resp = create_client(test_client, headers, "test", public_key) assert create_client_resp.status_code == 200 secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"} for name, secret in secrets.items(): add_resp = test_client.post( "/api/v1/clients/test/secrets/", headers=headers, json={"name": name, "secret": secret}, ) assert add_resp.status_code == 200 # Fetch the entire client. get_client_resp = test_client.get("/api/v1/clients/test", headers=headers) assert get_client_resp.status_code == 200 # Fetch the audit log audit_log_resp = test_client.get("/api/v1/audit/", headers=headers) assert audit_log_resp.status_code == 200 audit_logs = audit_log_resp.json() assert len(audit_logs) > 0 for entry in audit_logs: # Let's try to reassemble the objects audit_log = AuditLog.model_validate(entry) assert audit_log is not None def test_audit_log_filtering( session: Session, test_client: TestClient, headers: dict[str, str] ) -> None: """Test audit log filtering.""" # Create a lot of test data, but just manually. audit_log_amount = 150 entries: list[AuditLog] = [] for i in range(audit_log_amount): client_id = i % 5 entries.append( AuditLog( operation="TEST", object_id=str(i), client_name=f"client-{client_id}", message="Test Message", ) ) session.add_all(entries) session.commit() # This should have generated a lot of audit messages audit_path = "/api/v1/audit/" audit_log_resp = test_client.get(audit_path, headers=headers) assert audit_log_resp.status_code == 200 entries = audit_log_resp.json() assert len(entries) == 100 # We get 100 at a time audit_log_resp = test_client.get( audit_path, headers=headers, params={"offset": 100} ) entries = audit_log_resp.json() assert len(entries) == 52 # There should be 50 + the two requests we made # Try to get a specific client # There should be 30 log entries for each client. audit_log_resp = test_client.get( audit_path, headers=headers, params={"filter_client": "client-1"} ) entries = audit_log_resp.json() assert len(entries) == 30 def test_secret_invalidation(test_client: TestClient, headers: dict[str, str]) -> None: """Test secret invalidation.""" initial_key = make_test_key() create_client_resp = create_client(test_client, headers, "test", initial_key) assert create_client_resp.status_code == 200 secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"} for name, secret in secrets.items(): add_resp = test_client.post( "/api/v1/clients/test/secrets/", headers=headers, json={"name": name, "secret": secret}, ) assert add_resp.status_code == 200 # Update the public-key. This should cause all secrets to be invalidated # and no longer associated with a client. new_key = make_test_key() update_resp = test_client.post( "/api/v1/clients/test/public-key", headers=headers, json={"public_key": new_key}, ) assert update_resp.status_code == 200 # Fetch the client. The list of secrets should be empty. get_resp = test_client.get("/api/v1/clients/test", headers=headers) assert get_resp.status_code == 200 client = get_resp.json() secrets = client.get("secrets") assert bool(secrets) is False def test_client_default_policies( test_client: TestClient, headers: dict[str, str] ) -> None: """Test client policies.""" public_key = make_test_key() resp = create_client(test_client, headers, "test", public_key) assert resp.status_code == 200 # Fetch policies, should return * resp = test_client.get("/api/v1/clients/test/policies/", headers=headers) assert resp.status_code == 200 policies = resp.json() assert policies["sources"] == ["0.0.0.0/0", "::/0"] def test_client_policy_update_one( test_client: TestClient, headers: dict[str, str] ) -> None: """Update client policy with single policy.""" public_key = make_test_key() resp = create_client(test_client, headers, "test", public_key) assert resp.status_code == 200 policy = ["192.0.2.1"] resp = test_client.put( "/api/v1/clients/test/policies/", headers=headers, json={"sources": policy} ) assert resp.status_code == 200 resp = test_client.get("/api/v1/clients/test/policies/", headers=headers) assert resp.status_code == 200 policies = resp.json() assert policies["sources"] == policy def test_client_policy_update_advanced( test_client: TestClient, headers: dict[str, str] ) -> None: """Test other policy update scenarios.""" public_key = make_test_key() resp = create_client(test_client, headers, "test", public_key) assert resp.status_code == 200 policy = ["192.0.2.1", "198.18.0.0/24"] resp = test_client.put( "/api/v1/clients/test/policies/", headers=headers, json={"sources": policy} ) assert resp.status_code == 200 resp = test_client.get("/api/v1/clients/test/policies/", headers=headers) assert resp.status_code == 200 policies = resp.json() assert "192.0.2.1" in policies["sources"] assert "198.18.0.0/24" in policies["sources"] # Try to set it to something incorrect policy = ["obviosly_wrong"] resp = test_client.put( "/api/v1/clients/test/policies/", headers=headers, json={"sources": policy} ) assert resp.status_code == 422 # Check that the old value is still there resp = test_client.get("/api/v1/clients/test/policies/", headers=headers) assert resp.status_code == 200 policies = resp.json() assert "192.0.2.1" in policies["sources"] assert "198.18.0.0/24" in policies["sources"] # Clear the policies # def test_client_policy_update_unset( test_client: TestClient, headers: dict[str, str] ) -> None: """Test clearing the client policy.""" public_key = make_test_key() resp = create_client(test_client, headers, "test", public_key) assert resp.status_code == 200 policy = ["192.0.2.1", "198.18.0.0/24"] resp = test_client.put( "/api/v1/clients/test/policies/", headers=headers, json={"sources": policy} ) assert resp.status_code == 200 policies = resp.json() assert "192.0.2.1" in policies["sources"] assert "198.18.0.0/24" in policies["sources"] # Now we clear the policies resp = test_client.put( "/api/v1/clients/test/policies/", headers=headers, json={"sources": []} ) assert resp.status_code == 200 policies = resp.json() assert policies["sources"] == ["0.0.0.0/0", "::/0"]