WIP
This commit is contained in:
1
tests/shared/augmentation/__init__.py
Normal file
1
tests/shared/augmentation/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Tests for augmentation module
|
||||
347
tests/shared/augmentation/test_base.py
Normal file
347
tests/shared/augmentation/test_base.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
Tests for augmentation base module.
|
||||
|
||||
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
class TestAugmentationResult:
|
||||
"""Tests for AugmentationResult dataclass."""
|
||||
|
||||
def test_minimal_result(self) -> None:
|
||||
"""Test creating result with only required field."""
|
||||
from shared.augmentation.base import AugmentationResult
|
||||
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
result = AugmentationResult(image=image)
|
||||
|
||||
assert result.image is image
|
||||
assert result.bboxes is None
|
||||
assert result.transform_matrix is None
|
||||
assert result.applied is True
|
||||
assert result.metadata is None
|
||||
|
||||
def test_full_result(self) -> None:
|
||||
"""Test creating result with all fields."""
|
||||
from shared.augmentation.base import AugmentationResult
|
||||
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]])
|
||||
transform = np.eye(3)
|
||||
metadata = {"applied_transform": "wrinkle"}
|
||||
|
||||
result = AugmentationResult(
|
||||
image=image,
|
||||
bboxes=bboxes,
|
||||
transform_matrix=transform,
|
||||
applied=True,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
assert result.image is image
|
||||
np.testing.assert_array_equal(result.bboxes, bboxes)
|
||||
np.testing.assert_array_equal(result.transform_matrix, transform)
|
||||
assert result.applied is True
|
||||
assert result.metadata == {"applied_transform": "wrinkle"}
|
||||
|
||||
def test_not_applied(self) -> None:
|
||||
"""Test result when augmentation was not applied."""
|
||||
from shared.augmentation.base import AugmentationResult
|
||||
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
result = AugmentationResult(image=image, applied=False)
|
||||
|
||||
assert result.applied is False
|
||||
|
||||
|
||||
class TestBaseAugmentation:
|
||||
"""Tests for BaseAugmentation abstract class."""
|
||||
|
||||
def test_cannot_instantiate_directly(self) -> None:
|
||||
"""Test that BaseAugmentation cannot be instantiated."""
|
||||
from shared.augmentation.base import BaseAugmentation
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
BaseAugmentation({}) # type: ignore
|
||||
|
||||
def test_subclass_must_implement_apply(self) -> None:
|
||||
"""Test that subclass must implement apply method."""
|
||||
from shared.augmentation.base import BaseAugmentation
|
||||
|
||||
class IncompleteAugmentation(BaseAugmentation):
|
||||
name = "incomplete"
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
pass
|
||||
|
||||
# Missing apply method
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
IncompleteAugmentation({}) # type: ignore
|
||||
|
||||
def test_subclass_must_implement_validate_params(self) -> None:
|
||||
"""Test that subclass must implement _validate_params."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
class IncompleteAugmentation(BaseAugmentation):
|
||||
name = "incomplete"
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
return AugmentationResult(image=image)
|
||||
|
||||
# Missing _validate_params method
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
IncompleteAugmentation({}) # type: ignore
|
||||
|
||||
def test_valid_subclass(self) -> None:
|
||||
"""Test creating a valid subclass."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
class DummyAugmentation(BaseAugmentation):
|
||||
name = "dummy"
|
||||
affects_geometry = False
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
pass
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
return AugmentationResult(image=image, bboxes=bboxes)
|
||||
|
||||
aug = DummyAugmentation({"param1": "value1"})
|
||||
|
||||
assert aug.name == "dummy"
|
||||
assert aug.affects_geometry is False
|
||||
assert aug.params == {"param1": "value1"}
|
||||
|
||||
def test_apply_returns_augmentation_result(self) -> None:
|
||||
"""Test that apply returns AugmentationResult."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
class DummyAugmentation(BaseAugmentation):
|
||||
name = "dummy"
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
pass
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
# Simple pass-through
|
||||
return AugmentationResult(image=image, bboxes=bboxes)
|
||||
|
||||
aug = DummyAugmentation({})
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]])
|
||||
|
||||
result = aug.apply(image, bboxes)
|
||||
|
||||
assert isinstance(result, AugmentationResult)
|
||||
assert result.image is image
|
||||
np.testing.assert_array_equal(result.bboxes, bboxes)
|
||||
|
||||
def test_affects_geometry_default(self) -> None:
|
||||
"""Test that affects_geometry defaults to False."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
class DummyAugmentation(BaseAugmentation):
|
||||
name = "dummy"
|
||||
# Not setting affects_geometry
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
pass
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
return AugmentationResult(image=image)
|
||||
|
||||
aug = DummyAugmentation({})
|
||||
|
||||
assert aug.affects_geometry is False
|
||||
|
||||
def test_validate_params_called_on_init(self) -> None:
|
||||
"""Test that _validate_params is called during initialization."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
validation_called = {"called": False}
|
||||
|
||||
class ValidatingAugmentation(BaseAugmentation):
|
||||
name = "validating"
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
validation_called["called"] = True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
return AugmentationResult(image=image)
|
||||
|
||||
ValidatingAugmentation({})
|
||||
|
||||
assert validation_called["called"] is True
|
||||
|
||||
def test_validate_params_raises_on_invalid(self) -> None:
|
||||
"""Test that _validate_params can raise ValueError."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
class StrictAugmentation(BaseAugmentation):
|
||||
name = "strict"
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
if "required_param" not in self.params:
|
||||
raise ValueError("required_param is required")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
return AugmentationResult(image=image)
|
||||
|
||||
with pytest.raises(ValueError, match="required_param"):
|
||||
StrictAugmentation({})
|
||||
|
||||
# Should work with required param
|
||||
aug = StrictAugmentation({"required_param": "value"})
|
||||
assert aug.params["required_param"] == "value"
|
||||
|
||||
def test_rng_usage(self) -> None:
|
||||
"""Test that random generator can be passed and used."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
class RandomAugmentation(BaseAugmentation):
|
||||
name = "random"
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
pass
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
if rng is None:
|
||||
rng = np.random.default_rng()
|
||||
# Use rng to generate a random value
|
||||
random_value = rng.random()
|
||||
return AugmentationResult(
|
||||
image=image,
|
||||
metadata={"random_value": random_value},
|
||||
)
|
||||
|
||||
aug = RandomAugmentation({})
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
# With same seed, should get same result
|
||||
rng1 = np.random.default_rng(42)
|
||||
rng2 = np.random.default_rng(42)
|
||||
|
||||
result1 = aug.apply(image, rng=rng1)
|
||||
result2 = aug.apply(image, rng=rng2)
|
||||
|
||||
assert result1.metadata is not None
|
||||
assert result2.metadata is not None
|
||||
assert result1.metadata["random_value"] == result2.metadata["random_value"]
|
||||
|
||||
|
||||
class TestAugmentationResultImmutability:
|
||||
"""Tests for ensuring result doesn't mutate input."""
|
||||
|
||||
def test_image_not_modified(self) -> None:
|
||||
"""Test that original image is not modified."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
class ModifyingAugmentation(BaseAugmentation):
|
||||
name = "modifying"
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
pass
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
# Should copy before modifying
|
||||
modified = image.copy()
|
||||
modified[:] = 255
|
||||
return AugmentationResult(image=modified)
|
||||
|
||||
aug = ModifyingAugmentation({})
|
||||
original = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
original_copy = original.copy()
|
||||
|
||||
result = aug.apply(original)
|
||||
|
||||
# Original should be unchanged
|
||||
np.testing.assert_array_equal(original, original_copy)
|
||||
# Result should be modified
|
||||
assert np.all(result.image == 255)
|
||||
|
||||
def test_bboxes_not_modified(self) -> None:
|
||||
"""Test that original bboxes are not modified."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
class BboxModifyingAugmentation(BaseAugmentation):
|
||||
name = "bbox_modifying"
|
||||
affects_geometry = True
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
pass
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
if bboxes is not None:
|
||||
# Should copy before modifying
|
||||
modified_bboxes = bboxes.copy()
|
||||
modified_bboxes[:, 1:] *= 0.5 # Scale down
|
||||
return AugmentationResult(image=image, bboxes=modified_bboxes)
|
||||
return AugmentationResult(image=image)
|
||||
|
||||
aug = BboxModifyingAugmentation({})
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
original_bboxes = np.array([[0, 0.5, 0.5, 0.2, 0.2]], dtype=np.float32)
|
||||
original_bboxes_copy = original_bboxes.copy()
|
||||
|
||||
result = aug.apply(image, original_bboxes)
|
||||
|
||||
# Original should be unchanged
|
||||
np.testing.assert_array_equal(original_bboxes, original_bboxes_copy)
|
||||
# Result should be modified
|
||||
assert result.bboxes is not None
|
||||
np.testing.assert_array_almost_equal(
|
||||
result.bboxes, np.array([[0, 0.25, 0.25, 0.1, 0.1]])
|
||||
)
|
||||
283
tests/shared/augmentation/test_config.py
Normal file
283
tests/shared/augmentation/test_config.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
Tests for augmentation configuration module.
|
||||
|
||||
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestAugmentationParams:
|
||||
"""Tests for AugmentationParams dataclass."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test default parameter values."""
|
||||
from shared.augmentation.config import AugmentationParams
|
||||
|
||||
params = AugmentationParams()
|
||||
|
||||
assert params.enabled is False
|
||||
assert params.probability == 0.5
|
||||
assert params.params == {}
|
||||
|
||||
def test_custom_values(self) -> None:
|
||||
"""Test creating params with custom values."""
|
||||
from shared.augmentation.config import AugmentationParams
|
||||
|
||||
params = AugmentationParams(
|
||||
enabled=True,
|
||||
probability=0.8,
|
||||
params={"intensity": 0.5, "num_wrinkles": (2, 5)},
|
||||
)
|
||||
|
||||
assert params.enabled is True
|
||||
assert params.probability == 0.8
|
||||
assert params.params["intensity"] == 0.5
|
||||
assert params.params["num_wrinkles"] == (2, 5)
|
||||
|
||||
def test_immutability_params_dict(self) -> None:
|
||||
"""Test that params dict is independent between instances."""
|
||||
from shared.augmentation.config import AugmentationParams
|
||||
|
||||
params1 = AugmentationParams()
|
||||
params2 = AugmentationParams()
|
||||
|
||||
# Modifying one should not affect the other
|
||||
params1.params["test"] = "value"
|
||||
|
||||
assert "test" not in params2.params
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test conversion to dictionary."""
|
||||
from shared.augmentation.config import AugmentationParams
|
||||
|
||||
params = AugmentationParams(
|
||||
enabled=True,
|
||||
probability=0.7,
|
||||
params={"key": "value"},
|
||||
)
|
||||
|
||||
result = params.to_dict()
|
||||
|
||||
assert result == {
|
||||
"enabled": True,
|
||||
"probability": 0.7,
|
||||
"params": {"key": "value"},
|
||||
}
|
||||
|
||||
def test_from_dict(self) -> None:
|
||||
"""Test creation from dictionary."""
|
||||
from shared.augmentation.config import AugmentationParams
|
||||
|
||||
data = {
|
||||
"enabled": True,
|
||||
"probability": 0.6,
|
||||
"params": {"intensity": 0.3},
|
||||
}
|
||||
|
||||
params = AugmentationParams.from_dict(data)
|
||||
|
||||
assert params.enabled is True
|
||||
assert params.probability == 0.6
|
||||
assert params.params == {"intensity": 0.3}
|
||||
|
||||
def test_from_dict_with_defaults(self) -> None:
|
||||
"""Test creation from partial dictionary uses defaults."""
|
||||
from shared.augmentation.config import AugmentationParams
|
||||
|
||||
data: dict[str, Any] = {"enabled": True}
|
||||
|
||||
params = AugmentationParams.from_dict(data)
|
||||
|
||||
assert params.enabled is True
|
||||
assert params.probability == 0.5 # default
|
||||
assert params.params == {} # default
|
||||
|
||||
|
||||
class TestAugmentationConfig:
|
||||
"""Tests for AugmentationConfig dataclass."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test that all augmentation types have defaults."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
|
||||
config = AugmentationConfig()
|
||||
|
||||
# All augmentation types should exist
|
||||
augmentation_types = [
|
||||
"perspective_warp",
|
||||
"wrinkle",
|
||||
"edge_damage",
|
||||
"stain",
|
||||
"lighting_variation",
|
||||
"shadow",
|
||||
"gaussian_blur",
|
||||
"motion_blur",
|
||||
"gaussian_noise",
|
||||
"salt_pepper",
|
||||
"paper_texture",
|
||||
"scanner_artifacts",
|
||||
]
|
||||
|
||||
for aug_type in augmentation_types:
|
||||
assert hasattr(config, aug_type), f"Missing augmentation type: {aug_type}"
|
||||
params = getattr(config, aug_type)
|
||||
assert hasattr(params, "enabled")
|
||||
assert hasattr(params, "probability")
|
||||
assert hasattr(params, "params")
|
||||
|
||||
def test_global_settings_defaults(self) -> None:
|
||||
"""Test global settings default values."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
|
||||
config = AugmentationConfig()
|
||||
|
||||
assert config.preserve_bboxes is True
|
||||
assert config.seed is None
|
||||
|
||||
def test_custom_seed(self) -> None:
|
||||
"""Test setting custom seed for reproducibility."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
|
||||
config = AugmentationConfig(seed=42)
|
||||
|
||||
assert config.seed == 42
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test conversion to dictionary."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
|
||||
config = AugmentationConfig(seed=123, preserve_bboxes=False)
|
||||
|
||||
result = config.to_dict()
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["seed"] == 123
|
||||
assert result["preserve_bboxes"] is False
|
||||
assert "perspective_warp" in result
|
||||
assert "wrinkle" in result
|
||||
|
||||
def test_from_dict(self) -> None:
|
||||
"""Test creation from dictionary."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
|
||||
data = {
|
||||
"seed": 456,
|
||||
"preserve_bboxes": False,
|
||||
"wrinkle": {
|
||||
"enabled": True,
|
||||
"probability": 0.8,
|
||||
"params": {"intensity": 0.5},
|
||||
},
|
||||
}
|
||||
|
||||
config = AugmentationConfig.from_dict(data)
|
||||
|
||||
assert config.seed == 456
|
||||
assert config.preserve_bboxes is False
|
||||
assert config.wrinkle.enabled is True
|
||||
assert config.wrinkle.probability == 0.8
|
||||
assert config.wrinkle.params["intensity"] == 0.5
|
||||
|
||||
def test_from_dict_with_partial_data(self) -> None:
|
||||
"""Test creation from partial dictionary uses defaults."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
|
||||
data: dict[str, Any] = {
|
||||
"wrinkle": {"enabled": True},
|
||||
}
|
||||
|
||||
config = AugmentationConfig.from_dict(data)
|
||||
|
||||
# Explicitly set value
|
||||
assert config.wrinkle.enabled is True
|
||||
# Default values
|
||||
assert config.preserve_bboxes is True
|
||||
assert config.seed is None
|
||||
assert config.gaussian_blur.enabled is False
|
||||
|
||||
def test_get_enabled_augmentations(self) -> None:
|
||||
"""Test getting list of enabled augmentations."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
|
||||
config = AugmentationConfig(
|
||||
wrinkle=AugmentationParams(enabled=True),
|
||||
stain=AugmentationParams(enabled=True),
|
||||
gaussian_blur=AugmentationParams(enabled=False),
|
||||
)
|
||||
|
||||
enabled = config.get_enabled_augmentations()
|
||||
|
||||
assert "wrinkle" in enabled
|
||||
assert "stain" in enabled
|
||||
assert "gaussian_blur" not in enabled
|
||||
|
||||
def test_document_safe_defaults(self) -> None:
|
||||
"""Test that default params are document-safe (conservative)."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
|
||||
config = AugmentationConfig()
|
||||
|
||||
# Perspective warp should be very conservative
|
||||
assert config.perspective_warp.params.get("max_warp", 0.02) <= 0.05
|
||||
|
||||
# Noise should be subtle
|
||||
noise_std = config.gaussian_noise.params.get("std", (5, 15))
|
||||
if isinstance(noise_std, tuple):
|
||||
assert noise_std[1] <= 20 # Max std should be low
|
||||
|
||||
def test_immutability_between_instances(self) -> None:
|
||||
"""Test that config instances are independent."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
|
||||
config1 = AugmentationConfig()
|
||||
config2 = AugmentationConfig()
|
||||
|
||||
# Modifying one should not affect the other
|
||||
config1.wrinkle.params["test"] = "value"
|
||||
|
||||
assert "test" not in config2.wrinkle.params
|
||||
|
||||
|
||||
class TestAugmentationConfigValidation:
|
||||
"""Tests for configuration validation."""
|
||||
|
||||
def test_probability_range_validation(self) -> None:
|
||||
"""Test that probability values are validated."""
|
||||
from shared.augmentation.config import AugmentationParams
|
||||
|
||||
# Valid range
|
||||
params = AugmentationParams(probability=0.5)
|
||||
assert params.probability == 0.5
|
||||
|
||||
# Edge cases
|
||||
params_zero = AugmentationParams(probability=0.0)
|
||||
assert params_zero.probability == 0.0
|
||||
|
||||
params_one = AugmentationParams(probability=1.0)
|
||||
assert params_one.probability == 1.0
|
||||
|
||||
def test_config_validate_method(self) -> None:
|
||||
"""Test the validate method catches invalid configurations."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
|
||||
# Invalid probability
|
||||
config = AugmentationConfig(
|
||||
wrinkle=AugmentationParams(probability=1.5), # Invalid
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="probability"):
|
||||
config.validate()
|
||||
|
||||
def test_config_validate_negative_probability(self) -> None:
|
||||
"""Test validation catches negative probability."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
|
||||
config = AugmentationConfig(
|
||||
wrinkle=AugmentationParams(probability=-0.1),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="probability"):
|
||||
config.validate()
|
||||
338
tests/shared/augmentation/test_pipeline.py
Normal file
338
tests/shared/augmentation/test_pipeline.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""
|
||||
Tests for augmentation pipeline module.
|
||||
|
||||
TDD Phase 2: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
class TestAugmentationPipeline:
|
||||
"""Tests for AugmentationPipeline class."""
|
||||
|
||||
def test_create_with_config(self) -> None:
|
||||
"""Test creating pipeline with config."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig()
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
assert pipeline.config is config
|
||||
|
||||
def test_create_with_seed(self) -> None:
|
||||
"""Test creating pipeline with seed for reproducibility."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig(seed=42)
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
assert pipeline.config.seed == 42
|
||||
|
||||
def test_apply_returns_augmentation_result(self) -> None:
|
||||
"""Test that apply returns AugmentationResult."""
|
||||
from shared.augmentation.base import AugmentationResult
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig()
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
result = pipeline.apply(image)
|
||||
|
||||
assert isinstance(result, AugmentationResult)
|
||||
assert result.image is not None
|
||||
assert result.image.shape == image.shape
|
||||
|
||||
def test_apply_with_bboxes(self) -> None:
|
||||
"""Test apply with bounding boxes."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig()
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]], dtype=np.float32)
|
||||
|
||||
result = pipeline.apply(image, bboxes)
|
||||
|
||||
# Bboxes should be preserved when preserve_bboxes=True
|
||||
assert result.bboxes is not None
|
||||
|
||||
def test_apply_no_augmentations_enabled(self) -> None:
|
||||
"""Test apply when no augmentations are enabled."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
# Disable all augmentations
|
||||
config = AugmentationConfig(
|
||||
lighting_variation=AugmentationParams(enabled=False),
|
||||
)
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||
result = pipeline.apply(image)
|
||||
|
||||
# Image should be unchanged (or a copy)
|
||||
np.testing.assert_array_equal(result.image, image)
|
||||
|
||||
def test_apply_does_not_mutate_input(self) -> None:
|
||||
"""Test that apply does not mutate input image."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig(
|
||||
lighting_variation=AugmentationParams(enabled=True, probability=1.0),
|
||||
)
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
image = np.full((100, 100, 3), 128, dtype=np.uint8)
|
||||
original_copy = image.copy()
|
||||
|
||||
pipeline.apply(image)
|
||||
|
||||
np.testing.assert_array_equal(image, original_copy)
|
||||
|
||||
def test_reproducibility_with_seed(self) -> None:
|
||||
"""Test that same seed produces same results."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config1 = AugmentationConfig(
|
||||
seed=42,
|
||||
gaussian_noise=AugmentationParams(enabled=True, probability=1.0),
|
||||
)
|
||||
config2 = AugmentationConfig(
|
||||
seed=42,
|
||||
gaussian_noise=AugmentationParams(enabled=True, probability=1.0),
|
||||
)
|
||||
|
||||
pipeline1 = AugmentationPipeline(config1)
|
||||
pipeline2 = AugmentationPipeline(config2)
|
||||
|
||||
image = np.full((100, 100, 3), 128, dtype=np.uint8)
|
||||
|
||||
result1 = pipeline1.apply(image.copy())
|
||||
result2 = pipeline2.apply(image.copy())
|
||||
|
||||
np.testing.assert_array_equal(result1.image, result2.image)
|
||||
|
||||
def test_metadata_contains_applied_augmentations(self) -> None:
|
||||
"""Test that metadata lists applied augmentations."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig(
|
||||
seed=42,
|
||||
gaussian_noise=AugmentationParams(enabled=True, probability=1.0),
|
||||
lighting_variation=AugmentationParams(enabled=True, probability=1.0),
|
||||
)
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
image = np.full((100, 100, 3), 128, dtype=np.uint8)
|
||||
result = pipeline.apply(image)
|
||||
|
||||
assert result.metadata is not None
|
||||
assert "applied_augmentations" in result.metadata
|
||||
# Both should be applied with probability=1.0
|
||||
assert "gaussian_noise" in result.metadata["applied_augmentations"]
|
||||
assert "lighting_variation" in result.metadata["applied_augmentations"]
|
||||
|
||||
|
||||
class TestAugmentationPipelineStageOrder:
|
||||
"""Tests for pipeline stage ordering."""
|
||||
|
||||
def test_stage_order_defined(self) -> None:
|
||||
"""Test that stage order is defined."""
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
assert hasattr(AugmentationPipeline, "STAGE_ORDER")
|
||||
expected_stages = [
|
||||
"geometric",
|
||||
"degradation",
|
||||
"lighting",
|
||||
"texture",
|
||||
"blur",
|
||||
"noise",
|
||||
]
|
||||
assert AugmentationPipeline.STAGE_ORDER == expected_stages
|
||||
|
||||
def test_stage_mapping_defined(self) -> None:
|
||||
"""Test that all augmentation types are mapped to stages."""
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
assert hasattr(AugmentationPipeline, "STAGE_MAPPING")
|
||||
|
||||
expected_mappings = {
|
||||
"perspective_warp": "geometric",
|
||||
"wrinkle": "degradation",
|
||||
"edge_damage": "degradation",
|
||||
"stain": "degradation",
|
||||
"lighting_variation": "lighting",
|
||||
"shadow": "lighting",
|
||||
"paper_texture": "texture",
|
||||
"scanner_artifacts": "texture",
|
||||
"gaussian_blur": "blur",
|
||||
"motion_blur": "blur",
|
||||
"gaussian_noise": "noise",
|
||||
"salt_pepper": "noise",
|
||||
}
|
||||
|
||||
for aug_name, stage in expected_mappings.items():
|
||||
assert aug_name in AugmentationPipeline.STAGE_MAPPING
|
||||
assert AugmentationPipeline.STAGE_MAPPING[aug_name] == stage
|
||||
|
||||
def test_geometric_before_degradation(self) -> None:
|
||||
"""Test that geometric transforms are applied before degradation."""
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
stages = AugmentationPipeline.STAGE_ORDER
|
||||
geometric_idx = stages.index("geometric")
|
||||
degradation_idx = stages.index("degradation")
|
||||
|
||||
assert geometric_idx < degradation_idx
|
||||
|
||||
def test_noise_applied_last(self) -> None:
|
||||
"""Test that noise is applied last."""
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
stages = AugmentationPipeline.STAGE_ORDER
|
||||
assert stages[-1] == "noise"
|
||||
|
||||
|
||||
class TestAugmentationRegistry:
|
||||
"""Tests for augmentation registry."""
|
||||
|
||||
def test_registry_exists(self) -> None:
|
||||
"""Test that augmentation registry exists."""
|
||||
from shared.augmentation.pipeline import AUGMENTATION_REGISTRY
|
||||
|
||||
assert isinstance(AUGMENTATION_REGISTRY, dict)
|
||||
|
||||
def test_registry_contains_all_types(self) -> None:
|
||||
"""Test that registry contains all augmentation types."""
|
||||
from shared.augmentation.pipeline import AUGMENTATION_REGISTRY
|
||||
|
||||
expected_types = [
|
||||
"perspective_warp",
|
||||
"wrinkle",
|
||||
"edge_damage",
|
||||
"stain",
|
||||
"lighting_variation",
|
||||
"shadow",
|
||||
"gaussian_blur",
|
||||
"motion_blur",
|
||||
"gaussian_noise",
|
||||
"salt_pepper",
|
||||
"paper_texture",
|
||||
"scanner_artifacts",
|
||||
]
|
||||
|
||||
for aug_type in expected_types:
|
||||
assert aug_type in AUGMENTATION_REGISTRY, f"Missing: {aug_type}"
|
||||
|
||||
|
||||
class TestPipelinePreview:
|
||||
"""Tests for pipeline preview functionality."""
|
||||
|
||||
def test_preview_single_augmentation(self) -> None:
|
||||
"""Test previewing a single augmentation."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig(
|
||||
gaussian_noise=AugmentationParams(
|
||||
enabled=True, probability=1.0, params={"std": (10, 10)}
|
||||
),
|
||||
)
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
image = np.full((100, 100, 3), 128, dtype=np.uint8)
|
||||
preview = pipeline.preview(image, "gaussian_noise")
|
||||
|
||||
assert preview.shape == image.shape
|
||||
assert preview.dtype == np.uint8
|
||||
# Preview should modify the image
|
||||
assert not np.array_equal(preview, image)
|
||||
|
||||
def test_preview_unknown_augmentation_raises(self) -> None:
|
||||
"""Test that previewing unknown augmentation raises error."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig()
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown augmentation"):
|
||||
pipeline.preview(image, "non_existent_augmentation")
|
||||
|
||||
def test_preview_is_deterministic(self) -> None:
|
||||
"""Test that preview produces deterministic results."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig(
|
||||
gaussian_noise=AugmentationParams(enabled=True),
|
||||
)
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
image = np.full((100, 100, 3), 128, dtype=np.uint8)
|
||||
|
||||
preview1 = pipeline.preview(image, "gaussian_noise")
|
||||
preview2 = pipeline.preview(image, "gaussian_noise")
|
||||
|
||||
np.testing.assert_array_equal(preview1, preview2)
|
||||
|
||||
|
||||
class TestPipelineGetAvailableAugmentations:
|
||||
"""Tests for getting available augmentations."""
|
||||
|
||||
def test_get_available_augmentations(self) -> None:
|
||||
"""Test getting list of available augmentations."""
|
||||
from shared.augmentation.pipeline import get_available_augmentations
|
||||
|
||||
augmentations = get_available_augmentations()
|
||||
|
||||
assert isinstance(augmentations, list)
|
||||
assert len(augmentations) == 12
|
||||
|
||||
# Each item should have name, description, affects_geometry
|
||||
for aug in augmentations:
|
||||
assert "name" in aug
|
||||
assert "description" in aug
|
||||
assert "affects_geometry" in aug
|
||||
assert "stage" in aug
|
||||
|
||||
def test_get_available_augmentations_includes_all_types(self) -> None:
|
||||
"""Test that all augmentation types are included."""
|
||||
from shared.augmentation.pipeline import get_available_augmentations
|
||||
|
||||
augmentations = get_available_augmentations()
|
||||
names = [aug["name"] for aug in augmentations]
|
||||
|
||||
expected = [
|
||||
"perspective_warp",
|
||||
"wrinkle",
|
||||
"edge_damage",
|
||||
"stain",
|
||||
"lighting_variation",
|
||||
"shadow",
|
||||
"gaussian_blur",
|
||||
"motion_blur",
|
||||
"gaussian_noise",
|
||||
"salt_pepper",
|
||||
"paper_texture",
|
||||
"scanner_artifacts",
|
||||
]
|
||||
|
||||
for name in expected:
|
||||
assert name in names
|
||||
102
tests/shared/augmentation/test_presets.py
Normal file
102
tests/shared/augmentation/test_presets.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
Tests for augmentation presets module.
|
||||
|
||||
TDD Phase 4: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestPresets:
|
||||
"""Tests for augmentation presets."""
|
||||
|
||||
def test_presets_dict_exists(self) -> None:
|
||||
"""Test that PRESETS dictionary exists."""
|
||||
from shared.augmentation.presets import PRESETS
|
||||
|
||||
assert isinstance(PRESETS, dict)
|
||||
assert len(PRESETS) > 0
|
||||
|
||||
def test_expected_presets_exist(self) -> None:
|
||||
"""Test that expected presets are defined."""
|
||||
from shared.augmentation.presets import PRESETS
|
||||
|
||||
expected_presets = ["conservative", "moderate", "aggressive", "scanned_document"]
|
||||
|
||||
for preset_name in expected_presets:
|
||||
assert preset_name in PRESETS, f"Missing preset: {preset_name}"
|
||||
|
||||
def test_preset_structure(self) -> None:
|
||||
"""Test that each preset has required structure."""
|
||||
from shared.augmentation.presets import PRESETS
|
||||
|
||||
for name, preset in PRESETS.items():
|
||||
assert "description" in preset, f"Preset {name} missing description"
|
||||
assert "config" in preset, f"Preset {name} missing config"
|
||||
assert isinstance(preset["description"], str)
|
||||
assert isinstance(preset["config"], dict)
|
||||
|
||||
def test_get_preset_config(self) -> None:
|
||||
"""Test getting config from preset."""
|
||||
from shared.augmentation.presets import get_preset_config
|
||||
|
||||
config = get_preset_config("conservative")
|
||||
|
||||
assert config is not None
|
||||
# Should have at least some augmentations defined
|
||||
assert len(config) > 0
|
||||
|
||||
def test_get_preset_config_unknown_raises(self) -> None:
|
||||
"""Test that getting unknown preset raises error."""
|
||||
from shared.augmentation.presets import get_preset_config
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown preset"):
|
||||
get_preset_config("nonexistent_preset")
|
||||
|
||||
def test_create_config_from_preset(self) -> None:
|
||||
"""Test creating AugmentationConfig from preset."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
from shared.augmentation.presets import create_config_from_preset
|
||||
|
||||
config = create_config_from_preset("moderate")
|
||||
|
||||
assert isinstance(config, AugmentationConfig)
|
||||
|
||||
def test_conservative_preset_is_safe(self) -> None:
|
||||
"""Test that conservative preset only enables safe augmentations."""
|
||||
from shared.augmentation.presets import create_config_from_preset
|
||||
|
||||
config = create_config_from_preset("conservative")
|
||||
|
||||
# Should NOT enable geometric transforms
|
||||
assert config.perspective_warp.enabled is False
|
||||
|
||||
# Should NOT enable heavy degradation
|
||||
assert config.wrinkle.enabled is False
|
||||
assert config.edge_damage.enabled is False
|
||||
assert config.stain.enabled is False
|
||||
|
||||
def test_aggressive_preset_enables_more(self) -> None:
|
||||
"""Test that aggressive preset enables more augmentations."""
|
||||
from shared.augmentation.presets import create_config_from_preset
|
||||
|
||||
config = create_config_from_preset("aggressive")
|
||||
|
||||
enabled = config.get_enabled_augmentations()
|
||||
|
||||
# Should enable multiple augmentation types
|
||||
assert len(enabled) >= 6
|
||||
|
||||
def test_list_presets(self) -> None:
|
||||
"""Test listing available presets."""
|
||||
from shared.augmentation.presets import list_presets
|
||||
|
||||
presets = list_presets()
|
||||
|
||||
assert isinstance(presets, list)
|
||||
assert len(presets) >= 4
|
||||
|
||||
# Each item should have name and description
|
||||
for preset in presets:
|
||||
assert "name" in preset
|
||||
assert "description" in preset
|
||||
1
tests/shared/augmentation/transforms/__init__.py
Normal file
1
tests/shared/augmentation/transforms/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Tests for augmentation transforms
|
||||
Reference in New Issue
Block a user