""" Tests for augmentation configuration module. TDD Phase 1: RED - Write tests first, then implement to pass. """ from typing import Any import pytest class TestAugmentationParams: """Tests for AugmentationParams dataclass.""" def test_default_values(self) -> None: """Test default parameter values.""" from shared.augmentation.config import AugmentationParams params = AugmentationParams() assert params.enabled is False assert params.probability == 0.5 assert params.params == {} def test_custom_values(self) -> None: """Test creating params with custom values.""" from shared.augmentation.config import AugmentationParams params = AugmentationParams( enabled=True, probability=0.8, params={"intensity": 0.5, "num_wrinkles": (2, 5)}, ) assert params.enabled is True assert params.probability == 0.8 assert params.params["intensity"] == 0.5 assert params.params["num_wrinkles"] == (2, 5) def test_immutability_params_dict(self) -> None: """Test that params dict is independent between instances.""" from shared.augmentation.config import AugmentationParams params1 = AugmentationParams() params2 = AugmentationParams() # Modifying one should not affect the other params1.params["test"] = "value" assert "test" not in params2.params def test_to_dict(self) -> None: """Test conversion to dictionary.""" from shared.augmentation.config import AugmentationParams params = AugmentationParams( enabled=True, probability=0.7, params={"key": "value"}, ) result = params.to_dict() assert result == { "enabled": True, "probability": 0.7, "params": {"key": "value"}, } def test_from_dict(self) -> None: """Test creation from dictionary.""" from shared.augmentation.config import AugmentationParams data = { "enabled": True, "probability": 0.6, "params": {"intensity": 0.3}, } params = AugmentationParams.from_dict(data) assert params.enabled is True assert params.probability == 0.6 assert params.params == {"intensity": 0.3} def test_from_dict_with_defaults(self) -> None: """Test creation from partial dictionary uses defaults.""" from shared.augmentation.config import AugmentationParams data: dict[str, Any] = {"enabled": True} params = AugmentationParams.from_dict(data) assert params.enabled is True assert params.probability == 0.5 # default assert params.params == {} # default class TestAugmentationConfig: """Tests for AugmentationConfig dataclass.""" def test_default_values(self) -> None: """Test that all augmentation types have defaults.""" from shared.augmentation.config import AugmentationConfig config = AugmentationConfig() # All augmentation types should exist augmentation_types = [ "perspective_warp", "wrinkle", "edge_damage", "stain", "lighting_variation", "shadow", "gaussian_blur", "motion_blur", "gaussian_noise", "salt_pepper", "paper_texture", "scanner_artifacts", ] for aug_type in augmentation_types: assert hasattr(config, aug_type), f"Missing augmentation type: {aug_type}" params = getattr(config, aug_type) assert hasattr(params, "enabled") assert hasattr(params, "probability") assert hasattr(params, "params") def test_global_settings_defaults(self) -> None: """Test global settings default values.""" from shared.augmentation.config import AugmentationConfig config = AugmentationConfig() assert config.preserve_bboxes is True assert config.seed is None def test_custom_seed(self) -> None: """Test setting custom seed for reproducibility.""" from shared.augmentation.config import AugmentationConfig config = AugmentationConfig(seed=42) assert config.seed == 42 def test_to_dict(self) -> None: """Test conversion to dictionary.""" from shared.augmentation.config import AugmentationConfig config = AugmentationConfig(seed=123, preserve_bboxes=False) result = config.to_dict() assert isinstance(result, dict) assert result["seed"] == 123 assert result["preserve_bboxes"] is False assert "perspective_warp" in result assert "wrinkle" in result def test_from_dict(self) -> None: """Test creation from dictionary.""" from shared.augmentation.config import AugmentationConfig data = { "seed": 456, "preserve_bboxes": False, "wrinkle": { "enabled": True, "probability": 0.8, "params": {"intensity": 0.5}, }, } config = AugmentationConfig.from_dict(data) assert config.seed == 456 assert config.preserve_bboxes is False assert config.wrinkle.enabled is True assert config.wrinkle.probability == 0.8 assert config.wrinkle.params["intensity"] == 0.5 def test_from_dict_with_partial_data(self) -> None: """Test creation from partial dictionary uses defaults.""" from shared.augmentation.config import AugmentationConfig data: dict[str, Any] = { "wrinkle": {"enabled": True}, } config = AugmentationConfig.from_dict(data) # Explicitly set value assert config.wrinkle.enabled is True # Default values assert config.preserve_bboxes is True assert config.seed is None assert config.gaussian_blur.enabled is False def test_get_enabled_augmentations(self) -> None: """Test getting list of enabled augmentations.""" from shared.augmentation.config import AugmentationConfig, AugmentationParams config = AugmentationConfig( wrinkle=AugmentationParams(enabled=True), stain=AugmentationParams(enabled=True), gaussian_blur=AugmentationParams(enabled=False), ) enabled = config.get_enabled_augmentations() assert "wrinkle" in enabled assert "stain" in enabled assert "gaussian_blur" not in enabled def test_document_safe_defaults(self) -> None: """Test that default params are document-safe (conservative).""" from shared.augmentation.config import AugmentationConfig config = AugmentationConfig() # Perspective warp should be very conservative assert config.perspective_warp.params.get("max_warp", 0.02) <= 0.05 # Noise should be subtle noise_std = config.gaussian_noise.params.get("std", (5, 15)) if isinstance(noise_std, tuple): assert noise_std[1] <= 20 # Max std should be low def test_immutability_between_instances(self) -> None: """Test that config instances are independent.""" from shared.augmentation.config import AugmentationConfig config1 = AugmentationConfig() config2 = AugmentationConfig() # Modifying one should not affect the other config1.wrinkle.params["test"] = "value" assert "test" not in config2.wrinkle.params class TestAugmentationConfigValidation: """Tests for configuration validation.""" def test_probability_range_validation(self) -> None: """Test that probability values are validated.""" from shared.augmentation.config import AugmentationParams # Valid range params = AugmentationParams(probability=0.5) assert params.probability == 0.5 # Edge cases params_zero = AugmentationParams(probability=0.0) assert params_zero.probability == 0.0 params_one = AugmentationParams(probability=1.0) assert params_one.probability == 1.0 def test_config_validate_method(self) -> None: """Test the validate method catches invalid configurations.""" from shared.augmentation.config import AugmentationConfig, AugmentationParams # Invalid probability config = AugmentationConfig( wrinkle=AugmentationParams(probability=1.5), # Invalid ) with pytest.raises(ValueError, match="probability"): config.validate() def test_config_validate_negative_probability(self) -> None: """Test validation catches negative probability.""" from shared.augmentation.config import AugmentationConfig, AugmentationParams config = AugmentationConfig( wrinkle=AugmentationParams(probability=-0.1), ) with pytest.raises(ValueError, match="probability"): config.validate()