From a07fba9560a36574552a7d92a199521c7fb5a5e6 Mon Sep 17 00:00:00 2001 From: Allan Eising Date: Sun, 11 May 2025 11:22:20 +0200 Subject: [PATCH] Fix audit as async, function name --- .../sshecret-sshd/src/sshecret_sshd/cli.py | 4 ++-- .../src/sshecret_sshd/ssh_server.py | 22 +++++++++---------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/packages/sshecret-sshd/src/sshecret_sshd/cli.py b/packages/sshecret-sshd/src/sshecret_sshd/cli.py index 1416ead..d46283e 100644 --- a/packages/sshecret-sshd/src/sshecret_sshd/cli.py +++ b/packages/sshecret-sshd/src/sshecret_sshd/cli.py @@ -8,7 +8,7 @@ from typing import cast import click from pydantic import ValidationError from .settings import ServerSettings -from .ssh_server import start_server +from .ssh_server import start_sshecret_sshd LOG = logging.getLogger() @@ -51,7 +51,7 @@ def cli_run(ctx: click.Context, host: str | None, port: int | None) -> None: settings.port = port loop = asyncio.new_event_loop() - loop.run_until_complete(start_server(settings)) + loop.run_until_complete(start_sshecret_sshd(settings)) title = click.style("Sshecret SSH Daemon", fg="red", bold=True) click.echo(f"Starting {title}: {settings.listen_address}:{settings.port}") try: diff --git a/packages/sshecret-sshd/src/sshecret_sshd/ssh_server.py b/packages/sshecret-sshd/src/sshecret_sshd/ssh_server.py index 792945c..6ba444b 100644 --- a/packages/sshecret-sshd/src/sshecret_sshd/ssh_server.py +++ b/packages/sshecret-sshd/src/sshecret_sshd/ssh_server.py @@ -32,7 +32,7 @@ class CommandError(Exception): """Error class for errors during command processing.""" -def audit_process( +async def audit_process( backend: SshecretBackend, process: asyncssh.SSHServerProcess[str], operation: Operation, @@ -54,12 +54,12 @@ def audit_process( data["command"] = cmd data["args"] = " ".join(cmd_args) - backend.audit(SubSystem.SSHD).write( + await backend.audit(SubSystem.SSHD).write_async( operation, message, remote_ip, client, secret=None, secret_name=secret, **data ) -def audit_event( +async def audit_event( backend: SshecretBackend, message: str, operation: Operation, @@ -70,7 +70,7 @@ def audit_event( """Add an audit event.""" if not origin: origin = "UNKNOWN" - backend.audit(SubSystem.SSHD).write( + await backend.audit(SubSystem.SSHD).write_async( operation, message, origin, client, secret=None, secret_name=secret ) @@ -187,7 +187,7 @@ async def register_client( key = asyncssh.import_public_key(public_key) if key.algorithm.decode() != "ssh-rsa": raise CommandError(constants.ERROR_INVALID_KEY_TYPE) - audit_process(backend, process, Operation.CREATE, "Registering new client") + await audit_process(backend, process, Operation.CREATE, "Registering new client") LOG.debug("Registering client %s with public key %s", username, public_key) await backend.create_client(username, public_key) @@ -205,7 +205,7 @@ async def get_secret( if secret_name not in client.secrets: raise CommandError(constants.ERROR_NO_SECRET_FOUND) - audit_event( + await audit_event( backend, "Client requested secret", operation=Operation.READ, @@ -247,7 +247,7 @@ async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None allowed_networks = get_info_allowed_registration(process) if not allowed_networks: process.stdout.write("Unauthorized.\n") - audit_process( + await audit_process( backend, process, Operation.DENY, @@ -266,7 +266,7 @@ async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None if client_address in network: break else: - audit_process( + await audit_process( backend, process, Operation.DENY, @@ -390,7 +390,7 @@ class AsshyncServer(asyncssh.SSHServer): self._conn.set_extra_info(client=client) self._conn.set_authorized_keys(key) else: - audit_event( + await audit_event( self.backend, "Client denied due to policy", Operation.DENY, @@ -492,7 +492,7 @@ async def run_ssh_server( return server -async def start_server(settings: ServerSettings | None = None) -> None: +async def start_sshecret_sshd(settings: ServerSettings | None = None) -> asyncssh.SSHAcceptor: """Start the server.""" server_key = get_server_key() @@ -500,7 +500,7 @@ async def start_server(settings: ServerSettings | None = None) -> None: settings = ServerSettings() # pyright: ignore[reportCallIssue] backend = SshecretBackend(str(settings.backend_url), settings.backend_token) - await run_ssh_server( + return await run_ssh_server( backend=backend, listen_address=settings.listen_address, port=settings.port,