""" Document Repository Integration Tests Tests DocumentRepository with real database operations. """ from datetime import datetime, timezone, timedelta from uuid import uuid4 import pytest from sqlmodel import select from backend.data.admin_models import AdminAnnotation, AdminDocument from backend.data.repositories.document_repository import DocumentRepository def ensure_utc(dt: datetime | None) -> datetime | None: """Ensure datetime is timezone-aware (UTC). PostgreSQL may return offset-naive datetimes. This helper converts them to UTC for proper comparison. """ if dt is None: return None if dt.tzinfo is None: return dt.replace(tzinfo=timezone.utc) return dt class TestDocumentRepositoryCreate: """Tests for document creation.""" def test_create_document(self, patched_session): """Test creating a document and retrieving it.""" repo = DocumentRepository() doc_id = repo.create( filename="test_invoice.pdf", file_size=2048, content_type="application/pdf", file_path="/uploads/test_invoice.pdf", page_count=2, upload_source="api", category="invoice", ) assert doc_id is not None doc = repo.get(doc_id) assert doc is not None assert doc.filename == "test_invoice.pdf" assert doc.file_size == 2048 assert doc.page_count == 2 assert doc.upload_source == "api" assert doc.category == "invoice" assert doc.status == "pending" def test_create_document_with_csv_values(self, patched_session): """Test creating document with CSV field values.""" repo = DocumentRepository() csv_values = { "invoice_number": "INV-001", "amount": "1500.00", "supplier_name": "Test Supplier AB", } doc_id = repo.create( filename="invoice_with_csv.pdf", file_size=1024, content_type="application/pdf", file_path="/uploads/invoice_with_csv.pdf", csv_field_values=csv_values, ) doc = repo.get(doc_id) assert doc is not None assert doc.csv_field_values == csv_values def test_create_document_with_group_key(self, patched_session): """Test creating document with group key.""" repo = DocumentRepository() doc_id = repo.create( filename="grouped_doc.pdf", file_size=1024, content_type="application/pdf", file_path="/uploads/grouped_doc.pdf", group_key="batch-2024-01", ) doc = repo.get(doc_id) assert doc is not None assert doc.group_key == "batch-2024-01" class TestDocumentRepositoryRead: """Tests for document retrieval.""" def test_get_nonexistent_document(self, patched_session): """Test getting a document that doesn't exist.""" repo = DocumentRepository() doc = repo.get(str(uuid4())) assert doc is None def test_get_paginated_documents(self, patched_session, multiple_documents): """Test paginated document listing.""" repo = DocumentRepository() docs, total = repo.get_paginated(limit=2, offset=0) assert total == 5 assert len(docs) == 2 def test_get_paginated_with_status_filter(self, patched_session, multiple_documents): """Test filtering documents by status.""" repo = DocumentRepository() docs, total = repo.get_paginated(status="labeled") assert total == 2 for doc in docs: assert doc.status == "labeled" def test_get_paginated_with_category_filter(self, patched_session, multiple_documents): """Test filtering documents by category.""" repo = DocumentRepository() docs, total = repo.get_paginated(category="letter") assert total == 1 assert docs[0].category == "letter" def test_get_paginated_with_offset(self, patched_session, multiple_documents): """Test pagination offset.""" repo = DocumentRepository() docs_page1, _ = repo.get_paginated(limit=2, offset=0) docs_page2, _ = repo.get_paginated(limit=2, offset=2) doc_ids_page1 = {str(d.document_id) for d in docs_page1} doc_ids_page2 = {str(d.document_id) for d in docs_page2} assert len(doc_ids_page1 & doc_ids_page2) == 0 def test_get_by_ids(self, patched_session, multiple_documents): """Test getting multiple documents by IDs.""" repo = DocumentRepository() ids_to_fetch = [str(multiple_documents[0].document_id), str(multiple_documents[2].document_id)] docs = repo.get_by_ids(ids_to_fetch) assert len(docs) == 2 fetched_ids = {str(d.document_id) for d in docs} assert fetched_ids == set(ids_to_fetch) class TestDocumentRepositoryUpdate: """Tests for document updates.""" def test_update_status(self, patched_session, sample_document): """Test updating document status.""" repo = DocumentRepository() repo.update_status( str(sample_document.document_id), status="labeled", auto_label_status="completed", ) doc = repo.get(str(sample_document.document_id)) assert doc is not None assert doc.status == "labeled" assert doc.auto_label_status == "completed" def test_update_status_with_error(self, patched_session, sample_document): """Test updating document status with error message.""" repo = DocumentRepository() repo.update_status( str(sample_document.document_id), status="pending", auto_label_status="failed", auto_label_error="OCR extraction failed", ) doc = repo.get(str(sample_document.document_id)) assert doc is not None assert doc.auto_label_status == "failed" assert doc.auto_label_error == "OCR extraction failed" def test_update_file_path(self, patched_session, sample_document): """Test updating document file path.""" repo = DocumentRepository() new_path = "/archive/2024/test_invoice.pdf" repo.update_file_path(str(sample_document.document_id), new_path) doc = repo.get(str(sample_document.document_id)) assert doc is not None assert doc.file_path == new_path def test_update_group_key(self, patched_session, sample_document): """Test updating document group key.""" repo = DocumentRepository() result = repo.update_group_key(str(sample_document.document_id), "new-group-key") assert result is True doc = repo.get(str(sample_document.document_id)) assert doc is not None assert doc.group_key == "new-group-key" def test_update_category(self, patched_session, sample_document): """Test updating document category.""" repo = DocumentRepository() doc = repo.update_category(str(sample_document.document_id), "letter") assert doc is not None assert doc.category == "letter" class TestDocumentRepositoryDelete: """Tests for document deletion.""" def test_delete_document(self, patched_session, sample_document): """Test deleting a document.""" repo = DocumentRepository() result = repo.delete(str(sample_document.document_id)) assert result is True doc = repo.get(str(sample_document.document_id)) assert doc is None def test_delete_document_with_annotations(self, patched_session, sample_document, sample_annotation): """Test deleting document also deletes its annotations.""" repo = DocumentRepository() result = repo.delete(str(sample_document.document_id)) assert result is True # Verify annotation is also deleted from backend.data.repositories.annotation_repository import AnnotationRepository ann_repo = AnnotationRepository() annotations = ann_repo.get_for_document(str(sample_document.document_id)) assert len(annotations) == 0 def test_delete_nonexistent_document(self, patched_session): """Test deleting a document that doesn't exist.""" repo = DocumentRepository() result = repo.delete(str(uuid4())) assert result is False class TestDocumentRepositoryQueries: """Tests for complex document queries.""" def test_count_by_status(self, patched_session, multiple_documents): """Test counting documents by status.""" repo = DocumentRepository() counts = repo.count_by_status() assert counts.get("pending") == 2 assert counts.get("labeled") == 2 assert counts.get("exported") == 1 def test_get_categories(self, patched_session, multiple_documents): """Test getting unique categories.""" repo = DocumentRepository() categories = repo.get_categories() assert "invoice" in categories assert "letter" in categories def test_get_labeled_for_export(self, patched_session, multiple_documents): """Test getting labeled documents for export.""" repo = DocumentRepository() docs = repo.get_labeled_for_export() assert len(docs) == 2 for doc in docs: assert doc.status == "labeled" class TestDocumentAnnotationLocking: """Tests for annotation locking mechanism.""" def test_acquire_annotation_lock(self, patched_session, sample_document): """Test acquiring annotation lock.""" repo = DocumentRepository() doc = repo.acquire_annotation_lock( str(sample_document.document_id), duration_seconds=300, ) assert doc is not None assert doc.annotation_lock_until is not None lock_until = ensure_utc(doc.annotation_lock_until) assert lock_until > datetime.now(timezone.utc) def test_acquire_lock_when_already_locked(self, patched_session, sample_document): """Test acquiring lock fails when already locked.""" repo = DocumentRepository() # First lock repo.acquire_annotation_lock(str(sample_document.document_id), duration_seconds=300) # Second lock attempt should fail result = repo.acquire_annotation_lock(str(sample_document.document_id)) assert result is None def test_release_annotation_lock(self, patched_session, sample_document): """Test releasing annotation lock.""" repo = DocumentRepository() repo.acquire_annotation_lock(str(sample_document.document_id), duration_seconds=300) doc = repo.release_annotation_lock(str(sample_document.document_id)) assert doc is not None assert doc.annotation_lock_until is None def test_extend_annotation_lock(self, patched_session, sample_document): """Test extending annotation lock.""" repo = DocumentRepository() # Acquire initial lock initial_doc = repo.acquire_annotation_lock( str(sample_document.document_id), duration_seconds=300, ) initial_expiry = ensure_utc(initial_doc.annotation_lock_until) # Extend lock extended_doc = repo.extend_annotation_lock( str(sample_document.document_id), additional_seconds=300, ) assert extended_doc is not None extended_expiry = ensure_utc(extended_doc.annotation_lock_until) assert extended_expiry > initial_expiry