466 lines
14 KiB
Python
466 lines
14 KiB
Python
"""
|
|
Integration Test Fixtures
|
|
|
|
Provides shared fixtures for integration tests using PostgreSQL.
|
|
|
|
IMPORTANT: Integration tests MUST use Docker testcontainers for database isolation.
|
|
This ensures tests never touch the real production/development database.
|
|
|
|
Supported modes:
|
|
1. Docker testcontainers (default): Automatically starts a PostgreSQL container
|
|
2. TEST_DB_URL environment variable: Use a dedicated test database (NOT production!)
|
|
|
|
To use an external test database, set:
|
|
TEST_DB_URL=postgresql://user:password@host:port/test_dbname
|
|
"""
|
|
|
|
import os
|
|
import tempfile
|
|
from contextlib import contextmanager, ExitStack
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Generator
|
|
from unittest.mock import patch
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
from sqlmodel import Session, SQLModel, create_engine
|
|
|
|
from backend.data.admin_models import (
|
|
AdminAnnotation,
|
|
AdminDocument,
|
|
AdminToken,
|
|
AnnotationHistory,
|
|
BatchUpload,
|
|
BatchUploadFile,
|
|
DatasetDocument,
|
|
ModelVersion,
|
|
TrainingDataset,
|
|
TrainingDocumentLink,
|
|
TrainingLog,
|
|
TrainingTask,
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Database Fixtures
|
|
# =============================================================================
|
|
|
|
|
|
def _is_docker_available() -> bool:
|
|
"""Check if Docker is available."""
|
|
try:
|
|
import docker
|
|
client = docker.from_env()
|
|
client.ping()
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def _get_test_db_url() -> str | None:
|
|
"""Get test database URL from environment."""
|
|
return os.environ.get("TEST_DB_URL")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def test_engine():
|
|
"""Create a SQLAlchemy engine for testing.
|
|
|
|
Uses one of:
|
|
1. TEST_DB_URL environment variable (dedicated test database)
|
|
2. Docker testcontainers (if Docker is available)
|
|
|
|
IMPORTANT: Will NOT fall back to production database. If Docker is not
|
|
available and TEST_DB_URL is not set, tests will fail with a clear error.
|
|
|
|
The engine is shared across all tests in a session for efficiency.
|
|
"""
|
|
# Try to get URL from environment first
|
|
connection_url = _get_test_db_url()
|
|
|
|
if connection_url:
|
|
# Use external test database from environment
|
|
# Warn if it looks like a production database
|
|
if "docmaster" in connection_url and "_test" not in connection_url:
|
|
import warnings
|
|
warnings.warn(
|
|
"TEST_DB_URL appears to point to a production database. "
|
|
"Please use a dedicated test database (e.g., docmaster_test).",
|
|
UserWarning,
|
|
)
|
|
elif _is_docker_available():
|
|
# Use testcontainers - this is the recommended approach
|
|
from testcontainers.postgres import PostgresContainer
|
|
postgres = PostgresContainer("postgres:15-alpine")
|
|
postgres.start()
|
|
|
|
connection_url = postgres.get_connection_url()
|
|
if "psycopg2" in connection_url:
|
|
connection_url = connection_url.replace("postgresql+psycopg2://", "postgresql://")
|
|
|
|
# Store container for cleanup
|
|
test_engine._postgres_container = postgres
|
|
else:
|
|
# No Docker and no TEST_DB_URL - fail with clear instructions
|
|
pytest.fail(
|
|
"Integration tests require Docker or a TEST_DB_URL environment variable.\n\n"
|
|
"Option 1 (Recommended): Install Docker Desktop and ensure it's running.\n"
|
|
" - Windows: https://docs.docker.com/desktop/install/windows-install/\n"
|
|
" - The testcontainers library will automatically create a PostgreSQL container.\n\n"
|
|
"Option 2: Set TEST_DB_URL to a dedicated test database:\n"
|
|
" - export TEST_DB_URL=postgresql://user:password@host:port/test_dbname\n"
|
|
" - NEVER use your production database for tests!\n\n"
|
|
"Integration tests will NOT fall back to the production database."
|
|
)
|
|
|
|
engine = create_engine(
|
|
connection_url,
|
|
echo=False,
|
|
pool_pre_ping=True,
|
|
)
|
|
|
|
# Create all tables
|
|
SQLModel.metadata.create_all(engine)
|
|
|
|
yield engine
|
|
|
|
# Cleanup
|
|
SQLModel.metadata.drop_all(engine)
|
|
engine.dispose()
|
|
|
|
# Stop container if we started one
|
|
if hasattr(test_engine, "_postgres_container"):
|
|
test_engine._postgres_container.stop()
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def db_session(test_engine) -> Generator[Session, None, None]:
|
|
"""Provide a database session for each test function.
|
|
|
|
Each test gets a fresh session that rolls back after the test,
|
|
ensuring test isolation.
|
|
"""
|
|
connection = test_engine.connect()
|
|
transaction = connection.begin()
|
|
session = Session(bind=connection)
|
|
|
|
yield session
|
|
|
|
# Rollback and cleanup
|
|
session.close()
|
|
transaction.rollback()
|
|
connection.close()
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def patched_session(db_session):
|
|
"""Patch get_session_context to use the test session.
|
|
|
|
This allows repository classes to use the test database session
|
|
instead of creating their own connections.
|
|
|
|
We need to patch in multiple locations because each repository module
|
|
imports get_session_context directly.
|
|
"""
|
|
|
|
@contextmanager
|
|
def mock_session_context() -> Generator[Session, None, None]:
|
|
yield db_session
|
|
|
|
# All modules that import get_session_context
|
|
patch_targets = [
|
|
"backend.data.database.get_session_context",
|
|
"backend.data.repositories.document_repository.get_session_context",
|
|
"backend.data.repositories.annotation_repository.get_session_context",
|
|
"backend.data.repositories.dataset_repository.get_session_context",
|
|
"backend.data.repositories.training_task_repository.get_session_context",
|
|
"backend.data.repositories.model_version_repository.get_session_context",
|
|
"backend.data.repositories.batch_upload_repository.get_session_context",
|
|
"backend.data.repositories.token_repository.get_session_context",
|
|
"backend.web.services.dashboard_service.get_session_context",
|
|
]
|
|
|
|
with ExitStack() as stack:
|
|
for target in patch_targets:
|
|
try:
|
|
stack.enter_context(patch(target, mock_session_context))
|
|
except (ModuleNotFoundError, AttributeError):
|
|
# Skip if module doesn't exist or doesn't have the attribute
|
|
pass
|
|
yield db_session
|
|
|
|
|
|
# =============================================================================
|
|
# Test Data Fixtures
|
|
# =============================================================================
|
|
|
|
|
|
@pytest.fixture
|
|
def admin_token(db_session) -> AdminToken:
|
|
"""Create a test admin token."""
|
|
token = AdminToken(
|
|
token="test-admin-token-12345",
|
|
name="Test Admin",
|
|
is_active=True,
|
|
created_at=datetime.now(timezone.utc),
|
|
)
|
|
db_session.add(token)
|
|
db_session.commit()
|
|
db_session.refresh(token)
|
|
return token
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_document(db_session, admin_token) -> AdminDocument:
|
|
"""Create a sample document for testing."""
|
|
doc = AdminDocument(
|
|
document_id=uuid4(),
|
|
admin_token=admin_token.token,
|
|
filename="test_invoice.pdf",
|
|
file_size=1024,
|
|
content_type="application/pdf",
|
|
file_path="/uploads/test_invoice.pdf",
|
|
page_count=1,
|
|
status="pending",
|
|
upload_source="ui",
|
|
category="invoice",
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
db_session.add(doc)
|
|
db_session.commit()
|
|
db_session.refresh(doc)
|
|
return doc
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_annotation(db_session, sample_document) -> AdminAnnotation:
|
|
"""Create a sample annotation for testing."""
|
|
annotation = AdminAnnotation(
|
|
annotation_id=uuid4(),
|
|
document_id=sample_document.document_id,
|
|
page_number=1,
|
|
class_id=0,
|
|
class_name="invoice_number",
|
|
x_center=0.5,
|
|
y_center=0.3,
|
|
width=0.2,
|
|
height=0.05,
|
|
bbox_x=400,
|
|
bbox_y=240,
|
|
bbox_width=160,
|
|
bbox_height=40,
|
|
text_value="INV-2024-001",
|
|
confidence=0.95,
|
|
source="auto",
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
db_session.add(annotation)
|
|
db_session.commit()
|
|
db_session.refresh(annotation)
|
|
return annotation
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_dataset(db_session) -> TrainingDataset:
|
|
"""Create a sample training dataset for testing."""
|
|
dataset = TrainingDataset(
|
|
dataset_id=uuid4(),
|
|
name="Test Dataset",
|
|
description="Dataset for integration testing",
|
|
status="building",
|
|
train_ratio=0.8,
|
|
val_ratio=0.1,
|
|
seed=42,
|
|
total_documents=0,
|
|
total_images=0,
|
|
total_annotations=0,
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
db_session.add(dataset)
|
|
db_session.commit()
|
|
db_session.refresh(dataset)
|
|
return dataset
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_training_task(db_session, admin_token, sample_dataset) -> TrainingTask:
|
|
"""Create a sample training task for testing."""
|
|
task = TrainingTask(
|
|
task_id=uuid4(),
|
|
admin_token=admin_token.token,
|
|
name="Test Training Task",
|
|
description="Training task for integration testing",
|
|
status="pending",
|
|
task_type="train",
|
|
dataset_id=sample_dataset.dataset_id,
|
|
config={"epochs": 10, "batch_size": 16},
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
db_session.add(task)
|
|
db_session.commit()
|
|
db_session.refresh(task)
|
|
return task
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_model_version(db_session, sample_training_task, sample_dataset) -> ModelVersion:
|
|
"""Create a sample model version for testing."""
|
|
version = ModelVersion(
|
|
version_id=uuid4(),
|
|
version="1.0.0",
|
|
name="Test Model v1",
|
|
description="Model version for integration testing",
|
|
model_path="/models/test_model.pt",
|
|
status="inactive",
|
|
is_active=False,
|
|
task_id=sample_training_task.task_id,
|
|
dataset_id=sample_dataset.dataset_id,
|
|
metrics_mAP=0.85,
|
|
metrics_precision=0.88,
|
|
metrics_recall=0.82,
|
|
document_count=100,
|
|
file_size=50000000,
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
db_session.add(version)
|
|
db_session.commit()
|
|
db_session.refresh(version)
|
|
return version
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_batch_upload(db_session, admin_token) -> BatchUpload:
|
|
"""Create a sample batch upload for testing."""
|
|
batch = BatchUpload(
|
|
batch_id=uuid4(),
|
|
admin_token=admin_token.token,
|
|
filename="test_batch.zip",
|
|
file_size=10240,
|
|
upload_source="api",
|
|
status="processing",
|
|
total_files=5,
|
|
processed_files=0,
|
|
successful_files=0,
|
|
failed_files=0,
|
|
created_at=datetime.now(timezone.utc),
|
|
)
|
|
db_session.add(batch)
|
|
db_session.commit()
|
|
db_session.refresh(batch)
|
|
return batch
|
|
|
|
|
|
# =============================================================================
|
|
# Multiple Documents Fixture
|
|
# =============================================================================
|
|
|
|
|
|
@pytest.fixture
|
|
def multiple_documents(db_session, admin_token) -> list[AdminDocument]:
|
|
"""Create multiple documents for pagination/filtering tests."""
|
|
documents = []
|
|
statuses = ["pending", "pending", "labeled", "labeled", "exported"]
|
|
categories = ["invoice", "invoice", "invoice", "letter", "invoice"]
|
|
|
|
for i, (status, category) in enumerate(zip(statuses, categories)):
|
|
doc = AdminDocument(
|
|
document_id=uuid4(),
|
|
admin_token=admin_token.token,
|
|
filename=f"test_doc_{i}.pdf",
|
|
file_size=1024 + i * 100,
|
|
content_type="application/pdf",
|
|
file_path=f"/uploads/test_doc_{i}.pdf",
|
|
page_count=1,
|
|
status=status,
|
|
upload_source="ui",
|
|
category=category,
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
db_session.add(doc)
|
|
documents.append(doc)
|
|
|
|
db_session.commit()
|
|
for doc in documents:
|
|
db_session.refresh(doc)
|
|
|
|
return documents
|
|
|
|
|
|
# =============================================================================
|
|
# Temporary File Fixtures
|
|
# =============================================================================
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_upload_dir() -> Generator[Path, None, None]:
|
|
"""Create a temporary directory for file uploads."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
upload_dir = Path(tmpdir) / "uploads"
|
|
upload_dir.mkdir(parents=True, exist_ok=True)
|
|
yield upload_dir
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_model_dir() -> Generator[Path, None, None]:
|
|
"""Create a temporary directory for model files."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
model_dir = Path(tmpdir) / "models"
|
|
model_dir.mkdir(parents=True, exist_ok=True)
|
|
yield model_dir
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_dataset_dir() -> Generator[Path, None, None]:
|
|
"""Create a temporary directory for dataset files."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
dataset_dir = Path(tmpdir) / "datasets"
|
|
dataset_dir.mkdir(parents=True, exist_ok=True)
|
|
yield dataset_dir
|
|
|
|
|
|
# =============================================================================
|
|
# Sample PDF Fixture
|
|
# =============================================================================
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_pdf_bytes() -> bytes:
|
|
"""Return minimal valid PDF bytes for testing."""
|
|
# Minimal valid PDF structure
|
|
return b"""%PDF-1.4
|
|
1 0 obj
|
|
<< /Type /Catalog /Pages 2 0 R >>
|
|
endobj
|
|
2 0 obj
|
|
<< /Type /Pages /Kids [3 0 R] /Count 1 >>
|
|
endobj
|
|
3 0 obj
|
|
<< /Type /Page /Parent 2 0 R /MediaBox [0 0 612 792] >>
|
|
endobj
|
|
xref
|
|
0 4
|
|
0000000000 65535 f
|
|
0000000009 00000 n
|
|
0000000058 00000 n
|
|
0000000115 00000 n
|
|
trailer
|
|
<< /Size 4 /Root 1 0 R >>
|
|
startxref
|
|
196
|
|
%%EOF"""
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_pdf_file(temp_upload_dir, sample_pdf_bytes) -> Path:
|
|
"""Create a sample PDF file for testing."""
|
|
pdf_path = temp_upload_dir / "test_invoice.pdf"
|
|
pdf_path.write_bytes(sample_pdf_bytes)
|
|
return pdf_path
|