Check in common module in working state
This commit is contained in:
28
src/sshecret/backend/__init__.py
Normal file
28
src/sshecret/backend/__init__.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
"""SSHecret backend."""
|
||||||
|
|
||||||
|
from .api import SshecretBackend
|
||||||
|
from .models import (
|
||||||
|
AuditLog,
|
||||||
|
Client,
|
||||||
|
ClientFilter,
|
||||||
|
ClientReference,
|
||||||
|
ClientSecret,
|
||||||
|
DetailedSecrets,
|
||||||
|
Policy,
|
||||||
|
Secret,
|
||||||
|
FilterType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AuditLog",
|
||||||
|
"Client",
|
||||||
|
"ClientFilter",
|
||||||
|
"ClientReference",
|
||||||
|
"ClientSecret",
|
||||||
|
"DetailedSecrets",
|
||||||
|
"FilterType",
|
||||||
|
"Policy",
|
||||||
|
"Secret",
|
||||||
|
"SshecretBackend",
|
||||||
|
]
|
||||||
350
src/sshecret/backend/api.py
Normal file
350
src/sshecret/backend/api.py
Normal file
@ -0,0 +1,350 @@
|
|||||||
|
"""Backend client.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Self
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
from .models import (
|
||||||
|
AuditLog,
|
||||||
|
Client,
|
||||||
|
ClientSecret,
|
||||||
|
ClientQueryResult,
|
||||||
|
ClientFilter,
|
||||||
|
DetailedSecrets,
|
||||||
|
Secret,
|
||||||
|
)
|
||||||
|
from .exceptions import BackendValidationError, BackendConnectionError
|
||||||
|
from .utils import validate_public_key
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ClientQueryIterator:
|
||||||
|
"""Asynchronous query iterator."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
filter_params: ClientFilter | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Create a query iterator."""
|
||||||
|
self.client: httpx.AsyncClient = client
|
||||||
|
if not filter_params:
|
||||||
|
filter_params = ClientFilter()
|
||||||
|
self.filter_params: ClientFilter = filter_params
|
||||||
|
|
||||||
|
self._result: ClientQueryResult | None = None
|
||||||
|
self._clients: list[Client] = []
|
||||||
|
self.offset: int = 0
|
||||||
|
|
||||||
|
async def _get_batch(self) -> None:
|
||||||
|
"""Get batch."""
|
||||||
|
params: dict[str, str | int] = {
|
||||||
|
**self.filter_params.get_params(),
|
||||||
|
"offset": self.offset,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
results = await self.client.get("/api/v1/clients/", params=params)
|
||||||
|
except httpx.TransportError as e:
|
||||||
|
raise BackendConnectionError() from e
|
||||||
|
if results.status_code != 200:
|
||||||
|
raise BackendConnectionError()
|
||||||
|
self._result = ClientQueryResult.model_validate(results.json())
|
||||||
|
self._clients = self._result.clients
|
||||||
|
|
||||||
|
async def get_next_client(self) -> Client | None:
|
||||||
|
"""Get next client.
|
||||||
|
|
||||||
|
When do we know when to stop?
|
||||||
|
"""
|
||||||
|
if not self._result:
|
||||||
|
await self._get_batch()
|
||||||
|
assert self._result is not None
|
||||||
|
if self._result.total_results == self.offset:
|
||||||
|
return None
|
||||||
|
if not self._clients:
|
||||||
|
await self._get_batch()
|
||||||
|
|
||||||
|
client = self._clients.pop(0)
|
||||||
|
self.offset += 1
|
||||||
|
return client
|
||||||
|
|
||||||
|
def __aiter__(self) -> Self:
|
||||||
|
"""Iterate async."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self) -> Client:
|
||||||
|
"""Get next client."""
|
||||||
|
if client := await self.get_next_client():
|
||||||
|
return client
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
|
||||||
|
class SshecretBackend:
|
||||||
|
"""Backend interface."""
|
||||||
|
|
||||||
|
def __init__(self, backend_url: str, api_token: str) -> None:
|
||||||
|
"""Initialize backend client."""
|
||||||
|
|
||||||
|
url = httpx.URL(backend_url)
|
||||||
|
|
||||||
|
self.http_client: httpx.AsyncClient = httpx.AsyncClient(
|
||||||
|
headers={"X-Api-Token": api_token},
|
||||||
|
base_url=url,
|
||||||
|
)
|
||||||
|
self.sync_client: httpx.Client = httpx.Client(
|
||||||
|
headers={"X-Api-Token": api_token},
|
||||||
|
base_url=url,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get(self, path: str) -> httpx.Response:
|
||||||
|
"""Perform a get request."""
|
||||||
|
try:
|
||||||
|
return await self.http_client.get(path)
|
||||||
|
except httpx.ConnectError as e:
|
||||||
|
raise BackendConnectionError() from e
|
||||||
|
|
||||||
|
async def _delete(self, path: str) -> httpx.Response:
|
||||||
|
"""Perform a delete request."""
|
||||||
|
try:
|
||||||
|
return await self.http_client.delete(path)
|
||||||
|
except httpx.ConnectError as e:
|
||||||
|
raise BackendConnectionError() from e
|
||||||
|
|
||||||
|
async def _post(self, path: str, json: Any | None = None) -> httpx.Response:
|
||||||
|
"""Perform a POST request."""
|
||||||
|
try:
|
||||||
|
return await self.http_client.post(path, json=json)
|
||||||
|
except httpx.ConnectError as e:
|
||||||
|
raise BackendConnectionError() from e
|
||||||
|
|
||||||
|
async def _put(self, path: str, json: Any | None = None) -> httpx.Response:
|
||||||
|
"""Perform a PUT request."""
|
||||||
|
try:
|
||||||
|
return await self.http_client.put(path, json=json)
|
||||||
|
except httpx.ConnectError as e:
|
||||||
|
raise BackendConnectionError() from e
|
||||||
|
|
||||||
|
async def request(self, path: str) -> httpx.Response:
|
||||||
|
"""Send a simple GET request."""
|
||||||
|
response = await self._get(path)
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def create_client(
|
||||||
|
self, name: str, public_key: str, description: str | None = None
|
||||||
|
) -> None:
|
||||||
|
"""Register a new client."""
|
||||||
|
if not validate_public_key(public_key):
|
||||||
|
raise BackendValidationError("Error: Invalid public key format.")
|
||||||
|
data = {
|
||||||
|
"name": name,
|
||||||
|
"public_key": public_key,
|
||||||
|
}
|
||||||
|
if description:
|
||||||
|
data["description"] = description
|
||||||
|
path = "/api/v1/clients/"
|
||||||
|
response = await self._post(path, json=data)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
async def get_clients(self, filter: ClientFilter | None = None) -> list[Client]:
|
||||||
|
"""Get all clients."""
|
||||||
|
clients: list[Client] = []
|
||||||
|
async for client in ClientQueryIterator(self.http_client, filter):
|
||||||
|
clients.append(client)
|
||||||
|
|
||||||
|
return clients
|
||||||
|
|
||||||
|
async def get_client(self, name: str) -> Client | None:
|
||||||
|
"""Lookup a client on username."""
|
||||||
|
path = f"/api/v1/clients/{name}"
|
||||||
|
response = await self.request(path)
|
||||||
|
if response.status_code == 404:
|
||||||
|
return None
|
||||||
|
response.raise_for_status()
|
||||||
|
client = Client.model_validate(response.json())
|
||||||
|
return client
|
||||||
|
|
||||||
|
async def get_client_by_id(self, id: str) -> Client | None:
|
||||||
|
"""Lookup a client on username."""
|
||||||
|
path = f"/api/v1/clients/id/{id}"
|
||||||
|
response = await self.request(path)
|
||||||
|
if response.status_code == 404:
|
||||||
|
return None
|
||||||
|
response.raise_for_status()
|
||||||
|
client = Client.model_validate(response.json())
|
||||||
|
return client
|
||||||
|
|
||||||
|
async def delete_client(self, client_name: str) -> None:
|
||||||
|
"""Delete a client."""
|
||||||
|
path = f"/api/v1/clients/{client_name}"
|
||||||
|
response = await self._delete(path)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
async def delete_client_by_id(self, id: str) -> None:
|
||||||
|
"""Delete a client."""
|
||||||
|
path = f"/api/v1/clients/id/{id}"
|
||||||
|
response = await self._delete(path)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
async def create_client_secret(
|
||||||
|
self, client_name: str, secret_name: str, encrypted_secret: str
|
||||||
|
) -> None:
|
||||||
|
"""Create a secret.
|
||||||
|
|
||||||
|
This will overwrite any existing secret with that name.
|
||||||
|
"""
|
||||||
|
path = f"api/v1/clients/{client_name}/secrets/{secret_name}"
|
||||||
|
response = await self._put(path, json={"value": encrypted_secret})
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
async def get_client_secret(self, name: str, secret_name: str) -> str:
|
||||||
|
"""Fetch a secret."""
|
||||||
|
path = f"/api/v1/clients/{name}/secrets/{secret_name}"
|
||||||
|
response = await self.request(path)
|
||||||
|
response.raise_for_status()
|
||||||
|
secret = ClientSecret.model_validate(response.json())
|
||||||
|
return secret.secret
|
||||||
|
|
||||||
|
async def delete_client_secret(self, client_name: str, secret_name: str) -> None:
|
||||||
|
"""Delete a secret from a client."""
|
||||||
|
path = f"api/v1/clients/{client_name}/secrets/{secret_name}"
|
||||||
|
response = await self._delete(path)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
async def update_client(self, client: Client) -> Client:
|
||||||
|
"""Update the client."""
|
||||||
|
path = f"/api/v1/clients/{client.name}"
|
||||||
|
client_update = {
|
||||||
|
"name": client.name,
|
||||||
|
"description": client.description,
|
||||||
|
"public_key": client.public_key,
|
||||||
|
}
|
||||||
|
response = await self._put(path, json=client_update)
|
||||||
|
LOG.info("Response %s", response.text)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
if client.policies:
|
||||||
|
await self.update_client_sources(
|
||||||
|
str(client.id), [str(source) for source in client.policies]
|
||||||
|
)
|
||||||
|
return client
|
||||||
|
|
||||||
|
async def update_client_key(self, client_name: str, public_key: str) -> None:
|
||||||
|
"""Update the client key."""
|
||||||
|
path = f"/api/v1/clients/{client_name}/public-key"
|
||||||
|
response = await self._post(path, json={"public_key": public_key})
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
async def update_client_sources(
|
||||||
|
self, client_name: str, addresses: list[str] | None
|
||||||
|
) -> None:
|
||||||
|
"""Update client source addresses.
|
||||||
|
|
||||||
|
Pass None to sources to allow from all.
|
||||||
|
"""
|
||||||
|
if not addresses:
|
||||||
|
addresses = []
|
||||||
|
|
||||||
|
path = f"/api/v1/clients/{client_name}/policies/"
|
||||||
|
response = await self._put(path, json={"sources": addresses})
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
async def get_detailed_secrets(self) -> list[DetailedSecrets]:
|
||||||
|
"""Get detailed list of secrets."""
|
||||||
|
path = "/api/v1/secrets/detailed/"
|
||||||
|
response = await self._get(path)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
secret_list = TypeAdapter(list[DetailedSecrets])
|
||||||
|
return secret_list.validate_python(response.json())
|
||||||
|
|
||||||
|
async def get_secrets(self) -> list[Secret]:
|
||||||
|
"""Get Secrets.
|
||||||
|
|
||||||
|
This provides a list of secret names and which clients have them.
|
||||||
|
"""
|
||||||
|
path = "/api/v1/secrets/"
|
||||||
|
response = await self._get(path)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
secret_list = TypeAdapter(list[Secret])
|
||||||
|
return secret_list.validate_python(response.json())
|
||||||
|
|
||||||
|
async def get_secret(self, name: str) -> Secret | None:
|
||||||
|
"""Get clients mapped to a single secret."""
|
||||||
|
path = f"/api/v1/secrets/{name}"
|
||||||
|
response = await self._get(path)
|
||||||
|
if response.status_code == 404:
|
||||||
|
return None
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
return Secret.model_validate(response.json())
|
||||||
|
|
||||||
|
async def get_detailed_secret(self, name: str) -> DetailedSecrets | None:
|
||||||
|
"""Get clients mapped to a single secret."""
|
||||||
|
path = f"/api/v1/secrets/{name}/detailed"
|
||||||
|
response = await self._get(path)
|
||||||
|
if response.status_code == 404:
|
||||||
|
return None
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
return DetailedSecrets.model_validate(response.json())
|
||||||
|
|
||||||
|
async def get_audit_log(
|
||||||
|
self,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
client_name: str | None = None,
|
||||||
|
subsystem: str | None = None,
|
||||||
|
) -> list[AuditLog]:
|
||||||
|
"""Get audit log."""
|
||||||
|
path = f"/api/v1/audit/"
|
||||||
|
params: dict[str, str] = {
|
||||||
|
"offset": str(offset),
|
||||||
|
"limit": str(limit),
|
||||||
|
}
|
||||||
|
if client_name:
|
||||||
|
params["filter_client"] = client_name
|
||||||
|
|
||||||
|
if subsystem:
|
||||||
|
params["filter_subsystem"] = subsystem
|
||||||
|
|
||||||
|
response = await self.http_client.get(path, params=params)
|
||||||
|
response.raise_for_status()
|
||||||
|
audit_log_adapter = TypeAdapter(list[AuditLog])
|
||||||
|
return audit_log_adapter.validate_python(response.json())
|
||||||
|
|
||||||
|
async def add_audit_log(self, entry: AuditLog) -> None:
|
||||||
|
"""Add audit log entry."""
|
||||||
|
path = f"/api/v1/audit/"
|
||||||
|
|
||||||
|
response = await self.http_client.post(path, json=entry.model_dump())
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
async def get_audit_log_count(self) -> int:
|
||||||
|
"""Get amount of messages in the audit log."""
|
||||||
|
path = f"/api/v1/audit/info"
|
||||||
|
response = await self._get(path)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
return int(data["entries"])
|
||||||
|
|
||||||
|
def add_audit_log_sync(self, entry: AuditLog) -> None:
|
||||||
|
"""Add audit log entry."""
|
||||||
|
path = f"/api/v1/audit/"
|
||||||
|
LOG.info("AUDIT LOG SYNC %r", entry)
|
||||||
|
|
||||||
|
response = self.sync_client.post(path, json=entry.model_dump())
|
||||||
|
response.raise_for_status()
|
||||||
8
src/sshecret/backend/exceptions.py
Normal file
8
src/sshecret/backend/exceptions.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
"""Exceptions."""
|
||||||
|
|
||||||
|
class BackendValidationError(Exception):
|
||||||
|
"""Validation error."""
|
||||||
|
|
||||||
|
|
||||||
|
class BackendConnectionError(Exception):
|
||||||
|
"""Could not connect to backend server."""
|
||||||
115
src/sshecret/backend/models.py
Normal file
115
src/sshecret/backend/models.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
"""Backend models."""
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from pydantic import AfterValidator, BaseModel, IPvAnyAddress, IPvAnyNetwork
|
||||||
|
|
||||||
|
from sshecret.crypto import public_key_validator
|
||||||
|
|
||||||
|
|
||||||
|
class FilterType(enum.StrEnum):
|
||||||
|
"""Type of filter."""
|
||||||
|
|
||||||
|
LIKE = "like"
|
||||||
|
CONTAINS = "contains"
|
||||||
|
|
||||||
|
|
||||||
|
class Client(BaseModel):
|
||||||
|
"""Implementation of the backend class ClientView."""
|
||||||
|
|
||||||
|
id: uuid.UUID
|
||||||
|
name: str
|
||||||
|
description: str | None
|
||||||
|
public_key: Annotated[str, AfterValidator(public_key_validator)]
|
||||||
|
secrets: list[str]
|
||||||
|
policies: list[IPvAnyNetwork | IPvAnyAddress]
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime | None
|
||||||
|
|
||||||
|
|
||||||
|
class ClientQueryResult(BaseModel):
|
||||||
|
"""Implementation of the backend ClientQueryResult class."""
|
||||||
|
|
||||||
|
clients: list[Client]
|
||||||
|
total_results: int
|
||||||
|
remaining_results: int
|
||||||
|
|
||||||
|
|
||||||
|
class ClientSecret(BaseModel):
|
||||||
|
"""Implementation of the backend class ClientSecretResponse."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
secret: str
|
||||||
|
description: str | None
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime | None
|
||||||
|
|
||||||
|
|
||||||
|
class Secret(BaseModel):
|
||||||
|
"""Implementation of the backend class ClientSecretList."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
clients: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ClientReference(BaseModel):
|
||||||
|
"""Implementation of the backend class ClientReference."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class DetailedSecrets(BaseModel):
|
||||||
|
"""Implementation of the backend class ClientSecretDetailList."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
ids: list[str]
|
||||||
|
clients: list[ClientReference]
|
||||||
|
|
||||||
|
|
||||||
|
class Policy(BaseModel):
|
||||||
|
"""Implementation of the backend class ClientPolicyView."""
|
||||||
|
|
||||||
|
sources: list[IPvAnyNetwork | IPvAnyAddress]
|
||||||
|
|
||||||
|
|
||||||
|
class ClientFilter(BaseModel):
|
||||||
|
"""Client filter."""
|
||||||
|
|
||||||
|
id: str | None = None
|
||||||
|
name: str | None = None
|
||||||
|
filter_name: FilterType | None = None
|
||||||
|
|
||||||
|
def get_params(self) -> dict[str, str]:
|
||||||
|
"""Render query parameters."""
|
||||||
|
params: dict[str, str] = {}
|
||||||
|
if not self.id and not self.name:
|
||||||
|
return params
|
||||||
|
if self.id:
|
||||||
|
params["id"] = self.id
|
||||||
|
if self.name and not self.filter_name:
|
||||||
|
params["name"] = self.name
|
||||||
|
elif self.name and self.filter_name is FilterType.LIKE:
|
||||||
|
params["name__like"] = self.name
|
||||||
|
elif self.name and self.filter_name is FilterType.CONTAINS:
|
||||||
|
params["name__contains"] = self.name
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
class AuditLog(BaseModel):
|
||||||
|
"""Implementation of the backend class AuditLog."""
|
||||||
|
|
||||||
|
id: str | None = None
|
||||||
|
subsystem: str | None = None
|
||||||
|
object: str | None = None
|
||||||
|
object_id: str | None = None
|
||||||
|
operation: str
|
||||||
|
client_id: str | None = None
|
||||||
|
client_name: str | None = None
|
||||||
|
message: str
|
||||||
|
origin: str | None = None
|
||||||
|
timestamp: datetime | None = None
|
||||||
37
src/sshecret/backend/utils.py
Normal file
37
src/sshecret/backend/utils.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
"""Utility functions."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
|
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def load_public_key(key: bytes | str) -> rsa.RSAPublicKey:
|
||||||
|
if isinstance(key, str):
|
||||||
|
key = key.encode()
|
||||||
|
public_key = serialization.load_ssh_public_key(key)
|
||||||
|
if not isinstance(public_key, rsa.RSAPublicKey):
|
||||||
|
raise ValueError("Only RSA keys are supported.")
|
||||||
|
return public_key
|
||||||
|
|
||||||
|
|
||||||
|
def validate_public_key(key: str) -> bool:
|
||||||
|
"""Check if key provided in a string is valid."""
|
||||||
|
valid = False
|
||||||
|
public_key: rsa.RSAPublicKey | None = None
|
||||||
|
try:
|
||||||
|
keybytes = key.encode()
|
||||||
|
public_key = load_public_key(keybytes)
|
||||||
|
except Exception as e:
|
||||||
|
LOG.debug("Validation of public key failed: %s", e, exc_info=True)
|
||||||
|
else:
|
||||||
|
valid = True
|
||||||
|
|
||||||
|
if not isinstance(public_key, rsa.RSAPublicKey):
|
||||||
|
valid = False
|
||||||
|
|
||||||
|
if not valid:
|
||||||
|
raise ValueError("Invalid public key. Must an OpenSSH RSA public key.")
|
||||||
|
return valid
|
||||||
137
src/sshecret/crypto.py
Normal file
137
src/sshecret/crypto.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
"""Encryption related functions.
|
||||||
|
|
||||||
|
Note! Encryption uses the less secure PKCS1v15 padding. This is to allow
|
||||||
|
decryption via openssl on the command line.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import padding
|
||||||
|
|
||||||
|
|
||||||
|
RSA_PUBLIC_EXPONENT = 65537
|
||||||
|
RSA_KEY_SIZE = 2048
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def load_public_key(keybytes: bytes) -> rsa.RSAPublicKey:
|
||||||
|
public_key = serialization.load_ssh_public_key(keybytes)
|
||||||
|
if not isinstance(public_key, rsa.RSAPublicKey):
|
||||||
|
raise RuntimeError("Only RSA keys are supported.")
|
||||||
|
return public_key
|
||||||
|
|
||||||
|
|
||||||
|
def public_key_validator(public_key: str) -> str:
|
||||||
|
"""Validator for public key."""
|
||||||
|
if validate_public_key(public_key):
|
||||||
|
return public_key
|
||||||
|
|
||||||
|
raise ValueError("Error: Public key is not a valid SSH RSA Public Key.")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_public_key(key: str) -> bool:
|
||||||
|
"""Check if key provided in a string is valid."""
|
||||||
|
valid = False
|
||||||
|
public_key: rsa.RSAPublicKey | None = None
|
||||||
|
try:
|
||||||
|
keybytes = key.encode()
|
||||||
|
public_key = load_public_key(keybytes)
|
||||||
|
except Exception as e:
|
||||||
|
LOG.debug("Validation of public key failed: %s", e, exc_info=True)
|
||||||
|
else:
|
||||||
|
valid = True
|
||||||
|
|
||||||
|
if not isinstance(public_key, rsa.RSAPublicKey):
|
||||||
|
valid = False
|
||||||
|
|
||||||
|
return valid
|
||||||
|
|
||||||
|
|
||||||
|
def load_private_key(filename: str, password: str | None = None) -> rsa.RSAPrivateKey:
|
||||||
|
"""Load a private key."""
|
||||||
|
password_bytes: bytes | None = None
|
||||||
|
if password:
|
||||||
|
password_bytes = password.encode()
|
||||||
|
with open(filename, "rb") as f:
|
||||||
|
private_key = serialization.load_ssh_private_key(
|
||||||
|
f.read(), password=password_bytes
|
||||||
|
)
|
||||||
|
if not isinstance(private_key, rsa.RSAPrivateKey):
|
||||||
|
raise RuntimeError("Only RSA keys are supported.")
|
||||||
|
return private_key
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt_string(string: str, public_key: rsa.RSAPublicKey) -> str:
|
||||||
|
"""Encrypt string, end return it base64 encoded."""
|
||||||
|
message = string.encode()
|
||||||
|
ciphertext = public_key.encrypt(
|
||||||
|
message,
|
||||||
|
padding.PKCS1v15(),
|
||||||
|
)
|
||||||
|
return base64.b64encode(ciphertext).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def decode_string(ciphertext: str, private_key: rsa.RSAPrivateKey) -> str:
|
||||||
|
"""Decode a string. String must be base64 encoded."""
|
||||||
|
decoded = base64.b64decode(ciphertext)
|
||||||
|
decrypted = private_key.decrypt(
|
||||||
|
decoded,
|
||||||
|
padding.PKCS1v15(),
|
||||||
|
)
|
||||||
|
return decrypted.decode()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_private_key() -> rsa.RSAPrivateKey:
|
||||||
|
"""Generate private RSA key."""
|
||||||
|
private_key = rsa.generate_private_key(
|
||||||
|
public_exponent=RSA_PUBLIC_EXPONENT, key_size=RSA_KEY_SIZE
|
||||||
|
)
|
||||||
|
return private_key
|
||||||
|
|
||||||
|
|
||||||
|
def generate_pem(private_key: rsa.RSAPrivateKey) -> str:
|
||||||
|
"""Generate PEM."""
|
||||||
|
pem = private_key.private_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PrivateFormat.OpenSSH,
|
||||||
|
encryption_algorithm=serialization.NoEncryption(),
|
||||||
|
)
|
||||||
|
return pem.decode()
|
||||||
|
|
||||||
|
|
||||||
|
def create_private_rsa_key(filename: Path, password: str | None = None) -> None:
|
||||||
|
"""Create an RSA Private key at the given path.
|
||||||
|
|
||||||
|
A password may be provided for secure storage.
|
||||||
|
"""
|
||||||
|
if filename.exists():
|
||||||
|
raise RuntimeError("Error: private key file already exists.")
|
||||||
|
LOG.debug("Generating private RSA key at %s", filename)
|
||||||
|
private_key = generate_private_key()
|
||||||
|
encryption_algorithm = serialization.NoEncryption()
|
||||||
|
if password:
|
||||||
|
password_bytes = password.encode()
|
||||||
|
encryption_algorithm = serialization.BestAvailableEncryption(password_bytes)
|
||||||
|
with open(filename, "wb") as f:
|
||||||
|
pem = private_key.private_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PrivateFormat.OpenSSH,
|
||||||
|
encryption_algorithm=encryption_algorithm,
|
||||||
|
)
|
||||||
|
lines = f.write(pem)
|
||||||
|
LOG.debug("Wrote %s lines", lines)
|
||||||
|
f.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_public_key_string(public_key: rsa.RSAPublicKey) -> str:
|
||||||
|
"""Generate public key string."""
|
||||||
|
keybytes = public_key.public_bytes(
|
||||||
|
encoding=serialization.Encoding.OpenSSH,
|
||||||
|
format=serialization.PublicFormat.OpenSSH,
|
||||||
|
)
|
||||||
|
return keybytes.decode()
|
||||||
57
src/sshecret/settings.py
Normal file
57
src/sshecret/settings.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
"""Settings management."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
from pydantic import AnyHttpUrl, BaseModel, Field
|
||||||
|
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_SSH_PORT = 2222
|
||||||
|
DEFAULT_BACKEND_PORT = 8822
|
||||||
|
DEFAULT_ADMIN_PORT = 8022
|
||||||
|
|
||||||
|
|
||||||
|
class SshdSettings(BaseModel):
|
||||||
|
"""Settings for the SSHd server."""
|
||||||
|
|
||||||
|
name: Literal["sshd"]
|
||||||
|
listen_address: str = ""
|
||||||
|
backend_token: str
|
||||||
|
port: int = DEFAULT_SSH_PORT
|
||||||
|
|
||||||
|
|
||||||
|
class AdminSettings(BaseModel):
|
||||||
|
"""Settings for the Admin module."""
|
||||||
|
|
||||||
|
name: Literal["admin"]
|
||||||
|
listen_address: str = ""
|
||||||
|
backend_token: str
|
||||||
|
port: int = DEFAULT_ADMIN_PORT
|
||||||
|
|
||||||
|
|
||||||
|
class BackendSettings(BaseModel):
|
||||||
|
"""Settings for the backend server."""
|
||||||
|
|
||||||
|
name: Literal["backend"]
|
||||||
|
listen_address: str = ""
|
||||||
|
backend_token: str
|
||||||
|
port: int = DEFAULT_BACKEND_PORT
|
||||||
|
|
||||||
|
|
||||||
|
class SshecretSettings(BaseSettings):
|
||||||
|
"""General settings model.
|
||||||
|
|
||||||
|
Should probably be subclassed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
env_prefix="sshecret_",
|
||||||
|
env_nested_delimiter="__",
|
||||||
|
nested_model_default_partial_update=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
backend_url: AnyHttpUrl
|
||||||
|
# This is set up deliberately so only one app can run at the time.
|
||||||
|
application: SshdSettings | AdminSettings | BackendSettings = Field(
|
||||||
|
discriminator="name"
|
||||||
|
)
|
||||||
97
src/sshecret/testing.py
Normal file
97
src/sshecret/testing.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
"""Testing classes."""
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import asyncssh
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BackendContext:
|
||||||
|
"""Backend context."""
|
||||||
|
|
||||||
|
url: str
|
||||||
|
token: str
|
||||||
|
|
||||||
|
@property
|
||||||
|
def headers(self) -> dict[str, str]:
|
||||||
|
"""Get headers."""
|
||||||
|
return {
|
||||||
|
"X-API-Token": self.token,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_client(self) -> httpx.AsyncClient:
|
||||||
|
"""Create a test client."""
|
||||||
|
return httpx.AsyncClient(base_url=self.url, headers=self.headers)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AdminContext:
|
||||||
|
"""Admin context."""
|
||||||
|
|
||||||
|
url: str
|
||||||
|
username: str = "test"
|
||||||
|
password: str = "test"
|
||||||
|
jwt_token: str | None = field(init=False)
|
||||||
|
|
||||||
|
async def login(self) -> None:
|
||||||
|
"""Login to the application."""
|
||||||
|
client = httpx.AsyncClient(base_url=self.url)
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/token", data={"username": self.username, "password": self.password}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
token_data = response.json()
|
||||||
|
access_token = cast(str, token_data["access_token"])
|
||||||
|
self.jwt_token = str(access_token)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def headers(self) -> dict[str, str]:
|
||||||
|
"""Return headers."""
|
||||||
|
headers: dict[str, str] = {}
|
||||||
|
if self.jwt_token:
|
||||||
|
headers = {"Authorization": f"Bearer {self.jwt_token}"}
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def test_client(self) -> httpx.AsyncClient:
|
||||||
|
"""Create a test client."""
|
||||||
|
return httpx.AsyncClient(base_url=self.url, headers=self.headers)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SshServerContext:
|
||||||
|
"""SSH server context."""
|
||||||
|
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
|
||||||
|
async def run_command(
|
||||||
|
self, username: str, private_key_file: Path | None, command: str
|
||||||
|
) -> str:
|
||||||
|
"""Run command."""
|
||||||
|
async with self.connect(username, private_key_file) as conn:
|
||||||
|
result = await conn.run(command)
|
||||||
|
assert result.stdout is not None
|
||||||
|
return str(result.stdout.rstrip())
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def connect(
|
||||||
|
self, username: str, private_key_file: Path | None
|
||||||
|
) -> AsyncIterator[asyncssh.SSHClientConnection]:
|
||||||
|
"""Connect to the server and yield a connection."""
|
||||||
|
private_keys: list[str] = []
|
||||||
|
if private_key_file:
|
||||||
|
private_keys.append(str(private_key_file.absolute()))
|
||||||
|
async with asyncssh.connect(
|
||||||
|
self.host,
|
||||||
|
port=self.port,
|
||||||
|
username=username,
|
||||||
|
known_hosts=None,
|
||||||
|
client_keys=private_keys,
|
||||||
|
) as conn:
|
||||||
|
yield conn
|
||||||
Reference in New Issue
Block a user