Implement center-point based bbox scaling with directional compensation to capture field labels that typically appear above or to the left of field values. This improves YOLO training data quality by including contextual information around field values. Key changes: - Add shared.bbox module with ScaleStrategy dataclass and expand_bbox function - Define field-specific strategies (ocr_number, bankgiro, invoice_date, etc.) - Support manual_mode for minimal padding (no scaling) - Integrate expand_bbox into AnnotationGenerator - Add FIELD_TO_CLASS mapping for field_name to class_name lookup - Comprehensive tests with 100% coverage (45 tests) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
367 lines
12 KiB
Python
367 lines
12 KiB
Python
"""
|
|
YOLO Annotation Generator
|
|
|
|
Generates YOLO format annotations from matched fields.
|
|
Uses field-specific bbox expansion strategies for optimal training data.
|
|
"""
|
|
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any
|
|
import csv
|
|
|
|
# Import field mappings from single source of truth
|
|
from shared.fields import (
|
|
TRAINING_FIELD_CLASSES as FIELD_CLASSES,
|
|
CLASS_NAMES,
|
|
ACCOUNT_FIELD_MAPPING,
|
|
FIELD_TO_CLASS,
|
|
)
|
|
from shared.bbox import expand_bbox
|
|
|
|
|
|
@dataclass
|
|
class YOLOAnnotation:
|
|
"""Represents a single YOLO annotation."""
|
|
class_id: int
|
|
x_center: float # normalized 0-1
|
|
y_center: float # normalized 0-1
|
|
width: float # normalized 0-1
|
|
height: float # normalized 0-1
|
|
confidence: float = 1.0
|
|
|
|
def to_string(self) -> str:
|
|
"""Convert to YOLO format string."""
|
|
return f"{self.class_id} {self.x_center:.6f} {self.y_center:.6f} {self.width:.6f} {self.height:.6f}"
|
|
|
|
|
|
class AnnotationGenerator:
|
|
"""Generates YOLO annotations from document matches."""
|
|
|
|
def __init__(
|
|
self,
|
|
min_confidence: float = 0.7,
|
|
min_bbox_height_px: int = 30, # Minimum bbox height
|
|
):
|
|
"""
|
|
Initialize annotation generator.
|
|
|
|
Args:
|
|
min_confidence: Minimum match score to include in training
|
|
min_bbox_height_px: Minimum bbox height in pixels
|
|
"""
|
|
self.min_confidence = min_confidence
|
|
self.min_bbox_height_px = min_bbox_height_px
|
|
|
|
def generate_from_matches(
|
|
self,
|
|
matches: dict[str, list[Any]], # field_name -> list of Match
|
|
image_width: float,
|
|
image_height: float,
|
|
dpi: int = 300
|
|
) -> list[YOLOAnnotation]:
|
|
"""
|
|
Generate YOLO annotations from field matches.
|
|
|
|
Uses field-specific bbox expansion strategies for optimal training data.
|
|
Each field type has customized scale factors and directional compensation
|
|
to capture field labels and context.
|
|
|
|
Args:
|
|
matches: Dict of field_name -> list of Match objects
|
|
image_width: Width of the rendered image in pixels
|
|
image_height: Height of the rendered image in pixels
|
|
dpi: DPI used for rendering (needed to convert PDF coords to pixels)
|
|
|
|
Returns:
|
|
List of YOLOAnnotation objects
|
|
"""
|
|
annotations = []
|
|
|
|
# Scale factor to convert PDF points (72 DPI) to rendered pixels
|
|
scale = dpi / 72.0
|
|
|
|
for field_name, field_matches in matches.items():
|
|
if field_name not in FIELD_CLASSES:
|
|
continue
|
|
|
|
class_id = FIELD_CLASSES[field_name]
|
|
# Get class_name for bbox expansion strategy
|
|
class_name = FIELD_TO_CLASS.get(field_name, field_name)
|
|
|
|
# Take only the best match per field
|
|
if field_matches:
|
|
best_match = field_matches[0] # Already sorted by score
|
|
|
|
if best_match.score < self.min_confidence:
|
|
continue
|
|
|
|
# best_match.bbox is in PDF points (72 DPI), convert to pixels
|
|
x0, y0, x1, y1 = best_match.bbox
|
|
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
|
|
|
|
# Apply field-specific bbox expansion strategy
|
|
x0, y0, x1, y1 = expand_bbox(
|
|
bbox=(x0, y0, x1, y1),
|
|
image_width=image_width,
|
|
image_height=image_height,
|
|
field_type=class_name,
|
|
)
|
|
|
|
# Ensure minimum height
|
|
current_height = y1 - y0
|
|
if current_height < self.min_bbox_height_px:
|
|
extra = (self.min_bbox_height_px - current_height) / 2
|
|
y0 = max(0, int(y0 - extra))
|
|
y1 = min(int(image_height), int(y1 + extra))
|
|
|
|
# Convert to YOLO format (normalized center + size)
|
|
x_center = (x0 + x1) / 2 / image_width
|
|
y_center = (y0 + y1) / 2 / image_height
|
|
width = (x1 - x0) / image_width
|
|
height = (y1 - y0) / image_height
|
|
|
|
# Clamp values to 0-1
|
|
x_center = max(0, min(1, x_center))
|
|
y_center = max(0, min(1, y_center))
|
|
width = max(0, min(1, width))
|
|
height = max(0, min(1, height))
|
|
|
|
annotations.append(YOLOAnnotation(
|
|
class_id=class_id,
|
|
x_center=x_center,
|
|
y_center=y_center,
|
|
width=width,
|
|
height=height,
|
|
confidence=best_match.score
|
|
))
|
|
|
|
return annotations
|
|
|
|
def add_payment_line_annotation(
|
|
self,
|
|
annotations: list[YOLOAnnotation],
|
|
payment_line_bbox: tuple[float, float, float, float],
|
|
confidence: float,
|
|
image_width: float,
|
|
image_height: float,
|
|
dpi: int = 300
|
|
) -> list[YOLOAnnotation]:
|
|
"""
|
|
Add payment_line annotation from machine code parser result.
|
|
|
|
Uses "payment_line" scale strategy for minimal expansion
|
|
(machine-readable code needs less context).
|
|
|
|
Args:
|
|
annotations: Existing list of annotations to append to
|
|
payment_line_bbox: Bounding box (x0, y0, x1, y1) in PDF coordinates
|
|
confidence: Confidence score from machine code parser
|
|
image_width: Width of the rendered image in pixels
|
|
image_height: Height of the rendered image in pixels
|
|
dpi: DPI used for rendering
|
|
|
|
Returns:
|
|
Updated annotations list with payment_line annotation added
|
|
"""
|
|
if not payment_line_bbox or confidence < self.min_confidence:
|
|
return annotations
|
|
|
|
# Scale factor to convert PDF points (72 DPI) to rendered pixels
|
|
scale = dpi / 72.0
|
|
|
|
x0, y0, x1, y1 = payment_line_bbox
|
|
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
|
|
|
|
# Apply field-specific bbox expansion strategy for payment_line
|
|
x0, y0, x1, y1 = expand_bbox(
|
|
bbox=(x0, y0, x1, y1),
|
|
image_width=image_width,
|
|
image_height=image_height,
|
|
field_type="payment_line",
|
|
)
|
|
|
|
# Convert to YOLO format (normalized center + size)
|
|
x_center = (x0 + x1) / 2 / image_width
|
|
y_center = (y0 + y1) / 2 / image_height
|
|
width = (x1 - x0) / image_width
|
|
height = (y1 - y0) / image_height
|
|
|
|
# Clamp values to 0-1
|
|
x_center = max(0, min(1, x_center))
|
|
y_center = max(0, min(1, y_center))
|
|
width = max(0, min(1, width))
|
|
height = max(0, min(1, height))
|
|
|
|
annotations.append(YOLOAnnotation(
|
|
class_id=FIELD_CLASSES['payment_line'],
|
|
x_center=x_center,
|
|
y_center=y_center,
|
|
width=width,
|
|
height=height,
|
|
confidence=confidence
|
|
))
|
|
|
|
return annotations
|
|
|
|
def save_annotations(
|
|
self,
|
|
annotations: list[YOLOAnnotation],
|
|
output_path: str | Path
|
|
) -> None:
|
|
"""Save annotations to a YOLO format text file."""
|
|
output_path = Path(output_path)
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(output_path, 'w') as f:
|
|
for ann in annotations:
|
|
f.write(ann.to_string() + '\n')
|
|
|
|
@staticmethod
|
|
def generate_classes_file(output_path: str | Path) -> None:
|
|
"""Generate the classes.txt file for YOLO."""
|
|
output_path = Path(output_path)
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(output_path, 'w') as f:
|
|
for class_name in CLASS_NAMES:
|
|
f.write(class_name + '\n')
|
|
|
|
@staticmethod
|
|
def generate_yaml_config(
|
|
output_path: str | Path,
|
|
train_path: str = 'train/images',
|
|
val_path: str = 'val/images',
|
|
test_path: str = 'test/images',
|
|
use_wsl_paths: bool | None = None
|
|
) -> None:
|
|
"""
|
|
Generate YOLO dataset YAML configuration.
|
|
|
|
Args:
|
|
output_path: Path to output YAML file
|
|
train_path: Relative path to training images
|
|
val_path: Relative path to validation images
|
|
test_path: Relative path to test images
|
|
use_wsl_paths: If True, convert Windows paths to WSL format.
|
|
If None, auto-detect based on environment.
|
|
"""
|
|
import os
|
|
import platform
|
|
|
|
output_path = Path(output_path)
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
dataset_dir = output_path.parent.absolute()
|
|
dataset_path_str = str(dataset_dir)
|
|
|
|
# Auto-detect WSL environment
|
|
if use_wsl_paths is None:
|
|
# Check if running inside WSL
|
|
is_wsl = 'microsoft' in platform.uname().release.lower() if platform.system() == 'Linux' else False
|
|
# Check WSL_DISTRO_NAME environment variable (set when running in WSL)
|
|
is_wsl = is_wsl or os.environ.get('WSL_DISTRO_NAME') is not None
|
|
use_wsl_paths = is_wsl
|
|
|
|
# Convert path format based on environment
|
|
if use_wsl_paths:
|
|
# Running in WSL: convert Windows paths to /mnt/c/... format
|
|
dataset_path_str = dataset_path_str.replace('\\', '/')
|
|
if len(dataset_path_str) > 1 and dataset_path_str[1] == ':':
|
|
drive = dataset_path_str[0].lower()
|
|
dataset_path_str = f"/mnt/{drive}{dataset_path_str[2:]}"
|
|
elif platform.system() == 'Windows':
|
|
# Running on native Windows: use forward slashes for YOLO compatibility
|
|
dataset_path_str = dataset_path_str.replace('\\', '/')
|
|
|
|
config = f"""# Invoice Field Detection Dataset
|
|
path: {dataset_path_str}
|
|
train: {train_path}
|
|
val: {val_path}
|
|
test: {test_path}
|
|
|
|
# Classes
|
|
names:
|
|
"""
|
|
for i, name in enumerate(CLASS_NAMES):
|
|
config += f" {i}: {name}\n"
|
|
|
|
with open(output_path, 'w') as f:
|
|
f.write(config)
|
|
|
|
|
|
def generate_annotations(
|
|
pdf_path: str | Path,
|
|
structured_data: dict[str, Any],
|
|
output_dir: str | Path,
|
|
dpi: int = 300
|
|
) -> list[Path]:
|
|
"""
|
|
Generate YOLO annotations for a PDF using structured data.
|
|
|
|
Args:
|
|
pdf_path: Path to the PDF file
|
|
structured_data: Dict with field values from CSV
|
|
output_dir: Directory to save images and labels
|
|
dpi: Resolution for rendering
|
|
|
|
Returns:
|
|
List of paths to generated annotation files
|
|
"""
|
|
from shared.pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
|
|
from shared.pdf.renderer import get_render_dimensions
|
|
from shared.ocr import OCREngine
|
|
from shared.matcher import FieldMatcher
|
|
from shared.normalize import normalize_field
|
|
|
|
output_dir = Path(output_dir)
|
|
images_dir = output_dir / 'images'
|
|
labels_dir = output_dir / 'labels'
|
|
images_dir.mkdir(parents=True, exist_ok=True)
|
|
labels_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
generator = AnnotationGenerator()
|
|
matcher = FieldMatcher()
|
|
annotation_files = []
|
|
|
|
# Check PDF type
|
|
use_ocr = not is_text_pdf(pdf_path)
|
|
|
|
# Initialize OCR if needed
|
|
ocr_engine = OCREngine() if use_ocr else None
|
|
|
|
# Process each page
|
|
for page_no, image_path in render_pdf_to_images(pdf_path, images_dir, dpi=dpi):
|
|
# Get image dimensions
|
|
img_width, img_height = get_render_dimensions(pdf_path, page_no, dpi)
|
|
|
|
# Extract tokens
|
|
if use_ocr:
|
|
from PIL import Image
|
|
tokens = ocr_engine.extract_from_image(str(image_path), page_no)
|
|
else:
|
|
tokens = list(extract_text_tokens(pdf_path, page_no))
|
|
|
|
# Match fields
|
|
matches = {}
|
|
for field_name in FIELD_CLASSES.keys():
|
|
value = structured_data.get(field_name)
|
|
if value is None or str(value).strip() == '':
|
|
continue
|
|
|
|
normalized = normalize_field(field_name, str(value))
|
|
field_matches = matcher.find_matches(tokens, field_name, normalized, page_no)
|
|
if field_matches:
|
|
matches[field_name] = field_matches
|
|
|
|
# Generate annotations (pass DPI for coordinate conversion)
|
|
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
|
|
|
|
# Save annotations
|
|
if annotations:
|
|
label_path = labels_dir / f"{image_path.stem}.txt"
|
|
generator.save_annotations(annotations, label_path)
|
|
annotation_files.append(label_path)
|
|
|
|
return annotation_files
|