Complete backend
This commit is contained in:
@ -1,6 +1,8 @@
|
||||
"""Tests of the backend api using pytest."""
|
||||
|
||||
import logging
|
||||
import random
|
||||
import string
|
||||
from httpx import Response
|
||||
import pytest
|
||||
|
||||
@ -8,8 +10,7 @@ from fastapi.testclient import TestClient
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
from sqlmodel.pool import StaticPool
|
||||
|
||||
from sshecret_backend.router import app
|
||||
from sshecret_backend.app import get_session
|
||||
from sshecret_backend.app import app, get_session
|
||||
from sshecret_backend.testing import create_test_token
|
||||
from sshecret_backend.models import AuditLog
|
||||
|
||||
@ -22,17 +23,30 @@ LOG.addHandler(handler)
|
||||
LOG.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
TEST_FINGERPRINT = "00:11:22:33:44:55:66:77:88:99:aa:bb:cc:dd:ee:ff"
|
||||
def make_test_key() -> str:
|
||||
"""Generate a test key."""
|
||||
randomlength = 540
|
||||
key = "ssh-rsa "
|
||||
randompart = "".join(
|
||||
random.choices(string.ascii_letters + string.digits, k=randomlength)
|
||||
)
|
||||
comment = " invalid-test-key"
|
||||
return key + randompart + comment
|
||||
|
||||
|
||||
def create_client(
|
||||
test_client: TestClient,
|
||||
headers: dict[str, str],
|
||||
name: str,
|
||||
fingerprint: str = TEST_FINGERPRINT,
|
||||
public_key: str | None = None,
|
||||
) -> Response:
|
||||
"""Create client."""
|
||||
data = {"name": name, "fingerprint": fingerprint}
|
||||
if not public_key:
|
||||
public_key = make_test_key()
|
||||
data = {
|
||||
"name": name,
|
||||
"public_key": public_key,
|
||||
}
|
||||
create_response = test_client.post("/api/v1/clients", headers=headers, json=data)
|
||||
return create_response
|
||||
|
||||
@ -95,10 +109,8 @@ def test_with_token(test_client: TestClient, token: str) -> None:
|
||||
def test_create_client(test_client: TestClient, headers: dict[str, str]) -> None:
|
||||
"""Test creating a client."""
|
||||
client_name = "test"
|
||||
client_fingerprint = TEST_FINGERPRINT
|
||||
create_response = create_client(
|
||||
test_client, headers, client_name, client_fingerprint
|
||||
)
|
||||
client_publickey = make_test_key()
|
||||
create_response = create_client(test_client, headers, client_name, client_publickey)
|
||||
assert create_response.status_code == 200
|
||||
response = test_client.get("/api/v1/clients/", headers=headers)
|
||||
assert response.status_code == 200
|
||||
@ -107,16 +119,35 @@ def test_create_client(test_client: TestClient, headers: dict[str, str]) -> None
|
||||
client = clients[0]
|
||||
assert isinstance(client, dict)
|
||||
assert client.get("name") == client_name
|
||||
assert client.get("fingerprint") == client_fingerprint
|
||||
assert client.get("created_at") is not None
|
||||
|
||||
|
||||
def test_delete_client(test_client: TestClient, headers: dict[str, str]) -> None:
|
||||
"""Test creating a client."""
|
||||
client_name = "test"
|
||||
create_response = create_client(
|
||||
test_client,
|
||||
headers,
|
||||
client_name,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
|
||||
resp = test_client.delete("/api/v1/clients/test", headers=headers)
|
||||
assert resp.status_code == 200
|
||||
|
||||
resp = test_client.get("/api/v1/clients/test", headers=headers)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_add_secret(test_client: TestClient, headers: dict[str, str]) -> None:
|
||||
"""Test adding a secret to a client."""
|
||||
client_name = "test"
|
||||
client_fingerprint = "00:11:22:33:44:55:66:77:88:99:aa:bb:cc:dd:ee:ff"
|
||||
client_publickey = make_test_key()
|
||||
create_response = create_client(
|
||||
test_client, headers, client_name, client_fingerprint
|
||||
test_client,
|
||||
headers,
|
||||
client_name,
|
||||
client_publickey,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
secret_name = "mysecret"
|
||||
@ -136,6 +167,17 @@ def test_add_secret(test_client: TestClient, headers: dict[str, str]) -> None:
|
||||
assert secret_body["secret"] == data["secret"]
|
||||
|
||||
|
||||
def test_delete_secret(test_client: TestClient, headers: dict[str, str]) -> None:
|
||||
"""Test deleting a secret."""
|
||||
test_add_secret(test_client, headers)
|
||||
resp = test_client.delete("/api/v1/clients/test/secrets/mysecret", headers=headers)
|
||||
assert resp.status_code == 200
|
||||
get_response = test_client.get(
|
||||
"/api/v1/clients/test/secrets/mysecret", headers=headers
|
||||
)
|
||||
assert get_response.status_code == 404
|
||||
|
||||
|
||||
def test_put_add_secret(test_client: TestClient, headers: dict[str, str]) -> None:
|
||||
"""Test adding secret via PUT."""
|
||||
# Use the test_create_client function to create a client.
|
||||
@ -178,7 +220,8 @@ def test_put_update_secret(test_client: TestClient, headers: dict[str, str]) ->
|
||||
|
||||
def test_audit_logging(test_client: TestClient, headers: dict[str, str]) -> None:
|
||||
"""Test audit logging."""
|
||||
create_client_resp = create_client(test_client, headers, "test")
|
||||
public_key = make_test_key()
|
||||
create_client_resp = create_client(test_client, headers, "test", public_key)
|
||||
assert create_client_resp.status_code == 200
|
||||
secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"}
|
||||
for name, secret in secrets.items():
|
||||
@ -251,7 +294,8 @@ def test_audit_log_filtering(
|
||||
|
||||
def test_secret_invalidation(test_client: TestClient, headers: dict[str, str]) -> None:
|
||||
"""Test secret invalidation."""
|
||||
create_client_resp = create_client(test_client, headers, "test")
|
||||
initial_key = make_test_key()
|
||||
create_client_resp = create_client(test_client, headers, "test", initial_key)
|
||||
assert create_client_resp.status_code == 200
|
||||
secrets = {"secret1": "foo", "secret2": "bar", "secret3": "baz"}
|
||||
for name, secret in secrets.items():
|
||||
@ -262,12 +306,13 @@ def test_secret_invalidation(test_client: TestClient, headers: dict[str, str]) -
|
||||
)
|
||||
assert add_resp.status_code == 200
|
||||
|
||||
# Update the fingerprint. This should cause all secrets to be invalidated
|
||||
# Update the public-key. This should cause all secrets to be invalidated
|
||||
# and no longer associated with a client.
|
||||
new_key = make_test_key()
|
||||
update_resp = test_client.post(
|
||||
"/api/v1/clients/test/update_fingerprint",
|
||||
"/api/v1/clients/test/public-key",
|
||||
headers=headers,
|
||||
json={"fingerprint": "foobar"},
|
||||
json={"public_key": new_key},
|
||||
)
|
||||
assert update_resp.status_code == 200
|
||||
|
||||
@ -277,3 +322,117 @@ def test_secret_invalidation(test_client: TestClient, headers: dict[str, str]) -
|
||||
client = get_resp.json()
|
||||
secrets = client.get("secrets")
|
||||
assert bool(secrets) is False
|
||||
|
||||
|
||||
def test_client_default_policies(
|
||||
test_client: TestClient, headers: dict[str, str]
|
||||
) -> None:
|
||||
"""Test client policies."""
|
||||
public_key = make_test_key()
|
||||
resp = create_client(test_client, headers, "test", public_key)
|
||||
assert resp.status_code == 200
|
||||
# Fetch policies, should return *
|
||||
resp = test_client.get("/api/v1/clients/test/policies/", headers=headers)
|
||||
assert resp.status_code == 200
|
||||
|
||||
policies = resp.json()
|
||||
|
||||
assert policies["sources"] == ["0.0.0.0/0", "::/0"]
|
||||
|
||||
|
||||
def test_client_policy_update_one(
|
||||
test_client: TestClient, headers: dict[str, str]
|
||||
) -> None:
|
||||
"""Update client policy with single policy."""
|
||||
public_key = make_test_key()
|
||||
resp = create_client(test_client, headers, "test", public_key)
|
||||
assert resp.status_code == 200
|
||||
|
||||
policy = ["192.0.2.1"]
|
||||
resp = test_client.put(
|
||||
"/api/v1/clients/test/policies/", headers=headers, json={"sources": policy}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
resp = test_client.get("/api/v1/clients/test/policies/", headers=headers)
|
||||
assert resp.status_code == 200
|
||||
|
||||
policies = resp.json()
|
||||
|
||||
assert policies["sources"] == policy
|
||||
|
||||
|
||||
def test_client_policy_update_advanced(
|
||||
test_client: TestClient, headers: dict[str, str]
|
||||
) -> None:
|
||||
"""Test other policy update scenarios."""
|
||||
public_key = make_test_key()
|
||||
resp = create_client(test_client, headers, "test", public_key)
|
||||
assert resp.status_code == 200
|
||||
|
||||
policy = ["192.0.2.1", "198.18.0.0/24"]
|
||||
|
||||
resp = test_client.put(
|
||||
"/api/v1/clients/test/policies/", headers=headers, json={"sources": policy}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
resp = test_client.get("/api/v1/clients/test/policies/", headers=headers)
|
||||
assert resp.status_code == 200
|
||||
|
||||
policies = resp.json()
|
||||
|
||||
assert "192.0.2.1" in policies["sources"]
|
||||
assert "198.18.0.0/24" in policies["sources"]
|
||||
|
||||
# Try to set it to something incorrect
|
||||
|
||||
policy = ["obviosly_wrong"]
|
||||
|
||||
resp = test_client.put(
|
||||
"/api/v1/clients/test/policies/", headers=headers, json={"sources": policy}
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
# Check that the old value is still there
|
||||
|
||||
resp = test_client.get("/api/v1/clients/test/policies/", headers=headers)
|
||||
assert resp.status_code == 200
|
||||
|
||||
policies = resp.json()
|
||||
|
||||
assert "192.0.2.1" in policies["sources"]
|
||||
assert "198.18.0.0/24" in policies["sources"]
|
||||
|
||||
# Clear the policies
|
||||
#
|
||||
|
||||
|
||||
def test_client_policy_update_unset(
|
||||
test_client: TestClient, headers: dict[str, str]
|
||||
) -> None:
|
||||
"""Test clearing the client policy."""
|
||||
public_key = make_test_key()
|
||||
resp = create_client(test_client, headers, "test", public_key)
|
||||
assert resp.status_code == 200
|
||||
policy = ["192.0.2.1", "198.18.0.0/24"]
|
||||
|
||||
resp = test_client.put(
|
||||
"/api/v1/clients/test/policies/", headers=headers, json={"sources": policy}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
policies = resp.json()
|
||||
|
||||
assert "192.0.2.1" in policies["sources"]
|
||||
assert "198.18.0.0/24" in policies["sources"]
|
||||
|
||||
# Now we clear the policies
|
||||
|
||||
resp = test_client.put(
|
||||
"/api/v1/clients/test/policies/", headers=headers, json={"sources": []}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
policies = resp.json()
|
||||
|
||||
assert policies["sources"] == ["0.0.0.0/0", "::/0"]
|
||||
|
||||
Reference in New Issue
Block a user