Thursday, March 6, 2025

Few Middlewares for FastAPI

 CORs 

=======

from fastapi import FastAPI

from fastapi.middleware.cors import CORSMiddleware


app = FastAPI()


app.add_middleware(

    CORSMiddleware,

    allow_origins=["*"],  # Allows all origins

    allow_credentials=True,

    allow_methods=["*"],

    allow_headers=["*"],

)


@app.get("/")

async def root():

    return {"message": "Hello World"}


GZipMiddleware

===============


from fastapi import FastAPI

from fastapi.middleware.gzip import GZipMiddleware


app = FastAPI()

app.add_middleware(GZipMiddleware, minimum_size=1000)  # Compress responses larger than 1000 bytes


@app.get("/")

async def root():

    return {"message": "This is a test message that will be compressed."}



HTTPSRedirect Middleware

====================

from fastapi import FastAPI

from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware


app = FastAPI()

app.add_middleware(HTTPSRedirectMiddleware)


@app.get("/")

async def root():

    return {"message": "You are being redirected to HTTPS!"}


4. Session Middleware

=====================

from fastapi import FastAPI, Request

from starlette.middleware.sessions import SessionMiddleware


app = FastAPI()

app.add_middleware(SessionMiddleware, secret_key="your-secret-key")


@app.get("/set/")

async def set_session_data(request: Request):

    request.session['user'] = 'john_doe'

    return {"message": "Session data set"}


@app.get("/get/")

async def get_session_data(request: Request):

    user = request.session.get('user', 'guest')

    return {"user": user}



TrustedHost Middleware

======================

from fastapi import FastAPI

from fastapi.middleware.trustedhost import TrustedHostMiddleware


app = FastAPI()

app.add_middleware(TrustedHostMiddleware, allowed_hosts=["example.com", "*.example.com"])


@app.get("/")

async def root():

    return {"message": "This request came from a trusted host."}




Error Handling Middleware

=========================

from fastapi import FastAPI, Request

from fastapi.responses import JSONResponse

from starlette.middleware.base import BaseHTTPMiddleware


class ErrorHandlingMiddleware(BaseHTTPMiddleware):

    async def dispatch(self, request: Request, call_next):

        try:

            response = await call_next(request)

        except Exception as e:

            response = JSONResponse({"error": str(e)}, status_code=500)

        return response


app = FastAPI()

app.add_middleware(ErrorHandlingMiddleware)


@app.get("/")

async def root():

    raise ValueError("This is an error!")


Rate Limiting Middleware

==========================

from fastapi import FastAPI, Request

from fastapi.responses import JSONResponse

from starlette.middleware.base import BaseHTTPMiddleware

import time


class RateLimitMiddleware(BaseHTTPMiddleware):

    def __init__(self, app, max_requests: int, window: int):

        super().__init__(app)

        self.max_requests = max_requests

        self.window = window

        self.requests = {}


    async def dispatch(self, request: Request, call_next):

        client_ip = request.client.host

        current_time = time.time()


        if client_ip not in self.requests:

            self.requests[client_ip] = []


        self.requests[client_ip] = [timestamp for timestamp in self.requests[client_ip] if timestamp > current_time - self.window]


        if len(self.requests[client_ip]) >= self.max_requests:

            return JSONResponse(status_code=429, content={"error": "Too many requests"})


        self.requests[client_ip].append(current_time)

        return await call_next(request)



app = FastAPI()

app.add_middleware(RateLimitMiddleware, max_requests=5, window=60)


@app.get("/")

async def root():

    return {"message": "You haven't hit the rate limit yet!"}




 Authentication Middleware

==========================

from fastapi import FastAPI, Request, HTTPException

from starlette.middleware.base import BaseHTTPMiddleware

from fastapi.responses import PlainTextResponse


class AuthMiddleware(BaseHTTPMiddleware):

    async def dispatch(self, request: Request, call_next):

        token = request.headers.get("Authorization")

        if not token or token != "Bearer valid-token":

            return PlainTextResponse(status_code=401, content="Unauthorized")

        return await call_next(request)


app = FastAPI()

app.add_middleware(AuthMiddleware)


@app.get("/secure-data/")

async def secure_data():

    return {"message": "This is secured data"}




Headers Injection Middleware

===========================

from fastapi import FastAPI

from starlette.middleware.base import BaseHTTPMiddleware


class CustomHeaderMiddleware(BaseHTTPMiddleware):

    async def dispatch(self, request, call_next):

        response = await call_next(request)

        response.headers['Cache-Control'] = 'public, max-age=3600'

        response.headers["X-Content-Type-Options"] = "nosniff"

        response.headers["X-Frame-Options"] = "DENY"

        response.headers["X-XSS-Protection"] = "1; mode=block"

        response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"

        return response


app = FastAPI()

app.add_middleware(CustomHeaderMiddleware)


@app.get("/data/")

async def get_data():

    return {"message": "This response is cached for 1 hour."}



 Logging Middleware

===================

from fastapi import FastAPI, Request

import logging

from starlette.middleware.base import BaseHTTPMiddleware


logger = logging.getLogger("my_logger")


class LoggingMiddleware(BaseHTTPMiddleware):

    async def dispatch(self, request: Request, call_next):

        logger.info(f"Request: {request.method} {request.url}")

        response = await call_next(request)

        logger.info(f"Response status: {response.status_code}")

        return response


app = FastAPI()

app.add_middleware(LoggingMiddleware)


@app.get("/")

async def root():

    return {"message": "Check your logs for the request and response details."}



Timeout Middleware

==================

from fastapi import FastAPI, Request, HTTPException

from fastapi.responses import PlainTextResponse

import asyncio

from starlette.middleware.base import BaseHTTPMiddleware


class TimeoutMiddleware(BaseHTTPMiddleware):

    def __init__(self, app, timeout: int):

        super().__init__(app)

        self.timeout = timeout


    async def dispatch(self, request: Request, call_next):

        try:

            return await asyncio.wait_for(call_next(request), timeout=self.timeout)

        except asyncio.TimeoutError:

            return PlainTextResponse(status_code=504, content="Request timed out")


app = FastAPI()

app.add_middleware(TimeoutMiddleware, timeout=5)


@app.get("/")

async def root():

    await asyncio.sleep(10)  # Simulates a long-running process

    return {"message": "This won't be reached if the timeout is less than 10 seconds."}



 IP Whitelisting Middleware

===========================


from fastapi import FastAPI, Request, HTTPException

from starlette.middleware.base import BaseHTTPMiddleware

from fastapi.responses import PlainTextResponse


class IPWhitelistMiddleware(BaseHTTPMiddleware):

    def __init__(self, app, whitelist):

        super().__init__(app)

        self.whitelist = whitelist


    async def dispatch(self, request: Request, call_next):

        client_ip = request.client.host

        if client_ip not in self.whitelist:

            return PlainTextResponse(status_code=403, content="IP not allowed")

        return await call_next(request)


app = FastAPI()

app.add_middleware(IPWhitelistMiddleware, whitelist=["127.0.0.1", "192.168.1.1"])


@app.get("/")

async def root():

    return {"message": "Your IP is whitelisted!"}




ProxyHeadersMiddleware

=========================


from fastapi import FastAPI, Request

from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware



app = FastAPI()

app.add_middleware(ProxyHeadersMiddleware)



@app.get("/")

async def root(request: Request):

    return {"client_ip": request.client.host}



CSRF Middleware

================


from fastapi import FastAPI, Request

from starlette_csrf import CSRFMiddleware


app = FastAPI()


app.add_middleware(CSRFMiddleware, secret="__CHANGE_ME__")



@app.get("/")

async def root(request: Request):

    return {"message": request.cookies.get('csrftoken')}



GlobalsMiddleware

=================


from fastapi import FastAPI, Depends

from fastapi_g_context import GlobalsMiddleware, g


app = FastAPI()

app.add_middleware(GlobalsMiddleware)


async def set_globals() -> None:

    g.username = "JohnDoe"

    g.request_id = "123456"

    g.is_admin = True


@app.get("/", dependencies=[Depends(set_globals)])

async def info():

    return {"username": g.username, "request_id": g.request_id, "is_admin": g.is_admin}




No comments:

Post a Comment