This commit is contained in:
Yaojia Wang
2026-01-16 23:10:01 +01:00
parent 53d1e8db25
commit 425b8fdedf
10 changed files with 653 additions and 87 deletions

View File

@@ -72,7 +72,10 @@
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && ls -la\")",
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-master && python -c \"\"\nimport sys\nsys.path.insert\\(0, ''.''\\)\nfrom src.data.db import DocumentDB\nfrom src.yolo.db_dataset import DBYOLODataset\n\n# Connect to database\ndb = DocumentDB\\(\\)\ndb.connect\\(\\)\n\n# Create dataset\ndataset = DBYOLODataset\\(\n images_dir=''data/dataset'',\n db=db,\n split=''train'',\n train_ratio=0.8,\n val_ratio=0.1,\n seed=42,\n dpi=300\n\\)\n\nprint\\(f''Dataset size: {len\\(dataset\\)}''\\)\n\nif len\\(dataset\\) > 0:\n # Check first few items\n for i in range\\(min\\(3, len\\(dataset\\)\\)\\):\n item = dataset.items[i]\n print\\(f''\\\\n--- Item {i} ---''\\)\n print\\(f''Document: {item.document_id}''\\)\n print\\(f''Is scanned: {item.is_scanned}''\\)\n print\\(f''Image: {item.image_path.name}''\\)\n \n # Get YOLO labels\n yolo_labels = dataset.get_labels_for_yolo\\(i\\)\n print\\(f''YOLO labels:''\\)\n for line in yolo_labels.split\\(''\\\\n''\\)[:3]:\n print\\(f'' {line}''\\)\n # Check if values are normalized\n parts = line.split\\(\\)\n if len\\(parts\\) == 5:\n x, y, w, h = float\\(parts[1]\\), float\\(parts[2]\\), float\\(parts[3]\\), float\\(parts[4]\\)\n if x > 1 or y > 1 or w > 1 or h > 1:\n print\\(f'' WARNING: Values not normalized!''\\)\n elif x == 1.0 or y == 1.0:\n print\\(f'' WARNING: Values clamped to 1.0!''\\)\n else:\n print\\(f'' OK: Values properly normalized''\\)\n\ndb.close\\(\\)\n\"\"\")",
"Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/dataset/\")",
"Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/dataset/train/\")"
"Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/dataset/train/\")",
"Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/structured_data/*.csv 2>/dev/null | head -20\")",
"Bash(tasklist:*)",
"Bash(findstr:*)"
],
"deny": [],
"ask": [],

View File

@@ -9,12 +9,24 @@ import argparse
import sys
import time
import os
import signal
import warnings
from pathlib import Path
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
import multiprocessing
# Global flag for graceful shutdown
_shutdown_requested = False
def _signal_handler(signum, frame):
"""Handle interrupt signals for graceful shutdown."""
global _shutdown_requested
_shutdown_requested = True
print("\n\nShutdown requested. Finishing current batch and saving progress...")
print("(Press Ctrl+C again to force quit)\n")
# Windows compatibility: use 'spawn' method for multiprocessing
# This is required on Windows and is also safer for libraries like PaddleOCR
if sys.platform == 'win32':
@@ -111,6 +123,12 @@ def process_single_document(args_tuple):
report = AutoLabelReport(document_id=doc_id)
report.pdf_path = str(pdf_path)
# Store metadata fields from CSV
report.split = row_dict.get('split')
report.customer_number = row_dict.get('customer_number')
report.supplier_name = row_dict.get('supplier_name')
report.supplier_organisation_number = row_dict.get('supplier_organisation_number')
report.supplier_accounts = row_dict.get('supplier_accounts')
result = {
'doc_id': doc_id,
@@ -204,6 +222,67 @@ def process_single_document(args_tuple):
context_keywords=best.context_keywords
))
# Match supplier_accounts and map to Bankgiro/Plusgiro
supplier_accounts_value = row_dict.get('supplier_accounts')
if supplier_accounts_value:
# Parse accounts: "BG:xxx | PG:yyy" format
accounts = [acc.strip() for acc in str(supplier_accounts_value).split('|')]
for account in accounts:
account = account.strip()
if not account:
continue
# Determine account type (BG or PG) and extract account number
account_type = None
account_number = account # Default to full value
if account.upper().startswith('BG:'):
account_type = 'Bankgiro'
account_number = account[3:].strip() # Remove "BG:" prefix
elif account.upper().startswith('BG '):
account_type = 'Bankgiro'
account_number = account[2:].strip() # Remove "BG" prefix
elif account.upper().startswith('PG:'):
account_type = 'Plusgiro'
account_number = account[3:].strip() # Remove "PG:" prefix
elif account.upper().startswith('PG '):
account_type = 'Plusgiro'
account_number = account[2:].strip() # Remove "PG" prefix
else:
# Try to guess from format - Plusgiro often has format XXXXXXX-X
digits = ''.join(c for c in account if c.isdigit())
if len(digits) == 8 and '-' in account:
account_type = 'Plusgiro'
elif len(digits) in (7, 8):
account_type = 'Bankgiro' # Default to Bankgiro
if not account_type:
continue
# Normalize and match using the account number (without prefix)
normalized = normalize_field('supplier_accounts', account_number)
field_matches = matcher.find_matches(tokens, account_type, normalized, page_no)
if field_matches:
best = field_matches[0]
# Add to matches under the target class (Bankgiro/Plusgiro)
if account_type not in matches:
matches[account_type] = []
matches[account_type].extend(field_matches)
matched_fields.add('supplier_accounts')
report.add_field_result(FieldMatchResult(
field_name=f'supplier_accounts({account_type})',
csv_value=account_number, # Store without prefix
matched=True,
score=best.score,
matched_text=best.matched_text,
candidate_used=best.value,
bbox=best.bbox,
page_no=page_no,
context_keywords=best.context_keywords
))
# Count annotations
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
@@ -329,6 +408,10 @@ def main():
args = parser.parse_args()
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGINT, _signal_handler)
signal.signal(signal.SIGTERM, _signal_handler)
# Import here to avoid slow startup
from ..data import CSVLoader, AutoLabelReport, FieldMatchResult
from ..data.autolabel_report import ReportWriter
@@ -364,6 +447,7 @@ def main():
from ..data.db import DocumentDB
db = DocumentDB()
db.connect()
db.create_tables() # Ensure tables exist
print("Connected to database for status checking")
# Global stats
@@ -458,6 +542,11 @@ def main():
try:
# Process CSV files one by one (streaming)
for csv_idx, csv_file in enumerate(csv_files):
# Check for shutdown request
if _shutdown_requested:
print("\nShutdown requested. Stopping after current batch...")
break
print(f"\n[{csv_idx + 1}/{len(csv_files)}] Processing: {csv_file.name}")
# Load only this CSV file
@@ -548,6 +637,13 @@ def main():
'Bankgiro': row.Bankgiro,
'Plusgiro': row.Plusgiro,
'Amount': row.Amount,
# New fields
'supplier_organisation_number': row.supplier_organisation_number,
'supplier_accounts': row.supplier_accounts,
# Metadata fields (not for matching, but for database storage)
'split': row.split,
'customer_number': row.customer_number,
'supplier_name': row.supplier_name,
}
tasks.append((
@@ -647,11 +743,19 @@ def main():
futures = {executor.submit(process_single_document, task): task[0]['DocumentId']
for task in tasks}
# Per-document timeout: 120 seconds (2 minutes)
# This prevents a single stuck document from blocking the entire batch
DOCUMENT_TIMEOUT = 120
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
doc_id = futures[future]
try:
result = future.result()
result = future.result(timeout=DOCUMENT_TIMEOUT)
handle_result(result)
except TimeoutError:
handle_error(doc_id, f"Processing timeout after {DOCUMENT_TIMEOUT}s")
# Cancel the stuck future
future.cancel()
except Exception as e:
handle_error(doc_id, e)

View File

@@ -34,7 +34,13 @@ def create_tables(conn):
annotations_generated INTEGER,
processing_time_ms REAL,
timestamp TIMESTAMPTZ,
errors JSONB DEFAULT '[]'
errors JSONB DEFAULT '[]',
-- New fields for extended CSV format
split TEXT,
customer_number TEXT,
supplier_name TEXT,
supplier_organisation_number TEXT,
supplier_accounts TEXT
);
CREATE TABLE IF NOT EXISTS field_results (
@@ -56,6 +62,26 @@ def create_tables(conn):
CREATE INDEX IF NOT EXISTS idx_field_results_document_id ON field_results(document_id);
CREATE INDEX IF NOT EXISTS idx_field_results_field_name ON field_results(field_name);
CREATE INDEX IF NOT EXISTS idx_field_results_matched ON field_results(matched);
-- Add new columns to existing tables if they don't exist (for migration)
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='split') THEN
ALTER TABLE documents ADD COLUMN split TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='customer_number') THEN
ALTER TABLE documents ADD COLUMN customer_number TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_name') THEN
ALTER TABLE documents ADD COLUMN supplier_name TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_organisation_number') THEN
ALTER TABLE documents ADD COLUMN supplier_organisation_number TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_accounts') THEN
ALTER TABLE documents ADD COLUMN supplier_accounts TEXT;
END IF;
END $$;
""")
conn.commit()
@@ -82,7 +108,8 @@ def import_jsonl_file(conn, jsonl_path: Path, skip_existing: bool = True, batch_
INSERT INTO documents
(document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total, annotations_generated,
processing_time_ms, timestamp, errors)
processing_time_ms, timestamp, errors,
split, customer_number, supplier_name, supplier_organisation_number, supplier_accounts)
VALUES %s
ON CONFLICT (document_id) DO UPDATE SET
pdf_path = EXCLUDED.pdf_path,
@@ -94,7 +121,12 @@ def import_jsonl_file(conn, jsonl_path: Path, skip_existing: bool = True, batch_
annotations_generated = EXCLUDED.annotations_generated,
processing_time_ms = EXCLUDED.processing_time_ms,
timestamp = EXCLUDED.timestamp,
errors = EXCLUDED.errors
errors = EXCLUDED.errors,
split = EXCLUDED.split,
customer_number = EXCLUDED.customer_number,
supplier_name = EXCLUDED.supplier_name,
supplier_organisation_number = EXCLUDED.supplier_organisation_number,
supplier_accounts = EXCLUDED.supplier_accounts
""", doc_batch)
doc_batch = []
@@ -150,7 +182,13 @@ def import_jsonl_file(conn, jsonl_path: Path, skip_existing: bool = True, batch_
record.get('annotations_generated'),
record.get('processing_time_ms'),
record.get('timestamp'),
json.dumps(record.get('errors', []))
json.dumps(record.get('errors', [])),
# New fields
record.get('split'),
record.get('customer_number'),
record.get('supplier_name'),
record.get('supplier_organisation_number'),
record.get('supplier_accounts'),
))
for field in record.get('field_results', []):

View File

@@ -63,6 +63,12 @@ class AutoLabelReport:
processing_time_ms: float = 0.0
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
errors: list[str] = field(default_factory=list)
# New metadata fields (from CSV, not for matching)
split: str | None = None
customer_number: str | None = None
supplier_name: str | None = None
supplier_organisation_number: str | None = None
supplier_accounts: str | None = None
def add_field_result(self, result: FieldMatchResult) -> None:
"""Add a field matching result."""
@@ -87,7 +93,13 @@ class AutoLabelReport:
'label_paths': self.label_paths,
'processing_time_ms': self.processing_time_ms,
'timestamp': self.timestamp,
'errors': self.errors
'errors': self.errors,
# New metadata fields
'split': self.split,
'customer_number': self.customer_number,
'supplier_name': self.supplier_name,
'supplier_organisation_number': self.supplier_organisation_number,
'supplier_accounts': self.supplier_accounts,
}
def to_json(self, indent: int | None = None) -> str:

View File

@@ -25,6 +25,12 @@ class InvoiceRow:
Bankgiro: str | None = None
Plusgiro: str | None = None
Amount: Decimal | None = None
# New fields
split: str | None = None # train/test split indicator
customer_number: str | None = None # Customer number (no matching needed)
supplier_name: str | None = None # Supplier name (no matching)
supplier_organisation_number: str | None = None # Swedish org number (needs matching)
supplier_accounts: str | None = None # Supplier accounts (needs matching)
# Raw values for reference
raw_data: dict = field(default_factory=dict)
@@ -40,6 +46,8 @@ class InvoiceRow:
'Bankgiro': self.Bankgiro,
'Plusgiro': self.Plusgiro,
'Amount': str(self.Amount) if self.Amount else None,
'supplier_organisation_number': self.supplier_organisation_number,
'supplier_accounts': self.supplier_accounts,
}
def get_field_value(self, field_name: str) -> str | None:
@@ -68,6 +76,12 @@ class CSVLoader:
'Bankgiro': 'Bankgiro',
'Plusgiro': 'Plusgiro',
'Amount': 'Amount',
# New fields
'split': 'split',
'customer_number': 'customer_number',
'supplier_name': 'supplier_name',
'supplier_organisation_number': 'supplier_organisation_number',
'supplier_accounts': 'supplier_accounts',
}
def __init__(
@@ -200,6 +214,12 @@ class CSVLoader:
Bankgiro=self._parse_string(row.get('Bankgiro')),
Plusgiro=self._parse_string(row.get('Plusgiro')),
Amount=self._parse_amount(row.get('Amount')),
# New fields
split=self._parse_string(row.get('split')),
customer_number=self._parse_string(row.get('customer_number')),
supplier_name=self._parse_string(row.get('supplier_name')),
supplier_organisation_number=self._parse_string(row.get('supplier_organisation_number')),
supplier_accounts=self._parse_string(row.get('supplier_accounts')),
raw_data=dict(row)
)
@@ -318,14 +338,16 @@ class CSVLoader:
row.OCR,
row.Bankgiro,
row.Plusgiro,
row.Amount
row.Amount,
row.supplier_organisation_number,
row.supplier_accounts,
]
if not any(matchable_fields):
issues.append({
'row': i,
'doc_id': row.DocumentId,
'field': 'All',
'issue': 'No matchable fields (InvoiceNumber/OCR/Bankgiro/Plusgiro/Amount)'
'issue': 'No matchable fields (InvoiceNumber/OCR/Bankgiro/Plusgiro/Amount/supplier_organisation_number/supplier_accounts)'
})
return issues

View File

@@ -26,6 +26,73 @@ class DocumentDB:
self.conn = psycopg2.connect(self.connection_string)
return self.conn
def create_tables(self):
"""Create database tables if they don't exist."""
conn = self.connect()
with conn.cursor() as cursor:
cursor.execute("""
CREATE TABLE IF NOT EXISTS documents (
document_id TEXT PRIMARY KEY,
pdf_path TEXT,
pdf_type TEXT,
success BOOLEAN,
total_pages INTEGER,
fields_matched INTEGER,
fields_total INTEGER,
annotations_generated INTEGER,
processing_time_ms REAL,
timestamp TIMESTAMPTZ,
errors JSONB DEFAULT '[]',
-- Extended CSV format fields
split TEXT,
customer_number TEXT,
supplier_name TEXT,
supplier_organisation_number TEXT,
supplier_accounts TEXT
);
CREATE TABLE IF NOT EXISTS field_results (
id SERIAL PRIMARY KEY,
document_id TEXT NOT NULL REFERENCES documents(document_id) ON DELETE CASCADE,
field_name TEXT,
csv_value TEXT,
matched BOOLEAN,
score REAL,
matched_text TEXT,
candidate_used TEXT,
bbox JSONB,
page_no INTEGER,
context_keywords JSONB DEFAULT '[]',
error TEXT
);
CREATE INDEX IF NOT EXISTS idx_documents_success ON documents(success);
CREATE INDEX IF NOT EXISTS idx_field_results_document_id ON field_results(document_id);
CREATE INDEX IF NOT EXISTS idx_field_results_field_name ON field_results(field_name);
CREATE INDEX IF NOT EXISTS idx_field_results_matched ON field_results(matched);
-- Add new columns to existing tables if they don't exist (for migration)
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='split') THEN
ALTER TABLE documents ADD COLUMN split TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='customer_number') THEN
ALTER TABLE documents ADD COLUMN customer_number TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_name') THEN
ALTER TABLE documents ADD COLUMN supplier_name TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_organisation_number') THEN
ALTER TABLE documents ADD COLUMN supplier_organisation_number TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_accounts') THEN
ALTER TABLE documents ADD COLUMN supplier_accounts TEXT;
END IF;
END $$;
""")
conn.commit()
def close(self):
"""Close database connection."""
if self.conn:
@@ -110,7 +177,9 @@ class DocumentDB:
cursor.execute("""
SELECT document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total, annotations_generated,
processing_time_ms, timestamp, errors
processing_time_ms, timestamp, errors,
split, customer_number, supplier_name,
supplier_organisation_number, supplier_accounts
FROM documents WHERE document_id = %s
""", (doc_id,))
row = cursor.fetchone()
@@ -129,6 +198,12 @@ class DocumentDB:
'processing_time_ms': row[8],
'timestamp': str(row[9]) if row[9] else None,
'errors': row[10] if isinstance(row[10], list) else json.loads(row[10] or '[]'),
# New fields
'split': row[11],
'customer_number': row[12],
'supplier_name': row[13],
'supplier_organisation_number': row[14],
'supplier_accounts': row[15],
'field_results': []
}
@@ -253,7 +328,9 @@ class DocumentDB:
cursor.execute("""
SELECT document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total, annotations_generated,
processing_time_ms, timestamp, errors
processing_time_ms, timestamp, errors,
split, customer_number, supplier_name,
supplier_organisation_number, supplier_accounts
FROM documents WHERE document_id = ANY(%s)
""", (doc_ids,))
@@ -270,6 +347,12 @@ class DocumentDB:
'processing_time_ms': row[8],
'timestamp': str(row[9]) if row[9] else None,
'errors': row[10] if isinstance(row[10], list) else json.loads(row[10] or '[]'),
# New fields
'split': row[11],
'customer_number': row[12],
'supplier_name': row[13],
'supplier_organisation_number': row[14],
'supplier_accounts': row[15],
'field_results': []
}
@@ -315,8 +398,9 @@ class DocumentDB:
INSERT INTO documents
(document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total, annotations_generated,
processing_time_ms, timestamp, errors)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
processing_time_ms, timestamp, errors,
split, customer_number, supplier_name, supplier_organisation_number, supplier_accounts)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", (
doc_id,
report.get('pdf_path'),
@@ -328,7 +412,13 @@ class DocumentDB:
report.get('annotations_generated'),
report.get('processing_time_ms'),
report.get('timestamp'),
json.dumps(report.get('errors', []))
json.dumps(report.get('errors', [])),
# New fields
report.get('split'),
report.get('customer_number'),
report.get('supplier_name'),
report.get('supplier_organisation_number'),
report.get('supplier_accounts'),
))
# Batch insert field results using execute_values
@@ -387,7 +477,13 @@ class DocumentDB:
r.get('annotations_generated'),
r.get('processing_time_ms'),
r.get('timestamp'),
json.dumps(r.get('errors', []))
json.dumps(r.get('errors', [])),
# New fields
r.get('split'),
r.get('customer_number'),
r.get('supplier_name'),
r.get('supplier_organisation_number'),
r.get('supplier_accounts'),
)
for r in reports
]
@@ -395,7 +491,8 @@ class DocumentDB:
INSERT INTO documents
(document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total, annotations_generated,
processing_time_ms, timestamp, errors)
processing_time_ms, timestamp, errors,
split, customer_number, supplier_name, supplier_organisation_number, supplier_accounts)
VALUES %s
""", doc_values)

View File

@@ -14,6 +14,12 @@ from functools import cached_property
_DATE_PATTERN = re.compile(r'(\d{4})-(\d{2})-(\d{2})')
_WHITESPACE_PATTERN = re.compile(r'\s+')
_NON_DIGIT_PATTERN = re.compile(r'\D')
_DASH_PATTERN = re.compile(r'[\u2013\u2014\u2212]') # en-dash, em-dash, minus sign
def _normalize_dashes(text: str) -> str:
"""Normalize different dash types to standard hyphen-minus (ASCII 45)."""
return _DASH_PATTERN.sub('-', text)
class TokenLike(Protocol):
@@ -143,6 +149,9 @@ CONTEXT_KEYWORDS = {
'Bankgiro': ['bankgiro', 'bg', 'bg-nr', 'bg nr'],
'Plusgiro': ['plusgiro', 'pg', 'pg-nr', 'pg nr'],
'Amount': ['att betala', 'summa', 'total', 'belopp', 'amount', 'totalt', 'att erlägga', 'sek', 'kr'],
'supplier_organisation_number': ['organisationsnummer', 'org.nr', 'org nr', 'orgnr', 'org.nummer',
'momsreg', 'momsnr', 'moms nr', 'vat', 'corporate id'],
'supplier_accounts': ['konto', 'kontonr', 'konto nr', 'account', 'klientnr', 'kundnr'],
}
@@ -207,7 +216,10 @@ class FieldMatcher:
# Strategy 4: Substring match (for values embedded in longer text)
# e.g., "Fakturanummer: 2465027205" should match OCR value "2465027205"
if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount'):
# Note: Amount is excluded because short numbers like "451" can incorrectly match
# in OCR payment lines or other unrelated text
if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
'supplier_organisation_number', 'supplier_accounts'):
substring_matches = self._find_substring_matches(page_tokens, value, field_name)
matches.extend(substring_matches)
@@ -237,7 +249,8 @@ class FieldMatcher:
"""Find tokens that exactly match the value."""
matches = []
value_lower = value.lower()
value_digits = _NON_DIGIT_PATTERN.sub('', value) if field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro') else None
value_digits = _NON_DIGIT_PATTERN.sub('', value) if field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
'supplier_organisation_number', 'supplier_accounts') else None
for token in tokens:
token_text = token.text.strip()
@@ -355,33 +368,36 @@ class FieldMatcher:
matches = []
# Supported fields for substring matching
supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount')
supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount',
'supplier_organisation_number', 'supplier_accounts')
if field_name not in supported_fields:
return matches
for token in tokens:
token_text = token.text.strip()
# Normalize different dash types to hyphen-minus for matching
token_text_normalized = _normalize_dashes(token_text)
# Skip if token is the same length as value (would be exact match)
if len(token_text) <= len(value):
if len(token_text_normalized) <= len(value):
continue
# Check if value appears as substring
if value in token_text:
# Check if value appears as substring (using normalized text)
if value in token_text_normalized:
# Verify it's a proper boundary match (not part of a larger number)
idx = token_text.find(value)
idx = token_text_normalized.find(value)
# Check character before (if exists)
if idx > 0:
char_before = token_text[idx - 1]
char_before = token_text_normalized[idx - 1]
# Must be non-digit (allow : space - etc)
if char_before.isdigit():
continue
# Check character after (if exists)
end_idx = idx + len(value)
if end_idx < len(token_text):
char_after = token_text[end_idx]
if end_idx < len(token_text_normalized):
char_after = token_text_normalized[end_idx]
# Must be non-digit
if char_after.isdigit():
continue

View File

@@ -39,9 +39,12 @@ class FieldNormalizer:
@staticmethod
def clean_text(text: str) -> str:
"""Remove invisible characters and normalize whitespace."""
"""Remove invisible characters and normalize whitespace and dashes."""
# Remove zero-width characters
text = re.sub(r'[\u200b\u200c\u200d\ufeff]', '', text)
# Normalize different dash types to standard hyphen-minus (ASCII 45)
# en-dash (, U+2013), em-dash (—, U+2014), minus sign (, U+2212)
text = re.sub(r'[\u2013\u2014\u2212]', '-', text)
# Normalize whitespace
text = ' '.join(text.split())
return text.strip()
@@ -130,6 +133,133 @@ class FieldNormalizer:
return list(set(v for v in variants if v))
@staticmethod
def normalize_organisation_number(value: str) -> list[str]:
"""
Normalize Swedish organisation number and generate VAT number variants.
Organisation number format: NNNNNN-NNNN (6 digits + hyphen + 4 digits)
Swedish VAT format: SE + org_number (10 digits) + 01
Examples:
'556123-4567' -> ['556123-4567', '5561234567', 'SE556123456701', ...]
'5561234567' -> ['5561234567', '556123-4567', 'SE556123456701', ...]
'SE556123456701' -> ['SE556123456701', '5561234567', '556123-4567', ...]
"""
value = FieldNormalizer.clean_text(value)
# Check if input is a VAT number (starts with SE, ends with 01)
org_digits = None
if value.upper().startswith('SE') and len(value) >= 12:
# Extract org number from VAT: SE + 10 digits + 01
potential_org = re.sub(r'\D', '', value[2:]) # Remove SE prefix, keep digits
if len(potential_org) == 12 and potential_org.endswith('01'):
org_digits = potential_org[:-2] # Remove trailing 01
elif len(potential_org) == 10:
org_digits = potential_org
if org_digits is None:
org_digits = re.sub(r'\D', '', value)
variants = [value]
if org_digits:
variants.append(org_digits)
# Standard format: NNNNNN-NNNN (10 digits total)
if len(org_digits) == 10:
with_dash = f"{org_digits[:6]}-{org_digits[6:]}"
variants.append(with_dash)
# Swedish VAT format: SE + org_number + 01
vat_number = f"SE{org_digits}01"
variants.append(vat_number)
variants.append(vat_number.lower()) # se556123456701
# With spaces: SE 5561234567 01
variants.append(f"SE {org_digits} 01")
variants.append(f"SE {org_digits[:6]}-{org_digits[6:]} 01")
# Without 01 suffix (some invoices show just SE + org)
variants.append(f"SE{org_digits}")
variants.append(f"SE {org_digits}")
# Some may have 12 digits (century prefix): NNNNNNNN-NNNN
elif len(org_digits) == 12:
with_dash = f"{org_digits[:8]}-{org_digits[8:]}"
variants.append(with_dash)
# Also try without century prefix
short_version = org_digits[2:]
variants.append(short_version)
variants.append(f"{short_version[:6]}-{short_version[6:]}")
# VAT with short version
vat_number = f"SE{short_version}01"
variants.append(vat_number)
return list(set(v for v in variants if v))
@staticmethod
def normalize_supplier_accounts(value: str) -> list[str]:
"""
Normalize supplier accounts field.
The field may contain multiple accounts separated by ' | '.
Format examples:
'PG:48676043 | PG:49128028 | PG:8915035'
'BG:5393-9484'
Each account is normalized separately to generate variants.
Examples:
'PG:48676043' -> ['PG:48676043', '48676043', '4867604-3']
'BG:5393-9484' -> ['BG:5393-9484', '5393-9484', '53939484']
"""
value = FieldNormalizer.clean_text(value)
variants = []
# Split by ' | ' to handle multiple accounts
accounts = [acc.strip() for acc in value.split('|')]
for account in accounts:
account = account.strip()
if not account:
continue
# Add original value
variants.append(account)
# Remove prefix (PG:, BG:, etc.)
if ':' in account:
prefix, number = account.split(':', 1)
number = number.strip()
variants.append(number) # Just the number without prefix
# Also add with different prefix formats
prefix_upper = prefix.strip().upper()
variants.append(f"{prefix_upper}:{number}")
variants.append(f"{prefix_upper}: {number}") # With space
else:
number = account
# Extract digits only
digits_only = re.sub(r'\D', '', number)
if digits_only:
variants.append(digits_only)
# Plusgiro format: XXXXXXX-X (7 digits + check digit)
if len(digits_only) == 8:
with_dash = f"{digits_only[:-1]}-{digits_only[-1]}"
variants.append(with_dash)
# Also try 4-4 format for bankgiro
variants.append(f"{digits_only[:4]}-{digits_only[4:]}")
elif len(digits_only) == 7:
with_dash = f"{digits_only[:-1]}-{digits_only[-1]}"
variants.append(with_dash)
elif len(digits_only) == 10:
# 6-4 format (like org number)
variants.append(f"{digits_only[:6]}-{digits_only[6:]}")
return list(set(v for v in variants if v))
@staticmethod
def normalize_amount(value: str) -> list[str]:
"""
@@ -264,40 +394,71 @@ class FieldNormalizer:
'2025-12-13' -> ['2025-12-13', '13/12/2025', '13.12.2025']
'13/12/2025' -> ['2025-12-13', '13/12/2025', ...]
'13 december 2025' -> ['2025-12-13', ...]
Note: For ambiguous formats like DD/MM/YYYY vs MM/DD/YYYY,
we generate variants for BOTH interpretations to maximize matching.
"""
value = FieldNormalizer.clean_text(value)
variants = [value]
parsed_date = None
parsed_dates = [] # May have multiple interpretations
# Try different date formats
date_patterns = [
# ISO format with optional time (e.g., 2026-01-09 00:00:00)
(r'^(\d{4})-(\d{1,2})-(\d{1,2})(?:\s+\d{1,2}:\d{2}:\d{2})?$', lambda m: (int(m[1]), int(m[2]), int(m[3]))),
# European format with /
(r'^(\d{1,2})/(\d{1,2})/(\d{4})$', lambda m: (int(m[3]), int(m[2]), int(m[1]))),
# European format with .
(r'^(\d{1,2})\.(\d{1,2})\.(\d{4})$', lambda m: (int(m[3]), int(m[2]), int(m[1]))),
# European format with -
(r'^(\d{1,2})-(\d{1,2})-(\d{4})$', lambda m: (int(m[3]), int(m[2]), int(m[1]))),
# Swedish format: YYMMDD
(r'^(\d{2})(\d{2})(\d{2})$', lambda m: (2000 + int(m[1]) if int(m[1]) < 50 else 1900 + int(m[1]), int(m[2]), int(m[3]))),
# Swedish format: YYYYMMDD
(r'^(\d{4})(\d{2})(\d{2})$', lambda m: (int(m[1]), int(m[2]), int(m[3]))),
]
# Ambiguous patterns - try both DD/MM and MM/DD interpretations
ambiguous_patterns = [
# Format with / - could be DD/MM/YYYY (European) or MM/DD/YYYY (US)
r'^(\d{1,2})/(\d{1,2})/(\d{4})$',
# Format with . - typically European DD.MM.YYYY
r'^(\d{1,2})\.(\d{1,2})\.(\d{4})$',
# Format with - (not ISO) - could be DD-MM-YYYY or MM-DD-YYYY
r'^(\d{1,2})-(\d{1,2})-(\d{4})$',
]
# Try unambiguous patterns first
for pattern, extractor in date_patterns:
match = re.match(pattern, value)
if match:
try:
year, month, day = extractor(match)
parsed_date = datetime(year, month, day)
parsed_dates.append(datetime(year, month, day))
break
except ValueError:
continue
# Try ambiguous patterns with both interpretations
if not parsed_dates:
for pattern in ambiguous_patterns:
match = re.match(pattern, value)
if match:
n1, n2, year = int(match[1]), int(match[2]), int(match[3])
# Try DD/MM/YYYY (European - day first)
try:
parsed_dates.append(datetime(year, n2, n1))
except ValueError:
pass
# Try MM/DD/YYYY (US - month first) if different and valid
if n1 != n2:
try:
parsed_dates.append(datetime(year, n1, n2))
except ValueError:
pass
if parsed_dates:
break
# Try Swedish month names
if not parsed_date:
if not parsed_dates:
for month_name, month_num in FieldNormalizer.SWEDISH_MONTHS.items():
if month_name in value.lower():
# Extract day and year
@@ -308,16 +469,28 @@ class FieldNormalizer:
if year < 100:
year = 2000 + year if year < 50 else 1900 + year
try:
parsed_date = datetime(year, int(month_num), day)
parsed_dates.append(datetime(year, int(month_num), day))
break
except ValueError:
continue
if parsed_date:
# Generate variants for all parsed date interpretations
swedish_months_full = [
'januari', 'februari', 'mars', 'april', 'maj', 'juni',
'juli', 'augusti', 'september', 'oktober', 'november', 'december'
]
swedish_months_abbrev = [
'jan', 'feb', 'mar', 'apr', 'maj', 'jun',
'jul', 'aug', 'sep', 'okt', 'nov', 'dec'
]
for parsed_date in parsed_dates:
# Generate different formats
iso = parsed_date.strftime('%Y-%m-%d')
eu_slash = parsed_date.strftime('%d/%m/%Y')
us_slash = parsed_date.strftime('%m/%d/%Y') # US format MM/DD/YYYY
eu_dot = parsed_date.strftime('%d.%m.%Y')
iso_dot = parsed_date.strftime('%Y.%m.%d') # ISO with dots (e.g., 2024.02.08)
compact = parsed_date.strftime('%Y%m%d') # YYYYMMDD
compact_short = parsed_date.strftime('%y%m%d') # YYMMDD (e.g., 260108)
@@ -329,21 +502,13 @@ class FieldNormalizer:
spaced_short = parsed_date.strftime('%y %m %d')
# Swedish month name formats (e.g., "9 januari 2026", "9 jan 2026")
swedish_months_full = [
'januari', 'februari', 'mars', 'april', 'maj', 'juni',
'juli', 'augusti', 'september', 'oktober', 'november', 'december'
]
swedish_months_abbrev = [
'jan', 'feb', 'mar', 'apr', 'maj', 'jun',
'jul', 'aug', 'sep', 'okt', 'nov', 'dec'
]
month_full = swedish_months_full[parsed_date.month - 1]
month_abbrev = swedish_months_abbrev[parsed_date.month - 1]
swedish_format_full = f"{parsed_date.day} {month_full} {parsed_date.year}"
swedish_format_abbrev = f"{parsed_date.day} {month_abbrev} {parsed_date.year}"
variants.extend([
iso, eu_slash, eu_dot, compact, compact_short,
iso, eu_slash, us_slash, eu_dot, iso_dot, compact, compact_short,
eu_dot_short, spaced_full, spaced_short,
swedish_format_full, swedish_format_abbrev
])
@@ -360,6 +525,8 @@ NORMALIZERS: dict[str, Callable[[str], list[str]]] = {
'Amount': FieldNormalizer.normalize_amount,
'InvoiceDate': FieldNormalizer.normalize_date,
'InvoiceDueDate': FieldNormalizer.normalize_date,
'supplier_organisation_number': FieldNormalizer.normalize_organisation_number,
'supplier_accounts': FieldNormalizer.normalize_supplier_accounts,
}

View File

@@ -11,6 +11,7 @@ import csv
# Field class mapping for YOLO
# Note: supplier_accounts is not a separate class - its matches are mapped to Bankgiro/Plusgiro
FIELD_CLASSES = {
'InvoiceNumber': 0,
'InvoiceDate': 1,
@@ -19,6 +20,16 @@ FIELD_CLASSES = {
'Bankgiro': 4,
'Plusgiro': 5,
'Amount': 6,
'supplier_organisation_number': 7,
}
# Fields that need matching but map to other YOLO classes
# supplier_accounts matches are classified as Bankgiro or Plusgiro based on account type
ACCOUNT_FIELD_MAPPING = {
'supplier_accounts': {
'BG': 'Bankgiro', # BG:xxx -> Bankgiro class
'PG': 'Plusgiro', # PG:xxx -> Plusgiro class
}
}
CLASS_NAMES = [
@@ -29,6 +40,7 @@ CLASS_NAMES = [
'bankgiro',
'plusgiro',
'amount',
'supplier_org_number',
]

View File

@@ -52,6 +52,7 @@ class DatasetItem:
page_no: int
labels: list[YOLOAnnotation]
is_scanned: bool = False # True if bbox is in pixel coords, False if in PDF points
csv_split: str | None = None # CSV-defined split ('train', 'test', etc.)
class DBYOLODataset:
@@ -202,7 +203,7 @@ class DBYOLODataset:
total_images += len(images)
continue
labels_by_page, is_scanned = doc_data
labels_by_page, is_scanned, csv_split = doc_data
for image_path in images:
total_images += 1
@@ -218,7 +219,8 @@ class DBYOLODataset:
image_path=image_path,
page_no=page_no,
labels=page_labels,
is_scanned=is_scanned
is_scanned=is_scanned,
csv_split=csv_split
))
else:
skipped_no_labels += 1
@@ -237,16 +239,17 @@ class DBYOLODataset:
self.items, self._doc_ids_ordered = self._split_dataset(all_items)
print(f"Split '{self.split}': {len(self.items)} items")
def _load_labels_from_db(self, doc_ids: list[str]) -> dict[str, tuple[dict[int, list[YOLOAnnotation]], bool]]:
def _load_labels_from_db(self, doc_ids: list[str]) -> dict[str, tuple[dict[int, list[YOLOAnnotation]], bool, str | None]]:
"""
Load labels from database for given document IDs using batch queries.
Returns:
Dict of doc_id -> (page_labels, is_scanned)
Dict of doc_id -> (page_labels, is_scanned, split)
where page_labels is {page_no -> list[YOLOAnnotation]}
and is_scanned indicates if bbox is in pixel coords (True) or PDF points (False)
is_scanned indicates if bbox is in pixel coords (True) or PDF points (False)
split is the CSV-defined split ('train', 'test', etc.) or None
"""
result: dict[str, tuple[dict[int, list[YOLOAnnotation]], bool]] = {}
result: dict[str, tuple[dict[int, list[YOLOAnnotation]], bool, str | None]] = {}
# Query in batches using efficient batch method
batch_size = 500
@@ -263,6 +266,9 @@ class DBYOLODataset:
# Check if scanned PDF (OCR bbox is in pixels, text PDF bbox is in PDF points)
is_scanned = doc.get('pdf_type') == 'scanned'
# Get CSV-defined split
csv_split = doc.get('split')
page_labels: dict[int, list[YOLOAnnotation]] = {}
for field_result in doc.get('field_results', []):
@@ -292,7 +298,7 @@ class DBYOLODataset:
page_labels[page_no].append(annotation)
if page_labels:
result[doc_id] = (page_labels, is_scanned)
result[doc_id] = (page_labels, is_scanned, csv_split)
return result
@@ -333,7 +339,10 @@ class DBYOLODataset:
def _split_dataset(self, items: list[DatasetItem]) -> tuple[list[DatasetItem], list[str]]:
"""
Split items into train/val/test based on document ID.
Split items into train/val/test based on CSV-defined split field.
If CSV has 'split' field, use it directly.
Otherwise, fall back to random splitting based on train_ratio/val_ratio.
Returns:
Tuple of (split_items, ordered_doc_ids) where ordered_doc_ids can be
@@ -341,13 +350,64 @@ class DBYOLODataset:
"""
# Group by document ID for proper splitting
doc_items: dict[str, list[DatasetItem]] = {}
doc_csv_split: dict[str, str | None] = {} # Track CSV split per document
for item in items:
if item.document_id not in doc_items:
doc_items[item.document_id] = []
doc_csv_split[item.document_id] = item.csv_split
doc_items[item.document_id].append(item)
# Shuffle document IDs
# Check if we have CSV-defined splits
has_csv_splits = any(split is not None for split in doc_csv_split.values())
doc_ids = list(doc_items.keys())
if has_csv_splits:
# Use CSV-defined splits
print("Using CSV-defined split field for train/val/test assignment")
# Map split values: 'train' -> train, 'test' -> test, None -> train (fallback)
# 'val' is taken from train set using val_ratio
split_doc_ids = []
if self.split == 'train':
# Get documents marked as 'train' or no split defined
train_docs = [doc_id for doc_id in doc_ids
if doc_csv_split[doc_id] in (None, 'train', 'Train', 'TRAIN')]
# Take train_ratio of train docs for actual training, rest for val
random.seed(self.seed)
random.shuffle(train_docs)
n_train = int(len(train_docs) * (self.train_ratio / (self.train_ratio + self.val_ratio)))
split_doc_ids = train_docs[:n_train]
elif self.split == 'val':
# Get documents marked as 'train' and take val portion
train_docs = [doc_id for doc_id in doc_ids
if doc_csv_split[doc_id] in (None, 'train', 'Train', 'TRAIN')]
random.seed(self.seed)
random.shuffle(train_docs)
n_train = int(len(train_docs) * (self.train_ratio / (self.train_ratio + self.val_ratio)))
split_doc_ids = train_docs[n_train:]
else: # test
# Get documents marked as 'test'
split_doc_ids = [doc_id for doc_id in doc_ids
if doc_csv_split[doc_id] in ('test', 'Test', 'TEST')]
# Apply limit if specified
if self.limit is not None and self.limit < len(split_doc_ids):
split_doc_ids = split_doc_ids[:self.limit]
print(f"Limited to {self.limit} documents")
else:
# Fall back to random splitting (original behavior)
print("No CSV split field found, using random splitting")
random.seed(self.seed)
random.shuffle(doc_ids)
@@ -381,23 +441,58 @@ class DBYOLODataset:
Split items using cached data from a shared dataset.
Uses pre-computed doc_ids order for consistent splits.
Respects CSV-defined splits if available.
"""
# Group by document ID
# Group by document ID and track CSV splits
doc_items: dict[str, list[DatasetItem]] = {}
doc_csv_split: dict[str, str | None] = {}
for item in self._all_items:
if item.document_id not in doc_items:
doc_items[item.document_id] = []
doc_csv_split[item.document_id] = item.csv_split
doc_items[item.document_id].append(item)
# Use cached doc_ids order
# Check if we have CSV-defined splits
has_csv_splits = any(split is not None for split in doc_csv_split.values())
doc_ids = self._doc_ids_ordered
# Calculate split indices
if has_csv_splits:
# Use CSV-defined splits
if self.split == 'train':
train_docs = [doc_id for doc_id in doc_ids
if doc_id in doc_csv_split and
doc_csv_split[doc_id] in (None, 'train', 'Train', 'TRAIN')]
random.seed(self.seed)
random.shuffle(train_docs)
n_train = int(len(train_docs) * (self.train_ratio / (self.train_ratio + self.val_ratio)))
split_doc_ids = train_docs[:n_train]
elif self.split == 'val':
train_docs = [doc_id for doc_id in doc_ids
if doc_id in doc_csv_split and
doc_csv_split[doc_id] in (None, 'train', 'Train', 'TRAIN')]
random.seed(self.seed)
random.shuffle(train_docs)
n_train = int(len(train_docs) * (self.train_ratio / (self.train_ratio + self.val_ratio)))
split_doc_ids = train_docs[n_train:]
else: # test
split_doc_ids = [doc_id for doc_id in doc_ids
if doc_id in doc_csv_split and
doc_csv_split[doc_id] in ('test', 'Test', 'TEST')]
else:
# Fall back to random splitting
n_total = len(doc_ids)
n_train = int(n_total * self.train_ratio)
n_val = int(n_total * self.val_ratio)
# Split document IDs based on split type
if self.split == 'train':
split_doc_ids = doc_ids[:n_train]
elif self.split == 'val':