Initial commit: Invoice field extraction system using YOLO + OCR
Features: - Auto-labeling pipeline: CSV values -> PDF search -> YOLO annotations - Flexible date matching: year-month match, nearby date tolerance - PDF text extraction with PyMuPDF - OCR support for scanned documents (PaddleOCR) - YOLO training and inference pipeline - 7 field types: InvoiceNumber, InvoiceDate, InvoiceDueDate, OCR, Bankgiro, Plusgiro, Amount Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
281
src/yolo/annotation_generator.py
Normal file
281
src/yolo/annotation_generator.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""
|
||||
YOLO Annotation Generator
|
||||
|
||||
Generates YOLO format annotations from matched fields.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
import csv
|
||||
|
||||
|
||||
# Field class mapping for YOLO
|
||||
FIELD_CLASSES = {
|
||||
'InvoiceNumber': 0,
|
||||
'InvoiceDate': 1,
|
||||
'InvoiceDueDate': 2,
|
||||
'OCR': 3,
|
||||
'Bankgiro': 4,
|
||||
'Plusgiro': 5,
|
||||
'Amount': 6,
|
||||
}
|
||||
|
||||
CLASS_NAMES = [
|
||||
'invoice_number',
|
||||
'invoice_date',
|
||||
'invoice_due_date',
|
||||
'ocr_number',
|
||||
'bankgiro',
|
||||
'plusgiro',
|
||||
'amount',
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class YOLOAnnotation:
|
||||
"""Represents a single YOLO annotation."""
|
||||
class_id: int
|
||||
x_center: float # normalized 0-1
|
||||
y_center: float # normalized 0-1
|
||||
width: float # normalized 0-1
|
||||
height: float # normalized 0-1
|
||||
confidence: float = 1.0
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Convert to YOLO format string."""
|
||||
return f"{self.class_id} {self.x_center:.6f} {self.y_center:.6f} {self.width:.6f} {self.height:.6f}"
|
||||
|
||||
|
||||
class AnnotationGenerator:
|
||||
"""Generates YOLO annotations from document matches."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
min_confidence: float = 0.7,
|
||||
bbox_padding_px: int = 20, # Absolute padding in pixels
|
||||
min_bbox_height_px: int = 30 # Minimum bbox height
|
||||
):
|
||||
"""
|
||||
Initialize annotation generator.
|
||||
|
||||
Args:
|
||||
min_confidence: Minimum match score to include in training
|
||||
bbox_padding_px: Absolute padding in pixels to add around bboxes
|
||||
min_bbox_height_px: Minimum bbox height in pixels
|
||||
"""
|
||||
self.min_confidence = min_confidence
|
||||
self.bbox_padding_px = bbox_padding_px
|
||||
self.min_bbox_height_px = min_bbox_height_px
|
||||
|
||||
def generate_from_matches(
|
||||
self,
|
||||
matches: dict[str, list[Any]], # field_name -> list of Match
|
||||
image_width: float,
|
||||
image_height: float,
|
||||
dpi: int = 300
|
||||
) -> list[YOLOAnnotation]:
|
||||
"""
|
||||
Generate YOLO annotations from field matches.
|
||||
|
||||
Args:
|
||||
matches: Dict of field_name -> list of Match objects
|
||||
image_width: Width of the rendered image in pixels
|
||||
image_height: Height of the rendered image in pixels
|
||||
dpi: DPI used for rendering (needed to convert PDF coords to pixels)
|
||||
|
||||
Returns:
|
||||
List of YOLOAnnotation objects
|
||||
"""
|
||||
annotations = []
|
||||
|
||||
# Scale factor to convert PDF points (72 DPI) to rendered pixels
|
||||
scale = dpi / 72.0
|
||||
|
||||
for field_name, field_matches in matches.items():
|
||||
if field_name not in FIELD_CLASSES:
|
||||
continue
|
||||
|
||||
class_id = FIELD_CLASSES[field_name]
|
||||
|
||||
# Take only the best match per field
|
||||
if field_matches:
|
||||
best_match = field_matches[0] # Already sorted by score
|
||||
|
||||
if best_match.score < self.min_confidence:
|
||||
continue
|
||||
|
||||
# best_match.bbox is in PDF points (72 DPI), convert to pixels
|
||||
x0, y0, x1, y1 = best_match.bbox
|
||||
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
|
||||
|
||||
# Add absolute padding
|
||||
pad = self.bbox_padding_px
|
||||
x0 = max(0, x0 - pad)
|
||||
y0 = max(0, y0 - pad)
|
||||
x1 = min(image_width, x1 + pad)
|
||||
y1 = min(image_height, y1 + pad)
|
||||
|
||||
# Ensure minimum height
|
||||
current_height = y1 - y0
|
||||
if current_height < self.min_bbox_height_px:
|
||||
extra = (self.min_bbox_height_px - current_height) / 2
|
||||
y0 = max(0, y0 - extra)
|
||||
y1 = min(image_height, y1 + extra)
|
||||
|
||||
# Convert to YOLO format (normalized center + size)
|
||||
x_center = (x0 + x1) / 2 / image_width
|
||||
y_center = (y0 + y1) / 2 / image_height
|
||||
width = (x1 - x0) / image_width
|
||||
height = (y1 - y0) / image_height
|
||||
|
||||
# Clamp values to 0-1
|
||||
x_center = max(0, min(1, x_center))
|
||||
y_center = max(0, min(1, y_center))
|
||||
width = max(0, min(1, width))
|
||||
height = max(0, min(1, height))
|
||||
|
||||
annotations.append(YOLOAnnotation(
|
||||
class_id=class_id,
|
||||
x_center=x_center,
|
||||
y_center=y_center,
|
||||
width=width,
|
||||
height=height,
|
||||
confidence=best_match.score
|
||||
))
|
||||
|
||||
return annotations
|
||||
|
||||
def save_annotations(
|
||||
self,
|
||||
annotations: list[YOLOAnnotation],
|
||||
output_path: str | Path
|
||||
) -> None:
|
||||
"""Save annotations to a YOLO format text file."""
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, 'w') as f:
|
||||
for ann in annotations:
|
||||
f.write(ann.to_string() + '\n')
|
||||
|
||||
@staticmethod
|
||||
def generate_classes_file(output_path: str | Path) -> None:
|
||||
"""Generate the classes.txt file for YOLO."""
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, 'w') as f:
|
||||
for class_name in CLASS_NAMES:
|
||||
f.write(class_name + '\n')
|
||||
|
||||
@staticmethod
|
||||
def generate_yaml_config(
|
||||
output_path: str | Path,
|
||||
train_path: str = 'train/images',
|
||||
val_path: str = 'val/images',
|
||||
test_path: str = 'test/images'
|
||||
) -> None:
|
||||
"""Generate YOLO dataset YAML configuration."""
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Use absolute path for WSL compatibility
|
||||
dataset_dir = output_path.parent.absolute()
|
||||
# Convert Windows path to WSL path if needed
|
||||
dataset_path_str = str(dataset_dir).replace('\\', '/')
|
||||
if dataset_path_str[1] == ':':
|
||||
# Windows path like C:/... -> /mnt/c/...
|
||||
drive = dataset_path_str[0].lower()
|
||||
dataset_path_str = f"/mnt/{drive}{dataset_path_str[2:]}"
|
||||
|
||||
config = f"""# Invoice Field Detection Dataset
|
||||
path: {dataset_path_str}
|
||||
train: {train_path}
|
||||
val: {val_path}
|
||||
test: {test_path}
|
||||
|
||||
# Classes
|
||||
names:
|
||||
"""
|
||||
for i, name in enumerate(CLASS_NAMES):
|
||||
config += f" {i}: {name}\n"
|
||||
|
||||
with open(output_path, 'w') as f:
|
||||
f.write(config)
|
||||
|
||||
|
||||
def generate_annotations(
|
||||
pdf_path: str | Path,
|
||||
structured_data: dict[str, Any],
|
||||
output_dir: str | Path,
|
||||
dpi: int = 300
|
||||
) -> list[Path]:
|
||||
"""
|
||||
Generate YOLO annotations for a PDF using structured data.
|
||||
|
||||
Args:
|
||||
pdf_path: Path to the PDF file
|
||||
structured_data: Dict with field values from CSV
|
||||
output_dir: Directory to save images and labels
|
||||
dpi: Resolution for rendering
|
||||
|
||||
Returns:
|
||||
List of paths to generated annotation files
|
||||
"""
|
||||
from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
|
||||
from ..pdf.renderer import get_render_dimensions
|
||||
from ..ocr import OCREngine
|
||||
from ..matcher import FieldMatcher
|
||||
from ..normalize import normalize_field
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
images_dir = output_dir / 'images'
|
||||
labels_dir = output_dir / 'labels'
|
||||
images_dir.mkdir(parents=True, exist_ok=True)
|
||||
labels_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
generator = AnnotationGenerator()
|
||||
matcher = FieldMatcher()
|
||||
annotation_files = []
|
||||
|
||||
# Check PDF type
|
||||
use_ocr = not is_text_pdf(pdf_path)
|
||||
|
||||
# Initialize OCR if needed
|
||||
ocr_engine = OCREngine() if use_ocr else None
|
||||
|
||||
# Process each page
|
||||
for page_no, image_path in render_pdf_to_images(pdf_path, images_dir, dpi=dpi):
|
||||
# Get image dimensions
|
||||
img_width, img_height = get_render_dimensions(pdf_path, page_no, dpi)
|
||||
|
||||
# Extract tokens
|
||||
if use_ocr:
|
||||
from PIL import Image
|
||||
tokens = ocr_engine.extract_from_image(str(image_path), page_no)
|
||||
else:
|
||||
tokens = list(extract_text_tokens(pdf_path, page_no))
|
||||
|
||||
# Match fields
|
||||
matches = {}
|
||||
for field_name in FIELD_CLASSES.keys():
|
||||
value = structured_data.get(field_name)
|
||||
if value is None or str(value).strip() == '':
|
||||
continue
|
||||
|
||||
normalized = normalize_field(field_name, str(value))
|
||||
field_matches = matcher.find_matches(tokens, field_name, normalized, page_no)
|
||||
if field_matches:
|
||||
matches[field_name] = field_matches
|
||||
|
||||
# Generate annotations (pass DPI for coordinate conversion)
|
||||
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
|
||||
|
||||
# Save annotations
|
||||
if annotations:
|
||||
label_path = labels_dir / f"{image_path.stem}.txt"
|
||||
generator.save_annotations(annotations, label_path)
|
||||
annotation_files.append(label_path)
|
||||
|
||||
return annotation_files
|
||||
Reference in New Issue
Block a user