""" Tests for ModelVersionRepository 100% coverage tests for model version management. """ import pytest from datetime import datetime, timezone from unittest.mock import MagicMock, patch from uuid import uuid4, UUID from inference.data.admin_models import ModelVersion from inference.data.repositories.model_version_repository import ModelVersionRepository class TestModelVersionRepository: """Tests for ModelVersionRepository.""" @pytest.fixture def sample_model(self) -> ModelVersion: """Create a sample model version for testing.""" return ModelVersion( version_id=uuid4(), version="v1.0.0", name="Test Model", description="A test model", model_path="/path/to/model.pt", status="ready", is_active=False, metrics_mAP=0.95, metrics_precision=0.92, metrics_recall=0.88, document_count=100, training_config={"epochs": 100}, file_size=1024000, created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) @pytest.fixture def active_model(self) -> ModelVersion: """Create an active model version for testing.""" return ModelVersion( version_id=uuid4(), version="v1.0.0", name="Active Model", model_path="/path/to/active_model.pt", status="active", is_active=True, created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) @pytest.fixture def repo(self) -> ModelVersionRepository: """Create a ModelVersionRepository instance.""" return ModelVersionRepository() # ========================================================================= # create() tests # ========================================================================= def test_create_returns_model(self, repo): """Test create returns created model version.""" with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.create( version="v1.0.0", name="Test Model", model_path="/path/to/model.pt", ) mock_session.add.assert_called_once() mock_session.commit.assert_called_once() def test_create_with_all_params(self, repo): """Test create with all parameters.""" task_id = uuid4() dataset_id = uuid4() trained_at = datetime.now(timezone.utc) with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.create( version="v2.0.0", name="Full Model", model_path="/path/to/full_model.pt", description="A complete model", task_id=str(task_id), dataset_id=str(dataset_id), metrics_mAP=0.95, metrics_precision=0.92, metrics_recall=0.88, document_count=500, training_config={"epochs": 200}, file_size=2048000, trained_at=trained_at, ) added_model = mock_session.add.call_args[0][0] assert added_model.version == "v2.0.0" assert added_model.description == "A complete model" assert added_model.task_id == task_id assert added_model.dataset_id == dataset_id assert added_model.metrics_mAP == 0.95 def test_create_with_uuid_objects(self, repo): """Test create works with UUID objects.""" task_id = uuid4() dataset_id = uuid4() with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.create( version="v1.0.0", name="Test Model", model_path="/path/to/model.pt", task_id=task_id, dataset_id=dataset_id, ) added_model = mock_session.add.call_args[0][0] assert added_model.task_id == task_id assert added_model.dataset_id == dataset_id def test_create_without_optional_ids(self, repo): """Test create without task_id and dataset_id.""" with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.create( version="v1.0.0", name="Test Model", model_path="/path/to/model.pt", ) added_model = mock_session.add.call_args[0][0] assert added_model.task_id is None assert added_model.dataset_id is None # ========================================================================= # get() tests # ========================================================================= def test_get_returns_model(self, repo, sample_model): """Test get returns model when exists.""" with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_model mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get(str(sample_model.version_id)) assert result is not None assert result.name == "Test Model" mock_session.expunge.assert_called_once() def test_get_with_uuid(self, repo, sample_model): """Test get works with UUID object.""" with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_model mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get(sample_model.version_id) assert result is not None def test_get_returns_none_when_not_found(self, repo): """Test get returns None when model not found.""" with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get(str(uuid4())) assert result is None mock_session.expunge.assert_not_called() # ========================================================================= # get_paginated() tests # ========================================================================= def test_get_paginated_returns_models_and_total(self, repo, sample_model): """Test get_paginated returns list of models and total count.""" with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 1 mock_session.exec.return_value.all.return_value = [sample_model] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) models, total = repo.get_paginated() assert len(models) == 1 assert total == 1 def test_get_paginated_with_status_filter(self, repo, sample_model): """Test get_paginated filters by status.""" with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 1 mock_session.exec.return_value.all.return_value = [sample_model] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) models, total = repo.get_paginated(status="ready") assert len(models) == 1 def test_get_paginated_with_pagination(self, repo, sample_model): """Test get_paginated with limit and offset.""" with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 50 mock_session.exec.return_value.all.return_value = [sample_model] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) models, total = repo.get_paginated(limit=10, offset=20) assert total == 50 def test_get_paginated_empty_results(self, repo): """Test get_paginated with no results.""" with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 0 mock_session.exec.return_value.all.return_value = [] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) models, total = repo.get_paginated() assert models == [] assert total == 0 # ========================================================================= # get_active() tests # ========================================================================= def test_get_active_returns_active_model(self, repo, active_model): """Test get_active returns the active model.""" with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.first.return_value = active_model mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_active() assert result is not None assert result.is_active is True mock_session.expunge.assert_called_once() def test_get_active_returns_none(self, repo): """Test get_active returns None when no active model.""" with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.first.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_active() assert result is None mock_session.expunge.assert_not_called() # ========================================================================= # activate() tests # ========================================================================= def test_activate_activates_model(self, repo, sample_model, active_model): """Test activate sets model as active and deactivates others.""" with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [active_model] mock_session.get.return_value = sample_model mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.activate(str(sample_model.version_id)) assert result is not None assert sample_model.is_active is True assert sample_model.status == "active" assert active_model.is_active is False assert active_model.status == "inactive" def test_activate_with_uuid(self, repo, sample_model): """Test activate works with UUID object.""" with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [] mock_session.get.return_value = sample_model mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.activate(sample_model.version_id) assert result is not None assert sample_model.is_active is True def test_activate_returns_none_when_not_found(self, repo): """Test activate returns None when model not found.""" with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [] mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.activate(str(uuid4())) assert result is None def test_activate_sets_activated_at(self, repo, sample_model): """Test activate sets activated_at timestamp.""" sample_model.activated_at = None with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [] mock_session.get.return_value = sample_model mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.activate(str(sample_model.version_id)) assert sample_model.activated_at is not None # ========================================================================= # deactivate() tests # ========================================================================= def test_deactivate_deactivates_model(self, repo, active_model): """Test deactivate sets model as inactive.""" with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = active_model mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.deactivate(str(active_model.version_id)) assert result is not None assert active_model.is_active is False assert active_model.status == "inactive" mock_session.commit.assert_called_once() def test_deactivate_with_uuid(self, repo, active_model): """Test deactivate works with UUID object.""" with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = active_model mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.deactivate(active_model.version_id) assert result is not None def test_deactivate_returns_none_when_not_found(self, repo): """Test deactivate returns None when model not found.""" with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.deactivate(str(uuid4())) assert result is None # ========================================================================= # update() tests # ========================================================================= def test_update_updates_model(self, repo, sample_model): """Test update updates model metadata.""" with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_model mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.update( str(sample_model.version_id), name="Updated Model", ) assert result is not None assert sample_model.name == "Updated Model" mock_session.commit.assert_called_once() def test_update_all_fields(self, repo, sample_model): """Test update can update all fields.""" with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_model mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.update( str(sample_model.version_id), name="New Name", description="New Description", status="archived", ) assert sample_model.name == "New Name" assert sample_model.description == "New Description" assert sample_model.status == "archived" def test_update_with_uuid(self, repo, sample_model): """Test update works with UUID object.""" with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_model mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.update(sample_model.version_id, name="Updated") assert result is not None def test_update_returns_none_when_not_found(self, repo): """Test update returns None when model not found.""" with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.update(str(uuid4()), name="New Name") assert result is None def test_update_partial_fields(self, repo, sample_model): """Test update only updates provided fields.""" original_name = sample_model.name with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_model mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.update( str(sample_model.version_id), description="Only description changed", ) assert sample_model.name == original_name assert sample_model.description == "Only description changed" # ========================================================================= # archive() tests # ========================================================================= def test_archive_archives_model(self, repo, sample_model): """Test archive sets model status to archived.""" sample_model.is_active = False with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_model mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.archive(str(sample_model.version_id)) assert result is not None assert sample_model.status == "archived" mock_session.commit.assert_called_once() def test_archive_with_uuid(self, repo, sample_model): """Test archive works with UUID object.""" sample_model.is_active = False with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_model mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.archive(sample_model.version_id) assert result is not None def test_archive_returns_none_when_not_found(self, repo): """Test archive returns None when model not found.""" with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.archive(str(uuid4())) assert result is None def test_archive_returns_none_when_active(self, repo, active_model): """Test archive returns None when model is active.""" with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = active_model mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.archive(str(active_model.version_id)) assert result is None # ========================================================================= # delete() tests # ========================================================================= def test_delete_returns_true(self, repo, sample_model): """Test delete returns True when model exists and not active.""" sample_model.is_active = False with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_model mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.delete(str(sample_model.version_id)) assert result is True mock_session.delete.assert_called_once() mock_session.commit.assert_called_once() def test_delete_with_uuid(self, repo, sample_model): """Test delete works with UUID object.""" sample_model.is_active = False with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_model mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.delete(sample_model.version_id) assert result is True def test_delete_returns_false_when_not_found(self, repo): """Test delete returns False when model not found.""" with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.delete(str(uuid4())) assert result is False mock_session.delete.assert_not_called() def test_delete_returns_false_when_active(self, repo, active_model): """Test delete returns False when model is active.""" with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = active_model mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.delete(str(active_model.version_id)) assert result is False mock_session.delete.assert_not_called()