Add more tests

This commit is contained in:
Yaojia Wang
2026-02-01 22:40:41 +01:00
parent a564ac9d70
commit 400b12a967
55 changed files with 9306 additions and 267 deletions

View File

@@ -0,0 +1,350 @@
"""
Document Repository Integration Tests
Tests DocumentRepository with real database operations.
"""
from datetime import datetime, timezone, timedelta
from uuid import uuid4
import pytest
from sqlmodel import select
from inference.data.admin_models import AdminAnnotation, AdminDocument
from inference.data.repositories.document_repository import DocumentRepository
def ensure_utc(dt: datetime | None) -> datetime | None:
"""Ensure datetime is timezone-aware (UTC).
PostgreSQL may return offset-naive datetimes. This helper
converts them to UTC for proper comparison.
"""
if dt is None:
return None
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
return dt
class TestDocumentRepositoryCreate:
"""Tests for document creation."""
def test_create_document(self, patched_session):
"""Test creating a document and retrieving it."""
repo = DocumentRepository()
doc_id = repo.create(
filename="test_invoice.pdf",
file_size=2048,
content_type="application/pdf",
file_path="/uploads/test_invoice.pdf",
page_count=2,
upload_source="api",
category="invoice",
)
assert doc_id is not None
doc = repo.get(doc_id)
assert doc is not None
assert doc.filename == "test_invoice.pdf"
assert doc.file_size == 2048
assert doc.page_count == 2
assert doc.upload_source == "api"
assert doc.category == "invoice"
assert doc.status == "pending"
def test_create_document_with_csv_values(self, patched_session):
"""Test creating document with CSV field values."""
repo = DocumentRepository()
csv_values = {
"invoice_number": "INV-001",
"amount": "1500.00",
"supplier_name": "Test Supplier AB",
}
doc_id = repo.create(
filename="invoice_with_csv.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/invoice_with_csv.pdf",
csv_field_values=csv_values,
)
doc = repo.get(doc_id)
assert doc is not None
assert doc.csv_field_values == csv_values
def test_create_document_with_group_key(self, patched_session):
"""Test creating document with group key."""
repo = DocumentRepository()
doc_id = repo.create(
filename="grouped_doc.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/grouped_doc.pdf",
group_key="batch-2024-01",
)
doc = repo.get(doc_id)
assert doc is not None
assert doc.group_key == "batch-2024-01"
class TestDocumentRepositoryRead:
"""Tests for document retrieval."""
def test_get_nonexistent_document(self, patched_session):
"""Test getting a document that doesn't exist."""
repo = DocumentRepository()
doc = repo.get(str(uuid4()))
assert doc is None
def test_get_paginated_documents(self, patched_session, multiple_documents):
"""Test paginated document listing."""
repo = DocumentRepository()
docs, total = repo.get_paginated(limit=2, offset=0)
assert total == 5
assert len(docs) == 2
def test_get_paginated_with_status_filter(self, patched_session, multiple_documents):
"""Test filtering documents by status."""
repo = DocumentRepository()
docs, total = repo.get_paginated(status="labeled")
assert total == 2
for doc in docs:
assert doc.status == "labeled"
def test_get_paginated_with_category_filter(self, patched_session, multiple_documents):
"""Test filtering documents by category."""
repo = DocumentRepository()
docs, total = repo.get_paginated(category="letter")
assert total == 1
assert docs[0].category == "letter"
def test_get_paginated_with_offset(self, patched_session, multiple_documents):
"""Test pagination offset."""
repo = DocumentRepository()
docs_page1, _ = repo.get_paginated(limit=2, offset=0)
docs_page2, _ = repo.get_paginated(limit=2, offset=2)
doc_ids_page1 = {str(d.document_id) for d in docs_page1}
doc_ids_page2 = {str(d.document_id) for d in docs_page2}
assert len(doc_ids_page1 & doc_ids_page2) == 0
def test_get_by_ids(self, patched_session, multiple_documents):
"""Test getting multiple documents by IDs."""
repo = DocumentRepository()
ids_to_fetch = [str(multiple_documents[0].document_id), str(multiple_documents[2].document_id)]
docs = repo.get_by_ids(ids_to_fetch)
assert len(docs) == 2
fetched_ids = {str(d.document_id) for d in docs}
assert fetched_ids == set(ids_to_fetch)
class TestDocumentRepositoryUpdate:
"""Tests for document updates."""
def test_update_status(self, patched_session, sample_document):
"""Test updating document status."""
repo = DocumentRepository()
repo.update_status(
str(sample_document.document_id),
status="labeled",
auto_label_status="completed",
)
doc = repo.get(str(sample_document.document_id))
assert doc is not None
assert doc.status == "labeled"
assert doc.auto_label_status == "completed"
def test_update_status_with_error(self, patched_session, sample_document):
"""Test updating document status with error message."""
repo = DocumentRepository()
repo.update_status(
str(sample_document.document_id),
status="pending",
auto_label_status="failed",
auto_label_error="OCR extraction failed",
)
doc = repo.get(str(sample_document.document_id))
assert doc is not None
assert doc.auto_label_status == "failed"
assert doc.auto_label_error == "OCR extraction failed"
def test_update_file_path(self, patched_session, sample_document):
"""Test updating document file path."""
repo = DocumentRepository()
new_path = "/archive/2024/test_invoice.pdf"
repo.update_file_path(str(sample_document.document_id), new_path)
doc = repo.get(str(sample_document.document_id))
assert doc is not None
assert doc.file_path == new_path
def test_update_group_key(self, patched_session, sample_document):
"""Test updating document group key."""
repo = DocumentRepository()
result = repo.update_group_key(str(sample_document.document_id), "new-group-key")
assert result is True
doc = repo.get(str(sample_document.document_id))
assert doc is not None
assert doc.group_key == "new-group-key"
def test_update_category(self, patched_session, sample_document):
"""Test updating document category."""
repo = DocumentRepository()
doc = repo.update_category(str(sample_document.document_id), "letter")
assert doc is not None
assert doc.category == "letter"
class TestDocumentRepositoryDelete:
"""Tests for document deletion."""
def test_delete_document(self, patched_session, sample_document):
"""Test deleting a document."""
repo = DocumentRepository()
result = repo.delete(str(sample_document.document_id))
assert result is True
doc = repo.get(str(sample_document.document_id))
assert doc is None
def test_delete_document_with_annotations(self, patched_session, sample_document, sample_annotation):
"""Test deleting document also deletes its annotations."""
repo = DocumentRepository()
result = repo.delete(str(sample_document.document_id))
assert result is True
# Verify annotation is also deleted
from inference.data.repositories.annotation_repository import AnnotationRepository
ann_repo = AnnotationRepository()
annotations = ann_repo.get_for_document(str(sample_document.document_id))
assert len(annotations) == 0
def test_delete_nonexistent_document(self, patched_session):
"""Test deleting a document that doesn't exist."""
repo = DocumentRepository()
result = repo.delete(str(uuid4()))
assert result is False
class TestDocumentRepositoryQueries:
"""Tests for complex document queries."""
def test_count_by_status(self, patched_session, multiple_documents):
"""Test counting documents by status."""
repo = DocumentRepository()
counts = repo.count_by_status()
assert counts.get("pending") == 2
assert counts.get("labeled") == 2
assert counts.get("exported") == 1
def test_get_categories(self, patched_session, multiple_documents):
"""Test getting unique categories."""
repo = DocumentRepository()
categories = repo.get_categories()
assert "invoice" in categories
assert "letter" in categories
def test_get_labeled_for_export(self, patched_session, multiple_documents):
"""Test getting labeled documents for export."""
repo = DocumentRepository()
docs = repo.get_labeled_for_export()
assert len(docs) == 2
for doc in docs:
assert doc.status == "labeled"
class TestDocumentAnnotationLocking:
"""Tests for annotation locking mechanism."""
def test_acquire_annotation_lock(self, patched_session, sample_document):
"""Test acquiring annotation lock."""
repo = DocumentRepository()
doc = repo.acquire_annotation_lock(
str(sample_document.document_id),
duration_seconds=300,
)
assert doc is not None
assert doc.annotation_lock_until is not None
lock_until = ensure_utc(doc.annotation_lock_until)
assert lock_until > datetime.now(timezone.utc)
def test_acquire_lock_when_already_locked(self, patched_session, sample_document):
"""Test acquiring lock fails when already locked."""
repo = DocumentRepository()
# First lock
repo.acquire_annotation_lock(str(sample_document.document_id), duration_seconds=300)
# Second lock attempt should fail
result = repo.acquire_annotation_lock(str(sample_document.document_id))
assert result is None
def test_release_annotation_lock(self, patched_session, sample_document):
"""Test releasing annotation lock."""
repo = DocumentRepository()
repo.acquire_annotation_lock(str(sample_document.document_id), duration_seconds=300)
doc = repo.release_annotation_lock(str(sample_document.document_id))
assert doc is not None
assert doc.annotation_lock_until is None
def test_extend_annotation_lock(self, patched_session, sample_document):
"""Test extending annotation lock."""
repo = DocumentRepository()
# Acquire initial lock
initial_doc = repo.acquire_annotation_lock(
str(sample_document.document_id),
duration_seconds=300,
)
initial_expiry = ensure_utc(initial_doc.annotation_lock_until)
# Extend lock
extended_doc = repo.extend_annotation_lock(
str(sample_document.document_id),
additional_seconds=300,
)
assert extended_doc is not None
extended_expiry = ensure_utc(extended_doc.annotation_lock_until)
assert extended_expiry > initial_expiry