410 lines
14 KiB
Python
410 lines
14 KiB
Python
"""
|
|
Tests for the async API routes.
|
|
"""
|
|
|
|
import tempfile
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
from fastapi import FastAPI
|
|
from fastapi.testclient import TestClient
|
|
|
|
from inference.data.async_request_db import ApiKeyConfig, AsyncRequest, AsyncRequestDB
|
|
from inference.web.api.v1.public.async_api import create_async_router, set_async_service
|
|
from inference.web.services.async_processing import AsyncSubmitResult
|
|
from inference.web.dependencies import init_dependencies
|
|
from inference.web.rate_limiter import RateLimiter, RateLimitStatus
|
|
from inference.web.schemas.inference import AsyncStatus
|
|
|
|
# Valid UUID for testing
|
|
TEST_REQUEST_UUID = "550e8400-e29b-41d4-a716-446655440000"
|
|
INVALID_UUID = "nonexistent-id"
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_async_service():
|
|
"""Create a mock AsyncProcessingService."""
|
|
service = MagicMock()
|
|
|
|
# Mock config
|
|
mock_config = MagicMock()
|
|
mock_config.max_file_size_mb = 50
|
|
service._async_config = mock_config
|
|
|
|
# Default submit result
|
|
service.submit_request.return_value = AsyncSubmitResult(
|
|
success=True,
|
|
request_id="test-request-id",
|
|
estimated_wait_seconds=30,
|
|
)
|
|
|
|
return service
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_rate_limiter(mock_db):
|
|
"""Create a mock RateLimiter."""
|
|
limiter = MagicMock(spec=RateLimiter)
|
|
|
|
# Default: allow all requests
|
|
limiter.check_submit_limit.return_value = RateLimitStatus(
|
|
allowed=True,
|
|
remaining_requests=9,
|
|
reset_at=datetime.utcnow() + timedelta(seconds=60),
|
|
)
|
|
limiter.check_poll_limit.return_value = RateLimitStatus(
|
|
allowed=True,
|
|
remaining_requests=999,
|
|
reset_at=datetime.utcnow(),
|
|
)
|
|
limiter.get_rate_limit_headers.return_value = {}
|
|
|
|
return limiter
|
|
|
|
|
|
@pytest.fixture
|
|
def app(mock_db, mock_rate_limiter, mock_async_service):
|
|
"""Create a test FastAPI app with async routes."""
|
|
app = FastAPI()
|
|
|
|
# Initialize dependencies
|
|
init_dependencies(mock_db, mock_rate_limiter)
|
|
set_async_service(mock_async_service)
|
|
|
|
# Add routes
|
|
router = create_async_router(allowed_extensions=(".pdf", ".png", ".jpg", ".jpeg"))
|
|
app.include_router(router, prefix="/api/v1")
|
|
|
|
return app
|
|
|
|
|
|
@pytest.fixture
|
|
def client(app):
|
|
"""Create a test client."""
|
|
return TestClient(app)
|
|
|
|
|
|
class TestAsyncSubmitEndpoint:
|
|
"""Tests for POST /api/v1/async/submit."""
|
|
|
|
def test_submit_success(self, client, mock_async_service):
|
|
"""Test successful submission."""
|
|
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
|
|
f.write(b"fake pdf content")
|
|
f.seek(0)
|
|
|
|
response = client.post(
|
|
"/api/v1/async/submit",
|
|
files={"file": ("test.pdf", f, "application/pdf")},
|
|
headers={"X-API-Key": "test-api-key"},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "accepted"
|
|
assert data["request_id"] == "test-request-id"
|
|
assert "poll_url" in data
|
|
|
|
def test_submit_missing_api_key(self, client):
|
|
"""Test submission without API key."""
|
|
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
|
|
f.write(b"fake pdf content")
|
|
f.seek(0)
|
|
|
|
response = client.post(
|
|
"/api/v1/async/submit",
|
|
files={"file": ("test.pdf", f, "application/pdf")},
|
|
)
|
|
|
|
assert response.status_code == 401
|
|
assert "X-API-Key" in response.json()["detail"]
|
|
|
|
def test_submit_invalid_api_key(self, client, mock_db):
|
|
"""Test submission with invalid API key."""
|
|
mock_db.is_valid_api_key.return_value = False
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
|
|
f.write(b"fake pdf content")
|
|
f.seek(0)
|
|
|
|
response = client.post(
|
|
"/api/v1/async/submit",
|
|
files={"file": ("test.pdf", f, "application/pdf")},
|
|
headers={"X-API-Key": "invalid-key"},
|
|
)
|
|
|
|
assert response.status_code == 401
|
|
|
|
def test_submit_unsupported_file_type(self, client):
|
|
"""Test submission with unsupported file type."""
|
|
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f:
|
|
f.write(b"text content")
|
|
f.seek(0)
|
|
|
|
response = client.post(
|
|
"/api/v1/async/submit",
|
|
files={"file": ("test.txt", f, "text/plain")},
|
|
headers={"X-API-Key": "test-api-key"},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
assert "Unsupported file type" in response.json()["detail"]
|
|
|
|
def test_submit_rate_limited(self, client, mock_rate_limiter):
|
|
"""Test submission when rate limited."""
|
|
mock_rate_limiter.check_submit_limit.return_value = RateLimitStatus(
|
|
allowed=False,
|
|
remaining_requests=0,
|
|
reset_at=datetime.utcnow() + timedelta(seconds=30),
|
|
retry_after_seconds=30,
|
|
reason="Rate limit exceeded",
|
|
)
|
|
mock_rate_limiter.get_rate_limit_headers.return_value = {"Retry-After": "30"}
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
|
|
f.write(b"fake pdf content")
|
|
f.seek(0)
|
|
|
|
response = client.post(
|
|
"/api/v1/async/submit",
|
|
files={"file": ("test.pdf", f, "application/pdf")},
|
|
headers={"X-API-Key": "test-api-key"},
|
|
)
|
|
|
|
assert response.status_code == 429
|
|
assert "Retry-After" in response.headers
|
|
|
|
def test_submit_queue_full(self, client, mock_async_service):
|
|
"""Test submission when queue is full."""
|
|
mock_async_service.submit_request.return_value = AsyncSubmitResult(
|
|
success=False,
|
|
request_id="test-id",
|
|
error="Processing queue is full",
|
|
)
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
|
|
f.write(b"fake pdf content")
|
|
f.seek(0)
|
|
|
|
response = client.post(
|
|
"/api/v1/async/submit",
|
|
files={"file": ("test.pdf", f, "application/pdf")},
|
|
headers={"X-API-Key": "test-api-key"},
|
|
)
|
|
|
|
assert response.status_code == 503
|
|
|
|
|
|
class TestAsyncStatusEndpoint:
|
|
"""Tests for GET /api/v1/async/status/{request_id}."""
|
|
|
|
def test_get_status_pending(self, client, mock_db, sample_async_request):
|
|
"""Test getting status of pending request."""
|
|
mock_db.get_request_by_api_key.return_value = sample_async_request
|
|
mock_db.get_queue_position.return_value = 3
|
|
|
|
response = client.get(
|
|
"/api/v1/async/status/550e8400-e29b-41d4-a716-446655440000",
|
|
headers={"X-API-Key": "test-api-key"},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "pending"
|
|
assert data["position_in_queue"] == 3
|
|
assert data["result_url"] is None
|
|
|
|
def test_get_status_completed(self, client, mock_db, sample_async_request):
|
|
"""Test getting status of completed request."""
|
|
sample_async_request.status = "completed"
|
|
sample_async_request.completed_at = datetime.utcnow()
|
|
mock_db.get_request_by_api_key.return_value = sample_async_request
|
|
|
|
response = client.get(
|
|
"/api/v1/async/status/550e8400-e29b-41d4-a716-446655440000",
|
|
headers={"X-API-Key": "test-api-key"},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "completed"
|
|
assert data["result_url"] is not None
|
|
|
|
def test_get_status_not_found(self, client, mock_db):
|
|
"""Test getting status of non-existent request."""
|
|
mock_db.get_request_by_api_key.return_value = None
|
|
|
|
response = client.get(
|
|
"/api/v1/async/status/00000000-0000-0000-0000-000000000000",
|
|
headers={"X-API-Key": "test-api-key"},
|
|
)
|
|
|
|
assert response.status_code == 404
|
|
|
|
def test_get_status_wrong_api_key(self, client, mock_db, sample_async_request):
|
|
"""Test that requests are isolated by API key."""
|
|
# Request belongs to different API key
|
|
mock_db.get_request_by_api_key.return_value = None
|
|
|
|
response = client.get(
|
|
"/api/v1/async/status/550e8400-e29b-41d4-a716-446655440000",
|
|
headers={"X-API-Key": "different-api-key"},
|
|
)
|
|
|
|
assert response.status_code == 404
|
|
|
|
|
|
class TestAsyncResultEndpoint:
|
|
"""Tests for GET /api/v1/async/result/{request_id}."""
|
|
|
|
def test_get_result_completed(self, client, mock_db, sample_async_request):
|
|
"""Test getting result of completed request."""
|
|
sample_async_request.status = "completed"
|
|
sample_async_request.completed_at = datetime.utcnow()
|
|
sample_async_request.processing_time_ms = 1234.5
|
|
sample_async_request.result = {
|
|
"document_id": "test-doc",
|
|
"success": True,
|
|
"document_type": "invoice",
|
|
"fields": {"InvoiceNumber": "12345"},
|
|
"confidence": {"InvoiceNumber": 0.95},
|
|
"detections": [],
|
|
"errors": [],
|
|
}
|
|
mock_db.get_request_by_api_key.return_value = sample_async_request
|
|
|
|
response = client.get(
|
|
"/api/v1/async/result/550e8400-e29b-41d4-a716-446655440000",
|
|
headers={"X-API-Key": "test-api-key"},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "completed"
|
|
assert data["result"] is not None
|
|
assert data["result"]["fields"]["InvoiceNumber"] == "12345"
|
|
|
|
def test_get_result_not_completed(self, client, mock_db, sample_async_request):
|
|
"""Test getting result of pending request."""
|
|
mock_db.get_request_by_api_key.return_value = sample_async_request
|
|
|
|
response = client.get(
|
|
"/api/v1/async/result/550e8400-e29b-41d4-a716-446655440000",
|
|
headers={"X-API-Key": "test-api-key"},
|
|
)
|
|
|
|
assert response.status_code == 409
|
|
assert "not yet completed" in response.json()["detail"]
|
|
|
|
def test_get_result_failed(self, client, mock_db, sample_async_request):
|
|
"""Test getting result of failed request."""
|
|
sample_async_request.status = "failed"
|
|
sample_async_request.error_message = "Processing failed"
|
|
sample_async_request.processing_time_ms = 500.0
|
|
mock_db.get_request_by_api_key.return_value = sample_async_request
|
|
|
|
response = client.get(
|
|
"/api/v1/async/result/550e8400-e29b-41d4-a716-446655440000",
|
|
headers={"X-API-Key": "test-api-key"},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "failed"
|
|
|
|
|
|
class TestAsyncListEndpoint:
|
|
"""Tests for GET /api/v1/async/requests."""
|
|
|
|
def test_list_requests(self, client, mock_db, sample_async_request):
|
|
"""Test listing requests."""
|
|
mock_db.get_requests_by_api_key.return_value = ([sample_async_request], 1)
|
|
|
|
response = client.get(
|
|
"/api/v1/async/requests",
|
|
headers={"X-API-Key": "test-api-key"},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["total"] == 1
|
|
assert len(data["requests"]) == 1
|
|
|
|
def test_list_requests_with_status_filter(self, client, mock_db):
|
|
"""Test listing requests with status filter."""
|
|
mock_db.get_requests_by_api_key.return_value = ([], 0)
|
|
|
|
response = client.get(
|
|
"/api/v1/async/requests?status=completed",
|
|
headers={"X-API-Key": "test-api-key"},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
mock_db.get_requests_by_api_key.assert_called_once()
|
|
call_kwargs = mock_db.get_requests_by_api_key.call_args[1]
|
|
assert call_kwargs["status"] == "completed"
|
|
|
|
def test_list_requests_pagination(self, client, mock_db):
|
|
"""Test listing requests with pagination."""
|
|
mock_db.get_requests_by_api_key.return_value = ([], 0)
|
|
|
|
response = client.get(
|
|
"/api/v1/async/requests?limit=50&offset=10",
|
|
headers={"X-API-Key": "test-api-key"},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
call_kwargs = mock_db.get_requests_by_api_key.call_args[1]
|
|
assert call_kwargs["limit"] == 50
|
|
assert call_kwargs["offset"] == 10
|
|
|
|
def test_list_requests_invalid_status(self, client, mock_db):
|
|
"""Test listing with invalid status filter."""
|
|
response = client.get(
|
|
"/api/v1/async/requests?status=invalid",
|
|
headers={"X-API-Key": "test-api-key"},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
|
|
|
|
class TestAsyncDeleteEndpoint:
|
|
"""Tests for DELETE /api/v1/async/requests/{request_id}."""
|
|
|
|
def test_delete_pending_request(self, client, mock_db, sample_async_request):
|
|
"""Test deleting a pending request."""
|
|
mock_db.get_request_by_api_key.return_value = sample_async_request
|
|
|
|
response = client.delete(
|
|
"/api/v1/async/requests/550e8400-e29b-41d4-a716-446655440000",
|
|
headers={"X-API-Key": "test-api-key"},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
assert response.json()["status"] == "deleted"
|
|
|
|
def test_delete_processing_request(self, client, mock_db, sample_async_request):
|
|
"""Test that processing requests cannot be deleted."""
|
|
sample_async_request.status = "processing"
|
|
mock_db.get_request_by_api_key.return_value = sample_async_request
|
|
|
|
response = client.delete(
|
|
"/api/v1/async/requests/550e8400-e29b-41d4-a716-446655440000",
|
|
headers={"X-API-Key": "test-api-key"},
|
|
)
|
|
|
|
assert response.status_code == 409
|
|
|
|
def test_delete_not_found(self, client, mock_db):
|
|
"""Test deleting non-existent request."""
|
|
mock_db.get_request_by_api_key.return_value = None
|
|
|
|
response = client.delete(
|
|
"/api/v1/async/requests/00000000-0000-0000-0000-000000000000",
|
|
headers={"X-API-Key": "test-api-key"},
|
|
)
|
|
|
|
assert response.status_code == 404
|