Re-format files

This commit is contained in:
2025-03-23 16:06:53 +01:00
parent de46a47acc
commit 08d8031d09
12 changed files with 107 additions and 39 deletions

View File

@ -85,9 +85,7 @@ class FileTableBackend(BaseClientBackend):
encrypted: bool = False, encrypted: bool = False,
) -> None: ) -> None:
"""Add secret.""" """Add secret."""
client: ClientSpecification = self.table.by.name[ client: ClientSpecification = self.table.by.name[client_name] # pyright: ignore[reportAssignmentType]
client_name
] # pyright: ignore[reportAssignmentType]
if not encrypted: if not encrypted:
public_key = load_client_key(client) public_key = load_client_key(client)
secret_value = encrypt_string(secret_value, public_key) secret_value = encrypt_string(secret_value, public_key)

View File

@ -5,6 +5,7 @@ import click
from sshecret.crypto import decode_string, load_private_key from sshecret.crypto import decode_string, load_private_key
def decrypt_secret(encoded: str, client_key: str) -> str: def decrypt_secret(encoded: str, client_key: str) -> str:
"""Decrypt secret.""" """Decrypt secret."""
private_key = load_private_key(client_key) private_key = load_private_key(client_key)
@ -19,5 +20,6 @@ def cli_decrypt(keyfile: str, encrypted_input: TextIO) -> None:
decrypted = decrypt_secret(encrypted_input.read(), keyfile) decrypted = decrypt_secret(encrypted_input.read(), keyfile)
click.echo(decrypted) click.echo(decrypted)
if __name__ == "__main__": if __name__ == "__main__":
cli_decrypt() cli_decrypt()

View File

@ -7,7 +7,9 @@ VAR_PREFIX = "SSHECRET"
ERROR_NO_SECRET_FOUND = "Error: No secret available with the given name." ERROR_NO_SECRET_FOUND = "Error: No secret available with the given name."
ERROR_UKNOWN_CLIENT_OR_SECRET = "Error: Invalid client or secret name." ERROR_UKNOWN_CLIENT_OR_SECRET = "Error: Invalid client or secret name."
ERROR_NO_COMMAND_RECEIVED = "Error: No command was received from the client." ERROR_NO_COMMAND_RECEIVED = "Error: No command was received from the client."
ERROR_SOURCE_IP_NOT_ALLOWED = "Error: Client not authorized to connect from the given host." ERROR_SOURCE_IP_NOT_ALLOWED = (
"Error: Client not authorized to connect from the given host."
)
RSA_PUBLIC_EXPONENT = 65537 RSA_PUBLIC_EXPONENT = 65537
RSA_KEY_SIZE = 2048 RSA_KEY_SIZE = 2048

View File

@ -12,17 +12,20 @@ from . import constants
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def load_client_key(client: ClientSpecification) -> rsa.RSAPublicKey: def load_client_key(client: ClientSpecification) -> rsa.RSAPublicKey:
"""Load public key.""" """Load public key."""
keybytes = client.public_key.encode() keybytes = client.public_key.encode()
return load_public_key(keybytes) return load_public_key(keybytes)
def load_public_key(keybytes: bytes) -> rsa.RSAPublicKey: def load_public_key(keybytes: bytes) -> rsa.RSAPublicKey:
public_key = serialization.load_ssh_public_key(keybytes) public_key = serialization.load_ssh_public_key(keybytes)
if not isinstance(public_key, rsa.RSAPublicKey): if not isinstance(public_key, rsa.RSAPublicKey):
raise RuntimeError("Only RSA keys are supported.") raise RuntimeError("Only RSA keys are supported.")
return public_key return public_key
def load_private_key(filename: str) -> rsa.RSAPrivateKey: def load_private_key(filename: str) -> rsa.RSAPrivateKey:
"""Load a private key.""" """Load a private key."""
with open(filename, "rb") as f: with open(filename, "rb") as f:
@ -31,6 +34,7 @@ def load_private_key(filename: str) -> rsa.RSAPrivateKey:
raise RuntimeError("Only RSA keys are supported.") raise RuntimeError("Only RSA keys are supported.")
return private_key return private_key
def encrypt_string(string: str, public_key: rsa.RSAPublicKey) -> str: def encrypt_string(string: str, public_key: rsa.RSAPublicKey) -> str:
"""Encrypt string, end return it base64 encoded.""" """Encrypt string, end return it base64 encoded."""
message = string.encode() message = string.encode()
@ -40,10 +44,11 @@ def encrypt_string(string: str, public_key: rsa.RSAPublicKey) -> str:
mgf=padding.MGF1(algorithm=hashes.SHA256()), mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(), algorithm=hashes.SHA256(),
label=None, label=None,
) ),
) )
return base64.b64encode(ciphertext).decode() return base64.b64encode(ciphertext).decode()
def decode_string(ciphertext: str, private_key: rsa.RSAPrivateKey) -> str: def decode_string(ciphertext: str, private_key: rsa.RSAPrivateKey) -> str:
"""Decode a string. String must be base64 encoded.""" """Decode a string. String must be base64 encoded."""
decoded = base64.b64decode(ciphertext) decoded = base64.b64decode(ciphertext)
@ -52,18 +57,27 @@ def decode_string(ciphertext: str, private_key: rsa.RSAPrivateKey) -> str:
padding.OAEP( padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()), mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(), algorithm=hashes.SHA256(),
label=None label=None,
)) ),
)
return decrypted.decode() return decrypted.decode()
def generate_private_key() -> rsa.RSAPrivateKey: def generate_private_key() -> rsa.RSAPrivateKey:
"""Generate private RSA key.""" """Generate private RSA key."""
private_key = rsa.generate_private_key(public_exponent=constants.RSA_PUBLIC_EXPONENT, key_size=constants.RSA_KEY_SIZE) private_key = rsa.generate_private_key(
public_exponent=constants.RSA_PUBLIC_EXPONENT, key_size=constants.RSA_KEY_SIZE
)
return private_key return private_key
def generate_pem(private_key: rsa.RSAPrivateKey) -> str: def generate_pem(private_key: rsa.RSAPrivateKey) -> str:
"""Generate PEM.""" """Generate PEM."""
pem = private_key.private_bytes(encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.OpenSSH, encryption_algorithm=serialization.NoEncryption()) pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.OpenSSH,
encryption_algorithm=serialization.NoEncryption(),
)
return pem.decode() return pem.decode()
@ -74,7 +88,11 @@ def create_private_rsa_key(filename: Path) -> None:
LOG.debug("Generating private RSA key at %s", filename) LOG.debug("Generating private RSA key at %s", filename)
private_key = generate_private_key() private_key = generate_private_key()
with open(filename, "wb") as f: with open(filename, "wb") as f:
pem = private_key.private_bytes(encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.OpenSSH, encryption_algorithm=serialization.NoEncryption()) pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.OpenSSH,
encryption_algorithm=serialization.NoEncryption(),
)
lines = f.write(pem) lines = f.write(pem)
LOG.debug("Wrote %s lines", lines) LOG.debug("Wrote %s lines", lines)
f.flush() f.flush()
@ -82,5 +100,8 @@ def create_private_rsa_key(filename: Path) -> None:
def generate_public_key_string(public_key: rsa.RSAPublicKey) -> str: def generate_public_key_string(public_key: rsa.RSAPublicKey) -> str:
"""Generate public key string.""" """Generate public key string."""
keybytes = public_key.public_bytes(encoding=serialization.Encoding.OpenSSH, format=serialization.PublicFormat.OpenSSH) keybytes = public_key.public_bytes(
encoding=serialization.Encoding.OpenSSH,
format=serialization.PublicFormat.OpenSSH,
)
return keybytes.decode() return keybytes.decode()

View File

@ -20,29 +20,39 @@ def thread_id_filter(record: logging.LogRecord) -> logging.LogRecord:
record.thread_id = threading.get_native_id() record.thread_id = threading.get_native_id()
return record return record
LOG = logging.getLogger() LOG = logging.getLogger()
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.addFilter(thread_id_filter) handler.addFilter(thread_id_filter)
formatter = logging.Formatter("%(thread_id)d:%(created)f:%(levelname)s:%(name)s:%(module)s:%(message)s") formatter = logging.Formatter(
"%(thread_id)d:%(created)f:%(levelname)s:%(name)s:%(module)s:%(message)s"
)
handler.setFormatter(formatter) handler.setFormatter(formatter)
LOG.addHandler(handler) LOG.addHandler(handler)
LOG.setLevel(logging.DEBUG) LOG.setLevel(logging.DEBUG)
@click.group() @click.group()
def cli() -> None: def cli() -> None:
"""Run commands for testing.""" """Run commands for testing."""
@cli.command("create-client") @cli.command("create-client")
@click.argument("name") @click.argument("name")
@click.argument("filename", type=click.Path(file_okay=True, dir_okay=False, writable=True)) @click.argument(
"filename", type=click.Path(file_okay=True, dir_okay=False, writable=True)
)
@click.option("--public-key", type=click.Path(file_okay=True)) @click.option("--public-key", type=click.Path(file_okay=True))
def create_client(name: str, filename: str, public_key: str | None) -> None: def create_client(name: str, filename: str, public_key: str | None) -> None:
"""Create a client.""" """Create a client."""
create_client_file(name, filename, keyfile=public_key) create_client_file(name, filename, keyfile=public_key)
click.echo(f"Wrote client config to {filename}") click.echo(f"Wrote client config to {filename}")
@cli.command("add-secret") @cli.command("add-secret")
@click.argument("filename", type=click.Path(file_okay=True, dir_okay=False, writable=True)) @click.argument(
"filename", type=click.Path(file_okay=True, dir_okay=False, writable=True)
)
@click.argument("secret-name") @click.argument("secret-name")
@click.argument("secret-value") @click.argument("secret-value")
def add_secret(filename: str, secret_name: str, secret_value: str) -> None: def add_secret(filename: str, secret_name: str, secret_value: str) -> None:
@ -50,6 +60,7 @@ def add_secret(filename: str, secret_name: str, secret_value: str) -> None:
add_secret_to_client_file(filename, secret_name, secret_value) add_secret_to_client_file(filename, secret_name, secret_value)
click.echo(f"Wrote secret to {filename}") click.echo(f"Wrote secret to {filename}")
@cli.command("server") @cli.command("server")
@click.argument("directory", type=click.Path(file_okay=False, dir_okay=True)) @click.argument("directory", type=click.Path(file_okay=False, dir_okay=True))
@click.argument("port", type=click.INT) @click.argument("port", type=click.INT)
@ -70,6 +81,5 @@ def run_async_server(directory: str, port: int) -> None:
loop.run_forever() loop.run_forever()
if __name__ == "__main__": if __name__ == "__main__":
cli() cli()

View File

@ -1,4 +1,5 @@
"""Keepass integration.""" """Keepass integration."""
import logging import logging
from pathlib import Path from pathlib import Path
from typing import final, override, Self from typing import final, override, Self
@ -10,6 +11,7 @@ from .utils import generate_password
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@final @final
class KeepassManager(BasePasswordManager): class KeepassManager(BasePasswordManager):
"""KeepassXC compatible password manager.""" """KeepassXC compatible password manager."""
@ -35,7 +37,9 @@ class KeepassManager(BasePasswordManager):
@override @override
@classmethod @classmethod
def create_database(cls, location: str, reader_context: PasswordContext, overwrite: bool = False) -> Self: def create_database(
cls, location: str, reader_context: PasswordContext, overwrite: bool = False
) -> Self:
"""Create database.""" """Create database."""
if Path(location).exists() and not overwrite: if Path(location).exists() and not overwrite:
raise RuntimeError("Error: Database exists.") raise RuntimeError("Error: Database exists.")
@ -74,7 +78,9 @@ class KeepassManager(BasePasswordManager):
"""Generate password.""" """Generate password."""
# Generate a password. # Generate a password.
password = generate_password() password = generate_password()
_entry = self.keepass.add_entry(self.keepass.root_group, identifier, constants.NO_USERNAME, password) _entry = self.keepass.add_entry(
self.keepass.root_group, identifier, constants.NO_USERNAME, password
)
self.keepass.save() self.keepass.save()
LOG.debug("Created Entry %r", _entry) LOG.debug("Created Entry %r", _entry)
return password return password

View File

@ -1,6 +1,5 @@
"""Sshecret server module.""" """Sshecret server module."""
from .server import SshKeyServer from .server import SshKeyServer
__all__ = ["SshKeyServer"] __all__ = ["SshKeyServer"]

View File

@ -41,7 +41,9 @@ def handle_client(process: asyncssh.SSHServerProcess[str]) -> None:
process.exit(1) process.exit(1)
return return
LOG.debug("Client %s successfully connected. Fetching secret %s", client.name, secret_name) LOG.debug(
"Client %s successfully connected. Fetching secret %s", client.name, secret_name
)
secret = client.secrets.get(secret_name) secret = client.secrets.get(secret_name)
if not secret: if not secret:
@ -52,6 +54,7 @@ def handle_client(process: asyncssh.SSHServerProcess[str]) -> None:
process.stdout.write(secret) process.stdout.write(secret)
process.exit(0) process.exit(0)
class AsshyncServer(asyncssh.SSHServer): class AsshyncServer(asyncssh.SSHServer):
"""Asynchronous SSH server implementation.""" """Asynchronous SSH server implementation."""
@ -92,12 +95,11 @@ class AsshyncServer(asyncssh.SSHServer):
"""Deny password authentication.""" """Deny password authentication."""
return False return False
def check_connection_allowed(
def check_connection_allowed(self, client: ClientSpecification, source: str) -> bool: self, client: ClientSpecification, source: str
) -> bool:
"""Check if client is allowed to request secrets.""" """Check if client is allowed to request secrets."""
LOG.debug( LOG.debug("Checking if client is allowed to log in from %s", source)
"Checking if client is allowed to log in from %s", source
)
if isinstance(client.allowed_ips, str) and client.allowed_ips == "*": if isinstance(client.allowed_ips, str) and client.allowed_ips == "*":
LOG.debug("Client has no restrictions on source IP address. Permitting.") LOG.debug("Client has no restrictions on source IP address. Permitting.")
return True return True
@ -109,14 +111,19 @@ class AsshyncServer(asyncssh.SSHServer):
LOG.warning( LOG.warning(
"Connection for client %s received from IP address %s that is not permitted.", "Connection for client %s received from IP address %s that is not permitted.",
client.name, client.name,
source source,
) )
return False return False
async def start_server(port: int, backend: BaseClientBackend, host_key: str, create_key: bool = False) -> None:
async def start_server(
port: int, backend: BaseClientBackend, host_key: str, create_key: bool = False
) -> None:
"""Start server.""" """Start server."""
server = partial(AsshyncServer, backend=backend) server = partial(AsshyncServer, backend=backend)
if create_key: if create_key:
create_private_rsa_key(Path(host_key)) create_private_rsa_key(Path(host_key))
await asyncssh.create_server(server, '', port, server_host_keys=[host_key], process_factory=handle_client) await asyncssh.create_server(
server, "", port, server_host_keys=[host_key], process_factory=handle_client
)

View File

@ -1,5 +1,6 @@
"""Server errors.""" """Server errors."""
class BaseSshecretServerError(Exception): class BaseSshecretServerError(Exception):
"""Base Sshecret Server Error.""" """Base Sshecret Server Error."""

View File

@ -1,6 +1,5 @@
"""Testing utilities and classes.""" """Testing utilities and classes."""
import tempfile import tempfile
from dataclasses import dataclass, field from dataclasses import dataclass, field
from contextlib import contextmanager from contextlib import contextmanager

View File

@ -97,7 +97,13 @@ class BaseClientBackend(abc.ABC):
"""Lookup a client specification by name.""" """Lookup a client specification by name."""
@abc.abstractmethod @abc.abstractmethod
def add_secret(self, client_name: str, secret_name: str, secret_value: str, encrypted: bool = False) -> None: def add_secret(
self,
client_name: str,
secret_name: str,
secret_value: str,
encrypted: bool = False,
) -> None:
"""Add a secret to a client.""" """Add a secret to a client."""
@abc.abstractmethod @abc.abstractmethod

View File

@ -4,16 +4,24 @@ import secrets
from pathlib import Path from pathlib import Path
from .crypto import load_client_key, encrypt_string, generate_private_key, generate_pem, generate_public_key_string from .crypto import (
load_client_key,
encrypt_string,
generate_private_key,
generate_pem,
generate_public_key_string,
)
from .types import ClientSpecification from .types import ClientSpecification
def generate_password() -> str: def generate_password() -> str:
"""Generate a password. """Generate a password."""
"""
return secrets.token_urlsafe(32) return secrets.token_urlsafe(32)
def generate_client_object(name: str, secrets: dict[str, str] | None = None, keyfile: str | None = None) -> ClientSpecification: def generate_client_object(
name: str, secrets: dict[str, str] | None = None, keyfile: str | None = None
) -> ClientSpecification:
"""Generate a client object.""" """Generate a client object."""
private_key = generate_private_key() private_key = generate_private_key()
if keyfile: if keyfile:
@ -22,21 +30,28 @@ def generate_client_object(name: str, secrets: dict[str, str] | None = None, key
if not contents.startswith("ssh-rsa "): if not contents.startswith("ssh-rsa "):
raise RuntimeError("Error: Key must be an RSA key.") raise RuntimeError("Error: Key must be an RSA key.")
client = ClientSpecification(name=name, public_key=contents.strip()) client = ClientSpecification(name=name, public_key=contents.strip())
public_key = load_client_key(client) public_key = load_client_key(client)
else: else:
pem = generate_pem(private_key) pem = generate_pem(private_key)
public_key = private_key.public_key() public_key = private_key.public_key()
pubkey_str = generate_public_key_string(public_key) pubkey_str = generate_public_key_string(public_key)
client = ClientSpecification(name=name, public_key=pubkey_str, testing_private_key=pem) client = ClientSpecification(
name=name, public_key=pubkey_str, testing_private_key=pem
)
if secrets: if secrets:
for secret_name, secret_value in secrets.items(): for secret_name, secret_value in secrets.items():
client.secrets[secret_name] = encrypt_string(secret_value, public_key) client.secrets[secret_name] = encrypt_string(secret_value, public_key)
return client return client
def create_client_file(name: str, filename: Path | str, secrets: dict[str, str] | None = None, keyfile: str | None = None) -> None:
def create_client_file(
name: str,
filename: Path | str,
secrets: dict[str, str] | None = None,
keyfile: str | None = None,
) -> None:
"""Create client file.""" """Create client file."""
client = generate_client_object(name, secrets, keyfile) client = generate_client_object(name, secrets, keyfile)
@ -45,7 +60,9 @@ def create_client_file(name: str, filename: Path | str, secrets: dict[str, str]
f.flush() f.flush()
def add_secret_to_client_file(filename: str | Path, secret_name: str, secret_value: str) -> None: def add_secret_to_client_file(
filename: str | Path, secret_name: str, secret_value: str
) -> None:
"""Add secret to client file.""" """Add secret to client file."""
with open(filename, "r") as f: with open(filename, "r") as f:
client = ClientSpecification.model_validate_json(f.read()) client = ClientSpecification.model_validate_json(f.read())