Files
invoice-master-poc-v2/tests/integration/repositories/test_document_repo_integration.py
Yaojia Wang b602d0a340 re-structure
2026-02-01 22:55:31 +01:00

351 lines
11 KiB
Python

"""
Document Repository Integration Tests
Tests DocumentRepository with real database operations.
"""
from datetime import datetime, timezone, timedelta
from uuid import uuid4
import pytest
from sqlmodel import select
from backend.data.admin_models import AdminAnnotation, AdminDocument
from backend.data.repositories.document_repository import DocumentRepository
def ensure_utc(dt: datetime | None) -> datetime | None:
"""Ensure datetime is timezone-aware (UTC).
PostgreSQL may return offset-naive datetimes. This helper
converts them to UTC for proper comparison.
"""
if dt is None:
return None
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
return dt
class TestDocumentRepositoryCreate:
"""Tests for document creation."""
def test_create_document(self, patched_session):
"""Test creating a document and retrieving it."""
repo = DocumentRepository()
doc_id = repo.create(
filename="test_invoice.pdf",
file_size=2048,
content_type="application/pdf",
file_path="/uploads/test_invoice.pdf",
page_count=2,
upload_source="api",
category="invoice",
)
assert doc_id is not None
doc = repo.get(doc_id)
assert doc is not None
assert doc.filename == "test_invoice.pdf"
assert doc.file_size == 2048
assert doc.page_count == 2
assert doc.upload_source == "api"
assert doc.category == "invoice"
assert doc.status == "pending"
def test_create_document_with_csv_values(self, patched_session):
"""Test creating document with CSV field values."""
repo = DocumentRepository()
csv_values = {
"invoice_number": "INV-001",
"amount": "1500.00",
"supplier_name": "Test Supplier AB",
}
doc_id = repo.create(
filename="invoice_with_csv.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/invoice_with_csv.pdf",
csv_field_values=csv_values,
)
doc = repo.get(doc_id)
assert doc is not None
assert doc.csv_field_values == csv_values
def test_create_document_with_group_key(self, patched_session):
"""Test creating document with group key."""
repo = DocumentRepository()
doc_id = repo.create(
filename="grouped_doc.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/grouped_doc.pdf",
group_key="batch-2024-01",
)
doc = repo.get(doc_id)
assert doc is not None
assert doc.group_key == "batch-2024-01"
class TestDocumentRepositoryRead:
"""Tests for document retrieval."""
def test_get_nonexistent_document(self, patched_session):
"""Test getting a document that doesn't exist."""
repo = DocumentRepository()
doc = repo.get(str(uuid4()))
assert doc is None
def test_get_paginated_documents(self, patched_session, multiple_documents):
"""Test paginated document listing."""
repo = DocumentRepository()
docs, total = repo.get_paginated(limit=2, offset=0)
assert total == 5
assert len(docs) == 2
def test_get_paginated_with_status_filter(self, patched_session, multiple_documents):
"""Test filtering documents by status."""
repo = DocumentRepository()
docs, total = repo.get_paginated(status="labeled")
assert total == 2
for doc in docs:
assert doc.status == "labeled"
def test_get_paginated_with_category_filter(self, patched_session, multiple_documents):
"""Test filtering documents by category."""
repo = DocumentRepository()
docs, total = repo.get_paginated(category="letter")
assert total == 1
assert docs[0].category == "letter"
def test_get_paginated_with_offset(self, patched_session, multiple_documents):
"""Test pagination offset."""
repo = DocumentRepository()
docs_page1, _ = repo.get_paginated(limit=2, offset=0)
docs_page2, _ = repo.get_paginated(limit=2, offset=2)
doc_ids_page1 = {str(d.document_id) for d in docs_page1}
doc_ids_page2 = {str(d.document_id) for d in docs_page2}
assert len(doc_ids_page1 & doc_ids_page2) == 0
def test_get_by_ids(self, patched_session, multiple_documents):
"""Test getting multiple documents by IDs."""
repo = DocumentRepository()
ids_to_fetch = [str(multiple_documents[0].document_id), str(multiple_documents[2].document_id)]
docs = repo.get_by_ids(ids_to_fetch)
assert len(docs) == 2
fetched_ids = {str(d.document_id) for d in docs}
assert fetched_ids == set(ids_to_fetch)
class TestDocumentRepositoryUpdate:
"""Tests for document updates."""
def test_update_status(self, patched_session, sample_document):
"""Test updating document status."""
repo = DocumentRepository()
repo.update_status(
str(sample_document.document_id),
status="labeled",
auto_label_status="completed",
)
doc = repo.get(str(sample_document.document_id))
assert doc is not None
assert doc.status == "labeled"
assert doc.auto_label_status == "completed"
def test_update_status_with_error(self, patched_session, sample_document):
"""Test updating document status with error message."""
repo = DocumentRepository()
repo.update_status(
str(sample_document.document_id),
status="pending",
auto_label_status="failed",
auto_label_error="OCR extraction failed",
)
doc = repo.get(str(sample_document.document_id))
assert doc is not None
assert doc.auto_label_status == "failed"
assert doc.auto_label_error == "OCR extraction failed"
def test_update_file_path(self, patched_session, sample_document):
"""Test updating document file path."""
repo = DocumentRepository()
new_path = "/archive/2024/test_invoice.pdf"
repo.update_file_path(str(sample_document.document_id), new_path)
doc = repo.get(str(sample_document.document_id))
assert doc is not None
assert doc.file_path == new_path
def test_update_group_key(self, patched_session, sample_document):
"""Test updating document group key."""
repo = DocumentRepository()
result = repo.update_group_key(str(sample_document.document_id), "new-group-key")
assert result is True
doc = repo.get(str(sample_document.document_id))
assert doc is not None
assert doc.group_key == "new-group-key"
def test_update_category(self, patched_session, sample_document):
"""Test updating document category."""
repo = DocumentRepository()
doc = repo.update_category(str(sample_document.document_id), "letter")
assert doc is not None
assert doc.category == "letter"
class TestDocumentRepositoryDelete:
"""Tests for document deletion."""
def test_delete_document(self, patched_session, sample_document):
"""Test deleting a document."""
repo = DocumentRepository()
result = repo.delete(str(sample_document.document_id))
assert result is True
doc = repo.get(str(sample_document.document_id))
assert doc is None
def test_delete_document_with_annotations(self, patched_session, sample_document, sample_annotation):
"""Test deleting document also deletes its annotations."""
repo = DocumentRepository()
result = repo.delete(str(sample_document.document_id))
assert result is True
# Verify annotation is also deleted
from backend.data.repositories.annotation_repository import AnnotationRepository
ann_repo = AnnotationRepository()
annotations = ann_repo.get_for_document(str(sample_document.document_id))
assert len(annotations) == 0
def test_delete_nonexistent_document(self, patched_session):
"""Test deleting a document that doesn't exist."""
repo = DocumentRepository()
result = repo.delete(str(uuid4()))
assert result is False
class TestDocumentRepositoryQueries:
"""Tests for complex document queries."""
def test_count_by_status(self, patched_session, multiple_documents):
"""Test counting documents by status."""
repo = DocumentRepository()
counts = repo.count_by_status()
assert counts.get("pending") == 2
assert counts.get("labeled") == 2
assert counts.get("exported") == 1
def test_get_categories(self, patched_session, multiple_documents):
"""Test getting unique categories."""
repo = DocumentRepository()
categories = repo.get_categories()
assert "invoice" in categories
assert "letter" in categories
def test_get_labeled_for_export(self, patched_session, multiple_documents):
"""Test getting labeled documents for export."""
repo = DocumentRepository()
docs = repo.get_labeled_for_export()
assert len(docs) == 2
for doc in docs:
assert doc.status == "labeled"
class TestDocumentAnnotationLocking:
"""Tests for annotation locking mechanism."""
def test_acquire_annotation_lock(self, patched_session, sample_document):
"""Test acquiring annotation lock."""
repo = DocumentRepository()
doc = repo.acquire_annotation_lock(
str(sample_document.document_id),
duration_seconds=300,
)
assert doc is not None
assert doc.annotation_lock_until is not None
lock_until = ensure_utc(doc.annotation_lock_until)
assert lock_until > datetime.now(timezone.utc)
def test_acquire_lock_when_already_locked(self, patched_session, sample_document):
"""Test acquiring lock fails when already locked."""
repo = DocumentRepository()
# First lock
repo.acquire_annotation_lock(str(sample_document.document_id), duration_seconds=300)
# Second lock attempt should fail
result = repo.acquire_annotation_lock(str(sample_document.document_id))
assert result is None
def test_release_annotation_lock(self, patched_session, sample_document):
"""Test releasing annotation lock."""
repo = DocumentRepository()
repo.acquire_annotation_lock(str(sample_document.document_id), duration_seconds=300)
doc = repo.release_annotation_lock(str(sample_document.document_id))
assert doc is not None
assert doc.annotation_lock_until is None
def test_extend_annotation_lock(self, patched_session, sample_document):
"""Test extending annotation lock."""
repo = DocumentRepository()
# Acquire initial lock
initial_doc = repo.acquire_annotation_lock(
str(sample_document.document_id),
duration_seconds=300,
)
initial_expiry = ensure_utc(initial_doc.annotation_lock_until)
# Extend lock
extended_doc = repo.extend_annotation_lock(
str(sample_document.document_id),
additional_seconds=300,
)
assert extended_doc is not None
extended_expiry = ensure_utc(extended_doc.annotation_lock_until)
assert extended_expiry > initial_expiry