Update tests
This commit is contained in:
@ -6,34 +6,25 @@ import pytest
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
from sshecret.backend import Client
|
||||||
|
|
||||||
class TestAdminAPI:
|
from sshecret.crypto import generate_private_key, generate_public_key_string
|
||||||
"""Tests of the Admin REST API."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
from .types import AdminServer
|
||||||
async def test_health_check(
|
|
||||||
self, admin_server: tuple[str, tuple[str, str]]
|
|
||||||
) -> None:
|
|
||||||
"""Test admin login."""
|
|
||||||
async with self.http_client(admin_server, False) as client:
|
|
||||||
resp = await client.get("/health")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_admin_login(self, admin_server: tuple[str, tuple[str, str]]) -> None:
|
|
||||||
"""Test admin login."""
|
|
||||||
|
|
||||||
async with self.http_client(admin_server, False) as client:
|
def make_test_key() -> str:
|
||||||
resp = await client.get("api/v1/clients/")
|
"""Generate a test key."""
|
||||||
assert resp.status_code == 401
|
private_key = generate_private_key()
|
||||||
|
return generate_public_key_string(private_key.public_key())
|
||||||
|
|
||||||
async with self.http_client(admin_server, True) as client:
|
|
||||||
resp = await client.get("api/v1/clients/")
|
class BaseAdminTests:
|
||||||
assert resp.status_code == 200
|
"""Base admin test class."""
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def http_client(
|
async def http_client(
|
||||||
self, admin_server: tuple[str, tuple[str, str]], authenticate: bool = True
|
self, admin_server: AdminServer, authenticate: bool = True
|
||||||
) -> AsyncIterator[httpx.AsyncClient]:
|
) -> AsyncIterator[httpx.AsyncClient]:
|
||||||
"""Run a client towards the admin rest api."""
|
"""Run a client towards the admin rest api."""
|
||||||
admin_url, credentials = admin_server
|
admin_url, credentials = admin_server
|
||||||
@ -53,3 +44,179 @@ class TestAdminAPI:
|
|||||||
|
|
||||||
async with httpx.AsyncClient(base_url=admin_url, headers=headers) as client:
|
async with httpx.AsyncClient(base_url=admin_url, headers=headers) as client:
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
|
async def create_client(
|
||||||
|
self,
|
||||||
|
admin_server: AdminServer,
|
||||||
|
name: str,
|
||||||
|
public_key: str | None = None,
|
||||||
|
) -> Client:
|
||||||
|
"""Create a client."""
|
||||||
|
if not public_key:
|
||||||
|
public_key = make_test_key()
|
||||||
|
|
||||||
|
new_client = {
|
||||||
|
"name": name,
|
||||||
|
"public_key": public_key,
|
||||||
|
"sources": ["192.0.2.0/24"],
|
||||||
|
}
|
||||||
|
|
||||||
|
async with self.http_client(admin_server, True) as http_client:
|
||||||
|
response = await http_client.post("api/v1/clients/", json=new_client)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
client = Client.model_validate(data)
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdminAPI(BaseAdminTests):
|
||||||
|
"""Tests of the Admin REST API."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check(
|
||||||
|
self, admin_server: tuple[str, tuple[str, str]]
|
||||||
|
) -> None:
|
||||||
|
"""Test admin login."""
|
||||||
|
async with self.http_client(admin_server, False) as client:
|
||||||
|
resp = await client.get("/health")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_admin_login(self, admin_server: AdminServer) -> None:
|
||||||
|
"""Test admin login."""
|
||||||
|
|
||||||
|
async with self.http_client(admin_server, False) as client:
|
||||||
|
resp = await client.get("api/v1/clients/")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
async with self.http_client(admin_server, True) as client:
|
||||||
|
resp = await client.get("api/v1/clients/")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdminApiClients(BaseAdminTests):
|
||||||
|
"""Test client routes."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_client(self, admin_server: AdminServer) -> None:
|
||||||
|
"""Test create_client."""
|
||||||
|
client = await self.create_client(admin_server, "testclient")
|
||||||
|
|
||||||
|
assert client.id is not None
|
||||||
|
assert client.name == "testclient"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_clients(self, admin_server: AdminServer) -> None:
|
||||||
|
"""Test get_clients."""
|
||||||
|
|
||||||
|
client_names = ["test-db", "test-app", "test-www"]
|
||||||
|
for name in client_names:
|
||||||
|
await self.create_client(admin_server, name)
|
||||||
|
async with self.http_client(admin_server) as http_client:
|
||||||
|
resp = await http_client.get("api/v1/clients/")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert isinstance(data, list)
|
||||||
|
assert len(data) == 3
|
||||||
|
for entry in data:
|
||||||
|
assert isinstance(entry, dict)
|
||||||
|
client_name = entry.get("name")
|
||||||
|
assert client_name in client_names
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_client(self, admin_server: AdminServer) -> None:
|
||||||
|
"""Test delete_client."""
|
||||||
|
await self.create_client(admin_server, name="testclient")
|
||||||
|
async with self.http_client(admin_server) as http_client:
|
||||||
|
resp = await http_client.get("api/v1/clients/")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert isinstance(data, list)
|
||||||
|
assert len(data) == 1
|
||||||
|
assert data[0]["name"] == "testclient"
|
||||||
|
|
||||||
|
resp = await http_client.delete("/api/v1/clients/testclient")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
resp = await http_client.get("api/v1/clients/")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert isinstance(data, list)
|
||||||
|
assert len(data) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdminApiSecrets(BaseAdminTests):
|
||||||
|
"""Test secret management."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_secret(self, admin_server: AdminServer) -> None:
|
||||||
|
"""Test add_secret."""
|
||||||
|
await self.create_client(admin_server, name="testclient")
|
||||||
|
async with self.http_client(admin_server) as http_client:
|
||||||
|
data = {
|
||||||
|
"name": "testsecret",
|
||||||
|
"clients": ["testclient"],
|
||||||
|
"value": "secretstring",
|
||||||
|
}
|
||||||
|
resp = await http_client.post("api/v1/secrets/", json=data)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_secret(self, admin_server: AdminServer) -> None:
|
||||||
|
"""Test get_secret."""
|
||||||
|
await self.test_add_secret(admin_server)
|
||||||
|
async with self.http_client(admin_server) as http_client:
|
||||||
|
resp = await http_client.get("api/v1/secrets/testsecret")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert isinstance(data, dict)
|
||||||
|
assert data["name"] == "testsecret"
|
||||||
|
assert data["secret"] == "secretstring"
|
||||||
|
assert "testclient" in data["clients"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_secret_auto(self, admin_server: AdminServer) -> None:
|
||||||
|
"""Test adding a secret with an auto-generated value."""
|
||||||
|
await self.create_client(admin_server, name="testclient")
|
||||||
|
async with self.http_client(admin_server) as http_client:
|
||||||
|
data = {
|
||||||
|
"name": "testsecret",
|
||||||
|
"clients": ["testclient"],
|
||||||
|
"value": {"auto_generate": True, "length": 17},
|
||||||
|
}
|
||||||
|
resp = await http_client.post("api/v1/secrets/", json=data)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
resp = await http_client.get("api/v1/secrets/testsecret")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert isinstance(data, dict)
|
||||||
|
assert data["name"] == "testsecret"
|
||||||
|
assert len(data["secret"]) == 17
|
||||||
|
assert "testclient" in data["clients"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_secret(self, admin_server: AdminServer) -> None:
|
||||||
|
"""Test updating secrets."""
|
||||||
|
await self.test_add_secret_auto(admin_server)
|
||||||
|
async with self.http_client(admin_server) as http_client:
|
||||||
|
resp = await http_client.put(
|
||||||
|
"api/v1/secrets/testsecret",
|
||||||
|
json={"value": "secret"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
resp = await http_client.get("api/v1/secrets/testsecret")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["secret"] == "secret"
|
||||||
|
|
||||||
|
resp = await http_client.put(
|
||||||
|
"api/v1/secrets/testsecret",
|
||||||
|
json={"value": {"auto_generate": True, "length": 16}},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
resp = await http_client.get("api/v1/secrets/testsecret")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert len(data["secret"]) == 16
|
||||||
|
|||||||
@ -34,6 +34,7 @@ async def test_create_client(backend_api: SshecretBackend) -> None:
|
|||||||
|
|
||||||
assert clients[0].public_key == test_client.public_key
|
assert clients[0].public_key == test_client.public_key
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
async def test_create_secret(backend_api: SshecretBackend) -> None:
|
async def test_create_secret(backend_api: SshecretBackend) -> None:
|
||||||
"""Test creating secrets."""
|
"""Test creating secrets."""
|
||||||
test_client = create_test_client("test")
|
test_client = create_test_client("test")
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from .clients import ClientData
|
|||||||
|
|
||||||
PortFactory = Callable[[], int]
|
PortFactory = Callable[[], int]
|
||||||
|
|
||||||
|
AdminServer = tuple[str, tuple[str, str]]
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TestPorts:
|
class TestPorts:
|
||||||
|
|||||||
@ -1,12 +1,15 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any
|
from pydantic import IPvAnyNetwork
|
||||||
import pytest
|
import pytest
|
||||||
import uuid
|
import uuid
|
||||||
import asyncssh
|
import asyncssh
|
||||||
import tempfile
|
import tempfile
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
import pytest_asyncio
|
||||||
|
from pytest import FixtureRequest
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
from ipaddress import IPv4Network, IPv6Network
|
from ipaddress import IPv4Network, IPv6Network, ip_network
|
||||||
|
|
||||||
from sshecret_sshd.ssh_server import run_ssh_server
|
from sshecret_sshd.ssh_server import run_ssh_server
|
||||||
from sshecret_sshd.settings import ClientRegistrationSettings
|
from sshecret_sshd.settings import ClientRegistrationSettings
|
||||||
@ -32,14 +35,16 @@ def client_registry() -> ClientRegistry:
|
|||||||
) -> str:
|
) -> str:
|
||||||
private_key = asyncssh.generate_private_key("ssh-rsa")
|
private_key = asyncssh.generate_private_key("ssh-rsa")
|
||||||
public_key = private_key.export_public_key()
|
public_key = private_key.export_public_key()
|
||||||
clients[name] = ClientKey(name, private_key, public_key.decode().rstrip())
|
clients[name] = ClientKey(
|
||||||
|
name, private_key, public_key.decode().rstrip(), policies
|
||||||
|
)
|
||||||
secrets.update({(name, s): f"mocked-secret-{s}" for s in (secret_names or [])})
|
secrets.update({(name, s): f"mocked-secret-{s}" for s in (secret_names or [])})
|
||||||
return clients[name]
|
return clients[name]
|
||||||
|
|
||||||
return {"clients": clients, "secrets": secrets, "add_client": add_client}
|
return {"clients": clients, "secrets": secrets, "add_client": add_client}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest_asyncio.fixture(scope="function")
|
||||||
async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
|
async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
|
||||||
backend = MagicMock()
|
backend = MagicMock()
|
||||||
clients_data = client_registry["clients"]
|
clients_data = client_registry["clients"]
|
||||||
@ -48,13 +53,16 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
|
|||||||
async def get_client(name: str) -> Client | None:
|
async def get_client(name: str) -> Client | None:
|
||||||
client_key = clients_data.get(name)
|
client_key = clients_data.get(name)
|
||||||
if client_key:
|
if client_key:
|
||||||
|
policies = [IPv4Network("0.0.0.0/0"), IPv6Network("::/0")]
|
||||||
|
if client_key.policies:
|
||||||
|
policies = [ip_network(network) for network in client_key.policies]
|
||||||
response_model = Client(
|
response_model = Client(
|
||||||
id=uuid.uuid4(),
|
id=uuid.uuid4(),
|
||||||
name=name,
|
name=name,
|
||||||
description=f"Mock client {name}",
|
description=f"Mock client {name}",
|
||||||
public_key=client_key.public_key,
|
public_key=client_key.public_key,
|
||||||
secrets=[s for (c, s) in secrets_data if c == name],
|
secrets=[s for (c, s) in secrets_data if c == name],
|
||||||
policies=[IPv4Network("0.0.0.0/0"), IPv6Network("::/0")],
|
policies=policies,
|
||||||
)
|
)
|
||||||
return response_model
|
return response_model
|
||||||
return None
|
return None
|
||||||
@ -95,20 +103,31 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
|
|||||||
return backend
|
return backend
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest_asyncio.fixture(scope="function")
|
||||||
async def ssh_server(
|
async def ssh_server(
|
||||||
mock_backend: MagicMock, unused_tcp_port: int
|
request: FixtureRequest,
|
||||||
|
mock_backend: MagicMock,
|
||||||
|
unused_tcp_port: int,
|
||||||
) -> SshServerFixtureFun:
|
) -> SshServerFixtureFun:
|
||||||
port = unused_tcp_port
|
port = unused_tcp_port
|
||||||
|
|
||||||
private_key = asyncssh.generate_private_key("ssh-ed25519")
|
private_key = asyncssh.generate_private_key("ssh-ed25519")
|
||||||
key_str = private_key.export_private_key()
|
key_str = private_key.export_private_key()
|
||||||
|
registration_mark = request.node.get_closest_marker("enable_registration")
|
||||||
|
registration_enabled = registration_mark is not None
|
||||||
|
registration_source_mark = request.node.get_closest_marker("registration_sources")
|
||||||
|
allowed_from: list[IPvAnyNetwork] = []
|
||||||
|
if registration_source_mark:
|
||||||
|
for network in registration_source_mark.args:
|
||||||
|
allowed_from.append(ip_network(network))
|
||||||
|
else:
|
||||||
|
allowed_from = [IPv4Network("0.0.0.0/0")]
|
||||||
with tempfile.NamedTemporaryFile("w+", delete=True) as key_file:
|
with tempfile.NamedTemporaryFile("w+", delete=True) as key_file:
|
||||||
key_file.write(key_str.decode())
|
key_file.write(key_str.decode())
|
||||||
key_file.flush()
|
key_file.flush()
|
||||||
|
|
||||||
registration_settings = ClientRegistrationSettings(
|
registration_settings = ClientRegistrationSettings(
|
||||||
enabled=True, allow_from=[IPv4Network("0.0.0.0/0")]
|
enabled=registration_enabled,
|
||||||
|
allow_from=allowed_from,
|
||||||
)
|
)
|
||||||
server = await run_ssh_server(
|
server = await run_ssh_server(
|
||||||
backend=mock_backend,
|
backend=mock_backend,
|
||||||
|
|||||||
155
tests/packages/sshd/test_errors.py
Normal file
155
tests/packages/sshd/test_errors.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
"""Test various exceptions and error conditions."""
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
import asyncssh
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from .types import ClientRegistry, CommandRunner, ProcessRunner, SshServerFixture
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSshTests:
|
||||||
|
"""Base test class."""
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def unregistered_client(self, username: str, port: int):
|
||||||
|
"""Generate SSH session as an uregistered client."""
|
||||||
|
private_key = asyncssh.generate_private_key("ssh-rsa")
|
||||||
|
conn = await asyncssh.connect(
|
||||||
|
"127.0.0.1",
|
||||||
|
port=port,
|
||||||
|
username=username,
|
||||||
|
client_keys=[private_key],
|
||||||
|
known_hosts=None,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
await conn.wait_closed()
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def ssh_connection(
|
||||||
|
self, username: str, port: int, private_key: asyncssh.SSHKey
|
||||||
|
):
|
||||||
|
"""Generate SSH session as a client with an ed25519 key."""
|
||||||
|
# private_key = asyncssh.generate_private_key("ssh-ed25519")
|
||||||
|
conn = await asyncssh.connect(
|
||||||
|
"127.0.0.1",
|
||||||
|
port=port,
|
||||||
|
username=username,
|
||||||
|
client_keys=[private_key],
|
||||||
|
known_hosts=None,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
await conn.wait_closed()
|
||||||
|
|
||||||
|
|
||||||
|
class TestRegistrationErrors(BaseSshTests):
|
||||||
|
"""Test class for errors related to registartion."""
|
||||||
|
|
||||||
|
@pytest.mark.enable_registration(True)
|
||||||
|
@pytest.mark.registration_sources("192.0.2.0/24")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_client_invalid_source(
|
||||||
|
self, ssh_server: SshServerFixture
|
||||||
|
) -> None:
|
||||||
|
"""Test client registration from a network that's not permitted."""
|
||||||
|
_, port = ssh_server
|
||||||
|
with pytest.raises(asyncssh.misc.PermissionDenied):
|
||||||
|
async with self.unregistered_client("stranger", port) as conn:
|
||||||
|
async with conn.create_process("register") as process:
|
||||||
|
stdout, stderr = process.collect_output()
|
||||||
|
print(f"{stdout=!r}\n{stderr=!r}")
|
||||||
|
if isinstance(stdout, str):
|
||||||
|
assert "Enter public key" not in stdout
|
||||||
|
result = await process.wait()
|
||||||
|
assert result.exit_status == 1
|
||||||
|
|
||||||
|
@pytest.mark.enable_registration(True)
|
||||||
|
@pytest.mark.registration_sources("127.0.0.1")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_key_type(self, ssh_server: SshServerFixture) -> None:
|
||||||
|
"""Test registration with an unsupported key."""
|
||||||
|
_, port = ssh_server
|
||||||
|
private_key = asyncssh.generate_private_key("ssh-ed25519")
|
||||||
|
public_key = private_key.export_public_key().decode().rstrip() + "\n"
|
||||||
|
|
||||||
|
async with self.ssh_connection("stranger", port, private_key) as conn:
|
||||||
|
async with conn.create_process("register") as process:
|
||||||
|
output = await process.stdout.readline()
|
||||||
|
assert "Enter public key" in output
|
||||||
|
stdout, stderr = await process.communicate(public_key)
|
||||||
|
print(f"{stdout=!r}, {stderr=!r}")
|
||||||
|
assert stderr == "Error: Invalid key type: Only RSA keys are supported."
|
||||||
|
result = await process.wait()
|
||||||
|
assert result.exit_status == 1
|
||||||
|
|
||||||
|
@pytest.mark.enable_registration(True)
|
||||||
|
@pytest.mark.registration_sources("127.0.0.1")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_key(self, ssh_server: SshServerFixture) -> None:
|
||||||
|
"""Test registration with a bogus string as key.."""
|
||||||
|
_, port = ssh_server
|
||||||
|
private_key = asyncssh.generate_private_key("ssh-ed25519")
|
||||||
|
public_key = f"ssh-test {'A' * 544}\n"
|
||||||
|
|
||||||
|
async with self.ssh_connection("stranger", port, private_key) as conn:
|
||||||
|
async with conn.create_process("register") as process:
|
||||||
|
output = await process.stdout.readline()
|
||||||
|
assert "Enter public key" in output
|
||||||
|
stdout, stderr = await process.communicate(public_key)
|
||||||
|
print(f"{stdout=!r}, {stderr=!r}")
|
||||||
|
assert stderr == "Error: Invalid key type: Only RSA keys are supported."
|
||||||
|
result = await process.wait()
|
||||||
|
assert result.exit_status == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestCommandErrors(BaseSshTests):
|
||||||
|
"""Tests various errors around commands."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_command(
|
||||||
|
self, ssh_command_runner: CommandRunner, client_registry: ClientRegistry
|
||||||
|
) -> None:
|
||||||
|
"""Test sending an invalid command."""
|
||||||
|
await client_registry["add_client"]("test")
|
||||||
|
|
||||||
|
result = await ssh_command_runner("test", "cat /etc/passwd")
|
||||||
|
|
||||||
|
assert result.exit_status == 1
|
||||||
|
stderr = result.stderr or ""
|
||||||
|
assert stderr == "Error: Unsupported command."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_command(
|
||||||
|
self, ssh_server: SshServerFixture, client_registry: ClientRegistry
|
||||||
|
) -> None:
|
||||||
|
"""Test sending no command."""
|
||||||
|
await client_registry["add_client"]("test")
|
||||||
|
_, port = ssh_server
|
||||||
|
client_key = client_registry["clients"]["test"]
|
||||||
|
async with self.ssh_connection("test", port, client_key.private_key) as conn:
|
||||||
|
async with conn.create_process() as process:
|
||||||
|
stdout, stderr = await process.communicate()
|
||||||
|
print(f"{stdout=!r}, {stderr=!r}")
|
||||||
|
assert stderr == "Error: No command was received from the client."
|
||||||
|
result = await process.wait()
|
||||||
|
assert result.exit_status == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deny_client_connection(
|
||||||
|
self, ssh_command_runner: CommandRunner, client_registry: ClientRegistry
|
||||||
|
) -> None:
|
||||||
|
"""Test client that is not permitted to connect."""
|
||||||
|
await client_registry["add_client"](
|
||||||
|
"test-client",
|
||||||
|
["mysecret"],
|
||||||
|
["192.0.2.0/24"],
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(asyncssh.misc.PermissionDenied):
|
||||||
|
await ssh_command_runner("test-client", "get_secret mysecret")
|
||||||
@ -5,6 +5,7 @@ import pytest
|
|||||||
from .types import ClientRegistry, CommandRunner, ProcessRunner
|
from .types import ClientRegistry, CommandRunner, ProcessRunner
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.enable_registration(True)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_register_client(
|
async def test_register_client(
|
||||||
ssh_session: ProcessRunner,
|
ssh_session: ProcessRunner,
|
||||||
|
|||||||
@ -31,6 +31,8 @@ class ClientKey:
|
|||||||
name: str
|
name: str
|
||||||
private_key: asyncssh.SSHKey
|
private_key: asyncssh.SSHKey
|
||||||
public_key: str
|
public_key: str
|
||||||
|
policies: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AddClientFun(Protocol):
|
class AddClientFun(Protocol):
|
||||||
|
|||||||
Reference in New Issue
Block a user