Add more tests

This commit is contained in:
Yaojia Wang
2026-02-01 22:40:41 +01:00
parent a564ac9d70
commit 400b12a967
55 changed files with 9306 additions and 267 deletions

View File

@@ -0,0 +1,355 @@
"""
Batch Upload Repository Integration Tests
Tests BatchUploadRepository with real database operations.
"""
from datetime import datetime, timezone
from uuid import uuid4
import pytest
from inference.data.repositories.batch_upload_repository import BatchUploadRepository
class TestBatchUploadCreate:
"""Tests for batch upload creation."""
def test_create_batch_upload(self, patched_session, admin_token):
"""Test creating a batch upload."""
repo = BatchUploadRepository()
batch = repo.create(
admin_token=admin_token.token,
filename="test_batch.zip",
file_size=10240,
upload_source="api",
)
assert batch is not None
assert batch.batch_id is not None
assert batch.filename == "test_batch.zip"
assert batch.file_size == 10240
assert batch.upload_source == "api"
assert batch.status == "processing"
assert batch.total_files == 0
assert batch.processed_files == 0
def test_create_batch_upload_default_source(self, patched_session, admin_token):
"""Test creating batch upload with default source."""
repo = BatchUploadRepository()
batch = repo.create(
admin_token=admin_token.token,
filename="ui_batch.zip",
file_size=5120,
)
assert batch.upload_source == "ui"
class TestBatchUploadRead:
"""Tests for batch upload retrieval."""
def test_get_batch_upload(self, patched_session, sample_batch_upload):
"""Test getting a batch upload by ID."""
repo = BatchUploadRepository()
batch = repo.get(sample_batch_upload.batch_id)
assert batch is not None
assert batch.batch_id == sample_batch_upload.batch_id
assert batch.filename == sample_batch_upload.filename
def test_get_nonexistent_batch_upload(self, patched_session):
"""Test getting a batch upload that doesn't exist."""
repo = BatchUploadRepository()
batch = repo.get(uuid4())
assert batch is None
def test_get_paginated_batch_uploads(self, patched_session, admin_token):
"""Test paginated batch upload listing."""
repo = BatchUploadRepository()
# Create multiple batches
for i in range(5):
repo.create(
admin_token=admin_token.token,
filename=f"batch_{i}.zip",
file_size=1024 * (i + 1),
)
batches, total = repo.get_paginated(limit=3, offset=0)
assert total == 5
assert len(batches) == 3
def test_get_paginated_with_offset(self, patched_session, admin_token):
"""Test pagination offset."""
repo = BatchUploadRepository()
for i in range(5):
repo.create(
admin_token=admin_token.token,
filename=f"batch_{i}.zip",
file_size=1024,
)
page1, _ = repo.get_paginated(limit=2, offset=0)
page2, _ = repo.get_paginated(limit=2, offset=2)
ids_page1 = {b.batch_id for b in page1}
ids_page2 = {b.batch_id for b in page2}
assert len(ids_page1 & ids_page2) == 0
class TestBatchUploadUpdate:
"""Tests for batch upload updates."""
def test_update_batch_status(self, patched_session, sample_batch_upload):
"""Test updating batch upload status."""
repo = BatchUploadRepository()
repo.update(
sample_batch_upload.batch_id,
status="completed",
total_files=10,
processed_files=10,
successful_files=8,
failed_files=2,
)
# Need to commit to see changes
patched_session.commit()
batch = repo.get(sample_batch_upload.batch_id)
assert batch.status == "completed"
assert batch.total_files == 10
assert batch.successful_files == 8
assert batch.failed_files == 2
def test_update_batch_with_error(self, patched_session, sample_batch_upload):
"""Test updating batch upload with error message."""
repo = BatchUploadRepository()
repo.update(
sample_batch_upload.batch_id,
status="failed",
error_message="ZIP extraction failed",
)
patched_session.commit()
batch = repo.get(sample_batch_upload.batch_id)
assert batch.status == "failed"
assert batch.error_message == "ZIP extraction failed"
def test_update_batch_csv_info(self, patched_session, sample_batch_upload):
"""Test updating batch with CSV information."""
repo = BatchUploadRepository()
repo.update(
sample_batch_upload.batch_id,
csv_filename="manifest.csv",
csv_row_count=100,
)
patched_session.commit()
batch = repo.get(sample_batch_upload.batch_id)
assert batch.csv_filename == "manifest.csv"
assert batch.csv_row_count == 100
class TestBatchUploadFiles:
"""Tests for batch upload file management."""
def test_create_batch_file(self, patched_session, sample_batch_upload):
"""Test creating a batch upload file record."""
repo = BatchUploadRepository()
file_record = repo.create_file(
batch_id=sample_batch_upload.batch_id,
filename="invoice_001.pdf",
status="pending",
)
assert file_record is not None
assert file_record.file_id is not None
assert file_record.filename == "invoice_001.pdf"
assert file_record.batch_id == sample_batch_upload.batch_id
assert file_record.status == "pending"
def test_create_batch_file_with_document_link(self, patched_session, sample_batch_upload, sample_document):
"""Test creating batch file linked to a document."""
repo = BatchUploadRepository()
file_record = repo.create_file(
batch_id=sample_batch_upload.batch_id,
filename="invoice_linked.pdf",
document_id=sample_document.document_id,
status="completed",
annotation_count=5,
)
assert file_record.document_id == sample_document.document_id
assert file_record.status == "completed"
assert file_record.annotation_count == 5
def test_get_batch_files(self, patched_session, sample_batch_upload):
"""Test getting all files for a batch."""
repo = BatchUploadRepository()
# Create multiple files
for i in range(3):
repo.create_file(
batch_id=sample_batch_upload.batch_id,
filename=f"file_{i}.pdf",
)
files = repo.get_files(sample_batch_upload.batch_id)
assert len(files) == 3
assert all(f.batch_id == sample_batch_upload.batch_id for f in files)
def test_get_batch_files_empty(self, patched_session, sample_batch_upload):
"""Test getting files for batch with no files."""
repo = BatchUploadRepository()
files = repo.get_files(sample_batch_upload.batch_id)
assert files == []
def test_update_batch_file_status(self, patched_session, sample_batch_upload):
"""Test updating batch file status."""
repo = BatchUploadRepository()
file_record = repo.create_file(
batch_id=sample_batch_upload.batch_id,
filename="test.pdf",
)
repo.update_file(
file_record.file_id,
status="completed",
annotation_count=10,
)
patched_session.commit()
files = repo.get_files(sample_batch_upload.batch_id)
updated_file = files[0]
assert updated_file.status == "completed"
assert updated_file.annotation_count == 10
def test_update_batch_file_with_error(self, patched_session, sample_batch_upload):
"""Test updating batch file with error."""
repo = BatchUploadRepository()
file_record = repo.create_file(
batch_id=sample_batch_upload.batch_id,
filename="corrupt.pdf",
)
repo.update_file(
file_record.file_id,
status="failed",
error_message="Invalid PDF format",
)
patched_session.commit()
files = repo.get_files(sample_batch_upload.batch_id)
updated_file = files[0]
assert updated_file.status == "failed"
assert updated_file.error_message == "Invalid PDF format"
def test_update_batch_file_with_csv_data(self, patched_session, sample_batch_upload):
"""Test updating batch file with CSV row data."""
repo = BatchUploadRepository()
file_record = repo.create_file(
batch_id=sample_batch_upload.batch_id,
filename="invoice_with_csv.pdf",
)
csv_data = {
"invoice_number": "INV-001",
"amount": "1500.00",
"supplier": "Test Corp",
}
repo.update_file(
file_record.file_id,
csv_row_data=csv_data,
)
patched_session.commit()
files = repo.get_files(sample_batch_upload.batch_id)
updated_file = files[0]
assert updated_file.csv_row_data == csv_data
class TestBatchUploadWorkflow:
"""Tests for complete batch upload workflows."""
def test_complete_batch_workflow(self, patched_session, admin_token):
"""Test complete batch upload workflow."""
repo = BatchUploadRepository()
# 1. Create batch
batch = repo.create(
admin_token=admin_token.token,
filename="full_workflow.zip",
file_size=50000,
)
# 2. Update with file count
repo.update(batch.batch_id, total_files=3)
patched_session.commit()
# 3. Create file records
file_ids = []
for i in range(3):
file_record = repo.create_file(
batch_id=batch.batch_id,
filename=f"doc_{i}.pdf",
)
file_ids.append(file_record.file_id)
# 4. Process files one by one
for i, file_id in enumerate(file_ids):
status = "completed" if i < 2 else "failed"
repo.update_file(
file_id,
status=status,
annotation_count=5 if status == "completed" else 0,
)
# 5. Update batch progress
repo.update(
batch.batch_id,
processed_files=3,
successful_files=2,
failed_files=1,
status="partial",
)
patched_session.commit()
# Verify final state
final_batch = repo.get(batch.batch_id)
assert final_batch.status == "partial"
assert final_batch.total_files == 3
assert final_batch.processed_files == 3
assert final_batch.successful_files == 2
assert final_batch.failed_files == 1
files = repo.get_files(batch.batch_id)
assert len(files) == 3
completed = [f for f in files if f.status == "completed"]
failed = [f for f in files if f.status == "failed"]
assert len(completed) == 2
assert len(failed) == 1