Check in common module in working state

This commit is contained in:
2025-04-30 08:24:20 +02:00
parent 20f1ee707a
commit f4dd5e8e0c
8 changed files with 829 additions and 0 deletions

View File

@ -0,0 +1,28 @@
"""SSHecret backend."""
from .api import SshecretBackend
from .models import (
AuditLog,
Client,
ClientFilter,
ClientReference,
ClientSecret,
DetailedSecrets,
Policy,
Secret,
FilterType,
)
__all__ = [
"AuditLog",
"Client",
"ClientFilter",
"ClientReference",
"ClientSecret",
"DetailedSecrets",
"FilterType",
"Policy",
"Secret",
"SshecretBackend",
]

350
src/sshecret/backend/api.py Normal file
View File

@ -0,0 +1,350 @@
"""Backend client.
"""
import logging
from typing import Any, Self
import httpx
from pydantic import TypeAdapter
from .models import (
AuditLog,
Client,
ClientSecret,
ClientQueryResult,
ClientFilter,
DetailedSecrets,
Secret,
)
from .exceptions import BackendValidationError, BackendConnectionError
from .utils import validate_public_key
LOG = logging.getLogger(__name__)
class ClientQueryIterator:
"""Asynchronous query iterator."""
def __init__(
self,
client: httpx.AsyncClient,
filter_params: ClientFilter | None = None,
) -> None:
"""Create a query iterator."""
self.client: httpx.AsyncClient = client
if not filter_params:
filter_params = ClientFilter()
self.filter_params: ClientFilter = filter_params
self._result: ClientQueryResult | None = None
self._clients: list[Client] = []
self.offset: int = 0
async def _get_batch(self) -> None:
"""Get batch."""
params: dict[str, str | int] = {
**self.filter_params.get_params(),
"offset": self.offset,
}
try:
results = await self.client.get("/api/v1/clients/", params=params)
except httpx.TransportError as e:
raise BackendConnectionError() from e
if results.status_code != 200:
raise BackendConnectionError()
self._result = ClientQueryResult.model_validate(results.json())
self._clients = self._result.clients
async def get_next_client(self) -> Client | None:
"""Get next client.
When do we know when to stop?
"""
if not self._result:
await self._get_batch()
assert self._result is not None
if self._result.total_results == self.offset:
return None
if not self._clients:
await self._get_batch()
client = self._clients.pop(0)
self.offset += 1
return client
def __aiter__(self) -> Self:
"""Iterate async."""
return self
async def __anext__(self) -> Client:
"""Get next client."""
if client := await self.get_next_client():
return client
raise StopAsyncIteration
class SshecretBackend:
"""Backend interface."""
def __init__(self, backend_url: str, api_token: str) -> None:
"""Initialize backend client."""
url = httpx.URL(backend_url)
self.http_client: httpx.AsyncClient = httpx.AsyncClient(
headers={"X-Api-Token": api_token},
base_url=url,
)
self.sync_client: httpx.Client = httpx.Client(
headers={"X-Api-Token": api_token},
base_url=url,
)
async def _get(self, path: str) -> httpx.Response:
"""Perform a get request."""
try:
return await self.http_client.get(path)
except httpx.ConnectError as e:
raise BackendConnectionError() from e
async def _delete(self, path: str) -> httpx.Response:
"""Perform a delete request."""
try:
return await self.http_client.delete(path)
except httpx.ConnectError as e:
raise BackendConnectionError() from e
async def _post(self, path: str, json: Any | None = None) -> httpx.Response:
"""Perform a POST request."""
try:
return await self.http_client.post(path, json=json)
except httpx.ConnectError as e:
raise BackendConnectionError() from e
async def _put(self, path: str, json: Any | None = None) -> httpx.Response:
"""Perform a PUT request."""
try:
return await self.http_client.put(path, json=json)
except httpx.ConnectError as e:
raise BackendConnectionError() from e
async def request(self, path: str) -> httpx.Response:
"""Send a simple GET request."""
response = await self._get(path)
return response
async def create_client(
self, name: str, public_key: str, description: str | None = None
) -> None:
"""Register a new client."""
if not validate_public_key(public_key):
raise BackendValidationError("Error: Invalid public key format.")
data = {
"name": name,
"public_key": public_key,
}
if description:
data["description"] = description
path = "/api/v1/clients/"
response = await self._post(path, json=data)
response.raise_for_status()
async def get_clients(self, filter: ClientFilter | None = None) -> list[Client]:
"""Get all clients."""
clients: list[Client] = []
async for client in ClientQueryIterator(self.http_client, filter):
clients.append(client)
return clients
async def get_client(self, name: str) -> Client | None:
"""Lookup a client on username."""
path = f"/api/v1/clients/{name}"
response = await self.request(path)
if response.status_code == 404:
return None
response.raise_for_status()
client = Client.model_validate(response.json())
return client
async def get_client_by_id(self, id: str) -> Client | None:
"""Lookup a client on username."""
path = f"/api/v1/clients/id/{id}"
response = await self.request(path)
if response.status_code == 404:
return None
response.raise_for_status()
client = Client.model_validate(response.json())
return client
async def delete_client(self, client_name: str) -> None:
"""Delete a client."""
path = f"/api/v1/clients/{client_name}"
response = await self._delete(path)
response.raise_for_status()
async def delete_client_by_id(self, id: str) -> None:
"""Delete a client."""
path = f"/api/v1/clients/id/{id}"
response = await self._delete(path)
response.raise_for_status()
async def create_client_secret(
self, client_name: str, secret_name: str, encrypted_secret: str
) -> None:
"""Create a secret.
This will overwrite any existing secret with that name.
"""
path = f"api/v1/clients/{client_name}/secrets/{secret_name}"
response = await self._put(path, json={"value": encrypted_secret})
response.raise_for_status()
async def get_client_secret(self, name: str, secret_name: str) -> str:
"""Fetch a secret."""
path = f"/api/v1/clients/{name}/secrets/{secret_name}"
response = await self.request(path)
response.raise_for_status()
secret = ClientSecret.model_validate(response.json())
return secret.secret
async def delete_client_secret(self, client_name: str, secret_name: str) -> None:
"""Delete a secret from a client."""
path = f"api/v1/clients/{client_name}/secrets/{secret_name}"
response = await self._delete(path)
response.raise_for_status()
async def update_client(self, client: Client) -> Client:
"""Update the client."""
path = f"/api/v1/clients/{client.name}"
client_update = {
"name": client.name,
"description": client.description,
"public_key": client.public_key,
}
response = await self._put(path, json=client_update)
LOG.info("Response %s", response.text)
response.raise_for_status()
if client.policies:
await self.update_client_sources(
str(client.id), [str(source) for source in client.policies]
)
return client
async def update_client_key(self, client_name: str, public_key: str) -> None:
"""Update the client key."""
path = f"/api/v1/clients/{client_name}/public-key"
response = await self._post(path, json={"public_key": public_key})
response.raise_for_status()
async def update_client_sources(
self, client_name: str, addresses: list[str] | None
) -> None:
"""Update client source addresses.
Pass None to sources to allow from all.
"""
if not addresses:
addresses = []
path = f"/api/v1/clients/{client_name}/policies/"
response = await self._put(path, json={"sources": addresses})
response.raise_for_status()
async def get_detailed_secrets(self) -> list[DetailedSecrets]:
"""Get detailed list of secrets."""
path = "/api/v1/secrets/detailed/"
response = await self._get(path)
response.raise_for_status()
secret_list = TypeAdapter(list[DetailedSecrets])
return secret_list.validate_python(response.json())
async def get_secrets(self) -> list[Secret]:
"""Get Secrets.
This provides a list of secret names and which clients have them.
"""
path = "/api/v1/secrets/"
response = await self._get(path)
response.raise_for_status()
secret_list = TypeAdapter(list[Secret])
return secret_list.validate_python(response.json())
async def get_secret(self, name: str) -> Secret | None:
"""Get clients mapped to a single secret."""
path = f"/api/v1/secrets/{name}"
response = await self._get(path)
if response.status_code == 404:
return None
response.raise_for_status()
return Secret.model_validate(response.json())
async def get_detailed_secret(self, name: str) -> DetailedSecrets | None:
"""Get clients mapped to a single secret."""
path = f"/api/v1/secrets/{name}/detailed"
response = await self._get(path)
if response.status_code == 404:
return None
response.raise_for_status()
return DetailedSecrets.model_validate(response.json())
async def get_audit_log(
self,
offset: int = 0,
limit: int = 100,
client_name: str | None = None,
subsystem: str | None = None,
) -> list[AuditLog]:
"""Get audit log."""
path = f"/api/v1/audit/"
params: dict[str, str] = {
"offset": str(offset),
"limit": str(limit),
}
if client_name:
params["filter_client"] = client_name
if subsystem:
params["filter_subsystem"] = subsystem
response = await self.http_client.get(path, params=params)
response.raise_for_status()
audit_log_adapter = TypeAdapter(list[AuditLog])
return audit_log_adapter.validate_python(response.json())
async def add_audit_log(self, entry: AuditLog) -> None:
"""Add audit log entry."""
path = f"/api/v1/audit/"
response = await self.http_client.post(path, json=entry.model_dump())
response.raise_for_status()
async def get_audit_log_count(self) -> int:
"""Get amount of messages in the audit log."""
path = f"/api/v1/audit/info"
response = await self._get(path)
response.raise_for_status()
data = response.json()
return int(data["entries"])
def add_audit_log_sync(self, entry: AuditLog) -> None:
"""Add audit log entry."""
path = f"/api/v1/audit/"
LOG.info("AUDIT LOG SYNC %r", entry)
response = self.sync_client.post(path, json=entry.model_dump())
response.raise_for_status()

View File

@ -0,0 +1,8 @@
"""Exceptions."""
class BackendValidationError(Exception):
"""Validation error."""
class BackendConnectionError(Exception):
"""Could not connect to backend server."""

View File

@ -0,0 +1,115 @@
"""Backend models."""
import enum
import uuid
from datetime import datetime
from typing import Annotated
from pydantic import AfterValidator, BaseModel, IPvAnyAddress, IPvAnyNetwork
from sshecret.crypto import public_key_validator
class FilterType(enum.StrEnum):
"""Type of filter."""
LIKE = "like"
CONTAINS = "contains"
class Client(BaseModel):
"""Implementation of the backend class ClientView."""
id: uuid.UUID
name: str
description: str | None
public_key: Annotated[str, AfterValidator(public_key_validator)]
secrets: list[str]
policies: list[IPvAnyNetwork | IPvAnyAddress]
created_at: datetime
updated_at: datetime | None
class ClientQueryResult(BaseModel):
"""Implementation of the backend ClientQueryResult class."""
clients: list[Client]
total_results: int
remaining_results: int
class ClientSecret(BaseModel):
"""Implementation of the backend class ClientSecretResponse."""
name: str
secret: str
description: str | None
created_at: datetime
updated_at: datetime | None
class Secret(BaseModel):
"""Implementation of the backend class ClientSecretList."""
name: str
clients: list[str]
class ClientReference(BaseModel):
"""Implementation of the backend class ClientReference."""
id: str
name: str
class DetailedSecrets(BaseModel):
"""Implementation of the backend class ClientSecretDetailList."""
name: str
ids: list[str]
clients: list[ClientReference]
class Policy(BaseModel):
"""Implementation of the backend class ClientPolicyView."""
sources: list[IPvAnyNetwork | IPvAnyAddress]
class ClientFilter(BaseModel):
"""Client filter."""
id: str | None = None
name: str | None = None
filter_name: FilterType | None = None
def get_params(self) -> dict[str, str]:
"""Render query parameters."""
params: dict[str, str] = {}
if not self.id and not self.name:
return params
if self.id:
params["id"] = self.id
if self.name and not self.filter_name:
params["name"] = self.name
elif self.name and self.filter_name is FilterType.LIKE:
params["name__like"] = self.name
elif self.name and self.filter_name is FilterType.CONTAINS:
params["name__contains"] = self.name
return params
class AuditLog(BaseModel):
"""Implementation of the backend class AuditLog."""
id: str | None = None
subsystem: str | None = None
object: str | None = None
object_id: str | None = None
operation: str
client_id: str | None = None
client_name: str | None = None
message: str
origin: str | None = None
timestamp: datetime | None = None

View File

@ -0,0 +1,37 @@
"""Utility functions."""
import logging
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
LOG = logging.getLogger(__name__)
def load_public_key(key: bytes | str) -> rsa.RSAPublicKey:
if isinstance(key, str):
key = key.encode()
public_key = serialization.load_ssh_public_key(key)
if not isinstance(public_key, rsa.RSAPublicKey):
raise ValueError("Only RSA keys are supported.")
return public_key
def validate_public_key(key: str) -> bool:
"""Check if key provided in a string is valid."""
valid = False
public_key: rsa.RSAPublicKey | None = None
try:
keybytes = key.encode()
public_key = load_public_key(keybytes)
except Exception as e:
LOG.debug("Validation of public key failed: %s", e, exc_info=True)
else:
valid = True
if not isinstance(public_key, rsa.RSAPublicKey):
valid = False
if not valid:
raise ValueError("Invalid public key. Must an OpenSSH RSA public key.")
return valid

137
src/sshecret/crypto.py Normal file
View File

@ -0,0 +1,137 @@
"""Encryption related functions.
Note! Encryption uses the less secure PKCS1v15 padding. This is to allow
decryption via openssl on the command line.
"""
import base64
import logging
from pathlib import Path
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric import padding
RSA_PUBLIC_EXPONENT = 65537
RSA_KEY_SIZE = 2048
LOG = logging.getLogger(__name__)
def load_public_key(keybytes: bytes) -> rsa.RSAPublicKey:
public_key = serialization.load_ssh_public_key(keybytes)
if not isinstance(public_key, rsa.RSAPublicKey):
raise RuntimeError("Only RSA keys are supported.")
return public_key
def public_key_validator(public_key: str) -> str:
"""Validator for public key."""
if validate_public_key(public_key):
return public_key
raise ValueError("Error: Public key is not a valid SSH RSA Public Key.")
def validate_public_key(key: str) -> bool:
"""Check if key provided in a string is valid."""
valid = False
public_key: rsa.RSAPublicKey | None = None
try:
keybytes = key.encode()
public_key = load_public_key(keybytes)
except Exception as e:
LOG.debug("Validation of public key failed: %s", e, exc_info=True)
else:
valid = True
if not isinstance(public_key, rsa.RSAPublicKey):
valid = False
return valid
def load_private_key(filename: str, password: str | None = None) -> rsa.RSAPrivateKey:
"""Load a private key."""
password_bytes: bytes | None = None
if password:
password_bytes = password.encode()
with open(filename, "rb") as f:
private_key = serialization.load_ssh_private_key(
f.read(), password=password_bytes
)
if not isinstance(private_key, rsa.RSAPrivateKey):
raise RuntimeError("Only RSA keys are supported.")
return private_key
def encrypt_string(string: str, public_key: rsa.RSAPublicKey) -> str:
"""Encrypt string, end return it base64 encoded."""
message = string.encode()
ciphertext = public_key.encrypt(
message,
padding.PKCS1v15(),
)
return base64.b64encode(ciphertext).decode()
def decode_string(ciphertext: str, private_key: rsa.RSAPrivateKey) -> str:
"""Decode a string. String must be base64 encoded."""
decoded = base64.b64decode(ciphertext)
decrypted = private_key.decrypt(
decoded,
padding.PKCS1v15(),
)
return decrypted.decode()
def generate_private_key() -> rsa.RSAPrivateKey:
"""Generate private RSA key."""
private_key = rsa.generate_private_key(
public_exponent=RSA_PUBLIC_EXPONENT, key_size=RSA_KEY_SIZE
)
return private_key
def generate_pem(private_key: rsa.RSAPrivateKey) -> str:
"""Generate PEM."""
pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.OpenSSH,
encryption_algorithm=serialization.NoEncryption(),
)
return pem.decode()
def create_private_rsa_key(filename: Path, password: str | None = None) -> None:
"""Create an RSA Private key at the given path.
A password may be provided for secure storage.
"""
if filename.exists():
raise RuntimeError("Error: private key file already exists.")
LOG.debug("Generating private RSA key at %s", filename)
private_key = generate_private_key()
encryption_algorithm = serialization.NoEncryption()
if password:
password_bytes = password.encode()
encryption_algorithm = serialization.BestAvailableEncryption(password_bytes)
with open(filename, "wb") as f:
pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.OpenSSH,
encryption_algorithm=encryption_algorithm,
)
lines = f.write(pem)
LOG.debug("Wrote %s lines", lines)
f.flush()
def generate_public_key_string(public_key: rsa.RSAPublicKey) -> str:
"""Generate public key string."""
keybytes = public_key.public_bytes(
encoding=serialization.Encoding.OpenSSH,
format=serialization.PublicFormat.OpenSSH,
)
return keybytes.decode()

57
src/sshecret/settings.py Normal file
View File

@ -0,0 +1,57 @@
"""Settings management."""
from typing import Literal
from pydantic import AnyHttpUrl, BaseModel, Field
from pydantic_settings import BaseSettings, SettingsConfigDict
DEFAULT_SSH_PORT = 2222
DEFAULT_BACKEND_PORT = 8822
DEFAULT_ADMIN_PORT = 8022
class SshdSettings(BaseModel):
"""Settings for the SSHd server."""
name: Literal["sshd"]
listen_address: str = ""
backend_token: str
port: int = DEFAULT_SSH_PORT
class AdminSettings(BaseModel):
"""Settings for the Admin module."""
name: Literal["admin"]
listen_address: str = ""
backend_token: str
port: int = DEFAULT_ADMIN_PORT
class BackendSettings(BaseModel):
"""Settings for the backend server."""
name: Literal["backend"]
listen_address: str = ""
backend_token: str
port: int = DEFAULT_BACKEND_PORT
class SshecretSettings(BaseSettings):
"""General settings model.
Should probably be subclassed.
"""
model_config = SettingsConfigDict(
env_prefix="sshecret_",
env_nested_delimiter="__",
nested_model_default_partial_update=True,
)
backend_url: AnyHttpUrl
# This is set up deliberately so only one app can run at the time.
application: SshdSettings | AdminSettings | BackendSettings = Field(
discriminator="name"
)

97
src/sshecret/testing.py Normal file
View File

@ -0,0 +1,97 @@
"""Testing classes."""
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from pathlib import Path
from typing import cast
import asyncssh
import httpx
@dataclass
class BackendContext:
"""Backend context."""
url: str
token: str
@property
def headers(self) -> dict[str, str]:
"""Get headers."""
return {
"X-API-Token": self.token,
}
def test_client(self) -> httpx.AsyncClient:
"""Create a test client."""
return httpx.AsyncClient(base_url=self.url, headers=self.headers)
@dataclass
class AdminContext:
"""Admin context."""
url: str
username: str = "test"
password: str = "test"
jwt_token: str | None = field(init=False)
async def login(self) -> None:
"""Login to the application."""
client = httpx.AsyncClient(base_url=self.url)
response = await client.post(
"/api/v1/token", data={"username": self.username, "password": self.password}
)
assert response.status_code == 200
token_data = response.json()
access_token = cast(str, token_data["access_token"])
self.jwt_token = str(access_token)
@property
def headers(self) -> dict[str, str]:
"""Return headers."""
headers: dict[str, str] = {}
if self.jwt_token:
headers = {"Authorization": f"Bearer {self.jwt_token}"}
return headers
def test_client(self) -> httpx.AsyncClient:
"""Create a test client."""
return httpx.AsyncClient(base_url=self.url, headers=self.headers)
@dataclass
class SshServerContext:
"""SSH server context."""
host: str
port: int
async def run_command(
self, username: str, private_key_file: Path | None, command: str
) -> str:
"""Run command."""
async with self.connect(username, private_key_file) as conn:
result = await conn.run(command)
assert result.stdout is not None
return str(result.stdout.rstrip())
@asynccontextmanager
async def connect(
self, username: str, private_key_file: Path | None
) -> AsyncIterator[asyncssh.SSHClientConnection]:
"""Connect to the server and yield a connection."""
private_keys: list[str] = []
if private_key_file:
private_keys.append(str(private_key_file.absolute()))
async with asyncssh.connect(
self.host,
port=self.port,
username=username,
known_hosts=None,
client_keys=private_keys,
) as conn:
yield conn