""" Tests for augmentation API routes. TDD Phase 5: RED - Write tests first, then implement to pass. """ import pytest from unittest.mock import MagicMock, patch from fastapi import FastAPI from fastapi.testclient import TestClient import numpy as np from inference.web.api.v1.admin.augmentation import create_augmentation_router from inference.web.core.auth import validate_admin_token, get_admin_db TEST_ADMIN_TOKEN = "test-admin-token-12345" TEST_DOCUMENT_UUID = "550e8400-e29b-41d4-a716-446655440001" TEST_DATASET_UUID = "660e8400-e29b-41d4-a716-446655440001" @pytest.fixture def admin_token() -> str: """Provide admin token for testing.""" return TEST_ADMIN_TOKEN @pytest.fixture def mock_admin_db() -> MagicMock: """Create a mock AdminDB for testing.""" mock = MagicMock() # Default return values mock.get_document_by_token.return_value = None mock.get_dataset.return_value = None mock.get_augmented_datasets.return_value = ([], 0) return mock @pytest.fixture def admin_client(mock_admin_db: MagicMock) -> TestClient: """Create test client with admin authentication.""" app = FastAPI() # Override dependencies def get_token_override(): return TEST_ADMIN_TOKEN def get_db_override(): return mock_admin_db app.dependency_overrides[validate_admin_token] = get_token_override app.dependency_overrides[get_admin_db] = get_db_override # Include router - the router already has /augmentation prefix # so we add /api/v1/admin to get /api/v1/admin/augmentation router = create_augmentation_router() app.include_router(router, prefix="/api/v1/admin") return TestClient(app) @pytest.fixture def unauthenticated_client(mock_admin_db: MagicMock) -> TestClient: """Create test client WITHOUT admin authentication override.""" app = FastAPI() # Only override the database, NOT the token validation def get_db_override(): return mock_admin_db app.dependency_overrides[get_admin_db] = get_db_override router = create_augmentation_router() app.include_router(router, prefix="/api/v1/admin") return TestClient(app) class TestAugmentationTypesEndpoint: """Tests for GET /admin/augmentation/types endpoint.""" def test_list_augmentation_types( self, admin_client: TestClient, admin_token: str ) -> None: """Test listing available augmentation types.""" response = admin_client.get( "/api/v1/admin/augmentation/types", headers={"X-Admin-Token": admin_token}, ) assert response.status_code == 200 data = response.json() assert "augmentation_types" in data assert len(data["augmentation_types"]) == 12 # Check structure aug_type = data["augmentation_types"][0] assert "name" in aug_type assert "description" in aug_type assert "affects_geometry" in aug_type assert "stage" in aug_type def test_list_augmentation_types_unauthorized( self, unauthenticated_client: TestClient ) -> None: """Test that unauthorized request is rejected.""" response = unauthenticated_client.get("/api/v1/admin/augmentation/types") assert response.status_code == 401 class TestAugmentationPresetsEndpoint: """Tests for GET /admin/augmentation/presets endpoint.""" def test_list_presets(self, admin_client: TestClient, admin_token: str) -> None: """Test listing available presets.""" response = admin_client.get( "/api/v1/admin/augmentation/presets", headers={"X-Admin-Token": admin_token}, ) assert response.status_code == 200 data = response.json() assert "presets" in data assert len(data["presets"]) >= 4 # Check expected presets exist preset_names = [p["name"] for p in data["presets"]] assert "conservative" in preset_names assert "moderate" in preset_names assert "aggressive" in preset_names assert "scanned_document" in preset_names class TestAugmentationPreviewEndpoint: """Tests for POST /admin/augmentation/preview/{document_id} endpoint.""" def test_preview_augmentation( self, admin_client: TestClient, admin_token: str, sample_document_id: str, mock_admin_db: MagicMock, ) -> None: """Test previewing augmentation on a document.""" # Mock document exists mock_document = MagicMock() mock_document.images_dir = "/fake/path" mock_admin_db.get_document.return_value = mock_document # Create a fake image (100x100 RGB) fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) with patch( "inference.web.services.augmentation_service.AugmentationService._load_document_page" ) as mock_load: mock_load.return_value = fake_image response = admin_client.post( f"/api/v1/admin/augmentation/preview/{sample_document_id}", headers={"X-Admin-Token": admin_token}, json={ "augmentation_type": "gaussian_noise", "params": {"std": 15}, }, ) assert response.status_code == 200 data = response.json() assert "preview_url" in data assert "original_url" in data assert "applied_params" in data def test_preview_invalid_augmentation_type( self, admin_client: TestClient, admin_token: str, sample_document_id: str, ) -> None: """Test that invalid augmentation type returns error.""" response = admin_client.post( f"/api/v1/admin/augmentation/preview/{sample_document_id}", headers={"X-Admin-Token": admin_token}, json={ "augmentation_type": "nonexistent", "params": {}, }, ) assert response.status_code == 400 def test_preview_nonexistent_document( self, admin_client: TestClient, admin_token: str, ) -> None: """Test that nonexistent document returns 404.""" response = admin_client.post( "/api/v1/admin/augmentation/preview/00000000-0000-0000-0000-000000000000", headers={"X-Admin-Token": admin_token}, json={ "augmentation_type": "gaussian_noise", "params": {}, }, ) assert response.status_code == 404 class TestAugmentationPreviewConfigEndpoint: """Tests for POST /admin/augmentation/preview-config/{document_id} endpoint.""" def test_preview_config( self, admin_client: TestClient, admin_token: str, sample_document_id: str, mock_admin_db: MagicMock, ) -> None: """Test previewing full config on a document.""" # Mock document exists mock_document = MagicMock() mock_document.images_dir = "/fake/path" mock_admin_db.get_document.return_value = mock_document # Create a fake image (100x100 RGB) fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) with patch( "inference.web.services.augmentation_service.AugmentationService._load_document_page" ) as mock_load: mock_load.return_value = fake_image response = admin_client.post( f"/api/v1/admin/augmentation/preview-config/{sample_document_id}", headers={"X-Admin-Token": admin_token}, json={ "gaussian_noise": {"enabled": True, "probability": 1.0}, "lighting_variation": {"enabled": True, "probability": 1.0}, "preserve_bboxes": True, "seed": 42, }, ) assert response.status_code == 200 data = response.json() assert "preview_url" in data assert "original_url" in data class TestAugmentationBatchEndpoint: """Tests for POST /admin/augmentation/batch endpoint.""" def test_create_augmented_dataset( self, admin_client: TestClient, admin_token: str, sample_dataset_id: str, mock_admin_db: MagicMock, ) -> None: """Test creating augmented dataset.""" # Mock dataset exists mock_dataset = MagicMock() mock_dataset.total_images = 100 mock_admin_db.get_dataset.return_value = mock_dataset response = admin_client.post( "/api/v1/admin/augmentation/batch", headers={"X-Admin-Token": admin_token}, json={ "dataset_id": sample_dataset_id, "config": { "gaussian_noise": {"enabled": True, "probability": 0.5}, "preserve_bboxes": True, }, "output_name": "test_augmented_dataset", "multiplier": 2, }, ) assert response.status_code == 200 data = response.json() assert "task_id" in data assert "status" in data assert "estimated_images" in data def test_create_augmented_dataset_invalid_multiplier( self, admin_client: TestClient, admin_token: str, sample_dataset_id: str, ) -> None: """Test that invalid multiplier is rejected.""" response = admin_client.post( "/api/v1/admin/augmentation/batch", headers={"X-Admin-Token": admin_token}, json={ "dataset_id": sample_dataset_id, "config": {}, "output_name": "test", "multiplier": 100, # Too high }, ) assert response.status_code == 422 # Validation error class TestAugmentedDatasetsListEndpoint: """Tests for GET /admin/augmentation/datasets endpoint.""" def test_list_augmented_datasets( self, admin_client: TestClient, admin_token: str ) -> None: """Test listing augmented datasets.""" response = admin_client.get( "/api/v1/admin/augmentation/datasets", headers={"X-Admin-Token": admin_token}, ) assert response.status_code == 200 data = response.json() assert "total" in data assert "limit" in data assert "offset" in data assert "datasets" in data assert isinstance(data["datasets"], list) def test_list_augmented_datasets_pagination( self, admin_client: TestClient, admin_token: str ) -> None: """Test pagination parameters.""" response = admin_client.get( "/api/v1/admin/augmentation/datasets", headers={"X-Admin-Token": admin_token}, params={"limit": 5, "offset": 0}, ) assert response.status_code == 200 data = response.json() assert data["limit"] == 5 assert data["offset"] == 0 # Fixtures for tests @pytest.fixture def sample_document_id() -> str: """Provide a sample document ID for testing.""" return TEST_DOCUMENT_UUID @pytest.fixture def sample_dataset_id() -> str: """Provide a sample dataset ID for testing.""" return TEST_DATASET_UUID