Adds OIDC token validator application with FastAPI backend and HTML/JavaScript frontend. Includes Docker configuration and Kubernetes readiness.
366 lines
13 KiB
Python
366 lines
13 KiB
Python
from fastapi import FastAPI, HTTPException, Header
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import HTMLResponse, FileResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from pydantic import BaseModel
|
|
from typing import Optional
|
|
import jwt
|
|
import requests
|
|
import os
|
|
from functools import lru_cache
|
|
|
|
app = FastAPI(title="OIDC Token Validator")
|
|
|
|
# Mount static files directory (if needed for CSS, JS, images, etc.)
|
|
# app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
|
# Add CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # In production, specify exact origins
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Configuration
|
|
ISSUER = os.getenv("OIDC_ISSUER", "https://login.infomaniak.com") # OIDC issuer URL
|
|
CLIENT_ID = os.getenv("CLIENT_ID") # Client ID should be set as environment variable
|
|
CLIENT_SECRET = os.getenv("CLIENT_SECRET", "") # Client secret should be set as environment variable
|
|
WELL_KNOWN_CONFIG_URL = f"{ISSUER}/.well-known/openid-configuration"
|
|
|
|
# List of authorized users (in a real app, this would come from a database)
|
|
AUTHORIZED_USERS = {
|
|
"rene.luria@infomaniak.com": "Welcome to the secret club!",
|
|
"admin@example.com": "Admin super secret phrase!"
|
|
}
|
|
|
|
class TokenValidationRequest(BaseModel):
|
|
id_token: str
|
|
access_token: Optional[str] = None
|
|
|
|
class TokenRefreshRequest(BaseModel):
|
|
refresh_token: str
|
|
|
|
class TokenRefreshResponse(BaseModel):
|
|
access_token: str
|
|
id_token: Optional[str] = None
|
|
expires_in: int
|
|
token_type: str
|
|
error: Optional[str] = None
|
|
|
|
class UserInfo(BaseModel):
|
|
email: str
|
|
first_name: Optional[str] = None
|
|
last_name: Optional[str] = None
|
|
|
|
class TokenValidationResponse(BaseModel):
|
|
valid: bool
|
|
user: Optional[UserInfo] = None
|
|
secret_phrase: Optional[str] = None
|
|
error: Optional[str] = None
|
|
|
|
class ClientConfig(BaseModel):
|
|
client_id: str
|
|
issuer: str
|
|
|
|
@lru_cache(maxsize=1)
|
|
def get_well_known_config():
|
|
"""Fetch OIDC well-known configuration (cached)"""
|
|
try:
|
|
response = requests.get(WELL_KNOWN_CONFIG_URL, timeout=10)
|
|
response.raise_for_status()
|
|
if not response.text:
|
|
raise ValueError("Empty response from well-known configuration endpoint")
|
|
return response.json()
|
|
except requests.exceptions.RequestException as e:
|
|
raise ValueError(f"Failed to fetch well-known configuration: {str(e)}")
|
|
except ValueError as e:
|
|
raise ValueError(f"Invalid well-known configuration response: {str(e)}")
|
|
|
|
@lru_cache(maxsize=1)
|
|
def get_jwks():
|
|
"""Fetch JWKS from the issuer (cached)"""
|
|
try:
|
|
# Get the JWKS URL from the well-known configuration
|
|
well_known_config = get_well_known_config()
|
|
jwks_url = well_known_config.get("jwks_uri")
|
|
|
|
if not jwks_url:
|
|
raise ValueError("JWKS URI not found in well-known configuration")
|
|
|
|
response = requests.get(jwks_url, timeout=10)
|
|
response.raise_for_status()
|
|
if not response.text:
|
|
raise ValueError("Empty response from JWKS endpoint")
|
|
return response.json()
|
|
except requests.exceptions.RequestException as e:
|
|
raise ValueError(f"Failed to fetch JWKS: {str(e)}")
|
|
except ValueError as e:
|
|
raise ValueError(f"Invalid JWKS response: {str(e)}")
|
|
|
|
@lru_cache(maxsize=1)
|
|
def get_userinfo_endpoint():
|
|
"""Get the userinfo endpoint URL from the well-known configuration"""
|
|
try:
|
|
well_known_config = get_well_known_config()
|
|
userinfo_endpoint = well_known_config.get("userinfo_endpoint")
|
|
|
|
if not userinfo_endpoint:
|
|
raise ValueError("Userinfo endpoint not found in well-known configuration")
|
|
|
|
return userinfo_endpoint
|
|
except Exception as e:
|
|
raise ValueError(f"Failed to get userinfo endpoint: {str(e)}")
|
|
|
|
@lru_cache(maxsize=1)
|
|
def get_token_endpoint():
|
|
"""Get the token endpoint URL from the well-known configuration"""
|
|
try:
|
|
well_known_config = get_well_known_config()
|
|
token_endpoint = well_known_config.get("token_endpoint")
|
|
|
|
if not token_endpoint:
|
|
raise ValueError("Token endpoint not found in well-known configuration")
|
|
|
|
return token_endpoint
|
|
except Exception as e:
|
|
raise ValueError(f"Failed to get token endpoint: {str(e)}")
|
|
|
|
def get_user_info_from_endpoint(access_token: str):
|
|
"""Fetch user information from the userinfo endpoint"""
|
|
try:
|
|
userinfo_endpoint = get_userinfo_endpoint()
|
|
|
|
headers = {
|
|
"Authorization": f"Bearer {access_token}"
|
|
}
|
|
|
|
response = requests.get(userinfo_endpoint, headers=headers, timeout=10)
|
|
response.raise_for_status()
|
|
|
|
if not response.text:
|
|
raise ValueError("Empty response from userinfo endpoint")
|
|
|
|
user_info = response.json()
|
|
|
|
# Map the user info fields
|
|
# Note: Field names may vary depending on the OIDC provider
|
|
first_name = user_info.get("given_name") or user_info.get("first_name") or user_info.get("firstName")
|
|
last_name = user_info.get("family_name") or user_info.get("last_name") or user_info.get("lastName")
|
|
|
|
return {
|
|
"email": user_info.get("email", ""),
|
|
"first_name": first_name,
|
|
"last_name": last_name
|
|
}
|
|
except requests.exceptions.RequestException as e:
|
|
raise ValueError(f"Failed to fetch user info: {str(e)}")
|
|
except ValueError as e:
|
|
raise ValueError(f"Invalid user info response: {str(e)}")
|
|
|
|
def refresh_token(refresh_token: str):
|
|
"""Refresh an access token using a refresh token"""
|
|
try:
|
|
# Validate input
|
|
if not refresh_token or not isinstance(refresh_token, str):
|
|
raise ValueError("Invalid refresh token provided")
|
|
|
|
if not CLIENT_SECRET:
|
|
raise ValueError("Client secret not configured")
|
|
|
|
# Get token endpoint
|
|
token_endpoint = get_token_endpoint()
|
|
|
|
# Prepare refresh request
|
|
data = {
|
|
"grant_type": "refresh_token",
|
|
"refresh_token": refresh_token,
|
|
"client_id": CLIENT_ID,
|
|
"client_secret": CLIENT_SECRET
|
|
}
|
|
|
|
# Make token refresh request
|
|
response = requests.post(token_endpoint, data=data, timeout=10)
|
|
|
|
if not response.ok:
|
|
raise ValueError(f"Token refresh failed: {response.status_code} - {response.text}")
|
|
|
|
token_response = response.json()
|
|
|
|
# Validate required fields in response
|
|
if "access_token" not in token_response:
|
|
raise ValueError("Access token not found in refresh response")
|
|
|
|
return {
|
|
"access_token": token_response["access_token"],
|
|
"id_token": token_response.get("id_token"),
|
|
"expires_in": token_response.get("expires_in", 0),
|
|
"token_type": token_response.get("token_type", "Bearer")
|
|
}
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
raise ValueError(f"Failed to refresh token: {str(e)}")
|
|
except Exception as e:
|
|
raise ValueError(f"Token refresh failed: {str(e)}")
|
|
|
|
def verify_token(id_token: str):
|
|
"""Verify JWT token and return payload if valid"""
|
|
try:
|
|
# Validate input
|
|
if not id_token or not isinstance(id_token, str):
|
|
raise ValueError("Invalid token provided")
|
|
|
|
# Get JWKS
|
|
jwks = get_jwks()
|
|
|
|
# Decode token header to get kid
|
|
header = jwt.get_unverified_header(id_token)
|
|
kid = header.get("kid")
|
|
|
|
if not kid:
|
|
raise ValueError("Token header missing 'kid' field")
|
|
|
|
# Find the matching key in JWKS
|
|
key = None
|
|
if not jwks or "keys" not in jwks:
|
|
raise ValueError("Invalid JWKS format")
|
|
|
|
for jwk in jwks.get("keys", []):
|
|
if jwk.get("kid") == kid:
|
|
key = jwt.algorithms.RSAAlgorithm.from_jwk(jwk)
|
|
break
|
|
|
|
if not key:
|
|
raise ValueError("Unable to find matching key in JWKS")
|
|
|
|
# Verify and decode token
|
|
payload = jwt.decode(
|
|
id_token,
|
|
key=key,
|
|
algorithms=[header.get("alg", "RS256")],
|
|
audience=CLIENT_ID,
|
|
issuer=ISSUER
|
|
)
|
|
|
|
return payload
|
|
except jwt.ExpiredSignatureError:
|
|
raise ValueError("Token has expired")
|
|
except jwt.InvalidAudienceError:
|
|
raise ValueError("Invalid audience")
|
|
except jwt.InvalidIssuerError:
|
|
raise ValueError("Invalid issuer")
|
|
except Exception as e:
|
|
raise ValueError(f"Token validation failed: {str(e)}")
|
|
|
|
@app.post("/validate-token", response_model=TokenValidationResponse)
|
|
async def validate_token(request: TokenValidationRequest):
|
|
"""Validate ID token and return secret phrase for authorized users"""
|
|
# Log the incoming request for debugging
|
|
print(f"Validating token, length: {len(request.id_token) if request.id_token else 0}")
|
|
|
|
try:
|
|
# Verify token
|
|
payload = verify_token(request.id_token)
|
|
|
|
# Extract user email
|
|
user_email = payload.get("email")
|
|
if not user_email:
|
|
return TokenValidationResponse(
|
|
valid=False,
|
|
error="Email not found in token"
|
|
)
|
|
|
|
# Initialize user info with email from token
|
|
user_info = {
|
|
"email": user_email,
|
|
"first_name": None,
|
|
"last_name": None
|
|
}
|
|
|
|
# If access token is provided, fetch additional user info from userinfo endpoint
|
|
if request.access_token:
|
|
try:
|
|
additional_info = get_user_info_from_endpoint(request.access_token)
|
|
user_info.update(additional_info)
|
|
except Exception as e:
|
|
print(f"Warning: Failed to fetch user info from userinfo endpoint: {str(e)}")
|
|
# Continue with basic user info from token
|
|
|
|
# Check if user is authorized
|
|
secret_phrase = AUTHORIZED_USERS.get(user_email)
|
|
|
|
return TokenValidationResponse(
|
|
valid=True,
|
|
user=UserInfo(**user_info),
|
|
secret_phrase=secret_phrase
|
|
)
|
|
|
|
except ValueError as e:
|
|
print(f"ValueError during token validation: {str(e)}")
|
|
return TokenValidationResponse(
|
|
valid=False,
|
|
error=str(e)
|
|
)
|
|
except Exception as e:
|
|
print(f"Unexpected error during token validation: {str(e)}")
|
|
return TokenValidationResponse(
|
|
valid=False,
|
|
error=f"Unexpected error: {str(e)}"
|
|
)
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
async def read_root():
|
|
"""Serve the frontend HTML file"""
|
|
with open("templates/index.html", "r") as file:
|
|
html_content = file.read()
|
|
return HTMLResponse(content=html_content, status_code=200)
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint for Kubernetes deployment"""
|
|
return {"status": "healthy", "timestamp": os.getenv("HOSTNAME", "unknown")}
|
|
|
|
@app.post("/refresh-token", response_model=TokenRefreshResponse)
|
|
async def refresh_token_endpoint(request: TokenRefreshRequest):
|
|
"""Refresh an access token using a refresh token"""
|
|
try:
|
|
# Refresh the token
|
|
refreshed_tokens = refresh_token(request.refresh_token)
|
|
|
|
return TokenRefreshResponse(
|
|
access_token=refreshed_tokens["access_token"],
|
|
id_token=refreshed_tokens["id_token"],
|
|
expires_in=refreshed_tokens["expires_in"],
|
|
token_type=refreshed_tokens["token_type"]
|
|
)
|
|
except ValueError as e:
|
|
return TokenRefreshResponse(
|
|
access_token="", # Required field, but empty since there's an error
|
|
id_token=None,
|
|
expires_in=0,
|
|
token_type="",
|
|
error=str(e)
|
|
)
|
|
except Exception as e:
|
|
return TokenRefreshResponse(
|
|
access_token="", # Required field, but empty since there's an error
|
|
id_token=None,
|
|
expires_in=0,
|
|
token_type="",
|
|
error=f"Unexpected error: {str(e)}"
|
|
)
|
|
|
|
@app.get("/favicon.ico")
|
|
async def favicon():
|
|
"""Serve the favicon"""
|
|
return FileResponse("templates/favicon.ico")
|
|
|
|
@app.get("/config", response_model=ClientConfig)
|
|
async def get_client_config():
|
|
"""Serve client configuration to frontend"""
|
|
return ClientConfig(client_id=CLIENT_ID, issuer=ISSUER)
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |