Files
invoice-master-poc-v2/src/yolo/annotation_generator.py
Yaojia Wang 425b8fdedf WIP
2026-01-16 23:10:01 +01:00

321 lines
10 KiB
Python

"""
YOLO Annotation Generator
Generates YOLO format annotations from matched fields.
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import csv
# Field class mapping for YOLO
# Note: supplier_accounts is not a separate class - its matches are mapped to Bankgiro/Plusgiro
FIELD_CLASSES = {
'InvoiceNumber': 0,
'InvoiceDate': 1,
'InvoiceDueDate': 2,
'OCR': 3,
'Bankgiro': 4,
'Plusgiro': 5,
'Amount': 6,
'supplier_organisation_number': 7,
}
# Fields that need matching but map to other YOLO classes
# supplier_accounts matches are classified as Bankgiro or Plusgiro based on account type
ACCOUNT_FIELD_MAPPING = {
'supplier_accounts': {
'BG': 'Bankgiro', # BG:xxx -> Bankgiro class
'PG': 'Plusgiro', # PG:xxx -> Plusgiro class
}
}
CLASS_NAMES = [
'invoice_number',
'invoice_date',
'invoice_due_date',
'ocr_number',
'bankgiro',
'plusgiro',
'amount',
'supplier_org_number',
]
@dataclass
class YOLOAnnotation:
"""Represents a single YOLO annotation."""
class_id: int
x_center: float # normalized 0-1
y_center: float # normalized 0-1
width: float # normalized 0-1
height: float # normalized 0-1
confidence: float = 1.0
def to_string(self) -> str:
"""Convert to YOLO format string."""
return f"{self.class_id} {self.x_center:.6f} {self.y_center:.6f} {self.width:.6f} {self.height:.6f}"
class AnnotationGenerator:
"""Generates YOLO annotations from document matches."""
def __init__(
self,
min_confidence: float = 0.7,
bbox_padding_px: int = 20, # Absolute padding in pixels
min_bbox_height_px: int = 30 # Minimum bbox height
):
"""
Initialize annotation generator.
Args:
min_confidence: Minimum match score to include in training
bbox_padding_px: Absolute padding in pixels to add around bboxes
min_bbox_height_px: Minimum bbox height in pixels
"""
self.min_confidence = min_confidence
self.bbox_padding_px = bbox_padding_px
self.min_bbox_height_px = min_bbox_height_px
def generate_from_matches(
self,
matches: dict[str, list[Any]], # field_name -> list of Match
image_width: float,
image_height: float,
dpi: int = 300
) -> list[YOLOAnnotation]:
"""
Generate YOLO annotations from field matches.
Args:
matches: Dict of field_name -> list of Match objects
image_width: Width of the rendered image in pixels
image_height: Height of the rendered image in pixels
dpi: DPI used for rendering (needed to convert PDF coords to pixels)
Returns:
List of YOLOAnnotation objects
"""
annotations = []
# Scale factor to convert PDF points (72 DPI) to rendered pixels
scale = dpi / 72.0
for field_name, field_matches in matches.items():
if field_name not in FIELD_CLASSES:
continue
class_id = FIELD_CLASSES[field_name]
# Take only the best match per field
if field_matches:
best_match = field_matches[0] # Already sorted by score
if best_match.score < self.min_confidence:
continue
# best_match.bbox is in PDF points (72 DPI), convert to pixels
x0, y0, x1, y1 = best_match.bbox
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
# Add absolute padding
pad = self.bbox_padding_px
x0 = max(0, x0 - pad)
y0 = max(0, y0 - pad)
x1 = min(image_width, x1 + pad)
y1 = min(image_height, y1 + pad)
# Ensure minimum height
current_height = y1 - y0
if current_height < self.min_bbox_height_px:
extra = (self.min_bbox_height_px - current_height) / 2
y0 = max(0, y0 - extra)
y1 = min(image_height, y1 + extra)
# Convert to YOLO format (normalized center + size)
x_center = (x0 + x1) / 2 / image_width
y_center = (y0 + y1) / 2 / image_height
width = (x1 - x0) / image_width
height = (y1 - y0) / image_height
# Clamp values to 0-1
x_center = max(0, min(1, x_center))
y_center = max(0, min(1, y_center))
width = max(0, min(1, width))
height = max(0, min(1, height))
annotations.append(YOLOAnnotation(
class_id=class_id,
x_center=x_center,
y_center=y_center,
width=width,
height=height,
confidence=best_match.score
))
return annotations
def save_annotations(
self,
annotations: list[YOLOAnnotation],
output_path: str | Path
) -> None:
"""Save annotations to a YOLO format text file."""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w') as f:
for ann in annotations:
f.write(ann.to_string() + '\n')
@staticmethod
def generate_classes_file(output_path: str | Path) -> None:
"""Generate the classes.txt file for YOLO."""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w') as f:
for class_name in CLASS_NAMES:
f.write(class_name + '\n')
@staticmethod
def generate_yaml_config(
output_path: str | Path,
train_path: str = 'train/images',
val_path: str = 'val/images',
test_path: str = 'test/images',
use_wsl_paths: bool | None = None
) -> None:
"""
Generate YOLO dataset YAML configuration.
Args:
output_path: Path to output YAML file
train_path: Relative path to training images
val_path: Relative path to validation images
test_path: Relative path to test images
use_wsl_paths: If True, convert Windows paths to WSL format.
If None, auto-detect based on environment.
"""
import os
import platform
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
dataset_dir = output_path.parent.absolute()
dataset_path_str = str(dataset_dir)
# Auto-detect WSL environment
if use_wsl_paths is None:
# Check if running inside WSL
is_wsl = 'microsoft' in platform.uname().release.lower() if platform.system() == 'Linux' else False
# Check WSL_DISTRO_NAME environment variable (set when running in WSL)
is_wsl = is_wsl or os.environ.get('WSL_DISTRO_NAME') is not None
use_wsl_paths = is_wsl
# Convert path format based on environment
if use_wsl_paths:
# Running in WSL: convert Windows paths to /mnt/c/... format
dataset_path_str = dataset_path_str.replace('\\', '/')
if len(dataset_path_str) > 1 and dataset_path_str[1] == ':':
drive = dataset_path_str[0].lower()
dataset_path_str = f"/mnt/{drive}{dataset_path_str[2:]}"
elif platform.system() == 'Windows':
# Running on native Windows: use forward slashes for YOLO compatibility
dataset_path_str = dataset_path_str.replace('\\', '/')
config = f"""# Invoice Field Detection Dataset
path: {dataset_path_str}
train: {train_path}
val: {val_path}
test: {test_path}
# Classes
names:
"""
for i, name in enumerate(CLASS_NAMES):
config += f" {i}: {name}\n"
with open(output_path, 'w') as f:
f.write(config)
def generate_annotations(
pdf_path: str | Path,
structured_data: dict[str, Any],
output_dir: str | Path,
dpi: int = 300
) -> list[Path]:
"""
Generate YOLO annotations for a PDF using structured data.
Args:
pdf_path: Path to the PDF file
structured_data: Dict with field values from CSV
output_dir: Directory to save images and labels
dpi: Resolution for rendering
Returns:
List of paths to generated annotation files
"""
from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
from ..pdf.renderer import get_render_dimensions
from ..ocr import OCREngine
from ..matcher import FieldMatcher
from ..normalize import normalize_field
output_dir = Path(output_dir)
images_dir = output_dir / 'images'
labels_dir = output_dir / 'labels'
images_dir.mkdir(parents=True, exist_ok=True)
labels_dir.mkdir(parents=True, exist_ok=True)
generator = AnnotationGenerator()
matcher = FieldMatcher()
annotation_files = []
# Check PDF type
use_ocr = not is_text_pdf(pdf_path)
# Initialize OCR if needed
ocr_engine = OCREngine() if use_ocr else None
# Process each page
for page_no, image_path in render_pdf_to_images(pdf_path, images_dir, dpi=dpi):
# Get image dimensions
img_width, img_height = get_render_dimensions(pdf_path, page_no, dpi)
# Extract tokens
if use_ocr:
from PIL import Image
tokens = ocr_engine.extract_from_image(str(image_path), page_no)
else:
tokens = list(extract_text_tokens(pdf_path, page_no))
# Match fields
matches = {}
for field_name in FIELD_CLASSES.keys():
value = structured_data.get(field_name)
if value is None or str(value).strip() == '':
continue
normalized = normalize_field(field_name, str(value))
field_matches = matcher.find_matches(tokens, field_name, normalized, page_no)
if field_matches:
matches[field_name] = field_matches
# Generate annotations (pass DPI for coordinate conversion)
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
# Save annotations
if annotations:
label_path = labels_dir / f"{image_path.stem}.txt"
generator.save_annotations(annotations, label_path)
annotation_files.append(label_path)
return annotation_files