""" Tests for Admin Models v2 - Batch Upload and Training Links. Tests for new SQLModel classes: BatchUpload, BatchUploadFile, TrainingDocumentLink, AnnotationHistory. """ import pytest from datetime import datetime from uuid import UUID, uuid4 from inference.data.admin_models import ( BatchUpload, BatchUploadFile, TrainingDocumentLink, AnnotationHistory, AdminDocument, AdminAnnotation, TrainingTask, ) from shared.fields import FIELD_CLASSES, CSV_TO_CLASS_MAPPING class TestBatchUpload: """Tests for BatchUpload model.""" def test_batch_upload_creation(self): """Test basic batch upload creation.""" batch = BatchUpload( admin_token="test-token", filename="invoices.zip", file_size=1024000, upload_source="ui", ) assert batch.batch_id is not None assert isinstance(batch.batch_id, UUID) assert batch.admin_token == "test-token" assert batch.filename == "invoices.zip" assert batch.file_size == 1024000 assert batch.upload_source == "ui" assert batch.status == "processing" assert batch.total_files == 0 assert batch.processed_files == 0 assert batch.successful_files == 0 assert batch.failed_files == 0 assert batch.error_message is None assert batch.completed_at is None def test_batch_upload_api_source(self): """Test batch upload with API source.""" batch = BatchUpload( admin_token="api-token", filename="batch.zip", file_size=2048000, upload_source="api", ) assert batch.upload_source == "api" def test_batch_upload_with_progress(self): """Test batch upload with progress tracking.""" batch = BatchUpload( admin_token="test-token", filename="large_batch.zip", file_size=10240000, total_files=100, processed_files=50, successful_files=48, failed_files=2, status="processing", ) assert batch.total_files == 100 assert batch.processed_files == 50 assert batch.successful_files == 48 assert batch.failed_files == 2 def test_batch_upload_completed(self): """Test completed batch upload.""" now = datetime.utcnow() batch = BatchUpload( admin_token="test-token", filename="batch.zip", file_size=1024000, status="completed", total_files=10, processed_files=10, successful_files=10, failed_files=0, completed_at=now, ) assert batch.status == "completed" assert batch.completed_at == now def test_batch_upload_failed(self): """Test failed batch upload.""" batch = BatchUpload( admin_token="test-token", filename="bad.zip", file_size=1024, status="failed", error_message="Invalid ZIP file format", ) assert batch.status == "failed" assert batch.error_message == "Invalid ZIP file format" def test_batch_upload_partial(self): """Test partial batch upload with some failures.""" batch = BatchUpload( admin_token="test-token", filename="mixed.zip", file_size=5120000, status="partial", total_files=20, processed_files=20, successful_files=15, failed_files=5, ) assert batch.status == "partial" assert batch.failed_files == 5 class TestBatchUploadFile: """Tests for BatchUploadFile model.""" def test_batch_upload_file_creation(self): """Test basic file record creation.""" batch_id = uuid4() file_record = BatchUploadFile( batch_id=batch_id, filename="INV001.pdf", ) assert file_record.file_id is not None assert isinstance(file_record.file_id, UUID) assert file_record.batch_id == batch_id assert file_record.filename == "INV001.pdf" assert file_record.status == "pending" assert file_record.document_id is None assert file_record.error_message is None assert file_record.csv_row_data is None assert file_record.processed_at is None def test_batch_upload_file_with_document(self): """Test file record linked to document.""" batch_id = uuid4() document_id = uuid4() file_record = BatchUploadFile( batch_id=batch_id, document_id=document_id, filename="INV002.pdf", status="completed", ) assert file_record.document_id == document_id assert file_record.status == "completed" def test_batch_upload_file_with_csv_data(self): """Test file record with CSV row data.""" batch_id = uuid4() csv_data = { "DocumentId": "INV003", "InvoiceNumber": "F2024-003", "Amount": "1500.00", "OCR": "7350012345678", } file_record = BatchUploadFile( batch_id=batch_id, filename="INV003.pdf", csv_row_data=csv_data, ) assert file_record.csv_row_data == csv_data assert file_record.csv_row_data["InvoiceNumber"] == "F2024-003" def test_batch_upload_file_failed(self): """Test failed file record.""" batch_id = uuid4() file_record = BatchUploadFile( batch_id=batch_id, filename="corrupted.pdf", status="failed", error_message="Corrupted PDF file", ) assert file_record.status == "failed" assert file_record.error_message == "Corrupted PDF file" def test_batch_upload_file_skipped(self): """Test skipped file record.""" batch_id = uuid4() file_record = BatchUploadFile( batch_id=batch_id, filename="not_a_pdf.txt", status="skipped", error_message="Not a PDF file", ) assert file_record.status == "skipped" class TestTrainingDocumentLink: """Tests for TrainingDocumentLink model.""" def test_training_document_link_creation(self): """Test basic link creation.""" task_id = uuid4() document_id = uuid4() link = TrainingDocumentLink( task_id=task_id, document_id=document_id, ) assert link.link_id is not None assert isinstance(link.link_id, UUID) assert link.task_id == task_id assert link.document_id == document_id assert link.annotation_snapshot is None def test_training_document_link_with_snapshot(self): """Test link with annotation snapshot.""" task_id = uuid4() document_id = uuid4() snapshot = { "annotations": [ { "class_id": 0, "class_name": "invoice_number", "text_value": "F2024-001", "x_center": 0.5, "y_center": 0.3, }, { "class_id": 6, "class_name": "amount", "text_value": "1500.00", "x_center": 0.7, "y_center": 0.6, }, ], "total_count": 2, "snapshot_time": "2024-01-20T15:00:00", } link = TrainingDocumentLink( task_id=task_id, document_id=document_id, annotation_snapshot=snapshot, ) assert link.annotation_snapshot == snapshot assert len(link.annotation_snapshot["annotations"]) == 2 class TestAnnotationHistory: """Tests for AnnotationHistory model.""" def test_annotation_history_created(self): """Test history record for creation.""" annotation_id = uuid4() new_value = { "class_id": 0, "class_name": "invoice_number", "text_value": "F2024-001", "bbox_x": 100, "bbox_y": 200, "bbox_width": 150, "bbox_height": 30, "source": "manual", } history = AnnotationHistory( annotation_id=annotation_id, action="created", new_value=new_value, changed_by="admin-token-123", ) assert history.history_id is not None assert history.annotation_id == annotation_id assert history.action == "created" assert history.previous_value is None assert history.new_value == new_value assert history.changed_by == "admin-token-123" def test_annotation_history_updated(self): """Test history record for update.""" annotation_id = uuid4() previous_value = { "text_value": "F2024-001", "bbox_x": 100, } new_value = { "text_value": "F2024-001-A", "bbox_x": 110, } history = AnnotationHistory( annotation_id=annotation_id, action="updated", previous_value=previous_value, new_value=new_value, changed_by="admin-token-123", change_reason="Corrected OCR error", ) assert history.action == "updated" assert history.previous_value == previous_value assert history.new_value == new_value assert history.change_reason == "Corrected OCR error" def test_annotation_history_override(self): """Test history record for override.""" annotation_id = uuid4() previous_value = { "text_value": "F2024-001", "source": "auto", "confidence": 0.85, } new_value = { "text_value": "F2024-001-CORRECTED", "source": "manual", "confidence": None, } history = AnnotationHistory( annotation_id=annotation_id, action="override", previous_value=previous_value, new_value=new_value, changed_by="admin-token-123", change_reason="Manual correction of auto-label", ) assert history.action == "override" def test_annotation_history_deleted(self): """Test history record for deletion.""" annotation_id = uuid4() previous_value = { "class_id": 6, "class_name": "amount", "text_value": "1500.00", } history = AnnotationHistory( annotation_id=annotation_id, action="deleted", previous_value=previous_value, changed_by="admin-token-123", change_reason="Incorrect annotation", ) assert history.action == "deleted" assert history.new_value is None class TestAdminDocumentExtensions: """Tests for AdminDocument extensions.""" def test_document_with_upload_source(self): """Test document with upload source field.""" doc = AdminDocument( admin_token="test-token", filename="invoice.pdf", file_size=1024, content_type="application/pdf", file_path="/tmp/invoice.pdf", upload_source="api", ) assert doc.upload_source == "api" def test_document_with_batch_id(self): """Test document linked to batch upload.""" batch_id = uuid4() doc = AdminDocument( admin_token="test-token", filename="invoice.pdf", file_size=1024, content_type="application/pdf", file_path="/tmp/invoice.pdf", batch_id=batch_id, ) assert doc.batch_id == batch_id def test_document_with_csv_field_values(self): """Test document with CSV field values.""" csv_values = { "InvoiceNumber": "F2024-001", "Amount": "1500.00", "OCR": "7350012345678", } doc = AdminDocument( admin_token="test-token", filename="invoice.pdf", file_size=1024, content_type="application/pdf", file_path="/tmp/invoice.pdf", csv_field_values=csv_values, ) assert doc.csv_field_values == csv_values def test_document_with_annotation_lock(self): """Test document with annotation lock.""" lock_until = datetime.utcnow() doc = AdminDocument( admin_token="test-token", filename="invoice.pdf", file_size=1024, content_type="application/pdf", file_path="/tmp/invoice.pdf", annotation_lock_until=lock_until, ) assert doc.annotation_lock_until == lock_until class TestAdminAnnotationExtensions: """Tests for AdminAnnotation extensions.""" def test_annotation_with_verification(self): """Test annotation with verification fields.""" now = datetime.utcnow() ann = AdminAnnotation( document_id=uuid4(), class_id=0, class_name="invoice_number", x_center=0.5, y_center=0.3, width=0.2, height=0.05, bbox_x=100, bbox_y=200, bbox_width=150, bbox_height=30, is_verified=True, verified_at=now, verified_by="admin-token-123", ) assert ann.is_verified is True assert ann.verified_at == now assert ann.verified_by == "admin-token-123" def test_annotation_with_override_info(self): """Test annotation with override information.""" original_id = uuid4() ann = AdminAnnotation( document_id=uuid4(), class_id=0, class_name="invoice_number", x_center=0.5, y_center=0.3, width=0.2, height=0.05, bbox_x=100, bbox_y=200, bbox_width=150, bbox_height=30, source="manual", override_source="auto", original_annotation_id=original_id, ) assert ann.override_source == "auto" assert ann.original_annotation_id == original_id class TestTrainingTaskExtensions: """Tests for TrainingTask extensions.""" def test_training_task_with_document_count(self): """Test training task with document count.""" task = TrainingTask( admin_token="test-token", name="Training Run 2024-01", document_count=500, ) assert task.document_count == 500 def test_training_task_with_metrics(self): """Test training task with extracted metrics.""" task = TrainingTask( admin_token="test-token", name="Training Run 2024-01", status="completed", metrics_mAP=0.935, metrics_precision=0.92, metrics_recall=0.88, ) assert task.metrics_mAP == 0.935 assert task.metrics_precision == 0.92 assert task.metrics_recall == 0.88 class TestCSVToClassMapping: """Tests for CSV column to class ID mapping.""" def test_csv_mapping_exists(self): """Test that CSV mapping is defined.""" assert CSV_TO_CLASS_MAPPING is not None assert len(CSV_TO_CLASS_MAPPING) > 0 def test_csv_mapping_values(self): """Test specific CSV column mappings. Note: customer_number is class 8 (verified from trained model best.pt). """ assert CSV_TO_CLASS_MAPPING["InvoiceNumber"] == 0 assert CSV_TO_CLASS_MAPPING["InvoiceDate"] == 1 assert CSV_TO_CLASS_MAPPING["InvoiceDueDate"] == 2 assert CSV_TO_CLASS_MAPPING["OCR"] == 3 assert CSV_TO_CLASS_MAPPING["Bankgiro"] == 4 assert CSV_TO_CLASS_MAPPING["Plusgiro"] == 5 assert CSV_TO_CLASS_MAPPING["Amount"] == 6 assert CSV_TO_CLASS_MAPPING["supplier_organisation_number"] == 7 assert CSV_TO_CLASS_MAPPING["customer_number"] == 8 # Fixed: was 9, model uses 8 def test_csv_mapping_matches_field_classes(self): """Test that CSV mapping is consistent with FIELD_CLASSES.""" for csv_name, class_id in CSV_TO_CLASS_MAPPING.items(): assert class_id in FIELD_CLASSES