""" Batch Upload Repository Integration Tests Tests BatchUploadRepository with real database operations. """ from datetime import datetime, timezone from uuid import uuid4 import pytest from inference.data.repositories.batch_upload_repository import BatchUploadRepository class TestBatchUploadCreate: """Tests for batch upload creation.""" def test_create_batch_upload(self, patched_session, admin_token): """Test creating a batch upload.""" repo = BatchUploadRepository() batch = repo.create( admin_token=admin_token.token, filename="test_batch.zip", file_size=10240, upload_source="api", ) assert batch is not None assert batch.batch_id is not None assert batch.filename == "test_batch.zip" assert batch.file_size == 10240 assert batch.upload_source == "api" assert batch.status == "processing" assert batch.total_files == 0 assert batch.processed_files == 0 def test_create_batch_upload_default_source(self, patched_session, admin_token): """Test creating batch upload with default source.""" repo = BatchUploadRepository() batch = repo.create( admin_token=admin_token.token, filename="ui_batch.zip", file_size=5120, ) assert batch.upload_source == "ui" class TestBatchUploadRead: """Tests for batch upload retrieval.""" def test_get_batch_upload(self, patched_session, sample_batch_upload): """Test getting a batch upload by ID.""" repo = BatchUploadRepository() batch = repo.get(sample_batch_upload.batch_id) assert batch is not None assert batch.batch_id == sample_batch_upload.batch_id assert batch.filename == sample_batch_upload.filename def test_get_nonexistent_batch_upload(self, patched_session): """Test getting a batch upload that doesn't exist.""" repo = BatchUploadRepository() batch = repo.get(uuid4()) assert batch is None def test_get_paginated_batch_uploads(self, patched_session, admin_token): """Test paginated batch upload listing.""" repo = BatchUploadRepository() # Create multiple batches for i in range(5): repo.create( admin_token=admin_token.token, filename=f"batch_{i}.zip", file_size=1024 * (i + 1), ) batches, total = repo.get_paginated(limit=3, offset=0) assert total == 5 assert len(batches) == 3 def test_get_paginated_with_offset(self, patched_session, admin_token): """Test pagination offset.""" repo = BatchUploadRepository() for i in range(5): repo.create( admin_token=admin_token.token, filename=f"batch_{i}.zip", file_size=1024, ) page1, _ = repo.get_paginated(limit=2, offset=0) page2, _ = repo.get_paginated(limit=2, offset=2) ids_page1 = {b.batch_id for b in page1} ids_page2 = {b.batch_id for b in page2} assert len(ids_page1 & ids_page2) == 0 class TestBatchUploadUpdate: """Tests for batch upload updates.""" def test_update_batch_status(self, patched_session, sample_batch_upload): """Test updating batch upload status.""" repo = BatchUploadRepository() repo.update( sample_batch_upload.batch_id, status="completed", total_files=10, processed_files=10, successful_files=8, failed_files=2, ) # Need to commit to see changes patched_session.commit() batch = repo.get(sample_batch_upload.batch_id) assert batch.status == "completed" assert batch.total_files == 10 assert batch.successful_files == 8 assert batch.failed_files == 2 def test_update_batch_with_error(self, patched_session, sample_batch_upload): """Test updating batch upload with error message.""" repo = BatchUploadRepository() repo.update( sample_batch_upload.batch_id, status="failed", error_message="ZIP extraction failed", ) patched_session.commit() batch = repo.get(sample_batch_upload.batch_id) assert batch.status == "failed" assert batch.error_message == "ZIP extraction failed" def test_update_batch_csv_info(self, patched_session, sample_batch_upload): """Test updating batch with CSV information.""" repo = BatchUploadRepository() repo.update( sample_batch_upload.batch_id, csv_filename="manifest.csv", csv_row_count=100, ) patched_session.commit() batch = repo.get(sample_batch_upload.batch_id) assert batch.csv_filename == "manifest.csv" assert batch.csv_row_count == 100 class TestBatchUploadFiles: """Tests for batch upload file management.""" def test_create_batch_file(self, patched_session, sample_batch_upload): """Test creating a batch upload file record.""" repo = BatchUploadRepository() file_record = repo.create_file( batch_id=sample_batch_upload.batch_id, filename="invoice_001.pdf", status="pending", ) assert file_record is not None assert file_record.file_id is not None assert file_record.filename == "invoice_001.pdf" assert file_record.batch_id == sample_batch_upload.batch_id assert file_record.status == "pending" def test_create_batch_file_with_document_link(self, patched_session, sample_batch_upload, sample_document): """Test creating batch file linked to a document.""" repo = BatchUploadRepository() file_record = repo.create_file( batch_id=sample_batch_upload.batch_id, filename="invoice_linked.pdf", document_id=sample_document.document_id, status="completed", annotation_count=5, ) assert file_record.document_id == sample_document.document_id assert file_record.status == "completed" assert file_record.annotation_count == 5 def test_get_batch_files(self, patched_session, sample_batch_upload): """Test getting all files for a batch.""" repo = BatchUploadRepository() # Create multiple files for i in range(3): repo.create_file( batch_id=sample_batch_upload.batch_id, filename=f"file_{i}.pdf", ) files = repo.get_files(sample_batch_upload.batch_id) assert len(files) == 3 assert all(f.batch_id == sample_batch_upload.batch_id for f in files) def test_get_batch_files_empty(self, patched_session, sample_batch_upload): """Test getting files for batch with no files.""" repo = BatchUploadRepository() files = repo.get_files(sample_batch_upload.batch_id) assert files == [] def test_update_batch_file_status(self, patched_session, sample_batch_upload): """Test updating batch file status.""" repo = BatchUploadRepository() file_record = repo.create_file( batch_id=sample_batch_upload.batch_id, filename="test.pdf", ) repo.update_file( file_record.file_id, status="completed", annotation_count=10, ) patched_session.commit() files = repo.get_files(sample_batch_upload.batch_id) updated_file = files[0] assert updated_file.status == "completed" assert updated_file.annotation_count == 10 def test_update_batch_file_with_error(self, patched_session, sample_batch_upload): """Test updating batch file with error.""" repo = BatchUploadRepository() file_record = repo.create_file( batch_id=sample_batch_upload.batch_id, filename="corrupt.pdf", ) repo.update_file( file_record.file_id, status="failed", error_message="Invalid PDF format", ) patched_session.commit() files = repo.get_files(sample_batch_upload.batch_id) updated_file = files[0] assert updated_file.status == "failed" assert updated_file.error_message == "Invalid PDF format" def test_update_batch_file_with_csv_data(self, patched_session, sample_batch_upload): """Test updating batch file with CSV row data.""" repo = BatchUploadRepository() file_record = repo.create_file( batch_id=sample_batch_upload.batch_id, filename="invoice_with_csv.pdf", ) csv_data = { "invoice_number": "INV-001", "amount": "1500.00", "supplier": "Test Corp", } repo.update_file( file_record.file_id, csv_row_data=csv_data, ) patched_session.commit() files = repo.get_files(sample_batch_upload.batch_id) updated_file = files[0] assert updated_file.csv_row_data == csv_data class TestBatchUploadWorkflow: """Tests for complete batch upload workflows.""" def test_complete_batch_workflow(self, patched_session, admin_token): """Test complete batch upload workflow.""" repo = BatchUploadRepository() # 1. Create batch batch = repo.create( admin_token=admin_token.token, filename="full_workflow.zip", file_size=50000, ) # 2. Update with file count repo.update(batch.batch_id, total_files=3) patched_session.commit() # 3. Create file records file_ids = [] for i in range(3): file_record = repo.create_file( batch_id=batch.batch_id, filename=f"doc_{i}.pdf", ) file_ids.append(file_record.file_id) # 4. Process files one by one for i, file_id in enumerate(file_ids): status = "completed" if i < 2 else "failed" repo.update_file( file_id, status=status, annotation_count=5 if status == "completed" else 0, ) # 5. Update batch progress repo.update( batch.batch_id, processed_files=3, successful_files=2, failed_files=1, status="partial", ) patched_session.commit() # Verify final state final_batch = repo.get(batch.batch_id) assert final_batch.status == "partial" assert final_batch.total_files == 3 assert final_batch.processed_files == 3 assert final_batch.successful_files == 2 assert final_batch.failed_files == 1 files = repo.get_files(batch.batch_id) assert len(files) == 3 completed = [f for f in files if f.status == "completed"] failed = [f for f in files if f.status == "failed"] assert len(completed) == 2 assert len(failed) == 1