525 lines
16 KiB
Python
525 lines
16 KiB
Python
"""
|
|
Tests for Admin Models v2 - Batch Upload and Training Links.
|
|
|
|
Tests for new SQLModel classes: BatchUpload, BatchUploadFile,
|
|
TrainingDocumentLink, AnnotationHistory.
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime
|
|
from uuid import UUID, uuid4
|
|
|
|
from inference.data.admin_models import (
|
|
BatchUpload,
|
|
BatchUploadFile,
|
|
TrainingDocumentLink,
|
|
AnnotationHistory,
|
|
AdminDocument,
|
|
AdminAnnotation,
|
|
TrainingTask,
|
|
FIELD_CLASSES,
|
|
CSV_TO_CLASS_MAPPING,
|
|
)
|
|
|
|
|
|
class TestBatchUpload:
|
|
"""Tests for BatchUpload model."""
|
|
|
|
def test_batch_upload_creation(self):
|
|
"""Test basic batch upload creation."""
|
|
batch = BatchUpload(
|
|
admin_token="test-token",
|
|
filename="invoices.zip",
|
|
file_size=1024000,
|
|
upload_source="ui",
|
|
)
|
|
|
|
assert batch.batch_id is not None
|
|
assert isinstance(batch.batch_id, UUID)
|
|
assert batch.admin_token == "test-token"
|
|
assert batch.filename == "invoices.zip"
|
|
assert batch.file_size == 1024000
|
|
assert batch.upload_source == "ui"
|
|
assert batch.status == "processing"
|
|
assert batch.total_files == 0
|
|
assert batch.processed_files == 0
|
|
assert batch.successful_files == 0
|
|
assert batch.failed_files == 0
|
|
assert batch.error_message is None
|
|
assert batch.completed_at is None
|
|
|
|
def test_batch_upload_api_source(self):
|
|
"""Test batch upload with API source."""
|
|
batch = BatchUpload(
|
|
admin_token="api-token",
|
|
filename="batch.zip",
|
|
file_size=2048000,
|
|
upload_source="api",
|
|
)
|
|
|
|
assert batch.upload_source == "api"
|
|
|
|
def test_batch_upload_with_progress(self):
|
|
"""Test batch upload with progress tracking."""
|
|
batch = BatchUpload(
|
|
admin_token="test-token",
|
|
filename="large_batch.zip",
|
|
file_size=10240000,
|
|
total_files=100,
|
|
processed_files=50,
|
|
successful_files=48,
|
|
failed_files=2,
|
|
status="processing",
|
|
)
|
|
|
|
assert batch.total_files == 100
|
|
assert batch.processed_files == 50
|
|
assert batch.successful_files == 48
|
|
assert batch.failed_files == 2
|
|
|
|
def test_batch_upload_completed(self):
|
|
"""Test completed batch upload."""
|
|
now = datetime.utcnow()
|
|
batch = BatchUpload(
|
|
admin_token="test-token",
|
|
filename="batch.zip",
|
|
file_size=1024000,
|
|
status="completed",
|
|
total_files=10,
|
|
processed_files=10,
|
|
successful_files=10,
|
|
failed_files=0,
|
|
completed_at=now,
|
|
)
|
|
|
|
assert batch.status == "completed"
|
|
assert batch.completed_at == now
|
|
|
|
def test_batch_upload_failed(self):
|
|
"""Test failed batch upload."""
|
|
batch = BatchUpload(
|
|
admin_token="test-token",
|
|
filename="bad.zip",
|
|
file_size=1024,
|
|
status="failed",
|
|
error_message="Invalid ZIP file format",
|
|
)
|
|
|
|
assert batch.status == "failed"
|
|
assert batch.error_message == "Invalid ZIP file format"
|
|
|
|
def test_batch_upload_partial(self):
|
|
"""Test partial batch upload with some failures."""
|
|
batch = BatchUpload(
|
|
admin_token="test-token",
|
|
filename="mixed.zip",
|
|
file_size=5120000,
|
|
status="partial",
|
|
total_files=20,
|
|
processed_files=20,
|
|
successful_files=15,
|
|
failed_files=5,
|
|
)
|
|
|
|
assert batch.status == "partial"
|
|
assert batch.failed_files == 5
|
|
|
|
|
|
class TestBatchUploadFile:
|
|
"""Tests for BatchUploadFile model."""
|
|
|
|
def test_batch_upload_file_creation(self):
|
|
"""Test basic file record creation."""
|
|
batch_id = uuid4()
|
|
file_record = BatchUploadFile(
|
|
batch_id=batch_id,
|
|
filename="INV001.pdf",
|
|
)
|
|
|
|
assert file_record.file_id is not None
|
|
assert isinstance(file_record.file_id, UUID)
|
|
assert file_record.batch_id == batch_id
|
|
assert file_record.filename == "INV001.pdf"
|
|
assert file_record.status == "pending"
|
|
assert file_record.document_id is None
|
|
assert file_record.error_message is None
|
|
assert file_record.csv_row_data is None
|
|
assert file_record.processed_at is None
|
|
|
|
def test_batch_upload_file_with_document(self):
|
|
"""Test file record linked to document."""
|
|
batch_id = uuid4()
|
|
document_id = uuid4()
|
|
file_record = BatchUploadFile(
|
|
batch_id=batch_id,
|
|
document_id=document_id,
|
|
filename="INV002.pdf",
|
|
status="completed",
|
|
)
|
|
|
|
assert file_record.document_id == document_id
|
|
assert file_record.status == "completed"
|
|
|
|
def test_batch_upload_file_with_csv_data(self):
|
|
"""Test file record with CSV row data."""
|
|
batch_id = uuid4()
|
|
csv_data = {
|
|
"DocumentId": "INV003",
|
|
"InvoiceNumber": "F2024-003",
|
|
"Amount": "1500.00",
|
|
"OCR": "7350012345678",
|
|
}
|
|
file_record = BatchUploadFile(
|
|
batch_id=batch_id,
|
|
filename="INV003.pdf",
|
|
csv_row_data=csv_data,
|
|
)
|
|
|
|
assert file_record.csv_row_data == csv_data
|
|
assert file_record.csv_row_data["InvoiceNumber"] == "F2024-003"
|
|
|
|
def test_batch_upload_file_failed(self):
|
|
"""Test failed file record."""
|
|
batch_id = uuid4()
|
|
file_record = BatchUploadFile(
|
|
batch_id=batch_id,
|
|
filename="corrupted.pdf",
|
|
status="failed",
|
|
error_message="Corrupted PDF file",
|
|
)
|
|
|
|
assert file_record.status == "failed"
|
|
assert file_record.error_message == "Corrupted PDF file"
|
|
|
|
def test_batch_upload_file_skipped(self):
|
|
"""Test skipped file record."""
|
|
batch_id = uuid4()
|
|
file_record = BatchUploadFile(
|
|
batch_id=batch_id,
|
|
filename="not_a_pdf.txt",
|
|
status="skipped",
|
|
error_message="Not a PDF file",
|
|
)
|
|
|
|
assert file_record.status == "skipped"
|
|
|
|
|
|
class TestTrainingDocumentLink:
|
|
"""Tests for TrainingDocumentLink model."""
|
|
|
|
def test_training_document_link_creation(self):
|
|
"""Test basic link creation."""
|
|
task_id = uuid4()
|
|
document_id = uuid4()
|
|
link = TrainingDocumentLink(
|
|
task_id=task_id,
|
|
document_id=document_id,
|
|
)
|
|
|
|
assert link.link_id is not None
|
|
assert isinstance(link.link_id, UUID)
|
|
assert link.task_id == task_id
|
|
assert link.document_id == document_id
|
|
assert link.annotation_snapshot is None
|
|
|
|
def test_training_document_link_with_snapshot(self):
|
|
"""Test link with annotation snapshot."""
|
|
task_id = uuid4()
|
|
document_id = uuid4()
|
|
snapshot = {
|
|
"annotations": [
|
|
{
|
|
"class_id": 0,
|
|
"class_name": "invoice_number",
|
|
"text_value": "F2024-001",
|
|
"x_center": 0.5,
|
|
"y_center": 0.3,
|
|
},
|
|
{
|
|
"class_id": 6,
|
|
"class_name": "amount",
|
|
"text_value": "1500.00",
|
|
"x_center": 0.7,
|
|
"y_center": 0.6,
|
|
},
|
|
],
|
|
"total_count": 2,
|
|
"snapshot_time": "2024-01-20T15:00:00",
|
|
}
|
|
link = TrainingDocumentLink(
|
|
task_id=task_id,
|
|
document_id=document_id,
|
|
annotation_snapshot=snapshot,
|
|
)
|
|
|
|
assert link.annotation_snapshot == snapshot
|
|
assert len(link.annotation_snapshot["annotations"]) == 2
|
|
|
|
|
|
class TestAnnotationHistory:
|
|
"""Tests for AnnotationHistory model."""
|
|
|
|
def test_annotation_history_created(self):
|
|
"""Test history record for creation."""
|
|
annotation_id = uuid4()
|
|
new_value = {
|
|
"class_id": 0,
|
|
"class_name": "invoice_number",
|
|
"text_value": "F2024-001",
|
|
"bbox_x": 100,
|
|
"bbox_y": 200,
|
|
"bbox_width": 150,
|
|
"bbox_height": 30,
|
|
"source": "manual",
|
|
}
|
|
history = AnnotationHistory(
|
|
annotation_id=annotation_id,
|
|
action="created",
|
|
new_value=new_value,
|
|
changed_by="admin-token-123",
|
|
)
|
|
|
|
assert history.history_id is not None
|
|
assert history.annotation_id == annotation_id
|
|
assert history.action == "created"
|
|
assert history.previous_value is None
|
|
assert history.new_value == new_value
|
|
assert history.changed_by == "admin-token-123"
|
|
|
|
def test_annotation_history_updated(self):
|
|
"""Test history record for update."""
|
|
annotation_id = uuid4()
|
|
previous_value = {
|
|
"text_value": "F2024-001",
|
|
"bbox_x": 100,
|
|
}
|
|
new_value = {
|
|
"text_value": "F2024-001-A",
|
|
"bbox_x": 110,
|
|
}
|
|
history = AnnotationHistory(
|
|
annotation_id=annotation_id,
|
|
action="updated",
|
|
previous_value=previous_value,
|
|
new_value=new_value,
|
|
changed_by="admin-token-123",
|
|
change_reason="Corrected OCR error",
|
|
)
|
|
|
|
assert history.action == "updated"
|
|
assert history.previous_value == previous_value
|
|
assert history.new_value == new_value
|
|
assert history.change_reason == "Corrected OCR error"
|
|
|
|
def test_annotation_history_override(self):
|
|
"""Test history record for override."""
|
|
annotation_id = uuid4()
|
|
previous_value = {
|
|
"text_value": "F2024-001",
|
|
"source": "auto",
|
|
"confidence": 0.85,
|
|
}
|
|
new_value = {
|
|
"text_value": "F2024-001-CORRECTED",
|
|
"source": "manual",
|
|
"confidence": None,
|
|
}
|
|
history = AnnotationHistory(
|
|
annotation_id=annotation_id,
|
|
action="override",
|
|
previous_value=previous_value,
|
|
new_value=new_value,
|
|
changed_by="admin-token-123",
|
|
change_reason="Manual correction of auto-label",
|
|
)
|
|
|
|
assert history.action == "override"
|
|
|
|
def test_annotation_history_deleted(self):
|
|
"""Test history record for deletion."""
|
|
annotation_id = uuid4()
|
|
previous_value = {
|
|
"class_id": 6,
|
|
"class_name": "amount",
|
|
"text_value": "1500.00",
|
|
}
|
|
history = AnnotationHistory(
|
|
annotation_id=annotation_id,
|
|
action="deleted",
|
|
previous_value=previous_value,
|
|
changed_by="admin-token-123",
|
|
change_reason="Incorrect annotation",
|
|
)
|
|
|
|
assert history.action == "deleted"
|
|
assert history.new_value is None
|
|
|
|
|
|
class TestAdminDocumentExtensions:
|
|
"""Tests for AdminDocument extensions."""
|
|
|
|
def test_document_with_upload_source(self):
|
|
"""Test document with upload source field."""
|
|
doc = AdminDocument(
|
|
admin_token="test-token",
|
|
filename="invoice.pdf",
|
|
file_size=1024,
|
|
content_type="application/pdf",
|
|
file_path="/tmp/invoice.pdf",
|
|
upload_source="api",
|
|
)
|
|
|
|
assert doc.upload_source == "api"
|
|
|
|
def test_document_with_batch_id(self):
|
|
"""Test document linked to batch upload."""
|
|
batch_id = uuid4()
|
|
doc = AdminDocument(
|
|
admin_token="test-token",
|
|
filename="invoice.pdf",
|
|
file_size=1024,
|
|
content_type="application/pdf",
|
|
file_path="/tmp/invoice.pdf",
|
|
batch_id=batch_id,
|
|
)
|
|
|
|
assert doc.batch_id == batch_id
|
|
|
|
def test_document_with_csv_field_values(self):
|
|
"""Test document with CSV field values."""
|
|
csv_values = {
|
|
"InvoiceNumber": "F2024-001",
|
|
"Amount": "1500.00",
|
|
"OCR": "7350012345678",
|
|
}
|
|
doc = AdminDocument(
|
|
admin_token="test-token",
|
|
filename="invoice.pdf",
|
|
file_size=1024,
|
|
content_type="application/pdf",
|
|
file_path="/tmp/invoice.pdf",
|
|
csv_field_values=csv_values,
|
|
)
|
|
|
|
assert doc.csv_field_values == csv_values
|
|
|
|
def test_document_with_annotation_lock(self):
|
|
"""Test document with annotation lock."""
|
|
lock_until = datetime.utcnow()
|
|
doc = AdminDocument(
|
|
admin_token="test-token",
|
|
filename="invoice.pdf",
|
|
file_size=1024,
|
|
content_type="application/pdf",
|
|
file_path="/tmp/invoice.pdf",
|
|
annotation_lock_until=lock_until,
|
|
)
|
|
|
|
assert doc.annotation_lock_until == lock_until
|
|
|
|
|
|
class TestAdminAnnotationExtensions:
|
|
"""Tests for AdminAnnotation extensions."""
|
|
|
|
def test_annotation_with_verification(self):
|
|
"""Test annotation with verification fields."""
|
|
now = datetime.utcnow()
|
|
ann = AdminAnnotation(
|
|
document_id=uuid4(),
|
|
class_id=0,
|
|
class_name="invoice_number",
|
|
x_center=0.5,
|
|
y_center=0.3,
|
|
width=0.2,
|
|
height=0.05,
|
|
bbox_x=100,
|
|
bbox_y=200,
|
|
bbox_width=150,
|
|
bbox_height=30,
|
|
is_verified=True,
|
|
verified_at=now,
|
|
verified_by="admin-token-123",
|
|
)
|
|
|
|
assert ann.is_verified is True
|
|
assert ann.verified_at == now
|
|
assert ann.verified_by == "admin-token-123"
|
|
|
|
def test_annotation_with_override_info(self):
|
|
"""Test annotation with override information."""
|
|
original_id = uuid4()
|
|
ann = AdminAnnotation(
|
|
document_id=uuid4(),
|
|
class_id=0,
|
|
class_name="invoice_number",
|
|
x_center=0.5,
|
|
y_center=0.3,
|
|
width=0.2,
|
|
height=0.05,
|
|
bbox_x=100,
|
|
bbox_y=200,
|
|
bbox_width=150,
|
|
bbox_height=30,
|
|
source="manual",
|
|
override_source="auto",
|
|
original_annotation_id=original_id,
|
|
)
|
|
|
|
assert ann.override_source == "auto"
|
|
assert ann.original_annotation_id == original_id
|
|
|
|
|
|
class TestTrainingTaskExtensions:
|
|
"""Tests for TrainingTask extensions."""
|
|
|
|
def test_training_task_with_document_count(self):
|
|
"""Test training task with document count."""
|
|
task = TrainingTask(
|
|
admin_token="test-token",
|
|
name="Training Run 2024-01",
|
|
document_count=500,
|
|
)
|
|
|
|
assert task.document_count == 500
|
|
|
|
def test_training_task_with_metrics(self):
|
|
"""Test training task with extracted metrics."""
|
|
task = TrainingTask(
|
|
admin_token="test-token",
|
|
name="Training Run 2024-01",
|
|
status="completed",
|
|
metrics_mAP=0.935,
|
|
metrics_precision=0.92,
|
|
metrics_recall=0.88,
|
|
)
|
|
|
|
assert task.metrics_mAP == 0.935
|
|
assert task.metrics_precision == 0.92
|
|
assert task.metrics_recall == 0.88
|
|
|
|
|
|
class TestCSVToClassMapping:
|
|
"""Tests for CSV column to class ID mapping."""
|
|
|
|
def test_csv_mapping_exists(self):
|
|
"""Test that CSV mapping is defined."""
|
|
assert CSV_TO_CLASS_MAPPING is not None
|
|
assert len(CSV_TO_CLASS_MAPPING) > 0
|
|
|
|
def test_csv_mapping_values(self):
|
|
"""Test specific CSV column mappings."""
|
|
assert CSV_TO_CLASS_MAPPING["InvoiceNumber"] == 0
|
|
assert CSV_TO_CLASS_MAPPING["InvoiceDate"] == 1
|
|
assert CSV_TO_CLASS_MAPPING["InvoiceDueDate"] == 2
|
|
assert CSV_TO_CLASS_MAPPING["OCR"] == 3
|
|
assert CSV_TO_CLASS_MAPPING["Bankgiro"] == 4
|
|
assert CSV_TO_CLASS_MAPPING["Plusgiro"] == 5
|
|
assert CSV_TO_CLASS_MAPPING["Amount"] == 6
|
|
assert CSV_TO_CLASS_MAPPING["supplier_organisation_number"] == 7
|
|
assert CSV_TO_CLASS_MAPPING["customer_number"] == 9
|
|
|
|
def test_csv_mapping_matches_field_classes(self):
|
|
"""Test that CSV mapping is consistent with FIELD_CLASSES."""
|
|
for csv_name, class_id in CSV_TO_CLASS_MAPPING.items():
|
|
assert class_id in FIELD_CLASSES
|