This commit is contained in:
Yaojia Wang
2026-02-11 23:40:38 +01:00
parent f1a7bfe6b7
commit ad5ed46b4c
117 changed files with 5741 additions and 7669 deletions

View File

@@ -0,0 +1,344 @@
"""
Tests for Data Mixing Service.
Tests cover:
1. get_mixing_ratio boundary values
2. build_mixed_dataset with temp filesystem
3. _find_pool_images matching logic
4. _image_to_label_path conversion
5. Edge cases (empty pool, no old data, cap)
"""
import pytest
from pathlib import Path
from uuid import uuid4
from backend.web.services.data_mixer import (
get_mixing_ratio,
build_mixed_dataset,
_collect_images,
_image_to_label_path,
_find_pool_images,
MIXING_RATIOS,
DEFAULT_MULTIPLIER,
MAX_OLD_SAMPLES,
MIN_POOL_SIZE,
)
# =============================================================================
# Test Constants
# =============================================================================
class TestConstants:
"""Tests for data mixer constants."""
def test_mixing_ratios_defined(self):
"""MIXING_RATIOS should have expected entries."""
assert len(MIXING_RATIOS) == 4
assert MIXING_RATIOS[0] == (10, 50)
assert MIXING_RATIOS[1] == (50, 20)
assert MIXING_RATIOS[2] == (200, 10)
assert MIXING_RATIOS[3] == (500, 5)
def test_default_multiplier(self):
"""DEFAULT_MULTIPLIER should be 5."""
assert DEFAULT_MULTIPLIER == 5
def test_max_old_samples(self):
"""MAX_OLD_SAMPLES should be 3000."""
assert MAX_OLD_SAMPLES == 3000
def test_min_pool_size(self):
"""MIN_POOL_SIZE should be 50."""
assert MIN_POOL_SIZE == 50
# =============================================================================
# Test get_mixing_ratio
# =============================================================================
class TestGetMixingRatio:
"""Tests for get_mixing_ratio function."""
def test_1_sample_returns_50x(self):
"""1 new sample should get 50x old data."""
assert get_mixing_ratio(1) == 50
def test_10_samples_returns_50x(self):
"""10 new samples (boundary) should get 50x."""
assert get_mixing_ratio(10) == 50
def test_11_samples_returns_20x(self):
"""11 new samples should get 20x."""
assert get_mixing_ratio(11) == 20
def test_50_samples_returns_20x(self):
"""50 new samples (boundary) should get 20x."""
assert get_mixing_ratio(50) == 20
def test_51_samples_returns_10x(self):
"""51 new samples should get 10x."""
assert get_mixing_ratio(51) == 10
def test_200_samples_returns_10x(self):
"""200 new samples (boundary) should get 10x."""
assert get_mixing_ratio(200) == 10
def test_201_samples_returns_5x(self):
"""201 new samples should get 5x."""
assert get_mixing_ratio(201) == 5
def test_500_samples_returns_5x(self):
"""500 new samples (boundary) should get 5x."""
assert get_mixing_ratio(500) == 5
def test_1000_samples_returns_default(self):
"""1000+ samples should get default multiplier (5x)."""
assert get_mixing_ratio(1000) == DEFAULT_MULTIPLIER
# =============================================================================
# Test _collect_images
# =============================================================================
class TestCollectImages:
"""Tests for _collect_images function."""
def test_collects_png_files(self, tmp_path: Path):
"""Should collect .png files."""
(tmp_path / "img1.png").write_bytes(b"fake png")
(tmp_path / "img2.png").write_bytes(b"fake png")
images = _collect_images(tmp_path)
assert len(images) == 2
def test_collects_jpg_files(self, tmp_path: Path):
"""Should collect .jpg files."""
(tmp_path / "img1.jpg").write_bytes(b"fake jpg")
images = _collect_images(tmp_path)
assert len(images) == 1
def test_collects_both_types(self, tmp_path: Path):
"""Should collect both .png and .jpg files."""
(tmp_path / "img1.png").write_bytes(b"fake png")
(tmp_path / "img2.jpg").write_bytes(b"fake jpg")
images = _collect_images(tmp_path)
assert len(images) == 2
def test_ignores_other_files(self, tmp_path: Path):
"""Should ignore non-image files."""
(tmp_path / "data.txt").write_text("not an image")
(tmp_path / "data.yaml").write_text("yaml")
(tmp_path / "img.png").write_bytes(b"png")
images = _collect_images(tmp_path)
assert len(images) == 1
def test_returns_empty_for_nonexistent_dir(self, tmp_path: Path):
"""Should return empty list for nonexistent directory."""
images = _collect_images(tmp_path / "nonexistent")
assert images == []
# =============================================================================
# Test _image_to_label_path
# =============================================================================
class TestImageToLabelPath:
"""Tests for _image_to_label_path function."""
def test_converts_train_image_to_label(self, tmp_path: Path):
"""Should convert images/train/img.png to labels/train/img.txt."""
image_path = tmp_path / "dataset" / "images" / "train" / "doc1_page1.png"
label_path = _image_to_label_path(image_path)
assert label_path.name == "doc1_page1.txt"
assert "labels" in str(label_path)
assert "train" in str(label_path)
def test_converts_val_image_to_label(self, tmp_path: Path):
"""Should convert images/val/img.jpg to labels/val/img.txt."""
image_path = tmp_path / "dataset" / "images" / "val" / "doc2_page3.jpg"
label_path = _image_to_label_path(image_path)
assert label_path.name == "doc2_page3.txt"
assert "labels" in str(label_path)
assert "val" in str(label_path)
# =============================================================================
# Test _find_pool_images
# =============================================================================
class TestFindPoolImages:
"""Tests for _find_pool_images function."""
def _create_dataset(self, base_path: Path, doc_ids: list[str], split: str = "train") -> None:
"""Helper to create a dataset structure with images."""
images_dir = base_path / "images" / split
images_dir.mkdir(parents=True, exist_ok=True)
for doc_id in doc_ids:
(images_dir / f"{doc_id}_page1.png").write_bytes(b"img")
(images_dir / f"{doc_id}_page2.png").write_bytes(b"img")
def test_finds_matching_images(self, tmp_path: Path):
"""Should find images matching pool document IDs."""
doc_id1 = str(uuid4())
doc_id2 = str(uuid4())
self._create_dataset(tmp_path, [doc_id1, doc_id2])
pool_ids = {doc_id1}
images = _find_pool_images(tmp_path, pool_ids)
assert len(images) == 2 # 2 pages for doc_id1
assert all(doc_id1 in str(img) for img in images)
def test_ignores_non_pool_images(self, tmp_path: Path):
"""Should not return images for documents not in pool."""
doc_id1 = str(uuid4())
doc_id2 = str(uuid4())
self._create_dataset(tmp_path, [doc_id1, doc_id2])
pool_ids = {doc_id1}
images = _find_pool_images(tmp_path, pool_ids)
# Only doc_id1 images should be found
for img in images:
assert doc_id1 in str(img)
assert doc_id2 not in str(img)
def test_searches_all_splits(self, tmp_path: Path):
"""Should search train, val, and test splits."""
doc_id = str(uuid4())
for split in ("train", "val", "test"):
self._create_dataset(tmp_path, [doc_id], split=split)
images = _find_pool_images(tmp_path, {doc_id})
assert len(images) == 6 # 2 pages * 3 splits
def test_empty_pool_returns_empty(self, tmp_path: Path):
"""Should return empty list for empty pool IDs."""
self._create_dataset(tmp_path, [str(uuid4())])
images = _find_pool_images(tmp_path, set())
assert images == []
# =============================================================================
# Test build_mixed_dataset
# =============================================================================
class TestBuildMixedDataset:
"""Tests for build_mixed_dataset function."""
def _setup_base_dataset(self, base_path: Path, num_old: int = 20) -> None:
"""Create a base dataset with old training images."""
for split in ("train", "val"):
img_dir = base_path / "images" / split
lbl_dir = base_path / "labels" / split
img_dir.mkdir(parents=True, exist_ok=True)
lbl_dir.mkdir(parents=True, exist_ok=True)
count = int(num_old * 0.8) if split == "train" else num_old - int(num_old * 0.8)
for i in range(count):
doc_id = str(uuid4())
img_file = img_dir / f"{doc_id}_page1.png"
lbl_file = lbl_dir / f"{doc_id}_page1.txt"
img_file.write_bytes(b"fake image data")
lbl_file.write_text("0 0.5 0.5 0.1 0.1\n")
def _setup_pool_images(self, base_path: Path, doc_ids: list[str]) -> None:
"""Add pool images to the base dataset."""
img_dir = base_path / "images" / "train"
lbl_dir = base_path / "labels" / "train"
img_dir.mkdir(parents=True, exist_ok=True)
lbl_dir.mkdir(parents=True, exist_ok=True)
for doc_id in doc_ids:
img_file = img_dir / f"{doc_id}_page1.png"
lbl_file = lbl_dir / f"{doc_id}_page1.txt"
img_file.write_bytes(b"pool image data")
lbl_file.write_text("0 0.5 0.5 0.2 0.2\n")
@pytest.fixture
def base_dataset(self, tmp_path: Path) -> Path:
"""Create a base dataset for testing."""
base_path = tmp_path / "base_dataset"
self._setup_base_dataset(base_path, num_old=20)
return base_path
def test_builds_output_structure(self, base_dataset: Path, tmp_path: Path):
"""Should create proper YOLO directory structure."""
pool_ids = [uuid4() for _ in range(5)]
self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids])
output_dir = tmp_path / "mixed_output"
result = build_mixed_dataset(
pool_document_ids=pool_ids,
base_dataset_path=base_dataset,
output_dir=output_dir,
)
assert (output_dir / "images" / "train").exists()
assert (output_dir / "images" / "val").exists()
assert (output_dir / "labels" / "train").exists()
assert (output_dir / "labels" / "val").exists()
assert (output_dir / "data.yaml").exists()
def test_returns_correct_metadata(self, base_dataset: Path, tmp_path: Path):
"""Should return correct counts and metadata."""
pool_ids = [uuid4() for _ in range(5)]
self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids])
output_dir = tmp_path / "mixed_output"
result = build_mixed_dataset(
pool_document_ids=pool_ids,
base_dataset_path=base_dataset,
output_dir=output_dir,
)
assert "data_yaml" in result
assert "total_images" in result
assert "old_images" in result
assert "new_images" in result
assert "mixing_ratio" in result
assert result["total_images"] == result["old_images"] + result["new_images"]
def test_mixing_ratio_applied(self, base_dataset: Path, tmp_path: Path):
"""Should use correct mixing ratio based on pool size."""
pool_ids = [uuid4() for _ in range(5)]
self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids])
output_dir = tmp_path / "mixed_output"
result = build_mixed_dataset(
pool_document_ids=pool_ids,
base_dataset_path=base_dataset,
output_dir=output_dir,
)
# 5 new samples -> 50x multiplier
assert result["mixing_ratio"] == 50
def test_seed_reproducibility(self, base_dataset: Path, tmp_path: Path):
"""Same seed should produce same output."""
pool_ids = [uuid4() for _ in range(3)]
self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids])
out1 = tmp_path / "out1"
out2 = tmp_path / "out2"
r1 = build_mixed_dataset(pool_ids, base_dataset, out1, seed=42)
r2 = build_mixed_dataset(pool_ids, base_dataset, out2, seed=42)
assert r1["old_images"] == r2["old_images"]
assert r1["new_images"] == r2["new_images"]
assert r1["total_images"] == r2["total_images"]