WIP
This commit is contained in:
283
tests/shared/augmentation/test_config.py
Normal file
283
tests/shared/augmentation/test_config.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
Tests for augmentation configuration module.
|
||||
|
||||
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestAugmentationParams:
|
||||
"""Tests for AugmentationParams dataclass."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test default parameter values."""
|
||||
from shared.augmentation.config import AugmentationParams
|
||||
|
||||
params = AugmentationParams()
|
||||
|
||||
assert params.enabled is False
|
||||
assert params.probability == 0.5
|
||||
assert params.params == {}
|
||||
|
||||
def test_custom_values(self) -> None:
|
||||
"""Test creating params with custom values."""
|
||||
from shared.augmentation.config import AugmentationParams
|
||||
|
||||
params = AugmentationParams(
|
||||
enabled=True,
|
||||
probability=0.8,
|
||||
params={"intensity": 0.5, "num_wrinkles": (2, 5)},
|
||||
)
|
||||
|
||||
assert params.enabled is True
|
||||
assert params.probability == 0.8
|
||||
assert params.params["intensity"] == 0.5
|
||||
assert params.params["num_wrinkles"] == (2, 5)
|
||||
|
||||
def test_immutability_params_dict(self) -> None:
|
||||
"""Test that params dict is independent between instances."""
|
||||
from shared.augmentation.config import AugmentationParams
|
||||
|
||||
params1 = AugmentationParams()
|
||||
params2 = AugmentationParams()
|
||||
|
||||
# Modifying one should not affect the other
|
||||
params1.params["test"] = "value"
|
||||
|
||||
assert "test" not in params2.params
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test conversion to dictionary."""
|
||||
from shared.augmentation.config import AugmentationParams
|
||||
|
||||
params = AugmentationParams(
|
||||
enabled=True,
|
||||
probability=0.7,
|
||||
params={"key": "value"},
|
||||
)
|
||||
|
||||
result = params.to_dict()
|
||||
|
||||
assert result == {
|
||||
"enabled": True,
|
||||
"probability": 0.7,
|
||||
"params": {"key": "value"},
|
||||
}
|
||||
|
||||
def test_from_dict(self) -> None:
|
||||
"""Test creation from dictionary."""
|
||||
from shared.augmentation.config import AugmentationParams
|
||||
|
||||
data = {
|
||||
"enabled": True,
|
||||
"probability": 0.6,
|
||||
"params": {"intensity": 0.3},
|
||||
}
|
||||
|
||||
params = AugmentationParams.from_dict(data)
|
||||
|
||||
assert params.enabled is True
|
||||
assert params.probability == 0.6
|
||||
assert params.params == {"intensity": 0.3}
|
||||
|
||||
def test_from_dict_with_defaults(self) -> None:
|
||||
"""Test creation from partial dictionary uses defaults."""
|
||||
from shared.augmentation.config import AugmentationParams
|
||||
|
||||
data: dict[str, Any] = {"enabled": True}
|
||||
|
||||
params = AugmentationParams.from_dict(data)
|
||||
|
||||
assert params.enabled is True
|
||||
assert params.probability == 0.5 # default
|
||||
assert params.params == {} # default
|
||||
|
||||
|
||||
class TestAugmentationConfig:
|
||||
"""Tests for AugmentationConfig dataclass."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test that all augmentation types have defaults."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
|
||||
config = AugmentationConfig()
|
||||
|
||||
# All augmentation types should exist
|
||||
augmentation_types = [
|
||||
"perspective_warp",
|
||||
"wrinkle",
|
||||
"edge_damage",
|
||||
"stain",
|
||||
"lighting_variation",
|
||||
"shadow",
|
||||
"gaussian_blur",
|
||||
"motion_blur",
|
||||
"gaussian_noise",
|
||||
"salt_pepper",
|
||||
"paper_texture",
|
||||
"scanner_artifacts",
|
||||
]
|
||||
|
||||
for aug_type in augmentation_types:
|
||||
assert hasattr(config, aug_type), f"Missing augmentation type: {aug_type}"
|
||||
params = getattr(config, aug_type)
|
||||
assert hasattr(params, "enabled")
|
||||
assert hasattr(params, "probability")
|
||||
assert hasattr(params, "params")
|
||||
|
||||
def test_global_settings_defaults(self) -> None:
|
||||
"""Test global settings default values."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
|
||||
config = AugmentationConfig()
|
||||
|
||||
assert config.preserve_bboxes is True
|
||||
assert config.seed is None
|
||||
|
||||
def test_custom_seed(self) -> None:
|
||||
"""Test setting custom seed for reproducibility."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
|
||||
config = AugmentationConfig(seed=42)
|
||||
|
||||
assert config.seed == 42
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test conversion to dictionary."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
|
||||
config = AugmentationConfig(seed=123, preserve_bboxes=False)
|
||||
|
||||
result = config.to_dict()
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["seed"] == 123
|
||||
assert result["preserve_bboxes"] is False
|
||||
assert "perspective_warp" in result
|
||||
assert "wrinkle" in result
|
||||
|
||||
def test_from_dict(self) -> None:
|
||||
"""Test creation from dictionary."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
|
||||
data = {
|
||||
"seed": 456,
|
||||
"preserve_bboxes": False,
|
||||
"wrinkle": {
|
||||
"enabled": True,
|
||||
"probability": 0.8,
|
||||
"params": {"intensity": 0.5},
|
||||
},
|
||||
}
|
||||
|
||||
config = AugmentationConfig.from_dict(data)
|
||||
|
||||
assert config.seed == 456
|
||||
assert config.preserve_bboxes is False
|
||||
assert config.wrinkle.enabled is True
|
||||
assert config.wrinkle.probability == 0.8
|
||||
assert config.wrinkle.params["intensity"] == 0.5
|
||||
|
||||
def test_from_dict_with_partial_data(self) -> None:
|
||||
"""Test creation from partial dictionary uses defaults."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
|
||||
data: dict[str, Any] = {
|
||||
"wrinkle": {"enabled": True},
|
||||
}
|
||||
|
||||
config = AugmentationConfig.from_dict(data)
|
||||
|
||||
# Explicitly set value
|
||||
assert config.wrinkle.enabled is True
|
||||
# Default values
|
||||
assert config.preserve_bboxes is True
|
||||
assert config.seed is None
|
||||
assert config.gaussian_blur.enabled is False
|
||||
|
||||
def test_get_enabled_augmentations(self) -> None:
|
||||
"""Test getting list of enabled augmentations."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
|
||||
config = AugmentationConfig(
|
||||
wrinkle=AugmentationParams(enabled=True),
|
||||
stain=AugmentationParams(enabled=True),
|
||||
gaussian_blur=AugmentationParams(enabled=False),
|
||||
)
|
||||
|
||||
enabled = config.get_enabled_augmentations()
|
||||
|
||||
assert "wrinkle" in enabled
|
||||
assert "stain" in enabled
|
||||
assert "gaussian_blur" not in enabled
|
||||
|
||||
def test_document_safe_defaults(self) -> None:
|
||||
"""Test that default params are document-safe (conservative)."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
|
||||
config = AugmentationConfig()
|
||||
|
||||
# Perspective warp should be very conservative
|
||||
assert config.perspective_warp.params.get("max_warp", 0.02) <= 0.05
|
||||
|
||||
# Noise should be subtle
|
||||
noise_std = config.gaussian_noise.params.get("std", (5, 15))
|
||||
if isinstance(noise_std, tuple):
|
||||
assert noise_std[1] <= 20 # Max std should be low
|
||||
|
||||
def test_immutability_between_instances(self) -> None:
|
||||
"""Test that config instances are independent."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
|
||||
config1 = AugmentationConfig()
|
||||
config2 = AugmentationConfig()
|
||||
|
||||
# Modifying one should not affect the other
|
||||
config1.wrinkle.params["test"] = "value"
|
||||
|
||||
assert "test" not in config2.wrinkle.params
|
||||
|
||||
|
||||
class TestAugmentationConfigValidation:
|
||||
"""Tests for configuration validation."""
|
||||
|
||||
def test_probability_range_validation(self) -> None:
|
||||
"""Test that probability values are validated."""
|
||||
from shared.augmentation.config import AugmentationParams
|
||||
|
||||
# Valid range
|
||||
params = AugmentationParams(probability=0.5)
|
||||
assert params.probability == 0.5
|
||||
|
||||
# Edge cases
|
||||
params_zero = AugmentationParams(probability=0.0)
|
||||
assert params_zero.probability == 0.0
|
||||
|
||||
params_one = AugmentationParams(probability=1.0)
|
||||
assert params_one.probability == 1.0
|
||||
|
||||
def test_config_validate_method(self) -> None:
|
||||
"""Test the validate method catches invalid configurations."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
|
||||
# Invalid probability
|
||||
config = AugmentationConfig(
|
||||
wrinkle=AugmentationParams(probability=1.5), # Invalid
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="probability"):
|
||||
config.validate()
|
||||
|
||||
def test_config_validate_negative_probability(self) -> None:
|
||||
"""Test validation catches negative probability."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
|
||||
config = AugmentationConfig(
|
||||
wrinkle=AugmentationParams(probability=-0.1),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="probability"):
|
||||
config.validate()
|
||||
Reference in New Issue
Block a user