From 0eaa913e35a5cc2067bfc6486b8f677d03d875cc Mon Sep 17 00:00:00 2001 From: Allan Eising Date: Tue, 10 Jun 2025 10:28:17 +0200 Subject: [PATCH] Implement podman-compatible commands --- .../sshecret_admin/services/admin_backend.py | 2 +- .../api/clients/operations.py | 8 +- .../sshecret_backend/api/clients/schemas.py | 11 +- .../src/sshecret_backend/api/common.py | 36 +-- .../api/secrets/operations.py | 2 + .../src/sshecret_sshd/commands/dispatcher.py | 17 +- .../src/sshecret_sshd/commands/shelldriver.py | 163 ++++++++++++ tests/integration/test_sshd.py | 236 ++++++++++++++++-- 8 files changed, 414 insertions(+), 61 deletions(-) create mode 100644 packages/sshecret-sshd/src/sshecret_sshd/commands/shelldriver.py diff --git a/packages/sshecret-admin/src/sshecret_admin/services/admin_backend.py b/packages/sshecret-admin/src/sshecret_admin/services/admin_backend.py index 288da77..271eb16 100644 --- a/packages/sshecret-admin/src/sshecret_admin/services/admin_backend.py +++ b/packages/sshecret-admin/src/sshecret_admin/services/admin_backend.py @@ -16,7 +16,7 @@ from sshecret.backend import ( Operation, SubSystem, ) -from sshecret.backend.models import DetailedSecrets, Secret +from sshecret.backend.models import DetailedSecrets from sshecret.backend.api import AuditAPI, KeySpec from sshecret.crypto import encrypt_string, load_public_key diff --git a/packages/sshecret-backend/src/sshecret_backend/api/clients/operations.py b/packages/sshecret-backend/src/sshecret_backend/api/clients/operations.py index a86bb47..f02002e 100644 --- a/packages/sshecret-backend/src/sshecret_backend/api/clients/operations.py +++ b/packages/sshecret-backend/src/sshecret_backend/api/clients/operations.py @@ -130,6 +130,7 @@ class ClientOperations: if not db_client: raise HTTPException(status_code=404, detail="Client not found.") if db_client.is_deleted: + LOG.warning("Client %r was already deleted!", client) return db_client.is_deleted = True db_client.deleted_at = datetime.now(timezone.utc) @@ -271,7 +272,12 @@ async def get_clients( filter_query: ClientListParams, ) -> ClientQueryResult: """Get Clients.""" - count_statement = select(func.count("*")).select_from(Client).where(Client.is_deleted.is_not(True)).where(Client.is_active.is_not(False)) + count_statement = ( + select(func.count("*")) + .select_from(Client) + .where(Client.is_deleted.is_not(True)) + .where(Client.is_active.is_not(False)) + ) count_statement = cast( Select[tuple[int]], filter_client_statement(count_statement, filter_query, True), diff --git a/packages/sshecret-backend/src/sshecret_backend/api/clients/schemas.py b/packages/sshecret-backend/src/sshecret_backend/api/clients/schemas.py index d5fa680..0231020 100644 --- a/packages/sshecret-backend/src/sshecret_backend/api/clients/schemas.py +++ b/packages/sshecret-backend/src/sshecret_backend/api/clients/schemas.py @@ -40,7 +40,9 @@ class ClientView(BaseModel): return responses @classmethod - def from_client(cls, client: models.Client) -> Self: + def from_client( + cls, client: models.Client, include_deleted_secrets: bool = False + ) -> Self: """Instantiate from a client.""" view = cls( id=client.id, @@ -54,7 +56,12 @@ class ClientView(BaseModel): is_deleted=client.is_deleted, ) if client.secrets: - view.secrets = [secret.name for secret in client.secrets] + if include_deleted_secrets: + view.secrets = [secret.name for secret in client.secrets] + else: + view.secrets = [ + secret.name for secret in client.secrets if not secret.deleted + ] if client.policies: view.policies = [policy.source for policy in client.policies] diff --git a/packages/sshecret-backend/src/sshecret_backend/api/common.py b/packages/sshecret-backend/src/sshecret_backend/api/common.py index bab5d5e..b973fea 100644 --- a/packages/sshecret-backend/src/sshecret_backend/api/common.py +++ b/packages/sshecret-backend/src/sshecret_backend/api/common.py @@ -116,24 +116,18 @@ async def resolve_client_id( return None -async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | None: +async def get_client_by_id( + session: AsyncSession, id: uuid.UUID, include_deleted: bool = False +) -> Client | None: """Get client by ID.""" - client_filter = client_with_relationships().where(Client.id == id) + if include_deleted: + client_filter = client_with_relationships().where(Client.id == id) + else: + client_filter = query_active_clients().where(Client.id == id) client_results = await session.execute(client_filter) return client_results.scalars().first() -async def get_client_by_id_or_name( - session: AsyncSession, id_or_name: str -) -> Client | None: - """Get client either by id or name.""" - if RE_UUID.match(id_or_name): - id = uuid.UUID(id_or_name) - return await get_client_by_id(session, id) - - return await get_client_by_name(session, id_or_name) - - def query_active_clients() -> Select[tuple[Client]]: """Get all active clients.""" client_filter = ( @@ -144,22 +138,6 @@ def query_active_clients() -> Select[tuple[Client]]: return client_filter -async def get_client_by_name(session: AsyncSession, name: str) -> Client | None: - """Get client by name. - - This will get the latest client version, unless it's deleted. - """ - client_filter = ( - client_with_relationships() - .where(Client.is_active.is_(True)) - .where(Client.is_deleted.is_not(True)) - .where(Client.name == name) - .order_by(Client.version.desc()) - ) - client_result = await session.execute(client_filter) - return client_result.scalars().first() - - async def refresh_client(session: AsyncSession, client: Client) -> None: """Refresh the client and load in all relationships.""" await session.refresh( diff --git a/packages/sshecret-backend/src/sshecret_backend/api/secrets/operations.py b/packages/sshecret-backend/src/sshecret_backend/api/secrets/operations.py index 752b8e0..194bbfa 100644 --- a/packages/sshecret-backend/src/sshecret_backend/api/secrets/operations.py +++ b/packages/sshecret-backend/src/sshecret_backend/api/secrets/operations.py @@ -182,8 +182,10 @@ class ClientSecretOperations: async def delete_client_secret(self, secret_identifier: FlexID) -> None: """Delete a client secret.""" + LOG.debug("delete_client_secret called with identifier %r", secret_identifier) client_secret = await self._get_client_secret(secret_identifier) if not client_secret: + LOG.warning("Could not find any secret matching client secret.") return client_secret.deleted = True diff --git a/packages/sshecret-sshd/src/sshecret_sshd/commands/dispatcher.py b/packages/sshecret-sshd/src/sshecret_sshd/commands/dispatcher.py index 79fb766..2047451 100644 --- a/packages/sshecret-sshd/src/sshecret_sshd/commands/dispatcher.py +++ b/packages/sshecret-sshd/src/sshecret_sshd/commands/dispatcher.py @@ -7,14 +7,19 @@ import logging from typing import cast, final, override import asyncssh - -from sshecret_sshd import exceptions, constants +from sshecret_sshd import constants, exceptions from .base import CommandDispatcher from .get_secret import GetSecret -from .register import Register from .list_secrets import ListSecrets from .ping import PingCommand +from .register import Register +from .shelldriver import ( + ShellDeleteSecret, + ShellListSecrets, + ShellLookupSecret, + ShellStoreSecret, +) SYNOPSIS = """[bold]Sshecret SSH Server[/bold] @@ -29,9 +34,13 @@ encoded as base64. COMMANDS = [ GetSecret, - Register, ListSecrets, PingCommand, + Register, + ShellDeleteSecret, + ShellListSecrets, + ShellLookupSecret, + ShellStoreSecret, ] LOG = logging.getLogger(__name__) diff --git a/packages/sshecret-sshd/src/sshecret_sshd/commands/shelldriver.py b/packages/sshecret-sshd/src/sshecret_sshd/commands/shelldriver.py new file mode 100644 index 0000000..ae3bd6f --- /dev/null +++ b/packages/sshecret-sshd/src/sshecret_sshd/commands/shelldriver.py @@ -0,0 +1,163 @@ +"""Podman Shelldriver compatible commands.""" + +import logging +from typing import final, override +import asyncssh + +from sshecret.backend.models import Operation +from sshecret.crypto import encrypt_string, load_public_key + + +from .base import CommandDispatcher + +LOG = logging.getLogger(__name__) + +# These error messages are taken verbatim from podman, and while they don't seem +# to make complete sense, they will be used regardless. +ERR_SECRET_NOT_FOUND = "no such secret" +ERR_SECRET_EXISTS = "secret data with ID already exists" +ERR_INVALID_SECRET = "invalid key" + + +@final +class ShellListSecrets(CommandDispatcher): + """List secrets. + + This command lists secrets in a format compatible with podman's ShellDriver. + """ + + name = "list" + + @override + async def exec(self) -> None: + """List secrets.""" + LOG.debug("ShellListSecret called.") + await self.audit(Operation.READ, "Listed available secret names") + for secret_name in self.client.secrets: + self.print(secret_name) + + +@final +class ShellDeleteSecret(CommandDispatcher): + """Delete a secret. + + If the identifier for a secret does not exist, an error will be printed. + """ + + name = "delete" + mandatory_argument = "KEY" + + @override + async def exec(self) -> None: + """Delete a secret.""" + secret_name = self.arguments[0] + LOG.debug("ShellDeleteSecret called withg arguments %r.", self.arguments) + await self.audit( + operation=Operation.DELETE, + message="ClientSecret deleted", + secret=secret_name, + ) + await self.backend.delete_client_secret( + ("id", str(self.client.id)), ("name", secret_name) + ) + + +@final +class ShellLookupSecret(CommandDispatcher): + """Look up a secret. + + The identifier for the secret must be provided as the argument. + """ + + name = "lookup" + mandatory_argument = "KEY" + + @override + async def exec(self) -> None: + """Lookup secret.""" + LOG.debug("ShellLookupSecret called with arguments %r", self.arguments) + secret_name = self.arguments[0] + + secret = await self.backend.get_client_secret( + ("id", str(self.client.id)), secret_name + ) + if not secret: + LOG.debug( + "Secret %s not found for client %s (%s)", + secret_name, + self.client.id, + self.client.name, + ) + self.print(ERR_SECRET_NOT_FOUND, stderr=True) + return + await self.audit( + Operation.READ, message="Client requested secret", secret=secret_name + ) + + self.print(secret) + + +@final +class ShellStoreSecret(CommandDispatcher): + """Store a secret. + + Secret will be read from command argument, or via STDIN. + """ + + name = "store" + mandatory_argument = "KEY" + + @override + async def exec(self) -> None: + """Store a secret.""" + LOG.debug("ShellStoreSecret called with arguments %r", self.arguments) + secret_name = self.arguments[0] + if secret_name in self.client.secrets: + self.print(ERR_SECRET_EXISTS, stderr=True) + return + + secret_data: str | None = None + if len(self.arguments) == 2: + secret_data = self.arguments[1] + + if not secret_data: + LOG.debug("No secret set as input, trying stdin.") + secret_data = await self.get_secret_on_stdin() + + if not secret_data: + self.print(ERR_INVALID_SECRET, stderr=True) + return + + # Encrypt secret + encrypted = self.encrypt_secret(secret_data) + + await self.backend.create_client_secret( + ("id", str(self.client.id)), secret_name, encrypted + ) + await self.audit( + operation=Operation.CREATE, + message="Secret created from 'store' command", + secret=secret_name, + ) + + def encrypt_secret(self, value: str) -> str: + """Encrypt a secret.""" + public_key = load_public_key(self.client.public_key.encode()) + return encrypt_string(value, public_key) + + async def get_secret_on_stdin(self) -> str | None: + """Get secret from stdin.""" + secret_data = "" + try: + async for line in self.process.stdin: + if self.process.stdin.at_eof(): + break + if not line: + break + secret_data += line.rstrip() + except asyncssh.BreakReceived: + pass + + if not secret_data: + return None + return secret_data diff --git a/tests/integration/test_sshd.py b/tests/integration/test_sshd.py index 548278b..4fcc809 100644 --- a/tests/integration/test_sshd.py +++ b/tests/integration/test_sshd.py @@ -3,8 +3,9 @@ This essentially also tests parts of the admin API. """ +import asyncio from contextlib import asynccontextmanager -from typing import AsyncIterator +from collections.abc import AsyncIterator import os import httpx @@ -17,7 +18,33 @@ from .clients import create_test_client, ClientData from .types import CommandRunner, ProcessRunner -class TestSshd: +class BaseSSHTests: + """Base test class.""" + + async def register_client( + self, name: str, ssh_session: ProcessRunner + ) -> ClientData: + """Register client.""" + test_client = create_test_client(name) + async with ssh_session(test_client, "register") as session: + maxlines = 10 + linenum = 0 + found = False + while linenum < maxlines: + line = await session.stdout.readline() + if "Enter public key" in line: + found = True + break + assert found is True + session.stdin.write(test_client.public_key + "\n") + + result = await session.stdout.read() + assert "Key is valid. Registering client." in result + await session.wait() + return test_client + + +class TestSshd(BaseSSHTests): """Class based tests. This allows us to create small helpers. @@ -52,30 +79,9 @@ class TestSshd: client = clients[0] assert client.name == "new_client" - async def register_client( - self, name: str, ssh_session: ProcessRunner - ) -> ClientData: - """Register client.""" - test_client = create_test_client(name) - async with ssh_session(test_client, "register") as session: - maxlines = 10 - linenum = 0 - found = False - while linenum < maxlines: - line = await session.stdout.readline() - if "Enter public key" in line: - found = True - break - assert found is True - session.stdin.write(test_client.public_key + "\n") - - result = await session.stdout.read() - assert "Key is valid. Registering client." in result - await session.wait() - return test_client -class TestSshdIntegration(TestSshd): +class TestSshdIntegration(BaseSSHTests): """Integration tests.""" @pytest.mark.asyncio @@ -137,3 +143,185 @@ class TestSshdIntegration(TestSshd): base_url=url, headers={"Authorization": f"Bearer {token}"} ) as client: yield client + + +class TestShelldriverCommands(BaseSSHTests): + """Shelldriver command tests.""" + + @pytest.mark.asyncio + async def test_store_and_lookup( + self, + backend_server: tuple[str, tuple[str, str]], + ssh_session: ProcessRunner, + ssh_command_runner: CommandRunner, + ) -> None: + """Log in.""" + test_client = await self.register_client("myclient", ssh_session) + ssh_output = await ssh_command_runner(test_client, "store mysecret secretvalue") + assert bool(ssh_output.stderr) is False + + # wait half a second + await asyncio.sleep(0.5) + + ssh_output = await ssh_command_runner(test_client, "lookup mysecret") + assert bool(ssh_output.stderr) is False + assert ssh_output.stdout is not None + assert isinstance(ssh_output.stdout, str) + encrypted = ssh_output.stdout.rstrip() + decrypted = decode_string(encrypted, test_client.private_key) + assert decrypted == "secretvalue" + + @pytest.mark.asyncio + async def test_store_and_lookup_stdin( + self, + backend_server: tuple[str, tuple[str, str]], + ssh_session: ProcessRunner, + ssh_command_runner: CommandRunner, + ) -> None: + """Test store and lookup, with password specification in stdin.""" + test_client = await self.register_client("myclient", ssh_session) + async with ssh_session(test_client, "store insecret") as session: + session.stdin.write("testinput\n") + session.stdin.write_eof() + + await session.stdin.wait_closed() + + await asyncio.sleep(0.5) + ssh_output = await ssh_command_runner(test_client, "lookup insecret") + assert bool(ssh_output.stderr) is False + assert ssh_output.stdout is not None + assert isinstance(ssh_output.stdout, str) + encrypted = ssh_output.stdout.rstrip() + decrypted = decode_string(encrypted, test_client.private_key) + assert decrypted == "testinput" + + @pytest.mark.parametrize("secret_name",["nonexistant", "blåbærgrød", "../../../etc/shadow"]) + @pytest.mark.asyncio + async def test_invalid_lookup( + self, + secret_name: str, + ssh_command_runner: CommandRunner, + ssh_session: ProcessRunner, + ) -> None: + """Test lookup with invalid secret.""" + test_client = await self.register_client("myclient", ssh_session) + ssh_output = await ssh_command_runner(test_client, f"lookup {secret_name}") + assert isinstance(ssh_output.stderr, str) + assert ssh_output.stderr.rstrip() == "no such secret" + + +class TestShelldriverListCommand(BaseSSHTests): + """Tests for the list command.""" + + @pytest.fixture(name="secret_names") + def get_secret_names(self) -> list[str]: + """Get secret names. + + Sort of like a parametrize function. + """ + return ["foo", "bar", "abc123", "blåbærgrød"] + + @pytest.fixture(name="test_client") + @pytest.mark.asyncio + async def create_test_client( + self, + ssh_session: ProcessRunner, + ) -> ClientData: + """Register a test client.""" + return await self.register_client("listclient", ssh_session) + + @pytest.fixture(autouse=True) + @pytest.mark.asyncio + async def create_data( + self, + secret_names: list[str], + test_client: ClientData, + ssh_command_runner: CommandRunner, + ) -> None: + """Create data for the test.""" + for name in secret_names: + ssh_output = await ssh_command_runner(test_client, f"store {name} secretvalue") + assert bool(ssh_output.stderr) is False + assert ssh_output.stdout is not None + assert isinstance(ssh_output.stdout, str) + + @pytest.mark.asyncio + async def test_list_secrets( + self, + secret_names: list[str], + test_client: ClientData, + ssh_command_runner: CommandRunner, + ) -> None: + """Test list command.""" + ssh_output = await ssh_command_runner(test_client, "list") + assert bool(ssh_output.stderr) is False + assert ssh_output.stdout is not None + assert isinstance(ssh_output.stdout, str) + + list_output = ssh_output.stdout + print(list_output) + input_secret_names = list_output.splitlines() + assert len(input_secret_names) == len(secret_names) + + assert sorted(secret_names) == sorted(input_secret_names) + + +@pytest.mark.parametrize("secret_name", ["simple", "blåbærgrød"]) +class TestShelldriverDeleteCommand(BaseSSHTests): + """Tests for the list command.""" + + @pytest.fixture(name="test_client") + @pytest.mark.asyncio + async def create_test_client( + self, + ssh_session: ProcessRunner, + ) -> ClientData: + """Register a test client.""" + return await self.register_client("listclient", ssh_session) + + @pytest.fixture(autouse=True) + @pytest.mark.asyncio + async def create_data( + self, + secret_name: str, + test_client: ClientData, + ssh_command_runner: CommandRunner, + ) -> None: + """Create data for the test.""" + ssh_output = await ssh_command_runner(test_client, f"store {secret_name} secretvalue") + assert bool(ssh_output.stderr) is False + assert ssh_output.stdout is not None + assert isinstance(ssh_output.stdout, str) + + + async def get_stored_secrets( + self, + test_client: ClientData, + ssh_command_runner: CommandRunner, + ) -> list[str]: + """Get stored secrets.""" + ssh_output = await ssh_command_runner(test_client, "list") + assert bool(ssh_output.stderr) is False + assert ssh_output.stdout is not None + assert isinstance(ssh_output.stdout, str) + + list_output = ssh_output.stdout + print(list_output) + input_secret_names = list_output.splitlines() + return sorted(input_secret_names) + + + @pytest.mark.asyncio + async def test_delete_secret( + self, + secret_name: str, + test_client: ClientData, + ssh_command_runner: CommandRunner, + ) -> None: + """Delete secret.""" + current_secrets = await self.get_stored_secrets(test_client, ssh_command_runner) + assert secret_name in current_secrets + ssh_output = await ssh_command_runner(test_client, f"delete {secret_name}") + assert bool(ssh_output.stderr) is False + current_secrets = await self.get_stored_secrets(test_client, ssh_command_runner) + assert secret_name not in current_secrets