""" Integration Test Fixtures Provides shared fixtures for integration tests using PostgreSQL. IMPORTANT: Integration tests MUST use Docker testcontainers for database isolation. This ensures tests never touch the real production/development database. Supported modes: 1. Docker testcontainers (default): Automatically starts a PostgreSQL container 2. TEST_DB_URL environment variable: Use a dedicated test database (NOT production!) To use an external test database, set: TEST_DB_URL=postgresql://user:password@host:port/test_dbname """ import os import tempfile from contextlib import contextmanager, ExitStack from datetime import datetime, timezone from pathlib import Path from typing import Generator from unittest.mock import patch from uuid import uuid4 import pytest from sqlmodel import Session, SQLModel, create_engine from inference.data.admin_models import ( AdminAnnotation, AdminDocument, AdminToken, AnnotationHistory, BatchUpload, BatchUploadFile, DatasetDocument, ModelVersion, TrainingDataset, TrainingDocumentLink, TrainingLog, TrainingTask, ) # ============================================================================= # Database Fixtures # ============================================================================= def _is_docker_available() -> bool: """Check if Docker is available.""" try: import docker client = docker.from_env() client.ping() return True except Exception: return False def _get_test_db_url() -> str | None: """Get test database URL from environment.""" return os.environ.get("TEST_DB_URL") @pytest.fixture(scope="session") def test_engine(): """Create a SQLAlchemy engine for testing. Uses one of: 1. TEST_DB_URL environment variable (dedicated test database) 2. Docker testcontainers (if Docker is available) IMPORTANT: Will NOT fall back to production database. If Docker is not available and TEST_DB_URL is not set, tests will fail with a clear error. The engine is shared across all tests in a session for efficiency. """ # Try to get URL from environment first connection_url = _get_test_db_url() if connection_url: # Use external test database from environment # Warn if it looks like a production database if "docmaster" in connection_url and "_test" not in connection_url: import warnings warnings.warn( "TEST_DB_URL appears to point to a production database. " "Please use a dedicated test database (e.g., docmaster_test).", UserWarning, ) elif _is_docker_available(): # Use testcontainers - this is the recommended approach from testcontainers.postgres import PostgresContainer postgres = PostgresContainer("postgres:15-alpine") postgres.start() connection_url = postgres.get_connection_url() if "psycopg2" in connection_url: connection_url = connection_url.replace("postgresql+psycopg2://", "postgresql://") # Store container for cleanup test_engine._postgres_container = postgres else: # No Docker and no TEST_DB_URL - fail with clear instructions pytest.fail( "Integration tests require Docker or a TEST_DB_URL environment variable.\n\n" "Option 1 (Recommended): Install Docker Desktop and ensure it's running.\n" " - Windows: https://docs.docker.com/desktop/install/windows-install/\n" " - The testcontainers library will automatically create a PostgreSQL container.\n\n" "Option 2: Set TEST_DB_URL to a dedicated test database:\n" " - export TEST_DB_URL=postgresql://user:password@host:port/test_dbname\n" " - NEVER use your production database for tests!\n\n" "Integration tests will NOT fall back to the production database." ) engine = create_engine( connection_url, echo=False, pool_pre_ping=True, ) # Create all tables SQLModel.metadata.create_all(engine) yield engine # Cleanup SQLModel.metadata.drop_all(engine) engine.dispose() # Stop container if we started one if hasattr(test_engine, "_postgres_container"): test_engine._postgres_container.stop() @pytest.fixture(scope="function") def db_session(test_engine) -> Generator[Session, None, None]: """Provide a database session for each test function. Each test gets a fresh session that rolls back after the test, ensuring test isolation. """ connection = test_engine.connect() transaction = connection.begin() session = Session(bind=connection) yield session # Rollback and cleanup session.close() transaction.rollback() connection.close() @pytest.fixture(scope="function") def patched_session(db_session): """Patch get_session_context to use the test session. This allows repository classes to use the test database session instead of creating their own connections. We need to patch in multiple locations because each repository module imports get_session_context directly. """ @contextmanager def mock_session_context() -> Generator[Session, None, None]: yield db_session # All modules that import get_session_context patch_targets = [ "inference.data.database.get_session_context", "inference.data.repositories.document_repository.get_session_context", "inference.data.repositories.annotation_repository.get_session_context", "inference.data.repositories.dataset_repository.get_session_context", "inference.data.repositories.training_task_repository.get_session_context", "inference.data.repositories.model_version_repository.get_session_context", "inference.data.repositories.batch_upload_repository.get_session_context", "inference.data.repositories.token_repository.get_session_context", "inference.web.services.dashboard_service.get_session_context", ] with ExitStack() as stack: for target in patch_targets: try: stack.enter_context(patch(target, mock_session_context)) except (ModuleNotFoundError, AttributeError): # Skip if module doesn't exist or doesn't have the attribute pass yield db_session # ============================================================================= # Test Data Fixtures # ============================================================================= @pytest.fixture def admin_token(db_session) -> AdminToken: """Create a test admin token.""" token = AdminToken( token="test-admin-token-12345", name="Test Admin", is_active=True, created_at=datetime.now(timezone.utc), ) db_session.add(token) db_session.commit() db_session.refresh(token) return token @pytest.fixture def sample_document(db_session, admin_token) -> AdminDocument: """Create a sample document for testing.""" doc = AdminDocument( document_id=uuid4(), admin_token=admin_token.token, filename="test_invoice.pdf", file_size=1024, content_type="application/pdf", file_path="/uploads/test_invoice.pdf", page_count=1, status="pending", upload_source="ui", category="invoice", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) db_session.add(doc) db_session.commit() db_session.refresh(doc) return doc @pytest.fixture def sample_annotation(db_session, sample_document) -> AdminAnnotation: """Create a sample annotation for testing.""" annotation = AdminAnnotation( annotation_id=uuid4(), document_id=sample_document.document_id, page_number=1, class_id=0, class_name="invoice_number", x_center=0.5, y_center=0.3, width=0.2, height=0.05, bbox_x=400, bbox_y=240, bbox_width=160, bbox_height=40, text_value="INV-2024-001", confidence=0.95, source="auto", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) db_session.add(annotation) db_session.commit() db_session.refresh(annotation) return annotation @pytest.fixture def sample_dataset(db_session) -> TrainingDataset: """Create a sample training dataset for testing.""" dataset = TrainingDataset( dataset_id=uuid4(), name="Test Dataset", description="Dataset for integration testing", status="building", train_ratio=0.8, val_ratio=0.1, seed=42, total_documents=0, total_images=0, total_annotations=0, created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) db_session.add(dataset) db_session.commit() db_session.refresh(dataset) return dataset @pytest.fixture def sample_training_task(db_session, admin_token, sample_dataset) -> TrainingTask: """Create a sample training task for testing.""" task = TrainingTask( task_id=uuid4(), admin_token=admin_token.token, name="Test Training Task", description="Training task for integration testing", status="pending", task_type="train", dataset_id=sample_dataset.dataset_id, config={"epochs": 10, "batch_size": 16}, created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) db_session.add(task) db_session.commit() db_session.refresh(task) return task @pytest.fixture def sample_model_version(db_session, sample_training_task, sample_dataset) -> ModelVersion: """Create a sample model version for testing.""" version = ModelVersion( version_id=uuid4(), version="1.0.0", name="Test Model v1", description="Model version for integration testing", model_path="/models/test_model.pt", status="inactive", is_active=False, task_id=sample_training_task.task_id, dataset_id=sample_dataset.dataset_id, metrics_mAP=0.85, metrics_precision=0.88, metrics_recall=0.82, document_count=100, file_size=50000000, created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) db_session.add(version) db_session.commit() db_session.refresh(version) return version @pytest.fixture def sample_batch_upload(db_session, admin_token) -> BatchUpload: """Create a sample batch upload for testing.""" batch = BatchUpload( batch_id=uuid4(), admin_token=admin_token.token, filename="test_batch.zip", file_size=10240, upload_source="api", status="processing", total_files=5, processed_files=0, successful_files=0, failed_files=0, created_at=datetime.now(timezone.utc), ) db_session.add(batch) db_session.commit() db_session.refresh(batch) return batch # ============================================================================= # Multiple Documents Fixture # ============================================================================= @pytest.fixture def multiple_documents(db_session, admin_token) -> list[AdminDocument]: """Create multiple documents for pagination/filtering tests.""" documents = [] statuses = ["pending", "pending", "labeled", "labeled", "exported"] categories = ["invoice", "invoice", "invoice", "letter", "invoice"] for i, (status, category) in enumerate(zip(statuses, categories)): doc = AdminDocument( document_id=uuid4(), admin_token=admin_token.token, filename=f"test_doc_{i}.pdf", file_size=1024 + i * 100, content_type="application/pdf", file_path=f"/uploads/test_doc_{i}.pdf", page_count=1, status=status, upload_source="ui", category=category, created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) db_session.add(doc) documents.append(doc) db_session.commit() for doc in documents: db_session.refresh(doc) return documents # ============================================================================= # Temporary File Fixtures # ============================================================================= @pytest.fixture def temp_upload_dir() -> Generator[Path, None, None]: """Create a temporary directory for file uploads.""" with tempfile.TemporaryDirectory() as tmpdir: upload_dir = Path(tmpdir) / "uploads" upload_dir.mkdir(parents=True, exist_ok=True) yield upload_dir @pytest.fixture def temp_model_dir() -> Generator[Path, None, None]: """Create a temporary directory for model files.""" with tempfile.TemporaryDirectory() as tmpdir: model_dir = Path(tmpdir) / "models" model_dir.mkdir(parents=True, exist_ok=True) yield model_dir @pytest.fixture def temp_dataset_dir() -> Generator[Path, None, None]: """Create a temporary directory for dataset files.""" with tempfile.TemporaryDirectory() as tmpdir: dataset_dir = Path(tmpdir) / "datasets" dataset_dir.mkdir(parents=True, exist_ok=True) yield dataset_dir # ============================================================================= # Sample PDF Fixture # ============================================================================= @pytest.fixture def sample_pdf_bytes() -> bytes: """Return minimal valid PDF bytes for testing.""" # Minimal valid PDF structure return b"""%PDF-1.4 1 0 obj << /Type /Catalog /Pages 2 0 R >> endobj 2 0 obj << /Type /Pages /Kids [3 0 R] /Count 1 >> endobj 3 0 obj << /Type /Page /Parent 2 0 R /MediaBox [0 0 612 792] >> endobj xref 0 4 0000000000 65535 f 0000000009 00000 n 0000000058 00000 n 0000000115 00000 n trailer << /Size 4 /Root 1 0 R >> startxref 196 %%EOF""" @pytest.fixture def sample_pdf_file(temp_upload_dir, sample_pdf_bytes) -> Path: """Create a sample PDF file for testing.""" pdf_path = temp_upload_dir / "test_invoice.pdf" pdf_path.write_bytes(sample_pdf_bytes) return pdf_path