From 08d8031d0954971f8bc03e1501fc95224f04557b Mon Sep 17 00:00:00 2001 From: Allan Eising Date: Sun, 23 Mar 2025 16:06:53 +0100 Subject: [PATCH] Re-format files --- src/sshecret/backends/file_table.py | 4 +-- src/sshecret/client.py | 2 ++ src/sshecret/constants.py | 4 ++- src/sshecret/crypto.py | 39 ++++++++++++++++++++++------- src/sshecret/dev_cli.py | 18 ++++++++++--- src/sshecret/keepass.py | 10 ++++++-- src/sshecret/server/__init__.py | 1 - src/sshecret/server/async_server.py | 25 +++++++++++------- src/sshecret/server/errors.py | 1 + src/sshecret/testing.py | 1 - src/sshecret/types.py | 8 +++++- src/sshecret/utils.py | 33 ++++++++++++++++++------ 12 files changed, 107 insertions(+), 39 deletions(-) diff --git a/src/sshecret/backends/file_table.py b/src/sshecret/backends/file_table.py index c376193..57e7abd 100644 --- a/src/sshecret/backends/file_table.py +++ b/src/sshecret/backends/file_table.py @@ -85,9 +85,7 @@ class FileTableBackend(BaseClientBackend): encrypted: bool = False, ) -> None: """Add secret.""" - client: ClientSpecification = self.table.by.name[ - client_name - ] # pyright: ignore[reportAssignmentType] + client: ClientSpecification = self.table.by.name[client_name] # pyright: ignore[reportAssignmentType] if not encrypted: public_key = load_client_key(client) secret_value = encrypt_string(secret_value, public_key) diff --git a/src/sshecret/client.py b/src/sshecret/client.py index 4773241..1a23780 100644 --- a/src/sshecret/client.py +++ b/src/sshecret/client.py @@ -5,6 +5,7 @@ import click from sshecret.crypto import decode_string, load_private_key + def decrypt_secret(encoded: str, client_key: str) -> str: """Decrypt secret.""" private_key = load_private_key(client_key) @@ -19,5 +20,6 @@ def cli_decrypt(keyfile: str, encrypted_input: TextIO) -> None: decrypted = decrypt_secret(encrypted_input.read(), keyfile) click.echo(decrypted) + if __name__ == "__main__": cli_decrypt() diff --git a/src/sshecret/constants.py b/src/sshecret/constants.py index b6a2e88..a401528 100644 --- a/src/sshecret/constants.py +++ b/src/sshecret/constants.py @@ -7,7 +7,9 @@ VAR_PREFIX = "SSHECRET" ERROR_NO_SECRET_FOUND = "Error: No secret available with the given name." ERROR_UKNOWN_CLIENT_OR_SECRET = "Error: Invalid client or secret name." ERROR_NO_COMMAND_RECEIVED = "Error: No command was received from the client." -ERROR_SOURCE_IP_NOT_ALLOWED = "Error: Client not authorized to connect from the given host." +ERROR_SOURCE_IP_NOT_ALLOWED = ( + "Error: Client not authorized to connect from the given host." +) RSA_PUBLIC_EXPONENT = 65537 RSA_KEY_SIZE = 2048 diff --git a/src/sshecret/crypto.py b/src/sshecret/crypto.py index 59488f7..b863ebe 100644 --- a/src/sshecret/crypto.py +++ b/src/sshecret/crypto.py @@ -12,17 +12,20 @@ from . import constants LOG = logging.getLogger(__name__) + def load_client_key(client: ClientSpecification) -> rsa.RSAPublicKey: """Load public key.""" keybytes = client.public_key.encode() return load_public_key(keybytes) + def load_public_key(keybytes: bytes) -> rsa.RSAPublicKey: public_key = serialization.load_ssh_public_key(keybytes) if not isinstance(public_key, rsa.RSAPublicKey): raise RuntimeError("Only RSA keys are supported.") return public_key + def load_private_key(filename: str) -> rsa.RSAPrivateKey: """Load a private key.""" with open(filename, "rb") as f: @@ -31,6 +34,7 @@ def load_private_key(filename: str) -> rsa.RSAPrivateKey: raise RuntimeError("Only RSA keys are supported.") return private_key + def encrypt_string(string: str, public_key: rsa.RSAPublicKey) -> str: """Encrypt string, end return it base64 encoded.""" message = string.encode() @@ -40,30 +44,40 @@ def encrypt_string(string: str, public_key: rsa.RSAPublicKey) -> str: mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None, - ) + ), ) return base64.b64encode(ciphertext).decode() + def decode_string(ciphertext: str, private_key: rsa.RSAPrivateKey) -> str: """Decode a string. String must be base64 encoded.""" decoded = base64.b64decode(ciphertext) decrypted = private_key.decrypt( decoded, padding.OAEP( - mgf=padding.MGF1(algorithm=hashes.SHA256()), - algorithm=hashes.SHA256(), - label=None - )) + mgf=padding.MGF1(algorithm=hashes.SHA256()), + algorithm=hashes.SHA256(), + label=None, + ), + ) return decrypted.decode() + def generate_private_key() -> rsa.RSAPrivateKey: """Generate private RSA key.""" - private_key = rsa.generate_private_key(public_exponent=constants.RSA_PUBLIC_EXPONENT, key_size=constants.RSA_KEY_SIZE) + private_key = rsa.generate_private_key( + public_exponent=constants.RSA_PUBLIC_EXPONENT, key_size=constants.RSA_KEY_SIZE + ) return private_key + def generate_pem(private_key: rsa.RSAPrivateKey) -> str: """Generate PEM.""" - pem = private_key.private_bytes(encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.OpenSSH, encryption_algorithm=serialization.NoEncryption()) + pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.OpenSSH, + encryption_algorithm=serialization.NoEncryption(), + ) return pem.decode() @@ -74,7 +88,11 @@ def create_private_rsa_key(filename: Path) -> None: LOG.debug("Generating private RSA key at %s", filename) private_key = generate_private_key() with open(filename, "wb") as f: - pem = private_key.private_bytes(encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.OpenSSH, encryption_algorithm=serialization.NoEncryption()) + pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.OpenSSH, + encryption_algorithm=serialization.NoEncryption(), + ) lines = f.write(pem) LOG.debug("Wrote %s lines", lines) f.flush() @@ -82,5 +100,8 @@ def create_private_rsa_key(filename: Path) -> None: def generate_public_key_string(public_key: rsa.RSAPublicKey) -> str: """Generate public key string.""" - keybytes = public_key.public_bytes(encoding=serialization.Encoding.OpenSSH, format=serialization.PublicFormat.OpenSSH) + keybytes = public_key.public_bytes( + encoding=serialization.Encoding.OpenSSH, + format=serialization.PublicFormat.OpenSSH, + ) return keybytes.decode() diff --git a/src/sshecret/dev_cli.py b/src/sshecret/dev_cli.py index a5c5c7d..8cdf69e 100644 --- a/src/sshecret/dev_cli.py +++ b/src/sshecret/dev_cli.py @@ -20,29 +20,39 @@ def thread_id_filter(record: logging.LogRecord) -> logging.LogRecord: record.thread_id = threading.get_native_id() return record + LOG = logging.getLogger() handler = logging.StreamHandler() handler.addFilter(thread_id_filter) -formatter = logging.Formatter("%(thread_id)d:%(created)f:%(levelname)s:%(name)s:%(module)s:%(message)s") +formatter = logging.Formatter( + "%(thread_id)d:%(created)f:%(levelname)s:%(name)s:%(module)s:%(message)s" +) handler.setFormatter(formatter) LOG.addHandler(handler) LOG.setLevel(logging.DEBUG) + @click.group() def cli() -> None: """Run commands for testing.""" + @cli.command("create-client") @click.argument("name") -@click.argument("filename", type=click.Path(file_okay=True, dir_okay=False, writable=True)) +@click.argument( + "filename", type=click.Path(file_okay=True, dir_okay=False, writable=True) +) @click.option("--public-key", type=click.Path(file_okay=True)) def create_client(name: str, filename: str, public_key: str | None) -> None: """Create a client.""" create_client_file(name, filename, keyfile=public_key) click.echo(f"Wrote client config to {filename}") + @cli.command("add-secret") -@click.argument("filename", type=click.Path(file_okay=True, dir_okay=False, writable=True)) +@click.argument( + "filename", type=click.Path(file_okay=True, dir_okay=False, writable=True) +) @click.argument("secret-name") @click.argument("secret-value") def add_secret(filename: str, secret_name: str, secret_value: str) -> None: @@ -50,6 +60,7 @@ def add_secret(filename: str, secret_name: str, secret_value: str) -> None: add_secret_to_client_file(filename, secret_name, secret_value) click.echo(f"Wrote secret to {filename}") + @cli.command("server") @click.argument("directory", type=click.Path(file_okay=False, dir_okay=True)) @click.argument("port", type=click.INT) @@ -70,6 +81,5 @@ def run_async_server(directory: str, port: int) -> None: loop.run_forever() - if __name__ == "__main__": cli() diff --git a/src/sshecret/keepass.py b/src/sshecret/keepass.py index df4fd7f..4364d9c 100644 --- a/src/sshecret/keepass.py +++ b/src/sshecret/keepass.py @@ -1,4 +1,5 @@ """Keepass integration.""" + import logging from pathlib import Path from typing import final, override, Self @@ -10,6 +11,7 @@ from .utils import generate_password LOG = logging.getLogger(__name__) + @final class KeepassManager(BasePasswordManager): """KeepassXC compatible password manager.""" @@ -35,7 +37,9 @@ class KeepassManager(BasePasswordManager): @override @classmethod - def create_database(cls, location: str, reader_context: PasswordContext, overwrite: bool = False) -> Self: + def create_database( + cls, location: str, reader_context: PasswordContext, overwrite: bool = False + ) -> Self: """Create database.""" if Path(location).exists() and not overwrite: raise RuntimeError("Error: Database exists.") @@ -74,7 +78,9 @@ class KeepassManager(BasePasswordManager): """Generate password.""" # Generate a password. password = generate_password() - _entry = self.keepass.add_entry(self.keepass.root_group, identifier, constants.NO_USERNAME, password) + _entry = self.keepass.add_entry( + self.keepass.root_group, identifier, constants.NO_USERNAME, password + ) self.keepass.save() LOG.debug("Created Entry %r", _entry) return password diff --git a/src/sshecret/server/__init__.py b/src/sshecret/server/__init__.py index c7b092c..61b3a73 100644 --- a/src/sshecret/server/__init__.py +++ b/src/sshecret/server/__init__.py @@ -1,6 +1,5 @@ """Sshecret server module.""" - from .server import SshKeyServer __all__ = ["SshKeyServer"] diff --git a/src/sshecret/server/async_server.py b/src/sshecret/server/async_server.py index ad984ed..47fd5fd 100644 --- a/src/sshecret/server/async_server.py +++ b/src/sshecret/server/async_server.py @@ -41,7 +41,9 @@ def handle_client(process: asyncssh.SSHServerProcess[str]) -> None: process.exit(1) return - LOG.debug("Client %s successfully connected. Fetching secret %s", client.name, secret_name) + LOG.debug( + "Client %s successfully connected. Fetching secret %s", client.name, secret_name + ) secret = client.secrets.get(secret_name) if not secret: @@ -52,6 +54,7 @@ def handle_client(process: asyncssh.SSHServerProcess[str]) -> None: process.stdout.write(secret) process.exit(0) + class AsshyncServer(asyncssh.SSHServer): """Asynchronous SSH server implementation.""" @@ -92,12 +95,11 @@ class AsshyncServer(asyncssh.SSHServer): """Deny password authentication.""" return False - - def check_connection_allowed(self, client: ClientSpecification, source: str) -> bool: + def check_connection_allowed( + self, client: ClientSpecification, source: str + ) -> bool: """Check if client is allowed to request secrets.""" - LOG.debug( - "Checking if client is allowed to log in from %s", source - ) + LOG.debug("Checking if client is allowed to log in from %s", source) if isinstance(client.allowed_ips, str) and client.allowed_ips == "*": LOG.debug("Client has no restrictions on source IP address. Permitting.") return True @@ -109,14 +111,19 @@ class AsshyncServer(asyncssh.SSHServer): LOG.warning( "Connection for client %s received from IP address %s that is not permitted.", client.name, - source + source, ) return False -async def start_server(port: int, backend: BaseClientBackend, host_key: str, create_key: bool = False) -> None: + +async def start_server( + port: int, backend: BaseClientBackend, host_key: str, create_key: bool = False +) -> None: """Start server.""" server = partial(AsshyncServer, backend=backend) if create_key: create_private_rsa_key(Path(host_key)) - await asyncssh.create_server(server, '', port, server_host_keys=[host_key], process_factory=handle_client) + await asyncssh.create_server( + server, "", port, server_host_keys=[host_key], process_factory=handle_client + ) diff --git a/src/sshecret/server/errors.py b/src/sshecret/server/errors.py index 355630c..e058e9c 100644 --- a/src/sshecret/server/errors.py +++ b/src/sshecret/server/errors.py @@ -1,5 +1,6 @@ """Server errors.""" + class BaseSshecretServerError(Exception): """Base Sshecret Server Error.""" diff --git a/src/sshecret/testing.py b/src/sshecret/testing.py index df593a9..5e9bb59 100644 --- a/src/sshecret/testing.py +++ b/src/sshecret/testing.py @@ -1,6 +1,5 @@ """Testing utilities and classes.""" - import tempfile from dataclasses import dataclass, field from contextlib import contextmanager diff --git a/src/sshecret/types.py b/src/sshecret/types.py index 9d1754a..1faf786 100644 --- a/src/sshecret/types.py +++ b/src/sshecret/types.py @@ -97,7 +97,13 @@ class BaseClientBackend(abc.ABC): """Lookup a client specification by name.""" @abc.abstractmethod - def add_secret(self, client_name: str, secret_name: str, secret_value: str, encrypted: bool = False) -> None: + def add_secret( + self, + client_name: str, + secret_name: str, + secret_value: str, + encrypted: bool = False, + ) -> None: """Add a secret to a client.""" @abc.abstractmethod diff --git a/src/sshecret/utils.py b/src/sshecret/utils.py index fc6e925..75e9fa9 100644 --- a/src/sshecret/utils.py +++ b/src/sshecret/utils.py @@ -4,16 +4,24 @@ import secrets from pathlib import Path -from .crypto import load_client_key, encrypt_string, generate_private_key, generate_pem, generate_public_key_string +from .crypto import ( + load_client_key, + encrypt_string, + generate_private_key, + generate_pem, + generate_public_key_string, +) from .types import ClientSpecification + def generate_password() -> str: - """Generate a password. - """ + """Generate a password.""" return secrets.token_urlsafe(32) -def generate_client_object(name: str, secrets: dict[str, str] | None = None, keyfile: str | None = None) -> ClientSpecification: +def generate_client_object( + name: str, secrets: dict[str, str] | None = None, keyfile: str | None = None +) -> ClientSpecification: """Generate a client object.""" private_key = generate_private_key() if keyfile: @@ -22,21 +30,28 @@ def generate_client_object(name: str, secrets: dict[str, str] | None = None, key if not contents.startswith("ssh-rsa "): raise RuntimeError("Error: Key must be an RSA key.") - client = ClientSpecification(name=name, public_key=contents.strip()) public_key = load_client_key(client) else: pem = generate_pem(private_key) public_key = private_key.public_key() pubkey_str = generate_public_key_string(public_key) - client = ClientSpecification(name=name, public_key=pubkey_str, testing_private_key=pem) + client = ClientSpecification( + name=name, public_key=pubkey_str, testing_private_key=pem + ) if secrets: for secret_name, secret_value in secrets.items(): client.secrets[secret_name] = encrypt_string(secret_value, public_key) return client -def create_client_file(name: str, filename: Path | str, secrets: dict[str, str] | None = None, keyfile: str | None = None) -> None: + +def create_client_file( + name: str, + filename: Path | str, + secrets: dict[str, str] | None = None, + keyfile: str | None = None, +) -> None: """Create client file.""" client = generate_client_object(name, secrets, keyfile) @@ -45,7 +60,9 @@ def create_client_file(name: str, filename: Path | str, secrets: dict[str, str] f.flush() -def add_secret_to_client_file(filename: str | Path, secret_name: str, secret_value: str) -> None: +def add_secret_to_client_file( + filename: str | Path, secret_name: str, secret_value: str +) -> None: """Add secret to client file.""" with open(filename, "r") as f: client = ClientSpecification.model_validate_json(f.read())