Implement oidc login
This commit is contained in:
@ -12,18 +12,23 @@ from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from jinja2_fragments.fastapi import Jinja2Blocks
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
from sshecret_admin.auth.authentication import generate_user_info
|
||||
from sshecret_admin.auth.models import AuthProvider, IdentityClaims, LocalUserInfo
|
||||
from starlette.datastructures import URL
|
||||
|
||||
|
||||
from sshecret_admin.auth import PasswordDB, User, decode_token
|
||||
from sshecret_admin.auth.constants import LOCAL_ISSUER
|
||||
|
||||
from sshecret_admin.core.dependencies import BaseDependencies
|
||||
from sshecret_admin.services.admin_backend import AdminBackend
|
||||
from sshecret_admin.core.db import DatabaseSessionManager
|
||||
|
||||
from .dependencies import FrontendDependencies
|
||||
from .exceptions import RedirectException
|
||||
from .views import audit, auth, clients, index, secrets
|
||||
from .views import audit, auth, clients, index, secrets, oidc_auth
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -45,7 +50,7 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
templates = Jinja2Blocks(directory=template_path)
|
||||
|
||||
async def get_admin_backend(
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)]
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||
):
|
||||
"""Get admin backend API."""
|
||||
password_db = session.scalars(
|
||||
@ -58,66 +63,50 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
admin = AdminBackend(dependencies.settings, password_db.encrypted_password)
|
||||
yield admin
|
||||
|
||||
async def get_user_from_token(
|
||||
token: str,
|
||||
session: Session,
|
||||
) -> User | None:
|
||||
"""Get user from a token."""
|
||||
token_data = decode_token(dependencies.settings, token)
|
||||
if not token_data:
|
||||
return None
|
||||
user = session.scalars(
|
||||
select(User).where(User.username == token_data.username)
|
||||
).first()
|
||||
if not user or user.disabled:
|
||||
return None
|
||||
return user
|
||||
|
||||
async def get_user_from_refresh_token(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||
) -> User:
|
||||
"""Get user from refresh token."""
|
||||
next = URL("/login").include_query_params(next=request.url.path)
|
||||
credentials_error = RedirectException(to=next)
|
||||
token = request.cookies.get("refresh_token")
|
||||
if not token:
|
||||
raise credentials_error
|
||||
|
||||
user = await get_user_from_token(token, session)
|
||||
if not user:
|
||||
raise credentials_error
|
||||
return user
|
||||
|
||||
async def get_user_from_access_token(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||
) -> User:
|
||||
"""Get user from access token."""
|
||||
def get_identity_claims(request: Request) -> IdentityClaims:
|
||||
"""Get identity claim from session."""
|
||||
token = request.cookies.get("access_token")
|
||||
next = URL("/refresh").include_query_params(next=request.url.path)
|
||||
credentials_error = RedirectException(to=next)
|
||||
if not token:
|
||||
raise credentials_error
|
||||
|
||||
user = await get_user_from_token(token, session)
|
||||
if not user:
|
||||
claims = decode_token(dependencies.settings, token)
|
||||
if not claims:
|
||||
raise credentials_error
|
||||
return user
|
||||
return claims
|
||||
|
||||
async def get_login_status(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||
) -> bool:
|
||||
def refresh_identity_claims(request: Request) -> IdentityClaims:
|
||||
"""Get identity claim from session for refreshing the token."""
|
||||
token = request.cookies.get("refresh_token")
|
||||
next = URL("/login").include_query_params(next=request.url.path)
|
||||
credentials_error = RedirectException(to=next)
|
||||
if not token:
|
||||
raise credentials_error
|
||||
claims = decode_token(dependencies.settings, token)
|
||||
if not claims:
|
||||
raise credentials_error
|
||||
return claims
|
||||
|
||||
async def get_login_status(request: Request) -> bool:
|
||||
"""Get login status."""
|
||||
token = request.cookies.get("access_token")
|
||||
if not token:
|
||||
return False
|
||||
|
||||
user = await get_user_from_token(token, session)
|
||||
if not user:
|
||||
return False
|
||||
return True
|
||||
claims = decode_token(dependencies.settings, token)
|
||||
return claims is not None
|
||||
|
||||
async def require_login(request: Request) -> None:
|
||||
"""Enforce login requirement."""
|
||||
token = request.cookies.get("access_token")
|
||||
LOG.info("User has no cookie")
|
||||
if not token:
|
||||
url = URL("/login").include_query_params(next=request.url.path)
|
||||
raise RedirectException(to=url)
|
||||
is_logged_in = await get_login_status(request)
|
||||
if not is_logged_in:
|
||||
next = URL("/refresh").include_query_params(next=request.url.path)
|
||||
raise RedirectException(to=next)
|
||||
|
||||
async def get_async_session():
|
||||
"""Get async session."""
|
||||
@ -125,14 +114,43 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
async with sessionmanager.session() as session:
|
||||
yield session
|
||||
|
||||
async def get_user_info(
|
||||
request: Request, session: Annotated[AsyncSession, Depends(get_async_session)]
|
||||
) -> LocalUserInfo:
|
||||
"""Get User information."""
|
||||
claims = get_identity_claims(request)
|
||||
if claims.provider == LOCAL_ISSUER:
|
||||
LOG.info("Local user, finding username %s", claims.sub)
|
||||
query = (
|
||||
select(User)
|
||||
.where(User.username == claims.sub)
|
||||
.where(User.provider == AuthProvider.LOCAL)
|
||||
)
|
||||
else:
|
||||
query = (
|
||||
select(User)
|
||||
.where(User.oidc_issuer == claims.provider)
|
||||
.where(User.oidc_sub == claims.sub)
|
||||
)
|
||||
|
||||
result = await session.scalars(query)
|
||||
if user := result.first():
|
||||
if user.disabled:
|
||||
raise RedirectException(to=URL("/logout"))
|
||||
return generate_user_info(user)
|
||||
|
||||
next = URL("/refresh").include_query_params(next=request.url.path)
|
||||
raise RedirectException(to=next)
|
||||
|
||||
view_dependencies = FrontendDependencies.create(
|
||||
dependencies,
|
||||
get_admin_backend,
|
||||
templates,
|
||||
get_user_from_access_token,
|
||||
get_user_from_refresh_token,
|
||||
refresh_identity_claims,
|
||||
get_login_status,
|
||||
get_user_info,
|
||||
get_async_session,
|
||||
require_login,
|
||||
)
|
||||
|
||||
app.include_router(audit.create_router(view_dependencies))
|
||||
@ -140,5 +158,7 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
app.include_router(clients.create_router(view_dependencies))
|
||||
app.include_router(index.create_router(view_dependencies))
|
||||
app.include_router(secrets.create_router(view_dependencies))
|
||||
if dependencies.settings.oidc:
|
||||
app.include_router(oidc_auth.create_router(view_dependencies))
|
||||
|
||||
return app
|
||||
|
||||
Reference in New Issue
Block a user