""" Annotation Repository Integration Tests Tests AnnotationRepository with real database operations. """ from uuid import uuid4 import pytest from inference.data.repositories.annotation_repository import AnnotationRepository class TestAnnotationRepositoryCreate: """Tests for annotation creation.""" def test_create_annotation(self, patched_session, sample_document): """Test creating a single annotation.""" repo = AnnotationRepository() ann_id = repo.create( document_id=str(sample_document.document_id), page_number=1, class_id=0, class_name="invoice_number", x_center=0.5, y_center=0.3, width=0.2, height=0.05, bbox_x=400, bbox_y=240, bbox_width=160, bbox_height=40, text_value="INV-2024-001", confidence=0.95, source="auto", ) assert ann_id is not None ann = repo.get(ann_id) assert ann is not None assert ann.class_name == "invoice_number" assert ann.text_value == "INV-2024-001" assert ann.confidence == 0.95 assert ann.source == "auto" def test_create_batch_annotations(self, patched_session, sample_document): """Test batch creation of annotations.""" repo = AnnotationRepository() annotations_data = [ { "document_id": str(sample_document.document_id), "page_number": 1, "class_id": 0, "class_name": "invoice_number", "x_center": 0.5, "y_center": 0.1, "width": 0.2, "height": 0.05, "bbox_x": 400, "bbox_y": 80, "bbox_width": 160, "bbox_height": 40, "text_value": "INV-001", "confidence": 0.95, }, { "document_id": str(sample_document.document_id), "page_number": 1, "class_id": 1, "class_name": "invoice_date", "x_center": 0.5, "y_center": 0.2, "width": 0.15, "height": 0.04, "bbox_x": 400, "bbox_y": 160, "bbox_width": 120, "bbox_height": 32, "text_value": "2024-01-15", "confidence": 0.92, }, { "document_id": str(sample_document.document_id), "page_number": 1, "class_id": 6, "class_name": "amount", "x_center": 0.7, "y_center": 0.8, "width": 0.1, "height": 0.04, "bbox_x": 560, "bbox_y": 640, "bbox_width": 80, "bbox_height": 32, "text_value": "1500.00", "confidence": 0.98, }, ] ids = repo.create_batch(annotations_data) assert len(ids) == 3 # Verify all annotations exist for ann_id in ids: ann = repo.get(ann_id) assert ann is not None class TestAnnotationRepositoryRead: """Tests for annotation retrieval.""" def test_get_nonexistent_annotation(self, patched_session): """Test getting an annotation that doesn't exist.""" repo = AnnotationRepository() ann = repo.get(str(uuid4())) assert ann is None def test_get_annotations_for_document(self, patched_session, sample_document, sample_annotation): """Test getting all annotations for a document.""" repo = AnnotationRepository() # Add another annotation repo.create( document_id=str(sample_document.document_id), page_number=1, class_id=1, class_name="invoice_date", x_center=0.5, y_center=0.4, width=0.15, height=0.04, bbox_x=400, bbox_y=320, bbox_width=120, bbox_height=32, text_value="2024-01-15", ) annotations = repo.get_for_document(str(sample_document.document_id)) assert len(annotations) == 2 # Should be ordered by class_id assert annotations[0].class_id == 0 assert annotations[1].class_id == 1 def test_get_annotations_for_specific_page(self, patched_session, sample_document): """Test getting annotations for a specific page.""" repo = AnnotationRepository() # Create annotations on different pages repo.create( document_id=str(sample_document.document_id), page_number=1, class_id=0, class_name="invoice_number", x_center=0.5, y_center=0.1, width=0.2, height=0.05, bbox_x=400, bbox_y=80, bbox_width=160, bbox_height=40, ) repo.create( document_id=str(sample_document.document_id), page_number=2, class_id=6, class_name="amount", x_center=0.7, y_center=0.8, width=0.1, height=0.04, bbox_x=560, bbox_y=640, bbox_width=80, bbox_height=32, ) page1_annotations = repo.get_for_document( str(sample_document.document_id), page_number=1, ) page2_annotations = repo.get_for_document( str(sample_document.document_id), page_number=2, ) assert len(page1_annotations) == 1 assert len(page2_annotations) == 1 assert page1_annotations[0].page_number == 1 assert page2_annotations[0].page_number == 2 class TestAnnotationRepositoryUpdate: """Tests for annotation updates.""" def test_update_annotation_bbox(self, patched_session, sample_annotation): """Test updating annotation bounding box.""" repo = AnnotationRepository() result = repo.update( str(sample_annotation.annotation_id), x_center=0.6, y_center=0.4, width=0.25, height=0.06, bbox_x=480, bbox_y=320, bbox_width=200, bbox_height=48, ) assert result is True ann = repo.get(str(sample_annotation.annotation_id)) assert ann is not None assert ann.x_center == 0.6 assert ann.y_center == 0.4 assert ann.bbox_x == 480 assert ann.bbox_width == 200 def test_update_annotation_text(self, patched_session, sample_annotation): """Test updating annotation text value.""" repo = AnnotationRepository() result = repo.update( str(sample_annotation.annotation_id), text_value="INV-2024-002", ) assert result is True ann = repo.get(str(sample_annotation.annotation_id)) assert ann is not None assert ann.text_value == "INV-2024-002" def test_update_annotation_class(self, patched_session, sample_annotation): """Test updating annotation class.""" repo = AnnotationRepository() result = repo.update( str(sample_annotation.annotation_id), class_id=1, class_name="invoice_date", ) assert result is True ann = repo.get(str(sample_annotation.annotation_id)) assert ann is not None assert ann.class_id == 1 assert ann.class_name == "invoice_date" def test_update_nonexistent_annotation(self, patched_session): """Test updating annotation that doesn't exist.""" repo = AnnotationRepository() result = repo.update( str(uuid4()), text_value="new value", ) assert result is False class TestAnnotationRepositoryDelete: """Tests for annotation deletion.""" def test_delete_annotation(self, patched_session, sample_annotation): """Test deleting a single annotation.""" repo = AnnotationRepository() result = repo.delete(str(sample_annotation.annotation_id)) assert result is True ann = repo.get(str(sample_annotation.annotation_id)) assert ann is None def test_delete_nonexistent_annotation(self, patched_session): """Test deleting annotation that doesn't exist.""" repo = AnnotationRepository() result = repo.delete(str(uuid4())) assert result is False def test_delete_annotations_for_document(self, patched_session, sample_document): """Test deleting all annotations for a document.""" repo = AnnotationRepository() # Create multiple annotations for i in range(3): repo.create( document_id=str(sample_document.document_id), page_number=1, class_id=i, class_name=f"field_{i}", x_center=0.5, y_center=0.1 + i * 0.2, width=0.2, height=0.05, bbox_x=400, bbox_y=80 + i * 160, bbox_width=160, bbox_height=40, ) # Delete all count = repo.delete_for_document(str(sample_document.document_id)) assert count == 3 annotations = repo.get_for_document(str(sample_document.document_id)) assert len(annotations) == 0 def test_delete_annotations_by_source(self, patched_session, sample_document): """Test deleting annotations by source type.""" repo = AnnotationRepository() # Create auto and manual annotations repo.create( document_id=str(sample_document.document_id), page_number=1, class_id=0, class_name="invoice_number", x_center=0.5, y_center=0.1, width=0.2, height=0.05, bbox_x=400, bbox_y=80, bbox_width=160, bbox_height=40, source="auto", ) repo.create( document_id=str(sample_document.document_id), page_number=1, class_id=1, class_name="invoice_date", x_center=0.5, y_center=0.2, width=0.15, height=0.04, bbox_x=400, bbox_y=160, bbox_width=120, bbox_height=32, source="manual", ) # Delete only auto annotations count = repo.delete_for_document(str(sample_document.document_id), source="auto") assert count == 1 remaining = repo.get_for_document(str(sample_document.document_id)) assert len(remaining) == 1 assert remaining[0].source == "manual" class TestAnnotationVerification: """Tests for annotation verification.""" def test_verify_annotation(self, patched_session, admin_token, sample_annotation): """Test marking annotation as verified.""" repo = AnnotationRepository() ann = repo.verify(str(sample_annotation.annotation_id), admin_token.token) assert ann is not None assert ann.is_verified is True assert ann.verified_by == admin_token.token assert ann.verified_at is not None class TestAnnotationOverride: """Tests for annotation override functionality.""" def test_override_auto_annotation(self, patched_session, admin_token, sample_annotation): """Test overriding an auto-generated annotation.""" repo = AnnotationRepository() # Override the annotation ann = repo.override( str(sample_annotation.annotation_id), admin_token.token, change_reason="Correcting OCR error", text_value="INV-2024-CORRECTED", x_center=0.55, ) assert ann is not None assert ann.text_value == "INV-2024-CORRECTED" assert ann.x_center == 0.55 assert ann.source == "manual" # Changed from auto to manual assert ann.override_source == "auto" class TestAnnotationHistory: """Tests for annotation history tracking.""" def test_create_history_record(self, patched_session, sample_annotation): """Test creating annotation history record.""" repo = AnnotationRepository() history = repo.create_history( annotation_id=sample_annotation.annotation_id, document_id=sample_annotation.document_id, action="created", new_value={"text_value": "INV-001"}, changed_by="test-user", ) assert history is not None assert history.action == "created" assert history.changed_by == "test-user" def test_get_annotation_history(self, patched_session, sample_annotation): """Test getting history for an annotation.""" repo = AnnotationRepository() # Create history records repo.create_history( annotation_id=sample_annotation.annotation_id, document_id=sample_annotation.document_id, action="created", new_value={"text_value": "INV-001"}, ) repo.create_history( annotation_id=sample_annotation.annotation_id, document_id=sample_annotation.document_id, action="updated", previous_value={"text_value": "INV-001"}, new_value={"text_value": "INV-002"}, ) history = repo.get_history(sample_annotation.annotation_id) assert len(history) == 2 # Should be ordered by created_at desc assert history[0].action == "updated" assert history[1].action == "created" def test_get_document_history(self, patched_session, sample_document, sample_annotation): """Test getting all annotation history for a document.""" repo = AnnotationRepository() repo.create_history( annotation_id=sample_annotation.annotation_id, document_id=sample_document.document_id, action="created", new_value={"class_name": "invoice_number"}, ) history = repo.get_document_history(sample_document.document_id) assert len(history) >= 1 assert all(h.document_id == sample_document.document_id for h in history)