""" Tests for AnnotationRepository 100% coverage tests for annotation management. """ import pytest from datetime import datetime, timezone from unittest.mock import MagicMock, patch from uuid import uuid4, UUID from inference.data.admin_models import AdminAnnotation, AnnotationHistory from inference.data.repositories.annotation_repository import AnnotationRepository class TestAnnotationRepository: """Tests for AnnotationRepository.""" @pytest.fixture def sample_annotation(self) -> AdminAnnotation: """Create a sample annotation for testing.""" return AdminAnnotation( annotation_id=uuid4(), document_id=uuid4(), page_number=1, class_id=0, class_name="invoice_number", x_center=0.5, y_center=0.3, width=0.2, height=0.05, bbox_x=100, bbox_y=200, bbox_width=150, bbox_height=30, text_value="INV-001", confidence=0.95, source="auto", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) @pytest.fixture def sample_history(self) -> AnnotationHistory: """Create a sample annotation history for testing.""" return AnnotationHistory( history_id=uuid4(), annotation_id=uuid4(), document_id=uuid4(), action="override", previous_value={"class_name": "old_class"}, new_value={"class_name": "new_class"}, changed_by="admin-token", change_reason="Correction", created_at=datetime.now(timezone.utc), ) @pytest.fixture def repo(self) -> AnnotationRepository: """Create an AnnotationRepository instance.""" return AnnotationRepository() # ========================================================================= # create() tests # ========================================================================= def test_create_returns_annotation_id(self, repo): """Test create returns annotation ID.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.create( document_id=str(uuid4()), page_number=1, class_id=0, class_name="invoice_number", x_center=0.5, y_center=0.3, width=0.2, height=0.05, bbox_x=100, bbox_y=200, bbox_width=150, bbox_height=30, ) assert result is not None mock_session.add.assert_called_once() mock_session.flush.assert_called_once() def test_create_with_optional_params(self, repo): """Test create with optional text_value and confidence.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.create( document_id=str(uuid4()), page_number=2, class_id=1, class_name="invoice_date", x_center=0.6, y_center=0.4, width=0.15, height=0.04, bbox_x=200, bbox_y=300, bbox_width=100, bbox_height=25, text_value="2024-01-15", confidence=0.88, source="auto", ) assert result is not None mock_session.add.assert_called_once() added_annotation = mock_session.add.call_args[0][0] assert added_annotation.text_value == "2024-01-15" assert added_annotation.confidence == 0.88 assert added_annotation.source == "auto" def test_create_default_source_is_manual(self, repo): """Test create uses manual as default source.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.create( document_id=str(uuid4()), page_number=1, class_id=0, class_name="invoice_number", x_center=0.5, y_center=0.3, width=0.2, height=0.05, bbox_x=100, bbox_y=200, bbox_width=150, bbox_height=30, ) added_annotation = mock_session.add.call_args[0][0] assert added_annotation.source == "manual" # ========================================================================= # create_batch() tests # ========================================================================= def test_create_batch_returns_ids(self, repo): """Test create_batch returns list of annotation IDs.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) annotations = [ { "document_id": str(uuid4()), "class_id": 0, "class_name": "invoice_number", "x_center": 0.5, "y_center": 0.3, "width": 0.2, "height": 0.05, "bbox_x": 100, "bbox_y": 200, "bbox_width": 150, "bbox_height": 30, }, { "document_id": str(uuid4()), "class_id": 1, "class_name": "invoice_date", "x_center": 0.6, "y_center": 0.4, "width": 0.15, "height": 0.04, "bbox_x": 200, "bbox_y": 300, "bbox_width": 100, "bbox_height": 25, }, ] result = repo.create_batch(annotations) assert len(result) == 2 assert mock_session.add.call_count == 2 assert mock_session.flush.call_count == 2 def test_create_batch_default_page_number(self, repo): """Test create_batch uses page_number=1 by default.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) annotations = [ { "document_id": str(uuid4()), "class_id": 0, "class_name": "invoice_number", "x_center": 0.5, "y_center": 0.3, "width": 0.2, "height": 0.05, "bbox_x": 100, "bbox_y": 200, "bbox_width": 150, "bbox_height": 30, # no page_number }, ] repo.create_batch(annotations) added_annotation = mock_session.add.call_args[0][0] assert added_annotation.page_number == 1 def test_create_batch_with_all_optional_params(self, repo): """Test create_batch with all optional parameters.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) annotations = [ { "document_id": str(uuid4()), "page_number": 3, "class_id": 0, "class_name": "invoice_number", "x_center": 0.5, "y_center": 0.3, "width": 0.2, "height": 0.05, "bbox_x": 100, "bbox_y": 200, "bbox_width": 150, "bbox_height": 30, "text_value": "INV-123", "confidence": 0.92, "source": "ocr", }, ] repo.create_batch(annotations) added_annotation = mock_session.add.call_args[0][0] assert added_annotation.page_number == 3 assert added_annotation.text_value == "INV-123" assert added_annotation.confidence == 0.92 assert added_annotation.source == "ocr" def test_create_batch_empty_list(self, repo): """Test create_batch with empty list returns empty.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.create_batch([]) assert result == [] mock_session.add.assert_not_called() # ========================================================================= # get() tests # ========================================================================= def test_get_returns_annotation(self, repo, sample_annotation): """Test get returns annotation when exists.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_annotation mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get(str(sample_annotation.annotation_id)) assert result is not None assert result.class_name == "invoice_number" mock_session.expunge.assert_called_once() def test_get_returns_none_when_not_found(self, repo): """Test get returns None when annotation not found.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get(str(uuid4())) assert result is None mock_session.expunge.assert_not_called() # ========================================================================= # get_for_document() tests # ========================================================================= def test_get_for_document_returns_all_annotations(self, repo, sample_annotation): """Test get_for_document returns all annotations for document.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [sample_annotation] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_for_document(str(sample_annotation.document_id)) assert len(result) == 1 assert result[0].class_name == "invoice_number" def test_get_for_document_with_page_filter(self, repo, sample_annotation): """Test get_for_document filters by page number.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [sample_annotation] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_for_document(str(sample_annotation.document_id), page_number=1) assert len(result) == 1 def test_get_for_document_returns_empty_list(self, repo): """Test get_for_document returns empty list when no annotations.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_for_document(str(uuid4())) assert result == [] # ========================================================================= # update() tests # ========================================================================= def test_update_returns_true(self, repo, sample_annotation): """Test update returns True when annotation exists.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_annotation mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.update( str(sample_annotation.annotation_id), text_value="INV-002", ) assert result is True assert sample_annotation.text_value == "INV-002" def test_update_returns_false_when_not_found(self, repo): """Test update returns False when annotation not found.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.update(str(uuid4()), text_value="INV-002") assert result is False def test_update_all_fields(self, repo, sample_annotation): """Test update can update all fields.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_annotation mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.update( str(sample_annotation.annotation_id), x_center=0.6, y_center=0.4, width=0.25, height=0.06, bbox_x=150, bbox_y=250, bbox_width=175, bbox_height=35, text_value="NEW-VALUE", class_id=5, class_name="new_class", ) assert result is True assert sample_annotation.x_center == 0.6 assert sample_annotation.y_center == 0.4 assert sample_annotation.width == 0.25 assert sample_annotation.height == 0.06 assert sample_annotation.bbox_x == 150 assert sample_annotation.bbox_y == 250 assert sample_annotation.bbox_width == 175 assert sample_annotation.bbox_height == 35 assert sample_annotation.text_value == "NEW-VALUE" assert sample_annotation.class_id == 5 assert sample_annotation.class_name == "new_class" def test_update_partial_fields(self, repo, sample_annotation): """Test update only updates provided fields.""" original_x = sample_annotation.x_center with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_annotation mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.update( str(sample_annotation.annotation_id), text_value="UPDATED", ) assert result is True assert sample_annotation.text_value == "UPDATED" assert sample_annotation.x_center == original_x # unchanged # ========================================================================= # delete() tests # ========================================================================= def test_delete_returns_true(self, repo, sample_annotation): """Test delete returns True when annotation exists.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_annotation mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.delete(str(sample_annotation.annotation_id)) assert result is True mock_session.delete.assert_called_once() def test_delete_returns_false_when_not_found(self, repo): """Test delete returns False when annotation not found.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.delete(str(uuid4())) assert result is False mock_session.delete.assert_not_called() # ========================================================================= # delete_for_document() tests # ========================================================================= def test_delete_for_document_returns_count(self, repo, sample_annotation): """Test delete_for_document returns count of deleted annotations.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [sample_annotation] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.delete_for_document(str(sample_annotation.document_id)) assert result == 1 mock_session.delete.assert_called_once() def test_delete_for_document_with_source_filter(self, repo, sample_annotation): """Test delete_for_document filters by source.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [sample_annotation] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.delete_for_document(str(sample_annotation.document_id), source="auto") assert result == 1 def test_delete_for_document_returns_zero(self, repo): """Test delete_for_document returns 0 when no annotations.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.delete_for_document(str(uuid4())) assert result == 0 mock_session.delete.assert_not_called() # ========================================================================= # verify() tests # ========================================================================= def test_verify_marks_annotation_verified(self, repo, sample_annotation): """Test verify marks annotation as verified.""" sample_annotation.is_verified = False with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_annotation mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.verify(str(sample_annotation.annotation_id), "admin-token") assert result is not None assert sample_annotation.is_verified is True assert sample_annotation.verified_by == "admin-token" mock_session.commit.assert_called_once() def test_verify_returns_none_when_not_found(self, repo): """Test verify returns None when annotation not found.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.verify(str(uuid4()), "admin-token") assert result is None # ========================================================================= # override() tests # ========================================================================= def test_override_updates_annotation(self, repo, sample_annotation): """Test override updates annotation and creates history.""" sample_annotation.source = "auto" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_annotation mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.override( str(sample_annotation.annotation_id), "admin-token", change_reason="Correction", text_value="NEW-VALUE", ) assert result is not None assert sample_annotation.text_value == "NEW-VALUE" assert sample_annotation.source == "manual" assert sample_annotation.override_source == "auto" assert mock_session.add.call_count >= 2 # annotation + history def test_override_returns_none_when_not_found(self, repo): """Test override returns None when annotation not found.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.override(str(uuid4()), "admin-token", text_value="NEW") assert result is None def test_override_does_not_change_source_if_already_manual(self, repo, sample_annotation): """Test override does not change override_source if already manual.""" sample_annotation.source = "manual" sample_annotation.override_source = None with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_annotation mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.override( str(sample_annotation.annotation_id), "admin-token", text_value="NEW-VALUE", ) assert sample_annotation.source == "manual" assert sample_annotation.override_source is None def test_override_skips_unknown_attributes(self, repo, sample_annotation): """Test override ignores unknown attributes.""" sample_annotation.source = "auto" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_annotation mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.override( str(sample_annotation.annotation_id), "admin-token", unknown_field="should_be_ignored", text_value="VALID", ) assert result is not None assert sample_annotation.text_value == "VALID" assert not hasattr(sample_annotation, "unknown_field") or getattr(sample_annotation, "unknown_field", None) != "should_be_ignored" # ========================================================================= # create_history() tests # ========================================================================= def test_create_history_returns_history(self, repo): """Test create_history returns created history record.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) annotation_id = uuid4() document_id = uuid4() result = repo.create_history( annotation_id=annotation_id, document_id=document_id, action="create", previous_value=None, new_value={"class_name": "invoice_number"}, changed_by="admin-token", change_reason="Initial creation", ) mock_session.add.assert_called_once() mock_session.commit.assert_called_once() def test_create_history_with_minimal_params(self, repo): """Test create_history with minimal parameters.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.create_history( annotation_id=uuid4(), document_id=uuid4(), action="delete", ) mock_session.add.assert_called_once() added_history = mock_session.add.call_args[0][0] assert added_history.action == "delete" assert added_history.previous_value is None assert added_history.new_value is None # ========================================================================= # get_history() tests # ========================================================================= def test_get_history_returns_list(self, repo, sample_history): """Test get_history returns list of history records.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [sample_history] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_history(sample_history.annotation_id) assert len(result) == 1 assert result[0].action == "override" def test_get_history_returns_empty_list(self, repo): """Test get_history returns empty list when no history.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_history(uuid4()) assert result == [] # ========================================================================= # get_document_history() tests # ========================================================================= def test_get_document_history_returns_list(self, repo, sample_history): """Test get_document_history returns list of history records.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [sample_history] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_document_history(sample_history.document_id) assert len(result) == 1 def test_get_document_history_returns_empty_list(self, repo): """Test get_document_history returns empty list when no history.""" with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_document_history(uuid4()) assert result == []