""" Comprehensive unit tests for Data Mixing Service. Tests the data mixing service functions for YOLO fine-tuning: - Mixing ratio calculation based on sample counts - Dataset building with old/new sample mixing - Image collection and path conversion - Pool document matching """ from pathlib import Path from uuid import UUID, uuid4 import pytest from backend.web.services.data_mixer import ( get_mixing_ratio, build_mixed_dataset, _collect_images, _image_to_label_path, _find_pool_images, MIXING_RATIOS, DEFAULT_MULTIPLIER, MAX_OLD_SAMPLES, MIN_POOL_SIZE, ) class TestGetMixingRatio: """Tests for get_mixing_ratio function.""" def test_mixing_ratio_at_first_threshold(self): """Test mixing ratio at first threshold boundary (10 samples).""" assert get_mixing_ratio(1) == 50 assert get_mixing_ratio(5) == 50 assert get_mixing_ratio(10) == 50 def test_mixing_ratio_at_second_threshold(self): """Test mixing ratio at second threshold boundary (50 samples).""" assert get_mixing_ratio(11) == 20 assert get_mixing_ratio(30) == 20 assert get_mixing_ratio(50) == 20 def test_mixing_ratio_at_third_threshold(self): """Test mixing ratio at third threshold boundary (200 samples).""" assert get_mixing_ratio(51) == 10 assert get_mixing_ratio(100) == 10 assert get_mixing_ratio(200) == 10 def test_mixing_ratio_at_fourth_threshold(self): """Test mixing ratio at fourth threshold boundary (500 samples).""" assert get_mixing_ratio(201) == 5 assert get_mixing_ratio(350) == 5 assert get_mixing_ratio(500) == 5 def test_mixing_ratio_above_all_thresholds(self): """Test mixing ratio for samples above all thresholds.""" assert get_mixing_ratio(501) == DEFAULT_MULTIPLIER assert get_mixing_ratio(1000) == DEFAULT_MULTIPLIER assert get_mixing_ratio(10000) == DEFAULT_MULTIPLIER def test_mixing_ratio_boundary_values(self): """Test exact threshold boundaries match expected ratios.""" # Verify threshold boundaries from MIXING_RATIOS for threshold, expected_multiplier in MIXING_RATIOS: assert get_mixing_ratio(threshold) == expected_multiplier # One above threshold should give next ratio if threshold < MIXING_RATIOS[-1][0]: next_idx = MIXING_RATIOS.index((threshold, expected_multiplier)) + 1 next_multiplier = MIXING_RATIOS[next_idx][1] assert get_mixing_ratio(threshold + 1) == next_multiplier class TestCollectImages: """Tests for _collect_images function.""" def test_collect_images_empty_directory(self, tmp_path): """Test collecting images from empty directory.""" images_dir = tmp_path / "images" images_dir.mkdir() result = _collect_images(images_dir) assert result == [] def test_collect_images_nonexistent_directory(self, tmp_path): """Test collecting images from non-existent directory.""" images_dir = tmp_path / "nonexistent" result = _collect_images(images_dir) assert result == [] def test_collect_png_images(self, tmp_path): """Test collecting PNG images.""" images_dir = tmp_path / "images" images_dir.mkdir() # Create PNG files (images_dir / "img1.png").touch() (images_dir / "img2.png").touch() (images_dir / "img3.png").touch() result = _collect_images(images_dir) assert len(result) == 3 assert all(img.suffix == ".png" for img in result) # Verify sorted order assert result == sorted(result) def test_collect_jpg_images(self, tmp_path): """Test collecting JPG images.""" images_dir = tmp_path / "images" images_dir.mkdir() # Create JPG files (images_dir / "img1.jpg").touch() (images_dir / "img2.jpg").touch() result = _collect_images(images_dir) assert len(result) == 2 assert all(img.suffix == ".jpg" for img in result) def test_collect_mixed_image_types(self, tmp_path): """Test collecting both PNG and JPG images.""" images_dir = tmp_path / "images" images_dir.mkdir() # Create mixed files (images_dir / "img1.png").touch() (images_dir / "img2.jpg").touch() (images_dir / "img3.png").touch() (images_dir / "img4.jpg").touch() result = _collect_images(images_dir) assert len(result) == 4 # PNG files should come first (sorted separately) png_files = [r for r in result if r.suffix == ".png"] jpg_files = [r for r in result if r.suffix == ".jpg"] assert len(png_files) == 2 assert len(jpg_files) == 2 def test_collect_images_ignores_other_files(self, tmp_path): """Test that non-image files are ignored.""" images_dir = tmp_path / "images" images_dir.mkdir() # Create various files (images_dir / "img1.png").touch() (images_dir / "img2.jpg").touch() (images_dir / "doc.txt").touch() (images_dir / "data.json").touch() (images_dir / "notes.md").touch() result = _collect_images(images_dir) assert len(result) == 2 assert all(img.suffix in [".png", ".jpg"] for img in result) class TestImageToLabelPath: """Tests for _image_to_label_path function.""" def test_image_to_label_path_train(self, tmp_path): """Test converting train image path to label path.""" base = tmp_path / "dataset" image_path = base / "images" / "train" / "doc123_page1.png" label_path = _image_to_label_path(image_path) expected = base / "labels" / "train" / "doc123_page1.txt" assert label_path == expected def test_image_to_label_path_val(self, tmp_path): """Test converting val image path to label path.""" base = tmp_path / "dataset" image_path = base / "images" / "val" / "doc456_page2.jpg" label_path = _image_to_label_path(image_path) expected = base / "labels" / "val" / "doc456_page2.txt" assert label_path == expected def test_image_to_label_path_test(self, tmp_path): """Test converting test image path to label path.""" base = tmp_path / "dataset" image_path = base / "images" / "test" / "doc789_page3.png" label_path = _image_to_label_path(image_path) expected = base / "labels" / "test" / "doc789_page3.txt" assert label_path == expected def test_image_to_label_path_preserves_filename(self, tmp_path): """Test that filename (without extension) is preserved.""" base = tmp_path / "dataset" image_path = base / "images" / "train" / "complex_filename_123_page5.png" label_path = _image_to_label_path(image_path) assert label_path.stem == "complex_filename_123_page5" assert label_path.suffix == ".txt" def test_image_to_label_path_jpg_to_txt(self, tmp_path): """Test that JPG extension is converted to TXT.""" base = tmp_path / "dataset" image_path = base / "images" / "train" / "image.jpg" label_path = _image_to_label_path(image_path) assert label_path.suffix == ".txt" class TestFindPoolImages: """Tests for _find_pool_images function.""" def test_find_pool_images_in_train(self, tmp_path): """Test finding pool images in train split.""" base = tmp_path / "dataset" train_dir = base / "images" / "train" train_dir.mkdir(parents=True) doc_id = str(uuid4()) pool_doc_ids = {doc_id} # Create images (train_dir / f"{doc_id}_page1.png").touch() (train_dir / f"{doc_id}_page2.png").touch() (train_dir / "other_doc_page1.png").touch() result = _find_pool_images(base, pool_doc_ids) assert len(result) == 2 assert all(doc_id in str(img) for img in result) def test_find_pool_images_in_val(self, tmp_path): """Test finding pool images in val split.""" base = tmp_path / "dataset" val_dir = base / "images" / "val" val_dir.mkdir(parents=True) doc_id = str(uuid4()) pool_doc_ids = {doc_id} # Create images (val_dir / f"{doc_id}_page1.png").touch() result = _find_pool_images(base, pool_doc_ids) assert len(result) == 1 assert doc_id in str(result[0]) def test_find_pool_images_across_splits(self, tmp_path): """Test finding pool images across train, val, and test splits.""" base = tmp_path / "dataset" doc_id1 = str(uuid4()) doc_id2 = str(uuid4()) pool_doc_ids = {doc_id1, doc_id2} # Create images in different splits train_dir = base / "images" / "train" val_dir = base / "images" / "val" test_dir = base / "images" / "test" train_dir.mkdir(parents=True) val_dir.mkdir(parents=True) test_dir.mkdir(parents=True) (train_dir / f"{doc_id1}_page1.png").touch() (val_dir / f"{doc_id1}_page2.png").touch() (test_dir / f"{doc_id2}_page1.png").touch() (train_dir / "other_doc_page1.png").touch() result = _find_pool_images(base, pool_doc_ids) assert len(result) == 3 doc1_images = [img for img in result if doc_id1 in str(img)] doc2_images = [img for img in result if doc_id2 in str(img)] assert len(doc1_images) == 2 assert len(doc2_images) == 1 def test_find_pool_images_empty_pool(self, tmp_path): """Test finding images with empty pool.""" base = tmp_path / "dataset" train_dir = base / "images" / "train" train_dir.mkdir(parents=True) (train_dir / "doc123_page1.png").touch() result = _find_pool_images(base, set()) assert len(result) == 0 def test_find_pool_images_no_matches(self, tmp_path): """Test finding images when no documents match pool.""" base = tmp_path / "dataset" train_dir = base / "images" / "train" train_dir.mkdir(parents=True) pool_doc_ids = {str(uuid4())} (train_dir / "other_doc_page1.png").touch() (train_dir / "another_doc_page1.png").touch() result = _find_pool_images(base, pool_doc_ids) assert len(result) == 0 def test_find_pool_images_multiple_pages(self, tmp_path): """Test finding multiple pages for same document.""" base = tmp_path / "dataset" train_dir = base / "images" / "train" train_dir.mkdir(parents=True) doc_id = str(uuid4()) pool_doc_ids = {doc_id} # Create multiple pages for i in range(1, 6): (train_dir / f"{doc_id}_page{i}.png").touch() result = _find_pool_images(base, pool_doc_ids) assert len(result) == 5 def test_find_pool_images_ignores_non_files(self, tmp_path): """Test that directories are ignored.""" base = tmp_path / "dataset" train_dir = base / "images" / "train" train_dir.mkdir(parents=True) doc_id = str(uuid4()) pool_doc_ids = {doc_id} (train_dir / f"{doc_id}_page1.png").touch() (train_dir / "subdir").mkdir() result = _find_pool_images(base, pool_doc_ids) assert len(result) == 1 def test_find_pool_images_nonexistent_splits(self, tmp_path): """Test handling non-existent split directories.""" base = tmp_path / "dataset" # Don't create any directories pool_doc_ids = {str(uuid4())} result = _find_pool_images(base, pool_doc_ids) assert len(result) == 0 class TestBuildMixedDataset: """Tests for build_mixed_dataset function.""" @pytest.fixture def setup_base_dataset(self, tmp_path): """Create a base dataset with old training data.""" base = tmp_path / "base_dataset" # Create directory structure for split in ("train", "val"): (base / "images" / split).mkdir(parents=True) (base / "labels" / split).mkdir(parents=True) # Create old training images and labels for i in range(1, 11): img_path = base / "images" / "train" / f"old_doc_{i}_page1.png" label_path = base / "labels" / "train" / f"old_doc_{i}_page1.txt" img_path.write_text(f"image {i}") label_path.write_text(f"0 0.5 0.5 0.1 0.1") for i in range(1, 6): img_path = base / "images" / "val" / f"old_doc_val_{i}_page1.png" label_path = base / "labels" / "val" / f"old_doc_val_{i}_page1.txt" img_path.write_text(f"val image {i}") label_path.write_text(f"0 0.5 0.5 0.1 0.1") return base @pytest.fixture def setup_pool_documents(self, tmp_path, setup_base_dataset): """Create pool documents in base dataset.""" base = setup_base_dataset pool_ids = [uuid4() for _ in range(5)] # Add pool documents to train split for doc_id in pool_ids: img_path = base / "images" / "train" / f"{doc_id}_page1.png" label_path = base / "labels" / "train" / f"{doc_id}_page1.txt" img_path.write_text(f"pool image {doc_id}") label_path.write_text(f"1 0.5 0.5 0.2 0.2") return base, pool_ids def test_build_mixed_dataset_basic(self, tmp_path, setup_pool_documents): """Test basic mixed dataset building.""" base, pool_ids = setup_pool_documents output_dir = tmp_path / "mixed_output" result = build_mixed_dataset( pool_document_ids=pool_ids, base_dataset_path=base, output_dir=output_dir, seed=42, ) # Verify result structure assert "data_yaml" in result assert "total_images" in result assert "old_images" in result assert "new_images" in result assert "mixing_ratio" in result # Verify counts - new images should be > 0 (at least some were copied) # Note: new images are split 80/20 and copied without overwriting assert result["new_images"] > 0 assert result["old_images"] > 0 assert result["total_images"] == result["old_images"] + result["new_images"] # Verify output structure assert output_dir.exists() assert (output_dir / "images" / "train").exists() assert (output_dir / "images" / "val").exists() assert (output_dir / "labels" / "train").exists() assert (output_dir / "labels" / "val").exists() # Verify data.yaml exists yaml_path = Path(result["data_yaml"]) assert yaml_path.exists() yaml_content = yaml_path.read_text() assert "train: images/train" in yaml_content assert "val: images/val" in yaml_content assert "nc:" in yaml_content assert "names:" in yaml_content def test_build_mixed_dataset_respects_mixing_ratio(self, tmp_path, setup_pool_documents): """Test that mixing ratio is correctly applied.""" base, pool_ids = setup_pool_documents output_dir = tmp_path / "mixed_output" # With 5 pool documents, get_mixing_ratio(5) returns 50 # (because 5 <= 10, first threshold) # So target old_samples = 5 * 50 = 250 # But limited by available data: 10 old train + 5 old val + 5 pool = 20 total result = build_mixed_dataset( pool_document_ids=pool_ids, base_dataset_path=base, output_dir=output_dir, seed=42, ) # Pool images are in the base dataset, so they can be sampled as "old" # Total available: 20 images (15 pure old + 5 pool images) assert result["old_images"] <= 20 # Can't exceed available in base dataset assert result["old_images"] > 0 # Should have some old data assert result["mixing_ratio"] == 50 # Correct ratio for 5 samples def test_build_mixed_dataset_max_old_samples_limit(self, tmp_path): """Test that MAX_OLD_SAMPLES limit is applied.""" base = tmp_path / "base_dataset" # Create directory structure for split in ("train", "val"): (base / "images" / split).mkdir(parents=True) (base / "labels" / split).mkdir(parents=True) # Create MORE than MAX_OLD_SAMPLES old images for i in range(MAX_OLD_SAMPLES + 500): img_path = base / "images" / "train" / f"old_doc_{i}_page1.png" label_path = base / "labels" / "train" / f"old_doc_{i}_page1.txt" img_path.write_text(f"image {i}") label_path.write_text(f"0 0.5 0.5 0.1 0.1") # Create pool documents (100 samples, ratio=10, so target=1000) # But should be capped at MAX_OLD_SAMPLES (3000) pool_ids = [uuid4() for _ in range(100)] for doc_id in pool_ids: img_path = base / "images" / "train" / f"{doc_id}_page1.png" label_path = base / "labels" / "train" / f"{doc_id}_page1.txt" img_path.write_text(f"pool image {doc_id}") label_path.write_text(f"1 0.5 0.5 0.2 0.2") output_dir = tmp_path / "mixed_output" result = build_mixed_dataset( pool_document_ids=pool_ids, base_dataset_path=base, output_dir=output_dir, seed=42, ) # Should be capped at MAX_OLD_SAMPLES assert result["old_images"] <= MAX_OLD_SAMPLES def test_build_mixed_dataset_empty_pool(self, tmp_path, setup_base_dataset): """Test building dataset with empty pool.""" base = setup_base_dataset output_dir = tmp_path / "mixed_output" result = build_mixed_dataset( pool_document_ids=[], base_dataset_path=base, output_dir=output_dir, seed=42, ) # With 0 new samples, all counts should be 0 assert result["new_images"] == 0 assert result["old_images"] == 0 assert result["total_images"] == 0 def test_build_mixed_dataset_no_old_data(self, tmp_path): """Test building dataset with ONLY pool data (no separate old data).""" base = tmp_path / "base_dataset" # Create empty directory structure for split in ("train", "val"): (base / "images" / split).mkdir(parents=True) (base / "labels" / split).mkdir(parents=True) # Create only pool documents # NOTE: These are placed in base dataset train split # So they will be sampled as "old" data first, then skipped as "new" pool_ids = [uuid4() for _ in range(5)] for doc_id in pool_ids: img_path = base / "images" / "train" / f"{doc_id}_page1.png" label_path = base / "labels" / "train" / f"{doc_id}_page1.txt" img_path.write_text(f"pool image {doc_id}") label_path.write_text(f"1 0.5 0.5 0.2 0.2") output_dir = tmp_path / "mixed_output" result = build_mixed_dataset( pool_document_ids=pool_ids, base_dataset_path=base, output_dir=output_dir, seed=42, ) # Pool images are in base dataset, so they get sampled as "old" images # Then when copying "new" images, they're skipped because they already exist # So we expect: old_images > 0, new_images may be 0, total >= 0 assert result["total_images"] > 0 assert result["total_images"] == result["old_images"] + result["new_images"] def test_build_mixed_dataset_train_val_split(self, tmp_path, setup_pool_documents): """Test that images are split into train/val (80/20).""" base, pool_ids = setup_pool_documents output_dir = tmp_path / "mixed_output" result = build_mixed_dataset( pool_document_ids=pool_ids, base_dataset_path=base, output_dir=output_dir, seed=42, ) # Count images in train and val train_images = list((output_dir / "images" / "train").glob("*.png")) val_images = list((output_dir / "images" / "val").glob("*.png")) total_output_images = len(train_images) + len(val_images) # Should match total_images count assert total_output_images == result["total_images"] # Check approximate 80/20 split (allow some variance due to small sample size) if total_output_images > 0: train_ratio = len(train_images) / total_output_images assert 0.6 <= train_ratio <= 0.9 # Allow some variance def test_build_mixed_dataset_reproducible_with_seed(self, tmp_path, setup_pool_documents): """Test that same seed produces same results.""" base, pool_ids = setup_pool_documents output_dir1 = tmp_path / "mixed_output1" output_dir2 = tmp_path / "mixed_output2" result1 = build_mixed_dataset( pool_document_ids=pool_ids, base_dataset_path=base, output_dir=output_dir1, seed=123, ) result2 = build_mixed_dataset( pool_document_ids=pool_ids, base_dataset_path=base, output_dir=output_dir2, seed=123, ) # Same counts assert result1["old_images"] == result2["old_images"] assert result1["new_images"] == result2["new_images"] # Same files in train/val train_files1 = {f.name for f in (output_dir1 / "images" / "train").glob("*.png")} train_files2 = {f.name for f in (output_dir2 / "images" / "train").glob("*.png")} assert train_files1 == train_files2 def test_build_mixed_dataset_different_seeds(self, tmp_path, setup_pool_documents): """Test that different seeds produce different sampling.""" base, pool_ids = setup_pool_documents output_dir1 = tmp_path / "mixed_output1" output_dir2 = tmp_path / "mixed_output2" result1 = build_mixed_dataset( pool_document_ids=pool_ids, base_dataset_path=base, output_dir=output_dir1, seed=123, ) result2 = build_mixed_dataset( pool_document_ids=pool_ids, base_dataset_path=base, output_dir=output_dir2, seed=456, ) # Both should have processed images assert result1["total_images"] > 0 assert result2["total_images"] > 0 # Both should have the same mixing ratio (based on pool size) assert result1["mixing_ratio"] == result2["mixing_ratio"] # File distribution in train/val may differ due to different shuffling train_files1 = {f.name for f in (output_dir1 / "images" / "train").glob("*.png")} train_files2 = {f.name for f in (output_dir2 / "images" / "train").glob("*.png")} # With different seeds, we expect some difference in file distribution # But this is not strictly guaranteed, so we just verify both have files assert len(train_files1) > 0 assert len(train_files2) > 0 def test_build_mixed_dataset_copies_labels(self, tmp_path, setup_pool_documents): """Test that corresponding label files are copied.""" base, pool_ids = setup_pool_documents output_dir = tmp_path / "mixed_output" result = build_mixed_dataset( pool_document_ids=pool_ids, base_dataset_path=base, output_dir=output_dir, seed=42, ) # Count labels train_labels = list((output_dir / "labels" / "train").glob("*.txt")) val_labels = list((output_dir / "labels" / "val").glob("*.txt")) # Each image should have a corresponding label train_images = list((output_dir / "images" / "train").glob("*.png")) val_images = list((output_dir / "images" / "val").glob("*.png")) # Allow label count to be <= image count (in case some labels are missing) assert len(train_labels) <= len(train_images) assert len(val_labels) <= len(val_images) def test_build_mixed_dataset_skips_duplicate_files(self, tmp_path, setup_pool_documents): """Test behavior when running build_mixed_dataset multiple times.""" base, pool_ids = setup_pool_documents output_dir = tmp_path / "mixed_output" # First build result1 = build_mixed_dataset( pool_document_ids=pool_ids, base_dataset_path=base, output_dir=output_dir, seed=42, ) initial_count = result1["total_images"] # Find a file in output and modify it train_images = list((output_dir / "images" / "train").glob("*.png")) if len(train_images) > 0: test_file = train_images[0] test_file.write_text("modified content") # Second build with same seed result2 = build_mixed_dataset( pool_document_ids=pool_ids, base_dataset_path=base, output_dir=output_dir, seed=42, ) # The implementation uses shutil.copy2 which WILL overwrite # So the file will be restored to original content # Just verify the build completed successfully assert result2["total_images"] >= 0 # Verify the file was overwritten (shutil.copy2 overwrites by default) content = test_file.read_text() assert content != "modified content" # Should be restored def test_build_mixed_dataset_handles_jpg_images(self, tmp_path): """Test that JPG images are handled correctly.""" base = tmp_path / "base_dataset" # Create directory structure for split in ("train", "val"): (base / "images" / split).mkdir(parents=True) (base / "labels" / split).mkdir(parents=True) # Create JPG images as old data for i in range(1, 6): img_path = base / "images" / "train" / f"old_doc_{i}_page1.jpg" label_path = base / "labels" / "train" / f"old_doc_{i}_page1.txt" img_path.write_text(f"jpg image {i}") label_path.write_text(f"0 0.5 0.5 0.1 0.1") # Create pool with JPG - use multiple pages to ensure at least one gets copied pool_ids = [uuid4()] doc_id = pool_ids[0] for page_num in range(1, 4): img_path = base / "images" / "train" / f"{doc_id}_page{page_num}.jpg" label_path = base / "labels" / "train" / f"{doc_id}_page{page_num}.txt" img_path.write_text(f"pool jpg {doc_id} page {page_num}") label_path.write_text(f"1 0.5 0.5 0.2 0.2") output_dir = tmp_path / "mixed_output" result = build_mixed_dataset( pool_document_ids=pool_ids, base_dataset_path=base, output_dir=output_dir, seed=42, ) # Should have some new JPG images (at least 1 from the pool) assert result["new_images"] > 0 assert result["old_images"] > 0 # Verify JPG files exist in output all_images = list((output_dir / "images" / "train").glob("*.jpg")) + \ list((output_dir / "images" / "val").glob("*.jpg")) assert len(all_images) > 0 class TestConstants: """Tests for module constants.""" def test_mixing_ratios_structure(self): """Test MIXING_RATIOS constant structure.""" assert isinstance(MIXING_RATIOS, list) assert len(MIXING_RATIOS) == 4 # Verify format: (threshold, multiplier) for item in MIXING_RATIOS: assert isinstance(item, tuple) assert len(item) == 2 assert isinstance(item[0], int) assert isinstance(item[1], int) # Verify thresholds are ascending thresholds = [t for t, _ in MIXING_RATIOS] assert thresholds == sorted(thresholds) # Verify multipliers are descending multipliers = [m for _, m in MIXING_RATIOS] assert multipliers == sorted(multipliers, reverse=True) def test_default_multiplier(self): """Test DEFAULT_MULTIPLIER constant.""" assert DEFAULT_MULTIPLIER == 5 assert DEFAULT_MULTIPLIER == MIXING_RATIOS[-1][1] def test_max_old_samples(self): """Test MAX_OLD_SAMPLES constant.""" assert MAX_OLD_SAMPLES == 3000 assert MAX_OLD_SAMPLES > 0 def test_min_pool_size(self): """Test MIN_POOL_SIZE constant.""" assert MIN_POOL_SIZE == 50 assert MIN_POOL_SIZE > 0