diff --git a/src/sshecret/__init__.py b/src/sshecret/__init__.py index 0e57d6a..e69de29 100644 --- a/src/sshecret/__init__.py +++ b/src/sshecret/__init__.py @@ -1,2 +0,0 @@ -def hello() -> str: - return "Hello from sshecret!" diff --git a/src/sshecret/api.py b/src/sshecret/api.py deleted file mode 100644 index 8f577c1..0000000 --- a/src/sshecret/api.py +++ /dev/null @@ -1,447 +0,0 @@ -"""API. - -This module is an attempt to create some sort of meaningfull API around the -actions exposed here. -""" - -import abc - -from contextlib import contextmanager -from collections.abc import Iterator - -from pydantic.networks import IPvAnyAddress, IPvAnyNetwork - -from .audit import audit_message - -from .crypto import load_client_key, load_public_key, encrypt_string - -from .types import ( - BaseAPIClient, - BaseClientBackend, - BasePasswordManager, - BasePasswordReader, - ClientSpecification, - PasswordContext, -) - - -@contextmanager -def password_manager_session( - password_manager: BasePasswordManager, - password_context: PasswordContext | str, - api_client: BaseAPIClient, -) -> Iterator[BasePasswordManager]: - """Open password manager for read/write in a context.""" - audit_message( - "Opening password manager session", - "SECURITY", - source_address=api_client.source, - ) - password_manager.open_database(password_context) - yield password_manager - - audit_message( - "Closing password manager session", - "SECURITY", - source_address=api_client.source, - ) - password_manager.close_database() - - -class BaseSshecretAPI(abc.ABC): - """Base API class.""" - - def __init__( - self, - backend: BaseClientBackend, - api_client: BaseAPIClient, - manager_options: dict[str, str] | None = None, - ) -> None: - """Initialize API.""" - - self.backend: BaseClientBackend = backend - self.api_client: BaseAPIClient = api_client - self.manager_options: dict[str, str] | None = manager_options - - def _log_audit( - self, - message: str, - audit_type: str, - client_name: str | None = None, - **details: str, - ) -> None: - """Log an audit message.""" - audit_message( - message, - audit_type, - client_name, - source_address=self.api_client.source, - **details, - ) - - @contextmanager - def password_session( - self, reader: BasePasswordReader | None = None, password: str | None = None - ) -> Iterator[BasePasswordManager]: - """Open a password session.""" - if password: - context = password - else: - if not reader: - reader = self.api_client.get_reader() - context = self.api_client.get_context(reader) - - password_manager = self.api_client.password_manager(self.manager_options) - with password_manager_session( - password_manager, context, self.api_client - ) as session: - yield session - - -class ClientManagementAPI(BaseSshecretAPI): - """API for managing clients.""" - - def __init__( - self, - backend: BaseClientBackend, - client: ClientSpecification, - api_client: BaseAPIClient, - manager_options: dict[str, str] | None = None, - ) -> None: - """Create client management API instance.""" - super().__init__(backend, api_client, manager_options) - self.client: ClientSpecification = client - self.__password_manager: BasePasswordManager | None = None - - def log_security(self, message: str, **details: str) -> None: - """Log a security related message.""" - self._log_audit(message, "SECURITY", self.client.name, **details) - - def log_info(self, message: str) -> None: - """Log an informational message.""" - self._log_audit(message, "INFORMATIONAL", self.client.name) - - @property - def password_manager(self) -> BasePasswordManager: - """Get password manager.""" - if self.__password_manager: - self.log_security("Accessed password manager") - return self.__password_manager - raise RuntimeError("Password manager not initialized.") - - @password_manager.setter - def password_manager(self, instance: BasePasswordManager) -> None: - """Set password manager instance.""" - self.log_security("Opened password manager.") - self.__password_manager = instance - - def get_secrets(self) -> list[str]: - """Get names of the secrets that the client has access to..""" - self.log_security("Listing secret names.") - return list(self.client.secrets.keys()) - - def update_client( - self, - client: ClientSpecification, - reader: BasePasswordReader | None = None, - password: str | None = None, - ) -> ClientSpecification: - """Update client.""" - self.log_info("Updating client") - update_data = client.model_dump(exclude_unset=True) - if client.public_key != self.client.public_key: - self.log_security("Client public key has changed.") - if client.secrets != self.client.secrets: - raise RuntimeError( - "Error: Cannot update public key and secrets in the same operation." - ) - del update_data["secrets"] - secrets = self._re_encrypt(client.public_key, reader, password) - update_data["secrets"] = secrets - - updated_client = self.client.model_copy(update=update_data) - self.backend.update_client(self.client.name, updated_client) - self.client = updated_client - return updated_client - - def update_secret(self, name: str, password: str) -> None: - """Update a secret. - - If secret is not already a part of the client, it will be added. - """ - if name in self.client.secrets: - self.log_security("Updating secret", secret_name=name) - else: - self.log_security("Adding secret", secret_name=name) - public_key = load_client_key(self.client) - encrypted = encrypt_string(password, public_key) - client_secrets = {**self.client.secrets, name: encrypted} - updated_client = self.client.model_copy(update={"secrets": client_secrets}) - self.backend.update_client(self.client.name, updated_client) - self.client = updated_client - - def _re_encrypt( - self, - new_key: str, - reader: BasePasswordReader | None = None, - password: str | None = None, - ) -> dict[str, str]: - """Update public key on given client.""" - audit_message( - "Updating public key", - "INFORMATIONAL", - self.client.name, - source_address=self.api_client.source, - ) - if password: - context = password - else: - if not reader: - reader = self.api_client.get_reader() - context = self.api_client.get_context(reader) - - password_manager = self.api_client.password_manager(self.manager_options) - self.password_manager = password_manager - - client_key = load_public_key(new_key.encode()) - secrets: dict[str, str] = {} - with password_manager_session( - password_manager, context, self.api_client - ) as password_session: - for name in self.get_secrets(): - secret = password_session.get_password(name) - audit_message( - "Updating encrypted value", - "SECURITY", - self.client.name, - secret_name=name, - source_address=self.api_client.source, - ) - if not secret: - raise RuntimeError("Could not fetch a new secret value.") - new_value = encrypt_string(secret, client_key) - secrets[name] = new_value - - return secrets - - def update_public_key( - self, - new_key: str, - reader: BasePasswordReader | None = None, - password: str | None = None, - ) -> None: - """Update the public key.""" - audit_message( - "Updating public key", - "INFORMATIONAL", - self.client.name, - source_address=self.api_client.source, - ) - secrets = self._re_encrypt(new_key, reader, password) - - updated = self.client.model_copy(update={"secrets": secrets}) - self.backend.update_client(self.client.name, updated) - - @classmethod - def get_client( - cls, - backend: BaseClientBackend, - name: str, - api_client: BaseAPIClient, - ) -> "ClientManagementAPI | None": - """Get client.""" - client = backend.lookup_name(name) - if not client: - return None - return cls(backend, client, api_client) - - @classmethod - def create_client( - cls, - backend: BaseClientBackend, - name: str, - api_client: BaseAPIClient, - public_key: str, - allowed_ips: list[IPvAnyAddress | IPvAnyNetwork] | str = "*", - ) -> "ClientManagementAPI": - """Create a client.""" - client = ClientSpecification( - name=name, public_key=public_key, allowed_ips=allowed_ips - ) - backend.add_client(client) - return cls(backend, client, api_client) - - -class ManagementApi(BaseSshecretAPI): - """Api for general management.""" - - def __init__( - self, - backend: BaseClientBackend, - api_client: BaseAPIClient, - manager_options: dict[str, str] | None = None, - ) -> None: - """Initialize API.""" - super().__init__(backend, api_client, manager_options) - - def log_security( - self, message: str, client_name: str | None = None, **details: str - ) -> None: - """Log a security related message.""" - self._log_audit(message, "SECURITY", client_name, **details) - - def log_info(self, message: str, client_name: str | None = None) -> None: - """Log an informational message.""" - self._log_audit(message, "INFORMATIONAL", client_name) - - def get_client(self, name: str) -> ClientManagementAPI | None: - """Get a client.""" - client = self.backend.lookup_name(name) - if not client: - return None - return ClientManagementAPI( - self.backend, client, self.api_client, self.manager_options - ) - - def get_clients(self) -> list[ClientSpecification]: - """Get clients.""" - self.log_info("Fetched all clients") - return self.backend.get_all() - - def _get_clients(self) -> list[ClientSpecification]: - """Get clients.""" - return self.backend.get_all() - - def delete_client(self, name: str) -> None: - """Delete client.""" - client = self.backend.lookup_name(name) - if not client: - self.log_info("Attempted to delete a non-existing client.", name) - return - self.log_security("Deleting client", name) - self.backend.remove_client(name) - - def create_client( - self, - name: str, - public_key: str, - allowed_ips: list[IPvAnyAddress | IPvAnyNetwork] | str = "*", - ) -> ClientManagementAPI: - """Create a client.""" - self.log_info("Creating new client", name) - client = ClientSpecification( - name=name, public_key=public_key, allowed_ips=allowed_ips - ) - self.backend.add_client(client) - return ClientManagementAPI(self.backend, client, self.api_client) - - def get_secret_names( - self, reader: BasePasswordReader | None = None, password: str | None = None - ) -> dict[str, list[str]]: - """Get secret names and which clients have these..""" - self.log_security("Listing all secret names.") - with self.password_session(reader=reader, password=password) as session: - secret_names = session.get_entries() - - secret_mapping: dict[str, list[str]] = {} - for name in secret_names: - secret_mapping[name] = [ - client.name for client in self.backend.lookup_by_secret(name) - ] - - return secret_mapping - - def add_secret( - self, - name: str, - secret_value: str | None, - reader: BasePasswordReader | None = None, - password: str | None = None, - ) -> str: - """Add a secret.""" - self.log_security("Adding new secret", secret_name=name) - with self.password_session(reader=reader, password=password) as session: - if not secret_value: - self.log_security("Auto-generating a secret value", secret_name=name) - secret_value = session.generate_password(name) - else: - session.add_password(name, secret_value) - return secret_value - - def get_secret( - self, - name: str, - reader: BasePasswordReader | None = None, - password: str | None = None, - ) -> str | None: - """Get the clear-text value of a secret.""" - self.log_security("Client requested secret value", secret_name=name) - with self.password_session(reader=reader, password=password) as session: - secret = session.get_password(name) - - return secret - - def update_secret( - self, - name: str, - new_value: str, - reader: BasePasswordReader | None = None, - password: str | None = None, - ) -> None: - """Update a secret with a given name.""" - self.log_security("Changing secret", secret_name=name) - with self.password_session(reader=reader, password=password) as session: - session.change_password(name, new_value) - - clients = self.backend.lookup_by_secret(name) - for client in clients: - client_api = self.get_client(client.name) - if not client_api: - continue - client_api.update_secret(name, new_value) - - def regenerate_secret( - self, - name: str, - reader: BasePasswordReader | None = None, - password: str | None = None, - ) -> str: - """Regenerate a secret.""" - self.log_security("Generating a new secret value", secret_name=name) - with self.password_session(reader=reader, password=password) as session: - new_value = session.change_password(name, None) - - clients = self.backend.lookup_by_secret(name) - for client in clients: - client_api = self.get_client(client.name) - if not client_api: - continue - client_api.update_secret(name, new_value) - - return new_value - - def delete_secret( - self, - name: str, - reader: BasePasswordReader | None = None, - password: str | None = None, - ) -> None: - """Delete secret.""" - clients = self.backend.lookup_by_secret(name) - self.log_security("Deleting secret", secret_name=name) - with self.password_session(reader=reader, password=password) as session: - session.delete_password(name) - - for client in clients: - secrets = {**client.secrets} - del secrets[name] - new_client = client.model_copy(update={"secrets": secrets}) - client_api = self.get_client(client.name) - if not client_api: - continue - self.log_security( - "Removing secret from client.", - client_name=client.name, - secret_name=name, - ) - client_api.update_client(new_client, password=password) diff --git a/src/sshecret/audit.py b/src/sshecret/audit.py deleted file mode 100644 index 0894d1e..0000000 --- a/src/sshecret/audit.py +++ /dev/null @@ -1,66 +0,0 @@ -"""Audit setup.""" - -import enum -import json -import logging - -from typing import Any -from pythonjsonlogger.json import JsonFormatter - -from pydantic import BaseModel, ConfigDict - -from .constants import AUDIT_LOG_NAME - -AUDIT_LOG = logging.getLogger(AUDIT_LOG_NAME) - - -class AuditMessageType(enum.StrEnum): - """Audit Message Type.""" - - ACCESS = enum.auto() # Someone accessed something - SECURITY = enum.auto() # A message related to security - INFORMATIONAL = enum.auto() # other informational messages - - -class AuditMessage(BaseModel): - """Audit message.""" - - model_config = ConfigDict(use_enum_values=True) - - type: AuditMessageType - message: str - client_name: str | None = None - source_address: str | None = None - secret_name: str | None = None - - def __str__(self) -> str: - """Stringify object as JSON.""" - return self.model_dump_json() - - -def audit_message( - message: str, - audit_type: AuditMessageType | str | None = None, - client_name: str | None = None, - secret_name: str | None = None, - source_address: str | None = None, - **details: str -) -> None: - """Create an audit message.""" - if not audit_type: - audit_type = AuditMessageType.INFORMATIONAL - - if audit_type not in list(AuditMessageType): - audit_type = AuditMessageType.INFORMATIONAL - - audit_message = AuditMessage( - type=audit_type, - message=message, - client_name=client_name, - source_address=source_address, - secret_name=secret_name, - ) - - audit_dict = audit_message.model_dump(exclude_none=True) - - AUDIT_LOG.info({**audit_dict, **details}) diff --git a/src/sshecret/backends/__init__.py b/src/sshecret/backends/__init__.py deleted file mode 100644 index 9e79583..0000000 --- a/src/sshecret/backends/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Backend implementations""" - -from .file_table import FileTableBackend - -__all__ = ["FileTableBackend"] diff --git a/src/sshecret/backends/file_table.py b/src/sshecret/backends/file_table.py deleted file mode 100644 index 4883ec1..0000000 --- a/src/sshecret/backends/file_table.py +++ /dev/null @@ -1,133 +0,0 @@ -"""File table based backend.""" - -import logging -import os -from pathlib import Path -from typing import override - -import littletable as lt - -from sshecret.crypto import load_client_key, encrypt_string -from sshecret.types import ClientSpecification -from sshecret.types import BaseClientBackend - -LOG = logging.getLogger(__name__) - - -def load_clients_from_dir(directory: Path) -> dict[Path, ClientSpecification]: - """Load clients from a directory.""" - if not directory.exists() or not directory.is_dir(): - raise ValueError("Invalid directory specified.") - - clients: dict[Path, ClientSpecification] = {} - for client_file in directory.glob("*.json"): - with open(client_file, "r") as f: - client = ClientSpecification.model_validate_json(f.read()) - if client_file.name != f"{client.name}.json": - raise RuntimeError( - "Filename scheme of clients does not conform to expected format. Aborting import!" - ) - clients[client_file] = client - - return clients - - - -class FileTableBackend(BaseClientBackend): - """In-memory littletable based backend.""" - - def __init__(self, directory: Path) -> None: - """Create backend instance.""" - LOG.debug("Creating in-memory table to hold clients.") - self._directory: Path = directory - self.table: lt.Table[ClientSpecification] = lt.Table() - self._setup_table() - client_files = load_clients_from_dir(directory) - client_count = len(client_files) - LOG.debug("Loaded %s clients from disk.", client_count) - # self.client_file_map: dict[str, Path] = {client.name: filepath for filepath, client in client_files.items()} - LOG.debug("Inserting clients into table.") - self.table.insert_many(list(client_files.values())) - - def _setup_table(self) -> None: - """Set up the table.""" - self.table.create_index("name", unique=True) - - @override - def lookup_name(self, name: str) -> ClientSpecification | None: - """Lookup client by name.""" - if result := self.table.by.name.get(name): - if isinstance(result, ClientSpecification): - return result - return None - - @override - def add_client(self, spec: ClientSpecification) -> None: - """Add client.""" - self.table.insert(spec) - self._write_spec_file(spec) - - def _write_spec_file(self, spec: ClientSpecification) -> None: - """Write spec file to disk.""" - dest_file_name = f"{spec.name}.json" - dest_file = self._directory / dest_file_name - with open(dest_file.absolute(), "w") as f: - f.write( - spec.model_dump_json(exclude_none=True, exclude_unset=True, indent=2) - ) - f.flush() - - @override - def add_secret( - self, - client_name: str, - secret_name: str, - secret_value: str, - encrypted: bool = False, - ) -> None: - """Add secret.""" - client: ClientSpecification = self.table.by.name[client_name] # pyright: ignore[reportAssignmentType] - if not encrypted: - public_key = load_client_key(client) - secret_value = encrypt_string(secret_value, public_key) - client.secrets[secret_name] = secret_value - self._update_client_data(client) - self._write_spec_file(client) - - @override - def remove_client(self, name: str, persistent: bool = True) -> None: - """Delete client.""" - client = self.lookup_name(name) - if not client: - raise ValueError("Client does not exist!") - self.table.remove(client) - if persistent: - filename = f"{client.name}.json" - filepath = self._directory / filename - filepath.unlink() - - @override - def update_client(self, name: str, spec: ClientSpecification) -> None: - """Update client.""" - if not self.lookup_name(name): - raise ValueError("Client does not exist!") - self._update_client_data(spec) - self._write_spec_file(spec) - - def _update_client_data(self, spec: ClientSpecification) -> None: - """Update client data.""" - existing = self.lookup_name(spec.name) - if existing: - self.table.remove(existing) - self.add_client(spec) - - @override - def get_all(self) -> list[ClientSpecification]: - """Get all clients.""" - return list(self.table) - - @override - def lookup_by_secret(self, secret_name: str) -> list[ClientSpecification]: - """Lookup by secret name.""" - results = self.table.where(lambda client: secret_name in client.secrets) - return list(results) diff --git a/src/sshecret/cli.py b/src/sshecret/cli.py deleted file mode 100644 index 3b43108..0000000 --- a/src/sshecret/cli.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Command Line Interface""" - -import click diff --git a/src/sshecret/client.py b/src/sshecret/client.py deleted file mode 100644 index 1a23780..0000000 --- a/src/sshecret/client.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Client code""" - -from typing import TextIO -import click - -from sshecret.crypto import decode_string, load_private_key - - -def decrypt_secret(encoded: str, client_key: str) -> str: - """Decrypt secret.""" - private_key = load_private_key(client_key) - return decode_string(encoded, private_key) - - -@click.command() -@click.argument("keyfile", type=click.Path(exists=True, readable=True, dir_okay=False)) -@click.argument("encrypted_input", type=click.File("r")) -def cli_decrypt(keyfile: str, encrypted_input: TextIO) -> None: - """Decrypt on command line.""" - decrypted = decrypt_secret(encrypted_input.read(), keyfile) - click.echo(decrypted) - - -if __name__ == "__main__": - cli_decrypt() diff --git a/src/sshecret/config.py b/src/sshecret/config.py deleted file mode 100644 index 75915a4..0000000 --- a/src/sshecret/config.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Config file.""" -from pathlib import Path -from pydantic import SecretStr -from pydantic_settings import BaseSettings - - -class KeepassSettings(BaseSettings): - """Settings for Keepasss password database.""" - - database_path: Path - - -class SshecretSettings(BaseSettings): - """Settings model.""" - - admin_password: SecretStr - admin_ssh_key: str | None = None - keepass: KeepassSettings diff --git a/src/sshecret/constants.py b/src/sshecret/constants.py deleted file mode 100644 index e4e99da..0000000 --- a/src/sshecret/constants.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Constants.""" - -MASTER_PASSWORD = "MASTER_PASSWORD" -NO_USERNAME = "NO_USERNAME" -VAR_PREFIX = "SSHECRET" - -ERROR_NO_SECRET_FOUND = "Error: No secret available with the given name." -ERROR_UKNOWN_CLIENT_OR_SECRET = "Error: Invalid client or secret name." -ERROR_NO_COMMAND_RECEIVED = "Error: No command was received from the client." -ERROR_SOURCE_IP_NOT_ALLOWED = ( - "Error: Client not authorized to connect from the given host." -) - -RSA_PUBLIC_EXPONENT = 65537 -RSA_KEY_SIZE = 2048 - -AUDIT_LOG_NAME = "AUDIT" diff --git a/src/sshecret/crypto.py b/src/sshecret/crypto.py deleted file mode 100644 index 8afd75d..0000000 --- a/src/sshecret/crypto.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Encryption related functions. - -Note! Encryption uses the less secure PKCS1v15 padding. This is to allow -decryption via openssl on the command line. - -""" - -import base64 -import logging -from pathlib import Path -from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.hazmat.primitives.asymmetric import padding - -from .types import ClientSpecification -from . import constants - -LOG = logging.getLogger(__name__) - - -def load_client_key(client: ClientSpecification) -> rsa.RSAPublicKey: - """Load public key.""" - keybytes = client.public_key.encode() - return load_public_key(keybytes) - - -def load_public_key(keybytes: bytes) -> rsa.RSAPublicKey: - public_key = serialization.load_ssh_public_key(keybytes) - if not isinstance(public_key, rsa.RSAPublicKey): - raise RuntimeError("Only RSA keys are supported.") - pem_public_key = public_key.public_bytes(encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo) - LOG.info("pem:\n%s", pem_public_key) - return public_key - - -def load_private_key(filename: str) -> rsa.RSAPrivateKey: - """Load a private key.""" - with open(filename, "rb") as f: - private_key = serialization.load_ssh_private_key(f.read(), password=None) - if not isinstance(private_key, rsa.RSAPrivateKey): - raise RuntimeError("Only RSA keys are supported.") - return private_key - - -def encrypt_string(string: str, public_key: rsa.RSAPublicKey) -> str: - """Encrypt string, end return it base64 encoded.""" - message = string.encode() - ciphertext = public_key.encrypt( - message, - padding.PKCS1v15(), - ) - return base64.b64encode(ciphertext).decode() - - -def decode_string(ciphertext: str, private_key: rsa.RSAPrivateKey) -> str: - """Decode a string. String must be base64 encoded.""" - decoded = base64.b64decode(ciphertext) - decrypted = private_key.decrypt( - decoded, - padding.PKCS1v15(), - ) - return decrypted.decode() - - -def generate_private_key() -> rsa.RSAPrivateKey: - """Generate private RSA key.""" - private_key = rsa.generate_private_key( - public_exponent=constants.RSA_PUBLIC_EXPONENT, key_size=constants.RSA_KEY_SIZE - ) - return private_key - - -def generate_pem(private_key: rsa.RSAPrivateKey) -> str: - """Generate PEM.""" - pem = private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.OpenSSH, - encryption_algorithm=serialization.NoEncryption(), - ) - return pem.decode() - - -def create_private_rsa_key(filename: Path) -> None: - """Create an RSA Private key at the given path.""" - if filename.exists(): - raise RuntimeError("Error: private key file already exists.") - LOG.debug("Generating private RSA key at %s", filename) - private_key = generate_private_key() - with open(filename, "wb") as f: - pem = private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.OpenSSH, - encryption_algorithm=serialization.NoEncryption(), - ) - lines = f.write(pem) - LOG.debug("Wrote %s lines", lines) - f.flush() - - -def generate_public_key_string(public_key: rsa.RSAPublicKey) -> str: - """Generate public key string.""" - keybytes = public_key.public_bytes( - encoding=serialization.Encoding.OpenSSH, - format=serialization.PublicFormat.OpenSSH, - ) - return keybytes.decode() diff --git a/src/sshecret/dev_cli.py b/src/sshecret/dev_cli.py deleted file mode 100644 index cedcf0d..0000000 --- a/src/sshecret/dev_cli.py +++ /dev/null @@ -1,93 +0,0 @@ -"""Development CLI commands.""" - -import sys -import asyncio -import asyncssh -import click - -import logging -import tempfile -import threading -from pathlib import Path - -from pythonjsonlogger.json import JsonFormatter - -from .server import start_server -from sshecret.backends import FileTableBackend -from .utils import create_client_file, add_secret_to_client_file -from .constants import AUDIT_LOG_NAME - - -def thread_id_filter(record: logging.LogRecord) -> logging.LogRecord: - """Resolve thread id.""" - record.thread_id = threading.get_native_id() - return record - - -LOG = logging.getLogger() -handler = logging.StreamHandler() -handler.addFilter(thread_id_filter) -formatter = logging.Formatter( - "%(thread_id)d:%(created)f:%(levelname)s:%(name)s:%(module)s:%(message)s" -) -handler.setFormatter(formatter) -LOG.addHandler(handler) -LOG.setLevel(logging.DEBUG) - -AUDIT_LOG = logging.getLogger(AUDIT_LOG_NAME) -audit_formatter = JsonFormatter() -audit_handler = logging.StreamHandler() -audit_handler.setFormatter(audit_formatter) -AUDIT_LOG.addHandler(audit_handler) - - -@click.group() -def cli() -> None: - """Run commands for testing.""" - - -@cli.command("create-client") -@click.argument("name") -@click.argument( - "filename", type=click.Path(file_okay=True, dir_okay=False, writable=True) -) -@click.option("--public-key", type=click.Path(file_okay=True)) -def create_client(name: str, filename: str, public_key: str | None) -> None: - """Create a client.""" - create_client_file(name, filename, keyfile=public_key) - click.echo(f"Wrote client config to {filename}") - - -@cli.command("add-secret") -@click.argument( - "filename", type=click.Path(file_okay=True, dir_okay=False, writable=True) -) -@click.argument("secret-name") -@click.argument("secret-value") -def add_secret(filename: str, secret_name: str, secret_value: str) -> None: - """Add secret to client file.""" - add_secret_to_client_file(filename, secret_name, secret_value) - click.echo(f"Wrote secret to {filename}") - - -@cli.command("server") -@click.argument("directory", type=click.Path(file_okay=False, dir_okay=True)) -@click.argument("port", type=click.INT) -def run_async_server(directory: str, port: int) -> None: - """Run async server.""" - loop = asyncio.new_event_loop() - with tempfile.TemporaryDirectory() as tmpdir: - serverdir = Path(tmpdir) - host_key = str(serverdir / "hostkey") - clientdir = Path(directory) - backend = FileTableBackend(clientdir) - try: - loop.run_until_complete(start_server(port, backend, host_key, True)) - except (OSError, asyncssh.Error) as exc: - click.echo(f"Error starting server: {exc}") - sys.exit(1) - - loop.run_forever() - -if __name__ == "__main__": - cli() diff --git a/src/sshecret/keepass.py b/src/sshecret/keepass.py deleted file mode 100644 index 144b065..0000000 --- a/src/sshecret/keepass.py +++ /dev/null @@ -1,180 +0,0 @@ -"""Keepass integration.""" - -import logging -from pathlib import Path -from typing import cast, final, overload, override, Self - -import pykeepass -from . import constants -from .types import BasePasswordManager, PasswordContext -from .utils import generate_password - -LOG = logging.getLogger(__name__) - - -@final -class KeepassManager(BasePasswordManager): - """KeepassXC compatible password manager.""" - - master_password_identifier = constants.MASTER_PASSWORD - - def __init__(self) -> None: - """Initialize password manager.""" - self._location: Path | None = None - self._keepass: pykeepass.PyKeePass | None = None - - @property - def location(self) -> Path: - """Get location.""" - if not self._location: - raise RuntimeError("No location has been specified.") - return self._location - - @location.setter - def location(self, location: Path) -> None: - """Set location.""" - if not location.exists() or not location.is_file(): - raise RuntimeError("Unable to read provided password file.") - self._location = location - - @override - def set_manager_options(self, options: dict[str, str]) -> None: - """Set manager options.""" - if "location" in options: - location = Path(str(options["location"])) - self.location = location - - @property - def keepass(self) -> pykeepass.PyKeePass: - """Return keepass instance.""" - if self._keepass: - return self._keepass - raise RuntimeError("Error: Database has not been opened.") - - @keepass.setter - def keepass(self, instance: pykeepass.PyKeePass) -> None: - """Set the keepass instance.""" - self._keepass = instance - - @override - def get_entries(self) -> list[str]: - """Get all entries.""" - entries = self.keepass.entries - if not entries: - return [] - return [ - str(entry.title) for entry in entries - ] - - - @override - @classmethod - def create_database( - cls, location: str, password_context: PasswordContext | str, overwrite: bool = False - ) -> Self: - """Create database.""" - if Path(location).exists() and not overwrite: - raise RuntimeError("Error: Database exists.") - - if isinstance(password_context, PasswordContext): - master_password = password_context.get_password(cls.master_password_identifier, True) - else: - master_password = password_context - - # TODO: should we delete if overwrite is set? - keepass = pykeepass.create_database(location, password=master_password) - instance = cls() - instance.set_manager_options({"location": str(location)}) - instance.keepass = keepass - return instance - - @override - def open_database(self, password_context: PasswordContext | str) -> None: - """Open the database""" - if isinstance(password_context, PasswordContext): - password = password_context.get_password(self.master_password_identifier) - else: - password = password_context - instance = pykeepass.PyKeePass(str(self.location.absolute()), password=password) - self.keepass = instance - - @override - def close_database(self) -> None: - """Close the database.""" - self._keepass = None - - @override - def get_password(self, identifier: str) -> str | None: - """Get password.""" - entry = cast("pykeepass.entry.Entry | None", self.keepass.find_entries(title=identifier, first=True)) - if not entry: - return None - - if password := cast(str, entry.password): - return str(password) - raise RuntimeError(f"Cannot get password for entry {identifier}") - - - @override - def generate_password(self, identifier: str) -> str: - """Generate password.""" - # Generate a password. - password = generate_password() - _entry = self.keepass.add_entry( - self.keepass.root_group, identifier, constants.NO_USERNAME, password - ) - self.keepass.save() - LOG.debug("Created Entry %r", _entry) - return password - - @override - def add_password(self, identifier: str, password: str) -> None: - """Add a password.""" - entry = cast("pykeepass.entry.Entry | None", self.keepass.find_entries(title=identifier, first=True)) - if not entry: - _entry = self.keepass.add_entry(self.keepass.root_group, identifier, constants.NO_USERNAME, password) - self.keepass.save() - LOG.debug("Created entry %r", _entry) - return - self.change_password(identifier, password) - LOG.debug("Updated password on entry %r", entry) - - - @overload - def change_password(self, identifier: str, password: None) -> str: ... - - @overload - def change_password(self, identifier: str, password: str) -> None: ... - - @override - def change_password(self, identifier: str, password: str | None) -> str | None: - """Change a password.""" - generated_password = False - if not password: - password = generate_password() - generated_password = True - - entry = cast("pykeepass.entry.Entry | None", self.keepass.find_entries(title=identifier, first=True)) - - if not entry: - raise ValueError("Error: Entry not found!") - - entry.password = password - - self.keepass.save() - if generated_password: - return password - - return None - - @override - def delete_password(self, identifier: str) -> None: - """Delete password.""" - entry = cast("pykeepass.entry.Entry | None", self.keepass.find_entries(title=identifier, first=True)) - if not entry: - return - LOG.info("Deleting entry %s for keepass.", entry.uuid) - - self.keepass.delete_entry(entry) - - self.keepass.save() diff --git a/src/sshecret/password_readers.py b/src/sshecret/password_readers.py deleted file mode 100644 index 0cff46b..0000000 --- a/src/sshecret/password_readers.py +++ /dev/null @@ -1,61 +0,0 @@ -"""Password reader classes. - -This implements two interfaces to read passwords: - -InputPasswordReader and EnvironmentPasswordReader. -""" - -import re -import os -import sys -from typing import TextIO, override -import click - -from .types import BasePasswordReader -from . import constants - -RE_VARNAME = re.compile(r"^[a-zA-Z_]+[a-zA-Z0-9_]*$") - - -class InputPasswordReader(BasePasswordReader): - """Read a password from stdin.""" - - @override - def get_password(self, identifier: str, repeated: bool = False) -> str: - """Get password.""" - if password := click.prompt( - f"Enter password for {identifier}", hide_input=True, type=str, confirmation_prompt=repeated - ): - return str(password) - raise ValueError("No password received.") - - -class EnvironmentPasswordReader(BasePasswordReader): - """Read a password from the environment. - - The environemnt variable will be constructured based on the identifier and the prefix. - Final environemnt variable will be validated according to the regex `[a-zA-Z_]+[a-zA-Z0-9_]*` - """ - - def _resolve_var_name(self, identifier: str) -> str: - """Resolve variable name.""" - identifier = identifier.replace("-", "_") - fields = [constants.VAR_PREFIX, identifier] - varname = "_".join(fields) - if not RE_VARNAME.fullmatch(varname): - raise ValueError( - f"Cannot generate encode password identifier in variable name. {varname} is not a valid identifier." - ) - return varname - - def get_password_from_env(self, identifier: str) -> str: - """Get password from environment.""" - varname = self._resolve_var_name(identifier) - if password := os.getenv(varname, None): - return password - raise ValueError(f"Error: No variable named {varname} resolved.") - - @override - def get_password(self, identifier: str, repeated: bool = False) -> str: - """Get password.""" - return self.get_password_from_env(identifier) diff --git a/src/sshecret/py.typed b/src/sshecret/py.typed deleted file mode 100644 index e69de29..0000000 diff --git a/src/sshecret/server/__init__.py b/src/sshecret/server/__init__.py deleted file mode 100644 index 409595a..0000000 --- a/src/sshecret/server/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Sshecret server module.""" - -from .async_server import AsshyncServer, start_server - -__all__ = ["AsshyncServer", "start_server"] diff --git a/src/sshecret/server/async_server.py b/src/sshecret/server/async_server.py deleted file mode 100644 index 195e083..0000000 --- a/src/sshecret/server/async_server.py +++ /dev/null @@ -1,143 +0,0 @@ -"""Server implemented with asyncssh.""" - -import logging -from functools import partial -from pathlib import Path -from typing import override - -import asyncssh - -from sshecret import constants -from sshecret.audit import audit_message -from sshecret.types import ClientSpecification, BaseClientBackend -from sshecret.crypto import create_private_rsa_key - - -LOG = logging.getLogger(__name__) - - -def handle_client(process: asyncssh.SSHServerProcess[str]) -> None: - """Handle client.""" - remote_ip = process.get_extra_info("peername")[0] - client_found = process.get_extra_info("client_allowed", False) - if not client_found: - process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n") - audit_message("Unknown connection", source_address=remote_ip) - process.exit(1) - return - - client_allowed = process.get_extra_info("client_allowed", False) - if not client_allowed: - audit_message("Not permitted", "SECURITY", source_address=remote_ip) - process.stderr.write(constants.ERROR_SOURCE_IP_NOT_ALLOWED + "\n") - process.exit(1) - return - - client = process.get_extra_info("client") - if not client: - audit_message("Unknown client", source_address=remote_ip) - process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n") - process.exit(1) - return - - secret_name = process.command - if not secret_name: - audit_message("No secret specified", source_address=remote_ip, client_name=client.name) - process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n") - process.exit(1) - return - - LOG.debug( - "Client %s successfully connected. Fetching secret %s", client.name, secret_name - ) - - audit_message(f"Requested secret", client_name=client.name, secret_name=secret_name, source_address=remote_ip) - secret = client.secrets.get(secret_name) - if not secret: - process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n") - process.exit(1) - return - - audit_message("Accessed secret", client.name, secret_name, source_address=remote_ip) - process.stdout.write(secret) - process.exit(0) - - -class AsshyncServer(asyncssh.SSHServer): - """Asynchronous SSH server implementation.""" - - def __init__(self, backend: BaseClientBackend) -> None: - """Initialize server.""" - self.backend: BaseClientBackend = backend - self._conn: asyncssh.SSHServerConnection | None = None - - @override - def connection_made(self, conn: asyncssh.SSHServerConnection) -> None: - """Handle incoming connection.""" - peername = conn.get_extra_info("peername") - LOG.debug("Connection established from %r", peername) - self._conn = conn - - @override - def begin_auth(self, username: str) -> bool: - """Begin authentication.""" - if not self._conn: - return True - client = self.backend.lookup_name(username) - if not client: - return True - self._conn.set_extra_info(client_found=True) - remote_ip = self._conn.get_extra_info("peername")[0] - LOG.debug("Remote_IP: %r", remote_ip) - assert isinstance(remote_ip, str) - if self.check_connection_allowed(client, remote_ip): - audit_message("Authentication requested", "ACCESS", client_name=client.name, source_address=remote_ip) - self._conn.set_extra_info(client_allowed=True) - self._conn.set_extra_info(client=client) - - # Load the key. - public_key = asyncssh.import_authorized_keys(client.public_key) - self._conn.set_authorized_keys(public_key) - - return True - - @override - def password_auth_supported(self) -> bool: - """Deny password authentication.""" - return False - - def check_connection_allowed( - self, client: ClientSpecification, source: str - ) -> bool: - """Check if client is allowed to request secrets.""" - LOG.debug("Checking if client is allowed to log in from %s", source) - if isinstance(client.allowed_ips, str) and client.allowed_ips == "*": - audit_message("Permitting login", "SECURITY", client_name=client.name, source_address=source) - LOG.debug("Client has no restrictions on source IP address. Permitting.") - return True - if isinstance(client.allowed_ips, str) and "/" not in client.allowed_ips: - if source == client.allowed_ips: - audit_message("Permitting login", "SECURITY", client_name=client.name, source_address=source) - LOG.debug("Client IP matches permitted address") - return True - - LOG.warning( - "Connection for client %s received from IP address %s that is not permitted.", - client.name, - source, - ) - audit_message("REJECTED. Invalid address", "SECURITY", client_name=client.name, source_address=source) - - return False - - -async def start_server( - port: int, backend: BaseClientBackend, host_key: str, create_key: bool = False -) -> None: - """Start server.""" - server = partial(AsshyncServer, backend=backend) - if create_key: - create_private_rsa_key(Path(host_key)) - await asyncssh.create_server( - server, "", port, server_host_keys=[host_key], process_factory=handle_client - ) diff --git a/src/sshecret/server/errors.py b/src/sshecret/server/errors.py deleted file mode 100644 index e058e9c..0000000 --- a/src/sshecret/server/errors.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Server errors.""" - - -class BaseSshecretServerError(Exception): - """Base Sshecret Server Error.""" - - -class UnknownClientError(BaseSshecretServerError): - """Client was not recognized.""" - - -class AccessDeniedError(BaseSshecretServerError): - """Client was not authorized to access the resource.""" - - -class AccessPolicyViolationError(BaseSshecretServerError): - """Client was not authorized to access the secret.""" - - -class UnknownSecretError(BaseSshecretServerError): - """Error when resolving the secret.""" diff --git a/src/sshecret/server/ssh_password_reader.py b/src/sshecret/server/ssh_password_reader.py deleted file mode 100644 index 7aeb50c..0000000 --- a/src/sshecret/server/ssh_password_reader.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Password reader for use with the SSH server.""" - -from typing import override, TextIO -import asyncssh -from sshecret.types import BasePasswordReader - - -class SSHPasswordReader(BasePasswordReader): - """SSH Password reader.""" - - def __init__(self, channel: asyncssh.SSHLineEditorChannel, stdin: asyncssh.SSHReader[str], stdout: asyncssh.SSHWriter[str]) -> None: - """Initialize password reader.""" - self.channel: asyncssh.SSHLineEditorChannel = channel - self.stdin: asyncssh.SSHReader[str] = stdin - self.stdout: asyncssh.SSHWriter[str] = stdout - - @override - def get_password(self, identifier: str, repeated: bool = False) -> str: - """Get password.""" - raise RuntimeError("Use get_password_async!") - - async def get_password_async(self, identifier: str, repeated: bool = False) -> str: - """Get password async.""" - self.stdout.write(f"Enter password for {identifier}: ") - self.channel.set_echo(False) - while True: - password = await self.stdin.readline() - if not repeated: - break - self.stdout.write(f"\nRe-enter password for {identifier}: ") - password2 = await self.stdin.readline() - if password == password2: - break - self.stdout.write(f"Passwords did not match. Try again.\n") - self.channel.set_echo(True) - return password.strip() diff --git a/src/sshecret/settings.py b/src/sshecret/settings.py deleted file mode 100644 index 9e86e0f..0000000 --- a/src/sshecret/settings.py +++ /dev/null @@ -1,82 +0,0 @@ -"""Get settings.""" - -import abc -import enum -import os -import tomllib -from pathlib import Path -from typing import Literal -from dotenv import load_dotenv - -from pydantic import BaseModel, DirectoryPath, Field, FilePath -from pydantic_settings import BaseSettings, SettingsConfigDict - -from sshecret.keepass import KeepassManager - -SETTINGS_FILE = "sshecret.toml" - - -class Backend(enum.StrEnum): - """Supported backends.""" - - FILES = "FILES" - - -class PasswordManager(enum.StrEnum): - """Supported password managers.""" - - KEEPASS = "KeePass" - - -class SSHServerSettings(BaseModel): - """SSH Server settings.""" - - port: int = 22 - private_key: FilePath | None = None - - -class AdminApiSettings(BaseModel): - """Admin API settings.""" - - port: int = 8022 - - -class FileBackendSettings(BaseModel): - """File backend settings. - - This will eventually have the Discriminator pattern described in pydantic. - """ - - type: Literal["Files"] - - location: DirectoryPath - - -class KeepassPDBSettings(BaseModel): - """Keepass backend settings.""" - - type: Literal["KeePass"] - location: FilePath - -class Settings(BaseSettings): - """Sshecret settings.""" - - model_config = SettingsConfigDict(env_prefix="sshecret_", env_nested_delimiter="__") - - backend: FileBackendSettings - password_manager: KeepassPDBSettings - admin_api: AdminApiSettings = Field(default_factory=AdminApiSettings) - ssh_server: SSHServerSettings = Field(default_factory=SSHServerSettings) - - -def get_settings() -> Settings: - """Get settings.""" - cwd = Path(os.getcwd()) - settings_file = cwd / SETTINGS_FILE - if not settings_file.exists(): - # This should fail if the current env variables don't exist. - return Settings() # pyright: ignore[reportCallIssue] - with open(settings_file, "rb") as f: - settings_data = tomllib.load(f) - - return Settings.model_validate(settings_data) diff --git a/src/sshecret/shell/__init__.py b/src/sshecret/shell/__init__.py deleted file mode 100644 index a0e2394..0000000 --- a/src/sshecret/shell/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Shell interface.""" diff --git a/src/sshecret/shell/admin_shell.py b/src/sshecret/shell/admin_shell.py deleted file mode 100644 index b213fc2..0000000 --- a/src/sshecret/shell/admin_shell.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Admin shell.""" - -import os -import click -from click_repl import register_repl - -from sshecret.api import ClientManagementAPI - -from sshecret import constants -from sshecret.password_readers import InputPasswordReader -from sshecret.keepass import KeepassManager -from sshecret.types import PasswordContext -from .shell_client import ShellClient - -DB_PATH = os.path.join(os.getcwd(), "sshecrets.kdbx") - -api_client: ShellClient | None = None - - -@click.group() -@click.pass_context -def cli(ctx: click.Context) -> None: - """General CLI.""" - if api_client is None: - raise RuntimeError("No client object defined.") - -@cli.group(name="clients") -def cmd_clients() -> None: - """Client context.""" - -@cmd_clients.command(name="show") -def show_clients() -> None: - """Show clients.""" - example_set = ["client1", "client2", "client3"] - for client in example_set: - click.echo(f"- {client}") - -@cmd_clients.command(name="add") -@click.argument("name") -def add_client(name: str) -> None: - """Add a client.""" - public_key = click.prompt("Please paste RSA public key") - - - -@cli.command() -@click.option("--overwrite", is_flag=True, help="Overwrite password database.") -def create_database(overwrite: bool) -> None: - """Create database.""" - context = PasswordContext(InputPasswordReader) - KeepassManager.create_database(DB_PATH, context, overwrite) - - -if __name__ == "__main__": - api_client = ShellClient("127.0.0.1", KeepassManager) diff --git a/src/sshecret/shell/commands.py b/src/sshecret/shell/commands.py deleted file mode 100644 index 724997e..0000000 --- a/src/sshecret/shell/commands.py +++ /dev/null @@ -1,14 +0,0 @@ -"""Shell commands. - -The shell needs to implement the following shell commands: - -- Client management -client create/read/update/delete -secret create/read/update/delete -client permit secret -client revoke secret -client key rotate - -audit show - -""" diff --git a/src/sshecret/shell/shell_client.py b/src/sshecret/shell/shell_client.py deleted file mode 100644 index 6a7eff2..0000000 --- a/src/sshecret/shell/shell_client.py +++ /dev/null @@ -1,29 +0,0 @@ -"""Shell API Client object for auditing.""" - -from dataclasses import dataclass, field -from typing import override -from sshecret.password_readers import InputPasswordReader -from sshecret.types import BaseAPIClient, BasePasswordManager, BasePasswordReader, PasswordContext - - -@dataclass(frozen=True) -class ShellClient(BaseAPIClient): - """Client connecting from local host.""" - - source: str - password_manager_type: type[BasePasswordManager] - method: str = field(init=False, default="shell") - - @override - def get_reader(self) -> type[BasePasswordReader]: - """Get reader.""" - return InputPasswordReader - - @override - def password_manager(self, manager_options: dict[str, str] | None = None) -> BasePasswordManager: - """Instantiate password manager.""" - manager_instance = self.password_manager_type() - if manager_options: - manager_instance.set_manager_options(manager_options) - - return manager_instance diff --git a/src/sshecret/shell/shell_context.py b/src/sshecret/shell/shell_context.py deleted file mode 100644 index 8904a6d..0000000 --- a/src/sshecret/shell/shell_context.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Shell context manager.""" - -import sys -from dataclasses import dataclass -from contextlib import contextmanager -from contextvars import ContextVar -from typing import Iterator, TextIO -from click_shell.core import Shell - -from sshecret.api import ManagementApi -from sshecret.password_readers import InputPasswordReader -from sshecret.types import BaseClientBackend, BasePasswordManager - -from .shell_client import ShellClient - - -@dataclass(frozen=True) -class ShellContext: - """Shell context.""" - - api: ManagementApi - shell: Shell - streams: tuple[TextIO, TextIO] | None = None - - - - - - -@contextmanager -def shell_session( - shell: Shell, - backend: BaseClientBackend, - password_manager: type[BasePasswordManager], - source_address: str, - manager_options: dict[str, str] | None = None, -) -> Iterator[ShellContext]: - """Start a shell session. - - The idea here is to collect the context, store it in an instance variable, - and run the shell. - """ - reader = InputPasswordReader - client = ShellClient(source_address, password_manager) - api = ManagementApi(backend, client, manager_options) diff --git a/src/sshecret/testing.py b/src/sshecret/testing.py deleted file mode 100644 index 927211c..0000000 --- a/src/sshecret/testing.py +++ /dev/null @@ -1,96 +0,0 @@ -"""Testing utilities and classes.""" - -from io import StringIO -import tempfile -from dataclasses import dataclass, field -from contextlib import contextmanager -from pathlib import Path -from collections.abc import Iterator -from dotenv import load_dotenv - -from .utils import create_client_file, generate_password -from . import settings as app_settings -from .keepass import KeepassManager - - -@dataclass -class TestClientSpec: - """Specification of a test client.""" - - name: str - secrets: dict[str, str] = field(default_factory=dict) - - -@dataclass -class TestContext: - """Test context.""" - - path: Path - master_password: str - - @property - def password_database(self) -> Path: - """Return password database location.""" - return self.path / "test.kdbx" - - def get_settings(self) -> app_settings.Settings: - """Get settings.""" - return app_settings.Settings( - backend=app_settings.BackendSettings( - backend=app_settings.FileBackendSettings( - type="Files", location=self.path - ), - ), - password_manager=app_settings.PasswordManagerSettings( - manager=app_settings.KeepassPDBSettings( - type="KeePass", location=self.password_database - ) - ), - ) - - -def set_environment(context: TestContext) -> None: - """Set environment.""" - password_path = str(context.password_database) - env: list[str] = [ - f"sshecret_backend__backend_location={str(context.path)}", - "sshecret_backend__password_manager__manager_type=KeePass", - f"sshecret_backend__password_manager__manager_location={password_path}", - ] - env_str = StringIO("\n".join(env)) - load_dotenv(stream=env_str) - - -@contextmanager -def test_context(clients: list[TestClientSpec]) -> Iterator[Path]: - """Create a test context.""" - with tempfile.TemporaryDirectory() as tmpdir: - dirpath = Path(tmpdir) - for client in clients: - filename = dirpath / f"{client.name}.json" - create_client_file(client.name, filename, client.secrets) - - yield dirpath - - -@contextmanager -def api_context(clients: list[TestClientSpec]) -> Iterator[TestContext]: - """Create a context for testing the full API.""" - with tempfile.TemporaryDirectory() as tmpdir: - dirpath = Path(tmpdir) - master_password = generate_password() - context = TestContext(dirpath, master_password) - keepass = KeepassManager.create_database( - str(context.password_database), master_password - ) - seen_secrets: list[str] = [] - for client in clients: - filename = dirpath / f"{client.name}.json" - create_client_file(client.name, filename, client.secrets) - for secret, value in client.secrets.items(): - if secret in seen_secrets: - continue - keepass.add_password(secret, value) - seen_secrets.append(secret) - - yield context diff --git a/src/sshecret/types.py b/src/sshecret/types.py deleted file mode 100644 index cd2b4ba..0000000 --- a/src/sshecret/types.py +++ /dev/null @@ -1,179 +0,0 @@ -"""Interfaces and types.""" - -import abc -from types import NotImplementedType -from typing import Self, overload - -from pydantic import BaseModel -from pydantic.networks import IPvAnyAddress, IPvAnyNetwork - - -class BasePasswordReader(abc.ABC): - """Abstract strategy class to read a passwords.""" - - @abc.abstractmethod - def get_password(self, identifier: str, repeated: bool = False) -> str: - """Resolve the password, e.g., via input.""" - - -class PasswordContext: - """Context class for resolving a password.""" - - def __init__(self, reader: BasePasswordReader) -> None: - """Initialize password context.""" - self._reader: BasePasswordReader = reader - - @property - def reader(self) -> BasePasswordReader: - """Return reader.""" - return self._reader - - @reader.setter - def reader(self, reader: BasePasswordReader) -> None: - """Set the reader instance.""" - self._reader = reader - - def get_password(self, identifier: str, repeated: bool = False) -> str: - """Get the password.""" - return self.reader.get_password(identifier, repeated) - - -class BasePasswordManager(abc.ABC): - """Abstract base class for password managers.""" - - master_password_identifier: str - - @classmethod - @abc.abstractmethod - def create_database( - cls, - location: str, - password_context: PasswordContext | str, - overwrite: bool = False, - ) -> Self: - """Create database. - - Location can be a file, a url or something else. - """ - - @abc.abstractmethod - def open_database(self, password_context: PasswordContext | str) -> None: - """Open database.""" - - @abc.abstractmethod - def close_database(self) -> None: - """Close database.""" - - @abc.abstractmethod - def get_password(self, identifier: str) -> str | None: - """Get a password from the manager.""" - - @abc.abstractmethod - def generate_password(self, identifier: str) -> str: - """Generate a password using unspecified default rules. - - May be expanded later. - - Returns the generated password. - """ - - @abc.abstractmethod - def add_password(self, identifier: str, password: str) -> None: - """Add a pre-defined password.""" - - @abc.abstractmethod - def get_entries(self) -> list[str]: - """Get names of all entries.""" - - def set_manager_options(self, options: dict[str, str]) -> None: - """Set manager options.""" - pass - - @overload - def change_password(self, identifier: str, password: None) -> str: ... - - @overload - def change_password(self, identifier: str, password: str) -> None: ... - - @abc.abstractmethod - def change_password(self, identifier: str, password: str | None) -> str | None: - """Change password.""" - - @abc.abstractmethod - def delete_password(self, identifier: str) -> None: - """Delete a password.""" - - -class ClientSpecification(BaseModel): - """Specification of client.""" - - name: str - public_key: str - allowed_ips: list[IPvAnyAddress | IPvAnyNetwork] | str = "*" - secrets: dict[str, str] = {} - testing_private_key: str | None = None # Private key only for testing purposes! - - -class BaseClientBackend(abc.ABC): - """Base client backend. - - This class is responsible for managing the list of clients and facilitate - lookups. - """ - - @abc.abstractmethod - def lookup_name(self, name: str) -> ClientSpecification | None: - """Lookup a client specification by name.""" - - @abc.abstractmethod - def add_secret( - self, - client_name: str, - secret_name: str, - secret_value: str, - encrypted: bool = False, - ) -> None: - """Add a secret to a client.""" - - @abc.abstractmethod - def add_client(self, spec: ClientSpecification) -> None: - """Add a new client.""" - - @abc.abstractmethod - def update_client(self, name: str, spec: ClientSpecification) -> None: - """Update client information.""" - - @abc.abstractmethod - def remove_client(self, name: str, persistent: bool = True) -> None: - """Delete a client.""" - - @abc.abstractmethod - def get_all(self) -> list[ClientSpecification]: - """Get all clients.""" - - @abc.abstractmethod - def lookup_by_secret(self, secret_name: str) -> list[ClientSpecification]: - """Lookup by the name of a secret.""" - - -class BaseAPIClient(abc.ABC): - """Base API Client.""" - - source: str - method: str - - @abc.abstractmethod - def password_manager( - self, manager_options: dict[str, str] | None = None - ) -> BasePasswordManager: - """Instantiate password manager.""" - - def get_reader(self) -> BasePasswordReader: - """Get the reader.""" - raise NotImplementedError("Class-based password reading not implemented.") - - def get_context(self, reader: BasePasswordReader | None = None) -> PasswordContext: - """Get password context.""" - if not reader: - reader = self.get_reader() - return PasswordContext(reader) diff --git a/src/sshecret/utils.py b/src/sshecret/utils.py deleted file mode 100644 index 75e9fa9..0000000 --- a/src/sshecret/utils.py +++ /dev/null @@ -1,77 +0,0 @@ -"""Various utilities.""" - -import secrets - -from pathlib import Path - -from .crypto import ( - load_client_key, - encrypt_string, - generate_private_key, - generate_pem, - generate_public_key_string, -) -from .types import ClientSpecification - - -def generate_password() -> str: - """Generate a password.""" - return secrets.token_urlsafe(32) - - -def generate_client_object( - name: str, secrets: dict[str, str] | None = None, keyfile: str | None = None -) -> ClientSpecification: - """Generate a client object.""" - private_key = generate_private_key() - if keyfile: - with open(keyfile, "r") as f: - contents = f.read() - if not contents.startswith("ssh-rsa "): - raise RuntimeError("Error: Key must be an RSA key.") - - client = ClientSpecification(name=name, public_key=contents.strip()) - public_key = load_client_key(client) - else: - pem = generate_pem(private_key) - public_key = private_key.public_key() - pubkey_str = generate_public_key_string(public_key) - client = ClientSpecification( - name=name, public_key=pubkey_str, testing_private_key=pem - ) - if secrets: - for secret_name, secret_value in secrets.items(): - client.secrets[secret_name] = encrypt_string(secret_value, public_key) - - return client - - -def create_client_file( - name: str, - filename: Path | str, - secrets: dict[str, str] | None = None, - keyfile: str | None = None, -) -> None: - """Create client file.""" - client = generate_client_object(name, secrets, keyfile) - - with open(filename, "w") as f: - f.write(client.model_dump_json(exclude_none=True, indent=2)) - f.flush() - - -def add_secret_to_client_file( - filename: str | Path, secret_name: str, secret_value: str -) -> None: - """Add secret to client file.""" - with open(filename, "r") as f: - client = ClientSpecification.model_validate_json(f.read()) - - public_key = load_client_key(client) - encrypted = encrypt_string(secret_value, public_key) - client.secrets[secret_name] = encrypted - - with open(filename, "w") as f: - json_str = client.model_dump_json(exclude_none=True, indent=2) - f.write(json_str) - f.flush() diff --git a/src/sshecret/webapi/__init__.py b/src/sshecret/webapi/__init__.py deleted file mode 100644 index 8b13789..0000000 --- a/src/sshecret/webapi/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/sshecret/webapi/api.py b/src/sshecret/webapi/api.py deleted file mode 100644 index 7a25d72..0000000 --- a/src/sshecret/webapi/api.py +++ /dev/null @@ -1,279 +0,0 @@ -"""WebAPI.""" - -import asyncio -import logging -from functools import lru_cache -import secrets -import time - -from typing import Annotated -from fastapi import Header, HTTPException, Depends, Request, APIRouter - -from cryptography.fernet import Fernet - -from sshecret.api import ManagementApi -from sshecret.types import ( - BaseClientBackend, - BasePasswordManager, - ClientSpecification, -) -from sshecret.keepass import KeepassManager -from sshecret.backends.file_table import FileTableBackend - -from sshecret.settings import Settings, get_settings -from sshecret.webapi.api_client import WebManagementAPIClient -from . import models - -API_VERSION = "v1" - -admin_router = APIRouter(prefix=f"/api/{API_VERSION}") - -encryption_key = Fernet.generate_key() -cipher = Fernet(encryption_key) - -# We store sessions in memory. -sessions: dict[str, tuple[str, float]] = {} - -SESSION_TIMEOUT = 600 # 10 minutes - -LOG = logging.getLogger(__name__) - -session_lock = asyncio.Lock() - - -def encrypt_session_password(password: str) -> str: - """Encrypts the master password.""" - return cipher.encrypt(password.encode()).decode() - - -def decrypt_password(encrypted_password: str) -> str: - """Decrypts the master password asynchronously.""" - return cipher.decrypt(encrypted_password.encode()).decode() - - -async def validate_session(session_id: Annotated[str | None, Header()] = None) -> str: - """Middleware to validate session and enforce timeout.""" - if not session_id: - raise HTTPException(status_code=401, detail="Session ID required") - - async with session_lock: - if session_id not in sessions: - raise HTTPException(status_code=401, detail="Session invalid or expired") - - encrypted_password, last_access_time = sessions[session_id] - current_time = asyncio.get_event_loop().time() - - # Check for session timeout - if current_time - last_access_time > SESSION_TIMEOUT: - del sessions[session_id] # Auto-lock on timeout - raise HTTPException(status_code=401, detail="Session expired") - - # Update last access time - sessions[session_id] = (encrypted_password, current_time) - - return decrypt_password(encrypted_password) - - -@lru_cache -def get_app_settings() -> Settings: - """Get app settings.""" - return get_settings() - - -def get_password_manager( - settings: Annotated[Settings, Depends(get_app_settings)] -) -> BasePasswordManager: - """Get password manager.""" - # Currently only keepass is supported. - keepass = KeepassManager() - keepass.location = settings.password_manager.location - return keepass - - -async def get_backend( - settings: Annotated[Settings, Depends(get_app_settings)] -) -> BaseClientBackend: - """Get backend.""" - location = settings.backend.location - filetable = FileTableBackend(location) - return filetable - - -async def get_management_api( - request: Request, settings: Annotated[Settings, Depends(get_app_settings)] -) -> ManagementApi: - """Get management api.""" - client_ip = "unknown" - if req_client := request.client: - client_ip = req_client.host - - api_client = WebManagementAPIClient(client_ip, settings) - backend = await get_backend(settings) - return ManagementApi(backend, api_client) - - -BackendDependency = Annotated[BaseClientBackend, Depends(get_backend)] -ManagementAPIDependency = Annotated[ManagementApi, Depends(get_management_api)] -SessionPasswdDependency = Annotated[str, Depends(validate_session)] - - -@admin_router.post("/auth/unlock") -async def unlock_database( - password: models.PasswordBody, - password_manager: Annotated[BasePasswordManager, Depends(get_password_manager)], -) -> models.SessionResponse: - """Unlock database with master password sent in POST body.""" - password_str = password.password.get_secret_value() - try: - password_manager.open_database(password_str) - except Exception as e: - LOG.debug("Exception: %s", e, exc_info=True) - raise HTTPException(status_code=401, detail="Invalid password.") - - session_id = secrets.token_urlsafe(32) - sessions[session_id] = (encrypt_session_password(password_str), time.time()) - - return models.SessionResponse(session_id=session_id) - - -@admin_router.post("/auth/lock") -async def lock_database( - session_id: Annotated[str | None, Header()] = None -) -> dict[str, str]: - """Lock database.""" - if session_id and session_id in sessions: - del sessions[session_id] - - return {"message": "LOCKED"} - raise HTTPException(400, detail="Missing session ID.") - - -@admin_router.get("/auth/status") -async def get_lock_status( - session_id: Annotated[str | None, Header()] = None -) -> dict[str, str]: - """Get current lock status.""" - if session_id and session_id in sessions: - return {"message": "UNLOCKED"} - return {"message": "LOCKED"} - - -@admin_router.get("/clients") -async def get_clients(admin_api: ManagementAPIDependency) -> list[ClientSpecification]: - """Get clients.""" - return admin_api.get_clients() - - -@admin_router.get("/clients/{client_id}") -async def get_client( - client_id: str, admin_api: ManagementAPIDependency -) -> ClientSpecification: - """Get client.""" - if client_api := admin_api.get_client(client_id): - return client_api.client - raise HTTPException(status_code=404, detail="Client not found.") - - -@admin_router.put("/clients/{client_id}") -async def update_client( - client_id: str, - client: ClientSpecification, - admin_api: ManagementAPIDependency, - master_password: SessionPasswdDependency, -) -> ClientSpecification: - """Update client.""" - client_api = admin_api.get_client(client_id) - if not client_api: - raise HTTPException(status_code=404, detail="Client not found.") - new_client = client_api.update_client(client, password=master_password) - return new_client - - -@admin_router.delete("/clients/{client_id}", status_code=204) -async def delete_client(client_id: str, admin_api: ManagementAPIDependency) -> None: - """Delete client.""" - if admin_api.get_client(client_id): - admin_api.delete_client(client_id) - else: - raise HTTPException(status_code=404, detail="Client not found.") - - -@admin_router.post("/clients", status_code=201) -async def add_client( - client: models.CreateClientModel, admin_api: ManagementAPIDependency -) -> ClientSpecification: - """Add client.""" - new_client = admin_api.create_client( - client.name, client.public_key, client.allowed_ips - ) - return new_client.client - - -@admin_router.get("/secrets") -async def list_secrets( - admin_api: ManagementAPIDependency, password: SessionPasswdDependency -) -> list[models.SecretListResponse]: - """List secrets.""" - secrets = admin_api.get_secret_names(password=password) - return [ - models.SecretListResponse(name=name, assigned_clients=assigned_clients) - for name, assigned_clients in secrets.items() - ] - - -@admin_router.post("/secrets") -async def add_secret( - secret: models.CreateSecretSpecification, - password: SessionPasswdDependency, - admin_api: ManagementAPIDependency, -) -> models.RevealSecretResponse: - """Add secret. - - Will generate a password if none is specified. - """ - secret_value: str | None = None - if secret.secret: - secret_value = secret.secret.get_secret_value() - result_secret = admin_api.add_secret(secret.name, secret_value, password=password) - return models.RevealSecretResponse(name=secret.name, secret=result_secret) - - -@admin_router.get("/secrets/{name}") -async def get_secret( - name: str, admin_api: ManagementAPIDependency, password: SessionPasswdDependency -) -> models.RevealSecretResponse: - """Get secret.""" - if secret_value := admin_api.get_secret(name, password=password): - return models.RevealSecretResponse(name=name, secret=secret_value) - raise HTTPException(status_code=404, detail="Secret not found.") - - -@admin_router.put("/secrets/{name}") -async def update_secret( - name: str, - spec: models.UpdateSecretSpecification, - admin_api: ManagementAPIDependency, - password: SessionPasswdDependency, -) -> models.MaybeRevalSecretResponse: - """Update secret.""" - if spec.auto_generate: - secret_value = admin_api.regenerate_secret(name, password=password) - return models.MaybeRevalSecretResponse(name=name, secret=secret_value) - - if not spec.secret: - raise HTTPException( - status_code=400, - detail="Secret value must be specified if auto_generate is False", - ) - admin_api.update_secret(name, spec.secret, password=password) - return models.MaybeRevalSecretResponse(name=name, secret=None) - - -@admin_router.delete("/secrets/{name}", status_code=204) -async def delete_secret( - name: str, - admin_api: ManagementAPIDependency, - password: SessionPasswdDependency, -) -> None: - """Delete secret.""" - admin_api.delete_secret(name, password=password) diff --git a/src/sshecret/webapi/api_client.py b/src/sshecret/webapi/api_client.py deleted file mode 100644 index f6cf112..0000000 --- a/src/sshecret/webapi/api_client.py +++ /dev/null @@ -1,25 +0,0 @@ -"""API Client.""" - -from typing import override -from sshecret.keepass import KeepassManager -from sshecret.types import BaseAPIClient, BasePasswordManager -from sshecret.settings import Settings, get_settings - - -class WebManagementAPIClient(BaseAPIClient): - """Client class for the web management API.""" - - method: str = "admin-web-api" - - def __init__(self, source: str, settings: Settings | None = None) -> None: - """Construct client.""" - if not settings: - settings = get_settings() - self.source: str = source - self._password_manager: BasePasswordManager = KeepassManager() - self._password_manager.location = settings.password_manager.manager.location - - @override - def password_manager(self, manager_options: dict[str, str] | None = None) -> BasePasswordManager: - """Get password manager.""" - return self._password_manager diff --git a/src/sshecret/webapi/frontend.py b/src/sshecret/webapi/frontend.py deleted file mode 100644 index 3074af8..0000000 --- a/src/sshecret/webapi/frontend.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Admin frontend.""" - -from fastapi import APIRouter, Request - -from fastapi.responses import HTMLResponse -from fastapi.templating import Jinja2Templates - -templates = Jinja2Templates(directory="templates") - -frontend = APIRouter() - -# I'm just making some placeholders here -@frontend.get("/") -async def index(request: Request) -> HTMLResponse: - """Get frontpage.""" - return templates.TemplateResponse(request, name="index.html") - - -@frontend.get("/login") -async def login(request: Request) -> HTMLResponse: - """Get login page.""" - return templates.TemplateResponse(request, name="login.html") - - -@frontend.get("/clients") -async def clients(request: Request) -> HTMLResponse: - """Get login page.""" - return templates.TemplateResponse(request, name="clients.html") - -@frontend.get("/secrets") -async def secrets(request: Request) -> HTMLResponse: - """Get login page.""" - return templates.TemplateResponse(request, name="secrets.html") diff --git a/src/sshecret/webapi/models.py b/src/sshecret/webapi/models.py deleted file mode 100644 index ee40f08..0000000 --- a/src/sshecret/webapi/models.py +++ /dev/null @@ -1,71 +0,0 @@ -"""Response models.""" - -from pydantic import BaseModel, IPvAnyAddress, IPvAnyNetwork, SecretStr - - -class SSHKeyResponse(BaseModel): - """Response model for updated SSH keys.""" - - updated_secrets: list[str] - - -class SecretListResponse(BaseModel): - """Response for listing secrets.""" - - name: str - assigned_clients: list[str] - - -class CreateSecretSpecification(BaseModel): - """Model for creating a secret.""" - - name: str - secret: SecretStr | None - - -class SecretSpecification(BaseModel): - """Secret specification.""" - - name: str - secret: SecretStr - - -class UpdateSecretSpecification(BaseModel): - """Model for updating a secret.""" - - secret: str | None - auto_generate: bool | None = None - - -class RevealSecretResponse(BaseModel): - """Reveal secret.""" - - name: str - secret: str - - -class MaybeRevalSecretResponse(BaseModel): - """Model where the secret may be specified.""" - - name: str - secret: str | None - - -class PasswordBody(BaseModel): - """Password body.""" - - password: SecretStr - - -class SessionResponse(BaseModel): - """Session response.""" - - session_id: str - - -class CreateClientModel(BaseModel): - """Model for creating a client.""" - - name: str - public_key: str - allowed_ips: list[IPvAnyAddress | IPvAnyNetwork] | str = "*" diff --git a/src/sshecret/webapi/router.py b/src/sshecret/webapi/router.py deleted file mode 100644 index f720bf7..0000000 --- a/src/sshecret/webapi/router.py +++ /dev/null @@ -1,16 +0,0 @@ -"""API router.""" - -from fastapi import FastAPI -from fastapi.staticfiles import StaticFiles - - -from .api import admin_router -from .frontend import frontend - - -app = FastAPI() - -app.include_router(admin_router) -app.include_router(frontend) - -app.mount("/static", StaticFiles(directory="static"), name="static")