Files
sshecret/tests/packages/sshd/test_errors.py
2025-05-12 07:47:38 +02:00

156 lines
6.0 KiB
Python

"""Test various exceptions and error conditions."""
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
import asyncssh
import pytest
from .types import ClientRegistry, CommandRunner, ProcessRunner, SshServerFixture
class BaseSshTests:
"""Base test class."""
@asynccontextmanager
async def unregistered_client(self, username: str, port: int):
"""Generate SSH session as an uregistered client."""
private_key = asyncssh.generate_private_key("ssh-rsa")
conn = await asyncssh.connect(
"127.0.0.1",
port=port,
username=username,
client_keys=[private_key],
known_hosts=None,
)
try:
yield conn
finally:
conn.close()
await conn.wait_closed()
@asynccontextmanager
async def ssh_connection(
self, username: str, port: int, private_key: asyncssh.SSHKey
):
"""Generate SSH session as a client with an ed25519 key."""
# private_key = asyncssh.generate_private_key("ssh-ed25519")
conn = await asyncssh.connect(
"127.0.0.1",
port=port,
username=username,
client_keys=[private_key],
known_hosts=None,
)
try:
yield conn
finally:
conn.close()
await conn.wait_closed()
class TestRegistrationErrors(BaseSshTests):
"""Test class for errors related to registartion."""
@pytest.mark.enable_registration(True)
@pytest.mark.registration_sources("192.0.2.0/24")
@pytest.mark.asyncio
async def test_register_client_invalid_source(
self, ssh_server: SshServerFixture
) -> None:
"""Test client registration from a network that's not permitted."""
_, port = ssh_server
with pytest.raises(asyncssh.misc.PermissionDenied):
async with self.unregistered_client("stranger", port) as conn:
async with conn.create_process("register") as process:
stdout, stderr = process.collect_output()
print(f"{stdout=!r}\n{stderr=!r}")
if isinstance(stdout, str):
assert "Enter public key" not in stdout
result = await process.wait()
assert result.exit_status == 1
@pytest.mark.enable_registration(True)
@pytest.mark.registration_sources("127.0.0.1")
@pytest.mark.asyncio
async def test_invalid_key_type(self, ssh_server: SshServerFixture) -> None:
"""Test registration with an unsupported key."""
_, port = ssh_server
private_key = asyncssh.generate_private_key("ssh-ed25519")
public_key = private_key.export_public_key().decode().rstrip() + "\n"
async with self.ssh_connection("stranger", port, private_key) as conn:
async with conn.create_process("register") as process:
output = await process.stdout.readline()
assert "Enter public key" in output
stdout, stderr = await process.communicate(public_key)
print(f"{stdout=!r}, {stderr=!r}")
assert stderr == "Error: Invalid key type: Only RSA keys are supported."
result = await process.wait()
assert result.exit_status == 1
@pytest.mark.enable_registration(True)
@pytest.mark.registration_sources("127.0.0.1")
@pytest.mark.asyncio
async def test_invalid_key(self, ssh_server: SshServerFixture) -> None:
"""Test registration with a bogus string as key.."""
_, port = ssh_server
private_key = asyncssh.generate_private_key("ssh-ed25519")
public_key = f"ssh-test {'A' * 544}\n"
async with self.ssh_connection("stranger", port, private_key) as conn:
async with conn.create_process("register") as process:
output = await process.stdout.readline()
assert "Enter public key" in output
stdout, stderr = await process.communicate(public_key)
print(f"{stdout=!r}, {stderr=!r}")
assert stderr == "Error: Invalid key type: Only RSA keys are supported."
result = await process.wait()
assert result.exit_status == 1
class TestCommandErrors(BaseSshTests):
"""Tests various errors around commands."""
@pytest.mark.asyncio
async def test_invalid_command(
self, ssh_command_runner: CommandRunner, client_registry: ClientRegistry
) -> None:
"""Test sending an invalid command."""
await client_registry["add_client"]("test")
result = await ssh_command_runner("test", "cat /etc/passwd")
assert result.exit_status == 1
stderr = result.stderr or ""
assert stderr == "Error: Unsupported command."
@pytest.mark.asyncio
async def test_no_command(
self, ssh_server: SshServerFixture, client_registry: ClientRegistry
) -> None:
"""Test sending no command."""
await client_registry["add_client"]("test")
_, port = ssh_server
client_key = client_registry["clients"]["test"]
async with self.ssh_connection("test", port, client_key.private_key) as conn:
async with conn.create_process() as process:
stdout, stderr = await process.communicate()
print(f"{stdout=!r}, {stderr=!r}")
assert stderr == "Error: No command was received from the client."
result = await process.wait()
assert result.exit_status == 1
@pytest.mark.asyncio
async def test_deny_client_connection(
self, ssh_command_runner: CommandRunner, client_registry: ClientRegistry
) -> None:
"""Test client that is not permitted to connect."""
await client_registry["add_client"](
"test-client",
["mysecret"],
["192.0.2.0/24"],
)
with pytest.raises(asyncssh.misc.PermissionDenied):
await ssh_command_runner("test-client", "get_secret mysecret")