284 lines
8.9 KiB
Python
284 lines
8.9 KiB
Python
"""
|
|
Tests for augmentation configuration module.
|
|
|
|
TDD Phase 1: RED - Write tests first, then implement to pass.
|
|
"""
|
|
|
|
from typing import Any
|
|
|
|
import pytest
|
|
|
|
|
|
class TestAugmentationParams:
|
|
"""Tests for AugmentationParams dataclass."""
|
|
|
|
def test_default_values(self) -> None:
|
|
"""Test default parameter values."""
|
|
from shared.augmentation.config import AugmentationParams
|
|
|
|
params = AugmentationParams()
|
|
|
|
assert params.enabled is False
|
|
assert params.probability == 0.5
|
|
assert params.params == {}
|
|
|
|
def test_custom_values(self) -> None:
|
|
"""Test creating params with custom values."""
|
|
from shared.augmentation.config import AugmentationParams
|
|
|
|
params = AugmentationParams(
|
|
enabled=True,
|
|
probability=0.8,
|
|
params={"intensity": 0.5, "num_wrinkles": (2, 5)},
|
|
)
|
|
|
|
assert params.enabled is True
|
|
assert params.probability == 0.8
|
|
assert params.params["intensity"] == 0.5
|
|
assert params.params["num_wrinkles"] == (2, 5)
|
|
|
|
def test_immutability_params_dict(self) -> None:
|
|
"""Test that params dict is independent between instances."""
|
|
from shared.augmentation.config import AugmentationParams
|
|
|
|
params1 = AugmentationParams()
|
|
params2 = AugmentationParams()
|
|
|
|
# Modifying one should not affect the other
|
|
params1.params["test"] = "value"
|
|
|
|
assert "test" not in params2.params
|
|
|
|
def test_to_dict(self) -> None:
|
|
"""Test conversion to dictionary."""
|
|
from shared.augmentation.config import AugmentationParams
|
|
|
|
params = AugmentationParams(
|
|
enabled=True,
|
|
probability=0.7,
|
|
params={"key": "value"},
|
|
)
|
|
|
|
result = params.to_dict()
|
|
|
|
assert result == {
|
|
"enabled": True,
|
|
"probability": 0.7,
|
|
"params": {"key": "value"},
|
|
}
|
|
|
|
def test_from_dict(self) -> None:
|
|
"""Test creation from dictionary."""
|
|
from shared.augmentation.config import AugmentationParams
|
|
|
|
data = {
|
|
"enabled": True,
|
|
"probability": 0.6,
|
|
"params": {"intensity": 0.3},
|
|
}
|
|
|
|
params = AugmentationParams.from_dict(data)
|
|
|
|
assert params.enabled is True
|
|
assert params.probability == 0.6
|
|
assert params.params == {"intensity": 0.3}
|
|
|
|
def test_from_dict_with_defaults(self) -> None:
|
|
"""Test creation from partial dictionary uses defaults."""
|
|
from shared.augmentation.config import AugmentationParams
|
|
|
|
data: dict[str, Any] = {"enabled": True}
|
|
|
|
params = AugmentationParams.from_dict(data)
|
|
|
|
assert params.enabled is True
|
|
assert params.probability == 0.5 # default
|
|
assert params.params == {} # default
|
|
|
|
|
|
class TestAugmentationConfig:
|
|
"""Tests for AugmentationConfig dataclass."""
|
|
|
|
def test_default_values(self) -> None:
|
|
"""Test that all augmentation types have defaults."""
|
|
from shared.augmentation.config import AugmentationConfig
|
|
|
|
config = AugmentationConfig()
|
|
|
|
# All augmentation types should exist
|
|
augmentation_types = [
|
|
"perspective_warp",
|
|
"wrinkle",
|
|
"edge_damage",
|
|
"stain",
|
|
"lighting_variation",
|
|
"shadow",
|
|
"gaussian_blur",
|
|
"motion_blur",
|
|
"gaussian_noise",
|
|
"salt_pepper",
|
|
"paper_texture",
|
|
"scanner_artifacts",
|
|
]
|
|
|
|
for aug_type in augmentation_types:
|
|
assert hasattr(config, aug_type), f"Missing augmentation type: {aug_type}"
|
|
params = getattr(config, aug_type)
|
|
assert hasattr(params, "enabled")
|
|
assert hasattr(params, "probability")
|
|
assert hasattr(params, "params")
|
|
|
|
def test_global_settings_defaults(self) -> None:
|
|
"""Test global settings default values."""
|
|
from shared.augmentation.config import AugmentationConfig
|
|
|
|
config = AugmentationConfig()
|
|
|
|
assert config.preserve_bboxes is True
|
|
assert config.seed is None
|
|
|
|
def test_custom_seed(self) -> None:
|
|
"""Test setting custom seed for reproducibility."""
|
|
from shared.augmentation.config import AugmentationConfig
|
|
|
|
config = AugmentationConfig(seed=42)
|
|
|
|
assert config.seed == 42
|
|
|
|
def test_to_dict(self) -> None:
|
|
"""Test conversion to dictionary."""
|
|
from shared.augmentation.config import AugmentationConfig
|
|
|
|
config = AugmentationConfig(seed=123, preserve_bboxes=False)
|
|
|
|
result = config.to_dict()
|
|
|
|
assert isinstance(result, dict)
|
|
assert result["seed"] == 123
|
|
assert result["preserve_bboxes"] is False
|
|
assert "perspective_warp" in result
|
|
assert "wrinkle" in result
|
|
|
|
def test_from_dict(self) -> None:
|
|
"""Test creation from dictionary."""
|
|
from shared.augmentation.config import AugmentationConfig
|
|
|
|
data = {
|
|
"seed": 456,
|
|
"preserve_bboxes": False,
|
|
"wrinkle": {
|
|
"enabled": True,
|
|
"probability": 0.8,
|
|
"params": {"intensity": 0.5},
|
|
},
|
|
}
|
|
|
|
config = AugmentationConfig.from_dict(data)
|
|
|
|
assert config.seed == 456
|
|
assert config.preserve_bboxes is False
|
|
assert config.wrinkle.enabled is True
|
|
assert config.wrinkle.probability == 0.8
|
|
assert config.wrinkle.params["intensity"] == 0.5
|
|
|
|
def test_from_dict_with_partial_data(self) -> None:
|
|
"""Test creation from partial dictionary uses defaults."""
|
|
from shared.augmentation.config import AugmentationConfig
|
|
|
|
data: dict[str, Any] = {
|
|
"wrinkle": {"enabled": True},
|
|
}
|
|
|
|
config = AugmentationConfig.from_dict(data)
|
|
|
|
# Explicitly set value
|
|
assert config.wrinkle.enabled is True
|
|
# Default values
|
|
assert config.preserve_bboxes is True
|
|
assert config.seed is None
|
|
assert config.gaussian_blur.enabled is False
|
|
|
|
def test_get_enabled_augmentations(self) -> None:
|
|
"""Test getting list of enabled augmentations."""
|
|
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
|
|
|
config = AugmentationConfig(
|
|
wrinkle=AugmentationParams(enabled=True),
|
|
stain=AugmentationParams(enabled=True),
|
|
gaussian_blur=AugmentationParams(enabled=False),
|
|
)
|
|
|
|
enabled = config.get_enabled_augmentations()
|
|
|
|
assert "wrinkle" in enabled
|
|
assert "stain" in enabled
|
|
assert "gaussian_blur" not in enabled
|
|
|
|
def test_document_safe_defaults(self) -> None:
|
|
"""Test that default params are document-safe (conservative)."""
|
|
from shared.augmentation.config import AugmentationConfig
|
|
|
|
config = AugmentationConfig()
|
|
|
|
# Perspective warp should be very conservative
|
|
assert config.perspective_warp.params.get("max_warp", 0.02) <= 0.05
|
|
|
|
# Noise should be subtle
|
|
noise_std = config.gaussian_noise.params.get("std", (5, 15))
|
|
if isinstance(noise_std, tuple):
|
|
assert noise_std[1] <= 20 # Max std should be low
|
|
|
|
def test_immutability_between_instances(self) -> None:
|
|
"""Test that config instances are independent."""
|
|
from shared.augmentation.config import AugmentationConfig
|
|
|
|
config1 = AugmentationConfig()
|
|
config2 = AugmentationConfig()
|
|
|
|
# Modifying one should not affect the other
|
|
config1.wrinkle.params["test"] = "value"
|
|
|
|
assert "test" not in config2.wrinkle.params
|
|
|
|
|
|
class TestAugmentationConfigValidation:
|
|
"""Tests for configuration validation."""
|
|
|
|
def test_probability_range_validation(self) -> None:
|
|
"""Test that probability values are validated."""
|
|
from shared.augmentation.config import AugmentationParams
|
|
|
|
# Valid range
|
|
params = AugmentationParams(probability=0.5)
|
|
assert params.probability == 0.5
|
|
|
|
# Edge cases
|
|
params_zero = AugmentationParams(probability=0.0)
|
|
assert params_zero.probability == 0.0
|
|
|
|
params_one = AugmentationParams(probability=1.0)
|
|
assert params_one.probability == 1.0
|
|
|
|
def test_config_validate_method(self) -> None:
|
|
"""Test the validate method catches invalid configurations."""
|
|
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
|
|
|
# Invalid probability
|
|
config = AugmentationConfig(
|
|
wrinkle=AugmentationParams(probability=1.5), # Invalid
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="probability"):
|
|
config.validate()
|
|
|
|
def test_config_validate_negative_probability(self) -> None:
|
|
"""Test validation catches negative probability."""
|
|
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
|
|
|
config = AugmentationConfig(
|
|
wrinkle=AugmentationParams(probability=-0.1),
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="probability"):
|
|
config.validate()
|