"""Authentication utilities.""" import logging from datetime import datetime, timezone, timedelta from typing import cast, Any import bcrypt import jwt from sqlmodel import Session, select from sshecret_admin.core.settings import AdminServerSettings from .models import User, TokenData from .exceptions import AuthenticationFailedError JWT_ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 30 # I know refresh tokens are supposed to be long-lived, but 6 hours for a # sensitive application, seems reasonable. REFRESH_TOKEN_EXPIRE_HOURS = 6 LOG = logging.getLogger(__name__) def create_token( settings: AdminServerSettings, data: dict[str, Any], expires_delta: timedelta, ) -> str: """Create access token.""" to_encode = data.copy() expire = datetime.now(timezone.utc) + expires_delta to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=JWT_ALGORITHM) return str(encoded_jwt) def create_access_token( settings: AdminServerSettings, data: dict[str, Any], expires_delta: timedelta | None = None, ) -> str: """Create access token.""" if not expires_delta: expires_delta = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) return create_token(settings, data, expires_delta) def create_refresh_token( settings: AdminServerSettings, data: dict[str, Any], expires_delta: timedelta | None = None, ) -> str: """Create access token.""" if not expires_delta: expires_delta = timedelta(hours=REFRESH_TOKEN_EXPIRE_HOURS) return create_token(settings, data, expires_delta) def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify password against stored hash.""" return bcrypt.checkpw(plain_password.encode(), hashed_password.encode()) def check_password(plain_password: str, hashed_password: str) -> None: """Check password. If password doesn't match, throw AuthenticationFailedError. """ if not verify_password(plain_password, hashed_password): raise AuthenticationFailedError() def authenticate_user(session: Session, username: str, password: str) -> User | None: """Authenticate user.""" user = session.exec(select(User).where(User.username == username)).first() if not user: return None if not verify_password(password, user.hashed_password): return None return user def decode_token(settings: AdminServerSettings, token: str) -> TokenData | None: """Decode token.""" try: payload = jwt.decode(token, settings.secret_key, algorithms=[JWT_ALGORITHM]) username = cast("str | None", payload.get("sub")) if not username: return None token_data = TokenData(username=username) return token_data except jwt.InvalidTokenError as e: LOG.debug("Could not decode token: %s", e, exc_info=True) return None