#!/usr/bin/env python3 """ Auto-labeling CLI Generates YOLO training data from PDFs and structured CSV data. """ import argparse import sys import time import os import signal import warnings from pathlib import Path from tqdm import tqdm from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError import multiprocessing # Global flag for graceful shutdown _shutdown_requested = False def _signal_handler(signum, frame): """Handle interrupt signals for graceful shutdown.""" global _shutdown_requested _shutdown_requested = True print("\n\nShutdown requested. Finishing current batch and saving progress...") print("(Press Ctrl+C again to force quit)\n") # Windows compatibility: use 'spawn' method for multiprocessing # This is required on Windows and is also safer for libraries like PaddleOCR if sys.platform == 'win32': multiprocessing.set_start_method('spawn', force=True) sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from config import get_db_connection_string, PATHS, AUTOLABEL # Global OCR engine for worker processes (initialized once per worker) _worker_ocr_engine = None _worker_initialized = False _worker_type = None # 'cpu' or 'gpu' def _init_cpu_worker(): """Initialize CPU worker (no OCR engine needed).""" global _worker_initialized, _worker_type _worker_initialized = True _worker_type = 'cpu' def _init_gpu_worker(): """Initialize GPU worker with OCR engine (called once per worker).""" global _worker_ocr_engine, _worker_initialized, _worker_type # Suppress PaddlePaddle/PaddleX reinitialization warnings warnings.filterwarnings('ignore', message='.*PDX has already been initialized.*') warnings.filterwarnings('ignore', message='.*reinitialization.*') # Set environment variable to suppress paddle warnings os.environ['GLOG_minloglevel'] = '2' # Suppress INFO and WARNING logs # OCR engine will be lazily initialized on first use _worker_ocr_engine = None _worker_initialized = True _worker_type = 'gpu' def _init_worker(): """Initialize worker process with OCR engine (called once per worker). Legacy function for backwards compatibility. """ _init_gpu_worker() def _get_ocr_engine(): """Get or create OCR engine for current worker.""" global _worker_ocr_engine if _worker_ocr_engine is None: # Suppress warnings during OCR initialization with warnings.catch_warnings(): warnings.filterwarnings('ignore') from ..ocr import OCREngine _worker_ocr_engine = OCREngine() return _worker_ocr_engine def _save_output_img(output_img, image_path: Path) -> None: """Save OCR output_img to replace the original rendered image.""" from PIL import Image as PILImage # Convert numpy array to PIL Image and save if output_img is not None: img = PILImage.fromarray(output_img) img.save(str(image_path)) # If output_img is None, the original image is already saved def process_single_document(args_tuple): """ Process a single document (worker function for parallel processing). Args: args_tuple: (row_dict, pdf_path, output_dir, dpi, min_confidence, skip_ocr) Returns: dict with results """ row_dict, pdf_path_str, output_dir_str, dpi, min_confidence, skip_ocr = args_tuple # Import inside worker to avoid pickling issues from ..data import AutoLabelReport, FieldMatchResult from ..pdf import PDFDocument from ..matcher import FieldMatcher from ..normalize import normalize_field from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES start_time = time.time() pdf_path = Path(pdf_path_str) output_dir = Path(output_dir_str) doc_id = row_dict['DocumentId'] report = AutoLabelReport(document_id=doc_id) report.pdf_path = str(pdf_path) # Store metadata fields from CSV report.split = row_dict.get('split') report.customer_number = row_dict.get('customer_number') report.supplier_name = row_dict.get('supplier_name') report.supplier_organisation_number = row_dict.get('supplier_organisation_number') report.supplier_accounts = row_dict.get('supplier_accounts') result = { 'doc_id': doc_id, 'success': False, 'pages': [], 'report': None, 'stats': {name: 0 for name in FIELD_CLASSES.keys()} } try: # Use PDFDocument context manager for efficient PDF handling # Opens PDF only once, caches dimensions, handles cleanup automatically with PDFDocument(pdf_path) as pdf_doc: # Check PDF type (uses cached document) use_ocr = not pdf_doc.is_text_pdf() report.pdf_type = "scanned" if use_ocr else "text" # Skip OCR if requested if use_ocr and skip_ocr: report.errors.append("Skipped (scanned PDF)") report.processing_time_ms = (time.time() - start_time) * 1000 result['report'] = report.to_dict() return result # Get OCR engine from worker cache (only created once per worker) ocr_engine = None if use_ocr: ocr_engine = _get_ocr_engine() generator = AnnotationGenerator(min_confidence=min_confidence) matcher = FieldMatcher() # Process each page page_annotations = [] matched_fields = set() # Render all pages and process (uses cached document handle) images_dir = output_dir / 'temp' / doc_id / 'images' for page_no, image_path in pdf_doc.render_all_pages(images_dir, dpi=dpi): report.total_pages += 1 # Get dimensions from cache (no additional PDF open) img_width, img_height = pdf_doc.get_render_dimensions(page_no, dpi) # Extract tokens if use_ocr: # Use extract_with_image to get both tokens and preprocessed image # PaddleOCR coordinates are relative to output_img, not original image ocr_result = ocr_engine.extract_with_image( str(image_path), page_no, scale_to_pdf_points=72 / dpi ) tokens = ocr_result.tokens # Save output_img to replace the original rendered image # This ensures coordinates match the saved image _save_output_img(ocr_result.output_img, image_path) # Update image dimensions to match output_img if ocr_result.output_img is not None: img_height, img_width = ocr_result.output_img.shape[:2] else: # Use cached document for text extraction tokens = list(pdf_doc.extract_text_tokens(page_no)) # Match fields matches = {} for field_name in FIELD_CLASSES.keys(): value = row_dict.get(field_name) if not value: continue normalized = normalize_field(field_name, str(value)) field_matches = matcher.find_matches(tokens, field_name, normalized, page_no) # Record result if field_matches: best = field_matches[0] matches[field_name] = field_matches matched_fields.add(field_name) report.add_field_result(FieldMatchResult( field_name=field_name, csv_value=str(value), matched=True, score=best.score, matched_text=best.matched_text, candidate_used=best.value, bbox=best.bbox, page_no=page_no, context_keywords=best.context_keywords )) # Match supplier_accounts and map to Bankgiro/Plusgiro supplier_accounts_value = row_dict.get('supplier_accounts') if supplier_accounts_value: # Parse accounts: "BG:xxx | PG:yyy" format accounts = [acc.strip() for acc in str(supplier_accounts_value).split('|')] for account in accounts: account = account.strip() if not account: continue # Determine account type (BG or PG) and extract account number account_type = None account_number = account # Default to full value if account.upper().startswith('BG:'): account_type = 'Bankgiro' account_number = account[3:].strip() # Remove "BG:" prefix elif account.upper().startswith('BG '): account_type = 'Bankgiro' account_number = account[2:].strip() # Remove "BG" prefix elif account.upper().startswith('PG:'): account_type = 'Plusgiro' account_number = account[3:].strip() # Remove "PG:" prefix elif account.upper().startswith('PG '): account_type = 'Plusgiro' account_number = account[2:].strip() # Remove "PG" prefix else: # Try to guess from format - Plusgiro often has format XXXXXXX-X digits = ''.join(c for c in account if c.isdigit()) if len(digits) == 8 and '-' in account: account_type = 'Plusgiro' elif len(digits) in (7, 8): account_type = 'Bankgiro' # Default to Bankgiro if not account_type: continue # Normalize and match using the account number (without prefix) normalized = normalize_field('supplier_accounts', account_number) field_matches = matcher.find_matches(tokens, account_type, normalized, page_no) if field_matches: best = field_matches[0] # Add to matches under the target class (Bankgiro/Plusgiro) if account_type not in matches: matches[account_type] = [] matches[account_type].extend(field_matches) matched_fields.add('supplier_accounts') report.add_field_result(FieldMatchResult( field_name=f'supplier_accounts({account_type})', csv_value=account_number, # Store without prefix matched=True, score=best.score, matched_text=best.matched_text, candidate_used=best.value, bbox=best.bbox, page_no=page_no, context_keywords=best.context_keywords )) # Count annotations annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi) if annotations: page_annotations.append({ 'image_path': str(image_path), 'page_no': page_no, 'count': len(annotations) }) report.annotations_generated += len(annotations) for ann in annotations: class_name = list(FIELD_CLASSES.keys())[ann.class_id] result['stats'][class_name] += 1 # Record unmatched fields for field_name in FIELD_CLASSES.keys(): value = row_dict.get(field_name) if value and field_name not in matched_fields: report.add_field_result(FieldMatchResult( field_name=field_name, csv_value=str(value), matched=False, page_no=-1 )) if page_annotations: result['pages'] = page_annotations result['success'] = True report.success = True else: report.errors.append("No annotations generated") except Exception as e: report.errors.append(str(e)) report.processing_time_ms = (time.time() - start_time) * 1000 result['report'] = report.to_dict() return result def main(): parser = argparse.ArgumentParser( description='Generate YOLO annotations from PDFs and CSV data' ) parser.add_argument( '--csv', '-c', default=f"{PATHS['csv_dir']}/*.csv", help='Path to CSV file(s). Supports: single file, glob pattern (*.csv), or comma-separated list' ) parser.add_argument( '--pdf-dir', '-p', default=PATHS['pdf_dir'], help='Directory containing PDF files' ) parser.add_argument( '--output', '-o', default=PATHS['output_dir'], help='Output directory for dataset' ) parser.add_argument( '--dpi', type=int, default=AUTOLABEL['dpi'], help=f"DPI for PDF rendering (default: {AUTOLABEL['dpi']})" ) parser.add_argument( '--min-confidence', type=float, default=AUTOLABEL['min_confidence'], help=f"Minimum match confidence (default: {AUTOLABEL['min_confidence']})" ) parser.add_argument( '--report', default=f"{PATHS['reports_dir']}/autolabel_report.jsonl", help='Path for auto-label report (JSONL). With --max-records, creates report_part000.jsonl, etc.' ) parser.add_argument( '--max-records', type=int, default=10000, help='Max records per report file for sharding (default: 10000, 0 = single file)' ) parser.add_argument( '--single', help='Process single document ID only' ) parser.add_argument( '--verbose', '-v', action='store_true', help='Verbose output' ) parser.add_argument( '--workers', '-w', type=int, default=4, help='Number of parallel workers (default: 4). Use --cpu-workers and --gpu-workers for dual-pool mode.' ) parser.add_argument( '--cpu-workers', type=int, default=None, help='Number of CPU workers for text PDFs (enables dual-pool mode)' ) parser.add_argument( '--gpu-workers', type=int, default=1, help='Number of GPU workers for scanned PDFs (default: 1, used with --cpu-workers)' ) parser.add_argument( '--skip-ocr', action='store_true', help='Skip scanned PDFs (text-layer only)' ) parser.add_argument( '--limit', '-l', type=int, default=None, help='Limit number of documents to process (for testing)' ) args = parser.parse_args() # Register signal handlers for graceful shutdown signal.signal(signal.SIGINT, _signal_handler) signal.signal(signal.SIGTERM, _signal_handler) # Import here to avoid slow startup from ..data import CSVLoader, AutoLabelReport, FieldMatchResult from ..data.autolabel_report import ReportWriter from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens from ..pdf.renderer import get_render_dimensions from ..ocr import OCREngine from ..matcher import FieldMatcher from ..normalize import normalize_field from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES # Handle comma-separated CSV paths csv_input = args.csv if ',' in csv_input and '*' not in csv_input: csv_input = [p.strip() for p in csv_input.split(',')] # Get list of CSV files (don't load all data at once) temp_loader = CSVLoader(csv_input, args.pdf_dir) csv_files = temp_loader.csv_paths pdf_dir = temp_loader.pdf_dir print(f"Found {len(csv_files)} CSV file(s) to process") # Setup output directories output_dir = Path(args.output) # Only create temp directory for images (no train/val/test split during labeling) (output_dir / 'temp').mkdir(parents=True, exist_ok=True) # Report writer with optional sharding report_path = Path(args.report) report_path.parent.mkdir(parents=True, exist_ok=True) report_writer = ReportWriter(args.report, max_records_per_file=args.max_records) # Database connection for checking existing documents from ..data.db import DocumentDB db = DocumentDB() db.connect() db.create_tables() # Ensure tables exist print("Connected to database for status checking") # Global stats stats = { 'total': 0, 'successful': 0, 'failed': 0, 'skipped': 0, 'skipped_db': 0, # Skipped because already in DB 'retried': 0, # Re-processed failed ones 'annotations': 0, 'tasks_submitted': 0, # Tracks tasks submitted across all CSVs for limit 'by_field': {name: 0 for name in FIELD_CLASSES.keys()} } # Track all processed items for final split (write to temp file to save memory) processed_items_file = output_dir / 'temp' / 'processed_items.jsonl' processed_items_file.parent.mkdir(parents=True, exist_ok=True) processed_items_writer = open(processed_items_file, 'w', encoding='utf-8') processed_count = 0 seen_doc_ids = set() # Batch for database updates db_batch = [] DB_BATCH_SIZE = 100 # Helper function to handle result and update database # Defined outside the loop so nonlocal can properly reference db_batch def handle_result(result): nonlocal processed_count, db_batch # Write report to file if result['report']: report_writer.write_dict(result['report']) # Add to database batch db_batch.append(result['report']) if len(db_batch) >= DB_BATCH_SIZE: db.save_documents_batch(db_batch) db_batch.clear() if result['success']: # Write to temp file instead of memory import json processed_items_writer.write(json.dumps({ 'doc_id': result['doc_id'], 'pages': result['pages'] }) + '\n') processed_items_writer.flush() processed_count += 1 stats['successful'] += 1 for field, count in result['stats'].items(): stats['by_field'][field] += count stats['annotations'] += count elif 'Skipped' in str(result.get('report', {}).get('errors', [])): stats['skipped'] += 1 else: stats['failed'] += 1 def handle_error(doc_id, error): nonlocal db_batch stats['failed'] += 1 error_report = { 'document_id': doc_id, 'success': False, 'errors': [f"Worker error: {str(error)}"] } report_writer.write_dict(error_report) db_batch.append(error_report) if len(db_batch) >= DB_BATCH_SIZE: db.save_documents_batch(db_batch) db_batch.clear() if args.verbose: print(f"Error processing {doc_id}: {error}") # Initialize dual-pool coordinator if enabled (keeps workers alive across CSVs) dual_pool_coordinator = None use_dual_pool = args.cpu_workers is not None if use_dual_pool: from src.processing import DualPoolCoordinator from src.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf print(f"\nStarting dual-pool mode: {args.cpu_workers} CPU + {args.gpu_workers} GPU workers") dual_pool_coordinator = DualPoolCoordinator( cpu_workers=args.cpu_workers, gpu_workers=args.gpu_workers, gpu_id=0, task_timeout=300.0, ) dual_pool_coordinator.start() try: # Process CSV files one by one (streaming) for csv_idx, csv_file in enumerate(csv_files): # Check for shutdown request if _shutdown_requested: print("\nShutdown requested. Stopping after current batch...") break print(f"\n[{csv_idx + 1}/{len(csv_files)}] Processing: {csv_file.name}") # Load only this CSV file single_loader = CSVLoader(str(csv_file), str(pdf_dir)) rows = single_loader.load_all() # Filter to single document if specified if args.single: rows = [r for r in rows if r.DocumentId == args.single] if not rows: continue # Deduplicate across CSV files rows = [r for r in rows if r.DocumentId not in seen_doc_ids] for r in rows: seen_doc_ids.add(r.DocumentId) if not rows: print(f" Skipping CSV (no new documents)") continue # Batch query database for all document IDs in this CSV csv_doc_ids = [r.DocumentId for r in rows] db_status_map = db.check_documents_status_batch(csv_doc_ids) # Count how many are already processed successfully already_processed = sum(1 for doc_id in csv_doc_ids if db_status_map.get(doc_id) is True) # Skip entire CSV if all documents are already processed if already_processed == len(rows): print(f" Skipping CSV (all {len(rows)} documents already processed)") stats['skipped_db'] += len(rows) continue # Count how many new documents need processing in this CSV new_to_process = len(rows) - already_processed print(f" Found {new_to_process} new documents to process ({already_processed} already in DB)") stats['total'] += len(rows) # Prepare tasks for this CSV tasks = [] skipped_in_csv = 0 retry_in_csv = 0 # Calculate how many more we can process if limit is set # Use tasks_submitted counter which tracks across all CSVs if args.limit: remaining_limit = args.limit - stats.get('tasks_submitted', 0) if remaining_limit <= 0: print(f" Reached limit of {args.limit} new documents, stopping.") break else: remaining_limit = float('inf') for row in rows: # Stop adding tasks if we've reached the limit if len(tasks) >= remaining_limit: break doc_id = row.DocumentId # Check document status from batch query result db_status = db_status_map.get(doc_id) # None if not in DB # Skip if already successful in database if db_status is True: stats['skipped_db'] += 1 skipped_in_csv += 1 continue # Check if this is a retry (was failed before) if db_status is False: stats['retried'] += 1 retry_in_csv += 1 pdf_path = single_loader.get_pdf_path(row) if not pdf_path: stats['skipped'] += 1 continue row_dict = { 'DocumentId': row.DocumentId, 'InvoiceNumber': row.InvoiceNumber, 'InvoiceDate': row.InvoiceDate, 'InvoiceDueDate': row.InvoiceDueDate, 'OCR': row.OCR, 'Bankgiro': row.Bankgiro, 'Plusgiro': row.Plusgiro, 'Amount': row.Amount, # New fields 'supplier_organisation_number': row.supplier_organisation_number, 'supplier_accounts': row.supplier_accounts, # Metadata fields (not for matching, but for database storage) 'split': row.split, 'customer_number': row.customer_number, 'supplier_name': row.supplier_name, } tasks.append(( row_dict, str(pdf_path), str(output_dir), args.dpi, args.min_confidence, args.skip_ocr )) if skipped_in_csv > 0 or retry_in_csv > 0: print(f" Skipped {skipped_in_csv} (already in DB), retrying {retry_in_csv} failed") if not tasks: continue # Update tasks_submitted counter for limit tracking stats['tasks_submitted'] += len(tasks) if use_dual_pool: # Dual-pool mode using pre-initialized DualPoolCoordinator # (process_text_pdf, process_scanned_pdf already imported above) # Convert tasks to new format documents = [] for task in tasks: row_dict, pdf_path_str, output_dir_str, dpi, min_confidence, skip_ocr = task # Pre-classify PDF type try: is_text = is_text_pdf(pdf_path_str) except Exception: is_text = False documents.append({ "id": row_dict["DocumentId"], "row_dict": row_dict, "pdf_path": pdf_path_str, "output_dir": output_dir_str, "dpi": dpi, "min_confidence": min_confidence, "is_scanned": not is_text, "has_text": is_text, "text_length": 1000 if is_text else 0, # Approximate }) # Count task types text_count = sum(1 for d in documents if not d["is_scanned"]) scan_count = len(documents) - text_count print(f" Text PDFs: {text_count}, Scanned PDFs: {scan_count}") # Progress tracking with tqdm pbar = tqdm(total=len(documents), desc="Processing") def on_result(task_result): """Handle successful result.""" result = task_result.data handle_result(result) pbar.update(1) def on_error(task_id, error): """Handle failed task.""" handle_error(task_id, error) pbar.update(1) # Process with pre-initialized coordinator (workers stay alive) results = dual_pool_coordinator.process_batch( documents=documents, cpu_task_fn=process_text_pdf, gpu_task_fn=process_scanned_pdf, on_result=on_result, on_error=on_error, id_field="id", ) pbar.close() # Log summary successful = sum(1 for r in results if r.success) failed = len(results) - successful print(f" Batch complete: {successful} successful, {failed} failed") else: # Single-pool mode (original behavior) print(f" Processing {len(tasks)} documents with {args.workers} workers...") # Process documents in parallel (inside CSV loop for streaming) # Use single process for debugging or when workers=1 if args.workers == 1: for task in tqdm(tasks, desc="Processing"): result = process_single_document(task) handle_result(result) else: # Parallel processing with worker initialization # Each worker initializes OCR engine once and reuses it with ProcessPoolExecutor(max_workers=args.workers, initializer=_init_worker) as executor: futures = {executor.submit(process_single_document, task): task[0]['DocumentId'] for task in tasks} # Per-document timeout: 120 seconds (2 minutes) # This prevents a single stuck document from blocking the entire batch DOCUMENT_TIMEOUT = 120 for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"): doc_id = futures[future] try: result = future.result(timeout=DOCUMENT_TIMEOUT) handle_result(result) except TimeoutError: handle_error(doc_id, f"Processing timeout after {DOCUMENT_TIMEOUT}s") # Cancel the stuck future future.cancel() except Exception as e: handle_error(doc_id, e) # Flush remaining database batch after each CSV if db_batch: db.save_documents_batch(db_batch) db_batch.clear() finally: # Shutdown dual-pool coordinator if it was started if dual_pool_coordinator is not None: dual_pool_coordinator.shutdown() # Close temp file processed_items_writer.close() # Use the in-memory counter instead of re-reading the file (performance fix) # processed_count already tracks the number of successfully processed items # Cleanup processed_items temp file (not needed anymore) processed_items_file.unlink(missing_ok=True) # Close database connection db.close() # Print summary print("\n" + "=" * 50) print("Auto-labeling Complete") print("=" * 50) print(f"Total documents: {stats['total']}") print(f"Successful: {stats['successful']}") print(f"Failed: {stats['failed']}") print(f"Skipped (no PDF): {stats['skipped']}") print(f"Skipped (in DB): {stats['skipped_db']}") print(f"Retried (failed): {stats['retried']}") print(f"Total annotations: {stats['annotations']}") print(f"\nImages saved to: {output_dir / 'temp'}") print(f"Labels stored in: PostgreSQL database") print(f"\nAnnotations by field:") for field, count in stats['by_field'].items(): print(f" {field}: {count}") shard_files = report_writer.get_shard_files() if len(shard_files) > 1: print(f"\nReport files ({len(shard_files)}):") for sf in shard_files: print(f" - {sf}") else: print(f"\nReport: {shard_files[0] if shard_files else args.report}") if __name__ == '__main__': main()