""" Rate limiting middleware for the OpenEventDatabase. """ import time import threading import falcon from collections import defaultdict from oedb.utils.logging import logger class RateLimitMiddleware: """ Middleware that implements rate limiting to prevent API abuse. This middleware tracks request rates by IP address and rejects requests that exceed defined limits. It helps protect the API from abuse and ensures fair usage. """ def __init__(self, window_size=60, max_requests=60): """ Initialize the middleware with rate limiting settings. Args: window_size: Time window in seconds for rate limiting. max_requests: Maximum number of requests allowed per IP in the window. """ self.window_size = window_size self.max_requests = max_requests # Store request timestamps by IP self.requests = defaultdict(list) # Lock for thread safety self.lock = threading.Lock() # Define rate limit rules for different endpoints # Format: (endpoint_prefix, method, max_requests) self.rate_limit_rules = [ # Limit POST requests to /event to 10 per minute ('/event', 'POST', 10), # Limit POST requests to /event/search to 20 per minute ('/event/search', 'POST', 20), # Limit DELETE requests to /event to 5 per minute ('/event', 'DELETE', 5), ] logger.info(f"Rate limiting initialized: {max_requests} requests per {window_size} seconds") def process_request(self, req, resp): """ Process the request and apply rate limiting. Args: req: The request object. resp: The response object. Raises: falcon.HTTPTooManyRequests: If the rate limit is exceeded. """ # Get client IP address client_ip = self._get_client_ip(req) # Skip rate limiting for local requests (for development) if client_ip in ('127.0.0.1', 'localhost', '::1'): return # Get the appropriate rate limit for this endpoint max_requests = self._get_max_requests(req) # Check if the rate limit is exceeded with self.lock: # Clean up old requests self._clean_old_requests(client_ip) # Count recent requests recent_requests = len(self.requests[client_ip]) # Check if the rate limit is exceeded if recent_requests >= max_requests: logger.warning(f"Rate limit exceeded for IP {client_ip}: {recent_requests} requests in {self.window_size} seconds") retry_after = self.window_size - (int(time.time()) - self.requests[client_ip][0]) retry_after = max(1, retry_after) # Ensure retry_after is at least 1 second # Add the request to the log for tracking abuse patterns self._log_rate_limit_exceeded(client_ip, req) # Raise an exception to reject the request raise falcon.HTTPTooManyRequests( title="Rate limit exceeded", description=f"You have exceeded the rate limit of {max_requests} requests per {self.window_size} seconds", headers={'Retry-After': str(retry_after)} ) # Add the current request timestamp self.requests[client_ip].append(int(time.time())) def _get_client_ip(self, req): """ Get the client IP address from the request. Args: req: The request object. Returns: str: The client IP address. """ # Try to get the real IP from X-Forwarded-For header (if behind a proxy) forwarded_for = req.get_header('X-Forwarded-For') if forwarded_for: # The client IP is the first address in the list return forwarded_for.split(',')[0].strip() # Fall back to the remote_addr return req.remote_addr or '0.0.0.0' def _clean_old_requests(self, client_ip): """ Remove request timestamps that are outside the current window. Args: client_ip: The client IP address. """ if client_ip not in self.requests: return current_time = int(time.time()) cutoff_time = current_time - self.window_size # Keep only requests within the current window self.requests[client_ip] = [t for t in self.requests[client_ip] if t > cutoff_time] # Remove the IP from the dictionary if there are no recent requests if not self.requests[client_ip]: del self.requests[client_ip] def _get_max_requests(self, req): """ Determine the maximum requests allowed for the current endpoint. Args: req: The request object. Returns: int: The maximum number of requests allowed. """ # Check if the request matches any rate limit rules for endpoint, method, max_requests in self.rate_limit_rules: if req.path.startswith(endpoint) and req.method == method: return max_requests # Default to the global max_requests return self.max_requests def _log_rate_limit_exceeded(self, client_ip, req): """ Log details when a rate limit is exceeded for analysis. Args: client_ip: The client IP address. req: The request object. """ logger.warning( f"Rate limit exceeded: IP={client_ip}, " f"Method={req.method}, Path={req.path}, " f"User-Agent={req.get_header('User-Agent', 'Unknown')}" )