160 lines
6.2 KiB
Python
160 lines
6.2 KiB
Python
"""Test various exceptions and error conditions."""
|
|
|
|
from collections.abc import AsyncIterator
|
|
from contextlib import asynccontextmanager
|
|
import asyncssh
|
|
import pytest
|
|
|
|
from .types import ClientRegistry, CommandRunner, ProcessRunner, SshServerFixture
|
|
|
|
|
|
class BaseSshTests:
|
|
"""Base test class."""
|
|
|
|
@asynccontextmanager
|
|
async def unregistered_client(self, username: str, port: int):
|
|
"""Generate SSH session as an uregistered client."""
|
|
private_key = asyncssh.generate_private_key("ssh-rsa")
|
|
conn = await asyncssh.connect(
|
|
"127.0.0.1",
|
|
port=port,
|
|
username=username,
|
|
client_keys=[private_key],
|
|
known_hosts=None,
|
|
)
|
|
try:
|
|
yield conn
|
|
finally:
|
|
conn.close()
|
|
await conn.wait_closed()
|
|
|
|
@asynccontextmanager
|
|
async def ssh_connection(
|
|
self, username: str, port: int, private_key: asyncssh.SSHKey
|
|
):
|
|
"""Generate SSH session as a client with an ed25519 key."""
|
|
# private_key = asyncssh.generate_private_key("ssh-ed25519")
|
|
conn = await asyncssh.connect(
|
|
"127.0.0.1",
|
|
port=port,
|
|
username=username,
|
|
client_keys=[private_key],
|
|
known_hosts=None,
|
|
)
|
|
try:
|
|
yield conn
|
|
finally:
|
|
conn.close()
|
|
await conn.wait_closed()
|
|
|
|
|
|
class TestRegistrationErrors(BaseSshTests):
|
|
"""Test class for errors related to registartion."""
|
|
|
|
@pytest.mark.enable_registration(True)
|
|
@pytest.mark.registration_sources("192.0.2.0/24")
|
|
@pytest.mark.asyncio
|
|
async def test_register_client_invalid_source(
|
|
self, ssh_server: SshServerFixture
|
|
) -> None:
|
|
"""Test client registration from a network that's not permitted."""
|
|
_, port = ssh_server
|
|
with pytest.raises(asyncssh.misc.PermissionDenied):
|
|
async with self.unregistered_client("stranger", port) as conn:
|
|
async with conn.create_process("register") as process:
|
|
stdout, stderr = process.collect_output()
|
|
print(f"{stdout=!r}\n{stderr=!r}")
|
|
if isinstance(stdout, str):
|
|
assert "Enter public key" not in stdout
|
|
result = await process.wait()
|
|
assert result.exit_status == 1
|
|
|
|
@pytest.mark.enable_registration(True)
|
|
@pytest.mark.registration_sources("127.0.0.1")
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_key_type(self, ssh_server: SshServerFixture) -> None:
|
|
"""Test registration with an unsupported key."""
|
|
_, port = ssh_server
|
|
private_key = asyncssh.generate_private_key("ssh-ed25519")
|
|
public_key = private_key.export_public_key().decode().rstrip() + "\n"
|
|
|
|
async with self.ssh_connection("stranger", port, private_key) as conn:
|
|
async with conn.create_process("register") as process:
|
|
output = await process.stdout.readline()
|
|
assert "Enter public key" in output
|
|
stdout, stderr = await process.communicate(public_key)
|
|
assert isinstance(stderr, str)
|
|
print(f"{stdout=!r}, {stderr=!r}")
|
|
assert stderr.rstrip() == "Error: Invalid key type: Only RSA keys are supported."
|
|
result = await process.wait()
|
|
assert result.exit_status == 1
|
|
|
|
@pytest.mark.enable_registration(True)
|
|
@pytest.mark.registration_sources("127.0.0.1")
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_key(self, ssh_server: SshServerFixture) -> None:
|
|
"""Test registration with a bogus string as key.."""
|
|
_, port = ssh_server
|
|
private_key = asyncssh.generate_private_key("ssh-ed25519")
|
|
public_key = f"ssh-test {'A' * 544}\n"
|
|
|
|
async with self.ssh_connection("stranger", port, private_key) as conn:
|
|
async with conn.create_process("register") as process:
|
|
output = await process.stdout.readline()
|
|
assert "Enter public key" in output
|
|
stdout, stderr = await process.communicate(public_key)
|
|
assert isinstance(stderr, str)
|
|
print(f"{stdout=!r}, {stderr=!r}")
|
|
assert stderr.rstrip() == "Error: Invalid key type: Only RSA keys are supported."
|
|
result = await process.wait()
|
|
assert result.exit_status == 1
|
|
|
|
|
|
class TestCommandErrors(BaseSshTests):
|
|
"""Tests various errors around commands."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_command(
|
|
self, ssh_command_runner: CommandRunner, client_registry: ClientRegistry
|
|
) -> None:
|
|
"""Test sending an invalid command."""
|
|
await client_registry["add_client"]("test")
|
|
|
|
result = await ssh_command_runner("test", "cat /etc/passwd")
|
|
|
|
assert result.exit_status == 1
|
|
stderr = result.stderr or ""
|
|
assert isinstance(stderr, str)
|
|
assert stderr.rstrip() == "Error: Unsupported command."
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_command(
|
|
self, ssh_server: SshServerFixture, client_registry: ClientRegistry
|
|
) -> None:
|
|
"""Test sending no command."""
|
|
await client_registry["add_client"]("test")
|
|
_, port = ssh_server
|
|
client_key = client_registry["clients"]["test"]
|
|
async with self.ssh_connection("test", port, client_key.private_key) as conn:
|
|
async with conn.create_process() as process:
|
|
stdout, stderr = await process.communicate()
|
|
print(f"{stdout=!r}, {stderr=!r}")
|
|
assert isinstance(stderr, str)
|
|
assert stderr.rstrip() == "Error: No command was received from the client."
|
|
result = await process.wait()
|
|
assert result.exit_status == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_deny_client_connection(
|
|
self, ssh_command_runner: CommandRunner, client_registry: ClientRegistry
|
|
) -> None:
|
|
"""Test client that is not permitted to connect."""
|
|
await client_registry["add_client"](
|
|
"test-client",
|
|
["mysecret"],
|
|
["192.0.2.0/24"],
|
|
)
|
|
|
|
with pytest.raises(asyncssh.misc.PermissionDenied):
|
|
await ssh_command_runner("test-client", "get_secret mysecret")
|