Reformat and lint
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
"""CLI app."""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import sys
|
||||
@ -12,7 +13,9 @@ from .ssh_server import start_server
|
||||
LOG = logging.getLogger()
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter("%(created)f:%(levelname)s:%(name)s:%(module)s:%(message)s")
|
||||
formatter = logging.Formatter(
|
||||
"%(created)f:%(levelname)s:%(name)s:%(module)s:%(message)s"
|
||||
)
|
||||
|
||||
handler.setFormatter(formatter)
|
||||
LOG.addHandler(handler)
|
||||
|
||||
@ -11,4 +11,6 @@ SERVER_KEY_TYPE = "ed25519"
|
||||
ERROR_BACKEND_ERROR = "Error: Unexpected response or error from backend"
|
||||
ERROR_INFO_BACKEND_GONE = "Unexpected error: Backend connection lost."
|
||||
ERROR_INFO_USERNAME_GONE = "Unexpected error: Username lost in transit."
|
||||
ERROR_INFO_REMOTE_IP_GONE = "Unexpected error: Client connection details lost in transit."
|
||||
ERROR_INFO_REMOTE_IP_GONE = (
|
||||
"Unexpected error: Client connection details lost in transit."
|
||||
)
|
||||
|
||||
@ -13,9 +13,11 @@ class ClientRegistrationSettings(BaseModel):
|
||||
"""Client registration settings."""
|
||||
|
||||
enabled: bool = False
|
||||
allow_from: Annotated[list[IPvAnyNetwork], ForceDecode] = Field(default_factory=list)
|
||||
allow_from: Annotated[list[IPvAnyNetwork], ForceDecode] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
@field_validator('allow_from', mode="before")
|
||||
@field_validator("allow_from", mode="before")
|
||||
@classmethod
|
||||
def ensure_allow_from_list(cls, value: Any) -> list[IPvAnyNetwork]:
|
||||
"""Convert allow_from to a list."""
|
||||
@ -34,15 +36,20 @@ class ClientRegistrationSettings(BaseModel):
|
||||
allow_from.append(entry)
|
||||
return allow_from
|
||||
|
||||
|
||||
class ServerSettings(BaseSettings):
|
||||
"""Server Settings."""
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".sshd.env", env_prefix="sshecret_sshd_", env_nested_delimiter='_')
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".sshd.env", env_prefix="sshecret_sshd_", env_nested_delimiter="_"
|
||||
)
|
||||
|
||||
backend_url: AnyHttpUrl = Field(alias="sshecret_backend_url")
|
||||
backend_token: str
|
||||
listen_address: str = Field(default="127.0.0.1")
|
||||
port: int = DEFAULT_LISTEN_PORT
|
||||
registration: ClientRegistrationSettings = Field(default_factory=ClientRegistrationSettings)
|
||||
registration: ClientRegistrationSettings = Field(
|
||||
default_factory=ClientRegistrationSettings
|
||||
)
|
||||
debug: bool = False
|
||||
enable_ping_command: bool = False
|
||||
|
||||
@ -54,7 +54,10 @@ def audit_process(
|
||||
data["command"] = cmd
|
||||
data["args"] = " ".join(cmd_args)
|
||||
|
||||
backend.audit(SubSystem.SSHD).write(operation, message, remote_ip, client, secret=None, secret_name=secret, **data)
|
||||
backend.audit(SubSystem.SSHD).write(
|
||||
operation, message, remote_ip, client, secret=None, secret_name=secret, **data
|
||||
)
|
||||
|
||||
|
||||
def audit_event(
|
||||
backend: SshecretBackend,
|
||||
@ -67,7 +70,10 @@ def audit_event(
|
||||
"""Add an audit event."""
|
||||
if not origin:
|
||||
origin = "UNKNOWN"
|
||||
backend.audit(SubSystem.SSHD).write(operation, message, origin, client, secret=None, secret_name=secret)
|
||||
backend.audit(SubSystem.SSHD).write(
|
||||
operation, message, origin, client, secret=None, secret_name=secret
|
||||
)
|
||||
|
||||
|
||||
def verify_key_input(public_key: str) -> str | None:
|
||||
"""Verify key input."""
|
||||
@ -118,14 +124,19 @@ def get_info_remote_ip(process: asyncssh.SSHServerProcess[str]) -> str | None:
|
||||
|
||||
return remote_ip
|
||||
|
||||
def get_info_allowed_registration(process: asyncssh.SSHServerProcess[str]) -> list[IPvAnyNetwork] | None:
|
||||
|
||||
def get_info_allowed_registration(
|
||||
process: asyncssh.SSHServerProcess[str],
|
||||
) -> list[IPvAnyNetwork] | None:
|
||||
"""Get allowed networks to allow registration from."""
|
||||
|
||||
allowed_registration = cast(list[IPvAnyNetwork] | None, process.get_extra_info("allow_registration_from", None))
|
||||
allowed_registration = cast(
|
||||
list[IPvAnyNetwork] | None,
|
||||
process.get_extra_info("allow_registration_from", None),
|
||||
)
|
||||
return allowed_registration
|
||||
|
||||
|
||||
|
||||
def get_optional_commands(process: asyncssh.SSHServerProcess[str]) -> dict[str, bool]:
|
||||
"""Get optional command state."""
|
||||
with_registration = cast(
|
||||
@ -236,7 +247,12 @@ async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None
|
||||
allowed_networks = get_info_allowed_registration(process)
|
||||
if not allowed_networks:
|
||||
process.stdout.write("Unauthorized.\n")
|
||||
audit_process(backend, process, Operation.DENY, "Received registration command, but no subnets are allowed.")
|
||||
audit_process(
|
||||
backend,
|
||||
process,
|
||||
Operation.DENY,
|
||||
"Received registration command, but no subnets are allowed.",
|
||||
)
|
||||
return
|
||||
|
||||
remote_ip = get_info_remote_ip(process)
|
||||
@ -250,7 +266,12 @@ async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None
|
||||
if client_address in network:
|
||||
break
|
||||
else:
|
||||
audit_process(backend, process, Operation.DENY, "Received registration command from unauthorized subnet.")
|
||||
audit_process(
|
||||
backend,
|
||||
process,
|
||||
Operation.DENY,
|
||||
"Received registration command from unauthorized subnet.",
|
||||
)
|
||||
process.stdout.write("Unauthorized.\n")
|
||||
return
|
||||
|
||||
@ -369,7 +390,6 @@ class AsshyncServer(asyncssh.SSHServer):
|
||||
self._conn.set_extra_info(client=client)
|
||||
self._conn.set_authorized_keys(key)
|
||||
else:
|
||||
|
||||
audit_event(
|
||||
self.backend,
|
||||
"Client denied due to policy",
|
||||
@ -380,8 +400,12 @@ class AsshyncServer(asyncssh.SSHServer):
|
||||
LOG.warning("Client connection denied due to policy.")
|
||||
elif self.registration_enabled:
|
||||
self._conn.set_extra_info(provided_username=username)
|
||||
self._conn.set_extra_info(allow_registration_from=self.allow_registration_from)
|
||||
LOG.warning("Registration enabled, and client is not recognized. Bypassing authentication.")
|
||||
self._conn.set_extra_info(
|
||||
allow_registration_from=self.allow_registration_from
|
||||
)
|
||||
LOG.warning(
|
||||
"Registration enabled, and client is not recognized. Bypassing authentication."
|
||||
)
|
||||
return False
|
||||
|
||||
LOG.debug("Continuing to regular authentication")
|
||||
|
||||
Reference in New Issue
Block a user