155 lines
5.9 KiB
Python
155 lines
5.9 KiB
Python
"""
|
|
Tests for the RateLimiter class.
|
|
"""
|
|
|
|
import time
|
|
from datetime import datetime, timedelta
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from src.data.async_request_db import ApiKeyConfig
|
|
from src.web.rate_limiter import RateLimiter, RateLimitConfig, RateLimitStatus
|
|
|
|
|
|
class TestRateLimiter:
|
|
"""Tests for RateLimiter."""
|
|
|
|
def test_check_submit_limit_allowed(self, rate_limiter, mock_db):
|
|
"""Test that requests are allowed under the limit."""
|
|
status = rate_limiter.check_submit_limit("test-api-key")
|
|
|
|
assert status.allowed is True
|
|
assert status.remaining_requests >= 0
|
|
assert status.retry_after_seconds is None
|
|
|
|
def test_check_submit_limit_rate_exceeded(self, rate_limiter, mock_db):
|
|
"""Test rate limit exceeded when too many requests."""
|
|
# Record 10 requests (the default limit)
|
|
for _ in range(10):
|
|
rate_limiter.record_request("test-api-key")
|
|
|
|
status = rate_limiter.check_submit_limit("test-api-key")
|
|
|
|
assert status.allowed is False
|
|
assert status.remaining_requests == 0
|
|
assert status.retry_after_seconds is not None
|
|
assert status.retry_after_seconds > 0
|
|
assert "rate limit" in status.reason.lower()
|
|
|
|
def test_check_submit_limit_concurrent_jobs_exceeded(self, rate_limiter, mock_db):
|
|
"""Test rejection when max concurrent jobs reached."""
|
|
# Mock active jobs at the limit
|
|
mock_db.count_active_jobs.return_value = 3 # Max is 3
|
|
|
|
status = rate_limiter.check_submit_limit("test-api-key")
|
|
|
|
assert status.allowed is False
|
|
assert "concurrent" in status.reason.lower()
|
|
|
|
def test_record_request(self, rate_limiter, mock_db):
|
|
"""Test that recording a request works."""
|
|
rate_limiter.record_request("test-api-key")
|
|
|
|
# Should have called the database
|
|
mock_db.record_rate_limit_event.assert_called_once_with("test-api-key", "request")
|
|
|
|
def test_check_poll_limit_allowed(self, rate_limiter, mock_db):
|
|
"""Test that polling is allowed initially."""
|
|
status = rate_limiter.check_poll_limit("test-api-key", "request-123")
|
|
|
|
assert status.allowed is True
|
|
|
|
def test_check_poll_limit_too_frequent(self, rate_limiter, mock_db):
|
|
"""Test that rapid polling is rejected."""
|
|
# First poll should succeed
|
|
status1 = rate_limiter.check_poll_limit("test-api-key", "request-123")
|
|
assert status1.allowed is True
|
|
|
|
# Immediate second poll should fail
|
|
status2 = rate_limiter.check_poll_limit("test-api-key", "request-123")
|
|
assert status2.allowed is False
|
|
assert "polling" in status2.reason.lower()
|
|
assert status2.retry_after_seconds is not None
|
|
|
|
def test_check_poll_limit_different_requests(self, rate_limiter, mock_db):
|
|
"""Test that different request_ids have separate poll limits."""
|
|
# Poll request 1
|
|
status1 = rate_limiter.check_poll_limit("test-api-key", "request-1")
|
|
assert status1.allowed is True
|
|
|
|
# Poll request 2 should also be allowed
|
|
status2 = rate_limiter.check_poll_limit("test-api-key", "request-2")
|
|
assert status2.allowed is True
|
|
|
|
def test_sliding_window_expires(self, rate_limiter, mock_db):
|
|
"""Test that requests expire from the sliding window."""
|
|
# Record requests
|
|
for _ in range(5):
|
|
rate_limiter.record_request("test-api-key")
|
|
|
|
# Check status - should have 5 remaining
|
|
status1 = rate_limiter.check_submit_limit("test-api-key")
|
|
assert status1.allowed is True
|
|
assert status1.remaining_requests == 4 # 10 - 5 - 1 (for this check)
|
|
|
|
def test_get_rate_limit_headers(self, rate_limiter):
|
|
"""Test rate limit header generation."""
|
|
status = RateLimitStatus(
|
|
allowed=False,
|
|
remaining_requests=0,
|
|
reset_at=datetime.utcnow() + timedelta(seconds=30),
|
|
retry_after_seconds=30,
|
|
)
|
|
|
|
headers = rate_limiter.get_rate_limit_headers(status)
|
|
|
|
assert "X-RateLimit-Remaining" in headers
|
|
assert headers["X-RateLimit-Remaining"] == "0"
|
|
assert "Retry-After" in headers
|
|
assert headers["Retry-After"] == "30"
|
|
|
|
def test_cleanup_poll_timestamps(self, rate_limiter, mock_db):
|
|
"""Test cleanup of old poll timestamps."""
|
|
# Add some poll timestamps
|
|
rate_limiter.check_poll_limit("test-api-key", "old-request")
|
|
|
|
# Manually age the timestamp
|
|
rate_limiter._poll_timestamps[("test-api-key", "old-request")] = time.time() - 7200
|
|
|
|
# Run cleanup with 1 hour max age
|
|
cleaned = rate_limiter.cleanup_poll_timestamps(max_age_seconds=3600)
|
|
|
|
assert cleaned == 1
|
|
assert ("test-api-key", "old-request") not in rate_limiter._poll_timestamps
|
|
|
|
def test_cleanup_request_windows(self, rate_limiter, mock_db):
|
|
"""Test cleanup of empty request windows."""
|
|
# Add some old requests
|
|
rate_limiter._request_windows["old-key"] = [time.time() - 120]
|
|
|
|
# Run cleanup
|
|
rate_limiter.cleanup_request_windows()
|
|
|
|
# Old entries should be removed
|
|
assert "old-key" not in rate_limiter._request_windows
|
|
|
|
def test_config_caching(self, rate_limiter, mock_db):
|
|
"""Test that API key configs are cached."""
|
|
# First call should query database
|
|
rate_limiter._get_config("test-api-key")
|
|
assert mock_db.get_api_key_config.call_count == 1
|
|
|
|
# Second call should use cache
|
|
rate_limiter._get_config("test-api-key")
|
|
assert mock_db.get_api_key_config.call_count == 1 # Still 1
|
|
|
|
def test_default_config_for_unknown_key(self, rate_limiter, mock_db):
|
|
"""Test that unknown API keys get default config."""
|
|
mock_db.get_api_key_config.return_value = None
|
|
|
|
config = rate_limiter._get_config("unknown-key")
|
|
|
|
assert config.requests_per_minute == 10 # Default
|
|
assert config.max_concurrent_jobs == 3 # Default
|