Files
invoice-master-poc-v2/packages/training/training/yolo/annotation_generator.py
Yaojia Wang 0990239e9c feat: add field-specific bbox expansion strategies for YOLO training
Implement center-point based bbox scaling with directional compensation
to capture field labels that typically appear above or to the left of
field values. This improves YOLO training data quality by including
contextual information around field values.

Key changes:
- Add shared.bbox module with ScaleStrategy dataclass and expand_bbox function
- Define field-specific strategies (ocr_number, bankgiro, invoice_date, etc.)
- Support manual_mode for minimal padding (no scaling)
- Integrate expand_bbox into AnnotationGenerator
- Add FIELD_TO_CLASS mapping for field_name to class_name lookup
- Comprehensive tests with 100% coverage (45 tests)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-04 22:56:52 +01:00

367 lines
12 KiB
Python

"""
YOLO Annotation Generator
Generates YOLO format annotations from matched fields.
Uses field-specific bbox expansion strategies for optimal training data.
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import csv
# Import field mappings from single source of truth
from shared.fields import (
TRAINING_FIELD_CLASSES as FIELD_CLASSES,
CLASS_NAMES,
ACCOUNT_FIELD_MAPPING,
FIELD_TO_CLASS,
)
from shared.bbox import expand_bbox
@dataclass
class YOLOAnnotation:
"""Represents a single YOLO annotation."""
class_id: int
x_center: float # normalized 0-1
y_center: float # normalized 0-1
width: float # normalized 0-1
height: float # normalized 0-1
confidence: float = 1.0
def to_string(self) -> str:
"""Convert to YOLO format string."""
return f"{self.class_id} {self.x_center:.6f} {self.y_center:.6f} {self.width:.6f} {self.height:.6f}"
class AnnotationGenerator:
"""Generates YOLO annotations from document matches."""
def __init__(
self,
min_confidence: float = 0.7,
min_bbox_height_px: int = 30, # Minimum bbox height
):
"""
Initialize annotation generator.
Args:
min_confidence: Minimum match score to include in training
min_bbox_height_px: Minimum bbox height in pixels
"""
self.min_confidence = min_confidence
self.min_bbox_height_px = min_bbox_height_px
def generate_from_matches(
self,
matches: dict[str, list[Any]], # field_name -> list of Match
image_width: float,
image_height: float,
dpi: int = 300
) -> list[YOLOAnnotation]:
"""
Generate YOLO annotations from field matches.
Uses field-specific bbox expansion strategies for optimal training data.
Each field type has customized scale factors and directional compensation
to capture field labels and context.
Args:
matches: Dict of field_name -> list of Match objects
image_width: Width of the rendered image in pixels
image_height: Height of the rendered image in pixels
dpi: DPI used for rendering (needed to convert PDF coords to pixels)
Returns:
List of YOLOAnnotation objects
"""
annotations = []
# Scale factor to convert PDF points (72 DPI) to rendered pixels
scale = dpi / 72.0
for field_name, field_matches in matches.items():
if field_name not in FIELD_CLASSES:
continue
class_id = FIELD_CLASSES[field_name]
# Get class_name for bbox expansion strategy
class_name = FIELD_TO_CLASS.get(field_name, field_name)
# Take only the best match per field
if field_matches:
best_match = field_matches[0] # Already sorted by score
if best_match.score < self.min_confidence:
continue
# best_match.bbox is in PDF points (72 DPI), convert to pixels
x0, y0, x1, y1 = best_match.bbox
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
# Apply field-specific bbox expansion strategy
x0, y0, x1, y1 = expand_bbox(
bbox=(x0, y0, x1, y1),
image_width=image_width,
image_height=image_height,
field_type=class_name,
)
# Ensure minimum height
current_height = y1 - y0
if current_height < self.min_bbox_height_px:
extra = (self.min_bbox_height_px - current_height) / 2
y0 = max(0, int(y0 - extra))
y1 = min(int(image_height), int(y1 + extra))
# Convert to YOLO format (normalized center + size)
x_center = (x0 + x1) / 2 / image_width
y_center = (y0 + y1) / 2 / image_height
width = (x1 - x0) / image_width
height = (y1 - y0) / image_height
# Clamp values to 0-1
x_center = max(0, min(1, x_center))
y_center = max(0, min(1, y_center))
width = max(0, min(1, width))
height = max(0, min(1, height))
annotations.append(YOLOAnnotation(
class_id=class_id,
x_center=x_center,
y_center=y_center,
width=width,
height=height,
confidence=best_match.score
))
return annotations
def add_payment_line_annotation(
self,
annotations: list[YOLOAnnotation],
payment_line_bbox: tuple[float, float, float, float],
confidence: float,
image_width: float,
image_height: float,
dpi: int = 300
) -> list[YOLOAnnotation]:
"""
Add payment_line annotation from machine code parser result.
Uses "payment_line" scale strategy for minimal expansion
(machine-readable code needs less context).
Args:
annotations: Existing list of annotations to append to
payment_line_bbox: Bounding box (x0, y0, x1, y1) in PDF coordinates
confidence: Confidence score from machine code parser
image_width: Width of the rendered image in pixels
image_height: Height of the rendered image in pixels
dpi: DPI used for rendering
Returns:
Updated annotations list with payment_line annotation added
"""
if not payment_line_bbox or confidence < self.min_confidence:
return annotations
# Scale factor to convert PDF points (72 DPI) to rendered pixels
scale = dpi / 72.0
x0, y0, x1, y1 = payment_line_bbox
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
# Apply field-specific bbox expansion strategy for payment_line
x0, y0, x1, y1 = expand_bbox(
bbox=(x0, y0, x1, y1),
image_width=image_width,
image_height=image_height,
field_type="payment_line",
)
# Convert to YOLO format (normalized center + size)
x_center = (x0 + x1) / 2 / image_width
y_center = (y0 + y1) / 2 / image_height
width = (x1 - x0) / image_width
height = (y1 - y0) / image_height
# Clamp values to 0-1
x_center = max(0, min(1, x_center))
y_center = max(0, min(1, y_center))
width = max(0, min(1, width))
height = max(0, min(1, height))
annotations.append(YOLOAnnotation(
class_id=FIELD_CLASSES['payment_line'],
x_center=x_center,
y_center=y_center,
width=width,
height=height,
confidence=confidence
))
return annotations
def save_annotations(
self,
annotations: list[YOLOAnnotation],
output_path: str | Path
) -> None:
"""Save annotations to a YOLO format text file."""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w') as f:
for ann in annotations:
f.write(ann.to_string() + '\n')
@staticmethod
def generate_classes_file(output_path: str | Path) -> None:
"""Generate the classes.txt file for YOLO."""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w') as f:
for class_name in CLASS_NAMES:
f.write(class_name + '\n')
@staticmethod
def generate_yaml_config(
output_path: str | Path,
train_path: str = 'train/images',
val_path: str = 'val/images',
test_path: str = 'test/images',
use_wsl_paths: bool | None = None
) -> None:
"""
Generate YOLO dataset YAML configuration.
Args:
output_path: Path to output YAML file
train_path: Relative path to training images
val_path: Relative path to validation images
test_path: Relative path to test images
use_wsl_paths: If True, convert Windows paths to WSL format.
If None, auto-detect based on environment.
"""
import os
import platform
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
dataset_dir = output_path.parent.absolute()
dataset_path_str = str(dataset_dir)
# Auto-detect WSL environment
if use_wsl_paths is None:
# Check if running inside WSL
is_wsl = 'microsoft' in platform.uname().release.lower() if platform.system() == 'Linux' else False
# Check WSL_DISTRO_NAME environment variable (set when running in WSL)
is_wsl = is_wsl or os.environ.get('WSL_DISTRO_NAME') is not None
use_wsl_paths = is_wsl
# Convert path format based on environment
if use_wsl_paths:
# Running in WSL: convert Windows paths to /mnt/c/... format
dataset_path_str = dataset_path_str.replace('\\', '/')
if len(dataset_path_str) > 1 and dataset_path_str[1] == ':':
drive = dataset_path_str[0].lower()
dataset_path_str = f"/mnt/{drive}{dataset_path_str[2:]}"
elif platform.system() == 'Windows':
# Running on native Windows: use forward slashes for YOLO compatibility
dataset_path_str = dataset_path_str.replace('\\', '/')
config = f"""# Invoice Field Detection Dataset
path: {dataset_path_str}
train: {train_path}
val: {val_path}
test: {test_path}
# Classes
names:
"""
for i, name in enumerate(CLASS_NAMES):
config += f" {i}: {name}\n"
with open(output_path, 'w') as f:
f.write(config)
def generate_annotations(
pdf_path: str | Path,
structured_data: dict[str, Any],
output_dir: str | Path,
dpi: int = 300
) -> list[Path]:
"""
Generate YOLO annotations for a PDF using structured data.
Args:
pdf_path: Path to the PDF file
structured_data: Dict with field values from CSV
output_dir: Directory to save images and labels
dpi: Resolution for rendering
Returns:
List of paths to generated annotation files
"""
from shared.pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
from shared.pdf.renderer import get_render_dimensions
from shared.ocr import OCREngine
from shared.matcher import FieldMatcher
from shared.normalize import normalize_field
output_dir = Path(output_dir)
images_dir = output_dir / 'images'
labels_dir = output_dir / 'labels'
images_dir.mkdir(parents=True, exist_ok=True)
labels_dir.mkdir(parents=True, exist_ok=True)
generator = AnnotationGenerator()
matcher = FieldMatcher()
annotation_files = []
# Check PDF type
use_ocr = not is_text_pdf(pdf_path)
# Initialize OCR if needed
ocr_engine = OCREngine() if use_ocr else None
# Process each page
for page_no, image_path in render_pdf_to_images(pdf_path, images_dir, dpi=dpi):
# Get image dimensions
img_width, img_height = get_render_dimensions(pdf_path, page_no, dpi)
# Extract tokens
if use_ocr:
from PIL import Image
tokens = ocr_engine.extract_from_image(str(image_path), page_no)
else:
tokens = list(extract_text_tokens(pdf_path, page_no))
# Match fields
matches = {}
for field_name in FIELD_CLASSES.keys():
value = structured_data.get(field_name)
if value is None or str(value).strip() == '':
continue
normalized = normalize_field(field_name, str(value))
field_matches = matcher.find_matches(tokens, field_name, normalized, page_no)
if field_matches:
matches[field_name] = field_matches
# Generate annotations (pass DPI for coordinate conversion)
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
# Save annotations
if annotations:
label_path = labels_dir / f"{image_path.stem}.txt"
generator.save_annotations(annotations, label_path)
annotation_files.append(label_path)
return annotation_files