code issue fix
This commit is contained in:
@@ -10,6 +10,7 @@ import sys
|
|||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
|
import shutil
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -107,6 +108,7 @@ def process_single_document(args_tuple):
|
|||||||
Returns:
|
Returns:
|
||||||
dict with results
|
dict with results
|
||||||
"""
|
"""
|
||||||
|
import shutil
|
||||||
row_dict, pdf_path_str, output_dir_str, dpi, min_confidence, skip_ocr = args_tuple
|
row_dict, pdf_path_str, output_dir_str, dpi, min_confidence, skip_ocr = args_tuple
|
||||||
|
|
||||||
# Import inside worker to avoid pickling issues
|
# Import inside worker to avoid pickling issues
|
||||||
@@ -121,6 +123,11 @@ def process_single_document(args_tuple):
|
|||||||
output_dir = Path(output_dir_str)
|
output_dir = Path(output_dir_str)
|
||||||
doc_id = row_dict['DocumentId']
|
doc_id = row_dict['DocumentId']
|
||||||
|
|
||||||
|
# Clean up existing temp folder for this document (for re-matching)
|
||||||
|
temp_doc_dir = output_dir / 'temp' / doc_id
|
||||||
|
if temp_doc_dir.exists():
|
||||||
|
shutil.rmtree(temp_doc_dir, ignore_errors=True)
|
||||||
|
|
||||||
report = AutoLabelReport(document_id=doc_id)
|
report = AutoLabelReport(document_id=doc_id)
|
||||||
report.pdf_path = str(pdf_path)
|
report.pdf_path = str(pdf_path)
|
||||||
# Store metadata fields from CSV
|
# Store metadata fields from CSV
|
||||||
@@ -602,6 +609,9 @@ def main():
|
|||||||
else:
|
else:
|
||||||
remaining_limit = float('inf')
|
remaining_limit = float('inf')
|
||||||
|
|
||||||
|
# Collect doc_ids that need retry (for batch delete)
|
||||||
|
retry_doc_ids = []
|
||||||
|
|
||||||
for row in rows:
|
for row in rows:
|
||||||
# Stop adding tasks if we've reached the limit
|
# Stop adding tasks if we've reached the limit
|
||||||
if len(tasks) >= remaining_limit:
|
if len(tasks) >= remaining_limit:
|
||||||
@@ -622,6 +632,7 @@ def main():
|
|||||||
if db_status is False:
|
if db_status is False:
|
||||||
stats['retried'] += 1
|
stats['retried'] += 1
|
||||||
retry_in_csv += 1
|
retry_in_csv += 1
|
||||||
|
retry_doc_ids.append(doc_id)
|
||||||
|
|
||||||
pdf_path = single_loader.get_pdf_path(row)
|
pdf_path = single_loader.get_pdf_path(row)
|
||||||
if not pdf_path:
|
if not pdf_path:
|
||||||
@@ -637,12 +648,12 @@ def main():
|
|||||||
'Bankgiro': row.Bankgiro,
|
'Bankgiro': row.Bankgiro,
|
||||||
'Plusgiro': row.Plusgiro,
|
'Plusgiro': row.Plusgiro,
|
||||||
'Amount': row.Amount,
|
'Amount': row.Amount,
|
||||||
# New fields
|
# New fields for matching
|
||||||
'supplier_organisation_number': row.supplier_organisation_number,
|
'supplier_organisation_number': row.supplier_organisation_number,
|
||||||
'supplier_accounts': row.supplier_accounts,
|
'supplier_accounts': row.supplier_accounts,
|
||||||
|
'customer_number': row.customer_number,
|
||||||
# Metadata fields (not for matching, but for database storage)
|
# Metadata fields (not for matching, but for database storage)
|
||||||
'split': row.split,
|
'split': row.split,
|
||||||
'customer_number': row.customer_number,
|
|
||||||
'supplier_name': row.supplier_name,
|
'supplier_name': row.supplier_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -658,6 +669,22 @@ def main():
|
|||||||
if skipped_in_csv > 0 or retry_in_csv > 0:
|
if skipped_in_csv > 0 or retry_in_csv > 0:
|
||||||
print(f" Skipped {skipped_in_csv} (already in DB), retrying {retry_in_csv} failed")
|
print(f" Skipped {skipped_in_csv} (already in DB), retrying {retry_in_csv} failed")
|
||||||
|
|
||||||
|
# Clean up retry documents: delete from database and remove temp folders
|
||||||
|
if retry_doc_ids:
|
||||||
|
# Batch delete from database (field_results will be cascade deleted)
|
||||||
|
with db.connect().cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
"DELETE FROM documents WHERE document_id = ANY(%s)",
|
||||||
|
(retry_doc_ids,)
|
||||||
|
)
|
||||||
|
db.connect().commit()
|
||||||
|
# Remove temp folders
|
||||||
|
for doc_id in retry_doc_ids:
|
||||||
|
temp_doc_dir = output_dir / 'temp' / doc_id
|
||||||
|
if temp_doc_dir.exists():
|
||||||
|
shutil.rmtree(temp_doc_dir, ignore_errors=True)
|
||||||
|
print(f" Cleaned up {len(retry_doc_ids)} retry documents (DB + temp folders)")
|
||||||
|
|
||||||
if not tasks:
|
if not tasks:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
424
src/cli/reprocess_failed.py
Normal file
424
src/cli/reprocess_failed.py
Normal file
@@ -0,0 +1,424 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Re-process failed matches and store detailed information including OCR values,
|
||||||
|
CSV values, and source CSV filename in a new table.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
|
|
||||||
|
from src.data.db import DocumentDB
|
||||||
|
from src.data.csv_loader import CSVLoader
|
||||||
|
from src.normalize.normalizer import normalize_field
|
||||||
|
|
||||||
|
|
||||||
|
def create_failed_match_table(db: DocumentDB):
|
||||||
|
"""Create the failed_match_details table."""
|
||||||
|
conn = db.connect()
|
||||||
|
with conn.cursor() as cursor:
|
||||||
|
cursor.execute("""
|
||||||
|
DROP TABLE IF EXISTS failed_match_details;
|
||||||
|
|
||||||
|
CREATE TABLE failed_match_details (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
document_id TEXT NOT NULL,
|
||||||
|
field_name TEXT NOT NULL,
|
||||||
|
csv_value TEXT,
|
||||||
|
csv_value_normalized TEXT,
|
||||||
|
ocr_value TEXT,
|
||||||
|
ocr_value_normalized TEXT,
|
||||||
|
all_ocr_candidates JSONB,
|
||||||
|
matched BOOLEAN DEFAULT FALSE,
|
||||||
|
match_score REAL,
|
||||||
|
pdf_path TEXT,
|
||||||
|
pdf_type TEXT,
|
||||||
|
csv_filename TEXT,
|
||||||
|
page_no INTEGER,
|
||||||
|
bbox JSONB,
|
||||||
|
error TEXT,
|
||||||
|
reprocessed_at TIMESTAMPTZ DEFAULT NOW(),
|
||||||
|
|
||||||
|
UNIQUE(document_id, field_name)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_failed_match_document_id ON failed_match_details(document_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_failed_match_field_name ON failed_match_details(field_name);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_failed_match_csv_filename ON failed_match_details(csv_filename);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_failed_match_matched ON failed_match_details(matched);
|
||||||
|
""")
|
||||||
|
conn.commit()
|
||||||
|
print("Created table: failed_match_details")
|
||||||
|
|
||||||
|
|
||||||
|
def get_failed_documents(db: DocumentDB) -> list:
|
||||||
|
"""Get all documents that have at least one failed field match."""
|
||||||
|
conn = db.connect()
|
||||||
|
with conn.cursor() as cursor:
|
||||||
|
cursor.execute("""
|
||||||
|
SELECT DISTINCT fr.document_id, d.pdf_path, d.pdf_type
|
||||||
|
FROM field_results fr
|
||||||
|
JOIN documents d ON fr.document_id = d.document_id
|
||||||
|
WHERE fr.matched = false
|
||||||
|
ORDER BY fr.document_id
|
||||||
|
""")
|
||||||
|
return [{'document_id': row[0], 'pdf_path': row[1], 'pdf_type': row[2]}
|
||||||
|
for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
|
||||||
|
def get_failed_fields_for_document(db: DocumentDB, doc_id: str) -> list:
|
||||||
|
"""Get all failed field results for a document."""
|
||||||
|
conn = db.connect()
|
||||||
|
with conn.cursor() as cursor:
|
||||||
|
cursor.execute("""
|
||||||
|
SELECT field_name, csv_value, error
|
||||||
|
FROM field_results
|
||||||
|
WHERE document_id = %s AND matched = false
|
||||||
|
""", (doc_id,))
|
||||||
|
return [{'field_name': row[0], 'csv_value': row[1], 'error': row[2]}
|
||||||
|
for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
|
||||||
|
# Cache for CSV data
|
||||||
|
_csv_cache = {}
|
||||||
|
|
||||||
|
def build_csv_cache(csv_files: list):
|
||||||
|
"""Build a cache of document_id to csv_filename mapping."""
|
||||||
|
global _csv_cache
|
||||||
|
_csv_cache = {}
|
||||||
|
|
||||||
|
for csv_file in csv_files:
|
||||||
|
csv_filename = os.path.basename(csv_file)
|
||||||
|
loader = CSVLoader(csv_file)
|
||||||
|
for row in loader.iter_rows():
|
||||||
|
if row.DocumentId not in _csv_cache:
|
||||||
|
_csv_cache[row.DocumentId] = csv_filename
|
||||||
|
|
||||||
|
|
||||||
|
def find_csv_filename(doc_id: str) -> str:
|
||||||
|
"""Find which CSV file contains the document ID."""
|
||||||
|
return _csv_cache.get(doc_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
def init_worker():
|
||||||
|
"""Initialize worker process."""
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
|
os.environ["GLOG_minloglevel"] = "2"
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
|
||||||
|
def process_single_document(args):
|
||||||
|
"""Process a single document and extract OCR values for failed fields."""
|
||||||
|
doc_info, failed_fields, csv_filename = args
|
||||||
|
doc_id = doc_info['document_id']
|
||||||
|
pdf_path = doc_info['pdf_path']
|
||||||
|
pdf_type = doc_info['pdf_type']
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Try to extract OCR from PDF
|
||||||
|
try:
|
||||||
|
if pdf_path and os.path.exists(pdf_path):
|
||||||
|
from src.pdf import PDFDocument
|
||||||
|
from src.ocr import OCREngine
|
||||||
|
|
||||||
|
pdf_doc = PDFDocument(pdf_path)
|
||||||
|
is_scanned = pdf_doc.detect_type() == "scanned"
|
||||||
|
|
||||||
|
# Collect all OCR text blocks
|
||||||
|
all_ocr_texts = []
|
||||||
|
|
||||||
|
if is_scanned:
|
||||||
|
# Use OCR for scanned PDFs
|
||||||
|
ocr_engine = OCREngine()
|
||||||
|
for page_no in range(pdf_doc.page_count):
|
||||||
|
# Render page to image
|
||||||
|
img = pdf_doc.render_page(page_no, dpi=150)
|
||||||
|
if img is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# OCR the image
|
||||||
|
ocr_results = ocr_engine.extract_from_image(img)
|
||||||
|
for block in ocr_results:
|
||||||
|
all_ocr_texts.append({
|
||||||
|
'text': block.get('text', ''),
|
||||||
|
'bbox': block.get('bbox'),
|
||||||
|
'page_no': page_no
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# Use text extraction for text PDFs
|
||||||
|
for page_no in range(pdf_doc.page_count):
|
||||||
|
tokens = list(pdf_doc.extract_text_tokens(page_no))
|
||||||
|
for token in tokens:
|
||||||
|
all_ocr_texts.append({
|
||||||
|
'text': token.text,
|
||||||
|
'bbox': token.bbox,
|
||||||
|
'page_no': page_no
|
||||||
|
})
|
||||||
|
|
||||||
|
# For each failed field, try to find matching OCR
|
||||||
|
for field in failed_fields:
|
||||||
|
field_name = field['field_name']
|
||||||
|
csv_value = field['csv_value']
|
||||||
|
error = field['error']
|
||||||
|
|
||||||
|
# Normalize CSV value
|
||||||
|
csv_normalized = normalize_field(field_name, csv_value) if csv_value else None
|
||||||
|
|
||||||
|
# Try to find best match in OCR
|
||||||
|
best_score = 0
|
||||||
|
best_ocr = None
|
||||||
|
best_bbox = None
|
||||||
|
best_page = None
|
||||||
|
|
||||||
|
for ocr_block in all_ocr_texts:
|
||||||
|
ocr_text = ocr_block['text']
|
||||||
|
if not ocr_text:
|
||||||
|
continue
|
||||||
|
ocr_normalized = normalize_field(field_name, ocr_text)
|
||||||
|
|
||||||
|
# Calculate similarity
|
||||||
|
if csv_normalized and ocr_normalized:
|
||||||
|
# Check substring match
|
||||||
|
if csv_normalized in ocr_normalized:
|
||||||
|
score = len(csv_normalized) / max(len(ocr_normalized), 1)
|
||||||
|
if score > best_score:
|
||||||
|
best_score = score
|
||||||
|
best_ocr = ocr_text
|
||||||
|
best_bbox = ocr_block['bbox']
|
||||||
|
best_page = ocr_block['page_no']
|
||||||
|
elif ocr_normalized in csv_normalized:
|
||||||
|
score = len(ocr_normalized) / max(len(csv_normalized), 1)
|
||||||
|
if score > best_score:
|
||||||
|
best_score = score
|
||||||
|
best_ocr = ocr_text
|
||||||
|
best_bbox = ocr_block['bbox']
|
||||||
|
best_page = ocr_block['page_no']
|
||||||
|
# Exact match
|
||||||
|
elif csv_normalized == ocr_normalized:
|
||||||
|
best_score = 1.0
|
||||||
|
best_ocr = ocr_text
|
||||||
|
best_bbox = ocr_block['bbox']
|
||||||
|
best_page = ocr_block['page_no']
|
||||||
|
break
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
'document_id': doc_id,
|
||||||
|
'field_name': field_name,
|
||||||
|
'csv_value': csv_value,
|
||||||
|
'csv_value_normalized': csv_normalized,
|
||||||
|
'ocr_value': best_ocr,
|
||||||
|
'ocr_value_normalized': normalize_field(field_name, best_ocr) if best_ocr else None,
|
||||||
|
'all_ocr_candidates': [t['text'] for t in all_ocr_texts[:100]], # Limit to 100
|
||||||
|
'matched': best_score > 0.8,
|
||||||
|
'match_score': best_score,
|
||||||
|
'pdf_path': pdf_path,
|
||||||
|
'pdf_type': pdf_type,
|
||||||
|
'csv_filename': csv_filename,
|
||||||
|
'page_no': best_page,
|
||||||
|
'bbox': list(best_bbox) if best_bbox else None,
|
||||||
|
'error': error
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# PDF not found
|
||||||
|
for field in failed_fields:
|
||||||
|
results.append({
|
||||||
|
'document_id': doc_id,
|
||||||
|
'field_name': field['field_name'],
|
||||||
|
'csv_value': field['csv_value'],
|
||||||
|
'csv_value_normalized': normalize_field(field['field_name'], field['csv_value']) if field['csv_value'] else None,
|
||||||
|
'ocr_value': None,
|
||||||
|
'ocr_value_normalized': None,
|
||||||
|
'all_ocr_candidates': [],
|
||||||
|
'matched': False,
|
||||||
|
'match_score': 0,
|
||||||
|
'pdf_path': pdf_path,
|
||||||
|
'pdf_type': pdf_type,
|
||||||
|
'csv_filename': csv_filename,
|
||||||
|
'page_no': None,
|
||||||
|
'bbox': None,
|
||||||
|
'error': f"PDF not found: {pdf_path}"
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
for field in failed_fields:
|
||||||
|
results.append({
|
||||||
|
'document_id': doc_id,
|
||||||
|
'field_name': field['field_name'],
|
||||||
|
'csv_value': field['csv_value'],
|
||||||
|
'csv_value_normalized': None,
|
||||||
|
'ocr_value': None,
|
||||||
|
'ocr_value_normalized': None,
|
||||||
|
'all_ocr_candidates': [],
|
||||||
|
'matched': False,
|
||||||
|
'match_score': 0,
|
||||||
|
'pdf_path': pdf_path,
|
||||||
|
'pdf_type': pdf_type,
|
||||||
|
'csv_filename': csv_filename,
|
||||||
|
'page_no': None,
|
||||||
|
'bbox': None,
|
||||||
|
'error': str(e)
|
||||||
|
})
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def save_results_batch(db: DocumentDB, results: list):
|
||||||
|
"""Save results to failed_match_details table."""
|
||||||
|
if not results:
|
||||||
|
return
|
||||||
|
|
||||||
|
conn = db.connect()
|
||||||
|
with conn.cursor() as cursor:
|
||||||
|
for r in results:
|
||||||
|
cursor.execute("""
|
||||||
|
INSERT INTO failed_match_details
|
||||||
|
(document_id, field_name, csv_value, csv_value_normalized,
|
||||||
|
ocr_value, ocr_value_normalized, all_ocr_candidates,
|
||||||
|
matched, match_score, pdf_path, pdf_type, csv_filename,
|
||||||
|
page_no, bbox, error)
|
||||||
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||||
|
ON CONFLICT (document_id, field_name) DO UPDATE SET
|
||||||
|
csv_value = EXCLUDED.csv_value,
|
||||||
|
csv_value_normalized = EXCLUDED.csv_value_normalized,
|
||||||
|
ocr_value = EXCLUDED.ocr_value,
|
||||||
|
ocr_value_normalized = EXCLUDED.ocr_value_normalized,
|
||||||
|
all_ocr_candidates = EXCLUDED.all_ocr_candidates,
|
||||||
|
matched = EXCLUDED.matched,
|
||||||
|
match_score = EXCLUDED.match_score,
|
||||||
|
pdf_path = EXCLUDED.pdf_path,
|
||||||
|
pdf_type = EXCLUDED.pdf_type,
|
||||||
|
csv_filename = EXCLUDED.csv_filename,
|
||||||
|
page_no = EXCLUDED.page_no,
|
||||||
|
bbox = EXCLUDED.bbox,
|
||||||
|
error = EXCLUDED.error,
|
||||||
|
reprocessed_at = NOW()
|
||||||
|
""", (
|
||||||
|
r['document_id'],
|
||||||
|
r['field_name'],
|
||||||
|
r['csv_value'],
|
||||||
|
r['csv_value_normalized'],
|
||||||
|
r['ocr_value'],
|
||||||
|
r['ocr_value_normalized'],
|
||||||
|
json.dumps(r['all_ocr_candidates']),
|
||||||
|
r['matched'],
|
||||||
|
r['match_score'],
|
||||||
|
r['pdf_path'],
|
||||||
|
r['pdf_type'],
|
||||||
|
r['csv_filename'],
|
||||||
|
r['page_no'],
|
||||||
|
json.dumps(r['bbox']) if r['bbox'] else None,
|
||||||
|
r['error']
|
||||||
|
))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Re-process failed matches')
|
||||||
|
parser.add_argument('--csv', required=True, help='CSV files glob pattern')
|
||||||
|
parser.add_argument('--pdf-dir', required=True, help='PDF directory')
|
||||||
|
parser.add_argument('--workers', type=int, default=3, help='Number of workers')
|
||||||
|
parser.add_argument('--limit', type=int, help='Limit number of documents to process')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Expand CSV glob
|
||||||
|
csv_files = sorted(glob.glob(args.csv))
|
||||||
|
print(f"Found {len(csv_files)} CSV files")
|
||||||
|
|
||||||
|
# Build CSV cache
|
||||||
|
print("Building CSV filename cache...")
|
||||||
|
build_csv_cache(csv_files)
|
||||||
|
print(f"Cached {len(_csv_cache)} document IDs")
|
||||||
|
|
||||||
|
# Connect to database
|
||||||
|
db = DocumentDB()
|
||||||
|
db.connect()
|
||||||
|
|
||||||
|
# Create new table
|
||||||
|
create_failed_match_table(db)
|
||||||
|
|
||||||
|
# Get all failed documents
|
||||||
|
print("Fetching failed documents...")
|
||||||
|
failed_docs = get_failed_documents(db)
|
||||||
|
print(f"Found {len(failed_docs)} documents with failed matches")
|
||||||
|
|
||||||
|
if args.limit:
|
||||||
|
failed_docs = failed_docs[:args.limit]
|
||||||
|
print(f"Limited to {len(failed_docs)} documents")
|
||||||
|
|
||||||
|
# Prepare tasks
|
||||||
|
tasks = []
|
||||||
|
for doc in failed_docs:
|
||||||
|
failed_fields = get_failed_fields_for_document(db, doc['document_id'])
|
||||||
|
csv_filename = find_csv_filename(doc['document_id'])
|
||||||
|
if failed_fields:
|
||||||
|
tasks.append((doc, failed_fields, csv_filename))
|
||||||
|
|
||||||
|
print(f"Processing {len(tasks)} documents with {args.workers} workers...")
|
||||||
|
|
||||||
|
# Process with multiprocessing
|
||||||
|
total_results = 0
|
||||||
|
batch_results = []
|
||||||
|
batch_size = 50
|
||||||
|
|
||||||
|
with ProcessPoolExecutor(max_workers=args.workers, initializer=init_worker) as executor:
|
||||||
|
futures = {executor.submit(process_single_document, task): task[0]['document_id']
|
||||||
|
for task in tasks}
|
||||||
|
|
||||||
|
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
|
||||||
|
doc_id = futures[future]
|
||||||
|
try:
|
||||||
|
results = future.result(timeout=120)
|
||||||
|
batch_results.extend(results)
|
||||||
|
total_results += len(results)
|
||||||
|
|
||||||
|
# Save in batches
|
||||||
|
if len(batch_results) >= batch_size:
|
||||||
|
save_results_batch(db, batch_results)
|
||||||
|
batch_results = []
|
||||||
|
|
||||||
|
except TimeoutError:
|
||||||
|
print(f"\nTimeout processing {doc_id}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nError processing {doc_id}: {e}")
|
||||||
|
|
||||||
|
# Save remaining results
|
||||||
|
if batch_results:
|
||||||
|
save_results_batch(db, batch_results)
|
||||||
|
|
||||||
|
print(f"\nDone! Saved {total_results} failed match records to failed_match_details table")
|
||||||
|
|
||||||
|
# Show summary
|
||||||
|
conn = db.connect()
|
||||||
|
with conn.cursor() as cursor:
|
||||||
|
cursor.execute("""
|
||||||
|
SELECT field_name, COUNT(*) as total,
|
||||||
|
COUNT(*) FILTER (WHERE ocr_value IS NOT NULL) as has_ocr,
|
||||||
|
COALESCE(AVG(match_score), 0) as avg_score
|
||||||
|
FROM failed_match_details
|
||||||
|
GROUP BY field_name
|
||||||
|
ORDER BY total DESC
|
||||||
|
""")
|
||||||
|
print("\nSummary by field:")
|
||||||
|
print("-" * 70)
|
||||||
|
print(f"{'Field':<35} {'Total':>8} {'Has OCR':>10} {'Avg Score':>12}")
|
||||||
|
print("-" * 70)
|
||||||
|
for row in cursor.fetchall():
|
||||||
|
print(f"{row[0]:<35} {row[1]:>8} {row[2]:>10} {row[3]:>12.2f}")
|
||||||
|
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@@ -27,7 +27,7 @@ class InvoiceRow:
|
|||||||
Amount: Decimal | None = None
|
Amount: Decimal | None = None
|
||||||
# New fields
|
# New fields
|
||||||
split: str | None = None # train/test split indicator
|
split: str | None = None # train/test split indicator
|
||||||
customer_number: str | None = None # Customer number (no matching needed)
|
customer_number: str | None = None # Customer number (needs matching)
|
||||||
supplier_name: str | None = None # Supplier name (no matching)
|
supplier_name: str | None = None # Supplier name (no matching)
|
||||||
supplier_organisation_number: str | None = None # Swedish org number (needs matching)
|
supplier_organisation_number: str | None = None # Swedish org number (needs matching)
|
||||||
supplier_accounts: str | None = None # Supplier accounts (needs matching)
|
supplier_accounts: str | None = None # Supplier accounts (needs matching)
|
||||||
@@ -198,22 +198,30 @@ class CSVLoader:
|
|||||||
value = value.strip()
|
value = value.strip()
|
||||||
return value if value else None
|
return value if value else None
|
||||||
|
|
||||||
|
def _get_field(self, row: dict, *keys: str) -> str | None:
|
||||||
|
"""Get field value trying multiple possible column names."""
|
||||||
|
for key in keys:
|
||||||
|
value = row.get(key)
|
||||||
|
if value is not None:
|
||||||
|
return value
|
||||||
|
return None
|
||||||
|
|
||||||
def _parse_row(self, row: dict) -> InvoiceRow | None:
|
def _parse_row(self, row: dict) -> InvoiceRow | None:
|
||||||
"""Parse a single CSV row into InvoiceRow."""
|
"""Parse a single CSV row into InvoiceRow."""
|
||||||
doc_id = self._parse_string(row.get('DocumentId'))
|
doc_id = self._parse_string(self._get_field(row, 'DocumentId', 'document_id'))
|
||||||
if not doc_id:
|
if not doc_id:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return InvoiceRow(
|
return InvoiceRow(
|
||||||
DocumentId=doc_id,
|
DocumentId=doc_id,
|
||||||
InvoiceDate=self._parse_date(row.get('InvoiceDate')),
|
InvoiceDate=self._parse_date(self._get_field(row, 'InvoiceDate', 'invoice_date')),
|
||||||
InvoiceNumber=self._parse_string(row.get('InvoiceNumber')),
|
InvoiceNumber=self._parse_string(self._get_field(row, 'InvoiceNumber', 'invoice_number')),
|
||||||
InvoiceDueDate=self._parse_date(row.get('InvoiceDueDate')),
|
InvoiceDueDate=self._parse_date(self._get_field(row, 'InvoiceDueDate', 'invoice_due_date')),
|
||||||
OCR=self._parse_string(row.get('OCR')),
|
OCR=self._parse_string(self._get_field(row, 'OCR', 'ocr')),
|
||||||
Message=self._parse_string(row.get('Message')),
|
Message=self._parse_string(self._get_field(row, 'Message', 'message')),
|
||||||
Bankgiro=self._parse_string(row.get('Bankgiro')),
|
Bankgiro=self._parse_string(self._get_field(row, 'Bankgiro', 'bankgiro')),
|
||||||
Plusgiro=self._parse_string(row.get('Plusgiro')),
|
Plusgiro=self._parse_string(self._get_field(row, 'Plusgiro', 'plusgiro')),
|
||||||
Amount=self._parse_amount(row.get('Amount')),
|
Amount=self._parse_amount(self._get_field(row, 'Amount', 'amount', 'invoice_data_amount')),
|
||||||
# New fields
|
# New fields
|
||||||
split=self._parse_string(row.get('split')),
|
split=self._parse_string(row.get('split')),
|
||||||
customer_number=self._parse_string(row.get('customer_number')),
|
customer_number=self._parse_string(row.get('customer_number')),
|
||||||
|
|||||||
@@ -219,7 +219,7 @@ class FieldMatcher:
|
|||||||
# Note: Amount is excluded because short numbers like "451" can incorrectly match
|
# Note: Amount is excluded because short numbers like "451" can incorrectly match
|
||||||
# in OCR payment lines or other unrelated text
|
# in OCR payment lines or other unrelated text
|
||||||
if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
|
if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
|
||||||
'supplier_organisation_number', 'supplier_accounts'):
|
'supplier_organisation_number', 'supplier_accounts', 'customer_number'):
|
||||||
substring_matches = self._find_substring_matches(page_tokens, value, field_name)
|
substring_matches = self._find_substring_matches(page_tokens, value, field_name)
|
||||||
matches.extend(substring_matches)
|
matches.extend(substring_matches)
|
||||||
|
|
||||||
@@ -369,7 +369,7 @@ class FieldMatcher:
|
|||||||
|
|
||||||
# Supported fields for substring matching
|
# Supported fields for substring matching
|
||||||
supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount',
|
supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount',
|
||||||
'supplier_organisation_number', 'supplier_accounts')
|
'supplier_organisation_number', 'supplier_accounts', 'customer_number')
|
||||||
if field_name not in supported_fields:
|
if field_name not in supported_fields:
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
@@ -383,10 +383,17 @@ class FieldMatcher:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Check if value appears as substring (using normalized text)
|
# Check if value appears as substring (using normalized text)
|
||||||
|
# Try case-sensitive first, then case-insensitive
|
||||||
if value in token_text_normalized:
|
if value in token_text_normalized:
|
||||||
# Verify it's a proper boundary match (not part of a larger number)
|
|
||||||
idx = token_text_normalized.find(value)
|
idx = token_text_normalized.find(value)
|
||||||
|
case_sensitive_match = True
|
||||||
|
elif value.lower() in token_text_normalized.lower():
|
||||||
|
idx = token_text_normalized.lower().find(value.lower())
|
||||||
|
case_sensitive_match = False
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Verify it's a proper boundary match (not part of a larger number)
|
||||||
# Check character before (if exists)
|
# Check character before (if exists)
|
||||||
if idx > 0:
|
if idx > 0:
|
||||||
char_before = token_text_normalized[idx - 1]
|
char_before = token_text_normalized[idx - 1]
|
||||||
@@ -417,12 +424,15 @@ class FieldMatcher:
|
|||||||
# Boost score if keyword is inline
|
# Boost score if keyword is inline
|
||||||
inline_boost = 0.1 if inline_context else 0
|
inline_boost = 0.1 if inline_context else 0
|
||||||
|
|
||||||
|
# Lower score for case-insensitive match
|
||||||
|
base_score = 0.75 if case_sensitive_match else 0.70
|
||||||
|
|
||||||
matches.append(Match(
|
matches.append(Match(
|
||||||
field=field_name,
|
field=field_name,
|
||||||
value=value,
|
value=value,
|
||||||
bbox=token.bbox, # Use full token bbox
|
bbox=token.bbox, # Use full token bbox
|
||||||
page_no=token.page_no,
|
page_no=token.page_no,
|
||||||
score=min(1.0, 0.75 + context_boost + inline_boost), # Lower than exact match
|
score=min(1.0, base_score + context_boost + inline_boost),
|
||||||
matched_text=token_text,
|
matched_text=token_text,
|
||||||
context_keywords=context_keywords + inline_context
|
context_keywords=context_keywords + inline_context
|
||||||
))
|
))
|
||||||
|
|||||||
@@ -260,6 +260,45 @@ class FieldNormalizer:
|
|||||||
|
|
||||||
return list(set(v for v in variants if v))
|
return list(set(v for v in variants if v))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def normalize_customer_number(value: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Normalize customer number.
|
||||||
|
|
||||||
|
Customer numbers can have various formats:
|
||||||
|
- Alphanumeric codes: 'EMM 256-6', 'ABC123', 'A-1234'
|
||||||
|
- Pure numbers: '12345', '123-456'
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
'EMM 256-6' -> ['EMM 256-6', 'EMM256-6', 'EMM2566']
|
||||||
|
'ABC 123' -> ['ABC 123', 'ABC123']
|
||||||
|
"""
|
||||||
|
value = FieldNormalizer.clean_text(value)
|
||||||
|
variants = [value]
|
||||||
|
|
||||||
|
# Version without spaces
|
||||||
|
no_space = value.replace(' ', '')
|
||||||
|
if no_space != value:
|
||||||
|
variants.append(no_space)
|
||||||
|
|
||||||
|
# Version without dashes
|
||||||
|
no_dash = value.replace('-', '')
|
||||||
|
if no_dash != value:
|
||||||
|
variants.append(no_dash)
|
||||||
|
|
||||||
|
# Version without spaces and dashes
|
||||||
|
clean = value.replace(' ', '').replace('-', '')
|
||||||
|
if clean != value and clean not in variants:
|
||||||
|
variants.append(clean)
|
||||||
|
|
||||||
|
# Uppercase and lowercase versions
|
||||||
|
if value.upper() != value:
|
||||||
|
variants.append(value.upper())
|
||||||
|
if value.lower() != value:
|
||||||
|
variants.append(value.lower())
|
||||||
|
|
||||||
|
return list(set(v for v in variants if v))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def normalize_amount(value: str) -> list[str]:
|
def normalize_amount(value: str) -> list[str]:
|
||||||
"""
|
"""
|
||||||
@@ -414,7 +453,7 @@ class FieldNormalizer:
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Ambiguous patterns - try both DD/MM and MM/DD interpretations
|
# Ambiguous patterns - try both DD/MM and MM/DD interpretations
|
||||||
ambiguous_patterns = [
|
ambiguous_patterns_4digit_year = [
|
||||||
# Format with / - could be DD/MM/YYYY (European) or MM/DD/YYYY (US)
|
# Format with / - could be DD/MM/YYYY (European) or MM/DD/YYYY (US)
|
||||||
r'^(\d{1,2})/(\d{1,2})/(\d{4})$',
|
r'^(\d{1,2})/(\d{1,2})/(\d{4})$',
|
||||||
# Format with . - typically European DD.MM.YYYY
|
# Format with . - typically European DD.MM.YYYY
|
||||||
@@ -423,6 +462,16 @@ class FieldNormalizer:
|
|||||||
r'^(\d{1,2})-(\d{1,2})-(\d{4})$',
|
r'^(\d{1,2})-(\d{1,2})-(\d{4})$',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Patterns with 2-digit year (common in Swedish invoices)
|
||||||
|
ambiguous_patterns_2digit_year = [
|
||||||
|
# Format DD.MM.YY (e.g., 02.08.25 for 2025-08-02)
|
||||||
|
r'^(\d{1,2})\.(\d{1,2})\.(\d{2})$',
|
||||||
|
# Format DD/MM/YY
|
||||||
|
r'^(\d{1,2})/(\d{1,2})/(\d{2})$',
|
||||||
|
# Format DD-MM-YY
|
||||||
|
r'^(\d{1,2})-(\d{1,2})-(\d{2})$',
|
||||||
|
]
|
||||||
|
|
||||||
# Try unambiguous patterns first
|
# Try unambiguous patterns first
|
||||||
for pattern, extractor in date_patterns:
|
for pattern, extractor in date_patterns:
|
||||||
match = re.match(pattern, value)
|
match = re.match(pattern, value)
|
||||||
@@ -434,9 +483,9 @@ class FieldNormalizer:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Try ambiguous patterns with both interpretations
|
# Try ambiguous patterns with 4-digit year
|
||||||
if not parsed_dates:
|
if not parsed_dates:
|
||||||
for pattern in ambiguous_patterns:
|
for pattern in ambiguous_patterns_4digit_year:
|
||||||
match = re.match(pattern, value)
|
match = re.match(pattern, value)
|
||||||
if match:
|
if match:
|
||||||
n1, n2, year = int(match[1]), int(match[2]), int(match[3])
|
n1, n2, year = int(match[1]), int(match[2]), int(match[3])
|
||||||
@@ -457,6 +506,31 @@ class FieldNormalizer:
|
|||||||
if parsed_dates:
|
if parsed_dates:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Try ambiguous patterns with 2-digit year (e.g., 02.08.25)
|
||||||
|
if not parsed_dates:
|
||||||
|
for pattern in ambiguous_patterns_2digit_year:
|
||||||
|
match = re.match(pattern, value)
|
||||||
|
if match:
|
||||||
|
n1, n2, yy = int(match[1]), int(match[2]), int(match[3])
|
||||||
|
# Convert 2-digit year to 4-digit (00-49 -> 2000s, 50-99 -> 1900s)
|
||||||
|
year = 2000 + yy if yy < 50 else 1900 + yy
|
||||||
|
|
||||||
|
# Try DD/MM/YY (European - day first, most common in Sweden)
|
||||||
|
try:
|
||||||
|
parsed_dates.append(datetime(year, n2, n1))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Try MM/DD/YY (US - month first) if different and valid
|
||||||
|
if n1 != n2:
|
||||||
|
try:
|
||||||
|
parsed_dates.append(datetime(year, n1, n2))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if parsed_dates:
|
||||||
|
break
|
||||||
|
|
||||||
# Try Swedish month names
|
# Try Swedish month names
|
||||||
if not parsed_dates:
|
if not parsed_dates:
|
||||||
for month_name, month_num in FieldNormalizer.SWEDISH_MONTHS.items():
|
for month_name, month_num in FieldNormalizer.SWEDISH_MONTHS.items():
|
||||||
@@ -527,6 +601,7 @@ NORMALIZERS: dict[str, Callable[[str], list[str]]] = {
|
|||||||
'InvoiceDueDate': FieldNormalizer.normalize_date,
|
'InvoiceDueDate': FieldNormalizer.normalize_date,
|
||||||
'supplier_organisation_number': FieldNormalizer.normalize_organisation_number,
|
'supplier_organisation_number': FieldNormalizer.normalize_organisation_number,
|
||||||
'supplier_accounts': FieldNormalizer.normalize_supplier_accounts,
|
'supplier_accounts': FieldNormalizer.normalize_supplier_accounts,
|
||||||
|
'customer_number': FieldNormalizer.normalize_customer_number,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,9 @@ class OCREngine:
|
|||||||
self,
|
self,
|
||||||
lang: str = "en",
|
lang: str = "en",
|
||||||
det_model_dir: str | None = None,
|
det_model_dir: str | None = None,
|
||||||
rec_model_dir: str | None = None
|
rec_model_dir: str | None = None,
|
||||||
|
use_doc_orientation_classify: bool = True,
|
||||||
|
use_doc_unwarping: bool = False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize OCR engine.
|
Initialize OCR engine.
|
||||||
@@ -69,6 +71,13 @@ class OCREngine:
|
|||||||
lang: Language code ('en', 'sv', 'ch', etc.)
|
lang: Language code ('en', 'sv', 'ch', etc.)
|
||||||
det_model_dir: Custom detection model directory
|
det_model_dir: Custom detection model directory
|
||||||
rec_model_dir: Custom recognition model directory
|
rec_model_dir: Custom recognition model directory
|
||||||
|
use_doc_orientation_classify: Whether to auto-detect and correct document orientation.
|
||||||
|
Default True to handle rotated documents.
|
||||||
|
use_doc_unwarping: Whether to use UVDoc document unwarping for curved/warped documents.
|
||||||
|
Default False to preserve original image layout,
|
||||||
|
especially important for payment OCR lines at bottom.
|
||||||
|
Enable for severely warped documents at the cost of potentially
|
||||||
|
losing bottom content.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
PaddleOCR 3.x automatically uses GPU if available via PaddlePaddle.
|
PaddleOCR 3.x automatically uses GPU if available via PaddlePaddle.
|
||||||
@@ -82,6 +91,12 @@ class OCREngine:
|
|||||||
# PaddleOCR 3.x init (use_gpu removed, device controlled by paddle.set_device)
|
# PaddleOCR 3.x init (use_gpu removed, device controlled by paddle.set_device)
|
||||||
init_params = {
|
init_params = {
|
||||||
'lang': lang,
|
'lang': lang,
|
||||||
|
# Enable orientation classification to handle rotated documents
|
||||||
|
'use_doc_orientation_classify': use_doc_orientation_classify,
|
||||||
|
# Disable UVDoc unwarping to preserve original image layout
|
||||||
|
# This prevents the bottom payment OCR line from being cut off
|
||||||
|
# For severely warped documents, enable this but expect potential content loss
|
||||||
|
'use_doc_unwarping': use_doc_unwarping,
|
||||||
}
|
}
|
||||||
if det_model_dir:
|
if det_model_dir:
|
||||||
init_params['text_detection_model_dir'] = det_model_dir
|
init_params['text_detection_model_dir'] = det_model_dir
|
||||||
@@ -95,7 +110,9 @@ class OCREngine:
|
|||||||
image: str | Path | np.ndarray,
|
image: str | Path | np.ndarray,
|
||||||
page_no: int = 0,
|
page_no: int = 0,
|
||||||
max_size: int = 2000,
|
max_size: int = 2000,
|
||||||
scale_to_pdf_points: float | None = None
|
scale_to_pdf_points: float | None = None,
|
||||||
|
scan_bottom_region: bool = True,
|
||||||
|
bottom_region_ratio: float = 0.15
|
||||||
) -> list[OCRToken]:
|
) -> list[OCRToken]:
|
||||||
"""
|
"""
|
||||||
Extract text tokens from an image.
|
Extract text tokens from an image.
|
||||||
@@ -108,19 +125,106 @@ class OCREngine:
|
|||||||
scale_to_pdf_points: If provided, scale bbox coordinates by this factor
|
scale_to_pdf_points: If provided, scale bbox coordinates by this factor
|
||||||
to convert from pixel to PDF point coordinates.
|
to convert from pixel to PDF point coordinates.
|
||||||
Use (72 / dpi) for images rendered at a specific DPI.
|
Use (72 / dpi) for images rendered at a specific DPI.
|
||||||
|
scan_bottom_region: If True, also scan the bottom region separately to catch
|
||||||
|
OCR payment lines that may be missed in full-page scan.
|
||||||
|
bottom_region_ratio: Ratio of page height to scan as bottom region (default 0.15 = 15%)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of OCRToken objects with bbox in pixel coords (or PDF points if scale_to_pdf_points is set)
|
List of OCRToken objects with bbox in pixel coords (or PDF points if scale_to_pdf_points is set)
|
||||||
"""
|
"""
|
||||||
result = self.extract_with_image(image, page_no, max_size, scale_to_pdf_points)
|
result = self.extract_with_image(image, page_no, max_size, scale_to_pdf_points)
|
||||||
return result.tokens
|
tokens = result.tokens
|
||||||
|
|
||||||
|
# Optionally scan bottom region separately for Swedish OCR payment lines
|
||||||
|
if scan_bottom_region:
|
||||||
|
bottom_tokens = self._scan_bottom_region(
|
||||||
|
image, page_no, max_size, scale_to_pdf_points, bottom_region_ratio
|
||||||
|
)
|
||||||
|
tokens = self._merge_tokens(tokens, bottom_tokens)
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def _scan_bottom_region(
|
||||||
|
self,
|
||||||
|
image: str | Path | np.ndarray,
|
||||||
|
page_no: int,
|
||||||
|
max_size: int,
|
||||||
|
scale_to_pdf_points: float | None,
|
||||||
|
bottom_ratio: float
|
||||||
|
) -> list[OCRToken]:
|
||||||
|
"""Scan the bottom region of the image separately."""
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
# Load image if path
|
||||||
|
if isinstance(image, (str, Path)):
|
||||||
|
img = PILImage.open(str(image))
|
||||||
|
img_array = np.array(img)
|
||||||
|
else:
|
||||||
|
img_array = image
|
||||||
|
|
||||||
|
h, w = img_array.shape[:2]
|
||||||
|
crop_y = int(h * (1 - bottom_ratio))
|
||||||
|
|
||||||
|
# Crop bottom region
|
||||||
|
bottom_crop = img_array[crop_y:h, :, :] if len(img_array.shape) == 3 else img_array[crop_y:h, :]
|
||||||
|
|
||||||
|
# OCR the cropped region (without recursive bottom scan to avoid infinite loop)
|
||||||
|
result = self.extract_with_image(
|
||||||
|
bottom_crop, page_no, max_size,
|
||||||
|
scale_to_pdf_points=None,
|
||||||
|
scan_bottom_region=False # Important: disable to prevent recursion
|
||||||
|
)
|
||||||
|
|
||||||
|
# Adjust bbox y-coordinates to full image space
|
||||||
|
adjusted_tokens = []
|
||||||
|
for token in result.tokens:
|
||||||
|
# Scale factor for coordinates
|
||||||
|
scale = scale_to_pdf_points if scale_to_pdf_points else 1.0
|
||||||
|
|
||||||
|
adjusted_bbox = (
|
||||||
|
token.bbox[0] * scale,
|
||||||
|
(token.bbox[1] + crop_y) * scale,
|
||||||
|
token.bbox[2] * scale,
|
||||||
|
(token.bbox[3] + crop_y) * scale
|
||||||
|
)
|
||||||
|
adjusted_tokens.append(OCRToken(
|
||||||
|
text=token.text,
|
||||||
|
bbox=adjusted_bbox,
|
||||||
|
confidence=token.confidence,
|
||||||
|
page_no=token.page_no
|
||||||
|
))
|
||||||
|
|
||||||
|
return adjusted_tokens
|
||||||
|
|
||||||
|
def _merge_tokens(
|
||||||
|
self,
|
||||||
|
main_tokens: list[OCRToken],
|
||||||
|
bottom_tokens: list[OCRToken]
|
||||||
|
) -> list[OCRToken]:
|
||||||
|
"""Merge tokens from main scan and bottom region scan, removing duplicates."""
|
||||||
|
if not bottom_tokens:
|
||||||
|
return main_tokens
|
||||||
|
|
||||||
|
# Create a set of existing token texts for deduplication
|
||||||
|
existing_texts = {t.text.strip() for t in main_tokens}
|
||||||
|
|
||||||
|
# Add bottom tokens that aren't duplicates
|
||||||
|
merged = list(main_tokens)
|
||||||
|
for token in bottom_tokens:
|
||||||
|
if token.text.strip() not in existing_texts:
|
||||||
|
merged.append(token)
|
||||||
|
existing_texts.add(token.text.strip())
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
def extract_with_image(
|
def extract_with_image(
|
||||||
self,
|
self,
|
||||||
image: str | Path | np.ndarray,
|
image: str | Path | np.ndarray,
|
||||||
page_no: int = 0,
|
page_no: int = 0,
|
||||||
max_size: int = 2000,
|
max_size: int = 2000,
|
||||||
scale_to_pdf_points: float | None = None
|
scale_to_pdf_points: float | None = None,
|
||||||
|
scan_bottom_region: bool = True,
|
||||||
|
bottom_region_ratio: float = 0.15
|
||||||
) -> OCRResult:
|
) -> OCRResult:
|
||||||
"""
|
"""
|
||||||
Extract text tokens from an image and return the preprocessed image.
|
Extract text tokens from an image and return the preprocessed image.
|
||||||
@@ -138,6 +242,9 @@ class OCREngine:
|
|||||||
scale_to_pdf_points: If provided, scale bbox coordinates by this factor
|
scale_to_pdf_points: If provided, scale bbox coordinates by this factor
|
||||||
to convert from pixel to PDF point coordinates.
|
to convert from pixel to PDF point coordinates.
|
||||||
Use (72 / dpi) for images rendered at a specific DPI.
|
Use (72 / dpi) for images rendered at a specific DPI.
|
||||||
|
scan_bottom_region: If True, also scan the bottom region separately to catch
|
||||||
|
OCR payment lines that may be missed in full-page scan.
|
||||||
|
bottom_region_ratio: Ratio of page height to scan as bottom region (default 0.15 = 15%)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
OCRResult with tokens and output_img (preprocessed image from PaddleOCR)
|
OCRResult with tokens and output_img (preprocessed image from PaddleOCR)
|
||||||
@@ -241,6 +348,13 @@ class OCREngine:
|
|||||||
if output_img is None:
|
if output_img is None:
|
||||||
output_img = img_array
|
output_img = img_array
|
||||||
|
|
||||||
|
# Optionally scan bottom region separately for Swedish OCR payment lines
|
||||||
|
if scan_bottom_region:
|
||||||
|
bottom_tokens = self._scan_bottom_region(
|
||||||
|
image, page_no, max_size, scale_to_pdf_points, bottom_region_ratio
|
||||||
|
)
|
||||||
|
tokens = self._merge_tokens(tokens, bottom_tokens)
|
||||||
|
|
||||||
return OCRResult(tokens=tokens, output_img=output_img)
|
return OCRResult(tokens=tokens, output_img=output_img)
|
||||||
|
|
||||||
def extract_from_pdf(
|
def extract_from_pdf(
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ def get_pdf_type(pdf_path: str | Path) -> PDFType:
|
|||||||
return "scanned"
|
return "scanned"
|
||||||
|
|
||||||
text_pages = 0
|
text_pages = 0
|
||||||
|
total_pages = len(doc)
|
||||||
for page in doc:
|
for page in doc:
|
||||||
text = page.get_text().strip()
|
text = page.get_text().strip()
|
||||||
if len(text) > 30:
|
if len(text) > 30:
|
||||||
@@ -64,7 +65,6 @@ def get_pdf_type(pdf_path: str | Path) -> PDFType:
|
|||||||
|
|
||||||
doc.close()
|
doc.close()
|
||||||
|
|
||||||
total_pages = len(doc)
|
|
||||||
if text_pages == total_pages:
|
if text_pages == total_pages:
|
||||||
return "text"
|
return "text"
|
||||||
elif text_pages == 0:
|
elif text_pages == 0:
|
||||||
|
|||||||
@@ -85,6 +85,7 @@ def process_text_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
Returns:
|
Returns:
|
||||||
Result dictionary with success status, annotations, and report.
|
Result dictionary with success status, annotations, and report.
|
||||||
"""
|
"""
|
||||||
|
import shutil
|
||||||
from src.data import AutoLabelReport, FieldMatchResult
|
from src.data import AutoLabelReport, FieldMatchResult
|
||||||
from src.pdf import PDFDocument
|
from src.pdf import PDFDocument
|
||||||
from src.matcher import FieldMatcher
|
from src.matcher import FieldMatcher
|
||||||
@@ -100,6 +101,11 @@ def process_text_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
doc_id = row_dict["DocumentId"]
|
doc_id = row_dict["DocumentId"]
|
||||||
|
|
||||||
|
# Clean up existing temp folder for this document (for re-matching)
|
||||||
|
temp_doc_dir = output_dir / "temp" / doc_id
|
||||||
|
if temp_doc_dir.exists():
|
||||||
|
shutil.rmtree(temp_doc_dir, ignore_errors=True)
|
||||||
|
|
||||||
report = AutoLabelReport(document_id=doc_id)
|
report = AutoLabelReport(document_id=doc_id)
|
||||||
report.pdf_path = str(pdf_path)
|
report.pdf_path = str(pdf_path)
|
||||||
report.pdf_type = "text"
|
report.pdf_type = "text"
|
||||||
@@ -218,6 +224,7 @@ def process_scanned_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
Returns:
|
Returns:
|
||||||
Result dictionary with success status, annotations, and report.
|
Result dictionary with success status, annotations, and report.
|
||||||
"""
|
"""
|
||||||
|
import shutil
|
||||||
from src.data import AutoLabelReport, FieldMatchResult
|
from src.data import AutoLabelReport, FieldMatchResult
|
||||||
from src.pdf import PDFDocument
|
from src.pdf import PDFDocument
|
||||||
from src.matcher import FieldMatcher
|
from src.matcher import FieldMatcher
|
||||||
@@ -233,6 +240,11 @@ def process_scanned_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
doc_id = row_dict["DocumentId"]
|
doc_id = row_dict["DocumentId"]
|
||||||
|
|
||||||
|
# Clean up existing temp folder for this document (for re-matching)
|
||||||
|
temp_doc_dir = output_dir / "temp" / doc_id
|
||||||
|
if temp_doc_dir.exists():
|
||||||
|
shutil.rmtree(temp_doc_dir, ignore_errors=True)
|
||||||
|
|
||||||
report = AutoLabelReport(document_id=doc_id)
|
report = AutoLabelReport(document_id=doc_id)
|
||||||
report.pdf_path = str(pdf_path)
|
report.pdf_path = str(pdf_path)
|
||||||
report.pdf_type = "scanned"
|
report.pdf_type = "scanned"
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ FIELD_CLASSES = {
|
|||||||
'Plusgiro': 5,
|
'Plusgiro': 5,
|
||||||
'Amount': 6,
|
'Amount': 6,
|
||||||
'supplier_organisation_number': 7,
|
'supplier_organisation_number': 7,
|
||||||
|
'customer_number': 8,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Fields that need matching but map to other YOLO classes
|
# Fields that need matching but map to other YOLO classes
|
||||||
@@ -41,6 +42,7 @@ CLASS_NAMES = [
|
|||||||
'plusgiro',
|
'plusgiro',
|
||||||
'amount',
|
'amount',
|
||||||
'supplier_org_number',
|
'supplier_org_number',
|
||||||
|
'customer_number',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user