WIP
This commit is contained in:
@@ -53,7 +53,7 @@ class TestTrainingConfigSchema:
|
||||
"""Test default training configuration."""
|
||||
config = TrainingConfig()
|
||||
|
||||
assert config.model_name == "yolo11n.pt"
|
||||
assert config.model_name == "yolo26s.pt"
|
||||
assert config.epochs == 100
|
||||
assert config.batch_size == 16
|
||||
assert config.image_size == 640
|
||||
@@ -63,7 +63,7 @@ class TestTrainingConfigSchema:
|
||||
def test_custom_config(self):
|
||||
"""Test custom training configuration."""
|
||||
config = TrainingConfig(
|
||||
model_name="yolo11s.pt",
|
||||
model_name="yolo26s.pt",
|
||||
epochs=50,
|
||||
batch_size=8,
|
||||
image_size=416,
|
||||
@@ -71,7 +71,7 @@ class TestTrainingConfigSchema:
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
assert config.model_name == "yolo11s.pt"
|
||||
assert config.model_name == "yolo26s.pt"
|
||||
assert config.epochs == 50
|
||||
assert config.batch_size == 8
|
||||
|
||||
@@ -136,7 +136,7 @@ class TestTrainingTaskModel:
|
||||
def test_task_with_config(self):
|
||||
"""Test task with configuration."""
|
||||
config = {
|
||||
"model_name": "yolo11n.pt",
|
||||
"model_name": "yolo26s.pt",
|
||||
"epochs": 100,
|
||||
}
|
||||
task = TrainingTask(
|
||||
|
||||
784
tests/web/test_data_mixer.py
Normal file
784
tests/web/test_data_mixer.py
Normal file
@@ -0,0 +1,784 @@
|
||||
"""
|
||||
Comprehensive unit tests for Data Mixing Service.
|
||||
|
||||
Tests the data mixing service functions for YOLO fine-tuning:
|
||||
- Mixing ratio calculation based on sample counts
|
||||
- Dataset building with old/new sample mixing
|
||||
- Image collection and path conversion
|
||||
- Pool document matching
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.web.services.data_mixer import (
|
||||
get_mixing_ratio,
|
||||
build_mixed_dataset,
|
||||
_collect_images,
|
||||
_image_to_label_path,
|
||||
_find_pool_images,
|
||||
MIXING_RATIOS,
|
||||
DEFAULT_MULTIPLIER,
|
||||
MAX_OLD_SAMPLES,
|
||||
MIN_POOL_SIZE,
|
||||
)
|
||||
|
||||
|
||||
class TestGetMixingRatio:
|
||||
"""Tests for get_mixing_ratio function."""
|
||||
|
||||
def test_mixing_ratio_at_first_threshold(self):
|
||||
"""Test mixing ratio at first threshold boundary (10 samples)."""
|
||||
assert get_mixing_ratio(1) == 50
|
||||
assert get_mixing_ratio(5) == 50
|
||||
assert get_mixing_ratio(10) == 50
|
||||
|
||||
def test_mixing_ratio_at_second_threshold(self):
|
||||
"""Test mixing ratio at second threshold boundary (50 samples)."""
|
||||
assert get_mixing_ratio(11) == 20
|
||||
assert get_mixing_ratio(30) == 20
|
||||
assert get_mixing_ratio(50) == 20
|
||||
|
||||
def test_mixing_ratio_at_third_threshold(self):
|
||||
"""Test mixing ratio at third threshold boundary (200 samples)."""
|
||||
assert get_mixing_ratio(51) == 10
|
||||
assert get_mixing_ratio(100) == 10
|
||||
assert get_mixing_ratio(200) == 10
|
||||
|
||||
def test_mixing_ratio_at_fourth_threshold(self):
|
||||
"""Test mixing ratio at fourth threshold boundary (500 samples)."""
|
||||
assert get_mixing_ratio(201) == 5
|
||||
assert get_mixing_ratio(350) == 5
|
||||
assert get_mixing_ratio(500) == 5
|
||||
|
||||
def test_mixing_ratio_above_all_thresholds(self):
|
||||
"""Test mixing ratio for samples above all thresholds."""
|
||||
assert get_mixing_ratio(501) == DEFAULT_MULTIPLIER
|
||||
assert get_mixing_ratio(1000) == DEFAULT_MULTIPLIER
|
||||
assert get_mixing_ratio(10000) == DEFAULT_MULTIPLIER
|
||||
|
||||
def test_mixing_ratio_boundary_values(self):
|
||||
"""Test exact threshold boundaries match expected ratios."""
|
||||
# Verify threshold boundaries from MIXING_RATIOS
|
||||
for threshold, expected_multiplier in MIXING_RATIOS:
|
||||
assert get_mixing_ratio(threshold) == expected_multiplier
|
||||
# One above threshold should give next ratio
|
||||
if threshold < MIXING_RATIOS[-1][0]:
|
||||
next_idx = MIXING_RATIOS.index((threshold, expected_multiplier)) + 1
|
||||
next_multiplier = MIXING_RATIOS[next_idx][1]
|
||||
assert get_mixing_ratio(threshold + 1) == next_multiplier
|
||||
|
||||
|
||||
class TestCollectImages:
|
||||
"""Tests for _collect_images function."""
|
||||
|
||||
def test_collect_images_empty_directory(self, tmp_path):
|
||||
"""Test collecting images from empty directory."""
|
||||
images_dir = tmp_path / "images"
|
||||
images_dir.mkdir()
|
||||
|
||||
result = _collect_images(images_dir)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_collect_images_nonexistent_directory(self, tmp_path):
|
||||
"""Test collecting images from non-existent directory."""
|
||||
images_dir = tmp_path / "nonexistent"
|
||||
|
||||
result = _collect_images(images_dir)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_collect_png_images(self, tmp_path):
|
||||
"""Test collecting PNG images."""
|
||||
images_dir = tmp_path / "images"
|
||||
images_dir.mkdir()
|
||||
|
||||
# Create PNG files
|
||||
(images_dir / "img1.png").touch()
|
||||
(images_dir / "img2.png").touch()
|
||||
(images_dir / "img3.png").touch()
|
||||
|
||||
result = _collect_images(images_dir)
|
||||
|
||||
assert len(result) == 3
|
||||
assert all(img.suffix == ".png" for img in result)
|
||||
# Verify sorted order
|
||||
assert result == sorted(result)
|
||||
|
||||
def test_collect_jpg_images(self, tmp_path):
|
||||
"""Test collecting JPG images."""
|
||||
images_dir = tmp_path / "images"
|
||||
images_dir.mkdir()
|
||||
|
||||
# Create JPG files
|
||||
(images_dir / "img1.jpg").touch()
|
||||
(images_dir / "img2.jpg").touch()
|
||||
|
||||
result = _collect_images(images_dir)
|
||||
|
||||
assert len(result) == 2
|
||||
assert all(img.suffix == ".jpg" for img in result)
|
||||
|
||||
def test_collect_mixed_image_types(self, tmp_path):
|
||||
"""Test collecting both PNG and JPG images."""
|
||||
images_dir = tmp_path / "images"
|
||||
images_dir.mkdir()
|
||||
|
||||
# Create mixed files
|
||||
(images_dir / "img1.png").touch()
|
||||
(images_dir / "img2.jpg").touch()
|
||||
(images_dir / "img3.png").touch()
|
||||
(images_dir / "img4.jpg").touch()
|
||||
|
||||
result = _collect_images(images_dir)
|
||||
|
||||
assert len(result) == 4
|
||||
# PNG files should come first (sorted separately)
|
||||
png_files = [r for r in result if r.suffix == ".png"]
|
||||
jpg_files = [r for r in result if r.suffix == ".jpg"]
|
||||
assert len(png_files) == 2
|
||||
assert len(jpg_files) == 2
|
||||
|
||||
def test_collect_images_ignores_other_files(self, tmp_path):
|
||||
"""Test that non-image files are ignored."""
|
||||
images_dir = tmp_path / "images"
|
||||
images_dir.mkdir()
|
||||
|
||||
# Create various files
|
||||
(images_dir / "img1.png").touch()
|
||||
(images_dir / "img2.jpg").touch()
|
||||
(images_dir / "doc.txt").touch()
|
||||
(images_dir / "data.json").touch()
|
||||
(images_dir / "notes.md").touch()
|
||||
|
||||
result = _collect_images(images_dir)
|
||||
|
||||
assert len(result) == 2
|
||||
assert all(img.suffix in [".png", ".jpg"] for img in result)
|
||||
|
||||
|
||||
class TestImageToLabelPath:
|
||||
"""Tests for _image_to_label_path function."""
|
||||
|
||||
def test_image_to_label_path_train(self, tmp_path):
|
||||
"""Test converting train image path to label path."""
|
||||
base = tmp_path / "dataset"
|
||||
image_path = base / "images" / "train" / "doc123_page1.png"
|
||||
|
||||
label_path = _image_to_label_path(image_path)
|
||||
|
||||
expected = base / "labels" / "train" / "doc123_page1.txt"
|
||||
assert label_path == expected
|
||||
|
||||
def test_image_to_label_path_val(self, tmp_path):
|
||||
"""Test converting val image path to label path."""
|
||||
base = tmp_path / "dataset"
|
||||
image_path = base / "images" / "val" / "doc456_page2.jpg"
|
||||
|
||||
label_path = _image_to_label_path(image_path)
|
||||
|
||||
expected = base / "labels" / "val" / "doc456_page2.txt"
|
||||
assert label_path == expected
|
||||
|
||||
def test_image_to_label_path_test(self, tmp_path):
|
||||
"""Test converting test image path to label path."""
|
||||
base = tmp_path / "dataset"
|
||||
image_path = base / "images" / "test" / "doc789_page3.png"
|
||||
|
||||
label_path = _image_to_label_path(image_path)
|
||||
|
||||
expected = base / "labels" / "test" / "doc789_page3.txt"
|
||||
assert label_path == expected
|
||||
|
||||
def test_image_to_label_path_preserves_filename(self, tmp_path):
|
||||
"""Test that filename (without extension) is preserved."""
|
||||
base = tmp_path / "dataset"
|
||||
image_path = base / "images" / "train" / "complex_filename_123_page5.png"
|
||||
|
||||
label_path = _image_to_label_path(image_path)
|
||||
|
||||
assert label_path.stem == "complex_filename_123_page5"
|
||||
assert label_path.suffix == ".txt"
|
||||
|
||||
def test_image_to_label_path_jpg_to_txt(self, tmp_path):
|
||||
"""Test that JPG extension is converted to TXT."""
|
||||
base = tmp_path / "dataset"
|
||||
image_path = base / "images" / "train" / "image.jpg"
|
||||
|
||||
label_path = _image_to_label_path(image_path)
|
||||
|
||||
assert label_path.suffix == ".txt"
|
||||
|
||||
|
||||
class TestFindPoolImages:
|
||||
"""Tests for _find_pool_images function."""
|
||||
|
||||
def test_find_pool_images_in_train(self, tmp_path):
|
||||
"""Test finding pool images in train split."""
|
||||
base = tmp_path / "dataset"
|
||||
train_dir = base / "images" / "train"
|
||||
train_dir.mkdir(parents=True)
|
||||
|
||||
doc_id = str(uuid4())
|
||||
pool_doc_ids = {doc_id}
|
||||
|
||||
# Create images
|
||||
(train_dir / f"{doc_id}_page1.png").touch()
|
||||
(train_dir / f"{doc_id}_page2.png").touch()
|
||||
(train_dir / "other_doc_page1.png").touch()
|
||||
|
||||
result = _find_pool_images(base, pool_doc_ids)
|
||||
|
||||
assert len(result) == 2
|
||||
assert all(doc_id in str(img) for img in result)
|
||||
|
||||
def test_find_pool_images_in_val(self, tmp_path):
|
||||
"""Test finding pool images in val split."""
|
||||
base = tmp_path / "dataset"
|
||||
val_dir = base / "images" / "val"
|
||||
val_dir.mkdir(parents=True)
|
||||
|
||||
doc_id = str(uuid4())
|
||||
pool_doc_ids = {doc_id}
|
||||
|
||||
# Create images
|
||||
(val_dir / f"{doc_id}_page1.png").touch()
|
||||
|
||||
result = _find_pool_images(base, pool_doc_ids)
|
||||
|
||||
assert len(result) == 1
|
||||
assert doc_id in str(result[0])
|
||||
|
||||
def test_find_pool_images_across_splits(self, tmp_path):
|
||||
"""Test finding pool images across train, val, and test splits."""
|
||||
base = tmp_path / "dataset"
|
||||
|
||||
doc_id1 = str(uuid4())
|
||||
doc_id2 = str(uuid4())
|
||||
pool_doc_ids = {doc_id1, doc_id2}
|
||||
|
||||
# Create images in different splits
|
||||
train_dir = base / "images" / "train"
|
||||
val_dir = base / "images" / "val"
|
||||
test_dir = base / "images" / "test"
|
||||
|
||||
train_dir.mkdir(parents=True)
|
||||
val_dir.mkdir(parents=True)
|
||||
test_dir.mkdir(parents=True)
|
||||
|
||||
(train_dir / f"{doc_id1}_page1.png").touch()
|
||||
(val_dir / f"{doc_id1}_page2.png").touch()
|
||||
(test_dir / f"{doc_id2}_page1.png").touch()
|
||||
(train_dir / "other_doc_page1.png").touch()
|
||||
|
||||
result = _find_pool_images(base, pool_doc_ids)
|
||||
|
||||
assert len(result) == 3
|
||||
doc1_images = [img for img in result if doc_id1 in str(img)]
|
||||
doc2_images = [img for img in result if doc_id2 in str(img)]
|
||||
assert len(doc1_images) == 2
|
||||
assert len(doc2_images) == 1
|
||||
|
||||
def test_find_pool_images_empty_pool(self, tmp_path):
|
||||
"""Test finding images with empty pool."""
|
||||
base = tmp_path / "dataset"
|
||||
train_dir = base / "images" / "train"
|
||||
train_dir.mkdir(parents=True)
|
||||
|
||||
(train_dir / "doc123_page1.png").touch()
|
||||
|
||||
result = _find_pool_images(base, set())
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
def test_find_pool_images_no_matches(self, tmp_path):
|
||||
"""Test finding images when no documents match pool."""
|
||||
base = tmp_path / "dataset"
|
||||
train_dir = base / "images" / "train"
|
||||
train_dir.mkdir(parents=True)
|
||||
|
||||
pool_doc_ids = {str(uuid4())}
|
||||
|
||||
(train_dir / "other_doc_page1.png").touch()
|
||||
(train_dir / "another_doc_page1.png").touch()
|
||||
|
||||
result = _find_pool_images(base, pool_doc_ids)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
def test_find_pool_images_multiple_pages(self, tmp_path):
|
||||
"""Test finding multiple pages for same document."""
|
||||
base = tmp_path / "dataset"
|
||||
train_dir = base / "images" / "train"
|
||||
train_dir.mkdir(parents=True)
|
||||
|
||||
doc_id = str(uuid4())
|
||||
pool_doc_ids = {doc_id}
|
||||
|
||||
# Create multiple pages
|
||||
for i in range(1, 6):
|
||||
(train_dir / f"{doc_id}_page{i}.png").touch()
|
||||
|
||||
result = _find_pool_images(base, pool_doc_ids)
|
||||
|
||||
assert len(result) == 5
|
||||
|
||||
def test_find_pool_images_ignores_non_files(self, tmp_path):
|
||||
"""Test that directories are ignored."""
|
||||
base = tmp_path / "dataset"
|
||||
train_dir = base / "images" / "train"
|
||||
train_dir.mkdir(parents=True)
|
||||
|
||||
doc_id = str(uuid4())
|
||||
pool_doc_ids = {doc_id}
|
||||
|
||||
(train_dir / f"{doc_id}_page1.png").touch()
|
||||
(train_dir / "subdir").mkdir()
|
||||
|
||||
result = _find_pool_images(base, pool_doc_ids)
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
def test_find_pool_images_nonexistent_splits(self, tmp_path):
|
||||
"""Test handling non-existent split directories."""
|
||||
base = tmp_path / "dataset"
|
||||
# Don't create any directories
|
||||
|
||||
pool_doc_ids = {str(uuid4())}
|
||||
|
||||
result = _find_pool_images(base, pool_doc_ids)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
class TestBuildMixedDataset:
|
||||
"""Tests for build_mixed_dataset function."""
|
||||
|
||||
@pytest.fixture
|
||||
def setup_base_dataset(self, tmp_path):
|
||||
"""Create a base dataset with old training data."""
|
||||
base = tmp_path / "base_dataset"
|
||||
|
||||
# Create directory structure
|
||||
for split in ("train", "val"):
|
||||
(base / "images" / split).mkdir(parents=True)
|
||||
(base / "labels" / split).mkdir(parents=True)
|
||||
|
||||
# Create old training images and labels
|
||||
for i in range(1, 11):
|
||||
img_path = base / "images" / "train" / f"old_doc_{i}_page1.png"
|
||||
label_path = base / "labels" / "train" / f"old_doc_{i}_page1.txt"
|
||||
img_path.write_text(f"image {i}")
|
||||
label_path.write_text(f"0 0.5 0.5 0.1 0.1")
|
||||
|
||||
for i in range(1, 6):
|
||||
img_path = base / "images" / "val" / f"old_doc_val_{i}_page1.png"
|
||||
label_path = base / "labels" / "val" / f"old_doc_val_{i}_page1.txt"
|
||||
img_path.write_text(f"val image {i}")
|
||||
label_path.write_text(f"0 0.5 0.5 0.1 0.1")
|
||||
|
||||
return base
|
||||
|
||||
@pytest.fixture
|
||||
def setup_pool_documents(self, tmp_path, setup_base_dataset):
|
||||
"""Create pool documents in base dataset."""
|
||||
base = setup_base_dataset
|
||||
pool_ids = [uuid4() for _ in range(5)]
|
||||
|
||||
# Add pool documents to train split
|
||||
for doc_id in pool_ids:
|
||||
img_path = base / "images" / "train" / f"{doc_id}_page1.png"
|
||||
label_path = base / "labels" / "train" / f"{doc_id}_page1.txt"
|
||||
img_path.write_text(f"pool image {doc_id}")
|
||||
label_path.write_text(f"1 0.5 0.5 0.2 0.2")
|
||||
|
||||
return base, pool_ids
|
||||
|
||||
def test_build_mixed_dataset_basic(self, tmp_path, setup_pool_documents):
|
||||
"""Test basic mixed dataset building."""
|
||||
base, pool_ids = setup_pool_documents
|
||||
output_dir = tmp_path / "mixed_output"
|
||||
|
||||
result = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# Verify result structure
|
||||
assert "data_yaml" in result
|
||||
assert "total_images" in result
|
||||
assert "old_images" in result
|
||||
assert "new_images" in result
|
||||
assert "mixing_ratio" in result
|
||||
|
||||
# Verify counts - new images should be > 0 (at least some were copied)
|
||||
# Note: new images are split 80/20 and copied without overwriting
|
||||
assert result["new_images"] > 0
|
||||
assert result["old_images"] > 0
|
||||
assert result["total_images"] == result["old_images"] + result["new_images"]
|
||||
|
||||
# Verify output structure
|
||||
assert output_dir.exists()
|
||||
assert (output_dir / "images" / "train").exists()
|
||||
assert (output_dir / "images" / "val").exists()
|
||||
assert (output_dir / "labels" / "train").exists()
|
||||
assert (output_dir / "labels" / "val").exists()
|
||||
|
||||
# Verify data.yaml exists
|
||||
yaml_path = Path(result["data_yaml"])
|
||||
assert yaml_path.exists()
|
||||
yaml_content = yaml_path.read_text()
|
||||
assert "train: images/train" in yaml_content
|
||||
assert "val: images/val" in yaml_content
|
||||
assert "nc:" in yaml_content
|
||||
assert "names:" in yaml_content
|
||||
|
||||
def test_build_mixed_dataset_respects_mixing_ratio(self, tmp_path, setup_pool_documents):
|
||||
"""Test that mixing ratio is correctly applied."""
|
||||
base, pool_ids = setup_pool_documents
|
||||
output_dir = tmp_path / "mixed_output"
|
||||
|
||||
# With 5 pool documents, get_mixing_ratio(5) returns 50
|
||||
# (because 5 <= 10, first threshold)
|
||||
# So target old_samples = 5 * 50 = 250
|
||||
# But limited by available data: 10 old train + 5 old val + 5 pool = 20 total
|
||||
result = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# Pool images are in the base dataset, so they can be sampled as "old"
|
||||
# Total available: 20 images (15 pure old + 5 pool images)
|
||||
assert result["old_images"] <= 20 # Can't exceed available in base dataset
|
||||
assert result["old_images"] > 0 # Should have some old data
|
||||
assert result["mixing_ratio"] == 50 # Correct ratio for 5 samples
|
||||
|
||||
def test_build_mixed_dataset_max_old_samples_limit(self, tmp_path):
|
||||
"""Test that MAX_OLD_SAMPLES limit is applied."""
|
||||
base = tmp_path / "base_dataset"
|
||||
|
||||
# Create directory structure
|
||||
for split in ("train", "val"):
|
||||
(base / "images" / split).mkdir(parents=True)
|
||||
(base / "labels" / split).mkdir(parents=True)
|
||||
|
||||
# Create MORE than MAX_OLD_SAMPLES old images
|
||||
for i in range(MAX_OLD_SAMPLES + 500):
|
||||
img_path = base / "images" / "train" / f"old_doc_{i}_page1.png"
|
||||
label_path = base / "labels" / "train" / f"old_doc_{i}_page1.txt"
|
||||
img_path.write_text(f"image {i}")
|
||||
label_path.write_text(f"0 0.5 0.5 0.1 0.1")
|
||||
|
||||
# Create pool documents (100 samples, ratio=10, so target=1000)
|
||||
# But should be capped at MAX_OLD_SAMPLES (3000)
|
||||
pool_ids = [uuid4() for _ in range(100)]
|
||||
for doc_id in pool_ids:
|
||||
img_path = base / "images" / "train" / f"{doc_id}_page1.png"
|
||||
label_path = base / "labels" / "train" / f"{doc_id}_page1.txt"
|
||||
img_path.write_text(f"pool image {doc_id}")
|
||||
label_path.write_text(f"1 0.5 0.5 0.2 0.2")
|
||||
|
||||
output_dir = tmp_path / "mixed_output"
|
||||
|
||||
result = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# Should be capped at MAX_OLD_SAMPLES
|
||||
assert result["old_images"] <= MAX_OLD_SAMPLES
|
||||
|
||||
def test_build_mixed_dataset_empty_pool(self, tmp_path, setup_base_dataset):
|
||||
"""Test building dataset with empty pool."""
|
||||
base = setup_base_dataset
|
||||
output_dir = tmp_path / "mixed_output"
|
||||
|
||||
result = build_mixed_dataset(
|
||||
pool_document_ids=[],
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# With 0 new samples, all counts should be 0
|
||||
assert result["new_images"] == 0
|
||||
assert result["old_images"] == 0
|
||||
assert result["total_images"] == 0
|
||||
|
||||
def test_build_mixed_dataset_no_old_data(self, tmp_path):
|
||||
"""Test building dataset with ONLY pool data (no separate old data)."""
|
||||
base = tmp_path / "base_dataset"
|
||||
|
||||
# Create empty directory structure
|
||||
for split in ("train", "val"):
|
||||
(base / "images" / split).mkdir(parents=True)
|
||||
(base / "labels" / split).mkdir(parents=True)
|
||||
|
||||
# Create only pool documents
|
||||
# NOTE: These are placed in base dataset train split
|
||||
# So they will be sampled as "old" data first, then skipped as "new"
|
||||
pool_ids = [uuid4() for _ in range(5)]
|
||||
for doc_id in pool_ids:
|
||||
img_path = base / "images" / "train" / f"{doc_id}_page1.png"
|
||||
label_path = base / "labels" / "train" / f"{doc_id}_page1.txt"
|
||||
img_path.write_text(f"pool image {doc_id}")
|
||||
label_path.write_text(f"1 0.5 0.5 0.2 0.2")
|
||||
|
||||
output_dir = tmp_path / "mixed_output"
|
||||
|
||||
result = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# Pool images are in base dataset, so they get sampled as "old" images
|
||||
# Then when copying "new" images, they're skipped because they already exist
|
||||
# So we expect: old_images > 0, new_images may be 0, total >= 0
|
||||
assert result["total_images"] > 0
|
||||
assert result["total_images"] == result["old_images"] + result["new_images"]
|
||||
|
||||
def test_build_mixed_dataset_train_val_split(self, tmp_path, setup_pool_documents):
|
||||
"""Test that images are split into train/val (80/20)."""
|
||||
base, pool_ids = setup_pool_documents
|
||||
output_dir = tmp_path / "mixed_output"
|
||||
|
||||
result = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# Count images in train and val
|
||||
train_images = list((output_dir / "images" / "train").glob("*.png"))
|
||||
val_images = list((output_dir / "images" / "val").glob("*.png"))
|
||||
|
||||
total_output_images = len(train_images) + len(val_images)
|
||||
|
||||
# Should match total_images count
|
||||
assert total_output_images == result["total_images"]
|
||||
|
||||
# Check approximate 80/20 split (allow some variance due to small sample size)
|
||||
if total_output_images > 0:
|
||||
train_ratio = len(train_images) / total_output_images
|
||||
assert 0.6 <= train_ratio <= 0.9 # Allow some variance
|
||||
|
||||
def test_build_mixed_dataset_reproducible_with_seed(self, tmp_path, setup_pool_documents):
|
||||
"""Test that same seed produces same results."""
|
||||
base, pool_ids = setup_pool_documents
|
||||
output_dir1 = tmp_path / "mixed_output1"
|
||||
output_dir2 = tmp_path / "mixed_output2"
|
||||
|
||||
result1 = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir1,
|
||||
seed=123,
|
||||
)
|
||||
|
||||
result2 = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir2,
|
||||
seed=123,
|
||||
)
|
||||
|
||||
# Same counts
|
||||
assert result1["old_images"] == result2["old_images"]
|
||||
assert result1["new_images"] == result2["new_images"]
|
||||
|
||||
# Same files in train/val
|
||||
train_files1 = {f.name for f in (output_dir1 / "images" / "train").glob("*.png")}
|
||||
train_files2 = {f.name for f in (output_dir2 / "images" / "train").glob("*.png")}
|
||||
assert train_files1 == train_files2
|
||||
|
||||
def test_build_mixed_dataset_different_seeds(self, tmp_path, setup_pool_documents):
|
||||
"""Test that different seeds produce different sampling."""
|
||||
base, pool_ids = setup_pool_documents
|
||||
output_dir1 = tmp_path / "mixed_output1"
|
||||
output_dir2 = tmp_path / "mixed_output2"
|
||||
|
||||
result1 = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir1,
|
||||
seed=123,
|
||||
)
|
||||
|
||||
result2 = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir2,
|
||||
seed=456,
|
||||
)
|
||||
|
||||
# Both should have processed images
|
||||
assert result1["total_images"] > 0
|
||||
assert result2["total_images"] > 0
|
||||
|
||||
# Both should have the same mixing ratio (based on pool size)
|
||||
assert result1["mixing_ratio"] == result2["mixing_ratio"]
|
||||
|
||||
# File distribution in train/val may differ due to different shuffling
|
||||
train_files1 = {f.name for f in (output_dir1 / "images" / "train").glob("*.png")}
|
||||
train_files2 = {f.name for f in (output_dir2 / "images" / "train").glob("*.png")}
|
||||
|
||||
# With different seeds, we expect some difference in file distribution
|
||||
# But this is not strictly guaranteed, so we just verify both have files
|
||||
assert len(train_files1) > 0
|
||||
assert len(train_files2) > 0
|
||||
|
||||
def test_build_mixed_dataset_copies_labels(self, tmp_path, setup_pool_documents):
|
||||
"""Test that corresponding label files are copied."""
|
||||
base, pool_ids = setup_pool_documents
|
||||
output_dir = tmp_path / "mixed_output"
|
||||
|
||||
result = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# Count labels
|
||||
train_labels = list((output_dir / "labels" / "train").glob("*.txt"))
|
||||
val_labels = list((output_dir / "labels" / "val").glob("*.txt"))
|
||||
|
||||
# Each image should have a corresponding label
|
||||
train_images = list((output_dir / "images" / "train").glob("*.png"))
|
||||
val_images = list((output_dir / "images" / "val").glob("*.png"))
|
||||
|
||||
# Allow label count to be <= image count (in case some labels are missing)
|
||||
assert len(train_labels) <= len(train_images)
|
||||
assert len(val_labels) <= len(val_images)
|
||||
|
||||
def test_build_mixed_dataset_skips_duplicate_files(self, tmp_path, setup_pool_documents):
|
||||
"""Test behavior when running build_mixed_dataset multiple times."""
|
||||
base, pool_ids = setup_pool_documents
|
||||
output_dir = tmp_path / "mixed_output"
|
||||
|
||||
# First build
|
||||
result1 = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
initial_count = result1["total_images"]
|
||||
|
||||
# Find a file in output and modify it
|
||||
train_images = list((output_dir / "images" / "train").glob("*.png"))
|
||||
if len(train_images) > 0:
|
||||
test_file = train_images[0]
|
||||
test_file.write_text("modified content")
|
||||
|
||||
# Second build with same seed
|
||||
result2 = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# The implementation uses shutil.copy2 which WILL overwrite
|
||||
# So the file will be restored to original content
|
||||
# Just verify the build completed successfully
|
||||
assert result2["total_images"] >= 0
|
||||
|
||||
# Verify the file was overwritten (shutil.copy2 overwrites by default)
|
||||
content = test_file.read_text()
|
||||
assert content != "modified content" # Should be restored
|
||||
|
||||
def test_build_mixed_dataset_handles_jpg_images(self, tmp_path):
|
||||
"""Test that JPG images are handled correctly."""
|
||||
base = tmp_path / "base_dataset"
|
||||
|
||||
# Create directory structure
|
||||
for split in ("train", "val"):
|
||||
(base / "images" / split).mkdir(parents=True)
|
||||
(base / "labels" / split).mkdir(parents=True)
|
||||
|
||||
# Create JPG images as old data
|
||||
for i in range(1, 6):
|
||||
img_path = base / "images" / "train" / f"old_doc_{i}_page1.jpg"
|
||||
label_path = base / "labels" / "train" / f"old_doc_{i}_page1.txt"
|
||||
img_path.write_text(f"jpg image {i}")
|
||||
label_path.write_text(f"0 0.5 0.5 0.1 0.1")
|
||||
|
||||
# Create pool with JPG - use multiple pages to ensure at least one gets copied
|
||||
pool_ids = [uuid4()]
|
||||
doc_id = pool_ids[0]
|
||||
for page_num in range(1, 4):
|
||||
img_path = base / "images" / "train" / f"{doc_id}_page{page_num}.jpg"
|
||||
label_path = base / "labels" / "train" / f"{doc_id}_page{page_num}.txt"
|
||||
img_path.write_text(f"pool jpg {doc_id} page {page_num}")
|
||||
label_path.write_text(f"1 0.5 0.5 0.2 0.2")
|
||||
|
||||
output_dir = tmp_path / "mixed_output"
|
||||
|
||||
result = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# Should have some new JPG images (at least 1 from the pool)
|
||||
assert result["new_images"] > 0
|
||||
assert result["old_images"] > 0
|
||||
|
||||
# Verify JPG files exist in output
|
||||
all_images = list((output_dir / "images" / "train").glob("*.jpg")) + \
|
||||
list((output_dir / "images" / "val").glob("*.jpg"))
|
||||
assert len(all_images) > 0
|
||||
|
||||
|
||||
class TestConstants:
|
||||
"""Tests for module constants."""
|
||||
|
||||
def test_mixing_ratios_structure(self):
|
||||
"""Test MIXING_RATIOS constant structure."""
|
||||
assert isinstance(MIXING_RATIOS, list)
|
||||
assert len(MIXING_RATIOS) == 4
|
||||
|
||||
# Verify format: (threshold, multiplier)
|
||||
for item in MIXING_RATIOS:
|
||||
assert isinstance(item, tuple)
|
||||
assert len(item) == 2
|
||||
assert isinstance(item[0], int)
|
||||
assert isinstance(item[1], int)
|
||||
|
||||
# Verify thresholds are ascending
|
||||
thresholds = [t for t, _ in MIXING_RATIOS]
|
||||
assert thresholds == sorted(thresholds)
|
||||
|
||||
# Verify multipliers are descending
|
||||
multipliers = [m for _, m in MIXING_RATIOS]
|
||||
assert multipliers == sorted(multipliers, reverse=True)
|
||||
|
||||
def test_default_multiplier(self):
|
||||
"""Test DEFAULT_MULTIPLIER constant."""
|
||||
assert DEFAULT_MULTIPLIER == 5
|
||||
assert DEFAULT_MULTIPLIER == MIXING_RATIOS[-1][1]
|
||||
|
||||
def test_max_old_samples(self):
|
||||
"""Test MAX_OLD_SAMPLES constant."""
|
||||
assert MAX_OLD_SAMPLES == 3000
|
||||
assert MAX_OLD_SAMPLES > 0
|
||||
|
||||
def test_min_pool_size(self):
|
||||
"""Test MIN_POOL_SIZE constant."""
|
||||
assert MIN_POOL_SIZE == 50
|
||||
assert MIN_POOL_SIZE > 0
|
||||
@@ -310,7 +310,7 @@ class TestSchedulerDatasetStatusUpdates:
|
||||
try:
|
||||
scheduler._execute_task(
|
||||
task_id=task_id,
|
||||
config={"model_name": "yolo11n.pt"},
|
||||
config={"model_name": "yolo26s.pt"},
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
except Exception:
|
||||
|
||||
467
tests/web/test_finetune_pool.py
Normal file
467
tests/web/test_finetune_pool.py
Normal file
@@ -0,0 +1,467 @@
|
||||
"""
|
||||
Tests for Fine-Tune Pool feature.
|
||||
|
||||
Tests cover:
|
||||
1. FineTunePoolEntry database model
|
||||
2. PoolAddRequest/PoolStatsResponse schemas
|
||||
3. Chain prevention logic
|
||||
4. Pool threshold enforcement
|
||||
5. Model lineage fields on ModelVersion
|
||||
6. Gating enforcement on model activation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Database Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestFineTunePoolEntryModel:
|
||||
"""Tests for FineTunePoolEntry model."""
|
||||
|
||||
def test_creates_with_defaults(self):
|
||||
"""FineTunePoolEntry should have correct defaults."""
|
||||
from backend.data.admin_models import FineTunePoolEntry
|
||||
|
||||
entry = FineTunePoolEntry(document_id=uuid4())
|
||||
assert entry.entry_id is not None
|
||||
assert entry.is_verified is False
|
||||
assert entry.verified_at is None
|
||||
assert entry.verified_by is None
|
||||
assert entry.added_by is None
|
||||
assert entry.reason is None
|
||||
|
||||
def test_creates_with_all_fields(self):
|
||||
"""FineTunePoolEntry should accept all fields."""
|
||||
from backend.data.admin_models import FineTunePoolEntry
|
||||
|
||||
doc_id = uuid4()
|
||||
entry = FineTunePoolEntry(
|
||||
document_id=doc_id,
|
||||
added_by="admin",
|
||||
reason="user_reported_failure",
|
||||
is_verified=True,
|
||||
verified_by="reviewer",
|
||||
)
|
||||
assert entry.document_id == doc_id
|
||||
assert entry.added_by == "admin"
|
||||
assert entry.reason == "user_reported_failure"
|
||||
assert entry.is_verified is True
|
||||
assert entry.verified_by == "reviewer"
|
||||
|
||||
|
||||
class TestGatingResultModel:
|
||||
"""Tests for GatingResult model."""
|
||||
|
||||
def test_creates_with_defaults(self):
|
||||
"""GatingResult should have correct defaults."""
|
||||
from backend.data.admin_models import GatingResult
|
||||
|
||||
model_version_id = uuid4()
|
||||
result = GatingResult(
|
||||
model_version_id=model_version_id,
|
||||
gate1_status="pass",
|
||||
gate2_status="pass",
|
||||
overall_status="pass",
|
||||
)
|
||||
assert result.result_id is not None
|
||||
assert result.model_version_id == model_version_id
|
||||
assert result.gate1_status == "pass"
|
||||
assert result.gate2_status == "pass"
|
||||
assert result.overall_status == "pass"
|
||||
assert result.gate1_mAP_drop is None
|
||||
assert result.gate2_detection_rate is None
|
||||
|
||||
def test_creates_with_full_metrics(self):
|
||||
"""GatingResult should store full metrics."""
|
||||
from backend.data.admin_models import GatingResult
|
||||
|
||||
result = GatingResult(
|
||||
model_version_id=uuid4(),
|
||||
gate1_status="review",
|
||||
gate1_original_mAP=0.95,
|
||||
gate1_new_mAP=0.93,
|
||||
gate1_mAP_drop=0.02,
|
||||
gate2_status="pass",
|
||||
gate2_detection_rate=0.85,
|
||||
gate2_total_samples=100,
|
||||
gate2_detected_samples=85,
|
||||
overall_status="review",
|
||||
)
|
||||
assert result.gate1_original_mAP == 0.95
|
||||
assert result.gate1_new_mAP == 0.93
|
||||
assert result.gate1_mAP_drop == 0.02
|
||||
assert result.gate2_detection_rate == 0.85
|
||||
|
||||
|
||||
class TestModelVersionLineage:
|
||||
"""Tests for ModelVersion lineage fields."""
|
||||
|
||||
def test_default_model_type_is_base(self):
|
||||
"""ModelVersion should default to 'base' model_type."""
|
||||
from backend.data.admin_models import ModelVersion
|
||||
|
||||
mv = ModelVersion(
|
||||
version="v1.0",
|
||||
name="test-model",
|
||||
model_path="/path/to/model.pt",
|
||||
)
|
||||
assert mv.model_type == "base"
|
||||
assert mv.base_model_version_id is None
|
||||
assert mv.base_training_dataset_id is None
|
||||
assert mv.gating_status == "pending"
|
||||
|
||||
def test_finetune_model_type(self):
|
||||
"""ModelVersion should support 'finetune' type with lineage."""
|
||||
from backend.data.admin_models import ModelVersion
|
||||
|
||||
base_id = uuid4()
|
||||
dataset_id = uuid4()
|
||||
mv = ModelVersion(
|
||||
version="v2.0",
|
||||
name="finetuned-model",
|
||||
model_path="/path/to/ft_model.pt",
|
||||
model_type="finetune",
|
||||
base_model_version_id=base_id,
|
||||
base_training_dataset_id=dataset_id,
|
||||
gating_status="pending",
|
||||
)
|
||||
assert mv.model_type == "finetune"
|
||||
assert mv.base_model_version_id == base_id
|
||||
assert mv.base_training_dataset_id == dataset_id
|
||||
assert mv.gating_status == "pending"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPoolSchemas:
|
||||
"""Tests for pool Pydantic schemas."""
|
||||
|
||||
def test_pool_add_request_defaults(self):
|
||||
"""PoolAddRequest should have default reason."""
|
||||
from backend.web.schemas.admin.pool import PoolAddRequest
|
||||
|
||||
req = PoolAddRequest(document_id="550e8400-e29b-41d4-a716-446655440001")
|
||||
assert req.document_id == "550e8400-e29b-41d4-a716-446655440001"
|
||||
assert req.reason == "user_reported_failure"
|
||||
|
||||
def test_pool_add_request_custom_reason(self):
|
||||
"""PoolAddRequest should accept custom reason."""
|
||||
from backend.web.schemas.admin.pool import PoolAddRequest
|
||||
|
||||
req = PoolAddRequest(
|
||||
document_id="550e8400-e29b-41d4-a716-446655440001",
|
||||
reason="manual_addition",
|
||||
)
|
||||
assert req.reason == "manual_addition"
|
||||
|
||||
def test_pool_stats_response(self):
|
||||
"""PoolStatsResponse should compute readiness correctly."""
|
||||
from backend.web.schemas.admin.pool import PoolStatsResponse
|
||||
|
||||
# Not ready
|
||||
stats = PoolStatsResponse(
|
||||
total_entries=30,
|
||||
verified_entries=20,
|
||||
unverified_entries=10,
|
||||
is_ready=False,
|
||||
)
|
||||
assert stats.is_ready is False
|
||||
assert stats.min_required == 50
|
||||
|
||||
# Ready
|
||||
stats_ready = PoolStatsResponse(
|
||||
total_entries=80,
|
||||
verified_entries=60,
|
||||
unverified_entries=20,
|
||||
is_ready=True,
|
||||
)
|
||||
assert stats_ready.is_ready is True
|
||||
|
||||
def test_pool_entry_item(self):
|
||||
"""PoolEntryItem should serialize correctly."""
|
||||
from backend.web.schemas.admin.pool import PoolEntryItem
|
||||
|
||||
entry = PoolEntryItem(
|
||||
entry_id="entry-uuid",
|
||||
document_id="doc-uuid",
|
||||
is_verified=True,
|
||||
verified_at=datetime.utcnow(),
|
||||
verified_by="admin",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
assert entry.is_verified is True
|
||||
assert entry.verified_by == "admin"
|
||||
|
||||
def test_gating_result_item(self):
|
||||
"""GatingResultItem should serialize all gate fields."""
|
||||
from backend.web.schemas.admin.pool import GatingResultItem
|
||||
|
||||
item = GatingResultItem(
|
||||
result_id="result-uuid",
|
||||
model_version_id="model-uuid",
|
||||
gate1_status="pass",
|
||||
gate1_original_mAP=0.95,
|
||||
gate1_new_mAP=0.94,
|
||||
gate1_mAP_drop=0.01,
|
||||
gate2_status="pass",
|
||||
gate2_detection_rate=0.90,
|
||||
gate2_total_samples=50,
|
||||
gate2_detected_samples=45,
|
||||
overall_status="pass",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
assert item.gate1_status == "pass"
|
||||
assert item.overall_status == "pass"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Chain Prevention
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestChainPrevention:
|
||||
"""Tests for fine-tune chain prevention logic."""
|
||||
|
||||
def test_rejects_finetune_from_finetune_model(self):
|
||||
"""Should reject training when base model is already a fine-tune."""
|
||||
# Simulate the chain prevention check from datasets.py
|
||||
model_type = "finetune"
|
||||
base_model_version_id = str(uuid4())
|
||||
|
||||
# This should trigger rejection
|
||||
assert model_type == "finetune"
|
||||
|
||||
def test_allows_finetune_from_base_model(self):
|
||||
"""Should allow training when base model is a base model."""
|
||||
model_type = "base"
|
||||
assert model_type != "finetune"
|
||||
|
||||
def test_allows_fresh_training(self):
|
||||
"""Should allow fresh training (no base model)."""
|
||||
base_model_version_id = None
|
||||
assert base_model_version_id is None # No chain check needed
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Pool Threshold
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPoolThreshold:
|
||||
"""Tests for minimum pool size enforcement."""
|
||||
|
||||
def test_min_pool_size_constant(self):
|
||||
"""MIN_POOL_SIZE should be 50."""
|
||||
from backend.web.services.data_mixer import MIN_POOL_SIZE
|
||||
|
||||
assert MIN_POOL_SIZE == 50
|
||||
|
||||
def test_pool_below_threshold_blocks_finetune(self):
|
||||
"""Pool with fewer than 50 verified entries should block fine-tuning."""
|
||||
from backend.web.services.data_mixer import MIN_POOL_SIZE
|
||||
|
||||
verified_count = 30
|
||||
assert verified_count < MIN_POOL_SIZE
|
||||
|
||||
def test_pool_at_threshold_allows_finetune(self):
|
||||
"""Pool with exactly 50 verified entries should allow fine-tuning."""
|
||||
from backend.web.services.data_mixer import MIN_POOL_SIZE
|
||||
|
||||
verified_count = 50
|
||||
assert verified_count >= MIN_POOL_SIZE
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Gating Enforcement on Activation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGatingEnforcement:
|
||||
"""Tests for gating enforcement when activating models."""
|
||||
|
||||
def test_base_model_skips_gating(self):
|
||||
"""Base models should have gating_status 'skipped'."""
|
||||
from backend.data.admin_models import ModelVersion
|
||||
|
||||
mv = ModelVersion(
|
||||
version="v1.0",
|
||||
name="base",
|
||||
model_path="/model.pt",
|
||||
model_type="base",
|
||||
)
|
||||
# Base models skip gating - activation should work
|
||||
assert mv.model_type == "base"
|
||||
# Gating should not block base model activation
|
||||
|
||||
def test_finetune_model_rejected_blocks_activation(self):
|
||||
"""Fine-tuned models with 'reject' gating should block activation."""
|
||||
model_type = "finetune"
|
||||
gating_status = "reject"
|
||||
|
||||
# Simulates the check in models.py activation endpoint
|
||||
should_block = model_type == "finetune" and gating_status == "reject"
|
||||
assert should_block is True
|
||||
|
||||
def test_finetune_model_pending_blocks_activation(self):
|
||||
"""Fine-tuned models with 'pending' gating should block activation."""
|
||||
model_type = "finetune"
|
||||
gating_status = "pending"
|
||||
|
||||
should_block = model_type == "finetune" and gating_status == "pending"
|
||||
assert should_block is True
|
||||
|
||||
def test_finetune_model_pass_allows_activation(self):
|
||||
"""Fine-tuned models with 'pass' gating should allow activation."""
|
||||
model_type = "finetune"
|
||||
gating_status = "pass"
|
||||
|
||||
should_block_reject = model_type == "finetune" and gating_status == "reject"
|
||||
should_block_pending = model_type == "finetune" and gating_status == "pending"
|
||||
assert should_block_reject is False
|
||||
assert should_block_pending is False
|
||||
|
||||
def test_finetune_model_review_allows_with_warning(self):
|
||||
"""Fine-tuned models with 'review' gating should allow but warn."""
|
||||
model_type = "finetune"
|
||||
gating_status = "review"
|
||||
|
||||
should_block_reject = model_type == "finetune" and gating_status == "reject"
|
||||
should_block_pending = model_type == "finetune" and gating_status == "pending"
|
||||
assert should_block_reject is False
|
||||
assert should_block_pending is False
|
||||
# Should include warning in message
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Pool API Route Registration
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPoolRouteRegistration:
|
||||
"""Tests for pool route registration."""
|
||||
|
||||
def test_pool_routes_registered(self):
|
||||
"""Pool routes should be registered on training router."""
|
||||
from backend.web.api.v1.admin.training import create_training_router
|
||||
|
||||
router = create_training_router()
|
||||
paths = [route.path for route in router.routes]
|
||||
|
||||
assert any("/pool" in p for p in paths)
|
||||
assert any("/pool/stats" in p for p in paths)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Scheduler Fine-Tune Parameter Override
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestSchedulerFineTuneParams:
|
||||
"""Tests for scheduler fine-tune parameter overrides."""
|
||||
|
||||
def test_finetune_detected_from_base_model_path(self):
|
||||
"""Scheduler should detect fine-tune mode from base_model_path."""
|
||||
config = {"base_model_path": "/path/to/base_model.pt"}
|
||||
is_finetune = bool(config.get("base_model_path"))
|
||||
assert is_finetune is True
|
||||
|
||||
def test_fresh_training_not_finetune(self):
|
||||
"""Scheduler should not enable fine-tune for fresh training."""
|
||||
config = {"model_name": "yolo26s.pt"}
|
||||
is_finetune = bool(config.get("base_model_path"))
|
||||
assert is_finetune is False
|
||||
|
||||
def test_finetune_defaults_correct_epochs(self):
|
||||
"""Fine-tune should default to 10 epochs."""
|
||||
config = {"base_model_path": "/path/to/model.pt"}
|
||||
is_finetune = bool(config.get("base_model_path"))
|
||||
|
||||
if is_finetune:
|
||||
epochs = config.get("epochs", 10)
|
||||
learning_rate = config.get("learning_rate", 0.001)
|
||||
else:
|
||||
epochs = config.get("epochs", 100)
|
||||
learning_rate = config.get("learning_rate", 0.01)
|
||||
|
||||
assert epochs == 10
|
||||
assert learning_rate == 0.001
|
||||
|
||||
def test_model_lineage_set_for_finetune(self):
|
||||
"""Scheduler should set model_type and base_model_version_id for fine-tune."""
|
||||
config = {
|
||||
"base_model_path": "/path/to/model.pt",
|
||||
"base_model_version_id": str(uuid4()),
|
||||
}
|
||||
is_finetune = bool(config.get("base_model_path"))
|
||||
model_type = "finetune" if is_finetune else "base"
|
||||
base_model_version_id = config.get("base_model_version_id") if is_finetune else None
|
||||
gating_status = "pending" if is_finetune else "skipped"
|
||||
|
||||
assert model_type == "finetune"
|
||||
assert base_model_version_id is not None
|
||||
assert gating_status == "pending"
|
||||
|
||||
def test_model_lineage_skipped_for_base(self):
|
||||
"""Scheduler should set model_type='base' for fresh training."""
|
||||
config = {"model_name": "yolo26s.pt"}
|
||||
is_finetune = bool(config.get("base_model_path"))
|
||||
model_type = "finetune" if is_finetune else "base"
|
||||
gating_status = "pending" if is_finetune else "skipped"
|
||||
|
||||
assert model_type == "base"
|
||||
assert gating_status == "skipped"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test TrainingConfig freeze/cos_lr
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTrainingConfigFineTuneFields:
|
||||
"""Tests for freeze and cos_lr fields in shared TrainingConfig."""
|
||||
|
||||
def test_default_freeze_is_zero(self):
|
||||
"""TrainingConfig freeze should default to 0."""
|
||||
from shared.training import TrainingConfig
|
||||
|
||||
config = TrainingConfig(
|
||||
model_path="test.pt",
|
||||
data_yaml="data.yaml",
|
||||
)
|
||||
assert config.freeze == 0
|
||||
|
||||
def test_default_cos_lr_is_false(self):
|
||||
"""TrainingConfig cos_lr should default to False."""
|
||||
from shared.training import TrainingConfig
|
||||
|
||||
config = TrainingConfig(
|
||||
model_path="test.pt",
|
||||
data_yaml="data.yaml",
|
||||
)
|
||||
assert config.cos_lr is False
|
||||
|
||||
def test_finetune_config(self):
|
||||
"""TrainingConfig should accept fine-tune parameters."""
|
||||
from shared.training import TrainingConfig
|
||||
|
||||
config = TrainingConfig(
|
||||
model_path="base_model.pt",
|
||||
data_yaml="data.yaml",
|
||||
epochs=10,
|
||||
learning_rate=0.001,
|
||||
freeze=10,
|
||||
cos_lr=True,
|
||||
)
|
||||
assert config.freeze == 10
|
||||
assert config.cos_lr is True
|
||||
assert config.epochs == 10
|
||||
assert config.learning_rate == 0.001
|
||||
@@ -1,14 +1,14 @@
|
||||
"""
|
||||
Tests for Training Export with expand_bbox integration.
|
||||
Tests for Training Export with uniform expand_bbox integration.
|
||||
|
||||
Tests the export endpoint's integration with field-specific bbox expansion.
|
||||
Tests the export endpoint's integration with uniform bbox expansion.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from shared.bbox import expand_bbox
|
||||
from shared.bbox import expand_bbox, UNIFORM_PAD
|
||||
from shared.fields import CLASS_NAMES, FIELD_CLASS_IDS
|
||||
|
||||
|
||||
@@ -17,149 +17,87 @@ class TestExpandBboxForExport:
|
||||
|
||||
def test_expand_bbox_converts_normalized_to_pixel_and_back(self):
|
||||
"""Verify expand_bbox works with pixel-to-normalized conversion."""
|
||||
# Annotation stored as normalized coords
|
||||
x_center_norm = 0.5
|
||||
y_center_norm = 0.5
|
||||
width_norm = 0.1
|
||||
height_norm = 0.05
|
||||
|
||||
# Image dimensions
|
||||
img_width = 2480 # A4 at 300 DPI
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
# Convert to pixel coords
|
||||
x_center_px = x_center_norm * img_width
|
||||
y_center_px = y_center_norm * img_height
|
||||
width_px = width_norm * img_width
|
||||
height_px = height_norm * img_height
|
||||
|
||||
# Convert to corner coords
|
||||
x0 = x_center_px - width_px / 2
|
||||
y0 = y_center_px - height_px / 2
|
||||
x1 = x_center_px + width_px / 2
|
||||
y1 = y_center_px + height_px / 2
|
||||
|
||||
# Apply expansion
|
||||
class_name = "invoice_number"
|
||||
ex0, ey0, ex1, ey1 = expand_bbox(
|
||||
bbox=(x0, y0, x1, y1),
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type=class_name,
|
||||
)
|
||||
|
||||
# Verify expanded bbox is larger
|
||||
assert ex0 < x0 # Left expanded
|
||||
assert ey0 < y0 # Top expanded
|
||||
assert ex1 > x1 # Right expanded
|
||||
assert ey1 > y1 # Bottom expanded
|
||||
assert ex0 < x0
|
||||
assert ey0 < y0
|
||||
assert ex1 > x1
|
||||
assert ey1 > y1
|
||||
|
||||
# Convert back to normalized
|
||||
new_x_center = (ex0 + ex1) / 2 / img_width
|
||||
new_y_center = (ey0 + ey1) / 2 / img_height
|
||||
new_width = (ex1 - ex0) / img_width
|
||||
new_height = (ey1 - ey0) / img_height
|
||||
|
||||
# Verify valid normalized coords
|
||||
assert 0 <= new_x_center <= 1
|
||||
assert 0 <= new_y_center <= 1
|
||||
assert 0 <= new_width <= 1
|
||||
assert 0 <= new_height <= 1
|
||||
|
||||
def test_expand_bbox_manual_mode_minimal_expansion(self):
|
||||
"""Verify manual annotations use minimal expansion."""
|
||||
# Small bbox
|
||||
def test_expand_bbox_uniform_for_all_sources(self):
|
||||
"""Verify all annotation sources get the same uniform expansion."""
|
||||
bbox = (100, 100, 200, 150)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
# Auto mode (field-specific expansion)
|
||||
auto_result = expand_bbox(
|
||||
# All sources now get the same uniform expansion
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
manual_mode=False,
|
||||
)
|
||||
|
||||
# Manual mode (minimal expansion)
|
||||
manual_result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
manual_mode=True,
|
||||
expected = (
|
||||
100 - UNIFORM_PAD,
|
||||
100 - UNIFORM_PAD,
|
||||
200 + UNIFORM_PAD,
|
||||
150 + UNIFORM_PAD,
|
||||
)
|
||||
|
||||
# Auto expansion should be larger than manual
|
||||
auto_width = auto_result[2] - auto_result[0]
|
||||
manual_width = manual_result[2] - manual_result[0]
|
||||
assert auto_width > manual_width
|
||||
|
||||
auto_height = auto_result[3] - auto_result[1]
|
||||
manual_height = manual_result[3] - manual_result[1]
|
||||
assert auto_height > manual_height
|
||||
|
||||
def test_expand_bbox_different_sources_use_correct_mode(self):
|
||||
"""Verify different annotation sources use correct expansion mode."""
|
||||
bbox = (100, 100, 200, 150)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
# Define source to manual_mode mapping
|
||||
source_mode_mapping = {
|
||||
"manual": True, # Manual annotations -> minimal expansion
|
||||
"auto": False, # Auto-labeled -> field-specific expansion
|
||||
"imported": True, # Imported (from CSV) -> minimal expansion
|
||||
}
|
||||
|
||||
results = {}
|
||||
for source, manual_mode in source_mode_mapping.items():
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="ocr_number",
|
||||
manual_mode=manual_mode,
|
||||
)
|
||||
results[source] = result
|
||||
|
||||
# Auto should have largest expansion
|
||||
auto_area = (results["auto"][2] - results["auto"][0]) * \
|
||||
(results["auto"][3] - results["auto"][1])
|
||||
manual_area = (results["manual"][2] - results["manual"][0]) * \
|
||||
(results["manual"][3] - results["manual"][1])
|
||||
imported_area = (results["imported"][2] - results["imported"][0]) * \
|
||||
(results["imported"][3] - results["imported"][1])
|
||||
|
||||
assert auto_area > manual_area
|
||||
assert auto_area > imported_area
|
||||
# Manual and imported should be the same (both use minimal mode)
|
||||
assert manual_area == imported_area
|
||||
assert result == expected
|
||||
|
||||
def test_expand_bbox_all_field_types_work(self):
|
||||
"""Verify expand_bbox works for all field types."""
|
||||
"""Verify expand_bbox works for all field types (same result)."""
|
||||
bbox = (100, 100, 200, 150)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
for class_name in CLASS_NAMES:
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type=class_name,
|
||||
)
|
||||
# All fields should produce the same result with uniform padding
|
||||
first_result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
)
|
||||
|
||||
# Verify result is a valid bbox
|
||||
assert len(result) == 4
|
||||
x0, y0, x1, y1 = result
|
||||
assert x0 >= 0
|
||||
assert y0 >= 0
|
||||
assert x1 <= img_width
|
||||
assert y1 <= img_height
|
||||
assert x1 > x0
|
||||
assert y1 > y0
|
||||
assert len(first_result) == 4
|
||||
x0, y0, x1, y1 = first_result
|
||||
assert x0 >= 0
|
||||
assert y0 >= 0
|
||||
assert x1 <= img_width
|
||||
assert y1 <= img_height
|
||||
assert x1 > x0
|
||||
assert y1 > y0
|
||||
|
||||
|
||||
class TestExportAnnotationExpansion:
|
||||
@@ -167,7 +105,6 @@ class TestExportAnnotationExpansion:
|
||||
|
||||
def test_annotation_bbox_conversion_workflow(self):
|
||||
"""Test full annotation bbox conversion workflow."""
|
||||
# Simulate stored annotation (normalized coords)
|
||||
class MockAnnotation:
|
||||
class_id = FIELD_CLASS_IDS["invoice_number"]
|
||||
class_name = "invoice_number"
|
||||
@@ -181,7 +118,6 @@ class TestExportAnnotationExpansion:
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
# Step 1: Convert normalized to pixel corner coords
|
||||
half_w = (ann.width * img_width) / 2
|
||||
half_h = (ann.height * img_height) / 2
|
||||
x0 = ann.x_center * img_width - half_w
|
||||
@@ -189,38 +125,27 @@ class TestExportAnnotationExpansion:
|
||||
x1 = ann.x_center * img_width + half_w
|
||||
y1 = ann.y_center * img_height + half_h
|
||||
|
||||
# Step 2: Determine manual_mode based on source
|
||||
manual_mode = ann.source in ("manual", "imported")
|
||||
|
||||
# Step 3: Apply expand_bbox
|
||||
ex0, ey0, ex1, ey1 = expand_bbox(
|
||||
bbox=(x0, y0, x1, y1),
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type=ann.class_name,
|
||||
manual_mode=manual_mode,
|
||||
)
|
||||
|
||||
# Step 4: Convert back to normalized
|
||||
new_x_center = (ex0 + ex1) / 2 / img_width
|
||||
new_y_center = (ey0 + ey1) / 2 / img_height
|
||||
new_width = (ex1 - ex0) / img_width
|
||||
new_height = (ey1 - ey0) / img_height
|
||||
|
||||
# Verify expansion happened (auto mode)
|
||||
assert new_width > ann.width
|
||||
assert new_height > ann.height
|
||||
|
||||
# Verify valid YOLO format
|
||||
assert 0 <= new_x_center <= 1
|
||||
assert 0 <= new_y_center <= 1
|
||||
assert 0 < new_width <= 1
|
||||
assert 0 < new_height <= 1
|
||||
|
||||
def test_export_applies_expansion_to_each_annotation(self):
|
||||
"""Test that export applies expansion to each annotation."""
|
||||
# Simulate multiple annotations with different sources
|
||||
# Use smaller bboxes so manual mode padding has visible effect
|
||||
def test_export_applies_uniform_expansion_to_all_annotations(self):
|
||||
"""Test that export applies uniform expansion to all annotations."""
|
||||
annotations = [
|
||||
{"class_name": "invoice_number", "source": "auto", "x_center": 0.3, "y_center": 0.2, "width": 0.05, "height": 0.02},
|
||||
{"class_name": "ocr_number", "source": "manual", "x_center": 0.5, "y_center": 0.8, "width": 0.05, "height": 0.02},
|
||||
@@ -232,7 +157,6 @@ class TestExportAnnotationExpansion:
|
||||
|
||||
expanded_annotations = []
|
||||
for ann in annotations:
|
||||
# Convert to pixel coords
|
||||
half_w = (ann["width"] * img_width) / 2
|
||||
half_h = (ann["height"] * img_height) / 2
|
||||
x0 = ann["x_center"] * img_width - half_w
|
||||
@@ -240,19 +164,12 @@ class TestExportAnnotationExpansion:
|
||||
x1 = ann["x_center"] * img_width + half_w
|
||||
y1 = ann["y_center"] * img_height + half_h
|
||||
|
||||
# Determine manual_mode
|
||||
manual_mode = ann["source"] in ("manual", "imported")
|
||||
|
||||
# Apply expansion
|
||||
ex0, ey0, ex1, ey1 = expand_bbox(
|
||||
bbox=(x0, y0, x1, y1),
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type=ann["class_name"],
|
||||
manual_mode=manual_mode,
|
||||
)
|
||||
|
||||
# Convert back to normalized
|
||||
expanded_annotations.append({
|
||||
"class_name": ann["class_name"],
|
||||
"source": ann["source"],
|
||||
@@ -262,106 +179,48 @@ class TestExportAnnotationExpansion:
|
||||
"height": (ey1 - ey0) / img_height,
|
||||
})
|
||||
|
||||
# Verify auto-labeled annotation expanded more than manual/imported
|
||||
auto_ann = next(a for a in expanded_annotations if a["source"] == "auto")
|
||||
manual_ann = next(a for a in expanded_annotations if a["source"] == "manual")
|
||||
|
||||
# Auto mode should expand more than manual mode
|
||||
# (auto has larger scale factors and max_pad)
|
||||
assert auto_ann["width"] > manual_ann["width"]
|
||||
assert auto_ann["height"] > manual_ann["height"]
|
||||
|
||||
# All annotations should be expanded (at least slightly for manual mode)
|
||||
# Allow small precision loss (< 1%) due to integer conversion in expand_bbox
|
||||
for i, (orig, exp) in enumerate(zip(annotations, expanded_annotations)):
|
||||
# Width and height should be >= original (expansion or equal, with small tolerance)
|
||||
tolerance = 0.01 # 1% tolerance for integer rounding
|
||||
assert exp["width"] >= orig["width"] * (1 - tolerance), \
|
||||
f"Annotation {i} width unexpectedly smaller: {exp['width']} < {orig['width']}"
|
||||
assert exp["height"] >= orig["height"] * (1 - tolerance), \
|
||||
f"Annotation {i} height unexpectedly smaller: {exp['height']} < {orig['height']}"
|
||||
# All annotations get the same expansion
|
||||
tolerance = 0.01
|
||||
for orig, exp in zip(annotations, expanded_annotations):
|
||||
assert exp["width"] >= orig["width"] * (1 - tolerance)
|
||||
assert exp["height"] >= orig["height"] * (1 - tolerance)
|
||||
|
||||
|
||||
class TestExpandBboxEdgeCases:
|
||||
"""Tests for edge cases in export bbox expansion."""
|
||||
|
||||
def test_bbox_at_image_edge_left(self):
|
||||
"""Test bbox at left edge of image."""
|
||||
bbox = (0, 100, 50, 150)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
result = expand_bbox(bbox=bbox, image_width=2480, image_height=3508)
|
||||
|
||||
# Left edge should be clamped to 0
|
||||
assert result[0] >= 0
|
||||
|
||||
def test_bbox_at_image_edge_right(self):
|
||||
"""Test bbox at right edge of image."""
|
||||
bbox = (2400, 100, 2480, 150)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
result = expand_bbox(bbox=bbox, image_width=2480, image_height=3508)
|
||||
|
||||
# Right edge should be clamped to image width
|
||||
assert result[2] <= img_width
|
||||
assert result[2] <= 2480
|
||||
|
||||
def test_bbox_at_image_edge_top(self):
|
||||
"""Test bbox at top edge of image."""
|
||||
bbox = (100, 0, 200, 50)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
result = expand_bbox(bbox=bbox, image_width=2480, image_height=3508)
|
||||
|
||||
# Top edge should be clamped to 0
|
||||
assert result[1] >= 0
|
||||
|
||||
def test_bbox_at_image_edge_bottom(self):
|
||||
"""Test bbox at bottom edge of image."""
|
||||
bbox = (100, 3400, 200, 3508)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
result = expand_bbox(bbox=bbox, image_width=2480, image_height=3508)
|
||||
|
||||
# Bottom edge should be clamped to image height
|
||||
assert result[3] <= img_height
|
||||
assert result[3] <= 3508
|
||||
|
||||
def test_very_small_bbox(self):
|
||||
"""Test very small bbox gets expanded."""
|
||||
bbox = (100, 100, 105, 105) # 5x5 pixel bbox
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
bbox = (100, 100, 105, 105)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
result = expand_bbox(bbox=bbox, image_width=2480, image_height=3508)
|
||||
|
||||
# Should still produce a valid expanded bbox
|
||||
assert result[2] > result[0]
|
||||
assert result[3] > result[1]
|
||||
|
||||
Reference in New Issue
Block a user