""" Tests for DatasetAugmenter. TDD Phase 1: RED - Write tests first, then implement to pass. """ import tempfile from pathlib import Path import numpy as np import pytest from PIL import Image class TestDatasetAugmenter: """Tests for DatasetAugmenter class.""" @pytest.fixture def sample_dataset(self, tmp_path: Path) -> Path: """Create a sample YOLO dataset structure.""" dataset_dir = tmp_path / "dataset" # Create directory structure for split in ["train", "val", "test"]: (dataset_dir / "images" / split).mkdir(parents=True) (dataset_dir / "labels" / split).mkdir(parents=True) # Create sample images and labels for i in range(3): # Create 100x100 white image img = Image.new("RGB", (100, 100), color="white") img_path = dataset_dir / "images" / "train" / f"doc_{i}.png" img.save(img_path) # Create label with 2 bboxes # Format: class_id x_center y_center width height label_content = "0 0.5 0.3 0.2 0.1\n1 0.7 0.6 0.15 0.2\n" label_path = dataset_dir / "labels" / "train" / f"doc_{i}.txt" label_path.write_text(label_content) # Create data.yaml data_yaml = dataset_dir / "data.yaml" data_yaml.write_text( "path: .\n" "train: images/train\n" "val: images/val\n" "test: images/test\n" "nc: 10\n" "names: [class0, class1, class2, class3, class4, class5, class6, class7, class8, class9]\n" ) return dataset_dir @pytest.fixture def augmentation_config(self) -> dict: """Create a sample augmentation config.""" return { "gaussian_noise": { "enabled": True, "probability": 1.0, "params": {"std": 10}, }, "gaussian_blur": { "enabled": True, "probability": 1.0, "params": {"kernel_size": 3}, }, } def test_augmenter_creates_additional_images( self, sample_dataset: Path, augmentation_config: dict ): """Test that augmenter creates new augmented images.""" from shared.augmentation.dataset_augmenter import DatasetAugmenter augmenter = DatasetAugmenter(augmentation_config) # Count original images original_count = len(list((sample_dataset / "images" / "train").glob("*.png"))) assert original_count == 3 # Apply augmentation with multiplier=2 result = augmenter.augment_dataset(sample_dataset, multiplier=2) # Should now have original + 2x augmented = 3 + 6 = 9 images new_count = len(list((sample_dataset / "images" / "train").glob("*.png"))) assert new_count == 9 assert result["augmented_images"] == 6 def test_augmenter_creates_matching_labels( self, sample_dataset: Path, augmentation_config: dict ): """Test that augmenter creates label files for each augmented image.""" from shared.augmentation.dataset_augmenter import DatasetAugmenter augmenter = DatasetAugmenter(augmentation_config) augmenter.augment_dataset(sample_dataset, multiplier=2) # Check that each image has a matching label file images = list((sample_dataset / "images" / "train").glob("*.png")) labels = list((sample_dataset / "labels" / "train").glob("*.txt")) assert len(images) == len(labels) # Check that augmented images have corresponding labels for img_path in images: label_path = sample_dataset / "labels" / "train" / f"{img_path.stem}.txt" assert label_path.exists(), f"Missing label for {img_path.name}" def test_augmented_labels_have_valid_format( self, sample_dataset: Path, augmentation_config: dict ): """Test that augmented label files have valid YOLO format.""" from shared.augmentation.dataset_augmenter import DatasetAugmenter augmenter = DatasetAugmenter(augmentation_config) augmenter.augment_dataset(sample_dataset, multiplier=1) # Check all label files for label_path in (sample_dataset / "labels" / "train").glob("*.txt"): content = label_path.read_text().strip() if not content: continue # Empty labels are valid (background images) for line in content.split("\n"): parts = line.split() assert len(parts) == 5, f"Invalid label format in {label_path.name}" class_id = int(parts[0]) x_center = float(parts[1]) y_center = float(parts[2]) width = float(parts[3]) height = float(parts[4]) # Check values are in valid range assert 0 <= class_id < 100, f"Invalid class_id: {class_id}" assert 0 <= x_center <= 1, f"Invalid x_center: {x_center}" assert 0 <= y_center <= 1, f"Invalid y_center: {y_center}" assert 0 <= width <= 1, f"Invalid width: {width}" assert 0 <= height <= 1, f"Invalid height: {height}" def test_augmented_images_are_different( self, sample_dataset: Path, augmentation_config: dict ): """Test that augmented images are actually different from originals.""" from shared.augmentation.dataset_augmenter import DatasetAugmenter # Load original image original_path = sample_dataset / "images" / "train" / "doc_0.png" original_img = np.array(Image.open(original_path)) augmenter = DatasetAugmenter(augmentation_config) augmenter.augment_dataset(sample_dataset, multiplier=1) # Find augmented version aug_path = sample_dataset / "images" / "train" / "doc_0_aug0.png" assert aug_path.exists() aug_img = np.array(Image.open(aug_path)) # Images should be different (due to noise/blur) assert not np.array_equal(original_img, aug_img) def test_augmented_images_same_size( self, sample_dataset: Path, augmentation_config: dict ): """Test that augmented images have same size as originals.""" from shared.augmentation.dataset_augmenter import DatasetAugmenter # Get original size original_path = sample_dataset / "images" / "train" / "doc_0.png" original_img = Image.open(original_path) original_size = original_img.size augmenter = DatasetAugmenter(augmentation_config) augmenter.augment_dataset(sample_dataset, multiplier=1) # Check all augmented images have same size for img_path in (sample_dataset / "images" / "train").glob("*_aug*.png"): img = Image.open(img_path) assert img.size == original_size, f"{img_path.name} has wrong size" def test_perspective_warp_updates_bboxes(self, sample_dataset: Path): """Test that perspective_warp augmentation updates bbox coordinates.""" from shared.augmentation.dataset_augmenter import DatasetAugmenter config = { "perspective_warp": { "enabled": True, "probability": 1.0, "params": {"max_warp": 0.05}, # Use larger warp for visible difference }, } # Read original label original_label = (sample_dataset / "labels" / "train" / "doc_0.txt").read_text() original_bboxes = [line.split() for line in original_label.strip().split("\n")] augmenter = DatasetAugmenter(config) augmenter.augment_dataset(sample_dataset, multiplier=1) # Read augmented label aug_label = (sample_dataset / "labels" / "train" / "doc_0_aug0.txt").read_text() aug_bboxes = [line.split() for line in aug_label.strip().split("\n")] # Same number of bboxes assert len(original_bboxes) == len(aug_bboxes) # At least one bbox should have different coordinates # (perspective warp changes geometry) differences_found = False for orig, aug in zip(original_bboxes, aug_bboxes): # Class ID should be same assert orig[0] == aug[0] # Coordinates might differ if orig[1:] != aug[1:]: differences_found = True assert differences_found, "Perspective warp should change bbox coordinates" def test_augmenter_only_processes_train_split( self, sample_dataset: Path, augmentation_config: dict ): """Test that augmenter only processes train split by default.""" from shared.augmentation.dataset_augmenter import DatasetAugmenter # Add a val image val_img = Image.new("RGB", (100, 100), color="white") val_img.save(sample_dataset / "images" / "val" / "val_doc.png") (sample_dataset / "labels" / "val" / "val_doc.txt").write_text("0 0.5 0.5 0.1 0.1\n") augmenter = DatasetAugmenter(augmentation_config) augmenter.augment_dataset(sample_dataset, multiplier=2) # Val should still have only 1 image val_count = len(list((sample_dataset / "images" / "val").glob("*.png"))) assert val_count == 1 def test_augmenter_with_multiplier_zero_does_nothing( self, sample_dataset: Path, augmentation_config: dict ): """Test that multiplier=0 creates no augmented images.""" from shared.augmentation.dataset_augmenter import DatasetAugmenter original_count = len(list((sample_dataset / "images" / "train").glob("*.png"))) augmenter = DatasetAugmenter(augmentation_config) result = augmenter.augment_dataset(sample_dataset, multiplier=0) new_count = len(list((sample_dataset / "images" / "train").glob("*.png"))) assert new_count == original_count assert result["augmented_images"] == 0 def test_augmenter_with_seed_is_reproducible( self, sample_dataset: Path, augmentation_config: dict ): """Test that same seed produces same augmentation results.""" from shared.augmentation.dataset_augmenter import DatasetAugmenter # Create two separate datasets import shutil dataset1 = sample_dataset dataset2 = sample_dataset.parent / "dataset2" shutil.copytree(dataset1, dataset2) # Augment both with same seed augmenter1 = DatasetAugmenter(augmentation_config, seed=42) augmenter1.augment_dataset(dataset1, multiplier=1) augmenter2 = DatasetAugmenter(augmentation_config, seed=42) augmenter2.augment_dataset(dataset2, multiplier=1) # Compare augmented images aug1 = np.array(Image.open(dataset1 / "images" / "train" / "doc_0_aug0.png")) aug2 = np.array(Image.open(dataset2 / "images" / "train" / "doc_0_aug0.png")) assert np.array_equal(aug1, aug2), "Same seed should produce same augmentation" def test_augmenter_returns_summary( self, sample_dataset: Path, augmentation_config: dict ): """Test that augmenter returns a summary of what was done.""" from shared.augmentation.dataset_augmenter import DatasetAugmenter augmenter = DatasetAugmenter(augmentation_config) result = augmenter.augment_dataset(sample_dataset, multiplier=2) assert "original_images" in result assert "augmented_images" in result assert "total_images" in result assert result["original_images"] == 3 assert result["augmented_images"] == 6 assert result["total_images"] == 9