Add more tests
This commit is contained in:
350
tests/integration/repositories/test_document_repo_integration.py
Normal file
350
tests/integration/repositories/test_document_repo_integration.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
Document Repository Integration Tests
|
||||
|
||||
Tests DocumentRepository with real database operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlmodel import select
|
||||
|
||||
from inference.data.admin_models import AdminAnnotation, AdminDocument
|
||||
from inference.data.repositories.document_repository import DocumentRepository
|
||||
|
||||
|
||||
def ensure_utc(dt: datetime | None) -> datetime | None:
|
||||
"""Ensure datetime is timezone-aware (UTC).
|
||||
|
||||
PostgreSQL may return offset-naive datetimes. This helper
|
||||
converts them to UTC for proper comparison.
|
||||
"""
|
||||
if dt is None:
|
||||
return None
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
|
||||
|
||||
class TestDocumentRepositoryCreate:
|
||||
"""Tests for document creation."""
|
||||
|
||||
def test_create_document(self, patched_session):
|
||||
"""Test creating a document and retrieving it."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
doc_id = repo.create(
|
||||
filename="test_invoice.pdf",
|
||||
file_size=2048,
|
||||
content_type="application/pdf",
|
||||
file_path="/uploads/test_invoice.pdf",
|
||||
page_count=2,
|
||||
upload_source="api",
|
||||
category="invoice",
|
||||
)
|
||||
|
||||
assert doc_id is not None
|
||||
|
||||
doc = repo.get(doc_id)
|
||||
assert doc is not None
|
||||
assert doc.filename == "test_invoice.pdf"
|
||||
assert doc.file_size == 2048
|
||||
assert doc.page_count == 2
|
||||
assert doc.upload_source == "api"
|
||||
assert doc.category == "invoice"
|
||||
assert doc.status == "pending"
|
||||
|
||||
def test_create_document_with_csv_values(self, patched_session):
|
||||
"""Test creating document with CSV field values."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
csv_values = {
|
||||
"invoice_number": "INV-001",
|
||||
"amount": "1500.00",
|
||||
"supplier_name": "Test Supplier AB",
|
||||
}
|
||||
|
||||
doc_id = repo.create(
|
||||
filename="invoice_with_csv.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/uploads/invoice_with_csv.pdf",
|
||||
csv_field_values=csv_values,
|
||||
)
|
||||
|
||||
doc = repo.get(doc_id)
|
||||
assert doc is not None
|
||||
assert doc.csv_field_values == csv_values
|
||||
|
||||
def test_create_document_with_group_key(self, patched_session):
|
||||
"""Test creating document with group key."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
doc_id = repo.create(
|
||||
filename="grouped_doc.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/uploads/grouped_doc.pdf",
|
||||
group_key="batch-2024-01",
|
||||
)
|
||||
|
||||
doc = repo.get(doc_id)
|
||||
assert doc is not None
|
||||
assert doc.group_key == "batch-2024-01"
|
||||
|
||||
|
||||
class TestDocumentRepositoryRead:
|
||||
"""Tests for document retrieval."""
|
||||
|
||||
def test_get_nonexistent_document(self, patched_session):
|
||||
"""Test getting a document that doesn't exist."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
doc = repo.get(str(uuid4()))
|
||||
assert doc is None
|
||||
|
||||
def test_get_paginated_documents(self, patched_session, multiple_documents):
|
||||
"""Test paginated document listing."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
docs, total = repo.get_paginated(limit=2, offset=0)
|
||||
|
||||
assert total == 5
|
||||
assert len(docs) == 2
|
||||
|
||||
def test_get_paginated_with_status_filter(self, patched_session, multiple_documents):
|
||||
"""Test filtering documents by status."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
docs, total = repo.get_paginated(status="labeled")
|
||||
|
||||
assert total == 2
|
||||
for doc in docs:
|
||||
assert doc.status == "labeled"
|
||||
|
||||
def test_get_paginated_with_category_filter(self, patched_session, multiple_documents):
|
||||
"""Test filtering documents by category."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
docs, total = repo.get_paginated(category="letter")
|
||||
|
||||
assert total == 1
|
||||
assert docs[0].category == "letter"
|
||||
|
||||
def test_get_paginated_with_offset(self, patched_session, multiple_documents):
|
||||
"""Test pagination offset."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
docs_page1, _ = repo.get_paginated(limit=2, offset=0)
|
||||
docs_page2, _ = repo.get_paginated(limit=2, offset=2)
|
||||
|
||||
doc_ids_page1 = {str(d.document_id) for d in docs_page1}
|
||||
doc_ids_page2 = {str(d.document_id) for d in docs_page2}
|
||||
|
||||
assert len(doc_ids_page1 & doc_ids_page2) == 0
|
||||
|
||||
def test_get_by_ids(self, patched_session, multiple_documents):
|
||||
"""Test getting multiple documents by IDs."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
ids_to_fetch = [str(multiple_documents[0].document_id), str(multiple_documents[2].document_id)]
|
||||
docs = repo.get_by_ids(ids_to_fetch)
|
||||
|
||||
assert len(docs) == 2
|
||||
fetched_ids = {str(d.document_id) for d in docs}
|
||||
assert fetched_ids == set(ids_to_fetch)
|
||||
|
||||
|
||||
class TestDocumentRepositoryUpdate:
|
||||
"""Tests for document updates."""
|
||||
|
||||
def test_update_status(self, patched_session, sample_document):
|
||||
"""Test updating document status."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
repo.update_status(
|
||||
str(sample_document.document_id),
|
||||
status="labeled",
|
||||
auto_label_status="completed",
|
||||
)
|
||||
|
||||
doc = repo.get(str(sample_document.document_id))
|
||||
assert doc is not None
|
||||
assert doc.status == "labeled"
|
||||
assert doc.auto_label_status == "completed"
|
||||
|
||||
def test_update_status_with_error(self, patched_session, sample_document):
|
||||
"""Test updating document status with error message."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
repo.update_status(
|
||||
str(sample_document.document_id),
|
||||
status="pending",
|
||||
auto_label_status="failed",
|
||||
auto_label_error="OCR extraction failed",
|
||||
)
|
||||
|
||||
doc = repo.get(str(sample_document.document_id))
|
||||
assert doc is not None
|
||||
assert doc.auto_label_status == "failed"
|
||||
assert doc.auto_label_error == "OCR extraction failed"
|
||||
|
||||
def test_update_file_path(self, patched_session, sample_document):
|
||||
"""Test updating document file path."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
new_path = "/archive/2024/test_invoice.pdf"
|
||||
repo.update_file_path(str(sample_document.document_id), new_path)
|
||||
|
||||
doc = repo.get(str(sample_document.document_id))
|
||||
assert doc is not None
|
||||
assert doc.file_path == new_path
|
||||
|
||||
def test_update_group_key(self, patched_session, sample_document):
|
||||
"""Test updating document group key."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
result = repo.update_group_key(str(sample_document.document_id), "new-group-key")
|
||||
assert result is True
|
||||
|
||||
doc = repo.get(str(sample_document.document_id))
|
||||
assert doc is not None
|
||||
assert doc.group_key == "new-group-key"
|
||||
|
||||
def test_update_category(self, patched_session, sample_document):
|
||||
"""Test updating document category."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
doc = repo.update_category(str(sample_document.document_id), "letter")
|
||||
|
||||
assert doc is not None
|
||||
assert doc.category == "letter"
|
||||
|
||||
|
||||
class TestDocumentRepositoryDelete:
|
||||
"""Tests for document deletion."""
|
||||
|
||||
def test_delete_document(self, patched_session, sample_document):
|
||||
"""Test deleting a document."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
result = repo.delete(str(sample_document.document_id))
|
||||
assert result is True
|
||||
|
||||
doc = repo.get(str(sample_document.document_id))
|
||||
assert doc is None
|
||||
|
||||
def test_delete_document_with_annotations(self, patched_session, sample_document, sample_annotation):
|
||||
"""Test deleting document also deletes its annotations."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
result = repo.delete(str(sample_document.document_id))
|
||||
assert result is True
|
||||
|
||||
# Verify annotation is also deleted
|
||||
from inference.data.repositories.annotation_repository import AnnotationRepository
|
||||
|
||||
ann_repo = AnnotationRepository()
|
||||
annotations = ann_repo.get_for_document(str(sample_document.document_id))
|
||||
assert len(annotations) == 0
|
||||
|
||||
def test_delete_nonexistent_document(self, patched_session):
|
||||
"""Test deleting a document that doesn't exist."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
result = repo.delete(str(uuid4()))
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestDocumentRepositoryQueries:
|
||||
"""Tests for complex document queries."""
|
||||
|
||||
def test_count_by_status(self, patched_session, multiple_documents):
|
||||
"""Test counting documents by status."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
counts = repo.count_by_status()
|
||||
|
||||
assert counts.get("pending") == 2
|
||||
assert counts.get("labeled") == 2
|
||||
assert counts.get("exported") == 1
|
||||
|
||||
def test_get_categories(self, patched_session, multiple_documents):
|
||||
"""Test getting unique categories."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
categories = repo.get_categories()
|
||||
|
||||
assert "invoice" in categories
|
||||
assert "letter" in categories
|
||||
|
||||
def test_get_labeled_for_export(self, patched_session, multiple_documents):
|
||||
"""Test getting labeled documents for export."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
docs = repo.get_labeled_for_export()
|
||||
|
||||
assert len(docs) == 2
|
||||
for doc in docs:
|
||||
assert doc.status == "labeled"
|
||||
|
||||
|
||||
class TestDocumentAnnotationLocking:
|
||||
"""Tests for annotation locking mechanism."""
|
||||
|
||||
def test_acquire_annotation_lock(self, patched_session, sample_document):
|
||||
"""Test acquiring annotation lock."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
doc = repo.acquire_annotation_lock(
|
||||
str(sample_document.document_id),
|
||||
duration_seconds=300,
|
||||
)
|
||||
|
||||
assert doc is not None
|
||||
assert doc.annotation_lock_until is not None
|
||||
lock_until = ensure_utc(doc.annotation_lock_until)
|
||||
assert lock_until > datetime.now(timezone.utc)
|
||||
|
||||
def test_acquire_lock_when_already_locked(self, patched_session, sample_document):
|
||||
"""Test acquiring lock fails when already locked."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
# First lock
|
||||
repo.acquire_annotation_lock(str(sample_document.document_id), duration_seconds=300)
|
||||
|
||||
# Second lock attempt should fail
|
||||
result = repo.acquire_annotation_lock(str(sample_document.document_id))
|
||||
assert result is None
|
||||
|
||||
def test_release_annotation_lock(self, patched_session, sample_document):
|
||||
"""Test releasing annotation lock."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
repo.acquire_annotation_lock(str(sample_document.document_id), duration_seconds=300)
|
||||
doc = repo.release_annotation_lock(str(sample_document.document_id))
|
||||
|
||||
assert doc is not None
|
||||
assert doc.annotation_lock_until is None
|
||||
|
||||
def test_extend_annotation_lock(self, patched_session, sample_document):
|
||||
"""Test extending annotation lock."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
# Acquire initial lock
|
||||
initial_doc = repo.acquire_annotation_lock(
|
||||
str(sample_document.document_id),
|
||||
duration_seconds=300,
|
||||
)
|
||||
initial_expiry = ensure_utc(initial_doc.annotation_lock_until)
|
||||
|
||||
# Extend lock
|
||||
extended_doc = repo.extend_annotation_lock(
|
||||
str(sample_document.document_id),
|
||||
additional_seconds=300,
|
||||
)
|
||||
|
||||
assert extended_doc is not None
|
||||
extended_expiry = ensure_utc(extended_doc.annotation_lock_until)
|
||||
assert extended_expiry > initial_expiry
|
||||
Reference in New Issue
Block a user