""" Tests for AnnotationGenerator with field-specific bbox expansion. Tests verify that annotations are generated correctly using field-specific scale strategies. """ from dataclasses import dataclass import pytest from training.yolo.annotation_generator import ( AnnotationGenerator, YOLOAnnotation, ) from shared.fields import TRAINING_FIELD_CLASSES, CLASS_NAMES @dataclass class MockMatch: """Mock Match object for testing.""" bbox: tuple[float, float, float, float] score: float class TestYOLOAnnotation: """Tests for YOLOAnnotation dataclass.""" def test_to_string_format(self): """Verify YOLO format string output.""" ann = YOLOAnnotation( class_id=0, x_center=0.5, y_center=0.5, width=0.1, height=0.05, confidence=0.9 ) result = ann.to_string() assert result == "0 0.500000 0.500000 0.100000 0.050000" def test_default_confidence(self): """Verify default confidence is 1.0.""" ann = YOLOAnnotation( class_id=0, x_center=0.5, y_center=0.5, width=0.1, height=0.05, ) assert ann.confidence == 1.0 class TestAnnotationGeneratorInit: """Tests for AnnotationGenerator initialization.""" def test_default_values(self): """Verify default initialization values.""" gen = AnnotationGenerator() assert gen.min_confidence == 0.7 assert gen.min_bbox_height_px == 30 def test_custom_values(self): """Verify custom initialization values.""" gen = AnnotationGenerator( min_confidence=0.8, min_bbox_height_px=40, ) assert gen.min_confidence == 0.8 assert gen.min_bbox_height_px == 40 class TestGenerateFromMatches: """Tests for generate_from_matches method.""" def test_generates_annotation_for_valid_match(self): """Verify annotation is generated for valid match.""" gen = AnnotationGenerator(min_confidence=0.5) # Mock match in PDF points (72 DPI) # At 150 DPI, coords multiply by 150/72 = 2.083 matches = { "InvoiceNumber": [MockMatch(bbox=(100, 200, 200, 230), score=0.8)] } annotations = gen.generate_from_matches( matches=matches, image_width=1000, image_height=1000, dpi=150 ) assert len(annotations) == 1 ann = annotations[0] assert ann.class_id == TRAINING_FIELD_CLASSES["InvoiceNumber"] assert ann.confidence == 0.8 # Normalized values should be in 0-1 range assert 0 <= ann.x_center <= 1 assert 0 <= ann.y_center <= 1 assert 0 < ann.width <= 1 assert 0 < ann.height <= 1 def test_skips_low_confidence_match(self): """Verify low confidence matches are skipped.""" gen = AnnotationGenerator(min_confidence=0.7) matches = { "InvoiceNumber": [MockMatch(bbox=(100, 200, 200, 230), score=0.5)] } annotations = gen.generate_from_matches( matches=matches, image_width=1000, image_height=1000, dpi=150 ) assert len(annotations) == 0 def test_skips_unknown_field(self): """Verify unknown fields are skipped.""" gen = AnnotationGenerator(min_confidence=0.5) matches = { "UnknownField": [MockMatch(bbox=(100, 200, 200, 230), score=0.9)] } annotations = gen.generate_from_matches( matches=matches, image_width=1000, image_height=1000, dpi=150 ) assert len(annotations) == 0 def test_takes_best_match_only(self): """Verify only the best match is used per field.""" gen = AnnotationGenerator(min_confidence=0.5) matches = { "InvoiceNumber": [ MockMatch(bbox=(100, 200, 200, 230), score=0.9), # Best MockMatch(bbox=(300, 400, 400, 430), score=0.7), ] } annotations = gen.generate_from_matches( matches=matches, image_width=1000, image_height=1000, dpi=150 ) assert len(annotations) == 1 assert annotations[0].confidence == 0.9 def test_handles_empty_matches(self): """Verify empty matches list is handled.""" gen = AnnotationGenerator() matches = { "InvoiceNumber": [] } annotations = gen.generate_from_matches( matches=matches, image_width=1000, image_height=1000, dpi=150 ) assert len(annotations) == 0 def test_applies_field_specific_expansion(self): """Verify different fields get different expansion.""" gen = AnnotationGenerator(min_confidence=0.5) # Same bbox, different fields bbox = (100, 200, 200, 230) matches_invoice_number = { "InvoiceNumber": [MockMatch(bbox=bbox, score=0.9)] } matches_bankgiro = { "Bankgiro": [MockMatch(bbox=bbox, score=0.9)] } ann_invoice = gen.generate_from_matches( matches=matches_invoice_number, image_width=1000, image_height=1000, dpi=150 )[0] ann_bankgiro = gen.generate_from_matches( matches=matches_bankgiro, image_width=1000, image_height=1000, dpi=150 )[0] # Bankgiro has extra_left_ratio=0.80, invoice_number has extra_top_ratio=0.40 # They should have different widths due to different expansion # Bankgiro expands more to the left assert ann_bankgiro.width != ann_invoice.width or ann_bankgiro.x_center != ann_invoice.x_center def test_enforces_min_bbox_height(self): """Verify minimum bbox height is enforced.""" gen = AnnotationGenerator(min_confidence=0.5, min_bbox_height_px=50) # Very small bbox matches = { "InvoiceNumber": [MockMatch(bbox=(100, 200, 200, 210), score=0.9)] } annotations = gen.generate_from_matches( matches=matches, image_width=1000, image_height=1000, dpi=72 # 1:1 scale ) assert len(annotations) == 1 # Height should be at least min_bbox_height_px / image_height # After scale strategy expansion, height should be >= 50/1000 = 0.05 # Actually the min_bbox_height check happens AFTER expand_bbox # So the final height should meet the minimum class TestAddPaymentLineAnnotation: """Tests for add_payment_line_annotation method.""" def test_adds_payment_line_annotation(self): """Verify payment_line annotation is added.""" gen = AnnotationGenerator(min_confidence=0.5) annotations = [] result = gen.add_payment_line_annotation( annotations=annotations, payment_line_bbox=(100, 200, 400, 230), confidence=0.9, image_width=1000, image_height=1000, dpi=150 ) assert len(result) == 1 ann = result[0] assert ann.class_id == TRAINING_FIELD_CLASSES["payment_line"] assert ann.confidence == 0.9 def test_skips_none_bbox(self): """Verify None bbox is handled.""" gen = AnnotationGenerator(min_confidence=0.5) annotations = [] result = gen.add_payment_line_annotation( annotations=annotations, payment_line_bbox=None, confidence=0.9, image_width=1000, image_height=1000, dpi=150 ) assert len(result) == 0 def test_skips_low_confidence(self): """Verify low confidence is skipped.""" gen = AnnotationGenerator(min_confidence=0.7) annotations = [] result = gen.add_payment_line_annotation( annotations=annotations, payment_line_bbox=(100, 200, 400, 230), confidence=0.5, image_width=1000, image_height=1000, dpi=150 ) assert len(result) == 0 def test_appends_to_existing_annotations(self): """Verify payment_line is appended to existing list.""" gen = AnnotationGenerator(min_confidence=0.5) existing = [YOLOAnnotation(0, 0.5, 0.5, 0.1, 0.1, 0.9)] result = gen.add_payment_line_annotation( annotations=existing, payment_line_bbox=(100, 200, 400, 230), confidence=0.9, image_width=1000, image_height=1000, dpi=150 ) assert len(result) == 2 assert result[0].class_id == 0 # Original assert result[1].class_id == TRAINING_FIELD_CLASSES["payment_line"] class TestMultipleFieldsIntegration: """Integration tests for multiple fields.""" def test_generates_annotations_for_all_field_types(self): """Verify annotations can be generated for all field types.""" gen = AnnotationGenerator(min_confidence=0.5) # Create matches for each field (except payment_line which is derived) field_names = [ "InvoiceNumber", "InvoiceDate", "InvoiceDueDate", "OCR", "Bankgiro", "Plusgiro", "Amount", "supplier_organisation_number", "customer_number", ] matches = {} for i, field_name in enumerate(field_names): # Stagger bboxes to avoid overlap matches[field_name] = [ MockMatch(bbox=(100 + i * 50, 100 + i * 30, 200 + i * 50, 130 + i * 30), score=0.9) ] annotations = gen.generate_from_matches( matches=matches, image_width=2000, image_height=2000, dpi=150 ) assert len(annotations) == len(field_names) # Verify all class_ids are present class_ids = {ann.class_id for ann in annotations} expected_class_ids = {TRAINING_FIELD_CLASSES[fn] for fn in field_names} assert class_ids == expected_class_ids