WIP
This commit is contained in:
524
tests/data/test_admin_models_v2.py
Normal file
524
tests/data/test_admin_models_v2.py
Normal file
@@ -0,0 +1,524 @@
|
||||
"""
|
||||
Tests for Admin Models v2 - Batch Upload and Training Links.
|
||||
|
||||
Tests for new SQLModel classes: BatchUpload, BatchUploadFile,
|
||||
TrainingDocumentLink, AnnotationHistory.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from src.data.admin_models import (
|
||||
BatchUpload,
|
||||
BatchUploadFile,
|
||||
TrainingDocumentLink,
|
||||
AnnotationHistory,
|
||||
AdminDocument,
|
||||
AdminAnnotation,
|
||||
TrainingTask,
|
||||
FIELD_CLASSES,
|
||||
CSV_TO_CLASS_MAPPING,
|
||||
)
|
||||
|
||||
|
||||
class TestBatchUpload:
|
||||
"""Tests for BatchUpload model."""
|
||||
|
||||
def test_batch_upload_creation(self):
|
||||
"""Test basic batch upload creation."""
|
||||
batch = BatchUpload(
|
||||
admin_token="test-token",
|
||||
filename="invoices.zip",
|
||||
file_size=1024000,
|
||||
upload_source="ui",
|
||||
)
|
||||
|
||||
assert batch.batch_id is not None
|
||||
assert isinstance(batch.batch_id, UUID)
|
||||
assert batch.admin_token == "test-token"
|
||||
assert batch.filename == "invoices.zip"
|
||||
assert batch.file_size == 1024000
|
||||
assert batch.upload_source == "ui"
|
||||
assert batch.status == "processing"
|
||||
assert batch.total_files == 0
|
||||
assert batch.processed_files == 0
|
||||
assert batch.successful_files == 0
|
||||
assert batch.failed_files == 0
|
||||
assert batch.error_message is None
|
||||
assert batch.completed_at is None
|
||||
|
||||
def test_batch_upload_api_source(self):
|
||||
"""Test batch upload with API source."""
|
||||
batch = BatchUpload(
|
||||
admin_token="api-token",
|
||||
filename="batch.zip",
|
||||
file_size=2048000,
|
||||
upload_source="api",
|
||||
)
|
||||
|
||||
assert batch.upload_source == "api"
|
||||
|
||||
def test_batch_upload_with_progress(self):
|
||||
"""Test batch upload with progress tracking."""
|
||||
batch = BatchUpload(
|
||||
admin_token="test-token",
|
||||
filename="large_batch.zip",
|
||||
file_size=10240000,
|
||||
total_files=100,
|
||||
processed_files=50,
|
||||
successful_files=48,
|
||||
failed_files=2,
|
||||
status="processing",
|
||||
)
|
||||
|
||||
assert batch.total_files == 100
|
||||
assert batch.processed_files == 50
|
||||
assert batch.successful_files == 48
|
||||
assert batch.failed_files == 2
|
||||
|
||||
def test_batch_upload_completed(self):
|
||||
"""Test completed batch upload."""
|
||||
now = datetime.utcnow()
|
||||
batch = BatchUpload(
|
||||
admin_token="test-token",
|
||||
filename="batch.zip",
|
||||
file_size=1024000,
|
||||
status="completed",
|
||||
total_files=10,
|
||||
processed_files=10,
|
||||
successful_files=10,
|
||||
failed_files=0,
|
||||
completed_at=now,
|
||||
)
|
||||
|
||||
assert batch.status == "completed"
|
||||
assert batch.completed_at == now
|
||||
|
||||
def test_batch_upload_failed(self):
|
||||
"""Test failed batch upload."""
|
||||
batch = BatchUpload(
|
||||
admin_token="test-token",
|
||||
filename="bad.zip",
|
||||
file_size=1024,
|
||||
status="failed",
|
||||
error_message="Invalid ZIP file format",
|
||||
)
|
||||
|
||||
assert batch.status == "failed"
|
||||
assert batch.error_message == "Invalid ZIP file format"
|
||||
|
||||
def test_batch_upload_partial(self):
|
||||
"""Test partial batch upload with some failures."""
|
||||
batch = BatchUpload(
|
||||
admin_token="test-token",
|
||||
filename="mixed.zip",
|
||||
file_size=5120000,
|
||||
status="partial",
|
||||
total_files=20,
|
||||
processed_files=20,
|
||||
successful_files=15,
|
||||
failed_files=5,
|
||||
)
|
||||
|
||||
assert batch.status == "partial"
|
||||
assert batch.failed_files == 5
|
||||
|
||||
|
||||
class TestBatchUploadFile:
|
||||
"""Tests for BatchUploadFile model."""
|
||||
|
||||
def test_batch_upload_file_creation(self):
|
||||
"""Test basic file record creation."""
|
||||
batch_id = uuid4()
|
||||
file_record = BatchUploadFile(
|
||||
batch_id=batch_id,
|
||||
filename="INV001.pdf",
|
||||
)
|
||||
|
||||
assert file_record.file_id is not None
|
||||
assert isinstance(file_record.file_id, UUID)
|
||||
assert file_record.batch_id == batch_id
|
||||
assert file_record.filename == "INV001.pdf"
|
||||
assert file_record.status == "pending"
|
||||
assert file_record.document_id is None
|
||||
assert file_record.error_message is None
|
||||
assert file_record.csv_row_data is None
|
||||
assert file_record.processed_at is None
|
||||
|
||||
def test_batch_upload_file_with_document(self):
|
||||
"""Test file record linked to document."""
|
||||
batch_id = uuid4()
|
||||
document_id = uuid4()
|
||||
file_record = BatchUploadFile(
|
||||
batch_id=batch_id,
|
||||
document_id=document_id,
|
||||
filename="INV002.pdf",
|
||||
status="completed",
|
||||
)
|
||||
|
||||
assert file_record.document_id == document_id
|
||||
assert file_record.status == "completed"
|
||||
|
||||
def test_batch_upload_file_with_csv_data(self):
|
||||
"""Test file record with CSV row data."""
|
||||
batch_id = uuid4()
|
||||
csv_data = {
|
||||
"DocumentId": "INV003",
|
||||
"InvoiceNumber": "F2024-003",
|
||||
"Amount": "1500.00",
|
||||
"OCR": "7350012345678",
|
||||
}
|
||||
file_record = BatchUploadFile(
|
||||
batch_id=batch_id,
|
||||
filename="INV003.pdf",
|
||||
csv_row_data=csv_data,
|
||||
)
|
||||
|
||||
assert file_record.csv_row_data == csv_data
|
||||
assert file_record.csv_row_data["InvoiceNumber"] == "F2024-003"
|
||||
|
||||
def test_batch_upload_file_failed(self):
|
||||
"""Test failed file record."""
|
||||
batch_id = uuid4()
|
||||
file_record = BatchUploadFile(
|
||||
batch_id=batch_id,
|
||||
filename="corrupted.pdf",
|
||||
status="failed",
|
||||
error_message="Corrupted PDF file",
|
||||
)
|
||||
|
||||
assert file_record.status == "failed"
|
||||
assert file_record.error_message == "Corrupted PDF file"
|
||||
|
||||
def test_batch_upload_file_skipped(self):
|
||||
"""Test skipped file record."""
|
||||
batch_id = uuid4()
|
||||
file_record = BatchUploadFile(
|
||||
batch_id=batch_id,
|
||||
filename="not_a_pdf.txt",
|
||||
status="skipped",
|
||||
error_message="Not a PDF file",
|
||||
)
|
||||
|
||||
assert file_record.status == "skipped"
|
||||
|
||||
|
||||
class TestTrainingDocumentLink:
|
||||
"""Tests for TrainingDocumentLink model."""
|
||||
|
||||
def test_training_document_link_creation(self):
|
||||
"""Test basic link creation."""
|
||||
task_id = uuid4()
|
||||
document_id = uuid4()
|
||||
link = TrainingDocumentLink(
|
||||
task_id=task_id,
|
||||
document_id=document_id,
|
||||
)
|
||||
|
||||
assert link.link_id is not None
|
||||
assert isinstance(link.link_id, UUID)
|
||||
assert link.task_id == task_id
|
||||
assert link.document_id == document_id
|
||||
assert link.annotation_snapshot is None
|
||||
|
||||
def test_training_document_link_with_snapshot(self):
|
||||
"""Test link with annotation snapshot."""
|
||||
task_id = uuid4()
|
||||
document_id = uuid4()
|
||||
snapshot = {
|
||||
"annotations": [
|
||||
{
|
||||
"class_id": 0,
|
||||
"class_name": "invoice_number",
|
||||
"text_value": "F2024-001",
|
||||
"x_center": 0.5,
|
||||
"y_center": 0.3,
|
||||
},
|
||||
{
|
||||
"class_id": 6,
|
||||
"class_name": "amount",
|
||||
"text_value": "1500.00",
|
||||
"x_center": 0.7,
|
||||
"y_center": 0.6,
|
||||
},
|
||||
],
|
||||
"total_count": 2,
|
||||
"snapshot_time": "2024-01-20T15:00:00",
|
||||
}
|
||||
link = TrainingDocumentLink(
|
||||
task_id=task_id,
|
||||
document_id=document_id,
|
||||
annotation_snapshot=snapshot,
|
||||
)
|
||||
|
||||
assert link.annotation_snapshot == snapshot
|
||||
assert len(link.annotation_snapshot["annotations"]) == 2
|
||||
|
||||
|
||||
class TestAnnotationHistory:
|
||||
"""Tests for AnnotationHistory model."""
|
||||
|
||||
def test_annotation_history_created(self):
|
||||
"""Test history record for creation."""
|
||||
annotation_id = uuid4()
|
||||
new_value = {
|
||||
"class_id": 0,
|
||||
"class_name": "invoice_number",
|
||||
"text_value": "F2024-001",
|
||||
"bbox_x": 100,
|
||||
"bbox_y": 200,
|
||||
"bbox_width": 150,
|
||||
"bbox_height": 30,
|
||||
"source": "manual",
|
||||
}
|
||||
history = AnnotationHistory(
|
||||
annotation_id=annotation_id,
|
||||
action="created",
|
||||
new_value=new_value,
|
||||
changed_by="admin-token-123",
|
||||
)
|
||||
|
||||
assert history.history_id is not None
|
||||
assert history.annotation_id == annotation_id
|
||||
assert history.action == "created"
|
||||
assert history.previous_value is None
|
||||
assert history.new_value == new_value
|
||||
assert history.changed_by == "admin-token-123"
|
||||
|
||||
def test_annotation_history_updated(self):
|
||||
"""Test history record for update."""
|
||||
annotation_id = uuid4()
|
||||
previous_value = {
|
||||
"text_value": "F2024-001",
|
||||
"bbox_x": 100,
|
||||
}
|
||||
new_value = {
|
||||
"text_value": "F2024-001-A",
|
||||
"bbox_x": 110,
|
||||
}
|
||||
history = AnnotationHistory(
|
||||
annotation_id=annotation_id,
|
||||
action="updated",
|
||||
previous_value=previous_value,
|
||||
new_value=new_value,
|
||||
changed_by="admin-token-123",
|
||||
change_reason="Corrected OCR error",
|
||||
)
|
||||
|
||||
assert history.action == "updated"
|
||||
assert history.previous_value == previous_value
|
||||
assert history.new_value == new_value
|
||||
assert history.change_reason == "Corrected OCR error"
|
||||
|
||||
def test_annotation_history_override(self):
|
||||
"""Test history record for override."""
|
||||
annotation_id = uuid4()
|
||||
previous_value = {
|
||||
"text_value": "F2024-001",
|
||||
"source": "auto",
|
||||
"confidence": 0.85,
|
||||
}
|
||||
new_value = {
|
||||
"text_value": "F2024-001-CORRECTED",
|
||||
"source": "manual",
|
||||
"confidence": None,
|
||||
}
|
||||
history = AnnotationHistory(
|
||||
annotation_id=annotation_id,
|
||||
action="override",
|
||||
previous_value=previous_value,
|
||||
new_value=new_value,
|
||||
changed_by="admin-token-123",
|
||||
change_reason="Manual correction of auto-label",
|
||||
)
|
||||
|
||||
assert history.action == "override"
|
||||
|
||||
def test_annotation_history_deleted(self):
|
||||
"""Test history record for deletion."""
|
||||
annotation_id = uuid4()
|
||||
previous_value = {
|
||||
"class_id": 6,
|
||||
"class_name": "amount",
|
||||
"text_value": "1500.00",
|
||||
}
|
||||
history = AnnotationHistory(
|
||||
annotation_id=annotation_id,
|
||||
action="deleted",
|
||||
previous_value=previous_value,
|
||||
changed_by="admin-token-123",
|
||||
change_reason="Incorrect annotation",
|
||||
)
|
||||
|
||||
assert history.action == "deleted"
|
||||
assert history.new_value is None
|
||||
|
||||
|
||||
class TestAdminDocumentExtensions:
|
||||
"""Tests for AdminDocument extensions."""
|
||||
|
||||
def test_document_with_upload_source(self):
|
||||
"""Test document with upload source field."""
|
||||
doc = AdminDocument(
|
||||
admin_token="test-token",
|
||||
filename="invoice.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/tmp/invoice.pdf",
|
||||
upload_source="api",
|
||||
)
|
||||
|
||||
assert doc.upload_source == "api"
|
||||
|
||||
def test_document_with_batch_id(self):
|
||||
"""Test document linked to batch upload."""
|
||||
batch_id = uuid4()
|
||||
doc = AdminDocument(
|
||||
admin_token="test-token",
|
||||
filename="invoice.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/tmp/invoice.pdf",
|
||||
batch_id=batch_id,
|
||||
)
|
||||
|
||||
assert doc.batch_id == batch_id
|
||||
|
||||
def test_document_with_csv_field_values(self):
|
||||
"""Test document with CSV field values."""
|
||||
csv_values = {
|
||||
"InvoiceNumber": "F2024-001",
|
||||
"Amount": "1500.00",
|
||||
"OCR": "7350012345678",
|
||||
}
|
||||
doc = AdminDocument(
|
||||
admin_token="test-token",
|
||||
filename="invoice.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/tmp/invoice.pdf",
|
||||
csv_field_values=csv_values,
|
||||
)
|
||||
|
||||
assert doc.csv_field_values == csv_values
|
||||
|
||||
def test_document_with_annotation_lock(self):
|
||||
"""Test document with annotation lock."""
|
||||
lock_until = datetime.utcnow()
|
||||
doc = AdminDocument(
|
||||
admin_token="test-token",
|
||||
filename="invoice.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/tmp/invoice.pdf",
|
||||
annotation_lock_until=lock_until,
|
||||
)
|
||||
|
||||
assert doc.annotation_lock_until == lock_until
|
||||
|
||||
|
||||
class TestAdminAnnotationExtensions:
|
||||
"""Tests for AdminAnnotation extensions."""
|
||||
|
||||
def test_annotation_with_verification(self):
|
||||
"""Test annotation with verification fields."""
|
||||
now = datetime.utcnow()
|
||||
ann = AdminAnnotation(
|
||||
document_id=uuid4(),
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.3,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=100,
|
||||
bbox_y=200,
|
||||
bbox_width=150,
|
||||
bbox_height=30,
|
||||
is_verified=True,
|
||||
verified_at=now,
|
||||
verified_by="admin-token-123",
|
||||
)
|
||||
|
||||
assert ann.is_verified is True
|
||||
assert ann.verified_at == now
|
||||
assert ann.verified_by == "admin-token-123"
|
||||
|
||||
def test_annotation_with_override_info(self):
|
||||
"""Test annotation with override information."""
|
||||
original_id = uuid4()
|
||||
ann = AdminAnnotation(
|
||||
document_id=uuid4(),
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.3,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=100,
|
||||
bbox_y=200,
|
||||
bbox_width=150,
|
||||
bbox_height=30,
|
||||
source="manual",
|
||||
override_source="auto",
|
||||
original_annotation_id=original_id,
|
||||
)
|
||||
|
||||
assert ann.override_source == "auto"
|
||||
assert ann.original_annotation_id == original_id
|
||||
|
||||
|
||||
class TestTrainingTaskExtensions:
|
||||
"""Tests for TrainingTask extensions."""
|
||||
|
||||
def test_training_task_with_document_count(self):
|
||||
"""Test training task with document count."""
|
||||
task = TrainingTask(
|
||||
admin_token="test-token",
|
||||
name="Training Run 2024-01",
|
||||
document_count=500,
|
||||
)
|
||||
|
||||
assert task.document_count == 500
|
||||
|
||||
def test_training_task_with_metrics(self):
|
||||
"""Test training task with extracted metrics."""
|
||||
task = TrainingTask(
|
||||
admin_token="test-token",
|
||||
name="Training Run 2024-01",
|
||||
status="completed",
|
||||
metrics_mAP=0.935,
|
||||
metrics_precision=0.92,
|
||||
metrics_recall=0.88,
|
||||
)
|
||||
|
||||
assert task.metrics_mAP == 0.935
|
||||
assert task.metrics_precision == 0.92
|
||||
assert task.metrics_recall == 0.88
|
||||
|
||||
|
||||
class TestCSVToClassMapping:
|
||||
"""Tests for CSV column to class ID mapping."""
|
||||
|
||||
def test_csv_mapping_exists(self):
|
||||
"""Test that CSV mapping is defined."""
|
||||
assert CSV_TO_CLASS_MAPPING is not None
|
||||
assert len(CSV_TO_CLASS_MAPPING) > 0
|
||||
|
||||
def test_csv_mapping_values(self):
|
||||
"""Test specific CSV column mappings."""
|
||||
assert CSV_TO_CLASS_MAPPING["InvoiceNumber"] == 0
|
||||
assert CSV_TO_CLASS_MAPPING["InvoiceDate"] == 1
|
||||
assert CSV_TO_CLASS_MAPPING["InvoiceDueDate"] == 2
|
||||
assert CSV_TO_CLASS_MAPPING["OCR"] == 3
|
||||
assert CSV_TO_CLASS_MAPPING["Bankgiro"] == 4
|
||||
assert CSV_TO_CLASS_MAPPING["Plusgiro"] == 5
|
||||
assert CSV_TO_CLASS_MAPPING["Amount"] == 6
|
||||
assert CSV_TO_CLASS_MAPPING["supplier_organisation_number"] == 7
|
||||
assert CSV_TO_CLASS_MAPPING["customer_number"] == 9
|
||||
|
||||
def test_csv_mapping_matches_field_classes(self):
|
||||
"""Test that CSV mapping is consistent with FIELD_CLASSES."""
|
||||
for csv_name, class_id in CSV_TO_CLASS_MAPPING.items():
|
||||
assert class_id in FIELD_CLASSES
|
||||
@@ -18,7 +18,7 @@ class TestDatabaseConfig:
|
||||
def test_config_loads_from_env(self):
|
||||
"""Test that config loads successfully from .env file."""
|
||||
# Import config (should load .env automatically)
|
||||
import config
|
||||
from src import config
|
||||
|
||||
# Verify database config is loaded
|
||||
assert config.DATABASE is not None
|
||||
@@ -30,7 +30,7 @@ class TestDatabaseConfig:
|
||||
|
||||
def test_database_password_loaded(self):
|
||||
"""Test that database password is loaded from environment."""
|
||||
import config
|
||||
from src import config
|
||||
|
||||
# Password should be loaded from .env
|
||||
assert config.DATABASE['password'] is not None
|
||||
@@ -38,7 +38,7 @@ class TestDatabaseConfig:
|
||||
|
||||
def test_database_connection_string(self):
|
||||
"""Test database connection string generation."""
|
||||
import config
|
||||
from src import config
|
||||
|
||||
conn_str = config.get_db_connection_string()
|
||||
|
||||
@@ -71,7 +71,7 @@ class TestPathsConfig:
|
||||
|
||||
def test_paths_config_exists(self):
|
||||
"""Test that PATHS configuration exists."""
|
||||
import config
|
||||
from src import config
|
||||
|
||||
assert config.PATHS is not None
|
||||
assert 'csv_dir' in config.PATHS
|
||||
@@ -85,7 +85,7 @@ class TestAutolabelConfig:
|
||||
|
||||
def test_autolabel_config_exists(self):
|
||||
"""Test that AUTOLABEL configuration exists."""
|
||||
import config
|
||||
from src import config
|
||||
|
||||
assert config.AUTOLABEL is not None
|
||||
assert 'workers' in config.AUTOLABEL
|
||||
@@ -95,7 +95,7 @@ class TestAutolabelConfig:
|
||||
|
||||
def test_autolabel_ratios_sum_to_one(self):
|
||||
"""Test that train/val/test ratios sum to 1.0."""
|
||||
import config
|
||||
from src import config
|
||||
|
||||
total = (
|
||||
config.AUTOLABEL['train_ratio'] +
|
||||
|
||||
1
tests/web/__init__.py
Normal file
1
tests/web/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for web API components."""
|
||||
132
tests/web/conftest.py
Normal file
132
tests/web/conftest.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
Test fixtures for web API tests.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
||||
from src.data.async_request_db import ApiKeyConfig, AsyncRequestDB
|
||||
from src.data.models import AsyncRequest
|
||||
from src.web.workers.async_queue import AsyncTask, AsyncTaskQueue
|
||||
from src.web.services.async_processing import AsyncProcessingService
|
||||
from src.web.config import AsyncConfig, StorageConfig
|
||||
from src.web.core.rate_limiter import RateLimiter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db():
|
||||
"""Create a mock AsyncRequestDB."""
|
||||
db = MagicMock(spec=AsyncRequestDB)
|
||||
|
||||
# Default return values
|
||||
db.is_valid_api_key.return_value = True
|
||||
db.get_api_key_config.return_value = ApiKeyConfig(
|
||||
api_key="test-api-key",
|
||||
name="Test Key",
|
||||
is_active=True,
|
||||
requests_per_minute=10,
|
||||
max_concurrent_jobs=3,
|
||||
max_file_size_mb=50,
|
||||
)
|
||||
db.count_active_jobs.return_value = 0
|
||||
db.get_queue_position.return_value = 1
|
||||
|
||||
return db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rate_limiter(mock_db):
|
||||
"""Create a RateLimiter with mock database."""
|
||||
return RateLimiter(mock_db)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def task_queue():
|
||||
"""Create an AsyncTaskQueue."""
|
||||
return AsyncTaskQueue(max_size=10, worker_count=1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_config():
|
||||
"""Create an AsyncConfig for testing."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield AsyncConfig(
|
||||
queue_max_size=10,
|
||||
worker_count=1,
|
||||
task_timeout_seconds=30,
|
||||
result_retention_days=7,
|
||||
temp_upload_dir=Path(tmpdir) / "async",
|
||||
max_file_size_mb=10,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage_config():
|
||||
"""Create a StorageConfig for testing."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield StorageConfig(
|
||||
upload_dir=Path(tmpdir) / "uploads",
|
||||
result_dir=Path(tmpdir) / "results",
|
||||
max_file_size_mb=50,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_inference_service():
|
||||
"""Create a mock InferenceService."""
|
||||
service = MagicMock()
|
||||
service.is_initialized = True
|
||||
service.gpu_available = False
|
||||
|
||||
# Mock process_pdf to return a successful result
|
||||
mock_result = MagicMock()
|
||||
mock_result.document_id = "test-doc"
|
||||
mock_result.success = True
|
||||
mock_result.document_type = "invoice"
|
||||
mock_result.fields = {"InvoiceNumber": "12345", "Amount": "1000.00"}
|
||||
mock_result.confidence = {"InvoiceNumber": 0.95, "Amount": 0.92}
|
||||
mock_result.detections = []
|
||||
mock_result.errors = []
|
||||
mock_result.visualization_path = None
|
||||
|
||||
service.process_pdf.return_value = mock_result
|
||||
service.process_image.return_value = mock_result
|
||||
|
||||
return service
|
||||
|
||||
|
||||
# Valid UUID for testing
|
||||
TEST_REQUEST_UUID = "550e8400-e29b-41d4-a716-446655440000"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_async_request():
|
||||
"""Create a sample AsyncRequest."""
|
||||
return AsyncRequest(
|
||||
request_id=UUID(TEST_REQUEST_UUID),
|
||||
api_key="test-api-key",
|
||||
status="pending",
|
||||
filename="test.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
expires_at=datetime.utcnow() + timedelta(days=7),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_task():
|
||||
"""Create a sample AsyncTask."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
|
||||
f.write(b"fake pdf content")
|
||||
return AsyncTask(
|
||||
request_id=TEST_REQUEST_UUID,
|
||||
api_key="test-api-key",
|
||||
file_path=Path(f.name),
|
||||
filename="test.pdf",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
197
tests/web/test_admin_annotations.py
Normal file
197
tests/web/test_admin_annotations.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
Tests for Admin Annotation Routes.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from src.data.admin_models import AdminAnnotation, AdminDocument, FIELD_CLASSES
|
||||
from src.web.api.v1.admin.annotations import _validate_uuid, create_annotation_router
|
||||
from src.web.schemas.admin import (
|
||||
AnnotationCreate,
|
||||
AnnotationUpdate,
|
||||
AutoLabelRequest,
|
||||
BoundingBox,
|
||||
)
|
||||
|
||||
|
||||
# Test UUIDs
|
||||
TEST_DOC_UUID = "550e8400-e29b-41d4-a716-446655440000"
|
||||
TEST_ANN_UUID = "660e8400-e29b-41d4-a716-446655440001"
|
||||
TEST_TOKEN = "test-admin-token-12345"
|
||||
|
||||
|
||||
class TestAnnotationRouterCreation:
|
||||
"""Tests for annotation router creation."""
|
||||
|
||||
def test_creates_router_with_endpoints(self):
|
||||
"""Test router is created with expected endpoints."""
|
||||
router = create_annotation_router()
|
||||
|
||||
# Get route paths (includes prefix)
|
||||
paths = [route.path for route in router.routes]
|
||||
|
||||
# Paths include the /admin/documents prefix
|
||||
assert any("{document_id}/annotations" in p for p in paths)
|
||||
assert any("{annotation_id}" in p for p in paths)
|
||||
assert any("auto-label" in p for p in paths)
|
||||
assert any("images" in p for p in paths)
|
||||
|
||||
|
||||
class TestAnnotationCreateSchema:
|
||||
"""Tests for AnnotationCreate schema."""
|
||||
|
||||
def test_valid_annotation(self):
|
||||
"""Test valid annotation creation."""
|
||||
ann = AnnotationCreate(
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
bbox=BoundingBox(x=100, y=100, width=200, height=50),
|
||||
text_value="12345",
|
||||
)
|
||||
|
||||
assert ann.page_number == 1
|
||||
assert ann.class_id == 0
|
||||
assert ann.bbox.x == 100
|
||||
assert ann.text_value == "12345"
|
||||
|
||||
def test_class_id_range(self):
|
||||
"""Test class_id must be 0-9."""
|
||||
# Valid class IDs
|
||||
for class_id in range(10):
|
||||
ann = AnnotationCreate(
|
||||
page_number=1,
|
||||
class_id=class_id,
|
||||
bbox=BoundingBox(x=0, y=0, width=100, height=50),
|
||||
)
|
||||
assert ann.class_id == class_id
|
||||
|
||||
def test_bbox_validation(self):
|
||||
"""Test bounding box validation."""
|
||||
bbox = BoundingBox(x=0, y=0, width=100, height=50)
|
||||
assert bbox.width >= 1
|
||||
assert bbox.height >= 1
|
||||
|
||||
|
||||
class TestAnnotationUpdateSchema:
|
||||
"""Tests for AnnotationUpdate schema."""
|
||||
|
||||
def test_partial_update(self):
|
||||
"""Test partial update with only some fields."""
|
||||
update = AnnotationUpdate(
|
||||
text_value="new value",
|
||||
)
|
||||
|
||||
assert update.text_value == "new value"
|
||||
assert update.class_id is None
|
||||
assert update.bbox is None
|
||||
|
||||
def test_bbox_update(self):
|
||||
"""Test bounding box update."""
|
||||
update = AnnotationUpdate(
|
||||
bbox=BoundingBox(x=50, y=50, width=150, height=75),
|
||||
)
|
||||
|
||||
assert update.bbox.x == 50
|
||||
assert update.bbox.width == 150
|
||||
|
||||
|
||||
class TestAutoLabelRequestSchema:
|
||||
"""Tests for AutoLabelRequest schema."""
|
||||
|
||||
def test_valid_request(self):
|
||||
"""Test valid auto-label request."""
|
||||
request = AutoLabelRequest(
|
||||
field_values={
|
||||
"InvoiceNumber": "12345",
|
||||
"Amount": "1000.00",
|
||||
},
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
assert len(request.field_values) == 2
|
||||
assert request.field_values["InvoiceNumber"] == "12345"
|
||||
assert request.replace_existing is True
|
||||
|
||||
def test_requires_field_values(self):
|
||||
"""Test that field_values is required."""
|
||||
with pytest.raises(Exception):
|
||||
AutoLabelRequest(replace_existing=True)
|
||||
|
||||
|
||||
class TestFieldClasses:
|
||||
"""Tests for field class mapping."""
|
||||
|
||||
def test_all_classes_defined(self):
|
||||
"""Test all 10 field classes are defined."""
|
||||
assert len(FIELD_CLASSES) == 10
|
||||
|
||||
def test_class_ids_sequential(self):
|
||||
"""Test class IDs are 0-9."""
|
||||
assert set(FIELD_CLASSES.keys()) == set(range(10))
|
||||
|
||||
def test_known_field_names(self):
|
||||
"""Test known field names are present."""
|
||||
names = list(FIELD_CLASSES.values())
|
||||
|
||||
assert "invoice_number" in names
|
||||
assert "invoice_date" in names
|
||||
assert "amount" in names
|
||||
assert "bankgiro" in names
|
||||
assert "ocr_number" in names
|
||||
|
||||
|
||||
class TestAnnotationModel:
|
||||
"""Tests for AdminAnnotation model."""
|
||||
|
||||
def test_annotation_creation(self):
|
||||
"""Test annotation model creation."""
|
||||
ann = AdminAnnotation(
|
||||
document_id=UUID(TEST_DOC_UUID),
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.5,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=100,
|
||||
bbox_y=100,
|
||||
bbox_width=200,
|
||||
bbox_height=50,
|
||||
text_value="12345",
|
||||
confidence=0.95,
|
||||
source="manual",
|
||||
)
|
||||
|
||||
assert str(ann.document_id) == TEST_DOC_UUID
|
||||
assert ann.class_id == 0
|
||||
assert ann.x_center == 0.5
|
||||
assert ann.source == "manual"
|
||||
|
||||
def test_normalized_coordinates(self):
|
||||
"""Test normalized coordinates are 0-1 range."""
|
||||
# Valid normalized coords
|
||||
ann = AdminAnnotation(
|
||||
document_id=UUID(TEST_DOC_UUID),
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="test",
|
||||
x_center=0.5,
|
||||
y_center=0.5,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=0,
|
||||
bbox_y=0,
|
||||
bbox_width=100,
|
||||
bbox_height=50,
|
||||
)
|
||||
|
||||
assert 0 <= ann.x_center <= 1
|
||||
assert 0 <= ann.y_center <= 1
|
||||
assert 0 <= ann.width <= 1
|
||||
assert 0 <= ann.height <= 1
|
||||
162
tests/web/test_admin_auth.py
Normal file
162
tests/web/test_admin_auth.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
Tests for Admin Authentication.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from src.data.admin_db import AdminDB
|
||||
from src.data.admin_models import AdminToken
|
||||
from src.web.core.auth import (
|
||||
get_admin_db,
|
||||
reset_admin_db,
|
||||
validate_admin_token,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_admin_db():
|
||||
"""Create a mock AdminDB."""
|
||||
db = MagicMock(spec=AdminDB)
|
||||
db.is_valid_admin_token.return_value = True
|
||||
return db
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_db():
|
||||
"""Reset admin DB after each test."""
|
||||
yield
|
||||
reset_admin_db()
|
||||
|
||||
|
||||
class TestValidateAdminToken:
|
||||
"""Tests for validate_admin_token dependency."""
|
||||
|
||||
def test_missing_token_raises_401(self, mock_admin_db):
|
||||
"""Test that missing token raises 401."""
|
||||
import asyncio
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
validate_admin_token(None, mock_admin_db)
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Admin token required" in exc_info.value.detail
|
||||
|
||||
def test_invalid_token_raises_401(self, mock_admin_db):
|
||||
"""Test that invalid token raises 401."""
|
||||
import asyncio
|
||||
|
||||
mock_admin_db.is_valid_admin_token.return_value = False
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
validate_admin_token("invalid-token", mock_admin_db)
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid or expired" in exc_info.value.detail
|
||||
|
||||
def test_valid_token_returns_token(self, mock_admin_db):
|
||||
"""Test that valid token is returned."""
|
||||
import asyncio
|
||||
|
||||
token = "valid-test-token"
|
||||
mock_admin_db.is_valid_admin_token.return_value = True
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
validate_admin_token(token, mock_admin_db)
|
||||
)
|
||||
|
||||
assert result == token
|
||||
mock_admin_db.update_admin_token_usage.assert_called_once_with(token)
|
||||
|
||||
|
||||
class TestAdminDB:
|
||||
"""Tests for AdminDB operations."""
|
||||
|
||||
def test_is_valid_admin_token_active(self):
|
||||
"""Test valid active token."""
|
||||
with patch("src.data.admin_db.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_token = AdminToken(
|
||||
token="test-token",
|
||||
name="Test",
|
||||
is_active=True,
|
||||
expires_at=None,
|
||||
)
|
||||
mock_session.get.return_value = mock_token
|
||||
|
||||
db = AdminDB()
|
||||
assert db.is_valid_admin_token("test-token") is True
|
||||
|
||||
def test_is_valid_admin_token_inactive(self):
|
||||
"""Test inactive token."""
|
||||
with patch("src.data.admin_db.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_token = AdminToken(
|
||||
token="test-token",
|
||||
name="Test",
|
||||
is_active=False,
|
||||
expires_at=None,
|
||||
)
|
||||
mock_session.get.return_value = mock_token
|
||||
|
||||
db = AdminDB()
|
||||
assert db.is_valid_admin_token("test-token") is False
|
||||
|
||||
def test_is_valid_admin_token_expired(self):
|
||||
"""Test expired token."""
|
||||
with patch("src.data.admin_db.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_token = AdminToken(
|
||||
token="test-token",
|
||||
name="Test",
|
||||
is_active=True,
|
||||
expires_at=datetime.utcnow() - timedelta(days=1),
|
||||
)
|
||||
mock_session.get.return_value = mock_token
|
||||
|
||||
db = AdminDB()
|
||||
assert db.is_valid_admin_token("test-token") is False
|
||||
|
||||
def test_is_valid_admin_token_not_found(self):
|
||||
"""Test token not found."""
|
||||
with patch("src.data.admin_db.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
mock_session.get.return_value = None
|
||||
|
||||
db = AdminDB()
|
||||
assert db.is_valid_admin_token("nonexistent") is False
|
||||
|
||||
|
||||
class TestGetAdminDb:
|
||||
"""Tests for get_admin_db function."""
|
||||
|
||||
def test_returns_singleton(self):
|
||||
"""Test that get_admin_db returns singleton."""
|
||||
reset_admin_db()
|
||||
|
||||
db1 = get_admin_db()
|
||||
db2 = get_admin_db()
|
||||
|
||||
assert db1 is db2
|
||||
|
||||
def test_reset_clears_singleton(self):
|
||||
"""Test that reset clears singleton."""
|
||||
db1 = get_admin_db()
|
||||
reset_admin_db()
|
||||
db2 = get_admin_db()
|
||||
|
||||
assert db1 is not db2
|
||||
164
tests/web/test_admin_routes.py
Normal file
164
tests/web/test_admin_routes.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
Tests for Admin Document Routes.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from src.data.admin_models import AdminDocument, AdminToken
|
||||
from src.web.api.v1.admin.documents import _validate_uuid, create_admin_router
|
||||
|
||||
|
||||
# Test UUID
|
||||
TEST_DOC_UUID = "550e8400-e29b-41d4-a716-446655440000"
|
||||
TEST_TOKEN = "test-admin-token-12345"
|
||||
|
||||
|
||||
class TestValidateUUID:
|
||||
"""Tests for UUID validation."""
|
||||
|
||||
def test_valid_uuid(self):
|
||||
"""Test valid UUID passes validation."""
|
||||
_validate_uuid(TEST_DOC_UUID, "test") # Should not raise
|
||||
|
||||
def test_invalid_uuid_raises_400(self):
|
||||
"""Test invalid UUID raises 400."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
_validate_uuid("not-a-uuid", "document_id")
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "Invalid document_id format" in exc_info.value.detail
|
||||
|
||||
|
||||
class TestAdminRouter:
|
||||
"""Tests for admin router creation."""
|
||||
|
||||
def test_creates_router_with_endpoints(self):
|
||||
"""Test router is created with expected endpoints."""
|
||||
router = create_admin_router((".pdf", ".png", ".jpg"))
|
||||
|
||||
# Get route paths (include prefix from router)
|
||||
paths = [route.path for route in router.routes]
|
||||
|
||||
# Paths include the /admin prefix
|
||||
assert any("/auth/token" in p for p in paths)
|
||||
assert any("/documents" in p for p in paths)
|
||||
assert any("/documents/stats" in p for p in paths)
|
||||
assert any("{document_id}" in p for p in paths)
|
||||
|
||||
|
||||
class TestCreateTokenEndpoint:
|
||||
"""Tests for POST /admin/auth/token endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db(self):
|
||||
"""Create mock AdminDB."""
|
||||
db = MagicMock()
|
||||
db.is_valid_admin_token.return_value = True
|
||||
return db
|
||||
|
||||
def test_create_token_success(self, mock_db):
|
||||
"""Test successful token creation."""
|
||||
from src.web.schemas.admin import AdminTokenCreate
|
||||
|
||||
request = AdminTokenCreate(name="Test Token", expires_in_days=30)
|
||||
|
||||
# The actual endpoint would generate a token
|
||||
# This tests the schema validation
|
||||
assert request.name == "Test Token"
|
||||
assert request.expires_in_days == 30
|
||||
|
||||
|
||||
class TestDocumentUploadEndpoint:
|
||||
"""Tests for POST /admin/documents endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_pdf_bytes(self):
|
||||
"""Create sample PDF-like bytes."""
|
||||
# Minimal PDF header
|
||||
return b"%PDF-1.4\n%\xe2\xe3\xcf\xd3\n"
|
||||
|
||||
@pytest.fixture
|
||||
def mock_admin_db(self):
|
||||
"""Create mock AdminDB."""
|
||||
db = MagicMock()
|
||||
db.is_valid_admin_token.return_value = True
|
||||
db.create_document.return_value = TEST_DOC_UUID
|
||||
return db
|
||||
|
||||
def test_rejects_invalid_extension(self):
|
||||
"""Test that invalid file extensions are rejected."""
|
||||
# Schema validation would happen at the route level
|
||||
allowed = (".pdf", ".png", ".jpg")
|
||||
file_ext = ".exe"
|
||||
|
||||
assert file_ext not in allowed
|
||||
|
||||
|
||||
class TestDocumentListEndpoint:
|
||||
"""Tests for GET /admin/documents endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_documents(self):
|
||||
"""Create sample documents."""
|
||||
return [
|
||||
AdminDocument(
|
||||
document_id=UUID(TEST_DOC_UUID),
|
||||
admin_token=TEST_TOKEN,
|
||||
filename="test.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/tmp/test.pdf",
|
||||
page_count=1,
|
||||
status="pending",
|
||||
),
|
||||
]
|
||||
|
||||
def test_validates_status_filter(self):
|
||||
"""Test that invalid status filter is rejected."""
|
||||
valid_statuses = ("pending", "auto_labeling", "labeled", "exported")
|
||||
|
||||
assert "invalid_status" not in valid_statuses
|
||||
assert "pending" in valid_statuses
|
||||
|
||||
|
||||
class TestDocumentDetailEndpoint:
|
||||
"""Tests for GET /admin/documents/{document_id} endpoint."""
|
||||
|
||||
def test_requires_valid_uuid(self):
|
||||
"""Test that invalid UUID is rejected."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
_validate_uuid("invalid", "document_id")
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
|
||||
class TestDocumentDeleteEndpoint:
|
||||
"""Tests for DELETE /admin/documents/{document_id} endpoint."""
|
||||
|
||||
def test_validates_document_id(self):
|
||||
"""Test that document_id is validated."""
|
||||
# Valid UUID should not raise
|
||||
_validate_uuid(TEST_DOC_UUID, "document_id")
|
||||
|
||||
# Invalid should raise
|
||||
with pytest.raises(HTTPException):
|
||||
_validate_uuid("bad-id", "document_id")
|
||||
|
||||
|
||||
class TestDocumentStatusUpdateEndpoint:
|
||||
"""Tests for PATCH /admin/documents/{document_id}/status endpoint."""
|
||||
|
||||
def test_validates_status_values(self):
|
||||
"""Test that only valid statuses are accepted."""
|
||||
valid_statuses = ("pending", "labeled", "exported")
|
||||
|
||||
assert "pending" in valid_statuses
|
||||
assert "invalid" not in valid_statuses
|
||||
351
tests/web/test_admin_routes_enhanced.py
Normal file
351
tests/web/test_admin_routes_enhanced.py
Normal file
@@ -0,0 +1,351 @@
|
||||
"""
|
||||
Tests for Enhanced Admin Document Routes (Phase 3).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from src.web.api.v1.admin.documents import create_admin_router
|
||||
from src.web.core.auth import validate_admin_token, get_admin_db
|
||||
|
||||
|
||||
class MockAdminDocument:
|
||||
"""Mock AdminDocument for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.document_id = kwargs.get('document_id', uuid4())
|
||||
self.admin_token = kwargs.get('admin_token', 'test-token')
|
||||
self.filename = kwargs.get('filename', 'test.pdf')
|
||||
self.file_size = kwargs.get('file_size', 100000)
|
||||
self.content_type = kwargs.get('content_type', 'application/pdf')
|
||||
self.page_count = kwargs.get('page_count', 1)
|
||||
self.status = kwargs.get('status', 'pending')
|
||||
self.auto_label_status = kwargs.get('auto_label_status', None)
|
||||
self.auto_label_error = kwargs.get('auto_label_error', None)
|
||||
self.upload_source = kwargs.get('upload_source', 'ui')
|
||||
self.batch_id = kwargs.get('batch_id', None)
|
||||
self.csv_field_values = kwargs.get('csv_field_values', None)
|
||||
self.annotation_lock_until = kwargs.get('annotation_lock_until', None)
|
||||
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
||||
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockAnnotation:
|
||||
"""Mock AdminAnnotation for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.annotation_id = kwargs.get('annotation_id', uuid4())
|
||||
self.document_id = kwargs.get('document_id')
|
||||
self.page_number = kwargs.get('page_number', 1)
|
||||
self.class_id = kwargs.get('class_id', 0)
|
||||
self.class_name = kwargs.get('class_name', 'invoice_number')
|
||||
self.bbox_x = kwargs.get('bbox_x', 100.0)
|
||||
self.bbox_y = kwargs.get('bbox_y', 100.0)
|
||||
self.bbox_width = kwargs.get('bbox_width', 200.0)
|
||||
self.bbox_height = kwargs.get('bbox_height', 50.0)
|
||||
self.x_center = kwargs.get('x_center', 0.5)
|
||||
self.y_center = kwargs.get('y_center', 0.5)
|
||||
self.width = kwargs.get('width', 0.3)
|
||||
self.height = kwargs.get('height', 0.1)
|
||||
self.text_value = kwargs.get('text_value', 'INV-001')
|
||||
self.confidence = kwargs.get('confidence', 0.95)
|
||||
self.source = kwargs.get('source', 'manual')
|
||||
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockAdminDB:
|
||||
"""Mock AdminDB for testing enhanced features."""
|
||||
|
||||
def __init__(self):
|
||||
self.documents = {}
|
||||
self.annotations = {}
|
||||
|
||||
def get_documents_by_token(
|
||||
self,
|
||||
admin_token,
|
||||
status=None,
|
||||
upload_source=None,
|
||||
has_annotations=None,
|
||||
auto_label_status=None,
|
||||
batch_id=None,
|
||||
limit=20,
|
||||
offset=0
|
||||
):
|
||||
"""Get filtered documents."""
|
||||
docs = list(self.documents.values())
|
||||
|
||||
# Apply filters
|
||||
if status:
|
||||
docs = [d for d in docs if d.status == status]
|
||||
if upload_source:
|
||||
docs = [d for d in docs if d.upload_source == upload_source]
|
||||
if has_annotations is not None:
|
||||
for d in docs[:]:
|
||||
ann_count = len(self.annotations.get(str(d.document_id), []))
|
||||
if has_annotations and ann_count == 0:
|
||||
docs.remove(d)
|
||||
elif not has_annotations and ann_count > 0:
|
||||
docs.remove(d)
|
||||
if auto_label_status:
|
||||
docs = [d for d in docs if d.auto_label_status == auto_label_status]
|
||||
if batch_id:
|
||||
docs = [d for d in docs if str(d.batch_id) == str(batch_id)]
|
||||
|
||||
total = len(docs)
|
||||
return docs[offset:offset+limit], total
|
||||
|
||||
def get_annotations_for_document(self, document_id):
|
||||
"""Get annotations for document."""
|
||||
return self.annotations.get(str(document_id), [])
|
||||
|
||||
def count_documents_by_status(self, admin_token):
|
||||
"""Count documents by status."""
|
||||
counts = {}
|
||||
for doc in self.documents.values():
|
||||
if doc.admin_token == admin_token:
|
||||
counts[doc.status] = counts.get(doc.status, 0) + 1
|
||||
return counts
|
||||
|
||||
def get_document_by_token(self, document_id, admin_token):
|
||||
"""Get single document by ID and token."""
|
||||
doc = self.documents.get(document_id)
|
||||
if doc and doc.admin_token == admin_token:
|
||||
return doc
|
||||
return None
|
||||
|
||||
def get_document_training_tasks(self, document_id):
|
||||
"""Get training tasks that used this document."""
|
||||
return [] # No training history in this test
|
||||
|
||||
def get_training_task(self, task_id):
|
||||
"""Get training task by ID."""
|
||||
return None # No training tasks in this test
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""Create test FastAPI app."""
|
||||
app = FastAPI()
|
||||
|
||||
# Create mock DB
|
||||
mock_db = MockAdminDB()
|
||||
|
||||
# Add test documents
|
||||
doc1 = MockAdminDocument(
|
||||
filename="INV001.pdf",
|
||||
status="labeled",
|
||||
upload_source="ui",
|
||||
auto_label_status=None,
|
||||
batch_id=None
|
||||
)
|
||||
doc2 = MockAdminDocument(
|
||||
filename="INV002.pdf",
|
||||
status="labeled",
|
||||
upload_source="api",
|
||||
auto_label_status="completed",
|
||||
batch_id=uuid4()
|
||||
)
|
||||
doc3 = MockAdminDocument(
|
||||
filename="INV003.pdf",
|
||||
status="pending",
|
||||
upload_source="ui",
|
||||
auto_label_status=None, # Not auto-labeled yet
|
||||
batch_id=None
|
||||
)
|
||||
|
||||
mock_db.documents[str(doc1.document_id)] = doc1
|
||||
mock_db.documents[str(doc2.document_id)] = doc2
|
||||
mock_db.documents[str(doc3.document_id)] = doc3
|
||||
|
||||
# Add annotations to doc1 and doc2
|
||||
mock_db.annotations[str(doc1.document_id)] = [
|
||||
MockAnnotation(
|
||||
document_id=doc1.document_id,
|
||||
class_name="invoice_number",
|
||||
text_value="INV-001"
|
||||
)
|
||||
]
|
||||
mock_db.annotations[str(doc2.document_id)] = [
|
||||
MockAnnotation(
|
||||
document_id=doc2.document_id,
|
||||
class_id=6,
|
||||
class_name="amount",
|
||||
text_value="1500.00"
|
||||
),
|
||||
MockAnnotation(
|
||||
document_id=doc2.document_id,
|
||||
class_id=1,
|
||||
class_name="invoice_date",
|
||||
text_value="2024-01-15"
|
||||
)
|
||||
]
|
||||
|
||||
# Override dependencies
|
||||
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
|
||||
app.dependency_overrides[get_admin_db] = lambda: mock_db
|
||||
|
||||
# Include router
|
||||
router = create_admin_router((".pdf", ".png", ".jpg"))
|
||||
app.include_router(router)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create test client."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
class TestEnhancedDocumentList:
|
||||
"""Tests for enhanced document list endpoint."""
|
||||
|
||||
def test_list_documents_filter_by_upload_source_ui(self, client):
|
||||
"""Test filtering documents by upload_source=ui."""
|
||||
response = client.get("/admin/documents?upload_source=ui")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 2
|
||||
assert all(doc["filename"].startswith("INV") for doc in data["documents"])
|
||||
|
||||
def test_list_documents_filter_by_upload_source_api(self, client):
|
||||
"""Test filtering documents by upload_source=api."""
|
||||
response = client.get("/admin/documents?upload_source=api")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 1
|
||||
assert data["documents"][0]["filename"] == "INV002.pdf"
|
||||
|
||||
def test_list_documents_filter_by_has_annotations_true(self, client):
|
||||
"""Test filtering documents with annotations."""
|
||||
response = client.get("/admin/documents?has_annotations=true")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 2
|
||||
|
||||
def test_list_documents_filter_by_has_annotations_false(self, client):
|
||||
"""Test filtering documents without annotations."""
|
||||
response = client.get("/admin/documents?has_annotations=false")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 1
|
||||
|
||||
def test_list_documents_filter_by_auto_label_status(self, client):
|
||||
"""Test filtering by auto_label_status."""
|
||||
response = client.get("/admin/documents?auto_label_status=completed")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 1
|
||||
assert data["documents"][0]["filename"] == "INV002.pdf"
|
||||
|
||||
def test_list_documents_filter_by_batch_id(self, client):
|
||||
"""Test filtering by batch_id."""
|
||||
# Get a batch_id from the test data
|
||||
response_all = client.get("/admin/documents?upload_source=api")
|
||||
batch_id = response_all.json()["documents"][0]["batch_id"]
|
||||
|
||||
response = client.get(f"/admin/documents?batch_id={batch_id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 1
|
||||
|
||||
def test_list_documents_combined_filters(self, client):
|
||||
"""Test combining multiple filters."""
|
||||
response = client.get(
|
||||
"/admin/documents?status=labeled&upload_source=api"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 1
|
||||
assert data["documents"][0]["filename"] == "INV002.pdf"
|
||||
|
||||
def test_document_item_includes_new_fields(self, client):
|
||||
"""Test DocumentItem includes new Phase 2/3 fields."""
|
||||
response = client.get("/admin/documents?upload_source=api")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
doc = data["documents"][0]
|
||||
|
||||
# Check new fields exist
|
||||
assert "upload_source" in doc
|
||||
assert doc["upload_source"] == "api"
|
||||
assert "batch_id" in doc
|
||||
assert doc["batch_id"] is not None
|
||||
assert "can_annotate" in doc
|
||||
assert isinstance(doc["can_annotate"], bool)
|
||||
|
||||
|
||||
class TestEnhancedDocumentDetail:
|
||||
"""Tests for enhanced document detail endpoint."""
|
||||
|
||||
def test_document_detail_includes_new_fields(self, client, app):
|
||||
"""Test DocumentDetailResponse includes new Phase 2/3 fields."""
|
||||
# Get a document ID from list
|
||||
response = client.get("/admin/documents?upload_source=api")
|
||||
assert response.status_code == 200
|
||||
doc_list = response.json()
|
||||
document_id = doc_list["documents"][0]["document_id"]
|
||||
|
||||
# Get document detail
|
||||
response = client.get(f"/admin/documents/{document_id}")
|
||||
assert response.status_code == 200
|
||||
doc = response.json()
|
||||
|
||||
# Check new fields exist
|
||||
assert "upload_source" in doc
|
||||
assert doc["upload_source"] == "api"
|
||||
assert "batch_id" in doc
|
||||
assert doc["batch_id"] is not None
|
||||
assert "can_annotate" in doc
|
||||
assert isinstance(doc["can_annotate"], bool)
|
||||
assert "csv_field_values" in doc
|
||||
assert "annotation_lock_until" in doc
|
||||
|
||||
def test_document_detail_ui_upload_defaults(self, client, app):
|
||||
"""Test UI-uploaded document has correct defaults."""
|
||||
# Get a UI-uploaded document
|
||||
response = client.get("/admin/documents?upload_source=ui")
|
||||
assert response.status_code == 200
|
||||
doc_list = response.json()
|
||||
document_id = doc_list["documents"][0]["document_id"]
|
||||
|
||||
# Get document detail
|
||||
response = client.get(f"/admin/documents/{document_id}")
|
||||
assert response.status_code == 200
|
||||
doc = response.json()
|
||||
|
||||
# UI uploads should have these defaults
|
||||
assert doc["upload_source"] == "ui"
|
||||
assert doc["batch_id"] is None
|
||||
assert doc["csv_field_values"] is None
|
||||
assert doc["can_annotate"] is True
|
||||
assert doc["annotation_lock_until"] is None
|
||||
|
||||
def test_document_detail_with_annotations(self, client, app):
|
||||
"""Test document detail includes annotations."""
|
||||
# Get a document with annotations
|
||||
response = client.get("/admin/documents?has_annotations=true")
|
||||
assert response.status_code == 200
|
||||
doc_list = response.json()
|
||||
document_id = doc_list["documents"][0]["document_id"]
|
||||
|
||||
# Get document detail
|
||||
response = client.get(f"/admin/documents/{document_id}")
|
||||
assert response.status_code == 200
|
||||
doc = response.json()
|
||||
|
||||
# Should have annotations
|
||||
assert "annotations" in doc
|
||||
assert len(doc["annotations"]) > 0
|
||||
247
tests/web/test_admin_training.py
Normal file
247
tests/web/test_admin_training.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
Tests for Admin Training Routes and Scheduler.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
from src.data.admin_models import TrainingTask, TrainingLog
|
||||
from src.web.api.v1.admin.training import _validate_uuid, create_training_router
|
||||
from src.web.core.scheduler import (
|
||||
TrainingScheduler,
|
||||
get_training_scheduler,
|
||||
start_scheduler,
|
||||
stop_scheduler,
|
||||
)
|
||||
from src.web.schemas.admin import (
|
||||
TrainingConfig,
|
||||
TrainingStatus,
|
||||
TrainingTaskCreate,
|
||||
TrainingType,
|
||||
)
|
||||
|
||||
|
||||
# Test UUIDs
|
||||
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
|
||||
TEST_TOKEN = "test-admin-token-12345"
|
||||
|
||||
|
||||
class TestTrainingRouterCreation:
|
||||
"""Tests for training router creation."""
|
||||
|
||||
def test_creates_router_with_endpoints(self):
|
||||
"""Test router is created with expected endpoints."""
|
||||
router = create_training_router()
|
||||
|
||||
# Get route paths (include prefix)
|
||||
paths = [route.path for route in router.routes]
|
||||
|
||||
# Paths include the /admin/training prefix
|
||||
assert any("/tasks" in p for p in paths)
|
||||
assert any("{task_id}" in p for p in paths)
|
||||
assert any("cancel" in p for p in paths)
|
||||
assert any("logs" in p for p in paths)
|
||||
assert any("export" in p for p in paths)
|
||||
|
||||
|
||||
class TestTrainingConfigSchema:
|
||||
"""Tests for TrainingConfig schema."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default training configuration."""
|
||||
config = TrainingConfig()
|
||||
|
||||
assert config.model_name == "yolo11n.pt"
|
||||
assert config.epochs == 100
|
||||
assert config.batch_size == 16
|
||||
assert config.image_size == 640
|
||||
assert config.learning_rate == 0.01
|
||||
assert config.device == "0"
|
||||
|
||||
def test_custom_config(self):
|
||||
"""Test custom training configuration."""
|
||||
config = TrainingConfig(
|
||||
model_name="yolo11s.pt",
|
||||
epochs=50,
|
||||
batch_size=8,
|
||||
image_size=416,
|
||||
learning_rate=0.001,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
assert config.model_name == "yolo11s.pt"
|
||||
assert config.epochs == 50
|
||||
assert config.batch_size == 8
|
||||
|
||||
def test_config_validation(self):
|
||||
"""Test config validation constraints."""
|
||||
# Epochs must be 1-1000
|
||||
config = TrainingConfig(epochs=1)
|
||||
assert config.epochs == 1
|
||||
|
||||
config = TrainingConfig(epochs=1000)
|
||||
assert config.epochs == 1000
|
||||
|
||||
|
||||
class TestTrainingTaskCreateSchema:
|
||||
"""Tests for TrainingTaskCreate schema."""
|
||||
|
||||
def test_minimal_task(self):
|
||||
"""Test minimal task creation."""
|
||||
task = TrainingTaskCreate(name="Test Training")
|
||||
|
||||
assert task.name == "Test Training"
|
||||
assert task.task_type == TrainingType.TRAIN
|
||||
assert task.description is None
|
||||
assert task.scheduled_at is None
|
||||
|
||||
def test_scheduled_task(self):
|
||||
"""Test scheduled task creation."""
|
||||
scheduled_time = datetime.utcnow() + timedelta(hours=1)
|
||||
task = TrainingTaskCreate(
|
||||
name="Scheduled Training",
|
||||
scheduled_at=scheduled_time,
|
||||
)
|
||||
|
||||
assert task.scheduled_at == scheduled_time
|
||||
|
||||
def test_recurring_task(self):
|
||||
"""Test recurring task with cron expression."""
|
||||
task = TrainingTaskCreate(
|
||||
name="Recurring Training",
|
||||
cron_expression="0 0 * * 0", # Every Sunday at midnight
|
||||
)
|
||||
|
||||
assert task.cron_expression == "0 0 * * 0"
|
||||
|
||||
|
||||
class TestTrainingTaskModel:
|
||||
"""Tests for TrainingTask model."""
|
||||
|
||||
def test_task_creation(self):
|
||||
"""Test training task model creation."""
|
||||
task = TrainingTask(
|
||||
admin_token=TEST_TOKEN,
|
||||
name="Test Task",
|
||||
task_type="train",
|
||||
status="pending",
|
||||
)
|
||||
|
||||
assert task.name == "Test Task"
|
||||
assert task.task_type == "train"
|
||||
assert task.status == "pending"
|
||||
|
||||
def test_task_with_config(self):
|
||||
"""Test task with configuration."""
|
||||
config = {
|
||||
"model_name": "yolo11n.pt",
|
||||
"epochs": 100,
|
||||
}
|
||||
task = TrainingTask(
|
||||
admin_token=TEST_TOKEN,
|
||||
name="Configured Task",
|
||||
task_type="train",
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert task.config == config
|
||||
assert task.config["epochs"] == 100
|
||||
|
||||
|
||||
class TestTrainingLogModel:
|
||||
"""Tests for TrainingLog model."""
|
||||
|
||||
def test_log_creation(self):
|
||||
"""Test training log creation."""
|
||||
log = TrainingLog(
|
||||
task_id=UUID(TEST_TASK_UUID),
|
||||
level="INFO",
|
||||
message="Training started",
|
||||
)
|
||||
|
||||
assert str(log.task_id) == TEST_TASK_UUID
|
||||
assert log.level == "INFO"
|
||||
assert log.message == "Training started"
|
||||
|
||||
def test_log_with_details(self):
|
||||
"""Test log with additional details."""
|
||||
details = {
|
||||
"epoch": 10,
|
||||
"loss": 0.5,
|
||||
"mAP": 0.85,
|
||||
}
|
||||
log = TrainingLog(
|
||||
task_id=UUID(TEST_TASK_UUID),
|
||||
level="INFO",
|
||||
message="Epoch completed",
|
||||
details=details,
|
||||
)
|
||||
|
||||
assert log.details == details
|
||||
assert log.details["epoch"] == 10
|
||||
|
||||
|
||||
class TestTrainingScheduler:
|
||||
"""Tests for TrainingScheduler."""
|
||||
|
||||
@pytest.fixture
|
||||
def scheduler(self):
|
||||
"""Create a scheduler for testing."""
|
||||
return TrainingScheduler(check_interval_seconds=1)
|
||||
|
||||
def test_scheduler_creation(self, scheduler):
|
||||
"""Test scheduler creation."""
|
||||
assert scheduler._check_interval == 1
|
||||
assert scheduler._running is False
|
||||
assert scheduler._thread is None
|
||||
|
||||
def test_scheduler_start_stop(self, scheduler):
|
||||
"""Test scheduler start and stop."""
|
||||
with patch.object(scheduler, "_check_pending_tasks"):
|
||||
scheduler.start()
|
||||
assert scheduler._running is True
|
||||
assert scheduler._thread is not None
|
||||
|
||||
scheduler.stop()
|
||||
assert scheduler._running is False
|
||||
|
||||
def test_scheduler_singleton(self):
|
||||
"""Test get_training_scheduler returns singleton."""
|
||||
# Reset any existing scheduler
|
||||
stop_scheduler()
|
||||
|
||||
s1 = get_training_scheduler()
|
||||
s2 = get_training_scheduler()
|
||||
|
||||
assert s1 is s2
|
||||
|
||||
# Cleanup
|
||||
stop_scheduler()
|
||||
|
||||
|
||||
class TestTrainingStatusEnum:
|
||||
"""Tests for TrainingStatus enum."""
|
||||
|
||||
def test_all_statuses(self):
|
||||
"""Test all training statuses are defined."""
|
||||
statuses = [s.value for s in TrainingStatus]
|
||||
|
||||
assert "pending" in statuses
|
||||
assert "scheduled" in statuses
|
||||
assert "running" in statuses
|
||||
assert "completed" in statuses
|
||||
assert "failed" in statuses
|
||||
assert "cancelled" in statuses
|
||||
|
||||
|
||||
class TestTrainingTypeEnum:
|
||||
"""Tests for TrainingType enum."""
|
||||
|
||||
def test_all_types(self):
|
||||
"""Test all training types are defined."""
|
||||
types = [t.value for t in TrainingType]
|
||||
|
||||
assert "train" in types
|
||||
assert "finetune" in types
|
||||
276
tests/web/test_annotation_locks.py
Normal file
276
tests/web/test_annotation_locks.py
Normal file
@@ -0,0 +1,276 @@
|
||||
"""
|
||||
Tests for Annotation Lock Mechanism (Phase 3.3).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from src.web.api.v1.admin.documents import create_admin_router
|
||||
from src.web.core.auth import validate_admin_token, get_admin_db
|
||||
|
||||
|
||||
class MockAdminDocument:
|
||||
"""Mock AdminDocument for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.document_id = kwargs.get('document_id', uuid4())
|
||||
self.admin_token = kwargs.get('admin_token', 'test-token')
|
||||
self.filename = kwargs.get('filename', 'test.pdf')
|
||||
self.file_size = kwargs.get('file_size', 100000)
|
||||
self.content_type = kwargs.get('content_type', 'application/pdf')
|
||||
self.page_count = kwargs.get('page_count', 1)
|
||||
self.status = kwargs.get('status', 'pending')
|
||||
self.auto_label_status = kwargs.get('auto_label_status', None)
|
||||
self.auto_label_error = kwargs.get('auto_label_error', None)
|
||||
self.upload_source = kwargs.get('upload_source', 'ui')
|
||||
self.batch_id = kwargs.get('batch_id', None)
|
||||
self.csv_field_values = kwargs.get('csv_field_values', None)
|
||||
self.annotation_lock_until = kwargs.get('annotation_lock_until', None)
|
||||
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
||||
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockAdminDB:
|
||||
"""Mock AdminDB for testing annotation locks."""
|
||||
|
||||
def __init__(self):
|
||||
self.documents = {}
|
||||
|
||||
def get_document_by_token(self, document_id, admin_token):
|
||||
"""Get single document by ID and token."""
|
||||
doc = self.documents.get(document_id)
|
||||
if doc and doc.admin_token == admin_token:
|
||||
return doc
|
||||
return None
|
||||
|
||||
def acquire_annotation_lock(self, document_id, admin_token, duration_seconds=300):
|
||||
"""Acquire annotation lock for a document."""
|
||||
doc = self.documents.get(document_id)
|
||||
if not doc or doc.admin_token != admin_token:
|
||||
return None
|
||||
|
||||
# Check if already locked
|
||||
now = datetime.now(timezone.utc)
|
||||
if doc.annotation_lock_until and doc.annotation_lock_until > now:
|
||||
return None
|
||||
|
||||
# Acquire lock
|
||||
doc.annotation_lock_until = now + timedelta(seconds=duration_seconds)
|
||||
return doc
|
||||
|
||||
def release_annotation_lock(self, document_id, admin_token, force=False):
|
||||
"""Release annotation lock for a document."""
|
||||
doc = self.documents.get(document_id)
|
||||
if not doc or doc.admin_token != admin_token:
|
||||
return None
|
||||
|
||||
# Release lock
|
||||
doc.annotation_lock_until = None
|
||||
return doc
|
||||
|
||||
def extend_annotation_lock(self, document_id, admin_token, additional_seconds=300):
|
||||
"""Extend an existing annotation lock."""
|
||||
doc = self.documents.get(document_id)
|
||||
if not doc or doc.admin_token != admin_token:
|
||||
return None
|
||||
|
||||
# Check if lock exists and is still valid
|
||||
now = datetime.now(timezone.utc)
|
||||
if not doc.annotation_lock_until or doc.annotation_lock_until <= now:
|
||||
return None
|
||||
|
||||
# Extend lock
|
||||
doc.annotation_lock_until = doc.annotation_lock_until + timedelta(seconds=additional_seconds)
|
||||
return doc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""Create test FastAPI app."""
|
||||
app = FastAPI()
|
||||
|
||||
# Create mock DB
|
||||
mock_db = MockAdminDB()
|
||||
|
||||
# Add test document
|
||||
doc1 = MockAdminDocument(
|
||||
filename="INV001.pdf",
|
||||
status="pending",
|
||||
upload_source="ui",
|
||||
)
|
||||
|
||||
mock_db.documents[str(doc1.document_id)] = doc1
|
||||
|
||||
# Override dependencies
|
||||
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
|
||||
app.dependency_overrides[get_admin_db] = lambda: mock_db
|
||||
|
||||
# Include router
|
||||
router = create_admin_router((".pdf", ".png", ".jpg"))
|
||||
app.include_router(router)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create test client."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_id(app):
|
||||
"""Get document ID from the mock DB."""
|
||||
mock_db = app.dependency_overrides[get_admin_db]()
|
||||
return str(list(mock_db.documents.keys())[0])
|
||||
|
||||
|
||||
class TestAnnotationLocks:
|
||||
"""Tests for annotation lock endpoints."""
|
||||
|
||||
def test_acquire_lock_success(self, client, document_id):
|
||||
"""Test successfully acquiring an annotation lock."""
|
||||
response = client.post(
|
||||
f"/admin/documents/{document_id}/lock",
|
||||
json={"duration_seconds": 300}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["document_id"] == document_id
|
||||
assert data["locked"] is True
|
||||
assert data["lock_expires_at"] is not None
|
||||
assert "Lock acquired for 300 seconds" in data["message"]
|
||||
|
||||
def test_acquire_lock_already_locked(self, client, document_id):
|
||||
"""Test acquiring lock on already locked document."""
|
||||
# First lock
|
||||
response1 = client.post(
|
||||
f"/admin/documents/{document_id}/lock",
|
||||
json={"duration_seconds": 300}
|
||||
)
|
||||
assert response1.status_code == 200
|
||||
|
||||
# Try to lock again
|
||||
response2 = client.post(
|
||||
f"/admin/documents/{document_id}/lock",
|
||||
json={"duration_seconds": 300}
|
||||
)
|
||||
assert response2.status_code == 409
|
||||
assert "already locked" in response2.json()["detail"]
|
||||
|
||||
def test_release_lock_success(self, client, document_id):
|
||||
"""Test successfully releasing an annotation lock."""
|
||||
# First acquire lock
|
||||
client.post(
|
||||
f"/admin/documents/{document_id}/lock",
|
||||
json={"duration_seconds": 300}
|
||||
)
|
||||
|
||||
# Then release it
|
||||
response = client.delete(f"/admin/documents/{document_id}/lock")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["document_id"] == document_id
|
||||
assert data["locked"] is False
|
||||
assert data["lock_expires_at"] is None
|
||||
assert "released successfully" in data["message"]
|
||||
|
||||
def test_release_lock_not_locked(self, client, document_id):
|
||||
"""Test releasing lock on unlocked document."""
|
||||
response = client.delete(f"/admin/documents/{document_id}/lock")
|
||||
|
||||
# Should succeed even if not locked
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["locked"] is False
|
||||
|
||||
def test_extend_lock_success(self, client, document_id):
|
||||
"""Test successfully extending an annotation lock."""
|
||||
# First acquire lock
|
||||
response1 = client.post(
|
||||
f"/admin/documents/{document_id}/lock",
|
||||
json={"duration_seconds": 300}
|
||||
)
|
||||
original_expiry = response1.json()["lock_expires_at"]
|
||||
|
||||
# Extend lock
|
||||
response2 = client.patch(
|
||||
f"/admin/documents/{document_id}/lock",
|
||||
json={"duration_seconds": 300}
|
||||
)
|
||||
|
||||
assert response2.status_code == 200
|
||||
data = response2.json()
|
||||
assert data["document_id"] == document_id
|
||||
assert data["locked"] is True
|
||||
assert data["lock_expires_at"] != original_expiry
|
||||
assert "extended by 300 seconds" in data["message"]
|
||||
|
||||
def test_extend_lock_not_locked(self, client, document_id):
|
||||
"""Test extending lock on unlocked document."""
|
||||
response = client.patch(
|
||||
f"/admin/documents/{document_id}/lock",
|
||||
json={"duration_seconds": 300}
|
||||
)
|
||||
|
||||
assert response.status_code == 409
|
||||
assert "doesn't exist or has expired" in response.json()["detail"]
|
||||
|
||||
def test_acquire_lock_custom_duration(self, client, document_id):
|
||||
"""Test acquiring lock with custom duration."""
|
||||
response = client.post(
|
||||
f"/admin/documents/{document_id}/lock",
|
||||
json={"duration_seconds": 600}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "Lock acquired for 600 seconds" in data["message"]
|
||||
|
||||
def test_acquire_lock_invalid_document(self, client):
|
||||
"""Test acquiring lock on non-existent document."""
|
||||
fake_id = str(uuid4())
|
||||
response = client.post(
|
||||
f"/admin/documents/{fake_id}/lock",
|
||||
json={"duration_seconds": 300}
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"]
|
||||
|
||||
def test_lock_lifecycle(self, client, document_id):
|
||||
"""Test complete lock lifecycle: acquire -> extend -> release."""
|
||||
# Acquire
|
||||
response1 = client.post(
|
||||
f"/admin/documents/{document_id}/lock",
|
||||
json={"duration_seconds": 300}
|
||||
)
|
||||
assert response1.status_code == 200
|
||||
assert response1.json()["locked"] is True
|
||||
|
||||
# Extend
|
||||
response2 = client.patch(
|
||||
f"/admin/documents/{document_id}/lock",
|
||||
json={"duration_seconds": 300}
|
||||
)
|
||||
assert response2.status_code == 200
|
||||
assert response2.json()["locked"] is True
|
||||
|
||||
# Release
|
||||
response3 = client.delete(f"/admin/documents/{document_id}/lock")
|
||||
assert response3.status_code == 200
|
||||
assert response3.json()["locked"] is False
|
||||
|
||||
# Verify can acquire again after release
|
||||
response4 = client.post(
|
||||
f"/admin/documents/{document_id}/lock",
|
||||
json={"duration_seconds": 300}
|
||||
)
|
||||
assert response4.status_code == 200
|
||||
assert response4.json()["locked"] is True
|
||||
420
tests/web/test_annotation_phase5.py
Normal file
420
tests/web/test_annotation_phase5.py
Normal file
@@ -0,0 +1,420 @@
|
||||
"""
|
||||
Tests for Phase 5: Annotation Enhancement (Verification and Override)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from src.web.api.v1.admin.annotations import create_annotation_router
|
||||
from src.web.core.auth import validate_admin_token, get_admin_db
|
||||
|
||||
|
||||
class MockAdminDocument:
|
||||
"""Mock AdminDocument for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.document_id = kwargs.get('document_id', uuid4())
|
||||
self.admin_token = kwargs.get('admin_token', 'test-token')
|
||||
self.filename = kwargs.get('filename', 'test.pdf')
|
||||
self.file_size = kwargs.get('file_size', 100000)
|
||||
self.content_type = kwargs.get('content_type', 'application/pdf')
|
||||
self.page_count = kwargs.get('page_count', 1)
|
||||
self.status = kwargs.get('status', 'labeled')
|
||||
self.auto_label_status = kwargs.get('auto_label_status', None)
|
||||
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
||||
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockAnnotation:
|
||||
"""Mock AdminAnnotation for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.annotation_id = kwargs.get('annotation_id', uuid4())
|
||||
self.document_id = kwargs.get('document_id')
|
||||
self.page_number = kwargs.get('page_number', 1)
|
||||
self.class_id = kwargs.get('class_id', 0)
|
||||
self.class_name = kwargs.get('class_name', 'invoice_number')
|
||||
self.bbox_x = kwargs.get('bbox_x', 100)
|
||||
self.bbox_y = kwargs.get('bbox_y', 100)
|
||||
self.bbox_width = kwargs.get('bbox_width', 200)
|
||||
self.bbox_height = kwargs.get('bbox_height', 50)
|
||||
self.x_center = kwargs.get('x_center', 0.5)
|
||||
self.y_center = kwargs.get('y_center', 0.5)
|
||||
self.width = kwargs.get('width', 0.3)
|
||||
self.height = kwargs.get('height', 0.1)
|
||||
self.text_value = kwargs.get('text_value', 'INV-001')
|
||||
self.confidence = kwargs.get('confidence', 0.95)
|
||||
self.source = kwargs.get('source', 'auto')
|
||||
self.is_verified = kwargs.get('is_verified', False)
|
||||
self.verified_at = kwargs.get('verified_at', None)
|
||||
self.verified_by = kwargs.get('verified_by', None)
|
||||
self.override_source = kwargs.get('override_source', None)
|
||||
self.original_annotation_id = kwargs.get('original_annotation_id', None)
|
||||
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
||||
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockAnnotationHistory:
|
||||
"""Mock AnnotationHistory for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.history_id = kwargs.get('history_id', uuid4())
|
||||
self.annotation_id = kwargs.get('annotation_id')
|
||||
self.document_id = kwargs.get('document_id')
|
||||
self.action = kwargs.get('action', 'override')
|
||||
self.previous_value = kwargs.get('previous_value', {})
|
||||
self.new_value = kwargs.get('new_value', {})
|
||||
self.changed_by = kwargs.get('changed_by', 'test-token')
|
||||
self.change_reason = kwargs.get('change_reason', None)
|
||||
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockAdminDB:
|
||||
"""Mock AdminDB for testing Phase 5."""
|
||||
|
||||
def __init__(self):
|
||||
self.documents = {}
|
||||
self.annotations = {}
|
||||
self.annotation_history = {}
|
||||
|
||||
def get_document_by_token(self, document_id, admin_token):
|
||||
"""Get document by ID and token."""
|
||||
doc = self.documents.get(str(document_id))
|
||||
if doc and doc.admin_token == admin_token:
|
||||
return doc
|
||||
return None
|
||||
|
||||
def verify_annotation(self, annotation_id, admin_token):
|
||||
"""Mark annotation as verified."""
|
||||
annotation = self.annotations.get(str(annotation_id))
|
||||
if annotation:
|
||||
annotation.is_verified = True
|
||||
annotation.verified_at = datetime.utcnow()
|
||||
annotation.verified_by = admin_token
|
||||
return annotation
|
||||
return None
|
||||
|
||||
def override_annotation(
|
||||
self,
|
||||
annotation_id,
|
||||
admin_token,
|
||||
change_reason=None,
|
||||
**updates,
|
||||
):
|
||||
"""Override an annotation."""
|
||||
annotation = self.annotations.get(str(annotation_id))
|
||||
if annotation:
|
||||
# Apply updates
|
||||
for key, value in updates.items():
|
||||
if hasattr(annotation, key):
|
||||
setattr(annotation, key, value)
|
||||
|
||||
# Mark as overridden if was auto-generated
|
||||
if annotation.source == "auto":
|
||||
annotation.override_source = "auto"
|
||||
annotation.source = "manual"
|
||||
|
||||
# Create history record
|
||||
history = MockAnnotationHistory(
|
||||
annotation_id=uuid4().hex if isinstance(annotation_id, str) else annotation_id,
|
||||
document_id=annotation.document_id,
|
||||
action="override",
|
||||
changed_by=admin_token,
|
||||
change_reason=change_reason,
|
||||
)
|
||||
self.annotation_history[str(annotation.annotation_id)] = [history]
|
||||
|
||||
return annotation
|
||||
return None
|
||||
|
||||
def get_annotation_history(self, annotation_id):
|
||||
"""Get annotation history."""
|
||||
return self.annotation_history.get(str(annotation_id), [])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""Create test FastAPI app."""
|
||||
app = FastAPI()
|
||||
|
||||
# Create mock DB
|
||||
mock_db = MockAdminDB()
|
||||
|
||||
# Add test document
|
||||
doc1 = MockAdminDocument(
|
||||
filename="TEST001.pdf",
|
||||
status="labeled",
|
||||
)
|
||||
mock_db.documents[str(doc1.document_id)] = doc1
|
||||
|
||||
# Add test annotations
|
||||
ann1 = MockAnnotation(
|
||||
document_id=doc1.document_id,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
text_value="INV-001",
|
||||
source="auto",
|
||||
confidence=0.95,
|
||||
)
|
||||
ann2 = MockAnnotation(
|
||||
document_id=doc1.document_id,
|
||||
class_id=6,
|
||||
class_name="amount",
|
||||
text_value="1500.00",
|
||||
source="auto",
|
||||
confidence=0.98,
|
||||
)
|
||||
|
||||
mock_db.annotations[str(ann1.annotation_id)] = ann1
|
||||
mock_db.annotations[str(ann2.annotation_id)] = ann2
|
||||
|
||||
# Store document ID and annotation IDs for tests
|
||||
app.state.document_id = str(doc1.document_id)
|
||||
app.state.annotation_id_1 = str(ann1.annotation_id)
|
||||
app.state.annotation_id_2 = str(ann2.annotation_id)
|
||||
|
||||
# Override dependencies
|
||||
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
|
||||
app.dependency_overrides[get_admin_db] = lambda: mock_db
|
||||
|
||||
# Include router
|
||||
router = create_annotation_router()
|
||||
app.include_router(router)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create test client."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
class TestAnnotationVerification:
|
||||
"""Tests for POST /admin/documents/{document_id}/annotations/{annotation_id}/verify endpoint."""
|
||||
|
||||
def test_verify_annotation_success(self, client, app):
|
||||
"""Test successfully verifying an annotation."""
|
||||
document_id = app.state.document_id
|
||||
annotation_id = app.state.annotation_id_1
|
||||
|
||||
response = client.post(
|
||||
f"/admin/documents/{document_id}/annotations/{annotation_id}/verify"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["annotation_id"] == annotation_id
|
||||
assert data["is_verified"] is True
|
||||
assert data["verified_at"] is not None
|
||||
assert data["verified_by"] == "test-token"
|
||||
assert "verified successfully" in data["message"].lower()
|
||||
|
||||
def test_verify_annotation_not_found(self, client, app):
|
||||
"""Test verifying non-existent annotation."""
|
||||
document_id = app.state.document_id
|
||||
fake_annotation_id = str(uuid4())
|
||||
|
||||
response = client.post(
|
||||
f"/admin/documents/{document_id}/annotations/{fake_annotation_id}/verify"
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
||||
def test_verify_annotation_document_not_found(self, client):
|
||||
"""Test verifying annotation with non-existent document."""
|
||||
fake_document_id = str(uuid4())
|
||||
fake_annotation_id = str(uuid4())
|
||||
|
||||
response = client.post(
|
||||
f"/admin/documents/{fake_document_id}/annotations/{fake_annotation_id}/verify"
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
||||
def test_verify_annotation_invalid_uuid(self, client, app):
|
||||
"""Test verifying annotation with invalid UUID format."""
|
||||
document_id = app.state.document_id
|
||||
|
||||
response = client.post(
|
||||
f"/admin/documents/{document_id}/annotations/invalid-uuid/verify"
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "invalid" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
class TestAnnotationOverride:
|
||||
"""Tests for PATCH /admin/documents/{document_id}/annotations/{annotation_id}/override endpoint."""
|
||||
|
||||
def test_override_annotation_text_value(self, client, app):
|
||||
"""Test overriding annotation text value."""
|
||||
document_id = app.state.document_id
|
||||
annotation_id = app.state.annotation_id_1
|
||||
|
||||
response = client.patch(
|
||||
f"/admin/documents/{document_id}/annotations/{annotation_id}/override",
|
||||
json={
|
||||
"text_value": "INV-001-CORRECTED",
|
||||
"reason": "OCR error correction"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["annotation_id"] == annotation_id
|
||||
assert data["source"] == "manual"
|
||||
assert data["override_source"] == "auto"
|
||||
assert "successfully" in data["message"].lower()
|
||||
assert "history_id" in data
|
||||
|
||||
def test_override_annotation_bbox(self, client, app):
|
||||
"""Test overriding annotation bounding box."""
|
||||
document_id = app.state.document_id
|
||||
annotation_id = app.state.annotation_id_1
|
||||
|
||||
response = client.patch(
|
||||
f"/admin/documents/{document_id}/annotations/{annotation_id}/override",
|
||||
json={
|
||||
"bbox": {
|
||||
"x": 110,
|
||||
"y": 205,
|
||||
"width": 195,
|
||||
"height": 48
|
||||
},
|
||||
"reason": "Bbox adjustment"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["annotation_id"] == annotation_id
|
||||
assert data["source"] == "manual"
|
||||
|
||||
def test_override_annotation_class(self, client, app):
|
||||
"""Test overriding annotation class."""
|
||||
document_id = app.state.document_id
|
||||
annotation_id = app.state.annotation_id_1
|
||||
|
||||
response = client.patch(
|
||||
f"/admin/documents/{document_id}/annotations/{annotation_id}/override",
|
||||
json={
|
||||
"class_id": 1,
|
||||
"class_name": "invoice_date",
|
||||
"reason": "Wrong field classification"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["annotation_id"] == annotation_id
|
||||
|
||||
def test_override_annotation_multiple_fields(self, client, app):
|
||||
"""Test overriding multiple annotation fields at once."""
|
||||
document_id = app.state.document_id
|
||||
annotation_id = app.state.annotation_id_2
|
||||
|
||||
response = client.patch(
|
||||
f"/admin/documents/{document_id}/annotations/{annotation_id}/override",
|
||||
json={
|
||||
"text_value": "1550.00",
|
||||
"bbox": {
|
||||
"x": 120,
|
||||
"y": 210,
|
||||
"width": 180,
|
||||
"height": 45
|
||||
},
|
||||
"reason": "Multiple corrections"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["annotation_id"] == annotation_id
|
||||
|
||||
def test_override_annotation_no_updates(self, client, app):
|
||||
"""Test overriding annotation without providing any updates."""
|
||||
document_id = app.state.document_id
|
||||
annotation_id = app.state.annotation_id_1
|
||||
|
||||
response = client.patch(
|
||||
f"/admin/documents/{document_id}/annotations/{annotation_id}/override",
|
||||
json={}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "no updates" in response.json()["detail"].lower()
|
||||
|
||||
def test_override_annotation_not_found(self, client, app):
|
||||
"""Test overriding non-existent annotation."""
|
||||
document_id = app.state.document_id
|
||||
fake_annotation_id = str(uuid4())
|
||||
|
||||
response = client.patch(
|
||||
f"/admin/documents/{document_id}/annotations/{fake_annotation_id}/override",
|
||||
json={
|
||||
"text_value": "TEST"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
||||
def test_override_annotation_document_not_found(self, client):
|
||||
"""Test overriding annotation with non-existent document."""
|
||||
fake_document_id = str(uuid4())
|
||||
fake_annotation_id = str(uuid4())
|
||||
|
||||
response = client.patch(
|
||||
f"/admin/documents/{fake_document_id}/annotations/{fake_annotation_id}/override",
|
||||
json={
|
||||
"text_value": "TEST"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
||||
def test_override_annotation_creates_history(self, client, app):
|
||||
"""Test that overriding annotation creates history record."""
|
||||
document_id = app.state.document_id
|
||||
annotation_id = app.state.annotation_id_1
|
||||
|
||||
response = client.patch(
|
||||
f"/admin/documents/{document_id}/annotations/{annotation_id}/override",
|
||||
json={
|
||||
"text_value": "INV-CORRECTED",
|
||||
"reason": "Test history creation"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# History ID should be present and valid
|
||||
assert "history_id" in data
|
||||
assert data["history_id"] != ""
|
||||
|
||||
def test_override_annotation_with_reason(self, client, app):
|
||||
"""Test overriding annotation with change reason."""
|
||||
document_id = app.state.document_id
|
||||
annotation_id = app.state.annotation_id_1
|
||||
|
||||
change_reason = "Correcting OCR misread"
|
||||
response = client.patch(
|
||||
f"/admin/documents/{document_id}/annotations/{annotation_id}/override",
|
||||
json={
|
||||
"text_value": "INV-002",
|
||||
"reason": change_reason
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
# Reason is stored in history, not returned in response
|
||||
data = response.json()
|
||||
assert data["annotation_id"] == annotation_id
|
||||
217
tests/web/test_async_queue.py
Normal file
217
tests/web/test_async_queue.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Tests for the AsyncTaskQueue class.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from threading import Event
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from src.web.workers.async_queue import AsyncTask, AsyncTaskQueue
|
||||
|
||||
|
||||
class TestAsyncTask:
|
||||
"""Tests for AsyncTask dataclass."""
|
||||
|
||||
def test_create_task(self):
|
||||
"""Test creating an AsyncTask."""
|
||||
task = AsyncTask(
|
||||
request_id="test-id",
|
||||
api_key="test-key",
|
||||
file_path=Path("/tmp/test.pdf"),
|
||||
filename="test.pdf",
|
||||
)
|
||||
|
||||
assert task.request_id == "test-id"
|
||||
assert task.api_key == "test-key"
|
||||
assert task.filename == "test.pdf"
|
||||
assert task.priority == 0
|
||||
assert task.created_at is not None
|
||||
|
||||
|
||||
class TestAsyncTaskQueue:
|
||||
"""Tests for AsyncTaskQueue."""
|
||||
|
||||
def test_init(self):
|
||||
"""Test queue initialization."""
|
||||
queue = AsyncTaskQueue(max_size=50, worker_count=2)
|
||||
|
||||
assert queue._worker_count == 2
|
||||
assert queue._queue.maxsize == 50
|
||||
assert not queue._started
|
||||
|
||||
def test_submit_task(self, task_queue, sample_task):
|
||||
"""Test submitting a task to the queue."""
|
||||
success = task_queue.submit(sample_task)
|
||||
|
||||
assert success is True
|
||||
assert task_queue.get_queue_depth() == 1
|
||||
|
||||
def test_submit_when_full(self, sample_task):
|
||||
"""Test submitting to a full queue."""
|
||||
queue = AsyncTaskQueue(max_size=1, worker_count=1)
|
||||
|
||||
# Submit first task
|
||||
queue.submit(sample_task)
|
||||
|
||||
# Create second task
|
||||
task2 = AsyncTask(
|
||||
request_id="test-2",
|
||||
api_key="test-key",
|
||||
file_path=sample_task.file_path,
|
||||
filename="test2.pdf",
|
||||
)
|
||||
|
||||
# Queue should be full
|
||||
success = queue.submit(task2)
|
||||
assert success is False
|
||||
|
||||
def test_get_queue_depth(self, task_queue, sample_task):
|
||||
"""Test getting queue depth."""
|
||||
assert task_queue.get_queue_depth() == 0
|
||||
|
||||
task_queue.submit(sample_task)
|
||||
assert task_queue.get_queue_depth() == 1
|
||||
|
||||
def test_start_and_stop(self, task_queue):
|
||||
"""Test starting and stopping the queue."""
|
||||
handler = MagicMock()
|
||||
|
||||
task_queue.start(handler)
|
||||
assert task_queue._started is True
|
||||
assert task_queue.is_running is True
|
||||
assert len(task_queue._workers) == 1
|
||||
|
||||
task_queue.stop(timeout=5.0)
|
||||
assert task_queue._started is False
|
||||
assert task_queue.is_running is False
|
||||
assert len(task_queue._workers) == 0
|
||||
|
||||
def test_worker_processes_task(self, sample_task):
|
||||
"""Test that worker thread processes tasks."""
|
||||
queue = AsyncTaskQueue(max_size=10, worker_count=1)
|
||||
processed = Event()
|
||||
|
||||
def handler(task):
|
||||
processed.set()
|
||||
|
||||
queue.start(handler)
|
||||
queue.submit(sample_task)
|
||||
|
||||
# Wait for processing
|
||||
assert processed.wait(timeout=5.0)
|
||||
|
||||
queue.stop()
|
||||
|
||||
def test_worker_handles_errors(self, sample_task):
|
||||
"""Test that worker handles errors gracefully."""
|
||||
queue = AsyncTaskQueue(max_size=10, worker_count=1)
|
||||
error_handled = Event()
|
||||
|
||||
def failing_handler(task):
|
||||
error_handled.set()
|
||||
raise ValueError("Test error")
|
||||
|
||||
queue.start(failing_handler)
|
||||
queue.submit(sample_task)
|
||||
|
||||
# Should not crash
|
||||
assert error_handled.wait(timeout=5.0)
|
||||
time.sleep(0.5) # Give time for error handling
|
||||
|
||||
assert queue.is_running
|
||||
|
||||
queue.stop()
|
||||
|
||||
def test_processing_tracking(self, task_queue, sample_task):
|
||||
"""Test tracking of processing tasks."""
|
||||
processed = Event()
|
||||
|
||||
def slow_handler(task):
|
||||
processed.set()
|
||||
time.sleep(0.5)
|
||||
|
||||
task_queue.start(slow_handler)
|
||||
task_queue.submit(sample_task)
|
||||
|
||||
# Wait for processing to start
|
||||
assert processed.wait(timeout=5.0)
|
||||
|
||||
# Task should be in processing set
|
||||
assert task_queue.get_processing_count() == 1
|
||||
assert task_queue.is_processing(sample_task.request_id)
|
||||
|
||||
# Wait for completion
|
||||
time.sleep(1.0)
|
||||
|
||||
assert task_queue.get_processing_count() == 0
|
||||
assert not task_queue.is_processing(sample_task.request_id)
|
||||
|
||||
task_queue.stop()
|
||||
|
||||
def test_multiple_workers(self, sample_task):
|
||||
"""Test queue with multiple workers."""
|
||||
queue = AsyncTaskQueue(max_size=10, worker_count=3)
|
||||
processed_count = []
|
||||
|
||||
def handler(task):
|
||||
processed_count.append(task.request_id)
|
||||
time.sleep(0.1)
|
||||
|
||||
queue.start(handler)
|
||||
|
||||
# Submit multiple tasks
|
||||
for i in range(5):
|
||||
task = AsyncTask(
|
||||
request_id=f"task-{i}",
|
||||
api_key="test-key",
|
||||
file_path=sample_task.file_path,
|
||||
filename=f"test-{i}.pdf",
|
||||
)
|
||||
queue.submit(task)
|
||||
|
||||
# Wait for all tasks
|
||||
time.sleep(2.0)
|
||||
|
||||
assert len(processed_count) == 5
|
||||
|
||||
queue.stop()
|
||||
|
||||
def test_graceful_shutdown(self, sample_task):
|
||||
"""Test graceful shutdown waits for current task."""
|
||||
queue = AsyncTaskQueue(max_size=10, worker_count=1)
|
||||
started = Event()
|
||||
finished = Event()
|
||||
|
||||
def slow_handler(task):
|
||||
started.set()
|
||||
time.sleep(0.5)
|
||||
finished.set()
|
||||
|
||||
queue.start(slow_handler)
|
||||
queue.submit(sample_task)
|
||||
|
||||
# Wait for processing to start
|
||||
assert started.wait(timeout=5.0)
|
||||
|
||||
# Stop should wait for task to finish
|
||||
queue.stop(timeout=5.0)
|
||||
|
||||
assert finished.is_set()
|
||||
|
||||
def test_double_start(self, task_queue):
|
||||
"""Test that starting twice doesn't create duplicate workers."""
|
||||
handler = MagicMock()
|
||||
|
||||
task_queue.start(handler)
|
||||
assert len(task_queue._workers) == 1
|
||||
|
||||
# Starting again should not add more workers
|
||||
task_queue.start(handler)
|
||||
assert len(task_queue._workers) == 1
|
||||
|
||||
task_queue.stop()
|
||||
409
tests/web/test_async_routes.py
Normal file
409
tests/web/test_async_routes.py
Normal file
@@ -0,0 +1,409 @@
|
||||
"""
|
||||
Tests for the async API routes.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from src.data.async_request_db import ApiKeyConfig, AsyncRequest, AsyncRequestDB
|
||||
from src.web.api.v1.async_api.routes import create_async_router, set_async_service
|
||||
from src.web.services.async_processing import AsyncSubmitResult
|
||||
from src.web.dependencies import init_dependencies
|
||||
from src.web.rate_limiter import RateLimiter, RateLimitStatus
|
||||
from src.web.schemas.inference import AsyncStatus
|
||||
|
||||
# Valid UUID for testing
|
||||
TEST_REQUEST_UUID = "550e8400-e29b-41d4-a716-446655440000"
|
||||
INVALID_UUID = "nonexistent-id"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_service():
|
||||
"""Create a mock AsyncProcessingService."""
|
||||
service = MagicMock()
|
||||
|
||||
# Mock config
|
||||
mock_config = MagicMock()
|
||||
mock_config.max_file_size_mb = 50
|
||||
service._async_config = mock_config
|
||||
|
||||
# Default submit result
|
||||
service.submit_request.return_value = AsyncSubmitResult(
|
||||
success=True,
|
||||
request_id="test-request-id",
|
||||
estimated_wait_seconds=30,
|
||||
)
|
||||
|
||||
return service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rate_limiter(mock_db):
|
||||
"""Create a mock RateLimiter."""
|
||||
limiter = MagicMock(spec=RateLimiter)
|
||||
|
||||
# Default: allow all requests
|
||||
limiter.check_submit_limit.return_value = RateLimitStatus(
|
||||
allowed=True,
|
||||
remaining_requests=9,
|
||||
reset_at=datetime.utcnow() + timedelta(seconds=60),
|
||||
)
|
||||
limiter.check_poll_limit.return_value = RateLimitStatus(
|
||||
allowed=True,
|
||||
remaining_requests=999,
|
||||
reset_at=datetime.utcnow(),
|
||||
)
|
||||
limiter.get_rate_limit_headers.return_value = {}
|
||||
|
||||
return limiter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(mock_db, mock_rate_limiter, mock_async_service):
|
||||
"""Create a test FastAPI app with async routes."""
|
||||
app = FastAPI()
|
||||
|
||||
# Initialize dependencies
|
||||
init_dependencies(mock_db, mock_rate_limiter)
|
||||
set_async_service(mock_async_service)
|
||||
|
||||
# Add routes
|
||||
router = create_async_router(allowed_extensions=(".pdf", ".png", ".jpg", ".jpeg"))
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create a test client."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
class TestAsyncSubmitEndpoint:
|
||||
"""Tests for POST /api/v1/async/submit."""
|
||||
|
||||
def test_submit_success(self, client, mock_async_service):
|
||||
"""Test successful submission."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
|
||||
f.write(b"fake pdf content")
|
||||
f.seek(0)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/async/submit",
|
||||
files={"file": ("test.pdf", f, "application/pdf")},
|
||||
headers={"X-API-Key": "test-api-key"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "accepted"
|
||||
assert data["request_id"] == "test-request-id"
|
||||
assert "poll_url" in data
|
||||
|
||||
def test_submit_missing_api_key(self, client):
|
||||
"""Test submission without API key."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
|
||||
f.write(b"fake pdf content")
|
||||
f.seek(0)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/async/submit",
|
||||
files={"file": ("test.pdf", f, "application/pdf")},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert "X-API-Key" in response.json()["detail"]
|
||||
|
||||
def test_submit_invalid_api_key(self, client, mock_db):
|
||||
"""Test submission with invalid API key."""
|
||||
mock_db.is_valid_api_key.return_value = False
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
|
||||
f.write(b"fake pdf content")
|
||||
f.seek(0)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/async/submit",
|
||||
files={"file": ("test.pdf", f, "application/pdf")},
|
||||
headers={"X-API-Key": "invalid-key"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_submit_unsupported_file_type(self, client):
|
||||
"""Test submission with unsupported file type."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f:
|
||||
f.write(b"text content")
|
||||
f.seek(0)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/async/submit",
|
||||
files={"file": ("test.txt", f, "text/plain")},
|
||||
headers={"X-API-Key": "test-api-key"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Unsupported file type" in response.json()["detail"]
|
||||
|
||||
def test_submit_rate_limited(self, client, mock_rate_limiter):
|
||||
"""Test submission when rate limited."""
|
||||
mock_rate_limiter.check_submit_limit.return_value = RateLimitStatus(
|
||||
allowed=False,
|
||||
remaining_requests=0,
|
||||
reset_at=datetime.utcnow() + timedelta(seconds=30),
|
||||
retry_after_seconds=30,
|
||||
reason="Rate limit exceeded",
|
||||
)
|
||||
mock_rate_limiter.get_rate_limit_headers.return_value = {"Retry-After": "30"}
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
|
||||
f.write(b"fake pdf content")
|
||||
f.seek(0)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/async/submit",
|
||||
files={"file": ("test.pdf", f, "application/pdf")},
|
||||
headers={"X-API-Key": "test-api-key"},
|
||||
)
|
||||
|
||||
assert response.status_code == 429
|
||||
assert "Retry-After" in response.headers
|
||||
|
||||
def test_submit_queue_full(self, client, mock_async_service):
|
||||
"""Test submission when queue is full."""
|
||||
mock_async_service.submit_request.return_value = AsyncSubmitResult(
|
||||
success=False,
|
||||
request_id="test-id",
|
||||
error="Processing queue is full",
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
|
||||
f.write(b"fake pdf content")
|
||||
f.seek(0)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/async/submit",
|
||||
files={"file": ("test.pdf", f, "application/pdf")},
|
||||
headers={"X-API-Key": "test-api-key"},
|
||||
)
|
||||
|
||||
assert response.status_code == 503
|
||||
|
||||
|
||||
class TestAsyncStatusEndpoint:
|
||||
"""Tests for GET /api/v1/async/status/{request_id}."""
|
||||
|
||||
def test_get_status_pending(self, client, mock_db, sample_async_request):
|
||||
"""Test getting status of pending request."""
|
||||
mock_db.get_request_by_api_key.return_value = sample_async_request
|
||||
mock_db.get_queue_position.return_value = 3
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/async/status/550e8400-e29b-41d4-a716-446655440000",
|
||||
headers={"X-API-Key": "test-api-key"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "pending"
|
||||
assert data["position_in_queue"] == 3
|
||||
assert data["result_url"] is None
|
||||
|
||||
def test_get_status_completed(self, client, mock_db, sample_async_request):
|
||||
"""Test getting status of completed request."""
|
||||
sample_async_request.status = "completed"
|
||||
sample_async_request.completed_at = datetime.utcnow()
|
||||
mock_db.get_request_by_api_key.return_value = sample_async_request
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/async/status/550e8400-e29b-41d4-a716-446655440000",
|
||||
headers={"X-API-Key": "test-api-key"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "completed"
|
||||
assert data["result_url"] is not None
|
||||
|
||||
def test_get_status_not_found(self, client, mock_db):
|
||||
"""Test getting status of non-existent request."""
|
||||
mock_db.get_request_by_api_key.return_value = None
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/async/status/00000000-0000-0000-0000-000000000000",
|
||||
headers={"X-API-Key": "test-api-key"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_get_status_wrong_api_key(self, client, mock_db, sample_async_request):
|
||||
"""Test that requests are isolated by API key."""
|
||||
# Request belongs to different API key
|
||||
mock_db.get_request_by_api_key.return_value = None
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/async/status/550e8400-e29b-41d4-a716-446655440000",
|
||||
headers={"X-API-Key": "different-api-key"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
class TestAsyncResultEndpoint:
|
||||
"""Tests for GET /api/v1/async/result/{request_id}."""
|
||||
|
||||
def test_get_result_completed(self, client, mock_db, sample_async_request):
|
||||
"""Test getting result of completed request."""
|
||||
sample_async_request.status = "completed"
|
||||
sample_async_request.completed_at = datetime.utcnow()
|
||||
sample_async_request.processing_time_ms = 1234.5
|
||||
sample_async_request.result = {
|
||||
"document_id": "test-doc",
|
||||
"success": True,
|
||||
"document_type": "invoice",
|
||||
"fields": {"InvoiceNumber": "12345"},
|
||||
"confidence": {"InvoiceNumber": 0.95},
|
||||
"detections": [],
|
||||
"errors": [],
|
||||
}
|
||||
mock_db.get_request_by_api_key.return_value = sample_async_request
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/async/result/550e8400-e29b-41d4-a716-446655440000",
|
||||
headers={"X-API-Key": "test-api-key"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "completed"
|
||||
assert data["result"] is not None
|
||||
assert data["result"]["fields"]["InvoiceNumber"] == "12345"
|
||||
|
||||
def test_get_result_not_completed(self, client, mock_db, sample_async_request):
|
||||
"""Test getting result of pending request."""
|
||||
mock_db.get_request_by_api_key.return_value = sample_async_request
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/async/result/550e8400-e29b-41d4-a716-446655440000",
|
||||
headers={"X-API-Key": "test-api-key"},
|
||||
)
|
||||
|
||||
assert response.status_code == 409
|
||||
assert "not yet completed" in response.json()["detail"]
|
||||
|
||||
def test_get_result_failed(self, client, mock_db, sample_async_request):
|
||||
"""Test getting result of failed request."""
|
||||
sample_async_request.status = "failed"
|
||||
sample_async_request.error_message = "Processing failed"
|
||||
sample_async_request.processing_time_ms = 500.0
|
||||
mock_db.get_request_by_api_key.return_value = sample_async_request
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/async/result/550e8400-e29b-41d4-a716-446655440000",
|
||||
headers={"X-API-Key": "test-api-key"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "failed"
|
||||
|
||||
|
||||
class TestAsyncListEndpoint:
|
||||
"""Tests for GET /api/v1/async/requests."""
|
||||
|
||||
def test_list_requests(self, client, mock_db, sample_async_request):
|
||||
"""Test listing requests."""
|
||||
mock_db.get_requests_by_api_key.return_value = ([sample_async_request], 1)
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/async/requests",
|
||||
headers={"X-API-Key": "test-api-key"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 1
|
||||
assert len(data["requests"]) == 1
|
||||
|
||||
def test_list_requests_with_status_filter(self, client, mock_db):
|
||||
"""Test listing requests with status filter."""
|
||||
mock_db.get_requests_by_api_key.return_value = ([], 0)
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/async/requests?status=completed",
|
||||
headers={"X-API-Key": "test-api-key"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_db.get_requests_by_api_key.assert_called_once()
|
||||
call_kwargs = mock_db.get_requests_by_api_key.call_args[1]
|
||||
assert call_kwargs["status"] == "completed"
|
||||
|
||||
def test_list_requests_pagination(self, client, mock_db):
|
||||
"""Test listing requests with pagination."""
|
||||
mock_db.get_requests_by_api_key.return_value = ([], 0)
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/async/requests?limit=50&offset=10",
|
||||
headers={"X-API-Key": "test-api-key"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
call_kwargs = mock_db.get_requests_by_api_key.call_args[1]
|
||||
assert call_kwargs["limit"] == 50
|
||||
assert call_kwargs["offset"] == 10
|
||||
|
||||
def test_list_requests_invalid_status(self, client, mock_db):
|
||||
"""Test listing with invalid status filter."""
|
||||
response = client.get(
|
||||
"/api/v1/async/requests?status=invalid",
|
||||
headers={"X-API-Key": "test-api-key"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
class TestAsyncDeleteEndpoint:
|
||||
"""Tests for DELETE /api/v1/async/requests/{request_id}."""
|
||||
|
||||
def test_delete_pending_request(self, client, mock_db, sample_async_request):
|
||||
"""Test deleting a pending request."""
|
||||
mock_db.get_request_by_api_key.return_value = sample_async_request
|
||||
|
||||
response = client.delete(
|
||||
"/api/v1/async/requests/550e8400-e29b-41d4-a716-446655440000",
|
||||
headers={"X-API-Key": "test-api-key"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "deleted"
|
||||
|
||||
def test_delete_processing_request(self, client, mock_db, sample_async_request):
|
||||
"""Test that processing requests cannot be deleted."""
|
||||
sample_async_request.status = "processing"
|
||||
mock_db.get_request_by_api_key.return_value = sample_async_request
|
||||
|
||||
response = client.delete(
|
||||
"/api/v1/async/requests/550e8400-e29b-41d4-a716-446655440000",
|
||||
headers={"X-API-Key": "test-api-key"},
|
||||
)
|
||||
|
||||
assert response.status_code == 409
|
||||
|
||||
def test_delete_not_found(self, client, mock_db):
|
||||
"""Test deleting non-existent request."""
|
||||
mock_db.get_request_by_api_key.return_value = None
|
||||
|
||||
response = client.delete(
|
||||
"/api/v1/async/requests/00000000-0000-0000-0000-000000000000",
|
||||
headers={"X-API-Key": "test-api-key"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
266
tests/web/test_async_service.py
Normal file
266
tests/web/test_async_service.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""
|
||||
Tests for the AsyncProcessingService class.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.data.async_request_db import AsyncRequest
|
||||
from src.web.workers.async_queue import AsyncTask, AsyncTaskQueue
|
||||
from src.web.services.async_processing import AsyncProcessingService, AsyncSubmitResult
|
||||
from src.web.config import AsyncConfig, StorageConfig
|
||||
from src.web.rate_limiter import RateLimiter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_service(mock_db, mock_inference_service, rate_limiter, storage_config):
|
||||
"""Create an AsyncProcessingService for testing."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
async_config = AsyncConfig(
|
||||
queue_max_size=10,
|
||||
worker_count=1,
|
||||
task_timeout_seconds=30,
|
||||
result_retention_days=7,
|
||||
temp_upload_dir=Path(tmpdir) / "async",
|
||||
max_file_size_mb=10,
|
||||
)
|
||||
|
||||
queue = AsyncTaskQueue(max_size=10, worker_count=1)
|
||||
|
||||
service = AsyncProcessingService(
|
||||
inference_service=mock_inference_service,
|
||||
db=mock_db,
|
||||
queue=queue,
|
||||
rate_limiter=rate_limiter,
|
||||
async_config=async_config,
|
||||
storage_config=storage_config,
|
||||
)
|
||||
|
||||
yield service
|
||||
|
||||
# Cleanup
|
||||
if service._queue._started:
|
||||
service.stop()
|
||||
|
||||
|
||||
class TestAsyncProcessingService:
|
||||
"""Tests for AsyncProcessingService."""
|
||||
|
||||
def test_submit_request_success(self, async_service, mock_db):
|
||||
"""Test successful request submission."""
|
||||
mock_db.create_request.return_value = "test-request-id"
|
||||
|
||||
result = async_service.submit_request(
|
||||
api_key="test-api-key",
|
||||
file_content=b"fake pdf content",
|
||||
filename="test.pdf",
|
||||
content_type="application/pdf",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.request_id is not None
|
||||
assert result.estimated_wait_seconds >= 0
|
||||
assert result.error is None
|
||||
|
||||
def test_submit_request_creates_db_record(self, async_service, mock_db):
|
||||
"""Test that submission creates database record."""
|
||||
async_service.submit_request(
|
||||
api_key="test-api-key",
|
||||
file_content=b"fake pdf content",
|
||||
filename="test.pdf",
|
||||
content_type="application/pdf",
|
||||
)
|
||||
|
||||
mock_db.create_request.assert_called_once()
|
||||
call_kwargs = mock_db.create_request.call_args[1]
|
||||
assert call_kwargs["api_key"] == "test-api-key"
|
||||
assert call_kwargs["filename"] == "test.pdf"
|
||||
assert call_kwargs["content_type"] == "application/pdf"
|
||||
|
||||
def test_submit_request_saves_file(self, async_service, mock_db):
|
||||
"""Test that submission saves file to temp directory."""
|
||||
content = b"fake pdf content"
|
||||
|
||||
result = async_service.submit_request(
|
||||
api_key="test-api-key",
|
||||
file_content=content,
|
||||
filename="test.pdf",
|
||||
content_type="application/pdf",
|
||||
)
|
||||
|
||||
# File should exist in temp directory
|
||||
temp_dir = async_service._async_config.temp_upload_dir
|
||||
files = list(temp_dir.iterdir())
|
||||
|
||||
# Note: file may be cleaned up quickly if queue processes it
|
||||
# So we just check that the operation succeeded
|
||||
assert result.success is True
|
||||
|
||||
def test_submit_request_records_rate_limit(self, async_service, mock_db, rate_limiter):
|
||||
"""Test that submission records rate limit event."""
|
||||
async_service.submit_request(
|
||||
api_key="test-api-key",
|
||||
file_content=b"fake pdf content",
|
||||
filename="test.pdf",
|
||||
content_type="application/pdf",
|
||||
)
|
||||
|
||||
# Rate limiter should have recorded the request
|
||||
mock_db.record_rate_limit_event.assert_called()
|
||||
|
||||
def test_start_and_stop(self, async_service):
|
||||
"""Test starting and stopping the service."""
|
||||
async_service.start()
|
||||
|
||||
assert async_service._queue._started is True
|
||||
assert async_service._cleanup_thread is not None
|
||||
assert async_service._cleanup_thread.is_alive()
|
||||
|
||||
async_service.stop()
|
||||
|
||||
assert async_service._queue._started is False
|
||||
|
||||
def test_process_task_success(self, async_service, mock_db, mock_inference_service, sample_task):
|
||||
"""Test successful task processing."""
|
||||
async_service._process_task(sample_task)
|
||||
|
||||
# Should update status to processing
|
||||
mock_db.update_status.assert_called_with(sample_task.request_id, "processing")
|
||||
|
||||
# Should complete the request
|
||||
mock_db.complete_request.assert_called_once()
|
||||
call_kwargs = mock_db.complete_request.call_args[1]
|
||||
assert call_kwargs["request_id"] == sample_task.request_id
|
||||
assert "document_id" in call_kwargs
|
||||
|
||||
def test_process_task_pdf(self, async_service, mock_db, mock_inference_service, sample_task):
|
||||
"""Test processing a PDF task."""
|
||||
async_service._process_task(sample_task)
|
||||
|
||||
# Should call process_pdf for .pdf files
|
||||
mock_inference_service.process_pdf.assert_called_once()
|
||||
|
||||
def test_process_task_image(self, async_service, mock_db, mock_inference_service):
|
||||
"""Test processing an image task."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(b"fake image content")
|
||||
task = AsyncTask(
|
||||
request_id="image-task",
|
||||
api_key="test-api-key",
|
||||
file_path=Path(f.name),
|
||||
filename="test.png",
|
||||
)
|
||||
|
||||
async_service._process_task(task)
|
||||
|
||||
# Should call process_image for image files
|
||||
mock_inference_service.process_image.assert_called_once()
|
||||
|
||||
def test_process_task_failure(self, async_service, mock_db, mock_inference_service, sample_task):
|
||||
"""Test task processing failure."""
|
||||
mock_inference_service.process_pdf.side_effect = Exception("Processing failed")
|
||||
|
||||
async_service._process_task(sample_task)
|
||||
|
||||
# Should update status to failed
|
||||
mock_db.update_status.assert_called()
|
||||
last_call = mock_db.update_status.call_args_list[-1]
|
||||
assert last_call[0][1] == "failed" # status
|
||||
assert "Processing failed" in last_call[1]["error_message"]
|
||||
|
||||
def test_process_task_file_not_found(self, async_service, mock_db):
|
||||
"""Test task processing with missing file."""
|
||||
task = AsyncTask(
|
||||
request_id="missing-file-task",
|
||||
api_key="test-api-key",
|
||||
file_path=Path("/nonexistent/file.pdf"),
|
||||
filename="test.pdf",
|
||||
)
|
||||
|
||||
async_service._process_task(task)
|
||||
|
||||
# Should fail with file not found
|
||||
mock_db.update_status.assert_called()
|
||||
last_call = mock_db.update_status.call_args_list[-1]
|
||||
assert last_call[0][1] == "failed"
|
||||
|
||||
def test_process_task_cleans_up_file(self, async_service, mock_db, mock_inference_service):
|
||||
"""Test that task processing cleans up the uploaded file."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
|
||||
f.write(b"fake pdf content")
|
||||
file_path = Path(f.name)
|
||||
|
||||
task = AsyncTask(
|
||||
request_id="cleanup-task",
|
||||
api_key="test-api-key",
|
||||
file_path=file_path,
|
||||
filename="test.pdf",
|
||||
)
|
||||
|
||||
async_service._process_task(task)
|
||||
|
||||
# File should be deleted
|
||||
assert not file_path.exists()
|
||||
|
||||
def test_estimate_wait(self, async_service):
|
||||
"""Test wait time estimation."""
|
||||
# Empty queue
|
||||
wait = async_service._estimate_wait()
|
||||
assert wait == 0
|
||||
|
||||
def test_cleanup_orphan_files(self, async_service, mock_db):
|
||||
"""Test cleanup of orphan files."""
|
||||
# Create an orphan file
|
||||
temp_dir = async_service._async_config.temp_upload_dir
|
||||
orphan_file = temp_dir / "orphan-request.pdf"
|
||||
orphan_file.write_bytes(b"orphan content")
|
||||
|
||||
# Set file mtime to old
|
||||
import os
|
||||
old_time = time.time() - 7200
|
||||
os.utime(orphan_file, (old_time, old_time))
|
||||
|
||||
# Mock database to say file doesn't exist
|
||||
mock_db.get_request.return_value = None
|
||||
|
||||
count = async_service._cleanup_orphan_files()
|
||||
|
||||
assert count == 1
|
||||
assert not orphan_file.exists()
|
||||
|
||||
def test_save_upload(self, async_service):
|
||||
"""Test saving uploaded file."""
|
||||
content = b"test content"
|
||||
|
||||
file_path = async_service._save_upload(
|
||||
request_id="test-save",
|
||||
filename="test.pdf",
|
||||
content=content,
|
||||
)
|
||||
|
||||
assert file_path.exists()
|
||||
assert file_path.read_bytes() == content
|
||||
assert file_path.suffix == ".pdf"
|
||||
|
||||
# Cleanup
|
||||
file_path.unlink()
|
||||
|
||||
def test_save_upload_preserves_extension(self, async_service):
|
||||
"""Test that save_upload preserves file extension."""
|
||||
content = b"test content"
|
||||
|
||||
# Test various extensions
|
||||
for ext in [".pdf", ".png", ".jpg", ".jpeg"]:
|
||||
file_path = async_service._save_upload(
|
||||
request_id=f"test-{ext}",
|
||||
filename=f"test{ext}",
|
||||
content=content,
|
||||
)
|
||||
|
||||
assert file_path.suffix == ext
|
||||
file_path.unlink()
|
||||
250
tests/web/test_autolabel_with_locks.py
Normal file
250
tests/web/test_autolabel_with_locks.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
Tests for Auto-Label Service with Annotation Lock Integration (Phase 3.5).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from src.web.services.autolabel import AutoLabelService
|
||||
from src.data.admin_db import AdminDB
|
||||
|
||||
|
||||
class MockDocument:
|
||||
"""Mock document for testing."""
|
||||
|
||||
def __init__(self, document_id, annotation_lock_until=None):
|
||||
self.document_id = document_id
|
||||
self.annotation_lock_until = annotation_lock_until
|
||||
self.status = "pending"
|
||||
self.auto_label_status = None
|
||||
self.auto_label_error = None
|
||||
|
||||
|
||||
class MockAdminDB:
|
||||
"""Mock AdminDB for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.documents = {}
|
||||
self.annotations = []
|
||||
self.status_updates = []
|
||||
|
||||
def get_document(self, document_id):
|
||||
"""Get document by ID."""
|
||||
return self.documents.get(str(document_id))
|
||||
|
||||
def update_document_status(
|
||||
self,
|
||||
document_id,
|
||||
status=None,
|
||||
auto_label_status=None,
|
||||
auto_label_error=None,
|
||||
):
|
||||
"""Mock status update."""
|
||||
self.status_updates.append({
|
||||
"document_id": document_id,
|
||||
"status": status,
|
||||
"auto_label_status": auto_label_status,
|
||||
"auto_label_error": auto_label_error,
|
||||
})
|
||||
doc = self.documents.get(str(document_id))
|
||||
if doc:
|
||||
if status:
|
||||
doc.status = status
|
||||
if auto_label_status:
|
||||
doc.auto_label_status = auto_label_status
|
||||
if auto_label_error:
|
||||
doc.auto_label_error = auto_label_error
|
||||
|
||||
def delete_annotations_for_document(self, document_id, source=None):
|
||||
"""Mock delete annotations."""
|
||||
return 0
|
||||
|
||||
def create_annotations_batch(self, annotations):
|
||||
"""Mock create annotations."""
|
||||
self.annotations.extend(annotations)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db():
|
||||
"""Create mock admin DB."""
|
||||
return MockAdminDB()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auto_label_service(monkeypatch):
|
||||
"""Create auto-label service with mocked image processing."""
|
||||
service = AutoLabelService()
|
||||
# Mock the OCR engine to avoid dependencies
|
||||
service._ocr_engine = Mock()
|
||||
service._ocr_engine.extract_from_image = Mock(return_value=[])
|
||||
|
||||
# Mock the image processing methods to avoid file I/O errors
|
||||
def mock_process_image(self, document_id, image_path, field_values, db, page_number=1):
|
||||
return 0 # No annotations created (mocked)
|
||||
|
||||
monkeypatch.setattr(AutoLabelService, "_process_image", mock_process_image)
|
||||
|
||||
return service
|
||||
|
||||
|
||||
class TestAutoLabelWithLocks:
|
||||
"""Tests for auto-label service with lock integration."""
|
||||
|
||||
def test_auto_label_unlocked_document_succeeds(self, auto_label_service, mock_db, tmp_path):
|
||||
"""Test auto-labeling succeeds on unlocked document."""
|
||||
# Create test document (unlocked)
|
||||
document_id = str(uuid4())
|
||||
mock_db.documents[document_id] = MockDocument(
|
||||
document_id=document_id,
|
||||
annotation_lock_until=None,
|
||||
)
|
||||
|
||||
# Create dummy file
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_text("dummy")
|
||||
|
||||
# Attempt auto-label
|
||||
result = auto_label_service.auto_label_document(
|
||||
document_id=document_id,
|
||||
file_path=str(test_file),
|
||||
field_values={"invoice_number": "INV-001"},
|
||||
db=mock_db,
|
||||
)
|
||||
|
||||
# Should succeed
|
||||
assert result["status"] == "completed"
|
||||
# Verify status was updated to running and then completed
|
||||
assert len(mock_db.status_updates) >= 2
|
||||
assert mock_db.status_updates[0]["auto_label_status"] == "running"
|
||||
|
||||
def test_auto_label_locked_document_fails(self, auto_label_service, mock_db, tmp_path):
|
||||
"""Test auto-labeling fails on locked document."""
|
||||
# Create test document (locked for 1 hour)
|
||||
document_id = str(uuid4())
|
||||
lock_until = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
mock_db.documents[document_id] = MockDocument(
|
||||
document_id=document_id,
|
||||
annotation_lock_until=lock_until,
|
||||
)
|
||||
|
||||
# Create dummy file
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_text("dummy")
|
||||
|
||||
# Attempt auto-label (should fail)
|
||||
result = auto_label_service.auto_label_document(
|
||||
document_id=document_id,
|
||||
file_path=str(test_file),
|
||||
field_values={"invoice_number": "INV-001"},
|
||||
db=mock_db,
|
||||
)
|
||||
|
||||
# Should fail
|
||||
assert result["status"] == "failed"
|
||||
assert "locked for annotation" in result["error"]
|
||||
assert result["annotations_created"] == 0
|
||||
|
||||
# Verify status was updated to failed
|
||||
assert any(
|
||||
update["auto_label_status"] == "failed"
|
||||
for update in mock_db.status_updates
|
||||
)
|
||||
|
||||
def test_auto_label_expired_lock_succeeds(self, auto_label_service, mock_db, tmp_path):
|
||||
"""Test auto-labeling succeeds when lock has expired."""
|
||||
# Create test document (lock expired 1 hour ago)
|
||||
document_id = str(uuid4())
|
||||
lock_until = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
mock_db.documents[document_id] = MockDocument(
|
||||
document_id=document_id,
|
||||
annotation_lock_until=lock_until,
|
||||
)
|
||||
|
||||
# Create dummy file
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_text("dummy")
|
||||
|
||||
# Attempt auto-label
|
||||
result = auto_label_service.auto_label_document(
|
||||
document_id=document_id,
|
||||
file_path=str(test_file),
|
||||
field_values={"invoice_number": "INV-001"},
|
||||
db=mock_db,
|
||||
)
|
||||
|
||||
# Should succeed (lock expired)
|
||||
assert result["status"] == "completed"
|
||||
|
||||
def test_auto_label_skip_lock_check(self, auto_label_service, mock_db, tmp_path):
|
||||
"""Test auto-labeling with skip_lock_check=True bypasses lock."""
|
||||
# Create test document (locked)
|
||||
document_id = str(uuid4())
|
||||
lock_until = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
mock_db.documents[document_id] = MockDocument(
|
||||
document_id=document_id,
|
||||
annotation_lock_until=lock_until,
|
||||
)
|
||||
|
||||
# Create dummy file
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_text("dummy")
|
||||
|
||||
# Attempt auto-label with skip_lock_check=True
|
||||
result = auto_label_service.auto_label_document(
|
||||
document_id=document_id,
|
||||
file_path=str(test_file),
|
||||
field_values={"invoice_number": "INV-001"},
|
||||
db=mock_db,
|
||||
skip_lock_check=True, # Bypass lock check
|
||||
)
|
||||
|
||||
# Should succeed even though document is locked
|
||||
assert result["status"] == "completed"
|
||||
|
||||
def test_auto_label_document_not_found(self, auto_label_service, mock_db, tmp_path):
|
||||
"""Test auto-labeling fails when document doesn't exist."""
|
||||
# Create dummy file
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_text("dummy")
|
||||
|
||||
# Attempt auto-label on non-existent document
|
||||
result = auto_label_service.auto_label_document(
|
||||
document_id=str(uuid4()),
|
||||
file_path=str(test_file),
|
||||
field_values={"invoice_number": "INV-001"},
|
||||
db=mock_db,
|
||||
)
|
||||
|
||||
# Should fail
|
||||
assert result["status"] == "failed"
|
||||
assert "not found" in result["error"]
|
||||
|
||||
def test_auto_label_respects_lock_by_default(self, auto_label_service, mock_db, tmp_path):
|
||||
"""Test that lock check is enabled by default."""
|
||||
# Create test document (locked)
|
||||
document_id = str(uuid4())
|
||||
lock_until = datetime.now(timezone.utc) + timedelta(minutes=30)
|
||||
mock_db.documents[document_id] = MockDocument(
|
||||
document_id=document_id,
|
||||
annotation_lock_until=lock_until,
|
||||
)
|
||||
|
||||
# Create dummy file
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_text("dummy")
|
||||
|
||||
# Call without explicit skip_lock_check (defaults to False)
|
||||
result = auto_label_service.auto_label_document(
|
||||
document_id=document_id,
|
||||
file_path=str(test_file),
|
||||
field_values={"invoice_number": "INV-001"},
|
||||
db=mock_db,
|
||||
# skip_lock_check not specified, should default to False
|
||||
)
|
||||
|
||||
# Should fail due to lock
|
||||
assert result["status"] == "failed"
|
||||
assert "locked" in result["error"].lower()
|
||||
282
tests/web/test_batch_queue.py
Normal file
282
tests/web/test_batch_queue.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
Tests for Batch Upload Queue
|
||||
"""
|
||||
|
||||
import time
|
||||
from datetime import datetime
|
||||
from threading import Event
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from src.web.workers.batch_queue import BatchTask, BatchTaskQueue
|
||||
|
||||
|
||||
class MockBatchService:
|
||||
"""Mock batch upload service for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.processed_tasks = []
|
||||
self.process_delay = 0.1 # Simulate processing time
|
||||
self.should_fail = False
|
||||
|
||||
def process_zip_upload(self, admin_token, zip_filename, zip_content, upload_source):
|
||||
"""Mock process_zip_upload method."""
|
||||
if self.should_fail:
|
||||
raise Exception("Simulated processing error")
|
||||
|
||||
time.sleep(self.process_delay) # Simulate work
|
||||
|
||||
self.processed_tasks.append({
|
||||
"admin_token": admin_token,
|
||||
"zip_filename": zip_filename,
|
||||
"upload_source": upload_source,
|
||||
})
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"successful_files": 1,
|
||||
"failed_files": 0,
|
||||
}
|
||||
|
||||
|
||||
class TestBatchTask:
|
||||
"""Tests for BatchTask dataclass."""
|
||||
|
||||
def test_batch_task_creation(self):
|
||||
"""BatchTask can be created with required fields."""
|
||||
task = BatchTask(
|
||||
batch_id=uuid4(),
|
||||
admin_token="test-token",
|
||||
zip_content=b"test",
|
||||
zip_filename="test.zip",
|
||||
upload_source="ui",
|
||||
auto_label=True,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
assert task.batch_id is not None
|
||||
assert task.admin_token == "test-token"
|
||||
assert task.zip_filename == "test.zip"
|
||||
assert task.upload_source == "ui"
|
||||
assert task.auto_label is True
|
||||
|
||||
|
||||
class TestBatchTaskQueue:
|
||||
"""Tests for batch task queue functionality."""
|
||||
|
||||
def test_queue_initialization(self):
|
||||
"""Queue initializes with correct defaults."""
|
||||
queue = BatchTaskQueue(max_size=10, worker_count=1)
|
||||
|
||||
assert queue.get_queue_depth() == 0
|
||||
assert queue.is_running is False
|
||||
assert queue._worker_count == 1
|
||||
|
||||
def test_start_queue(self):
|
||||
"""Queue starts with batch service."""
|
||||
queue = BatchTaskQueue(max_size=10, worker_count=1)
|
||||
service = MockBatchService()
|
||||
|
||||
queue.start(service)
|
||||
|
||||
assert queue.is_running is True
|
||||
assert len(queue._workers) == 1
|
||||
|
||||
queue.stop()
|
||||
|
||||
def test_stop_queue(self):
|
||||
"""Queue stops gracefully."""
|
||||
queue = BatchTaskQueue(max_size=10, worker_count=1)
|
||||
service = MockBatchService()
|
||||
|
||||
queue.start(service)
|
||||
assert queue.is_running is True
|
||||
|
||||
queue.stop(timeout=5.0)
|
||||
|
||||
assert queue.is_running is False
|
||||
assert len(queue._workers) == 0
|
||||
|
||||
def test_submit_task_success(self):
|
||||
"""Task is submitted to queue successfully."""
|
||||
queue = BatchTaskQueue(max_size=10, worker_count=1)
|
||||
|
||||
task = BatchTask(
|
||||
batch_id=uuid4(),
|
||||
admin_token="test-token",
|
||||
zip_content=b"test",
|
||||
zip_filename="test.zip",
|
||||
upload_source="ui",
|
||||
auto_label=True,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
result = queue.submit(task)
|
||||
|
||||
assert result is True
|
||||
assert queue.get_queue_depth() == 1
|
||||
|
||||
def test_submit_task_queue_full(self):
|
||||
"""Returns False when queue is full."""
|
||||
queue = BatchTaskQueue(max_size=2, worker_count=1)
|
||||
|
||||
# Fill the queue
|
||||
for i in range(2):
|
||||
task = BatchTask(
|
||||
batch_id=uuid4(),
|
||||
admin_token="test-token",
|
||||
zip_content=b"test",
|
||||
zip_filename=f"test{i}.zip",
|
||||
upload_source="ui",
|
||||
auto_label=True,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
assert queue.submit(task) is True
|
||||
|
||||
# Try to add one more (should fail)
|
||||
extra_task = BatchTask(
|
||||
batch_id=uuid4(),
|
||||
admin_token="test-token",
|
||||
zip_content=b"test",
|
||||
zip_filename="extra.zip",
|
||||
upload_source="ui",
|
||||
auto_label=True,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
result = queue.submit(extra_task)
|
||||
|
||||
assert result is False
|
||||
assert queue.get_queue_depth() == 2
|
||||
|
||||
def test_worker_processes_task(self):
|
||||
"""Worker thread processes queued tasks."""
|
||||
queue = BatchTaskQueue(max_size=10, worker_count=1)
|
||||
service = MockBatchService()
|
||||
|
||||
queue.start(service)
|
||||
|
||||
task = BatchTask(
|
||||
batch_id=uuid4(),
|
||||
admin_token="test-token",
|
||||
zip_content=b"test",
|
||||
zip_filename="test.zip",
|
||||
upload_source="ui",
|
||||
auto_label=True,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
queue.submit(task)
|
||||
|
||||
# Wait for processing
|
||||
time.sleep(0.5)
|
||||
|
||||
assert len(service.processed_tasks) == 1
|
||||
assert service.processed_tasks[0]["zip_filename"] == "test.zip"
|
||||
|
||||
queue.stop()
|
||||
|
||||
def test_multiple_tasks_processed(self):
|
||||
"""Multiple tasks are processed in order."""
|
||||
queue = BatchTaskQueue(max_size=10, worker_count=1)
|
||||
service = MockBatchService()
|
||||
|
||||
queue.start(service)
|
||||
|
||||
# Submit multiple tasks
|
||||
for i in range(3):
|
||||
task = BatchTask(
|
||||
batch_id=uuid4(),
|
||||
admin_token="test-token",
|
||||
zip_content=b"test",
|
||||
zip_filename=f"test{i}.zip",
|
||||
upload_source="ui",
|
||||
auto_label=True,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
queue.submit(task)
|
||||
|
||||
# Wait for all to process
|
||||
time.sleep(1.0)
|
||||
|
||||
assert len(service.processed_tasks) == 3
|
||||
|
||||
queue.stop()
|
||||
|
||||
def test_get_queue_depth(self):
|
||||
"""Returns correct queue depth."""
|
||||
queue = BatchTaskQueue(max_size=10, worker_count=1)
|
||||
|
||||
assert queue.get_queue_depth() == 0
|
||||
|
||||
# Add tasks
|
||||
for i in range(3):
|
||||
task = BatchTask(
|
||||
batch_id=uuid4(),
|
||||
admin_token="test-token",
|
||||
zip_content=b"test",
|
||||
zip_filename=f"test{i}.zip",
|
||||
upload_source="ui",
|
||||
auto_label=True,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
queue.submit(task)
|
||||
|
||||
assert queue.get_queue_depth() == 3
|
||||
|
||||
def test_is_running_property(self):
|
||||
"""is_running reflects queue state."""
|
||||
queue = BatchTaskQueue(max_size=10, worker_count=1)
|
||||
service = MockBatchService()
|
||||
|
||||
assert queue.is_running is False
|
||||
|
||||
queue.start(service)
|
||||
assert queue.is_running is True
|
||||
|
||||
queue.stop()
|
||||
assert queue.is_running is False
|
||||
|
||||
def test_double_start_ignored(self):
|
||||
"""Starting queue twice is safely ignored."""
|
||||
queue = BatchTaskQueue(max_size=10, worker_count=1)
|
||||
service = MockBatchService()
|
||||
|
||||
queue.start(service)
|
||||
worker_count_after_first_start = len(queue._workers)
|
||||
|
||||
queue.start(service) # Second start
|
||||
worker_count_after_second_start = len(queue._workers)
|
||||
|
||||
assert worker_count_after_first_start == worker_count_after_second_start
|
||||
|
||||
queue.stop()
|
||||
|
||||
def test_error_handling_in_worker(self):
|
||||
"""Worker handles processing errors gracefully."""
|
||||
queue = BatchTaskQueue(max_size=10, worker_count=1)
|
||||
service = MockBatchService()
|
||||
service.should_fail = True # Cause errors
|
||||
|
||||
queue.start(service)
|
||||
|
||||
task = BatchTask(
|
||||
batch_id=uuid4(),
|
||||
admin_token="test-token",
|
||||
zip_content=b"test",
|
||||
zip_filename="test.zip",
|
||||
upload_source="ui",
|
||||
auto_label=True,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
queue.submit(task)
|
||||
|
||||
# Wait for processing attempt
|
||||
time.sleep(0.5)
|
||||
|
||||
# Worker should still be running
|
||||
assert queue.is_running is True
|
||||
|
||||
queue.stop()
|
||||
368
tests/web/test_batch_upload_routes.py
Normal file
368
tests/web/test_batch_upload_routes.py
Normal file
@@ -0,0 +1,368 @@
|
||||
"""
|
||||
Tests for Batch Upload Routes
|
||||
"""
|
||||
|
||||
import io
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from src.web.api.v1.batch.routes import router
|
||||
from src.web.core.auth import validate_admin_token, get_admin_db
|
||||
from src.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
|
||||
from src.web.services.batch_upload import BatchUploadService
|
||||
|
||||
|
||||
class MockAdminDB:
|
||||
"""Mock AdminDB for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.batches = {}
|
||||
self.batch_files = {}
|
||||
|
||||
def create_batch_upload(self, admin_token, filename, file_size, upload_source):
|
||||
batch_id = uuid4()
|
||||
batch = type('BatchUpload', (), {
|
||||
'batch_id': batch_id,
|
||||
'admin_token': admin_token,
|
||||
'filename': filename,
|
||||
'file_size': file_size,
|
||||
'upload_source': upload_source,
|
||||
'status': 'processing',
|
||||
'total_files': 0,
|
||||
'processed_files': 0,
|
||||
'successful_files': 0,
|
||||
'failed_files': 0,
|
||||
'csv_filename': None,
|
||||
'csv_row_count': None,
|
||||
'error_message': None,
|
||||
'created_at': datetime.utcnow(),
|
||||
'completed_at': None,
|
||||
})()
|
||||
self.batches[batch_id] = batch
|
||||
return batch
|
||||
|
||||
def update_batch_upload(self, batch_id, **kwargs):
|
||||
if batch_id in self.batches:
|
||||
batch = self.batches[batch_id]
|
||||
for key, value in kwargs.items():
|
||||
setattr(batch, key, value)
|
||||
|
||||
def create_batch_upload_file(self, batch_id, filename, **kwargs):
|
||||
file_id = uuid4()
|
||||
defaults = {
|
||||
'file_id': file_id,
|
||||
'batch_id': batch_id,
|
||||
'filename': filename,
|
||||
'status': 'pending',
|
||||
'error_message': None,
|
||||
'annotation_count': 0,
|
||||
'csv_row_data': None,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
file_record = type('BatchUploadFile', (), defaults)()
|
||||
if batch_id not in self.batch_files:
|
||||
self.batch_files[batch_id] = []
|
||||
self.batch_files[batch_id].append(file_record)
|
||||
return file_record
|
||||
|
||||
def update_batch_upload_file(self, file_id, **kwargs):
|
||||
for files in self.batch_files.values():
|
||||
for file_record in files:
|
||||
if file_record.file_id == file_id:
|
||||
for key, value in kwargs.items():
|
||||
setattr(file_record, key, value)
|
||||
return
|
||||
|
||||
def get_batch_upload(self, batch_id):
|
||||
return self.batches.get(batch_id, type('BatchUpload', (), {
|
||||
'batch_id': batch_id,
|
||||
'admin_token': 'test-token',
|
||||
'filename': 'test.zip',
|
||||
'status': 'completed',
|
||||
'total_files': 2,
|
||||
'processed_files': 2,
|
||||
'successful_files': 2,
|
||||
'failed_files': 0,
|
||||
'csv_filename': None,
|
||||
'csv_row_count': None,
|
||||
'error_message': None,
|
||||
'created_at': datetime.utcnow(),
|
||||
'completed_at': datetime.utcnow(),
|
||||
})())
|
||||
|
||||
def get_batch_upload_files(self, batch_id):
|
||||
return self.batch_files.get(batch_id, [])
|
||||
|
||||
def get_batch_uploads_by_token(self, admin_token, limit=50, offset=0):
|
||||
"""Get batches filtered by admin token with pagination."""
|
||||
token_batches = [b for b in self.batches.values() if b.admin_token == admin_token]
|
||||
total = len(token_batches)
|
||||
return token_batches[offset:offset+limit], total
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def app():
|
||||
"""Create test FastAPI app with mocked dependencies."""
|
||||
app = FastAPI()
|
||||
|
||||
# Create mock admin DB
|
||||
mock_admin_db = MockAdminDB()
|
||||
|
||||
# Override dependencies
|
||||
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
|
||||
app.dependency_overrides[get_admin_db] = lambda: mock_admin_db
|
||||
|
||||
# Initialize batch queue with mock service
|
||||
batch_service = BatchUploadService(mock_admin_db)
|
||||
init_batch_queue(batch_service)
|
||||
|
||||
app.include_router(router)
|
||||
|
||||
yield app
|
||||
|
||||
# Cleanup: shutdown batch queue after all tests in class
|
||||
shutdown_batch_queue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create test client."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def create_test_zip(files):
|
||||
"""Create a test ZIP file."""
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
|
||||
for filename, content in files.items():
|
||||
zip_file.writestr(filename, content)
|
||||
zip_buffer.seek(0)
|
||||
return zip_buffer
|
||||
|
||||
|
||||
class TestBatchUploadRoutes:
|
||||
"""Tests for batch upload API routes."""
|
||||
|
||||
def test_upload_batch_success(self, client):
|
||||
"""Test successful batch upload (defaults to async mode)."""
|
||||
files = {
|
||||
"INV001.pdf": b"%PDF-1.4 test content",
|
||||
"INV002.pdf": b"%PDF-1.4 test content 2",
|
||||
}
|
||||
zip_file = create_test_zip(files)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/admin/batch/upload",
|
||||
files={"file": ("test.zip", zip_file, "application/zip")},
|
||||
data={"upload_source": "ui"},
|
||||
)
|
||||
|
||||
# Async mode is default, should return 202
|
||||
assert response.status_code == 202
|
||||
result = response.json()
|
||||
assert "batch_id" in result
|
||||
assert result["status"] == "accepted"
|
||||
|
||||
def test_upload_batch_non_zip_file(self, client):
|
||||
"""Test uploading non-ZIP file."""
|
||||
response = client.post(
|
||||
"/api/v1/admin/batch/upload",
|
||||
files={"file": ("test.pdf", io.BytesIO(b"test"), "application/pdf")},
|
||||
data={"upload_source": "ui"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Only ZIP files" in response.json()["detail"]
|
||||
|
||||
def test_upload_batch_with_csv(self, client):
|
||||
"""Test batch upload with CSV (defaults to async)."""
|
||||
csv_content = """DocumentId,InvoiceNumber,Amount
|
||||
INV001,F2024-001,1500.00
|
||||
INV002,F2024-002,2500.00
|
||||
"""
|
||||
files = {
|
||||
"INV001.pdf": b"%PDF-1.4 test",
|
||||
"INV002.pdf": b"%PDF-1.4 test 2",
|
||||
"metadata.csv": csv_content.encode('utf-8'),
|
||||
}
|
||||
zip_file = create_test_zip(files)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/admin/batch/upload",
|
||||
files={"file": ("batch.zip", zip_file, "application/zip")},
|
||||
data={"upload_source": "api"},
|
||||
)
|
||||
|
||||
# Async mode is default, should return 202
|
||||
assert response.status_code == 202
|
||||
result = response.json()
|
||||
assert "batch_id" in result
|
||||
assert result["status"] == "accepted"
|
||||
|
||||
def test_get_batch_status(self, client):
|
||||
"""Test getting batch status."""
|
||||
batch_id = str(uuid4())
|
||||
response = client.get(f"/api/v1/admin/batch/status/{batch_id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
assert result["batch_id"] == batch_id
|
||||
assert "status" in result
|
||||
assert "total_files" in result
|
||||
|
||||
def test_list_batch_uploads(self, client):
|
||||
"""Test listing batch uploads."""
|
||||
response = client.get("/api/v1/admin/batch/list")
|
||||
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
assert "batches" in result
|
||||
assert "total" in result
|
||||
assert "limit" in result
|
||||
assert "offset" in result
|
||||
|
||||
def test_upload_batch_async_mode_default(self, client):
|
||||
"""Test async mode is default (async_mode=True)."""
|
||||
files = {
|
||||
"INV001.pdf": b"%PDF-1.4 test content",
|
||||
}
|
||||
zip_file = create_test_zip(files)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/admin/batch/upload",
|
||||
files={"file": ("test.zip", zip_file, "application/zip")},
|
||||
data={"upload_source": "ui"},
|
||||
)
|
||||
|
||||
# Async mode should return 202 Accepted
|
||||
assert response.status_code == 202
|
||||
result = response.json()
|
||||
assert result["status"] == "accepted"
|
||||
assert "batch_id" in result
|
||||
assert "status_url" in result
|
||||
assert "queue_depth" in result
|
||||
assert result["message"] == "Batch upload queued for processing"
|
||||
|
||||
def test_upload_batch_async_mode_explicit(self, client):
|
||||
"""Test explicit async mode (async_mode=True)."""
|
||||
files = {
|
||||
"INV001.pdf": b"%PDF-1.4 test content",
|
||||
}
|
||||
zip_file = create_test_zip(files)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/admin/batch/upload",
|
||||
files={"file": ("test.zip", zip_file, "application/zip")},
|
||||
data={"upload_source": "ui", "async_mode": "true"},
|
||||
)
|
||||
|
||||
assert response.status_code == 202
|
||||
result = response.json()
|
||||
assert result["status"] == "accepted"
|
||||
assert "batch_id" in result
|
||||
assert "status_url" in result
|
||||
|
||||
def test_upload_batch_sync_mode(self, client):
|
||||
"""Test sync mode (async_mode=False)."""
|
||||
files = {
|
||||
"INV001.pdf": b"%PDF-1.4 test content",
|
||||
}
|
||||
zip_file = create_test_zip(files)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/admin/batch/upload",
|
||||
files={"file": ("test.zip", zip_file, "application/zip")},
|
||||
data={"upload_source": "ui", "async_mode": "false"},
|
||||
)
|
||||
|
||||
# Sync mode should return 200 OK with full results
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
assert "batch_id" in result
|
||||
assert result["status"] in ["completed", "partial", "failed"]
|
||||
assert "successful_files" in result
|
||||
|
||||
def test_upload_batch_async_with_auto_label(self, client):
|
||||
"""Test async mode with auto_label flag."""
|
||||
files = {
|
||||
"INV001.pdf": b"%PDF-1.4 test content",
|
||||
}
|
||||
zip_file = create_test_zip(files)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/admin/batch/upload",
|
||||
files={"file": ("test.zip", zip_file, "application/zip")},
|
||||
data={
|
||||
"upload_source": "ui",
|
||||
"async_mode": "true",
|
||||
"auto_label": "true",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 202
|
||||
result = response.json()
|
||||
assert result["status"] == "accepted"
|
||||
assert "batch_id" in result
|
||||
|
||||
def test_upload_batch_async_without_auto_label(self, client):
|
||||
"""Test async mode with auto_label disabled."""
|
||||
files = {
|
||||
"INV001.pdf": b"%PDF-1.4 test content",
|
||||
}
|
||||
zip_file = create_test_zip(files)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/admin/batch/upload",
|
||||
files={"file": ("test.zip", zip_file, "application/zip")},
|
||||
data={
|
||||
"upload_source": "ui",
|
||||
"async_mode": "true",
|
||||
"auto_label": "false",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 202
|
||||
result = response.json()
|
||||
assert result["status"] == "accepted"
|
||||
|
||||
def test_upload_batch_queue_full(self, client):
|
||||
"""Test handling queue full scenario."""
|
||||
# This test would require mocking the queue to return False on submit
|
||||
# For now, we verify the endpoint accepts the request
|
||||
files = {
|
||||
"INV001.pdf": b"%PDF-1.4 test content",
|
||||
}
|
||||
zip_file = create_test_zip(files)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/admin/batch/upload",
|
||||
files={"file": ("test.zip", zip_file, "application/zip")},
|
||||
data={"upload_source": "ui", "async_mode": "true"},
|
||||
)
|
||||
|
||||
# Should either accept (202) or reject if queue full (503)
|
||||
assert response.status_code in [202, 503]
|
||||
|
||||
def test_async_status_url_format(self, client):
|
||||
"""Test async response contains correctly formatted status URL."""
|
||||
files = {
|
||||
"INV001.pdf": b"%PDF-1.4 test content",
|
||||
}
|
||||
zip_file = create_test_zip(files)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/admin/batch/upload",
|
||||
files={"file": ("test.zip", zip_file, "application/zip")},
|
||||
data={"async_mode": "true"},
|
||||
)
|
||||
|
||||
assert response.status_code == 202
|
||||
result = response.json()
|
||||
batch_id = result["batch_id"]
|
||||
expected_url = f"/api/v1/admin/batch/status/{batch_id}"
|
||||
assert result["status_url"] == expected_url
|
||||
221
tests/web/test_batch_upload_service.py
Normal file
221
tests/web/test_batch_upload_service.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""
|
||||
Tests for Batch Upload Service
|
||||
"""
|
||||
|
||||
import io
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from src.data.admin_db import AdminDB
|
||||
from src.web.services.batch_upload import BatchUploadService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_db():
|
||||
"""Mock admin database for testing."""
|
||||
class MockAdminDB:
|
||||
def __init__(self):
|
||||
self.batches = {}
|
||||
self.batch_files = {}
|
||||
|
||||
def create_batch_upload(self, admin_token, filename, file_size, upload_source):
|
||||
batch_id = uuid4()
|
||||
batch = type('BatchUpload', (), {
|
||||
'batch_id': batch_id,
|
||||
'admin_token': admin_token,
|
||||
'filename': filename,
|
||||
'file_size': file_size,
|
||||
'upload_source': upload_source,
|
||||
'status': 'processing',
|
||||
'total_files': 0,
|
||||
'processed_files': 0,
|
||||
'successful_files': 0,
|
||||
'failed_files': 0,
|
||||
'csv_filename': None,
|
||||
'csv_row_count': None,
|
||||
'error_message': None,
|
||||
'created_at': None,
|
||||
'completed_at': None,
|
||||
})()
|
||||
self.batches[batch_id] = batch
|
||||
return batch
|
||||
|
||||
def update_batch_upload(self, batch_id, **kwargs):
|
||||
if batch_id in self.batches:
|
||||
batch = self.batches[batch_id]
|
||||
for key, value in kwargs.items():
|
||||
setattr(batch, key, value)
|
||||
|
||||
def create_batch_upload_file(self, batch_id, filename, **kwargs):
|
||||
file_id = uuid4()
|
||||
# Set defaults for attributes
|
||||
defaults = {
|
||||
'file_id': file_id,
|
||||
'batch_id': batch_id,
|
||||
'filename': filename,
|
||||
'status': 'pending',
|
||||
'error_message': None,
|
||||
'annotation_count': 0,
|
||||
'csv_row_data': None,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
file_record = type('BatchUploadFile', (), defaults)()
|
||||
if batch_id not in self.batch_files:
|
||||
self.batch_files[batch_id] = []
|
||||
self.batch_files[batch_id].append(file_record)
|
||||
return file_record
|
||||
|
||||
def update_batch_upload_file(self, file_id, **kwargs):
|
||||
for files in self.batch_files.values():
|
||||
for file_record in files:
|
||||
if file_record.file_id == file_id:
|
||||
for key, value in kwargs.items():
|
||||
setattr(file_record, key, value)
|
||||
return
|
||||
|
||||
def get_batch_upload(self, batch_id):
|
||||
return self.batches.get(batch_id)
|
||||
|
||||
def get_batch_upload_files(self, batch_id):
|
||||
return self.batch_files.get(batch_id, [])
|
||||
|
||||
return MockAdminDB()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def batch_service(admin_db):
|
||||
"""Batch upload service instance."""
|
||||
return BatchUploadService(admin_db)
|
||||
|
||||
|
||||
def create_test_zip(files):
|
||||
"""Create a test ZIP file with given files.
|
||||
|
||||
Args:
|
||||
files: Dictionary mapping filenames to content bytes
|
||||
|
||||
Returns:
|
||||
ZIP file content as bytes
|
||||
"""
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
|
||||
for filename, content in files.items():
|
||||
zip_file.writestr(filename, content)
|
||||
return zip_buffer.getvalue()
|
||||
|
||||
|
||||
class TestBatchUploadService:
|
||||
"""Tests for BatchUploadService."""
|
||||
|
||||
def test_process_empty_zip(self, batch_service):
|
||||
"""Test processing an empty ZIP file."""
|
||||
zip_content = create_test_zip({})
|
||||
result = batch_service.process_zip_upload(
|
||||
admin_token="test-token",
|
||||
zip_filename="empty.zip",
|
||||
zip_content=zip_content,
|
||||
)
|
||||
|
||||
assert result["status"] == "failed"
|
||||
assert "No PDF files" in result.get("error", "")
|
||||
|
||||
def test_process_zip_with_pdfs_only(self, batch_service):
|
||||
"""Test processing ZIP with PDFs but no CSV."""
|
||||
files = {
|
||||
"INV001.pdf": b"%PDF-1.4 test content",
|
||||
"INV002.pdf": b"%PDF-1.4 test content 2",
|
||||
}
|
||||
zip_content = create_test_zip(files)
|
||||
|
||||
result = batch_service.process_zip_upload(
|
||||
admin_token="test-token",
|
||||
zip_filename="invoices.zip",
|
||||
zip_content=zip_content,
|
||||
)
|
||||
|
||||
assert result["status"] == "completed"
|
||||
assert result["total_files"] == 2
|
||||
assert result["successful_files"] == 2
|
||||
assert result["failed_files"] == 0
|
||||
|
||||
def test_process_zip_with_csv(self, batch_service):
|
||||
"""Test processing ZIP with PDFs and CSV."""
|
||||
csv_content = """DocumentId,InvoiceNumber,Amount,OCR
|
||||
INV001,F2024-001,1500.00,7350012345678
|
||||
INV002,F2024-002,2500.00,7350087654321
|
||||
"""
|
||||
files = {
|
||||
"INV001.pdf": b"%PDF-1.4 test content",
|
||||
"INV002.pdf": b"%PDF-1.4 test content 2",
|
||||
"metadata.csv": csv_content.encode('utf-8'),
|
||||
}
|
||||
zip_content = create_test_zip(files)
|
||||
|
||||
result = batch_service.process_zip_upload(
|
||||
admin_token="test-token",
|
||||
zip_filename="invoices.zip",
|
||||
zip_content=zip_content,
|
||||
)
|
||||
|
||||
assert result["status"] == "completed"
|
||||
assert result["total_files"] == 2
|
||||
assert result["csv_filename"] == "metadata.csv"
|
||||
assert result["csv_row_count"] == 2
|
||||
|
||||
def test_process_invalid_zip(self, batch_service):
|
||||
"""Test processing invalid ZIP file."""
|
||||
result = batch_service.process_zip_upload(
|
||||
admin_token="test-token",
|
||||
zip_filename="invalid.zip",
|
||||
zip_content=b"not a zip file",
|
||||
)
|
||||
|
||||
assert result["status"] == "failed"
|
||||
assert "Invalid ZIP file" in result.get("error", "")
|
||||
|
||||
def test_csv_parsing(self, batch_service):
|
||||
"""Test CSV field parsing."""
|
||||
csv_content = """DocumentId,InvoiceNumber,InvoiceDate,Amount,OCR,Bankgiro,customer_number
|
||||
INV001,F2024-001,2024-01-15,1500.00,7350012345678,123-4567,C123
|
||||
INV002,F2024-002,2024-01-16,2500.00,7350087654321,123-4567,C124
|
||||
"""
|
||||
zip_file_content = create_test_zip({"metadata.csv": csv_content.encode('utf-8')})
|
||||
|
||||
with zipfile.ZipFile(io.BytesIO(zip_file_content)) as zip_file:
|
||||
csv_file_info = [f for f in zip_file.filelist if f.filename.endswith('.csv')][0]
|
||||
csv_data = batch_service._parse_csv_file(zip_file, csv_file_info)
|
||||
|
||||
assert len(csv_data) == 2
|
||||
assert "INV001" in csv_data
|
||||
assert csv_data["INV001"]["InvoiceNumber"] == "F2024-001"
|
||||
assert csv_data["INV001"]["Amount"] == "1500.00"
|
||||
assert csv_data["INV001"]["customer_number"] == "C123"
|
||||
|
||||
def test_get_batch_status(self, batch_service, admin_db):
|
||||
"""Test getting batch upload status."""
|
||||
# Create a batch
|
||||
zip_content = create_test_zip({"INV001.pdf": b"%PDF-1.4 test"})
|
||||
result = batch_service.process_zip_upload(
|
||||
admin_token="test-token",
|
||||
zip_filename="test.zip",
|
||||
zip_content=zip_content,
|
||||
)
|
||||
|
||||
batch_id = result["batch_id"]
|
||||
|
||||
# Get status
|
||||
status = batch_service.get_batch_status(batch_id)
|
||||
|
||||
assert status["batch_id"] == batch_id
|
||||
assert status["filename"] == "test.zip"
|
||||
assert status["status"] == "completed"
|
||||
assert status["total_files"] == 1
|
||||
assert len(status["files"]) == 1
|
||||
|
||||
def test_get_batch_status_not_found(self, batch_service):
|
||||
"""Test getting status for non-existent batch."""
|
||||
status = batch_service.get_batch_status(str(uuid4()))
|
||||
assert "error" in status
|
||||
298
tests/web/test_inference_api.py
Normal file
298
tests/web/test_inference_api.py
Normal file
@@ -0,0 +1,298 @@
|
||||
"""
|
||||
Integration tests for inference API endpoints.
|
||||
|
||||
Tests the /api/v1/infer endpoint to ensure it works end-to-end.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
from src.web.app import create_app
|
||||
from src.web.config import ModelConfig, StorageConfig, AppConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_app(tmp_path):
|
||||
"""Create test FastAPI application."""
|
||||
# Setup test directories
|
||||
upload_dir = tmp_path / "uploads"
|
||||
result_dir = tmp_path / "results"
|
||||
upload_dir.mkdir()
|
||||
result_dir.mkdir()
|
||||
|
||||
# Create test config
|
||||
app_config = AppConfig(
|
||||
model=ModelConfig(
|
||||
model_path=Path("runs/train/invoice_fields/weights/best.pt"),
|
||||
confidence_threshold=0.5,
|
||||
use_gpu=False,
|
||||
dpi=150,
|
||||
),
|
||||
storage=StorageConfig(
|
||||
upload_dir=upload_dir,
|
||||
result_dir=result_dir,
|
||||
allowed_extensions={".pdf", ".png", ".jpg", ".jpeg"},
|
||||
max_file_size_mb=50,
|
||||
),
|
||||
)
|
||||
|
||||
# Create app
|
||||
app = create_app(app_config)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(test_app):
|
||||
"""Create test client."""
|
||||
return TestClient(test_app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_png_bytes():
|
||||
"""Create sample PNG image bytes."""
|
||||
img = Image.new('RGB', (800, 1200), color='white')
|
||||
img_bytes = io.BytesIO()
|
||||
img.save(img_bytes, format='PNG')
|
||||
img_bytes.seek(0)
|
||||
return img_bytes
|
||||
|
||||
|
||||
class TestHealthEndpoint:
|
||||
"""Test /api/v1/health endpoint."""
|
||||
|
||||
def test_health_check_returns_200(self, client):
|
||||
"""Test health check returns 200 OK."""
|
||||
response = client.get("/api/v1/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_health_check_response_structure(self, client):
|
||||
"""Test health check response has correct structure."""
|
||||
response = client.get("/api/v1/health")
|
||||
data = response.json()
|
||||
|
||||
assert "status" in data
|
||||
assert "model_loaded" in data
|
||||
assert "gpu_available" in data
|
||||
assert "version" in data
|
||||
|
||||
assert data["status"] == "healthy"
|
||||
assert isinstance(data["model_loaded"], bool)
|
||||
assert isinstance(data["gpu_available"], bool)
|
||||
|
||||
|
||||
class TestInferEndpoint:
|
||||
"""Test /api/v1/infer endpoint."""
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
def test_infer_accepts_png_file(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
mock_pipeline,
|
||||
client,
|
||||
sample_png_bytes,
|
||||
):
|
||||
"""Test that /infer endpoint accepts PNG files."""
|
||||
# Setup mocks
|
||||
mock_detector_instance = Mock()
|
||||
mock_pipeline_instance = Mock()
|
||||
mock_yolo_detector.return_value = mock_detector_instance
|
||||
mock_pipeline.return_value = mock_pipeline_instance
|
||||
|
||||
# Mock pipeline result
|
||||
mock_result = Mock()
|
||||
mock_result.fields = {"InvoiceNumber": "12345"}
|
||||
mock_result.confidence = {"InvoiceNumber": 0.95}
|
||||
mock_result.success = True
|
||||
mock_result.errors = []
|
||||
mock_result.raw_detections = []
|
||||
mock_result.document_id = "test123"
|
||||
mock_result.document_type = "invoice"
|
||||
mock_result.processing_time_ms = 100.0
|
||||
mock_result.visualization_path = None
|
||||
mock_result.detections = []
|
||||
mock_pipeline_instance.process_image.return_value = mock_result
|
||||
|
||||
# Make request
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.png", sample_png_bytes, "image/png")},
|
||||
)
|
||||
|
||||
# Verify response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["status"] == "success"
|
||||
assert "result" in data
|
||||
assert data["result"]["fields"]["InvoiceNumber"] == "12345"
|
||||
assert data["result"]["confidence"]["InvoiceNumber"] == 0.95
|
||||
|
||||
def test_infer_rejects_invalid_file_type(self, client):
|
||||
"""Test that /infer rejects unsupported file types."""
|
||||
invalid_file = io.BytesIO(b"fake txt content")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.txt", invalid_file, "text/plain")},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Unsupported file type" in response.json()["detail"]
|
||||
|
||||
def test_infer_requires_file(self, client):
|
||||
"""Test that /infer requires a file parameter."""
|
||||
response = client.post("/api/v1/infer")
|
||||
|
||||
assert response.status_code == 422 # Unprocessable Entity
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
def test_infer_returns_cross_validation_if_available(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
mock_pipeline,
|
||||
client,
|
||||
sample_png_bytes,
|
||||
):
|
||||
"""Test that cross-validation results are included if available."""
|
||||
# Setup mocks
|
||||
mock_detector_instance = Mock()
|
||||
mock_pipeline_instance = Mock()
|
||||
mock_yolo_detector.return_value = mock_detector_instance
|
||||
mock_pipeline.return_value = mock_pipeline_instance
|
||||
|
||||
# Mock pipeline result with cross-validation
|
||||
mock_result = Mock()
|
||||
mock_result.fields = {
|
||||
"InvoiceNumber": "12345",
|
||||
"OCR": "1234567",
|
||||
"Amount": "100.00",
|
||||
}
|
||||
mock_result.confidence = {
|
||||
"InvoiceNumber": 0.95,
|
||||
"OCR": 0.90,
|
||||
"Amount": 0.88,
|
||||
}
|
||||
mock_result.success = True
|
||||
mock_result.errors = []
|
||||
mock_result.raw_detections = []
|
||||
mock_result.document_id = "test123"
|
||||
mock_result.document_type = "invoice"
|
||||
mock_result.processing_time_ms = 100.0
|
||||
mock_result.visualization_path = None
|
||||
mock_result.detections = []
|
||||
|
||||
# Add cross-validation result
|
||||
mock_cv = Mock()
|
||||
mock_cv.is_valid = True
|
||||
mock_cv.payment_line_ocr = "1234567"
|
||||
mock_cv.ocr_match = True
|
||||
mock_result.cross_validation = mock_cv
|
||||
|
||||
mock_pipeline_instance.process_image.return_value = mock_result
|
||||
|
||||
# Make request
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.png", sample_png_bytes, "image/png")},
|
||||
)
|
||||
|
||||
# Verify response includes cross-validation
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Note: cross_validation is not currently in the response schema
|
||||
# This test documents that it should be added
|
||||
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
def test_infer_handles_processing_errors_gracefully(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
mock_pipeline,
|
||||
client,
|
||||
sample_png_bytes,
|
||||
):
|
||||
"""Test that processing errors are handled gracefully."""
|
||||
# Setup mocks
|
||||
mock_detector_instance = Mock()
|
||||
mock_pipeline_instance = Mock()
|
||||
mock_yolo_detector.return_value = mock_detector_instance
|
||||
mock_pipeline.return_value = mock_pipeline_instance
|
||||
|
||||
# Make pipeline raise an error
|
||||
mock_pipeline_instance.process_image.side_effect = Exception("Model inference failed")
|
||||
|
||||
# Make request
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.png", sample_png_bytes, "image/png")},
|
||||
)
|
||||
|
||||
# Verify error handling - service catches exceptions and returns partial results
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "partial"
|
||||
assert data["result"]["success"] is False
|
||||
assert len(data["result"]["errors"]) > 0
|
||||
assert "Model inference failed" in data["result"]["errors"][0]
|
||||
|
||||
|
||||
class TestResultsEndpoint:
|
||||
"""Test /api/v1/results/{filename} endpoint."""
|
||||
|
||||
def test_get_result_image_returns_404_if_not_found(self, client):
|
||||
"""Test that getting non-existent result returns 404."""
|
||||
response = client.get("/api/v1/results/nonexistent.png")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_get_result_image_returns_file_if_exists(self, client, test_app, tmp_path):
|
||||
"""Test that existing result file is returned."""
|
||||
# Get storage config from app
|
||||
storage_config = test_app.extra.get("storage_config")
|
||||
if not storage_config:
|
||||
pytest.skip("Storage config not available in test app")
|
||||
|
||||
# Create a test result file
|
||||
result_file = storage_config.result_dir / "test_result.png"
|
||||
img = Image.new('RGB', (100, 100), color='red')
|
||||
img.save(result_file)
|
||||
|
||||
# Request the file
|
||||
response = client.get("/api/v1/results/test_result.png")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "image/png"
|
||||
|
||||
|
||||
class TestInferenceServiceImports:
|
||||
"""Critical test to catch import errors."""
|
||||
|
||||
def test_inference_service_can_import_modules(self):
|
||||
"""
|
||||
Test that InferenceService can import its dependencies.
|
||||
|
||||
This test will fail if there are ImportError issues like:
|
||||
- from ..inference.pipeline (wrong relative import)
|
||||
- from src.web.inference (non-existent module)
|
||||
|
||||
It ensures the imports are correct before runtime.
|
||||
"""
|
||||
from src.web.services.inference import InferenceService
|
||||
|
||||
# Import the modules that InferenceService tries to import
|
||||
from src.inference.pipeline import InferencePipeline
|
||||
from src.inference.yolo_detector import YOLODetector
|
||||
from src.pdf.renderer import render_pdf_to_images
|
||||
|
||||
# If we got here, all imports work correctly
|
||||
assert InferencePipeline is not None
|
||||
assert YOLODetector is not None
|
||||
assert render_pdf_to_images is not None
|
||||
assert InferenceService is not None
|
||||
297
tests/web/test_inference_service.py
Normal file
297
tests/web/test_inference_service.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""
|
||||
Integration tests for inference service.
|
||||
|
||||
Tests the full initialization and processing flow to catch import errors.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
from src.web.services.inference import InferenceService
|
||||
from src.web.config import ModelConfig, StorageConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_config(tmp_path):
|
||||
"""Create model configuration for testing."""
|
||||
return ModelConfig(
|
||||
model_path=Path("runs/train/invoice_fields/weights/best.pt"),
|
||||
confidence_threshold=0.5,
|
||||
use_gpu=False, # Use CPU for tests
|
||||
dpi=150,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage_config(tmp_path):
|
||||
"""Create storage configuration for testing."""
|
||||
upload_dir = tmp_path / "uploads"
|
||||
result_dir = tmp_path / "results"
|
||||
upload_dir.mkdir()
|
||||
result_dir.mkdir()
|
||||
|
||||
return StorageConfig(
|
||||
upload_dir=upload_dir,
|
||||
result_dir=result_dir,
|
||||
allowed_extensions={".pdf", ".png", ".jpg", ".jpeg"},
|
||||
max_file_size_mb=50,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image(tmp_path):
|
||||
"""Create a sample test image."""
|
||||
image_path = tmp_path / "test_invoice.png"
|
||||
img = Image.new('RGB', (800, 1200), color='white')
|
||||
img.save(image_path)
|
||||
return image_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def inference_service(model_config, storage_config):
|
||||
"""Create inference service instance."""
|
||||
return InferenceService(
|
||||
model_config=model_config,
|
||||
storage_config=storage_config,
|
||||
)
|
||||
|
||||
|
||||
class TestInferenceServiceInitialization:
|
||||
"""Test inference service initialization to catch import errors."""
|
||||
|
||||
def test_service_creation(self, inference_service):
|
||||
"""Test that service can be created without errors."""
|
||||
assert inference_service is not None
|
||||
assert not inference_service.is_initialized
|
||||
|
||||
def test_gpu_available_check(self, inference_service):
|
||||
"""Test GPU availability check (should not crash)."""
|
||||
gpu_available = inference_service.gpu_available
|
||||
assert isinstance(gpu_available, bool)
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
def test_initialize_imports_correctly(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
mock_pipeline,
|
||||
inference_service,
|
||||
):
|
||||
"""
|
||||
Test that initialize() imports modules correctly.
|
||||
|
||||
This test ensures that the import statements in initialize()
|
||||
use correct paths and don't fail with ImportError.
|
||||
"""
|
||||
# Mock the constructors to avoid actually loading models
|
||||
mock_detector_instance = Mock()
|
||||
mock_pipeline_instance = Mock()
|
||||
mock_yolo_detector.return_value = mock_detector_instance
|
||||
mock_pipeline.return_value = mock_pipeline_instance
|
||||
|
||||
# Initialize should not raise ImportError
|
||||
inference_service.initialize()
|
||||
|
||||
# Verify initialization succeeded
|
||||
assert inference_service.is_initialized
|
||||
|
||||
# Verify imports were called with correct parameters
|
||||
mock_yolo_detector.assert_called_once()
|
||||
mock_pipeline.assert_called_once()
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
def test_initialize_sets_up_pipeline(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
mock_pipeline,
|
||||
inference_service,
|
||||
model_config,
|
||||
):
|
||||
"""Test that initialize sets up pipeline with correct config."""
|
||||
mock_detector_instance = Mock()
|
||||
mock_pipeline_instance = Mock()
|
||||
mock_yolo_detector.return_value = mock_detector_instance
|
||||
mock_pipeline.return_value = mock_pipeline_instance
|
||||
|
||||
inference_service.initialize()
|
||||
|
||||
# Check YOLO detector was initialized correctly
|
||||
mock_yolo_detector.assert_called_once_with(
|
||||
str(model_config.model_path),
|
||||
confidence_threshold=model_config.confidence_threshold,
|
||||
device="cpu", # use_gpu=False in fixture
|
||||
)
|
||||
|
||||
# Check pipeline was initialized correctly
|
||||
mock_pipeline.assert_called_once_with(
|
||||
model_path=str(model_config.model_path),
|
||||
confidence_threshold=model_config.confidence_threshold,
|
||||
use_gpu=False,
|
||||
dpi=150,
|
||||
enable_fallback=True,
|
||||
)
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
def test_initialize_idempotent(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
mock_pipeline,
|
||||
inference_service,
|
||||
):
|
||||
"""Test that calling initialize() multiple times is safe."""
|
||||
mock_detector_instance = Mock()
|
||||
mock_pipeline_instance = Mock()
|
||||
mock_yolo_detector.return_value = mock_detector_instance
|
||||
mock_pipeline.return_value = mock_pipeline_instance
|
||||
|
||||
# Call initialize twice
|
||||
inference_service.initialize()
|
||||
inference_service.initialize()
|
||||
|
||||
# Should only be called once due to is_initialized check
|
||||
assert mock_yolo_detector.call_count == 1
|
||||
assert mock_pipeline.call_count == 1
|
||||
|
||||
|
||||
class TestInferenceServiceProcessing:
|
||||
"""Test inference processing methods."""
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
@patch('ultralytics.YOLO')
|
||||
def test_process_image_basic_flow(
|
||||
self,
|
||||
mock_yolo_class,
|
||||
mock_yolo_detector,
|
||||
mock_pipeline,
|
||||
inference_service,
|
||||
sample_image,
|
||||
):
|
||||
"""Test basic image processing flow."""
|
||||
# Setup mocks
|
||||
mock_detector_instance = Mock()
|
||||
mock_pipeline_instance = Mock()
|
||||
mock_yolo_detector.return_value = mock_detector_instance
|
||||
mock_pipeline.return_value = mock_pipeline_instance
|
||||
|
||||
# Mock pipeline result
|
||||
mock_result = Mock()
|
||||
mock_result.fields = {"InvoiceNumber": "12345"}
|
||||
mock_result.confidence = {"InvoiceNumber": 0.95}
|
||||
mock_result.success = True
|
||||
mock_result.errors = []
|
||||
mock_result.raw_detections = []
|
||||
mock_pipeline_instance.process_image.return_value = mock_result
|
||||
|
||||
# Process image
|
||||
result = inference_service.process_image(sample_image)
|
||||
|
||||
# Verify result
|
||||
assert result.success
|
||||
assert result.fields == {"InvoiceNumber": "12345"}
|
||||
assert result.confidence == {"InvoiceNumber": 0.95}
|
||||
assert result.processing_time_ms > 0
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
def test_process_image_handles_errors(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
mock_pipeline,
|
||||
inference_service,
|
||||
sample_image,
|
||||
):
|
||||
"""Test that processing errors are handled gracefully."""
|
||||
# Setup mocks
|
||||
mock_detector_instance = Mock()
|
||||
mock_pipeline_instance = Mock()
|
||||
mock_yolo_detector.return_value = mock_detector_instance
|
||||
mock_pipeline.return_value = mock_pipeline_instance
|
||||
|
||||
# Make pipeline raise an error
|
||||
mock_pipeline_instance.process_image.side_effect = Exception("Test error")
|
||||
|
||||
# Process should not crash
|
||||
result = inference_service.process_image(sample_image)
|
||||
|
||||
# Verify error handling
|
||||
assert not result.success
|
||||
assert len(result.errors) > 0
|
||||
assert "Test error" in result.errors[0]
|
||||
|
||||
|
||||
class TestInferenceServicePDFRendering:
|
||||
"""Test PDF rendering imports."""
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
@patch('src.pdf.renderer.render_pdf_to_images')
|
||||
@patch('ultralytics.YOLO')
|
||||
def test_pdf_visualization_imports_correctly(
|
||||
self,
|
||||
mock_yolo_class,
|
||||
mock_render_pdf,
|
||||
mock_yolo_detector,
|
||||
mock_pipeline,
|
||||
inference_service,
|
||||
tmp_path,
|
||||
):
|
||||
"""
|
||||
Test that _save_pdf_visualization imports render_pdf_to_images correctly.
|
||||
|
||||
This catches the import error we had with:
|
||||
from ..pdf.renderer (wrong) vs from src.pdf.renderer (correct)
|
||||
"""
|
||||
# Setup mocks
|
||||
mock_detector_instance = Mock()
|
||||
mock_pipeline_instance = Mock()
|
||||
mock_yolo_detector.return_value = mock_detector_instance
|
||||
mock_pipeline.return_value = mock_pipeline_instance
|
||||
|
||||
# Create a fake PDF path
|
||||
pdf_path = tmp_path / "test.pdf"
|
||||
pdf_path.touch()
|
||||
|
||||
# Mock render_pdf_to_images to return an image
|
||||
image_bytes = io.BytesIO()
|
||||
img = Image.new('RGB', (800, 1200), color='white')
|
||||
img.save(image_bytes, format='PNG')
|
||||
mock_render_pdf.return_value = [(1, image_bytes.getvalue())]
|
||||
|
||||
# Mock YOLO
|
||||
mock_model_instance = Mock()
|
||||
mock_result = Mock()
|
||||
mock_result.save = Mock()
|
||||
mock_model_instance.predict.return_value = [mock_result]
|
||||
mock_yolo_class.return_value = mock_model_instance
|
||||
|
||||
# This should not raise ImportError
|
||||
result_path = inference_service._save_pdf_visualization(pdf_path, "test123")
|
||||
|
||||
# Verify import was successful
|
||||
mock_render_pdf.assert_called_once()
|
||||
assert result_path is not None
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not Path("runs/train/invoice_fields/weights/best.pt").exists(),
|
||||
reason="Model file not available"
|
||||
)
|
||||
class TestInferenceServiceRealModel:
|
||||
"""Integration tests with real model (skip if model not available)."""
|
||||
|
||||
def test_real_initialization(self, model_config, storage_config):
|
||||
"""Test real initialization with actual model."""
|
||||
service = InferenceService(model_config, storage_config)
|
||||
|
||||
# This should work with the real imports
|
||||
service.initialize()
|
||||
|
||||
assert service.is_initialized
|
||||
assert service._pipeline is not None
|
||||
assert service._detector is not None
|
||||
154
tests/web/test_rate_limiter.py
Normal file
154
tests/web/test_rate_limiter.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
Tests for the RateLimiter class.
|
||||
"""
|
||||
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from src.data.async_request_db import ApiKeyConfig
|
||||
from src.web.rate_limiter import RateLimiter, RateLimitConfig, RateLimitStatus
|
||||
|
||||
|
||||
class TestRateLimiter:
|
||||
"""Tests for RateLimiter."""
|
||||
|
||||
def test_check_submit_limit_allowed(self, rate_limiter, mock_db):
|
||||
"""Test that requests are allowed under the limit."""
|
||||
status = rate_limiter.check_submit_limit("test-api-key")
|
||||
|
||||
assert status.allowed is True
|
||||
assert status.remaining_requests >= 0
|
||||
assert status.retry_after_seconds is None
|
||||
|
||||
def test_check_submit_limit_rate_exceeded(self, rate_limiter, mock_db):
|
||||
"""Test rate limit exceeded when too many requests."""
|
||||
# Record 10 requests (the default limit)
|
||||
for _ in range(10):
|
||||
rate_limiter.record_request("test-api-key")
|
||||
|
||||
status = rate_limiter.check_submit_limit("test-api-key")
|
||||
|
||||
assert status.allowed is False
|
||||
assert status.remaining_requests == 0
|
||||
assert status.retry_after_seconds is not None
|
||||
assert status.retry_after_seconds > 0
|
||||
assert "rate limit" in status.reason.lower()
|
||||
|
||||
def test_check_submit_limit_concurrent_jobs_exceeded(self, rate_limiter, mock_db):
|
||||
"""Test rejection when max concurrent jobs reached."""
|
||||
# Mock active jobs at the limit
|
||||
mock_db.count_active_jobs.return_value = 3 # Max is 3
|
||||
|
||||
status = rate_limiter.check_submit_limit("test-api-key")
|
||||
|
||||
assert status.allowed is False
|
||||
assert "concurrent" in status.reason.lower()
|
||||
|
||||
def test_record_request(self, rate_limiter, mock_db):
|
||||
"""Test that recording a request works."""
|
||||
rate_limiter.record_request("test-api-key")
|
||||
|
||||
# Should have called the database
|
||||
mock_db.record_rate_limit_event.assert_called_once_with("test-api-key", "request")
|
||||
|
||||
def test_check_poll_limit_allowed(self, rate_limiter, mock_db):
|
||||
"""Test that polling is allowed initially."""
|
||||
status = rate_limiter.check_poll_limit("test-api-key", "request-123")
|
||||
|
||||
assert status.allowed is True
|
||||
|
||||
def test_check_poll_limit_too_frequent(self, rate_limiter, mock_db):
|
||||
"""Test that rapid polling is rejected."""
|
||||
# First poll should succeed
|
||||
status1 = rate_limiter.check_poll_limit("test-api-key", "request-123")
|
||||
assert status1.allowed is True
|
||||
|
||||
# Immediate second poll should fail
|
||||
status2 = rate_limiter.check_poll_limit("test-api-key", "request-123")
|
||||
assert status2.allowed is False
|
||||
assert "polling" in status2.reason.lower()
|
||||
assert status2.retry_after_seconds is not None
|
||||
|
||||
def test_check_poll_limit_different_requests(self, rate_limiter, mock_db):
|
||||
"""Test that different request_ids have separate poll limits."""
|
||||
# Poll request 1
|
||||
status1 = rate_limiter.check_poll_limit("test-api-key", "request-1")
|
||||
assert status1.allowed is True
|
||||
|
||||
# Poll request 2 should also be allowed
|
||||
status2 = rate_limiter.check_poll_limit("test-api-key", "request-2")
|
||||
assert status2.allowed is True
|
||||
|
||||
def test_sliding_window_expires(self, rate_limiter, mock_db):
|
||||
"""Test that requests expire from the sliding window."""
|
||||
# Record requests
|
||||
for _ in range(5):
|
||||
rate_limiter.record_request("test-api-key")
|
||||
|
||||
# Check status - should have 5 remaining
|
||||
status1 = rate_limiter.check_submit_limit("test-api-key")
|
||||
assert status1.allowed is True
|
||||
assert status1.remaining_requests == 4 # 10 - 5 - 1 (for this check)
|
||||
|
||||
def test_get_rate_limit_headers(self, rate_limiter):
|
||||
"""Test rate limit header generation."""
|
||||
status = RateLimitStatus(
|
||||
allowed=False,
|
||||
remaining_requests=0,
|
||||
reset_at=datetime.utcnow() + timedelta(seconds=30),
|
||||
retry_after_seconds=30,
|
||||
)
|
||||
|
||||
headers = rate_limiter.get_rate_limit_headers(status)
|
||||
|
||||
assert "X-RateLimit-Remaining" in headers
|
||||
assert headers["X-RateLimit-Remaining"] == "0"
|
||||
assert "Retry-After" in headers
|
||||
assert headers["Retry-After"] == "30"
|
||||
|
||||
def test_cleanup_poll_timestamps(self, rate_limiter, mock_db):
|
||||
"""Test cleanup of old poll timestamps."""
|
||||
# Add some poll timestamps
|
||||
rate_limiter.check_poll_limit("test-api-key", "old-request")
|
||||
|
||||
# Manually age the timestamp
|
||||
rate_limiter._poll_timestamps[("test-api-key", "old-request")] = time.time() - 7200
|
||||
|
||||
# Run cleanup with 1 hour max age
|
||||
cleaned = rate_limiter.cleanup_poll_timestamps(max_age_seconds=3600)
|
||||
|
||||
assert cleaned == 1
|
||||
assert ("test-api-key", "old-request") not in rate_limiter._poll_timestamps
|
||||
|
||||
def test_cleanup_request_windows(self, rate_limiter, mock_db):
|
||||
"""Test cleanup of empty request windows."""
|
||||
# Add some old requests
|
||||
rate_limiter._request_windows["old-key"] = [time.time() - 120]
|
||||
|
||||
# Run cleanup
|
||||
rate_limiter.cleanup_request_windows()
|
||||
|
||||
# Old entries should be removed
|
||||
assert "old-key" not in rate_limiter._request_windows
|
||||
|
||||
def test_config_caching(self, rate_limiter, mock_db):
|
||||
"""Test that API key configs are cached."""
|
||||
# First call should query database
|
||||
rate_limiter._get_config("test-api-key")
|
||||
assert mock_db.get_api_key_config.call_count == 1
|
||||
|
||||
# Second call should use cache
|
||||
rate_limiter._get_config("test-api-key")
|
||||
assert mock_db.get_api_key_config.call_count == 1 # Still 1
|
||||
|
||||
def test_default_config_for_unknown_key(self, rate_limiter, mock_db):
|
||||
"""Test that unknown API keys get default config."""
|
||||
mock_db.get_api_key_config.return_value = None
|
||||
|
||||
config = rate_limiter._get_config("unknown-key")
|
||||
|
||||
assert config.requests_per_minute == 10 # Default
|
||||
assert config.max_concurrent_jobs == 3 # Default
|
||||
384
tests/web/test_training_phase4.py
Normal file
384
tests/web/test_training_phase4.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""
|
||||
Tests for Phase 4: Training Data Management
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from src.web.api.v1.admin.training import create_training_router
|
||||
from src.web.core.auth import validate_admin_token, get_admin_db
|
||||
|
||||
|
||||
class MockTrainingTask:
|
||||
"""Mock TrainingTask for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.task_id = kwargs.get('task_id', uuid4())
|
||||
self.admin_token = kwargs.get('admin_token', 'test-token')
|
||||
self.name = kwargs.get('name', 'Test Training')
|
||||
self.description = kwargs.get('description', None)
|
||||
self.status = kwargs.get('status', 'completed')
|
||||
self.task_type = kwargs.get('task_type', 'train')
|
||||
self.config = kwargs.get('config', {})
|
||||
self.scheduled_at = kwargs.get('scheduled_at', None)
|
||||
self.cron_expression = kwargs.get('cron_expression', None)
|
||||
self.is_recurring = kwargs.get('is_recurring', False)
|
||||
self.started_at = kwargs.get('started_at', datetime.utcnow())
|
||||
self.completed_at = kwargs.get('completed_at', datetime.utcnow())
|
||||
self.error_message = kwargs.get('error_message', None)
|
||||
self.result_metrics = kwargs.get('result_metrics', {})
|
||||
self.model_path = kwargs.get('model_path', 'runs/train/test/weights/best.pt')
|
||||
self.document_count = kwargs.get('document_count', 0)
|
||||
self.metrics_mAP = kwargs.get('metrics_mAP', 0.935)
|
||||
self.metrics_precision = kwargs.get('metrics_precision', 0.92)
|
||||
self.metrics_recall = kwargs.get('metrics_recall', 0.88)
|
||||
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
||||
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockTrainingDocumentLink:
|
||||
"""Mock TrainingDocumentLink for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.link_id = kwargs.get('link_id', uuid4())
|
||||
self.task_id = kwargs.get('task_id')
|
||||
self.document_id = kwargs.get('document_id')
|
||||
self.annotation_snapshot = kwargs.get('annotation_snapshot', None)
|
||||
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockAdminDocument:
|
||||
"""Mock AdminDocument for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.document_id = kwargs.get('document_id', uuid4())
|
||||
self.admin_token = kwargs.get('admin_token', 'test-token')
|
||||
self.filename = kwargs.get('filename', 'test.pdf')
|
||||
self.file_size = kwargs.get('file_size', 100000)
|
||||
self.content_type = kwargs.get('content_type', 'application/pdf')
|
||||
self.file_path = kwargs.get('file_path', 'data/admin_docs/test.pdf')
|
||||
self.page_count = kwargs.get('page_count', 1)
|
||||
self.status = kwargs.get('status', 'labeled')
|
||||
self.auto_label_status = kwargs.get('auto_label_status', None)
|
||||
self.auto_label_error = kwargs.get('auto_label_error', None)
|
||||
self.upload_source = kwargs.get('upload_source', 'ui')
|
||||
self.batch_id = kwargs.get('batch_id', None)
|
||||
self.csv_field_values = kwargs.get('csv_field_values', None)
|
||||
self.auto_label_queued_at = kwargs.get('auto_label_queued_at', None)
|
||||
self.annotation_lock_until = kwargs.get('annotation_lock_until', None)
|
||||
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
||||
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockAnnotation:
|
||||
"""Mock AdminAnnotation for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.annotation_id = kwargs.get('annotation_id', uuid4())
|
||||
self.document_id = kwargs.get('document_id')
|
||||
self.page_number = kwargs.get('page_number', 1)
|
||||
self.class_id = kwargs.get('class_id', 0)
|
||||
self.class_name = kwargs.get('class_name', 'invoice_number')
|
||||
self.bbox_x = kwargs.get('bbox_x', 100)
|
||||
self.bbox_y = kwargs.get('bbox_y', 100)
|
||||
self.bbox_width = kwargs.get('bbox_width', 200)
|
||||
self.bbox_height = kwargs.get('bbox_height', 50)
|
||||
self.x_center = kwargs.get('x_center', 0.5)
|
||||
self.y_center = kwargs.get('y_center', 0.5)
|
||||
self.width = kwargs.get('width', 0.3)
|
||||
self.height = kwargs.get('height', 0.1)
|
||||
self.text_value = kwargs.get('text_value', 'INV-001')
|
||||
self.confidence = kwargs.get('confidence', 0.95)
|
||||
self.source = kwargs.get('source', 'manual')
|
||||
self.is_verified = kwargs.get('is_verified', False)
|
||||
self.verified_at = kwargs.get('verified_at', None)
|
||||
self.verified_by = kwargs.get('verified_by', None)
|
||||
self.override_source = kwargs.get('override_source', None)
|
||||
self.original_annotation_id = kwargs.get('original_annotation_id', None)
|
||||
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
||||
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockAdminDB:
|
||||
"""Mock AdminDB for testing Phase 4."""
|
||||
|
||||
def __init__(self):
|
||||
self.documents = {}
|
||||
self.annotations = {}
|
||||
self.training_tasks = {}
|
||||
self.training_links = {}
|
||||
|
||||
def get_documents_for_training(
|
||||
self,
|
||||
admin_token,
|
||||
status="labeled",
|
||||
has_annotations=True,
|
||||
min_annotation_count=None,
|
||||
exclude_used_in_training=False,
|
||||
limit=100,
|
||||
offset=0,
|
||||
):
|
||||
"""Get documents for training."""
|
||||
# Filter documents by criteria
|
||||
filtered = []
|
||||
for doc in self.documents.values():
|
||||
if doc.admin_token != admin_token or doc.status != status:
|
||||
continue
|
||||
|
||||
# Check annotations
|
||||
annotations = self.annotations.get(str(doc.document_id), [])
|
||||
if has_annotations and len(annotations) == 0:
|
||||
continue
|
||||
if min_annotation_count and len(annotations) < min_annotation_count:
|
||||
continue
|
||||
|
||||
# Check if used in training
|
||||
if exclude_used_in_training:
|
||||
links = self.training_links.get(str(doc.document_id), [])
|
||||
if links:
|
||||
continue
|
||||
|
||||
filtered.append(doc)
|
||||
|
||||
total = len(filtered)
|
||||
return filtered[offset:offset+limit], total
|
||||
|
||||
def get_annotations_for_document(self, document_id):
|
||||
"""Get annotations for document."""
|
||||
return self.annotations.get(str(document_id), [])
|
||||
|
||||
def get_document_training_tasks(self, document_id):
|
||||
"""Get training tasks that used this document."""
|
||||
return self.training_links.get(str(document_id), [])
|
||||
|
||||
def get_training_tasks_by_token(
|
||||
self,
|
||||
admin_token,
|
||||
status=None,
|
||||
limit=20,
|
||||
offset=0,
|
||||
):
|
||||
"""Get training tasks filtered by token."""
|
||||
tasks = [t for t in self.training_tasks.values() if t.admin_token == admin_token]
|
||||
if status:
|
||||
tasks = [t for t in tasks if t.status == status]
|
||||
|
||||
total = len(tasks)
|
||||
return tasks[offset:offset+limit], total
|
||||
|
||||
def get_training_task(self, task_id):
|
||||
"""Get training task by ID."""
|
||||
return self.training_tasks.get(str(task_id))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""Create test FastAPI app."""
|
||||
app = FastAPI()
|
||||
|
||||
# Create mock DB
|
||||
mock_db = MockAdminDB()
|
||||
|
||||
# Add test documents
|
||||
doc1 = MockAdminDocument(
|
||||
filename="DOC001.pdf",
|
||||
status="labeled",
|
||||
)
|
||||
doc2 = MockAdminDocument(
|
||||
filename="DOC002.pdf",
|
||||
status="labeled",
|
||||
)
|
||||
doc3 = MockAdminDocument(
|
||||
filename="DOC003.pdf",
|
||||
status="labeled",
|
||||
)
|
||||
|
||||
mock_db.documents[str(doc1.document_id)] = doc1
|
||||
mock_db.documents[str(doc2.document_id)] = doc2
|
||||
mock_db.documents[str(doc3.document_id)] = doc3
|
||||
|
||||
# Add annotations
|
||||
mock_db.annotations[str(doc1.document_id)] = [
|
||||
MockAnnotation(document_id=doc1.document_id, source="manual"),
|
||||
MockAnnotation(document_id=doc1.document_id, source="auto"),
|
||||
]
|
||||
mock_db.annotations[str(doc2.document_id)] = [
|
||||
MockAnnotation(document_id=doc2.document_id, source="auto"),
|
||||
MockAnnotation(document_id=doc2.document_id, source="auto"),
|
||||
MockAnnotation(document_id=doc2.document_id, source="auto"),
|
||||
]
|
||||
# doc3 has no annotations
|
||||
|
||||
# Add training tasks
|
||||
task1 = MockTrainingTask(
|
||||
name="Training Run 2024-01",
|
||||
status="completed",
|
||||
document_count=500,
|
||||
metrics_mAP=0.935,
|
||||
metrics_precision=0.92,
|
||||
metrics_recall=0.88,
|
||||
)
|
||||
task2 = MockTrainingTask(
|
||||
name="Training Run 2024-02",
|
||||
status="completed",
|
||||
document_count=600,
|
||||
metrics_mAP=0.951,
|
||||
metrics_precision=0.94,
|
||||
metrics_recall=0.92,
|
||||
)
|
||||
|
||||
mock_db.training_tasks[str(task1.task_id)] = task1
|
||||
mock_db.training_tasks[str(task2.task_id)] = task2
|
||||
|
||||
# Add training links (doc1 used in task1)
|
||||
link1 = MockTrainingDocumentLink(
|
||||
task_id=task1.task_id,
|
||||
document_id=doc1.document_id,
|
||||
)
|
||||
mock_db.training_links[str(doc1.document_id)] = [link1]
|
||||
|
||||
# Override dependencies
|
||||
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
|
||||
app.dependency_overrides[get_admin_db] = lambda: mock_db
|
||||
|
||||
# Include router
|
||||
router = create_training_router()
|
||||
app.include_router(router)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create test client."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
class TestTrainingDocuments:
|
||||
"""Tests for GET /admin/training/documents endpoint."""
|
||||
|
||||
def test_get_training_documents_success(self, client):
|
||||
"""Test getting documents for training."""
|
||||
response = client.get("/admin/training/documents")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "total" in data
|
||||
assert "documents" in data
|
||||
assert data["total"] >= 0
|
||||
assert isinstance(data["documents"], list)
|
||||
|
||||
def test_get_training_documents_with_annotations(self, client):
|
||||
"""Test filtering documents with annotations."""
|
||||
response = client.get("/admin/training/documents?has_annotations=true")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Should return doc1 and doc2 (both have annotations)
|
||||
assert data["total"] == 2
|
||||
|
||||
def test_get_training_documents_min_annotation_count(self, client):
|
||||
"""Test filtering by minimum annotation count."""
|
||||
response = client.get("/admin/training/documents?min_annotation_count=3")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Should return only doc2 (has 3 annotations)
|
||||
assert data["total"] == 1
|
||||
|
||||
def test_get_training_documents_exclude_used(self, client):
|
||||
"""Test excluding documents already used in training."""
|
||||
response = client.get("/admin/training/documents?exclude_used_in_training=true")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Should exclude doc1 (used in training)
|
||||
assert data["total"] == 1 # Only doc2 (doc3 has no annotations)
|
||||
|
||||
def test_get_training_documents_annotation_sources(self, client):
|
||||
"""Test that annotation sources are included."""
|
||||
response = client.get("/admin/training/documents?has_annotations=true")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Check that documents have annotation_sources field
|
||||
for doc in data["documents"]:
|
||||
assert "annotation_sources" in doc
|
||||
assert isinstance(doc["annotation_sources"], dict)
|
||||
assert "manual" in doc["annotation_sources"]
|
||||
assert "auto" in doc["annotation_sources"]
|
||||
|
||||
def test_get_training_documents_pagination(self, client):
|
||||
"""Test pagination parameters."""
|
||||
response = client.get("/admin/training/documents?limit=1&offset=0")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["limit"] == 1
|
||||
assert data["offset"] == 0
|
||||
assert len(data["documents"]) <= 1
|
||||
|
||||
|
||||
class TestTrainingModels:
|
||||
"""Tests for GET /admin/training/models endpoint."""
|
||||
|
||||
def test_get_training_models_success(self, client):
|
||||
"""Test getting trained models list."""
|
||||
response = client.get("/admin/training/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "total" in data
|
||||
assert "models" in data
|
||||
assert data["total"] == 2
|
||||
assert len(data["models"]) == 2
|
||||
|
||||
def test_get_training_models_includes_metrics(self, client):
|
||||
"""Test that models include metrics."""
|
||||
response = client.get("/admin/training/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Check first model has metrics
|
||||
model = data["models"][0]
|
||||
assert "metrics" in model
|
||||
assert "mAP" in model["metrics"]
|
||||
assert model["metrics"]["mAP"] is not None
|
||||
assert "precision" in model["metrics"]
|
||||
assert "recall" in model["metrics"]
|
||||
|
||||
def test_get_training_models_includes_download_url(self, client):
|
||||
"""Test that completed models have download URLs."""
|
||||
response = client.get("/admin/training/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Check completed models have download URLs
|
||||
for model in data["models"]:
|
||||
if model["status"] == "completed":
|
||||
assert "download_url" in model
|
||||
assert model["download_url"] is not None
|
||||
|
||||
def test_get_training_models_filter_by_status(self, client):
|
||||
"""Test filtering models by status."""
|
||||
response = client.get("/admin/training/models?status=completed")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# All returned models should be completed
|
||||
for model in data["models"]:
|
||||
assert model["status"] == "completed"
|
||||
|
||||
def test_get_training_models_pagination(self, client):
|
||||
"""Test pagination for models."""
|
||||
response = client.get("/admin/training/models?limit=1&offset=0")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["limit"] == 1
|
||||
assert data["offset"] == 0
|
||||
assert len(data["models"]) == 1
|
||||
Reference in New Issue
Block a user