code issue fix

This commit is contained in:
Yaojia Wang
2026-01-17 18:55:46 +01:00
parent 510890d18c
commit e9460e9f34
9 changed files with 729 additions and 57 deletions

View File

@@ -10,6 +10,7 @@ import sys
import time import time
import os import os
import signal import signal
import shutil
import warnings import warnings
from pathlib import Path from pathlib import Path
from tqdm import tqdm from tqdm import tqdm
@@ -107,6 +108,7 @@ def process_single_document(args_tuple):
Returns: Returns:
dict with results dict with results
""" """
import shutil
row_dict, pdf_path_str, output_dir_str, dpi, min_confidence, skip_ocr = args_tuple row_dict, pdf_path_str, output_dir_str, dpi, min_confidence, skip_ocr = args_tuple
# Import inside worker to avoid pickling issues # Import inside worker to avoid pickling issues
@@ -121,6 +123,11 @@ def process_single_document(args_tuple):
output_dir = Path(output_dir_str) output_dir = Path(output_dir_str)
doc_id = row_dict['DocumentId'] doc_id = row_dict['DocumentId']
# Clean up existing temp folder for this document (for re-matching)
temp_doc_dir = output_dir / 'temp' / doc_id
if temp_doc_dir.exists():
shutil.rmtree(temp_doc_dir, ignore_errors=True)
report = AutoLabelReport(document_id=doc_id) report = AutoLabelReport(document_id=doc_id)
report.pdf_path = str(pdf_path) report.pdf_path = str(pdf_path)
# Store metadata fields from CSV # Store metadata fields from CSV
@@ -602,6 +609,9 @@ def main():
else: else:
remaining_limit = float('inf') remaining_limit = float('inf')
# Collect doc_ids that need retry (for batch delete)
retry_doc_ids = []
for row in rows: for row in rows:
# Stop adding tasks if we've reached the limit # Stop adding tasks if we've reached the limit
if len(tasks) >= remaining_limit: if len(tasks) >= remaining_limit:
@@ -622,6 +632,7 @@ def main():
if db_status is False: if db_status is False:
stats['retried'] += 1 stats['retried'] += 1
retry_in_csv += 1 retry_in_csv += 1
retry_doc_ids.append(doc_id)
pdf_path = single_loader.get_pdf_path(row) pdf_path = single_loader.get_pdf_path(row)
if not pdf_path: if not pdf_path:
@@ -637,12 +648,12 @@ def main():
'Bankgiro': row.Bankgiro, 'Bankgiro': row.Bankgiro,
'Plusgiro': row.Plusgiro, 'Plusgiro': row.Plusgiro,
'Amount': row.Amount, 'Amount': row.Amount,
# New fields # New fields for matching
'supplier_organisation_number': row.supplier_organisation_number, 'supplier_organisation_number': row.supplier_organisation_number,
'supplier_accounts': row.supplier_accounts, 'supplier_accounts': row.supplier_accounts,
'customer_number': row.customer_number,
# Metadata fields (not for matching, but for database storage) # Metadata fields (not for matching, but for database storage)
'split': row.split, 'split': row.split,
'customer_number': row.customer_number,
'supplier_name': row.supplier_name, 'supplier_name': row.supplier_name,
} }
@@ -658,6 +669,22 @@ def main():
if skipped_in_csv > 0 or retry_in_csv > 0: if skipped_in_csv > 0 or retry_in_csv > 0:
print(f" Skipped {skipped_in_csv} (already in DB), retrying {retry_in_csv} failed") print(f" Skipped {skipped_in_csv} (already in DB), retrying {retry_in_csv} failed")
# Clean up retry documents: delete from database and remove temp folders
if retry_doc_ids:
# Batch delete from database (field_results will be cascade deleted)
with db.connect().cursor() as cursor:
cursor.execute(
"DELETE FROM documents WHERE document_id = ANY(%s)",
(retry_doc_ids,)
)
db.connect().commit()
# Remove temp folders
for doc_id in retry_doc_ids:
temp_doc_dir = output_dir / 'temp' / doc_id
if temp_doc_dir.exists():
shutil.rmtree(temp_doc_dir, ignore_errors=True)
print(f" Cleaned up {len(retry_doc_ids)} retry documents (DB + temp folders)")
if not tasks: if not tasks:
continue continue

424
src/cli/reprocess_failed.py Normal file
View File

@@ -0,0 +1,424 @@
#!/usr/bin/env python3
"""
Re-process failed matches and store detailed information including OCR values,
CSV values, and source CSV filename in a new table.
"""
import argparse
import json
import glob
import os
import sys
import time
from pathlib import Path
from datetime import datetime
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
from tqdm import tqdm
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from src.data.db import DocumentDB
from src.data.csv_loader import CSVLoader
from src.normalize.normalizer import normalize_field
def create_failed_match_table(db: DocumentDB):
"""Create the failed_match_details table."""
conn = db.connect()
with conn.cursor() as cursor:
cursor.execute("""
DROP TABLE IF EXISTS failed_match_details;
CREATE TABLE failed_match_details (
id SERIAL PRIMARY KEY,
document_id TEXT NOT NULL,
field_name TEXT NOT NULL,
csv_value TEXT,
csv_value_normalized TEXT,
ocr_value TEXT,
ocr_value_normalized TEXT,
all_ocr_candidates JSONB,
matched BOOLEAN DEFAULT FALSE,
match_score REAL,
pdf_path TEXT,
pdf_type TEXT,
csv_filename TEXT,
page_no INTEGER,
bbox JSONB,
error TEXT,
reprocessed_at TIMESTAMPTZ DEFAULT NOW(),
UNIQUE(document_id, field_name)
);
CREATE INDEX IF NOT EXISTS idx_failed_match_document_id ON failed_match_details(document_id);
CREATE INDEX IF NOT EXISTS idx_failed_match_field_name ON failed_match_details(field_name);
CREATE INDEX IF NOT EXISTS idx_failed_match_csv_filename ON failed_match_details(csv_filename);
CREATE INDEX IF NOT EXISTS idx_failed_match_matched ON failed_match_details(matched);
""")
conn.commit()
print("Created table: failed_match_details")
def get_failed_documents(db: DocumentDB) -> list:
"""Get all documents that have at least one failed field match."""
conn = db.connect()
with conn.cursor() as cursor:
cursor.execute("""
SELECT DISTINCT fr.document_id, d.pdf_path, d.pdf_type
FROM field_results fr
JOIN documents d ON fr.document_id = d.document_id
WHERE fr.matched = false
ORDER BY fr.document_id
""")
return [{'document_id': row[0], 'pdf_path': row[1], 'pdf_type': row[2]}
for row in cursor.fetchall()]
def get_failed_fields_for_document(db: DocumentDB, doc_id: str) -> list:
"""Get all failed field results for a document."""
conn = db.connect()
with conn.cursor() as cursor:
cursor.execute("""
SELECT field_name, csv_value, error
FROM field_results
WHERE document_id = %s AND matched = false
""", (doc_id,))
return [{'field_name': row[0], 'csv_value': row[1], 'error': row[2]}
for row in cursor.fetchall()]
# Cache for CSV data
_csv_cache = {}
def build_csv_cache(csv_files: list):
"""Build a cache of document_id to csv_filename mapping."""
global _csv_cache
_csv_cache = {}
for csv_file in csv_files:
csv_filename = os.path.basename(csv_file)
loader = CSVLoader(csv_file)
for row in loader.iter_rows():
if row.DocumentId not in _csv_cache:
_csv_cache[row.DocumentId] = csv_filename
def find_csv_filename(doc_id: str) -> str:
"""Find which CSV file contains the document ID."""
return _csv_cache.get(doc_id, None)
def init_worker():
"""Initialize worker process."""
import os
import warnings
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["GLOG_minloglevel"] = "2"
warnings.filterwarnings("ignore")
def process_single_document(args):
"""Process a single document and extract OCR values for failed fields."""
doc_info, failed_fields, csv_filename = args
doc_id = doc_info['document_id']
pdf_path = doc_info['pdf_path']
pdf_type = doc_info['pdf_type']
results = []
# Try to extract OCR from PDF
try:
if pdf_path and os.path.exists(pdf_path):
from src.pdf import PDFDocument
from src.ocr import OCREngine
pdf_doc = PDFDocument(pdf_path)
is_scanned = pdf_doc.detect_type() == "scanned"
# Collect all OCR text blocks
all_ocr_texts = []
if is_scanned:
# Use OCR for scanned PDFs
ocr_engine = OCREngine()
for page_no in range(pdf_doc.page_count):
# Render page to image
img = pdf_doc.render_page(page_no, dpi=150)
if img is None:
continue
# OCR the image
ocr_results = ocr_engine.extract_from_image(img)
for block in ocr_results:
all_ocr_texts.append({
'text': block.get('text', ''),
'bbox': block.get('bbox'),
'page_no': page_no
})
else:
# Use text extraction for text PDFs
for page_no in range(pdf_doc.page_count):
tokens = list(pdf_doc.extract_text_tokens(page_no))
for token in tokens:
all_ocr_texts.append({
'text': token.text,
'bbox': token.bbox,
'page_no': page_no
})
# For each failed field, try to find matching OCR
for field in failed_fields:
field_name = field['field_name']
csv_value = field['csv_value']
error = field['error']
# Normalize CSV value
csv_normalized = normalize_field(field_name, csv_value) if csv_value else None
# Try to find best match in OCR
best_score = 0
best_ocr = None
best_bbox = None
best_page = None
for ocr_block in all_ocr_texts:
ocr_text = ocr_block['text']
if not ocr_text:
continue
ocr_normalized = normalize_field(field_name, ocr_text)
# Calculate similarity
if csv_normalized and ocr_normalized:
# Check substring match
if csv_normalized in ocr_normalized:
score = len(csv_normalized) / max(len(ocr_normalized), 1)
if score > best_score:
best_score = score
best_ocr = ocr_text
best_bbox = ocr_block['bbox']
best_page = ocr_block['page_no']
elif ocr_normalized in csv_normalized:
score = len(ocr_normalized) / max(len(csv_normalized), 1)
if score > best_score:
best_score = score
best_ocr = ocr_text
best_bbox = ocr_block['bbox']
best_page = ocr_block['page_no']
# Exact match
elif csv_normalized == ocr_normalized:
best_score = 1.0
best_ocr = ocr_text
best_bbox = ocr_block['bbox']
best_page = ocr_block['page_no']
break
results.append({
'document_id': doc_id,
'field_name': field_name,
'csv_value': csv_value,
'csv_value_normalized': csv_normalized,
'ocr_value': best_ocr,
'ocr_value_normalized': normalize_field(field_name, best_ocr) if best_ocr else None,
'all_ocr_candidates': [t['text'] for t in all_ocr_texts[:100]], # Limit to 100
'matched': best_score > 0.8,
'match_score': best_score,
'pdf_path': pdf_path,
'pdf_type': pdf_type,
'csv_filename': csv_filename,
'page_no': best_page,
'bbox': list(best_bbox) if best_bbox else None,
'error': error
})
else:
# PDF not found
for field in failed_fields:
results.append({
'document_id': doc_id,
'field_name': field['field_name'],
'csv_value': field['csv_value'],
'csv_value_normalized': normalize_field(field['field_name'], field['csv_value']) if field['csv_value'] else None,
'ocr_value': None,
'ocr_value_normalized': None,
'all_ocr_candidates': [],
'matched': False,
'match_score': 0,
'pdf_path': pdf_path,
'pdf_type': pdf_type,
'csv_filename': csv_filename,
'page_no': None,
'bbox': None,
'error': f"PDF not found: {pdf_path}"
})
except Exception as e:
for field in failed_fields:
results.append({
'document_id': doc_id,
'field_name': field['field_name'],
'csv_value': field['csv_value'],
'csv_value_normalized': None,
'ocr_value': None,
'ocr_value_normalized': None,
'all_ocr_candidates': [],
'matched': False,
'match_score': 0,
'pdf_path': pdf_path,
'pdf_type': pdf_type,
'csv_filename': csv_filename,
'page_no': None,
'bbox': None,
'error': str(e)
})
return results
def save_results_batch(db: DocumentDB, results: list):
"""Save results to failed_match_details table."""
if not results:
return
conn = db.connect()
with conn.cursor() as cursor:
for r in results:
cursor.execute("""
INSERT INTO failed_match_details
(document_id, field_name, csv_value, csv_value_normalized,
ocr_value, ocr_value_normalized, all_ocr_candidates,
matched, match_score, pdf_path, pdf_type, csv_filename,
page_no, bbox, error)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (document_id, field_name) DO UPDATE SET
csv_value = EXCLUDED.csv_value,
csv_value_normalized = EXCLUDED.csv_value_normalized,
ocr_value = EXCLUDED.ocr_value,
ocr_value_normalized = EXCLUDED.ocr_value_normalized,
all_ocr_candidates = EXCLUDED.all_ocr_candidates,
matched = EXCLUDED.matched,
match_score = EXCLUDED.match_score,
pdf_path = EXCLUDED.pdf_path,
pdf_type = EXCLUDED.pdf_type,
csv_filename = EXCLUDED.csv_filename,
page_no = EXCLUDED.page_no,
bbox = EXCLUDED.bbox,
error = EXCLUDED.error,
reprocessed_at = NOW()
""", (
r['document_id'],
r['field_name'],
r['csv_value'],
r['csv_value_normalized'],
r['ocr_value'],
r['ocr_value_normalized'],
json.dumps(r['all_ocr_candidates']),
r['matched'],
r['match_score'],
r['pdf_path'],
r['pdf_type'],
r['csv_filename'],
r['page_no'],
json.dumps(r['bbox']) if r['bbox'] else None,
r['error']
))
conn.commit()
def main():
parser = argparse.ArgumentParser(description='Re-process failed matches')
parser.add_argument('--csv', required=True, help='CSV files glob pattern')
parser.add_argument('--pdf-dir', required=True, help='PDF directory')
parser.add_argument('--workers', type=int, default=3, help='Number of workers')
parser.add_argument('--limit', type=int, help='Limit number of documents to process')
args = parser.parse_args()
# Expand CSV glob
csv_files = sorted(glob.glob(args.csv))
print(f"Found {len(csv_files)} CSV files")
# Build CSV cache
print("Building CSV filename cache...")
build_csv_cache(csv_files)
print(f"Cached {len(_csv_cache)} document IDs")
# Connect to database
db = DocumentDB()
db.connect()
# Create new table
create_failed_match_table(db)
# Get all failed documents
print("Fetching failed documents...")
failed_docs = get_failed_documents(db)
print(f"Found {len(failed_docs)} documents with failed matches")
if args.limit:
failed_docs = failed_docs[:args.limit]
print(f"Limited to {len(failed_docs)} documents")
# Prepare tasks
tasks = []
for doc in failed_docs:
failed_fields = get_failed_fields_for_document(db, doc['document_id'])
csv_filename = find_csv_filename(doc['document_id'])
if failed_fields:
tasks.append((doc, failed_fields, csv_filename))
print(f"Processing {len(tasks)} documents with {args.workers} workers...")
# Process with multiprocessing
total_results = 0
batch_results = []
batch_size = 50
with ProcessPoolExecutor(max_workers=args.workers, initializer=init_worker) as executor:
futures = {executor.submit(process_single_document, task): task[0]['document_id']
for task in tasks}
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
doc_id = futures[future]
try:
results = future.result(timeout=120)
batch_results.extend(results)
total_results += len(results)
# Save in batches
if len(batch_results) >= batch_size:
save_results_batch(db, batch_results)
batch_results = []
except TimeoutError:
print(f"\nTimeout processing {doc_id}")
except Exception as e:
print(f"\nError processing {doc_id}: {e}")
# Save remaining results
if batch_results:
save_results_batch(db, batch_results)
print(f"\nDone! Saved {total_results} failed match records to failed_match_details table")
# Show summary
conn = db.connect()
with conn.cursor() as cursor:
cursor.execute("""
SELECT field_name, COUNT(*) as total,
COUNT(*) FILTER (WHERE ocr_value IS NOT NULL) as has_ocr,
COALESCE(AVG(match_score), 0) as avg_score
FROM failed_match_details
GROUP BY field_name
ORDER BY total DESC
""")
print("\nSummary by field:")
print("-" * 70)
print(f"{'Field':<35} {'Total':>8} {'Has OCR':>10} {'Avg Score':>12}")
print("-" * 70)
for row in cursor.fetchall():
print(f"{row[0]:<35} {row[1]:>8} {row[2]:>10} {row[3]:>12.2f}")
db.close()
if __name__ == '__main__':
main()

View File

@@ -27,7 +27,7 @@ class InvoiceRow:
Amount: Decimal | None = None Amount: Decimal | None = None
# New fields # New fields
split: str | None = None # train/test split indicator split: str | None = None # train/test split indicator
customer_number: str | None = None # Customer number (no matching needed) customer_number: str | None = None # Customer number (needs matching)
supplier_name: str | None = None # Supplier name (no matching) supplier_name: str | None = None # Supplier name (no matching)
supplier_organisation_number: str | None = None # Swedish org number (needs matching) supplier_organisation_number: str | None = None # Swedish org number (needs matching)
supplier_accounts: str | None = None # Supplier accounts (needs matching) supplier_accounts: str | None = None # Supplier accounts (needs matching)
@@ -198,22 +198,30 @@ class CSVLoader:
value = value.strip() value = value.strip()
return value if value else None return value if value else None
def _get_field(self, row: dict, *keys: str) -> str | None:
"""Get field value trying multiple possible column names."""
for key in keys:
value = row.get(key)
if value is not None:
return value
return None
def _parse_row(self, row: dict) -> InvoiceRow | None: def _parse_row(self, row: dict) -> InvoiceRow | None:
"""Parse a single CSV row into InvoiceRow.""" """Parse a single CSV row into InvoiceRow."""
doc_id = self._parse_string(row.get('DocumentId')) doc_id = self._parse_string(self._get_field(row, 'DocumentId', 'document_id'))
if not doc_id: if not doc_id:
return None return None
return InvoiceRow( return InvoiceRow(
DocumentId=doc_id, DocumentId=doc_id,
InvoiceDate=self._parse_date(row.get('InvoiceDate')), InvoiceDate=self._parse_date(self._get_field(row, 'InvoiceDate', 'invoice_date')),
InvoiceNumber=self._parse_string(row.get('InvoiceNumber')), InvoiceNumber=self._parse_string(self._get_field(row, 'InvoiceNumber', 'invoice_number')),
InvoiceDueDate=self._parse_date(row.get('InvoiceDueDate')), InvoiceDueDate=self._parse_date(self._get_field(row, 'InvoiceDueDate', 'invoice_due_date')),
OCR=self._parse_string(row.get('OCR')), OCR=self._parse_string(self._get_field(row, 'OCR', 'ocr')),
Message=self._parse_string(row.get('Message')), Message=self._parse_string(self._get_field(row, 'Message', 'message')),
Bankgiro=self._parse_string(row.get('Bankgiro')), Bankgiro=self._parse_string(self._get_field(row, 'Bankgiro', 'bankgiro')),
Plusgiro=self._parse_string(row.get('Plusgiro')), Plusgiro=self._parse_string(self._get_field(row, 'Plusgiro', 'plusgiro')),
Amount=self._parse_amount(row.get('Amount')), Amount=self._parse_amount(self._get_field(row, 'Amount', 'amount', 'invoice_data_amount')),
# New fields # New fields
split=self._parse_string(row.get('split')), split=self._parse_string(row.get('split')),
customer_number=self._parse_string(row.get('customer_number')), customer_number=self._parse_string(row.get('customer_number')),

View File

@@ -219,7 +219,7 @@ class FieldMatcher:
# Note: Amount is excluded because short numbers like "451" can incorrectly match # Note: Amount is excluded because short numbers like "451" can incorrectly match
# in OCR payment lines or other unrelated text # in OCR payment lines or other unrelated text
if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
'supplier_organisation_number', 'supplier_accounts'): 'supplier_organisation_number', 'supplier_accounts', 'customer_number'):
substring_matches = self._find_substring_matches(page_tokens, value, field_name) substring_matches = self._find_substring_matches(page_tokens, value, field_name)
matches.extend(substring_matches) matches.extend(substring_matches)
@@ -369,7 +369,7 @@ class FieldMatcher:
# Supported fields for substring matching # Supported fields for substring matching
supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount', supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount',
'supplier_organisation_number', 'supplier_accounts') 'supplier_organisation_number', 'supplier_accounts', 'customer_number')
if field_name not in supported_fields: if field_name not in supported_fields:
return matches return matches
@@ -383,10 +383,17 @@ class FieldMatcher:
continue continue
# Check if value appears as substring (using normalized text) # Check if value appears as substring (using normalized text)
# Try case-sensitive first, then case-insensitive
if value in token_text_normalized: if value in token_text_normalized:
# Verify it's a proper boundary match (not part of a larger number)
idx = token_text_normalized.find(value) idx = token_text_normalized.find(value)
case_sensitive_match = True
elif value.lower() in token_text_normalized.lower():
idx = token_text_normalized.lower().find(value.lower())
case_sensitive_match = False
else:
continue
# Verify it's a proper boundary match (not part of a larger number)
# Check character before (if exists) # Check character before (if exists)
if idx > 0: if idx > 0:
char_before = token_text_normalized[idx - 1] char_before = token_text_normalized[idx - 1]
@@ -417,12 +424,15 @@ class FieldMatcher:
# Boost score if keyword is inline # Boost score if keyword is inline
inline_boost = 0.1 if inline_context else 0 inline_boost = 0.1 if inline_context else 0
# Lower score for case-insensitive match
base_score = 0.75 if case_sensitive_match else 0.70
matches.append(Match( matches.append(Match(
field=field_name, field=field_name,
value=value, value=value,
bbox=token.bbox, # Use full token bbox bbox=token.bbox, # Use full token bbox
page_no=token.page_no, page_no=token.page_no,
score=min(1.0, 0.75 + context_boost + inline_boost), # Lower than exact match score=min(1.0, base_score + context_boost + inline_boost),
matched_text=token_text, matched_text=token_text,
context_keywords=context_keywords + inline_context context_keywords=context_keywords + inline_context
)) ))

View File

@@ -260,6 +260,45 @@ class FieldNormalizer:
return list(set(v for v in variants if v)) return list(set(v for v in variants if v))
@staticmethod
def normalize_customer_number(value: str) -> list[str]:
"""
Normalize customer number.
Customer numbers can have various formats:
- Alphanumeric codes: 'EMM 256-6', 'ABC123', 'A-1234'
- Pure numbers: '12345', '123-456'
Examples:
'EMM 256-6' -> ['EMM 256-6', 'EMM256-6', 'EMM2566']
'ABC 123' -> ['ABC 123', 'ABC123']
"""
value = FieldNormalizer.clean_text(value)
variants = [value]
# Version without spaces
no_space = value.replace(' ', '')
if no_space != value:
variants.append(no_space)
# Version without dashes
no_dash = value.replace('-', '')
if no_dash != value:
variants.append(no_dash)
# Version without spaces and dashes
clean = value.replace(' ', '').replace('-', '')
if clean != value and clean not in variants:
variants.append(clean)
# Uppercase and lowercase versions
if value.upper() != value:
variants.append(value.upper())
if value.lower() != value:
variants.append(value.lower())
return list(set(v for v in variants if v))
@staticmethod @staticmethod
def normalize_amount(value: str) -> list[str]: def normalize_amount(value: str) -> list[str]:
""" """
@@ -414,7 +453,7 @@ class FieldNormalizer:
] ]
# Ambiguous patterns - try both DD/MM and MM/DD interpretations # Ambiguous patterns - try both DD/MM and MM/DD interpretations
ambiguous_patterns = [ ambiguous_patterns_4digit_year = [
# Format with / - could be DD/MM/YYYY (European) or MM/DD/YYYY (US) # Format with / - could be DD/MM/YYYY (European) or MM/DD/YYYY (US)
r'^(\d{1,2})/(\d{1,2})/(\d{4})$', r'^(\d{1,2})/(\d{1,2})/(\d{4})$',
# Format with . - typically European DD.MM.YYYY # Format with . - typically European DD.MM.YYYY
@@ -423,6 +462,16 @@ class FieldNormalizer:
r'^(\d{1,2})-(\d{1,2})-(\d{4})$', r'^(\d{1,2})-(\d{1,2})-(\d{4})$',
] ]
# Patterns with 2-digit year (common in Swedish invoices)
ambiguous_patterns_2digit_year = [
# Format DD.MM.YY (e.g., 02.08.25 for 2025-08-02)
r'^(\d{1,2})\.(\d{1,2})\.(\d{2})$',
# Format DD/MM/YY
r'^(\d{1,2})/(\d{1,2})/(\d{2})$',
# Format DD-MM-YY
r'^(\d{1,2})-(\d{1,2})-(\d{2})$',
]
# Try unambiguous patterns first # Try unambiguous patterns first
for pattern, extractor in date_patterns: for pattern, extractor in date_patterns:
match = re.match(pattern, value) match = re.match(pattern, value)
@@ -434,9 +483,9 @@ class FieldNormalizer:
except ValueError: except ValueError:
continue continue
# Try ambiguous patterns with both interpretations # Try ambiguous patterns with 4-digit year
if not parsed_dates: if not parsed_dates:
for pattern in ambiguous_patterns: for pattern in ambiguous_patterns_4digit_year:
match = re.match(pattern, value) match = re.match(pattern, value)
if match: if match:
n1, n2, year = int(match[1]), int(match[2]), int(match[3]) n1, n2, year = int(match[1]), int(match[2]), int(match[3])
@@ -457,6 +506,31 @@ class FieldNormalizer:
if parsed_dates: if parsed_dates:
break break
# Try ambiguous patterns with 2-digit year (e.g., 02.08.25)
if not parsed_dates:
for pattern in ambiguous_patterns_2digit_year:
match = re.match(pattern, value)
if match:
n1, n2, yy = int(match[1]), int(match[2]), int(match[3])
# Convert 2-digit year to 4-digit (00-49 -> 2000s, 50-99 -> 1900s)
year = 2000 + yy if yy < 50 else 1900 + yy
# Try DD/MM/YY (European - day first, most common in Sweden)
try:
parsed_dates.append(datetime(year, n2, n1))
except ValueError:
pass
# Try MM/DD/YY (US - month first) if different and valid
if n1 != n2:
try:
parsed_dates.append(datetime(year, n1, n2))
except ValueError:
pass
if parsed_dates:
break
# Try Swedish month names # Try Swedish month names
if not parsed_dates: if not parsed_dates:
for month_name, month_num in FieldNormalizer.SWEDISH_MONTHS.items(): for month_name, month_num in FieldNormalizer.SWEDISH_MONTHS.items():
@@ -527,6 +601,7 @@ NORMALIZERS: dict[str, Callable[[str], list[str]]] = {
'InvoiceDueDate': FieldNormalizer.normalize_date, 'InvoiceDueDate': FieldNormalizer.normalize_date,
'supplier_organisation_number': FieldNormalizer.normalize_organisation_number, 'supplier_organisation_number': FieldNormalizer.normalize_organisation_number,
'supplier_accounts': FieldNormalizer.normalize_supplier_accounts, 'supplier_accounts': FieldNormalizer.normalize_supplier_accounts,
'customer_number': FieldNormalizer.normalize_customer_number,
} }

View File

@@ -60,7 +60,9 @@ class OCREngine:
self, self,
lang: str = "en", lang: str = "en",
det_model_dir: str | None = None, det_model_dir: str | None = None,
rec_model_dir: str | None = None rec_model_dir: str | None = None,
use_doc_orientation_classify: bool = True,
use_doc_unwarping: bool = False
): ):
""" """
Initialize OCR engine. Initialize OCR engine.
@@ -69,6 +71,13 @@ class OCREngine:
lang: Language code ('en', 'sv', 'ch', etc.) lang: Language code ('en', 'sv', 'ch', etc.)
det_model_dir: Custom detection model directory det_model_dir: Custom detection model directory
rec_model_dir: Custom recognition model directory rec_model_dir: Custom recognition model directory
use_doc_orientation_classify: Whether to auto-detect and correct document orientation.
Default True to handle rotated documents.
use_doc_unwarping: Whether to use UVDoc document unwarping for curved/warped documents.
Default False to preserve original image layout,
especially important for payment OCR lines at bottom.
Enable for severely warped documents at the cost of potentially
losing bottom content.
Note: Note:
PaddleOCR 3.x automatically uses GPU if available via PaddlePaddle. PaddleOCR 3.x automatically uses GPU if available via PaddlePaddle.
@@ -82,6 +91,12 @@ class OCREngine:
# PaddleOCR 3.x init (use_gpu removed, device controlled by paddle.set_device) # PaddleOCR 3.x init (use_gpu removed, device controlled by paddle.set_device)
init_params = { init_params = {
'lang': lang, 'lang': lang,
# Enable orientation classification to handle rotated documents
'use_doc_orientation_classify': use_doc_orientation_classify,
# Disable UVDoc unwarping to preserve original image layout
# This prevents the bottom payment OCR line from being cut off
# For severely warped documents, enable this but expect potential content loss
'use_doc_unwarping': use_doc_unwarping,
} }
if det_model_dir: if det_model_dir:
init_params['text_detection_model_dir'] = det_model_dir init_params['text_detection_model_dir'] = det_model_dir
@@ -95,7 +110,9 @@ class OCREngine:
image: str | Path | np.ndarray, image: str | Path | np.ndarray,
page_no: int = 0, page_no: int = 0,
max_size: int = 2000, max_size: int = 2000,
scale_to_pdf_points: float | None = None scale_to_pdf_points: float | None = None,
scan_bottom_region: bool = True,
bottom_region_ratio: float = 0.15
) -> list[OCRToken]: ) -> list[OCRToken]:
""" """
Extract text tokens from an image. Extract text tokens from an image.
@@ -108,19 +125,106 @@ class OCREngine:
scale_to_pdf_points: If provided, scale bbox coordinates by this factor scale_to_pdf_points: If provided, scale bbox coordinates by this factor
to convert from pixel to PDF point coordinates. to convert from pixel to PDF point coordinates.
Use (72 / dpi) for images rendered at a specific DPI. Use (72 / dpi) for images rendered at a specific DPI.
scan_bottom_region: If True, also scan the bottom region separately to catch
OCR payment lines that may be missed in full-page scan.
bottom_region_ratio: Ratio of page height to scan as bottom region (default 0.15 = 15%)
Returns: Returns:
List of OCRToken objects with bbox in pixel coords (or PDF points if scale_to_pdf_points is set) List of OCRToken objects with bbox in pixel coords (or PDF points if scale_to_pdf_points is set)
""" """
result = self.extract_with_image(image, page_no, max_size, scale_to_pdf_points) result = self.extract_with_image(image, page_no, max_size, scale_to_pdf_points)
return result.tokens tokens = result.tokens
# Optionally scan bottom region separately for Swedish OCR payment lines
if scan_bottom_region:
bottom_tokens = self._scan_bottom_region(
image, page_no, max_size, scale_to_pdf_points, bottom_region_ratio
)
tokens = self._merge_tokens(tokens, bottom_tokens)
return tokens
def _scan_bottom_region(
self,
image: str | Path | np.ndarray,
page_no: int,
max_size: int,
scale_to_pdf_points: float | None,
bottom_ratio: float
) -> list[OCRToken]:
"""Scan the bottom region of the image separately."""
from PIL import Image as PILImage
# Load image if path
if isinstance(image, (str, Path)):
img = PILImage.open(str(image))
img_array = np.array(img)
else:
img_array = image
h, w = img_array.shape[:2]
crop_y = int(h * (1 - bottom_ratio))
# Crop bottom region
bottom_crop = img_array[crop_y:h, :, :] if len(img_array.shape) == 3 else img_array[crop_y:h, :]
# OCR the cropped region (without recursive bottom scan to avoid infinite loop)
result = self.extract_with_image(
bottom_crop, page_no, max_size,
scale_to_pdf_points=None,
scan_bottom_region=False # Important: disable to prevent recursion
)
# Adjust bbox y-coordinates to full image space
adjusted_tokens = []
for token in result.tokens:
# Scale factor for coordinates
scale = scale_to_pdf_points if scale_to_pdf_points else 1.0
adjusted_bbox = (
token.bbox[0] * scale,
(token.bbox[1] + crop_y) * scale,
token.bbox[2] * scale,
(token.bbox[3] + crop_y) * scale
)
adjusted_tokens.append(OCRToken(
text=token.text,
bbox=adjusted_bbox,
confidence=token.confidence,
page_no=token.page_no
))
return adjusted_tokens
def _merge_tokens(
self,
main_tokens: list[OCRToken],
bottom_tokens: list[OCRToken]
) -> list[OCRToken]:
"""Merge tokens from main scan and bottom region scan, removing duplicates."""
if not bottom_tokens:
return main_tokens
# Create a set of existing token texts for deduplication
existing_texts = {t.text.strip() for t in main_tokens}
# Add bottom tokens that aren't duplicates
merged = list(main_tokens)
for token in bottom_tokens:
if token.text.strip() not in existing_texts:
merged.append(token)
existing_texts.add(token.text.strip())
return merged
def extract_with_image( def extract_with_image(
self, self,
image: str | Path | np.ndarray, image: str | Path | np.ndarray,
page_no: int = 0, page_no: int = 0,
max_size: int = 2000, max_size: int = 2000,
scale_to_pdf_points: float | None = None scale_to_pdf_points: float | None = None,
scan_bottom_region: bool = True,
bottom_region_ratio: float = 0.15
) -> OCRResult: ) -> OCRResult:
""" """
Extract text tokens from an image and return the preprocessed image. Extract text tokens from an image and return the preprocessed image.
@@ -138,6 +242,9 @@ class OCREngine:
scale_to_pdf_points: If provided, scale bbox coordinates by this factor scale_to_pdf_points: If provided, scale bbox coordinates by this factor
to convert from pixel to PDF point coordinates. to convert from pixel to PDF point coordinates.
Use (72 / dpi) for images rendered at a specific DPI. Use (72 / dpi) for images rendered at a specific DPI.
scan_bottom_region: If True, also scan the bottom region separately to catch
OCR payment lines that may be missed in full-page scan.
bottom_region_ratio: Ratio of page height to scan as bottom region (default 0.15 = 15%)
Returns: Returns:
OCRResult with tokens and output_img (preprocessed image from PaddleOCR) OCRResult with tokens and output_img (preprocessed image from PaddleOCR)
@@ -241,6 +348,13 @@ class OCREngine:
if output_img is None: if output_img is None:
output_img = img_array output_img = img_array
# Optionally scan bottom region separately for Swedish OCR payment lines
if scan_bottom_region:
bottom_tokens = self._scan_bottom_region(
image, page_no, max_size, scale_to_pdf_points, bottom_region_ratio
)
tokens = self._merge_tokens(tokens, bottom_tokens)
return OCRResult(tokens=tokens, output_img=output_img) return OCRResult(tokens=tokens, output_img=output_img)
def extract_from_pdf( def extract_from_pdf(

View File

@@ -57,6 +57,7 @@ def get_pdf_type(pdf_path: str | Path) -> PDFType:
return "scanned" return "scanned"
text_pages = 0 text_pages = 0
total_pages = len(doc)
for page in doc: for page in doc:
text = page.get_text().strip() text = page.get_text().strip()
if len(text) > 30: if len(text) > 30:
@@ -64,7 +65,6 @@ def get_pdf_type(pdf_path: str | Path) -> PDFType:
doc.close() doc.close()
total_pages = len(doc)
if text_pages == total_pages: if text_pages == total_pages:
return "text" return "text"
elif text_pages == 0: elif text_pages == 0:

View File

@@ -85,6 +85,7 @@ def process_text_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
Returns: Returns:
Result dictionary with success status, annotations, and report. Result dictionary with success status, annotations, and report.
""" """
import shutil
from src.data import AutoLabelReport, FieldMatchResult from src.data import AutoLabelReport, FieldMatchResult
from src.pdf import PDFDocument from src.pdf import PDFDocument
from src.matcher import FieldMatcher from src.matcher import FieldMatcher
@@ -100,6 +101,11 @@ def process_text_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
start_time = time.time() start_time = time.time()
doc_id = row_dict["DocumentId"] doc_id = row_dict["DocumentId"]
# Clean up existing temp folder for this document (for re-matching)
temp_doc_dir = output_dir / "temp" / doc_id
if temp_doc_dir.exists():
shutil.rmtree(temp_doc_dir, ignore_errors=True)
report = AutoLabelReport(document_id=doc_id) report = AutoLabelReport(document_id=doc_id)
report.pdf_path = str(pdf_path) report.pdf_path = str(pdf_path)
report.pdf_type = "text" report.pdf_type = "text"
@@ -218,6 +224,7 @@ def process_scanned_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
Returns: Returns:
Result dictionary with success status, annotations, and report. Result dictionary with success status, annotations, and report.
""" """
import shutil
from src.data import AutoLabelReport, FieldMatchResult from src.data import AutoLabelReport, FieldMatchResult
from src.pdf import PDFDocument from src.pdf import PDFDocument
from src.matcher import FieldMatcher from src.matcher import FieldMatcher
@@ -233,6 +240,11 @@ def process_scanned_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
start_time = time.time() start_time = time.time()
doc_id = row_dict["DocumentId"] doc_id = row_dict["DocumentId"]
# Clean up existing temp folder for this document (for re-matching)
temp_doc_dir = output_dir / "temp" / doc_id
if temp_doc_dir.exists():
shutil.rmtree(temp_doc_dir, ignore_errors=True)
report = AutoLabelReport(document_id=doc_id) report = AutoLabelReport(document_id=doc_id)
report.pdf_path = str(pdf_path) report.pdf_path = str(pdf_path)
report.pdf_type = "scanned" report.pdf_type = "scanned"

View File

@@ -21,6 +21,7 @@ FIELD_CLASSES = {
'Plusgiro': 5, 'Plusgiro': 5,
'Amount': 6, 'Amount': 6,
'supplier_organisation_number': 7, 'supplier_organisation_number': 7,
'customer_number': 8,
} }
# Fields that need matching but map to other YOLO classes # Fields that need matching but map to other YOLO classes
@@ -41,6 +42,7 @@ CLASS_NAMES = [
'plusgiro', 'plusgiro',
'amount', 'amount',
'supplier_org_number', 'supplier_org_number',
'customer_number',
] ]