WIP
This commit is contained in:
@@ -124,7 +124,7 @@ class AmountNormalizer(BaseNormalizer):
|
||||
if not match:
|
||||
continue
|
||||
amount = self._parse_amount_str(match)
|
||||
if amount is not None and amount > 0:
|
||||
if amount is not None and 0 < amount < 10_000_000:
|
||||
all_amounts.append(amount)
|
||||
|
||||
# Return the last amount found (usually the total)
|
||||
@@ -134,7 +134,7 @@ class AmountNormalizer(BaseNormalizer):
|
||||
# Fallback: try shared validator on cleaned text
|
||||
cleaned = TextCleaner.normalize_amount_text(text)
|
||||
amount = FieldValidators.parse_amount(cleaned)
|
||||
if amount is not None and amount > 0:
|
||||
if amount is not None and 0 < amount < 10_000_000:
|
||||
return NormalizationResult.success(f"{amount:.2f}")
|
||||
|
||||
# Try to find any decimal number
|
||||
@@ -144,7 +144,7 @@ class AmountNormalizer(BaseNormalizer):
|
||||
amount_str = matches[-1].replace(",", ".")
|
||||
try:
|
||||
amount = float(amount_str)
|
||||
if amount > 0:
|
||||
if 0 < amount < 10_000_000:
|
||||
return NormalizationResult.success(f"{amount:.2f}")
|
||||
except ValueError:
|
||||
pass
|
||||
@@ -156,7 +156,7 @@ class AmountNormalizer(BaseNormalizer):
|
||||
if match:
|
||||
try:
|
||||
amount = float(match.group(1))
|
||||
if amount > 0:
|
||||
if 0 < amount < 10_000_000:
|
||||
return NormalizationResult.success(f"{amount:.2f}")
|
||||
except ValueError:
|
||||
pass
|
||||
@@ -168,7 +168,7 @@ class AmountNormalizer(BaseNormalizer):
|
||||
# Take the last/largest number
|
||||
try:
|
||||
amount = float(matches[-1])
|
||||
if amount > 0:
|
||||
if 0 < amount < 10_000_000:
|
||||
return NormalizationResult.success(f"{amount:.2f}")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
@@ -62,14 +62,25 @@ class InvoiceNumberNormalizer(BaseNormalizer):
|
||||
# Skip if it looks like a date (YYYYMMDD)
|
||||
if len(seq) == 8 and seq.startswith("20"):
|
||||
continue
|
||||
# Skip year-only values (2024, 2025, 2026, etc.)
|
||||
if len(seq) == 4 and seq.startswith("20"):
|
||||
continue
|
||||
# Skip if too long (likely OCR number)
|
||||
if len(seq) > 10:
|
||||
continue
|
||||
valid_sequences.append(seq)
|
||||
|
||||
if valid_sequences:
|
||||
# Return shortest valid sequence
|
||||
return NormalizationResult.success(min(valid_sequences, key=len))
|
||||
# Prefer 4-8 digit sequences (typical invoice numbers),
|
||||
# then closest to 6 digits within that range.
|
||||
# This avoids picking short fragments like "775" from amounts.
|
||||
def _score(seq: str) -> tuple[int, int]:
|
||||
length = len(seq)
|
||||
if 4 <= length <= 8:
|
||||
return (1, -abs(length - 6))
|
||||
return (0, -length)
|
||||
|
||||
return NormalizationResult.success(max(valid_sequences, key=_score))
|
||||
|
||||
# Fallback: extract all digits if nothing else works
|
||||
digits = re.sub(r"\D", "", text)
|
||||
|
||||
@@ -14,7 +14,7 @@ class OcrNumberNormalizer(BaseNormalizer):
|
||||
Normalizes OCR (Optical Character Recognition) reference numbers.
|
||||
|
||||
OCR numbers in Swedish payment systems:
|
||||
- Minimum 5 digits
|
||||
- Minimum 2 digits
|
||||
- Used for automated payment matching
|
||||
"""
|
||||
|
||||
@@ -29,7 +29,7 @@ class OcrNumberNormalizer(BaseNormalizer):
|
||||
|
||||
digits = re.sub(r"\D", "", text)
|
||||
|
||||
if len(digits) < 5:
|
||||
if len(digits) < 2:
|
||||
return NormalizationResult.failure(
|
||||
f"Too few digits for OCR: {len(digits)}"
|
||||
)
|
||||
|
||||
@@ -234,7 +234,7 @@ class InferencePipeline:
|
||||
confidence_threshold=confidence_threshold,
|
||||
device='cuda' if use_gpu else 'cpu'
|
||||
)
|
||||
self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu)
|
||||
self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu, dpi=dpi)
|
||||
self.payment_line_parser = PaymentLineParser()
|
||||
self.dpi = dpi
|
||||
self.enable_fallback = enable_fallback
|
||||
@@ -361,6 +361,7 @@ class InferencePipeline:
|
||||
# Fallback if key fields are missing
|
||||
if self.enable_fallback and self._needs_fallback(result):
|
||||
self._run_fallback(pdf_path, result)
|
||||
self._dedup_invoice_number(result)
|
||||
|
||||
# Extract business invoice features if enabled
|
||||
if use_business_features:
|
||||
@@ -477,9 +478,48 @@ class InferencePipeline:
|
||||
# Store bbox for each field (useful for payment_line and other fields)
|
||||
result.bboxes[field_name] = best.bbox
|
||||
|
||||
# Validate date consistency
|
||||
self._validate_dates(result)
|
||||
|
||||
# Perform cross-validation if payment_line is detected
|
||||
self._cross_validate_payment_line(result)
|
||||
|
||||
# Remove InvoiceNumber if it duplicates OCR or Bankgiro
|
||||
self._dedup_invoice_number(result)
|
||||
|
||||
def _validate_dates(self, result: InferenceResult) -> None:
|
||||
"""Remove InvoiceDueDate if it is earlier than InvoiceDate."""
|
||||
invoice_date = result.fields.get('InvoiceDate')
|
||||
due_date = result.fields.get('InvoiceDueDate')
|
||||
if invoice_date and due_date and due_date < invoice_date:
|
||||
del result.fields['InvoiceDueDate']
|
||||
result.confidence.pop('InvoiceDueDate', None)
|
||||
result.bboxes.pop('InvoiceDueDate', None)
|
||||
|
||||
def _dedup_invoice_number(self, result: InferenceResult) -> None:
|
||||
"""Remove InvoiceNumber if it duplicates OCR or Bankgiro digits."""
|
||||
inv_num = result.fields.get('InvoiceNumber')
|
||||
if not inv_num:
|
||||
return
|
||||
inv_digits = re.sub(r'\D', '', str(inv_num))
|
||||
|
||||
# Check against OCR
|
||||
ocr = result.fields.get('OCR')
|
||||
if ocr and inv_digits == re.sub(r'\D', '', str(ocr)):
|
||||
del result.fields['InvoiceNumber']
|
||||
result.confidence.pop('InvoiceNumber', None)
|
||||
result.bboxes.pop('InvoiceNumber', None)
|
||||
return
|
||||
|
||||
# Check against Bankgiro (exact or substring match)
|
||||
bg = result.fields.get('Bankgiro')
|
||||
if bg:
|
||||
bg_digits = re.sub(r'\D', '', str(bg))
|
||||
if inv_digits == bg_digits or inv_digits in bg_digits:
|
||||
del result.fields['InvoiceNumber']
|
||||
result.confidence.pop('InvoiceNumber', None)
|
||||
result.bboxes.pop('InvoiceNumber', None)
|
||||
|
||||
def _parse_machine_readable_payment_line(self, payment_line: str) -> tuple[str | None, str | None, str | None]:
|
||||
"""
|
||||
Parse machine-readable Swedish payment line format using unified PaymentLineParser.
|
||||
@@ -638,10 +678,14 @@ class InferencePipeline:
|
||||
|
||||
def _needs_fallback(self, result: InferenceResult) -> bool:
|
||||
"""Check if fallback OCR is needed."""
|
||||
# Check for key fields
|
||||
key_fields = ['Amount', 'InvoiceNumber', 'OCR']
|
||||
missing = sum(1 for f in key_fields if f not in result.fields)
|
||||
return missing >= 2 # Fallback if 2+ key fields missing
|
||||
important_fields = ['InvoiceDate', 'InvoiceDueDate', 'supplier_organisation_number']
|
||||
|
||||
key_missing = sum(1 for f in key_fields if f not in result.fields)
|
||||
important_missing = sum(1 for f in important_fields if f not in result.fields)
|
||||
|
||||
# Fallback if any key field missing OR 2+ important fields missing
|
||||
return key_missing >= 1 or important_missing >= 2
|
||||
|
||||
def _run_fallback(self, pdf_path: str | Path, result: InferenceResult) -> None:
|
||||
"""Run full-page OCR fallback."""
|
||||
@@ -673,12 +717,13 @@ class InferencePipeline:
|
||||
"""Extract fields using regex patterns (fallback)."""
|
||||
patterns = {
|
||||
'Amount': [
|
||||
r'(?:att\s*betala|summa|total|belopp)\s*[:.]?\s*([\d\s,\.]+)\s*(?:SEK|kr)?',
|
||||
r'([\d\s,\.]+)\s*(?:SEK|kr)\s*$',
|
||||
r'(?:att\s+betala)\s*[:.]?\s*([\d\s\.]*\d+[,\.]\d{2})\s*(?:SEK|kr)?',
|
||||
r'(?:summa|total|belopp)\s*[:.]?\s*([\d\s\.]*\d+[,\.]\d{2})\s*(?:SEK|kr)?',
|
||||
r'([\d\s\.]*\d+[,\.]\d{2})\s*(?:SEK|kr)\s*$',
|
||||
],
|
||||
'Bankgiro': [
|
||||
r'(?:bankgiro|bg)\s*[:.]?\s*(\d{3,4}[-\s]?\d{4})',
|
||||
r'(\d{4}[-\s]\d{4})\s*(?=\s|$)',
|
||||
r'(?<!\d)(\d{3,4}[-\s]\d{4})(?!\d)',
|
||||
],
|
||||
'OCR': [
|
||||
r'(?:ocr|referens)\s*[:.]?\s*(\d{10,25})',
|
||||
@@ -686,6 +731,20 @@ class InferencePipeline:
|
||||
'InvoiceNumber': [
|
||||
r'(?:fakturanr|fakturanummer|invoice)\s*[:.]?\s*(\d+)',
|
||||
],
|
||||
'InvoiceDate': [
|
||||
r'(?:fakturadatum|invoice\s*date)\s*[:.]?\s*(\d{4}[-/]\d{2}[-/]\d{2})',
|
||||
r'(?:fakturadatum|invoice\s*date)\s*[:.]?\s*(\d{2}[-/]\d{2}[-/]\d{4})',
|
||||
],
|
||||
'InvoiceDueDate': [
|
||||
r'(?:f[oö]rfallo(?:dag|datum)?|due\s*date|betala\s*senast)\s*[:.]?\s*(\d{4}[-/]\d{2}[-/]\d{2})',
|
||||
r'(?:f[oö]rfallo(?:dag|datum)?|due\s*date|betala\s*senast)\s*[:.]?\s*(\d{2}[-/]\d{2}[-/]\d{4})',
|
||||
],
|
||||
'supplier_organisation_number': [
|
||||
r'(?:org\.?\s*n[ru]|organisationsnummer)\s*[:.]?\s*(\d{6}[-\s]?\d{4})',
|
||||
],
|
||||
'Plusgiro': [
|
||||
r'(?:plusgiro|pg)\s*[:.]?\s*(\d[\d\s-]{4,12}\d)',
|
||||
],
|
||||
}
|
||||
|
||||
for field_name, field_patterns in patterns.items():
|
||||
@@ -708,6 +767,22 @@ class InferencePipeline:
|
||||
digits = re.sub(r'\D', '', value)
|
||||
if len(digits) == 8:
|
||||
value = f"{digits[:4]}-{digits[4:]}"
|
||||
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
|
||||
# Normalize DD/MM/YYYY to YYYY-MM-DD
|
||||
date_match = re.match(r'(\d{2})[-/](\d{2})[-/](\d{4})', value)
|
||||
if date_match:
|
||||
value = f"{date_match.group(3)}-{date_match.group(2)}-{date_match.group(1)}"
|
||||
# Replace / with -
|
||||
value = value.replace('/', '-')
|
||||
elif field_name == 'InvoiceNumber':
|
||||
# Skip year-like values (2024, 2025, 2026, etc.)
|
||||
if re.match(r'^20\d{2}$', value):
|
||||
continue
|
||||
elif field_name == 'supplier_organisation_number':
|
||||
# Ensure NNNNNN-NNNN format
|
||||
digits = re.sub(r'\D', '', value)
|
||||
if len(digits) == 10:
|
||||
value = f"{digits[:6]}-{digits[6:]}"
|
||||
|
||||
result.fields[field_name] = value
|
||||
result.confidence[field_name] = 0.5 # Lower confidence for regex
|
||||
|
||||
@@ -123,12 +123,12 @@ class ValueSelector:
|
||||
|
||||
@staticmethod
|
||||
def _select_ocr_number(tokens: list[OCRToken]) -> list[OCRToken]:
|
||||
"""Select token with the longest digit sequence (min 5 digits)."""
|
||||
"""Select token with the longest digit sequence (min 2 digits)."""
|
||||
best: OCRToken | None = None
|
||||
best_count = 0
|
||||
for token in tokens:
|
||||
digit_count = _count_digits(token.text)
|
||||
if digit_count >= 5 and digit_count > best_count:
|
||||
if digit_count >= 2 and digit_count > best_count:
|
||||
best = token
|
||||
best_count = digit_count
|
||||
return [best] if best else []
|
||||
|
||||
Reference in New Issue
Block a user