|
|
|
|
|
""" |
|
|
Simple Rate Limiter for API Endpoints |
|
|
""" |
|
|
|
|
|
import time |
|
|
from collections import defaultdict |
|
|
from typing import Dict, Tuple |
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class SimpleRateLimiter: |
|
|
""" |
|
|
Simple in-memory rate limiter |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
|
|
|
self.requests: Dict[str, list] = defaultdict(list) |
|
|
|
|
|
|
|
|
self.limits = { |
|
|
"default": 60, |
|
|
"sentiment": 30, |
|
|
"model_loading": 5, |
|
|
"dataset_loading": 5, |
|
|
"external_api": 100 |
|
|
} |
|
|
|
|
|
|
|
|
self.window = 60 |
|
|
|
|
|
def is_allowed( |
|
|
self, |
|
|
client_id: str, |
|
|
endpoint_type: str = "default" |
|
|
) -> Tuple[bool, Dict]: |
|
|
""" |
|
|
Check if request is allowed based on rate limit |
|
|
|
|
|
Args: |
|
|
client_id: Client identifier (IP, API key, etc.) |
|
|
endpoint_type: Type of endpoint (default, sentiment, model_loading, etc.) |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_allowed, info_dict) |
|
|
""" |
|
|
current_time = time.time() |
|
|
limit = self.limits.get(endpoint_type, self.limits["default"]) |
|
|
|
|
|
|
|
|
self.requests[client_id] = [ |
|
|
ts for ts in self.requests[client_id] |
|
|
if current_time - ts < self.window |
|
|
] |
|
|
|
|
|
|
|
|
request_count = len(self.requests[client_id]) |
|
|
|
|
|
|
|
|
if request_count < limit: |
|
|
|
|
|
self.requests[client_id].append(current_time) |
|
|
|
|
|
return True, { |
|
|
"allowed": True, |
|
|
"requests_remaining": limit - request_count - 1, |
|
|
"limit": limit, |
|
|
"window_seconds": self.window, |
|
|
"reset_at": current_time + self.window |
|
|
} |
|
|
else: |
|
|
|
|
|
oldest_request = min(self.requests[client_id]) |
|
|
reset_at = oldest_request + self.window |
|
|
|
|
|
return False, { |
|
|
"allowed": False, |
|
|
"requests_remaining": 0, |
|
|
"limit": limit, |
|
|
"window_seconds": self.window, |
|
|
"reset_at": reset_at, |
|
|
"retry_after": reset_at - current_time |
|
|
} |
|
|
|
|
|
def reset_client(self, client_id: str): |
|
|
"""Reset rate limit for a specific client""" |
|
|
if client_id in self.requests: |
|
|
del self.requests[client_id] |
|
|
logger.info(f"Rate limit reset for client: {client_id}") |
|
|
|
|
|
def get_stats(self) -> Dict: |
|
|
"""Get rate limiter statistics""" |
|
|
current_time = time.time() |
|
|
|
|
|
active_clients = 0 |
|
|
total_requests = 0 |
|
|
|
|
|
for client_id, timestamps in self.requests.items(): |
|
|
|
|
|
recent_requests = [ |
|
|
ts for ts in timestamps |
|
|
if current_time - ts < self.window |
|
|
] |
|
|
if recent_requests: |
|
|
active_clients += 1 |
|
|
total_requests += len(recent_requests) |
|
|
|
|
|
return { |
|
|
"active_clients": active_clients, |
|
|
"total_recent_requests": total_requests, |
|
|
"window_seconds": self.window, |
|
|
"limits": self.limits |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
rate_limiter = SimpleRateLimiter() |
|
|
|
|
|
|
|
|
|
|
|
__all__ = ["SimpleRateLimiter", "rate_limiter"] |
|
|
|