- Add MachineCodeParser for Swedish invoice payment line parsing - Fix OCR Reference extraction by normalizing account number spaces - Add cross-validation tests for pipeline and field_extractor - Update UI layout for compact upload and full-width results Key changes: - machine_code_parser.py: Handle spaces in Bankgiro numbers (e.g. "78 2 1 713") - pipeline.py: OCR and Amount override from payment_line, BG/PG comparison only - field_extractor.py: Improved invoice number normalization - app.py: Responsive UI layout changes Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
782 lines
29 KiB
Python
782 lines
29 KiB
Python
"""
|
|
Field Extractor Module
|
|
|
|
Extracts and validates field values from detected regions.
|
|
"""
|
|
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any
|
|
import re
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
from .yolo_detector import Detection, CLASS_TO_FIELD
|
|
|
|
|
|
@dataclass
|
|
class ExtractedField:
|
|
"""Represents an extracted field value."""
|
|
field_name: str
|
|
raw_text: str
|
|
normalized_value: str | None
|
|
confidence: float
|
|
detection_confidence: float
|
|
ocr_confidence: float
|
|
bbox: tuple[float, float, float, float]
|
|
page_no: int
|
|
is_valid: bool = True
|
|
validation_error: str | None = None
|
|
|
|
def to_dict(self) -> dict:
|
|
"""Convert to dictionary."""
|
|
return {
|
|
'field_name': self.field_name,
|
|
'value': self.normalized_value,
|
|
'raw_text': self.raw_text,
|
|
'confidence': self.confidence,
|
|
'bbox': list(self.bbox),
|
|
'page_no': self.page_no,
|
|
'is_valid': self.is_valid,
|
|
'validation_error': self.validation_error
|
|
}
|
|
|
|
|
|
class FieldExtractor:
|
|
"""Extracts field values from detected regions using OCR or PDF text."""
|
|
|
|
def __init__(
|
|
self,
|
|
ocr_lang: str = 'en',
|
|
use_gpu: bool = False,
|
|
bbox_padding: float = 0.1,
|
|
dpi: int = 300
|
|
):
|
|
"""
|
|
Initialize field extractor.
|
|
|
|
Args:
|
|
ocr_lang: Language for OCR
|
|
use_gpu: Whether to use GPU for OCR
|
|
bbox_padding: Padding to add around bboxes (as fraction)
|
|
dpi: DPI used for rendering (for coordinate conversion)
|
|
"""
|
|
self.ocr_lang = ocr_lang
|
|
self.use_gpu = use_gpu
|
|
self.bbox_padding = bbox_padding
|
|
self.dpi = dpi
|
|
self._ocr_engine = None # Lazy init
|
|
|
|
@property
|
|
def ocr_engine(self):
|
|
"""Lazy-load OCR engine only when needed."""
|
|
if self._ocr_engine is None:
|
|
from ..ocr import OCREngine
|
|
self._ocr_engine = OCREngine(lang=self.ocr_lang)
|
|
return self._ocr_engine
|
|
|
|
def extract_from_detection_with_pdf(
|
|
self,
|
|
detection: Detection,
|
|
pdf_tokens: list,
|
|
image_width: int,
|
|
image_height: int
|
|
) -> ExtractedField:
|
|
"""
|
|
Extract field value using PDF text tokens (faster and more accurate for text PDFs).
|
|
|
|
Args:
|
|
detection: Detection object with bbox in pixel coordinates
|
|
pdf_tokens: List of Token objects from PDF text extraction
|
|
image_width: Width of rendered image in pixels
|
|
image_height: Height of rendered image in pixels
|
|
|
|
Returns:
|
|
ExtractedField object
|
|
"""
|
|
# Convert detection bbox from pixels to PDF points
|
|
scale = 72 / self.dpi # points per pixel
|
|
x0_pdf = detection.bbox[0] * scale
|
|
y0_pdf = detection.bbox[1] * scale
|
|
x1_pdf = detection.bbox[2] * scale
|
|
y1_pdf = detection.bbox[3] * scale
|
|
|
|
# Add padding in points
|
|
pad = 3 # Small padding in points
|
|
|
|
# Find tokens that overlap with detection bbox
|
|
matching_tokens = []
|
|
for token in pdf_tokens:
|
|
if token.page_no != detection.page_no:
|
|
continue
|
|
tx0, ty0, tx1, ty1 = token.bbox
|
|
# Check overlap
|
|
if (tx0 < x1_pdf + pad and tx1 > x0_pdf - pad and
|
|
ty0 < y1_pdf + pad and ty1 > y0_pdf - pad):
|
|
# Calculate overlap ratio to prioritize better matches
|
|
overlap_x = min(tx1, x1_pdf) - max(tx0, x0_pdf)
|
|
overlap_y = min(ty1, y1_pdf) - max(ty0, y0_pdf)
|
|
if overlap_x > 0 and overlap_y > 0:
|
|
token_area = (tx1 - tx0) * (ty1 - ty0)
|
|
overlap_area = overlap_x * overlap_y
|
|
overlap_ratio = overlap_area / token_area if token_area > 0 else 0
|
|
matching_tokens.append((token, overlap_ratio))
|
|
|
|
# Sort by overlap ratio and combine text
|
|
matching_tokens.sort(key=lambda x: -x[1])
|
|
raw_text = ' '.join(t[0].text for t in matching_tokens)
|
|
|
|
# Get field name
|
|
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
|
|
|
|
# Normalize and validate
|
|
normalized_value, is_valid, validation_error = self._normalize_and_validate(
|
|
field_name, raw_text
|
|
)
|
|
|
|
return ExtractedField(
|
|
field_name=field_name,
|
|
raw_text=raw_text,
|
|
normalized_value=normalized_value,
|
|
confidence=detection.confidence if normalized_value else detection.confidence * 0.5,
|
|
detection_confidence=detection.confidence,
|
|
ocr_confidence=1.0, # PDF text is always accurate
|
|
bbox=detection.bbox,
|
|
page_no=detection.page_no,
|
|
is_valid=is_valid,
|
|
validation_error=validation_error
|
|
)
|
|
|
|
def extract_from_detection(
|
|
self,
|
|
detection: Detection,
|
|
image: np.ndarray | Image.Image
|
|
) -> ExtractedField:
|
|
"""
|
|
Extract field value from a detection region using OCR.
|
|
|
|
Args:
|
|
detection: Detection object
|
|
image: Full page image
|
|
|
|
Returns:
|
|
ExtractedField object
|
|
"""
|
|
if isinstance(image, Image.Image):
|
|
image = np.array(image)
|
|
|
|
# Get padded bbox
|
|
h, w = image.shape[:2]
|
|
bbox = detection.get_padded_bbox(self.bbox_padding, w, h)
|
|
|
|
# Crop region
|
|
x0, y0, x1, y1 = [int(v) for v in bbox]
|
|
region = image[y0:y1, x0:x1]
|
|
|
|
# Run OCR on region
|
|
ocr_tokens = self.ocr_engine.extract_from_image(region)
|
|
|
|
# Combine all OCR text
|
|
raw_text = ' '.join(t.text for t in ocr_tokens)
|
|
ocr_confidence = sum(t.confidence for t in ocr_tokens) / len(ocr_tokens) if ocr_tokens else 0.0
|
|
|
|
# Get field name
|
|
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
|
|
|
|
# Normalize and validate
|
|
normalized_value, is_valid, validation_error = self._normalize_and_validate(
|
|
field_name, raw_text
|
|
)
|
|
|
|
# Combined confidence
|
|
confidence = (detection.confidence + ocr_confidence) / 2 if ocr_tokens else detection.confidence * 0.5
|
|
|
|
return ExtractedField(
|
|
field_name=field_name,
|
|
raw_text=raw_text,
|
|
normalized_value=normalized_value,
|
|
confidence=confidence,
|
|
detection_confidence=detection.confidence,
|
|
ocr_confidence=ocr_confidence,
|
|
bbox=detection.bbox,
|
|
page_no=detection.page_no,
|
|
is_valid=is_valid,
|
|
validation_error=validation_error
|
|
)
|
|
|
|
def _normalize_and_validate(
|
|
self,
|
|
field_name: str,
|
|
raw_text: str
|
|
) -> tuple[str | None, bool, str | None]:
|
|
"""
|
|
Normalize and validate extracted text for a field.
|
|
|
|
Returns:
|
|
(normalized_value, is_valid, validation_error)
|
|
"""
|
|
text = raw_text.strip()
|
|
|
|
if not text:
|
|
return None, False, "Empty text"
|
|
|
|
if field_name == 'InvoiceNumber':
|
|
return self._normalize_invoice_number(text)
|
|
|
|
elif field_name == 'OCR':
|
|
return self._normalize_ocr_number(text)
|
|
|
|
elif field_name == 'Bankgiro':
|
|
return self._normalize_bankgiro(text)
|
|
|
|
elif field_name == 'Plusgiro':
|
|
return self._normalize_plusgiro(text)
|
|
|
|
elif field_name == 'Amount':
|
|
return self._normalize_amount(text)
|
|
|
|
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
|
|
return self._normalize_date(text)
|
|
|
|
elif field_name == 'payment_line':
|
|
return self._normalize_payment_line(text)
|
|
|
|
elif field_name == 'supplier_org_number':
|
|
return self._normalize_supplier_org_number(text)
|
|
|
|
elif field_name == 'customer_number':
|
|
return self._normalize_customer_number(text)
|
|
|
|
else:
|
|
return text, True, None
|
|
|
|
def _normalize_invoice_number(self, text: str) -> tuple[str | None, bool, str | None]:
|
|
"""
|
|
Normalize invoice number.
|
|
|
|
Invoice numbers can be:
|
|
- Pure digits: 12345678
|
|
- Alphanumeric: A3861, INV-2024-001, F12345
|
|
- With separators: 2024/001, 2024-001
|
|
|
|
Strategy:
|
|
1. Look for common invoice number patterns
|
|
2. Prefer shorter, more specific matches over long digit sequences
|
|
"""
|
|
# Pattern 1: Alphanumeric invoice number (letter + digits or digits + letter)
|
|
# Examples: A3861, F12345, INV001
|
|
alpha_patterns = [
|
|
r'\b([A-Z]{1,3}\d{3,10})\b', # A3861, INV12345
|
|
r'\b(\d{3,10}[A-Z]{1,3})\b', # 12345A
|
|
r'\b([A-Z]{2,5}[-/]?\d{3,10})\b', # INV-12345, FAK12345
|
|
]
|
|
|
|
for pattern in alpha_patterns:
|
|
match = re.search(pattern, text, re.IGNORECASE)
|
|
if match:
|
|
return match.group(1).upper(), True, None
|
|
|
|
# Pattern 2: Invoice number with year prefix (2024-001, 2024/12345)
|
|
year_pattern = r'\b(20\d{2}[-/]\d{3,8})\b'
|
|
match = re.search(year_pattern, text)
|
|
if match:
|
|
return match.group(1), True, None
|
|
|
|
# Pattern 3: Short digit sequence (3-10 digits) - prefer shorter sequences
|
|
# This avoids capturing long OCR numbers
|
|
digit_sequences = re.findall(r'\b(\d{3,10})\b', text)
|
|
if digit_sequences:
|
|
# Prefer shorter sequences (more likely to be invoice number)
|
|
# Also filter out sequences that look like dates (8 digits starting with 20)
|
|
valid_sequences = []
|
|
for seq in digit_sequences:
|
|
# Skip if it looks like a date (YYYYMMDD)
|
|
if len(seq) == 8 and seq.startswith('20'):
|
|
continue
|
|
# Skip if too long (likely OCR number)
|
|
if len(seq) > 10:
|
|
continue
|
|
valid_sequences.append(seq)
|
|
|
|
if valid_sequences:
|
|
# Return shortest valid sequence
|
|
return min(valid_sequences, key=len), True, None
|
|
|
|
# Fallback: extract all digits if nothing else works
|
|
digits = re.sub(r'\D', '', text)
|
|
if len(digits) >= 3:
|
|
# Limit to first 15 digits to avoid very long sequences
|
|
return digits[:15], True, "Fallback extraction"
|
|
|
|
return None, False, f"Cannot extract invoice number from: {text[:50]}"
|
|
|
|
def _normalize_ocr_number(self, text: str) -> tuple[str | None, bool, str | None]:
|
|
"""Normalize OCR number."""
|
|
digits = re.sub(r'\D', '', text)
|
|
|
|
if len(digits) < 5:
|
|
return None, False, f"Too few digits for OCR: {len(digits)}"
|
|
|
|
return digits, True, None
|
|
|
|
def _luhn_checksum(self, digits: str) -> bool:
|
|
"""
|
|
Validate using Luhn (Mod10) algorithm.
|
|
Used for Bankgiro, Plusgiro, and OCR number validation.
|
|
|
|
The checksum is valid if the total modulo 10 equals 0.
|
|
"""
|
|
if not digits.isdigit():
|
|
return False
|
|
|
|
total = 0
|
|
for i, char in enumerate(reversed(digits)):
|
|
digit = int(char)
|
|
if i % 2 == 1: # Double every second digit from right
|
|
digit *= 2
|
|
if digit > 9:
|
|
digit -= 9
|
|
total += digit
|
|
|
|
return total % 10 == 0
|
|
|
|
def _detect_giro_type(self, text: str) -> str | None:
|
|
"""
|
|
Detect if text matches BG or PG display format pattern.
|
|
|
|
BG typical format: ^\d{3,4}-\d{4}$ (e.g., 123-4567, 1234-5678)
|
|
PG typical format: ^\d{1,7}-\d$ (e.g., 1-8, 12345-6, 1234567-8)
|
|
|
|
Returns: 'BG', 'PG', or None if cannot determine
|
|
"""
|
|
text = text.strip()
|
|
|
|
# BG pattern: 3-4 digits, dash, 4 digits (total 7-8 digits)
|
|
if re.match(r'^\d{3,4}-\d{4}$', text):
|
|
return 'BG'
|
|
|
|
# PG pattern: 1-7 digits, dash, 1 digit (total 2-8 digits)
|
|
if re.match(r'^\d{1,7}-\d$', text):
|
|
return 'PG'
|
|
|
|
return None
|
|
|
|
def _normalize_bankgiro(self, text: str) -> tuple[str | None, bool, str | None]:
|
|
"""
|
|
Normalize Bankgiro number.
|
|
|
|
Bankgiro rules:
|
|
- 7 or 8 digits only
|
|
- Last digit is Luhn (Mod10) check digit
|
|
- Display format: XXX-XXXX (7 digits) or XXXX-XXXX (8 digits)
|
|
|
|
Display pattern: ^\d{3,4}-\d{4}$
|
|
Normalized pattern: ^\d{7,8}$
|
|
|
|
Note: Text may contain both BG and PG numbers. We specifically look for
|
|
BG display format (XXX-XXXX or XXXX-XXXX) to extract the correct one.
|
|
"""
|
|
# Look for BG display format pattern: 3-4 digits, dash, 4 digits
|
|
# This distinguishes BG from PG which uses X-X format (digits-single digit)
|
|
bg_matches = re.findall(r'(\d{3,4})-(\d{4})', text)
|
|
|
|
if bg_matches:
|
|
# Try each match and find one with valid Luhn
|
|
for match in bg_matches:
|
|
digits = match[0] + match[1]
|
|
if len(digits) in (7, 8) and self._luhn_checksum(digits):
|
|
# Valid BG found
|
|
if len(digits) == 8:
|
|
formatted = f"{digits[:4]}-{digits[4:]}"
|
|
else:
|
|
formatted = f"{digits[:3]}-{digits[3:]}"
|
|
return formatted, True, None
|
|
|
|
# No valid Luhn, use first match
|
|
digits = bg_matches[0][0] + bg_matches[0][1]
|
|
if len(digits) in (7, 8):
|
|
if len(digits) == 8:
|
|
formatted = f"{digits[:4]}-{digits[4:]}"
|
|
else:
|
|
formatted = f"{digits[:3]}-{digits[3:]}"
|
|
return formatted, True, f"Luhn checksum failed (possible OCR error)"
|
|
|
|
# Fallback: try to find 7-8 consecutive digits
|
|
# But first check if text contains PG format (XXXXXXX-X), if so don't use fallback
|
|
# to avoid misinterpreting PG as BG
|
|
pg_format_present = re.search(r'(?<![0-9])\d{1,7}-\d(?!\d)', text)
|
|
if pg_format_present:
|
|
return None, False, f"No valid Bankgiro found in text"
|
|
|
|
digit_match = re.search(r'\b(\d{7,8})\b', text)
|
|
if digit_match:
|
|
digits = digit_match.group(1)
|
|
if len(digits) in (7, 8):
|
|
luhn_ok = self._luhn_checksum(digits)
|
|
if len(digits) == 8:
|
|
formatted = f"{digits[:4]}-{digits[4:]}"
|
|
else:
|
|
formatted = f"{digits[:3]}-{digits[3:]}"
|
|
if luhn_ok:
|
|
return formatted, True, None
|
|
else:
|
|
return formatted, True, f"Luhn checksum failed (possible OCR error)"
|
|
|
|
return None, False, f"No valid Bankgiro found in text"
|
|
|
|
def _normalize_plusgiro(self, text: str) -> tuple[str | None, bool, str | None]:
|
|
"""
|
|
Normalize Plusgiro number.
|
|
|
|
Plusgiro rules:
|
|
- 2 to 8 digits
|
|
- Last digit is Luhn (Mod10) check digit
|
|
- Display format: XXXXXXX-X (all digits except last, dash, last digit)
|
|
|
|
Display pattern: ^\d{1,7}-\d$
|
|
Normalized pattern: ^\d{2,8}$
|
|
|
|
Note: Text may contain both BG and PG numbers. We specifically look for
|
|
PG display format (X-X, XX-X, ..., XXXXXXX-X) to extract the correct one.
|
|
"""
|
|
# First look for PG display format: 1-7 digits (possibly with spaces), dash, 1 digit
|
|
# This is distinct from BG format which has 4 digits after the dash
|
|
# Pattern allows spaces within the number like "486 98 63-6"
|
|
# (?<![0-9]) ensures we don't start from within another number (like BG)
|
|
pg_matches = re.findall(r'(?<![0-9])(\d[\d\s]{0,10})-(\d)(?!\d)', text)
|
|
|
|
if pg_matches:
|
|
# Try each match and find one with valid Luhn
|
|
for match in pg_matches:
|
|
# Remove spaces from the first part
|
|
digits = re.sub(r'\s', '', match[0]) + match[1]
|
|
if 2 <= len(digits) <= 8 and self._luhn_checksum(digits):
|
|
# Valid PG found
|
|
formatted = f"{digits[:-1]}-{digits[-1]}"
|
|
return formatted, True, None
|
|
|
|
# No valid Luhn, use first match with most digits
|
|
best_match = max(pg_matches, key=lambda m: len(re.sub(r'\s', '', m[0])))
|
|
digits = re.sub(r'\s', '', best_match[0]) + best_match[1]
|
|
if 2 <= len(digits) <= 8:
|
|
formatted = f"{digits[:-1]}-{digits[-1]}"
|
|
return formatted, True, f"Luhn checksum failed (possible OCR error)"
|
|
|
|
# If no PG format found, extract all digits and format as PG
|
|
# This handles cases where the number might be in BG format or raw digits
|
|
all_digits = re.sub(r'\D', '', text)
|
|
|
|
# Try to find a valid 2-8 digit sequence
|
|
if 2 <= len(all_digits) <= 8:
|
|
luhn_ok = self._luhn_checksum(all_digits)
|
|
formatted = f"{all_digits[:-1]}-{all_digits[-1]}"
|
|
if luhn_ok:
|
|
return formatted, True, None
|
|
else:
|
|
return formatted, True, f"Luhn checksum failed (possible OCR error)"
|
|
|
|
# Try to find any 2-8 digit sequence in text
|
|
digit_match = re.search(r'\b(\d{2,8})\b', text)
|
|
if digit_match:
|
|
digits = digit_match.group(1)
|
|
luhn_ok = self._luhn_checksum(digits)
|
|
formatted = f"{digits[:-1]}-{digits[-1]}"
|
|
if luhn_ok:
|
|
return formatted, True, None
|
|
else:
|
|
return formatted, True, f"Luhn checksum failed (possible OCR error)"
|
|
|
|
return None, False, f"No valid Plusgiro found in text"
|
|
|
|
def _normalize_amount(self, text: str) -> tuple[str | None, bool, str | None]:
|
|
"""Normalize monetary amount."""
|
|
# Try to extract amount using regex patterns
|
|
# Pattern 1: Number with comma as decimal (Swedish format: 1 234,56)
|
|
# Pattern 2: Number with dot as decimal (1234.56)
|
|
# Pattern 3: Number followed by currency (275,60 kr or 275.60 SEK)
|
|
|
|
patterns = [
|
|
# Swedish format with space thousand separator: 1 234,56 or 1234,56
|
|
r'(\d[\d\s]*[,\.]\d{2})\s*(?:kr|SEK)?',
|
|
# Simple decimal: 350.00 or 350,00
|
|
r'(\d+[,\.]\d{2})',
|
|
# Integer amount
|
|
r'(\d{2,})',
|
|
]
|
|
|
|
for pattern in patterns:
|
|
matches = re.findall(pattern, text, re.IGNORECASE)
|
|
if matches:
|
|
# Take the last match (usually the total amount)
|
|
amount_str = matches[-1]
|
|
# Clean up
|
|
amount_str = amount_str.replace(' ', '').replace('\xa0', '')
|
|
# Handle comma as decimal separator
|
|
if ',' in amount_str:
|
|
amount_str = amount_str.replace(',', '.')
|
|
|
|
try:
|
|
amount = float(amount_str)
|
|
if amount > 0:
|
|
return f"{amount:.2f}", True, None
|
|
except ValueError:
|
|
continue
|
|
|
|
return None, False, f"Cannot parse amount: {text}"
|
|
|
|
def _normalize_date(self, text: str) -> tuple[str | None, bool, str | None]:
|
|
"""
|
|
Normalize date from text that may contain surrounding text.
|
|
|
|
Handles various date formats found in Swedish invoices:
|
|
- 2025-08-29 (ISO format)
|
|
- 2025.08.29 (dot separator)
|
|
- 29/08/2025 (European format)
|
|
- 29.08.2025 (European with dots)
|
|
- 20250829 (compact format)
|
|
"""
|
|
from datetime import datetime
|
|
|
|
# Common date patterns - order matters, most specific first
|
|
patterns = [
|
|
# ISO format: 2025-08-29
|
|
(r'(\d{4})-(\d{1,2})-(\d{1,2})', lambda m: f"{m.group(1)}-{int(m.group(2)):02d}-{int(m.group(3)):02d}"),
|
|
# Dot format: 2025.08.29 (common in Swedish)
|
|
(r'(\d{4})\.(\d{1,2})\.(\d{1,2})', lambda m: f"{m.group(1)}-{int(m.group(2)):02d}-{int(m.group(3)):02d}"),
|
|
# European slash format: 29/08/2025
|
|
(r'(\d{1,2})/(\d{1,2})/(\d{4})', lambda m: f"{m.group(3)}-{int(m.group(2)):02d}-{int(m.group(1)):02d}"),
|
|
# European dot format: 29.08.2025
|
|
(r'(\d{1,2})\.(\d{1,2})\.(\d{4})', lambda m: f"{m.group(3)}-{int(m.group(2)):02d}-{int(m.group(1)):02d}"),
|
|
# Compact format: 20250829
|
|
(r'(?<!\d)(\d{4})(\d{2})(\d{2})(?!\d)', lambda m: f"{m.group(1)}-{m.group(2)}-{m.group(3)}"),
|
|
]
|
|
|
|
for pattern, formatter in patterns:
|
|
match = re.search(pattern, text)
|
|
if match:
|
|
try:
|
|
date_str = formatter(match)
|
|
# Validate date
|
|
parsed_date = datetime.strptime(date_str, '%Y-%m-%d')
|
|
# Sanity check: year should be reasonable (2000-2100)
|
|
if 2000 <= parsed_date.year <= 2100:
|
|
return date_str, True, None
|
|
except ValueError:
|
|
continue
|
|
|
|
return None, False, f"Cannot parse date: {text}"
|
|
|
|
def _normalize_payment_line(self, text: str) -> tuple[str | None, bool, str | None]:
|
|
"""
|
|
Normalize payment line region text.
|
|
|
|
Extracts OCR, Amount, and Bankgiro from the payment line using MachineCodeParser.
|
|
"""
|
|
from ..ocr.machine_code_parser import MachineCodeParser
|
|
|
|
# Create a simple token-like structure for the parser
|
|
# (The parser expects tokens, but for inference we have raw text)
|
|
parser = MachineCodeParser()
|
|
|
|
# Try to parse the standard payment line format
|
|
result = parser._parse_standard_payment_line(text)
|
|
|
|
if result:
|
|
# Format as structured output
|
|
parts = []
|
|
if result.get('ocr'):
|
|
parts.append(f"OCR:{result['ocr']}")
|
|
if result.get('amount'):
|
|
parts.append(f"Amount:{result['amount']}")
|
|
if result.get('bankgiro'):
|
|
parts.append(f"BG:{result['bankgiro']}")
|
|
|
|
if parts:
|
|
return ' '.join(parts), True, None
|
|
|
|
# Fallback: return raw text if no structured parsing possible
|
|
return text, True, None
|
|
|
|
def _normalize_supplier_org_number(self, text: str) -> tuple[str | None, bool, str | None]:
|
|
"""
|
|
Normalize Swedish supplier organization number.
|
|
|
|
Extracts organization number in format: NNNNNN-NNNN (10 digits)
|
|
Also handles VAT numbers: SE + 10 digits + 01
|
|
|
|
Examples:
|
|
'org.nr. 516406-1102, Filialregistret...' -> '516406-1102'
|
|
'Momsreg.nr SE556123456701' -> '556123-4567'
|
|
"""
|
|
# Pattern 1: Standard org number format: NNNNNN-NNNN
|
|
org_pattern = r'\b(\d{6})-?(\d{4})\b'
|
|
match = re.search(org_pattern, text)
|
|
if match:
|
|
org_num = f"{match.group(1)}-{match.group(2)}"
|
|
return org_num, True, None
|
|
|
|
# Pattern 2: VAT number format: SE + 10 digits + 01
|
|
vat_pattern = r'SE\s*(\d{10})01'
|
|
match = re.search(vat_pattern, text, re.IGNORECASE)
|
|
if match:
|
|
digits = match.group(1)
|
|
org_num = f"{digits[:6]}-{digits[6:]}"
|
|
return org_num, True, None
|
|
|
|
# Pattern 3: Just 10 consecutive digits
|
|
digits_pattern = r'\b(\d{10})\b'
|
|
match = re.search(digits_pattern, text)
|
|
if match:
|
|
digits = match.group(1)
|
|
# Validate: first digit should be 1-9 for Swedish org numbers
|
|
if digits[0] in '123456789':
|
|
org_num = f"{digits[:6]}-{digits[6:]}"
|
|
return org_num, True, None
|
|
|
|
return None, False, f"Cannot extract org number from: {text[:100]}"
|
|
|
|
def _normalize_customer_number(self, text: str) -> tuple[str | None, bool, str | None]:
|
|
"""
|
|
Normalize customer number extracted from OCR.
|
|
|
|
Customer numbers can have various formats:
|
|
- With separators: 'JTY 576-3', 'EMM 256-6', 'FFL 019N'
|
|
- Compact (no separators): 'JTY5763', 'EMM2566', 'FFL019N'
|
|
- Mixed with names: 'VIKSTRÖM, ELIAS CH FFL 01' -> extract 'FFL 01'
|
|
|
|
Note: Spaces and dashes may be removed from invoice display,
|
|
so we need to match both 'JTY 576-3' and 'JTY5763' formats.
|
|
"""
|
|
from ..normalize.normalizer import FieldNormalizer
|
|
|
|
# Clean the text using the same logic as matcher
|
|
text = FieldNormalizer.clean_text(text)
|
|
|
|
if not text:
|
|
return None, False, "Empty text"
|
|
|
|
# Customer number patterns - ordered by specificity
|
|
# Match both spaced/dashed versions and compact versions
|
|
customer_code_patterns = [
|
|
# Pattern: Letters + space/dash + digits + dash + digit (EMM 256-6, JTY 576-3)
|
|
r'\b([A-Z]{2,4}[\s\-]?\d{1,4}[\s\-]\d{1,2}[A-Z]?)\b',
|
|
# Pattern: Letters + space/dash + digits + optional letter (FFL 019N, ABC 123X)
|
|
r'\b([A-Z]{2,4}[\s\-]\d{2,4}[A-Z]?)\b',
|
|
# Pattern: Compact format - letters immediately followed by digits + optional letter (JTY5763, FFL019N)
|
|
r'\b([A-Z]{2,4}\d{3,6}[A-Z]?)\b',
|
|
# Pattern: Single letter + digits (A12345)
|
|
r'\b([A-Z]\d{4,6}[A-Z]?)\b',
|
|
# Pattern: Digits + dash/space + digits (123-456)
|
|
r'\b(\d{3,6}[\s\-]\d{1,4})\b',
|
|
]
|
|
|
|
all_matches = []
|
|
for pattern in customer_code_patterns:
|
|
matches = re.findall(pattern, text, re.IGNORECASE)
|
|
all_matches.extend(matches)
|
|
|
|
if all_matches:
|
|
# Prefer longer matches and those appearing later in text (after names)
|
|
# Sort by position in text (later = better) and length (longer = better)
|
|
scored_matches = []
|
|
for match in all_matches:
|
|
pos = text.upper().rfind(match.upper())
|
|
# Score: position * 0.1 + length (prefer later and longer)
|
|
score = pos * 0.1 + len(match)
|
|
scored_matches.append((score, match))
|
|
|
|
best_match = max(scored_matches, key=lambda x: x[0])[1]
|
|
return best_match.strip().upper(), True, None
|
|
|
|
# Pattern 2: Look for explicit labels
|
|
labeled_patterns = [
|
|
r'(?:kund(?:nr|nummer|id)?|ert?\s*(?:kund)?(?:nr|nummer)?|customer\s*(?:no|number|id)?)\s*[:\.]?\s*([A-Za-z0-9][\w\s\-]{1,20}?)(?:\s{2,}|\n|$)',
|
|
]
|
|
|
|
for pattern in labeled_patterns:
|
|
match = re.search(pattern, text, re.IGNORECASE)
|
|
if match:
|
|
extracted = match.group(1).strip()
|
|
extracted = re.sub(r'[\s\.\,\:]+$', '', extracted)
|
|
if extracted and len(extracted) >= 2:
|
|
return extracted.upper(), True, None
|
|
|
|
# Pattern 3: If text contains comma (likely "NAME, NAME CODE"), extract after last comma
|
|
if ',' in text:
|
|
after_comma = text.split(',')[-1].strip()
|
|
# Look for alphanumeric code in the part after comma
|
|
for pattern in customer_code_patterns[:3]: # Use first 3 patterns
|
|
code_match = re.search(pattern, after_comma, re.IGNORECASE)
|
|
if code_match:
|
|
return code_match.group(1).strip().upper(), True, None
|
|
|
|
# Pattern 4: Short text - filter out name-like words
|
|
if len(text) <= 20:
|
|
words = text.split()
|
|
code_parts = []
|
|
for word in words:
|
|
# Keep if: contains digits, or is short uppercase (likely abbreviation)
|
|
if re.search(r'\d', word) or (len(word) <= 4 and word.isupper()):
|
|
code_parts.append(word)
|
|
if code_parts:
|
|
result = ' '.join(code_parts).upper()
|
|
if len(result) >= 3:
|
|
return result, True, None
|
|
|
|
# Fallback: return cleaned text if reasonable
|
|
if text and 3 <= len(text) <= 15:
|
|
return text.upper(), True, None
|
|
|
|
return None, False, f"Cannot extract customer number from: {text[:50]}"
|
|
|
|
def extract_all_fields(
|
|
self,
|
|
detections: list[Detection],
|
|
image: np.ndarray | Image.Image
|
|
) -> list[ExtractedField]:
|
|
"""
|
|
Extract fields from all detections.
|
|
|
|
Args:
|
|
detections: List of detections
|
|
image: Full page image
|
|
|
|
Returns:
|
|
List of ExtractedField objects
|
|
"""
|
|
fields = []
|
|
|
|
for detection in detections:
|
|
field = self.extract_from_detection(detection, image)
|
|
fields.append(field)
|
|
|
|
return fields
|
|
|
|
@staticmethod
|
|
def infer_ocr_from_invoice_number(fields: dict[str, str]) -> dict[str, str]:
|
|
"""
|
|
Infer OCR field from InvoiceNumber if not detected.
|
|
|
|
In Swedish invoices, OCR reference number is often identical to InvoiceNumber.
|
|
When OCR is not detected but InvoiceNumber is, we can infer OCR value.
|
|
|
|
Args:
|
|
fields: Dict of field_name -> normalized_value
|
|
|
|
Returns:
|
|
Updated fields dict with inferred OCR if applicable
|
|
"""
|
|
# If OCR already exists, no need to infer
|
|
if fields.get('OCR'):
|
|
return fields
|
|
|
|
# If InvoiceNumber exists and is numeric, use it as OCR
|
|
invoice_number = fields.get('InvoiceNumber')
|
|
if invoice_number:
|
|
# Check if it's mostly digits (valid OCR reference)
|
|
digits_only = re.sub(r'\D', '', invoice_number)
|
|
if len(digits_only) >= 5 and len(digits_only) == len(invoice_number):
|
|
fields['OCR'] = invoice_number
|
|
|
|
return fields
|