Files
invoice-master-poc-v2/tests/shared/fields/test_field_config.py
Yaojia Wang 0990239e9c feat: add field-specific bbox expansion strategies for YOLO training
Implement center-point based bbox scaling with directional compensation
to capture field labels that typically appear above or to the left of
field values. This improves YOLO training data quality by including
contextual information around field values.

Key changes:
- Add shared.bbox module with ScaleStrategy dataclass and expand_bbox function
- Define field-specific strategies (ocr_number, bankgiro, invoice_date, etc.)
- Support manual_mode for minimal padding (no scaling)
- Integrate expand_bbox into AnnotationGenerator
- Add FIELD_TO_CLASS mapping for field_name to class_name lookup
- Comprehensive tests with 100% coverage (45 tests)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-04 22:56:52 +01:00

216 lines
8.0 KiB
Python

"""
Tests for field configuration - Single Source of Truth.
These tests ensure consistency across all field definitions and prevent
accidental changes that could break model inference.
CRITICAL: These tests verify that field definitions match the trained YOLO model.
If these tests fail, it likely means someone modified field IDs incorrectly.
"""
import pytest
from shared.fields import (
FIELD_DEFINITIONS,
CLASS_NAMES,
FIELD_CLASSES,
FIELD_CLASS_IDS,
CLASS_TO_FIELD,
FIELD_TO_CLASS,
CSV_TO_CLASS_MAPPING,
TRAINING_FIELD_CLASSES,
NUM_CLASSES,
FieldDefinition,
)
class TestFieldDefinitionsIntegrity:
"""Tests to ensure field definitions are complete and consistent."""
def test_exactly_10_field_definitions(self):
"""Verify we have exactly 10 field classes (matching trained model)."""
assert len(FIELD_DEFINITIONS) == 10
assert NUM_CLASSES == 10
def test_class_ids_are_sequential(self):
"""Verify class IDs are 0-9 without gaps."""
class_ids = {fd.class_id for fd in FIELD_DEFINITIONS}
assert class_ids == set(range(10))
def test_class_ids_are_unique(self):
"""Verify no duplicate class IDs."""
class_ids = [fd.class_id for fd in FIELD_DEFINITIONS]
assert len(class_ids) == len(set(class_ids))
def test_class_names_are_unique(self):
"""Verify no duplicate class names."""
class_names = [fd.class_name for fd in FIELD_DEFINITIONS]
assert len(class_names) == len(set(class_names))
def test_field_definition_is_immutable(self):
"""Verify FieldDefinition is frozen (immutable)."""
fd = FIELD_DEFINITIONS[0]
with pytest.raises(AttributeError):
fd.class_id = 99 # type: ignore
class TestModelCompatibility:
"""Tests to verify field definitions match the trained YOLO model.
These exact values are read from runs/train/invoice_fields/weights/best.pt
and MUST NOT be changed without retraining the model.
"""
# Expected model.names from best.pt - DO NOT CHANGE
EXPECTED_MODEL_NAMES = {
0: "invoice_number",
1: "invoice_date",
2: "invoice_due_date",
3: "ocr_number",
4: "bankgiro",
5: "plusgiro",
6: "amount",
7: "supplier_org_number",
8: "customer_number",
9: "payment_line",
}
def test_field_classes_match_model(self):
"""CRITICAL: Verify FIELD_CLASSES matches trained model exactly."""
assert FIELD_CLASSES == self.EXPECTED_MODEL_NAMES
def test_class_names_order_matches_model(self):
"""CRITICAL: Verify CLASS_NAMES order matches model class IDs."""
expected_order = [
self.EXPECTED_MODEL_NAMES[i] for i in range(10)
]
assert CLASS_NAMES == expected_order
def test_customer_number_is_class_8(self):
"""CRITICAL: customer_number must be class 8 (not 9)."""
assert FIELD_CLASS_IDS["customer_number"] == 8
assert FIELD_CLASSES[8] == "customer_number"
def test_payment_line_is_class_9(self):
"""CRITICAL: payment_line must be class 9 (not 8)."""
assert FIELD_CLASS_IDS["payment_line"] == 9
assert FIELD_CLASSES[9] == "payment_line"
class TestMappingConsistency:
"""Tests to verify all mappings are consistent with each other."""
def test_field_classes_and_field_class_ids_are_inverses(self):
"""Verify FIELD_CLASSES and FIELD_CLASS_IDS are proper inverses."""
for class_id, class_name in FIELD_CLASSES.items():
assert FIELD_CLASS_IDS[class_name] == class_id
for class_name, class_id in FIELD_CLASS_IDS.items():
assert FIELD_CLASSES[class_id] == class_name
def test_class_names_matches_field_classes_values(self):
"""Verify CLASS_NAMES list matches FIELD_CLASSES values in order."""
for i, class_name in enumerate(CLASS_NAMES):
assert FIELD_CLASSES[i] == class_name
def test_class_to_field_has_all_classes(self):
"""Verify CLASS_TO_FIELD has mapping for all class names."""
for class_name in CLASS_NAMES:
assert class_name in CLASS_TO_FIELD
def test_csv_mapping_excludes_derived_fields(self):
"""Verify CSV_TO_CLASS_MAPPING excludes derived fields like payment_line."""
# payment_line is derived, should not be in CSV mapping
assert "payment_line" not in CSV_TO_CLASS_MAPPING
# All non-derived fields should be in CSV mapping
for fd in FIELD_DEFINITIONS:
if not fd.is_derived:
assert fd.field_name in CSV_TO_CLASS_MAPPING
def test_training_field_classes_includes_all(self):
"""Verify TRAINING_FIELD_CLASSES includes all fields including derived."""
for fd in FIELD_DEFINITIONS:
assert fd.field_name in TRAINING_FIELD_CLASSES
assert TRAINING_FIELD_CLASSES[fd.field_name] == fd.class_id
def test_field_to_class_is_inverse_of_class_to_field(self):
"""Verify FIELD_TO_CLASS and CLASS_TO_FIELD are proper inverses."""
for class_name, field_name in CLASS_TO_FIELD.items():
assert FIELD_TO_CLASS[field_name] == class_name
for field_name, class_name in FIELD_TO_CLASS.items():
assert CLASS_TO_FIELD[class_name] == field_name
def test_field_to_class_has_all_fields(self):
"""Verify FIELD_TO_CLASS has mapping for all field names."""
for fd in FIELD_DEFINITIONS:
assert fd.field_name in FIELD_TO_CLASS
assert FIELD_TO_CLASS[fd.field_name] == fd.class_name
class TestSpecificFieldDefinitions:
"""Tests for specific field definitions to catch common mistakes."""
@pytest.mark.parametrize(
"class_id,expected_class_name",
[
(0, "invoice_number"),
(1, "invoice_date"),
(2, "invoice_due_date"),
(3, "ocr_number"),
(4, "bankgiro"),
(5, "plusgiro"),
(6, "amount"),
(7, "supplier_org_number"),
(8, "customer_number"),
(9, "payment_line"),
],
)
def test_class_id_to_name_mapping(self, class_id: int, expected_class_name: str):
"""Verify each class ID maps to the correct class name."""
assert FIELD_CLASSES[class_id] == expected_class_name
def test_payment_line_is_derived(self):
"""Verify payment_line is marked as derived."""
payment_line_def = next(
fd for fd in FIELD_DEFINITIONS if fd.class_name == "payment_line"
)
assert payment_line_def.is_derived is True
def test_other_fields_are_not_derived(self):
"""Verify all fields except payment_line are not derived."""
for fd in FIELD_DEFINITIONS:
if fd.class_name != "payment_line":
assert fd.is_derived is False, f"{fd.class_name} should not be derived"
class TestBackwardCompatibility:
"""Tests to ensure backward compatibility with existing code."""
def test_csv_to_class_mapping_field_names(self):
"""Verify CSV_TO_CLASS_MAPPING uses correct field names."""
# These are the field names used in CSV files
expected_fields = {
"InvoiceNumber": 0,
"InvoiceDate": 1,
"InvoiceDueDate": 2,
"OCR": 3,
"Bankgiro": 4,
"Plusgiro": 5,
"Amount": 6,
"supplier_organisation_number": 7,
"customer_number": 8,
# payment_line (9) is derived, not in CSV
}
assert CSV_TO_CLASS_MAPPING == expected_fields
def test_class_to_field_returns_field_names(self):
"""Verify CLASS_TO_FIELD maps class names to field names correctly."""
# Sample checks for key fields
assert CLASS_TO_FIELD["invoice_number"] == "InvoiceNumber"
assert CLASS_TO_FIELD["invoice_date"] == "InvoiceDate"
assert CLASS_TO_FIELD["ocr_number"] == "OCR"
assert CLASS_TO_FIELD["customer_number"] == "customer_number"
assert CLASS_TO_FIELD["payment_line"] == "payment_line"