WOP
This commit is contained in:
@@ -8,31 +8,83 @@ Generates YOLO training data from PDFs and structured CSV data.
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
import multiprocessing
|
||||
|
||||
# Windows compatibility: use 'spawn' method for multiprocessing
|
||||
# This is required on Windows and is also safer for libraries like PaddleOCR
|
||||
if sys.platform == 'win32':
|
||||
multiprocessing.set_start_method('spawn', force=True)
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from config import get_db_connection_string, PATHS, AUTOLABEL
|
||||
|
||||
# Global OCR engine for worker processes (initialized once per worker)
|
||||
_worker_ocr_engine = None
|
||||
_worker_initialized = False
|
||||
_worker_type = None # 'cpu' or 'gpu'
|
||||
|
||||
|
||||
def _init_cpu_worker():
|
||||
"""Initialize CPU worker (no OCR engine needed)."""
|
||||
global _worker_initialized, _worker_type
|
||||
_worker_initialized = True
|
||||
_worker_type = 'cpu'
|
||||
|
||||
|
||||
def _init_gpu_worker():
|
||||
"""Initialize GPU worker with OCR engine (called once per worker)."""
|
||||
global _worker_ocr_engine, _worker_initialized, _worker_type
|
||||
|
||||
# Suppress PaddlePaddle/PaddleX reinitialization warnings
|
||||
warnings.filterwarnings('ignore', message='.*PDX has already been initialized.*')
|
||||
warnings.filterwarnings('ignore', message='.*reinitialization.*')
|
||||
|
||||
# Set environment variable to suppress paddle warnings
|
||||
os.environ['GLOG_minloglevel'] = '2' # Suppress INFO and WARNING logs
|
||||
|
||||
# OCR engine will be lazily initialized on first use
|
||||
_worker_ocr_engine = None
|
||||
_worker_initialized = True
|
||||
_worker_type = 'gpu'
|
||||
|
||||
|
||||
def _init_worker():
|
||||
"""Initialize worker process with OCR engine (called once per worker)."""
|
||||
global _worker_ocr_engine
|
||||
# OCR engine will be lazily initialized on first use
|
||||
_worker_ocr_engine = None
|
||||
"""Initialize worker process with OCR engine (called once per worker).
|
||||
Legacy function for backwards compatibility.
|
||||
"""
|
||||
_init_gpu_worker()
|
||||
|
||||
|
||||
def _get_ocr_engine():
|
||||
"""Get or create OCR engine for current worker."""
|
||||
global _worker_ocr_engine
|
||||
|
||||
if _worker_ocr_engine is None:
|
||||
from ..ocr import OCREngine
|
||||
_worker_ocr_engine = OCREngine()
|
||||
# Suppress warnings during OCR initialization
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore')
|
||||
from ..ocr import OCREngine
|
||||
_worker_ocr_engine = OCREngine()
|
||||
|
||||
return _worker_ocr_engine
|
||||
|
||||
|
||||
def _save_output_img(output_img, image_path: Path) -> None:
|
||||
"""Save OCR output_img to replace the original rendered image."""
|
||||
from PIL import Image as PILImage
|
||||
|
||||
# Convert numpy array to PIL Image and save
|
||||
if output_img is not None:
|
||||
img = PILImage.fromarray(output_img)
|
||||
img.save(str(image_path))
|
||||
# If output_img is None, the original image is already saved
|
||||
|
||||
|
||||
def process_single_document(args_tuple):
|
||||
"""
|
||||
Process a single document (worker function for parallel processing).
|
||||
@@ -47,8 +99,7 @@ def process_single_document(args_tuple):
|
||||
|
||||
# Import inside worker to avoid pickling issues
|
||||
from ..data import AutoLabelReport, FieldMatchResult
|
||||
from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
|
||||
from ..pdf.renderer import get_render_dimensions
|
||||
from ..pdf import PDFDocument
|
||||
from ..matcher import FieldMatcher
|
||||
from ..normalize import normalize_field
|
||||
from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
|
||||
@@ -70,98 +121,121 @@ def process_single_document(args_tuple):
|
||||
}
|
||||
|
||||
try:
|
||||
# Check PDF type
|
||||
use_ocr = not is_text_pdf(pdf_path)
|
||||
report.pdf_type = "scanned" if use_ocr else "text"
|
||||
# Use PDFDocument context manager for efficient PDF handling
|
||||
# Opens PDF only once, caches dimensions, handles cleanup automatically
|
||||
with PDFDocument(pdf_path) as pdf_doc:
|
||||
# Check PDF type (uses cached document)
|
||||
use_ocr = not pdf_doc.is_text_pdf()
|
||||
report.pdf_type = "scanned" if use_ocr else "text"
|
||||
|
||||
# Skip OCR if requested
|
||||
if use_ocr and skip_ocr:
|
||||
report.errors.append("Skipped (scanned PDF)")
|
||||
report.processing_time_ms = (time.time() - start_time) * 1000
|
||||
result['report'] = report.to_dict()
|
||||
return result
|
||||
# Skip OCR if requested
|
||||
if use_ocr and skip_ocr:
|
||||
report.errors.append("Skipped (scanned PDF)")
|
||||
report.processing_time_ms = (time.time() - start_time) * 1000
|
||||
result['report'] = report.to_dict()
|
||||
return result
|
||||
|
||||
# Get OCR engine from worker cache (only created once per worker)
|
||||
ocr_engine = None
|
||||
if use_ocr:
|
||||
ocr_engine = _get_ocr_engine()
|
||||
|
||||
generator = AnnotationGenerator(min_confidence=min_confidence)
|
||||
matcher = FieldMatcher()
|
||||
|
||||
# Process each page
|
||||
page_annotations = []
|
||||
|
||||
for page_no, image_path in render_pdf_to_images(
|
||||
pdf_path,
|
||||
output_dir / 'temp' / doc_id / 'images',
|
||||
dpi=dpi
|
||||
):
|
||||
report.total_pages += 1
|
||||
img_width, img_height = get_render_dimensions(pdf_path, page_no, dpi)
|
||||
|
||||
# Extract tokens
|
||||
# Get OCR engine from worker cache (only created once per worker)
|
||||
ocr_engine = None
|
||||
if use_ocr:
|
||||
tokens = ocr_engine.extract_from_image(str(image_path), page_no)
|
||||
else:
|
||||
tokens = list(extract_text_tokens(pdf_path, page_no))
|
||||
ocr_engine = _get_ocr_engine()
|
||||
|
||||
# Match fields
|
||||
matches = {}
|
||||
generator = AnnotationGenerator(min_confidence=min_confidence)
|
||||
matcher = FieldMatcher()
|
||||
|
||||
# Process each page
|
||||
page_annotations = []
|
||||
matched_fields = set()
|
||||
|
||||
# Render all pages and process (uses cached document handle)
|
||||
images_dir = output_dir / 'temp' / doc_id / 'images'
|
||||
for page_no, image_path in pdf_doc.render_all_pages(images_dir, dpi=dpi):
|
||||
report.total_pages += 1
|
||||
|
||||
# Get dimensions from cache (no additional PDF open)
|
||||
img_width, img_height = pdf_doc.get_render_dimensions(page_no, dpi)
|
||||
|
||||
# Extract tokens
|
||||
if use_ocr:
|
||||
# Use extract_with_image to get both tokens and preprocessed image
|
||||
# PaddleOCR coordinates are relative to output_img, not original image
|
||||
ocr_result = ocr_engine.extract_with_image(
|
||||
str(image_path),
|
||||
page_no,
|
||||
scale_to_pdf_points=72 / dpi
|
||||
)
|
||||
tokens = ocr_result.tokens
|
||||
|
||||
# Save output_img to replace the original rendered image
|
||||
# This ensures coordinates match the saved image
|
||||
_save_output_img(ocr_result.output_img, image_path)
|
||||
|
||||
# Update image dimensions to match output_img
|
||||
if ocr_result.output_img is not None:
|
||||
img_height, img_width = ocr_result.output_img.shape[:2]
|
||||
else:
|
||||
# Use cached document for text extraction
|
||||
tokens = list(pdf_doc.extract_text_tokens(page_no))
|
||||
|
||||
# Match fields
|
||||
matches = {}
|
||||
for field_name in FIELD_CLASSES.keys():
|
||||
value = row_dict.get(field_name)
|
||||
if not value:
|
||||
continue
|
||||
|
||||
normalized = normalize_field(field_name, str(value))
|
||||
field_matches = matcher.find_matches(tokens, field_name, normalized, page_no)
|
||||
|
||||
# Record result
|
||||
if field_matches:
|
||||
best = field_matches[0]
|
||||
matches[field_name] = field_matches
|
||||
matched_fields.add(field_name)
|
||||
report.add_field_result(FieldMatchResult(
|
||||
field_name=field_name,
|
||||
csv_value=str(value),
|
||||
matched=True,
|
||||
score=best.score,
|
||||
matched_text=best.matched_text,
|
||||
candidate_used=best.value,
|
||||
bbox=best.bbox,
|
||||
page_no=page_no,
|
||||
context_keywords=best.context_keywords
|
||||
))
|
||||
|
||||
# Count annotations
|
||||
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
|
||||
|
||||
if annotations:
|
||||
page_annotations.append({
|
||||
'image_path': str(image_path),
|
||||
'page_no': page_no,
|
||||
'count': len(annotations)
|
||||
})
|
||||
|
||||
report.annotations_generated += len(annotations)
|
||||
for ann in annotations:
|
||||
class_name = list(FIELD_CLASSES.keys())[ann.class_id]
|
||||
result['stats'][class_name] += 1
|
||||
|
||||
# Record unmatched fields
|
||||
for field_name in FIELD_CLASSES.keys():
|
||||
value = row_dict.get(field_name)
|
||||
if not value:
|
||||
continue
|
||||
|
||||
normalized = normalize_field(field_name, str(value))
|
||||
field_matches = matcher.find_matches(tokens, field_name, normalized, page_no)
|
||||
|
||||
# Record result
|
||||
if field_matches:
|
||||
best = field_matches[0]
|
||||
matches[field_name] = field_matches
|
||||
report.add_field_result(FieldMatchResult(
|
||||
field_name=field_name,
|
||||
csv_value=str(value),
|
||||
matched=True,
|
||||
score=best.score,
|
||||
matched_text=best.matched_text,
|
||||
candidate_used=best.value,
|
||||
bbox=best.bbox,
|
||||
page_no=page_no,
|
||||
context_keywords=best.context_keywords
|
||||
))
|
||||
else:
|
||||
if value and field_name not in matched_fields:
|
||||
report.add_field_result(FieldMatchResult(
|
||||
field_name=field_name,
|
||||
csv_value=str(value),
|
||||
matched=False,
|
||||
page_no=page_no
|
||||
page_no=-1
|
||||
))
|
||||
|
||||
# Generate annotations
|
||||
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
|
||||
|
||||
if annotations:
|
||||
label_path = output_dir / 'temp' / doc_id / 'labels' / f"{image_path.stem}.txt"
|
||||
generator.save_annotations(annotations, label_path)
|
||||
page_annotations.append({
|
||||
'image_path': str(image_path),
|
||||
'label_path': str(label_path),
|
||||
'count': len(annotations)
|
||||
})
|
||||
|
||||
report.annotations_generated += len(annotations)
|
||||
for ann in annotations:
|
||||
class_name = list(FIELD_CLASSES.keys())[ann.class_id]
|
||||
result['stats'][class_name] += 1
|
||||
|
||||
if page_annotations:
|
||||
result['pages'] = page_annotations
|
||||
result['success'] = True
|
||||
report.success = True
|
||||
else:
|
||||
report.errors.append("No annotations generated")
|
||||
if page_annotations:
|
||||
result['pages'] = page_annotations
|
||||
result['success'] = True
|
||||
report.success = True
|
||||
else:
|
||||
report.errors.append("No annotations generated")
|
||||
|
||||
except Exception as e:
|
||||
report.errors.append(str(e))
|
||||
@@ -178,47 +252,41 @@ def main():
|
||||
)
|
||||
parser.add_argument(
|
||||
'--csv', '-c',
|
||||
default='data/structured_data/document_export_20260109_212743.csv',
|
||||
help='Path to structured data CSV file'
|
||||
default=f"{PATHS['csv_dir']}/*.csv",
|
||||
help='Path to CSV file(s). Supports: single file, glob pattern (*.csv), or comma-separated list'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--pdf-dir', '-p',
|
||||
default='data/raw_pdfs',
|
||||
default=PATHS['pdf_dir'],
|
||||
help='Directory containing PDF files'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output', '-o',
|
||||
default='data/dataset',
|
||||
default=PATHS['output_dir'],
|
||||
help='Output directory for dataset'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dpi',
|
||||
type=int,
|
||||
default=300,
|
||||
help='DPI for PDF rendering (default: 300)'
|
||||
default=AUTOLABEL['dpi'],
|
||||
help=f"DPI for PDF rendering (default: {AUTOLABEL['dpi']})"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--min-confidence',
|
||||
type=float,
|
||||
default=0.7,
|
||||
help='Minimum match confidence (default: 0.7)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--train-ratio',
|
||||
type=float,
|
||||
default=0.8,
|
||||
help='Training set ratio (default: 0.8)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--val-ratio',
|
||||
type=float,
|
||||
default=0.1,
|
||||
help='Validation set ratio (default: 0.1)'
|
||||
default=AUTOLABEL['min_confidence'],
|
||||
help=f"Minimum match confidence (default: {AUTOLABEL['min_confidence']})"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--report',
|
||||
default='reports/autolabel_report.jsonl',
|
||||
help='Path for auto-label report (JSONL)'
|
||||
default=f"{PATHS['reports_dir']}/autolabel_report.jsonl",
|
||||
help='Path for auto-label report (JSONL). With --max-records, creates report_part000.jsonl, etc.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--max-records',
|
||||
type=int,
|
||||
default=10000,
|
||||
help='Max records per report file for sharding (default: 10000, 0 = single file)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--single',
|
||||
@@ -233,20 +301,37 @@ def main():
|
||||
'--workers', '-w',
|
||||
type=int,
|
||||
default=4,
|
||||
help='Number of parallel workers (default: 4)'
|
||||
help='Number of parallel workers (default: 4). Use --cpu-workers and --gpu-workers for dual-pool mode.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--cpu-workers',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Number of CPU workers for text PDFs (enables dual-pool mode)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--gpu-workers',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Number of GPU workers for scanned PDFs (default: 1, used with --cpu-workers)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--skip-ocr',
|
||||
action='store_true',
|
||||
help='Skip scanned PDFs (text-layer only)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--limit', '-l',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Limit number of documents to process (for testing)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Import here to avoid slow startup
|
||||
from ..data import CSVLoader, AutoLabelReport, FieldMatchResult
|
||||
from ..data.autolabel_report import ReportWriter
|
||||
from ..yolo import DatasetBuilder
|
||||
from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
|
||||
from ..pdf.renderer import get_render_dimensions
|
||||
from ..ocr import OCREngine
|
||||
@@ -254,185 +339,343 @@ def main():
|
||||
from ..normalize import normalize_field
|
||||
from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
|
||||
|
||||
print(f"Loading CSV data from: {args.csv}")
|
||||
loader = CSVLoader(args.csv, args.pdf_dir)
|
||||
# Handle comma-separated CSV paths
|
||||
csv_input = args.csv
|
||||
if ',' in csv_input and '*' not in csv_input:
|
||||
csv_input = [p.strip() for p in csv_input.split(',')]
|
||||
|
||||
# Validate data
|
||||
issues = loader.validate()
|
||||
if issues:
|
||||
print(f"Warning: Found {len(issues)} validation issues")
|
||||
if args.verbose:
|
||||
for issue in issues[:10]:
|
||||
print(f" - {issue}")
|
||||
|
||||
rows = loader.load_all()
|
||||
print(f"Loaded {len(rows)} invoice records")
|
||||
|
||||
# Filter to single document if specified
|
||||
if args.single:
|
||||
rows = [r for r in rows if r.DocumentId == args.single]
|
||||
if not rows:
|
||||
print(f"Error: Document {args.single} not found")
|
||||
sys.exit(1)
|
||||
print(f"Processing single document: {args.single}")
|
||||
# Get list of CSV files (don't load all data at once)
|
||||
temp_loader = CSVLoader(csv_input, args.pdf_dir)
|
||||
csv_files = temp_loader.csv_paths
|
||||
pdf_dir = temp_loader.pdf_dir
|
||||
print(f"Found {len(csv_files)} CSV file(s) to process")
|
||||
|
||||
# Setup output directories
|
||||
output_dir = Path(args.output)
|
||||
for split in ['train', 'val', 'test']:
|
||||
(output_dir / split / 'images').mkdir(parents=True, exist_ok=True)
|
||||
(output_dir / split / 'labels').mkdir(parents=True, exist_ok=True)
|
||||
# Only create temp directory for images (no train/val/test split during labeling)
|
||||
(output_dir / 'temp').mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate YOLO config files
|
||||
AnnotationGenerator.generate_classes_file(output_dir / 'classes.txt')
|
||||
AnnotationGenerator.generate_yaml_config(output_dir / 'dataset.yaml')
|
||||
|
||||
# Report writer
|
||||
# Report writer with optional sharding
|
||||
report_path = Path(args.report)
|
||||
report_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
report_writer = ReportWriter(args.report)
|
||||
report_writer = ReportWriter(args.report, max_records_per_file=args.max_records)
|
||||
|
||||
# Stats
|
||||
# Database connection for checking existing documents
|
||||
from ..data.db import DocumentDB
|
||||
db = DocumentDB()
|
||||
db.connect()
|
||||
print("Connected to database for status checking")
|
||||
|
||||
# Global stats
|
||||
stats = {
|
||||
'total': len(rows),
|
||||
'total': 0,
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'skipped': 0,
|
||||
'skipped_db': 0, # Skipped because already in DB
|
||||
'retried': 0, # Re-processed failed ones
|
||||
'annotations': 0,
|
||||
'tasks_submitted': 0, # Tracks tasks submitted across all CSVs for limit
|
||||
'by_field': {name: 0 for name in FIELD_CLASSES.keys()}
|
||||
}
|
||||
|
||||
# Prepare tasks
|
||||
tasks = []
|
||||
for row in rows:
|
||||
pdf_path = loader.get_pdf_path(row)
|
||||
if not pdf_path:
|
||||
# Write report for missing PDF
|
||||
report = AutoLabelReport(document_id=row.DocumentId)
|
||||
report.errors.append("PDF not found")
|
||||
report_writer.write(report)
|
||||
# Track all processed items for final split (write to temp file to save memory)
|
||||
processed_items_file = output_dir / 'temp' / 'processed_items.jsonl'
|
||||
processed_items_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
processed_items_writer = open(processed_items_file, 'w', encoding='utf-8')
|
||||
processed_count = 0
|
||||
seen_doc_ids = set()
|
||||
|
||||
# Batch for database updates
|
||||
db_batch = []
|
||||
DB_BATCH_SIZE = 100
|
||||
|
||||
# Helper function to handle result and update database
|
||||
# Defined outside the loop so nonlocal can properly reference db_batch
|
||||
def handle_result(result):
|
||||
nonlocal processed_count, db_batch
|
||||
|
||||
# Write report to file
|
||||
if result['report']:
|
||||
report_writer.write_dict(result['report'])
|
||||
# Add to database batch
|
||||
db_batch.append(result['report'])
|
||||
if len(db_batch) >= DB_BATCH_SIZE:
|
||||
db.save_documents_batch(db_batch)
|
||||
db_batch.clear()
|
||||
|
||||
if result['success']:
|
||||
# Write to temp file instead of memory
|
||||
import json
|
||||
processed_items_writer.write(json.dumps({
|
||||
'doc_id': result['doc_id'],
|
||||
'pages': result['pages']
|
||||
}) + '\n')
|
||||
processed_items_writer.flush()
|
||||
processed_count += 1
|
||||
stats['successful'] += 1
|
||||
for field, count in result['stats'].items():
|
||||
stats['by_field'][field] += count
|
||||
stats['annotations'] += count
|
||||
elif 'Skipped' in str(result.get('report', {}).get('errors', [])):
|
||||
stats['skipped'] += 1
|
||||
else:
|
||||
stats['failed'] += 1
|
||||
continue
|
||||
|
||||
# Convert row to dict for pickling
|
||||
row_dict = {
|
||||
'DocumentId': row.DocumentId,
|
||||
'InvoiceNumber': row.InvoiceNumber,
|
||||
'InvoiceDate': row.InvoiceDate,
|
||||
'InvoiceDueDate': row.InvoiceDueDate,
|
||||
'OCR': row.OCR,
|
||||
'Bankgiro': row.Bankgiro,
|
||||
'Plusgiro': row.Plusgiro,
|
||||
'Amount': row.Amount,
|
||||
def handle_error(doc_id, error):
|
||||
nonlocal db_batch
|
||||
stats['failed'] += 1
|
||||
error_report = {
|
||||
'document_id': doc_id,
|
||||
'success': False,
|
||||
'errors': [f"Worker error: {str(error)}"]
|
||||
}
|
||||
report_writer.write_dict(error_report)
|
||||
db_batch.append(error_report)
|
||||
if len(db_batch) >= DB_BATCH_SIZE:
|
||||
db.save_documents_batch(db_batch)
|
||||
db_batch.clear()
|
||||
if args.verbose:
|
||||
print(f"Error processing {doc_id}: {error}")
|
||||
|
||||
tasks.append((
|
||||
row_dict,
|
||||
str(pdf_path),
|
||||
str(output_dir),
|
||||
args.dpi,
|
||||
args.min_confidence,
|
||||
args.skip_ocr
|
||||
))
|
||||
# Initialize dual-pool coordinator if enabled (keeps workers alive across CSVs)
|
||||
dual_pool_coordinator = None
|
||||
use_dual_pool = args.cpu_workers is not None
|
||||
|
||||
print(f"Processing {len(tasks)} documents with {args.workers} workers...")
|
||||
if use_dual_pool:
|
||||
from src.processing import DualPoolCoordinator
|
||||
from src.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf
|
||||
|
||||
# Process documents in parallel
|
||||
processed_items = []
|
||||
print(f"\nStarting dual-pool mode: {args.cpu_workers} CPU + {args.gpu_workers} GPU workers")
|
||||
dual_pool_coordinator = DualPoolCoordinator(
|
||||
cpu_workers=args.cpu_workers,
|
||||
gpu_workers=args.gpu_workers,
|
||||
gpu_id=0,
|
||||
task_timeout=300.0,
|
||||
)
|
||||
dual_pool_coordinator.start()
|
||||
|
||||
# Use single process for debugging or when workers=1
|
||||
if args.workers == 1:
|
||||
for task in tqdm(tasks, desc="Processing"):
|
||||
result = process_single_document(task)
|
||||
try:
|
||||
# Process CSV files one by one (streaming)
|
||||
for csv_idx, csv_file in enumerate(csv_files):
|
||||
print(f"\n[{csv_idx + 1}/{len(csv_files)}] Processing: {csv_file.name}")
|
||||
|
||||
# Write report
|
||||
if result['report']:
|
||||
report_writer.write_dict(result['report'])
|
||||
# Load only this CSV file
|
||||
single_loader = CSVLoader(str(csv_file), str(pdf_dir))
|
||||
rows = single_loader.load_all()
|
||||
|
||||
if result['success']:
|
||||
processed_items.append({
|
||||
'doc_id': result['doc_id'],
|
||||
'pages': result['pages']
|
||||
})
|
||||
stats['successful'] += 1
|
||||
for field, count in result['stats'].items():
|
||||
stats['by_field'][field] += count
|
||||
stats['annotations'] += count
|
||||
elif 'Skipped' in str(result.get('report', {}).get('errors', [])):
|
||||
stats['skipped'] += 1
|
||||
# Filter to single document if specified
|
||||
if args.single:
|
||||
rows = [r for r in rows if r.DocumentId == args.single]
|
||||
if not rows:
|
||||
continue
|
||||
|
||||
# Deduplicate across CSV files
|
||||
rows = [r for r in rows if r.DocumentId not in seen_doc_ids]
|
||||
for r in rows:
|
||||
seen_doc_ids.add(r.DocumentId)
|
||||
|
||||
if not rows:
|
||||
print(f" Skipping CSV (no new documents)")
|
||||
continue
|
||||
|
||||
# Batch query database for all document IDs in this CSV
|
||||
csv_doc_ids = [r.DocumentId for r in rows]
|
||||
db_status_map = db.check_documents_status_batch(csv_doc_ids)
|
||||
|
||||
# Count how many are already processed successfully
|
||||
already_processed = sum(1 for doc_id in csv_doc_ids if db_status_map.get(doc_id) is True)
|
||||
|
||||
# Skip entire CSV if all documents are already processed
|
||||
if already_processed == len(rows):
|
||||
print(f" Skipping CSV (all {len(rows)} documents already processed)")
|
||||
stats['skipped_db'] += len(rows)
|
||||
continue
|
||||
|
||||
# Count how many new documents need processing in this CSV
|
||||
new_to_process = len(rows) - already_processed
|
||||
print(f" Found {new_to_process} new documents to process ({already_processed} already in DB)")
|
||||
|
||||
stats['total'] += len(rows)
|
||||
|
||||
# Prepare tasks for this CSV
|
||||
tasks = []
|
||||
skipped_in_csv = 0
|
||||
retry_in_csv = 0
|
||||
|
||||
# Calculate how many more we can process if limit is set
|
||||
# Use tasks_submitted counter which tracks across all CSVs
|
||||
if args.limit:
|
||||
remaining_limit = args.limit - stats.get('tasks_submitted', 0)
|
||||
if remaining_limit <= 0:
|
||||
print(f" Reached limit of {args.limit} new documents, stopping.")
|
||||
break
|
||||
else:
|
||||
stats['failed'] += 1
|
||||
else:
|
||||
# Parallel processing with worker initialization
|
||||
# Each worker initializes OCR engine once and reuses it
|
||||
with ProcessPoolExecutor(max_workers=args.workers, initializer=_init_worker) as executor:
|
||||
futures = {executor.submit(process_single_document, task): task[0]['DocumentId']
|
||||
for task in tasks}
|
||||
remaining_limit = float('inf')
|
||||
|
||||
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
|
||||
doc_id = futures[future]
|
||||
try:
|
||||
result = future.result()
|
||||
for row in rows:
|
||||
# Stop adding tasks if we've reached the limit
|
||||
if len(tasks) >= remaining_limit:
|
||||
break
|
||||
|
||||
# Write report
|
||||
if result['report']:
|
||||
report_writer.write_dict(result['report'])
|
||||
doc_id = row.DocumentId
|
||||
|
||||
if result['success']:
|
||||
processed_items.append({
|
||||
'doc_id': result['doc_id'],
|
||||
'pages': result['pages']
|
||||
})
|
||||
stats['successful'] += 1
|
||||
for field, count in result['stats'].items():
|
||||
stats['by_field'][field] += count
|
||||
stats['annotations'] += count
|
||||
elif 'Skipped' in str(result.get('report', {}).get('errors', [])):
|
||||
stats['skipped'] += 1
|
||||
else:
|
||||
stats['failed'] += 1
|
||||
# Check document status from batch query result
|
||||
db_status = db_status_map.get(doc_id) # None if not in DB
|
||||
|
||||
except Exception as e:
|
||||
stats['failed'] += 1
|
||||
# Write error report for failed documents
|
||||
error_report = {
|
||||
'document_id': doc_id,
|
||||
'success': False,
|
||||
'errors': [f"Worker error: {str(e)}"]
|
||||
}
|
||||
report_writer.write_dict(error_report)
|
||||
if args.verbose:
|
||||
print(f"Error processing {doc_id}: {e}")
|
||||
# Skip if already successful in database
|
||||
if db_status is True:
|
||||
stats['skipped_db'] += 1
|
||||
skipped_in_csv += 1
|
||||
continue
|
||||
|
||||
# Split and move files
|
||||
import random
|
||||
random.seed(42)
|
||||
random.shuffle(processed_items)
|
||||
# Check if this is a retry (was failed before)
|
||||
if db_status is False:
|
||||
stats['retried'] += 1
|
||||
retry_in_csv += 1
|
||||
|
||||
n_train = int(len(processed_items) * args.train_ratio)
|
||||
n_val = int(len(processed_items) * args.val_ratio)
|
||||
pdf_path = single_loader.get_pdf_path(row)
|
||||
if not pdf_path:
|
||||
stats['skipped'] += 1
|
||||
continue
|
||||
|
||||
splits = {
|
||||
'train': processed_items[:n_train],
|
||||
'val': processed_items[n_train:n_train + n_val],
|
||||
'test': processed_items[n_train + n_val:]
|
||||
}
|
||||
row_dict = {
|
||||
'DocumentId': row.DocumentId,
|
||||
'InvoiceNumber': row.InvoiceNumber,
|
||||
'InvoiceDate': row.InvoiceDate,
|
||||
'InvoiceDueDate': row.InvoiceDueDate,
|
||||
'OCR': row.OCR,
|
||||
'Bankgiro': row.Bankgiro,
|
||||
'Plusgiro': row.Plusgiro,
|
||||
'Amount': row.Amount,
|
||||
}
|
||||
|
||||
import shutil
|
||||
for split_name, items in splits.items():
|
||||
for item in items:
|
||||
for page in item['pages']:
|
||||
# Move image
|
||||
image_path = Path(page['image_path'])
|
||||
label_path = Path(page['label_path'])
|
||||
dest_img = output_dir / split_name / 'images' / image_path.name
|
||||
shutil.move(str(image_path), str(dest_img))
|
||||
tasks.append((
|
||||
row_dict,
|
||||
str(pdf_path),
|
||||
str(output_dir),
|
||||
args.dpi,
|
||||
args.min_confidence,
|
||||
args.skip_ocr
|
||||
))
|
||||
|
||||
# Move label
|
||||
dest_label = output_dir / split_name / 'labels' / label_path.name
|
||||
shutil.move(str(label_path), str(dest_label))
|
||||
if skipped_in_csv > 0 or retry_in_csv > 0:
|
||||
print(f" Skipped {skipped_in_csv} (already in DB), retrying {retry_in_csv} failed")
|
||||
|
||||
# Cleanup temp
|
||||
shutil.rmtree(output_dir / 'temp', ignore_errors=True)
|
||||
if not tasks:
|
||||
continue
|
||||
|
||||
# Update tasks_submitted counter for limit tracking
|
||||
stats['tasks_submitted'] += len(tasks)
|
||||
|
||||
if use_dual_pool:
|
||||
# Dual-pool mode using pre-initialized DualPoolCoordinator
|
||||
# (process_text_pdf, process_scanned_pdf already imported above)
|
||||
|
||||
# Convert tasks to new format
|
||||
documents = []
|
||||
for task in tasks:
|
||||
row_dict, pdf_path_str, output_dir_str, dpi, min_confidence, skip_ocr = task
|
||||
# Pre-classify PDF type
|
||||
try:
|
||||
is_text = is_text_pdf(pdf_path_str)
|
||||
except Exception:
|
||||
is_text = False
|
||||
|
||||
documents.append({
|
||||
"id": row_dict["DocumentId"],
|
||||
"row_dict": row_dict,
|
||||
"pdf_path": pdf_path_str,
|
||||
"output_dir": output_dir_str,
|
||||
"dpi": dpi,
|
||||
"min_confidence": min_confidence,
|
||||
"is_scanned": not is_text,
|
||||
"has_text": is_text,
|
||||
"text_length": 1000 if is_text else 0, # Approximate
|
||||
})
|
||||
|
||||
# Count task types
|
||||
text_count = sum(1 for d in documents if not d["is_scanned"])
|
||||
scan_count = len(documents) - text_count
|
||||
print(f" Text PDFs: {text_count}, Scanned PDFs: {scan_count}")
|
||||
|
||||
# Progress tracking with tqdm
|
||||
pbar = tqdm(total=len(documents), desc="Processing")
|
||||
|
||||
def on_result(task_result):
|
||||
"""Handle successful result."""
|
||||
result = task_result.data
|
||||
handle_result(result)
|
||||
pbar.update(1)
|
||||
|
||||
def on_error(task_id, error):
|
||||
"""Handle failed task."""
|
||||
handle_error(task_id, error)
|
||||
pbar.update(1)
|
||||
|
||||
# Process with pre-initialized coordinator (workers stay alive)
|
||||
results = dual_pool_coordinator.process_batch(
|
||||
documents=documents,
|
||||
cpu_task_fn=process_text_pdf,
|
||||
gpu_task_fn=process_scanned_pdf,
|
||||
on_result=on_result,
|
||||
on_error=on_error,
|
||||
id_field="id",
|
||||
)
|
||||
|
||||
pbar.close()
|
||||
|
||||
# Log summary
|
||||
successful = sum(1 for r in results if r.success)
|
||||
failed = len(results) - successful
|
||||
print(f" Batch complete: {successful} successful, {failed} failed")
|
||||
|
||||
else:
|
||||
# Single-pool mode (original behavior)
|
||||
print(f" Processing {len(tasks)} documents with {args.workers} workers...")
|
||||
|
||||
# Process documents in parallel (inside CSV loop for streaming)
|
||||
# Use single process for debugging or when workers=1
|
||||
if args.workers == 1:
|
||||
for task in tqdm(tasks, desc="Processing"):
|
||||
result = process_single_document(task)
|
||||
handle_result(result)
|
||||
else:
|
||||
# Parallel processing with worker initialization
|
||||
# Each worker initializes OCR engine once and reuses it
|
||||
with ProcessPoolExecutor(max_workers=args.workers, initializer=_init_worker) as executor:
|
||||
futures = {executor.submit(process_single_document, task): task[0]['DocumentId']
|
||||
for task in tasks}
|
||||
|
||||
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
|
||||
doc_id = futures[future]
|
||||
try:
|
||||
result = future.result()
|
||||
handle_result(result)
|
||||
except Exception as e:
|
||||
handle_error(doc_id, e)
|
||||
|
||||
# Flush remaining database batch after each CSV
|
||||
if db_batch:
|
||||
db.save_documents_batch(db_batch)
|
||||
db_batch.clear()
|
||||
|
||||
finally:
|
||||
# Shutdown dual-pool coordinator if it was started
|
||||
if dual_pool_coordinator is not None:
|
||||
dual_pool_coordinator.shutdown()
|
||||
|
||||
# Close temp file
|
||||
processed_items_writer.close()
|
||||
|
||||
# Use the in-memory counter instead of re-reading the file (performance fix)
|
||||
# processed_count already tracks the number of successfully processed items
|
||||
|
||||
# Cleanup processed_items temp file (not needed anymore)
|
||||
processed_items_file.unlink(missing_ok=True)
|
||||
|
||||
# Close database connection
|
||||
db.close()
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 50)
|
||||
@@ -441,17 +684,22 @@ def main():
|
||||
print(f"Total documents: {stats['total']}")
|
||||
print(f"Successful: {stats['successful']}")
|
||||
print(f"Failed: {stats['failed']}")
|
||||
print(f"Skipped (OCR): {stats['skipped']}")
|
||||
print(f"Skipped (no PDF): {stats['skipped']}")
|
||||
print(f"Skipped (in DB): {stats['skipped_db']}")
|
||||
print(f"Retried (failed): {stats['retried']}")
|
||||
print(f"Total annotations: {stats['annotations']}")
|
||||
print(f"\nDataset split:")
|
||||
print(f" Train: {len(splits['train'])} documents")
|
||||
print(f" Val: {len(splits['val'])} documents")
|
||||
print(f" Test: {len(splits['test'])} documents")
|
||||
print(f"\nImages saved to: {output_dir / 'temp'}")
|
||||
print(f"Labels stored in: PostgreSQL database")
|
||||
print(f"\nAnnotations by field:")
|
||||
for field, count in stats['by_field'].items():
|
||||
print(f" {field}: {count}")
|
||||
print(f"\nOutput: {output_dir}")
|
||||
print(f"Report: {args.report}")
|
||||
shard_files = report_writer.get_shard_files()
|
||||
if len(shard_files) > 1:
|
||||
print(f"\nReport files ({len(shard_files)}):")
|
||||
for sf in shard_files:
|
||||
print(f" - {sf}")
|
||||
else:
|
||||
print(f"\nReport: {shard_files[0] if shard_files else args.report}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user