code issue fix
This commit is contained in:
@@ -10,6 +10,7 @@ import sys
|
||||
import time
|
||||
import os
|
||||
import signal
|
||||
import shutil
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
@@ -107,6 +108,7 @@ def process_single_document(args_tuple):
|
||||
Returns:
|
||||
dict with results
|
||||
"""
|
||||
import shutil
|
||||
row_dict, pdf_path_str, output_dir_str, dpi, min_confidence, skip_ocr = args_tuple
|
||||
|
||||
# Import inside worker to avoid pickling issues
|
||||
@@ -121,6 +123,11 @@ def process_single_document(args_tuple):
|
||||
output_dir = Path(output_dir_str)
|
||||
doc_id = row_dict['DocumentId']
|
||||
|
||||
# Clean up existing temp folder for this document (for re-matching)
|
||||
temp_doc_dir = output_dir / 'temp' / doc_id
|
||||
if temp_doc_dir.exists():
|
||||
shutil.rmtree(temp_doc_dir, ignore_errors=True)
|
||||
|
||||
report = AutoLabelReport(document_id=doc_id)
|
||||
report.pdf_path = str(pdf_path)
|
||||
# Store metadata fields from CSV
|
||||
@@ -602,6 +609,9 @@ def main():
|
||||
else:
|
||||
remaining_limit = float('inf')
|
||||
|
||||
# Collect doc_ids that need retry (for batch delete)
|
||||
retry_doc_ids = []
|
||||
|
||||
for row in rows:
|
||||
# Stop adding tasks if we've reached the limit
|
||||
if len(tasks) >= remaining_limit:
|
||||
@@ -622,6 +632,7 @@ def main():
|
||||
if db_status is False:
|
||||
stats['retried'] += 1
|
||||
retry_in_csv += 1
|
||||
retry_doc_ids.append(doc_id)
|
||||
|
||||
pdf_path = single_loader.get_pdf_path(row)
|
||||
if not pdf_path:
|
||||
@@ -637,12 +648,12 @@ def main():
|
||||
'Bankgiro': row.Bankgiro,
|
||||
'Plusgiro': row.Plusgiro,
|
||||
'Amount': row.Amount,
|
||||
# New fields
|
||||
# New fields for matching
|
||||
'supplier_organisation_number': row.supplier_organisation_number,
|
||||
'supplier_accounts': row.supplier_accounts,
|
||||
'customer_number': row.customer_number,
|
||||
# Metadata fields (not for matching, but for database storage)
|
||||
'split': row.split,
|
||||
'customer_number': row.customer_number,
|
||||
'supplier_name': row.supplier_name,
|
||||
}
|
||||
|
||||
@@ -658,6 +669,22 @@ def main():
|
||||
if skipped_in_csv > 0 or retry_in_csv > 0:
|
||||
print(f" Skipped {skipped_in_csv} (already in DB), retrying {retry_in_csv} failed")
|
||||
|
||||
# Clean up retry documents: delete from database and remove temp folders
|
||||
if retry_doc_ids:
|
||||
# Batch delete from database (field_results will be cascade deleted)
|
||||
with db.connect().cursor() as cursor:
|
||||
cursor.execute(
|
||||
"DELETE FROM documents WHERE document_id = ANY(%s)",
|
||||
(retry_doc_ids,)
|
||||
)
|
||||
db.connect().commit()
|
||||
# Remove temp folders
|
||||
for doc_id in retry_doc_ids:
|
||||
temp_doc_dir = output_dir / 'temp' / doc_id
|
||||
if temp_doc_dir.exists():
|
||||
shutil.rmtree(temp_doc_dir, ignore_errors=True)
|
||||
print(f" Cleaned up {len(retry_doc_ids)} retry documents (DB + temp folders)")
|
||||
|
||||
if not tasks:
|
||||
continue
|
||||
|
||||
|
||||
424
src/cli/reprocess_failed.py
Normal file
424
src/cli/reprocess_failed.py
Normal file
@@ -0,0 +1,424 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Re-process failed matches and store detailed information including OCR values,
|
||||
CSV values, and source CSV filename in a new table.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import glob
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
|
||||
from tqdm import tqdm
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from src.data.db import DocumentDB
|
||||
from src.data.csv_loader import CSVLoader
|
||||
from src.normalize.normalizer import normalize_field
|
||||
|
||||
|
||||
def create_failed_match_table(db: DocumentDB):
|
||||
"""Create the failed_match_details table."""
|
||||
conn = db.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
DROP TABLE IF EXISTS failed_match_details;
|
||||
|
||||
CREATE TABLE failed_match_details (
|
||||
id SERIAL PRIMARY KEY,
|
||||
document_id TEXT NOT NULL,
|
||||
field_name TEXT NOT NULL,
|
||||
csv_value TEXT,
|
||||
csv_value_normalized TEXT,
|
||||
ocr_value TEXT,
|
||||
ocr_value_normalized TEXT,
|
||||
all_ocr_candidates JSONB,
|
||||
matched BOOLEAN DEFAULT FALSE,
|
||||
match_score REAL,
|
||||
pdf_path TEXT,
|
||||
pdf_type TEXT,
|
||||
csv_filename TEXT,
|
||||
page_no INTEGER,
|
||||
bbox JSONB,
|
||||
error TEXT,
|
||||
reprocessed_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
|
||||
UNIQUE(document_id, field_name)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_failed_match_document_id ON failed_match_details(document_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_failed_match_field_name ON failed_match_details(field_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_failed_match_csv_filename ON failed_match_details(csv_filename);
|
||||
CREATE INDEX IF NOT EXISTS idx_failed_match_matched ON failed_match_details(matched);
|
||||
""")
|
||||
conn.commit()
|
||||
print("Created table: failed_match_details")
|
||||
|
||||
|
||||
def get_failed_documents(db: DocumentDB) -> list:
|
||||
"""Get all documents that have at least one failed field match."""
|
||||
conn = db.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
SELECT DISTINCT fr.document_id, d.pdf_path, d.pdf_type
|
||||
FROM field_results fr
|
||||
JOIN documents d ON fr.document_id = d.document_id
|
||||
WHERE fr.matched = false
|
||||
ORDER BY fr.document_id
|
||||
""")
|
||||
return [{'document_id': row[0], 'pdf_path': row[1], 'pdf_type': row[2]}
|
||||
for row in cursor.fetchall()]
|
||||
|
||||
|
||||
def get_failed_fields_for_document(db: DocumentDB, doc_id: str) -> list:
|
||||
"""Get all failed field results for a document."""
|
||||
conn = db.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
SELECT field_name, csv_value, error
|
||||
FROM field_results
|
||||
WHERE document_id = %s AND matched = false
|
||||
""", (doc_id,))
|
||||
return [{'field_name': row[0], 'csv_value': row[1], 'error': row[2]}
|
||||
for row in cursor.fetchall()]
|
||||
|
||||
|
||||
# Cache for CSV data
|
||||
_csv_cache = {}
|
||||
|
||||
def build_csv_cache(csv_files: list):
|
||||
"""Build a cache of document_id to csv_filename mapping."""
|
||||
global _csv_cache
|
||||
_csv_cache = {}
|
||||
|
||||
for csv_file in csv_files:
|
||||
csv_filename = os.path.basename(csv_file)
|
||||
loader = CSVLoader(csv_file)
|
||||
for row in loader.iter_rows():
|
||||
if row.DocumentId not in _csv_cache:
|
||||
_csv_cache[row.DocumentId] = csv_filename
|
||||
|
||||
|
||||
def find_csv_filename(doc_id: str) -> str:
|
||||
"""Find which CSV file contains the document ID."""
|
||||
return _csv_cache.get(doc_id, None)
|
||||
|
||||
|
||||
def init_worker():
|
||||
"""Initialize worker process."""
|
||||
import os
|
||||
import warnings
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
os.environ["GLOG_minloglevel"] = "2"
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def process_single_document(args):
|
||||
"""Process a single document and extract OCR values for failed fields."""
|
||||
doc_info, failed_fields, csv_filename = args
|
||||
doc_id = doc_info['document_id']
|
||||
pdf_path = doc_info['pdf_path']
|
||||
pdf_type = doc_info['pdf_type']
|
||||
|
||||
results = []
|
||||
|
||||
# Try to extract OCR from PDF
|
||||
try:
|
||||
if pdf_path and os.path.exists(pdf_path):
|
||||
from src.pdf import PDFDocument
|
||||
from src.ocr import OCREngine
|
||||
|
||||
pdf_doc = PDFDocument(pdf_path)
|
||||
is_scanned = pdf_doc.detect_type() == "scanned"
|
||||
|
||||
# Collect all OCR text blocks
|
||||
all_ocr_texts = []
|
||||
|
||||
if is_scanned:
|
||||
# Use OCR for scanned PDFs
|
||||
ocr_engine = OCREngine()
|
||||
for page_no in range(pdf_doc.page_count):
|
||||
# Render page to image
|
||||
img = pdf_doc.render_page(page_no, dpi=150)
|
||||
if img is None:
|
||||
continue
|
||||
|
||||
# OCR the image
|
||||
ocr_results = ocr_engine.extract_from_image(img)
|
||||
for block in ocr_results:
|
||||
all_ocr_texts.append({
|
||||
'text': block.get('text', ''),
|
||||
'bbox': block.get('bbox'),
|
||||
'page_no': page_no
|
||||
})
|
||||
else:
|
||||
# Use text extraction for text PDFs
|
||||
for page_no in range(pdf_doc.page_count):
|
||||
tokens = list(pdf_doc.extract_text_tokens(page_no))
|
||||
for token in tokens:
|
||||
all_ocr_texts.append({
|
||||
'text': token.text,
|
||||
'bbox': token.bbox,
|
||||
'page_no': page_no
|
||||
})
|
||||
|
||||
# For each failed field, try to find matching OCR
|
||||
for field in failed_fields:
|
||||
field_name = field['field_name']
|
||||
csv_value = field['csv_value']
|
||||
error = field['error']
|
||||
|
||||
# Normalize CSV value
|
||||
csv_normalized = normalize_field(field_name, csv_value) if csv_value else None
|
||||
|
||||
# Try to find best match in OCR
|
||||
best_score = 0
|
||||
best_ocr = None
|
||||
best_bbox = None
|
||||
best_page = None
|
||||
|
||||
for ocr_block in all_ocr_texts:
|
||||
ocr_text = ocr_block['text']
|
||||
if not ocr_text:
|
||||
continue
|
||||
ocr_normalized = normalize_field(field_name, ocr_text)
|
||||
|
||||
# Calculate similarity
|
||||
if csv_normalized and ocr_normalized:
|
||||
# Check substring match
|
||||
if csv_normalized in ocr_normalized:
|
||||
score = len(csv_normalized) / max(len(ocr_normalized), 1)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_ocr = ocr_text
|
||||
best_bbox = ocr_block['bbox']
|
||||
best_page = ocr_block['page_no']
|
||||
elif ocr_normalized in csv_normalized:
|
||||
score = len(ocr_normalized) / max(len(csv_normalized), 1)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_ocr = ocr_text
|
||||
best_bbox = ocr_block['bbox']
|
||||
best_page = ocr_block['page_no']
|
||||
# Exact match
|
||||
elif csv_normalized == ocr_normalized:
|
||||
best_score = 1.0
|
||||
best_ocr = ocr_text
|
||||
best_bbox = ocr_block['bbox']
|
||||
best_page = ocr_block['page_no']
|
||||
break
|
||||
|
||||
results.append({
|
||||
'document_id': doc_id,
|
||||
'field_name': field_name,
|
||||
'csv_value': csv_value,
|
||||
'csv_value_normalized': csv_normalized,
|
||||
'ocr_value': best_ocr,
|
||||
'ocr_value_normalized': normalize_field(field_name, best_ocr) if best_ocr else None,
|
||||
'all_ocr_candidates': [t['text'] for t in all_ocr_texts[:100]], # Limit to 100
|
||||
'matched': best_score > 0.8,
|
||||
'match_score': best_score,
|
||||
'pdf_path': pdf_path,
|
||||
'pdf_type': pdf_type,
|
||||
'csv_filename': csv_filename,
|
||||
'page_no': best_page,
|
||||
'bbox': list(best_bbox) if best_bbox else None,
|
||||
'error': error
|
||||
})
|
||||
else:
|
||||
# PDF not found
|
||||
for field in failed_fields:
|
||||
results.append({
|
||||
'document_id': doc_id,
|
||||
'field_name': field['field_name'],
|
||||
'csv_value': field['csv_value'],
|
||||
'csv_value_normalized': normalize_field(field['field_name'], field['csv_value']) if field['csv_value'] else None,
|
||||
'ocr_value': None,
|
||||
'ocr_value_normalized': None,
|
||||
'all_ocr_candidates': [],
|
||||
'matched': False,
|
||||
'match_score': 0,
|
||||
'pdf_path': pdf_path,
|
||||
'pdf_type': pdf_type,
|
||||
'csv_filename': csv_filename,
|
||||
'page_no': None,
|
||||
'bbox': None,
|
||||
'error': f"PDF not found: {pdf_path}"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
for field in failed_fields:
|
||||
results.append({
|
||||
'document_id': doc_id,
|
||||
'field_name': field['field_name'],
|
||||
'csv_value': field['csv_value'],
|
||||
'csv_value_normalized': None,
|
||||
'ocr_value': None,
|
||||
'ocr_value_normalized': None,
|
||||
'all_ocr_candidates': [],
|
||||
'matched': False,
|
||||
'match_score': 0,
|
||||
'pdf_path': pdf_path,
|
||||
'pdf_type': pdf_type,
|
||||
'csv_filename': csv_filename,
|
||||
'page_no': None,
|
||||
'bbox': None,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def save_results_batch(db: DocumentDB, results: list):
|
||||
"""Save results to failed_match_details table."""
|
||||
if not results:
|
||||
return
|
||||
|
||||
conn = db.connect()
|
||||
with conn.cursor() as cursor:
|
||||
for r in results:
|
||||
cursor.execute("""
|
||||
INSERT INTO failed_match_details
|
||||
(document_id, field_name, csv_value, csv_value_normalized,
|
||||
ocr_value, ocr_value_normalized, all_ocr_candidates,
|
||||
matched, match_score, pdf_path, pdf_type, csv_filename,
|
||||
page_no, bbox, error)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
ON CONFLICT (document_id, field_name) DO UPDATE SET
|
||||
csv_value = EXCLUDED.csv_value,
|
||||
csv_value_normalized = EXCLUDED.csv_value_normalized,
|
||||
ocr_value = EXCLUDED.ocr_value,
|
||||
ocr_value_normalized = EXCLUDED.ocr_value_normalized,
|
||||
all_ocr_candidates = EXCLUDED.all_ocr_candidates,
|
||||
matched = EXCLUDED.matched,
|
||||
match_score = EXCLUDED.match_score,
|
||||
pdf_path = EXCLUDED.pdf_path,
|
||||
pdf_type = EXCLUDED.pdf_type,
|
||||
csv_filename = EXCLUDED.csv_filename,
|
||||
page_no = EXCLUDED.page_no,
|
||||
bbox = EXCLUDED.bbox,
|
||||
error = EXCLUDED.error,
|
||||
reprocessed_at = NOW()
|
||||
""", (
|
||||
r['document_id'],
|
||||
r['field_name'],
|
||||
r['csv_value'],
|
||||
r['csv_value_normalized'],
|
||||
r['ocr_value'],
|
||||
r['ocr_value_normalized'],
|
||||
json.dumps(r['all_ocr_candidates']),
|
||||
r['matched'],
|
||||
r['match_score'],
|
||||
r['pdf_path'],
|
||||
r['pdf_type'],
|
||||
r['csv_filename'],
|
||||
r['page_no'],
|
||||
json.dumps(r['bbox']) if r['bbox'] else None,
|
||||
r['error']
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Re-process failed matches')
|
||||
parser.add_argument('--csv', required=True, help='CSV files glob pattern')
|
||||
parser.add_argument('--pdf-dir', required=True, help='PDF directory')
|
||||
parser.add_argument('--workers', type=int, default=3, help='Number of workers')
|
||||
parser.add_argument('--limit', type=int, help='Limit number of documents to process')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Expand CSV glob
|
||||
csv_files = sorted(glob.glob(args.csv))
|
||||
print(f"Found {len(csv_files)} CSV files")
|
||||
|
||||
# Build CSV cache
|
||||
print("Building CSV filename cache...")
|
||||
build_csv_cache(csv_files)
|
||||
print(f"Cached {len(_csv_cache)} document IDs")
|
||||
|
||||
# Connect to database
|
||||
db = DocumentDB()
|
||||
db.connect()
|
||||
|
||||
# Create new table
|
||||
create_failed_match_table(db)
|
||||
|
||||
# Get all failed documents
|
||||
print("Fetching failed documents...")
|
||||
failed_docs = get_failed_documents(db)
|
||||
print(f"Found {len(failed_docs)} documents with failed matches")
|
||||
|
||||
if args.limit:
|
||||
failed_docs = failed_docs[:args.limit]
|
||||
print(f"Limited to {len(failed_docs)} documents")
|
||||
|
||||
# Prepare tasks
|
||||
tasks = []
|
||||
for doc in failed_docs:
|
||||
failed_fields = get_failed_fields_for_document(db, doc['document_id'])
|
||||
csv_filename = find_csv_filename(doc['document_id'])
|
||||
if failed_fields:
|
||||
tasks.append((doc, failed_fields, csv_filename))
|
||||
|
||||
print(f"Processing {len(tasks)} documents with {args.workers} workers...")
|
||||
|
||||
# Process with multiprocessing
|
||||
total_results = 0
|
||||
batch_results = []
|
||||
batch_size = 50
|
||||
|
||||
with ProcessPoolExecutor(max_workers=args.workers, initializer=init_worker) as executor:
|
||||
futures = {executor.submit(process_single_document, task): task[0]['document_id']
|
||||
for task in tasks}
|
||||
|
||||
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
|
||||
doc_id = futures[future]
|
||||
try:
|
||||
results = future.result(timeout=120)
|
||||
batch_results.extend(results)
|
||||
total_results += len(results)
|
||||
|
||||
# Save in batches
|
||||
if len(batch_results) >= batch_size:
|
||||
save_results_batch(db, batch_results)
|
||||
batch_results = []
|
||||
|
||||
except TimeoutError:
|
||||
print(f"\nTimeout processing {doc_id}")
|
||||
except Exception as e:
|
||||
print(f"\nError processing {doc_id}: {e}")
|
||||
|
||||
# Save remaining results
|
||||
if batch_results:
|
||||
save_results_batch(db, batch_results)
|
||||
|
||||
print(f"\nDone! Saved {total_results} failed match records to failed_match_details table")
|
||||
|
||||
# Show summary
|
||||
conn = db.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
SELECT field_name, COUNT(*) as total,
|
||||
COUNT(*) FILTER (WHERE ocr_value IS NOT NULL) as has_ocr,
|
||||
COALESCE(AVG(match_score), 0) as avg_score
|
||||
FROM failed_match_details
|
||||
GROUP BY field_name
|
||||
ORDER BY total DESC
|
||||
""")
|
||||
print("\nSummary by field:")
|
||||
print("-" * 70)
|
||||
print(f"{'Field':<35} {'Total':>8} {'Has OCR':>10} {'Avg Score':>12}")
|
||||
print("-" * 70)
|
||||
for row in cursor.fetchall():
|
||||
print(f"{row[0]:<35} {row[1]:>8} {row[2]:>10} {row[3]:>12.2f}")
|
||||
|
||||
db.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -27,7 +27,7 @@ class InvoiceRow:
|
||||
Amount: Decimal | None = None
|
||||
# New fields
|
||||
split: str | None = None # train/test split indicator
|
||||
customer_number: str | None = None # Customer number (no matching needed)
|
||||
customer_number: str | None = None # Customer number (needs matching)
|
||||
supplier_name: str | None = None # Supplier name (no matching)
|
||||
supplier_organisation_number: str | None = None # Swedish org number (needs matching)
|
||||
supplier_accounts: str | None = None # Supplier accounts (needs matching)
|
||||
@@ -198,22 +198,30 @@ class CSVLoader:
|
||||
value = value.strip()
|
||||
return value if value else None
|
||||
|
||||
def _get_field(self, row: dict, *keys: str) -> str | None:
|
||||
"""Get field value trying multiple possible column names."""
|
||||
for key in keys:
|
||||
value = row.get(key)
|
||||
if value is not None:
|
||||
return value
|
||||
return None
|
||||
|
||||
def _parse_row(self, row: dict) -> InvoiceRow | None:
|
||||
"""Parse a single CSV row into InvoiceRow."""
|
||||
doc_id = self._parse_string(row.get('DocumentId'))
|
||||
doc_id = self._parse_string(self._get_field(row, 'DocumentId', 'document_id'))
|
||||
if not doc_id:
|
||||
return None
|
||||
|
||||
return InvoiceRow(
|
||||
DocumentId=doc_id,
|
||||
InvoiceDate=self._parse_date(row.get('InvoiceDate')),
|
||||
InvoiceNumber=self._parse_string(row.get('InvoiceNumber')),
|
||||
InvoiceDueDate=self._parse_date(row.get('InvoiceDueDate')),
|
||||
OCR=self._parse_string(row.get('OCR')),
|
||||
Message=self._parse_string(row.get('Message')),
|
||||
Bankgiro=self._parse_string(row.get('Bankgiro')),
|
||||
Plusgiro=self._parse_string(row.get('Plusgiro')),
|
||||
Amount=self._parse_amount(row.get('Amount')),
|
||||
InvoiceDate=self._parse_date(self._get_field(row, 'InvoiceDate', 'invoice_date')),
|
||||
InvoiceNumber=self._parse_string(self._get_field(row, 'InvoiceNumber', 'invoice_number')),
|
||||
InvoiceDueDate=self._parse_date(self._get_field(row, 'InvoiceDueDate', 'invoice_due_date')),
|
||||
OCR=self._parse_string(self._get_field(row, 'OCR', 'ocr')),
|
||||
Message=self._parse_string(self._get_field(row, 'Message', 'message')),
|
||||
Bankgiro=self._parse_string(self._get_field(row, 'Bankgiro', 'bankgiro')),
|
||||
Plusgiro=self._parse_string(self._get_field(row, 'Plusgiro', 'plusgiro')),
|
||||
Amount=self._parse_amount(self._get_field(row, 'Amount', 'amount', 'invoice_data_amount')),
|
||||
# New fields
|
||||
split=self._parse_string(row.get('split')),
|
||||
customer_number=self._parse_string(row.get('customer_number')),
|
||||
|
||||
@@ -219,7 +219,7 @@ class FieldMatcher:
|
||||
# Note: Amount is excluded because short numbers like "451" can incorrectly match
|
||||
# in OCR payment lines or other unrelated text
|
||||
if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
|
||||
'supplier_organisation_number', 'supplier_accounts'):
|
||||
'supplier_organisation_number', 'supplier_accounts', 'customer_number'):
|
||||
substring_matches = self._find_substring_matches(page_tokens, value, field_name)
|
||||
matches.extend(substring_matches)
|
||||
|
||||
@@ -369,7 +369,7 @@ class FieldMatcher:
|
||||
|
||||
# Supported fields for substring matching
|
||||
supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount',
|
||||
'supplier_organisation_number', 'supplier_accounts')
|
||||
'supplier_organisation_number', 'supplier_accounts', 'customer_number')
|
||||
if field_name not in supported_fields:
|
||||
return matches
|
||||
|
||||
@@ -383,10 +383,17 @@ class FieldMatcher:
|
||||
continue
|
||||
|
||||
# Check if value appears as substring (using normalized text)
|
||||
# Try case-sensitive first, then case-insensitive
|
||||
if value in token_text_normalized:
|
||||
# Verify it's a proper boundary match (not part of a larger number)
|
||||
idx = token_text_normalized.find(value)
|
||||
case_sensitive_match = True
|
||||
elif value.lower() in token_text_normalized.lower():
|
||||
idx = token_text_normalized.lower().find(value.lower())
|
||||
case_sensitive_match = False
|
||||
else:
|
||||
continue
|
||||
|
||||
# Verify it's a proper boundary match (not part of a larger number)
|
||||
# Check character before (if exists)
|
||||
if idx > 0:
|
||||
char_before = token_text_normalized[idx - 1]
|
||||
@@ -417,12 +424,15 @@ class FieldMatcher:
|
||||
# Boost score if keyword is inline
|
||||
inline_boost = 0.1 if inline_context else 0
|
||||
|
||||
# Lower score for case-insensitive match
|
||||
base_score = 0.75 if case_sensitive_match else 0.70
|
||||
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=value,
|
||||
bbox=token.bbox, # Use full token bbox
|
||||
page_no=token.page_no,
|
||||
score=min(1.0, 0.75 + context_boost + inline_boost), # Lower than exact match
|
||||
score=min(1.0, base_score + context_boost + inline_boost),
|
||||
matched_text=token_text,
|
||||
context_keywords=context_keywords + inline_context
|
||||
))
|
||||
|
||||
@@ -260,6 +260,45 @@ class FieldNormalizer:
|
||||
|
||||
return list(set(v for v in variants if v))
|
||||
|
||||
@staticmethod
|
||||
def normalize_customer_number(value: str) -> list[str]:
|
||||
"""
|
||||
Normalize customer number.
|
||||
|
||||
Customer numbers can have various formats:
|
||||
- Alphanumeric codes: 'EMM 256-6', 'ABC123', 'A-1234'
|
||||
- Pure numbers: '12345', '123-456'
|
||||
|
||||
Examples:
|
||||
'EMM 256-6' -> ['EMM 256-6', 'EMM256-6', 'EMM2566']
|
||||
'ABC 123' -> ['ABC 123', 'ABC123']
|
||||
"""
|
||||
value = FieldNormalizer.clean_text(value)
|
||||
variants = [value]
|
||||
|
||||
# Version without spaces
|
||||
no_space = value.replace(' ', '')
|
||||
if no_space != value:
|
||||
variants.append(no_space)
|
||||
|
||||
# Version without dashes
|
||||
no_dash = value.replace('-', '')
|
||||
if no_dash != value:
|
||||
variants.append(no_dash)
|
||||
|
||||
# Version without spaces and dashes
|
||||
clean = value.replace(' ', '').replace('-', '')
|
||||
if clean != value and clean not in variants:
|
||||
variants.append(clean)
|
||||
|
||||
# Uppercase and lowercase versions
|
||||
if value.upper() != value:
|
||||
variants.append(value.upper())
|
||||
if value.lower() != value:
|
||||
variants.append(value.lower())
|
||||
|
||||
return list(set(v for v in variants if v))
|
||||
|
||||
@staticmethod
|
||||
def normalize_amount(value: str) -> list[str]:
|
||||
"""
|
||||
@@ -414,7 +453,7 @@ class FieldNormalizer:
|
||||
]
|
||||
|
||||
# Ambiguous patterns - try both DD/MM and MM/DD interpretations
|
||||
ambiguous_patterns = [
|
||||
ambiguous_patterns_4digit_year = [
|
||||
# Format with / - could be DD/MM/YYYY (European) or MM/DD/YYYY (US)
|
||||
r'^(\d{1,2})/(\d{1,2})/(\d{4})$',
|
||||
# Format with . - typically European DD.MM.YYYY
|
||||
@@ -423,6 +462,16 @@ class FieldNormalizer:
|
||||
r'^(\d{1,2})-(\d{1,2})-(\d{4})$',
|
||||
]
|
||||
|
||||
# Patterns with 2-digit year (common in Swedish invoices)
|
||||
ambiguous_patterns_2digit_year = [
|
||||
# Format DD.MM.YY (e.g., 02.08.25 for 2025-08-02)
|
||||
r'^(\d{1,2})\.(\d{1,2})\.(\d{2})$',
|
||||
# Format DD/MM/YY
|
||||
r'^(\d{1,2})/(\d{1,2})/(\d{2})$',
|
||||
# Format DD-MM-YY
|
||||
r'^(\d{1,2})-(\d{1,2})-(\d{2})$',
|
||||
]
|
||||
|
||||
# Try unambiguous patterns first
|
||||
for pattern, extractor in date_patterns:
|
||||
match = re.match(pattern, value)
|
||||
@@ -434,9 +483,9 @@ class FieldNormalizer:
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Try ambiguous patterns with both interpretations
|
||||
# Try ambiguous patterns with 4-digit year
|
||||
if not parsed_dates:
|
||||
for pattern in ambiguous_patterns:
|
||||
for pattern in ambiguous_patterns_4digit_year:
|
||||
match = re.match(pattern, value)
|
||||
if match:
|
||||
n1, n2, year = int(match[1]), int(match[2]), int(match[3])
|
||||
@@ -457,6 +506,31 @@ class FieldNormalizer:
|
||||
if parsed_dates:
|
||||
break
|
||||
|
||||
# Try ambiguous patterns with 2-digit year (e.g., 02.08.25)
|
||||
if not parsed_dates:
|
||||
for pattern in ambiguous_patterns_2digit_year:
|
||||
match = re.match(pattern, value)
|
||||
if match:
|
||||
n1, n2, yy = int(match[1]), int(match[2]), int(match[3])
|
||||
# Convert 2-digit year to 4-digit (00-49 -> 2000s, 50-99 -> 1900s)
|
||||
year = 2000 + yy if yy < 50 else 1900 + yy
|
||||
|
||||
# Try DD/MM/YY (European - day first, most common in Sweden)
|
||||
try:
|
||||
parsed_dates.append(datetime(year, n2, n1))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Try MM/DD/YY (US - month first) if different and valid
|
||||
if n1 != n2:
|
||||
try:
|
||||
parsed_dates.append(datetime(year, n1, n2))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if parsed_dates:
|
||||
break
|
||||
|
||||
# Try Swedish month names
|
||||
if not parsed_dates:
|
||||
for month_name, month_num in FieldNormalizer.SWEDISH_MONTHS.items():
|
||||
@@ -527,6 +601,7 @@ NORMALIZERS: dict[str, Callable[[str], list[str]]] = {
|
||||
'InvoiceDueDate': FieldNormalizer.normalize_date,
|
||||
'supplier_organisation_number': FieldNormalizer.normalize_organisation_number,
|
||||
'supplier_accounts': FieldNormalizer.normalize_supplier_accounts,
|
||||
'customer_number': FieldNormalizer.normalize_customer_number,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -60,7 +60,9 @@ class OCREngine:
|
||||
self,
|
||||
lang: str = "en",
|
||||
det_model_dir: str | None = None,
|
||||
rec_model_dir: str | None = None
|
||||
rec_model_dir: str | None = None,
|
||||
use_doc_orientation_classify: bool = True,
|
||||
use_doc_unwarping: bool = False
|
||||
):
|
||||
"""
|
||||
Initialize OCR engine.
|
||||
@@ -69,6 +71,13 @@ class OCREngine:
|
||||
lang: Language code ('en', 'sv', 'ch', etc.)
|
||||
det_model_dir: Custom detection model directory
|
||||
rec_model_dir: Custom recognition model directory
|
||||
use_doc_orientation_classify: Whether to auto-detect and correct document orientation.
|
||||
Default True to handle rotated documents.
|
||||
use_doc_unwarping: Whether to use UVDoc document unwarping for curved/warped documents.
|
||||
Default False to preserve original image layout,
|
||||
especially important for payment OCR lines at bottom.
|
||||
Enable for severely warped documents at the cost of potentially
|
||||
losing bottom content.
|
||||
|
||||
Note:
|
||||
PaddleOCR 3.x automatically uses GPU if available via PaddlePaddle.
|
||||
@@ -82,6 +91,12 @@ class OCREngine:
|
||||
# PaddleOCR 3.x init (use_gpu removed, device controlled by paddle.set_device)
|
||||
init_params = {
|
||||
'lang': lang,
|
||||
# Enable orientation classification to handle rotated documents
|
||||
'use_doc_orientation_classify': use_doc_orientation_classify,
|
||||
# Disable UVDoc unwarping to preserve original image layout
|
||||
# This prevents the bottom payment OCR line from being cut off
|
||||
# For severely warped documents, enable this but expect potential content loss
|
||||
'use_doc_unwarping': use_doc_unwarping,
|
||||
}
|
||||
if det_model_dir:
|
||||
init_params['text_detection_model_dir'] = det_model_dir
|
||||
@@ -95,7 +110,9 @@ class OCREngine:
|
||||
image: str | Path | np.ndarray,
|
||||
page_no: int = 0,
|
||||
max_size: int = 2000,
|
||||
scale_to_pdf_points: float | None = None
|
||||
scale_to_pdf_points: float | None = None,
|
||||
scan_bottom_region: bool = True,
|
||||
bottom_region_ratio: float = 0.15
|
||||
) -> list[OCRToken]:
|
||||
"""
|
||||
Extract text tokens from an image.
|
||||
@@ -108,19 +125,106 @@ class OCREngine:
|
||||
scale_to_pdf_points: If provided, scale bbox coordinates by this factor
|
||||
to convert from pixel to PDF point coordinates.
|
||||
Use (72 / dpi) for images rendered at a specific DPI.
|
||||
scan_bottom_region: If True, also scan the bottom region separately to catch
|
||||
OCR payment lines that may be missed in full-page scan.
|
||||
bottom_region_ratio: Ratio of page height to scan as bottom region (default 0.15 = 15%)
|
||||
|
||||
Returns:
|
||||
List of OCRToken objects with bbox in pixel coords (or PDF points if scale_to_pdf_points is set)
|
||||
"""
|
||||
result = self.extract_with_image(image, page_no, max_size, scale_to_pdf_points)
|
||||
return result.tokens
|
||||
tokens = result.tokens
|
||||
|
||||
# Optionally scan bottom region separately for Swedish OCR payment lines
|
||||
if scan_bottom_region:
|
||||
bottom_tokens = self._scan_bottom_region(
|
||||
image, page_no, max_size, scale_to_pdf_points, bottom_region_ratio
|
||||
)
|
||||
tokens = self._merge_tokens(tokens, bottom_tokens)
|
||||
|
||||
return tokens
|
||||
|
||||
def _scan_bottom_region(
|
||||
self,
|
||||
image: str | Path | np.ndarray,
|
||||
page_no: int,
|
||||
max_size: int,
|
||||
scale_to_pdf_points: float | None,
|
||||
bottom_ratio: float
|
||||
) -> list[OCRToken]:
|
||||
"""Scan the bottom region of the image separately."""
|
||||
from PIL import Image as PILImage
|
||||
|
||||
# Load image if path
|
||||
if isinstance(image, (str, Path)):
|
||||
img = PILImage.open(str(image))
|
||||
img_array = np.array(img)
|
||||
else:
|
||||
img_array = image
|
||||
|
||||
h, w = img_array.shape[:2]
|
||||
crop_y = int(h * (1 - bottom_ratio))
|
||||
|
||||
# Crop bottom region
|
||||
bottom_crop = img_array[crop_y:h, :, :] if len(img_array.shape) == 3 else img_array[crop_y:h, :]
|
||||
|
||||
# OCR the cropped region (without recursive bottom scan to avoid infinite loop)
|
||||
result = self.extract_with_image(
|
||||
bottom_crop, page_no, max_size,
|
||||
scale_to_pdf_points=None,
|
||||
scan_bottom_region=False # Important: disable to prevent recursion
|
||||
)
|
||||
|
||||
# Adjust bbox y-coordinates to full image space
|
||||
adjusted_tokens = []
|
||||
for token in result.tokens:
|
||||
# Scale factor for coordinates
|
||||
scale = scale_to_pdf_points if scale_to_pdf_points else 1.0
|
||||
|
||||
adjusted_bbox = (
|
||||
token.bbox[0] * scale,
|
||||
(token.bbox[1] + crop_y) * scale,
|
||||
token.bbox[2] * scale,
|
||||
(token.bbox[3] + crop_y) * scale
|
||||
)
|
||||
adjusted_tokens.append(OCRToken(
|
||||
text=token.text,
|
||||
bbox=adjusted_bbox,
|
||||
confidence=token.confidence,
|
||||
page_no=token.page_no
|
||||
))
|
||||
|
||||
return adjusted_tokens
|
||||
|
||||
def _merge_tokens(
|
||||
self,
|
||||
main_tokens: list[OCRToken],
|
||||
bottom_tokens: list[OCRToken]
|
||||
) -> list[OCRToken]:
|
||||
"""Merge tokens from main scan and bottom region scan, removing duplicates."""
|
||||
if not bottom_tokens:
|
||||
return main_tokens
|
||||
|
||||
# Create a set of existing token texts for deduplication
|
||||
existing_texts = {t.text.strip() for t in main_tokens}
|
||||
|
||||
# Add bottom tokens that aren't duplicates
|
||||
merged = list(main_tokens)
|
||||
for token in bottom_tokens:
|
||||
if token.text.strip() not in existing_texts:
|
||||
merged.append(token)
|
||||
existing_texts.add(token.text.strip())
|
||||
|
||||
return merged
|
||||
|
||||
def extract_with_image(
|
||||
self,
|
||||
image: str | Path | np.ndarray,
|
||||
page_no: int = 0,
|
||||
max_size: int = 2000,
|
||||
scale_to_pdf_points: float | None = None
|
||||
scale_to_pdf_points: float | None = None,
|
||||
scan_bottom_region: bool = True,
|
||||
bottom_region_ratio: float = 0.15
|
||||
) -> OCRResult:
|
||||
"""
|
||||
Extract text tokens from an image and return the preprocessed image.
|
||||
@@ -138,6 +242,9 @@ class OCREngine:
|
||||
scale_to_pdf_points: If provided, scale bbox coordinates by this factor
|
||||
to convert from pixel to PDF point coordinates.
|
||||
Use (72 / dpi) for images rendered at a specific DPI.
|
||||
scan_bottom_region: If True, also scan the bottom region separately to catch
|
||||
OCR payment lines that may be missed in full-page scan.
|
||||
bottom_region_ratio: Ratio of page height to scan as bottom region (default 0.15 = 15%)
|
||||
|
||||
Returns:
|
||||
OCRResult with tokens and output_img (preprocessed image from PaddleOCR)
|
||||
@@ -241,6 +348,13 @@ class OCREngine:
|
||||
if output_img is None:
|
||||
output_img = img_array
|
||||
|
||||
# Optionally scan bottom region separately for Swedish OCR payment lines
|
||||
if scan_bottom_region:
|
||||
bottom_tokens = self._scan_bottom_region(
|
||||
image, page_no, max_size, scale_to_pdf_points, bottom_region_ratio
|
||||
)
|
||||
tokens = self._merge_tokens(tokens, bottom_tokens)
|
||||
|
||||
return OCRResult(tokens=tokens, output_img=output_img)
|
||||
|
||||
def extract_from_pdf(
|
||||
|
||||
@@ -57,6 +57,7 @@ def get_pdf_type(pdf_path: str | Path) -> PDFType:
|
||||
return "scanned"
|
||||
|
||||
text_pages = 0
|
||||
total_pages = len(doc)
|
||||
for page in doc:
|
||||
text = page.get_text().strip()
|
||||
if len(text) > 30:
|
||||
@@ -64,7 +65,6 @@ def get_pdf_type(pdf_path: str | Path) -> PDFType:
|
||||
|
||||
doc.close()
|
||||
|
||||
total_pages = len(doc)
|
||||
if text_pages == total_pages:
|
||||
return "text"
|
||||
elif text_pages == 0:
|
||||
|
||||
@@ -85,6 +85,7 @@ def process_text_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
Returns:
|
||||
Result dictionary with success status, annotations, and report.
|
||||
"""
|
||||
import shutil
|
||||
from src.data import AutoLabelReport, FieldMatchResult
|
||||
from src.pdf import PDFDocument
|
||||
from src.matcher import FieldMatcher
|
||||
@@ -100,6 +101,11 @@ def process_text_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
start_time = time.time()
|
||||
doc_id = row_dict["DocumentId"]
|
||||
|
||||
# Clean up existing temp folder for this document (for re-matching)
|
||||
temp_doc_dir = output_dir / "temp" / doc_id
|
||||
if temp_doc_dir.exists():
|
||||
shutil.rmtree(temp_doc_dir, ignore_errors=True)
|
||||
|
||||
report = AutoLabelReport(document_id=doc_id)
|
||||
report.pdf_path = str(pdf_path)
|
||||
report.pdf_type = "text"
|
||||
@@ -218,6 +224,7 @@ def process_scanned_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
Returns:
|
||||
Result dictionary with success status, annotations, and report.
|
||||
"""
|
||||
import shutil
|
||||
from src.data import AutoLabelReport, FieldMatchResult
|
||||
from src.pdf import PDFDocument
|
||||
from src.matcher import FieldMatcher
|
||||
@@ -233,6 +240,11 @@ def process_scanned_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
start_time = time.time()
|
||||
doc_id = row_dict["DocumentId"]
|
||||
|
||||
# Clean up existing temp folder for this document (for re-matching)
|
||||
temp_doc_dir = output_dir / "temp" / doc_id
|
||||
if temp_doc_dir.exists():
|
||||
shutil.rmtree(temp_doc_dir, ignore_errors=True)
|
||||
|
||||
report = AutoLabelReport(document_id=doc_id)
|
||||
report.pdf_path = str(pdf_path)
|
||||
report.pdf_type = "scanned"
|
||||
|
||||
@@ -21,6 +21,7 @@ FIELD_CLASSES = {
|
||||
'Plusgiro': 5,
|
||||
'Amount': 6,
|
||||
'supplier_organisation_number': 7,
|
||||
'customer_number': 8,
|
||||
}
|
||||
|
||||
# Fields that need matching but map to other YOLO classes
|
||||
@@ -41,6 +42,7 @@ CLASS_NAMES = [
|
||||
'plusgiro',
|
||||
'amount',
|
||||
'supplier_org_number',
|
||||
'customer_number',
|
||||
]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user