339 lines
12 KiB
Python
339 lines
12 KiB
Python
"""
|
|
Tests for augmentation pipeline module.
|
|
|
|
TDD Phase 2: RED - Write tests first, then implement to pass.
|
|
"""
|
|
|
|
from typing import Any
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
|
|
class TestAugmentationPipeline:
|
|
"""Tests for AugmentationPipeline class."""
|
|
|
|
def test_create_with_config(self) -> None:
|
|
"""Test creating pipeline with config."""
|
|
from shared.augmentation.config import AugmentationConfig
|
|
from shared.augmentation.pipeline import AugmentationPipeline
|
|
|
|
config = AugmentationConfig()
|
|
pipeline = AugmentationPipeline(config)
|
|
|
|
assert pipeline.config is config
|
|
|
|
def test_create_with_seed(self) -> None:
|
|
"""Test creating pipeline with seed for reproducibility."""
|
|
from shared.augmentation.config import AugmentationConfig
|
|
from shared.augmentation.pipeline import AugmentationPipeline
|
|
|
|
config = AugmentationConfig(seed=42)
|
|
pipeline = AugmentationPipeline(config)
|
|
|
|
assert pipeline.config.seed == 42
|
|
|
|
def test_apply_returns_augmentation_result(self) -> None:
|
|
"""Test that apply returns AugmentationResult."""
|
|
from shared.augmentation.base import AugmentationResult
|
|
from shared.augmentation.config import AugmentationConfig
|
|
from shared.augmentation.pipeline import AugmentationPipeline
|
|
|
|
config = AugmentationConfig()
|
|
pipeline = AugmentationPipeline(config)
|
|
|
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
|
result = pipeline.apply(image)
|
|
|
|
assert isinstance(result, AugmentationResult)
|
|
assert result.image is not None
|
|
assert result.image.shape == image.shape
|
|
|
|
def test_apply_with_bboxes(self) -> None:
|
|
"""Test apply with bounding boxes."""
|
|
from shared.augmentation.config import AugmentationConfig
|
|
from shared.augmentation.pipeline import AugmentationPipeline
|
|
|
|
config = AugmentationConfig()
|
|
pipeline = AugmentationPipeline(config)
|
|
|
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
|
bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]], dtype=np.float32)
|
|
|
|
result = pipeline.apply(image, bboxes)
|
|
|
|
# Bboxes should be preserved when preserve_bboxes=True
|
|
assert result.bboxes is not None
|
|
|
|
def test_apply_no_augmentations_enabled(self) -> None:
|
|
"""Test apply when no augmentations are enabled."""
|
|
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
|
from shared.augmentation.pipeline import AugmentationPipeline
|
|
|
|
# Disable all augmentations
|
|
config = AugmentationConfig(
|
|
lighting_variation=AugmentationParams(enabled=False),
|
|
)
|
|
pipeline = AugmentationPipeline(config)
|
|
|
|
image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
|
result = pipeline.apply(image)
|
|
|
|
# Image should be unchanged (or a copy)
|
|
np.testing.assert_array_equal(result.image, image)
|
|
|
|
def test_apply_does_not_mutate_input(self) -> None:
|
|
"""Test that apply does not mutate input image."""
|
|
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
|
from shared.augmentation.pipeline import AugmentationPipeline
|
|
|
|
config = AugmentationConfig(
|
|
lighting_variation=AugmentationParams(enabled=True, probability=1.0),
|
|
)
|
|
pipeline = AugmentationPipeline(config)
|
|
|
|
image = np.full((100, 100, 3), 128, dtype=np.uint8)
|
|
original_copy = image.copy()
|
|
|
|
pipeline.apply(image)
|
|
|
|
np.testing.assert_array_equal(image, original_copy)
|
|
|
|
def test_reproducibility_with_seed(self) -> None:
|
|
"""Test that same seed produces same results."""
|
|
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
|
from shared.augmentation.pipeline import AugmentationPipeline
|
|
|
|
config1 = AugmentationConfig(
|
|
seed=42,
|
|
gaussian_noise=AugmentationParams(enabled=True, probability=1.0),
|
|
)
|
|
config2 = AugmentationConfig(
|
|
seed=42,
|
|
gaussian_noise=AugmentationParams(enabled=True, probability=1.0),
|
|
)
|
|
|
|
pipeline1 = AugmentationPipeline(config1)
|
|
pipeline2 = AugmentationPipeline(config2)
|
|
|
|
image = np.full((100, 100, 3), 128, dtype=np.uint8)
|
|
|
|
result1 = pipeline1.apply(image.copy())
|
|
result2 = pipeline2.apply(image.copy())
|
|
|
|
np.testing.assert_array_equal(result1.image, result2.image)
|
|
|
|
def test_metadata_contains_applied_augmentations(self) -> None:
|
|
"""Test that metadata lists applied augmentations."""
|
|
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
|
from shared.augmentation.pipeline import AugmentationPipeline
|
|
|
|
config = AugmentationConfig(
|
|
seed=42,
|
|
gaussian_noise=AugmentationParams(enabled=True, probability=1.0),
|
|
lighting_variation=AugmentationParams(enabled=True, probability=1.0),
|
|
)
|
|
pipeline = AugmentationPipeline(config)
|
|
|
|
image = np.full((100, 100, 3), 128, dtype=np.uint8)
|
|
result = pipeline.apply(image)
|
|
|
|
assert result.metadata is not None
|
|
assert "applied_augmentations" in result.metadata
|
|
# Both should be applied with probability=1.0
|
|
assert "gaussian_noise" in result.metadata["applied_augmentations"]
|
|
assert "lighting_variation" in result.metadata["applied_augmentations"]
|
|
|
|
|
|
class TestAugmentationPipelineStageOrder:
|
|
"""Tests for pipeline stage ordering."""
|
|
|
|
def test_stage_order_defined(self) -> None:
|
|
"""Test that stage order is defined."""
|
|
from shared.augmentation.pipeline import AugmentationPipeline
|
|
|
|
assert hasattr(AugmentationPipeline, "STAGE_ORDER")
|
|
expected_stages = [
|
|
"geometric",
|
|
"degradation",
|
|
"lighting",
|
|
"texture",
|
|
"blur",
|
|
"noise",
|
|
]
|
|
assert AugmentationPipeline.STAGE_ORDER == expected_stages
|
|
|
|
def test_stage_mapping_defined(self) -> None:
|
|
"""Test that all augmentation types are mapped to stages."""
|
|
from shared.augmentation.pipeline import AugmentationPipeline
|
|
|
|
assert hasattr(AugmentationPipeline, "STAGE_MAPPING")
|
|
|
|
expected_mappings = {
|
|
"perspective_warp": "geometric",
|
|
"wrinkle": "degradation",
|
|
"edge_damage": "degradation",
|
|
"stain": "degradation",
|
|
"lighting_variation": "lighting",
|
|
"shadow": "lighting",
|
|
"paper_texture": "texture",
|
|
"scanner_artifacts": "texture",
|
|
"gaussian_blur": "blur",
|
|
"motion_blur": "blur",
|
|
"gaussian_noise": "noise",
|
|
"salt_pepper": "noise",
|
|
}
|
|
|
|
for aug_name, stage in expected_mappings.items():
|
|
assert aug_name in AugmentationPipeline.STAGE_MAPPING
|
|
assert AugmentationPipeline.STAGE_MAPPING[aug_name] == stage
|
|
|
|
def test_geometric_before_degradation(self) -> None:
|
|
"""Test that geometric transforms are applied before degradation."""
|
|
from shared.augmentation.pipeline import AugmentationPipeline
|
|
|
|
stages = AugmentationPipeline.STAGE_ORDER
|
|
geometric_idx = stages.index("geometric")
|
|
degradation_idx = stages.index("degradation")
|
|
|
|
assert geometric_idx < degradation_idx
|
|
|
|
def test_noise_applied_last(self) -> None:
|
|
"""Test that noise is applied last."""
|
|
from shared.augmentation.pipeline import AugmentationPipeline
|
|
|
|
stages = AugmentationPipeline.STAGE_ORDER
|
|
assert stages[-1] == "noise"
|
|
|
|
|
|
class TestAugmentationRegistry:
|
|
"""Tests for augmentation registry."""
|
|
|
|
def test_registry_exists(self) -> None:
|
|
"""Test that augmentation registry exists."""
|
|
from shared.augmentation.pipeline import AUGMENTATION_REGISTRY
|
|
|
|
assert isinstance(AUGMENTATION_REGISTRY, dict)
|
|
|
|
def test_registry_contains_all_types(self) -> None:
|
|
"""Test that registry contains all augmentation types."""
|
|
from shared.augmentation.pipeline import AUGMENTATION_REGISTRY
|
|
|
|
expected_types = [
|
|
"perspective_warp",
|
|
"wrinkle",
|
|
"edge_damage",
|
|
"stain",
|
|
"lighting_variation",
|
|
"shadow",
|
|
"gaussian_blur",
|
|
"motion_blur",
|
|
"gaussian_noise",
|
|
"salt_pepper",
|
|
"paper_texture",
|
|
"scanner_artifacts",
|
|
]
|
|
|
|
for aug_type in expected_types:
|
|
assert aug_type in AUGMENTATION_REGISTRY, f"Missing: {aug_type}"
|
|
|
|
|
|
class TestPipelinePreview:
|
|
"""Tests for pipeline preview functionality."""
|
|
|
|
def test_preview_single_augmentation(self) -> None:
|
|
"""Test previewing a single augmentation."""
|
|
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
|
from shared.augmentation.pipeline import AugmentationPipeline
|
|
|
|
config = AugmentationConfig(
|
|
gaussian_noise=AugmentationParams(
|
|
enabled=True, probability=1.0, params={"std": (10, 10)}
|
|
),
|
|
)
|
|
pipeline = AugmentationPipeline(config)
|
|
|
|
image = np.full((100, 100, 3), 128, dtype=np.uint8)
|
|
preview = pipeline.preview(image, "gaussian_noise")
|
|
|
|
assert preview.shape == image.shape
|
|
assert preview.dtype == np.uint8
|
|
# Preview should modify the image
|
|
assert not np.array_equal(preview, image)
|
|
|
|
def test_preview_unknown_augmentation_raises(self) -> None:
|
|
"""Test that previewing unknown augmentation raises error."""
|
|
from shared.augmentation.config import AugmentationConfig
|
|
from shared.augmentation.pipeline import AugmentationPipeline
|
|
|
|
config = AugmentationConfig()
|
|
pipeline = AugmentationPipeline(config)
|
|
|
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
|
|
|
with pytest.raises(ValueError, match="Unknown augmentation"):
|
|
pipeline.preview(image, "non_existent_augmentation")
|
|
|
|
def test_preview_is_deterministic(self) -> None:
|
|
"""Test that preview produces deterministic results."""
|
|
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
|
from shared.augmentation.pipeline import AugmentationPipeline
|
|
|
|
config = AugmentationConfig(
|
|
gaussian_noise=AugmentationParams(enabled=True),
|
|
)
|
|
pipeline = AugmentationPipeline(config)
|
|
|
|
image = np.full((100, 100, 3), 128, dtype=np.uint8)
|
|
|
|
preview1 = pipeline.preview(image, "gaussian_noise")
|
|
preview2 = pipeline.preview(image, "gaussian_noise")
|
|
|
|
np.testing.assert_array_equal(preview1, preview2)
|
|
|
|
|
|
class TestPipelineGetAvailableAugmentations:
|
|
"""Tests for getting available augmentations."""
|
|
|
|
def test_get_available_augmentations(self) -> None:
|
|
"""Test getting list of available augmentations."""
|
|
from shared.augmentation.pipeline import get_available_augmentations
|
|
|
|
augmentations = get_available_augmentations()
|
|
|
|
assert isinstance(augmentations, list)
|
|
assert len(augmentations) == 12
|
|
|
|
# Each item should have name, description, affects_geometry
|
|
for aug in augmentations:
|
|
assert "name" in aug
|
|
assert "description" in aug
|
|
assert "affects_geometry" in aug
|
|
assert "stage" in aug
|
|
|
|
def test_get_available_augmentations_includes_all_types(self) -> None:
|
|
"""Test that all augmentation types are included."""
|
|
from shared.augmentation.pipeline import get_available_augmentations
|
|
|
|
augmentations = get_available_augmentations()
|
|
names = [aug["name"] for aug in augmentations]
|
|
|
|
expected = [
|
|
"perspective_warp",
|
|
"wrinkle",
|
|
"edge_damage",
|
|
"stain",
|
|
"lighting_variation",
|
|
"shadow",
|
|
"gaussian_blur",
|
|
"motion_blur",
|
|
"gaussian_noise",
|
|
"salt_pepper",
|
|
"paper_texture",
|
|
"scanner_artifacts",
|
|
]
|
|
|
|
for name in expected:
|
|
assert name in names
|