""" Model Version Repository Integration Tests Tests ModelVersionRepository with real database operations. """ from datetime import datetime, timezone from uuid import uuid4 import pytest from inference.data.repositories.model_version_repository import ModelVersionRepository class TestModelVersionCreate: """Tests for model version creation.""" def test_create_model_version(self, patched_session): """Test creating a model version.""" repo = ModelVersionRepository() model = repo.create( version="1.0.0", name="Invoice Extractor v1", model_path="/models/invoice_v1.pt", description="Initial production model", metrics_mAP=0.92, metrics_precision=0.89, metrics_recall=0.85, document_count=1000, file_size=50000000, ) assert model is not None assert model.version == "1.0.0" assert model.name == "Invoice Extractor v1" assert model.model_path == "/models/invoice_v1.pt" assert model.metrics_mAP == 0.92 assert model.is_active is False assert model.status == "inactive" def test_create_model_version_with_training_info( self, patched_session, sample_training_task, sample_dataset ): """Test creating model version linked to training task and dataset.""" repo = ModelVersionRepository() model = repo.create( version="1.1.0", name="Invoice Extractor v1.1", model_path="/models/invoice_v1.1.pt", task_id=sample_training_task.task_id, dataset_id=sample_dataset.dataset_id, training_config={"epochs": 100, "batch_size": 16}, trained_at=datetime.now(timezone.utc), ) assert model is not None assert model.task_id == sample_training_task.task_id assert model.dataset_id == sample_dataset.dataset_id assert model.training_config["epochs"] == 100 class TestModelVersionRead: """Tests for model version retrieval.""" def test_get_model_version_by_id(self, patched_session, sample_model_version): """Test getting model version by ID.""" repo = ModelVersionRepository() model = repo.get(str(sample_model_version.version_id)) assert model is not None assert model.version_id == sample_model_version.version_id def test_get_nonexistent_model_version(self, patched_session): """Test getting model version that doesn't exist.""" repo = ModelVersionRepository() model = repo.get(str(uuid4())) assert model is None def test_get_paginated_model_versions(self, patched_session): """Test paginated model version listing.""" repo = ModelVersionRepository() # Create multiple versions for i in range(5): repo.create( version=f"1.{i}.0", name=f"Model v1.{i}", model_path=f"/models/model_v1.{i}.pt", ) models, total = repo.get_paginated(limit=2, offset=0) assert total == 5 assert len(models) == 2 def test_get_paginated_with_status_filter(self, patched_session): """Test filtering model versions by status.""" repo = ModelVersionRepository() # Create active and inactive models m1 = repo.create(version="1.0.0", name="Active Model", model_path="/models/active.pt") repo.activate(str(m1.version_id)) repo.create(version="2.0.0", name="Inactive Model", model_path="/models/inactive.pt") active_models, active_total = repo.get_paginated(status="active") inactive_models, inactive_total = repo.get_paginated(status="inactive") assert active_total == 1 assert inactive_total == 1 class TestModelVersionActivation: """Tests for model version activation.""" def test_activate_model_version(self, patched_session, sample_model_version): """Test activating a model version.""" repo = ModelVersionRepository() model = repo.activate(str(sample_model_version.version_id)) assert model is not None assert model.is_active is True assert model.status == "active" assert model.activated_at is not None def test_activate_deactivates_others(self, patched_session): """Test that activating one version deactivates others.""" repo = ModelVersionRepository() # Create and activate first model m1 = repo.create(version="1.0.0", name="Model 1", model_path="/models/m1.pt") repo.activate(str(m1.version_id)) # Create and activate second model m2 = repo.create(version="2.0.0", name="Model 2", model_path="/models/m2.pt") repo.activate(str(m2.version_id)) # Check first model is now inactive m1_after = repo.get(str(m1.version_id)) assert m1_after.is_active is False assert m1_after.status == "inactive" # Check second model is active m2_after = repo.get(str(m2.version_id)) assert m2_after.is_active is True def test_get_active_model(self, patched_session, sample_model_version): """Test getting the currently active model.""" repo = ModelVersionRepository() # Initially no active model active = repo.get_active() assert active is None # Activate model repo.activate(str(sample_model_version.version_id)) # Now should return active model active = repo.get_active() assert active is not None assert active.version_id == sample_model_version.version_id def test_deactivate_model_version(self, patched_session, sample_model_version): """Test deactivating a model version.""" repo = ModelVersionRepository() # First activate repo.activate(str(sample_model_version.version_id)) # Then deactivate model = repo.deactivate(str(sample_model_version.version_id)) assert model is not None assert model.is_active is False assert model.status == "inactive" class TestModelVersionUpdate: """Tests for model version updates.""" def test_update_model_metadata(self, patched_session, sample_model_version): """Test updating model version metadata.""" repo = ModelVersionRepository() model = repo.update( str(sample_model_version.version_id), name="Updated Model Name", description="Updated description", ) assert model is not None assert model.name == "Updated Model Name" assert model.description == "Updated description" def test_update_model_status(self, patched_session, sample_model_version): """Test updating model version status.""" repo = ModelVersionRepository() model = repo.update(str(sample_model_version.version_id), status="deprecated") assert model is not None assert model.status == "deprecated" def test_update_nonexistent_model(self, patched_session): """Test updating model that doesn't exist.""" repo = ModelVersionRepository() model = repo.update(str(uuid4()), name="New Name") assert model is None class TestModelVersionArchive: """Tests for model version archiving.""" def test_archive_model_version(self, patched_session, sample_model_version): """Test archiving an inactive model version.""" repo = ModelVersionRepository() model = repo.archive(str(sample_model_version.version_id)) assert model is not None assert model.status == "archived" def test_cannot_archive_active_model(self, patched_session, sample_model_version): """Test that active model cannot be archived.""" repo = ModelVersionRepository() # Activate the model repo.activate(str(sample_model_version.version_id)) # Try to archive model = repo.archive(str(sample_model_version.version_id)) assert model is None # Verify model is still active current = repo.get(str(sample_model_version.version_id)) assert current.status == "active" class TestModelVersionDelete: """Tests for model version deletion.""" def test_delete_inactive_model(self, patched_session, sample_model_version): """Test deleting an inactive model version.""" repo = ModelVersionRepository() result = repo.delete(str(sample_model_version.version_id)) assert result is True model = repo.get(str(sample_model_version.version_id)) assert model is None def test_cannot_delete_active_model(self, patched_session, sample_model_version): """Test that active model cannot be deleted.""" repo = ModelVersionRepository() # Activate the model repo.activate(str(sample_model_version.version_id)) # Try to delete result = repo.delete(str(sample_model_version.version_id)) assert result is False # Verify model still exists model = repo.get(str(sample_model_version.version_id)) assert model is not None def test_delete_nonexistent_model(self, patched_session): """Test deleting model that doesn't exist.""" repo = ModelVersionRepository() result = repo.delete(str(uuid4())) assert result is False class TestOnlyOneActiveModel: """Tests to verify only one model can be active at a time.""" def test_single_active_model_constraint(self, patched_session): """Test that only one model can be active at any time.""" repo = ModelVersionRepository() # Create multiple models models = [] for i in range(3): m = repo.create( version=f"1.{i}.0", name=f"Model {i}", model_path=f"/models/model_{i}.pt", ) models.append(m) # Activate each model in sequence for model in models: repo.activate(str(model.version_id)) # Count active models all_models, _ = repo.get_paginated(status="active") assert len(all_models) == 1 # Verify it's the last one activated assert all_models[0].version_id == models[-1].version_id