This commit is contained in:
Yaojia Wang
2026-01-13 00:10:27 +01:00
parent 1b7c61cdd8
commit b26fd61852
43 changed files with 7751 additions and 578 deletions

View File

@@ -8,31 +8,83 @@ Generates YOLO training data from PDFs and structured CSV data.
import argparse
import sys
import time
import os
import warnings
from pathlib import Path
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing
# Windows compatibility: use 'spawn' method for multiprocessing
# This is required on Windows and is also safer for libraries like PaddleOCR
if sys.platform == 'win32':
multiprocessing.set_start_method('spawn', force=True)
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from config import get_db_connection_string, PATHS, AUTOLABEL
# Global OCR engine for worker processes (initialized once per worker)
_worker_ocr_engine = None
_worker_initialized = False
_worker_type = None # 'cpu' or 'gpu'
def _init_cpu_worker():
"""Initialize CPU worker (no OCR engine needed)."""
global _worker_initialized, _worker_type
_worker_initialized = True
_worker_type = 'cpu'
def _init_gpu_worker():
"""Initialize GPU worker with OCR engine (called once per worker)."""
global _worker_ocr_engine, _worker_initialized, _worker_type
# Suppress PaddlePaddle/PaddleX reinitialization warnings
warnings.filterwarnings('ignore', message='.*PDX has already been initialized.*')
warnings.filterwarnings('ignore', message='.*reinitialization.*')
# Set environment variable to suppress paddle warnings
os.environ['GLOG_minloglevel'] = '2' # Suppress INFO and WARNING logs
# OCR engine will be lazily initialized on first use
_worker_ocr_engine = None
_worker_initialized = True
_worker_type = 'gpu'
def _init_worker():
"""Initialize worker process with OCR engine (called once per worker)."""
global _worker_ocr_engine
# OCR engine will be lazily initialized on first use
_worker_ocr_engine = None
"""Initialize worker process with OCR engine (called once per worker).
Legacy function for backwards compatibility.
"""
_init_gpu_worker()
def _get_ocr_engine():
"""Get or create OCR engine for current worker."""
global _worker_ocr_engine
if _worker_ocr_engine is None:
from ..ocr import OCREngine
_worker_ocr_engine = OCREngine()
# Suppress warnings during OCR initialization
with warnings.catch_warnings():
warnings.filterwarnings('ignore')
from ..ocr import OCREngine
_worker_ocr_engine = OCREngine()
return _worker_ocr_engine
def _save_output_img(output_img, image_path: Path) -> None:
"""Save OCR output_img to replace the original rendered image."""
from PIL import Image as PILImage
# Convert numpy array to PIL Image and save
if output_img is not None:
img = PILImage.fromarray(output_img)
img.save(str(image_path))
# If output_img is None, the original image is already saved
def process_single_document(args_tuple):
"""
Process a single document (worker function for parallel processing).
@@ -47,8 +99,7 @@ def process_single_document(args_tuple):
# Import inside worker to avoid pickling issues
from ..data import AutoLabelReport, FieldMatchResult
from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
from ..pdf.renderer import get_render_dimensions
from ..pdf import PDFDocument
from ..matcher import FieldMatcher
from ..normalize import normalize_field
from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
@@ -70,98 +121,121 @@ def process_single_document(args_tuple):
}
try:
# Check PDF type
use_ocr = not is_text_pdf(pdf_path)
report.pdf_type = "scanned" if use_ocr else "text"
# Use PDFDocument context manager for efficient PDF handling
# Opens PDF only once, caches dimensions, handles cleanup automatically
with PDFDocument(pdf_path) as pdf_doc:
# Check PDF type (uses cached document)
use_ocr = not pdf_doc.is_text_pdf()
report.pdf_type = "scanned" if use_ocr else "text"
# Skip OCR if requested
if use_ocr and skip_ocr:
report.errors.append("Skipped (scanned PDF)")
report.processing_time_ms = (time.time() - start_time) * 1000
result['report'] = report.to_dict()
return result
# Skip OCR if requested
if use_ocr and skip_ocr:
report.errors.append("Skipped (scanned PDF)")
report.processing_time_ms = (time.time() - start_time) * 1000
result['report'] = report.to_dict()
return result
# Get OCR engine from worker cache (only created once per worker)
ocr_engine = None
if use_ocr:
ocr_engine = _get_ocr_engine()
generator = AnnotationGenerator(min_confidence=min_confidence)
matcher = FieldMatcher()
# Process each page
page_annotations = []
for page_no, image_path in render_pdf_to_images(
pdf_path,
output_dir / 'temp' / doc_id / 'images',
dpi=dpi
):
report.total_pages += 1
img_width, img_height = get_render_dimensions(pdf_path, page_no, dpi)
# Extract tokens
# Get OCR engine from worker cache (only created once per worker)
ocr_engine = None
if use_ocr:
tokens = ocr_engine.extract_from_image(str(image_path), page_no)
else:
tokens = list(extract_text_tokens(pdf_path, page_no))
ocr_engine = _get_ocr_engine()
# Match fields
matches = {}
generator = AnnotationGenerator(min_confidence=min_confidence)
matcher = FieldMatcher()
# Process each page
page_annotations = []
matched_fields = set()
# Render all pages and process (uses cached document handle)
images_dir = output_dir / 'temp' / doc_id / 'images'
for page_no, image_path in pdf_doc.render_all_pages(images_dir, dpi=dpi):
report.total_pages += 1
# Get dimensions from cache (no additional PDF open)
img_width, img_height = pdf_doc.get_render_dimensions(page_no, dpi)
# Extract tokens
if use_ocr:
# Use extract_with_image to get both tokens and preprocessed image
# PaddleOCR coordinates are relative to output_img, not original image
ocr_result = ocr_engine.extract_with_image(
str(image_path),
page_no,
scale_to_pdf_points=72 / dpi
)
tokens = ocr_result.tokens
# Save output_img to replace the original rendered image
# This ensures coordinates match the saved image
_save_output_img(ocr_result.output_img, image_path)
# Update image dimensions to match output_img
if ocr_result.output_img is not None:
img_height, img_width = ocr_result.output_img.shape[:2]
else:
# Use cached document for text extraction
tokens = list(pdf_doc.extract_text_tokens(page_no))
# Match fields
matches = {}
for field_name in FIELD_CLASSES.keys():
value = row_dict.get(field_name)
if not value:
continue
normalized = normalize_field(field_name, str(value))
field_matches = matcher.find_matches(tokens, field_name, normalized, page_no)
# Record result
if field_matches:
best = field_matches[0]
matches[field_name] = field_matches
matched_fields.add(field_name)
report.add_field_result(FieldMatchResult(
field_name=field_name,
csv_value=str(value),
matched=True,
score=best.score,
matched_text=best.matched_text,
candidate_used=best.value,
bbox=best.bbox,
page_no=page_no,
context_keywords=best.context_keywords
))
# Count annotations
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
if annotations:
page_annotations.append({
'image_path': str(image_path),
'page_no': page_no,
'count': len(annotations)
})
report.annotations_generated += len(annotations)
for ann in annotations:
class_name = list(FIELD_CLASSES.keys())[ann.class_id]
result['stats'][class_name] += 1
# Record unmatched fields
for field_name in FIELD_CLASSES.keys():
value = row_dict.get(field_name)
if not value:
continue
normalized = normalize_field(field_name, str(value))
field_matches = matcher.find_matches(tokens, field_name, normalized, page_no)
# Record result
if field_matches:
best = field_matches[0]
matches[field_name] = field_matches
report.add_field_result(FieldMatchResult(
field_name=field_name,
csv_value=str(value),
matched=True,
score=best.score,
matched_text=best.matched_text,
candidate_used=best.value,
bbox=best.bbox,
page_no=page_no,
context_keywords=best.context_keywords
))
else:
if value and field_name not in matched_fields:
report.add_field_result(FieldMatchResult(
field_name=field_name,
csv_value=str(value),
matched=False,
page_no=page_no
page_no=-1
))
# Generate annotations
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
if annotations:
label_path = output_dir / 'temp' / doc_id / 'labels' / f"{image_path.stem}.txt"
generator.save_annotations(annotations, label_path)
page_annotations.append({
'image_path': str(image_path),
'label_path': str(label_path),
'count': len(annotations)
})
report.annotations_generated += len(annotations)
for ann in annotations:
class_name = list(FIELD_CLASSES.keys())[ann.class_id]
result['stats'][class_name] += 1
if page_annotations:
result['pages'] = page_annotations
result['success'] = True
report.success = True
else:
report.errors.append("No annotations generated")
if page_annotations:
result['pages'] = page_annotations
result['success'] = True
report.success = True
else:
report.errors.append("No annotations generated")
except Exception as e:
report.errors.append(str(e))
@@ -178,47 +252,41 @@ def main():
)
parser.add_argument(
'--csv', '-c',
default='data/structured_data/document_export_20260109_212743.csv',
help='Path to structured data CSV file'
default=f"{PATHS['csv_dir']}/*.csv",
help='Path to CSV file(s). Supports: single file, glob pattern (*.csv), or comma-separated list'
)
parser.add_argument(
'--pdf-dir', '-p',
default='data/raw_pdfs',
default=PATHS['pdf_dir'],
help='Directory containing PDF files'
)
parser.add_argument(
'--output', '-o',
default='data/dataset',
default=PATHS['output_dir'],
help='Output directory for dataset'
)
parser.add_argument(
'--dpi',
type=int,
default=300,
help='DPI for PDF rendering (default: 300)'
default=AUTOLABEL['dpi'],
help=f"DPI for PDF rendering (default: {AUTOLABEL['dpi']})"
)
parser.add_argument(
'--min-confidence',
type=float,
default=0.7,
help='Minimum match confidence (default: 0.7)'
)
parser.add_argument(
'--train-ratio',
type=float,
default=0.8,
help='Training set ratio (default: 0.8)'
)
parser.add_argument(
'--val-ratio',
type=float,
default=0.1,
help='Validation set ratio (default: 0.1)'
default=AUTOLABEL['min_confidence'],
help=f"Minimum match confidence (default: {AUTOLABEL['min_confidence']})"
)
parser.add_argument(
'--report',
default='reports/autolabel_report.jsonl',
help='Path for auto-label report (JSONL)'
default=f"{PATHS['reports_dir']}/autolabel_report.jsonl",
help='Path for auto-label report (JSONL). With --max-records, creates report_part000.jsonl, etc.'
)
parser.add_argument(
'--max-records',
type=int,
default=10000,
help='Max records per report file for sharding (default: 10000, 0 = single file)'
)
parser.add_argument(
'--single',
@@ -233,20 +301,37 @@ def main():
'--workers', '-w',
type=int,
default=4,
help='Number of parallel workers (default: 4)'
help='Number of parallel workers (default: 4). Use --cpu-workers and --gpu-workers for dual-pool mode.'
)
parser.add_argument(
'--cpu-workers',
type=int,
default=None,
help='Number of CPU workers for text PDFs (enables dual-pool mode)'
)
parser.add_argument(
'--gpu-workers',
type=int,
default=1,
help='Number of GPU workers for scanned PDFs (default: 1, used with --cpu-workers)'
)
parser.add_argument(
'--skip-ocr',
action='store_true',
help='Skip scanned PDFs (text-layer only)'
)
parser.add_argument(
'--limit', '-l',
type=int,
default=None,
help='Limit number of documents to process (for testing)'
)
args = parser.parse_args()
# Import here to avoid slow startup
from ..data import CSVLoader, AutoLabelReport, FieldMatchResult
from ..data.autolabel_report import ReportWriter
from ..yolo import DatasetBuilder
from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
from ..pdf.renderer import get_render_dimensions
from ..ocr import OCREngine
@@ -254,185 +339,343 @@ def main():
from ..normalize import normalize_field
from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
print(f"Loading CSV data from: {args.csv}")
loader = CSVLoader(args.csv, args.pdf_dir)
# Handle comma-separated CSV paths
csv_input = args.csv
if ',' in csv_input and '*' not in csv_input:
csv_input = [p.strip() for p in csv_input.split(',')]
# Validate data
issues = loader.validate()
if issues:
print(f"Warning: Found {len(issues)} validation issues")
if args.verbose:
for issue in issues[:10]:
print(f" - {issue}")
rows = loader.load_all()
print(f"Loaded {len(rows)} invoice records")
# Filter to single document if specified
if args.single:
rows = [r for r in rows if r.DocumentId == args.single]
if not rows:
print(f"Error: Document {args.single} not found")
sys.exit(1)
print(f"Processing single document: {args.single}")
# Get list of CSV files (don't load all data at once)
temp_loader = CSVLoader(csv_input, args.pdf_dir)
csv_files = temp_loader.csv_paths
pdf_dir = temp_loader.pdf_dir
print(f"Found {len(csv_files)} CSV file(s) to process")
# Setup output directories
output_dir = Path(args.output)
for split in ['train', 'val', 'test']:
(output_dir / split / 'images').mkdir(parents=True, exist_ok=True)
(output_dir / split / 'labels').mkdir(parents=True, exist_ok=True)
# Only create temp directory for images (no train/val/test split during labeling)
(output_dir / 'temp').mkdir(parents=True, exist_ok=True)
# Generate YOLO config files
AnnotationGenerator.generate_classes_file(output_dir / 'classes.txt')
AnnotationGenerator.generate_yaml_config(output_dir / 'dataset.yaml')
# Report writer
# Report writer with optional sharding
report_path = Path(args.report)
report_path.parent.mkdir(parents=True, exist_ok=True)
report_writer = ReportWriter(args.report)
report_writer = ReportWriter(args.report, max_records_per_file=args.max_records)
# Stats
# Database connection for checking existing documents
from ..data.db import DocumentDB
db = DocumentDB()
db.connect()
print("Connected to database for status checking")
# Global stats
stats = {
'total': len(rows),
'total': 0,
'successful': 0,
'failed': 0,
'skipped': 0,
'skipped_db': 0, # Skipped because already in DB
'retried': 0, # Re-processed failed ones
'annotations': 0,
'tasks_submitted': 0, # Tracks tasks submitted across all CSVs for limit
'by_field': {name: 0 for name in FIELD_CLASSES.keys()}
}
# Prepare tasks
tasks = []
for row in rows:
pdf_path = loader.get_pdf_path(row)
if not pdf_path:
# Write report for missing PDF
report = AutoLabelReport(document_id=row.DocumentId)
report.errors.append("PDF not found")
report_writer.write(report)
# Track all processed items for final split (write to temp file to save memory)
processed_items_file = output_dir / 'temp' / 'processed_items.jsonl'
processed_items_file.parent.mkdir(parents=True, exist_ok=True)
processed_items_writer = open(processed_items_file, 'w', encoding='utf-8')
processed_count = 0
seen_doc_ids = set()
# Batch for database updates
db_batch = []
DB_BATCH_SIZE = 100
# Helper function to handle result and update database
# Defined outside the loop so nonlocal can properly reference db_batch
def handle_result(result):
nonlocal processed_count, db_batch
# Write report to file
if result['report']:
report_writer.write_dict(result['report'])
# Add to database batch
db_batch.append(result['report'])
if len(db_batch) >= DB_BATCH_SIZE:
db.save_documents_batch(db_batch)
db_batch.clear()
if result['success']:
# Write to temp file instead of memory
import json
processed_items_writer.write(json.dumps({
'doc_id': result['doc_id'],
'pages': result['pages']
}) + '\n')
processed_items_writer.flush()
processed_count += 1
stats['successful'] += 1
for field, count in result['stats'].items():
stats['by_field'][field] += count
stats['annotations'] += count
elif 'Skipped' in str(result.get('report', {}).get('errors', [])):
stats['skipped'] += 1
else:
stats['failed'] += 1
continue
# Convert row to dict for pickling
row_dict = {
'DocumentId': row.DocumentId,
'InvoiceNumber': row.InvoiceNumber,
'InvoiceDate': row.InvoiceDate,
'InvoiceDueDate': row.InvoiceDueDate,
'OCR': row.OCR,
'Bankgiro': row.Bankgiro,
'Plusgiro': row.Plusgiro,
'Amount': row.Amount,
def handle_error(doc_id, error):
nonlocal db_batch
stats['failed'] += 1
error_report = {
'document_id': doc_id,
'success': False,
'errors': [f"Worker error: {str(error)}"]
}
report_writer.write_dict(error_report)
db_batch.append(error_report)
if len(db_batch) >= DB_BATCH_SIZE:
db.save_documents_batch(db_batch)
db_batch.clear()
if args.verbose:
print(f"Error processing {doc_id}: {error}")
tasks.append((
row_dict,
str(pdf_path),
str(output_dir),
args.dpi,
args.min_confidence,
args.skip_ocr
))
# Initialize dual-pool coordinator if enabled (keeps workers alive across CSVs)
dual_pool_coordinator = None
use_dual_pool = args.cpu_workers is not None
print(f"Processing {len(tasks)} documents with {args.workers} workers...")
if use_dual_pool:
from src.processing import DualPoolCoordinator
from src.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf
# Process documents in parallel
processed_items = []
print(f"\nStarting dual-pool mode: {args.cpu_workers} CPU + {args.gpu_workers} GPU workers")
dual_pool_coordinator = DualPoolCoordinator(
cpu_workers=args.cpu_workers,
gpu_workers=args.gpu_workers,
gpu_id=0,
task_timeout=300.0,
)
dual_pool_coordinator.start()
# Use single process for debugging or when workers=1
if args.workers == 1:
for task in tqdm(tasks, desc="Processing"):
result = process_single_document(task)
try:
# Process CSV files one by one (streaming)
for csv_idx, csv_file in enumerate(csv_files):
print(f"\n[{csv_idx + 1}/{len(csv_files)}] Processing: {csv_file.name}")
# Write report
if result['report']:
report_writer.write_dict(result['report'])
# Load only this CSV file
single_loader = CSVLoader(str(csv_file), str(pdf_dir))
rows = single_loader.load_all()
if result['success']:
processed_items.append({
'doc_id': result['doc_id'],
'pages': result['pages']
})
stats['successful'] += 1
for field, count in result['stats'].items():
stats['by_field'][field] += count
stats['annotations'] += count
elif 'Skipped' in str(result.get('report', {}).get('errors', [])):
stats['skipped'] += 1
# Filter to single document if specified
if args.single:
rows = [r for r in rows if r.DocumentId == args.single]
if not rows:
continue
# Deduplicate across CSV files
rows = [r for r in rows if r.DocumentId not in seen_doc_ids]
for r in rows:
seen_doc_ids.add(r.DocumentId)
if not rows:
print(f" Skipping CSV (no new documents)")
continue
# Batch query database for all document IDs in this CSV
csv_doc_ids = [r.DocumentId for r in rows]
db_status_map = db.check_documents_status_batch(csv_doc_ids)
# Count how many are already processed successfully
already_processed = sum(1 for doc_id in csv_doc_ids if db_status_map.get(doc_id) is True)
# Skip entire CSV if all documents are already processed
if already_processed == len(rows):
print(f" Skipping CSV (all {len(rows)} documents already processed)")
stats['skipped_db'] += len(rows)
continue
# Count how many new documents need processing in this CSV
new_to_process = len(rows) - already_processed
print(f" Found {new_to_process} new documents to process ({already_processed} already in DB)")
stats['total'] += len(rows)
# Prepare tasks for this CSV
tasks = []
skipped_in_csv = 0
retry_in_csv = 0
# Calculate how many more we can process if limit is set
# Use tasks_submitted counter which tracks across all CSVs
if args.limit:
remaining_limit = args.limit - stats.get('tasks_submitted', 0)
if remaining_limit <= 0:
print(f" Reached limit of {args.limit} new documents, stopping.")
break
else:
stats['failed'] += 1
else:
# Parallel processing with worker initialization
# Each worker initializes OCR engine once and reuses it
with ProcessPoolExecutor(max_workers=args.workers, initializer=_init_worker) as executor:
futures = {executor.submit(process_single_document, task): task[0]['DocumentId']
for task in tasks}
remaining_limit = float('inf')
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
doc_id = futures[future]
try:
result = future.result()
for row in rows:
# Stop adding tasks if we've reached the limit
if len(tasks) >= remaining_limit:
break
# Write report
if result['report']:
report_writer.write_dict(result['report'])
doc_id = row.DocumentId
if result['success']:
processed_items.append({
'doc_id': result['doc_id'],
'pages': result['pages']
})
stats['successful'] += 1
for field, count in result['stats'].items():
stats['by_field'][field] += count
stats['annotations'] += count
elif 'Skipped' in str(result.get('report', {}).get('errors', [])):
stats['skipped'] += 1
else:
stats['failed'] += 1
# Check document status from batch query result
db_status = db_status_map.get(doc_id) # None if not in DB
except Exception as e:
stats['failed'] += 1
# Write error report for failed documents
error_report = {
'document_id': doc_id,
'success': False,
'errors': [f"Worker error: {str(e)}"]
}
report_writer.write_dict(error_report)
if args.verbose:
print(f"Error processing {doc_id}: {e}")
# Skip if already successful in database
if db_status is True:
stats['skipped_db'] += 1
skipped_in_csv += 1
continue
# Split and move files
import random
random.seed(42)
random.shuffle(processed_items)
# Check if this is a retry (was failed before)
if db_status is False:
stats['retried'] += 1
retry_in_csv += 1
n_train = int(len(processed_items) * args.train_ratio)
n_val = int(len(processed_items) * args.val_ratio)
pdf_path = single_loader.get_pdf_path(row)
if not pdf_path:
stats['skipped'] += 1
continue
splits = {
'train': processed_items[:n_train],
'val': processed_items[n_train:n_train + n_val],
'test': processed_items[n_train + n_val:]
}
row_dict = {
'DocumentId': row.DocumentId,
'InvoiceNumber': row.InvoiceNumber,
'InvoiceDate': row.InvoiceDate,
'InvoiceDueDate': row.InvoiceDueDate,
'OCR': row.OCR,
'Bankgiro': row.Bankgiro,
'Plusgiro': row.Plusgiro,
'Amount': row.Amount,
}
import shutil
for split_name, items in splits.items():
for item in items:
for page in item['pages']:
# Move image
image_path = Path(page['image_path'])
label_path = Path(page['label_path'])
dest_img = output_dir / split_name / 'images' / image_path.name
shutil.move(str(image_path), str(dest_img))
tasks.append((
row_dict,
str(pdf_path),
str(output_dir),
args.dpi,
args.min_confidence,
args.skip_ocr
))
# Move label
dest_label = output_dir / split_name / 'labels' / label_path.name
shutil.move(str(label_path), str(dest_label))
if skipped_in_csv > 0 or retry_in_csv > 0:
print(f" Skipped {skipped_in_csv} (already in DB), retrying {retry_in_csv} failed")
# Cleanup temp
shutil.rmtree(output_dir / 'temp', ignore_errors=True)
if not tasks:
continue
# Update tasks_submitted counter for limit tracking
stats['tasks_submitted'] += len(tasks)
if use_dual_pool:
# Dual-pool mode using pre-initialized DualPoolCoordinator
# (process_text_pdf, process_scanned_pdf already imported above)
# Convert tasks to new format
documents = []
for task in tasks:
row_dict, pdf_path_str, output_dir_str, dpi, min_confidence, skip_ocr = task
# Pre-classify PDF type
try:
is_text = is_text_pdf(pdf_path_str)
except Exception:
is_text = False
documents.append({
"id": row_dict["DocumentId"],
"row_dict": row_dict,
"pdf_path": pdf_path_str,
"output_dir": output_dir_str,
"dpi": dpi,
"min_confidence": min_confidence,
"is_scanned": not is_text,
"has_text": is_text,
"text_length": 1000 if is_text else 0, # Approximate
})
# Count task types
text_count = sum(1 for d in documents if not d["is_scanned"])
scan_count = len(documents) - text_count
print(f" Text PDFs: {text_count}, Scanned PDFs: {scan_count}")
# Progress tracking with tqdm
pbar = tqdm(total=len(documents), desc="Processing")
def on_result(task_result):
"""Handle successful result."""
result = task_result.data
handle_result(result)
pbar.update(1)
def on_error(task_id, error):
"""Handle failed task."""
handle_error(task_id, error)
pbar.update(1)
# Process with pre-initialized coordinator (workers stay alive)
results = dual_pool_coordinator.process_batch(
documents=documents,
cpu_task_fn=process_text_pdf,
gpu_task_fn=process_scanned_pdf,
on_result=on_result,
on_error=on_error,
id_field="id",
)
pbar.close()
# Log summary
successful = sum(1 for r in results if r.success)
failed = len(results) - successful
print(f" Batch complete: {successful} successful, {failed} failed")
else:
# Single-pool mode (original behavior)
print(f" Processing {len(tasks)} documents with {args.workers} workers...")
# Process documents in parallel (inside CSV loop for streaming)
# Use single process for debugging or when workers=1
if args.workers == 1:
for task in tqdm(tasks, desc="Processing"):
result = process_single_document(task)
handle_result(result)
else:
# Parallel processing with worker initialization
# Each worker initializes OCR engine once and reuses it
with ProcessPoolExecutor(max_workers=args.workers, initializer=_init_worker) as executor:
futures = {executor.submit(process_single_document, task): task[0]['DocumentId']
for task in tasks}
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
doc_id = futures[future]
try:
result = future.result()
handle_result(result)
except Exception as e:
handle_error(doc_id, e)
# Flush remaining database batch after each CSV
if db_batch:
db.save_documents_batch(db_batch)
db_batch.clear()
finally:
# Shutdown dual-pool coordinator if it was started
if dual_pool_coordinator is not None:
dual_pool_coordinator.shutdown()
# Close temp file
processed_items_writer.close()
# Use the in-memory counter instead of re-reading the file (performance fix)
# processed_count already tracks the number of successfully processed items
# Cleanup processed_items temp file (not needed anymore)
processed_items_file.unlink(missing_ok=True)
# Close database connection
db.close()
# Print summary
print("\n" + "=" * 50)
@@ -441,17 +684,22 @@ def main():
print(f"Total documents: {stats['total']}")
print(f"Successful: {stats['successful']}")
print(f"Failed: {stats['failed']}")
print(f"Skipped (OCR): {stats['skipped']}")
print(f"Skipped (no PDF): {stats['skipped']}")
print(f"Skipped (in DB): {stats['skipped_db']}")
print(f"Retried (failed): {stats['retried']}")
print(f"Total annotations: {stats['annotations']}")
print(f"\nDataset split:")
print(f" Train: {len(splits['train'])} documents")
print(f" Val: {len(splits['val'])} documents")
print(f" Test: {len(splits['test'])} documents")
print(f"\nImages saved to: {output_dir / 'temp'}")
print(f"Labels stored in: PostgreSQL database")
print(f"\nAnnotations by field:")
for field, count in stats['by_field'].items():
print(f" {field}: {count}")
print(f"\nOutput: {output_dir}")
print(f"Report: {args.report}")
shard_files = report_writer.get_shard_files()
if len(shard_files) > 1:
print(f"\nReport files ({len(shard_files)}):")
for sf in shard_files:
print(f" - {sf}")
else:
print(f"\nReport: {shard_files[0] if shard_files else args.report}")
if __name__ == '__main__':