Clear POC first draft
This commit is contained in:
@ -1,2 +0,0 @@
|
|||||||
def hello() -> str:
|
|
||||||
return "Hello from sshecret!"
|
|
||||||
|
|||||||
@ -1,447 +0,0 @@
|
|||||||
"""API.
|
|
||||||
|
|
||||||
This module is an attempt to create some sort of meaningfull API around the
|
|
||||||
actions exposed here.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import abc
|
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from collections.abc import Iterator
|
|
||||||
|
|
||||||
from pydantic.networks import IPvAnyAddress, IPvAnyNetwork
|
|
||||||
|
|
||||||
from .audit import audit_message
|
|
||||||
|
|
||||||
from .crypto import load_client_key, load_public_key, encrypt_string
|
|
||||||
|
|
||||||
from .types import (
|
|
||||||
BaseAPIClient,
|
|
||||||
BaseClientBackend,
|
|
||||||
BasePasswordManager,
|
|
||||||
BasePasswordReader,
|
|
||||||
ClientSpecification,
|
|
||||||
PasswordContext,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def password_manager_session(
|
|
||||||
password_manager: BasePasswordManager,
|
|
||||||
password_context: PasswordContext | str,
|
|
||||||
api_client: BaseAPIClient,
|
|
||||||
) -> Iterator[BasePasswordManager]:
|
|
||||||
"""Open password manager for read/write in a context."""
|
|
||||||
audit_message(
|
|
||||||
"Opening password manager session",
|
|
||||||
"SECURITY",
|
|
||||||
source_address=api_client.source,
|
|
||||||
)
|
|
||||||
password_manager.open_database(password_context)
|
|
||||||
yield password_manager
|
|
||||||
|
|
||||||
audit_message(
|
|
||||||
"Closing password manager session",
|
|
||||||
"SECURITY",
|
|
||||||
source_address=api_client.source,
|
|
||||||
)
|
|
||||||
password_manager.close_database()
|
|
||||||
|
|
||||||
|
|
||||||
class BaseSshecretAPI(abc.ABC):
|
|
||||||
"""Base API class."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
backend: BaseClientBackend,
|
|
||||||
api_client: BaseAPIClient,
|
|
||||||
manager_options: dict[str, str] | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Initialize API."""
|
|
||||||
|
|
||||||
self.backend: BaseClientBackend = backend
|
|
||||||
self.api_client: BaseAPIClient = api_client
|
|
||||||
self.manager_options: dict[str, str] | None = manager_options
|
|
||||||
|
|
||||||
def _log_audit(
|
|
||||||
self,
|
|
||||||
message: str,
|
|
||||||
audit_type: str,
|
|
||||||
client_name: str | None = None,
|
|
||||||
**details: str,
|
|
||||||
) -> None:
|
|
||||||
"""Log an audit message."""
|
|
||||||
audit_message(
|
|
||||||
message,
|
|
||||||
audit_type,
|
|
||||||
client_name,
|
|
||||||
source_address=self.api_client.source,
|
|
||||||
**details,
|
|
||||||
)
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def password_session(
|
|
||||||
self, reader: BasePasswordReader | None = None, password: str | None = None
|
|
||||||
) -> Iterator[BasePasswordManager]:
|
|
||||||
"""Open a password session."""
|
|
||||||
if password:
|
|
||||||
context = password
|
|
||||||
else:
|
|
||||||
if not reader:
|
|
||||||
reader = self.api_client.get_reader()
|
|
||||||
context = self.api_client.get_context(reader)
|
|
||||||
|
|
||||||
password_manager = self.api_client.password_manager(self.manager_options)
|
|
||||||
with password_manager_session(
|
|
||||||
password_manager, context, self.api_client
|
|
||||||
) as session:
|
|
||||||
yield session
|
|
||||||
|
|
||||||
|
|
||||||
class ClientManagementAPI(BaseSshecretAPI):
|
|
||||||
"""API for managing clients."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
backend: BaseClientBackend,
|
|
||||||
client: ClientSpecification,
|
|
||||||
api_client: BaseAPIClient,
|
|
||||||
manager_options: dict[str, str] | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Create client management API instance."""
|
|
||||||
super().__init__(backend, api_client, manager_options)
|
|
||||||
self.client: ClientSpecification = client
|
|
||||||
self.__password_manager: BasePasswordManager | None = None
|
|
||||||
|
|
||||||
def log_security(self, message: str, **details: str) -> None:
|
|
||||||
"""Log a security related message."""
|
|
||||||
self._log_audit(message, "SECURITY", self.client.name, **details)
|
|
||||||
|
|
||||||
def log_info(self, message: str) -> None:
|
|
||||||
"""Log an informational message."""
|
|
||||||
self._log_audit(message, "INFORMATIONAL", self.client.name)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def password_manager(self) -> BasePasswordManager:
|
|
||||||
"""Get password manager."""
|
|
||||||
if self.__password_manager:
|
|
||||||
self.log_security("Accessed password manager")
|
|
||||||
return self.__password_manager
|
|
||||||
raise RuntimeError("Password manager not initialized.")
|
|
||||||
|
|
||||||
@password_manager.setter
|
|
||||||
def password_manager(self, instance: BasePasswordManager) -> None:
|
|
||||||
"""Set password manager instance."""
|
|
||||||
self.log_security("Opened password manager.")
|
|
||||||
self.__password_manager = instance
|
|
||||||
|
|
||||||
def get_secrets(self) -> list[str]:
|
|
||||||
"""Get names of the secrets that the client has access to.."""
|
|
||||||
self.log_security("Listing secret names.")
|
|
||||||
return list(self.client.secrets.keys())
|
|
||||||
|
|
||||||
def update_client(
|
|
||||||
self,
|
|
||||||
client: ClientSpecification,
|
|
||||||
reader: BasePasswordReader | None = None,
|
|
||||||
password: str | None = None,
|
|
||||||
) -> ClientSpecification:
|
|
||||||
"""Update client."""
|
|
||||||
self.log_info("Updating client")
|
|
||||||
update_data = client.model_dump(exclude_unset=True)
|
|
||||||
if client.public_key != self.client.public_key:
|
|
||||||
self.log_security("Client public key has changed.")
|
|
||||||
if client.secrets != self.client.secrets:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Error: Cannot update public key and secrets in the same operation."
|
|
||||||
)
|
|
||||||
del update_data["secrets"]
|
|
||||||
secrets = self._re_encrypt(client.public_key, reader, password)
|
|
||||||
update_data["secrets"] = secrets
|
|
||||||
|
|
||||||
updated_client = self.client.model_copy(update=update_data)
|
|
||||||
self.backend.update_client(self.client.name, updated_client)
|
|
||||||
self.client = updated_client
|
|
||||||
return updated_client
|
|
||||||
|
|
||||||
def update_secret(self, name: str, password: str) -> None:
|
|
||||||
"""Update a secret.
|
|
||||||
|
|
||||||
If secret is not already a part of the client, it will be added.
|
|
||||||
"""
|
|
||||||
if name in self.client.secrets:
|
|
||||||
self.log_security("Updating secret", secret_name=name)
|
|
||||||
else:
|
|
||||||
self.log_security("Adding secret", secret_name=name)
|
|
||||||
public_key = load_client_key(self.client)
|
|
||||||
encrypted = encrypt_string(password, public_key)
|
|
||||||
client_secrets = {**self.client.secrets, name: encrypted}
|
|
||||||
updated_client = self.client.model_copy(update={"secrets": client_secrets})
|
|
||||||
self.backend.update_client(self.client.name, updated_client)
|
|
||||||
self.client = updated_client
|
|
||||||
|
|
||||||
def _re_encrypt(
|
|
||||||
self,
|
|
||||||
new_key: str,
|
|
||||||
reader: BasePasswordReader | None = None,
|
|
||||||
password: str | None = None,
|
|
||||||
) -> dict[str, str]:
|
|
||||||
"""Update public key on given client."""
|
|
||||||
audit_message(
|
|
||||||
"Updating public key",
|
|
||||||
"INFORMATIONAL",
|
|
||||||
self.client.name,
|
|
||||||
source_address=self.api_client.source,
|
|
||||||
)
|
|
||||||
if password:
|
|
||||||
context = password
|
|
||||||
else:
|
|
||||||
if not reader:
|
|
||||||
reader = self.api_client.get_reader()
|
|
||||||
context = self.api_client.get_context(reader)
|
|
||||||
|
|
||||||
password_manager = self.api_client.password_manager(self.manager_options)
|
|
||||||
self.password_manager = password_manager
|
|
||||||
|
|
||||||
client_key = load_public_key(new_key.encode())
|
|
||||||
secrets: dict[str, str] = {}
|
|
||||||
with password_manager_session(
|
|
||||||
password_manager, context, self.api_client
|
|
||||||
) as password_session:
|
|
||||||
for name in self.get_secrets():
|
|
||||||
secret = password_session.get_password(name)
|
|
||||||
audit_message(
|
|
||||||
"Updating encrypted value",
|
|
||||||
"SECURITY",
|
|
||||||
self.client.name,
|
|
||||||
secret_name=name,
|
|
||||||
source_address=self.api_client.source,
|
|
||||||
)
|
|
||||||
if not secret:
|
|
||||||
raise RuntimeError("Could not fetch a new secret value.")
|
|
||||||
new_value = encrypt_string(secret, client_key)
|
|
||||||
secrets[name] = new_value
|
|
||||||
|
|
||||||
return secrets
|
|
||||||
|
|
||||||
def update_public_key(
|
|
||||||
self,
|
|
||||||
new_key: str,
|
|
||||||
reader: BasePasswordReader | None = None,
|
|
||||||
password: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Update the public key."""
|
|
||||||
audit_message(
|
|
||||||
"Updating public key",
|
|
||||||
"INFORMATIONAL",
|
|
||||||
self.client.name,
|
|
||||||
source_address=self.api_client.source,
|
|
||||||
)
|
|
||||||
secrets = self._re_encrypt(new_key, reader, password)
|
|
||||||
|
|
||||||
updated = self.client.model_copy(update={"secrets": secrets})
|
|
||||||
self.backend.update_client(self.client.name, updated)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_client(
|
|
||||||
cls,
|
|
||||||
backend: BaseClientBackend,
|
|
||||||
name: str,
|
|
||||||
api_client: BaseAPIClient,
|
|
||||||
) -> "ClientManagementAPI | None":
|
|
||||||
"""Get client."""
|
|
||||||
client = backend.lookup_name(name)
|
|
||||||
if not client:
|
|
||||||
return None
|
|
||||||
return cls(backend, client, api_client)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create_client(
|
|
||||||
cls,
|
|
||||||
backend: BaseClientBackend,
|
|
||||||
name: str,
|
|
||||||
api_client: BaseAPIClient,
|
|
||||||
public_key: str,
|
|
||||||
allowed_ips: list[IPvAnyAddress | IPvAnyNetwork] | str = "*",
|
|
||||||
) -> "ClientManagementAPI":
|
|
||||||
"""Create a client."""
|
|
||||||
client = ClientSpecification(
|
|
||||||
name=name, public_key=public_key, allowed_ips=allowed_ips
|
|
||||||
)
|
|
||||||
backend.add_client(client)
|
|
||||||
return cls(backend, client, api_client)
|
|
||||||
|
|
||||||
|
|
||||||
class ManagementApi(BaseSshecretAPI):
|
|
||||||
"""Api for general management."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
backend: BaseClientBackend,
|
|
||||||
api_client: BaseAPIClient,
|
|
||||||
manager_options: dict[str, str] | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Initialize API."""
|
|
||||||
super().__init__(backend, api_client, manager_options)
|
|
||||||
|
|
||||||
def log_security(
|
|
||||||
self, message: str, client_name: str | None = None, **details: str
|
|
||||||
) -> None:
|
|
||||||
"""Log a security related message."""
|
|
||||||
self._log_audit(message, "SECURITY", client_name, **details)
|
|
||||||
|
|
||||||
def log_info(self, message: str, client_name: str | None = None) -> None:
|
|
||||||
"""Log an informational message."""
|
|
||||||
self._log_audit(message, "INFORMATIONAL", client_name)
|
|
||||||
|
|
||||||
def get_client(self, name: str) -> ClientManagementAPI | None:
|
|
||||||
"""Get a client."""
|
|
||||||
client = self.backend.lookup_name(name)
|
|
||||||
if not client:
|
|
||||||
return None
|
|
||||||
return ClientManagementAPI(
|
|
||||||
self.backend, client, self.api_client, self.manager_options
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_clients(self) -> list[ClientSpecification]:
|
|
||||||
"""Get clients."""
|
|
||||||
self.log_info("Fetched all clients")
|
|
||||||
return self.backend.get_all()
|
|
||||||
|
|
||||||
def _get_clients(self) -> list[ClientSpecification]:
|
|
||||||
"""Get clients."""
|
|
||||||
return self.backend.get_all()
|
|
||||||
|
|
||||||
def delete_client(self, name: str) -> None:
|
|
||||||
"""Delete client."""
|
|
||||||
client = self.backend.lookup_name(name)
|
|
||||||
if not client:
|
|
||||||
self.log_info("Attempted to delete a non-existing client.", name)
|
|
||||||
return
|
|
||||||
self.log_security("Deleting client", name)
|
|
||||||
self.backend.remove_client(name)
|
|
||||||
|
|
||||||
def create_client(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
public_key: str,
|
|
||||||
allowed_ips: list[IPvAnyAddress | IPvAnyNetwork] | str = "*",
|
|
||||||
) -> ClientManagementAPI:
|
|
||||||
"""Create a client."""
|
|
||||||
self.log_info("Creating new client", name)
|
|
||||||
client = ClientSpecification(
|
|
||||||
name=name, public_key=public_key, allowed_ips=allowed_ips
|
|
||||||
)
|
|
||||||
self.backend.add_client(client)
|
|
||||||
return ClientManagementAPI(self.backend, client, self.api_client)
|
|
||||||
|
|
||||||
def get_secret_names(
|
|
||||||
self, reader: BasePasswordReader | None = None, password: str | None = None
|
|
||||||
) -> dict[str, list[str]]:
|
|
||||||
"""Get secret names and which clients have these.."""
|
|
||||||
self.log_security("Listing all secret names.")
|
|
||||||
with self.password_session(reader=reader, password=password) as session:
|
|
||||||
secret_names = session.get_entries()
|
|
||||||
|
|
||||||
secret_mapping: dict[str, list[str]] = {}
|
|
||||||
for name in secret_names:
|
|
||||||
secret_mapping[name] = [
|
|
||||||
client.name for client in self.backend.lookup_by_secret(name)
|
|
||||||
]
|
|
||||||
|
|
||||||
return secret_mapping
|
|
||||||
|
|
||||||
def add_secret(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
secret_value: str | None,
|
|
||||||
reader: BasePasswordReader | None = None,
|
|
||||||
password: str | None = None,
|
|
||||||
) -> str:
|
|
||||||
"""Add a secret."""
|
|
||||||
self.log_security("Adding new secret", secret_name=name)
|
|
||||||
with self.password_session(reader=reader, password=password) as session:
|
|
||||||
if not secret_value:
|
|
||||||
self.log_security("Auto-generating a secret value", secret_name=name)
|
|
||||||
secret_value = session.generate_password(name)
|
|
||||||
else:
|
|
||||||
session.add_password(name, secret_value)
|
|
||||||
return secret_value
|
|
||||||
|
|
||||||
def get_secret(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
reader: BasePasswordReader | None = None,
|
|
||||||
password: str | None = None,
|
|
||||||
) -> str | None:
|
|
||||||
"""Get the clear-text value of a secret."""
|
|
||||||
self.log_security("Client requested secret value", secret_name=name)
|
|
||||||
with self.password_session(reader=reader, password=password) as session:
|
|
||||||
secret = session.get_password(name)
|
|
||||||
|
|
||||||
return secret
|
|
||||||
|
|
||||||
def update_secret(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
new_value: str,
|
|
||||||
reader: BasePasswordReader | None = None,
|
|
||||||
password: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Update a secret with a given name."""
|
|
||||||
self.log_security("Changing secret", secret_name=name)
|
|
||||||
with self.password_session(reader=reader, password=password) as session:
|
|
||||||
session.change_password(name, new_value)
|
|
||||||
|
|
||||||
clients = self.backend.lookup_by_secret(name)
|
|
||||||
for client in clients:
|
|
||||||
client_api = self.get_client(client.name)
|
|
||||||
if not client_api:
|
|
||||||
continue
|
|
||||||
client_api.update_secret(name, new_value)
|
|
||||||
|
|
||||||
def regenerate_secret(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
reader: BasePasswordReader | None = None,
|
|
||||||
password: str | None = None,
|
|
||||||
) -> str:
|
|
||||||
"""Regenerate a secret."""
|
|
||||||
self.log_security("Generating a new secret value", secret_name=name)
|
|
||||||
with self.password_session(reader=reader, password=password) as session:
|
|
||||||
new_value = session.change_password(name, None)
|
|
||||||
|
|
||||||
clients = self.backend.lookup_by_secret(name)
|
|
||||||
for client in clients:
|
|
||||||
client_api = self.get_client(client.name)
|
|
||||||
if not client_api:
|
|
||||||
continue
|
|
||||||
client_api.update_secret(name, new_value)
|
|
||||||
|
|
||||||
return new_value
|
|
||||||
|
|
||||||
def delete_secret(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
reader: BasePasswordReader | None = None,
|
|
||||||
password: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Delete secret."""
|
|
||||||
clients = self.backend.lookup_by_secret(name)
|
|
||||||
self.log_security("Deleting secret", secret_name=name)
|
|
||||||
with self.password_session(reader=reader, password=password) as session:
|
|
||||||
session.delete_password(name)
|
|
||||||
|
|
||||||
for client in clients:
|
|
||||||
secrets = {**client.secrets}
|
|
||||||
del secrets[name]
|
|
||||||
new_client = client.model_copy(update={"secrets": secrets})
|
|
||||||
client_api = self.get_client(client.name)
|
|
||||||
if not client_api:
|
|
||||||
continue
|
|
||||||
self.log_security(
|
|
||||||
"Removing secret from client.",
|
|
||||||
client_name=client.name,
|
|
||||||
secret_name=name,
|
|
||||||
)
|
|
||||||
client_api.update_client(new_client, password=password)
|
|
||||||
@ -1,66 +0,0 @@
|
|||||||
"""Audit setup."""
|
|
||||||
|
|
||||||
import enum
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
from pythonjsonlogger.json import JsonFormatter
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
|
|
||||||
from .constants import AUDIT_LOG_NAME
|
|
||||||
|
|
||||||
AUDIT_LOG = logging.getLogger(AUDIT_LOG_NAME)
|
|
||||||
|
|
||||||
|
|
||||||
class AuditMessageType(enum.StrEnum):
|
|
||||||
"""Audit Message Type."""
|
|
||||||
|
|
||||||
ACCESS = enum.auto() # Someone accessed something
|
|
||||||
SECURITY = enum.auto() # A message related to security
|
|
||||||
INFORMATIONAL = enum.auto() # other informational messages
|
|
||||||
|
|
||||||
|
|
||||||
class AuditMessage(BaseModel):
|
|
||||||
"""Audit message."""
|
|
||||||
|
|
||||||
model_config = ConfigDict(use_enum_values=True)
|
|
||||||
|
|
||||||
type: AuditMessageType
|
|
||||||
message: str
|
|
||||||
client_name: str | None = None
|
|
||||||
source_address: str | None = None
|
|
||||||
secret_name: str | None = None
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
"""Stringify object as JSON."""
|
|
||||||
return self.model_dump_json()
|
|
||||||
|
|
||||||
|
|
||||||
def audit_message(
|
|
||||||
message: str,
|
|
||||||
audit_type: AuditMessageType | str | None = None,
|
|
||||||
client_name: str | None = None,
|
|
||||||
secret_name: str | None = None,
|
|
||||||
source_address: str | None = None,
|
|
||||||
**details: str
|
|
||||||
) -> None:
|
|
||||||
"""Create an audit message."""
|
|
||||||
if not audit_type:
|
|
||||||
audit_type = AuditMessageType.INFORMATIONAL
|
|
||||||
|
|
||||||
if audit_type not in list(AuditMessageType):
|
|
||||||
audit_type = AuditMessageType.INFORMATIONAL
|
|
||||||
|
|
||||||
audit_message = AuditMessage(
|
|
||||||
type=audit_type,
|
|
||||||
message=message,
|
|
||||||
client_name=client_name,
|
|
||||||
source_address=source_address,
|
|
||||||
secret_name=secret_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
audit_dict = audit_message.model_dump(exclude_none=True)
|
|
||||||
|
|
||||||
AUDIT_LOG.info({**audit_dict, **details})
|
|
||||||
@ -1,5 +0,0 @@
|
|||||||
"""Backend implementations"""
|
|
||||||
|
|
||||||
from .file_table import FileTableBackend
|
|
||||||
|
|
||||||
__all__ = ["FileTableBackend"]
|
|
||||||
@ -1,133 +0,0 @@
|
|||||||
"""File table based backend."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import override
|
|
||||||
|
|
||||||
import littletable as lt
|
|
||||||
|
|
||||||
from sshecret.crypto import load_client_key, encrypt_string
|
|
||||||
from sshecret.types import ClientSpecification
|
|
||||||
from sshecret.types import BaseClientBackend
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def load_clients_from_dir(directory: Path) -> dict[Path, ClientSpecification]:
|
|
||||||
"""Load clients from a directory."""
|
|
||||||
if not directory.exists() or not directory.is_dir():
|
|
||||||
raise ValueError("Invalid directory specified.")
|
|
||||||
|
|
||||||
clients: dict[Path, ClientSpecification] = {}
|
|
||||||
for client_file in directory.glob("*.json"):
|
|
||||||
with open(client_file, "r") as f:
|
|
||||||
client = ClientSpecification.model_validate_json(f.read())
|
|
||||||
if client_file.name != f"{client.name}.json":
|
|
||||||
raise RuntimeError(
|
|
||||||
"Filename scheme of clients does not conform to expected format. Aborting import!"
|
|
||||||
)
|
|
||||||
clients[client_file] = client
|
|
||||||
|
|
||||||
return clients
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FileTableBackend(BaseClientBackend):
|
|
||||||
"""In-memory littletable based backend."""
|
|
||||||
|
|
||||||
def __init__(self, directory: Path) -> None:
|
|
||||||
"""Create backend instance."""
|
|
||||||
LOG.debug("Creating in-memory table to hold clients.")
|
|
||||||
self._directory: Path = directory
|
|
||||||
self.table: lt.Table[ClientSpecification] = lt.Table()
|
|
||||||
self._setup_table()
|
|
||||||
client_files = load_clients_from_dir(directory)
|
|
||||||
client_count = len(client_files)
|
|
||||||
LOG.debug("Loaded %s clients from disk.", client_count)
|
|
||||||
# self.client_file_map: dict[str, Path] = {client.name: filepath for filepath, client in client_files.items()}
|
|
||||||
LOG.debug("Inserting clients into table.")
|
|
||||||
self.table.insert_many(list(client_files.values()))
|
|
||||||
|
|
||||||
def _setup_table(self) -> None:
|
|
||||||
"""Set up the table."""
|
|
||||||
self.table.create_index("name", unique=True)
|
|
||||||
|
|
||||||
@override
|
|
||||||
def lookup_name(self, name: str) -> ClientSpecification | None:
|
|
||||||
"""Lookup client by name."""
|
|
||||||
if result := self.table.by.name.get(name):
|
|
||||||
if isinstance(result, ClientSpecification):
|
|
||||||
return result
|
|
||||||
return None
|
|
||||||
|
|
||||||
@override
|
|
||||||
def add_client(self, spec: ClientSpecification) -> None:
|
|
||||||
"""Add client."""
|
|
||||||
self.table.insert(spec)
|
|
||||||
self._write_spec_file(spec)
|
|
||||||
|
|
||||||
def _write_spec_file(self, spec: ClientSpecification) -> None:
|
|
||||||
"""Write spec file to disk."""
|
|
||||||
dest_file_name = f"{spec.name}.json"
|
|
||||||
dest_file = self._directory / dest_file_name
|
|
||||||
with open(dest_file.absolute(), "w") as f:
|
|
||||||
f.write(
|
|
||||||
spec.model_dump_json(exclude_none=True, exclude_unset=True, indent=2)
|
|
||||||
)
|
|
||||||
f.flush()
|
|
||||||
|
|
||||||
@override
|
|
||||||
def add_secret(
|
|
||||||
self,
|
|
||||||
client_name: str,
|
|
||||||
secret_name: str,
|
|
||||||
secret_value: str,
|
|
||||||
encrypted: bool = False,
|
|
||||||
) -> None:
|
|
||||||
"""Add secret."""
|
|
||||||
client: ClientSpecification = self.table.by.name[client_name] # pyright: ignore[reportAssignmentType]
|
|
||||||
if not encrypted:
|
|
||||||
public_key = load_client_key(client)
|
|
||||||
secret_value = encrypt_string(secret_value, public_key)
|
|
||||||
client.secrets[secret_name] = secret_value
|
|
||||||
self._update_client_data(client)
|
|
||||||
self._write_spec_file(client)
|
|
||||||
|
|
||||||
@override
|
|
||||||
def remove_client(self, name: str, persistent: bool = True) -> None:
|
|
||||||
"""Delete client."""
|
|
||||||
client = self.lookup_name(name)
|
|
||||||
if not client:
|
|
||||||
raise ValueError("Client does not exist!")
|
|
||||||
self.table.remove(client)
|
|
||||||
if persistent:
|
|
||||||
filename = f"{client.name}.json"
|
|
||||||
filepath = self._directory / filename
|
|
||||||
filepath.unlink()
|
|
||||||
|
|
||||||
@override
|
|
||||||
def update_client(self, name: str, spec: ClientSpecification) -> None:
|
|
||||||
"""Update client."""
|
|
||||||
if not self.lookup_name(name):
|
|
||||||
raise ValueError("Client does not exist!")
|
|
||||||
self._update_client_data(spec)
|
|
||||||
self._write_spec_file(spec)
|
|
||||||
|
|
||||||
def _update_client_data(self, spec: ClientSpecification) -> None:
|
|
||||||
"""Update client data."""
|
|
||||||
existing = self.lookup_name(spec.name)
|
|
||||||
if existing:
|
|
||||||
self.table.remove(existing)
|
|
||||||
self.add_client(spec)
|
|
||||||
|
|
||||||
@override
|
|
||||||
def get_all(self) -> list[ClientSpecification]:
|
|
||||||
"""Get all clients."""
|
|
||||||
return list(self.table)
|
|
||||||
|
|
||||||
@override
|
|
||||||
def lookup_by_secret(self, secret_name: str) -> list[ClientSpecification]:
|
|
||||||
"""Lookup by secret name."""
|
|
||||||
results = self.table.where(lambda client: secret_name in client.secrets)
|
|
||||||
return list(results)
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
"""Command Line Interface"""
|
|
||||||
|
|
||||||
import click
|
|
||||||
@ -1,25 +0,0 @@
|
|||||||
"""Client code"""
|
|
||||||
|
|
||||||
from typing import TextIO
|
|
||||||
import click
|
|
||||||
|
|
||||||
from sshecret.crypto import decode_string, load_private_key
|
|
||||||
|
|
||||||
|
|
||||||
def decrypt_secret(encoded: str, client_key: str) -> str:
|
|
||||||
"""Decrypt secret."""
|
|
||||||
private_key = load_private_key(client_key)
|
|
||||||
return decode_string(encoded, private_key)
|
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
|
||||||
@click.argument("keyfile", type=click.Path(exists=True, readable=True, dir_okay=False))
|
|
||||||
@click.argument("encrypted_input", type=click.File("r"))
|
|
||||||
def cli_decrypt(keyfile: str, encrypted_input: TextIO) -> None:
|
|
||||||
"""Decrypt on command line."""
|
|
||||||
decrypted = decrypt_secret(encrypted_input.read(), keyfile)
|
|
||||||
click.echo(decrypted)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
cli_decrypt()
|
|
||||||
@ -1,18 +0,0 @@
|
|||||||
"""Config file."""
|
|
||||||
from pathlib import Path
|
|
||||||
from pydantic import SecretStr
|
|
||||||
from pydantic_settings import BaseSettings
|
|
||||||
|
|
||||||
|
|
||||||
class KeepassSettings(BaseSettings):
|
|
||||||
"""Settings for Keepasss password database."""
|
|
||||||
|
|
||||||
database_path: Path
|
|
||||||
|
|
||||||
|
|
||||||
class SshecretSettings(BaseSettings):
|
|
||||||
"""Settings model."""
|
|
||||||
|
|
||||||
admin_password: SecretStr
|
|
||||||
admin_ssh_key: str | None = None
|
|
||||||
keepass: KeepassSettings
|
|
||||||
@ -1,17 +0,0 @@
|
|||||||
"""Constants."""
|
|
||||||
|
|
||||||
MASTER_PASSWORD = "MASTER_PASSWORD"
|
|
||||||
NO_USERNAME = "NO_USERNAME"
|
|
||||||
VAR_PREFIX = "SSHECRET"
|
|
||||||
|
|
||||||
ERROR_NO_SECRET_FOUND = "Error: No secret available with the given name."
|
|
||||||
ERROR_UKNOWN_CLIENT_OR_SECRET = "Error: Invalid client or secret name."
|
|
||||||
ERROR_NO_COMMAND_RECEIVED = "Error: No command was received from the client."
|
|
||||||
ERROR_SOURCE_IP_NOT_ALLOWED = (
|
|
||||||
"Error: Client not authorized to connect from the given host."
|
|
||||||
)
|
|
||||||
|
|
||||||
RSA_PUBLIC_EXPONENT = 65537
|
|
||||||
RSA_KEY_SIZE = 2048
|
|
||||||
|
|
||||||
AUDIT_LOG_NAME = "AUDIT"
|
|
||||||
@ -1,106 +0,0 @@
|
|||||||
"""Encryption related functions.
|
|
||||||
|
|
||||||
Note! Encryption uses the less secure PKCS1v15 padding. This is to allow
|
|
||||||
decryption via openssl on the command line.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from cryptography.hazmat.primitives import serialization
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import padding
|
|
||||||
|
|
||||||
from .types import ClientSpecification
|
|
||||||
from . import constants
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def load_client_key(client: ClientSpecification) -> rsa.RSAPublicKey:
|
|
||||||
"""Load public key."""
|
|
||||||
keybytes = client.public_key.encode()
|
|
||||||
return load_public_key(keybytes)
|
|
||||||
|
|
||||||
|
|
||||||
def load_public_key(keybytes: bytes) -> rsa.RSAPublicKey:
|
|
||||||
public_key = serialization.load_ssh_public_key(keybytes)
|
|
||||||
if not isinstance(public_key, rsa.RSAPublicKey):
|
|
||||||
raise RuntimeError("Only RSA keys are supported.")
|
|
||||||
pem_public_key = public_key.public_bytes(encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo)
|
|
||||||
LOG.info("pem:\n%s", pem_public_key)
|
|
||||||
return public_key
|
|
||||||
|
|
||||||
|
|
||||||
def load_private_key(filename: str) -> rsa.RSAPrivateKey:
|
|
||||||
"""Load a private key."""
|
|
||||||
with open(filename, "rb") as f:
|
|
||||||
private_key = serialization.load_ssh_private_key(f.read(), password=None)
|
|
||||||
if not isinstance(private_key, rsa.RSAPrivateKey):
|
|
||||||
raise RuntimeError("Only RSA keys are supported.")
|
|
||||||
return private_key
|
|
||||||
|
|
||||||
|
|
||||||
def encrypt_string(string: str, public_key: rsa.RSAPublicKey) -> str:
|
|
||||||
"""Encrypt string, end return it base64 encoded."""
|
|
||||||
message = string.encode()
|
|
||||||
ciphertext = public_key.encrypt(
|
|
||||||
message,
|
|
||||||
padding.PKCS1v15(),
|
|
||||||
)
|
|
||||||
return base64.b64encode(ciphertext).decode()
|
|
||||||
|
|
||||||
|
|
||||||
def decode_string(ciphertext: str, private_key: rsa.RSAPrivateKey) -> str:
|
|
||||||
"""Decode a string. String must be base64 encoded."""
|
|
||||||
decoded = base64.b64decode(ciphertext)
|
|
||||||
decrypted = private_key.decrypt(
|
|
||||||
decoded,
|
|
||||||
padding.PKCS1v15(),
|
|
||||||
)
|
|
||||||
return decrypted.decode()
|
|
||||||
|
|
||||||
|
|
||||||
def generate_private_key() -> rsa.RSAPrivateKey:
|
|
||||||
"""Generate private RSA key."""
|
|
||||||
private_key = rsa.generate_private_key(
|
|
||||||
public_exponent=constants.RSA_PUBLIC_EXPONENT, key_size=constants.RSA_KEY_SIZE
|
|
||||||
)
|
|
||||||
return private_key
|
|
||||||
|
|
||||||
|
|
||||||
def generate_pem(private_key: rsa.RSAPrivateKey) -> str:
|
|
||||||
"""Generate PEM."""
|
|
||||||
pem = private_key.private_bytes(
|
|
||||||
encoding=serialization.Encoding.PEM,
|
|
||||||
format=serialization.PrivateFormat.OpenSSH,
|
|
||||||
encryption_algorithm=serialization.NoEncryption(),
|
|
||||||
)
|
|
||||||
return pem.decode()
|
|
||||||
|
|
||||||
|
|
||||||
def create_private_rsa_key(filename: Path) -> None:
|
|
||||||
"""Create an RSA Private key at the given path."""
|
|
||||||
if filename.exists():
|
|
||||||
raise RuntimeError("Error: private key file already exists.")
|
|
||||||
LOG.debug("Generating private RSA key at %s", filename)
|
|
||||||
private_key = generate_private_key()
|
|
||||||
with open(filename, "wb") as f:
|
|
||||||
pem = private_key.private_bytes(
|
|
||||||
encoding=serialization.Encoding.PEM,
|
|
||||||
format=serialization.PrivateFormat.OpenSSH,
|
|
||||||
encryption_algorithm=serialization.NoEncryption(),
|
|
||||||
)
|
|
||||||
lines = f.write(pem)
|
|
||||||
LOG.debug("Wrote %s lines", lines)
|
|
||||||
f.flush()
|
|
||||||
|
|
||||||
|
|
||||||
def generate_public_key_string(public_key: rsa.RSAPublicKey) -> str:
|
|
||||||
"""Generate public key string."""
|
|
||||||
keybytes = public_key.public_bytes(
|
|
||||||
encoding=serialization.Encoding.OpenSSH,
|
|
||||||
format=serialization.PublicFormat.OpenSSH,
|
|
||||||
)
|
|
||||||
return keybytes.decode()
|
|
||||||
@ -1,93 +0,0 @@
|
|||||||
"""Development CLI commands."""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import asyncio
|
|
||||||
import asyncssh
|
|
||||||
import click
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import tempfile
|
|
||||||
import threading
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from pythonjsonlogger.json import JsonFormatter
|
|
||||||
|
|
||||||
from .server import start_server
|
|
||||||
from sshecret.backends import FileTableBackend
|
|
||||||
from .utils import create_client_file, add_secret_to_client_file
|
|
||||||
from .constants import AUDIT_LOG_NAME
|
|
||||||
|
|
||||||
|
|
||||||
def thread_id_filter(record: logging.LogRecord) -> logging.LogRecord:
|
|
||||||
"""Resolve thread id."""
|
|
||||||
record.thread_id = threading.get_native_id()
|
|
||||||
return record
|
|
||||||
|
|
||||||
|
|
||||||
LOG = logging.getLogger()
|
|
||||||
handler = logging.StreamHandler()
|
|
||||||
handler.addFilter(thread_id_filter)
|
|
||||||
formatter = logging.Formatter(
|
|
||||||
"%(thread_id)d:%(created)f:%(levelname)s:%(name)s:%(module)s:%(message)s"
|
|
||||||
)
|
|
||||||
handler.setFormatter(formatter)
|
|
||||||
LOG.addHandler(handler)
|
|
||||||
LOG.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
AUDIT_LOG = logging.getLogger(AUDIT_LOG_NAME)
|
|
||||||
audit_formatter = JsonFormatter()
|
|
||||||
audit_handler = logging.StreamHandler()
|
|
||||||
audit_handler.setFormatter(audit_formatter)
|
|
||||||
AUDIT_LOG.addHandler(audit_handler)
|
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
|
||||||
def cli() -> None:
|
|
||||||
"""Run commands for testing."""
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command("create-client")
|
|
||||||
@click.argument("name")
|
|
||||||
@click.argument(
|
|
||||||
"filename", type=click.Path(file_okay=True, dir_okay=False, writable=True)
|
|
||||||
)
|
|
||||||
@click.option("--public-key", type=click.Path(file_okay=True))
|
|
||||||
def create_client(name: str, filename: str, public_key: str | None) -> None:
|
|
||||||
"""Create a client."""
|
|
||||||
create_client_file(name, filename, keyfile=public_key)
|
|
||||||
click.echo(f"Wrote client config to {filename}")
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command("add-secret")
|
|
||||||
@click.argument(
|
|
||||||
"filename", type=click.Path(file_okay=True, dir_okay=False, writable=True)
|
|
||||||
)
|
|
||||||
@click.argument("secret-name")
|
|
||||||
@click.argument("secret-value")
|
|
||||||
def add_secret(filename: str, secret_name: str, secret_value: str) -> None:
|
|
||||||
"""Add secret to client file."""
|
|
||||||
add_secret_to_client_file(filename, secret_name, secret_value)
|
|
||||||
click.echo(f"Wrote secret to {filename}")
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command("server")
|
|
||||||
@click.argument("directory", type=click.Path(file_okay=False, dir_okay=True))
|
|
||||||
@click.argument("port", type=click.INT)
|
|
||||||
def run_async_server(directory: str, port: int) -> None:
|
|
||||||
"""Run async server."""
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
|
||||||
serverdir = Path(tmpdir)
|
|
||||||
host_key = str(serverdir / "hostkey")
|
|
||||||
clientdir = Path(directory)
|
|
||||||
backend = FileTableBackend(clientdir)
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(start_server(port, backend, host_key, True))
|
|
||||||
except (OSError, asyncssh.Error) as exc:
|
|
||||||
click.echo(f"Error starting server: {exc}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
loop.run_forever()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
cli()
|
|
||||||
@ -1,180 +0,0 @@
|
|||||||
"""Keepass integration."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import cast, final, overload, override, Self
|
|
||||||
|
|
||||||
import pykeepass
|
|
||||||
from . import constants
|
|
||||||
from .types import BasePasswordManager, PasswordContext
|
|
||||||
from .utils import generate_password
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@final
|
|
||||||
class KeepassManager(BasePasswordManager):
|
|
||||||
"""KeepassXC compatible password manager."""
|
|
||||||
|
|
||||||
master_password_identifier = constants.MASTER_PASSWORD
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
"""Initialize password manager."""
|
|
||||||
self._location: Path | None = None
|
|
||||||
self._keepass: pykeepass.PyKeePass | None = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def location(self) -> Path:
|
|
||||||
"""Get location."""
|
|
||||||
if not self._location:
|
|
||||||
raise RuntimeError("No location has been specified.")
|
|
||||||
return self._location
|
|
||||||
|
|
||||||
@location.setter
|
|
||||||
def location(self, location: Path) -> None:
|
|
||||||
"""Set location."""
|
|
||||||
if not location.exists() or not location.is_file():
|
|
||||||
raise RuntimeError("Unable to read provided password file.")
|
|
||||||
self._location = location
|
|
||||||
|
|
||||||
@override
|
|
||||||
def set_manager_options(self, options: dict[str, str]) -> None:
|
|
||||||
"""Set manager options."""
|
|
||||||
if "location" in options:
|
|
||||||
location = Path(str(options["location"]))
|
|
||||||
self.location = location
|
|
||||||
|
|
||||||
@property
|
|
||||||
def keepass(self) -> pykeepass.PyKeePass:
|
|
||||||
"""Return keepass instance."""
|
|
||||||
if self._keepass:
|
|
||||||
return self._keepass
|
|
||||||
raise RuntimeError("Error: Database has not been opened.")
|
|
||||||
|
|
||||||
@keepass.setter
|
|
||||||
def keepass(self, instance: pykeepass.PyKeePass) -> None:
|
|
||||||
"""Set the keepass instance."""
|
|
||||||
self._keepass = instance
|
|
||||||
|
|
||||||
@override
|
|
||||||
def get_entries(self) -> list[str]:
|
|
||||||
"""Get all entries."""
|
|
||||||
entries = self.keepass.entries
|
|
||||||
if not entries:
|
|
||||||
return []
|
|
||||||
return [
|
|
||||||
str(entry.title) for entry in entries
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@override
|
|
||||||
@classmethod
|
|
||||||
def create_database(
|
|
||||||
cls, location: str, password_context: PasswordContext | str, overwrite: bool = False
|
|
||||||
) -> Self:
|
|
||||||
"""Create database."""
|
|
||||||
if Path(location).exists() and not overwrite:
|
|
||||||
raise RuntimeError("Error: Database exists.")
|
|
||||||
|
|
||||||
if isinstance(password_context, PasswordContext):
|
|
||||||
master_password = password_context.get_password(cls.master_password_identifier, True)
|
|
||||||
else:
|
|
||||||
master_password = password_context
|
|
||||||
|
|
||||||
# TODO: should we delete if overwrite is set?
|
|
||||||
keepass = pykeepass.create_database(location, password=master_password)
|
|
||||||
instance = cls()
|
|
||||||
instance.set_manager_options({"location": str(location)})
|
|
||||||
instance.keepass = keepass
|
|
||||||
return instance
|
|
||||||
|
|
||||||
@override
|
|
||||||
def open_database(self, password_context: PasswordContext | str) -> None:
|
|
||||||
"""Open the database"""
|
|
||||||
if isinstance(password_context, PasswordContext):
|
|
||||||
password = password_context.get_password(self.master_password_identifier)
|
|
||||||
else:
|
|
||||||
password = password_context
|
|
||||||
instance = pykeepass.PyKeePass(str(self.location.absolute()), password=password)
|
|
||||||
self.keepass = instance
|
|
||||||
|
|
||||||
@override
|
|
||||||
def close_database(self) -> None:
|
|
||||||
"""Close the database."""
|
|
||||||
self._keepass = None
|
|
||||||
|
|
||||||
@override
|
|
||||||
def get_password(self, identifier: str) -> str | None:
|
|
||||||
"""Get password."""
|
|
||||||
entry = cast("pykeepass.entry.Entry | None", self.keepass.find_entries(title=identifier, first=True))
|
|
||||||
if not entry:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if password := cast(str, entry.password):
|
|
||||||
return str(password)
|
|
||||||
raise RuntimeError(f"Cannot get password for entry {identifier}")
|
|
||||||
|
|
||||||
|
|
||||||
@override
|
|
||||||
def generate_password(self, identifier: str) -> str:
|
|
||||||
"""Generate password."""
|
|
||||||
# Generate a password.
|
|
||||||
password = generate_password()
|
|
||||||
_entry = self.keepass.add_entry(
|
|
||||||
self.keepass.root_group, identifier, constants.NO_USERNAME, password
|
|
||||||
)
|
|
||||||
self.keepass.save()
|
|
||||||
LOG.debug("Created Entry %r", _entry)
|
|
||||||
return password
|
|
||||||
|
|
||||||
@override
|
|
||||||
def add_password(self, identifier: str, password: str) -> None:
|
|
||||||
"""Add a password."""
|
|
||||||
entry = cast("pykeepass.entry.Entry | None", self.keepass.find_entries(title=identifier, first=True))
|
|
||||||
if not entry:
|
|
||||||
_entry = self.keepass.add_entry(self.keepass.root_group, identifier, constants.NO_USERNAME, password)
|
|
||||||
self.keepass.save()
|
|
||||||
LOG.debug("Created entry %r", _entry)
|
|
||||||
return
|
|
||||||
self.change_password(identifier, password)
|
|
||||||
LOG.debug("Updated password on entry %r", entry)
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def change_password(self, identifier: str, password: None) -> str: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def change_password(self, identifier: str, password: str) -> None: ...
|
|
||||||
|
|
||||||
@override
|
|
||||||
def change_password(self, identifier: str, password: str | None) -> str | None:
|
|
||||||
"""Change a password."""
|
|
||||||
generated_password = False
|
|
||||||
if not password:
|
|
||||||
password = generate_password()
|
|
||||||
generated_password = True
|
|
||||||
|
|
||||||
entry = cast("pykeepass.entry.Entry | None", self.keepass.find_entries(title=identifier, first=True))
|
|
||||||
|
|
||||||
if not entry:
|
|
||||||
raise ValueError("Error: Entry not found!")
|
|
||||||
|
|
||||||
entry.password = password
|
|
||||||
|
|
||||||
self.keepass.save()
|
|
||||||
if generated_password:
|
|
||||||
return password
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@override
|
|
||||||
def delete_password(self, identifier: str) -> None:
|
|
||||||
"""Delete password."""
|
|
||||||
entry = cast("pykeepass.entry.Entry | None", self.keepass.find_entries(title=identifier, first=True))
|
|
||||||
if not entry:
|
|
||||||
return
|
|
||||||
LOG.info("Deleting entry %s for keepass.", entry.uuid)
|
|
||||||
|
|
||||||
self.keepass.delete_entry(entry)
|
|
||||||
|
|
||||||
self.keepass.save()
|
|
||||||
@ -1,61 +0,0 @@
|
|||||||
"""Password reader classes.
|
|
||||||
|
|
||||||
This implements two interfaces to read passwords:
|
|
||||||
|
|
||||||
InputPasswordReader and EnvironmentPasswordReader.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from typing import TextIO, override
|
|
||||||
import click
|
|
||||||
|
|
||||||
from .types import BasePasswordReader
|
|
||||||
from . import constants
|
|
||||||
|
|
||||||
RE_VARNAME = re.compile(r"^[a-zA-Z_]+[a-zA-Z0-9_]*$")
|
|
||||||
|
|
||||||
|
|
||||||
class InputPasswordReader(BasePasswordReader):
|
|
||||||
"""Read a password from stdin."""
|
|
||||||
|
|
||||||
@override
|
|
||||||
def get_password(self, identifier: str, repeated: bool = False) -> str:
|
|
||||||
"""Get password."""
|
|
||||||
if password := click.prompt(
|
|
||||||
f"Enter password for {identifier}", hide_input=True, type=str, confirmation_prompt=repeated
|
|
||||||
):
|
|
||||||
return str(password)
|
|
||||||
raise ValueError("No password received.")
|
|
||||||
|
|
||||||
|
|
||||||
class EnvironmentPasswordReader(BasePasswordReader):
|
|
||||||
"""Read a password from the environment.
|
|
||||||
|
|
||||||
The environemnt variable will be constructured based on the identifier and the prefix.
|
|
||||||
Final environemnt variable will be validated according to the regex `[a-zA-Z_]+[a-zA-Z0-9_]*`
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _resolve_var_name(self, identifier: str) -> str:
|
|
||||||
"""Resolve variable name."""
|
|
||||||
identifier = identifier.replace("-", "_")
|
|
||||||
fields = [constants.VAR_PREFIX, identifier]
|
|
||||||
varname = "_".join(fields)
|
|
||||||
if not RE_VARNAME.fullmatch(varname):
|
|
||||||
raise ValueError(
|
|
||||||
f"Cannot generate encode password identifier in variable name. {varname} is not a valid identifier."
|
|
||||||
)
|
|
||||||
return varname
|
|
||||||
|
|
||||||
def get_password_from_env(self, identifier: str) -> str:
|
|
||||||
"""Get password from environment."""
|
|
||||||
varname = self._resolve_var_name(identifier)
|
|
||||||
if password := os.getenv(varname, None):
|
|
||||||
return password
|
|
||||||
raise ValueError(f"Error: No variable named {varname} resolved.")
|
|
||||||
|
|
||||||
@override
|
|
||||||
def get_password(self, identifier: str, repeated: bool = False) -> str:
|
|
||||||
"""Get password."""
|
|
||||||
return self.get_password_from_env(identifier)
|
|
||||||
@ -1,5 +0,0 @@
|
|||||||
"""Sshecret server module."""
|
|
||||||
|
|
||||||
from .async_server import AsshyncServer, start_server
|
|
||||||
|
|
||||||
__all__ = ["AsshyncServer", "start_server"]
|
|
||||||
@ -1,143 +0,0 @@
|
|||||||
"""Server implemented with asyncssh."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import override
|
|
||||||
|
|
||||||
import asyncssh
|
|
||||||
|
|
||||||
from sshecret import constants
|
|
||||||
from sshecret.audit import audit_message
|
|
||||||
from sshecret.types import ClientSpecification, BaseClientBackend
|
|
||||||
from sshecret.crypto import create_private_rsa_key
|
|
||||||
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def handle_client(process: asyncssh.SSHServerProcess[str]) -> None:
|
|
||||||
"""Handle client."""
|
|
||||||
remote_ip = process.get_extra_info("peername")[0]
|
|
||||||
client_found = process.get_extra_info("client_allowed", False)
|
|
||||||
if not client_found:
|
|
||||||
process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n")
|
|
||||||
audit_message("Unknown connection", source_address=remote_ip)
|
|
||||||
process.exit(1)
|
|
||||||
return
|
|
||||||
|
|
||||||
client_allowed = process.get_extra_info("client_allowed", False)
|
|
||||||
if not client_allowed:
|
|
||||||
audit_message("Not permitted", "SECURITY", source_address=remote_ip)
|
|
||||||
process.stderr.write(constants.ERROR_SOURCE_IP_NOT_ALLOWED + "\n")
|
|
||||||
process.exit(1)
|
|
||||||
return
|
|
||||||
|
|
||||||
client = process.get_extra_info("client")
|
|
||||||
if not client:
|
|
||||||
audit_message("Unknown client", source_address=remote_ip)
|
|
||||||
process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n")
|
|
||||||
process.exit(1)
|
|
||||||
return
|
|
||||||
|
|
||||||
secret_name = process.command
|
|
||||||
if not secret_name:
|
|
||||||
audit_message("No secret specified", source_address=remote_ip, client_name=client.name)
|
|
||||||
process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n")
|
|
||||||
process.exit(1)
|
|
||||||
return
|
|
||||||
|
|
||||||
LOG.debug(
|
|
||||||
"Client %s successfully connected. Fetching secret %s", client.name, secret_name
|
|
||||||
)
|
|
||||||
|
|
||||||
audit_message(f"Requested secret", client_name=client.name, secret_name=secret_name, source_address=remote_ip)
|
|
||||||
secret = client.secrets.get(secret_name)
|
|
||||||
if not secret:
|
|
||||||
process.stderr.write(constants.ERROR_UKNOWN_CLIENT_OR_SECRET + "\n")
|
|
||||||
process.exit(1)
|
|
||||||
return
|
|
||||||
|
|
||||||
audit_message("Accessed secret", client.name, secret_name, source_address=remote_ip)
|
|
||||||
process.stdout.write(secret)
|
|
||||||
process.exit(0)
|
|
||||||
|
|
||||||
|
|
||||||
class AsshyncServer(asyncssh.SSHServer):
|
|
||||||
"""Asynchronous SSH server implementation."""
|
|
||||||
|
|
||||||
def __init__(self, backend: BaseClientBackend) -> None:
|
|
||||||
"""Initialize server."""
|
|
||||||
self.backend: BaseClientBackend = backend
|
|
||||||
self._conn: asyncssh.SSHServerConnection | None = None
|
|
||||||
|
|
||||||
@override
|
|
||||||
def connection_made(self, conn: asyncssh.SSHServerConnection) -> None:
|
|
||||||
"""Handle incoming connection."""
|
|
||||||
peername = conn.get_extra_info("peername")
|
|
||||||
LOG.debug("Connection established from %r", peername)
|
|
||||||
self._conn = conn
|
|
||||||
|
|
||||||
@override
|
|
||||||
def begin_auth(self, username: str) -> bool:
|
|
||||||
"""Begin authentication."""
|
|
||||||
if not self._conn:
|
|
||||||
return True
|
|
||||||
client = self.backend.lookup_name(username)
|
|
||||||
if not client:
|
|
||||||
return True
|
|
||||||
self._conn.set_extra_info(client_found=True)
|
|
||||||
remote_ip = self._conn.get_extra_info("peername")[0]
|
|
||||||
LOG.debug("Remote_IP: %r", remote_ip)
|
|
||||||
assert isinstance(remote_ip, str)
|
|
||||||
if self.check_connection_allowed(client, remote_ip):
|
|
||||||
audit_message("Authentication requested", "ACCESS", client_name=client.name, source_address=remote_ip)
|
|
||||||
self._conn.set_extra_info(client_allowed=True)
|
|
||||||
self._conn.set_extra_info(client=client)
|
|
||||||
|
|
||||||
# Load the key.
|
|
||||||
public_key = asyncssh.import_authorized_keys(client.public_key)
|
|
||||||
self._conn.set_authorized_keys(public_key)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
@override
|
|
||||||
def password_auth_supported(self) -> bool:
|
|
||||||
"""Deny password authentication."""
|
|
||||||
return False
|
|
||||||
|
|
||||||
def check_connection_allowed(
|
|
||||||
self, client: ClientSpecification, source: str
|
|
||||||
) -> bool:
|
|
||||||
"""Check if client is allowed to request secrets."""
|
|
||||||
LOG.debug("Checking if client is allowed to log in from %s", source)
|
|
||||||
if isinstance(client.allowed_ips, str) and client.allowed_ips == "*":
|
|
||||||
audit_message("Permitting login", "SECURITY", client_name=client.name, source_address=source)
|
|
||||||
LOG.debug("Client has no restrictions on source IP address. Permitting.")
|
|
||||||
return True
|
|
||||||
if isinstance(client.allowed_ips, str) and "/" not in client.allowed_ips:
|
|
||||||
if source == client.allowed_ips:
|
|
||||||
audit_message("Permitting login", "SECURITY", client_name=client.name, source_address=source)
|
|
||||||
LOG.debug("Client IP matches permitted address")
|
|
||||||
return True
|
|
||||||
|
|
||||||
LOG.warning(
|
|
||||||
"Connection for client %s received from IP address %s that is not permitted.",
|
|
||||||
client.name,
|
|
||||||
source,
|
|
||||||
)
|
|
||||||
audit_message("REJECTED. Invalid address", "SECURITY", client_name=client.name, source_address=source)
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def start_server(
|
|
||||||
port: int, backend: BaseClientBackend, host_key: str, create_key: bool = False
|
|
||||||
) -> None:
|
|
||||||
"""Start server."""
|
|
||||||
server = partial(AsshyncServer, backend=backend)
|
|
||||||
if create_key:
|
|
||||||
create_private_rsa_key(Path(host_key))
|
|
||||||
await asyncssh.create_server(
|
|
||||||
server, "", port, server_host_keys=[host_key], process_factory=handle_client
|
|
||||||
)
|
|
||||||
@ -1,21 +0,0 @@
|
|||||||
"""Server errors."""
|
|
||||||
|
|
||||||
|
|
||||||
class BaseSshecretServerError(Exception):
|
|
||||||
"""Base Sshecret Server Error."""
|
|
||||||
|
|
||||||
|
|
||||||
class UnknownClientError(BaseSshecretServerError):
|
|
||||||
"""Client was not recognized."""
|
|
||||||
|
|
||||||
|
|
||||||
class AccessDeniedError(BaseSshecretServerError):
|
|
||||||
"""Client was not authorized to access the resource."""
|
|
||||||
|
|
||||||
|
|
||||||
class AccessPolicyViolationError(BaseSshecretServerError):
|
|
||||||
"""Client was not authorized to access the secret."""
|
|
||||||
|
|
||||||
|
|
||||||
class UnknownSecretError(BaseSshecretServerError):
|
|
||||||
"""Error when resolving the secret."""
|
|
||||||
@ -1,36 +0,0 @@
|
|||||||
"""Password reader for use with the SSH server."""
|
|
||||||
|
|
||||||
from typing import override, TextIO
|
|
||||||
import asyncssh
|
|
||||||
from sshecret.types import BasePasswordReader
|
|
||||||
|
|
||||||
|
|
||||||
class SSHPasswordReader(BasePasswordReader):
|
|
||||||
"""SSH Password reader."""
|
|
||||||
|
|
||||||
def __init__(self, channel: asyncssh.SSHLineEditorChannel, stdin: asyncssh.SSHReader[str], stdout: asyncssh.SSHWriter[str]) -> None:
|
|
||||||
"""Initialize password reader."""
|
|
||||||
self.channel: asyncssh.SSHLineEditorChannel = channel
|
|
||||||
self.stdin: asyncssh.SSHReader[str] = stdin
|
|
||||||
self.stdout: asyncssh.SSHWriter[str] = stdout
|
|
||||||
|
|
||||||
@override
|
|
||||||
def get_password(self, identifier: str, repeated: bool = False) -> str:
|
|
||||||
"""Get password."""
|
|
||||||
raise RuntimeError("Use get_password_async!")
|
|
||||||
|
|
||||||
async def get_password_async(self, identifier: str, repeated: bool = False) -> str:
|
|
||||||
"""Get password async."""
|
|
||||||
self.stdout.write(f"Enter password for {identifier}: ")
|
|
||||||
self.channel.set_echo(False)
|
|
||||||
while True:
|
|
||||||
password = await self.stdin.readline()
|
|
||||||
if not repeated:
|
|
||||||
break
|
|
||||||
self.stdout.write(f"\nRe-enter password for {identifier}: ")
|
|
||||||
password2 = await self.stdin.readline()
|
|
||||||
if password == password2:
|
|
||||||
break
|
|
||||||
self.stdout.write(f"Passwords did not match. Try again.\n")
|
|
||||||
self.channel.set_echo(True)
|
|
||||||
return password.strip()
|
|
||||||
@ -1,82 +0,0 @@
|
|||||||
"""Get settings."""
|
|
||||||
|
|
||||||
import abc
|
|
||||||
import enum
|
|
||||||
import os
|
|
||||||
import tomllib
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Literal
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
from pydantic import BaseModel, DirectoryPath, Field, FilePath
|
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
||||||
|
|
||||||
from sshecret.keepass import KeepassManager
|
|
||||||
|
|
||||||
SETTINGS_FILE = "sshecret.toml"
|
|
||||||
|
|
||||||
|
|
||||||
class Backend(enum.StrEnum):
|
|
||||||
"""Supported backends."""
|
|
||||||
|
|
||||||
FILES = "FILES"
|
|
||||||
|
|
||||||
|
|
||||||
class PasswordManager(enum.StrEnum):
|
|
||||||
"""Supported password managers."""
|
|
||||||
|
|
||||||
KEEPASS = "KeePass"
|
|
||||||
|
|
||||||
|
|
||||||
class SSHServerSettings(BaseModel):
|
|
||||||
"""SSH Server settings."""
|
|
||||||
|
|
||||||
port: int = 22
|
|
||||||
private_key: FilePath | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class AdminApiSettings(BaseModel):
|
|
||||||
"""Admin API settings."""
|
|
||||||
|
|
||||||
port: int = 8022
|
|
||||||
|
|
||||||
|
|
||||||
class FileBackendSettings(BaseModel):
|
|
||||||
"""File backend settings.
|
|
||||||
|
|
||||||
This will eventually have the Discriminator pattern described in pydantic.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: Literal["Files"]
|
|
||||||
|
|
||||||
location: DirectoryPath
|
|
||||||
|
|
||||||
|
|
||||||
class KeepassPDBSettings(BaseModel):
|
|
||||||
"""Keepass backend settings."""
|
|
||||||
|
|
||||||
type: Literal["KeePass"]
|
|
||||||
location: FilePath
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
|
||||||
"""Sshecret settings."""
|
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_prefix="sshecret_", env_nested_delimiter="__")
|
|
||||||
|
|
||||||
backend: FileBackendSettings
|
|
||||||
password_manager: KeepassPDBSettings
|
|
||||||
admin_api: AdminApiSettings = Field(default_factory=AdminApiSettings)
|
|
||||||
ssh_server: SSHServerSettings = Field(default_factory=SSHServerSettings)
|
|
||||||
|
|
||||||
|
|
||||||
def get_settings() -> Settings:
|
|
||||||
"""Get settings."""
|
|
||||||
cwd = Path(os.getcwd())
|
|
||||||
settings_file = cwd / SETTINGS_FILE
|
|
||||||
if not settings_file.exists():
|
|
||||||
# This should fail if the current env variables don't exist.
|
|
||||||
return Settings() # pyright: ignore[reportCallIssue]
|
|
||||||
with open(settings_file, "rb") as f:
|
|
||||||
settings_data = tomllib.load(f)
|
|
||||||
|
|
||||||
return Settings.model_validate(settings_data)
|
|
||||||
@ -1 +0,0 @@
|
|||||||
"""Shell interface."""
|
|
||||||
@ -1,55 +0,0 @@
|
|||||||
"""Admin shell."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import click
|
|
||||||
from click_repl import register_repl
|
|
||||||
|
|
||||||
from sshecret.api import ClientManagementAPI
|
|
||||||
|
|
||||||
from sshecret import constants
|
|
||||||
from sshecret.password_readers import InputPasswordReader
|
|
||||||
from sshecret.keepass import KeepassManager
|
|
||||||
from sshecret.types import PasswordContext
|
|
||||||
from .shell_client import ShellClient
|
|
||||||
|
|
||||||
DB_PATH = os.path.join(os.getcwd(), "sshecrets.kdbx")
|
|
||||||
|
|
||||||
api_client: ShellClient | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
|
||||||
@click.pass_context
|
|
||||||
def cli(ctx: click.Context) -> None:
|
|
||||||
"""General CLI."""
|
|
||||||
if api_client is None:
|
|
||||||
raise RuntimeError("No client object defined.")
|
|
||||||
|
|
||||||
@cli.group(name="clients")
|
|
||||||
def cmd_clients() -> None:
|
|
||||||
"""Client context."""
|
|
||||||
|
|
||||||
@cmd_clients.command(name="show")
|
|
||||||
def show_clients() -> None:
|
|
||||||
"""Show clients."""
|
|
||||||
example_set = ["client1", "client2", "client3"]
|
|
||||||
for client in example_set:
|
|
||||||
click.echo(f"- {client}")
|
|
||||||
|
|
||||||
@cmd_clients.command(name="add")
|
|
||||||
@click.argument("name")
|
|
||||||
def add_client(name: str) -> None:
|
|
||||||
"""Add a client."""
|
|
||||||
public_key = click.prompt("Please paste RSA public key")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
|
||||||
@click.option("--overwrite", is_flag=True, help="Overwrite password database.")
|
|
||||||
def create_database(overwrite: bool) -> None:
|
|
||||||
"""Create database."""
|
|
||||||
context = PasswordContext(InputPasswordReader)
|
|
||||||
KeepassManager.create_database(DB_PATH, context, overwrite)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
api_client = ShellClient("127.0.0.1", KeepassManager)
|
|
||||||
@ -1,14 +0,0 @@
|
|||||||
"""Shell commands.
|
|
||||||
|
|
||||||
The shell needs to implement the following shell commands:
|
|
||||||
|
|
||||||
- Client management
|
|
||||||
client create/read/update/delete
|
|
||||||
secret create/read/update/delete
|
|
||||||
client permit secret
|
|
||||||
client revoke secret
|
|
||||||
client key rotate
|
|
||||||
|
|
||||||
audit show
|
|
||||||
|
|
||||||
"""
|
|
||||||
@ -1,29 +0,0 @@
|
|||||||
"""Shell API Client object for auditing."""
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import override
|
|
||||||
from sshecret.password_readers import InputPasswordReader
|
|
||||||
from sshecret.types import BaseAPIClient, BasePasswordManager, BasePasswordReader, PasswordContext
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class ShellClient(BaseAPIClient):
|
|
||||||
"""Client connecting from local host."""
|
|
||||||
|
|
||||||
source: str
|
|
||||||
password_manager_type: type[BasePasswordManager]
|
|
||||||
method: str = field(init=False, default="shell")
|
|
||||||
|
|
||||||
@override
|
|
||||||
def get_reader(self) -> type[BasePasswordReader]:
|
|
||||||
"""Get reader."""
|
|
||||||
return InputPasswordReader
|
|
||||||
|
|
||||||
@override
|
|
||||||
def password_manager(self, manager_options: dict[str, str] | None = None) -> BasePasswordManager:
|
|
||||||
"""Instantiate password manager."""
|
|
||||||
manager_instance = self.password_manager_type()
|
|
||||||
if manager_options:
|
|
||||||
manager_instance.set_manager_options(manager_options)
|
|
||||||
|
|
||||||
return manager_instance
|
|
||||||
@ -1,45 +0,0 @@
|
|||||||
"""Shell context manager."""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from contextvars import ContextVar
|
|
||||||
from typing import Iterator, TextIO
|
|
||||||
from click_shell.core import Shell
|
|
||||||
|
|
||||||
from sshecret.api import ManagementApi
|
|
||||||
from sshecret.password_readers import InputPasswordReader
|
|
||||||
from sshecret.types import BaseClientBackend, BasePasswordManager
|
|
||||||
|
|
||||||
from .shell_client import ShellClient
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class ShellContext:
|
|
||||||
"""Shell context."""
|
|
||||||
|
|
||||||
api: ManagementApi
|
|
||||||
shell: Shell
|
|
||||||
streams: tuple[TextIO, TextIO] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def shell_session(
|
|
||||||
shell: Shell,
|
|
||||||
backend: BaseClientBackend,
|
|
||||||
password_manager: type[BasePasswordManager],
|
|
||||||
source_address: str,
|
|
||||||
manager_options: dict[str, str] | None = None,
|
|
||||||
) -> Iterator[ShellContext]:
|
|
||||||
"""Start a shell session.
|
|
||||||
|
|
||||||
The idea here is to collect the context, store it in an instance variable,
|
|
||||||
and run the shell.
|
|
||||||
"""
|
|
||||||
reader = InputPasswordReader
|
|
||||||
client = ShellClient(source_address, password_manager)
|
|
||||||
api = ManagementApi(backend, client, manager_options)
|
|
||||||
@ -1,96 +0,0 @@
|
|||||||
"""Testing utilities and classes."""
|
|
||||||
|
|
||||||
from io import StringIO
|
|
||||||
import tempfile
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from pathlib import Path
|
|
||||||
from collections.abc import Iterator
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
from .utils import create_client_file, generate_password
|
|
||||||
from . import settings as app_settings
|
|
||||||
from .keepass import KeepassManager
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TestClientSpec:
|
|
||||||
"""Specification of a test client."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
secrets: dict[str, str] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TestContext:
|
|
||||||
"""Test context."""
|
|
||||||
|
|
||||||
path: Path
|
|
||||||
master_password: str
|
|
||||||
|
|
||||||
@property
|
|
||||||
def password_database(self) -> Path:
|
|
||||||
"""Return password database location."""
|
|
||||||
return self.path / "test.kdbx"
|
|
||||||
|
|
||||||
def get_settings(self) -> app_settings.Settings:
|
|
||||||
"""Get settings."""
|
|
||||||
return app_settings.Settings(
|
|
||||||
backend=app_settings.BackendSettings(
|
|
||||||
backend=app_settings.FileBackendSettings(
|
|
||||||
type="Files", location=self.path
|
|
||||||
),
|
|
||||||
),
|
|
||||||
password_manager=app_settings.PasswordManagerSettings(
|
|
||||||
manager=app_settings.KeepassPDBSettings(
|
|
||||||
type="KeePass", location=self.password_database
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def set_environment(context: TestContext) -> None:
|
|
||||||
"""Set environment."""
|
|
||||||
password_path = str(context.password_database)
|
|
||||||
env: list[str] = [
|
|
||||||
f"sshecret_backend__backend_location={str(context.path)}",
|
|
||||||
"sshecret_backend__password_manager__manager_type=KeePass",
|
|
||||||
f"sshecret_backend__password_manager__manager_location={password_path}",
|
|
||||||
]
|
|
||||||
env_str = StringIO("\n".join(env))
|
|
||||||
load_dotenv(stream=env_str)
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def test_context(clients: list[TestClientSpec]) -> Iterator[Path]:
|
|
||||||
"""Create a test context."""
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
|
||||||
dirpath = Path(tmpdir)
|
|
||||||
for client in clients:
|
|
||||||
filename = dirpath / f"{client.name}.json"
|
|
||||||
create_client_file(client.name, filename, client.secrets)
|
|
||||||
|
|
||||||
yield dirpath
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def api_context(clients: list[TestClientSpec]) -> Iterator[TestContext]:
|
|
||||||
"""Create a context for testing the full API."""
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
|
||||||
dirpath = Path(tmpdir)
|
|
||||||
master_password = generate_password()
|
|
||||||
context = TestContext(dirpath, master_password)
|
|
||||||
keepass = KeepassManager.create_database(
|
|
||||||
str(context.password_database), master_password
|
|
||||||
)
|
|
||||||
seen_secrets: list[str] = []
|
|
||||||
for client in clients:
|
|
||||||
filename = dirpath / f"{client.name}.json"
|
|
||||||
create_client_file(client.name, filename, client.secrets)
|
|
||||||
for secret, value in client.secrets.items():
|
|
||||||
if secret in seen_secrets:
|
|
||||||
continue
|
|
||||||
keepass.add_password(secret, value)
|
|
||||||
seen_secrets.append(secret)
|
|
||||||
|
|
||||||
yield context
|
|
||||||
@ -1,179 +0,0 @@
|
|||||||
"""Interfaces and types."""
|
|
||||||
|
|
||||||
import abc
|
|
||||||
from types import NotImplementedType
|
|
||||||
from typing import Self, overload
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from pydantic.networks import IPvAnyAddress, IPvAnyNetwork
|
|
||||||
|
|
||||||
|
|
||||||
class BasePasswordReader(abc.ABC):
|
|
||||||
"""Abstract strategy class to read a passwords."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def get_password(self, identifier: str, repeated: bool = False) -> str:
|
|
||||||
"""Resolve the password, e.g., via input."""
|
|
||||||
|
|
||||||
|
|
||||||
class PasswordContext:
|
|
||||||
"""Context class for resolving a password."""
|
|
||||||
|
|
||||||
def __init__(self, reader: BasePasswordReader) -> None:
|
|
||||||
"""Initialize password context."""
|
|
||||||
self._reader: BasePasswordReader = reader
|
|
||||||
|
|
||||||
@property
|
|
||||||
def reader(self) -> BasePasswordReader:
|
|
||||||
"""Return reader."""
|
|
||||||
return self._reader
|
|
||||||
|
|
||||||
@reader.setter
|
|
||||||
def reader(self, reader: BasePasswordReader) -> None:
|
|
||||||
"""Set the reader instance."""
|
|
||||||
self._reader = reader
|
|
||||||
|
|
||||||
def get_password(self, identifier: str, repeated: bool = False) -> str:
|
|
||||||
"""Get the password."""
|
|
||||||
return self.reader.get_password(identifier, repeated)
|
|
||||||
|
|
||||||
|
|
||||||
class BasePasswordManager(abc.ABC):
|
|
||||||
"""Abstract base class for password managers."""
|
|
||||||
|
|
||||||
master_password_identifier: str
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@abc.abstractmethod
|
|
||||||
def create_database(
|
|
||||||
cls,
|
|
||||||
location: str,
|
|
||||||
password_context: PasswordContext | str,
|
|
||||||
overwrite: bool = False,
|
|
||||||
) -> Self:
|
|
||||||
"""Create database.
|
|
||||||
|
|
||||||
Location can be a file, a url or something else.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def open_database(self, password_context: PasswordContext | str) -> None:
|
|
||||||
"""Open database."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def close_database(self) -> None:
|
|
||||||
"""Close database."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def get_password(self, identifier: str) -> str | None:
|
|
||||||
"""Get a password from the manager."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def generate_password(self, identifier: str) -> str:
|
|
||||||
"""Generate a password using unspecified default rules.
|
|
||||||
|
|
||||||
May be expanded later.
|
|
||||||
|
|
||||||
Returns the generated password.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def add_password(self, identifier: str, password: str) -> None:
|
|
||||||
"""Add a pre-defined password."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def get_entries(self) -> list[str]:
|
|
||||||
"""Get names of all entries."""
|
|
||||||
|
|
||||||
def set_manager_options(self, options: dict[str, str]) -> None:
|
|
||||||
"""Set manager options."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def change_password(self, identifier: str, password: None) -> str: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def change_password(self, identifier: str, password: str) -> None: ...
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def change_password(self, identifier: str, password: str | None) -> str | None:
|
|
||||||
"""Change password."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def delete_password(self, identifier: str) -> None:
|
|
||||||
"""Delete a password."""
|
|
||||||
|
|
||||||
|
|
||||||
class ClientSpecification(BaseModel):
|
|
||||||
"""Specification of client."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
public_key: str
|
|
||||||
allowed_ips: list[IPvAnyAddress | IPvAnyNetwork] | str = "*"
|
|
||||||
secrets: dict[str, str] = {}
|
|
||||||
testing_private_key: str | None = None # Private key only for testing purposes!
|
|
||||||
|
|
||||||
|
|
||||||
class BaseClientBackend(abc.ABC):
|
|
||||||
"""Base client backend.
|
|
||||||
|
|
||||||
This class is responsible for managing the list of clients and facilitate
|
|
||||||
lookups.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def lookup_name(self, name: str) -> ClientSpecification | None:
|
|
||||||
"""Lookup a client specification by name."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def add_secret(
|
|
||||||
self,
|
|
||||||
client_name: str,
|
|
||||||
secret_name: str,
|
|
||||||
secret_value: str,
|
|
||||||
encrypted: bool = False,
|
|
||||||
) -> None:
|
|
||||||
"""Add a secret to a client."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def add_client(self, spec: ClientSpecification) -> None:
|
|
||||||
"""Add a new client."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def update_client(self, name: str, spec: ClientSpecification) -> None:
|
|
||||||
"""Update client information."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def remove_client(self, name: str, persistent: bool = True) -> None:
|
|
||||||
"""Delete a client."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def get_all(self) -> list[ClientSpecification]:
|
|
||||||
"""Get all clients."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def lookup_by_secret(self, secret_name: str) -> list[ClientSpecification]:
|
|
||||||
"""Lookup by the name of a secret."""
|
|
||||||
|
|
||||||
|
|
||||||
class BaseAPIClient(abc.ABC):
|
|
||||||
"""Base API Client."""
|
|
||||||
|
|
||||||
source: str
|
|
||||||
method: str
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def password_manager(
|
|
||||||
self, manager_options: dict[str, str] | None = None
|
|
||||||
) -> BasePasswordManager:
|
|
||||||
"""Instantiate password manager."""
|
|
||||||
|
|
||||||
def get_reader(self) -> BasePasswordReader:
|
|
||||||
"""Get the reader."""
|
|
||||||
raise NotImplementedError("Class-based password reading not implemented.")
|
|
||||||
|
|
||||||
def get_context(self, reader: BasePasswordReader | None = None) -> PasswordContext:
|
|
||||||
"""Get password context."""
|
|
||||||
if not reader:
|
|
||||||
reader = self.get_reader()
|
|
||||||
return PasswordContext(reader)
|
|
||||||
@ -1,77 +0,0 @@
|
|||||||
"""Various utilities."""
|
|
||||||
|
|
||||||
import secrets
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from .crypto import (
|
|
||||||
load_client_key,
|
|
||||||
encrypt_string,
|
|
||||||
generate_private_key,
|
|
||||||
generate_pem,
|
|
||||||
generate_public_key_string,
|
|
||||||
)
|
|
||||||
from .types import ClientSpecification
|
|
||||||
|
|
||||||
|
|
||||||
def generate_password() -> str:
|
|
||||||
"""Generate a password."""
|
|
||||||
return secrets.token_urlsafe(32)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_client_object(
|
|
||||||
name: str, secrets: dict[str, str] | None = None, keyfile: str | None = None
|
|
||||||
) -> ClientSpecification:
|
|
||||||
"""Generate a client object."""
|
|
||||||
private_key = generate_private_key()
|
|
||||||
if keyfile:
|
|
||||||
with open(keyfile, "r") as f:
|
|
||||||
contents = f.read()
|
|
||||||
if not contents.startswith("ssh-rsa "):
|
|
||||||
raise RuntimeError("Error: Key must be an RSA key.")
|
|
||||||
|
|
||||||
client = ClientSpecification(name=name, public_key=contents.strip())
|
|
||||||
public_key = load_client_key(client)
|
|
||||||
else:
|
|
||||||
pem = generate_pem(private_key)
|
|
||||||
public_key = private_key.public_key()
|
|
||||||
pubkey_str = generate_public_key_string(public_key)
|
|
||||||
client = ClientSpecification(
|
|
||||||
name=name, public_key=pubkey_str, testing_private_key=pem
|
|
||||||
)
|
|
||||||
if secrets:
|
|
||||||
for secret_name, secret_value in secrets.items():
|
|
||||||
client.secrets[secret_name] = encrypt_string(secret_value, public_key)
|
|
||||||
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
def create_client_file(
|
|
||||||
name: str,
|
|
||||||
filename: Path | str,
|
|
||||||
secrets: dict[str, str] | None = None,
|
|
||||||
keyfile: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Create client file."""
|
|
||||||
client = generate_client_object(name, secrets, keyfile)
|
|
||||||
|
|
||||||
with open(filename, "w") as f:
|
|
||||||
f.write(client.model_dump_json(exclude_none=True, indent=2))
|
|
||||||
f.flush()
|
|
||||||
|
|
||||||
|
|
||||||
def add_secret_to_client_file(
|
|
||||||
filename: str | Path, secret_name: str, secret_value: str
|
|
||||||
) -> None:
|
|
||||||
"""Add secret to client file."""
|
|
||||||
with open(filename, "r") as f:
|
|
||||||
client = ClientSpecification.model_validate_json(f.read())
|
|
||||||
|
|
||||||
public_key = load_client_key(client)
|
|
||||||
encrypted = encrypt_string(secret_value, public_key)
|
|
||||||
client.secrets[secret_name] = encrypted
|
|
||||||
|
|
||||||
with open(filename, "w") as f:
|
|
||||||
json_str = client.model_dump_json(exclude_none=True, indent=2)
|
|
||||||
f.write(json_str)
|
|
||||||
f.flush()
|
|
||||||
@ -1 +0,0 @@
|
|||||||
|
|
||||||
@ -1,279 +0,0 @@
|
|||||||
"""WebAPI."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
from functools import lru_cache
|
|
||||||
import secrets
|
|
||||||
import time
|
|
||||||
|
|
||||||
from typing import Annotated
|
|
||||||
from fastapi import Header, HTTPException, Depends, Request, APIRouter
|
|
||||||
|
|
||||||
from cryptography.fernet import Fernet
|
|
||||||
|
|
||||||
from sshecret.api import ManagementApi
|
|
||||||
from sshecret.types import (
|
|
||||||
BaseClientBackend,
|
|
||||||
BasePasswordManager,
|
|
||||||
ClientSpecification,
|
|
||||||
)
|
|
||||||
from sshecret.keepass import KeepassManager
|
|
||||||
from sshecret.backends.file_table import FileTableBackend
|
|
||||||
|
|
||||||
from sshecret.settings import Settings, get_settings
|
|
||||||
from sshecret.webapi.api_client import WebManagementAPIClient
|
|
||||||
from . import models
|
|
||||||
|
|
||||||
API_VERSION = "v1"
|
|
||||||
|
|
||||||
admin_router = APIRouter(prefix=f"/api/{API_VERSION}")
|
|
||||||
|
|
||||||
encryption_key = Fernet.generate_key()
|
|
||||||
cipher = Fernet(encryption_key)
|
|
||||||
|
|
||||||
# We store sessions in memory.
|
|
||||||
sessions: dict[str, tuple[str, float]] = {}
|
|
||||||
|
|
||||||
SESSION_TIMEOUT = 600 # 10 minutes
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
session_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
|
|
||||||
def encrypt_session_password(password: str) -> str:
|
|
||||||
"""Encrypts the master password."""
|
|
||||||
return cipher.encrypt(password.encode()).decode()
|
|
||||||
|
|
||||||
|
|
||||||
def decrypt_password(encrypted_password: str) -> str:
|
|
||||||
"""Decrypts the master password asynchronously."""
|
|
||||||
return cipher.decrypt(encrypted_password.encode()).decode()
|
|
||||||
|
|
||||||
|
|
||||||
async def validate_session(session_id: Annotated[str | None, Header()] = None) -> str:
|
|
||||||
"""Middleware to validate session and enforce timeout."""
|
|
||||||
if not session_id:
|
|
||||||
raise HTTPException(status_code=401, detail="Session ID required")
|
|
||||||
|
|
||||||
async with session_lock:
|
|
||||||
if session_id not in sessions:
|
|
||||||
raise HTTPException(status_code=401, detail="Session invalid or expired")
|
|
||||||
|
|
||||||
encrypted_password, last_access_time = sessions[session_id]
|
|
||||||
current_time = asyncio.get_event_loop().time()
|
|
||||||
|
|
||||||
# Check for session timeout
|
|
||||||
if current_time - last_access_time > SESSION_TIMEOUT:
|
|
||||||
del sessions[session_id] # Auto-lock on timeout
|
|
||||||
raise HTTPException(status_code=401, detail="Session expired")
|
|
||||||
|
|
||||||
# Update last access time
|
|
||||||
sessions[session_id] = (encrypted_password, current_time)
|
|
||||||
|
|
||||||
return decrypt_password(encrypted_password)
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
|
||||||
def get_app_settings() -> Settings:
|
|
||||||
"""Get app settings."""
|
|
||||||
return get_settings()
|
|
||||||
|
|
||||||
|
|
||||||
def get_password_manager(
|
|
||||||
settings: Annotated[Settings, Depends(get_app_settings)]
|
|
||||||
) -> BasePasswordManager:
|
|
||||||
"""Get password manager."""
|
|
||||||
# Currently only keepass is supported.
|
|
||||||
keepass = KeepassManager()
|
|
||||||
keepass.location = settings.password_manager.location
|
|
||||||
return keepass
|
|
||||||
|
|
||||||
|
|
||||||
async def get_backend(
|
|
||||||
settings: Annotated[Settings, Depends(get_app_settings)]
|
|
||||||
) -> BaseClientBackend:
|
|
||||||
"""Get backend."""
|
|
||||||
location = settings.backend.location
|
|
||||||
filetable = FileTableBackend(location)
|
|
||||||
return filetable
|
|
||||||
|
|
||||||
|
|
||||||
async def get_management_api(
|
|
||||||
request: Request, settings: Annotated[Settings, Depends(get_app_settings)]
|
|
||||||
) -> ManagementApi:
|
|
||||||
"""Get management api."""
|
|
||||||
client_ip = "unknown"
|
|
||||||
if req_client := request.client:
|
|
||||||
client_ip = req_client.host
|
|
||||||
|
|
||||||
api_client = WebManagementAPIClient(client_ip, settings)
|
|
||||||
backend = await get_backend(settings)
|
|
||||||
return ManagementApi(backend, api_client)
|
|
||||||
|
|
||||||
|
|
||||||
BackendDependency = Annotated[BaseClientBackend, Depends(get_backend)]
|
|
||||||
ManagementAPIDependency = Annotated[ManagementApi, Depends(get_management_api)]
|
|
||||||
SessionPasswdDependency = Annotated[str, Depends(validate_session)]
|
|
||||||
|
|
||||||
|
|
||||||
@admin_router.post("/auth/unlock")
|
|
||||||
async def unlock_database(
|
|
||||||
password: models.PasswordBody,
|
|
||||||
password_manager: Annotated[BasePasswordManager, Depends(get_password_manager)],
|
|
||||||
) -> models.SessionResponse:
|
|
||||||
"""Unlock database with master password sent in POST body."""
|
|
||||||
password_str = password.password.get_secret_value()
|
|
||||||
try:
|
|
||||||
password_manager.open_database(password_str)
|
|
||||||
except Exception as e:
|
|
||||||
LOG.debug("Exception: %s", e, exc_info=True)
|
|
||||||
raise HTTPException(status_code=401, detail="Invalid password.")
|
|
||||||
|
|
||||||
session_id = secrets.token_urlsafe(32)
|
|
||||||
sessions[session_id] = (encrypt_session_password(password_str), time.time())
|
|
||||||
|
|
||||||
return models.SessionResponse(session_id=session_id)
|
|
||||||
|
|
||||||
|
|
||||||
@admin_router.post("/auth/lock")
|
|
||||||
async def lock_database(
|
|
||||||
session_id: Annotated[str | None, Header()] = None
|
|
||||||
) -> dict[str, str]:
|
|
||||||
"""Lock database."""
|
|
||||||
if session_id and session_id in sessions:
|
|
||||||
del sessions[session_id]
|
|
||||||
|
|
||||||
return {"message": "LOCKED"}
|
|
||||||
raise HTTPException(400, detail="Missing session ID.")
|
|
||||||
|
|
||||||
|
|
||||||
@admin_router.get("/auth/status")
|
|
||||||
async def get_lock_status(
|
|
||||||
session_id: Annotated[str | None, Header()] = None
|
|
||||||
) -> dict[str, str]:
|
|
||||||
"""Get current lock status."""
|
|
||||||
if session_id and session_id in sessions:
|
|
||||||
return {"message": "UNLOCKED"}
|
|
||||||
return {"message": "LOCKED"}
|
|
||||||
|
|
||||||
|
|
||||||
@admin_router.get("/clients")
|
|
||||||
async def get_clients(admin_api: ManagementAPIDependency) -> list[ClientSpecification]:
|
|
||||||
"""Get clients."""
|
|
||||||
return admin_api.get_clients()
|
|
||||||
|
|
||||||
|
|
||||||
@admin_router.get("/clients/{client_id}")
|
|
||||||
async def get_client(
|
|
||||||
client_id: str, admin_api: ManagementAPIDependency
|
|
||||||
) -> ClientSpecification:
|
|
||||||
"""Get client."""
|
|
||||||
if client_api := admin_api.get_client(client_id):
|
|
||||||
return client_api.client
|
|
||||||
raise HTTPException(status_code=404, detail="Client not found.")
|
|
||||||
|
|
||||||
|
|
||||||
@admin_router.put("/clients/{client_id}")
|
|
||||||
async def update_client(
|
|
||||||
client_id: str,
|
|
||||||
client: ClientSpecification,
|
|
||||||
admin_api: ManagementAPIDependency,
|
|
||||||
master_password: SessionPasswdDependency,
|
|
||||||
) -> ClientSpecification:
|
|
||||||
"""Update client."""
|
|
||||||
client_api = admin_api.get_client(client_id)
|
|
||||||
if not client_api:
|
|
||||||
raise HTTPException(status_code=404, detail="Client not found.")
|
|
||||||
new_client = client_api.update_client(client, password=master_password)
|
|
||||||
return new_client
|
|
||||||
|
|
||||||
|
|
||||||
@admin_router.delete("/clients/{client_id}", status_code=204)
|
|
||||||
async def delete_client(client_id: str, admin_api: ManagementAPIDependency) -> None:
|
|
||||||
"""Delete client."""
|
|
||||||
if admin_api.get_client(client_id):
|
|
||||||
admin_api.delete_client(client_id)
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=404, detail="Client not found.")
|
|
||||||
|
|
||||||
|
|
||||||
@admin_router.post("/clients", status_code=201)
|
|
||||||
async def add_client(
|
|
||||||
client: models.CreateClientModel, admin_api: ManagementAPIDependency
|
|
||||||
) -> ClientSpecification:
|
|
||||||
"""Add client."""
|
|
||||||
new_client = admin_api.create_client(
|
|
||||||
client.name, client.public_key, client.allowed_ips
|
|
||||||
)
|
|
||||||
return new_client.client
|
|
||||||
|
|
||||||
|
|
||||||
@admin_router.get("/secrets")
|
|
||||||
async def list_secrets(
|
|
||||||
admin_api: ManagementAPIDependency, password: SessionPasswdDependency
|
|
||||||
) -> list[models.SecretListResponse]:
|
|
||||||
"""List secrets."""
|
|
||||||
secrets = admin_api.get_secret_names(password=password)
|
|
||||||
return [
|
|
||||||
models.SecretListResponse(name=name, assigned_clients=assigned_clients)
|
|
||||||
for name, assigned_clients in secrets.items()
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@admin_router.post("/secrets")
|
|
||||||
async def add_secret(
|
|
||||||
secret: models.CreateSecretSpecification,
|
|
||||||
password: SessionPasswdDependency,
|
|
||||||
admin_api: ManagementAPIDependency,
|
|
||||||
) -> models.RevealSecretResponse:
|
|
||||||
"""Add secret.
|
|
||||||
|
|
||||||
Will generate a password if none is specified.
|
|
||||||
"""
|
|
||||||
secret_value: str | None = None
|
|
||||||
if secret.secret:
|
|
||||||
secret_value = secret.secret.get_secret_value()
|
|
||||||
result_secret = admin_api.add_secret(secret.name, secret_value, password=password)
|
|
||||||
return models.RevealSecretResponse(name=secret.name, secret=result_secret)
|
|
||||||
|
|
||||||
|
|
||||||
@admin_router.get("/secrets/{name}")
|
|
||||||
async def get_secret(
|
|
||||||
name: str, admin_api: ManagementAPIDependency, password: SessionPasswdDependency
|
|
||||||
) -> models.RevealSecretResponse:
|
|
||||||
"""Get secret."""
|
|
||||||
if secret_value := admin_api.get_secret(name, password=password):
|
|
||||||
return models.RevealSecretResponse(name=name, secret=secret_value)
|
|
||||||
raise HTTPException(status_code=404, detail="Secret not found.")
|
|
||||||
|
|
||||||
|
|
||||||
@admin_router.put("/secrets/{name}")
|
|
||||||
async def update_secret(
|
|
||||||
name: str,
|
|
||||||
spec: models.UpdateSecretSpecification,
|
|
||||||
admin_api: ManagementAPIDependency,
|
|
||||||
password: SessionPasswdDependency,
|
|
||||||
) -> models.MaybeRevalSecretResponse:
|
|
||||||
"""Update secret."""
|
|
||||||
if spec.auto_generate:
|
|
||||||
secret_value = admin_api.regenerate_secret(name, password=password)
|
|
||||||
return models.MaybeRevalSecretResponse(name=name, secret=secret_value)
|
|
||||||
|
|
||||||
if not spec.secret:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="Secret value must be specified if auto_generate is False",
|
|
||||||
)
|
|
||||||
admin_api.update_secret(name, spec.secret, password=password)
|
|
||||||
return models.MaybeRevalSecretResponse(name=name, secret=None)
|
|
||||||
|
|
||||||
|
|
||||||
@admin_router.delete("/secrets/{name}", status_code=204)
|
|
||||||
async def delete_secret(
|
|
||||||
name: str,
|
|
||||||
admin_api: ManagementAPIDependency,
|
|
||||||
password: SessionPasswdDependency,
|
|
||||||
) -> None:
|
|
||||||
"""Delete secret."""
|
|
||||||
admin_api.delete_secret(name, password=password)
|
|
||||||
@ -1,25 +0,0 @@
|
|||||||
"""API Client."""
|
|
||||||
|
|
||||||
from typing import override
|
|
||||||
from sshecret.keepass import KeepassManager
|
|
||||||
from sshecret.types import BaseAPIClient, BasePasswordManager
|
|
||||||
from sshecret.settings import Settings, get_settings
|
|
||||||
|
|
||||||
|
|
||||||
class WebManagementAPIClient(BaseAPIClient):
|
|
||||||
"""Client class for the web management API."""
|
|
||||||
|
|
||||||
method: str = "admin-web-api"
|
|
||||||
|
|
||||||
def __init__(self, source: str, settings: Settings | None = None) -> None:
|
|
||||||
"""Construct client."""
|
|
||||||
if not settings:
|
|
||||||
settings = get_settings()
|
|
||||||
self.source: str = source
|
|
||||||
self._password_manager: BasePasswordManager = KeepassManager()
|
|
||||||
self._password_manager.location = settings.password_manager.manager.location
|
|
||||||
|
|
||||||
@override
|
|
||||||
def password_manager(self, manager_options: dict[str, str] | None = None) -> BasePasswordManager:
|
|
||||||
"""Get password manager."""
|
|
||||||
return self._password_manager
|
|
||||||
@ -1,33 +0,0 @@
|
|||||||
"""Admin frontend."""
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Request
|
|
||||||
|
|
||||||
from fastapi.responses import HTMLResponse
|
|
||||||
from fastapi.templating import Jinja2Templates
|
|
||||||
|
|
||||||
templates = Jinja2Templates(directory="templates")
|
|
||||||
|
|
||||||
frontend = APIRouter()
|
|
||||||
|
|
||||||
# I'm just making some placeholders here
|
|
||||||
@frontend.get("/")
|
|
||||||
async def index(request: Request) -> HTMLResponse:
|
|
||||||
"""Get frontpage."""
|
|
||||||
return templates.TemplateResponse(request, name="index.html")
|
|
||||||
|
|
||||||
|
|
||||||
@frontend.get("/login")
|
|
||||||
async def login(request: Request) -> HTMLResponse:
|
|
||||||
"""Get login page."""
|
|
||||||
return templates.TemplateResponse(request, name="login.html")
|
|
||||||
|
|
||||||
|
|
||||||
@frontend.get("/clients")
|
|
||||||
async def clients(request: Request) -> HTMLResponse:
|
|
||||||
"""Get login page."""
|
|
||||||
return templates.TemplateResponse(request, name="clients.html")
|
|
||||||
|
|
||||||
@frontend.get("/secrets")
|
|
||||||
async def secrets(request: Request) -> HTMLResponse:
|
|
||||||
"""Get login page."""
|
|
||||||
return templates.TemplateResponse(request, name="secrets.html")
|
|
||||||
@ -1,71 +0,0 @@
|
|||||||
"""Response models."""
|
|
||||||
|
|
||||||
from pydantic import BaseModel, IPvAnyAddress, IPvAnyNetwork, SecretStr
|
|
||||||
|
|
||||||
|
|
||||||
class SSHKeyResponse(BaseModel):
|
|
||||||
"""Response model for updated SSH keys."""
|
|
||||||
|
|
||||||
updated_secrets: list[str]
|
|
||||||
|
|
||||||
|
|
||||||
class SecretListResponse(BaseModel):
|
|
||||||
"""Response for listing secrets."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
assigned_clients: list[str]
|
|
||||||
|
|
||||||
|
|
||||||
class CreateSecretSpecification(BaseModel):
|
|
||||||
"""Model for creating a secret."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
secret: SecretStr | None
|
|
||||||
|
|
||||||
|
|
||||||
class SecretSpecification(BaseModel):
|
|
||||||
"""Secret specification."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
secret: SecretStr
|
|
||||||
|
|
||||||
|
|
||||||
class UpdateSecretSpecification(BaseModel):
|
|
||||||
"""Model for updating a secret."""
|
|
||||||
|
|
||||||
secret: str | None
|
|
||||||
auto_generate: bool | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class RevealSecretResponse(BaseModel):
|
|
||||||
"""Reveal secret."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
secret: str
|
|
||||||
|
|
||||||
|
|
||||||
class MaybeRevalSecretResponse(BaseModel):
|
|
||||||
"""Model where the secret may be specified."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
secret: str | None
|
|
||||||
|
|
||||||
|
|
||||||
class PasswordBody(BaseModel):
|
|
||||||
"""Password body."""
|
|
||||||
|
|
||||||
password: SecretStr
|
|
||||||
|
|
||||||
|
|
||||||
class SessionResponse(BaseModel):
|
|
||||||
"""Session response."""
|
|
||||||
|
|
||||||
session_id: str
|
|
||||||
|
|
||||||
|
|
||||||
class CreateClientModel(BaseModel):
|
|
||||||
"""Model for creating a client."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
public_key: str
|
|
||||||
allowed_ips: list[IPvAnyAddress | IPvAnyNetwork] | str = "*"
|
|
||||||
@ -1,16 +0,0 @@
|
|||||||
"""API router."""
|
|
||||||
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi.staticfiles import StaticFiles
|
|
||||||
|
|
||||||
|
|
||||||
from .api import admin_router
|
|
||||||
from .frontend import frontend
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
app.include_router(admin_router)
|
|
||||||
app.include_router(frontend)
|
|
||||||
|
|
||||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
||||||
Reference in New Issue
Block a user