348 lines
12 KiB
Python
348 lines
12 KiB
Python
"""
|
|
Tests for augmentation base module.
|
|
|
|
TDD Phase 1: RED - Write tests first, then implement to pass.
|
|
"""
|
|
|
|
from typing import Any
|
|
from unittest.mock import MagicMock
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
|
|
class TestAugmentationResult:
|
|
"""Tests for AugmentationResult dataclass."""
|
|
|
|
def test_minimal_result(self) -> None:
|
|
"""Test creating result with only required field."""
|
|
from shared.augmentation.base import AugmentationResult
|
|
|
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
|
result = AugmentationResult(image=image)
|
|
|
|
assert result.image is image
|
|
assert result.bboxes is None
|
|
assert result.transform_matrix is None
|
|
assert result.applied is True
|
|
assert result.metadata is None
|
|
|
|
def test_full_result(self) -> None:
|
|
"""Test creating result with all fields."""
|
|
from shared.augmentation.base import AugmentationResult
|
|
|
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
|
bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]])
|
|
transform = np.eye(3)
|
|
metadata = {"applied_transform": "wrinkle"}
|
|
|
|
result = AugmentationResult(
|
|
image=image,
|
|
bboxes=bboxes,
|
|
transform_matrix=transform,
|
|
applied=True,
|
|
metadata=metadata,
|
|
)
|
|
|
|
assert result.image is image
|
|
np.testing.assert_array_equal(result.bboxes, bboxes)
|
|
np.testing.assert_array_equal(result.transform_matrix, transform)
|
|
assert result.applied is True
|
|
assert result.metadata == {"applied_transform": "wrinkle"}
|
|
|
|
def test_not_applied(self) -> None:
|
|
"""Test result when augmentation was not applied."""
|
|
from shared.augmentation.base import AugmentationResult
|
|
|
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
|
result = AugmentationResult(image=image, applied=False)
|
|
|
|
assert result.applied is False
|
|
|
|
|
|
class TestBaseAugmentation:
|
|
"""Tests for BaseAugmentation abstract class."""
|
|
|
|
def test_cannot_instantiate_directly(self) -> None:
|
|
"""Test that BaseAugmentation cannot be instantiated."""
|
|
from shared.augmentation.base import BaseAugmentation
|
|
|
|
with pytest.raises(TypeError):
|
|
BaseAugmentation({}) # type: ignore
|
|
|
|
def test_subclass_must_implement_apply(self) -> None:
|
|
"""Test that subclass must implement apply method."""
|
|
from shared.augmentation.base import BaseAugmentation
|
|
|
|
class IncompleteAugmentation(BaseAugmentation):
|
|
name = "incomplete"
|
|
|
|
def _validate_params(self) -> None:
|
|
pass
|
|
|
|
# Missing apply method
|
|
|
|
with pytest.raises(TypeError):
|
|
IncompleteAugmentation({}) # type: ignore
|
|
|
|
def test_subclass_must_implement_validate_params(self) -> None:
|
|
"""Test that subclass must implement _validate_params."""
|
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
|
|
|
class IncompleteAugmentation(BaseAugmentation):
|
|
name = "incomplete"
|
|
|
|
def apply(
|
|
self,
|
|
image: np.ndarray,
|
|
bboxes: np.ndarray | None = None,
|
|
rng: np.random.Generator | None = None,
|
|
) -> AugmentationResult:
|
|
return AugmentationResult(image=image)
|
|
|
|
# Missing _validate_params method
|
|
|
|
with pytest.raises(TypeError):
|
|
IncompleteAugmentation({}) # type: ignore
|
|
|
|
def test_valid_subclass(self) -> None:
|
|
"""Test creating a valid subclass."""
|
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
|
|
|
class DummyAugmentation(BaseAugmentation):
|
|
name = "dummy"
|
|
affects_geometry = False
|
|
|
|
def _validate_params(self) -> None:
|
|
pass
|
|
|
|
def apply(
|
|
self,
|
|
image: np.ndarray,
|
|
bboxes: np.ndarray | None = None,
|
|
rng: np.random.Generator | None = None,
|
|
) -> AugmentationResult:
|
|
return AugmentationResult(image=image, bboxes=bboxes)
|
|
|
|
aug = DummyAugmentation({"param1": "value1"})
|
|
|
|
assert aug.name == "dummy"
|
|
assert aug.affects_geometry is False
|
|
assert aug.params == {"param1": "value1"}
|
|
|
|
def test_apply_returns_augmentation_result(self) -> None:
|
|
"""Test that apply returns AugmentationResult."""
|
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
|
|
|
class DummyAugmentation(BaseAugmentation):
|
|
name = "dummy"
|
|
|
|
def _validate_params(self) -> None:
|
|
pass
|
|
|
|
def apply(
|
|
self,
|
|
image: np.ndarray,
|
|
bboxes: np.ndarray | None = None,
|
|
rng: np.random.Generator | None = None,
|
|
) -> AugmentationResult:
|
|
# Simple pass-through
|
|
return AugmentationResult(image=image, bboxes=bboxes)
|
|
|
|
aug = DummyAugmentation({})
|
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
|
bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]])
|
|
|
|
result = aug.apply(image, bboxes)
|
|
|
|
assert isinstance(result, AugmentationResult)
|
|
assert result.image is image
|
|
np.testing.assert_array_equal(result.bboxes, bboxes)
|
|
|
|
def test_affects_geometry_default(self) -> None:
|
|
"""Test that affects_geometry defaults to False."""
|
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
|
|
|
class DummyAugmentation(BaseAugmentation):
|
|
name = "dummy"
|
|
# Not setting affects_geometry
|
|
|
|
def _validate_params(self) -> None:
|
|
pass
|
|
|
|
def apply(
|
|
self,
|
|
image: np.ndarray,
|
|
bboxes: np.ndarray | None = None,
|
|
rng: np.random.Generator | None = None,
|
|
) -> AugmentationResult:
|
|
return AugmentationResult(image=image)
|
|
|
|
aug = DummyAugmentation({})
|
|
|
|
assert aug.affects_geometry is False
|
|
|
|
def test_validate_params_called_on_init(self) -> None:
|
|
"""Test that _validate_params is called during initialization."""
|
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
|
|
|
validation_called = {"called": False}
|
|
|
|
class ValidatingAugmentation(BaseAugmentation):
|
|
name = "validating"
|
|
|
|
def _validate_params(self) -> None:
|
|
validation_called["called"] = True
|
|
|
|
def apply(
|
|
self,
|
|
image: np.ndarray,
|
|
bboxes: np.ndarray | None = None,
|
|
rng: np.random.Generator | None = None,
|
|
) -> AugmentationResult:
|
|
return AugmentationResult(image=image)
|
|
|
|
ValidatingAugmentation({})
|
|
|
|
assert validation_called["called"] is True
|
|
|
|
def test_validate_params_raises_on_invalid(self) -> None:
|
|
"""Test that _validate_params can raise ValueError."""
|
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
|
|
|
class StrictAugmentation(BaseAugmentation):
|
|
name = "strict"
|
|
|
|
def _validate_params(self) -> None:
|
|
if "required_param" not in self.params:
|
|
raise ValueError("required_param is required")
|
|
|
|
def apply(
|
|
self,
|
|
image: np.ndarray,
|
|
bboxes: np.ndarray | None = None,
|
|
rng: np.random.Generator | None = None,
|
|
) -> AugmentationResult:
|
|
return AugmentationResult(image=image)
|
|
|
|
with pytest.raises(ValueError, match="required_param"):
|
|
StrictAugmentation({})
|
|
|
|
# Should work with required param
|
|
aug = StrictAugmentation({"required_param": "value"})
|
|
assert aug.params["required_param"] == "value"
|
|
|
|
def test_rng_usage(self) -> None:
|
|
"""Test that random generator can be passed and used."""
|
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
|
|
|
class RandomAugmentation(BaseAugmentation):
|
|
name = "random"
|
|
|
|
def _validate_params(self) -> None:
|
|
pass
|
|
|
|
def apply(
|
|
self,
|
|
image: np.ndarray,
|
|
bboxes: np.ndarray | None = None,
|
|
rng: np.random.Generator | None = None,
|
|
) -> AugmentationResult:
|
|
if rng is None:
|
|
rng = np.random.default_rng()
|
|
# Use rng to generate a random value
|
|
random_value = rng.random()
|
|
return AugmentationResult(
|
|
image=image,
|
|
metadata={"random_value": random_value},
|
|
)
|
|
|
|
aug = RandomAugmentation({})
|
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
|
|
|
# With same seed, should get same result
|
|
rng1 = np.random.default_rng(42)
|
|
rng2 = np.random.default_rng(42)
|
|
|
|
result1 = aug.apply(image, rng=rng1)
|
|
result2 = aug.apply(image, rng=rng2)
|
|
|
|
assert result1.metadata is not None
|
|
assert result2.metadata is not None
|
|
assert result1.metadata["random_value"] == result2.metadata["random_value"]
|
|
|
|
|
|
class TestAugmentationResultImmutability:
|
|
"""Tests for ensuring result doesn't mutate input."""
|
|
|
|
def test_image_not_modified(self) -> None:
|
|
"""Test that original image is not modified."""
|
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
|
|
|
class ModifyingAugmentation(BaseAugmentation):
|
|
name = "modifying"
|
|
|
|
def _validate_params(self) -> None:
|
|
pass
|
|
|
|
def apply(
|
|
self,
|
|
image: np.ndarray,
|
|
bboxes: np.ndarray | None = None,
|
|
rng: np.random.Generator | None = None,
|
|
) -> AugmentationResult:
|
|
# Should copy before modifying
|
|
modified = image.copy()
|
|
modified[:] = 255
|
|
return AugmentationResult(image=modified)
|
|
|
|
aug = ModifyingAugmentation({})
|
|
original = np.zeros((100, 100, 3), dtype=np.uint8)
|
|
original_copy = original.copy()
|
|
|
|
result = aug.apply(original)
|
|
|
|
# Original should be unchanged
|
|
np.testing.assert_array_equal(original, original_copy)
|
|
# Result should be modified
|
|
assert np.all(result.image == 255)
|
|
|
|
def test_bboxes_not_modified(self) -> None:
|
|
"""Test that original bboxes are not modified."""
|
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
|
|
|
class BboxModifyingAugmentation(BaseAugmentation):
|
|
name = "bbox_modifying"
|
|
affects_geometry = True
|
|
|
|
def _validate_params(self) -> None:
|
|
pass
|
|
|
|
def apply(
|
|
self,
|
|
image: np.ndarray,
|
|
bboxes: np.ndarray | None = None,
|
|
rng: np.random.Generator | None = None,
|
|
) -> AugmentationResult:
|
|
if bboxes is not None:
|
|
# Should copy before modifying
|
|
modified_bboxes = bboxes.copy()
|
|
modified_bboxes[:, 1:] *= 0.5 # Scale down
|
|
return AugmentationResult(image=image, bboxes=modified_bboxes)
|
|
return AugmentationResult(image=image)
|
|
|
|
aug = BboxModifyingAugmentation({})
|
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
|
original_bboxes = np.array([[0, 0.5, 0.5, 0.2, 0.2]], dtype=np.float32)
|
|
original_bboxes_copy = original_bboxes.copy()
|
|
|
|
result = aug.apply(image, original_bboxes)
|
|
|
|
# Original should be unchanged
|
|
np.testing.assert_array_equal(original_bboxes, original_bboxes_copy)
|
|
# Result should be modified
|
|
assert result.bboxes is not None
|
|
np.testing.assert_array_almost_equal(
|
|
result.bboxes, np.array([[0, 0.25, 0.25, 0.1, 0.1]])
|
|
)
|