From de46a47accb6b98849b57fa419fab4a80a344518 Mon Sep 17 00:00:00 2001 From: Allan Eising Date: Sun, 23 Mar 2025 16:06:10 +0100 Subject: [PATCH] Refactor --- src/sshecret/backends/__init__.py | 5 + .../file_table.py} | 4 +- src/sshecret/client.py | 23 ++ src/sshecret/dev_cli.py | 21 +- src/sshecret/server/async_server.py | 122 +++++++ src/sshecret/server/server.py | 309 ------------------ src/sshecret/server/types.py | 137 -------- src/sshecret/types.py | 28 ++ tests/test_client_backend.py | 20 +- tests/test_crypto.py | 4 +- 10 files changed, 208 insertions(+), 465 deletions(-) create mode 100644 src/sshecret/backends/__init__.py rename src/sshecret/{server/client_loader.py => backends/file_table.py} (98%) create mode 100644 src/sshecret/client.py create mode 100644 src/sshecret/server/async_server.py delete mode 100644 src/sshecret/server/server.py delete mode 100644 src/sshecret/server/types.py diff --git a/src/sshecret/backends/__init__.py b/src/sshecret/backends/__init__.py new file mode 100644 index 0000000..9e79583 --- /dev/null +++ b/src/sshecret/backends/__init__.py @@ -0,0 +1,5 @@ +"""Backend implementations""" + +from .file_table import FileTableBackend + +__all__ = ["FileTableBackend"] diff --git a/src/sshecret/server/client_loader.py b/src/sshecret/backends/file_table.py similarity index 98% rename from src/sshecret/server/client_loader.py rename to src/sshecret/backends/file_table.py index e4125d6..c376193 100644 --- a/src/sshecret/server/client_loader.py +++ b/src/sshecret/backends/file_table.py @@ -1,4 +1,4 @@ -"""Client loaders.""" +"""File table based backend.""" import logging import os @@ -9,7 +9,7 @@ import littletable as lt from sshecret.crypto import load_client_key, encrypt_string from sshecret.types import ClientSpecification -from .types import BaseClientBackend +from sshecret.types import BaseClientBackend LOG = logging.getLogger(__name__) diff --git a/src/sshecret/client.py b/src/sshecret/client.py new file mode 100644 index 0000000..4773241 --- /dev/null +++ b/src/sshecret/client.py @@ -0,0 +1,23 @@ +"""Client code""" + +from typing import TextIO +import click + +from sshecret.crypto import decode_string, load_private_key + +def decrypt_secret(encoded: str, client_key: str) -> str: + """Decrypt secret.""" + private_key = load_private_key(client_key) + return decode_string(encoded, private_key) + + +@click.command() +@click.argument("keyfile", type=click.Path(exists=True, readable=True, dir_okay=False)) +@click.argument("encrypted_input", type=click.File("r")) +def cli_decrypt(keyfile: str, encrypted_input: TextIO) -> None: + """Decrypt on command line.""" + decrypted = decrypt_secret(encrypted_input.read(), keyfile) + click.echo(decrypted) + +if __name__ == "__main__": + cli_decrypt() diff --git a/src/sshecret/dev_cli.py b/src/sshecret/dev_cli.py index 34a6f41..a5c5c7d 100644 --- a/src/sshecret/dev_cli.py +++ b/src/sshecret/dev_cli.py @@ -1,5 +1,8 @@ """Development CLI commands.""" +import sys +import asyncio +import asyncssh import click import logging @@ -7,8 +10,8 @@ import tempfile import threading from pathlib import Path -from .server import SshKeyServer -from .server.client_loader import FileTableBackend +from .server.async_server import start_server +from sshecret.backends import FileTableBackend from .utils import create_client_file, add_secret_to_client_file @@ -50,15 +53,21 @@ def add_secret(filename: str, secret_name: str, secret_value: str) -> None: @cli.command("server") @click.argument("directory", type=click.Path(file_okay=False, dir_okay=True)) @click.argument("port", type=click.INT) -def run_server(directory: str, port: int) -> None: - """Run server.""" +def run_async_server(directory: str, port: int) -> None: + """Run async server.""" + loop = asyncio.new_event_loop() with tempfile.TemporaryDirectory() as tmpdir: serverdir = Path(tmpdir) - host_key = serverdir / "hostkey" + host_key = str(serverdir / "hostkey") clientdir = Path(directory) backend = FileTableBackend(clientdir) - SshKeyServer.start_server(host_key, clients=backend, port=port, create_key=True) + try: + loop.run_until_complete(start_server(port, backend, host_key, True)) + except (OSError, asyncssh.Error) as exc: + click.echo(f"Error starting server: {exc}") + sys.exit(1) + loop.run_forever() diff --git a/src/sshecret/server/async_server.py b/src/sshecret/server/async_server.py new file mode 100644 index 0000000..ad984ed --- /dev/null +++ b/src/sshecret/server/async_server.py @@ -0,0 +1,122 @@ +"""Server implemented with asyncssh.""" + +import logging +from functools import partial +from pathlib import Path +from typing import override + +import asyncssh + +from sshecret import constants +from sshecret.types import ClientSpecification, BaseClientBackend +from sshecret.crypto import create_private_rsa_key + + +LOG = logging.getLogger(__name__) + + +def handle_client(process: asyncssh.SSHServerProcess[str]) -> None: + """Handle client.""" + client_found = process.get_extra_info("client_allowed", False) + if not client_found: + process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n") + process.exit(1) + return + + client_allowed = process.get_extra_info("client_allowed", False) + if not client_allowed: + process.stderr.write(constants.ERROR_SOURCE_IP_NOT_ALLOWED + "\n") + process.exit(1) + return + + client = process.get_extra_info("client") + if not client: + process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n") + process.exit(1) + return + + secret_name = process.command + if not secret_name: + process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n") + process.exit(1) + return + + LOG.debug("Client %s successfully connected. Fetching secret %s", client.name, secret_name) + + secret = client.secrets.get(secret_name) + if not secret: + process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n") + process.exit(1) + return + + process.stdout.write(secret) + process.exit(0) + +class AsshyncServer(asyncssh.SSHServer): + """Asynchronous SSH server implementation.""" + + def __init__(self, backend: BaseClientBackend) -> None: + """Initialize server.""" + self.backend: BaseClientBackend = backend + self._conn: asyncssh.SSHServerConnection | None = None + + @override + def connection_made(self, conn: asyncssh.SSHServerConnection) -> None: + """Handle incoming connection.""" + self._conn = conn + + @override + def begin_auth(self, username: str) -> bool: + """Begin authentication.""" + if not self._conn: + return True + client = self.backend.lookup_name(username) + if not client: + return True + self._conn.set_extra_info(client_found=True) + remote_ip = self._conn.get_extra_info("peername")[0] + LOG.debug("Remote_IP: %r", remote_ip) + assert isinstance(remote_ip, str) + if self.check_connection_allowed(client, remote_ip): + self._conn.set_extra_info(client_allowed=True) + self._conn.set_extra_info(client=client) + + # Load the key. + public_key = asyncssh.import_authorized_keys(client.public_key) + self._conn.set_authorized_keys(public_key) + + return True + + @override + def password_auth_supported(self) -> bool: + """Deny password authentication.""" + return False + + + def check_connection_allowed(self, client: ClientSpecification, source: str) -> bool: + """Check if client is allowed to request secrets.""" + LOG.debug( + "Checking if client is allowed to log in from %s", source + ) + if isinstance(client.allowed_ips, str) and client.allowed_ips == "*": + LOG.debug("Client has no restrictions on source IP address. Permitting.") + return True + if isinstance(client.allowed_ips, str) and "/" not in client.allowed_ips: + if source == client.allowed_ips: + LOG.debug("Client IP matches permitted address") + return True + + LOG.warning( + "Connection for client %s received from IP address %s that is not permitted.", + client.name, + source + ) + + return False + +async def start_server(port: int, backend: BaseClientBackend, host_key: str, create_key: bool = False) -> None: + """Start server.""" + server = partial(AsshyncServer, backend=backend) + if create_key: + create_private_rsa_key(Path(host_key)) + await asyncssh.create_server(server, '', port, server_host_keys=[host_key], process_factory=handle_client) diff --git a/src/sshecret/server/server.py b/src/sshecret/server/server.py deleted file mode 100644 index 273c3f7..0000000 --- a/src/sshecret/server/server.py +++ /dev/null @@ -1,309 +0,0 @@ -"""SSH Server implementation.""" - -import ipaddress -import logging -import threading -import socket -from contextvars import ContextVar, Context, copy_context -from pathlib import Path -from typing import Any, override - -import paramiko -from paramiko.common import ( - AUTH_SUCCESSFUL, - AUTH_FAILED, - OPEN_SUCCEEDED, - OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, -) - -from sshecret import constants -from sshecret.crypto import create_private_rsa_key -from sshecret.types import ClientSpecification - -from .types import BaseClientBackend, BaseServer - -from . import errors - -CLIENT_STATE = threading.local() - -client_secret_request_name: ContextVar[str] = ContextVar("client_secret_request_name") -client_request_name: ContextVar[str] = ContextVar("client_request_name") - -LOG = logging.getLogger(__name__) - - -class TransportContext(paramiko.Transport): - """Context-aware transport.""" - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.__context: Context | None = None - - def snapshot_context(self) -> None: - """Take a snapshot of the current context.""" - LOG.debug("Snapshot!") - self.__context = copy_context() - #object.__setattr__(self, "__context", copy_context()) - - def get_context(self) -> None: - """Get the frozen context into our current one.""" - if contextobj := self.__context: - for var, value in contextobj.items(): - var.set(value) - -class SshServerInterface(paramiko.ServerInterface): - """Define our ssh server interface.""" - - def __init__(self, clients: BaseClientBackend, client_address: str) -> None: - """Initialize server interface.""" - self.clients: BaseClientBackend = clients - self.client_address: str = client_address - self.event: threading.Event = threading.Event() - - @override - def check_auth_publickey(self, username: str, key: paramiko.PKey) -> int: - """Check if we can authenticate.""" - LOG.debug("Verifying public key of username %s", username) - if self.clients.verify_key(username, key): - LOG.debug("Key matches configured key.") - return AUTH_SUCCESSFUL - LOG.warning("Key did not match. Auth denied!") - return AUTH_FAILED - - @override - def get_allowed_auths(self, username: str) -> str: - """Get allowed auth methods.""" - return "publickey" - - @override - def check_channel_request(self, kind: str, chanid: int) -> int: - LOG.debug("Open channel request received: kind=%s, chanid=%r", kind, chanid) - if kind == "session": - LOG.debug("Session requested.") - return OPEN_SUCCEEDED - LOG.warning("Prohibited channel request received.") - return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED - - @override - def check_channel_shell_request(self, channel: paramiko.Channel) -> bool: - """Check shell request.""" - # This shouldn't be allowed. - LOG.debug("Channel request received. Channel: %r", channel) - return False - - - @override - def check_channel_exec_request( - self, channel: paramiko.Channel, command: bytes - ) -> bool: - """Check if the exec request is valid. - - This is where we send the password. The command is always the name of - the secret. - - """ - LOG.debug("Exec request received: command: %r", command) - # Documentation says command is a string, but typeshed says it's bytes... - command_str = command.decode() - transport = channel.get_transport() - if not transport.is_authenticated(): - return False - username = transport.get_username() - LOG.debug("Resolved username: %r", username) - if not username: - return False - - if not isinstance(channel.transport, TransportContext): - LOG.critical("Error: Incorrect transport class. Cannot process commands.") - self.event.set() - return False - client_secret_request_name.set(command_str) - client_request_name.set(username) - channel.transport.snapshot_context() - self.event.set() - LOG.debug("Command check completed.") - return True - - def check_allowed_client_ip(self, client: ClientSpecification) -> bool: - """Check if the client is allowed to log in based on source IP.""" - LOG.debug( - "Checking if client is allowed to log in from %s", self.client_address - ) - if isinstance(client.allowed_ips, str) and client.allowed_ips == "*": - LOG.debug("Client has no restrictions on source IP address. Permitting.") - return True - if isinstance(client.allowed_ips, str): - if self.client_address == client.allowed_ips: - LOG.debug("Client IP matches permitted address") - return True - LOG.warning( - "Connection for client %s received from IP address %s that is not permitted.", - client.name, - self.client_address, - ) - return False - client_ip = ipaddress.ip_address(self.client_address) - if client_ip in client.allowed_ips: - LOG.debug("Client IP matches permitted address") - return True - LOG.warning( - "Connection for client %s received from IP address %s that is not permitted.", - client.name, - self.client_address, - ) - return False - - -class SshKeyServer(BaseServer): - """SSH secrets server.""" - - def __init__(self, host_key: Path, clients: BaseClientBackend) -> None: - """Create server instance.""" - super().__init__() - self._host_key: paramiko.RSAKey = paramiko.RSAKey.from_private_key_file( - str(host_key), None - ) - self.clients: BaseClientBackend = clients - - def resolve_client(self) -> ClientSpecification: - """Resolve client.""" - LOG.debug("Looking up client data.") - client_name = client_request_name.get(None) - if not client_name: - LOG.debug("No context data was resolved.") - raise errors.UnknownClientError(constants.ERROR_NO_COMMAND_RECEIVED) - - client = self.clients.lookup_name(str(client_name)) - if not client: - raise errors.UnknownClientError(constants.ERROR_UKNOWN_CLIENT_OR_SECRET) - return client - - def get_secret(self, client: ClientSpecification) -> str: - """Get command.""" - LOG.debug("Looking up secret as requested.") - secret_name = client_secret_request_name.get(None) - if not secret_name: - raise errors.UnknownSecretError(constants.ERROR_UKNOWN_CLIENT_OR_SECRET) - - secret = client.secrets.get(str(secret_name)) - if not secret: - raise errors.UnknownSecretError(constants.ERROR_NO_SECRET_FOUND) - return secret - - def check_connection_allowed(self, client: ClientSpecification, source: str) -> None: - """Check if client is allowed to request secrets.""" - LOG.debug( - "Checking if client is allowed to log in from %s", source - ) - if isinstance(client.allowed_ips, str) and client.allowed_ips == "*": - LOG.debug("Client has no restrictions on source IP address. Permitting.") - return - if isinstance(client.allowed_ips, str) and "/" not in client.allowed_ips: - if source == client.allowed_ips: - LOG.debug("Client IP matches permitted address") - return - - LOG.warning( - "Connection for client %s received from IP address %s that is not permitted.", - client.name, - source - ) - - raise errors.AccessPolicyViolationError(constants.ERROR_SOURCE_IP_NOT_ALLOWED) - - source_ip = ipaddress.ip_address(source) - permitted = False - for client_ip in client.allowed_ips: - if isinstance(client_ip, (ipaddress.IPv4Network, ipaddress.IPv6Network)): - if source_ip in client_ip: - permitted = True - break - else: - if source_ip == client_ip: - permitted = True - break - if not permitted: - raise errors.AccessPolicyViolationError(constants.ERROR_SOURCE_IP_NOT_ALLOWED) - - LOG.debug("Matched client to permitted IP address statement.") - return - - - @override - def connection_function( - self, client_socket: socket.socket, client_address: str - ) -> None: - """Run on connection.""" - LOG.debug("Connection function called by %s", client_address) - try: - session = TransportContext(client_socket) - session.add_server_key(self._host_key) - server = SshServerInterface(self.clients, client_address) - try: - session.start_server(server=server) - except paramiko.SSHException: - return - - channel = session.accept(30) - if not channel: - LOG.debug("No channel opened!") - return - - LOG.debug("Got channel: %r, transport: %r, ", channel, channel.transport) - server.event.wait() - - LOG.debug("Opening channel file") - stdout = channel.makefile("rw") - - LOG.debug("Extracting context.") - session.get_context() - try: - LOG.debug("Looking up client.") - client = self.resolve_client() - LOG.debug("Checking source address policy.") - self.check_connection_allowed(client, client_address) - LOG.debug("Looking up secret.") - secret = self.get_secret(client) - except errors.BaseSshecretServerError as e: - error_message = f"{e}\n" - LOG.critical(e, exc_info=True) - stdout.write(error_message) - session.close() - return - - stdout.write(secret) - session.close() - except Exception as e: - LOG.critical(e, exc_info=True) - - @classmethod - def start_server( - cls, - host_key: str | Path, - clients: BaseClientBackend, - bind_address: str = "127.0.0.1", - port: int = 22, - timeout: int = 10, - create_key: bool = False, - ) -> None: - """Start the server. - - Args: - host_key: path to the private host key (str or Path) - clients: Client secret loader instance. - bind_address: address to bind to (default: 127.0.0.1) - port: Port to bind to (default: 22) - timeout: Socket timeout, default 1 second - create_key: Create the private key if it doesn't exist. - - """ - if isinstance(host_key, str): - host_key = Path(host_key) - - if not host_key.exists(): - if create_key: - create_private_rsa_key(host_key) - else: - raise RuntimeError("Error: provided host key does not exist.") - server = cls(host_key, clients) - server.start(bind_address, port, timeout) diff --git a/src/sshecret/server/types.py b/src/sshecret/server/types.py deleted file mode 100644 index d1a9e4b..0000000 --- a/src/sshecret/server/types.py +++ /dev/null @@ -1,137 +0,0 @@ -"""Base types and interfaces for the server.""" - -import abc -import logging -import socket -import sys -import threading - -from typing import Any, TypeGuard -from cryptography.hazmat.primitives import serialization - -import paramiko -from sshecret.types import ClientSpecification -from sshecret.crypto import load_client_key - - -LOG = logging.getLogger(__name__) - - -def validate_socket_tuple(data: Any) -> TypeGuard[tuple[str, int]]: - """Validate socket accept return data..""" - if not isinstance(data, tuple): - return False - if not len(data) == 2: - return False - ip, port = data # pyright: ignore[reportUnknownVariableType] - if not isinstance(ip, str): - return False - if not isinstance(port, int): - return False - return True - - -class BaseClientBackend(abc.ABC): - """Base client backend. - - This class is responsible for managing the list of clients and facilitate - lookups. - """ - - @abc.abstractmethod - def lookup_name(self, name: str) -> ClientSpecification | None: - """Lookup a client specification by name.""" - - def _convert_to_pkey(self, client: ClientSpecification) -> paramiko.RSAKey: - """Convert client key to paramiko key.""" - client_key = load_client_key(client) - - return paramiko.RSAKey(key=client_key) - - def verify_key(self, name: str, key: paramiko.PKey) -> ClientSpecification | None: - """Verify key.""" - client = self.lookup_name(name) - if not client: - return None - LOG.debug("Verifying key: %r", key) - expected_key = self._convert_to_pkey(client) - if key == expected_key: - return client - return None - - @abc.abstractmethod - def add_secret(self, client_name: str, secret_name: str, secret_value: str, encrypted: bool = False) -> None: - """Add a secret to a client.""" - - @abc.abstractmethod - def add_client(self, spec: ClientSpecification) -> None: - """Add a new client.""" - - @abc.abstractmethod - def update_client(self, name: str, spec: ClientSpecification) -> None: - """Update client information.""" - - @abc.abstractmethod - def remove_client(self, name: str, persistent: bool = True) -> None: - """Delete a client.""" - - -class BaseServer(abc.ABC): - """Base SSH server.""" - - def __init__(self) -> None: - """Initialize the server.""" - self._is_running: threading.Event = threading.Event() - self._socket: socket.socket | None = None - self._listen_thread: threading.Thread | None = None - - def start( - self, address: str = "127.0.0.1", port: int = 22, timeout: int = 1 - ) -> None: - """Start the server.""" - if not self._is_running.is_set(): - LOG.info("Starting SSH server on %s port %s", address, port) - self._is_running.set() - self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True) - if sys.platform == "linux" or sys.platform == "linux2": - self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, True) - LOG.debug("Setting socket timeout %s", timeout) - self._socket.settimeout(timeout) - LOG.debug("Binding to %s:%s", address, port) - self._socket.bind((address, port)) - LOG.debug("Spawning thread.") - self._listen_thread = threading.Thread(target=self._listen) - self._listen_thread.start() - - def stop(self) -> None: - """Stop the server.""" - if self._is_running.is_set(): - self._is_running.clear() - if self._listen_thread: - self._listen_thread.join() - if self._socket: - self._socket.close() - - @abc.abstractmethod - def connection_function( - self, client_socket: socket.socket, client_address: str - ) -> None: - """Run function on connect.""" - - def _listen(self) -> None: - """Connect client to function.""" - if self._socket is None: - raise RuntimeError("Received connection request without any socket") - - while self._is_running.is_set(): - try: - self._socket.listen() - client_socket, addr = self._socket.accept() # pyright: ignore[reportAny] - if not validate_socket_tuple(addr): - LOG.warning("Socket address tuple did not pass typeguard check!") - continue - LOG.debug("Received connection from %r", addr) - self.connection_function(client_socket, addr[0]) - except socket.timeout: - pass diff --git a/src/sshecret/types.py b/src/sshecret/types.py index b8be5d7..9d1754a 100644 --- a/src/sshecret/types.py +++ b/src/sshecret/types.py @@ -83,3 +83,31 @@ class ClientSpecification(BaseModel): allowed_ips: list[IPvAnyAddress | IPvAnyNetwork] | str = "*" secrets: dict[str, str] = {} testing_private_key: str | None = None # Private key only for testing purposes! + + +class BaseClientBackend(abc.ABC): + """Base client backend. + + This class is responsible for managing the list of clients and facilitate + lookups. + """ + + @abc.abstractmethod + def lookup_name(self, name: str) -> ClientSpecification | None: + """Lookup a client specification by name.""" + + @abc.abstractmethod + def add_secret(self, client_name: str, secret_name: str, secret_value: str, encrypted: bool = False) -> None: + """Add a secret to a client.""" + + @abc.abstractmethod + def add_client(self, spec: ClientSpecification) -> None: + """Add a new client.""" + + @abc.abstractmethod + def update_client(self, name: str, spec: ClientSpecification) -> None: + """Update client information.""" + + @abc.abstractmethod + def remove_client(self, name: str, persistent: bool = True) -> None: + """Delete a client.""" diff --git a/tests/test_client_backend.py b/tests/test_client_backend.py index bec9a80..23550a9 100644 --- a/tests/test_client_backend.py +++ b/tests/test_client_backend.py @@ -1,7 +1,8 @@ """Tests of client loader.""" +# pyright: reportUninitializedInstanceVariable=false, reportImplicitOverride=false import unittest -from sshecret.server import client_loader +from sshecret.backends import FileTableBackend from sshecret.utils import generate_client_object from sshecret.testing import TestClientSpec, test_context @@ -11,7 +12,7 @@ class TestFileTableBackend(unittest.TestCase): def setUp(self) -> None: """Set up tests.""" - self.test_dataset = [ + self.test_dataset: list[TestClientSpec] = [ TestClientSpec("webserver", {"SECRET_TOKEN": "mysecrettoken"}), TestClientSpec("dbserver", {"DB_ROOT_PASSWORD": "mysecretpassword"}), ] @@ -19,21 +20,22 @@ class TestFileTableBackend(unittest.TestCase): def test_init(self) -> None: """Test instance creation.""" with test_context(self.test_dataset) as testdir: - backend = client_loader.FileTableBackend(testdir) + backend = FileTableBackend(testdir) self.assertGreater(len(backend.table), 0) def test_lookup_name(self) -> None: """Test lookup name.""" with test_context(self.test_dataset) as testdir: - backend = client_loader.FileTableBackend(testdir) + backend = FileTableBackend(testdir) webserver = backend.lookup_name("webserver") self.assertIsNotNone(webserver) + assert webserver is not None self.assertEqual(webserver.name, "webserver") def test_add_client(self) -> None: """Test whether it is possible to add a client.""" with test_context(self.test_dataset) as testdir: - backend = client_loader.FileTableBackend(testdir) + backend = FileTableBackend(testdir) new_client = generate_client_object( "backupserver", {"BACKUP_KEY": "mysecretbackupkey"} ) @@ -46,7 +48,7 @@ class TestFileTableBackend(unittest.TestCase): def test_add_secret(self) -> None: """Test whether it is possible to add a secret.""" with test_context(self.test_dataset) as testdir: - backend = client_loader.FileTableBackend(testdir) + backend = FileTableBackend(testdir) backend.add_secret("webserver", "OTHER_SECRET_TOKEN", "myothersecrettoken") webserver = backend.lookup_name("webserver") assert webserver is not None @@ -65,7 +67,7 @@ class TestFileTableBackend(unittest.TestCase): def test_update_client(self) -> None: """Test update_client method.""" with test_context(self.test_dataset) as testdir: - backend = client_loader.FileTableBackend(testdir) + backend = FileTableBackend(testdir) webserver = backend.lookup_name("webserver") assert webserver is not None webserver.allowed_ips = "192.0.2.1" @@ -77,7 +79,7 @@ class TestFileTableBackend(unittest.TestCase): def test_remove_client(self) -> None: """Test removal of client.""" with test_context(self.test_dataset) as testdir: - backend = client_loader.FileTableBackend(testdir) + backend = FileTableBackend(testdir) backend.remove_client("webserver", persistent=False) webserver = backend.lookup_name("webserver") self.assertIsNone(webserver) @@ -87,7 +89,7 @@ class TestFileTableBackend(unittest.TestCase): def test_remove_client_persistent(self) -> None: """Test removal of client.""" with test_context(self.test_dataset) as testdir: - backend = client_loader.FileTableBackend(testdir) + backend = FileTableBackend(testdir) backend.remove_client("webserver", persistent=True) webserver = backend.lookup_name("webserver") self.assertIsNone(webserver) diff --git a/tests/test_crypto.py b/tests/test_crypto.py index c67a28d..5b9ebc6 100644 --- a/tests/test_crypto.py +++ b/tests/test_crypto.py @@ -35,8 +35,8 @@ class TestBasicCrypto(unittest.TestCase): def test_key_loading(self) -> None: """Test basic flow.""" - public_key = load_public_key(self.public_key) - private_key = load_private_key(self.private_key) + load_public_key(self.public_key) + load_private_key(self.private_key) self.assertEqual(True, True) def test_encrypt_decrypt(self) -> None: