425 lines
16 KiB
Python
425 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Re-process failed matches and store detailed information including OCR values,
|
|
CSV values, and source CSV filename in a new table.
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import glob
|
|
import os
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
|
|
from tqdm import tqdm
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
|
|
|
from src.data.db import DocumentDB
|
|
from src.data.csv_loader import CSVLoader
|
|
from src.normalize.normalizer import normalize_field
|
|
|
|
|
|
def create_failed_match_table(db: DocumentDB):
|
|
"""Create the failed_match_details table."""
|
|
conn = db.connect()
|
|
with conn.cursor() as cursor:
|
|
cursor.execute("""
|
|
DROP TABLE IF EXISTS failed_match_details;
|
|
|
|
CREATE TABLE failed_match_details (
|
|
id SERIAL PRIMARY KEY,
|
|
document_id TEXT NOT NULL,
|
|
field_name TEXT NOT NULL,
|
|
csv_value TEXT,
|
|
csv_value_normalized TEXT,
|
|
ocr_value TEXT,
|
|
ocr_value_normalized TEXT,
|
|
all_ocr_candidates JSONB,
|
|
matched BOOLEAN DEFAULT FALSE,
|
|
match_score REAL,
|
|
pdf_path TEXT,
|
|
pdf_type TEXT,
|
|
csv_filename TEXT,
|
|
page_no INTEGER,
|
|
bbox JSONB,
|
|
error TEXT,
|
|
reprocessed_at TIMESTAMPTZ DEFAULT NOW(),
|
|
|
|
UNIQUE(document_id, field_name)
|
|
);
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_failed_match_document_id ON failed_match_details(document_id);
|
|
CREATE INDEX IF NOT EXISTS idx_failed_match_field_name ON failed_match_details(field_name);
|
|
CREATE INDEX IF NOT EXISTS idx_failed_match_csv_filename ON failed_match_details(csv_filename);
|
|
CREATE INDEX IF NOT EXISTS idx_failed_match_matched ON failed_match_details(matched);
|
|
""")
|
|
conn.commit()
|
|
print("Created table: failed_match_details")
|
|
|
|
|
|
def get_failed_documents(db: DocumentDB) -> list:
|
|
"""Get all documents that have at least one failed field match."""
|
|
conn = db.connect()
|
|
with conn.cursor() as cursor:
|
|
cursor.execute("""
|
|
SELECT DISTINCT fr.document_id, d.pdf_path, d.pdf_type
|
|
FROM field_results fr
|
|
JOIN documents d ON fr.document_id = d.document_id
|
|
WHERE fr.matched = false
|
|
ORDER BY fr.document_id
|
|
""")
|
|
return [{'document_id': row[0], 'pdf_path': row[1], 'pdf_type': row[2]}
|
|
for row in cursor.fetchall()]
|
|
|
|
|
|
def get_failed_fields_for_document(db: DocumentDB, doc_id: str) -> list:
|
|
"""Get all failed field results for a document."""
|
|
conn = db.connect()
|
|
with conn.cursor() as cursor:
|
|
cursor.execute("""
|
|
SELECT field_name, csv_value, error
|
|
FROM field_results
|
|
WHERE document_id = %s AND matched = false
|
|
""", (doc_id,))
|
|
return [{'field_name': row[0], 'csv_value': row[1], 'error': row[2]}
|
|
for row in cursor.fetchall()]
|
|
|
|
|
|
# Cache for CSV data
|
|
_csv_cache = {}
|
|
|
|
def build_csv_cache(csv_files: list):
|
|
"""Build a cache of document_id to csv_filename mapping."""
|
|
global _csv_cache
|
|
_csv_cache = {}
|
|
|
|
for csv_file in csv_files:
|
|
csv_filename = os.path.basename(csv_file)
|
|
loader = CSVLoader(csv_file)
|
|
for row in loader.iter_rows():
|
|
if row.DocumentId not in _csv_cache:
|
|
_csv_cache[row.DocumentId] = csv_filename
|
|
|
|
|
|
def find_csv_filename(doc_id: str) -> str:
|
|
"""Find which CSV file contains the document ID."""
|
|
return _csv_cache.get(doc_id, None)
|
|
|
|
|
|
def init_worker():
|
|
"""Initialize worker process."""
|
|
import os
|
|
import warnings
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
os.environ["GLOG_minloglevel"] = "2"
|
|
warnings.filterwarnings("ignore")
|
|
|
|
|
|
def process_single_document(args):
|
|
"""Process a single document and extract OCR values for failed fields."""
|
|
doc_info, failed_fields, csv_filename = args
|
|
doc_id = doc_info['document_id']
|
|
pdf_path = doc_info['pdf_path']
|
|
pdf_type = doc_info['pdf_type']
|
|
|
|
results = []
|
|
|
|
# Try to extract OCR from PDF
|
|
try:
|
|
if pdf_path and os.path.exists(pdf_path):
|
|
from src.pdf import PDFDocument
|
|
from src.ocr import OCREngine
|
|
|
|
pdf_doc = PDFDocument(pdf_path)
|
|
is_scanned = pdf_doc.detect_type() == "scanned"
|
|
|
|
# Collect all OCR text blocks
|
|
all_ocr_texts = []
|
|
|
|
if is_scanned:
|
|
# Use OCR for scanned PDFs
|
|
ocr_engine = OCREngine()
|
|
for page_no in range(pdf_doc.page_count):
|
|
# Render page to image
|
|
img = pdf_doc.render_page(page_no, dpi=150)
|
|
if img is None:
|
|
continue
|
|
|
|
# OCR the image
|
|
ocr_results = ocr_engine.extract_from_image(img)
|
|
for block in ocr_results:
|
|
all_ocr_texts.append({
|
|
'text': block.get('text', ''),
|
|
'bbox': block.get('bbox'),
|
|
'page_no': page_no
|
|
})
|
|
else:
|
|
# Use text extraction for text PDFs
|
|
for page_no in range(pdf_doc.page_count):
|
|
tokens = list(pdf_doc.extract_text_tokens(page_no))
|
|
for token in tokens:
|
|
all_ocr_texts.append({
|
|
'text': token.text,
|
|
'bbox': token.bbox,
|
|
'page_no': page_no
|
|
})
|
|
|
|
# For each failed field, try to find matching OCR
|
|
for field in failed_fields:
|
|
field_name = field['field_name']
|
|
csv_value = field['csv_value']
|
|
error = field['error']
|
|
|
|
# Normalize CSV value
|
|
csv_normalized = normalize_field(field_name, csv_value) if csv_value else None
|
|
|
|
# Try to find best match in OCR
|
|
best_score = 0
|
|
best_ocr = None
|
|
best_bbox = None
|
|
best_page = None
|
|
|
|
for ocr_block in all_ocr_texts:
|
|
ocr_text = ocr_block['text']
|
|
if not ocr_text:
|
|
continue
|
|
ocr_normalized = normalize_field(field_name, ocr_text)
|
|
|
|
# Calculate similarity
|
|
if csv_normalized and ocr_normalized:
|
|
# Check substring match
|
|
if csv_normalized in ocr_normalized:
|
|
score = len(csv_normalized) / max(len(ocr_normalized), 1)
|
|
if score > best_score:
|
|
best_score = score
|
|
best_ocr = ocr_text
|
|
best_bbox = ocr_block['bbox']
|
|
best_page = ocr_block['page_no']
|
|
elif ocr_normalized in csv_normalized:
|
|
score = len(ocr_normalized) / max(len(csv_normalized), 1)
|
|
if score > best_score:
|
|
best_score = score
|
|
best_ocr = ocr_text
|
|
best_bbox = ocr_block['bbox']
|
|
best_page = ocr_block['page_no']
|
|
# Exact match
|
|
elif csv_normalized == ocr_normalized:
|
|
best_score = 1.0
|
|
best_ocr = ocr_text
|
|
best_bbox = ocr_block['bbox']
|
|
best_page = ocr_block['page_no']
|
|
break
|
|
|
|
results.append({
|
|
'document_id': doc_id,
|
|
'field_name': field_name,
|
|
'csv_value': csv_value,
|
|
'csv_value_normalized': csv_normalized,
|
|
'ocr_value': best_ocr,
|
|
'ocr_value_normalized': normalize_field(field_name, best_ocr) if best_ocr else None,
|
|
'all_ocr_candidates': [t['text'] for t in all_ocr_texts[:100]], # Limit to 100
|
|
'matched': best_score > 0.8,
|
|
'match_score': best_score,
|
|
'pdf_path': pdf_path,
|
|
'pdf_type': pdf_type,
|
|
'csv_filename': csv_filename,
|
|
'page_no': best_page,
|
|
'bbox': list(best_bbox) if best_bbox else None,
|
|
'error': error
|
|
})
|
|
else:
|
|
# PDF not found
|
|
for field in failed_fields:
|
|
results.append({
|
|
'document_id': doc_id,
|
|
'field_name': field['field_name'],
|
|
'csv_value': field['csv_value'],
|
|
'csv_value_normalized': normalize_field(field['field_name'], field['csv_value']) if field['csv_value'] else None,
|
|
'ocr_value': None,
|
|
'ocr_value_normalized': None,
|
|
'all_ocr_candidates': [],
|
|
'matched': False,
|
|
'match_score': 0,
|
|
'pdf_path': pdf_path,
|
|
'pdf_type': pdf_type,
|
|
'csv_filename': csv_filename,
|
|
'page_no': None,
|
|
'bbox': None,
|
|
'error': f"PDF not found: {pdf_path}"
|
|
})
|
|
|
|
except Exception as e:
|
|
for field in failed_fields:
|
|
results.append({
|
|
'document_id': doc_id,
|
|
'field_name': field['field_name'],
|
|
'csv_value': field['csv_value'],
|
|
'csv_value_normalized': None,
|
|
'ocr_value': None,
|
|
'ocr_value_normalized': None,
|
|
'all_ocr_candidates': [],
|
|
'matched': False,
|
|
'match_score': 0,
|
|
'pdf_path': pdf_path,
|
|
'pdf_type': pdf_type,
|
|
'csv_filename': csv_filename,
|
|
'page_no': None,
|
|
'bbox': None,
|
|
'error': str(e)
|
|
})
|
|
|
|
return results
|
|
|
|
|
|
def save_results_batch(db: DocumentDB, results: list):
|
|
"""Save results to failed_match_details table."""
|
|
if not results:
|
|
return
|
|
|
|
conn = db.connect()
|
|
with conn.cursor() as cursor:
|
|
for r in results:
|
|
cursor.execute("""
|
|
INSERT INTO failed_match_details
|
|
(document_id, field_name, csv_value, csv_value_normalized,
|
|
ocr_value, ocr_value_normalized, all_ocr_candidates,
|
|
matched, match_score, pdf_path, pdf_type, csv_filename,
|
|
page_no, bbox, error)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
|
ON CONFLICT (document_id, field_name) DO UPDATE SET
|
|
csv_value = EXCLUDED.csv_value,
|
|
csv_value_normalized = EXCLUDED.csv_value_normalized,
|
|
ocr_value = EXCLUDED.ocr_value,
|
|
ocr_value_normalized = EXCLUDED.ocr_value_normalized,
|
|
all_ocr_candidates = EXCLUDED.all_ocr_candidates,
|
|
matched = EXCLUDED.matched,
|
|
match_score = EXCLUDED.match_score,
|
|
pdf_path = EXCLUDED.pdf_path,
|
|
pdf_type = EXCLUDED.pdf_type,
|
|
csv_filename = EXCLUDED.csv_filename,
|
|
page_no = EXCLUDED.page_no,
|
|
bbox = EXCLUDED.bbox,
|
|
error = EXCLUDED.error,
|
|
reprocessed_at = NOW()
|
|
""", (
|
|
r['document_id'],
|
|
r['field_name'],
|
|
r['csv_value'],
|
|
r['csv_value_normalized'],
|
|
r['ocr_value'],
|
|
r['ocr_value_normalized'],
|
|
json.dumps(r['all_ocr_candidates']),
|
|
r['matched'],
|
|
r['match_score'],
|
|
r['pdf_path'],
|
|
r['pdf_type'],
|
|
r['csv_filename'],
|
|
r['page_no'],
|
|
json.dumps(r['bbox']) if r['bbox'] else None,
|
|
r['error']
|
|
))
|
|
conn.commit()
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Re-process failed matches')
|
|
parser.add_argument('--csv', required=True, help='CSV files glob pattern')
|
|
parser.add_argument('--pdf-dir', required=True, help='PDF directory')
|
|
parser.add_argument('--workers', type=int, default=3, help='Number of workers')
|
|
parser.add_argument('--limit', type=int, help='Limit number of documents to process')
|
|
args = parser.parse_args()
|
|
|
|
# Expand CSV glob
|
|
csv_files = sorted(glob.glob(args.csv))
|
|
print(f"Found {len(csv_files)} CSV files")
|
|
|
|
# Build CSV cache
|
|
print("Building CSV filename cache...")
|
|
build_csv_cache(csv_files)
|
|
print(f"Cached {len(_csv_cache)} document IDs")
|
|
|
|
# Connect to database
|
|
db = DocumentDB()
|
|
db.connect()
|
|
|
|
# Create new table
|
|
create_failed_match_table(db)
|
|
|
|
# Get all failed documents
|
|
print("Fetching failed documents...")
|
|
failed_docs = get_failed_documents(db)
|
|
print(f"Found {len(failed_docs)} documents with failed matches")
|
|
|
|
if args.limit:
|
|
failed_docs = failed_docs[:args.limit]
|
|
print(f"Limited to {len(failed_docs)} documents")
|
|
|
|
# Prepare tasks
|
|
tasks = []
|
|
for doc in failed_docs:
|
|
failed_fields = get_failed_fields_for_document(db, doc['document_id'])
|
|
csv_filename = find_csv_filename(doc['document_id'])
|
|
if failed_fields:
|
|
tasks.append((doc, failed_fields, csv_filename))
|
|
|
|
print(f"Processing {len(tasks)} documents with {args.workers} workers...")
|
|
|
|
# Process with multiprocessing
|
|
total_results = 0
|
|
batch_results = []
|
|
batch_size = 50
|
|
|
|
with ProcessPoolExecutor(max_workers=args.workers, initializer=init_worker) as executor:
|
|
futures = {executor.submit(process_single_document, task): task[0]['document_id']
|
|
for task in tasks}
|
|
|
|
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
|
|
doc_id = futures[future]
|
|
try:
|
|
results = future.result(timeout=120)
|
|
batch_results.extend(results)
|
|
total_results += len(results)
|
|
|
|
# Save in batches
|
|
if len(batch_results) >= batch_size:
|
|
save_results_batch(db, batch_results)
|
|
batch_results = []
|
|
|
|
except TimeoutError:
|
|
print(f"\nTimeout processing {doc_id}")
|
|
except Exception as e:
|
|
print(f"\nError processing {doc_id}: {e}")
|
|
|
|
# Save remaining results
|
|
if batch_results:
|
|
save_results_batch(db, batch_results)
|
|
|
|
print(f"\nDone! Saved {total_results} failed match records to failed_match_details table")
|
|
|
|
# Show summary
|
|
conn = db.connect()
|
|
with conn.cursor() as cursor:
|
|
cursor.execute("""
|
|
SELECT field_name, COUNT(*) as total,
|
|
COUNT(*) FILTER (WHERE ocr_value IS NOT NULL) as has_ocr,
|
|
COALESCE(AVG(match_score), 0) as avg_score
|
|
FROM failed_match_details
|
|
GROUP BY field_name
|
|
ORDER BY total DESC
|
|
""")
|
|
print("\nSummary by field:")
|
|
print("-" * 70)
|
|
print(f"{'Field':<35} {'Total':>8} {'Has OCR':>10} {'Avg Score':>12}")
|
|
print("-" * 70)
|
|
for row in cursor.fetchall():
|
|
print(f"{row[0]:<35} {row[1]:>8} {row[2]:>10} {row[3]:>12.2f}")
|
|
|
|
db.close()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|