This commit is contained in:
Yaojia Wang
2026-01-27 00:47:10 +01:00
parent e83a0cae36
commit 58bf75db68
141 changed files with 24814 additions and 3884 deletions

1
tests/web/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Tests for web API components."""

132
tests/web/conftest.py Normal file
View File

@@ -0,0 +1,132 @@
"""
Test fixtures for web API tests.
"""
import tempfile
from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import MagicMock, patch
from uuid import UUID
import pytest
from src.data.async_request_db import ApiKeyConfig, AsyncRequestDB
from src.data.models import AsyncRequest
from src.web.workers.async_queue import AsyncTask, AsyncTaskQueue
from src.web.services.async_processing import AsyncProcessingService
from src.web.config import AsyncConfig, StorageConfig
from src.web.core.rate_limiter import RateLimiter
@pytest.fixture
def mock_db():
"""Create a mock AsyncRequestDB."""
db = MagicMock(spec=AsyncRequestDB)
# Default return values
db.is_valid_api_key.return_value = True
db.get_api_key_config.return_value = ApiKeyConfig(
api_key="test-api-key",
name="Test Key",
is_active=True,
requests_per_minute=10,
max_concurrent_jobs=3,
max_file_size_mb=50,
)
db.count_active_jobs.return_value = 0
db.get_queue_position.return_value = 1
return db
@pytest.fixture
def rate_limiter(mock_db):
"""Create a RateLimiter with mock database."""
return RateLimiter(mock_db)
@pytest.fixture
def task_queue():
"""Create an AsyncTaskQueue."""
return AsyncTaskQueue(max_size=10, worker_count=1)
@pytest.fixture
def async_config():
"""Create an AsyncConfig for testing."""
with tempfile.TemporaryDirectory() as tmpdir:
yield AsyncConfig(
queue_max_size=10,
worker_count=1,
task_timeout_seconds=30,
result_retention_days=7,
temp_upload_dir=Path(tmpdir) / "async",
max_file_size_mb=10,
)
@pytest.fixture
def storage_config():
"""Create a StorageConfig for testing."""
with tempfile.TemporaryDirectory() as tmpdir:
yield StorageConfig(
upload_dir=Path(tmpdir) / "uploads",
result_dir=Path(tmpdir) / "results",
max_file_size_mb=50,
)
@pytest.fixture
def mock_inference_service():
"""Create a mock InferenceService."""
service = MagicMock()
service.is_initialized = True
service.gpu_available = False
# Mock process_pdf to return a successful result
mock_result = MagicMock()
mock_result.document_id = "test-doc"
mock_result.success = True
mock_result.document_type = "invoice"
mock_result.fields = {"InvoiceNumber": "12345", "Amount": "1000.00"}
mock_result.confidence = {"InvoiceNumber": 0.95, "Amount": 0.92}
mock_result.detections = []
mock_result.errors = []
mock_result.visualization_path = None
service.process_pdf.return_value = mock_result
service.process_image.return_value = mock_result
return service
# Valid UUID for testing
TEST_REQUEST_UUID = "550e8400-e29b-41d4-a716-446655440000"
@pytest.fixture
def sample_async_request():
"""Create a sample AsyncRequest."""
return AsyncRequest(
request_id=UUID(TEST_REQUEST_UUID),
api_key="test-api-key",
status="pending",
filename="test.pdf",
file_size=1024,
content_type="application/pdf",
expires_at=datetime.utcnow() + timedelta(days=7),
)
@pytest.fixture
def sample_task():
"""Create a sample AsyncTask."""
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
f.write(b"fake pdf content")
return AsyncTask(
request_id=TEST_REQUEST_UUID,
api_key="test-api-key",
file_path=Path(f.name),
filename="test.pdf",
created_at=datetime.utcnow(),
)

View File

@@ -0,0 +1,197 @@
"""
Tests for Admin Annotation Routes.
"""
import pytest
from datetime import datetime
from unittest.mock import MagicMock, patch
from uuid import UUID
from fastapi import HTTPException
from src.data.admin_models import AdminAnnotation, AdminDocument, FIELD_CLASSES
from src.web.api.v1.admin.annotations import _validate_uuid, create_annotation_router
from src.web.schemas.admin import (
AnnotationCreate,
AnnotationUpdate,
AutoLabelRequest,
BoundingBox,
)
# Test UUIDs
TEST_DOC_UUID = "550e8400-e29b-41d4-a716-446655440000"
TEST_ANN_UUID = "660e8400-e29b-41d4-a716-446655440001"
TEST_TOKEN = "test-admin-token-12345"
class TestAnnotationRouterCreation:
"""Tests for annotation router creation."""
def test_creates_router_with_endpoints(self):
"""Test router is created with expected endpoints."""
router = create_annotation_router()
# Get route paths (includes prefix)
paths = [route.path for route in router.routes]
# Paths include the /admin/documents prefix
assert any("{document_id}/annotations" in p for p in paths)
assert any("{annotation_id}" in p for p in paths)
assert any("auto-label" in p for p in paths)
assert any("images" in p for p in paths)
class TestAnnotationCreateSchema:
"""Tests for AnnotationCreate schema."""
def test_valid_annotation(self):
"""Test valid annotation creation."""
ann = AnnotationCreate(
page_number=1,
class_id=0,
bbox=BoundingBox(x=100, y=100, width=200, height=50),
text_value="12345",
)
assert ann.page_number == 1
assert ann.class_id == 0
assert ann.bbox.x == 100
assert ann.text_value == "12345"
def test_class_id_range(self):
"""Test class_id must be 0-9."""
# Valid class IDs
for class_id in range(10):
ann = AnnotationCreate(
page_number=1,
class_id=class_id,
bbox=BoundingBox(x=0, y=0, width=100, height=50),
)
assert ann.class_id == class_id
def test_bbox_validation(self):
"""Test bounding box validation."""
bbox = BoundingBox(x=0, y=0, width=100, height=50)
assert bbox.width >= 1
assert bbox.height >= 1
class TestAnnotationUpdateSchema:
"""Tests for AnnotationUpdate schema."""
def test_partial_update(self):
"""Test partial update with only some fields."""
update = AnnotationUpdate(
text_value="new value",
)
assert update.text_value == "new value"
assert update.class_id is None
assert update.bbox is None
def test_bbox_update(self):
"""Test bounding box update."""
update = AnnotationUpdate(
bbox=BoundingBox(x=50, y=50, width=150, height=75),
)
assert update.bbox.x == 50
assert update.bbox.width == 150
class TestAutoLabelRequestSchema:
"""Tests for AutoLabelRequest schema."""
def test_valid_request(self):
"""Test valid auto-label request."""
request = AutoLabelRequest(
field_values={
"InvoiceNumber": "12345",
"Amount": "1000.00",
},
replace_existing=True,
)
assert len(request.field_values) == 2
assert request.field_values["InvoiceNumber"] == "12345"
assert request.replace_existing is True
def test_requires_field_values(self):
"""Test that field_values is required."""
with pytest.raises(Exception):
AutoLabelRequest(replace_existing=True)
class TestFieldClasses:
"""Tests for field class mapping."""
def test_all_classes_defined(self):
"""Test all 10 field classes are defined."""
assert len(FIELD_CLASSES) == 10
def test_class_ids_sequential(self):
"""Test class IDs are 0-9."""
assert set(FIELD_CLASSES.keys()) == set(range(10))
def test_known_field_names(self):
"""Test known field names are present."""
names = list(FIELD_CLASSES.values())
assert "invoice_number" in names
assert "invoice_date" in names
assert "amount" in names
assert "bankgiro" in names
assert "ocr_number" in names
class TestAnnotationModel:
"""Tests for AdminAnnotation model."""
def test_annotation_creation(self):
"""Test annotation model creation."""
ann = AdminAnnotation(
document_id=UUID(TEST_DOC_UUID),
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.5,
width=0.2,
height=0.05,
bbox_x=100,
bbox_y=100,
bbox_width=200,
bbox_height=50,
text_value="12345",
confidence=0.95,
source="manual",
)
assert str(ann.document_id) == TEST_DOC_UUID
assert ann.class_id == 0
assert ann.x_center == 0.5
assert ann.source == "manual"
def test_normalized_coordinates(self):
"""Test normalized coordinates are 0-1 range."""
# Valid normalized coords
ann = AdminAnnotation(
document_id=UUID(TEST_DOC_UUID),
page_number=1,
class_id=0,
class_name="test",
x_center=0.5,
y_center=0.5,
width=0.2,
height=0.05,
bbox_x=0,
bbox_y=0,
bbox_width=100,
bbox_height=50,
)
assert 0 <= ann.x_center <= 1
assert 0 <= ann.y_center <= 1
assert 0 <= ann.width <= 1
assert 0 <= ann.height <= 1

View File

@@ -0,0 +1,162 @@
"""
Tests for Admin Authentication.
"""
import pytest
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch
from fastapi import HTTPException
from src.data.admin_db import AdminDB
from src.data.admin_models import AdminToken
from src.web.core.auth import (
get_admin_db,
reset_admin_db,
validate_admin_token,
)
@pytest.fixture
def mock_admin_db():
"""Create a mock AdminDB."""
db = MagicMock(spec=AdminDB)
db.is_valid_admin_token.return_value = True
return db
@pytest.fixture(autouse=True)
def reset_db():
"""Reset admin DB after each test."""
yield
reset_admin_db()
class TestValidateAdminToken:
"""Tests for validate_admin_token dependency."""
def test_missing_token_raises_401(self, mock_admin_db):
"""Test that missing token raises 401."""
import asyncio
with pytest.raises(HTTPException) as exc_info:
asyncio.get_event_loop().run_until_complete(
validate_admin_token(None, mock_admin_db)
)
assert exc_info.value.status_code == 401
assert "Admin token required" in exc_info.value.detail
def test_invalid_token_raises_401(self, mock_admin_db):
"""Test that invalid token raises 401."""
import asyncio
mock_admin_db.is_valid_admin_token.return_value = False
with pytest.raises(HTTPException) as exc_info:
asyncio.get_event_loop().run_until_complete(
validate_admin_token("invalid-token", mock_admin_db)
)
assert exc_info.value.status_code == 401
assert "Invalid or expired" in exc_info.value.detail
def test_valid_token_returns_token(self, mock_admin_db):
"""Test that valid token is returned."""
import asyncio
token = "valid-test-token"
mock_admin_db.is_valid_admin_token.return_value = True
result = asyncio.get_event_loop().run_until_complete(
validate_admin_token(token, mock_admin_db)
)
assert result == token
mock_admin_db.update_admin_token_usage.assert_called_once_with(token)
class TestAdminDB:
"""Tests for AdminDB operations."""
def test_is_valid_admin_token_active(self):
"""Test valid active token."""
with patch("src.data.admin_db.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
mock_token = AdminToken(
token="test-token",
name="Test",
is_active=True,
expires_at=None,
)
mock_session.get.return_value = mock_token
db = AdminDB()
assert db.is_valid_admin_token("test-token") is True
def test_is_valid_admin_token_inactive(self):
"""Test inactive token."""
with patch("src.data.admin_db.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
mock_token = AdminToken(
token="test-token",
name="Test",
is_active=False,
expires_at=None,
)
mock_session.get.return_value = mock_token
db = AdminDB()
assert db.is_valid_admin_token("test-token") is False
def test_is_valid_admin_token_expired(self):
"""Test expired token."""
with patch("src.data.admin_db.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
mock_token = AdminToken(
token="test-token",
name="Test",
is_active=True,
expires_at=datetime.utcnow() - timedelta(days=1),
)
mock_session.get.return_value = mock_token
db = AdminDB()
assert db.is_valid_admin_token("test-token") is False
def test_is_valid_admin_token_not_found(self):
"""Test token not found."""
with patch("src.data.admin_db.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
mock_session.get.return_value = None
db = AdminDB()
assert db.is_valid_admin_token("nonexistent") is False
class TestGetAdminDb:
"""Tests for get_admin_db function."""
def test_returns_singleton(self):
"""Test that get_admin_db returns singleton."""
reset_admin_db()
db1 = get_admin_db()
db2 = get_admin_db()
assert db1 is db2
def test_reset_clears_singleton(self):
"""Test that reset clears singleton."""
db1 = get_admin_db()
reset_admin_db()
db2 = get_admin_db()
assert db1 is not db2

View File

@@ -0,0 +1,164 @@
"""
Tests for Admin Document Routes.
"""
import pytest
from datetime import datetime
from io import BytesIO
from pathlib import Path
from unittest.mock import MagicMock, patch
from uuid import UUID
from fastapi import HTTPException
from fastapi.testclient import TestClient
from src.data.admin_models import AdminDocument, AdminToken
from src.web.api.v1.admin.documents import _validate_uuid, create_admin_router
# Test UUID
TEST_DOC_UUID = "550e8400-e29b-41d4-a716-446655440000"
TEST_TOKEN = "test-admin-token-12345"
class TestValidateUUID:
"""Tests for UUID validation."""
def test_valid_uuid(self):
"""Test valid UUID passes validation."""
_validate_uuid(TEST_DOC_UUID, "test") # Should not raise
def test_invalid_uuid_raises_400(self):
"""Test invalid UUID raises 400."""
with pytest.raises(HTTPException) as exc_info:
_validate_uuid("not-a-uuid", "document_id")
assert exc_info.value.status_code == 400
assert "Invalid document_id format" in exc_info.value.detail
class TestAdminRouter:
"""Tests for admin router creation."""
def test_creates_router_with_endpoints(self):
"""Test router is created with expected endpoints."""
router = create_admin_router((".pdf", ".png", ".jpg"))
# Get route paths (include prefix from router)
paths = [route.path for route in router.routes]
# Paths include the /admin prefix
assert any("/auth/token" in p for p in paths)
assert any("/documents" in p for p in paths)
assert any("/documents/stats" in p for p in paths)
assert any("{document_id}" in p for p in paths)
class TestCreateTokenEndpoint:
"""Tests for POST /admin/auth/token endpoint."""
@pytest.fixture
def mock_db(self):
"""Create mock AdminDB."""
db = MagicMock()
db.is_valid_admin_token.return_value = True
return db
def test_create_token_success(self, mock_db):
"""Test successful token creation."""
from src.web.schemas.admin import AdminTokenCreate
request = AdminTokenCreate(name="Test Token", expires_in_days=30)
# The actual endpoint would generate a token
# This tests the schema validation
assert request.name == "Test Token"
assert request.expires_in_days == 30
class TestDocumentUploadEndpoint:
"""Tests for POST /admin/documents endpoint."""
@pytest.fixture
def sample_pdf_bytes(self):
"""Create sample PDF-like bytes."""
# Minimal PDF header
return b"%PDF-1.4\n%\xe2\xe3\xcf\xd3\n"
@pytest.fixture
def mock_admin_db(self):
"""Create mock AdminDB."""
db = MagicMock()
db.is_valid_admin_token.return_value = True
db.create_document.return_value = TEST_DOC_UUID
return db
def test_rejects_invalid_extension(self):
"""Test that invalid file extensions are rejected."""
# Schema validation would happen at the route level
allowed = (".pdf", ".png", ".jpg")
file_ext = ".exe"
assert file_ext not in allowed
class TestDocumentListEndpoint:
"""Tests for GET /admin/documents endpoint."""
@pytest.fixture
def sample_documents(self):
"""Create sample documents."""
return [
AdminDocument(
document_id=UUID(TEST_DOC_UUID),
admin_token=TEST_TOKEN,
filename="test.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/tmp/test.pdf",
page_count=1,
status="pending",
),
]
def test_validates_status_filter(self):
"""Test that invalid status filter is rejected."""
valid_statuses = ("pending", "auto_labeling", "labeled", "exported")
assert "invalid_status" not in valid_statuses
assert "pending" in valid_statuses
class TestDocumentDetailEndpoint:
"""Tests for GET /admin/documents/{document_id} endpoint."""
def test_requires_valid_uuid(self):
"""Test that invalid UUID is rejected."""
with pytest.raises(HTTPException) as exc_info:
_validate_uuid("invalid", "document_id")
assert exc_info.value.status_code == 400
class TestDocumentDeleteEndpoint:
"""Tests for DELETE /admin/documents/{document_id} endpoint."""
def test_validates_document_id(self):
"""Test that document_id is validated."""
# Valid UUID should not raise
_validate_uuid(TEST_DOC_UUID, "document_id")
# Invalid should raise
with pytest.raises(HTTPException):
_validate_uuid("bad-id", "document_id")
class TestDocumentStatusUpdateEndpoint:
"""Tests for PATCH /admin/documents/{document_id}/status endpoint."""
def test_validates_status_values(self):
"""Test that only valid statuses are accepted."""
valid_statuses = ("pending", "labeled", "exported")
assert "pending" in valid_statuses
assert "invalid" not in valid_statuses

View File

@@ -0,0 +1,351 @@
"""
Tests for Enhanced Admin Document Routes (Phase 3).
"""
import pytest
from datetime import datetime
from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from src.web.api.v1.admin.documents import create_admin_router
from src.web.core.auth import validate_admin_token, get_admin_db
class MockAdminDocument:
"""Mock AdminDocument for testing."""
def __init__(self, **kwargs):
self.document_id = kwargs.get('document_id', uuid4())
self.admin_token = kwargs.get('admin_token', 'test-token')
self.filename = kwargs.get('filename', 'test.pdf')
self.file_size = kwargs.get('file_size', 100000)
self.content_type = kwargs.get('content_type', 'application/pdf')
self.page_count = kwargs.get('page_count', 1)
self.status = kwargs.get('status', 'pending')
self.auto_label_status = kwargs.get('auto_label_status', None)
self.auto_label_error = kwargs.get('auto_label_error', None)
self.upload_source = kwargs.get('upload_source', 'ui')
self.batch_id = kwargs.get('batch_id', None)
self.csv_field_values = kwargs.get('csv_field_values', None)
self.annotation_lock_until = kwargs.get('annotation_lock_until', None)
self.created_at = kwargs.get('created_at', datetime.utcnow())
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockAnnotation:
"""Mock AdminAnnotation for testing."""
def __init__(self, **kwargs):
self.annotation_id = kwargs.get('annotation_id', uuid4())
self.document_id = kwargs.get('document_id')
self.page_number = kwargs.get('page_number', 1)
self.class_id = kwargs.get('class_id', 0)
self.class_name = kwargs.get('class_name', 'invoice_number')
self.bbox_x = kwargs.get('bbox_x', 100.0)
self.bbox_y = kwargs.get('bbox_y', 100.0)
self.bbox_width = kwargs.get('bbox_width', 200.0)
self.bbox_height = kwargs.get('bbox_height', 50.0)
self.x_center = kwargs.get('x_center', 0.5)
self.y_center = kwargs.get('y_center', 0.5)
self.width = kwargs.get('width', 0.3)
self.height = kwargs.get('height', 0.1)
self.text_value = kwargs.get('text_value', 'INV-001')
self.confidence = kwargs.get('confidence', 0.95)
self.source = kwargs.get('source', 'manual')
self.created_at = kwargs.get('created_at', datetime.utcnow())
class MockAdminDB:
"""Mock AdminDB for testing enhanced features."""
def __init__(self):
self.documents = {}
self.annotations = {}
def get_documents_by_token(
self,
admin_token,
status=None,
upload_source=None,
has_annotations=None,
auto_label_status=None,
batch_id=None,
limit=20,
offset=0
):
"""Get filtered documents."""
docs = list(self.documents.values())
# Apply filters
if status:
docs = [d for d in docs if d.status == status]
if upload_source:
docs = [d for d in docs if d.upload_source == upload_source]
if has_annotations is not None:
for d in docs[:]:
ann_count = len(self.annotations.get(str(d.document_id), []))
if has_annotations and ann_count == 0:
docs.remove(d)
elif not has_annotations and ann_count > 0:
docs.remove(d)
if auto_label_status:
docs = [d for d in docs if d.auto_label_status == auto_label_status]
if batch_id:
docs = [d for d in docs if str(d.batch_id) == str(batch_id)]
total = len(docs)
return docs[offset:offset+limit], total
def get_annotations_for_document(self, document_id):
"""Get annotations for document."""
return self.annotations.get(str(document_id), [])
def count_documents_by_status(self, admin_token):
"""Count documents by status."""
counts = {}
for doc in self.documents.values():
if doc.admin_token == admin_token:
counts[doc.status] = counts.get(doc.status, 0) + 1
return counts
def get_document_by_token(self, document_id, admin_token):
"""Get single document by ID and token."""
doc = self.documents.get(document_id)
if doc and doc.admin_token == admin_token:
return doc
return None
def get_document_training_tasks(self, document_id):
"""Get training tasks that used this document."""
return [] # No training history in this test
def get_training_task(self, task_id):
"""Get training task by ID."""
return None # No training tasks in this test
@pytest.fixture
def app():
"""Create test FastAPI app."""
app = FastAPI()
# Create mock DB
mock_db = MockAdminDB()
# Add test documents
doc1 = MockAdminDocument(
filename="INV001.pdf",
status="labeled",
upload_source="ui",
auto_label_status=None,
batch_id=None
)
doc2 = MockAdminDocument(
filename="INV002.pdf",
status="labeled",
upload_source="api",
auto_label_status="completed",
batch_id=uuid4()
)
doc3 = MockAdminDocument(
filename="INV003.pdf",
status="pending",
upload_source="ui",
auto_label_status=None, # Not auto-labeled yet
batch_id=None
)
mock_db.documents[str(doc1.document_id)] = doc1
mock_db.documents[str(doc2.document_id)] = doc2
mock_db.documents[str(doc3.document_id)] = doc3
# Add annotations to doc1 and doc2
mock_db.annotations[str(doc1.document_id)] = [
MockAnnotation(
document_id=doc1.document_id,
class_name="invoice_number",
text_value="INV-001"
)
]
mock_db.annotations[str(doc2.document_id)] = [
MockAnnotation(
document_id=doc2.document_id,
class_id=6,
class_name="amount",
text_value="1500.00"
),
MockAnnotation(
document_id=doc2.document_id,
class_id=1,
class_name="invoice_date",
text_value="2024-01-15"
)
]
# Override dependencies
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
app.dependency_overrides[get_admin_db] = lambda: mock_db
# Include router
router = create_admin_router((".pdf", ".png", ".jpg"))
app.include_router(router)
return app
@pytest.fixture
def client(app):
"""Create test client."""
return TestClient(app)
class TestEnhancedDocumentList:
"""Tests for enhanced document list endpoint."""
def test_list_documents_filter_by_upload_source_ui(self, client):
"""Test filtering documents by upload_source=ui."""
response = client.get("/admin/documents?upload_source=ui")
assert response.status_code == 200
data = response.json()
assert data["total"] == 2
assert all(doc["filename"].startswith("INV") for doc in data["documents"])
def test_list_documents_filter_by_upload_source_api(self, client):
"""Test filtering documents by upload_source=api."""
response = client.get("/admin/documents?upload_source=api")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["documents"][0]["filename"] == "INV002.pdf"
def test_list_documents_filter_by_has_annotations_true(self, client):
"""Test filtering documents with annotations."""
response = client.get("/admin/documents?has_annotations=true")
assert response.status_code == 200
data = response.json()
assert data["total"] == 2
def test_list_documents_filter_by_has_annotations_false(self, client):
"""Test filtering documents without annotations."""
response = client.get("/admin/documents?has_annotations=false")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
def test_list_documents_filter_by_auto_label_status(self, client):
"""Test filtering by auto_label_status."""
response = client.get("/admin/documents?auto_label_status=completed")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["documents"][0]["filename"] == "INV002.pdf"
def test_list_documents_filter_by_batch_id(self, client):
"""Test filtering by batch_id."""
# Get a batch_id from the test data
response_all = client.get("/admin/documents?upload_source=api")
batch_id = response_all.json()["documents"][0]["batch_id"]
response = client.get(f"/admin/documents?batch_id={batch_id}")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
def test_list_documents_combined_filters(self, client):
"""Test combining multiple filters."""
response = client.get(
"/admin/documents?status=labeled&upload_source=api"
)
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["documents"][0]["filename"] == "INV002.pdf"
def test_document_item_includes_new_fields(self, client):
"""Test DocumentItem includes new Phase 2/3 fields."""
response = client.get("/admin/documents?upload_source=api")
assert response.status_code == 200
data = response.json()
doc = data["documents"][0]
# Check new fields exist
assert "upload_source" in doc
assert doc["upload_source"] == "api"
assert "batch_id" in doc
assert doc["batch_id"] is not None
assert "can_annotate" in doc
assert isinstance(doc["can_annotate"], bool)
class TestEnhancedDocumentDetail:
"""Tests for enhanced document detail endpoint."""
def test_document_detail_includes_new_fields(self, client, app):
"""Test DocumentDetailResponse includes new Phase 2/3 fields."""
# Get a document ID from list
response = client.get("/admin/documents?upload_source=api")
assert response.status_code == 200
doc_list = response.json()
document_id = doc_list["documents"][0]["document_id"]
# Get document detail
response = client.get(f"/admin/documents/{document_id}")
assert response.status_code == 200
doc = response.json()
# Check new fields exist
assert "upload_source" in doc
assert doc["upload_source"] == "api"
assert "batch_id" in doc
assert doc["batch_id"] is not None
assert "can_annotate" in doc
assert isinstance(doc["can_annotate"], bool)
assert "csv_field_values" in doc
assert "annotation_lock_until" in doc
def test_document_detail_ui_upload_defaults(self, client, app):
"""Test UI-uploaded document has correct defaults."""
# Get a UI-uploaded document
response = client.get("/admin/documents?upload_source=ui")
assert response.status_code == 200
doc_list = response.json()
document_id = doc_list["documents"][0]["document_id"]
# Get document detail
response = client.get(f"/admin/documents/{document_id}")
assert response.status_code == 200
doc = response.json()
# UI uploads should have these defaults
assert doc["upload_source"] == "ui"
assert doc["batch_id"] is None
assert doc["csv_field_values"] is None
assert doc["can_annotate"] is True
assert doc["annotation_lock_until"] is None
def test_document_detail_with_annotations(self, client, app):
"""Test document detail includes annotations."""
# Get a document with annotations
response = client.get("/admin/documents?has_annotations=true")
assert response.status_code == 200
doc_list = response.json()
document_id = doc_list["documents"][0]["document_id"]
# Get document detail
response = client.get(f"/admin/documents/{document_id}")
assert response.status_code == 200
doc = response.json()
# Should have annotations
assert "annotations" in doc
assert len(doc["annotations"]) > 0

View File

@@ -0,0 +1,247 @@
"""
Tests for Admin Training Routes and Scheduler.
"""
import pytest
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch
from uuid import UUID
from src.data.admin_models import TrainingTask, TrainingLog
from src.web.api.v1.admin.training import _validate_uuid, create_training_router
from src.web.core.scheduler import (
TrainingScheduler,
get_training_scheduler,
start_scheduler,
stop_scheduler,
)
from src.web.schemas.admin import (
TrainingConfig,
TrainingStatus,
TrainingTaskCreate,
TrainingType,
)
# Test UUIDs
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
TEST_TOKEN = "test-admin-token-12345"
class TestTrainingRouterCreation:
"""Tests for training router creation."""
def test_creates_router_with_endpoints(self):
"""Test router is created with expected endpoints."""
router = create_training_router()
# Get route paths (include prefix)
paths = [route.path for route in router.routes]
# Paths include the /admin/training prefix
assert any("/tasks" in p for p in paths)
assert any("{task_id}" in p for p in paths)
assert any("cancel" in p for p in paths)
assert any("logs" in p for p in paths)
assert any("export" in p for p in paths)
class TestTrainingConfigSchema:
"""Tests for TrainingConfig schema."""
def test_default_config(self):
"""Test default training configuration."""
config = TrainingConfig()
assert config.model_name == "yolo11n.pt"
assert config.epochs == 100
assert config.batch_size == 16
assert config.image_size == 640
assert config.learning_rate == 0.01
assert config.device == "0"
def test_custom_config(self):
"""Test custom training configuration."""
config = TrainingConfig(
model_name="yolo11s.pt",
epochs=50,
batch_size=8,
image_size=416,
learning_rate=0.001,
device="cpu",
)
assert config.model_name == "yolo11s.pt"
assert config.epochs == 50
assert config.batch_size == 8
def test_config_validation(self):
"""Test config validation constraints."""
# Epochs must be 1-1000
config = TrainingConfig(epochs=1)
assert config.epochs == 1
config = TrainingConfig(epochs=1000)
assert config.epochs == 1000
class TestTrainingTaskCreateSchema:
"""Tests for TrainingTaskCreate schema."""
def test_minimal_task(self):
"""Test minimal task creation."""
task = TrainingTaskCreate(name="Test Training")
assert task.name == "Test Training"
assert task.task_type == TrainingType.TRAIN
assert task.description is None
assert task.scheduled_at is None
def test_scheduled_task(self):
"""Test scheduled task creation."""
scheduled_time = datetime.utcnow() + timedelta(hours=1)
task = TrainingTaskCreate(
name="Scheduled Training",
scheduled_at=scheduled_time,
)
assert task.scheduled_at == scheduled_time
def test_recurring_task(self):
"""Test recurring task with cron expression."""
task = TrainingTaskCreate(
name="Recurring Training",
cron_expression="0 0 * * 0", # Every Sunday at midnight
)
assert task.cron_expression == "0 0 * * 0"
class TestTrainingTaskModel:
"""Tests for TrainingTask model."""
def test_task_creation(self):
"""Test training task model creation."""
task = TrainingTask(
admin_token=TEST_TOKEN,
name="Test Task",
task_type="train",
status="pending",
)
assert task.name == "Test Task"
assert task.task_type == "train"
assert task.status == "pending"
def test_task_with_config(self):
"""Test task with configuration."""
config = {
"model_name": "yolo11n.pt",
"epochs": 100,
}
task = TrainingTask(
admin_token=TEST_TOKEN,
name="Configured Task",
task_type="train",
config=config,
)
assert task.config == config
assert task.config["epochs"] == 100
class TestTrainingLogModel:
"""Tests for TrainingLog model."""
def test_log_creation(self):
"""Test training log creation."""
log = TrainingLog(
task_id=UUID(TEST_TASK_UUID),
level="INFO",
message="Training started",
)
assert str(log.task_id) == TEST_TASK_UUID
assert log.level == "INFO"
assert log.message == "Training started"
def test_log_with_details(self):
"""Test log with additional details."""
details = {
"epoch": 10,
"loss": 0.5,
"mAP": 0.85,
}
log = TrainingLog(
task_id=UUID(TEST_TASK_UUID),
level="INFO",
message="Epoch completed",
details=details,
)
assert log.details == details
assert log.details["epoch"] == 10
class TestTrainingScheduler:
"""Tests for TrainingScheduler."""
@pytest.fixture
def scheduler(self):
"""Create a scheduler for testing."""
return TrainingScheduler(check_interval_seconds=1)
def test_scheduler_creation(self, scheduler):
"""Test scheduler creation."""
assert scheduler._check_interval == 1
assert scheduler._running is False
assert scheduler._thread is None
def test_scheduler_start_stop(self, scheduler):
"""Test scheduler start and stop."""
with patch.object(scheduler, "_check_pending_tasks"):
scheduler.start()
assert scheduler._running is True
assert scheduler._thread is not None
scheduler.stop()
assert scheduler._running is False
def test_scheduler_singleton(self):
"""Test get_training_scheduler returns singleton."""
# Reset any existing scheduler
stop_scheduler()
s1 = get_training_scheduler()
s2 = get_training_scheduler()
assert s1 is s2
# Cleanup
stop_scheduler()
class TestTrainingStatusEnum:
"""Tests for TrainingStatus enum."""
def test_all_statuses(self):
"""Test all training statuses are defined."""
statuses = [s.value for s in TrainingStatus]
assert "pending" in statuses
assert "scheduled" in statuses
assert "running" in statuses
assert "completed" in statuses
assert "failed" in statuses
assert "cancelled" in statuses
class TestTrainingTypeEnum:
"""Tests for TrainingType enum."""
def test_all_types(self):
"""Test all training types are defined."""
types = [t.value for t in TrainingType]
assert "train" in types
assert "finetune" in types

View File

@@ -0,0 +1,276 @@
"""
Tests for Annotation Lock Mechanism (Phase 3.3).
"""
import pytest
from datetime import datetime, timedelta, timezone
from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from src.web.api.v1.admin.documents import create_admin_router
from src.web.core.auth import validate_admin_token, get_admin_db
class MockAdminDocument:
"""Mock AdminDocument for testing."""
def __init__(self, **kwargs):
self.document_id = kwargs.get('document_id', uuid4())
self.admin_token = kwargs.get('admin_token', 'test-token')
self.filename = kwargs.get('filename', 'test.pdf')
self.file_size = kwargs.get('file_size', 100000)
self.content_type = kwargs.get('content_type', 'application/pdf')
self.page_count = kwargs.get('page_count', 1)
self.status = kwargs.get('status', 'pending')
self.auto_label_status = kwargs.get('auto_label_status', None)
self.auto_label_error = kwargs.get('auto_label_error', None)
self.upload_source = kwargs.get('upload_source', 'ui')
self.batch_id = kwargs.get('batch_id', None)
self.csv_field_values = kwargs.get('csv_field_values', None)
self.annotation_lock_until = kwargs.get('annotation_lock_until', None)
self.created_at = kwargs.get('created_at', datetime.utcnow())
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockAdminDB:
"""Mock AdminDB for testing annotation locks."""
def __init__(self):
self.documents = {}
def get_document_by_token(self, document_id, admin_token):
"""Get single document by ID and token."""
doc = self.documents.get(document_id)
if doc and doc.admin_token == admin_token:
return doc
return None
def acquire_annotation_lock(self, document_id, admin_token, duration_seconds=300):
"""Acquire annotation lock for a document."""
doc = self.documents.get(document_id)
if not doc or doc.admin_token != admin_token:
return None
# Check if already locked
now = datetime.now(timezone.utc)
if doc.annotation_lock_until and doc.annotation_lock_until > now:
return None
# Acquire lock
doc.annotation_lock_until = now + timedelta(seconds=duration_seconds)
return doc
def release_annotation_lock(self, document_id, admin_token, force=False):
"""Release annotation lock for a document."""
doc = self.documents.get(document_id)
if not doc or doc.admin_token != admin_token:
return None
# Release lock
doc.annotation_lock_until = None
return doc
def extend_annotation_lock(self, document_id, admin_token, additional_seconds=300):
"""Extend an existing annotation lock."""
doc = self.documents.get(document_id)
if not doc or doc.admin_token != admin_token:
return None
# Check if lock exists and is still valid
now = datetime.now(timezone.utc)
if not doc.annotation_lock_until or doc.annotation_lock_until <= now:
return None
# Extend lock
doc.annotation_lock_until = doc.annotation_lock_until + timedelta(seconds=additional_seconds)
return doc
@pytest.fixture
def app():
"""Create test FastAPI app."""
app = FastAPI()
# Create mock DB
mock_db = MockAdminDB()
# Add test document
doc1 = MockAdminDocument(
filename="INV001.pdf",
status="pending",
upload_source="ui",
)
mock_db.documents[str(doc1.document_id)] = doc1
# Override dependencies
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
app.dependency_overrides[get_admin_db] = lambda: mock_db
# Include router
router = create_admin_router((".pdf", ".png", ".jpg"))
app.include_router(router)
return app
@pytest.fixture
def client(app):
"""Create test client."""
return TestClient(app)
@pytest.fixture
def document_id(app):
"""Get document ID from the mock DB."""
mock_db = app.dependency_overrides[get_admin_db]()
return str(list(mock_db.documents.keys())[0])
class TestAnnotationLocks:
"""Tests for annotation lock endpoints."""
def test_acquire_lock_success(self, client, document_id):
"""Test successfully acquiring an annotation lock."""
response = client.post(
f"/admin/documents/{document_id}/lock",
json={"duration_seconds": 300}
)
assert response.status_code == 200
data = response.json()
assert data["document_id"] == document_id
assert data["locked"] is True
assert data["lock_expires_at"] is not None
assert "Lock acquired for 300 seconds" in data["message"]
def test_acquire_lock_already_locked(self, client, document_id):
"""Test acquiring lock on already locked document."""
# First lock
response1 = client.post(
f"/admin/documents/{document_id}/lock",
json={"duration_seconds": 300}
)
assert response1.status_code == 200
# Try to lock again
response2 = client.post(
f"/admin/documents/{document_id}/lock",
json={"duration_seconds": 300}
)
assert response2.status_code == 409
assert "already locked" in response2.json()["detail"]
def test_release_lock_success(self, client, document_id):
"""Test successfully releasing an annotation lock."""
# First acquire lock
client.post(
f"/admin/documents/{document_id}/lock",
json={"duration_seconds": 300}
)
# Then release it
response = client.delete(f"/admin/documents/{document_id}/lock")
assert response.status_code == 200
data = response.json()
assert data["document_id"] == document_id
assert data["locked"] is False
assert data["lock_expires_at"] is None
assert "released successfully" in data["message"]
def test_release_lock_not_locked(self, client, document_id):
"""Test releasing lock on unlocked document."""
response = client.delete(f"/admin/documents/{document_id}/lock")
# Should succeed even if not locked
assert response.status_code == 200
data = response.json()
assert data["locked"] is False
def test_extend_lock_success(self, client, document_id):
"""Test successfully extending an annotation lock."""
# First acquire lock
response1 = client.post(
f"/admin/documents/{document_id}/lock",
json={"duration_seconds": 300}
)
original_expiry = response1.json()["lock_expires_at"]
# Extend lock
response2 = client.patch(
f"/admin/documents/{document_id}/lock",
json={"duration_seconds": 300}
)
assert response2.status_code == 200
data = response2.json()
assert data["document_id"] == document_id
assert data["locked"] is True
assert data["lock_expires_at"] != original_expiry
assert "extended by 300 seconds" in data["message"]
def test_extend_lock_not_locked(self, client, document_id):
"""Test extending lock on unlocked document."""
response = client.patch(
f"/admin/documents/{document_id}/lock",
json={"duration_seconds": 300}
)
assert response.status_code == 409
assert "doesn't exist or has expired" in response.json()["detail"]
def test_acquire_lock_custom_duration(self, client, document_id):
"""Test acquiring lock with custom duration."""
response = client.post(
f"/admin/documents/{document_id}/lock",
json={"duration_seconds": 600}
)
assert response.status_code == 200
data = response.json()
assert "Lock acquired for 600 seconds" in data["message"]
def test_acquire_lock_invalid_document(self, client):
"""Test acquiring lock on non-existent document."""
fake_id = str(uuid4())
response = client.post(
f"/admin/documents/{fake_id}/lock",
json={"duration_seconds": 300}
)
assert response.status_code == 404
assert "not found" in response.json()["detail"]
def test_lock_lifecycle(self, client, document_id):
"""Test complete lock lifecycle: acquire -> extend -> release."""
# Acquire
response1 = client.post(
f"/admin/documents/{document_id}/lock",
json={"duration_seconds": 300}
)
assert response1.status_code == 200
assert response1.json()["locked"] is True
# Extend
response2 = client.patch(
f"/admin/documents/{document_id}/lock",
json={"duration_seconds": 300}
)
assert response2.status_code == 200
assert response2.json()["locked"] is True
# Release
response3 = client.delete(f"/admin/documents/{document_id}/lock")
assert response3.status_code == 200
assert response3.json()["locked"] is False
# Verify can acquire again after release
response4 = client.post(
f"/admin/documents/{document_id}/lock",
json={"duration_seconds": 300}
)
assert response4.status_code == 200
assert response4.json()["locked"] is True

View File

@@ -0,0 +1,420 @@
"""
Tests for Phase 5: Annotation Enhancement (Verification and Override)
"""
import pytest
from datetime import datetime
from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from src.web.api.v1.admin.annotations import create_annotation_router
from src.web.core.auth import validate_admin_token, get_admin_db
class MockAdminDocument:
"""Mock AdminDocument for testing."""
def __init__(self, **kwargs):
self.document_id = kwargs.get('document_id', uuid4())
self.admin_token = kwargs.get('admin_token', 'test-token')
self.filename = kwargs.get('filename', 'test.pdf')
self.file_size = kwargs.get('file_size', 100000)
self.content_type = kwargs.get('content_type', 'application/pdf')
self.page_count = kwargs.get('page_count', 1)
self.status = kwargs.get('status', 'labeled')
self.auto_label_status = kwargs.get('auto_label_status', None)
self.created_at = kwargs.get('created_at', datetime.utcnow())
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockAnnotation:
"""Mock AdminAnnotation for testing."""
def __init__(self, **kwargs):
self.annotation_id = kwargs.get('annotation_id', uuid4())
self.document_id = kwargs.get('document_id')
self.page_number = kwargs.get('page_number', 1)
self.class_id = kwargs.get('class_id', 0)
self.class_name = kwargs.get('class_name', 'invoice_number')
self.bbox_x = kwargs.get('bbox_x', 100)
self.bbox_y = kwargs.get('bbox_y', 100)
self.bbox_width = kwargs.get('bbox_width', 200)
self.bbox_height = kwargs.get('bbox_height', 50)
self.x_center = kwargs.get('x_center', 0.5)
self.y_center = kwargs.get('y_center', 0.5)
self.width = kwargs.get('width', 0.3)
self.height = kwargs.get('height', 0.1)
self.text_value = kwargs.get('text_value', 'INV-001')
self.confidence = kwargs.get('confidence', 0.95)
self.source = kwargs.get('source', 'auto')
self.is_verified = kwargs.get('is_verified', False)
self.verified_at = kwargs.get('verified_at', None)
self.verified_by = kwargs.get('verified_by', None)
self.override_source = kwargs.get('override_source', None)
self.original_annotation_id = kwargs.get('original_annotation_id', None)
self.created_at = kwargs.get('created_at', datetime.utcnow())
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockAnnotationHistory:
"""Mock AnnotationHistory for testing."""
def __init__(self, **kwargs):
self.history_id = kwargs.get('history_id', uuid4())
self.annotation_id = kwargs.get('annotation_id')
self.document_id = kwargs.get('document_id')
self.action = kwargs.get('action', 'override')
self.previous_value = kwargs.get('previous_value', {})
self.new_value = kwargs.get('new_value', {})
self.changed_by = kwargs.get('changed_by', 'test-token')
self.change_reason = kwargs.get('change_reason', None)
self.created_at = kwargs.get('created_at', datetime.utcnow())
class MockAdminDB:
"""Mock AdminDB for testing Phase 5."""
def __init__(self):
self.documents = {}
self.annotations = {}
self.annotation_history = {}
def get_document_by_token(self, document_id, admin_token):
"""Get document by ID and token."""
doc = self.documents.get(str(document_id))
if doc and doc.admin_token == admin_token:
return doc
return None
def verify_annotation(self, annotation_id, admin_token):
"""Mark annotation as verified."""
annotation = self.annotations.get(str(annotation_id))
if annotation:
annotation.is_verified = True
annotation.verified_at = datetime.utcnow()
annotation.verified_by = admin_token
return annotation
return None
def override_annotation(
self,
annotation_id,
admin_token,
change_reason=None,
**updates,
):
"""Override an annotation."""
annotation = self.annotations.get(str(annotation_id))
if annotation:
# Apply updates
for key, value in updates.items():
if hasattr(annotation, key):
setattr(annotation, key, value)
# Mark as overridden if was auto-generated
if annotation.source == "auto":
annotation.override_source = "auto"
annotation.source = "manual"
# Create history record
history = MockAnnotationHistory(
annotation_id=uuid4().hex if isinstance(annotation_id, str) else annotation_id,
document_id=annotation.document_id,
action="override",
changed_by=admin_token,
change_reason=change_reason,
)
self.annotation_history[str(annotation.annotation_id)] = [history]
return annotation
return None
def get_annotation_history(self, annotation_id):
"""Get annotation history."""
return self.annotation_history.get(str(annotation_id), [])
@pytest.fixture
def app():
"""Create test FastAPI app."""
app = FastAPI()
# Create mock DB
mock_db = MockAdminDB()
# Add test document
doc1 = MockAdminDocument(
filename="TEST001.pdf",
status="labeled",
)
mock_db.documents[str(doc1.document_id)] = doc1
# Add test annotations
ann1 = MockAnnotation(
document_id=doc1.document_id,
class_id=0,
class_name="invoice_number",
text_value="INV-001",
source="auto",
confidence=0.95,
)
ann2 = MockAnnotation(
document_id=doc1.document_id,
class_id=6,
class_name="amount",
text_value="1500.00",
source="auto",
confidence=0.98,
)
mock_db.annotations[str(ann1.annotation_id)] = ann1
mock_db.annotations[str(ann2.annotation_id)] = ann2
# Store document ID and annotation IDs for tests
app.state.document_id = str(doc1.document_id)
app.state.annotation_id_1 = str(ann1.annotation_id)
app.state.annotation_id_2 = str(ann2.annotation_id)
# Override dependencies
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
app.dependency_overrides[get_admin_db] = lambda: mock_db
# Include router
router = create_annotation_router()
app.include_router(router)
return app
@pytest.fixture
def client(app):
"""Create test client."""
return TestClient(app)
class TestAnnotationVerification:
"""Tests for POST /admin/documents/{document_id}/annotations/{annotation_id}/verify endpoint."""
def test_verify_annotation_success(self, client, app):
"""Test successfully verifying an annotation."""
document_id = app.state.document_id
annotation_id = app.state.annotation_id_1
response = client.post(
f"/admin/documents/{document_id}/annotations/{annotation_id}/verify"
)
assert response.status_code == 200
data = response.json()
assert data["annotation_id"] == annotation_id
assert data["is_verified"] is True
assert data["verified_at"] is not None
assert data["verified_by"] == "test-token"
assert "verified successfully" in data["message"].lower()
def test_verify_annotation_not_found(self, client, app):
"""Test verifying non-existent annotation."""
document_id = app.state.document_id
fake_annotation_id = str(uuid4())
response = client.post(
f"/admin/documents/{document_id}/annotations/{fake_annotation_id}/verify"
)
assert response.status_code == 404
assert "not found" in response.json()["detail"].lower()
def test_verify_annotation_document_not_found(self, client):
"""Test verifying annotation with non-existent document."""
fake_document_id = str(uuid4())
fake_annotation_id = str(uuid4())
response = client.post(
f"/admin/documents/{fake_document_id}/annotations/{fake_annotation_id}/verify"
)
assert response.status_code == 404
assert "not found" in response.json()["detail"].lower()
def test_verify_annotation_invalid_uuid(self, client, app):
"""Test verifying annotation with invalid UUID format."""
document_id = app.state.document_id
response = client.post(
f"/admin/documents/{document_id}/annotations/invalid-uuid/verify"
)
assert response.status_code == 400
assert "invalid" in response.json()["detail"].lower()
class TestAnnotationOverride:
"""Tests for PATCH /admin/documents/{document_id}/annotations/{annotation_id}/override endpoint."""
def test_override_annotation_text_value(self, client, app):
"""Test overriding annotation text value."""
document_id = app.state.document_id
annotation_id = app.state.annotation_id_1
response = client.patch(
f"/admin/documents/{document_id}/annotations/{annotation_id}/override",
json={
"text_value": "INV-001-CORRECTED",
"reason": "OCR error correction"
}
)
assert response.status_code == 200
data = response.json()
assert data["annotation_id"] == annotation_id
assert data["source"] == "manual"
assert data["override_source"] == "auto"
assert "successfully" in data["message"].lower()
assert "history_id" in data
def test_override_annotation_bbox(self, client, app):
"""Test overriding annotation bounding box."""
document_id = app.state.document_id
annotation_id = app.state.annotation_id_1
response = client.patch(
f"/admin/documents/{document_id}/annotations/{annotation_id}/override",
json={
"bbox": {
"x": 110,
"y": 205,
"width": 195,
"height": 48
},
"reason": "Bbox adjustment"
}
)
assert response.status_code == 200
data = response.json()
assert data["annotation_id"] == annotation_id
assert data["source"] == "manual"
def test_override_annotation_class(self, client, app):
"""Test overriding annotation class."""
document_id = app.state.document_id
annotation_id = app.state.annotation_id_1
response = client.patch(
f"/admin/documents/{document_id}/annotations/{annotation_id}/override",
json={
"class_id": 1,
"class_name": "invoice_date",
"reason": "Wrong field classification"
}
)
assert response.status_code == 200
data = response.json()
assert data["annotation_id"] == annotation_id
def test_override_annotation_multiple_fields(self, client, app):
"""Test overriding multiple annotation fields at once."""
document_id = app.state.document_id
annotation_id = app.state.annotation_id_2
response = client.patch(
f"/admin/documents/{document_id}/annotations/{annotation_id}/override",
json={
"text_value": "1550.00",
"bbox": {
"x": 120,
"y": 210,
"width": 180,
"height": 45
},
"reason": "Multiple corrections"
}
)
assert response.status_code == 200
data = response.json()
assert data["annotation_id"] == annotation_id
def test_override_annotation_no_updates(self, client, app):
"""Test overriding annotation without providing any updates."""
document_id = app.state.document_id
annotation_id = app.state.annotation_id_1
response = client.patch(
f"/admin/documents/{document_id}/annotations/{annotation_id}/override",
json={}
)
assert response.status_code == 400
assert "no updates" in response.json()["detail"].lower()
def test_override_annotation_not_found(self, client, app):
"""Test overriding non-existent annotation."""
document_id = app.state.document_id
fake_annotation_id = str(uuid4())
response = client.patch(
f"/admin/documents/{document_id}/annotations/{fake_annotation_id}/override",
json={
"text_value": "TEST"
}
)
assert response.status_code == 404
assert "not found" in response.json()["detail"].lower()
def test_override_annotation_document_not_found(self, client):
"""Test overriding annotation with non-existent document."""
fake_document_id = str(uuid4())
fake_annotation_id = str(uuid4())
response = client.patch(
f"/admin/documents/{fake_document_id}/annotations/{fake_annotation_id}/override",
json={
"text_value": "TEST"
}
)
assert response.status_code == 404
assert "not found" in response.json()["detail"].lower()
def test_override_annotation_creates_history(self, client, app):
"""Test that overriding annotation creates history record."""
document_id = app.state.document_id
annotation_id = app.state.annotation_id_1
response = client.patch(
f"/admin/documents/{document_id}/annotations/{annotation_id}/override",
json={
"text_value": "INV-CORRECTED",
"reason": "Test history creation"
}
)
assert response.status_code == 200
data = response.json()
# History ID should be present and valid
assert "history_id" in data
assert data["history_id"] != ""
def test_override_annotation_with_reason(self, client, app):
"""Test overriding annotation with change reason."""
document_id = app.state.document_id
annotation_id = app.state.annotation_id_1
change_reason = "Correcting OCR misread"
response = client.patch(
f"/admin/documents/{document_id}/annotations/{annotation_id}/override",
json={
"text_value": "INV-002",
"reason": change_reason
}
)
assert response.status_code == 200
# Reason is stored in history, not returned in response
data = response.json()
assert data["annotation_id"] == annotation_id

View File

@@ -0,0 +1,217 @@
"""
Tests for the AsyncTaskQueue class.
"""
import tempfile
import time
from datetime import datetime
from pathlib import Path
from threading import Event
from unittest.mock import MagicMock
import pytest
from src.web.workers.async_queue import AsyncTask, AsyncTaskQueue
class TestAsyncTask:
"""Tests for AsyncTask dataclass."""
def test_create_task(self):
"""Test creating an AsyncTask."""
task = AsyncTask(
request_id="test-id",
api_key="test-key",
file_path=Path("/tmp/test.pdf"),
filename="test.pdf",
)
assert task.request_id == "test-id"
assert task.api_key == "test-key"
assert task.filename == "test.pdf"
assert task.priority == 0
assert task.created_at is not None
class TestAsyncTaskQueue:
"""Tests for AsyncTaskQueue."""
def test_init(self):
"""Test queue initialization."""
queue = AsyncTaskQueue(max_size=50, worker_count=2)
assert queue._worker_count == 2
assert queue._queue.maxsize == 50
assert not queue._started
def test_submit_task(self, task_queue, sample_task):
"""Test submitting a task to the queue."""
success = task_queue.submit(sample_task)
assert success is True
assert task_queue.get_queue_depth() == 1
def test_submit_when_full(self, sample_task):
"""Test submitting to a full queue."""
queue = AsyncTaskQueue(max_size=1, worker_count=1)
# Submit first task
queue.submit(sample_task)
# Create second task
task2 = AsyncTask(
request_id="test-2",
api_key="test-key",
file_path=sample_task.file_path,
filename="test2.pdf",
)
# Queue should be full
success = queue.submit(task2)
assert success is False
def test_get_queue_depth(self, task_queue, sample_task):
"""Test getting queue depth."""
assert task_queue.get_queue_depth() == 0
task_queue.submit(sample_task)
assert task_queue.get_queue_depth() == 1
def test_start_and_stop(self, task_queue):
"""Test starting and stopping the queue."""
handler = MagicMock()
task_queue.start(handler)
assert task_queue._started is True
assert task_queue.is_running is True
assert len(task_queue._workers) == 1
task_queue.stop(timeout=5.0)
assert task_queue._started is False
assert task_queue.is_running is False
assert len(task_queue._workers) == 0
def test_worker_processes_task(self, sample_task):
"""Test that worker thread processes tasks."""
queue = AsyncTaskQueue(max_size=10, worker_count=1)
processed = Event()
def handler(task):
processed.set()
queue.start(handler)
queue.submit(sample_task)
# Wait for processing
assert processed.wait(timeout=5.0)
queue.stop()
def test_worker_handles_errors(self, sample_task):
"""Test that worker handles errors gracefully."""
queue = AsyncTaskQueue(max_size=10, worker_count=1)
error_handled = Event()
def failing_handler(task):
error_handled.set()
raise ValueError("Test error")
queue.start(failing_handler)
queue.submit(sample_task)
# Should not crash
assert error_handled.wait(timeout=5.0)
time.sleep(0.5) # Give time for error handling
assert queue.is_running
queue.stop()
def test_processing_tracking(self, task_queue, sample_task):
"""Test tracking of processing tasks."""
processed = Event()
def slow_handler(task):
processed.set()
time.sleep(0.5)
task_queue.start(slow_handler)
task_queue.submit(sample_task)
# Wait for processing to start
assert processed.wait(timeout=5.0)
# Task should be in processing set
assert task_queue.get_processing_count() == 1
assert task_queue.is_processing(sample_task.request_id)
# Wait for completion
time.sleep(1.0)
assert task_queue.get_processing_count() == 0
assert not task_queue.is_processing(sample_task.request_id)
task_queue.stop()
def test_multiple_workers(self, sample_task):
"""Test queue with multiple workers."""
queue = AsyncTaskQueue(max_size=10, worker_count=3)
processed_count = []
def handler(task):
processed_count.append(task.request_id)
time.sleep(0.1)
queue.start(handler)
# Submit multiple tasks
for i in range(5):
task = AsyncTask(
request_id=f"task-{i}",
api_key="test-key",
file_path=sample_task.file_path,
filename=f"test-{i}.pdf",
)
queue.submit(task)
# Wait for all tasks
time.sleep(2.0)
assert len(processed_count) == 5
queue.stop()
def test_graceful_shutdown(self, sample_task):
"""Test graceful shutdown waits for current task."""
queue = AsyncTaskQueue(max_size=10, worker_count=1)
started = Event()
finished = Event()
def slow_handler(task):
started.set()
time.sleep(0.5)
finished.set()
queue.start(slow_handler)
queue.submit(sample_task)
# Wait for processing to start
assert started.wait(timeout=5.0)
# Stop should wait for task to finish
queue.stop(timeout=5.0)
assert finished.is_set()
def test_double_start(self, task_queue):
"""Test that starting twice doesn't create duplicate workers."""
handler = MagicMock()
task_queue.start(handler)
assert len(task_queue._workers) == 1
# Starting again should not add more workers
task_queue.start(handler)
assert len(task_queue._workers) == 1
task_queue.stop()

View File

@@ -0,0 +1,409 @@
"""
Tests for the async API routes.
"""
import tempfile
from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from src.data.async_request_db import ApiKeyConfig, AsyncRequest, AsyncRequestDB
from src.web.api.v1.async_api.routes import create_async_router, set_async_service
from src.web.services.async_processing import AsyncSubmitResult
from src.web.dependencies import init_dependencies
from src.web.rate_limiter import RateLimiter, RateLimitStatus
from src.web.schemas.inference import AsyncStatus
# Valid UUID for testing
TEST_REQUEST_UUID = "550e8400-e29b-41d4-a716-446655440000"
INVALID_UUID = "nonexistent-id"
@pytest.fixture
def mock_async_service():
"""Create a mock AsyncProcessingService."""
service = MagicMock()
# Mock config
mock_config = MagicMock()
mock_config.max_file_size_mb = 50
service._async_config = mock_config
# Default submit result
service.submit_request.return_value = AsyncSubmitResult(
success=True,
request_id="test-request-id",
estimated_wait_seconds=30,
)
return service
@pytest.fixture
def mock_rate_limiter(mock_db):
"""Create a mock RateLimiter."""
limiter = MagicMock(spec=RateLimiter)
# Default: allow all requests
limiter.check_submit_limit.return_value = RateLimitStatus(
allowed=True,
remaining_requests=9,
reset_at=datetime.utcnow() + timedelta(seconds=60),
)
limiter.check_poll_limit.return_value = RateLimitStatus(
allowed=True,
remaining_requests=999,
reset_at=datetime.utcnow(),
)
limiter.get_rate_limit_headers.return_value = {}
return limiter
@pytest.fixture
def app(mock_db, mock_rate_limiter, mock_async_service):
"""Create a test FastAPI app with async routes."""
app = FastAPI()
# Initialize dependencies
init_dependencies(mock_db, mock_rate_limiter)
set_async_service(mock_async_service)
# Add routes
router = create_async_router(allowed_extensions=(".pdf", ".png", ".jpg", ".jpeg"))
app.include_router(router, prefix="/api/v1")
return app
@pytest.fixture
def client(app):
"""Create a test client."""
return TestClient(app)
class TestAsyncSubmitEndpoint:
"""Tests for POST /api/v1/async/submit."""
def test_submit_success(self, client, mock_async_service):
"""Test successful submission."""
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
f.write(b"fake pdf content")
f.seek(0)
response = client.post(
"/api/v1/async/submit",
files={"file": ("test.pdf", f, "application/pdf")},
headers={"X-API-Key": "test-api-key"},
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "accepted"
assert data["request_id"] == "test-request-id"
assert "poll_url" in data
def test_submit_missing_api_key(self, client):
"""Test submission without API key."""
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
f.write(b"fake pdf content")
f.seek(0)
response = client.post(
"/api/v1/async/submit",
files={"file": ("test.pdf", f, "application/pdf")},
)
assert response.status_code == 401
assert "X-API-Key" in response.json()["detail"]
def test_submit_invalid_api_key(self, client, mock_db):
"""Test submission with invalid API key."""
mock_db.is_valid_api_key.return_value = False
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
f.write(b"fake pdf content")
f.seek(0)
response = client.post(
"/api/v1/async/submit",
files={"file": ("test.pdf", f, "application/pdf")},
headers={"X-API-Key": "invalid-key"},
)
assert response.status_code == 401
def test_submit_unsupported_file_type(self, client):
"""Test submission with unsupported file type."""
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f:
f.write(b"text content")
f.seek(0)
response = client.post(
"/api/v1/async/submit",
files={"file": ("test.txt", f, "text/plain")},
headers={"X-API-Key": "test-api-key"},
)
assert response.status_code == 400
assert "Unsupported file type" in response.json()["detail"]
def test_submit_rate_limited(self, client, mock_rate_limiter):
"""Test submission when rate limited."""
mock_rate_limiter.check_submit_limit.return_value = RateLimitStatus(
allowed=False,
remaining_requests=0,
reset_at=datetime.utcnow() + timedelta(seconds=30),
retry_after_seconds=30,
reason="Rate limit exceeded",
)
mock_rate_limiter.get_rate_limit_headers.return_value = {"Retry-After": "30"}
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
f.write(b"fake pdf content")
f.seek(0)
response = client.post(
"/api/v1/async/submit",
files={"file": ("test.pdf", f, "application/pdf")},
headers={"X-API-Key": "test-api-key"},
)
assert response.status_code == 429
assert "Retry-After" in response.headers
def test_submit_queue_full(self, client, mock_async_service):
"""Test submission when queue is full."""
mock_async_service.submit_request.return_value = AsyncSubmitResult(
success=False,
request_id="test-id",
error="Processing queue is full",
)
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
f.write(b"fake pdf content")
f.seek(0)
response = client.post(
"/api/v1/async/submit",
files={"file": ("test.pdf", f, "application/pdf")},
headers={"X-API-Key": "test-api-key"},
)
assert response.status_code == 503
class TestAsyncStatusEndpoint:
"""Tests for GET /api/v1/async/status/{request_id}."""
def test_get_status_pending(self, client, mock_db, sample_async_request):
"""Test getting status of pending request."""
mock_db.get_request_by_api_key.return_value = sample_async_request
mock_db.get_queue_position.return_value = 3
response = client.get(
"/api/v1/async/status/550e8400-e29b-41d4-a716-446655440000",
headers={"X-API-Key": "test-api-key"},
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "pending"
assert data["position_in_queue"] == 3
assert data["result_url"] is None
def test_get_status_completed(self, client, mock_db, sample_async_request):
"""Test getting status of completed request."""
sample_async_request.status = "completed"
sample_async_request.completed_at = datetime.utcnow()
mock_db.get_request_by_api_key.return_value = sample_async_request
response = client.get(
"/api/v1/async/status/550e8400-e29b-41d4-a716-446655440000",
headers={"X-API-Key": "test-api-key"},
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "completed"
assert data["result_url"] is not None
def test_get_status_not_found(self, client, mock_db):
"""Test getting status of non-existent request."""
mock_db.get_request_by_api_key.return_value = None
response = client.get(
"/api/v1/async/status/00000000-0000-0000-0000-000000000000",
headers={"X-API-Key": "test-api-key"},
)
assert response.status_code == 404
def test_get_status_wrong_api_key(self, client, mock_db, sample_async_request):
"""Test that requests are isolated by API key."""
# Request belongs to different API key
mock_db.get_request_by_api_key.return_value = None
response = client.get(
"/api/v1/async/status/550e8400-e29b-41d4-a716-446655440000",
headers={"X-API-Key": "different-api-key"},
)
assert response.status_code == 404
class TestAsyncResultEndpoint:
"""Tests for GET /api/v1/async/result/{request_id}."""
def test_get_result_completed(self, client, mock_db, sample_async_request):
"""Test getting result of completed request."""
sample_async_request.status = "completed"
sample_async_request.completed_at = datetime.utcnow()
sample_async_request.processing_time_ms = 1234.5
sample_async_request.result = {
"document_id": "test-doc",
"success": True,
"document_type": "invoice",
"fields": {"InvoiceNumber": "12345"},
"confidence": {"InvoiceNumber": 0.95},
"detections": [],
"errors": [],
}
mock_db.get_request_by_api_key.return_value = sample_async_request
response = client.get(
"/api/v1/async/result/550e8400-e29b-41d4-a716-446655440000",
headers={"X-API-Key": "test-api-key"},
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "completed"
assert data["result"] is not None
assert data["result"]["fields"]["InvoiceNumber"] == "12345"
def test_get_result_not_completed(self, client, mock_db, sample_async_request):
"""Test getting result of pending request."""
mock_db.get_request_by_api_key.return_value = sample_async_request
response = client.get(
"/api/v1/async/result/550e8400-e29b-41d4-a716-446655440000",
headers={"X-API-Key": "test-api-key"},
)
assert response.status_code == 409
assert "not yet completed" in response.json()["detail"]
def test_get_result_failed(self, client, mock_db, sample_async_request):
"""Test getting result of failed request."""
sample_async_request.status = "failed"
sample_async_request.error_message = "Processing failed"
sample_async_request.processing_time_ms = 500.0
mock_db.get_request_by_api_key.return_value = sample_async_request
response = client.get(
"/api/v1/async/result/550e8400-e29b-41d4-a716-446655440000",
headers={"X-API-Key": "test-api-key"},
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "failed"
class TestAsyncListEndpoint:
"""Tests for GET /api/v1/async/requests."""
def test_list_requests(self, client, mock_db, sample_async_request):
"""Test listing requests."""
mock_db.get_requests_by_api_key.return_value = ([sample_async_request], 1)
response = client.get(
"/api/v1/async/requests",
headers={"X-API-Key": "test-api-key"},
)
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert len(data["requests"]) == 1
def test_list_requests_with_status_filter(self, client, mock_db):
"""Test listing requests with status filter."""
mock_db.get_requests_by_api_key.return_value = ([], 0)
response = client.get(
"/api/v1/async/requests?status=completed",
headers={"X-API-Key": "test-api-key"},
)
assert response.status_code == 200
mock_db.get_requests_by_api_key.assert_called_once()
call_kwargs = mock_db.get_requests_by_api_key.call_args[1]
assert call_kwargs["status"] == "completed"
def test_list_requests_pagination(self, client, mock_db):
"""Test listing requests with pagination."""
mock_db.get_requests_by_api_key.return_value = ([], 0)
response = client.get(
"/api/v1/async/requests?limit=50&offset=10",
headers={"X-API-Key": "test-api-key"},
)
assert response.status_code == 200
call_kwargs = mock_db.get_requests_by_api_key.call_args[1]
assert call_kwargs["limit"] == 50
assert call_kwargs["offset"] == 10
def test_list_requests_invalid_status(self, client, mock_db):
"""Test listing with invalid status filter."""
response = client.get(
"/api/v1/async/requests?status=invalid",
headers={"X-API-Key": "test-api-key"},
)
assert response.status_code == 400
class TestAsyncDeleteEndpoint:
"""Tests for DELETE /api/v1/async/requests/{request_id}."""
def test_delete_pending_request(self, client, mock_db, sample_async_request):
"""Test deleting a pending request."""
mock_db.get_request_by_api_key.return_value = sample_async_request
response = client.delete(
"/api/v1/async/requests/550e8400-e29b-41d4-a716-446655440000",
headers={"X-API-Key": "test-api-key"},
)
assert response.status_code == 200
assert response.json()["status"] == "deleted"
def test_delete_processing_request(self, client, mock_db, sample_async_request):
"""Test that processing requests cannot be deleted."""
sample_async_request.status = "processing"
mock_db.get_request_by_api_key.return_value = sample_async_request
response = client.delete(
"/api/v1/async/requests/550e8400-e29b-41d4-a716-446655440000",
headers={"X-API-Key": "test-api-key"},
)
assert response.status_code == 409
def test_delete_not_found(self, client, mock_db):
"""Test deleting non-existent request."""
mock_db.get_request_by_api_key.return_value = None
response = client.delete(
"/api/v1/async/requests/00000000-0000-0000-0000-000000000000",
headers={"X-API-Key": "test-api-key"},
)
assert response.status_code == 404

View File

@@ -0,0 +1,266 @@
"""
Tests for the AsyncProcessingService class.
"""
import tempfile
import time
from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from src.data.async_request_db import AsyncRequest
from src.web.workers.async_queue import AsyncTask, AsyncTaskQueue
from src.web.services.async_processing import AsyncProcessingService, AsyncSubmitResult
from src.web.config import AsyncConfig, StorageConfig
from src.web.rate_limiter import RateLimiter
@pytest.fixture
def async_service(mock_db, mock_inference_service, rate_limiter, storage_config):
"""Create an AsyncProcessingService for testing."""
with tempfile.TemporaryDirectory() as tmpdir:
async_config = AsyncConfig(
queue_max_size=10,
worker_count=1,
task_timeout_seconds=30,
result_retention_days=7,
temp_upload_dir=Path(tmpdir) / "async",
max_file_size_mb=10,
)
queue = AsyncTaskQueue(max_size=10, worker_count=1)
service = AsyncProcessingService(
inference_service=mock_inference_service,
db=mock_db,
queue=queue,
rate_limiter=rate_limiter,
async_config=async_config,
storage_config=storage_config,
)
yield service
# Cleanup
if service._queue._started:
service.stop()
class TestAsyncProcessingService:
"""Tests for AsyncProcessingService."""
def test_submit_request_success(self, async_service, mock_db):
"""Test successful request submission."""
mock_db.create_request.return_value = "test-request-id"
result = async_service.submit_request(
api_key="test-api-key",
file_content=b"fake pdf content",
filename="test.pdf",
content_type="application/pdf",
)
assert result.success is True
assert result.request_id is not None
assert result.estimated_wait_seconds >= 0
assert result.error is None
def test_submit_request_creates_db_record(self, async_service, mock_db):
"""Test that submission creates database record."""
async_service.submit_request(
api_key="test-api-key",
file_content=b"fake pdf content",
filename="test.pdf",
content_type="application/pdf",
)
mock_db.create_request.assert_called_once()
call_kwargs = mock_db.create_request.call_args[1]
assert call_kwargs["api_key"] == "test-api-key"
assert call_kwargs["filename"] == "test.pdf"
assert call_kwargs["content_type"] == "application/pdf"
def test_submit_request_saves_file(self, async_service, mock_db):
"""Test that submission saves file to temp directory."""
content = b"fake pdf content"
result = async_service.submit_request(
api_key="test-api-key",
file_content=content,
filename="test.pdf",
content_type="application/pdf",
)
# File should exist in temp directory
temp_dir = async_service._async_config.temp_upload_dir
files = list(temp_dir.iterdir())
# Note: file may be cleaned up quickly if queue processes it
# So we just check that the operation succeeded
assert result.success is True
def test_submit_request_records_rate_limit(self, async_service, mock_db, rate_limiter):
"""Test that submission records rate limit event."""
async_service.submit_request(
api_key="test-api-key",
file_content=b"fake pdf content",
filename="test.pdf",
content_type="application/pdf",
)
# Rate limiter should have recorded the request
mock_db.record_rate_limit_event.assert_called()
def test_start_and_stop(self, async_service):
"""Test starting and stopping the service."""
async_service.start()
assert async_service._queue._started is True
assert async_service._cleanup_thread is not None
assert async_service._cleanup_thread.is_alive()
async_service.stop()
assert async_service._queue._started is False
def test_process_task_success(self, async_service, mock_db, mock_inference_service, sample_task):
"""Test successful task processing."""
async_service._process_task(sample_task)
# Should update status to processing
mock_db.update_status.assert_called_with(sample_task.request_id, "processing")
# Should complete the request
mock_db.complete_request.assert_called_once()
call_kwargs = mock_db.complete_request.call_args[1]
assert call_kwargs["request_id"] == sample_task.request_id
assert "document_id" in call_kwargs
def test_process_task_pdf(self, async_service, mock_db, mock_inference_service, sample_task):
"""Test processing a PDF task."""
async_service._process_task(sample_task)
# Should call process_pdf for .pdf files
mock_inference_service.process_pdf.assert_called_once()
def test_process_task_image(self, async_service, mock_db, mock_inference_service):
"""Test processing an image task."""
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(b"fake image content")
task = AsyncTask(
request_id="image-task",
api_key="test-api-key",
file_path=Path(f.name),
filename="test.png",
)
async_service._process_task(task)
# Should call process_image for image files
mock_inference_service.process_image.assert_called_once()
def test_process_task_failure(self, async_service, mock_db, mock_inference_service, sample_task):
"""Test task processing failure."""
mock_inference_service.process_pdf.side_effect = Exception("Processing failed")
async_service._process_task(sample_task)
# Should update status to failed
mock_db.update_status.assert_called()
last_call = mock_db.update_status.call_args_list[-1]
assert last_call[0][1] == "failed" # status
assert "Processing failed" in last_call[1]["error_message"]
def test_process_task_file_not_found(self, async_service, mock_db):
"""Test task processing with missing file."""
task = AsyncTask(
request_id="missing-file-task",
api_key="test-api-key",
file_path=Path("/nonexistent/file.pdf"),
filename="test.pdf",
)
async_service._process_task(task)
# Should fail with file not found
mock_db.update_status.assert_called()
last_call = mock_db.update_status.call_args_list[-1]
assert last_call[0][1] == "failed"
def test_process_task_cleans_up_file(self, async_service, mock_db, mock_inference_service):
"""Test that task processing cleans up the uploaded file."""
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
f.write(b"fake pdf content")
file_path = Path(f.name)
task = AsyncTask(
request_id="cleanup-task",
api_key="test-api-key",
file_path=file_path,
filename="test.pdf",
)
async_service._process_task(task)
# File should be deleted
assert not file_path.exists()
def test_estimate_wait(self, async_service):
"""Test wait time estimation."""
# Empty queue
wait = async_service._estimate_wait()
assert wait == 0
def test_cleanup_orphan_files(self, async_service, mock_db):
"""Test cleanup of orphan files."""
# Create an orphan file
temp_dir = async_service._async_config.temp_upload_dir
orphan_file = temp_dir / "orphan-request.pdf"
orphan_file.write_bytes(b"orphan content")
# Set file mtime to old
import os
old_time = time.time() - 7200
os.utime(orphan_file, (old_time, old_time))
# Mock database to say file doesn't exist
mock_db.get_request.return_value = None
count = async_service._cleanup_orphan_files()
assert count == 1
assert not orphan_file.exists()
def test_save_upload(self, async_service):
"""Test saving uploaded file."""
content = b"test content"
file_path = async_service._save_upload(
request_id="test-save",
filename="test.pdf",
content=content,
)
assert file_path.exists()
assert file_path.read_bytes() == content
assert file_path.suffix == ".pdf"
# Cleanup
file_path.unlink()
def test_save_upload_preserves_extension(self, async_service):
"""Test that save_upload preserves file extension."""
content = b"test content"
# Test various extensions
for ext in [".pdf", ".png", ".jpg", ".jpeg"]:
file_path = async_service._save_upload(
request_id=f"test-{ext}",
filename=f"test{ext}",
content=content,
)
assert file_path.suffix == ext
file_path.unlink()

View File

@@ -0,0 +1,250 @@
"""
Tests for Auto-Label Service with Annotation Lock Integration (Phase 3.5).
"""
import pytest
from datetime import datetime, timedelta, timezone
from pathlib import Path
from unittest.mock import Mock, MagicMock
from uuid import uuid4
from src.web.services.autolabel import AutoLabelService
from src.data.admin_db import AdminDB
class MockDocument:
"""Mock document for testing."""
def __init__(self, document_id, annotation_lock_until=None):
self.document_id = document_id
self.annotation_lock_until = annotation_lock_until
self.status = "pending"
self.auto_label_status = None
self.auto_label_error = None
class MockAdminDB:
"""Mock AdminDB for testing."""
def __init__(self):
self.documents = {}
self.annotations = []
self.status_updates = []
def get_document(self, document_id):
"""Get document by ID."""
return self.documents.get(str(document_id))
def update_document_status(
self,
document_id,
status=None,
auto_label_status=None,
auto_label_error=None,
):
"""Mock status update."""
self.status_updates.append({
"document_id": document_id,
"status": status,
"auto_label_status": auto_label_status,
"auto_label_error": auto_label_error,
})
doc = self.documents.get(str(document_id))
if doc:
if status:
doc.status = status
if auto_label_status:
doc.auto_label_status = auto_label_status
if auto_label_error:
doc.auto_label_error = auto_label_error
def delete_annotations_for_document(self, document_id, source=None):
"""Mock delete annotations."""
return 0
def create_annotations_batch(self, annotations):
"""Mock create annotations."""
self.annotations.extend(annotations)
@pytest.fixture
def mock_db():
"""Create mock admin DB."""
return MockAdminDB()
@pytest.fixture
def auto_label_service(monkeypatch):
"""Create auto-label service with mocked image processing."""
service = AutoLabelService()
# Mock the OCR engine to avoid dependencies
service._ocr_engine = Mock()
service._ocr_engine.extract_from_image = Mock(return_value=[])
# Mock the image processing methods to avoid file I/O errors
def mock_process_image(self, document_id, image_path, field_values, db, page_number=1):
return 0 # No annotations created (mocked)
monkeypatch.setattr(AutoLabelService, "_process_image", mock_process_image)
return service
class TestAutoLabelWithLocks:
"""Tests for auto-label service with lock integration."""
def test_auto_label_unlocked_document_succeeds(self, auto_label_service, mock_db, tmp_path):
"""Test auto-labeling succeeds on unlocked document."""
# Create test document (unlocked)
document_id = str(uuid4())
mock_db.documents[document_id] = MockDocument(
document_id=document_id,
annotation_lock_until=None,
)
# Create dummy file
test_file = tmp_path / "test.png"
test_file.write_text("dummy")
# Attempt auto-label
result = auto_label_service.auto_label_document(
document_id=document_id,
file_path=str(test_file),
field_values={"invoice_number": "INV-001"},
db=mock_db,
)
# Should succeed
assert result["status"] == "completed"
# Verify status was updated to running and then completed
assert len(mock_db.status_updates) >= 2
assert mock_db.status_updates[0]["auto_label_status"] == "running"
def test_auto_label_locked_document_fails(self, auto_label_service, mock_db, tmp_path):
"""Test auto-labeling fails on locked document."""
# Create test document (locked for 1 hour)
document_id = str(uuid4())
lock_until = datetime.now(timezone.utc) + timedelta(hours=1)
mock_db.documents[document_id] = MockDocument(
document_id=document_id,
annotation_lock_until=lock_until,
)
# Create dummy file
test_file = tmp_path / "test.png"
test_file.write_text("dummy")
# Attempt auto-label (should fail)
result = auto_label_service.auto_label_document(
document_id=document_id,
file_path=str(test_file),
field_values={"invoice_number": "INV-001"},
db=mock_db,
)
# Should fail
assert result["status"] == "failed"
assert "locked for annotation" in result["error"]
assert result["annotations_created"] == 0
# Verify status was updated to failed
assert any(
update["auto_label_status"] == "failed"
for update in mock_db.status_updates
)
def test_auto_label_expired_lock_succeeds(self, auto_label_service, mock_db, tmp_path):
"""Test auto-labeling succeeds when lock has expired."""
# Create test document (lock expired 1 hour ago)
document_id = str(uuid4())
lock_until = datetime.now(timezone.utc) - timedelta(hours=1)
mock_db.documents[document_id] = MockDocument(
document_id=document_id,
annotation_lock_until=lock_until,
)
# Create dummy file
test_file = tmp_path / "test.png"
test_file.write_text("dummy")
# Attempt auto-label
result = auto_label_service.auto_label_document(
document_id=document_id,
file_path=str(test_file),
field_values={"invoice_number": "INV-001"},
db=mock_db,
)
# Should succeed (lock expired)
assert result["status"] == "completed"
def test_auto_label_skip_lock_check(self, auto_label_service, mock_db, tmp_path):
"""Test auto-labeling with skip_lock_check=True bypasses lock."""
# Create test document (locked)
document_id = str(uuid4())
lock_until = datetime.now(timezone.utc) + timedelta(hours=1)
mock_db.documents[document_id] = MockDocument(
document_id=document_id,
annotation_lock_until=lock_until,
)
# Create dummy file
test_file = tmp_path / "test.png"
test_file.write_text("dummy")
# Attempt auto-label with skip_lock_check=True
result = auto_label_service.auto_label_document(
document_id=document_id,
file_path=str(test_file),
field_values={"invoice_number": "INV-001"},
db=mock_db,
skip_lock_check=True, # Bypass lock check
)
# Should succeed even though document is locked
assert result["status"] == "completed"
def test_auto_label_document_not_found(self, auto_label_service, mock_db, tmp_path):
"""Test auto-labeling fails when document doesn't exist."""
# Create dummy file
test_file = tmp_path / "test.png"
test_file.write_text("dummy")
# Attempt auto-label on non-existent document
result = auto_label_service.auto_label_document(
document_id=str(uuid4()),
file_path=str(test_file),
field_values={"invoice_number": "INV-001"},
db=mock_db,
)
# Should fail
assert result["status"] == "failed"
assert "not found" in result["error"]
def test_auto_label_respects_lock_by_default(self, auto_label_service, mock_db, tmp_path):
"""Test that lock check is enabled by default."""
# Create test document (locked)
document_id = str(uuid4())
lock_until = datetime.now(timezone.utc) + timedelta(minutes=30)
mock_db.documents[document_id] = MockDocument(
document_id=document_id,
annotation_lock_until=lock_until,
)
# Create dummy file
test_file = tmp_path / "test.png"
test_file.write_text("dummy")
# Call without explicit skip_lock_check (defaults to False)
result = auto_label_service.auto_label_document(
document_id=document_id,
file_path=str(test_file),
field_values={"invoice_number": "INV-001"},
db=mock_db,
# skip_lock_check not specified, should default to False
)
# Should fail due to lock
assert result["status"] == "failed"
assert "locked" in result["error"].lower()

View File

@@ -0,0 +1,282 @@
"""
Tests for Batch Upload Queue
"""
import time
from datetime import datetime
from threading import Event
from uuid import uuid4
import pytest
from src.web.workers.batch_queue import BatchTask, BatchTaskQueue
class MockBatchService:
"""Mock batch upload service for testing."""
def __init__(self):
self.processed_tasks = []
self.process_delay = 0.1 # Simulate processing time
self.should_fail = False
def process_zip_upload(self, admin_token, zip_filename, zip_content, upload_source):
"""Mock process_zip_upload method."""
if self.should_fail:
raise Exception("Simulated processing error")
time.sleep(self.process_delay) # Simulate work
self.processed_tasks.append({
"admin_token": admin_token,
"zip_filename": zip_filename,
"upload_source": upload_source,
})
return {
"status": "completed",
"successful_files": 1,
"failed_files": 0,
}
class TestBatchTask:
"""Tests for BatchTask dataclass."""
def test_batch_task_creation(self):
"""BatchTask can be created with required fields."""
task = BatchTask(
batch_id=uuid4(),
admin_token="test-token",
zip_content=b"test",
zip_filename="test.zip",
upload_source="ui",
auto_label=True,
created_at=datetime.utcnow(),
)
assert task.batch_id is not None
assert task.admin_token == "test-token"
assert task.zip_filename == "test.zip"
assert task.upload_source == "ui"
assert task.auto_label is True
class TestBatchTaskQueue:
"""Tests for batch task queue functionality."""
def test_queue_initialization(self):
"""Queue initializes with correct defaults."""
queue = BatchTaskQueue(max_size=10, worker_count=1)
assert queue.get_queue_depth() == 0
assert queue.is_running is False
assert queue._worker_count == 1
def test_start_queue(self):
"""Queue starts with batch service."""
queue = BatchTaskQueue(max_size=10, worker_count=1)
service = MockBatchService()
queue.start(service)
assert queue.is_running is True
assert len(queue._workers) == 1
queue.stop()
def test_stop_queue(self):
"""Queue stops gracefully."""
queue = BatchTaskQueue(max_size=10, worker_count=1)
service = MockBatchService()
queue.start(service)
assert queue.is_running is True
queue.stop(timeout=5.0)
assert queue.is_running is False
assert len(queue._workers) == 0
def test_submit_task_success(self):
"""Task is submitted to queue successfully."""
queue = BatchTaskQueue(max_size=10, worker_count=1)
task = BatchTask(
batch_id=uuid4(),
admin_token="test-token",
zip_content=b"test",
zip_filename="test.zip",
upload_source="ui",
auto_label=True,
created_at=datetime.utcnow(),
)
result = queue.submit(task)
assert result is True
assert queue.get_queue_depth() == 1
def test_submit_task_queue_full(self):
"""Returns False when queue is full."""
queue = BatchTaskQueue(max_size=2, worker_count=1)
# Fill the queue
for i in range(2):
task = BatchTask(
batch_id=uuid4(),
admin_token="test-token",
zip_content=b"test",
zip_filename=f"test{i}.zip",
upload_source="ui",
auto_label=True,
created_at=datetime.utcnow(),
)
assert queue.submit(task) is True
# Try to add one more (should fail)
extra_task = BatchTask(
batch_id=uuid4(),
admin_token="test-token",
zip_content=b"test",
zip_filename="extra.zip",
upload_source="ui",
auto_label=True,
created_at=datetime.utcnow(),
)
result = queue.submit(extra_task)
assert result is False
assert queue.get_queue_depth() == 2
def test_worker_processes_task(self):
"""Worker thread processes queued tasks."""
queue = BatchTaskQueue(max_size=10, worker_count=1)
service = MockBatchService()
queue.start(service)
task = BatchTask(
batch_id=uuid4(),
admin_token="test-token",
zip_content=b"test",
zip_filename="test.zip",
upload_source="ui",
auto_label=True,
created_at=datetime.utcnow(),
)
queue.submit(task)
# Wait for processing
time.sleep(0.5)
assert len(service.processed_tasks) == 1
assert service.processed_tasks[0]["zip_filename"] == "test.zip"
queue.stop()
def test_multiple_tasks_processed(self):
"""Multiple tasks are processed in order."""
queue = BatchTaskQueue(max_size=10, worker_count=1)
service = MockBatchService()
queue.start(service)
# Submit multiple tasks
for i in range(3):
task = BatchTask(
batch_id=uuid4(),
admin_token="test-token",
zip_content=b"test",
zip_filename=f"test{i}.zip",
upload_source="ui",
auto_label=True,
created_at=datetime.utcnow(),
)
queue.submit(task)
# Wait for all to process
time.sleep(1.0)
assert len(service.processed_tasks) == 3
queue.stop()
def test_get_queue_depth(self):
"""Returns correct queue depth."""
queue = BatchTaskQueue(max_size=10, worker_count=1)
assert queue.get_queue_depth() == 0
# Add tasks
for i in range(3):
task = BatchTask(
batch_id=uuid4(),
admin_token="test-token",
zip_content=b"test",
zip_filename=f"test{i}.zip",
upload_source="ui",
auto_label=True,
created_at=datetime.utcnow(),
)
queue.submit(task)
assert queue.get_queue_depth() == 3
def test_is_running_property(self):
"""is_running reflects queue state."""
queue = BatchTaskQueue(max_size=10, worker_count=1)
service = MockBatchService()
assert queue.is_running is False
queue.start(service)
assert queue.is_running is True
queue.stop()
assert queue.is_running is False
def test_double_start_ignored(self):
"""Starting queue twice is safely ignored."""
queue = BatchTaskQueue(max_size=10, worker_count=1)
service = MockBatchService()
queue.start(service)
worker_count_after_first_start = len(queue._workers)
queue.start(service) # Second start
worker_count_after_second_start = len(queue._workers)
assert worker_count_after_first_start == worker_count_after_second_start
queue.stop()
def test_error_handling_in_worker(self):
"""Worker handles processing errors gracefully."""
queue = BatchTaskQueue(max_size=10, worker_count=1)
service = MockBatchService()
service.should_fail = True # Cause errors
queue.start(service)
task = BatchTask(
batch_id=uuid4(),
admin_token="test-token",
zip_content=b"test",
zip_filename="test.zip",
upload_source="ui",
auto_label=True,
created_at=datetime.utcnow(),
)
queue.submit(task)
# Wait for processing attempt
time.sleep(0.5)
# Worker should still be running
assert queue.is_running is True
queue.stop()

View File

@@ -0,0 +1,368 @@
"""
Tests for Batch Upload Routes
"""
import io
import zipfile
from datetime import datetime
from uuid import uuid4
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from src.web.api.v1.batch.routes import router
from src.web.core.auth import validate_admin_token, get_admin_db
from src.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
from src.web.services.batch_upload import BatchUploadService
class MockAdminDB:
"""Mock AdminDB for testing."""
def __init__(self):
self.batches = {}
self.batch_files = {}
def create_batch_upload(self, admin_token, filename, file_size, upload_source):
batch_id = uuid4()
batch = type('BatchUpload', (), {
'batch_id': batch_id,
'admin_token': admin_token,
'filename': filename,
'file_size': file_size,
'upload_source': upload_source,
'status': 'processing',
'total_files': 0,
'processed_files': 0,
'successful_files': 0,
'failed_files': 0,
'csv_filename': None,
'csv_row_count': None,
'error_message': None,
'created_at': datetime.utcnow(),
'completed_at': None,
})()
self.batches[batch_id] = batch
return batch
def update_batch_upload(self, batch_id, **kwargs):
if batch_id in self.batches:
batch = self.batches[batch_id]
for key, value in kwargs.items():
setattr(batch, key, value)
def create_batch_upload_file(self, batch_id, filename, **kwargs):
file_id = uuid4()
defaults = {
'file_id': file_id,
'batch_id': batch_id,
'filename': filename,
'status': 'pending',
'error_message': None,
'annotation_count': 0,
'csv_row_data': None,
}
defaults.update(kwargs)
file_record = type('BatchUploadFile', (), defaults)()
if batch_id not in self.batch_files:
self.batch_files[batch_id] = []
self.batch_files[batch_id].append(file_record)
return file_record
def update_batch_upload_file(self, file_id, **kwargs):
for files in self.batch_files.values():
for file_record in files:
if file_record.file_id == file_id:
for key, value in kwargs.items():
setattr(file_record, key, value)
return
def get_batch_upload(self, batch_id):
return self.batches.get(batch_id, type('BatchUpload', (), {
'batch_id': batch_id,
'admin_token': 'test-token',
'filename': 'test.zip',
'status': 'completed',
'total_files': 2,
'processed_files': 2,
'successful_files': 2,
'failed_files': 0,
'csv_filename': None,
'csv_row_count': None,
'error_message': None,
'created_at': datetime.utcnow(),
'completed_at': datetime.utcnow(),
})())
def get_batch_upload_files(self, batch_id):
return self.batch_files.get(batch_id, [])
def get_batch_uploads_by_token(self, admin_token, limit=50, offset=0):
"""Get batches filtered by admin token with pagination."""
token_batches = [b for b in self.batches.values() if b.admin_token == admin_token]
total = len(token_batches)
return token_batches[offset:offset+limit], total
@pytest.fixture(scope="class")
def app():
"""Create test FastAPI app with mocked dependencies."""
app = FastAPI()
# Create mock admin DB
mock_admin_db = MockAdminDB()
# Override dependencies
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
app.dependency_overrides[get_admin_db] = lambda: mock_admin_db
# Initialize batch queue with mock service
batch_service = BatchUploadService(mock_admin_db)
init_batch_queue(batch_service)
app.include_router(router)
yield app
# Cleanup: shutdown batch queue after all tests in class
shutdown_batch_queue()
@pytest.fixture
def client(app):
"""Create test client."""
return TestClient(app)
def create_test_zip(files):
"""Create a test ZIP file."""
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
for filename, content in files.items():
zip_file.writestr(filename, content)
zip_buffer.seek(0)
return zip_buffer
class TestBatchUploadRoutes:
"""Tests for batch upload API routes."""
def test_upload_batch_success(self, client):
"""Test successful batch upload (defaults to async mode)."""
files = {
"INV001.pdf": b"%PDF-1.4 test content",
"INV002.pdf": b"%PDF-1.4 test content 2",
}
zip_file = create_test_zip(files)
response = client.post(
"/api/v1/admin/batch/upload",
files={"file": ("test.zip", zip_file, "application/zip")},
data={"upload_source": "ui"},
)
# Async mode is default, should return 202
assert response.status_code == 202
result = response.json()
assert "batch_id" in result
assert result["status"] == "accepted"
def test_upload_batch_non_zip_file(self, client):
"""Test uploading non-ZIP file."""
response = client.post(
"/api/v1/admin/batch/upload",
files={"file": ("test.pdf", io.BytesIO(b"test"), "application/pdf")},
data={"upload_source": "ui"},
)
assert response.status_code == 400
assert "Only ZIP files" in response.json()["detail"]
def test_upload_batch_with_csv(self, client):
"""Test batch upload with CSV (defaults to async)."""
csv_content = """DocumentId,InvoiceNumber,Amount
INV001,F2024-001,1500.00
INV002,F2024-002,2500.00
"""
files = {
"INV001.pdf": b"%PDF-1.4 test",
"INV002.pdf": b"%PDF-1.4 test 2",
"metadata.csv": csv_content.encode('utf-8'),
}
zip_file = create_test_zip(files)
response = client.post(
"/api/v1/admin/batch/upload",
files={"file": ("batch.zip", zip_file, "application/zip")},
data={"upload_source": "api"},
)
# Async mode is default, should return 202
assert response.status_code == 202
result = response.json()
assert "batch_id" in result
assert result["status"] == "accepted"
def test_get_batch_status(self, client):
"""Test getting batch status."""
batch_id = str(uuid4())
response = client.get(f"/api/v1/admin/batch/status/{batch_id}")
assert response.status_code == 200
result = response.json()
assert result["batch_id"] == batch_id
assert "status" in result
assert "total_files" in result
def test_list_batch_uploads(self, client):
"""Test listing batch uploads."""
response = client.get("/api/v1/admin/batch/list")
assert response.status_code == 200
result = response.json()
assert "batches" in result
assert "total" in result
assert "limit" in result
assert "offset" in result
def test_upload_batch_async_mode_default(self, client):
"""Test async mode is default (async_mode=True)."""
files = {
"INV001.pdf": b"%PDF-1.4 test content",
}
zip_file = create_test_zip(files)
response = client.post(
"/api/v1/admin/batch/upload",
files={"file": ("test.zip", zip_file, "application/zip")},
data={"upload_source": "ui"},
)
# Async mode should return 202 Accepted
assert response.status_code == 202
result = response.json()
assert result["status"] == "accepted"
assert "batch_id" in result
assert "status_url" in result
assert "queue_depth" in result
assert result["message"] == "Batch upload queued for processing"
def test_upload_batch_async_mode_explicit(self, client):
"""Test explicit async mode (async_mode=True)."""
files = {
"INV001.pdf": b"%PDF-1.4 test content",
}
zip_file = create_test_zip(files)
response = client.post(
"/api/v1/admin/batch/upload",
files={"file": ("test.zip", zip_file, "application/zip")},
data={"upload_source": "ui", "async_mode": "true"},
)
assert response.status_code == 202
result = response.json()
assert result["status"] == "accepted"
assert "batch_id" in result
assert "status_url" in result
def test_upload_batch_sync_mode(self, client):
"""Test sync mode (async_mode=False)."""
files = {
"INV001.pdf": b"%PDF-1.4 test content",
}
zip_file = create_test_zip(files)
response = client.post(
"/api/v1/admin/batch/upload",
files={"file": ("test.zip", zip_file, "application/zip")},
data={"upload_source": "ui", "async_mode": "false"},
)
# Sync mode should return 200 OK with full results
assert response.status_code == 200
result = response.json()
assert "batch_id" in result
assert result["status"] in ["completed", "partial", "failed"]
assert "successful_files" in result
def test_upload_batch_async_with_auto_label(self, client):
"""Test async mode with auto_label flag."""
files = {
"INV001.pdf": b"%PDF-1.4 test content",
}
zip_file = create_test_zip(files)
response = client.post(
"/api/v1/admin/batch/upload",
files={"file": ("test.zip", zip_file, "application/zip")},
data={
"upload_source": "ui",
"async_mode": "true",
"auto_label": "true",
},
)
assert response.status_code == 202
result = response.json()
assert result["status"] == "accepted"
assert "batch_id" in result
def test_upload_batch_async_without_auto_label(self, client):
"""Test async mode with auto_label disabled."""
files = {
"INV001.pdf": b"%PDF-1.4 test content",
}
zip_file = create_test_zip(files)
response = client.post(
"/api/v1/admin/batch/upload",
files={"file": ("test.zip", zip_file, "application/zip")},
data={
"upload_source": "ui",
"async_mode": "true",
"auto_label": "false",
},
)
assert response.status_code == 202
result = response.json()
assert result["status"] == "accepted"
def test_upload_batch_queue_full(self, client):
"""Test handling queue full scenario."""
# This test would require mocking the queue to return False on submit
# For now, we verify the endpoint accepts the request
files = {
"INV001.pdf": b"%PDF-1.4 test content",
}
zip_file = create_test_zip(files)
response = client.post(
"/api/v1/admin/batch/upload",
files={"file": ("test.zip", zip_file, "application/zip")},
data={"upload_source": "ui", "async_mode": "true"},
)
# Should either accept (202) or reject if queue full (503)
assert response.status_code in [202, 503]
def test_async_status_url_format(self, client):
"""Test async response contains correctly formatted status URL."""
files = {
"INV001.pdf": b"%PDF-1.4 test content",
}
zip_file = create_test_zip(files)
response = client.post(
"/api/v1/admin/batch/upload",
files={"file": ("test.zip", zip_file, "application/zip")},
data={"async_mode": "true"},
)
assert response.status_code == 202
result = response.json()
batch_id = result["batch_id"]
expected_url = f"/api/v1/admin/batch/status/{batch_id}"
assert result["status_url"] == expected_url

View File

@@ -0,0 +1,221 @@
"""
Tests for Batch Upload Service
"""
import io
import zipfile
from pathlib import Path
from uuid import uuid4
import pytest
from src.data.admin_db import AdminDB
from src.web.services.batch_upload import BatchUploadService
@pytest.fixture
def admin_db():
"""Mock admin database for testing."""
class MockAdminDB:
def __init__(self):
self.batches = {}
self.batch_files = {}
def create_batch_upload(self, admin_token, filename, file_size, upload_source):
batch_id = uuid4()
batch = type('BatchUpload', (), {
'batch_id': batch_id,
'admin_token': admin_token,
'filename': filename,
'file_size': file_size,
'upload_source': upload_source,
'status': 'processing',
'total_files': 0,
'processed_files': 0,
'successful_files': 0,
'failed_files': 0,
'csv_filename': None,
'csv_row_count': None,
'error_message': None,
'created_at': None,
'completed_at': None,
})()
self.batches[batch_id] = batch
return batch
def update_batch_upload(self, batch_id, **kwargs):
if batch_id in self.batches:
batch = self.batches[batch_id]
for key, value in kwargs.items():
setattr(batch, key, value)
def create_batch_upload_file(self, batch_id, filename, **kwargs):
file_id = uuid4()
# Set defaults for attributes
defaults = {
'file_id': file_id,
'batch_id': batch_id,
'filename': filename,
'status': 'pending',
'error_message': None,
'annotation_count': 0,
'csv_row_data': None,
}
defaults.update(kwargs)
file_record = type('BatchUploadFile', (), defaults)()
if batch_id not in self.batch_files:
self.batch_files[batch_id] = []
self.batch_files[batch_id].append(file_record)
return file_record
def update_batch_upload_file(self, file_id, **kwargs):
for files in self.batch_files.values():
for file_record in files:
if file_record.file_id == file_id:
for key, value in kwargs.items():
setattr(file_record, key, value)
return
def get_batch_upload(self, batch_id):
return self.batches.get(batch_id)
def get_batch_upload_files(self, batch_id):
return self.batch_files.get(batch_id, [])
return MockAdminDB()
@pytest.fixture
def batch_service(admin_db):
"""Batch upload service instance."""
return BatchUploadService(admin_db)
def create_test_zip(files):
"""Create a test ZIP file with given files.
Args:
files: Dictionary mapping filenames to content bytes
Returns:
ZIP file content as bytes
"""
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
for filename, content in files.items():
zip_file.writestr(filename, content)
return zip_buffer.getvalue()
class TestBatchUploadService:
"""Tests for BatchUploadService."""
def test_process_empty_zip(self, batch_service):
"""Test processing an empty ZIP file."""
zip_content = create_test_zip({})
result = batch_service.process_zip_upload(
admin_token="test-token",
zip_filename="empty.zip",
zip_content=zip_content,
)
assert result["status"] == "failed"
assert "No PDF files" in result.get("error", "")
def test_process_zip_with_pdfs_only(self, batch_service):
"""Test processing ZIP with PDFs but no CSV."""
files = {
"INV001.pdf": b"%PDF-1.4 test content",
"INV002.pdf": b"%PDF-1.4 test content 2",
}
zip_content = create_test_zip(files)
result = batch_service.process_zip_upload(
admin_token="test-token",
zip_filename="invoices.zip",
zip_content=zip_content,
)
assert result["status"] == "completed"
assert result["total_files"] == 2
assert result["successful_files"] == 2
assert result["failed_files"] == 0
def test_process_zip_with_csv(self, batch_service):
"""Test processing ZIP with PDFs and CSV."""
csv_content = """DocumentId,InvoiceNumber,Amount,OCR
INV001,F2024-001,1500.00,7350012345678
INV002,F2024-002,2500.00,7350087654321
"""
files = {
"INV001.pdf": b"%PDF-1.4 test content",
"INV002.pdf": b"%PDF-1.4 test content 2",
"metadata.csv": csv_content.encode('utf-8'),
}
zip_content = create_test_zip(files)
result = batch_service.process_zip_upload(
admin_token="test-token",
zip_filename="invoices.zip",
zip_content=zip_content,
)
assert result["status"] == "completed"
assert result["total_files"] == 2
assert result["csv_filename"] == "metadata.csv"
assert result["csv_row_count"] == 2
def test_process_invalid_zip(self, batch_service):
"""Test processing invalid ZIP file."""
result = batch_service.process_zip_upload(
admin_token="test-token",
zip_filename="invalid.zip",
zip_content=b"not a zip file",
)
assert result["status"] == "failed"
assert "Invalid ZIP file" in result.get("error", "")
def test_csv_parsing(self, batch_service):
"""Test CSV field parsing."""
csv_content = """DocumentId,InvoiceNumber,InvoiceDate,Amount,OCR,Bankgiro,customer_number
INV001,F2024-001,2024-01-15,1500.00,7350012345678,123-4567,C123
INV002,F2024-002,2024-01-16,2500.00,7350087654321,123-4567,C124
"""
zip_file_content = create_test_zip({"metadata.csv": csv_content.encode('utf-8')})
with zipfile.ZipFile(io.BytesIO(zip_file_content)) as zip_file:
csv_file_info = [f for f in zip_file.filelist if f.filename.endswith('.csv')][0]
csv_data = batch_service._parse_csv_file(zip_file, csv_file_info)
assert len(csv_data) == 2
assert "INV001" in csv_data
assert csv_data["INV001"]["InvoiceNumber"] == "F2024-001"
assert csv_data["INV001"]["Amount"] == "1500.00"
assert csv_data["INV001"]["customer_number"] == "C123"
def test_get_batch_status(self, batch_service, admin_db):
"""Test getting batch upload status."""
# Create a batch
zip_content = create_test_zip({"INV001.pdf": b"%PDF-1.4 test"})
result = batch_service.process_zip_upload(
admin_token="test-token",
zip_filename="test.zip",
zip_content=zip_content,
)
batch_id = result["batch_id"]
# Get status
status = batch_service.get_batch_status(batch_id)
assert status["batch_id"] == batch_id
assert status["filename"] == "test.zip"
assert status["status"] == "completed"
assert status["total_files"] == 1
assert len(status["files"]) == 1
def test_get_batch_status_not_found(self, batch_service):
"""Test getting status for non-existent batch."""
status = batch_service.get_batch_status(str(uuid4()))
assert "error" in status

View File

@@ -0,0 +1,298 @@
"""
Integration tests for inference API endpoints.
Tests the /api/v1/infer endpoint to ensure it works end-to-end.
"""
import pytest
from pathlib import Path
from unittest.mock import Mock, patch
from fastapi.testclient import TestClient
from PIL import Image
import io
from src.web.app import create_app
from src.web.config import ModelConfig, StorageConfig, AppConfig
@pytest.fixture
def test_app(tmp_path):
"""Create test FastAPI application."""
# Setup test directories
upload_dir = tmp_path / "uploads"
result_dir = tmp_path / "results"
upload_dir.mkdir()
result_dir.mkdir()
# Create test config
app_config = AppConfig(
model=ModelConfig(
model_path=Path("runs/train/invoice_fields/weights/best.pt"),
confidence_threshold=0.5,
use_gpu=False,
dpi=150,
),
storage=StorageConfig(
upload_dir=upload_dir,
result_dir=result_dir,
allowed_extensions={".pdf", ".png", ".jpg", ".jpeg"},
max_file_size_mb=50,
),
)
# Create app
app = create_app(app_config)
return app
@pytest.fixture
def client(test_app):
"""Create test client."""
return TestClient(test_app)
@pytest.fixture
def sample_png_bytes():
"""Create sample PNG image bytes."""
img = Image.new('RGB', (800, 1200), color='white')
img_bytes = io.BytesIO()
img.save(img_bytes, format='PNG')
img_bytes.seek(0)
return img_bytes
class TestHealthEndpoint:
"""Test /api/v1/health endpoint."""
def test_health_check_returns_200(self, client):
"""Test health check returns 200 OK."""
response = client.get("/api/v1/health")
assert response.status_code == 200
def test_health_check_response_structure(self, client):
"""Test health check response has correct structure."""
response = client.get("/api/v1/health")
data = response.json()
assert "status" in data
assert "model_loaded" in data
assert "gpu_available" in data
assert "version" in data
assert data["status"] == "healthy"
assert isinstance(data["model_loaded"], bool)
assert isinstance(data["gpu_available"], bool)
class TestInferEndpoint:
"""Test /api/v1/infer endpoint."""
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
def test_infer_accepts_png_file(
self,
mock_yolo_detector,
mock_pipeline,
client,
sample_png_bytes,
):
"""Test that /infer endpoint accepts PNG files."""
# Setup mocks
mock_detector_instance = Mock()
mock_pipeline_instance = Mock()
mock_yolo_detector.return_value = mock_detector_instance
mock_pipeline.return_value = mock_pipeline_instance
# Mock pipeline result
mock_result = Mock()
mock_result.fields = {"InvoiceNumber": "12345"}
mock_result.confidence = {"InvoiceNumber": 0.95}
mock_result.success = True
mock_result.errors = []
mock_result.raw_detections = []
mock_result.document_id = "test123"
mock_result.document_type = "invoice"
mock_result.processing_time_ms = 100.0
mock_result.visualization_path = None
mock_result.detections = []
mock_pipeline_instance.process_image.return_value = mock_result
# Make request
response = client.post(
"/api/v1/infer",
files={"file": ("test.png", sample_png_bytes, "image/png")},
)
# Verify response
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert "result" in data
assert data["result"]["fields"]["InvoiceNumber"] == "12345"
assert data["result"]["confidence"]["InvoiceNumber"] == 0.95
def test_infer_rejects_invalid_file_type(self, client):
"""Test that /infer rejects unsupported file types."""
invalid_file = io.BytesIO(b"fake txt content")
response = client.post(
"/api/v1/infer",
files={"file": ("test.txt", invalid_file, "text/plain")},
)
assert response.status_code == 400
assert "Unsupported file type" in response.json()["detail"]
def test_infer_requires_file(self, client):
"""Test that /infer requires a file parameter."""
response = client.post("/api/v1/infer")
assert response.status_code == 422 # Unprocessable Entity
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
def test_infer_returns_cross_validation_if_available(
self,
mock_yolo_detector,
mock_pipeline,
client,
sample_png_bytes,
):
"""Test that cross-validation results are included if available."""
# Setup mocks
mock_detector_instance = Mock()
mock_pipeline_instance = Mock()
mock_yolo_detector.return_value = mock_detector_instance
mock_pipeline.return_value = mock_pipeline_instance
# Mock pipeline result with cross-validation
mock_result = Mock()
mock_result.fields = {
"InvoiceNumber": "12345",
"OCR": "1234567",
"Amount": "100.00",
}
mock_result.confidence = {
"InvoiceNumber": 0.95,
"OCR": 0.90,
"Amount": 0.88,
}
mock_result.success = True
mock_result.errors = []
mock_result.raw_detections = []
mock_result.document_id = "test123"
mock_result.document_type = "invoice"
mock_result.processing_time_ms = 100.0
mock_result.visualization_path = None
mock_result.detections = []
# Add cross-validation result
mock_cv = Mock()
mock_cv.is_valid = True
mock_cv.payment_line_ocr = "1234567"
mock_cv.ocr_match = True
mock_result.cross_validation = mock_cv
mock_pipeline_instance.process_image.return_value = mock_result
# Make request
response = client.post(
"/api/v1/infer",
files={"file": ("test.png", sample_png_bytes, "image/png")},
)
# Verify response includes cross-validation
assert response.status_code == 200
data = response.json()
# Note: cross_validation is not currently in the response schema
# This test documents that it should be added
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
def test_infer_handles_processing_errors_gracefully(
self,
mock_yolo_detector,
mock_pipeline,
client,
sample_png_bytes,
):
"""Test that processing errors are handled gracefully."""
# Setup mocks
mock_detector_instance = Mock()
mock_pipeline_instance = Mock()
mock_yolo_detector.return_value = mock_detector_instance
mock_pipeline.return_value = mock_pipeline_instance
# Make pipeline raise an error
mock_pipeline_instance.process_image.side_effect = Exception("Model inference failed")
# Make request
response = client.post(
"/api/v1/infer",
files={"file": ("test.png", sample_png_bytes, "image/png")},
)
# Verify error handling - service catches exceptions and returns partial results
assert response.status_code == 200
data = response.json()
assert data["status"] == "partial"
assert data["result"]["success"] is False
assert len(data["result"]["errors"]) > 0
assert "Model inference failed" in data["result"]["errors"][0]
class TestResultsEndpoint:
"""Test /api/v1/results/{filename} endpoint."""
def test_get_result_image_returns_404_if_not_found(self, client):
"""Test that getting non-existent result returns 404."""
response = client.get("/api/v1/results/nonexistent.png")
assert response.status_code == 404
def test_get_result_image_returns_file_if_exists(self, client, test_app, tmp_path):
"""Test that existing result file is returned."""
# Get storage config from app
storage_config = test_app.extra.get("storage_config")
if not storage_config:
pytest.skip("Storage config not available in test app")
# Create a test result file
result_file = storage_config.result_dir / "test_result.png"
img = Image.new('RGB', (100, 100), color='red')
img.save(result_file)
# Request the file
response = client.get("/api/v1/results/test_result.png")
assert response.status_code == 200
assert response.headers["content-type"] == "image/png"
class TestInferenceServiceImports:
"""Critical test to catch import errors."""
def test_inference_service_can_import_modules(self):
"""
Test that InferenceService can import its dependencies.
This test will fail if there are ImportError issues like:
- from ..inference.pipeline (wrong relative import)
- from src.web.inference (non-existent module)
It ensures the imports are correct before runtime.
"""
from src.web.services.inference import InferenceService
# Import the modules that InferenceService tries to import
from src.inference.pipeline import InferencePipeline
from src.inference.yolo_detector import YOLODetector
from src.pdf.renderer import render_pdf_to_images
# If we got here, all imports work correctly
assert InferencePipeline is not None
assert YOLODetector is not None
assert render_pdf_to_images is not None
assert InferenceService is not None

View File

@@ -0,0 +1,297 @@
"""
Integration tests for inference service.
Tests the full initialization and processing flow to catch import errors.
"""
import pytest
from pathlib import Path
from unittest.mock import Mock, patch
from PIL import Image
import io
from src.web.services.inference import InferenceService
from src.web.config import ModelConfig, StorageConfig
@pytest.fixture
def model_config(tmp_path):
"""Create model configuration for testing."""
return ModelConfig(
model_path=Path("runs/train/invoice_fields/weights/best.pt"),
confidence_threshold=0.5,
use_gpu=False, # Use CPU for tests
dpi=150,
)
@pytest.fixture
def storage_config(tmp_path):
"""Create storage configuration for testing."""
upload_dir = tmp_path / "uploads"
result_dir = tmp_path / "results"
upload_dir.mkdir()
result_dir.mkdir()
return StorageConfig(
upload_dir=upload_dir,
result_dir=result_dir,
allowed_extensions={".pdf", ".png", ".jpg", ".jpeg"},
max_file_size_mb=50,
)
@pytest.fixture
def sample_image(tmp_path):
"""Create a sample test image."""
image_path = tmp_path / "test_invoice.png"
img = Image.new('RGB', (800, 1200), color='white')
img.save(image_path)
return image_path
@pytest.fixture
def inference_service(model_config, storage_config):
"""Create inference service instance."""
return InferenceService(
model_config=model_config,
storage_config=storage_config,
)
class TestInferenceServiceInitialization:
"""Test inference service initialization to catch import errors."""
def test_service_creation(self, inference_service):
"""Test that service can be created without errors."""
assert inference_service is not None
assert not inference_service.is_initialized
def test_gpu_available_check(self, inference_service):
"""Test GPU availability check (should not crash)."""
gpu_available = inference_service.gpu_available
assert isinstance(gpu_available, bool)
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
def test_initialize_imports_correctly(
self,
mock_yolo_detector,
mock_pipeline,
inference_service,
):
"""
Test that initialize() imports modules correctly.
This test ensures that the import statements in initialize()
use correct paths and don't fail with ImportError.
"""
# Mock the constructors to avoid actually loading models
mock_detector_instance = Mock()
mock_pipeline_instance = Mock()
mock_yolo_detector.return_value = mock_detector_instance
mock_pipeline.return_value = mock_pipeline_instance
# Initialize should not raise ImportError
inference_service.initialize()
# Verify initialization succeeded
assert inference_service.is_initialized
# Verify imports were called with correct parameters
mock_yolo_detector.assert_called_once()
mock_pipeline.assert_called_once()
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
def test_initialize_sets_up_pipeline(
self,
mock_yolo_detector,
mock_pipeline,
inference_service,
model_config,
):
"""Test that initialize sets up pipeline with correct config."""
mock_detector_instance = Mock()
mock_pipeline_instance = Mock()
mock_yolo_detector.return_value = mock_detector_instance
mock_pipeline.return_value = mock_pipeline_instance
inference_service.initialize()
# Check YOLO detector was initialized correctly
mock_yolo_detector.assert_called_once_with(
str(model_config.model_path),
confidence_threshold=model_config.confidence_threshold,
device="cpu", # use_gpu=False in fixture
)
# Check pipeline was initialized correctly
mock_pipeline.assert_called_once_with(
model_path=str(model_config.model_path),
confidence_threshold=model_config.confidence_threshold,
use_gpu=False,
dpi=150,
enable_fallback=True,
)
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
def test_initialize_idempotent(
self,
mock_yolo_detector,
mock_pipeline,
inference_service,
):
"""Test that calling initialize() multiple times is safe."""
mock_detector_instance = Mock()
mock_pipeline_instance = Mock()
mock_yolo_detector.return_value = mock_detector_instance
mock_pipeline.return_value = mock_pipeline_instance
# Call initialize twice
inference_service.initialize()
inference_service.initialize()
# Should only be called once due to is_initialized check
assert mock_yolo_detector.call_count == 1
assert mock_pipeline.call_count == 1
class TestInferenceServiceProcessing:
"""Test inference processing methods."""
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
@patch('ultralytics.YOLO')
def test_process_image_basic_flow(
self,
mock_yolo_class,
mock_yolo_detector,
mock_pipeline,
inference_service,
sample_image,
):
"""Test basic image processing flow."""
# Setup mocks
mock_detector_instance = Mock()
mock_pipeline_instance = Mock()
mock_yolo_detector.return_value = mock_detector_instance
mock_pipeline.return_value = mock_pipeline_instance
# Mock pipeline result
mock_result = Mock()
mock_result.fields = {"InvoiceNumber": "12345"}
mock_result.confidence = {"InvoiceNumber": 0.95}
mock_result.success = True
mock_result.errors = []
mock_result.raw_detections = []
mock_pipeline_instance.process_image.return_value = mock_result
# Process image
result = inference_service.process_image(sample_image)
# Verify result
assert result.success
assert result.fields == {"InvoiceNumber": "12345"}
assert result.confidence == {"InvoiceNumber": 0.95}
assert result.processing_time_ms > 0
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
def test_process_image_handles_errors(
self,
mock_yolo_detector,
mock_pipeline,
inference_service,
sample_image,
):
"""Test that processing errors are handled gracefully."""
# Setup mocks
mock_detector_instance = Mock()
mock_pipeline_instance = Mock()
mock_yolo_detector.return_value = mock_detector_instance
mock_pipeline.return_value = mock_pipeline_instance
# Make pipeline raise an error
mock_pipeline_instance.process_image.side_effect = Exception("Test error")
# Process should not crash
result = inference_service.process_image(sample_image)
# Verify error handling
assert not result.success
assert len(result.errors) > 0
assert "Test error" in result.errors[0]
class TestInferenceServicePDFRendering:
"""Test PDF rendering imports."""
@patch('src.inference.pipeline.InferencePipeline')
@patch('src.inference.yolo_detector.YOLODetector')
@patch('src.pdf.renderer.render_pdf_to_images')
@patch('ultralytics.YOLO')
def test_pdf_visualization_imports_correctly(
self,
mock_yolo_class,
mock_render_pdf,
mock_yolo_detector,
mock_pipeline,
inference_service,
tmp_path,
):
"""
Test that _save_pdf_visualization imports render_pdf_to_images correctly.
This catches the import error we had with:
from ..pdf.renderer (wrong) vs from src.pdf.renderer (correct)
"""
# Setup mocks
mock_detector_instance = Mock()
mock_pipeline_instance = Mock()
mock_yolo_detector.return_value = mock_detector_instance
mock_pipeline.return_value = mock_pipeline_instance
# Create a fake PDF path
pdf_path = tmp_path / "test.pdf"
pdf_path.touch()
# Mock render_pdf_to_images to return an image
image_bytes = io.BytesIO()
img = Image.new('RGB', (800, 1200), color='white')
img.save(image_bytes, format='PNG')
mock_render_pdf.return_value = [(1, image_bytes.getvalue())]
# Mock YOLO
mock_model_instance = Mock()
mock_result = Mock()
mock_result.save = Mock()
mock_model_instance.predict.return_value = [mock_result]
mock_yolo_class.return_value = mock_model_instance
# This should not raise ImportError
result_path = inference_service._save_pdf_visualization(pdf_path, "test123")
# Verify import was successful
mock_render_pdf.assert_called_once()
assert result_path is not None
@pytest.mark.skipif(
not Path("runs/train/invoice_fields/weights/best.pt").exists(),
reason="Model file not available"
)
class TestInferenceServiceRealModel:
"""Integration tests with real model (skip if model not available)."""
def test_real_initialization(self, model_config, storage_config):
"""Test real initialization with actual model."""
service = InferenceService(model_config, storage_config)
# This should work with the real imports
service.initialize()
assert service.is_initialized
assert service._pipeline is not None
assert service._detector is not None

View File

@@ -0,0 +1,154 @@
"""
Tests for the RateLimiter class.
"""
import time
from datetime import datetime, timedelta
from unittest.mock import MagicMock
import pytest
from src.data.async_request_db import ApiKeyConfig
from src.web.rate_limiter import RateLimiter, RateLimitConfig, RateLimitStatus
class TestRateLimiter:
"""Tests for RateLimiter."""
def test_check_submit_limit_allowed(self, rate_limiter, mock_db):
"""Test that requests are allowed under the limit."""
status = rate_limiter.check_submit_limit("test-api-key")
assert status.allowed is True
assert status.remaining_requests >= 0
assert status.retry_after_seconds is None
def test_check_submit_limit_rate_exceeded(self, rate_limiter, mock_db):
"""Test rate limit exceeded when too many requests."""
# Record 10 requests (the default limit)
for _ in range(10):
rate_limiter.record_request("test-api-key")
status = rate_limiter.check_submit_limit("test-api-key")
assert status.allowed is False
assert status.remaining_requests == 0
assert status.retry_after_seconds is not None
assert status.retry_after_seconds > 0
assert "rate limit" in status.reason.lower()
def test_check_submit_limit_concurrent_jobs_exceeded(self, rate_limiter, mock_db):
"""Test rejection when max concurrent jobs reached."""
# Mock active jobs at the limit
mock_db.count_active_jobs.return_value = 3 # Max is 3
status = rate_limiter.check_submit_limit("test-api-key")
assert status.allowed is False
assert "concurrent" in status.reason.lower()
def test_record_request(self, rate_limiter, mock_db):
"""Test that recording a request works."""
rate_limiter.record_request("test-api-key")
# Should have called the database
mock_db.record_rate_limit_event.assert_called_once_with("test-api-key", "request")
def test_check_poll_limit_allowed(self, rate_limiter, mock_db):
"""Test that polling is allowed initially."""
status = rate_limiter.check_poll_limit("test-api-key", "request-123")
assert status.allowed is True
def test_check_poll_limit_too_frequent(self, rate_limiter, mock_db):
"""Test that rapid polling is rejected."""
# First poll should succeed
status1 = rate_limiter.check_poll_limit("test-api-key", "request-123")
assert status1.allowed is True
# Immediate second poll should fail
status2 = rate_limiter.check_poll_limit("test-api-key", "request-123")
assert status2.allowed is False
assert "polling" in status2.reason.lower()
assert status2.retry_after_seconds is not None
def test_check_poll_limit_different_requests(self, rate_limiter, mock_db):
"""Test that different request_ids have separate poll limits."""
# Poll request 1
status1 = rate_limiter.check_poll_limit("test-api-key", "request-1")
assert status1.allowed is True
# Poll request 2 should also be allowed
status2 = rate_limiter.check_poll_limit("test-api-key", "request-2")
assert status2.allowed is True
def test_sliding_window_expires(self, rate_limiter, mock_db):
"""Test that requests expire from the sliding window."""
# Record requests
for _ in range(5):
rate_limiter.record_request("test-api-key")
# Check status - should have 5 remaining
status1 = rate_limiter.check_submit_limit("test-api-key")
assert status1.allowed is True
assert status1.remaining_requests == 4 # 10 - 5 - 1 (for this check)
def test_get_rate_limit_headers(self, rate_limiter):
"""Test rate limit header generation."""
status = RateLimitStatus(
allowed=False,
remaining_requests=0,
reset_at=datetime.utcnow() + timedelta(seconds=30),
retry_after_seconds=30,
)
headers = rate_limiter.get_rate_limit_headers(status)
assert "X-RateLimit-Remaining" in headers
assert headers["X-RateLimit-Remaining"] == "0"
assert "Retry-After" in headers
assert headers["Retry-After"] == "30"
def test_cleanup_poll_timestamps(self, rate_limiter, mock_db):
"""Test cleanup of old poll timestamps."""
# Add some poll timestamps
rate_limiter.check_poll_limit("test-api-key", "old-request")
# Manually age the timestamp
rate_limiter._poll_timestamps[("test-api-key", "old-request")] = time.time() - 7200
# Run cleanup with 1 hour max age
cleaned = rate_limiter.cleanup_poll_timestamps(max_age_seconds=3600)
assert cleaned == 1
assert ("test-api-key", "old-request") not in rate_limiter._poll_timestamps
def test_cleanup_request_windows(self, rate_limiter, mock_db):
"""Test cleanup of empty request windows."""
# Add some old requests
rate_limiter._request_windows["old-key"] = [time.time() - 120]
# Run cleanup
rate_limiter.cleanup_request_windows()
# Old entries should be removed
assert "old-key" not in rate_limiter._request_windows
def test_config_caching(self, rate_limiter, mock_db):
"""Test that API key configs are cached."""
# First call should query database
rate_limiter._get_config("test-api-key")
assert mock_db.get_api_key_config.call_count == 1
# Second call should use cache
rate_limiter._get_config("test-api-key")
assert mock_db.get_api_key_config.call_count == 1 # Still 1
def test_default_config_for_unknown_key(self, rate_limiter, mock_db):
"""Test that unknown API keys get default config."""
mock_db.get_api_key_config.return_value = None
config = rate_limiter._get_config("unknown-key")
assert config.requests_per_minute == 10 # Default
assert config.max_concurrent_jobs == 3 # Default

View File

@@ -0,0 +1,384 @@
"""
Tests for Phase 4: Training Data Management
"""
import pytest
from datetime import datetime
from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from src.web.api.v1.admin.training import create_training_router
from src.web.core.auth import validate_admin_token, get_admin_db
class MockTrainingTask:
"""Mock TrainingTask for testing."""
def __init__(self, **kwargs):
self.task_id = kwargs.get('task_id', uuid4())
self.admin_token = kwargs.get('admin_token', 'test-token')
self.name = kwargs.get('name', 'Test Training')
self.description = kwargs.get('description', None)
self.status = kwargs.get('status', 'completed')
self.task_type = kwargs.get('task_type', 'train')
self.config = kwargs.get('config', {})
self.scheduled_at = kwargs.get('scheduled_at', None)
self.cron_expression = kwargs.get('cron_expression', None)
self.is_recurring = kwargs.get('is_recurring', False)
self.started_at = kwargs.get('started_at', datetime.utcnow())
self.completed_at = kwargs.get('completed_at', datetime.utcnow())
self.error_message = kwargs.get('error_message', None)
self.result_metrics = kwargs.get('result_metrics', {})
self.model_path = kwargs.get('model_path', 'runs/train/test/weights/best.pt')
self.document_count = kwargs.get('document_count', 0)
self.metrics_mAP = kwargs.get('metrics_mAP', 0.935)
self.metrics_precision = kwargs.get('metrics_precision', 0.92)
self.metrics_recall = kwargs.get('metrics_recall', 0.88)
self.created_at = kwargs.get('created_at', datetime.utcnow())
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockTrainingDocumentLink:
"""Mock TrainingDocumentLink for testing."""
def __init__(self, **kwargs):
self.link_id = kwargs.get('link_id', uuid4())
self.task_id = kwargs.get('task_id')
self.document_id = kwargs.get('document_id')
self.annotation_snapshot = kwargs.get('annotation_snapshot', None)
self.created_at = kwargs.get('created_at', datetime.utcnow())
class MockAdminDocument:
"""Mock AdminDocument for testing."""
def __init__(self, **kwargs):
self.document_id = kwargs.get('document_id', uuid4())
self.admin_token = kwargs.get('admin_token', 'test-token')
self.filename = kwargs.get('filename', 'test.pdf')
self.file_size = kwargs.get('file_size', 100000)
self.content_type = kwargs.get('content_type', 'application/pdf')
self.file_path = kwargs.get('file_path', 'data/admin_docs/test.pdf')
self.page_count = kwargs.get('page_count', 1)
self.status = kwargs.get('status', 'labeled')
self.auto_label_status = kwargs.get('auto_label_status', None)
self.auto_label_error = kwargs.get('auto_label_error', None)
self.upload_source = kwargs.get('upload_source', 'ui')
self.batch_id = kwargs.get('batch_id', None)
self.csv_field_values = kwargs.get('csv_field_values', None)
self.auto_label_queued_at = kwargs.get('auto_label_queued_at', None)
self.annotation_lock_until = kwargs.get('annotation_lock_until', None)
self.created_at = kwargs.get('created_at', datetime.utcnow())
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockAnnotation:
"""Mock AdminAnnotation for testing."""
def __init__(self, **kwargs):
self.annotation_id = kwargs.get('annotation_id', uuid4())
self.document_id = kwargs.get('document_id')
self.page_number = kwargs.get('page_number', 1)
self.class_id = kwargs.get('class_id', 0)
self.class_name = kwargs.get('class_name', 'invoice_number')
self.bbox_x = kwargs.get('bbox_x', 100)
self.bbox_y = kwargs.get('bbox_y', 100)
self.bbox_width = kwargs.get('bbox_width', 200)
self.bbox_height = kwargs.get('bbox_height', 50)
self.x_center = kwargs.get('x_center', 0.5)
self.y_center = kwargs.get('y_center', 0.5)
self.width = kwargs.get('width', 0.3)
self.height = kwargs.get('height', 0.1)
self.text_value = kwargs.get('text_value', 'INV-001')
self.confidence = kwargs.get('confidence', 0.95)
self.source = kwargs.get('source', 'manual')
self.is_verified = kwargs.get('is_verified', False)
self.verified_at = kwargs.get('verified_at', None)
self.verified_by = kwargs.get('verified_by', None)
self.override_source = kwargs.get('override_source', None)
self.original_annotation_id = kwargs.get('original_annotation_id', None)
self.created_at = kwargs.get('created_at', datetime.utcnow())
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockAdminDB:
"""Mock AdminDB for testing Phase 4."""
def __init__(self):
self.documents = {}
self.annotations = {}
self.training_tasks = {}
self.training_links = {}
def get_documents_for_training(
self,
admin_token,
status="labeled",
has_annotations=True,
min_annotation_count=None,
exclude_used_in_training=False,
limit=100,
offset=0,
):
"""Get documents for training."""
# Filter documents by criteria
filtered = []
for doc in self.documents.values():
if doc.admin_token != admin_token or doc.status != status:
continue
# Check annotations
annotations = self.annotations.get(str(doc.document_id), [])
if has_annotations and len(annotations) == 0:
continue
if min_annotation_count and len(annotations) < min_annotation_count:
continue
# Check if used in training
if exclude_used_in_training:
links = self.training_links.get(str(doc.document_id), [])
if links:
continue
filtered.append(doc)
total = len(filtered)
return filtered[offset:offset+limit], total
def get_annotations_for_document(self, document_id):
"""Get annotations for document."""
return self.annotations.get(str(document_id), [])
def get_document_training_tasks(self, document_id):
"""Get training tasks that used this document."""
return self.training_links.get(str(document_id), [])
def get_training_tasks_by_token(
self,
admin_token,
status=None,
limit=20,
offset=0,
):
"""Get training tasks filtered by token."""
tasks = [t for t in self.training_tasks.values() if t.admin_token == admin_token]
if status:
tasks = [t for t in tasks if t.status == status]
total = len(tasks)
return tasks[offset:offset+limit], total
def get_training_task(self, task_id):
"""Get training task by ID."""
return self.training_tasks.get(str(task_id))
@pytest.fixture
def app():
"""Create test FastAPI app."""
app = FastAPI()
# Create mock DB
mock_db = MockAdminDB()
# Add test documents
doc1 = MockAdminDocument(
filename="DOC001.pdf",
status="labeled",
)
doc2 = MockAdminDocument(
filename="DOC002.pdf",
status="labeled",
)
doc3 = MockAdminDocument(
filename="DOC003.pdf",
status="labeled",
)
mock_db.documents[str(doc1.document_id)] = doc1
mock_db.documents[str(doc2.document_id)] = doc2
mock_db.documents[str(doc3.document_id)] = doc3
# Add annotations
mock_db.annotations[str(doc1.document_id)] = [
MockAnnotation(document_id=doc1.document_id, source="manual"),
MockAnnotation(document_id=doc1.document_id, source="auto"),
]
mock_db.annotations[str(doc2.document_id)] = [
MockAnnotation(document_id=doc2.document_id, source="auto"),
MockAnnotation(document_id=doc2.document_id, source="auto"),
MockAnnotation(document_id=doc2.document_id, source="auto"),
]
# doc3 has no annotations
# Add training tasks
task1 = MockTrainingTask(
name="Training Run 2024-01",
status="completed",
document_count=500,
metrics_mAP=0.935,
metrics_precision=0.92,
metrics_recall=0.88,
)
task2 = MockTrainingTask(
name="Training Run 2024-02",
status="completed",
document_count=600,
metrics_mAP=0.951,
metrics_precision=0.94,
metrics_recall=0.92,
)
mock_db.training_tasks[str(task1.task_id)] = task1
mock_db.training_tasks[str(task2.task_id)] = task2
# Add training links (doc1 used in task1)
link1 = MockTrainingDocumentLink(
task_id=task1.task_id,
document_id=doc1.document_id,
)
mock_db.training_links[str(doc1.document_id)] = [link1]
# Override dependencies
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
app.dependency_overrides[get_admin_db] = lambda: mock_db
# Include router
router = create_training_router()
app.include_router(router)
return app
@pytest.fixture
def client(app):
"""Create test client."""
return TestClient(app)
class TestTrainingDocuments:
"""Tests for GET /admin/training/documents endpoint."""
def test_get_training_documents_success(self, client):
"""Test getting documents for training."""
response = client.get("/admin/training/documents")
assert response.status_code == 200
data = response.json()
assert "total" in data
assert "documents" in data
assert data["total"] >= 0
assert isinstance(data["documents"], list)
def test_get_training_documents_with_annotations(self, client):
"""Test filtering documents with annotations."""
response = client.get("/admin/training/documents?has_annotations=true")
assert response.status_code == 200
data = response.json()
# Should return doc1 and doc2 (both have annotations)
assert data["total"] == 2
def test_get_training_documents_min_annotation_count(self, client):
"""Test filtering by minimum annotation count."""
response = client.get("/admin/training/documents?min_annotation_count=3")
assert response.status_code == 200
data = response.json()
# Should return only doc2 (has 3 annotations)
assert data["total"] == 1
def test_get_training_documents_exclude_used(self, client):
"""Test excluding documents already used in training."""
response = client.get("/admin/training/documents?exclude_used_in_training=true")
assert response.status_code == 200
data = response.json()
# Should exclude doc1 (used in training)
assert data["total"] == 1 # Only doc2 (doc3 has no annotations)
def test_get_training_documents_annotation_sources(self, client):
"""Test that annotation sources are included."""
response = client.get("/admin/training/documents?has_annotations=true")
assert response.status_code == 200
data = response.json()
# Check that documents have annotation_sources field
for doc in data["documents"]:
assert "annotation_sources" in doc
assert isinstance(doc["annotation_sources"], dict)
assert "manual" in doc["annotation_sources"]
assert "auto" in doc["annotation_sources"]
def test_get_training_documents_pagination(self, client):
"""Test pagination parameters."""
response = client.get("/admin/training/documents?limit=1&offset=0")
assert response.status_code == 200
data = response.json()
assert data["limit"] == 1
assert data["offset"] == 0
assert len(data["documents"]) <= 1
class TestTrainingModels:
"""Tests for GET /admin/training/models endpoint."""
def test_get_training_models_success(self, client):
"""Test getting trained models list."""
response = client.get("/admin/training/models")
assert response.status_code == 200
data = response.json()
assert "total" in data
assert "models" in data
assert data["total"] == 2
assert len(data["models"]) == 2
def test_get_training_models_includes_metrics(self, client):
"""Test that models include metrics."""
response = client.get("/admin/training/models")
assert response.status_code == 200
data = response.json()
# Check first model has metrics
model = data["models"][0]
assert "metrics" in model
assert "mAP" in model["metrics"]
assert model["metrics"]["mAP"] is not None
assert "precision" in model["metrics"]
assert "recall" in model["metrics"]
def test_get_training_models_includes_download_url(self, client):
"""Test that completed models have download URLs."""
response = client.get("/admin/training/models")
assert response.status_code == 200
data = response.json()
# Check completed models have download URLs
for model in data["models"]:
if model["status"] == "completed":
assert "download_url" in model
assert model["download_url"] is not None
def test_get_training_models_filter_by_status(self, client):
"""Test filtering models by status."""
response = client.get("/admin/training/models?status=completed")
assert response.status_code == 200
data = response.json()
# All returned models should be completed
for model in data["models"]:
assert model["status"] == "completed"
def test_get_training_models_pagination(self, client):
"""Test pagination for models."""
response = client.get("/admin/training/models?limit=1&offset=0")
assert response.status_code == 200
data = response.json()
assert data["limit"] == 1
assert data["offset"] == 0
assert len(data["models"]) == 1