""" Tests for DocumentRepository Comprehensive TDD tests for document management - targeting 100% coverage. """ import pytest from datetime import datetime, timedelta, timezone from unittest.mock import MagicMock, patch from uuid import uuid4 from inference.data.admin_models import AdminDocument, AdminAnnotation from inference.data.repositories.document_repository import DocumentRepository class TestDocumentRepository: """Tests for DocumentRepository.""" @pytest.fixture def sample_document(self) -> AdminDocument: """Create a sample document for testing.""" return AdminDocument( document_id=uuid4(), filename="test.pdf", file_size=1024, content_type="application/pdf", file_path="/tmp/test.pdf", page_count=1, status="pending", category="invoice", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) @pytest.fixture def labeled_document(self) -> AdminDocument: """Create a labeled document for testing.""" return AdminDocument( document_id=uuid4(), filename="labeled.pdf", file_size=2048, content_type="application/pdf", file_path="/tmp/labeled.pdf", page_count=2, status="labeled", category="invoice", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) @pytest.fixture def locked_document(self) -> AdminDocument: """Create a document with annotation lock.""" doc = AdminDocument( document_id=uuid4(), filename="locked.pdf", file_size=1024, content_type="application/pdf", file_path="/tmp/locked.pdf", page_count=1, status="pending", category="invoice", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) doc.annotation_lock_until = datetime.now(timezone.utc) + timedelta(minutes=5) return doc @pytest.fixture def expired_lock_document(self) -> AdminDocument: """Create a document with expired annotation lock.""" doc = AdminDocument( document_id=uuid4(), filename="expired_lock.pdf", file_size=1024, content_type="application/pdf", file_path="/tmp/expired_lock.pdf", page_count=1, status="pending", category="invoice", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) doc.annotation_lock_until = datetime.now(timezone.utc) - timedelta(minutes=5) return doc @pytest.fixture def repo(self) -> DocumentRepository: """Create a DocumentRepository instance.""" return DocumentRepository() # ========================================================================== # create() tests # ========================================================================== def test_create_returns_document_id(self, repo): """Test create returns document ID.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.create( filename="test.pdf", file_size=1024, content_type="application/pdf", file_path="/tmp/test.pdf", ) assert result is not None mock_session.add.assert_called_once() mock_session.flush.assert_called_once() def test_create_with_all_parameters(self, repo): """Test create with all optional parameters.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.create( filename="test.pdf", file_size=1024, content_type="application/pdf", file_path="/tmp/test.pdf", page_count=5, upload_source="api", csv_field_values={"InvoiceNumber": "INV-001"}, group_key="batch-001", category="receipt", admin_token="token-123", ) assert result is not None added_doc = mock_session.add.call_args[0][0] assert added_doc.page_count == 5 assert added_doc.upload_source == "api" assert added_doc.csv_field_values == {"InvoiceNumber": "INV-001"} assert added_doc.group_key == "batch-001" assert added_doc.category == "receipt" # ========================================================================== # get() tests # ========================================================================== def test_get_returns_document(self, repo, sample_document): """Test get returns document when exists.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_document mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get(str(sample_document.document_id)) assert result is not None assert result.filename == "test.pdf" mock_session.expunge.assert_called_once() def test_get_returns_none_when_not_found(self, repo): """Test get returns None when document not found.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get(str(uuid4())) assert result is None # ========================================================================== # get_by_token() tests # ========================================================================== def test_get_by_token_delegates_to_get(self, repo, sample_document): """Test get_by_token delegates to get method.""" with patch.object(repo, "get", return_value=sample_document) as mock_get: result = repo.get_by_token(str(sample_document.document_id), "token-123") assert result == sample_document mock_get.assert_called_once_with(str(sample_document.document_id)) # ========================================================================== # get_paginated() tests # ========================================================================== def test_get_paginated_no_filters(self, repo, sample_document): """Test get_paginated with no filters.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 1 mock_session.exec.return_value.all.return_value = [sample_document] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) results, total = repo.get_paginated() assert total == 1 assert len(results) == 1 def test_get_paginated_with_status_filter(self, repo, sample_document): """Test get_paginated with status filter.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 1 mock_session.exec.return_value.all.return_value = [sample_document] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) results, total = repo.get_paginated(status="pending") assert total == 1 def test_get_paginated_with_upload_source_filter(self, repo, sample_document): """Test get_paginated with upload_source filter.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 1 mock_session.exec.return_value.all.return_value = [sample_document] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) results, total = repo.get_paginated(upload_source="ui") assert total == 1 def test_get_paginated_with_auto_label_status_filter(self, repo, sample_document): """Test get_paginated with auto_label_status filter.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 1 mock_session.exec.return_value.all.return_value = [sample_document] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) results, total = repo.get_paginated(auto_label_status="completed") assert total == 1 def test_get_paginated_with_batch_id_filter(self, repo, sample_document): """Test get_paginated with batch_id filter.""" batch_id = str(uuid4()) with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 1 mock_session.exec.return_value.all.return_value = [sample_document] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) results, total = repo.get_paginated(batch_id=batch_id) assert total == 1 def test_get_paginated_with_category_filter(self, repo, sample_document): """Test get_paginated with category filter.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 1 mock_session.exec.return_value.all.return_value = [sample_document] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) results, total = repo.get_paginated(category="invoice") assert total == 1 def test_get_paginated_with_has_annotations_true(self, repo, sample_document): """Test get_paginated with has_annotations=True.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 1 mock_session.exec.return_value.all.return_value = [sample_document] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) results, total = repo.get_paginated(has_annotations=True) assert total == 1 def test_get_paginated_with_has_annotations_false(self, repo, sample_document): """Test get_paginated with has_annotations=False.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 1 mock_session.exec.return_value.all.return_value = [sample_document] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) results, total = repo.get_paginated(has_annotations=False) assert total == 1 # ========================================================================== # update_status() tests # ========================================================================== def test_update_status(self, repo, sample_document): """Test update_status updates document status.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_document mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_status(str(sample_document.document_id), "labeled") assert sample_document.status == "labeled" mock_session.add.assert_called_once() def test_update_status_with_auto_label_status(self, repo, sample_document): """Test update_status with auto_label_status.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_document mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_status( str(sample_document.document_id), "labeled", auto_label_status="completed", ) assert sample_document.auto_label_status == "completed" def test_update_status_with_auto_label_error(self, repo, sample_document): """Test update_status with auto_label_error.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_document mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_status( str(sample_document.document_id), "failed", auto_label_error="OCR failed", ) assert sample_document.auto_label_error == "OCR failed" def test_update_status_document_not_found(self, repo): """Test update_status when document not found.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_status(str(uuid4()), "labeled") mock_session.add.assert_not_called() # ========================================================================== # update_file_path() tests # ========================================================================== def test_update_file_path(self, repo, sample_document): """Test update_file_path updates document file path.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_document mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_file_path(str(sample_document.document_id), "/new/path.pdf") assert sample_document.file_path == "/new/path.pdf" mock_session.add.assert_called_once() def test_update_file_path_document_not_found(self, repo): """Test update_file_path when document not found.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_file_path(str(uuid4()), "/new/path.pdf") mock_session.add.assert_not_called() # ========================================================================== # update_group_key() tests # ========================================================================== def test_update_group_key_returns_true(self, repo, sample_document): """Test update_group_key returns True when document exists.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_document mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.update_group_key(str(sample_document.document_id), "new-group") assert result is True assert sample_document.group_key == "new-group" def test_update_group_key_returns_false(self, repo): """Test update_group_key returns False when document not found.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.update_group_key(str(uuid4()), "new-group") assert result is False # ========================================================================== # update_category() tests # ========================================================================== def test_update_category(self, repo, sample_document): """Test update_category updates document category.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_document mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.update_category(str(sample_document.document_id), "receipt") assert sample_document.category == "receipt" mock_session.add.assert_called() def test_update_category_returns_none_when_not_found(self, repo): """Test update_category returns None when document not found.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.update_category(str(uuid4()), "receipt") assert result is None # ========================================================================== # delete() tests # ========================================================================== def test_delete_returns_true_when_exists(self, repo, sample_document): """Test delete returns True when document exists.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_document mock_session.exec.return_value.all.return_value = [] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.delete(str(sample_document.document_id)) assert result is True mock_session.delete.assert_called_once_with(sample_document) def test_delete_with_annotations(self, repo, sample_document): """Test delete removes annotations before deleting document.""" annotation = MagicMock() with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_document mock_session.exec.return_value.all.return_value = [annotation] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.delete(str(sample_document.document_id)) assert result is True assert mock_session.delete.call_count == 2 def test_delete_returns_false_when_not_exists(self, repo): """Test delete returns False when document not found.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.delete(str(uuid4())) assert result is False # ========================================================================== # get_categories() tests # ========================================================================== def test_get_categories(self, repo): """Test get_categories returns unique categories.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = ["invoice", "receipt", None] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_categories() assert result == ["invoice", "receipt"] # ========================================================================== # get_labeled_for_export() tests # ========================================================================== def test_get_labeled_for_export(self, repo, labeled_document): """Test get_labeled_for_export returns labeled documents.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [labeled_document] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_labeled_for_export() assert len(result) == 1 assert result[0].status == "labeled" def test_get_labeled_for_export_with_token(self, repo, labeled_document): """Test get_labeled_for_export with admin_token filter.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [labeled_document] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_labeled_for_export(admin_token="token-123") assert len(result) == 1 # ========================================================================== # count_by_status() tests # ========================================================================== def test_count_by_status(self, repo): """Test count_by_status returns status counts.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [ ("pending", 10), ("labeled", 5), ] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.count_by_status() assert result == {"pending": 10, "labeled": 5} # ========================================================================== # get_by_ids() tests # ========================================================================== def test_get_by_ids(self, repo, sample_document): """Test get_by_ids returns documents by IDs.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [sample_document] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_by_ids([str(sample_document.document_id)]) assert len(result) == 1 # ========================================================================== # get_for_training() tests # ========================================================================== def test_get_for_training_basic(self, repo, labeled_document): """Test get_for_training with default parameters.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 1 mock_session.exec.return_value.all.return_value = [labeled_document] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) results, total = repo.get_for_training() assert total == 1 assert len(results) == 1 def test_get_for_training_with_min_annotation_count(self, repo, labeled_document): """Test get_for_training with min_annotation_count.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 1 mock_session.exec.return_value.all.return_value = [labeled_document] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) results, total = repo.get_for_training(min_annotation_count=3) assert total == 1 def test_get_for_training_exclude_used(self, repo, labeled_document): """Test get_for_training with exclude_used_in_training.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 1 mock_session.exec.return_value.all.return_value = [labeled_document] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) results, total = repo.get_for_training(exclude_used_in_training=True) assert total == 1 def test_get_for_training_no_annotations(self, repo, labeled_document): """Test get_for_training with has_annotations=False.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 1 mock_session.exec.return_value.all.return_value = [labeled_document] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) results, total = repo.get_for_training(has_annotations=False) assert total == 1 # ========================================================================== # acquire_annotation_lock() tests # ========================================================================== def test_acquire_annotation_lock_success(self, repo, sample_document): """Test acquire_annotation_lock when no lock exists.""" sample_document.annotation_lock_until = None with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_document mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.acquire_annotation_lock(str(sample_document.document_id)) assert result is not None assert sample_document.annotation_lock_until is not None def test_acquire_annotation_lock_fails_when_locked(self, repo, locked_document): """Test acquire_annotation_lock fails when document is already locked.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = locked_document mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.acquire_annotation_lock(str(locked_document.document_id)) assert result is None def test_acquire_annotation_lock_document_not_found(self, repo): """Test acquire_annotation_lock when document not found.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.acquire_annotation_lock(str(uuid4())) assert result is None # ========================================================================== # release_annotation_lock() tests # ========================================================================== def test_release_annotation_lock_success(self, repo, locked_document): """Test release_annotation_lock releases the lock.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = locked_document mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.release_annotation_lock(str(locked_document.document_id)) assert result is not None assert locked_document.annotation_lock_until is None def test_release_annotation_lock_document_not_found(self, repo): """Test release_annotation_lock when document not found.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.release_annotation_lock(str(uuid4())) assert result is None # ========================================================================== # extend_annotation_lock() tests # ========================================================================== def test_extend_annotation_lock_success(self, repo, locked_document): """Test extend_annotation_lock extends the lock.""" original_lock = locked_document.annotation_lock_until with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = locked_document mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.extend_annotation_lock(str(locked_document.document_id)) assert result is not None assert locked_document.annotation_lock_until > original_lock def test_extend_annotation_lock_fails_when_no_lock(self, repo, sample_document): """Test extend_annotation_lock fails when no lock exists.""" sample_document.annotation_lock_until = None with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_document mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.extend_annotation_lock(str(sample_document.document_id)) assert result is None def test_extend_annotation_lock_fails_when_expired(self, repo, expired_lock_document): """Test extend_annotation_lock fails when lock is expired.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = expired_lock_document mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.extend_annotation_lock(str(expired_lock_document.document_id)) assert result is None def test_extend_annotation_lock_document_not_found(self, repo): """Test extend_annotation_lock when document not found.""" with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.extend_annotation_lock(str(uuid4())) assert result is None