""" Tests for the RateLimiter class. """ import time from datetime import datetime, timedelta from unittest.mock import MagicMock import pytest from inference.data.async_request_db import ApiKeyConfig from inference.web.rate_limiter import RateLimiter, RateLimitConfig, RateLimitStatus class TestRateLimiter: """Tests for RateLimiter.""" def test_check_submit_limit_allowed(self, rate_limiter, mock_db): """Test that requests are allowed under the limit.""" status = rate_limiter.check_submit_limit("test-api-key") assert status.allowed is True assert status.remaining_requests >= 0 assert status.retry_after_seconds is None def test_check_submit_limit_rate_exceeded(self, rate_limiter, mock_db): """Test rate limit exceeded when too many requests.""" # Record 10 requests (the default limit) for _ in range(10): rate_limiter.record_request("test-api-key") status = rate_limiter.check_submit_limit("test-api-key") assert status.allowed is False assert status.remaining_requests == 0 assert status.retry_after_seconds is not None assert status.retry_after_seconds > 0 assert "rate limit" in status.reason.lower() def test_check_submit_limit_concurrent_jobs_exceeded(self, rate_limiter, mock_db): """Test rejection when max concurrent jobs reached.""" # Mock active jobs at the limit mock_db.count_active_jobs.return_value = 3 # Max is 3 status = rate_limiter.check_submit_limit("test-api-key") assert status.allowed is False assert "concurrent" in status.reason.lower() def test_record_request(self, rate_limiter, mock_db): """Test that recording a request works.""" rate_limiter.record_request("test-api-key") # Should have called the database mock_db.record_rate_limit_event.assert_called_once_with("test-api-key", "request") def test_check_poll_limit_allowed(self, rate_limiter, mock_db): """Test that polling is allowed initially.""" status = rate_limiter.check_poll_limit("test-api-key", "request-123") assert status.allowed is True def test_check_poll_limit_too_frequent(self, rate_limiter, mock_db): """Test that rapid polling is rejected.""" # First poll should succeed status1 = rate_limiter.check_poll_limit("test-api-key", "request-123") assert status1.allowed is True # Immediate second poll should fail status2 = rate_limiter.check_poll_limit("test-api-key", "request-123") assert status2.allowed is False assert "polling" in status2.reason.lower() assert status2.retry_after_seconds is not None def test_check_poll_limit_different_requests(self, rate_limiter, mock_db): """Test that different request_ids have separate poll limits.""" # Poll request 1 status1 = rate_limiter.check_poll_limit("test-api-key", "request-1") assert status1.allowed is True # Poll request 2 should also be allowed status2 = rate_limiter.check_poll_limit("test-api-key", "request-2") assert status2.allowed is True def test_sliding_window_expires(self, rate_limiter, mock_db): """Test that requests expire from the sliding window.""" # Record requests for _ in range(5): rate_limiter.record_request("test-api-key") # Check status - should have 5 remaining status1 = rate_limiter.check_submit_limit("test-api-key") assert status1.allowed is True assert status1.remaining_requests == 4 # 10 - 5 - 1 (for this check) def test_get_rate_limit_headers(self, rate_limiter): """Test rate limit header generation.""" status = RateLimitStatus( allowed=False, remaining_requests=0, reset_at=datetime.utcnow() + timedelta(seconds=30), retry_after_seconds=30, ) headers = rate_limiter.get_rate_limit_headers(status) assert "X-RateLimit-Remaining" in headers assert headers["X-RateLimit-Remaining"] == "0" assert "Retry-After" in headers assert headers["Retry-After"] == "30" def test_cleanup_poll_timestamps(self, rate_limiter, mock_db): """Test cleanup of old poll timestamps.""" # Add some poll timestamps rate_limiter.check_poll_limit("test-api-key", "old-request") # Manually age the timestamp rate_limiter._poll_timestamps[("test-api-key", "old-request")] = time.time() - 7200 # Run cleanup with 1 hour max age cleaned = rate_limiter.cleanup_poll_timestamps(max_age_seconds=3600) assert cleaned == 1 assert ("test-api-key", "old-request") not in rate_limiter._poll_timestamps def test_cleanup_request_windows(self, rate_limiter, mock_db): """Test cleanup of empty request windows.""" # Add some old requests rate_limiter._request_windows["old-key"] = [time.time() - 120] # Run cleanup rate_limiter.cleanup_request_windows() # Old entries should be removed assert "old-key" not in rate_limiter._request_windows def test_config_caching(self, rate_limiter, mock_db): """Test that API key configs are cached.""" # First call should query database rate_limiter._get_config("test-api-key") assert mock_db.get_api_key_config.call_count == 1 # Second call should use cache rate_limiter._get_config("test-api-key") assert mock_db.get_api_key_config.call_count == 1 # Still 1 def test_default_config_for_unknown_key(self, rate_limiter, mock_db): """Test that unknown API keys get default config.""" mock_db.get_api_key_config.return_value = None config = rate_limiter._get_config("unknown-key") assert config.requests_per_minute == 10 # Default assert config.max_concurrent_jobs == 3 # Default