811 lines
31 KiB
Python
811 lines
31 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Auto-labeling CLI
|
|
|
|
Generates YOLO training data from PDFs and structured CSV data.
|
|
"""
|
|
|
|
import argparse
|
|
import sys
|
|
import time
|
|
import os
|
|
import signal
|
|
import warnings
|
|
from pathlib import Path
|
|
from tqdm import tqdm
|
|
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
|
|
import multiprocessing
|
|
|
|
# Global flag for graceful shutdown
|
|
_shutdown_requested = False
|
|
|
|
|
|
def _signal_handler(signum, frame):
|
|
"""Handle interrupt signals for graceful shutdown."""
|
|
global _shutdown_requested
|
|
_shutdown_requested = True
|
|
print("\n\nShutdown requested. Finishing current batch and saving progress...")
|
|
print("(Press Ctrl+C again to force quit)\n")
|
|
|
|
# Windows compatibility: use 'spawn' method for multiprocessing
|
|
# This is required on Windows and is also safer for libraries like PaddleOCR
|
|
if sys.platform == 'win32':
|
|
multiprocessing.set_start_method('spawn', force=True)
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
|
from config import get_db_connection_string, PATHS, AUTOLABEL
|
|
|
|
# Global OCR engine for worker processes (initialized once per worker)
|
|
_worker_ocr_engine = None
|
|
_worker_initialized = False
|
|
_worker_type = None # 'cpu' or 'gpu'
|
|
|
|
|
|
def _init_cpu_worker():
|
|
"""Initialize CPU worker (no OCR engine needed)."""
|
|
global _worker_initialized, _worker_type
|
|
_worker_initialized = True
|
|
_worker_type = 'cpu'
|
|
|
|
|
|
def _init_gpu_worker():
|
|
"""Initialize GPU worker with OCR engine (called once per worker)."""
|
|
global _worker_ocr_engine, _worker_initialized, _worker_type
|
|
|
|
# Suppress PaddlePaddle/PaddleX reinitialization warnings
|
|
warnings.filterwarnings('ignore', message='.*PDX has already been initialized.*')
|
|
warnings.filterwarnings('ignore', message='.*reinitialization.*')
|
|
|
|
# Set environment variable to suppress paddle warnings
|
|
os.environ['GLOG_minloglevel'] = '2' # Suppress INFO and WARNING logs
|
|
|
|
# OCR engine will be lazily initialized on first use
|
|
_worker_ocr_engine = None
|
|
_worker_initialized = True
|
|
_worker_type = 'gpu'
|
|
|
|
|
|
def _init_worker():
|
|
"""Initialize worker process with OCR engine (called once per worker).
|
|
Legacy function for backwards compatibility.
|
|
"""
|
|
_init_gpu_worker()
|
|
|
|
|
|
def _get_ocr_engine():
|
|
"""Get or create OCR engine for current worker."""
|
|
global _worker_ocr_engine
|
|
|
|
if _worker_ocr_engine is None:
|
|
# Suppress warnings during OCR initialization
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings('ignore')
|
|
from ..ocr import OCREngine
|
|
_worker_ocr_engine = OCREngine()
|
|
|
|
return _worker_ocr_engine
|
|
|
|
|
|
def _save_output_img(output_img, image_path: Path) -> None:
|
|
"""Save OCR output_img to replace the original rendered image."""
|
|
from PIL import Image as PILImage
|
|
|
|
# Convert numpy array to PIL Image and save
|
|
if output_img is not None:
|
|
img = PILImage.fromarray(output_img)
|
|
img.save(str(image_path))
|
|
# If output_img is None, the original image is already saved
|
|
|
|
|
|
def process_single_document(args_tuple):
|
|
"""
|
|
Process a single document (worker function for parallel processing).
|
|
|
|
Args:
|
|
args_tuple: (row_dict, pdf_path, output_dir, dpi, min_confidence, skip_ocr)
|
|
|
|
Returns:
|
|
dict with results
|
|
"""
|
|
row_dict, pdf_path_str, output_dir_str, dpi, min_confidence, skip_ocr = args_tuple
|
|
|
|
# Import inside worker to avoid pickling issues
|
|
from ..data import AutoLabelReport, FieldMatchResult
|
|
from ..pdf import PDFDocument
|
|
from ..matcher import FieldMatcher
|
|
from ..normalize import normalize_field
|
|
from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
|
|
|
|
start_time = time.time()
|
|
pdf_path = Path(pdf_path_str)
|
|
output_dir = Path(output_dir_str)
|
|
doc_id = row_dict['DocumentId']
|
|
|
|
report = AutoLabelReport(document_id=doc_id)
|
|
report.pdf_path = str(pdf_path)
|
|
# Store metadata fields from CSV
|
|
report.split = row_dict.get('split')
|
|
report.customer_number = row_dict.get('customer_number')
|
|
report.supplier_name = row_dict.get('supplier_name')
|
|
report.supplier_organisation_number = row_dict.get('supplier_organisation_number')
|
|
report.supplier_accounts = row_dict.get('supplier_accounts')
|
|
|
|
result = {
|
|
'doc_id': doc_id,
|
|
'success': False,
|
|
'pages': [],
|
|
'report': None,
|
|
'stats': {name: 0 for name in FIELD_CLASSES.keys()}
|
|
}
|
|
|
|
try:
|
|
# Use PDFDocument context manager for efficient PDF handling
|
|
# Opens PDF only once, caches dimensions, handles cleanup automatically
|
|
with PDFDocument(pdf_path) as pdf_doc:
|
|
# Check PDF type (uses cached document)
|
|
use_ocr = not pdf_doc.is_text_pdf()
|
|
report.pdf_type = "scanned" if use_ocr else "text"
|
|
|
|
# Skip OCR if requested
|
|
if use_ocr and skip_ocr:
|
|
report.errors.append("Skipped (scanned PDF)")
|
|
report.processing_time_ms = (time.time() - start_time) * 1000
|
|
result['report'] = report.to_dict()
|
|
return result
|
|
|
|
# Get OCR engine from worker cache (only created once per worker)
|
|
ocr_engine = None
|
|
if use_ocr:
|
|
ocr_engine = _get_ocr_engine()
|
|
|
|
generator = AnnotationGenerator(min_confidence=min_confidence)
|
|
matcher = FieldMatcher()
|
|
|
|
# Process each page
|
|
page_annotations = []
|
|
matched_fields = set()
|
|
|
|
# Render all pages and process (uses cached document handle)
|
|
images_dir = output_dir / 'temp' / doc_id / 'images'
|
|
for page_no, image_path in pdf_doc.render_all_pages(images_dir, dpi=dpi):
|
|
report.total_pages += 1
|
|
|
|
# Get dimensions from cache (no additional PDF open)
|
|
img_width, img_height = pdf_doc.get_render_dimensions(page_no, dpi)
|
|
|
|
# Extract tokens
|
|
if use_ocr:
|
|
# Use extract_with_image to get both tokens and preprocessed image
|
|
# PaddleOCR coordinates are relative to output_img, not original image
|
|
ocr_result = ocr_engine.extract_with_image(
|
|
str(image_path),
|
|
page_no,
|
|
scale_to_pdf_points=72 / dpi
|
|
)
|
|
tokens = ocr_result.tokens
|
|
|
|
# Save output_img to replace the original rendered image
|
|
# This ensures coordinates match the saved image
|
|
_save_output_img(ocr_result.output_img, image_path)
|
|
|
|
# Update image dimensions to match output_img
|
|
if ocr_result.output_img is not None:
|
|
img_height, img_width = ocr_result.output_img.shape[:2]
|
|
else:
|
|
# Use cached document for text extraction
|
|
tokens = list(pdf_doc.extract_text_tokens(page_no))
|
|
|
|
# Match fields
|
|
matches = {}
|
|
for field_name in FIELD_CLASSES.keys():
|
|
value = row_dict.get(field_name)
|
|
if not value:
|
|
continue
|
|
|
|
normalized = normalize_field(field_name, str(value))
|
|
field_matches = matcher.find_matches(tokens, field_name, normalized, page_no)
|
|
|
|
# Record result
|
|
if field_matches:
|
|
best = field_matches[0]
|
|
matches[field_name] = field_matches
|
|
matched_fields.add(field_name)
|
|
report.add_field_result(FieldMatchResult(
|
|
field_name=field_name,
|
|
csv_value=str(value),
|
|
matched=True,
|
|
score=best.score,
|
|
matched_text=best.matched_text,
|
|
candidate_used=best.value,
|
|
bbox=best.bbox,
|
|
page_no=page_no,
|
|
context_keywords=best.context_keywords
|
|
))
|
|
|
|
# Match supplier_accounts and map to Bankgiro/Plusgiro
|
|
supplier_accounts_value = row_dict.get('supplier_accounts')
|
|
if supplier_accounts_value:
|
|
# Parse accounts: "BG:xxx | PG:yyy" format
|
|
accounts = [acc.strip() for acc in str(supplier_accounts_value).split('|')]
|
|
for account in accounts:
|
|
account = account.strip()
|
|
if not account:
|
|
continue
|
|
|
|
# Determine account type (BG or PG) and extract account number
|
|
account_type = None
|
|
account_number = account # Default to full value
|
|
|
|
if account.upper().startswith('BG:'):
|
|
account_type = 'Bankgiro'
|
|
account_number = account[3:].strip() # Remove "BG:" prefix
|
|
elif account.upper().startswith('BG '):
|
|
account_type = 'Bankgiro'
|
|
account_number = account[2:].strip() # Remove "BG" prefix
|
|
elif account.upper().startswith('PG:'):
|
|
account_type = 'Plusgiro'
|
|
account_number = account[3:].strip() # Remove "PG:" prefix
|
|
elif account.upper().startswith('PG '):
|
|
account_type = 'Plusgiro'
|
|
account_number = account[2:].strip() # Remove "PG" prefix
|
|
else:
|
|
# Try to guess from format - Plusgiro often has format XXXXXXX-X
|
|
digits = ''.join(c for c in account if c.isdigit())
|
|
if len(digits) == 8 and '-' in account:
|
|
account_type = 'Plusgiro'
|
|
elif len(digits) in (7, 8):
|
|
account_type = 'Bankgiro' # Default to Bankgiro
|
|
|
|
if not account_type:
|
|
continue
|
|
|
|
# Normalize and match using the account number (without prefix)
|
|
normalized = normalize_field('supplier_accounts', account_number)
|
|
field_matches = matcher.find_matches(tokens, account_type, normalized, page_no)
|
|
|
|
if field_matches:
|
|
best = field_matches[0]
|
|
# Add to matches under the target class (Bankgiro/Plusgiro)
|
|
if account_type not in matches:
|
|
matches[account_type] = []
|
|
matches[account_type].extend(field_matches)
|
|
matched_fields.add('supplier_accounts')
|
|
|
|
report.add_field_result(FieldMatchResult(
|
|
field_name=f'supplier_accounts({account_type})',
|
|
csv_value=account_number, # Store without prefix
|
|
matched=True,
|
|
score=best.score,
|
|
matched_text=best.matched_text,
|
|
candidate_used=best.value,
|
|
bbox=best.bbox,
|
|
page_no=page_no,
|
|
context_keywords=best.context_keywords
|
|
))
|
|
|
|
# Count annotations
|
|
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
|
|
|
|
if annotations:
|
|
page_annotations.append({
|
|
'image_path': str(image_path),
|
|
'page_no': page_no,
|
|
'count': len(annotations)
|
|
})
|
|
|
|
report.annotations_generated += len(annotations)
|
|
for ann in annotations:
|
|
class_name = list(FIELD_CLASSES.keys())[ann.class_id]
|
|
result['stats'][class_name] += 1
|
|
|
|
# Record unmatched fields
|
|
for field_name in FIELD_CLASSES.keys():
|
|
value = row_dict.get(field_name)
|
|
if value and field_name not in matched_fields:
|
|
report.add_field_result(FieldMatchResult(
|
|
field_name=field_name,
|
|
csv_value=str(value),
|
|
matched=False,
|
|
page_no=-1
|
|
))
|
|
|
|
if page_annotations:
|
|
result['pages'] = page_annotations
|
|
result['success'] = True
|
|
report.success = True
|
|
else:
|
|
report.errors.append("No annotations generated")
|
|
|
|
except Exception as e:
|
|
report.errors.append(str(e))
|
|
|
|
report.processing_time_ms = (time.time() - start_time) * 1000
|
|
result['report'] = report.to_dict()
|
|
|
|
return result
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description='Generate YOLO annotations from PDFs and CSV data'
|
|
)
|
|
parser.add_argument(
|
|
'--csv', '-c',
|
|
default=f"{PATHS['csv_dir']}/*.csv",
|
|
help='Path to CSV file(s). Supports: single file, glob pattern (*.csv), or comma-separated list'
|
|
)
|
|
parser.add_argument(
|
|
'--pdf-dir', '-p',
|
|
default=PATHS['pdf_dir'],
|
|
help='Directory containing PDF files'
|
|
)
|
|
parser.add_argument(
|
|
'--output', '-o',
|
|
default=PATHS['output_dir'],
|
|
help='Output directory for dataset'
|
|
)
|
|
parser.add_argument(
|
|
'--dpi',
|
|
type=int,
|
|
default=AUTOLABEL['dpi'],
|
|
help=f"DPI for PDF rendering (default: {AUTOLABEL['dpi']})"
|
|
)
|
|
parser.add_argument(
|
|
'--min-confidence',
|
|
type=float,
|
|
default=AUTOLABEL['min_confidence'],
|
|
help=f"Minimum match confidence (default: {AUTOLABEL['min_confidence']})"
|
|
)
|
|
parser.add_argument(
|
|
'--report',
|
|
default=f"{PATHS['reports_dir']}/autolabel_report.jsonl",
|
|
help='Path for auto-label report (JSONL). With --max-records, creates report_part000.jsonl, etc.'
|
|
)
|
|
parser.add_argument(
|
|
'--max-records',
|
|
type=int,
|
|
default=10000,
|
|
help='Max records per report file for sharding (default: 10000, 0 = single file)'
|
|
)
|
|
parser.add_argument(
|
|
'--single',
|
|
help='Process single document ID only'
|
|
)
|
|
parser.add_argument(
|
|
'--verbose', '-v',
|
|
action='store_true',
|
|
help='Verbose output'
|
|
)
|
|
parser.add_argument(
|
|
'--workers', '-w',
|
|
type=int,
|
|
default=4,
|
|
help='Number of parallel workers (default: 4). Use --cpu-workers and --gpu-workers for dual-pool mode.'
|
|
)
|
|
parser.add_argument(
|
|
'--cpu-workers',
|
|
type=int,
|
|
default=None,
|
|
help='Number of CPU workers for text PDFs (enables dual-pool mode)'
|
|
)
|
|
parser.add_argument(
|
|
'--gpu-workers',
|
|
type=int,
|
|
default=1,
|
|
help='Number of GPU workers for scanned PDFs (default: 1, used with --cpu-workers)'
|
|
)
|
|
parser.add_argument(
|
|
'--skip-ocr',
|
|
action='store_true',
|
|
help='Skip scanned PDFs (text-layer only)'
|
|
)
|
|
parser.add_argument(
|
|
'--limit', '-l',
|
|
type=int,
|
|
default=None,
|
|
help='Limit number of documents to process (for testing)'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Register signal handlers for graceful shutdown
|
|
signal.signal(signal.SIGINT, _signal_handler)
|
|
signal.signal(signal.SIGTERM, _signal_handler)
|
|
|
|
# Import here to avoid slow startup
|
|
from ..data import CSVLoader, AutoLabelReport, FieldMatchResult
|
|
from ..data.autolabel_report import ReportWriter
|
|
from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
|
|
from ..pdf.renderer import get_render_dimensions
|
|
from ..ocr import OCREngine
|
|
from ..matcher import FieldMatcher
|
|
from ..normalize import normalize_field
|
|
from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
|
|
|
|
# Handle comma-separated CSV paths
|
|
csv_input = args.csv
|
|
if ',' in csv_input and '*' not in csv_input:
|
|
csv_input = [p.strip() for p in csv_input.split(',')]
|
|
|
|
# Get list of CSV files (don't load all data at once)
|
|
temp_loader = CSVLoader(csv_input, args.pdf_dir)
|
|
csv_files = temp_loader.csv_paths
|
|
pdf_dir = temp_loader.pdf_dir
|
|
print(f"Found {len(csv_files)} CSV file(s) to process")
|
|
|
|
# Setup output directories
|
|
output_dir = Path(args.output)
|
|
# Only create temp directory for images (no train/val/test split during labeling)
|
|
(output_dir / 'temp').mkdir(parents=True, exist_ok=True)
|
|
|
|
# Report writer with optional sharding
|
|
report_path = Path(args.report)
|
|
report_path.parent.mkdir(parents=True, exist_ok=True)
|
|
report_writer = ReportWriter(args.report, max_records_per_file=args.max_records)
|
|
|
|
# Database connection for checking existing documents
|
|
from ..data.db import DocumentDB
|
|
db = DocumentDB()
|
|
db.connect()
|
|
db.create_tables() # Ensure tables exist
|
|
print("Connected to database for status checking")
|
|
|
|
# Global stats
|
|
stats = {
|
|
'total': 0,
|
|
'successful': 0,
|
|
'failed': 0,
|
|
'skipped': 0,
|
|
'skipped_db': 0, # Skipped because already in DB
|
|
'retried': 0, # Re-processed failed ones
|
|
'annotations': 0,
|
|
'tasks_submitted': 0, # Tracks tasks submitted across all CSVs for limit
|
|
'by_field': {name: 0 for name in FIELD_CLASSES.keys()}
|
|
}
|
|
|
|
# Track all processed items for final split (write to temp file to save memory)
|
|
processed_items_file = output_dir / 'temp' / 'processed_items.jsonl'
|
|
processed_items_file.parent.mkdir(parents=True, exist_ok=True)
|
|
processed_items_writer = open(processed_items_file, 'w', encoding='utf-8')
|
|
processed_count = 0
|
|
seen_doc_ids = set()
|
|
|
|
# Batch for database updates
|
|
db_batch = []
|
|
DB_BATCH_SIZE = 100
|
|
|
|
# Helper function to handle result and update database
|
|
# Defined outside the loop so nonlocal can properly reference db_batch
|
|
def handle_result(result):
|
|
nonlocal processed_count, db_batch
|
|
|
|
# Write report to file
|
|
if result['report']:
|
|
report_writer.write_dict(result['report'])
|
|
# Add to database batch
|
|
db_batch.append(result['report'])
|
|
if len(db_batch) >= DB_BATCH_SIZE:
|
|
db.save_documents_batch(db_batch)
|
|
db_batch.clear()
|
|
|
|
if result['success']:
|
|
# Write to temp file instead of memory
|
|
import json
|
|
processed_items_writer.write(json.dumps({
|
|
'doc_id': result['doc_id'],
|
|
'pages': result['pages']
|
|
}) + '\n')
|
|
processed_items_writer.flush()
|
|
processed_count += 1
|
|
stats['successful'] += 1
|
|
for field, count in result['stats'].items():
|
|
stats['by_field'][field] += count
|
|
stats['annotations'] += count
|
|
elif 'Skipped' in str(result.get('report', {}).get('errors', [])):
|
|
stats['skipped'] += 1
|
|
else:
|
|
stats['failed'] += 1
|
|
|
|
def handle_error(doc_id, error):
|
|
nonlocal db_batch
|
|
stats['failed'] += 1
|
|
error_report = {
|
|
'document_id': doc_id,
|
|
'success': False,
|
|
'errors': [f"Worker error: {str(error)}"]
|
|
}
|
|
report_writer.write_dict(error_report)
|
|
db_batch.append(error_report)
|
|
if len(db_batch) >= DB_BATCH_SIZE:
|
|
db.save_documents_batch(db_batch)
|
|
db_batch.clear()
|
|
if args.verbose:
|
|
print(f"Error processing {doc_id}: {error}")
|
|
|
|
# Initialize dual-pool coordinator if enabled (keeps workers alive across CSVs)
|
|
dual_pool_coordinator = None
|
|
use_dual_pool = args.cpu_workers is not None
|
|
|
|
if use_dual_pool:
|
|
from src.processing import DualPoolCoordinator
|
|
from src.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf
|
|
|
|
print(f"\nStarting dual-pool mode: {args.cpu_workers} CPU + {args.gpu_workers} GPU workers")
|
|
dual_pool_coordinator = DualPoolCoordinator(
|
|
cpu_workers=args.cpu_workers,
|
|
gpu_workers=args.gpu_workers,
|
|
gpu_id=0,
|
|
task_timeout=300.0,
|
|
)
|
|
dual_pool_coordinator.start()
|
|
|
|
try:
|
|
# Process CSV files one by one (streaming)
|
|
for csv_idx, csv_file in enumerate(csv_files):
|
|
# Check for shutdown request
|
|
if _shutdown_requested:
|
|
print("\nShutdown requested. Stopping after current batch...")
|
|
break
|
|
|
|
print(f"\n[{csv_idx + 1}/{len(csv_files)}] Processing: {csv_file.name}")
|
|
|
|
# Load only this CSV file
|
|
single_loader = CSVLoader(str(csv_file), str(pdf_dir))
|
|
rows = single_loader.load_all()
|
|
|
|
# Filter to single document if specified
|
|
if args.single:
|
|
rows = [r for r in rows if r.DocumentId == args.single]
|
|
if not rows:
|
|
continue
|
|
|
|
# Deduplicate across CSV files
|
|
rows = [r for r in rows if r.DocumentId not in seen_doc_ids]
|
|
for r in rows:
|
|
seen_doc_ids.add(r.DocumentId)
|
|
|
|
if not rows:
|
|
print(f" Skipping CSV (no new documents)")
|
|
continue
|
|
|
|
# Batch query database for all document IDs in this CSV
|
|
csv_doc_ids = [r.DocumentId for r in rows]
|
|
db_status_map = db.check_documents_status_batch(csv_doc_ids)
|
|
|
|
# Count how many are already processed successfully
|
|
already_processed = sum(1 for doc_id in csv_doc_ids if db_status_map.get(doc_id) is True)
|
|
|
|
# Skip entire CSV if all documents are already processed
|
|
if already_processed == len(rows):
|
|
print(f" Skipping CSV (all {len(rows)} documents already processed)")
|
|
stats['skipped_db'] += len(rows)
|
|
continue
|
|
|
|
# Count how many new documents need processing in this CSV
|
|
new_to_process = len(rows) - already_processed
|
|
print(f" Found {new_to_process} new documents to process ({already_processed} already in DB)")
|
|
|
|
stats['total'] += len(rows)
|
|
|
|
# Prepare tasks for this CSV
|
|
tasks = []
|
|
skipped_in_csv = 0
|
|
retry_in_csv = 0
|
|
|
|
# Calculate how many more we can process if limit is set
|
|
# Use tasks_submitted counter which tracks across all CSVs
|
|
if args.limit:
|
|
remaining_limit = args.limit - stats.get('tasks_submitted', 0)
|
|
if remaining_limit <= 0:
|
|
print(f" Reached limit of {args.limit} new documents, stopping.")
|
|
break
|
|
else:
|
|
remaining_limit = float('inf')
|
|
|
|
for row in rows:
|
|
# Stop adding tasks if we've reached the limit
|
|
if len(tasks) >= remaining_limit:
|
|
break
|
|
|
|
doc_id = row.DocumentId
|
|
|
|
# Check document status from batch query result
|
|
db_status = db_status_map.get(doc_id) # None if not in DB
|
|
|
|
# Skip if already successful in database
|
|
if db_status is True:
|
|
stats['skipped_db'] += 1
|
|
skipped_in_csv += 1
|
|
continue
|
|
|
|
# Check if this is a retry (was failed before)
|
|
if db_status is False:
|
|
stats['retried'] += 1
|
|
retry_in_csv += 1
|
|
|
|
pdf_path = single_loader.get_pdf_path(row)
|
|
if not pdf_path:
|
|
stats['skipped'] += 1
|
|
continue
|
|
|
|
row_dict = {
|
|
'DocumentId': row.DocumentId,
|
|
'InvoiceNumber': row.InvoiceNumber,
|
|
'InvoiceDate': row.InvoiceDate,
|
|
'InvoiceDueDate': row.InvoiceDueDate,
|
|
'OCR': row.OCR,
|
|
'Bankgiro': row.Bankgiro,
|
|
'Plusgiro': row.Plusgiro,
|
|
'Amount': row.Amount,
|
|
# New fields
|
|
'supplier_organisation_number': row.supplier_organisation_number,
|
|
'supplier_accounts': row.supplier_accounts,
|
|
# Metadata fields (not for matching, but for database storage)
|
|
'split': row.split,
|
|
'customer_number': row.customer_number,
|
|
'supplier_name': row.supplier_name,
|
|
}
|
|
|
|
tasks.append((
|
|
row_dict,
|
|
str(pdf_path),
|
|
str(output_dir),
|
|
args.dpi,
|
|
args.min_confidence,
|
|
args.skip_ocr
|
|
))
|
|
|
|
if skipped_in_csv > 0 or retry_in_csv > 0:
|
|
print(f" Skipped {skipped_in_csv} (already in DB), retrying {retry_in_csv} failed")
|
|
|
|
if not tasks:
|
|
continue
|
|
|
|
# Update tasks_submitted counter for limit tracking
|
|
stats['tasks_submitted'] += len(tasks)
|
|
|
|
if use_dual_pool:
|
|
# Dual-pool mode using pre-initialized DualPoolCoordinator
|
|
# (process_text_pdf, process_scanned_pdf already imported above)
|
|
|
|
# Convert tasks to new format
|
|
documents = []
|
|
for task in tasks:
|
|
row_dict, pdf_path_str, output_dir_str, dpi, min_confidence, skip_ocr = task
|
|
# Pre-classify PDF type
|
|
try:
|
|
is_text = is_text_pdf(pdf_path_str)
|
|
except Exception:
|
|
is_text = False
|
|
|
|
documents.append({
|
|
"id": row_dict["DocumentId"],
|
|
"row_dict": row_dict,
|
|
"pdf_path": pdf_path_str,
|
|
"output_dir": output_dir_str,
|
|
"dpi": dpi,
|
|
"min_confidence": min_confidence,
|
|
"is_scanned": not is_text,
|
|
"has_text": is_text,
|
|
"text_length": 1000 if is_text else 0, # Approximate
|
|
})
|
|
|
|
# Count task types
|
|
text_count = sum(1 for d in documents if not d["is_scanned"])
|
|
scan_count = len(documents) - text_count
|
|
print(f" Text PDFs: {text_count}, Scanned PDFs: {scan_count}")
|
|
|
|
# Progress tracking with tqdm
|
|
pbar = tqdm(total=len(documents), desc="Processing")
|
|
|
|
def on_result(task_result):
|
|
"""Handle successful result."""
|
|
result = task_result.data
|
|
handle_result(result)
|
|
pbar.update(1)
|
|
|
|
def on_error(task_id, error):
|
|
"""Handle failed task."""
|
|
handle_error(task_id, error)
|
|
pbar.update(1)
|
|
|
|
# Process with pre-initialized coordinator (workers stay alive)
|
|
results = dual_pool_coordinator.process_batch(
|
|
documents=documents,
|
|
cpu_task_fn=process_text_pdf,
|
|
gpu_task_fn=process_scanned_pdf,
|
|
on_result=on_result,
|
|
on_error=on_error,
|
|
id_field="id",
|
|
)
|
|
|
|
pbar.close()
|
|
|
|
# Log summary
|
|
successful = sum(1 for r in results if r.success)
|
|
failed = len(results) - successful
|
|
print(f" Batch complete: {successful} successful, {failed} failed")
|
|
|
|
else:
|
|
# Single-pool mode (original behavior)
|
|
print(f" Processing {len(tasks)} documents with {args.workers} workers...")
|
|
|
|
# Process documents in parallel (inside CSV loop for streaming)
|
|
# Use single process for debugging or when workers=1
|
|
if args.workers == 1:
|
|
for task in tqdm(tasks, desc="Processing"):
|
|
result = process_single_document(task)
|
|
handle_result(result)
|
|
else:
|
|
# Parallel processing with worker initialization
|
|
# Each worker initializes OCR engine once and reuses it
|
|
with ProcessPoolExecutor(max_workers=args.workers, initializer=_init_worker) as executor:
|
|
futures = {executor.submit(process_single_document, task): task[0]['DocumentId']
|
|
for task in tasks}
|
|
|
|
# Per-document timeout: 120 seconds (2 minutes)
|
|
# This prevents a single stuck document from blocking the entire batch
|
|
DOCUMENT_TIMEOUT = 120
|
|
|
|
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
|
|
doc_id = futures[future]
|
|
try:
|
|
result = future.result(timeout=DOCUMENT_TIMEOUT)
|
|
handle_result(result)
|
|
except TimeoutError:
|
|
handle_error(doc_id, f"Processing timeout after {DOCUMENT_TIMEOUT}s")
|
|
# Cancel the stuck future
|
|
future.cancel()
|
|
except Exception as e:
|
|
handle_error(doc_id, e)
|
|
|
|
# Flush remaining database batch after each CSV
|
|
if db_batch:
|
|
db.save_documents_batch(db_batch)
|
|
db_batch.clear()
|
|
|
|
finally:
|
|
# Shutdown dual-pool coordinator if it was started
|
|
if dual_pool_coordinator is not None:
|
|
dual_pool_coordinator.shutdown()
|
|
|
|
# Close temp file
|
|
processed_items_writer.close()
|
|
|
|
# Use the in-memory counter instead of re-reading the file (performance fix)
|
|
# processed_count already tracks the number of successfully processed items
|
|
|
|
# Cleanup processed_items temp file (not needed anymore)
|
|
processed_items_file.unlink(missing_ok=True)
|
|
|
|
# Close database connection
|
|
db.close()
|
|
|
|
# Print summary
|
|
print("\n" + "=" * 50)
|
|
print("Auto-labeling Complete")
|
|
print("=" * 50)
|
|
print(f"Total documents: {stats['total']}")
|
|
print(f"Successful: {stats['successful']}")
|
|
print(f"Failed: {stats['failed']}")
|
|
print(f"Skipped (no PDF): {stats['skipped']}")
|
|
print(f"Skipped (in DB): {stats['skipped_db']}")
|
|
print(f"Retried (failed): {stats['retried']}")
|
|
print(f"Total annotations: {stats['annotations']}")
|
|
print(f"\nImages saved to: {output_dir / 'temp'}")
|
|
print(f"Labels stored in: PostgreSQL database")
|
|
print(f"\nAnnotations by field:")
|
|
for field, count in stats['by_field'].items():
|
|
print(f" {field}: {count}")
|
|
shard_files = report_writer.get_shard_files()
|
|
if len(shard_files) > 1:
|
|
print(f"\nReport files ({len(shard_files)}):")
|
|
for sf in shard_files:
|
|
print(f" - {sf}")
|
|
else:
|
|
print(f"\nReport: {shard_files[0] if shard_files else args.report}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|