294 lines
11 KiB
Python
294 lines
11 KiB
Python
"""
|
|
Tests for DatasetAugmenter.
|
|
|
|
TDD Phase 1: RED - Write tests first, then implement to pass.
|
|
"""
|
|
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import pytest
|
|
from PIL import Image
|
|
|
|
|
|
class TestDatasetAugmenter:
|
|
"""Tests for DatasetAugmenter class."""
|
|
|
|
@pytest.fixture
|
|
def sample_dataset(self, tmp_path: Path) -> Path:
|
|
"""Create a sample YOLO dataset structure."""
|
|
dataset_dir = tmp_path / "dataset"
|
|
|
|
# Create directory structure
|
|
for split in ["train", "val", "test"]:
|
|
(dataset_dir / "images" / split).mkdir(parents=True)
|
|
(dataset_dir / "labels" / split).mkdir(parents=True)
|
|
|
|
# Create sample images and labels
|
|
for i in range(3):
|
|
# Create 100x100 white image
|
|
img = Image.new("RGB", (100, 100), color="white")
|
|
img_path = dataset_dir / "images" / "train" / f"doc_{i}.png"
|
|
img.save(img_path)
|
|
|
|
# Create label with 2 bboxes
|
|
# Format: class_id x_center y_center width height
|
|
label_content = "0 0.5 0.3 0.2 0.1\n1 0.7 0.6 0.15 0.2\n"
|
|
label_path = dataset_dir / "labels" / "train" / f"doc_{i}.txt"
|
|
label_path.write_text(label_content)
|
|
|
|
# Create data.yaml
|
|
data_yaml = dataset_dir / "data.yaml"
|
|
data_yaml.write_text(
|
|
"path: .\n"
|
|
"train: images/train\n"
|
|
"val: images/val\n"
|
|
"test: images/test\n"
|
|
"nc: 10\n"
|
|
"names: [class0, class1, class2, class3, class4, class5, class6, class7, class8, class9]\n"
|
|
)
|
|
|
|
return dataset_dir
|
|
|
|
@pytest.fixture
|
|
def augmentation_config(self) -> dict:
|
|
"""Create a sample augmentation config."""
|
|
return {
|
|
"gaussian_noise": {
|
|
"enabled": True,
|
|
"probability": 1.0,
|
|
"params": {"std": 10},
|
|
},
|
|
"gaussian_blur": {
|
|
"enabled": True,
|
|
"probability": 1.0,
|
|
"params": {"kernel_size": 3},
|
|
},
|
|
}
|
|
|
|
def test_augmenter_creates_additional_images(
|
|
self, sample_dataset: Path, augmentation_config: dict
|
|
):
|
|
"""Test that augmenter creates new augmented images."""
|
|
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
|
|
|
augmenter = DatasetAugmenter(augmentation_config)
|
|
|
|
# Count original images
|
|
original_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
|
|
assert original_count == 3
|
|
|
|
# Apply augmentation with multiplier=2
|
|
result = augmenter.augment_dataset(sample_dataset, multiplier=2)
|
|
|
|
# Should now have original + 2x augmented = 3 + 6 = 9 images
|
|
new_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
|
|
assert new_count == 9
|
|
assert result["augmented_images"] == 6
|
|
|
|
def test_augmenter_creates_matching_labels(
|
|
self, sample_dataset: Path, augmentation_config: dict
|
|
):
|
|
"""Test that augmenter creates label files for each augmented image."""
|
|
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
|
|
|
augmenter = DatasetAugmenter(augmentation_config)
|
|
augmenter.augment_dataset(sample_dataset, multiplier=2)
|
|
|
|
# Check that each image has a matching label file
|
|
images = list((sample_dataset / "images" / "train").glob("*.png"))
|
|
labels = list((sample_dataset / "labels" / "train").glob("*.txt"))
|
|
|
|
assert len(images) == len(labels)
|
|
|
|
# Check that augmented images have corresponding labels
|
|
for img_path in images:
|
|
label_path = sample_dataset / "labels" / "train" / f"{img_path.stem}.txt"
|
|
assert label_path.exists(), f"Missing label for {img_path.name}"
|
|
|
|
def test_augmented_labels_have_valid_format(
|
|
self, sample_dataset: Path, augmentation_config: dict
|
|
):
|
|
"""Test that augmented label files have valid YOLO format."""
|
|
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
|
|
|
augmenter = DatasetAugmenter(augmentation_config)
|
|
augmenter.augment_dataset(sample_dataset, multiplier=1)
|
|
|
|
# Check all label files
|
|
for label_path in (sample_dataset / "labels" / "train").glob("*.txt"):
|
|
content = label_path.read_text().strip()
|
|
if not content:
|
|
continue # Empty labels are valid (background images)
|
|
|
|
for line in content.split("\n"):
|
|
parts = line.split()
|
|
assert len(parts) == 5, f"Invalid label format in {label_path.name}"
|
|
|
|
class_id = int(parts[0])
|
|
x_center = float(parts[1])
|
|
y_center = float(parts[2])
|
|
width = float(parts[3])
|
|
height = float(parts[4])
|
|
|
|
# Check values are in valid range
|
|
assert 0 <= class_id < 100, f"Invalid class_id: {class_id}"
|
|
assert 0 <= x_center <= 1, f"Invalid x_center: {x_center}"
|
|
assert 0 <= y_center <= 1, f"Invalid y_center: {y_center}"
|
|
assert 0 <= width <= 1, f"Invalid width: {width}"
|
|
assert 0 <= height <= 1, f"Invalid height: {height}"
|
|
|
|
def test_augmented_images_are_different(
|
|
self, sample_dataset: Path, augmentation_config: dict
|
|
):
|
|
"""Test that augmented images are actually different from originals."""
|
|
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
|
|
|
# Load original image
|
|
original_path = sample_dataset / "images" / "train" / "doc_0.png"
|
|
original_img = np.array(Image.open(original_path))
|
|
|
|
augmenter = DatasetAugmenter(augmentation_config)
|
|
augmenter.augment_dataset(sample_dataset, multiplier=1)
|
|
|
|
# Find augmented version
|
|
aug_path = sample_dataset / "images" / "train" / "doc_0_aug0.png"
|
|
assert aug_path.exists()
|
|
|
|
aug_img = np.array(Image.open(aug_path))
|
|
|
|
# Images should be different (due to noise/blur)
|
|
assert not np.array_equal(original_img, aug_img)
|
|
|
|
def test_augmented_images_same_size(
|
|
self, sample_dataset: Path, augmentation_config: dict
|
|
):
|
|
"""Test that augmented images have same size as originals."""
|
|
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
|
|
|
# Get original size
|
|
original_path = sample_dataset / "images" / "train" / "doc_0.png"
|
|
original_img = Image.open(original_path)
|
|
original_size = original_img.size
|
|
|
|
augmenter = DatasetAugmenter(augmentation_config)
|
|
augmenter.augment_dataset(sample_dataset, multiplier=1)
|
|
|
|
# Check all augmented images have same size
|
|
for img_path in (sample_dataset / "images" / "train").glob("*_aug*.png"):
|
|
img = Image.open(img_path)
|
|
assert img.size == original_size, f"{img_path.name} has wrong size"
|
|
|
|
def test_perspective_warp_updates_bboxes(self, sample_dataset: Path):
|
|
"""Test that perspective_warp augmentation updates bbox coordinates."""
|
|
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
|
|
|
config = {
|
|
"perspective_warp": {
|
|
"enabled": True,
|
|
"probability": 1.0,
|
|
"params": {"max_warp": 0.05}, # Use larger warp for visible difference
|
|
},
|
|
}
|
|
|
|
# Read original label
|
|
original_label = (sample_dataset / "labels" / "train" / "doc_0.txt").read_text()
|
|
original_bboxes = [line.split() for line in original_label.strip().split("\n")]
|
|
|
|
augmenter = DatasetAugmenter(config)
|
|
augmenter.augment_dataset(sample_dataset, multiplier=1)
|
|
|
|
# Read augmented label
|
|
aug_label = (sample_dataset / "labels" / "train" / "doc_0_aug0.txt").read_text()
|
|
aug_bboxes = [line.split() for line in aug_label.strip().split("\n")]
|
|
|
|
# Same number of bboxes
|
|
assert len(original_bboxes) == len(aug_bboxes)
|
|
|
|
# At least one bbox should have different coordinates
|
|
# (perspective warp changes geometry)
|
|
differences_found = False
|
|
for orig, aug in zip(original_bboxes, aug_bboxes):
|
|
# Class ID should be same
|
|
assert orig[0] == aug[0]
|
|
# Coordinates might differ
|
|
if orig[1:] != aug[1:]:
|
|
differences_found = True
|
|
|
|
assert differences_found, "Perspective warp should change bbox coordinates"
|
|
|
|
def test_augmenter_only_processes_train_split(
|
|
self, sample_dataset: Path, augmentation_config: dict
|
|
):
|
|
"""Test that augmenter only processes train split by default."""
|
|
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
|
|
|
# Add a val image
|
|
val_img = Image.new("RGB", (100, 100), color="white")
|
|
val_img.save(sample_dataset / "images" / "val" / "val_doc.png")
|
|
(sample_dataset / "labels" / "val" / "val_doc.txt").write_text("0 0.5 0.5 0.1 0.1\n")
|
|
|
|
augmenter = DatasetAugmenter(augmentation_config)
|
|
augmenter.augment_dataset(sample_dataset, multiplier=2)
|
|
|
|
# Val should still have only 1 image
|
|
val_count = len(list((sample_dataset / "images" / "val").glob("*.png")))
|
|
assert val_count == 1
|
|
|
|
def test_augmenter_with_multiplier_zero_does_nothing(
|
|
self, sample_dataset: Path, augmentation_config: dict
|
|
):
|
|
"""Test that multiplier=0 creates no augmented images."""
|
|
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
|
|
|
original_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
|
|
|
|
augmenter = DatasetAugmenter(augmentation_config)
|
|
result = augmenter.augment_dataset(sample_dataset, multiplier=0)
|
|
|
|
new_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
|
|
assert new_count == original_count
|
|
assert result["augmented_images"] == 0
|
|
|
|
def test_augmenter_with_seed_is_reproducible(
|
|
self, sample_dataset: Path, augmentation_config: dict
|
|
):
|
|
"""Test that same seed produces same augmentation results."""
|
|
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
|
|
|
# Create two separate datasets
|
|
import shutil
|
|
dataset1 = sample_dataset
|
|
dataset2 = sample_dataset.parent / "dataset2"
|
|
shutil.copytree(dataset1, dataset2)
|
|
|
|
# Augment both with same seed
|
|
augmenter1 = DatasetAugmenter(augmentation_config, seed=42)
|
|
augmenter1.augment_dataset(dataset1, multiplier=1)
|
|
|
|
augmenter2 = DatasetAugmenter(augmentation_config, seed=42)
|
|
augmenter2.augment_dataset(dataset2, multiplier=1)
|
|
|
|
# Compare augmented images
|
|
aug1 = np.array(Image.open(dataset1 / "images" / "train" / "doc_0_aug0.png"))
|
|
aug2 = np.array(Image.open(dataset2 / "images" / "train" / "doc_0_aug0.png"))
|
|
|
|
assert np.array_equal(aug1, aug2), "Same seed should produce same augmentation"
|
|
|
|
def test_augmenter_returns_summary(
|
|
self, sample_dataset: Path, augmentation_config: dict
|
|
):
|
|
"""Test that augmenter returns a summary of what was done."""
|
|
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
|
|
|
augmenter = DatasetAugmenter(augmentation_config)
|
|
result = augmenter.augment_dataset(sample_dataset, multiplier=2)
|
|
|
|
assert "original_images" in result
|
|
assert "augmented_images" in result
|
|
assert "total_images" in result
|
|
assert result["original_images"] == 3
|
|
assert result["augmented_images"] == 6
|
|
assert result["total_images"] == 9
|