diff --git a/src/cli/autolabel.py b/src/cli/autolabel.py index 7c240cf..3675d0d 100644 --- a/src/cli/autolabel.py +++ b/src/cli/autolabel.py @@ -10,6 +10,7 @@ import sys import time import os import signal +import shutil import warnings from pathlib import Path from tqdm import tqdm @@ -107,6 +108,7 @@ def process_single_document(args_tuple): Returns: dict with results """ + import shutil row_dict, pdf_path_str, output_dir_str, dpi, min_confidence, skip_ocr = args_tuple # Import inside worker to avoid pickling issues @@ -121,6 +123,11 @@ def process_single_document(args_tuple): output_dir = Path(output_dir_str) doc_id = row_dict['DocumentId'] + # Clean up existing temp folder for this document (for re-matching) + temp_doc_dir = output_dir / 'temp' / doc_id + if temp_doc_dir.exists(): + shutil.rmtree(temp_doc_dir, ignore_errors=True) + report = AutoLabelReport(document_id=doc_id) report.pdf_path = str(pdf_path) # Store metadata fields from CSV @@ -602,6 +609,9 @@ def main(): else: remaining_limit = float('inf') + # Collect doc_ids that need retry (for batch delete) + retry_doc_ids = [] + for row in rows: # Stop adding tasks if we've reached the limit if len(tasks) >= remaining_limit: @@ -622,6 +632,7 @@ def main(): if db_status is False: stats['retried'] += 1 retry_in_csv += 1 + retry_doc_ids.append(doc_id) pdf_path = single_loader.get_pdf_path(row) if not pdf_path: @@ -637,12 +648,12 @@ def main(): 'Bankgiro': row.Bankgiro, 'Plusgiro': row.Plusgiro, 'Amount': row.Amount, - # New fields + # New fields for matching 'supplier_organisation_number': row.supplier_organisation_number, 'supplier_accounts': row.supplier_accounts, + 'customer_number': row.customer_number, # Metadata fields (not for matching, but for database storage) 'split': row.split, - 'customer_number': row.customer_number, 'supplier_name': row.supplier_name, } @@ -658,6 +669,22 @@ def main(): if skipped_in_csv > 0 or retry_in_csv > 0: print(f" Skipped {skipped_in_csv} (already in DB), retrying {retry_in_csv} failed") + # Clean up retry documents: delete from database and remove temp folders + if retry_doc_ids: + # Batch delete from database (field_results will be cascade deleted) + with db.connect().cursor() as cursor: + cursor.execute( + "DELETE FROM documents WHERE document_id = ANY(%s)", + (retry_doc_ids,) + ) + db.connect().commit() + # Remove temp folders + for doc_id in retry_doc_ids: + temp_doc_dir = output_dir / 'temp' / doc_id + if temp_doc_dir.exists(): + shutil.rmtree(temp_doc_dir, ignore_errors=True) + print(f" Cleaned up {len(retry_doc_ids)} retry documents (DB + temp folders)") + if not tasks: continue diff --git a/src/cli/reprocess_failed.py b/src/cli/reprocess_failed.py new file mode 100644 index 0000000..b2a65ff --- /dev/null +++ b/src/cli/reprocess_failed.py @@ -0,0 +1,424 @@ +#!/usr/bin/env python3 +""" +Re-process failed matches and store detailed information including OCR values, +CSV values, and source CSV filename in a new table. +""" + +import argparse +import json +import glob +import os +import sys +import time +from pathlib import Path +from datetime import datetime +from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError +from tqdm import tqdm + +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from src.data.db import DocumentDB +from src.data.csv_loader import CSVLoader +from src.normalize.normalizer import normalize_field + + +def create_failed_match_table(db: DocumentDB): + """Create the failed_match_details table.""" + conn = db.connect() + with conn.cursor() as cursor: + cursor.execute(""" + DROP TABLE IF EXISTS failed_match_details; + + CREATE TABLE failed_match_details ( + id SERIAL PRIMARY KEY, + document_id TEXT NOT NULL, + field_name TEXT NOT NULL, + csv_value TEXT, + csv_value_normalized TEXT, + ocr_value TEXT, + ocr_value_normalized TEXT, + all_ocr_candidates JSONB, + matched BOOLEAN DEFAULT FALSE, + match_score REAL, + pdf_path TEXT, + pdf_type TEXT, + csv_filename TEXT, + page_no INTEGER, + bbox JSONB, + error TEXT, + reprocessed_at TIMESTAMPTZ DEFAULT NOW(), + + UNIQUE(document_id, field_name) + ); + + CREATE INDEX IF NOT EXISTS idx_failed_match_document_id ON failed_match_details(document_id); + CREATE INDEX IF NOT EXISTS idx_failed_match_field_name ON failed_match_details(field_name); + CREATE INDEX IF NOT EXISTS idx_failed_match_csv_filename ON failed_match_details(csv_filename); + CREATE INDEX IF NOT EXISTS idx_failed_match_matched ON failed_match_details(matched); + """) + conn.commit() + print("Created table: failed_match_details") + + +def get_failed_documents(db: DocumentDB) -> list: + """Get all documents that have at least one failed field match.""" + conn = db.connect() + with conn.cursor() as cursor: + cursor.execute(""" + SELECT DISTINCT fr.document_id, d.pdf_path, d.pdf_type + FROM field_results fr + JOIN documents d ON fr.document_id = d.document_id + WHERE fr.matched = false + ORDER BY fr.document_id + """) + return [{'document_id': row[0], 'pdf_path': row[1], 'pdf_type': row[2]} + for row in cursor.fetchall()] + + +def get_failed_fields_for_document(db: DocumentDB, doc_id: str) -> list: + """Get all failed field results for a document.""" + conn = db.connect() + with conn.cursor() as cursor: + cursor.execute(""" + SELECT field_name, csv_value, error + FROM field_results + WHERE document_id = %s AND matched = false + """, (doc_id,)) + return [{'field_name': row[0], 'csv_value': row[1], 'error': row[2]} + for row in cursor.fetchall()] + + +# Cache for CSV data +_csv_cache = {} + +def build_csv_cache(csv_files: list): + """Build a cache of document_id to csv_filename mapping.""" + global _csv_cache + _csv_cache = {} + + for csv_file in csv_files: + csv_filename = os.path.basename(csv_file) + loader = CSVLoader(csv_file) + for row in loader.iter_rows(): + if row.DocumentId not in _csv_cache: + _csv_cache[row.DocumentId] = csv_filename + + +def find_csv_filename(doc_id: str) -> str: + """Find which CSV file contains the document ID.""" + return _csv_cache.get(doc_id, None) + + +def init_worker(): + """Initialize worker process.""" + import os + import warnings + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + os.environ["GLOG_minloglevel"] = "2" + warnings.filterwarnings("ignore") + + +def process_single_document(args): + """Process a single document and extract OCR values for failed fields.""" + doc_info, failed_fields, csv_filename = args + doc_id = doc_info['document_id'] + pdf_path = doc_info['pdf_path'] + pdf_type = doc_info['pdf_type'] + + results = [] + + # Try to extract OCR from PDF + try: + if pdf_path and os.path.exists(pdf_path): + from src.pdf import PDFDocument + from src.ocr import OCREngine + + pdf_doc = PDFDocument(pdf_path) + is_scanned = pdf_doc.detect_type() == "scanned" + + # Collect all OCR text blocks + all_ocr_texts = [] + + if is_scanned: + # Use OCR for scanned PDFs + ocr_engine = OCREngine() + for page_no in range(pdf_doc.page_count): + # Render page to image + img = pdf_doc.render_page(page_no, dpi=150) + if img is None: + continue + + # OCR the image + ocr_results = ocr_engine.extract_from_image(img) + for block in ocr_results: + all_ocr_texts.append({ + 'text': block.get('text', ''), + 'bbox': block.get('bbox'), + 'page_no': page_no + }) + else: + # Use text extraction for text PDFs + for page_no in range(pdf_doc.page_count): + tokens = list(pdf_doc.extract_text_tokens(page_no)) + for token in tokens: + all_ocr_texts.append({ + 'text': token.text, + 'bbox': token.bbox, + 'page_no': page_no + }) + + # For each failed field, try to find matching OCR + for field in failed_fields: + field_name = field['field_name'] + csv_value = field['csv_value'] + error = field['error'] + + # Normalize CSV value + csv_normalized = normalize_field(field_name, csv_value) if csv_value else None + + # Try to find best match in OCR + best_score = 0 + best_ocr = None + best_bbox = None + best_page = None + + for ocr_block in all_ocr_texts: + ocr_text = ocr_block['text'] + if not ocr_text: + continue + ocr_normalized = normalize_field(field_name, ocr_text) + + # Calculate similarity + if csv_normalized and ocr_normalized: + # Check substring match + if csv_normalized in ocr_normalized: + score = len(csv_normalized) / max(len(ocr_normalized), 1) + if score > best_score: + best_score = score + best_ocr = ocr_text + best_bbox = ocr_block['bbox'] + best_page = ocr_block['page_no'] + elif ocr_normalized in csv_normalized: + score = len(ocr_normalized) / max(len(csv_normalized), 1) + if score > best_score: + best_score = score + best_ocr = ocr_text + best_bbox = ocr_block['bbox'] + best_page = ocr_block['page_no'] + # Exact match + elif csv_normalized == ocr_normalized: + best_score = 1.0 + best_ocr = ocr_text + best_bbox = ocr_block['bbox'] + best_page = ocr_block['page_no'] + break + + results.append({ + 'document_id': doc_id, + 'field_name': field_name, + 'csv_value': csv_value, + 'csv_value_normalized': csv_normalized, + 'ocr_value': best_ocr, + 'ocr_value_normalized': normalize_field(field_name, best_ocr) if best_ocr else None, + 'all_ocr_candidates': [t['text'] for t in all_ocr_texts[:100]], # Limit to 100 + 'matched': best_score > 0.8, + 'match_score': best_score, + 'pdf_path': pdf_path, + 'pdf_type': pdf_type, + 'csv_filename': csv_filename, + 'page_no': best_page, + 'bbox': list(best_bbox) if best_bbox else None, + 'error': error + }) + else: + # PDF not found + for field in failed_fields: + results.append({ + 'document_id': doc_id, + 'field_name': field['field_name'], + 'csv_value': field['csv_value'], + 'csv_value_normalized': normalize_field(field['field_name'], field['csv_value']) if field['csv_value'] else None, + 'ocr_value': None, + 'ocr_value_normalized': None, + 'all_ocr_candidates': [], + 'matched': False, + 'match_score': 0, + 'pdf_path': pdf_path, + 'pdf_type': pdf_type, + 'csv_filename': csv_filename, + 'page_no': None, + 'bbox': None, + 'error': f"PDF not found: {pdf_path}" + }) + + except Exception as e: + for field in failed_fields: + results.append({ + 'document_id': doc_id, + 'field_name': field['field_name'], + 'csv_value': field['csv_value'], + 'csv_value_normalized': None, + 'ocr_value': None, + 'ocr_value_normalized': None, + 'all_ocr_candidates': [], + 'matched': False, + 'match_score': 0, + 'pdf_path': pdf_path, + 'pdf_type': pdf_type, + 'csv_filename': csv_filename, + 'page_no': None, + 'bbox': None, + 'error': str(e) + }) + + return results + + +def save_results_batch(db: DocumentDB, results: list): + """Save results to failed_match_details table.""" + if not results: + return + + conn = db.connect() + with conn.cursor() as cursor: + for r in results: + cursor.execute(""" + INSERT INTO failed_match_details + (document_id, field_name, csv_value, csv_value_normalized, + ocr_value, ocr_value_normalized, all_ocr_candidates, + matched, match_score, pdf_path, pdf_type, csv_filename, + page_no, bbox, error) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + ON CONFLICT (document_id, field_name) DO UPDATE SET + csv_value = EXCLUDED.csv_value, + csv_value_normalized = EXCLUDED.csv_value_normalized, + ocr_value = EXCLUDED.ocr_value, + ocr_value_normalized = EXCLUDED.ocr_value_normalized, + all_ocr_candidates = EXCLUDED.all_ocr_candidates, + matched = EXCLUDED.matched, + match_score = EXCLUDED.match_score, + pdf_path = EXCLUDED.pdf_path, + pdf_type = EXCLUDED.pdf_type, + csv_filename = EXCLUDED.csv_filename, + page_no = EXCLUDED.page_no, + bbox = EXCLUDED.bbox, + error = EXCLUDED.error, + reprocessed_at = NOW() + """, ( + r['document_id'], + r['field_name'], + r['csv_value'], + r['csv_value_normalized'], + r['ocr_value'], + r['ocr_value_normalized'], + json.dumps(r['all_ocr_candidates']), + r['matched'], + r['match_score'], + r['pdf_path'], + r['pdf_type'], + r['csv_filename'], + r['page_no'], + json.dumps(r['bbox']) if r['bbox'] else None, + r['error'] + )) + conn.commit() + + +def main(): + parser = argparse.ArgumentParser(description='Re-process failed matches') + parser.add_argument('--csv', required=True, help='CSV files glob pattern') + parser.add_argument('--pdf-dir', required=True, help='PDF directory') + parser.add_argument('--workers', type=int, default=3, help='Number of workers') + parser.add_argument('--limit', type=int, help='Limit number of documents to process') + args = parser.parse_args() + + # Expand CSV glob + csv_files = sorted(glob.glob(args.csv)) + print(f"Found {len(csv_files)} CSV files") + + # Build CSV cache + print("Building CSV filename cache...") + build_csv_cache(csv_files) + print(f"Cached {len(_csv_cache)} document IDs") + + # Connect to database + db = DocumentDB() + db.connect() + + # Create new table + create_failed_match_table(db) + + # Get all failed documents + print("Fetching failed documents...") + failed_docs = get_failed_documents(db) + print(f"Found {len(failed_docs)} documents with failed matches") + + if args.limit: + failed_docs = failed_docs[:args.limit] + print(f"Limited to {len(failed_docs)} documents") + + # Prepare tasks + tasks = [] + for doc in failed_docs: + failed_fields = get_failed_fields_for_document(db, doc['document_id']) + csv_filename = find_csv_filename(doc['document_id']) + if failed_fields: + tasks.append((doc, failed_fields, csv_filename)) + + print(f"Processing {len(tasks)} documents with {args.workers} workers...") + + # Process with multiprocessing + total_results = 0 + batch_results = [] + batch_size = 50 + + with ProcessPoolExecutor(max_workers=args.workers, initializer=init_worker) as executor: + futures = {executor.submit(process_single_document, task): task[0]['document_id'] + for task in tasks} + + for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"): + doc_id = futures[future] + try: + results = future.result(timeout=120) + batch_results.extend(results) + total_results += len(results) + + # Save in batches + if len(batch_results) >= batch_size: + save_results_batch(db, batch_results) + batch_results = [] + + except TimeoutError: + print(f"\nTimeout processing {doc_id}") + except Exception as e: + print(f"\nError processing {doc_id}: {e}") + + # Save remaining results + if batch_results: + save_results_batch(db, batch_results) + + print(f"\nDone! Saved {total_results} failed match records to failed_match_details table") + + # Show summary + conn = db.connect() + with conn.cursor() as cursor: + cursor.execute(""" + SELECT field_name, COUNT(*) as total, + COUNT(*) FILTER (WHERE ocr_value IS NOT NULL) as has_ocr, + COALESCE(AVG(match_score), 0) as avg_score + FROM failed_match_details + GROUP BY field_name + ORDER BY total DESC + """) + print("\nSummary by field:") + print("-" * 70) + print(f"{'Field':<35} {'Total':>8} {'Has OCR':>10} {'Avg Score':>12}") + print("-" * 70) + for row in cursor.fetchall(): + print(f"{row[0]:<35} {row[1]:>8} {row[2]:>10} {row[3]:>12.2f}") + + db.close() + + +if __name__ == '__main__': + main() diff --git a/src/data/csv_loader.py b/src/data/csv_loader.py index cac43a4..3742060 100644 --- a/src/data/csv_loader.py +++ b/src/data/csv_loader.py @@ -27,7 +27,7 @@ class InvoiceRow: Amount: Decimal | None = None # New fields split: str | None = None # train/test split indicator - customer_number: str | None = None # Customer number (no matching needed) + customer_number: str | None = None # Customer number (needs matching) supplier_name: str | None = None # Supplier name (no matching) supplier_organisation_number: str | None = None # Swedish org number (needs matching) supplier_accounts: str | None = None # Supplier accounts (needs matching) @@ -198,22 +198,30 @@ class CSVLoader: value = value.strip() return value if value else None + def _get_field(self, row: dict, *keys: str) -> str | None: + """Get field value trying multiple possible column names.""" + for key in keys: + value = row.get(key) + if value is not None: + return value + return None + def _parse_row(self, row: dict) -> InvoiceRow | None: """Parse a single CSV row into InvoiceRow.""" - doc_id = self._parse_string(row.get('DocumentId')) + doc_id = self._parse_string(self._get_field(row, 'DocumentId', 'document_id')) if not doc_id: return None return InvoiceRow( DocumentId=doc_id, - InvoiceDate=self._parse_date(row.get('InvoiceDate')), - InvoiceNumber=self._parse_string(row.get('InvoiceNumber')), - InvoiceDueDate=self._parse_date(row.get('InvoiceDueDate')), - OCR=self._parse_string(row.get('OCR')), - Message=self._parse_string(row.get('Message')), - Bankgiro=self._parse_string(row.get('Bankgiro')), - Plusgiro=self._parse_string(row.get('Plusgiro')), - Amount=self._parse_amount(row.get('Amount')), + InvoiceDate=self._parse_date(self._get_field(row, 'InvoiceDate', 'invoice_date')), + InvoiceNumber=self._parse_string(self._get_field(row, 'InvoiceNumber', 'invoice_number')), + InvoiceDueDate=self._parse_date(self._get_field(row, 'InvoiceDueDate', 'invoice_due_date')), + OCR=self._parse_string(self._get_field(row, 'OCR', 'ocr')), + Message=self._parse_string(self._get_field(row, 'Message', 'message')), + Bankgiro=self._parse_string(self._get_field(row, 'Bankgiro', 'bankgiro')), + Plusgiro=self._parse_string(self._get_field(row, 'Plusgiro', 'plusgiro')), + Amount=self._parse_amount(self._get_field(row, 'Amount', 'amount', 'invoice_data_amount')), # New fields split=self._parse_string(row.get('split')), customer_number=self._parse_string(row.get('customer_number')), diff --git a/src/matcher/field_matcher.py b/src/matcher/field_matcher.py index 4f14a9f..7e0205e 100644 --- a/src/matcher/field_matcher.py +++ b/src/matcher/field_matcher.py @@ -219,7 +219,7 @@ class FieldMatcher: # Note: Amount is excluded because short numbers like "451" can incorrectly match # in OCR payment lines or other unrelated text if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', - 'supplier_organisation_number', 'supplier_accounts'): + 'supplier_organisation_number', 'supplier_accounts', 'customer_number'): substring_matches = self._find_substring_matches(page_tokens, value, field_name) matches.extend(substring_matches) @@ -369,7 +369,7 @@ class FieldMatcher: # Supported fields for substring matching supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount', - 'supplier_organisation_number', 'supplier_accounts') + 'supplier_organisation_number', 'supplier_accounts', 'customer_number') if field_name not in supported_fields: return matches @@ -383,49 +383,59 @@ class FieldMatcher: continue # Check if value appears as substring (using normalized text) + # Try case-sensitive first, then case-insensitive if value in token_text_normalized: - # Verify it's a proper boundary match (not part of a larger number) idx = token_text_normalized.find(value) + case_sensitive_match = True + elif value.lower() in token_text_normalized.lower(): + idx = token_text_normalized.lower().find(value.lower()) + case_sensitive_match = False + else: + continue - # Check character before (if exists) - if idx > 0: - char_before = token_text_normalized[idx - 1] - # Must be non-digit (allow : space - etc) - if char_before.isdigit(): - continue + # Verify it's a proper boundary match (not part of a larger number) + # Check character before (if exists) + if idx > 0: + char_before = token_text_normalized[idx - 1] + # Must be non-digit (allow : space - etc) + if char_before.isdigit(): + continue - # Check character after (if exists) - end_idx = idx + len(value) - if end_idx < len(token_text_normalized): - char_after = token_text_normalized[end_idx] - # Must be non-digit - if char_after.isdigit(): - continue + # Check character after (if exists) + end_idx = idx + len(value) + if end_idx < len(token_text_normalized): + char_after = token_text_normalized[end_idx] + # Must be non-digit + if char_after.isdigit(): + continue - # Found valid substring match - context_keywords, context_boost = self._find_context_keywords( - tokens, token, field_name - ) + # Found valid substring match + context_keywords, context_boost = self._find_context_keywords( + tokens, token, field_name + ) - # Check if context keyword is in the same token (like "Fakturadatum:") - token_lower = token_text.lower() - inline_context = [] - for keyword in CONTEXT_KEYWORDS.get(field_name, []): - if keyword in token_lower: - inline_context.append(keyword) + # Check if context keyword is in the same token (like "Fakturadatum:") + token_lower = token_text.lower() + inline_context = [] + for keyword in CONTEXT_KEYWORDS.get(field_name, []): + if keyword in token_lower: + inline_context.append(keyword) - # Boost score if keyword is inline - inline_boost = 0.1 if inline_context else 0 + # Boost score if keyword is inline + inline_boost = 0.1 if inline_context else 0 - matches.append(Match( - field=field_name, - value=value, - bbox=token.bbox, # Use full token bbox - page_no=token.page_no, - score=min(1.0, 0.75 + context_boost + inline_boost), # Lower than exact match - matched_text=token_text, - context_keywords=context_keywords + inline_context - )) + # Lower score for case-insensitive match + base_score = 0.75 if case_sensitive_match else 0.70 + + matches.append(Match( + field=field_name, + value=value, + bbox=token.bbox, # Use full token bbox + page_no=token.page_no, + score=min(1.0, base_score + context_boost + inline_boost), + matched_text=token_text, + context_keywords=context_keywords + inline_context + )) return matches diff --git a/src/normalize/normalizer.py b/src/normalize/normalizer.py index ba66209..caaeca5 100644 --- a/src/normalize/normalizer.py +++ b/src/normalize/normalizer.py @@ -260,6 +260,45 @@ class FieldNormalizer: return list(set(v for v in variants if v)) + @staticmethod + def normalize_customer_number(value: str) -> list[str]: + """ + Normalize customer number. + + Customer numbers can have various formats: + - Alphanumeric codes: 'EMM 256-6', 'ABC123', 'A-1234' + - Pure numbers: '12345', '123-456' + + Examples: + 'EMM 256-6' -> ['EMM 256-6', 'EMM256-6', 'EMM2566'] + 'ABC 123' -> ['ABC 123', 'ABC123'] + """ + value = FieldNormalizer.clean_text(value) + variants = [value] + + # Version without spaces + no_space = value.replace(' ', '') + if no_space != value: + variants.append(no_space) + + # Version without dashes + no_dash = value.replace('-', '') + if no_dash != value: + variants.append(no_dash) + + # Version without spaces and dashes + clean = value.replace(' ', '').replace('-', '') + if clean != value and clean not in variants: + variants.append(clean) + + # Uppercase and lowercase versions + if value.upper() != value: + variants.append(value.upper()) + if value.lower() != value: + variants.append(value.lower()) + + return list(set(v for v in variants if v)) + @staticmethod def normalize_amount(value: str) -> list[str]: """ @@ -414,7 +453,7 @@ class FieldNormalizer: ] # Ambiguous patterns - try both DD/MM and MM/DD interpretations - ambiguous_patterns = [ + ambiguous_patterns_4digit_year = [ # Format with / - could be DD/MM/YYYY (European) or MM/DD/YYYY (US) r'^(\d{1,2})/(\d{1,2})/(\d{4})$', # Format with . - typically European DD.MM.YYYY @@ -423,6 +462,16 @@ class FieldNormalizer: r'^(\d{1,2})-(\d{1,2})-(\d{4})$', ] + # Patterns with 2-digit year (common in Swedish invoices) + ambiguous_patterns_2digit_year = [ + # Format DD.MM.YY (e.g., 02.08.25 for 2025-08-02) + r'^(\d{1,2})\.(\d{1,2})\.(\d{2})$', + # Format DD/MM/YY + r'^(\d{1,2})/(\d{1,2})/(\d{2})$', + # Format DD-MM-YY + r'^(\d{1,2})-(\d{1,2})-(\d{2})$', + ] + # Try unambiguous patterns first for pattern, extractor in date_patterns: match = re.match(pattern, value) @@ -434,9 +483,9 @@ class FieldNormalizer: except ValueError: continue - # Try ambiguous patterns with both interpretations + # Try ambiguous patterns with 4-digit year if not parsed_dates: - for pattern in ambiguous_patterns: + for pattern in ambiguous_patterns_4digit_year: match = re.match(pattern, value) if match: n1, n2, year = int(match[1]), int(match[2]), int(match[3]) @@ -457,6 +506,31 @@ class FieldNormalizer: if parsed_dates: break + # Try ambiguous patterns with 2-digit year (e.g., 02.08.25) + if not parsed_dates: + for pattern in ambiguous_patterns_2digit_year: + match = re.match(pattern, value) + if match: + n1, n2, yy = int(match[1]), int(match[2]), int(match[3]) + # Convert 2-digit year to 4-digit (00-49 -> 2000s, 50-99 -> 1900s) + year = 2000 + yy if yy < 50 else 1900 + yy + + # Try DD/MM/YY (European - day first, most common in Sweden) + try: + parsed_dates.append(datetime(year, n2, n1)) + except ValueError: + pass + + # Try MM/DD/YY (US - month first) if different and valid + if n1 != n2: + try: + parsed_dates.append(datetime(year, n1, n2)) + except ValueError: + pass + + if parsed_dates: + break + # Try Swedish month names if not parsed_dates: for month_name, month_num in FieldNormalizer.SWEDISH_MONTHS.items(): @@ -527,6 +601,7 @@ NORMALIZERS: dict[str, Callable[[str], list[str]]] = { 'InvoiceDueDate': FieldNormalizer.normalize_date, 'supplier_organisation_number': FieldNormalizer.normalize_organisation_number, 'supplier_accounts': FieldNormalizer.normalize_supplier_accounts, + 'customer_number': FieldNormalizer.normalize_customer_number, } diff --git a/src/ocr/paddle_ocr.py b/src/ocr/paddle_ocr.py index 0a9c7b4..76c5775 100644 --- a/src/ocr/paddle_ocr.py +++ b/src/ocr/paddle_ocr.py @@ -60,7 +60,9 @@ class OCREngine: self, lang: str = "en", det_model_dir: str | None = None, - rec_model_dir: str | None = None + rec_model_dir: str | None = None, + use_doc_orientation_classify: bool = True, + use_doc_unwarping: bool = False ): """ Initialize OCR engine. @@ -69,6 +71,13 @@ class OCREngine: lang: Language code ('en', 'sv', 'ch', etc.) det_model_dir: Custom detection model directory rec_model_dir: Custom recognition model directory + use_doc_orientation_classify: Whether to auto-detect and correct document orientation. + Default True to handle rotated documents. + use_doc_unwarping: Whether to use UVDoc document unwarping for curved/warped documents. + Default False to preserve original image layout, + especially important for payment OCR lines at bottom. + Enable for severely warped documents at the cost of potentially + losing bottom content. Note: PaddleOCR 3.x automatically uses GPU if available via PaddlePaddle. @@ -82,6 +91,12 @@ class OCREngine: # PaddleOCR 3.x init (use_gpu removed, device controlled by paddle.set_device) init_params = { 'lang': lang, + # Enable orientation classification to handle rotated documents + 'use_doc_orientation_classify': use_doc_orientation_classify, + # Disable UVDoc unwarping to preserve original image layout + # This prevents the bottom payment OCR line from being cut off + # For severely warped documents, enable this but expect potential content loss + 'use_doc_unwarping': use_doc_unwarping, } if det_model_dir: init_params['text_detection_model_dir'] = det_model_dir @@ -95,7 +110,9 @@ class OCREngine: image: str | Path | np.ndarray, page_no: int = 0, max_size: int = 2000, - scale_to_pdf_points: float | None = None + scale_to_pdf_points: float | None = None, + scan_bottom_region: bool = True, + bottom_region_ratio: float = 0.15 ) -> list[OCRToken]: """ Extract text tokens from an image. @@ -108,19 +125,106 @@ class OCREngine: scale_to_pdf_points: If provided, scale bbox coordinates by this factor to convert from pixel to PDF point coordinates. Use (72 / dpi) for images rendered at a specific DPI. + scan_bottom_region: If True, also scan the bottom region separately to catch + OCR payment lines that may be missed in full-page scan. + bottom_region_ratio: Ratio of page height to scan as bottom region (default 0.15 = 15%) Returns: List of OCRToken objects with bbox in pixel coords (or PDF points if scale_to_pdf_points is set) """ result = self.extract_with_image(image, page_no, max_size, scale_to_pdf_points) - return result.tokens + tokens = result.tokens + + # Optionally scan bottom region separately for Swedish OCR payment lines + if scan_bottom_region: + bottom_tokens = self._scan_bottom_region( + image, page_no, max_size, scale_to_pdf_points, bottom_region_ratio + ) + tokens = self._merge_tokens(tokens, bottom_tokens) + + return tokens + + def _scan_bottom_region( + self, + image: str | Path | np.ndarray, + page_no: int, + max_size: int, + scale_to_pdf_points: float | None, + bottom_ratio: float + ) -> list[OCRToken]: + """Scan the bottom region of the image separately.""" + from PIL import Image as PILImage + + # Load image if path + if isinstance(image, (str, Path)): + img = PILImage.open(str(image)) + img_array = np.array(img) + else: + img_array = image + + h, w = img_array.shape[:2] + crop_y = int(h * (1 - bottom_ratio)) + + # Crop bottom region + bottom_crop = img_array[crop_y:h, :, :] if len(img_array.shape) == 3 else img_array[crop_y:h, :] + + # OCR the cropped region (without recursive bottom scan to avoid infinite loop) + result = self.extract_with_image( + bottom_crop, page_no, max_size, + scale_to_pdf_points=None, + scan_bottom_region=False # Important: disable to prevent recursion + ) + + # Adjust bbox y-coordinates to full image space + adjusted_tokens = [] + for token in result.tokens: + # Scale factor for coordinates + scale = scale_to_pdf_points if scale_to_pdf_points else 1.0 + + adjusted_bbox = ( + token.bbox[0] * scale, + (token.bbox[1] + crop_y) * scale, + token.bbox[2] * scale, + (token.bbox[3] + crop_y) * scale + ) + adjusted_tokens.append(OCRToken( + text=token.text, + bbox=adjusted_bbox, + confidence=token.confidence, + page_no=token.page_no + )) + + return adjusted_tokens + + def _merge_tokens( + self, + main_tokens: list[OCRToken], + bottom_tokens: list[OCRToken] + ) -> list[OCRToken]: + """Merge tokens from main scan and bottom region scan, removing duplicates.""" + if not bottom_tokens: + return main_tokens + + # Create a set of existing token texts for deduplication + existing_texts = {t.text.strip() for t in main_tokens} + + # Add bottom tokens that aren't duplicates + merged = list(main_tokens) + for token in bottom_tokens: + if token.text.strip() not in existing_texts: + merged.append(token) + existing_texts.add(token.text.strip()) + + return merged def extract_with_image( self, image: str | Path | np.ndarray, page_no: int = 0, max_size: int = 2000, - scale_to_pdf_points: float | None = None + scale_to_pdf_points: float | None = None, + scan_bottom_region: bool = True, + bottom_region_ratio: float = 0.15 ) -> OCRResult: """ Extract text tokens from an image and return the preprocessed image. @@ -138,6 +242,9 @@ class OCREngine: scale_to_pdf_points: If provided, scale bbox coordinates by this factor to convert from pixel to PDF point coordinates. Use (72 / dpi) for images rendered at a specific DPI. + scan_bottom_region: If True, also scan the bottom region separately to catch + OCR payment lines that may be missed in full-page scan. + bottom_region_ratio: Ratio of page height to scan as bottom region (default 0.15 = 15%) Returns: OCRResult with tokens and output_img (preprocessed image from PaddleOCR) @@ -241,6 +348,13 @@ class OCREngine: if output_img is None: output_img = img_array + # Optionally scan bottom region separately for Swedish OCR payment lines + if scan_bottom_region: + bottom_tokens = self._scan_bottom_region( + image, page_no, max_size, scale_to_pdf_points, bottom_region_ratio + ) + tokens = self._merge_tokens(tokens, bottom_tokens) + return OCRResult(tokens=tokens, output_img=output_img) def extract_from_pdf( diff --git a/src/pdf/detector.py b/src/pdf/detector.py index 4b9ec99..0344766 100644 --- a/src/pdf/detector.py +++ b/src/pdf/detector.py @@ -57,6 +57,7 @@ def get_pdf_type(pdf_path: str | Path) -> PDFType: return "scanned" text_pages = 0 + total_pages = len(doc) for page in doc: text = page.get_text().strip() if len(text) > 30: @@ -64,7 +65,6 @@ def get_pdf_type(pdf_path: str | Path) -> PDFType: doc.close() - total_pages = len(doc) if text_pages == total_pages: return "text" elif text_pages == 0: diff --git a/src/processing/autolabel_tasks.py b/src/processing/autolabel_tasks.py index bee162d..e2e720b 100644 --- a/src/processing/autolabel_tasks.py +++ b/src/processing/autolabel_tasks.py @@ -85,6 +85,7 @@ def process_text_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]: Returns: Result dictionary with success status, annotations, and report. """ + import shutil from src.data import AutoLabelReport, FieldMatchResult from src.pdf import PDFDocument from src.matcher import FieldMatcher @@ -100,6 +101,11 @@ def process_text_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]: start_time = time.time() doc_id = row_dict["DocumentId"] + # Clean up existing temp folder for this document (for re-matching) + temp_doc_dir = output_dir / "temp" / doc_id + if temp_doc_dir.exists(): + shutil.rmtree(temp_doc_dir, ignore_errors=True) + report = AutoLabelReport(document_id=doc_id) report.pdf_path = str(pdf_path) report.pdf_type = "text" @@ -218,6 +224,7 @@ def process_scanned_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]: Returns: Result dictionary with success status, annotations, and report. """ + import shutil from src.data import AutoLabelReport, FieldMatchResult from src.pdf import PDFDocument from src.matcher import FieldMatcher @@ -233,6 +240,11 @@ def process_scanned_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]: start_time = time.time() doc_id = row_dict["DocumentId"] + # Clean up existing temp folder for this document (for re-matching) + temp_doc_dir = output_dir / "temp" / doc_id + if temp_doc_dir.exists(): + shutil.rmtree(temp_doc_dir, ignore_errors=True) + report = AutoLabelReport(document_id=doc_id) report.pdf_path = str(pdf_path) report.pdf_type = "scanned" diff --git a/src/yolo/annotation_generator.py b/src/yolo/annotation_generator.py index 9dc0823..785bf37 100644 --- a/src/yolo/annotation_generator.py +++ b/src/yolo/annotation_generator.py @@ -21,6 +21,7 @@ FIELD_CLASSES = { 'Plusgiro': 5, 'Amount': 6, 'supplier_organisation_number': 7, + 'customer_number': 8, } # Fields that need matching but map to other YOLO classes @@ -41,6 +42,7 @@ CLASS_NAMES = [ 'plusgiro', 'amount', 'supplier_org_number', + 'customer_number', ]