Files
invoice-master-poc-v2/src/cli/autolabel.py
Yaojia Wang 425b8fdedf WIP
2026-01-16 23:10:01 +01:00

811 lines
31 KiB
Python

#!/usr/bin/env python3
"""
Auto-labeling CLI
Generates YOLO training data from PDFs and structured CSV data.
"""
import argparse
import sys
import time
import os
import signal
import warnings
from pathlib import Path
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
import multiprocessing
# Global flag for graceful shutdown
_shutdown_requested = False
def _signal_handler(signum, frame):
"""Handle interrupt signals for graceful shutdown."""
global _shutdown_requested
_shutdown_requested = True
print("\n\nShutdown requested. Finishing current batch and saving progress...")
print("(Press Ctrl+C again to force quit)\n")
# Windows compatibility: use 'spawn' method for multiprocessing
# This is required on Windows and is also safer for libraries like PaddleOCR
if sys.platform == 'win32':
multiprocessing.set_start_method('spawn', force=True)
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from config import get_db_connection_string, PATHS, AUTOLABEL
# Global OCR engine for worker processes (initialized once per worker)
_worker_ocr_engine = None
_worker_initialized = False
_worker_type = None # 'cpu' or 'gpu'
def _init_cpu_worker():
"""Initialize CPU worker (no OCR engine needed)."""
global _worker_initialized, _worker_type
_worker_initialized = True
_worker_type = 'cpu'
def _init_gpu_worker():
"""Initialize GPU worker with OCR engine (called once per worker)."""
global _worker_ocr_engine, _worker_initialized, _worker_type
# Suppress PaddlePaddle/PaddleX reinitialization warnings
warnings.filterwarnings('ignore', message='.*PDX has already been initialized.*')
warnings.filterwarnings('ignore', message='.*reinitialization.*')
# Set environment variable to suppress paddle warnings
os.environ['GLOG_minloglevel'] = '2' # Suppress INFO and WARNING logs
# OCR engine will be lazily initialized on first use
_worker_ocr_engine = None
_worker_initialized = True
_worker_type = 'gpu'
def _init_worker():
"""Initialize worker process with OCR engine (called once per worker).
Legacy function for backwards compatibility.
"""
_init_gpu_worker()
def _get_ocr_engine():
"""Get or create OCR engine for current worker."""
global _worker_ocr_engine
if _worker_ocr_engine is None:
# Suppress warnings during OCR initialization
with warnings.catch_warnings():
warnings.filterwarnings('ignore')
from ..ocr import OCREngine
_worker_ocr_engine = OCREngine()
return _worker_ocr_engine
def _save_output_img(output_img, image_path: Path) -> None:
"""Save OCR output_img to replace the original rendered image."""
from PIL import Image as PILImage
# Convert numpy array to PIL Image and save
if output_img is not None:
img = PILImage.fromarray(output_img)
img.save(str(image_path))
# If output_img is None, the original image is already saved
def process_single_document(args_tuple):
"""
Process a single document (worker function for parallel processing).
Args:
args_tuple: (row_dict, pdf_path, output_dir, dpi, min_confidence, skip_ocr)
Returns:
dict with results
"""
row_dict, pdf_path_str, output_dir_str, dpi, min_confidence, skip_ocr = args_tuple
# Import inside worker to avoid pickling issues
from ..data import AutoLabelReport, FieldMatchResult
from ..pdf import PDFDocument
from ..matcher import FieldMatcher
from ..normalize import normalize_field
from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
start_time = time.time()
pdf_path = Path(pdf_path_str)
output_dir = Path(output_dir_str)
doc_id = row_dict['DocumentId']
report = AutoLabelReport(document_id=doc_id)
report.pdf_path = str(pdf_path)
# Store metadata fields from CSV
report.split = row_dict.get('split')
report.customer_number = row_dict.get('customer_number')
report.supplier_name = row_dict.get('supplier_name')
report.supplier_organisation_number = row_dict.get('supplier_organisation_number')
report.supplier_accounts = row_dict.get('supplier_accounts')
result = {
'doc_id': doc_id,
'success': False,
'pages': [],
'report': None,
'stats': {name: 0 for name in FIELD_CLASSES.keys()}
}
try:
# Use PDFDocument context manager for efficient PDF handling
# Opens PDF only once, caches dimensions, handles cleanup automatically
with PDFDocument(pdf_path) as pdf_doc:
# Check PDF type (uses cached document)
use_ocr = not pdf_doc.is_text_pdf()
report.pdf_type = "scanned" if use_ocr else "text"
# Skip OCR if requested
if use_ocr and skip_ocr:
report.errors.append("Skipped (scanned PDF)")
report.processing_time_ms = (time.time() - start_time) * 1000
result['report'] = report.to_dict()
return result
# Get OCR engine from worker cache (only created once per worker)
ocr_engine = None
if use_ocr:
ocr_engine = _get_ocr_engine()
generator = AnnotationGenerator(min_confidence=min_confidence)
matcher = FieldMatcher()
# Process each page
page_annotations = []
matched_fields = set()
# Render all pages and process (uses cached document handle)
images_dir = output_dir / 'temp' / doc_id / 'images'
for page_no, image_path in pdf_doc.render_all_pages(images_dir, dpi=dpi):
report.total_pages += 1
# Get dimensions from cache (no additional PDF open)
img_width, img_height = pdf_doc.get_render_dimensions(page_no, dpi)
# Extract tokens
if use_ocr:
# Use extract_with_image to get both tokens and preprocessed image
# PaddleOCR coordinates are relative to output_img, not original image
ocr_result = ocr_engine.extract_with_image(
str(image_path),
page_no,
scale_to_pdf_points=72 / dpi
)
tokens = ocr_result.tokens
# Save output_img to replace the original rendered image
# This ensures coordinates match the saved image
_save_output_img(ocr_result.output_img, image_path)
# Update image dimensions to match output_img
if ocr_result.output_img is not None:
img_height, img_width = ocr_result.output_img.shape[:2]
else:
# Use cached document for text extraction
tokens = list(pdf_doc.extract_text_tokens(page_no))
# Match fields
matches = {}
for field_name in FIELD_CLASSES.keys():
value = row_dict.get(field_name)
if not value:
continue
normalized = normalize_field(field_name, str(value))
field_matches = matcher.find_matches(tokens, field_name, normalized, page_no)
# Record result
if field_matches:
best = field_matches[0]
matches[field_name] = field_matches
matched_fields.add(field_name)
report.add_field_result(FieldMatchResult(
field_name=field_name,
csv_value=str(value),
matched=True,
score=best.score,
matched_text=best.matched_text,
candidate_used=best.value,
bbox=best.bbox,
page_no=page_no,
context_keywords=best.context_keywords
))
# Match supplier_accounts and map to Bankgiro/Plusgiro
supplier_accounts_value = row_dict.get('supplier_accounts')
if supplier_accounts_value:
# Parse accounts: "BG:xxx | PG:yyy" format
accounts = [acc.strip() for acc in str(supplier_accounts_value).split('|')]
for account in accounts:
account = account.strip()
if not account:
continue
# Determine account type (BG or PG) and extract account number
account_type = None
account_number = account # Default to full value
if account.upper().startswith('BG:'):
account_type = 'Bankgiro'
account_number = account[3:].strip() # Remove "BG:" prefix
elif account.upper().startswith('BG '):
account_type = 'Bankgiro'
account_number = account[2:].strip() # Remove "BG" prefix
elif account.upper().startswith('PG:'):
account_type = 'Plusgiro'
account_number = account[3:].strip() # Remove "PG:" prefix
elif account.upper().startswith('PG '):
account_type = 'Plusgiro'
account_number = account[2:].strip() # Remove "PG" prefix
else:
# Try to guess from format - Plusgiro often has format XXXXXXX-X
digits = ''.join(c for c in account if c.isdigit())
if len(digits) == 8 and '-' in account:
account_type = 'Plusgiro'
elif len(digits) in (7, 8):
account_type = 'Bankgiro' # Default to Bankgiro
if not account_type:
continue
# Normalize and match using the account number (without prefix)
normalized = normalize_field('supplier_accounts', account_number)
field_matches = matcher.find_matches(tokens, account_type, normalized, page_no)
if field_matches:
best = field_matches[0]
# Add to matches under the target class (Bankgiro/Plusgiro)
if account_type not in matches:
matches[account_type] = []
matches[account_type].extend(field_matches)
matched_fields.add('supplier_accounts')
report.add_field_result(FieldMatchResult(
field_name=f'supplier_accounts({account_type})',
csv_value=account_number, # Store without prefix
matched=True,
score=best.score,
matched_text=best.matched_text,
candidate_used=best.value,
bbox=best.bbox,
page_no=page_no,
context_keywords=best.context_keywords
))
# Count annotations
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
if annotations:
page_annotations.append({
'image_path': str(image_path),
'page_no': page_no,
'count': len(annotations)
})
report.annotations_generated += len(annotations)
for ann in annotations:
class_name = list(FIELD_CLASSES.keys())[ann.class_id]
result['stats'][class_name] += 1
# Record unmatched fields
for field_name in FIELD_CLASSES.keys():
value = row_dict.get(field_name)
if value and field_name not in matched_fields:
report.add_field_result(FieldMatchResult(
field_name=field_name,
csv_value=str(value),
matched=False,
page_no=-1
))
if page_annotations:
result['pages'] = page_annotations
result['success'] = True
report.success = True
else:
report.errors.append("No annotations generated")
except Exception as e:
report.errors.append(str(e))
report.processing_time_ms = (time.time() - start_time) * 1000
result['report'] = report.to_dict()
return result
def main():
parser = argparse.ArgumentParser(
description='Generate YOLO annotations from PDFs and CSV data'
)
parser.add_argument(
'--csv', '-c',
default=f"{PATHS['csv_dir']}/*.csv",
help='Path to CSV file(s). Supports: single file, glob pattern (*.csv), or comma-separated list'
)
parser.add_argument(
'--pdf-dir', '-p',
default=PATHS['pdf_dir'],
help='Directory containing PDF files'
)
parser.add_argument(
'--output', '-o',
default=PATHS['output_dir'],
help='Output directory for dataset'
)
parser.add_argument(
'--dpi',
type=int,
default=AUTOLABEL['dpi'],
help=f"DPI for PDF rendering (default: {AUTOLABEL['dpi']})"
)
parser.add_argument(
'--min-confidence',
type=float,
default=AUTOLABEL['min_confidence'],
help=f"Minimum match confidence (default: {AUTOLABEL['min_confidence']})"
)
parser.add_argument(
'--report',
default=f"{PATHS['reports_dir']}/autolabel_report.jsonl",
help='Path for auto-label report (JSONL). With --max-records, creates report_part000.jsonl, etc.'
)
parser.add_argument(
'--max-records',
type=int,
default=10000,
help='Max records per report file for sharding (default: 10000, 0 = single file)'
)
parser.add_argument(
'--single',
help='Process single document ID only'
)
parser.add_argument(
'--verbose', '-v',
action='store_true',
help='Verbose output'
)
parser.add_argument(
'--workers', '-w',
type=int,
default=4,
help='Number of parallel workers (default: 4). Use --cpu-workers and --gpu-workers for dual-pool mode.'
)
parser.add_argument(
'--cpu-workers',
type=int,
default=None,
help='Number of CPU workers for text PDFs (enables dual-pool mode)'
)
parser.add_argument(
'--gpu-workers',
type=int,
default=1,
help='Number of GPU workers for scanned PDFs (default: 1, used with --cpu-workers)'
)
parser.add_argument(
'--skip-ocr',
action='store_true',
help='Skip scanned PDFs (text-layer only)'
)
parser.add_argument(
'--limit', '-l',
type=int,
default=None,
help='Limit number of documents to process (for testing)'
)
args = parser.parse_args()
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGINT, _signal_handler)
signal.signal(signal.SIGTERM, _signal_handler)
# Import here to avoid slow startup
from ..data import CSVLoader, AutoLabelReport, FieldMatchResult
from ..data.autolabel_report import ReportWriter
from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
from ..pdf.renderer import get_render_dimensions
from ..ocr import OCREngine
from ..matcher import FieldMatcher
from ..normalize import normalize_field
from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
# Handle comma-separated CSV paths
csv_input = args.csv
if ',' in csv_input and '*' not in csv_input:
csv_input = [p.strip() for p in csv_input.split(',')]
# Get list of CSV files (don't load all data at once)
temp_loader = CSVLoader(csv_input, args.pdf_dir)
csv_files = temp_loader.csv_paths
pdf_dir = temp_loader.pdf_dir
print(f"Found {len(csv_files)} CSV file(s) to process")
# Setup output directories
output_dir = Path(args.output)
# Only create temp directory for images (no train/val/test split during labeling)
(output_dir / 'temp').mkdir(parents=True, exist_ok=True)
# Report writer with optional sharding
report_path = Path(args.report)
report_path.parent.mkdir(parents=True, exist_ok=True)
report_writer = ReportWriter(args.report, max_records_per_file=args.max_records)
# Database connection for checking existing documents
from ..data.db import DocumentDB
db = DocumentDB()
db.connect()
db.create_tables() # Ensure tables exist
print("Connected to database for status checking")
# Global stats
stats = {
'total': 0,
'successful': 0,
'failed': 0,
'skipped': 0,
'skipped_db': 0, # Skipped because already in DB
'retried': 0, # Re-processed failed ones
'annotations': 0,
'tasks_submitted': 0, # Tracks tasks submitted across all CSVs for limit
'by_field': {name: 0 for name in FIELD_CLASSES.keys()}
}
# Track all processed items for final split (write to temp file to save memory)
processed_items_file = output_dir / 'temp' / 'processed_items.jsonl'
processed_items_file.parent.mkdir(parents=True, exist_ok=True)
processed_items_writer = open(processed_items_file, 'w', encoding='utf-8')
processed_count = 0
seen_doc_ids = set()
# Batch for database updates
db_batch = []
DB_BATCH_SIZE = 100
# Helper function to handle result and update database
# Defined outside the loop so nonlocal can properly reference db_batch
def handle_result(result):
nonlocal processed_count, db_batch
# Write report to file
if result['report']:
report_writer.write_dict(result['report'])
# Add to database batch
db_batch.append(result['report'])
if len(db_batch) >= DB_BATCH_SIZE:
db.save_documents_batch(db_batch)
db_batch.clear()
if result['success']:
# Write to temp file instead of memory
import json
processed_items_writer.write(json.dumps({
'doc_id': result['doc_id'],
'pages': result['pages']
}) + '\n')
processed_items_writer.flush()
processed_count += 1
stats['successful'] += 1
for field, count in result['stats'].items():
stats['by_field'][field] += count
stats['annotations'] += count
elif 'Skipped' in str(result.get('report', {}).get('errors', [])):
stats['skipped'] += 1
else:
stats['failed'] += 1
def handle_error(doc_id, error):
nonlocal db_batch
stats['failed'] += 1
error_report = {
'document_id': doc_id,
'success': False,
'errors': [f"Worker error: {str(error)}"]
}
report_writer.write_dict(error_report)
db_batch.append(error_report)
if len(db_batch) >= DB_BATCH_SIZE:
db.save_documents_batch(db_batch)
db_batch.clear()
if args.verbose:
print(f"Error processing {doc_id}: {error}")
# Initialize dual-pool coordinator if enabled (keeps workers alive across CSVs)
dual_pool_coordinator = None
use_dual_pool = args.cpu_workers is not None
if use_dual_pool:
from src.processing import DualPoolCoordinator
from src.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf
print(f"\nStarting dual-pool mode: {args.cpu_workers} CPU + {args.gpu_workers} GPU workers")
dual_pool_coordinator = DualPoolCoordinator(
cpu_workers=args.cpu_workers,
gpu_workers=args.gpu_workers,
gpu_id=0,
task_timeout=300.0,
)
dual_pool_coordinator.start()
try:
# Process CSV files one by one (streaming)
for csv_idx, csv_file in enumerate(csv_files):
# Check for shutdown request
if _shutdown_requested:
print("\nShutdown requested. Stopping after current batch...")
break
print(f"\n[{csv_idx + 1}/{len(csv_files)}] Processing: {csv_file.name}")
# Load only this CSV file
single_loader = CSVLoader(str(csv_file), str(pdf_dir))
rows = single_loader.load_all()
# Filter to single document if specified
if args.single:
rows = [r for r in rows if r.DocumentId == args.single]
if not rows:
continue
# Deduplicate across CSV files
rows = [r for r in rows if r.DocumentId not in seen_doc_ids]
for r in rows:
seen_doc_ids.add(r.DocumentId)
if not rows:
print(f" Skipping CSV (no new documents)")
continue
# Batch query database for all document IDs in this CSV
csv_doc_ids = [r.DocumentId for r in rows]
db_status_map = db.check_documents_status_batch(csv_doc_ids)
# Count how many are already processed successfully
already_processed = sum(1 for doc_id in csv_doc_ids if db_status_map.get(doc_id) is True)
# Skip entire CSV if all documents are already processed
if already_processed == len(rows):
print(f" Skipping CSV (all {len(rows)} documents already processed)")
stats['skipped_db'] += len(rows)
continue
# Count how many new documents need processing in this CSV
new_to_process = len(rows) - already_processed
print(f" Found {new_to_process} new documents to process ({already_processed} already in DB)")
stats['total'] += len(rows)
# Prepare tasks for this CSV
tasks = []
skipped_in_csv = 0
retry_in_csv = 0
# Calculate how many more we can process if limit is set
# Use tasks_submitted counter which tracks across all CSVs
if args.limit:
remaining_limit = args.limit - stats.get('tasks_submitted', 0)
if remaining_limit <= 0:
print(f" Reached limit of {args.limit} new documents, stopping.")
break
else:
remaining_limit = float('inf')
for row in rows:
# Stop adding tasks if we've reached the limit
if len(tasks) >= remaining_limit:
break
doc_id = row.DocumentId
# Check document status from batch query result
db_status = db_status_map.get(doc_id) # None if not in DB
# Skip if already successful in database
if db_status is True:
stats['skipped_db'] += 1
skipped_in_csv += 1
continue
# Check if this is a retry (was failed before)
if db_status is False:
stats['retried'] += 1
retry_in_csv += 1
pdf_path = single_loader.get_pdf_path(row)
if not pdf_path:
stats['skipped'] += 1
continue
row_dict = {
'DocumentId': row.DocumentId,
'InvoiceNumber': row.InvoiceNumber,
'InvoiceDate': row.InvoiceDate,
'InvoiceDueDate': row.InvoiceDueDate,
'OCR': row.OCR,
'Bankgiro': row.Bankgiro,
'Plusgiro': row.Plusgiro,
'Amount': row.Amount,
# New fields
'supplier_organisation_number': row.supplier_organisation_number,
'supplier_accounts': row.supplier_accounts,
# Metadata fields (not for matching, but for database storage)
'split': row.split,
'customer_number': row.customer_number,
'supplier_name': row.supplier_name,
}
tasks.append((
row_dict,
str(pdf_path),
str(output_dir),
args.dpi,
args.min_confidence,
args.skip_ocr
))
if skipped_in_csv > 0 or retry_in_csv > 0:
print(f" Skipped {skipped_in_csv} (already in DB), retrying {retry_in_csv} failed")
if not tasks:
continue
# Update tasks_submitted counter for limit tracking
stats['tasks_submitted'] += len(tasks)
if use_dual_pool:
# Dual-pool mode using pre-initialized DualPoolCoordinator
# (process_text_pdf, process_scanned_pdf already imported above)
# Convert tasks to new format
documents = []
for task in tasks:
row_dict, pdf_path_str, output_dir_str, dpi, min_confidence, skip_ocr = task
# Pre-classify PDF type
try:
is_text = is_text_pdf(pdf_path_str)
except Exception:
is_text = False
documents.append({
"id": row_dict["DocumentId"],
"row_dict": row_dict,
"pdf_path": pdf_path_str,
"output_dir": output_dir_str,
"dpi": dpi,
"min_confidence": min_confidence,
"is_scanned": not is_text,
"has_text": is_text,
"text_length": 1000 if is_text else 0, # Approximate
})
# Count task types
text_count = sum(1 for d in documents if not d["is_scanned"])
scan_count = len(documents) - text_count
print(f" Text PDFs: {text_count}, Scanned PDFs: {scan_count}")
# Progress tracking with tqdm
pbar = tqdm(total=len(documents), desc="Processing")
def on_result(task_result):
"""Handle successful result."""
result = task_result.data
handle_result(result)
pbar.update(1)
def on_error(task_id, error):
"""Handle failed task."""
handle_error(task_id, error)
pbar.update(1)
# Process with pre-initialized coordinator (workers stay alive)
results = dual_pool_coordinator.process_batch(
documents=documents,
cpu_task_fn=process_text_pdf,
gpu_task_fn=process_scanned_pdf,
on_result=on_result,
on_error=on_error,
id_field="id",
)
pbar.close()
# Log summary
successful = sum(1 for r in results if r.success)
failed = len(results) - successful
print(f" Batch complete: {successful} successful, {failed} failed")
else:
# Single-pool mode (original behavior)
print(f" Processing {len(tasks)} documents with {args.workers} workers...")
# Process documents in parallel (inside CSV loop for streaming)
# Use single process for debugging or when workers=1
if args.workers == 1:
for task in tqdm(tasks, desc="Processing"):
result = process_single_document(task)
handle_result(result)
else:
# Parallel processing with worker initialization
# Each worker initializes OCR engine once and reuses it
with ProcessPoolExecutor(max_workers=args.workers, initializer=_init_worker) as executor:
futures = {executor.submit(process_single_document, task): task[0]['DocumentId']
for task in tasks}
# Per-document timeout: 120 seconds (2 minutes)
# This prevents a single stuck document from blocking the entire batch
DOCUMENT_TIMEOUT = 120
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
doc_id = futures[future]
try:
result = future.result(timeout=DOCUMENT_TIMEOUT)
handle_result(result)
except TimeoutError:
handle_error(doc_id, f"Processing timeout after {DOCUMENT_TIMEOUT}s")
# Cancel the stuck future
future.cancel()
except Exception as e:
handle_error(doc_id, e)
# Flush remaining database batch after each CSV
if db_batch:
db.save_documents_batch(db_batch)
db_batch.clear()
finally:
# Shutdown dual-pool coordinator if it was started
if dual_pool_coordinator is not None:
dual_pool_coordinator.shutdown()
# Close temp file
processed_items_writer.close()
# Use the in-memory counter instead of re-reading the file (performance fix)
# processed_count already tracks the number of successfully processed items
# Cleanup processed_items temp file (not needed anymore)
processed_items_file.unlink(missing_ok=True)
# Close database connection
db.close()
# Print summary
print("\n" + "=" * 50)
print("Auto-labeling Complete")
print("=" * 50)
print(f"Total documents: {stats['total']}")
print(f"Successful: {stats['successful']}")
print(f"Failed: {stats['failed']}")
print(f"Skipped (no PDF): {stats['skipped']}")
print(f"Skipped (in DB): {stats['skipped_db']}")
print(f"Retried (failed): {stats['retried']}")
print(f"Total annotations: {stats['annotations']}")
print(f"\nImages saved to: {output_dir / 'temp'}")
print(f"Labels stored in: PostgreSQL database")
print(f"\nAnnotations by field:")
for field, count in stats['by_field'].items():
print(f" {field}: {count}")
shard_files = report_writer.get_shard_files()
if len(shard_files) > 1:
print(f"\nReport files ({len(shard_files)}):")
for sf in shard_files:
print(f" - {sf}")
else:
print(f"\nReport: {shard_files[0] if shard_files else args.report}")
if __name__ == '__main__':
main()