diff --git a/src/sshecret/backend/__init__.py b/src/sshecret/backend/__init__.py new file mode 100644 index 0000000..be1b7d5 --- /dev/null +++ b/src/sshecret/backend/__init__.py @@ -0,0 +1,28 @@ +"""SSHecret backend.""" + +from .api import SshecretBackend +from .models import ( + AuditLog, + Client, + ClientFilter, + ClientReference, + ClientSecret, + DetailedSecrets, + Policy, + Secret, + FilterType, +) + + +__all__ = [ + "AuditLog", + "Client", + "ClientFilter", + "ClientReference", + "ClientSecret", + "DetailedSecrets", + "FilterType", + "Policy", + "Secret", + "SshecretBackend", +] diff --git a/src/sshecret/backend/api.py b/src/sshecret/backend/api.py new file mode 100644 index 0000000..c35db51 --- /dev/null +++ b/src/sshecret/backend/api.py @@ -0,0 +1,350 @@ +"""Backend client. + +""" + +import logging +from typing import Any, Self +import httpx + +from pydantic import TypeAdapter + +from .models import ( + AuditLog, + Client, + ClientSecret, + ClientQueryResult, + ClientFilter, + DetailedSecrets, + Secret, +) +from .exceptions import BackendValidationError, BackendConnectionError +from .utils import validate_public_key + +LOG = logging.getLogger(__name__) + + +class ClientQueryIterator: + """Asynchronous query iterator.""" + + def __init__( + self, + client: httpx.AsyncClient, + filter_params: ClientFilter | None = None, + ) -> None: + """Create a query iterator.""" + self.client: httpx.AsyncClient = client + if not filter_params: + filter_params = ClientFilter() + self.filter_params: ClientFilter = filter_params + + self._result: ClientQueryResult | None = None + self._clients: list[Client] = [] + self.offset: int = 0 + + async def _get_batch(self) -> None: + """Get batch.""" + params: dict[str, str | int] = { + **self.filter_params.get_params(), + "offset": self.offset, + } + try: + results = await self.client.get("/api/v1/clients/", params=params) + except httpx.TransportError as e: + raise BackendConnectionError() from e + if results.status_code != 200: + raise BackendConnectionError() + self._result = ClientQueryResult.model_validate(results.json()) + self._clients = self._result.clients + + async def get_next_client(self) -> Client | None: + """Get next client. + + When do we know when to stop? + """ + if not self._result: + await self._get_batch() + assert self._result is not None + if self._result.total_results == self.offset: + return None + if not self._clients: + await self._get_batch() + + client = self._clients.pop(0) + self.offset += 1 + return client + + def __aiter__(self) -> Self: + """Iterate async.""" + return self + + async def __anext__(self) -> Client: + """Get next client.""" + if client := await self.get_next_client(): + return client + raise StopAsyncIteration + + +class SshecretBackend: + """Backend interface.""" + + def __init__(self, backend_url: str, api_token: str) -> None: + """Initialize backend client.""" + + url = httpx.URL(backend_url) + + self.http_client: httpx.AsyncClient = httpx.AsyncClient( + headers={"X-Api-Token": api_token}, + base_url=url, + ) + self.sync_client: httpx.Client = httpx.Client( + headers={"X-Api-Token": api_token}, + base_url=url, + ) + + async def _get(self, path: str) -> httpx.Response: + """Perform a get request.""" + try: + return await self.http_client.get(path) + except httpx.ConnectError as e: + raise BackendConnectionError() from e + + async def _delete(self, path: str) -> httpx.Response: + """Perform a delete request.""" + try: + return await self.http_client.delete(path) + except httpx.ConnectError as e: + raise BackendConnectionError() from e + + async def _post(self, path: str, json: Any | None = None) -> httpx.Response: + """Perform a POST request.""" + try: + return await self.http_client.post(path, json=json) + except httpx.ConnectError as e: + raise BackendConnectionError() from e + + async def _put(self, path: str, json: Any | None = None) -> httpx.Response: + """Perform a PUT request.""" + try: + return await self.http_client.put(path, json=json) + except httpx.ConnectError as e: + raise BackendConnectionError() from e + + async def request(self, path: str) -> httpx.Response: + """Send a simple GET request.""" + response = await self._get(path) + return response + + async def create_client( + self, name: str, public_key: str, description: str | None = None + ) -> None: + """Register a new client.""" + if not validate_public_key(public_key): + raise BackendValidationError("Error: Invalid public key format.") + data = { + "name": name, + "public_key": public_key, + } + if description: + data["description"] = description + path = "/api/v1/clients/" + response = await self._post(path, json=data) + + response.raise_for_status() + + async def get_clients(self, filter: ClientFilter | None = None) -> list[Client]: + """Get all clients.""" + clients: list[Client] = [] + async for client in ClientQueryIterator(self.http_client, filter): + clients.append(client) + + return clients + + async def get_client(self, name: str) -> Client | None: + """Lookup a client on username.""" + path = f"/api/v1/clients/{name}" + response = await self.request(path) + if response.status_code == 404: + return None + response.raise_for_status() + client = Client.model_validate(response.json()) + return client + + async def get_client_by_id(self, id: str) -> Client | None: + """Lookup a client on username.""" + path = f"/api/v1/clients/id/{id}" + response = await self.request(path) + if response.status_code == 404: + return None + response.raise_for_status() + client = Client.model_validate(response.json()) + return client + + async def delete_client(self, client_name: str) -> None: + """Delete a client.""" + path = f"/api/v1/clients/{client_name}" + response = await self._delete(path) + + response.raise_for_status() + + async def delete_client_by_id(self, id: str) -> None: + """Delete a client.""" + path = f"/api/v1/clients/id/{id}" + response = await self._delete(path) + + response.raise_for_status() + + async def create_client_secret( + self, client_name: str, secret_name: str, encrypted_secret: str + ) -> None: + """Create a secret. + + This will overwrite any existing secret with that name. + """ + path = f"api/v1/clients/{client_name}/secrets/{secret_name}" + response = await self._put(path, json={"value": encrypted_secret}) + + response.raise_for_status() + + async def get_client_secret(self, name: str, secret_name: str) -> str: + """Fetch a secret.""" + path = f"/api/v1/clients/{name}/secrets/{secret_name}" + response = await self.request(path) + response.raise_for_status() + secret = ClientSecret.model_validate(response.json()) + return secret.secret + + async def delete_client_secret(self, client_name: str, secret_name: str) -> None: + """Delete a secret from a client.""" + path = f"api/v1/clients/{client_name}/secrets/{secret_name}" + response = await self._delete(path) + + response.raise_for_status() + + async def update_client(self, client: Client) -> Client: + """Update the client.""" + path = f"/api/v1/clients/{client.name}" + client_update = { + "name": client.name, + "description": client.description, + "public_key": client.public_key, + } + response = await self._put(path, json=client_update) + LOG.info("Response %s", response.text) + + response.raise_for_status() + if client.policies: + await self.update_client_sources( + str(client.id), [str(source) for source in client.policies] + ) + return client + + async def update_client_key(self, client_name: str, public_key: str) -> None: + """Update the client key.""" + path = f"/api/v1/clients/{client_name}/public-key" + response = await self._post(path, json={"public_key": public_key}) + + response.raise_for_status() + + async def update_client_sources( + self, client_name: str, addresses: list[str] | None + ) -> None: + """Update client source addresses. + + Pass None to sources to allow from all. + """ + if not addresses: + addresses = [] + + path = f"/api/v1/clients/{client_name}/policies/" + response = await self._put(path, json={"sources": addresses}) + + response.raise_for_status() + + async def get_detailed_secrets(self) -> list[DetailedSecrets]: + """Get detailed list of secrets.""" + path = "/api/v1/secrets/detailed/" + response = await self._get(path) + response.raise_for_status() + + secret_list = TypeAdapter(list[DetailedSecrets]) + return secret_list.validate_python(response.json()) + + async def get_secrets(self) -> list[Secret]: + """Get Secrets. + + This provides a list of secret names and which clients have them. + """ + path = "/api/v1/secrets/" + response = await self._get(path) + + response.raise_for_status() + + secret_list = TypeAdapter(list[Secret]) + return secret_list.validate_python(response.json()) + + async def get_secret(self, name: str) -> Secret | None: + """Get clients mapped to a single secret.""" + path = f"/api/v1/secrets/{name}" + response = await self._get(path) + if response.status_code == 404: + return None + response.raise_for_status() + + return Secret.model_validate(response.json()) + + async def get_detailed_secret(self, name: str) -> DetailedSecrets | None: + """Get clients mapped to a single secret.""" + path = f"/api/v1/secrets/{name}/detailed" + response = await self._get(path) + if response.status_code == 404: + return None + response.raise_for_status() + + return DetailedSecrets.model_validate(response.json()) + + async def get_audit_log( + self, + offset: int = 0, + limit: int = 100, + client_name: str | None = None, + subsystem: str | None = None, + ) -> list[AuditLog]: + """Get audit log.""" + path = f"/api/v1/audit/" + params: dict[str, str] = { + "offset": str(offset), + "limit": str(limit), + } + if client_name: + params["filter_client"] = client_name + + if subsystem: + params["filter_subsystem"] = subsystem + + response = await self.http_client.get(path, params=params) + response.raise_for_status() + audit_log_adapter = TypeAdapter(list[AuditLog]) + return audit_log_adapter.validate_python(response.json()) + + async def add_audit_log(self, entry: AuditLog) -> None: + """Add audit log entry.""" + path = f"/api/v1/audit/" + + response = await self.http_client.post(path, json=entry.model_dump()) + response.raise_for_status() + + async def get_audit_log_count(self) -> int: + """Get amount of messages in the audit log.""" + path = f"/api/v1/audit/info" + response = await self._get(path) + response.raise_for_status() + data = response.json() + return int(data["entries"]) + + def add_audit_log_sync(self, entry: AuditLog) -> None: + """Add audit log entry.""" + path = f"/api/v1/audit/" + LOG.info("AUDIT LOG SYNC %r", entry) + + response = self.sync_client.post(path, json=entry.model_dump()) + response.raise_for_status() diff --git a/src/sshecret/backend/exceptions.py b/src/sshecret/backend/exceptions.py new file mode 100644 index 0000000..f2e8691 --- /dev/null +++ b/src/sshecret/backend/exceptions.py @@ -0,0 +1,8 @@ +"""Exceptions.""" + +class BackendValidationError(Exception): + """Validation error.""" + + +class BackendConnectionError(Exception): + """Could not connect to backend server.""" diff --git a/src/sshecret/backend/models.py b/src/sshecret/backend/models.py new file mode 100644 index 0000000..9ca4523 --- /dev/null +++ b/src/sshecret/backend/models.py @@ -0,0 +1,115 @@ +"""Backend models.""" + +import enum +import uuid +from datetime import datetime +from typing import Annotated + +from pydantic import AfterValidator, BaseModel, IPvAnyAddress, IPvAnyNetwork + +from sshecret.crypto import public_key_validator + + +class FilterType(enum.StrEnum): + """Type of filter.""" + + LIKE = "like" + CONTAINS = "contains" + + +class Client(BaseModel): + """Implementation of the backend class ClientView.""" + + id: uuid.UUID + name: str + description: str | None + public_key: Annotated[str, AfterValidator(public_key_validator)] + secrets: list[str] + policies: list[IPvAnyNetwork | IPvAnyAddress] + created_at: datetime + updated_at: datetime | None + + +class ClientQueryResult(BaseModel): + """Implementation of the backend ClientQueryResult class.""" + + clients: list[Client] + total_results: int + remaining_results: int + + +class ClientSecret(BaseModel): + """Implementation of the backend class ClientSecretResponse.""" + + name: str + secret: str + description: str | None + created_at: datetime + updated_at: datetime | None + + +class Secret(BaseModel): + """Implementation of the backend class ClientSecretList.""" + + name: str + clients: list[str] + + +class ClientReference(BaseModel): + """Implementation of the backend class ClientReference.""" + + id: str + name: str + + +class DetailedSecrets(BaseModel): + """Implementation of the backend class ClientSecretDetailList.""" + + name: str + ids: list[str] + clients: list[ClientReference] + + +class Policy(BaseModel): + """Implementation of the backend class ClientPolicyView.""" + + sources: list[IPvAnyNetwork | IPvAnyAddress] + + +class ClientFilter(BaseModel): + """Client filter.""" + + id: str | None = None + name: str | None = None + filter_name: FilterType | None = None + + def get_params(self) -> dict[str, str]: + """Render query parameters.""" + params: dict[str, str] = {} + if not self.id and not self.name: + return params + if self.id: + params["id"] = self.id + if self.name and not self.filter_name: + params["name"] = self.name + elif self.name and self.filter_name is FilterType.LIKE: + params["name__like"] = self.name + elif self.name and self.filter_name is FilterType.CONTAINS: + params["name__contains"] = self.name + + return params + + +class AuditLog(BaseModel): + """Implementation of the backend class AuditLog.""" + + id: str | None = None + subsystem: str | None = None + object: str | None = None + object_id: str | None = None + operation: str + client_id: str | None = None + client_name: str | None = None + message: str + origin: str | None = None + timestamp: datetime | None = None diff --git a/src/sshecret/backend/utils.py b/src/sshecret/backend/utils.py new file mode 100644 index 0000000..df7a4f9 --- /dev/null +++ b/src/sshecret/backend/utils.py @@ -0,0 +1,37 @@ +"""Utility functions.""" + +import logging +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa + + +LOG = logging.getLogger(__name__) + + +def load_public_key(key: bytes | str) -> rsa.RSAPublicKey: + if isinstance(key, str): + key = key.encode() + public_key = serialization.load_ssh_public_key(key) + if not isinstance(public_key, rsa.RSAPublicKey): + raise ValueError("Only RSA keys are supported.") + return public_key + + +def validate_public_key(key: str) -> bool: + """Check if key provided in a string is valid.""" + valid = False + public_key: rsa.RSAPublicKey | None = None + try: + keybytes = key.encode() + public_key = load_public_key(keybytes) + except Exception as e: + LOG.debug("Validation of public key failed: %s", e, exc_info=True) + else: + valid = True + + if not isinstance(public_key, rsa.RSAPublicKey): + valid = False + + if not valid: + raise ValueError("Invalid public key. Must an OpenSSH RSA public key.") + return valid diff --git a/src/sshecret/crypto.py b/src/sshecret/crypto.py new file mode 100644 index 0000000..76f602d --- /dev/null +++ b/src/sshecret/crypto.py @@ -0,0 +1,137 @@ +"""Encryption related functions. + +Note! Encryption uses the less secure PKCS1v15 padding. This is to allow +decryption via openssl on the command line. + +""" + +import base64 +import logging +from pathlib import Path +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric import padding + + +RSA_PUBLIC_EXPONENT = 65537 +RSA_KEY_SIZE = 2048 + +LOG = logging.getLogger(__name__) + + +def load_public_key(keybytes: bytes) -> rsa.RSAPublicKey: + public_key = serialization.load_ssh_public_key(keybytes) + if not isinstance(public_key, rsa.RSAPublicKey): + raise RuntimeError("Only RSA keys are supported.") + return public_key + + +def public_key_validator(public_key: str) -> str: + """Validator for public key.""" + if validate_public_key(public_key): + return public_key + + raise ValueError("Error: Public key is not a valid SSH RSA Public Key.") + + +def validate_public_key(key: str) -> bool: + """Check if key provided in a string is valid.""" + valid = False + public_key: rsa.RSAPublicKey | None = None + try: + keybytes = key.encode() + public_key = load_public_key(keybytes) + except Exception as e: + LOG.debug("Validation of public key failed: %s", e, exc_info=True) + else: + valid = True + + if not isinstance(public_key, rsa.RSAPublicKey): + valid = False + + return valid + + +def load_private_key(filename: str, password: str | None = None) -> rsa.RSAPrivateKey: + """Load a private key.""" + password_bytes: bytes | None = None + if password: + password_bytes = password.encode() + with open(filename, "rb") as f: + private_key = serialization.load_ssh_private_key( + f.read(), password=password_bytes + ) + if not isinstance(private_key, rsa.RSAPrivateKey): + raise RuntimeError("Only RSA keys are supported.") + return private_key + + +def encrypt_string(string: str, public_key: rsa.RSAPublicKey) -> str: + """Encrypt string, end return it base64 encoded.""" + message = string.encode() + ciphertext = public_key.encrypt( + message, + padding.PKCS1v15(), + ) + return base64.b64encode(ciphertext).decode() + + +def decode_string(ciphertext: str, private_key: rsa.RSAPrivateKey) -> str: + """Decode a string. String must be base64 encoded.""" + decoded = base64.b64decode(ciphertext) + decrypted = private_key.decrypt( + decoded, + padding.PKCS1v15(), + ) + return decrypted.decode() + + +def generate_private_key() -> rsa.RSAPrivateKey: + """Generate private RSA key.""" + private_key = rsa.generate_private_key( + public_exponent=RSA_PUBLIC_EXPONENT, key_size=RSA_KEY_SIZE + ) + return private_key + + +def generate_pem(private_key: rsa.RSAPrivateKey) -> str: + """Generate PEM.""" + pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.OpenSSH, + encryption_algorithm=serialization.NoEncryption(), + ) + return pem.decode() + + +def create_private_rsa_key(filename: Path, password: str | None = None) -> None: + """Create an RSA Private key at the given path. + + A password may be provided for secure storage. + """ + if filename.exists(): + raise RuntimeError("Error: private key file already exists.") + LOG.debug("Generating private RSA key at %s", filename) + private_key = generate_private_key() + encryption_algorithm = serialization.NoEncryption() + if password: + password_bytes = password.encode() + encryption_algorithm = serialization.BestAvailableEncryption(password_bytes) + with open(filename, "wb") as f: + pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.OpenSSH, + encryption_algorithm=encryption_algorithm, + ) + lines = f.write(pem) + LOG.debug("Wrote %s lines", lines) + f.flush() + + +def generate_public_key_string(public_key: rsa.RSAPublicKey) -> str: + """Generate public key string.""" + keybytes = public_key.public_bytes( + encoding=serialization.Encoding.OpenSSH, + format=serialization.PublicFormat.OpenSSH, + ) + return keybytes.decode() diff --git a/src/sshecret/settings.py b/src/sshecret/settings.py new file mode 100644 index 0000000..ef21445 --- /dev/null +++ b/src/sshecret/settings.py @@ -0,0 +1,57 @@ +"""Settings management.""" + +from typing import Literal +from pydantic import AnyHttpUrl, BaseModel, Field + +from pydantic_settings import BaseSettings, SettingsConfigDict + + +DEFAULT_SSH_PORT = 2222 +DEFAULT_BACKEND_PORT = 8822 +DEFAULT_ADMIN_PORT = 8022 + + +class SshdSettings(BaseModel): + """Settings for the SSHd server.""" + + name: Literal["sshd"] + listen_address: str = "" + backend_token: str + port: int = DEFAULT_SSH_PORT + + +class AdminSettings(BaseModel): + """Settings for the Admin module.""" + + name: Literal["admin"] + listen_address: str = "" + backend_token: str + port: int = DEFAULT_ADMIN_PORT + + +class BackendSettings(BaseModel): + """Settings for the backend server.""" + + name: Literal["backend"] + listen_address: str = "" + backend_token: str + port: int = DEFAULT_BACKEND_PORT + + +class SshecretSettings(BaseSettings): + """General settings model. + + Should probably be subclassed. + """ + + model_config = SettingsConfigDict( + env_prefix="sshecret_", + env_nested_delimiter="__", + nested_model_default_partial_update=True, + ) + + backend_url: AnyHttpUrl + # This is set up deliberately so only one app can run at the time. + application: SshdSettings | AdminSettings | BackendSettings = Field( + discriminator="name" + ) diff --git a/src/sshecret/testing.py b/src/sshecret/testing.py new file mode 100644 index 0000000..4eadc35 --- /dev/null +++ b/src/sshecret/testing.py @@ -0,0 +1,97 @@ +"""Testing classes.""" + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from pathlib import Path +from typing import cast + +import asyncssh +import httpx + + +@dataclass +class BackendContext: + """Backend context.""" + + url: str + token: str + + @property + def headers(self) -> dict[str, str]: + """Get headers.""" + return { + "X-API-Token": self.token, + } + + def test_client(self) -> httpx.AsyncClient: + """Create a test client.""" + return httpx.AsyncClient(base_url=self.url, headers=self.headers) + + +@dataclass +class AdminContext: + """Admin context.""" + + url: str + username: str = "test" + password: str = "test" + jwt_token: str | None = field(init=False) + + async def login(self) -> None: + """Login to the application.""" + client = httpx.AsyncClient(base_url=self.url) + response = await client.post( + "/api/v1/token", data={"username": self.username, "password": self.password} + ) + assert response.status_code == 200 + token_data = response.json() + access_token = cast(str, token_data["access_token"]) + self.jwt_token = str(access_token) + + @property + def headers(self) -> dict[str, str]: + """Return headers.""" + headers: dict[str, str] = {} + if self.jwt_token: + headers = {"Authorization": f"Bearer {self.jwt_token}"} + + return headers + + def test_client(self) -> httpx.AsyncClient: + """Create a test client.""" + return httpx.AsyncClient(base_url=self.url, headers=self.headers) + + +@dataclass +class SshServerContext: + """SSH server context.""" + + host: str + port: int + + async def run_command( + self, username: str, private_key_file: Path | None, command: str + ) -> str: + """Run command.""" + async with self.connect(username, private_key_file) as conn: + result = await conn.run(command) + assert result.stdout is not None + return str(result.stdout.rstrip()) + + @asynccontextmanager + async def connect( + self, username: str, private_key_file: Path | None + ) -> AsyncIterator[asyncssh.SSHClientConnection]: + """Connect to the server and yield a connection.""" + private_keys: list[str] = [] + if private_key_file: + private_keys.append(str(private_key_file.absolute())) + async with asyncssh.connect( + self.host, + port=self.port, + username=username, + known_hosts=None, + client_keys=private_keys, + ) as conn: + yield conn