Clear POC first draft

This commit is contained in:
2025-04-18 16:41:22 +02:00
parent 525be35fcf
commit 19a7e2a91b
33 changed files with 0 additions and 2364 deletions

View File

@ -1,2 +0,0 @@
def hello() -> str:
return "Hello from sshecret!"

View File

@ -1,447 +0,0 @@
"""API.
This module is an attempt to create some sort of meaningfull API around the
actions exposed here.
"""
import abc
from contextlib import contextmanager
from collections.abc import Iterator
from pydantic.networks import IPvAnyAddress, IPvAnyNetwork
from .audit import audit_message
from .crypto import load_client_key, load_public_key, encrypt_string
from .types import (
BaseAPIClient,
BaseClientBackend,
BasePasswordManager,
BasePasswordReader,
ClientSpecification,
PasswordContext,
)
@contextmanager
def password_manager_session(
password_manager: BasePasswordManager,
password_context: PasswordContext | str,
api_client: BaseAPIClient,
) -> Iterator[BasePasswordManager]:
"""Open password manager for read/write in a context."""
audit_message(
"Opening password manager session",
"SECURITY",
source_address=api_client.source,
)
password_manager.open_database(password_context)
yield password_manager
audit_message(
"Closing password manager session",
"SECURITY",
source_address=api_client.source,
)
password_manager.close_database()
class BaseSshecretAPI(abc.ABC):
"""Base API class."""
def __init__(
self,
backend: BaseClientBackend,
api_client: BaseAPIClient,
manager_options: dict[str, str] | None = None,
) -> None:
"""Initialize API."""
self.backend: BaseClientBackend = backend
self.api_client: BaseAPIClient = api_client
self.manager_options: dict[str, str] | None = manager_options
def _log_audit(
self,
message: str,
audit_type: str,
client_name: str | None = None,
**details: str,
) -> None:
"""Log an audit message."""
audit_message(
message,
audit_type,
client_name,
source_address=self.api_client.source,
**details,
)
@contextmanager
def password_session(
self, reader: BasePasswordReader | None = None, password: str | None = None
) -> Iterator[BasePasswordManager]:
"""Open a password session."""
if password:
context = password
else:
if not reader:
reader = self.api_client.get_reader()
context = self.api_client.get_context(reader)
password_manager = self.api_client.password_manager(self.manager_options)
with password_manager_session(
password_manager, context, self.api_client
) as session:
yield session
class ClientManagementAPI(BaseSshecretAPI):
"""API for managing clients."""
def __init__(
self,
backend: BaseClientBackend,
client: ClientSpecification,
api_client: BaseAPIClient,
manager_options: dict[str, str] | None = None,
) -> None:
"""Create client management API instance."""
super().__init__(backend, api_client, manager_options)
self.client: ClientSpecification = client
self.__password_manager: BasePasswordManager | None = None
def log_security(self, message: str, **details: str) -> None:
"""Log a security related message."""
self._log_audit(message, "SECURITY", self.client.name, **details)
def log_info(self, message: str) -> None:
"""Log an informational message."""
self._log_audit(message, "INFORMATIONAL", self.client.name)
@property
def password_manager(self) -> BasePasswordManager:
"""Get password manager."""
if self.__password_manager:
self.log_security("Accessed password manager")
return self.__password_manager
raise RuntimeError("Password manager not initialized.")
@password_manager.setter
def password_manager(self, instance: BasePasswordManager) -> None:
"""Set password manager instance."""
self.log_security("Opened password manager.")
self.__password_manager = instance
def get_secrets(self) -> list[str]:
"""Get names of the secrets that the client has access to.."""
self.log_security("Listing secret names.")
return list(self.client.secrets.keys())
def update_client(
self,
client: ClientSpecification,
reader: BasePasswordReader | None = None,
password: str | None = None,
) -> ClientSpecification:
"""Update client."""
self.log_info("Updating client")
update_data = client.model_dump(exclude_unset=True)
if client.public_key != self.client.public_key:
self.log_security("Client public key has changed.")
if client.secrets != self.client.secrets:
raise RuntimeError(
"Error: Cannot update public key and secrets in the same operation."
)
del update_data["secrets"]
secrets = self._re_encrypt(client.public_key, reader, password)
update_data["secrets"] = secrets
updated_client = self.client.model_copy(update=update_data)
self.backend.update_client(self.client.name, updated_client)
self.client = updated_client
return updated_client
def update_secret(self, name: str, password: str) -> None:
"""Update a secret.
If secret is not already a part of the client, it will be added.
"""
if name in self.client.secrets:
self.log_security("Updating secret", secret_name=name)
else:
self.log_security("Adding secret", secret_name=name)
public_key = load_client_key(self.client)
encrypted = encrypt_string(password, public_key)
client_secrets = {**self.client.secrets, name: encrypted}
updated_client = self.client.model_copy(update={"secrets": client_secrets})
self.backend.update_client(self.client.name, updated_client)
self.client = updated_client
def _re_encrypt(
self,
new_key: str,
reader: BasePasswordReader | None = None,
password: str | None = None,
) -> dict[str, str]:
"""Update public key on given client."""
audit_message(
"Updating public key",
"INFORMATIONAL",
self.client.name,
source_address=self.api_client.source,
)
if password:
context = password
else:
if not reader:
reader = self.api_client.get_reader()
context = self.api_client.get_context(reader)
password_manager = self.api_client.password_manager(self.manager_options)
self.password_manager = password_manager
client_key = load_public_key(new_key.encode())
secrets: dict[str, str] = {}
with password_manager_session(
password_manager, context, self.api_client
) as password_session:
for name in self.get_secrets():
secret = password_session.get_password(name)
audit_message(
"Updating encrypted value",
"SECURITY",
self.client.name,
secret_name=name,
source_address=self.api_client.source,
)
if not secret:
raise RuntimeError("Could not fetch a new secret value.")
new_value = encrypt_string(secret, client_key)
secrets[name] = new_value
return secrets
def update_public_key(
self,
new_key: str,
reader: BasePasswordReader | None = None,
password: str | None = None,
) -> None:
"""Update the public key."""
audit_message(
"Updating public key",
"INFORMATIONAL",
self.client.name,
source_address=self.api_client.source,
)
secrets = self._re_encrypt(new_key, reader, password)
updated = self.client.model_copy(update={"secrets": secrets})
self.backend.update_client(self.client.name, updated)
@classmethod
def get_client(
cls,
backend: BaseClientBackend,
name: str,
api_client: BaseAPIClient,
) -> "ClientManagementAPI | None":
"""Get client."""
client = backend.lookup_name(name)
if not client:
return None
return cls(backend, client, api_client)
@classmethod
def create_client(
cls,
backend: BaseClientBackend,
name: str,
api_client: BaseAPIClient,
public_key: str,
allowed_ips: list[IPvAnyAddress | IPvAnyNetwork] | str = "*",
) -> "ClientManagementAPI":
"""Create a client."""
client = ClientSpecification(
name=name, public_key=public_key, allowed_ips=allowed_ips
)
backend.add_client(client)
return cls(backend, client, api_client)
class ManagementApi(BaseSshecretAPI):
"""Api for general management."""
def __init__(
self,
backend: BaseClientBackend,
api_client: BaseAPIClient,
manager_options: dict[str, str] | None = None,
) -> None:
"""Initialize API."""
super().__init__(backend, api_client, manager_options)
def log_security(
self, message: str, client_name: str | None = None, **details: str
) -> None:
"""Log a security related message."""
self._log_audit(message, "SECURITY", client_name, **details)
def log_info(self, message: str, client_name: str | None = None) -> None:
"""Log an informational message."""
self._log_audit(message, "INFORMATIONAL", client_name)
def get_client(self, name: str) -> ClientManagementAPI | None:
"""Get a client."""
client = self.backend.lookup_name(name)
if not client:
return None
return ClientManagementAPI(
self.backend, client, self.api_client, self.manager_options
)
def get_clients(self) -> list[ClientSpecification]:
"""Get clients."""
self.log_info("Fetched all clients")
return self.backend.get_all()
def _get_clients(self) -> list[ClientSpecification]:
"""Get clients."""
return self.backend.get_all()
def delete_client(self, name: str) -> None:
"""Delete client."""
client = self.backend.lookup_name(name)
if not client:
self.log_info("Attempted to delete a non-existing client.", name)
return
self.log_security("Deleting client", name)
self.backend.remove_client(name)
def create_client(
self,
name: str,
public_key: str,
allowed_ips: list[IPvAnyAddress | IPvAnyNetwork] | str = "*",
) -> ClientManagementAPI:
"""Create a client."""
self.log_info("Creating new client", name)
client = ClientSpecification(
name=name, public_key=public_key, allowed_ips=allowed_ips
)
self.backend.add_client(client)
return ClientManagementAPI(self.backend, client, self.api_client)
def get_secret_names(
self, reader: BasePasswordReader | None = None, password: str | None = None
) -> dict[str, list[str]]:
"""Get secret names and which clients have these.."""
self.log_security("Listing all secret names.")
with self.password_session(reader=reader, password=password) as session:
secret_names = session.get_entries()
secret_mapping: dict[str, list[str]] = {}
for name in secret_names:
secret_mapping[name] = [
client.name for client in self.backend.lookup_by_secret(name)
]
return secret_mapping
def add_secret(
self,
name: str,
secret_value: str | None,
reader: BasePasswordReader | None = None,
password: str | None = None,
) -> str:
"""Add a secret."""
self.log_security("Adding new secret", secret_name=name)
with self.password_session(reader=reader, password=password) as session:
if not secret_value:
self.log_security("Auto-generating a secret value", secret_name=name)
secret_value = session.generate_password(name)
else:
session.add_password(name, secret_value)
return secret_value
def get_secret(
self,
name: str,
reader: BasePasswordReader | None = None,
password: str | None = None,
) -> str | None:
"""Get the clear-text value of a secret."""
self.log_security("Client requested secret value", secret_name=name)
with self.password_session(reader=reader, password=password) as session:
secret = session.get_password(name)
return secret
def update_secret(
self,
name: str,
new_value: str,
reader: BasePasswordReader | None = None,
password: str | None = None,
) -> None:
"""Update a secret with a given name."""
self.log_security("Changing secret", secret_name=name)
with self.password_session(reader=reader, password=password) as session:
session.change_password(name, new_value)
clients = self.backend.lookup_by_secret(name)
for client in clients:
client_api = self.get_client(client.name)
if not client_api:
continue
client_api.update_secret(name, new_value)
def regenerate_secret(
self,
name: str,
reader: BasePasswordReader | None = None,
password: str | None = None,
) -> str:
"""Regenerate a secret."""
self.log_security("Generating a new secret value", secret_name=name)
with self.password_session(reader=reader, password=password) as session:
new_value = session.change_password(name, None)
clients = self.backend.lookup_by_secret(name)
for client in clients:
client_api = self.get_client(client.name)
if not client_api:
continue
client_api.update_secret(name, new_value)
return new_value
def delete_secret(
self,
name: str,
reader: BasePasswordReader | None = None,
password: str | None = None,
) -> None:
"""Delete secret."""
clients = self.backend.lookup_by_secret(name)
self.log_security("Deleting secret", secret_name=name)
with self.password_session(reader=reader, password=password) as session:
session.delete_password(name)
for client in clients:
secrets = {**client.secrets}
del secrets[name]
new_client = client.model_copy(update={"secrets": secrets})
client_api = self.get_client(client.name)
if not client_api:
continue
self.log_security(
"Removing secret from client.",
client_name=client.name,
secret_name=name,
)
client_api.update_client(new_client, password=password)

View File

@ -1,66 +0,0 @@
"""Audit setup."""
import enum
import json
import logging
from typing import Any
from pythonjsonlogger.json import JsonFormatter
from pydantic import BaseModel, ConfigDict
from .constants import AUDIT_LOG_NAME
AUDIT_LOG = logging.getLogger(AUDIT_LOG_NAME)
class AuditMessageType(enum.StrEnum):
"""Audit Message Type."""
ACCESS = enum.auto() # Someone accessed something
SECURITY = enum.auto() # A message related to security
INFORMATIONAL = enum.auto() # other informational messages
class AuditMessage(BaseModel):
"""Audit message."""
model_config = ConfigDict(use_enum_values=True)
type: AuditMessageType
message: str
client_name: str | None = None
source_address: str | None = None
secret_name: str | None = None
def __str__(self) -> str:
"""Stringify object as JSON."""
return self.model_dump_json()
def audit_message(
message: str,
audit_type: AuditMessageType | str | None = None,
client_name: str | None = None,
secret_name: str | None = None,
source_address: str | None = None,
**details: str
) -> None:
"""Create an audit message."""
if not audit_type:
audit_type = AuditMessageType.INFORMATIONAL
if audit_type not in list(AuditMessageType):
audit_type = AuditMessageType.INFORMATIONAL
audit_message = AuditMessage(
type=audit_type,
message=message,
client_name=client_name,
source_address=source_address,
secret_name=secret_name,
)
audit_dict = audit_message.model_dump(exclude_none=True)
AUDIT_LOG.info({**audit_dict, **details})

View File

@ -1,5 +0,0 @@
"""Backend implementations"""
from .file_table import FileTableBackend
__all__ = ["FileTableBackend"]

View File

@ -1,133 +0,0 @@
"""File table based backend."""
import logging
import os
from pathlib import Path
from typing import override
import littletable as lt
from sshecret.crypto import load_client_key, encrypt_string
from sshecret.types import ClientSpecification
from sshecret.types import BaseClientBackend
LOG = logging.getLogger(__name__)
def load_clients_from_dir(directory: Path) -> dict[Path, ClientSpecification]:
"""Load clients from a directory."""
if not directory.exists() or not directory.is_dir():
raise ValueError("Invalid directory specified.")
clients: dict[Path, ClientSpecification] = {}
for client_file in directory.glob("*.json"):
with open(client_file, "r") as f:
client = ClientSpecification.model_validate_json(f.read())
if client_file.name != f"{client.name}.json":
raise RuntimeError(
"Filename scheme of clients does not conform to expected format. Aborting import!"
)
clients[client_file] = client
return clients
class FileTableBackend(BaseClientBackend):
"""In-memory littletable based backend."""
def __init__(self, directory: Path) -> None:
"""Create backend instance."""
LOG.debug("Creating in-memory table to hold clients.")
self._directory: Path = directory
self.table: lt.Table[ClientSpecification] = lt.Table()
self._setup_table()
client_files = load_clients_from_dir(directory)
client_count = len(client_files)
LOG.debug("Loaded %s clients from disk.", client_count)
# self.client_file_map: dict[str, Path] = {client.name: filepath for filepath, client in client_files.items()}
LOG.debug("Inserting clients into table.")
self.table.insert_many(list(client_files.values()))
def _setup_table(self) -> None:
"""Set up the table."""
self.table.create_index("name", unique=True)
@override
def lookup_name(self, name: str) -> ClientSpecification | None:
"""Lookup client by name."""
if result := self.table.by.name.get(name):
if isinstance(result, ClientSpecification):
return result
return None
@override
def add_client(self, spec: ClientSpecification) -> None:
"""Add client."""
self.table.insert(spec)
self._write_spec_file(spec)
def _write_spec_file(self, spec: ClientSpecification) -> None:
"""Write spec file to disk."""
dest_file_name = f"{spec.name}.json"
dest_file = self._directory / dest_file_name
with open(dest_file.absolute(), "w") as f:
f.write(
spec.model_dump_json(exclude_none=True, exclude_unset=True, indent=2)
)
f.flush()
@override
def add_secret(
self,
client_name: str,
secret_name: str,
secret_value: str,
encrypted: bool = False,
) -> None:
"""Add secret."""
client: ClientSpecification = self.table.by.name[client_name] # pyright: ignore[reportAssignmentType]
if not encrypted:
public_key = load_client_key(client)
secret_value = encrypt_string(secret_value, public_key)
client.secrets[secret_name] = secret_value
self._update_client_data(client)
self._write_spec_file(client)
@override
def remove_client(self, name: str, persistent: bool = True) -> None:
"""Delete client."""
client = self.lookup_name(name)
if not client:
raise ValueError("Client does not exist!")
self.table.remove(client)
if persistent:
filename = f"{client.name}.json"
filepath = self._directory / filename
filepath.unlink()
@override
def update_client(self, name: str, spec: ClientSpecification) -> None:
"""Update client."""
if not self.lookup_name(name):
raise ValueError("Client does not exist!")
self._update_client_data(spec)
self._write_spec_file(spec)
def _update_client_data(self, spec: ClientSpecification) -> None:
"""Update client data."""
existing = self.lookup_name(spec.name)
if existing:
self.table.remove(existing)
self.add_client(spec)
@override
def get_all(self) -> list[ClientSpecification]:
"""Get all clients."""
return list(self.table)
@override
def lookup_by_secret(self, secret_name: str) -> list[ClientSpecification]:
"""Lookup by secret name."""
results = self.table.where(lambda client: secret_name in client.secrets)
return list(results)

View File

@ -1,3 +0,0 @@
"""Command Line Interface"""
import click

View File

@ -1,25 +0,0 @@
"""Client code"""
from typing import TextIO
import click
from sshecret.crypto import decode_string, load_private_key
def decrypt_secret(encoded: str, client_key: str) -> str:
"""Decrypt secret."""
private_key = load_private_key(client_key)
return decode_string(encoded, private_key)
@click.command()
@click.argument("keyfile", type=click.Path(exists=True, readable=True, dir_okay=False))
@click.argument("encrypted_input", type=click.File("r"))
def cli_decrypt(keyfile: str, encrypted_input: TextIO) -> None:
"""Decrypt on command line."""
decrypted = decrypt_secret(encrypted_input.read(), keyfile)
click.echo(decrypted)
if __name__ == "__main__":
cli_decrypt()

View File

@ -1,18 +0,0 @@
"""Config file."""
from pathlib import Path
from pydantic import SecretStr
from pydantic_settings import BaseSettings
class KeepassSettings(BaseSettings):
"""Settings for Keepasss password database."""
database_path: Path
class SshecretSettings(BaseSettings):
"""Settings model."""
admin_password: SecretStr
admin_ssh_key: str | None = None
keepass: KeepassSettings

View File

@ -1,17 +0,0 @@
"""Constants."""
MASTER_PASSWORD = "MASTER_PASSWORD"
NO_USERNAME = "NO_USERNAME"
VAR_PREFIX = "SSHECRET"
ERROR_NO_SECRET_FOUND = "Error: No secret available with the given name."
ERROR_UKNOWN_CLIENT_OR_SECRET = "Error: Invalid client or secret name."
ERROR_NO_COMMAND_RECEIVED = "Error: No command was received from the client."
ERROR_SOURCE_IP_NOT_ALLOWED = (
"Error: Client not authorized to connect from the given host."
)
RSA_PUBLIC_EXPONENT = 65537
RSA_KEY_SIZE = 2048
AUDIT_LOG_NAME = "AUDIT"

View File

@ -1,106 +0,0 @@
"""Encryption related functions.
Note! Encryption uses the less secure PKCS1v15 padding. This is to allow
decryption via openssl on the command line.
"""
import base64
import logging
from pathlib import Path
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric import padding
from .types import ClientSpecification
from . import constants
LOG = logging.getLogger(__name__)
def load_client_key(client: ClientSpecification) -> rsa.RSAPublicKey:
"""Load public key."""
keybytes = client.public_key.encode()
return load_public_key(keybytes)
def load_public_key(keybytes: bytes) -> rsa.RSAPublicKey:
public_key = serialization.load_ssh_public_key(keybytes)
if not isinstance(public_key, rsa.RSAPublicKey):
raise RuntimeError("Only RSA keys are supported.")
pem_public_key = public_key.public_bytes(encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo)
LOG.info("pem:\n%s", pem_public_key)
return public_key
def load_private_key(filename: str) -> rsa.RSAPrivateKey:
"""Load a private key."""
with open(filename, "rb") as f:
private_key = serialization.load_ssh_private_key(f.read(), password=None)
if not isinstance(private_key, rsa.RSAPrivateKey):
raise RuntimeError("Only RSA keys are supported.")
return private_key
def encrypt_string(string: str, public_key: rsa.RSAPublicKey) -> str:
"""Encrypt string, end return it base64 encoded."""
message = string.encode()
ciphertext = public_key.encrypt(
message,
padding.PKCS1v15(),
)
return base64.b64encode(ciphertext).decode()
def decode_string(ciphertext: str, private_key: rsa.RSAPrivateKey) -> str:
"""Decode a string. String must be base64 encoded."""
decoded = base64.b64decode(ciphertext)
decrypted = private_key.decrypt(
decoded,
padding.PKCS1v15(),
)
return decrypted.decode()
def generate_private_key() -> rsa.RSAPrivateKey:
"""Generate private RSA key."""
private_key = rsa.generate_private_key(
public_exponent=constants.RSA_PUBLIC_EXPONENT, key_size=constants.RSA_KEY_SIZE
)
return private_key
def generate_pem(private_key: rsa.RSAPrivateKey) -> str:
"""Generate PEM."""
pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.OpenSSH,
encryption_algorithm=serialization.NoEncryption(),
)
return pem.decode()
def create_private_rsa_key(filename: Path) -> None:
"""Create an RSA Private key at the given path."""
if filename.exists():
raise RuntimeError("Error: private key file already exists.")
LOG.debug("Generating private RSA key at %s", filename)
private_key = generate_private_key()
with open(filename, "wb") as f:
pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.OpenSSH,
encryption_algorithm=serialization.NoEncryption(),
)
lines = f.write(pem)
LOG.debug("Wrote %s lines", lines)
f.flush()
def generate_public_key_string(public_key: rsa.RSAPublicKey) -> str:
"""Generate public key string."""
keybytes = public_key.public_bytes(
encoding=serialization.Encoding.OpenSSH,
format=serialization.PublicFormat.OpenSSH,
)
return keybytes.decode()

View File

@ -1,93 +0,0 @@
"""Development CLI commands."""
import sys
import asyncio
import asyncssh
import click
import logging
import tempfile
import threading
from pathlib import Path
from pythonjsonlogger.json import JsonFormatter
from .server import start_server
from sshecret.backends import FileTableBackend
from .utils import create_client_file, add_secret_to_client_file
from .constants import AUDIT_LOG_NAME
def thread_id_filter(record: logging.LogRecord) -> logging.LogRecord:
"""Resolve thread id."""
record.thread_id = threading.get_native_id()
return record
LOG = logging.getLogger()
handler = logging.StreamHandler()
handler.addFilter(thread_id_filter)
formatter = logging.Formatter(
"%(thread_id)d:%(created)f:%(levelname)s:%(name)s:%(module)s:%(message)s"
)
handler.setFormatter(formatter)
LOG.addHandler(handler)
LOG.setLevel(logging.DEBUG)
AUDIT_LOG = logging.getLogger(AUDIT_LOG_NAME)
audit_formatter = JsonFormatter()
audit_handler = logging.StreamHandler()
audit_handler.setFormatter(audit_formatter)
AUDIT_LOG.addHandler(audit_handler)
@click.group()
def cli() -> None:
"""Run commands for testing."""
@cli.command("create-client")
@click.argument("name")
@click.argument(
"filename", type=click.Path(file_okay=True, dir_okay=False, writable=True)
)
@click.option("--public-key", type=click.Path(file_okay=True))
def create_client(name: str, filename: str, public_key: str | None) -> None:
"""Create a client."""
create_client_file(name, filename, keyfile=public_key)
click.echo(f"Wrote client config to {filename}")
@cli.command("add-secret")
@click.argument(
"filename", type=click.Path(file_okay=True, dir_okay=False, writable=True)
)
@click.argument("secret-name")
@click.argument("secret-value")
def add_secret(filename: str, secret_name: str, secret_value: str) -> None:
"""Add secret to client file."""
add_secret_to_client_file(filename, secret_name, secret_value)
click.echo(f"Wrote secret to {filename}")
@cli.command("server")
@click.argument("directory", type=click.Path(file_okay=False, dir_okay=True))
@click.argument("port", type=click.INT)
def run_async_server(directory: str, port: int) -> None:
"""Run async server."""
loop = asyncio.new_event_loop()
with tempfile.TemporaryDirectory() as tmpdir:
serverdir = Path(tmpdir)
host_key = str(serverdir / "hostkey")
clientdir = Path(directory)
backend = FileTableBackend(clientdir)
try:
loop.run_until_complete(start_server(port, backend, host_key, True))
except (OSError, asyncssh.Error) as exc:
click.echo(f"Error starting server: {exc}")
sys.exit(1)
loop.run_forever()
if __name__ == "__main__":
cli()

View File

@ -1,180 +0,0 @@
"""Keepass integration."""
import logging
from pathlib import Path
from typing import cast, final, overload, override, Self
import pykeepass
from . import constants
from .types import BasePasswordManager, PasswordContext
from .utils import generate_password
LOG = logging.getLogger(__name__)
@final
class KeepassManager(BasePasswordManager):
"""KeepassXC compatible password manager."""
master_password_identifier = constants.MASTER_PASSWORD
def __init__(self) -> None:
"""Initialize password manager."""
self._location: Path | None = None
self._keepass: pykeepass.PyKeePass | None = None
@property
def location(self) -> Path:
"""Get location."""
if not self._location:
raise RuntimeError("No location has been specified.")
return self._location
@location.setter
def location(self, location: Path) -> None:
"""Set location."""
if not location.exists() or not location.is_file():
raise RuntimeError("Unable to read provided password file.")
self._location = location
@override
def set_manager_options(self, options: dict[str, str]) -> None:
"""Set manager options."""
if "location" in options:
location = Path(str(options["location"]))
self.location = location
@property
def keepass(self) -> pykeepass.PyKeePass:
"""Return keepass instance."""
if self._keepass:
return self._keepass
raise RuntimeError("Error: Database has not been opened.")
@keepass.setter
def keepass(self, instance: pykeepass.PyKeePass) -> None:
"""Set the keepass instance."""
self._keepass = instance
@override
def get_entries(self) -> list[str]:
"""Get all entries."""
entries = self.keepass.entries
if not entries:
return []
return [
str(entry.title) for entry in entries
]
@override
@classmethod
def create_database(
cls, location: str, password_context: PasswordContext | str, overwrite: bool = False
) -> Self:
"""Create database."""
if Path(location).exists() and not overwrite:
raise RuntimeError("Error: Database exists.")
if isinstance(password_context, PasswordContext):
master_password = password_context.get_password(cls.master_password_identifier, True)
else:
master_password = password_context
# TODO: should we delete if overwrite is set?
keepass = pykeepass.create_database(location, password=master_password)
instance = cls()
instance.set_manager_options({"location": str(location)})
instance.keepass = keepass
return instance
@override
def open_database(self, password_context: PasswordContext | str) -> None:
"""Open the database"""
if isinstance(password_context, PasswordContext):
password = password_context.get_password(self.master_password_identifier)
else:
password = password_context
instance = pykeepass.PyKeePass(str(self.location.absolute()), password=password)
self.keepass = instance
@override
def close_database(self) -> None:
"""Close the database."""
self._keepass = None
@override
def get_password(self, identifier: str) -> str | None:
"""Get password."""
entry = cast("pykeepass.entry.Entry | None", self.keepass.find_entries(title=identifier, first=True))
if not entry:
return None
if password := cast(str, entry.password):
return str(password)
raise RuntimeError(f"Cannot get password for entry {identifier}")
@override
def generate_password(self, identifier: str) -> str:
"""Generate password."""
# Generate a password.
password = generate_password()
_entry = self.keepass.add_entry(
self.keepass.root_group, identifier, constants.NO_USERNAME, password
)
self.keepass.save()
LOG.debug("Created Entry %r", _entry)
return password
@override
def add_password(self, identifier: str, password: str) -> None:
"""Add a password."""
entry = cast("pykeepass.entry.Entry | None", self.keepass.find_entries(title=identifier, first=True))
if not entry:
_entry = self.keepass.add_entry(self.keepass.root_group, identifier, constants.NO_USERNAME, password)
self.keepass.save()
LOG.debug("Created entry %r", _entry)
return
self.change_password(identifier, password)
LOG.debug("Updated password on entry %r", entry)
@overload
def change_password(self, identifier: str, password: None) -> str: ...
@overload
def change_password(self, identifier: str, password: str) -> None: ...
@override
def change_password(self, identifier: str, password: str | None) -> str | None:
"""Change a password."""
generated_password = False
if not password:
password = generate_password()
generated_password = True
entry = cast("pykeepass.entry.Entry | None", self.keepass.find_entries(title=identifier, first=True))
if not entry:
raise ValueError("Error: Entry not found!")
entry.password = password
self.keepass.save()
if generated_password:
return password
return None
@override
def delete_password(self, identifier: str) -> None:
"""Delete password."""
entry = cast("pykeepass.entry.Entry | None", self.keepass.find_entries(title=identifier, first=True))
if not entry:
return
LOG.info("Deleting entry %s for keepass.", entry.uuid)
self.keepass.delete_entry(entry)
self.keepass.save()

View File

@ -1,61 +0,0 @@
"""Password reader classes.
This implements two interfaces to read passwords:
InputPasswordReader and EnvironmentPasswordReader.
"""
import re
import os
import sys
from typing import TextIO, override
import click
from .types import BasePasswordReader
from . import constants
RE_VARNAME = re.compile(r"^[a-zA-Z_]+[a-zA-Z0-9_]*$")
class InputPasswordReader(BasePasswordReader):
"""Read a password from stdin."""
@override
def get_password(self, identifier: str, repeated: bool = False) -> str:
"""Get password."""
if password := click.prompt(
f"Enter password for {identifier}", hide_input=True, type=str, confirmation_prompt=repeated
):
return str(password)
raise ValueError("No password received.")
class EnvironmentPasswordReader(BasePasswordReader):
"""Read a password from the environment.
The environemnt variable will be constructured based on the identifier and the prefix.
Final environemnt variable will be validated according to the regex `[a-zA-Z_]+[a-zA-Z0-9_]*`
"""
def _resolve_var_name(self, identifier: str) -> str:
"""Resolve variable name."""
identifier = identifier.replace("-", "_")
fields = [constants.VAR_PREFIX, identifier]
varname = "_".join(fields)
if not RE_VARNAME.fullmatch(varname):
raise ValueError(
f"Cannot generate encode password identifier in variable name. {varname} is not a valid identifier."
)
return varname
def get_password_from_env(self, identifier: str) -> str:
"""Get password from environment."""
varname = self._resolve_var_name(identifier)
if password := os.getenv(varname, None):
return password
raise ValueError(f"Error: No variable named {varname} resolved.")
@override
def get_password(self, identifier: str, repeated: bool = False) -> str:
"""Get password."""
return self.get_password_from_env(identifier)

View File

View File

@ -1,5 +0,0 @@
"""Sshecret server module."""
from .async_server import AsshyncServer, start_server
__all__ = ["AsshyncServer", "start_server"]

View File

@ -1,143 +0,0 @@
"""Server implemented with asyncssh."""
import logging
from functools import partial
from pathlib import Path
from typing import override
import asyncssh
from sshecret import constants
from sshecret.audit import audit_message
from sshecret.types import ClientSpecification, BaseClientBackend
from sshecret.crypto import create_private_rsa_key
LOG = logging.getLogger(__name__)
def handle_client(process: asyncssh.SSHServerProcess[str]) -> None:
"""Handle client."""
remote_ip = process.get_extra_info("peername")[0]
client_found = process.get_extra_info("client_allowed", False)
if not client_found:
process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n")
audit_message("Unknown connection", source_address=remote_ip)
process.exit(1)
return
client_allowed = process.get_extra_info("client_allowed", False)
if not client_allowed:
audit_message("Not permitted", "SECURITY", source_address=remote_ip)
process.stderr.write(constants.ERROR_SOURCE_IP_NOT_ALLOWED + "\n")
process.exit(1)
return
client = process.get_extra_info("client")
if not client:
audit_message("Unknown client", source_address=remote_ip)
process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n")
process.exit(1)
return
secret_name = process.command
if not secret_name:
audit_message("No secret specified", source_address=remote_ip, client_name=client.name)
process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n")
process.exit(1)
return
LOG.debug(
"Client %s successfully connected. Fetching secret %s", client.name, secret_name
)
audit_message(f"Requested secret", client_name=client.name, secret_name=secret_name, source_address=remote_ip)
secret = client.secrets.get(secret_name)
if not secret:
process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n")
process.exit(1)
return
audit_message("Accessed secret", client.name, secret_name, source_address=remote_ip)
process.stdout.write(secret)
process.exit(0)
class AsshyncServer(asyncssh.SSHServer):
"""Asynchronous SSH server implementation."""
def __init__(self, backend: BaseClientBackend) -> None:
"""Initialize server."""
self.backend: BaseClientBackend = backend
self._conn: asyncssh.SSHServerConnection | None = None
@override
def connection_made(self, conn: asyncssh.SSHServerConnection) -> None:
"""Handle incoming connection."""
peername = conn.get_extra_info("peername")
LOG.debug("Connection established from %r", peername)
self._conn = conn
@override
def begin_auth(self, username: str) -> bool:
"""Begin authentication."""
if not self._conn:
return True
client = self.backend.lookup_name(username)
if not client:
return True
self._conn.set_extra_info(client_found=True)
remote_ip = self._conn.get_extra_info("peername")[0]
LOG.debug("Remote_IP: %r", remote_ip)
assert isinstance(remote_ip, str)
if self.check_connection_allowed(client, remote_ip):
audit_message("Authentication requested", "ACCESS", client_name=client.name, source_address=remote_ip)
self._conn.set_extra_info(client_allowed=True)
self._conn.set_extra_info(client=client)
# Load the key.
public_key = asyncssh.import_authorized_keys(client.public_key)
self._conn.set_authorized_keys(public_key)
return True
@override
def password_auth_supported(self) -> bool:
"""Deny password authentication."""
return False
def check_connection_allowed(
self, client: ClientSpecification, source: str
) -> bool:
"""Check if client is allowed to request secrets."""
LOG.debug("Checking if client is allowed to log in from %s", source)
if isinstance(client.allowed_ips, str) and client.allowed_ips == "*":
audit_message("Permitting login", "SECURITY", client_name=client.name, source_address=source)
LOG.debug("Client has no restrictions on source IP address. Permitting.")
return True
if isinstance(client.allowed_ips, str) and "/" not in client.allowed_ips:
if source == client.allowed_ips:
audit_message("Permitting login", "SECURITY", client_name=client.name, source_address=source)
LOG.debug("Client IP matches permitted address")
return True
LOG.warning(
"Connection for client %s received from IP address %s that is not permitted.",
client.name,
source,
)
audit_message("REJECTED. Invalid address", "SECURITY", client_name=client.name, source_address=source)
return False
async def start_server(
port: int, backend: BaseClientBackend, host_key: str, create_key: bool = False
) -> None:
"""Start server."""
server = partial(AsshyncServer, backend=backend)
if create_key:
create_private_rsa_key(Path(host_key))
await asyncssh.create_server(
server, "", port, server_host_keys=[host_key], process_factory=handle_client
)

View File

@ -1,21 +0,0 @@
"""Server errors."""
class BaseSshecretServerError(Exception):
"""Base Sshecret Server Error."""
class UnknownClientError(BaseSshecretServerError):
"""Client was not recognized."""
class AccessDeniedError(BaseSshecretServerError):
"""Client was not authorized to access the resource."""
class AccessPolicyViolationError(BaseSshecretServerError):
"""Client was not authorized to access the secret."""
class UnknownSecretError(BaseSshecretServerError):
"""Error when resolving the secret."""

View File

@ -1,36 +0,0 @@
"""Password reader for use with the SSH server."""
from typing import override, TextIO
import asyncssh
from sshecret.types import BasePasswordReader
class SSHPasswordReader(BasePasswordReader):
"""SSH Password reader."""
def __init__(self, channel: asyncssh.SSHLineEditorChannel, stdin: asyncssh.SSHReader[str], stdout: asyncssh.SSHWriter[str]) -> None:
"""Initialize password reader."""
self.channel: asyncssh.SSHLineEditorChannel = channel
self.stdin: asyncssh.SSHReader[str] = stdin
self.stdout: asyncssh.SSHWriter[str] = stdout
@override
def get_password(self, identifier: str, repeated: bool = False) -> str:
"""Get password."""
raise RuntimeError("Use get_password_async!")
async def get_password_async(self, identifier: str, repeated: bool = False) -> str:
"""Get password async."""
self.stdout.write(f"Enter password for {identifier}: ")
self.channel.set_echo(False)
while True:
password = await self.stdin.readline()
if not repeated:
break
self.stdout.write(f"\nRe-enter password for {identifier}: ")
password2 = await self.stdin.readline()
if password == password2:
break
self.stdout.write(f"Passwords did not match. Try again.\n")
self.channel.set_echo(True)
return password.strip()

View File

@ -1,82 +0,0 @@
"""Get settings."""
import abc
import enum
import os
import tomllib
from pathlib import Path
from typing import Literal
from dotenv import load_dotenv
from pydantic import BaseModel, DirectoryPath, Field, FilePath
from pydantic_settings import BaseSettings, SettingsConfigDict
from sshecret.keepass import KeepassManager
SETTINGS_FILE = "sshecret.toml"
class Backend(enum.StrEnum):
"""Supported backends."""
FILES = "FILES"
class PasswordManager(enum.StrEnum):
"""Supported password managers."""
KEEPASS = "KeePass"
class SSHServerSettings(BaseModel):
"""SSH Server settings."""
port: int = 22
private_key: FilePath | None = None
class AdminApiSettings(BaseModel):
"""Admin API settings."""
port: int = 8022
class FileBackendSettings(BaseModel):
"""File backend settings.
This will eventually have the Discriminator pattern described in pydantic.
"""
type: Literal["Files"]
location: DirectoryPath
class KeepassPDBSettings(BaseModel):
"""Keepass backend settings."""
type: Literal["KeePass"]
location: FilePath
class Settings(BaseSettings):
"""Sshecret settings."""
model_config = SettingsConfigDict(env_prefix="sshecret_", env_nested_delimiter="__")
backend: FileBackendSettings
password_manager: KeepassPDBSettings
admin_api: AdminApiSettings = Field(default_factory=AdminApiSettings)
ssh_server: SSHServerSettings = Field(default_factory=SSHServerSettings)
def get_settings() -> Settings:
"""Get settings."""
cwd = Path(os.getcwd())
settings_file = cwd / SETTINGS_FILE
if not settings_file.exists():
# This should fail if the current env variables don't exist.
return Settings() # pyright: ignore[reportCallIssue]
with open(settings_file, "rb") as f:
settings_data = tomllib.load(f)
return Settings.model_validate(settings_data)

View File

@ -1 +0,0 @@
"""Shell interface."""

View File

@ -1,55 +0,0 @@
"""Admin shell."""
import os
import click
from click_repl import register_repl
from sshecret.api import ClientManagementAPI
from sshecret import constants
from sshecret.password_readers import InputPasswordReader
from sshecret.keepass import KeepassManager
from sshecret.types import PasswordContext
from .shell_client import ShellClient
DB_PATH = os.path.join(os.getcwd(), "sshecrets.kdbx")
api_client: ShellClient | None = None
@click.group()
@click.pass_context
def cli(ctx: click.Context) -> None:
"""General CLI."""
if api_client is None:
raise RuntimeError("No client object defined.")
@cli.group(name="clients")
def cmd_clients() -> None:
"""Client context."""
@cmd_clients.command(name="show")
def show_clients() -> None:
"""Show clients."""
example_set = ["client1", "client2", "client3"]
for client in example_set:
click.echo(f"- {client}")
@cmd_clients.command(name="add")
@click.argument("name")
def add_client(name: str) -> None:
"""Add a client."""
public_key = click.prompt("Please paste RSA public key")
@cli.command()
@click.option("--overwrite", is_flag=True, help="Overwrite password database.")
def create_database(overwrite: bool) -> None:
"""Create database."""
context = PasswordContext(InputPasswordReader)
KeepassManager.create_database(DB_PATH, context, overwrite)
if __name__ == "__main__":
api_client = ShellClient("127.0.0.1", KeepassManager)

View File

@ -1,14 +0,0 @@
"""Shell commands.
The shell needs to implement the following shell commands:
- Client management
client create/read/update/delete
secret create/read/update/delete
client permit secret
client revoke secret
client key rotate
audit show
"""

View File

@ -1,29 +0,0 @@
"""Shell API Client object for auditing."""
from dataclasses import dataclass, field
from typing import override
from sshecret.password_readers import InputPasswordReader
from sshecret.types import BaseAPIClient, BasePasswordManager, BasePasswordReader, PasswordContext
@dataclass(frozen=True)
class ShellClient(BaseAPIClient):
"""Client connecting from local host."""
source: str
password_manager_type: type[BasePasswordManager]
method: str = field(init=False, default="shell")
@override
def get_reader(self) -> type[BasePasswordReader]:
"""Get reader."""
return InputPasswordReader
@override
def password_manager(self, manager_options: dict[str, str] | None = None) -> BasePasswordManager:
"""Instantiate password manager."""
manager_instance = self.password_manager_type()
if manager_options:
manager_instance.set_manager_options(manager_options)
return manager_instance

View File

@ -1,45 +0,0 @@
"""Shell context manager."""
import sys
from dataclasses import dataclass
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Iterator, TextIO
from click_shell.core import Shell
from sshecret.api import ManagementApi
from sshecret.password_readers import InputPasswordReader
from sshecret.types import BaseClientBackend, BasePasswordManager
from .shell_client import ShellClient
@dataclass(frozen=True)
class ShellContext:
"""Shell context."""
api: ManagementApi
shell: Shell
streams: tuple[TextIO, TextIO] | None = None
@contextmanager
def shell_session(
shell: Shell,
backend: BaseClientBackend,
password_manager: type[BasePasswordManager],
source_address: str,
manager_options: dict[str, str] | None = None,
) -> Iterator[ShellContext]:
"""Start a shell session.
The idea here is to collect the context, store it in an instance variable,
and run the shell.
"""
reader = InputPasswordReader
client = ShellClient(source_address, password_manager)
api = ManagementApi(backend, client, manager_options)

View File

@ -1,96 +0,0 @@
"""Testing utilities and classes."""
from io import StringIO
import tempfile
from dataclasses import dataclass, field
from contextlib import contextmanager
from pathlib import Path
from collections.abc import Iterator
from dotenv import load_dotenv
from .utils import create_client_file, generate_password
from . import settings as app_settings
from .keepass import KeepassManager
@dataclass
class TestClientSpec:
"""Specification of a test client."""
name: str
secrets: dict[str, str] = field(default_factory=dict)
@dataclass
class TestContext:
"""Test context."""
path: Path
master_password: str
@property
def password_database(self) -> Path:
"""Return password database location."""
return self.path / "test.kdbx"
def get_settings(self) -> app_settings.Settings:
"""Get settings."""
return app_settings.Settings(
backend=app_settings.BackendSettings(
backend=app_settings.FileBackendSettings(
type="Files", location=self.path
),
),
password_manager=app_settings.PasswordManagerSettings(
manager=app_settings.KeepassPDBSettings(
type="KeePass", location=self.password_database
)
),
)
def set_environment(context: TestContext) -> None:
"""Set environment."""
password_path = str(context.password_database)
env: list[str] = [
f"sshecret_backend__backend_location={str(context.path)}",
"sshecret_backend__password_manager__manager_type=KeePass",
f"sshecret_backend__password_manager__manager_location={password_path}",
]
env_str = StringIO("\n".join(env))
load_dotenv(stream=env_str)
@contextmanager
def test_context(clients: list[TestClientSpec]) -> Iterator[Path]:
"""Create a test context."""
with tempfile.TemporaryDirectory() as tmpdir:
dirpath = Path(tmpdir)
for client in clients:
filename = dirpath / f"{client.name}.json"
create_client_file(client.name, filename, client.secrets)
yield dirpath
@contextmanager
def api_context(clients: list[TestClientSpec]) -> Iterator[TestContext]:
"""Create a context for testing the full API."""
with tempfile.TemporaryDirectory() as tmpdir:
dirpath = Path(tmpdir)
master_password = generate_password()
context = TestContext(dirpath, master_password)
keepass = KeepassManager.create_database(
str(context.password_database), master_password
)
seen_secrets: list[str] = []
for client in clients:
filename = dirpath / f"{client.name}.json"
create_client_file(client.name, filename, client.secrets)
for secret, value in client.secrets.items():
if secret in seen_secrets:
continue
keepass.add_password(secret, value)
seen_secrets.append(secret)
yield context

View File

@ -1,179 +0,0 @@
"""Interfaces and types."""
import abc
from types import NotImplementedType
from typing import Self, overload
from pydantic import BaseModel
from pydantic.networks import IPvAnyAddress, IPvAnyNetwork
class BasePasswordReader(abc.ABC):
"""Abstract strategy class to read a passwords."""
@abc.abstractmethod
def get_password(self, identifier: str, repeated: bool = False) -> str:
"""Resolve the password, e.g., via input."""
class PasswordContext:
"""Context class for resolving a password."""
def __init__(self, reader: BasePasswordReader) -> None:
"""Initialize password context."""
self._reader: BasePasswordReader = reader
@property
def reader(self) -> BasePasswordReader:
"""Return reader."""
return self._reader
@reader.setter
def reader(self, reader: BasePasswordReader) -> None:
"""Set the reader instance."""
self._reader = reader
def get_password(self, identifier: str, repeated: bool = False) -> str:
"""Get the password."""
return self.reader.get_password(identifier, repeated)
class BasePasswordManager(abc.ABC):
"""Abstract base class for password managers."""
master_password_identifier: str
@classmethod
@abc.abstractmethod
def create_database(
cls,
location: str,
password_context: PasswordContext | str,
overwrite: bool = False,
) -> Self:
"""Create database.
Location can be a file, a url or something else.
"""
@abc.abstractmethod
def open_database(self, password_context: PasswordContext | str) -> None:
"""Open database."""
@abc.abstractmethod
def close_database(self) -> None:
"""Close database."""
@abc.abstractmethod
def get_password(self, identifier: str) -> str | None:
"""Get a password from the manager."""
@abc.abstractmethod
def generate_password(self, identifier: str) -> str:
"""Generate a password using unspecified default rules.
May be expanded later.
Returns the generated password.
"""
@abc.abstractmethod
def add_password(self, identifier: str, password: str) -> None:
"""Add a pre-defined password."""
@abc.abstractmethod
def get_entries(self) -> list[str]:
"""Get names of all entries."""
def set_manager_options(self, options: dict[str, str]) -> None:
"""Set manager options."""
pass
@overload
def change_password(self, identifier: str, password: None) -> str: ...
@overload
def change_password(self, identifier: str, password: str) -> None: ...
@abc.abstractmethod
def change_password(self, identifier: str, password: str | None) -> str | None:
"""Change password."""
@abc.abstractmethod
def delete_password(self, identifier: str) -> None:
"""Delete a password."""
class ClientSpecification(BaseModel):
"""Specification of client."""
name: str
public_key: str
allowed_ips: list[IPvAnyAddress | IPvAnyNetwork] | str = "*"
secrets: dict[str, str] = {}
testing_private_key: str | None = None # Private key only for testing purposes!
class BaseClientBackend(abc.ABC):
"""Base client backend.
This class is responsible for managing the list of clients and facilitate
lookups.
"""
@abc.abstractmethod
def lookup_name(self, name: str) -> ClientSpecification | None:
"""Lookup a client specification by name."""
@abc.abstractmethod
def add_secret(
self,
client_name: str,
secret_name: str,
secret_value: str,
encrypted: bool = False,
) -> None:
"""Add a secret to a client."""
@abc.abstractmethod
def add_client(self, spec: ClientSpecification) -> None:
"""Add a new client."""
@abc.abstractmethod
def update_client(self, name: str, spec: ClientSpecification) -> None:
"""Update client information."""
@abc.abstractmethod
def remove_client(self, name: str, persistent: bool = True) -> None:
"""Delete a client."""
@abc.abstractmethod
def get_all(self) -> list[ClientSpecification]:
"""Get all clients."""
@abc.abstractmethod
def lookup_by_secret(self, secret_name: str) -> list[ClientSpecification]:
"""Lookup by the name of a secret."""
class BaseAPIClient(abc.ABC):
"""Base API Client."""
source: str
method: str
@abc.abstractmethod
def password_manager(
self, manager_options: dict[str, str] | None = None
) -> BasePasswordManager:
"""Instantiate password manager."""
def get_reader(self) -> BasePasswordReader:
"""Get the reader."""
raise NotImplementedError("Class-based password reading not implemented.")
def get_context(self, reader: BasePasswordReader | None = None) -> PasswordContext:
"""Get password context."""
if not reader:
reader = self.get_reader()
return PasswordContext(reader)

View File

@ -1,77 +0,0 @@
"""Various utilities."""
import secrets
from pathlib import Path
from .crypto import (
load_client_key,
encrypt_string,
generate_private_key,
generate_pem,
generate_public_key_string,
)
from .types import ClientSpecification
def generate_password() -> str:
"""Generate a password."""
return secrets.token_urlsafe(32)
def generate_client_object(
name: str, secrets: dict[str, str] | None = None, keyfile: str | None = None
) -> ClientSpecification:
"""Generate a client object."""
private_key = generate_private_key()
if keyfile:
with open(keyfile, "r") as f:
contents = f.read()
if not contents.startswith("ssh-rsa "):
raise RuntimeError("Error: Key must be an RSA key.")
client = ClientSpecification(name=name, public_key=contents.strip())
public_key = load_client_key(client)
else:
pem = generate_pem(private_key)
public_key = private_key.public_key()
pubkey_str = generate_public_key_string(public_key)
client = ClientSpecification(
name=name, public_key=pubkey_str, testing_private_key=pem
)
if secrets:
for secret_name, secret_value in secrets.items():
client.secrets[secret_name] = encrypt_string(secret_value, public_key)
return client
def create_client_file(
name: str,
filename: Path | str,
secrets: dict[str, str] | None = None,
keyfile: str | None = None,
) -> None:
"""Create client file."""
client = generate_client_object(name, secrets, keyfile)
with open(filename, "w") as f:
f.write(client.model_dump_json(exclude_none=True, indent=2))
f.flush()
def add_secret_to_client_file(
filename: str | Path, secret_name: str, secret_value: str
) -> None:
"""Add secret to client file."""
with open(filename, "r") as f:
client = ClientSpecification.model_validate_json(f.read())
public_key = load_client_key(client)
encrypted = encrypt_string(secret_value, public_key)
client.secrets[secret_name] = encrypted
with open(filename, "w") as f:
json_str = client.model_dump_json(exclude_none=True, indent=2)
f.write(json_str)
f.flush()

View File

@ -1 +0,0 @@

View File

@ -1,279 +0,0 @@
"""WebAPI."""
import asyncio
import logging
from functools import lru_cache
import secrets
import time
from typing import Annotated
from fastapi import Header, HTTPException, Depends, Request, APIRouter
from cryptography.fernet import Fernet
from sshecret.api import ManagementApi
from sshecret.types import (
BaseClientBackend,
BasePasswordManager,
ClientSpecification,
)
from sshecret.keepass import KeepassManager
from sshecret.backends.file_table import FileTableBackend
from sshecret.settings import Settings, get_settings
from sshecret.webapi.api_client import WebManagementAPIClient
from . import models
API_VERSION = "v1"
admin_router = APIRouter(prefix=f"/api/{API_VERSION}")
encryption_key = Fernet.generate_key()
cipher = Fernet(encryption_key)
# We store sessions in memory.
sessions: dict[str, tuple[str, float]] = {}
SESSION_TIMEOUT = 600 # 10 minutes
LOG = logging.getLogger(__name__)
session_lock = asyncio.Lock()
def encrypt_session_password(password: str) -> str:
"""Encrypts the master password."""
return cipher.encrypt(password.encode()).decode()
def decrypt_password(encrypted_password: str) -> str:
"""Decrypts the master password asynchronously."""
return cipher.decrypt(encrypted_password.encode()).decode()
async def validate_session(session_id: Annotated[str | None, Header()] = None) -> str:
"""Middleware to validate session and enforce timeout."""
if not session_id:
raise HTTPException(status_code=401, detail="Session ID required")
async with session_lock:
if session_id not in sessions:
raise HTTPException(status_code=401, detail="Session invalid or expired")
encrypted_password, last_access_time = sessions[session_id]
current_time = asyncio.get_event_loop().time()
# Check for session timeout
if current_time - last_access_time > SESSION_TIMEOUT:
del sessions[session_id] # Auto-lock on timeout
raise HTTPException(status_code=401, detail="Session expired")
# Update last access time
sessions[session_id] = (encrypted_password, current_time)
return decrypt_password(encrypted_password)
@lru_cache
def get_app_settings() -> Settings:
"""Get app settings."""
return get_settings()
def get_password_manager(
settings: Annotated[Settings, Depends(get_app_settings)]
) -> BasePasswordManager:
"""Get password manager."""
# Currently only keepass is supported.
keepass = KeepassManager()
keepass.location = settings.password_manager.location
return keepass
async def get_backend(
settings: Annotated[Settings, Depends(get_app_settings)]
) -> BaseClientBackend:
"""Get backend."""
location = settings.backend.location
filetable = FileTableBackend(location)
return filetable
async def get_management_api(
request: Request, settings: Annotated[Settings, Depends(get_app_settings)]
) -> ManagementApi:
"""Get management api."""
client_ip = "unknown"
if req_client := request.client:
client_ip = req_client.host
api_client = WebManagementAPIClient(client_ip, settings)
backend = await get_backend(settings)
return ManagementApi(backend, api_client)
BackendDependency = Annotated[BaseClientBackend, Depends(get_backend)]
ManagementAPIDependency = Annotated[ManagementApi, Depends(get_management_api)]
SessionPasswdDependency = Annotated[str, Depends(validate_session)]
@admin_router.post("/auth/unlock")
async def unlock_database(
password: models.PasswordBody,
password_manager: Annotated[BasePasswordManager, Depends(get_password_manager)],
) -> models.SessionResponse:
"""Unlock database with master password sent in POST body."""
password_str = password.password.get_secret_value()
try:
password_manager.open_database(password_str)
except Exception as e:
LOG.debug("Exception: %s", e, exc_info=True)
raise HTTPException(status_code=401, detail="Invalid password.")
session_id = secrets.token_urlsafe(32)
sessions[session_id] = (encrypt_session_password(password_str), time.time())
return models.SessionResponse(session_id=session_id)
@admin_router.post("/auth/lock")
async def lock_database(
session_id: Annotated[str | None, Header()] = None
) -> dict[str, str]:
"""Lock database."""
if session_id and session_id in sessions:
del sessions[session_id]
return {"message": "LOCKED"}
raise HTTPException(400, detail="Missing session ID.")
@admin_router.get("/auth/status")
async def get_lock_status(
session_id: Annotated[str | None, Header()] = None
) -> dict[str, str]:
"""Get current lock status."""
if session_id and session_id in sessions:
return {"message": "UNLOCKED"}
return {"message": "LOCKED"}
@admin_router.get("/clients")
async def get_clients(admin_api: ManagementAPIDependency) -> list[ClientSpecification]:
"""Get clients."""
return admin_api.get_clients()
@admin_router.get("/clients/{client_id}")
async def get_client(
client_id: str, admin_api: ManagementAPIDependency
) -> ClientSpecification:
"""Get client."""
if client_api := admin_api.get_client(client_id):
return client_api.client
raise HTTPException(status_code=404, detail="Client not found.")
@admin_router.put("/clients/{client_id}")
async def update_client(
client_id: str,
client: ClientSpecification,
admin_api: ManagementAPIDependency,
master_password: SessionPasswdDependency,
) -> ClientSpecification:
"""Update client."""
client_api = admin_api.get_client(client_id)
if not client_api:
raise HTTPException(status_code=404, detail="Client not found.")
new_client = client_api.update_client(client, password=master_password)
return new_client
@admin_router.delete("/clients/{client_id}", status_code=204)
async def delete_client(client_id: str, admin_api: ManagementAPIDependency) -> None:
"""Delete client."""
if admin_api.get_client(client_id):
admin_api.delete_client(client_id)
else:
raise HTTPException(status_code=404, detail="Client not found.")
@admin_router.post("/clients", status_code=201)
async def add_client(
client: models.CreateClientModel, admin_api: ManagementAPIDependency
) -> ClientSpecification:
"""Add client."""
new_client = admin_api.create_client(
client.name, client.public_key, client.allowed_ips
)
return new_client.client
@admin_router.get("/secrets")
async def list_secrets(
admin_api: ManagementAPIDependency, password: SessionPasswdDependency
) -> list[models.SecretListResponse]:
"""List secrets."""
secrets = admin_api.get_secret_names(password=password)
return [
models.SecretListResponse(name=name, assigned_clients=assigned_clients)
for name, assigned_clients in secrets.items()
]
@admin_router.post("/secrets")
async def add_secret(
secret: models.CreateSecretSpecification,
password: SessionPasswdDependency,
admin_api: ManagementAPIDependency,
) -> models.RevealSecretResponse:
"""Add secret.
Will generate a password if none is specified.
"""
secret_value: str | None = None
if secret.secret:
secret_value = secret.secret.get_secret_value()
result_secret = admin_api.add_secret(secret.name, secret_value, password=password)
return models.RevealSecretResponse(name=secret.name, secret=result_secret)
@admin_router.get("/secrets/{name}")
async def get_secret(
name: str, admin_api: ManagementAPIDependency, password: SessionPasswdDependency
) -> models.RevealSecretResponse:
"""Get secret."""
if secret_value := admin_api.get_secret(name, password=password):
return models.RevealSecretResponse(name=name, secret=secret_value)
raise HTTPException(status_code=404, detail="Secret not found.")
@admin_router.put("/secrets/{name}")
async def update_secret(
name: str,
spec: models.UpdateSecretSpecification,
admin_api: ManagementAPIDependency,
password: SessionPasswdDependency,
) -> models.MaybeRevalSecretResponse:
"""Update secret."""
if spec.auto_generate:
secret_value = admin_api.regenerate_secret(name, password=password)
return models.MaybeRevalSecretResponse(name=name, secret=secret_value)
if not spec.secret:
raise HTTPException(
status_code=400,
detail="Secret value must be specified if auto_generate is False",
)
admin_api.update_secret(name, spec.secret, password=password)
return models.MaybeRevalSecretResponse(name=name, secret=None)
@admin_router.delete("/secrets/{name}", status_code=204)
async def delete_secret(
name: str,
admin_api: ManagementAPIDependency,
password: SessionPasswdDependency,
) -> None:
"""Delete secret."""
admin_api.delete_secret(name, password=password)

View File

@ -1,25 +0,0 @@
"""API Client."""
from typing import override
from sshecret.keepass import KeepassManager
from sshecret.types import BaseAPIClient, BasePasswordManager
from sshecret.settings import Settings, get_settings
class WebManagementAPIClient(BaseAPIClient):
"""Client class for the web management API."""
method: str = "admin-web-api"
def __init__(self, source: str, settings: Settings | None = None) -> None:
"""Construct client."""
if not settings:
settings = get_settings()
self.source: str = source
self._password_manager: BasePasswordManager = KeepassManager()
self._password_manager.location = settings.password_manager.manager.location
@override
def password_manager(self, manager_options: dict[str, str] | None = None) -> BasePasswordManager:
"""Get password manager."""
return self._password_manager

View File

@ -1,33 +0,0 @@
"""Admin frontend."""
from fastapi import APIRouter, Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
templates = Jinja2Templates(directory="templates")
frontend = APIRouter()
# I'm just making some placeholders here
@frontend.get("/")
async def index(request: Request) -> HTMLResponse:
"""Get frontpage."""
return templates.TemplateResponse(request, name="index.html")
@frontend.get("/login")
async def login(request: Request) -> HTMLResponse:
"""Get login page."""
return templates.TemplateResponse(request, name="login.html")
@frontend.get("/clients")
async def clients(request: Request) -> HTMLResponse:
"""Get login page."""
return templates.TemplateResponse(request, name="clients.html")
@frontend.get("/secrets")
async def secrets(request: Request) -> HTMLResponse:
"""Get login page."""
return templates.TemplateResponse(request, name="secrets.html")

View File

@ -1,71 +0,0 @@
"""Response models."""
from pydantic import BaseModel, IPvAnyAddress, IPvAnyNetwork, SecretStr
class SSHKeyResponse(BaseModel):
"""Response model for updated SSH keys."""
updated_secrets: list[str]
class SecretListResponse(BaseModel):
"""Response for listing secrets."""
name: str
assigned_clients: list[str]
class CreateSecretSpecification(BaseModel):
"""Model for creating a secret."""
name: str
secret: SecretStr | None
class SecretSpecification(BaseModel):
"""Secret specification."""
name: str
secret: SecretStr
class UpdateSecretSpecification(BaseModel):
"""Model for updating a secret."""
secret: str | None
auto_generate: bool | None = None
class RevealSecretResponse(BaseModel):
"""Reveal secret."""
name: str
secret: str
class MaybeRevalSecretResponse(BaseModel):
"""Model where the secret may be specified."""
name: str
secret: str | None
class PasswordBody(BaseModel):
"""Password body."""
password: SecretStr
class SessionResponse(BaseModel):
"""Session response."""
session_id: str
class CreateClientModel(BaseModel):
"""Model for creating a client."""
name: str
public_key: str
allowed_ips: list[IPvAnyAddress | IPvAnyNetwork] | str = "*"

View File

@ -1,16 +0,0 @@
"""API router."""
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from .api import admin_router
from .frontend import frontend
app = FastAPI()
app.include_router(admin_router)
app.include_router(frontend)
app.mount("/static", StaticFiles(directory="static"), name="static")