WIP
This commit is contained in:
1
tests/shared/fields/__init__.py
Normal file
1
tests/shared/fields/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for shared.fields module."""
|
||||
200
tests/shared/fields/test_field_config.py
Normal file
200
tests/shared/fields/test_field_config.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Tests for field configuration - Single Source of Truth.
|
||||
|
||||
These tests ensure consistency across all field definitions and prevent
|
||||
accidental changes that could break model inference.
|
||||
|
||||
CRITICAL: These tests verify that field definitions match the trained YOLO model.
|
||||
If these tests fail, it likely means someone modified field IDs incorrectly.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.fields import (
|
||||
FIELD_DEFINITIONS,
|
||||
CLASS_NAMES,
|
||||
FIELD_CLASSES,
|
||||
FIELD_CLASS_IDS,
|
||||
CLASS_TO_FIELD,
|
||||
CSV_TO_CLASS_MAPPING,
|
||||
TRAINING_FIELD_CLASSES,
|
||||
NUM_CLASSES,
|
||||
FieldDefinition,
|
||||
)
|
||||
|
||||
|
||||
class TestFieldDefinitionsIntegrity:
|
||||
"""Tests to ensure field definitions are complete and consistent."""
|
||||
|
||||
def test_exactly_10_field_definitions(self):
|
||||
"""Verify we have exactly 10 field classes (matching trained model)."""
|
||||
assert len(FIELD_DEFINITIONS) == 10
|
||||
assert NUM_CLASSES == 10
|
||||
|
||||
def test_class_ids_are_sequential(self):
|
||||
"""Verify class IDs are 0-9 without gaps."""
|
||||
class_ids = {fd.class_id for fd in FIELD_DEFINITIONS}
|
||||
assert class_ids == set(range(10))
|
||||
|
||||
def test_class_ids_are_unique(self):
|
||||
"""Verify no duplicate class IDs."""
|
||||
class_ids = [fd.class_id for fd in FIELD_DEFINITIONS]
|
||||
assert len(class_ids) == len(set(class_ids))
|
||||
|
||||
def test_class_names_are_unique(self):
|
||||
"""Verify no duplicate class names."""
|
||||
class_names = [fd.class_name for fd in FIELD_DEFINITIONS]
|
||||
assert len(class_names) == len(set(class_names))
|
||||
|
||||
def test_field_definition_is_immutable(self):
|
||||
"""Verify FieldDefinition is frozen (immutable)."""
|
||||
fd = FIELD_DEFINITIONS[0]
|
||||
with pytest.raises(AttributeError):
|
||||
fd.class_id = 99 # type: ignore
|
||||
|
||||
|
||||
class TestModelCompatibility:
|
||||
"""Tests to verify field definitions match the trained YOLO model.
|
||||
|
||||
These exact values are read from runs/train/invoice_fields/weights/best.pt
|
||||
and MUST NOT be changed without retraining the model.
|
||||
"""
|
||||
|
||||
# Expected model.names from best.pt - DO NOT CHANGE
|
||||
EXPECTED_MODEL_NAMES = {
|
||||
0: "invoice_number",
|
||||
1: "invoice_date",
|
||||
2: "invoice_due_date",
|
||||
3: "ocr_number",
|
||||
4: "bankgiro",
|
||||
5: "plusgiro",
|
||||
6: "amount",
|
||||
7: "supplier_org_number",
|
||||
8: "customer_number",
|
||||
9: "payment_line",
|
||||
}
|
||||
|
||||
def test_field_classes_match_model(self):
|
||||
"""CRITICAL: Verify FIELD_CLASSES matches trained model exactly."""
|
||||
assert FIELD_CLASSES == self.EXPECTED_MODEL_NAMES
|
||||
|
||||
def test_class_names_order_matches_model(self):
|
||||
"""CRITICAL: Verify CLASS_NAMES order matches model class IDs."""
|
||||
expected_order = [
|
||||
self.EXPECTED_MODEL_NAMES[i] for i in range(10)
|
||||
]
|
||||
assert CLASS_NAMES == expected_order
|
||||
|
||||
def test_customer_number_is_class_8(self):
|
||||
"""CRITICAL: customer_number must be class 8 (not 9)."""
|
||||
assert FIELD_CLASS_IDS["customer_number"] == 8
|
||||
assert FIELD_CLASSES[8] == "customer_number"
|
||||
|
||||
def test_payment_line_is_class_9(self):
|
||||
"""CRITICAL: payment_line must be class 9 (not 8)."""
|
||||
assert FIELD_CLASS_IDS["payment_line"] == 9
|
||||
assert FIELD_CLASSES[9] == "payment_line"
|
||||
|
||||
|
||||
class TestMappingConsistency:
|
||||
"""Tests to verify all mappings are consistent with each other."""
|
||||
|
||||
def test_field_classes_and_field_class_ids_are_inverses(self):
|
||||
"""Verify FIELD_CLASSES and FIELD_CLASS_IDS are proper inverses."""
|
||||
for class_id, class_name in FIELD_CLASSES.items():
|
||||
assert FIELD_CLASS_IDS[class_name] == class_id
|
||||
|
||||
for class_name, class_id in FIELD_CLASS_IDS.items():
|
||||
assert FIELD_CLASSES[class_id] == class_name
|
||||
|
||||
def test_class_names_matches_field_classes_values(self):
|
||||
"""Verify CLASS_NAMES list matches FIELD_CLASSES values in order."""
|
||||
for i, class_name in enumerate(CLASS_NAMES):
|
||||
assert FIELD_CLASSES[i] == class_name
|
||||
|
||||
def test_class_to_field_has_all_classes(self):
|
||||
"""Verify CLASS_TO_FIELD has mapping for all class names."""
|
||||
for class_name in CLASS_NAMES:
|
||||
assert class_name in CLASS_TO_FIELD
|
||||
|
||||
def test_csv_mapping_excludes_derived_fields(self):
|
||||
"""Verify CSV_TO_CLASS_MAPPING excludes derived fields like payment_line."""
|
||||
# payment_line is derived, should not be in CSV mapping
|
||||
assert "payment_line" not in CSV_TO_CLASS_MAPPING
|
||||
|
||||
# All non-derived fields should be in CSV mapping
|
||||
for fd in FIELD_DEFINITIONS:
|
||||
if not fd.is_derived:
|
||||
assert fd.field_name in CSV_TO_CLASS_MAPPING
|
||||
|
||||
def test_training_field_classes_includes_all(self):
|
||||
"""Verify TRAINING_FIELD_CLASSES includes all fields including derived."""
|
||||
for fd in FIELD_DEFINITIONS:
|
||||
assert fd.field_name in TRAINING_FIELD_CLASSES
|
||||
assert TRAINING_FIELD_CLASSES[fd.field_name] == fd.class_id
|
||||
|
||||
|
||||
class TestSpecificFieldDefinitions:
|
||||
"""Tests for specific field definitions to catch common mistakes."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"class_id,expected_class_name",
|
||||
[
|
||||
(0, "invoice_number"),
|
||||
(1, "invoice_date"),
|
||||
(2, "invoice_due_date"),
|
||||
(3, "ocr_number"),
|
||||
(4, "bankgiro"),
|
||||
(5, "plusgiro"),
|
||||
(6, "amount"),
|
||||
(7, "supplier_org_number"),
|
||||
(8, "customer_number"),
|
||||
(9, "payment_line"),
|
||||
],
|
||||
)
|
||||
def test_class_id_to_name_mapping(self, class_id: int, expected_class_name: str):
|
||||
"""Verify each class ID maps to the correct class name."""
|
||||
assert FIELD_CLASSES[class_id] == expected_class_name
|
||||
|
||||
def test_payment_line_is_derived(self):
|
||||
"""Verify payment_line is marked as derived."""
|
||||
payment_line_def = next(
|
||||
fd for fd in FIELD_DEFINITIONS if fd.class_name == "payment_line"
|
||||
)
|
||||
assert payment_line_def.is_derived is True
|
||||
|
||||
def test_other_fields_are_not_derived(self):
|
||||
"""Verify all fields except payment_line are not derived."""
|
||||
for fd in FIELD_DEFINITIONS:
|
||||
if fd.class_name != "payment_line":
|
||||
assert fd.is_derived is False, f"{fd.class_name} should not be derived"
|
||||
|
||||
|
||||
class TestBackwardCompatibility:
|
||||
"""Tests to ensure backward compatibility with existing code."""
|
||||
|
||||
def test_csv_to_class_mapping_field_names(self):
|
||||
"""Verify CSV_TO_CLASS_MAPPING uses correct field names."""
|
||||
# These are the field names used in CSV files
|
||||
expected_fields = {
|
||||
"InvoiceNumber": 0,
|
||||
"InvoiceDate": 1,
|
||||
"InvoiceDueDate": 2,
|
||||
"OCR": 3,
|
||||
"Bankgiro": 4,
|
||||
"Plusgiro": 5,
|
||||
"Amount": 6,
|
||||
"supplier_organisation_number": 7,
|
||||
"customer_number": 8,
|
||||
# payment_line (9) is derived, not in CSV
|
||||
}
|
||||
assert CSV_TO_CLASS_MAPPING == expected_fields
|
||||
|
||||
def test_class_to_field_returns_field_names(self):
|
||||
"""Verify CLASS_TO_FIELD maps class names to field names correctly."""
|
||||
# Sample checks for key fields
|
||||
assert CLASS_TO_FIELD["invoice_number"] == "InvoiceNumber"
|
||||
assert CLASS_TO_FIELD["invoice_date"] == "InvoiceDate"
|
||||
assert CLASS_TO_FIELD["ocr_number"] == "OCR"
|
||||
assert CLASS_TO_FIELD["customer_number"] == "customer_number"
|
||||
assert CLASS_TO_FIELD["payment_line"] == "payment_line"
|
||||
Reference in New Issue
Block a user