"""Test various exceptions and error conditions.""" from collections.abc import AsyncIterator from contextlib import asynccontextmanager import asyncssh import pytest from .types import ClientRegistry, CommandRunner, ProcessRunner, SshServerFixture class BaseSshTests: """Base test class.""" @asynccontextmanager async def unregistered_client(self, username: str, port: int): """Generate SSH session as an uregistered client.""" private_key = asyncssh.generate_private_key("ssh-rsa") conn = await asyncssh.connect( "127.0.0.1", port=port, username=username, client_keys=[private_key], known_hosts=None, ) try: yield conn finally: conn.close() await conn.wait_closed() @asynccontextmanager async def ssh_connection( self, username: str, port: int, private_key: asyncssh.SSHKey ): """Generate SSH session as a client with an ed25519 key.""" # private_key = asyncssh.generate_private_key("ssh-ed25519") conn = await asyncssh.connect( "127.0.0.1", port=port, username=username, client_keys=[private_key], known_hosts=None, ) try: yield conn finally: conn.close() await conn.wait_closed() class TestRegistrationErrors(BaseSshTests): """Test class for errors related to registartion.""" @pytest.mark.enable_registration(True) @pytest.mark.registration_sources("192.0.2.0/24") @pytest.mark.asyncio async def test_register_client_invalid_source( self, ssh_server: SshServerFixture ) -> None: """Test client registration from a network that's not permitted.""" _, port = ssh_server with pytest.raises(asyncssh.misc.PermissionDenied): async with self.unregistered_client("stranger", port) as conn: async with conn.create_process("register") as process: stdout, stderr = process.collect_output() print(f"{stdout=!r}\n{stderr=!r}") if isinstance(stdout, str): assert "Enter public key" not in stdout result = await process.wait() assert result.exit_status == 1 @pytest.mark.enable_registration(True) @pytest.mark.registration_sources("127.0.0.1") @pytest.mark.asyncio async def test_invalid_key_type(self, ssh_server: SshServerFixture) -> None: """Test registration with an unsupported key.""" _, port = ssh_server private_key = asyncssh.generate_private_key("ssh-ed25519") public_key = private_key.export_public_key().decode().rstrip() + "\n" async with self.ssh_connection("stranger", port, private_key) as conn: async with conn.create_process("register") as process: output = await process.stdout.readline() assert "Enter public key" in output stdout, stderr = await process.communicate(public_key) print(f"{stdout=!r}, {stderr=!r}") assert stderr == "Error: Invalid key type: Only RSA keys are supported." result = await process.wait() assert result.exit_status == 1 @pytest.mark.enable_registration(True) @pytest.mark.registration_sources("127.0.0.1") @pytest.mark.asyncio async def test_invalid_key(self, ssh_server: SshServerFixture) -> None: """Test registration with a bogus string as key..""" _, port = ssh_server private_key = asyncssh.generate_private_key("ssh-ed25519") public_key = f"ssh-test {'A' * 544}\n" async with self.ssh_connection("stranger", port, private_key) as conn: async with conn.create_process("register") as process: output = await process.stdout.readline() assert "Enter public key" in output stdout, stderr = await process.communicate(public_key) print(f"{stdout=!r}, {stderr=!r}") assert stderr == "Error: Invalid key type: Only RSA keys are supported." result = await process.wait() assert result.exit_status == 1 class TestCommandErrors(BaseSshTests): """Tests various errors around commands.""" @pytest.mark.asyncio async def test_invalid_command( self, ssh_command_runner: CommandRunner, client_registry: ClientRegistry ) -> None: """Test sending an invalid command.""" await client_registry["add_client"]("test") result = await ssh_command_runner("test", "cat /etc/passwd") assert result.exit_status == 1 stderr = result.stderr or "" assert stderr == "Error: Unsupported command." @pytest.mark.asyncio async def test_no_command( self, ssh_server: SshServerFixture, client_registry: ClientRegistry ) -> None: """Test sending no command.""" await client_registry["add_client"]("test") _, port = ssh_server client_key = client_registry["clients"]["test"] async with self.ssh_connection("test", port, client_key.private_key) as conn: async with conn.create_process() as process: stdout, stderr = await process.communicate() print(f"{stdout=!r}, {stderr=!r}") assert stderr == "Error: No command was received from the client." result = await process.wait() assert result.exit_status == 1 @pytest.mark.asyncio async def test_deny_client_connection( self, ssh_command_runner: CommandRunner, client_registry: ClientRegistry ) -> None: """Test client that is not permitted to connect.""" await client_registry["add_client"]( "test-client", ["mysecret"], ["192.0.2.0/24"], ) with pytest.raises(asyncssh.misc.PermissionDenied): await ssh_command_runner("test-client", "get_secret mysecret")