Re-format files
This commit is contained in:
@ -85,9 +85,7 @@ class FileTableBackend(BaseClientBackend):
|
||||
encrypted: bool = False,
|
||||
) -> None:
|
||||
"""Add secret."""
|
||||
client: ClientSpecification = self.table.by.name[
|
||||
client_name
|
||||
] # pyright: ignore[reportAssignmentType]
|
||||
client: ClientSpecification = self.table.by.name[client_name] # pyright: ignore[reportAssignmentType]
|
||||
if not encrypted:
|
||||
public_key = load_client_key(client)
|
||||
secret_value = encrypt_string(secret_value, public_key)
|
||||
|
||||
@ -5,6 +5,7 @@ import click
|
||||
|
||||
from sshecret.crypto import decode_string, load_private_key
|
||||
|
||||
|
||||
def decrypt_secret(encoded: str, client_key: str) -> str:
|
||||
"""Decrypt secret."""
|
||||
private_key = load_private_key(client_key)
|
||||
@ -19,5 +20,6 @@ def cli_decrypt(keyfile: str, encrypted_input: TextIO) -> None:
|
||||
decrypted = decrypt_secret(encrypted_input.read(), keyfile)
|
||||
click.echo(decrypted)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_decrypt()
|
||||
|
||||
@ -7,7 +7,9 @@ VAR_PREFIX = "SSHECRET"
|
||||
ERROR_NO_SECRET_FOUND = "Error: No secret available with the given name."
|
||||
ERROR_UKNOWN_CLIENT_OR_SECRET = "Error: Invalid client or secret name."
|
||||
ERROR_NO_COMMAND_RECEIVED = "Error: No command was received from the client."
|
||||
ERROR_SOURCE_IP_NOT_ALLOWED = "Error: Client not authorized to connect from the given host."
|
||||
ERROR_SOURCE_IP_NOT_ALLOWED = (
|
||||
"Error: Client not authorized to connect from the given host."
|
||||
)
|
||||
|
||||
RSA_PUBLIC_EXPONENT = 65537
|
||||
RSA_KEY_SIZE = 2048
|
||||
|
||||
@ -12,17 +12,20 @@ from . import constants
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_client_key(client: ClientSpecification) -> rsa.RSAPublicKey:
|
||||
"""Load public key."""
|
||||
keybytes = client.public_key.encode()
|
||||
return load_public_key(keybytes)
|
||||
|
||||
|
||||
def load_public_key(keybytes: bytes) -> rsa.RSAPublicKey:
|
||||
public_key = serialization.load_ssh_public_key(keybytes)
|
||||
if not isinstance(public_key, rsa.RSAPublicKey):
|
||||
raise RuntimeError("Only RSA keys are supported.")
|
||||
return public_key
|
||||
|
||||
|
||||
def load_private_key(filename: str) -> rsa.RSAPrivateKey:
|
||||
"""Load a private key."""
|
||||
with open(filename, "rb") as f:
|
||||
@ -31,6 +34,7 @@ def load_private_key(filename: str) -> rsa.RSAPrivateKey:
|
||||
raise RuntimeError("Only RSA keys are supported.")
|
||||
return private_key
|
||||
|
||||
|
||||
def encrypt_string(string: str, public_key: rsa.RSAPublicKey) -> str:
|
||||
"""Encrypt string, end return it base64 encoded."""
|
||||
message = string.encode()
|
||||
@ -40,10 +44,11 @@ def encrypt_string(string: str, public_key: rsa.RSAPublicKey) -> str:
|
||||
mgf=padding.MGF1(algorithm=hashes.SHA256()),
|
||||
algorithm=hashes.SHA256(),
|
||||
label=None,
|
||||
)
|
||||
),
|
||||
)
|
||||
return base64.b64encode(ciphertext).decode()
|
||||
|
||||
|
||||
def decode_string(ciphertext: str, private_key: rsa.RSAPrivateKey) -> str:
|
||||
"""Decode a string. String must be base64 encoded."""
|
||||
decoded = base64.b64decode(ciphertext)
|
||||
@ -52,18 +57,27 @@ def decode_string(ciphertext: str, private_key: rsa.RSAPrivateKey) -> str:
|
||||
padding.OAEP(
|
||||
mgf=padding.MGF1(algorithm=hashes.SHA256()),
|
||||
algorithm=hashes.SHA256(),
|
||||
label=None
|
||||
))
|
||||
label=None,
|
||||
),
|
||||
)
|
||||
return decrypted.decode()
|
||||
|
||||
|
||||
def generate_private_key() -> rsa.RSAPrivateKey:
|
||||
"""Generate private RSA key."""
|
||||
private_key = rsa.generate_private_key(public_exponent=constants.RSA_PUBLIC_EXPONENT, key_size=constants.RSA_KEY_SIZE)
|
||||
private_key = rsa.generate_private_key(
|
||||
public_exponent=constants.RSA_PUBLIC_EXPONENT, key_size=constants.RSA_KEY_SIZE
|
||||
)
|
||||
return private_key
|
||||
|
||||
|
||||
def generate_pem(private_key: rsa.RSAPrivateKey) -> str:
|
||||
"""Generate PEM."""
|
||||
pem = private_key.private_bytes(encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.OpenSSH, encryption_algorithm=serialization.NoEncryption())
|
||||
pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.OpenSSH,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
return pem.decode()
|
||||
|
||||
|
||||
@ -74,7 +88,11 @@ def create_private_rsa_key(filename: Path) -> None:
|
||||
LOG.debug("Generating private RSA key at %s", filename)
|
||||
private_key = generate_private_key()
|
||||
with open(filename, "wb") as f:
|
||||
pem = private_key.private_bytes(encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.OpenSSH, encryption_algorithm=serialization.NoEncryption())
|
||||
pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.OpenSSH,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
lines = f.write(pem)
|
||||
LOG.debug("Wrote %s lines", lines)
|
||||
f.flush()
|
||||
@ -82,5 +100,8 @@ def create_private_rsa_key(filename: Path) -> None:
|
||||
|
||||
def generate_public_key_string(public_key: rsa.RSAPublicKey) -> str:
|
||||
"""Generate public key string."""
|
||||
keybytes = public_key.public_bytes(encoding=serialization.Encoding.OpenSSH, format=serialization.PublicFormat.OpenSSH)
|
||||
keybytes = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.OpenSSH,
|
||||
format=serialization.PublicFormat.OpenSSH,
|
||||
)
|
||||
return keybytes.decode()
|
||||
|
||||
@ -20,29 +20,39 @@ def thread_id_filter(record: logging.LogRecord) -> logging.LogRecord:
|
||||
record.thread_id = threading.get_native_id()
|
||||
return record
|
||||
|
||||
|
||||
LOG = logging.getLogger()
|
||||
handler = logging.StreamHandler()
|
||||
handler.addFilter(thread_id_filter)
|
||||
formatter = logging.Formatter("%(thread_id)d:%(created)f:%(levelname)s:%(name)s:%(module)s:%(message)s")
|
||||
formatter = logging.Formatter(
|
||||
"%(thread_id)d:%(created)f:%(levelname)s:%(name)s:%(module)s:%(message)s"
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
LOG.addHandler(handler)
|
||||
LOG.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli() -> None:
|
||||
"""Run commands for testing."""
|
||||
|
||||
|
||||
@cli.command("create-client")
|
||||
@click.argument("name")
|
||||
@click.argument("filename", type=click.Path(file_okay=True, dir_okay=False, writable=True))
|
||||
@click.argument(
|
||||
"filename", type=click.Path(file_okay=True, dir_okay=False, writable=True)
|
||||
)
|
||||
@click.option("--public-key", type=click.Path(file_okay=True))
|
||||
def create_client(name: str, filename: str, public_key: str | None) -> None:
|
||||
"""Create a client."""
|
||||
create_client_file(name, filename, keyfile=public_key)
|
||||
click.echo(f"Wrote client config to {filename}")
|
||||
|
||||
|
||||
@cli.command("add-secret")
|
||||
@click.argument("filename", type=click.Path(file_okay=True, dir_okay=False, writable=True))
|
||||
@click.argument(
|
||||
"filename", type=click.Path(file_okay=True, dir_okay=False, writable=True)
|
||||
)
|
||||
@click.argument("secret-name")
|
||||
@click.argument("secret-value")
|
||||
def add_secret(filename: str, secret_name: str, secret_value: str) -> None:
|
||||
@ -50,6 +60,7 @@ def add_secret(filename: str, secret_name: str, secret_value: str) -> None:
|
||||
add_secret_to_client_file(filename, secret_name, secret_value)
|
||||
click.echo(f"Wrote secret to {filename}")
|
||||
|
||||
|
||||
@cli.command("server")
|
||||
@click.argument("directory", type=click.Path(file_okay=False, dir_okay=True))
|
||||
@click.argument("port", type=click.INT)
|
||||
@ -70,6 +81,5 @@ def run_async_server(directory: str, port: int) -> None:
|
||||
loop.run_forever()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Keepass integration."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import final, override, Self
|
||||
@ -10,6 +11,7 @@ from .utils import generate_password
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class KeepassManager(BasePasswordManager):
|
||||
"""KeepassXC compatible password manager."""
|
||||
@ -35,7 +37,9 @@ class KeepassManager(BasePasswordManager):
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
def create_database(cls, location: str, reader_context: PasswordContext, overwrite: bool = False) -> Self:
|
||||
def create_database(
|
||||
cls, location: str, reader_context: PasswordContext, overwrite: bool = False
|
||||
) -> Self:
|
||||
"""Create database."""
|
||||
if Path(location).exists() and not overwrite:
|
||||
raise RuntimeError("Error: Database exists.")
|
||||
@ -74,7 +78,9 @@ class KeepassManager(BasePasswordManager):
|
||||
"""Generate password."""
|
||||
# Generate a password.
|
||||
password = generate_password()
|
||||
_entry = self.keepass.add_entry(self.keepass.root_group, identifier, constants.NO_USERNAME, password)
|
||||
_entry = self.keepass.add_entry(
|
||||
self.keepass.root_group, identifier, constants.NO_USERNAME, password
|
||||
)
|
||||
self.keepass.save()
|
||||
LOG.debug("Created Entry %r", _entry)
|
||||
return password
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
"""Sshecret server module."""
|
||||
|
||||
|
||||
from .server import SshKeyServer
|
||||
|
||||
__all__ = ["SshKeyServer"]
|
||||
|
||||
@ -41,7 +41,9 @@ def handle_client(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
process.exit(1)
|
||||
return
|
||||
|
||||
LOG.debug("Client %s successfully connected. Fetching secret %s", client.name, secret_name)
|
||||
LOG.debug(
|
||||
"Client %s successfully connected. Fetching secret %s", client.name, secret_name
|
||||
)
|
||||
|
||||
secret = client.secrets.get(secret_name)
|
||||
if not secret:
|
||||
@ -52,6 +54,7 @@ def handle_client(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
process.stdout.write(secret)
|
||||
process.exit(0)
|
||||
|
||||
|
||||
class AsshyncServer(asyncssh.SSHServer):
|
||||
"""Asynchronous SSH server implementation."""
|
||||
|
||||
@ -92,12 +95,11 @@ class AsshyncServer(asyncssh.SSHServer):
|
||||
"""Deny password authentication."""
|
||||
return False
|
||||
|
||||
|
||||
def check_connection_allowed(self, client: ClientSpecification, source: str) -> bool:
|
||||
def check_connection_allowed(
|
||||
self, client: ClientSpecification, source: str
|
||||
) -> bool:
|
||||
"""Check if client is allowed to request secrets."""
|
||||
LOG.debug(
|
||||
"Checking if client is allowed to log in from %s", source
|
||||
)
|
||||
LOG.debug("Checking if client is allowed to log in from %s", source)
|
||||
if isinstance(client.allowed_ips, str) and client.allowed_ips == "*":
|
||||
LOG.debug("Client has no restrictions on source IP address. Permitting.")
|
||||
return True
|
||||
@ -109,14 +111,19 @@ class AsshyncServer(asyncssh.SSHServer):
|
||||
LOG.warning(
|
||||
"Connection for client %s received from IP address %s that is not permitted.",
|
||||
client.name,
|
||||
source
|
||||
source,
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
async def start_server(port: int, backend: BaseClientBackend, host_key: str, create_key: bool = False) -> None:
|
||||
|
||||
async def start_server(
|
||||
port: int, backend: BaseClientBackend, host_key: str, create_key: bool = False
|
||||
) -> None:
|
||||
"""Start server."""
|
||||
server = partial(AsshyncServer, backend=backend)
|
||||
if create_key:
|
||||
create_private_rsa_key(Path(host_key))
|
||||
await asyncssh.create_server(server, '', port, server_host_keys=[host_key], process_factory=handle_client)
|
||||
await asyncssh.create_server(
|
||||
server, "", port, server_host_keys=[host_key], process_factory=handle_client
|
||||
)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""Server errors."""
|
||||
|
||||
|
||||
class BaseSshecretServerError(Exception):
|
||||
"""Base Sshecret Server Error."""
|
||||
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
"""Testing utilities and classes."""
|
||||
|
||||
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from contextlib import contextmanager
|
||||
|
||||
@ -97,7 +97,13 @@ class BaseClientBackend(abc.ABC):
|
||||
"""Lookup a client specification by name."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def add_secret(self, client_name: str, secret_name: str, secret_value: str, encrypted: bool = False) -> None:
|
||||
def add_secret(
|
||||
self,
|
||||
client_name: str,
|
||||
secret_name: str,
|
||||
secret_value: str,
|
||||
encrypted: bool = False,
|
||||
) -> None:
|
||||
"""Add a secret to a client."""
|
||||
|
||||
@abc.abstractmethod
|
||||
|
||||
@ -4,16 +4,24 @@ import secrets
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from .crypto import load_client_key, encrypt_string, generate_private_key, generate_pem, generate_public_key_string
|
||||
from .crypto import (
|
||||
load_client_key,
|
||||
encrypt_string,
|
||||
generate_private_key,
|
||||
generate_pem,
|
||||
generate_public_key_string,
|
||||
)
|
||||
from .types import ClientSpecification
|
||||
|
||||
|
||||
def generate_password() -> str:
|
||||
"""Generate a password.
|
||||
"""
|
||||
"""Generate a password."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def generate_client_object(name: str, secrets: dict[str, str] | None = None, keyfile: str | None = None) -> ClientSpecification:
|
||||
def generate_client_object(
|
||||
name: str, secrets: dict[str, str] | None = None, keyfile: str | None = None
|
||||
) -> ClientSpecification:
|
||||
"""Generate a client object."""
|
||||
private_key = generate_private_key()
|
||||
if keyfile:
|
||||
@ -22,21 +30,28 @@ def generate_client_object(name: str, secrets: dict[str, str] | None = None, key
|
||||
if not contents.startswith("ssh-rsa "):
|
||||
raise RuntimeError("Error: Key must be an RSA key.")
|
||||
|
||||
|
||||
client = ClientSpecification(name=name, public_key=contents.strip())
|
||||
public_key = load_client_key(client)
|
||||
else:
|
||||
pem = generate_pem(private_key)
|
||||
public_key = private_key.public_key()
|
||||
pubkey_str = generate_public_key_string(public_key)
|
||||
client = ClientSpecification(name=name, public_key=pubkey_str, testing_private_key=pem)
|
||||
client = ClientSpecification(
|
||||
name=name, public_key=pubkey_str, testing_private_key=pem
|
||||
)
|
||||
if secrets:
|
||||
for secret_name, secret_value in secrets.items():
|
||||
client.secrets[secret_name] = encrypt_string(secret_value, public_key)
|
||||
|
||||
return client
|
||||
|
||||
def create_client_file(name: str, filename: Path | str, secrets: dict[str, str] | None = None, keyfile: str | None = None) -> None:
|
||||
|
||||
def create_client_file(
|
||||
name: str,
|
||||
filename: Path | str,
|
||||
secrets: dict[str, str] | None = None,
|
||||
keyfile: str | None = None,
|
||||
) -> None:
|
||||
"""Create client file."""
|
||||
client = generate_client_object(name, secrets, keyfile)
|
||||
|
||||
@ -45,7 +60,9 @@ def create_client_file(name: str, filename: Path | str, secrets: dict[str, str]
|
||||
f.flush()
|
||||
|
||||
|
||||
def add_secret_to_client_file(filename: str | Path, secret_name: str, secret_value: str) -> None:
|
||||
def add_secret_to_client_file(
|
||||
filename: str | Path, secret_name: str, secret_value: str
|
||||
) -> None:
|
||||
"""Add secret to client file."""
|
||||
with open(filename, "r") as f:
|
||||
client = ClientSpecification.model_validate_json(f.read())
|
||||
|
||||
Reference in New Issue
Block a user