356 lines
12 KiB
Python
356 lines
12 KiB
Python
"""
|
|
YOLO Annotation Generator
|
|
|
|
Generates YOLO format annotations from matched fields.
|
|
"""
|
|
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any
|
|
import csv
|
|
|
|
# Import field mappings from single source of truth
|
|
from shared.fields import (
|
|
TRAINING_FIELD_CLASSES as FIELD_CLASSES,
|
|
CLASS_NAMES,
|
|
ACCOUNT_FIELD_MAPPING,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class YOLOAnnotation:
|
|
"""Represents a single YOLO annotation."""
|
|
class_id: int
|
|
x_center: float # normalized 0-1
|
|
y_center: float # normalized 0-1
|
|
width: float # normalized 0-1
|
|
height: float # normalized 0-1
|
|
confidence: float = 1.0
|
|
|
|
def to_string(self) -> str:
|
|
"""Convert to YOLO format string."""
|
|
return f"{self.class_id} {self.x_center:.6f} {self.y_center:.6f} {self.width:.6f} {self.height:.6f}"
|
|
|
|
|
|
class AnnotationGenerator:
|
|
"""Generates YOLO annotations from document matches."""
|
|
|
|
def __init__(
|
|
self,
|
|
min_confidence: float = 0.7,
|
|
bbox_padding_px: int = 20, # Absolute padding in pixels
|
|
min_bbox_height_px: int = 30 # Minimum bbox height
|
|
):
|
|
"""
|
|
Initialize annotation generator.
|
|
|
|
Args:
|
|
min_confidence: Minimum match score to include in training
|
|
bbox_padding_px: Absolute padding in pixels to add around bboxes
|
|
min_bbox_height_px: Minimum bbox height in pixels
|
|
"""
|
|
self.min_confidence = min_confidence
|
|
self.bbox_padding_px = bbox_padding_px
|
|
self.min_bbox_height_px = min_bbox_height_px
|
|
|
|
def generate_from_matches(
|
|
self,
|
|
matches: dict[str, list[Any]], # field_name -> list of Match
|
|
image_width: float,
|
|
image_height: float,
|
|
dpi: int = 300
|
|
) -> list[YOLOAnnotation]:
|
|
"""
|
|
Generate YOLO annotations from field matches.
|
|
|
|
Args:
|
|
matches: Dict of field_name -> list of Match objects
|
|
image_width: Width of the rendered image in pixels
|
|
image_height: Height of the rendered image in pixels
|
|
dpi: DPI used for rendering (needed to convert PDF coords to pixels)
|
|
|
|
Returns:
|
|
List of YOLOAnnotation objects
|
|
"""
|
|
annotations = []
|
|
|
|
# Scale factor to convert PDF points (72 DPI) to rendered pixels
|
|
scale = dpi / 72.0
|
|
|
|
for field_name, field_matches in matches.items():
|
|
if field_name not in FIELD_CLASSES:
|
|
continue
|
|
|
|
class_id = FIELD_CLASSES[field_name]
|
|
|
|
# Take only the best match per field
|
|
if field_matches:
|
|
best_match = field_matches[0] # Already sorted by score
|
|
|
|
if best_match.score < self.min_confidence:
|
|
continue
|
|
|
|
# best_match.bbox is in PDF points (72 DPI), convert to pixels
|
|
x0, y0, x1, y1 = best_match.bbox
|
|
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
|
|
|
|
# Add absolute padding
|
|
pad = self.bbox_padding_px
|
|
x0 = max(0, x0 - pad)
|
|
y0 = max(0, y0 - pad)
|
|
x1 = min(image_width, x1 + pad)
|
|
y1 = min(image_height, y1 + pad)
|
|
|
|
# Ensure minimum height
|
|
current_height = y1 - y0
|
|
if current_height < self.min_bbox_height_px:
|
|
extra = (self.min_bbox_height_px - current_height) / 2
|
|
y0 = max(0, y0 - extra)
|
|
y1 = min(image_height, y1 + extra)
|
|
|
|
# Convert to YOLO format (normalized center + size)
|
|
x_center = (x0 + x1) / 2 / image_width
|
|
y_center = (y0 + y1) / 2 / image_height
|
|
width = (x1 - x0) / image_width
|
|
height = (y1 - y0) / image_height
|
|
|
|
# Clamp values to 0-1
|
|
x_center = max(0, min(1, x_center))
|
|
y_center = max(0, min(1, y_center))
|
|
width = max(0, min(1, width))
|
|
height = max(0, min(1, height))
|
|
|
|
annotations.append(YOLOAnnotation(
|
|
class_id=class_id,
|
|
x_center=x_center,
|
|
y_center=y_center,
|
|
width=width,
|
|
height=height,
|
|
confidence=best_match.score
|
|
))
|
|
|
|
return annotations
|
|
|
|
def add_payment_line_annotation(
|
|
self,
|
|
annotations: list[YOLOAnnotation],
|
|
payment_line_bbox: tuple[float, float, float, float],
|
|
confidence: float,
|
|
image_width: float,
|
|
image_height: float,
|
|
dpi: int = 300
|
|
) -> list[YOLOAnnotation]:
|
|
"""
|
|
Add payment_line annotation from machine code parser result.
|
|
|
|
Args:
|
|
annotations: Existing list of annotations to append to
|
|
payment_line_bbox: Bounding box (x0, y0, x1, y1) in PDF coordinates
|
|
confidence: Confidence score from machine code parser
|
|
image_width: Width of the rendered image in pixels
|
|
image_height: Height of the rendered image in pixels
|
|
dpi: DPI used for rendering
|
|
|
|
Returns:
|
|
Updated annotations list with payment_line annotation added
|
|
"""
|
|
if not payment_line_bbox or confidence < self.min_confidence:
|
|
return annotations
|
|
|
|
# Scale factor to convert PDF points (72 DPI) to rendered pixels
|
|
scale = dpi / 72.0
|
|
|
|
x0, y0, x1, y1 = payment_line_bbox
|
|
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
|
|
|
|
# Add absolute padding
|
|
pad = self.bbox_padding_px
|
|
x0 = max(0, x0 - pad)
|
|
y0 = max(0, y0 - pad)
|
|
x1 = min(image_width, x1 + pad)
|
|
y1 = min(image_height, y1 + pad)
|
|
|
|
# Convert to YOLO format (normalized center + size)
|
|
x_center = (x0 + x1) / 2 / image_width
|
|
y_center = (y0 + y1) / 2 / image_height
|
|
width = (x1 - x0) / image_width
|
|
height = (y1 - y0) / image_height
|
|
|
|
# Clamp values to 0-1
|
|
x_center = max(0, min(1, x_center))
|
|
y_center = max(0, min(1, y_center))
|
|
width = max(0, min(1, width))
|
|
height = max(0, min(1, height))
|
|
|
|
annotations.append(YOLOAnnotation(
|
|
class_id=FIELD_CLASSES['payment_line'],
|
|
x_center=x_center,
|
|
y_center=y_center,
|
|
width=width,
|
|
height=height,
|
|
confidence=confidence
|
|
))
|
|
|
|
return annotations
|
|
|
|
def save_annotations(
|
|
self,
|
|
annotations: list[YOLOAnnotation],
|
|
output_path: str | Path
|
|
) -> None:
|
|
"""Save annotations to a YOLO format text file."""
|
|
output_path = Path(output_path)
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(output_path, 'w') as f:
|
|
for ann in annotations:
|
|
f.write(ann.to_string() + '\n')
|
|
|
|
@staticmethod
|
|
def generate_classes_file(output_path: str | Path) -> None:
|
|
"""Generate the classes.txt file for YOLO."""
|
|
output_path = Path(output_path)
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(output_path, 'w') as f:
|
|
for class_name in CLASS_NAMES:
|
|
f.write(class_name + '\n')
|
|
|
|
@staticmethod
|
|
def generate_yaml_config(
|
|
output_path: str | Path,
|
|
train_path: str = 'train/images',
|
|
val_path: str = 'val/images',
|
|
test_path: str = 'test/images',
|
|
use_wsl_paths: bool | None = None
|
|
) -> None:
|
|
"""
|
|
Generate YOLO dataset YAML configuration.
|
|
|
|
Args:
|
|
output_path: Path to output YAML file
|
|
train_path: Relative path to training images
|
|
val_path: Relative path to validation images
|
|
test_path: Relative path to test images
|
|
use_wsl_paths: If True, convert Windows paths to WSL format.
|
|
If None, auto-detect based on environment.
|
|
"""
|
|
import os
|
|
import platform
|
|
|
|
output_path = Path(output_path)
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
dataset_dir = output_path.parent.absolute()
|
|
dataset_path_str = str(dataset_dir)
|
|
|
|
# Auto-detect WSL environment
|
|
if use_wsl_paths is None:
|
|
# Check if running inside WSL
|
|
is_wsl = 'microsoft' in platform.uname().release.lower() if platform.system() == 'Linux' else False
|
|
# Check WSL_DISTRO_NAME environment variable (set when running in WSL)
|
|
is_wsl = is_wsl or os.environ.get('WSL_DISTRO_NAME') is not None
|
|
use_wsl_paths = is_wsl
|
|
|
|
# Convert path format based on environment
|
|
if use_wsl_paths:
|
|
# Running in WSL: convert Windows paths to /mnt/c/... format
|
|
dataset_path_str = dataset_path_str.replace('\\', '/')
|
|
if len(dataset_path_str) > 1 and dataset_path_str[1] == ':':
|
|
drive = dataset_path_str[0].lower()
|
|
dataset_path_str = f"/mnt/{drive}{dataset_path_str[2:]}"
|
|
elif platform.system() == 'Windows':
|
|
# Running on native Windows: use forward slashes for YOLO compatibility
|
|
dataset_path_str = dataset_path_str.replace('\\', '/')
|
|
|
|
config = f"""# Invoice Field Detection Dataset
|
|
path: {dataset_path_str}
|
|
train: {train_path}
|
|
val: {val_path}
|
|
test: {test_path}
|
|
|
|
# Classes
|
|
names:
|
|
"""
|
|
for i, name in enumerate(CLASS_NAMES):
|
|
config += f" {i}: {name}\n"
|
|
|
|
with open(output_path, 'w') as f:
|
|
f.write(config)
|
|
|
|
|
|
def generate_annotations(
|
|
pdf_path: str | Path,
|
|
structured_data: dict[str, Any],
|
|
output_dir: str | Path,
|
|
dpi: int = 300
|
|
) -> list[Path]:
|
|
"""
|
|
Generate YOLO annotations for a PDF using structured data.
|
|
|
|
Args:
|
|
pdf_path: Path to the PDF file
|
|
structured_data: Dict with field values from CSV
|
|
output_dir: Directory to save images and labels
|
|
dpi: Resolution for rendering
|
|
|
|
Returns:
|
|
List of paths to generated annotation files
|
|
"""
|
|
from shared.pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
|
|
from shared.pdf.renderer import get_render_dimensions
|
|
from shared.ocr import OCREngine
|
|
from shared.matcher import FieldMatcher
|
|
from shared.normalize import normalize_field
|
|
|
|
output_dir = Path(output_dir)
|
|
images_dir = output_dir / 'images'
|
|
labels_dir = output_dir / 'labels'
|
|
images_dir.mkdir(parents=True, exist_ok=True)
|
|
labels_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
generator = AnnotationGenerator()
|
|
matcher = FieldMatcher()
|
|
annotation_files = []
|
|
|
|
# Check PDF type
|
|
use_ocr = not is_text_pdf(pdf_path)
|
|
|
|
# Initialize OCR if needed
|
|
ocr_engine = OCREngine() if use_ocr else None
|
|
|
|
# Process each page
|
|
for page_no, image_path in render_pdf_to_images(pdf_path, images_dir, dpi=dpi):
|
|
# Get image dimensions
|
|
img_width, img_height = get_render_dimensions(pdf_path, page_no, dpi)
|
|
|
|
# Extract tokens
|
|
if use_ocr:
|
|
from PIL import Image
|
|
tokens = ocr_engine.extract_from_image(str(image_path), page_no)
|
|
else:
|
|
tokens = list(extract_text_tokens(pdf_path, page_no))
|
|
|
|
# Match fields
|
|
matches = {}
|
|
for field_name in FIELD_CLASSES.keys():
|
|
value = structured_data.get(field_name)
|
|
if value is None or str(value).strip() == '':
|
|
continue
|
|
|
|
normalized = normalize_field(field_name, str(value))
|
|
field_matches = matcher.find_matches(tokens, field_name, normalized, page_no)
|
|
if field_matches:
|
|
matches[field_name] = field_matches
|
|
|
|
# Generate annotations (pass DPI for coordinate conversion)
|
|
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
|
|
|
|
# Save annotations
|
|
if annotations:
|
|
label_path = labels_dir / f"{image_path.stem}.txt"
|
|
generator.save_annotations(annotations, label_path)
|
|
annotation_files.append(label_path)
|
|
|
|
return annotation_files
|