Initial commit: Invoice field extraction system using YOLO + OCR
Features: - Auto-labeling pipeline: CSV values -> PDF search -> YOLO annotations - Flexible date matching: year-month match, nearby date tolerance - PDF text extraction with PyMuPDF - OCR support for scanned documents (PaddleOCR) - YOLO training and inference pipeline - 7 field types: InvoiceNumber, InvoiceDate, InvoiceDueDate, OCR, Bankgiro, Plusgiro, Amount Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
458
src/cli/autolabel.py
Normal file
458
src/cli/autolabel.py
Normal file
@@ -0,0 +1,458 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Auto-labeling CLI
|
||||
|
||||
Generates YOLO training data from PDFs and structured CSV data.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
import multiprocessing
|
||||
|
||||
# Global OCR engine for worker processes (initialized once per worker)
|
||||
_worker_ocr_engine = None
|
||||
|
||||
|
||||
def _init_worker():
|
||||
"""Initialize worker process with OCR engine (called once per worker)."""
|
||||
global _worker_ocr_engine
|
||||
# OCR engine will be lazily initialized on first use
|
||||
_worker_ocr_engine = None
|
||||
|
||||
|
||||
def _get_ocr_engine():
|
||||
"""Get or create OCR engine for current worker."""
|
||||
global _worker_ocr_engine
|
||||
if _worker_ocr_engine is None:
|
||||
from ..ocr import OCREngine
|
||||
_worker_ocr_engine = OCREngine()
|
||||
return _worker_ocr_engine
|
||||
|
||||
|
||||
def process_single_document(args_tuple):
|
||||
"""
|
||||
Process a single document (worker function for parallel processing).
|
||||
|
||||
Args:
|
||||
args_tuple: (row_dict, pdf_path, output_dir, dpi, min_confidence, skip_ocr)
|
||||
|
||||
Returns:
|
||||
dict with results
|
||||
"""
|
||||
row_dict, pdf_path_str, output_dir_str, dpi, min_confidence, skip_ocr = args_tuple
|
||||
|
||||
# Import inside worker to avoid pickling issues
|
||||
from ..data import AutoLabelReport, FieldMatchResult
|
||||
from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
|
||||
from ..pdf.renderer import get_render_dimensions
|
||||
from ..matcher import FieldMatcher
|
||||
from ..normalize import normalize_field
|
||||
from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
|
||||
|
||||
start_time = time.time()
|
||||
pdf_path = Path(pdf_path_str)
|
||||
output_dir = Path(output_dir_str)
|
||||
doc_id = row_dict['DocumentId']
|
||||
|
||||
report = AutoLabelReport(document_id=doc_id)
|
||||
report.pdf_path = str(pdf_path)
|
||||
|
||||
result = {
|
||||
'doc_id': doc_id,
|
||||
'success': False,
|
||||
'pages': [],
|
||||
'report': None,
|
||||
'stats': {name: 0 for name in FIELD_CLASSES.keys()}
|
||||
}
|
||||
|
||||
try:
|
||||
# Check PDF type
|
||||
use_ocr = not is_text_pdf(pdf_path)
|
||||
report.pdf_type = "scanned" if use_ocr else "text"
|
||||
|
||||
# Skip OCR if requested
|
||||
if use_ocr and skip_ocr:
|
||||
report.errors.append("Skipped (scanned PDF)")
|
||||
report.processing_time_ms = (time.time() - start_time) * 1000
|
||||
result['report'] = report.to_dict()
|
||||
return result
|
||||
|
||||
# Get OCR engine from worker cache (only created once per worker)
|
||||
ocr_engine = None
|
||||
if use_ocr:
|
||||
ocr_engine = _get_ocr_engine()
|
||||
|
||||
generator = AnnotationGenerator(min_confidence=min_confidence)
|
||||
matcher = FieldMatcher()
|
||||
|
||||
# Process each page
|
||||
page_annotations = []
|
||||
|
||||
for page_no, image_path in render_pdf_to_images(
|
||||
pdf_path,
|
||||
output_dir / 'temp' / doc_id / 'images',
|
||||
dpi=dpi
|
||||
):
|
||||
report.total_pages += 1
|
||||
img_width, img_height = get_render_dimensions(pdf_path, page_no, dpi)
|
||||
|
||||
# Extract tokens
|
||||
if use_ocr:
|
||||
tokens = ocr_engine.extract_from_image(str(image_path), page_no)
|
||||
else:
|
||||
tokens = list(extract_text_tokens(pdf_path, page_no))
|
||||
|
||||
# Match fields
|
||||
matches = {}
|
||||
for field_name in FIELD_CLASSES.keys():
|
||||
value = row_dict.get(field_name)
|
||||
if not value:
|
||||
continue
|
||||
|
||||
normalized = normalize_field(field_name, str(value))
|
||||
field_matches = matcher.find_matches(tokens, field_name, normalized, page_no)
|
||||
|
||||
# Record result
|
||||
if field_matches:
|
||||
best = field_matches[0]
|
||||
matches[field_name] = field_matches
|
||||
report.add_field_result(FieldMatchResult(
|
||||
field_name=field_name,
|
||||
csv_value=str(value),
|
||||
matched=True,
|
||||
score=best.score,
|
||||
matched_text=best.matched_text,
|
||||
candidate_used=best.value,
|
||||
bbox=best.bbox,
|
||||
page_no=page_no,
|
||||
context_keywords=best.context_keywords
|
||||
))
|
||||
else:
|
||||
report.add_field_result(FieldMatchResult(
|
||||
field_name=field_name,
|
||||
csv_value=str(value),
|
||||
matched=False,
|
||||
page_no=page_no
|
||||
))
|
||||
|
||||
# Generate annotations
|
||||
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
|
||||
|
||||
if annotations:
|
||||
label_path = output_dir / 'temp' / doc_id / 'labels' / f"{image_path.stem}.txt"
|
||||
generator.save_annotations(annotations, label_path)
|
||||
page_annotations.append({
|
||||
'image_path': str(image_path),
|
||||
'label_path': str(label_path),
|
||||
'count': len(annotations)
|
||||
})
|
||||
|
||||
report.annotations_generated += len(annotations)
|
||||
for ann in annotations:
|
||||
class_name = list(FIELD_CLASSES.keys())[ann.class_id]
|
||||
result['stats'][class_name] += 1
|
||||
|
||||
if page_annotations:
|
||||
result['pages'] = page_annotations
|
||||
result['success'] = True
|
||||
report.success = True
|
||||
else:
|
||||
report.errors.append("No annotations generated")
|
||||
|
||||
except Exception as e:
|
||||
report.errors.append(str(e))
|
||||
|
||||
report.processing_time_ms = (time.time() - start_time) * 1000
|
||||
result['report'] = report.to_dict()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Generate YOLO annotations from PDFs and CSV data'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--csv', '-c',
|
||||
default='data/structured_data/document_export_20260109_212743.csv',
|
||||
help='Path to structured data CSV file'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--pdf-dir', '-p',
|
||||
default='data/raw_pdfs',
|
||||
help='Directory containing PDF files'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output', '-o',
|
||||
default='data/dataset',
|
||||
help='Output directory for dataset'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dpi',
|
||||
type=int,
|
||||
default=300,
|
||||
help='DPI for PDF rendering (default: 300)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--min-confidence',
|
||||
type=float,
|
||||
default=0.7,
|
||||
help='Minimum match confidence (default: 0.7)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--train-ratio',
|
||||
type=float,
|
||||
default=0.8,
|
||||
help='Training set ratio (default: 0.8)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--val-ratio',
|
||||
type=float,
|
||||
default=0.1,
|
||||
help='Validation set ratio (default: 0.1)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--report',
|
||||
default='reports/autolabel_report.jsonl',
|
||||
help='Path for auto-label report (JSONL)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--single',
|
||||
help='Process single document ID only'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--verbose', '-v',
|
||||
action='store_true',
|
||||
help='Verbose output'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--workers', '-w',
|
||||
type=int,
|
||||
default=4,
|
||||
help='Number of parallel workers (default: 4)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--skip-ocr',
|
||||
action='store_true',
|
||||
help='Skip scanned PDFs (text-layer only)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Import here to avoid slow startup
|
||||
from ..data import CSVLoader, AutoLabelReport, FieldMatchResult
|
||||
from ..data.autolabel_report import ReportWriter
|
||||
from ..yolo import DatasetBuilder
|
||||
from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
|
||||
from ..pdf.renderer import get_render_dimensions
|
||||
from ..ocr import OCREngine
|
||||
from ..matcher import FieldMatcher
|
||||
from ..normalize import normalize_field
|
||||
from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
|
||||
|
||||
print(f"Loading CSV data from: {args.csv}")
|
||||
loader = CSVLoader(args.csv, args.pdf_dir)
|
||||
|
||||
# Validate data
|
||||
issues = loader.validate()
|
||||
if issues:
|
||||
print(f"Warning: Found {len(issues)} validation issues")
|
||||
if args.verbose:
|
||||
for issue in issues[:10]:
|
||||
print(f" - {issue}")
|
||||
|
||||
rows = loader.load_all()
|
||||
print(f"Loaded {len(rows)} invoice records")
|
||||
|
||||
# Filter to single document if specified
|
||||
if args.single:
|
||||
rows = [r for r in rows if r.DocumentId == args.single]
|
||||
if not rows:
|
||||
print(f"Error: Document {args.single} not found")
|
||||
sys.exit(1)
|
||||
print(f"Processing single document: {args.single}")
|
||||
|
||||
# Setup output directories
|
||||
output_dir = Path(args.output)
|
||||
for split in ['train', 'val', 'test']:
|
||||
(output_dir / split / 'images').mkdir(parents=True, exist_ok=True)
|
||||
(output_dir / split / 'labels').mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate YOLO config files
|
||||
AnnotationGenerator.generate_classes_file(output_dir / 'classes.txt')
|
||||
AnnotationGenerator.generate_yaml_config(output_dir / 'dataset.yaml')
|
||||
|
||||
# Report writer
|
||||
report_path = Path(args.report)
|
||||
report_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
report_writer = ReportWriter(args.report)
|
||||
|
||||
# Stats
|
||||
stats = {
|
||||
'total': len(rows),
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'skipped': 0,
|
||||
'annotations': 0,
|
||||
'by_field': {name: 0 for name in FIELD_CLASSES.keys()}
|
||||
}
|
||||
|
||||
# Prepare tasks
|
||||
tasks = []
|
||||
for row in rows:
|
||||
pdf_path = loader.get_pdf_path(row)
|
||||
if not pdf_path:
|
||||
# Write report for missing PDF
|
||||
report = AutoLabelReport(document_id=row.DocumentId)
|
||||
report.errors.append("PDF not found")
|
||||
report_writer.write(report)
|
||||
stats['failed'] += 1
|
||||
continue
|
||||
|
||||
# Convert row to dict for pickling
|
||||
row_dict = {
|
||||
'DocumentId': row.DocumentId,
|
||||
'InvoiceNumber': row.InvoiceNumber,
|
||||
'InvoiceDate': row.InvoiceDate,
|
||||
'InvoiceDueDate': row.InvoiceDueDate,
|
||||
'OCR': row.OCR,
|
||||
'Bankgiro': row.Bankgiro,
|
||||
'Plusgiro': row.Plusgiro,
|
||||
'Amount': row.Amount,
|
||||
}
|
||||
|
||||
tasks.append((
|
||||
row_dict,
|
||||
str(pdf_path),
|
||||
str(output_dir),
|
||||
args.dpi,
|
||||
args.min_confidence,
|
||||
args.skip_ocr
|
||||
))
|
||||
|
||||
print(f"Processing {len(tasks)} documents with {args.workers} workers...")
|
||||
|
||||
# Process documents in parallel
|
||||
processed_items = []
|
||||
|
||||
# Use single process for debugging or when workers=1
|
||||
if args.workers == 1:
|
||||
for task in tqdm(tasks, desc="Processing"):
|
||||
result = process_single_document(task)
|
||||
|
||||
# Write report
|
||||
if result['report']:
|
||||
report_writer.write_dict(result['report'])
|
||||
|
||||
if result['success']:
|
||||
processed_items.append({
|
||||
'doc_id': result['doc_id'],
|
||||
'pages': result['pages']
|
||||
})
|
||||
stats['successful'] += 1
|
||||
for field, count in result['stats'].items():
|
||||
stats['by_field'][field] += count
|
||||
stats['annotations'] += count
|
||||
elif 'Skipped' in str(result.get('report', {}).get('errors', [])):
|
||||
stats['skipped'] += 1
|
||||
else:
|
||||
stats['failed'] += 1
|
||||
else:
|
||||
# Parallel processing with worker initialization
|
||||
# Each worker initializes OCR engine once and reuses it
|
||||
with ProcessPoolExecutor(max_workers=args.workers, initializer=_init_worker) as executor:
|
||||
futures = {executor.submit(process_single_document, task): task[0]['DocumentId']
|
||||
for task in tasks}
|
||||
|
||||
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
|
||||
doc_id = futures[future]
|
||||
try:
|
||||
result = future.result()
|
||||
|
||||
# Write report
|
||||
if result['report']:
|
||||
report_writer.write_dict(result['report'])
|
||||
|
||||
if result['success']:
|
||||
processed_items.append({
|
||||
'doc_id': result['doc_id'],
|
||||
'pages': result['pages']
|
||||
})
|
||||
stats['successful'] += 1
|
||||
for field, count in result['stats'].items():
|
||||
stats['by_field'][field] += count
|
||||
stats['annotations'] += count
|
||||
elif 'Skipped' in str(result.get('report', {}).get('errors', [])):
|
||||
stats['skipped'] += 1
|
||||
else:
|
||||
stats['failed'] += 1
|
||||
|
||||
except Exception as e:
|
||||
stats['failed'] += 1
|
||||
# Write error report for failed documents
|
||||
error_report = {
|
||||
'document_id': doc_id,
|
||||
'success': False,
|
||||
'errors': [f"Worker error: {str(e)}"]
|
||||
}
|
||||
report_writer.write_dict(error_report)
|
||||
if args.verbose:
|
||||
print(f"Error processing {doc_id}: {e}")
|
||||
|
||||
# Split and move files
|
||||
import random
|
||||
random.seed(42)
|
||||
random.shuffle(processed_items)
|
||||
|
||||
n_train = int(len(processed_items) * args.train_ratio)
|
||||
n_val = int(len(processed_items) * args.val_ratio)
|
||||
|
||||
splits = {
|
||||
'train': processed_items[:n_train],
|
||||
'val': processed_items[n_train:n_train + n_val],
|
||||
'test': processed_items[n_train + n_val:]
|
||||
}
|
||||
|
||||
import shutil
|
||||
for split_name, items in splits.items():
|
||||
for item in items:
|
||||
for page in item['pages']:
|
||||
# Move image
|
||||
image_path = Path(page['image_path'])
|
||||
label_path = Path(page['label_path'])
|
||||
dest_img = output_dir / split_name / 'images' / image_path.name
|
||||
shutil.move(str(image_path), str(dest_img))
|
||||
|
||||
# Move label
|
||||
dest_label = output_dir / split_name / 'labels' / label_path.name
|
||||
shutil.move(str(label_path), str(dest_label))
|
||||
|
||||
# Cleanup temp
|
||||
shutil.rmtree(output_dir / 'temp', ignore_errors=True)
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 50)
|
||||
print("Auto-labeling Complete")
|
||||
print("=" * 50)
|
||||
print(f"Total documents: {stats['total']}")
|
||||
print(f"Successful: {stats['successful']}")
|
||||
print(f"Failed: {stats['failed']}")
|
||||
print(f"Skipped (OCR): {stats['skipped']}")
|
||||
print(f"Total annotations: {stats['annotations']}")
|
||||
print(f"\nDataset split:")
|
||||
print(f" Train: {len(splits['train'])} documents")
|
||||
print(f" Val: {len(splits['val'])} documents")
|
||||
print(f" Test: {len(splits['test'])} documents")
|
||||
print(f"\nAnnotations by field:")
|
||||
for field, count in stats['by_field'].items():
|
||||
print(f" {field}: {count}")
|
||||
print(f"\nOutput: {output_dir}")
|
||||
print(f"Report: {args.report}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user