Files
demo-oidc/main.py
Rene Luria a4908ac492 Initial commit
Adds OIDC token validator application with FastAPI backend and HTML/JavaScript frontend.
Includes Docker configuration and Kubernetes readiness.
2025-08-08 09:16:40 +02:00

366 lines
13 KiB
Python

from fastapi import FastAPI, HTTPException, Header
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from typing import Optional
import jwt
import requests
import os
from functools import lru_cache
app = FastAPI(title="OIDC Token Validator")
# Mount static files directory (if needed for CSS, JS, images, etc.)
# app.mount("/static", StaticFiles(directory="static"), name="static")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify exact origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Configuration
ISSUER = os.getenv("OIDC_ISSUER", "https://login.infomaniak.com") # OIDC issuer URL
CLIENT_ID = os.getenv("CLIENT_ID") # Client ID should be set as environment variable
CLIENT_SECRET = os.getenv("CLIENT_SECRET", "") # Client secret should be set as environment variable
WELL_KNOWN_CONFIG_URL = f"{ISSUER}/.well-known/openid-configuration"
# List of authorized users (in a real app, this would come from a database)
AUTHORIZED_USERS = {
"rene.luria@infomaniak.com": "Welcome to the secret club!",
"admin@example.com": "Admin super secret phrase!"
}
class TokenValidationRequest(BaseModel):
id_token: str
access_token: Optional[str] = None
class TokenRefreshRequest(BaseModel):
refresh_token: str
class TokenRefreshResponse(BaseModel):
access_token: str
id_token: Optional[str] = None
expires_in: int
token_type: str
error: Optional[str] = None
class UserInfo(BaseModel):
email: str
first_name: Optional[str] = None
last_name: Optional[str] = None
class TokenValidationResponse(BaseModel):
valid: bool
user: Optional[UserInfo] = None
secret_phrase: Optional[str] = None
error: Optional[str] = None
class ClientConfig(BaseModel):
client_id: str
issuer: str
@lru_cache(maxsize=1)
def get_well_known_config():
"""Fetch OIDC well-known configuration (cached)"""
try:
response = requests.get(WELL_KNOWN_CONFIG_URL, timeout=10)
response.raise_for_status()
if not response.text:
raise ValueError("Empty response from well-known configuration endpoint")
return response.json()
except requests.exceptions.RequestException as e:
raise ValueError(f"Failed to fetch well-known configuration: {str(e)}")
except ValueError as e:
raise ValueError(f"Invalid well-known configuration response: {str(e)}")
@lru_cache(maxsize=1)
def get_jwks():
"""Fetch JWKS from the issuer (cached)"""
try:
# Get the JWKS URL from the well-known configuration
well_known_config = get_well_known_config()
jwks_url = well_known_config.get("jwks_uri")
if not jwks_url:
raise ValueError("JWKS URI not found in well-known configuration")
response = requests.get(jwks_url, timeout=10)
response.raise_for_status()
if not response.text:
raise ValueError("Empty response from JWKS endpoint")
return response.json()
except requests.exceptions.RequestException as e:
raise ValueError(f"Failed to fetch JWKS: {str(e)}")
except ValueError as e:
raise ValueError(f"Invalid JWKS response: {str(e)}")
@lru_cache(maxsize=1)
def get_userinfo_endpoint():
"""Get the userinfo endpoint URL from the well-known configuration"""
try:
well_known_config = get_well_known_config()
userinfo_endpoint = well_known_config.get("userinfo_endpoint")
if not userinfo_endpoint:
raise ValueError("Userinfo endpoint not found in well-known configuration")
return userinfo_endpoint
except Exception as e:
raise ValueError(f"Failed to get userinfo endpoint: {str(e)}")
@lru_cache(maxsize=1)
def get_token_endpoint():
"""Get the token endpoint URL from the well-known configuration"""
try:
well_known_config = get_well_known_config()
token_endpoint = well_known_config.get("token_endpoint")
if not token_endpoint:
raise ValueError("Token endpoint not found in well-known configuration")
return token_endpoint
except Exception as e:
raise ValueError(f"Failed to get token endpoint: {str(e)}")
def get_user_info_from_endpoint(access_token: str):
"""Fetch user information from the userinfo endpoint"""
try:
userinfo_endpoint = get_userinfo_endpoint()
headers = {
"Authorization": f"Bearer {access_token}"
}
response = requests.get(userinfo_endpoint, headers=headers, timeout=10)
response.raise_for_status()
if not response.text:
raise ValueError("Empty response from userinfo endpoint")
user_info = response.json()
# Map the user info fields
# Note: Field names may vary depending on the OIDC provider
first_name = user_info.get("given_name") or user_info.get("first_name") or user_info.get("firstName")
last_name = user_info.get("family_name") or user_info.get("last_name") or user_info.get("lastName")
return {
"email": user_info.get("email", ""),
"first_name": first_name,
"last_name": last_name
}
except requests.exceptions.RequestException as e:
raise ValueError(f"Failed to fetch user info: {str(e)}")
except ValueError as e:
raise ValueError(f"Invalid user info response: {str(e)}")
def refresh_token(refresh_token: str):
"""Refresh an access token using a refresh token"""
try:
# Validate input
if not refresh_token or not isinstance(refresh_token, str):
raise ValueError("Invalid refresh token provided")
if not CLIENT_SECRET:
raise ValueError("Client secret not configured")
# Get token endpoint
token_endpoint = get_token_endpoint()
# Prepare refresh request
data = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": CLIENT_ID,
"client_secret": CLIENT_SECRET
}
# Make token refresh request
response = requests.post(token_endpoint, data=data, timeout=10)
if not response.ok:
raise ValueError(f"Token refresh failed: {response.status_code} - {response.text}")
token_response = response.json()
# Validate required fields in response
if "access_token" not in token_response:
raise ValueError("Access token not found in refresh response")
return {
"access_token": token_response["access_token"],
"id_token": token_response.get("id_token"),
"expires_in": token_response.get("expires_in", 0),
"token_type": token_response.get("token_type", "Bearer")
}
except requests.exceptions.RequestException as e:
raise ValueError(f"Failed to refresh token: {str(e)}")
except Exception as e:
raise ValueError(f"Token refresh failed: {str(e)}")
def verify_token(id_token: str):
"""Verify JWT token and return payload if valid"""
try:
# Validate input
if not id_token or not isinstance(id_token, str):
raise ValueError("Invalid token provided")
# Get JWKS
jwks = get_jwks()
# Decode token header to get kid
header = jwt.get_unverified_header(id_token)
kid = header.get("kid")
if not kid:
raise ValueError("Token header missing 'kid' field")
# Find the matching key in JWKS
key = None
if not jwks or "keys" not in jwks:
raise ValueError("Invalid JWKS format")
for jwk in jwks.get("keys", []):
if jwk.get("kid") == kid:
key = jwt.algorithms.RSAAlgorithm.from_jwk(jwk)
break
if not key:
raise ValueError("Unable to find matching key in JWKS")
# Verify and decode token
payload = jwt.decode(
id_token,
key=key,
algorithms=[header.get("alg", "RS256")],
audience=CLIENT_ID,
issuer=ISSUER
)
return payload
except jwt.ExpiredSignatureError:
raise ValueError("Token has expired")
except jwt.InvalidAudienceError:
raise ValueError("Invalid audience")
except jwt.InvalidIssuerError:
raise ValueError("Invalid issuer")
except Exception as e:
raise ValueError(f"Token validation failed: {str(e)}")
@app.post("/validate-token", response_model=TokenValidationResponse)
async def validate_token(request: TokenValidationRequest):
"""Validate ID token and return secret phrase for authorized users"""
# Log the incoming request for debugging
print(f"Validating token, length: {len(request.id_token) if request.id_token else 0}")
try:
# Verify token
payload = verify_token(request.id_token)
# Extract user email
user_email = payload.get("email")
if not user_email:
return TokenValidationResponse(
valid=False,
error="Email not found in token"
)
# Initialize user info with email from token
user_info = {
"email": user_email,
"first_name": None,
"last_name": None
}
# If access token is provided, fetch additional user info from userinfo endpoint
if request.access_token:
try:
additional_info = get_user_info_from_endpoint(request.access_token)
user_info.update(additional_info)
except Exception as e:
print(f"Warning: Failed to fetch user info from userinfo endpoint: {str(e)}")
# Continue with basic user info from token
# Check if user is authorized
secret_phrase = AUTHORIZED_USERS.get(user_email)
return TokenValidationResponse(
valid=True,
user=UserInfo(**user_info),
secret_phrase=secret_phrase
)
except ValueError as e:
print(f"ValueError during token validation: {str(e)}")
return TokenValidationResponse(
valid=False,
error=str(e)
)
except Exception as e:
print(f"Unexpected error during token validation: {str(e)}")
return TokenValidationResponse(
valid=False,
error=f"Unexpected error: {str(e)}"
)
@app.get("/", response_class=HTMLResponse)
async def read_root():
"""Serve the frontend HTML file"""
with open("templates/index.html", "r") as file:
html_content = file.read()
return HTMLResponse(content=html_content, status_code=200)
@app.get("/health")
async def health_check():
"""Health check endpoint for Kubernetes deployment"""
return {"status": "healthy", "timestamp": os.getenv("HOSTNAME", "unknown")}
@app.post("/refresh-token", response_model=TokenRefreshResponse)
async def refresh_token_endpoint(request: TokenRefreshRequest):
"""Refresh an access token using a refresh token"""
try:
# Refresh the token
refreshed_tokens = refresh_token(request.refresh_token)
return TokenRefreshResponse(
access_token=refreshed_tokens["access_token"],
id_token=refreshed_tokens["id_token"],
expires_in=refreshed_tokens["expires_in"],
token_type=refreshed_tokens["token_type"]
)
except ValueError as e:
return TokenRefreshResponse(
access_token="", # Required field, but empty since there's an error
id_token=None,
expires_in=0,
token_type="",
error=str(e)
)
except Exception as e:
return TokenRefreshResponse(
access_token="", # Required field, but empty since there's an error
id_token=None,
expires_in=0,
token_type="",
error=f"Unexpected error: {str(e)}"
)
@app.get("/favicon.ico")
async def favicon():
"""Serve the favicon"""
return FileResponse("templates/favicon.ico")
@app.get("/config", response_model=ClientConfig)
async def get_client_config():
"""Serve client configuration to frontend"""
return ClientConfig(client_id=CLIENT_ID, issuer=ISSUER)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)