167 lines
		
	
	
	
		
			5.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			167 lines
		
	
	
	
		
			5.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | """
 | ||
|  | Rate limiting middleware for the OpenEventDatabase. | ||
|  | """
 | ||
|  | 
 | ||
|  | import time | ||
|  | import threading | ||
|  | import falcon | ||
|  | from collections import defaultdict | ||
|  | from oedb.utils.logging import logger | ||
|  | 
 | ||
|  | class RateLimitMiddleware: | ||
|  |     """
 | ||
|  |     Middleware that implements rate limiting to prevent API abuse. | ||
|  |      | ||
|  |     This middleware tracks request rates by IP address and rejects requests | ||
|  |     that exceed defined limits. It helps protect the API from abuse and | ||
|  |     ensures fair usage. | ||
|  |     """
 | ||
|  |      | ||
|  |     def __init__(self, window_size=60, max_requests=60): | ||
|  |         """
 | ||
|  |         Initialize the middleware with rate limiting settings. | ||
|  |          | ||
|  |         Args: | ||
|  |             window_size: Time window in seconds for rate limiting. | ||
|  |             max_requests: Maximum number of requests allowed per IP in the window. | ||
|  |         """
 | ||
|  |         self.window_size = window_size | ||
|  |         self.max_requests = max_requests | ||
|  |          | ||
|  |         # Store request timestamps by IP | ||
|  |         self.requests = defaultdict(list) | ||
|  |          | ||
|  |         # Lock for thread safety | ||
|  |         self.lock = threading.Lock() | ||
|  |          | ||
|  |         # Define rate limit rules for different endpoints | ||
|  |         # Format: (endpoint_prefix, method, max_requests) | ||
|  |         self.rate_limit_rules = [ | ||
|  |             # Limit POST requests to /event to 10 per minute | ||
|  |             ('/event', 'POST', 10), | ||
|  |             # Limit POST requests to /event/search to 20 per minute | ||
|  |             ('/event/search', 'POST', 20), | ||
|  |             # Limit DELETE requests to /event to 5 per minute | ||
|  |             ('/event', 'DELETE', 5), | ||
|  |         ] | ||
|  |          | ||
|  |         logger.info(f"Rate limiting initialized: {max_requests} requests per {window_size} seconds") | ||
|  |      | ||
|  |     def process_request(self, req, resp): | ||
|  |         """
 | ||
|  |         Process the request and apply rate limiting. | ||
|  |          | ||
|  |         Args: | ||
|  |             req: The request object. | ||
|  |             resp: The response object. | ||
|  |          | ||
|  |         Raises: | ||
|  |             falcon.HTTPTooManyRequests: If the rate limit is exceeded. | ||
|  |         """
 | ||
|  |         # Get client IP address | ||
|  |         client_ip = self._get_client_ip(req) | ||
|  |          | ||
|  |         # Skip rate limiting for local requests (for development) | ||
|  |         if client_ip in ('127.0.0.1', 'localhost', '::1'): | ||
|  |             return | ||
|  |          | ||
|  |         # Get the appropriate rate limit for this endpoint | ||
|  |         max_requests = self._get_max_requests(req) | ||
|  |          | ||
|  |         # Check if the rate limit is exceeded | ||
|  |         with self.lock: | ||
|  |             # Clean up old requests | ||
|  |             self._clean_old_requests(client_ip) | ||
|  |              | ||
|  |             # Count recent requests | ||
|  |             recent_requests = len(self.requests[client_ip]) | ||
|  |              | ||
|  |             # Check if the rate limit is exceeded | ||
|  |             if recent_requests >= max_requests: | ||
|  |                 logger.warning(f"Rate limit exceeded for IP {client_ip}: {recent_requests} requests in {self.window_size} seconds") | ||
|  |                 retry_after = self.window_size - (int(time.time()) - self.requests[client_ip][0]) | ||
|  |                 retry_after = max(1, retry_after)  # Ensure retry_after is at least 1 second | ||
|  |                  | ||
|  |                 # Add the request to the log for tracking abuse patterns | ||
|  |                 self._log_rate_limit_exceeded(client_ip, req) | ||
|  |                  | ||
|  |                 # Raise an exception to reject the request | ||
|  |                 raise falcon.HTTPTooManyRequests( | ||
|  |                     title="Rate limit exceeded", | ||
|  |                     description=f"You have exceeded the rate limit of {max_requests} requests per {self.window_size} seconds", | ||
|  |                     headers={'Retry-After': str(retry_after)} | ||
|  |                 ) | ||
|  |              | ||
|  |             # Add the current request timestamp | ||
|  |             self.requests[client_ip].append(int(time.time())) | ||
|  |      | ||
|  |     def _get_client_ip(self, req): | ||
|  |         """
 | ||
|  |         Get the client IP address from the request. | ||
|  |          | ||
|  |         Args: | ||
|  |             req: The request object. | ||
|  |              | ||
|  |         Returns: | ||
|  |             str: The client IP address. | ||
|  |         """
 | ||
|  |         # Try to get the real IP from X-Forwarded-For header (if behind a proxy) | ||
|  |         forwarded_for = req.get_header('X-Forwarded-For') | ||
|  |         if forwarded_for: | ||
|  |             # The client IP is the first address in the list | ||
|  |             return forwarded_for.split(',')[0].strip() | ||
|  |          | ||
|  |         # Fall back to the remote_addr | ||
|  |         return req.remote_addr or '0.0.0.0' | ||
|  |      | ||
|  |     def _clean_old_requests(self, client_ip): | ||
|  |         """
 | ||
|  |         Remove request timestamps that are outside the current window. | ||
|  |          | ||
|  |         Args: | ||
|  |             client_ip: The client IP address. | ||
|  |         """
 | ||
|  |         if client_ip not in self.requests: | ||
|  |             return | ||
|  |          | ||
|  |         current_time = int(time.time()) | ||
|  |         cutoff_time = current_time - self.window_size | ||
|  |          | ||
|  |         # Keep only requests within the current window | ||
|  |         self.requests[client_ip] = [t for t in self.requests[client_ip] if t > cutoff_time] | ||
|  |          | ||
|  |         # Remove the IP from the dictionary if there are no recent requests | ||
|  |         if not self.requests[client_ip]: | ||
|  |             del self.requests[client_ip] | ||
|  |      | ||
|  |     def _get_max_requests(self, req): | ||
|  |         """
 | ||
|  |         Determine the maximum requests allowed for the current endpoint. | ||
|  |          | ||
|  |         Args: | ||
|  |             req: The request object. | ||
|  |              | ||
|  |         Returns: | ||
|  |             int: The maximum number of requests allowed. | ||
|  |         """
 | ||
|  |         # Check if the request matches any rate limit rules | ||
|  |         for endpoint, method, max_requests in self.rate_limit_rules: | ||
|  |             if req.path.startswith(endpoint) and req.method == method: | ||
|  |                 return max_requests | ||
|  |          | ||
|  |         # Default to the global max_requests | ||
|  |         return self.max_requests | ||
|  |      | ||
|  |     def _log_rate_limit_exceeded(self, client_ip, req): | ||
|  |         """
 | ||
|  |         Log details when a rate limit is exceeded for analysis. | ||
|  |          | ||
|  |         Args: | ||
|  |             client_ip: The client IP address. | ||
|  |             req: The request object. | ||
|  |         """
 | ||
|  |         logger.warning( | ||
|  |             f"Rate limit exceeded: IP={client_ip}, " | ||
|  |             f"Method={req.method}, Path={req.path}, " | ||
|  |             f"User-Agent={req.get_header('User-Agent', 'Unknown')}" | ||
|  |         ) |