167 lines
No EOL
5.9 KiB
Python
167 lines
No EOL
5.9 KiB
Python
"""
|
|
Rate limiting middleware for the OpenEventDatabase.
|
|
"""
|
|
|
|
import time
|
|
import threading
|
|
import falcon
|
|
from collections import defaultdict
|
|
from oedb.utils.logging import logger
|
|
|
|
class RateLimitMiddleware:
|
|
"""
|
|
Middleware that implements rate limiting to prevent API abuse.
|
|
|
|
This middleware tracks request rates by IP address and rejects requests
|
|
that exceed defined limits. It helps protect the API from abuse and
|
|
ensures fair usage.
|
|
"""
|
|
|
|
def __init__(self, window_size=60, max_requests=60):
|
|
"""
|
|
Initialize the middleware with rate limiting settings.
|
|
|
|
Args:
|
|
window_size: Time window in seconds for rate limiting.
|
|
max_requests: Maximum number of requests allowed per IP in the window.
|
|
"""
|
|
self.window_size = window_size
|
|
self.max_requests = max_requests
|
|
|
|
# Store request timestamps by IP
|
|
self.requests = defaultdict(list)
|
|
|
|
# Lock for thread safety
|
|
self.lock = threading.Lock()
|
|
|
|
# Define rate limit rules for different endpoints
|
|
# Format: (endpoint_prefix, method, max_requests)
|
|
self.rate_limit_rules = [
|
|
# Limit POST requests to /event to 10 per minute
|
|
('/event', 'POST', 10),
|
|
# Limit POST requests to /event/search to 20 per minute
|
|
('/event/search', 'POST', 20),
|
|
# Limit DELETE requests to /event to 5 per minute
|
|
('/event', 'DELETE', 5),
|
|
]
|
|
|
|
logger.info(f"Rate limiting initialized: {max_requests} requests per {window_size} seconds")
|
|
|
|
def process_request(self, req, resp):
|
|
"""
|
|
Process the request and apply rate limiting.
|
|
|
|
Args:
|
|
req: The request object.
|
|
resp: The response object.
|
|
|
|
Raises:
|
|
falcon.HTTPTooManyRequests: If the rate limit is exceeded.
|
|
"""
|
|
# Get client IP address
|
|
client_ip = self._get_client_ip(req)
|
|
|
|
# Skip rate limiting for local requests (for development)
|
|
if client_ip in ('127.0.0.1', 'localhost', '::1'):
|
|
return
|
|
|
|
# Get the appropriate rate limit for this endpoint
|
|
max_requests = self._get_max_requests(req)
|
|
|
|
# Check if the rate limit is exceeded
|
|
with self.lock:
|
|
# Clean up old requests
|
|
self._clean_old_requests(client_ip)
|
|
|
|
# Count recent requests
|
|
recent_requests = len(self.requests[client_ip])
|
|
|
|
# Check if the rate limit is exceeded
|
|
if recent_requests >= max_requests:
|
|
logger.warning(f"Rate limit exceeded for IP {client_ip}: {recent_requests} requests in {self.window_size} seconds")
|
|
retry_after = self.window_size - (int(time.time()) - self.requests[client_ip][0])
|
|
retry_after = max(1, retry_after) # Ensure retry_after is at least 1 second
|
|
|
|
# Add the request to the log for tracking abuse patterns
|
|
self._log_rate_limit_exceeded(client_ip, req)
|
|
|
|
# Raise an exception to reject the request
|
|
raise falcon.HTTPTooManyRequests(
|
|
title="Rate limit exceeded",
|
|
description=f"You have exceeded the rate limit of {max_requests} requests per {self.window_size} seconds",
|
|
headers={'Retry-After': str(retry_after)}
|
|
)
|
|
|
|
# Add the current request timestamp
|
|
self.requests[client_ip].append(int(time.time()))
|
|
|
|
def _get_client_ip(self, req):
|
|
"""
|
|
Get the client IP address from the request.
|
|
|
|
Args:
|
|
req: The request object.
|
|
|
|
Returns:
|
|
str: The client IP address.
|
|
"""
|
|
# Try to get the real IP from X-Forwarded-For header (if behind a proxy)
|
|
forwarded_for = req.get_header('X-Forwarded-For')
|
|
if forwarded_for:
|
|
# The client IP is the first address in the list
|
|
return forwarded_for.split(',')[0].strip()
|
|
|
|
# Fall back to the remote_addr
|
|
return req.remote_addr or '0.0.0.0'
|
|
|
|
def _clean_old_requests(self, client_ip):
|
|
"""
|
|
Remove request timestamps that are outside the current window.
|
|
|
|
Args:
|
|
client_ip: The client IP address.
|
|
"""
|
|
if client_ip not in self.requests:
|
|
return
|
|
|
|
current_time = int(time.time())
|
|
cutoff_time = current_time - self.window_size
|
|
|
|
# Keep only requests within the current window
|
|
self.requests[client_ip] = [t for t in self.requests[client_ip] if t > cutoff_time]
|
|
|
|
# Remove the IP from the dictionary if there are no recent requests
|
|
if not self.requests[client_ip]:
|
|
del self.requests[client_ip]
|
|
|
|
def _get_max_requests(self, req):
|
|
"""
|
|
Determine the maximum requests allowed for the current endpoint.
|
|
|
|
Args:
|
|
req: The request object.
|
|
|
|
Returns:
|
|
int: The maximum number of requests allowed.
|
|
"""
|
|
# Check if the request matches any rate limit rules
|
|
for endpoint, method, max_requests in self.rate_limit_rules:
|
|
if req.path.startswith(endpoint) and req.method == method:
|
|
return max_requests
|
|
|
|
# Default to the global max_requests
|
|
return self.max_requests
|
|
|
|
def _log_rate_limit_exceeded(self, client_ip, req):
|
|
"""
|
|
Log details when a rate limit is exceeded for analysis.
|
|
|
|
Args:
|
|
client_ip: The client IP address.
|
|
req: The request object.
|
|
"""
|
|
logger.warning(
|
|
f"Rate limit exceeded: IP={client_ip}, "
|
|
f"Method={req.method}, Path={req.path}, "
|
|
f"User-Agent={req.get_header('User-Agent', 'Unknown')}"
|
|
) |