Add more tests
This commit is contained in:
1
tests/integration/__init__.py
Normal file
1
tests/integration/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Integration tests for invoice-master-poc-v2."""
|
||||
1
tests/integration/api/__init__.py
Normal file
1
tests/integration/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""API integration tests."""
|
||||
389
tests/integration/api/test_api_integration.py
Normal file
389
tests/integration/api/test_api_integration.py
Normal file
@@ -0,0 +1,389 @@
|
||||
"""
|
||||
API Integration Tests
|
||||
|
||||
Tests FastAPI endpoints with mocked services.
|
||||
These tests verify the API layer works correctly with the service layer.
|
||||
"""
|
||||
|
||||
import io
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockServiceResult:
|
||||
"""Mock result from inference service."""
|
||||
|
||||
document_id: str = "test-doc-123"
|
||||
success: bool = True
|
||||
document_type: str = "invoice"
|
||||
fields: dict[str, str] = field(default_factory=lambda: {
|
||||
"InvoiceNumber": "INV-2024-001",
|
||||
"Amount": "1500.00",
|
||||
"InvoiceDate": "2024-01-15",
|
||||
"OCR": "12345678901234",
|
||||
"Bankgiro": "1234-5678",
|
||||
})
|
||||
confidence: dict[str, float] = field(default_factory=lambda: {
|
||||
"InvoiceNumber": 0.95,
|
||||
"Amount": 0.92,
|
||||
"InvoiceDate": 0.88,
|
||||
"OCR": 0.95,
|
||||
"Bankgiro": 0.90,
|
||||
})
|
||||
detections: list[dict[str, Any]] = field(default_factory=list)
|
||||
processing_time_ms: float = 150.5
|
||||
visualization_path: Path | None = None
|
||||
errors: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_storage_dir():
|
||||
"""Create temporary storage directories."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
base = Path(tmpdir)
|
||||
uploads_dir = base / "uploads" / "inference"
|
||||
results_dir = base / "results"
|
||||
uploads_dir.mkdir(parents=True, exist_ok=True)
|
||||
results_dir.mkdir(parents=True, exist_ok=True)
|
||||
yield {
|
||||
"base": base,
|
||||
"uploads": uploads_dir,
|
||||
"results": results_dir,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_inference_service():
|
||||
"""Create a mock inference service."""
|
||||
service = MagicMock()
|
||||
service.is_initialized = True
|
||||
service.gpu_available = False
|
||||
|
||||
# Create a realistic mock result
|
||||
mock_result = MockServiceResult()
|
||||
|
||||
service.process_pdf.return_value = mock_result
|
||||
service.process_image.return_value = mock_result
|
||||
service.initialize.return_value = None
|
||||
|
||||
return service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage_config(temp_storage_dir):
|
||||
"""Create mock storage configuration."""
|
||||
from inference.web.config import StorageConfig
|
||||
|
||||
return StorageConfig(
|
||||
upload_dir=temp_storage_dir["uploads"],
|
||||
result_dir=temp_storage_dir["results"],
|
||||
max_file_size_mb=50,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage_helper(temp_storage_dir):
|
||||
"""Create a mock storage helper."""
|
||||
helper = MagicMock()
|
||||
helper.get_uploads_base_path.return_value = temp_storage_dir["uploads"]
|
||||
helper.get_result_local_path.return_value = None
|
||||
helper.result_exists.return_value = False
|
||||
return helper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_app(mock_inference_service, mock_storage_config, mock_storage_helper):
|
||||
"""Create a test FastAPI application with mocked storage."""
|
||||
from inference.web.api.v1.public.inference import create_inference_router
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Patch get_storage_helper to return our mock
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
inference_router = create_inference_router(mock_inference_service, mock_storage_config)
|
||||
app.include_router(inference_router)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(test_app, mock_storage_helper):
|
||||
"""Create a test client with storage helper patched."""
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
yield TestClient(test_app)
|
||||
|
||||
|
||||
class TestHealthEndpoint:
|
||||
"""Tests for health check endpoint."""
|
||||
|
||||
def test_health_check(self, client, mock_inference_service):
|
||||
"""Test health check returns status."""
|
||||
response = client.get("/api/v1/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "status" in data
|
||||
assert "model_loaded" in data
|
||||
|
||||
|
||||
class TestInferenceEndpoint:
|
||||
"""Tests for inference endpoint."""
|
||||
|
||||
def test_infer_pdf(self, client, mock_inference_service, mock_storage_helper, temp_storage_dir):
|
||||
"""Test PDF inference endpoint."""
|
||||
# Create a minimal PDF content
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "result" in data
|
||||
assert data["result"]["success"] is True
|
||||
assert "InvoiceNumber" in data["result"]["fields"]
|
||||
|
||||
def test_infer_image(self, client, mock_inference_service, mock_storage_helper):
|
||||
"""Test image inference endpoint."""
|
||||
# Create minimal PNG header
|
||||
png_header = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.png", io.BytesIO(png_header), "image/png")},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "result" in data
|
||||
|
||||
def test_infer_invalid_file_type(self, client, mock_storage_helper):
|
||||
"""Test rejection of invalid file types."""
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.txt", io.BytesIO(b"hello"), "text/plain")},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_infer_no_file(self, client, mock_storage_helper):
|
||||
"""Test rejection when no file provided."""
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post("/api/v1/infer")
|
||||
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
def test_infer_result_structure(self, client, mock_inference_service, mock_storage_helper):
|
||||
"""Test that result has expected structure."""
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
result = data["result"]
|
||||
|
||||
# Check required fields
|
||||
assert "document_id" in result
|
||||
assert "success" in result
|
||||
assert "fields" in result
|
||||
assert "confidence" in result
|
||||
assert "processing_time_ms" in result
|
||||
|
||||
|
||||
class TestInferenceResultFormat:
|
||||
"""Tests for inference result formatting."""
|
||||
|
||||
def test_result_fields_mapped_correctly(self, client, mock_inference_service, mock_storage_helper):
|
||||
"""Test that fields are mapped to API response format."""
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
fields = data["result"]["fields"]
|
||||
|
||||
assert fields["InvoiceNumber"] == "INV-2024-001"
|
||||
assert fields["Amount"] == "1500.00"
|
||||
assert fields["InvoiceDate"] == "2024-01-15"
|
||||
|
||||
def test_confidence_values_included(self, client, mock_inference_service, mock_storage_helper):
|
||||
"""Test that confidence values are included."""
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
confidence = data["result"]["confidence"]
|
||||
|
||||
assert "InvoiceNumber" in confidence
|
||||
assert confidence["InvoiceNumber"] == 0.95
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Tests for error handling in API."""
|
||||
|
||||
def test_service_error_handling(self, client, mock_inference_service, mock_storage_helper):
|
||||
"""Test handling of service errors."""
|
||||
mock_inference_service.process_pdf.side_effect = Exception("Processing failed")
|
||||
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
|
||||
)
|
||||
|
||||
# Should return error response
|
||||
assert response.status_code >= 400
|
||||
|
||||
def test_empty_file_handling(self, client, mock_storage_helper):
|
||||
"""Test handling of empty files."""
|
||||
# Empty file still has valid content type
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.pdf", io.BytesIO(b""), "application/pdf")},
|
||||
)
|
||||
|
||||
# Empty file may be processed or rejected depending on implementation
|
||||
# Just verify we get a response
|
||||
assert response.status_code in [200, 400, 422, 500]
|
||||
|
||||
|
||||
class TestResponseFormat:
|
||||
"""Tests for API response format consistency."""
|
||||
|
||||
def test_success_response_format(self, client, mock_inference_service, mock_storage_helper):
|
||||
"""Test successful response format."""
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "application/json"
|
||||
|
||||
data = response.json()
|
||||
assert isinstance(data, dict)
|
||||
assert "result" in data
|
||||
|
||||
def test_json_serialization(self, client, mock_inference_service, mock_storage_helper):
|
||||
"""Test that all result fields are JSON serializable."""
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
|
||||
)
|
||||
|
||||
# If this doesn't raise, JSON is valid
|
||||
data = response.json()
|
||||
assert data is not None
|
||||
|
||||
|
||||
class TestDocumentIdGeneration:
|
||||
"""Tests for document ID handling."""
|
||||
|
||||
def test_document_id_generated(self, client, mock_inference_service, mock_storage_helper):
|
||||
"""Test that document ID is generated."""
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert "document_id" in data["result"]
|
||||
assert data["result"]["document_id"] is not None
|
||||
|
||||
def test_document_id_from_filename(self, client, mock_inference_service, mock_storage_helper):
|
||||
"""Test document ID derived from filename."""
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("my_invoice_123.pdf", io.BytesIO(pdf_content), "application/pdf")},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
# Document ID should be set (either from filename or generated)
|
||||
assert data["result"]["document_id"] is not None
|
||||
400
tests/integration/api/test_dashboard_api_integration.py
Normal file
400
tests/integration/api/test_dashboard_api_integration.py
Normal file
@@ -0,0 +1,400 @@
|
||||
"""
|
||||
Dashboard API Integration Tests
|
||||
|
||||
Tests Dashboard API endpoints with real database operations via TestClient.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from inference.data.admin_models import (
|
||||
AdminAnnotation,
|
||||
AdminDocument,
|
||||
AdminToken,
|
||||
AnnotationHistory,
|
||||
ModelVersion,
|
||||
TrainingDataset,
|
||||
TrainingTask,
|
||||
)
|
||||
from inference.web.api.v1.admin.dashboard import create_dashboard_router
|
||||
from inference.web.core.auth import get_admin_token_dep
|
||||
|
||||
|
||||
def create_test_app(override_token_dep):
|
||||
"""Create a FastAPI test application with dashboard router."""
|
||||
app = FastAPI()
|
||||
router = create_dashboard_router()
|
||||
app.include_router(router)
|
||||
|
||||
# Override auth dependency
|
||||
app.dependency_overrides[get_admin_token_dep] = lambda: override_token_dep
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class TestDashboardStatsEndpoint:
|
||||
"""Tests for GET /admin/dashboard/stats endpoint."""
|
||||
|
||||
def test_stats_empty_database(self, patched_session, admin_token):
|
||||
"""Test stats endpoint with empty database."""
|
||||
app = create_test_app(admin_token.token)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/admin/dashboard/stats")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_documents"] == 0
|
||||
assert data["annotation_complete"] == 0
|
||||
assert data["annotation_incomplete"] == 0
|
||||
assert data["pending"] == 0
|
||||
assert data["completeness_rate"] == 0.0
|
||||
|
||||
def test_stats_with_pending_documents(self, patched_session, admin_token):
|
||||
"""Test stats with pending documents."""
|
||||
session = patched_session
|
||||
|
||||
# Create pending documents
|
||||
for i in range(3):
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename=f"pending_{i}.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path=f"/uploads/pending_{i}.pdf",
|
||||
page_count=1,
|
||||
status="pending",
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(doc)
|
||||
session.commit()
|
||||
|
||||
app = create_test_app(admin_token.token)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/admin/dashboard/stats")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_documents"] == 3
|
||||
assert data["pending"] == 3
|
||||
|
||||
def test_stats_with_complete_annotations(self, patched_session, admin_token):
|
||||
"""Test stats with complete annotations."""
|
||||
session = patched_session
|
||||
|
||||
# Create labeled document with complete annotations
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename="complete.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/uploads/complete.pdf",
|
||||
page_count=1,
|
||||
status="labeled",
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(doc)
|
||||
session.commit()
|
||||
|
||||
# Add identifier and payment annotations
|
||||
session.add(AdminAnnotation(
|
||||
annotation_id=uuid4(),
|
||||
document_id=doc.document_id,
|
||||
page_number=1,
|
||||
class_id=0, # invoice_number
|
||||
class_name="invoice_number",
|
||||
x_center=0.5, y_center=0.1, width=0.2, height=0.05,
|
||||
bbox_x=400, bbox_y=80, bbox_width=160, bbox_height=40,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
))
|
||||
session.add(AdminAnnotation(
|
||||
annotation_id=uuid4(),
|
||||
document_id=doc.document_id,
|
||||
page_number=1,
|
||||
class_id=4, # bankgiro
|
||||
class_name="bankgiro",
|
||||
x_center=0.5, y_center=0.2, width=0.2, height=0.05,
|
||||
bbox_x=400, bbox_y=160, bbox_width=160, bbox_height=40,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
))
|
||||
session.commit()
|
||||
|
||||
app = create_test_app(admin_token.token)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/admin/dashboard/stats")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["annotation_complete"] == 1
|
||||
assert data["completeness_rate"] == 100.0
|
||||
|
||||
|
||||
class TestActiveModelEndpoint:
|
||||
"""Tests for GET /admin/dashboard/active-model endpoint."""
|
||||
|
||||
def test_active_model_none(self, patched_session, admin_token):
|
||||
"""Test active-model endpoint with no active model."""
|
||||
app = create_test_app(admin_token.token)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/admin/dashboard/active-model")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["model"] is None
|
||||
assert data["running_training"] is None
|
||||
|
||||
def test_active_model_with_model(self, patched_session, admin_token, sample_dataset):
|
||||
"""Test active-model endpoint with active model."""
|
||||
session = patched_session
|
||||
|
||||
# Create training task
|
||||
task = TrainingTask(
|
||||
task_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
name="Test Task",
|
||||
status="completed",
|
||||
task_type="train",
|
||||
dataset_id=sample_dataset.dataset_id,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(task)
|
||||
session.commit()
|
||||
|
||||
# Create active model
|
||||
model = ModelVersion(
|
||||
version_id=uuid4(),
|
||||
version="1.0.0",
|
||||
name="Test Model",
|
||||
model_path="/models/test.pt",
|
||||
status="active",
|
||||
is_active=True,
|
||||
task_id=task.task_id,
|
||||
dataset_id=sample_dataset.dataset_id,
|
||||
metrics_mAP=0.90,
|
||||
metrics_precision=0.88,
|
||||
metrics_recall=0.85,
|
||||
document_count=100,
|
||||
file_size=50000000,
|
||||
activated_at=datetime.now(timezone.utc),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(model)
|
||||
session.commit()
|
||||
|
||||
app = create_test_app(admin_token.token)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/admin/dashboard/active-model")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["model"] is not None
|
||||
assert data["model"]["version"] == "1.0.0"
|
||||
assert data["model"]["name"] == "Test Model"
|
||||
assert data["model"]["metrics_mAP"] == 0.90
|
||||
|
||||
def test_active_model_with_running_training(self, patched_session, admin_token, sample_dataset):
|
||||
"""Test active-model endpoint with running training."""
|
||||
session = patched_session
|
||||
|
||||
# Create running training task
|
||||
task = TrainingTask(
|
||||
task_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
name="Running Task",
|
||||
status="running",
|
||||
task_type="train",
|
||||
dataset_id=sample_dataset.dataset_id,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
progress=50,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(task)
|
||||
session.commit()
|
||||
|
||||
app = create_test_app(admin_token.token)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/admin/dashboard/active-model")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["running_training"] is not None
|
||||
assert data["running_training"]["name"] == "Running Task"
|
||||
assert data["running_training"]["status"] == "running"
|
||||
assert data["running_training"]["progress"] == 50
|
||||
|
||||
|
||||
class TestRecentActivityEndpoint:
|
||||
"""Tests for GET /admin/dashboard/activity endpoint."""
|
||||
|
||||
def test_activity_empty(self, patched_session, admin_token):
|
||||
"""Test activity endpoint with no activities."""
|
||||
app = create_test_app(admin_token.token)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/admin/dashboard/activity")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["activities"] == []
|
||||
|
||||
def test_activity_with_uploads(self, patched_session, admin_token):
|
||||
"""Test activity includes document uploads."""
|
||||
session = patched_session
|
||||
|
||||
# Create documents
|
||||
for i in range(3):
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename=f"activity_{i}.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path=f"/uploads/activity_{i}.pdf",
|
||||
page_count=1,
|
||||
status="pending",
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(doc)
|
||||
session.commit()
|
||||
|
||||
app = create_test_app(admin_token.token)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/admin/dashboard/activity")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
upload_activities = [a for a in data["activities"] if a["type"] == "document_uploaded"]
|
||||
assert len(upload_activities) == 3
|
||||
|
||||
def test_activity_limit_parameter(self, patched_session, admin_token):
|
||||
"""Test activity limit parameter."""
|
||||
session = patched_session
|
||||
|
||||
# Create many documents
|
||||
for i in range(15):
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename=f"limit_{i}.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path=f"/uploads/limit_{i}.pdf",
|
||||
page_count=1,
|
||||
status="pending",
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(doc)
|
||||
session.commit()
|
||||
|
||||
app = create_test_app(admin_token.token)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/admin/dashboard/activity?limit=5")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["activities"]) <= 5
|
||||
|
||||
def test_activity_invalid_limit(self, patched_session, admin_token):
|
||||
"""Test activity with invalid limit parameter."""
|
||||
app = create_test_app(admin_token.token)
|
||||
client = TestClient(app)
|
||||
|
||||
# Limit too high
|
||||
response = client.get("/admin/dashboard/activity?limit=100")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Limit too low
|
||||
response = client.get("/admin/dashboard/activity?limit=0")
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_activity_with_training_completion(self, patched_session, admin_token, sample_dataset):
|
||||
"""Test activity includes training completions."""
|
||||
session = patched_session
|
||||
|
||||
# Create completed training task
|
||||
task = TrainingTask(
|
||||
task_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
name="Completed Task",
|
||||
status="completed",
|
||||
task_type="train",
|
||||
dataset_id=sample_dataset.dataset_id,
|
||||
metrics_mAP=0.95,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(task)
|
||||
session.commit()
|
||||
|
||||
app = create_test_app(admin_token.token)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/admin/dashboard/activity")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
training_activities = [a for a in data["activities"] if a["type"] == "training_completed"]
|
||||
assert len(training_activities) >= 1
|
||||
|
||||
def test_activity_sorted_by_timestamp(self, patched_session, admin_token):
|
||||
"""Test activities are sorted by timestamp descending."""
|
||||
session = patched_session
|
||||
|
||||
# Create documents
|
||||
for i in range(5):
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename=f"sorted_{i}.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path=f"/uploads/sorted_{i}.pdf",
|
||||
page_count=1,
|
||||
status="pending",
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(doc)
|
||||
session.commit()
|
||||
|
||||
app = create_test_app(admin_token.token)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/admin/dashboard/activity")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
timestamps = [a["timestamp"] for a in data["activities"]]
|
||||
assert timestamps == sorted(timestamps, reverse=True)
|
||||
465
tests/integration/conftest.py
Normal file
465
tests/integration/conftest.py
Normal file
@@ -0,0 +1,465 @@
|
||||
"""
|
||||
Integration Test Fixtures
|
||||
|
||||
Provides shared fixtures for integration tests using PostgreSQL.
|
||||
|
||||
IMPORTANT: Integration tests MUST use Docker testcontainers for database isolation.
|
||||
This ensures tests never touch the real production/development database.
|
||||
|
||||
Supported modes:
|
||||
1. Docker testcontainers (default): Automatically starts a PostgreSQL container
|
||||
2. TEST_DB_URL environment variable: Use a dedicated test database (NOT production!)
|
||||
|
||||
To use an external test database, set:
|
||||
TEST_DB_URL=postgresql://user:password@host:port/test_dbname
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from contextlib import contextmanager, ExitStack
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
from inference.data.admin_models import (
|
||||
AdminAnnotation,
|
||||
AdminDocument,
|
||||
AdminToken,
|
||||
AnnotationHistory,
|
||||
BatchUpload,
|
||||
BatchUploadFile,
|
||||
DatasetDocument,
|
||||
ModelVersion,
|
||||
TrainingDataset,
|
||||
TrainingDocumentLink,
|
||||
TrainingLog,
|
||||
TrainingTask,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Database Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _is_docker_available() -> bool:
|
||||
"""Check if Docker is available."""
|
||||
try:
|
||||
import docker
|
||||
client = docker.from_env()
|
||||
client.ping()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _get_test_db_url() -> str | None:
|
||||
"""Get test database URL from environment."""
|
||||
return os.environ.get("TEST_DB_URL")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def test_engine():
|
||||
"""Create a SQLAlchemy engine for testing.
|
||||
|
||||
Uses one of:
|
||||
1. TEST_DB_URL environment variable (dedicated test database)
|
||||
2. Docker testcontainers (if Docker is available)
|
||||
|
||||
IMPORTANT: Will NOT fall back to production database. If Docker is not
|
||||
available and TEST_DB_URL is not set, tests will fail with a clear error.
|
||||
|
||||
The engine is shared across all tests in a session for efficiency.
|
||||
"""
|
||||
# Try to get URL from environment first
|
||||
connection_url = _get_test_db_url()
|
||||
|
||||
if connection_url:
|
||||
# Use external test database from environment
|
||||
# Warn if it looks like a production database
|
||||
if "docmaster" in connection_url and "_test" not in connection_url:
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"TEST_DB_URL appears to point to a production database. "
|
||||
"Please use a dedicated test database (e.g., docmaster_test).",
|
||||
UserWarning,
|
||||
)
|
||||
elif _is_docker_available():
|
||||
# Use testcontainers - this is the recommended approach
|
||||
from testcontainers.postgres import PostgresContainer
|
||||
postgres = PostgresContainer("postgres:15-alpine")
|
||||
postgres.start()
|
||||
|
||||
connection_url = postgres.get_connection_url()
|
||||
if "psycopg2" in connection_url:
|
||||
connection_url = connection_url.replace("postgresql+psycopg2://", "postgresql://")
|
||||
|
||||
# Store container for cleanup
|
||||
test_engine._postgres_container = postgres
|
||||
else:
|
||||
# No Docker and no TEST_DB_URL - fail with clear instructions
|
||||
pytest.fail(
|
||||
"Integration tests require Docker or a TEST_DB_URL environment variable.\n\n"
|
||||
"Option 1 (Recommended): Install Docker Desktop and ensure it's running.\n"
|
||||
" - Windows: https://docs.docker.com/desktop/install/windows-install/\n"
|
||||
" - The testcontainers library will automatically create a PostgreSQL container.\n\n"
|
||||
"Option 2: Set TEST_DB_URL to a dedicated test database:\n"
|
||||
" - export TEST_DB_URL=postgresql://user:password@host:port/test_dbname\n"
|
||||
" - NEVER use your production database for tests!\n\n"
|
||||
"Integration tests will NOT fall back to the production database."
|
||||
)
|
||||
|
||||
engine = create_engine(
|
||||
connection_url,
|
||||
echo=False,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
# Create all tables
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
yield engine
|
||||
|
||||
# Cleanup
|
||||
SQLModel.metadata.drop_all(engine)
|
||||
engine.dispose()
|
||||
|
||||
# Stop container if we started one
|
||||
if hasattr(test_engine, "_postgres_container"):
|
||||
test_engine._postgres_container.stop()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def db_session(test_engine) -> Generator[Session, None, None]:
|
||||
"""Provide a database session for each test function.
|
||||
|
||||
Each test gets a fresh session that rolls back after the test,
|
||||
ensuring test isolation.
|
||||
"""
|
||||
connection = test_engine.connect()
|
||||
transaction = connection.begin()
|
||||
session = Session(bind=connection)
|
||||
|
||||
yield session
|
||||
|
||||
# Rollback and cleanup
|
||||
session.close()
|
||||
transaction.rollback()
|
||||
connection.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def patched_session(db_session):
|
||||
"""Patch get_session_context to use the test session.
|
||||
|
||||
This allows repository classes to use the test database session
|
||||
instead of creating their own connections.
|
||||
|
||||
We need to patch in multiple locations because each repository module
|
||||
imports get_session_context directly.
|
||||
"""
|
||||
|
||||
@contextmanager
|
||||
def mock_session_context() -> Generator[Session, None, None]:
|
||||
yield db_session
|
||||
|
||||
# All modules that import get_session_context
|
||||
patch_targets = [
|
||||
"inference.data.database.get_session_context",
|
||||
"inference.data.repositories.document_repository.get_session_context",
|
||||
"inference.data.repositories.annotation_repository.get_session_context",
|
||||
"inference.data.repositories.dataset_repository.get_session_context",
|
||||
"inference.data.repositories.training_task_repository.get_session_context",
|
||||
"inference.data.repositories.model_version_repository.get_session_context",
|
||||
"inference.data.repositories.batch_upload_repository.get_session_context",
|
||||
"inference.data.repositories.token_repository.get_session_context",
|
||||
"inference.web.services.dashboard_service.get_session_context",
|
||||
]
|
||||
|
||||
with ExitStack() as stack:
|
||||
for target in patch_targets:
|
||||
try:
|
||||
stack.enter_context(patch(target, mock_session_context))
|
||||
except (ModuleNotFoundError, AttributeError):
|
||||
# Skip if module doesn't exist or doesn't have the attribute
|
||||
pass
|
||||
yield db_session
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Data Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_token(db_session) -> AdminToken:
|
||||
"""Create a test admin token."""
|
||||
token = AdminToken(
|
||||
token="test-admin-token-12345",
|
||||
name="Test Admin",
|
||||
is_active=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(token)
|
||||
db_session.commit()
|
||||
db_session.refresh(token)
|
||||
return token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_document(db_session, admin_token) -> AdminDocument:
|
||||
"""Create a sample document for testing."""
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename="test_invoice.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/uploads/test_invoice.pdf",
|
||||
page_count=1,
|
||||
status="pending",
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(doc)
|
||||
db_session.commit()
|
||||
db_session.refresh(doc)
|
||||
return doc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_annotation(db_session, sample_document) -> AdminAnnotation:
|
||||
"""Create a sample annotation for testing."""
|
||||
annotation = AdminAnnotation(
|
||||
annotation_id=uuid4(),
|
||||
document_id=sample_document.document_id,
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.3,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=240,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
text_value="INV-2024-001",
|
||||
confidence=0.95,
|
||||
source="auto",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(annotation)
|
||||
db_session.commit()
|
||||
db_session.refresh(annotation)
|
||||
return annotation
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_dataset(db_session) -> TrainingDataset:
|
||||
"""Create a sample training dataset for testing."""
|
||||
dataset = TrainingDataset(
|
||||
dataset_id=uuid4(),
|
||||
name="Test Dataset",
|
||||
description="Dataset for integration testing",
|
||||
status="building",
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
total_documents=0,
|
||||
total_images=0,
|
||||
total_annotations=0,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(dataset)
|
||||
db_session.commit()
|
||||
db_session.refresh(dataset)
|
||||
return dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_training_task(db_session, admin_token, sample_dataset) -> TrainingTask:
|
||||
"""Create a sample training task for testing."""
|
||||
task = TrainingTask(
|
||||
task_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
name="Test Training Task",
|
||||
description="Training task for integration testing",
|
||||
status="pending",
|
||||
task_type="train",
|
||||
dataset_id=sample_dataset.dataset_id,
|
||||
config={"epochs": 10, "batch_size": 16},
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(task)
|
||||
db_session.commit()
|
||||
db_session.refresh(task)
|
||||
return task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_model_version(db_session, sample_training_task, sample_dataset) -> ModelVersion:
|
||||
"""Create a sample model version for testing."""
|
||||
version = ModelVersion(
|
||||
version_id=uuid4(),
|
||||
version="1.0.0",
|
||||
name="Test Model v1",
|
||||
description="Model version for integration testing",
|
||||
model_path="/models/test_model.pt",
|
||||
status="inactive",
|
||||
is_active=False,
|
||||
task_id=sample_training_task.task_id,
|
||||
dataset_id=sample_dataset.dataset_id,
|
||||
metrics_mAP=0.85,
|
||||
metrics_precision=0.88,
|
||||
metrics_recall=0.82,
|
||||
document_count=100,
|
||||
file_size=50000000,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(version)
|
||||
db_session.commit()
|
||||
db_session.refresh(version)
|
||||
return version
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_batch_upload(db_session, admin_token) -> BatchUpload:
|
||||
"""Create a sample batch upload for testing."""
|
||||
batch = BatchUpload(
|
||||
batch_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename="test_batch.zip",
|
||||
file_size=10240,
|
||||
upload_source="api",
|
||||
status="processing",
|
||||
total_files=5,
|
||||
processed_files=0,
|
||||
successful_files=0,
|
||||
failed_files=0,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(batch)
|
||||
db_session.commit()
|
||||
db_session.refresh(batch)
|
||||
return batch
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Multiple Documents Fixture
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multiple_documents(db_session, admin_token) -> list[AdminDocument]:
|
||||
"""Create multiple documents for pagination/filtering tests."""
|
||||
documents = []
|
||||
statuses = ["pending", "pending", "labeled", "labeled", "exported"]
|
||||
categories = ["invoice", "invoice", "invoice", "letter", "invoice"]
|
||||
|
||||
for i, (status, category) in enumerate(zip(statuses, categories)):
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename=f"test_doc_{i}.pdf",
|
||||
file_size=1024 + i * 100,
|
||||
content_type="application/pdf",
|
||||
file_path=f"/uploads/test_doc_{i}.pdf",
|
||||
page_count=1,
|
||||
status=status,
|
||||
upload_source="ui",
|
||||
category=category,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(doc)
|
||||
documents.append(doc)
|
||||
|
||||
db_session.commit()
|
||||
for doc in documents:
|
||||
db_session.refresh(doc)
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Temporary File Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_upload_dir() -> Generator[Path, None, None]:
|
||||
"""Create a temporary directory for file uploads."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
upload_dir = Path(tmpdir) / "uploads"
|
||||
upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
yield upload_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_model_dir() -> Generator[Path, None, None]:
|
||||
"""Create a temporary directory for model files."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model_dir = Path(tmpdir) / "models"
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
yield model_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dataset_dir() -> Generator[Path, None, None]:
|
||||
"""Create a temporary directory for dataset files."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
dataset_dir = Path(tmpdir) / "datasets"
|
||||
dataset_dir.mkdir(parents=True, exist_ok=True)
|
||||
yield dataset_dir
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Sample PDF Fixture
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_pdf_bytes() -> bytes:
|
||||
"""Return minimal valid PDF bytes for testing."""
|
||||
# Minimal valid PDF structure
|
||||
return b"""%PDF-1.4
|
||||
1 0 obj
|
||||
<< /Type /Catalog /Pages 2 0 R >>
|
||||
endobj
|
||||
2 0 obj
|
||||
<< /Type /Pages /Kids [3 0 R] /Count 1 >>
|
||||
endobj
|
||||
3 0 obj
|
||||
<< /Type /Page /Parent 2 0 R /MediaBox [0 0 612 792] >>
|
||||
endobj
|
||||
xref
|
||||
0 4
|
||||
0000000000 65535 f
|
||||
0000000009 00000 n
|
||||
0000000058 00000 n
|
||||
0000000115 00000 n
|
||||
trailer
|
||||
<< /Size 4 /Root 1 0 R >>
|
||||
startxref
|
||||
196
|
||||
%%EOF"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_pdf_file(temp_upload_dir, sample_pdf_bytes) -> Path:
|
||||
"""Create a sample PDF file for testing."""
|
||||
pdf_path = temp_upload_dir / "test_invoice.pdf"
|
||||
pdf_path.write_bytes(sample_pdf_bytes)
|
||||
return pdf_path
|
||||
1
tests/integration/pipeline/__init__.py
Normal file
1
tests/integration/pipeline/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Pipeline integration tests."""
|
||||
456
tests/integration/pipeline/test_pipeline_integration.py
Normal file
456
tests/integration/pipeline/test_pipeline_integration.py
Normal file
@@ -0,0 +1,456 @@
|
||||
"""
|
||||
Inference Pipeline Integration Tests
|
||||
|
||||
Tests the complete pipeline from input to output.
|
||||
Note: These tests use mocks for YOLO and OCR to avoid requiring actual models,
|
||||
but test the integration of pipeline components.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from inference.pipeline.pipeline import (
|
||||
InferencePipeline,
|
||||
InferenceResult,
|
||||
CrossValidationResult,
|
||||
)
|
||||
from inference.pipeline.yolo_detector import Detection
|
||||
from inference.pipeline.field_extractor import ExtractedField
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_detection():
|
||||
"""Create a mock detection."""
|
||||
return Detection(
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
confidence=0.95,
|
||||
bbox=(100, 50, 200, 30),
|
||||
page_no=0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_extracted_field():
|
||||
"""Create a mock extracted field."""
|
||||
return ExtractedField(
|
||||
field_name="InvoiceNumber",
|
||||
raw_text="INV-2024-001",
|
||||
normalized_value="INV-2024-001",
|
||||
confidence=0.95,
|
||||
bbox=(100, 50, 200, 30),
|
||||
page_no=0,
|
||||
is_valid=True,
|
||||
)
|
||||
|
||||
|
||||
class TestInferenceResultConstruction:
|
||||
"""Tests for InferenceResult construction and methods."""
|
||||
|
||||
def test_default_result(self):
|
||||
"""Test default InferenceResult values."""
|
||||
result = InferenceResult()
|
||||
|
||||
assert result.document_id is None
|
||||
assert result.success is False
|
||||
assert result.fields == {}
|
||||
assert result.confidence == {}
|
||||
assert result.raw_detections == []
|
||||
assert result.extracted_fields == []
|
||||
assert result.errors == []
|
||||
assert result.fallback_used is False
|
||||
assert result.cross_validation is None
|
||||
|
||||
def test_result_to_json(self):
|
||||
"""Test JSON serialization of result."""
|
||||
result = InferenceResult(
|
||||
document_id="test-doc",
|
||||
success=True,
|
||||
fields={
|
||||
"InvoiceNumber": "INV-001",
|
||||
"Amount": "1500.00",
|
||||
},
|
||||
confidence={
|
||||
"InvoiceNumber": 0.95,
|
||||
"Amount": 0.92,
|
||||
},
|
||||
bboxes={
|
||||
"InvoiceNumber": (100, 50, 200, 30),
|
||||
},
|
||||
)
|
||||
|
||||
json_data = result.to_json()
|
||||
|
||||
assert json_data["DocumentId"] == "test-doc"
|
||||
assert json_data["success"] is True
|
||||
assert json_data["InvoiceNumber"] == "INV-001"
|
||||
assert json_data["Amount"] == "1500.00"
|
||||
assert json_data["confidence"]["InvoiceNumber"] == 0.95
|
||||
assert "bboxes" in json_data
|
||||
|
||||
def test_result_get_field(self):
|
||||
"""Test getting field value and confidence."""
|
||||
result = InferenceResult(
|
||||
fields={"InvoiceNumber": "INV-001"},
|
||||
confidence={"InvoiceNumber": 0.95},
|
||||
)
|
||||
|
||||
value, conf = result.get_field("InvoiceNumber")
|
||||
assert value == "INV-001"
|
||||
assert conf == 0.95
|
||||
|
||||
value, conf = result.get_field("Amount")
|
||||
assert value is None
|
||||
assert conf == 0.0
|
||||
|
||||
|
||||
class TestCrossValidation:
|
||||
"""Tests for cross-validation logic."""
|
||||
|
||||
def test_cross_validation_default(self):
|
||||
"""Test default CrossValidationResult values."""
|
||||
cv = CrossValidationResult()
|
||||
|
||||
assert cv.is_valid is False
|
||||
assert cv.ocr_match is None
|
||||
assert cv.amount_match is None
|
||||
assert cv.bankgiro_match is None
|
||||
assert cv.plusgiro_match is None
|
||||
assert cv.payment_line_ocr is None
|
||||
assert cv.payment_line_amount is None
|
||||
assert cv.details == []
|
||||
|
||||
def test_cross_validation_with_matches(self):
|
||||
"""Test CrossValidationResult with matches."""
|
||||
cv = CrossValidationResult(
|
||||
is_valid=True,
|
||||
ocr_match=True,
|
||||
amount_match=True,
|
||||
bankgiro_match=True,
|
||||
payment_line_ocr="12345678901234",
|
||||
payment_line_amount="1500.00",
|
||||
payment_line_account="1234-5678",
|
||||
payment_line_account_type="bankgiro",
|
||||
details=["OCR match", "Amount match", "Bankgiro match"],
|
||||
)
|
||||
|
||||
assert cv.is_valid is True
|
||||
assert cv.ocr_match is True
|
||||
assert cv.amount_match is True
|
||||
assert len(cv.details) == 3
|
||||
|
||||
|
||||
class TestPipelineMergeFields:
|
||||
"""Tests for field merging logic."""
|
||||
|
||||
def test_merge_selects_highest_confidence(self):
|
||||
"""Test that merge selects highest confidence for duplicate fields."""
|
||||
# Create mock pipeline with minimal mocking
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
pipeline.payment_line_parser = MagicMock()
|
||||
pipeline.payment_line_parser.parse.return_value = MagicMock(is_valid=False)
|
||||
|
||||
result = InferenceResult()
|
||||
result.extracted_fields = [
|
||||
ExtractedField(
|
||||
field_name="InvoiceNumber",
|
||||
raw_text="INV-001",
|
||||
normalized_value="INV-001",
|
||||
confidence=0.85,
|
||||
detection_confidence=0.90,
|
||||
ocr_confidence=0.85,
|
||||
bbox=(100, 50, 200, 30),
|
||||
page_no=0,
|
||||
is_valid=True,
|
||||
),
|
||||
ExtractedField(
|
||||
field_name="InvoiceNumber",
|
||||
raw_text="INV-001",
|
||||
normalized_value="INV-001",
|
||||
confidence=0.95, # Higher confidence
|
||||
detection_confidence=0.95,
|
||||
ocr_confidence=0.95,
|
||||
bbox=(105, 52, 198, 28),
|
||||
page_no=0,
|
||||
is_valid=True,
|
||||
),
|
||||
]
|
||||
|
||||
pipeline._merge_fields(result)
|
||||
|
||||
assert result.fields["InvoiceNumber"] == "INV-001"
|
||||
assert result.confidence["InvoiceNumber"] == 0.95
|
||||
|
||||
def test_merge_skips_invalid_fields(self):
|
||||
"""Test that merge skips invalid extracted fields."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
pipeline.payment_line_parser = MagicMock()
|
||||
pipeline.payment_line_parser.parse.return_value = MagicMock(is_valid=False)
|
||||
|
||||
result = InferenceResult()
|
||||
result.extracted_fields = [
|
||||
ExtractedField(
|
||||
field_name="InvoiceNumber",
|
||||
raw_text="",
|
||||
normalized_value=None,
|
||||
confidence=0.95,
|
||||
detection_confidence=0.95,
|
||||
ocr_confidence=0.95,
|
||||
bbox=(100, 50, 200, 30),
|
||||
page_no=0,
|
||||
is_valid=False, # Invalid
|
||||
),
|
||||
ExtractedField(
|
||||
field_name="Amount",
|
||||
raw_text="1500.00",
|
||||
normalized_value="1500.00",
|
||||
confidence=0.92,
|
||||
detection_confidence=0.92,
|
||||
ocr_confidence=0.92,
|
||||
bbox=(200, 100, 100, 25),
|
||||
page_no=0,
|
||||
is_valid=True,
|
||||
),
|
||||
]
|
||||
|
||||
pipeline._merge_fields(result)
|
||||
|
||||
assert "InvoiceNumber" not in result.fields
|
||||
assert result.fields["Amount"] == "1500.00"
|
||||
|
||||
|
||||
class TestPaymentLineValidation:
|
||||
"""Tests for payment line cross-validation."""
|
||||
|
||||
def test_payment_line_overrides_ocr(self):
|
||||
"""Test that payment line OCR overrides detected OCR."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
# Mock payment line parser
|
||||
mock_parsed = MagicMock()
|
||||
mock_parsed.is_valid = True
|
||||
mock_parsed.ocr_number = "12345678901234"
|
||||
mock_parsed.amount = "1500.00"
|
||||
mock_parsed.account_number = "12345678"
|
||||
|
||||
pipeline.payment_line_parser = MagicMock()
|
||||
pipeline.payment_line_parser.parse.return_value = mock_parsed
|
||||
|
||||
result = InferenceResult(
|
||||
fields={
|
||||
"payment_line": "# 12345678901234 # 1500 00 5 > 12345678#41#",
|
||||
"OCR": "99999999999999", # Different OCR
|
||||
},
|
||||
confidence={"OCR": 0.85},
|
||||
)
|
||||
|
||||
pipeline._cross_validate_payment_line(result)
|
||||
|
||||
# Payment line OCR should override
|
||||
assert result.fields["OCR"] == "12345678901234"
|
||||
assert result.confidence["OCR"] == 0.95
|
||||
|
||||
def test_payment_line_overrides_amount(self):
|
||||
"""Test that payment line amount overrides detected amount."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
mock_parsed = MagicMock()
|
||||
mock_parsed.is_valid = True
|
||||
mock_parsed.ocr_number = None
|
||||
mock_parsed.amount = "2500.50"
|
||||
mock_parsed.account_number = None
|
||||
|
||||
pipeline.payment_line_parser = MagicMock()
|
||||
pipeline.payment_line_parser.parse.return_value = mock_parsed
|
||||
|
||||
result = InferenceResult(
|
||||
fields={
|
||||
"payment_line": "# ... # 2500 50 5 > ...",
|
||||
"Amount": "2500.00", # Slightly different
|
||||
},
|
||||
confidence={"Amount": 0.80},
|
||||
)
|
||||
|
||||
pipeline._cross_validate_payment_line(result)
|
||||
|
||||
assert result.fields["Amount"] == "2500.50"
|
||||
assert result.confidence["Amount"] == 0.95
|
||||
|
||||
def test_cross_validation_records_matches(self):
|
||||
"""Test that cross-validation records match status."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
mock_parsed = MagicMock()
|
||||
mock_parsed.is_valid = True
|
||||
mock_parsed.ocr_number = "12345678901234"
|
||||
mock_parsed.amount = "1500.00"
|
||||
mock_parsed.account_number = "12345678"
|
||||
|
||||
pipeline.payment_line_parser = MagicMock()
|
||||
pipeline.payment_line_parser.parse.return_value = mock_parsed
|
||||
|
||||
result = InferenceResult(
|
||||
fields={
|
||||
"payment_line": "# 12345678901234 # 1500 00 5 > 12345678#41#",
|
||||
"OCR": "12345678901234", # Matching
|
||||
"Amount": "1500.00", # Matching
|
||||
"Bankgiro": "1234-5678", # Matching
|
||||
},
|
||||
confidence={},
|
||||
)
|
||||
|
||||
pipeline._cross_validate_payment_line(result)
|
||||
|
||||
assert result.cross_validation is not None
|
||||
assert result.cross_validation.ocr_match is True
|
||||
assert result.cross_validation.amount_match is True
|
||||
assert result.cross_validation.is_valid is True
|
||||
|
||||
|
||||
class TestFallbackLogic:
|
||||
"""Tests for fallback detection logic."""
|
||||
|
||||
def test_needs_fallback_when_key_fields_missing(self):
|
||||
"""Test fallback is triggered when key fields missing."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
# Only one key field present
|
||||
result = InferenceResult(fields={"Amount": "1500.00"})
|
||||
|
||||
assert pipeline._needs_fallback(result) is True
|
||||
|
||||
def test_no_fallback_when_fields_present(self):
|
||||
"""Test no fallback when key fields present."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
# All key fields present
|
||||
result = InferenceResult(
|
||||
fields={
|
||||
"Amount": "1500.00",
|
||||
"InvoiceNumber": "INV-001",
|
||||
"OCR": "12345678901234",
|
||||
}
|
||||
)
|
||||
|
||||
assert pipeline._needs_fallback(result) is False
|
||||
|
||||
|
||||
class TestPatternExtraction:
|
||||
"""Tests for fallback pattern extraction."""
|
||||
|
||||
def test_extract_amount_pattern(self):
|
||||
"""Test amount extraction with regex."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
text = "Att betala: 1 500,00 SEK"
|
||||
result = InferenceResult()
|
||||
|
||||
pipeline._extract_with_patterns(text, result)
|
||||
|
||||
assert "Amount" in result.fields
|
||||
assert result.confidence["Amount"] == 0.5
|
||||
|
||||
def test_extract_bankgiro_pattern(self):
|
||||
"""Test bankgiro extraction with regex."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
text = "Bankgiro: 1234-5678"
|
||||
result = InferenceResult()
|
||||
|
||||
pipeline._extract_with_patterns(text, result)
|
||||
|
||||
assert "Bankgiro" in result.fields
|
||||
assert result.fields["Bankgiro"] == "1234-5678"
|
||||
|
||||
def test_extract_ocr_pattern(self):
|
||||
"""Test OCR extraction with regex."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
text = "OCR: 12345678901234567890"
|
||||
result = InferenceResult()
|
||||
|
||||
pipeline._extract_with_patterns(text, result)
|
||||
|
||||
assert "OCR" in result.fields
|
||||
assert result.fields["OCR"] == "12345678901234567890"
|
||||
|
||||
def test_does_not_override_existing_fields(self):
|
||||
"""Test pattern extraction doesn't override existing fields."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
text = "Fakturanr: 999"
|
||||
result = InferenceResult(fields={"InvoiceNumber": "INV-001"})
|
||||
|
||||
pipeline._extract_with_patterns(text, result)
|
||||
|
||||
# Should keep existing value
|
||||
assert result.fields["InvoiceNumber"] == "INV-001"
|
||||
|
||||
|
||||
class TestAmountNormalization:
|
||||
"""Tests for amount normalization."""
|
||||
|
||||
def test_normalize_swedish_format(self):
|
||||
"""Test normalizing Swedish amount format."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
# Swedish format: space as thousands separator, comma as decimal
|
||||
assert pipeline._normalize_amount_for_compare("1 500,00") == 1500.00
|
||||
# Standard format: dot as decimal
|
||||
assert pipeline._normalize_amount_for_compare("1500.00") == 1500.00
|
||||
# Swedish format with comma as decimal
|
||||
assert pipeline._normalize_amount_for_compare("1500,00") == 1500.00
|
||||
|
||||
def test_normalize_invalid_amount(self):
|
||||
"""Test normalizing invalid amount returns None."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
assert pipeline._normalize_amount_for_compare("invalid") is None
|
||||
assert pipeline._normalize_amount_for_compare("") is None
|
||||
|
||||
|
||||
class TestResultSerialization:
|
||||
"""Tests for result serialization with cross-validation."""
|
||||
|
||||
def test_to_json_with_cross_validation(self):
|
||||
"""Test JSON serialization includes cross-validation."""
|
||||
cv = CrossValidationResult(
|
||||
is_valid=True,
|
||||
ocr_match=True,
|
||||
amount_match=True,
|
||||
payment_line_ocr="12345678901234",
|
||||
payment_line_amount="1500.00",
|
||||
details=["OCR match", "Amount match"],
|
||||
)
|
||||
|
||||
result = InferenceResult(
|
||||
document_id="test-doc",
|
||||
success=True,
|
||||
fields={"InvoiceNumber": "INV-001"},
|
||||
cross_validation=cv,
|
||||
)
|
||||
|
||||
json_data = result.to_json()
|
||||
|
||||
assert "cross_validation" in json_data
|
||||
assert json_data["cross_validation"]["is_valid"] is True
|
||||
assert json_data["cross_validation"]["ocr_match"] is True
|
||||
assert json_data["cross_validation"]["payment_line_ocr"] == "12345678901234"
|
||||
1
tests/integration/repositories/__init__.py
Normal file
1
tests/integration/repositories/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Repository integration tests."""
|
||||
@@ -0,0 +1,464 @@
|
||||
"""
|
||||
Annotation Repository Integration Tests
|
||||
|
||||
Tests AnnotationRepository with real database operations.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.annotation_repository import AnnotationRepository
|
||||
|
||||
|
||||
class TestAnnotationRepositoryCreate:
|
||||
"""Tests for annotation creation."""
|
||||
|
||||
def test_create_annotation(self, patched_session, sample_document):
|
||||
"""Test creating a single annotation."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
ann_id = repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.3,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=240,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
text_value="INV-2024-001",
|
||||
confidence=0.95,
|
||||
source="auto",
|
||||
)
|
||||
|
||||
assert ann_id is not None
|
||||
|
||||
ann = repo.get(ann_id)
|
||||
assert ann is not None
|
||||
assert ann.class_name == "invoice_number"
|
||||
assert ann.text_value == "INV-2024-001"
|
||||
assert ann.confidence == 0.95
|
||||
assert ann.source == "auto"
|
||||
|
||||
def test_create_batch_annotations(self, patched_session, sample_document):
|
||||
"""Test batch creation of annotations."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
annotations_data = [
|
||||
{
|
||||
"document_id": str(sample_document.document_id),
|
||||
"page_number": 1,
|
||||
"class_id": 0,
|
||||
"class_name": "invoice_number",
|
||||
"x_center": 0.5,
|
||||
"y_center": 0.1,
|
||||
"width": 0.2,
|
||||
"height": 0.05,
|
||||
"bbox_x": 400,
|
||||
"bbox_y": 80,
|
||||
"bbox_width": 160,
|
||||
"bbox_height": 40,
|
||||
"text_value": "INV-001",
|
||||
"confidence": 0.95,
|
||||
},
|
||||
{
|
||||
"document_id": str(sample_document.document_id),
|
||||
"page_number": 1,
|
||||
"class_id": 1,
|
||||
"class_name": "invoice_date",
|
||||
"x_center": 0.5,
|
||||
"y_center": 0.2,
|
||||
"width": 0.15,
|
||||
"height": 0.04,
|
||||
"bbox_x": 400,
|
||||
"bbox_y": 160,
|
||||
"bbox_width": 120,
|
||||
"bbox_height": 32,
|
||||
"text_value": "2024-01-15",
|
||||
"confidence": 0.92,
|
||||
},
|
||||
{
|
||||
"document_id": str(sample_document.document_id),
|
||||
"page_number": 1,
|
||||
"class_id": 6,
|
||||
"class_name": "amount",
|
||||
"x_center": 0.7,
|
||||
"y_center": 0.8,
|
||||
"width": 0.1,
|
||||
"height": 0.04,
|
||||
"bbox_x": 560,
|
||||
"bbox_y": 640,
|
||||
"bbox_width": 80,
|
||||
"bbox_height": 32,
|
||||
"text_value": "1500.00",
|
||||
"confidence": 0.98,
|
||||
},
|
||||
]
|
||||
|
||||
ids = repo.create_batch(annotations_data)
|
||||
|
||||
assert len(ids) == 3
|
||||
|
||||
# Verify all annotations exist
|
||||
for ann_id in ids:
|
||||
ann = repo.get(ann_id)
|
||||
assert ann is not None
|
||||
|
||||
|
||||
class TestAnnotationRepositoryRead:
|
||||
"""Tests for annotation retrieval."""
|
||||
|
||||
def test_get_nonexistent_annotation(self, patched_session):
|
||||
"""Test getting an annotation that doesn't exist."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
ann = repo.get(str(uuid4()))
|
||||
assert ann is None
|
||||
|
||||
def test_get_annotations_for_document(self, patched_session, sample_document, sample_annotation):
|
||||
"""Test getting all annotations for a document."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
# Add another annotation
|
||||
repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=1,
|
||||
class_id=1,
|
||||
class_name="invoice_date",
|
||||
x_center=0.5,
|
||||
y_center=0.4,
|
||||
width=0.15,
|
||||
height=0.04,
|
||||
bbox_x=400,
|
||||
bbox_y=320,
|
||||
bbox_width=120,
|
||||
bbox_height=32,
|
||||
text_value="2024-01-15",
|
||||
)
|
||||
|
||||
annotations = repo.get_for_document(str(sample_document.document_id))
|
||||
|
||||
assert len(annotations) == 2
|
||||
# Should be ordered by class_id
|
||||
assert annotations[0].class_id == 0
|
||||
assert annotations[1].class_id == 1
|
||||
|
||||
def test_get_annotations_for_specific_page(self, patched_session, sample_document):
|
||||
"""Test getting annotations for a specific page."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
# Create annotations on different pages
|
||||
repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.1,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=80,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
)
|
||||
repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=2,
|
||||
class_id=6,
|
||||
class_name="amount",
|
||||
x_center=0.7,
|
||||
y_center=0.8,
|
||||
width=0.1,
|
||||
height=0.04,
|
||||
bbox_x=560,
|
||||
bbox_y=640,
|
||||
bbox_width=80,
|
||||
bbox_height=32,
|
||||
)
|
||||
|
||||
page1_annotations = repo.get_for_document(
|
||||
str(sample_document.document_id),
|
||||
page_number=1,
|
||||
)
|
||||
page2_annotations = repo.get_for_document(
|
||||
str(sample_document.document_id),
|
||||
page_number=2,
|
||||
)
|
||||
|
||||
assert len(page1_annotations) == 1
|
||||
assert len(page2_annotations) == 1
|
||||
assert page1_annotations[0].page_number == 1
|
||||
assert page2_annotations[0].page_number == 2
|
||||
|
||||
|
||||
class TestAnnotationRepositoryUpdate:
|
||||
"""Tests for annotation updates."""
|
||||
|
||||
def test_update_annotation_bbox(self, patched_session, sample_annotation):
|
||||
"""Test updating annotation bounding box."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
result = repo.update(
|
||||
str(sample_annotation.annotation_id),
|
||||
x_center=0.6,
|
||||
y_center=0.4,
|
||||
width=0.25,
|
||||
height=0.06,
|
||||
bbox_x=480,
|
||||
bbox_y=320,
|
||||
bbox_width=200,
|
||||
bbox_height=48,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
ann = repo.get(str(sample_annotation.annotation_id))
|
||||
assert ann is not None
|
||||
assert ann.x_center == 0.6
|
||||
assert ann.y_center == 0.4
|
||||
assert ann.bbox_x == 480
|
||||
assert ann.bbox_width == 200
|
||||
|
||||
def test_update_annotation_text(self, patched_session, sample_annotation):
|
||||
"""Test updating annotation text value."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
result = repo.update(
|
||||
str(sample_annotation.annotation_id),
|
||||
text_value="INV-2024-002",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
ann = repo.get(str(sample_annotation.annotation_id))
|
||||
assert ann is not None
|
||||
assert ann.text_value == "INV-2024-002"
|
||||
|
||||
def test_update_annotation_class(self, patched_session, sample_annotation):
|
||||
"""Test updating annotation class."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
result = repo.update(
|
||||
str(sample_annotation.annotation_id),
|
||||
class_id=1,
|
||||
class_name="invoice_date",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
ann = repo.get(str(sample_annotation.annotation_id))
|
||||
assert ann is not None
|
||||
assert ann.class_id == 1
|
||||
assert ann.class_name == "invoice_date"
|
||||
|
||||
def test_update_nonexistent_annotation(self, patched_session):
|
||||
"""Test updating annotation that doesn't exist."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
result = repo.update(
|
||||
str(uuid4()),
|
||||
text_value="new value",
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestAnnotationRepositoryDelete:
|
||||
"""Tests for annotation deletion."""
|
||||
|
||||
def test_delete_annotation(self, patched_session, sample_annotation):
|
||||
"""Test deleting a single annotation."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
result = repo.delete(str(sample_annotation.annotation_id))
|
||||
assert result is True
|
||||
|
||||
ann = repo.get(str(sample_annotation.annotation_id))
|
||||
assert ann is None
|
||||
|
||||
def test_delete_nonexistent_annotation(self, patched_session):
|
||||
"""Test deleting annotation that doesn't exist."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
result = repo.delete(str(uuid4()))
|
||||
assert result is False
|
||||
|
||||
def test_delete_annotations_for_document(self, patched_session, sample_document):
|
||||
"""Test deleting all annotations for a document."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
# Create multiple annotations
|
||||
for i in range(3):
|
||||
repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=1,
|
||||
class_id=i,
|
||||
class_name=f"field_{i}",
|
||||
x_center=0.5,
|
||||
y_center=0.1 + i * 0.2,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=80 + i * 160,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
)
|
||||
|
||||
# Delete all
|
||||
count = repo.delete_for_document(str(sample_document.document_id))
|
||||
|
||||
assert count == 3
|
||||
|
||||
annotations = repo.get_for_document(str(sample_document.document_id))
|
||||
assert len(annotations) == 0
|
||||
|
||||
def test_delete_annotations_by_source(self, patched_session, sample_document):
|
||||
"""Test deleting annotations by source type."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
# Create auto and manual annotations
|
||||
repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.1,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=80,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
source="auto",
|
||||
)
|
||||
repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=1,
|
||||
class_id=1,
|
||||
class_name="invoice_date",
|
||||
x_center=0.5,
|
||||
y_center=0.2,
|
||||
width=0.15,
|
||||
height=0.04,
|
||||
bbox_x=400,
|
||||
bbox_y=160,
|
||||
bbox_width=120,
|
||||
bbox_height=32,
|
||||
source="manual",
|
||||
)
|
||||
|
||||
# Delete only auto annotations
|
||||
count = repo.delete_for_document(str(sample_document.document_id), source="auto")
|
||||
|
||||
assert count == 1
|
||||
|
||||
remaining = repo.get_for_document(str(sample_document.document_id))
|
||||
assert len(remaining) == 1
|
||||
assert remaining[0].source == "manual"
|
||||
|
||||
|
||||
class TestAnnotationVerification:
|
||||
"""Tests for annotation verification."""
|
||||
|
||||
def test_verify_annotation(self, patched_session, admin_token, sample_annotation):
|
||||
"""Test marking annotation as verified."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
ann = repo.verify(str(sample_annotation.annotation_id), admin_token.token)
|
||||
|
||||
assert ann is not None
|
||||
assert ann.is_verified is True
|
||||
assert ann.verified_by == admin_token.token
|
||||
assert ann.verified_at is not None
|
||||
|
||||
|
||||
class TestAnnotationOverride:
|
||||
"""Tests for annotation override functionality."""
|
||||
|
||||
def test_override_auto_annotation(self, patched_session, admin_token, sample_annotation):
|
||||
"""Test overriding an auto-generated annotation."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
# Override the annotation
|
||||
ann = repo.override(
|
||||
str(sample_annotation.annotation_id),
|
||||
admin_token.token,
|
||||
change_reason="Correcting OCR error",
|
||||
text_value="INV-2024-CORRECTED",
|
||||
x_center=0.55,
|
||||
)
|
||||
|
||||
assert ann is not None
|
||||
assert ann.text_value == "INV-2024-CORRECTED"
|
||||
assert ann.x_center == 0.55
|
||||
assert ann.source == "manual" # Changed from auto to manual
|
||||
assert ann.override_source == "auto"
|
||||
|
||||
|
||||
class TestAnnotationHistory:
|
||||
"""Tests for annotation history tracking."""
|
||||
|
||||
def test_create_history_record(self, patched_session, sample_annotation):
|
||||
"""Test creating annotation history record."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
history = repo.create_history(
|
||||
annotation_id=sample_annotation.annotation_id,
|
||||
document_id=sample_annotation.document_id,
|
||||
action="created",
|
||||
new_value={"text_value": "INV-001"},
|
||||
changed_by="test-user",
|
||||
)
|
||||
|
||||
assert history is not None
|
||||
assert history.action == "created"
|
||||
assert history.changed_by == "test-user"
|
||||
|
||||
def test_get_annotation_history(self, patched_session, sample_annotation):
|
||||
"""Test getting history for an annotation."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
# Create history records
|
||||
repo.create_history(
|
||||
annotation_id=sample_annotation.annotation_id,
|
||||
document_id=sample_annotation.document_id,
|
||||
action="created",
|
||||
new_value={"text_value": "INV-001"},
|
||||
)
|
||||
repo.create_history(
|
||||
annotation_id=sample_annotation.annotation_id,
|
||||
document_id=sample_annotation.document_id,
|
||||
action="updated",
|
||||
previous_value={"text_value": "INV-001"},
|
||||
new_value={"text_value": "INV-002"},
|
||||
)
|
||||
|
||||
history = repo.get_history(sample_annotation.annotation_id)
|
||||
|
||||
assert len(history) == 2
|
||||
# Should be ordered by created_at desc
|
||||
assert history[0].action == "updated"
|
||||
assert history[1].action == "created"
|
||||
|
||||
def test_get_document_history(self, patched_session, sample_document, sample_annotation):
|
||||
"""Test getting all annotation history for a document."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
repo.create_history(
|
||||
annotation_id=sample_annotation.annotation_id,
|
||||
document_id=sample_document.document_id,
|
||||
action="created",
|
||||
new_value={"class_name": "invoice_number"},
|
||||
)
|
||||
|
||||
history = repo.get_document_history(sample_document.document_id)
|
||||
|
||||
assert len(history) >= 1
|
||||
assert all(h.document_id == sample_document.document_id for h in history)
|
||||
@@ -0,0 +1,355 @@
|
||||
"""
|
||||
Batch Upload Repository Integration Tests
|
||||
|
||||
Tests BatchUploadRepository with real database operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.batch_upload_repository import BatchUploadRepository
|
||||
|
||||
|
||||
class TestBatchUploadCreate:
|
||||
"""Tests for batch upload creation."""
|
||||
|
||||
def test_create_batch_upload(self, patched_session, admin_token):
|
||||
"""Test creating a batch upload."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
batch = repo.create(
|
||||
admin_token=admin_token.token,
|
||||
filename="test_batch.zip",
|
||||
file_size=10240,
|
||||
upload_source="api",
|
||||
)
|
||||
|
||||
assert batch is not None
|
||||
assert batch.batch_id is not None
|
||||
assert batch.filename == "test_batch.zip"
|
||||
assert batch.file_size == 10240
|
||||
assert batch.upload_source == "api"
|
||||
assert batch.status == "processing"
|
||||
assert batch.total_files == 0
|
||||
assert batch.processed_files == 0
|
||||
|
||||
def test_create_batch_upload_default_source(self, patched_session, admin_token):
|
||||
"""Test creating batch upload with default source."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
batch = repo.create(
|
||||
admin_token=admin_token.token,
|
||||
filename="ui_batch.zip",
|
||||
file_size=5120,
|
||||
)
|
||||
|
||||
assert batch.upload_source == "ui"
|
||||
|
||||
|
||||
class TestBatchUploadRead:
|
||||
"""Tests for batch upload retrieval."""
|
||||
|
||||
def test_get_batch_upload(self, patched_session, sample_batch_upload):
|
||||
"""Test getting a batch upload by ID."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
batch = repo.get(sample_batch_upload.batch_id)
|
||||
|
||||
assert batch is not None
|
||||
assert batch.batch_id == sample_batch_upload.batch_id
|
||||
assert batch.filename == sample_batch_upload.filename
|
||||
|
||||
def test_get_nonexistent_batch_upload(self, patched_session):
|
||||
"""Test getting a batch upload that doesn't exist."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
batch = repo.get(uuid4())
|
||||
assert batch is None
|
||||
|
||||
def test_get_paginated_batch_uploads(self, patched_session, admin_token):
|
||||
"""Test paginated batch upload listing."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
# Create multiple batches
|
||||
for i in range(5):
|
||||
repo.create(
|
||||
admin_token=admin_token.token,
|
||||
filename=f"batch_{i}.zip",
|
||||
file_size=1024 * (i + 1),
|
||||
)
|
||||
|
||||
batches, total = repo.get_paginated(limit=3, offset=0)
|
||||
|
||||
assert total == 5
|
||||
assert len(batches) == 3
|
||||
|
||||
def test_get_paginated_with_offset(self, patched_session, admin_token):
|
||||
"""Test pagination offset."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
for i in range(5):
|
||||
repo.create(
|
||||
admin_token=admin_token.token,
|
||||
filename=f"batch_{i}.zip",
|
||||
file_size=1024,
|
||||
)
|
||||
|
||||
page1, _ = repo.get_paginated(limit=2, offset=0)
|
||||
page2, _ = repo.get_paginated(limit=2, offset=2)
|
||||
|
||||
ids_page1 = {b.batch_id for b in page1}
|
||||
ids_page2 = {b.batch_id for b in page2}
|
||||
|
||||
assert len(ids_page1 & ids_page2) == 0
|
||||
|
||||
|
||||
class TestBatchUploadUpdate:
|
||||
"""Tests for batch upload updates."""
|
||||
|
||||
def test_update_batch_status(self, patched_session, sample_batch_upload):
|
||||
"""Test updating batch upload status."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
repo.update(
|
||||
sample_batch_upload.batch_id,
|
||||
status="completed",
|
||||
total_files=10,
|
||||
processed_files=10,
|
||||
successful_files=8,
|
||||
failed_files=2,
|
||||
)
|
||||
|
||||
# Need to commit to see changes
|
||||
patched_session.commit()
|
||||
|
||||
batch = repo.get(sample_batch_upload.batch_id)
|
||||
assert batch.status == "completed"
|
||||
assert batch.total_files == 10
|
||||
assert batch.successful_files == 8
|
||||
assert batch.failed_files == 2
|
||||
|
||||
def test_update_batch_with_error(self, patched_session, sample_batch_upload):
|
||||
"""Test updating batch upload with error message."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
repo.update(
|
||||
sample_batch_upload.batch_id,
|
||||
status="failed",
|
||||
error_message="ZIP extraction failed",
|
||||
)
|
||||
|
||||
patched_session.commit()
|
||||
|
||||
batch = repo.get(sample_batch_upload.batch_id)
|
||||
assert batch.status == "failed"
|
||||
assert batch.error_message == "ZIP extraction failed"
|
||||
|
||||
def test_update_batch_csv_info(self, patched_session, sample_batch_upload):
|
||||
"""Test updating batch with CSV information."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
repo.update(
|
||||
sample_batch_upload.batch_id,
|
||||
csv_filename="manifest.csv",
|
||||
csv_row_count=100,
|
||||
)
|
||||
|
||||
patched_session.commit()
|
||||
|
||||
batch = repo.get(sample_batch_upload.batch_id)
|
||||
assert batch.csv_filename == "manifest.csv"
|
||||
assert batch.csv_row_count == 100
|
||||
|
||||
|
||||
class TestBatchUploadFiles:
|
||||
"""Tests for batch upload file management."""
|
||||
|
||||
def test_create_batch_file(self, patched_session, sample_batch_upload):
|
||||
"""Test creating a batch upload file record."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
file_record = repo.create_file(
|
||||
batch_id=sample_batch_upload.batch_id,
|
||||
filename="invoice_001.pdf",
|
||||
status="pending",
|
||||
)
|
||||
|
||||
assert file_record is not None
|
||||
assert file_record.file_id is not None
|
||||
assert file_record.filename == "invoice_001.pdf"
|
||||
assert file_record.batch_id == sample_batch_upload.batch_id
|
||||
assert file_record.status == "pending"
|
||||
|
||||
def test_create_batch_file_with_document_link(self, patched_session, sample_batch_upload, sample_document):
|
||||
"""Test creating batch file linked to a document."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
file_record = repo.create_file(
|
||||
batch_id=sample_batch_upload.batch_id,
|
||||
filename="invoice_linked.pdf",
|
||||
document_id=sample_document.document_id,
|
||||
status="completed",
|
||||
annotation_count=5,
|
||||
)
|
||||
|
||||
assert file_record.document_id == sample_document.document_id
|
||||
assert file_record.status == "completed"
|
||||
assert file_record.annotation_count == 5
|
||||
|
||||
def test_get_batch_files(self, patched_session, sample_batch_upload):
|
||||
"""Test getting all files for a batch."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
# Create multiple files
|
||||
for i in range(3):
|
||||
repo.create_file(
|
||||
batch_id=sample_batch_upload.batch_id,
|
||||
filename=f"file_{i}.pdf",
|
||||
)
|
||||
|
||||
files = repo.get_files(sample_batch_upload.batch_id)
|
||||
|
||||
assert len(files) == 3
|
||||
assert all(f.batch_id == sample_batch_upload.batch_id for f in files)
|
||||
|
||||
def test_get_batch_files_empty(self, patched_session, sample_batch_upload):
|
||||
"""Test getting files for batch with no files."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
files = repo.get_files(sample_batch_upload.batch_id)
|
||||
|
||||
assert files == []
|
||||
|
||||
def test_update_batch_file_status(self, patched_session, sample_batch_upload):
|
||||
"""Test updating batch file status."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
file_record = repo.create_file(
|
||||
batch_id=sample_batch_upload.batch_id,
|
||||
filename="test.pdf",
|
||||
)
|
||||
|
||||
repo.update_file(
|
||||
file_record.file_id,
|
||||
status="completed",
|
||||
annotation_count=10,
|
||||
)
|
||||
|
||||
patched_session.commit()
|
||||
|
||||
files = repo.get_files(sample_batch_upload.batch_id)
|
||||
updated_file = files[0]
|
||||
assert updated_file.status == "completed"
|
||||
assert updated_file.annotation_count == 10
|
||||
|
||||
def test_update_batch_file_with_error(self, patched_session, sample_batch_upload):
|
||||
"""Test updating batch file with error."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
file_record = repo.create_file(
|
||||
batch_id=sample_batch_upload.batch_id,
|
||||
filename="corrupt.pdf",
|
||||
)
|
||||
|
||||
repo.update_file(
|
||||
file_record.file_id,
|
||||
status="failed",
|
||||
error_message="Invalid PDF format",
|
||||
)
|
||||
|
||||
patched_session.commit()
|
||||
|
||||
files = repo.get_files(sample_batch_upload.batch_id)
|
||||
updated_file = files[0]
|
||||
assert updated_file.status == "failed"
|
||||
assert updated_file.error_message == "Invalid PDF format"
|
||||
|
||||
def test_update_batch_file_with_csv_data(self, patched_session, sample_batch_upload):
|
||||
"""Test updating batch file with CSV row data."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
file_record = repo.create_file(
|
||||
batch_id=sample_batch_upload.batch_id,
|
||||
filename="invoice_with_csv.pdf",
|
||||
)
|
||||
|
||||
csv_data = {
|
||||
"invoice_number": "INV-001",
|
||||
"amount": "1500.00",
|
||||
"supplier": "Test Corp",
|
||||
}
|
||||
|
||||
repo.update_file(
|
||||
file_record.file_id,
|
||||
csv_row_data=csv_data,
|
||||
)
|
||||
|
||||
patched_session.commit()
|
||||
|
||||
files = repo.get_files(sample_batch_upload.batch_id)
|
||||
updated_file = files[0]
|
||||
assert updated_file.csv_row_data == csv_data
|
||||
|
||||
|
||||
class TestBatchUploadWorkflow:
|
||||
"""Tests for complete batch upload workflows."""
|
||||
|
||||
def test_complete_batch_workflow(self, patched_session, admin_token):
|
||||
"""Test complete batch upload workflow."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
# 1. Create batch
|
||||
batch = repo.create(
|
||||
admin_token=admin_token.token,
|
||||
filename="full_workflow.zip",
|
||||
file_size=50000,
|
||||
)
|
||||
|
||||
# 2. Update with file count
|
||||
repo.update(batch.batch_id, total_files=3)
|
||||
patched_session.commit()
|
||||
|
||||
# 3. Create file records
|
||||
file_ids = []
|
||||
for i in range(3):
|
||||
file_record = repo.create_file(
|
||||
batch_id=batch.batch_id,
|
||||
filename=f"doc_{i}.pdf",
|
||||
)
|
||||
file_ids.append(file_record.file_id)
|
||||
|
||||
# 4. Process files one by one
|
||||
for i, file_id in enumerate(file_ids):
|
||||
status = "completed" if i < 2 else "failed"
|
||||
repo.update_file(
|
||||
file_id,
|
||||
status=status,
|
||||
annotation_count=5 if status == "completed" else 0,
|
||||
)
|
||||
|
||||
# 5. Update batch progress
|
||||
repo.update(
|
||||
batch.batch_id,
|
||||
processed_files=3,
|
||||
successful_files=2,
|
||||
failed_files=1,
|
||||
status="partial",
|
||||
)
|
||||
patched_session.commit()
|
||||
|
||||
# Verify final state
|
||||
final_batch = repo.get(batch.batch_id)
|
||||
assert final_batch.status == "partial"
|
||||
assert final_batch.total_files == 3
|
||||
assert final_batch.processed_files == 3
|
||||
assert final_batch.successful_files == 2
|
||||
assert final_batch.failed_files == 1
|
||||
|
||||
files = repo.get_files(batch.batch_id)
|
||||
assert len(files) == 3
|
||||
completed = [f for f in files if f.status == "completed"]
|
||||
failed = [f for f in files if f.status == "failed"]
|
||||
assert len(completed) == 2
|
||||
assert len(failed) == 1
|
||||
321
tests/integration/repositories/test_dataset_repo_integration.py
Normal file
321
tests/integration/repositories/test_dataset_repo_integration.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""
|
||||
Dataset Repository Integration Tests
|
||||
|
||||
Tests DatasetRepository with real database operations.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.dataset_repository import DatasetRepository
|
||||
|
||||
|
||||
class TestDatasetRepositoryCreate:
|
||||
"""Tests for dataset creation."""
|
||||
|
||||
def test_create_dataset(self, patched_session):
|
||||
"""Test creating a training dataset."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
dataset = repo.create(
|
||||
name="Test Dataset",
|
||||
description="Dataset for integration testing",
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
assert dataset is not None
|
||||
assert dataset.name == "Test Dataset"
|
||||
assert dataset.description == "Dataset for integration testing"
|
||||
assert dataset.train_ratio == 0.8
|
||||
assert dataset.val_ratio == 0.1
|
||||
assert dataset.seed == 42
|
||||
assert dataset.status == "building"
|
||||
|
||||
def test_create_dataset_with_defaults(self, patched_session):
|
||||
"""Test creating dataset with default values."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
dataset = repo.create(name="Minimal Dataset")
|
||||
|
||||
assert dataset is not None
|
||||
assert dataset.train_ratio == 0.8
|
||||
assert dataset.val_ratio == 0.1
|
||||
assert dataset.seed == 42
|
||||
|
||||
|
||||
class TestDatasetRepositoryRead:
|
||||
"""Tests for dataset retrieval."""
|
||||
|
||||
def test_get_dataset_by_id(self, patched_session, sample_dataset):
|
||||
"""Test getting dataset by ID."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
|
||||
assert dataset is not None
|
||||
assert dataset.dataset_id == sample_dataset.dataset_id
|
||||
assert dataset.name == sample_dataset.name
|
||||
|
||||
def test_get_nonexistent_dataset(self, patched_session):
|
||||
"""Test getting dataset that doesn't exist."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
dataset = repo.get(str(uuid4()))
|
||||
assert dataset is None
|
||||
|
||||
def test_get_paginated_datasets(self, patched_session):
|
||||
"""Test paginated dataset listing."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
# Create multiple datasets
|
||||
for i in range(5):
|
||||
repo.create(name=f"Dataset {i}")
|
||||
|
||||
datasets, total = repo.get_paginated(limit=2, offset=0)
|
||||
|
||||
assert total == 5
|
||||
assert len(datasets) == 2
|
||||
|
||||
def test_get_paginated_with_status_filter(self, patched_session):
|
||||
"""Test filtering datasets by status."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
# Create datasets with different statuses
|
||||
d1 = repo.create(name="Building Dataset")
|
||||
repo.update_status(str(d1.dataset_id), "ready")
|
||||
|
||||
d2 = repo.create(name="Another Building Dataset")
|
||||
# stays as "building"
|
||||
|
||||
datasets, total = repo.get_paginated(status="ready")
|
||||
|
||||
assert total == 1
|
||||
assert datasets[0].status == "ready"
|
||||
|
||||
|
||||
class TestDatasetRepositoryUpdate:
|
||||
"""Tests for dataset updates."""
|
||||
|
||||
def test_update_status(self, patched_session, sample_dataset):
|
||||
"""Test updating dataset status."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
repo.update_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
status="ready",
|
||||
total_documents=100,
|
||||
total_images=150,
|
||||
total_annotations=500,
|
||||
)
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
assert dataset is not None
|
||||
assert dataset.status == "ready"
|
||||
assert dataset.total_documents == 100
|
||||
assert dataset.total_images == 150
|
||||
assert dataset.total_annotations == 500
|
||||
|
||||
def test_update_status_with_error(self, patched_session, sample_dataset):
|
||||
"""Test updating dataset status with error message."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
repo.update_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
status="failed",
|
||||
error_message="Failed to build dataset: insufficient documents",
|
||||
)
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
assert dataset is not None
|
||||
assert dataset.status == "failed"
|
||||
assert "insufficient documents" in dataset.error_message
|
||||
|
||||
def test_update_status_with_path(self, patched_session, sample_dataset):
|
||||
"""Test updating dataset path."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
repo.update_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
status="ready",
|
||||
dataset_path="/datasets/test_dataset_2024",
|
||||
)
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
assert dataset is not None
|
||||
assert dataset.dataset_path == "/datasets/test_dataset_2024"
|
||||
|
||||
def test_update_training_status(self, patched_session, sample_dataset, sample_training_task):
|
||||
"""Test updating dataset training status."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
repo.update_training_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
training_status="running",
|
||||
active_training_task_id=str(sample_training_task.task_id),
|
||||
)
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
assert dataset is not None
|
||||
assert dataset.training_status == "running"
|
||||
assert dataset.active_training_task_id == sample_training_task.task_id
|
||||
|
||||
def test_update_training_status_completed(self, patched_session, sample_dataset):
|
||||
"""Test updating training status to completed updates main status."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
# First set to ready
|
||||
repo.update_status(str(sample_dataset.dataset_id), status="ready")
|
||||
|
||||
# Then complete training
|
||||
repo.update_training_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
training_status="completed",
|
||||
update_main_status=True,
|
||||
)
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
assert dataset is not None
|
||||
assert dataset.training_status == "completed"
|
||||
assert dataset.status == "trained"
|
||||
|
||||
|
||||
class TestDatasetDocuments:
|
||||
"""Tests for dataset document management."""
|
||||
|
||||
def test_add_documents_to_dataset(self, patched_session, sample_dataset, multiple_documents):
|
||||
"""Test adding documents to a dataset."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
documents_data = [
|
||||
{
|
||||
"document_id": str(multiple_documents[0].document_id),
|
||||
"split": "train",
|
||||
"page_count": 1,
|
||||
"annotation_count": 5,
|
||||
},
|
||||
{
|
||||
"document_id": str(multiple_documents[1].document_id),
|
||||
"split": "train",
|
||||
"page_count": 2,
|
||||
"annotation_count": 8,
|
||||
},
|
||||
{
|
||||
"document_id": str(multiple_documents[2].document_id),
|
||||
"split": "val",
|
||||
"page_count": 1,
|
||||
"annotation_count": 3,
|
||||
},
|
||||
]
|
||||
|
||||
repo.add_documents(str(sample_dataset.dataset_id), documents_data)
|
||||
|
||||
# Verify documents were added
|
||||
docs = repo.get_documents(str(sample_dataset.dataset_id))
|
||||
assert len(docs) == 3
|
||||
|
||||
train_docs = [d for d in docs if d.split == "train"]
|
||||
val_docs = [d for d in docs if d.split == "val"]
|
||||
|
||||
assert len(train_docs) == 2
|
||||
assert len(val_docs) == 1
|
||||
|
||||
def test_get_dataset_documents(self, patched_session, sample_dataset, sample_document):
|
||||
"""Test getting documents from a dataset."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
repo.add_documents(
|
||||
str(sample_dataset.dataset_id),
|
||||
[
|
||||
{
|
||||
"document_id": str(sample_document.document_id),
|
||||
"split": "train",
|
||||
"page_count": 1,
|
||||
"annotation_count": 5,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
docs = repo.get_documents(str(sample_dataset.dataset_id))
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].document_id == sample_document.document_id
|
||||
assert docs[0].split == "train"
|
||||
assert docs[0].page_count == 1
|
||||
assert docs[0].annotation_count == 5
|
||||
|
||||
|
||||
class TestDatasetRepositoryDelete:
|
||||
"""Tests for dataset deletion."""
|
||||
|
||||
def test_delete_dataset(self, patched_session, sample_dataset):
|
||||
"""Test deleting a dataset."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
result = repo.delete(str(sample_dataset.dataset_id))
|
||||
assert result is True
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
assert dataset is None
|
||||
|
||||
def test_delete_nonexistent_dataset(self, patched_session):
|
||||
"""Test deleting dataset that doesn't exist."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
result = repo.delete(str(uuid4()))
|
||||
assert result is False
|
||||
|
||||
def test_delete_dataset_cascades_documents(self, patched_session, sample_dataset, sample_document):
|
||||
"""Test deleting dataset also removes document links."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
# Add document to dataset
|
||||
repo.add_documents(
|
||||
str(sample_dataset.dataset_id),
|
||||
[
|
||||
{
|
||||
"document_id": str(sample_document.document_id),
|
||||
"split": "train",
|
||||
"page_count": 1,
|
||||
"annotation_count": 5,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
# Delete dataset
|
||||
repo.delete(str(sample_dataset.dataset_id))
|
||||
|
||||
# Document links should be gone
|
||||
docs = repo.get_documents(str(sample_dataset.dataset_id))
|
||||
assert len(docs) == 0
|
||||
|
||||
|
||||
class TestActiveTrainingTasks:
|
||||
"""Tests for active training task queries."""
|
||||
|
||||
def test_get_active_training_tasks(self, patched_session, sample_dataset, sample_training_task):
|
||||
"""Test getting active training tasks for datasets."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
# Update task to running
|
||||
from inference.data.repositories.training_task_repository import TrainingTaskRepository
|
||||
|
||||
task_repo = TrainingTaskRepository()
|
||||
task_repo.update_status(str(sample_training_task.task_id), "running")
|
||||
|
||||
result = repo.get_active_training_tasks([str(sample_dataset.dataset_id)])
|
||||
|
||||
assert str(sample_dataset.dataset_id) in result
|
||||
assert result[str(sample_dataset.dataset_id)]["status"] == "running"
|
||||
|
||||
def test_get_active_training_tasks_empty(self, patched_session, sample_dataset):
|
||||
"""Test getting active training tasks returns empty when no tasks exist."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
result = repo.get_active_training_tasks([str(sample_dataset.dataset_id)])
|
||||
|
||||
# No training task exists for this dataset, so result should be empty
|
||||
assert str(sample_dataset.dataset_id) not in result
|
||||
assert result == {}
|
||||
350
tests/integration/repositories/test_document_repo_integration.py
Normal file
350
tests/integration/repositories/test_document_repo_integration.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
Document Repository Integration Tests
|
||||
|
||||
Tests DocumentRepository with real database operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlmodel import select
|
||||
|
||||
from inference.data.admin_models import AdminAnnotation, AdminDocument
|
||||
from inference.data.repositories.document_repository import DocumentRepository
|
||||
|
||||
|
||||
def ensure_utc(dt: datetime | None) -> datetime | None:
|
||||
"""Ensure datetime is timezone-aware (UTC).
|
||||
|
||||
PostgreSQL may return offset-naive datetimes. This helper
|
||||
converts them to UTC for proper comparison.
|
||||
"""
|
||||
if dt is None:
|
||||
return None
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
|
||||
|
||||
class TestDocumentRepositoryCreate:
|
||||
"""Tests for document creation."""
|
||||
|
||||
def test_create_document(self, patched_session):
|
||||
"""Test creating a document and retrieving it."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
doc_id = repo.create(
|
||||
filename="test_invoice.pdf",
|
||||
file_size=2048,
|
||||
content_type="application/pdf",
|
||||
file_path="/uploads/test_invoice.pdf",
|
||||
page_count=2,
|
||||
upload_source="api",
|
||||
category="invoice",
|
||||
)
|
||||
|
||||
assert doc_id is not None
|
||||
|
||||
doc = repo.get(doc_id)
|
||||
assert doc is not None
|
||||
assert doc.filename == "test_invoice.pdf"
|
||||
assert doc.file_size == 2048
|
||||
assert doc.page_count == 2
|
||||
assert doc.upload_source == "api"
|
||||
assert doc.category == "invoice"
|
||||
assert doc.status == "pending"
|
||||
|
||||
def test_create_document_with_csv_values(self, patched_session):
|
||||
"""Test creating document with CSV field values."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
csv_values = {
|
||||
"invoice_number": "INV-001",
|
||||
"amount": "1500.00",
|
||||
"supplier_name": "Test Supplier AB",
|
||||
}
|
||||
|
||||
doc_id = repo.create(
|
||||
filename="invoice_with_csv.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/uploads/invoice_with_csv.pdf",
|
||||
csv_field_values=csv_values,
|
||||
)
|
||||
|
||||
doc = repo.get(doc_id)
|
||||
assert doc is not None
|
||||
assert doc.csv_field_values == csv_values
|
||||
|
||||
def test_create_document_with_group_key(self, patched_session):
|
||||
"""Test creating document with group key."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
doc_id = repo.create(
|
||||
filename="grouped_doc.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/uploads/grouped_doc.pdf",
|
||||
group_key="batch-2024-01",
|
||||
)
|
||||
|
||||
doc = repo.get(doc_id)
|
||||
assert doc is not None
|
||||
assert doc.group_key == "batch-2024-01"
|
||||
|
||||
|
||||
class TestDocumentRepositoryRead:
|
||||
"""Tests for document retrieval."""
|
||||
|
||||
def test_get_nonexistent_document(self, patched_session):
|
||||
"""Test getting a document that doesn't exist."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
doc = repo.get(str(uuid4()))
|
||||
assert doc is None
|
||||
|
||||
def test_get_paginated_documents(self, patched_session, multiple_documents):
|
||||
"""Test paginated document listing."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
docs, total = repo.get_paginated(limit=2, offset=0)
|
||||
|
||||
assert total == 5
|
||||
assert len(docs) == 2
|
||||
|
||||
def test_get_paginated_with_status_filter(self, patched_session, multiple_documents):
|
||||
"""Test filtering documents by status."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
docs, total = repo.get_paginated(status="labeled")
|
||||
|
||||
assert total == 2
|
||||
for doc in docs:
|
||||
assert doc.status == "labeled"
|
||||
|
||||
def test_get_paginated_with_category_filter(self, patched_session, multiple_documents):
|
||||
"""Test filtering documents by category."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
docs, total = repo.get_paginated(category="letter")
|
||||
|
||||
assert total == 1
|
||||
assert docs[0].category == "letter"
|
||||
|
||||
def test_get_paginated_with_offset(self, patched_session, multiple_documents):
|
||||
"""Test pagination offset."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
docs_page1, _ = repo.get_paginated(limit=2, offset=0)
|
||||
docs_page2, _ = repo.get_paginated(limit=2, offset=2)
|
||||
|
||||
doc_ids_page1 = {str(d.document_id) for d in docs_page1}
|
||||
doc_ids_page2 = {str(d.document_id) for d in docs_page2}
|
||||
|
||||
assert len(doc_ids_page1 & doc_ids_page2) == 0
|
||||
|
||||
def test_get_by_ids(self, patched_session, multiple_documents):
|
||||
"""Test getting multiple documents by IDs."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
ids_to_fetch = [str(multiple_documents[0].document_id), str(multiple_documents[2].document_id)]
|
||||
docs = repo.get_by_ids(ids_to_fetch)
|
||||
|
||||
assert len(docs) == 2
|
||||
fetched_ids = {str(d.document_id) for d in docs}
|
||||
assert fetched_ids == set(ids_to_fetch)
|
||||
|
||||
|
||||
class TestDocumentRepositoryUpdate:
|
||||
"""Tests for document updates."""
|
||||
|
||||
def test_update_status(self, patched_session, sample_document):
|
||||
"""Test updating document status."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
repo.update_status(
|
||||
str(sample_document.document_id),
|
||||
status="labeled",
|
||||
auto_label_status="completed",
|
||||
)
|
||||
|
||||
doc = repo.get(str(sample_document.document_id))
|
||||
assert doc is not None
|
||||
assert doc.status == "labeled"
|
||||
assert doc.auto_label_status == "completed"
|
||||
|
||||
def test_update_status_with_error(self, patched_session, sample_document):
|
||||
"""Test updating document status with error message."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
repo.update_status(
|
||||
str(sample_document.document_id),
|
||||
status="pending",
|
||||
auto_label_status="failed",
|
||||
auto_label_error="OCR extraction failed",
|
||||
)
|
||||
|
||||
doc = repo.get(str(sample_document.document_id))
|
||||
assert doc is not None
|
||||
assert doc.auto_label_status == "failed"
|
||||
assert doc.auto_label_error == "OCR extraction failed"
|
||||
|
||||
def test_update_file_path(self, patched_session, sample_document):
|
||||
"""Test updating document file path."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
new_path = "/archive/2024/test_invoice.pdf"
|
||||
repo.update_file_path(str(sample_document.document_id), new_path)
|
||||
|
||||
doc = repo.get(str(sample_document.document_id))
|
||||
assert doc is not None
|
||||
assert doc.file_path == new_path
|
||||
|
||||
def test_update_group_key(self, patched_session, sample_document):
|
||||
"""Test updating document group key."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
result = repo.update_group_key(str(sample_document.document_id), "new-group-key")
|
||||
assert result is True
|
||||
|
||||
doc = repo.get(str(sample_document.document_id))
|
||||
assert doc is not None
|
||||
assert doc.group_key == "new-group-key"
|
||||
|
||||
def test_update_category(self, patched_session, sample_document):
|
||||
"""Test updating document category."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
doc = repo.update_category(str(sample_document.document_id), "letter")
|
||||
|
||||
assert doc is not None
|
||||
assert doc.category == "letter"
|
||||
|
||||
|
||||
class TestDocumentRepositoryDelete:
|
||||
"""Tests for document deletion."""
|
||||
|
||||
def test_delete_document(self, patched_session, sample_document):
|
||||
"""Test deleting a document."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
result = repo.delete(str(sample_document.document_id))
|
||||
assert result is True
|
||||
|
||||
doc = repo.get(str(sample_document.document_id))
|
||||
assert doc is None
|
||||
|
||||
def test_delete_document_with_annotations(self, patched_session, sample_document, sample_annotation):
|
||||
"""Test deleting document also deletes its annotations."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
result = repo.delete(str(sample_document.document_id))
|
||||
assert result is True
|
||||
|
||||
# Verify annotation is also deleted
|
||||
from inference.data.repositories.annotation_repository import AnnotationRepository
|
||||
|
||||
ann_repo = AnnotationRepository()
|
||||
annotations = ann_repo.get_for_document(str(sample_document.document_id))
|
||||
assert len(annotations) == 0
|
||||
|
||||
def test_delete_nonexistent_document(self, patched_session):
|
||||
"""Test deleting a document that doesn't exist."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
result = repo.delete(str(uuid4()))
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestDocumentRepositoryQueries:
|
||||
"""Tests for complex document queries."""
|
||||
|
||||
def test_count_by_status(self, patched_session, multiple_documents):
|
||||
"""Test counting documents by status."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
counts = repo.count_by_status()
|
||||
|
||||
assert counts.get("pending") == 2
|
||||
assert counts.get("labeled") == 2
|
||||
assert counts.get("exported") == 1
|
||||
|
||||
def test_get_categories(self, patched_session, multiple_documents):
|
||||
"""Test getting unique categories."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
categories = repo.get_categories()
|
||||
|
||||
assert "invoice" in categories
|
||||
assert "letter" in categories
|
||||
|
||||
def test_get_labeled_for_export(self, patched_session, multiple_documents):
|
||||
"""Test getting labeled documents for export."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
docs = repo.get_labeled_for_export()
|
||||
|
||||
assert len(docs) == 2
|
||||
for doc in docs:
|
||||
assert doc.status == "labeled"
|
||||
|
||||
|
||||
class TestDocumentAnnotationLocking:
|
||||
"""Tests for annotation locking mechanism."""
|
||||
|
||||
def test_acquire_annotation_lock(self, patched_session, sample_document):
|
||||
"""Test acquiring annotation lock."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
doc = repo.acquire_annotation_lock(
|
||||
str(sample_document.document_id),
|
||||
duration_seconds=300,
|
||||
)
|
||||
|
||||
assert doc is not None
|
||||
assert doc.annotation_lock_until is not None
|
||||
lock_until = ensure_utc(doc.annotation_lock_until)
|
||||
assert lock_until > datetime.now(timezone.utc)
|
||||
|
||||
def test_acquire_lock_when_already_locked(self, patched_session, sample_document):
|
||||
"""Test acquiring lock fails when already locked."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
# First lock
|
||||
repo.acquire_annotation_lock(str(sample_document.document_id), duration_seconds=300)
|
||||
|
||||
# Second lock attempt should fail
|
||||
result = repo.acquire_annotation_lock(str(sample_document.document_id))
|
||||
assert result is None
|
||||
|
||||
def test_release_annotation_lock(self, patched_session, sample_document):
|
||||
"""Test releasing annotation lock."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
repo.acquire_annotation_lock(str(sample_document.document_id), duration_seconds=300)
|
||||
doc = repo.release_annotation_lock(str(sample_document.document_id))
|
||||
|
||||
assert doc is not None
|
||||
assert doc.annotation_lock_until is None
|
||||
|
||||
def test_extend_annotation_lock(self, patched_session, sample_document):
|
||||
"""Test extending annotation lock."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
# Acquire initial lock
|
||||
initial_doc = repo.acquire_annotation_lock(
|
||||
str(sample_document.document_id),
|
||||
duration_seconds=300,
|
||||
)
|
||||
initial_expiry = ensure_utc(initial_doc.annotation_lock_until)
|
||||
|
||||
# Extend lock
|
||||
extended_doc = repo.extend_annotation_lock(
|
||||
str(sample_document.document_id),
|
||||
additional_seconds=300,
|
||||
)
|
||||
|
||||
assert extended_doc is not None
|
||||
extended_expiry = ensure_utc(extended_doc.annotation_lock_until)
|
||||
assert extended_expiry > initial_expiry
|
||||
@@ -0,0 +1,310 @@
|
||||
"""
|
||||
Model Version Repository Integration Tests
|
||||
|
||||
Tests ModelVersionRepository with real database operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.model_version_repository import ModelVersionRepository
|
||||
|
||||
|
||||
class TestModelVersionCreate:
|
||||
"""Tests for model version creation."""
|
||||
|
||||
def test_create_model_version(self, patched_session):
|
||||
"""Test creating a model version."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.create(
|
||||
version="1.0.0",
|
||||
name="Invoice Extractor v1",
|
||||
model_path="/models/invoice_v1.pt",
|
||||
description="Initial production model",
|
||||
metrics_mAP=0.92,
|
||||
metrics_precision=0.89,
|
||||
metrics_recall=0.85,
|
||||
document_count=1000,
|
||||
file_size=50000000,
|
||||
)
|
||||
|
||||
assert model is not None
|
||||
assert model.version == "1.0.0"
|
||||
assert model.name == "Invoice Extractor v1"
|
||||
assert model.model_path == "/models/invoice_v1.pt"
|
||||
assert model.metrics_mAP == 0.92
|
||||
assert model.is_active is False
|
||||
assert model.status == "inactive"
|
||||
|
||||
def test_create_model_version_with_training_info(
|
||||
self, patched_session, sample_training_task, sample_dataset
|
||||
):
|
||||
"""Test creating model version linked to training task and dataset."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.create(
|
||||
version="1.1.0",
|
||||
name="Invoice Extractor v1.1",
|
||||
model_path="/models/invoice_v1.1.pt",
|
||||
task_id=sample_training_task.task_id,
|
||||
dataset_id=sample_dataset.dataset_id,
|
||||
training_config={"epochs": 100, "batch_size": 16},
|
||||
trained_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
assert model is not None
|
||||
assert model.task_id == sample_training_task.task_id
|
||||
assert model.dataset_id == sample_dataset.dataset_id
|
||||
assert model.training_config["epochs"] == 100
|
||||
|
||||
|
||||
class TestModelVersionRead:
|
||||
"""Tests for model version retrieval."""
|
||||
|
||||
def test_get_model_version_by_id(self, patched_session, sample_model_version):
|
||||
"""Test getting model version by ID."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.get(str(sample_model_version.version_id))
|
||||
|
||||
assert model is not None
|
||||
assert model.version_id == sample_model_version.version_id
|
||||
|
||||
def test_get_nonexistent_model_version(self, patched_session):
|
||||
"""Test getting model version that doesn't exist."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.get(str(uuid4()))
|
||||
assert model is None
|
||||
|
||||
def test_get_paginated_model_versions(self, patched_session):
|
||||
"""Test paginated model version listing."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
# Create multiple versions
|
||||
for i in range(5):
|
||||
repo.create(
|
||||
version=f"1.{i}.0",
|
||||
name=f"Model v1.{i}",
|
||||
model_path=f"/models/model_v1.{i}.pt",
|
||||
)
|
||||
|
||||
models, total = repo.get_paginated(limit=2, offset=0)
|
||||
|
||||
assert total == 5
|
||||
assert len(models) == 2
|
||||
|
||||
def test_get_paginated_with_status_filter(self, patched_session):
|
||||
"""Test filtering model versions by status."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
# Create active and inactive models
|
||||
m1 = repo.create(version="1.0.0", name="Active Model", model_path="/models/active.pt")
|
||||
repo.activate(str(m1.version_id))
|
||||
|
||||
repo.create(version="2.0.0", name="Inactive Model", model_path="/models/inactive.pt")
|
||||
|
||||
active_models, active_total = repo.get_paginated(status="active")
|
||||
inactive_models, inactive_total = repo.get_paginated(status="inactive")
|
||||
|
||||
assert active_total == 1
|
||||
assert inactive_total == 1
|
||||
|
||||
|
||||
class TestModelVersionActivation:
|
||||
"""Tests for model version activation."""
|
||||
|
||||
def test_activate_model_version(self, patched_session, sample_model_version):
|
||||
"""Test activating a model version."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.activate(str(sample_model_version.version_id))
|
||||
|
||||
assert model is not None
|
||||
assert model.is_active is True
|
||||
assert model.status == "active"
|
||||
assert model.activated_at is not None
|
||||
|
||||
def test_activate_deactivates_others(self, patched_session):
|
||||
"""Test that activating one version deactivates others."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
# Create and activate first model
|
||||
m1 = repo.create(version="1.0.0", name="Model 1", model_path="/models/m1.pt")
|
||||
repo.activate(str(m1.version_id))
|
||||
|
||||
# Create and activate second model
|
||||
m2 = repo.create(version="2.0.0", name="Model 2", model_path="/models/m2.pt")
|
||||
repo.activate(str(m2.version_id))
|
||||
|
||||
# Check first model is now inactive
|
||||
m1_after = repo.get(str(m1.version_id))
|
||||
assert m1_after.is_active is False
|
||||
assert m1_after.status == "inactive"
|
||||
|
||||
# Check second model is active
|
||||
m2_after = repo.get(str(m2.version_id))
|
||||
assert m2_after.is_active is True
|
||||
|
||||
def test_get_active_model(self, patched_session, sample_model_version):
|
||||
"""Test getting the currently active model."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
# Initially no active model
|
||||
active = repo.get_active()
|
||||
assert active is None
|
||||
|
||||
# Activate model
|
||||
repo.activate(str(sample_model_version.version_id))
|
||||
|
||||
# Now should return active model
|
||||
active = repo.get_active()
|
||||
assert active is not None
|
||||
assert active.version_id == sample_model_version.version_id
|
||||
|
||||
def test_deactivate_model_version(self, patched_session, sample_model_version):
|
||||
"""Test deactivating a model version."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
# First activate
|
||||
repo.activate(str(sample_model_version.version_id))
|
||||
|
||||
# Then deactivate
|
||||
model = repo.deactivate(str(sample_model_version.version_id))
|
||||
|
||||
assert model is not None
|
||||
assert model.is_active is False
|
||||
assert model.status == "inactive"
|
||||
|
||||
|
||||
class TestModelVersionUpdate:
|
||||
"""Tests for model version updates."""
|
||||
|
||||
def test_update_model_metadata(self, patched_session, sample_model_version):
|
||||
"""Test updating model version metadata."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.update(
|
||||
str(sample_model_version.version_id),
|
||||
name="Updated Model Name",
|
||||
description="Updated description",
|
||||
)
|
||||
|
||||
assert model is not None
|
||||
assert model.name == "Updated Model Name"
|
||||
assert model.description == "Updated description"
|
||||
|
||||
def test_update_model_status(self, patched_session, sample_model_version):
|
||||
"""Test updating model version status."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.update(str(sample_model_version.version_id), status="deprecated")
|
||||
|
||||
assert model is not None
|
||||
assert model.status == "deprecated"
|
||||
|
||||
def test_update_nonexistent_model(self, patched_session):
|
||||
"""Test updating model that doesn't exist."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.update(str(uuid4()), name="New Name")
|
||||
assert model is None
|
||||
|
||||
|
||||
class TestModelVersionArchive:
|
||||
"""Tests for model version archiving."""
|
||||
|
||||
def test_archive_model_version(self, patched_session, sample_model_version):
|
||||
"""Test archiving an inactive model version."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.archive(str(sample_model_version.version_id))
|
||||
|
||||
assert model is not None
|
||||
assert model.status == "archived"
|
||||
|
||||
def test_cannot_archive_active_model(self, patched_session, sample_model_version):
|
||||
"""Test that active model cannot be archived."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
# Activate the model
|
||||
repo.activate(str(sample_model_version.version_id))
|
||||
|
||||
# Try to archive
|
||||
model = repo.archive(str(sample_model_version.version_id))
|
||||
|
||||
assert model is None
|
||||
|
||||
# Verify model is still active
|
||||
current = repo.get(str(sample_model_version.version_id))
|
||||
assert current.status == "active"
|
||||
|
||||
|
||||
class TestModelVersionDelete:
|
||||
"""Tests for model version deletion."""
|
||||
|
||||
def test_delete_inactive_model(self, patched_session, sample_model_version):
|
||||
"""Test deleting an inactive model version."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
result = repo.delete(str(sample_model_version.version_id))
|
||||
|
||||
assert result is True
|
||||
|
||||
model = repo.get(str(sample_model_version.version_id))
|
||||
assert model is None
|
||||
|
||||
def test_cannot_delete_active_model(self, patched_session, sample_model_version):
|
||||
"""Test that active model cannot be deleted."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
# Activate the model
|
||||
repo.activate(str(sample_model_version.version_id))
|
||||
|
||||
# Try to delete
|
||||
result = repo.delete(str(sample_model_version.version_id))
|
||||
|
||||
assert result is False
|
||||
|
||||
# Verify model still exists
|
||||
model = repo.get(str(sample_model_version.version_id))
|
||||
assert model is not None
|
||||
|
||||
def test_delete_nonexistent_model(self, patched_session):
|
||||
"""Test deleting model that doesn't exist."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
result = repo.delete(str(uuid4()))
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestOnlyOneActiveModel:
|
||||
"""Tests to verify only one model can be active at a time."""
|
||||
|
||||
def test_single_active_model_constraint(self, patched_session):
|
||||
"""Test that only one model can be active at any time."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
# Create multiple models
|
||||
models = []
|
||||
for i in range(3):
|
||||
m = repo.create(
|
||||
version=f"1.{i}.0",
|
||||
name=f"Model {i}",
|
||||
model_path=f"/models/model_{i}.pt",
|
||||
)
|
||||
models.append(m)
|
||||
|
||||
# Activate each model in sequence
|
||||
for model in models:
|
||||
repo.activate(str(model.version_id))
|
||||
|
||||
# Count active models
|
||||
all_models, _ = repo.get_paginated(status="active")
|
||||
assert len(all_models) == 1
|
||||
|
||||
# Verify it's the last one activated
|
||||
assert all_models[0].version_id == models[-1].version_id
|
||||
274
tests/integration/repositories/test_token_repo_integration.py
Normal file
274
tests/integration/repositories/test_token_repo_integration.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""
|
||||
Token Repository Integration Tests
|
||||
|
||||
Tests TokenRepository with real database operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.token_repository import TokenRepository
|
||||
|
||||
|
||||
class TestTokenCreate:
|
||||
"""Tests for token creation."""
|
||||
|
||||
def test_create_new_token(self, patched_session):
|
||||
"""Test creating a new admin token."""
|
||||
repo = TokenRepository()
|
||||
|
||||
repo.create(
|
||||
token="new-test-token-abc123",
|
||||
name="New Test Admin",
|
||||
)
|
||||
|
||||
token = repo.get("new-test-token-abc123")
|
||||
assert token is not None
|
||||
assert token.token == "new-test-token-abc123"
|
||||
assert token.name == "New Test Admin"
|
||||
assert token.is_active is True
|
||||
assert token.expires_at is None
|
||||
|
||||
def test_create_token_with_expiration(self, patched_session):
|
||||
"""Test creating token with expiration date."""
|
||||
repo = TokenRepository()
|
||||
expiry = datetime.now(timezone.utc) + timedelta(days=30)
|
||||
|
||||
repo.create(
|
||||
token="expiring-token-xyz789",
|
||||
name="Expiring Token",
|
||||
expires_at=expiry,
|
||||
)
|
||||
|
||||
token = repo.get("expiring-token-xyz789")
|
||||
assert token is not None
|
||||
assert token.expires_at is not None
|
||||
|
||||
def test_create_updates_existing_token(self, patched_session, admin_token):
|
||||
"""Test creating with existing token updates it."""
|
||||
repo = TokenRepository()
|
||||
new_expiry = datetime.now(timezone.utc) + timedelta(days=60)
|
||||
|
||||
repo.create(
|
||||
token=admin_token.token,
|
||||
name="Updated Admin Name",
|
||||
expires_at=new_expiry,
|
||||
)
|
||||
|
||||
token = repo.get(admin_token.token)
|
||||
assert token is not None
|
||||
assert token.name == "Updated Admin Name"
|
||||
assert token.is_active is True
|
||||
|
||||
|
||||
class TestTokenValidation:
|
||||
"""Tests for token validation."""
|
||||
|
||||
def test_is_valid_active_token(self, patched_session, admin_token):
|
||||
"""Test that active token is valid."""
|
||||
repo = TokenRepository()
|
||||
|
||||
result = repo.is_valid(admin_token.token)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_is_valid_nonexistent_token(self, patched_session):
|
||||
"""Test that nonexistent token is invalid."""
|
||||
repo = TokenRepository()
|
||||
|
||||
result = repo.is_valid("nonexistent-token-12345")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_is_valid_deactivated_token(self, patched_session, admin_token):
|
||||
"""Test that deactivated token is invalid."""
|
||||
repo = TokenRepository()
|
||||
|
||||
repo.deactivate(admin_token.token)
|
||||
result = repo.is_valid(admin_token.token)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_is_valid_expired_token(self, patched_session):
|
||||
"""Test that expired token is invalid."""
|
||||
repo = TokenRepository()
|
||||
past_expiry = datetime.now(timezone.utc) - timedelta(days=1)
|
||||
|
||||
repo.create(
|
||||
token="expired-token-test",
|
||||
name="Expired Token",
|
||||
expires_at=past_expiry,
|
||||
)
|
||||
|
||||
result = repo.is_valid("expired-token-test")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_is_valid_not_yet_expired_token(self, patched_session):
|
||||
"""Test that not-yet-expired token is valid."""
|
||||
repo = TokenRepository()
|
||||
future_expiry = datetime.now(timezone.utc) + timedelta(days=7)
|
||||
|
||||
repo.create(
|
||||
token="valid-expiring-token",
|
||||
name="Valid Expiring Token",
|
||||
expires_at=future_expiry,
|
||||
)
|
||||
|
||||
result = repo.is_valid("valid-expiring-token")
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestTokenGet:
|
||||
"""Tests for token retrieval."""
|
||||
|
||||
def test_get_existing_token(self, patched_session, admin_token):
|
||||
"""Test getting an existing token."""
|
||||
repo = TokenRepository()
|
||||
|
||||
token = repo.get(admin_token.token)
|
||||
|
||||
assert token is not None
|
||||
assert token.token == admin_token.token
|
||||
assert token.name == admin_token.name
|
||||
|
||||
def test_get_nonexistent_token(self, patched_session):
|
||||
"""Test getting a token that doesn't exist."""
|
||||
repo = TokenRepository()
|
||||
|
||||
token = repo.get("nonexistent-token-xyz")
|
||||
|
||||
assert token is None
|
||||
|
||||
|
||||
class TestTokenDeactivate:
|
||||
"""Tests for token deactivation."""
|
||||
|
||||
def test_deactivate_existing_token(self, patched_session, admin_token):
|
||||
"""Test deactivating an existing token."""
|
||||
repo = TokenRepository()
|
||||
|
||||
result = repo.deactivate(admin_token.token)
|
||||
|
||||
assert result is True
|
||||
token = repo.get(admin_token.token)
|
||||
assert token is not None
|
||||
assert token.is_active is False
|
||||
|
||||
def test_deactivate_nonexistent_token(self, patched_session):
|
||||
"""Test deactivating a token that doesn't exist."""
|
||||
repo = TokenRepository()
|
||||
|
||||
result = repo.deactivate("nonexistent-token-abc")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_reactivate_deactivated_token(self, patched_session, admin_token):
|
||||
"""Test reactivating a deactivated token via create."""
|
||||
repo = TokenRepository()
|
||||
|
||||
# Deactivate first
|
||||
repo.deactivate(admin_token.token)
|
||||
assert repo.is_valid(admin_token.token) is False
|
||||
|
||||
# Reactivate via create
|
||||
repo.create(
|
||||
token=admin_token.token,
|
||||
name="Reactivated Admin",
|
||||
)
|
||||
|
||||
assert repo.is_valid(admin_token.token) is True
|
||||
|
||||
|
||||
class TestTokenUsageTracking:
|
||||
"""Tests for token usage tracking."""
|
||||
|
||||
def test_update_usage(self, patched_session, admin_token):
|
||||
"""Test updating token last used timestamp."""
|
||||
repo = TokenRepository()
|
||||
|
||||
# Initially last_used_at might be None
|
||||
initial_token = repo.get(admin_token.token)
|
||||
initial_last_used = initial_token.last_used_at
|
||||
|
||||
repo.update_usage(admin_token.token)
|
||||
|
||||
updated_token = repo.get(admin_token.token)
|
||||
assert updated_token.last_used_at is not None
|
||||
if initial_last_used:
|
||||
assert updated_token.last_used_at >= initial_last_used
|
||||
|
||||
def test_update_usage_nonexistent_token(self, patched_session):
|
||||
"""Test updating usage for nonexistent token does nothing."""
|
||||
repo = TokenRepository()
|
||||
|
||||
# Should not raise, just does nothing
|
||||
repo.update_usage("nonexistent-token-usage")
|
||||
|
||||
token = repo.get("nonexistent-token-usage")
|
||||
assert token is None
|
||||
|
||||
|
||||
class TestTokenWorkflow:
|
||||
"""Tests for complete token workflows."""
|
||||
|
||||
def test_full_token_lifecycle(self, patched_session):
|
||||
"""Test complete token lifecycle: create, validate, use, deactivate."""
|
||||
repo = TokenRepository()
|
||||
token_str = "lifecycle-test-token"
|
||||
|
||||
# 1. Create token
|
||||
repo.create(token=token_str, name="Lifecycle Token")
|
||||
assert repo.is_valid(token_str) is True
|
||||
|
||||
# 2. Use token
|
||||
repo.update_usage(token_str)
|
||||
token = repo.get(token_str)
|
||||
assert token.last_used_at is not None
|
||||
|
||||
# 3. Update token info
|
||||
new_expiry = datetime.now(timezone.utc) + timedelta(days=90)
|
||||
repo.create(
|
||||
token=token_str,
|
||||
name="Updated Lifecycle Token",
|
||||
expires_at=new_expiry,
|
||||
)
|
||||
token = repo.get(token_str)
|
||||
assert token.name == "Updated Lifecycle Token"
|
||||
|
||||
# 4. Deactivate token
|
||||
result = repo.deactivate(token_str)
|
||||
assert result is True
|
||||
assert repo.is_valid(token_str) is False
|
||||
|
||||
# 5. Reactivate token
|
||||
repo.create(token=token_str, name="Reactivated Token")
|
||||
assert repo.is_valid(token_str) is True
|
||||
|
||||
def test_multiple_tokens(self, patched_session):
|
||||
"""Test managing multiple tokens."""
|
||||
repo = TokenRepository()
|
||||
|
||||
# Create multiple tokens
|
||||
tokens = [
|
||||
("token-a", "Admin A"),
|
||||
("token-b", "Admin B"),
|
||||
("token-c", "Admin C"),
|
||||
]
|
||||
|
||||
for token_str, name in tokens:
|
||||
repo.create(token=token_str, name=name)
|
||||
|
||||
# Verify all are valid
|
||||
for token_str, _ in tokens:
|
||||
assert repo.is_valid(token_str) is True
|
||||
|
||||
# Deactivate one
|
||||
repo.deactivate("token-b")
|
||||
|
||||
# Verify states
|
||||
assert repo.is_valid("token-a") is True
|
||||
assert repo.is_valid("token-b") is False
|
||||
assert repo.is_valid("token-c") is True
|
||||
@@ -0,0 +1,364 @@
|
||||
"""
|
||||
Training Task Repository Integration Tests
|
||||
|
||||
Tests TrainingTaskRepository with real database operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.training_task_repository import TrainingTaskRepository
|
||||
|
||||
|
||||
class TestTrainingTaskCreate:
|
||||
"""Tests for training task creation."""
|
||||
|
||||
def test_create_training_task(self, patched_session, admin_token):
|
||||
"""Test creating a training task."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
task_id = repo.create(
|
||||
admin_token=admin_token.token,
|
||||
name="Test Training Task",
|
||||
task_type="train",
|
||||
description="Integration test training task",
|
||||
config={"epochs": 100, "batch_size": 16},
|
||||
)
|
||||
|
||||
assert task_id is not None
|
||||
|
||||
task = repo.get(task_id)
|
||||
assert task is not None
|
||||
assert task.name == "Test Training Task"
|
||||
assert task.task_type == "train"
|
||||
assert task.status == "pending"
|
||||
assert task.config["epochs"] == 100
|
||||
|
||||
def test_create_scheduled_task(self, patched_session, admin_token):
|
||||
"""Test creating a scheduled training task."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
scheduled_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
|
||||
task_id = repo.create(
|
||||
admin_token=admin_token.token,
|
||||
name="Scheduled Task",
|
||||
scheduled_at=scheduled_time,
|
||||
)
|
||||
|
||||
task = repo.get(task_id)
|
||||
assert task is not None
|
||||
assert task.status == "scheduled"
|
||||
assert task.scheduled_at is not None
|
||||
|
||||
def test_create_recurring_task(self, patched_session, admin_token):
|
||||
"""Test creating a recurring training task."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
task_id = repo.create(
|
||||
admin_token=admin_token.token,
|
||||
name="Recurring Task",
|
||||
cron_expression="0 2 * * *",
|
||||
is_recurring=True,
|
||||
)
|
||||
|
||||
task = repo.get(task_id)
|
||||
assert task is not None
|
||||
assert task.is_recurring is True
|
||||
assert task.cron_expression == "0 2 * * *"
|
||||
|
||||
def test_create_task_with_dataset(self, patched_session, admin_token, sample_dataset):
|
||||
"""Test creating task linked to a dataset."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
task_id = repo.create(
|
||||
admin_token=admin_token.token,
|
||||
name="Dataset Training Task",
|
||||
dataset_id=str(sample_dataset.dataset_id),
|
||||
)
|
||||
|
||||
task = repo.get(task_id)
|
||||
assert task is not None
|
||||
assert task.dataset_id == sample_dataset.dataset_id
|
||||
|
||||
|
||||
class TestTrainingTaskRead:
|
||||
"""Tests for training task retrieval."""
|
||||
|
||||
def test_get_task_by_id(self, patched_session, sample_training_task):
|
||||
"""Test getting task by ID."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
task = repo.get(str(sample_training_task.task_id))
|
||||
|
||||
assert task is not None
|
||||
assert task.task_id == sample_training_task.task_id
|
||||
|
||||
def test_get_nonexistent_task(self, patched_session):
|
||||
"""Test getting task that doesn't exist."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
task = repo.get(str(uuid4()))
|
||||
assert task is None
|
||||
|
||||
def test_get_paginated_tasks(self, patched_session, admin_token):
|
||||
"""Test paginated task listing."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
# Create multiple tasks
|
||||
for i in range(5):
|
||||
repo.create(admin_token=admin_token.token, name=f"Task {i}")
|
||||
|
||||
tasks, total = repo.get_paginated(limit=2, offset=0)
|
||||
|
||||
assert total == 5
|
||||
assert len(tasks) == 2
|
||||
|
||||
def test_get_paginated_with_status_filter(self, patched_session, admin_token):
|
||||
"""Test filtering tasks by status."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
# Create tasks with different statuses
|
||||
task_id = repo.create(admin_token=admin_token.token, name="Running Task")
|
||||
repo.update_status(task_id, "running")
|
||||
|
||||
repo.create(admin_token=admin_token.token, name="Pending Task")
|
||||
|
||||
tasks, total = repo.get_paginated(status="running")
|
||||
|
||||
assert total == 1
|
||||
assert tasks[0].status == "running"
|
||||
|
||||
def test_get_pending_tasks(self, patched_session, admin_token):
|
||||
"""Test getting pending tasks ready to run."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
# Create pending task
|
||||
repo.create(admin_token=admin_token.token, name="Ready Task")
|
||||
|
||||
# Create scheduled task in the past (should be included)
|
||||
past_time = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
repo.create(
|
||||
admin_token=admin_token.token,
|
||||
name="Past Scheduled Task",
|
||||
scheduled_at=past_time,
|
||||
)
|
||||
|
||||
# Create scheduled task in the future (should not be included)
|
||||
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
repo.create(
|
||||
admin_token=admin_token.token,
|
||||
name="Future Scheduled Task",
|
||||
scheduled_at=future_time,
|
||||
)
|
||||
|
||||
pending = repo.get_pending()
|
||||
|
||||
# Should include pending and past scheduled, not future scheduled
|
||||
assert len(pending) >= 2
|
||||
names = [t.name for t in pending]
|
||||
assert "Ready Task" in names
|
||||
assert "Past Scheduled Task" in names
|
||||
|
||||
def test_get_running_task(self, patched_session, admin_token):
|
||||
"""Test getting currently running task."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
task_id = repo.create(admin_token=admin_token.token, name="Running Task")
|
||||
repo.update_status(task_id, "running")
|
||||
|
||||
running = repo.get_running()
|
||||
|
||||
assert running is not None
|
||||
assert running.status == "running"
|
||||
|
||||
def test_get_running_task_none(self, patched_session, admin_token):
|
||||
"""Test getting running task when none is running."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
repo.create(admin_token=admin_token.token, name="Pending Task")
|
||||
|
||||
running = repo.get_running()
|
||||
assert running is None
|
||||
|
||||
|
||||
class TestTrainingTaskUpdate:
|
||||
"""Tests for training task updates."""
|
||||
|
||||
def test_update_status_to_running(self, patched_session, sample_training_task):
|
||||
"""Test updating task status to running."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
repo.update_status(str(sample_training_task.task_id), "running")
|
||||
|
||||
task = repo.get(str(sample_training_task.task_id))
|
||||
assert task is not None
|
||||
assert task.status == "running"
|
||||
assert task.started_at is not None
|
||||
|
||||
def test_update_status_to_completed(self, patched_session, sample_training_task):
|
||||
"""Test updating task status to completed."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
metrics = {"mAP": 0.92, "precision": 0.89, "recall": 0.85}
|
||||
|
||||
repo.update_status(
|
||||
str(sample_training_task.task_id),
|
||||
"completed",
|
||||
result_metrics=metrics,
|
||||
model_path="/models/trained_model.pt",
|
||||
)
|
||||
|
||||
task = repo.get(str(sample_training_task.task_id))
|
||||
assert task is not None
|
||||
assert task.status == "completed"
|
||||
assert task.completed_at is not None
|
||||
assert task.result_metrics["mAP"] == 0.92
|
||||
assert task.model_path == "/models/trained_model.pt"
|
||||
|
||||
def test_update_status_to_failed(self, patched_session, sample_training_task):
|
||||
"""Test updating task status to failed with error message."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
repo.update_status(
|
||||
str(sample_training_task.task_id),
|
||||
"failed",
|
||||
error_message="CUDA out of memory",
|
||||
)
|
||||
|
||||
task = repo.get(str(sample_training_task.task_id))
|
||||
assert task is not None
|
||||
assert task.status == "failed"
|
||||
assert task.completed_at is not None
|
||||
assert "CUDA out of memory" in task.error_message
|
||||
|
||||
def test_cancel_pending_task(self, patched_session, sample_training_task):
|
||||
"""Test cancelling a pending task."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
result = repo.cancel(str(sample_training_task.task_id))
|
||||
|
||||
assert result is True
|
||||
|
||||
task = repo.get(str(sample_training_task.task_id))
|
||||
assert task is not None
|
||||
assert task.status == "cancelled"
|
||||
|
||||
def test_cannot_cancel_running_task(self, patched_session, sample_training_task):
|
||||
"""Test that running task cannot be cancelled."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
repo.update_status(str(sample_training_task.task_id), "running")
|
||||
|
||||
result = repo.cancel(str(sample_training_task.task_id))
|
||||
|
||||
assert result is False
|
||||
|
||||
task = repo.get(str(sample_training_task.task_id))
|
||||
assert task.status == "running"
|
||||
|
||||
|
||||
class TestTrainingLogs:
|
||||
"""Tests for training log management."""
|
||||
|
||||
def test_add_log_entry(self, patched_session, sample_training_task):
|
||||
"""Test adding a training log entry."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
repo.add_log(
|
||||
str(sample_training_task.task_id),
|
||||
level="INFO",
|
||||
message="Starting training...",
|
||||
details={"epoch": 1, "batch": 0},
|
||||
)
|
||||
|
||||
logs = repo.get_logs(str(sample_training_task.task_id))
|
||||
assert len(logs) == 1
|
||||
assert logs[0].level == "INFO"
|
||||
assert logs[0].message == "Starting training..."
|
||||
|
||||
def test_add_multiple_log_entries(self, patched_session, sample_training_task):
|
||||
"""Test adding multiple log entries."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
for i in range(5):
|
||||
repo.add_log(
|
||||
str(sample_training_task.task_id),
|
||||
level="INFO",
|
||||
message=f"Epoch {i} completed",
|
||||
details={"epoch": i, "loss": 0.5 - i * 0.1},
|
||||
)
|
||||
|
||||
logs = repo.get_logs(str(sample_training_task.task_id))
|
||||
assert len(logs) == 5
|
||||
|
||||
def test_get_logs_pagination(self, patched_session, sample_training_task):
|
||||
"""Test paginated log retrieval."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
for i in range(10):
|
||||
repo.add_log(
|
||||
str(sample_training_task.task_id),
|
||||
level="INFO",
|
||||
message=f"Log entry {i}",
|
||||
)
|
||||
|
||||
logs = repo.get_logs(str(sample_training_task.task_id), limit=5, offset=0)
|
||||
assert len(logs) == 5
|
||||
|
||||
logs_page2 = repo.get_logs(str(sample_training_task.task_id), limit=5, offset=5)
|
||||
assert len(logs_page2) == 5
|
||||
|
||||
|
||||
class TestDocumentLinks:
|
||||
"""Tests for training document link management."""
|
||||
|
||||
def test_create_document_link(self, patched_session, sample_training_task, sample_document):
|
||||
"""Test creating a document link."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
link = repo.create_document_link(
|
||||
task_id=sample_training_task.task_id,
|
||||
document_id=sample_document.document_id,
|
||||
annotation_snapshot={"count": 5, "verified": 3},
|
||||
)
|
||||
|
||||
assert link is not None
|
||||
assert link.task_id == sample_training_task.task_id
|
||||
assert link.document_id == sample_document.document_id
|
||||
assert link.annotation_snapshot["count"] == 5
|
||||
|
||||
def test_get_document_links(self, patched_session, sample_training_task, multiple_documents):
|
||||
"""Test getting all document links for a task."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
for doc in multiple_documents[:3]:
|
||||
repo.create_document_link(
|
||||
task_id=sample_training_task.task_id,
|
||||
document_id=doc.document_id,
|
||||
)
|
||||
|
||||
links = repo.get_document_links(sample_training_task.task_id)
|
||||
assert len(links) == 3
|
||||
|
||||
def test_get_document_training_tasks(self, patched_session, admin_token, sample_document):
|
||||
"""Test getting training tasks that used a document."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
# Create multiple tasks using the same document
|
||||
task1_id = repo.create(admin_token=admin_token.token, name="Task 1")
|
||||
task2_id = repo.create(admin_token=admin_token.token, name="Task 2")
|
||||
|
||||
repo.create_document_link(
|
||||
task_id=repo.get(task1_id).task_id,
|
||||
document_id=sample_document.document_id,
|
||||
)
|
||||
repo.create_document_link(
|
||||
task_id=repo.get(task2_id).task_id,
|
||||
document_id=sample_document.document_id,
|
||||
)
|
||||
|
||||
links = repo.get_document_training_tasks(sample_document.document_id)
|
||||
assert len(links) == 2
|
||||
1
tests/integration/services/__init__.py
Normal file
1
tests/integration/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Service integration tests."""
|
||||
497
tests/integration/services/test_dashboard_service_integration.py
Normal file
497
tests/integration/services/test_dashboard_service_integration.py
Normal file
@@ -0,0 +1,497 @@
|
||||
"""
|
||||
Dashboard Service Integration Tests
|
||||
|
||||
Tests DashboardStatsService and DashboardActivityService with real database operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.admin_models import (
|
||||
AdminAnnotation,
|
||||
AdminDocument,
|
||||
AnnotationHistory,
|
||||
ModelVersion,
|
||||
TrainingDataset,
|
||||
TrainingTask,
|
||||
)
|
||||
from inference.web.services.dashboard_service import (
|
||||
DashboardStatsService,
|
||||
DashboardActivityService,
|
||||
is_annotation_complete,
|
||||
IDENTIFIER_CLASS_IDS,
|
||||
PAYMENT_CLASS_IDS,
|
||||
)
|
||||
|
||||
|
||||
class TestIsAnnotationComplete:
|
||||
"""Tests for is_annotation_complete function."""
|
||||
|
||||
def test_complete_with_invoice_number_and_bankgiro(self):
|
||||
"""Test complete with invoice_number (0) and bankgiro (4)."""
|
||||
annotations = [
|
||||
{"class_id": 0}, # invoice_number
|
||||
{"class_id": 4}, # bankgiro
|
||||
]
|
||||
assert is_annotation_complete(annotations) is True
|
||||
|
||||
def test_complete_with_ocr_number_and_plusgiro(self):
|
||||
"""Test complete with ocr_number (3) and plusgiro (5)."""
|
||||
annotations = [
|
||||
{"class_id": 3}, # ocr_number
|
||||
{"class_id": 5}, # plusgiro
|
||||
]
|
||||
assert is_annotation_complete(annotations) is True
|
||||
|
||||
def test_incomplete_missing_identifier(self):
|
||||
"""Test incomplete when missing identifier."""
|
||||
annotations = [
|
||||
{"class_id": 4}, # bankgiro only
|
||||
]
|
||||
assert is_annotation_complete(annotations) is False
|
||||
|
||||
def test_incomplete_missing_payment(self):
|
||||
"""Test incomplete when missing payment."""
|
||||
annotations = [
|
||||
{"class_id": 0}, # invoice_number only
|
||||
]
|
||||
assert is_annotation_complete(annotations) is False
|
||||
|
||||
def test_incomplete_empty_annotations(self):
|
||||
"""Test incomplete with empty annotations."""
|
||||
assert is_annotation_complete([]) is False
|
||||
|
||||
def test_complete_with_multiple_fields(self):
|
||||
"""Test complete with multiple fields."""
|
||||
annotations = [
|
||||
{"class_id": 0}, # invoice_number
|
||||
{"class_id": 1}, # invoice_date
|
||||
{"class_id": 3}, # ocr_number
|
||||
{"class_id": 4}, # bankgiro
|
||||
{"class_id": 5}, # plusgiro
|
||||
{"class_id": 6}, # amount
|
||||
]
|
||||
assert is_annotation_complete(annotations) is True
|
||||
|
||||
|
||||
class TestDashboardStatsService:
|
||||
"""Tests for DashboardStatsService."""
|
||||
|
||||
def test_get_stats_empty_database(self, patched_session):
|
||||
"""Test stats with empty database."""
|
||||
service = DashboardStatsService()
|
||||
|
||||
stats = service.get_stats()
|
||||
|
||||
assert stats["total_documents"] == 0
|
||||
assert stats["annotation_complete"] == 0
|
||||
assert stats["annotation_incomplete"] == 0
|
||||
assert stats["pending"] == 0
|
||||
assert stats["completeness_rate"] == 0.0
|
||||
|
||||
def test_get_stats_with_documents(self, patched_session, admin_token):
|
||||
"""Test stats with various document states."""
|
||||
service = DashboardStatsService()
|
||||
session = patched_session
|
||||
|
||||
# Create documents with different statuses
|
||||
docs = []
|
||||
for i, status in enumerate(["pending", "auto_labeling", "labeled", "labeled", "exported"]):
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename=f"doc_{i}.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path=f"/uploads/doc_{i}.pdf",
|
||||
page_count=1,
|
||||
status=status,
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(doc)
|
||||
docs.append(doc)
|
||||
session.commit()
|
||||
|
||||
stats = service.get_stats()
|
||||
|
||||
assert stats["total_documents"] == 5
|
||||
assert stats["pending"] == 2 # pending + auto_labeling
|
||||
|
||||
def test_get_stats_complete_annotations(self, patched_session, admin_token):
|
||||
"""Test completeness calculation with proper annotations."""
|
||||
service = DashboardStatsService()
|
||||
session = patched_session
|
||||
|
||||
# Create a labeled document with complete annotations
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename="complete_doc.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/uploads/complete_doc.pdf",
|
||||
page_count=1,
|
||||
status="labeled",
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(doc)
|
||||
session.commit()
|
||||
|
||||
# Add identifier annotation (invoice_number = class_id 0)
|
||||
ann1 = AdminAnnotation(
|
||||
annotation_id=uuid4(),
|
||||
document_id=doc.document_id,
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.1,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=80,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
text_value="INV-001",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(ann1)
|
||||
|
||||
# Add payment annotation (bankgiro = class_id 4)
|
||||
ann2 = AdminAnnotation(
|
||||
annotation_id=uuid4(),
|
||||
document_id=doc.document_id,
|
||||
page_number=1,
|
||||
class_id=4,
|
||||
class_name="bankgiro",
|
||||
x_center=0.5,
|
||||
y_center=0.2,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=160,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
text_value="123-4567",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(ann2)
|
||||
session.commit()
|
||||
|
||||
stats = service.get_stats()
|
||||
|
||||
assert stats["annotation_complete"] == 1
|
||||
assert stats["annotation_incomplete"] == 0
|
||||
assert stats["completeness_rate"] == 100.0
|
||||
|
||||
def test_get_stats_incomplete_annotations(self, patched_session, admin_token):
|
||||
"""Test completeness with incomplete annotations."""
|
||||
service = DashboardStatsService()
|
||||
session = patched_session
|
||||
|
||||
# Create a labeled document missing payment annotation
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename="incomplete_doc.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/uploads/incomplete_doc.pdf",
|
||||
page_count=1,
|
||||
status="labeled",
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(doc)
|
||||
session.commit()
|
||||
|
||||
# Add only identifier annotation (missing payment)
|
||||
ann = AdminAnnotation(
|
||||
annotation_id=uuid4(),
|
||||
document_id=doc.document_id,
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.1,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=80,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
text_value="INV-001",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(ann)
|
||||
session.commit()
|
||||
|
||||
stats = service.get_stats()
|
||||
|
||||
assert stats["annotation_complete"] == 0
|
||||
assert stats["annotation_incomplete"] == 1
|
||||
assert stats["completeness_rate"] == 0.0
|
||||
|
||||
def test_get_stats_mixed_completeness(self, patched_session, admin_token):
|
||||
"""Test stats with mix of complete and incomplete documents."""
|
||||
service = DashboardStatsService()
|
||||
session = patched_session
|
||||
|
||||
# Create 2 labeled documents
|
||||
docs = []
|
||||
for i in range(2):
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename=f"mixed_doc_{i}.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path=f"/uploads/mixed_doc_{i}.pdf",
|
||||
page_count=1,
|
||||
status="labeled",
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(doc)
|
||||
docs.append(doc)
|
||||
session.commit()
|
||||
|
||||
# First document: complete (has identifier + payment)
|
||||
session.add(AdminAnnotation(
|
||||
annotation_id=uuid4(),
|
||||
document_id=docs[0].document_id,
|
||||
page_number=1,
|
||||
class_id=0, # invoice_number
|
||||
class_name="invoice_number",
|
||||
x_center=0.5, y_center=0.1, width=0.2, height=0.05,
|
||||
bbox_x=400, bbox_y=80, bbox_width=160, bbox_height=40,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
))
|
||||
session.add(AdminAnnotation(
|
||||
annotation_id=uuid4(),
|
||||
document_id=docs[0].document_id,
|
||||
page_number=1,
|
||||
class_id=4, # bankgiro
|
||||
class_name="bankgiro",
|
||||
x_center=0.5, y_center=0.2, width=0.2, height=0.05,
|
||||
bbox_x=400, bbox_y=160, bbox_width=160, bbox_height=40,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
))
|
||||
|
||||
# Second document: incomplete (missing payment)
|
||||
session.add(AdminAnnotation(
|
||||
annotation_id=uuid4(),
|
||||
document_id=docs[1].document_id,
|
||||
page_number=1,
|
||||
class_id=0, # invoice_number only
|
||||
class_name="invoice_number",
|
||||
x_center=0.5, y_center=0.1, width=0.2, height=0.05,
|
||||
bbox_x=400, bbox_y=80, bbox_width=160, bbox_height=40,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
))
|
||||
session.commit()
|
||||
|
||||
stats = service.get_stats()
|
||||
|
||||
assert stats["annotation_complete"] == 1
|
||||
assert stats["annotation_incomplete"] == 1
|
||||
assert stats["completeness_rate"] == 50.0
|
||||
|
||||
|
||||
class TestDashboardActivityService:
|
||||
"""Tests for DashboardActivityService."""
|
||||
|
||||
def test_get_recent_activities_empty(self, patched_session):
|
||||
"""Test activities with empty database."""
|
||||
service = DashboardActivityService()
|
||||
|
||||
activities = service.get_recent_activities()
|
||||
|
||||
assert activities == []
|
||||
|
||||
def test_get_recent_activities_document_uploads(self, patched_session, admin_token):
|
||||
"""Test activities include document uploads."""
|
||||
service = DashboardActivityService()
|
||||
session = patched_session
|
||||
|
||||
# Create documents
|
||||
for i in range(3):
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename=f"activity_doc_{i}.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path=f"/uploads/activity_doc_{i}.pdf",
|
||||
page_count=1,
|
||||
status="pending",
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(doc)
|
||||
session.commit()
|
||||
|
||||
activities = service.get_recent_activities()
|
||||
|
||||
upload_activities = [a for a in activities if a["type"] == "document_uploaded"]
|
||||
assert len(upload_activities) == 3
|
||||
|
||||
def test_get_recent_activities_annotation_overrides(self, patched_session, sample_document, sample_annotation):
|
||||
"""Test activities include annotation overrides."""
|
||||
service = DashboardActivityService()
|
||||
session = patched_session
|
||||
|
||||
# Create annotation history with override
|
||||
history = AnnotationHistory(
|
||||
history_id=uuid4(),
|
||||
annotation_id=sample_annotation.annotation_id,
|
||||
document_id=sample_document.document_id,
|
||||
action="override",
|
||||
previous_value={"text_value": "OLD-001"},
|
||||
new_value={"text_value": "NEW-001", "class_name": "invoice_number"},
|
||||
changed_by="test-admin",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(history)
|
||||
session.commit()
|
||||
|
||||
activities = service.get_recent_activities()
|
||||
|
||||
override_activities = [a for a in activities if a["type"] == "annotation_modified"]
|
||||
assert len(override_activities) >= 1
|
||||
|
||||
def test_get_recent_activities_training_completed(self, patched_session, sample_training_task):
|
||||
"""Test activities include training completions."""
|
||||
service = DashboardActivityService()
|
||||
session = patched_session
|
||||
|
||||
# Update training task to completed
|
||||
sample_training_task.status = "completed"
|
||||
sample_training_task.metrics_mAP = 0.85
|
||||
sample_training_task.updated_at = datetime.now(timezone.utc)
|
||||
session.add(sample_training_task)
|
||||
session.commit()
|
||||
|
||||
activities = service.get_recent_activities()
|
||||
|
||||
training_activities = [a for a in activities if a["type"] == "training_completed"]
|
||||
assert len(training_activities) >= 1
|
||||
assert "mAP" in training_activities[0]["metadata"]
|
||||
|
||||
def test_get_recent_activities_training_failed(self, patched_session, sample_training_task):
|
||||
"""Test activities include training failures."""
|
||||
service = DashboardActivityService()
|
||||
session = patched_session
|
||||
|
||||
# Update training task to failed
|
||||
sample_training_task.status = "failed"
|
||||
sample_training_task.error_message = "CUDA out of memory"
|
||||
sample_training_task.updated_at = datetime.now(timezone.utc)
|
||||
session.add(sample_training_task)
|
||||
session.commit()
|
||||
|
||||
activities = service.get_recent_activities()
|
||||
|
||||
failed_activities = [a for a in activities if a["type"] == "training_failed"]
|
||||
assert len(failed_activities) >= 1
|
||||
assert failed_activities[0]["metadata"]["error"] == "CUDA out of memory"
|
||||
|
||||
def test_get_recent_activities_model_activated(self, patched_session, sample_model_version):
|
||||
"""Test activities include model activations."""
|
||||
service = DashboardActivityService()
|
||||
session = patched_session
|
||||
|
||||
# Activate model
|
||||
sample_model_version.is_active = True
|
||||
sample_model_version.activated_at = datetime.now(timezone.utc)
|
||||
session.add(sample_model_version)
|
||||
session.commit()
|
||||
|
||||
activities = service.get_recent_activities()
|
||||
|
||||
activation_activities = [a for a in activities if a["type"] == "model_activated"]
|
||||
assert len(activation_activities) >= 1
|
||||
assert activation_activities[0]["metadata"]["version"] == sample_model_version.version
|
||||
|
||||
def test_get_recent_activities_limit(self, patched_session, admin_token):
|
||||
"""Test activity limit parameter."""
|
||||
service = DashboardActivityService()
|
||||
session = patched_session
|
||||
|
||||
# Create many documents
|
||||
for i in range(20):
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename=f"limit_doc_{i}.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path=f"/uploads/limit_doc_{i}.pdf",
|
||||
page_count=1,
|
||||
status="pending",
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(doc)
|
||||
session.commit()
|
||||
|
||||
activities = service.get_recent_activities(limit=5)
|
||||
|
||||
assert len(activities) <= 5
|
||||
|
||||
def test_get_recent_activities_sorted_by_timestamp(self, patched_session, admin_token, sample_training_task):
|
||||
"""Test activities are sorted by timestamp descending."""
|
||||
service = DashboardActivityService()
|
||||
session = patched_session
|
||||
|
||||
# Create document
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename="sorted_doc.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/uploads/sorted_doc.pdf",
|
||||
page_count=1,
|
||||
status="pending",
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(doc)
|
||||
|
||||
# Complete training task
|
||||
sample_training_task.status = "completed"
|
||||
sample_training_task.metrics_mAP = 0.90
|
||||
sample_training_task.updated_at = datetime.now(timezone.utc)
|
||||
session.add(sample_training_task)
|
||||
session.commit()
|
||||
|
||||
activities = service.get_recent_activities()
|
||||
|
||||
# Verify sorted by timestamp DESC
|
||||
timestamps = [a["timestamp"] for a in activities]
|
||||
assert timestamps == sorted(timestamps, reverse=True)
|
||||
453
tests/integration/services/test_dataset_builder_integration.py
Normal file
453
tests/integration/services/test_dataset_builder_integration.py
Normal file
@@ -0,0 +1,453 @@
|
||||
"""
|
||||
Dataset Builder Service Integration Tests
|
||||
|
||||
Tests DatasetBuilder with real file operations and repository interactions.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from inference.data.admin_models import AdminAnnotation, AdminDocument
|
||||
from inference.data.repositories.annotation_repository import AnnotationRepository
|
||||
from inference.data.repositories.dataset_repository import DatasetRepository
|
||||
from inference.data.repositories.document_repository import DocumentRepository
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_builder(patched_session, temp_dataset_dir):
|
||||
"""Create a DatasetBuilder with real repositories."""
|
||||
return DatasetBuilder(
|
||||
datasets_repo=DatasetRepository(),
|
||||
documents_repo=DocumentRepository(),
|
||||
annotations_repo=AnnotationRepository(),
|
||||
base_dir=temp_dataset_dir,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_images_dir(temp_upload_dir):
|
||||
"""Create a directory for admin images."""
|
||||
images_dir = temp_upload_dir / "admin_images"
|
||||
images_dir.mkdir(parents=True, exist_ok=True)
|
||||
return images_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def documents_with_annotations(patched_session, db_session, admin_token, admin_images_dir):
|
||||
"""Create documents with annotations and corresponding image files."""
|
||||
documents = []
|
||||
doc_repo = DocumentRepository()
|
||||
ann_repo = AnnotationRepository()
|
||||
|
||||
for i in range(5):
|
||||
# Create document
|
||||
doc_id = doc_repo.create(
|
||||
filename=f"invoice_{i}.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path=f"/uploads/invoice_{i}.pdf",
|
||||
page_count=2,
|
||||
category="invoice",
|
||||
group_key=f"group_{i % 2}", # Two groups
|
||||
)
|
||||
|
||||
# Create image files for each page
|
||||
doc_dir = admin_images_dir / doc_id
|
||||
doc_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for page in range(1, 3):
|
||||
image_path = doc_dir / f"page_{page}.png"
|
||||
# Create a minimal fake PNG
|
||||
image_path.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100)
|
||||
|
||||
# Create annotations
|
||||
for j in range(3):
|
||||
ann_repo.create(
|
||||
document_id=doc_id,
|
||||
page_number=1,
|
||||
class_id=j,
|
||||
class_name=f"field_{j}",
|
||||
x_center=0.5,
|
||||
y_center=0.1 + j * 0.2,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=80 + j * 160,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
text_value=f"value_{j}",
|
||||
confidence=0.95,
|
||||
source="auto",
|
||||
)
|
||||
|
||||
doc = doc_repo.get(doc_id)
|
||||
documents.append(doc)
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
class TestDatasetBuilderBasic:
|
||||
"""Tests for basic dataset building operations."""
|
||||
|
||||
def test_build_dataset_creates_directory_structure(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
|
||||
):
|
||||
"""Test that building creates proper directory structure."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Test Dataset")
|
||||
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
|
||||
|
||||
# Check directory structure
|
||||
assert (dataset_dir / "images" / "train").exists()
|
||||
assert (dataset_dir / "images" / "val").exists()
|
||||
assert (dataset_dir / "images" / "test").exists()
|
||||
assert (dataset_dir / "labels" / "train").exists()
|
||||
assert (dataset_dir / "labels" / "val").exists()
|
||||
assert (dataset_dir / "labels" / "test").exists()
|
||||
|
||||
def test_build_dataset_copies_images(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
|
||||
):
|
||||
"""Test that images are copied to dataset directory."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Image Copy Test")
|
||||
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
result = dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
|
||||
|
||||
# Count total images across all splits
|
||||
total_images = 0
|
||||
for split in ["train", "val", "test"]:
|
||||
images = list((dataset_dir / "images" / split).glob("*.png"))
|
||||
total_images += len(images)
|
||||
|
||||
# 5 docs * 2 pages = 10 images
|
||||
assert total_images == 10
|
||||
assert result["total_images"] == 10
|
||||
|
||||
def test_build_dataset_generates_labels(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
|
||||
):
|
||||
"""Test that YOLO label files are generated."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Label Generation Test")
|
||||
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
|
||||
|
||||
# Count total label files
|
||||
total_labels = 0
|
||||
for split in ["train", "val", "test"]:
|
||||
labels = list((dataset_dir / "labels" / split).glob("*.txt"))
|
||||
total_labels += len(labels)
|
||||
|
||||
# Same count as images
|
||||
assert total_labels == 10
|
||||
|
||||
def test_build_dataset_generates_data_yaml(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
|
||||
):
|
||||
"""Test that data.yaml is generated correctly."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="YAML Generation Test")
|
||||
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
|
||||
yaml_path = dataset_dir / "data.yaml"
|
||||
|
||||
assert yaml_path.exists()
|
||||
|
||||
with open(yaml_path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
assert data["train"] == "images/train"
|
||||
assert data["val"] == "images/val"
|
||||
assert data["test"] == "images/test"
|
||||
assert "nc" in data
|
||||
assert "names" in data
|
||||
|
||||
|
||||
class TestDatasetBuilderSplits:
|
||||
"""Tests for train/val/test split assignment."""
|
||||
|
||||
def test_split_ratio_respected(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
|
||||
):
|
||||
"""Test that split ratios are approximately respected."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Split Ratio Test")
|
||||
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.6,
|
||||
val_ratio=0.2,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
# Check document assignments in database
|
||||
dataset_docs = dataset_repo.get_documents(str(dataset.dataset_id))
|
||||
|
||||
splits = {"train": 0, "val": 0, "test": 0}
|
||||
for doc in dataset_docs:
|
||||
splits[doc.split] += 1
|
||||
|
||||
# With 5 docs and ratios 0.6/0.2/0.2, expect ~3/1/1
|
||||
# Due to rounding and group constraints, allow some variation
|
||||
assert splits["train"] >= 2
|
||||
assert splits["val"] >= 1 or splits["test"] >= 1
|
||||
|
||||
def test_same_seed_same_split(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
|
||||
):
|
||||
"""Test that same seed produces same split."""
|
||||
dataset_repo = DatasetRepository()
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
# Build first dataset
|
||||
dataset1 = dataset_repo.create(name="Seed Test 1")
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset1.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=12345,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
# Build second dataset with same seed
|
||||
dataset2 = dataset_repo.create(name="Seed Test 2")
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset2.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=12345,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
# Compare splits
|
||||
docs1 = {str(d.document_id): d.split for d in dataset_repo.get_documents(str(dataset1.dataset_id))}
|
||||
docs2 = {str(d.document_id): d.split for d in dataset_repo.get_documents(str(dataset2.dataset_id))}
|
||||
|
||||
assert docs1 == docs2
|
||||
|
||||
|
||||
class TestDatasetBuilderDatabase:
|
||||
"""Tests for database interactions."""
|
||||
|
||||
def test_updates_dataset_status(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, patched_session
|
||||
):
|
||||
"""Test that dataset status is updated after build."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Status Update Test")
|
||||
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
updated = dataset_repo.get(str(dataset.dataset_id))
|
||||
|
||||
assert updated.status == "ready"
|
||||
assert updated.total_documents == 5
|
||||
assert updated.total_images == 10
|
||||
assert updated.total_annotations > 0
|
||||
assert updated.dataset_path is not None
|
||||
|
||||
def test_records_document_assignments(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, patched_session
|
||||
):
|
||||
"""Test that document assignments are recorded in database."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Assignment Recording Test")
|
||||
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
dataset_docs = dataset_repo.get_documents(str(dataset.dataset_id))
|
||||
|
||||
assert len(dataset_docs) == 5
|
||||
|
||||
for doc in dataset_docs:
|
||||
assert doc.split in ["train", "val", "test"]
|
||||
assert doc.page_count > 0
|
||||
|
||||
|
||||
class TestDatasetBuilderErrors:
|
||||
"""Tests for error handling."""
|
||||
|
||||
def test_fails_with_no_documents(self, dataset_builder, admin_images_dir, patched_session):
|
||||
"""Test that building fails with empty document list."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Empty Docs Test")
|
||||
|
||||
with pytest.raises(ValueError, match="No valid documents"):
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[],
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
def test_fails_with_invalid_doc_ids(self, dataset_builder, admin_images_dir, patched_session):
|
||||
"""Test that building fails with nonexistent document IDs."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Invalid IDs Test")
|
||||
|
||||
fake_ids = [str(uuid4()) for _ in range(3)]
|
||||
|
||||
with pytest.raises(ValueError, match="No valid documents"):
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=fake_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
def test_updates_status_on_failure(self, dataset_builder, admin_images_dir, patched_session):
|
||||
"""Test that dataset status is set to failed on error."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Failure Status Test")
|
||||
|
||||
try:
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[],
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
updated = dataset_repo.get(str(dataset.dataset_id))
|
||||
assert updated.status == "failed"
|
||||
assert updated.error_message is not None
|
||||
|
||||
|
||||
class TestLabelFileFormat:
|
||||
"""Tests for YOLO label file format."""
|
||||
|
||||
def test_label_file_format(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
|
||||
):
|
||||
"""Test that label files are in correct YOLO format."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Label Format Test")
|
||||
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
|
||||
|
||||
# Find a label file with content
|
||||
label_files = []
|
||||
for split in ["train", "val", "test"]:
|
||||
label_files.extend(list((dataset_dir / "labels" / split).glob("*.txt")))
|
||||
|
||||
# Check at least one label file has correct format
|
||||
found_valid_label = False
|
||||
for label_file in label_files:
|
||||
content = label_file.read_text().strip()
|
||||
if content:
|
||||
lines = content.split("\n")
|
||||
for line in lines:
|
||||
parts = line.split()
|
||||
assert len(parts) == 5, f"Expected 5 parts, got {len(parts)}: {line}"
|
||||
|
||||
class_id = int(parts[0])
|
||||
x_center = float(parts[1])
|
||||
y_center = float(parts[2])
|
||||
width = float(parts[3])
|
||||
height = float(parts[4])
|
||||
|
||||
assert 0 <= class_id < 10
|
||||
assert 0 <= x_center <= 1
|
||||
assert 0 <= y_center <= 1
|
||||
assert 0 <= width <= 1
|
||||
assert 0 <= height <= 1
|
||||
|
||||
found_valid_label = True
|
||||
break
|
||||
|
||||
assert found_valid_label, "No valid label files found"
|
||||
283
tests/integration/services/test_document_service_integration.py
Normal file
283
tests/integration/services/test_document_service_integration.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
Document Service Integration Tests
|
||||
|
||||
Tests DocumentService with real storage operations.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.web.services.document_service import DocumentService, DocumentResult
|
||||
|
||||
|
||||
class MockStorageBackend:
|
||||
"""Simple in-memory storage backend for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self._files: dict[str, bytes] = {}
|
||||
|
||||
def upload_bytes(self, content: bytes, remote_path: str, overwrite: bool = False) -> None:
|
||||
if not overwrite and remote_path in self._files:
|
||||
raise FileExistsError(f"File already exists: {remote_path}")
|
||||
self._files[remote_path] = content
|
||||
|
||||
def download_bytes(self, remote_path: str) -> bytes:
|
||||
if remote_path not in self._files:
|
||||
raise FileNotFoundError(f"File not found: {remote_path}")
|
||||
return self._files[remote_path]
|
||||
|
||||
def get_presigned_url(self, remote_path: str, expires_in_seconds: int = 3600) -> str:
|
||||
return f"https://storage.example.com/{remote_path}?expires={expires_in_seconds}"
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
return remote_path in self._files
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
if remote_path in self._files:
|
||||
del self._files[remote_path]
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
return [path for path in self._files.keys() if path.startswith(prefix)]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage():
|
||||
"""Create a mock storage backend."""
|
||||
return MockStorageBackend()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_service(mock_storage):
|
||||
"""Create a DocumentService with mock storage."""
|
||||
return DocumentService(storage_backend=mock_storage)
|
||||
|
||||
|
||||
class TestDocumentUpload:
|
||||
"""Tests for document upload operations."""
|
||||
|
||||
def test_upload_document(self, document_service):
|
||||
"""Test uploading a document."""
|
||||
content = b"%PDF-1.4 test content"
|
||||
filename = "test_invoice.pdf"
|
||||
|
||||
result = document_service.upload_document(content, filename)
|
||||
|
||||
assert result is not None
|
||||
assert result.id is not None
|
||||
assert result.filename == filename
|
||||
assert result.file_path.startswith("documents/")
|
||||
assert result.file_path.endswith(".pdf")
|
||||
|
||||
def test_upload_document_with_custom_id(self, document_service):
|
||||
"""Test uploading with custom document ID."""
|
||||
content = b"%PDF-1.4 test content"
|
||||
filename = "invoice.pdf"
|
||||
custom_id = "custom-doc-12345"
|
||||
|
||||
result = document_service.upload_document(
|
||||
content, filename, document_id=custom_id
|
||||
)
|
||||
|
||||
assert result.id == custom_id
|
||||
assert custom_id in result.file_path
|
||||
|
||||
def test_upload_preserves_extension(self, document_service):
|
||||
"""Test that file extension is preserved."""
|
||||
cases = [
|
||||
("document.pdf", ".pdf"),
|
||||
("image.PNG", ".png"),
|
||||
("file.JPEG", ".jpeg"),
|
||||
("noextension", ""),
|
||||
]
|
||||
|
||||
for filename, expected_ext in cases:
|
||||
result = document_service.upload_document(b"content", filename)
|
||||
if expected_ext:
|
||||
assert result.file_path.endswith(expected_ext)
|
||||
|
||||
def test_upload_document_overwrite(self, document_service, mock_storage):
|
||||
"""Test that upload overwrites existing file."""
|
||||
content1 = b"original content"
|
||||
content2 = b"new content"
|
||||
doc_id = "overwrite-test"
|
||||
|
||||
document_service.upload_document(content1, "doc.pdf", document_id=doc_id)
|
||||
document_service.upload_document(content2, "doc.pdf", document_id=doc_id)
|
||||
|
||||
# Should have new content
|
||||
remote_path = f"documents/{doc_id}.pdf"
|
||||
stored_content = mock_storage.download_bytes(remote_path)
|
||||
assert stored_content == content2
|
||||
|
||||
|
||||
class TestDocumentDownload:
|
||||
"""Tests for document download operations."""
|
||||
|
||||
def test_download_document(self, document_service, mock_storage):
|
||||
"""Test downloading a document."""
|
||||
content = b"test document content"
|
||||
remote_path = "documents/test-doc.pdf"
|
||||
mock_storage.upload_bytes(content, remote_path)
|
||||
|
||||
downloaded = document_service.download_document(remote_path)
|
||||
|
||||
assert downloaded == content
|
||||
|
||||
def test_download_nonexistent_document(self, document_service):
|
||||
"""Test downloading document that doesn't exist."""
|
||||
with pytest.raises(FileNotFoundError):
|
||||
document_service.download_document("documents/nonexistent.pdf")
|
||||
|
||||
|
||||
class TestDocumentUrl:
|
||||
"""Tests for document URL generation."""
|
||||
|
||||
def test_get_document_url(self, document_service, mock_storage):
|
||||
"""Test getting presigned URL for document."""
|
||||
remote_path = "documents/test-doc.pdf"
|
||||
mock_storage.upload_bytes(b"content", remote_path)
|
||||
|
||||
url = document_service.get_document_url(remote_path, expires_in_seconds=7200)
|
||||
|
||||
assert url.startswith("https://")
|
||||
assert remote_path in url
|
||||
assert "7200" in url
|
||||
|
||||
def test_get_document_url_default_expiry(self, document_service):
|
||||
"""Test default URL expiry."""
|
||||
url = document_service.get_document_url("documents/doc.pdf")
|
||||
|
||||
assert "3600" in url
|
||||
|
||||
|
||||
class TestDocumentExists:
|
||||
"""Tests for document existence check."""
|
||||
|
||||
def test_document_exists(self, document_service, mock_storage):
|
||||
"""Test checking if document exists."""
|
||||
remote_path = "documents/existing.pdf"
|
||||
mock_storage.upload_bytes(b"content", remote_path)
|
||||
|
||||
assert document_service.document_exists(remote_path) is True
|
||||
|
||||
def test_document_not_exists(self, document_service):
|
||||
"""Test checking if nonexistent document exists."""
|
||||
assert document_service.document_exists("documents/nonexistent.pdf") is False
|
||||
|
||||
|
||||
class TestDocumentDelete:
|
||||
"""Tests for document deletion."""
|
||||
|
||||
def test_delete_document(self, document_service, mock_storage):
|
||||
"""Test deleting a document."""
|
||||
remote_path = "documents/to-delete.pdf"
|
||||
mock_storage.upload_bytes(b"content", remote_path)
|
||||
|
||||
result = document_service.delete_document_files(remote_path)
|
||||
|
||||
assert result is True
|
||||
assert document_service.document_exists(remote_path) is False
|
||||
|
||||
def test_delete_nonexistent_document(self, document_service):
|
||||
"""Test deleting document that doesn't exist."""
|
||||
result = document_service.delete_document_files("documents/nonexistent.pdf")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestPageImages:
|
||||
"""Tests for page image operations."""
|
||||
|
||||
def test_save_page_image(self, document_service, mock_storage):
|
||||
"""Test saving a page image."""
|
||||
doc_id = "test-doc-123"
|
||||
page_num = 1
|
||||
image_content = b"\x89PNG\r\n\x1a\n fake png"
|
||||
|
||||
remote_path = document_service.save_page_image(doc_id, page_num, image_content)
|
||||
|
||||
assert remote_path == f"images/{doc_id}/page_{page_num}.png"
|
||||
assert mock_storage.exists(remote_path)
|
||||
|
||||
def test_save_multiple_page_images(self, document_service, mock_storage):
|
||||
"""Test saving images for multiple pages."""
|
||||
doc_id = "multi-page-doc"
|
||||
|
||||
for page_num in range(1, 4):
|
||||
content = f"page {page_num} content".encode()
|
||||
document_service.save_page_image(doc_id, page_num, content)
|
||||
|
||||
images = document_service.list_document_images(doc_id)
|
||||
assert len(images) == 3
|
||||
|
||||
def test_get_page_image(self, document_service, mock_storage):
|
||||
"""Test downloading a page image."""
|
||||
doc_id = "test-doc"
|
||||
page_num = 2
|
||||
image_content = b"image data"
|
||||
|
||||
document_service.save_page_image(doc_id, page_num, image_content)
|
||||
downloaded = document_service.get_page_image(doc_id, page_num)
|
||||
|
||||
assert downloaded == image_content
|
||||
|
||||
def test_get_page_image_url(self, document_service):
|
||||
"""Test getting URL for page image."""
|
||||
doc_id = "test-doc"
|
||||
page_num = 1
|
||||
|
||||
url = document_service.get_page_image_url(doc_id, page_num)
|
||||
|
||||
assert f"images/{doc_id}/page_{page_num}.png" in url
|
||||
|
||||
def test_list_document_images(self, document_service, mock_storage):
|
||||
"""Test listing all images for a document."""
|
||||
doc_id = "list-test-doc"
|
||||
|
||||
for i in range(5):
|
||||
document_service.save_page_image(doc_id, i + 1, f"page {i}".encode())
|
||||
|
||||
images = document_service.list_document_images(doc_id)
|
||||
|
||||
assert len(images) == 5
|
||||
|
||||
def test_delete_document_images(self, document_service, mock_storage):
|
||||
"""Test deleting all images for a document."""
|
||||
doc_id = "delete-images-doc"
|
||||
|
||||
for i in range(3):
|
||||
document_service.save_page_image(doc_id, i + 1, b"content")
|
||||
|
||||
deleted_count = document_service.delete_document_images(doc_id)
|
||||
|
||||
assert deleted_count == 3
|
||||
assert len(document_service.list_document_images(doc_id)) == 0
|
||||
|
||||
|
||||
class TestRoundTrip:
|
||||
"""Tests for complete upload-download cycles."""
|
||||
|
||||
def test_document_round_trip(self, document_service):
|
||||
"""Test uploading and downloading document."""
|
||||
original_content = b"%PDF-1.4 complete document content here"
|
||||
filename = "roundtrip.pdf"
|
||||
|
||||
result = document_service.upload_document(original_content, filename)
|
||||
downloaded = document_service.download_document(result.file_path)
|
||||
|
||||
assert downloaded == original_content
|
||||
|
||||
def test_image_round_trip(self, document_service):
|
||||
"""Test saving and retrieving page image."""
|
||||
doc_id = "roundtrip-doc"
|
||||
page_num = 1
|
||||
original_image = b"\x89PNG fake image data"
|
||||
|
||||
document_service.save_page_image(doc_id, page_num, original_image)
|
||||
retrieved = document_service.get_page_image(doc_id, page_num)
|
||||
|
||||
assert retrieved == original_image
|
||||
258
tests/integration/test_database_setup.py
Normal file
258
tests/integration/test_database_setup.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""
|
||||
Database Setup Integration Tests
|
||||
|
||||
Tests for database connection, session management, and basic operations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from inference.data.admin_models import AdminDocument, AdminToken
|
||||
|
||||
|
||||
class TestDatabaseConnection:
|
||||
"""Tests for database engine and connection."""
|
||||
|
||||
def test_engine_connection(self, test_engine):
|
||||
"""Verify database engine can establish connection."""
|
||||
with test_engine.connect() as conn:
|
||||
result = conn.execute(select(1))
|
||||
assert result.scalar() == 1
|
||||
|
||||
def test_tables_created(self, test_engine):
|
||||
"""Verify all expected tables are created."""
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
table_names = SQLModel.metadata.tables.keys()
|
||||
|
||||
expected_tables = [
|
||||
"admin_tokens",
|
||||
"admin_documents",
|
||||
"admin_annotations",
|
||||
"training_tasks",
|
||||
"training_logs",
|
||||
"batch_uploads",
|
||||
"batch_upload_files",
|
||||
"training_datasets",
|
||||
"dataset_documents",
|
||||
"training_document_links",
|
||||
"model_versions",
|
||||
]
|
||||
|
||||
for table in expected_tables:
|
||||
assert table in table_names, f"Table '{table}' not found"
|
||||
|
||||
|
||||
class TestSessionManagement:
|
||||
"""Tests for database session context manager."""
|
||||
|
||||
def test_session_commit(self, db_session):
|
||||
"""Verify session commits changes successfully."""
|
||||
token = AdminToken(
|
||||
token="commit-test-token",
|
||||
name="Commit Test",
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(token)
|
||||
db_session.commit()
|
||||
|
||||
result = db_session.exec(
|
||||
select(AdminToken).where(AdminToken.token == "commit-test-token")
|
||||
).first()
|
||||
|
||||
assert result is not None
|
||||
assert result.name == "Commit Test"
|
||||
|
||||
def test_session_rollback_on_error(self, test_engine):
|
||||
"""Verify session rollback on exception."""
|
||||
session = Session(test_engine)
|
||||
|
||||
try:
|
||||
token = AdminToken(
|
||||
token="rollback-test-token",
|
||||
name="Rollback Test",
|
||||
is_active=True,
|
||||
)
|
||||
session.add(token)
|
||||
session.commit()
|
||||
|
||||
# Try to insert duplicate (should fail)
|
||||
duplicate = AdminToken(
|
||||
token="rollback-test-token", # Same primary key
|
||||
name="Duplicate",
|
||||
is_active=True,
|
||||
)
|
||||
session.add(duplicate)
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
# Verify original record exists
|
||||
with Session(test_engine) as verify_session:
|
||||
result = verify_session.exec(
|
||||
select(AdminToken).where(AdminToken.token == "rollback-test-token")
|
||||
).first()
|
||||
assert result is not None
|
||||
assert result.name == "Rollback Test"
|
||||
|
||||
def test_session_isolation(self, test_engine):
|
||||
"""Verify sessions are isolated from each other."""
|
||||
session1 = Session(test_engine)
|
||||
session2 = Session(test_engine)
|
||||
|
||||
try:
|
||||
# Insert in session1, don't commit
|
||||
token = AdminToken(
|
||||
token="isolation-test-token",
|
||||
name="Isolation Test",
|
||||
is_active=True,
|
||||
)
|
||||
session1.add(token)
|
||||
session1.flush()
|
||||
|
||||
# Session2 should not see uncommitted data (with proper isolation)
|
||||
# Note: SQLite in-memory may have different isolation behavior
|
||||
session1.commit()
|
||||
|
||||
result = session2.exec(
|
||||
select(AdminToken).where(AdminToken.token == "isolation-test-token")
|
||||
).first()
|
||||
|
||||
# After commit, session2 should see the data
|
||||
assert result is not None
|
||||
|
||||
finally:
|
||||
session1.close()
|
||||
session2.close()
|
||||
|
||||
|
||||
class TestBasicCRUDOperations:
|
||||
"""Tests for basic CRUD operations on database."""
|
||||
|
||||
def test_create_and_read_token(self, db_session):
|
||||
"""Test creating and reading admin token."""
|
||||
token = AdminToken(
|
||||
token="crud-test-token",
|
||||
name="CRUD Test",
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(token)
|
||||
db_session.commit()
|
||||
|
||||
result = db_session.get(AdminToken, "crud-test-token")
|
||||
|
||||
assert result is not None
|
||||
assert result.name == "CRUD Test"
|
||||
assert result.is_active is True
|
||||
|
||||
def test_update_entity(self, db_session, admin_token):
|
||||
"""Test updating an entity."""
|
||||
admin_token.name = "Updated Name"
|
||||
db_session.add(admin_token)
|
||||
db_session.commit()
|
||||
|
||||
result = db_session.get(AdminToken, admin_token.token)
|
||||
|
||||
assert result is not None
|
||||
assert result.name == "Updated Name"
|
||||
|
||||
def test_delete_entity(self, db_session):
|
||||
"""Test deleting an entity."""
|
||||
token = AdminToken(
|
||||
token="delete-test-token",
|
||||
name="Delete Test",
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(token)
|
||||
db_session.commit()
|
||||
|
||||
db_session.delete(token)
|
||||
db_session.commit()
|
||||
|
||||
result = db_session.get(AdminToken, "delete-test-token")
|
||||
assert result is None
|
||||
|
||||
def test_foreign_key_constraint(self, db_session, admin_token):
|
||||
"""Test foreign key constraints are enforced."""
|
||||
from uuid import uuid4
|
||||
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename="fk_test.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/test/fk_test.pdf",
|
||||
page_count=1,
|
||||
status="pending",
|
||||
)
|
||||
db_session.add(doc)
|
||||
db_session.commit()
|
||||
|
||||
# Document should reference valid token
|
||||
result = db_session.get(AdminDocument, doc.document_id)
|
||||
assert result is not None
|
||||
assert result.admin_token == admin_token.token
|
||||
|
||||
|
||||
class TestQueryOperations:
|
||||
"""Tests for various query operations."""
|
||||
|
||||
def test_select_with_filter(self, db_session, multiple_documents):
|
||||
"""Test SELECT with WHERE clause."""
|
||||
results = db_session.exec(
|
||||
select(AdminDocument).where(AdminDocument.status == "labeled")
|
||||
).all()
|
||||
|
||||
assert len(results) == 2
|
||||
for doc in results:
|
||||
assert doc.status == "labeled"
|
||||
|
||||
def test_select_with_order(self, db_session, multiple_documents):
|
||||
"""Test SELECT with ORDER BY clause."""
|
||||
results = db_session.exec(
|
||||
select(AdminDocument).order_by(AdminDocument.file_size.desc())
|
||||
).all()
|
||||
|
||||
file_sizes = [doc.file_size for doc in results]
|
||||
assert file_sizes == sorted(file_sizes, reverse=True)
|
||||
|
||||
def test_select_with_limit_offset(self, db_session, multiple_documents):
|
||||
"""Test SELECT with LIMIT and OFFSET."""
|
||||
results = db_session.exec(
|
||||
select(AdminDocument)
|
||||
.order_by(AdminDocument.filename)
|
||||
.offset(2)
|
||||
.limit(2)
|
||||
).all()
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
def test_count_query(self, db_session, multiple_documents):
|
||||
"""Test COUNT aggregation."""
|
||||
from sqlalchemy import func
|
||||
|
||||
count = db_session.exec(
|
||||
select(func.count()).select_from(AdminDocument)
|
||||
).one()
|
||||
|
||||
assert count == 5
|
||||
|
||||
def test_group_by_query(self, db_session, multiple_documents):
|
||||
"""Test GROUP BY aggregation."""
|
||||
from sqlalchemy import func
|
||||
|
||||
results = db_session.exec(
|
||||
select(
|
||||
AdminDocument.status,
|
||||
func.count(AdminDocument.document_id).label("count"),
|
||||
).group_by(AdminDocument.status)
|
||||
).all()
|
||||
|
||||
status_counts = {row[0]: row[1] for row in results}
|
||||
|
||||
assert status_counts.get("pending") == 2
|
||||
assert status_counts.get("labeled") == 2
|
||||
assert status_counts.get("exported") == 1
|
||||
Reference in New Issue
Block a user