Check in common module in working state
This commit is contained in:
28
src/sshecret/backend/__init__.py
Normal file
28
src/sshecret/backend/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
"""SSHecret backend."""
|
||||
|
||||
from .api import SshecretBackend
|
||||
from .models import (
|
||||
AuditLog,
|
||||
Client,
|
||||
ClientFilter,
|
||||
ClientReference,
|
||||
ClientSecret,
|
||||
DetailedSecrets,
|
||||
Policy,
|
||||
Secret,
|
||||
FilterType,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AuditLog",
|
||||
"Client",
|
||||
"ClientFilter",
|
||||
"ClientReference",
|
||||
"ClientSecret",
|
||||
"DetailedSecrets",
|
||||
"FilterType",
|
||||
"Policy",
|
||||
"Secret",
|
||||
"SshecretBackend",
|
||||
]
|
||||
350
src/sshecret/backend/api.py
Normal file
350
src/sshecret/backend/api.py
Normal file
@ -0,0 +1,350 @@
|
||||
"""Backend client.
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Self
|
||||
import httpx
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from .models import (
|
||||
AuditLog,
|
||||
Client,
|
||||
ClientSecret,
|
||||
ClientQueryResult,
|
||||
ClientFilter,
|
||||
DetailedSecrets,
|
||||
Secret,
|
||||
)
|
||||
from .exceptions import BackendValidationError, BackendConnectionError
|
||||
from .utils import validate_public_key
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClientQueryIterator:
|
||||
"""Asynchronous query iterator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
filter_params: ClientFilter | None = None,
|
||||
) -> None:
|
||||
"""Create a query iterator."""
|
||||
self.client: httpx.AsyncClient = client
|
||||
if not filter_params:
|
||||
filter_params = ClientFilter()
|
||||
self.filter_params: ClientFilter = filter_params
|
||||
|
||||
self._result: ClientQueryResult | None = None
|
||||
self._clients: list[Client] = []
|
||||
self.offset: int = 0
|
||||
|
||||
async def _get_batch(self) -> None:
|
||||
"""Get batch."""
|
||||
params: dict[str, str | int] = {
|
||||
**self.filter_params.get_params(),
|
||||
"offset": self.offset,
|
||||
}
|
||||
try:
|
||||
results = await self.client.get("/api/v1/clients/", params=params)
|
||||
except httpx.TransportError as e:
|
||||
raise BackendConnectionError() from e
|
||||
if results.status_code != 200:
|
||||
raise BackendConnectionError()
|
||||
self._result = ClientQueryResult.model_validate(results.json())
|
||||
self._clients = self._result.clients
|
||||
|
||||
async def get_next_client(self) -> Client | None:
|
||||
"""Get next client.
|
||||
|
||||
When do we know when to stop?
|
||||
"""
|
||||
if not self._result:
|
||||
await self._get_batch()
|
||||
assert self._result is not None
|
||||
if self._result.total_results == self.offset:
|
||||
return None
|
||||
if not self._clients:
|
||||
await self._get_batch()
|
||||
|
||||
client = self._clients.pop(0)
|
||||
self.offset += 1
|
||||
return client
|
||||
|
||||
def __aiter__(self) -> Self:
|
||||
"""Iterate async."""
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> Client:
|
||||
"""Get next client."""
|
||||
if client := await self.get_next_client():
|
||||
return client
|
||||
raise StopAsyncIteration
|
||||
|
||||
|
||||
class SshecretBackend:
|
||||
"""Backend interface."""
|
||||
|
||||
def __init__(self, backend_url: str, api_token: str) -> None:
|
||||
"""Initialize backend client."""
|
||||
|
||||
url = httpx.URL(backend_url)
|
||||
|
||||
self.http_client: httpx.AsyncClient = httpx.AsyncClient(
|
||||
headers={"X-Api-Token": api_token},
|
||||
base_url=url,
|
||||
)
|
||||
self.sync_client: httpx.Client = httpx.Client(
|
||||
headers={"X-Api-Token": api_token},
|
||||
base_url=url,
|
||||
)
|
||||
|
||||
async def _get(self, path: str) -> httpx.Response:
|
||||
"""Perform a get request."""
|
||||
try:
|
||||
return await self.http_client.get(path)
|
||||
except httpx.ConnectError as e:
|
||||
raise BackendConnectionError() from e
|
||||
|
||||
async def _delete(self, path: str) -> httpx.Response:
|
||||
"""Perform a delete request."""
|
||||
try:
|
||||
return await self.http_client.delete(path)
|
||||
except httpx.ConnectError as e:
|
||||
raise BackendConnectionError() from e
|
||||
|
||||
async def _post(self, path: str, json: Any | None = None) -> httpx.Response:
|
||||
"""Perform a POST request."""
|
||||
try:
|
||||
return await self.http_client.post(path, json=json)
|
||||
except httpx.ConnectError as e:
|
||||
raise BackendConnectionError() from e
|
||||
|
||||
async def _put(self, path: str, json: Any | None = None) -> httpx.Response:
|
||||
"""Perform a PUT request."""
|
||||
try:
|
||||
return await self.http_client.put(path, json=json)
|
||||
except httpx.ConnectError as e:
|
||||
raise BackendConnectionError() from e
|
||||
|
||||
async def request(self, path: str) -> httpx.Response:
|
||||
"""Send a simple GET request."""
|
||||
response = await self._get(path)
|
||||
return response
|
||||
|
||||
async def create_client(
|
||||
self, name: str, public_key: str, description: str | None = None
|
||||
) -> None:
|
||||
"""Register a new client."""
|
||||
if not validate_public_key(public_key):
|
||||
raise BackendValidationError("Error: Invalid public key format.")
|
||||
data = {
|
||||
"name": name,
|
||||
"public_key": public_key,
|
||||
}
|
||||
if description:
|
||||
data["description"] = description
|
||||
path = "/api/v1/clients/"
|
||||
response = await self._post(path, json=data)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
async def get_clients(self, filter: ClientFilter | None = None) -> list[Client]:
|
||||
"""Get all clients."""
|
||||
clients: list[Client] = []
|
||||
async for client in ClientQueryIterator(self.http_client, filter):
|
||||
clients.append(client)
|
||||
|
||||
return clients
|
||||
|
||||
async def get_client(self, name: str) -> Client | None:
|
||||
"""Lookup a client on username."""
|
||||
path = f"/api/v1/clients/{name}"
|
||||
response = await self.request(path)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
response.raise_for_status()
|
||||
client = Client.model_validate(response.json())
|
||||
return client
|
||||
|
||||
async def get_client_by_id(self, id: str) -> Client | None:
|
||||
"""Lookup a client on username."""
|
||||
path = f"/api/v1/clients/id/{id}"
|
||||
response = await self.request(path)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
response.raise_for_status()
|
||||
client = Client.model_validate(response.json())
|
||||
return client
|
||||
|
||||
async def delete_client(self, client_name: str) -> None:
|
||||
"""Delete a client."""
|
||||
path = f"/api/v1/clients/{client_name}"
|
||||
response = await self._delete(path)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
async def delete_client_by_id(self, id: str) -> None:
|
||||
"""Delete a client."""
|
||||
path = f"/api/v1/clients/id/{id}"
|
||||
response = await self._delete(path)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
async def create_client_secret(
|
||||
self, client_name: str, secret_name: str, encrypted_secret: str
|
||||
) -> None:
|
||||
"""Create a secret.
|
||||
|
||||
This will overwrite any existing secret with that name.
|
||||
"""
|
||||
path = f"api/v1/clients/{client_name}/secrets/{secret_name}"
|
||||
response = await self._put(path, json={"value": encrypted_secret})
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
async def get_client_secret(self, name: str, secret_name: str) -> str:
|
||||
"""Fetch a secret."""
|
||||
path = f"/api/v1/clients/{name}/secrets/{secret_name}"
|
||||
response = await self.request(path)
|
||||
response.raise_for_status()
|
||||
secret = ClientSecret.model_validate(response.json())
|
||||
return secret.secret
|
||||
|
||||
async def delete_client_secret(self, client_name: str, secret_name: str) -> None:
|
||||
"""Delete a secret from a client."""
|
||||
path = f"api/v1/clients/{client_name}/secrets/{secret_name}"
|
||||
response = await self._delete(path)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
async def update_client(self, client: Client) -> Client:
|
||||
"""Update the client."""
|
||||
path = f"/api/v1/clients/{client.name}"
|
||||
client_update = {
|
||||
"name": client.name,
|
||||
"description": client.description,
|
||||
"public_key": client.public_key,
|
||||
}
|
||||
response = await self._put(path, json=client_update)
|
||||
LOG.info("Response %s", response.text)
|
||||
|
||||
response.raise_for_status()
|
||||
if client.policies:
|
||||
await self.update_client_sources(
|
||||
str(client.id), [str(source) for source in client.policies]
|
||||
)
|
||||
return client
|
||||
|
||||
async def update_client_key(self, client_name: str, public_key: str) -> None:
|
||||
"""Update the client key."""
|
||||
path = f"/api/v1/clients/{client_name}/public-key"
|
||||
response = await self._post(path, json={"public_key": public_key})
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
async def update_client_sources(
|
||||
self, client_name: str, addresses: list[str] | None
|
||||
) -> None:
|
||||
"""Update client source addresses.
|
||||
|
||||
Pass None to sources to allow from all.
|
||||
"""
|
||||
if not addresses:
|
||||
addresses = []
|
||||
|
||||
path = f"/api/v1/clients/{client_name}/policies/"
|
||||
response = await self._put(path, json={"sources": addresses})
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
async def get_detailed_secrets(self) -> list[DetailedSecrets]:
|
||||
"""Get detailed list of secrets."""
|
||||
path = "/api/v1/secrets/detailed/"
|
||||
response = await self._get(path)
|
||||
response.raise_for_status()
|
||||
|
||||
secret_list = TypeAdapter(list[DetailedSecrets])
|
||||
return secret_list.validate_python(response.json())
|
||||
|
||||
async def get_secrets(self) -> list[Secret]:
|
||||
"""Get Secrets.
|
||||
|
||||
This provides a list of secret names and which clients have them.
|
||||
"""
|
||||
path = "/api/v1/secrets/"
|
||||
response = await self._get(path)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
secret_list = TypeAdapter(list[Secret])
|
||||
return secret_list.validate_python(response.json())
|
||||
|
||||
async def get_secret(self, name: str) -> Secret | None:
|
||||
"""Get clients mapped to a single secret."""
|
||||
path = f"/api/v1/secrets/{name}"
|
||||
response = await self._get(path)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
response.raise_for_status()
|
||||
|
||||
return Secret.model_validate(response.json())
|
||||
|
||||
async def get_detailed_secret(self, name: str) -> DetailedSecrets | None:
|
||||
"""Get clients mapped to a single secret."""
|
||||
path = f"/api/v1/secrets/{name}/detailed"
|
||||
response = await self._get(path)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
response.raise_for_status()
|
||||
|
||||
return DetailedSecrets.model_validate(response.json())
|
||||
|
||||
async def get_audit_log(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 100,
|
||||
client_name: str | None = None,
|
||||
subsystem: str | None = None,
|
||||
) -> list[AuditLog]:
|
||||
"""Get audit log."""
|
||||
path = f"/api/v1/audit/"
|
||||
params: dict[str, str] = {
|
||||
"offset": str(offset),
|
||||
"limit": str(limit),
|
||||
}
|
||||
if client_name:
|
||||
params["filter_client"] = client_name
|
||||
|
||||
if subsystem:
|
||||
params["filter_subsystem"] = subsystem
|
||||
|
||||
response = await self.http_client.get(path, params=params)
|
||||
response.raise_for_status()
|
||||
audit_log_adapter = TypeAdapter(list[AuditLog])
|
||||
return audit_log_adapter.validate_python(response.json())
|
||||
|
||||
async def add_audit_log(self, entry: AuditLog) -> None:
|
||||
"""Add audit log entry."""
|
||||
path = f"/api/v1/audit/"
|
||||
|
||||
response = await self.http_client.post(path, json=entry.model_dump())
|
||||
response.raise_for_status()
|
||||
|
||||
async def get_audit_log_count(self) -> int:
|
||||
"""Get amount of messages in the audit log."""
|
||||
path = f"/api/v1/audit/info"
|
||||
response = await self._get(path)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return int(data["entries"])
|
||||
|
||||
def add_audit_log_sync(self, entry: AuditLog) -> None:
|
||||
"""Add audit log entry."""
|
||||
path = f"/api/v1/audit/"
|
||||
LOG.info("AUDIT LOG SYNC %r", entry)
|
||||
|
||||
response = self.sync_client.post(path, json=entry.model_dump())
|
||||
response.raise_for_status()
|
||||
8
src/sshecret/backend/exceptions.py
Normal file
8
src/sshecret/backend/exceptions.py
Normal file
@ -0,0 +1,8 @@
|
||||
"""Exceptions."""
|
||||
|
||||
class BackendValidationError(Exception):
|
||||
"""Validation error."""
|
||||
|
||||
|
||||
class BackendConnectionError(Exception):
|
||||
"""Could not connect to backend server."""
|
||||
115
src/sshecret/backend/models.py
Normal file
115
src/sshecret/backend/models.py
Normal file
@ -0,0 +1,115 @@
|
||||
"""Backend models."""
|
||||
|
||||
import enum
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, IPvAnyAddress, IPvAnyNetwork
|
||||
|
||||
from sshecret.crypto import public_key_validator
|
||||
|
||||
|
||||
class FilterType(enum.StrEnum):
|
||||
"""Type of filter."""
|
||||
|
||||
LIKE = "like"
|
||||
CONTAINS = "contains"
|
||||
|
||||
|
||||
class Client(BaseModel):
|
||||
"""Implementation of the backend class ClientView."""
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
description: str | None
|
||||
public_key: Annotated[str, AfterValidator(public_key_validator)]
|
||||
secrets: list[str]
|
||||
policies: list[IPvAnyNetwork | IPvAnyAddress]
|
||||
created_at: datetime
|
||||
updated_at: datetime | None
|
||||
|
||||
|
||||
class ClientQueryResult(BaseModel):
|
||||
"""Implementation of the backend ClientQueryResult class."""
|
||||
|
||||
clients: list[Client]
|
||||
total_results: int
|
||||
remaining_results: int
|
||||
|
||||
|
||||
class ClientSecret(BaseModel):
|
||||
"""Implementation of the backend class ClientSecretResponse."""
|
||||
|
||||
name: str
|
||||
secret: str
|
||||
description: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime | None
|
||||
|
||||
|
||||
class Secret(BaseModel):
|
||||
"""Implementation of the backend class ClientSecretList."""
|
||||
|
||||
name: str
|
||||
clients: list[str]
|
||||
|
||||
|
||||
class ClientReference(BaseModel):
|
||||
"""Implementation of the backend class ClientReference."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class DetailedSecrets(BaseModel):
|
||||
"""Implementation of the backend class ClientSecretDetailList."""
|
||||
|
||||
name: str
|
||||
ids: list[str]
|
||||
clients: list[ClientReference]
|
||||
|
||||
|
||||
class Policy(BaseModel):
|
||||
"""Implementation of the backend class ClientPolicyView."""
|
||||
|
||||
sources: list[IPvAnyNetwork | IPvAnyAddress]
|
||||
|
||||
|
||||
class ClientFilter(BaseModel):
|
||||
"""Client filter."""
|
||||
|
||||
id: str | None = None
|
||||
name: str | None = None
|
||||
filter_name: FilterType | None = None
|
||||
|
||||
def get_params(self) -> dict[str, str]:
|
||||
"""Render query parameters."""
|
||||
params: dict[str, str] = {}
|
||||
if not self.id and not self.name:
|
||||
return params
|
||||
if self.id:
|
||||
params["id"] = self.id
|
||||
if self.name and not self.filter_name:
|
||||
params["name"] = self.name
|
||||
elif self.name and self.filter_name is FilterType.LIKE:
|
||||
params["name__like"] = self.name
|
||||
elif self.name and self.filter_name is FilterType.CONTAINS:
|
||||
params["name__contains"] = self.name
|
||||
|
||||
return params
|
||||
|
||||
|
||||
class AuditLog(BaseModel):
|
||||
"""Implementation of the backend class AuditLog."""
|
||||
|
||||
id: str | None = None
|
||||
subsystem: str | None = None
|
||||
object: str | None = None
|
||||
object_id: str | None = None
|
||||
operation: str
|
||||
client_id: str | None = None
|
||||
client_name: str | None = None
|
||||
message: str
|
||||
origin: str | None = None
|
||||
timestamp: datetime | None = None
|
||||
37
src/sshecret/backend/utils.py
Normal file
37
src/sshecret/backend/utils.py
Normal file
@ -0,0 +1,37 @@
|
||||
"""Utility functions."""
|
||||
|
||||
import logging
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_public_key(key: bytes | str) -> rsa.RSAPublicKey:
|
||||
if isinstance(key, str):
|
||||
key = key.encode()
|
||||
public_key = serialization.load_ssh_public_key(key)
|
||||
if not isinstance(public_key, rsa.RSAPublicKey):
|
||||
raise ValueError("Only RSA keys are supported.")
|
||||
return public_key
|
||||
|
||||
|
||||
def validate_public_key(key: str) -> bool:
|
||||
"""Check if key provided in a string is valid."""
|
||||
valid = False
|
||||
public_key: rsa.RSAPublicKey | None = None
|
||||
try:
|
||||
keybytes = key.encode()
|
||||
public_key = load_public_key(keybytes)
|
||||
except Exception as e:
|
||||
LOG.debug("Validation of public key failed: %s", e, exc_info=True)
|
||||
else:
|
||||
valid = True
|
||||
|
||||
if not isinstance(public_key, rsa.RSAPublicKey):
|
||||
valid = False
|
||||
|
||||
if not valid:
|
||||
raise ValueError("Invalid public key. Must an OpenSSH RSA public key.")
|
||||
return valid
|
||||
137
src/sshecret/crypto.py
Normal file
137
src/sshecret/crypto.py
Normal file
@ -0,0 +1,137 @@
|
||||
"""Encryption related functions.
|
||||
|
||||
Note! Encryption uses the less secure PKCS1v15 padding. This is to allow
|
||||
decryption via openssl on the command line.
|
||||
|
||||
"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
|
||||
|
||||
RSA_PUBLIC_EXPONENT = 65537
|
||||
RSA_KEY_SIZE = 2048
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_public_key(keybytes: bytes) -> rsa.RSAPublicKey:
|
||||
public_key = serialization.load_ssh_public_key(keybytes)
|
||||
if not isinstance(public_key, rsa.RSAPublicKey):
|
||||
raise RuntimeError("Only RSA keys are supported.")
|
||||
return public_key
|
||||
|
||||
|
||||
def public_key_validator(public_key: str) -> str:
|
||||
"""Validator for public key."""
|
||||
if validate_public_key(public_key):
|
||||
return public_key
|
||||
|
||||
raise ValueError("Error: Public key is not a valid SSH RSA Public Key.")
|
||||
|
||||
|
||||
def validate_public_key(key: str) -> bool:
|
||||
"""Check if key provided in a string is valid."""
|
||||
valid = False
|
||||
public_key: rsa.RSAPublicKey | None = None
|
||||
try:
|
||||
keybytes = key.encode()
|
||||
public_key = load_public_key(keybytes)
|
||||
except Exception as e:
|
||||
LOG.debug("Validation of public key failed: %s", e, exc_info=True)
|
||||
else:
|
||||
valid = True
|
||||
|
||||
if not isinstance(public_key, rsa.RSAPublicKey):
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
|
||||
def load_private_key(filename: str, password: str | None = None) -> rsa.RSAPrivateKey:
|
||||
"""Load a private key."""
|
||||
password_bytes: bytes | None = None
|
||||
if password:
|
||||
password_bytes = password.encode()
|
||||
with open(filename, "rb") as f:
|
||||
private_key = serialization.load_ssh_private_key(
|
||||
f.read(), password=password_bytes
|
||||
)
|
||||
if not isinstance(private_key, rsa.RSAPrivateKey):
|
||||
raise RuntimeError("Only RSA keys are supported.")
|
||||
return private_key
|
||||
|
||||
|
||||
def encrypt_string(string: str, public_key: rsa.RSAPublicKey) -> str:
|
||||
"""Encrypt string, end return it base64 encoded."""
|
||||
message = string.encode()
|
||||
ciphertext = public_key.encrypt(
|
||||
message,
|
||||
padding.PKCS1v15(),
|
||||
)
|
||||
return base64.b64encode(ciphertext).decode()
|
||||
|
||||
|
||||
def decode_string(ciphertext: str, private_key: rsa.RSAPrivateKey) -> str:
|
||||
"""Decode a string. String must be base64 encoded."""
|
||||
decoded = base64.b64decode(ciphertext)
|
||||
decrypted = private_key.decrypt(
|
||||
decoded,
|
||||
padding.PKCS1v15(),
|
||||
)
|
||||
return decrypted.decode()
|
||||
|
||||
|
||||
def generate_private_key() -> rsa.RSAPrivateKey:
|
||||
"""Generate private RSA key."""
|
||||
private_key = rsa.generate_private_key(
|
||||
public_exponent=RSA_PUBLIC_EXPONENT, key_size=RSA_KEY_SIZE
|
||||
)
|
||||
return private_key
|
||||
|
||||
|
||||
def generate_pem(private_key: rsa.RSAPrivateKey) -> str:
|
||||
"""Generate PEM."""
|
||||
pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.OpenSSH,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
return pem.decode()
|
||||
|
||||
|
||||
def create_private_rsa_key(filename: Path, password: str | None = None) -> None:
|
||||
"""Create an RSA Private key at the given path.
|
||||
|
||||
A password may be provided for secure storage.
|
||||
"""
|
||||
if filename.exists():
|
||||
raise RuntimeError("Error: private key file already exists.")
|
||||
LOG.debug("Generating private RSA key at %s", filename)
|
||||
private_key = generate_private_key()
|
||||
encryption_algorithm = serialization.NoEncryption()
|
||||
if password:
|
||||
password_bytes = password.encode()
|
||||
encryption_algorithm = serialization.BestAvailableEncryption(password_bytes)
|
||||
with open(filename, "wb") as f:
|
||||
pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.OpenSSH,
|
||||
encryption_algorithm=encryption_algorithm,
|
||||
)
|
||||
lines = f.write(pem)
|
||||
LOG.debug("Wrote %s lines", lines)
|
||||
f.flush()
|
||||
|
||||
|
||||
def generate_public_key_string(public_key: rsa.RSAPublicKey) -> str:
|
||||
"""Generate public key string."""
|
||||
keybytes = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.OpenSSH,
|
||||
format=serialization.PublicFormat.OpenSSH,
|
||||
)
|
||||
return keybytes.decode()
|
||||
57
src/sshecret/settings.py
Normal file
57
src/sshecret/settings.py
Normal file
@ -0,0 +1,57 @@
|
||||
"""Settings management."""
|
||||
|
||||
from typing import Literal
|
||||
from pydantic import AnyHttpUrl, BaseModel, Field
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
DEFAULT_SSH_PORT = 2222
|
||||
DEFAULT_BACKEND_PORT = 8822
|
||||
DEFAULT_ADMIN_PORT = 8022
|
||||
|
||||
|
||||
class SshdSettings(BaseModel):
|
||||
"""Settings for the SSHd server."""
|
||||
|
||||
name: Literal["sshd"]
|
||||
listen_address: str = ""
|
||||
backend_token: str
|
||||
port: int = DEFAULT_SSH_PORT
|
||||
|
||||
|
||||
class AdminSettings(BaseModel):
|
||||
"""Settings for the Admin module."""
|
||||
|
||||
name: Literal["admin"]
|
||||
listen_address: str = ""
|
||||
backend_token: str
|
||||
port: int = DEFAULT_ADMIN_PORT
|
||||
|
||||
|
||||
class BackendSettings(BaseModel):
|
||||
"""Settings for the backend server."""
|
||||
|
||||
name: Literal["backend"]
|
||||
listen_address: str = ""
|
||||
backend_token: str
|
||||
port: int = DEFAULT_BACKEND_PORT
|
||||
|
||||
|
||||
class SshecretSettings(BaseSettings):
|
||||
"""General settings model.
|
||||
|
||||
Should probably be subclassed.
|
||||
"""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="sshecret_",
|
||||
env_nested_delimiter="__",
|
||||
nested_model_default_partial_update=True,
|
||||
)
|
||||
|
||||
backend_url: AnyHttpUrl
|
||||
# This is set up deliberately so only one app can run at the time.
|
||||
application: SshdSettings | AdminSettings | BackendSettings = Field(
|
||||
discriminator="name"
|
||||
)
|
||||
97
src/sshecret/testing.py
Normal file
97
src/sshecret/testing.py
Normal file
@ -0,0 +1,97 @@
|
||||
"""Testing classes."""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import asyncssh
|
||||
import httpx
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackendContext:
|
||||
"""Backend context."""
|
||||
|
||||
url: str
|
||||
token: str
|
||||
|
||||
@property
|
||||
def headers(self) -> dict[str, str]:
|
||||
"""Get headers."""
|
||||
return {
|
||||
"X-API-Token": self.token,
|
||||
}
|
||||
|
||||
def test_client(self) -> httpx.AsyncClient:
|
||||
"""Create a test client."""
|
||||
return httpx.AsyncClient(base_url=self.url, headers=self.headers)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminContext:
|
||||
"""Admin context."""
|
||||
|
||||
url: str
|
||||
username: str = "test"
|
||||
password: str = "test"
|
||||
jwt_token: str | None = field(init=False)
|
||||
|
||||
async def login(self) -> None:
|
||||
"""Login to the application."""
|
||||
client = httpx.AsyncClient(base_url=self.url)
|
||||
response = await client.post(
|
||||
"/api/v1/token", data={"username": self.username, "password": self.password}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
token_data = response.json()
|
||||
access_token = cast(str, token_data["access_token"])
|
||||
self.jwt_token = str(access_token)
|
||||
|
||||
@property
|
||||
def headers(self) -> dict[str, str]:
|
||||
"""Return headers."""
|
||||
headers: dict[str, str] = {}
|
||||
if self.jwt_token:
|
||||
headers = {"Authorization": f"Bearer {self.jwt_token}"}
|
||||
|
||||
return headers
|
||||
|
||||
def test_client(self) -> httpx.AsyncClient:
|
||||
"""Create a test client."""
|
||||
return httpx.AsyncClient(base_url=self.url, headers=self.headers)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SshServerContext:
|
||||
"""SSH server context."""
|
||||
|
||||
host: str
|
||||
port: int
|
||||
|
||||
async def run_command(
|
||||
self, username: str, private_key_file: Path | None, command: str
|
||||
) -> str:
|
||||
"""Run command."""
|
||||
async with self.connect(username, private_key_file) as conn:
|
||||
result = await conn.run(command)
|
||||
assert result.stdout is not None
|
||||
return str(result.stdout.rstrip())
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect(
|
||||
self, username: str, private_key_file: Path | None
|
||||
) -> AsyncIterator[asyncssh.SSHClientConnection]:
|
||||
"""Connect to the server and yield a connection."""
|
||||
private_keys: list[str] = []
|
||||
if private_key_file:
|
||||
private_keys.append(str(private_key_file.absolute()))
|
||||
async with asyncssh.connect(
|
||||
self.host,
|
||||
port=self.port,
|
||||
username=username,
|
||||
known_hosts=None,
|
||||
client_keys=private_keys,
|
||||
) as conn:
|
||||
yield conn
|
||||
Reference in New Issue
Block a user