Files
invoice-master-poc-v2/tests/web/test_augmentation_routes.py
Yaojia Wang a564ac9d70 WIP
2026-02-01 18:51:54 +01:00

383 lines
12 KiB
Python

"""
Tests for augmentation API routes.
TDD Phase 5: RED - Write tests first, then implement to pass.
"""
import pytest
from unittest.mock import MagicMock, patch
from fastapi import FastAPI
from fastapi.testclient import TestClient
import numpy as np
from inference.web.api.v1.admin.augmentation import create_augmentation_router
from inference.web.core.auth import (
validate_admin_token,
get_document_repository,
get_dataset_repository,
)
TEST_ADMIN_TOKEN = "test-admin-token-12345"
TEST_DOCUMENT_UUID = "550e8400-e29b-41d4-a716-446655440001"
TEST_DATASET_UUID = "660e8400-e29b-41d4-a716-446655440001"
@pytest.fixture
def admin_token() -> str:
"""Provide admin token for testing."""
return TEST_ADMIN_TOKEN
@pytest.fixture
def mock_document_repo() -> MagicMock:
"""Create a mock DocumentRepository for testing."""
mock = MagicMock()
# Default return values
mock.get.return_value = None
mock.get_by_token.return_value = None
return mock
@pytest.fixture
def mock_dataset_repo() -> MagicMock:
"""Create a mock DatasetRepository for testing."""
mock = MagicMock()
# Default return values
mock.get.return_value = None
mock.get_paginated.return_value = ([], 0)
return mock
@pytest.fixture
def admin_client(mock_document_repo: MagicMock, mock_dataset_repo: MagicMock) -> TestClient:
"""Create test client with admin authentication."""
app = FastAPI()
# Override dependencies
def get_token_override():
return TEST_ADMIN_TOKEN
def get_document_repo_override():
return mock_document_repo
def get_dataset_repo_override():
return mock_dataset_repo
app.dependency_overrides[validate_admin_token] = get_token_override
app.dependency_overrides[get_document_repository] = get_document_repo_override
app.dependency_overrides[get_dataset_repository] = get_dataset_repo_override
# Include router - the router already has /augmentation prefix
# so we add /api/v1/admin to get /api/v1/admin/augmentation
router = create_augmentation_router()
app.include_router(router, prefix="/api/v1/admin")
return TestClient(app)
@pytest.fixture
def unauthenticated_client(mock_document_repo: MagicMock, mock_dataset_repo: MagicMock) -> TestClient:
"""Create test client WITHOUT admin authentication override."""
app = FastAPI()
# Only override the repositories, NOT the token validation
def get_document_repo_override():
return mock_document_repo
def get_dataset_repo_override():
return mock_dataset_repo
app.dependency_overrides[get_document_repository] = get_document_repo_override
app.dependency_overrides[get_dataset_repository] = get_dataset_repo_override
router = create_augmentation_router()
app.include_router(router, prefix="/api/v1/admin")
return TestClient(app)
class TestAugmentationTypesEndpoint:
"""Tests for GET /admin/augmentation/types endpoint."""
def test_list_augmentation_types(
self, admin_client: TestClient, admin_token: str
) -> None:
"""Test listing available augmentation types."""
response = admin_client.get(
"/api/v1/admin/augmentation/types",
headers={"X-Admin-Token": admin_token},
)
assert response.status_code == 200
data = response.json()
assert "augmentation_types" in data
assert len(data["augmentation_types"]) == 12
# Check structure
aug_type = data["augmentation_types"][0]
assert "name" in aug_type
assert "description" in aug_type
assert "affects_geometry" in aug_type
assert "stage" in aug_type
def test_list_augmentation_types_unauthorized(
self, unauthenticated_client: TestClient
) -> None:
"""Test that unauthorized request is rejected."""
response = unauthenticated_client.get("/api/v1/admin/augmentation/types")
assert response.status_code == 401
class TestAugmentationPresetsEndpoint:
"""Tests for GET /admin/augmentation/presets endpoint."""
def test_list_presets(self, admin_client: TestClient, admin_token: str) -> None:
"""Test listing available presets."""
response = admin_client.get(
"/api/v1/admin/augmentation/presets",
headers={"X-Admin-Token": admin_token},
)
assert response.status_code == 200
data = response.json()
assert "presets" in data
assert len(data["presets"]) >= 4
# Check expected presets exist
preset_names = [p["name"] for p in data["presets"]]
assert "conservative" in preset_names
assert "moderate" in preset_names
assert "aggressive" in preset_names
assert "scanned_document" in preset_names
class TestAugmentationPreviewEndpoint:
"""Tests for POST /admin/augmentation/preview/{document_id} endpoint."""
def test_preview_augmentation(
self,
admin_client: TestClient,
admin_token: str,
sample_document_id: str,
mock_document_repo: MagicMock,
) -> None:
"""Test previewing augmentation on a document."""
# Mock document exists
mock_document = MagicMock()
mock_document.images_dir = "/fake/path"
mock_document_repo.get.return_value = mock_document
# Create a fake image (100x100 RGB)
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
with patch(
"inference.web.services.augmentation_service.AugmentationService._load_document_page"
) as mock_load:
mock_load.return_value = fake_image
response = admin_client.post(
f"/api/v1/admin/augmentation/preview/{sample_document_id}",
headers={"X-Admin-Token": admin_token},
json={
"augmentation_type": "gaussian_noise",
"params": {"std": 15},
},
)
assert response.status_code == 200
data = response.json()
assert "preview_url" in data
assert "original_url" in data
assert "applied_params" in data
def test_preview_invalid_augmentation_type(
self,
admin_client: TestClient,
admin_token: str,
sample_document_id: str,
) -> None:
"""Test that invalid augmentation type returns error."""
response = admin_client.post(
f"/api/v1/admin/augmentation/preview/{sample_document_id}",
headers={"X-Admin-Token": admin_token},
json={
"augmentation_type": "nonexistent",
"params": {},
},
)
assert response.status_code == 400
def test_preview_nonexistent_document(
self,
admin_client: TestClient,
admin_token: str,
) -> None:
"""Test that nonexistent document returns 404."""
response = admin_client.post(
"/api/v1/admin/augmentation/preview/00000000-0000-0000-0000-000000000000",
headers={"X-Admin-Token": admin_token},
json={
"augmentation_type": "gaussian_noise",
"params": {},
},
)
assert response.status_code == 404
class TestAugmentationPreviewConfigEndpoint:
"""Tests for POST /admin/augmentation/preview-config/{document_id} endpoint."""
def test_preview_config(
self,
admin_client: TestClient,
admin_token: str,
sample_document_id: str,
mock_document_repo: MagicMock,
) -> None:
"""Test previewing full config on a document."""
# Mock document exists
mock_document = MagicMock()
mock_document.images_dir = "/fake/path"
mock_document_repo.get.return_value = mock_document
# Create a fake image (100x100 RGB)
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
with patch(
"inference.web.services.augmentation_service.AugmentationService._load_document_page"
) as mock_load:
mock_load.return_value = fake_image
response = admin_client.post(
f"/api/v1/admin/augmentation/preview-config/{sample_document_id}",
headers={"X-Admin-Token": admin_token},
json={
"gaussian_noise": {"enabled": True, "probability": 1.0},
"lighting_variation": {"enabled": True, "probability": 1.0},
"preserve_bboxes": True,
"seed": 42,
},
)
assert response.status_code == 200
data = response.json()
assert "preview_url" in data
assert "original_url" in data
class TestAugmentationBatchEndpoint:
"""Tests for POST /admin/augmentation/batch endpoint."""
def test_create_augmented_dataset(
self,
admin_client: TestClient,
admin_token: str,
sample_dataset_id: str,
mock_dataset_repo: MagicMock,
) -> None:
"""Test creating augmented dataset."""
# Mock dataset exists
mock_dataset = MagicMock()
mock_dataset.total_images = 100
mock_dataset_repo.get.return_value = mock_dataset
response = admin_client.post(
"/api/v1/admin/augmentation/batch",
headers={"X-Admin-Token": admin_token},
json={
"dataset_id": sample_dataset_id,
"config": {
"gaussian_noise": {"enabled": True, "probability": 0.5},
"preserve_bboxes": True,
},
"output_name": "test_augmented_dataset",
"multiplier": 2,
},
)
assert response.status_code == 200
data = response.json()
assert "task_id" in data
assert "status" in data
assert "estimated_images" in data
def test_create_augmented_dataset_invalid_multiplier(
self,
admin_client: TestClient,
admin_token: str,
sample_dataset_id: str,
) -> None:
"""Test that invalid multiplier is rejected."""
response = admin_client.post(
"/api/v1/admin/augmentation/batch",
headers={"X-Admin-Token": admin_token},
json={
"dataset_id": sample_dataset_id,
"config": {},
"output_name": "test",
"multiplier": 100, # Too high
},
)
assert response.status_code == 422 # Validation error
class TestAugmentedDatasetsListEndpoint:
"""Tests for GET /admin/augmentation/datasets endpoint."""
def test_list_augmented_datasets(
self, admin_client: TestClient, admin_token: str
) -> None:
"""Test listing augmented datasets."""
response = admin_client.get(
"/api/v1/admin/augmentation/datasets",
headers={"X-Admin-Token": admin_token},
)
assert response.status_code == 200
data = response.json()
assert "total" in data
assert "limit" in data
assert "offset" in data
assert "datasets" in data
assert isinstance(data["datasets"], list)
def test_list_augmented_datasets_pagination(
self, admin_client: TestClient, admin_token: str
) -> None:
"""Test pagination parameters."""
response = admin_client.get(
"/api/v1/admin/augmentation/datasets",
headers={"X-Admin-Token": admin_token},
params={"limit": 5, "offset": 0},
)
assert response.status_code == 200
data = response.json()
assert data["limit"] == 5
assert data["offset"] == 0
# Fixtures for tests
@pytest.fixture
def sample_document_id() -> str:
"""Provide a sample document ID for testing."""
return TEST_DOCUMENT_UUID
@pytest.fixture
def sample_dataset_id() -> str:
"""Provide a sample dataset ID for testing."""
return TEST_DATASET_UUID