Implement oidc login
This commit is contained in:
@ -5,21 +5,26 @@ from datetime import datetime, timezone, timedelta
|
||||
from typing import cast, Any
|
||||
|
||||
import bcrypt
|
||||
import jwt
|
||||
from joserfc import jwt
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from joserfc.jwk import OctKey
|
||||
from joserfc.errors import JoseError
|
||||
|
||||
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
|
||||
from .models import User, TokenData
|
||||
from .models import AuthProvider, LocalUserInfo, User, IdentityClaims
|
||||
from .exceptions import AuthenticationFailedError
|
||||
|
||||
JWT_ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||
# I know refresh tokens are supposed to be long-lived, but 6 hours for a
|
||||
# sensitive application, seems reasonable.
|
||||
REFRESH_TOKEN_EXPIRE_HOURS = 6
|
||||
from .constants import (
|
||||
JWT_ALGORITHM,
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||
REFRESH_TOKEN_EXPIRE_HOURS,
|
||||
LOCAL_ISSUER,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -28,12 +33,14 @@ def create_token(
|
||||
settings: AdminServerSettings,
|
||||
data: dict[str, Any],
|
||||
expires_delta: timedelta,
|
||||
provider: str,
|
||||
) -> str:
|
||||
"""Create access token."""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=JWT_ALGORITHM)
|
||||
to_encode.update({"exp": expire, "iss": provider})
|
||||
key = OctKey.import_key(settings.secret_key)
|
||||
encoded_jwt = jwt.encode({"alg": JWT_ALGORITHM}, to_encode, key)
|
||||
return str(encoded_jwt)
|
||||
|
||||
|
||||
@ -41,22 +48,24 @@ def create_access_token(
|
||||
settings: AdminServerSettings,
|
||||
data: dict[str, Any],
|
||||
expires_delta: timedelta | None = None,
|
||||
provider: str = LOCAL_ISSUER,
|
||||
) -> str:
|
||||
"""Create access token."""
|
||||
if not expires_delta:
|
||||
expires_delta = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
return create_token(settings, data, expires_delta)
|
||||
return create_token(settings, data, expires_delta, provider)
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
settings: AdminServerSettings,
|
||||
data: dict[str, Any],
|
||||
expires_delta: timedelta | None = None,
|
||||
provider: str = LOCAL_ISSUER,
|
||||
) -> str:
|
||||
"""Create access token."""
|
||||
if not expires_delta:
|
||||
expires_delta = timedelta(hours=REFRESH_TOKEN_EXPIRE_HOURS)
|
||||
return create_token(settings, data, expires_delta)
|
||||
return create_token(settings, data, expires_delta, provider)
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
@ -73,9 +82,13 @@ def check_password(plain_password: str, hashed_password: str) -> None:
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
|
||||
async def authenticate_user_async(session: AsyncSession, username: str, password: str) -> User | None:
|
||||
async def authenticate_user_async(
|
||||
session: AsyncSession, username: str, password: str
|
||||
) -> User | None:
|
||||
"""Authenticate user async."""
|
||||
user = (await session.scalars(select(User).where(User.username == username))).first()
|
||||
user = (
|
||||
await session.scalars(select(User).where(User.username == username))
|
||||
).first()
|
||||
if not user:
|
||||
return None
|
||||
if not verify_password(password, user.hashed_password):
|
||||
@ -83,6 +96,44 @@ async def authenticate_user_async(session: AsyncSession, username: str, password
|
||||
return user
|
||||
|
||||
|
||||
async def handle_oidc_claim(session: AsyncSession, claim: IdentityClaims) -> User:
|
||||
"""Handle OIDC claim.
|
||||
|
||||
Either return an existing user, or create a new one.
|
||||
"""
|
||||
LOG.debug("Looking up OIDC token claim %r", claim)
|
||||
if claim.provider == LOCAL_ISSUER:
|
||||
raise ValueError("IdentityClaims do not originate from OIDC.")
|
||||
query = (
|
||||
select(User)
|
||||
.where(User.oidc_sub == claim.sub)
|
||||
.where(User.oidc_issuer == claim.provider)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
if user := result.scalar_one_or_none():
|
||||
LOG.debug("Found existing user %s", user.id)
|
||||
return user
|
||||
|
||||
LOG.debug("User not found in local database. Creating a new user")
|
||||
user = User(
|
||||
username=claim.username,
|
||||
email=claim.email,
|
||||
disabled=False,
|
||||
oidc_sub=claim.sub,
|
||||
oidc_issuer=claim.provider,
|
||||
provider=AuthProvider.OIDC,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
query = (
|
||||
select(User)
|
||||
.where(User.oidc_sub == claim.sub)
|
||||
.where(User.oidc_issuer == claim.provider)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one()
|
||||
|
||||
|
||||
def authenticate_user(session: Session, username: str, password: str) -> User | None:
|
||||
"""Authenticate user."""
|
||||
user = session.scalars(select(User).where(User.username == username)).first()
|
||||
@ -93,22 +144,48 @@ def authenticate_user(session: Session, username: str, password: str) -> User |
|
||||
return user
|
||||
|
||||
|
||||
def decode_token(settings: AdminServerSettings, token: str) -> TokenData | None:
|
||||
def decode_token(settings: AdminServerSettings, token: str) -> IdentityClaims | None:
|
||||
"""Decode token."""
|
||||
key = OctKey.import_key(settings.secret_key)
|
||||
try:
|
||||
decoded = jwt.decode(token, key)
|
||||
claims_requests = jwt.JWTClaimsRegistry(
|
||||
exp={"essential": True},
|
||||
sub={"essential": True},
|
||||
)
|
||||
|
||||
claims_requests.validate(decoded.claims)
|
||||
payload = jwt.decode(token, settings.secret_key, algorithms=[JWT_ALGORITHM])
|
||||
username = cast("str | None", payload.get("sub"))
|
||||
if not username:
|
||||
sub = cast("str | None", payload.claims.get("sub"))
|
||||
if not sub:
|
||||
return None
|
||||
|
||||
token_data = TokenData(username=username)
|
||||
return token_data
|
||||
except jwt.InvalidTokenError as e:
|
||||
issuer = payload.claims.get("iss") or LOCAL_ISSUER
|
||||
|
||||
identity_claims = IdentityClaims(sub=sub, provider=issuer)
|
||||
if issuer == LOCAL_ISSUER:
|
||||
identity_claims.username = sub
|
||||
|
||||
return identity_claims
|
||||
|
||||
except JoseError as e:
|
||||
LOG.debug("Could not decode token: %s", e, exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash password."""
|
||||
salt = bcrypt.gensalt()
|
||||
hashed_password = bcrypt.hashpw(password.encode(), salt)
|
||||
return hashed_password.decode()
|
||||
|
||||
|
||||
def generate_user_info(user: User) -> LocalUserInfo:
|
||||
"""Generate user info object from a user entry."""
|
||||
is_local = user.provider == AuthProvider.LOCAL
|
||||
if user.username:
|
||||
LOG.info("User has a username: %s", user.username)
|
||||
return LocalUserInfo(id=user.id, display_name=user.username, local=is_local)
|
||||
assert user.email is not None
|
||||
LOG.info("User has no username")
|
||||
return LocalUserInfo(id=user.id, display_name=user.email, local=is_local)
|
||||
|
||||
Reference in New Issue
Block a user