feat: add field-specific bbox expansion strategies for YOLO training
Implement center-point based bbox scaling with directional compensation to capture field labels that typically appear above or to the left of field values. This improves YOLO training data quality by including contextual information around field values. Key changes: - Add shared.bbox module with ScaleStrategy dataclass and expand_bbox function - Define field-specific strategies (ocr_number, bankgiro, invoice_date, etc.) - Support manual_mode for minimal padding (no scaling) - Integrate expand_bbox into AnnotationGenerator - Add FIELD_TO_CLASS mapping for field_name to class_name lookup - Comprehensive tests with 100% coverage (45 tests) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
37
packages/shared/shared/bbox/__init__.py
Normal file
37
packages/shared/shared/bbox/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""
|
||||||
|
BBox Scale Strategy Module.
|
||||||
|
|
||||||
|
Provides field-specific bounding box expansion strategies for YOLO training data.
|
||||||
|
Expands bboxes using center-point scaling with directional compensation to capture
|
||||||
|
field labels that typically appear above or to the left of field values.
|
||||||
|
|
||||||
|
Two modes are supported:
|
||||||
|
- Auto-label: Field-specific scale strategies with directional compensation
|
||||||
|
- Manual-label: Minimal padding only to prevent edge clipping
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from shared.bbox import expand_bbox, ScaleStrategy, FIELD_SCALE_STRATEGIES
|
||||||
|
|
||||||
|
Available exports:
|
||||||
|
- ScaleStrategy: Dataclass for scale strategy configuration
|
||||||
|
- DEFAULT_STRATEGY: Default strategy for unknown fields (auto-label)
|
||||||
|
- MANUAL_LABEL_STRATEGY: Minimal padding strategy for manual labels
|
||||||
|
- FIELD_SCALE_STRATEGIES: dict[str, ScaleStrategy] - field-specific strategies
|
||||||
|
- expand_bbox: Function to expand bbox using field-specific strategy
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .scale_strategy import (
|
||||||
|
ScaleStrategy,
|
||||||
|
DEFAULT_STRATEGY,
|
||||||
|
MANUAL_LABEL_STRATEGY,
|
||||||
|
FIELD_SCALE_STRATEGIES,
|
||||||
|
)
|
||||||
|
from .expander import expand_bbox
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ScaleStrategy",
|
||||||
|
"DEFAULT_STRATEGY",
|
||||||
|
"MANUAL_LABEL_STRATEGY",
|
||||||
|
"FIELD_SCALE_STRATEGIES",
|
||||||
|
"expand_bbox",
|
||||||
|
]
|
||||||
101
packages/shared/shared/bbox/expander.py
Normal file
101
packages/shared/shared/bbox/expander.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
"""
|
||||||
|
BBox Expander Module.
|
||||||
|
|
||||||
|
Provides functions to expand bounding boxes using field-specific strategies.
|
||||||
|
Expansion is center-point based with directional compensation.
|
||||||
|
|
||||||
|
Two modes:
|
||||||
|
- Auto-label (default): Field-specific scale strategies
|
||||||
|
- Manual-label: Minimal padding only to prevent edge clipping
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .scale_strategy import (
|
||||||
|
ScaleStrategy,
|
||||||
|
DEFAULT_STRATEGY,
|
||||||
|
MANUAL_LABEL_STRATEGY,
|
||||||
|
FIELD_SCALE_STRATEGIES,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def expand_bbox(
|
||||||
|
bbox: tuple[float, float, float, float],
|
||||||
|
image_width: float,
|
||||||
|
image_height: float,
|
||||||
|
field_type: str,
|
||||||
|
strategies: dict[str, ScaleStrategy] | None = None,
|
||||||
|
manual_mode: bool = False,
|
||||||
|
) -> tuple[int, int, int, int]:
|
||||||
|
"""
|
||||||
|
Expand bbox using field-specific scale strategy.
|
||||||
|
|
||||||
|
The expansion follows these steps:
|
||||||
|
1. Scale bbox around center point (scale_x, scale_y)
|
||||||
|
2. Apply directional compensation (extra_*_ratio)
|
||||||
|
3. Clamp expansion to max_pad limits
|
||||||
|
4. Clamp to image boundaries
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bbox: (x0, y0, x1, y1) in pixels
|
||||||
|
image_width: Image width for boundary clamping
|
||||||
|
image_height: Image height for boundary clamping
|
||||||
|
field_type: Field class_name (e.g., "ocr_number")
|
||||||
|
strategies: Custom strategies dict, defaults to FIELD_SCALE_STRATEGIES
|
||||||
|
manual_mode: If True, use MANUAL_LABEL_STRATEGY (minimal padding only)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Expanded bbox (x0, y0, x1, y1) as integers, clamped to image bounds
|
||||||
|
"""
|
||||||
|
x0, y0, x1, y1 = bbox
|
||||||
|
w = x1 - x0
|
||||||
|
h = y1 - y0
|
||||||
|
|
||||||
|
# Get strategy based on mode
|
||||||
|
if manual_mode:
|
||||||
|
strategy = MANUAL_LABEL_STRATEGY
|
||||||
|
elif strategies is None:
|
||||||
|
strategy = FIELD_SCALE_STRATEGIES.get(field_type, DEFAULT_STRATEGY)
|
||||||
|
else:
|
||||||
|
strategy = strategies.get(field_type, DEFAULT_STRATEGY)
|
||||||
|
|
||||||
|
# Step 1: Scale around center point
|
||||||
|
cx = (x0 + x1) / 2
|
||||||
|
cy = (y0 + y1) / 2
|
||||||
|
|
||||||
|
new_w = w * strategy.scale_x
|
||||||
|
new_h = h * strategy.scale_y
|
||||||
|
|
||||||
|
nx0 = cx - new_w / 2
|
||||||
|
nx1 = cx + new_w / 2
|
||||||
|
ny0 = cy - new_h / 2
|
||||||
|
ny1 = cy + new_h / 2
|
||||||
|
|
||||||
|
# Step 2: Apply directional compensation
|
||||||
|
nx0 -= w * strategy.extra_left_ratio
|
||||||
|
nx1 += w * strategy.extra_right_ratio
|
||||||
|
ny0 -= h * strategy.extra_top_ratio
|
||||||
|
ny1 += h * strategy.extra_bottom_ratio
|
||||||
|
|
||||||
|
# Step 3: Clamp expansion to max_pad limits (preserve asymmetry)
|
||||||
|
left_pad = min(x0 - nx0, strategy.max_pad_x)
|
||||||
|
right_pad = min(nx1 - x1, strategy.max_pad_x)
|
||||||
|
top_pad = min(y0 - ny0, strategy.max_pad_y)
|
||||||
|
bottom_pad = min(ny1 - y1, strategy.max_pad_y)
|
||||||
|
|
||||||
|
# Ensure pads are non-negative (in case of contraction)
|
||||||
|
left_pad = max(0, left_pad)
|
||||||
|
right_pad = max(0, right_pad)
|
||||||
|
top_pad = max(0, top_pad)
|
||||||
|
bottom_pad = max(0, bottom_pad)
|
||||||
|
|
||||||
|
nx0 = x0 - left_pad
|
||||||
|
nx1 = x1 + right_pad
|
||||||
|
ny0 = y0 - top_pad
|
||||||
|
ny1 = y1 + bottom_pad
|
||||||
|
|
||||||
|
# Step 4: Clamp to image boundaries
|
||||||
|
nx0 = max(0, int(nx0))
|
||||||
|
ny0 = max(0, int(ny0))
|
||||||
|
nx1 = min(int(image_width), int(nx1))
|
||||||
|
ny1 = min(int(image_height), int(ny1))
|
||||||
|
|
||||||
|
return (nx0, ny0, nx1, ny1)
|
||||||
140
packages/shared/shared/bbox/scale_strategy.py
Normal file
140
packages/shared/shared/bbox/scale_strategy.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
"""
|
||||||
|
Scale Strategy Configuration.
|
||||||
|
|
||||||
|
Defines field-specific bbox expansion strategies for YOLO training data.
|
||||||
|
Each strategy controls how bboxes are expanded around field values to
|
||||||
|
capture contextual information like labels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ScaleStrategy:
|
||||||
|
"""Immutable scale strategy for bbox expansion.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
scale_x: Horizontal scale factor (1.0 = no scaling)
|
||||||
|
scale_y: Vertical scale factor (1.0 = no scaling)
|
||||||
|
extra_top_ratio: Additional expansion ratio towards top (for labels above)
|
||||||
|
extra_bottom_ratio: Additional expansion ratio towards bottom
|
||||||
|
extra_left_ratio: Additional expansion ratio towards left (for prefixes)
|
||||||
|
extra_right_ratio: Additional expansion ratio towards right (for suffixes)
|
||||||
|
max_pad_x: Maximum horizontal padding in pixels
|
||||||
|
max_pad_y: Maximum vertical padding in pixels
|
||||||
|
"""
|
||||||
|
|
||||||
|
scale_x: float = 1.15
|
||||||
|
scale_y: float = 1.15
|
||||||
|
extra_top_ratio: float = 0.0
|
||||||
|
extra_bottom_ratio: float = 0.0
|
||||||
|
extra_left_ratio: float = 0.0
|
||||||
|
extra_right_ratio: float = 0.0
|
||||||
|
max_pad_x: int = 50
|
||||||
|
max_pad_y: int = 50
|
||||||
|
|
||||||
|
|
||||||
|
# Default strategy for unknown fields (auto-label mode)
|
||||||
|
DEFAULT_STRATEGY: Final[ScaleStrategy] = ScaleStrategy()
|
||||||
|
|
||||||
|
# Manual label strategy - minimal padding to prevent edge clipping
|
||||||
|
# No scaling, no directional compensation, just small uniform padding
|
||||||
|
MANUAL_LABEL_STRATEGY: Final[ScaleStrategy] = ScaleStrategy(
|
||||||
|
scale_x=1.0,
|
||||||
|
scale_y=1.0,
|
||||||
|
extra_top_ratio=0.0,
|
||||||
|
extra_bottom_ratio=0.0,
|
||||||
|
extra_left_ratio=0.0,
|
||||||
|
extra_right_ratio=0.0,
|
||||||
|
max_pad_x=10, # Small padding to prevent edge loss
|
||||||
|
max_pad_y=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Field-specific strategies based on Swedish invoice field characteristics
|
||||||
|
# Field labels typically appear above or to the left of values
|
||||||
|
FIELD_SCALE_STRATEGIES: Final[dict[str, ScaleStrategy]] = {
|
||||||
|
# OCR number - label "OCR" or "Referens" typically above
|
||||||
|
"ocr_number": ScaleStrategy(
|
||||||
|
scale_x=1.15,
|
||||||
|
scale_y=1.80,
|
||||||
|
extra_top_ratio=0.60,
|
||||||
|
max_pad_x=50,
|
||||||
|
max_pad_y=140,
|
||||||
|
),
|
||||||
|
# Bankgiro - prefix "Bankgiro:" or "BG:" typically to the left
|
||||||
|
"bankgiro": ScaleStrategy(
|
||||||
|
scale_x=1.45,
|
||||||
|
scale_y=1.35,
|
||||||
|
extra_left_ratio=0.80,
|
||||||
|
max_pad_x=160,
|
||||||
|
max_pad_y=90,
|
||||||
|
),
|
||||||
|
# Plusgiro - prefix "Plusgiro:" or "PG:" typically to the left
|
||||||
|
"plusgiro": ScaleStrategy(
|
||||||
|
scale_x=1.45,
|
||||||
|
scale_y=1.35,
|
||||||
|
extra_left_ratio=0.80,
|
||||||
|
max_pad_x=160,
|
||||||
|
max_pad_y=90,
|
||||||
|
),
|
||||||
|
# Invoice date - label "Fakturadatum" typically above
|
||||||
|
"invoice_date": ScaleStrategy(
|
||||||
|
scale_x=1.25,
|
||||||
|
scale_y=1.55,
|
||||||
|
extra_top_ratio=0.40,
|
||||||
|
max_pad_x=80,
|
||||||
|
max_pad_y=110,
|
||||||
|
),
|
||||||
|
# Due date - label "Forfalldatum" typically above, sometimes left
|
||||||
|
"invoice_due_date": ScaleStrategy(
|
||||||
|
scale_x=1.30,
|
||||||
|
scale_y=1.65,
|
||||||
|
extra_top_ratio=0.45,
|
||||||
|
extra_left_ratio=0.35,
|
||||||
|
max_pad_x=100,
|
||||||
|
max_pad_y=120,
|
||||||
|
),
|
||||||
|
# Amount - currency symbol "SEK" or "kr" may be to the right
|
||||||
|
"amount": ScaleStrategy(
|
||||||
|
scale_x=1.20,
|
||||||
|
scale_y=1.35,
|
||||||
|
extra_right_ratio=0.30,
|
||||||
|
max_pad_x=70,
|
||||||
|
max_pad_y=80,
|
||||||
|
),
|
||||||
|
# Invoice number - label "Fakturanummer" typically above
|
||||||
|
"invoice_number": ScaleStrategy(
|
||||||
|
scale_x=1.20,
|
||||||
|
scale_y=1.50,
|
||||||
|
extra_top_ratio=0.40,
|
||||||
|
max_pad_x=80,
|
||||||
|
max_pad_y=100,
|
||||||
|
),
|
||||||
|
# Supplier org number - label "Org.nr" typically above or left
|
||||||
|
"supplier_org_number": ScaleStrategy(
|
||||||
|
scale_x=1.25,
|
||||||
|
scale_y=1.40,
|
||||||
|
extra_top_ratio=0.30,
|
||||||
|
extra_left_ratio=0.20,
|
||||||
|
max_pad_x=90,
|
||||||
|
max_pad_y=90,
|
||||||
|
),
|
||||||
|
# Customer number - label "Kundnummer" typically above or left
|
||||||
|
"customer_number": ScaleStrategy(
|
||||||
|
scale_x=1.25,
|
||||||
|
scale_y=1.45,
|
||||||
|
extra_top_ratio=0.35,
|
||||||
|
extra_left_ratio=0.25,
|
||||||
|
max_pad_x=90,
|
||||||
|
max_pad_y=100,
|
||||||
|
),
|
||||||
|
# Payment line - machine-readable code, minimal expansion needed
|
||||||
|
"payment_line": ScaleStrategy(
|
||||||
|
scale_x=1.10,
|
||||||
|
scale_y=1.20,
|
||||||
|
max_pad_x=40,
|
||||||
|
max_pad_y=40,
|
||||||
|
),
|
||||||
|
}
|
||||||
@@ -16,6 +16,7 @@ Available exports:
|
|||||||
- FIELD_CLASSES: dict[int, str] - class_id to class_name
|
- FIELD_CLASSES: dict[int, str] - class_id to class_name
|
||||||
- FIELD_CLASS_IDS: dict[str, int] - class_name to class_id
|
- FIELD_CLASS_IDS: dict[str, int] - class_name to class_id
|
||||||
- CLASS_TO_FIELD: dict[str, str] - class_name to field_name
|
- CLASS_TO_FIELD: dict[str, str] - class_name to field_name
|
||||||
|
- FIELD_TO_CLASS: dict[str, str] - field_name to class_name
|
||||||
- CSV_TO_CLASS_MAPPING: dict[str, int] - field_name to class_id (excludes derived)
|
- CSV_TO_CLASS_MAPPING: dict[str, int] - field_name to class_id (excludes derived)
|
||||||
- TRAINING_FIELD_CLASSES: dict[str, int] - field_name to class_id (all fields)
|
- TRAINING_FIELD_CLASSES: dict[str, int] - field_name to class_id (all fields)
|
||||||
- ACCOUNT_FIELD_MAPPING: Mapping for supplier_accounts handling
|
- ACCOUNT_FIELD_MAPPING: Mapping for supplier_accounts handling
|
||||||
@@ -27,6 +28,7 @@ from .mappings import (
|
|||||||
FIELD_CLASSES,
|
FIELD_CLASSES,
|
||||||
FIELD_CLASS_IDS,
|
FIELD_CLASS_IDS,
|
||||||
CLASS_TO_FIELD,
|
CLASS_TO_FIELD,
|
||||||
|
FIELD_TO_CLASS,
|
||||||
CSV_TO_CLASS_MAPPING,
|
CSV_TO_CLASS_MAPPING,
|
||||||
TRAINING_FIELD_CLASSES,
|
TRAINING_FIELD_CLASSES,
|
||||||
ACCOUNT_FIELD_MAPPING,
|
ACCOUNT_FIELD_MAPPING,
|
||||||
@@ -40,6 +42,7 @@ __all__ = [
|
|||||||
"FIELD_CLASSES",
|
"FIELD_CLASSES",
|
||||||
"FIELD_CLASS_IDS",
|
"FIELD_CLASS_IDS",
|
||||||
"CLASS_TO_FIELD",
|
"CLASS_TO_FIELD",
|
||||||
|
"FIELD_TO_CLASS",
|
||||||
"CSV_TO_CLASS_MAPPING",
|
"CSV_TO_CLASS_MAPPING",
|
||||||
"TRAINING_FIELD_CLASSES",
|
"TRAINING_FIELD_CLASSES",
|
||||||
"ACCOUNT_FIELD_MAPPING",
|
"ACCOUNT_FIELD_MAPPING",
|
||||||
|
|||||||
@@ -47,6 +47,12 @@ TRAINING_FIELD_CLASSES: Final[dict[str, int]] = {
|
|||||||
fd.field_name: fd.class_id for fd in FIELD_DEFINITIONS
|
fd.field_name: fd.class_id for fd in FIELD_DEFINITIONS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# field_name -> class_name mapping (reverse of CLASS_TO_FIELD)
|
||||||
|
# Example: {"InvoiceNumber": "invoice_number", "OCR": "ocr_number", ...}
|
||||||
|
FIELD_TO_CLASS: Final[dict[str, str]] = {
|
||||||
|
fd.field_name: fd.class_name for fd in FIELD_DEFINITIONS
|
||||||
|
}
|
||||||
|
|
||||||
# Account field mapping for supplier_accounts special handling
|
# Account field mapping for supplier_accounts special handling
|
||||||
# BG:xxx -> Bankgiro, PG:xxx -> Plusgiro
|
# BG:xxx -> Bankgiro, PG:xxx -> Plusgiro
|
||||||
ACCOUNT_FIELD_MAPPING: Final[dict[str, dict[str, str]]] = {
|
ACCOUNT_FIELD_MAPPING: Final[dict[str, dict[str, str]]] = {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
YOLO Annotation Generator
|
YOLO Annotation Generator
|
||||||
|
|
||||||
Generates YOLO format annotations from matched fields.
|
Generates YOLO format annotations from matched fields.
|
||||||
|
Uses field-specific bbox expansion strategies for optimal training data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -14,7 +15,9 @@ from shared.fields import (
|
|||||||
TRAINING_FIELD_CLASSES as FIELD_CLASSES,
|
TRAINING_FIELD_CLASSES as FIELD_CLASSES,
|
||||||
CLASS_NAMES,
|
CLASS_NAMES,
|
||||||
ACCOUNT_FIELD_MAPPING,
|
ACCOUNT_FIELD_MAPPING,
|
||||||
|
FIELD_TO_CLASS,
|
||||||
)
|
)
|
||||||
|
from shared.bbox import expand_bbox
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -38,19 +41,16 @@ class AnnotationGenerator:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
min_confidence: float = 0.7,
|
min_confidence: float = 0.7,
|
||||||
bbox_padding_px: int = 20, # Absolute padding in pixels
|
min_bbox_height_px: int = 30, # Minimum bbox height
|
||||||
min_bbox_height_px: int = 30 # Minimum bbox height
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize annotation generator.
|
Initialize annotation generator.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
min_confidence: Minimum match score to include in training
|
min_confidence: Minimum match score to include in training
|
||||||
bbox_padding_px: Absolute padding in pixels to add around bboxes
|
|
||||||
min_bbox_height_px: Minimum bbox height in pixels
|
min_bbox_height_px: Minimum bbox height in pixels
|
||||||
"""
|
"""
|
||||||
self.min_confidence = min_confidence
|
self.min_confidence = min_confidence
|
||||||
self.bbox_padding_px = bbox_padding_px
|
|
||||||
self.min_bbox_height_px = min_bbox_height_px
|
self.min_bbox_height_px = min_bbox_height_px
|
||||||
|
|
||||||
def generate_from_matches(
|
def generate_from_matches(
|
||||||
@@ -63,6 +63,10 @@ class AnnotationGenerator:
|
|||||||
"""
|
"""
|
||||||
Generate YOLO annotations from field matches.
|
Generate YOLO annotations from field matches.
|
||||||
|
|
||||||
|
Uses field-specific bbox expansion strategies for optimal training data.
|
||||||
|
Each field type has customized scale factors and directional compensation
|
||||||
|
to capture field labels and context.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
matches: Dict of field_name -> list of Match objects
|
matches: Dict of field_name -> list of Match objects
|
||||||
image_width: Width of the rendered image in pixels
|
image_width: Width of the rendered image in pixels
|
||||||
@@ -82,6 +86,8 @@ class AnnotationGenerator:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
class_id = FIELD_CLASSES[field_name]
|
class_id = FIELD_CLASSES[field_name]
|
||||||
|
# Get class_name for bbox expansion strategy
|
||||||
|
class_name = FIELD_TO_CLASS.get(field_name, field_name)
|
||||||
|
|
||||||
# Take only the best match per field
|
# Take only the best match per field
|
||||||
if field_matches:
|
if field_matches:
|
||||||
@@ -94,19 +100,20 @@ class AnnotationGenerator:
|
|||||||
x0, y0, x1, y1 = best_match.bbox
|
x0, y0, x1, y1 = best_match.bbox
|
||||||
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
|
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
|
||||||
|
|
||||||
# Add absolute padding
|
# Apply field-specific bbox expansion strategy
|
||||||
pad = self.bbox_padding_px
|
x0, y0, x1, y1 = expand_bbox(
|
||||||
x0 = max(0, x0 - pad)
|
bbox=(x0, y0, x1, y1),
|
||||||
y0 = max(0, y0 - pad)
|
image_width=image_width,
|
||||||
x1 = min(image_width, x1 + pad)
|
image_height=image_height,
|
||||||
y1 = min(image_height, y1 + pad)
|
field_type=class_name,
|
||||||
|
)
|
||||||
|
|
||||||
# Ensure minimum height
|
# Ensure minimum height
|
||||||
current_height = y1 - y0
|
current_height = y1 - y0
|
||||||
if current_height < self.min_bbox_height_px:
|
if current_height < self.min_bbox_height_px:
|
||||||
extra = (self.min_bbox_height_px - current_height) / 2
|
extra = (self.min_bbox_height_px - current_height) / 2
|
||||||
y0 = max(0, y0 - extra)
|
y0 = max(0, int(y0 - extra))
|
||||||
y1 = min(image_height, y1 + extra)
|
y1 = min(int(image_height), int(y1 + extra))
|
||||||
|
|
||||||
# Convert to YOLO format (normalized center + size)
|
# Convert to YOLO format (normalized center + size)
|
||||||
x_center = (x0 + x1) / 2 / image_width
|
x_center = (x0 + x1) / 2 / image_width
|
||||||
@@ -143,6 +150,9 @@ class AnnotationGenerator:
|
|||||||
"""
|
"""
|
||||||
Add payment_line annotation from machine code parser result.
|
Add payment_line annotation from machine code parser result.
|
||||||
|
|
||||||
|
Uses "payment_line" scale strategy for minimal expansion
|
||||||
|
(machine-readable code needs less context).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
annotations: Existing list of annotations to append to
|
annotations: Existing list of annotations to append to
|
||||||
payment_line_bbox: Bounding box (x0, y0, x1, y1) in PDF coordinates
|
payment_line_bbox: Bounding box (x0, y0, x1, y1) in PDF coordinates
|
||||||
@@ -163,12 +173,13 @@ class AnnotationGenerator:
|
|||||||
x0, y0, x1, y1 = payment_line_bbox
|
x0, y0, x1, y1 = payment_line_bbox
|
||||||
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
|
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
|
||||||
|
|
||||||
# Add absolute padding
|
# Apply field-specific bbox expansion strategy for payment_line
|
||||||
pad = self.bbox_padding_px
|
x0, y0, x1, y1 = expand_bbox(
|
||||||
x0 = max(0, x0 - pad)
|
bbox=(x0, y0, x1, y1),
|
||||||
y0 = max(0, y0 - pad)
|
image_width=image_width,
|
||||||
x1 = min(image_width, x1 + pad)
|
image_height=image_height,
|
||||||
y1 = min(image_height, y1 + pad)
|
field_type="payment_line",
|
||||||
|
)
|
||||||
|
|
||||||
# Convert to YOLO format (normalized center + size)
|
# Convert to YOLO format (normalized center + size)
|
||||||
x_center = (x0 + x1) / 2 / image_width
|
x_center = (x0 + x1) / 2 / image_width
|
||||||
|
|||||||
1
tests/shared/bbox/__init__.py
Normal file
1
tests/shared/bbox/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for shared.bbox module."""
|
||||||
556
tests/shared/bbox/test_expander.py
Normal file
556
tests/shared/bbox/test_expander.py
Normal file
@@ -0,0 +1,556 @@
|
|||||||
|
"""
|
||||||
|
Tests for expand_bbox function.
|
||||||
|
|
||||||
|
Tests verify that bbox expansion works correctly with center-point scaling,
|
||||||
|
directional compensation, max padding clamping, and image boundary handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from shared.bbox import (
|
||||||
|
expand_bbox,
|
||||||
|
ScaleStrategy,
|
||||||
|
FIELD_SCALE_STRATEGIES,
|
||||||
|
DEFAULT_STRATEGY,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestExpandBboxCenterScaling:
|
||||||
|
"""Tests for center-point based scaling."""
|
||||||
|
|
||||||
|
def test_center_scaling_expands_symmetrically(self):
|
||||||
|
"""Verify bbox expands symmetrically around center when no extra ratios."""
|
||||||
|
# 100x50 bbox at (100, 200)
|
||||||
|
bbox = (100, 200, 200, 250)
|
||||||
|
strategy = ScaleStrategy(
|
||||||
|
scale_x=1.2, # 20% wider
|
||||||
|
scale_y=1.4, # 40% taller
|
||||||
|
max_pad_x=1000, # Large to avoid clamping
|
||||||
|
max_pad_y=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="test_field",
|
||||||
|
strategies={"test_field": strategy},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Original: width=100, height=50
|
||||||
|
# New: width=120, height=70
|
||||||
|
# Center: (150, 225)
|
||||||
|
# Expected: x0=150-60=90, x1=150+60=210, y0=225-35=190, y1=225+35=260
|
||||||
|
assert result[0] == 90 # x0
|
||||||
|
assert result[1] == 190 # y0
|
||||||
|
assert result[2] == 210 # x1
|
||||||
|
assert result[3] == 260 # y1
|
||||||
|
|
||||||
|
def test_no_scaling_returns_original(self):
|
||||||
|
"""Verify scale=1.0 with no extras returns original bbox."""
|
||||||
|
bbox = (100, 200, 200, 250)
|
||||||
|
strategy = ScaleStrategy(
|
||||||
|
scale_x=1.0,
|
||||||
|
scale_y=1.0,
|
||||||
|
max_pad_x=1000,
|
||||||
|
max_pad_y=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="test_field",
|
||||||
|
strategies={"test_field": strategy},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == (100, 200, 200, 250)
|
||||||
|
|
||||||
|
|
||||||
|
class TestExpandBboxDirectionalCompensation:
|
||||||
|
"""Tests for directional compensation (extra ratios)."""
|
||||||
|
|
||||||
|
def test_extra_top_expands_upward(self):
|
||||||
|
"""Verify extra_top_ratio adds expansion toward top."""
|
||||||
|
bbox = (100, 200, 200, 250) # width=100, height=50
|
||||||
|
strategy = ScaleStrategy(
|
||||||
|
scale_x=1.0,
|
||||||
|
scale_y=1.0,
|
||||||
|
extra_top_ratio=0.5, # Add 50% of height to top
|
||||||
|
max_pad_x=1000,
|
||||||
|
max_pad_y=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="test_field",
|
||||||
|
strategies={"test_field": strategy},
|
||||||
|
)
|
||||||
|
|
||||||
|
# extra_top = 50 * 0.5 = 25
|
||||||
|
assert result[0] == 100 # x0 unchanged
|
||||||
|
assert result[1] == 175 # y0 = 200 - 25
|
||||||
|
assert result[2] == 200 # x1 unchanged
|
||||||
|
assert result[3] == 250 # y1 unchanged
|
||||||
|
|
||||||
|
def test_extra_left_expands_leftward(self):
|
||||||
|
"""Verify extra_left_ratio adds expansion toward left."""
|
||||||
|
bbox = (100, 200, 200, 250) # width=100
|
||||||
|
strategy = ScaleStrategy(
|
||||||
|
scale_x=1.0,
|
||||||
|
scale_y=1.0,
|
||||||
|
extra_left_ratio=0.8, # Add 80% of width to left
|
||||||
|
max_pad_x=1000,
|
||||||
|
max_pad_y=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="test_field",
|
||||||
|
strategies={"test_field": strategy},
|
||||||
|
)
|
||||||
|
|
||||||
|
# extra_left = 100 * 0.8 = 80
|
||||||
|
assert result[0] == 20 # x0 = 100 - 80
|
||||||
|
assert result[1] == 200 # y0 unchanged
|
||||||
|
assert result[2] == 200 # x1 unchanged
|
||||||
|
assert result[3] == 250 # y1 unchanged
|
||||||
|
|
||||||
|
def test_extra_right_expands_rightward(self):
|
||||||
|
"""Verify extra_right_ratio adds expansion toward right."""
|
||||||
|
bbox = (100, 200, 200, 250) # width=100
|
||||||
|
strategy = ScaleStrategy(
|
||||||
|
scale_x=1.0,
|
||||||
|
scale_y=1.0,
|
||||||
|
extra_right_ratio=0.3, # Add 30% of width to right
|
||||||
|
max_pad_x=1000,
|
||||||
|
max_pad_y=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="test_field",
|
||||||
|
strategies={"test_field": strategy},
|
||||||
|
)
|
||||||
|
|
||||||
|
# extra_right = 100 * 0.3 = 30
|
||||||
|
assert result[0] == 100 # x0 unchanged
|
||||||
|
assert result[1] == 200 # y0 unchanged
|
||||||
|
assert result[2] == 230 # x1 = 200 + 30
|
||||||
|
assert result[3] == 250 # y1 unchanged
|
||||||
|
|
||||||
|
def test_extra_bottom_expands_downward(self):
|
||||||
|
"""Verify extra_bottom_ratio adds expansion toward bottom."""
|
||||||
|
bbox = (100, 200, 200, 250) # height=50
|
||||||
|
strategy = ScaleStrategy(
|
||||||
|
scale_x=1.0,
|
||||||
|
scale_y=1.0,
|
||||||
|
extra_bottom_ratio=0.4, # Add 40% of height to bottom
|
||||||
|
max_pad_x=1000,
|
||||||
|
max_pad_y=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="test_field",
|
||||||
|
strategies={"test_field": strategy},
|
||||||
|
)
|
||||||
|
|
||||||
|
# extra_bottom = 50 * 0.4 = 20
|
||||||
|
assert result[0] == 100 # x0 unchanged
|
||||||
|
assert result[1] == 200 # y0 unchanged
|
||||||
|
assert result[2] == 200 # x1 unchanged
|
||||||
|
assert result[3] == 270 # y1 = 250 + 20
|
||||||
|
|
||||||
|
def test_combined_scaling_and_directional(self):
|
||||||
|
"""Verify scale + directional compensation work together."""
|
||||||
|
bbox = (100, 200, 200, 250) # width=100, height=50
|
||||||
|
strategy = ScaleStrategy(
|
||||||
|
scale_x=1.2, # 20% wider -> 120 width
|
||||||
|
scale_y=1.0, # no height change
|
||||||
|
extra_left_ratio=0.5, # Add 50% of width to left
|
||||||
|
max_pad_x=1000,
|
||||||
|
max_pad_y=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="test_field",
|
||||||
|
strategies={"test_field": strategy},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Center: x=150
|
||||||
|
# After scale: width=120 -> x0=150-60=90, x1=150+60=210
|
||||||
|
# After extra_left: x0 = 90 - (100 * 0.5) = 40
|
||||||
|
assert result[0] == 40 # x0
|
||||||
|
assert result[2] == 210 # x1
|
||||||
|
|
||||||
|
|
||||||
|
class TestExpandBboxMaxPadClamping:
|
||||||
|
"""Tests for max padding clamping."""
|
||||||
|
|
||||||
|
def test_max_pad_x_limits_horizontal_expansion(self):
|
||||||
|
"""Verify max_pad_x limits expansion on left and right."""
|
||||||
|
bbox = (100, 200, 200, 250) # width=100
|
||||||
|
strategy = ScaleStrategy(
|
||||||
|
scale_x=2.0, # Double width (would add 50 each side)
|
||||||
|
scale_y=1.0,
|
||||||
|
max_pad_x=30, # Limit to 30 pixels each side
|
||||||
|
max_pad_y=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="test_field",
|
||||||
|
strategies={"test_field": strategy},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scale would make: x0=100, x1=200 -> x0=50, x1=250 (50px each side)
|
||||||
|
# But max_pad_x=30 limits to: x0=70, x1=230
|
||||||
|
assert result[0] == 70 # x0 = 100 - 30
|
||||||
|
assert result[2] == 230 # x1 = 200 + 30
|
||||||
|
|
||||||
|
def test_max_pad_y_limits_vertical_expansion(self):
|
||||||
|
"""Verify max_pad_y limits expansion on top and bottom."""
|
||||||
|
bbox = (100, 200, 200, 250) # height=50
|
||||||
|
strategy = ScaleStrategy(
|
||||||
|
scale_x=1.0,
|
||||||
|
scale_y=3.0, # Triple height (would add 50 each side)
|
||||||
|
max_pad_x=1000,
|
||||||
|
max_pad_y=20, # Limit to 20 pixels each side
|
||||||
|
)
|
||||||
|
|
||||||
|
result = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="test_field",
|
||||||
|
strategies={"test_field": strategy},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scale would make: y0=175, y1=275 (50px each side)
|
||||||
|
# But max_pad_y=20 limits to: y0=180, y1=270
|
||||||
|
assert result[1] == 180 # y0 = 200 - 20
|
||||||
|
assert result[3] == 270 # y1 = 250 + 20
|
||||||
|
|
||||||
|
def test_max_pad_preserves_asymmetry(self):
|
||||||
|
"""Verify max_pad clamping preserves asymmetric expansion."""
|
||||||
|
bbox = (100, 200, 200, 250) # width=100
|
||||||
|
strategy = ScaleStrategy(
|
||||||
|
scale_x=1.0,
|
||||||
|
scale_y=1.0,
|
||||||
|
extra_left_ratio=1.0, # 100px left expansion
|
||||||
|
extra_right_ratio=0.0, # No right expansion
|
||||||
|
max_pad_x=50, # Limit to 50 pixels
|
||||||
|
max_pad_y=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="test_field",
|
||||||
|
strategies={"test_field": strategy},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Left would expand 100, clamped to 50
|
||||||
|
# Right stays at 0
|
||||||
|
assert result[0] == 50 # x0 = 100 - 50
|
||||||
|
assert result[2] == 200 # x1 unchanged
|
||||||
|
|
||||||
|
|
||||||
|
class TestExpandBboxImageBoundaryClamping:
|
||||||
|
"""Tests for image boundary clamping."""
|
||||||
|
|
||||||
|
def test_clamps_to_left_boundary(self):
|
||||||
|
"""Verify x0 is clamped to 0."""
|
||||||
|
bbox = (10, 200, 110, 250) # Close to left edge
|
||||||
|
strategy = ScaleStrategy(
|
||||||
|
scale_x=1.0,
|
||||||
|
scale_y=1.0,
|
||||||
|
extra_left_ratio=0.5, # Would push x0 below 0
|
||||||
|
max_pad_x=1000,
|
||||||
|
max_pad_y=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="test_field",
|
||||||
|
strategies={"test_field": strategy},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result[0] == 0 # Clamped to 0
|
||||||
|
|
||||||
|
def test_clamps_to_top_boundary(self):
|
||||||
|
"""Verify y0 is clamped to 0."""
|
||||||
|
bbox = (100, 10, 200, 60) # Close to top edge
|
||||||
|
strategy = ScaleStrategy(
|
||||||
|
scale_x=1.0,
|
||||||
|
scale_y=1.0,
|
||||||
|
extra_top_ratio=0.5, # Would push y0 below 0
|
||||||
|
max_pad_x=1000,
|
||||||
|
max_pad_y=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="test_field",
|
||||||
|
strategies={"test_field": strategy},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result[1] == 0 # Clamped to 0
|
||||||
|
|
||||||
|
def test_clamps_to_right_boundary(self):
|
||||||
|
"""Verify x1 is clamped to image_width."""
|
||||||
|
bbox = (900, 200, 990, 250) # Close to right edge
|
||||||
|
strategy = ScaleStrategy(
|
||||||
|
scale_x=1.0,
|
||||||
|
scale_y=1.0,
|
||||||
|
extra_right_ratio=0.5, # Would push x1 beyond image_width
|
||||||
|
max_pad_x=1000,
|
||||||
|
max_pad_y=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="test_field",
|
||||||
|
strategies={"test_field": strategy},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result[2] == 1000 # Clamped to image_width
|
||||||
|
|
||||||
|
def test_clamps_to_bottom_boundary(self):
|
||||||
|
"""Verify y1 is clamped to image_height."""
|
||||||
|
bbox = (100, 940, 200, 990) # Close to bottom edge
|
||||||
|
strategy = ScaleStrategy(
|
||||||
|
scale_x=1.0,
|
||||||
|
scale_y=1.0,
|
||||||
|
extra_bottom_ratio=0.5, # Would push y1 beyond image_height
|
||||||
|
max_pad_x=1000,
|
||||||
|
max_pad_y=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="test_field",
|
||||||
|
strategies={"test_field": strategy},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result[3] == 1000 # Clamped to image_height
|
||||||
|
|
||||||
|
|
||||||
|
class TestExpandBboxUnknownField:
|
||||||
|
"""Tests for unknown field handling."""
|
||||||
|
|
||||||
|
def test_unknown_field_uses_default_strategy(self):
|
||||||
|
"""Verify unknown field types use DEFAULT_STRATEGY."""
|
||||||
|
bbox = (100, 200, 200, 250)
|
||||||
|
|
||||||
|
result = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="unknown_field_xyz",
|
||||||
|
)
|
||||||
|
|
||||||
|
# DEFAULT_STRATEGY: scale_x=1.15, scale_y=1.15
|
||||||
|
# Original: width=100, height=50
|
||||||
|
# New: width=115, height=57.5
|
||||||
|
# Center: (150, 225)
|
||||||
|
# x0 = 150 - 57.5 = 92.5 -> 92
|
||||||
|
# x1 = 150 + 57.5 = 207.5 -> 207
|
||||||
|
# y0 = 225 - 28.75 = 196.25 -> 196
|
||||||
|
# y1 = 225 + 28.75 = 253.75 -> 253
|
||||||
|
# But max_pad_x=50 may clamp...
|
||||||
|
# Left pad = 100 - 92.5 = 7.5 (< 50, ok)
|
||||||
|
# Right pad = 207.5 - 200 = 7.5 (< 50, ok)
|
||||||
|
assert result[0] == 92
|
||||||
|
assert result[2] == 207
|
||||||
|
|
||||||
|
|
||||||
|
class TestExpandBboxWithRealStrategies:
|
||||||
|
"""Tests using actual FIELD_SCALE_STRATEGIES."""
|
||||||
|
|
||||||
|
def test_ocr_number_expands_significantly_upward(self):
|
||||||
|
"""Verify ocr_number field gets significant upward expansion."""
|
||||||
|
bbox = (100, 200, 200, 230) # Small height=30
|
||||||
|
|
||||||
|
result = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="ocr_number",
|
||||||
|
)
|
||||||
|
|
||||||
|
# extra_top_ratio=0.60 -> 30 * 0.6 = 18 extra top
|
||||||
|
# y0 should decrease significantly
|
||||||
|
assert result[1] < 200 - 10 # At least 10px upward expansion
|
||||||
|
|
||||||
|
def test_bankgiro_expands_significantly_leftward(self):
|
||||||
|
"""Verify bankgiro field gets significant leftward expansion."""
|
||||||
|
bbox = (200, 200, 300, 230) # width=100
|
||||||
|
|
||||||
|
result = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="bankgiro",
|
||||||
|
)
|
||||||
|
|
||||||
|
# extra_left_ratio=0.80 -> 100 * 0.8 = 80 extra left
|
||||||
|
# x0 should decrease significantly
|
||||||
|
assert result[0] < 200 - 30 # At least 30px leftward expansion
|
||||||
|
|
||||||
|
def test_amount_expands_rightward(self):
|
||||||
|
"""Verify amount field gets rightward expansion for currency."""
|
||||||
|
bbox = (100, 200, 200, 230) # width=100
|
||||||
|
|
||||||
|
result = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="amount",
|
||||||
|
)
|
||||||
|
|
||||||
|
# extra_right_ratio=0.30 -> 100 * 0.3 = 30 extra right
|
||||||
|
# x1 should increase
|
||||||
|
assert result[2] > 200 + 10 # At least 10px rightward expansion
|
||||||
|
|
||||||
|
|
||||||
|
class TestExpandBboxReturnType:
|
||||||
|
"""Tests for return type and value format."""
|
||||||
|
|
||||||
|
def test_returns_tuple_of_four_ints(self):
|
||||||
|
"""Verify return type is tuple of 4 integers."""
|
||||||
|
bbox = (100.5, 200.3, 200.7, 250.9)
|
||||||
|
|
||||||
|
result = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="invoice_number",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, tuple)
|
||||||
|
assert len(result) == 4
|
||||||
|
assert all(isinstance(v, int) for v in result)
|
||||||
|
|
||||||
|
def test_returns_valid_bbox_format(self):
|
||||||
|
"""Verify returned bbox has x0 < x1 and y0 < y1."""
|
||||||
|
bbox = (100, 200, 200, 250)
|
||||||
|
|
||||||
|
result = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="invoice_number",
|
||||||
|
)
|
||||||
|
|
||||||
|
x0, y0, x1, y1 = result
|
||||||
|
assert x0 < x1, "x0 should be less than x1"
|
||||||
|
assert y0 < y1, "y0 should be less than y1"
|
||||||
|
|
||||||
|
|
||||||
|
class TestManualLabelMode:
|
||||||
|
"""Tests for manual_mode parameter."""
|
||||||
|
|
||||||
|
def test_manual_mode_uses_minimal_padding(self):
|
||||||
|
"""Verify manual_mode uses MANUAL_LABEL_STRATEGY with minimal padding."""
|
||||||
|
bbox = (100, 200, 200, 250) # width=100, height=50
|
||||||
|
|
||||||
|
result = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="bankgiro", # Would normally expand left significantly
|
||||||
|
manual_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# MANUAL_LABEL_STRATEGY: scale=1.0, max_pad=10
|
||||||
|
# Should only add 10px padding each side (but scale=1.0 means no scaling)
|
||||||
|
# Actually with scale=1.0, no extra ratios, we get 0 expansion from scaling
|
||||||
|
# Only max_pad=10 applies as a limit, but there's no expansion to limit
|
||||||
|
# So result should be same as original
|
||||||
|
assert result == (100, 200, 200, 250)
|
||||||
|
|
||||||
|
def test_manual_mode_ignores_field_type(self):
|
||||||
|
"""Verify manual_mode ignores field-specific strategies."""
|
||||||
|
bbox = (100, 200, 200, 250)
|
||||||
|
|
||||||
|
# Different fields should give same result in manual_mode
|
||||||
|
result_bankgiro = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="bankgiro",
|
||||||
|
manual_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
result_ocr = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="ocr_number",
|
||||||
|
manual_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result_bankgiro == result_ocr
|
||||||
|
|
||||||
|
def test_manual_mode_vs_auto_mode_different(self):
|
||||||
|
"""Verify manual_mode produces different results than auto mode."""
|
||||||
|
bbox = (100, 200, 200, 250)
|
||||||
|
|
||||||
|
auto_result = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="bankgiro", # Has extra_left_ratio=0.80
|
||||||
|
manual_mode=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
manual_result = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="bankgiro",
|
||||||
|
manual_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Auto mode should expand more (especially to the left for bankgiro)
|
||||||
|
assert auto_result[0] < manual_result[0] # Auto x0 is more left
|
||||||
|
|
||||||
|
def test_manual_mode_clamps_to_image_bounds(self):
|
||||||
|
"""Verify manual_mode still respects image boundaries."""
|
||||||
|
bbox = (5, 5, 50, 50) # Close to top-left corner
|
||||||
|
|
||||||
|
result = expand_bbox(
|
||||||
|
bbox=bbox,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
field_type="test",
|
||||||
|
manual_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should clamp to 0
|
||||||
|
assert result[0] >= 0
|
||||||
|
assert result[1] >= 0
|
||||||
192
tests/shared/bbox/test_scale_strategy.py
Normal file
192
tests/shared/bbox/test_scale_strategy.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
"""
|
||||||
|
Tests for ScaleStrategy configuration.
|
||||||
|
|
||||||
|
Tests verify that scale strategies are properly defined, immutable,
|
||||||
|
and cover all required fields.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from shared.bbox import (
|
||||||
|
ScaleStrategy,
|
||||||
|
DEFAULT_STRATEGY,
|
||||||
|
MANUAL_LABEL_STRATEGY,
|
||||||
|
FIELD_SCALE_STRATEGIES,
|
||||||
|
)
|
||||||
|
from shared.fields import CLASS_NAMES
|
||||||
|
|
||||||
|
|
||||||
|
class TestScaleStrategyDataclass:
|
||||||
|
"""Tests for ScaleStrategy dataclass behavior."""
|
||||||
|
|
||||||
|
def test_default_strategy_values(self):
|
||||||
|
"""Verify default strategy has expected default values."""
|
||||||
|
strategy = ScaleStrategy()
|
||||||
|
assert strategy.scale_x == 1.15
|
||||||
|
assert strategy.scale_y == 1.15
|
||||||
|
assert strategy.extra_top_ratio == 0.0
|
||||||
|
assert strategy.extra_bottom_ratio == 0.0
|
||||||
|
assert strategy.extra_left_ratio == 0.0
|
||||||
|
assert strategy.extra_right_ratio == 0.0
|
||||||
|
assert strategy.max_pad_x == 50
|
||||||
|
assert strategy.max_pad_y == 50
|
||||||
|
|
||||||
|
def test_scale_strategy_immutability(self):
|
||||||
|
"""Verify ScaleStrategy is frozen (immutable)."""
|
||||||
|
strategy = ScaleStrategy()
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
strategy.scale_x = 2.0 # type: ignore
|
||||||
|
|
||||||
|
def test_custom_strategy_values(self):
|
||||||
|
"""Verify custom values are properly set."""
|
||||||
|
strategy = ScaleStrategy(
|
||||||
|
scale_x=1.5,
|
||||||
|
scale_y=1.8,
|
||||||
|
extra_top_ratio=0.6,
|
||||||
|
extra_left_ratio=0.8,
|
||||||
|
max_pad_x=100,
|
||||||
|
max_pad_y=150,
|
||||||
|
)
|
||||||
|
assert strategy.scale_x == 1.5
|
||||||
|
assert strategy.scale_y == 1.8
|
||||||
|
assert strategy.extra_top_ratio == 0.6
|
||||||
|
assert strategy.extra_left_ratio == 0.8
|
||||||
|
assert strategy.max_pad_x == 100
|
||||||
|
assert strategy.max_pad_y == 150
|
||||||
|
|
||||||
|
|
||||||
|
class TestDefaultStrategy:
|
||||||
|
"""Tests for DEFAULT_STRATEGY constant."""
|
||||||
|
|
||||||
|
def test_default_strategy_is_scale_strategy(self):
|
||||||
|
"""Verify DEFAULT_STRATEGY is a ScaleStrategy instance."""
|
||||||
|
assert isinstance(DEFAULT_STRATEGY, ScaleStrategy)
|
||||||
|
|
||||||
|
def test_default_strategy_matches_default_values(self):
|
||||||
|
"""Verify DEFAULT_STRATEGY has same values as ScaleStrategy()."""
|
||||||
|
expected = ScaleStrategy()
|
||||||
|
assert DEFAULT_STRATEGY == expected
|
||||||
|
|
||||||
|
|
||||||
|
class TestManualLabelStrategy:
|
||||||
|
"""Tests for MANUAL_LABEL_STRATEGY constant."""
|
||||||
|
|
||||||
|
def test_manual_label_strategy_is_scale_strategy(self):
|
||||||
|
"""Verify MANUAL_LABEL_STRATEGY is a ScaleStrategy instance."""
|
||||||
|
assert isinstance(MANUAL_LABEL_STRATEGY, ScaleStrategy)
|
||||||
|
|
||||||
|
def test_manual_label_strategy_has_no_scaling(self):
|
||||||
|
"""Verify MANUAL_LABEL_STRATEGY has scale factors of 1.0."""
|
||||||
|
assert MANUAL_LABEL_STRATEGY.scale_x == 1.0
|
||||||
|
assert MANUAL_LABEL_STRATEGY.scale_y == 1.0
|
||||||
|
|
||||||
|
def test_manual_label_strategy_has_no_directional_expansion(self):
|
||||||
|
"""Verify MANUAL_LABEL_STRATEGY has no directional expansion."""
|
||||||
|
assert MANUAL_LABEL_STRATEGY.extra_top_ratio == 0.0
|
||||||
|
assert MANUAL_LABEL_STRATEGY.extra_bottom_ratio == 0.0
|
||||||
|
assert MANUAL_LABEL_STRATEGY.extra_left_ratio == 0.0
|
||||||
|
assert MANUAL_LABEL_STRATEGY.extra_right_ratio == 0.0
|
||||||
|
|
||||||
|
def test_manual_label_strategy_has_small_max_pad(self):
|
||||||
|
"""Verify MANUAL_LABEL_STRATEGY has small max padding."""
|
||||||
|
assert MANUAL_LABEL_STRATEGY.max_pad_x <= 15
|
||||||
|
assert MANUAL_LABEL_STRATEGY.max_pad_y <= 15
|
||||||
|
|
||||||
|
|
||||||
|
class TestFieldScaleStrategies:
|
||||||
|
"""Tests for FIELD_SCALE_STRATEGIES dictionary."""
|
||||||
|
|
||||||
|
def test_all_class_names_have_strategies(self):
|
||||||
|
"""Verify all field class names have defined strategies."""
|
||||||
|
for class_name in CLASS_NAMES:
|
||||||
|
assert class_name in FIELD_SCALE_STRATEGIES, (
|
||||||
|
f"Missing strategy for field: {class_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_strategies_are_scale_strategy_instances(self):
|
||||||
|
"""Verify all strategies are ScaleStrategy instances."""
|
||||||
|
for field_name, strategy in FIELD_SCALE_STRATEGIES.items():
|
||||||
|
assert isinstance(strategy, ScaleStrategy), (
|
||||||
|
f"Strategy for {field_name} is not a ScaleStrategy"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_scale_values_are_greater_than_one(self):
|
||||||
|
"""Verify all scale values are >= 1.0 (expansion, not contraction)."""
|
||||||
|
for field_name, strategy in FIELD_SCALE_STRATEGIES.items():
|
||||||
|
assert strategy.scale_x >= 1.0, (
|
||||||
|
f"{field_name} scale_x should be >= 1.0"
|
||||||
|
)
|
||||||
|
assert strategy.scale_y >= 1.0, (
|
||||||
|
f"{field_name} scale_y should be >= 1.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_extra_ratios_are_non_negative(self):
|
||||||
|
"""Verify all extra ratios are >= 0."""
|
||||||
|
for field_name, strategy in FIELD_SCALE_STRATEGIES.items():
|
||||||
|
assert strategy.extra_top_ratio >= 0, (
|
||||||
|
f"{field_name} extra_top_ratio should be >= 0"
|
||||||
|
)
|
||||||
|
assert strategy.extra_bottom_ratio >= 0, (
|
||||||
|
f"{field_name} extra_bottom_ratio should be >= 0"
|
||||||
|
)
|
||||||
|
assert strategy.extra_left_ratio >= 0, (
|
||||||
|
f"{field_name} extra_left_ratio should be >= 0"
|
||||||
|
)
|
||||||
|
assert strategy.extra_right_ratio >= 0, (
|
||||||
|
f"{field_name} extra_right_ratio should be >= 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_max_pad_values_are_positive(self):
|
||||||
|
"""Verify all max_pad values are > 0."""
|
||||||
|
for field_name, strategy in FIELD_SCALE_STRATEGIES.items():
|
||||||
|
assert strategy.max_pad_x > 0, (
|
||||||
|
f"{field_name} max_pad_x should be > 0"
|
||||||
|
)
|
||||||
|
assert strategy.max_pad_y > 0, (
|
||||||
|
f"{field_name} max_pad_y should be > 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSpecificFieldStrategies:
|
||||||
|
"""Tests for specific field strategy configurations."""
|
||||||
|
|
||||||
|
def test_ocr_number_expands_upward(self):
|
||||||
|
"""Verify ocr_number strategy expands upward to capture label."""
|
||||||
|
strategy = FIELD_SCALE_STRATEGIES["ocr_number"]
|
||||||
|
assert strategy.extra_top_ratio > 0.0
|
||||||
|
assert strategy.extra_top_ratio >= 0.5 # Significant upward expansion
|
||||||
|
|
||||||
|
def test_bankgiro_expands_leftward(self):
|
||||||
|
"""Verify bankgiro strategy expands leftward to capture prefix."""
|
||||||
|
strategy = FIELD_SCALE_STRATEGIES["bankgiro"]
|
||||||
|
assert strategy.extra_left_ratio > 0.0
|
||||||
|
assert strategy.extra_left_ratio >= 0.5 # Significant leftward expansion
|
||||||
|
|
||||||
|
def test_plusgiro_expands_leftward(self):
|
||||||
|
"""Verify plusgiro strategy expands leftward to capture prefix."""
|
||||||
|
strategy = FIELD_SCALE_STRATEGIES["plusgiro"]
|
||||||
|
assert strategy.extra_left_ratio > 0.0
|
||||||
|
assert strategy.extra_left_ratio >= 0.5
|
||||||
|
|
||||||
|
def test_amount_expands_rightward(self):
|
||||||
|
"""Verify amount strategy expands rightward for currency symbol."""
|
||||||
|
strategy = FIELD_SCALE_STRATEGIES["amount"]
|
||||||
|
assert strategy.extra_right_ratio > 0.0
|
||||||
|
|
||||||
|
def test_invoice_date_expands_upward(self):
|
||||||
|
"""Verify invoice_date strategy expands upward to capture label."""
|
||||||
|
strategy = FIELD_SCALE_STRATEGIES["invoice_date"]
|
||||||
|
assert strategy.extra_top_ratio > 0.0
|
||||||
|
|
||||||
|
def test_invoice_due_date_expands_upward_and_leftward(self):
|
||||||
|
"""Verify invoice_due_date strategy expands both up and left."""
|
||||||
|
strategy = FIELD_SCALE_STRATEGIES["invoice_due_date"]
|
||||||
|
assert strategy.extra_top_ratio > 0.0
|
||||||
|
assert strategy.extra_left_ratio > 0.0
|
||||||
|
|
||||||
|
def test_payment_line_has_minimal_expansion(self):
|
||||||
|
"""Verify payment_line has conservative expansion (machine code)."""
|
||||||
|
strategy = FIELD_SCALE_STRATEGIES["payment_line"]
|
||||||
|
# Payment line is machine-readable, needs minimal expansion
|
||||||
|
assert strategy.scale_x <= 1.2
|
||||||
|
assert strategy.scale_y <= 1.3
|
||||||
@@ -16,6 +16,7 @@ from shared.fields import (
|
|||||||
FIELD_CLASSES,
|
FIELD_CLASSES,
|
||||||
FIELD_CLASS_IDS,
|
FIELD_CLASS_IDS,
|
||||||
CLASS_TO_FIELD,
|
CLASS_TO_FIELD,
|
||||||
|
FIELD_TO_CLASS,
|
||||||
CSV_TO_CLASS_MAPPING,
|
CSV_TO_CLASS_MAPPING,
|
||||||
TRAINING_FIELD_CLASSES,
|
TRAINING_FIELD_CLASSES,
|
||||||
NUM_CLASSES,
|
NUM_CLASSES,
|
||||||
@@ -133,6 +134,20 @@ class TestMappingConsistency:
|
|||||||
assert fd.field_name in TRAINING_FIELD_CLASSES
|
assert fd.field_name in TRAINING_FIELD_CLASSES
|
||||||
assert TRAINING_FIELD_CLASSES[fd.field_name] == fd.class_id
|
assert TRAINING_FIELD_CLASSES[fd.field_name] == fd.class_id
|
||||||
|
|
||||||
|
def test_field_to_class_is_inverse_of_class_to_field(self):
|
||||||
|
"""Verify FIELD_TO_CLASS and CLASS_TO_FIELD are proper inverses."""
|
||||||
|
for class_name, field_name in CLASS_TO_FIELD.items():
|
||||||
|
assert FIELD_TO_CLASS[field_name] == class_name
|
||||||
|
|
||||||
|
for field_name, class_name in FIELD_TO_CLASS.items():
|
||||||
|
assert CLASS_TO_FIELD[class_name] == field_name
|
||||||
|
|
||||||
|
def test_field_to_class_has_all_fields(self):
|
||||||
|
"""Verify FIELD_TO_CLASS has mapping for all field names."""
|
||||||
|
for fd in FIELD_DEFINITIONS:
|
||||||
|
assert fd.field_name in FIELD_TO_CLASS
|
||||||
|
assert FIELD_TO_CLASS[fd.field_name] == fd.class_name
|
||||||
|
|
||||||
|
|
||||||
class TestSpecificFieldDefinitions:
|
class TestSpecificFieldDefinitions:
|
||||||
"""Tests for specific field definitions to catch common mistakes."""
|
"""Tests for specific field definitions to catch common mistakes."""
|
||||||
|
|||||||
1
tests/training/__init__.py
Normal file
1
tests/training/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for training package."""
|
||||||
1
tests/training/yolo/__init__.py
Normal file
1
tests/training/yolo/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for training.yolo module."""
|
||||||
342
tests/training/yolo/test_annotation_generator.py
Normal file
342
tests/training/yolo/test_annotation_generator.py
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
"""
|
||||||
|
Tests for AnnotationGenerator with field-specific bbox expansion.
|
||||||
|
|
||||||
|
Tests verify that annotations are generated correctly using
|
||||||
|
field-specific scale strategies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from training.yolo.annotation_generator import (
|
||||||
|
AnnotationGenerator,
|
||||||
|
YOLOAnnotation,
|
||||||
|
)
|
||||||
|
from shared.fields import TRAINING_FIELD_CLASSES, CLASS_NAMES
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockMatch:
|
||||||
|
"""Mock Match object for testing."""
|
||||||
|
bbox: tuple[float, float, float, float]
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
class TestYOLOAnnotation:
|
||||||
|
"""Tests for YOLOAnnotation dataclass."""
|
||||||
|
|
||||||
|
def test_to_string_format(self):
|
||||||
|
"""Verify YOLO format string output."""
|
||||||
|
ann = YOLOAnnotation(
|
||||||
|
class_id=0,
|
||||||
|
x_center=0.5,
|
||||||
|
y_center=0.5,
|
||||||
|
width=0.1,
|
||||||
|
height=0.05,
|
||||||
|
confidence=0.9
|
||||||
|
)
|
||||||
|
result = ann.to_string()
|
||||||
|
assert result == "0 0.500000 0.500000 0.100000 0.050000"
|
||||||
|
|
||||||
|
def test_default_confidence(self):
|
||||||
|
"""Verify default confidence is 1.0."""
|
||||||
|
ann = YOLOAnnotation(
|
||||||
|
class_id=0,
|
||||||
|
x_center=0.5,
|
||||||
|
y_center=0.5,
|
||||||
|
width=0.1,
|
||||||
|
height=0.05,
|
||||||
|
)
|
||||||
|
assert ann.confidence == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnnotationGeneratorInit:
|
||||||
|
"""Tests for AnnotationGenerator initialization."""
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
"""Verify default initialization values."""
|
||||||
|
gen = AnnotationGenerator()
|
||||||
|
assert gen.min_confidence == 0.7
|
||||||
|
assert gen.min_bbox_height_px == 30
|
||||||
|
|
||||||
|
def test_custom_values(self):
|
||||||
|
"""Verify custom initialization values."""
|
||||||
|
gen = AnnotationGenerator(
|
||||||
|
min_confidence=0.8,
|
||||||
|
min_bbox_height_px=40,
|
||||||
|
)
|
||||||
|
assert gen.min_confidence == 0.8
|
||||||
|
assert gen.min_bbox_height_px == 40
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerateFromMatches:
|
||||||
|
"""Tests for generate_from_matches method."""
|
||||||
|
|
||||||
|
def test_generates_annotation_for_valid_match(self):
|
||||||
|
"""Verify annotation is generated for valid match."""
|
||||||
|
gen = AnnotationGenerator(min_confidence=0.5)
|
||||||
|
|
||||||
|
# Mock match in PDF points (72 DPI)
|
||||||
|
# At 150 DPI, coords multiply by 150/72 = 2.083
|
||||||
|
matches = {
|
||||||
|
"InvoiceNumber": [MockMatch(bbox=(100, 200, 200, 230), score=0.8)]
|
||||||
|
}
|
||||||
|
|
||||||
|
annotations = gen.generate_from_matches(
|
||||||
|
matches=matches,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
dpi=150
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(annotations) == 1
|
||||||
|
ann = annotations[0]
|
||||||
|
assert ann.class_id == TRAINING_FIELD_CLASSES["InvoiceNumber"]
|
||||||
|
assert ann.confidence == 0.8
|
||||||
|
# Normalized values should be in 0-1 range
|
||||||
|
assert 0 <= ann.x_center <= 1
|
||||||
|
assert 0 <= ann.y_center <= 1
|
||||||
|
assert 0 < ann.width <= 1
|
||||||
|
assert 0 < ann.height <= 1
|
||||||
|
|
||||||
|
def test_skips_low_confidence_match(self):
|
||||||
|
"""Verify low confidence matches are skipped."""
|
||||||
|
gen = AnnotationGenerator(min_confidence=0.7)
|
||||||
|
|
||||||
|
matches = {
|
||||||
|
"InvoiceNumber": [MockMatch(bbox=(100, 200, 200, 230), score=0.5)]
|
||||||
|
}
|
||||||
|
|
||||||
|
annotations = gen.generate_from_matches(
|
||||||
|
matches=matches,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
dpi=150
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(annotations) == 0
|
||||||
|
|
||||||
|
def test_skips_unknown_field(self):
|
||||||
|
"""Verify unknown fields are skipped."""
|
||||||
|
gen = AnnotationGenerator(min_confidence=0.5)
|
||||||
|
|
||||||
|
matches = {
|
||||||
|
"UnknownField": [MockMatch(bbox=(100, 200, 200, 230), score=0.9)]
|
||||||
|
}
|
||||||
|
|
||||||
|
annotations = gen.generate_from_matches(
|
||||||
|
matches=matches,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
dpi=150
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(annotations) == 0
|
||||||
|
|
||||||
|
def test_takes_best_match_only(self):
|
||||||
|
"""Verify only the best match is used per field."""
|
||||||
|
gen = AnnotationGenerator(min_confidence=0.5)
|
||||||
|
|
||||||
|
matches = {
|
||||||
|
"InvoiceNumber": [
|
||||||
|
MockMatch(bbox=(100, 200, 200, 230), score=0.9), # Best
|
||||||
|
MockMatch(bbox=(300, 400, 400, 430), score=0.7),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
annotations = gen.generate_from_matches(
|
||||||
|
matches=matches,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
dpi=150
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(annotations) == 1
|
||||||
|
assert annotations[0].confidence == 0.9
|
||||||
|
|
||||||
|
def test_handles_empty_matches(self):
|
||||||
|
"""Verify empty matches list is handled."""
|
||||||
|
gen = AnnotationGenerator()
|
||||||
|
|
||||||
|
matches = {
|
||||||
|
"InvoiceNumber": []
|
||||||
|
}
|
||||||
|
|
||||||
|
annotations = gen.generate_from_matches(
|
||||||
|
matches=matches,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
dpi=150
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(annotations) == 0
|
||||||
|
|
||||||
|
def test_applies_field_specific_expansion(self):
|
||||||
|
"""Verify different fields get different expansion."""
|
||||||
|
gen = AnnotationGenerator(min_confidence=0.5)
|
||||||
|
|
||||||
|
# Same bbox, different fields
|
||||||
|
bbox = (100, 200, 200, 230)
|
||||||
|
|
||||||
|
matches_invoice_number = {
|
||||||
|
"InvoiceNumber": [MockMatch(bbox=bbox, score=0.9)]
|
||||||
|
}
|
||||||
|
matches_bankgiro = {
|
||||||
|
"Bankgiro": [MockMatch(bbox=bbox, score=0.9)]
|
||||||
|
}
|
||||||
|
|
||||||
|
ann_invoice = gen.generate_from_matches(
|
||||||
|
matches=matches_invoice_number,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
dpi=150
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
ann_bankgiro = gen.generate_from_matches(
|
||||||
|
matches=matches_bankgiro,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
dpi=150
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# Bankgiro has extra_left_ratio=0.80, invoice_number has extra_top_ratio=0.40
|
||||||
|
# They should have different widths due to different expansion
|
||||||
|
# Bankgiro expands more to the left
|
||||||
|
assert ann_bankgiro.width != ann_invoice.width or ann_bankgiro.x_center != ann_invoice.x_center
|
||||||
|
|
||||||
|
def test_enforces_min_bbox_height(self):
|
||||||
|
"""Verify minimum bbox height is enforced."""
|
||||||
|
gen = AnnotationGenerator(min_confidence=0.5, min_bbox_height_px=50)
|
||||||
|
|
||||||
|
# Very small bbox
|
||||||
|
matches = {
|
||||||
|
"InvoiceNumber": [MockMatch(bbox=(100, 200, 200, 210), score=0.9)]
|
||||||
|
}
|
||||||
|
|
||||||
|
annotations = gen.generate_from_matches(
|
||||||
|
matches=matches,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
dpi=72 # 1:1 scale
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(annotations) == 1
|
||||||
|
# Height should be at least min_bbox_height_px / image_height
|
||||||
|
# After scale strategy expansion, height should be >= 50/1000 = 0.05
|
||||||
|
# Actually the min_bbox_height check happens AFTER expand_bbox
|
||||||
|
# So the final height should meet the minimum
|
||||||
|
|
||||||
|
|
||||||
|
class TestAddPaymentLineAnnotation:
|
||||||
|
"""Tests for add_payment_line_annotation method."""
|
||||||
|
|
||||||
|
def test_adds_payment_line_annotation(self):
|
||||||
|
"""Verify payment_line annotation is added."""
|
||||||
|
gen = AnnotationGenerator(min_confidence=0.5)
|
||||||
|
annotations = []
|
||||||
|
|
||||||
|
result = gen.add_payment_line_annotation(
|
||||||
|
annotations=annotations,
|
||||||
|
payment_line_bbox=(100, 200, 400, 230),
|
||||||
|
confidence=0.9,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
dpi=150
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
ann = result[0]
|
||||||
|
assert ann.class_id == TRAINING_FIELD_CLASSES["payment_line"]
|
||||||
|
assert ann.confidence == 0.9
|
||||||
|
|
||||||
|
def test_skips_none_bbox(self):
|
||||||
|
"""Verify None bbox is handled."""
|
||||||
|
gen = AnnotationGenerator(min_confidence=0.5)
|
||||||
|
annotations = []
|
||||||
|
|
||||||
|
result = gen.add_payment_line_annotation(
|
||||||
|
annotations=annotations,
|
||||||
|
payment_line_bbox=None,
|
||||||
|
confidence=0.9,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
dpi=150
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 0
|
||||||
|
|
||||||
|
def test_skips_low_confidence(self):
|
||||||
|
"""Verify low confidence is skipped."""
|
||||||
|
gen = AnnotationGenerator(min_confidence=0.7)
|
||||||
|
annotations = []
|
||||||
|
|
||||||
|
result = gen.add_payment_line_annotation(
|
||||||
|
annotations=annotations,
|
||||||
|
payment_line_bbox=(100, 200, 400, 230),
|
||||||
|
confidence=0.5,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
dpi=150
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 0
|
||||||
|
|
||||||
|
def test_appends_to_existing_annotations(self):
|
||||||
|
"""Verify payment_line is appended to existing list."""
|
||||||
|
gen = AnnotationGenerator(min_confidence=0.5)
|
||||||
|
existing = [YOLOAnnotation(0, 0.5, 0.5, 0.1, 0.1, 0.9)]
|
||||||
|
|
||||||
|
result = gen.add_payment_line_annotation(
|
||||||
|
annotations=existing,
|
||||||
|
payment_line_bbox=(100, 200, 400, 230),
|
||||||
|
confidence=0.9,
|
||||||
|
image_width=1000,
|
||||||
|
image_height=1000,
|
||||||
|
dpi=150
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0].class_id == 0 # Original
|
||||||
|
assert result[1].class_id == TRAINING_FIELD_CLASSES["payment_line"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultipleFieldsIntegration:
|
||||||
|
"""Integration tests for multiple fields."""
|
||||||
|
|
||||||
|
def test_generates_annotations_for_all_field_types(self):
|
||||||
|
"""Verify annotations can be generated for all field types."""
|
||||||
|
gen = AnnotationGenerator(min_confidence=0.5)
|
||||||
|
|
||||||
|
# Create matches for each field (except payment_line which is derived)
|
||||||
|
field_names = [
|
||||||
|
"InvoiceNumber",
|
||||||
|
"InvoiceDate",
|
||||||
|
"InvoiceDueDate",
|
||||||
|
"OCR",
|
||||||
|
"Bankgiro",
|
||||||
|
"Plusgiro",
|
||||||
|
"Amount",
|
||||||
|
"supplier_organisation_number",
|
||||||
|
"customer_number",
|
||||||
|
]
|
||||||
|
|
||||||
|
matches = {}
|
||||||
|
for i, field_name in enumerate(field_names):
|
||||||
|
# Stagger bboxes to avoid overlap
|
||||||
|
matches[field_name] = [
|
||||||
|
MockMatch(bbox=(100 + i * 50, 100 + i * 30, 200 + i * 50, 130 + i * 30), score=0.9)
|
||||||
|
]
|
||||||
|
|
||||||
|
annotations = gen.generate_from_matches(
|
||||||
|
matches=matches,
|
||||||
|
image_width=2000,
|
||||||
|
image_height=2000,
|
||||||
|
dpi=150
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(annotations) == len(field_names)
|
||||||
|
|
||||||
|
# Verify all class_ids are present
|
||||||
|
class_ids = {ann.class_id for ann in annotations}
|
||||||
|
expected_class_ids = {TRAINING_FIELD_CLASSES[fn] for fn in field_names}
|
||||||
|
assert class_ids == expected_class_ids
|
||||||
Reference in New Issue
Block a user