oedb-backend/oedb/middleware/rate_limit.py

167 lines
5.9 KiB
Python
Raw Normal View History

"""
Rate limiting middleware for the OpenEventDatabase.
"""
import time
import threading
import falcon
from collections import defaultdict
from oedb.utils.logging import logger
class RateLimitMiddleware:
"""
Middleware that implements rate limiting to prevent API abuse.
This middleware tracks request rates by IP address and rejects requests
that exceed defined limits. It helps protect the API from abuse and
ensures fair usage.
"""
def __init__(self, window_size=60, max_requests=60):
"""
Initialize the middleware with rate limiting settings.
Args:
window_size: Time window in seconds for rate limiting.
max_requests: Maximum number of requests allowed per IP in the window.
"""
self.window_size = window_size
self.max_requests = max_requests
# Store request timestamps by IP
self.requests = defaultdict(list)
# Lock for thread safety
self.lock = threading.Lock()
# Define rate limit rules for different endpoints
# Format: (endpoint_prefix, method, max_requests)
self.rate_limit_rules = [
# Limit POST requests to /event to 10 per minute
('/event', 'POST', 10),
# Limit POST requests to /event/search to 20 per minute
('/event/search', 'POST', 20),
# Limit DELETE requests to /event to 5 per minute
('/event', 'DELETE', 5),
]
logger.info(f"Rate limiting initialized: {max_requests} requests per {window_size} seconds")
def process_request(self, req, resp):
"""
Process the request and apply rate limiting.
Args:
req: The request object.
resp: The response object.
Raises:
falcon.HTTPTooManyRequests: If the rate limit is exceeded.
"""
# Get client IP address
client_ip = self._get_client_ip(req)
# Skip rate limiting for local requests (for development)
if client_ip in ('127.0.0.1', 'localhost', '::1'):
return
# Get the appropriate rate limit for this endpoint
max_requests = self._get_max_requests(req)
# Check if the rate limit is exceeded
with self.lock:
# Clean up old requests
self._clean_old_requests(client_ip)
# Count recent requests
recent_requests = len(self.requests[client_ip])
# Check if the rate limit is exceeded
if recent_requests >= max_requests:
logger.warning(f"Rate limit exceeded for IP {client_ip}: {recent_requests} requests in {self.window_size} seconds")
retry_after = self.window_size - (int(time.time()) - self.requests[client_ip][0])
retry_after = max(1, retry_after) # Ensure retry_after is at least 1 second
# Add the request to the log for tracking abuse patterns
self._log_rate_limit_exceeded(client_ip, req)
# Raise an exception to reject the request
raise falcon.HTTPTooManyRequests(
title="Rate limit exceeded",
description=f"You have exceeded the rate limit of {max_requests} requests per {self.window_size} seconds",
headers={'Retry-After': str(retry_after)}
)
# Add the current request timestamp
self.requests[client_ip].append(int(time.time()))
def _get_client_ip(self, req):
"""
Get the client IP address from the request.
Args:
req: The request object.
Returns:
str: The client IP address.
"""
# Try to get the real IP from X-Forwarded-For header (if behind a proxy)
forwarded_for = req.get_header('X-Forwarded-For')
if forwarded_for:
# The client IP is the first address in the list
return forwarded_for.split(',')[0].strip()
# Fall back to the remote_addr
return req.remote_addr or '0.0.0.0'
def _clean_old_requests(self, client_ip):
"""
Remove request timestamps that are outside the current window.
Args:
client_ip: The client IP address.
"""
if client_ip not in self.requests:
return
current_time = int(time.time())
cutoff_time = current_time - self.window_size
# Keep only requests within the current window
self.requests[client_ip] = [t for t in self.requests[client_ip] if t > cutoff_time]
# Remove the IP from the dictionary if there are no recent requests
if not self.requests[client_ip]:
del self.requests[client_ip]
def _get_max_requests(self, req):
"""
Determine the maximum requests allowed for the current endpoint.
Args:
req: The request object.
Returns:
int: The maximum number of requests allowed.
"""
# Check if the request matches any rate limit rules
for endpoint, method, max_requests in self.rate_limit_rules:
if req.path.startswith(endpoint) and req.method == method:
return max_requests
# Default to the global max_requests
return self.max_requests
def _log_rate_limit_exceeded(self, client_ip, req):
"""
Log details when a rate limit is exceeded for analysis.
Args:
client_ip: The client IP address.
req: The request object.
"""
logger.warning(
f"Rate limit exceeded: IP={client_ip}, "
f"Method={req.method}, Path={req.path}, "
f"User-Agent={req.get_header('User-Agent', 'Unknown')}"
)