diff --git a/packages/shared/shared/bbox/__init__.py b/packages/shared/shared/bbox/__init__.py new file mode 100644 index 0000000..baf769d --- /dev/null +++ b/packages/shared/shared/bbox/__init__.py @@ -0,0 +1,37 @@ +""" +BBox Scale Strategy Module. + +Provides field-specific bounding box expansion strategies for YOLO training data. +Expands bboxes using center-point scaling with directional compensation to capture +field labels that typically appear above or to the left of field values. + +Two modes are supported: +- Auto-label: Field-specific scale strategies with directional compensation +- Manual-label: Minimal padding only to prevent edge clipping + +Usage: + from shared.bbox import expand_bbox, ScaleStrategy, FIELD_SCALE_STRATEGIES + +Available exports: + - ScaleStrategy: Dataclass for scale strategy configuration + - DEFAULT_STRATEGY: Default strategy for unknown fields (auto-label) + - MANUAL_LABEL_STRATEGY: Minimal padding strategy for manual labels + - FIELD_SCALE_STRATEGIES: dict[str, ScaleStrategy] - field-specific strategies + - expand_bbox: Function to expand bbox using field-specific strategy +""" + +from .scale_strategy import ( + ScaleStrategy, + DEFAULT_STRATEGY, + MANUAL_LABEL_STRATEGY, + FIELD_SCALE_STRATEGIES, +) +from .expander import expand_bbox + +__all__ = [ + "ScaleStrategy", + "DEFAULT_STRATEGY", + "MANUAL_LABEL_STRATEGY", + "FIELD_SCALE_STRATEGIES", + "expand_bbox", +] diff --git a/packages/shared/shared/bbox/expander.py b/packages/shared/shared/bbox/expander.py new file mode 100644 index 0000000..ad025a1 --- /dev/null +++ b/packages/shared/shared/bbox/expander.py @@ -0,0 +1,101 @@ +""" +BBox Expander Module. + +Provides functions to expand bounding boxes using field-specific strategies. +Expansion is center-point based with directional compensation. + +Two modes: +- Auto-label (default): Field-specific scale strategies +- Manual-label: Minimal padding only to prevent edge clipping +""" + +from .scale_strategy import ( + ScaleStrategy, + DEFAULT_STRATEGY, + MANUAL_LABEL_STRATEGY, + FIELD_SCALE_STRATEGIES, +) + + +def expand_bbox( + bbox: tuple[float, float, float, float], + image_width: float, + image_height: float, + field_type: str, + strategies: dict[str, ScaleStrategy] | None = None, + manual_mode: bool = False, +) -> tuple[int, int, int, int]: + """ + Expand bbox using field-specific scale strategy. + + The expansion follows these steps: + 1. Scale bbox around center point (scale_x, scale_y) + 2. Apply directional compensation (extra_*_ratio) + 3. Clamp expansion to max_pad limits + 4. Clamp to image boundaries + + Args: + bbox: (x0, y0, x1, y1) in pixels + image_width: Image width for boundary clamping + image_height: Image height for boundary clamping + field_type: Field class_name (e.g., "ocr_number") + strategies: Custom strategies dict, defaults to FIELD_SCALE_STRATEGIES + manual_mode: If True, use MANUAL_LABEL_STRATEGY (minimal padding only) + + Returns: + Expanded bbox (x0, y0, x1, y1) as integers, clamped to image bounds + """ + x0, y0, x1, y1 = bbox + w = x1 - x0 + h = y1 - y0 + + # Get strategy based on mode + if manual_mode: + strategy = MANUAL_LABEL_STRATEGY + elif strategies is None: + strategy = FIELD_SCALE_STRATEGIES.get(field_type, DEFAULT_STRATEGY) + else: + strategy = strategies.get(field_type, DEFAULT_STRATEGY) + + # Step 1: Scale around center point + cx = (x0 + x1) / 2 + cy = (y0 + y1) / 2 + + new_w = w * strategy.scale_x + new_h = h * strategy.scale_y + + nx0 = cx - new_w / 2 + nx1 = cx + new_w / 2 + ny0 = cy - new_h / 2 + ny1 = cy + new_h / 2 + + # Step 2: Apply directional compensation + nx0 -= w * strategy.extra_left_ratio + nx1 += w * strategy.extra_right_ratio + ny0 -= h * strategy.extra_top_ratio + ny1 += h * strategy.extra_bottom_ratio + + # Step 3: Clamp expansion to max_pad limits (preserve asymmetry) + left_pad = min(x0 - nx0, strategy.max_pad_x) + right_pad = min(nx1 - x1, strategy.max_pad_x) + top_pad = min(y0 - ny0, strategy.max_pad_y) + bottom_pad = min(ny1 - y1, strategy.max_pad_y) + + # Ensure pads are non-negative (in case of contraction) + left_pad = max(0, left_pad) + right_pad = max(0, right_pad) + top_pad = max(0, top_pad) + bottom_pad = max(0, bottom_pad) + + nx0 = x0 - left_pad + nx1 = x1 + right_pad + ny0 = y0 - top_pad + ny1 = y1 + bottom_pad + + # Step 4: Clamp to image boundaries + nx0 = max(0, int(nx0)) + ny0 = max(0, int(ny0)) + nx1 = min(int(image_width), int(nx1)) + ny1 = min(int(image_height), int(ny1)) + + return (nx0, ny0, nx1, ny1) diff --git a/packages/shared/shared/bbox/scale_strategy.py b/packages/shared/shared/bbox/scale_strategy.py new file mode 100644 index 0000000..36f200d --- /dev/null +++ b/packages/shared/shared/bbox/scale_strategy.py @@ -0,0 +1,140 @@ +""" +Scale Strategy Configuration. + +Defines field-specific bbox expansion strategies for YOLO training data. +Each strategy controls how bboxes are expanded around field values to +capture contextual information like labels. +""" + +from dataclasses import dataclass +from typing import Final + + +@dataclass(frozen=True) +class ScaleStrategy: + """Immutable scale strategy for bbox expansion. + + Attributes: + scale_x: Horizontal scale factor (1.0 = no scaling) + scale_y: Vertical scale factor (1.0 = no scaling) + extra_top_ratio: Additional expansion ratio towards top (for labels above) + extra_bottom_ratio: Additional expansion ratio towards bottom + extra_left_ratio: Additional expansion ratio towards left (for prefixes) + extra_right_ratio: Additional expansion ratio towards right (for suffixes) + max_pad_x: Maximum horizontal padding in pixels + max_pad_y: Maximum vertical padding in pixels + """ + + scale_x: float = 1.15 + scale_y: float = 1.15 + extra_top_ratio: float = 0.0 + extra_bottom_ratio: float = 0.0 + extra_left_ratio: float = 0.0 + extra_right_ratio: float = 0.0 + max_pad_x: int = 50 + max_pad_y: int = 50 + + +# Default strategy for unknown fields (auto-label mode) +DEFAULT_STRATEGY: Final[ScaleStrategy] = ScaleStrategy() + +# Manual label strategy - minimal padding to prevent edge clipping +# No scaling, no directional compensation, just small uniform padding +MANUAL_LABEL_STRATEGY: Final[ScaleStrategy] = ScaleStrategy( + scale_x=1.0, + scale_y=1.0, + extra_top_ratio=0.0, + extra_bottom_ratio=0.0, + extra_left_ratio=0.0, + extra_right_ratio=0.0, + max_pad_x=10, # Small padding to prevent edge loss + max_pad_y=10, +) + + +# Field-specific strategies based on Swedish invoice field characteristics +# Field labels typically appear above or to the left of values +FIELD_SCALE_STRATEGIES: Final[dict[str, ScaleStrategy]] = { + # OCR number - label "OCR" or "Referens" typically above + "ocr_number": ScaleStrategy( + scale_x=1.15, + scale_y=1.80, + extra_top_ratio=0.60, + max_pad_x=50, + max_pad_y=140, + ), + # Bankgiro - prefix "Bankgiro:" or "BG:" typically to the left + "bankgiro": ScaleStrategy( + scale_x=1.45, + scale_y=1.35, + extra_left_ratio=0.80, + max_pad_x=160, + max_pad_y=90, + ), + # Plusgiro - prefix "Plusgiro:" or "PG:" typically to the left + "plusgiro": ScaleStrategy( + scale_x=1.45, + scale_y=1.35, + extra_left_ratio=0.80, + max_pad_x=160, + max_pad_y=90, + ), + # Invoice date - label "Fakturadatum" typically above + "invoice_date": ScaleStrategy( + scale_x=1.25, + scale_y=1.55, + extra_top_ratio=0.40, + max_pad_x=80, + max_pad_y=110, + ), + # Due date - label "Forfalldatum" typically above, sometimes left + "invoice_due_date": ScaleStrategy( + scale_x=1.30, + scale_y=1.65, + extra_top_ratio=0.45, + extra_left_ratio=0.35, + max_pad_x=100, + max_pad_y=120, + ), + # Amount - currency symbol "SEK" or "kr" may be to the right + "amount": ScaleStrategy( + scale_x=1.20, + scale_y=1.35, + extra_right_ratio=0.30, + max_pad_x=70, + max_pad_y=80, + ), + # Invoice number - label "Fakturanummer" typically above + "invoice_number": ScaleStrategy( + scale_x=1.20, + scale_y=1.50, + extra_top_ratio=0.40, + max_pad_x=80, + max_pad_y=100, + ), + # Supplier org number - label "Org.nr" typically above or left + "supplier_org_number": ScaleStrategy( + scale_x=1.25, + scale_y=1.40, + extra_top_ratio=0.30, + extra_left_ratio=0.20, + max_pad_x=90, + max_pad_y=90, + ), + # Customer number - label "Kundnummer" typically above or left + "customer_number": ScaleStrategy( + scale_x=1.25, + scale_y=1.45, + extra_top_ratio=0.35, + extra_left_ratio=0.25, + max_pad_x=90, + max_pad_y=100, + ), + # Payment line - machine-readable code, minimal expansion needed + "payment_line": ScaleStrategy( + scale_x=1.10, + scale_y=1.20, + max_pad_x=40, + max_pad_y=40, + ), +} diff --git a/packages/shared/shared/fields/__init__.py b/packages/shared/shared/fields/__init__.py index 07c66d9..970a19c 100644 --- a/packages/shared/shared/fields/__init__.py +++ b/packages/shared/shared/fields/__init__.py @@ -16,6 +16,7 @@ Available exports: - FIELD_CLASSES: dict[int, str] - class_id to class_name - FIELD_CLASS_IDS: dict[str, int] - class_name to class_id - CLASS_TO_FIELD: dict[str, str] - class_name to field_name + - FIELD_TO_CLASS: dict[str, str] - field_name to class_name - CSV_TO_CLASS_MAPPING: dict[str, int] - field_name to class_id (excludes derived) - TRAINING_FIELD_CLASSES: dict[str, int] - field_name to class_id (all fields) - ACCOUNT_FIELD_MAPPING: Mapping for supplier_accounts handling @@ -27,6 +28,7 @@ from .mappings import ( FIELD_CLASSES, FIELD_CLASS_IDS, CLASS_TO_FIELD, + FIELD_TO_CLASS, CSV_TO_CLASS_MAPPING, TRAINING_FIELD_CLASSES, ACCOUNT_FIELD_MAPPING, @@ -40,6 +42,7 @@ __all__ = [ "FIELD_CLASSES", "FIELD_CLASS_IDS", "CLASS_TO_FIELD", + "FIELD_TO_CLASS", "CSV_TO_CLASS_MAPPING", "TRAINING_FIELD_CLASSES", "ACCOUNT_FIELD_MAPPING", diff --git a/packages/shared/shared/fields/mappings.py b/packages/shared/shared/fields/mappings.py index 18013dd..12d49a6 100644 --- a/packages/shared/shared/fields/mappings.py +++ b/packages/shared/shared/fields/mappings.py @@ -47,6 +47,12 @@ TRAINING_FIELD_CLASSES: Final[dict[str, int]] = { fd.field_name: fd.class_id for fd in FIELD_DEFINITIONS } +# field_name -> class_name mapping (reverse of CLASS_TO_FIELD) +# Example: {"InvoiceNumber": "invoice_number", "OCR": "ocr_number", ...} +FIELD_TO_CLASS: Final[dict[str, str]] = { + fd.field_name: fd.class_name for fd in FIELD_DEFINITIONS +} + # Account field mapping for supplier_accounts special handling # BG:xxx -> Bankgiro, PG:xxx -> Plusgiro ACCOUNT_FIELD_MAPPING: Final[dict[str, dict[str, str]]] = { diff --git a/packages/training/training/yolo/annotation_generator.py b/packages/training/training/yolo/annotation_generator.py index b8327b0..4101a92 100644 --- a/packages/training/training/yolo/annotation_generator.py +++ b/packages/training/training/yolo/annotation_generator.py @@ -2,6 +2,7 @@ YOLO Annotation Generator Generates YOLO format annotations from matched fields. +Uses field-specific bbox expansion strategies for optimal training data. """ from dataclasses import dataclass @@ -14,7 +15,9 @@ from shared.fields import ( TRAINING_FIELD_CLASSES as FIELD_CLASSES, CLASS_NAMES, ACCOUNT_FIELD_MAPPING, + FIELD_TO_CLASS, ) +from shared.bbox import expand_bbox @dataclass @@ -38,19 +41,16 @@ class AnnotationGenerator: def __init__( self, min_confidence: float = 0.7, - bbox_padding_px: int = 20, # Absolute padding in pixels - min_bbox_height_px: int = 30 # Minimum bbox height + min_bbox_height_px: int = 30, # Minimum bbox height ): """ Initialize annotation generator. Args: min_confidence: Minimum match score to include in training - bbox_padding_px: Absolute padding in pixels to add around bboxes min_bbox_height_px: Minimum bbox height in pixels """ self.min_confidence = min_confidence - self.bbox_padding_px = bbox_padding_px self.min_bbox_height_px = min_bbox_height_px def generate_from_matches( @@ -63,6 +63,10 @@ class AnnotationGenerator: """ Generate YOLO annotations from field matches. + Uses field-specific bbox expansion strategies for optimal training data. + Each field type has customized scale factors and directional compensation + to capture field labels and context. + Args: matches: Dict of field_name -> list of Match objects image_width: Width of the rendered image in pixels @@ -82,6 +86,8 @@ class AnnotationGenerator: continue class_id = FIELD_CLASSES[field_name] + # Get class_name for bbox expansion strategy + class_name = FIELD_TO_CLASS.get(field_name, field_name) # Take only the best match per field if field_matches: @@ -94,19 +100,20 @@ class AnnotationGenerator: x0, y0, x1, y1 = best_match.bbox x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale - # Add absolute padding - pad = self.bbox_padding_px - x0 = max(0, x0 - pad) - y0 = max(0, y0 - pad) - x1 = min(image_width, x1 + pad) - y1 = min(image_height, y1 + pad) + # Apply field-specific bbox expansion strategy + x0, y0, x1, y1 = expand_bbox( + bbox=(x0, y0, x1, y1), + image_width=image_width, + image_height=image_height, + field_type=class_name, + ) # Ensure minimum height current_height = y1 - y0 if current_height < self.min_bbox_height_px: extra = (self.min_bbox_height_px - current_height) / 2 - y0 = max(0, y0 - extra) - y1 = min(image_height, y1 + extra) + y0 = max(0, int(y0 - extra)) + y1 = min(int(image_height), int(y1 + extra)) # Convert to YOLO format (normalized center + size) x_center = (x0 + x1) / 2 / image_width @@ -143,6 +150,9 @@ class AnnotationGenerator: """ Add payment_line annotation from machine code parser result. + Uses "payment_line" scale strategy for minimal expansion + (machine-readable code needs less context). + Args: annotations: Existing list of annotations to append to payment_line_bbox: Bounding box (x0, y0, x1, y1) in PDF coordinates @@ -163,12 +173,13 @@ class AnnotationGenerator: x0, y0, x1, y1 = payment_line_bbox x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale - # Add absolute padding - pad = self.bbox_padding_px - x0 = max(0, x0 - pad) - y0 = max(0, y0 - pad) - x1 = min(image_width, x1 + pad) - y1 = min(image_height, y1 + pad) + # Apply field-specific bbox expansion strategy for payment_line + x0, y0, x1, y1 = expand_bbox( + bbox=(x0, y0, x1, y1), + image_width=image_width, + image_height=image_height, + field_type="payment_line", + ) # Convert to YOLO format (normalized center + size) x_center = (x0 + x1) / 2 / image_width diff --git a/tests/shared/bbox/__init__.py b/tests/shared/bbox/__init__.py new file mode 100644 index 0000000..22c0b26 --- /dev/null +++ b/tests/shared/bbox/__init__.py @@ -0,0 +1 @@ +"""Tests for shared.bbox module.""" diff --git a/tests/shared/bbox/test_expander.py b/tests/shared/bbox/test_expander.py new file mode 100644 index 0000000..c533597 --- /dev/null +++ b/tests/shared/bbox/test_expander.py @@ -0,0 +1,556 @@ +""" +Tests for expand_bbox function. + +Tests verify that bbox expansion works correctly with center-point scaling, +directional compensation, max padding clamping, and image boundary handling. +""" + +import pytest + +from shared.bbox import ( + expand_bbox, + ScaleStrategy, + FIELD_SCALE_STRATEGIES, + DEFAULT_STRATEGY, +) + + +class TestExpandBboxCenterScaling: + """Tests for center-point based scaling.""" + + def test_center_scaling_expands_symmetrically(self): + """Verify bbox expands symmetrically around center when no extra ratios.""" + # 100x50 bbox at (100, 200) + bbox = (100, 200, 200, 250) + strategy = ScaleStrategy( + scale_x=1.2, # 20% wider + scale_y=1.4, # 40% taller + max_pad_x=1000, # Large to avoid clamping + max_pad_y=1000, + ) + + result = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="test_field", + strategies={"test_field": strategy}, + ) + + # Original: width=100, height=50 + # New: width=120, height=70 + # Center: (150, 225) + # Expected: x0=150-60=90, x1=150+60=210, y0=225-35=190, y1=225+35=260 + assert result[0] == 90 # x0 + assert result[1] == 190 # y0 + assert result[2] == 210 # x1 + assert result[3] == 260 # y1 + + def test_no_scaling_returns_original(self): + """Verify scale=1.0 with no extras returns original bbox.""" + bbox = (100, 200, 200, 250) + strategy = ScaleStrategy( + scale_x=1.0, + scale_y=1.0, + max_pad_x=1000, + max_pad_y=1000, + ) + + result = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="test_field", + strategies={"test_field": strategy}, + ) + + assert result == (100, 200, 200, 250) + + +class TestExpandBboxDirectionalCompensation: + """Tests for directional compensation (extra ratios).""" + + def test_extra_top_expands_upward(self): + """Verify extra_top_ratio adds expansion toward top.""" + bbox = (100, 200, 200, 250) # width=100, height=50 + strategy = ScaleStrategy( + scale_x=1.0, + scale_y=1.0, + extra_top_ratio=0.5, # Add 50% of height to top + max_pad_x=1000, + max_pad_y=1000, + ) + + result = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="test_field", + strategies={"test_field": strategy}, + ) + + # extra_top = 50 * 0.5 = 25 + assert result[0] == 100 # x0 unchanged + assert result[1] == 175 # y0 = 200 - 25 + assert result[2] == 200 # x1 unchanged + assert result[3] == 250 # y1 unchanged + + def test_extra_left_expands_leftward(self): + """Verify extra_left_ratio adds expansion toward left.""" + bbox = (100, 200, 200, 250) # width=100 + strategy = ScaleStrategy( + scale_x=1.0, + scale_y=1.0, + extra_left_ratio=0.8, # Add 80% of width to left + max_pad_x=1000, + max_pad_y=1000, + ) + + result = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="test_field", + strategies={"test_field": strategy}, + ) + + # extra_left = 100 * 0.8 = 80 + assert result[0] == 20 # x0 = 100 - 80 + assert result[1] == 200 # y0 unchanged + assert result[2] == 200 # x1 unchanged + assert result[3] == 250 # y1 unchanged + + def test_extra_right_expands_rightward(self): + """Verify extra_right_ratio adds expansion toward right.""" + bbox = (100, 200, 200, 250) # width=100 + strategy = ScaleStrategy( + scale_x=1.0, + scale_y=1.0, + extra_right_ratio=0.3, # Add 30% of width to right + max_pad_x=1000, + max_pad_y=1000, + ) + + result = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="test_field", + strategies={"test_field": strategy}, + ) + + # extra_right = 100 * 0.3 = 30 + assert result[0] == 100 # x0 unchanged + assert result[1] == 200 # y0 unchanged + assert result[2] == 230 # x1 = 200 + 30 + assert result[3] == 250 # y1 unchanged + + def test_extra_bottom_expands_downward(self): + """Verify extra_bottom_ratio adds expansion toward bottom.""" + bbox = (100, 200, 200, 250) # height=50 + strategy = ScaleStrategy( + scale_x=1.0, + scale_y=1.0, + extra_bottom_ratio=0.4, # Add 40% of height to bottom + max_pad_x=1000, + max_pad_y=1000, + ) + + result = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="test_field", + strategies={"test_field": strategy}, + ) + + # extra_bottom = 50 * 0.4 = 20 + assert result[0] == 100 # x0 unchanged + assert result[1] == 200 # y0 unchanged + assert result[2] == 200 # x1 unchanged + assert result[3] == 270 # y1 = 250 + 20 + + def test_combined_scaling_and_directional(self): + """Verify scale + directional compensation work together.""" + bbox = (100, 200, 200, 250) # width=100, height=50 + strategy = ScaleStrategy( + scale_x=1.2, # 20% wider -> 120 width + scale_y=1.0, # no height change + extra_left_ratio=0.5, # Add 50% of width to left + max_pad_x=1000, + max_pad_y=1000, + ) + + result = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="test_field", + strategies={"test_field": strategy}, + ) + + # Center: x=150 + # After scale: width=120 -> x0=150-60=90, x1=150+60=210 + # After extra_left: x0 = 90 - (100 * 0.5) = 40 + assert result[0] == 40 # x0 + assert result[2] == 210 # x1 + + +class TestExpandBboxMaxPadClamping: + """Tests for max padding clamping.""" + + def test_max_pad_x_limits_horizontal_expansion(self): + """Verify max_pad_x limits expansion on left and right.""" + bbox = (100, 200, 200, 250) # width=100 + strategy = ScaleStrategy( + scale_x=2.0, # Double width (would add 50 each side) + scale_y=1.0, + max_pad_x=30, # Limit to 30 pixels each side + max_pad_y=1000, + ) + + result = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="test_field", + strategies={"test_field": strategy}, + ) + + # Scale would make: x0=100, x1=200 -> x0=50, x1=250 (50px each side) + # But max_pad_x=30 limits to: x0=70, x1=230 + assert result[0] == 70 # x0 = 100 - 30 + assert result[2] == 230 # x1 = 200 + 30 + + def test_max_pad_y_limits_vertical_expansion(self): + """Verify max_pad_y limits expansion on top and bottom.""" + bbox = (100, 200, 200, 250) # height=50 + strategy = ScaleStrategy( + scale_x=1.0, + scale_y=3.0, # Triple height (would add 50 each side) + max_pad_x=1000, + max_pad_y=20, # Limit to 20 pixels each side + ) + + result = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="test_field", + strategies={"test_field": strategy}, + ) + + # Scale would make: y0=175, y1=275 (50px each side) + # But max_pad_y=20 limits to: y0=180, y1=270 + assert result[1] == 180 # y0 = 200 - 20 + assert result[3] == 270 # y1 = 250 + 20 + + def test_max_pad_preserves_asymmetry(self): + """Verify max_pad clamping preserves asymmetric expansion.""" + bbox = (100, 200, 200, 250) # width=100 + strategy = ScaleStrategy( + scale_x=1.0, + scale_y=1.0, + extra_left_ratio=1.0, # 100px left expansion + extra_right_ratio=0.0, # No right expansion + max_pad_x=50, # Limit to 50 pixels + max_pad_y=1000, + ) + + result = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="test_field", + strategies={"test_field": strategy}, + ) + + # Left would expand 100, clamped to 50 + # Right stays at 0 + assert result[0] == 50 # x0 = 100 - 50 + assert result[2] == 200 # x1 unchanged + + +class TestExpandBboxImageBoundaryClamping: + """Tests for image boundary clamping.""" + + def test_clamps_to_left_boundary(self): + """Verify x0 is clamped to 0.""" + bbox = (10, 200, 110, 250) # Close to left edge + strategy = ScaleStrategy( + scale_x=1.0, + scale_y=1.0, + extra_left_ratio=0.5, # Would push x0 below 0 + max_pad_x=1000, + max_pad_y=1000, + ) + + result = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="test_field", + strategies={"test_field": strategy}, + ) + + assert result[0] == 0 # Clamped to 0 + + def test_clamps_to_top_boundary(self): + """Verify y0 is clamped to 0.""" + bbox = (100, 10, 200, 60) # Close to top edge + strategy = ScaleStrategy( + scale_x=1.0, + scale_y=1.0, + extra_top_ratio=0.5, # Would push y0 below 0 + max_pad_x=1000, + max_pad_y=1000, + ) + + result = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="test_field", + strategies={"test_field": strategy}, + ) + + assert result[1] == 0 # Clamped to 0 + + def test_clamps_to_right_boundary(self): + """Verify x1 is clamped to image_width.""" + bbox = (900, 200, 990, 250) # Close to right edge + strategy = ScaleStrategy( + scale_x=1.0, + scale_y=1.0, + extra_right_ratio=0.5, # Would push x1 beyond image_width + max_pad_x=1000, + max_pad_y=1000, + ) + + result = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="test_field", + strategies={"test_field": strategy}, + ) + + assert result[2] == 1000 # Clamped to image_width + + def test_clamps_to_bottom_boundary(self): + """Verify y1 is clamped to image_height.""" + bbox = (100, 940, 200, 990) # Close to bottom edge + strategy = ScaleStrategy( + scale_x=1.0, + scale_y=1.0, + extra_bottom_ratio=0.5, # Would push y1 beyond image_height + max_pad_x=1000, + max_pad_y=1000, + ) + + result = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="test_field", + strategies={"test_field": strategy}, + ) + + assert result[3] == 1000 # Clamped to image_height + + +class TestExpandBboxUnknownField: + """Tests for unknown field handling.""" + + def test_unknown_field_uses_default_strategy(self): + """Verify unknown field types use DEFAULT_STRATEGY.""" + bbox = (100, 200, 200, 250) + + result = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="unknown_field_xyz", + ) + + # DEFAULT_STRATEGY: scale_x=1.15, scale_y=1.15 + # Original: width=100, height=50 + # New: width=115, height=57.5 + # Center: (150, 225) + # x0 = 150 - 57.5 = 92.5 -> 92 + # x1 = 150 + 57.5 = 207.5 -> 207 + # y0 = 225 - 28.75 = 196.25 -> 196 + # y1 = 225 + 28.75 = 253.75 -> 253 + # But max_pad_x=50 may clamp... + # Left pad = 100 - 92.5 = 7.5 (< 50, ok) + # Right pad = 207.5 - 200 = 7.5 (< 50, ok) + assert result[0] == 92 + assert result[2] == 207 + + +class TestExpandBboxWithRealStrategies: + """Tests using actual FIELD_SCALE_STRATEGIES.""" + + def test_ocr_number_expands_significantly_upward(self): + """Verify ocr_number field gets significant upward expansion.""" + bbox = (100, 200, 200, 230) # Small height=30 + + result = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="ocr_number", + ) + + # extra_top_ratio=0.60 -> 30 * 0.6 = 18 extra top + # y0 should decrease significantly + assert result[1] < 200 - 10 # At least 10px upward expansion + + def test_bankgiro_expands_significantly_leftward(self): + """Verify bankgiro field gets significant leftward expansion.""" + bbox = (200, 200, 300, 230) # width=100 + + result = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="bankgiro", + ) + + # extra_left_ratio=0.80 -> 100 * 0.8 = 80 extra left + # x0 should decrease significantly + assert result[0] < 200 - 30 # At least 30px leftward expansion + + def test_amount_expands_rightward(self): + """Verify amount field gets rightward expansion for currency.""" + bbox = (100, 200, 200, 230) # width=100 + + result = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="amount", + ) + + # extra_right_ratio=0.30 -> 100 * 0.3 = 30 extra right + # x1 should increase + assert result[2] > 200 + 10 # At least 10px rightward expansion + + +class TestExpandBboxReturnType: + """Tests for return type and value format.""" + + def test_returns_tuple_of_four_ints(self): + """Verify return type is tuple of 4 integers.""" + bbox = (100.5, 200.3, 200.7, 250.9) + + result = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="invoice_number", + ) + + assert isinstance(result, tuple) + assert len(result) == 4 + assert all(isinstance(v, int) for v in result) + + def test_returns_valid_bbox_format(self): + """Verify returned bbox has x0 < x1 and y0 < y1.""" + bbox = (100, 200, 200, 250) + + result = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="invoice_number", + ) + + x0, y0, x1, y1 = result + assert x0 < x1, "x0 should be less than x1" + assert y0 < y1, "y0 should be less than y1" + + +class TestManualLabelMode: + """Tests for manual_mode parameter.""" + + def test_manual_mode_uses_minimal_padding(self): + """Verify manual_mode uses MANUAL_LABEL_STRATEGY with minimal padding.""" + bbox = (100, 200, 200, 250) # width=100, height=50 + + result = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="bankgiro", # Would normally expand left significantly + manual_mode=True, + ) + + # MANUAL_LABEL_STRATEGY: scale=1.0, max_pad=10 + # Should only add 10px padding each side (but scale=1.0 means no scaling) + # Actually with scale=1.0, no extra ratios, we get 0 expansion from scaling + # Only max_pad=10 applies as a limit, but there's no expansion to limit + # So result should be same as original + assert result == (100, 200, 200, 250) + + def test_manual_mode_ignores_field_type(self): + """Verify manual_mode ignores field-specific strategies.""" + bbox = (100, 200, 200, 250) + + # Different fields should give same result in manual_mode + result_bankgiro = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="bankgiro", + manual_mode=True, + ) + + result_ocr = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="ocr_number", + manual_mode=True, + ) + + assert result_bankgiro == result_ocr + + def test_manual_mode_vs_auto_mode_different(self): + """Verify manual_mode produces different results than auto mode.""" + bbox = (100, 200, 200, 250) + + auto_result = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="bankgiro", # Has extra_left_ratio=0.80 + manual_mode=False, + ) + + manual_result = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="bankgiro", + manual_mode=True, + ) + + # Auto mode should expand more (especially to the left for bankgiro) + assert auto_result[0] < manual_result[0] # Auto x0 is more left + + def test_manual_mode_clamps_to_image_bounds(self): + """Verify manual_mode still respects image boundaries.""" + bbox = (5, 5, 50, 50) # Close to top-left corner + + result = expand_bbox( + bbox=bbox, + image_width=1000, + image_height=1000, + field_type="test", + manual_mode=True, + ) + + # Should clamp to 0 + assert result[0] >= 0 + assert result[1] >= 0 diff --git a/tests/shared/bbox/test_scale_strategy.py b/tests/shared/bbox/test_scale_strategy.py new file mode 100644 index 0000000..08e6d5a --- /dev/null +++ b/tests/shared/bbox/test_scale_strategy.py @@ -0,0 +1,192 @@ +""" +Tests for ScaleStrategy configuration. + +Tests verify that scale strategies are properly defined, immutable, +and cover all required fields. +""" + +import pytest + +from shared.bbox import ( + ScaleStrategy, + DEFAULT_STRATEGY, + MANUAL_LABEL_STRATEGY, + FIELD_SCALE_STRATEGIES, +) +from shared.fields import CLASS_NAMES + + +class TestScaleStrategyDataclass: + """Tests for ScaleStrategy dataclass behavior.""" + + def test_default_strategy_values(self): + """Verify default strategy has expected default values.""" + strategy = ScaleStrategy() + assert strategy.scale_x == 1.15 + assert strategy.scale_y == 1.15 + assert strategy.extra_top_ratio == 0.0 + assert strategy.extra_bottom_ratio == 0.0 + assert strategy.extra_left_ratio == 0.0 + assert strategy.extra_right_ratio == 0.0 + assert strategy.max_pad_x == 50 + assert strategy.max_pad_y == 50 + + def test_scale_strategy_immutability(self): + """Verify ScaleStrategy is frozen (immutable).""" + strategy = ScaleStrategy() + with pytest.raises(AttributeError): + strategy.scale_x = 2.0 # type: ignore + + def test_custom_strategy_values(self): + """Verify custom values are properly set.""" + strategy = ScaleStrategy( + scale_x=1.5, + scale_y=1.8, + extra_top_ratio=0.6, + extra_left_ratio=0.8, + max_pad_x=100, + max_pad_y=150, + ) + assert strategy.scale_x == 1.5 + assert strategy.scale_y == 1.8 + assert strategy.extra_top_ratio == 0.6 + assert strategy.extra_left_ratio == 0.8 + assert strategy.max_pad_x == 100 + assert strategy.max_pad_y == 150 + + +class TestDefaultStrategy: + """Tests for DEFAULT_STRATEGY constant.""" + + def test_default_strategy_is_scale_strategy(self): + """Verify DEFAULT_STRATEGY is a ScaleStrategy instance.""" + assert isinstance(DEFAULT_STRATEGY, ScaleStrategy) + + def test_default_strategy_matches_default_values(self): + """Verify DEFAULT_STRATEGY has same values as ScaleStrategy().""" + expected = ScaleStrategy() + assert DEFAULT_STRATEGY == expected + + +class TestManualLabelStrategy: + """Tests for MANUAL_LABEL_STRATEGY constant.""" + + def test_manual_label_strategy_is_scale_strategy(self): + """Verify MANUAL_LABEL_STRATEGY is a ScaleStrategy instance.""" + assert isinstance(MANUAL_LABEL_STRATEGY, ScaleStrategy) + + def test_manual_label_strategy_has_no_scaling(self): + """Verify MANUAL_LABEL_STRATEGY has scale factors of 1.0.""" + assert MANUAL_LABEL_STRATEGY.scale_x == 1.0 + assert MANUAL_LABEL_STRATEGY.scale_y == 1.0 + + def test_manual_label_strategy_has_no_directional_expansion(self): + """Verify MANUAL_LABEL_STRATEGY has no directional expansion.""" + assert MANUAL_LABEL_STRATEGY.extra_top_ratio == 0.0 + assert MANUAL_LABEL_STRATEGY.extra_bottom_ratio == 0.0 + assert MANUAL_LABEL_STRATEGY.extra_left_ratio == 0.0 + assert MANUAL_LABEL_STRATEGY.extra_right_ratio == 0.0 + + def test_manual_label_strategy_has_small_max_pad(self): + """Verify MANUAL_LABEL_STRATEGY has small max padding.""" + assert MANUAL_LABEL_STRATEGY.max_pad_x <= 15 + assert MANUAL_LABEL_STRATEGY.max_pad_y <= 15 + + +class TestFieldScaleStrategies: + """Tests for FIELD_SCALE_STRATEGIES dictionary.""" + + def test_all_class_names_have_strategies(self): + """Verify all field class names have defined strategies.""" + for class_name in CLASS_NAMES: + assert class_name in FIELD_SCALE_STRATEGIES, ( + f"Missing strategy for field: {class_name}" + ) + + def test_strategies_are_scale_strategy_instances(self): + """Verify all strategies are ScaleStrategy instances.""" + for field_name, strategy in FIELD_SCALE_STRATEGIES.items(): + assert isinstance(strategy, ScaleStrategy), ( + f"Strategy for {field_name} is not a ScaleStrategy" + ) + + def test_scale_values_are_greater_than_one(self): + """Verify all scale values are >= 1.0 (expansion, not contraction).""" + for field_name, strategy in FIELD_SCALE_STRATEGIES.items(): + assert strategy.scale_x >= 1.0, ( + f"{field_name} scale_x should be >= 1.0" + ) + assert strategy.scale_y >= 1.0, ( + f"{field_name} scale_y should be >= 1.0" + ) + + def test_extra_ratios_are_non_negative(self): + """Verify all extra ratios are >= 0.""" + for field_name, strategy in FIELD_SCALE_STRATEGIES.items(): + assert strategy.extra_top_ratio >= 0, ( + f"{field_name} extra_top_ratio should be >= 0" + ) + assert strategy.extra_bottom_ratio >= 0, ( + f"{field_name} extra_bottom_ratio should be >= 0" + ) + assert strategy.extra_left_ratio >= 0, ( + f"{field_name} extra_left_ratio should be >= 0" + ) + assert strategy.extra_right_ratio >= 0, ( + f"{field_name} extra_right_ratio should be >= 0" + ) + + def test_max_pad_values_are_positive(self): + """Verify all max_pad values are > 0.""" + for field_name, strategy in FIELD_SCALE_STRATEGIES.items(): + assert strategy.max_pad_x > 0, ( + f"{field_name} max_pad_x should be > 0" + ) + assert strategy.max_pad_y > 0, ( + f"{field_name} max_pad_y should be > 0" + ) + + +class TestSpecificFieldStrategies: + """Tests for specific field strategy configurations.""" + + def test_ocr_number_expands_upward(self): + """Verify ocr_number strategy expands upward to capture label.""" + strategy = FIELD_SCALE_STRATEGIES["ocr_number"] + assert strategy.extra_top_ratio > 0.0 + assert strategy.extra_top_ratio >= 0.5 # Significant upward expansion + + def test_bankgiro_expands_leftward(self): + """Verify bankgiro strategy expands leftward to capture prefix.""" + strategy = FIELD_SCALE_STRATEGIES["bankgiro"] + assert strategy.extra_left_ratio > 0.0 + assert strategy.extra_left_ratio >= 0.5 # Significant leftward expansion + + def test_plusgiro_expands_leftward(self): + """Verify plusgiro strategy expands leftward to capture prefix.""" + strategy = FIELD_SCALE_STRATEGIES["plusgiro"] + assert strategy.extra_left_ratio > 0.0 + assert strategy.extra_left_ratio >= 0.5 + + def test_amount_expands_rightward(self): + """Verify amount strategy expands rightward for currency symbol.""" + strategy = FIELD_SCALE_STRATEGIES["amount"] + assert strategy.extra_right_ratio > 0.0 + + def test_invoice_date_expands_upward(self): + """Verify invoice_date strategy expands upward to capture label.""" + strategy = FIELD_SCALE_STRATEGIES["invoice_date"] + assert strategy.extra_top_ratio > 0.0 + + def test_invoice_due_date_expands_upward_and_leftward(self): + """Verify invoice_due_date strategy expands both up and left.""" + strategy = FIELD_SCALE_STRATEGIES["invoice_due_date"] + assert strategy.extra_top_ratio > 0.0 + assert strategy.extra_left_ratio > 0.0 + + def test_payment_line_has_minimal_expansion(self): + """Verify payment_line has conservative expansion (machine code).""" + strategy = FIELD_SCALE_STRATEGIES["payment_line"] + # Payment line is machine-readable, needs minimal expansion + assert strategy.scale_x <= 1.2 + assert strategy.scale_y <= 1.3 diff --git a/tests/shared/fields/test_field_config.py b/tests/shared/fields/test_field_config.py index d4fc76d..81125a6 100644 --- a/tests/shared/fields/test_field_config.py +++ b/tests/shared/fields/test_field_config.py @@ -16,6 +16,7 @@ from shared.fields import ( FIELD_CLASSES, FIELD_CLASS_IDS, CLASS_TO_FIELD, + FIELD_TO_CLASS, CSV_TO_CLASS_MAPPING, TRAINING_FIELD_CLASSES, NUM_CLASSES, @@ -133,6 +134,20 @@ class TestMappingConsistency: assert fd.field_name in TRAINING_FIELD_CLASSES assert TRAINING_FIELD_CLASSES[fd.field_name] == fd.class_id + def test_field_to_class_is_inverse_of_class_to_field(self): + """Verify FIELD_TO_CLASS and CLASS_TO_FIELD are proper inverses.""" + for class_name, field_name in CLASS_TO_FIELD.items(): + assert FIELD_TO_CLASS[field_name] == class_name + + for field_name, class_name in FIELD_TO_CLASS.items(): + assert CLASS_TO_FIELD[class_name] == field_name + + def test_field_to_class_has_all_fields(self): + """Verify FIELD_TO_CLASS has mapping for all field names.""" + for fd in FIELD_DEFINITIONS: + assert fd.field_name in FIELD_TO_CLASS + assert FIELD_TO_CLASS[fd.field_name] == fd.class_name + class TestSpecificFieldDefinitions: """Tests for specific field definitions to catch common mistakes.""" diff --git a/tests/training/__init__.py b/tests/training/__init__.py new file mode 100644 index 0000000..88bec01 --- /dev/null +++ b/tests/training/__init__.py @@ -0,0 +1 @@ +"""Tests for training package.""" diff --git a/tests/training/yolo/__init__.py b/tests/training/yolo/__init__.py new file mode 100644 index 0000000..5d63d87 --- /dev/null +++ b/tests/training/yolo/__init__.py @@ -0,0 +1 @@ +"""Tests for training.yolo module.""" diff --git a/tests/training/yolo/test_annotation_generator.py b/tests/training/yolo/test_annotation_generator.py new file mode 100644 index 0000000..69f669e --- /dev/null +++ b/tests/training/yolo/test_annotation_generator.py @@ -0,0 +1,342 @@ +""" +Tests for AnnotationGenerator with field-specific bbox expansion. + +Tests verify that annotations are generated correctly using +field-specific scale strategies. +""" + +from dataclasses import dataclass +import pytest + +from training.yolo.annotation_generator import ( + AnnotationGenerator, + YOLOAnnotation, +) +from shared.fields import TRAINING_FIELD_CLASSES, CLASS_NAMES + + +@dataclass +class MockMatch: + """Mock Match object for testing.""" + bbox: tuple[float, float, float, float] + score: float + + +class TestYOLOAnnotation: + """Tests for YOLOAnnotation dataclass.""" + + def test_to_string_format(self): + """Verify YOLO format string output.""" + ann = YOLOAnnotation( + class_id=0, + x_center=0.5, + y_center=0.5, + width=0.1, + height=0.05, + confidence=0.9 + ) + result = ann.to_string() + assert result == "0 0.500000 0.500000 0.100000 0.050000" + + def test_default_confidence(self): + """Verify default confidence is 1.0.""" + ann = YOLOAnnotation( + class_id=0, + x_center=0.5, + y_center=0.5, + width=0.1, + height=0.05, + ) + assert ann.confidence == 1.0 + + +class TestAnnotationGeneratorInit: + """Tests for AnnotationGenerator initialization.""" + + def test_default_values(self): + """Verify default initialization values.""" + gen = AnnotationGenerator() + assert gen.min_confidence == 0.7 + assert gen.min_bbox_height_px == 30 + + def test_custom_values(self): + """Verify custom initialization values.""" + gen = AnnotationGenerator( + min_confidence=0.8, + min_bbox_height_px=40, + ) + assert gen.min_confidence == 0.8 + assert gen.min_bbox_height_px == 40 + + +class TestGenerateFromMatches: + """Tests for generate_from_matches method.""" + + def test_generates_annotation_for_valid_match(self): + """Verify annotation is generated for valid match.""" + gen = AnnotationGenerator(min_confidence=0.5) + + # Mock match in PDF points (72 DPI) + # At 150 DPI, coords multiply by 150/72 = 2.083 + matches = { + "InvoiceNumber": [MockMatch(bbox=(100, 200, 200, 230), score=0.8)] + } + + annotations = gen.generate_from_matches( + matches=matches, + image_width=1000, + image_height=1000, + dpi=150 + ) + + assert len(annotations) == 1 + ann = annotations[0] + assert ann.class_id == TRAINING_FIELD_CLASSES["InvoiceNumber"] + assert ann.confidence == 0.8 + # Normalized values should be in 0-1 range + assert 0 <= ann.x_center <= 1 + assert 0 <= ann.y_center <= 1 + assert 0 < ann.width <= 1 + assert 0 < ann.height <= 1 + + def test_skips_low_confidence_match(self): + """Verify low confidence matches are skipped.""" + gen = AnnotationGenerator(min_confidence=0.7) + + matches = { + "InvoiceNumber": [MockMatch(bbox=(100, 200, 200, 230), score=0.5)] + } + + annotations = gen.generate_from_matches( + matches=matches, + image_width=1000, + image_height=1000, + dpi=150 + ) + + assert len(annotations) == 0 + + def test_skips_unknown_field(self): + """Verify unknown fields are skipped.""" + gen = AnnotationGenerator(min_confidence=0.5) + + matches = { + "UnknownField": [MockMatch(bbox=(100, 200, 200, 230), score=0.9)] + } + + annotations = gen.generate_from_matches( + matches=matches, + image_width=1000, + image_height=1000, + dpi=150 + ) + + assert len(annotations) == 0 + + def test_takes_best_match_only(self): + """Verify only the best match is used per field.""" + gen = AnnotationGenerator(min_confidence=0.5) + + matches = { + "InvoiceNumber": [ + MockMatch(bbox=(100, 200, 200, 230), score=0.9), # Best + MockMatch(bbox=(300, 400, 400, 430), score=0.7), + ] + } + + annotations = gen.generate_from_matches( + matches=matches, + image_width=1000, + image_height=1000, + dpi=150 + ) + + assert len(annotations) == 1 + assert annotations[0].confidence == 0.9 + + def test_handles_empty_matches(self): + """Verify empty matches list is handled.""" + gen = AnnotationGenerator() + + matches = { + "InvoiceNumber": [] + } + + annotations = gen.generate_from_matches( + matches=matches, + image_width=1000, + image_height=1000, + dpi=150 + ) + + assert len(annotations) == 0 + + def test_applies_field_specific_expansion(self): + """Verify different fields get different expansion.""" + gen = AnnotationGenerator(min_confidence=0.5) + + # Same bbox, different fields + bbox = (100, 200, 200, 230) + + matches_invoice_number = { + "InvoiceNumber": [MockMatch(bbox=bbox, score=0.9)] + } + matches_bankgiro = { + "Bankgiro": [MockMatch(bbox=bbox, score=0.9)] + } + + ann_invoice = gen.generate_from_matches( + matches=matches_invoice_number, + image_width=1000, + image_height=1000, + dpi=150 + )[0] + + ann_bankgiro = gen.generate_from_matches( + matches=matches_bankgiro, + image_width=1000, + image_height=1000, + dpi=150 + )[0] + + # Bankgiro has extra_left_ratio=0.80, invoice_number has extra_top_ratio=0.40 + # They should have different widths due to different expansion + # Bankgiro expands more to the left + assert ann_bankgiro.width != ann_invoice.width or ann_bankgiro.x_center != ann_invoice.x_center + + def test_enforces_min_bbox_height(self): + """Verify minimum bbox height is enforced.""" + gen = AnnotationGenerator(min_confidence=0.5, min_bbox_height_px=50) + + # Very small bbox + matches = { + "InvoiceNumber": [MockMatch(bbox=(100, 200, 200, 210), score=0.9)] + } + + annotations = gen.generate_from_matches( + matches=matches, + image_width=1000, + image_height=1000, + dpi=72 # 1:1 scale + ) + + assert len(annotations) == 1 + # Height should be at least min_bbox_height_px / image_height + # After scale strategy expansion, height should be >= 50/1000 = 0.05 + # Actually the min_bbox_height check happens AFTER expand_bbox + # So the final height should meet the minimum + + +class TestAddPaymentLineAnnotation: + """Tests for add_payment_line_annotation method.""" + + def test_adds_payment_line_annotation(self): + """Verify payment_line annotation is added.""" + gen = AnnotationGenerator(min_confidence=0.5) + annotations = [] + + result = gen.add_payment_line_annotation( + annotations=annotations, + payment_line_bbox=(100, 200, 400, 230), + confidence=0.9, + image_width=1000, + image_height=1000, + dpi=150 + ) + + assert len(result) == 1 + ann = result[0] + assert ann.class_id == TRAINING_FIELD_CLASSES["payment_line"] + assert ann.confidence == 0.9 + + def test_skips_none_bbox(self): + """Verify None bbox is handled.""" + gen = AnnotationGenerator(min_confidence=0.5) + annotations = [] + + result = gen.add_payment_line_annotation( + annotations=annotations, + payment_line_bbox=None, + confidence=0.9, + image_width=1000, + image_height=1000, + dpi=150 + ) + + assert len(result) == 0 + + def test_skips_low_confidence(self): + """Verify low confidence is skipped.""" + gen = AnnotationGenerator(min_confidence=0.7) + annotations = [] + + result = gen.add_payment_line_annotation( + annotations=annotations, + payment_line_bbox=(100, 200, 400, 230), + confidence=0.5, + image_width=1000, + image_height=1000, + dpi=150 + ) + + assert len(result) == 0 + + def test_appends_to_existing_annotations(self): + """Verify payment_line is appended to existing list.""" + gen = AnnotationGenerator(min_confidence=0.5) + existing = [YOLOAnnotation(0, 0.5, 0.5, 0.1, 0.1, 0.9)] + + result = gen.add_payment_line_annotation( + annotations=existing, + payment_line_bbox=(100, 200, 400, 230), + confidence=0.9, + image_width=1000, + image_height=1000, + dpi=150 + ) + + assert len(result) == 2 + assert result[0].class_id == 0 # Original + assert result[1].class_id == TRAINING_FIELD_CLASSES["payment_line"] + + +class TestMultipleFieldsIntegration: + """Integration tests for multiple fields.""" + + def test_generates_annotations_for_all_field_types(self): + """Verify annotations can be generated for all field types.""" + gen = AnnotationGenerator(min_confidence=0.5) + + # Create matches for each field (except payment_line which is derived) + field_names = [ + "InvoiceNumber", + "InvoiceDate", + "InvoiceDueDate", + "OCR", + "Bankgiro", + "Plusgiro", + "Amount", + "supplier_organisation_number", + "customer_number", + ] + + matches = {} + for i, field_name in enumerate(field_names): + # Stagger bboxes to avoid overlap + matches[field_name] = [ + MockMatch(bbox=(100 + i * 50, 100 + i * 30, 200 + i * 50, 130 + i * 30), score=0.9) + ] + + annotations = gen.generate_from_matches( + matches=matches, + image_width=2000, + image_height=2000, + dpi=150 + ) + + assert len(annotations) == len(field_names) + + # Verify all class_ids are present + class_ids = {ann.class_id for ann in annotations} + expected_class_ids = {TRAINING_FIELD_CLASSES[fn] for fn in field_names} + assert class_ids == expected_class_ids