WIP
This commit is contained in:
@@ -9,12 +9,24 @@ import argparse
|
||||
import sys
|
||||
import time
|
||||
import os
|
||||
import signal
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
|
||||
import multiprocessing
|
||||
|
||||
# Global flag for graceful shutdown
|
||||
_shutdown_requested = False
|
||||
|
||||
|
||||
def _signal_handler(signum, frame):
|
||||
"""Handle interrupt signals for graceful shutdown."""
|
||||
global _shutdown_requested
|
||||
_shutdown_requested = True
|
||||
print("\n\nShutdown requested. Finishing current batch and saving progress...")
|
||||
print("(Press Ctrl+C again to force quit)\n")
|
||||
|
||||
# Windows compatibility: use 'spawn' method for multiprocessing
|
||||
# This is required on Windows and is also safer for libraries like PaddleOCR
|
||||
if sys.platform == 'win32':
|
||||
@@ -111,6 +123,12 @@ def process_single_document(args_tuple):
|
||||
|
||||
report = AutoLabelReport(document_id=doc_id)
|
||||
report.pdf_path = str(pdf_path)
|
||||
# Store metadata fields from CSV
|
||||
report.split = row_dict.get('split')
|
||||
report.customer_number = row_dict.get('customer_number')
|
||||
report.supplier_name = row_dict.get('supplier_name')
|
||||
report.supplier_organisation_number = row_dict.get('supplier_organisation_number')
|
||||
report.supplier_accounts = row_dict.get('supplier_accounts')
|
||||
|
||||
result = {
|
||||
'doc_id': doc_id,
|
||||
@@ -204,6 +222,67 @@ def process_single_document(args_tuple):
|
||||
context_keywords=best.context_keywords
|
||||
))
|
||||
|
||||
# Match supplier_accounts and map to Bankgiro/Plusgiro
|
||||
supplier_accounts_value = row_dict.get('supplier_accounts')
|
||||
if supplier_accounts_value:
|
||||
# Parse accounts: "BG:xxx | PG:yyy" format
|
||||
accounts = [acc.strip() for acc in str(supplier_accounts_value).split('|')]
|
||||
for account in accounts:
|
||||
account = account.strip()
|
||||
if not account:
|
||||
continue
|
||||
|
||||
# Determine account type (BG or PG) and extract account number
|
||||
account_type = None
|
||||
account_number = account # Default to full value
|
||||
|
||||
if account.upper().startswith('BG:'):
|
||||
account_type = 'Bankgiro'
|
||||
account_number = account[3:].strip() # Remove "BG:" prefix
|
||||
elif account.upper().startswith('BG '):
|
||||
account_type = 'Bankgiro'
|
||||
account_number = account[2:].strip() # Remove "BG" prefix
|
||||
elif account.upper().startswith('PG:'):
|
||||
account_type = 'Plusgiro'
|
||||
account_number = account[3:].strip() # Remove "PG:" prefix
|
||||
elif account.upper().startswith('PG '):
|
||||
account_type = 'Plusgiro'
|
||||
account_number = account[2:].strip() # Remove "PG" prefix
|
||||
else:
|
||||
# Try to guess from format - Plusgiro often has format XXXXXXX-X
|
||||
digits = ''.join(c for c in account if c.isdigit())
|
||||
if len(digits) == 8 and '-' in account:
|
||||
account_type = 'Plusgiro'
|
||||
elif len(digits) in (7, 8):
|
||||
account_type = 'Bankgiro' # Default to Bankgiro
|
||||
|
||||
if not account_type:
|
||||
continue
|
||||
|
||||
# Normalize and match using the account number (without prefix)
|
||||
normalized = normalize_field('supplier_accounts', account_number)
|
||||
field_matches = matcher.find_matches(tokens, account_type, normalized, page_no)
|
||||
|
||||
if field_matches:
|
||||
best = field_matches[0]
|
||||
# Add to matches under the target class (Bankgiro/Plusgiro)
|
||||
if account_type not in matches:
|
||||
matches[account_type] = []
|
||||
matches[account_type].extend(field_matches)
|
||||
matched_fields.add('supplier_accounts')
|
||||
|
||||
report.add_field_result(FieldMatchResult(
|
||||
field_name=f'supplier_accounts({account_type})',
|
||||
csv_value=account_number, # Store without prefix
|
||||
matched=True,
|
||||
score=best.score,
|
||||
matched_text=best.matched_text,
|
||||
candidate_used=best.value,
|
||||
bbox=best.bbox,
|
||||
page_no=page_no,
|
||||
context_keywords=best.context_keywords
|
||||
))
|
||||
|
||||
# Count annotations
|
||||
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
|
||||
|
||||
@@ -329,6 +408,10 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Register signal handlers for graceful shutdown
|
||||
signal.signal(signal.SIGINT, _signal_handler)
|
||||
signal.signal(signal.SIGTERM, _signal_handler)
|
||||
|
||||
# Import here to avoid slow startup
|
||||
from ..data import CSVLoader, AutoLabelReport, FieldMatchResult
|
||||
from ..data.autolabel_report import ReportWriter
|
||||
@@ -364,6 +447,7 @@ def main():
|
||||
from ..data.db import DocumentDB
|
||||
db = DocumentDB()
|
||||
db.connect()
|
||||
db.create_tables() # Ensure tables exist
|
||||
print("Connected to database for status checking")
|
||||
|
||||
# Global stats
|
||||
@@ -458,6 +542,11 @@ def main():
|
||||
try:
|
||||
# Process CSV files one by one (streaming)
|
||||
for csv_idx, csv_file in enumerate(csv_files):
|
||||
# Check for shutdown request
|
||||
if _shutdown_requested:
|
||||
print("\nShutdown requested. Stopping after current batch...")
|
||||
break
|
||||
|
||||
print(f"\n[{csv_idx + 1}/{len(csv_files)}] Processing: {csv_file.name}")
|
||||
|
||||
# Load only this CSV file
|
||||
@@ -548,6 +637,13 @@ def main():
|
||||
'Bankgiro': row.Bankgiro,
|
||||
'Plusgiro': row.Plusgiro,
|
||||
'Amount': row.Amount,
|
||||
# New fields
|
||||
'supplier_organisation_number': row.supplier_organisation_number,
|
||||
'supplier_accounts': row.supplier_accounts,
|
||||
# Metadata fields (not for matching, but for database storage)
|
||||
'split': row.split,
|
||||
'customer_number': row.customer_number,
|
||||
'supplier_name': row.supplier_name,
|
||||
}
|
||||
|
||||
tasks.append((
|
||||
@@ -647,11 +743,19 @@ def main():
|
||||
futures = {executor.submit(process_single_document, task): task[0]['DocumentId']
|
||||
for task in tasks}
|
||||
|
||||
# Per-document timeout: 120 seconds (2 minutes)
|
||||
# This prevents a single stuck document from blocking the entire batch
|
||||
DOCUMENT_TIMEOUT = 120
|
||||
|
||||
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
|
||||
doc_id = futures[future]
|
||||
try:
|
||||
result = future.result()
|
||||
result = future.result(timeout=DOCUMENT_TIMEOUT)
|
||||
handle_result(result)
|
||||
except TimeoutError:
|
||||
handle_error(doc_id, f"Processing timeout after {DOCUMENT_TIMEOUT}s")
|
||||
# Cancel the stuck future
|
||||
future.cancel()
|
||||
except Exception as e:
|
||||
handle_error(doc_id, e)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user