WIP
This commit is contained in:
347
tests/shared/augmentation/test_base.py
Normal file
347
tests/shared/augmentation/test_base.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
Tests for augmentation base module.
|
||||
|
||||
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
class TestAugmentationResult:
|
||||
"""Tests for AugmentationResult dataclass."""
|
||||
|
||||
def test_minimal_result(self) -> None:
|
||||
"""Test creating result with only required field."""
|
||||
from shared.augmentation.base import AugmentationResult
|
||||
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
result = AugmentationResult(image=image)
|
||||
|
||||
assert result.image is image
|
||||
assert result.bboxes is None
|
||||
assert result.transform_matrix is None
|
||||
assert result.applied is True
|
||||
assert result.metadata is None
|
||||
|
||||
def test_full_result(self) -> None:
|
||||
"""Test creating result with all fields."""
|
||||
from shared.augmentation.base import AugmentationResult
|
||||
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]])
|
||||
transform = np.eye(3)
|
||||
metadata = {"applied_transform": "wrinkle"}
|
||||
|
||||
result = AugmentationResult(
|
||||
image=image,
|
||||
bboxes=bboxes,
|
||||
transform_matrix=transform,
|
||||
applied=True,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
assert result.image is image
|
||||
np.testing.assert_array_equal(result.bboxes, bboxes)
|
||||
np.testing.assert_array_equal(result.transform_matrix, transform)
|
||||
assert result.applied is True
|
||||
assert result.metadata == {"applied_transform": "wrinkle"}
|
||||
|
||||
def test_not_applied(self) -> None:
|
||||
"""Test result when augmentation was not applied."""
|
||||
from shared.augmentation.base import AugmentationResult
|
||||
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
result = AugmentationResult(image=image, applied=False)
|
||||
|
||||
assert result.applied is False
|
||||
|
||||
|
||||
class TestBaseAugmentation:
|
||||
"""Tests for BaseAugmentation abstract class."""
|
||||
|
||||
def test_cannot_instantiate_directly(self) -> None:
|
||||
"""Test that BaseAugmentation cannot be instantiated."""
|
||||
from shared.augmentation.base import BaseAugmentation
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
BaseAugmentation({}) # type: ignore
|
||||
|
||||
def test_subclass_must_implement_apply(self) -> None:
|
||||
"""Test that subclass must implement apply method."""
|
||||
from shared.augmentation.base import BaseAugmentation
|
||||
|
||||
class IncompleteAugmentation(BaseAugmentation):
|
||||
name = "incomplete"
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
pass
|
||||
|
||||
# Missing apply method
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
IncompleteAugmentation({}) # type: ignore
|
||||
|
||||
def test_subclass_must_implement_validate_params(self) -> None:
|
||||
"""Test that subclass must implement _validate_params."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
class IncompleteAugmentation(BaseAugmentation):
|
||||
name = "incomplete"
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
return AugmentationResult(image=image)
|
||||
|
||||
# Missing _validate_params method
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
IncompleteAugmentation({}) # type: ignore
|
||||
|
||||
def test_valid_subclass(self) -> None:
|
||||
"""Test creating a valid subclass."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
class DummyAugmentation(BaseAugmentation):
|
||||
name = "dummy"
|
||||
affects_geometry = False
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
pass
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
return AugmentationResult(image=image, bboxes=bboxes)
|
||||
|
||||
aug = DummyAugmentation({"param1": "value1"})
|
||||
|
||||
assert aug.name == "dummy"
|
||||
assert aug.affects_geometry is False
|
||||
assert aug.params == {"param1": "value1"}
|
||||
|
||||
def test_apply_returns_augmentation_result(self) -> None:
|
||||
"""Test that apply returns AugmentationResult."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
class DummyAugmentation(BaseAugmentation):
|
||||
name = "dummy"
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
pass
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
# Simple pass-through
|
||||
return AugmentationResult(image=image, bboxes=bboxes)
|
||||
|
||||
aug = DummyAugmentation({})
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]])
|
||||
|
||||
result = aug.apply(image, bboxes)
|
||||
|
||||
assert isinstance(result, AugmentationResult)
|
||||
assert result.image is image
|
||||
np.testing.assert_array_equal(result.bboxes, bboxes)
|
||||
|
||||
def test_affects_geometry_default(self) -> None:
|
||||
"""Test that affects_geometry defaults to False."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
class DummyAugmentation(BaseAugmentation):
|
||||
name = "dummy"
|
||||
# Not setting affects_geometry
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
pass
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
return AugmentationResult(image=image)
|
||||
|
||||
aug = DummyAugmentation({})
|
||||
|
||||
assert aug.affects_geometry is False
|
||||
|
||||
def test_validate_params_called_on_init(self) -> None:
|
||||
"""Test that _validate_params is called during initialization."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
validation_called = {"called": False}
|
||||
|
||||
class ValidatingAugmentation(BaseAugmentation):
|
||||
name = "validating"
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
validation_called["called"] = True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
return AugmentationResult(image=image)
|
||||
|
||||
ValidatingAugmentation({})
|
||||
|
||||
assert validation_called["called"] is True
|
||||
|
||||
def test_validate_params_raises_on_invalid(self) -> None:
|
||||
"""Test that _validate_params can raise ValueError."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
class StrictAugmentation(BaseAugmentation):
|
||||
name = "strict"
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
if "required_param" not in self.params:
|
||||
raise ValueError("required_param is required")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
return AugmentationResult(image=image)
|
||||
|
||||
with pytest.raises(ValueError, match="required_param"):
|
||||
StrictAugmentation({})
|
||||
|
||||
# Should work with required param
|
||||
aug = StrictAugmentation({"required_param": "value"})
|
||||
assert aug.params["required_param"] == "value"
|
||||
|
||||
def test_rng_usage(self) -> None:
|
||||
"""Test that random generator can be passed and used."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
class RandomAugmentation(BaseAugmentation):
|
||||
name = "random"
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
pass
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
if rng is None:
|
||||
rng = np.random.default_rng()
|
||||
# Use rng to generate a random value
|
||||
random_value = rng.random()
|
||||
return AugmentationResult(
|
||||
image=image,
|
||||
metadata={"random_value": random_value},
|
||||
)
|
||||
|
||||
aug = RandomAugmentation({})
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
# With same seed, should get same result
|
||||
rng1 = np.random.default_rng(42)
|
||||
rng2 = np.random.default_rng(42)
|
||||
|
||||
result1 = aug.apply(image, rng=rng1)
|
||||
result2 = aug.apply(image, rng=rng2)
|
||||
|
||||
assert result1.metadata is not None
|
||||
assert result2.metadata is not None
|
||||
assert result1.metadata["random_value"] == result2.metadata["random_value"]
|
||||
|
||||
|
||||
class TestAugmentationResultImmutability:
|
||||
"""Tests for ensuring result doesn't mutate input."""
|
||||
|
||||
def test_image_not_modified(self) -> None:
|
||||
"""Test that original image is not modified."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
class ModifyingAugmentation(BaseAugmentation):
|
||||
name = "modifying"
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
pass
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
# Should copy before modifying
|
||||
modified = image.copy()
|
||||
modified[:] = 255
|
||||
return AugmentationResult(image=modified)
|
||||
|
||||
aug = ModifyingAugmentation({})
|
||||
original = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
original_copy = original.copy()
|
||||
|
||||
result = aug.apply(original)
|
||||
|
||||
# Original should be unchanged
|
||||
np.testing.assert_array_equal(original, original_copy)
|
||||
# Result should be modified
|
||||
assert np.all(result.image == 255)
|
||||
|
||||
def test_bboxes_not_modified(self) -> None:
|
||||
"""Test that original bboxes are not modified."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
class BboxModifyingAugmentation(BaseAugmentation):
|
||||
name = "bbox_modifying"
|
||||
affects_geometry = True
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
pass
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
if bboxes is not None:
|
||||
# Should copy before modifying
|
||||
modified_bboxes = bboxes.copy()
|
||||
modified_bboxes[:, 1:] *= 0.5 # Scale down
|
||||
return AugmentationResult(image=image, bboxes=modified_bboxes)
|
||||
return AugmentationResult(image=image)
|
||||
|
||||
aug = BboxModifyingAugmentation({})
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
original_bboxes = np.array([[0, 0.5, 0.5, 0.2, 0.2]], dtype=np.float32)
|
||||
original_bboxes_copy = original_bboxes.copy()
|
||||
|
||||
result = aug.apply(image, original_bboxes)
|
||||
|
||||
# Original should be unchanged
|
||||
np.testing.assert_array_equal(original_bboxes, original_bboxes_copy)
|
||||
# Result should be modified
|
||||
assert result.bboxes is not None
|
||||
np.testing.assert_array_almost_equal(
|
||||
result.bboxes, np.array([[0, 0.25, 0.25, 0.1, 0.1]])
|
||||
)
|
||||
Reference in New Issue
Block a user