""" Tests for field configuration - Single Source of Truth. These tests ensure consistency across all field definitions and prevent accidental changes that could break model inference. CRITICAL: These tests verify that field definitions match the trained YOLO model. If these tests fail, it likely means someone modified field IDs incorrectly. """ import pytest from shared.fields import ( FIELD_DEFINITIONS, CLASS_NAMES, FIELD_CLASSES, FIELD_CLASS_IDS, CLASS_TO_FIELD, CSV_TO_CLASS_MAPPING, TRAINING_FIELD_CLASSES, NUM_CLASSES, FieldDefinition, ) class TestFieldDefinitionsIntegrity: """Tests to ensure field definitions are complete and consistent.""" def test_exactly_10_field_definitions(self): """Verify we have exactly 10 field classes (matching trained model).""" assert len(FIELD_DEFINITIONS) == 10 assert NUM_CLASSES == 10 def test_class_ids_are_sequential(self): """Verify class IDs are 0-9 without gaps.""" class_ids = {fd.class_id for fd in FIELD_DEFINITIONS} assert class_ids == set(range(10)) def test_class_ids_are_unique(self): """Verify no duplicate class IDs.""" class_ids = [fd.class_id for fd in FIELD_DEFINITIONS] assert len(class_ids) == len(set(class_ids)) def test_class_names_are_unique(self): """Verify no duplicate class names.""" class_names = [fd.class_name for fd in FIELD_DEFINITIONS] assert len(class_names) == len(set(class_names)) def test_field_definition_is_immutable(self): """Verify FieldDefinition is frozen (immutable).""" fd = FIELD_DEFINITIONS[0] with pytest.raises(AttributeError): fd.class_id = 99 # type: ignore class TestModelCompatibility: """Tests to verify field definitions match the trained YOLO model. These exact values are read from runs/train/invoice_fields/weights/best.pt and MUST NOT be changed without retraining the model. """ # Expected model.names from best.pt - DO NOT CHANGE EXPECTED_MODEL_NAMES = { 0: "invoice_number", 1: "invoice_date", 2: "invoice_due_date", 3: "ocr_number", 4: "bankgiro", 5: "plusgiro", 6: "amount", 7: "supplier_org_number", 8: "customer_number", 9: "payment_line", } def test_field_classes_match_model(self): """CRITICAL: Verify FIELD_CLASSES matches trained model exactly.""" assert FIELD_CLASSES == self.EXPECTED_MODEL_NAMES def test_class_names_order_matches_model(self): """CRITICAL: Verify CLASS_NAMES order matches model class IDs.""" expected_order = [ self.EXPECTED_MODEL_NAMES[i] for i in range(10) ] assert CLASS_NAMES == expected_order def test_customer_number_is_class_8(self): """CRITICAL: customer_number must be class 8 (not 9).""" assert FIELD_CLASS_IDS["customer_number"] == 8 assert FIELD_CLASSES[8] == "customer_number" def test_payment_line_is_class_9(self): """CRITICAL: payment_line must be class 9 (not 8).""" assert FIELD_CLASS_IDS["payment_line"] == 9 assert FIELD_CLASSES[9] == "payment_line" class TestMappingConsistency: """Tests to verify all mappings are consistent with each other.""" def test_field_classes_and_field_class_ids_are_inverses(self): """Verify FIELD_CLASSES and FIELD_CLASS_IDS are proper inverses.""" for class_id, class_name in FIELD_CLASSES.items(): assert FIELD_CLASS_IDS[class_name] == class_id for class_name, class_id in FIELD_CLASS_IDS.items(): assert FIELD_CLASSES[class_id] == class_name def test_class_names_matches_field_classes_values(self): """Verify CLASS_NAMES list matches FIELD_CLASSES values in order.""" for i, class_name in enumerate(CLASS_NAMES): assert FIELD_CLASSES[i] == class_name def test_class_to_field_has_all_classes(self): """Verify CLASS_TO_FIELD has mapping for all class names.""" for class_name in CLASS_NAMES: assert class_name in CLASS_TO_FIELD def test_csv_mapping_excludes_derived_fields(self): """Verify CSV_TO_CLASS_MAPPING excludes derived fields like payment_line.""" # payment_line is derived, should not be in CSV mapping assert "payment_line" not in CSV_TO_CLASS_MAPPING # All non-derived fields should be in CSV mapping for fd in FIELD_DEFINITIONS: if not fd.is_derived: assert fd.field_name in CSV_TO_CLASS_MAPPING def test_training_field_classes_includes_all(self): """Verify TRAINING_FIELD_CLASSES includes all fields including derived.""" for fd in FIELD_DEFINITIONS: assert fd.field_name in TRAINING_FIELD_CLASSES assert TRAINING_FIELD_CLASSES[fd.field_name] == fd.class_id class TestSpecificFieldDefinitions: """Tests for specific field definitions to catch common mistakes.""" @pytest.mark.parametrize( "class_id,expected_class_name", [ (0, "invoice_number"), (1, "invoice_date"), (2, "invoice_due_date"), (3, "ocr_number"), (4, "bankgiro"), (5, "plusgiro"), (6, "amount"), (7, "supplier_org_number"), (8, "customer_number"), (9, "payment_line"), ], ) def test_class_id_to_name_mapping(self, class_id: int, expected_class_name: str): """Verify each class ID maps to the correct class name.""" assert FIELD_CLASSES[class_id] == expected_class_name def test_payment_line_is_derived(self): """Verify payment_line is marked as derived.""" payment_line_def = next( fd for fd in FIELD_DEFINITIONS if fd.class_name == "payment_line" ) assert payment_line_def.is_derived is True def test_other_fields_are_not_derived(self): """Verify all fields except payment_line are not derived.""" for fd in FIELD_DEFINITIONS: if fd.class_name != "payment_line": assert fd.is_derived is False, f"{fd.class_name} should not be derived" class TestBackwardCompatibility: """Tests to ensure backward compatibility with existing code.""" def test_csv_to_class_mapping_field_names(self): """Verify CSV_TO_CLASS_MAPPING uses correct field names.""" # These are the field names used in CSV files expected_fields = { "InvoiceNumber": 0, "InvoiceDate": 1, "InvoiceDueDate": 2, "OCR": 3, "Bankgiro": 4, "Plusgiro": 5, "Amount": 6, "supplier_organisation_number": 7, "customer_number": 8, # payment_line (9) is derived, not in CSV } assert CSV_TO_CLASS_MAPPING == expected_fields def test_class_to_field_returns_field_names(self): """Verify CLASS_TO_FIELD maps class names to field names correctly.""" # Sample checks for key fields assert CLASS_TO_FIELD["invoice_number"] == "InvoiceNumber" assert CLASS_TO_FIELD["invoice_date"] == "InvoiceDate" assert CLASS_TO_FIELD["ocr_number"] == "OCR" assert CLASS_TO_FIELD["customer_number"] == "customer_number" assert CLASS_TO_FIELD["payment_line"] == "payment_line"