Implement oidc login

This commit is contained in:
2025-05-30 10:57:59 +02:00
parent b491dff4b1
commit 391e310b91
39 changed files with 938 additions and 308 deletions

View File

@ -5,21 +5,26 @@ from datetime import datetime, timezone, timedelta
from typing import cast, Any
import bcrypt
import jwt
from joserfc import jwt
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from joserfc.jwk import OctKey
from joserfc.errors import JoseError
from sshecret_admin.core.settings import AdminServerSettings
from .models import User, TokenData
from .models import AuthProvider, LocalUserInfo, User, IdentityClaims
from .exceptions import AuthenticationFailedError
JWT_ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
# I know refresh tokens are supposed to be long-lived, but 6 hours for a
# sensitive application, seems reasonable.
REFRESH_TOKEN_EXPIRE_HOURS = 6
from .constants import (
JWT_ALGORITHM,
ACCESS_TOKEN_EXPIRE_MINUTES,
REFRESH_TOKEN_EXPIRE_HOURS,
LOCAL_ISSUER,
)
LOG = logging.getLogger(__name__)
@ -28,12 +33,14 @@ def create_token(
settings: AdminServerSettings,
data: dict[str, Any],
expires_delta: timedelta,
provider: str,
) -> str:
"""Create access token."""
to_encode = data.copy()
expire = datetime.now(timezone.utc) + expires_delta
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=JWT_ALGORITHM)
to_encode.update({"exp": expire, "iss": provider})
key = OctKey.import_key(settings.secret_key)
encoded_jwt = jwt.encode({"alg": JWT_ALGORITHM}, to_encode, key)
return str(encoded_jwt)
@ -41,22 +48,24 @@ def create_access_token(
settings: AdminServerSettings,
data: dict[str, Any],
expires_delta: timedelta | None = None,
provider: str = LOCAL_ISSUER,
) -> str:
"""Create access token."""
if not expires_delta:
expires_delta = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
return create_token(settings, data, expires_delta)
return create_token(settings, data, expires_delta, provider)
def create_refresh_token(
settings: AdminServerSettings,
data: dict[str, Any],
expires_delta: timedelta | None = None,
provider: str = LOCAL_ISSUER,
) -> str:
"""Create access token."""
if not expires_delta:
expires_delta = timedelta(hours=REFRESH_TOKEN_EXPIRE_HOURS)
return create_token(settings, data, expires_delta)
return create_token(settings, data, expires_delta, provider)
def verify_password(plain_password: str, hashed_password: str) -> bool:
@ -73,9 +82,13 @@ def check_password(plain_password: str, hashed_password: str) -> None:
raise AuthenticationFailedError()
async def authenticate_user_async(session: AsyncSession, username: str, password: str) -> User | None:
async def authenticate_user_async(
session: AsyncSession, username: str, password: str
) -> User | None:
"""Authenticate user async."""
user = (await session.scalars(select(User).where(User.username == username))).first()
user = (
await session.scalars(select(User).where(User.username == username))
).first()
if not user:
return None
if not verify_password(password, user.hashed_password):
@ -83,6 +96,44 @@ async def authenticate_user_async(session: AsyncSession, username: str, password
return user
async def handle_oidc_claim(session: AsyncSession, claim: IdentityClaims) -> User:
"""Handle OIDC claim.
Either return an existing user, or create a new one.
"""
LOG.debug("Looking up OIDC token claim %r", claim)
if claim.provider == LOCAL_ISSUER:
raise ValueError("IdentityClaims do not originate from OIDC.")
query = (
select(User)
.where(User.oidc_sub == claim.sub)
.where(User.oidc_issuer == claim.provider)
)
result = await session.execute(query)
if user := result.scalar_one_or_none():
LOG.debug("Found existing user %s", user.id)
return user
LOG.debug("User not found in local database. Creating a new user")
user = User(
username=claim.username,
email=claim.email,
disabled=False,
oidc_sub=claim.sub,
oidc_issuer=claim.provider,
provider=AuthProvider.OIDC,
)
session.add(user)
await session.commit()
query = (
select(User)
.where(User.oidc_sub == claim.sub)
.where(User.oidc_issuer == claim.provider)
)
result = await session.execute(query)
return result.scalar_one()
def authenticate_user(session: Session, username: str, password: str) -> User | None:
"""Authenticate user."""
user = session.scalars(select(User).where(User.username == username)).first()
@ -93,22 +144,48 @@ def authenticate_user(session: Session, username: str, password: str) -> User |
return user
def decode_token(settings: AdminServerSettings, token: str) -> TokenData | None:
def decode_token(settings: AdminServerSettings, token: str) -> IdentityClaims | None:
"""Decode token."""
key = OctKey.import_key(settings.secret_key)
try:
decoded = jwt.decode(token, key)
claims_requests = jwt.JWTClaimsRegistry(
exp={"essential": True},
sub={"essential": True},
)
claims_requests.validate(decoded.claims)
payload = jwt.decode(token, settings.secret_key, algorithms=[JWT_ALGORITHM])
username = cast("str | None", payload.get("sub"))
if not username:
sub = cast("str | None", payload.claims.get("sub"))
if not sub:
return None
token_data = TokenData(username=username)
return token_data
except jwt.InvalidTokenError as e:
issuer = payload.claims.get("iss") or LOCAL_ISSUER
identity_claims = IdentityClaims(sub=sub, provider=issuer)
if issuer == LOCAL_ISSUER:
identity_claims.username = sub
return identity_claims
except JoseError as e:
LOG.debug("Could not decode token: %s", e, exc_info=True)
return None
def hash_password(password: str) -> str:
"""Hash password."""
salt = bcrypt.gensalt()
hashed_password = bcrypt.hashpw(password.encode(), salt)
return hashed_password.decode()
def generate_user_info(user: User) -> LocalUserInfo:
"""Generate user info object from a user entry."""
is_local = user.provider == AuthProvider.LOCAL
if user.username:
LOG.info("User has a username: %s", user.username)
return LocalUserInfo(id=user.id, display_name=user.username, local=is_local)
assert user.email is not None
LOG.info("User has no username")
return LocalUserInfo(id=user.id, display_name=user.email, local=is_local)