383 lines
12 KiB
Python
383 lines
12 KiB
Python
"""
|
|
Tests for augmentation API routes.
|
|
|
|
TDD Phase 5: RED - Write tests first, then implement to pass.
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
from fastapi import FastAPI
|
|
from fastapi.testclient import TestClient
|
|
import numpy as np
|
|
|
|
from inference.web.api.v1.admin.augmentation import create_augmentation_router
|
|
from inference.web.core.auth import (
|
|
validate_admin_token,
|
|
get_document_repository,
|
|
get_dataset_repository,
|
|
)
|
|
|
|
|
|
TEST_ADMIN_TOKEN = "test-admin-token-12345"
|
|
TEST_DOCUMENT_UUID = "550e8400-e29b-41d4-a716-446655440001"
|
|
TEST_DATASET_UUID = "660e8400-e29b-41d4-a716-446655440001"
|
|
|
|
|
|
@pytest.fixture
|
|
def admin_token() -> str:
|
|
"""Provide admin token for testing."""
|
|
return TEST_ADMIN_TOKEN
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_document_repo() -> MagicMock:
|
|
"""Create a mock DocumentRepository for testing."""
|
|
mock = MagicMock()
|
|
# Default return values
|
|
mock.get.return_value = None
|
|
mock.get_by_token.return_value = None
|
|
return mock
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_dataset_repo() -> MagicMock:
|
|
"""Create a mock DatasetRepository for testing."""
|
|
mock = MagicMock()
|
|
# Default return values
|
|
mock.get.return_value = None
|
|
mock.get_paginated.return_value = ([], 0)
|
|
return mock
|
|
|
|
|
|
@pytest.fixture
|
|
def admin_client(mock_document_repo: MagicMock, mock_dataset_repo: MagicMock) -> TestClient:
|
|
"""Create test client with admin authentication."""
|
|
app = FastAPI()
|
|
|
|
# Override dependencies
|
|
def get_token_override():
|
|
return TEST_ADMIN_TOKEN
|
|
|
|
def get_document_repo_override():
|
|
return mock_document_repo
|
|
|
|
def get_dataset_repo_override():
|
|
return mock_dataset_repo
|
|
|
|
app.dependency_overrides[validate_admin_token] = get_token_override
|
|
app.dependency_overrides[get_document_repository] = get_document_repo_override
|
|
app.dependency_overrides[get_dataset_repository] = get_dataset_repo_override
|
|
|
|
# Include router - the router already has /augmentation prefix
|
|
# so we add /api/v1/admin to get /api/v1/admin/augmentation
|
|
router = create_augmentation_router()
|
|
app.include_router(router, prefix="/api/v1/admin")
|
|
|
|
return TestClient(app)
|
|
|
|
|
|
@pytest.fixture
|
|
def unauthenticated_client(mock_document_repo: MagicMock, mock_dataset_repo: MagicMock) -> TestClient:
|
|
"""Create test client WITHOUT admin authentication override."""
|
|
app = FastAPI()
|
|
|
|
# Only override the repositories, NOT the token validation
|
|
def get_document_repo_override():
|
|
return mock_document_repo
|
|
|
|
def get_dataset_repo_override():
|
|
return mock_dataset_repo
|
|
|
|
app.dependency_overrides[get_document_repository] = get_document_repo_override
|
|
app.dependency_overrides[get_dataset_repository] = get_dataset_repo_override
|
|
|
|
router = create_augmentation_router()
|
|
app.include_router(router, prefix="/api/v1/admin")
|
|
|
|
return TestClient(app)
|
|
|
|
|
|
class TestAugmentationTypesEndpoint:
|
|
"""Tests for GET /admin/augmentation/types endpoint."""
|
|
|
|
def test_list_augmentation_types(
|
|
self, admin_client: TestClient, admin_token: str
|
|
) -> None:
|
|
"""Test listing available augmentation types."""
|
|
response = admin_client.get(
|
|
"/api/v1/admin/augmentation/types",
|
|
headers={"X-Admin-Token": admin_token},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert "augmentation_types" in data
|
|
assert len(data["augmentation_types"]) == 12
|
|
|
|
# Check structure
|
|
aug_type = data["augmentation_types"][0]
|
|
assert "name" in aug_type
|
|
assert "description" in aug_type
|
|
assert "affects_geometry" in aug_type
|
|
assert "stage" in aug_type
|
|
|
|
def test_list_augmentation_types_unauthorized(
|
|
self, unauthenticated_client: TestClient
|
|
) -> None:
|
|
"""Test that unauthorized request is rejected."""
|
|
response = unauthenticated_client.get("/api/v1/admin/augmentation/types")
|
|
|
|
assert response.status_code == 401
|
|
|
|
|
|
class TestAugmentationPresetsEndpoint:
|
|
"""Tests for GET /admin/augmentation/presets endpoint."""
|
|
|
|
def test_list_presets(self, admin_client: TestClient, admin_token: str) -> None:
|
|
"""Test listing available presets."""
|
|
response = admin_client.get(
|
|
"/api/v1/admin/augmentation/presets",
|
|
headers={"X-Admin-Token": admin_token},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert "presets" in data
|
|
assert len(data["presets"]) >= 4
|
|
|
|
# Check expected presets exist
|
|
preset_names = [p["name"] for p in data["presets"]]
|
|
assert "conservative" in preset_names
|
|
assert "moderate" in preset_names
|
|
assert "aggressive" in preset_names
|
|
assert "scanned_document" in preset_names
|
|
|
|
|
|
class TestAugmentationPreviewEndpoint:
|
|
"""Tests for POST /admin/augmentation/preview/{document_id} endpoint."""
|
|
|
|
def test_preview_augmentation(
|
|
self,
|
|
admin_client: TestClient,
|
|
admin_token: str,
|
|
sample_document_id: str,
|
|
mock_document_repo: MagicMock,
|
|
) -> None:
|
|
"""Test previewing augmentation on a document."""
|
|
# Mock document exists
|
|
mock_document = MagicMock()
|
|
mock_document.images_dir = "/fake/path"
|
|
mock_document_repo.get.return_value = mock_document
|
|
|
|
# Create a fake image (100x100 RGB)
|
|
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
|
|
|
with patch(
|
|
"inference.web.services.augmentation_service.AugmentationService._load_document_page"
|
|
) as mock_load:
|
|
mock_load.return_value = fake_image
|
|
|
|
response = admin_client.post(
|
|
f"/api/v1/admin/augmentation/preview/{sample_document_id}",
|
|
headers={"X-Admin-Token": admin_token},
|
|
json={
|
|
"augmentation_type": "gaussian_noise",
|
|
"params": {"std": 15},
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert "preview_url" in data
|
|
assert "original_url" in data
|
|
assert "applied_params" in data
|
|
|
|
def test_preview_invalid_augmentation_type(
|
|
self,
|
|
admin_client: TestClient,
|
|
admin_token: str,
|
|
sample_document_id: str,
|
|
) -> None:
|
|
"""Test that invalid augmentation type returns error."""
|
|
response = admin_client.post(
|
|
f"/api/v1/admin/augmentation/preview/{sample_document_id}",
|
|
headers={"X-Admin-Token": admin_token},
|
|
json={
|
|
"augmentation_type": "nonexistent",
|
|
"params": {},
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
|
|
def test_preview_nonexistent_document(
|
|
self,
|
|
admin_client: TestClient,
|
|
admin_token: str,
|
|
) -> None:
|
|
"""Test that nonexistent document returns 404."""
|
|
response = admin_client.post(
|
|
"/api/v1/admin/augmentation/preview/00000000-0000-0000-0000-000000000000",
|
|
headers={"X-Admin-Token": admin_token},
|
|
json={
|
|
"augmentation_type": "gaussian_noise",
|
|
"params": {},
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 404
|
|
|
|
|
|
class TestAugmentationPreviewConfigEndpoint:
|
|
"""Tests for POST /admin/augmentation/preview-config/{document_id} endpoint."""
|
|
|
|
def test_preview_config(
|
|
self,
|
|
admin_client: TestClient,
|
|
admin_token: str,
|
|
sample_document_id: str,
|
|
mock_document_repo: MagicMock,
|
|
) -> None:
|
|
"""Test previewing full config on a document."""
|
|
# Mock document exists
|
|
mock_document = MagicMock()
|
|
mock_document.images_dir = "/fake/path"
|
|
mock_document_repo.get.return_value = mock_document
|
|
|
|
# Create a fake image (100x100 RGB)
|
|
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
|
|
|
with patch(
|
|
"inference.web.services.augmentation_service.AugmentationService._load_document_page"
|
|
) as mock_load:
|
|
mock_load.return_value = fake_image
|
|
|
|
response = admin_client.post(
|
|
f"/api/v1/admin/augmentation/preview-config/{sample_document_id}",
|
|
headers={"X-Admin-Token": admin_token},
|
|
json={
|
|
"gaussian_noise": {"enabled": True, "probability": 1.0},
|
|
"lighting_variation": {"enabled": True, "probability": 1.0},
|
|
"preserve_bboxes": True,
|
|
"seed": 42,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert "preview_url" in data
|
|
assert "original_url" in data
|
|
|
|
|
|
class TestAugmentationBatchEndpoint:
|
|
"""Tests for POST /admin/augmentation/batch endpoint."""
|
|
|
|
def test_create_augmented_dataset(
|
|
self,
|
|
admin_client: TestClient,
|
|
admin_token: str,
|
|
sample_dataset_id: str,
|
|
mock_dataset_repo: MagicMock,
|
|
) -> None:
|
|
"""Test creating augmented dataset."""
|
|
# Mock dataset exists
|
|
mock_dataset = MagicMock()
|
|
mock_dataset.total_images = 100
|
|
mock_dataset_repo.get.return_value = mock_dataset
|
|
|
|
response = admin_client.post(
|
|
"/api/v1/admin/augmentation/batch",
|
|
headers={"X-Admin-Token": admin_token},
|
|
json={
|
|
"dataset_id": sample_dataset_id,
|
|
"config": {
|
|
"gaussian_noise": {"enabled": True, "probability": 0.5},
|
|
"preserve_bboxes": True,
|
|
},
|
|
"output_name": "test_augmented_dataset",
|
|
"multiplier": 2,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert "task_id" in data
|
|
assert "status" in data
|
|
assert "estimated_images" in data
|
|
|
|
def test_create_augmented_dataset_invalid_multiplier(
|
|
self,
|
|
admin_client: TestClient,
|
|
admin_token: str,
|
|
sample_dataset_id: str,
|
|
) -> None:
|
|
"""Test that invalid multiplier is rejected."""
|
|
response = admin_client.post(
|
|
"/api/v1/admin/augmentation/batch",
|
|
headers={"X-Admin-Token": admin_token},
|
|
json={
|
|
"dataset_id": sample_dataset_id,
|
|
"config": {},
|
|
"output_name": "test",
|
|
"multiplier": 100, # Too high
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 422 # Validation error
|
|
|
|
|
|
class TestAugmentedDatasetsListEndpoint:
|
|
"""Tests for GET /admin/augmentation/datasets endpoint."""
|
|
|
|
def test_list_augmented_datasets(
|
|
self, admin_client: TestClient, admin_token: str
|
|
) -> None:
|
|
"""Test listing augmented datasets."""
|
|
response = admin_client.get(
|
|
"/api/v1/admin/augmentation/datasets",
|
|
headers={"X-Admin-Token": admin_token},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert "total" in data
|
|
assert "limit" in data
|
|
assert "offset" in data
|
|
assert "datasets" in data
|
|
assert isinstance(data["datasets"], list)
|
|
|
|
def test_list_augmented_datasets_pagination(
|
|
self, admin_client: TestClient, admin_token: str
|
|
) -> None:
|
|
"""Test pagination parameters."""
|
|
response = admin_client.get(
|
|
"/api/v1/admin/augmentation/datasets",
|
|
headers={"X-Admin-Token": admin_token},
|
|
params={"limit": 5, "offset": 0},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert data["limit"] == 5
|
|
assert data["offset"] == 0
|
|
|
|
|
|
# Fixtures for tests
|
|
@pytest.fixture
|
|
def sample_document_id() -> str:
|
|
"""Provide a sample document ID for testing."""
|
|
return TEST_DOCUMENT_UUID
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_dataset_id() -> str:
|
|
"""Provide a sample dataset ID for testing."""
|
|
return TEST_DATASET_UUID
|