Add more tests

This commit is contained in:
Yaojia Wang
2026-02-01 22:40:41 +01:00
parent a564ac9d70
commit 400b12a967
55 changed files with 9306 additions and 267 deletions

View File

@@ -0,0 +1 @@
"""API integration tests."""

View File

@@ -0,0 +1,389 @@
"""
API Integration Tests
Tests FastAPI endpoints with mocked services.
These tests verify the API layer works correctly with the service layer.
"""
import io
import tempfile
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
@dataclass
class MockServiceResult:
"""Mock result from inference service."""
document_id: str = "test-doc-123"
success: bool = True
document_type: str = "invoice"
fields: dict[str, str] = field(default_factory=lambda: {
"InvoiceNumber": "INV-2024-001",
"Amount": "1500.00",
"InvoiceDate": "2024-01-15",
"OCR": "12345678901234",
"Bankgiro": "1234-5678",
})
confidence: dict[str, float] = field(default_factory=lambda: {
"InvoiceNumber": 0.95,
"Amount": 0.92,
"InvoiceDate": 0.88,
"OCR": 0.95,
"Bankgiro": 0.90,
})
detections: list[dict[str, Any]] = field(default_factory=list)
processing_time_ms: float = 150.5
visualization_path: Path | None = None
errors: list[str] = field(default_factory=list)
@pytest.fixture
def temp_storage_dir():
"""Create temporary storage directories."""
with tempfile.TemporaryDirectory() as tmpdir:
base = Path(tmpdir)
uploads_dir = base / "uploads" / "inference"
results_dir = base / "results"
uploads_dir.mkdir(parents=True, exist_ok=True)
results_dir.mkdir(parents=True, exist_ok=True)
yield {
"base": base,
"uploads": uploads_dir,
"results": results_dir,
}
@pytest.fixture
def mock_inference_service():
"""Create a mock inference service."""
service = MagicMock()
service.is_initialized = True
service.gpu_available = False
# Create a realistic mock result
mock_result = MockServiceResult()
service.process_pdf.return_value = mock_result
service.process_image.return_value = mock_result
service.initialize.return_value = None
return service
@pytest.fixture
def mock_storage_config(temp_storage_dir):
"""Create mock storage configuration."""
from inference.web.config import StorageConfig
return StorageConfig(
upload_dir=temp_storage_dir["uploads"],
result_dir=temp_storage_dir["results"],
max_file_size_mb=50,
)
@pytest.fixture
def mock_storage_helper(temp_storage_dir):
"""Create a mock storage helper."""
helper = MagicMock()
helper.get_uploads_base_path.return_value = temp_storage_dir["uploads"]
helper.get_result_local_path.return_value = None
helper.result_exists.return_value = False
return helper
@pytest.fixture
def test_app(mock_inference_service, mock_storage_config, mock_storage_helper):
"""Create a test FastAPI application with mocked storage."""
from inference.web.api.v1.public.inference import create_inference_router
app = FastAPI()
# Patch get_storage_helper to return our mock
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
inference_router = create_inference_router(mock_inference_service, mock_storage_config)
app.include_router(inference_router)
return app
@pytest.fixture
def client(test_app, mock_storage_helper):
"""Create a test client with storage helper patched."""
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
yield TestClient(test_app)
class TestHealthEndpoint:
"""Tests for health check endpoint."""
def test_health_check(self, client, mock_inference_service):
"""Test health check returns status."""
response = client.get("/api/v1/health")
assert response.status_code == 200
data = response.json()
assert "status" in data
assert "model_loaded" in data
class TestInferenceEndpoint:
"""Tests for inference endpoint."""
def test_infer_pdf(self, client, mock_inference_service, mock_storage_helper, temp_storage_dir):
"""Test PDF inference endpoint."""
# Create a minimal PDF content
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
assert response.status_code == 200
data = response.json()
assert "result" in data
assert data["result"]["success"] is True
assert "InvoiceNumber" in data["result"]["fields"]
def test_infer_image(self, client, mock_inference_service, mock_storage_helper):
"""Test image inference endpoint."""
# Create minimal PNG header
png_header = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.png", io.BytesIO(png_header), "image/png")},
)
assert response.status_code == 200
data = response.json()
assert "result" in data
def test_infer_invalid_file_type(self, client, mock_storage_helper):
"""Test rejection of invalid file types."""
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.txt", io.BytesIO(b"hello"), "text/plain")},
)
assert response.status_code == 400
def test_infer_no_file(self, client, mock_storage_helper):
"""Test rejection when no file provided."""
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post("/api/v1/infer")
assert response.status_code == 422 # Validation error
def test_infer_result_structure(self, client, mock_inference_service, mock_storage_helper):
"""Test that result has expected structure."""
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
data = response.json()
result = data["result"]
# Check required fields
assert "document_id" in result
assert "success" in result
assert "fields" in result
assert "confidence" in result
assert "processing_time_ms" in result
class TestInferenceResultFormat:
"""Tests for inference result formatting."""
def test_result_fields_mapped_correctly(self, client, mock_inference_service, mock_storage_helper):
"""Test that fields are mapped to API response format."""
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
data = response.json()
fields = data["result"]["fields"]
assert fields["InvoiceNumber"] == "INV-2024-001"
assert fields["Amount"] == "1500.00"
assert fields["InvoiceDate"] == "2024-01-15"
def test_confidence_values_included(self, client, mock_inference_service, mock_storage_helper):
"""Test that confidence values are included."""
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
data = response.json()
confidence = data["result"]["confidence"]
assert "InvoiceNumber" in confidence
assert confidence["InvoiceNumber"] == 0.95
class TestErrorHandling:
"""Tests for error handling in API."""
def test_service_error_handling(self, client, mock_inference_service, mock_storage_helper):
"""Test handling of service errors."""
mock_inference_service.process_pdf.side_effect = Exception("Processing failed")
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
# Should return error response
assert response.status_code >= 400
def test_empty_file_handling(self, client, mock_storage_helper):
"""Test handling of empty files."""
# Empty file still has valid content type
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(b""), "application/pdf")},
)
# Empty file may be processed or rejected depending on implementation
# Just verify we get a response
assert response.status_code in [200, 400, 422, 500]
class TestResponseFormat:
"""Tests for API response format consistency."""
def test_success_response_format(self, client, mock_inference_service, mock_storage_helper):
"""Test successful response format."""
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
data = response.json()
assert isinstance(data, dict)
assert "result" in data
def test_json_serialization(self, client, mock_inference_service, mock_storage_helper):
"""Test that all result fields are JSON serializable."""
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
# If this doesn't raise, JSON is valid
data = response.json()
assert data is not None
class TestDocumentIdGeneration:
"""Tests for document ID handling."""
def test_document_id_generated(self, client, mock_inference_service, mock_storage_helper):
"""Test that document ID is generated."""
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
data = response.json()
assert "document_id" in data["result"]
assert data["result"]["document_id"] is not None
def test_document_id_from_filename(self, client, mock_inference_service, mock_storage_helper):
"""Test document ID derived from filename."""
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("my_invoice_123.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
data = response.json()
# Document ID should be set (either from filename or generated)
assert data["result"]["document_id"] is not None

View File

@@ -0,0 +1,400 @@
"""
Dashboard API Integration Tests
Tests Dashboard API endpoints with real database operations via TestClient.
"""
from datetime import datetime, timezone
from uuid import uuid4
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from inference.data.admin_models import (
AdminAnnotation,
AdminDocument,
AdminToken,
AnnotationHistory,
ModelVersion,
TrainingDataset,
TrainingTask,
)
from inference.web.api.v1.admin.dashboard import create_dashboard_router
from inference.web.core.auth import get_admin_token_dep
def create_test_app(override_token_dep):
"""Create a FastAPI test application with dashboard router."""
app = FastAPI()
router = create_dashboard_router()
app.include_router(router)
# Override auth dependency
app.dependency_overrides[get_admin_token_dep] = lambda: override_token_dep
return app
class TestDashboardStatsEndpoint:
"""Tests for GET /admin/dashboard/stats endpoint."""
def test_stats_empty_database(self, patched_session, admin_token):
"""Test stats endpoint with empty database."""
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/stats")
assert response.status_code == 200
data = response.json()
assert data["total_documents"] == 0
assert data["annotation_complete"] == 0
assert data["annotation_incomplete"] == 0
assert data["pending"] == 0
assert data["completeness_rate"] == 0.0
def test_stats_with_pending_documents(self, patched_session, admin_token):
"""Test stats with pending documents."""
session = patched_session
# Create pending documents
for i in range(3):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"pending_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/pending_{i}.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/stats")
assert response.status_code == 200
data = response.json()
assert data["total_documents"] == 3
assert data["pending"] == 3
def test_stats_with_complete_annotations(self, patched_session, admin_token):
"""Test stats with complete annotations."""
session = patched_session
# Create labeled document with complete annotations
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename="complete.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/complete.pdf",
page_count=1,
status="labeled",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
# Add identifier and payment annotations
session.add(AdminAnnotation(
annotation_id=uuid4(),
document_id=doc.document_id,
page_number=1,
class_id=0, # invoice_number
class_name="invoice_number",
x_center=0.5, y_center=0.1, width=0.2, height=0.05,
bbox_x=400, bbox_y=80, bbox_width=160, bbox_height=40,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
))
session.add(AdminAnnotation(
annotation_id=uuid4(),
document_id=doc.document_id,
page_number=1,
class_id=4, # bankgiro
class_name="bankgiro",
x_center=0.5, y_center=0.2, width=0.2, height=0.05,
bbox_x=400, bbox_y=160, bbox_width=160, bbox_height=40,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
))
session.commit()
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/stats")
assert response.status_code == 200
data = response.json()
assert data["annotation_complete"] == 1
assert data["completeness_rate"] == 100.0
class TestActiveModelEndpoint:
"""Tests for GET /admin/dashboard/active-model endpoint."""
def test_active_model_none(self, patched_session, admin_token):
"""Test active-model endpoint with no active model."""
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/active-model")
assert response.status_code == 200
data = response.json()
assert data["model"] is None
assert data["running_training"] is None
def test_active_model_with_model(self, patched_session, admin_token, sample_dataset):
"""Test active-model endpoint with active model."""
session = patched_session
# Create training task
task = TrainingTask(
task_id=uuid4(),
admin_token=admin_token.token,
name="Test Task",
status="completed",
task_type="train",
dataset_id=sample_dataset.dataset_id,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(task)
session.commit()
# Create active model
model = ModelVersion(
version_id=uuid4(),
version="1.0.0",
name="Test Model",
model_path="/models/test.pt",
status="active",
is_active=True,
task_id=task.task_id,
dataset_id=sample_dataset.dataset_id,
metrics_mAP=0.90,
metrics_precision=0.88,
metrics_recall=0.85,
document_count=100,
file_size=50000000,
activated_at=datetime.now(timezone.utc),
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(model)
session.commit()
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/active-model")
assert response.status_code == 200
data = response.json()
assert data["model"] is not None
assert data["model"]["version"] == "1.0.0"
assert data["model"]["name"] == "Test Model"
assert data["model"]["metrics_mAP"] == 0.90
def test_active_model_with_running_training(self, patched_session, admin_token, sample_dataset):
"""Test active-model endpoint with running training."""
session = patched_session
# Create running training task
task = TrainingTask(
task_id=uuid4(),
admin_token=admin_token.token,
name="Running Task",
status="running",
task_type="train",
dataset_id=sample_dataset.dataset_id,
started_at=datetime.now(timezone.utc),
progress=50,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(task)
session.commit()
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/active-model")
assert response.status_code == 200
data = response.json()
assert data["running_training"] is not None
assert data["running_training"]["name"] == "Running Task"
assert data["running_training"]["status"] == "running"
assert data["running_training"]["progress"] == 50
class TestRecentActivityEndpoint:
"""Tests for GET /admin/dashboard/activity endpoint."""
def test_activity_empty(self, patched_session, admin_token):
"""Test activity endpoint with no activities."""
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/activity")
assert response.status_code == 200
data = response.json()
assert data["activities"] == []
def test_activity_with_uploads(self, patched_session, admin_token):
"""Test activity includes document uploads."""
session = patched_session
# Create documents
for i in range(3):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"activity_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/activity_{i}.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/activity")
assert response.status_code == 200
data = response.json()
upload_activities = [a for a in data["activities"] if a["type"] == "document_uploaded"]
assert len(upload_activities) == 3
def test_activity_limit_parameter(self, patched_session, admin_token):
"""Test activity limit parameter."""
session = patched_session
# Create many documents
for i in range(15):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"limit_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/limit_{i}.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/activity?limit=5")
assert response.status_code == 200
data = response.json()
assert len(data["activities"]) <= 5
def test_activity_invalid_limit(self, patched_session, admin_token):
"""Test activity with invalid limit parameter."""
app = create_test_app(admin_token.token)
client = TestClient(app)
# Limit too high
response = client.get("/admin/dashboard/activity?limit=100")
assert response.status_code == 422
# Limit too low
response = client.get("/admin/dashboard/activity?limit=0")
assert response.status_code == 422
def test_activity_with_training_completion(self, patched_session, admin_token, sample_dataset):
"""Test activity includes training completions."""
session = patched_session
# Create completed training task
task = TrainingTask(
task_id=uuid4(),
admin_token=admin_token.token,
name="Completed Task",
status="completed",
task_type="train",
dataset_id=sample_dataset.dataset_id,
metrics_mAP=0.95,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(task)
session.commit()
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/activity")
assert response.status_code == 200
data = response.json()
training_activities = [a for a in data["activities"] if a["type"] == "training_completed"]
assert len(training_activities) >= 1
def test_activity_sorted_by_timestamp(self, patched_session, admin_token):
"""Test activities are sorted by timestamp descending."""
session = patched_session
# Create documents
for i in range(5):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"sorted_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/sorted_{i}.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/activity")
assert response.status_code == 200
data = response.json()
timestamps = [a["timestamp"] for a in data["activities"]]
assert timestamps == sorted(timestamps, reverse=True)