Initial commit
Adds OIDC token validator application with FastAPI backend and HTML/JavaScript frontend. Includes Docker configuration and Kubernetes readiness.
This commit is contained in:
366
main.py
Normal file
366
main.py
Normal file
@@ -0,0 +1,366 @@
|
||||
from fastapi import FastAPI, HTTPException, Header
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse, FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import jwt
|
||||
import requests
|
||||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
app = FastAPI(title="OIDC Token Validator")
|
||||
|
||||
# Mount static files directory (if needed for CSS, JS, images, etc.)
|
||||
# app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # In production, specify exact origins
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Configuration
|
||||
ISSUER = os.getenv("OIDC_ISSUER", "https://login.infomaniak.com") # OIDC issuer URL
|
||||
CLIENT_ID = os.getenv("CLIENT_ID") # Client ID should be set as environment variable
|
||||
CLIENT_SECRET = os.getenv("CLIENT_SECRET", "") # Client secret should be set as environment variable
|
||||
WELL_KNOWN_CONFIG_URL = f"{ISSUER}/.well-known/openid-configuration"
|
||||
|
||||
# List of authorized users (in a real app, this would come from a database)
|
||||
AUTHORIZED_USERS = {
|
||||
"rene.luria@infomaniak.com": "Welcome to the secret club!",
|
||||
"admin@example.com": "Admin super secret phrase!"
|
||||
}
|
||||
|
||||
class TokenValidationRequest(BaseModel):
|
||||
id_token: str
|
||||
access_token: Optional[str] = None
|
||||
|
||||
class TokenRefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
class TokenRefreshResponse(BaseModel):
|
||||
access_token: str
|
||||
id_token: Optional[str] = None
|
||||
expires_in: int
|
||||
token_type: str
|
||||
error: Optional[str] = None
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
email: str
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
|
||||
class TokenValidationResponse(BaseModel):
|
||||
valid: bool
|
||||
user: Optional[UserInfo] = None
|
||||
secret_phrase: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
class ClientConfig(BaseModel):
|
||||
client_id: str
|
||||
issuer: str
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_well_known_config():
|
||||
"""Fetch OIDC well-known configuration (cached)"""
|
||||
try:
|
||||
response = requests.get(WELL_KNOWN_CONFIG_URL, timeout=10)
|
||||
response.raise_for_status()
|
||||
if not response.text:
|
||||
raise ValueError("Empty response from well-known configuration endpoint")
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise ValueError(f"Failed to fetch well-known configuration: {str(e)}")
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid well-known configuration response: {str(e)}")
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_jwks():
|
||||
"""Fetch JWKS from the issuer (cached)"""
|
||||
try:
|
||||
# Get the JWKS URL from the well-known configuration
|
||||
well_known_config = get_well_known_config()
|
||||
jwks_url = well_known_config.get("jwks_uri")
|
||||
|
||||
if not jwks_url:
|
||||
raise ValueError("JWKS URI not found in well-known configuration")
|
||||
|
||||
response = requests.get(jwks_url, timeout=10)
|
||||
response.raise_for_status()
|
||||
if not response.text:
|
||||
raise ValueError("Empty response from JWKS endpoint")
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise ValueError(f"Failed to fetch JWKS: {str(e)}")
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid JWKS response: {str(e)}")
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_userinfo_endpoint():
|
||||
"""Get the userinfo endpoint URL from the well-known configuration"""
|
||||
try:
|
||||
well_known_config = get_well_known_config()
|
||||
userinfo_endpoint = well_known_config.get("userinfo_endpoint")
|
||||
|
||||
if not userinfo_endpoint:
|
||||
raise ValueError("Userinfo endpoint not found in well-known configuration")
|
||||
|
||||
return userinfo_endpoint
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to get userinfo endpoint: {str(e)}")
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_token_endpoint():
|
||||
"""Get the token endpoint URL from the well-known configuration"""
|
||||
try:
|
||||
well_known_config = get_well_known_config()
|
||||
token_endpoint = well_known_config.get("token_endpoint")
|
||||
|
||||
if not token_endpoint:
|
||||
raise ValueError("Token endpoint not found in well-known configuration")
|
||||
|
||||
return token_endpoint
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to get token endpoint: {str(e)}")
|
||||
|
||||
def get_user_info_from_endpoint(access_token: str):
|
||||
"""Fetch user information from the userinfo endpoint"""
|
||||
try:
|
||||
userinfo_endpoint = get_userinfo_endpoint()
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}"
|
||||
}
|
||||
|
||||
response = requests.get(userinfo_endpoint, headers=headers, timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
if not response.text:
|
||||
raise ValueError("Empty response from userinfo endpoint")
|
||||
|
||||
user_info = response.json()
|
||||
|
||||
# Map the user info fields
|
||||
# Note: Field names may vary depending on the OIDC provider
|
||||
first_name = user_info.get("given_name") or user_info.get("first_name") or user_info.get("firstName")
|
||||
last_name = user_info.get("family_name") or user_info.get("last_name") or user_info.get("lastName")
|
||||
|
||||
return {
|
||||
"email": user_info.get("email", ""),
|
||||
"first_name": first_name,
|
||||
"last_name": last_name
|
||||
}
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise ValueError(f"Failed to fetch user info: {str(e)}")
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid user info response: {str(e)}")
|
||||
|
||||
def refresh_token(refresh_token: str):
|
||||
"""Refresh an access token using a refresh token"""
|
||||
try:
|
||||
# Validate input
|
||||
if not refresh_token or not isinstance(refresh_token, str):
|
||||
raise ValueError("Invalid refresh token provided")
|
||||
|
||||
if not CLIENT_SECRET:
|
||||
raise ValueError("Client secret not configured")
|
||||
|
||||
# Get token endpoint
|
||||
token_endpoint = get_token_endpoint()
|
||||
|
||||
# Prepare refresh request
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": CLIENT_ID,
|
||||
"client_secret": CLIENT_SECRET
|
||||
}
|
||||
|
||||
# Make token refresh request
|
||||
response = requests.post(token_endpoint, data=data, timeout=10)
|
||||
|
||||
if not response.ok:
|
||||
raise ValueError(f"Token refresh failed: {response.status_code} - {response.text}")
|
||||
|
||||
token_response = response.json()
|
||||
|
||||
# Validate required fields in response
|
||||
if "access_token" not in token_response:
|
||||
raise ValueError("Access token not found in refresh response")
|
||||
|
||||
return {
|
||||
"access_token": token_response["access_token"],
|
||||
"id_token": token_response.get("id_token"),
|
||||
"expires_in": token_response.get("expires_in", 0),
|
||||
"token_type": token_response.get("token_type", "Bearer")
|
||||
}
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise ValueError(f"Failed to refresh token: {str(e)}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Token refresh failed: {str(e)}")
|
||||
|
||||
def verify_token(id_token: str):
|
||||
"""Verify JWT token and return payload if valid"""
|
||||
try:
|
||||
# Validate input
|
||||
if not id_token or not isinstance(id_token, str):
|
||||
raise ValueError("Invalid token provided")
|
||||
|
||||
# Get JWKS
|
||||
jwks = get_jwks()
|
||||
|
||||
# Decode token header to get kid
|
||||
header = jwt.get_unverified_header(id_token)
|
||||
kid = header.get("kid")
|
||||
|
||||
if not kid:
|
||||
raise ValueError("Token header missing 'kid' field")
|
||||
|
||||
# Find the matching key in JWKS
|
||||
key = None
|
||||
if not jwks or "keys" not in jwks:
|
||||
raise ValueError("Invalid JWKS format")
|
||||
|
||||
for jwk in jwks.get("keys", []):
|
||||
if jwk.get("kid") == kid:
|
||||
key = jwt.algorithms.RSAAlgorithm.from_jwk(jwk)
|
||||
break
|
||||
|
||||
if not key:
|
||||
raise ValueError("Unable to find matching key in JWKS")
|
||||
|
||||
# Verify and decode token
|
||||
payload = jwt.decode(
|
||||
id_token,
|
||||
key=key,
|
||||
algorithms=[header.get("alg", "RS256")],
|
||||
audience=CLIENT_ID,
|
||||
issuer=ISSUER
|
||||
)
|
||||
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise ValueError("Token has expired")
|
||||
except jwt.InvalidAudienceError:
|
||||
raise ValueError("Invalid audience")
|
||||
except jwt.InvalidIssuerError:
|
||||
raise ValueError("Invalid issuer")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Token validation failed: {str(e)}")
|
||||
|
||||
@app.post("/validate-token", response_model=TokenValidationResponse)
|
||||
async def validate_token(request: TokenValidationRequest):
|
||||
"""Validate ID token and return secret phrase for authorized users"""
|
||||
# Log the incoming request for debugging
|
||||
print(f"Validating token, length: {len(request.id_token) if request.id_token else 0}")
|
||||
|
||||
try:
|
||||
# Verify token
|
||||
payload = verify_token(request.id_token)
|
||||
|
||||
# Extract user email
|
||||
user_email = payload.get("email")
|
||||
if not user_email:
|
||||
return TokenValidationResponse(
|
||||
valid=False,
|
||||
error="Email not found in token"
|
||||
)
|
||||
|
||||
# Initialize user info with email from token
|
||||
user_info = {
|
||||
"email": user_email,
|
||||
"first_name": None,
|
||||
"last_name": None
|
||||
}
|
||||
|
||||
# If access token is provided, fetch additional user info from userinfo endpoint
|
||||
if request.access_token:
|
||||
try:
|
||||
additional_info = get_user_info_from_endpoint(request.access_token)
|
||||
user_info.update(additional_info)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to fetch user info from userinfo endpoint: {str(e)}")
|
||||
# Continue with basic user info from token
|
||||
|
||||
# Check if user is authorized
|
||||
secret_phrase = AUTHORIZED_USERS.get(user_email)
|
||||
|
||||
return TokenValidationResponse(
|
||||
valid=True,
|
||||
user=UserInfo(**user_info),
|
||||
secret_phrase=secret_phrase
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
print(f"ValueError during token validation: {str(e)}")
|
||||
return TokenValidationResponse(
|
||||
valid=False,
|
||||
error=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Unexpected error during token validation: {str(e)}")
|
||||
return TokenValidationResponse(
|
||||
valid=False,
|
||||
error=f"Unexpected error: {str(e)}"
|
||||
)
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def read_root():
|
||||
"""Serve the frontend HTML file"""
|
||||
with open("templates/index.html", "r") as file:
|
||||
html_content = file.read()
|
||||
return HTMLResponse(content=html_content, status_code=200)
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint for Kubernetes deployment"""
|
||||
return {"status": "healthy", "timestamp": os.getenv("HOSTNAME", "unknown")}
|
||||
|
||||
@app.post("/refresh-token", response_model=TokenRefreshResponse)
|
||||
async def refresh_token_endpoint(request: TokenRefreshRequest):
|
||||
"""Refresh an access token using a refresh token"""
|
||||
try:
|
||||
# Refresh the token
|
||||
refreshed_tokens = refresh_token(request.refresh_token)
|
||||
|
||||
return TokenRefreshResponse(
|
||||
access_token=refreshed_tokens["access_token"],
|
||||
id_token=refreshed_tokens["id_token"],
|
||||
expires_in=refreshed_tokens["expires_in"],
|
||||
token_type=refreshed_tokens["token_type"]
|
||||
)
|
||||
except ValueError as e:
|
||||
return TokenRefreshResponse(
|
||||
access_token="", # Required field, but empty since there's an error
|
||||
id_token=None,
|
||||
expires_in=0,
|
||||
token_type="",
|
||||
error=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
return TokenRefreshResponse(
|
||||
access_token="", # Required field, but empty since there's an error
|
||||
id_token=None,
|
||||
expires_in=0,
|
||||
token_type="",
|
||||
error=f"Unexpected error: {str(e)}"
|
||||
)
|
||||
|
||||
@app.get("/favicon.ico")
|
||||
async def favicon():
|
||||
"""Serve the favicon"""
|
||||
return FileResponse("templates/favicon.ico")
|
||||
|
||||
@app.get("/config", response_model=ClientConfig)
|
||||
async def get_client_config():
|
||||
"""Serve client configuration to frontend"""
|
||||
return ClientConfig(client_id=CLIENT_ID, issuer=ISSUER)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
Reference in New Issue
Block a user