369 lines
12 KiB
Python
369 lines
12 KiB
Python
"""
|
|
Tests for Batch Upload Routes
|
|
"""
|
|
|
|
import io
|
|
import zipfile
|
|
from datetime import datetime
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
from fastapi import FastAPI
|
|
from fastapi.testclient import TestClient
|
|
|
|
from inference.web.api.v1.batch.routes import router
|
|
from inference.web.core.auth import validate_admin_token, get_admin_db
|
|
from inference.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
|
|
from inference.web.services.batch_upload import BatchUploadService
|
|
|
|
|
|
class MockAdminDB:
|
|
"""Mock AdminDB for testing."""
|
|
|
|
def __init__(self):
|
|
self.batches = {}
|
|
self.batch_files = {}
|
|
|
|
def create_batch_upload(self, admin_token, filename, file_size, upload_source):
|
|
batch_id = uuid4()
|
|
batch = type('BatchUpload', (), {
|
|
'batch_id': batch_id,
|
|
'admin_token': admin_token,
|
|
'filename': filename,
|
|
'file_size': file_size,
|
|
'upload_source': upload_source,
|
|
'status': 'processing',
|
|
'total_files': 0,
|
|
'processed_files': 0,
|
|
'successful_files': 0,
|
|
'failed_files': 0,
|
|
'csv_filename': None,
|
|
'csv_row_count': None,
|
|
'error_message': None,
|
|
'created_at': datetime.utcnow(),
|
|
'completed_at': None,
|
|
})()
|
|
self.batches[batch_id] = batch
|
|
return batch
|
|
|
|
def update_batch_upload(self, batch_id, **kwargs):
|
|
if batch_id in self.batches:
|
|
batch = self.batches[batch_id]
|
|
for key, value in kwargs.items():
|
|
setattr(batch, key, value)
|
|
|
|
def create_batch_upload_file(self, batch_id, filename, **kwargs):
|
|
file_id = uuid4()
|
|
defaults = {
|
|
'file_id': file_id,
|
|
'batch_id': batch_id,
|
|
'filename': filename,
|
|
'status': 'pending',
|
|
'error_message': None,
|
|
'annotation_count': 0,
|
|
'csv_row_data': None,
|
|
}
|
|
defaults.update(kwargs)
|
|
file_record = type('BatchUploadFile', (), defaults)()
|
|
if batch_id not in self.batch_files:
|
|
self.batch_files[batch_id] = []
|
|
self.batch_files[batch_id].append(file_record)
|
|
return file_record
|
|
|
|
def update_batch_upload_file(self, file_id, **kwargs):
|
|
for files in self.batch_files.values():
|
|
for file_record in files:
|
|
if file_record.file_id == file_id:
|
|
for key, value in kwargs.items():
|
|
setattr(file_record, key, value)
|
|
return
|
|
|
|
def get_batch_upload(self, batch_id):
|
|
return self.batches.get(batch_id, type('BatchUpload', (), {
|
|
'batch_id': batch_id,
|
|
'admin_token': 'test-token',
|
|
'filename': 'test.zip',
|
|
'status': 'completed',
|
|
'total_files': 2,
|
|
'processed_files': 2,
|
|
'successful_files': 2,
|
|
'failed_files': 0,
|
|
'csv_filename': None,
|
|
'csv_row_count': None,
|
|
'error_message': None,
|
|
'created_at': datetime.utcnow(),
|
|
'completed_at': datetime.utcnow(),
|
|
})())
|
|
|
|
def get_batch_upload_files(self, batch_id):
|
|
return self.batch_files.get(batch_id, [])
|
|
|
|
def get_batch_uploads_by_token(self, admin_token, limit=50, offset=0):
|
|
"""Get batches filtered by admin token with pagination."""
|
|
token_batches = [b for b in self.batches.values() if b.admin_token == admin_token]
|
|
total = len(token_batches)
|
|
return token_batches[offset:offset+limit], total
|
|
|
|
|
|
@pytest.fixture(scope="class")
|
|
def app():
|
|
"""Create test FastAPI app with mocked dependencies."""
|
|
app = FastAPI()
|
|
|
|
# Create mock admin DB
|
|
mock_admin_db = MockAdminDB()
|
|
|
|
# Override dependencies
|
|
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
|
|
app.dependency_overrides[get_admin_db] = lambda: mock_admin_db
|
|
|
|
# Initialize batch queue with mock service
|
|
batch_service = BatchUploadService(mock_admin_db)
|
|
init_batch_queue(batch_service)
|
|
|
|
app.include_router(router)
|
|
|
|
yield app
|
|
|
|
# Cleanup: shutdown batch queue after all tests in class
|
|
shutdown_batch_queue()
|
|
|
|
|
|
@pytest.fixture
|
|
def client(app):
|
|
"""Create test client."""
|
|
return TestClient(app)
|
|
|
|
|
|
def create_test_zip(files):
|
|
"""Create a test ZIP file."""
|
|
zip_buffer = io.BytesIO()
|
|
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
|
|
for filename, content in files.items():
|
|
zip_file.writestr(filename, content)
|
|
zip_buffer.seek(0)
|
|
return zip_buffer
|
|
|
|
|
|
class TestBatchUploadRoutes:
|
|
"""Tests for batch upload API routes."""
|
|
|
|
def test_upload_batch_success(self, client):
|
|
"""Test successful batch upload (defaults to async mode)."""
|
|
files = {
|
|
"INV001.pdf": b"%PDF-1.4 test content",
|
|
"INV002.pdf": b"%PDF-1.4 test content 2",
|
|
}
|
|
zip_file = create_test_zip(files)
|
|
|
|
response = client.post(
|
|
"/api/v1/admin/batch/upload",
|
|
files={"file": ("test.zip", zip_file, "application/zip")},
|
|
data={"upload_source": "ui"},
|
|
)
|
|
|
|
# Async mode is default, should return 202
|
|
assert response.status_code == 202
|
|
result = response.json()
|
|
assert "batch_id" in result
|
|
assert result["status"] == "accepted"
|
|
|
|
def test_upload_batch_non_zip_file(self, client):
|
|
"""Test uploading non-ZIP file."""
|
|
response = client.post(
|
|
"/api/v1/admin/batch/upload",
|
|
files={"file": ("test.pdf", io.BytesIO(b"test"), "application/pdf")},
|
|
data={"upload_source": "ui"},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
assert "Only ZIP files" in response.json()["detail"]
|
|
|
|
def test_upload_batch_with_csv(self, client):
|
|
"""Test batch upload with CSV (defaults to async)."""
|
|
csv_content = """DocumentId,InvoiceNumber,Amount
|
|
INV001,F2024-001,1500.00
|
|
INV002,F2024-002,2500.00
|
|
"""
|
|
files = {
|
|
"INV001.pdf": b"%PDF-1.4 test",
|
|
"INV002.pdf": b"%PDF-1.4 test 2",
|
|
"metadata.csv": csv_content.encode('utf-8'),
|
|
}
|
|
zip_file = create_test_zip(files)
|
|
|
|
response = client.post(
|
|
"/api/v1/admin/batch/upload",
|
|
files={"file": ("batch.zip", zip_file, "application/zip")},
|
|
data={"upload_source": "api"},
|
|
)
|
|
|
|
# Async mode is default, should return 202
|
|
assert response.status_code == 202
|
|
result = response.json()
|
|
assert "batch_id" in result
|
|
assert result["status"] == "accepted"
|
|
|
|
def test_get_batch_status(self, client):
|
|
"""Test getting batch status."""
|
|
batch_id = str(uuid4())
|
|
response = client.get(f"/api/v1/admin/batch/status/{batch_id}")
|
|
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
assert result["batch_id"] == batch_id
|
|
assert "status" in result
|
|
assert "total_files" in result
|
|
|
|
def test_list_batch_uploads(self, client):
|
|
"""Test listing batch uploads."""
|
|
response = client.get("/api/v1/admin/batch/list")
|
|
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
assert "batches" in result
|
|
assert "total" in result
|
|
assert "limit" in result
|
|
assert "offset" in result
|
|
|
|
def test_upload_batch_async_mode_default(self, client):
|
|
"""Test async mode is default (async_mode=True)."""
|
|
files = {
|
|
"INV001.pdf": b"%PDF-1.4 test content",
|
|
}
|
|
zip_file = create_test_zip(files)
|
|
|
|
response = client.post(
|
|
"/api/v1/admin/batch/upload",
|
|
files={"file": ("test.zip", zip_file, "application/zip")},
|
|
data={"upload_source": "ui"},
|
|
)
|
|
|
|
# Async mode should return 202 Accepted
|
|
assert response.status_code == 202
|
|
result = response.json()
|
|
assert result["status"] == "accepted"
|
|
assert "batch_id" in result
|
|
assert "status_url" in result
|
|
assert "queue_depth" in result
|
|
assert result["message"] == "Batch upload queued for processing"
|
|
|
|
def test_upload_batch_async_mode_explicit(self, client):
|
|
"""Test explicit async mode (async_mode=True)."""
|
|
files = {
|
|
"INV001.pdf": b"%PDF-1.4 test content",
|
|
}
|
|
zip_file = create_test_zip(files)
|
|
|
|
response = client.post(
|
|
"/api/v1/admin/batch/upload",
|
|
files={"file": ("test.zip", zip_file, "application/zip")},
|
|
data={"upload_source": "ui", "async_mode": "true"},
|
|
)
|
|
|
|
assert response.status_code == 202
|
|
result = response.json()
|
|
assert result["status"] == "accepted"
|
|
assert "batch_id" in result
|
|
assert "status_url" in result
|
|
|
|
def test_upload_batch_sync_mode(self, client):
|
|
"""Test sync mode (async_mode=False)."""
|
|
files = {
|
|
"INV001.pdf": b"%PDF-1.4 test content",
|
|
}
|
|
zip_file = create_test_zip(files)
|
|
|
|
response = client.post(
|
|
"/api/v1/admin/batch/upload",
|
|
files={"file": ("test.zip", zip_file, "application/zip")},
|
|
data={"upload_source": "ui", "async_mode": "false"},
|
|
)
|
|
|
|
# Sync mode should return 200 OK with full results
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
assert "batch_id" in result
|
|
assert result["status"] in ["completed", "partial", "failed"]
|
|
assert "successful_files" in result
|
|
|
|
def test_upload_batch_async_with_auto_label(self, client):
|
|
"""Test async mode with auto_label flag."""
|
|
files = {
|
|
"INV001.pdf": b"%PDF-1.4 test content",
|
|
}
|
|
zip_file = create_test_zip(files)
|
|
|
|
response = client.post(
|
|
"/api/v1/admin/batch/upload",
|
|
files={"file": ("test.zip", zip_file, "application/zip")},
|
|
data={
|
|
"upload_source": "ui",
|
|
"async_mode": "true",
|
|
"auto_label": "true",
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 202
|
|
result = response.json()
|
|
assert result["status"] == "accepted"
|
|
assert "batch_id" in result
|
|
|
|
def test_upload_batch_async_without_auto_label(self, client):
|
|
"""Test async mode with auto_label disabled."""
|
|
files = {
|
|
"INV001.pdf": b"%PDF-1.4 test content",
|
|
}
|
|
zip_file = create_test_zip(files)
|
|
|
|
response = client.post(
|
|
"/api/v1/admin/batch/upload",
|
|
files={"file": ("test.zip", zip_file, "application/zip")},
|
|
data={
|
|
"upload_source": "ui",
|
|
"async_mode": "true",
|
|
"auto_label": "false",
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 202
|
|
result = response.json()
|
|
assert result["status"] == "accepted"
|
|
|
|
def test_upload_batch_queue_full(self, client):
|
|
"""Test handling queue full scenario."""
|
|
# This test would require mocking the queue to return False on submit
|
|
# For now, we verify the endpoint accepts the request
|
|
files = {
|
|
"INV001.pdf": b"%PDF-1.4 test content",
|
|
}
|
|
zip_file = create_test_zip(files)
|
|
|
|
response = client.post(
|
|
"/api/v1/admin/batch/upload",
|
|
files={"file": ("test.zip", zip_file, "application/zip")},
|
|
data={"upload_source": "ui", "async_mode": "true"},
|
|
)
|
|
|
|
# Should either accept (202) or reject if queue full (503)
|
|
assert response.status_code in [202, 503]
|
|
|
|
def test_async_status_url_format(self, client):
|
|
"""Test async response contains correctly formatted status URL."""
|
|
files = {
|
|
"INV001.pdf": b"%PDF-1.4 test content",
|
|
}
|
|
zip_file = create_test_zip(files)
|
|
|
|
response = client.post(
|
|
"/api/v1/admin/batch/upload",
|
|
files={"file": ("test.zip", zip_file, "application/zip")},
|
|
data={"async_mode": "true"},
|
|
)
|
|
|
|
assert response.status_code == 202
|
|
result = response.json()
|
|
batch_id = result["batch_id"]
|
|
expected_url = f"/api/v1/admin/batch/status/{batch_id}"
|
|
assert result["status_url"] == expected_url
|