WIP
This commit is contained in:
154
tests/web/test_rate_limiter.py
Normal file
154
tests/web/test_rate_limiter.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
Tests for the RateLimiter class.
|
||||
"""
|
||||
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from src.data.async_request_db import ApiKeyConfig
|
||||
from src.web.rate_limiter import RateLimiter, RateLimitConfig, RateLimitStatus
|
||||
|
||||
|
||||
class TestRateLimiter:
|
||||
"""Tests for RateLimiter."""
|
||||
|
||||
def test_check_submit_limit_allowed(self, rate_limiter, mock_db):
|
||||
"""Test that requests are allowed under the limit."""
|
||||
status = rate_limiter.check_submit_limit("test-api-key")
|
||||
|
||||
assert status.allowed is True
|
||||
assert status.remaining_requests >= 0
|
||||
assert status.retry_after_seconds is None
|
||||
|
||||
def test_check_submit_limit_rate_exceeded(self, rate_limiter, mock_db):
|
||||
"""Test rate limit exceeded when too many requests."""
|
||||
# Record 10 requests (the default limit)
|
||||
for _ in range(10):
|
||||
rate_limiter.record_request("test-api-key")
|
||||
|
||||
status = rate_limiter.check_submit_limit("test-api-key")
|
||||
|
||||
assert status.allowed is False
|
||||
assert status.remaining_requests == 0
|
||||
assert status.retry_after_seconds is not None
|
||||
assert status.retry_after_seconds > 0
|
||||
assert "rate limit" in status.reason.lower()
|
||||
|
||||
def test_check_submit_limit_concurrent_jobs_exceeded(self, rate_limiter, mock_db):
|
||||
"""Test rejection when max concurrent jobs reached."""
|
||||
# Mock active jobs at the limit
|
||||
mock_db.count_active_jobs.return_value = 3 # Max is 3
|
||||
|
||||
status = rate_limiter.check_submit_limit("test-api-key")
|
||||
|
||||
assert status.allowed is False
|
||||
assert "concurrent" in status.reason.lower()
|
||||
|
||||
def test_record_request(self, rate_limiter, mock_db):
|
||||
"""Test that recording a request works."""
|
||||
rate_limiter.record_request("test-api-key")
|
||||
|
||||
# Should have called the database
|
||||
mock_db.record_rate_limit_event.assert_called_once_with("test-api-key", "request")
|
||||
|
||||
def test_check_poll_limit_allowed(self, rate_limiter, mock_db):
|
||||
"""Test that polling is allowed initially."""
|
||||
status = rate_limiter.check_poll_limit("test-api-key", "request-123")
|
||||
|
||||
assert status.allowed is True
|
||||
|
||||
def test_check_poll_limit_too_frequent(self, rate_limiter, mock_db):
|
||||
"""Test that rapid polling is rejected."""
|
||||
# First poll should succeed
|
||||
status1 = rate_limiter.check_poll_limit("test-api-key", "request-123")
|
||||
assert status1.allowed is True
|
||||
|
||||
# Immediate second poll should fail
|
||||
status2 = rate_limiter.check_poll_limit("test-api-key", "request-123")
|
||||
assert status2.allowed is False
|
||||
assert "polling" in status2.reason.lower()
|
||||
assert status2.retry_after_seconds is not None
|
||||
|
||||
def test_check_poll_limit_different_requests(self, rate_limiter, mock_db):
|
||||
"""Test that different request_ids have separate poll limits."""
|
||||
# Poll request 1
|
||||
status1 = rate_limiter.check_poll_limit("test-api-key", "request-1")
|
||||
assert status1.allowed is True
|
||||
|
||||
# Poll request 2 should also be allowed
|
||||
status2 = rate_limiter.check_poll_limit("test-api-key", "request-2")
|
||||
assert status2.allowed is True
|
||||
|
||||
def test_sliding_window_expires(self, rate_limiter, mock_db):
|
||||
"""Test that requests expire from the sliding window."""
|
||||
# Record requests
|
||||
for _ in range(5):
|
||||
rate_limiter.record_request("test-api-key")
|
||||
|
||||
# Check status - should have 5 remaining
|
||||
status1 = rate_limiter.check_submit_limit("test-api-key")
|
||||
assert status1.allowed is True
|
||||
assert status1.remaining_requests == 4 # 10 - 5 - 1 (for this check)
|
||||
|
||||
def test_get_rate_limit_headers(self, rate_limiter):
|
||||
"""Test rate limit header generation."""
|
||||
status = RateLimitStatus(
|
||||
allowed=False,
|
||||
remaining_requests=0,
|
||||
reset_at=datetime.utcnow() + timedelta(seconds=30),
|
||||
retry_after_seconds=30,
|
||||
)
|
||||
|
||||
headers = rate_limiter.get_rate_limit_headers(status)
|
||||
|
||||
assert "X-RateLimit-Remaining" in headers
|
||||
assert headers["X-RateLimit-Remaining"] == "0"
|
||||
assert "Retry-After" in headers
|
||||
assert headers["Retry-After"] == "30"
|
||||
|
||||
def test_cleanup_poll_timestamps(self, rate_limiter, mock_db):
|
||||
"""Test cleanup of old poll timestamps."""
|
||||
# Add some poll timestamps
|
||||
rate_limiter.check_poll_limit("test-api-key", "old-request")
|
||||
|
||||
# Manually age the timestamp
|
||||
rate_limiter._poll_timestamps[("test-api-key", "old-request")] = time.time() - 7200
|
||||
|
||||
# Run cleanup with 1 hour max age
|
||||
cleaned = rate_limiter.cleanup_poll_timestamps(max_age_seconds=3600)
|
||||
|
||||
assert cleaned == 1
|
||||
assert ("test-api-key", "old-request") not in rate_limiter._poll_timestamps
|
||||
|
||||
def test_cleanup_request_windows(self, rate_limiter, mock_db):
|
||||
"""Test cleanup of empty request windows."""
|
||||
# Add some old requests
|
||||
rate_limiter._request_windows["old-key"] = [time.time() - 120]
|
||||
|
||||
# Run cleanup
|
||||
rate_limiter.cleanup_request_windows()
|
||||
|
||||
# Old entries should be removed
|
||||
assert "old-key" not in rate_limiter._request_windows
|
||||
|
||||
def test_config_caching(self, rate_limiter, mock_db):
|
||||
"""Test that API key configs are cached."""
|
||||
# First call should query database
|
||||
rate_limiter._get_config("test-api-key")
|
||||
assert mock_db.get_api_key_config.call_count == 1
|
||||
|
||||
# Second call should use cache
|
||||
rate_limiter._get_config("test-api-key")
|
||||
assert mock_db.get_api_key_config.call_count == 1 # Still 1
|
||||
|
||||
def test_default_config_for_unknown_key(self, rate_limiter, mock_db):
|
||||
"""Test that unknown API keys get default config."""
|
||||
mock_db.get_api_key_config.return_value = None
|
||||
|
||||
config = rate_limiter._get_config("unknown-key")
|
||||
|
||||
assert config.requests_per_minute == 10 # Default
|
||||
assert config.max_concurrent_jobs == 3 # Default
|
||||
Reference in New Issue
Block a user