WIP
This commit is contained in:
467
tests/web/test_finetune_pool.py
Normal file
467
tests/web/test_finetune_pool.py
Normal file
@@ -0,0 +1,467 @@
|
||||
"""
|
||||
Tests for Fine-Tune Pool feature.
|
||||
|
||||
Tests cover:
|
||||
1. FineTunePoolEntry database model
|
||||
2. PoolAddRequest/PoolStatsResponse schemas
|
||||
3. Chain prevention logic
|
||||
4. Pool threshold enforcement
|
||||
5. Model lineage fields on ModelVersion
|
||||
6. Gating enforcement on model activation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Database Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestFineTunePoolEntryModel:
|
||||
"""Tests for FineTunePoolEntry model."""
|
||||
|
||||
def test_creates_with_defaults(self):
|
||||
"""FineTunePoolEntry should have correct defaults."""
|
||||
from backend.data.admin_models import FineTunePoolEntry
|
||||
|
||||
entry = FineTunePoolEntry(document_id=uuid4())
|
||||
assert entry.entry_id is not None
|
||||
assert entry.is_verified is False
|
||||
assert entry.verified_at is None
|
||||
assert entry.verified_by is None
|
||||
assert entry.added_by is None
|
||||
assert entry.reason is None
|
||||
|
||||
def test_creates_with_all_fields(self):
|
||||
"""FineTunePoolEntry should accept all fields."""
|
||||
from backend.data.admin_models import FineTunePoolEntry
|
||||
|
||||
doc_id = uuid4()
|
||||
entry = FineTunePoolEntry(
|
||||
document_id=doc_id,
|
||||
added_by="admin",
|
||||
reason="user_reported_failure",
|
||||
is_verified=True,
|
||||
verified_by="reviewer",
|
||||
)
|
||||
assert entry.document_id == doc_id
|
||||
assert entry.added_by == "admin"
|
||||
assert entry.reason == "user_reported_failure"
|
||||
assert entry.is_verified is True
|
||||
assert entry.verified_by == "reviewer"
|
||||
|
||||
|
||||
class TestGatingResultModel:
|
||||
"""Tests for GatingResult model."""
|
||||
|
||||
def test_creates_with_defaults(self):
|
||||
"""GatingResult should have correct defaults."""
|
||||
from backend.data.admin_models import GatingResult
|
||||
|
||||
model_version_id = uuid4()
|
||||
result = GatingResult(
|
||||
model_version_id=model_version_id,
|
||||
gate1_status="pass",
|
||||
gate2_status="pass",
|
||||
overall_status="pass",
|
||||
)
|
||||
assert result.result_id is not None
|
||||
assert result.model_version_id == model_version_id
|
||||
assert result.gate1_status == "pass"
|
||||
assert result.gate2_status == "pass"
|
||||
assert result.overall_status == "pass"
|
||||
assert result.gate1_mAP_drop is None
|
||||
assert result.gate2_detection_rate is None
|
||||
|
||||
def test_creates_with_full_metrics(self):
|
||||
"""GatingResult should store full metrics."""
|
||||
from backend.data.admin_models import GatingResult
|
||||
|
||||
result = GatingResult(
|
||||
model_version_id=uuid4(),
|
||||
gate1_status="review",
|
||||
gate1_original_mAP=0.95,
|
||||
gate1_new_mAP=0.93,
|
||||
gate1_mAP_drop=0.02,
|
||||
gate2_status="pass",
|
||||
gate2_detection_rate=0.85,
|
||||
gate2_total_samples=100,
|
||||
gate2_detected_samples=85,
|
||||
overall_status="review",
|
||||
)
|
||||
assert result.gate1_original_mAP == 0.95
|
||||
assert result.gate1_new_mAP == 0.93
|
||||
assert result.gate1_mAP_drop == 0.02
|
||||
assert result.gate2_detection_rate == 0.85
|
||||
|
||||
|
||||
class TestModelVersionLineage:
|
||||
"""Tests for ModelVersion lineage fields."""
|
||||
|
||||
def test_default_model_type_is_base(self):
|
||||
"""ModelVersion should default to 'base' model_type."""
|
||||
from backend.data.admin_models import ModelVersion
|
||||
|
||||
mv = ModelVersion(
|
||||
version="v1.0",
|
||||
name="test-model",
|
||||
model_path="/path/to/model.pt",
|
||||
)
|
||||
assert mv.model_type == "base"
|
||||
assert mv.base_model_version_id is None
|
||||
assert mv.base_training_dataset_id is None
|
||||
assert mv.gating_status == "pending"
|
||||
|
||||
def test_finetune_model_type(self):
|
||||
"""ModelVersion should support 'finetune' type with lineage."""
|
||||
from backend.data.admin_models import ModelVersion
|
||||
|
||||
base_id = uuid4()
|
||||
dataset_id = uuid4()
|
||||
mv = ModelVersion(
|
||||
version="v2.0",
|
||||
name="finetuned-model",
|
||||
model_path="/path/to/ft_model.pt",
|
||||
model_type="finetune",
|
||||
base_model_version_id=base_id,
|
||||
base_training_dataset_id=dataset_id,
|
||||
gating_status="pending",
|
||||
)
|
||||
assert mv.model_type == "finetune"
|
||||
assert mv.base_model_version_id == base_id
|
||||
assert mv.base_training_dataset_id == dataset_id
|
||||
assert mv.gating_status == "pending"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPoolSchemas:
|
||||
"""Tests for pool Pydantic schemas."""
|
||||
|
||||
def test_pool_add_request_defaults(self):
|
||||
"""PoolAddRequest should have default reason."""
|
||||
from backend.web.schemas.admin.pool import PoolAddRequest
|
||||
|
||||
req = PoolAddRequest(document_id="550e8400-e29b-41d4-a716-446655440001")
|
||||
assert req.document_id == "550e8400-e29b-41d4-a716-446655440001"
|
||||
assert req.reason == "user_reported_failure"
|
||||
|
||||
def test_pool_add_request_custom_reason(self):
|
||||
"""PoolAddRequest should accept custom reason."""
|
||||
from backend.web.schemas.admin.pool import PoolAddRequest
|
||||
|
||||
req = PoolAddRequest(
|
||||
document_id="550e8400-e29b-41d4-a716-446655440001",
|
||||
reason="manual_addition",
|
||||
)
|
||||
assert req.reason == "manual_addition"
|
||||
|
||||
def test_pool_stats_response(self):
|
||||
"""PoolStatsResponse should compute readiness correctly."""
|
||||
from backend.web.schemas.admin.pool import PoolStatsResponse
|
||||
|
||||
# Not ready
|
||||
stats = PoolStatsResponse(
|
||||
total_entries=30,
|
||||
verified_entries=20,
|
||||
unverified_entries=10,
|
||||
is_ready=False,
|
||||
)
|
||||
assert stats.is_ready is False
|
||||
assert stats.min_required == 50
|
||||
|
||||
# Ready
|
||||
stats_ready = PoolStatsResponse(
|
||||
total_entries=80,
|
||||
verified_entries=60,
|
||||
unverified_entries=20,
|
||||
is_ready=True,
|
||||
)
|
||||
assert stats_ready.is_ready is True
|
||||
|
||||
def test_pool_entry_item(self):
|
||||
"""PoolEntryItem should serialize correctly."""
|
||||
from backend.web.schemas.admin.pool import PoolEntryItem
|
||||
|
||||
entry = PoolEntryItem(
|
||||
entry_id="entry-uuid",
|
||||
document_id="doc-uuid",
|
||||
is_verified=True,
|
||||
verified_at=datetime.utcnow(),
|
||||
verified_by="admin",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
assert entry.is_verified is True
|
||||
assert entry.verified_by == "admin"
|
||||
|
||||
def test_gating_result_item(self):
|
||||
"""GatingResultItem should serialize all gate fields."""
|
||||
from backend.web.schemas.admin.pool import GatingResultItem
|
||||
|
||||
item = GatingResultItem(
|
||||
result_id="result-uuid",
|
||||
model_version_id="model-uuid",
|
||||
gate1_status="pass",
|
||||
gate1_original_mAP=0.95,
|
||||
gate1_new_mAP=0.94,
|
||||
gate1_mAP_drop=0.01,
|
||||
gate2_status="pass",
|
||||
gate2_detection_rate=0.90,
|
||||
gate2_total_samples=50,
|
||||
gate2_detected_samples=45,
|
||||
overall_status="pass",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
assert item.gate1_status == "pass"
|
||||
assert item.overall_status == "pass"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Chain Prevention
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestChainPrevention:
|
||||
"""Tests for fine-tune chain prevention logic."""
|
||||
|
||||
def test_rejects_finetune_from_finetune_model(self):
|
||||
"""Should reject training when base model is already a fine-tune."""
|
||||
# Simulate the chain prevention check from datasets.py
|
||||
model_type = "finetune"
|
||||
base_model_version_id = str(uuid4())
|
||||
|
||||
# This should trigger rejection
|
||||
assert model_type == "finetune"
|
||||
|
||||
def test_allows_finetune_from_base_model(self):
|
||||
"""Should allow training when base model is a base model."""
|
||||
model_type = "base"
|
||||
assert model_type != "finetune"
|
||||
|
||||
def test_allows_fresh_training(self):
|
||||
"""Should allow fresh training (no base model)."""
|
||||
base_model_version_id = None
|
||||
assert base_model_version_id is None # No chain check needed
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Pool Threshold
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPoolThreshold:
|
||||
"""Tests for minimum pool size enforcement."""
|
||||
|
||||
def test_min_pool_size_constant(self):
|
||||
"""MIN_POOL_SIZE should be 50."""
|
||||
from backend.web.services.data_mixer import MIN_POOL_SIZE
|
||||
|
||||
assert MIN_POOL_SIZE == 50
|
||||
|
||||
def test_pool_below_threshold_blocks_finetune(self):
|
||||
"""Pool with fewer than 50 verified entries should block fine-tuning."""
|
||||
from backend.web.services.data_mixer import MIN_POOL_SIZE
|
||||
|
||||
verified_count = 30
|
||||
assert verified_count < MIN_POOL_SIZE
|
||||
|
||||
def test_pool_at_threshold_allows_finetune(self):
|
||||
"""Pool with exactly 50 verified entries should allow fine-tuning."""
|
||||
from backend.web.services.data_mixer import MIN_POOL_SIZE
|
||||
|
||||
verified_count = 50
|
||||
assert verified_count >= MIN_POOL_SIZE
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Gating Enforcement on Activation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGatingEnforcement:
|
||||
"""Tests for gating enforcement when activating models."""
|
||||
|
||||
def test_base_model_skips_gating(self):
|
||||
"""Base models should have gating_status 'skipped'."""
|
||||
from backend.data.admin_models import ModelVersion
|
||||
|
||||
mv = ModelVersion(
|
||||
version="v1.0",
|
||||
name="base",
|
||||
model_path="/model.pt",
|
||||
model_type="base",
|
||||
)
|
||||
# Base models skip gating - activation should work
|
||||
assert mv.model_type == "base"
|
||||
# Gating should not block base model activation
|
||||
|
||||
def test_finetune_model_rejected_blocks_activation(self):
|
||||
"""Fine-tuned models with 'reject' gating should block activation."""
|
||||
model_type = "finetune"
|
||||
gating_status = "reject"
|
||||
|
||||
# Simulates the check in models.py activation endpoint
|
||||
should_block = model_type == "finetune" and gating_status == "reject"
|
||||
assert should_block is True
|
||||
|
||||
def test_finetune_model_pending_blocks_activation(self):
|
||||
"""Fine-tuned models with 'pending' gating should block activation."""
|
||||
model_type = "finetune"
|
||||
gating_status = "pending"
|
||||
|
||||
should_block = model_type == "finetune" and gating_status == "pending"
|
||||
assert should_block is True
|
||||
|
||||
def test_finetune_model_pass_allows_activation(self):
|
||||
"""Fine-tuned models with 'pass' gating should allow activation."""
|
||||
model_type = "finetune"
|
||||
gating_status = "pass"
|
||||
|
||||
should_block_reject = model_type == "finetune" and gating_status == "reject"
|
||||
should_block_pending = model_type == "finetune" and gating_status == "pending"
|
||||
assert should_block_reject is False
|
||||
assert should_block_pending is False
|
||||
|
||||
def test_finetune_model_review_allows_with_warning(self):
|
||||
"""Fine-tuned models with 'review' gating should allow but warn."""
|
||||
model_type = "finetune"
|
||||
gating_status = "review"
|
||||
|
||||
should_block_reject = model_type == "finetune" and gating_status == "reject"
|
||||
should_block_pending = model_type == "finetune" and gating_status == "pending"
|
||||
assert should_block_reject is False
|
||||
assert should_block_pending is False
|
||||
# Should include warning in message
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Pool API Route Registration
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPoolRouteRegistration:
|
||||
"""Tests for pool route registration."""
|
||||
|
||||
def test_pool_routes_registered(self):
|
||||
"""Pool routes should be registered on training router."""
|
||||
from backend.web.api.v1.admin.training import create_training_router
|
||||
|
||||
router = create_training_router()
|
||||
paths = [route.path for route in router.routes]
|
||||
|
||||
assert any("/pool" in p for p in paths)
|
||||
assert any("/pool/stats" in p for p in paths)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Scheduler Fine-Tune Parameter Override
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestSchedulerFineTuneParams:
|
||||
"""Tests for scheduler fine-tune parameter overrides."""
|
||||
|
||||
def test_finetune_detected_from_base_model_path(self):
|
||||
"""Scheduler should detect fine-tune mode from base_model_path."""
|
||||
config = {"base_model_path": "/path/to/base_model.pt"}
|
||||
is_finetune = bool(config.get("base_model_path"))
|
||||
assert is_finetune is True
|
||||
|
||||
def test_fresh_training_not_finetune(self):
|
||||
"""Scheduler should not enable fine-tune for fresh training."""
|
||||
config = {"model_name": "yolo26s.pt"}
|
||||
is_finetune = bool(config.get("base_model_path"))
|
||||
assert is_finetune is False
|
||||
|
||||
def test_finetune_defaults_correct_epochs(self):
|
||||
"""Fine-tune should default to 10 epochs."""
|
||||
config = {"base_model_path": "/path/to/model.pt"}
|
||||
is_finetune = bool(config.get("base_model_path"))
|
||||
|
||||
if is_finetune:
|
||||
epochs = config.get("epochs", 10)
|
||||
learning_rate = config.get("learning_rate", 0.001)
|
||||
else:
|
||||
epochs = config.get("epochs", 100)
|
||||
learning_rate = config.get("learning_rate", 0.01)
|
||||
|
||||
assert epochs == 10
|
||||
assert learning_rate == 0.001
|
||||
|
||||
def test_model_lineage_set_for_finetune(self):
|
||||
"""Scheduler should set model_type and base_model_version_id for fine-tune."""
|
||||
config = {
|
||||
"base_model_path": "/path/to/model.pt",
|
||||
"base_model_version_id": str(uuid4()),
|
||||
}
|
||||
is_finetune = bool(config.get("base_model_path"))
|
||||
model_type = "finetune" if is_finetune else "base"
|
||||
base_model_version_id = config.get("base_model_version_id") if is_finetune else None
|
||||
gating_status = "pending" if is_finetune else "skipped"
|
||||
|
||||
assert model_type == "finetune"
|
||||
assert base_model_version_id is not None
|
||||
assert gating_status == "pending"
|
||||
|
||||
def test_model_lineage_skipped_for_base(self):
|
||||
"""Scheduler should set model_type='base' for fresh training."""
|
||||
config = {"model_name": "yolo26s.pt"}
|
||||
is_finetune = bool(config.get("base_model_path"))
|
||||
model_type = "finetune" if is_finetune else "base"
|
||||
gating_status = "pending" if is_finetune else "skipped"
|
||||
|
||||
assert model_type == "base"
|
||||
assert gating_status == "skipped"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test TrainingConfig freeze/cos_lr
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTrainingConfigFineTuneFields:
|
||||
"""Tests for freeze and cos_lr fields in shared TrainingConfig."""
|
||||
|
||||
def test_default_freeze_is_zero(self):
|
||||
"""TrainingConfig freeze should default to 0."""
|
||||
from shared.training import TrainingConfig
|
||||
|
||||
config = TrainingConfig(
|
||||
model_path="test.pt",
|
||||
data_yaml="data.yaml",
|
||||
)
|
||||
assert config.freeze == 0
|
||||
|
||||
def test_default_cos_lr_is_false(self):
|
||||
"""TrainingConfig cos_lr should default to False."""
|
||||
from shared.training import TrainingConfig
|
||||
|
||||
config = TrainingConfig(
|
||||
model_path="test.pt",
|
||||
data_yaml="data.yaml",
|
||||
)
|
||||
assert config.cos_lr is False
|
||||
|
||||
def test_finetune_config(self):
|
||||
"""TrainingConfig should accept fine-tune parameters."""
|
||||
from shared.training import TrainingConfig
|
||||
|
||||
config = TrainingConfig(
|
||||
model_path="base_model.pt",
|
||||
data_yaml="data.yaml",
|
||||
epochs=10,
|
||||
learning_rate=0.001,
|
||||
freeze=10,
|
||||
cos_lr=True,
|
||||
)
|
||||
assert config.freeze == 10
|
||||
assert config.cos_lr is True
|
||||
assert config.epochs == 10
|
||||
assert config.learning_rate == 0.001
|
||||
Reference in New Issue
Block a user