""" Tests for the async API routes. """ import tempfile from datetime import datetime, timedelta from pathlib import Path from unittest.mock import MagicMock, patch import pytest from fastapi import FastAPI from fastapi.testclient import TestClient from inference.data.async_request_db import ApiKeyConfig, AsyncRequest, AsyncRequestDB from inference.web.api.v1.public.async_api import create_async_router, set_async_service from inference.web.services.async_processing import AsyncSubmitResult from inference.web.dependencies import init_dependencies from inference.web.rate_limiter import RateLimiter, RateLimitStatus from inference.web.schemas.inference import AsyncStatus # Valid UUID for testing TEST_REQUEST_UUID = "550e8400-e29b-41d4-a716-446655440000" INVALID_UUID = "nonexistent-id" @pytest.fixture def mock_async_service(): """Create a mock AsyncProcessingService.""" service = MagicMock() # Mock config mock_config = MagicMock() mock_config.max_file_size_mb = 50 service._async_config = mock_config # Default submit result service.submit_request.return_value = AsyncSubmitResult( success=True, request_id="test-request-id", estimated_wait_seconds=30, ) return service @pytest.fixture def mock_rate_limiter(mock_db): """Create a mock RateLimiter.""" limiter = MagicMock(spec=RateLimiter) # Default: allow all requests limiter.check_submit_limit.return_value = RateLimitStatus( allowed=True, remaining_requests=9, reset_at=datetime.utcnow() + timedelta(seconds=60), ) limiter.check_poll_limit.return_value = RateLimitStatus( allowed=True, remaining_requests=999, reset_at=datetime.utcnow(), ) limiter.get_rate_limit_headers.return_value = {} return limiter @pytest.fixture def app(mock_db, mock_rate_limiter, mock_async_service): """Create a test FastAPI app with async routes.""" app = FastAPI() # Initialize dependencies init_dependencies(mock_db, mock_rate_limiter) set_async_service(mock_async_service) # Add routes router = create_async_router(allowed_extensions=(".pdf", ".png", ".jpg", ".jpeg")) app.include_router(router, prefix="/api/v1") return app @pytest.fixture def client(app): """Create a test client.""" return TestClient(app) class TestAsyncSubmitEndpoint: """Tests for POST /api/v1/async/submit.""" def test_submit_success(self, client, mock_async_service): """Test successful submission.""" with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f: f.write(b"fake pdf content") f.seek(0) response = client.post( "/api/v1/async/submit", files={"file": ("test.pdf", f, "application/pdf")}, headers={"X-API-Key": "test-api-key"}, ) assert response.status_code == 200 data = response.json() assert data["status"] == "accepted" assert data["request_id"] == "test-request-id" assert "poll_url" in data def test_submit_missing_api_key(self, client): """Test submission without API key.""" with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f: f.write(b"fake pdf content") f.seek(0) response = client.post( "/api/v1/async/submit", files={"file": ("test.pdf", f, "application/pdf")}, ) assert response.status_code == 401 assert "X-API-Key" in response.json()["detail"] def test_submit_invalid_api_key(self, client, mock_db): """Test submission with invalid API key.""" mock_db.is_valid_api_key.return_value = False with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f: f.write(b"fake pdf content") f.seek(0) response = client.post( "/api/v1/async/submit", files={"file": ("test.pdf", f, "application/pdf")}, headers={"X-API-Key": "invalid-key"}, ) assert response.status_code == 401 def test_submit_unsupported_file_type(self, client): """Test submission with unsupported file type.""" with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f: f.write(b"text content") f.seek(0) response = client.post( "/api/v1/async/submit", files={"file": ("test.txt", f, "text/plain")}, headers={"X-API-Key": "test-api-key"}, ) assert response.status_code == 400 assert "Unsupported file type" in response.json()["detail"] def test_submit_rate_limited(self, client, mock_rate_limiter): """Test submission when rate limited.""" mock_rate_limiter.check_submit_limit.return_value = RateLimitStatus( allowed=False, remaining_requests=0, reset_at=datetime.utcnow() + timedelta(seconds=30), retry_after_seconds=30, reason="Rate limit exceeded", ) mock_rate_limiter.get_rate_limit_headers.return_value = {"Retry-After": "30"} with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f: f.write(b"fake pdf content") f.seek(0) response = client.post( "/api/v1/async/submit", files={"file": ("test.pdf", f, "application/pdf")}, headers={"X-API-Key": "test-api-key"}, ) assert response.status_code == 429 assert "Retry-After" in response.headers def test_submit_queue_full(self, client, mock_async_service): """Test submission when queue is full.""" mock_async_service.submit_request.return_value = AsyncSubmitResult( success=False, request_id="test-id", error="Processing queue is full", ) with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f: f.write(b"fake pdf content") f.seek(0) response = client.post( "/api/v1/async/submit", files={"file": ("test.pdf", f, "application/pdf")}, headers={"X-API-Key": "test-api-key"}, ) assert response.status_code == 503 class TestAsyncStatusEndpoint: """Tests for GET /api/v1/async/status/{request_id}.""" def test_get_status_pending(self, client, mock_db, sample_async_request): """Test getting status of pending request.""" mock_db.get_request_by_api_key.return_value = sample_async_request mock_db.get_queue_position.return_value = 3 response = client.get( "/api/v1/async/status/550e8400-e29b-41d4-a716-446655440000", headers={"X-API-Key": "test-api-key"}, ) assert response.status_code == 200 data = response.json() assert data["status"] == "pending" assert data["position_in_queue"] == 3 assert data["result_url"] is None def test_get_status_completed(self, client, mock_db, sample_async_request): """Test getting status of completed request.""" sample_async_request.status = "completed" sample_async_request.completed_at = datetime.utcnow() mock_db.get_request_by_api_key.return_value = sample_async_request response = client.get( "/api/v1/async/status/550e8400-e29b-41d4-a716-446655440000", headers={"X-API-Key": "test-api-key"}, ) assert response.status_code == 200 data = response.json() assert data["status"] == "completed" assert data["result_url"] is not None def test_get_status_not_found(self, client, mock_db): """Test getting status of non-existent request.""" mock_db.get_request_by_api_key.return_value = None response = client.get( "/api/v1/async/status/00000000-0000-0000-0000-000000000000", headers={"X-API-Key": "test-api-key"}, ) assert response.status_code == 404 def test_get_status_wrong_api_key(self, client, mock_db, sample_async_request): """Test that requests are isolated by API key.""" # Request belongs to different API key mock_db.get_request_by_api_key.return_value = None response = client.get( "/api/v1/async/status/550e8400-e29b-41d4-a716-446655440000", headers={"X-API-Key": "different-api-key"}, ) assert response.status_code == 404 class TestAsyncResultEndpoint: """Tests for GET /api/v1/async/result/{request_id}.""" def test_get_result_completed(self, client, mock_db, sample_async_request): """Test getting result of completed request.""" sample_async_request.status = "completed" sample_async_request.completed_at = datetime.utcnow() sample_async_request.processing_time_ms = 1234.5 sample_async_request.result = { "document_id": "test-doc", "success": True, "document_type": "invoice", "fields": {"InvoiceNumber": "12345"}, "confidence": {"InvoiceNumber": 0.95}, "detections": [], "errors": [], } mock_db.get_request_by_api_key.return_value = sample_async_request response = client.get( "/api/v1/async/result/550e8400-e29b-41d4-a716-446655440000", headers={"X-API-Key": "test-api-key"}, ) assert response.status_code == 200 data = response.json() assert data["status"] == "completed" assert data["result"] is not None assert data["result"]["fields"]["InvoiceNumber"] == "12345" def test_get_result_not_completed(self, client, mock_db, sample_async_request): """Test getting result of pending request.""" mock_db.get_request_by_api_key.return_value = sample_async_request response = client.get( "/api/v1/async/result/550e8400-e29b-41d4-a716-446655440000", headers={"X-API-Key": "test-api-key"}, ) assert response.status_code == 409 assert "not yet completed" in response.json()["detail"] def test_get_result_failed(self, client, mock_db, sample_async_request): """Test getting result of failed request.""" sample_async_request.status = "failed" sample_async_request.error_message = "Processing failed" sample_async_request.processing_time_ms = 500.0 mock_db.get_request_by_api_key.return_value = sample_async_request response = client.get( "/api/v1/async/result/550e8400-e29b-41d4-a716-446655440000", headers={"X-API-Key": "test-api-key"}, ) assert response.status_code == 200 data = response.json() assert data["status"] == "failed" class TestAsyncListEndpoint: """Tests for GET /api/v1/async/requests.""" def test_list_requests(self, client, mock_db, sample_async_request): """Test listing requests.""" mock_db.get_requests_by_api_key.return_value = ([sample_async_request], 1) response = client.get( "/api/v1/async/requests", headers={"X-API-Key": "test-api-key"}, ) assert response.status_code == 200 data = response.json() assert data["total"] == 1 assert len(data["requests"]) == 1 def test_list_requests_with_status_filter(self, client, mock_db): """Test listing requests with status filter.""" mock_db.get_requests_by_api_key.return_value = ([], 0) response = client.get( "/api/v1/async/requests?status=completed", headers={"X-API-Key": "test-api-key"}, ) assert response.status_code == 200 mock_db.get_requests_by_api_key.assert_called_once() call_kwargs = mock_db.get_requests_by_api_key.call_args[1] assert call_kwargs["status"] == "completed" def test_list_requests_pagination(self, client, mock_db): """Test listing requests with pagination.""" mock_db.get_requests_by_api_key.return_value = ([], 0) response = client.get( "/api/v1/async/requests?limit=50&offset=10", headers={"X-API-Key": "test-api-key"}, ) assert response.status_code == 200 call_kwargs = mock_db.get_requests_by_api_key.call_args[1] assert call_kwargs["limit"] == 50 assert call_kwargs["offset"] == 10 def test_list_requests_invalid_status(self, client, mock_db): """Test listing with invalid status filter.""" response = client.get( "/api/v1/async/requests?status=invalid", headers={"X-API-Key": "test-api-key"}, ) assert response.status_code == 400 class TestAsyncDeleteEndpoint: """Tests for DELETE /api/v1/async/requests/{request_id}.""" def test_delete_pending_request(self, client, mock_db, sample_async_request): """Test deleting a pending request.""" mock_db.get_request_by_api_key.return_value = sample_async_request response = client.delete( "/api/v1/async/requests/550e8400-e29b-41d4-a716-446655440000", headers={"X-API-Key": "test-api-key"}, ) assert response.status_code == 200 assert response.json()["status"] == "deleted" def test_delete_processing_request(self, client, mock_db, sample_async_request): """Test that processing requests cannot be deleted.""" sample_async_request.status = "processing" mock_db.get_request_by_api_key.return_value = sample_async_request response = client.delete( "/api/v1/async/requests/550e8400-e29b-41d4-a716-446655440000", headers={"X-API-Key": "test-api-key"}, ) assert response.status_code == 409 def test_delete_not_found(self, client, mock_db): """Test deleting non-existent request.""" mock_db.get_request_by_api_key.return_value = None response = client.delete( "/api/v1/async/requests/00000000-0000-0000-0000-000000000000", headers={"X-API-Key": "test-api-key"}, ) assert response.status_code == 404