Fix various small bugs

This commit is contained in:
2025-05-12 07:47:11 +02:00
parent a2ec2173ac
commit 458863de3d
6 changed files with 61 additions and 71 deletions

View File

@ -43,7 +43,10 @@ class PasswordContext:
) )
if entry and overwrite: if entry and overwrite:
entry.password = secret entry.password = secret
elif entry: self.keepass.save()
return
if entry:
raise ValueError("Error: A secret with this name already exists.") raise ValueError("Error: A secret with this name already exists.")
LOG.debug("Add secret entry to keepass: %s", entry_name) LOG.debug("Add secret entry to keepass: %s", entry_name)
entry = self.keepass.add_entry( entry = self.keepass.add_entry(

View File

@ -85,7 +85,7 @@ class SecretUpdate(BaseModel):
""" """
if isinstance(self.value, str): if isinstance(self.value, str):
return self.value return self.value
secret = secrets.token_urlsafe(self.value.length) secret = secrets.token_urlsafe(32)[:self.value.length]
return secret return secret

View File

@ -1,40 +1,17 @@
"""Testing helper functions.""" """Testing helper functions.
This allows creation of a user from within tests.
"""
import os import os
import bcrypt import bcrypt
from sqlmodel import Session
from sshecret_admin.auth.models import User from sshecret_admin.auth.models import User
def get_test_user_details() -> tuple[str, str]:
"""Resolve testing user."""
test_user = os.getenv("SSHECRET_TEST_USERNAME") or "test"
test_password = os.getenv("SSHECRET_TEST_PASSWORD") or "test"
if test_user and test_password:
return (test_user, test_password)
raise RuntimeError(
"Error: No testing username and password registered in environment."
)
def is_testing_mode() -> bool:
"""Check if we're running in test mode.
We will determine this by looking for the environment variable SSHECRET_TEST_MODE=1
"""
if os.environ.get("PYTEST_VERSION") is not None:
return True
return False
def create_test_user(session: Session, username: str, password: str) -> User: def create_test_user(session: Session, username: str, password: str) -> User:
"""Create test user. """Create test user."""
We create a user with whatever username and password is supplied.
"""
salt = bcrypt.gensalt() salt = bcrypt.gensalt()
hashed_password = bcrypt.hashpw(password.encode(), salt) hashed_password = bcrypt.hashpw(password.encode(), salt)
user = User(username=username, hashed_password=hashed_password.decode()) user = User(username=username, hashed_password=hashed_password.decode())

View File

@ -1,2 +0,0 @@
def hello() -> str:
return "Hello from sshecret-sshd!"

View File

@ -6,7 +6,7 @@ ERROR_SOURCE_IP_NOT_ALLOWED = (
) )
ERROR_NO_PUBLIC_KEY = "Error: No valid public key received." ERROR_NO_PUBLIC_KEY = "Error: No valid public key received."
ERROR_INVALID_KEY_TYPE = "Error: Invalid key type: Only RSA keys are supported." ERROR_INVALID_KEY_TYPE = "Error: Invalid key type: Only RSA keys are supported."
ERROR_UNKNOWN_COMMAND = "Error: The given command was not understood." ERROR_UNKNOWN_COMMAND = "Error: Unsupported command."
SERVER_KEY_TYPE = "ed25519" SERVER_KEY_TYPE = "ed25519"
ERROR_BACKEND_ERROR = "Error: Unexpected response or error from backend" ERROR_BACKEND_ERROR = "Error: Unexpected response or error from backend"
ERROR_INFO_BACKEND_GONE = "Unexpected error: Backend connection lost." ERROR_INFO_BACKEND_GONE = "Unexpected error: Backend connection lost."

View File

@ -1,5 +1,6 @@
"""SSH Server implementation.""" """SSH Server implementation."""
from asyncio import _register_task
import logging import logging
import asyncssh import asyncssh
@ -66,12 +67,13 @@ async def audit_event(
client: Client | None = None, client: Client | None = None,
origin: str | None = None, origin: str | None = None,
secret: str | None = None, secret: str | None = None,
**data: str,
) -> None: ) -> None:
"""Add an audit event.""" """Add an audit event."""
if not origin: if not origin:
origin = "UNKNOWN" origin = "UNKNOWN"
await backend.audit(SubSystem.SSHD).write_async( await backend.audit(SubSystem.SSHD).write_async(
operation, message, origin, client, secret=None, secret_name=secret operation, message, origin, client, secret=None, secret_name=secret, **data
) )
@ -158,22 +160,14 @@ async def get_stdin_public_key(process: asyncssh.SSHServerProcess[str]) -> str |
public_key = verify_key_input(line.rstrip("\n")) public_key = verify_key_input(line.rstrip("\n"))
if public_key: if public_key:
break break
process.stdout.write("Invalid key. Must be RSA Public Key.\n") raise CommandError(constants.ERROR_INVALID_KEY_TYPE)
except asyncssh.BreakReceived: except asyncssh.BreakReceived:
pass pass
else:
process.stdout.write("OK\n") process.stdout.write("OK\n")
return public_key return public_key
def get_info_user_and_public_key(
process: asyncssh.SSHServerProcess[str],
) -> tuple[str | None, str | None]:
"""Get username and public_key from process."""
username = cast("str | None", process.get_extra_info("provided_username", None))
public_key = cast("str | None", process.get_extra_info("provided_key", None))
return (username, public_key)
async def register_client( async def register_client(
process: asyncssh.SSHServerProcess[str], process: asyncssh.SSHServerProcess[str],
backend: SshecretBackend, backend: SshecretBackend,
@ -381,8 +375,14 @@ class AsshyncServer(asyncssh.SSHServer):
""" """
LOG.debug("Started authentication flow for user %s", username) LOG.debug("Started authentication flow for user %s", username)
if not self._conn: allowed_registration_sources: list[IPvAnyNetwork] = []
return True if self.registration_enabled and not self.allow_registration_from:
allowed_registration_sources.append(ipaddress.IPv4Network("0.0.0.0/0"))
allowed_registration_sources.append(ipaddress.IPv6Network("::/0"))
elif self.registration_enabled and self.allow_registration_from:
allowed_registration_sources = self.allow_registration_from
assert self._conn is not None, "Error: No connection found."
if client := await self.backend.get_client(username): if client := await self.backend.get_client(username):
LOG.debug("Client lookup sucessful: %r", client) LOG.debug("Client lookup sucessful: %r", client)
if key := self.resolve_client_key(client): if key := self.resolve_client_key(client):
@ -397,33 +397,43 @@ class AsshyncServer(asyncssh.SSHServer):
client, client,
origin=self.client_ip, origin=self.client_ip,
) )
LOG.warning("Client connection denied due to policy.") LOG.warning(
elif self.registration_enabled: "Client connection denied. Source: %s, policy: %r.",
self.client_ip,
client.policies,
)
elif allowed_registration_sources and self.client_ip:
client_ip = ipaddress.ip_address(self.client_ip)
for network in allowed_registration_sources:
if client_ip.version != network.version:
continue
if client_ip in network:
self._conn.set_extra_info(provided_username=username) self._conn.set_extra_info(provided_username=username)
self._conn.set_extra_info( self._conn.set_extra_info(
allow_registration_from=self.allow_registration_from allow_registration_from=self.allow_registration_from
) )
LOG.warning( LOG.info(
"Registration enabled, and client is not recognized. Bypassing authentication." "Registration enabled, and client is not recognized. Bypassing authentication."
) )
return False return False
else:
await audit_event(
self.backend,
"Received registration command from unauthorized subnet.",
Operation.DENY,
origin=self.client_ip,
username=username,
)
LOG.warning(
"Registration not permitted for username=%s, origin: %s",
username,
self.client_ip,
)
LOG.debug("Continuing to regular authentication") LOG.debug("Continuing to regular authentication")
return True return True
@override
def validate_public_key(self, username: str, key: asyncssh.SSHKey) -> bool:
"""Intercept public key validation."""
if not self._conn:
return False
# get an export of the provided public key.
keystring = key.export_public_key().decode()
self._conn.set_extra_info(provided_username=username)
self._conn.set_extra_info(provided_key=keystring)
LOG.debug("Intercepting user public key")
return False
def resolve_client_key(self, client: Client) -> asyncssh.SSHAuthorizedKeys | None: def resolve_client_key(self, client: Client) -> asyncssh.SSHAuthorizedKeys | None:
"""Resolve the client key. """Resolve the client key.
@ -492,7 +502,9 @@ async def run_ssh_server(
return server return server
async def start_sshecret_sshd(settings: ServerSettings | None = None) -> asyncssh.SSHAcceptor: async def start_sshecret_sshd(
settings: ServerSettings | None = None,
) -> asyncssh.SSHAcceptor:
"""Start the server.""" """Start the server."""
server_key = get_server_key() server_key = get_server_key()