Add more tests

This commit is contained in:
Yaojia Wang
2026-02-01 22:40:41 +01:00
parent a564ac9d70
commit 400b12a967
55 changed files with 9306 additions and 267 deletions

View File

@@ -0,0 +1 @@
"""Integration tests for invoice-master-poc-v2."""

View File

@@ -0,0 +1 @@
"""API integration tests."""

View File

@@ -0,0 +1,389 @@
"""
API Integration Tests
Tests FastAPI endpoints with mocked services.
These tests verify the API layer works correctly with the service layer.
"""
import io
import tempfile
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
@dataclass
class MockServiceResult:
"""Mock result from inference service."""
document_id: str = "test-doc-123"
success: bool = True
document_type: str = "invoice"
fields: dict[str, str] = field(default_factory=lambda: {
"InvoiceNumber": "INV-2024-001",
"Amount": "1500.00",
"InvoiceDate": "2024-01-15",
"OCR": "12345678901234",
"Bankgiro": "1234-5678",
})
confidence: dict[str, float] = field(default_factory=lambda: {
"InvoiceNumber": 0.95,
"Amount": 0.92,
"InvoiceDate": 0.88,
"OCR": 0.95,
"Bankgiro": 0.90,
})
detections: list[dict[str, Any]] = field(default_factory=list)
processing_time_ms: float = 150.5
visualization_path: Path | None = None
errors: list[str] = field(default_factory=list)
@pytest.fixture
def temp_storage_dir():
"""Create temporary storage directories."""
with tempfile.TemporaryDirectory() as tmpdir:
base = Path(tmpdir)
uploads_dir = base / "uploads" / "inference"
results_dir = base / "results"
uploads_dir.mkdir(parents=True, exist_ok=True)
results_dir.mkdir(parents=True, exist_ok=True)
yield {
"base": base,
"uploads": uploads_dir,
"results": results_dir,
}
@pytest.fixture
def mock_inference_service():
"""Create a mock inference service."""
service = MagicMock()
service.is_initialized = True
service.gpu_available = False
# Create a realistic mock result
mock_result = MockServiceResult()
service.process_pdf.return_value = mock_result
service.process_image.return_value = mock_result
service.initialize.return_value = None
return service
@pytest.fixture
def mock_storage_config(temp_storage_dir):
"""Create mock storage configuration."""
from inference.web.config import StorageConfig
return StorageConfig(
upload_dir=temp_storage_dir["uploads"],
result_dir=temp_storage_dir["results"],
max_file_size_mb=50,
)
@pytest.fixture
def mock_storage_helper(temp_storage_dir):
"""Create a mock storage helper."""
helper = MagicMock()
helper.get_uploads_base_path.return_value = temp_storage_dir["uploads"]
helper.get_result_local_path.return_value = None
helper.result_exists.return_value = False
return helper
@pytest.fixture
def test_app(mock_inference_service, mock_storage_config, mock_storage_helper):
"""Create a test FastAPI application with mocked storage."""
from inference.web.api.v1.public.inference import create_inference_router
app = FastAPI()
# Patch get_storage_helper to return our mock
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
inference_router = create_inference_router(mock_inference_service, mock_storage_config)
app.include_router(inference_router)
return app
@pytest.fixture
def client(test_app, mock_storage_helper):
"""Create a test client with storage helper patched."""
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
yield TestClient(test_app)
class TestHealthEndpoint:
"""Tests for health check endpoint."""
def test_health_check(self, client, mock_inference_service):
"""Test health check returns status."""
response = client.get("/api/v1/health")
assert response.status_code == 200
data = response.json()
assert "status" in data
assert "model_loaded" in data
class TestInferenceEndpoint:
"""Tests for inference endpoint."""
def test_infer_pdf(self, client, mock_inference_service, mock_storage_helper, temp_storage_dir):
"""Test PDF inference endpoint."""
# Create a minimal PDF content
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
assert response.status_code == 200
data = response.json()
assert "result" in data
assert data["result"]["success"] is True
assert "InvoiceNumber" in data["result"]["fields"]
def test_infer_image(self, client, mock_inference_service, mock_storage_helper):
"""Test image inference endpoint."""
# Create minimal PNG header
png_header = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.png", io.BytesIO(png_header), "image/png")},
)
assert response.status_code == 200
data = response.json()
assert "result" in data
def test_infer_invalid_file_type(self, client, mock_storage_helper):
"""Test rejection of invalid file types."""
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.txt", io.BytesIO(b"hello"), "text/plain")},
)
assert response.status_code == 400
def test_infer_no_file(self, client, mock_storage_helper):
"""Test rejection when no file provided."""
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post("/api/v1/infer")
assert response.status_code == 422 # Validation error
def test_infer_result_structure(self, client, mock_inference_service, mock_storage_helper):
"""Test that result has expected structure."""
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
data = response.json()
result = data["result"]
# Check required fields
assert "document_id" in result
assert "success" in result
assert "fields" in result
assert "confidence" in result
assert "processing_time_ms" in result
class TestInferenceResultFormat:
"""Tests for inference result formatting."""
def test_result_fields_mapped_correctly(self, client, mock_inference_service, mock_storage_helper):
"""Test that fields are mapped to API response format."""
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
data = response.json()
fields = data["result"]["fields"]
assert fields["InvoiceNumber"] == "INV-2024-001"
assert fields["Amount"] == "1500.00"
assert fields["InvoiceDate"] == "2024-01-15"
def test_confidence_values_included(self, client, mock_inference_service, mock_storage_helper):
"""Test that confidence values are included."""
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
data = response.json()
confidence = data["result"]["confidence"]
assert "InvoiceNumber" in confidence
assert confidence["InvoiceNumber"] == 0.95
class TestErrorHandling:
"""Tests for error handling in API."""
def test_service_error_handling(self, client, mock_inference_service, mock_storage_helper):
"""Test handling of service errors."""
mock_inference_service.process_pdf.side_effect = Exception("Processing failed")
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
# Should return error response
assert response.status_code >= 400
def test_empty_file_handling(self, client, mock_storage_helper):
"""Test handling of empty files."""
# Empty file still has valid content type
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(b""), "application/pdf")},
)
# Empty file may be processed or rejected depending on implementation
# Just verify we get a response
assert response.status_code in [200, 400, 422, 500]
class TestResponseFormat:
"""Tests for API response format consistency."""
def test_success_response_format(self, client, mock_inference_service, mock_storage_helper):
"""Test successful response format."""
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
data = response.json()
assert isinstance(data, dict)
assert "result" in data
def test_json_serialization(self, client, mock_inference_service, mock_storage_helper):
"""Test that all result fields are JSON serializable."""
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
# If this doesn't raise, JSON is valid
data = response.json()
assert data is not None
class TestDocumentIdGeneration:
"""Tests for document ID handling."""
def test_document_id_generated(self, client, mock_inference_service, mock_storage_helper):
"""Test that document ID is generated."""
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
data = response.json()
assert "document_id" in data["result"]
assert data["result"]["document_id"] is not None
def test_document_id_from_filename(self, client, mock_inference_service, mock_storage_helper):
"""Test document ID derived from filename."""
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("my_invoice_123.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
data = response.json()
# Document ID should be set (either from filename or generated)
assert data["result"]["document_id"] is not None

View File

@@ -0,0 +1,400 @@
"""
Dashboard API Integration Tests
Tests Dashboard API endpoints with real database operations via TestClient.
"""
from datetime import datetime, timezone
from uuid import uuid4
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from inference.data.admin_models import (
AdminAnnotation,
AdminDocument,
AdminToken,
AnnotationHistory,
ModelVersion,
TrainingDataset,
TrainingTask,
)
from inference.web.api.v1.admin.dashboard import create_dashboard_router
from inference.web.core.auth import get_admin_token_dep
def create_test_app(override_token_dep):
"""Create a FastAPI test application with dashboard router."""
app = FastAPI()
router = create_dashboard_router()
app.include_router(router)
# Override auth dependency
app.dependency_overrides[get_admin_token_dep] = lambda: override_token_dep
return app
class TestDashboardStatsEndpoint:
"""Tests for GET /admin/dashboard/stats endpoint."""
def test_stats_empty_database(self, patched_session, admin_token):
"""Test stats endpoint with empty database."""
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/stats")
assert response.status_code == 200
data = response.json()
assert data["total_documents"] == 0
assert data["annotation_complete"] == 0
assert data["annotation_incomplete"] == 0
assert data["pending"] == 0
assert data["completeness_rate"] == 0.0
def test_stats_with_pending_documents(self, patched_session, admin_token):
"""Test stats with pending documents."""
session = patched_session
# Create pending documents
for i in range(3):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"pending_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/pending_{i}.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/stats")
assert response.status_code == 200
data = response.json()
assert data["total_documents"] == 3
assert data["pending"] == 3
def test_stats_with_complete_annotations(self, patched_session, admin_token):
"""Test stats with complete annotations."""
session = patched_session
# Create labeled document with complete annotations
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename="complete.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/complete.pdf",
page_count=1,
status="labeled",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
# Add identifier and payment annotations
session.add(AdminAnnotation(
annotation_id=uuid4(),
document_id=doc.document_id,
page_number=1,
class_id=0, # invoice_number
class_name="invoice_number",
x_center=0.5, y_center=0.1, width=0.2, height=0.05,
bbox_x=400, bbox_y=80, bbox_width=160, bbox_height=40,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
))
session.add(AdminAnnotation(
annotation_id=uuid4(),
document_id=doc.document_id,
page_number=1,
class_id=4, # bankgiro
class_name="bankgiro",
x_center=0.5, y_center=0.2, width=0.2, height=0.05,
bbox_x=400, bbox_y=160, bbox_width=160, bbox_height=40,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
))
session.commit()
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/stats")
assert response.status_code == 200
data = response.json()
assert data["annotation_complete"] == 1
assert data["completeness_rate"] == 100.0
class TestActiveModelEndpoint:
"""Tests for GET /admin/dashboard/active-model endpoint."""
def test_active_model_none(self, patched_session, admin_token):
"""Test active-model endpoint with no active model."""
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/active-model")
assert response.status_code == 200
data = response.json()
assert data["model"] is None
assert data["running_training"] is None
def test_active_model_with_model(self, patched_session, admin_token, sample_dataset):
"""Test active-model endpoint with active model."""
session = patched_session
# Create training task
task = TrainingTask(
task_id=uuid4(),
admin_token=admin_token.token,
name="Test Task",
status="completed",
task_type="train",
dataset_id=sample_dataset.dataset_id,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(task)
session.commit()
# Create active model
model = ModelVersion(
version_id=uuid4(),
version="1.0.0",
name="Test Model",
model_path="/models/test.pt",
status="active",
is_active=True,
task_id=task.task_id,
dataset_id=sample_dataset.dataset_id,
metrics_mAP=0.90,
metrics_precision=0.88,
metrics_recall=0.85,
document_count=100,
file_size=50000000,
activated_at=datetime.now(timezone.utc),
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(model)
session.commit()
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/active-model")
assert response.status_code == 200
data = response.json()
assert data["model"] is not None
assert data["model"]["version"] == "1.0.0"
assert data["model"]["name"] == "Test Model"
assert data["model"]["metrics_mAP"] == 0.90
def test_active_model_with_running_training(self, patched_session, admin_token, sample_dataset):
"""Test active-model endpoint with running training."""
session = patched_session
# Create running training task
task = TrainingTask(
task_id=uuid4(),
admin_token=admin_token.token,
name="Running Task",
status="running",
task_type="train",
dataset_id=sample_dataset.dataset_id,
started_at=datetime.now(timezone.utc),
progress=50,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(task)
session.commit()
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/active-model")
assert response.status_code == 200
data = response.json()
assert data["running_training"] is not None
assert data["running_training"]["name"] == "Running Task"
assert data["running_training"]["status"] == "running"
assert data["running_training"]["progress"] == 50
class TestRecentActivityEndpoint:
"""Tests for GET /admin/dashboard/activity endpoint."""
def test_activity_empty(self, patched_session, admin_token):
"""Test activity endpoint with no activities."""
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/activity")
assert response.status_code == 200
data = response.json()
assert data["activities"] == []
def test_activity_with_uploads(self, patched_session, admin_token):
"""Test activity includes document uploads."""
session = patched_session
# Create documents
for i in range(3):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"activity_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/activity_{i}.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/activity")
assert response.status_code == 200
data = response.json()
upload_activities = [a for a in data["activities"] if a["type"] == "document_uploaded"]
assert len(upload_activities) == 3
def test_activity_limit_parameter(self, patched_session, admin_token):
"""Test activity limit parameter."""
session = patched_session
# Create many documents
for i in range(15):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"limit_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/limit_{i}.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/activity?limit=5")
assert response.status_code == 200
data = response.json()
assert len(data["activities"]) <= 5
def test_activity_invalid_limit(self, patched_session, admin_token):
"""Test activity with invalid limit parameter."""
app = create_test_app(admin_token.token)
client = TestClient(app)
# Limit too high
response = client.get("/admin/dashboard/activity?limit=100")
assert response.status_code == 422
# Limit too low
response = client.get("/admin/dashboard/activity?limit=0")
assert response.status_code == 422
def test_activity_with_training_completion(self, patched_session, admin_token, sample_dataset):
"""Test activity includes training completions."""
session = patched_session
# Create completed training task
task = TrainingTask(
task_id=uuid4(),
admin_token=admin_token.token,
name="Completed Task",
status="completed",
task_type="train",
dataset_id=sample_dataset.dataset_id,
metrics_mAP=0.95,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(task)
session.commit()
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/activity")
assert response.status_code == 200
data = response.json()
training_activities = [a for a in data["activities"] if a["type"] == "training_completed"]
assert len(training_activities) >= 1
def test_activity_sorted_by_timestamp(self, patched_session, admin_token):
"""Test activities are sorted by timestamp descending."""
session = patched_session
# Create documents
for i in range(5):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"sorted_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/sorted_{i}.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/activity")
assert response.status_code == 200
data = response.json()
timestamps = [a["timestamp"] for a in data["activities"]]
assert timestamps == sorted(timestamps, reverse=True)

View File

@@ -0,0 +1,465 @@
"""
Integration Test Fixtures
Provides shared fixtures for integration tests using PostgreSQL.
IMPORTANT: Integration tests MUST use Docker testcontainers for database isolation.
This ensures tests never touch the real production/development database.
Supported modes:
1. Docker testcontainers (default): Automatically starts a PostgreSQL container
2. TEST_DB_URL environment variable: Use a dedicated test database (NOT production!)
To use an external test database, set:
TEST_DB_URL=postgresql://user:password@host:port/test_dbname
"""
import os
import tempfile
from contextlib import contextmanager, ExitStack
from datetime import datetime, timezone
from pathlib import Path
from typing import Generator
from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlmodel import Session, SQLModel, create_engine
from inference.data.admin_models import (
AdminAnnotation,
AdminDocument,
AdminToken,
AnnotationHistory,
BatchUpload,
BatchUploadFile,
DatasetDocument,
ModelVersion,
TrainingDataset,
TrainingDocumentLink,
TrainingLog,
TrainingTask,
)
# =============================================================================
# Database Fixtures
# =============================================================================
def _is_docker_available() -> bool:
"""Check if Docker is available."""
try:
import docker
client = docker.from_env()
client.ping()
return True
except Exception:
return False
def _get_test_db_url() -> str | None:
"""Get test database URL from environment."""
return os.environ.get("TEST_DB_URL")
@pytest.fixture(scope="session")
def test_engine():
"""Create a SQLAlchemy engine for testing.
Uses one of:
1. TEST_DB_URL environment variable (dedicated test database)
2. Docker testcontainers (if Docker is available)
IMPORTANT: Will NOT fall back to production database. If Docker is not
available and TEST_DB_URL is not set, tests will fail with a clear error.
The engine is shared across all tests in a session for efficiency.
"""
# Try to get URL from environment first
connection_url = _get_test_db_url()
if connection_url:
# Use external test database from environment
# Warn if it looks like a production database
if "docmaster" in connection_url and "_test" not in connection_url:
import warnings
warnings.warn(
"TEST_DB_URL appears to point to a production database. "
"Please use a dedicated test database (e.g., docmaster_test).",
UserWarning,
)
elif _is_docker_available():
# Use testcontainers - this is the recommended approach
from testcontainers.postgres import PostgresContainer
postgres = PostgresContainer("postgres:15-alpine")
postgres.start()
connection_url = postgres.get_connection_url()
if "psycopg2" in connection_url:
connection_url = connection_url.replace("postgresql+psycopg2://", "postgresql://")
# Store container for cleanup
test_engine._postgres_container = postgres
else:
# No Docker and no TEST_DB_URL - fail with clear instructions
pytest.fail(
"Integration tests require Docker or a TEST_DB_URL environment variable.\n\n"
"Option 1 (Recommended): Install Docker Desktop and ensure it's running.\n"
" - Windows: https://docs.docker.com/desktop/install/windows-install/\n"
" - The testcontainers library will automatically create a PostgreSQL container.\n\n"
"Option 2: Set TEST_DB_URL to a dedicated test database:\n"
" - export TEST_DB_URL=postgresql://user:password@host:port/test_dbname\n"
" - NEVER use your production database for tests!\n\n"
"Integration tests will NOT fall back to the production database."
)
engine = create_engine(
connection_url,
echo=False,
pool_pre_ping=True,
)
# Create all tables
SQLModel.metadata.create_all(engine)
yield engine
# Cleanup
SQLModel.metadata.drop_all(engine)
engine.dispose()
# Stop container if we started one
if hasattr(test_engine, "_postgres_container"):
test_engine._postgres_container.stop()
@pytest.fixture(scope="function")
def db_session(test_engine) -> Generator[Session, None, None]:
"""Provide a database session for each test function.
Each test gets a fresh session that rolls back after the test,
ensuring test isolation.
"""
connection = test_engine.connect()
transaction = connection.begin()
session = Session(bind=connection)
yield session
# Rollback and cleanup
session.close()
transaction.rollback()
connection.close()
@pytest.fixture(scope="function")
def patched_session(db_session):
"""Patch get_session_context to use the test session.
This allows repository classes to use the test database session
instead of creating their own connections.
We need to patch in multiple locations because each repository module
imports get_session_context directly.
"""
@contextmanager
def mock_session_context() -> Generator[Session, None, None]:
yield db_session
# All modules that import get_session_context
patch_targets = [
"inference.data.database.get_session_context",
"inference.data.repositories.document_repository.get_session_context",
"inference.data.repositories.annotation_repository.get_session_context",
"inference.data.repositories.dataset_repository.get_session_context",
"inference.data.repositories.training_task_repository.get_session_context",
"inference.data.repositories.model_version_repository.get_session_context",
"inference.data.repositories.batch_upload_repository.get_session_context",
"inference.data.repositories.token_repository.get_session_context",
"inference.web.services.dashboard_service.get_session_context",
]
with ExitStack() as stack:
for target in patch_targets:
try:
stack.enter_context(patch(target, mock_session_context))
except (ModuleNotFoundError, AttributeError):
# Skip if module doesn't exist or doesn't have the attribute
pass
yield db_session
# =============================================================================
# Test Data Fixtures
# =============================================================================
@pytest.fixture
def admin_token(db_session) -> AdminToken:
"""Create a test admin token."""
token = AdminToken(
token="test-admin-token-12345",
name="Test Admin",
is_active=True,
created_at=datetime.now(timezone.utc),
)
db_session.add(token)
db_session.commit()
db_session.refresh(token)
return token
@pytest.fixture
def sample_document(db_session, admin_token) -> AdminDocument:
"""Create a sample document for testing."""
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename="test_invoice.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/test_invoice.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(doc)
db_session.commit()
db_session.refresh(doc)
return doc
@pytest.fixture
def sample_annotation(db_session, sample_document) -> AdminAnnotation:
"""Create a sample annotation for testing."""
annotation = AdminAnnotation(
annotation_id=uuid4(),
document_id=sample_document.document_id,
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.3,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=240,
bbox_width=160,
bbox_height=40,
text_value="INV-2024-001",
confidence=0.95,
source="auto",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(annotation)
db_session.commit()
db_session.refresh(annotation)
return annotation
@pytest.fixture
def sample_dataset(db_session) -> TrainingDataset:
"""Create a sample training dataset for testing."""
dataset = TrainingDataset(
dataset_id=uuid4(),
name="Test Dataset",
description="Dataset for integration testing",
status="building",
train_ratio=0.8,
val_ratio=0.1,
seed=42,
total_documents=0,
total_images=0,
total_annotations=0,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(dataset)
db_session.commit()
db_session.refresh(dataset)
return dataset
@pytest.fixture
def sample_training_task(db_session, admin_token, sample_dataset) -> TrainingTask:
"""Create a sample training task for testing."""
task = TrainingTask(
task_id=uuid4(),
admin_token=admin_token.token,
name="Test Training Task",
description="Training task for integration testing",
status="pending",
task_type="train",
dataset_id=sample_dataset.dataset_id,
config={"epochs": 10, "batch_size": 16},
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
return task
@pytest.fixture
def sample_model_version(db_session, sample_training_task, sample_dataset) -> ModelVersion:
"""Create a sample model version for testing."""
version = ModelVersion(
version_id=uuid4(),
version="1.0.0",
name="Test Model v1",
description="Model version for integration testing",
model_path="/models/test_model.pt",
status="inactive",
is_active=False,
task_id=sample_training_task.task_id,
dataset_id=sample_dataset.dataset_id,
metrics_mAP=0.85,
metrics_precision=0.88,
metrics_recall=0.82,
document_count=100,
file_size=50000000,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(version)
db_session.commit()
db_session.refresh(version)
return version
@pytest.fixture
def sample_batch_upload(db_session, admin_token) -> BatchUpload:
"""Create a sample batch upload for testing."""
batch = BatchUpload(
batch_id=uuid4(),
admin_token=admin_token.token,
filename="test_batch.zip",
file_size=10240,
upload_source="api",
status="processing",
total_files=5,
processed_files=0,
successful_files=0,
failed_files=0,
created_at=datetime.now(timezone.utc),
)
db_session.add(batch)
db_session.commit()
db_session.refresh(batch)
return batch
# =============================================================================
# Multiple Documents Fixture
# =============================================================================
@pytest.fixture
def multiple_documents(db_session, admin_token) -> list[AdminDocument]:
"""Create multiple documents for pagination/filtering tests."""
documents = []
statuses = ["pending", "pending", "labeled", "labeled", "exported"]
categories = ["invoice", "invoice", "invoice", "letter", "invoice"]
for i, (status, category) in enumerate(zip(statuses, categories)):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"test_doc_{i}.pdf",
file_size=1024 + i * 100,
content_type="application/pdf",
file_path=f"/uploads/test_doc_{i}.pdf",
page_count=1,
status=status,
upload_source="ui",
category=category,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(doc)
documents.append(doc)
db_session.commit()
for doc in documents:
db_session.refresh(doc)
return documents
# =============================================================================
# Temporary File Fixtures
# =============================================================================
@pytest.fixture
def temp_upload_dir() -> Generator[Path, None, None]:
"""Create a temporary directory for file uploads."""
with tempfile.TemporaryDirectory() as tmpdir:
upload_dir = Path(tmpdir) / "uploads"
upload_dir.mkdir(parents=True, exist_ok=True)
yield upload_dir
@pytest.fixture
def temp_model_dir() -> Generator[Path, None, None]:
"""Create a temporary directory for model files."""
with tempfile.TemporaryDirectory() as tmpdir:
model_dir = Path(tmpdir) / "models"
model_dir.mkdir(parents=True, exist_ok=True)
yield model_dir
@pytest.fixture
def temp_dataset_dir() -> Generator[Path, None, None]:
"""Create a temporary directory for dataset files."""
with tempfile.TemporaryDirectory() as tmpdir:
dataset_dir = Path(tmpdir) / "datasets"
dataset_dir.mkdir(parents=True, exist_ok=True)
yield dataset_dir
# =============================================================================
# Sample PDF Fixture
# =============================================================================
@pytest.fixture
def sample_pdf_bytes() -> bytes:
"""Return minimal valid PDF bytes for testing."""
# Minimal valid PDF structure
return b"""%PDF-1.4
1 0 obj
<< /Type /Catalog /Pages 2 0 R >>
endobj
2 0 obj
<< /Type /Pages /Kids [3 0 R] /Count 1 >>
endobj
3 0 obj
<< /Type /Page /Parent 2 0 R /MediaBox [0 0 612 792] >>
endobj
xref
0 4
0000000000 65535 f
0000000009 00000 n
0000000058 00000 n
0000000115 00000 n
trailer
<< /Size 4 /Root 1 0 R >>
startxref
196
%%EOF"""
@pytest.fixture
def sample_pdf_file(temp_upload_dir, sample_pdf_bytes) -> Path:
"""Create a sample PDF file for testing."""
pdf_path = temp_upload_dir / "test_invoice.pdf"
pdf_path.write_bytes(sample_pdf_bytes)
return pdf_path

View File

@@ -0,0 +1 @@
"""Pipeline integration tests."""

View File

@@ -0,0 +1,456 @@
"""
Inference Pipeline Integration Tests
Tests the complete pipeline from input to output.
Note: These tests use mocks for YOLO and OCR to avoid requiring actual models,
but test the integration of pipeline components.
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
import numpy as np
from inference.pipeline.pipeline import (
InferencePipeline,
InferenceResult,
CrossValidationResult,
)
from inference.pipeline.yolo_detector import Detection
from inference.pipeline.field_extractor import ExtractedField
@pytest.fixture
def mock_detection():
"""Create a mock detection."""
return Detection(
class_id=0,
class_name="invoice_number",
confidence=0.95,
bbox=(100, 50, 200, 30),
page_no=0,
)
@pytest.fixture
def mock_extracted_field():
"""Create a mock extracted field."""
return ExtractedField(
field_name="InvoiceNumber",
raw_text="INV-2024-001",
normalized_value="INV-2024-001",
confidence=0.95,
bbox=(100, 50, 200, 30),
page_no=0,
is_valid=True,
)
class TestInferenceResultConstruction:
"""Tests for InferenceResult construction and methods."""
def test_default_result(self):
"""Test default InferenceResult values."""
result = InferenceResult()
assert result.document_id is None
assert result.success is False
assert result.fields == {}
assert result.confidence == {}
assert result.raw_detections == []
assert result.extracted_fields == []
assert result.errors == []
assert result.fallback_used is False
assert result.cross_validation is None
def test_result_to_json(self):
"""Test JSON serialization of result."""
result = InferenceResult(
document_id="test-doc",
success=True,
fields={
"InvoiceNumber": "INV-001",
"Amount": "1500.00",
},
confidence={
"InvoiceNumber": 0.95,
"Amount": 0.92,
},
bboxes={
"InvoiceNumber": (100, 50, 200, 30),
},
)
json_data = result.to_json()
assert json_data["DocumentId"] == "test-doc"
assert json_data["success"] is True
assert json_data["InvoiceNumber"] == "INV-001"
assert json_data["Amount"] == "1500.00"
assert json_data["confidence"]["InvoiceNumber"] == 0.95
assert "bboxes" in json_data
def test_result_get_field(self):
"""Test getting field value and confidence."""
result = InferenceResult(
fields={"InvoiceNumber": "INV-001"},
confidence={"InvoiceNumber": 0.95},
)
value, conf = result.get_field("InvoiceNumber")
assert value == "INV-001"
assert conf == 0.95
value, conf = result.get_field("Amount")
assert value is None
assert conf == 0.0
class TestCrossValidation:
"""Tests for cross-validation logic."""
def test_cross_validation_default(self):
"""Test default CrossValidationResult values."""
cv = CrossValidationResult()
assert cv.is_valid is False
assert cv.ocr_match is None
assert cv.amount_match is None
assert cv.bankgiro_match is None
assert cv.plusgiro_match is None
assert cv.payment_line_ocr is None
assert cv.payment_line_amount is None
assert cv.details == []
def test_cross_validation_with_matches(self):
"""Test CrossValidationResult with matches."""
cv = CrossValidationResult(
is_valid=True,
ocr_match=True,
amount_match=True,
bankgiro_match=True,
payment_line_ocr="12345678901234",
payment_line_amount="1500.00",
payment_line_account="1234-5678",
payment_line_account_type="bankgiro",
details=["OCR match", "Amount match", "Bankgiro match"],
)
assert cv.is_valid is True
assert cv.ocr_match is True
assert cv.amount_match is True
assert len(cv.details) == 3
class TestPipelineMergeFields:
"""Tests for field merging logic."""
def test_merge_selects_highest_confidence(self):
"""Test that merge selects highest confidence for duplicate fields."""
# Create mock pipeline with minimal mocking
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
pipeline.payment_line_parser = MagicMock()
pipeline.payment_line_parser.parse.return_value = MagicMock(is_valid=False)
result = InferenceResult()
result.extracted_fields = [
ExtractedField(
field_name="InvoiceNumber",
raw_text="INV-001",
normalized_value="INV-001",
confidence=0.85,
detection_confidence=0.90,
ocr_confidence=0.85,
bbox=(100, 50, 200, 30),
page_no=0,
is_valid=True,
),
ExtractedField(
field_name="InvoiceNumber",
raw_text="INV-001",
normalized_value="INV-001",
confidence=0.95, # Higher confidence
detection_confidence=0.95,
ocr_confidence=0.95,
bbox=(105, 52, 198, 28),
page_no=0,
is_valid=True,
),
]
pipeline._merge_fields(result)
assert result.fields["InvoiceNumber"] == "INV-001"
assert result.confidence["InvoiceNumber"] == 0.95
def test_merge_skips_invalid_fields(self):
"""Test that merge skips invalid extracted fields."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
pipeline.payment_line_parser = MagicMock()
pipeline.payment_line_parser.parse.return_value = MagicMock(is_valid=False)
result = InferenceResult()
result.extracted_fields = [
ExtractedField(
field_name="InvoiceNumber",
raw_text="",
normalized_value=None,
confidence=0.95,
detection_confidence=0.95,
ocr_confidence=0.95,
bbox=(100, 50, 200, 30),
page_no=0,
is_valid=False, # Invalid
),
ExtractedField(
field_name="Amount",
raw_text="1500.00",
normalized_value="1500.00",
confidence=0.92,
detection_confidence=0.92,
ocr_confidence=0.92,
bbox=(200, 100, 100, 25),
page_no=0,
is_valid=True,
),
]
pipeline._merge_fields(result)
assert "InvoiceNumber" not in result.fields
assert result.fields["Amount"] == "1500.00"
class TestPaymentLineValidation:
"""Tests for payment line cross-validation."""
def test_payment_line_overrides_ocr(self):
"""Test that payment line OCR overrides detected OCR."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
# Mock payment line parser
mock_parsed = MagicMock()
mock_parsed.is_valid = True
mock_parsed.ocr_number = "12345678901234"
mock_parsed.amount = "1500.00"
mock_parsed.account_number = "12345678"
pipeline.payment_line_parser = MagicMock()
pipeline.payment_line_parser.parse.return_value = mock_parsed
result = InferenceResult(
fields={
"payment_line": "# 12345678901234 # 1500 00 5 > 12345678#41#",
"OCR": "99999999999999", # Different OCR
},
confidence={"OCR": 0.85},
)
pipeline._cross_validate_payment_line(result)
# Payment line OCR should override
assert result.fields["OCR"] == "12345678901234"
assert result.confidence["OCR"] == 0.95
def test_payment_line_overrides_amount(self):
"""Test that payment line amount overrides detected amount."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
mock_parsed = MagicMock()
mock_parsed.is_valid = True
mock_parsed.ocr_number = None
mock_parsed.amount = "2500.50"
mock_parsed.account_number = None
pipeline.payment_line_parser = MagicMock()
pipeline.payment_line_parser.parse.return_value = mock_parsed
result = InferenceResult(
fields={
"payment_line": "# ... # 2500 50 5 > ...",
"Amount": "2500.00", # Slightly different
},
confidence={"Amount": 0.80},
)
pipeline._cross_validate_payment_line(result)
assert result.fields["Amount"] == "2500.50"
assert result.confidence["Amount"] == 0.95
def test_cross_validation_records_matches(self):
"""Test that cross-validation records match status."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
mock_parsed = MagicMock()
mock_parsed.is_valid = True
mock_parsed.ocr_number = "12345678901234"
mock_parsed.amount = "1500.00"
mock_parsed.account_number = "12345678"
pipeline.payment_line_parser = MagicMock()
pipeline.payment_line_parser.parse.return_value = mock_parsed
result = InferenceResult(
fields={
"payment_line": "# 12345678901234 # 1500 00 5 > 12345678#41#",
"OCR": "12345678901234", # Matching
"Amount": "1500.00", # Matching
"Bankgiro": "1234-5678", # Matching
},
confidence={},
)
pipeline._cross_validate_payment_line(result)
assert result.cross_validation is not None
assert result.cross_validation.ocr_match is True
assert result.cross_validation.amount_match is True
assert result.cross_validation.is_valid is True
class TestFallbackLogic:
"""Tests for fallback detection logic."""
def test_needs_fallback_when_key_fields_missing(self):
"""Test fallback is triggered when key fields missing."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
# Only one key field present
result = InferenceResult(fields={"Amount": "1500.00"})
assert pipeline._needs_fallback(result) is True
def test_no_fallback_when_fields_present(self):
"""Test no fallback when key fields present."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
# All key fields present
result = InferenceResult(
fields={
"Amount": "1500.00",
"InvoiceNumber": "INV-001",
"OCR": "12345678901234",
}
)
assert pipeline._needs_fallback(result) is False
class TestPatternExtraction:
"""Tests for fallback pattern extraction."""
def test_extract_amount_pattern(self):
"""Test amount extraction with regex."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
text = "Att betala: 1 500,00 SEK"
result = InferenceResult()
pipeline._extract_with_patterns(text, result)
assert "Amount" in result.fields
assert result.confidence["Amount"] == 0.5
def test_extract_bankgiro_pattern(self):
"""Test bankgiro extraction with regex."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
text = "Bankgiro: 1234-5678"
result = InferenceResult()
pipeline._extract_with_patterns(text, result)
assert "Bankgiro" in result.fields
assert result.fields["Bankgiro"] == "1234-5678"
def test_extract_ocr_pattern(self):
"""Test OCR extraction with regex."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
text = "OCR: 12345678901234567890"
result = InferenceResult()
pipeline._extract_with_patterns(text, result)
assert "OCR" in result.fields
assert result.fields["OCR"] == "12345678901234567890"
def test_does_not_override_existing_fields(self):
"""Test pattern extraction doesn't override existing fields."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
text = "Fakturanr: 999"
result = InferenceResult(fields={"InvoiceNumber": "INV-001"})
pipeline._extract_with_patterns(text, result)
# Should keep existing value
assert result.fields["InvoiceNumber"] == "INV-001"
class TestAmountNormalization:
"""Tests for amount normalization."""
def test_normalize_swedish_format(self):
"""Test normalizing Swedish amount format."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
# Swedish format: space as thousands separator, comma as decimal
assert pipeline._normalize_amount_for_compare("1 500,00") == 1500.00
# Standard format: dot as decimal
assert pipeline._normalize_amount_for_compare("1500.00") == 1500.00
# Swedish format with comma as decimal
assert pipeline._normalize_amount_for_compare("1500,00") == 1500.00
def test_normalize_invalid_amount(self):
"""Test normalizing invalid amount returns None."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
assert pipeline._normalize_amount_for_compare("invalid") is None
assert pipeline._normalize_amount_for_compare("") is None
class TestResultSerialization:
"""Tests for result serialization with cross-validation."""
def test_to_json_with_cross_validation(self):
"""Test JSON serialization includes cross-validation."""
cv = CrossValidationResult(
is_valid=True,
ocr_match=True,
amount_match=True,
payment_line_ocr="12345678901234",
payment_line_amount="1500.00",
details=["OCR match", "Amount match"],
)
result = InferenceResult(
document_id="test-doc",
success=True,
fields={"InvoiceNumber": "INV-001"},
cross_validation=cv,
)
json_data = result.to_json()
assert "cross_validation" in json_data
assert json_data["cross_validation"]["is_valid"] is True
assert json_data["cross_validation"]["ocr_match"] is True
assert json_data["cross_validation"]["payment_line_ocr"] == "12345678901234"

View File

@@ -0,0 +1 @@
"""Repository integration tests."""

View File

@@ -0,0 +1,464 @@
"""
Annotation Repository Integration Tests
Tests AnnotationRepository with real database operations.
"""
from uuid import uuid4
import pytest
from inference.data.repositories.annotation_repository import AnnotationRepository
class TestAnnotationRepositoryCreate:
"""Tests for annotation creation."""
def test_create_annotation(self, patched_session, sample_document):
"""Test creating a single annotation."""
repo = AnnotationRepository()
ann_id = repo.create(
document_id=str(sample_document.document_id),
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.3,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=240,
bbox_width=160,
bbox_height=40,
text_value="INV-2024-001",
confidence=0.95,
source="auto",
)
assert ann_id is not None
ann = repo.get(ann_id)
assert ann is not None
assert ann.class_name == "invoice_number"
assert ann.text_value == "INV-2024-001"
assert ann.confidence == 0.95
assert ann.source == "auto"
def test_create_batch_annotations(self, patched_session, sample_document):
"""Test batch creation of annotations."""
repo = AnnotationRepository()
annotations_data = [
{
"document_id": str(sample_document.document_id),
"page_number": 1,
"class_id": 0,
"class_name": "invoice_number",
"x_center": 0.5,
"y_center": 0.1,
"width": 0.2,
"height": 0.05,
"bbox_x": 400,
"bbox_y": 80,
"bbox_width": 160,
"bbox_height": 40,
"text_value": "INV-001",
"confidence": 0.95,
},
{
"document_id": str(sample_document.document_id),
"page_number": 1,
"class_id": 1,
"class_name": "invoice_date",
"x_center": 0.5,
"y_center": 0.2,
"width": 0.15,
"height": 0.04,
"bbox_x": 400,
"bbox_y": 160,
"bbox_width": 120,
"bbox_height": 32,
"text_value": "2024-01-15",
"confidence": 0.92,
},
{
"document_id": str(sample_document.document_id),
"page_number": 1,
"class_id": 6,
"class_name": "amount",
"x_center": 0.7,
"y_center": 0.8,
"width": 0.1,
"height": 0.04,
"bbox_x": 560,
"bbox_y": 640,
"bbox_width": 80,
"bbox_height": 32,
"text_value": "1500.00",
"confidence": 0.98,
},
]
ids = repo.create_batch(annotations_data)
assert len(ids) == 3
# Verify all annotations exist
for ann_id in ids:
ann = repo.get(ann_id)
assert ann is not None
class TestAnnotationRepositoryRead:
"""Tests for annotation retrieval."""
def test_get_nonexistent_annotation(self, patched_session):
"""Test getting an annotation that doesn't exist."""
repo = AnnotationRepository()
ann = repo.get(str(uuid4()))
assert ann is None
def test_get_annotations_for_document(self, patched_session, sample_document, sample_annotation):
"""Test getting all annotations for a document."""
repo = AnnotationRepository()
# Add another annotation
repo.create(
document_id=str(sample_document.document_id),
page_number=1,
class_id=1,
class_name="invoice_date",
x_center=0.5,
y_center=0.4,
width=0.15,
height=0.04,
bbox_x=400,
bbox_y=320,
bbox_width=120,
bbox_height=32,
text_value="2024-01-15",
)
annotations = repo.get_for_document(str(sample_document.document_id))
assert len(annotations) == 2
# Should be ordered by class_id
assert annotations[0].class_id == 0
assert annotations[1].class_id == 1
def test_get_annotations_for_specific_page(self, patched_session, sample_document):
"""Test getting annotations for a specific page."""
repo = AnnotationRepository()
# Create annotations on different pages
repo.create(
document_id=str(sample_document.document_id),
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.1,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=80,
bbox_width=160,
bbox_height=40,
)
repo.create(
document_id=str(sample_document.document_id),
page_number=2,
class_id=6,
class_name="amount",
x_center=0.7,
y_center=0.8,
width=0.1,
height=0.04,
bbox_x=560,
bbox_y=640,
bbox_width=80,
bbox_height=32,
)
page1_annotations = repo.get_for_document(
str(sample_document.document_id),
page_number=1,
)
page2_annotations = repo.get_for_document(
str(sample_document.document_id),
page_number=2,
)
assert len(page1_annotations) == 1
assert len(page2_annotations) == 1
assert page1_annotations[0].page_number == 1
assert page2_annotations[0].page_number == 2
class TestAnnotationRepositoryUpdate:
"""Tests for annotation updates."""
def test_update_annotation_bbox(self, patched_session, sample_annotation):
"""Test updating annotation bounding box."""
repo = AnnotationRepository()
result = repo.update(
str(sample_annotation.annotation_id),
x_center=0.6,
y_center=0.4,
width=0.25,
height=0.06,
bbox_x=480,
bbox_y=320,
bbox_width=200,
bbox_height=48,
)
assert result is True
ann = repo.get(str(sample_annotation.annotation_id))
assert ann is not None
assert ann.x_center == 0.6
assert ann.y_center == 0.4
assert ann.bbox_x == 480
assert ann.bbox_width == 200
def test_update_annotation_text(self, patched_session, sample_annotation):
"""Test updating annotation text value."""
repo = AnnotationRepository()
result = repo.update(
str(sample_annotation.annotation_id),
text_value="INV-2024-002",
)
assert result is True
ann = repo.get(str(sample_annotation.annotation_id))
assert ann is not None
assert ann.text_value == "INV-2024-002"
def test_update_annotation_class(self, patched_session, sample_annotation):
"""Test updating annotation class."""
repo = AnnotationRepository()
result = repo.update(
str(sample_annotation.annotation_id),
class_id=1,
class_name="invoice_date",
)
assert result is True
ann = repo.get(str(sample_annotation.annotation_id))
assert ann is not None
assert ann.class_id == 1
assert ann.class_name == "invoice_date"
def test_update_nonexistent_annotation(self, patched_session):
"""Test updating annotation that doesn't exist."""
repo = AnnotationRepository()
result = repo.update(
str(uuid4()),
text_value="new value",
)
assert result is False
class TestAnnotationRepositoryDelete:
"""Tests for annotation deletion."""
def test_delete_annotation(self, patched_session, sample_annotation):
"""Test deleting a single annotation."""
repo = AnnotationRepository()
result = repo.delete(str(sample_annotation.annotation_id))
assert result is True
ann = repo.get(str(sample_annotation.annotation_id))
assert ann is None
def test_delete_nonexistent_annotation(self, patched_session):
"""Test deleting annotation that doesn't exist."""
repo = AnnotationRepository()
result = repo.delete(str(uuid4()))
assert result is False
def test_delete_annotations_for_document(self, patched_session, sample_document):
"""Test deleting all annotations for a document."""
repo = AnnotationRepository()
# Create multiple annotations
for i in range(3):
repo.create(
document_id=str(sample_document.document_id),
page_number=1,
class_id=i,
class_name=f"field_{i}",
x_center=0.5,
y_center=0.1 + i * 0.2,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=80 + i * 160,
bbox_width=160,
bbox_height=40,
)
# Delete all
count = repo.delete_for_document(str(sample_document.document_id))
assert count == 3
annotations = repo.get_for_document(str(sample_document.document_id))
assert len(annotations) == 0
def test_delete_annotations_by_source(self, patched_session, sample_document):
"""Test deleting annotations by source type."""
repo = AnnotationRepository()
# Create auto and manual annotations
repo.create(
document_id=str(sample_document.document_id),
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.1,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=80,
bbox_width=160,
bbox_height=40,
source="auto",
)
repo.create(
document_id=str(sample_document.document_id),
page_number=1,
class_id=1,
class_name="invoice_date",
x_center=0.5,
y_center=0.2,
width=0.15,
height=0.04,
bbox_x=400,
bbox_y=160,
bbox_width=120,
bbox_height=32,
source="manual",
)
# Delete only auto annotations
count = repo.delete_for_document(str(sample_document.document_id), source="auto")
assert count == 1
remaining = repo.get_for_document(str(sample_document.document_id))
assert len(remaining) == 1
assert remaining[0].source == "manual"
class TestAnnotationVerification:
"""Tests for annotation verification."""
def test_verify_annotation(self, patched_session, admin_token, sample_annotation):
"""Test marking annotation as verified."""
repo = AnnotationRepository()
ann = repo.verify(str(sample_annotation.annotation_id), admin_token.token)
assert ann is not None
assert ann.is_verified is True
assert ann.verified_by == admin_token.token
assert ann.verified_at is not None
class TestAnnotationOverride:
"""Tests for annotation override functionality."""
def test_override_auto_annotation(self, patched_session, admin_token, sample_annotation):
"""Test overriding an auto-generated annotation."""
repo = AnnotationRepository()
# Override the annotation
ann = repo.override(
str(sample_annotation.annotation_id),
admin_token.token,
change_reason="Correcting OCR error",
text_value="INV-2024-CORRECTED",
x_center=0.55,
)
assert ann is not None
assert ann.text_value == "INV-2024-CORRECTED"
assert ann.x_center == 0.55
assert ann.source == "manual" # Changed from auto to manual
assert ann.override_source == "auto"
class TestAnnotationHistory:
"""Tests for annotation history tracking."""
def test_create_history_record(self, patched_session, sample_annotation):
"""Test creating annotation history record."""
repo = AnnotationRepository()
history = repo.create_history(
annotation_id=sample_annotation.annotation_id,
document_id=sample_annotation.document_id,
action="created",
new_value={"text_value": "INV-001"},
changed_by="test-user",
)
assert history is not None
assert history.action == "created"
assert history.changed_by == "test-user"
def test_get_annotation_history(self, patched_session, sample_annotation):
"""Test getting history for an annotation."""
repo = AnnotationRepository()
# Create history records
repo.create_history(
annotation_id=sample_annotation.annotation_id,
document_id=sample_annotation.document_id,
action="created",
new_value={"text_value": "INV-001"},
)
repo.create_history(
annotation_id=sample_annotation.annotation_id,
document_id=sample_annotation.document_id,
action="updated",
previous_value={"text_value": "INV-001"},
new_value={"text_value": "INV-002"},
)
history = repo.get_history(sample_annotation.annotation_id)
assert len(history) == 2
# Should be ordered by created_at desc
assert history[0].action == "updated"
assert history[1].action == "created"
def test_get_document_history(self, patched_session, sample_document, sample_annotation):
"""Test getting all annotation history for a document."""
repo = AnnotationRepository()
repo.create_history(
annotation_id=sample_annotation.annotation_id,
document_id=sample_document.document_id,
action="created",
new_value={"class_name": "invoice_number"},
)
history = repo.get_document_history(sample_document.document_id)
assert len(history) >= 1
assert all(h.document_id == sample_document.document_id for h in history)

View File

@@ -0,0 +1,355 @@
"""
Batch Upload Repository Integration Tests
Tests BatchUploadRepository with real database operations.
"""
from datetime import datetime, timezone
from uuid import uuid4
import pytest
from inference.data.repositories.batch_upload_repository import BatchUploadRepository
class TestBatchUploadCreate:
"""Tests for batch upload creation."""
def test_create_batch_upload(self, patched_session, admin_token):
"""Test creating a batch upload."""
repo = BatchUploadRepository()
batch = repo.create(
admin_token=admin_token.token,
filename="test_batch.zip",
file_size=10240,
upload_source="api",
)
assert batch is not None
assert batch.batch_id is not None
assert batch.filename == "test_batch.zip"
assert batch.file_size == 10240
assert batch.upload_source == "api"
assert batch.status == "processing"
assert batch.total_files == 0
assert batch.processed_files == 0
def test_create_batch_upload_default_source(self, patched_session, admin_token):
"""Test creating batch upload with default source."""
repo = BatchUploadRepository()
batch = repo.create(
admin_token=admin_token.token,
filename="ui_batch.zip",
file_size=5120,
)
assert batch.upload_source == "ui"
class TestBatchUploadRead:
"""Tests for batch upload retrieval."""
def test_get_batch_upload(self, patched_session, sample_batch_upload):
"""Test getting a batch upload by ID."""
repo = BatchUploadRepository()
batch = repo.get(sample_batch_upload.batch_id)
assert batch is not None
assert batch.batch_id == sample_batch_upload.batch_id
assert batch.filename == sample_batch_upload.filename
def test_get_nonexistent_batch_upload(self, patched_session):
"""Test getting a batch upload that doesn't exist."""
repo = BatchUploadRepository()
batch = repo.get(uuid4())
assert batch is None
def test_get_paginated_batch_uploads(self, patched_session, admin_token):
"""Test paginated batch upload listing."""
repo = BatchUploadRepository()
# Create multiple batches
for i in range(5):
repo.create(
admin_token=admin_token.token,
filename=f"batch_{i}.zip",
file_size=1024 * (i + 1),
)
batches, total = repo.get_paginated(limit=3, offset=0)
assert total == 5
assert len(batches) == 3
def test_get_paginated_with_offset(self, patched_session, admin_token):
"""Test pagination offset."""
repo = BatchUploadRepository()
for i in range(5):
repo.create(
admin_token=admin_token.token,
filename=f"batch_{i}.zip",
file_size=1024,
)
page1, _ = repo.get_paginated(limit=2, offset=0)
page2, _ = repo.get_paginated(limit=2, offset=2)
ids_page1 = {b.batch_id for b in page1}
ids_page2 = {b.batch_id for b in page2}
assert len(ids_page1 & ids_page2) == 0
class TestBatchUploadUpdate:
"""Tests for batch upload updates."""
def test_update_batch_status(self, patched_session, sample_batch_upload):
"""Test updating batch upload status."""
repo = BatchUploadRepository()
repo.update(
sample_batch_upload.batch_id,
status="completed",
total_files=10,
processed_files=10,
successful_files=8,
failed_files=2,
)
# Need to commit to see changes
patched_session.commit()
batch = repo.get(sample_batch_upload.batch_id)
assert batch.status == "completed"
assert batch.total_files == 10
assert batch.successful_files == 8
assert batch.failed_files == 2
def test_update_batch_with_error(self, patched_session, sample_batch_upload):
"""Test updating batch upload with error message."""
repo = BatchUploadRepository()
repo.update(
sample_batch_upload.batch_id,
status="failed",
error_message="ZIP extraction failed",
)
patched_session.commit()
batch = repo.get(sample_batch_upload.batch_id)
assert batch.status == "failed"
assert batch.error_message == "ZIP extraction failed"
def test_update_batch_csv_info(self, patched_session, sample_batch_upload):
"""Test updating batch with CSV information."""
repo = BatchUploadRepository()
repo.update(
sample_batch_upload.batch_id,
csv_filename="manifest.csv",
csv_row_count=100,
)
patched_session.commit()
batch = repo.get(sample_batch_upload.batch_id)
assert batch.csv_filename == "manifest.csv"
assert batch.csv_row_count == 100
class TestBatchUploadFiles:
"""Tests for batch upload file management."""
def test_create_batch_file(self, patched_session, sample_batch_upload):
"""Test creating a batch upload file record."""
repo = BatchUploadRepository()
file_record = repo.create_file(
batch_id=sample_batch_upload.batch_id,
filename="invoice_001.pdf",
status="pending",
)
assert file_record is not None
assert file_record.file_id is not None
assert file_record.filename == "invoice_001.pdf"
assert file_record.batch_id == sample_batch_upload.batch_id
assert file_record.status == "pending"
def test_create_batch_file_with_document_link(self, patched_session, sample_batch_upload, sample_document):
"""Test creating batch file linked to a document."""
repo = BatchUploadRepository()
file_record = repo.create_file(
batch_id=sample_batch_upload.batch_id,
filename="invoice_linked.pdf",
document_id=sample_document.document_id,
status="completed",
annotation_count=5,
)
assert file_record.document_id == sample_document.document_id
assert file_record.status == "completed"
assert file_record.annotation_count == 5
def test_get_batch_files(self, patched_session, sample_batch_upload):
"""Test getting all files for a batch."""
repo = BatchUploadRepository()
# Create multiple files
for i in range(3):
repo.create_file(
batch_id=sample_batch_upload.batch_id,
filename=f"file_{i}.pdf",
)
files = repo.get_files(sample_batch_upload.batch_id)
assert len(files) == 3
assert all(f.batch_id == sample_batch_upload.batch_id for f in files)
def test_get_batch_files_empty(self, patched_session, sample_batch_upload):
"""Test getting files for batch with no files."""
repo = BatchUploadRepository()
files = repo.get_files(sample_batch_upload.batch_id)
assert files == []
def test_update_batch_file_status(self, patched_session, sample_batch_upload):
"""Test updating batch file status."""
repo = BatchUploadRepository()
file_record = repo.create_file(
batch_id=sample_batch_upload.batch_id,
filename="test.pdf",
)
repo.update_file(
file_record.file_id,
status="completed",
annotation_count=10,
)
patched_session.commit()
files = repo.get_files(sample_batch_upload.batch_id)
updated_file = files[0]
assert updated_file.status == "completed"
assert updated_file.annotation_count == 10
def test_update_batch_file_with_error(self, patched_session, sample_batch_upload):
"""Test updating batch file with error."""
repo = BatchUploadRepository()
file_record = repo.create_file(
batch_id=sample_batch_upload.batch_id,
filename="corrupt.pdf",
)
repo.update_file(
file_record.file_id,
status="failed",
error_message="Invalid PDF format",
)
patched_session.commit()
files = repo.get_files(sample_batch_upload.batch_id)
updated_file = files[0]
assert updated_file.status == "failed"
assert updated_file.error_message == "Invalid PDF format"
def test_update_batch_file_with_csv_data(self, patched_session, sample_batch_upload):
"""Test updating batch file with CSV row data."""
repo = BatchUploadRepository()
file_record = repo.create_file(
batch_id=sample_batch_upload.batch_id,
filename="invoice_with_csv.pdf",
)
csv_data = {
"invoice_number": "INV-001",
"amount": "1500.00",
"supplier": "Test Corp",
}
repo.update_file(
file_record.file_id,
csv_row_data=csv_data,
)
patched_session.commit()
files = repo.get_files(sample_batch_upload.batch_id)
updated_file = files[0]
assert updated_file.csv_row_data == csv_data
class TestBatchUploadWorkflow:
"""Tests for complete batch upload workflows."""
def test_complete_batch_workflow(self, patched_session, admin_token):
"""Test complete batch upload workflow."""
repo = BatchUploadRepository()
# 1. Create batch
batch = repo.create(
admin_token=admin_token.token,
filename="full_workflow.zip",
file_size=50000,
)
# 2. Update with file count
repo.update(batch.batch_id, total_files=3)
patched_session.commit()
# 3. Create file records
file_ids = []
for i in range(3):
file_record = repo.create_file(
batch_id=batch.batch_id,
filename=f"doc_{i}.pdf",
)
file_ids.append(file_record.file_id)
# 4. Process files one by one
for i, file_id in enumerate(file_ids):
status = "completed" if i < 2 else "failed"
repo.update_file(
file_id,
status=status,
annotation_count=5 if status == "completed" else 0,
)
# 5. Update batch progress
repo.update(
batch.batch_id,
processed_files=3,
successful_files=2,
failed_files=1,
status="partial",
)
patched_session.commit()
# Verify final state
final_batch = repo.get(batch.batch_id)
assert final_batch.status == "partial"
assert final_batch.total_files == 3
assert final_batch.processed_files == 3
assert final_batch.successful_files == 2
assert final_batch.failed_files == 1
files = repo.get_files(batch.batch_id)
assert len(files) == 3
completed = [f for f in files if f.status == "completed"]
failed = [f for f in files if f.status == "failed"]
assert len(completed) == 2
assert len(failed) == 1

View File

@@ -0,0 +1,321 @@
"""
Dataset Repository Integration Tests
Tests DatasetRepository with real database operations.
"""
from uuid import uuid4
import pytest
from inference.data.repositories.dataset_repository import DatasetRepository
class TestDatasetRepositoryCreate:
"""Tests for dataset creation."""
def test_create_dataset(self, patched_session):
"""Test creating a training dataset."""
repo = DatasetRepository()
dataset = repo.create(
name="Test Dataset",
description="Dataset for integration testing",
train_ratio=0.8,
val_ratio=0.1,
seed=42,
)
assert dataset is not None
assert dataset.name == "Test Dataset"
assert dataset.description == "Dataset for integration testing"
assert dataset.train_ratio == 0.8
assert dataset.val_ratio == 0.1
assert dataset.seed == 42
assert dataset.status == "building"
def test_create_dataset_with_defaults(self, patched_session):
"""Test creating dataset with default values."""
repo = DatasetRepository()
dataset = repo.create(name="Minimal Dataset")
assert dataset is not None
assert dataset.train_ratio == 0.8
assert dataset.val_ratio == 0.1
assert dataset.seed == 42
class TestDatasetRepositoryRead:
"""Tests for dataset retrieval."""
def test_get_dataset_by_id(self, patched_session, sample_dataset):
"""Test getting dataset by ID."""
repo = DatasetRepository()
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is not None
assert dataset.dataset_id == sample_dataset.dataset_id
assert dataset.name == sample_dataset.name
def test_get_nonexistent_dataset(self, patched_session):
"""Test getting dataset that doesn't exist."""
repo = DatasetRepository()
dataset = repo.get(str(uuid4()))
assert dataset is None
def test_get_paginated_datasets(self, patched_session):
"""Test paginated dataset listing."""
repo = DatasetRepository()
# Create multiple datasets
for i in range(5):
repo.create(name=f"Dataset {i}")
datasets, total = repo.get_paginated(limit=2, offset=0)
assert total == 5
assert len(datasets) == 2
def test_get_paginated_with_status_filter(self, patched_session):
"""Test filtering datasets by status."""
repo = DatasetRepository()
# Create datasets with different statuses
d1 = repo.create(name="Building Dataset")
repo.update_status(str(d1.dataset_id), "ready")
d2 = repo.create(name="Another Building Dataset")
# stays as "building"
datasets, total = repo.get_paginated(status="ready")
assert total == 1
assert datasets[0].status == "ready"
class TestDatasetRepositoryUpdate:
"""Tests for dataset updates."""
def test_update_status(self, patched_session, sample_dataset):
"""Test updating dataset status."""
repo = DatasetRepository()
repo.update_status(
str(sample_dataset.dataset_id),
status="ready",
total_documents=100,
total_images=150,
total_annotations=500,
)
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is not None
assert dataset.status == "ready"
assert dataset.total_documents == 100
assert dataset.total_images == 150
assert dataset.total_annotations == 500
def test_update_status_with_error(self, patched_session, sample_dataset):
"""Test updating dataset status with error message."""
repo = DatasetRepository()
repo.update_status(
str(sample_dataset.dataset_id),
status="failed",
error_message="Failed to build dataset: insufficient documents",
)
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is not None
assert dataset.status == "failed"
assert "insufficient documents" in dataset.error_message
def test_update_status_with_path(self, patched_session, sample_dataset):
"""Test updating dataset path."""
repo = DatasetRepository()
repo.update_status(
str(sample_dataset.dataset_id),
status="ready",
dataset_path="/datasets/test_dataset_2024",
)
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is not None
assert dataset.dataset_path == "/datasets/test_dataset_2024"
def test_update_training_status(self, patched_session, sample_dataset, sample_training_task):
"""Test updating dataset training status."""
repo = DatasetRepository()
repo.update_training_status(
str(sample_dataset.dataset_id),
training_status="running",
active_training_task_id=str(sample_training_task.task_id),
)
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is not None
assert dataset.training_status == "running"
assert dataset.active_training_task_id == sample_training_task.task_id
def test_update_training_status_completed(self, patched_session, sample_dataset):
"""Test updating training status to completed updates main status."""
repo = DatasetRepository()
# First set to ready
repo.update_status(str(sample_dataset.dataset_id), status="ready")
# Then complete training
repo.update_training_status(
str(sample_dataset.dataset_id),
training_status="completed",
update_main_status=True,
)
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is not None
assert dataset.training_status == "completed"
assert dataset.status == "trained"
class TestDatasetDocuments:
"""Tests for dataset document management."""
def test_add_documents_to_dataset(self, patched_session, sample_dataset, multiple_documents):
"""Test adding documents to a dataset."""
repo = DatasetRepository()
documents_data = [
{
"document_id": str(multiple_documents[0].document_id),
"split": "train",
"page_count": 1,
"annotation_count": 5,
},
{
"document_id": str(multiple_documents[1].document_id),
"split": "train",
"page_count": 2,
"annotation_count": 8,
},
{
"document_id": str(multiple_documents[2].document_id),
"split": "val",
"page_count": 1,
"annotation_count": 3,
},
]
repo.add_documents(str(sample_dataset.dataset_id), documents_data)
# Verify documents were added
docs = repo.get_documents(str(sample_dataset.dataset_id))
assert len(docs) == 3
train_docs = [d for d in docs if d.split == "train"]
val_docs = [d for d in docs if d.split == "val"]
assert len(train_docs) == 2
assert len(val_docs) == 1
def test_get_dataset_documents(self, patched_session, sample_dataset, sample_document):
"""Test getting documents from a dataset."""
repo = DatasetRepository()
repo.add_documents(
str(sample_dataset.dataset_id),
[
{
"document_id": str(sample_document.document_id),
"split": "train",
"page_count": 1,
"annotation_count": 5,
}
],
)
docs = repo.get_documents(str(sample_dataset.dataset_id))
assert len(docs) == 1
assert docs[0].document_id == sample_document.document_id
assert docs[0].split == "train"
assert docs[0].page_count == 1
assert docs[0].annotation_count == 5
class TestDatasetRepositoryDelete:
"""Tests for dataset deletion."""
def test_delete_dataset(self, patched_session, sample_dataset):
"""Test deleting a dataset."""
repo = DatasetRepository()
result = repo.delete(str(sample_dataset.dataset_id))
assert result is True
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is None
def test_delete_nonexistent_dataset(self, patched_session):
"""Test deleting dataset that doesn't exist."""
repo = DatasetRepository()
result = repo.delete(str(uuid4()))
assert result is False
def test_delete_dataset_cascades_documents(self, patched_session, sample_dataset, sample_document):
"""Test deleting dataset also removes document links."""
repo = DatasetRepository()
# Add document to dataset
repo.add_documents(
str(sample_dataset.dataset_id),
[
{
"document_id": str(sample_document.document_id),
"split": "train",
"page_count": 1,
"annotation_count": 5,
}
],
)
# Delete dataset
repo.delete(str(sample_dataset.dataset_id))
# Document links should be gone
docs = repo.get_documents(str(sample_dataset.dataset_id))
assert len(docs) == 0
class TestActiveTrainingTasks:
"""Tests for active training task queries."""
def test_get_active_training_tasks(self, patched_session, sample_dataset, sample_training_task):
"""Test getting active training tasks for datasets."""
repo = DatasetRepository()
# Update task to running
from inference.data.repositories.training_task_repository import TrainingTaskRepository
task_repo = TrainingTaskRepository()
task_repo.update_status(str(sample_training_task.task_id), "running")
result = repo.get_active_training_tasks([str(sample_dataset.dataset_id)])
assert str(sample_dataset.dataset_id) in result
assert result[str(sample_dataset.dataset_id)]["status"] == "running"
def test_get_active_training_tasks_empty(self, patched_session, sample_dataset):
"""Test getting active training tasks returns empty when no tasks exist."""
repo = DatasetRepository()
result = repo.get_active_training_tasks([str(sample_dataset.dataset_id)])
# No training task exists for this dataset, so result should be empty
assert str(sample_dataset.dataset_id) not in result
assert result == {}

View File

@@ -0,0 +1,350 @@
"""
Document Repository Integration Tests
Tests DocumentRepository with real database operations.
"""
from datetime import datetime, timezone, timedelta
from uuid import uuid4
import pytest
from sqlmodel import select
from inference.data.admin_models import AdminAnnotation, AdminDocument
from inference.data.repositories.document_repository import DocumentRepository
def ensure_utc(dt: datetime | None) -> datetime | None:
"""Ensure datetime is timezone-aware (UTC).
PostgreSQL may return offset-naive datetimes. This helper
converts them to UTC for proper comparison.
"""
if dt is None:
return None
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
return dt
class TestDocumentRepositoryCreate:
"""Tests for document creation."""
def test_create_document(self, patched_session):
"""Test creating a document and retrieving it."""
repo = DocumentRepository()
doc_id = repo.create(
filename="test_invoice.pdf",
file_size=2048,
content_type="application/pdf",
file_path="/uploads/test_invoice.pdf",
page_count=2,
upload_source="api",
category="invoice",
)
assert doc_id is not None
doc = repo.get(doc_id)
assert doc is not None
assert doc.filename == "test_invoice.pdf"
assert doc.file_size == 2048
assert doc.page_count == 2
assert doc.upload_source == "api"
assert doc.category == "invoice"
assert doc.status == "pending"
def test_create_document_with_csv_values(self, patched_session):
"""Test creating document with CSV field values."""
repo = DocumentRepository()
csv_values = {
"invoice_number": "INV-001",
"amount": "1500.00",
"supplier_name": "Test Supplier AB",
}
doc_id = repo.create(
filename="invoice_with_csv.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/invoice_with_csv.pdf",
csv_field_values=csv_values,
)
doc = repo.get(doc_id)
assert doc is not None
assert doc.csv_field_values == csv_values
def test_create_document_with_group_key(self, patched_session):
"""Test creating document with group key."""
repo = DocumentRepository()
doc_id = repo.create(
filename="grouped_doc.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/grouped_doc.pdf",
group_key="batch-2024-01",
)
doc = repo.get(doc_id)
assert doc is not None
assert doc.group_key == "batch-2024-01"
class TestDocumentRepositoryRead:
"""Tests for document retrieval."""
def test_get_nonexistent_document(self, patched_session):
"""Test getting a document that doesn't exist."""
repo = DocumentRepository()
doc = repo.get(str(uuid4()))
assert doc is None
def test_get_paginated_documents(self, patched_session, multiple_documents):
"""Test paginated document listing."""
repo = DocumentRepository()
docs, total = repo.get_paginated(limit=2, offset=0)
assert total == 5
assert len(docs) == 2
def test_get_paginated_with_status_filter(self, patched_session, multiple_documents):
"""Test filtering documents by status."""
repo = DocumentRepository()
docs, total = repo.get_paginated(status="labeled")
assert total == 2
for doc in docs:
assert doc.status == "labeled"
def test_get_paginated_with_category_filter(self, patched_session, multiple_documents):
"""Test filtering documents by category."""
repo = DocumentRepository()
docs, total = repo.get_paginated(category="letter")
assert total == 1
assert docs[0].category == "letter"
def test_get_paginated_with_offset(self, patched_session, multiple_documents):
"""Test pagination offset."""
repo = DocumentRepository()
docs_page1, _ = repo.get_paginated(limit=2, offset=0)
docs_page2, _ = repo.get_paginated(limit=2, offset=2)
doc_ids_page1 = {str(d.document_id) for d in docs_page1}
doc_ids_page2 = {str(d.document_id) for d in docs_page2}
assert len(doc_ids_page1 & doc_ids_page2) == 0
def test_get_by_ids(self, patched_session, multiple_documents):
"""Test getting multiple documents by IDs."""
repo = DocumentRepository()
ids_to_fetch = [str(multiple_documents[0].document_id), str(multiple_documents[2].document_id)]
docs = repo.get_by_ids(ids_to_fetch)
assert len(docs) == 2
fetched_ids = {str(d.document_id) for d in docs}
assert fetched_ids == set(ids_to_fetch)
class TestDocumentRepositoryUpdate:
"""Tests for document updates."""
def test_update_status(self, patched_session, sample_document):
"""Test updating document status."""
repo = DocumentRepository()
repo.update_status(
str(sample_document.document_id),
status="labeled",
auto_label_status="completed",
)
doc = repo.get(str(sample_document.document_id))
assert doc is not None
assert doc.status == "labeled"
assert doc.auto_label_status == "completed"
def test_update_status_with_error(self, patched_session, sample_document):
"""Test updating document status with error message."""
repo = DocumentRepository()
repo.update_status(
str(sample_document.document_id),
status="pending",
auto_label_status="failed",
auto_label_error="OCR extraction failed",
)
doc = repo.get(str(sample_document.document_id))
assert doc is not None
assert doc.auto_label_status == "failed"
assert doc.auto_label_error == "OCR extraction failed"
def test_update_file_path(self, patched_session, sample_document):
"""Test updating document file path."""
repo = DocumentRepository()
new_path = "/archive/2024/test_invoice.pdf"
repo.update_file_path(str(sample_document.document_id), new_path)
doc = repo.get(str(sample_document.document_id))
assert doc is not None
assert doc.file_path == new_path
def test_update_group_key(self, patched_session, sample_document):
"""Test updating document group key."""
repo = DocumentRepository()
result = repo.update_group_key(str(sample_document.document_id), "new-group-key")
assert result is True
doc = repo.get(str(sample_document.document_id))
assert doc is not None
assert doc.group_key == "new-group-key"
def test_update_category(self, patched_session, sample_document):
"""Test updating document category."""
repo = DocumentRepository()
doc = repo.update_category(str(sample_document.document_id), "letter")
assert doc is not None
assert doc.category == "letter"
class TestDocumentRepositoryDelete:
"""Tests for document deletion."""
def test_delete_document(self, patched_session, sample_document):
"""Test deleting a document."""
repo = DocumentRepository()
result = repo.delete(str(sample_document.document_id))
assert result is True
doc = repo.get(str(sample_document.document_id))
assert doc is None
def test_delete_document_with_annotations(self, patched_session, sample_document, sample_annotation):
"""Test deleting document also deletes its annotations."""
repo = DocumentRepository()
result = repo.delete(str(sample_document.document_id))
assert result is True
# Verify annotation is also deleted
from inference.data.repositories.annotation_repository import AnnotationRepository
ann_repo = AnnotationRepository()
annotations = ann_repo.get_for_document(str(sample_document.document_id))
assert len(annotations) == 0
def test_delete_nonexistent_document(self, patched_session):
"""Test deleting a document that doesn't exist."""
repo = DocumentRepository()
result = repo.delete(str(uuid4()))
assert result is False
class TestDocumentRepositoryQueries:
"""Tests for complex document queries."""
def test_count_by_status(self, patched_session, multiple_documents):
"""Test counting documents by status."""
repo = DocumentRepository()
counts = repo.count_by_status()
assert counts.get("pending") == 2
assert counts.get("labeled") == 2
assert counts.get("exported") == 1
def test_get_categories(self, patched_session, multiple_documents):
"""Test getting unique categories."""
repo = DocumentRepository()
categories = repo.get_categories()
assert "invoice" in categories
assert "letter" in categories
def test_get_labeled_for_export(self, patched_session, multiple_documents):
"""Test getting labeled documents for export."""
repo = DocumentRepository()
docs = repo.get_labeled_for_export()
assert len(docs) == 2
for doc in docs:
assert doc.status == "labeled"
class TestDocumentAnnotationLocking:
"""Tests for annotation locking mechanism."""
def test_acquire_annotation_lock(self, patched_session, sample_document):
"""Test acquiring annotation lock."""
repo = DocumentRepository()
doc = repo.acquire_annotation_lock(
str(sample_document.document_id),
duration_seconds=300,
)
assert doc is not None
assert doc.annotation_lock_until is not None
lock_until = ensure_utc(doc.annotation_lock_until)
assert lock_until > datetime.now(timezone.utc)
def test_acquire_lock_when_already_locked(self, patched_session, sample_document):
"""Test acquiring lock fails when already locked."""
repo = DocumentRepository()
# First lock
repo.acquire_annotation_lock(str(sample_document.document_id), duration_seconds=300)
# Second lock attempt should fail
result = repo.acquire_annotation_lock(str(sample_document.document_id))
assert result is None
def test_release_annotation_lock(self, patched_session, sample_document):
"""Test releasing annotation lock."""
repo = DocumentRepository()
repo.acquire_annotation_lock(str(sample_document.document_id), duration_seconds=300)
doc = repo.release_annotation_lock(str(sample_document.document_id))
assert doc is not None
assert doc.annotation_lock_until is None
def test_extend_annotation_lock(self, patched_session, sample_document):
"""Test extending annotation lock."""
repo = DocumentRepository()
# Acquire initial lock
initial_doc = repo.acquire_annotation_lock(
str(sample_document.document_id),
duration_seconds=300,
)
initial_expiry = ensure_utc(initial_doc.annotation_lock_until)
# Extend lock
extended_doc = repo.extend_annotation_lock(
str(sample_document.document_id),
additional_seconds=300,
)
assert extended_doc is not None
extended_expiry = ensure_utc(extended_doc.annotation_lock_until)
assert extended_expiry > initial_expiry

View File

@@ -0,0 +1,310 @@
"""
Model Version Repository Integration Tests
Tests ModelVersionRepository with real database operations.
"""
from datetime import datetime, timezone
from uuid import uuid4
import pytest
from inference.data.repositories.model_version_repository import ModelVersionRepository
class TestModelVersionCreate:
"""Tests for model version creation."""
def test_create_model_version(self, patched_session):
"""Test creating a model version."""
repo = ModelVersionRepository()
model = repo.create(
version="1.0.0",
name="Invoice Extractor v1",
model_path="/models/invoice_v1.pt",
description="Initial production model",
metrics_mAP=0.92,
metrics_precision=0.89,
metrics_recall=0.85,
document_count=1000,
file_size=50000000,
)
assert model is not None
assert model.version == "1.0.0"
assert model.name == "Invoice Extractor v1"
assert model.model_path == "/models/invoice_v1.pt"
assert model.metrics_mAP == 0.92
assert model.is_active is False
assert model.status == "inactive"
def test_create_model_version_with_training_info(
self, patched_session, sample_training_task, sample_dataset
):
"""Test creating model version linked to training task and dataset."""
repo = ModelVersionRepository()
model = repo.create(
version="1.1.0",
name="Invoice Extractor v1.1",
model_path="/models/invoice_v1.1.pt",
task_id=sample_training_task.task_id,
dataset_id=sample_dataset.dataset_id,
training_config={"epochs": 100, "batch_size": 16},
trained_at=datetime.now(timezone.utc),
)
assert model is not None
assert model.task_id == sample_training_task.task_id
assert model.dataset_id == sample_dataset.dataset_id
assert model.training_config["epochs"] == 100
class TestModelVersionRead:
"""Tests for model version retrieval."""
def test_get_model_version_by_id(self, patched_session, sample_model_version):
"""Test getting model version by ID."""
repo = ModelVersionRepository()
model = repo.get(str(sample_model_version.version_id))
assert model is not None
assert model.version_id == sample_model_version.version_id
def test_get_nonexistent_model_version(self, patched_session):
"""Test getting model version that doesn't exist."""
repo = ModelVersionRepository()
model = repo.get(str(uuid4()))
assert model is None
def test_get_paginated_model_versions(self, patched_session):
"""Test paginated model version listing."""
repo = ModelVersionRepository()
# Create multiple versions
for i in range(5):
repo.create(
version=f"1.{i}.0",
name=f"Model v1.{i}",
model_path=f"/models/model_v1.{i}.pt",
)
models, total = repo.get_paginated(limit=2, offset=0)
assert total == 5
assert len(models) == 2
def test_get_paginated_with_status_filter(self, patched_session):
"""Test filtering model versions by status."""
repo = ModelVersionRepository()
# Create active and inactive models
m1 = repo.create(version="1.0.0", name="Active Model", model_path="/models/active.pt")
repo.activate(str(m1.version_id))
repo.create(version="2.0.0", name="Inactive Model", model_path="/models/inactive.pt")
active_models, active_total = repo.get_paginated(status="active")
inactive_models, inactive_total = repo.get_paginated(status="inactive")
assert active_total == 1
assert inactive_total == 1
class TestModelVersionActivation:
"""Tests for model version activation."""
def test_activate_model_version(self, patched_session, sample_model_version):
"""Test activating a model version."""
repo = ModelVersionRepository()
model = repo.activate(str(sample_model_version.version_id))
assert model is not None
assert model.is_active is True
assert model.status == "active"
assert model.activated_at is not None
def test_activate_deactivates_others(self, patched_session):
"""Test that activating one version deactivates others."""
repo = ModelVersionRepository()
# Create and activate first model
m1 = repo.create(version="1.0.0", name="Model 1", model_path="/models/m1.pt")
repo.activate(str(m1.version_id))
# Create and activate second model
m2 = repo.create(version="2.0.0", name="Model 2", model_path="/models/m2.pt")
repo.activate(str(m2.version_id))
# Check first model is now inactive
m1_after = repo.get(str(m1.version_id))
assert m1_after.is_active is False
assert m1_after.status == "inactive"
# Check second model is active
m2_after = repo.get(str(m2.version_id))
assert m2_after.is_active is True
def test_get_active_model(self, patched_session, sample_model_version):
"""Test getting the currently active model."""
repo = ModelVersionRepository()
# Initially no active model
active = repo.get_active()
assert active is None
# Activate model
repo.activate(str(sample_model_version.version_id))
# Now should return active model
active = repo.get_active()
assert active is not None
assert active.version_id == sample_model_version.version_id
def test_deactivate_model_version(self, patched_session, sample_model_version):
"""Test deactivating a model version."""
repo = ModelVersionRepository()
# First activate
repo.activate(str(sample_model_version.version_id))
# Then deactivate
model = repo.deactivate(str(sample_model_version.version_id))
assert model is not None
assert model.is_active is False
assert model.status == "inactive"
class TestModelVersionUpdate:
"""Tests for model version updates."""
def test_update_model_metadata(self, patched_session, sample_model_version):
"""Test updating model version metadata."""
repo = ModelVersionRepository()
model = repo.update(
str(sample_model_version.version_id),
name="Updated Model Name",
description="Updated description",
)
assert model is not None
assert model.name == "Updated Model Name"
assert model.description == "Updated description"
def test_update_model_status(self, patched_session, sample_model_version):
"""Test updating model version status."""
repo = ModelVersionRepository()
model = repo.update(str(sample_model_version.version_id), status="deprecated")
assert model is not None
assert model.status == "deprecated"
def test_update_nonexistent_model(self, patched_session):
"""Test updating model that doesn't exist."""
repo = ModelVersionRepository()
model = repo.update(str(uuid4()), name="New Name")
assert model is None
class TestModelVersionArchive:
"""Tests for model version archiving."""
def test_archive_model_version(self, patched_session, sample_model_version):
"""Test archiving an inactive model version."""
repo = ModelVersionRepository()
model = repo.archive(str(sample_model_version.version_id))
assert model is not None
assert model.status == "archived"
def test_cannot_archive_active_model(self, patched_session, sample_model_version):
"""Test that active model cannot be archived."""
repo = ModelVersionRepository()
# Activate the model
repo.activate(str(sample_model_version.version_id))
# Try to archive
model = repo.archive(str(sample_model_version.version_id))
assert model is None
# Verify model is still active
current = repo.get(str(sample_model_version.version_id))
assert current.status == "active"
class TestModelVersionDelete:
"""Tests for model version deletion."""
def test_delete_inactive_model(self, patched_session, sample_model_version):
"""Test deleting an inactive model version."""
repo = ModelVersionRepository()
result = repo.delete(str(sample_model_version.version_id))
assert result is True
model = repo.get(str(sample_model_version.version_id))
assert model is None
def test_cannot_delete_active_model(self, patched_session, sample_model_version):
"""Test that active model cannot be deleted."""
repo = ModelVersionRepository()
# Activate the model
repo.activate(str(sample_model_version.version_id))
# Try to delete
result = repo.delete(str(sample_model_version.version_id))
assert result is False
# Verify model still exists
model = repo.get(str(sample_model_version.version_id))
assert model is not None
def test_delete_nonexistent_model(self, patched_session):
"""Test deleting model that doesn't exist."""
repo = ModelVersionRepository()
result = repo.delete(str(uuid4()))
assert result is False
class TestOnlyOneActiveModel:
"""Tests to verify only one model can be active at a time."""
def test_single_active_model_constraint(self, patched_session):
"""Test that only one model can be active at any time."""
repo = ModelVersionRepository()
# Create multiple models
models = []
for i in range(3):
m = repo.create(
version=f"1.{i}.0",
name=f"Model {i}",
model_path=f"/models/model_{i}.pt",
)
models.append(m)
# Activate each model in sequence
for model in models:
repo.activate(str(model.version_id))
# Count active models
all_models, _ = repo.get_paginated(status="active")
assert len(all_models) == 1
# Verify it's the last one activated
assert all_models[0].version_id == models[-1].version_id

View File

@@ -0,0 +1,274 @@
"""
Token Repository Integration Tests
Tests TokenRepository with real database operations.
"""
from datetime import datetime, timezone, timedelta
import pytest
from inference.data.repositories.token_repository import TokenRepository
class TestTokenCreate:
"""Tests for token creation."""
def test_create_new_token(self, patched_session):
"""Test creating a new admin token."""
repo = TokenRepository()
repo.create(
token="new-test-token-abc123",
name="New Test Admin",
)
token = repo.get("new-test-token-abc123")
assert token is not None
assert token.token == "new-test-token-abc123"
assert token.name == "New Test Admin"
assert token.is_active is True
assert token.expires_at is None
def test_create_token_with_expiration(self, patched_session):
"""Test creating token with expiration date."""
repo = TokenRepository()
expiry = datetime.now(timezone.utc) + timedelta(days=30)
repo.create(
token="expiring-token-xyz789",
name="Expiring Token",
expires_at=expiry,
)
token = repo.get("expiring-token-xyz789")
assert token is not None
assert token.expires_at is not None
def test_create_updates_existing_token(self, patched_session, admin_token):
"""Test creating with existing token updates it."""
repo = TokenRepository()
new_expiry = datetime.now(timezone.utc) + timedelta(days=60)
repo.create(
token=admin_token.token,
name="Updated Admin Name",
expires_at=new_expiry,
)
token = repo.get(admin_token.token)
assert token is not None
assert token.name == "Updated Admin Name"
assert token.is_active is True
class TestTokenValidation:
"""Tests for token validation."""
def test_is_valid_active_token(self, patched_session, admin_token):
"""Test that active token is valid."""
repo = TokenRepository()
result = repo.is_valid(admin_token.token)
assert result is True
def test_is_valid_nonexistent_token(self, patched_session):
"""Test that nonexistent token is invalid."""
repo = TokenRepository()
result = repo.is_valid("nonexistent-token-12345")
assert result is False
def test_is_valid_deactivated_token(self, patched_session, admin_token):
"""Test that deactivated token is invalid."""
repo = TokenRepository()
repo.deactivate(admin_token.token)
result = repo.is_valid(admin_token.token)
assert result is False
def test_is_valid_expired_token(self, patched_session):
"""Test that expired token is invalid."""
repo = TokenRepository()
past_expiry = datetime.now(timezone.utc) - timedelta(days=1)
repo.create(
token="expired-token-test",
name="Expired Token",
expires_at=past_expiry,
)
result = repo.is_valid("expired-token-test")
assert result is False
def test_is_valid_not_yet_expired_token(self, patched_session):
"""Test that not-yet-expired token is valid."""
repo = TokenRepository()
future_expiry = datetime.now(timezone.utc) + timedelta(days=7)
repo.create(
token="valid-expiring-token",
name="Valid Expiring Token",
expires_at=future_expiry,
)
result = repo.is_valid("valid-expiring-token")
assert result is True
class TestTokenGet:
"""Tests for token retrieval."""
def test_get_existing_token(self, patched_session, admin_token):
"""Test getting an existing token."""
repo = TokenRepository()
token = repo.get(admin_token.token)
assert token is not None
assert token.token == admin_token.token
assert token.name == admin_token.name
def test_get_nonexistent_token(self, patched_session):
"""Test getting a token that doesn't exist."""
repo = TokenRepository()
token = repo.get("nonexistent-token-xyz")
assert token is None
class TestTokenDeactivate:
"""Tests for token deactivation."""
def test_deactivate_existing_token(self, patched_session, admin_token):
"""Test deactivating an existing token."""
repo = TokenRepository()
result = repo.deactivate(admin_token.token)
assert result is True
token = repo.get(admin_token.token)
assert token is not None
assert token.is_active is False
def test_deactivate_nonexistent_token(self, patched_session):
"""Test deactivating a token that doesn't exist."""
repo = TokenRepository()
result = repo.deactivate("nonexistent-token-abc")
assert result is False
def test_reactivate_deactivated_token(self, patched_session, admin_token):
"""Test reactivating a deactivated token via create."""
repo = TokenRepository()
# Deactivate first
repo.deactivate(admin_token.token)
assert repo.is_valid(admin_token.token) is False
# Reactivate via create
repo.create(
token=admin_token.token,
name="Reactivated Admin",
)
assert repo.is_valid(admin_token.token) is True
class TestTokenUsageTracking:
"""Tests for token usage tracking."""
def test_update_usage(self, patched_session, admin_token):
"""Test updating token last used timestamp."""
repo = TokenRepository()
# Initially last_used_at might be None
initial_token = repo.get(admin_token.token)
initial_last_used = initial_token.last_used_at
repo.update_usage(admin_token.token)
updated_token = repo.get(admin_token.token)
assert updated_token.last_used_at is not None
if initial_last_used:
assert updated_token.last_used_at >= initial_last_used
def test_update_usage_nonexistent_token(self, patched_session):
"""Test updating usage for nonexistent token does nothing."""
repo = TokenRepository()
# Should not raise, just does nothing
repo.update_usage("nonexistent-token-usage")
token = repo.get("nonexistent-token-usage")
assert token is None
class TestTokenWorkflow:
"""Tests for complete token workflows."""
def test_full_token_lifecycle(self, patched_session):
"""Test complete token lifecycle: create, validate, use, deactivate."""
repo = TokenRepository()
token_str = "lifecycle-test-token"
# 1. Create token
repo.create(token=token_str, name="Lifecycle Token")
assert repo.is_valid(token_str) is True
# 2. Use token
repo.update_usage(token_str)
token = repo.get(token_str)
assert token.last_used_at is not None
# 3. Update token info
new_expiry = datetime.now(timezone.utc) + timedelta(days=90)
repo.create(
token=token_str,
name="Updated Lifecycle Token",
expires_at=new_expiry,
)
token = repo.get(token_str)
assert token.name == "Updated Lifecycle Token"
# 4. Deactivate token
result = repo.deactivate(token_str)
assert result is True
assert repo.is_valid(token_str) is False
# 5. Reactivate token
repo.create(token=token_str, name="Reactivated Token")
assert repo.is_valid(token_str) is True
def test_multiple_tokens(self, patched_session):
"""Test managing multiple tokens."""
repo = TokenRepository()
# Create multiple tokens
tokens = [
("token-a", "Admin A"),
("token-b", "Admin B"),
("token-c", "Admin C"),
]
for token_str, name in tokens:
repo.create(token=token_str, name=name)
# Verify all are valid
for token_str, _ in tokens:
assert repo.is_valid(token_str) is True
# Deactivate one
repo.deactivate("token-b")
# Verify states
assert repo.is_valid("token-a") is True
assert repo.is_valid("token-b") is False
assert repo.is_valid("token-c") is True

View File

@@ -0,0 +1,364 @@
"""
Training Task Repository Integration Tests
Tests TrainingTaskRepository with real database operations.
"""
from datetime import datetime, timezone, timedelta
from uuid import uuid4
import pytest
from inference.data.repositories.training_task_repository import TrainingTaskRepository
class TestTrainingTaskCreate:
"""Tests for training task creation."""
def test_create_training_task(self, patched_session, admin_token):
"""Test creating a training task."""
repo = TrainingTaskRepository()
task_id = repo.create(
admin_token=admin_token.token,
name="Test Training Task",
task_type="train",
description="Integration test training task",
config={"epochs": 100, "batch_size": 16},
)
assert task_id is not None
task = repo.get(task_id)
assert task is not None
assert task.name == "Test Training Task"
assert task.task_type == "train"
assert task.status == "pending"
assert task.config["epochs"] == 100
def test_create_scheduled_task(self, patched_session, admin_token):
"""Test creating a scheduled training task."""
repo = TrainingTaskRepository()
scheduled_time = datetime.now(timezone.utc) + timedelta(hours=1)
task_id = repo.create(
admin_token=admin_token.token,
name="Scheduled Task",
scheduled_at=scheduled_time,
)
task = repo.get(task_id)
assert task is not None
assert task.status == "scheduled"
assert task.scheduled_at is not None
def test_create_recurring_task(self, patched_session, admin_token):
"""Test creating a recurring training task."""
repo = TrainingTaskRepository()
task_id = repo.create(
admin_token=admin_token.token,
name="Recurring Task",
cron_expression="0 2 * * *",
is_recurring=True,
)
task = repo.get(task_id)
assert task is not None
assert task.is_recurring is True
assert task.cron_expression == "0 2 * * *"
def test_create_task_with_dataset(self, patched_session, admin_token, sample_dataset):
"""Test creating task linked to a dataset."""
repo = TrainingTaskRepository()
task_id = repo.create(
admin_token=admin_token.token,
name="Dataset Training Task",
dataset_id=str(sample_dataset.dataset_id),
)
task = repo.get(task_id)
assert task is not None
assert task.dataset_id == sample_dataset.dataset_id
class TestTrainingTaskRead:
"""Tests for training task retrieval."""
def test_get_task_by_id(self, patched_session, sample_training_task):
"""Test getting task by ID."""
repo = TrainingTaskRepository()
task = repo.get(str(sample_training_task.task_id))
assert task is not None
assert task.task_id == sample_training_task.task_id
def test_get_nonexistent_task(self, patched_session):
"""Test getting task that doesn't exist."""
repo = TrainingTaskRepository()
task = repo.get(str(uuid4()))
assert task is None
def test_get_paginated_tasks(self, patched_session, admin_token):
"""Test paginated task listing."""
repo = TrainingTaskRepository()
# Create multiple tasks
for i in range(5):
repo.create(admin_token=admin_token.token, name=f"Task {i}")
tasks, total = repo.get_paginated(limit=2, offset=0)
assert total == 5
assert len(tasks) == 2
def test_get_paginated_with_status_filter(self, patched_session, admin_token):
"""Test filtering tasks by status."""
repo = TrainingTaskRepository()
# Create tasks with different statuses
task_id = repo.create(admin_token=admin_token.token, name="Running Task")
repo.update_status(task_id, "running")
repo.create(admin_token=admin_token.token, name="Pending Task")
tasks, total = repo.get_paginated(status="running")
assert total == 1
assert tasks[0].status == "running"
def test_get_pending_tasks(self, patched_session, admin_token):
"""Test getting pending tasks ready to run."""
repo = TrainingTaskRepository()
# Create pending task
repo.create(admin_token=admin_token.token, name="Ready Task")
# Create scheduled task in the past (should be included)
past_time = datetime.now(timezone.utc) - timedelta(hours=1)
repo.create(
admin_token=admin_token.token,
name="Past Scheduled Task",
scheduled_at=past_time,
)
# Create scheduled task in the future (should not be included)
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
repo.create(
admin_token=admin_token.token,
name="Future Scheduled Task",
scheduled_at=future_time,
)
pending = repo.get_pending()
# Should include pending and past scheduled, not future scheduled
assert len(pending) >= 2
names = [t.name for t in pending]
assert "Ready Task" in names
assert "Past Scheduled Task" in names
def test_get_running_task(self, patched_session, admin_token):
"""Test getting currently running task."""
repo = TrainingTaskRepository()
task_id = repo.create(admin_token=admin_token.token, name="Running Task")
repo.update_status(task_id, "running")
running = repo.get_running()
assert running is not None
assert running.status == "running"
def test_get_running_task_none(self, patched_session, admin_token):
"""Test getting running task when none is running."""
repo = TrainingTaskRepository()
repo.create(admin_token=admin_token.token, name="Pending Task")
running = repo.get_running()
assert running is None
class TestTrainingTaskUpdate:
"""Tests for training task updates."""
def test_update_status_to_running(self, patched_session, sample_training_task):
"""Test updating task status to running."""
repo = TrainingTaskRepository()
repo.update_status(str(sample_training_task.task_id), "running")
task = repo.get(str(sample_training_task.task_id))
assert task is not None
assert task.status == "running"
assert task.started_at is not None
def test_update_status_to_completed(self, patched_session, sample_training_task):
"""Test updating task status to completed."""
repo = TrainingTaskRepository()
metrics = {"mAP": 0.92, "precision": 0.89, "recall": 0.85}
repo.update_status(
str(sample_training_task.task_id),
"completed",
result_metrics=metrics,
model_path="/models/trained_model.pt",
)
task = repo.get(str(sample_training_task.task_id))
assert task is not None
assert task.status == "completed"
assert task.completed_at is not None
assert task.result_metrics["mAP"] == 0.92
assert task.model_path == "/models/trained_model.pt"
def test_update_status_to_failed(self, patched_session, sample_training_task):
"""Test updating task status to failed with error message."""
repo = TrainingTaskRepository()
repo.update_status(
str(sample_training_task.task_id),
"failed",
error_message="CUDA out of memory",
)
task = repo.get(str(sample_training_task.task_id))
assert task is not None
assert task.status == "failed"
assert task.completed_at is not None
assert "CUDA out of memory" in task.error_message
def test_cancel_pending_task(self, patched_session, sample_training_task):
"""Test cancelling a pending task."""
repo = TrainingTaskRepository()
result = repo.cancel(str(sample_training_task.task_id))
assert result is True
task = repo.get(str(sample_training_task.task_id))
assert task is not None
assert task.status == "cancelled"
def test_cannot_cancel_running_task(self, patched_session, sample_training_task):
"""Test that running task cannot be cancelled."""
repo = TrainingTaskRepository()
repo.update_status(str(sample_training_task.task_id), "running")
result = repo.cancel(str(sample_training_task.task_id))
assert result is False
task = repo.get(str(sample_training_task.task_id))
assert task.status == "running"
class TestTrainingLogs:
"""Tests for training log management."""
def test_add_log_entry(self, patched_session, sample_training_task):
"""Test adding a training log entry."""
repo = TrainingTaskRepository()
repo.add_log(
str(sample_training_task.task_id),
level="INFO",
message="Starting training...",
details={"epoch": 1, "batch": 0},
)
logs = repo.get_logs(str(sample_training_task.task_id))
assert len(logs) == 1
assert logs[0].level == "INFO"
assert logs[0].message == "Starting training..."
def test_add_multiple_log_entries(self, patched_session, sample_training_task):
"""Test adding multiple log entries."""
repo = TrainingTaskRepository()
for i in range(5):
repo.add_log(
str(sample_training_task.task_id),
level="INFO",
message=f"Epoch {i} completed",
details={"epoch": i, "loss": 0.5 - i * 0.1},
)
logs = repo.get_logs(str(sample_training_task.task_id))
assert len(logs) == 5
def test_get_logs_pagination(self, patched_session, sample_training_task):
"""Test paginated log retrieval."""
repo = TrainingTaskRepository()
for i in range(10):
repo.add_log(
str(sample_training_task.task_id),
level="INFO",
message=f"Log entry {i}",
)
logs = repo.get_logs(str(sample_training_task.task_id), limit=5, offset=0)
assert len(logs) == 5
logs_page2 = repo.get_logs(str(sample_training_task.task_id), limit=5, offset=5)
assert len(logs_page2) == 5
class TestDocumentLinks:
"""Tests for training document link management."""
def test_create_document_link(self, patched_session, sample_training_task, sample_document):
"""Test creating a document link."""
repo = TrainingTaskRepository()
link = repo.create_document_link(
task_id=sample_training_task.task_id,
document_id=sample_document.document_id,
annotation_snapshot={"count": 5, "verified": 3},
)
assert link is not None
assert link.task_id == sample_training_task.task_id
assert link.document_id == sample_document.document_id
assert link.annotation_snapshot["count"] == 5
def test_get_document_links(self, patched_session, sample_training_task, multiple_documents):
"""Test getting all document links for a task."""
repo = TrainingTaskRepository()
for doc in multiple_documents[:3]:
repo.create_document_link(
task_id=sample_training_task.task_id,
document_id=doc.document_id,
)
links = repo.get_document_links(sample_training_task.task_id)
assert len(links) == 3
def test_get_document_training_tasks(self, patched_session, admin_token, sample_document):
"""Test getting training tasks that used a document."""
repo = TrainingTaskRepository()
# Create multiple tasks using the same document
task1_id = repo.create(admin_token=admin_token.token, name="Task 1")
task2_id = repo.create(admin_token=admin_token.token, name="Task 2")
repo.create_document_link(
task_id=repo.get(task1_id).task_id,
document_id=sample_document.document_id,
)
repo.create_document_link(
task_id=repo.get(task2_id).task_id,
document_id=sample_document.document_id,
)
links = repo.get_document_training_tasks(sample_document.document_id)
assert len(links) == 2

View File

@@ -0,0 +1 @@
"""Service integration tests."""

View File

@@ -0,0 +1,497 @@
"""
Dashboard Service Integration Tests
Tests DashboardStatsService and DashboardActivityService with real database operations.
"""
from datetime import datetime, timezone
from uuid import uuid4
import pytest
from inference.data.admin_models import (
AdminAnnotation,
AdminDocument,
AnnotationHistory,
ModelVersion,
TrainingDataset,
TrainingTask,
)
from inference.web.services.dashboard_service import (
DashboardStatsService,
DashboardActivityService,
is_annotation_complete,
IDENTIFIER_CLASS_IDS,
PAYMENT_CLASS_IDS,
)
class TestIsAnnotationComplete:
"""Tests for is_annotation_complete function."""
def test_complete_with_invoice_number_and_bankgiro(self):
"""Test complete with invoice_number (0) and bankgiro (4)."""
annotations = [
{"class_id": 0}, # invoice_number
{"class_id": 4}, # bankgiro
]
assert is_annotation_complete(annotations) is True
def test_complete_with_ocr_number_and_plusgiro(self):
"""Test complete with ocr_number (3) and plusgiro (5)."""
annotations = [
{"class_id": 3}, # ocr_number
{"class_id": 5}, # plusgiro
]
assert is_annotation_complete(annotations) is True
def test_incomplete_missing_identifier(self):
"""Test incomplete when missing identifier."""
annotations = [
{"class_id": 4}, # bankgiro only
]
assert is_annotation_complete(annotations) is False
def test_incomplete_missing_payment(self):
"""Test incomplete when missing payment."""
annotations = [
{"class_id": 0}, # invoice_number only
]
assert is_annotation_complete(annotations) is False
def test_incomplete_empty_annotations(self):
"""Test incomplete with empty annotations."""
assert is_annotation_complete([]) is False
def test_complete_with_multiple_fields(self):
"""Test complete with multiple fields."""
annotations = [
{"class_id": 0}, # invoice_number
{"class_id": 1}, # invoice_date
{"class_id": 3}, # ocr_number
{"class_id": 4}, # bankgiro
{"class_id": 5}, # plusgiro
{"class_id": 6}, # amount
]
assert is_annotation_complete(annotations) is True
class TestDashboardStatsService:
"""Tests for DashboardStatsService."""
def test_get_stats_empty_database(self, patched_session):
"""Test stats with empty database."""
service = DashboardStatsService()
stats = service.get_stats()
assert stats["total_documents"] == 0
assert stats["annotation_complete"] == 0
assert stats["annotation_incomplete"] == 0
assert stats["pending"] == 0
assert stats["completeness_rate"] == 0.0
def test_get_stats_with_documents(self, patched_session, admin_token):
"""Test stats with various document states."""
service = DashboardStatsService()
session = patched_session
# Create documents with different statuses
docs = []
for i, status in enumerate(["pending", "auto_labeling", "labeled", "labeled", "exported"]):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"doc_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/doc_{i}.pdf",
page_count=1,
status=status,
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
docs.append(doc)
session.commit()
stats = service.get_stats()
assert stats["total_documents"] == 5
assert stats["pending"] == 2 # pending + auto_labeling
def test_get_stats_complete_annotations(self, patched_session, admin_token):
"""Test completeness calculation with proper annotations."""
service = DashboardStatsService()
session = patched_session
# Create a labeled document with complete annotations
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename="complete_doc.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/complete_doc.pdf",
page_count=1,
status="labeled",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
# Add identifier annotation (invoice_number = class_id 0)
ann1 = AdminAnnotation(
annotation_id=uuid4(),
document_id=doc.document_id,
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.1,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=80,
bbox_width=160,
bbox_height=40,
text_value="INV-001",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(ann1)
# Add payment annotation (bankgiro = class_id 4)
ann2 = AdminAnnotation(
annotation_id=uuid4(),
document_id=doc.document_id,
page_number=1,
class_id=4,
class_name="bankgiro",
x_center=0.5,
y_center=0.2,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=160,
bbox_width=160,
bbox_height=40,
text_value="123-4567",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(ann2)
session.commit()
stats = service.get_stats()
assert stats["annotation_complete"] == 1
assert stats["annotation_incomplete"] == 0
assert stats["completeness_rate"] == 100.0
def test_get_stats_incomplete_annotations(self, patched_session, admin_token):
"""Test completeness with incomplete annotations."""
service = DashboardStatsService()
session = patched_session
# Create a labeled document missing payment annotation
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename="incomplete_doc.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/incomplete_doc.pdf",
page_count=1,
status="labeled",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
# Add only identifier annotation (missing payment)
ann = AdminAnnotation(
annotation_id=uuid4(),
document_id=doc.document_id,
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.1,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=80,
bbox_width=160,
bbox_height=40,
text_value="INV-001",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(ann)
session.commit()
stats = service.get_stats()
assert stats["annotation_complete"] == 0
assert stats["annotation_incomplete"] == 1
assert stats["completeness_rate"] == 0.0
def test_get_stats_mixed_completeness(self, patched_session, admin_token):
"""Test stats with mix of complete and incomplete documents."""
service = DashboardStatsService()
session = patched_session
# Create 2 labeled documents
docs = []
for i in range(2):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"mixed_doc_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/mixed_doc_{i}.pdf",
page_count=1,
status="labeled",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
docs.append(doc)
session.commit()
# First document: complete (has identifier + payment)
session.add(AdminAnnotation(
annotation_id=uuid4(),
document_id=docs[0].document_id,
page_number=1,
class_id=0, # invoice_number
class_name="invoice_number",
x_center=0.5, y_center=0.1, width=0.2, height=0.05,
bbox_x=400, bbox_y=80, bbox_width=160, bbox_height=40,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
))
session.add(AdminAnnotation(
annotation_id=uuid4(),
document_id=docs[0].document_id,
page_number=1,
class_id=4, # bankgiro
class_name="bankgiro",
x_center=0.5, y_center=0.2, width=0.2, height=0.05,
bbox_x=400, bbox_y=160, bbox_width=160, bbox_height=40,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
))
# Second document: incomplete (missing payment)
session.add(AdminAnnotation(
annotation_id=uuid4(),
document_id=docs[1].document_id,
page_number=1,
class_id=0, # invoice_number only
class_name="invoice_number",
x_center=0.5, y_center=0.1, width=0.2, height=0.05,
bbox_x=400, bbox_y=80, bbox_width=160, bbox_height=40,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
))
session.commit()
stats = service.get_stats()
assert stats["annotation_complete"] == 1
assert stats["annotation_incomplete"] == 1
assert stats["completeness_rate"] == 50.0
class TestDashboardActivityService:
"""Tests for DashboardActivityService."""
def test_get_recent_activities_empty(self, patched_session):
"""Test activities with empty database."""
service = DashboardActivityService()
activities = service.get_recent_activities()
assert activities == []
def test_get_recent_activities_document_uploads(self, patched_session, admin_token):
"""Test activities include document uploads."""
service = DashboardActivityService()
session = patched_session
# Create documents
for i in range(3):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"activity_doc_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/activity_doc_{i}.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
activities = service.get_recent_activities()
upload_activities = [a for a in activities if a["type"] == "document_uploaded"]
assert len(upload_activities) == 3
def test_get_recent_activities_annotation_overrides(self, patched_session, sample_document, sample_annotation):
"""Test activities include annotation overrides."""
service = DashboardActivityService()
session = patched_session
# Create annotation history with override
history = AnnotationHistory(
history_id=uuid4(),
annotation_id=sample_annotation.annotation_id,
document_id=sample_document.document_id,
action="override",
previous_value={"text_value": "OLD-001"},
new_value={"text_value": "NEW-001", "class_name": "invoice_number"},
changed_by="test-admin",
created_at=datetime.now(timezone.utc),
)
session.add(history)
session.commit()
activities = service.get_recent_activities()
override_activities = [a for a in activities if a["type"] == "annotation_modified"]
assert len(override_activities) >= 1
def test_get_recent_activities_training_completed(self, patched_session, sample_training_task):
"""Test activities include training completions."""
service = DashboardActivityService()
session = patched_session
# Update training task to completed
sample_training_task.status = "completed"
sample_training_task.metrics_mAP = 0.85
sample_training_task.updated_at = datetime.now(timezone.utc)
session.add(sample_training_task)
session.commit()
activities = service.get_recent_activities()
training_activities = [a for a in activities if a["type"] == "training_completed"]
assert len(training_activities) >= 1
assert "mAP" in training_activities[0]["metadata"]
def test_get_recent_activities_training_failed(self, patched_session, sample_training_task):
"""Test activities include training failures."""
service = DashboardActivityService()
session = patched_session
# Update training task to failed
sample_training_task.status = "failed"
sample_training_task.error_message = "CUDA out of memory"
sample_training_task.updated_at = datetime.now(timezone.utc)
session.add(sample_training_task)
session.commit()
activities = service.get_recent_activities()
failed_activities = [a for a in activities if a["type"] == "training_failed"]
assert len(failed_activities) >= 1
assert failed_activities[0]["metadata"]["error"] == "CUDA out of memory"
def test_get_recent_activities_model_activated(self, patched_session, sample_model_version):
"""Test activities include model activations."""
service = DashboardActivityService()
session = patched_session
# Activate model
sample_model_version.is_active = True
sample_model_version.activated_at = datetime.now(timezone.utc)
session.add(sample_model_version)
session.commit()
activities = service.get_recent_activities()
activation_activities = [a for a in activities if a["type"] == "model_activated"]
assert len(activation_activities) >= 1
assert activation_activities[0]["metadata"]["version"] == sample_model_version.version
def test_get_recent_activities_limit(self, patched_session, admin_token):
"""Test activity limit parameter."""
service = DashboardActivityService()
session = patched_session
# Create many documents
for i in range(20):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"limit_doc_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/limit_doc_{i}.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
activities = service.get_recent_activities(limit=5)
assert len(activities) <= 5
def test_get_recent_activities_sorted_by_timestamp(self, patched_session, admin_token, sample_training_task):
"""Test activities are sorted by timestamp descending."""
service = DashboardActivityService()
session = patched_session
# Create document
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename="sorted_doc.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/sorted_doc.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
# Complete training task
sample_training_task.status = "completed"
sample_training_task.metrics_mAP = 0.90
sample_training_task.updated_at = datetime.now(timezone.utc)
session.add(sample_training_task)
session.commit()
activities = service.get_recent_activities()
# Verify sorted by timestamp DESC
timestamps = [a["timestamp"] for a in activities]
assert timestamps == sorted(timestamps, reverse=True)

View File

@@ -0,0 +1,453 @@
"""
Dataset Builder Service Integration Tests
Tests DatasetBuilder with real file operations and repository interactions.
"""
import shutil
from datetime import datetime, timezone
from pathlib import Path
from uuid import uuid4
import pytest
import yaml
from inference.data.admin_models import AdminAnnotation, AdminDocument
from inference.data.repositories.annotation_repository import AnnotationRepository
from inference.data.repositories.dataset_repository import DatasetRepository
from inference.data.repositories.document_repository import DocumentRepository
from inference.web.services.dataset_builder import DatasetBuilder
@pytest.fixture
def dataset_builder(patched_session, temp_dataset_dir):
"""Create a DatasetBuilder with real repositories."""
return DatasetBuilder(
datasets_repo=DatasetRepository(),
documents_repo=DocumentRepository(),
annotations_repo=AnnotationRepository(),
base_dir=temp_dataset_dir,
)
@pytest.fixture
def admin_images_dir(temp_upload_dir):
"""Create a directory for admin images."""
images_dir = temp_upload_dir / "admin_images"
images_dir.mkdir(parents=True, exist_ok=True)
return images_dir
@pytest.fixture
def documents_with_annotations(patched_session, db_session, admin_token, admin_images_dir):
"""Create documents with annotations and corresponding image files."""
documents = []
doc_repo = DocumentRepository()
ann_repo = AnnotationRepository()
for i in range(5):
# Create document
doc_id = doc_repo.create(
filename=f"invoice_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/invoice_{i}.pdf",
page_count=2,
category="invoice",
group_key=f"group_{i % 2}", # Two groups
)
# Create image files for each page
doc_dir = admin_images_dir / doc_id
doc_dir.mkdir(parents=True, exist_ok=True)
for page in range(1, 3):
image_path = doc_dir / f"page_{page}.png"
# Create a minimal fake PNG
image_path.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100)
# Create annotations
for j in range(3):
ann_repo.create(
document_id=doc_id,
page_number=1,
class_id=j,
class_name=f"field_{j}",
x_center=0.5,
y_center=0.1 + j * 0.2,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=80 + j * 160,
bbox_width=160,
bbox_height=40,
text_value=f"value_{j}",
confidence=0.95,
source="auto",
)
doc = doc_repo.get(doc_id)
documents.append(doc)
return documents
class TestDatasetBuilderBasic:
"""Tests for basic dataset building operations."""
def test_build_dataset_creates_directory_structure(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that building creates proper directory structure."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Test Dataset")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
# Check directory structure
assert (dataset_dir / "images" / "train").exists()
assert (dataset_dir / "images" / "val").exists()
assert (dataset_dir / "images" / "test").exists()
assert (dataset_dir / "labels" / "train").exists()
assert (dataset_dir / "labels" / "val").exists()
assert (dataset_dir / "labels" / "test").exists()
def test_build_dataset_copies_images(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that images are copied to dataset directory."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Image Copy Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
result = dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
# Count total images across all splits
total_images = 0
for split in ["train", "val", "test"]:
images = list((dataset_dir / "images" / split).glob("*.png"))
total_images += len(images)
# 5 docs * 2 pages = 10 images
assert total_images == 10
assert result["total_images"] == 10
def test_build_dataset_generates_labels(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that YOLO label files are generated."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Label Generation Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
# Count total label files
total_labels = 0
for split in ["train", "val", "test"]:
labels = list((dataset_dir / "labels" / split).glob("*.txt"))
total_labels += len(labels)
# Same count as images
assert total_labels == 10
def test_build_dataset_generates_data_yaml(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that data.yaml is generated correctly."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="YAML Generation Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
yaml_path = dataset_dir / "data.yaml"
assert yaml_path.exists()
with open(yaml_path) as f:
data = yaml.safe_load(f)
assert data["train"] == "images/train"
assert data["val"] == "images/val"
assert data["test"] == "images/test"
assert "nc" in data
assert "names" in data
class TestDatasetBuilderSplits:
"""Tests for train/val/test split assignment."""
def test_split_ratio_respected(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that split ratios are approximately respected."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Split Ratio Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.6,
val_ratio=0.2,
seed=42,
admin_images_dir=admin_images_dir,
)
# Check document assignments in database
dataset_docs = dataset_repo.get_documents(str(dataset.dataset_id))
splits = {"train": 0, "val": 0, "test": 0}
for doc in dataset_docs:
splits[doc.split] += 1
# With 5 docs and ratios 0.6/0.2/0.2, expect ~3/1/1
# Due to rounding and group constraints, allow some variation
assert splits["train"] >= 2
assert splits["val"] >= 1 or splits["test"] >= 1
def test_same_seed_same_split(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that same seed produces same split."""
dataset_repo = DatasetRepository()
doc_ids = [str(d.document_id) for d in documents_with_annotations]
# Build first dataset
dataset1 = dataset_repo.create(name="Seed Test 1")
dataset_builder.build_dataset(
dataset_id=str(dataset1.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=12345,
admin_images_dir=admin_images_dir,
)
# Build second dataset with same seed
dataset2 = dataset_repo.create(name="Seed Test 2")
dataset_builder.build_dataset(
dataset_id=str(dataset2.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=12345,
admin_images_dir=admin_images_dir,
)
# Compare splits
docs1 = {str(d.document_id): d.split for d in dataset_repo.get_documents(str(dataset1.dataset_id))}
docs2 = {str(d.document_id): d.split for d in dataset_repo.get_documents(str(dataset2.dataset_id))}
assert docs1 == docs2
class TestDatasetBuilderDatabase:
"""Tests for database interactions."""
def test_updates_dataset_status(
self, dataset_builder, documents_with_annotations, admin_images_dir, patched_session
):
"""Test that dataset status is updated after build."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Status Update Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
updated = dataset_repo.get(str(dataset.dataset_id))
assert updated.status == "ready"
assert updated.total_documents == 5
assert updated.total_images == 10
assert updated.total_annotations > 0
assert updated.dataset_path is not None
def test_records_document_assignments(
self, dataset_builder, documents_with_annotations, admin_images_dir, patched_session
):
"""Test that document assignments are recorded in database."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Assignment Recording Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
dataset_docs = dataset_repo.get_documents(str(dataset.dataset_id))
assert len(dataset_docs) == 5
for doc in dataset_docs:
assert doc.split in ["train", "val", "test"]
assert doc.page_count > 0
class TestDatasetBuilderErrors:
"""Tests for error handling."""
def test_fails_with_no_documents(self, dataset_builder, admin_images_dir, patched_session):
"""Test that building fails with empty document list."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Empty Docs Test")
with pytest.raises(ValueError, match="No valid documents"):
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[],
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
def test_fails_with_invalid_doc_ids(self, dataset_builder, admin_images_dir, patched_session):
"""Test that building fails with nonexistent document IDs."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Invalid IDs Test")
fake_ids = [str(uuid4()) for _ in range(3)]
with pytest.raises(ValueError, match="No valid documents"):
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=fake_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
def test_updates_status_on_failure(self, dataset_builder, admin_images_dir, patched_session):
"""Test that dataset status is set to failed on error."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Failure Status Test")
try:
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[],
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
except ValueError:
pass
updated = dataset_repo.get(str(dataset.dataset_id))
assert updated.status == "failed"
assert updated.error_message is not None
class TestLabelFileFormat:
"""Tests for YOLO label file format."""
def test_label_file_format(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that label files are in correct YOLO format."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Label Format Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
# Find a label file with content
label_files = []
for split in ["train", "val", "test"]:
label_files.extend(list((dataset_dir / "labels" / split).glob("*.txt")))
# Check at least one label file has correct format
found_valid_label = False
for label_file in label_files:
content = label_file.read_text().strip()
if content:
lines = content.split("\n")
for line in lines:
parts = line.split()
assert len(parts) == 5, f"Expected 5 parts, got {len(parts)}: {line}"
class_id = int(parts[0])
x_center = float(parts[1])
y_center = float(parts[2])
width = float(parts[3])
height = float(parts[4])
assert 0 <= class_id < 10
assert 0 <= x_center <= 1
assert 0 <= y_center <= 1
assert 0 <= width <= 1
assert 0 <= height <= 1
found_valid_label = True
break
assert found_valid_label, "No valid label files found"

View File

@@ -0,0 +1,283 @@
"""
Document Service Integration Tests
Tests DocumentService with real storage operations.
"""
from pathlib import Path
from unittest.mock import MagicMock
import pytest
from inference.web.services.document_service import DocumentService, DocumentResult
class MockStorageBackend:
"""Simple in-memory storage backend for testing."""
def __init__(self):
self._files: dict[str, bytes] = {}
def upload_bytes(self, content: bytes, remote_path: str, overwrite: bool = False) -> None:
if not overwrite and remote_path in self._files:
raise FileExistsError(f"File already exists: {remote_path}")
self._files[remote_path] = content
def download_bytes(self, remote_path: str) -> bytes:
if remote_path not in self._files:
raise FileNotFoundError(f"File not found: {remote_path}")
return self._files[remote_path]
def get_presigned_url(self, remote_path: str, expires_in_seconds: int = 3600) -> str:
return f"https://storage.example.com/{remote_path}?expires={expires_in_seconds}"
def exists(self, remote_path: str) -> bool:
return remote_path in self._files
def delete(self, remote_path: str) -> bool:
if remote_path in self._files:
del self._files[remote_path]
return True
return False
def list_files(self, prefix: str) -> list[str]:
return [path for path in self._files.keys() if path.startswith(prefix)]
@pytest.fixture
def mock_storage():
"""Create a mock storage backend."""
return MockStorageBackend()
@pytest.fixture
def document_service(mock_storage):
"""Create a DocumentService with mock storage."""
return DocumentService(storage_backend=mock_storage)
class TestDocumentUpload:
"""Tests for document upload operations."""
def test_upload_document(self, document_service):
"""Test uploading a document."""
content = b"%PDF-1.4 test content"
filename = "test_invoice.pdf"
result = document_service.upload_document(content, filename)
assert result is not None
assert result.id is not None
assert result.filename == filename
assert result.file_path.startswith("documents/")
assert result.file_path.endswith(".pdf")
def test_upload_document_with_custom_id(self, document_service):
"""Test uploading with custom document ID."""
content = b"%PDF-1.4 test content"
filename = "invoice.pdf"
custom_id = "custom-doc-12345"
result = document_service.upload_document(
content, filename, document_id=custom_id
)
assert result.id == custom_id
assert custom_id in result.file_path
def test_upload_preserves_extension(self, document_service):
"""Test that file extension is preserved."""
cases = [
("document.pdf", ".pdf"),
("image.PNG", ".png"),
("file.JPEG", ".jpeg"),
("noextension", ""),
]
for filename, expected_ext in cases:
result = document_service.upload_document(b"content", filename)
if expected_ext:
assert result.file_path.endswith(expected_ext)
def test_upload_document_overwrite(self, document_service, mock_storage):
"""Test that upload overwrites existing file."""
content1 = b"original content"
content2 = b"new content"
doc_id = "overwrite-test"
document_service.upload_document(content1, "doc.pdf", document_id=doc_id)
document_service.upload_document(content2, "doc.pdf", document_id=doc_id)
# Should have new content
remote_path = f"documents/{doc_id}.pdf"
stored_content = mock_storage.download_bytes(remote_path)
assert stored_content == content2
class TestDocumentDownload:
"""Tests for document download operations."""
def test_download_document(self, document_service, mock_storage):
"""Test downloading a document."""
content = b"test document content"
remote_path = "documents/test-doc.pdf"
mock_storage.upload_bytes(content, remote_path)
downloaded = document_service.download_document(remote_path)
assert downloaded == content
def test_download_nonexistent_document(self, document_service):
"""Test downloading document that doesn't exist."""
with pytest.raises(FileNotFoundError):
document_service.download_document("documents/nonexistent.pdf")
class TestDocumentUrl:
"""Tests for document URL generation."""
def test_get_document_url(self, document_service, mock_storage):
"""Test getting presigned URL for document."""
remote_path = "documents/test-doc.pdf"
mock_storage.upload_bytes(b"content", remote_path)
url = document_service.get_document_url(remote_path, expires_in_seconds=7200)
assert url.startswith("https://")
assert remote_path in url
assert "7200" in url
def test_get_document_url_default_expiry(self, document_service):
"""Test default URL expiry."""
url = document_service.get_document_url("documents/doc.pdf")
assert "3600" in url
class TestDocumentExists:
"""Tests for document existence check."""
def test_document_exists(self, document_service, mock_storage):
"""Test checking if document exists."""
remote_path = "documents/existing.pdf"
mock_storage.upload_bytes(b"content", remote_path)
assert document_service.document_exists(remote_path) is True
def test_document_not_exists(self, document_service):
"""Test checking if nonexistent document exists."""
assert document_service.document_exists("documents/nonexistent.pdf") is False
class TestDocumentDelete:
"""Tests for document deletion."""
def test_delete_document(self, document_service, mock_storage):
"""Test deleting a document."""
remote_path = "documents/to-delete.pdf"
mock_storage.upload_bytes(b"content", remote_path)
result = document_service.delete_document_files(remote_path)
assert result is True
assert document_service.document_exists(remote_path) is False
def test_delete_nonexistent_document(self, document_service):
"""Test deleting document that doesn't exist."""
result = document_service.delete_document_files("documents/nonexistent.pdf")
assert result is False
class TestPageImages:
"""Tests for page image operations."""
def test_save_page_image(self, document_service, mock_storage):
"""Test saving a page image."""
doc_id = "test-doc-123"
page_num = 1
image_content = b"\x89PNG\r\n\x1a\n fake png"
remote_path = document_service.save_page_image(doc_id, page_num, image_content)
assert remote_path == f"images/{doc_id}/page_{page_num}.png"
assert mock_storage.exists(remote_path)
def test_save_multiple_page_images(self, document_service, mock_storage):
"""Test saving images for multiple pages."""
doc_id = "multi-page-doc"
for page_num in range(1, 4):
content = f"page {page_num} content".encode()
document_service.save_page_image(doc_id, page_num, content)
images = document_service.list_document_images(doc_id)
assert len(images) == 3
def test_get_page_image(self, document_service, mock_storage):
"""Test downloading a page image."""
doc_id = "test-doc"
page_num = 2
image_content = b"image data"
document_service.save_page_image(doc_id, page_num, image_content)
downloaded = document_service.get_page_image(doc_id, page_num)
assert downloaded == image_content
def test_get_page_image_url(self, document_service):
"""Test getting URL for page image."""
doc_id = "test-doc"
page_num = 1
url = document_service.get_page_image_url(doc_id, page_num)
assert f"images/{doc_id}/page_{page_num}.png" in url
def test_list_document_images(self, document_service, mock_storage):
"""Test listing all images for a document."""
doc_id = "list-test-doc"
for i in range(5):
document_service.save_page_image(doc_id, i + 1, f"page {i}".encode())
images = document_service.list_document_images(doc_id)
assert len(images) == 5
def test_delete_document_images(self, document_service, mock_storage):
"""Test deleting all images for a document."""
doc_id = "delete-images-doc"
for i in range(3):
document_service.save_page_image(doc_id, i + 1, b"content")
deleted_count = document_service.delete_document_images(doc_id)
assert deleted_count == 3
assert len(document_service.list_document_images(doc_id)) == 0
class TestRoundTrip:
"""Tests for complete upload-download cycles."""
def test_document_round_trip(self, document_service):
"""Test uploading and downloading document."""
original_content = b"%PDF-1.4 complete document content here"
filename = "roundtrip.pdf"
result = document_service.upload_document(original_content, filename)
downloaded = document_service.download_document(result.file_path)
assert downloaded == original_content
def test_image_round_trip(self, document_service):
"""Test saving and retrieving page image."""
doc_id = "roundtrip-doc"
page_num = 1
original_image = b"\x89PNG fake image data"
document_service.save_page_image(doc_id, page_num, original_image)
retrieved = document_service.get_page_image(doc_id, page_num)
assert retrieved == original_image

View File

@@ -0,0 +1,258 @@
"""
Database Setup Integration Tests
Tests for database connection, session management, and basic operations.
"""
import pytest
from sqlmodel import Session, select
from inference.data.admin_models import AdminDocument, AdminToken
class TestDatabaseConnection:
"""Tests for database engine and connection."""
def test_engine_connection(self, test_engine):
"""Verify database engine can establish connection."""
with test_engine.connect() as conn:
result = conn.execute(select(1))
assert result.scalar() == 1
def test_tables_created(self, test_engine):
"""Verify all expected tables are created."""
from sqlmodel import SQLModel
table_names = SQLModel.metadata.tables.keys()
expected_tables = [
"admin_tokens",
"admin_documents",
"admin_annotations",
"training_tasks",
"training_logs",
"batch_uploads",
"batch_upload_files",
"training_datasets",
"dataset_documents",
"training_document_links",
"model_versions",
]
for table in expected_tables:
assert table in table_names, f"Table '{table}' not found"
class TestSessionManagement:
"""Tests for database session context manager."""
def test_session_commit(self, db_session):
"""Verify session commits changes successfully."""
token = AdminToken(
token="commit-test-token",
name="Commit Test",
is_active=True,
)
db_session.add(token)
db_session.commit()
result = db_session.exec(
select(AdminToken).where(AdminToken.token == "commit-test-token")
).first()
assert result is not None
assert result.name == "Commit Test"
def test_session_rollback_on_error(self, test_engine):
"""Verify session rollback on exception."""
session = Session(test_engine)
try:
token = AdminToken(
token="rollback-test-token",
name="Rollback Test",
is_active=True,
)
session.add(token)
session.commit()
# Try to insert duplicate (should fail)
duplicate = AdminToken(
token="rollback-test-token", # Same primary key
name="Duplicate",
is_active=True,
)
session.add(duplicate)
session.commit()
except Exception:
session.rollback()
finally:
session.close()
# Verify original record exists
with Session(test_engine) as verify_session:
result = verify_session.exec(
select(AdminToken).where(AdminToken.token == "rollback-test-token")
).first()
assert result is not None
assert result.name == "Rollback Test"
def test_session_isolation(self, test_engine):
"""Verify sessions are isolated from each other."""
session1 = Session(test_engine)
session2 = Session(test_engine)
try:
# Insert in session1, don't commit
token = AdminToken(
token="isolation-test-token",
name="Isolation Test",
is_active=True,
)
session1.add(token)
session1.flush()
# Session2 should not see uncommitted data (with proper isolation)
# Note: SQLite in-memory may have different isolation behavior
session1.commit()
result = session2.exec(
select(AdminToken).where(AdminToken.token == "isolation-test-token")
).first()
# After commit, session2 should see the data
assert result is not None
finally:
session1.close()
session2.close()
class TestBasicCRUDOperations:
"""Tests for basic CRUD operations on database."""
def test_create_and_read_token(self, db_session):
"""Test creating and reading admin token."""
token = AdminToken(
token="crud-test-token",
name="CRUD Test",
is_active=True,
)
db_session.add(token)
db_session.commit()
result = db_session.get(AdminToken, "crud-test-token")
assert result is not None
assert result.name == "CRUD Test"
assert result.is_active is True
def test_update_entity(self, db_session, admin_token):
"""Test updating an entity."""
admin_token.name = "Updated Name"
db_session.add(admin_token)
db_session.commit()
result = db_session.get(AdminToken, admin_token.token)
assert result is not None
assert result.name == "Updated Name"
def test_delete_entity(self, db_session):
"""Test deleting an entity."""
token = AdminToken(
token="delete-test-token",
name="Delete Test",
is_active=True,
)
db_session.add(token)
db_session.commit()
db_session.delete(token)
db_session.commit()
result = db_session.get(AdminToken, "delete-test-token")
assert result is None
def test_foreign_key_constraint(self, db_session, admin_token):
"""Test foreign key constraints are enforced."""
from uuid import uuid4
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename="fk_test.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/test/fk_test.pdf",
page_count=1,
status="pending",
)
db_session.add(doc)
db_session.commit()
# Document should reference valid token
result = db_session.get(AdminDocument, doc.document_id)
assert result is not None
assert result.admin_token == admin_token.token
class TestQueryOperations:
"""Tests for various query operations."""
def test_select_with_filter(self, db_session, multiple_documents):
"""Test SELECT with WHERE clause."""
results = db_session.exec(
select(AdminDocument).where(AdminDocument.status == "labeled")
).all()
assert len(results) == 2
for doc in results:
assert doc.status == "labeled"
def test_select_with_order(self, db_session, multiple_documents):
"""Test SELECT with ORDER BY clause."""
results = db_session.exec(
select(AdminDocument).order_by(AdminDocument.file_size.desc())
).all()
file_sizes = [doc.file_size for doc in results]
assert file_sizes == sorted(file_sizes, reverse=True)
def test_select_with_limit_offset(self, db_session, multiple_documents):
"""Test SELECT with LIMIT and OFFSET."""
results = db_session.exec(
select(AdminDocument)
.order_by(AdminDocument.filename)
.offset(2)
.limit(2)
).all()
assert len(results) == 2
def test_count_query(self, db_session, multiple_documents):
"""Test COUNT aggregation."""
from sqlalchemy import func
count = db_session.exec(
select(func.count()).select_from(AdminDocument)
).one()
assert count == 5
def test_group_by_query(self, db_session, multiple_documents):
"""Test GROUP BY aggregation."""
from sqlalchemy import func
results = db_session.exec(
select(
AdminDocument.status,
func.count(AdminDocument.document_id).label("count"),
).group_by(AdminDocument.status)
).all()
status_counts = {row[0]: row[1] for row in results}
assert status_counts.get("pending") == 2
assert status_counts.get("labeled") == 2
assert status_counts.get("exported") == 1