344 lines
10 KiB
Python
344 lines
10 KiB
Python
"""
|
|
Tests for AnnotationGenerator with field-specific bbox expansion.
|
|
|
|
Tests verify that annotations are generated correctly using
|
|
field-specific scale strategies.
|
|
"""
|
|
|
|
from dataclasses import dataclass
|
|
import pytest
|
|
|
|
from training.yolo.annotation_generator import (
|
|
AnnotationGenerator,
|
|
YOLOAnnotation,
|
|
)
|
|
from shared.fields import TRAINING_FIELD_CLASSES, CLASS_NAMES
|
|
|
|
|
|
@dataclass
|
|
class MockMatch:
|
|
"""Mock Match object for testing."""
|
|
bbox: tuple[float, float, float, float]
|
|
score: float
|
|
|
|
|
|
class TestYOLOAnnotation:
|
|
"""Tests for YOLOAnnotation dataclass."""
|
|
|
|
def test_to_string_format(self):
|
|
"""Verify YOLO format string output."""
|
|
ann = YOLOAnnotation(
|
|
class_id=0,
|
|
x_center=0.5,
|
|
y_center=0.5,
|
|
width=0.1,
|
|
height=0.05,
|
|
confidence=0.9
|
|
)
|
|
result = ann.to_string()
|
|
assert result == "0 0.500000 0.500000 0.100000 0.050000"
|
|
|
|
def test_default_confidence(self):
|
|
"""Verify default confidence is 1.0."""
|
|
ann = YOLOAnnotation(
|
|
class_id=0,
|
|
x_center=0.5,
|
|
y_center=0.5,
|
|
width=0.1,
|
|
height=0.05,
|
|
)
|
|
assert ann.confidence == 1.0
|
|
|
|
|
|
class TestAnnotationGeneratorInit:
|
|
"""Tests for AnnotationGenerator initialization."""
|
|
|
|
def test_default_values(self):
|
|
"""Verify default initialization values."""
|
|
gen = AnnotationGenerator()
|
|
assert gen.min_confidence == 0.7
|
|
assert gen.min_bbox_height_px == 30
|
|
|
|
def test_custom_values(self):
|
|
"""Verify custom initialization values."""
|
|
gen = AnnotationGenerator(
|
|
min_confidence=0.8,
|
|
min_bbox_height_px=40,
|
|
)
|
|
assert gen.min_confidence == 0.8
|
|
assert gen.min_bbox_height_px == 40
|
|
|
|
|
|
class TestGenerateFromMatches:
|
|
"""Tests for generate_from_matches method."""
|
|
|
|
def test_generates_annotation_for_valid_match(self):
|
|
"""Verify annotation is generated for valid match."""
|
|
gen = AnnotationGenerator(min_confidence=0.5)
|
|
|
|
# Mock match in PDF points (72 DPI)
|
|
# At 150 DPI, coords multiply by 150/72 = 2.083
|
|
matches = {
|
|
"InvoiceNumber": [MockMatch(bbox=(100, 200, 200, 230), score=0.8)]
|
|
}
|
|
|
|
annotations = gen.generate_from_matches(
|
|
matches=matches,
|
|
image_width=1000,
|
|
image_height=1000,
|
|
dpi=150
|
|
)
|
|
|
|
assert len(annotations) == 1
|
|
ann = annotations[0]
|
|
assert ann.class_id == TRAINING_FIELD_CLASSES["InvoiceNumber"]
|
|
assert ann.confidence == 0.8
|
|
# Normalized values should be in 0-1 range
|
|
assert 0 <= ann.x_center <= 1
|
|
assert 0 <= ann.y_center <= 1
|
|
assert 0 < ann.width <= 1
|
|
assert 0 < ann.height <= 1
|
|
|
|
def test_skips_low_confidence_match(self):
|
|
"""Verify low confidence matches are skipped."""
|
|
gen = AnnotationGenerator(min_confidence=0.7)
|
|
|
|
matches = {
|
|
"InvoiceNumber": [MockMatch(bbox=(100, 200, 200, 230), score=0.5)]
|
|
}
|
|
|
|
annotations = gen.generate_from_matches(
|
|
matches=matches,
|
|
image_width=1000,
|
|
image_height=1000,
|
|
dpi=150
|
|
)
|
|
|
|
assert len(annotations) == 0
|
|
|
|
def test_skips_unknown_field(self):
|
|
"""Verify unknown fields are skipped."""
|
|
gen = AnnotationGenerator(min_confidence=0.5)
|
|
|
|
matches = {
|
|
"UnknownField": [MockMatch(bbox=(100, 200, 200, 230), score=0.9)]
|
|
}
|
|
|
|
annotations = gen.generate_from_matches(
|
|
matches=matches,
|
|
image_width=1000,
|
|
image_height=1000,
|
|
dpi=150
|
|
)
|
|
|
|
assert len(annotations) == 0
|
|
|
|
def test_takes_best_match_only(self):
|
|
"""Verify only the best match is used per field."""
|
|
gen = AnnotationGenerator(min_confidence=0.5)
|
|
|
|
matches = {
|
|
"InvoiceNumber": [
|
|
MockMatch(bbox=(100, 200, 200, 230), score=0.9), # Best
|
|
MockMatch(bbox=(300, 400, 400, 430), score=0.7),
|
|
]
|
|
}
|
|
|
|
annotations = gen.generate_from_matches(
|
|
matches=matches,
|
|
image_width=1000,
|
|
image_height=1000,
|
|
dpi=150
|
|
)
|
|
|
|
assert len(annotations) == 1
|
|
assert annotations[0].confidence == 0.9
|
|
|
|
def test_handles_empty_matches(self):
|
|
"""Verify empty matches list is handled."""
|
|
gen = AnnotationGenerator()
|
|
|
|
matches = {
|
|
"InvoiceNumber": []
|
|
}
|
|
|
|
annotations = gen.generate_from_matches(
|
|
matches=matches,
|
|
image_width=1000,
|
|
image_height=1000,
|
|
dpi=150
|
|
)
|
|
|
|
assert len(annotations) == 0
|
|
|
|
def test_applies_uniform_expansion(self):
|
|
"""Verify all fields get the same uniform expansion."""
|
|
gen = AnnotationGenerator(min_confidence=0.5)
|
|
|
|
# Same bbox, different fields
|
|
bbox = (100, 200, 200, 230)
|
|
|
|
matches_invoice_number = {
|
|
"InvoiceNumber": [MockMatch(bbox=bbox, score=0.9)]
|
|
}
|
|
matches_bankgiro = {
|
|
"Bankgiro": [MockMatch(bbox=bbox, score=0.9)]
|
|
}
|
|
|
|
ann_invoice = gen.generate_from_matches(
|
|
matches=matches_invoice_number,
|
|
image_width=1000,
|
|
image_height=1000,
|
|
dpi=150
|
|
)[0]
|
|
|
|
ann_bankgiro = gen.generate_from_matches(
|
|
matches=matches_bankgiro,
|
|
image_width=1000,
|
|
image_height=1000,
|
|
dpi=150
|
|
)[0]
|
|
|
|
# Uniform expansion: same bbox -> same dimensions (only class_id differs)
|
|
assert ann_bankgiro.width == ann_invoice.width
|
|
assert ann_bankgiro.height == ann_invoice.height
|
|
assert ann_bankgiro.x_center == ann_invoice.x_center
|
|
assert ann_bankgiro.y_center == ann_invoice.y_center
|
|
|
|
def test_enforces_min_bbox_height(self):
|
|
"""Verify minimum bbox height is enforced."""
|
|
gen = AnnotationGenerator(min_confidence=0.5, min_bbox_height_px=50)
|
|
|
|
# Very small bbox
|
|
matches = {
|
|
"InvoiceNumber": [MockMatch(bbox=(100, 200, 200, 210), score=0.9)]
|
|
}
|
|
|
|
annotations = gen.generate_from_matches(
|
|
matches=matches,
|
|
image_width=1000,
|
|
image_height=1000,
|
|
dpi=72 # 1:1 scale
|
|
)
|
|
|
|
assert len(annotations) == 1
|
|
# Height should be at least min_bbox_height_px / image_height
|
|
# After scale strategy expansion, height should be >= 50/1000 = 0.05
|
|
# Actually the min_bbox_height check happens AFTER expand_bbox
|
|
# So the final height should meet the minimum
|
|
|
|
|
|
class TestAddPaymentLineAnnotation:
|
|
"""Tests for add_payment_line_annotation method."""
|
|
|
|
def test_adds_payment_line_annotation(self):
|
|
"""Verify payment_line annotation is added."""
|
|
gen = AnnotationGenerator(min_confidence=0.5)
|
|
annotations = []
|
|
|
|
result = gen.add_payment_line_annotation(
|
|
annotations=annotations,
|
|
payment_line_bbox=(100, 200, 400, 230),
|
|
confidence=0.9,
|
|
image_width=1000,
|
|
image_height=1000,
|
|
dpi=150
|
|
)
|
|
|
|
assert len(result) == 1
|
|
ann = result[0]
|
|
assert ann.class_id == TRAINING_FIELD_CLASSES["payment_line"]
|
|
assert ann.confidence == 0.9
|
|
|
|
def test_skips_none_bbox(self):
|
|
"""Verify None bbox is handled."""
|
|
gen = AnnotationGenerator(min_confidence=0.5)
|
|
annotations = []
|
|
|
|
result = gen.add_payment_line_annotation(
|
|
annotations=annotations,
|
|
payment_line_bbox=None,
|
|
confidence=0.9,
|
|
image_width=1000,
|
|
image_height=1000,
|
|
dpi=150
|
|
)
|
|
|
|
assert len(result) == 0
|
|
|
|
def test_skips_low_confidence(self):
|
|
"""Verify low confidence is skipped."""
|
|
gen = AnnotationGenerator(min_confidence=0.7)
|
|
annotations = []
|
|
|
|
result = gen.add_payment_line_annotation(
|
|
annotations=annotations,
|
|
payment_line_bbox=(100, 200, 400, 230),
|
|
confidence=0.5,
|
|
image_width=1000,
|
|
image_height=1000,
|
|
dpi=150
|
|
)
|
|
|
|
assert len(result) == 0
|
|
|
|
def test_appends_to_existing_annotations(self):
|
|
"""Verify payment_line is appended to existing list."""
|
|
gen = AnnotationGenerator(min_confidence=0.5)
|
|
existing = [YOLOAnnotation(0, 0.5, 0.5, 0.1, 0.1, 0.9)]
|
|
|
|
result = gen.add_payment_line_annotation(
|
|
annotations=existing,
|
|
payment_line_bbox=(100, 200, 400, 230),
|
|
confidence=0.9,
|
|
image_width=1000,
|
|
image_height=1000,
|
|
dpi=150
|
|
)
|
|
|
|
assert len(result) == 2
|
|
assert result[0].class_id == 0 # Original
|
|
assert result[1].class_id == TRAINING_FIELD_CLASSES["payment_line"]
|
|
|
|
|
|
class TestMultipleFieldsIntegration:
|
|
"""Integration tests for multiple fields."""
|
|
|
|
def test_generates_annotations_for_all_field_types(self):
|
|
"""Verify annotations can be generated for all field types."""
|
|
gen = AnnotationGenerator(min_confidence=0.5)
|
|
|
|
# Create matches for each field (except payment_line which is derived)
|
|
field_names = [
|
|
"InvoiceNumber",
|
|
"InvoiceDate",
|
|
"InvoiceDueDate",
|
|
"OCR",
|
|
"Bankgiro",
|
|
"Plusgiro",
|
|
"Amount",
|
|
"supplier_organisation_number",
|
|
"customer_number",
|
|
]
|
|
|
|
matches = {}
|
|
for i, field_name in enumerate(field_names):
|
|
# Stagger bboxes to avoid overlap
|
|
matches[field_name] = [
|
|
MockMatch(bbox=(100 + i * 50, 100 + i * 30, 200 + i * 50, 130 + i * 30), score=0.9)
|
|
]
|
|
|
|
annotations = gen.generate_from_matches(
|
|
matches=matches,
|
|
image_width=2000,
|
|
image_height=2000,
|
|
dpi=150
|
|
)
|
|
|
|
assert len(annotations) == len(field_names)
|
|
|
|
# Verify all class_ids are present
|
|
class_ids = {ann.class_id for ann in annotations}
|
|
expected_class_ids = {TRAINING_FIELD_CLASSES[fn] for fn in field_names}
|
|
assert class_ids == expected_class_ids
|