""" Tests for augmentation base module. TDD Phase 1: RED - Write tests first, then implement to pass. """ from typing import Any from unittest.mock import MagicMock import numpy as np import pytest class TestAugmentationResult: """Tests for AugmentationResult dataclass.""" def test_minimal_result(self) -> None: """Test creating result with only required field.""" from shared.augmentation.base import AugmentationResult image = np.zeros((100, 100, 3), dtype=np.uint8) result = AugmentationResult(image=image) assert result.image is image assert result.bboxes is None assert result.transform_matrix is None assert result.applied is True assert result.metadata is None def test_full_result(self) -> None: """Test creating result with all fields.""" from shared.augmentation.base import AugmentationResult image = np.zeros((100, 100, 3), dtype=np.uint8) bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]]) transform = np.eye(3) metadata = {"applied_transform": "wrinkle"} result = AugmentationResult( image=image, bboxes=bboxes, transform_matrix=transform, applied=True, metadata=metadata, ) assert result.image is image np.testing.assert_array_equal(result.bboxes, bboxes) np.testing.assert_array_equal(result.transform_matrix, transform) assert result.applied is True assert result.metadata == {"applied_transform": "wrinkle"} def test_not_applied(self) -> None: """Test result when augmentation was not applied.""" from shared.augmentation.base import AugmentationResult image = np.zeros((100, 100, 3), dtype=np.uint8) result = AugmentationResult(image=image, applied=False) assert result.applied is False class TestBaseAugmentation: """Tests for BaseAugmentation abstract class.""" def test_cannot_instantiate_directly(self) -> None: """Test that BaseAugmentation cannot be instantiated.""" from shared.augmentation.base import BaseAugmentation with pytest.raises(TypeError): BaseAugmentation({}) # type: ignore def test_subclass_must_implement_apply(self) -> None: """Test that subclass must implement apply method.""" from shared.augmentation.base import BaseAugmentation class IncompleteAugmentation(BaseAugmentation): name = "incomplete" def _validate_params(self) -> None: pass # Missing apply method with pytest.raises(TypeError): IncompleteAugmentation({}) # type: ignore def test_subclass_must_implement_validate_params(self) -> None: """Test that subclass must implement _validate_params.""" from shared.augmentation.base import AugmentationResult, BaseAugmentation class IncompleteAugmentation(BaseAugmentation): name = "incomplete" def apply( self, image: np.ndarray, bboxes: np.ndarray | None = None, rng: np.random.Generator | None = None, ) -> AugmentationResult: return AugmentationResult(image=image) # Missing _validate_params method with pytest.raises(TypeError): IncompleteAugmentation({}) # type: ignore def test_valid_subclass(self) -> None: """Test creating a valid subclass.""" from shared.augmentation.base import AugmentationResult, BaseAugmentation class DummyAugmentation(BaseAugmentation): name = "dummy" affects_geometry = False def _validate_params(self) -> None: pass def apply( self, image: np.ndarray, bboxes: np.ndarray | None = None, rng: np.random.Generator | None = None, ) -> AugmentationResult: return AugmentationResult(image=image, bboxes=bboxes) aug = DummyAugmentation({"param1": "value1"}) assert aug.name == "dummy" assert aug.affects_geometry is False assert aug.params == {"param1": "value1"} def test_apply_returns_augmentation_result(self) -> None: """Test that apply returns AugmentationResult.""" from shared.augmentation.base import AugmentationResult, BaseAugmentation class DummyAugmentation(BaseAugmentation): name = "dummy" def _validate_params(self) -> None: pass def apply( self, image: np.ndarray, bboxes: np.ndarray | None = None, rng: np.random.Generator | None = None, ) -> AugmentationResult: # Simple pass-through return AugmentationResult(image=image, bboxes=bboxes) aug = DummyAugmentation({}) image = np.zeros((100, 100, 3), dtype=np.uint8) bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]]) result = aug.apply(image, bboxes) assert isinstance(result, AugmentationResult) assert result.image is image np.testing.assert_array_equal(result.bboxes, bboxes) def test_affects_geometry_default(self) -> None: """Test that affects_geometry defaults to False.""" from shared.augmentation.base import AugmentationResult, BaseAugmentation class DummyAugmentation(BaseAugmentation): name = "dummy" # Not setting affects_geometry def _validate_params(self) -> None: pass def apply( self, image: np.ndarray, bboxes: np.ndarray | None = None, rng: np.random.Generator | None = None, ) -> AugmentationResult: return AugmentationResult(image=image) aug = DummyAugmentation({}) assert aug.affects_geometry is False def test_validate_params_called_on_init(self) -> None: """Test that _validate_params is called during initialization.""" from shared.augmentation.base import AugmentationResult, BaseAugmentation validation_called = {"called": False} class ValidatingAugmentation(BaseAugmentation): name = "validating" def _validate_params(self) -> None: validation_called["called"] = True def apply( self, image: np.ndarray, bboxes: np.ndarray | None = None, rng: np.random.Generator | None = None, ) -> AugmentationResult: return AugmentationResult(image=image) ValidatingAugmentation({}) assert validation_called["called"] is True def test_validate_params_raises_on_invalid(self) -> None: """Test that _validate_params can raise ValueError.""" from shared.augmentation.base import AugmentationResult, BaseAugmentation class StrictAugmentation(BaseAugmentation): name = "strict" def _validate_params(self) -> None: if "required_param" not in self.params: raise ValueError("required_param is required") def apply( self, image: np.ndarray, bboxes: np.ndarray | None = None, rng: np.random.Generator | None = None, ) -> AugmentationResult: return AugmentationResult(image=image) with pytest.raises(ValueError, match="required_param"): StrictAugmentation({}) # Should work with required param aug = StrictAugmentation({"required_param": "value"}) assert aug.params["required_param"] == "value" def test_rng_usage(self) -> None: """Test that random generator can be passed and used.""" from shared.augmentation.base import AugmentationResult, BaseAugmentation class RandomAugmentation(BaseAugmentation): name = "random" def _validate_params(self) -> None: pass def apply( self, image: np.ndarray, bboxes: np.ndarray | None = None, rng: np.random.Generator | None = None, ) -> AugmentationResult: if rng is None: rng = np.random.default_rng() # Use rng to generate a random value random_value = rng.random() return AugmentationResult( image=image, metadata={"random_value": random_value}, ) aug = RandomAugmentation({}) image = np.zeros((100, 100, 3), dtype=np.uint8) # With same seed, should get same result rng1 = np.random.default_rng(42) rng2 = np.random.default_rng(42) result1 = aug.apply(image, rng=rng1) result2 = aug.apply(image, rng=rng2) assert result1.metadata is not None assert result2.metadata is not None assert result1.metadata["random_value"] == result2.metadata["random_value"] class TestAugmentationResultImmutability: """Tests for ensuring result doesn't mutate input.""" def test_image_not_modified(self) -> None: """Test that original image is not modified.""" from shared.augmentation.base import AugmentationResult, BaseAugmentation class ModifyingAugmentation(BaseAugmentation): name = "modifying" def _validate_params(self) -> None: pass def apply( self, image: np.ndarray, bboxes: np.ndarray | None = None, rng: np.random.Generator | None = None, ) -> AugmentationResult: # Should copy before modifying modified = image.copy() modified[:] = 255 return AugmentationResult(image=modified) aug = ModifyingAugmentation({}) original = np.zeros((100, 100, 3), dtype=np.uint8) original_copy = original.copy() result = aug.apply(original) # Original should be unchanged np.testing.assert_array_equal(original, original_copy) # Result should be modified assert np.all(result.image == 255) def test_bboxes_not_modified(self) -> None: """Test that original bboxes are not modified.""" from shared.augmentation.base import AugmentationResult, BaseAugmentation class BboxModifyingAugmentation(BaseAugmentation): name = "bbox_modifying" affects_geometry = True def _validate_params(self) -> None: pass def apply( self, image: np.ndarray, bboxes: np.ndarray | None = None, rng: np.random.Generator | None = None, ) -> AugmentationResult: if bboxes is not None: # Should copy before modifying modified_bboxes = bboxes.copy() modified_bboxes[:, 1:] *= 0.5 # Scale down return AugmentationResult(image=image, bboxes=modified_bboxes) return AugmentationResult(image=image) aug = BboxModifyingAugmentation({}) image = np.zeros((100, 100, 3), dtype=np.uint8) original_bboxes = np.array([[0, 0.5, 0.5, 0.2, 0.2]], dtype=np.float32) original_bboxes_copy = original_bboxes.copy() result = aug.apply(image, original_bboxes) # Original should be unchanged np.testing.assert_array_equal(original_bboxes, original_bboxes_copy) # Result should be modified assert result.bboxes is not None np.testing.assert_array_almost_equal( result.bboxes, np.array([[0, 0.25, 0.25, 0.1, 0.1]]) )