WIP
This commit is contained in:
1
tests/services/__init__.py
Normal file
1
tests/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for backend services."""
|
||||
344
tests/services/test_data_mixer.py
Normal file
344
tests/services/test_data_mixer.py
Normal file
@@ -0,0 +1,344 @@
|
||||
"""
|
||||
Tests for Data Mixing Service.
|
||||
|
||||
Tests cover:
|
||||
1. get_mixing_ratio boundary values
|
||||
2. build_mixed_dataset with temp filesystem
|
||||
3. _find_pool_images matching logic
|
||||
4. _image_to_label_path conversion
|
||||
5. Edge cases (empty pool, no old data, cap)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
from backend.web.services.data_mixer import (
|
||||
get_mixing_ratio,
|
||||
build_mixed_dataset,
|
||||
_collect_images,
|
||||
_image_to_label_path,
|
||||
_find_pool_images,
|
||||
MIXING_RATIOS,
|
||||
DEFAULT_MULTIPLIER,
|
||||
MAX_OLD_SAMPLES,
|
||||
MIN_POOL_SIZE,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Constants
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestConstants:
|
||||
"""Tests for data mixer constants."""
|
||||
|
||||
def test_mixing_ratios_defined(self):
|
||||
"""MIXING_RATIOS should have expected entries."""
|
||||
assert len(MIXING_RATIOS) == 4
|
||||
assert MIXING_RATIOS[0] == (10, 50)
|
||||
assert MIXING_RATIOS[1] == (50, 20)
|
||||
assert MIXING_RATIOS[2] == (200, 10)
|
||||
assert MIXING_RATIOS[3] == (500, 5)
|
||||
|
||||
def test_default_multiplier(self):
|
||||
"""DEFAULT_MULTIPLIER should be 5."""
|
||||
assert DEFAULT_MULTIPLIER == 5
|
||||
|
||||
def test_max_old_samples(self):
|
||||
"""MAX_OLD_SAMPLES should be 3000."""
|
||||
assert MAX_OLD_SAMPLES == 3000
|
||||
|
||||
def test_min_pool_size(self):
|
||||
"""MIN_POOL_SIZE should be 50."""
|
||||
assert MIN_POOL_SIZE == 50
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test get_mixing_ratio
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGetMixingRatio:
|
||||
"""Tests for get_mixing_ratio function."""
|
||||
|
||||
def test_1_sample_returns_50x(self):
|
||||
"""1 new sample should get 50x old data."""
|
||||
assert get_mixing_ratio(1) == 50
|
||||
|
||||
def test_10_samples_returns_50x(self):
|
||||
"""10 new samples (boundary) should get 50x."""
|
||||
assert get_mixing_ratio(10) == 50
|
||||
|
||||
def test_11_samples_returns_20x(self):
|
||||
"""11 new samples should get 20x."""
|
||||
assert get_mixing_ratio(11) == 20
|
||||
|
||||
def test_50_samples_returns_20x(self):
|
||||
"""50 new samples (boundary) should get 20x."""
|
||||
assert get_mixing_ratio(50) == 20
|
||||
|
||||
def test_51_samples_returns_10x(self):
|
||||
"""51 new samples should get 10x."""
|
||||
assert get_mixing_ratio(51) == 10
|
||||
|
||||
def test_200_samples_returns_10x(self):
|
||||
"""200 new samples (boundary) should get 10x."""
|
||||
assert get_mixing_ratio(200) == 10
|
||||
|
||||
def test_201_samples_returns_5x(self):
|
||||
"""201 new samples should get 5x."""
|
||||
assert get_mixing_ratio(201) == 5
|
||||
|
||||
def test_500_samples_returns_5x(self):
|
||||
"""500 new samples (boundary) should get 5x."""
|
||||
assert get_mixing_ratio(500) == 5
|
||||
|
||||
def test_1000_samples_returns_default(self):
|
||||
"""1000+ samples should get default multiplier (5x)."""
|
||||
assert get_mixing_ratio(1000) == DEFAULT_MULTIPLIER
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test _collect_images
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestCollectImages:
|
||||
"""Tests for _collect_images function."""
|
||||
|
||||
def test_collects_png_files(self, tmp_path: Path):
|
||||
"""Should collect .png files."""
|
||||
(tmp_path / "img1.png").write_bytes(b"fake png")
|
||||
(tmp_path / "img2.png").write_bytes(b"fake png")
|
||||
|
||||
images = _collect_images(tmp_path)
|
||||
assert len(images) == 2
|
||||
|
||||
def test_collects_jpg_files(self, tmp_path: Path):
|
||||
"""Should collect .jpg files."""
|
||||
(tmp_path / "img1.jpg").write_bytes(b"fake jpg")
|
||||
|
||||
images = _collect_images(tmp_path)
|
||||
assert len(images) == 1
|
||||
|
||||
def test_collects_both_types(self, tmp_path: Path):
|
||||
"""Should collect both .png and .jpg files."""
|
||||
(tmp_path / "img1.png").write_bytes(b"fake png")
|
||||
(tmp_path / "img2.jpg").write_bytes(b"fake jpg")
|
||||
|
||||
images = _collect_images(tmp_path)
|
||||
assert len(images) == 2
|
||||
|
||||
def test_ignores_other_files(self, tmp_path: Path):
|
||||
"""Should ignore non-image files."""
|
||||
(tmp_path / "data.txt").write_text("not an image")
|
||||
(tmp_path / "data.yaml").write_text("yaml")
|
||||
(tmp_path / "img.png").write_bytes(b"png")
|
||||
|
||||
images = _collect_images(tmp_path)
|
||||
assert len(images) == 1
|
||||
|
||||
def test_returns_empty_for_nonexistent_dir(self, tmp_path: Path):
|
||||
"""Should return empty list for nonexistent directory."""
|
||||
images = _collect_images(tmp_path / "nonexistent")
|
||||
assert images == []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test _image_to_label_path
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestImageToLabelPath:
|
||||
"""Tests for _image_to_label_path function."""
|
||||
|
||||
def test_converts_train_image_to_label(self, tmp_path: Path):
|
||||
"""Should convert images/train/img.png to labels/train/img.txt."""
|
||||
image_path = tmp_path / "dataset" / "images" / "train" / "doc1_page1.png"
|
||||
label_path = _image_to_label_path(image_path)
|
||||
|
||||
assert label_path.name == "doc1_page1.txt"
|
||||
assert "labels" in str(label_path)
|
||||
assert "train" in str(label_path)
|
||||
|
||||
def test_converts_val_image_to_label(self, tmp_path: Path):
|
||||
"""Should convert images/val/img.jpg to labels/val/img.txt."""
|
||||
image_path = tmp_path / "dataset" / "images" / "val" / "doc2_page3.jpg"
|
||||
label_path = _image_to_label_path(image_path)
|
||||
|
||||
assert label_path.name == "doc2_page3.txt"
|
||||
assert "labels" in str(label_path)
|
||||
assert "val" in str(label_path)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test _find_pool_images
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestFindPoolImages:
|
||||
"""Tests for _find_pool_images function."""
|
||||
|
||||
def _create_dataset(self, base_path: Path, doc_ids: list[str], split: str = "train") -> None:
|
||||
"""Helper to create a dataset structure with images."""
|
||||
images_dir = base_path / "images" / split
|
||||
images_dir.mkdir(parents=True, exist_ok=True)
|
||||
for doc_id in doc_ids:
|
||||
(images_dir / f"{doc_id}_page1.png").write_bytes(b"img")
|
||||
(images_dir / f"{doc_id}_page2.png").write_bytes(b"img")
|
||||
|
||||
def test_finds_matching_images(self, tmp_path: Path):
|
||||
"""Should find images matching pool document IDs."""
|
||||
doc_id1 = str(uuid4())
|
||||
doc_id2 = str(uuid4())
|
||||
self._create_dataset(tmp_path, [doc_id1, doc_id2])
|
||||
|
||||
pool_ids = {doc_id1}
|
||||
images = _find_pool_images(tmp_path, pool_ids)
|
||||
|
||||
assert len(images) == 2 # 2 pages for doc_id1
|
||||
assert all(doc_id1 in str(img) for img in images)
|
||||
|
||||
def test_ignores_non_pool_images(self, tmp_path: Path):
|
||||
"""Should not return images for documents not in pool."""
|
||||
doc_id1 = str(uuid4())
|
||||
doc_id2 = str(uuid4())
|
||||
self._create_dataset(tmp_path, [doc_id1, doc_id2])
|
||||
|
||||
pool_ids = {doc_id1}
|
||||
images = _find_pool_images(tmp_path, pool_ids)
|
||||
|
||||
# Only doc_id1 images should be found
|
||||
for img in images:
|
||||
assert doc_id1 in str(img)
|
||||
assert doc_id2 not in str(img)
|
||||
|
||||
def test_searches_all_splits(self, tmp_path: Path):
|
||||
"""Should search train, val, and test splits."""
|
||||
doc_id = str(uuid4())
|
||||
for split in ("train", "val", "test"):
|
||||
self._create_dataset(tmp_path, [doc_id], split=split)
|
||||
|
||||
images = _find_pool_images(tmp_path, {doc_id})
|
||||
assert len(images) == 6 # 2 pages * 3 splits
|
||||
|
||||
def test_empty_pool_returns_empty(self, tmp_path: Path):
|
||||
"""Should return empty list for empty pool IDs."""
|
||||
self._create_dataset(tmp_path, [str(uuid4())])
|
||||
|
||||
images = _find_pool_images(tmp_path, set())
|
||||
assert images == []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test build_mixed_dataset
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestBuildMixedDataset:
|
||||
"""Tests for build_mixed_dataset function."""
|
||||
|
||||
def _setup_base_dataset(self, base_path: Path, num_old: int = 20) -> None:
|
||||
"""Create a base dataset with old training images."""
|
||||
for split in ("train", "val"):
|
||||
img_dir = base_path / "images" / split
|
||||
lbl_dir = base_path / "labels" / split
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
lbl_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
count = int(num_old * 0.8) if split == "train" else num_old - int(num_old * 0.8)
|
||||
for i in range(count):
|
||||
doc_id = str(uuid4())
|
||||
img_file = img_dir / f"{doc_id}_page1.png"
|
||||
lbl_file = lbl_dir / f"{doc_id}_page1.txt"
|
||||
img_file.write_bytes(b"fake image data")
|
||||
lbl_file.write_text("0 0.5 0.5 0.1 0.1\n")
|
||||
|
||||
def _setup_pool_images(self, base_path: Path, doc_ids: list[str]) -> None:
|
||||
"""Add pool images to the base dataset."""
|
||||
img_dir = base_path / "images" / "train"
|
||||
lbl_dir = base_path / "labels" / "train"
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
lbl_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for doc_id in doc_ids:
|
||||
img_file = img_dir / f"{doc_id}_page1.png"
|
||||
lbl_file = lbl_dir / f"{doc_id}_page1.txt"
|
||||
img_file.write_bytes(b"pool image data")
|
||||
lbl_file.write_text("0 0.5 0.5 0.2 0.2\n")
|
||||
|
||||
@pytest.fixture
|
||||
def base_dataset(self, tmp_path: Path) -> Path:
|
||||
"""Create a base dataset for testing."""
|
||||
base_path = tmp_path / "base_dataset"
|
||||
self._setup_base_dataset(base_path, num_old=20)
|
||||
return base_path
|
||||
|
||||
def test_builds_output_structure(self, base_dataset: Path, tmp_path: Path):
|
||||
"""Should create proper YOLO directory structure."""
|
||||
pool_ids = [uuid4() for _ in range(5)]
|
||||
self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids])
|
||||
output_dir = tmp_path / "mixed_output"
|
||||
|
||||
result = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base_dataset,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
|
||||
assert (output_dir / "images" / "train").exists()
|
||||
assert (output_dir / "images" / "val").exists()
|
||||
assert (output_dir / "labels" / "train").exists()
|
||||
assert (output_dir / "labels" / "val").exists()
|
||||
assert (output_dir / "data.yaml").exists()
|
||||
|
||||
def test_returns_correct_metadata(self, base_dataset: Path, tmp_path: Path):
|
||||
"""Should return correct counts and metadata."""
|
||||
pool_ids = [uuid4() for _ in range(5)]
|
||||
self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids])
|
||||
output_dir = tmp_path / "mixed_output"
|
||||
|
||||
result = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base_dataset,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
|
||||
assert "data_yaml" in result
|
||||
assert "total_images" in result
|
||||
assert "old_images" in result
|
||||
assert "new_images" in result
|
||||
assert "mixing_ratio" in result
|
||||
assert result["total_images"] == result["old_images"] + result["new_images"]
|
||||
|
||||
def test_mixing_ratio_applied(self, base_dataset: Path, tmp_path: Path):
|
||||
"""Should use correct mixing ratio based on pool size."""
|
||||
pool_ids = [uuid4() for _ in range(5)]
|
||||
self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids])
|
||||
output_dir = tmp_path / "mixed_output"
|
||||
|
||||
result = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base_dataset,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
|
||||
# 5 new samples -> 50x multiplier
|
||||
assert result["mixing_ratio"] == 50
|
||||
|
||||
def test_seed_reproducibility(self, base_dataset: Path, tmp_path: Path):
|
||||
"""Same seed should produce same output."""
|
||||
pool_ids = [uuid4() for _ in range(3)]
|
||||
self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids])
|
||||
|
||||
out1 = tmp_path / "out1"
|
||||
out2 = tmp_path / "out2"
|
||||
|
||||
r1 = build_mixed_dataset(pool_ids, base_dataset, out1, seed=42)
|
||||
r2 = build_mixed_dataset(pool_ids, base_dataset, out2, seed=42)
|
||||
|
||||
assert r1["old_images"] == r2["old_images"]
|
||||
assert r1["new_images"] == r2["new_images"]
|
||||
assert r1["total_images"] == r2["total_images"]
|
||||
540
tests/services/test_gating_validator.py
Normal file
540
tests/services/test_gating_validator.py
Normal file
@@ -0,0 +1,540 @@
|
||||
"""
|
||||
Unit tests for gating validation service.
|
||||
|
||||
Tests the quality gate validation logic for model deployment:
|
||||
- Gate 1: mAP regression validation
|
||||
- Gate 2: detection rate validation
|
||||
- Overall status computation
|
||||
- Full validation workflow with mocked dependencies
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from backend.web.services.gating_validator import (
|
||||
GATE1_PASS_THRESHOLD,
|
||||
GATE1_REVIEW_THRESHOLD,
|
||||
GATE2_PASS_THRESHOLD,
|
||||
classify_gate1,
|
||||
classify_gate2,
|
||||
compute_overall_status,
|
||||
run_gating_validation,
|
||||
)
|
||||
from backend.data.admin_models import GatingResult
|
||||
|
||||
|
||||
class TestClassifyGate1:
|
||||
"""Test Gate 1 classification (mAP drop thresholds)."""
|
||||
|
||||
def test_pass_below_threshold(self):
|
||||
"""Test mAP drop < 0.01 returns pass."""
|
||||
assert classify_gate1(0.009) == "pass"
|
||||
assert classify_gate1(0.005) == "pass"
|
||||
assert classify_gate1(0.0) == "pass"
|
||||
assert classify_gate1(-0.01) == "pass" # negative drop (improvement)
|
||||
|
||||
def test_pass_boundary(self):
|
||||
"""Test mAP drop exactly at pass threshold."""
|
||||
# 0.01 should be review (not pass), since condition is < 0.01
|
||||
assert classify_gate1(GATE1_PASS_THRESHOLD) == "review"
|
||||
|
||||
def test_review_in_range(self):
|
||||
"""Test mAP drop in review range [0.01, 0.03)."""
|
||||
assert classify_gate1(0.01) == "review"
|
||||
assert classify_gate1(0.015) == "review"
|
||||
assert classify_gate1(0.02) == "review"
|
||||
assert classify_gate1(0.029) == "review"
|
||||
|
||||
def test_review_boundary(self):
|
||||
"""Test mAP drop exactly at review threshold."""
|
||||
# 0.03 should be reject (not review), since condition is < 0.03
|
||||
assert classify_gate1(GATE1_REVIEW_THRESHOLD) == "reject"
|
||||
|
||||
def test_reject_above_threshold(self):
|
||||
"""Test mAP drop >= 0.03 returns reject."""
|
||||
assert classify_gate1(0.03) == "reject"
|
||||
assert classify_gate1(0.05) == "reject"
|
||||
assert classify_gate1(0.10) == "reject"
|
||||
assert classify_gate1(1.0) == "reject"
|
||||
|
||||
|
||||
class TestClassifyGate2:
|
||||
"""Test Gate 2 classification (detection rate thresholds)."""
|
||||
|
||||
def test_pass_above_threshold(self):
|
||||
"""Test detection rate >= 0.80 returns pass."""
|
||||
assert classify_gate2(0.80) == "pass"
|
||||
assert classify_gate2(0.85) == "pass"
|
||||
assert classify_gate2(0.99) == "pass"
|
||||
assert classify_gate2(1.0) == "pass"
|
||||
|
||||
def test_pass_boundary(self):
|
||||
"""Test detection rate exactly at pass threshold."""
|
||||
assert classify_gate2(GATE2_PASS_THRESHOLD) == "pass"
|
||||
|
||||
def test_review_below_threshold(self):
|
||||
"""Test detection rate < 0.80 returns review."""
|
||||
assert classify_gate2(0.79) == "review"
|
||||
assert classify_gate2(0.75) == "review"
|
||||
assert classify_gate2(0.50) == "review"
|
||||
assert classify_gate2(0.0) == "review"
|
||||
|
||||
|
||||
class TestComputeOverallStatus:
|
||||
"""Test overall status computation from individual gates."""
|
||||
|
||||
def test_both_pass(self):
|
||||
"""Test both gates pass -> overall pass."""
|
||||
assert compute_overall_status("pass", "pass") == "pass"
|
||||
|
||||
def test_gate1_reject_gate2_pass(self):
|
||||
"""Test any reject -> overall reject."""
|
||||
assert compute_overall_status("reject", "pass") == "reject"
|
||||
|
||||
def test_gate1_pass_gate2_reject(self):
|
||||
"""Test any reject -> overall reject."""
|
||||
assert compute_overall_status("pass", "reject") == "reject"
|
||||
|
||||
def test_both_reject(self):
|
||||
"""Test both reject -> overall reject."""
|
||||
assert compute_overall_status("reject", "reject") == "reject"
|
||||
|
||||
def test_gate1_review_gate2_pass(self):
|
||||
"""Test any review (no reject) -> overall review."""
|
||||
assert compute_overall_status("review", "pass") == "review"
|
||||
|
||||
def test_gate1_pass_gate2_review(self):
|
||||
"""Test any review (no reject) -> overall review."""
|
||||
assert compute_overall_status("pass", "review") == "review"
|
||||
|
||||
def test_both_review(self):
|
||||
"""Test both review -> overall review."""
|
||||
assert compute_overall_status("review", "review") == "review"
|
||||
|
||||
def test_gate1_reject_gate2_review(self):
|
||||
"""Test reject takes precedence over review."""
|
||||
assert compute_overall_status("reject", "review") == "reject"
|
||||
|
||||
def test_gate1_review_gate2_reject(self):
|
||||
"""Test reject takes precedence over review."""
|
||||
assert compute_overall_status("review", "reject") == "reject"
|
||||
|
||||
|
||||
class TestRunGatingValidation:
|
||||
"""Test full gating validation workflow with mocked dependencies."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_version_id(self):
|
||||
"""Generate a model version ID for testing."""
|
||||
return uuid4()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_base_model_version_id(self):
|
||||
"""Generate a base model version ID for testing."""
|
||||
return uuid4()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task_id(self):
|
||||
"""Generate a task ID for testing."""
|
||||
return uuid4()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_base_model(self):
|
||||
"""Create a mock base model with metrics."""
|
||||
model = Mock()
|
||||
model.metrics_mAP = 0.85
|
||||
return model
|
||||
|
||||
@pytest.fixture
|
||||
def mock_new_model(self):
|
||||
"""Create a mock new model with metrics."""
|
||||
model = Mock()
|
||||
model.metrics_mAP = 0.82
|
||||
return model
|
||||
|
||||
def test_gate1_pass_gate2_pass(
|
||||
self,
|
||||
mock_model_version_id,
|
||||
mock_base_model_version_id,
|
||||
mock_task_id,
|
||||
mock_base_model,
|
||||
mock_new_model,
|
||||
):
|
||||
"""Test validation with both gates passing."""
|
||||
# Setup: base mAP=0.85, new mAP=0.84 -> drop=0.01 (review)
|
||||
# But new model mAP=0.82 -> gate2 pass
|
||||
mock_base_model.metrics_mAP = 0.85
|
||||
mock_new_model.metrics_mAP = 0.82
|
||||
|
||||
mock_val_metrics = {"mAP50": 0.84}
|
||||
|
||||
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
|
||||
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
|
||||
patch("shared.training.YOLOTrainer") as MockTrainer, \
|
||||
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
|
||||
|
||||
# Mock repository
|
||||
mock_repo = MockRepo.return_value
|
||||
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
|
||||
|
||||
# Mock session context
|
||||
mock_session = MagicMock()
|
||||
mock_session_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock YOLO trainer
|
||||
mock_trainer = MockTrainer.return_value
|
||||
mock_trainer.validate.return_value = mock_val_metrics
|
||||
|
||||
# Execute
|
||||
result = run_gating_validation(
|
||||
model_version_id=mock_model_version_id,
|
||||
new_model_path="/path/to/model.pt",
|
||||
base_model_version_id=mock_base_model_version_id,
|
||||
data_yaml="/path/to/data.yaml",
|
||||
task_id=mock_task_id,
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result.gate1_status == "review" # 0.85 - 0.84 = 0.01
|
||||
assert result.gate1_original_mAP == 0.85
|
||||
assert result.gate1_new_mAP == 0.84
|
||||
assert result.gate1_mAP_drop == pytest.approx(0.01, abs=1e-6)
|
||||
|
||||
assert result.gate2_status == "pass" # 0.82 >= 0.80
|
||||
assert result.gate2_detection_rate == 0.82
|
||||
|
||||
assert result.overall_status == "review" # Any review -> review
|
||||
|
||||
# Verify DB operations
|
||||
mock_session.add.assert_called()
|
||||
mock_session.commit.assert_called()
|
||||
mock_update.assert_called_once_with(str(mock_model_version_id), "review")
|
||||
|
||||
def test_gate1_reject_due_to_large_drop(
|
||||
self,
|
||||
mock_model_version_id,
|
||||
mock_base_model_version_id,
|
||||
mock_task_id,
|
||||
mock_base_model,
|
||||
mock_new_model,
|
||||
):
|
||||
"""Test Gate 1 reject when mAP drop >= 0.03."""
|
||||
mock_base_model.metrics_mAP = 0.85
|
||||
mock_new_model.metrics_mAP = 0.82
|
||||
|
||||
mock_val_metrics = {"mAP50": 0.81} # 0.85 - 0.81 = 0.04 (reject)
|
||||
|
||||
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
|
||||
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
|
||||
patch("shared.training.YOLOTrainer") as MockTrainer, \
|
||||
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
|
||||
|
||||
mock_repo = MockRepo.return_value
|
||||
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_trainer = MockTrainer.return_value
|
||||
mock_trainer.validate.return_value = mock_val_metrics
|
||||
|
||||
result = run_gating_validation(
|
||||
model_version_id=mock_model_version_id,
|
||||
new_model_path="/path/to/model.pt",
|
||||
base_model_version_id=mock_base_model_version_id,
|
||||
data_yaml="/path/to/data.yaml",
|
||||
task_id=mock_task_id,
|
||||
)
|
||||
|
||||
assert result.gate1_status == "reject"
|
||||
assert result.gate1_mAP_drop == pytest.approx(0.04, abs=1e-6)
|
||||
assert result.overall_status == "reject" # Any reject -> reject
|
||||
|
||||
mock_update.assert_called_once_with(str(mock_model_version_id), "reject")
|
||||
|
||||
def test_gate2_review_due_to_low_detection_rate(
|
||||
self,
|
||||
mock_model_version_id,
|
||||
mock_base_model_version_id,
|
||||
mock_task_id,
|
||||
mock_base_model,
|
||||
mock_new_model,
|
||||
):
|
||||
"""Test Gate 2 review when detection rate < 0.80."""
|
||||
mock_base_model.metrics_mAP = 0.85
|
||||
mock_new_model.metrics_mAP = 0.75 # Below 0.80 threshold
|
||||
|
||||
mock_val_metrics = {"mAP50": 0.845} # Gate 1: 0.85 - 0.845 = 0.005 (pass)
|
||||
|
||||
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
|
||||
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
|
||||
patch("shared.training.YOLOTrainer") as MockTrainer, \
|
||||
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
|
||||
|
||||
mock_repo = MockRepo.return_value
|
||||
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_trainer = MockTrainer.return_value
|
||||
mock_trainer.validate.return_value = mock_val_metrics
|
||||
|
||||
result = run_gating_validation(
|
||||
model_version_id=mock_model_version_id,
|
||||
new_model_path="/path/to/model.pt",
|
||||
base_model_version_id=mock_base_model_version_id,
|
||||
data_yaml="/path/to/data.yaml",
|
||||
task_id=mock_task_id,
|
||||
)
|
||||
|
||||
assert result.gate1_status == "pass"
|
||||
assert result.gate2_status == "review" # 0.75 < 0.80
|
||||
assert result.gate2_detection_rate == 0.75
|
||||
assert result.overall_status == "review"
|
||||
|
||||
mock_update.assert_called_once_with(str(mock_model_version_id), "review")
|
||||
|
||||
def test_no_base_model_skips_gate1(
|
||||
self,
|
||||
mock_model_version_id,
|
||||
mock_task_id,
|
||||
mock_new_model,
|
||||
):
|
||||
"""Test Gate 1 passes when no base model is provided."""
|
||||
mock_new_model.metrics_mAP = 0.85
|
||||
|
||||
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
|
||||
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
|
||||
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
|
||||
|
||||
mock_repo = MockRepo.return_value
|
||||
mock_repo.get.return_value = mock_new_model
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
result = run_gating_validation(
|
||||
model_version_id=mock_model_version_id,
|
||||
new_model_path="/path/to/model.pt",
|
||||
base_model_version_id=None,
|
||||
data_yaml="/path/to/data.yaml",
|
||||
task_id=mock_task_id,
|
||||
)
|
||||
|
||||
assert result.gate1_status == "pass" # Skipped
|
||||
assert result.gate1_original_mAP is None
|
||||
assert result.gate1_new_mAP is None
|
||||
assert result.gate1_mAP_drop is None
|
||||
|
||||
assert result.gate2_status == "pass" # 0.85 >= 0.80
|
||||
assert result.overall_status == "pass"
|
||||
|
||||
mock_update.assert_called_once_with(str(mock_model_version_id), "pass")
|
||||
|
||||
def test_base_model_without_metrics_skips_gate1(
|
||||
self,
|
||||
mock_model_version_id,
|
||||
mock_base_model_version_id,
|
||||
mock_task_id,
|
||||
mock_base_model,
|
||||
mock_new_model,
|
||||
):
|
||||
"""Test Gate 1 passes when base model has no metrics."""
|
||||
mock_base_model.metrics_mAP = None
|
||||
mock_new_model.metrics_mAP = 0.85
|
||||
|
||||
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
|
||||
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
|
||||
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
|
||||
|
||||
mock_repo = MockRepo.return_value
|
||||
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
result = run_gating_validation(
|
||||
model_version_id=mock_model_version_id,
|
||||
new_model_path="/path/to/model.pt",
|
||||
base_model_version_id=mock_base_model_version_id,
|
||||
data_yaml="/path/to/data.yaml",
|
||||
task_id=mock_task_id,
|
||||
)
|
||||
|
||||
assert result.gate1_status == "pass" # Skipped due to no base metrics
|
||||
assert result.gate2_status == "pass"
|
||||
assert result.overall_status == "pass"
|
||||
|
||||
def test_validation_failure_marks_gate1_review(
|
||||
self,
|
||||
mock_model_version_id,
|
||||
mock_base_model_version_id,
|
||||
mock_task_id,
|
||||
mock_base_model,
|
||||
mock_new_model,
|
||||
):
|
||||
"""Test Gate 1 review when validation raises exception."""
|
||||
mock_base_model.metrics_mAP = 0.85
|
||||
mock_new_model.metrics_mAP = 0.82
|
||||
|
||||
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
|
||||
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
|
||||
patch("shared.training.YOLOTrainer") as MockTrainer, \
|
||||
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
|
||||
|
||||
mock_repo = MockRepo.return_value
|
||||
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock trainer to raise exception
|
||||
mock_trainer = MockTrainer.return_value
|
||||
mock_trainer.validate.side_effect = RuntimeError("Validation failed")
|
||||
|
||||
result = run_gating_validation(
|
||||
model_version_id=mock_model_version_id,
|
||||
new_model_path="/path/to/model.pt",
|
||||
base_model_version_id=mock_base_model_version_id,
|
||||
data_yaml="/path/to/data.yaml",
|
||||
task_id=mock_task_id,
|
||||
)
|
||||
|
||||
assert result.gate1_status == "review" # Exception -> review
|
||||
assert result.gate2_status == "pass"
|
||||
assert result.overall_status == "review"
|
||||
|
||||
mock_update.assert_called_once_with(str(mock_model_version_id), "review")
|
||||
|
||||
def test_validation_returns_none_mAP_marks_gate1_review(
|
||||
self,
|
||||
mock_model_version_id,
|
||||
mock_base_model_version_id,
|
||||
mock_task_id,
|
||||
mock_base_model,
|
||||
mock_new_model,
|
||||
):
|
||||
"""Test Gate 1 review when validation returns None mAP."""
|
||||
mock_base_model.metrics_mAP = 0.85
|
||||
mock_new_model.metrics_mAP = 0.82
|
||||
|
||||
mock_val_metrics = {"mAP50": None} # No mAP returned
|
||||
|
||||
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
|
||||
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
|
||||
patch("shared.training.YOLOTrainer") as MockTrainer, \
|
||||
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
|
||||
|
||||
mock_repo = MockRepo.return_value
|
||||
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_trainer = MockTrainer.return_value
|
||||
mock_trainer.validate.return_value = mock_val_metrics
|
||||
|
||||
result = run_gating_validation(
|
||||
model_version_id=mock_model_version_id,
|
||||
new_model_path="/path/to/model.pt",
|
||||
base_model_version_id=mock_base_model_version_id,
|
||||
data_yaml="/path/to/data.yaml",
|
||||
task_id=mock_task_id,
|
||||
)
|
||||
|
||||
assert result.gate1_status == "review" # None mAP -> review
|
||||
assert result.gate1_new_mAP is None
|
||||
assert result.gate2_status == "pass"
|
||||
assert result.overall_status == "review"
|
||||
|
||||
def test_gate2_exception_marks_gate2_review(
|
||||
self,
|
||||
mock_model_version_id,
|
||||
mock_base_model_version_id,
|
||||
mock_task_id,
|
||||
mock_base_model,
|
||||
mock_new_model,
|
||||
):
|
||||
"""Test Gate 2 review when accessing new model metrics raises exception."""
|
||||
mock_base_model.metrics_mAP = 0.85
|
||||
mock_new_model.metrics_mAP = 0.82
|
||||
|
||||
mock_val_metrics = {"mAP50": 0.84}
|
||||
|
||||
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
|
||||
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
|
||||
patch("shared.training.YOLOTrainer") as MockTrainer, \
|
||||
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
|
||||
|
||||
mock_repo = MockRepo.return_value
|
||||
|
||||
# Mock to raise exception for new model on second call
|
||||
def get_side_effect(id):
|
||||
if str(id) == str(mock_base_model_version_id):
|
||||
return mock_base_model
|
||||
elif str(id) == str(mock_model_version_id):
|
||||
raise RuntimeError("Cannot fetch new model")
|
||||
return None
|
||||
|
||||
mock_repo.get.side_effect = get_side_effect
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_trainer = MockTrainer.return_value
|
||||
mock_trainer.validate.return_value = mock_val_metrics
|
||||
|
||||
result = run_gating_validation(
|
||||
model_version_id=mock_model_version_id,
|
||||
new_model_path="/path/to/model.pt",
|
||||
base_model_version_id=mock_base_model_version_id,
|
||||
data_yaml="/path/to/data.yaml",
|
||||
task_id=mock_task_id,
|
||||
)
|
||||
|
||||
assert result.gate1_status == "review" # 0.85 - 0.84 = 0.01
|
||||
assert result.gate2_status == "review" # Exception -> review
|
||||
assert result.overall_status == "review"
|
||||
|
||||
def test_string_uuids_accepted(
|
||||
self,
|
||||
mock_model_version_id,
|
||||
mock_base_model_version_id,
|
||||
mock_task_id,
|
||||
mock_base_model,
|
||||
mock_new_model,
|
||||
):
|
||||
"""Test that string UUIDs are accepted and converted properly."""
|
||||
mock_base_model.metrics_mAP = 0.85
|
||||
mock_new_model.metrics_mAP = 0.85
|
||||
|
||||
mock_val_metrics = {"mAP50": 0.85}
|
||||
|
||||
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
|
||||
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
|
||||
patch("shared.training.YOLOTrainer") as MockTrainer, \
|
||||
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
|
||||
|
||||
mock_repo = MockRepo.return_value
|
||||
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_trainer = MockTrainer.return_value
|
||||
mock_trainer.validate.return_value = mock_val_metrics
|
||||
|
||||
# Pass string UUIDs
|
||||
result = run_gating_validation(
|
||||
model_version_id=str(mock_model_version_id),
|
||||
new_model_path="/path/to/model.pt",
|
||||
base_model_version_id=str(mock_base_model_version_id),
|
||||
data_yaml="/path/to/data.yaml",
|
||||
task_id=str(mock_task_id),
|
||||
)
|
||||
|
||||
assert result.model_version_id == mock_model_version_id
|
||||
assert result.task_id == mock_task_id
|
||||
assert result.overall_status == "pass"
|
||||
Reference in New Issue
Block a user