201 lines
7.3 KiB
Python
201 lines
7.3 KiB
Python
"""
|
|
Tests for field configuration - Single Source of Truth.
|
|
|
|
These tests ensure consistency across all field definitions and prevent
|
|
accidental changes that could break model inference.
|
|
|
|
CRITICAL: These tests verify that field definitions match the trained YOLO model.
|
|
If these tests fail, it likely means someone modified field IDs incorrectly.
|
|
"""
|
|
|
|
import pytest
|
|
|
|
from shared.fields import (
|
|
FIELD_DEFINITIONS,
|
|
CLASS_NAMES,
|
|
FIELD_CLASSES,
|
|
FIELD_CLASS_IDS,
|
|
CLASS_TO_FIELD,
|
|
CSV_TO_CLASS_MAPPING,
|
|
TRAINING_FIELD_CLASSES,
|
|
NUM_CLASSES,
|
|
FieldDefinition,
|
|
)
|
|
|
|
|
|
class TestFieldDefinitionsIntegrity:
|
|
"""Tests to ensure field definitions are complete and consistent."""
|
|
|
|
def test_exactly_10_field_definitions(self):
|
|
"""Verify we have exactly 10 field classes (matching trained model)."""
|
|
assert len(FIELD_DEFINITIONS) == 10
|
|
assert NUM_CLASSES == 10
|
|
|
|
def test_class_ids_are_sequential(self):
|
|
"""Verify class IDs are 0-9 without gaps."""
|
|
class_ids = {fd.class_id for fd in FIELD_DEFINITIONS}
|
|
assert class_ids == set(range(10))
|
|
|
|
def test_class_ids_are_unique(self):
|
|
"""Verify no duplicate class IDs."""
|
|
class_ids = [fd.class_id for fd in FIELD_DEFINITIONS]
|
|
assert len(class_ids) == len(set(class_ids))
|
|
|
|
def test_class_names_are_unique(self):
|
|
"""Verify no duplicate class names."""
|
|
class_names = [fd.class_name for fd in FIELD_DEFINITIONS]
|
|
assert len(class_names) == len(set(class_names))
|
|
|
|
def test_field_definition_is_immutable(self):
|
|
"""Verify FieldDefinition is frozen (immutable)."""
|
|
fd = FIELD_DEFINITIONS[0]
|
|
with pytest.raises(AttributeError):
|
|
fd.class_id = 99 # type: ignore
|
|
|
|
|
|
class TestModelCompatibility:
|
|
"""Tests to verify field definitions match the trained YOLO model.
|
|
|
|
These exact values are read from runs/train/invoice_fields/weights/best.pt
|
|
and MUST NOT be changed without retraining the model.
|
|
"""
|
|
|
|
# Expected model.names from best.pt - DO NOT CHANGE
|
|
EXPECTED_MODEL_NAMES = {
|
|
0: "invoice_number",
|
|
1: "invoice_date",
|
|
2: "invoice_due_date",
|
|
3: "ocr_number",
|
|
4: "bankgiro",
|
|
5: "plusgiro",
|
|
6: "amount",
|
|
7: "supplier_org_number",
|
|
8: "customer_number",
|
|
9: "payment_line",
|
|
}
|
|
|
|
def test_field_classes_match_model(self):
|
|
"""CRITICAL: Verify FIELD_CLASSES matches trained model exactly."""
|
|
assert FIELD_CLASSES == self.EXPECTED_MODEL_NAMES
|
|
|
|
def test_class_names_order_matches_model(self):
|
|
"""CRITICAL: Verify CLASS_NAMES order matches model class IDs."""
|
|
expected_order = [
|
|
self.EXPECTED_MODEL_NAMES[i] for i in range(10)
|
|
]
|
|
assert CLASS_NAMES == expected_order
|
|
|
|
def test_customer_number_is_class_8(self):
|
|
"""CRITICAL: customer_number must be class 8 (not 9)."""
|
|
assert FIELD_CLASS_IDS["customer_number"] == 8
|
|
assert FIELD_CLASSES[8] == "customer_number"
|
|
|
|
def test_payment_line_is_class_9(self):
|
|
"""CRITICAL: payment_line must be class 9 (not 8)."""
|
|
assert FIELD_CLASS_IDS["payment_line"] == 9
|
|
assert FIELD_CLASSES[9] == "payment_line"
|
|
|
|
|
|
class TestMappingConsistency:
|
|
"""Tests to verify all mappings are consistent with each other."""
|
|
|
|
def test_field_classes_and_field_class_ids_are_inverses(self):
|
|
"""Verify FIELD_CLASSES and FIELD_CLASS_IDS are proper inverses."""
|
|
for class_id, class_name in FIELD_CLASSES.items():
|
|
assert FIELD_CLASS_IDS[class_name] == class_id
|
|
|
|
for class_name, class_id in FIELD_CLASS_IDS.items():
|
|
assert FIELD_CLASSES[class_id] == class_name
|
|
|
|
def test_class_names_matches_field_classes_values(self):
|
|
"""Verify CLASS_NAMES list matches FIELD_CLASSES values in order."""
|
|
for i, class_name in enumerate(CLASS_NAMES):
|
|
assert FIELD_CLASSES[i] == class_name
|
|
|
|
def test_class_to_field_has_all_classes(self):
|
|
"""Verify CLASS_TO_FIELD has mapping for all class names."""
|
|
for class_name in CLASS_NAMES:
|
|
assert class_name in CLASS_TO_FIELD
|
|
|
|
def test_csv_mapping_excludes_derived_fields(self):
|
|
"""Verify CSV_TO_CLASS_MAPPING excludes derived fields like payment_line."""
|
|
# payment_line is derived, should not be in CSV mapping
|
|
assert "payment_line" not in CSV_TO_CLASS_MAPPING
|
|
|
|
# All non-derived fields should be in CSV mapping
|
|
for fd in FIELD_DEFINITIONS:
|
|
if not fd.is_derived:
|
|
assert fd.field_name in CSV_TO_CLASS_MAPPING
|
|
|
|
def test_training_field_classes_includes_all(self):
|
|
"""Verify TRAINING_FIELD_CLASSES includes all fields including derived."""
|
|
for fd in FIELD_DEFINITIONS:
|
|
assert fd.field_name in TRAINING_FIELD_CLASSES
|
|
assert TRAINING_FIELD_CLASSES[fd.field_name] == fd.class_id
|
|
|
|
|
|
class TestSpecificFieldDefinitions:
|
|
"""Tests for specific field definitions to catch common mistakes."""
|
|
|
|
@pytest.mark.parametrize(
|
|
"class_id,expected_class_name",
|
|
[
|
|
(0, "invoice_number"),
|
|
(1, "invoice_date"),
|
|
(2, "invoice_due_date"),
|
|
(3, "ocr_number"),
|
|
(4, "bankgiro"),
|
|
(5, "plusgiro"),
|
|
(6, "amount"),
|
|
(7, "supplier_org_number"),
|
|
(8, "customer_number"),
|
|
(9, "payment_line"),
|
|
],
|
|
)
|
|
def test_class_id_to_name_mapping(self, class_id: int, expected_class_name: str):
|
|
"""Verify each class ID maps to the correct class name."""
|
|
assert FIELD_CLASSES[class_id] == expected_class_name
|
|
|
|
def test_payment_line_is_derived(self):
|
|
"""Verify payment_line is marked as derived."""
|
|
payment_line_def = next(
|
|
fd for fd in FIELD_DEFINITIONS if fd.class_name == "payment_line"
|
|
)
|
|
assert payment_line_def.is_derived is True
|
|
|
|
def test_other_fields_are_not_derived(self):
|
|
"""Verify all fields except payment_line are not derived."""
|
|
for fd in FIELD_DEFINITIONS:
|
|
if fd.class_name != "payment_line":
|
|
assert fd.is_derived is False, f"{fd.class_name} should not be derived"
|
|
|
|
|
|
class TestBackwardCompatibility:
|
|
"""Tests to ensure backward compatibility with existing code."""
|
|
|
|
def test_csv_to_class_mapping_field_names(self):
|
|
"""Verify CSV_TO_CLASS_MAPPING uses correct field names."""
|
|
# These are the field names used in CSV files
|
|
expected_fields = {
|
|
"InvoiceNumber": 0,
|
|
"InvoiceDate": 1,
|
|
"InvoiceDueDate": 2,
|
|
"OCR": 3,
|
|
"Bankgiro": 4,
|
|
"Plusgiro": 5,
|
|
"Amount": 6,
|
|
"supplier_organisation_number": 7,
|
|
"customer_number": 8,
|
|
# payment_line (9) is derived, not in CSV
|
|
}
|
|
assert CSV_TO_CLASS_MAPPING == expected_fields
|
|
|
|
def test_class_to_field_returns_field_names(self):
|
|
"""Verify CLASS_TO_FIELD maps class names to field names correctly."""
|
|
# Sample checks for key fields
|
|
assert CLASS_TO_FIELD["invoice_number"] == "InvoiceNumber"
|
|
assert CLASS_TO_FIELD["invoice_date"] == "InvoiceDate"
|
|
assert CLASS_TO_FIELD["ocr_number"] == "OCR"
|
|
assert CLASS_TO_FIELD["customer_number"] == "customer_number"
|
|
assert CLASS_TO_FIELD["payment_line"] == "payment_line"
|