Files
invoice-master-poc-v2/tests/shared/augmentation/test_base.py
Yaojia Wang 33ada0350d WIP
2026-01-30 00:44:21 +01:00

348 lines
12 KiB
Python

"""
Tests for augmentation base module.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
from typing import Any
from unittest.mock import MagicMock
import numpy as np
import pytest
class TestAugmentationResult:
"""Tests for AugmentationResult dataclass."""
def test_minimal_result(self) -> None:
"""Test creating result with only required field."""
from shared.augmentation.base import AugmentationResult
image = np.zeros((100, 100, 3), dtype=np.uint8)
result = AugmentationResult(image=image)
assert result.image is image
assert result.bboxes is None
assert result.transform_matrix is None
assert result.applied is True
assert result.metadata is None
def test_full_result(self) -> None:
"""Test creating result with all fields."""
from shared.augmentation.base import AugmentationResult
image = np.zeros((100, 100, 3), dtype=np.uint8)
bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]])
transform = np.eye(3)
metadata = {"applied_transform": "wrinkle"}
result = AugmentationResult(
image=image,
bboxes=bboxes,
transform_matrix=transform,
applied=True,
metadata=metadata,
)
assert result.image is image
np.testing.assert_array_equal(result.bboxes, bboxes)
np.testing.assert_array_equal(result.transform_matrix, transform)
assert result.applied is True
assert result.metadata == {"applied_transform": "wrinkle"}
def test_not_applied(self) -> None:
"""Test result when augmentation was not applied."""
from shared.augmentation.base import AugmentationResult
image = np.zeros((100, 100, 3), dtype=np.uint8)
result = AugmentationResult(image=image, applied=False)
assert result.applied is False
class TestBaseAugmentation:
"""Tests for BaseAugmentation abstract class."""
def test_cannot_instantiate_directly(self) -> None:
"""Test that BaseAugmentation cannot be instantiated."""
from shared.augmentation.base import BaseAugmentation
with pytest.raises(TypeError):
BaseAugmentation({}) # type: ignore
def test_subclass_must_implement_apply(self) -> None:
"""Test that subclass must implement apply method."""
from shared.augmentation.base import BaseAugmentation
class IncompleteAugmentation(BaseAugmentation):
name = "incomplete"
def _validate_params(self) -> None:
pass
# Missing apply method
with pytest.raises(TypeError):
IncompleteAugmentation({}) # type: ignore
def test_subclass_must_implement_validate_params(self) -> None:
"""Test that subclass must implement _validate_params."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class IncompleteAugmentation(BaseAugmentation):
name = "incomplete"
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
return AugmentationResult(image=image)
# Missing _validate_params method
with pytest.raises(TypeError):
IncompleteAugmentation({}) # type: ignore
def test_valid_subclass(self) -> None:
"""Test creating a valid subclass."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class DummyAugmentation(BaseAugmentation):
name = "dummy"
affects_geometry = False
def _validate_params(self) -> None:
pass
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
return AugmentationResult(image=image, bboxes=bboxes)
aug = DummyAugmentation({"param1": "value1"})
assert aug.name == "dummy"
assert aug.affects_geometry is False
assert aug.params == {"param1": "value1"}
def test_apply_returns_augmentation_result(self) -> None:
"""Test that apply returns AugmentationResult."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class DummyAugmentation(BaseAugmentation):
name = "dummy"
def _validate_params(self) -> None:
pass
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
# Simple pass-through
return AugmentationResult(image=image, bboxes=bboxes)
aug = DummyAugmentation({})
image = np.zeros((100, 100, 3), dtype=np.uint8)
bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]])
result = aug.apply(image, bboxes)
assert isinstance(result, AugmentationResult)
assert result.image is image
np.testing.assert_array_equal(result.bboxes, bboxes)
def test_affects_geometry_default(self) -> None:
"""Test that affects_geometry defaults to False."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class DummyAugmentation(BaseAugmentation):
name = "dummy"
# Not setting affects_geometry
def _validate_params(self) -> None:
pass
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
return AugmentationResult(image=image)
aug = DummyAugmentation({})
assert aug.affects_geometry is False
def test_validate_params_called_on_init(self) -> None:
"""Test that _validate_params is called during initialization."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
validation_called = {"called": False}
class ValidatingAugmentation(BaseAugmentation):
name = "validating"
def _validate_params(self) -> None:
validation_called["called"] = True
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
return AugmentationResult(image=image)
ValidatingAugmentation({})
assert validation_called["called"] is True
def test_validate_params_raises_on_invalid(self) -> None:
"""Test that _validate_params can raise ValueError."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class StrictAugmentation(BaseAugmentation):
name = "strict"
def _validate_params(self) -> None:
if "required_param" not in self.params:
raise ValueError("required_param is required")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
return AugmentationResult(image=image)
with pytest.raises(ValueError, match="required_param"):
StrictAugmentation({})
# Should work with required param
aug = StrictAugmentation({"required_param": "value"})
assert aug.params["required_param"] == "value"
def test_rng_usage(self) -> None:
"""Test that random generator can be passed and used."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class RandomAugmentation(BaseAugmentation):
name = "random"
def _validate_params(self) -> None:
pass
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
if rng is None:
rng = np.random.default_rng()
# Use rng to generate a random value
random_value = rng.random()
return AugmentationResult(
image=image,
metadata={"random_value": random_value},
)
aug = RandomAugmentation({})
image = np.zeros((100, 100, 3), dtype=np.uint8)
# With same seed, should get same result
rng1 = np.random.default_rng(42)
rng2 = np.random.default_rng(42)
result1 = aug.apply(image, rng=rng1)
result2 = aug.apply(image, rng=rng2)
assert result1.metadata is not None
assert result2.metadata is not None
assert result1.metadata["random_value"] == result2.metadata["random_value"]
class TestAugmentationResultImmutability:
"""Tests for ensuring result doesn't mutate input."""
def test_image_not_modified(self) -> None:
"""Test that original image is not modified."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class ModifyingAugmentation(BaseAugmentation):
name = "modifying"
def _validate_params(self) -> None:
pass
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
# Should copy before modifying
modified = image.copy()
modified[:] = 255
return AugmentationResult(image=modified)
aug = ModifyingAugmentation({})
original = np.zeros((100, 100, 3), dtype=np.uint8)
original_copy = original.copy()
result = aug.apply(original)
# Original should be unchanged
np.testing.assert_array_equal(original, original_copy)
# Result should be modified
assert np.all(result.image == 255)
def test_bboxes_not_modified(self) -> None:
"""Test that original bboxes are not modified."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class BboxModifyingAugmentation(BaseAugmentation):
name = "bbox_modifying"
affects_geometry = True
def _validate_params(self) -> None:
pass
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
if bboxes is not None:
# Should copy before modifying
modified_bboxes = bboxes.copy()
modified_bboxes[:, 1:] *= 0.5 # Scale down
return AugmentationResult(image=image, bboxes=modified_bboxes)
return AugmentationResult(image=image)
aug = BboxModifyingAugmentation({})
image = np.zeros((100, 100, 3), dtype=np.uint8)
original_bboxes = np.array([[0, 0.5, 0.5, 0.2, 0.2]], dtype=np.float32)
original_bboxes_copy = original_bboxes.copy()
result = aug.apply(image, original_bboxes)
# Original should be unchanged
np.testing.assert_array_equal(original_bboxes, original_bboxes_copy)
# Result should be modified
assert result.bboxes is not None
np.testing.assert_array_almost_equal(
result.bboxes, np.array([[0, 0.25, 0.25, 0.1, 0.1]])
)