from fastapi import FastAPI, HTTPException, Header from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse, FileResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from typing import Optional import jwt import requests import os from functools import lru_cache app = FastAPI(title="OIDC Token Validator") # Mount static files directory (if needed for CSS, JS, images, etc.) # app.mount("/static", StaticFiles(directory="static"), name="static") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # In production, specify exact origins allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Configuration ISSUER = os.getenv("OIDC_ISSUER", "https://login.infomaniak.com") # OIDC issuer URL CLIENT_ID = os.getenv("CLIENT_ID") # Client ID should be set as environment variable CLIENT_SECRET = os.getenv("CLIENT_SECRET", "") # Client secret should be set as environment variable WELL_KNOWN_CONFIG_URL = f"{ISSUER}/.well-known/openid-configuration" # List of authorized users (in a real app, this would come from a database) AUTHORIZED_USERS = { "rene.luria@infomaniak.com": "Welcome to the secret club!", "admin@example.com": "Admin super secret phrase!" } class TokenValidationRequest(BaseModel): id_token: str access_token: Optional[str] = None class TokenRefreshRequest(BaseModel): refresh_token: str class TokenRefreshResponse(BaseModel): access_token: str id_token: Optional[str] = None expires_in: int token_type: str error: Optional[str] = None class UserInfo(BaseModel): email: str first_name: Optional[str] = None last_name: Optional[str] = None class TokenValidationResponse(BaseModel): valid: bool user: Optional[UserInfo] = None secret_phrase: Optional[str] = None error: Optional[str] = None class ClientConfig(BaseModel): client_id: str issuer: str @lru_cache(maxsize=1) def get_well_known_config(): """Fetch OIDC well-known configuration (cached)""" try: response = requests.get(WELL_KNOWN_CONFIG_URL, timeout=10) response.raise_for_status() if not response.text: raise ValueError("Empty response from well-known configuration endpoint") return response.json() except requests.exceptions.RequestException as e: raise ValueError(f"Failed to fetch well-known configuration: {str(e)}") except ValueError as e: raise ValueError(f"Invalid well-known configuration response: {str(e)}") @lru_cache(maxsize=1) def get_jwks(): """Fetch JWKS from the issuer (cached)""" try: # Get the JWKS URL from the well-known configuration well_known_config = get_well_known_config() jwks_url = well_known_config.get("jwks_uri") if not jwks_url: raise ValueError("JWKS URI not found in well-known configuration") response = requests.get(jwks_url, timeout=10) response.raise_for_status() if not response.text: raise ValueError("Empty response from JWKS endpoint") return response.json() except requests.exceptions.RequestException as e: raise ValueError(f"Failed to fetch JWKS: {str(e)}") except ValueError as e: raise ValueError(f"Invalid JWKS response: {str(e)}") @lru_cache(maxsize=1) def get_userinfo_endpoint(): """Get the userinfo endpoint URL from the well-known configuration""" try: well_known_config = get_well_known_config() userinfo_endpoint = well_known_config.get("userinfo_endpoint") if not userinfo_endpoint: raise ValueError("Userinfo endpoint not found in well-known configuration") return userinfo_endpoint except Exception as e: raise ValueError(f"Failed to get userinfo endpoint: {str(e)}") @lru_cache(maxsize=1) def get_token_endpoint(): """Get the token endpoint URL from the well-known configuration""" try: well_known_config = get_well_known_config() token_endpoint = well_known_config.get("token_endpoint") if not token_endpoint: raise ValueError("Token endpoint not found in well-known configuration") return token_endpoint except Exception as e: raise ValueError(f"Failed to get token endpoint: {str(e)}") def get_user_info_from_endpoint(access_token: str): """Fetch user information from the userinfo endpoint""" try: userinfo_endpoint = get_userinfo_endpoint() headers = { "Authorization": f"Bearer {access_token}" } response = requests.get(userinfo_endpoint, headers=headers, timeout=10) response.raise_for_status() if not response.text: raise ValueError("Empty response from userinfo endpoint") user_info = response.json() # Map the user info fields # Note: Field names may vary depending on the OIDC provider first_name = user_info.get("given_name") or user_info.get("first_name") or user_info.get("firstName") last_name = user_info.get("family_name") or user_info.get("last_name") or user_info.get("lastName") return { "email": user_info.get("email", ""), "first_name": first_name, "last_name": last_name } except requests.exceptions.RequestException as e: raise ValueError(f"Failed to fetch user info: {str(e)}") except ValueError as e: raise ValueError(f"Invalid user info response: {str(e)}") def refresh_token(refresh_token: str): """Refresh an access token using a refresh token""" try: # Validate input if not refresh_token or not isinstance(refresh_token, str): raise ValueError("Invalid refresh token provided") if not CLIENT_SECRET: raise ValueError("Client secret not configured") # Get token endpoint token_endpoint = get_token_endpoint() # Prepare refresh request data = { "grant_type": "refresh_token", "refresh_token": refresh_token, "client_id": CLIENT_ID, "client_secret": CLIENT_SECRET } # Make token refresh request response = requests.post(token_endpoint, data=data, timeout=10) if not response.ok: raise ValueError(f"Token refresh failed: {response.status_code} - {response.text}") token_response = response.json() # Validate required fields in response if "access_token" not in token_response: raise ValueError("Access token not found in refresh response") return { "access_token": token_response["access_token"], "id_token": token_response.get("id_token"), "expires_in": token_response.get("expires_in", 0), "token_type": token_response.get("token_type", "Bearer") } except requests.exceptions.RequestException as e: raise ValueError(f"Failed to refresh token: {str(e)}") except Exception as e: raise ValueError(f"Token refresh failed: {str(e)}") def verify_token(id_token: str): """Verify JWT token and return payload if valid""" try: # Validate input if not id_token or not isinstance(id_token, str): raise ValueError("Invalid token provided") # Get JWKS jwks = get_jwks() # Decode token header to get kid header = jwt.get_unverified_header(id_token) kid = header.get("kid") if not kid: raise ValueError("Token header missing 'kid' field") # Find the matching key in JWKS key = None if not jwks or "keys" not in jwks: raise ValueError("Invalid JWKS format") for jwk in jwks.get("keys", []): if jwk.get("kid") == kid: key = jwt.algorithms.RSAAlgorithm.from_jwk(jwk) break if not key: raise ValueError("Unable to find matching key in JWKS") # Verify and decode token payload = jwt.decode( id_token, key=key, algorithms=[header.get("alg", "RS256")], audience=CLIENT_ID, issuer=ISSUER ) return payload except jwt.ExpiredSignatureError: raise ValueError("Token has expired") except jwt.InvalidAudienceError: raise ValueError("Invalid audience") except jwt.InvalidIssuerError: raise ValueError("Invalid issuer") except Exception as e: raise ValueError(f"Token validation failed: {str(e)}") @app.post("/validate-token", response_model=TokenValidationResponse) async def validate_token(request: TokenValidationRequest): """Validate ID token and return secret phrase for authorized users""" # Log the incoming request for debugging print(f"Validating token, length: {len(request.id_token) if request.id_token else 0}") try: # Verify token payload = verify_token(request.id_token) # Extract user email user_email = payload.get("email") if not user_email: return TokenValidationResponse( valid=False, error="Email not found in token" ) # Initialize user info with email from token user_info = { "email": user_email, "first_name": None, "last_name": None } # If access token is provided, fetch additional user info from userinfo endpoint if request.access_token: try: additional_info = get_user_info_from_endpoint(request.access_token) user_info.update(additional_info) except Exception as e: print(f"Warning: Failed to fetch user info from userinfo endpoint: {str(e)}") # Continue with basic user info from token # Check if user is authorized secret_phrase = AUTHORIZED_USERS.get(user_email) return TokenValidationResponse( valid=True, user=UserInfo(**user_info), secret_phrase=secret_phrase ) except ValueError as e: print(f"ValueError during token validation: {str(e)}") return TokenValidationResponse( valid=False, error=str(e) ) except Exception as e: print(f"Unexpected error during token validation: {str(e)}") return TokenValidationResponse( valid=False, error=f"Unexpected error: {str(e)}" ) @app.get("/", response_class=HTMLResponse) async def read_root(): """Serve the frontend HTML file""" with open("templates/index.html", "r") as file: html_content = file.read() return HTMLResponse(content=html_content, status_code=200) @app.get("/health") async def health_check(): """Health check endpoint for Kubernetes deployment""" return {"status": "healthy", "timestamp": os.getenv("HOSTNAME", "unknown")} @app.post("/refresh-token", response_model=TokenRefreshResponse) async def refresh_token_endpoint(request: TokenRefreshRequest): """Refresh an access token using a refresh token""" try: # Refresh the token refreshed_tokens = refresh_token(request.refresh_token) return TokenRefreshResponse( access_token=refreshed_tokens["access_token"], id_token=refreshed_tokens["id_token"], expires_in=refreshed_tokens["expires_in"], token_type=refreshed_tokens["token_type"] ) except ValueError as e: return TokenRefreshResponse( access_token="", # Required field, but empty since there's an error id_token=None, expires_in=0, token_type="", error=str(e) ) except Exception as e: return TokenRefreshResponse( access_token="", # Required field, but empty since there's an error id_token=None, expires_in=0, token_type="", error=f"Unexpected error: {str(e)}" ) @app.get("/favicon.ico") async def favicon(): """Serve the favicon""" return FileResponse("templates/favicon.ico") @app.get("/config", response_model=ClientConfig) async def get_client_config(): """Serve client configuration to frontend""" return ClientConfig(client_id=CLIENT_ID, issuer=ISSUER) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)