This commit is contained in:
Yaojia Wang
2026-01-30 00:44:21 +01:00
parent d2489a97d4
commit 33ada0350d
79 changed files with 9737 additions and 297 deletions

View File

@@ -0,0 +1,283 @@
"""
Tests for augmentation configuration module.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
from typing import Any
import pytest
class TestAugmentationParams:
"""Tests for AugmentationParams dataclass."""
def test_default_values(self) -> None:
"""Test default parameter values."""
from shared.augmentation.config import AugmentationParams
params = AugmentationParams()
assert params.enabled is False
assert params.probability == 0.5
assert params.params == {}
def test_custom_values(self) -> None:
"""Test creating params with custom values."""
from shared.augmentation.config import AugmentationParams
params = AugmentationParams(
enabled=True,
probability=0.8,
params={"intensity": 0.5, "num_wrinkles": (2, 5)},
)
assert params.enabled is True
assert params.probability == 0.8
assert params.params["intensity"] == 0.5
assert params.params["num_wrinkles"] == (2, 5)
def test_immutability_params_dict(self) -> None:
"""Test that params dict is independent between instances."""
from shared.augmentation.config import AugmentationParams
params1 = AugmentationParams()
params2 = AugmentationParams()
# Modifying one should not affect the other
params1.params["test"] = "value"
assert "test" not in params2.params
def test_to_dict(self) -> None:
"""Test conversion to dictionary."""
from shared.augmentation.config import AugmentationParams
params = AugmentationParams(
enabled=True,
probability=0.7,
params={"key": "value"},
)
result = params.to_dict()
assert result == {
"enabled": True,
"probability": 0.7,
"params": {"key": "value"},
}
def test_from_dict(self) -> None:
"""Test creation from dictionary."""
from shared.augmentation.config import AugmentationParams
data = {
"enabled": True,
"probability": 0.6,
"params": {"intensity": 0.3},
}
params = AugmentationParams.from_dict(data)
assert params.enabled is True
assert params.probability == 0.6
assert params.params == {"intensity": 0.3}
def test_from_dict_with_defaults(self) -> None:
"""Test creation from partial dictionary uses defaults."""
from shared.augmentation.config import AugmentationParams
data: dict[str, Any] = {"enabled": True}
params = AugmentationParams.from_dict(data)
assert params.enabled is True
assert params.probability == 0.5 # default
assert params.params == {} # default
class TestAugmentationConfig:
"""Tests for AugmentationConfig dataclass."""
def test_default_values(self) -> None:
"""Test that all augmentation types have defaults."""
from shared.augmentation.config import AugmentationConfig
config = AugmentationConfig()
# All augmentation types should exist
augmentation_types = [
"perspective_warp",
"wrinkle",
"edge_damage",
"stain",
"lighting_variation",
"shadow",
"gaussian_blur",
"motion_blur",
"gaussian_noise",
"salt_pepper",
"paper_texture",
"scanner_artifacts",
]
for aug_type in augmentation_types:
assert hasattr(config, aug_type), f"Missing augmentation type: {aug_type}"
params = getattr(config, aug_type)
assert hasattr(params, "enabled")
assert hasattr(params, "probability")
assert hasattr(params, "params")
def test_global_settings_defaults(self) -> None:
"""Test global settings default values."""
from shared.augmentation.config import AugmentationConfig
config = AugmentationConfig()
assert config.preserve_bboxes is True
assert config.seed is None
def test_custom_seed(self) -> None:
"""Test setting custom seed for reproducibility."""
from shared.augmentation.config import AugmentationConfig
config = AugmentationConfig(seed=42)
assert config.seed == 42
def test_to_dict(self) -> None:
"""Test conversion to dictionary."""
from shared.augmentation.config import AugmentationConfig
config = AugmentationConfig(seed=123, preserve_bboxes=False)
result = config.to_dict()
assert isinstance(result, dict)
assert result["seed"] == 123
assert result["preserve_bboxes"] is False
assert "perspective_warp" in result
assert "wrinkle" in result
def test_from_dict(self) -> None:
"""Test creation from dictionary."""
from shared.augmentation.config import AugmentationConfig
data = {
"seed": 456,
"preserve_bboxes": False,
"wrinkle": {
"enabled": True,
"probability": 0.8,
"params": {"intensity": 0.5},
},
}
config = AugmentationConfig.from_dict(data)
assert config.seed == 456
assert config.preserve_bboxes is False
assert config.wrinkle.enabled is True
assert config.wrinkle.probability == 0.8
assert config.wrinkle.params["intensity"] == 0.5
def test_from_dict_with_partial_data(self) -> None:
"""Test creation from partial dictionary uses defaults."""
from shared.augmentation.config import AugmentationConfig
data: dict[str, Any] = {
"wrinkle": {"enabled": True},
}
config = AugmentationConfig.from_dict(data)
# Explicitly set value
assert config.wrinkle.enabled is True
# Default values
assert config.preserve_bboxes is True
assert config.seed is None
assert config.gaussian_blur.enabled is False
def test_get_enabled_augmentations(self) -> None:
"""Test getting list of enabled augmentations."""
from shared.augmentation.config import AugmentationConfig, AugmentationParams
config = AugmentationConfig(
wrinkle=AugmentationParams(enabled=True),
stain=AugmentationParams(enabled=True),
gaussian_blur=AugmentationParams(enabled=False),
)
enabled = config.get_enabled_augmentations()
assert "wrinkle" in enabled
assert "stain" in enabled
assert "gaussian_blur" not in enabled
def test_document_safe_defaults(self) -> None:
"""Test that default params are document-safe (conservative)."""
from shared.augmentation.config import AugmentationConfig
config = AugmentationConfig()
# Perspective warp should be very conservative
assert config.perspective_warp.params.get("max_warp", 0.02) <= 0.05
# Noise should be subtle
noise_std = config.gaussian_noise.params.get("std", (5, 15))
if isinstance(noise_std, tuple):
assert noise_std[1] <= 20 # Max std should be low
def test_immutability_between_instances(self) -> None:
"""Test that config instances are independent."""
from shared.augmentation.config import AugmentationConfig
config1 = AugmentationConfig()
config2 = AugmentationConfig()
# Modifying one should not affect the other
config1.wrinkle.params["test"] = "value"
assert "test" not in config2.wrinkle.params
class TestAugmentationConfigValidation:
"""Tests for configuration validation."""
def test_probability_range_validation(self) -> None:
"""Test that probability values are validated."""
from shared.augmentation.config import AugmentationParams
# Valid range
params = AugmentationParams(probability=0.5)
assert params.probability == 0.5
# Edge cases
params_zero = AugmentationParams(probability=0.0)
assert params_zero.probability == 0.0
params_one = AugmentationParams(probability=1.0)
assert params_one.probability == 1.0
def test_config_validate_method(self) -> None:
"""Test the validate method catches invalid configurations."""
from shared.augmentation.config import AugmentationConfig, AugmentationParams
# Invalid probability
config = AugmentationConfig(
wrinkle=AugmentationParams(probability=1.5), # Invalid
)
with pytest.raises(ValueError, match="probability"):
config.validate()
def test_config_validate_negative_probability(self) -> None:
"""Test validation catches negative probability."""
from shared.augmentation.config import AugmentationConfig, AugmentationParams
config = AugmentationConfig(
wrinkle=AugmentationParams(probability=-0.1),
)
with pytest.raises(ValueError, match="probability"):
config.validate()