WIP
This commit is contained in:
1
tests/shared/augmentation/__init__.py
Normal file
1
tests/shared/augmentation/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Tests for augmentation module
|
||||
347
tests/shared/augmentation/test_base.py
Normal file
347
tests/shared/augmentation/test_base.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
Tests for augmentation base module.
|
||||
|
||||
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
class TestAugmentationResult:
|
||||
"""Tests for AugmentationResult dataclass."""
|
||||
|
||||
def test_minimal_result(self) -> None:
|
||||
"""Test creating result with only required field."""
|
||||
from shared.augmentation.base import AugmentationResult
|
||||
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
result = AugmentationResult(image=image)
|
||||
|
||||
assert result.image is image
|
||||
assert result.bboxes is None
|
||||
assert result.transform_matrix is None
|
||||
assert result.applied is True
|
||||
assert result.metadata is None
|
||||
|
||||
def test_full_result(self) -> None:
|
||||
"""Test creating result with all fields."""
|
||||
from shared.augmentation.base import AugmentationResult
|
||||
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]])
|
||||
transform = np.eye(3)
|
||||
metadata = {"applied_transform": "wrinkle"}
|
||||
|
||||
result = AugmentationResult(
|
||||
image=image,
|
||||
bboxes=bboxes,
|
||||
transform_matrix=transform,
|
||||
applied=True,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
assert result.image is image
|
||||
np.testing.assert_array_equal(result.bboxes, bboxes)
|
||||
np.testing.assert_array_equal(result.transform_matrix, transform)
|
||||
assert result.applied is True
|
||||
assert result.metadata == {"applied_transform": "wrinkle"}
|
||||
|
||||
def test_not_applied(self) -> None:
|
||||
"""Test result when augmentation was not applied."""
|
||||
from shared.augmentation.base import AugmentationResult
|
||||
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
result = AugmentationResult(image=image, applied=False)
|
||||
|
||||
assert result.applied is False
|
||||
|
||||
|
||||
class TestBaseAugmentation:
|
||||
"""Tests for BaseAugmentation abstract class."""
|
||||
|
||||
def test_cannot_instantiate_directly(self) -> None:
|
||||
"""Test that BaseAugmentation cannot be instantiated."""
|
||||
from shared.augmentation.base import BaseAugmentation
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
BaseAugmentation({}) # type: ignore
|
||||
|
||||
def test_subclass_must_implement_apply(self) -> None:
|
||||
"""Test that subclass must implement apply method."""
|
||||
from shared.augmentation.base import BaseAugmentation
|
||||
|
||||
class IncompleteAugmentation(BaseAugmentation):
|
||||
name = "incomplete"
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
pass
|
||||
|
||||
# Missing apply method
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
IncompleteAugmentation({}) # type: ignore
|
||||
|
||||
def test_subclass_must_implement_validate_params(self) -> None:
|
||||
"""Test that subclass must implement _validate_params."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
class IncompleteAugmentation(BaseAugmentation):
|
||||
name = "incomplete"
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
return AugmentationResult(image=image)
|
||||
|
||||
# Missing _validate_params method
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
IncompleteAugmentation({}) # type: ignore
|
||||
|
||||
def test_valid_subclass(self) -> None:
|
||||
"""Test creating a valid subclass."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
class DummyAugmentation(BaseAugmentation):
|
||||
name = "dummy"
|
||||
affects_geometry = False
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
pass
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
return AugmentationResult(image=image, bboxes=bboxes)
|
||||
|
||||
aug = DummyAugmentation({"param1": "value1"})
|
||||
|
||||
assert aug.name == "dummy"
|
||||
assert aug.affects_geometry is False
|
||||
assert aug.params == {"param1": "value1"}
|
||||
|
||||
def test_apply_returns_augmentation_result(self) -> None:
|
||||
"""Test that apply returns AugmentationResult."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
class DummyAugmentation(BaseAugmentation):
|
||||
name = "dummy"
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
pass
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
# Simple pass-through
|
||||
return AugmentationResult(image=image, bboxes=bboxes)
|
||||
|
||||
aug = DummyAugmentation({})
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]])
|
||||
|
||||
result = aug.apply(image, bboxes)
|
||||
|
||||
assert isinstance(result, AugmentationResult)
|
||||
assert result.image is image
|
||||
np.testing.assert_array_equal(result.bboxes, bboxes)
|
||||
|
||||
def test_affects_geometry_default(self) -> None:
|
||||
"""Test that affects_geometry defaults to False."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
class DummyAugmentation(BaseAugmentation):
|
||||
name = "dummy"
|
||||
# Not setting affects_geometry
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
pass
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
return AugmentationResult(image=image)
|
||||
|
||||
aug = DummyAugmentation({})
|
||||
|
||||
assert aug.affects_geometry is False
|
||||
|
||||
def test_validate_params_called_on_init(self) -> None:
|
||||
"""Test that _validate_params is called during initialization."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
validation_called = {"called": False}
|
||||
|
||||
class ValidatingAugmentation(BaseAugmentation):
|
||||
name = "validating"
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
validation_called["called"] = True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
return AugmentationResult(image=image)
|
||||
|
||||
ValidatingAugmentation({})
|
||||
|
||||
assert validation_called["called"] is True
|
||||
|
||||
def test_validate_params_raises_on_invalid(self) -> None:
|
||||
"""Test that _validate_params can raise ValueError."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
class StrictAugmentation(BaseAugmentation):
|
||||
name = "strict"
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
if "required_param" not in self.params:
|
||||
raise ValueError("required_param is required")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
return AugmentationResult(image=image)
|
||||
|
||||
with pytest.raises(ValueError, match="required_param"):
|
||||
StrictAugmentation({})
|
||||
|
||||
# Should work with required param
|
||||
aug = StrictAugmentation({"required_param": "value"})
|
||||
assert aug.params["required_param"] == "value"
|
||||
|
||||
def test_rng_usage(self) -> None:
|
||||
"""Test that random generator can be passed and used."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
class RandomAugmentation(BaseAugmentation):
|
||||
name = "random"
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
pass
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
if rng is None:
|
||||
rng = np.random.default_rng()
|
||||
# Use rng to generate a random value
|
||||
random_value = rng.random()
|
||||
return AugmentationResult(
|
||||
image=image,
|
||||
metadata={"random_value": random_value},
|
||||
)
|
||||
|
||||
aug = RandomAugmentation({})
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
# With same seed, should get same result
|
||||
rng1 = np.random.default_rng(42)
|
||||
rng2 = np.random.default_rng(42)
|
||||
|
||||
result1 = aug.apply(image, rng=rng1)
|
||||
result2 = aug.apply(image, rng=rng2)
|
||||
|
||||
assert result1.metadata is not None
|
||||
assert result2.metadata is not None
|
||||
assert result1.metadata["random_value"] == result2.metadata["random_value"]
|
||||
|
||||
|
||||
class TestAugmentationResultImmutability:
|
||||
"""Tests for ensuring result doesn't mutate input."""
|
||||
|
||||
def test_image_not_modified(self) -> None:
|
||||
"""Test that original image is not modified."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
class ModifyingAugmentation(BaseAugmentation):
|
||||
name = "modifying"
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
pass
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
# Should copy before modifying
|
||||
modified = image.copy()
|
||||
modified[:] = 255
|
||||
return AugmentationResult(image=modified)
|
||||
|
||||
aug = ModifyingAugmentation({})
|
||||
original = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
original_copy = original.copy()
|
||||
|
||||
result = aug.apply(original)
|
||||
|
||||
# Original should be unchanged
|
||||
np.testing.assert_array_equal(original, original_copy)
|
||||
# Result should be modified
|
||||
assert np.all(result.image == 255)
|
||||
|
||||
def test_bboxes_not_modified(self) -> None:
|
||||
"""Test that original bboxes are not modified."""
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
class BboxModifyingAugmentation(BaseAugmentation):
|
||||
name = "bbox_modifying"
|
||||
affects_geometry = True
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
pass
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
if bboxes is not None:
|
||||
# Should copy before modifying
|
||||
modified_bboxes = bboxes.copy()
|
||||
modified_bboxes[:, 1:] *= 0.5 # Scale down
|
||||
return AugmentationResult(image=image, bboxes=modified_bboxes)
|
||||
return AugmentationResult(image=image)
|
||||
|
||||
aug = BboxModifyingAugmentation({})
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
original_bboxes = np.array([[0, 0.5, 0.5, 0.2, 0.2]], dtype=np.float32)
|
||||
original_bboxes_copy = original_bboxes.copy()
|
||||
|
||||
result = aug.apply(image, original_bboxes)
|
||||
|
||||
# Original should be unchanged
|
||||
np.testing.assert_array_equal(original_bboxes, original_bboxes_copy)
|
||||
# Result should be modified
|
||||
assert result.bboxes is not None
|
||||
np.testing.assert_array_almost_equal(
|
||||
result.bboxes, np.array([[0, 0.25, 0.25, 0.1, 0.1]])
|
||||
)
|
||||
283
tests/shared/augmentation/test_config.py
Normal file
283
tests/shared/augmentation/test_config.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
Tests for augmentation configuration module.
|
||||
|
||||
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestAugmentationParams:
|
||||
"""Tests for AugmentationParams dataclass."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test default parameter values."""
|
||||
from shared.augmentation.config import AugmentationParams
|
||||
|
||||
params = AugmentationParams()
|
||||
|
||||
assert params.enabled is False
|
||||
assert params.probability == 0.5
|
||||
assert params.params == {}
|
||||
|
||||
def test_custom_values(self) -> None:
|
||||
"""Test creating params with custom values."""
|
||||
from shared.augmentation.config import AugmentationParams
|
||||
|
||||
params = AugmentationParams(
|
||||
enabled=True,
|
||||
probability=0.8,
|
||||
params={"intensity": 0.5, "num_wrinkles": (2, 5)},
|
||||
)
|
||||
|
||||
assert params.enabled is True
|
||||
assert params.probability == 0.8
|
||||
assert params.params["intensity"] == 0.5
|
||||
assert params.params["num_wrinkles"] == (2, 5)
|
||||
|
||||
def test_immutability_params_dict(self) -> None:
|
||||
"""Test that params dict is independent between instances."""
|
||||
from shared.augmentation.config import AugmentationParams
|
||||
|
||||
params1 = AugmentationParams()
|
||||
params2 = AugmentationParams()
|
||||
|
||||
# Modifying one should not affect the other
|
||||
params1.params["test"] = "value"
|
||||
|
||||
assert "test" not in params2.params
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test conversion to dictionary."""
|
||||
from shared.augmentation.config import AugmentationParams
|
||||
|
||||
params = AugmentationParams(
|
||||
enabled=True,
|
||||
probability=0.7,
|
||||
params={"key": "value"},
|
||||
)
|
||||
|
||||
result = params.to_dict()
|
||||
|
||||
assert result == {
|
||||
"enabled": True,
|
||||
"probability": 0.7,
|
||||
"params": {"key": "value"},
|
||||
}
|
||||
|
||||
def test_from_dict(self) -> None:
|
||||
"""Test creation from dictionary."""
|
||||
from shared.augmentation.config import AugmentationParams
|
||||
|
||||
data = {
|
||||
"enabled": True,
|
||||
"probability": 0.6,
|
||||
"params": {"intensity": 0.3},
|
||||
}
|
||||
|
||||
params = AugmentationParams.from_dict(data)
|
||||
|
||||
assert params.enabled is True
|
||||
assert params.probability == 0.6
|
||||
assert params.params == {"intensity": 0.3}
|
||||
|
||||
def test_from_dict_with_defaults(self) -> None:
|
||||
"""Test creation from partial dictionary uses defaults."""
|
||||
from shared.augmentation.config import AugmentationParams
|
||||
|
||||
data: dict[str, Any] = {"enabled": True}
|
||||
|
||||
params = AugmentationParams.from_dict(data)
|
||||
|
||||
assert params.enabled is True
|
||||
assert params.probability == 0.5 # default
|
||||
assert params.params == {} # default
|
||||
|
||||
|
||||
class TestAugmentationConfig:
|
||||
"""Tests for AugmentationConfig dataclass."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test that all augmentation types have defaults."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
|
||||
config = AugmentationConfig()
|
||||
|
||||
# All augmentation types should exist
|
||||
augmentation_types = [
|
||||
"perspective_warp",
|
||||
"wrinkle",
|
||||
"edge_damage",
|
||||
"stain",
|
||||
"lighting_variation",
|
||||
"shadow",
|
||||
"gaussian_blur",
|
||||
"motion_blur",
|
||||
"gaussian_noise",
|
||||
"salt_pepper",
|
||||
"paper_texture",
|
||||
"scanner_artifacts",
|
||||
]
|
||||
|
||||
for aug_type in augmentation_types:
|
||||
assert hasattr(config, aug_type), f"Missing augmentation type: {aug_type}"
|
||||
params = getattr(config, aug_type)
|
||||
assert hasattr(params, "enabled")
|
||||
assert hasattr(params, "probability")
|
||||
assert hasattr(params, "params")
|
||||
|
||||
def test_global_settings_defaults(self) -> None:
|
||||
"""Test global settings default values."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
|
||||
config = AugmentationConfig()
|
||||
|
||||
assert config.preserve_bboxes is True
|
||||
assert config.seed is None
|
||||
|
||||
def test_custom_seed(self) -> None:
|
||||
"""Test setting custom seed for reproducibility."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
|
||||
config = AugmentationConfig(seed=42)
|
||||
|
||||
assert config.seed == 42
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test conversion to dictionary."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
|
||||
config = AugmentationConfig(seed=123, preserve_bboxes=False)
|
||||
|
||||
result = config.to_dict()
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["seed"] == 123
|
||||
assert result["preserve_bboxes"] is False
|
||||
assert "perspective_warp" in result
|
||||
assert "wrinkle" in result
|
||||
|
||||
def test_from_dict(self) -> None:
|
||||
"""Test creation from dictionary."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
|
||||
data = {
|
||||
"seed": 456,
|
||||
"preserve_bboxes": False,
|
||||
"wrinkle": {
|
||||
"enabled": True,
|
||||
"probability": 0.8,
|
||||
"params": {"intensity": 0.5},
|
||||
},
|
||||
}
|
||||
|
||||
config = AugmentationConfig.from_dict(data)
|
||||
|
||||
assert config.seed == 456
|
||||
assert config.preserve_bboxes is False
|
||||
assert config.wrinkle.enabled is True
|
||||
assert config.wrinkle.probability == 0.8
|
||||
assert config.wrinkle.params["intensity"] == 0.5
|
||||
|
||||
def test_from_dict_with_partial_data(self) -> None:
|
||||
"""Test creation from partial dictionary uses defaults."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
|
||||
data: dict[str, Any] = {
|
||||
"wrinkle": {"enabled": True},
|
||||
}
|
||||
|
||||
config = AugmentationConfig.from_dict(data)
|
||||
|
||||
# Explicitly set value
|
||||
assert config.wrinkle.enabled is True
|
||||
# Default values
|
||||
assert config.preserve_bboxes is True
|
||||
assert config.seed is None
|
||||
assert config.gaussian_blur.enabled is False
|
||||
|
||||
def test_get_enabled_augmentations(self) -> None:
|
||||
"""Test getting list of enabled augmentations."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
|
||||
config = AugmentationConfig(
|
||||
wrinkle=AugmentationParams(enabled=True),
|
||||
stain=AugmentationParams(enabled=True),
|
||||
gaussian_blur=AugmentationParams(enabled=False),
|
||||
)
|
||||
|
||||
enabled = config.get_enabled_augmentations()
|
||||
|
||||
assert "wrinkle" in enabled
|
||||
assert "stain" in enabled
|
||||
assert "gaussian_blur" not in enabled
|
||||
|
||||
def test_document_safe_defaults(self) -> None:
|
||||
"""Test that default params are document-safe (conservative)."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
|
||||
config = AugmentationConfig()
|
||||
|
||||
# Perspective warp should be very conservative
|
||||
assert config.perspective_warp.params.get("max_warp", 0.02) <= 0.05
|
||||
|
||||
# Noise should be subtle
|
||||
noise_std = config.gaussian_noise.params.get("std", (5, 15))
|
||||
if isinstance(noise_std, tuple):
|
||||
assert noise_std[1] <= 20 # Max std should be low
|
||||
|
||||
def test_immutability_between_instances(self) -> None:
|
||||
"""Test that config instances are independent."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
|
||||
config1 = AugmentationConfig()
|
||||
config2 = AugmentationConfig()
|
||||
|
||||
# Modifying one should not affect the other
|
||||
config1.wrinkle.params["test"] = "value"
|
||||
|
||||
assert "test" not in config2.wrinkle.params
|
||||
|
||||
|
||||
class TestAugmentationConfigValidation:
|
||||
"""Tests for configuration validation."""
|
||||
|
||||
def test_probability_range_validation(self) -> None:
|
||||
"""Test that probability values are validated."""
|
||||
from shared.augmentation.config import AugmentationParams
|
||||
|
||||
# Valid range
|
||||
params = AugmentationParams(probability=0.5)
|
||||
assert params.probability == 0.5
|
||||
|
||||
# Edge cases
|
||||
params_zero = AugmentationParams(probability=0.0)
|
||||
assert params_zero.probability == 0.0
|
||||
|
||||
params_one = AugmentationParams(probability=1.0)
|
||||
assert params_one.probability == 1.0
|
||||
|
||||
def test_config_validate_method(self) -> None:
|
||||
"""Test the validate method catches invalid configurations."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
|
||||
# Invalid probability
|
||||
config = AugmentationConfig(
|
||||
wrinkle=AugmentationParams(probability=1.5), # Invalid
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="probability"):
|
||||
config.validate()
|
||||
|
||||
def test_config_validate_negative_probability(self) -> None:
|
||||
"""Test validation catches negative probability."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
|
||||
config = AugmentationConfig(
|
||||
wrinkle=AugmentationParams(probability=-0.1),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="probability"):
|
||||
config.validate()
|
||||
338
tests/shared/augmentation/test_pipeline.py
Normal file
338
tests/shared/augmentation/test_pipeline.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""
|
||||
Tests for augmentation pipeline module.
|
||||
|
||||
TDD Phase 2: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
class TestAugmentationPipeline:
|
||||
"""Tests for AugmentationPipeline class."""
|
||||
|
||||
def test_create_with_config(self) -> None:
|
||||
"""Test creating pipeline with config."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig()
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
assert pipeline.config is config
|
||||
|
||||
def test_create_with_seed(self) -> None:
|
||||
"""Test creating pipeline with seed for reproducibility."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig(seed=42)
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
assert pipeline.config.seed == 42
|
||||
|
||||
def test_apply_returns_augmentation_result(self) -> None:
|
||||
"""Test that apply returns AugmentationResult."""
|
||||
from shared.augmentation.base import AugmentationResult
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig()
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
result = pipeline.apply(image)
|
||||
|
||||
assert isinstance(result, AugmentationResult)
|
||||
assert result.image is not None
|
||||
assert result.image.shape == image.shape
|
||||
|
||||
def test_apply_with_bboxes(self) -> None:
|
||||
"""Test apply with bounding boxes."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig()
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]], dtype=np.float32)
|
||||
|
||||
result = pipeline.apply(image, bboxes)
|
||||
|
||||
# Bboxes should be preserved when preserve_bboxes=True
|
||||
assert result.bboxes is not None
|
||||
|
||||
def test_apply_no_augmentations_enabled(self) -> None:
|
||||
"""Test apply when no augmentations are enabled."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
# Disable all augmentations
|
||||
config = AugmentationConfig(
|
||||
lighting_variation=AugmentationParams(enabled=False),
|
||||
)
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||
result = pipeline.apply(image)
|
||||
|
||||
# Image should be unchanged (or a copy)
|
||||
np.testing.assert_array_equal(result.image, image)
|
||||
|
||||
def test_apply_does_not_mutate_input(self) -> None:
|
||||
"""Test that apply does not mutate input image."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig(
|
||||
lighting_variation=AugmentationParams(enabled=True, probability=1.0),
|
||||
)
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
image = np.full((100, 100, 3), 128, dtype=np.uint8)
|
||||
original_copy = image.copy()
|
||||
|
||||
pipeline.apply(image)
|
||||
|
||||
np.testing.assert_array_equal(image, original_copy)
|
||||
|
||||
def test_reproducibility_with_seed(self) -> None:
|
||||
"""Test that same seed produces same results."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config1 = AugmentationConfig(
|
||||
seed=42,
|
||||
gaussian_noise=AugmentationParams(enabled=True, probability=1.0),
|
||||
)
|
||||
config2 = AugmentationConfig(
|
||||
seed=42,
|
||||
gaussian_noise=AugmentationParams(enabled=True, probability=1.0),
|
||||
)
|
||||
|
||||
pipeline1 = AugmentationPipeline(config1)
|
||||
pipeline2 = AugmentationPipeline(config2)
|
||||
|
||||
image = np.full((100, 100, 3), 128, dtype=np.uint8)
|
||||
|
||||
result1 = pipeline1.apply(image.copy())
|
||||
result2 = pipeline2.apply(image.copy())
|
||||
|
||||
np.testing.assert_array_equal(result1.image, result2.image)
|
||||
|
||||
def test_metadata_contains_applied_augmentations(self) -> None:
|
||||
"""Test that metadata lists applied augmentations."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig(
|
||||
seed=42,
|
||||
gaussian_noise=AugmentationParams(enabled=True, probability=1.0),
|
||||
lighting_variation=AugmentationParams(enabled=True, probability=1.0),
|
||||
)
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
image = np.full((100, 100, 3), 128, dtype=np.uint8)
|
||||
result = pipeline.apply(image)
|
||||
|
||||
assert result.metadata is not None
|
||||
assert "applied_augmentations" in result.metadata
|
||||
# Both should be applied with probability=1.0
|
||||
assert "gaussian_noise" in result.metadata["applied_augmentations"]
|
||||
assert "lighting_variation" in result.metadata["applied_augmentations"]
|
||||
|
||||
|
||||
class TestAugmentationPipelineStageOrder:
|
||||
"""Tests for pipeline stage ordering."""
|
||||
|
||||
def test_stage_order_defined(self) -> None:
|
||||
"""Test that stage order is defined."""
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
assert hasattr(AugmentationPipeline, "STAGE_ORDER")
|
||||
expected_stages = [
|
||||
"geometric",
|
||||
"degradation",
|
||||
"lighting",
|
||||
"texture",
|
||||
"blur",
|
||||
"noise",
|
||||
]
|
||||
assert AugmentationPipeline.STAGE_ORDER == expected_stages
|
||||
|
||||
def test_stage_mapping_defined(self) -> None:
|
||||
"""Test that all augmentation types are mapped to stages."""
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
assert hasattr(AugmentationPipeline, "STAGE_MAPPING")
|
||||
|
||||
expected_mappings = {
|
||||
"perspective_warp": "geometric",
|
||||
"wrinkle": "degradation",
|
||||
"edge_damage": "degradation",
|
||||
"stain": "degradation",
|
||||
"lighting_variation": "lighting",
|
||||
"shadow": "lighting",
|
||||
"paper_texture": "texture",
|
||||
"scanner_artifacts": "texture",
|
||||
"gaussian_blur": "blur",
|
||||
"motion_blur": "blur",
|
||||
"gaussian_noise": "noise",
|
||||
"salt_pepper": "noise",
|
||||
}
|
||||
|
||||
for aug_name, stage in expected_mappings.items():
|
||||
assert aug_name in AugmentationPipeline.STAGE_MAPPING
|
||||
assert AugmentationPipeline.STAGE_MAPPING[aug_name] == stage
|
||||
|
||||
def test_geometric_before_degradation(self) -> None:
|
||||
"""Test that geometric transforms are applied before degradation."""
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
stages = AugmentationPipeline.STAGE_ORDER
|
||||
geometric_idx = stages.index("geometric")
|
||||
degradation_idx = stages.index("degradation")
|
||||
|
||||
assert geometric_idx < degradation_idx
|
||||
|
||||
def test_noise_applied_last(self) -> None:
|
||||
"""Test that noise is applied last."""
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
stages = AugmentationPipeline.STAGE_ORDER
|
||||
assert stages[-1] == "noise"
|
||||
|
||||
|
||||
class TestAugmentationRegistry:
|
||||
"""Tests for augmentation registry."""
|
||||
|
||||
def test_registry_exists(self) -> None:
|
||||
"""Test that augmentation registry exists."""
|
||||
from shared.augmentation.pipeline import AUGMENTATION_REGISTRY
|
||||
|
||||
assert isinstance(AUGMENTATION_REGISTRY, dict)
|
||||
|
||||
def test_registry_contains_all_types(self) -> None:
|
||||
"""Test that registry contains all augmentation types."""
|
||||
from shared.augmentation.pipeline import AUGMENTATION_REGISTRY
|
||||
|
||||
expected_types = [
|
||||
"perspective_warp",
|
||||
"wrinkle",
|
||||
"edge_damage",
|
||||
"stain",
|
||||
"lighting_variation",
|
||||
"shadow",
|
||||
"gaussian_blur",
|
||||
"motion_blur",
|
||||
"gaussian_noise",
|
||||
"salt_pepper",
|
||||
"paper_texture",
|
||||
"scanner_artifacts",
|
||||
]
|
||||
|
||||
for aug_type in expected_types:
|
||||
assert aug_type in AUGMENTATION_REGISTRY, f"Missing: {aug_type}"
|
||||
|
||||
|
||||
class TestPipelinePreview:
|
||||
"""Tests for pipeline preview functionality."""
|
||||
|
||||
def test_preview_single_augmentation(self) -> None:
|
||||
"""Test previewing a single augmentation."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig(
|
||||
gaussian_noise=AugmentationParams(
|
||||
enabled=True, probability=1.0, params={"std": (10, 10)}
|
||||
),
|
||||
)
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
image = np.full((100, 100, 3), 128, dtype=np.uint8)
|
||||
preview = pipeline.preview(image, "gaussian_noise")
|
||||
|
||||
assert preview.shape == image.shape
|
||||
assert preview.dtype == np.uint8
|
||||
# Preview should modify the image
|
||||
assert not np.array_equal(preview, image)
|
||||
|
||||
def test_preview_unknown_augmentation_raises(self) -> None:
|
||||
"""Test that previewing unknown augmentation raises error."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig()
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown augmentation"):
|
||||
pipeline.preview(image, "non_existent_augmentation")
|
||||
|
||||
def test_preview_is_deterministic(self) -> None:
|
||||
"""Test that preview produces deterministic results."""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
config = AugmentationConfig(
|
||||
gaussian_noise=AugmentationParams(enabled=True),
|
||||
)
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
image = np.full((100, 100, 3), 128, dtype=np.uint8)
|
||||
|
||||
preview1 = pipeline.preview(image, "gaussian_noise")
|
||||
preview2 = pipeline.preview(image, "gaussian_noise")
|
||||
|
||||
np.testing.assert_array_equal(preview1, preview2)
|
||||
|
||||
|
||||
class TestPipelineGetAvailableAugmentations:
|
||||
"""Tests for getting available augmentations."""
|
||||
|
||||
def test_get_available_augmentations(self) -> None:
|
||||
"""Test getting list of available augmentations."""
|
||||
from shared.augmentation.pipeline import get_available_augmentations
|
||||
|
||||
augmentations = get_available_augmentations()
|
||||
|
||||
assert isinstance(augmentations, list)
|
||||
assert len(augmentations) == 12
|
||||
|
||||
# Each item should have name, description, affects_geometry
|
||||
for aug in augmentations:
|
||||
assert "name" in aug
|
||||
assert "description" in aug
|
||||
assert "affects_geometry" in aug
|
||||
assert "stage" in aug
|
||||
|
||||
def test_get_available_augmentations_includes_all_types(self) -> None:
|
||||
"""Test that all augmentation types are included."""
|
||||
from shared.augmentation.pipeline import get_available_augmentations
|
||||
|
||||
augmentations = get_available_augmentations()
|
||||
names = [aug["name"] for aug in augmentations]
|
||||
|
||||
expected = [
|
||||
"perspective_warp",
|
||||
"wrinkle",
|
||||
"edge_damage",
|
||||
"stain",
|
||||
"lighting_variation",
|
||||
"shadow",
|
||||
"gaussian_blur",
|
||||
"motion_blur",
|
||||
"gaussian_noise",
|
||||
"salt_pepper",
|
||||
"paper_texture",
|
||||
"scanner_artifacts",
|
||||
]
|
||||
|
||||
for name in expected:
|
||||
assert name in names
|
||||
102
tests/shared/augmentation/test_presets.py
Normal file
102
tests/shared/augmentation/test_presets.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
Tests for augmentation presets module.
|
||||
|
||||
TDD Phase 4: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestPresets:
|
||||
"""Tests for augmentation presets."""
|
||||
|
||||
def test_presets_dict_exists(self) -> None:
|
||||
"""Test that PRESETS dictionary exists."""
|
||||
from shared.augmentation.presets import PRESETS
|
||||
|
||||
assert isinstance(PRESETS, dict)
|
||||
assert len(PRESETS) > 0
|
||||
|
||||
def test_expected_presets_exist(self) -> None:
|
||||
"""Test that expected presets are defined."""
|
||||
from shared.augmentation.presets import PRESETS
|
||||
|
||||
expected_presets = ["conservative", "moderate", "aggressive", "scanned_document"]
|
||||
|
||||
for preset_name in expected_presets:
|
||||
assert preset_name in PRESETS, f"Missing preset: {preset_name}"
|
||||
|
||||
def test_preset_structure(self) -> None:
|
||||
"""Test that each preset has required structure."""
|
||||
from shared.augmentation.presets import PRESETS
|
||||
|
||||
for name, preset in PRESETS.items():
|
||||
assert "description" in preset, f"Preset {name} missing description"
|
||||
assert "config" in preset, f"Preset {name} missing config"
|
||||
assert isinstance(preset["description"], str)
|
||||
assert isinstance(preset["config"], dict)
|
||||
|
||||
def test_get_preset_config(self) -> None:
|
||||
"""Test getting config from preset."""
|
||||
from shared.augmentation.presets import get_preset_config
|
||||
|
||||
config = get_preset_config("conservative")
|
||||
|
||||
assert config is not None
|
||||
# Should have at least some augmentations defined
|
||||
assert len(config) > 0
|
||||
|
||||
def test_get_preset_config_unknown_raises(self) -> None:
|
||||
"""Test that getting unknown preset raises error."""
|
||||
from shared.augmentation.presets import get_preset_config
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown preset"):
|
||||
get_preset_config("nonexistent_preset")
|
||||
|
||||
def test_create_config_from_preset(self) -> None:
|
||||
"""Test creating AugmentationConfig from preset."""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
from shared.augmentation.presets import create_config_from_preset
|
||||
|
||||
config = create_config_from_preset("moderate")
|
||||
|
||||
assert isinstance(config, AugmentationConfig)
|
||||
|
||||
def test_conservative_preset_is_safe(self) -> None:
|
||||
"""Test that conservative preset only enables safe augmentations."""
|
||||
from shared.augmentation.presets import create_config_from_preset
|
||||
|
||||
config = create_config_from_preset("conservative")
|
||||
|
||||
# Should NOT enable geometric transforms
|
||||
assert config.perspective_warp.enabled is False
|
||||
|
||||
# Should NOT enable heavy degradation
|
||||
assert config.wrinkle.enabled is False
|
||||
assert config.edge_damage.enabled is False
|
||||
assert config.stain.enabled is False
|
||||
|
||||
def test_aggressive_preset_enables_more(self) -> None:
|
||||
"""Test that aggressive preset enables more augmentations."""
|
||||
from shared.augmentation.presets import create_config_from_preset
|
||||
|
||||
config = create_config_from_preset("aggressive")
|
||||
|
||||
enabled = config.get_enabled_augmentations()
|
||||
|
||||
# Should enable multiple augmentation types
|
||||
assert len(enabled) >= 6
|
||||
|
||||
def test_list_presets(self) -> None:
|
||||
"""Test listing available presets."""
|
||||
from shared.augmentation.presets import list_presets
|
||||
|
||||
presets = list_presets()
|
||||
|
||||
assert isinstance(presets, list)
|
||||
assert len(presets) >= 4
|
||||
|
||||
# Each item should have name and description
|
||||
for preset in presets:
|
||||
assert "name" in preset
|
||||
assert "description" in preset
|
||||
1
tests/shared/augmentation/transforms/__init__.py
Normal file
1
tests/shared/augmentation/transforms/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Tests for augmentation transforms
|
||||
293
tests/shared/test_dataset_augmenter.py
Normal file
293
tests/shared/test_dataset_augmenter.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""
|
||||
Tests for DatasetAugmenter.
|
||||
|
||||
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class TestDatasetAugmenter:
|
||||
"""Tests for DatasetAugmenter class."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_dataset(self, tmp_path: Path) -> Path:
|
||||
"""Create a sample YOLO dataset structure."""
|
||||
dataset_dir = tmp_path / "dataset"
|
||||
|
||||
# Create directory structure
|
||||
for split in ["train", "val", "test"]:
|
||||
(dataset_dir / "images" / split).mkdir(parents=True)
|
||||
(dataset_dir / "labels" / split).mkdir(parents=True)
|
||||
|
||||
# Create sample images and labels
|
||||
for i in range(3):
|
||||
# Create 100x100 white image
|
||||
img = Image.new("RGB", (100, 100), color="white")
|
||||
img_path = dataset_dir / "images" / "train" / f"doc_{i}.png"
|
||||
img.save(img_path)
|
||||
|
||||
# Create label with 2 bboxes
|
||||
# Format: class_id x_center y_center width height
|
||||
label_content = "0 0.5 0.3 0.2 0.1\n1 0.7 0.6 0.15 0.2\n"
|
||||
label_path = dataset_dir / "labels" / "train" / f"doc_{i}.txt"
|
||||
label_path.write_text(label_content)
|
||||
|
||||
# Create data.yaml
|
||||
data_yaml = dataset_dir / "data.yaml"
|
||||
data_yaml.write_text(
|
||||
"path: .\n"
|
||||
"train: images/train\n"
|
||||
"val: images/val\n"
|
||||
"test: images/test\n"
|
||||
"nc: 10\n"
|
||||
"names: [class0, class1, class2, class3, class4, class5, class6, class7, class8, class9]\n"
|
||||
)
|
||||
|
||||
return dataset_dir
|
||||
|
||||
@pytest.fixture
|
||||
def augmentation_config(self) -> dict:
|
||||
"""Create a sample augmentation config."""
|
||||
return {
|
||||
"gaussian_noise": {
|
||||
"enabled": True,
|
||||
"probability": 1.0,
|
||||
"params": {"std": 10},
|
||||
},
|
||||
"gaussian_blur": {
|
||||
"enabled": True,
|
||||
"probability": 1.0,
|
||||
"params": {"kernel_size": 3},
|
||||
},
|
||||
}
|
||||
|
||||
def test_augmenter_creates_additional_images(
|
||||
self, sample_dataset: Path, augmentation_config: dict
|
||||
):
|
||||
"""Test that augmenter creates new augmented images."""
|
||||
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
||||
|
||||
augmenter = DatasetAugmenter(augmentation_config)
|
||||
|
||||
# Count original images
|
||||
original_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
|
||||
assert original_count == 3
|
||||
|
||||
# Apply augmentation with multiplier=2
|
||||
result = augmenter.augment_dataset(sample_dataset, multiplier=2)
|
||||
|
||||
# Should now have original + 2x augmented = 3 + 6 = 9 images
|
||||
new_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
|
||||
assert new_count == 9
|
||||
assert result["augmented_images"] == 6
|
||||
|
||||
def test_augmenter_creates_matching_labels(
|
||||
self, sample_dataset: Path, augmentation_config: dict
|
||||
):
|
||||
"""Test that augmenter creates label files for each augmented image."""
|
||||
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
||||
|
||||
augmenter = DatasetAugmenter(augmentation_config)
|
||||
augmenter.augment_dataset(sample_dataset, multiplier=2)
|
||||
|
||||
# Check that each image has a matching label file
|
||||
images = list((sample_dataset / "images" / "train").glob("*.png"))
|
||||
labels = list((sample_dataset / "labels" / "train").glob("*.txt"))
|
||||
|
||||
assert len(images) == len(labels)
|
||||
|
||||
# Check that augmented images have corresponding labels
|
||||
for img_path in images:
|
||||
label_path = sample_dataset / "labels" / "train" / f"{img_path.stem}.txt"
|
||||
assert label_path.exists(), f"Missing label for {img_path.name}"
|
||||
|
||||
def test_augmented_labels_have_valid_format(
|
||||
self, sample_dataset: Path, augmentation_config: dict
|
||||
):
|
||||
"""Test that augmented label files have valid YOLO format."""
|
||||
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
||||
|
||||
augmenter = DatasetAugmenter(augmentation_config)
|
||||
augmenter.augment_dataset(sample_dataset, multiplier=1)
|
||||
|
||||
# Check all label files
|
||||
for label_path in (sample_dataset / "labels" / "train").glob("*.txt"):
|
||||
content = label_path.read_text().strip()
|
||||
if not content:
|
||||
continue # Empty labels are valid (background images)
|
||||
|
||||
for line in content.split("\n"):
|
||||
parts = line.split()
|
||||
assert len(parts) == 5, f"Invalid label format in {label_path.name}"
|
||||
|
||||
class_id = int(parts[0])
|
||||
x_center = float(parts[1])
|
||||
y_center = float(parts[2])
|
||||
width = float(parts[3])
|
||||
height = float(parts[4])
|
||||
|
||||
# Check values are in valid range
|
||||
assert 0 <= class_id < 100, f"Invalid class_id: {class_id}"
|
||||
assert 0 <= x_center <= 1, f"Invalid x_center: {x_center}"
|
||||
assert 0 <= y_center <= 1, f"Invalid y_center: {y_center}"
|
||||
assert 0 <= width <= 1, f"Invalid width: {width}"
|
||||
assert 0 <= height <= 1, f"Invalid height: {height}"
|
||||
|
||||
def test_augmented_images_are_different(
|
||||
self, sample_dataset: Path, augmentation_config: dict
|
||||
):
|
||||
"""Test that augmented images are actually different from originals."""
|
||||
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
||||
|
||||
# Load original image
|
||||
original_path = sample_dataset / "images" / "train" / "doc_0.png"
|
||||
original_img = np.array(Image.open(original_path))
|
||||
|
||||
augmenter = DatasetAugmenter(augmentation_config)
|
||||
augmenter.augment_dataset(sample_dataset, multiplier=1)
|
||||
|
||||
# Find augmented version
|
||||
aug_path = sample_dataset / "images" / "train" / "doc_0_aug0.png"
|
||||
assert aug_path.exists()
|
||||
|
||||
aug_img = np.array(Image.open(aug_path))
|
||||
|
||||
# Images should be different (due to noise/blur)
|
||||
assert not np.array_equal(original_img, aug_img)
|
||||
|
||||
def test_augmented_images_same_size(
|
||||
self, sample_dataset: Path, augmentation_config: dict
|
||||
):
|
||||
"""Test that augmented images have same size as originals."""
|
||||
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
||||
|
||||
# Get original size
|
||||
original_path = sample_dataset / "images" / "train" / "doc_0.png"
|
||||
original_img = Image.open(original_path)
|
||||
original_size = original_img.size
|
||||
|
||||
augmenter = DatasetAugmenter(augmentation_config)
|
||||
augmenter.augment_dataset(sample_dataset, multiplier=1)
|
||||
|
||||
# Check all augmented images have same size
|
||||
for img_path in (sample_dataset / "images" / "train").glob("*_aug*.png"):
|
||||
img = Image.open(img_path)
|
||||
assert img.size == original_size, f"{img_path.name} has wrong size"
|
||||
|
||||
def test_perspective_warp_updates_bboxes(self, sample_dataset: Path):
|
||||
"""Test that perspective_warp augmentation updates bbox coordinates."""
|
||||
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
||||
|
||||
config = {
|
||||
"perspective_warp": {
|
||||
"enabled": True,
|
||||
"probability": 1.0,
|
||||
"params": {"max_warp": 0.05}, # Use larger warp for visible difference
|
||||
},
|
||||
}
|
||||
|
||||
# Read original label
|
||||
original_label = (sample_dataset / "labels" / "train" / "doc_0.txt").read_text()
|
||||
original_bboxes = [line.split() for line in original_label.strip().split("\n")]
|
||||
|
||||
augmenter = DatasetAugmenter(config)
|
||||
augmenter.augment_dataset(sample_dataset, multiplier=1)
|
||||
|
||||
# Read augmented label
|
||||
aug_label = (sample_dataset / "labels" / "train" / "doc_0_aug0.txt").read_text()
|
||||
aug_bboxes = [line.split() for line in aug_label.strip().split("\n")]
|
||||
|
||||
# Same number of bboxes
|
||||
assert len(original_bboxes) == len(aug_bboxes)
|
||||
|
||||
# At least one bbox should have different coordinates
|
||||
# (perspective warp changes geometry)
|
||||
differences_found = False
|
||||
for orig, aug in zip(original_bboxes, aug_bboxes):
|
||||
# Class ID should be same
|
||||
assert orig[0] == aug[0]
|
||||
# Coordinates might differ
|
||||
if orig[1:] != aug[1:]:
|
||||
differences_found = True
|
||||
|
||||
assert differences_found, "Perspective warp should change bbox coordinates"
|
||||
|
||||
def test_augmenter_only_processes_train_split(
|
||||
self, sample_dataset: Path, augmentation_config: dict
|
||||
):
|
||||
"""Test that augmenter only processes train split by default."""
|
||||
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
||||
|
||||
# Add a val image
|
||||
val_img = Image.new("RGB", (100, 100), color="white")
|
||||
val_img.save(sample_dataset / "images" / "val" / "val_doc.png")
|
||||
(sample_dataset / "labels" / "val" / "val_doc.txt").write_text("0 0.5 0.5 0.1 0.1\n")
|
||||
|
||||
augmenter = DatasetAugmenter(augmentation_config)
|
||||
augmenter.augment_dataset(sample_dataset, multiplier=2)
|
||||
|
||||
# Val should still have only 1 image
|
||||
val_count = len(list((sample_dataset / "images" / "val").glob("*.png")))
|
||||
assert val_count == 1
|
||||
|
||||
def test_augmenter_with_multiplier_zero_does_nothing(
|
||||
self, sample_dataset: Path, augmentation_config: dict
|
||||
):
|
||||
"""Test that multiplier=0 creates no augmented images."""
|
||||
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
||||
|
||||
original_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
|
||||
|
||||
augmenter = DatasetAugmenter(augmentation_config)
|
||||
result = augmenter.augment_dataset(sample_dataset, multiplier=0)
|
||||
|
||||
new_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
|
||||
assert new_count == original_count
|
||||
assert result["augmented_images"] == 0
|
||||
|
||||
def test_augmenter_with_seed_is_reproducible(
|
||||
self, sample_dataset: Path, augmentation_config: dict
|
||||
):
|
||||
"""Test that same seed produces same augmentation results."""
|
||||
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
||||
|
||||
# Create two separate datasets
|
||||
import shutil
|
||||
dataset1 = sample_dataset
|
||||
dataset2 = sample_dataset.parent / "dataset2"
|
||||
shutil.copytree(dataset1, dataset2)
|
||||
|
||||
# Augment both with same seed
|
||||
augmenter1 = DatasetAugmenter(augmentation_config, seed=42)
|
||||
augmenter1.augment_dataset(dataset1, multiplier=1)
|
||||
|
||||
augmenter2 = DatasetAugmenter(augmentation_config, seed=42)
|
||||
augmenter2.augment_dataset(dataset2, multiplier=1)
|
||||
|
||||
# Compare augmented images
|
||||
aug1 = np.array(Image.open(dataset1 / "images" / "train" / "doc_0_aug0.png"))
|
||||
aug2 = np.array(Image.open(dataset2 / "images" / "train" / "doc_0_aug0.png"))
|
||||
|
||||
assert np.array_equal(aug1, aug2), "Same seed should produce same augmentation"
|
||||
|
||||
def test_augmenter_returns_summary(
|
||||
self, sample_dataset: Path, augmentation_config: dict
|
||||
):
|
||||
"""Test that augmenter returns a summary of what was done."""
|
||||
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
||||
|
||||
augmenter = DatasetAugmenter(augmentation_config)
|
||||
result = augmenter.augment_dataset(sample_dataset, multiplier=2)
|
||||
|
||||
assert "original_images" in result
|
||||
assert "augmented_images" in result
|
||||
assert "total_images" in result
|
||||
assert result["original_images"] == 3
|
||||
assert result["augmented_images"] == 6
|
||||
assert result["total_images"] == 9
|
||||
261
tests/web/test_augmentation_routes.py
Normal file
261
tests/web/test_augmentation_routes.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""
|
||||
Tests for augmentation API routes.
|
||||
|
||||
TDD Phase 5: RED - Write tests first, then implement to pass.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
class TestAugmentationTypesEndpoint:
|
||||
"""Tests for GET /admin/augmentation/types endpoint."""
|
||||
|
||||
def test_list_augmentation_types(
|
||||
self, admin_client: TestClient, admin_token: str
|
||||
) -> None:
|
||||
"""Test listing available augmentation types."""
|
||||
response = admin_client.get(
|
||||
"/api/v1/admin/augmentation/types",
|
||||
headers={"X-Admin-Token": admin_token},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert "augmentation_types" in data
|
||||
assert len(data["augmentation_types"]) == 12
|
||||
|
||||
# Check structure
|
||||
aug_type = data["augmentation_types"][0]
|
||||
assert "name" in aug_type
|
||||
assert "description" in aug_type
|
||||
assert "affects_geometry" in aug_type
|
||||
assert "stage" in aug_type
|
||||
|
||||
def test_list_augmentation_types_unauthorized(
|
||||
self, admin_client: TestClient
|
||||
) -> None:
|
||||
"""Test that unauthorized request is rejected."""
|
||||
response = admin_client.get("/api/v1/admin/augmentation/types")
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
class TestAugmentationPresetsEndpoint:
|
||||
"""Tests for GET /admin/augmentation/presets endpoint."""
|
||||
|
||||
def test_list_presets(self, admin_client: TestClient, admin_token: str) -> None:
|
||||
"""Test listing available presets."""
|
||||
response = admin_client.get(
|
||||
"/api/v1/admin/augmentation/presets",
|
||||
headers={"X-Admin-Token": admin_token},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert "presets" in data
|
||||
assert len(data["presets"]) >= 4
|
||||
|
||||
# Check expected presets exist
|
||||
preset_names = [p["name"] for p in data["presets"]]
|
||||
assert "conservative" in preset_names
|
||||
assert "moderate" in preset_names
|
||||
assert "aggressive" in preset_names
|
||||
assert "scanned_document" in preset_names
|
||||
|
||||
|
||||
class TestAugmentationPreviewEndpoint:
|
||||
"""Tests for POST /admin/augmentation/preview/{document_id} endpoint."""
|
||||
|
||||
def test_preview_augmentation(
|
||||
self,
|
||||
admin_client: TestClient,
|
||||
admin_token: str,
|
||||
sample_document_id: str,
|
||||
) -> None:
|
||||
"""Test previewing augmentation on a document."""
|
||||
response = admin_client.post(
|
||||
f"/api/v1/admin/augmentation/preview/{sample_document_id}",
|
||||
headers={"X-Admin-Token": admin_token},
|
||||
json={
|
||||
"augmentation_type": "gaussian_noise",
|
||||
"params": {"std": 15},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert "preview_url" in data
|
||||
assert "original_url" in data
|
||||
assert "applied_params" in data
|
||||
|
||||
def test_preview_invalid_augmentation_type(
|
||||
self,
|
||||
admin_client: TestClient,
|
||||
admin_token: str,
|
||||
sample_document_id: str,
|
||||
) -> None:
|
||||
"""Test that invalid augmentation type returns error."""
|
||||
response = admin_client.post(
|
||||
f"/api/v1/admin/augmentation/preview/{sample_document_id}",
|
||||
headers={"X-Admin-Token": admin_token},
|
||||
json={
|
||||
"augmentation_type": "nonexistent",
|
||||
"params": {},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_preview_nonexistent_document(
|
||||
self,
|
||||
admin_client: TestClient,
|
||||
admin_token: str,
|
||||
) -> None:
|
||||
"""Test that nonexistent document returns 404."""
|
||||
response = admin_client.post(
|
||||
"/api/v1/admin/augmentation/preview/00000000-0000-0000-0000-000000000000",
|
||||
headers={"X-Admin-Token": admin_token},
|
||||
json={
|
||||
"augmentation_type": "gaussian_noise",
|
||||
"params": {},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
class TestAugmentationPreviewConfigEndpoint:
|
||||
"""Tests for POST /admin/augmentation/preview-config/{document_id} endpoint."""
|
||||
|
||||
def test_preview_config(
|
||||
self,
|
||||
admin_client: TestClient,
|
||||
admin_token: str,
|
||||
sample_document_id: str,
|
||||
) -> None:
|
||||
"""Test previewing full config on a document."""
|
||||
response = admin_client.post(
|
||||
f"/api/v1/admin/augmentation/preview-config/{sample_document_id}",
|
||||
headers={"X-Admin-Token": admin_token},
|
||||
json={
|
||||
"gaussian_noise": {"enabled": True, "probability": 1.0},
|
||||
"lighting_variation": {"enabled": True, "probability": 1.0},
|
||||
"preserve_bboxes": True,
|
||||
"seed": 42,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert "preview_url" in data
|
||||
assert "original_url" in data
|
||||
|
||||
|
||||
class TestAugmentationBatchEndpoint:
|
||||
"""Tests for POST /admin/augmentation/batch endpoint."""
|
||||
|
||||
def test_create_augmented_dataset(
|
||||
self,
|
||||
admin_client: TestClient,
|
||||
admin_token: str,
|
||||
sample_dataset_id: str,
|
||||
) -> None:
|
||||
"""Test creating augmented dataset."""
|
||||
response = admin_client.post(
|
||||
"/api/v1/admin/augmentation/batch",
|
||||
headers={"X-Admin-Token": admin_token},
|
||||
json={
|
||||
"dataset_id": sample_dataset_id,
|
||||
"config": {
|
||||
"gaussian_noise": {"enabled": True, "probability": 0.5},
|
||||
"preserve_bboxes": True,
|
||||
},
|
||||
"output_name": "test_augmented_dataset",
|
||||
"multiplier": 2,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert "task_id" in data
|
||||
assert "status" in data
|
||||
assert "estimated_images" in data
|
||||
|
||||
def test_create_augmented_dataset_invalid_multiplier(
|
||||
self,
|
||||
admin_client: TestClient,
|
||||
admin_token: str,
|
||||
sample_dataset_id: str,
|
||||
) -> None:
|
||||
"""Test that invalid multiplier is rejected."""
|
||||
response = admin_client.post(
|
||||
"/api/v1/admin/augmentation/batch",
|
||||
headers={"X-Admin-Token": admin_token},
|
||||
json={
|
||||
"dataset_id": sample_dataset_id,
|
||||
"config": {},
|
||||
"output_name": "test",
|
||||
"multiplier": 100, # Too high
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
|
||||
class TestAugmentedDatasetsListEndpoint:
|
||||
"""Tests for GET /admin/augmentation/datasets endpoint."""
|
||||
|
||||
def test_list_augmented_datasets(
|
||||
self, admin_client: TestClient, admin_token: str
|
||||
) -> None:
|
||||
"""Test listing augmented datasets."""
|
||||
response = admin_client.get(
|
||||
"/api/v1/admin/augmentation/datasets",
|
||||
headers={"X-Admin-Token": admin_token},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert "total" in data
|
||||
assert "limit" in data
|
||||
assert "offset" in data
|
||||
assert "datasets" in data
|
||||
assert isinstance(data["datasets"], list)
|
||||
|
||||
def test_list_augmented_datasets_pagination(
|
||||
self, admin_client: TestClient, admin_token: str
|
||||
) -> None:
|
||||
"""Test pagination parameters."""
|
||||
response = admin_client.get(
|
||||
"/api/v1/admin/augmentation/datasets",
|
||||
headers={"X-Admin-Token": admin_token},
|
||||
params={"limit": 5, "offset": 0},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["limit"] == 5
|
||||
assert data["offset"] == 0
|
||||
|
||||
|
||||
# Fixtures for tests
|
||||
@pytest.fixture
|
||||
def sample_document_id() -> str:
|
||||
"""Provide a sample document ID for testing."""
|
||||
# This would need to be created in test setup
|
||||
return "test-document-id"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_dataset_id() -> str:
|
||||
"""Provide a sample dataset ID for testing."""
|
||||
# This would need to be created in test setup
|
||||
return "test-dataset-id"
|
||||
@@ -329,3 +329,414 @@ class TestDatasetBuilder:
|
||||
results.append([(d["document_id"], d["split"]) for d in docs])
|
||||
|
||||
assert results[0] == results[1]
|
||||
|
||||
|
||||
class TestAssignSplitsByGroup:
|
||||
"""Tests for _assign_splits_by_group method with group_key logic."""
|
||||
|
||||
def _make_mock_doc(self, doc_id, group_key=None):
|
||||
"""Create a mock AdminDocument with document_id and group_key."""
|
||||
doc = MagicMock(spec=AdminDocument)
|
||||
doc.document_id = doc_id
|
||||
doc.group_key = group_key
|
||||
doc.page_count = 1
|
||||
return doc
|
||||
|
||||
def test_single_doc_groups_are_distributed(self, tmp_path, mock_admin_db):
|
||||
"""Documents with unique group_key are distributed across splits."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
|
||||
# 3 documents, each with unique group_key
|
||||
docs = [
|
||||
self._make_mock_doc(uuid4(), group_key="group-A"),
|
||||
self._make_mock_doc(uuid4(), group_key="group-B"),
|
||||
self._make_mock_doc(uuid4(), group_key="group-C"),
|
||||
]
|
||||
|
||||
result = builder._assign_splits_by_group(docs, train_ratio=0.7, val_ratio=0.2, seed=42)
|
||||
|
||||
# With 3 groups: 70% train = 2, 20% val = 1 (at least 1)
|
||||
train_count = sum(1 for s in result.values() if s == "train")
|
||||
val_count = sum(1 for s in result.values() if s == "val")
|
||||
assert train_count >= 1
|
||||
assert val_count >= 1 # Ensure val is not empty
|
||||
|
||||
def test_null_group_key_treated_as_single_doc_group(self, tmp_path, mock_admin_db):
|
||||
"""Documents with null/empty group_key are each treated as independent single-doc groups."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
|
||||
docs = [
|
||||
self._make_mock_doc(uuid4(), group_key=None),
|
||||
self._make_mock_doc(uuid4(), group_key=""),
|
||||
self._make_mock_doc(uuid4(), group_key=None),
|
||||
]
|
||||
|
||||
result = builder._assign_splits_by_group(docs, train_ratio=0.7, val_ratio=0.2, seed=42)
|
||||
|
||||
# Each null/empty group_key doc is independent, distributed across splits
|
||||
# With 3 docs: ensure at least 1 in train and 1 in val
|
||||
train_count = sum(1 for s in result.values() if s == "train")
|
||||
val_count = sum(1 for s in result.values() if s == "val")
|
||||
assert train_count >= 1
|
||||
assert val_count >= 1
|
||||
|
||||
def test_multi_doc_groups_stay_together(self, tmp_path, mock_admin_db):
|
||||
"""Documents with same group_key should be assigned to the same split."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
|
||||
# 6 documents in 2 groups
|
||||
docs = [
|
||||
self._make_mock_doc(uuid4(), group_key="supplier-A"),
|
||||
self._make_mock_doc(uuid4(), group_key="supplier-A"),
|
||||
self._make_mock_doc(uuid4(), group_key="supplier-A"),
|
||||
self._make_mock_doc(uuid4(), group_key="supplier-B"),
|
||||
self._make_mock_doc(uuid4(), group_key="supplier-B"),
|
||||
self._make_mock_doc(uuid4(), group_key="supplier-B"),
|
||||
]
|
||||
|
||||
result = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.5, seed=42)
|
||||
|
||||
# All docs in supplier-A should have same split
|
||||
splits_a = [result[str(d.document_id)] for d in docs[:3]]
|
||||
assert len(set(splits_a)) == 1, "All docs in supplier-A should be in same split"
|
||||
|
||||
# All docs in supplier-B should have same split
|
||||
splits_b = [result[str(d.document_id)] for d in docs[3:]]
|
||||
assert len(set(splits_b)) == 1, "All docs in supplier-B should be in same split"
|
||||
|
||||
def test_multi_doc_groups_split_by_ratio(self, tmp_path, mock_admin_db):
|
||||
"""Multi-doc groups should be split according to train/val/test ratios."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
|
||||
# 10 groups with 2 docs each
|
||||
docs = []
|
||||
for i in range(10):
|
||||
group_key = f"group-{i}"
|
||||
docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
|
||||
docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
|
||||
|
||||
result = builder._assign_splits_by_group(docs, train_ratio=0.7, val_ratio=0.2, seed=42)
|
||||
|
||||
# Count groups per split
|
||||
group_splits = {}
|
||||
for doc in docs:
|
||||
split = result[str(doc.document_id)]
|
||||
if doc.group_key not in group_splits:
|
||||
group_splits[doc.group_key] = split
|
||||
else:
|
||||
# Verify same group has same split
|
||||
assert group_splits[doc.group_key] == split
|
||||
|
||||
split_counts = {"train": 0, "val": 0, "test": 0}
|
||||
for split in group_splits.values():
|
||||
split_counts[split] += 1
|
||||
|
||||
# With 10 groups, 70/20/10 -> ~7 train, ~2 val, ~1 test
|
||||
assert split_counts["train"] >= 6
|
||||
assert split_counts["train"] <= 8
|
||||
assert split_counts["val"] >= 1
|
||||
assert split_counts["val"] <= 3
|
||||
|
||||
def test_mixed_single_and_multi_doc_groups(self, tmp_path, mock_admin_db):
|
||||
"""Mix of single-doc and multi-doc groups should be handled correctly."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
|
||||
docs = [
|
||||
# Single-doc groups
|
||||
self._make_mock_doc(uuid4(), group_key="single-1"),
|
||||
self._make_mock_doc(uuid4(), group_key="single-2"),
|
||||
self._make_mock_doc(uuid4(), group_key=None),
|
||||
# Multi-doc groups
|
||||
self._make_mock_doc(uuid4(), group_key="multi-A"),
|
||||
self._make_mock_doc(uuid4(), group_key="multi-A"),
|
||||
self._make_mock_doc(uuid4(), group_key="multi-B"),
|
||||
self._make_mock_doc(uuid4(), group_key="multi-B"),
|
||||
]
|
||||
|
||||
result = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.5, seed=42)
|
||||
|
||||
# All groups are shuffled and distributed
|
||||
# Ensure at least 1 in train and 1 in val
|
||||
train_count = sum(1 for s in result.values() if s == "train")
|
||||
val_count = sum(1 for s in result.values() if s == "val")
|
||||
assert train_count >= 1
|
||||
assert val_count >= 1
|
||||
|
||||
# Multi-doc groups stay together
|
||||
assert result[str(docs[3].document_id)] == result[str(docs[4].document_id)]
|
||||
assert result[str(docs[5].document_id)] == result[str(docs[6].document_id)]
|
||||
|
||||
def test_deterministic_with_seed(self, tmp_path, mock_admin_db):
|
||||
"""Same seed should produce same split assignments."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
|
||||
docs = [
|
||||
self._make_mock_doc(uuid4(), group_key="group-A"),
|
||||
self._make_mock_doc(uuid4(), group_key="group-A"),
|
||||
self._make_mock_doc(uuid4(), group_key="group-B"),
|
||||
self._make_mock_doc(uuid4(), group_key="group-B"),
|
||||
self._make_mock_doc(uuid4(), group_key="group-C"),
|
||||
self._make_mock_doc(uuid4(), group_key="group-C"),
|
||||
]
|
||||
|
||||
result1 = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.3, seed=123)
|
||||
result2 = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.3, seed=123)
|
||||
|
||||
assert result1 == result2
|
||||
|
||||
def test_different_seed_may_produce_different_splits(self, tmp_path, mock_admin_db):
|
||||
"""Different seeds should potentially produce different split assignments."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
|
||||
# Many groups to increase chance of different results
|
||||
docs = []
|
||||
for i in range(20):
|
||||
group_key = f"group-{i}"
|
||||
docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
|
||||
docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
|
||||
|
||||
result1 = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.3, seed=1)
|
||||
result2 = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.3, seed=999)
|
||||
|
||||
# Results should be different (very likely with 20 groups)
|
||||
assert result1 != result2
|
||||
|
||||
def test_all_docs_assigned(self, tmp_path, mock_admin_db):
|
||||
"""Every document should be assigned a split."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
|
||||
docs = [
|
||||
self._make_mock_doc(uuid4(), group_key="group-A"),
|
||||
self._make_mock_doc(uuid4(), group_key="group-A"),
|
||||
self._make_mock_doc(uuid4(), group_key=None),
|
||||
self._make_mock_doc(uuid4(), group_key="single"),
|
||||
]
|
||||
|
||||
result = builder._assign_splits_by_group(docs, train_ratio=0.7, val_ratio=0.2, seed=42)
|
||||
|
||||
assert len(result) == len(docs)
|
||||
for doc in docs:
|
||||
assert str(doc.document_id) in result
|
||||
assert result[str(doc.document_id)] in ["train", "val", "test"]
|
||||
|
||||
def test_empty_documents_list(self, tmp_path, mock_admin_db):
|
||||
"""Empty document list should return empty result."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
|
||||
result = builder._assign_splits_by_group([], train_ratio=0.7, val_ratio=0.2, seed=42)
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_only_multi_doc_groups(self, tmp_path, mock_admin_db):
|
||||
"""When all groups have multiple docs, splits should follow ratios."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
|
||||
# 5 groups with 3 docs each
|
||||
docs = []
|
||||
for i in range(5):
|
||||
group_key = f"group-{i}"
|
||||
for _ in range(3):
|
||||
docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
|
||||
|
||||
result = builder._assign_splits_by_group(docs, train_ratio=0.6, val_ratio=0.2, seed=42)
|
||||
|
||||
# Group splits
|
||||
group_splits = {}
|
||||
for doc in docs:
|
||||
if doc.group_key not in group_splits:
|
||||
group_splits[doc.group_key] = result[str(doc.document_id)]
|
||||
|
||||
split_counts = {"train": 0, "val": 0, "test": 0}
|
||||
for split in group_splits.values():
|
||||
split_counts[split] += 1
|
||||
|
||||
# With 5 groups, 60/20/20 -> 3 train, 1 val, 1 test
|
||||
assert split_counts["train"] >= 2
|
||||
assert split_counts["train"] <= 4
|
||||
|
||||
def test_only_single_doc_groups(self, tmp_path, mock_admin_db):
|
||||
"""When all groups have single doc, they are distributed across splits."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
|
||||
docs = [
|
||||
self._make_mock_doc(uuid4(), group_key="unique-1"),
|
||||
self._make_mock_doc(uuid4(), group_key="unique-2"),
|
||||
self._make_mock_doc(uuid4(), group_key="unique-3"),
|
||||
self._make_mock_doc(uuid4(), group_key=None),
|
||||
self._make_mock_doc(uuid4(), group_key=""),
|
||||
]
|
||||
|
||||
result = builder._assign_splits_by_group(docs, train_ratio=0.6, val_ratio=0.2, seed=42)
|
||||
|
||||
# With 5 groups: 60% train = 3, 20% val = 1 (at least 1)
|
||||
train_count = sum(1 for s in result.values() if s == "train")
|
||||
val_count = sum(1 for s in result.values() if s == "val")
|
||||
assert train_count >= 2
|
||||
assert val_count >= 1 # Ensure val is not empty
|
||||
|
||||
|
||||
class TestBuildDatasetWithGroupKey:
|
||||
"""Integration tests for build_dataset with group_key logic."""
|
||||
|
||||
@pytest.fixture
|
||||
def grouped_documents(self, tmp_path):
|
||||
"""Create documents with various group_key configurations."""
|
||||
doc_ids = []
|
||||
docs = []
|
||||
|
||||
# Create 3 groups: 2 multi-doc groups + 2 single-doc groups
|
||||
group_configs = [
|
||||
("supplier-A", 3), # Multi-doc group: 3 docs
|
||||
("supplier-B", 2), # Multi-doc group: 2 docs
|
||||
("unique-1", 1), # Single-doc group
|
||||
(None, 1), # Null group_key
|
||||
]
|
||||
|
||||
for group_key, count in group_configs:
|
||||
for _ in range(count):
|
||||
doc_id = uuid4()
|
||||
doc_ids.append(doc_id)
|
||||
|
||||
# Create image files
|
||||
doc_dir = tmp_path / "admin_images" / str(doc_id)
|
||||
doc_dir.mkdir(parents=True)
|
||||
for page in range(1, 3):
|
||||
(doc_dir / f"page_{page}.png").write_bytes(b"fake-png")
|
||||
|
||||
# Create mock document
|
||||
doc = MagicMock(spec=AdminDocument)
|
||||
doc.document_id = doc_id
|
||||
doc.filename = f"{doc_id}.pdf"
|
||||
doc.page_count = 2
|
||||
doc.group_key = group_key
|
||||
doc.file_path = str(doc_dir)
|
||||
docs.append(doc)
|
||||
|
||||
return tmp_path, docs
|
||||
|
||||
@pytest.fixture
|
||||
def grouped_annotations(self, grouped_documents):
|
||||
"""Create annotations for grouped documents."""
|
||||
tmp_path, docs = grouped_documents
|
||||
annotations = {}
|
||||
for doc in docs:
|
||||
doc_anns = []
|
||||
for page in range(1, 3):
|
||||
ann = MagicMock(spec=AdminAnnotation)
|
||||
ann.document_id = doc.document_id
|
||||
ann.page_number = page
|
||||
ann.class_id = 0
|
||||
ann.class_name = "invoice_number"
|
||||
ann.x_center = 0.5
|
||||
ann.y_center = 0.3
|
||||
ann.width = 0.2
|
||||
ann.height = 0.05
|
||||
doc_anns.append(ann)
|
||||
annotations[str(doc.document_id)] = doc_anns
|
||||
return annotations
|
||||
|
||||
def test_build_respects_group_key_splits(
|
||||
self, grouped_documents, grouped_annotations, mock_admin_db
|
||||
):
|
||||
"""build_dataset should use group_key for split assignment."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
tmp_path, docs = grouped_documents
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = docs
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
grouped_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in docs],
|
||||
train_ratio=0.5,
|
||||
val_ratio=0.5,
|
||||
seed=42,
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
# Get the document splits from add_dataset_documents call
|
||||
call_args = mock_admin_db.add_dataset_documents.call_args
|
||||
docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
|
||||
|
||||
# Build mapping of doc_id -> split
|
||||
doc_split_map = {d["document_id"]: d["split"] for d in docs_added}
|
||||
|
||||
# Verify all docs are assigned a valid split
|
||||
for doc_id in doc_split_map:
|
||||
assert doc_split_map[doc_id] in ("train", "val", "test")
|
||||
|
||||
# Verify multi-doc groups stay together
|
||||
supplier_a_ids = [str(d.document_id) for d in docs if d.group_key == "supplier-A"]
|
||||
supplier_a_splits = [doc_split_map[doc_id] for doc_id in supplier_a_ids]
|
||||
assert len(set(supplier_a_splits)) == 1, "supplier-A docs should be in same split"
|
||||
|
||||
supplier_b_ids = [str(d.document_id) for d in docs if d.group_key == "supplier-B"]
|
||||
supplier_b_splits = [doc_split_map[doc_id] for doc_id in supplier_b_ids]
|
||||
assert len(set(supplier_b_splits)) == 1, "supplier-B docs should be in same split"
|
||||
|
||||
def test_build_with_all_same_group_key(self, tmp_path, mock_admin_db):
|
||||
"""All docs with same group_key should go to same split."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
# Create 5 docs all with same group_key
|
||||
docs = []
|
||||
for i in range(5):
|
||||
doc_id = uuid4()
|
||||
doc_dir = tmp_path / "admin_images" / str(doc_id)
|
||||
doc_dir.mkdir(parents=True)
|
||||
(doc_dir / "page_1.png").write_bytes(b"fake-png")
|
||||
|
||||
doc = MagicMock(spec=AdminDocument)
|
||||
doc.document_id = doc_id
|
||||
doc.filename = f"{doc_id}.pdf"
|
||||
doc.page_count = 1
|
||||
doc.group_key = "same-group"
|
||||
docs.append(doc)
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = docs
|
||||
mock_admin_db.get_annotations_for_document.return_value = []
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in docs],
|
||||
train_ratio=0.6,
|
||||
val_ratio=0.2,
|
||||
seed=42,
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
call_args = mock_admin_db.add_dataset_documents.call_args
|
||||
docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
|
||||
|
||||
splits = [d["split"] for d in docs_added]
|
||||
# All should be in the same split (one group)
|
||||
assert len(set(splits)) == 1, "All docs with same group_key should be in same split"
|
||||
|
||||
@@ -25,6 +25,9 @@ TEST_DOC_UUID_2 = "990e8400-e29b-41d4-a716-446655440012"
|
||||
TEST_TOKEN = "test-admin-token-12345"
|
||||
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
|
||||
|
||||
# Generate 10 unique UUIDs for minimum document count tests
|
||||
TEST_DOC_UUIDS = [f"990e8400-e29b-41d4-a716-4466554400{i:02d}" for i in range(10, 20)]
|
||||
|
||||
|
||||
def _make_dataset(**overrides) -> MagicMock:
|
||||
defaults = dict(
|
||||
@@ -83,14 +86,14 @@ class TestCreateDatasetRoute:
|
||||
|
||||
mock_builder = MagicMock()
|
||||
mock_builder.build_dataset.return_value = {
|
||||
"total_documents": 2,
|
||||
"total_images": 4,
|
||||
"total_annotations": 10,
|
||||
"total_documents": 10,
|
||||
"total_images": 20,
|
||||
"total_annotations": 50,
|
||||
}
|
||||
|
||||
request = DatasetCreateRequest(
|
||||
name="test-dataset",
|
||||
document_ids=[TEST_DOC_UUID_1, TEST_DOC_UUID_2],
|
||||
document_ids=TEST_DOC_UUIDS, # Use 10 documents to meet minimum
|
||||
)
|
||||
|
||||
with patch(
|
||||
@@ -104,6 +107,73 @@ class TestCreateDatasetRoute:
|
||||
assert result.dataset_id == TEST_DATASET_UUID
|
||||
assert result.name == "test-dataset"
|
||||
|
||||
def test_create_dataset_fails_with_less_than_10_documents(self):
|
||||
"""Test that creating dataset fails if fewer than 10 documents provided."""
|
||||
fn = _find_endpoint("create_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
# Only 2 documents - should fail
|
||||
request = DatasetCreateRequest(
|
||||
name="test-dataset",
|
||||
document_ids=[TEST_DOC_UUID_1, TEST_DOC_UUID_2],
|
||||
)
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "Minimum 10 documents required" in exc_info.value.detail
|
||||
assert "got 2" in exc_info.value.detail
|
||||
# Ensure DB was never called since validation failed first
|
||||
mock_db.create_dataset.assert_not_called()
|
||||
|
||||
def test_create_dataset_fails_with_9_documents(self):
|
||||
"""Test boundary condition: 9 documents should fail."""
|
||||
fn = _find_endpoint("create_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
# 9 documents - just under the limit
|
||||
request = DatasetCreateRequest(
|
||||
name="test-dataset",
|
||||
document_ids=TEST_DOC_UUIDS[:9],
|
||||
)
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "Minimum 10 documents required" in exc_info.value.detail
|
||||
|
||||
def test_create_dataset_succeeds_with_exactly_10_documents(self):
|
||||
"""Test boundary condition: exactly 10 documents should succeed."""
|
||||
fn = _find_endpoint("create_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.create_dataset.return_value = _make_dataset(status="building")
|
||||
|
||||
mock_builder = MagicMock()
|
||||
|
||||
# Exactly 10 documents - should pass
|
||||
request = DatasetCreateRequest(
|
||||
name="test-dataset",
|
||||
document_ids=TEST_DOC_UUIDS[:10],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"inference.web.services.dataset_builder.DatasetBuilder",
|
||||
return_value=mock_builder,
|
||||
):
|
||||
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
mock_db.create_dataset.assert_called_once()
|
||||
assert result.dataset_id == TEST_DATASET_UUID
|
||||
|
||||
|
||||
class TestListDatasetsRoute:
|
||||
"""Tests for GET /admin/training/datasets."""
|
||||
@@ -198,3 +268,53 @@ class TestTrainFromDatasetRoute:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
def test_incremental_training_with_base_model(self):
|
||||
"""Test training with base_model_version_id for incremental training."""
|
||||
fn = _find_endpoint("train_from_dataset")
|
||||
|
||||
mock_model_version = MagicMock()
|
||||
mock_model_version.model_path = "runs/train/invoice_fields/weights/best.pt"
|
||||
mock_model_version.version = "1.0.0"
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_dataset.return_value = _make_dataset(status="ready")
|
||||
mock_db.get_model_version.return_value = mock_model_version
|
||||
mock_db.create_training_task.return_value = TEST_TASK_UUID
|
||||
|
||||
base_model_uuid = "550e8400-e29b-41d4-a716-446655440099"
|
||||
config = TrainingConfig(base_model_version_id=base_model_uuid)
|
||||
request = DatasetTrainRequest(name="incremental-train", config=config)
|
||||
|
||||
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
# Verify model version was looked up
|
||||
mock_db.get_model_version.assert_called_once_with(base_model_uuid)
|
||||
|
||||
# Verify task was created with finetune type
|
||||
call_kwargs = mock_db.create_training_task.call_args[1]
|
||||
assert call_kwargs["task_type"] == "finetune"
|
||||
assert call_kwargs["config"]["base_model_path"] == "runs/train/invoice_fields/weights/best.pt"
|
||||
assert call_kwargs["config"]["base_model_version"] == "1.0.0"
|
||||
|
||||
assert result.task_id == TEST_TASK_UUID
|
||||
assert "Incremental training" in result.message
|
||||
|
||||
def test_incremental_training_with_invalid_base_model_fails(self):
|
||||
"""Test that training fails if base_model_version_id doesn't exist."""
|
||||
fn = _find_endpoint("train_from_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_dataset.return_value = _make_dataset(status="ready")
|
||||
mock_db.get_model_version.return_value = None
|
||||
|
||||
base_model_uuid = "550e8400-e29b-41d4-a716-446655440099"
|
||||
config = TrainingConfig(base_model_version_id=base_model_uuid)
|
||||
request = DatasetTrainRequest(name="incremental-train", config=config)
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
assert exc_info.value.status_code == 404
|
||||
assert "Base model version not found" in exc_info.value.detail
|
||||
|
||||
399
tests/web/test_model_versions.py
Normal file
399
tests/web/test_model_versions.py
Normal file
@@ -0,0 +1,399 @@
|
||||
"""
|
||||
Tests for Model Version API routes.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.admin_models import ModelVersion
|
||||
from inference.web.api.v1.admin.training import create_training_router
|
||||
from inference.web.schemas.admin import (
|
||||
ModelVersionCreateRequest,
|
||||
ModelVersionUpdateRequest,
|
||||
)
|
||||
|
||||
|
||||
TEST_VERSION_UUID = "880e8400-e29b-41d4-a716-446655440020"
|
||||
TEST_VERSION_UUID_2 = "880e8400-e29b-41d4-a716-446655440021"
|
||||
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
|
||||
TEST_DATASET_UUID = "880e8400-e29b-41d4-a716-446655440010"
|
||||
TEST_TOKEN = "test-admin-token-12345"
|
||||
|
||||
|
||||
def _make_model_version(**overrides) -> MagicMock:
|
||||
"""Create a mock ModelVersion."""
|
||||
defaults = dict(
|
||||
version_id=UUID(TEST_VERSION_UUID),
|
||||
version="1.0.0",
|
||||
name="test-model-v1",
|
||||
description="Test model version",
|
||||
model_path="/models/test-model-v1.pt",
|
||||
status="inactive",
|
||||
is_active=False,
|
||||
task_id=UUID(TEST_TASK_UUID),
|
||||
dataset_id=UUID(TEST_DATASET_UUID),
|
||||
metrics_mAP=0.935,
|
||||
metrics_precision=0.92,
|
||||
metrics_recall=0.88,
|
||||
document_count=100,
|
||||
training_config={"epochs": 100, "batch_size": 16},
|
||||
file_size=52428800,
|
||||
trained_at=datetime(2025, 1, 15, tzinfo=timezone.utc),
|
||||
activated_at=None,
|
||||
created_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2025, 1, 15, tzinfo=timezone.utc),
|
||||
)
|
||||
defaults.update(overrides)
|
||||
model = MagicMock(spec=ModelVersion)
|
||||
for k, v in defaults.items():
|
||||
setattr(model, k, v)
|
||||
return model
|
||||
|
||||
|
||||
def _find_endpoint(name: str):
|
||||
"""Find endpoint function by name."""
|
||||
router = create_training_router()
|
||||
for route in router.routes:
|
||||
if hasattr(route, "endpoint") and route.endpoint.__name__ == name:
|
||||
return route.endpoint
|
||||
raise AssertionError(f"Endpoint {name} not found")
|
||||
|
||||
|
||||
class TestModelVersionRouterRegistration:
|
||||
"""Tests that model version endpoints are registered."""
|
||||
|
||||
def test_router_has_model_endpoints(self):
|
||||
router = create_training_router()
|
||||
paths = [route.path for route in router.routes]
|
||||
assert any("models" in p for p in paths)
|
||||
|
||||
def test_has_create_model_version_endpoint(self):
|
||||
endpoint = _find_endpoint("create_model_version")
|
||||
assert endpoint is not None
|
||||
|
||||
def test_has_list_model_versions_endpoint(self):
|
||||
endpoint = _find_endpoint("list_model_versions")
|
||||
assert endpoint is not None
|
||||
|
||||
def test_has_get_active_model_endpoint(self):
|
||||
endpoint = _find_endpoint("get_active_model")
|
||||
assert endpoint is not None
|
||||
|
||||
def test_has_activate_model_version_endpoint(self):
|
||||
endpoint = _find_endpoint("activate_model_version")
|
||||
assert endpoint is not None
|
||||
|
||||
|
||||
class TestCreateModelVersionRoute:
|
||||
"""Tests for POST /admin/training/models."""
|
||||
|
||||
def test_create_model_version(self):
|
||||
fn = _find_endpoint("create_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.create_model_version.return_value = _make_model_version()
|
||||
|
||||
request = ModelVersionCreateRequest(
|
||||
version="1.0.0",
|
||||
name="test-model-v1",
|
||||
model_path="/models/test-model-v1.pt",
|
||||
description="Test model",
|
||||
metrics_mAP=0.935,
|
||||
document_count=100,
|
||||
)
|
||||
|
||||
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
mock_db.create_model_version.assert_called_once()
|
||||
assert result.version_id == TEST_VERSION_UUID
|
||||
assert result.status == "inactive"
|
||||
assert result.message == "Model version created successfully"
|
||||
|
||||
def test_create_model_version_with_task_and_dataset(self):
|
||||
fn = _find_endpoint("create_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.create_model_version.return_value = _make_model_version()
|
||||
|
||||
request = ModelVersionCreateRequest(
|
||||
version="1.0.0",
|
||||
name="test-model-v1",
|
||||
model_path="/models/test-model-v1.pt",
|
||||
task_id=TEST_TASK_UUID,
|
||||
dataset_id=TEST_DATASET_UUID,
|
||||
)
|
||||
|
||||
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
call_kwargs = mock_db.create_model_version.call_args[1]
|
||||
assert call_kwargs["task_id"] == TEST_TASK_UUID
|
||||
assert call_kwargs["dataset_id"] == TEST_DATASET_UUID
|
||||
|
||||
|
||||
class TestListModelVersionsRoute:
|
||||
"""Tests for GET /admin/training/models."""
|
||||
|
||||
def test_list_model_versions(self):
|
||||
fn = _find_endpoint("list_model_versions")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_model_versions.return_value = (
|
||||
[_make_model_version(), _make_model_version(version_id=UUID(TEST_VERSION_UUID_2), version="1.1.0")],
|
||||
2,
|
||||
)
|
||||
|
||||
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))
|
||||
|
||||
assert result.total == 2
|
||||
assert len(result.models) == 2
|
||||
assert result.models[0].version == "1.0.0"
|
||||
|
||||
def test_list_model_versions_with_status_filter(self):
|
||||
fn = _find_endpoint("list_model_versions")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_model_versions.return_value = ([_make_model_version(status="active", is_active=True)], 1)
|
||||
|
||||
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status="active", limit=20, offset=0))
|
||||
|
||||
mock_db.get_model_versions.assert_called_once_with(status="active", limit=20, offset=0)
|
||||
assert result.total == 1
|
||||
assert result.models[0].status == "active"
|
||||
|
||||
|
||||
class TestGetActiveModelRoute:
|
||||
"""Tests for GET /admin/training/models/active."""
|
||||
|
||||
def test_get_active_model_when_exists(self):
|
||||
fn = _find_endpoint("get_active_model")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_active_model_version.return_value = _make_model_version(status="active", is_active=True)
|
||||
|
||||
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
assert result.has_active_model is True
|
||||
assert result.model is not None
|
||||
assert result.model.is_active is True
|
||||
|
||||
def test_get_active_model_when_none(self):
|
||||
fn = _find_endpoint("get_active_model")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_active_model_version.return_value = None
|
||||
|
||||
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
assert result.has_active_model is False
|
||||
assert result.model is None
|
||||
|
||||
|
||||
class TestGetModelVersionRoute:
|
||||
"""Tests for GET /admin/training/models/{version_id}."""
|
||||
|
||||
def test_get_model_version(self):
|
||||
fn = _find_endpoint("get_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_model_version.return_value = _make_model_version()
|
||||
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
assert result.version_id == TEST_VERSION_UUID
|
||||
assert result.version == "1.0.0"
|
||||
assert result.name == "test-model-v1"
|
||||
assert result.metrics_mAP == 0.935
|
||||
|
||||
def test_get_model_version_not_found(self):
|
||||
fn = _find_endpoint("get_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_model_version.return_value = None
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
class TestUpdateModelVersionRoute:
|
||||
"""Tests for PATCH /admin/training/models/{version_id}."""
|
||||
|
||||
def test_update_model_version(self):
|
||||
fn = _find_endpoint("update_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.update_model_version.return_value = _make_model_version(name="updated-name")
|
||||
|
||||
request = ModelVersionUpdateRequest(name="updated-name", description="Updated description")
|
||||
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
mock_db.update_model_version.assert_called_once_with(
|
||||
version_id=TEST_VERSION_UUID,
|
||||
name="updated-name",
|
||||
description="Updated description",
|
||||
status=None,
|
||||
)
|
||||
assert result.message == "Model version updated successfully"
|
||||
|
||||
def test_update_model_version_not_found(self):
|
||||
fn = _find_endpoint("update_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.update_model_version.return_value = None
|
||||
|
||||
request = ModelVersionUpdateRequest(name="updated-name")
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
class TestActivateModelVersionRoute:
|
||||
"""Tests for POST /admin/training/models/{version_id}/activate."""
|
||||
|
||||
def test_activate_model_version(self):
|
||||
fn = _find_endpoint("activate_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.activate_model_version.return_value = _make_model_version(status="active", is_active=True)
|
||||
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
mock_db.activate_model_version.assert_called_once_with(TEST_VERSION_UUID)
|
||||
assert result.status == "active"
|
||||
assert result.message == "Model version activated for inference"
|
||||
|
||||
def test_activate_model_version_not_found(self):
|
||||
fn = _find_endpoint("activate_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.activate_model_version.return_value = None
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
class TestDeactivateModelVersionRoute:
|
||||
"""Tests for POST /admin/training/models/{version_id}/deactivate."""
|
||||
|
||||
def test_deactivate_model_version(self):
|
||||
fn = _find_endpoint("deactivate_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.deactivate_model_version.return_value = _make_model_version(status="inactive", is_active=False)
|
||||
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
assert result.status == "inactive"
|
||||
assert result.message == "Model version deactivated"
|
||||
|
||||
def test_deactivate_model_version_not_found(self):
|
||||
fn = _find_endpoint("deactivate_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.deactivate_model_version.return_value = None
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
class TestArchiveModelVersionRoute:
|
||||
"""Tests for POST /admin/training/models/{version_id}/archive."""
|
||||
|
||||
def test_archive_model_version(self):
|
||||
fn = _find_endpoint("archive_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.archive_model_version.return_value = _make_model_version(status="archived")
|
||||
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
assert result.status == "archived"
|
||||
assert result.message == "Model version archived"
|
||||
|
||||
def test_archive_active_model_fails(self):
|
||||
fn = _find_endpoint("archive_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.archive_model_version.return_value = None
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
|
||||
class TestDeleteModelVersionRoute:
|
||||
"""Tests for DELETE /admin/training/models/{version_id}."""
|
||||
|
||||
def test_delete_model_version(self):
|
||||
fn = _find_endpoint("delete_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.delete_model_version.return_value = True
|
||||
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
|
||||
mock_db.delete_model_version.assert_called_once_with(TEST_VERSION_UUID)
|
||||
assert result["message"] == "Model version deleted"
|
||||
|
||||
def test_delete_active_model_fails(self):
|
||||
fn = _find_endpoint("delete_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.delete_model_version.return_value = False
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
|
||||
class TestModelVersionSchemas:
|
||||
"""Tests for model version Pydantic schemas."""
|
||||
|
||||
def test_create_request_validation(self):
|
||||
request = ModelVersionCreateRequest(
|
||||
version="1.0.0",
|
||||
name="test-model",
|
||||
model_path="/models/test.pt",
|
||||
)
|
||||
assert request.version == "1.0.0"
|
||||
assert request.name == "test-model"
|
||||
assert request.document_count == 0
|
||||
|
||||
def test_create_request_with_metrics(self):
|
||||
request = ModelVersionCreateRequest(
|
||||
version="2.0.0",
|
||||
name="test-model-v2",
|
||||
model_path="/models/v2.pt",
|
||||
metrics_mAP=0.95,
|
||||
metrics_precision=0.92,
|
||||
metrics_recall=0.88,
|
||||
document_count=500,
|
||||
)
|
||||
assert request.metrics_mAP == 0.95
|
||||
assert request.document_count == 500
|
||||
|
||||
def test_update_request_partial(self):
|
||||
request = ModelVersionUpdateRequest(name="new-name")
|
||||
assert request.name == "new-name"
|
||||
assert request.description is None
|
||||
assert request.status is None
|
||||
Reference in New Issue
Block a user