This commit is contained in:
2025-03-23 16:06:10 +01:00
parent bbf2d0b280
commit de46a47acc
10 changed files with 208 additions and 465 deletions

View File

@ -0,0 +1,5 @@
"""Backend implementations"""
from .file_table import FileTableBackend
__all__ = ["FileTableBackend"]

View File

@ -1,4 +1,4 @@
"""Client loaders.""" """File table based backend."""
import logging import logging
import os import os
@ -9,7 +9,7 @@ import littletable as lt
from sshecret.crypto import load_client_key, encrypt_string from sshecret.crypto import load_client_key, encrypt_string
from sshecret.types import ClientSpecification from sshecret.types import ClientSpecification
from .types import BaseClientBackend from sshecret.types import BaseClientBackend
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)

23
src/sshecret/client.py Normal file
View File

@ -0,0 +1,23 @@
"""Client code"""
from typing import TextIO
import click
from sshecret.crypto import decode_string, load_private_key
def decrypt_secret(encoded: str, client_key: str) -> str:
"""Decrypt secret."""
private_key = load_private_key(client_key)
return decode_string(encoded, private_key)
@click.command()
@click.argument("keyfile", type=click.Path(exists=True, readable=True, dir_okay=False))
@click.argument("encrypted_input", type=click.File("r"))
def cli_decrypt(keyfile: str, encrypted_input: TextIO) -> None:
"""Decrypt on command line."""
decrypted = decrypt_secret(encrypted_input.read(), keyfile)
click.echo(decrypted)
if __name__ == "__main__":
cli_decrypt()

View File

@ -1,5 +1,8 @@
"""Development CLI commands.""" """Development CLI commands."""
import sys
import asyncio
import asyncssh
import click import click
import logging import logging
@ -7,8 +10,8 @@ import tempfile
import threading import threading
from pathlib import Path from pathlib import Path
from .server import SshKeyServer from .server.async_server import start_server
from .server.client_loader import FileTableBackend from sshecret.backends import FileTableBackend
from .utils import create_client_file, add_secret_to_client_file from .utils import create_client_file, add_secret_to_client_file
@ -50,15 +53,21 @@ def add_secret(filename: str, secret_name: str, secret_value: str) -> None:
@cli.command("server") @cli.command("server")
@click.argument("directory", type=click.Path(file_okay=False, dir_okay=True)) @click.argument("directory", type=click.Path(file_okay=False, dir_okay=True))
@click.argument("port", type=click.INT) @click.argument("port", type=click.INT)
def run_server(directory: str, port: int) -> None: def run_async_server(directory: str, port: int) -> None:
"""Run server.""" """Run async server."""
loop = asyncio.new_event_loop()
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
serverdir = Path(tmpdir) serverdir = Path(tmpdir)
host_key = serverdir / "hostkey" host_key = str(serverdir / "hostkey")
clientdir = Path(directory) clientdir = Path(directory)
backend = FileTableBackend(clientdir) backend = FileTableBackend(clientdir)
SshKeyServer.start_server(host_key, clients=backend, port=port, create_key=True) try:
loop.run_until_complete(start_server(port, backend, host_key, True))
except (OSError, asyncssh.Error) as exc:
click.echo(f"Error starting server: {exc}")
sys.exit(1)
loop.run_forever()

View File

@ -0,0 +1,122 @@
"""Server implemented with asyncssh."""
import logging
from functools import partial
from pathlib import Path
from typing import override
import asyncssh
from sshecret import constants
from sshecret.types import ClientSpecification, BaseClientBackend
from sshecret.crypto import create_private_rsa_key
LOG = logging.getLogger(__name__)
def handle_client(process: asyncssh.SSHServerProcess[str]) -> None:
"""Handle client."""
client_found = process.get_extra_info("client_allowed", False)
if not client_found:
process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n")
process.exit(1)
return
client_allowed = process.get_extra_info("client_allowed", False)
if not client_allowed:
process.stderr.write(constants.ERROR_SOURCE_IP_NOT_ALLOWED + "\n")
process.exit(1)
return
client = process.get_extra_info("client")
if not client:
process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n")
process.exit(1)
return
secret_name = process.command
if not secret_name:
process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n")
process.exit(1)
return
LOG.debug("Client %s successfully connected. Fetching secret %s", client.name, secret_name)
secret = client.secrets.get(secret_name)
if not secret:
process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n")
process.exit(1)
return
process.stdout.write(secret)
process.exit(0)
class AsshyncServer(asyncssh.SSHServer):
"""Asynchronous SSH server implementation."""
def __init__(self, backend: BaseClientBackend) -> None:
"""Initialize server."""
self.backend: BaseClientBackend = backend
self._conn: asyncssh.SSHServerConnection | None = None
@override
def connection_made(self, conn: asyncssh.SSHServerConnection) -> None:
"""Handle incoming connection."""
self._conn = conn
@override
def begin_auth(self, username: str) -> bool:
"""Begin authentication."""
if not self._conn:
return True
client = self.backend.lookup_name(username)
if not client:
return True
self._conn.set_extra_info(client_found=True)
remote_ip = self._conn.get_extra_info("peername")[0]
LOG.debug("Remote_IP: %r", remote_ip)
assert isinstance(remote_ip, str)
if self.check_connection_allowed(client, remote_ip):
self._conn.set_extra_info(client_allowed=True)
self._conn.set_extra_info(client=client)
# Load the key.
public_key = asyncssh.import_authorized_keys(client.public_key)
self._conn.set_authorized_keys(public_key)
return True
@override
def password_auth_supported(self) -> bool:
"""Deny password authentication."""
return False
def check_connection_allowed(self, client: ClientSpecification, source: str) -> bool:
"""Check if client is allowed to request secrets."""
LOG.debug(
"Checking if client is allowed to log in from %s", source
)
if isinstance(client.allowed_ips, str) and client.allowed_ips == "*":
LOG.debug("Client has no restrictions on source IP address. Permitting.")
return True
if isinstance(client.allowed_ips, str) and "/" not in client.allowed_ips:
if source == client.allowed_ips:
LOG.debug("Client IP matches permitted address")
return True
LOG.warning(
"Connection for client %s received from IP address %s that is not permitted.",
client.name,
source
)
return False
async def start_server(port: int, backend: BaseClientBackend, host_key: str, create_key: bool = False) -> None:
"""Start server."""
server = partial(AsshyncServer, backend=backend)
if create_key:
create_private_rsa_key(Path(host_key))
await asyncssh.create_server(server, '', port, server_host_keys=[host_key], process_factory=handle_client)

View File

@ -1,309 +0,0 @@
"""SSH Server implementation."""
import ipaddress
import logging
import threading
import socket
from contextvars import ContextVar, Context, copy_context
from pathlib import Path
from typing import Any, override
import paramiko
from paramiko.common import (
AUTH_SUCCESSFUL,
AUTH_FAILED,
OPEN_SUCCEEDED,
OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED,
)
from sshecret import constants
from sshecret.crypto import create_private_rsa_key
from sshecret.types import ClientSpecification
from .types import BaseClientBackend, BaseServer
from . import errors
CLIENT_STATE = threading.local()
client_secret_request_name: ContextVar[str] = ContextVar("client_secret_request_name")
client_request_name: ContextVar[str] = ContextVar("client_request_name")
LOG = logging.getLogger(__name__)
class TransportContext(paramiko.Transport):
"""Context-aware transport."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.__context: Context | None = None
def snapshot_context(self) -> None:
"""Take a snapshot of the current context."""
LOG.debug("Snapshot!")
self.__context = copy_context()
#object.__setattr__(self, "__context", copy_context())
def get_context(self) -> None:
"""Get the frozen context into our current one."""
if contextobj := self.__context:
for var, value in contextobj.items():
var.set(value)
class SshServerInterface(paramiko.ServerInterface):
"""Define our ssh server interface."""
def __init__(self, clients: BaseClientBackend, client_address: str) -> None:
"""Initialize server interface."""
self.clients: BaseClientBackend = clients
self.client_address: str = client_address
self.event: threading.Event = threading.Event()
@override
def check_auth_publickey(self, username: str, key: paramiko.PKey) -> int:
"""Check if we can authenticate."""
LOG.debug("Verifying public key of username %s", username)
if self.clients.verify_key(username, key):
LOG.debug("Key matches configured key.")
return AUTH_SUCCESSFUL
LOG.warning("Key did not match. Auth denied!")
return AUTH_FAILED
@override
def get_allowed_auths(self, username: str) -> str:
"""Get allowed auth methods."""
return "publickey"
@override
def check_channel_request(self, kind: str, chanid: int) -> int:
LOG.debug("Open channel request received: kind=%s, chanid=%r", kind, chanid)
if kind == "session":
LOG.debug("Session requested.")
return OPEN_SUCCEEDED
LOG.warning("Prohibited channel request received.")
return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
@override
def check_channel_shell_request(self, channel: paramiko.Channel) -> bool:
"""Check shell request."""
# This shouldn't be allowed.
LOG.debug("Channel request received. Channel: %r", channel)
return False
@override
def check_channel_exec_request(
self, channel: paramiko.Channel, command: bytes
) -> bool:
"""Check if the exec request is valid.
This is where we send the password. The command is always the name of
the secret.
"""
LOG.debug("Exec request received: command: %r", command)
# Documentation says command is a string, but typeshed says it's bytes...
command_str = command.decode()
transport = channel.get_transport()
if not transport.is_authenticated():
return False
username = transport.get_username()
LOG.debug("Resolved username: %r", username)
if not username:
return False
if not isinstance(channel.transport, TransportContext):
LOG.critical("Error: Incorrect transport class. Cannot process commands.")
self.event.set()
return False
client_secret_request_name.set(command_str)
client_request_name.set(username)
channel.transport.snapshot_context()
self.event.set()
LOG.debug("Command check completed.")
return True
def check_allowed_client_ip(self, client: ClientSpecification) -> bool:
"""Check if the client is allowed to log in based on source IP."""
LOG.debug(
"Checking if client is allowed to log in from %s", self.client_address
)
if isinstance(client.allowed_ips, str) and client.allowed_ips == "*":
LOG.debug("Client has no restrictions on source IP address. Permitting.")
return True
if isinstance(client.allowed_ips, str):
if self.client_address == client.allowed_ips:
LOG.debug("Client IP matches permitted address")
return True
LOG.warning(
"Connection for client %s received from IP address %s that is not permitted.",
client.name,
self.client_address,
)
return False
client_ip = ipaddress.ip_address(self.client_address)
if client_ip in client.allowed_ips:
LOG.debug("Client IP matches permitted address")
return True
LOG.warning(
"Connection for client %s received from IP address %s that is not permitted.",
client.name,
self.client_address,
)
return False
class SshKeyServer(BaseServer):
"""SSH secrets server."""
def __init__(self, host_key: Path, clients: BaseClientBackend) -> None:
"""Create server instance."""
super().__init__()
self._host_key: paramiko.RSAKey = paramiko.RSAKey.from_private_key_file(
str(host_key), None
)
self.clients: BaseClientBackend = clients
def resolve_client(self) -> ClientSpecification:
"""Resolve client."""
LOG.debug("Looking up client data.")
client_name = client_request_name.get(None)
if not client_name:
LOG.debug("No context data was resolved.")
raise errors.UnknownClientError(constants.ERROR_NO_COMMAND_RECEIVED)
client = self.clients.lookup_name(str(client_name))
if not client:
raise errors.UnknownClientError(constants.ERROR_UKNOWN_CLIENT_OR_SECRET)
return client
def get_secret(self, client: ClientSpecification) -> str:
"""Get command."""
LOG.debug("Looking up secret as requested.")
secret_name = client_secret_request_name.get(None)
if not secret_name:
raise errors.UnknownSecretError(constants.ERROR_UKNOWN_CLIENT_OR_SECRET)
secret = client.secrets.get(str(secret_name))
if not secret:
raise errors.UnknownSecretError(constants.ERROR_NO_SECRET_FOUND)
return secret
def check_connection_allowed(self, client: ClientSpecification, source: str) -> None:
"""Check if client is allowed to request secrets."""
LOG.debug(
"Checking if client is allowed to log in from %s", source
)
if isinstance(client.allowed_ips, str) and client.allowed_ips == "*":
LOG.debug("Client has no restrictions on source IP address. Permitting.")
return
if isinstance(client.allowed_ips, str) and "/" not in client.allowed_ips:
if source == client.allowed_ips:
LOG.debug("Client IP matches permitted address")
return
LOG.warning(
"Connection for client %s received from IP address %s that is not permitted.",
client.name,
source
)
raise errors.AccessPolicyViolationError(constants.ERROR_SOURCE_IP_NOT_ALLOWED)
source_ip = ipaddress.ip_address(source)
permitted = False
for client_ip in client.allowed_ips:
if isinstance(client_ip, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
if source_ip in client_ip:
permitted = True
break
else:
if source_ip == client_ip:
permitted = True
break
if not permitted:
raise errors.AccessPolicyViolationError(constants.ERROR_SOURCE_IP_NOT_ALLOWED)
LOG.debug("Matched client to permitted IP address statement.")
return
@override
def connection_function(
self, client_socket: socket.socket, client_address: str
) -> None:
"""Run on connection."""
LOG.debug("Connection function called by %s", client_address)
try:
session = TransportContext(client_socket)
session.add_server_key(self._host_key)
server = SshServerInterface(self.clients, client_address)
try:
session.start_server(server=server)
except paramiko.SSHException:
return
channel = session.accept(30)
if not channel:
LOG.debug("No channel opened!")
return
LOG.debug("Got channel: %r, transport: %r, ", channel, channel.transport)
server.event.wait()
LOG.debug("Opening channel file")
stdout = channel.makefile("rw")
LOG.debug("Extracting context.")
session.get_context()
try:
LOG.debug("Looking up client.")
client = self.resolve_client()
LOG.debug("Checking source address policy.")
self.check_connection_allowed(client, client_address)
LOG.debug("Looking up secret.")
secret = self.get_secret(client)
except errors.BaseSshecretServerError as e:
error_message = f"{e}\n"
LOG.critical(e, exc_info=True)
stdout.write(error_message)
session.close()
return
stdout.write(secret)
session.close()
except Exception as e:
LOG.critical(e, exc_info=True)
@classmethod
def start_server(
cls,
host_key: str | Path,
clients: BaseClientBackend,
bind_address: str = "127.0.0.1",
port: int = 22,
timeout: int = 10,
create_key: bool = False,
) -> None:
"""Start the server.
Args:
host_key: path to the private host key (str or Path)
clients: Client secret loader instance.
bind_address: address to bind to (default: 127.0.0.1)
port: Port to bind to (default: 22)
timeout: Socket timeout, default 1 second
create_key: Create the private key if it doesn't exist.
"""
if isinstance(host_key, str):
host_key = Path(host_key)
if not host_key.exists():
if create_key:
create_private_rsa_key(host_key)
else:
raise RuntimeError("Error: provided host key does not exist.")
server = cls(host_key, clients)
server.start(bind_address, port, timeout)

View File

@ -1,137 +0,0 @@
"""Base types and interfaces for the server."""
import abc
import logging
import socket
import sys
import threading
from typing import Any, TypeGuard
from cryptography.hazmat.primitives import serialization
import paramiko
from sshecret.types import ClientSpecification
from sshecret.crypto import load_client_key
LOG = logging.getLogger(__name__)
def validate_socket_tuple(data: Any) -> TypeGuard[tuple[str, int]]:
"""Validate socket accept return data.."""
if not isinstance(data, tuple):
return False
if not len(data) == 2:
return False
ip, port = data # pyright: ignore[reportUnknownVariableType]
if not isinstance(ip, str):
return False
if not isinstance(port, int):
return False
return True
class BaseClientBackend(abc.ABC):
"""Base client backend.
This class is responsible for managing the list of clients and facilitate
lookups.
"""
@abc.abstractmethod
def lookup_name(self, name: str) -> ClientSpecification | None:
"""Lookup a client specification by name."""
def _convert_to_pkey(self, client: ClientSpecification) -> paramiko.RSAKey:
"""Convert client key to paramiko key."""
client_key = load_client_key(client)
return paramiko.RSAKey(key=client_key)
def verify_key(self, name: str, key: paramiko.PKey) -> ClientSpecification | None:
"""Verify key."""
client = self.lookup_name(name)
if not client:
return None
LOG.debug("Verifying key: %r", key)
expected_key = self._convert_to_pkey(client)
if key == expected_key:
return client
return None
@abc.abstractmethod
def add_secret(self, client_name: str, secret_name: str, secret_value: str, encrypted: bool = False) -> None:
"""Add a secret to a client."""
@abc.abstractmethod
def add_client(self, spec: ClientSpecification) -> None:
"""Add a new client."""
@abc.abstractmethod
def update_client(self, name: str, spec: ClientSpecification) -> None:
"""Update client information."""
@abc.abstractmethod
def remove_client(self, name: str, persistent: bool = True) -> None:
"""Delete a client."""
class BaseServer(abc.ABC):
"""Base SSH server."""
def __init__(self) -> None:
"""Initialize the server."""
self._is_running: threading.Event = threading.Event()
self._socket: socket.socket | None = None
self._listen_thread: threading.Thread | None = None
def start(
self, address: str = "127.0.0.1", port: int = 22, timeout: int = 1
) -> None:
"""Start the server."""
if not self._is_running.is_set():
LOG.info("Starting SSH server on %s port %s", address, port)
self._is_running.set()
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
if sys.platform == "linux" or sys.platform == "linux2":
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, True)
LOG.debug("Setting socket timeout %s", timeout)
self._socket.settimeout(timeout)
LOG.debug("Binding to %s:%s", address, port)
self._socket.bind((address, port))
LOG.debug("Spawning thread.")
self._listen_thread = threading.Thread(target=self._listen)
self._listen_thread.start()
def stop(self) -> None:
"""Stop the server."""
if self._is_running.is_set():
self._is_running.clear()
if self._listen_thread:
self._listen_thread.join()
if self._socket:
self._socket.close()
@abc.abstractmethod
def connection_function(
self, client_socket: socket.socket, client_address: str
) -> None:
"""Run function on connect."""
def _listen(self) -> None:
"""Connect client to function."""
if self._socket is None:
raise RuntimeError("Received connection request without any socket")
while self._is_running.is_set():
try:
self._socket.listen()
client_socket, addr = self._socket.accept() # pyright: ignore[reportAny]
if not validate_socket_tuple(addr):
LOG.warning("Socket address tuple did not pass typeguard check!")
continue
LOG.debug("Received connection from %r", addr)
self.connection_function(client_socket, addr[0])
except socket.timeout:
pass

View File

@ -83,3 +83,31 @@ class ClientSpecification(BaseModel):
allowed_ips: list[IPvAnyAddress | IPvAnyNetwork] | str = "*" allowed_ips: list[IPvAnyAddress | IPvAnyNetwork] | str = "*"
secrets: dict[str, str] = {} secrets: dict[str, str] = {}
testing_private_key: str | None = None # Private key only for testing purposes! testing_private_key: str | None = None # Private key only for testing purposes!
class BaseClientBackend(abc.ABC):
"""Base client backend.
This class is responsible for managing the list of clients and facilitate
lookups.
"""
@abc.abstractmethod
def lookup_name(self, name: str) -> ClientSpecification | None:
"""Lookup a client specification by name."""
@abc.abstractmethod
def add_secret(self, client_name: str, secret_name: str, secret_value: str, encrypted: bool = False) -> None:
"""Add a secret to a client."""
@abc.abstractmethod
def add_client(self, spec: ClientSpecification) -> None:
"""Add a new client."""
@abc.abstractmethod
def update_client(self, name: str, spec: ClientSpecification) -> None:
"""Update client information."""
@abc.abstractmethod
def remove_client(self, name: str, persistent: bool = True) -> None:
"""Delete a client."""

View File

@ -1,7 +1,8 @@
"""Tests of client loader.""" """Tests of client loader."""
# pyright: reportUninitializedInstanceVariable=false, reportImplicitOverride=false
import unittest import unittest
from sshecret.server import client_loader from sshecret.backends import FileTableBackend
from sshecret.utils import generate_client_object from sshecret.utils import generate_client_object
from sshecret.testing import TestClientSpec, test_context from sshecret.testing import TestClientSpec, test_context
@ -11,7 +12,7 @@ class TestFileTableBackend(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
"""Set up tests.""" """Set up tests."""
self.test_dataset = [ self.test_dataset: list[TestClientSpec] = [
TestClientSpec("webserver", {"SECRET_TOKEN": "mysecrettoken"}), TestClientSpec("webserver", {"SECRET_TOKEN": "mysecrettoken"}),
TestClientSpec("dbserver", {"DB_ROOT_PASSWORD": "mysecretpassword"}), TestClientSpec("dbserver", {"DB_ROOT_PASSWORD": "mysecretpassword"}),
] ]
@ -19,21 +20,22 @@ class TestFileTableBackend(unittest.TestCase):
def test_init(self) -> None: def test_init(self) -> None:
"""Test instance creation.""" """Test instance creation."""
with test_context(self.test_dataset) as testdir: with test_context(self.test_dataset) as testdir:
backend = client_loader.FileTableBackend(testdir) backend = FileTableBackend(testdir)
self.assertGreater(len(backend.table), 0) self.assertGreater(len(backend.table), 0)
def test_lookup_name(self) -> None: def test_lookup_name(self) -> None:
"""Test lookup name.""" """Test lookup name."""
with test_context(self.test_dataset) as testdir: with test_context(self.test_dataset) as testdir:
backend = client_loader.FileTableBackend(testdir) backend = FileTableBackend(testdir)
webserver = backend.lookup_name("webserver") webserver = backend.lookup_name("webserver")
self.assertIsNotNone(webserver) self.assertIsNotNone(webserver)
assert webserver is not None
self.assertEqual(webserver.name, "webserver") self.assertEqual(webserver.name, "webserver")
def test_add_client(self) -> None: def test_add_client(self) -> None:
"""Test whether it is possible to add a client.""" """Test whether it is possible to add a client."""
with test_context(self.test_dataset) as testdir: with test_context(self.test_dataset) as testdir:
backend = client_loader.FileTableBackend(testdir) backend = FileTableBackend(testdir)
new_client = generate_client_object( new_client = generate_client_object(
"backupserver", {"BACKUP_KEY": "mysecretbackupkey"} "backupserver", {"BACKUP_KEY": "mysecretbackupkey"}
) )
@ -46,7 +48,7 @@ class TestFileTableBackend(unittest.TestCase):
def test_add_secret(self) -> None: def test_add_secret(self) -> None:
"""Test whether it is possible to add a secret.""" """Test whether it is possible to add a secret."""
with test_context(self.test_dataset) as testdir: with test_context(self.test_dataset) as testdir:
backend = client_loader.FileTableBackend(testdir) backend = FileTableBackend(testdir)
backend.add_secret("webserver", "OTHER_SECRET_TOKEN", "myothersecrettoken") backend.add_secret("webserver", "OTHER_SECRET_TOKEN", "myothersecrettoken")
webserver = backend.lookup_name("webserver") webserver = backend.lookup_name("webserver")
assert webserver is not None assert webserver is not None
@ -65,7 +67,7 @@ class TestFileTableBackend(unittest.TestCase):
def test_update_client(self) -> None: def test_update_client(self) -> None:
"""Test update_client method.""" """Test update_client method."""
with test_context(self.test_dataset) as testdir: with test_context(self.test_dataset) as testdir:
backend = client_loader.FileTableBackend(testdir) backend = FileTableBackend(testdir)
webserver = backend.lookup_name("webserver") webserver = backend.lookup_name("webserver")
assert webserver is not None assert webserver is not None
webserver.allowed_ips = "192.0.2.1" webserver.allowed_ips = "192.0.2.1"
@ -77,7 +79,7 @@ class TestFileTableBackend(unittest.TestCase):
def test_remove_client(self) -> None: def test_remove_client(self) -> None:
"""Test removal of client.""" """Test removal of client."""
with test_context(self.test_dataset) as testdir: with test_context(self.test_dataset) as testdir:
backend = client_loader.FileTableBackend(testdir) backend = FileTableBackend(testdir)
backend.remove_client("webserver", persistent=False) backend.remove_client("webserver", persistent=False)
webserver = backend.lookup_name("webserver") webserver = backend.lookup_name("webserver")
self.assertIsNone(webserver) self.assertIsNone(webserver)
@ -87,7 +89,7 @@ class TestFileTableBackend(unittest.TestCase):
def test_remove_client_persistent(self) -> None: def test_remove_client_persistent(self) -> None:
"""Test removal of client.""" """Test removal of client."""
with test_context(self.test_dataset) as testdir: with test_context(self.test_dataset) as testdir:
backend = client_loader.FileTableBackend(testdir) backend = FileTableBackend(testdir)
backend.remove_client("webserver", persistent=True) backend.remove_client("webserver", persistent=True)
webserver = backend.lookup_name("webserver") webserver = backend.lookup_name("webserver")
self.assertIsNone(webserver) self.assertIsNone(webserver)

View File

@ -35,8 +35,8 @@ class TestBasicCrypto(unittest.TestCase):
def test_key_loading(self) -> None: def test_key_loading(self) -> None:
"""Test basic flow.""" """Test basic flow."""
public_key = load_public_key(self.public_key) load_public_key(self.public_key)
private_key = load_private_key(self.private_key) load_private_key(self.private_key)
self.assertEqual(True, True) self.assertEqual(True, True)
def test_encrypt_decrypt(self) -> None: def test_encrypt_decrypt(self) -> None: