This commit is contained in:
Yaojia Wang
2026-01-30 00:44:21 +01:00
parent d2489a97d4
commit 33ada0350d
79 changed files with 9737 additions and 297 deletions

View File

@@ -0,0 +1,293 @@
"""
Tests for DatasetAugmenter.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
import tempfile
from pathlib import Path
import numpy as np
import pytest
from PIL import Image
class TestDatasetAugmenter:
"""Tests for DatasetAugmenter class."""
@pytest.fixture
def sample_dataset(self, tmp_path: Path) -> Path:
"""Create a sample YOLO dataset structure."""
dataset_dir = tmp_path / "dataset"
# Create directory structure
for split in ["train", "val", "test"]:
(dataset_dir / "images" / split).mkdir(parents=True)
(dataset_dir / "labels" / split).mkdir(parents=True)
# Create sample images and labels
for i in range(3):
# Create 100x100 white image
img = Image.new("RGB", (100, 100), color="white")
img_path = dataset_dir / "images" / "train" / f"doc_{i}.png"
img.save(img_path)
# Create label with 2 bboxes
# Format: class_id x_center y_center width height
label_content = "0 0.5 0.3 0.2 0.1\n1 0.7 0.6 0.15 0.2\n"
label_path = dataset_dir / "labels" / "train" / f"doc_{i}.txt"
label_path.write_text(label_content)
# Create data.yaml
data_yaml = dataset_dir / "data.yaml"
data_yaml.write_text(
"path: .\n"
"train: images/train\n"
"val: images/val\n"
"test: images/test\n"
"nc: 10\n"
"names: [class0, class1, class2, class3, class4, class5, class6, class7, class8, class9]\n"
)
return dataset_dir
@pytest.fixture
def augmentation_config(self) -> dict:
"""Create a sample augmentation config."""
return {
"gaussian_noise": {
"enabled": True,
"probability": 1.0,
"params": {"std": 10},
},
"gaussian_blur": {
"enabled": True,
"probability": 1.0,
"params": {"kernel_size": 3},
},
}
def test_augmenter_creates_additional_images(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that augmenter creates new augmented images."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
augmenter = DatasetAugmenter(augmentation_config)
# Count original images
original_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
assert original_count == 3
# Apply augmentation with multiplier=2
result = augmenter.augment_dataset(sample_dataset, multiplier=2)
# Should now have original + 2x augmented = 3 + 6 = 9 images
new_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
assert new_count == 9
assert result["augmented_images"] == 6
def test_augmenter_creates_matching_labels(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that augmenter creates label files for each augmented image."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
augmenter = DatasetAugmenter(augmentation_config)
augmenter.augment_dataset(sample_dataset, multiplier=2)
# Check that each image has a matching label file
images = list((sample_dataset / "images" / "train").glob("*.png"))
labels = list((sample_dataset / "labels" / "train").glob("*.txt"))
assert len(images) == len(labels)
# Check that augmented images have corresponding labels
for img_path in images:
label_path = sample_dataset / "labels" / "train" / f"{img_path.stem}.txt"
assert label_path.exists(), f"Missing label for {img_path.name}"
def test_augmented_labels_have_valid_format(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that augmented label files have valid YOLO format."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
augmenter = DatasetAugmenter(augmentation_config)
augmenter.augment_dataset(sample_dataset, multiplier=1)
# Check all label files
for label_path in (sample_dataset / "labels" / "train").glob("*.txt"):
content = label_path.read_text().strip()
if not content:
continue # Empty labels are valid (background images)
for line in content.split("\n"):
parts = line.split()
assert len(parts) == 5, f"Invalid label format in {label_path.name}"
class_id = int(parts[0])
x_center = float(parts[1])
y_center = float(parts[2])
width = float(parts[3])
height = float(parts[4])
# Check values are in valid range
assert 0 <= class_id < 100, f"Invalid class_id: {class_id}"
assert 0 <= x_center <= 1, f"Invalid x_center: {x_center}"
assert 0 <= y_center <= 1, f"Invalid y_center: {y_center}"
assert 0 <= width <= 1, f"Invalid width: {width}"
assert 0 <= height <= 1, f"Invalid height: {height}"
def test_augmented_images_are_different(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that augmented images are actually different from originals."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
# Load original image
original_path = sample_dataset / "images" / "train" / "doc_0.png"
original_img = np.array(Image.open(original_path))
augmenter = DatasetAugmenter(augmentation_config)
augmenter.augment_dataset(sample_dataset, multiplier=1)
# Find augmented version
aug_path = sample_dataset / "images" / "train" / "doc_0_aug0.png"
assert aug_path.exists()
aug_img = np.array(Image.open(aug_path))
# Images should be different (due to noise/blur)
assert not np.array_equal(original_img, aug_img)
def test_augmented_images_same_size(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that augmented images have same size as originals."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
# Get original size
original_path = sample_dataset / "images" / "train" / "doc_0.png"
original_img = Image.open(original_path)
original_size = original_img.size
augmenter = DatasetAugmenter(augmentation_config)
augmenter.augment_dataset(sample_dataset, multiplier=1)
# Check all augmented images have same size
for img_path in (sample_dataset / "images" / "train").glob("*_aug*.png"):
img = Image.open(img_path)
assert img.size == original_size, f"{img_path.name} has wrong size"
def test_perspective_warp_updates_bboxes(self, sample_dataset: Path):
"""Test that perspective_warp augmentation updates bbox coordinates."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
config = {
"perspective_warp": {
"enabled": True,
"probability": 1.0,
"params": {"max_warp": 0.05}, # Use larger warp for visible difference
},
}
# Read original label
original_label = (sample_dataset / "labels" / "train" / "doc_0.txt").read_text()
original_bboxes = [line.split() for line in original_label.strip().split("\n")]
augmenter = DatasetAugmenter(config)
augmenter.augment_dataset(sample_dataset, multiplier=1)
# Read augmented label
aug_label = (sample_dataset / "labels" / "train" / "doc_0_aug0.txt").read_text()
aug_bboxes = [line.split() for line in aug_label.strip().split("\n")]
# Same number of bboxes
assert len(original_bboxes) == len(aug_bboxes)
# At least one bbox should have different coordinates
# (perspective warp changes geometry)
differences_found = False
for orig, aug in zip(original_bboxes, aug_bboxes):
# Class ID should be same
assert orig[0] == aug[0]
# Coordinates might differ
if orig[1:] != aug[1:]:
differences_found = True
assert differences_found, "Perspective warp should change bbox coordinates"
def test_augmenter_only_processes_train_split(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that augmenter only processes train split by default."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
# Add a val image
val_img = Image.new("RGB", (100, 100), color="white")
val_img.save(sample_dataset / "images" / "val" / "val_doc.png")
(sample_dataset / "labels" / "val" / "val_doc.txt").write_text("0 0.5 0.5 0.1 0.1\n")
augmenter = DatasetAugmenter(augmentation_config)
augmenter.augment_dataset(sample_dataset, multiplier=2)
# Val should still have only 1 image
val_count = len(list((sample_dataset / "images" / "val").glob("*.png")))
assert val_count == 1
def test_augmenter_with_multiplier_zero_does_nothing(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that multiplier=0 creates no augmented images."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
original_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
augmenter = DatasetAugmenter(augmentation_config)
result = augmenter.augment_dataset(sample_dataset, multiplier=0)
new_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
assert new_count == original_count
assert result["augmented_images"] == 0
def test_augmenter_with_seed_is_reproducible(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that same seed produces same augmentation results."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
# Create two separate datasets
import shutil
dataset1 = sample_dataset
dataset2 = sample_dataset.parent / "dataset2"
shutil.copytree(dataset1, dataset2)
# Augment both with same seed
augmenter1 = DatasetAugmenter(augmentation_config, seed=42)
augmenter1.augment_dataset(dataset1, multiplier=1)
augmenter2 = DatasetAugmenter(augmentation_config, seed=42)
augmenter2.augment_dataset(dataset2, multiplier=1)
# Compare augmented images
aug1 = np.array(Image.open(dataset1 / "images" / "train" / "doc_0_aug0.png"))
aug2 = np.array(Image.open(dataset2 / "images" / "train" / "doc_0_aug0.png"))
assert np.array_equal(aug1, aug2), "Same seed should produce same augmentation"
def test_augmenter_returns_summary(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that augmenter returns a summary of what was done."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
augmenter = DatasetAugmenter(augmentation_config)
result = augmenter.augment_dataset(sample_dataset, multiplier=2)
assert "original_images" in result
assert "augmented_images" in result
assert "total_images" in result
assert result["original_images"] == 3
assert result["augmented_images"] == 6
assert result["total_images"] == 9