|
|
|
|
|
"""
|
|
|
Simple Rate Limiter for API Endpoints
|
|
|
"""
|
|
|
|
|
|
import time
|
|
|
from collections import defaultdict
|
|
|
from typing import Dict, Tuple
|
|
|
import logging
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
class SimpleRateLimiter:
|
|
|
"""
|
|
|
Simple in-memory rate limiter
|
|
|
"""
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
self.requests: Dict[str, list] = defaultdict(list)
|
|
|
|
|
|
|
|
|
self.limits = {
|
|
|
"default": 60,
|
|
|
"sentiment": 30,
|
|
|
"model_loading": 5,
|
|
|
"dataset_loading": 5,
|
|
|
"external_api": 100
|
|
|
}
|
|
|
|
|
|
|
|
|
self.window = 60
|
|
|
|
|
|
def is_allowed(
|
|
|
self,
|
|
|
client_id: str,
|
|
|
endpoint_type: str = "default"
|
|
|
) -> Tuple[bool, Dict]:
|
|
|
"""
|
|
|
Check if request is allowed based on rate limit
|
|
|
|
|
|
Args:
|
|
|
client_id: Client identifier (IP, API key, etc.)
|
|
|
endpoint_type: Type of endpoint (default, sentiment, model_loading, etc.)
|
|
|
|
|
|
Returns:
|
|
|
Tuple of (is_allowed, info_dict)
|
|
|
"""
|
|
|
current_time = time.time()
|
|
|
limit = self.limits.get(endpoint_type, self.limits["default"])
|
|
|
|
|
|
|
|
|
self.requests[client_id] = [
|
|
|
ts for ts in self.requests[client_id]
|
|
|
if current_time - ts < self.window
|
|
|
]
|
|
|
|
|
|
|
|
|
request_count = len(self.requests[client_id])
|
|
|
|
|
|
|
|
|
if request_count < limit:
|
|
|
|
|
|
self.requests[client_id].append(current_time)
|
|
|
|
|
|
return True, {
|
|
|
"allowed": True,
|
|
|
"requests_remaining": limit - request_count - 1,
|
|
|
"limit": limit,
|
|
|
"window_seconds": self.window,
|
|
|
"reset_at": current_time + self.window
|
|
|
}
|
|
|
else:
|
|
|
|
|
|
oldest_request = min(self.requests[client_id])
|
|
|
reset_at = oldest_request + self.window
|
|
|
|
|
|
return False, {
|
|
|
"allowed": False,
|
|
|
"requests_remaining": 0,
|
|
|
"limit": limit,
|
|
|
"window_seconds": self.window,
|
|
|
"reset_at": reset_at,
|
|
|
"retry_after": reset_at - current_time
|
|
|
}
|
|
|
|
|
|
def reset_client(self, client_id: str):
|
|
|
"""Reset rate limit for a specific client"""
|
|
|
if client_id in self.requests:
|
|
|
del self.requests[client_id]
|
|
|
logger.info(f"Rate limit reset for client: {client_id}")
|
|
|
|
|
|
def get_stats(self) -> Dict:
|
|
|
"""Get rate limiter statistics"""
|
|
|
current_time = time.time()
|
|
|
|
|
|
active_clients = 0
|
|
|
total_requests = 0
|
|
|
|
|
|
for client_id, timestamps in self.requests.items():
|
|
|
|
|
|
recent_requests = [
|
|
|
ts for ts in timestamps
|
|
|
if current_time - ts < self.window
|
|
|
]
|
|
|
if recent_requests:
|
|
|
active_clients += 1
|
|
|
total_requests += len(recent_requests)
|
|
|
|
|
|
return {
|
|
|
"active_clients": active_clients,
|
|
|
"total_recent_requests": total_requests,
|
|
|
"window_seconds": self.window,
|
|
|
"limits": self.limits
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
rate_limiter = SimpleRateLimiter()
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ["SimpleRateLimiter", "rate_limiter"]
|
|
|
|