Files
invoice-master-poc-v2/tests/web/test_finetune_pool.py
Yaojia Wang ad5ed46b4c WIP
2026-02-11 23:40:38 +01:00

468 lines
16 KiB
Python

"""
Tests for Fine-Tune Pool feature.
Tests cover:
1. FineTunePoolEntry database model
2. PoolAddRequest/PoolStatsResponse schemas
3. Chain prevention logic
4. Pool threshold enforcement
5. Model lineage fields on ModelVersion
6. Gating enforcement on model activation
"""
import pytest
from datetime import datetime
from unittest.mock import MagicMock, patch
from uuid import uuid4, UUID
# =============================================================================
# Test Database Models
# =============================================================================
class TestFineTunePoolEntryModel:
"""Tests for FineTunePoolEntry model."""
def test_creates_with_defaults(self):
"""FineTunePoolEntry should have correct defaults."""
from backend.data.admin_models import FineTunePoolEntry
entry = FineTunePoolEntry(document_id=uuid4())
assert entry.entry_id is not None
assert entry.is_verified is False
assert entry.verified_at is None
assert entry.verified_by is None
assert entry.added_by is None
assert entry.reason is None
def test_creates_with_all_fields(self):
"""FineTunePoolEntry should accept all fields."""
from backend.data.admin_models import FineTunePoolEntry
doc_id = uuid4()
entry = FineTunePoolEntry(
document_id=doc_id,
added_by="admin",
reason="user_reported_failure",
is_verified=True,
verified_by="reviewer",
)
assert entry.document_id == doc_id
assert entry.added_by == "admin"
assert entry.reason == "user_reported_failure"
assert entry.is_verified is True
assert entry.verified_by == "reviewer"
class TestGatingResultModel:
"""Tests for GatingResult model."""
def test_creates_with_defaults(self):
"""GatingResult should have correct defaults."""
from backend.data.admin_models import GatingResult
model_version_id = uuid4()
result = GatingResult(
model_version_id=model_version_id,
gate1_status="pass",
gate2_status="pass",
overall_status="pass",
)
assert result.result_id is not None
assert result.model_version_id == model_version_id
assert result.gate1_status == "pass"
assert result.gate2_status == "pass"
assert result.overall_status == "pass"
assert result.gate1_mAP_drop is None
assert result.gate2_detection_rate is None
def test_creates_with_full_metrics(self):
"""GatingResult should store full metrics."""
from backend.data.admin_models import GatingResult
result = GatingResult(
model_version_id=uuid4(),
gate1_status="review",
gate1_original_mAP=0.95,
gate1_new_mAP=0.93,
gate1_mAP_drop=0.02,
gate2_status="pass",
gate2_detection_rate=0.85,
gate2_total_samples=100,
gate2_detected_samples=85,
overall_status="review",
)
assert result.gate1_original_mAP == 0.95
assert result.gate1_new_mAP == 0.93
assert result.gate1_mAP_drop == 0.02
assert result.gate2_detection_rate == 0.85
class TestModelVersionLineage:
"""Tests for ModelVersion lineage fields."""
def test_default_model_type_is_base(self):
"""ModelVersion should default to 'base' model_type."""
from backend.data.admin_models import ModelVersion
mv = ModelVersion(
version="v1.0",
name="test-model",
model_path="/path/to/model.pt",
)
assert mv.model_type == "base"
assert mv.base_model_version_id is None
assert mv.base_training_dataset_id is None
assert mv.gating_status == "pending"
def test_finetune_model_type(self):
"""ModelVersion should support 'finetune' type with lineage."""
from backend.data.admin_models import ModelVersion
base_id = uuid4()
dataset_id = uuid4()
mv = ModelVersion(
version="v2.0",
name="finetuned-model",
model_path="/path/to/ft_model.pt",
model_type="finetune",
base_model_version_id=base_id,
base_training_dataset_id=dataset_id,
gating_status="pending",
)
assert mv.model_type == "finetune"
assert mv.base_model_version_id == base_id
assert mv.base_training_dataset_id == dataset_id
assert mv.gating_status == "pending"
# =============================================================================
# Test Schemas
# =============================================================================
class TestPoolSchemas:
"""Tests for pool Pydantic schemas."""
def test_pool_add_request_defaults(self):
"""PoolAddRequest should have default reason."""
from backend.web.schemas.admin.pool import PoolAddRequest
req = PoolAddRequest(document_id="550e8400-e29b-41d4-a716-446655440001")
assert req.document_id == "550e8400-e29b-41d4-a716-446655440001"
assert req.reason == "user_reported_failure"
def test_pool_add_request_custom_reason(self):
"""PoolAddRequest should accept custom reason."""
from backend.web.schemas.admin.pool import PoolAddRequest
req = PoolAddRequest(
document_id="550e8400-e29b-41d4-a716-446655440001",
reason="manual_addition",
)
assert req.reason == "manual_addition"
def test_pool_stats_response(self):
"""PoolStatsResponse should compute readiness correctly."""
from backend.web.schemas.admin.pool import PoolStatsResponse
# Not ready
stats = PoolStatsResponse(
total_entries=30,
verified_entries=20,
unverified_entries=10,
is_ready=False,
)
assert stats.is_ready is False
assert stats.min_required == 50
# Ready
stats_ready = PoolStatsResponse(
total_entries=80,
verified_entries=60,
unverified_entries=20,
is_ready=True,
)
assert stats_ready.is_ready is True
def test_pool_entry_item(self):
"""PoolEntryItem should serialize correctly."""
from backend.web.schemas.admin.pool import PoolEntryItem
entry = PoolEntryItem(
entry_id="entry-uuid",
document_id="doc-uuid",
is_verified=True,
verified_at=datetime.utcnow(),
verified_by="admin",
created_at=datetime.utcnow(),
)
assert entry.is_verified is True
assert entry.verified_by == "admin"
def test_gating_result_item(self):
"""GatingResultItem should serialize all gate fields."""
from backend.web.schemas.admin.pool import GatingResultItem
item = GatingResultItem(
result_id="result-uuid",
model_version_id="model-uuid",
gate1_status="pass",
gate1_original_mAP=0.95,
gate1_new_mAP=0.94,
gate1_mAP_drop=0.01,
gate2_status="pass",
gate2_detection_rate=0.90,
gate2_total_samples=50,
gate2_detected_samples=45,
overall_status="pass",
created_at=datetime.utcnow(),
)
assert item.gate1_status == "pass"
assert item.overall_status == "pass"
# =============================================================================
# Test Chain Prevention
# =============================================================================
class TestChainPrevention:
"""Tests for fine-tune chain prevention logic."""
def test_rejects_finetune_from_finetune_model(self):
"""Should reject training when base model is already a fine-tune."""
# Simulate the chain prevention check from datasets.py
model_type = "finetune"
base_model_version_id = str(uuid4())
# This should trigger rejection
assert model_type == "finetune"
def test_allows_finetune_from_base_model(self):
"""Should allow training when base model is a base model."""
model_type = "base"
assert model_type != "finetune"
def test_allows_fresh_training(self):
"""Should allow fresh training (no base model)."""
base_model_version_id = None
assert base_model_version_id is None # No chain check needed
# =============================================================================
# Test Pool Threshold
# =============================================================================
class TestPoolThreshold:
"""Tests for minimum pool size enforcement."""
def test_min_pool_size_constant(self):
"""MIN_POOL_SIZE should be 50."""
from backend.web.services.data_mixer import MIN_POOL_SIZE
assert MIN_POOL_SIZE == 50
def test_pool_below_threshold_blocks_finetune(self):
"""Pool with fewer than 50 verified entries should block fine-tuning."""
from backend.web.services.data_mixer import MIN_POOL_SIZE
verified_count = 30
assert verified_count < MIN_POOL_SIZE
def test_pool_at_threshold_allows_finetune(self):
"""Pool with exactly 50 verified entries should allow fine-tuning."""
from backend.web.services.data_mixer import MIN_POOL_SIZE
verified_count = 50
assert verified_count >= MIN_POOL_SIZE
# =============================================================================
# Test Gating Enforcement on Activation
# =============================================================================
class TestGatingEnforcement:
"""Tests for gating enforcement when activating models."""
def test_base_model_skips_gating(self):
"""Base models should have gating_status 'skipped'."""
from backend.data.admin_models import ModelVersion
mv = ModelVersion(
version="v1.0",
name="base",
model_path="/model.pt",
model_type="base",
)
# Base models skip gating - activation should work
assert mv.model_type == "base"
# Gating should not block base model activation
def test_finetune_model_rejected_blocks_activation(self):
"""Fine-tuned models with 'reject' gating should block activation."""
model_type = "finetune"
gating_status = "reject"
# Simulates the check in models.py activation endpoint
should_block = model_type == "finetune" and gating_status == "reject"
assert should_block is True
def test_finetune_model_pending_blocks_activation(self):
"""Fine-tuned models with 'pending' gating should block activation."""
model_type = "finetune"
gating_status = "pending"
should_block = model_type == "finetune" and gating_status == "pending"
assert should_block is True
def test_finetune_model_pass_allows_activation(self):
"""Fine-tuned models with 'pass' gating should allow activation."""
model_type = "finetune"
gating_status = "pass"
should_block_reject = model_type == "finetune" and gating_status == "reject"
should_block_pending = model_type == "finetune" and gating_status == "pending"
assert should_block_reject is False
assert should_block_pending is False
def test_finetune_model_review_allows_with_warning(self):
"""Fine-tuned models with 'review' gating should allow but warn."""
model_type = "finetune"
gating_status = "review"
should_block_reject = model_type == "finetune" and gating_status == "reject"
should_block_pending = model_type == "finetune" and gating_status == "pending"
assert should_block_reject is False
assert should_block_pending is False
# Should include warning in message
# =============================================================================
# Test Pool API Route Registration
# =============================================================================
class TestPoolRouteRegistration:
"""Tests for pool route registration."""
def test_pool_routes_registered(self):
"""Pool routes should be registered on training router."""
from backend.web.api.v1.admin.training import create_training_router
router = create_training_router()
paths = [route.path for route in router.routes]
assert any("/pool" in p for p in paths)
assert any("/pool/stats" in p for p in paths)
# =============================================================================
# Test Scheduler Fine-Tune Parameter Override
# =============================================================================
class TestSchedulerFineTuneParams:
"""Tests for scheduler fine-tune parameter overrides."""
def test_finetune_detected_from_base_model_path(self):
"""Scheduler should detect fine-tune mode from base_model_path."""
config = {"base_model_path": "/path/to/base_model.pt"}
is_finetune = bool(config.get("base_model_path"))
assert is_finetune is True
def test_fresh_training_not_finetune(self):
"""Scheduler should not enable fine-tune for fresh training."""
config = {"model_name": "yolo26s.pt"}
is_finetune = bool(config.get("base_model_path"))
assert is_finetune is False
def test_finetune_defaults_correct_epochs(self):
"""Fine-tune should default to 10 epochs."""
config = {"base_model_path": "/path/to/model.pt"}
is_finetune = bool(config.get("base_model_path"))
if is_finetune:
epochs = config.get("epochs", 10)
learning_rate = config.get("learning_rate", 0.001)
else:
epochs = config.get("epochs", 100)
learning_rate = config.get("learning_rate", 0.01)
assert epochs == 10
assert learning_rate == 0.001
def test_model_lineage_set_for_finetune(self):
"""Scheduler should set model_type and base_model_version_id for fine-tune."""
config = {
"base_model_path": "/path/to/model.pt",
"base_model_version_id": str(uuid4()),
}
is_finetune = bool(config.get("base_model_path"))
model_type = "finetune" if is_finetune else "base"
base_model_version_id = config.get("base_model_version_id") if is_finetune else None
gating_status = "pending" if is_finetune else "skipped"
assert model_type == "finetune"
assert base_model_version_id is not None
assert gating_status == "pending"
def test_model_lineage_skipped_for_base(self):
"""Scheduler should set model_type='base' for fresh training."""
config = {"model_name": "yolo26s.pt"}
is_finetune = bool(config.get("base_model_path"))
model_type = "finetune" if is_finetune else "base"
gating_status = "pending" if is_finetune else "skipped"
assert model_type == "base"
assert gating_status == "skipped"
# =============================================================================
# Test TrainingConfig freeze/cos_lr
# =============================================================================
class TestTrainingConfigFineTuneFields:
"""Tests for freeze and cos_lr fields in shared TrainingConfig."""
def test_default_freeze_is_zero(self):
"""TrainingConfig freeze should default to 0."""
from shared.training import TrainingConfig
config = TrainingConfig(
model_path="test.pt",
data_yaml="data.yaml",
)
assert config.freeze == 0
def test_default_cos_lr_is_false(self):
"""TrainingConfig cos_lr should default to False."""
from shared.training import TrainingConfig
config = TrainingConfig(
model_path="test.pt",
data_yaml="data.yaml",
)
assert config.cos_lr is False
def test_finetune_config(self):
"""TrainingConfig should accept fine-tune parameters."""
from shared.training import TrainingConfig
config = TrainingConfig(
model_path="base_model.pt",
data_yaml="data.yaml",
epochs=10,
learning_rate=0.001,
freeze=10,
cos_lr=True,
)
assert config.freeze == 10
assert config.cos_lr is True
assert config.epochs == 10
assert config.learning_rate == 0.001