""" Tests for Batch Upload Routes """ import io import zipfile from datetime import datetime from uuid import uuid4 import pytest from fastapi import FastAPI from fastapi.testclient import TestClient from inference.web.api.v1.batch.routes import router, get_batch_repository from inference.web.core.auth import validate_admin_token from inference.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue from inference.web.services.batch_upload import BatchUploadService class MockBatchUploadRepository: """Mock BatchUploadRepository for testing.""" def __init__(self): self.batches = {} self.batch_files = {} def create(self, admin_token, filename, file_size, upload_source="ui"): batch_id = uuid4() batch = type('BatchUpload', (), { 'batch_id': batch_id, 'admin_token': admin_token, 'filename': filename, 'file_size': file_size, 'upload_source': upload_source, 'status': 'processing', 'total_files': 0, 'processed_files': 0, 'successful_files': 0, 'failed_files': 0, 'csv_filename': None, 'csv_row_count': None, 'error_message': None, 'created_at': datetime.utcnow(), 'completed_at': None, })() self.batches[batch_id] = batch return batch def update(self, batch_id, **kwargs): if batch_id in self.batches: batch = self.batches[batch_id] for key, value in kwargs.items(): setattr(batch, key, value) def create_file(self, batch_id, filename, **kwargs): file_id = uuid4() defaults = { 'file_id': file_id, 'batch_id': batch_id, 'filename': filename, 'status': 'pending', 'error_message': None, 'annotation_count': 0, 'csv_row_data': None, } defaults.update(kwargs) file_record = type('BatchUploadFile', (), defaults)() if batch_id not in self.batch_files: self.batch_files[batch_id] = [] self.batch_files[batch_id].append(file_record) return file_record def update_file(self, file_id, **kwargs): for files in self.batch_files.values(): for file_record in files: if file_record.file_id == file_id: for key, value in kwargs.items(): setattr(file_record, key, value) return def get(self, batch_id): return self.batches.get(batch_id, type('BatchUpload', (), { 'batch_id': batch_id, 'admin_token': 'test-token', 'filename': 'test.zip', 'status': 'completed', 'total_files': 2, 'processed_files': 2, 'successful_files': 2, 'failed_files': 0, 'csv_filename': None, 'csv_row_count': None, 'error_message': None, 'created_at': datetime.utcnow(), 'completed_at': datetime.utcnow(), })()) def get_files(self, batch_id): return self.batch_files.get(batch_id, []) def get_paginated(self, admin_token=None, limit=50, offset=0): """Get batches filtered by admin token with pagination.""" if admin_token: token_batches = [b for b in self.batches.values() if b.admin_token == admin_token] else: token_batches = list(self.batches.values()) total = len(token_batches) return token_batches[offset:offset+limit], total @pytest.fixture(scope="class") def app(): """Create test FastAPI app with mocked dependencies.""" app = FastAPI() # Create mock batch upload repository mock_batch_upload_repo = MockBatchUploadRepository() # Override dependencies app.dependency_overrides[validate_admin_token] = lambda: "test-token" app.dependency_overrides[get_batch_repository] = lambda: mock_batch_upload_repo # Initialize batch queue with mock service batch_service = BatchUploadService(mock_batch_upload_repo) init_batch_queue(batch_service) app.include_router(router) yield app # Cleanup: shutdown batch queue after all tests in class shutdown_batch_queue() @pytest.fixture def client(app): """Create test client.""" return TestClient(app) def create_test_zip(files): """Create a test ZIP file.""" zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file: for filename, content in files.items(): zip_file.writestr(filename, content) zip_buffer.seek(0) return zip_buffer class TestBatchUploadRoutes: """Tests for batch upload API routes.""" def test_upload_batch_success(self, client): """Test successful batch upload (defaults to async mode).""" files = { "INV001.pdf": b"%PDF-1.4 test content", "INV002.pdf": b"%PDF-1.4 test content 2", } zip_file = create_test_zip(files) response = client.post( "/api/v1/admin/batch/upload", files={"file": ("test.zip", zip_file, "application/zip")}, data={"upload_source": "ui"}, ) # Async mode is default, should return 202 assert response.status_code == 202 result = response.json() assert "batch_id" in result assert result["status"] == "accepted" def test_upload_batch_non_zip_file(self, client): """Test uploading non-ZIP file.""" response = client.post( "/api/v1/admin/batch/upload", files={"file": ("test.pdf", io.BytesIO(b"test"), "application/pdf")}, data={"upload_source": "ui"}, ) assert response.status_code == 400 assert "Only ZIP files" in response.json()["detail"] def test_upload_batch_with_csv(self, client): """Test batch upload with CSV (defaults to async).""" csv_content = """DocumentId,InvoiceNumber,Amount INV001,F2024-001,1500.00 INV002,F2024-002,2500.00 """ files = { "INV001.pdf": b"%PDF-1.4 test", "INV002.pdf": b"%PDF-1.4 test 2", "metadata.csv": csv_content.encode('utf-8'), } zip_file = create_test_zip(files) response = client.post( "/api/v1/admin/batch/upload", files={"file": ("batch.zip", zip_file, "application/zip")}, data={"upload_source": "api"}, ) # Async mode is default, should return 202 assert response.status_code == 202 result = response.json() assert "batch_id" in result assert result["status"] == "accepted" def test_get_batch_status(self, client): """Test getting batch status.""" batch_id = str(uuid4()) response = client.get(f"/api/v1/admin/batch/status/{batch_id}") assert response.status_code == 200 result = response.json() assert result["batch_id"] == batch_id assert "status" in result assert "total_files" in result def test_list_batch_uploads(self, client): """Test listing batch uploads.""" response = client.get("/api/v1/admin/batch/list") assert response.status_code == 200 result = response.json() assert "batches" in result assert "total" in result assert "limit" in result assert "offset" in result def test_upload_batch_async_mode_default(self, client): """Test async mode is default (async_mode=True).""" files = { "INV001.pdf": b"%PDF-1.4 test content", } zip_file = create_test_zip(files) response = client.post( "/api/v1/admin/batch/upload", files={"file": ("test.zip", zip_file, "application/zip")}, data={"upload_source": "ui"}, ) # Async mode should return 202 Accepted assert response.status_code == 202 result = response.json() assert result["status"] == "accepted" assert "batch_id" in result assert "status_url" in result assert "queue_depth" in result assert result["message"] == "Batch upload queued for processing" def test_upload_batch_async_mode_explicit(self, client): """Test explicit async mode (async_mode=True).""" files = { "INV001.pdf": b"%PDF-1.4 test content", } zip_file = create_test_zip(files) response = client.post( "/api/v1/admin/batch/upload", files={"file": ("test.zip", zip_file, "application/zip")}, data={"upload_source": "ui", "async_mode": "true"}, ) assert response.status_code == 202 result = response.json() assert result["status"] == "accepted" assert "batch_id" in result assert "status_url" in result def test_upload_batch_sync_mode(self, client): """Test sync mode (async_mode=False).""" files = { "INV001.pdf": b"%PDF-1.4 test content", } zip_file = create_test_zip(files) response = client.post( "/api/v1/admin/batch/upload", files={"file": ("test.zip", zip_file, "application/zip")}, data={"upload_source": "ui", "async_mode": "false"}, ) # Sync mode should return 200 OK with full results assert response.status_code == 200 result = response.json() assert "batch_id" in result assert result["status"] in ["completed", "partial", "failed"] assert "successful_files" in result def test_upload_batch_async_with_auto_label(self, client): """Test async mode with auto_label flag.""" files = { "INV001.pdf": b"%PDF-1.4 test content", } zip_file = create_test_zip(files) response = client.post( "/api/v1/admin/batch/upload", files={"file": ("test.zip", zip_file, "application/zip")}, data={ "upload_source": "ui", "async_mode": "true", "auto_label": "true", }, ) assert response.status_code == 202 result = response.json() assert result["status"] == "accepted" assert "batch_id" in result def test_upload_batch_async_without_auto_label(self, client): """Test async mode with auto_label disabled.""" files = { "INV001.pdf": b"%PDF-1.4 test content", } zip_file = create_test_zip(files) response = client.post( "/api/v1/admin/batch/upload", files={"file": ("test.zip", zip_file, "application/zip")}, data={ "upload_source": "ui", "async_mode": "true", "auto_label": "false", }, ) assert response.status_code == 202 result = response.json() assert result["status"] == "accepted" def test_upload_batch_queue_full(self, client): """Test handling queue full scenario.""" # This test would require mocking the queue to return False on submit # For now, we verify the endpoint accepts the request files = { "INV001.pdf": b"%PDF-1.4 test content", } zip_file = create_test_zip(files) response = client.post( "/api/v1/admin/batch/upload", files={"file": ("test.zip", zip_file, "application/zip")}, data={"upload_source": "ui", "async_mode": "true"}, ) # Should either accept (202) or reject if queue full (503) assert response.status_code in [202, 503] def test_async_status_url_format(self, client): """Test async response contains correctly formatted status URL.""" files = { "INV001.pdf": b"%PDF-1.4 test content", } zip_file = create_test_zip(files) response = client.post( "/api/v1/admin/batch/upload", files={"file": ("test.zip", zip_file, "application/zip")}, data={"async_mode": "true"}, ) assert response.status_code == 202 result = response.json() batch_id = result["batch_id"] expected_url = f"/api/v1/admin/batch/status/{batch_id}" assert result["status_url"] == expected_url