"""Authentication utilities.""" import logging from datetime import datetime, timezone, timedelta from typing import cast, Any import bcrypt from joserfc import jwt from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from joserfc.jwk import OctKey from joserfc.errors import JoseError from sshecret_admin.core.settings import AdminServerSettings from .models import AuthProvider, LocalUserInfo, User, IdentityClaims from .exceptions import AuthenticationFailedError from .constants import ( JWT_ALGORITHM, ACCESS_TOKEN_EXPIRE_MINUTES, REFRESH_TOKEN_EXPIRE_HOURS, LOCAL_ISSUER, ) LOG = logging.getLogger(__name__) def create_token( settings: AdminServerSettings, data: dict[str, Any], expires_delta: timedelta, provider: str, ) -> str: """Create access token.""" to_encode = data.copy() expire = datetime.now(timezone.utc) + expires_delta to_encode.update({"exp": expire, "iss": provider}) key = OctKey.import_key(settings.secret_key) encoded_jwt = jwt.encode({"alg": JWT_ALGORITHM}, to_encode, key) return str(encoded_jwt) def create_access_token( settings: AdminServerSettings, data: dict[str, Any], expires_delta: timedelta | None = None, provider: str = LOCAL_ISSUER, ) -> str: """Create access token.""" if not expires_delta: expires_delta = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) return create_token(settings, data, expires_delta, provider) def create_refresh_token( settings: AdminServerSettings, data: dict[str, Any], expires_delta: timedelta | None = None, provider: str = LOCAL_ISSUER, ) -> str: """Create access token.""" if not expires_delta: expires_delta = timedelta(hours=REFRESH_TOKEN_EXPIRE_HOURS) return create_token(settings, data, expires_delta, provider) def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify password against stored hash.""" return bcrypt.checkpw(plain_password.encode(), hashed_password.encode()) def check_password(plain_password: str, hashed_password: str) -> None: """Check password. If password doesn't match, throw AuthenticationFailedError. """ if not verify_password(plain_password, hashed_password): raise AuthenticationFailedError() async def authenticate_user_async( session: AsyncSession, username: str, password: str ) -> User | None: """Authenticate user async.""" user = ( await session.scalars(select(User).where(User.username == username)) ).first() if not user: return None if not verify_password(password, user.hashed_password): return None return user async def handle_oidc_claim(session: AsyncSession, claim: IdentityClaims) -> User: """Handle OIDC claim. Either return an existing user, or create a new one. """ LOG.debug("Looking up OIDC token claim %r", claim) if claim.provider == LOCAL_ISSUER: raise ValueError("IdentityClaims do not originate from OIDC.") query = ( select(User) .where(User.oidc_sub == claim.sub) .where(User.oidc_issuer == claim.provider) ) result = await session.execute(query) if user := result.scalar_one_or_none(): LOG.debug("Found existing user %s", user.id) return user LOG.debug("User not found in local database. Creating a new user") user = User( username=claim.username, email=claim.email, disabled=False, oidc_sub=claim.sub, oidc_issuer=claim.provider, provider=AuthProvider.OIDC, ) session.add(user) await session.commit() query = ( select(User) .where(User.oidc_sub == claim.sub) .where(User.oidc_issuer == claim.provider) ) result = await session.execute(query) return result.scalar_one() def authenticate_user(session: Session, username: str, password: str) -> User | None: """Authenticate user.""" user = session.scalars(select(User).where(User.username == username)).first() if not user: return None if not verify_password(password, user.hashed_password): return None return user def decode_token(settings: AdminServerSettings, token: str) -> IdentityClaims | None: """Decode token.""" key = OctKey.import_key(settings.secret_key) try: decoded = jwt.decode(token, key) claims_requests = jwt.JWTClaimsRegistry( exp={"essential": True}, sub={"essential": True}, ) claims_requests.validate(decoded.claims) payload = jwt.decode(token, settings.secret_key, algorithms=[JWT_ALGORITHM]) sub = cast("str | None", payload.claims.get("sub")) if not sub: return None issuer = payload.claims.get("iss") or LOCAL_ISSUER identity_claims = IdentityClaims(sub=sub, provider=issuer) if issuer == LOCAL_ISSUER: identity_claims.username = sub return identity_claims except JoseError as e: LOG.debug("Could not decode token: %s", e, exc_info=True) return None def hash_password(password: str) -> str: """Hash password.""" salt = bcrypt.gensalt() hashed_password = bcrypt.hashpw(password.encode(), salt) return hashed_password.decode() def generate_user_info(user: User) -> LocalUserInfo: """Generate user info object from a user entry.""" is_local = user.provider == AuthProvider.LOCAL if user.username: LOG.info("User has a username: %s", user.username) return LocalUserInfo(id=user.id, display_name=user.username, local=is_local) assert user.email is not None LOG.info("User has no username") return LocalUserInfo(id=user.id, display_name=user.email, local=is_local)