210 lines
7.1 KiB
Python
210 lines
7.1 KiB
Python
"""Test library.
|
|
|
|
Strategy:
|
|
|
|
We start by spawning the backend server, and create two test keys.
|
|
|
|
Then we spawn the sshd and the admin api.
|
|
"""
|
|
|
|
import asyncio
|
|
import asyncssh
|
|
import secrets
|
|
import tempfile
|
|
from contextlib import asynccontextmanager
|
|
from pathlib import Path
|
|
|
|
import httpx
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
import uvicorn
|
|
from sshecret.backend import SshecretBackend
|
|
from sshecret.crypto import (
|
|
generate_private_key,
|
|
generate_public_key_string,
|
|
write_private_key,
|
|
)
|
|
from sshecret_admin.core.app import create_admin_app
|
|
from sshecret_admin.core.settings import AdminServerSettings
|
|
from sshecret_backend.app import create_backend_app
|
|
from sshecret_backend.settings import BackendSettings
|
|
from sshecret_backend.testing import create_test_token
|
|
from sshecret_sshd.settings import ServerSettings
|
|
from sshecret_sshd.ssh_server import start_sshecret_sshd
|
|
|
|
from .clients import ClientData
|
|
from .helpers import create_sshd_server_key, create_test_admin_user, in_tempdir
|
|
from .types import PortFactory, TestPorts
|
|
|
|
TEST_SCOPE = "function"
|
|
LOOP_SCOPE = "function"
|
|
|
|
|
|
def make_test_key() -> str:
|
|
"""Generate a test key."""
|
|
private_key = generate_private_key()
|
|
return generate_public_key_string(private_key.public_key())
|
|
|
|
|
|
@pytest.fixture(name="test_ports", scope="session")
|
|
def generate_test_ports(unused_tcp_port_factory: PortFactory) -> TestPorts:
|
|
"""Generate the test ports."""
|
|
test_ports = TestPorts(
|
|
backend=unused_tcp_port_factory(),
|
|
admin=unused_tcp_port_factory(),
|
|
sshd=unused_tcp_port_factory(),
|
|
)
|
|
print(f"{test_ports=!r}")
|
|
return test_ports
|
|
|
|
|
|
@pytest_asyncio.fixture(scope=TEST_SCOPE, name="backend_server", loop_scope=LOOP_SCOPE)
|
|
async def run_backend_server(test_ports: TestPorts):
|
|
"""Run the backend server."""
|
|
port = test_ports.backend
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
backend_work_path = Path(tmp_dir)
|
|
db_file = backend_work_path / "backend.db"
|
|
backend_settings = BackendSettings(database=str(db_file.absolute()))
|
|
backend_app = create_backend_app(backend_settings)
|
|
token = create_test_token(backend_settings)
|
|
config = uvicorn.Config(app=backend_app, port=port, loop="asyncio")
|
|
server = uvicorn.Server(config=config)
|
|
server_task = asyncio.create_task(server.serve())
|
|
await asyncio.sleep(0.1)
|
|
backend_url = f"http://127.0.0.1:{port}"
|
|
yield (backend_url, token)
|
|
server.should_exit = True
|
|
await server_task
|
|
|
|
|
|
@pytest_asyncio.fixture(scope=TEST_SCOPE, name="admin_server", loop_scope=LOOP_SCOPE)
|
|
async def run_admin_server(test_ports: TestPorts, backend_server: tuple[str, str]):
|
|
"""Run admin server."""
|
|
backend_url, backend_token = backend_server
|
|
secret_key = secrets.token_urlsafe(32)
|
|
port = test_ports.admin
|
|
with in_tempdir() as admin_work_path:
|
|
admin_db = admin_work_path / "ssh_admin.db"
|
|
admin_settings = AdminServerSettings.model_validate(
|
|
{
|
|
"sshecret_backend_url": backend_url,
|
|
"backend_token": backend_token,
|
|
"secret_key": secret_key,
|
|
"listen_address": "127.0.0.1",
|
|
"port": port,
|
|
"database": str(admin_db.absolute()),
|
|
"password_manager_directory": str(admin_work_path.absolute()),
|
|
}
|
|
)
|
|
admin_app = create_admin_app(admin_settings)
|
|
config = uvicorn.Config(app=admin_app, port=port, loop="asyncio")
|
|
server = uvicorn.Server(config=config)
|
|
server_task = asyncio.create_task(server.serve())
|
|
await asyncio.sleep(0.1)
|
|
admin_url = f"http://127.0.0.1:{port}"
|
|
admin_password = secrets.token_urlsafe(10)
|
|
create_test_admin_user(admin_settings, "test", admin_password)
|
|
await asyncio.sleep(0.1)
|
|
yield (admin_url, ("test", admin_password))
|
|
server.should_exit = True
|
|
await server_task
|
|
|
|
|
|
@pytest_asyncio.fixture(scope=TEST_SCOPE, name="ssh_server", loop_scope=LOOP_SCOPE)
|
|
async def start_ssh_server(test_ports: TestPorts, backend_server: tuple[str, str]):
|
|
"""Run ssh server."""
|
|
backend_url, backend_token = backend_server
|
|
port = test_ports.sshd
|
|
with in_tempdir() as ssh_workdir:
|
|
create_sshd_server_key(ssh_workdir)
|
|
sshd_server_settings = ServerSettings.model_validate(
|
|
{
|
|
"sshecret_backend_url": backend_url,
|
|
"backend_token": backend_token,
|
|
"listen_address": "",
|
|
"port": port,
|
|
"registration": {"enabled": True, "allow_from": "0.0.0.0/0"},
|
|
"enable_ping_command": True,
|
|
}
|
|
)
|
|
|
|
ssh_server = await start_sshecret_sshd(sshd_server_settings)
|
|
await asyncio.sleep(0.1)
|
|
print(f"Started sshd on port {port}")
|
|
yield port
|
|
|
|
ssh_server.close()
|
|
await ssh_server.wait_closed()
|
|
|
|
|
|
@pytest_asyncio.fixture(scope=TEST_SCOPE, name="backend_client", loop_scope=LOOP_SCOPE)
|
|
async def create_backend_http_client(backend_server: tuple[str, str]):
|
|
"""Create a test client."""
|
|
backend_url, backend_token = backend_server
|
|
print(f"Creating backend client towards {backend_url}")
|
|
async with httpx.AsyncClient(
|
|
base_url=backend_url, headers={"X-API-Token": backend_token}
|
|
) as client:
|
|
yield client
|
|
|
|
|
|
@pytest_asyncio.fixture(name="backend_api")
|
|
async def get_test_backend_api(backend_server: tuple[str, str]) -> SshecretBackend:
|
|
"""Get the backend API."""
|
|
backend_url, backend_token = backend_server
|
|
return SshecretBackend(backend_url, backend_token)
|
|
|
|
|
|
@pytest.fixture(scope=TEST_SCOPE)
|
|
def ssh_command_runner(ssh_server: int, tmp_path: Path):
|
|
"""Run a single command on the ssh server."""
|
|
port = ssh_server
|
|
|
|
async def run_command_as(test_client: ClientData, command: str):
|
|
private_key_file = tmp_path / f"id_{test_client.name}"
|
|
write_private_key(test_client.private_key, private_key_file)
|
|
|
|
conn = await asyncssh.connect(
|
|
"127.0.0.1",
|
|
port=port,
|
|
username=test_client.name,
|
|
client_keys=[str(private_key_file)],
|
|
known_hosts=None,
|
|
)
|
|
try:
|
|
result = await conn.run(command)
|
|
return result
|
|
finally:
|
|
conn.close()
|
|
await conn.wait_closed()
|
|
|
|
return run_command_as
|
|
|
|
|
|
@pytest.fixture(name="ssh_session", scope=TEST_SCOPE)
|
|
def create_ssh_session(ssh_server: int, tmp_path: Path):
|
|
"""Create a ssh Session."""
|
|
port = ssh_server
|
|
|
|
@asynccontextmanager
|
|
async def run_process(test_client: ClientData, command: str):
|
|
private_key_file = tmp_path / f"id_{test_client.name}"
|
|
write_private_key(test_client.private_key, private_key_file)
|
|
conn = await asyncssh.connect(
|
|
"127.0.0.1",
|
|
port=port,
|
|
username=test_client.name,
|
|
client_keys=[str(private_key_file)],
|
|
known_hosts=None,
|
|
)
|
|
try:
|
|
async with conn.create_process(command) as process:
|
|
yield process
|
|
finally:
|
|
conn.close()
|
|
await conn.wait_closed()
|
|
|
|
return run_process
|