Files
invoice-master-poc-v2/tests/training/yolo/test_annotation_generator.py
Yaojia Wang ad5ed46b4c WIP
2026-02-11 23:40:38 +01:00

344 lines
10 KiB
Python

"""
Tests for AnnotationGenerator with field-specific bbox expansion.
Tests verify that annotations are generated correctly using
field-specific scale strategies.
"""
from dataclasses import dataclass
import pytest
from training.yolo.annotation_generator import (
AnnotationGenerator,
YOLOAnnotation,
)
from shared.fields import TRAINING_FIELD_CLASSES, CLASS_NAMES
@dataclass
class MockMatch:
"""Mock Match object for testing."""
bbox: tuple[float, float, float, float]
score: float
class TestYOLOAnnotation:
"""Tests for YOLOAnnotation dataclass."""
def test_to_string_format(self):
"""Verify YOLO format string output."""
ann = YOLOAnnotation(
class_id=0,
x_center=0.5,
y_center=0.5,
width=0.1,
height=0.05,
confidence=0.9
)
result = ann.to_string()
assert result == "0 0.500000 0.500000 0.100000 0.050000"
def test_default_confidence(self):
"""Verify default confidence is 1.0."""
ann = YOLOAnnotation(
class_id=0,
x_center=0.5,
y_center=0.5,
width=0.1,
height=0.05,
)
assert ann.confidence == 1.0
class TestAnnotationGeneratorInit:
"""Tests for AnnotationGenerator initialization."""
def test_default_values(self):
"""Verify default initialization values."""
gen = AnnotationGenerator()
assert gen.min_confidence == 0.7
assert gen.min_bbox_height_px == 30
def test_custom_values(self):
"""Verify custom initialization values."""
gen = AnnotationGenerator(
min_confidence=0.8,
min_bbox_height_px=40,
)
assert gen.min_confidence == 0.8
assert gen.min_bbox_height_px == 40
class TestGenerateFromMatches:
"""Tests for generate_from_matches method."""
def test_generates_annotation_for_valid_match(self):
"""Verify annotation is generated for valid match."""
gen = AnnotationGenerator(min_confidence=0.5)
# Mock match in PDF points (72 DPI)
# At 150 DPI, coords multiply by 150/72 = 2.083
matches = {
"InvoiceNumber": [MockMatch(bbox=(100, 200, 200, 230), score=0.8)]
}
annotations = gen.generate_from_matches(
matches=matches,
image_width=1000,
image_height=1000,
dpi=150
)
assert len(annotations) == 1
ann = annotations[0]
assert ann.class_id == TRAINING_FIELD_CLASSES["InvoiceNumber"]
assert ann.confidence == 0.8
# Normalized values should be in 0-1 range
assert 0 <= ann.x_center <= 1
assert 0 <= ann.y_center <= 1
assert 0 < ann.width <= 1
assert 0 < ann.height <= 1
def test_skips_low_confidence_match(self):
"""Verify low confidence matches are skipped."""
gen = AnnotationGenerator(min_confidence=0.7)
matches = {
"InvoiceNumber": [MockMatch(bbox=(100, 200, 200, 230), score=0.5)]
}
annotations = gen.generate_from_matches(
matches=matches,
image_width=1000,
image_height=1000,
dpi=150
)
assert len(annotations) == 0
def test_skips_unknown_field(self):
"""Verify unknown fields are skipped."""
gen = AnnotationGenerator(min_confidence=0.5)
matches = {
"UnknownField": [MockMatch(bbox=(100, 200, 200, 230), score=0.9)]
}
annotations = gen.generate_from_matches(
matches=matches,
image_width=1000,
image_height=1000,
dpi=150
)
assert len(annotations) == 0
def test_takes_best_match_only(self):
"""Verify only the best match is used per field."""
gen = AnnotationGenerator(min_confidence=0.5)
matches = {
"InvoiceNumber": [
MockMatch(bbox=(100, 200, 200, 230), score=0.9), # Best
MockMatch(bbox=(300, 400, 400, 430), score=0.7),
]
}
annotations = gen.generate_from_matches(
matches=matches,
image_width=1000,
image_height=1000,
dpi=150
)
assert len(annotations) == 1
assert annotations[0].confidence == 0.9
def test_handles_empty_matches(self):
"""Verify empty matches list is handled."""
gen = AnnotationGenerator()
matches = {
"InvoiceNumber": []
}
annotations = gen.generate_from_matches(
matches=matches,
image_width=1000,
image_height=1000,
dpi=150
)
assert len(annotations) == 0
def test_applies_uniform_expansion(self):
"""Verify all fields get the same uniform expansion."""
gen = AnnotationGenerator(min_confidence=0.5)
# Same bbox, different fields
bbox = (100, 200, 200, 230)
matches_invoice_number = {
"InvoiceNumber": [MockMatch(bbox=bbox, score=0.9)]
}
matches_bankgiro = {
"Bankgiro": [MockMatch(bbox=bbox, score=0.9)]
}
ann_invoice = gen.generate_from_matches(
matches=matches_invoice_number,
image_width=1000,
image_height=1000,
dpi=150
)[0]
ann_bankgiro = gen.generate_from_matches(
matches=matches_bankgiro,
image_width=1000,
image_height=1000,
dpi=150
)[0]
# Uniform expansion: same bbox -> same dimensions (only class_id differs)
assert ann_bankgiro.width == ann_invoice.width
assert ann_bankgiro.height == ann_invoice.height
assert ann_bankgiro.x_center == ann_invoice.x_center
assert ann_bankgiro.y_center == ann_invoice.y_center
def test_enforces_min_bbox_height(self):
"""Verify minimum bbox height is enforced."""
gen = AnnotationGenerator(min_confidence=0.5, min_bbox_height_px=50)
# Very small bbox
matches = {
"InvoiceNumber": [MockMatch(bbox=(100, 200, 200, 210), score=0.9)]
}
annotations = gen.generate_from_matches(
matches=matches,
image_width=1000,
image_height=1000,
dpi=72 # 1:1 scale
)
assert len(annotations) == 1
# Height should be at least min_bbox_height_px / image_height
# After scale strategy expansion, height should be >= 50/1000 = 0.05
# Actually the min_bbox_height check happens AFTER expand_bbox
# So the final height should meet the minimum
class TestAddPaymentLineAnnotation:
"""Tests for add_payment_line_annotation method."""
def test_adds_payment_line_annotation(self):
"""Verify payment_line annotation is added."""
gen = AnnotationGenerator(min_confidence=0.5)
annotations = []
result = gen.add_payment_line_annotation(
annotations=annotations,
payment_line_bbox=(100, 200, 400, 230),
confidence=0.9,
image_width=1000,
image_height=1000,
dpi=150
)
assert len(result) == 1
ann = result[0]
assert ann.class_id == TRAINING_FIELD_CLASSES["payment_line"]
assert ann.confidence == 0.9
def test_skips_none_bbox(self):
"""Verify None bbox is handled."""
gen = AnnotationGenerator(min_confidence=0.5)
annotations = []
result = gen.add_payment_line_annotation(
annotations=annotations,
payment_line_bbox=None,
confidence=0.9,
image_width=1000,
image_height=1000,
dpi=150
)
assert len(result) == 0
def test_skips_low_confidence(self):
"""Verify low confidence is skipped."""
gen = AnnotationGenerator(min_confidence=0.7)
annotations = []
result = gen.add_payment_line_annotation(
annotations=annotations,
payment_line_bbox=(100, 200, 400, 230),
confidence=0.5,
image_width=1000,
image_height=1000,
dpi=150
)
assert len(result) == 0
def test_appends_to_existing_annotations(self):
"""Verify payment_line is appended to existing list."""
gen = AnnotationGenerator(min_confidence=0.5)
existing = [YOLOAnnotation(0, 0.5, 0.5, 0.1, 0.1, 0.9)]
result = gen.add_payment_line_annotation(
annotations=existing,
payment_line_bbox=(100, 200, 400, 230),
confidence=0.9,
image_width=1000,
image_height=1000,
dpi=150
)
assert len(result) == 2
assert result[0].class_id == 0 # Original
assert result[1].class_id == TRAINING_FIELD_CLASSES["payment_line"]
class TestMultipleFieldsIntegration:
"""Integration tests for multiple fields."""
def test_generates_annotations_for_all_field_types(self):
"""Verify annotations can be generated for all field types."""
gen = AnnotationGenerator(min_confidence=0.5)
# Create matches for each field (except payment_line which is derived)
field_names = [
"InvoiceNumber",
"InvoiceDate",
"InvoiceDueDate",
"OCR",
"Bankgiro",
"Plusgiro",
"Amount",
"supplier_organisation_number",
"customer_number",
]
matches = {}
for i, field_name in enumerate(field_names):
# Stagger bboxes to avoid overlap
matches[field_name] = [
MockMatch(bbox=(100 + i * 50, 100 + i * 30, 200 + i * 50, 130 + i * 30), score=0.9)
]
annotations = gen.generate_from_matches(
matches=matches,
image_width=2000,
image_height=2000,
dpi=150
)
assert len(annotations) == len(field_names)
# Verify all class_ids are present
class_ids = {ann.class_id for ann in annotations}
expected_class_ids = {TRAINING_FIELD_CLASSES[fn] for fn in field_names}
assert class_ids == expected_class_ids