""" Tests for Phase 5: Annotation Enhancement (Verification and Override) """ import pytest from datetime import datetime from uuid import uuid4 from fastapi import FastAPI from fastapi.testclient import TestClient from inference.web.api.v1.admin.annotations import ( create_annotation_router, get_doc_repository, get_ann_repository, ) from inference.web.core.auth import validate_admin_token class MockAdminDocument: """Mock AdminDocument for testing.""" def __init__(self, **kwargs): self.document_id = kwargs.get('document_id', uuid4()) self.admin_token = kwargs.get('admin_token', 'test-token') self.filename = kwargs.get('filename', 'test.pdf') self.file_size = kwargs.get('file_size', 100000) self.content_type = kwargs.get('content_type', 'application/pdf') self.page_count = kwargs.get('page_count', 1) self.status = kwargs.get('status', 'labeled') self.auto_label_status = kwargs.get('auto_label_status', None) self.created_at = kwargs.get('created_at', datetime.utcnow()) self.updated_at = kwargs.get('updated_at', datetime.utcnow()) class MockAnnotation: """Mock AdminAnnotation for testing.""" def __init__(self, **kwargs): self.annotation_id = kwargs.get('annotation_id', uuid4()) self.document_id = kwargs.get('document_id') self.page_number = kwargs.get('page_number', 1) self.class_id = kwargs.get('class_id', 0) self.class_name = kwargs.get('class_name', 'invoice_number') self.bbox_x = kwargs.get('bbox_x', 100) self.bbox_y = kwargs.get('bbox_y', 100) self.bbox_width = kwargs.get('bbox_width', 200) self.bbox_height = kwargs.get('bbox_height', 50) self.x_center = kwargs.get('x_center', 0.5) self.y_center = kwargs.get('y_center', 0.5) self.width = kwargs.get('width', 0.3) self.height = kwargs.get('height', 0.1) self.text_value = kwargs.get('text_value', 'INV-001') self.confidence = kwargs.get('confidence', 0.95) self.source = kwargs.get('source', 'auto') self.is_verified = kwargs.get('is_verified', False) self.verified_at = kwargs.get('verified_at', None) self.verified_by = kwargs.get('verified_by', None) self.override_source = kwargs.get('override_source', None) self.original_annotation_id = kwargs.get('original_annotation_id', None) self.created_at = kwargs.get('created_at', datetime.utcnow()) self.updated_at = kwargs.get('updated_at', datetime.utcnow()) class MockAnnotationHistory: """Mock AnnotationHistory for testing.""" def __init__(self, **kwargs): self.history_id = kwargs.get('history_id', uuid4()) self.annotation_id = kwargs.get('annotation_id') self.document_id = kwargs.get('document_id') self.action = kwargs.get('action', 'override') self.previous_value = kwargs.get('previous_value', {}) self.new_value = kwargs.get('new_value', {}) self.changed_by = kwargs.get('changed_by', 'test-token') self.change_reason = kwargs.get('change_reason', None) self.created_at = kwargs.get('created_at', datetime.utcnow()) class MockDocumentRepository: """Mock DocumentRepository for testing Phase 5.""" def __init__(self): self.documents = {} def get(self, document_id): """Get document by ID.""" return self.documents.get(str(document_id)) def get_by_token(self, document_id, admin_token=None): """Get document by ID and token.""" doc = self.documents.get(str(document_id)) if doc and (admin_token is None or doc.admin_token == admin_token): return doc return None class MockAnnotationRepository: """Mock AnnotationRepository for testing Phase 5.""" def __init__(self): self.annotations = {} self.annotation_history = {} def get(self, annotation_id): """Get annotation by ID.""" return self.annotations.get(str(annotation_id)) def get_for_document(self, document_id, page_number=None): """Get annotations for a document.""" return [a for a in self.annotations.values() if str(a.document_id) == str(document_id)] def verify(self, annotation_id, admin_token): """Mark annotation as verified.""" annotation = self.annotations.get(str(annotation_id)) if annotation: annotation.is_verified = True annotation.verified_at = datetime.utcnow() annotation.verified_by = admin_token return annotation return None def override( self, annotation_id, admin_token, change_reason=None, **updates, ): """Override an annotation.""" annotation = self.annotations.get(str(annotation_id)) if annotation: # Apply updates for key, value in updates.items(): if hasattr(annotation, key): setattr(annotation, key, value) # Mark as overridden if was auto-generated if annotation.source == "auto": annotation.override_source = "auto" annotation.source = "manual" # Create history record history = MockAnnotationHistory( annotation_id=uuid4().hex if isinstance(annotation_id, str) else annotation_id, document_id=annotation.document_id, action="override", changed_by=admin_token, change_reason=change_reason, ) self.annotation_history[str(annotation.annotation_id)] = [history] return annotation return None def get_history(self, annotation_id): """Get annotation history.""" return self.annotation_history.get(str(annotation_id), []) @pytest.fixture def app(): """Create test FastAPI app.""" app = FastAPI() # Create mock repositories mock_document_repo = MockDocumentRepository() mock_annotation_repo = MockAnnotationRepository() # Add test document doc1 = MockAdminDocument( filename="TEST001.pdf", status="labeled", ) mock_document_repo.documents[str(doc1.document_id)] = doc1 # Add test annotations ann1 = MockAnnotation( document_id=doc1.document_id, class_id=0, class_name="invoice_number", text_value="INV-001", source="auto", confidence=0.95, ) ann2 = MockAnnotation( document_id=doc1.document_id, class_id=6, class_name="amount", text_value="1500.00", source="auto", confidence=0.98, ) mock_annotation_repo.annotations[str(ann1.annotation_id)] = ann1 mock_annotation_repo.annotations[str(ann2.annotation_id)] = ann2 # Store document ID and annotation IDs for tests app.state.document_id = str(doc1.document_id) app.state.annotation_id_1 = str(ann1.annotation_id) app.state.annotation_id_2 = str(ann2.annotation_id) # Override dependencies app.dependency_overrides[validate_admin_token] = lambda: "test-token" app.dependency_overrides[get_doc_repository] = lambda: mock_document_repo app.dependency_overrides[get_ann_repository] = lambda: mock_annotation_repo # Include router router = create_annotation_router() app.include_router(router) return app @pytest.fixture def client(app): """Create test client.""" return TestClient(app) class TestAnnotationVerification: """Tests for POST /admin/documents/{document_id}/annotations/{annotation_id}/verify endpoint.""" def test_verify_annotation_success(self, client, app): """Test successfully verifying an annotation.""" document_id = app.state.document_id annotation_id = app.state.annotation_id_1 response = client.post( f"/admin/documents/{document_id}/annotations/{annotation_id}/verify" ) assert response.status_code == 200 data = response.json() assert data["annotation_id"] == annotation_id assert data["is_verified"] is True assert data["verified_at"] is not None assert data["verified_by"] == "test-token" assert "verified successfully" in data["message"].lower() def test_verify_annotation_not_found(self, client, app): """Test verifying non-existent annotation.""" document_id = app.state.document_id fake_annotation_id = str(uuid4()) response = client.post( f"/admin/documents/{document_id}/annotations/{fake_annotation_id}/verify" ) assert response.status_code == 404 assert "not found" in response.json()["detail"].lower() def test_verify_annotation_document_not_found(self, client): """Test verifying annotation with non-existent document.""" fake_document_id = str(uuid4()) fake_annotation_id = str(uuid4()) response = client.post( f"/admin/documents/{fake_document_id}/annotations/{fake_annotation_id}/verify" ) assert response.status_code == 404 assert "not found" in response.json()["detail"].lower() def test_verify_annotation_invalid_uuid(self, client, app): """Test verifying annotation with invalid UUID format.""" document_id = app.state.document_id response = client.post( f"/admin/documents/{document_id}/annotations/invalid-uuid/verify" ) assert response.status_code == 400 assert "invalid" in response.json()["detail"].lower() class TestAnnotationOverride: """Tests for PATCH /admin/documents/{document_id}/annotations/{annotation_id}/override endpoint.""" def test_override_annotation_text_value(self, client, app): """Test overriding annotation text value.""" document_id = app.state.document_id annotation_id = app.state.annotation_id_1 response = client.patch( f"/admin/documents/{document_id}/annotations/{annotation_id}/override", json={ "text_value": "INV-001-CORRECTED", "reason": "OCR error correction" } ) assert response.status_code == 200 data = response.json() assert data["annotation_id"] == annotation_id assert data["source"] == "manual" assert data["override_source"] == "auto" assert "successfully" in data["message"].lower() assert "history_id" in data def test_override_annotation_bbox(self, client, app): """Test overriding annotation bounding box.""" document_id = app.state.document_id annotation_id = app.state.annotation_id_1 response = client.patch( f"/admin/documents/{document_id}/annotations/{annotation_id}/override", json={ "bbox": { "x": 110, "y": 205, "width": 195, "height": 48 }, "reason": "Bbox adjustment" } ) assert response.status_code == 200 data = response.json() assert data["annotation_id"] == annotation_id assert data["source"] == "manual" def test_override_annotation_class(self, client, app): """Test overriding annotation class.""" document_id = app.state.document_id annotation_id = app.state.annotation_id_1 response = client.patch( f"/admin/documents/{document_id}/annotations/{annotation_id}/override", json={ "class_id": 1, "class_name": "invoice_date", "reason": "Wrong field classification" } ) assert response.status_code == 200 data = response.json() assert data["annotation_id"] == annotation_id def test_override_annotation_multiple_fields(self, client, app): """Test overriding multiple annotation fields at once.""" document_id = app.state.document_id annotation_id = app.state.annotation_id_2 response = client.patch( f"/admin/documents/{document_id}/annotations/{annotation_id}/override", json={ "text_value": "1550.00", "bbox": { "x": 120, "y": 210, "width": 180, "height": 45 }, "reason": "Multiple corrections" } ) assert response.status_code == 200 data = response.json() assert data["annotation_id"] == annotation_id def test_override_annotation_no_updates(self, client, app): """Test overriding annotation without providing any updates.""" document_id = app.state.document_id annotation_id = app.state.annotation_id_1 response = client.patch( f"/admin/documents/{document_id}/annotations/{annotation_id}/override", json={} ) assert response.status_code == 400 assert "no updates" in response.json()["detail"].lower() def test_override_annotation_not_found(self, client, app): """Test overriding non-existent annotation.""" document_id = app.state.document_id fake_annotation_id = str(uuid4()) response = client.patch( f"/admin/documents/{document_id}/annotations/{fake_annotation_id}/override", json={ "text_value": "TEST" } ) assert response.status_code == 404 assert "not found" in response.json()["detail"].lower() def test_override_annotation_document_not_found(self, client): """Test overriding annotation with non-existent document.""" fake_document_id = str(uuid4()) fake_annotation_id = str(uuid4()) response = client.patch( f"/admin/documents/{fake_document_id}/annotations/{fake_annotation_id}/override", json={ "text_value": "TEST" } ) assert response.status_code == 404 assert "not found" in response.json()["detail"].lower() def test_override_annotation_creates_history(self, client, app): """Test that overriding annotation creates history record.""" document_id = app.state.document_id annotation_id = app.state.annotation_id_1 response = client.patch( f"/admin/documents/{document_id}/annotations/{annotation_id}/override", json={ "text_value": "INV-CORRECTED", "reason": "Test history creation" } ) assert response.status_code == 200 data = response.json() # History ID should be present and valid assert "history_id" in data assert data["history_id"] != "" def test_override_annotation_with_reason(self, client, app): """Test overriding annotation with change reason.""" document_id = app.state.document_id annotation_id = app.state.annotation_id_1 change_reason = "Correcting OCR misread" response = client.patch( f"/admin/documents/{document_id}/annotations/{annotation_id}/override", json={ "text_value": "INV-002", "reason": change_reason } ) assert response.status_code == 200 # Reason is stored in history, not returned in response data = response.json() assert data["annotation_id"] == annotation_id