Re-format files

This commit is contained in:
2025-03-23 16:06:53 +01:00
parent de46a47acc
commit 08d8031d09
12 changed files with 107 additions and 39 deletions

View File

@ -85,9 +85,7 @@ class FileTableBackend(BaseClientBackend):
encrypted: bool = False,
) -> None:
"""Add secret."""
client: ClientSpecification = self.table.by.name[
client_name
] # pyright: ignore[reportAssignmentType]
client: ClientSpecification = self.table.by.name[client_name] # pyright: ignore[reportAssignmentType]
if not encrypted:
public_key = load_client_key(client)
secret_value = encrypt_string(secret_value, public_key)

View File

@ -5,6 +5,7 @@ import click
from sshecret.crypto import decode_string, load_private_key
def decrypt_secret(encoded: str, client_key: str) -> str:
"""Decrypt secret."""
private_key = load_private_key(client_key)
@ -19,5 +20,6 @@ def cli_decrypt(keyfile: str, encrypted_input: TextIO) -> None:
decrypted = decrypt_secret(encrypted_input.read(), keyfile)
click.echo(decrypted)
if __name__ == "__main__":
cli_decrypt()

View File

@ -7,7 +7,9 @@ VAR_PREFIX = "SSHECRET"
ERROR_NO_SECRET_FOUND = "Error: No secret available with the given name."
ERROR_UKNOWN_CLIENT_OR_SECRET = "Error: Invalid client or secret name."
ERROR_NO_COMMAND_RECEIVED = "Error: No command was received from the client."
ERROR_SOURCE_IP_NOT_ALLOWED = "Error: Client not authorized to connect from the given host."
ERROR_SOURCE_IP_NOT_ALLOWED = (
"Error: Client not authorized to connect from the given host."
)
RSA_PUBLIC_EXPONENT = 65537
RSA_KEY_SIZE = 2048

View File

@ -12,17 +12,20 @@ from . import constants
LOG = logging.getLogger(__name__)
def load_client_key(client: ClientSpecification) -> rsa.RSAPublicKey:
"""Load public key."""
keybytes = client.public_key.encode()
return load_public_key(keybytes)
def load_public_key(keybytes: bytes) -> rsa.RSAPublicKey:
public_key = serialization.load_ssh_public_key(keybytes)
if not isinstance(public_key, rsa.RSAPublicKey):
raise RuntimeError("Only RSA keys are supported.")
return public_key
def load_private_key(filename: str) -> rsa.RSAPrivateKey:
"""Load a private key."""
with open(filename, "rb") as f:
@ -31,6 +34,7 @@ def load_private_key(filename: str) -> rsa.RSAPrivateKey:
raise RuntimeError("Only RSA keys are supported.")
return private_key
def encrypt_string(string: str, public_key: rsa.RSAPublicKey) -> str:
"""Encrypt string, end return it base64 encoded."""
message = string.encode()
@ -40,10 +44,11 @@ def encrypt_string(string: str, public_key: rsa.RSAPublicKey) -> str:
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None,
)
),
)
return base64.b64encode(ciphertext).decode()
def decode_string(ciphertext: str, private_key: rsa.RSAPrivateKey) -> str:
"""Decode a string. String must be base64 encoded."""
decoded = base64.b64decode(ciphertext)
@ -52,18 +57,27 @@ def decode_string(ciphertext: str, private_key: rsa.RSAPrivateKey) -> str:
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
))
label=None,
),
)
return decrypted.decode()
def generate_private_key() -> rsa.RSAPrivateKey:
"""Generate private RSA key."""
private_key = rsa.generate_private_key(public_exponent=constants.RSA_PUBLIC_EXPONENT, key_size=constants.RSA_KEY_SIZE)
private_key = rsa.generate_private_key(
public_exponent=constants.RSA_PUBLIC_EXPONENT, key_size=constants.RSA_KEY_SIZE
)
return private_key
def generate_pem(private_key: rsa.RSAPrivateKey) -> str:
"""Generate PEM."""
pem = private_key.private_bytes(encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.OpenSSH, encryption_algorithm=serialization.NoEncryption())
pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.OpenSSH,
encryption_algorithm=serialization.NoEncryption(),
)
return pem.decode()
@ -74,7 +88,11 @@ def create_private_rsa_key(filename: Path) -> None:
LOG.debug("Generating private RSA key at %s", filename)
private_key = generate_private_key()
with open(filename, "wb") as f:
pem = private_key.private_bytes(encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.OpenSSH, encryption_algorithm=serialization.NoEncryption())
pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.OpenSSH,
encryption_algorithm=serialization.NoEncryption(),
)
lines = f.write(pem)
LOG.debug("Wrote %s lines", lines)
f.flush()
@ -82,5 +100,8 @@ def create_private_rsa_key(filename: Path) -> None:
def generate_public_key_string(public_key: rsa.RSAPublicKey) -> str:
"""Generate public key string."""
keybytes = public_key.public_bytes(encoding=serialization.Encoding.OpenSSH, format=serialization.PublicFormat.OpenSSH)
keybytes = public_key.public_bytes(
encoding=serialization.Encoding.OpenSSH,
format=serialization.PublicFormat.OpenSSH,
)
return keybytes.decode()

View File

@ -20,29 +20,39 @@ def thread_id_filter(record: logging.LogRecord) -> logging.LogRecord:
record.thread_id = threading.get_native_id()
return record
LOG = logging.getLogger()
handler = logging.StreamHandler()
handler.addFilter(thread_id_filter)
formatter = logging.Formatter("%(thread_id)d:%(created)f:%(levelname)s:%(name)s:%(module)s:%(message)s")
formatter = logging.Formatter(
"%(thread_id)d:%(created)f:%(levelname)s:%(name)s:%(module)s:%(message)s"
)
handler.setFormatter(formatter)
LOG.addHandler(handler)
LOG.setLevel(logging.DEBUG)
@click.group()
def cli() -> None:
"""Run commands for testing."""
@cli.command("create-client")
@click.argument("name")
@click.argument("filename", type=click.Path(file_okay=True, dir_okay=False, writable=True))
@click.argument(
"filename", type=click.Path(file_okay=True, dir_okay=False, writable=True)
)
@click.option("--public-key", type=click.Path(file_okay=True))
def create_client(name: str, filename: str, public_key: str | None) -> None:
"""Create a client."""
create_client_file(name, filename, keyfile=public_key)
click.echo(f"Wrote client config to {filename}")
@cli.command("add-secret")
@click.argument("filename", type=click.Path(file_okay=True, dir_okay=False, writable=True))
@click.argument(
"filename", type=click.Path(file_okay=True, dir_okay=False, writable=True)
)
@click.argument("secret-name")
@click.argument("secret-value")
def add_secret(filename: str, secret_name: str, secret_value: str) -> None:
@ -50,6 +60,7 @@ def add_secret(filename: str, secret_name: str, secret_value: str) -> None:
add_secret_to_client_file(filename, secret_name, secret_value)
click.echo(f"Wrote secret to {filename}")
@cli.command("server")
@click.argument("directory", type=click.Path(file_okay=False, dir_okay=True))
@click.argument("port", type=click.INT)
@ -70,6 +81,5 @@ def run_async_server(directory: str, port: int) -> None:
loop.run_forever()
if __name__ == "__main__":
cli()

View File

@ -1,4 +1,5 @@
"""Keepass integration."""
import logging
from pathlib import Path
from typing import final, override, Self
@ -10,6 +11,7 @@ from .utils import generate_password
LOG = logging.getLogger(__name__)
@final
class KeepassManager(BasePasswordManager):
"""KeepassXC compatible password manager."""
@ -35,7 +37,9 @@ class KeepassManager(BasePasswordManager):
@override
@classmethod
def create_database(cls, location: str, reader_context: PasswordContext, overwrite: bool = False) -> Self:
def create_database(
cls, location: str, reader_context: PasswordContext, overwrite: bool = False
) -> Self:
"""Create database."""
if Path(location).exists() and not overwrite:
raise RuntimeError("Error: Database exists.")
@ -74,7 +78,9 @@ class KeepassManager(BasePasswordManager):
"""Generate password."""
# Generate a password.
password = generate_password()
_entry = self.keepass.add_entry(self.keepass.root_group, identifier, constants.NO_USERNAME, password)
_entry = self.keepass.add_entry(
self.keepass.root_group, identifier, constants.NO_USERNAME, password
)
self.keepass.save()
LOG.debug("Created Entry %r", _entry)
return password

View File

@ -1,6 +1,5 @@
"""Sshecret server module."""
from .server import SshKeyServer
__all__ = ["SshKeyServer"]

View File

@ -41,7 +41,9 @@ def handle_client(process: asyncssh.SSHServerProcess[str]) -> None:
process.exit(1)
return
LOG.debug("Client %s successfully connected. Fetching secret %s", client.name, secret_name)
LOG.debug(
"Client %s successfully connected. Fetching secret %s", client.name, secret_name
)
secret = client.secrets.get(secret_name)
if not secret:
@ -52,6 +54,7 @@ def handle_client(process: asyncssh.SSHServerProcess[str]) -> None:
process.stdout.write(secret)
process.exit(0)
class AsshyncServer(asyncssh.SSHServer):
"""Asynchronous SSH server implementation."""
@ -92,12 +95,11 @@ class AsshyncServer(asyncssh.SSHServer):
"""Deny password authentication."""
return False
def check_connection_allowed(self, client: ClientSpecification, source: str) -> bool:
def check_connection_allowed(
self, client: ClientSpecification, source: str
) -> bool:
"""Check if client is allowed to request secrets."""
LOG.debug(
"Checking if client is allowed to log in from %s", source
)
LOG.debug("Checking if client is allowed to log in from %s", source)
if isinstance(client.allowed_ips, str) and client.allowed_ips == "*":
LOG.debug("Client has no restrictions on source IP address. Permitting.")
return True
@ -109,14 +111,19 @@ class AsshyncServer(asyncssh.SSHServer):
LOG.warning(
"Connection for client %s received from IP address %s that is not permitted.",
client.name,
source
source,
)
return False
async def start_server(port: int, backend: BaseClientBackend, host_key: str, create_key: bool = False) -> None:
async def start_server(
port: int, backend: BaseClientBackend, host_key: str, create_key: bool = False
) -> None:
"""Start server."""
server = partial(AsshyncServer, backend=backend)
if create_key:
create_private_rsa_key(Path(host_key))
await asyncssh.create_server(server, '', port, server_host_keys=[host_key], process_factory=handle_client)
await asyncssh.create_server(
server, "", port, server_host_keys=[host_key], process_factory=handle_client
)

View File

@ -1,5 +1,6 @@
"""Server errors."""
class BaseSshecretServerError(Exception):
"""Base Sshecret Server Error."""

View File

@ -1,6 +1,5 @@
"""Testing utilities and classes."""
import tempfile
from dataclasses import dataclass, field
from contextlib import contextmanager

View File

@ -97,7 +97,13 @@ class BaseClientBackend(abc.ABC):
"""Lookup a client specification by name."""
@abc.abstractmethod
def add_secret(self, client_name: str, secret_name: str, secret_value: str, encrypted: bool = False) -> None:
def add_secret(
self,
client_name: str,
secret_name: str,
secret_value: str,
encrypted: bool = False,
) -> None:
"""Add a secret to a client."""
@abc.abstractmethod

View File

@ -4,16 +4,24 @@ import secrets
from pathlib import Path
from .crypto import load_client_key, encrypt_string, generate_private_key, generate_pem, generate_public_key_string
from .crypto import (
load_client_key,
encrypt_string,
generate_private_key,
generate_pem,
generate_public_key_string,
)
from .types import ClientSpecification
def generate_password() -> str:
"""Generate a password.
"""
"""Generate a password."""
return secrets.token_urlsafe(32)
def generate_client_object(name: str, secrets: dict[str, str] | None = None, keyfile: str | None = None) -> ClientSpecification:
def generate_client_object(
name: str, secrets: dict[str, str] | None = None, keyfile: str | None = None
) -> ClientSpecification:
"""Generate a client object."""
private_key = generate_private_key()
if keyfile:
@ -22,21 +30,28 @@ def generate_client_object(name: str, secrets: dict[str, str] | None = None, key
if not contents.startswith("ssh-rsa "):
raise RuntimeError("Error: Key must be an RSA key.")
client = ClientSpecification(name=name, public_key=contents.strip())
public_key = load_client_key(client)
else:
pem = generate_pem(private_key)
public_key = private_key.public_key()
pubkey_str = generate_public_key_string(public_key)
client = ClientSpecification(name=name, public_key=pubkey_str, testing_private_key=pem)
client = ClientSpecification(
name=name, public_key=pubkey_str, testing_private_key=pem
)
if secrets:
for secret_name, secret_value in secrets.items():
client.secrets[secret_name] = encrypt_string(secret_value, public_key)
return client
def create_client_file(name: str, filename: Path | str, secrets: dict[str, str] | None = None, keyfile: str | None = None) -> None:
def create_client_file(
name: str,
filename: Path | str,
secrets: dict[str, str] | None = None,
keyfile: str | None = None,
) -> None:
"""Create client file."""
client = generate_client_object(name, secrets, keyfile)
@ -45,7 +60,9 @@ def create_client_file(name: str, filename: Path | str, secrets: dict[str, str]
f.flush()
def add_secret_to_client_file(filename: str | Path, secret_name: str, secret_value: str) -> None:
def add_secret_to_client_file(
filename: str | Path, secret_name: str, secret_value: str
) -> None:
"""Add secret to client file."""
with open(filename, "r") as f:
client = ClientSpecification.model_validate_json(f.read())