CORs
=======
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
return {"message": "Hello World"}
GZipMiddleware
===============
from fastapi import FastAPI
from fastapi.middleware.gzip import GZipMiddleware
app = FastAPI()
app.add_middleware(GZipMiddleware, minimum_size=1000) # Compress responses larger than 1000 bytes
@app.get("/")
async def root():
return {"message": "This is a test message that will be compressed."}
HTTPSRedirect Middleware
====================
from fastapi import FastAPI
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
app = FastAPI()
app.add_middleware(HTTPSRedirectMiddleware)
@app.get("/")
async def root():
return {"message": "You are being redirected to HTTPS!"}
4. Session Middleware
=====================
from fastapi import FastAPI, Request
from starlette.middleware.sessions import SessionMiddleware
app = FastAPI()
app.add_middleware(SessionMiddleware, secret_key="your-secret-key")
@app.get("/set/")
async def set_session_data(request: Request):
request.session['user'] = 'john_doe'
return {"message": "Session data set"}
@app.get("/get/")
async def get_session_data(request: Request):
user = request.session.get('user', 'guest')
return {"user": user}
TrustedHost Middleware
======================
from fastapi import FastAPI
from fastapi.middleware.trustedhost import TrustedHostMiddleware
app = FastAPI()
app.add_middleware(TrustedHostMiddleware, allowed_hosts=["example.com", "*.example.com"])
@app.get("/")
async def root():
return {"message": "This request came from a trusted host."}
Error Handling Middleware
=========================
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
class ErrorHandlingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
try:
response = await call_next(request)
except Exception as e:
response = JSONResponse({"error": str(e)}, status_code=500)
return response
app = FastAPI()
app.add_middleware(ErrorHandlingMiddleware)
@app.get("/")
async def root():
raise ValueError("This is an error!")
Rate Limiting Middleware
==========================
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import time
class RateLimitMiddleware(BaseHTTPMiddleware):
def __init__(self, app, max_requests: int, window: int):
super().__init__(app)
self.max_requests = max_requests
self.window = window
self.requests = {}
async def dispatch(self, request: Request, call_next):
client_ip = request.client.host
current_time = time.time()
if client_ip not in self.requests:
self.requests[client_ip] = []
self.requests[client_ip] = [timestamp for timestamp in self.requests[client_ip] if timestamp > current_time - self.window]
if len(self.requests[client_ip]) >= self.max_requests:
return JSONResponse(status_code=429, content={"error": "Too many requests"})
self.requests[client_ip].append(current_time)
return await call_next(request)
app = FastAPI()
app.add_middleware(RateLimitMiddleware, max_requests=5, window=60)
@app.get("/")
async def root():
return {"message": "You haven't hit the rate limit yet!"}
Authentication Middleware
==========================
from fastapi import FastAPI, Request, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi.responses import PlainTextResponse
class AuthMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
token = request.headers.get("Authorization")
if not token or token != "Bearer valid-token":
return PlainTextResponse(status_code=401, content="Unauthorized")
return await call_next(request)
app = FastAPI()
app.add_middleware(AuthMiddleware)
@app.get("/secure-data/")
async def secure_data():
return {"message": "This is secured data"}
Headers Injection Middleware
===========================
from fastapi import FastAPI
from starlette.middleware.base import BaseHTTPMiddleware
class CustomHeaderMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
response = await call_next(request)
response.headers['Cache-Control'] = 'public, max-age=3600'
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
return response
app = FastAPI()
app.add_middleware(CustomHeaderMiddleware)
@app.get("/data/")
async def get_data():
return {"message": "This response is cached for 1 hour."}
Logging Middleware
===================
from fastapi import FastAPI, Request
import logging
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger("my_logger")
class LoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
logger.info(f"Request: {request.method} {request.url}")
response = await call_next(request)
logger.info(f"Response status: {response.status_code}")
return response
app = FastAPI()
app.add_middleware(LoggingMiddleware)
@app.get("/")
async def root():
return {"message": "Check your logs for the request and response details."}
Timeout Middleware
==================
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import PlainTextResponse
import asyncio
from starlette.middleware.base import BaseHTTPMiddleware
class TimeoutMiddleware(BaseHTTPMiddleware):
def __init__(self, app, timeout: int):
super().__init__(app)
self.timeout = timeout
async def dispatch(self, request: Request, call_next):
try:
return await asyncio.wait_for(call_next(request), timeout=self.timeout)
except asyncio.TimeoutError:
return PlainTextResponse(status_code=504, content="Request timed out")
app = FastAPI()
app.add_middleware(TimeoutMiddleware, timeout=5)
@app.get("/")
async def root():
await asyncio.sleep(10) # Simulates a long-running process
return {"message": "This won't be reached if the timeout is less than 10 seconds."}
IP Whitelisting Middleware
===========================
from fastapi import FastAPI, Request, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi.responses import PlainTextResponse
class IPWhitelistMiddleware(BaseHTTPMiddleware):
def __init__(self, app, whitelist):
super().__init__(app)
self.whitelist = whitelist
async def dispatch(self, request: Request, call_next):
client_ip = request.client.host
if client_ip not in self.whitelist:
return PlainTextResponse(status_code=403, content="IP not allowed")
return await call_next(request)
app = FastAPI()
app.add_middleware(IPWhitelistMiddleware, whitelist=["127.0.0.1", "192.168.1.1"])
@app.get("/")
async def root():
return {"message": "Your IP is whitelisted!"}
ProxyHeadersMiddleware
=========================
from fastapi import FastAPI, Request
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
app = FastAPI()
app.add_middleware(ProxyHeadersMiddleware)
@app.get("/")
async def root(request: Request):
return {"client_ip": request.client.host}
CSRF Middleware
================
from fastapi import FastAPI, Request
from starlette_csrf import CSRFMiddleware
app = FastAPI()
app.add_middleware(CSRFMiddleware, secret="__CHANGE_ME__")
@app.get("/")
async def root(request: Request):
return {"message": request.cookies.get('csrftoken')}
GlobalsMiddleware
=================
from fastapi import FastAPI, Depends
from fastapi_g_context import GlobalsMiddleware, g
app = FastAPI()
app.add_middleware(GlobalsMiddleware)
async def set_globals() -> None:
g.username = "JohnDoe"
g.request_id = "123456"
g.is_admin = True
@app.get("/", dependencies=[Depends(set_globals)])
async def info():
return {"username": g.username, "request_id": g.request_id, "is_admin": g.is_admin}