""" Tests for BatchUploadRepository 100% coverage tests for batch upload management. """ import pytest from datetime import datetime, timezone from unittest.mock import MagicMock, patch from uuid import uuid4, UUID from inference.data.admin_models import BatchUpload, BatchUploadFile from inference.data.repositories.batch_upload_repository import BatchUploadRepository class TestBatchUploadRepository: """Tests for BatchUploadRepository.""" @pytest.fixture def sample_batch(self) -> BatchUpload: """Create a sample batch upload for testing.""" return BatchUpload( batch_id=uuid4(), admin_token="admin-token", filename="invoices.zip", file_size=1024000, upload_source="ui", status="pending", total_files=10, processed_files=0, created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) @pytest.fixture def sample_file(self) -> BatchUploadFile: """Create a sample batch upload file for testing.""" return BatchUploadFile( file_id=uuid4(), batch_id=uuid4(), filename="invoice_001.pdf", status="pending", created_at=datetime.now(timezone.utc), ) @pytest.fixture def repo(self) -> BatchUploadRepository: """Create a BatchUploadRepository instance.""" return BatchUploadRepository() # ========================================================================= # create() tests # ========================================================================= def test_create_returns_batch(self, repo): """Test create returns created batch upload.""" with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.create( admin_token="admin-token", filename="test.zip", file_size=1024, ) mock_session.add.assert_called_once() mock_session.commit.assert_called_once() def test_create_with_upload_source(self, repo): """Test create with custom upload source.""" with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.create( admin_token="admin-token", filename="test.zip", file_size=1024, upload_source="api", ) added_batch = mock_session.add.call_args[0][0] assert added_batch.upload_source == "api" def test_create_default_upload_source(self, repo): """Test create uses default upload source.""" with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.create( admin_token="admin-token", filename="test.zip", file_size=1024, ) added_batch = mock_session.add.call_args[0][0] assert added_batch.upload_source == "ui" # ========================================================================= # get() tests # ========================================================================= def test_get_returns_batch(self, repo, sample_batch): """Test get returns batch when exists.""" with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_batch mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get(sample_batch.batch_id) assert result is not None assert result.filename == "invoices.zip" mock_session.expunge.assert_called_once() def test_get_returns_none_when_not_found(self, repo): """Test get returns None when batch not found.""" with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get(uuid4()) assert result is None mock_session.expunge.assert_not_called() # ========================================================================= # update() tests # ========================================================================= def test_update_updates_batch(self, repo, sample_batch): """Test update updates batch fields.""" with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_batch mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update( sample_batch.batch_id, status="processing", processed_files=5, ) assert sample_batch.status == "processing" assert sample_batch.processed_files == 5 mock_session.add.assert_called_once() def test_update_ignores_unknown_fields(self, repo, sample_batch): """Test update ignores unknown fields.""" with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_batch mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update( sample_batch.batch_id, unknown_field="should_be_ignored", ) mock_session.add.assert_called_once() def test_update_not_found(self, repo): """Test update does nothing when batch not found.""" with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update(uuid4(), status="processing") mock_session.add.assert_not_called() def test_update_multiple_fields(self, repo, sample_batch): """Test update can update multiple fields.""" with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_batch mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update( sample_batch.batch_id, status="completed", processed_files=10, total_files=10, ) assert sample_batch.status == "completed" assert sample_batch.processed_files == 10 assert sample_batch.total_files == 10 # ========================================================================= # create_file() tests # ========================================================================= def test_create_file_returns_file(self, repo): """Test create_file returns created file record.""" with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.create_file( batch_id=uuid4(), filename="invoice_001.pdf", ) mock_session.add.assert_called_once() mock_session.commit.assert_called_once() def test_create_file_with_kwargs(self, repo): """Test create_file with additional kwargs.""" with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.create_file( batch_id=uuid4(), filename="invoice_001.pdf", status="processing", file_size=1024, ) added_file = mock_session.add.call_args[0][0] assert added_file.filename == "invoice_001.pdf" # ========================================================================= # update_file() tests # ========================================================================= def test_update_file_updates_file(self, repo, sample_file): """Test update_file updates file fields.""" with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_file mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_file( sample_file.file_id, status="completed", ) assert sample_file.status == "completed" mock_session.add.assert_called_once() def test_update_file_ignores_unknown_fields(self, repo, sample_file): """Test update_file ignores unknown fields.""" with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_file mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_file( sample_file.file_id, unknown_field="should_be_ignored", ) mock_session.add.assert_called_once() def test_update_file_not_found(self, repo): """Test update_file does nothing when file not found.""" with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_file(uuid4(), status="completed") mock_session.add.assert_not_called() def test_update_file_multiple_fields(self, repo, sample_file): """Test update_file can update multiple fields.""" with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_file mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_file( sample_file.file_id, status="failed", ) assert sample_file.status == "failed" # ========================================================================= # get_files() tests # ========================================================================= def test_get_files_returns_list(self, repo, sample_file): """Test get_files returns list of files.""" with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [sample_file] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_files(sample_file.batch_id) assert len(result) == 1 assert result[0].filename == "invoice_001.pdf" def test_get_files_returns_empty_list(self, repo): """Test get_files returns empty list when no files.""" with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_files(uuid4()) assert result == [] # ========================================================================= # get_paginated() tests # ========================================================================= def test_get_paginated_returns_batches_and_total(self, repo, sample_batch): """Test get_paginated returns list of batches and total count.""" with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 1 mock_session.exec.return_value.all.return_value = [sample_batch] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) batches, total = repo.get_paginated() assert len(batches) == 1 assert total == 1 def test_get_paginated_with_pagination(self, repo, sample_batch): """Test get_paginated with limit and offset.""" with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 100 mock_session.exec.return_value.all.return_value = [sample_batch] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) batches, total = repo.get_paginated(limit=25, offset=50) assert total == 100 def test_get_paginated_empty_results(self, repo): """Test get_paginated with no results.""" with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 0 mock_session.exec.return_value.all.return_value = [] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) batches, total = repo.get_paginated() assert batches == [] assert total == 0 def test_get_paginated_with_admin_token(self, repo, sample_batch): """Test get_paginated with admin_token parameter (deprecated, ignored).""" with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 1 mock_session.exec.return_value.all.return_value = [sample_batch] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) batches, total = repo.get_paginated(admin_token="admin-token") assert len(batches) == 1