Implement podman-compatible commands
All checks were successful
Build and push image / build-containers (push) Successful in 8m46s

This commit is contained in:
2025-06-10 10:28:17 +02:00
parent 782ec19137
commit 0eaa913e35
8 changed files with 414 additions and 61 deletions

View File

@ -16,7 +16,7 @@ from sshecret.backend import (
Operation, Operation,
SubSystem, SubSystem,
) )
from sshecret.backend.models import DetailedSecrets, Secret from sshecret.backend.models import DetailedSecrets
from sshecret.backend.api import AuditAPI, KeySpec from sshecret.backend.api import AuditAPI, KeySpec
from sshecret.crypto import encrypt_string, load_public_key from sshecret.crypto import encrypt_string, load_public_key

View File

@ -130,6 +130,7 @@ class ClientOperations:
if not db_client: if not db_client:
raise HTTPException(status_code=404, detail="Client not found.") raise HTTPException(status_code=404, detail="Client not found.")
if db_client.is_deleted: if db_client.is_deleted:
LOG.warning("Client %r was already deleted!", client)
return return
db_client.is_deleted = True db_client.is_deleted = True
db_client.deleted_at = datetime.now(timezone.utc) db_client.deleted_at = datetime.now(timezone.utc)
@ -271,7 +272,12 @@ async def get_clients(
filter_query: ClientListParams, filter_query: ClientListParams,
) -> ClientQueryResult: ) -> ClientQueryResult:
"""Get Clients.""" """Get Clients."""
count_statement = select(func.count("*")).select_from(Client).where(Client.is_deleted.is_not(True)).where(Client.is_active.is_not(False)) count_statement = (
select(func.count("*"))
.select_from(Client)
.where(Client.is_deleted.is_not(True))
.where(Client.is_active.is_not(False))
)
count_statement = cast( count_statement = cast(
Select[tuple[int]], Select[tuple[int]],
filter_client_statement(count_statement, filter_query, True), filter_client_statement(count_statement, filter_query, True),

View File

@ -40,7 +40,9 @@ class ClientView(BaseModel):
return responses return responses
@classmethod @classmethod
def from_client(cls, client: models.Client) -> Self: def from_client(
cls, client: models.Client, include_deleted_secrets: bool = False
) -> Self:
"""Instantiate from a client.""" """Instantiate from a client."""
view = cls( view = cls(
id=client.id, id=client.id,
@ -54,7 +56,12 @@ class ClientView(BaseModel):
is_deleted=client.is_deleted, is_deleted=client.is_deleted,
) )
if client.secrets: if client.secrets:
if include_deleted_secrets:
view.secrets = [secret.name for secret in client.secrets] view.secrets = [secret.name for secret in client.secrets]
else:
view.secrets = [
secret.name for secret in client.secrets if not secret.deleted
]
if client.policies: if client.policies:
view.policies = [policy.source for policy in client.policies] view.policies = [policy.source for policy in client.policies]

View File

@ -116,24 +116,18 @@ async def resolve_client_id(
return None return None
async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | None: async def get_client_by_id(
session: AsyncSession, id: uuid.UUID, include_deleted: bool = False
) -> Client | None:
"""Get client by ID.""" """Get client by ID."""
if include_deleted:
client_filter = client_with_relationships().where(Client.id == id) client_filter = client_with_relationships().where(Client.id == id)
else:
client_filter = query_active_clients().where(Client.id == id)
client_results = await session.execute(client_filter) client_results = await session.execute(client_filter)
return client_results.scalars().first() return client_results.scalars().first()
async def get_client_by_id_or_name(
session: AsyncSession, id_or_name: str
) -> Client | None:
"""Get client either by id or name."""
if RE_UUID.match(id_or_name):
id = uuid.UUID(id_or_name)
return await get_client_by_id(session, id)
return await get_client_by_name(session, id_or_name)
def query_active_clients() -> Select[tuple[Client]]: def query_active_clients() -> Select[tuple[Client]]:
"""Get all active clients.""" """Get all active clients."""
client_filter = ( client_filter = (
@ -144,22 +138,6 @@ def query_active_clients() -> Select[tuple[Client]]:
return client_filter return client_filter
async def get_client_by_name(session: AsyncSession, name: str) -> Client | None:
"""Get client by name.
This will get the latest client version, unless it's deleted.
"""
client_filter = (
client_with_relationships()
.where(Client.is_active.is_(True))
.where(Client.is_deleted.is_not(True))
.where(Client.name == name)
.order_by(Client.version.desc())
)
client_result = await session.execute(client_filter)
return client_result.scalars().first()
async def refresh_client(session: AsyncSession, client: Client) -> None: async def refresh_client(session: AsyncSession, client: Client) -> None:
"""Refresh the client and load in all relationships.""" """Refresh the client and load in all relationships."""
await session.refresh( await session.refresh(

View File

@ -182,8 +182,10 @@ class ClientSecretOperations:
async def delete_client_secret(self, secret_identifier: FlexID) -> None: async def delete_client_secret(self, secret_identifier: FlexID) -> None:
"""Delete a client secret.""" """Delete a client secret."""
LOG.debug("delete_client_secret called with identifier %r", secret_identifier)
client_secret = await self._get_client_secret(secret_identifier) client_secret = await self._get_client_secret(secret_identifier)
if not client_secret: if not client_secret:
LOG.warning("Could not find any secret matching client secret.")
return return
client_secret.deleted = True client_secret.deleted = True

View File

@ -7,14 +7,19 @@ import logging
from typing import cast, final, override from typing import cast, final, override
import asyncssh import asyncssh
from sshecret_sshd import constants, exceptions
from sshecret_sshd import exceptions, constants
from .base import CommandDispatcher from .base import CommandDispatcher
from .get_secret import GetSecret from .get_secret import GetSecret
from .register import Register
from .list_secrets import ListSecrets from .list_secrets import ListSecrets
from .ping import PingCommand from .ping import PingCommand
from .register import Register
from .shelldriver import (
ShellDeleteSecret,
ShellListSecrets,
ShellLookupSecret,
ShellStoreSecret,
)
SYNOPSIS = """[bold]Sshecret SSH Server[/bold] SYNOPSIS = """[bold]Sshecret SSH Server[/bold]
@ -29,9 +34,13 @@ encoded as base64.
COMMANDS = [ COMMANDS = [
GetSecret, GetSecret,
Register,
ListSecrets, ListSecrets,
PingCommand, PingCommand,
Register,
ShellDeleteSecret,
ShellListSecrets,
ShellLookupSecret,
ShellStoreSecret,
] ]
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)

View File

@ -0,0 +1,163 @@
"""Podman Shelldriver compatible commands."""
import logging
from typing import final, override
import asyncssh
from sshecret.backend.models import Operation
from sshecret.crypto import encrypt_string, load_public_key
from .base import CommandDispatcher
LOG = logging.getLogger(__name__)
# These error messages are taken verbatim from podman, and while they don't seem
# to make complete sense, they will be used regardless.
ERR_SECRET_NOT_FOUND = "no such secret"
ERR_SECRET_EXISTS = "secret data with ID already exists"
ERR_INVALID_SECRET = "invalid key"
@final
class ShellListSecrets(CommandDispatcher):
"""List secrets.
This command lists secrets in a format compatible with podman's ShellDriver.
"""
name = "list"
@override
async def exec(self) -> None:
"""List secrets."""
LOG.debug("ShellListSecret called.")
await self.audit(Operation.READ, "Listed available secret names")
for secret_name in self.client.secrets:
self.print(secret_name)
@final
class ShellDeleteSecret(CommandDispatcher):
"""Delete a secret.
If the identifier for a secret does not exist, an error will be printed.
"""
name = "delete"
mandatory_argument = "KEY"
@override
async def exec(self) -> None:
"""Delete a secret."""
secret_name = self.arguments[0]
LOG.debug("ShellDeleteSecret called withg arguments %r.", self.arguments)
await self.audit(
operation=Operation.DELETE,
message="ClientSecret deleted",
secret=secret_name,
)
await self.backend.delete_client_secret(
("id", str(self.client.id)), ("name", secret_name)
)
@final
class ShellLookupSecret(CommandDispatcher):
"""Look up a secret.
The identifier for the secret must be provided as the argument.
"""
name = "lookup"
mandatory_argument = "KEY"
@override
async def exec(self) -> None:
"""Lookup secret."""
LOG.debug("ShellLookupSecret called with arguments %r", self.arguments)
secret_name = self.arguments[0]
secret = await self.backend.get_client_secret(
("id", str(self.client.id)), secret_name
)
if not secret:
LOG.debug(
"Secret %s not found for client %s (%s)",
secret_name,
self.client.id,
self.client.name,
)
self.print(ERR_SECRET_NOT_FOUND, stderr=True)
return
await self.audit(
Operation.READ, message="Client requested secret", secret=secret_name
)
self.print(secret)
@final
class ShellStoreSecret(CommandDispatcher):
"""Store a secret.
Secret will be read from command argument, or via STDIN.
"""
name = "store"
mandatory_argument = "KEY"
@override
async def exec(self) -> None:
"""Store a secret."""
LOG.debug("ShellStoreSecret called with arguments %r", self.arguments)
secret_name = self.arguments[0]
if secret_name in self.client.secrets:
self.print(ERR_SECRET_EXISTS, stderr=True)
return
secret_data: str | None = None
if len(self.arguments) == 2:
secret_data = self.arguments[1]
if not secret_data:
LOG.debug("No secret set as input, trying stdin.")
secret_data = await self.get_secret_on_stdin()
if not secret_data:
self.print(ERR_INVALID_SECRET, stderr=True)
return
# Encrypt secret
encrypted = self.encrypt_secret(secret_data)
await self.backend.create_client_secret(
("id", str(self.client.id)), secret_name, encrypted
)
await self.audit(
operation=Operation.CREATE,
message="Secret created from 'store' command",
secret=secret_name,
)
def encrypt_secret(self, value: str) -> str:
"""Encrypt a secret."""
public_key = load_public_key(self.client.public_key.encode())
return encrypt_string(value, public_key)
async def get_secret_on_stdin(self) -> str | None:
"""Get secret from stdin."""
secret_data = ""
try:
async for line in self.process.stdin:
if self.process.stdin.at_eof():
break
if not line:
break
secret_data += line.rstrip()
except asyncssh.BreakReceived:
pass
if not secret_data:
return None
return secret_data

View File

@ -3,8 +3,9 @@
This essentially also tests parts of the admin API. This essentially also tests parts of the admin API.
""" """
import asyncio
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import AsyncIterator from collections.abc import AsyncIterator
import os import os
import httpx import httpx
@ -17,7 +18,33 @@ from .clients import create_test_client, ClientData
from .types import CommandRunner, ProcessRunner from .types import CommandRunner, ProcessRunner
class TestSshd: class BaseSSHTests:
"""Base test class."""
async def register_client(
self, name: str, ssh_session: ProcessRunner
) -> ClientData:
"""Register client."""
test_client = create_test_client(name)
async with ssh_session(test_client, "register") as session:
maxlines = 10
linenum = 0
found = False
while linenum < maxlines:
line = await session.stdout.readline()
if "Enter public key" in line:
found = True
break
assert found is True
session.stdin.write(test_client.public_key + "\n")
result = await session.stdout.read()
assert "Key is valid. Registering client." in result
await session.wait()
return test_client
class TestSshd(BaseSSHTests):
"""Class based tests. """Class based tests.
This allows us to create small helpers. This allows us to create small helpers.
@ -52,30 +79,9 @@ class TestSshd:
client = clients[0] client = clients[0]
assert client.name == "new_client" assert client.name == "new_client"
async def register_client(
self, name: str, ssh_session: ProcessRunner
) -> ClientData:
"""Register client."""
test_client = create_test_client(name)
async with ssh_session(test_client, "register") as session:
maxlines = 10
linenum = 0
found = False
while linenum < maxlines:
line = await session.stdout.readline()
if "Enter public key" in line:
found = True
break
assert found is True
session.stdin.write(test_client.public_key + "\n")
result = await session.stdout.read()
assert "Key is valid. Registering client." in result
await session.wait()
return test_client
class TestSshdIntegration(TestSshd): class TestSshdIntegration(BaseSSHTests):
"""Integration tests.""" """Integration tests."""
@pytest.mark.asyncio @pytest.mark.asyncio
@ -137,3 +143,185 @@ class TestSshdIntegration(TestSshd):
base_url=url, headers={"Authorization": f"Bearer {token}"} base_url=url, headers={"Authorization": f"Bearer {token}"}
) as client: ) as client:
yield client yield client
class TestShelldriverCommands(BaseSSHTests):
"""Shelldriver command tests."""
@pytest.mark.asyncio
async def test_store_and_lookup(
self,
backend_server: tuple[str, tuple[str, str]],
ssh_session: ProcessRunner,
ssh_command_runner: CommandRunner,
) -> None:
"""Log in."""
test_client = await self.register_client("myclient", ssh_session)
ssh_output = await ssh_command_runner(test_client, "store mysecret secretvalue")
assert bool(ssh_output.stderr) is False
# wait half a second
await asyncio.sleep(0.5)
ssh_output = await ssh_command_runner(test_client, "lookup mysecret")
assert bool(ssh_output.stderr) is False
assert ssh_output.stdout is not None
assert isinstance(ssh_output.stdout, str)
encrypted = ssh_output.stdout.rstrip()
decrypted = decode_string(encrypted, test_client.private_key)
assert decrypted == "secretvalue"
@pytest.mark.asyncio
async def test_store_and_lookup_stdin(
self,
backend_server: tuple[str, tuple[str, str]],
ssh_session: ProcessRunner,
ssh_command_runner: CommandRunner,
) -> None:
"""Test store and lookup, with password specification in stdin."""
test_client = await self.register_client("myclient", ssh_session)
async with ssh_session(test_client, "store insecret") as session:
session.stdin.write("testinput\n")
session.stdin.write_eof()
await session.stdin.wait_closed()
await asyncio.sleep(0.5)
ssh_output = await ssh_command_runner(test_client, "lookup insecret")
assert bool(ssh_output.stderr) is False
assert ssh_output.stdout is not None
assert isinstance(ssh_output.stdout, str)
encrypted = ssh_output.stdout.rstrip()
decrypted = decode_string(encrypted, test_client.private_key)
assert decrypted == "testinput"
@pytest.mark.parametrize("secret_name",["nonexistant", "blåbærgrød", "../../../etc/shadow"])
@pytest.mark.asyncio
async def test_invalid_lookup(
self,
secret_name: str,
ssh_command_runner: CommandRunner,
ssh_session: ProcessRunner,
) -> None:
"""Test lookup with invalid secret."""
test_client = await self.register_client("myclient", ssh_session)
ssh_output = await ssh_command_runner(test_client, f"lookup {secret_name}")
assert isinstance(ssh_output.stderr, str)
assert ssh_output.stderr.rstrip() == "no such secret"
class TestShelldriverListCommand(BaseSSHTests):
"""Tests for the list command."""
@pytest.fixture(name="secret_names")
def get_secret_names(self) -> list[str]:
"""Get secret names.
Sort of like a parametrize function.
"""
return ["foo", "bar", "abc123", "blåbærgrød"]
@pytest.fixture(name="test_client")
@pytest.mark.asyncio
async def create_test_client(
self,
ssh_session: ProcessRunner,
) -> ClientData:
"""Register a test client."""
return await self.register_client("listclient", ssh_session)
@pytest.fixture(autouse=True)
@pytest.mark.asyncio
async def create_data(
self,
secret_names: list[str],
test_client: ClientData,
ssh_command_runner: CommandRunner,
) -> None:
"""Create data for the test."""
for name in secret_names:
ssh_output = await ssh_command_runner(test_client, f"store {name} secretvalue")
assert bool(ssh_output.stderr) is False
assert ssh_output.stdout is not None
assert isinstance(ssh_output.stdout, str)
@pytest.mark.asyncio
async def test_list_secrets(
self,
secret_names: list[str],
test_client: ClientData,
ssh_command_runner: CommandRunner,
) -> None:
"""Test list command."""
ssh_output = await ssh_command_runner(test_client, "list")
assert bool(ssh_output.stderr) is False
assert ssh_output.stdout is not None
assert isinstance(ssh_output.stdout, str)
list_output = ssh_output.stdout
print(list_output)
input_secret_names = list_output.splitlines()
assert len(input_secret_names) == len(secret_names)
assert sorted(secret_names) == sorted(input_secret_names)
@pytest.mark.parametrize("secret_name", ["simple", "blåbærgrød"])
class TestShelldriverDeleteCommand(BaseSSHTests):
"""Tests for the list command."""
@pytest.fixture(name="test_client")
@pytest.mark.asyncio
async def create_test_client(
self,
ssh_session: ProcessRunner,
) -> ClientData:
"""Register a test client."""
return await self.register_client("listclient", ssh_session)
@pytest.fixture(autouse=True)
@pytest.mark.asyncio
async def create_data(
self,
secret_name: str,
test_client: ClientData,
ssh_command_runner: CommandRunner,
) -> None:
"""Create data for the test."""
ssh_output = await ssh_command_runner(test_client, f"store {secret_name} secretvalue")
assert bool(ssh_output.stderr) is False
assert ssh_output.stdout is not None
assert isinstance(ssh_output.stdout, str)
async def get_stored_secrets(
self,
test_client: ClientData,
ssh_command_runner: CommandRunner,
) -> list[str]:
"""Get stored secrets."""
ssh_output = await ssh_command_runner(test_client, "list")
assert bool(ssh_output.stderr) is False
assert ssh_output.stdout is not None
assert isinstance(ssh_output.stdout, str)
list_output = ssh_output.stdout
print(list_output)
input_secret_names = list_output.splitlines()
return sorted(input_secret_names)
@pytest.mark.asyncio
async def test_delete_secret(
self,
secret_name: str,
test_client: ClientData,
ssh_command_runner: CommandRunner,
) -> None:
"""Delete secret."""
current_secrets = await self.get_stored_secrets(test_client, ssh_command_runner)
assert secret_name in current_secrets
ssh_output = await ssh_command_runner(test_client, f"delete {secret_name}")
assert bool(ssh_output.stderr) is False
current_secrets = await self.get_stored_secrets(test_client, ssh_command_runner)
assert secret_name not in current_secrets