This commit is contained in:
Yaojia Wang
2026-01-30 00:44:21 +01:00
parent d2489a97d4
commit 33ada0350d
79 changed files with 9737 additions and 297 deletions

View File

@@ -0,0 +1 @@
# Tests for augmentation module

View File

@@ -0,0 +1,347 @@
"""
Tests for augmentation base module.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
from typing import Any
from unittest.mock import MagicMock
import numpy as np
import pytest
class TestAugmentationResult:
"""Tests for AugmentationResult dataclass."""
def test_minimal_result(self) -> None:
"""Test creating result with only required field."""
from shared.augmentation.base import AugmentationResult
image = np.zeros((100, 100, 3), dtype=np.uint8)
result = AugmentationResult(image=image)
assert result.image is image
assert result.bboxes is None
assert result.transform_matrix is None
assert result.applied is True
assert result.metadata is None
def test_full_result(self) -> None:
"""Test creating result with all fields."""
from shared.augmentation.base import AugmentationResult
image = np.zeros((100, 100, 3), dtype=np.uint8)
bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]])
transform = np.eye(3)
metadata = {"applied_transform": "wrinkle"}
result = AugmentationResult(
image=image,
bboxes=bboxes,
transform_matrix=transform,
applied=True,
metadata=metadata,
)
assert result.image is image
np.testing.assert_array_equal(result.bboxes, bboxes)
np.testing.assert_array_equal(result.transform_matrix, transform)
assert result.applied is True
assert result.metadata == {"applied_transform": "wrinkle"}
def test_not_applied(self) -> None:
"""Test result when augmentation was not applied."""
from shared.augmentation.base import AugmentationResult
image = np.zeros((100, 100, 3), dtype=np.uint8)
result = AugmentationResult(image=image, applied=False)
assert result.applied is False
class TestBaseAugmentation:
"""Tests for BaseAugmentation abstract class."""
def test_cannot_instantiate_directly(self) -> None:
"""Test that BaseAugmentation cannot be instantiated."""
from shared.augmentation.base import BaseAugmentation
with pytest.raises(TypeError):
BaseAugmentation({}) # type: ignore
def test_subclass_must_implement_apply(self) -> None:
"""Test that subclass must implement apply method."""
from shared.augmentation.base import BaseAugmentation
class IncompleteAugmentation(BaseAugmentation):
name = "incomplete"
def _validate_params(self) -> None:
pass
# Missing apply method
with pytest.raises(TypeError):
IncompleteAugmentation({}) # type: ignore
def test_subclass_must_implement_validate_params(self) -> None:
"""Test that subclass must implement _validate_params."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class IncompleteAugmentation(BaseAugmentation):
name = "incomplete"
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
return AugmentationResult(image=image)
# Missing _validate_params method
with pytest.raises(TypeError):
IncompleteAugmentation({}) # type: ignore
def test_valid_subclass(self) -> None:
"""Test creating a valid subclass."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class DummyAugmentation(BaseAugmentation):
name = "dummy"
affects_geometry = False
def _validate_params(self) -> None:
pass
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
return AugmentationResult(image=image, bboxes=bboxes)
aug = DummyAugmentation({"param1": "value1"})
assert aug.name == "dummy"
assert aug.affects_geometry is False
assert aug.params == {"param1": "value1"}
def test_apply_returns_augmentation_result(self) -> None:
"""Test that apply returns AugmentationResult."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class DummyAugmentation(BaseAugmentation):
name = "dummy"
def _validate_params(self) -> None:
pass
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
# Simple pass-through
return AugmentationResult(image=image, bboxes=bboxes)
aug = DummyAugmentation({})
image = np.zeros((100, 100, 3), dtype=np.uint8)
bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]])
result = aug.apply(image, bboxes)
assert isinstance(result, AugmentationResult)
assert result.image is image
np.testing.assert_array_equal(result.bboxes, bboxes)
def test_affects_geometry_default(self) -> None:
"""Test that affects_geometry defaults to False."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class DummyAugmentation(BaseAugmentation):
name = "dummy"
# Not setting affects_geometry
def _validate_params(self) -> None:
pass
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
return AugmentationResult(image=image)
aug = DummyAugmentation({})
assert aug.affects_geometry is False
def test_validate_params_called_on_init(self) -> None:
"""Test that _validate_params is called during initialization."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
validation_called = {"called": False}
class ValidatingAugmentation(BaseAugmentation):
name = "validating"
def _validate_params(self) -> None:
validation_called["called"] = True
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
return AugmentationResult(image=image)
ValidatingAugmentation({})
assert validation_called["called"] is True
def test_validate_params_raises_on_invalid(self) -> None:
"""Test that _validate_params can raise ValueError."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class StrictAugmentation(BaseAugmentation):
name = "strict"
def _validate_params(self) -> None:
if "required_param" not in self.params:
raise ValueError("required_param is required")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
return AugmentationResult(image=image)
with pytest.raises(ValueError, match="required_param"):
StrictAugmentation({})
# Should work with required param
aug = StrictAugmentation({"required_param": "value"})
assert aug.params["required_param"] == "value"
def test_rng_usage(self) -> None:
"""Test that random generator can be passed and used."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class RandomAugmentation(BaseAugmentation):
name = "random"
def _validate_params(self) -> None:
pass
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
if rng is None:
rng = np.random.default_rng()
# Use rng to generate a random value
random_value = rng.random()
return AugmentationResult(
image=image,
metadata={"random_value": random_value},
)
aug = RandomAugmentation({})
image = np.zeros((100, 100, 3), dtype=np.uint8)
# With same seed, should get same result
rng1 = np.random.default_rng(42)
rng2 = np.random.default_rng(42)
result1 = aug.apply(image, rng=rng1)
result2 = aug.apply(image, rng=rng2)
assert result1.metadata is not None
assert result2.metadata is not None
assert result1.metadata["random_value"] == result2.metadata["random_value"]
class TestAugmentationResultImmutability:
"""Tests for ensuring result doesn't mutate input."""
def test_image_not_modified(self) -> None:
"""Test that original image is not modified."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class ModifyingAugmentation(BaseAugmentation):
name = "modifying"
def _validate_params(self) -> None:
pass
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
# Should copy before modifying
modified = image.copy()
modified[:] = 255
return AugmentationResult(image=modified)
aug = ModifyingAugmentation({})
original = np.zeros((100, 100, 3), dtype=np.uint8)
original_copy = original.copy()
result = aug.apply(original)
# Original should be unchanged
np.testing.assert_array_equal(original, original_copy)
# Result should be modified
assert np.all(result.image == 255)
def test_bboxes_not_modified(self) -> None:
"""Test that original bboxes are not modified."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class BboxModifyingAugmentation(BaseAugmentation):
name = "bbox_modifying"
affects_geometry = True
def _validate_params(self) -> None:
pass
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
if bboxes is not None:
# Should copy before modifying
modified_bboxes = bboxes.copy()
modified_bboxes[:, 1:] *= 0.5 # Scale down
return AugmentationResult(image=image, bboxes=modified_bboxes)
return AugmentationResult(image=image)
aug = BboxModifyingAugmentation({})
image = np.zeros((100, 100, 3), dtype=np.uint8)
original_bboxes = np.array([[0, 0.5, 0.5, 0.2, 0.2]], dtype=np.float32)
original_bboxes_copy = original_bboxes.copy()
result = aug.apply(image, original_bboxes)
# Original should be unchanged
np.testing.assert_array_equal(original_bboxes, original_bboxes_copy)
# Result should be modified
assert result.bboxes is not None
np.testing.assert_array_almost_equal(
result.bboxes, np.array([[0, 0.25, 0.25, 0.1, 0.1]])
)

View File

@@ -0,0 +1,283 @@
"""
Tests for augmentation configuration module.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
from typing import Any
import pytest
class TestAugmentationParams:
"""Tests for AugmentationParams dataclass."""
def test_default_values(self) -> None:
"""Test default parameter values."""
from shared.augmentation.config import AugmentationParams
params = AugmentationParams()
assert params.enabled is False
assert params.probability == 0.5
assert params.params == {}
def test_custom_values(self) -> None:
"""Test creating params with custom values."""
from shared.augmentation.config import AugmentationParams
params = AugmentationParams(
enabled=True,
probability=0.8,
params={"intensity": 0.5, "num_wrinkles": (2, 5)},
)
assert params.enabled is True
assert params.probability == 0.8
assert params.params["intensity"] == 0.5
assert params.params["num_wrinkles"] == (2, 5)
def test_immutability_params_dict(self) -> None:
"""Test that params dict is independent between instances."""
from shared.augmentation.config import AugmentationParams
params1 = AugmentationParams()
params2 = AugmentationParams()
# Modifying one should not affect the other
params1.params["test"] = "value"
assert "test" not in params2.params
def test_to_dict(self) -> None:
"""Test conversion to dictionary."""
from shared.augmentation.config import AugmentationParams
params = AugmentationParams(
enabled=True,
probability=0.7,
params={"key": "value"},
)
result = params.to_dict()
assert result == {
"enabled": True,
"probability": 0.7,
"params": {"key": "value"},
}
def test_from_dict(self) -> None:
"""Test creation from dictionary."""
from shared.augmentation.config import AugmentationParams
data = {
"enabled": True,
"probability": 0.6,
"params": {"intensity": 0.3},
}
params = AugmentationParams.from_dict(data)
assert params.enabled is True
assert params.probability == 0.6
assert params.params == {"intensity": 0.3}
def test_from_dict_with_defaults(self) -> None:
"""Test creation from partial dictionary uses defaults."""
from shared.augmentation.config import AugmentationParams
data: dict[str, Any] = {"enabled": True}
params = AugmentationParams.from_dict(data)
assert params.enabled is True
assert params.probability == 0.5 # default
assert params.params == {} # default
class TestAugmentationConfig:
"""Tests for AugmentationConfig dataclass."""
def test_default_values(self) -> None:
"""Test that all augmentation types have defaults."""
from shared.augmentation.config import AugmentationConfig
config = AugmentationConfig()
# All augmentation types should exist
augmentation_types = [
"perspective_warp",
"wrinkle",
"edge_damage",
"stain",
"lighting_variation",
"shadow",
"gaussian_blur",
"motion_blur",
"gaussian_noise",
"salt_pepper",
"paper_texture",
"scanner_artifacts",
]
for aug_type in augmentation_types:
assert hasattr(config, aug_type), f"Missing augmentation type: {aug_type}"
params = getattr(config, aug_type)
assert hasattr(params, "enabled")
assert hasattr(params, "probability")
assert hasattr(params, "params")
def test_global_settings_defaults(self) -> None:
"""Test global settings default values."""
from shared.augmentation.config import AugmentationConfig
config = AugmentationConfig()
assert config.preserve_bboxes is True
assert config.seed is None
def test_custom_seed(self) -> None:
"""Test setting custom seed for reproducibility."""
from shared.augmentation.config import AugmentationConfig
config = AugmentationConfig(seed=42)
assert config.seed == 42
def test_to_dict(self) -> None:
"""Test conversion to dictionary."""
from shared.augmentation.config import AugmentationConfig
config = AugmentationConfig(seed=123, preserve_bboxes=False)
result = config.to_dict()
assert isinstance(result, dict)
assert result["seed"] == 123
assert result["preserve_bboxes"] is False
assert "perspective_warp" in result
assert "wrinkle" in result
def test_from_dict(self) -> None:
"""Test creation from dictionary."""
from shared.augmentation.config import AugmentationConfig
data = {
"seed": 456,
"preserve_bboxes": False,
"wrinkle": {
"enabled": True,
"probability": 0.8,
"params": {"intensity": 0.5},
},
}
config = AugmentationConfig.from_dict(data)
assert config.seed == 456
assert config.preserve_bboxes is False
assert config.wrinkle.enabled is True
assert config.wrinkle.probability == 0.8
assert config.wrinkle.params["intensity"] == 0.5
def test_from_dict_with_partial_data(self) -> None:
"""Test creation from partial dictionary uses defaults."""
from shared.augmentation.config import AugmentationConfig
data: dict[str, Any] = {
"wrinkle": {"enabled": True},
}
config = AugmentationConfig.from_dict(data)
# Explicitly set value
assert config.wrinkle.enabled is True
# Default values
assert config.preserve_bboxes is True
assert config.seed is None
assert config.gaussian_blur.enabled is False
def test_get_enabled_augmentations(self) -> None:
"""Test getting list of enabled augmentations."""
from shared.augmentation.config import AugmentationConfig, AugmentationParams
config = AugmentationConfig(
wrinkle=AugmentationParams(enabled=True),
stain=AugmentationParams(enabled=True),
gaussian_blur=AugmentationParams(enabled=False),
)
enabled = config.get_enabled_augmentations()
assert "wrinkle" in enabled
assert "stain" in enabled
assert "gaussian_blur" not in enabled
def test_document_safe_defaults(self) -> None:
"""Test that default params are document-safe (conservative)."""
from shared.augmentation.config import AugmentationConfig
config = AugmentationConfig()
# Perspective warp should be very conservative
assert config.perspective_warp.params.get("max_warp", 0.02) <= 0.05
# Noise should be subtle
noise_std = config.gaussian_noise.params.get("std", (5, 15))
if isinstance(noise_std, tuple):
assert noise_std[1] <= 20 # Max std should be low
def test_immutability_between_instances(self) -> None:
"""Test that config instances are independent."""
from shared.augmentation.config import AugmentationConfig
config1 = AugmentationConfig()
config2 = AugmentationConfig()
# Modifying one should not affect the other
config1.wrinkle.params["test"] = "value"
assert "test" not in config2.wrinkle.params
class TestAugmentationConfigValidation:
"""Tests for configuration validation."""
def test_probability_range_validation(self) -> None:
"""Test that probability values are validated."""
from shared.augmentation.config import AugmentationParams
# Valid range
params = AugmentationParams(probability=0.5)
assert params.probability == 0.5
# Edge cases
params_zero = AugmentationParams(probability=0.0)
assert params_zero.probability == 0.0
params_one = AugmentationParams(probability=1.0)
assert params_one.probability == 1.0
def test_config_validate_method(self) -> None:
"""Test the validate method catches invalid configurations."""
from shared.augmentation.config import AugmentationConfig, AugmentationParams
# Invalid probability
config = AugmentationConfig(
wrinkle=AugmentationParams(probability=1.5), # Invalid
)
with pytest.raises(ValueError, match="probability"):
config.validate()
def test_config_validate_negative_probability(self) -> None:
"""Test validation catches negative probability."""
from shared.augmentation.config import AugmentationConfig, AugmentationParams
config = AugmentationConfig(
wrinkle=AugmentationParams(probability=-0.1),
)
with pytest.raises(ValueError, match="probability"):
config.validate()

View File

@@ -0,0 +1,338 @@
"""
Tests for augmentation pipeline module.
TDD Phase 2: RED - Write tests first, then implement to pass.
"""
from typing import Any
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
class TestAugmentationPipeline:
"""Tests for AugmentationPipeline class."""
def test_create_with_config(self) -> None:
"""Test creating pipeline with config."""
from shared.augmentation.config import AugmentationConfig
from shared.augmentation.pipeline import AugmentationPipeline
config = AugmentationConfig()
pipeline = AugmentationPipeline(config)
assert pipeline.config is config
def test_create_with_seed(self) -> None:
"""Test creating pipeline with seed for reproducibility."""
from shared.augmentation.config import AugmentationConfig
from shared.augmentation.pipeline import AugmentationPipeline
config = AugmentationConfig(seed=42)
pipeline = AugmentationPipeline(config)
assert pipeline.config.seed == 42
def test_apply_returns_augmentation_result(self) -> None:
"""Test that apply returns AugmentationResult."""
from shared.augmentation.base import AugmentationResult
from shared.augmentation.config import AugmentationConfig
from shared.augmentation.pipeline import AugmentationPipeline
config = AugmentationConfig()
pipeline = AugmentationPipeline(config)
image = np.zeros((100, 100, 3), dtype=np.uint8)
result = pipeline.apply(image)
assert isinstance(result, AugmentationResult)
assert result.image is not None
assert result.image.shape == image.shape
def test_apply_with_bboxes(self) -> None:
"""Test apply with bounding boxes."""
from shared.augmentation.config import AugmentationConfig
from shared.augmentation.pipeline import AugmentationPipeline
config = AugmentationConfig()
pipeline = AugmentationPipeline(config)
image = np.zeros((100, 100, 3), dtype=np.uint8)
bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]], dtype=np.float32)
result = pipeline.apply(image, bboxes)
# Bboxes should be preserved when preserve_bboxes=True
assert result.bboxes is not None
def test_apply_no_augmentations_enabled(self) -> None:
"""Test apply when no augmentations are enabled."""
from shared.augmentation.config import AugmentationConfig, AugmentationParams
from shared.augmentation.pipeline import AugmentationPipeline
# Disable all augmentations
config = AugmentationConfig(
lighting_variation=AugmentationParams(enabled=False),
)
pipeline = AugmentationPipeline(config)
image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
result = pipeline.apply(image)
# Image should be unchanged (or a copy)
np.testing.assert_array_equal(result.image, image)
def test_apply_does_not_mutate_input(self) -> None:
"""Test that apply does not mutate input image."""
from shared.augmentation.config import AugmentationConfig, AugmentationParams
from shared.augmentation.pipeline import AugmentationPipeline
config = AugmentationConfig(
lighting_variation=AugmentationParams(enabled=True, probability=1.0),
)
pipeline = AugmentationPipeline(config)
image = np.full((100, 100, 3), 128, dtype=np.uint8)
original_copy = image.copy()
pipeline.apply(image)
np.testing.assert_array_equal(image, original_copy)
def test_reproducibility_with_seed(self) -> None:
"""Test that same seed produces same results."""
from shared.augmentation.config import AugmentationConfig, AugmentationParams
from shared.augmentation.pipeline import AugmentationPipeline
config1 = AugmentationConfig(
seed=42,
gaussian_noise=AugmentationParams(enabled=True, probability=1.0),
)
config2 = AugmentationConfig(
seed=42,
gaussian_noise=AugmentationParams(enabled=True, probability=1.0),
)
pipeline1 = AugmentationPipeline(config1)
pipeline2 = AugmentationPipeline(config2)
image = np.full((100, 100, 3), 128, dtype=np.uint8)
result1 = pipeline1.apply(image.copy())
result2 = pipeline2.apply(image.copy())
np.testing.assert_array_equal(result1.image, result2.image)
def test_metadata_contains_applied_augmentations(self) -> None:
"""Test that metadata lists applied augmentations."""
from shared.augmentation.config import AugmentationConfig, AugmentationParams
from shared.augmentation.pipeline import AugmentationPipeline
config = AugmentationConfig(
seed=42,
gaussian_noise=AugmentationParams(enabled=True, probability=1.0),
lighting_variation=AugmentationParams(enabled=True, probability=1.0),
)
pipeline = AugmentationPipeline(config)
image = np.full((100, 100, 3), 128, dtype=np.uint8)
result = pipeline.apply(image)
assert result.metadata is not None
assert "applied_augmentations" in result.metadata
# Both should be applied with probability=1.0
assert "gaussian_noise" in result.metadata["applied_augmentations"]
assert "lighting_variation" in result.metadata["applied_augmentations"]
class TestAugmentationPipelineStageOrder:
"""Tests for pipeline stage ordering."""
def test_stage_order_defined(self) -> None:
"""Test that stage order is defined."""
from shared.augmentation.pipeline import AugmentationPipeline
assert hasattr(AugmentationPipeline, "STAGE_ORDER")
expected_stages = [
"geometric",
"degradation",
"lighting",
"texture",
"blur",
"noise",
]
assert AugmentationPipeline.STAGE_ORDER == expected_stages
def test_stage_mapping_defined(self) -> None:
"""Test that all augmentation types are mapped to stages."""
from shared.augmentation.pipeline import AugmentationPipeline
assert hasattr(AugmentationPipeline, "STAGE_MAPPING")
expected_mappings = {
"perspective_warp": "geometric",
"wrinkle": "degradation",
"edge_damage": "degradation",
"stain": "degradation",
"lighting_variation": "lighting",
"shadow": "lighting",
"paper_texture": "texture",
"scanner_artifacts": "texture",
"gaussian_blur": "blur",
"motion_blur": "blur",
"gaussian_noise": "noise",
"salt_pepper": "noise",
}
for aug_name, stage in expected_mappings.items():
assert aug_name in AugmentationPipeline.STAGE_MAPPING
assert AugmentationPipeline.STAGE_MAPPING[aug_name] == stage
def test_geometric_before_degradation(self) -> None:
"""Test that geometric transforms are applied before degradation."""
from shared.augmentation.pipeline import AugmentationPipeline
stages = AugmentationPipeline.STAGE_ORDER
geometric_idx = stages.index("geometric")
degradation_idx = stages.index("degradation")
assert geometric_idx < degradation_idx
def test_noise_applied_last(self) -> None:
"""Test that noise is applied last."""
from shared.augmentation.pipeline import AugmentationPipeline
stages = AugmentationPipeline.STAGE_ORDER
assert stages[-1] == "noise"
class TestAugmentationRegistry:
"""Tests for augmentation registry."""
def test_registry_exists(self) -> None:
"""Test that augmentation registry exists."""
from shared.augmentation.pipeline import AUGMENTATION_REGISTRY
assert isinstance(AUGMENTATION_REGISTRY, dict)
def test_registry_contains_all_types(self) -> None:
"""Test that registry contains all augmentation types."""
from shared.augmentation.pipeline import AUGMENTATION_REGISTRY
expected_types = [
"perspective_warp",
"wrinkle",
"edge_damage",
"stain",
"lighting_variation",
"shadow",
"gaussian_blur",
"motion_blur",
"gaussian_noise",
"salt_pepper",
"paper_texture",
"scanner_artifacts",
]
for aug_type in expected_types:
assert aug_type in AUGMENTATION_REGISTRY, f"Missing: {aug_type}"
class TestPipelinePreview:
"""Tests for pipeline preview functionality."""
def test_preview_single_augmentation(self) -> None:
"""Test previewing a single augmentation."""
from shared.augmentation.config import AugmentationConfig, AugmentationParams
from shared.augmentation.pipeline import AugmentationPipeline
config = AugmentationConfig(
gaussian_noise=AugmentationParams(
enabled=True, probability=1.0, params={"std": (10, 10)}
),
)
pipeline = AugmentationPipeline(config)
image = np.full((100, 100, 3), 128, dtype=np.uint8)
preview = pipeline.preview(image, "gaussian_noise")
assert preview.shape == image.shape
assert preview.dtype == np.uint8
# Preview should modify the image
assert not np.array_equal(preview, image)
def test_preview_unknown_augmentation_raises(self) -> None:
"""Test that previewing unknown augmentation raises error."""
from shared.augmentation.config import AugmentationConfig
from shared.augmentation.pipeline import AugmentationPipeline
config = AugmentationConfig()
pipeline = AugmentationPipeline(config)
image = np.zeros((100, 100, 3), dtype=np.uint8)
with pytest.raises(ValueError, match="Unknown augmentation"):
pipeline.preview(image, "non_existent_augmentation")
def test_preview_is_deterministic(self) -> None:
"""Test that preview produces deterministic results."""
from shared.augmentation.config import AugmentationConfig, AugmentationParams
from shared.augmentation.pipeline import AugmentationPipeline
config = AugmentationConfig(
gaussian_noise=AugmentationParams(enabled=True),
)
pipeline = AugmentationPipeline(config)
image = np.full((100, 100, 3), 128, dtype=np.uint8)
preview1 = pipeline.preview(image, "gaussian_noise")
preview2 = pipeline.preview(image, "gaussian_noise")
np.testing.assert_array_equal(preview1, preview2)
class TestPipelineGetAvailableAugmentations:
"""Tests for getting available augmentations."""
def test_get_available_augmentations(self) -> None:
"""Test getting list of available augmentations."""
from shared.augmentation.pipeline import get_available_augmentations
augmentations = get_available_augmentations()
assert isinstance(augmentations, list)
assert len(augmentations) == 12
# Each item should have name, description, affects_geometry
for aug in augmentations:
assert "name" in aug
assert "description" in aug
assert "affects_geometry" in aug
assert "stage" in aug
def test_get_available_augmentations_includes_all_types(self) -> None:
"""Test that all augmentation types are included."""
from shared.augmentation.pipeline import get_available_augmentations
augmentations = get_available_augmentations()
names = [aug["name"] for aug in augmentations]
expected = [
"perspective_warp",
"wrinkle",
"edge_damage",
"stain",
"lighting_variation",
"shadow",
"gaussian_blur",
"motion_blur",
"gaussian_noise",
"salt_pepper",
"paper_texture",
"scanner_artifacts",
]
for name in expected:
assert name in names

View File

@@ -0,0 +1,102 @@
"""
Tests for augmentation presets module.
TDD Phase 4: RED - Write tests first, then implement to pass.
"""
import pytest
class TestPresets:
"""Tests for augmentation presets."""
def test_presets_dict_exists(self) -> None:
"""Test that PRESETS dictionary exists."""
from shared.augmentation.presets import PRESETS
assert isinstance(PRESETS, dict)
assert len(PRESETS) > 0
def test_expected_presets_exist(self) -> None:
"""Test that expected presets are defined."""
from shared.augmentation.presets import PRESETS
expected_presets = ["conservative", "moderate", "aggressive", "scanned_document"]
for preset_name in expected_presets:
assert preset_name in PRESETS, f"Missing preset: {preset_name}"
def test_preset_structure(self) -> None:
"""Test that each preset has required structure."""
from shared.augmentation.presets import PRESETS
for name, preset in PRESETS.items():
assert "description" in preset, f"Preset {name} missing description"
assert "config" in preset, f"Preset {name} missing config"
assert isinstance(preset["description"], str)
assert isinstance(preset["config"], dict)
def test_get_preset_config(self) -> None:
"""Test getting config from preset."""
from shared.augmentation.presets import get_preset_config
config = get_preset_config("conservative")
assert config is not None
# Should have at least some augmentations defined
assert len(config) > 0
def test_get_preset_config_unknown_raises(self) -> None:
"""Test that getting unknown preset raises error."""
from shared.augmentation.presets import get_preset_config
with pytest.raises(ValueError, match="Unknown preset"):
get_preset_config("nonexistent_preset")
def test_create_config_from_preset(self) -> None:
"""Test creating AugmentationConfig from preset."""
from shared.augmentation.config import AugmentationConfig
from shared.augmentation.presets import create_config_from_preset
config = create_config_from_preset("moderate")
assert isinstance(config, AugmentationConfig)
def test_conservative_preset_is_safe(self) -> None:
"""Test that conservative preset only enables safe augmentations."""
from shared.augmentation.presets import create_config_from_preset
config = create_config_from_preset("conservative")
# Should NOT enable geometric transforms
assert config.perspective_warp.enabled is False
# Should NOT enable heavy degradation
assert config.wrinkle.enabled is False
assert config.edge_damage.enabled is False
assert config.stain.enabled is False
def test_aggressive_preset_enables_more(self) -> None:
"""Test that aggressive preset enables more augmentations."""
from shared.augmentation.presets import create_config_from_preset
config = create_config_from_preset("aggressive")
enabled = config.get_enabled_augmentations()
# Should enable multiple augmentation types
assert len(enabled) >= 6
def test_list_presets(self) -> None:
"""Test listing available presets."""
from shared.augmentation.presets import list_presets
presets = list_presets()
assert isinstance(presets, list)
assert len(presets) >= 4
# Each item should have name and description
for preset in presets:
assert "name" in preset
assert "description" in preset

View File

@@ -0,0 +1 @@
# Tests for augmentation transforms

View File

@@ -0,0 +1,293 @@
"""
Tests for DatasetAugmenter.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
import tempfile
from pathlib import Path
import numpy as np
import pytest
from PIL import Image
class TestDatasetAugmenter:
"""Tests for DatasetAugmenter class."""
@pytest.fixture
def sample_dataset(self, tmp_path: Path) -> Path:
"""Create a sample YOLO dataset structure."""
dataset_dir = tmp_path / "dataset"
# Create directory structure
for split in ["train", "val", "test"]:
(dataset_dir / "images" / split).mkdir(parents=True)
(dataset_dir / "labels" / split).mkdir(parents=True)
# Create sample images and labels
for i in range(3):
# Create 100x100 white image
img = Image.new("RGB", (100, 100), color="white")
img_path = dataset_dir / "images" / "train" / f"doc_{i}.png"
img.save(img_path)
# Create label with 2 bboxes
# Format: class_id x_center y_center width height
label_content = "0 0.5 0.3 0.2 0.1\n1 0.7 0.6 0.15 0.2\n"
label_path = dataset_dir / "labels" / "train" / f"doc_{i}.txt"
label_path.write_text(label_content)
# Create data.yaml
data_yaml = dataset_dir / "data.yaml"
data_yaml.write_text(
"path: .\n"
"train: images/train\n"
"val: images/val\n"
"test: images/test\n"
"nc: 10\n"
"names: [class0, class1, class2, class3, class4, class5, class6, class7, class8, class9]\n"
)
return dataset_dir
@pytest.fixture
def augmentation_config(self) -> dict:
"""Create a sample augmentation config."""
return {
"gaussian_noise": {
"enabled": True,
"probability": 1.0,
"params": {"std": 10},
},
"gaussian_blur": {
"enabled": True,
"probability": 1.0,
"params": {"kernel_size": 3},
},
}
def test_augmenter_creates_additional_images(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that augmenter creates new augmented images."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
augmenter = DatasetAugmenter(augmentation_config)
# Count original images
original_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
assert original_count == 3
# Apply augmentation with multiplier=2
result = augmenter.augment_dataset(sample_dataset, multiplier=2)
# Should now have original + 2x augmented = 3 + 6 = 9 images
new_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
assert new_count == 9
assert result["augmented_images"] == 6
def test_augmenter_creates_matching_labels(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that augmenter creates label files for each augmented image."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
augmenter = DatasetAugmenter(augmentation_config)
augmenter.augment_dataset(sample_dataset, multiplier=2)
# Check that each image has a matching label file
images = list((sample_dataset / "images" / "train").glob("*.png"))
labels = list((sample_dataset / "labels" / "train").glob("*.txt"))
assert len(images) == len(labels)
# Check that augmented images have corresponding labels
for img_path in images:
label_path = sample_dataset / "labels" / "train" / f"{img_path.stem}.txt"
assert label_path.exists(), f"Missing label for {img_path.name}"
def test_augmented_labels_have_valid_format(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that augmented label files have valid YOLO format."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
augmenter = DatasetAugmenter(augmentation_config)
augmenter.augment_dataset(sample_dataset, multiplier=1)
# Check all label files
for label_path in (sample_dataset / "labels" / "train").glob("*.txt"):
content = label_path.read_text().strip()
if not content:
continue # Empty labels are valid (background images)
for line in content.split("\n"):
parts = line.split()
assert len(parts) == 5, f"Invalid label format in {label_path.name}"
class_id = int(parts[0])
x_center = float(parts[1])
y_center = float(parts[2])
width = float(parts[3])
height = float(parts[4])
# Check values are in valid range
assert 0 <= class_id < 100, f"Invalid class_id: {class_id}"
assert 0 <= x_center <= 1, f"Invalid x_center: {x_center}"
assert 0 <= y_center <= 1, f"Invalid y_center: {y_center}"
assert 0 <= width <= 1, f"Invalid width: {width}"
assert 0 <= height <= 1, f"Invalid height: {height}"
def test_augmented_images_are_different(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that augmented images are actually different from originals."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
# Load original image
original_path = sample_dataset / "images" / "train" / "doc_0.png"
original_img = np.array(Image.open(original_path))
augmenter = DatasetAugmenter(augmentation_config)
augmenter.augment_dataset(sample_dataset, multiplier=1)
# Find augmented version
aug_path = sample_dataset / "images" / "train" / "doc_0_aug0.png"
assert aug_path.exists()
aug_img = np.array(Image.open(aug_path))
# Images should be different (due to noise/blur)
assert not np.array_equal(original_img, aug_img)
def test_augmented_images_same_size(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that augmented images have same size as originals."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
# Get original size
original_path = sample_dataset / "images" / "train" / "doc_0.png"
original_img = Image.open(original_path)
original_size = original_img.size
augmenter = DatasetAugmenter(augmentation_config)
augmenter.augment_dataset(sample_dataset, multiplier=1)
# Check all augmented images have same size
for img_path in (sample_dataset / "images" / "train").glob("*_aug*.png"):
img = Image.open(img_path)
assert img.size == original_size, f"{img_path.name} has wrong size"
def test_perspective_warp_updates_bboxes(self, sample_dataset: Path):
"""Test that perspective_warp augmentation updates bbox coordinates."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
config = {
"perspective_warp": {
"enabled": True,
"probability": 1.0,
"params": {"max_warp": 0.05}, # Use larger warp for visible difference
},
}
# Read original label
original_label = (sample_dataset / "labels" / "train" / "doc_0.txt").read_text()
original_bboxes = [line.split() for line in original_label.strip().split("\n")]
augmenter = DatasetAugmenter(config)
augmenter.augment_dataset(sample_dataset, multiplier=1)
# Read augmented label
aug_label = (sample_dataset / "labels" / "train" / "doc_0_aug0.txt").read_text()
aug_bboxes = [line.split() for line in aug_label.strip().split("\n")]
# Same number of bboxes
assert len(original_bboxes) == len(aug_bboxes)
# At least one bbox should have different coordinates
# (perspective warp changes geometry)
differences_found = False
for orig, aug in zip(original_bboxes, aug_bboxes):
# Class ID should be same
assert orig[0] == aug[0]
# Coordinates might differ
if orig[1:] != aug[1:]:
differences_found = True
assert differences_found, "Perspective warp should change bbox coordinates"
def test_augmenter_only_processes_train_split(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that augmenter only processes train split by default."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
# Add a val image
val_img = Image.new("RGB", (100, 100), color="white")
val_img.save(sample_dataset / "images" / "val" / "val_doc.png")
(sample_dataset / "labels" / "val" / "val_doc.txt").write_text("0 0.5 0.5 0.1 0.1\n")
augmenter = DatasetAugmenter(augmentation_config)
augmenter.augment_dataset(sample_dataset, multiplier=2)
# Val should still have only 1 image
val_count = len(list((sample_dataset / "images" / "val").glob("*.png")))
assert val_count == 1
def test_augmenter_with_multiplier_zero_does_nothing(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that multiplier=0 creates no augmented images."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
original_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
augmenter = DatasetAugmenter(augmentation_config)
result = augmenter.augment_dataset(sample_dataset, multiplier=0)
new_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
assert new_count == original_count
assert result["augmented_images"] == 0
def test_augmenter_with_seed_is_reproducible(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that same seed produces same augmentation results."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
# Create two separate datasets
import shutil
dataset1 = sample_dataset
dataset2 = sample_dataset.parent / "dataset2"
shutil.copytree(dataset1, dataset2)
# Augment both with same seed
augmenter1 = DatasetAugmenter(augmentation_config, seed=42)
augmenter1.augment_dataset(dataset1, multiplier=1)
augmenter2 = DatasetAugmenter(augmentation_config, seed=42)
augmenter2.augment_dataset(dataset2, multiplier=1)
# Compare augmented images
aug1 = np.array(Image.open(dataset1 / "images" / "train" / "doc_0_aug0.png"))
aug2 = np.array(Image.open(dataset2 / "images" / "train" / "doc_0_aug0.png"))
assert np.array_equal(aug1, aug2), "Same seed should produce same augmentation"
def test_augmenter_returns_summary(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that augmenter returns a summary of what was done."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
augmenter = DatasetAugmenter(augmentation_config)
result = augmenter.augment_dataset(sample_dataset, multiplier=2)
assert "original_images" in result
assert "augmented_images" in result
assert "total_images" in result
assert result["original_images"] == 3
assert result["augmented_images"] == 6
assert result["total_images"] == 9