""" Tests for Data Mixing Service. Tests cover: 1. get_mixing_ratio boundary values 2. build_mixed_dataset with temp filesystem 3. _find_pool_images matching logic 4. _image_to_label_path conversion 5. Edge cases (empty pool, no old data, cap) """ import pytest from pathlib import Path from uuid import uuid4 from backend.web.services.data_mixer import ( get_mixing_ratio, build_mixed_dataset, _collect_images, _image_to_label_path, _find_pool_images, MIXING_RATIOS, DEFAULT_MULTIPLIER, MAX_OLD_SAMPLES, MIN_POOL_SIZE, ) # ============================================================================= # Test Constants # ============================================================================= class TestConstants: """Tests for data mixer constants.""" def test_mixing_ratios_defined(self): """MIXING_RATIOS should have expected entries.""" assert len(MIXING_RATIOS) == 4 assert MIXING_RATIOS[0] == (10, 50) assert MIXING_RATIOS[1] == (50, 20) assert MIXING_RATIOS[2] == (200, 10) assert MIXING_RATIOS[3] == (500, 5) def test_default_multiplier(self): """DEFAULT_MULTIPLIER should be 5.""" assert DEFAULT_MULTIPLIER == 5 def test_max_old_samples(self): """MAX_OLD_SAMPLES should be 3000.""" assert MAX_OLD_SAMPLES == 3000 def test_min_pool_size(self): """MIN_POOL_SIZE should be 50.""" assert MIN_POOL_SIZE == 50 # ============================================================================= # Test get_mixing_ratio # ============================================================================= class TestGetMixingRatio: """Tests for get_mixing_ratio function.""" def test_1_sample_returns_50x(self): """1 new sample should get 50x old data.""" assert get_mixing_ratio(1) == 50 def test_10_samples_returns_50x(self): """10 new samples (boundary) should get 50x.""" assert get_mixing_ratio(10) == 50 def test_11_samples_returns_20x(self): """11 new samples should get 20x.""" assert get_mixing_ratio(11) == 20 def test_50_samples_returns_20x(self): """50 new samples (boundary) should get 20x.""" assert get_mixing_ratio(50) == 20 def test_51_samples_returns_10x(self): """51 new samples should get 10x.""" assert get_mixing_ratio(51) == 10 def test_200_samples_returns_10x(self): """200 new samples (boundary) should get 10x.""" assert get_mixing_ratio(200) == 10 def test_201_samples_returns_5x(self): """201 new samples should get 5x.""" assert get_mixing_ratio(201) == 5 def test_500_samples_returns_5x(self): """500 new samples (boundary) should get 5x.""" assert get_mixing_ratio(500) == 5 def test_1000_samples_returns_default(self): """1000+ samples should get default multiplier (5x).""" assert get_mixing_ratio(1000) == DEFAULT_MULTIPLIER # ============================================================================= # Test _collect_images # ============================================================================= class TestCollectImages: """Tests for _collect_images function.""" def test_collects_png_files(self, tmp_path: Path): """Should collect .png files.""" (tmp_path / "img1.png").write_bytes(b"fake png") (tmp_path / "img2.png").write_bytes(b"fake png") images = _collect_images(tmp_path) assert len(images) == 2 def test_collects_jpg_files(self, tmp_path: Path): """Should collect .jpg files.""" (tmp_path / "img1.jpg").write_bytes(b"fake jpg") images = _collect_images(tmp_path) assert len(images) == 1 def test_collects_both_types(self, tmp_path: Path): """Should collect both .png and .jpg files.""" (tmp_path / "img1.png").write_bytes(b"fake png") (tmp_path / "img2.jpg").write_bytes(b"fake jpg") images = _collect_images(tmp_path) assert len(images) == 2 def test_ignores_other_files(self, tmp_path: Path): """Should ignore non-image files.""" (tmp_path / "data.txt").write_text("not an image") (tmp_path / "data.yaml").write_text("yaml") (tmp_path / "img.png").write_bytes(b"png") images = _collect_images(tmp_path) assert len(images) == 1 def test_returns_empty_for_nonexistent_dir(self, tmp_path: Path): """Should return empty list for nonexistent directory.""" images = _collect_images(tmp_path / "nonexistent") assert images == [] # ============================================================================= # Test _image_to_label_path # ============================================================================= class TestImageToLabelPath: """Tests for _image_to_label_path function.""" def test_converts_train_image_to_label(self, tmp_path: Path): """Should convert images/train/img.png to labels/train/img.txt.""" image_path = tmp_path / "dataset" / "images" / "train" / "doc1_page1.png" label_path = _image_to_label_path(image_path) assert label_path.name == "doc1_page1.txt" assert "labels" in str(label_path) assert "train" in str(label_path) def test_converts_val_image_to_label(self, tmp_path: Path): """Should convert images/val/img.jpg to labels/val/img.txt.""" image_path = tmp_path / "dataset" / "images" / "val" / "doc2_page3.jpg" label_path = _image_to_label_path(image_path) assert label_path.name == "doc2_page3.txt" assert "labels" in str(label_path) assert "val" in str(label_path) # ============================================================================= # Test _find_pool_images # ============================================================================= class TestFindPoolImages: """Tests for _find_pool_images function.""" def _create_dataset(self, base_path: Path, doc_ids: list[str], split: str = "train") -> None: """Helper to create a dataset structure with images.""" images_dir = base_path / "images" / split images_dir.mkdir(parents=True, exist_ok=True) for doc_id in doc_ids: (images_dir / f"{doc_id}_page1.png").write_bytes(b"img") (images_dir / f"{doc_id}_page2.png").write_bytes(b"img") def test_finds_matching_images(self, tmp_path: Path): """Should find images matching pool document IDs.""" doc_id1 = str(uuid4()) doc_id2 = str(uuid4()) self._create_dataset(tmp_path, [doc_id1, doc_id2]) pool_ids = {doc_id1} images = _find_pool_images(tmp_path, pool_ids) assert len(images) == 2 # 2 pages for doc_id1 assert all(doc_id1 in str(img) for img in images) def test_ignores_non_pool_images(self, tmp_path: Path): """Should not return images for documents not in pool.""" doc_id1 = str(uuid4()) doc_id2 = str(uuid4()) self._create_dataset(tmp_path, [doc_id1, doc_id2]) pool_ids = {doc_id1} images = _find_pool_images(tmp_path, pool_ids) # Only doc_id1 images should be found for img in images: assert doc_id1 in str(img) assert doc_id2 not in str(img) def test_searches_all_splits(self, tmp_path: Path): """Should search train, val, and test splits.""" doc_id = str(uuid4()) for split in ("train", "val", "test"): self._create_dataset(tmp_path, [doc_id], split=split) images = _find_pool_images(tmp_path, {doc_id}) assert len(images) == 6 # 2 pages * 3 splits def test_empty_pool_returns_empty(self, tmp_path: Path): """Should return empty list for empty pool IDs.""" self._create_dataset(tmp_path, [str(uuid4())]) images = _find_pool_images(tmp_path, set()) assert images == [] # ============================================================================= # Test build_mixed_dataset # ============================================================================= class TestBuildMixedDataset: """Tests for build_mixed_dataset function.""" def _setup_base_dataset(self, base_path: Path, num_old: int = 20) -> None: """Create a base dataset with old training images.""" for split in ("train", "val"): img_dir = base_path / "images" / split lbl_dir = base_path / "labels" / split img_dir.mkdir(parents=True, exist_ok=True) lbl_dir.mkdir(parents=True, exist_ok=True) count = int(num_old * 0.8) if split == "train" else num_old - int(num_old * 0.8) for i in range(count): doc_id = str(uuid4()) img_file = img_dir / f"{doc_id}_page1.png" lbl_file = lbl_dir / f"{doc_id}_page1.txt" img_file.write_bytes(b"fake image data") lbl_file.write_text("0 0.5 0.5 0.1 0.1\n") def _setup_pool_images(self, base_path: Path, doc_ids: list[str]) -> None: """Add pool images to the base dataset.""" img_dir = base_path / "images" / "train" lbl_dir = base_path / "labels" / "train" img_dir.mkdir(parents=True, exist_ok=True) lbl_dir.mkdir(parents=True, exist_ok=True) for doc_id in doc_ids: img_file = img_dir / f"{doc_id}_page1.png" lbl_file = lbl_dir / f"{doc_id}_page1.txt" img_file.write_bytes(b"pool image data") lbl_file.write_text("0 0.5 0.5 0.2 0.2\n") @pytest.fixture def base_dataset(self, tmp_path: Path) -> Path: """Create a base dataset for testing.""" base_path = tmp_path / "base_dataset" self._setup_base_dataset(base_path, num_old=20) return base_path def test_builds_output_structure(self, base_dataset: Path, tmp_path: Path): """Should create proper YOLO directory structure.""" pool_ids = [uuid4() for _ in range(5)] self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids]) output_dir = tmp_path / "mixed_output" result = build_mixed_dataset( pool_document_ids=pool_ids, base_dataset_path=base_dataset, output_dir=output_dir, ) assert (output_dir / "images" / "train").exists() assert (output_dir / "images" / "val").exists() assert (output_dir / "labels" / "train").exists() assert (output_dir / "labels" / "val").exists() assert (output_dir / "data.yaml").exists() def test_returns_correct_metadata(self, base_dataset: Path, tmp_path: Path): """Should return correct counts and metadata.""" pool_ids = [uuid4() for _ in range(5)] self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids]) output_dir = tmp_path / "mixed_output" result = build_mixed_dataset( pool_document_ids=pool_ids, base_dataset_path=base_dataset, output_dir=output_dir, ) assert "data_yaml" in result assert "total_images" in result assert "old_images" in result assert "new_images" in result assert "mixing_ratio" in result assert result["total_images"] == result["old_images"] + result["new_images"] def test_mixing_ratio_applied(self, base_dataset: Path, tmp_path: Path): """Should use correct mixing ratio based on pool size.""" pool_ids = [uuid4() for _ in range(5)] self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids]) output_dir = tmp_path / "mixed_output" result = build_mixed_dataset( pool_document_ids=pool_ids, base_dataset_path=base_dataset, output_dir=output_dir, ) # 5 new samples -> 50x multiplier assert result["mixing_ratio"] == 50 def test_seed_reproducibility(self, base_dataset: Path, tmp_path: Path): """Same seed should produce same output.""" pool_ids = [uuid4() for _ in range(3)] self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids]) out1 = tmp_path / "out1" out2 = tmp_path / "out2" r1 = build_mixed_dataset(pool_ids, base_dataset, out1, seed=42) r2 = build_mixed_dataset(pool_ids, base_dataset, out2, seed=42) assert r1["old_images"] == r2["old_images"] assert r1["new_images"] == r2["new_images"] assert r1["total_images"] == r2["total_images"]