"""Authentication related endpoints factory.""" # pyright: reportUnusedFunction=false import os from datetime import datetime import logging from typing import Annotated, Literal from fastapi import ( APIRouter, Depends, Form, HTTPException, Request, Security, status, ) from fastapi.responses import RedirectResponse from fastapi.security import OAuth2PasswordRequestForm from pydantic import BaseModel, ValidationError from sqlalchemy.ext.asyncio import AsyncSession from sshecret_admin.auth import ( LocalUserInfo, Token, User, authenticate_user_async, create_access_token, create_refresh_token, decode_token, ) from sshecret_admin.auth.authentication import handle_oidc_claim, hash_password from sshecret_admin.auth.exceptions import AuthenticationFailedError from sshecret_admin.auth.models import AuthProvider, LoginInfo from sshecret_admin.auth.oidc import AdminOidc from sshecret_admin.core.dependencies import AdminDependencies from sshecret_admin.core.settings import AdminServerSettings from sshecret_admin.services import AdminBackend from sshecret_admin.services.models import UserPasswordChange from sshecret.backend.models import Operation LOG = logging.getLogger(__name__) class RefreshTokenForm(BaseModel): """The refresh token form data.""" grant_type: Literal["refresh_token"] refresh_token: str def create_router(dependencies: AdminDependencies) -> APIRouter: """Create auth router.""" app = APIRouter() def get_oidc_client() -> AdminOidc: """Get OIDC client dependency.""" if not dependencies.settings.oidc: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="OIDC authentication not available.", ) oidc = AdminOidc(dependencies.settings.oidc) return oidc @app.post("/token") async def login_for_access_token( session: Annotated[AsyncSession, Depends(dependencies.get_async_session)], form_data: Annotated[OAuth2PasswordRequestForm, Depends()], ) -> Token: """Login user and generate token.""" user = await authenticate_user_async( session, form_data.username, form_data.password ) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"}, ) token_data: dict[str, str] = {"sub": user.username} access_token = create_access_token( dependencies.settings, data=token_data, ) refresh_token = create_refresh_token(dependencies.settings, data=token_data) return Token( access_token=access_token, refresh_token=refresh_token, token_type="bearer" ) @app.post("/refresh") async def refresh_token( form_data: Annotated[RefreshTokenForm, Form()], ) -> Token: """Refresh access token.""" LOG.info("Refresh token data: %r", form_data) claims = decode_token(dependencies.settings, form_data.refresh_token) if not claims: LOG.info("Could not decode claims") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token", headers={"WWW-Authenticate": "Bearer"}, ) token_data: dict[str, str] = {"sub": claims.sub} access_token = create_access_token( dependencies.settings, data=token_data, ) refresh_token = create_refresh_token(dependencies.settings, data=token_data) return Token( access_token=access_token, refresh_token=refresh_token, token_type="bearer" ) @app.post("/password") async def change_password( request: Request, current_user: Annotated[User, Security(dependencies.get_current_active_user)], admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)], session: Annotated[AsyncSession, Depends(dependencies.get_async_session)], password_form: UserPasswordChange, ) -> None: """Change user password""" user = await authenticate_user_async( session, current_user.username, password_form.current_password ) if not user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid current password", ) new_password_hash = hash_password(password_form.new_password) user.hashed_password = new_password_hash session.add(user) await session.commit() origin = "UNKNOWN" if request.client: origin = request.client.host await admin.write_audit_message( Operation.UPDATE, message="User changed their password", origin=origin, username=user.username, ) @app.get("/oidc/login") async def start_oidc_login( request: Request, oidc: Annotated[AdminOidc, Depends(get_oidc_client)] ) -> RedirectResponse: """Redirect for OIDC login.""" redirect_url = request.url_for("oidc_callback") return await oidc.start_auth(request, redirect_url) @app.get("/oidc/callback") async def oidc_callback( request: Request, session: Annotated[AsyncSession, Depends(dependencies.get_async_session)], admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)], oidc: Annotated[AdminOidc, Depends(get_oidc_client)], ): """Callback for OIDC auth.""" try: claims = await oidc.handle_auth_callback(request) except AuthenticationFailedError as error: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=error) except ValidationError as error: LOG.error("Validation error: %s", error, exc_info=True) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=error) # We now have a IdentityClaims object. # We need to check if this matches an existing user, or we need to create a new one. user = await handle_oidc_claim(session, claims) user.last_login = datetime.now() session.add(user) await session.commit() # Set cookies token_data: dict[str, str] = {"sub": claims.sub} access_token = create_access_token( dependencies.settings, data=token_data, provider=claims.provider ) refresh_token = create_refresh_token( dependencies.settings, data=token_data, provider=claims.provider ) callback_url = f"/auth_cb#access_token={access_token}&refresh_token={refresh_token}" if dependencies.settings.frontend_test_url: callback_url = os.path.join(dependencies.settings.frontend_test_url, callback_url) origin = "UNKNOWN" if request.client: origin = request.client.host await admin.write_audit_message( operation=Operation.LOGIN, message="Logged in to admin frontend", origin=origin, username=user.username, oidc=claims.provider, ) return RedirectResponse(callback_url) @app.get("/users/me") async def get_current_user( current_user: Annotated[User, Security(dependencies.get_current_active_user)], ) -> LocalUserInfo: """Get information about the user currently logged in.""" is_local = current_user.provider is AuthProvider.LOCAL return LocalUserInfo( id=current_user.id, display_name=current_user.username, local=is_local ) @app.get("/oidc/status") async def get_auth_info() -> LoginInfo: """Check if OIDC login is available.""" if dependencies.settings.oidc: return LoginInfo( enabled=True, oidc_provider=dependencies.settings.oidc.name ) return LoginInfo(enabled=False) return app