192 lines
5.7 KiB
Python
192 lines
5.7 KiB
Python
"""Authentication utilities."""
|
|
|
|
import logging
|
|
from datetime import datetime, timezone, timedelta
|
|
from typing import cast, Any
|
|
|
|
import bcrypt
|
|
from joserfc import jwt
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import Session
|
|
|
|
from joserfc.jwk import OctKey
|
|
from joserfc.errors import JoseError
|
|
|
|
|
|
from sshecret_admin.core.settings import AdminServerSettings
|
|
|
|
from .models import AuthProvider, LocalUserInfo, User, IdentityClaims
|
|
from .exceptions import AuthenticationFailedError
|
|
from .constants import (
|
|
JWT_ALGORITHM,
|
|
ACCESS_TOKEN_EXPIRE_MINUTES,
|
|
REFRESH_TOKEN_EXPIRE_HOURS,
|
|
LOCAL_ISSUER,
|
|
)
|
|
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
|
|
def create_token(
|
|
settings: AdminServerSettings,
|
|
data: dict[str, Any],
|
|
expires_delta: timedelta,
|
|
provider: str,
|
|
) -> str:
|
|
"""Create access token."""
|
|
to_encode = data.copy()
|
|
expire = datetime.now(timezone.utc) + expires_delta
|
|
to_encode.update({"exp": expire, "iss": provider})
|
|
key = OctKey.import_key(settings.secret_key)
|
|
encoded_jwt = jwt.encode({"alg": JWT_ALGORITHM}, to_encode, key)
|
|
return str(encoded_jwt)
|
|
|
|
|
|
def create_access_token(
|
|
settings: AdminServerSettings,
|
|
data: dict[str, Any],
|
|
expires_delta: timedelta | None = None,
|
|
provider: str = LOCAL_ISSUER,
|
|
) -> str:
|
|
"""Create access token."""
|
|
if not expires_delta:
|
|
expires_delta = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
return create_token(settings, data, expires_delta, provider)
|
|
|
|
|
|
def create_refresh_token(
|
|
settings: AdminServerSettings,
|
|
data: dict[str, Any],
|
|
expires_delta: timedelta | None = None,
|
|
provider: str = LOCAL_ISSUER,
|
|
) -> str:
|
|
"""Create access token."""
|
|
if not expires_delta:
|
|
expires_delta = timedelta(hours=REFRESH_TOKEN_EXPIRE_HOURS)
|
|
return create_token(settings, data, expires_delta, provider)
|
|
|
|
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
"""Verify password against stored hash."""
|
|
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
|
|
|
|
|
|
def check_password(plain_password: str, hashed_password: str) -> None:
|
|
"""Check password.
|
|
|
|
If password doesn't match, throw AuthenticationFailedError.
|
|
"""
|
|
if not verify_password(plain_password, hashed_password):
|
|
raise AuthenticationFailedError()
|
|
|
|
|
|
async def authenticate_user_async(
|
|
session: AsyncSession, username: str, password: str
|
|
) -> User | None:
|
|
"""Authenticate user async."""
|
|
user = (
|
|
await session.scalars(select(User).where(User.username == username))
|
|
).first()
|
|
if not user:
|
|
return None
|
|
if not verify_password(password, user.hashed_password):
|
|
return None
|
|
return user
|
|
|
|
|
|
async def handle_oidc_claim(session: AsyncSession, claim: IdentityClaims) -> User:
|
|
"""Handle OIDC claim.
|
|
|
|
Either return an existing user, or create a new one.
|
|
"""
|
|
LOG.debug("Looking up OIDC token claim %r", claim)
|
|
if claim.provider == LOCAL_ISSUER:
|
|
raise ValueError("IdentityClaims do not originate from OIDC.")
|
|
query = (
|
|
select(User)
|
|
.where(User.oidc_sub == claim.sub)
|
|
.where(User.oidc_issuer == claim.provider)
|
|
)
|
|
result = await session.execute(query)
|
|
if user := result.scalar_one_or_none():
|
|
LOG.debug("Found existing user %s", user.id)
|
|
return user
|
|
|
|
LOG.debug("User not found in local database. Creating a new user")
|
|
user = User(
|
|
username=claim.username,
|
|
email=claim.email,
|
|
disabled=False,
|
|
oidc_sub=claim.sub,
|
|
oidc_issuer=claim.provider,
|
|
provider=AuthProvider.OIDC,
|
|
)
|
|
session.add(user)
|
|
await session.commit()
|
|
query = (
|
|
select(User)
|
|
.where(User.oidc_sub == claim.sub)
|
|
.where(User.oidc_issuer == claim.provider)
|
|
)
|
|
result = await session.execute(query)
|
|
return result.scalar_one()
|
|
|
|
|
|
def authenticate_user(session: Session, username: str, password: str) -> User | None:
|
|
"""Authenticate user."""
|
|
user = session.scalars(select(User).where(User.username == username)).first()
|
|
if not user:
|
|
return None
|
|
if not verify_password(password, user.hashed_password):
|
|
return None
|
|
return user
|
|
|
|
|
|
def decode_token(settings: AdminServerSettings, token: str) -> IdentityClaims | None:
|
|
"""Decode token."""
|
|
key = OctKey.import_key(settings.secret_key)
|
|
try:
|
|
decoded = jwt.decode(token, key)
|
|
claims_requests = jwt.JWTClaimsRegistry(
|
|
exp={"essential": True},
|
|
sub={"essential": True},
|
|
)
|
|
|
|
claims_requests.validate(decoded.claims)
|
|
payload = jwt.decode(token, settings.secret_key, algorithms=[JWT_ALGORITHM])
|
|
sub = cast("str | None", payload.claims.get("sub"))
|
|
if not sub:
|
|
return None
|
|
|
|
issuer = payload.claims.get("iss") or LOCAL_ISSUER
|
|
|
|
identity_claims = IdentityClaims(sub=sub, provider=issuer)
|
|
if issuer == LOCAL_ISSUER:
|
|
identity_claims.username = sub
|
|
|
|
return identity_claims
|
|
|
|
except JoseError as e:
|
|
LOG.debug("Could not decode token: %s", e, exc_info=True)
|
|
return None
|
|
|
|
|
|
def hash_password(password: str) -> str:
|
|
"""Hash password."""
|
|
salt = bcrypt.gensalt()
|
|
hashed_password = bcrypt.hashpw(password.encode(), salt)
|
|
return hashed_password.decode()
|
|
|
|
|
|
def generate_user_info(user: User) -> LocalUserInfo:
|
|
"""Generate user info object from a user entry."""
|
|
is_local = user.provider == AuthProvider.LOCAL
|
|
if user.username:
|
|
LOG.info("User has a username: %s", user.username)
|
|
return LocalUserInfo(id=user.id, display_name=user.username, local=is_local)
|
|
assert user.email is not None
|
|
LOG.info("User has no username")
|
|
return LocalUserInfo(id=user.id, display_name=user.email, local=is_local)
|