Fix audit as async, function name

This commit is contained in:
2025-05-11 11:22:20 +02:00
parent d3d99775d9
commit a07fba9560
2 changed files with 13 additions and 13 deletions

View File

@ -8,7 +8,7 @@ from typing import cast
import click import click
from pydantic import ValidationError from pydantic import ValidationError
from .settings import ServerSettings from .settings import ServerSettings
from .ssh_server import start_server from .ssh_server import start_sshecret_sshd
LOG = logging.getLogger() LOG = logging.getLogger()
@ -51,7 +51,7 @@ def cli_run(ctx: click.Context, host: str | None, port: int | None) -> None:
settings.port = port settings.port = port
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
loop.run_until_complete(start_server(settings)) loop.run_until_complete(start_sshecret_sshd(settings))
title = click.style("Sshecret SSH Daemon", fg="red", bold=True) title = click.style("Sshecret SSH Daemon", fg="red", bold=True)
click.echo(f"Starting {title}: {settings.listen_address}:{settings.port}") click.echo(f"Starting {title}: {settings.listen_address}:{settings.port}")
try: try:

View File

@ -32,7 +32,7 @@ class CommandError(Exception):
"""Error class for errors during command processing.""" """Error class for errors during command processing."""
def audit_process( async def audit_process(
backend: SshecretBackend, backend: SshecretBackend,
process: asyncssh.SSHServerProcess[str], process: asyncssh.SSHServerProcess[str],
operation: Operation, operation: Operation,
@ -54,12 +54,12 @@ def audit_process(
data["command"] = cmd data["command"] = cmd
data["args"] = " ".join(cmd_args) data["args"] = " ".join(cmd_args)
backend.audit(SubSystem.SSHD).write( await backend.audit(SubSystem.SSHD).write_async(
operation, message, remote_ip, client, secret=None, secret_name=secret, **data operation, message, remote_ip, client, secret=None, secret_name=secret, **data
) )
def audit_event( async def audit_event(
backend: SshecretBackend, backend: SshecretBackend,
message: str, message: str,
operation: Operation, operation: Operation,
@ -70,7 +70,7 @@ def audit_event(
"""Add an audit event.""" """Add an audit event."""
if not origin: if not origin:
origin = "UNKNOWN" origin = "UNKNOWN"
backend.audit(SubSystem.SSHD).write( await backend.audit(SubSystem.SSHD).write_async(
operation, message, origin, client, secret=None, secret_name=secret operation, message, origin, client, secret=None, secret_name=secret
) )
@ -187,7 +187,7 @@ async def register_client(
key = asyncssh.import_public_key(public_key) key = asyncssh.import_public_key(public_key)
if key.algorithm.decode() != "ssh-rsa": if key.algorithm.decode() != "ssh-rsa":
raise CommandError(constants.ERROR_INVALID_KEY_TYPE) raise CommandError(constants.ERROR_INVALID_KEY_TYPE)
audit_process(backend, process, Operation.CREATE, "Registering new client") await audit_process(backend, process, Operation.CREATE, "Registering new client")
LOG.debug("Registering client %s with public key %s", username, public_key) LOG.debug("Registering client %s with public key %s", username, public_key)
await backend.create_client(username, public_key) await backend.create_client(username, public_key)
@ -205,7 +205,7 @@ async def get_secret(
if secret_name not in client.secrets: if secret_name not in client.secrets:
raise CommandError(constants.ERROR_NO_SECRET_FOUND) raise CommandError(constants.ERROR_NO_SECRET_FOUND)
audit_event( await audit_event(
backend, backend,
"Client requested secret", "Client requested secret",
operation=Operation.READ, operation=Operation.READ,
@ -247,7 +247,7 @@ async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None
allowed_networks = get_info_allowed_registration(process) allowed_networks = get_info_allowed_registration(process)
if not allowed_networks: if not allowed_networks:
process.stdout.write("Unauthorized.\n") process.stdout.write("Unauthorized.\n")
audit_process( await audit_process(
backend, backend,
process, process,
Operation.DENY, Operation.DENY,
@ -266,7 +266,7 @@ async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None
if client_address in network: if client_address in network:
break break
else: else:
audit_process( await audit_process(
backend, backend,
process, process,
Operation.DENY, Operation.DENY,
@ -390,7 +390,7 @@ class AsshyncServer(asyncssh.SSHServer):
self._conn.set_extra_info(client=client) self._conn.set_extra_info(client=client)
self._conn.set_authorized_keys(key) self._conn.set_authorized_keys(key)
else: else:
audit_event( await audit_event(
self.backend, self.backend,
"Client denied due to policy", "Client denied due to policy",
Operation.DENY, Operation.DENY,
@ -492,7 +492,7 @@ async def run_ssh_server(
return server return server
async def start_server(settings: ServerSettings | None = None) -> None: async def start_sshecret_sshd(settings: ServerSettings | None = None) -> asyncssh.SSHAcceptor:
"""Start the server.""" """Start the server."""
server_key = get_server_key() server_key = get_server_key()
@ -500,7 +500,7 @@ async def start_server(settings: ServerSettings | None = None) -> None:
settings = ServerSettings() # pyright: ignore[reportCallIssue] settings = ServerSettings() # pyright: ignore[reportCallIssue]
backend = SshecretBackend(str(settings.backend_url), settings.backend_token) backend = SshecretBackend(str(settings.backend_url), settings.backend_token)
await run_ssh_server( return await run_ssh_server(
backend=backend, backend=backend,
listen_address=settings.listen_address, listen_address=settings.listen_address,
port=settings.port, port=settings.port,