Files
invoice-master-poc-v2/tests/integration/conftest.py
2026-02-01 22:40:41 +01:00

466 lines
14 KiB
Python

"""
Integration Test Fixtures
Provides shared fixtures for integration tests using PostgreSQL.
IMPORTANT: Integration tests MUST use Docker testcontainers for database isolation.
This ensures tests never touch the real production/development database.
Supported modes:
1. Docker testcontainers (default): Automatically starts a PostgreSQL container
2. TEST_DB_URL environment variable: Use a dedicated test database (NOT production!)
To use an external test database, set:
TEST_DB_URL=postgresql://user:password@host:port/test_dbname
"""
import os
import tempfile
from contextlib import contextmanager, ExitStack
from datetime import datetime, timezone
from pathlib import Path
from typing import Generator
from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlmodel import Session, SQLModel, create_engine
from inference.data.admin_models import (
AdminAnnotation,
AdminDocument,
AdminToken,
AnnotationHistory,
BatchUpload,
BatchUploadFile,
DatasetDocument,
ModelVersion,
TrainingDataset,
TrainingDocumentLink,
TrainingLog,
TrainingTask,
)
# =============================================================================
# Database Fixtures
# =============================================================================
def _is_docker_available() -> bool:
"""Check if Docker is available."""
try:
import docker
client = docker.from_env()
client.ping()
return True
except Exception:
return False
def _get_test_db_url() -> str | None:
"""Get test database URL from environment."""
return os.environ.get("TEST_DB_URL")
@pytest.fixture(scope="session")
def test_engine():
"""Create a SQLAlchemy engine for testing.
Uses one of:
1. TEST_DB_URL environment variable (dedicated test database)
2. Docker testcontainers (if Docker is available)
IMPORTANT: Will NOT fall back to production database. If Docker is not
available and TEST_DB_URL is not set, tests will fail with a clear error.
The engine is shared across all tests in a session for efficiency.
"""
# Try to get URL from environment first
connection_url = _get_test_db_url()
if connection_url:
# Use external test database from environment
# Warn if it looks like a production database
if "docmaster" in connection_url and "_test" not in connection_url:
import warnings
warnings.warn(
"TEST_DB_URL appears to point to a production database. "
"Please use a dedicated test database (e.g., docmaster_test).",
UserWarning,
)
elif _is_docker_available():
# Use testcontainers - this is the recommended approach
from testcontainers.postgres import PostgresContainer
postgres = PostgresContainer("postgres:15-alpine")
postgres.start()
connection_url = postgres.get_connection_url()
if "psycopg2" in connection_url:
connection_url = connection_url.replace("postgresql+psycopg2://", "postgresql://")
# Store container for cleanup
test_engine._postgres_container = postgres
else:
# No Docker and no TEST_DB_URL - fail with clear instructions
pytest.fail(
"Integration tests require Docker or a TEST_DB_URL environment variable.\n\n"
"Option 1 (Recommended): Install Docker Desktop and ensure it's running.\n"
" - Windows: https://docs.docker.com/desktop/install/windows-install/\n"
" - The testcontainers library will automatically create a PostgreSQL container.\n\n"
"Option 2: Set TEST_DB_URL to a dedicated test database:\n"
" - export TEST_DB_URL=postgresql://user:password@host:port/test_dbname\n"
" - NEVER use your production database for tests!\n\n"
"Integration tests will NOT fall back to the production database."
)
engine = create_engine(
connection_url,
echo=False,
pool_pre_ping=True,
)
# Create all tables
SQLModel.metadata.create_all(engine)
yield engine
# Cleanup
SQLModel.metadata.drop_all(engine)
engine.dispose()
# Stop container if we started one
if hasattr(test_engine, "_postgres_container"):
test_engine._postgres_container.stop()
@pytest.fixture(scope="function")
def db_session(test_engine) -> Generator[Session, None, None]:
"""Provide a database session for each test function.
Each test gets a fresh session that rolls back after the test,
ensuring test isolation.
"""
connection = test_engine.connect()
transaction = connection.begin()
session = Session(bind=connection)
yield session
# Rollback and cleanup
session.close()
transaction.rollback()
connection.close()
@pytest.fixture(scope="function")
def patched_session(db_session):
"""Patch get_session_context to use the test session.
This allows repository classes to use the test database session
instead of creating their own connections.
We need to patch in multiple locations because each repository module
imports get_session_context directly.
"""
@contextmanager
def mock_session_context() -> Generator[Session, None, None]:
yield db_session
# All modules that import get_session_context
patch_targets = [
"inference.data.database.get_session_context",
"inference.data.repositories.document_repository.get_session_context",
"inference.data.repositories.annotation_repository.get_session_context",
"inference.data.repositories.dataset_repository.get_session_context",
"inference.data.repositories.training_task_repository.get_session_context",
"inference.data.repositories.model_version_repository.get_session_context",
"inference.data.repositories.batch_upload_repository.get_session_context",
"inference.data.repositories.token_repository.get_session_context",
"inference.web.services.dashboard_service.get_session_context",
]
with ExitStack() as stack:
for target in patch_targets:
try:
stack.enter_context(patch(target, mock_session_context))
except (ModuleNotFoundError, AttributeError):
# Skip if module doesn't exist or doesn't have the attribute
pass
yield db_session
# =============================================================================
# Test Data Fixtures
# =============================================================================
@pytest.fixture
def admin_token(db_session) -> AdminToken:
"""Create a test admin token."""
token = AdminToken(
token="test-admin-token-12345",
name="Test Admin",
is_active=True,
created_at=datetime.now(timezone.utc),
)
db_session.add(token)
db_session.commit()
db_session.refresh(token)
return token
@pytest.fixture
def sample_document(db_session, admin_token) -> AdminDocument:
"""Create a sample document for testing."""
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename="test_invoice.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/test_invoice.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(doc)
db_session.commit()
db_session.refresh(doc)
return doc
@pytest.fixture
def sample_annotation(db_session, sample_document) -> AdminAnnotation:
"""Create a sample annotation for testing."""
annotation = AdminAnnotation(
annotation_id=uuid4(),
document_id=sample_document.document_id,
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.3,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=240,
bbox_width=160,
bbox_height=40,
text_value="INV-2024-001",
confidence=0.95,
source="auto",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(annotation)
db_session.commit()
db_session.refresh(annotation)
return annotation
@pytest.fixture
def sample_dataset(db_session) -> TrainingDataset:
"""Create a sample training dataset for testing."""
dataset = TrainingDataset(
dataset_id=uuid4(),
name="Test Dataset",
description="Dataset for integration testing",
status="building",
train_ratio=0.8,
val_ratio=0.1,
seed=42,
total_documents=0,
total_images=0,
total_annotations=0,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(dataset)
db_session.commit()
db_session.refresh(dataset)
return dataset
@pytest.fixture
def sample_training_task(db_session, admin_token, sample_dataset) -> TrainingTask:
"""Create a sample training task for testing."""
task = TrainingTask(
task_id=uuid4(),
admin_token=admin_token.token,
name="Test Training Task",
description="Training task for integration testing",
status="pending",
task_type="train",
dataset_id=sample_dataset.dataset_id,
config={"epochs": 10, "batch_size": 16},
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
return task
@pytest.fixture
def sample_model_version(db_session, sample_training_task, sample_dataset) -> ModelVersion:
"""Create a sample model version for testing."""
version = ModelVersion(
version_id=uuid4(),
version="1.0.0",
name="Test Model v1",
description="Model version for integration testing",
model_path="/models/test_model.pt",
status="inactive",
is_active=False,
task_id=sample_training_task.task_id,
dataset_id=sample_dataset.dataset_id,
metrics_mAP=0.85,
metrics_precision=0.88,
metrics_recall=0.82,
document_count=100,
file_size=50000000,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(version)
db_session.commit()
db_session.refresh(version)
return version
@pytest.fixture
def sample_batch_upload(db_session, admin_token) -> BatchUpload:
"""Create a sample batch upload for testing."""
batch = BatchUpload(
batch_id=uuid4(),
admin_token=admin_token.token,
filename="test_batch.zip",
file_size=10240,
upload_source="api",
status="processing",
total_files=5,
processed_files=0,
successful_files=0,
failed_files=0,
created_at=datetime.now(timezone.utc),
)
db_session.add(batch)
db_session.commit()
db_session.refresh(batch)
return batch
# =============================================================================
# Multiple Documents Fixture
# =============================================================================
@pytest.fixture
def multiple_documents(db_session, admin_token) -> list[AdminDocument]:
"""Create multiple documents for pagination/filtering tests."""
documents = []
statuses = ["pending", "pending", "labeled", "labeled", "exported"]
categories = ["invoice", "invoice", "invoice", "letter", "invoice"]
for i, (status, category) in enumerate(zip(statuses, categories)):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"test_doc_{i}.pdf",
file_size=1024 + i * 100,
content_type="application/pdf",
file_path=f"/uploads/test_doc_{i}.pdf",
page_count=1,
status=status,
upload_source="ui",
category=category,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(doc)
documents.append(doc)
db_session.commit()
for doc in documents:
db_session.refresh(doc)
return documents
# =============================================================================
# Temporary File Fixtures
# =============================================================================
@pytest.fixture
def temp_upload_dir() -> Generator[Path, None, None]:
"""Create a temporary directory for file uploads."""
with tempfile.TemporaryDirectory() as tmpdir:
upload_dir = Path(tmpdir) / "uploads"
upload_dir.mkdir(parents=True, exist_ok=True)
yield upload_dir
@pytest.fixture
def temp_model_dir() -> Generator[Path, None, None]:
"""Create a temporary directory for model files."""
with tempfile.TemporaryDirectory() as tmpdir:
model_dir = Path(tmpdir) / "models"
model_dir.mkdir(parents=True, exist_ok=True)
yield model_dir
@pytest.fixture
def temp_dataset_dir() -> Generator[Path, None, None]:
"""Create a temporary directory for dataset files."""
with tempfile.TemporaryDirectory() as tmpdir:
dataset_dir = Path(tmpdir) / "datasets"
dataset_dir.mkdir(parents=True, exist_ok=True)
yield dataset_dir
# =============================================================================
# Sample PDF Fixture
# =============================================================================
@pytest.fixture
def sample_pdf_bytes() -> bytes:
"""Return minimal valid PDF bytes for testing."""
# Minimal valid PDF structure
return b"""%PDF-1.4
1 0 obj
<< /Type /Catalog /Pages 2 0 R >>
endobj
2 0 obj
<< /Type /Pages /Kids [3 0 R] /Count 1 >>
endobj
3 0 obj
<< /Type /Page /Parent 2 0 R /MediaBox [0 0 612 792] >>
endobj
xref
0 4
0000000000 65535 f
0000000009 00000 n
0000000058 00000 n
0000000115 00000 n
trailer
<< /Size 4 /Root 1 0 R >>
startxref
196
%%EOF"""
@pytest.fixture
def sample_pdf_file(temp_upload_dir, sample_pdf_bytes) -> Path:
"""Create a sample PDF file for testing."""
pdf_path = temp_upload_dir / "test_invoice.pdf"
pdf_path.write_bytes(sample_pdf_bytes)
return pdf_path