WIP
This commit is contained in:
338
tests/shared/augmentation/test_pipeline.py
Normal file
338
tests/shared/augmentation/test_pipeline.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""
|
||||
Tests for augmentation pipeline module.
|
||||
|
||||
TDD Phase 2: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
class TestAugmentationPipeline:
|
||||
"""Tests for AugmentationPipeline class."""
|
||||
|
||||
def test_create_with_config(self) -> None:
|
||||
"""Test creating pipeline with config."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig()
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
assert pipeline.config is config
|
||||
|
||||
def test_create_with_seed(self) -> None:
|
||||
"""Test creating pipeline with seed for reproducibility."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig(seed=42)
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
assert pipeline.config.seed == 42
|
||||
|
||||
def test_apply_returns_augmentation_result(self) -> None:
|
||||
"""Test that apply returns AugmentationResult."""
|
||||
from shared.augmentation.base import AugmentationResult
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig()
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
result = pipeline.apply(image)
|
||||
|
||||
assert isinstance(result, AugmentationResult)
|
||||
assert result.image is not None
|
||||
assert result.image.shape == image.shape
|
||||
|
||||
def test_apply_with_bboxes(self) -> None:
|
||||
"""Test apply with bounding boxes."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig()
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]], dtype=np.float32)
|
||||
|
||||
result = pipeline.apply(image, bboxes)
|
||||
|
||||
# Bboxes should be preserved when preserve_bboxes=True
|
||||
assert result.bboxes is not None
|
||||
|
||||
def test_apply_no_augmentations_enabled(self) -> None:
|
||||
"""Test apply when no augmentations are enabled."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
# Disable all augmentations
|
||||
config = AugmentationConfig(
|
||||
lighting_variation=AugmentationParams(enabled=False),
|
||||
)
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||
result = pipeline.apply(image)
|
||||
|
||||
# Image should be unchanged (or a copy)
|
||||
np.testing.assert_array_equal(result.image, image)
|
||||
|
||||
def test_apply_does_not_mutate_input(self) -> None:
|
||||
"""Test that apply does not mutate input image."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig(
|
||||
lighting_variation=AugmentationParams(enabled=True, probability=1.0),
|
||||
)
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
image = np.full((100, 100, 3), 128, dtype=np.uint8)
|
||||
original_copy = image.copy()
|
||||
|
||||
pipeline.apply(image)
|
||||
|
||||
np.testing.assert_array_equal(image, original_copy)
|
||||
|
||||
def test_reproducibility_with_seed(self) -> None:
|
||||
"""Test that same seed produces same results."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config1 = AugmentationConfig(
|
||||
seed=42,
|
||||
gaussian_noise=AugmentationParams(enabled=True, probability=1.0),
|
||||
)
|
||||
config2 = AugmentationConfig(
|
||||
seed=42,
|
||||
gaussian_noise=AugmentationParams(enabled=True, probability=1.0),
|
||||
)
|
||||
|
||||
pipeline1 = AugmentationPipeline(config1)
|
||||
pipeline2 = AugmentationPipeline(config2)
|
||||
|
||||
image = np.full((100, 100, 3), 128, dtype=np.uint8)
|
||||
|
||||
result1 = pipeline1.apply(image.copy())
|
||||
result2 = pipeline2.apply(image.copy())
|
||||
|
||||
np.testing.assert_array_equal(result1.image, result2.image)
|
||||
|
||||
def test_metadata_contains_applied_augmentations(self) -> None:
|
||||
"""Test that metadata lists applied augmentations."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig(
|
||||
seed=42,
|
||||
gaussian_noise=AugmentationParams(enabled=True, probability=1.0),
|
||||
lighting_variation=AugmentationParams(enabled=True, probability=1.0),
|
||||
)
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
image = np.full((100, 100, 3), 128, dtype=np.uint8)
|
||||
result = pipeline.apply(image)
|
||||
|
||||
assert result.metadata is not None
|
||||
assert "applied_augmentations" in result.metadata
|
||||
# Both should be applied with probability=1.0
|
||||
assert "gaussian_noise" in result.metadata["applied_augmentations"]
|
||||
assert "lighting_variation" in result.metadata["applied_augmentations"]
|
||||
|
||||
|
||||
class TestAugmentationPipelineStageOrder:
|
||||
"""Tests for pipeline stage ordering."""
|
||||
|
||||
def test_stage_order_defined(self) -> None:
|
||||
"""Test that stage order is defined."""
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
assert hasattr(AugmentationPipeline, "STAGE_ORDER")
|
||||
expected_stages = [
|
||||
"geometric",
|
||||
"degradation",
|
||||
"lighting",
|
||||
"texture",
|
||||
"blur",
|
||||
"noise",
|
||||
]
|
||||
assert AugmentationPipeline.STAGE_ORDER == expected_stages
|
||||
|
||||
def test_stage_mapping_defined(self) -> None:
|
||||
"""Test that all augmentation types are mapped to stages."""
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
assert hasattr(AugmentationPipeline, "STAGE_MAPPING")
|
||||
|
||||
expected_mappings = {
|
||||
"perspective_warp": "geometric",
|
||||
"wrinkle": "degradation",
|
||||
"edge_damage": "degradation",
|
||||
"stain": "degradation",
|
||||
"lighting_variation": "lighting",
|
||||
"shadow": "lighting",
|
||||
"paper_texture": "texture",
|
||||
"scanner_artifacts": "texture",
|
||||
"gaussian_blur": "blur",
|
||||
"motion_blur": "blur",
|
||||
"gaussian_noise": "noise",
|
||||
"salt_pepper": "noise",
|
||||
}
|
||||
|
||||
for aug_name, stage in expected_mappings.items():
|
||||
assert aug_name in AugmentationPipeline.STAGE_MAPPING
|
||||
assert AugmentationPipeline.STAGE_MAPPING[aug_name] == stage
|
||||
|
||||
def test_geometric_before_degradation(self) -> None:
|
||||
"""Test that geometric transforms are applied before degradation."""
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
stages = AugmentationPipeline.STAGE_ORDER
|
||||
geometric_idx = stages.index("geometric")
|
||||
degradation_idx = stages.index("degradation")
|
||||
|
||||
assert geometric_idx < degradation_idx
|
||||
|
||||
def test_noise_applied_last(self) -> None:
|
||||
"""Test that noise is applied last."""
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
stages = AugmentationPipeline.STAGE_ORDER
|
||||
assert stages[-1] == "noise"
|
||||
|
||||
|
||||
class TestAugmentationRegistry:
|
||||
"""Tests for augmentation registry."""
|
||||
|
||||
def test_registry_exists(self) -> None:
|
||||
"""Test that augmentation registry exists."""
|
||||
from shared.augmentation.pipeline import AUGMENTATION_REGISTRY
|
||||
|
||||
assert isinstance(AUGMENTATION_REGISTRY, dict)
|
||||
|
||||
def test_registry_contains_all_types(self) -> None:
|
||||
"""Test that registry contains all augmentation types."""
|
||||
from shared.augmentation.pipeline import AUGMENTATION_REGISTRY
|
||||
|
||||
expected_types = [
|
||||
"perspective_warp",
|
||||
"wrinkle",
|
||||
"edge_damage",
|
||||
"stain",
|
||||
"lighting_variation",
|
||||
"shadow",
|
||||
"gaussian_blur",
|
||||
"motion_blur",
|
||||
"gaussian_noise",
|
||||
"salt_pepper",
|
||||
"paper_texture",
|
||||
"scanner_artifacts",
|
||||
]
|
||||
|
||||
for aug_type in expected_types:
|
||||
assert aug_type in AUGMENTATION_REGISTRY, f"Missing: {aug_type}"
|
||||
|
||||
|
||||
class TestPipelinePreview:
|
||||
"""Tests for pipeline preview functionality."""
|
||||
|
||||
def test_preview_single_augmentation(self) -> None:
|
||||
"""Test previewing a single augmentation."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig(
|
||||
gaussian_noise=AugmentationParams(
|
||||
enabled=True, probability=1.0, params={"std": (10, 10)}
|
||||
),
|
||||
)
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
image = np.full((100, 100, 3), 128, dtype=np.uint8)
|
||||
preview = pipeline.preview(image, "gaussian_noise")
|
||||
|
||||
assert preview.shape == image.shape
|
||||
assert preview.dtype == np.uint8
|
||||
# Preview should modify the image
|
||||
assert not np.array_equal(preview, image)
|
||||
|
||||
def test_preview_unknown_augmentation_raises(self) -> None:
|
||||
"""Test that previewing unknown augmentation raises error."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig()
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown augmentation"):
|
||||
pipeline.preview(image, "non_existent_augmentation")
|
||||
|
||||
def test_preview_is_deterministic(self) -> None:
|
||||
"""Test that preview produces deterministic results."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig(
|
||||
gaussian_noise=AugmentationParams(enabled=True),
|
||||
)
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
image = np.full((100, 100, 3), 128, dtype=np.uint8)
|
||||
|
||||
preview1 = pipeline.preview(image, "gaussian_noise")
|
||||
preview2 = pipeline.preview(image, "gaussian_noise")
|
||||
|
||||
np.testing.assert_array_equal(preview1, preview2)
|
||||
|
||||
|
||||
class TestPipelineGetAvailableAugmentations:
|
||||
"""Tests for getting available augmentations."""
|
||||
|
||||
def test_get_available_augmentations(self) -> None:
|
||||
"""Test getting list of available augmentations."""
|
||||
from shared.augmentation.pipeline import get_available_augmentations
|
||||
|
||||
augmentations = get_available_augmentations()
|
||||
|
||||
assert isinstance(augmentations, list)
|
||||
assert len(augmentations) == 12
|
||||
|
||||
# Each item should have name, description, affects_geometry
|
||||
for aug in augmentations:
|
||||
assert "name" in aug
|
||||
assert "description" in aug
|
||||
assert "affects_geometry" in aug
|
||||
assert "stage" in aug
|
||||
|
||||
def test_get_available_augmentations_includes_all_types(self) -> None:
|
||||
"""Test that all augmentation types are included."""
|
||||
from shared.augmentation.pipeline import get_available_augmentations
|
||||
|
||||
augmentations = get_available_augmentations()
|
||||
names = [aug["name"] for aug in augmentations]
|
||||
|
||||
expected = [
|
||||
"perspective_warp",
|
||||
"wrinkle",
|
||||
"edge_damage",
|
||||
"stain",
|
||||
"lighting_variation",
|
||||
"shadow",
|
||||
"gaussian_blur",
|
||||
"motion_blur",
|
||||
"gaussian_noise",
|
||||
"salt_pepper",
|
||||
"paper_texture",
|
||||
"scanner_artifacts",
|
||||
]
|
||||
|
||||
for name in expected:
|
||||
assert name in names
|
||||
Reference in New Issue
Block a user