Files
invoice-master-poc-v2/tests/data/repositories/test_document_repository.py
Yaojia Wang b602d0a340 re-structure
2026-02-01 22:55:31 +01:00

749 lines
35 KiB
Python

"""
Tests for DocumentRepository
Comprehensive TDD tests for document management - targeting 100% coverage.
"""
import pytest
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, patch
from uuid import uuid4
from backend.data.admin_models import AdminDocument, AdminAnnotation
from backend.data.repositories.document_repository import DocumentRepository
class TestDocumentRepository:
"""Tests for DocumentRepository."""
@pytest.fixture
def sample_document(self) -> AdminDocument:
"""Create a sample document for testing."""
return AdminDocument(
document_id=uuid4(),
filename="test.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/tmp/test.pdf",
page_count=1,
status="pending",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
@pytest.fixture
def labeled_document(self) -> AdminDocument:
"""Create a labeled document for testing."""
return AdminDocument(
document_id=uuid4(),
filename="labeled.pdf",
file_size=2048,
content_type="application/pdf",
file_path="/tmp/labeled.pdf",
page_count=2,
status="labeled",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
@pytest.fixture
def locked_document(self) -> AdminDocument:
"""Create a document with annotation lock."""
doc = AdminDocument(
document_id=uuid4(),
filename="locked.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/tmp/locked.pdf",
page_count=1,
status="pending",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
doc.annotation_lock_until = datetime.now(timezone.utc) + timedelta(minutes=5)
return doc
@pytest.fixture
def expired_lock_document(self) -> AdminDocument:
"""Create a document with expired annotation lock."""
doc = AdminDocument(
document_id=uuid4(),
filename="expired_lock.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/tmp/expired_lock.pdf",
page_count=1,
status="pending",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
doc.annotation_lock_until = datetime.now(timezone.utc) - timedelta(minutes=5)
return doc
@pytest.fixture
def repo(self) -> DocumentRepository:
"""Create a DocumentRepository instance."""
return DocumentRepository()
# ==========================================================================
# create() tests
# ==========================================================================
def test_create_returns_document_id(self, repo):
"""Test create returns document ID."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.create(
filename="test.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/tmp/test.pdf",
)
assert result is not None
mock_session.add.assert_called_once()
mock_session.flush.assert_called_once()
def test_create_with_all_parameters(self, repo):
"""Test create with all optional parameters."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.create(
filename="test.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/tmp/test.pdf",
page_count=5,
upload_source="api",
csv_field_values={"InvoiceNumber": "INV-001"},
group_key="batch-001",
category="receipt",
admin_token="token-123",
)
assert result is not None
added_doc = mock_session.add.call_args[0][0]
assert added_doc.page_count == 5
assert added_doc.upload_source == "api"
assert added_doc.csv_field_values == {"InvoiceNumber": "INV-001"}
assert added_doc.group_key == "batch-001"
assert added_doc.category == "receipt"
# ==========================================================================
# get() tests
# ==========================================================================
def test_get_returns_document(self, repo, sample_document):
"""Test get returns document when exists."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get(str(sample_document.document_id))
assert result is not None
assert result.filename == "test.pdf"
mock_session.expunge.assert_called_once()
def test_get_returns_none_when_not_found(self, repo):
"""Test get returns None when document not found."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get(str(uuid4()))
assert result is None
# ==========================================================================
# get_by_token() tests
# ==========================================================================
def test_get_by_token_delegates_to_get(self, repo, sample_document):
"""Test get_by_token delegates to get method."""
with patch.object(repo, "get", return_value=sample_document) as mock_get:
result = repo.get_by_token(str(sample_document.document_id), "token-123")
assert result == sample_document
mock_get.assert_called_once_with(str(sample_document.document_id))
# ==========================================================================
# get_paginated() tests
# ==========================================================================
def test_get_paginated_no_filters(self, repo, sample_document):
"""Test get_paginated with no filters."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
results, total = repo.get_paginated()
assert total == 1
assert len(results) == 1
def test_get_paginated_with_status_filter(self, repo, sample_document):
"""Test get_paginated with status filter."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
results, total = repo.get_paginated(status="pending")
assert total == 1
def test_get_paginated_with_upload_source_filter(self, repo, sample_document):
"""Test get_paginated with upload_source filter."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
results, total = repo.get_paginated(upload_source="ui")
assert total == 1
def test_get_paginated_with_auto_label_status_filter(self, repo, sample_document):
"""Test get_paginated with auto_label_status filter."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
results, total = repo.get_paginated(auto_label_status="completed")
assert total == 1
def test_get_paginated_with_batch_id_filter(self, repo, sample_document):
"""Test get_paginated with batch_id filter."""
batch_id = str(uuid4())
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
results, total = repo.get_paginated(batch_id=batch_id)
assert total == 1
def test_get_paginated_with_category_filter(self, repo, sample_document):
"""Test get_paginated with category filter."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
results, total = repo.get_paginated(category="invoice")
assert total == 1
def test_get_paginated_with_has_annotations_true(self, repo, sample_document):
"""Test get_paginated with has_annotations=True."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
results, total = repo.get_paginated(has_annotations=True)
assert total == 1
def test_get_paginated_with_has_annotations_false(self, repo, sample_document):
"""Test get_paginated with has_annotations=False."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
results, total = repo.get_paginated(has_annotations=False)
assert total == 1
# ==========================================================================
# update_status() tests
# ==========================================================================
def test_update_status(self, repo, sample_document):
"""Test update_status updates document status."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(str(sample_document.document_id), "labeled")
assert sample_document.status == "labeled"
mock_session.add.assert_called_once()
def test_update_status_with_auto_label_status(self, repo, sample_document):
"""Test update_status with auto_label_status."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(
str(sample_document.document_id),
"labeled",
auto_label_status="completed",
)
assert sample_document.auto_label_status == "completed"
def test_update_status_with_auto_label_error(self, repo, sample_document):
"""Test update_status with auto_label_error."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(
str(sample_document.document_id),
"failed",
auto_label_error="OCR failed",
)
assert sample_document.auto_label_error == "OCR failed"
def test_update_status_document_not_found(self, repo):
"""Test update_status when document not found."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(str(uuid4()), "labeled")
mock_session.add.assert_not_called()
# ==========================================================================
# update_file_path() tests
# ==========================================================================
def test_update_file_path(self, repo, sample_document):
"""Test update_file_path updates document file path."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_file_path(str(sample_document.document_id), "/new/path.pdf")
assert sample_document.file_path == "/new/path.pdf"
mock_session.add.assert_called_once()
def test_update_file_path_document_not_found(self, repo):
"""Test update_file_path when document not found."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_file_path(str(uuid4()), "/new/path.pdf")
mock_session.add.assert_not_called()
# ==========================================================================
# update_group_key() tests
# ==========================================================================
def test_update_group_key_returns_true(self, repo, sample_document):
"""Test update_group_key returns True when document exists."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update_group_key(str(sample_document.document_id), "new-group")
assert result is True
assert sample_document.group_key == "new-group"
def test_update_group_key_returns_false(self, repo):
"""Test update_group_key returns False when document not found."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update_group_key(str(uuid4()), "new-group")
assert result is False
# ==========================================================================
# update_category() tests
# ==========================================================================
def test_update_category(self, repo, sample_document):
"""Test update_category updates document category."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update_category(str(sample_document.document_id), "receipt")
assert sample_document.category == "receipt"
mock_session.add.assert_called()
def test_update_category_returns_none_when_not_found(self, repo):
"""Test update_category returns None when document not found."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update_category(str(uuid4()), "receipt")
assert result is None
# ==========================================================================
# delete() tests
# ==========================================================================
def test_delete_returns_true_when_exists(self, repo, sample_document):
"""Test delete returns True when document exists."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete(str(sample_document.document_id))
assert result is True
mock_session.delete.assert_called_once_with(sample_document)
def test_delete_with_annotations(self, repo, sample_document):
"""Test delete removes annotations before deleting document."""
annotation = MagicMock()
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_session.exec.return_value.all.return_value = [annotation]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete(str(sample_document.document_id))
assert result is True
assert mock_session.delete.call_count == 2
def test_delete_returns_false_when_not_exists(self, repo):
"""Test delete returns False when document not found."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete(str(uuid4()))
assert result is False
# ==========================================================================
# get_categories() tests
# ==========================================================================
def test_get_categories(self, repo):
"""Test get_categories returns unique categories."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = ["invoice", "receipt", None]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_categories()
assert result == ["invoice", "receipt"]
# ==========================================================================
# get_labeled_for_export() tests
# ==========================================================================
def test_get_labeled_for_export(self, repo, labeled_document):
"""Test get_labeled_for_export returns labeled documents."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [labeled_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_labeled_for_export()
assert len(result) == 1
assert result[0].status == "labeled"
def test_get_labeled_for_export_with_token(self, repo, labeled_document):
"""Test get_labeled_for_export with admin_token filter."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [labeled_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_labeled_for_export(admin_token="token-123")
assert len(result) == 1
# ==========================================================================
# count_by_status() tests
# ==========================================================================
def test_count_by_status(self, repo):
"""Test count_by_status returns status counts."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [
("pending", 10),
("labeled", 5),
]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.count_by_status()
assert result == {"pending": 10, "labeled": 5}
# ==========================================================================
# get_by_ids() tests
# ==========================================================================
def test_get_by_ids(self, repo, sample_document):
"""Test get_by_ids returns documents by IDs."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_by_ids([str(sample_document.document_id)])
assert len(result) == 1
# ==========================================================================
# get_for_training() tests
# ==========================================================================
def test_get_for_training_basic(self, repo, labeled_document):
"""Test get_for_training with default parameters."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [labeled_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
results, total = repo.get_for_training()
assert total == 1
assert len(results) == 1
def test_get_for_training_with_min_annotation_count(self, repo, labeled_document):
"""Test get_for_training with min_annotation_count."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [labeled_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
results, total = repo.get_for_training(min_annotation_count=3)
assert total == 1
def test_get_for_training_exclude_used(self, repo, labeled_document):
"""Test get_for_training with exclude_used_in_training."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [labeled_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
results, total = repo.get_for_training(exclude_used_in_training=True)
assert total == 1
def test_get_for_training_no_annotations(self, repo, labeled_document):
"""Test get_for_training with has_annotations=False."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [labeled_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
results, total = repo.get_for_training(has_annotations=False)
assert total == 1
# ==========================================================================
# acquire_annotation_lock() tests
# ==========================================================================
def test_acquire_annotation_lock_success(self, repo, sample_document):
"""Test acquire_annotation_lock when no lock exists."""
sample_document.annotation_lock_until = None
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.acquire_annotation_lock(str(sample_document.document_id))
assert result is not None
assert sample_document.annotation_lock_until is not None
def test_acquire_annotation_lock_fails_when_locked(self, repo, locked_document):
"""Test acquire_annotation_lock fails when document is already locked."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = locked_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.acquire_annotation_lock(str(locked_document.document_id))
assert result is None
def test_acquire_annotation_lock_document_not_found(self, repo):
"""Test acquire_annotation_lock when document not found."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.acquire_annotation_lock(str(uuid4()))
assert result is None
# ==========================================================================
# release_annotation_lock() tests
# ==========================================================================
def test_release_annotation_lock_success(self, repo, locked_document):
"""Test release_annotation_lock releases the lock."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = locked_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.release_annotation_lock(str(locked_document.document_id))
assert result is not None
assert locked_document.annotation_lock_until is None
def test_release_annotation_lock_document_not_found(self, repo):
"""Test release_annotation_lock when document not found."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.release_annotation_lock(str(uuid4()))
assert result is None
# ==========================================================================
# extend_annotation_lock() tests
# ==========================================================================
def test_extend_annotation_lock_success(self, repo, locked_document):
"""Test extend_annotation_lock extends the lock."""
original_lock = locked_document.annotation_lock_until
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = locked_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.extend_annotation_lock(str(locked_document.document_id))
assert result is not None
assert locked_document.annotation_lock_until > original_lock
def test_extend_annotation_lock_fails_when_no_lock(self, repo, sample_document):
"""Test extend_annotation_lock fails when no lock exists."""
sample_document.annotation_lock_until = None
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.extend_annotation_lock(str(sample_document.document_id))
assert result is None
def test_extend_annotation_lock_fails_when_expired(self, repo, expired_lock_document):
"""Test extend_annotation_lock fails when lock is expired."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = expired_lock_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.extend_annotation_lock(str(expired_lock_document.document_id))
assert result is None
def test_extend_annotation_lock_document_not_found(self, repo):
"""Test extend_annotation_lock when document not found."""
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.extend_annotation_lock(str(uuid4()))
assert result is None