This commit is contained in:
Yaojia Wang
2026-01-16 23:10:01 +01:00
parent 53d1e8db25
commit 425b8fdedf
10 changed files with 653 additions and 87 deletions

View File

@@ -9,12 +9,24 @@ import argparse
import sys
import time
import os
import signal
import warnings
from pathlib import Path
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
import multiprocessing
# Global flag for graceful shutdown
_shutdown_requested = False
def _signal_handler(signum, frame):
"""Handle interrupt signals for graceful shutdown."""
global _shutdown_requested
_shutdown_requested = True
print("\n\nShutdown requested. Finishing current batch and saving progress...")
print("(Press Ctrl+C again to force quit)\n")
# Windows compatibility: use 'spawn' method for multiprocessing
# This is required on Windows and is also safer for libraries like PaddleOCR
if sys.platform == 'win32':
@@ -111,6 +123,12 @@ def process_single_document(args_tuple):
report = AutoLabelReport(document_id=doc_id)
report.pdf_path = str(pdf_path)
# Store metadata fields from CSV
report.split = row_dict.get('split')
report.customer_number = row_dict.get('customer_number')
report.supplier_name = row_dict.get('supplier_name')
report.supplier_organisation_number = row_dict.get('supplier_organisation_number')
report.supplier_accounts = row_dict.get('supplier_accounts')
result = {
'doc_id': doc_id,
@@ -204,6 +222,67 @@ def process_single_document(args_tuple):
context_keywords=best.context_keywords
))
# Match supplier_accounts and map to Bankgiro/Plusgiro
supplier_accounts_value = row_dict.get('supplier_accounts')
if supplier_accounts_value:
# Parse accounts: "BG:xxx | PG:yyy" format
accounts = [acc.strip() for acc in str(supplier_accounts_value).split('|')]
for account in accounts:
account = account.strip()
if not account:
continue
# Determine account type (BG or PG) and extract account number
account_type = None
account_number = account # Default to full value
if account.upper().startswith('BG:'):
account_type = 'Bankgiro'
account_number = account[3:].strip() # Remove "BG:" prefix
elif account.upper().startswith('BG '):
account_type = 'Bankgiro'
account_number = account[2:].strip() # Remove "BG" prefix
elif account.upper().startswith('PG:'):
account_type = 'Plusgiro'
account_number = account[3:].strip() # Remove "PG:" prefix
elif account.upper().startswith('PG '):
account_type = 'Plusgiro'
account_number = account[2:].strip() # Remove "PG" prefix
else:
# Try to guess from format - Plusgiro often has format XXXXXXX-X
digits = ''.join(c for c in account if c.isdigit())
if len(digits) == 8 and '-' in account:
account_type = 'Plusgiro'
elif len(digits) in (7, 8):
account_type = 'Bankgiro' # Default to Bankgiro
if not account_type:
continue
# Normalize and match using the account number (without prefix)
normalized = normalize_field('supplier_accounts', account_number)
field_matches = matcher.find_matches(tokens, account_type, normalized, page_no)
if field_matches:
best = field_matches[0]
# Add to matches under the target class (Bankgiro/Plusgiro)
if account_type not in matches:
matches[account_type] = []
matches[account_type].extend(field_matches)
matched_fields.add('supplier_accounts')
report.add_field_result(FieldMatchResult(
field_name=f'supplier_accounts({account_type})',
csv_value=account_number, # Store without prefix
matched=True,
score=best.score,
matched_text=best.matched_text,
candidate_used=best.value,
bbox=best.bbox,
page_no=page_no,
context_keywords=best.context_keywords
))
# Count annotations
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
@@ -329,6 +408,10 @@ def main():
args = parser.parse_args()
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGINT, _signal_handler)
signal.signal(signal.SIGTERM, _signal_handler)
# Import here to avoid slow startup
from ..data import CSVLoader, AutoLabelReport, FieldMatchResult
from ..data.autolabel_report import ReportWriter
@@ -364,6 +447,7 @@ def main():
from ..data.db import DocumentDB
db = DocumentDB()
db.connect()
db.create_tables() # Ensure tables exist
print("Connected to database for status checking")
# Global stats
@@ -458,6 +542,11 @@ def main():
try:
# Process CSV files one by one (streaming)
for csv_idx, csv_file in enumerate(csv_files):
# Check for shutdown request
if _shutdown_requested:
print("\nShutdown requested. Stopping after current batch...")
break
print(f"\n[{csv_idx + 1}/{len(csv_files)}] Processing: {csv_file.name}")
# Load only this CSV file
@@ -548,6 +637,13 @@ def main():
'Bankgiro': row.Bankgiro,
'Plusgiro': row.Plusgiro,
'Amount': row.Amount,
# New fields
'supplier_organisation_number': row.supplier_organisation_number,
'supplier_accounts': row.supplier_accounts,
# Metadata fields (not for matching, but for database storage)
'split': row.split,
'customer_number': row.customer_number,
'supplier_name': row.supplier_name,
}
tasks.append((
@@ -647,11 +743,19 @@ def main():
futures = {executor.submit(process_single_document, task): task[0]['DocumentId']
for task in tasks}
# Per-document timeout: 120 seconds (2 minutes)
# This prevents a single stuck document from blocking the entire batch
DOCUMENT_TIMEOUT = 120
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
doc_id = futures[future]
try:
result = future.result()
result = future.result(timeout=DOCUMENT_TIMEOUT)
handle_result(result)
except TimeoutError:
handle_error(doc_id, f"Processing timeout after {DOCUMENT_TIMEOUT}s")
# Cancel the stuck future
future.cancel()
except Exception as e:
handle_error(doc_id, e)