Files
sshecret/packages/sshecret-admin/src/sshecret_admin/auth/authentication.py
2025-05-30 10:59:09 +02:00

192 lines
5.7 KiB
Python

"""Authentication utilities."""
import logging
from datetime import datetime, timezone, timedelta
from typing import cast, Any
import bcrypt
from joserfc import jwt
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from joserfc.jwk import OctKey
from joserfc.errors import JoseError
from sshecret_admin.core.settings import AdminServerSettings
from .models import AuthProvider, LocalUserInfo, User, IdentityClaims
from .exceptions import AuthenticationFailedError
from .constants import (
JWT_ALGORITHM,
ACCESS_TOKEN_EXPIRE_MINUTES,
REFRESH_TOKEN_EXPIRE_HOURS,
LOCAL_ISSUER,
)
LOG = logging.getLogger(__name__)
def create_token(
settings: AdminServerSettings,
data: dict[str, Any],
expires_delta: timedelta,
provider: str,
) -> str:
"""Create access token."""
to_encode = data.copy()
expire = datetime.now(timezone.utc) + expires_delta
to_encode.update({"exp": expire, "iss": provider})
key = OctKey.import_key(settings.secret_key)
encoded_jwt = jwt.encode({"alg": JWT_ALGORITHM}, to_encode, key)
return str(encoded_jwt)
def create_access_token(
settings: AdminServerSettings,
data: dict[str, Any],
expires_delta: timedelta | None = None,
provider: str = LOCAL_ISSUER,
) -> str:
"""Create access token."""
if not expires_delta:
expires_delta = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
return create_token(settings, data, expires_delta, provider)
def create_refresh_token(
settings: AdminServerSettings,
data: dict[str, Any],
expires_delta: timedelta | None = None,
provider: str = LOCAL_ISSUER,
) -> str:
"""Create access token."""
if not expires_delta:
expires_delta = timedelta(hours=REFRESH_TOKEN_EXPIRE_HOURS)
return create_token(settings, data, expires_delta, provider)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify password against stored hash."""
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
def check_password(plain_password: str, hashed_password: str) -> None:
"""Check password.
If password doesn't match, throw AuthenticationFailedError.
"""
if not verify_password(plain_password, hashed_password):
raise AuthenticationFailedError()
async def authenticate_user_async(
session: AsyncSession, username: str, password: str
) -> User | None:
"""Authenticate user async."""
user = (
await session.scalars(select(User).where(User.username == username))
).first()
if not user:
return None
if not verify_password(password, user.hashed_password):
return None
return user
async def handle_oidc_claim(session: AsyncSession, claim: IdentityClaims) -> User:
"""Handle OIDC claim.
Either return an existing user, or create a new one.
"""
LOG.debug("Looking up OIDC token claim %r", claim)
if claim.provider == LOCAL_ISSUER:
raise ValueError("IdentityClaims do not originate from OIDC.")
query = (
select(User)
.where(User.oidc_sub == claim.sub)
.where(User.oidc_issuer == claim.provider)
)
result = await session.execute(query)
if user := result.scalar_one_or_none():
LOG.debug("Found existing user %s", user.id)
return user
LOG.debug("User not found in local database. Creating a new user")
user = User(
username=claim.username,
email=claim.email,
disabled=False,
oidc_sub=claim.sub,
oidc_issuer=claim.provider,
provider=AuthProvider.OIDC,
)
session.add(user)
await session.commit()
query = (
select(User)
.where(User.oidc_sub == claim.sub)
.where(User.oidc_issuer == claim.provider)
)
result = await session.execute(query)
return result.scalar_one()
def authenticate_user(session: Session, username: str, password: str) -> User | None:
"""Authenticate user."""
user = session.scalars(select(User).where(User.username == username)).first()
if not user:
return None
if not verify_password(password, user.hashed_password):
return None
return user
def decode_token(settings: AdminServerSettings, token: str) -> IdentityClaims | None:
"""Decode token."""
key = OctKey.import_key(settings.secret_key)
try:
decoded = jwt.decode(token, key)
claims_requests = jwt.JWTClaimsRegistry(
exp={"essential": True},
sub={"essential": True},
)
claims_requests.validate(decoded.claims)
payload = jwt.decode(token, settings.secret_key, algorithms=[JWT_ALGORITHM])
sub = cast("str | None", payload.claims.get("sub"))
if not sub:
return None
issuer = payload.claims.get("iss") or LOCAL_ISSUER
identity_claims = IdentityClaims(sub=sub, provider=issuer)
if issuer == LOCAL_ISSUER:
identity_claims.username = sub
return identity_claims
except JoseError as e:
LOG.debug("Could not decode token: %s", e, exc_info=True)
return None
def hash_password(password: str) -> str:
"""Hash password."""
salt = bcrypt.gensalt()
hashed_password = bcrypt.hashpw(password.encode(), salt)
return hashed_password.decode()
def generate_user_info(user: User) -> LocalUserInfo:
"""Generate user info object from a user entry."""
is_local = user.provider == AuthProvider.LOCAL
if user.username:
LOG.info("User has a username: %s", user.username)
return LocalUserInfo(id=user.id, display_name=user.username, local=is_local)
assert user.email is not None
LOG.info("User has no username")
return LocalUserInfo(id=user.id, display_name=user.email, local=is_local)