351 lines
11 KiB
Python
351 lines
11 KiB
Python
"""
|
|
Document Repository Integration Tests
|
|
|
|
Tests DocumentRepository with real database operations.
|
|
"""
|
|
|
|
from datetime import datetime, timezone, timedelta
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
from sqlmodel import select
|
|
|
|
from inference.data.admin_models import AdminAnnotation, AdminDocument
|
|
from inference.data.repositories.document_repository import DocumentRepository
|
|
|
|
|
|
def ensure_utc(dt: datetime | None) -> datetime | None:
|
|
"""Ensure datetime is timezone-aware (UTC).
|
|
|
|
PostgreSQL may return offset-naive datetimes. This helper
|
|
converts them to UTC for proper comparison.
|
|
"""
|
|
if dt is None:
|
|
return None
|
|
if dt.tzinfo is None:
|
|
return dt.replace(tzinfo=timezone.utc)
|
|
return dt
|
|
|
|
|
|
class TestDocumentRepositoryCreate:
|
|
"""Tests for document creation."""
|
|
|
|
def test_create_document(self, patched_session):
|
|
"""Test creating a document and retrieving it."""
|
|
repo = DocumentRepository()
|
|
|
|
doc_id = repo.create(
|
|
filename="test_invoice.pdf",
|
|
file_size=2048,
|
|
content_type="application/pdf",
|
|
file_path="/uploads/test_invoice.pdf",
|
|
page_count=2,
|
|
upload_source="api",
|
|
category="invoice",
|
|
)
|
|
|
|
assert doc_id is not None
|
|
|
|
doc = repo.get(doc_id)
|
|
assert doc is not None
|
|
assert doc.filename == "test_invoice.pdf"
|
|
assert doc.file_size == 2048
|
|
assert doc.page_count == 2
|
|
assert doc.upload_source == "api"
|
|
assert doc.category == "invoice"
|
|
assert doc.status == "pending"
|
|
|
|
def test_create_document_with_csv_values(self, patched_session):
|
|
"""Test creating document with CSV field values."""
|
|
repo = DocumentRepository()
|
|
|
|
csv_values = {
|
|
"invoice_number": "INV-001",
|
|
"amount": "1500.00",
|
|
"supplier_name": "Test Supplier AB",
|
|
}
|
|
|
|
doc_id = repo.create(
|
|
filename="invoice_with_csv.pdf",
|
|
file_size=1024,
|
|
content_type="application/pdf",
|
|
file_path="/uploads/invoice_with_csv.pdf",
|
|
csv_field_values=csv_values,
|
|
)
|
|
|
|
doc = repo.get(doc_id)
|
|
assert doc is not None
|
|
assert doc.csv_field_values == csv_values
|
|
|
|
def test_create_document_with_group_key(self, patched_session):
|
|
"""Test creating document with group key."""
|
|
repo = DocumentRepository()
|
|
|
|
doc_id = repo.create(
|
|
filename="grouped_doc.pdf",
|
|
file_size=1024,
|
|
content_type="application/pdf",
|
|
file_path="/uploads/grouped_doc.pdf",
|
|
group_key="batch-2024-01",
|
|
)
|
|
|
|
doc = repo.get(doc_id)
|
|
assert doc is not None
|
|
assert doc.group_key == "batch-2024-01"
|
|
|
|
|
|
class TestDocumentRepositoryRead:
|
|
"""Tests for document retrieval."""
|
|
|
|
def test_get_nonexistent_document(self, patched_session):
|
|
"""Test getting a document that doesn't exist."""
|
|
repo = DocumentRepository()
|
|
|
|
doc = repo.get(str(uuid4()))
|
|
assert doc is None
|
|
|
|
def test_get_paginated_documents(self, patched_session, multiple_documents):
|
|
"""Test paginated document listing."""
|
|
repo = DocumentRepository()
|
|
|
|
docs, total = repo.get_paginated(limit=2, offset=0)
|
|
|
|
assert total == 5
|
|
assert len(docs) == 2
|
|
|
|
def test_get_paginated_with_status_filter(self, patched_session, multiple_documents):
|
|
"""Test filtering documents by status."""
|
|
repo = DocumentRepository()
|
|
|
|
docs, total = repo.get_paginated(status="labeled")
|
|
|
|
assert total == 2
|
|
for doc in docs:
|
|
assert doc.status == "labeled"
|
|
|
|
def test_get_paginated_with_category_filter(self, patched_session, multiple_documents):
|
|
"""Test filtering documents by category."""
|
|
repo = DocumentRepository()
|
|
|
|
docs, total = repo.get_paginated(category="letter")
|
|
|
|
assert total == 1
|
|
assert docs[0].category == "letter"
|
|
|
|
def test_get_paginated_with_offset(self, patched_session, multiple_documents):
|
|
"""Test pagination offset."""
|
|
repo = DocumentRepository()
|
|
|
|
docs_page1, _ = repo.get_paginated(limit=2, offset=0)
|
|
docs_page2, _ = repo.get_paginated(limit=2, offset=2)
|
|
|
|
doc_ids_page1 = {str(d.document_id) for d in docs_page1}
|
|
doc_ids_page2 = {str(d.document_id) for d in docs_page2}
|
|
|
|
assert len(doc_ids_page1 & doc_ids_page2) == 0
|
|
|
|
def test_get_by_ids(self, patched_session, multiple_documents):
|
|
"""Test getting multiple documents by IDs."""
|
|
repo = DocumentRepository()
|
|
|
|
ids_to_fetch = [str(multiple_documents[0].document_id), str(multiple_documents[2].document_id)]
|
|
docs = repo.get_by_ids(ids_to_fetch)
|
|
|
|
assert len(docs) == 2
|
|
fetched_ids = {str(d.document_id) for d in docs}
|
|
assert fetched_ids == set(ids_to_fetch)
|
|
|
|
|
|
class TestDocumentRepositoryUpdate:
|
|
"""Tests for document updates."""
|
|
|
|
def test_update_status(self, patched_session, sample_document):
|
|
"""Test updating document status."""
|
|
repo = DocumentRepository()
|
|
|
|
repo.update_status(
|
|
str(sample_document.document_id),
|
|
status="labeled",
|
|
auto_label_status="completed",
|
|
)
|
|
|
|
doc = repo.get(str(sample_document.document_id))
|
|
assert doc is not None
|
|
assert doc.status == "labeled"
|
|
assert doc.auto_label_status == "completed"
|
|
|
|
def test_update_status_with_error(self, patched_session, sample_document):
|
|
"""Test updating document status with error message."""
|
|
repo = DocumentRepository()
|
|
|
|
repo.update_status(
|
|
str(sample_document.document_id),
|
|
status="pending",
|
|
auto_label_status="failed",
|
|
auto_label_error="OCR extraction failed",
|
|
)
|
|
|
|
doc = repo.get(str(sample_document.document_id))
|
|
assert doc is not None
|
|
assert doc.auto_label_status == "failed"
|
|
assert doc.auto_label_error == "OCR extraction failed"
|
|
|
|
def test_update_file_path(self, patched_session, sample_document):
|
|
"""Test updating document file path."""
|
|
repo = DocumentRepository()
|
|
|
|
new_path = "/archive/2024/test_invoice.pdf"
|
|
repo.update_file_path(str(sample_document.document_id), new_path)
|
|
|
|
doc = repo.get(str(sample_document.document_id))
|
|
assert doc is not None
|
|
assert doc.file_path == new_path
|
|
|
|
def test_update_group_key(self, patched_session, sample_document):
|
|
"""Test updating document group key."""
|
|
repo = DocumentRepository()
|
|
|
|
result = repo.update_group_key(str(sample_document.document_id), "new-group-key")
|
|
assert result is True
|
|
|
|
doc = repo.get(str(sample_document.document_id))
|
|
assert doc is not None
|
|
assert doc.group_key == "new-group-key"
|
|
|
|
def test_update_category(self, patched_session, sample_document):
|
|
"""Test updating document category."""
|
|
repo = DocumentRepository()
|
|
|
|
doc = repo.update_category(str(sample_document.document_id), "letter")
|
|
|
|
assert doc is not None
|
|
assert doc.category == "letter"
|
|
|
|
|
|
class TestDocumentRepositoryDelete:
|
|
"""Tests for document deletion."""
|
|
|
|
def test_delete_document(self, patched_session, sample_document):
|
|
"""Test deleting a document."""
|
|
repo = DocumentRepository()
|
|
|
|
result = repo.delete(str(sample_document.document_id))
|
|
assert result is True
|
|
|
|
doc = repo.get(str(sample_document.document_id))
|
|
assert doc is None
|
|
|
|
def test_delete_document_with_annotations(self, patched_session, sample_document, sample_annotation):
|
|
"""Test deleting document also deletes its annotations."""
|
|
repo = DocumentRepository()
|
|
|
|
result = repo.delete(str(sample_document.document_id))
|
|
assert result is True
|
|
|
|
# Verify annotation is also deleted
|
|
from inference.data.repositories.annotation_repository import AnnotationRepository
|
|
|
|
ann_repo = AnnotationRepository()
|
|
annotations = ann_repo.get_for_document(str(sample_document.document_id))
|
|
assert len(annotations) == 0
|
|
|
|
def test_delete_nonexistent_document(self, patched_session):
|
|
"""Test deleting a document that doesn't exist."""
|
|
repo = DocumentRepository()
|
|
|
|
result = repo.delete(str(uuid4()))
|
|
assert result is False
|
|
|
|
|
|
class TestDocumentRepositoryQueries:
|
|
"""Tests for complex document queries."""
|
|
|
|
def test_count_by_status(self, patched_session, multiple_documents):
|
|
"""Test counting documents by status."""
|
|
repo = DocumentRepository()
|
|
|
|
counts = repo.count_by_status()
|
|
|
|
assert counts.get("pending") == 2
|
|
assert counts.get("labeled") == 2
|
|
assert counts.get("exported") == 1
|
|
|
|
def test_get_categories(self, patched_session, multiple_documents):
|
|
"""Test getting unique categories."""
|
|
repo = DocumentRepository()
|
|
|
|
categories = repo.get_categories()
|
|
|
|
assert "invoice" in categories
|
|
assert "letter" in categories
|
|
|
|
def test_get_labeled_for_export(self, patched_session, multiple_documents):
|
|
"""Test getting labeled documents for export."""
|
|
repo = DocumentRepository()
|
|
|
|
docs = repo.get_labeled_for_export()
|
|
|
|
assert len(docs) == 2
|
|
for doc in docs:
|
|
assert doc.status == "labeled"
|
|
|
|
|
|
class TestDocumentAnnotationLocking:
|
|
"""Tests for annotation locking mechanism."""
|
|
|
|
def test_acquire_annotation_lock(self, patched_session, sample_document):
|
|
"""Test acquiring annotation lock."""
|
|
repo = DocumentRepository()
|
|
|
|
doc = repo.acquire_annotation_lock(
|
|
str(sample_document.document_id),
|
|
duration_seconds=300,
|
|
)
|
|
|
|
assert doc is not None
|
|
assert doc.annotation_lock_until is not None
|
|
lock_until = ensure_utc(doc.annotation_lock_until)
|
|
assert lock_until > datetime.now(timezone.utc)
|
|
|
|
def test_acquire_lock_when_already_locked(self, patched_session, sample_document):
|
|
"""Test acquiring lock fails when already locked."""
|
|
repo = DocumentRepository()
|
|
|
|
# First lock
|
|
repo.acquire_annotation_lock(str(sample_document.document_id), duration_seconds=300)
|
|
|
|
# Second lock attempt should fail
|
|
result = repo.acquire_annotation_lock(str(sample_document.document_id))
|
|
assert result is None
|
|
|
|
def test_release_annotation_lock(self, patched_session, sample_document):
|
|
"""Test releasing annotation lock."""
|
|
repo = DocumentRepository()
|
|
|
|
repo.acquire_annotation_lock(str(sample_document.document_id), duration_seconds=300)
|
|
doc = repo.release_annotation_lock(str(sample_document.document_id))
|
|
|
|
assert doc is not None
|
|
assert doc.annotation_lock_until is None
|
|
|
|
def test_extend_annotation_lock(self, patched_session, sample_document):
|
|
"""Test extending annotation lock."""
|
|
repo = DocumentRepository()
|
|
|
|
# Acquire initial lock
|
|
initial_doc = repo.acquire_annotation_lock(
|
|
str(sample_document.document_id),
|
|
duration_seconds=300,
|
|
)
|
|
initial_expiry = ensure_utc(initial_doc.annotation_lock_until)
|
|
|
|
# Extend lock
|
|
extended_doc = repo.extend_annotation_lock(
|
|
str(sample_document.document_id),
|
|
additional_seconds=300,
|
|
)
|
|
|
|
assert extended_doc is not None
|
|
extended_expiry = ensure_utc(extended_doc.annotation_lock_until)
|
|
assert extended_expiry > initial_expiry
|