Clear POC first draft
This commit is contained in:
@ -1,2 +0,0 @@
|
||||
def hello() -> str:
|
||||
return "Hello from sshecret!"
|
||||
|
||||
@ -1,447 +0,0 @@
|
||||
"""API.
|
||||
|
||||
This module is an attempt to create some sort of meaningfull API around the
|
||||
actions exposed here.
|
||||
"""
|
||||
|
||||
import abc
|
||||
|
||||
from contextlib import contextmanager
|
||||
from collections.abc import Iterator
|
||||
|
||||
from pydantic.networks import IPvAnyAddress, IPvAnyNetwork
|
||||
|
||||
from .audit import audit_message
|
||||
|
||||
from .crypto import load_client_key, load_public_key, encrypt_string
|
||||
|
||||
from .types import (
|
||||
BaseAPIClient,
|
||||
BaseClientBackend,
|
||||
BasePasswordManager,
|
||||
BasePasswordReader,
|
||||
ClientSpecification,
|
||||
PasswordContext,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def password_manager_session(
|
||||
password_manager: BasePasswordManager,
|
||||
password_context: PasswordContext | str,
|
||||
api_client: BaseAPIClient,
|
||||
) -> Iterator[BasePasswordManager]:
|
||||
"""Open password manager for read/write in a context."""
|
||||
audit_message(
|
||||
"Opening password manager session",
|
||||
"SECURITY",
|
||||
source_address=api_client.source,
|
||||
)
|
||||
password_manager.open_database(password_context)
|
||||
yield password_manager
|
||||
|
||||
audit_message(
|
||||
"Closing password manager session",
|
||||
"SECURITY",
|
||||
source_address=api_client.source,
|
||||
)
|
||||
password_manager.close_database()
|
||||
|
||||
|
||||
class BaseSshecretAPI(abc.ABC):
|
||||
"""Base API class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backend: BaseClientBackend,
|
||||
api_client: BaseAPIClient,
|
||||
manager_options: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Initialize API."""
|
||||
|
||||
self.backend: BaseClientBackend = backend
|
||||
self.api_client: BaseAPIClient = api_client
|
||||
self.manager_options: dict[str, str] | None = manager_options
|
||||
|
||||
def _log_audit(
|
||||
self,
|
||||
message: str,
|
||||
audit_type: str,
|
||||
client_name: str | None = None,
|
||||
**details: str,
|
||||
) -> None:
|
||||
"""Log an audit message."""
|
||||
audit_message(
|
||||
message,
|
||||
audit_type,
|
||||
client_name,
|
||||
source_address=self.api_client.source,
|
||||
**details,
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def password_session(
|
||||
self, reader: BasePasswordReader | None = None, password: str | None = None
|
||||
) -> Iterator[BasePasswordManager]:
|
||||
"""Open a password session."""
|
||||
if password:
|
||||
context = password
|
||||
else:
|
||||
if not reader:
|
||||
reader = self.api_client.get_reader()
|
||||
context = self.api_client.get_context(reader)
|
||||
|
||||
password_manager = self.api_client.password_manager(self.manager_options)
|
||||
with password_manager_session(
|
||||
password_manager, context, self.api_client
|
||||
) as session:
|
||||
yield session
|
||||
|
||||
|
||||
class ClientManagementAPI(BaseSshecretAPI):
|
||||
"""API for managing clients."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backend: BaseClientBackend,
|
||||
client: ClientSpecification,
|
||||
api_client: BaseAPIClient,
|
||||
manager_options: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Create client management API instance."""
|
||||
super().__init__(backend, api_client, manager_options)
|
||||
self.client: ClientSpecification = client
|
||||
self.__password_manager: BasePasswordManager | None = None
|
||||
|
||||
def log_security(self, message: str, **details: str) -> None:
|
||||
"""Log a security related message."""
|
||||
self._log_audit(message, "SECURITY", self.client.name, **details)
|
||||
|
||||
def log_info(self, message: str) -> None:
|
||||
"""Log an informational message."""
|
||||
self._log_audit(message, "INFORMATIONAL", self.client.name)
|
||||
|
||||
@property
|
||||
def password_manager(self) -> BasePasswordManager:
|
||||
"""Get password manager."""
|
||||
if self.__password_manager:
|
||||
self.log_security("Accessed password manager")
|
||||
return self.__password_manager
|
||||
raise RuntimeError("Password manager not initialized.")
|
||||
|
||||
@password_manager.setter
|
||||
def password_manager(self, instance: BasePasswordManager) -> None:
|
||||
"""Set password manager instance."""
|
||||
self.log_security("Opened password manager.")
|
||||
self.__password_manager = instance
|
||||
|
||||
def get_secrets(self) -> list[str]:
|
||||
"""Get names of the secrets that the client has access to.."""
|
||||
self.log_security("Listing secret names.")
|
||||
return list(self.client.secrets.keys())
|
||||
|
||||
def update_client(
|
||||
self,
|
||||
client: ClientSpecification,
|
||||
reader: BasePasswordReader | None = None,
|
||||
password: str | None = None,
|
||||
) -> ClientSpecification:
|
||||
"""Update client."""
|
||||
self.log_info("Updating client")
|
||||
update_data = client.model_dump(exclude_unset=True)
|
||||
if client.public_key != self.client.public_key:
|
||||
self.log_security("Client public key has changed.")
|
||||
if client.secrets != self.client.secrets:
|
||||
raise RuntimeError(
|
||||
"Error: Cannot update public key and secrets in the same operation."
|
||||
)
|
||||
del update_data["secrets"]
|
||||
secrets = self._re_encrypt(client.public_key, reader, password)
|
||||
update_data["secrets"] = secrets
|
||||
|
||||
updated_client = self.client.model_copy(update=update_data)
|
||||
self.backend.update_client(self.client.name, updated_client)
|
||||
self.client = updated_client
|
||||
return updated_client
|
||||
|
||||
def update_secret(self, name: str, password: str) -> None:
|
||||
"""Update a secret.
|
||||
|
||||
If secret is not already a part of the client, it will be added.
|
||||
"""
|
||||
if name in self.client.secrets:
|
||||
self.log_security("Updating secret", secret_name=name)
|
||||
else:
|
||||
self.log_security("Adding secret", secret_name=name)
|
||||
public_key = load_client_key(self.client)
|
||||
encrypted = encrypt_string(password, public_key)
|
||||
client_secrets = {**self.client.secrets, name: encrypted}
|
||||
updated_client = self.client.model_copy(update={"secrets": client_secrets})
|
||||
self.backend.update_client(self.client.name, updated_client)
|
||||
self.client = updated_client
|
||||
|
||||
def _re_encrypt(
|
||||
self,
|
||||
new_key: str,
|
||||
reader: BasePasswordReader | None = None,
|
||||
password: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""Update public key on given client."""
|
||||
audit_message(
|
||||
"Updating public key",
|
||||
"INFORMATIONAL",
|
||||
self.client.name,
|
||||
source_address=self.api_client.source,
|
||||
)
|
||||
if password:
|
||||
context = password
|
||||
else:
|
||||
if not reader:
|
||||
reader = self.api_client.get_reader()
|
||||
context = self.api_client.get_context(reader)
|
||||
|
||||
password_manager = self.api_client.password_manager(self.manager_options)
|
||||
self.password_manager = password_manager
|
||||
|
||||
client_key = load_public_key(new_key.encode())
|
||||
secrets: dict[str, str] = {}
|
||||
with password_manager_session(
|
||||
password_manager, context, self.api_client
|
||||
) as password_session:
|
||||
for name in self.get_secrets():
|
||||
secret = password_session.get_password(name)
|
||||
audit_message(
|
||||
"Updating encrypted value",
|
||||
"SECURITY",
|
||||
self.client.name,
|
||||
secret_name=name,
|
||||
source_address=self.api_client.source,
|
||||
)
|
||||
if not secret:
|
||||
raise RuntimeError("Could not fetch a new secret value.")
|
||||
new_value = encrypt_string(secret, client_key)
|
||||
secrets[name] = new_value
|
||||
|
||||
return secrets
|
||||
|
||||
def update_public_key(
|
||||
self,
|
||||
new_key: str,
|
||||
reader: BasePasswordReader | None = None,
|
||||
password: str | None = None,
|
||||
) -> None:
|
||||
"""Update the public key."""
|
||||
audit_message(
|
||||
"Updating public key",
|
||||
"INFORMATIONAL",
|
||||
self.client.name,
|
||||
source_address=self.api_client.source,
|
||||
)
|
||||
secrets = self._re_encrypt(new_key, reader, password)
|
||||
|
||||
updated = self.client.model_copy(update={"secrets": secrets})
|
||||
self.backend.update_client(self.client.name, updated)
|
||||
|
||||
@classmethod
|
||||
def get_client(
|
||||
cls,
|
||||
backend: BaseClientBackend,
|
||||
name: str,
|
||||
api_client: BaseAPIClient,
|
||||
) -> "ClientManagementAPI | None":
|
||||
"""Get client."""
|
||||
client = backend.lookup_name(name)
|
||||
if not client:
|
||||
return None
|
||||
return cls(backend, client, api_client)
|
||||
|
||||
@classmethod
|
||||
def create_client(
|
||||
cls,
|
||||
backend: BaseClientBackend,
|
||||
name: str,
|
||||
api_client: BaseAPIClient,
|
||||
public_key: str,
|
||||
allowed_ips: list[IPvAnyAddress | IPvAnyNetwork] | str = "*",
|
||||
) -> "ClientManagementAPI":
|
||||
"""Create a client."""
|
||||
client = ClientSpecification(
|
||||
name=name, public_key=public_key, allowed_ips=allowed_ips
|
||||
)
|
||||
backend.add_client(client)
|
||||
return cls(backend, client, api_client)
|
||||
|
||||
|
||||
class ManagementApi(BaseSshecretAPI):
|
||||
"""Api for general management."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backend: BaseClientBackend,
|
||||
api_client: BaseAPIClient,
|
||||
manager_options: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Initialize API."""
|
||||
super().__init__(backend, api_client, manager_options)
|
||||
|
||||
def log_security(
|
||||
self, message: str, client_name: str | None = None, **details: str
|
||||
) -> None:
|
||||
"""Log a security related message."""
|
||||
self._log_audit(message, "SECURITY", client_name, **details)
|
||||
|
||||
def log_info(self, message: str, client_name: str | None = None) -> None:
|
||||
"""Log an informational message."""
|
||||
self._log_audit(message, "INFORMATIONAL", client_name)
|
||||
|
||||
def get_client(self, name: str) -> ClientManagementAPI | None:
|
||||
"""Get a client."""
|
||||
client = self.backend.lookup_name(name)
|
||||
if not client:
|
||||
return None
|
||||
return ClientManagementAPI(
|
||||
self.backend, client, self.api_client, self.manager_options
|
||||
)
|
||||
|
||||
def get_clients(self) -> list[ClientSpecification]:
|
||||
"""Get clients."""
|
||||
self.log_info("Fetched all clients")
|
||||
return self.backend.get_all()
|
||||
|
||||
def _get_clients(self) -> list[ClientSpecification]:
|
||||
"""Get clients."""
|
||||
return self.backend.get_all()
|
||||
|
||||
def delete_client(self, name: str) -> None:
|
||||
"""Delete client."""
|
||||
client = self.backend.lookup_name(name)
|
||||
if not client:
|
||||
self.log_info("Attempted to delete a non-existing client.", name)
|
||||
return
|
||||
self.log_security("Deleting client", name)
|
||||
self.backend.remove_client(name)
|
||||
|
||||
def create_client(
|
||||
self,
|
||||
name: str,
|
||||
public_key: str,
|
||||
allowed_ips: list[IPvAnyAddress | IPvAnyNetwork] | str = "*",
|
||||
) -> ClientManagementAPI:
|
||||
"""Create a client."""
|
||||
self.log_info("Creating new client", name)
|
||||
client = ClientSpecification(
|
||||
name=name, public_key=public_key, allowed_ips=allowed_ips
|
||||
)
|
||||
self.backend.add_client(client)
|
||||
return ClientManagementAPI(self.backend, client, self.api_client)
|
||||
|
||||
def get_secret_names(
|
||||
self, reader: BasePasswordReader | None = None, password: str | None = None
|
||||
) -> dict[str, list[str]]:
|
||||
"""Get secret names and which clients have these.."""
|
||||
self.log_security("Listing all secret names.")
|
||||
with self.password_session(reader=reader, password=password) as session:
|
||||
secret_names = session.get_entries()
|
||||
|
||||
secret_mapping: dict[str, list[str]] = {}
|
||||
for name in secret_names:
|
||||
secret_mapping[name] = [
|
||||
client.name for client in self.backend.lookup_by_secret(name)
|
||||
]
|
||||
|
||||
return secret_mapping
|
||||
|
||||
def add_secret(
|
||||
self,
|
||||
name: str,
|
||||
secret_value: str | None,
|
||||
reader: BasePasswordReader | None = None,
|
||||
password: str | None = None,
|
||||
) -> str:
|
||||
"""Add a secret."""
|
||||
self.log_security("Adding new secret", secret_name=name)
|
||||
with self.password_session(reader=reader, password=password) as session:
|
||||
if not secret_value:
|
||||
self.log_security("Auto-generating a secret value", secret_name=name)
|
||||
secret_value = session.generate_password(name)
|
||||
else:
|
||||
session.add_password(name, secret_value)
|
||||
return secret_value
|
||||
|
||||
def get_secret(
|
||||
self,
|
||||
name: str,
|
||||
reader: BasePasswordReader | None = None,
|
||||
password: str | None = None,
|
||||
) -> str | None:
|
||||
"""Get the clear-text value of a secret."""
|
||||
self.log_security("Client requested secret value", secret_name=name)
|
||||
with self.password_session(reader=reader, password=password) as session:
|
||||
secret = session.get_password(name)
|
||||
|
||||
return secret
|
||||
|
||||
def update_secret(
|
||||
self,
|
||||
name: str,
|
||||
new_value: str,
|
||||
reader: BasePasswordReader | None = None,
|
||||
password: str | None = None,
|
||||
) -> None:
|
||||
"""Update a secret with a given name."""
|
||||
self.log_security("Changing secret", secret_name=name)
|
||||
with self.password_session(reader=reader, password=password) as session:
|
||||
session.change_password(name, new_value)
|
||||
|
||||
clients = self.backend.lookup_by_secret(name)
|
||||
for client in clients:
|
||||
client_api = self.get_client(client.name)
|
||||
if not client_api:
|
||||
continue
|
||||
client_api.update_secret(name, new_value)
|
||||
|
||||
def regenerate_secret(
|
||||
self,
|
||||
name: str,
|
||||
reader: BasePasswordReader | None = None,
|
||||
password: str | None = None,
|
||||
) -> str:
|
||||
"""Regenerate a secret."""
|
||||
self.log_security("Generating a new secret value", secret_name=name)
|
||||
with self.password_session(reader=reader, password=password) as session:
|
||||
new_value = session.change_password(name, None)
|
||||
|
||||
clients = self.backend.lookup_by_secret(name)
|
||||
for client in clients:
|
||||
client_api = self.get_client(client.name)
|
||||
if not client_api:
|
||||
continue
|
||||
client_api.update_secret(name, new_value)
|
||||
|
||||
return new_value
|
||||
|
||||
def delete_secret(
|
||||
self,
|
||||
name: str,
|
||||
reader: BasePasswordReader | None = None,
|
||||
password: str | None = None,
|
||||
) -> None:
|
||||
"""Delete secret."""
|
||||
clients = self.backend.lookup_by_secret(name)
|
||||
self.log_security("Deleting secret", secret_name=name)
|
||||
with self.password_session(reader=reader, password=password) as session:
|
||||
session.delete_password(name)
|
||||
|
||||
for client in clients:
|
||||
secrets = {**client.secrets}
|
||||
del secrets[name]
|
||||
new_client = client.model_copy(update={"secrets": secrets})
|
||||
client_api = self.get_client(client.name)
|
||||
if not client_api:
|
||||
continue
|
||||
self.log_security(
|
||||
"Removing secret from client.",
|
||||
client_name=client.name,
|
||||
secret_name=name,
|
||||
)
|
||||
client_api.update_client(new_client, password=password)
|
||||
@ -1,66 +0,0 @@
|
||||
"""Audit setup."""
|
||||
|
||||
import enum
|
||||
import json
|
||||
import logging
|
||||
|
||||
from typing import Any
|
||||
from pythonjsonlogger.json import JsonFormatter
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from .constants import AUDIT_LOG_NAME
|
||||
|
||||
AUDIT_LOG = logging.getLogger(AUDIT_LOG_NAME)
|
||||
|
||||
|
||||
class AuditMessageType(enum.StrEnum):
|
||||
"""Audit Message Type."""
|
||||
|
||||
ACCESS = enum.auto() # Someone accessed something
|
||||
SECURITY = enum.auto() # A message related to security
|
||||
INFORMATIONAL = enum.auto() # other informational messages
|
||||
|
||||
|
||||
class AuditMessage(BaseModel):
|
||||
"""Audit message."""
|
||||
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
type: AuditMessageType
|
||||
message: str
|
||||
client_name: str | None = None
|
||||
source_address: str | None = None
|
||||
secret_name: str | None = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Stringify object as JSON."""
|
||||
return self.model_dump_json()
|
||||
|
||||
|
||||
def audit_message(
|
||||
message: str,
|
||||
audit_type: AuditMessageType | str | None = None,
|
||||
client_name: str | None = None,
|
||||
secret_name: str | None = None,
|
||||
source_address: str | None = None,
|
||||
**details: str
|
||||
) -> None:
|
||||
"""Create an audit message."""
|
||||
if not audit_type:
|
||||
audit_type = AuditMessageType.INFORMATIONAL
|
||||
|
||||
if audit_type not in list(AuditMessageType):
|
||||
audit_type = AuditMessageType.INFORMATIONAL
|
||||
|
||||
audit_message = AuditMessage(
|
||||
type=audit_type,
|
||||
message=message,
|
||||
client_name=client_name,
|
||||
source_address=source_address,
|
||||
secret_name=secret_name,
|
||||
)
|
||||
|
||||
audit_dict = audit_message.model_dump(exclude_none=True)
|
||||
|
||||
AUDIT_LOG.info({**audit_dict, **details})
|
||||
@ -1,5 +0,0 @@
|
||||
"""Backend implementations"""
|
||||
|
||||
from .file_table import FileTableBackend
|
||||
|
||||
__all__ = ["FileTableBackend"]
|
||||
@ -1,133 +0,0 @@
|
||||
"""File table based backend."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import override
|
||||
|
||||
import littletable as lt
|
||||
|
||||
from sshecret.crypto import load_client_key, encrypt_string
|
||||
from sshecret.types import ClientSpecification
|
||||
from sshecret.types import BaseClientBackend
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_clients_from_dir(directory: Path) -> dict[Path, ClientSpecification]:
|
||||
"""Load clients from a directory."""
|
||||
if not directory.exists() or not directory.is_dir():
|
||||
raise ValueError("Invalid directory specified.")
|
||||
|
||||
clients: dict[Path, ClientSpecification] = {}
|
||||
for client_file in directory.glob("*.json"):
|
||||
with open(client_file, "r") as f:
|
||||
client = ClientSpecification.model_validate_json(f.read())
|
||||
if client_file.name != f"{client.name}.json":
|
||||
raise RuntimeError(
|
||||
"Filename scheme of clients does not conform to expected format. Aborting import!"
|
||||
)
|
||||
clients[client_file] = client
|
||||
|
||||
return clients
|
||||
|
||||
|
||||
|
||||
class FileTableBackend(BaseClientBackend):
|
||||
"""In-memory littletable based backend."""
|
||||
|
||||
def __init__(self, directory: Path) -> None:
|
||||
"""Create backend instance."""
|
||||
LOG.debug("Creating in-memory table to hold clients.")
|
||||
self._directory: Path = directory
|
||||
self.table: lt.Table[ClientSpecification] = lt.Table()
|
||||
self._setup_table()
|
||||
client_files = load_clients_from_dir(directory)
|
||||
client_count = len(client_files)
|
||||
LOG.debug("Loaded %s clients from disk.", client_count)
|
||||
# self.client_file_map: dict[str, Path] = {client.name: filepath for filepath, client in client_files.items()}
|
||||
LOG.debug("Inserting clients into table.")
|
||||
self.table.insert_many(list(client_files.values()))
|
||||
|
||||
def _setup_table(self) -> None:
|
||||
"""Set up the table."""
|
||||
self.table.create_index("name", unique=True)
|
||||
|
||||
@override
|
||||
def lookup_name(self, name: str) -> ClientSpecification | None:
|
||||
"""Lookup client by name."""
|
||||
if result := self.table.by.name.get(name):
|
||||
if isinstance(result, ClientSpecification):
|
||||
return result
|
||||
return None
|
||||
|
||||
@override
|
||||
def add_client(self, spec: ClientSpecification) -> None:
|
||||
"""Add client."""
|
||||
self.table.insert(spec)
|
||||
self._write_spec_file(spec)
|
||||
|
||||
def _write_spec_file(self, spec: ClientSpecification) -> None:
|
||||
"""Write spec file to disk."""
|
||||
dest_file_name = f"{spec.name}.json"
|
||||
dest_file = self._directory / dest_file_name
|
||||
with open(dest_file.absolute(), "w") as f:
|
||||
f.write(
|
||||
spec.model_dump_json(exclude_none=True, exclude_unset=True, indent=2)
|
||||
)
|
||||
f.flush()
|
||||
|
||||
@override
|
||||
def add_secret(
|
||||
self,
|
||||
client_name: str,
|
||||
secret_name: str,
|
||||
secret_value: str,
|
||||
encrypted: bool = False,
|
||||
) -> None:
|
||||
"""Add secret."""
|
||||
client: ClientSpecification = self.table.by.name[client_name] # pyright: ignore[reportAssignmentType]
|
||||
if not encrypted:
|
||||
public_key = load_client_key(client)
|
||||
secret_value = encrypt_string(secret_value, public_key)
|
||||
client.secrets[secret_name] = secret_value
|
||||
self._update_client_data(client)
|
||||
self._write_spec_file(client)
|
||||
|
||||
@override
|
||||
def remove_client(self, name: str, persistent: bool = True) -> None:
|
||||
"""Delete client."""
|
||||
client = self.lookup_name(name)
|
||||
if not client:
|
||||
raise ValueError("Client does not exist!")
|
||||
self.table.remove(client)
|
||||
if persistent:
|
||||
filename = f"{client.name}.json"
|
||||
filepath = self._directory / filename
|
||||
filepath.unlink()
|
||||
|
||||
@override
|
||||
def update_client(self, name: str, spec: ClientSpecification) -> None:
|
||||
"""Update client."""
|
||||
if not self.lookup_name(name):
|
||||
raise ValueError("Client does not exist!")
|
||||
self._update_client_data(spec)
|
||||
self._write_spec_file(spec)
|
||||
|
||||
def _update_client_data(self, spec: ClientSpecification) -> None:
|
||||
"""Update client data."""
|
||||
existing = self.lookup_name(spec.name)
|
||||
if existing:
|
||||
self.table.remove(existing)
|
||||
self.add_client(spec)
|
||||
|
||||
@override
|
||||
def get_all(self) -> list[ClientSpecification]:
|
||||
"""Get all clients."""
|
||||
return list(self.table)
|
||||
|
||||
@override
|
||||
def lookup_by_secret(self, secret_name: str) -> list[ClientSpecification]:
|
||||
"""Lookup by secret name."""
|
||||
results = self.table.where(lambda client: secret_name in client.secrets)
|
||||
return list(results)
|
||||
@ -1,3 +0,0 @@
|
||||
"""Command Line Interface"""
|
||||
|
||||
import click
|
||||
@ -1,25 +0,0 @@
|
||||
"""Client code"""
|
||||
|
||||
from typing import TextIO
|
||||
import click
|
||||
|
||||
from sshecret.crypto import decode_string, load_private_key
|
||||
|
||||
|
||||
def decrypt_secret(encoded: str, client_key: str) -> str:
|
||||
"""Decrypt secret."""
|
||||
private_key = load_private_key(client_key)
|
||||
return decode_string(encoded, private_key)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("keyfile", type=click.Path(exists=True, readable=True, dir_okay=False))
|
||||
@click.argument("encrypted_input", type=click.File("r"))
|
||||
def cli_decrypt(keyfile: str, encrypted_input: TextIO) -> None:
|
||||
"""Decrypt on command line."""
|
||||
decrypted = decrypt_secret(encrypted_input.read(), keyfile)
|
||||
click.echo(decrypted)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_decrypt()
|
||||
@ -1,18 +0,0 @@
|
||||
"""Config file."""
|
||||
from pathlib import Path
|
||||
from pydantic import SecretStr
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class KeepassSettings(BaseSettings):
|
||||
"""Settings for Keepasss password database."""
|
||||
|
||||
database_path: Path
|
||||
|
||||
|
||||
class SshecretSettings(BaseSettings):
|
||||
"""Settings model."""
|
||||
|
||||
admin_password: SecretStr
|
||||
admin_ssh_key: str | None = None
|
||||
keepass: KeepassSettings
|
||||
@ -1,17 +0,0 @@
|
||||
"""Constants."""
|
||||
|
||||
MASTER_PASSWORD = "MASTER_PASSWORD"
|
||||
NO_USERNAME = "NO_USERNAME"
|
||||
VAR_PREFIX = "SSHECRET"
|
||||
|
||||
ERROR_NO_SECRET_FOUND = "Error: No secret available with the given name."
|
||||
ERROR_UKNOWN_CLIENT_OR_SECRET = "Error: Invalid client or secret name."
|
||||
ERROR_NO_COMMAND_RECEIVED = "Error: No command was received from the client."
|
||||
ERROR_SOURCE_IP_NOT_ALLOWED = (
|
||||
"Error: Client not authorized to connect from the given host."
|
||||
)
|
||||
|
||||
RSA_PUBLIC_EXPONENT = 65537
|
||||
RSA_KEY_SIZE = 2048
|
||||
|
||||
AUDIT_LOG_NAME = "AUDIT"
|
||||
@ -1,106 +0,0 @@
|
||||
"""Encryption related functions.
|
||||
|
||||
Note! Encryption uses the less secure PKCS1v15 padding. This is to allow
|
||||
decryption via openssl on the command line.
|
||||
|
||||
"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
|
||||
from .types import ClientSpecification
|
||||
from . import constants
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_client_key(client: ClientSpecification) -> rsa.RSAPublicKey:
|
||||
"""Load public key."""
|
||||
keybytes = client.public_key.encode()
|
||||
return load_public_key(keybytes)
|
||||
|
||||
|
||||
def load_public_key(keybytes: bytes) -> rsa.RSAPublicKey:
|
||||
public_key = serialization.load_ssh_public_key(keybytes)
|
||||
if not isinstance(public_key, rsa.RSAPublicKey):
|
||||
raise RuntimeError("Only RSA keys are supported.")
|
||||
pem_public_key = public_key.public_bytes(encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo)
|
||||
LOG.info("pem:\n%s", pem_public_key)
|
||||
return public_key
|
||||
|
||||
|
||||
def load_private_key(filename: str) -> rsa.RSAPrivateKey:
|
||||
"""Load a private key."""
|
||||
with open(filename, "rb") as f:
|
||||
private_key = serialization.load_ssh_private_key(f.read(), password=None)
|
||||
if not isinstance(private_key, rsa.RSAPrivateKey):
|
||||
raise RuntimeError("Only RSA keys are supported.")
|
||||
return private_key
|
||||
|
||||
|
||||
def encrypt_string(string: str, public_key: rsa.RSAPublicKey) -> str:
|
||||
"""Encrypt string, end return it base64 encoded."""
|
||||
message = string.encode()
|
||||
ciphertext = public_key.encrypt(
|
||||
message,
|
||||
padding.PKCS1v15(),
|
||||
)
|
||||
return base64.b64encode(ciphertext).decode()
|
||||
|
||||
|
||||
def decode_string(ciphertext: str, private_key: rsa.RSAPrivateKey) -> str:
|
||||
"""Decode a string. String must be base64 encoded."""
|
||||
decoded = base64.b64decode(ciphertext)
|
||||
decrypted = private_key.decrypt(
|
||||
decoded,
|
||||
padding.PKCS1v15(),
|
||||
)
|
||||
return decrypted.decode()
|
||||
|
||||
|
||||
def generate_private_key() -> rsa.RSAPrivateKey:
|
||||
"""Generate private RSA key."""
|
||||
private_key = rsa.generate_private_key(
|
||||
public_exponent=constants.RSA_PUBLIC_EXPONENT, key_size=constants.RSA_KEY_SIZE
|
||||
)
|
||||
return private_key
|
||||
|
||||
|
||||
def generate_pem(private_key: rsa.RSAPrivateKey) -> str:
|
||||
"""Generate PEM."""
|
||||
pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.OpenSSH,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
return pem.decode()
|
||||
|
||||
|
||||
def create_private_rsa_key(filename: Path) -> None:
|
||||
"""Create an RSA Private key at the given path."""
|
||||
if filename.exists():
|
||||
raise RuntimeError("Error: private key file already exists.")
|
||||
LOG.debug("Generating private RSA key at %s", filename)
|
||||
private_key = generate_private_key()
|
||||
with open(filename, "wb") as f:
|
||||
pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.OpenSSH,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
lines = f.write(pem)
|
||||
LOG.debug("Wrote %s lines", lines)
|
||||
f.flush()
|
||||
|
||||
|
||||
def generate_public_key_string(public_key: rsa.RSAPublicKey) -> str:
|
||||
"""Generate public key string."""
|
||||
keybytes = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.OpenSSH,
|
||||
format=serialization.PublicFormat.OpenSSH,
|
||||
)
|
||||
return keybytes.decode()
|
||||
@ -1,93 +0,0 @@
|
||||
"""Development CLI commands."""
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
import asyncssh
|
||||
import click
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
from pythonjsonlogger.json import JsonFormatter
|
||||
|
||||
from .server import start_server
|
||||
from sshecret.backends import FileTableBackend
|
||||
from .utils import create_client_file, add_secret_to_client_file
|
||||
from .constants import AUDIT_LOG_NAME
|
||||
|
||||
|
||||
def thread_id_filter(record: logging.LogRecord) -> logging.LogRecord:
|
||||
"""Resolve thread id."""
|
||||
record.thread_id = threading.get_native_id()
|
||||
return record
|
||||
|
||||
|
||||
LOG = logging.getLogger()
|
||||
handler = logging.StreamHandler()
|
||||
handler.addFilter(thread_id_filter)
|
||||
formatter = logging.Formatter(
|
||||
"%(thread_id)d:%(created)f:%(levelname)s:%(name)s:%(module)s:%(message)s"
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
LOG.addHandler(handler)
|
||||
LOG.setLevel(logging.DEBUG)
|
||||
|
||||
AUDIT_LOG = logging.getLogger(AUDIT_LOG_NAME)
|
||||
audit_formatter = JsonFormatter()
|
||||
audit_handler = logging.StreamHandler()
|
||||
audit_handler.setFormatter(audit_formatter)
|
||||
AUDIT_LOG.addHandler(audit_handler)
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli() -> None:
|
||||
"""Run commands for testing."""
|
||||
|
||||
|
||||
@cli.command("create-client")
|
||||
@click.argument("name")
|
||||
@click.argument(
|
||||
"filename", type=click.Path(file_okay=True, dir_okay=False, writable=True)
|
||||
)
|
||||
@click.option("--public-key", type=click.Path(file_okay=True))
|
||||
def create_client(name: str, filename: str, public_key: str | None) -> None:
|
||||
"""Create a client."""
|
||||
create_client_file(name, filename, keyfile=public_key)
|
||||
click.echo(f"Wrote client config to {filename}")
|
||||
|
||||
|
||||
@cli.command("add-secret")
|
||||
@click.argument(
|
||||
"filename", type=click.Path(file_okay=True, dir_okay=False, writable=True)
|
||||
)
|
||||
@click.argument("secret-name")
|
||||
@click.argument("secret-value")
|
||||
def add_secret(filename: str, secret_name: str, secret_value: str) -> None:
|
||||
"""Add secret to client file."""
|
||||
add_secret_to_client_file(filename, secret_name, secret_value)
|
||||
click.echo(f"Wrote secret to {filename}")
|
||||
|
||||
|
||||
@cli.command("server")
|
||||
@click.argument("directory", type=click.Path(file_okay=False, dir_okay=True))
|
||||
@click.argument("port", type=click.INT)
|
||||
def run_async_server(directory: str, port: int) -> None:
|
||||
"""Run async server."""
|
||||
loop = asyncio.new_event_loop()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
serverdir = Path(tmpdir)
|
||||
host_key = str(serverdir / "hostkey")
|
||||
clientdir = Path(directory)
|
||||
backend = FileTableBackend(clientdir)
|
||||
try:
|
||||
loop.run_until_complete(start_server(port, backend, host_key, True))
|
||||
except (OSError, asyncssh.Error) as exc:
|
||||
click.echo(f"Error starting server: {exc}")
|
||||
sys.exit(1)
|
||||
|
||||
loop.run_forever()
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
@ -1,180 +0,0 @@
|
||||
"""Keepass integration."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import cast, final, overload, override, Self
|
||||
|
||||
import pykeepass
|
||||
from . import constants
|
||||
from .types import BasePasswordManager, PasswordContext
|
||||
from .utils import generate_password
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class KeepassManager(BasePasswordManager):
|
||||
"""KeepassXC compatible password manager."""
|
||||
|
||||
master_password_identifier = constants.MASTER_PASSWORD
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize password manager."""
|
||||
self._location: Path | None = None
|
||||
self._keepass: pykeepass.PyKeePass | None = None
|
||||
|
||||
@property
|
||||
def location(self) -> Path:
|
||||
"""Get location."""
|
||||
if not self._location:
|
||||
raise RuntimeError("No location has been specified.")
|
||||
return self._location
|
||||
|
||||
@location.setter
|
||||
def location(self, location: Path) -> None:
|
||||
"""Set location."""
|
||||
if not location.exists() or not location.is_file():
|
||||
raise RuntimeError("Unable to read provided password file.")
|
||||
self._location = location
|
||||
|
||||
@override
|
||||
def set_manager_options(self, options: dict[str, str]) -> None:
|
||||
"""Set manager options."""
|
||||
if "location" in options:
|
||||
location = Path(str(options["location"]))
|
||||
self.location = location
|
||||
|
||||
@property
|
||||
def keepass(self) -> pykeepass.PyKeePass:
|
||||
"""Return keepass instance."""
|
||||
if self._keepass:
|
||||
return self._keepass
|
||||
raise RuntimeError("Error: Database has not been opened.")
|
||||
|
||||
@keepass.setter
|
||||
def keepass(self, instance: pykeepass.PyKeePass) -> None:
|
||||
"""Set the keepass instance."""
|
||||
self._keepass = instance
|
||||
|
||||
@override
|
||||
def get_entries(self) -> list[str]:
|
||||
"""Get all entries."""
|
||||
entries = self.keepass.entries
|
||||
if not entries:
|
||||
return []
|
||||
return [
|
||||
str(entry.title) for entry in entries
|
||||
]
|
||||
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
def create_database(
|
||||
cls, location: str, password_context: PasswordContext | str, overwrite: bool = False
|
||||
) -> Self:
|
||||
"""Create database."""
|
||||
if Path(location).exists() and not overwrite:
|
||||
raise RuntimeError("Error: Database exists.")
|
||||
|
||||
if isinstance(password_context, PasswordContext):
|
||||
master_password = password_context.get_password(cls.master_password_identifier, True)
|
||||
else:
|
||||
master_password = password_context
|
||||
|
||||
# TODO: should we delete if overwrite is set?
|
||||
keepass = pykeepass.create_database(location, password=master_password)
|
||||
instance = cls()
|
||||
instance.set_manager_options({"location": str(location)})
|
||||
instance.keepass = keepass
|
||||
return instance
|
||||
|
||||
@override
|
||||
def open_database(self, password_context: PasswordContext | str) -> None:
|
||||
"""Open the database"""
|
||||
if isinstance(password_context, PasswordContext):
|
||||
password = password_context.get_password(self.master_password_identifier)
|
||||
else:
|
||||
password = password_context
|
||||
instance = pykeepass.PyKeePass(str(self.location.absolute()), password=password)
|
||||
self.keepass = instance
|
||||
|
||||
@override
|
||||
def close_database(self) -> None:
|
||||
"""Close the database."""
|
||||
self._keepass = None
|
||||
|
||||
@override
|
||||
def get_password(self, identifier: str) -> str | None:
|
||||
"""Get password."""
|
||||
entry = cast("pykeepass.entry.Entry | None", self.keepass.find_entries(title=identifier, first=True))
|
||||
if not entry:
|
||||
return None
|
||||
|
||||
if password := cast(str, entry.password):
|
||||
return str(password)
|
||||
raise RuntimeError(f"Cannot get password for entry {identifier}")
|
||||
|
||||
|
||||
@override
|
||||
def generate_password(self, identifier: str) -> str:
|
||||
"""Generate password."""
|
||||
# Generate a password.
|
||||
password = generate_password()
|
||||
_entry = self.keepass.add_entry(
|
||||
self.keepass.root_group, identifier, constants.NO_USERNAME, password
|
||||
)
|
||||
self.keepass.save()
|
||||
LOG.debug("Created Entry %r", _entry)
|
||||
return password
|
||||
|
||||
@override
|
||||
def add_password(self, identifier: str, password: str) -> None:
|
||||
"""Add a password."""
|
||||
entry = cast("pykeepass.entry.Entry | None", self.keepass.find_entries(title=identifier, first=True))
|
||||
if not entry:
|
||||
_entry = self.keepass.add_entry(self.keepass.root_group, identifier, constants.NO_USERNAME, password)
|
||||
self.keepass.save()
|
||||
LOG.debug("Created entry %r", _entry)
|
||||
return
|
||||
self.change_password(identifier, password)
|
||||
LOG.debug("Updated password on entry %r", entry)
|
||||
|
||||
|
||||
@overload
|
||||
def change_password(self, identifier: str, password: None) -> str: ...
|
||||
|
||||
@overload
|
||||
def change_password(self, identifier: str, password: str) -> None: ...
|
||||
|
||||
@override
|
||||
def change_password(self, identifier: str, password: str | None) -> str | None:
|
||||
"""Change a password."""
|
||||
generated_password = False
|
||||
if not password:
|
||||
password = generate_password()
|
||||
generated_password = True
|
||||
|
||||
entry = cast("pykeepass.entry.Entry | None", self.keepass.find_entries(title=identifier, first=True))
|
||||
|
||||
if not entry:
|
||||
raise ValueError("Error: Entry not found!")
|
||||
|
||||
entry.password = password
|
||||
|
||||
self.keepass.save()
|
||||
if generated_password:
|
||||
return password
|
||||
|
||||
return None
|
||||
|
||||
@override
|
||||
def delete_password(self, identifier: str) -> None:
|
||||
"""Delete password."""
|
||||
entry = cast("pykeepass.entry.Entry | None", self.keepass.find_entries(title=identifier, first=True))
|
||||
if not entry:
|
||||
return
|
||||
LOG.info("Deleting entry %s for keepass.", entry.uuid)
|
||||
|
||||
self.keepass.delete_entry(entry)
|
||||
|
||||
self.keepass.save()
|
||||
@ -1,61 +0,0 @@
|
||||
"""Password reader classes.
|
||||
|
||||
This implements two interfaces to read passwords:
|
||||
|
||||
InputPasswordReader and EnvironmentPasswordReader.
|
||||
"""
|
||||
|
||||
import re
|
||||
import os
|
||||
import sys
|
||||
from typing import TextIO, override
|
||||
import click
|
||||
|
||||
from .types import BasePasswordReader
|
||||
from . import constants
|
||||
|
||||
RE_VARNAME = re.compile(r"^[a-zA-Z_]+[a-zA-Z0-9_]*$")
|
||||
|
||||
|
||||
class InputPasswordReader(BasePasswordReader):
|
||||
"""Read a password from stdin."""
|
||||
|
||||
@override
|
||||
def get_password(self, identifier: str, repeated: bool = False) -> str:
|
||||
"""Get password."""
|
||||
if password := click.prompt(
|
||||
f"Enter password for {identifier}", hide_input=True, type=str, confirmation_prompt=repeated
|
||||
):
|
||||
return str(password)
|
||||
raise ValueError("No password received.")
|
||||
|
||||
|
||||
class EnvironmentPasswordReader(BasePasswordReader):
|
||||
"""Read a password from the environment.
|
||||
|
||||
The environemnt variable will be constructured based on the identifier and the prefix.
|
||||
Final environemnt variable will be validated according to the regex `[a-zA-Z_]+[a-zA-Z0-9_]*`
|
||||
"""
|
||||
|
||||
def _resolve_var_name(self, identifier: str) -> str:
|
||||
"""Resolve variable name."""
|
||||
identifier = identifier.replace("-", "_")
|
||||
fields = [constants.VAR_PREFIX, identifier]
|
||||
varname = "_".join(fields)
|
||||
if not RE_VARNAME.fullmatch(varname):
|
||||
raise ValueError(
|
||||
f"Cannot generate encode password identifier in variable name. {varname} is not a valid identifier."
|
||||
)
|
||||
return varname
|
||||
|
||||
def get_password_from_env(self, identifier: str) -> str:
|
||||
"""Get password from environment."""
|
||||
varname = self._resolve_var_name(identifier)
|
||||
if password := os.getenv(varname, None):
|
||||
return password
|
||||
raise ValueError(f"Error: No variable named {varname} resolved.")
|
||||
|
||||
@override
|
||||
def get_password(self, identifier: str, repeated: bool = False) -> str:
|
||||
"""Get password."""
|
||||
return self.get_password_from_env(identifier)
|
||||
@ -1,5 +0,0 @@
|
||||
"""Sshecret server module."""
|
||||
|
||||
from .async_server import AsshyncServer, start_server
|
||||
|
||||
__all__ = ["AsshyncServer", "start_server"]
|
||||
@ -1,143 +0,0 @@
|
||||
"""Server implemented with asyncssh."""
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import override
|
||||
|
||||
import asyncssh
|
||||
|
||||
from sshecret import constants
|
||||
from sshecret.audit import audit_message
|
||||
from sshecret.types import ClientSpecification, BaseClientBackend
|
||||
from sshecret.crypto import create_private_rsa_key
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def handle_client(process: asyncssh.SSHServerProcess[str]) -> None:
|
||||
"""Handle client."""
|
||||
remote_ip = process.get_extra_info("peername")[0]
|
||||
client_found = process.get_extra_info("client_allowed", False)
|
||||
if not client_found:
|
||||
process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n")
|
||||
audit_message("Unknown connection", source_address=remote_ip)
|
||||
process.exit(1)
|
||||
return
|
||||
|
||||
client_allowed = process.get_extra_info("client_allowed", False)
|
||||
if not client_allowed:
|
||||
audit_message("Not permitted", "SECURITY", source_address=remote_ip)
|
||||
process.stderr.write(constants.ERROR_SOURCE_IP_NOT_ALLOWED + "\n")
|
||||
process.exit(1)
|
||||
return
|
||||
|
||||
client = process.get_extra_info("client")
|
||||
if not client:
|
||||
audit_message("Unknown client", source_address=remote_ip)
|
||||
process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n")
|
||||
process.exit(1)
|
||||
return
|
||||
|
||||
secret_name = process.command
|
||||
if not secret_name:
|
||||
audit_message("No secret specified", source_address=remote_ip, client_name=client.name)
|
||||
process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n")
|
||||
process.exit(1)
|
||||
return
|
||||
|
||||
LOG.debug(
|
||||
"Client %s successfully connected. Fetching secret %s", client.name, secret_name
|
||||
)
|
||||
|
||||
audit_message(f"Requested secret", client_name=client.name, secret_name=secret_name, source_address=remote_ip)
|
||||
secret = client.secrets.get(secret_name)
|
||||
if not secret:
|
||||
process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n")
|
||||
process.exit(1)
|
||||
return
|
||||
|
||||
audit_message("Accessed secret", client.name, secret_name, source_address=remote_ip)
|
||||
process.stdout.write(secret)
|
||||
process.exit(0)
|
||||
|
||||
|
||||
class AsshyncServer(asyncssh.SSHServer):
|
||||
"""Asynchronous SSH server implementation."""
|
||||
|
||||
def __init__(self, backend: BaseClientBackend) -> None:
|
||||
"""Initialize server."""
|
||||
self.backend: BaseClientBackend = backend
|
||||
self._conn: asyncssh.SSHServerConnection | None = None
|
||||
|
||||
@override
|
||||
def connection_made(self, conn: asyncssh.SSHServerConnection) -> None:
|
||||
"""Handle incoming connection."""
|
||||
peername = conn.get_extra_info("peername")
|
||||
LOG.debug("Connection established from %r", peername)
|
||||
self._conn = conn
|
||||
|
||||
@override
|
||||
def begin_auth(self, username: str) -> bool:
|
||||
"""Begin authentication."""
|
||||
if not self._conn:
|
||||
return True
|
||||
client = self.backend.lookup_name(username)
|
||||
if not client:
|
||||
return True
|
||||
self._conn.set_extra_info(client_found=True)
|
||||
remote_ip = self._conn.get_extra_info("peername")[0]
|
||||
LOG.debug("Remote_IP: %r", remote_ip)
|
||||
assert isinstance(remote_ip, str)
|
||||
if self.check_connection_allowed(client, remote_ip):
|
||||
audit_message("Authentication requested", "ACCESS", client_name=client.name, source_address=remote_ip)
|
||||
self._conn.set_extra_info(client_allowed=True)
|
||||
self._conn.set_extra_info(client=client)
|
||||
|
||||
# Load the key.
|
||||
public_key = asyncssh.import_authorized_keys(client.public_key)
|
||||
self._conn.set_authorized_keys(public_key)
|
||||
|
||||
return True
|
||||
|
||||
@override
|
||||
def password_auth_supported(self) -> bool:
|
||||
"""Deny password authentication."""
|
||||
return False
|
||||
|
||||
def check_connection_allowed(
|
||||
self, client: ClientSpecification, source: str
|
||||
) -> bool:
|
||||
"""Check if client is allowed to request secrets."""
|
||||
LOG.debug("Checking if client is allowed to log in from %s", source)
|
||||
if isinstance(client.allowed_ips, str) and client.allowed_ips == "*":
|
||||
audit_message("Permitting login", "SECURITY", client_name=client.name, source_address=source)
|
||||
LOG.debug("Client has no restrictions on source IP address. Permitting.")
|
||||
return True
|
||||
if isinstance(client.allowed_ips, str) and "/" not in client.allowed_ips:
|
||||
if source == client.allowed_ips:
|
||||
audit_message("Permitting login", "SECURITY", client_name=client.name, source_address=source)
|
||||
LOG.debug("Client IP matches permitted address")
|
||||
return True
|
||||
|
||||
LOG.warning(
|
||||
"Connection for client %s received from IP address %s that is not permitted.",
|
||||
client.name,
|
||||
source,
|
||||
)
|
||||
audit_message("REJECTED. Invalid address", "SECURITY", client_name=client.name, source_address=source)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def start_server(
|
||||
port: int, backend: BaseClientBackend, host_key: str, create_key: bool = False
|
||||
) -> None:
|
||||
"""Start server."""
|
||||
server = partial(AsshyncServer, backend=backend)
|
||||
if create_key:
|
||||
create_private_rsa_key(Path(host_key))
|
||||
await asyncssh.create_server(
|
||||
server, "", port, server_host_keys=[host_key], process_factory=handle_client
|
||||
)
|
||||
@ -1,21 +0,0 @@
|
||||
"""Server errors."""
|
||||
|
||||
|
||||
class BaseSshecretServerError(Exception):
|
||||
"""Base Sshecret Server Error."""
|
||||
|
||||
|
||||
class UnknownClientError(BaseSshecretServerError):
|
||||
"""Client was not recognized."""
|
||||
|
||||
|
||||
class AccessDeniedError(BaseSshecretServerError):
|
||||
"""Client was not authorized to access the resource."""
|
||||
|
||||
|
||||
class AccessPolicyViolationError(BaseSshecretServerError):
|
||||
"""Client was not authorized to access the secret."""
|
||||
|
||||
|
||||
class UnknownSecretError(BaseSshecretServerError):
|
||||
"""Error when resolving the secret."""
|
||||
@ -1,36 +0,0 @@
|
||||
"""Password reader for use with the SSH server."""
|
||||
|
||||
from typing import override, TextIO
|
||||
import asyncssh
|
||||
from sshecret.types import BasePasswordReader
|
||||
|
||||
|
||||
class SSHPasswordReader(BasePasswordReader):
|
||||
"""SSH Password reader."""
|
||||
|
||||
def __init__(self, channel: asyncssh.SSHLineEditorChannel, stdin: asyncssh.SSHReader[str], stdout: asyncssh.SSHWriter[str]) -> None:
|
||||
"""Initialize password reader."""
|
||||
self.channel: asyncssh.SSHLineEditorChannel = channel
|
||||
self.stdin: asyncssh.SSHReader[str] = stdin
|
||||
self.stdout: asyncssh.SSHWriter[str] = stdout
|
||||
|
||||
@override
|
||||
def get_password(self, identifier: str, repeated: bool = False) -> str:
|
||||
"""Get password."""
|
||||
raise RuntimeError("Use get_password_async!")
|
||||
|
||||
async def get_password_async(self, identifier: str, repeated: bool = False) -> str:
|
||||
"""Get password async."""
|
||||
self.stdout.write(f"Enter password for {identifier}: ")
|
||||
self.channel.set_echo(False)
|
||||
while True:
|
||||
password = await self.stdin.readline()
|
||||
if not repeated:
|
||||
break
|
||||
self.stdout.write(f"\nRe-enter password for {identifier}: ")
|
||||
password2 = await self.stdin.readline()
|
||||
if password == password2:
|
||||
break
|
||||
self.stdout.write(f"Passwords did not match. Try again.\n")
|
||||
self.channel.set_echo(True)
|
||||
return password.strip()
|
||||
@ -1,82 +0,0 @@
|
||||
"""Get settings."""
|
||||
|
||||
import abc
|
||||
import enum
|
||||
import os
|
||||
import tomllib
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from pydantic import BaseModel, DirectoryPath, Field, FilePath
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from sshecret.keepass import KeepassManager
|
||||
|
||||
SETTINGS_FILE = "sshecret.toml"
|
||||
|
||||
|
||||
class Backend(enum.StrEnum):
|
||||
"""Supported backends."""
|
||||
|
||||
FILES = "FILES"
|
||||
|
||||
|
||||
class PasswordManager(enum.StrEnum):
|
||||
"""Supported password managers."""
|
||||
|
||||
KEEPASS = "KeePass"
|
||||
|
||||
|
||||
class SSHServerSettings(BaseModel):
|
||||
"""SSH Server settings."""
|
||||
|
||||
port: int = 22
|
||||
private_key: FilePath | None = None
|
||||
|
||||
|
||||
class AdminApiSettings(BaseModel):
|
||||
"""Admin API settings."""
|
||||
|
||||
port: int = 8022
|
||||
|
||||
|
||||
class FileBackendSettings(BaseModel):
|
||||
"""File backend settings.
|
||||
|
||||
This will eventually have the Discriminator pattern described in pydantic.
|
||||
"""
|
||||
|
||||
type: Literal["Files"]
|
||||
|
||||
location: DirectoryPath
|
||||
|
||||
|
||||
class KeepassPDBSettings(BaseModel):
|
||||
"""Keepass backend settings."""
|
||||
|
||||
type: Literal["KeePass"]
|
||||
location: FilePath
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Sshecret settings."""
|
||||
|
||||
model_config = SettingsConfigDict(env_prefix="sshecret_", env_nested_delimiter="__")
|
||||
|
||||
backend: FileBackendSettings
|
||||
password_manager: KeepassPDBSettings
|
||||
admin_api: AdminApiSettings = Field(default_factory=AdminApiSettings)
|
||||
ssh_server: SSHServerSettings = Field(default_factory=SSHServerSettings)
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
"""Get settings."""
|
||||
cwd = Path(os.getcwd())
|
||||
settings_file = cwd / SETTINGS_FILE
|
||||
if not settings_file.exists():
|
||||
# This should fail if the current env variables don't exist.
|
||||
return Settings() # pyright: ignore[reportCallIssue]
|
||||
with open(settings_file, "rb") as f:
|
||||
settings_data = tomllib.load(f)
|
||||
|
||||
return Settings.model_validate(settings_data)
|
||||
@ -1 +0,0 @@
|
||||
"""Shell interface."""
|
||||
@ -1,55 +0,0 @@
|
||||
"""Admin shell."""
|
||||
|
||||
import os
|
||||
import click
|
||||
from click_repl import register_repl
|
||||
|
||||
from sshecret.api import ClientManagementAPI
|
||||
|
||||
from sshecret import constants
|
||||
from sshecret.password_readers import InputPasswordReader
|
||||
from sshecret.keepass import KeepassManager
|
||||
from sshecret.types import PasswordContext
|
||||
from .shell_client import ShellClient
|
||||
|
||||
DB_PATH = os.path.join(os.getcwd(), "sshecrets.kdbx")
|
||||
|
||||
api_client: ShellClient | None = None
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context) -> None:
|
||||
"""General CLI."""
|
||||
if api_client is None:
|
||||
raise RuntimeError("No client object defined.")
|
||||
|
||||
@cli.group(name="clients")
|
||||
def cmd_clients() -> None:
|
||||
"""Client context."""
|
||||
|
||||
@cmd_clients.command(name="show")
|
||||
def show_clients() -> None:
|
||||
"""Show clients."""
|
||||
example_set = ["client1", "client2", "client3"]
|
||||
for client in example_set:
|
||||
click.echo(f"- {client}")
|
||||
|
||||
@cmd_clients.command(name="add")
|
||||
@click.argument("name")
|
||||
def add_client(name: str) -> None:
|
||||
"""Add a client."""
|
||||
public_key = click.prompt("Please paste RSA public key")
|
||||
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("--overwrite", is_flag=True, help="Overwrite password database.")
|
||||
def create_database(overwrite: bool) -> None:
|
||||
"""Create database."""
|
||||
context = PasswordContext(InputPasswordReader)
|
||||
KeepassManager.create_database(DB_PATH, context, overwrite)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
api_client = ShellClient("127.0.0.1", KeepassManager)
|
||||
@ -1,14 +0,0 @@
|
||||
"""Shell commands.
|
||||
|
||||
The shell needs to implement the following shell commands:
|
||||
|
||||
- Client management
|
||||
client create/read/update/delete
|
||||
secret create/read/update/delete
|
||||
client permit secret
|
||||
client revoke secret
|
||||
client key rotate
|
||||
|
||||
audit show
|
||||
|
||||
"""
|
||||
@ -1,29 +0,0 @@
|
||||
"""Shell API Client object for auditing."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import override
|
||||
from sshecret.password_readers import InputPasswordReader
|
||||
from sshecret.types import BaseAPIClient, BasePasswordManager, BasePasswordReader, PasswordContext
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ShellClient(BaseAPIClient):
|
||||
"""Client connecting from local host."""
|
||||
|
||||
source: str
|
||||
password_manager_type: type[BasePasswordManager]
|
||||
method: str = field(init=False, default="shell")
|
||||
|
||||
@override
|
||||
def get_reader(self) -> type[BasePasswordReader]:
|
||||
"""Get reader."""
|
||||
return InputPasswordReader
|
||||
|
||||
@override
|
||||
def password_manager(self, manager_options: dict[str, str] | None = None) -> BasePasswordManager:
|
||||
"""Instantiate password manager."""
|
||||
manager_instance = self.password_manager_type()
|
||||
if manager_options:
|
||||
manager_instance.set_manager_options(manager_options)
|
||||
|
||||
return manager_instance
|
||||
@ -1,45 +0,0 @@
|
||||
"""Shell context manager."""
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Iterator, TextIO
|
||||
from click_shell.core import Shell
|
||||
|
||||
from sshecret.api import ManagementApi
|
||||
from sshecret.password_readers import InputPasswordReader
|
||||
from sshecret.types import BaseClientBackend, BasePasswordManager
|
||||
|
||||
from .shell_client import ShellClient
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ShellContext:
|
||||
"""Shell context."""
|
||||
|
||||
api: ManagementApi
|
||||
shell: Shell
|
||||
streams: tuple[TextIO, TextIO] | None = None
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@contextmanager
|
||||
def shell_session(
|
||||
shell: Shell,
|
||||
backend: BaseClientBackend,
|
||||
password_manager: type[BasePasswordManager],
|
||||
source_address: str,
|
||||
manager_options: dict[str, str] | None = None,
|
||||
) -> Iterator[ShellContext]:
|
||||
"""Start a shell session.
|
||||
|
||||
The idea here is to collect the context, store it in an instance variable,
|
||||
and run the shell.
|
||||
"""
|
||||
reader = InputPasswordReader
|
||||
client = ShellClient(source_address, password_manager)
|
||||
api = ManagementApi(backend, client, manager_options)
|
||||
@ -1,96 +0,0 @@
|
||||
"""Testing utilities and classes."""
|
||||
|
||||
from io import StringIO
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from collections.abc import Iterator
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from .utils import create_client_file, generate_password
|
||||
from . import settings as app_settings
|
||||
from .keepass import KeepassManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestClientSpec:
|
||||
"""Specification of a test client."""
|
||||
|
||||
name: str
|
||||
secrets: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestContext:
|
||||
"""Test context."""
|
||||
|
||||
path: Path
|
||||
master_password: str
|
||||
|
||||
@property
|
||||
def password_database(self) -> Path:
|
||||
"""Return password database location."""
|
||||
return self.path / "test.kdbx"
|
||||
|
||||
def get_settings(self) -> app_settings.Settings:
|
||||
"""Get settings."""
|
||||
return app_settings.Settings(
|
||||
backend=app_settings.BackendSettings(
|
||||
backend=app_settings.FileBackendSettings(
|
||||
type="Files", location=self.path
|
||||
),
|
||||
),
|
||||
password_manager=app_settings.PasswordManagerSettings(
|
||||
manager=app_settings.KeepassPDBSettings(
|
||||
type="KeePass", location=self.password_database
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def set_environment(context: TestContext) -> None:
|
||||
"""Set environment."""
|
||||
password_path = str(context.password_database)
|
||||
env: list[str] = [
|
||||
f"sshecret_backend__backend_location={str(context.path)}",
|
||||
"sshecret_backend__password_manager__manager_type=KeePass",
|
||||
f"sshecret_backend__password_manager__manager_location={password_path}",
|
||||
]
|
||||
env_str = StringIO("\n".join(env))
|
||||
load_dotenv(stream=env_str)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def test_context(clients: list[TestClientSpec]) -> Iterator[Path]:
|
||||
"""Create a test context."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
dirpath = Path(tmpdir)
|
||||
for client in clients:
|
||||
filename = dirpath / f"{client.name}.json"
|
||||
create_client_file(client.name, filename, client.secrets)
|
||||
|
||||
yield dirpath
|
||||
|
||||
|
||||
@contextmanager
|
||||
def api_context(clients: list[TestClientSpec]) -> Iterator[TestContext]:
|
||||
"""Create a context for testing the full API."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
dirpath = Path(tmpdir)
|
||||
master_password = generate_password()
|
||||
context = TestContext(dirpath, master_password)
|
||||
keepass = KeepassManager.create_database(
|
||||
str(context.password_database), master_password
|
||||
)
|
||||
seen_secrets: list[str] = []
|
||||
for client in clients:
|
||||
filename = dirpath / f"{client.name}.json"
|
||||
create_client_file(client.name, filename, client.secrets)
|
||||
for secret, value in client.secrets.items():
|
||||
if secret in seen_secrets:
|
||||
continue
|
||||
keepass.add_password(secret, value)
|
||||
seen_secrets.append(secret)
|
||||
|
||||
yield context
|
||||
@ -1,179 +0,0 @@
|
||||
"""Interfaces and types."""
|
||||
|
||||
import abc
|
||||
from types import NotImplementedType
|
||||
from typing import Self, overload
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.networks import IPvAnyAddress, IPvAnyNetwork
|
||||
|
||||
|
||||
class BasePasswordReader(abc.ABC):
|
||||
"""Abstract strategy class to read a passwords."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_password(self, identifier: str, repeated: bool = False) -> str:
|
||||
"""Resolve the password, e.g., via input."""
|
||||
|
||||
|
||||
class PasswordContext:
|
||||
"""Context class for resolving a password."""
|
||||
|
||||
def __init__(self, reader: BasePasswordReader) -> None:
|
||||
"""Initialize password context."""
|
||||
self._reader: BasePasswordReader = reader
|
||||
|
||||
@property
|
||||
def reader(self) -> BasePasswordReader:
|
||||
"""Return reader."""
|
||||
return self._reader
|
||||
|
||||
@reader.setter
|
||||
def reader(self, reader: BasePasswordReader) -> None:
|
||||
"""Set the reader instance."""
|
||||
self._reader = reader
|
||||
|
||||
def get_password(self, identifier: str, repeated: bool = False) -> str:
|
||||
"""Get the password."""
|
||||
return self.reader.get_password(identifier, repeated)
|
||||
|
||||
|
||||
class BasePasswordManager(abc.ABC):
|
||||
"""Abstract base class for password managers."""
|
||||
|
||||
master_password_identifier: str
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def create_database(
|
||||
cls,
|
||||
location: str,
|
||||
password_context: PasswordContext | str,
|
||||
overwrite: bool = False,
|
||||
) -> Self:
|
||||
"""Create database.
|
||||
|
||||
Location can be a file, a url or something else.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def open_database(self, password_context: PasswordContext | str) -> None:
|
||||
"""Open database."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def close_database(self) -> None:
|
||||
"""Close database."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_password(self, identifier: str) -> str | None:
|
||||
"""Get a password from the manager."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def generate_password(self, identifier: str) -> str:
|
||||
"""Generate a password using unspecified default rules.
|
||||
|
||||
May be expanded later.
|
||||
|
||||
Returns the generated password.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def add_password(self, identifier: str, password: str) -> None:
|
||||
"""Add a pre-defined password."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_entries(self) -> list[str]:
|
||||
"""Get names of all entries."""
|
||||
|
||||
def set_manager_options(self, options: dict[str, str]) -> None:
|
||||
"""Set manager options."""
|
||||
pass
|
||||
|
||||
@overload
|
||||
def change_password(self, identifier: str, password: None) -> str: ...
|
||||
|
||||
@overload
|
||||
def change_password(self, identifier: str, password: str) -> None: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def change_password(self, identifier: str, password: str | None) -> str | None:
|
||||
"""Change password."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_password(self, identifier: str) -> None:
|
||||
"""Delete a password."""
|
||||
|
||||
|
||||
class ClientSpecification(BaseModel):
|
||||
"""Specification of client."""
|
||||
|
||||
name: str
|
||||
public_key: str
|
||||
allowed_ips: list[IPvAnyAddress | IPvAnyNetwork] | str = "*"
|
||||
secrets: dict[str, str] = {}
|
||||
testing_private_key: str | None = None # Private key only for testing purposes!
|
||||
|
||||
|
||||
class BaseClientBackend(abc.ABC):
|
||||
"""Base client backend.
|
||||
|
||||
This class is responsible for managing the list of clients and facilitate
|
||||
lookups.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def lookup_name(self, name: str) -> ClientSpecification | None:
|
||||
"""Lookup a client specification by name."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def add_secret(
|
||||
self,
|
||||
client_name: str,
|
||||
secret_name: str,
|
||||
secret_value: str,
|
||||
encrypted: bool = False,
|
||||
) -> None:
|
||||
"""Add a secret to a client."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def add_client(self, spec: ClientSpecification) -> None:
|
||||
"""Add a new client."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_client(self, name: str, spec: ClientSpecification) -> None:
|
||||
"""Update client information."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def remove_client(self, name: str, persistent: bool = True) -> None:
|
||||
"""Delete a client."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_all(self) -> list[ClientSpecification]:
|
||||
"""Get all clients."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def lookup_by_secret(self, secret_name: str) -> list[ClientSpecification]:
|
||||
"""Lookup by the name of a secret."""
|
||||
|
||||
|
||||
class BaseAPIClient(abc.ABC):
|
||||
"""Base API Client."""
|
||||
|
||||
source: str
|
||||
method: str
|
||||
|
||||
@abc.abstractmethod
|
||||
def password_manager(
|
||||
self, manager_options: dict[str, str] | None = None
|
||||
) -> BasePasswordManager:
|
||||
"""Instantiate password manager."""
|
||||
|
||||
def get_reader(self) -> BasePasswordReader:
|
||||
"""Get the reader."""
|
||||
raise NotImplementedError("Class-based password reading not implemented.")
|
||||
|
||||
def get_context(self, reader: BasePasswordReader | None = None) -> PasswordContext:
|
||||
"""Get password context."""
|
||||
if not reader:
|
||||
reader = self.get_reader()
|
||||
return PasswordContext(reader)
|
||||
@ -1,77 +0,0 @@
|
||||
"""Various utilities."""
|
||||
|
||||
import secrets
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from .crypto import (
|
||||
load_client_key,
|
||||
encrypt_string,
|
||||
generate_private_key,
|
||||
generate_pem,
|
||||
generate_public_key_string,
|
||||
)
|
||||
from .types import ClientSpecification
|
||||
|
||||
|
||||
def generate_password() -> str:
|
||||
"""Generate a password."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def generate_client_object(
|
||||
name: str, secrets: dict[str, str] | None = None, keyfile: str | None = None
|
||||
) -> ClientSpecification:
|
||||
"""Generate a client object."""
|
||||
private_key = generate_private_key()
|
||||
if keyfile:
|
||||
with open(keyfile, "r") as f:
|
||||
contents = f.read()
|
||||
if not contents.startswith("ssh-rsa "):
|
||||
raise RuntimeError("Error: Key must be an RSA key.")
|
||||
|
||||
client = ClientSpecification(name=name, public_key=contents.strip())
|
||||
public_key = load_client_key(client)
|
||||
else:
|
||||
pem = generate_pem(private_key)
|
||||
public_key = private_key.public_key()
|
||||
pubkey_str = generate_public_key_string(public_key)
|
||||
client = ClientSpecification(
|
||||
name=name, public_key=pubkey_str, testing_private_key=pem
|
||||
)
|
||||
if secrets:
|
||||
for secret_name, secret_value in secrets.items():
|
||||
client.secrets[secret_name] = encrypt_string(secret_value, public_key)
|
||||
|
||||
return client
|
||||
|
||||
|
||||
def create_client_file(
|
||||
name: str,
|
||||
filename: Path | str,
|
||||
secrets: dict[str, str] | None = None,
|
||||
keyfile: str | None = None,
|
||||
) -> None:
|
||||
"""Create client file."""
|
||||
client = generate_client_object(name, secrets, keyfile)
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.write(client.model_dump_json(exclude_none=True, indent=2))
|
||||
f.flush()
|
||||
|
||||
|
||||
def add_secret_to_client_file(
|
||||
filename: str | Path, secret_name: str, secret_value: str
|
||||
) -> None:
|
||||
"""Add secret to client file."""
|
||||
with open(filename, "r") as f:
|
||||
client = ClientSpecification.model_validate_json(f.read())
|
||||
|
||||
public_key = load_client_key(client)
|
||||
encrypted = encrypt_string(secret_value, public_key)
|
||||
client.secrets[secret_name] = encrypted
|
||||
|
||||
with open(filename, "w") as f:
|
||||
json_str = client.model_dump_json(exclude_none=True, indent=2)
|
||||
f.write(json_str)
|
||||
f.flush()
|
||||
@ -1 +0,0 @@
|
||||
|
||||
@ -1,279 +0,0 @@
|
||||
"""WebAPI."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
import secrets
|
||||
import time
|
||||
|
||||
from typing import Annotated
|
||||
from fastapi import Header, HTTPException, Depends, Request, APIRouter
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from sshecret.api import ManagementApi
|
||||
from sshecret.types import (
|
||||
BaseClientBackend,
|
||||
BasePasswordManager,
|
||||
ClientSpecification,
|
||||
)
|
||||
from sshecret.keepass import KeepassManager
|
||||
from sshecret.backends.file_table import FileTableBackend
|
||||
|
||||
from sshecret.settings import Settings, get_settings
|
||||
from sshecret.webapi.api_client import WebManagementAPIClient
|
||||
from . import models
|
||||
|
||||
API_VERSION = "v1"
|
||||
|
||||
admin_router = APIRouter(prefix=f"/api/{API_VERSION}")
|
||||
|
||||
encryption_key = Fernet.generate_key()
|
||||
cipher = Fernet(encryption_key)
|
||||
|
||||
# We store sessions in memory.
|
||||
sessions: dict[str, tuple[str, float]] = {}
|
||||
|
||||
SESSION_TIMEOUT = 600 # 10 minutes
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
session_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def encrypt_session_password(password: str) -> str:
|
||||
"""Encrypts the master password."""
|
||||
return cipher.encrypt(password.encode()).decode()
|
||||
|
||||
|
||||
def decrypt_password(encrypted_password: str) -> str:
|
||||
"""Decrypts the master password asynchronously."""
|
||||
return cipher.decrypt(encrypted_password.encode()).decode()
|
||||
|
||||
|
||||
async def validate_session(session_id: Annotated[str | None, Header()] = None) -> str:
|
||||
"""Middleware to validate session and enforce timeout."""
|
||||
if not session_id:
|
||||
raise HTTPException(status_code=401, detail="Session ID required")
|
||||
|
||||
async with session_lock:
|
||||
if session_id not in sessions:
|
||||
raise HTTPException(status_code=401, detail="Session invalid or expired")
|
||||
|
||||
encrypted_password, last_access_time = sessions[session_id]
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
|
||||
# Check for session timeout
|
||||
if current_time - last_access_time > SESSION_TIMEOUT:
|
||||
del sessions[session_id] # Auto-lock on timeout
|
||||
raise HTTPException(status_code=401, detail="Session expired")
|
||||
|
||||
# Update last access time
|
||||
sessions[session_id] = (encrypted_password, current_time)
|
||||
|
||||
return decrypt_password(encrypted_password)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_app_settings() -> Settings:
|
||||
"""Get app settings."""
|
||||
return get_settings()
|
||||
|
||||
|
||||
def get_password_manager(
|
||||
settings: Annotated[Settings, Depends(get_app_settings)]
|
||||
) -> BasePasswordManager:
|
||||
"""Get password manager."""
|
||||
# Currently only keepass is supported.
|
||||
keepass = KeepassManager()
|
||||
keepass.location = settings.password_manager.location
|
||||
return keepass
|
||||
|
||||
|
||||
async def get_backend(
|
||||
settings: Annotated[Settings, Depends(get_app_settings)]
|
||||
) -> BaseClientBackend:
|
||||
"""Get backend."""
|
||||
location = settings.backend.location
|
||||
filetable = FileTableBackend(location)
|
||||
return filetable
|
||||
|
||||
|
||||
async def get_management_api(
|
||||
request: Request, settings: Annotated[Settings, Depends(get_app_settings)]
|
||||
) -> ManagementApi:
|
||||
"""Get management api."""
|
||||
client_ip = "unknown"
|
||||
if req_client := request.client:
|
||||
client_ip = req_client.host
|
||||
|
||||
api_client = WebManagementAPIClient(client_ip, settings)
|
||||
backend = await get_backend(settings)
|
||||
return ManagementApi(backend, api_client)
|
||||
|
||||
|
||||
BackendDependency = Annotated[BaseClientBackend, Depends(get_backend)]
|
||||
ManagementAPIDependency = Annotated[ManagementApi, Depends(get_management_api)]
|
||||
SessionPasswdDependency = Annotated[str, Depends(validate_session)]
|
||||
|
||||
|
||||
@admin_router.post("/auth/unlock")
|
||||
async def unlock_database(
|
||||
password: models.PasswordBody,
|
||||
password_manager: Annotated[BasePasswordManager, Depends(get_password_manager)],
|
||||
) -> models.SessionResponse:
|
||||
"""Unlock database with master password sent in POST body."""
|
||||
password_str = password.password.get_secret_value()
|
||||
try:
|
||||
password_manager.open_database(password_str)
|
||||
except Exception as e:
|
||||
LOG.debug("Exception: %s", e, exc_info=True)
|
||||
raise HTTPException(status_code=401, detail="Invalid password.")
|
||||
|
||||
session_id = secrets.token_urlsafe(32)
|
||||
sessions[session_id] = (encrypt_session_password(password_str), time.time())
|
||||
|
||||
return models.SessionResponse(session_id=session_id)
|
||||
|
||||
|
||||
@admin_router.post("/auth/lock")
|
||||
async def lock_database(
|
||||
session_id: Annotated[str | None, Header()] = None
|
||||
) -> dict[str, str]:
|
||||
"""Lock database."""
|
||||
if session_id and session_id in sessions:
|
||||
del sessions[session_id]
|
||||
|
||||
return {"message": "LOCKED"}
|
||||
raise HTTPException(400, detail="Missing session ID.")
|
||||
|
||||
|
||||
@admin_router.get("/auth/status")
|
||||
async def get_lock_status(
|
||||
session_id: Annotated[str | None, Header()] = None
|
||||
) -> dict[str, str]:
|
||||
"""Get current lock status."""
|
||||
if session_id and session_id in sessions:
|
||||
return {"message": "UNLOCKED"}
|
||||
return {"message": "LOCKED"}
|
||||
|
||||
|
||||
@admin_router.get("/clients")
|
||||
async def get_clients(admin_api: ManagementAPIDependency) -> list[ClientSpecification]:
|
||||
"""Get clients."""
|
||||
return admin_api.get_clients()
|
||||
|
||||
|
||||
@admin_router.get("/clients/{client_id}")
|
||||
async def get_client(
|
||||
client_id: str, admin_api: ManagementAPIDependency
|
||||
) -> ClientSpecification:
|
||||
"""Get client."""
|
||||
if client_api := admin_api.get_client(client_id):
|
||||
return client_api.client
|
||||
raise HTTPException(status_code=404, detail="Client not found.")
|
||||
|
||||
|
||||
@admin_router.put("/clients/{client_id}")
|
||||
async def update_client(
|
||||
client_id: str,
|
||||
client: ClientSpecification,
|
||||
admin_api: ManagementAPIDependency,
|
||||
master_password: SessionPasswdDependency,
|
||||
) -> ClientSpecification:
|
||||
"""Update client."""
|
||||
client_api = admin_api.get_client(client_id)
|
||||
if not client_api:
|
||||
raise HTTPException(status_code=404, detail="Client not found.")
|
||||
new_client = client_api.update_client(client, password=master_password)
|
||||
return new_client
|
||||
|
||||
|
||||
@admin_router.delete("/clients/{client_id}", status_code=204)
|
||||
async def delete_client(client_id: str, admin_api: ManagementAPIDependency) -> None:
|
||||
"""Delete client."""
|
||||
if admin_api.get_client(client_id):
|
||||
admin_api.delete_client(client_id)
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Client not found.")
|
||||
|
||||
|
||||
@admin_router.post("/clients", status_code=201)
|
||||
async def add_client(
|
||||
client: models.CreateClientModel, admin_api: ManagementAPIDependency
|
||||
) -> ClientSpecification:
|
||||
"""Add client."""
|
||||
new_client = admin_api.create_client(
|
||||
client.name, client.public_key, client.allowed_ips
|
||||
)
|
||||
return new_client.client
|
||||
|
||||
|
||||
@admin_router.get("/secrets")
|
||||
async def list_secrets(
|
||||
admin_api: ManagementAPIDependency, password: SessionPasswdDependency
|
||||
) -> list[models.SecretListResponse]:
|
||||
"""List secrets."""
|
||||
secrets = admin_api.get_secret_names(password=password)
|
||||
return [
|
||||
models.SecretListResponse(name=name, assigned_clients=assigned_clients)
|
||||
for name, assigned_clients in secrets.items()
|
||||
]
|
||||
|
||||
|
||||
@admin_router.post("/secrets")
|
||||
async def add_secret(
|
||||
secret: models.CreateSecretSpecification,
|
||||
password: SessionPasswdDependency,
|
||||
admin_api: ManagementAPIDependency,
|
||||
) -> models.RevealSecretResponse:
|
||||
"""Add secret.
|
||||
|
||||
Will generate a password if none is specified.
|
||||
"""
|
||||
secret_value: str | None = None
|
||||
if secret.secret:
|
||||
secret_value = secret.secret.get_secret_value()
|
||||
result_secret = admin_api.add_secret(secret.name, secret_value, password=password)
|
||||
return models.RevealSecretResponse(name=secret.name, secret=result_secret)
|
||||
|
||||
|
||||
@admin_router.get("/secrets/{name}")
|
||||
async def get_secret(
|
||||
name: str, admin_api: ManagementAPIDependency, password: SessionPasswdDependency
|
||||
) -> models.RevealSecretResponse:
|
||||
"""Get secret."""
|
||||
if secret_value := admin_api.get_secret(name, password=password):
|
||||
return models.RevealSecretResponse(name=name, secret=secret_value)
|
||||
raise HTTPException(status_code=404, detail="Secret not found.")
|
||||
|
||||
|
||||
@admin_router.put("/secrets/{name}")
|
||||
async def update_secret(
|
||||
name: str,
|
||||
spec: models.UpdateSecretSpecification,
|
||||
admin_api: ManagementAPIDependency,
|
||||
password: SessionPasswdDependency,
|
||||
) -> models.MaybeRevalSecretResponse:
|
||||
"""Update secret."""
|
||||
if spec.auto_generate:
|
||||
secret_value = admin_api.regenerate_secret(name, password=password)
|
||||
return models.MaybeRevalSecretResponse(name=name, secret=secret_value)
|
||||
|
||||
if not spec.secret:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Secret value must be specified if auto_generate is False",
|
||||
)
|
||||
admin_api.update_secret(name, spec.secret, password=password)
|
||||
return models.MaybeRevalSecretResponse(name=name, secret=None)
|
||||
|
||||
|
||||
@admin_router.delete("/secrets/{name}", status_code=204)
|
||||
async def delete_secret(
|
||||
name: str,
|
||||
admin_api: ManagementAPIDependency,
|
||||
password: SessionPasswdDependency,
|
||||
) -> None:
|
||||
"""Delete secret."""
|
||||
admin_api.delete_secret(name, password=password)
|
||||
@ -1,25 +0,0 @@
|
||||
"""API Client."""
|
||||
|
||||
from typing import override
|
||||
from sshecret.keepass import KeepassManager
|
||||
from sshecret.types import BaseAPIClient, BasePasswordManager
|
||||
from sshecret.settings import Settings, get_settings
|
||||
|
||||
|
||||
class WebManagementAPIClient(BaseAPIClient):
|
||||
"""Client class for the web management API."""
|
||||
|
||||
method: str = "admin-web-api"
|
||||
|
||||
def __init__(self, source: str, settings: Settings | None = None) -> None:
|
||||
"""Construct client."""
|
||||
if not settings:
|
||||
settings = get_settings()
|
||||
self.source: str = source
|
||||
self._password_manager: BasePasswordManager = KeepassManager()
|
||||
self._password_manager.location = settings.password_manager.manager.location
|
||||
|
||||
@override
|
||||
def password_manager(self, manager_options: dict[str, str] | None = None) -> BasePasswordManager:
|
||||
"""Get password manager."""
|
||||
return self._password_manager
|
||||
@ -1,33 +0,0 @@
|
||||
"""Admin frontend."""
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
templates = Jinja2Templates(directory="templates")
|
||||
|
||||
frontend = APIRouter()
|
||||
|
||||
# I'm just making some placeholders here
|
||||
@frontend.get("/")
|
||||
async def index(request: Request) -> HTMLResponse:
|
||||
"""Get frontpage."""
|
||||
return templates.TemplateResponse(request, name="index.html")
|
||||
|
||||
|
||||
@frontend.get("/login")
|
||||
async def login(request: Request) -> HTMLResponse:
|
||||
"""Get login page."""
|
||||
return templates.TemplateResponse(request, name="login.html")
|
||||
|
||||
|
||||
@frontend.get("/clients")
|
||||
async def clients(request: Request) -> HTMLResponse:
|
||||
"""Get login page."""
|
||||
return templates.TemplateResponse(request, name="clients.html")
|
||||
|
||||
@frontend.get("/secrets")
|
||||
async def secrets(request: Request) -> HTMLResponse:
|
||||
"""Get login page."""
|
||||
return templates.TemplateResponse(request, name="secrets.html")
|
||||
@ -1,71 +0,0 @@
|
||||
"""Response models."""
|
||||
|
||||
from pydantic import BaseModel, IPvAnyAddress, IPvAnyNetwork, SecretStr
|
||||
|
||||
|
||||
class SSHKeyResponse(BaseModel):
|
||||
"""Response model for updated SSH keys."""
|
||||
|
||||
updated_secrets: list[str]
|
||||
|
||||
|
||||
class SecretListResponse(BaseModel):
|
||||
"""Response for listing secrets."""
|
||||
|
||||
name: str
|
||||
assigned_clients: list[str]
|
||||
|
||||
|
||||
class CreateSecretSpecification(BaseModel):
|
||||
"""Model for creating a secret."""
|
||||
|
||||
name: str
|
||||
secret: SecretStr | None
|
||||
|
||||
|
||||
class SecretSpecification(BaseModel):
|
||||
"""Secret specification."""
|
||||
|
||||
name: str
|
||||
secret: SecretStr
|
||||
|
||||
|
||||
class UpdateSecretSpecification(BaseModel):
|
||||
"""Model for updating a secret."""
|
||||
|
||||
secret: str | None
|
||||
auto_generate: bool | None = None
|
||||
|
||||
|
||||
class RevealSecretResponse(BaseModel):
|
||||
"""Reveal secret."""
|
||||
|
||||
name: str
|
||||
secret: str
|
||||
|
||||
|
||||
class MaybeRevalSecretResponse(BaseModel):
|
||||
"""Model where the secret may be specified."""
|
||||
|
||||
name: str
|
||||
secret: str | None
|
||||
|
||||
|
||||
class PasswordBody(BaseModel):
|
||||
"""Password body."""
|
||||
|
||||
password: SecretStr
|
||||
|
||||
|
||||
class SessionResponse(BaseModel):
|
||||
"""Session response."""
|
||||
|
||||
session_id: str
|
||||
|
||||
|
||||
class CreateClientModel(BaseModel):
|
||||
"""Model for creating a client."""
|
||||
|
||||
name: str
|
||||
public_key: str
|
||||
allowed_ips: list[IPvAnyAddress | IPvAnyNetwork] | str = "*"
|
||||
@ -1,16 +0,0 @@
|
||||
"""API router."""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
|
||||
from .api import admin_router
|
||||
from .frontend import frontend
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.include_router(admin_router)
|
||||
app.include_router(frontend)
|
||||
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
Reference in New Issue
Block a user