Files
invoice-master-poc-v2/src/cli/reprocess_failed.py
2026-01-17 18:55:46 +01:00

425 lines
16 KiB
Python

#!/usr/bin/env python3
"""
Re-process failed matches and store detailed information including OCR values,
CSV values, and source CSV filename in a new table.
"""
import argparse
import json
import glob
import os
import sys
import time
from pathlib import Path
from datetime import datetime
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
from tqdm import tqdm
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from src.data.db import DocumentDB
from src.data.csv_loader import CSVLoader
from src.normalize.normalizer import normalize_field
def create_failed_match_table(db: DocumentDB):
"""Create the failed_match_details table."""
conn = db.connect()
with conn.cursor() as cursor:
cursor.execute("""
DROP TABLE IF EXISTS failed_match_details;
CREATE TABLE failed_match_details (
id SERIAL PRIMARY KEY,
document_id TEXT NOT NULL,
field_name TEXT NOT NULL,
csv_value TEXT,
csv_value_normalized TEXT,
ocr_value TEXT,
ocr_value_normalized TEXT,
all_ocr_candidates JSONB,
matched BOOLEAN DEFAULT FALSE,
match_score REAL,
pdf_path TEXT,
pdf_type TEXT,
csv_filename TEXT,
page_no INTEGER,
bbox JSONB,
error TEXT,
reprocessed_at TIMESTAMPTZ DEFAULT NOW(),
UNIQUE(document_id, field_name)
);
CREATE INDEX IF NOT EXISTS idx_failed_match_document_id ON failed_match_details(document_id);
CREATE INDEX IF NOT EXISTS idx_failed_match_field_name ON failed_match_details(field_name);
CREATE INDEX IF NOT EXISTS idx_failed_match_csv_filename ON failed_match_details(csv_filename);
CREATE INDEX IF NOT EXISTS idx_failed_match_matched ON failed_match_details(matched);
""")
conn.commit()
print("Created table: failed_match_details")
def get_failed_documents(db: DocumentDB) -> list:
"""Get all documents that have at least one failed field match."""
conn = db.connect()
with conn.cursor() as cursor:
cursor.execute("""
SELECT DISTINCT fr.document_id, d.pdf_path, d.pdf_type
FROM field_results fr
JOIN documents d ON fr.document_id = d.document_id
WHERE fr.matched = false
ORDER BY fr.document_id
""")
return [{'document_id': row[0], 'pdf_path': row[1], 'pdf_type': row[2]}
for row in cursor.fetchall()]
def get_failed_fields_for_document(db: DocumentDB, doc_id: str) -> list:
"""Get all failed field results for a document."""
conn = db.connect()
with conn.cursor() as cursor:
cursor.execute("""
SELECT field_name, csv_value, error
FROM field_results
WHERE document_id = %s AND matched = false
""", (doc_id,))
return [{'field_name': row[0], 'csv_value': row[1], 'error': row[2]}
for row in cursor.fetchall()]
# Cache for CSV data
_csv_cache = {}
def build_csv_cache(csv_files: list):
"""Build a cache of document_id to csv_filename mapping."""
global _csv_cache
_csv_cache = {}
for csv_file in csv_files:
csv_filename = os.path.basename(csv_file)
loader = CSVLoader(csv_file)
for row in loader.iter_rows():
if row.DocumentId not in _csv_cache:
_csv_cache[row.DocumentId] = csv_filename
def find_csv_filename(doc_id: str) -> str:
"""Find which CSV file contains the document ID."""
return _csv_cache.get(doc_id, None)
def init_worker():
"""Initialize worker process."""
import os
import warnings
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["GLOG_minloglevel"] = "2"
warnings.filterwarnings("ignore")
def process_single_document(args):
"""Process a single document and extract OCR values for failed fields."""
doc_info, failed_fields, csv_filename = args
doc_id = doc_info['document_id']
pdf_path = doc_info['pdf_path']
pdf_type = doc_info['pdf_type']
results = []
# Try to extract OCR from PDF
try:
if pdf_path and os.path.exists(pdf_path):
from src.pdf import PDFDocument
from src.ocr import OCREngine
pdf_doc = PDFDocument(pdf_path)
is_scanned = pdf_doc.detect_type() == "scanned"
# Collect all OCR text blocks
all_ocr_texts = []
if is_scanned:
# Use OCR for scanned PDFs
ocr_engine = OCREngine()
for page_no in range(pdf_doc.page_count):
# Render page to image
img = pdf_doc.render_page(page_no, dpi=150)
if img is None:
continue
# OCR the image
ocr_results = ocr_engine.extract_from_image(img)
for block in ocr_results:
all_ocr_texts.append({
'text': block.get('text', ''),
'bbox': block.get('bbox'),
'page_no': page_no
})
else:
# Use text extraction for text PDFs
for page_no in range(pdf_doc.page_count):
tokens = list(pdf_doc.extract_text_tokens(page_no))
for token in tokens:
all_ocr_texts.append({
'text': token.text,
'bbox': token.bbox,
'page_no': page_no
})
# For each failed field, try to find matching OCR
for field in failed_fields:
field_name = field['field_name']
csv_value = field['csv_value']
error = field['error']
# Normalize CSV value
csv_normalized = normalize_field(field_name, csv_value) if csv_value else None
# Try to find best match in OCR
best_score = 0
best_ocr = None
best_bbox = None
best_page = None
for ocr_block in all_ocr_texts:
ocr_text = ocr_block['text']
if not ocr_text:
continue
ocr_normalized = normalize_field(field_name, ocr_text)
# Calculate similarity
if csv_normalized and ocr_normalized:
# Check substring match
if csv_normalized in ocr_normalized:
score = len(csv_normalized) / max(len(ocr_normalized), 1)
if score > best_score:
best_score = score
best_ocr = ocr_text
best_bbox = ocr_block['bbox']
best_page = ocr_block['page_no']
elif ocr_normalized in csv_normalized:
score = len(ocr_normalized) / max(len(csv_normalized), 1)
if score > best_score:
best_score = score
best_ocr = ocr_text
best_bbox = ocr_block['bbox']
best_page = ocr_block['page_no']
# Exact match
elif csv_normalized == ocr_normalized:
best_score = 1.0
best_ocr = ocr_text
best_bbox = ocr_block['bbox']
best_page = ocr_block['page_no']
break
results.append({
'document_id': doc_id,
'field_name': field_name,
'csv_value': csv_value,
'csv_value_normalized': csv_normalized,
'ocr_value': best_ocr,
'ocr_value_normalized': normalize_field(field_name, best_ocr) if best_ocr else None,
'all_ocr_candidates': [t['text'] for t in all_ocr_texts[:100]], # Limit to 100
'matched': best_score > 0.8,
'match_score': best_score,
'pdf_path': pdf_path,
'pdf_type': pdf_type,
'csv_filename': csv_filename,
'page_no': best_page,
'bbox': list(best_bbox) if best_bbox else None,
'error': error
})
else:
# PDF not found
for field in failed_fields:
results.append({
'document_id': doc_id,
'field_name': field['field_name'],
'csv_value': field['csv_value'],
'csv_value_normalized': normalize_field(field['field_name'], field['csv_value']) if field['csv_value'] else None,
'ocr_value': None,
'ocr_value_normalized': None,
'all_ocr_candidates': [],
'matched': False,
'match_score': 0,
'pdf_path': pdf_path,
'pdf_type': pdf_type,
'csv_filename': csv_filename,
'page_no': None,
'bbox': None,
'error': f"PDF not found: {pdf_path}"
})
except Exception as e:
for field in failed_fields:
results.append({
'document_id': doc_id,
'field_name': field['field_name'],
'csv_value': field['csv_value'],
'csv_value_normalized': None,
'ocr_value': None,
'ocr_value_normalized': None,
'all_ocr_candidates': [],
'matched': False,
'match_score': 0,
'pdf_path': pdf_path,
'pdf_type': pdf_type,
'csv_filename': csv_filename,
'page_no': None,
'bbox': None,
'error': str(e)
})
return results
def save_results_batch(db: DocumentDB, results: list):
"""Save results to failed_match_details table."""
if not results:
return
conn = db.connect()
with conn.cursor() as cursor:
for r in results:
cursor.execute("""
INSERT INTO failed_match_details
(document_id, field_name, csv_value, csv_value_normalized,
ocr_value, ocr_value_normalized, all_ocr_candidates,
matched, match_score, pdf_path, pdf_type, csv_filename,
page_no, bbox, error)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (document_id, field_name) DO UPDATE SET
csv_value = EXCLUDED.csv_value,
csv_value_normalized = EXCLUDED.csv_value_normalized,
ocr_value = EXCLUDED.ocr_value,
ocr_value_normalized = EXCLUDED.ocr_value_normalized,
all_ocr_candidates = EXCLUDED.all_ocr_candidates,
matched = EXCLUDED.matched,
match_score = EXCLUDED.match_score,
pdf_path = EXCLUDED.pdf_path,
pdf_type = EXCLUDED.pdf_type,
csv_filename = EXCLUDED.csv_filename,
page_no = EXCLUDED.page_no,
bbox = EXCLUDED.bbox,
error = EXCLUDED.error,
reprocessed_at = NOW()
""", (
r['document_id'],
r['field_name'],
r['csv_value'],
r['csv_value_normalized'],
r['ocr_value'],
r['ocr_value_normalized'],
json.dumps(r['all_ocr_candidates']),
r['matched'],
r['match_score'],
r['pdf_path'],
r['pdf_type'],
r['csv_filename'],
r['page_no'],
json.dumps(r['bbox']) if r['bbox'] else None,
r['error']
))
conn.commit()
def main():
parser = argparse.ArgumentParser(description='Re-process failed matches')
parser.add_argument('--csv', required=True, help='CSV files glob pattern')
parser.add_argument('--pdf-dir', required=True, help='PDF directory')
parser.add_argument('--workers', type=int, default=3, help='Number of workers')
parser.add_argument('--limit', type=int, help='Limit number of documents to process')
args = parser.parse_args()
# Expand CSV glob
csv_files = sorted(glob.glob(args.csv))
print(f"Found {len(csv_files)} CSV files")
# Build CSV cache
print("Building CSV filename cache...")
build_csv_cache(csv_files)
print(f"Cached {len(_csv_cache)} document IDs")
# Connect to database
db = DocumentDB()
db.connect()
# Create new table
create_failed_match_table(db)
# Get all failed documents
print("Fetching failed documents...")
failed_docs = get_failed_documents(db)
print(f"Found {len(failed_docs)} documents with failed matches")
if args.limit:
failed_docs = failed_docs[:args.limit]
print(f"Limited to {len(failed_docs)} documents")
# Prepare tasks
tasks = []
for doc in failed_docs:
failed_fields = get_failed_fields_for_document(db, doc['document_id'])
csv_filename = find_csv_filename(doc['document_id'])
if failed_fields:
tasks.append((doc, failed_fields, csv_filename))
print(f"Processing {len(tasks)} documents with {args.workers} workers...")
# Process with multiprocessing
total_results = 0
batch_results = []
batch_size = 50
with ProcessPoolExecutor(max_workers=args.workers, initializer=init_worker) as executor:
futures = {executor.submit(process_single_document, task): task[0]['document_id']
for task in tasks}
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
doc_id = futures[future]
try:
results = future.result(timeout=120)
batch_results.extend(results)
total_results += len(results)
# Save in batches
if len(batch_results) >= batch_size:
save_results_batch(db, batch_results)
batch_results = []
except TimeoutError:
print(f"\nTimeout processing {doc_id}")
except Exception as e:
print(f"\nError processing {doc_id}: {e}")
# Save remaining results
if batch_results:
save_results_batch(db, batch_results)
print(f"\nDone! Saved {total_results} failed match records to failed_match_details table")
# Show summary
conn = db.connect()
with conn.cursor() as cursor:
cursor.execute("""
SELECT field_name, COUNT(*) as total,
COUNT(*) FILTER (WHERE ocr_value IS NOT NULL) as has_ocr,
COALESCE(AVG(match_score), 0) as avg_score
FROM failed_match_details
GROUP BY field_name
ORDER BY total DESC
""")
print("\nSummary by field:")
print("-" * 70)
print(f"{'Field':<35} {'Total':>8} {'Has OCR':>10} {'Avg Score':>12}")
print("-" * 70)
for row in cursor.fetchall():
print(f"{row[0]:<35} {row[1]:>8} {row[2]:>10} {row[3]:>12.2f}")
db.close()
if __name__ == '__main__':
main()