Refactor
This commit is contained in:
5
src/sshecret/backends/__init__.py
Normal file
5
src/sshecret/backends/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
"""Backend implementations"""
|
||||||
|
|
||||||
|
from .file_table import FileTableBackend
|
||||||
|
|
||||||
|
__all__ = ["FileTableBackend"]
|
||||||
@ -1,4 +1,4 @@
|
|||||||
"""Client loaders."""
|
"""File table based backend."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@ -9,7 +9,7 @@ import littletable as lt
|
|||||||
|
|
||||||
from sshecret.crypto import load_client_key, encrypt_string
|
from sshecret.crypto import load_client_key, encrypt_string
|
||||||
from sshecret.types import ClientSpecification
|
from sshecret.types import ClientSpecification
|
||||||
from .types import BaseClientBackend
|
from sshecret.types import BaseClientBackend
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
23
src/sshecret/client.py
Normal file
23
src/sshecret/client.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
"""Client code"""
|
||||||
|
|
||||||
|
from typing import TextIO
|
||||||
|
import click
|
||||||
|
|
||||||
|
from sshecret.crypto import decode_string, load_private_key
|
||||||
|
|
||||||
|
def decrypt_secret(encoded: str, client_key: str) -> str:
|
||||||
|
"""Decrypt secret."""
|
||||||
|
private_key = load_private_key(client_key)
|
||||||
|
return decode_string(encoded, private_key)
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.argument("keyfile", type=click.Path(exists=True, readable=True, dir_okay=False))
|
||||||
|
@click.argument("encrypted_input", type=click.File("r"))
|
||||||
|
def cli_decrypt(keyfile: str, encrypted_input: TextIO) -> None:
|
||||||
|
"""Decrypt on command line."""
|
||||||
|
decrypted = decrypt_secret(encrypted_input.read(), keyfile)
|
||||||
|
click.echo(decrypted)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli_decrypt()
|
||||||
@ -1,5 +1,8 @@
|
|||||||
"""Development CLI commands."""
|
"""Development CLI commands."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import asyncio
|
||||||
|
import asyncssh
|
||||||
import click
|
import click
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@ -7,8 +10,8 @@ import tempfile
|
|||||||
import threading
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from .server import SshKeyServer
|
from .server.async_server import start_server
|
||||||
from .server.client_loader import FileTableBackend
|
from sshecret.backends import FileTableBackend
|
||||||
from .utils import create_client_file, add_secret_to_client_file
|
from .utils import create_client_file, add_secret_to_client_file
|
||||||
|
|
||||||
|
|
||||||
@ -50,15 +53,21 @@ def add_secret(filename: str, secret_name: str, secret_value: str) -> None:
|
|||||||
@cli.command("server")
|
@cli.command("server")
|
||||||
@click.argument("directory", type=click.Path(file_okay=False, dir_okay=True))
|
@click.argument("directory", type=click.Path(file_okay=False, dir_okay=True))
|
||||||
@click.argument("port", type=click.INT)
|
@click.argument("port", type=click.INT)
|
||||||
def run_server(directory: str, port: int) -> None:
|
def run_async_server(directory: str, port: int) -> None:
|
||||||
"""Run server."""
|
"""Run async server."""
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
serverdir = Path(tmpdir)
|
serverdir = Path(tmpdir)
|
||||||
host_key = serverdir / "hostkey"
|
host_key = str(serverdir / "hostkey")
|
||||||
clientdir = Path(directory)
|
clientdir = Path(directory)
|
||||||
backend = FileTableBackend(clientdir)
|
backend = FileTableBackend(clientdir)
|
||||||
SshKeyServer.start_server(host_key, clients=backend, port=port, create_key=True)
|
try:
|
||||||
|
loop.run_until_complete(start_server(port, backend, host_key, True))
|
||||||
|
except (OSError, asyncssh.Error) as exc:
|
||||||
|
click.echo(f"Error starting server: {exc}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
loop.run_forever()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
122
src/sshecret/server/async_server.py
Normal file
122
src/sshecret/server/async_server.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
"""Server implemented with asyncssh."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import override
|
||||||
|
|
||||||
|
import asyncssh
|
||||||
|
|
||||||
|
from sshecret import constants
|
||||||
|
from sshecret.types import ClientSpecification, BaseClientBackend
|
||||||
|
from sshecret.crypto import create_private_rsa_key
|
||||||
|
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_client(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||||
|
"""Handle client."""
|
||||||
|
client_found = process.get_extra_info("client_allowed", False)
|
||||||
|
if not client_found:
|
||||||
|
process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n")
|
||||||
|
process.exit(1)
|
||||||
|
return
|
||||||
|
|
||||||
|
client_allowed = process.get_extra_info("client_allowed", False)
|
||||||
|
if not client_allowed:
|
||||||
|
process.stderr.write(constants.ERROR_SOURCE_IP_NOT_ALLOWED + "\n")
|
||||||
|
process.exit(1)
|
||||||
|
return
|
||||||
|
|
||||||
|
client = process.get_extra_info("client")
|
||||||
|
if not client:
|
||||||
|
process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n")
|
||||||
|
process.exit(1)
|
||||||
|
return
|
||||||
|
|
||||||
|
secret_name = process.command
|
||||||
|
if not secret_name:
|
||||||
|
process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n")
|
||||||
|
process.exit(1)
|
||||||
|
return
|
||||||
|
|
||||||
|
LOG.debug("Client %s successfully connected. Fetching secret %s", client.name, secret_name)
|
||||||
|
|
||||||
|
secret = client.secrets.get(secret_name)
|
||||||
|
if not secret:
|
||||||
|
process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n")
|
||||||
|
process.exit(1)
|
||||||
|
return
|
||||||
|
|
||||||
|
process.stdout.write(secret)
|
||||||
|
process.exit(0)
|
||||||
|
|
||||||
|
class AsshyncServer(asyncssh.SSHServer):
|
||||||
|
"""Asynchronous SSH server implementation."""
|
||||||
|
|
||||||
|
def __init__(self, backend: BaseClientBackend) -> None:
|
||||||
|
"""Initialize server."""
|
||||||
|
self.backend: BaseClientBackend = backend
|
||||||
|
self._conn: asyncssh.SSHServerConnection | None = None
|
||||||
|
|
||||||
|
@override
|
||||||
|
def connection_made(self, conn: asyncssh.SSHServerConnection) -> None:
|
||||||
|
"""Handle incoming connection."""
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
@override
|
||||||
|
def begin_auth(self, username: str) -> bool:
|
||||||
|
"""Begin authentication."""
|
||||||
|
if not self._conn:
|
||||||
|
return True
|
||||||
|
client = self.backend.lookup_name(username)
|
||||||
|
if not client:
|
||||||
|
return True
|
||||||
|
self._conn.set_extra_info(client_found=True)
|
||||||
|
remote_ip = self._conn.get_extra_info("peername")[0]
|
||||||
|
LOG.debug("Remote_IP: %r", remote_ip)
|
||||||
|
assert isinstance(remote_ip, str)
|
||||||
|
if self.check_connection_allowed(client, remote_ip):
|
||||||
|
self._conn.set_extra_info(client_allowed=True)
|
||||||
|
self._conn.set_extra_info(client=client)
|
||||||
|
|
||||||
|
# Load the key.
|
||||||
|
public_key = asyncssh.import_authorized_keys(client.public_key)
|
||||||
|
self._conn.set_authorized_keys(public_key)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
@override
|
||||||
|
def password_auth_supported(self) -> bool:
|
||||||
|
"""Deny password authentication."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def check_connection_allowed(self, client: ClientSpecification, source: str) -> bool:
|
||||||
|
"""Check if client is allowed to request secrets."""
|
||||||
|
LOG.debug(
|
||||||
|
"Checking if client is allowed to log in from %s", source
|
||||||
|
)
|
||||||
|
if isinstance(client.allowed_ips, str) and client.allowed_ips == "*":
|
||||||
|
LOG.debug("Client has no restrictions on source IP address. Permitting.")
|
||||||
|
return True
|
||||||
|
if isinstance(client.allowed_ips, str) and "/" not in client.allowed_ips:
|
||||||
|
if source == client.allowed_ips:
|
||||||
|
LOG.debug("Client IP matches permitted address")
|
||||||
|
return True
|
||||||
|
|
||||||
|
LOG.warning(
|
||||||
|
"Connection for client %s received from IP address %s that is not permitted.",
|
||||||
|
client.name,
|
||||||
|
source
|
||||||
|
)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def start_server(port: int, backend: BaseClientBackend, host_key: str, create_key: bool = False) -> None:
|
||||||
|
"""Start server."""
|
||||||
|
server = partial(AsshyncServer, backend=backend)
|
||||||
|
if create_key:
|
||||||
|
create_private_rsa_key(Path(host_key))
|
||||||
|
await asyncssh.create_server(server, '', port, server_host_keys=[host_key], process_factory=handle_client)
|
||||||
@ -1,309 +0,0 @@
|
|||||||
"""SSH Server implementation."""
|
|
||||||
|
|
||||||
import ipaddress
|
|
||||||
import logging
|
|
||||||
import threading
|
|
||||||
import socket
|
|
||||||
from contextvars import ContextVar, Context, copy_context
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, override
|
|
||||||
|
|
||||||
import paramiko
|
|
||||||
from paramiko.common import (
|
|
||||||
AUTH_SUCCESSFUL,
|
|
||||||
AUTH_FAILED,
|
|
||||||
OPEN_SUCCEEDED,
|
|
||||||
OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED,
|
|
||||||
)
|
|
||||||
|
|
||||||
from sshecret import constants
|
|
||||||
from sshecret.crypto import create_private_rsa_key
|
|
||||||
from sshecret.types import ClientSpecification
|
|
||||||
|
|
||||||
from .types import BaseClientBackend, BaseServer
|
|
||||||
|
|
||||||
from . import errors
|
|
||||||
|
|
||||||
CLIENT_STATE = threading.local()
|
|
||||||
|
|
||||||
client_secret_request_name: ContextVar[str] = ContextVar("client_secret_request_name")
|
|
||||||
client_request_name: ContextVar[str] = ContextVar("client_request_name")
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class TransportContext(paramiko.Transport):
|
|
||||||
"""Context-aware transport."""
|
|
||||||
|
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.__context: Context | None = None
|
|
||||||
|
|
||||||
def snapshot_context(self) -> None:
|
|
||||||
"""Take a snapshot of the current context."""
|
|
||||||
LOG.debug("Snapshot!")
|
|
||||||
self.__context = copy_context()
|
|
||||||
#object.__setattr__(self, "__context", copy_context())
|
|
||||||
|
|
||||||
def get_context(self) -> None:
|
|
||||||
"""Get the frozen context into our current one."""
|
|
||||||
if contextobj := self.__context:
|
|
||||||
for var, value in contextobj.items():
|
|
||||||
var.set(value)
|
|
||||||
|
|
||||||
class SshServerInterface(paramiko.ServerInterface):
|
|
||||||
"""Define our ssh server interface."""
|
|
||||||
|
|
||||||
def __init__(self, clients: BaseClientBackend, client_address: str) -> None:
|
|
||||||
"""Initialize server interface."""
|
|
||||||
self.clients: BaseClientBackend = clients
|
|
||||||
self.client_address: str = client_address
|
|
||||||
self.event: threading.Event = threading.Event()
|
|
||||||
|
|
||||||
@override
|
|
||||||
def check_auth_publickey(self, username: str, key: paramiko.PKey) -> int:
|
|
||||||
"""Check if we can authenticate."""
|
|
||||||
LOG.debug("Verifying public key of username %s", username)
|
|
||||||
if self.clients.verify_key(username, key):
|
|
||||||
LOG.debug("Key matches configured key.")
|
|
||||||
return AUTH_SUCCESSFUL
|
|
||||||
LOG.warning("Key did not match. Auth denied!")
|
|
||||||
return AUTH_FAILED
|
|
||||||
|
|
||||||
@override
|
|
||||||
def get_allowed_auths(self, username: str) -> str:
|
|
||||||
"""Get allowed auth methods."""
|
|
||||||
return "publickey"
|
|
||||||
|
|
||||||
@override
|
|
||||||
def check_channel_request(self, kind: str, chanid: int) -> int:
|
|
||||||
LOG.debug("Open channel request received: kind=%s, chanid=%r", kind, chanid)
|
|
||||||
if kind == "session":
|
|
||||||
LOG.debug("Session requested.")
|
|
||||||
return OPEN_SUCCEEDED
|
|
||||||
LOG.warning("Prohibited channel request received.")
|
|
||||||
return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
|
|
||||||
|
|
||||||
@override
|
|
||||||
def check_channel_shell_request(self, channel: paramiko.Channel) -> bool:
|
|
||||||
"""Check shell request."""
|
|
||||||
# This shouldn't be allowed.
|
|
||||||
LOG.debug("Channel request received. Channel: %r", channel)
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
@override
|
|
||||||
def check_channel_exec_request(
|
|
||||||
self, channel: paramiko.Channel, command: bytes
|
|
||||||
) -> bool:
|
|
||||||
"""Check if the exec request is valid.
|
|
||||||
|
|
||||||
This is where we send the password. The command is always the name of
|
|
||||||
the secret.
|
|
||||||
|
|
||||||
"""
|
|
||||||
LOG.debug("Exec request received: command: %r", command)
|
|
||||||
# Documentation says command is a string, but typeshed says it's bytes...
|
|
||||||
command_str = command.decode()
|
|
||||||
transport = channel.get_transport()
|
|
||||||
if not transport.is_authenticated():
|
|
||||||
return False
|
|
||||||
username = transport.get_username()
|
|
||||||
LOG.debug("Resolved username: %r", username)
|
|
||||||
if not username:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not isinstance(channel.transport, TransportContext):
|
|
||||||
LOG.critical("Error: Incorrect transport class. Cannot process commands.")
|
|
||||||
self.event.set()
|
|
||||||
return False
|
|
||||||
client_secret_request_name.set(command_str)
|
|
||||||
client_request_name.set(username)
|
|
||||||
channel.transport.snapshot_context()
|
|
||||||
self.event.set()
|
|
||||||
LOG.debug("Command check completed.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def check_allowed_client_ip(self, client: ClientSpecification) -> bool:
|
|
||||||
"""Check if the client is allowed to log in based on source IP."""
|
|
||||||
LOG.debug(
|
|
||||||
"Checking if client is allowed to log in from %s", self.client_address
|
|
||||||
)
|
|
||||||
if isinstance(client.allowed_ips, str) and client.allowed_ips == "*":
|
|
||||||
LOG.debug("Client has no restrictions on source IP address. Permitting.")
|
|
||||||
return True
|
|
||||||
if isinstance(client.allowed_ips, str):
|
|
||||||
if self.client_address == client.allowed_ips:
|
|
||||||
LOG.debug("Client IP matches permitted address")
|
|
||||||
return True
|
|
||||||
LOG.warning(
|
|
||||||
"Connection for client %s received from IP address %s that is not permitted.",
|
|
||||||
client.name,
|
|
||||||
self.client_address,
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
client_ip = ipaddress.ip_address(self.client_address)
|
|
||||||
if client_ip in client.allowed_ips:
|
|
||||||
LOG.debug("Client IP matches permitted address")
|
|
||||||
return True
|
|
||||||
LOG.warning(
|
|
||||||
"Connection for client %s received from IP address %s that is not permitted.",
|
|
||||||
client.name,
|
|
||||||
self.client_address,
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class SshKeyServer(BaseServer):
|
|
||||||
"""SSH secrets server."""
|
|
||||||
|
|
||||||
def __init__(self, host_key: Path, clients: BaseClientBackend) -> None:
|
|
||||||
"""Create server instance."""
|
|
||||||
super().__init__()
|
|
||||||
self._host_key: paramiko.RSAKey = paramiko.RSAKey.from_private_key_file(
|
|
||||||
str(host_key), None
|
|
||||||
)
|
|
||||||
self.clients: BaseClientBackend = clients
|
|
||||||
|
|
||||||
def resolve_client(self) -> ClientSpecification:
|
|
||||||
"""Resolve client."""
|
|
||||||
LOG.debug("Looking up client data.")
|
|
||||||
client_name = client_request_name.get(None)
|
|
||||||
if not client_name:
|
|
||||||
LOG.debug("No context data was resolved.")
|
|
||||||
raise errors.UnknownClientError(constants.ERROR_NO_COMMAND_RECEIVED)
|
|
||||||
|
|
||||||
client = self.clients.lookup_name(str(client_name))
|
|
||||||
if not client:
|
|
||||||
raise errors.UnknownClientError(constants.ERROR_UKNOWN_CLIENT_OR_SECRET)
|
|
||||||
return client
|
|
||||||
|
|
||||||
def get_secret(self, client: ClientSpecification) -> str:
|
|
||||||
"""Get command."""
|
|
||||||
LOG.debug("Looking up secret as requested.")
|
|
||||||
secret_name = client_secret_request_name.get(None)
|
|
||||||
if not secret_name:
|
|
||||||
raise errors.UnknownSecretError(constants.ERROR_UKNOWN_CLIENT_OR_SECRET)
|
|
||||||
|
|
||||||
secret = client.secrets.get(str(secret_name))
|
|
||||||
if not secret:
|
|
||||||
raise errors.UnknownSecretError(constants.ERROR_NO_SECRET_FOUND)
|
|
||||||
return secret
|
|
||||||
|
|
||||||
def check_connection_allowed(self, client: ClientSpecification, source: str) -> None:
|
|
||||||
"""Check if client is allowed to request secrets."""
|
|
||||||
LOG.debug(
|
|
||||||
"Checking if client is allowed to log in from %s", source
|
|
||||||
)
|
|
||||||
if isinstance(client.allowed_ips, str) and client.allowed_ips == "*":
|
|
||||||
LOG.debug("Client has no restrictions on source IP address. Permitting.")
|
|
||||||
return
|
|
||||||
if isinstance(client.allowed_ips, str) and "/" not in client.allowed_ips:
|
|
||||||
if source == client.allowed_ips:
|
|
||||||
LOG.debug("Client IP matches permitted address")
|
|
||||||
return
|
|
||||||
|
|
||||||
LOG.warning(
|
|
||||||
"Connection for client %s received from IP address %s that is not permitted.",
|
|
||||||
client.name,
|
|
||||||
source
|
|
||||||
)
|
|
||||||
|
|
||||||
raise errors.AccessPolicyViolationError(constants.ERROR_SOURCE_IP_NOT_ALLOWED)
|
|
||||||
|
|
||||||
source_ip = ipaddress.ip_address(source)
|
|
||||||
permitted = False
|
|
||||||
for client_ip in client.allowed_ips:
|
|
||||||
if isinstance(client_ip, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
|
|
||||||
if source_ip in client_ip:
|
|
||||||
permitted = True
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
if source_ip == client_ip:
|
|
||||||
permitted = True
|
|
||||||
break
|
|
||||||
if not permitted:
|
|
||||||
raise errors.AccessPolicyViolationError(constants.ERROR_SOURCE_IP_NOT_ALLOWED)
|
|
||||||
|
|
||||||
LOG.debug("Matched client to permitted IP address statement.")
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
@override
|
|
||||||
def connection_function(
|
|
||||||
self, client_socket: socket.socket, client_address: str
|
|
||||||
) -> None:
|
|
||||||
"""Run on connection."""
|
|
||||||
LOG.debug("Connection function called by %s", client_address)
|
|
||||||
try:
|
|
||||||
session = TransportContext(client_socket)
|
|
||||||
session.add_server_key(self._host_key)
|
|
||||||
server = SshServerInterface(self.clients, client_address)
|
|
||||||
try:
|
|
||||||
session.start_server(server=server)
|
|
||||||
except paramiko.SSHException:
|
|
||||||
return
|
|
||||||
|
|
||||||
channel = session.accept(30)
|
|
||||||
if not channel:
|
|
||||||
LOG.debug("No channel opened!")
|
|
||||||
return
|
|
||||||
|
|
||||||
LOG.debug("Got channel: %r, transport: %r, ", channel, channel.transport)
|
|
||||||
server.event.wait()
|
|
||||||
|
|
||||||
LOG.debug("Opening channel file")
|
|
||||||
stdout = channel.makefile("rw")
|
|
||||||
|
|
||||||
LOG.debug("Extracting context.")
|
|
||||||
session.get_context()
|
|
||||||
try:
|
|
||||||
LOG.debug("Looking up client.")
|
|
||||||
client = self.resolve_client()
|
|
||||||
LOG.debug("Checking source address policy.")
|
|
||||||
self.check_connection_allowed(client, client_address)
|
|
||||||
LOG.debug("Looking up secret.")
|
|
||||||
secret = self.get_secret(client)
|
|
||||||
except errors.BaseSshecretServerError as e:
|
|
||||||
error_message = f"{e}\n"
|
|
||||||
LOG.critical(e, exc_info=True)
|
|
||||||
stdout.write(error_message)
|
|
||||||
session.close()
|
|
||||||
return
|
|
||||||
|
|
||||||
stdout.write(secret)
|
|
||||||
session.close()
|
|
||||||
except Exception as e:
|
|
||||||
LOG.critical(e, exc_info=True)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def start_server(
|
|
||||||
cls,
|
|
||||||
host_key: str | Path,
|
|
||||||
clients: BaseClientBackend,
|
|
||||||
bind_address: str = "127.0.0.1",
|
|
||||||
port: int = 22,
|
|
||||||
timeout: int = 10,
|
|
||||||
create_key: bool = False,
|
|
||||||
) -> None:
|
|
||||||
"""Start the server.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
host_key: path to the private host key (str or Path)
|
|
||||||
clients: Client secret loader instance.
|
|
||||||
bind_address: address to bind to (default: 127.0.0.1)
|
|
||||||
port: Port to bind to (default: 22)
|
|
||||||
timeout: Socket timeout, default 1 second
|
|
||||||
create_key: Create the private key if it doesn't exist.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if isinstance(host_key, str):
|
|
||||||
host_key = Path(host_key)
|
|
||||||
|
|
||||||
if not host_key.exists():
|
|
||||||
if create_key:
|
|
||||||
create_private_rsa_key(host_key)
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Error: provided host key does not exist.")
|
|
||||||
server = cls(host_key, clients)
|
|
||||||
server.start(bind_address, port, timeout)
|
|
||||||
@ -1,137 +0,0 @@
|
|||||||
"""Base types and interfaces for the server."""
|
|
||||||
|
|
||||||
import abc
|
|
||||||
import logging
|
|
||||||
import socket
|
|
||||||
import sys
|
|
||||||
import threading
|
|
||||||
|
|
||||||
from typing import Any, TypeGuard
|
|
||||||
from cryptography.hazmat.primitives import serialization
|
|
||||||
|
|
||||||
import paramiko
|
|
||||||
from sshecret.types import ClientSpecification
|
|
||||||
from sshecret.crypto import load_client_key
|
|
||||||
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_socket_tuple(data: Any) -> TypeGuard[tuple[str, int]]:
|
|
||||||
"""Validate socket accept return data.."""
|
|
||||||
if not isinstance(data, tuple):
|
|
||||||
return False
|
|
||||||
if not len(data) == 2:
|
|
||||||
return False
|
|
||||||
ip, port = data # pyright: ignore[reportUnknownVariableType]
|
|
||||||
if not isinstance(ip, str):
|
|
||||||
return False
|
|
||||||
if not isinstance(port, int):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class BaseClientBackend(abc.ABC):
|
|
||||||
"""Base client backend.
|
|
||||||
|
|
||||||
This class is responsible for managing the list of clients and facilitate
|
|
||||||
lookups.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def lookup_name(self, name: str) -> ClientSpecification | None:
|
|
||||||
"""Lookup a client specification by name."""
|
|
||||||
|
|
||||||
def _convert_to_pkey(self, client: ClientSpecification) -> paramiko.RSAKey:
|
|
||||||
"""Convert client key to paramiko key."""
|
|
||||||
client_key = load_client_key(client)
|
|
||||||
|
|
||||||
return paramiko.RSAKey(key=client_key)
|
|
||||||
|
|
||||||
def verify_key(self, name: str, key: paramiko.PKey) -> ClientSpecification | None:
|
|
||||||
"""Verify key."""
|
|
||||||
client = self.lookup_name(name)
|
|
||||||
if not client:
|
|
||||||
return None
|
|
||||||
LOG.debug("Verifying key: %r", key)
|
|
||||||
expected_key = self._convert_to_pkey(client)
|
|
||||||
if key == expected_key:
|
|
||||||
return client
|
|
||||||
return None
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def add_secret(self, client_name: str, secret_name: str, secret_value: str, encrypted: bool = False) -> None:
|
|
||||||
"""Add a secret to a client."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def add_client(self, spec: ClientSpecification) -> None:
|
|
||||||
"""Add a new client."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def update_client(self, name: str, spec: ClientSpecification) -> None:
|
|
||||||
"""Update client information."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def remove_client(self, name: str, persistent: bool = True) -> None:
|
|
||||||
"""Delete a client."""
|
|
||||||
|
|
||||||
|
|
||||||
class BaseServer(abc.ABC):
|
|
||||||
"""Base SSH server."""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
"""Initialize the server."""
|
|
||||||
self._is_running: threading.Event = threading.Event()
|
|
||||||
self._socket: socket.socket | None = None
|
|
||||||
self._listen_thread: threading.Thread | None = None
|
|
||||||
|
|
||||||
def start(
|
|
||||||
self, address: str = "127.0.0.1", port: int = 22, timeout: int = 1
|
|
||||||
) -> None:
|
|
||||||
"""Start the server."""
|
|
||||||
if not self._is_running.is_set():
|
|
||||||
LOG.info("Starting SSH server on %s port %s", address, port)
|
|
||||||
self._is_running.set()
|
|
||||||
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
||||||
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
|
|
||||||
if sys.platform == "linux" or sys.platform == "linux2":
|
|
||||||
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, True)
|
|
||||||
LOG.debug("Setting socket timeout %s", timeout)
|
|
||||||
self._socket.settimeout(timeout)
|
|
||||||
LOG.debug("Binding to %s:%s", address, port)
|
|
||||||
self._socket.bind((address, port))
|
|
||||||
LOG.debug("Spawning thread.")
|
|
||||||
self._listen_thread = threading.Thread(target=self._listen)
|
|
||||||
self._listen_thread.start()
|
|
||||||
|
|
||||||
def stop(self) -> None:
|
|
||||||
"""Stop the server."""
|
|
||||||
if self._is_running.is_set():
|
|
||||||
self._is_running.clear()
|
|
||||||
if self._listen_thread:
|
|
||||||
self._listen_thread.join()
|
|
||||||
if self._socket:
|
|
||||||
self._socket.close()
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def connection_function(
|
|
||||||
self, client_socket: socket.socket, client_address: str
|
|
||||||
) -> None:
|
|
||||||
"""Run function on connect."""
|
|
||||||
|
|
||||||
def _listen(self) -> None:
|
|
||||||
"""Connect client to function."""
|
|
||||||
if self._socket is None:
|
|
||||||
raise RuntimeError("Received connection request without any socket")
|
|
||||||
|
|
||||||
while self._is_running.is_set():
|
|
||||||
try:
|
|
||||||
self._socket.listen()
|
|
||||||
client_socket, addr = self._socket.accept() # pyright: ignore[reportAny]
|
|
||||||
if not validate_socket_tuple(addr):
|
|
||||||
LOG.warning("Socket address tuple did not pass typeguard check!")
|
|
||||||
continue
|
|
||||||
LOG.debug("Received connection from %r", addr)
|
|
||||||
self.connection_function(client_socket, addr[0])
|
|
||||||
except socket.timeout:
|
|
||||||
pass
|
|
||||||
@ -83,3 +83,31 @@ class ClientSpecification(BaseModel):
|
|||||||
allowed_ips: list[IPvAnyAddress | IPvAnyNetwork] | str = "*"
|
allowed_ips: list[IPvAnyAddress | IPvAnyNetwork] | str = "*"
|
||||||
secrets: dict[str, str] = {}
|
secrets: dict[str, str] = {}
|
||||||
testing_private_key: str | None = None # Private key only for testing purposes!
|
testing_private_key: str | None = None # Private key only for testing purposes!
|
||||||
|
|
||||||
|
|
||||||
|
class BaseClientBackend(abc.ABC):
|
||||||
|
"""Base client backend.
|
||||||
|
|
||||||
|
This class is responsible for managing the list of clients and facilitate
|
||||||
|
lookups.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def lookup_name(self, name: str) -> ClientSpecification | None:
|
||||||
|
"""Lookup a client specification by name."""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def add_secret(self, client_name: str, secret_name: str, secret_value: str, encrypted: bool = False) -> None:
|
||||||
|
"""Add a secret to a client."""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def add_client(self, spec: ClientSpecification) -> None:
|
||||||
|
"""Add a new client."""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def update_client(self, name: str, spec: ClientSpecification) -> None:
|
||||||
|
"""Update client information."""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def remove_client(self, name: str, persistent: bool = True) -> None:
|
||||||
|
"""Delete a client."""
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
"""Tests of client loader."""
|
"""Tests of client loader."""
|
||||||
|
# pyright: reportUninitializedInstanceVariable=false, reportImplicitOverride=false
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from sshecret.server import client_loader
|
from sshecret.backends import FileTableBackend
|
||||||
from sshecret.utils import generate_client_object
|
from sshecret.utils import generate_client_object
|
||||||
from sshecret.testing import TestClientSpec, test_context
|
from sshecret.testing import TestClientSpec, test_context
|
||||||
|
|
||||||
@ -11,7 +12,7 @@ class TestFileTableBackend(unittest.TestCase):
|
|||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
"""Set up tests."""
|
"""Set up tests."""
|
||||||
self.test_dataset = [
|
self.test_dataset: list[TestClientSpec] = [
|
||||||
TestClientSpec("webserver", {"SECRET_TOKEN": "mysecrettoken"}),
|
TestClientSpec("webserver", {"SECRET_TOKEN": "mysecrettoken"}),
|
||||||
TestClientSpec("dbserver", {"DB_ROOT_PASSWORD": "mysecretpassword"}),
|
TestClientSpec("dbserver", {"DB_ROOT_PASSWORD": "mysecretpassword"}),
|
||||||
]
|
]
|
||||||
@ -19,21 +20,22 @@ class TestFileTableBackend(unittest.TestCase):
|
|||||||
def test_init(self) -> None:
|
def test_init(self) -> None:
|
||||||
"""Test instance creation."""
|
"""Test instance creation."""
|
||||||
with test_context(self.test_dataset) as testdir:
|
with test_context(self.test_dataset) as testdir:
|
||||||
backend = client_loader.FileTableBackend(testdir)
|
backend = FileTableBackend(testdir)
|
||||||
self.assertGreater(len(backend.table), 0)
|
self.assertGreater(len(backend.table), 0)
|
||||||
|
|
||||||
def test_lookup_name(self) -> None:
|
def test_lookup_name(self) -> None:
|
||||||
"""Test lookup name."""
|
"""Test lookup name."""
|
||||||
with test_context(self.test_dataset) as testdir:
|
with test_context(self.test_dataset) as testdir:
|
||||||
backend = client_loader.FileTableBackend(testdir)
|
backend = FileTableBackend(testdir)
|
||||||
webserver = backend.lookup_name("webserver")
|
webserver = backend.lookup_name("webserver")
|
||||||
self.assertIsNotNone(webserver)
|
self.assertIsNotNone(webserver)
|
||||||
|
assert webserver is not None
|
||||||
self.assertEqual(webserver.name, "webserver")
|
self.assertEqual(webserver.name, "webserver")
|
||||||
|
|
||||||
def test_add_client(self) -> None:
|
def test_add_client(self) -> None:
|
||||||
"""Test whether it is possible to add a client."""
|
"""Test whether it is possible to add a client."""
|
||||||
with test_context(self.test_dataset) as testdir:
|
with test_context(self.test_dataset) as testdir:
|
||||||
backend = client_loader.FileTableBackend(testdir)
|
backend = FileTableBackend(testdir)
|
||||||
new_client = generate_client_object(
|
new_client = generate_client_object(
|
||||||
"backupserver", {"BACKUP_KEY": "mysecretbackupkey"}
|
"backupserver", {"BACKUP_KEY": "mysecretbackupkey"}
|
||||||
)
|
)
|
||||||
@ -46,7 +48,7 @@ class TestFileTableBackend(unittest.TestCase):
|
|||||||
def test_add_secret(self) -> None:
|
def test_add_secret(self) -> None:
|
||||||
"""Test whether it is possible to add a secret."""
|
"""Test whether it is possible to add a secret."""
|
||||||
with test_context(self.test_dataset) as testdir:
|
with test_context(self.test_dataset) as testdir:
|
||||||
backend = client_loader.FileTableBackend(testdir)
|
backend = FileTableBackend(testdir)
|
||||||
backend.add_secret("webserver", "OTHER_SECRET_TOKEN", "myothersecrettoken")
|
backend.add_secret("webserver", "OTHER_SECRET_TOKEN", "myothersecrettoken")
|
||||||
webserver = backend.lookup_name("webserver")
|
webserver = backend.lookup_name("webserver")
|
||||||
assert webserver is not None
|
assert webserver is not None
|
||||||
@ -65,7 +67,7 @@ class TestFileTableBackend(unittest.TestCase):
|
|||||||
def test_update_client(self) -> None:
|
def test_update_client(self) -> None:
|
||||||
"""Test update_client method."""
|
"""Test update_client method."""
|
||||||
with test_context(self.test_dataset) as testdir:
|
with test_context(self.test_dataset) as testdir:
|
||||||
backend = client_loader.FileTableBackend(testdir)
|
backend = FileTableBackend(testdir)
|
||||||
webserver = backend.lookup_name("webserver")
|
webserver = backend.lookup_name("webserver")
|
||||||
assert webserver is not None
|
assert webserver is not None
|
||||||
webserver.allowed_ips = "192.0.2.1"
|
webserver.allowed_ips = "192.0.2.1"
|
||||||
@ -77,7 +79,7 @@ class TestFileTableBackend(unittest.TestCase):
|
|||||||
def test_remove_client(self) -> None:
|
def test_remove_client(self) -> None:
|
||||||
"""Test removal of client."""
|
"""Test removal of client."""
|
||||||
with test_context(self.test_dataset) as testdir:
|
with test_context(self.test_dataset) as testdir:
|
||||||
backend = client_loader.FileTableBackend(testdir)
|
backend = FileTableBackend(testdir)
|
||||||
backend.remove_client("webserver", persistent=False)
|
backend.remove_client("webserver", persistent=False)
|
||||||
webserver = backend.lookup_name("webserver")
|
webserver = backend.lookup_name("webserver")
|
||||||
self.assertIsNone(webserver)
|
self.assertIsNone(webserver)
|
||||||
@ -87,7 +89,7 @@ class TestFileTableBackend(unittest.TestCase):
|
|||||||
def test_remove_client_persistent(self) -> None:
|
def test_remove_client_persistent(self) -> None:
|
||||||
"""Test removal of client."""
|
"""Test removal of client."""
|
||||||
with test_context(self.test_dataset) as testdir:
|
with test_context(self.test_dataset) as testdir:
|
||||||
backend = client_loader.FileTableBackend(testdir)
|
backend = FileTableBackend(testdir)
|
||||||
backend.remove_client("webserver", persistent=True)
|
backend.remove_client("webserver", persistent=True)
|
||||||
webserver = backend.lookup_name("webserver")
|
webserver = backend.lookup_name("webserver")
|
||||||
self.assertIsNone(webserver)
|
self.assertIsNone(webserver)
|
||||||
|
|||||||
@ -35,8 +35,8 @@ class TestBasicCrypto(unittest.TestCase):
|
|||||||
|
|
||||||
def test_key_loading(self) -> None:
|
def test_key_loading(self) -> None:
|
||||||
"""Test basic flow."""
|
"""Test basic flow."""
|
||||||
public_key = load_public_key(self.public_key)
|
load_public_key(self.public_key)
|
||||||
private_key = load_private_key(self.private_key)
|
load_private_key(self.private_key)
|
||||||
self.assertEqual(True, True)
|
self.assertEqual(True, True)
|
||||||
|
|
||||||
def test_encrypt_decrypt(self) -> None:
|
def test_encrypt_decrypt(self) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user