Reformat and lint

This commit is contained in:
2025-05-10 08:29:58 +02:00
parent 0a427b6a91
commit d866553ac1
9 changed files with 120 additions and 44 deletions

View File

@ -1,4 +1,5 @@
"""CLI app."""
import logging
import asyncio
import sys
@ -12,7 +13,9 @@ from .ssh_server import start_server
LOG = logging.getLogger()
handler = logging.StreamHandler()
formatter = logging.Formatter("%(created)f:%(levelname)s:%(name)s:%(module)s:%(message)s")
formatter = logging.Formatter(
"%(created)f:%(levelname)s:%(name)s:%(module)s:%(message)s"
)
handler.setFormatter(formatter)
LOG.addHandler(handler)

View File

@ -11,4 +11,6 @@ SERVER_KEY_TYPE = "ed25519"
ERROR_BACKEND_ERROR = "Error: Unexpected response or error from backend"
ERROR_INFO_BACKEND_GONE = "Unexpected error: Backend connection lost."
ERROR_INFO_USERNAME_GONE = "Unexpected error: Username lost in transit."
ERROR_INFO_REMOTE_IP_GONE = "Unexpected error: Client connection details lost in transit."
ERROR_INFO_REMOTE_IP_GONE = (
"Unexpected error: Client connection details lost in transit."
)

View File

@ -13,9 +13,11 @@ class ClientRegistrationSettings(BaseModel):
"""Client registration settings."""
enabled: bool = False
allow_from: Annotated[list[IPvAnyNetwork], ForceDecode] = Field(default_factory=list)
allow_from: Annotated[list[IPvAnyNetwork], ForceDecode] = Field(
default_factory=list
)
@field_validator('allow_from', mode="before")
@field_validator("allow_from", mode="before")
@classmethod
def ensure_allow_from_list(cls, value: Any) -> list[IPvAnyNetwork]:
"""Convert allow_from to a list."""
@ -34,15 +36,20 @@ class ClientRegistrationSettings(BaseModel):
allow_from.append(entry)
return allow_from
class ServerSettings(BaseSettings):
"""Server Settings."""
model_config = SettingsConfigDict(env_file=".sshd.env", env_prefix="sshecret_sshd_", env_nested_delimiter='_')
model_config = SettingsConfigDict(
env_file=".sshd.env", env_prefix="sshecret_sshd_", env_nested_delimiter="_"
)
backend_url: AnyHttpUrl = Field(alias="sshecret_backend_url")
backend_token: str
listen_address: str = Field(default="127.0.0.1")
port: int = DEFAULT_LISTEN_PORT
registration: ClientRegistrationSettings = Field(default_factory=ClientRegistrationSettings)
registration: ClientRegistrationSettings = Field(
default_factory=ClientRegistrationSettings
)
debug: bool = False
enable_ping_command: bool = False

View File

@ -54,7 +54,10 @@ def audit_process(
data["command"] = cmd
data["args"] = " ".join(cmd_args)
backend.audit(SubSystem.SSHD).write(operation, message, remote_ip, client, secret=None, secret_name=secret, **data)
backend.audit(SubSystem.SSHD).write(
operation, message, remote_ip, client, secret=None, secret_name=secret, **data
)
def audit_event(
backend: SshecretBackend,
@ -67,7 +70,10 @@ def audit_event(
"""Add an audit event."""
if not origin:
origin = "UNKNOWN"
backend.audit(SubSystem.SSHD).write(operation, message, origin, client, secret=None, secret_name=secret)
backend.audit(SubSystem.SSHD).write(
operation, message, origin, client, secret=None, secret_name=secret
)
def verify_key_input(public_key: str) -> str | None:
"""Verify key input."""
@ -118,14 +124,19 @@ def get_info_remote_ip(process: asyncssh.SSHServerProcess[str]) -> str | None:
return remote_ip
def get_info_allowed_registration(process: asyncssh.SSHServerProcess[str]) -> list[IPvAnyNetwork] | None:
def get_info_allowed_registration(
process: asyncssh.SSHServerProcess[str],
) -> list[IPvAnyNetwork] | None:
"""Get allowed networks to allow registration from."""
allowed_registration = cast(list[IPvAnyNetwork] | None, process.get_extra_info("allow_registration_from", None))
allowed_registration = cast(
list[IPvAnyNetwork] | None,
process.get_extra_info("allow_registration_from", None),
)
return allowed_registration
def get_optional_commands(process: asyncssh.SSHServerProcess[str]) -> dict[str, bool]:
"""Get optional command state."""
with_registration = cast(
@ -236,7 +247,12 @@ async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None
allowed_networks = get_info_allowed_registration(process)
if not allowed_networks:
process.stdout.write("Unauthorized.\n")
audit_process(backend, process, Operation.DENY, "Received registration command, but no subnets are allowed.")
audit_process(
backend,
process,
Operation.DENY,
"Received registration command, but no subnets are allowed.",
)
return
remote_ip = get_info_remote_ip(process)
@ -250,7 +266,12 @@ async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None
if client_address in network:
break
else:
audit_process(backend, process, Operation.DENY, "Received registration command from unauthorized subnet.")
audit_process(
backend,
process,
Operation.DENY,
"Received registration command from unauthorized subnet.",
)
process.stdout.write("Unauthorized.\n")
return
@ -369,7 +390,6 @@ class AsshyncServer(asyncssh.SSHServer):
self._conn.set_extra_info(client=client)
self._conn.set_authorized_keys(key)
else:
audit_event(
self.backend,
"Client denied due to policy",
@ -380,8 +400,12 @@ class AsshyncServer(asyncssh.SSHServer):
LOG.warning("Client connection denied due to policy.")
elif self.registration_enabled:
self._conn.set_extra_info(provided_username=username)
self._conn.set_extra_info(allow_registration_from=self.allow_registration_from)
LOG.warning("Registration enabled, and client is not recognized. Bypassing authentication.")
self._conn.set_extra_info(
allow_registration_from=self.allow_registration_from
)
LOG.warning(
"Registration enabled, and client is not recognized. Bypassing authentication."
)
return False
LOG.debug("Continuing to regular authentication")