This commit is contained in:
Yaojia Wang
2026-02-12 23:06:00 +01:00
parent ad5ed46b4c
commit 58d36c8927
26 changed files with 3903 additions and 2551 deletions

View File

@@ -124,7 +124,7 @@ class AmountNormalizer(BaseNormalizer):
if not match:
continue
amount = self._parse_amount_str(match)
if amount is not None and amount > 0:
if amount is not None and 0 < amount < 10_000_000:
all_amounts.append(amount)
# Return the last amount found (usually the total)
@@ -134,7 +134,7 @@ class AmountNormalizer(BaseNormalizer):
# Fallback: try shared validator on cleaned text
cleaned = TextCleaner.normalize_amount_text(text)
amount = FieldValidators.parse_amount(cleaned)
if amount is not None and amount > 0:
if amount is not None and 0 < amount < 10_000_000:
return NormalizationResult.success(f"{amount:.2f}")
# Try to find any decimal number
@@ -144,7 +144,7 @@ class AmountNormalizer(BaseNormalizer):
amount_str = matches[-1].replace(",", ".")
try:
amount = float(amount_str)
if amount > 0:
if 0 < amount < 10_000_000:
return NormalizationResult.success(f"{amount:.2f}")
except ValueError:
pass
@@ -156,7 +156,7 @@ class AmountNormalizer(BaseNormalizer):
if match:
try:
amount = float(match.group(1))
if amount > 0:
if 0 < amount < 10_000_000:
return NormalizationResult.success(f"{amount:.2f}")
except ValueError:
pass
@@ -168,7 +168,7 @@ class AmountNormalizer(BaseNormalizer):
# Take the last/largest number
try:
amount = float(matches[-1])
if amount > 0:
if 0 < amount < 10_000_000:
return NormalizationResult.success(f"{amount:.2f}")
except ValueError:
pass

View File

@@ -62,14 +62,25 @@ class InvoiceNumberNormalizer(BaseNormalizer):
# Skip if it looks like a date (YYYYMMDD)
if len(seq) == 8 and seq.startswith("20"):
continue
# Skip year-only values (2024, 2025, 2026, etc.)
if len(seq) == 4 and seq.startswith("20"):
continue
# Skip if too long (likely OCR number)
if len(seq) > 10:
continue
valid_sequences.append(seq)
if valid_sequences:
# Return shortest valid sequence
return NormalizationResult.success(min(valid_sequences, key=len))
# Prefer 4-8 digit sequences (typical invoice numbers),
# then closest to 6 digits within that range.
# This avoids picking short fragments like "775" from amounts.
def _score(seq: str) -> tuple[int, int]:
length = len(seq)
if 4 <= length <= 8:
return (1, -abs(length - 6))
return (0, -length)
return NormalizationResult.success(max(valid_sequences, key=_score))
# Fallback: extract all digits if nothing else works
digits = re.sub(r"\D", "", text)

View File

@@ -14,7 +14,7 @@ class OcrNumberNormalizer(BaseNormalizer):
Normalizes OCR (Optical Character Recognition) reference numbers.
OCR numbers in Swedish payment systems:
- Minimum 5 digits
- Minimum 2 digits
- Used for automated payment matching
"""
@@ -29,7 +29,7 @@ class OcrNumberNormalizer(BaseNormalizer):
digits = re.sub(r"\D", "", text)
if len(digits) < 5:
if len(digits) < 2:
return NormalizationResult.failure(
f"Too few digits for OCR: {len(digits)}"
)

View File

@@ -234,7 +234,7 @@ class InferencePipeline:
confidence_threshold=confidence_threshold,
device='cuda' if use_gpu else 'cpu'
)
self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu)
self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu, dpi=dpi)
self.payment_line_parser = PaymentLineParser()
self.dpi = dpi
self.enable_fallback = enable_fallback
@@ -361,6 +361,7 @@ class InferencePipeline:
# Fallback if key fields are missing
if self.enable_fallback and self._needs_fallback(result):
self._run_fallback(pdf_path, result)
self._dedup_invoice_number(result)
# Extract business invoice features if enabled
if use_business_features:
@@ -477,9 +478,48 @@ class InferencePipeline:
# Store bbox for each field (useful for payment_line and other fields)
result.bboxes[field_name] = best.bbox
# Validate date consistency
self._validate_dates(result)
# Perform cross-validation if payment_line is detected
self._cross_validate_payment_line(result)
# Remove InvoiceNumber if it duplicates OCR or Bankgiro
self._dedup_invoice_number(result)
def _validate_dates(self, result: InferenceResult) -> None:
"""Remove InvoiceDueDate if it is earlier than InvoiceDate."""
invoice_date = result.fields.get('InvoiceDate')
due_date = result.fields.get('InvoiceDueDate')
if invoice_date and due_date and due_date < invoice_date:
del result.fields['InvoiceDueDate']
result.confidence.pop('InvoiceDueDate', None)
result.bboxes.pop('InvoiceDueDate', None)
def _dedup_invoice_number(self, result: InferenceResult) -> None:
"""Remove InvoiceNumber if it duplicates OCR or Bankgiro digits."""
inv_num = result.fields.get('InvoiceNumber')
if not inv_num:
return
inv_digits = re.sub(r'\D', '', str(inv_num))
# Check against OCR
ocr = result.fields.get('OCR')
if ocr and inv_digits == re.sub(r'\D', '', str(ocr)):
del result.fields['InvoiceNumber']
result.confidence.pop('InvoiceNumber', None)
result.bboxes.pop('InvoiceNumber', None)
return
# Check against Bankgiro (exact or substring match)
bg = result.fields.get('Bankgiro')
if bg:
bg_digits = re.sub(r'\D', '', str(bg))
if inv_digits == bg_digits or inv_digits in bg_digits:
del result.fields['InvoiceNumber']
result.confidence.pop('InvoiceNumber', None)
result.bboxes.pop('InvoiceNumber', None)
def _parse_machine_readable_payment_line(self, payment_line: str) -> tuple[str | None, str | None, str | None]:
"""
Parse machine-readable Swedish payment line format using unified PaymentLineParser.
@@ -638,10 +678,14 @@ class InferencePipeline:
def _needs_fallback(self, result: InferenceResult) -> bool:
"""Check if fallback OCR is needed."""
# Check for key fields
key_fields = ['Amount', 'InvoiceNumber', 'OCR']
missing = sum(1 for f in key_fields if f not in result.fields)
return missing >= 2 # Fallback if 2+ key fields missing
important_fields = ['InvoiceDate', 'InvoiceDueDate', 'supplier_organisation_number']
key_missing = sum(1 for f in key_fields if f not in result.fields)
important_missing = sum(1 for f in important_fields if f not in result.fields)
# Fallback if any key field missing OR 2+ important fields missing
return key_missing >= 1 or important_missing >= 2
def _run_fallback(self, pdf_path: str | Path, result: InferenceResult) -> None:
"""Run full-page OCR fallback."""
@@ -673,12 +717,13 @@ class InferencePipeline:
"""Extract fields using regex patterns (fallback)."""
patterns = {
'Amount': [
r'(?:att\s*betala|summa|total|belopp)\s*[:.]?\s*([\d\s,\.]+)\s*(?:SEK|kr)?',
r'([\d\s,\.]+)\s*(?:SEK|kr)\s*$',
r'(?:att\s+betala)\s*[:.]?\s*([\d\s\.]*\d+[,\.]\d{2})\s*(?:SEK|kr)?',
r'(?:summa|total|belopp)\s*[:.]?\s*([\d\s\.]*\d+[,\.]\d{2})\s*(?:SEK|kr)?',
r'([\d\s\.]*\d+[,\.]\d{2})\s*(?:SEK|kr)\s*$',
],
'Bankgiro': [
r'(?:bankgiro|bg)\s*[:.]?\s*(\d{3,4}[-\s]?\d{4})',
r'(\d{4}[-\s]\d{4})\s*(?=\s|$)',
r'(?<!\d)(\d{3,4}[-\s]\d{4})(?!\d)',
],
'OCR': [
r'(?:ocr|referens)\s*[:.]?\s*(\d{10,25})',
@@ -686,6 +731,20 @@ class InferencePipeline:
'InvoiceNumber': [
r'(?:fakturanr|fakturanummer|invoice)\s*[:.]?\s*(\d+)',
],
'InvoiceDate': [
r'(?:fakturadatum|invoice\s*date)\s*[:.]?\s*(\d{4}[-/]\d{2}[-/]\d{2})',
r'(?:fakturadatum|invoice\s*date)\s*[:.]?\s*(\d{2}[-/]\d{2}[-/]\d{4})',
],
'InvoiceDueDate': [
r'(?:f[oö]rfallo(?:dag|datum)?|due\s*date|betala\s*senast)\s*[:.]?\s*(\d{4}[-/]\d{2}[-/]\d{2})',
r'(?:f[oö]rfallo(?:dag|datum)?|due\s*date|betala\s*senast)\s*[:.]?\s*(\d{2}[-/]\d{2}[-/]\d{4})',
],
'supplier_organisation_number': [
r'(?:org\.?\s*n[ru]|organisationsnummer)\s*[:.]?\s*(\d{6}[-\s]?\d{4})',
],
'Plusgiro': [
r'(?:plusgiro|pg)\s*[:.]?\s*(\d[\d\s-]{4,12}\d)',
],
}
for field_name, field_patterns in patterns.items():
@@ -708,6 +767,22 @@ class InferencePipeline:
digits = re.sub(r'\D', '', value)
if len(digits) == 8:
value = f"{digits[:4]}-{digits[4:]}"
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
# Normalize DD/MM/YYYY to YYYY-MM-DD
date_match = re.match(r'(\d{2})[-/](\d{2})[-/](\d{4})', value)
if date_match:
value = f"{date_match.group(3)}-{date_match.group(2)}-{date_match.group(1)}"
# Replace / with -
value = value.replace('/', '-')
elif field_name == 'InvoiceNumber':
# Skip year-like values (2024, 2025, 2026, etc.)
if re.match(r'^20\d{2}$', value):
continue
elif field_name == 'supplier_organisation_number':
# Ensure NNNNNN-NNNN format
digits = re.sub(r'\D', '', value)
if len(digits) == 10:
value = f"{digits[:6]}-{digits[6:]}"
result.fields[field_name] = value
result.confidence[field_name] = 0.5 # Lower confidence for regex

View File

@@ -123,12 +123,12 @@ class ValueSelector:
@staticmethod
def _select_ocr_number(tokens: list[OCRToken]) -> list[OCRToken]:
"""Select token with the longest digit sequence (min 5 digits)."""
"""Select token with the longest digit sequence (min 2 digits)."""
best: OCRToken | None = None
best_count = 0
for token in tokens:
digit_count = _count_digits(token.text)
if digit_count >= 5 and digit_count > best_count:
if digit_count >= 2 and digit_count > best_count:
best = token
best_count = digit_count
return [best] if best else []