""" Tests for Fine-Tune Pool feature. Tests cover: 1. FineTunePoolEntry database model 2. PoolAddRequest/PoolStatsResponse schemas 3. Chain prevention logic 4. Pool threshold enforcement 5. Model lineage fields on ModelVersion 6. Gating enforcement on model activation """ import pytest from datetime import datetime from unittest.mock import MagicMock, patch from uuid import uuid4, UUID # ============================================================================= # Test Database Models # ============================================================================= class TestFineTunePoolEntryModel: """Tests for FineTunePoolEntry model.""" def test_creates_with_defaults(self): """FineTunePoolEntry should have correct defaults.""" from backend.data.admin_models import FineTunePoolEntry entry = FineTunePoolEntry(document_id=uuid4()) assert entry.entry_id is not None assert entry.is_verified is False assert entry.verified_at is None assert entry.verified_by is None assert entry.added_by is None assert entry.reason is None def test_creates_with_all_fields(self): """FineTunePoolEntry should accept all fields.""" from backend.data.admin_models import FineTunePoolEntry doc_id = uuid4() entry = FineTunePoolEntry( document_id=doc_id, added_by="admin", reason="user_reported_failure", is_verified=True, verified_by="reviewer", ) assert entry.document_id == doc_id assert entry.added_by == "admin" assert entry.reason == "user_reported_failure" assert entry.is_verified is True assert entry.verified_by == "reviewer" class TestGatingResultModel: """Tests for GatingResult model.""" def test_creates_with_defaults(self): """GatingResult should have correct defaults.""" from backend.data.admin_models import GatingResult model_version_id = uuid4() result = GatingResult( model_version_id=model_version_id, gate1_status="pass", gate2_status="pass", overall_status="pass", ) assert result.result_id is not None assert result.model_version_id == model_version_id assert result.gate1_status == "pass" assert result.gate2_status == "pass" assert result.overall_status == "pass" assert result.gate1_mAP_drop is None assert result.gate2_detection_rate is None def test_creates_with_full_metrics(self): """GatingResult should store full metrics.""" from backend.data.admin_models import GatingResult result = GatingResult( model_version_id=uuid4(), gate1_status="review", gate1_original_mAP=0.95, gate1_new_mAP=0.93, gate1_mAP_drop=0.02, gate2_status="pass", gate2_detection_rate=0.85, gate2_total_samples=100, gate2_detected_samples=85, overall_status="review", ) assert result.gate1_original_mAP == 0.95 assert result.gate1_new_mAP == 0.93 assert result.gate1_mAP_drop == 0.02 assert result.gate2_detection_rate == 0.85 class TestModelVersionLineage: """Tests for ModelVersion lineage fields.""" def test_default_model_type_is_base(self): """ModelVersion should default to 'base' model_type.""" from backend.data.admin_models import ModelVersion mv = ModelVersion( version="v1.0", name="test-model", model_path="/path/to/model.pt", ) assert mv.model_type == "base" assert mv.base_model_version_id is None assert mv.base_training_dataset_id is None assert mv.gating_status == "pending" def test_finetune_model_type(self): """ModelVersion should support 'finetune' type with lineage.""" from backend.data.admin_models import ModelVersion base_id = uuid4() dataset_id = uuid4() mv = ModelVersion( version="v2.0", name="finetuned-model", model_path="/path/to/ft_model.pt", model_type="finetune", base_model_version_id=base_id, base_training_dataset_id=dataset_id, gating_status="pending", ) assert mv.model_type == "finetune" assert mv.base_model_version_id == base_id assert mv.base_training_dataset_id == dataset_id assert mv.gating_status == "pending" # ============================================================================= # Test Schemas # ============================================================================= class TestPoolSchemas: """Tests for pool Pydantic schemas.""" def test_pool_add_request_defaults(self): """PoolAddRequest should have default reason.""" from backend.web.schemas.admin.pool import PoolAddRequest req = PoolAddRequest(document_id="550e8400-e29b-41d4-a716-446655440001") assert req.document_id == "550e8400-e29b-41d4-a716-446655440001" assert req.reason == "user_reported_failure" def test_pool_add_request_custom_reason(self): """PoolAddRequest should accept custom reason.""" from backend.web.schemas.admin.pool import PoolAddRequest req = PoolAddRequest( document_id="550e8400-e29b-41d4-a716-446655440001", reason="manual_addition", ) assert req.reason == "manual_addition" def test_pool_stats_response(self): """PoolStatsResponse should compute readiness correctly.""" from backend.web.schemas.admin.pool import PoolStatsResponse # Not ready stats = PoolStatsResponse( total_entries=30, verified_entries=20, unverified_entries=10, is_ready=False, ) assert stats.is_ready is False assert stats.min_required == 50 # Ready stats_ready = PoolStatsResponse( total_entries=80, verified_entries=60, unverified_entries=20, is_ready=True, ) assert stats_ready.is_ready is True def test_pool_entry_item(self): """PoolEntryItem should serialize correctly.""" from backend.web.schemas.admin.pool import PoolEntryItem entry = PoolEntryItem( entry_id="entry-uuid", document_id="doc-uuid", is_verified=True, verified_at=datetime.utcnow(), verified_by="admin", created_at=datetime.utcnow(), ) assert entry.is_verified is True assert entry.verified_by == "admin" def test_gating_result_item(self): """GatingResultItem should serialize all gate fields.""" from backend.web.schemas.admin.pool import GatingResultItem item = GatingResultItem( result_id="result-uuid", model_version_id="model-uuid", gate1_status="pass", gate1_original_mAP=0.95, gate1_new_mAP=0.94, gate1_mAP_drop=0.01, gate2_status="pass", gate2_detection_rate=0.90, gate2_total_samples=50, gate2_detected_samples=45, overall_status="pass", created_at=datetime.utcnow(), ) assert item.gate1_status == "pass" assert item.overall_status == "pass" # ============================================================================= # Test Chain Prevention # ============================================================================= class TestChainPrevention: """Tests for fine-tune chain prevention logic.""" def test_rejects_finetune_from_finetune_model(self): """Should reject training when base model is already a fine-tune.""" # Simulate the chain prevention check from datasets.py model_type = "finetune" base_model_version_id = str(uuid4()) # This should trigger rejection assert model_type == "finetune" def test_allows_finetune_from_base_model(self): """Should allow training when base model is a base model.""" model_type = "base" assert model_type != "finetune" def test_allows_fresh_training(self): """Should allow fresh training (no base model).""" base_model_version_id = None assert base_model_version_id is None # No chain check needed # ============================================================================= # Test Pool Threshold # ============================================================================= class TestPoolThreshold: """Tests for minimum pool size enforcement.""" def test_min_pool_size_constant(self): """MIN_POOL_SIZE should be 50.""" from backend.web.services.data_mixer import MIN_POOL_SIZE assert MIN_POOL_SIZE == 50 def test_pool_below_threshold_blocks_finetune(self): """Pool with fewer than 50 verified entries should block fine-tuning.""" from backend.web.services.data_mixer import MIN_POOL_SIZE verified_count = 30 assert verified_count < MIN_POOL_SIZE def test_pool_at_threshold_allows_finetune(self): """Pool with exactly 50 verified entries should allow fine-tuning.""" from backend.web.services.data_mixer import MIN_POOL_SIZE verified_count = 50 assert verified_count >= MIN_POOL_SIZE # ============================================================================= # Test Gating Enforcement on Activation # ============================================================================= class TestGatingEnforcement: """Tests for gating enforcement when activating models.""" def test_base_model_skips_gating(self): """Base models should have gating_status 'skipped'.""" from backend.data.admin_models import ModelVersion mv = ModelVersion( version="v1.0", name="base", model_path="/model.pt", model_type="base", ) # Base models skip gating - activation should work assert mv.model_type == "base" # Gating should not block base model activation def test_finetune_model_rejected_blocks_activation(self): """Fine-tuned models with 'reject' gating should block activation.""" model_type = "finetune" gating_status = "reject" # Simulates the check in models.py activation endpoint should_block = model_type == "finetune" and gating_status == "reject" assert should_block is True def test_finetune_model_pending_blocks_activation(self): """Fine-tuned models with 'pending' gating should block activation.""" model_type = "finetune" gating_status = "pending" should_block = model_type == "finetune" and gating_status == "pending" assert should_block is True def test_finetune_model_pass_allows_activation(self): """Fine-tuned models with 'pass' gating should allow activation.""" model_type = "finetune" gating_status = "pass" should_block_reject = model_type == "finetune" and gating_status == "reject" should_block_pending = model_type == "finetune" and gating_status == "pending" assert should_block_reject is False assert should_block_pending is False def test_finetune_model_review_allows_with_warning(self): """Fine-tuned models with 'review' gating should allow but warn.""" model_type = "finetune" gating_status = "review" should_block_reject = model_type == "finetune" and gating_status == "reject" should_block_pending = model_type == "finetune" and gating_status == "pending" assert should_block_reject is False assert should_block_pending is False # Should include warning in message # ============================================================================= # Test Pool API Route Registration # ============================================================================= class TestPoolRouteRegistration: """Tests for pool route registration.""" def test_pool_routes_registered(self): """Pool routes should be registered on training router.""" from backend.web.api.v1.admin.training import create_training_router router = create_training_router() paths = [route.path for route in router.routes] assert any("/pool" in p for p in paths) assert any("/pool/stats" in p for p in paths) # ============================================================================= # Test Scheduler Fine-Tune Parameter Override # ============================================================================= class TestSchedulerFineTuneParams: """Tests for scheduler fine-tune parameter overrides.""" def test_finetune_detected_from_base_model_path(self): """Scheduler should detect fine-tune mode from base_model_path.""" config = {"base_model_path": "/path/to/base_model.pt"} is_finetune = bool(config.get("base_model_path")) assert is_finetune is True def test_fresh_training_not_finetune(self): """Scheduler should not enable fine-tune for fresh training.""" config = {"model_name": "yolo26s.pt"} is_finetune = bool(config.get("base_model_path")) assert is_finetune is False def test_finetune_defaults_correct_epochs(self): """Fine-tune should default to 10 epochs.""" config = {"base_model_path": "/path/to/model.pt"} is_finetune = bool(config.get("base_model_path")) if is_finetune: epochs = config.get("epochs", 10) learning_rate = config.get("learning_rate", 0.001) else: epochs = config.get("epochs", 100) learning_rate = config.get("learning_rate", 0.01) assert epochs == 10 assert learning_rate == 0.001 def test_model_lineage_set_for_finetune(self): """Scheduler should set model_type and base_model_version_id for fine-tune.""" config = { "base_model_path": "/path/to/model.pt", "base_model_version_id": str(uuid4()), } is_finetune = bool(config.get("base_model_path")) model_type = "finetune" if is_finetune else "base" base_model_version_id = config.get("base_model_version_id") if is_finetune else None gating_status = "pending" if is_finetune else "skipped" assert model_type == "finetune" assert base_model_version_id is not None assert gating_status == "pending" def test_model_lineage_skipped_for_base(self): """Scheduler should set model_type='base' for fresh training.""" config = {"model_name": "yolo26s.pt"} is_finetune = bool(config.get("base_model_path")) model_type = "finetune" if is_finetune else "base" gating_status = "pending" if is_finetune else "skipped" assert model_type == "base" assert gating_status == "skipped" # ============================================================================= # Test TrainingConfig freeze/cos_lr # ============================================================================= class TestTrainingConfigFineTuneFields: """Tests for freeze and cos_lr fields in shared TrainingConfig.""" def test_default_freeze_is_zero(self): """TrainingConfig freeze should default to 0.""" from shared.training import TrainingConfig config = TrainingConfig( model_path="test.pt", data_yaml="data.yaml", ) assert config.freeze == 0 def test_default_cos_lr_is_false(self): """TrainingConfig cos_lr should default to False.""" from shared.training import TrainingConfig config = TrainingConfig( model_path="test.pt", data_yaml="data.yaml", ) assert config.cos_lr is False def test_finetune_config(self): """TrainingConfig should accept fine-tune parameters.""" from shared.training import TrainingConfig config = TrainingConfig( model_path="base_model.pt", data_yaml="data.yaml", epochs=10, learning_rate=0.001, freeze=10, cos_lr=True, ) assert config.freeze == 10 assert config.cos_lr is True assert config.epochs == 10 assert config.learning_rate == 0.001