This commit is contained in:
Yaojia Wang
2026-02-11 23:40:38 +01:00
parent f1a7bfe6b7
commit ad5ed46b4c
117 changed files with 5741 additions and 7669 deletions

View File

@@ -0,0 +1 @@
"""Tests for backend services."""

View File

@@ -0,0 +1,344 @@
"""
Tests for Data Mixing Service.
Tests cover:
1. get_mixing_ratio boundary values
2. build_mixed_dataset with temp filesystem
3. _find_pool_images matching logic
4. _image_to_label_path conversion
5. Edge cases (empty pool, no old data, cap)
"""
import pytest
from pathlib import Path
from uuid import uuid4
from backend.web.services.data_mixer import (
get_mixing_ratio,
build_mixed_dataset,
_collect_images,
_image_to_label_path,
_find_pool_images,
MIXING_RATIOS,
DEFAULT_MULTIPLIER,
MAX_OLD_SAMPLES,
MIN_POOL_SIZE,
)
# =============================================================================
# Test Constants
# =============================================================================
class TestConstants:
"""Tests for data mixer constants."""
def test_mixing_ratios_defined(self):
"""MIXING_RATIOS should have expected entries."""
assert len(MIXING_RATIOS) == 4
assert MIXING_RATIOS[0] == (10, 50)
assert MIXING_RATIOS[1] == (50, 20)
assert MIXING_RATIOS[2] == (200, 10)
assert MIXING_RATIOS[3] == (500, 5)
def test_default_multiplier(self):
"""DEFAULT_MULTIPLIER should be 5."""
assert DEFAULT_MULTIPLIER == 5
def test_max_old_samples(self):
"""MAX_OLD_SAMPLES should be 3000."""
assert MAX_OLD_SAMPLES == 3000
def test_min_pool_size(self):
"""MIN_POOL_SIZE should be 50."""
assert MIN_POOL_SIZE == 50
# =============================================================================
# Test get_mixing_ratio
# =============================================================================
class TestGetMixingRatio:
"""Tests for get_mixing_ratio function."""
def test_1_sample_returns_50x(self):
"""1 new sample should get 50x old data."""
assert get_mixing_ratio(1) == 50
def test_10_samples_returns_50x(self):
"""10 new samples (boundary) should get 50x."""
assert get_mixing_ratio(10) == 50
def test_11_samples_returns_20x(self):
"""11 new samples should get 20x."""
assert get_mixing_ratio(11) == 20
def test_50_samples_returns_20x(self):
"""50 new samples (boundary) should get 20x."""
assert get_mixing_ratio(50) == 20
def test_51_samples_returns_10x(self):
"""51 new samples should get 10x."""
assert get_mixing_ratio(51) == 10
def test_200_samples_returns_10x(self):
"""200 new samples (boundary) should get 10x."""
assert get_mixing_ratio(200) == 10
def test_201_samples_returns_5x(self):
"""201 new samples should get 5x."""
assert get_mixing_ratio(201) == 5
def test_500_samples_returns_5x(self):
"""500 new samples (boundary) should get 5x."""
assert get_mixing_ratio(500) == 5
def test_1000_samples_returns_default(self):
"""1000+ samples should get default multiplier (5x)."""
assert get_mixing_ratio(1000) == DEFAULT_MULTIPLIER
# =============================================================================
# Test _collect_images
# =============================================================================
class TestCollectImages:
"""Tests for _collect_images function."""
def test_collects_png_files(self, tmp_path: Path):
"""Should collect .png files."""
(tmp_path / "img1.png").write_bytes(b"fake png")
(tmp_path / "img2.png").write_bytes(b"fake png")
images = _collect_images(tmp_path)
assert len(images) == 2
def test_collects_jpg_files(self, tmp_path: Path):
"""Should collect .jpg files."""
(tmp_path / "img1.jpg").write_bytes(b"fake jpg")
images = _collect_images(tmp_path)
assert len(images) == 1
def test_collects_both_types(self, tmp_path: Path):
"""Should collect both .png and .jpg files."""
(tmp_path / "img1.png").write_bytes(b"fake png")
(tmp_path / "img2.jpg").write_bytes(b"fake jpg")
images = _collect_images(tmp_path)
assert len(images) == 2
def test_ignores_other_files(self, tmp_path: Path):
"""Should ignore non-image files."""
(tmp_path / "data.txt").write_text("not an image")
(tmp_path / "data.yaml").write_text("yaml")
(tmp_path / "img.png").write_bytes(b"png")
images = _collect_images(tmp_path)
assert len(images) == 1
def test_returns_empty_for_nonexistent_dir(self, tmp_path: Path):
"""Should return empty list for nonexistent directory."""
images = _collect_images(tmp_path / "nonexistent")
assert images == []
# =============================================================================
# Test _image_to_label_path
# =============================================================================
class TestImageToLabelPath:
"""Tests for _image_to_label_path function."""
def test_converts_train_image_to_label(self, tmp_path: Path):
"""Should convert images/train/img.png to labels/train/img.txt."""
image_path = tmp_path / "dataset" / "images" / "train" / "doc1_page1.png"
label_path = _image_to_label_path(image_path)
assert label_path.name == "doc1_page1.txt"
assert "labels" in str(label_path)
assert "train" in str(label_path)
def test_converts_val_image_to_label(self, tmp_path: Path):
"""Should convert images/val/img.jpg to labels/val/img.txt."""
image_path = tmp_path / "dataset" / "images" / "val" / "doc2_page3.jpg"
label_path = _image_to_label_path(image_path)
assert label_path.name == "doc2_page3.txt"
assert "labels" in str(label_path)
assert "val" in str(label_path)
# =============================================================================
# Test _find_pool_images
# =============================================================================
class TestFindPoolImages:
"""Tests for _find_pool_images function."""
def _create_dataset(self, base_path: Path, doc_ids: list[str], split: str = "train") -> None:
"""Helper to create a dataset structure with images."""
images_dir = base_path / "images" / split
images_dir.mkdir(parents=True, exist_ok=True)
for doc_id in doc_ids:
(images_dir / f"{doc_id}_page1.png").write_bytes(b"img")
(images_dir / f"{doc_id}_page2.png").write_bytes(b"img")
def test_finds_matching_images(self, tmp_path: Path):
"""Should find images matching pool document IDs."""
doc_id1 = str(uuid4())
doc_id2 = str(uuid4())
self._create_dataset(tmp_path, [doc_id1, doc_id2])
pool_ids = {doc_id1}
images = _find_pool_images(tmp_path, pool_ids)
assert len(images) == 2 # 2 pages for doc_id1
assert all(doc_id1 in str(img) for img in images)
def test_ignores_non_pool_images(self, tmp_path: Path):
"""Should not return images for documents not in pool."""
doc_id1 = str(uuid4())
doc_id2 = str(uuid4())
self._create_dataset(tmp_path, [doc_id1, doc_id2])
pool_ids = {doc_id1}
images = _find_pool_images(tmp_path, pool_ids)
# Only doc_id1 images should be found
for img in images:
assert doc_id1 in str(img)
assert doc_id2 not in str(img)
def test_searches_all_splits(self, tmp_path: Path):
"""Should search train, val, and test splits."""
doc_id = str(uuid4())
for split in ("train", "val", "test"):
self._create_dataset(tmp_path, [doc_id], split=split)
images = _find_pool_images(tmp_path, {doc_id})
assert len(images) == 6 # 2 pages * 3 splits
def test_empty_pool_returns_empty(self, tmp_path: Path):
"""Should return empty list for empty pool IDs."""
self._create_dataset(tmp_path, [str(uuid4())])
images = _find_pool_images(tmp_path, set())
assert images == []
# =============================================================================
# Test build_mixed_dataset
# =============================================================================
class TestBuildMixedDataset:
"""Tests for build_mixed_dataset function."""
def _setup_base_dataset(self, base_path: Path, num_old: int = 20) -> None:
"""Create a base dataset with old training images."""
for split in ("train", "val"):
img_dir = base_path / "images" / split
lbl_dir = base_path / "labels" / split
img_dir.mkdir(parents=True, exist_ok=True)
lbl_dir.mkdir(parents=True, exist_ok=True)
count = int(num_old * 0.8) if split == "train" else num_old - int(num_old * 0.8)
for i in range(count):
doc_id = str(uuid4())
img_file = img_dir / f"{doc_id}_page1.png"
lbl_file = lbl_dir / f"{doc_id}_page1.txt"
img_file.write_bytes(b"fake image data")
lbl_file.write_text("0 0.5 0.5 0.1 0.1\n")
def _setup_pool_images(self, base_path: Path, doc_ids: list[str]) -> None:
"""Add pool images to the base dataset."""
img_dir = base_path / "images" / "train"
lbl_dir = base_path / "labels" / "train"
img_dir.mkdir(parents=True, exist_ok=True)
lbl_dir.mkdir(parents=True, exist_ok=True)
for doc_id in doc_ids:
img_file = img_dir / f"{doc_id}_page1.png"
lbl_file = lbl_dir / f"{doc_id}_page1.txt"
img_file.write_bytes(b"pool image data")
lbl_file.write_text("0 0.5 0.5 0.2 0.2\n")
@pytest.fixture
def base_dataset(self, tmp_path: Path) -> Path:
"""Create a base dataset for testing."""
base_path = tmp_path / "base_dataset"
self._setup_base_dataset(base_path, num_old=20)
return base_path
def test_builds_output_structure(self, base_dataset: Path, tmp_path: Path):
"""Should create proper YOLO directory structure."""
pool_ids = [uuid4() for _ in range(5)]
self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids])
output_dir = tmp_path / "mixed_output"
result = build_mixed_dataset(
pool_document_ids=pool_ids,
base_dataset_path=base_dataset,
output_dir=output_dir,
)
assert (output_dir / "images" / "train").exists()
assert (output_dir / "images" / "val").exists()
assert (output_dir / "labels" / "train").exists()
assert (output_dir / "labels" / "val").exists()
assert (output_dir / "data.yaml").exists()
def test_returns_correct_metadata(self, base_dataset: Path, tmp_path: Path):
"""Should return correct counts and metadata."""
pool_ids = [uuid4() for _ in range(5)]
self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids])
output_dir = tmp_path / "mixed_output"
result = build_mixed_dataset(
pool_document_ids=pool_ids,
base_dataset_path=base_dataset,
output_dir=output_dir,
)
assert "data_yaml" in result
assert "total_images" in result
assert "old_images" in result
assert "new_images" in result
assert "mixing_ratio" in result
assert result["total_images"] == result["old_images"] + result["new_images"]
def test_mixing_ratio_applied(self, base_dataset: Path, tmp_path: Path):
"""Should use correct mixing ratio based on pool size."""
pool_ids = [uuid4() for _ in range(5)]
self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids])
output_dir = tmp_path / "mixed_output"
result = build_mixed_dataset(
pool_document_ids=pool_ids,
base_dataset_path=base_dataset,
output_dir=output_dir,
)
# 5 new samples -> 50x multiplier
assert result["mixing_ratio"] == 50
def test_seed_reproducibility(self, base_dataset: Path, tmp_path: Path):
"""Same seed should produce same output."""
pool_ids = [uuid4() for _ in range(3)]
self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids])
out1 = tmp_path / "out1"
out2 = tmp_path / "out2"
r1 = build_mixed_dataset(pool_ids, base_dataset, out1, seed=42)
r2 = build_mixed_dataset(pool_ids, base_dataset, out2, seed=42)
assert r1["old_images"] == r2["old_images"]
assert r1["new_images"] == r2["new_images"]
assert r1["total_images"] == r2["total_images"]

View File

@@ -0,0 +1,540 @@
"""
Unit tests for gating validation service.
Tests the quality gate validation logic for model deployment:
- Gate 1: mAP regression validation
- Gate 2: detection rate validation
- Overall status computation
- Full validation workflow with mocked dependencies
"""
import pytest
from unittest.mock import MagicMock, Mock, patch
from uuid import UUID, uuid4
from backend.web.services.gating_validator import (
GATE1_PASS_THRESHOLD,
GATE1_REVIEW_THRESHOLD,
GATE2_PASS_THRESHOLD,
classify_gate1,
classify_gate2,
compute_overall_status,
run_gating_validation,
)
from backend.data.admin_models import GatingResult
class TestClassifyGate1:
"""Test Gate 1 classification (mAP drop thresholds)."""
def test_pass_below_threshold(self):
"""Test mAP drop < 0.01 returns pass."""
assert classify_gate1(0.009) == "pass"
assert classify_gate1(0.005) == "pass"
assert classify_gate1(0.0) == "pass"
assert classify_gate1(-0.01) == "pass" # negative drop (improvement)
def test_pass_boundary(self):
"""Test mAP drop exactly at pass threshold."""
# 0.01 should be review (not pass), since condition is < 0.01
assert classify_gate1(GATE1_PASS_THRESHOLD) == "review"
def test_review_in_range(self):
"""Test mAP drop in review range [0.01, 0.03)."""
assert classify_gate1(0.01) == "review"
assert classify_gate1(0.015) == "review"
assert classify_gate1(0.02) == "review"
assert classify_gate1(0.029) == "review"
def test_review_boundary(self):
"""Test mAP drop exactly at review threshold."""
# 0.03 should be reject (not review), since condition is < 0.03
assert classify_gate1(GATE1_REVIEW_THRESHOLD) == "reject"
def test_reject_above_threshold(self):
"""Test mAP drop >= 0.03 returns reject."""
assert classify_gate1(0.03) == "reject"
assert classify_gate1(0.05) == "reject"
assert classify_gate1(0.10) == "reject"
assert classify_gate1(1.0) == "reject"
class TestClassifyGate2:
"""Test Gate 2 classification (detection rate thresholds)."""
def test_pass_above_threshold(self):
"""Test detection rate >= 0.80 returns pass."""
assert classify_gate2(0.80) == "pass"
assert classify_gate2(0.85) == "pass"
assert classify_gate2(0.99) == "pass"
assert classify_gate2(1.0) == "pass"
def test_pass_boundary(self):
"""Test detection rate exactly at pass threshold."""
assert classify_gate2(GATE2_PASS_THRESHOLD) == "pass"
def test_review_below_threshold(self):
"""Test detection rate < 0.80 returns review."""
assert classify_gate2(0.79) == "review"
assert classify_gate2(0.75) == "review"
assert classify_gate2(0.50) == "review"
assert classify_gate2(0.0) == "review"
class TestComputeOverallStatus:
"""Test overall status computation from individual gates."""
def test_both_pass(self):
"""Test both gates pass -> overall pass."""
assert compute_overall_status("pass", "pass") == "pass"
def test_gate1_reject_gate2_pass(self):
"""Test any reject -> overall reject."""
assert compute_overall_status("reject", "pass") == "reject"
def test_gate1_pass_gate2_reject(self):
"""Test any reject -> overall reject."""
assert compute_overall_status("pass", "reject") == "reject"
def test_both_reject(self):
"""Test both reject -> overall reject."""
assert compute_overall_status("reject", "reject") == "reject"
def test_gate1_review_gate2_pass(self):
"""Test any review (no reject) -> overall review."""
assert compute_overall_status("review", "pass") == "review"
def test_gate1_pass_gate2_review(self):
"""Test any review (no reject) -> overall review."""
assert compute_overall_status("pass", "review") == "review"
def test_both_review(self):
"""Test both review -> overall review."""
assert compute_overall_status("review", "review") == "review"
def test_gate1_reject_gate2_review(self):
"""Test reject takes precedence over review."""
assert compute_overall_status("reject", "review") == "reject"
def test_gate1_review_gate2_reject(self):
"""Test reject takes precedence over review."""
assert compute_overall_status("review", "reject") == "reject"
class TestRunGatingValidation:
"""Test full gating validation workflow with mocked dependencies."""
@pytest.fixture
def mock_model_version_id(self):
"""Generate a model version ID for testing."""
return uuid4()
@pytest.fixture
def mock_base_model_version_id(self):
"""Generate a base model version ID for testing."""
return uuid4()
@pytest.fixture
def mock_task_id(self):
"""Generate a task ID for testing."""
return uuid4()
@pytest.fixture
def mock_base_model(self):
"""Create a mock base model with metrics."""
model = Mock()
model.metrics_mAP = 0.85
return model
@pytest.fixture
def mock_new_model(self):
"""Create a mock new model with metrics."""
model = Mock()
model.metrics_mAP = 0.82
return model
def test_gate1_pass_gate2_pass(
self,
mock_model_version_id,
mock_base_model_version_id,
mock_task_id,
mock_base_model,
mock_new_model,
):
"""Test validation with both gates passing."""
# Setup: base mAP=0.85, new mAP=0.84 -> drop=0.01 (review)
# But new model mAP=0.82 -> gate2 pass
mock_base_model.metrics_mAP = 0.85
mock_new_model.metrics_mAP = 0.82
mock_val_metrics = {"mAP50": 0.84}
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
patch("shared.training.YOLOTrainer") as MockTrainer, \
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
# Mock repository
mock_repo = MockRepo.return_value
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
# Mock session context
mock_session = MagicMock()
mock_session_ctx.return_value.__enter__.return_value = mock_session
# Mock YOLO trainer
mock_trainer = MockTrainer.return_value
mock_trainer.validate.return_value = mock_val_metrics
# Execute
result = run_gating_validation(
model_version_id=mock_model_version_id,
new_model_path="/path/to/model.pt",
base_model_version_id=mock_base_model_version_id,
data_yaml="/path/to/data.yaml",
task_id=mock_task_id,
)
# Verify
assert result.gate1_status == "review" # 0.85 - 0.84 = 0.01
assert result.gate1_original_mAP == 0.85
assert result.gate1_new_mAP == 0.84
assert result.gate1_mAP_drop == pytest.approx(0.01, abs=1e-6)
assert result.gate2_status == "pass" # 0.82 >= 0.80
assert result.gate2_detection_rate == 0.82
assert result.overall_status == "review" # Any review -> review
# Verify DB operations
mock_session.add.assert_called()
mock_session.commit.assert_called()
mock_update.assert_called_once_with(str(mock_model_version_id), "review")
def test_gate1_reject_due_to_large_drop(
self,
mock_model_version_id,
mock_base_model_version_id,
mock_task_id,
mock_base_model,
mock_new_model,
):
"""Test Gate 1 reject when mAP drop >= 0.03."""
mock_base_model.metrics_mAP = 0.85
mock_new_model.metrics_mAP = 0.82
mock_val_metrics = {"mAP50": 0.81} # 0.85 - 0.81 = 0.04 (reject)
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
patch("shared.training.YOLOTrainer") as MockTrainer, \
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
mock_repo = MockRepo.return_value
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
mock_session = MagicMock()
mock_session_ctx.return_value.__enter__.return_value = mock_session
mock_trainer = MockTrainer.return_value
mock_trainer.validate.return_value = mock_val_metrics
result = run_gating_validation(
model_version_id=mock_model_version_id,
new_model_path="/path/to/model.pt",
base_model_version_id=mock_base_model_version_id,
data_yaml="/path/to/data.yaml",
task_id=mock_task_id,
)
assert result.gate1_status == "reject"
assert result.gate1_mAP_drop == pytest.approx(0.04, abs=1e-6)
assert result.overall_status == "reject" # Any reject -> reject
mock_update.assert_called_once_with(str(mock_model_version_id), "reject")
def test_gate2_review_due_to_low_detection_rate(
self,
mock_model_version_id,
mock_base_model_version_id,
mock_task_id,
mock_base_model,
mock_new_model,
):
"""Test Gate 2 review when detection rate < 0.80."""
mock_base_model.metrics_mAP = 0.85
mock_new_model.metrics_mAP = 0.75 # Below 0.80 threshold
mock_val_metrics = {"mAP50": 0.845} # Gate 1: 0.85 - 0.845 = 0.005 (pass)
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
patch("shared.training.YOLOTrainer") as MockTrainer, \
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
mock_repo = MockRepo.return_value
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
mock_session = MagicMock()
mock_session_ctx.return_value.__enter__.return_value = mock_session
mock_trainer = MockTrainer.return_value
mock_trainer.validate.return_value = mock_val_metrics
result = run_gating_validation(
model_version_id=mock_model_version_id,
new_model_path="/path/to/model.pt",
base_model_version_id=mock_base_model_version_id,
data_yaml="/path/to/data.yaml",
task_id=mock_task_id,
)
assert result.gate1_status == "pass"
assert result.gate2_status == "review" # 0.75 < 0.80
assert result.gate2_detection_rate == 0.75
assert result.overall_status == "review"
mock_update.assert_called_once_with(str(mock_model_version_id), "review")
def test_no_base_model_skips_gate1(
self,
mock_model_version_id,
mock_task_id,
mock_new_model,
):
"""Test Gate 1 passes when no base model is provided."""
mock_new_model.metrics_mAP = 0.85
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
mock_repo = MockRepo.return_value
mock_repo.get.return_value = mock_new_model
mock_session = MagicMock()
mock_session_ctx.return_value.__enter__.return_value = mock_session
result = run_gating_validation(
model_version_id=mock_model_version_id,
new_model_path="/path/to/model.pt",
base_model_version_id=None,
data_yaml="/path/to/data.yaml",
task_id=mock_task_id,
)
assert result.gate1_status == "pass" # Skipped
assert result.gate1_original_mAP is None
assert result.gate1_new_mAP is None
assert result.gate1_mAP_drop is None
assert result.gate2_status == "pass" # 0.85 >= 0.80
assert result.overall_status == "pass"
mock_update.assert_called_once_with(str(mock_model_version_id), "pass")
def test_base_model_without_metrics_skips_gate1(
self,
mock_model_version_id,
mock_base_model_version_id,
mock_task_id,
mock_base_model,
mock_new_model,
):
"""Test Gate 1 passes when base model has no metrics."""
mock_base_model.metrics_mAP = None
mock_new_model.metrics_mAP = 0.85
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
mock_repo = MockRepo.return_value
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
mock_session = MagicMock()
mock_session_ctx.return_value.__enter__.return_value = mock_session
result = run_gating_validation(
model_version_id=mock_model_version_id,
new_model_path="/path/to/model.pt",
base_model_version_id=mock_base_model_version_id,
data_yaml="/path/to/data.yaml",
task_id=mock_task_id,
)
assert result.gate1_status == "pass" # Skipped due to no base metrics
assert result.gate2_status == "pass"
assert result.overall_status == "pass"
def test_validation_failure_marks_gate1_review(
self,
mock_model_version_id,
mock_base_model_version_id,
mock_task_id,
mock_base_model,
mock_new_model,
):
"""Test Gate 1 review when validation raises exception."""
mock_base_model.metrics_mAP = 0.85
mock_new_model.metrics_mAP = 0.82
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
patch("shared.training.YOLOTrainer") as MockTrainer, \
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
mock_repo = MockRepo.return_value
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
mock_session = MagicMock()
mock_session_ctx.return_value.__enter__.return_value = mock_session
# Mock trainer to raise exception
mock_trainer = MockTrainer.return_value
mock_trainer.validate.side_effect = RuntimeError("Validation failed")
result = run_gating_validation(
model_version_id=mock_model_version_id,
new_model_path="/path/to/model.pt",
base_model_version_id=mock_base_model_version_id,
data_yaml="/path/to/data.yaml",
task_id=mock_task_id,
)
assert result.gate1_status == "review" # Exception -> review
assert result.gate2_status == "pass"
assert result.overall_status == "review"
mock_update.assert_called_once_with(str(mock_model_version_id), "review")
def test_validation_returns_none_mAP_marks_gate1_review(
self,
mock_model_version_id,
mock_base_model_version_id,
mock_task_id,
mock_base_model,
mock_new_model,
):
"""Test Gate 1 review when validation returns None mAP."""
mock_base_model.metrics_mAP = 0.85
mock_new_model.metrics_mAP = 0.82
mock_val_metrics = {"mAP50": None} # No mAP returned
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
patch("shared.training.YOLOTrainer") as MockTrainer, \
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
mock_repo = MockRepo.return_value
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
mock_session = MagicMock()
mock_session_ctx.return_value.__enter__.return_value = mock_session
mock_trainer = MockTrainer.return_value
mock_trainer.validate.return_value = mock_val_metrics
result = run_gating_validation(
model_version_id=mock_model_version_id,
new_model_path="/path/to/model.pt",
base_model_version_id=mock_base_model_version_id,
data_yaml="/path/to/data.yaml",
task_id=mock_task_id,
)
assert result.gate1_status == "review" # None mAP -> review
assert result.gate1_new_mAP is None
assert result.gate2_status == "pass"
assert result.overall_status == "review"
def test_gate2_exception_marks_gate2_review(
self,
mock_model_version_id,
mock_base_model_version_id,
mock_task_id,
mock_base_model,
mock_new_model,
):
"""Test Gate 2 review when accessing new model metrics raises exception."""
mock_base_model.metrics_mAP = 0.85
mock_new_model.metrics_mAP = 0.82
mock_val_metrics = {"mAP50": 0.84}
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
patch("shared.training.YOLOTrainer") as MockTrainer, \
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
mock_repo = MockRepo.return_value
# Mock to raise exception for new model on second call
def get_side_effect(id):
if str(id) == str(mock_base_model_version_id):
return mock_base_model
elif str(id) == str(mock_model_version_id):
raise RuntimeError("Cannot fetch new model")
return None
mock_repo.get.side_effect = get_side_effect
mock_session = MagicMock()
mock_session_ctx.return_value.__enter__.return_value = mock_session
mock_trainer = MockTrainer.return_value
mock_trainer.validate.return_value = mock_val_metrics
result = run_gating_validation(
model_version_id=mock_model_version_id,
new_model_path="/path/to/model.pt",
base_model_version_id=mock_base_model_version_id,
data_yaml="/path/to/data.yaml",
task_id=mock_task_id,
)
assert result.gate1_status == "review" # 0.85 - 0.84 = 0.01
assert result.gate2_status == "review" # Exception -> review
assert result.overall_status == "review"
def test_string_uuids_accepted(
self,
mock_model_version_id,
mock_base_model_version_id,
mock_task_id,
mock_base_model,
mock_new_model,
):
"""Test that string UUIDs are accepted and converted properly."""
mock_base_model.metrics_mAP = 0.85
mock_new_model.metrics_mAP = 0.85
mock_val_metrics = {"mAP50": 0.85}
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
patch("shared.training.YOLOTrainer") as MockTrainer, \
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
mock_repo = MockRepo.return_value
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
mock_session = MagicMock()
mock_session_ctx.return_value.__enter__.return_value = mock_session
mock_trainer = MockTrainer.return_value
mock_trainer.validate.return_value = mock_val_metrics
# Pass string UUIDs
result = run_gating_validation(
model_version_id=str(mock_model_version_id),
new_model_path="/path/to/model.pt",
base_model_version_id=str(mock_base_model_version_id),
data_yaml="/path/to/data.yaml",
task_id=str(mock_task_id),
)
assert result.model_version_id == mock_model_version_id
assert result.task_id == mock_task_id
assert result.overall_status == "pass"