This commit is contained in:
Yaojia Wang
2026-02-01 00:08:40 +01:00
parent 33ada0350d
commit a516de4320
90 changed files with 11642 additions and 398 deletions

View File

@@ -0,0 +1,200 @@
"""
Tests for field configuration - Single Source of Truth.
These tests ensure consistency across all field definitions and prevent
accidental changes that could break model inference.
CRITICAL: These tests verify that field definitions match the trained YOLO model.
If these tests fail, it likely means someone modified field IDs incorrectly.
"""
import pytest
from shared.fields import (
FIELD_DEFINITIONS,
CLASS_NAMES,
FIELD_CLASSES,
FIELD_CLASS_IDS,
CLASS_TO_FIELD,
CSV_TO_CLASS_MAPPING,
TRAINING_FIELD_CLASSES,
NUM_CLASSES,
FieldDefinition,
)
class TestFieldDefinitionsIntegrity:
"""Tests to ensure field definitions are complete and consistent."""
def test_exactly_10_field_definitions(self):
"""Verify we have exactly 10 field classes (matching trained model)."""
assert len(FIELD_DEFINITIONS) == 10
assert NUM_CLASSES == 10
def test_class_ids_are_sequential(self):
"""Verify class IDs are 0-9 without gaps."""
class_ids = {fd.class_id for fd in FIELD_DEFINITIONS}
assert class_ids == set(range(10))
def test_class_ids_are_unique(self):
"""Verify no duplicate class IDs."""
class_ids = [fd.class_id for fd in FIELD_DEFINITIONS]
assert len(class_ids) == len(set(class_ids))
def test_class_names_are_unique(self):
"""Verify no duplicate class names."""
class_names = [fd.class_name for fd in FIELD_DEFINITIONS]
assert len(class_names) == len(set(class_names))
def test_field_definition_is_immutable(self):
"""Verify FieldDefinition is frozen (immutable)."""
fd = FIELD_DEFINITIONS[0]
with pytest.raises(AttributeError):
fd.class_id = 99 # type: ignore
class TestModelCompatibility:
"""Tests to verify field definitions match the trained YOLO model.
These exact values are read from runs/train/invoice_fields/weights/best.pt
and MUST NOT be changed without retraining the model.
"""
# Expected model.names from best.pt - DO NOT CHANGE
EXPECTED_MODEL_NAMES = {
0: "invoice_number",
1: "invoice_date",
2: "invoice_due_date",
3: "ocr_number",
4: "bankgiro",
5: "plusgiro",
6: "amount",
7: "supplier_org_number",
8: "customer_number",
9: "payment_line",
}
def test_field_classes_match_model(self):
"""CRITICAL: Verify FIELD_CLASSES matches trained model exactly."""
assert FIELD_CLASSES == self.EXPECTED_MODEL_NAMES
def test_class_names_order_matches_model(self):
"""CRITICAL: Verify CLASS_NAMES order matches model class IDs."""
expected_order = [
self.EXPECTED_MODEL_NAMES[i] for i in range(10)
]
assert CLASS_NAMES == expected_order
def test_customer_number_is_class_8(self):
"""CRITICAL: customer_number must be class 8 (not 9)."""
assert FIELD_CLASS_IDS["customer_number"] == 8
assert FIELD_CLASSES[8] == "customer_number"
def test_payment_line_is_class_9(self):
"""CRITICAL: payment_line must be class 9 (not 8)."""
assert FIELD_CLASS_IDS["payment_line"] == 9
assert FIELD_CLASSES[9] == "payment_line"
class TestMappingConsistency:
"""Tests to verify all mappings are consistent with each other."""
def test_field_classes_and_field_class_ids_are_inverses(self):
"""Verify FIELD_CLASSES and FIELD_CLASS_IDS are proper inverses."""
for class_id, class_name in FIELD_CLASSES.items():
assert FIELD_CLASS_IDS[class_name] == class_id
for class_name, class_id in FIELD_CLASS_IDS.items():
assert FIELD_CLASSES[class_id] == class_name
def test_class_names_matches_field_classes_values(self):
"""Verify CLASS_NAMES list matches FIELD_CLASSES values in order."""
for i, class_name in enumerate(CLASS_NAMES):
assert FIELD_CLASSES[i] == class_name
def test_class_to_field_has_all_classes(self):
"""Verify CLASS_TO_FIELD has mapping for all class names."""
for class_name in CLASS_NAMES:
assert class_name in CLASS_TO_FIELD
def test_csv_mapping_excludes_derived_fields(self):
"""Verify CSV_TO_CLASS_MAPPING excludes derived fields like payment_line."""
# payment_line is derived, should not be in CSV mapping
assert "payment_line" not in CSV_TO_CLASS_MAPPING
# All non-derived fields should be in CSV mapping
for fd in FIELD_DEFINITIONS:
if not fd.is_derived:
assert fd.field_name in CSV_TO_CLASS_MAPPING
def test_training_field_classes_includes_all(self):
"""Verify TRAINING_FIELD_CLASSES includes all fields including derived."""
for fd in FIELD_DEFINITIONS:
assert fd.field_name in TRAINING_FIELD_CLASSES
assert TRAINING_FIELD_CLASSES[fd.field_name] == fd.class_id
class TestSpecificFieldDefinitions:
"""Tests for specific field definitions to catch common mistakes."""
@pytest.mark.parametrize(
"class_id,expected_class_name",
[
(0, "invoice_number"),
(1, "invoice_date"),
(2, "invoice_due_date"),
(3, "ocr_number"),
(4, "bankgiro"),
(5, "plusgiro"),
(6, "amount"),
(7, "supplier_org_number"),
(8, "customer_number"),
(9, "payment_line"),
],
)
def test_class_id_to_name_mapping(self, class_id: int, expected_class_name: str):
"""Verify each class ID maps to the correct class name."""
assert FIELD_CLASSES[class_id] == expected_class_name
def test_payment_line_is_derived(self):
"""Verify payment_line is marked as derived."""
payment_line_def = next(
fd for fd in FIELD_DEFINITIONS if fd.class_name == "payment_line"
)
assert payment_line_def.is_derived is True
def test_other_fields_are_not_derived(self):
"""Verify all fields except payment_line are not derived."""
for fd in FIELD_DEFINITIONS:
if fd.class_name != "payment_line":
assert fd.is_derived is False, f"{fd.class_name} should not be derived"
class TestBackwardCompatibility:
"""Tests to ensure backward compatibility with existing code."""
def test_csv_to_class_mapping_field_names(self):
"""Verify CSV_TO_CLASS_MAPPING uses correct field names."""
# These are the field names used in CSV files
expected_fields = {
"InvoiceNumber": 0,
"InvoiceDate": 1,
"InvoiceDueDate": 2,
"OCR": 3,
"Bankgiro": 4,
"Plusgiro": 5,
"Amount": 6,
"supplier_organisation_number": 7,
"customer_number": 8,
# payment_line (9) is derived, not in CSV
}
assert CSV_TO_CLASS_MAPPING == expected_fields
def test_class_to_field_returns_field_names(self):
"""Verify CLASS_TO_FIELD maps class names to field names correctly."""
# Sample checks for key fields
assert CLASS_TO_FIELD["invoice_number"] == "InvoiceNumber"
assert CLASS_TO_FIELD["invoice_date"] == "InvoiceDate"
assert CLASS_TO_FIELD["ocr_number"] == "OCR"
assert CLASS_TO_FIELD["customer_number"] == "customer_number"
assert CLASS_TO_FIELD["payment_line"] == "payment_line"