173 lines
5.4 KiB
Python
173 lines
5.4 KiB
Python
"""
|
|
Value Selector Module.
|
|
|
|
Selects the most likely value token(s) from OCR output per field type,
|
|
filtering out label text before sending to normalizer.
|
|
|
|
Stateless and pure -- easy to test, no side effects.
|
|
"""
|
|
|
|
import re
|
|
from typing import Final
|
|
|
|
from shared.ocr.paddle_ocr import OCRToken
|
|
|
|
|
|
# Swedish label keywords commonly found near field values
|
|
LABEL_KEYWORDS: Final[frozenset[str]] = frozenset({
|
|
"fakturanummer", "fakturanr", "fakturadatum", "forfallodag", "forfalldatum",
|
|
"bankgiro", "plusgiro", "bg", "pg", "ocr", "belopp", "summa",
|
|
"total", "att", "betala", "kundnummer", "organisationsnummer",
|
|
"org", "nr", "datum", "nummer", "ref", "referens",
|
|
"momsreg", "vat", "moms", "sek", "kr",
|
|
"org.nr", "kund", "faktura", "invoice",
|
|
})
|
|
|
|
# Patterns
|
|
_DATE_PATTERN = re.compile(
|
|
r"\d{4}[-./]\d{2}[-./]\d{2}" # 2024-01-15, 2024.01.15
|
|
r"|"
|
|
r"\d{2}[-./]\d{2}[-./]\d{4}" # 15/01/2024
|
|
r"|"
|
|
r"\d{8}" # 20240115
|
|
)
|
|
|
|
_AMOUNT_PATTERN = re.compile(
|
|
r"\d[\d\s.]*,\d{2}(?:\s*(?:kr|SEK))?$" # European: 2.254,50 SEK
|
|
r"|"
|
|
r"\d[\d\s,]*\.\d{2}(?:\s*(?:kr|SEK))?$" # Anglo: 1,234.56 SEK
|
|
)
|
|
|
|
_BANKGIRO_PATTERN = re.compile(
|
|
r"^\d{3,4}-\d{4}$" # 123-4567
|
|
r"|"
|
|
r"^\d{7,8}$" # 1234567 or 12345678
|
|
)
|
|
|
|
_PLUSGIRO_PATTERN = re.compile(
|
|
r"^\d+-\d$" # 12345-6
|
|
r"|"
|
|
r"^\d{2,8}$" # 1234567
|
|
)
|
|
|
|
_ORG_NUMBER_PATTERN = re.compile(
|
|
r"\d{6}-?\d{4}" # 556123-4567 or 5561234567
|
|
)
|
|
|
|
|
|
def _is_label(text: str) -> bool:
|
|
"""Check if token text is a known Swedish label keyword."""
|
|
cleaned = text.lower().rstrip(":").strip()
|
|
return cleaned in LABEL_KEYWORDS
|
|
|
|
|
|
def _count_digits(text: str) -> int:
|
|
"""Count digit characters in text."""
|
|
return sum(c.isdigit() for c in text)
|
|
|
|
|
|
class ValueSelector:
|
|
"""Selects value token(s) from OCR output, filtering label text.
|
|
|
|
Pure static methods -- no state, no side effects.
|
|
Fallback: always returns all tokens if no pattern matches,
|
|
so this can never make results worse than current behavior.
|
|
"""
|
|
|
|
@staticmethod
|
|
def select_value_tokens(
|
|
tokens: list[OCRToken],
|
|
field_name: str,
|
|
) -> list[OCRToken]:
|
|
"""Select the most likely value token(s) for a given field.
|
|
|
|
Args:
|
|
tokens: OCR tokens from the detected region.
|
|
field_name: Normalized field name (e.g. "InvoiceDate", "Amount").
|
|
|
|
Returns:
|
|
Filtered list of value tokens. Falls back to all tokens
|
|
if no field-specific pattern matches.
|
|
"""
|
|
if not tokens:
|
|
return []
|
|
|
|
selector = _FIELD_SELECTORS.get(field_name, _fallback_selector)
|
|
selected = selector(tokens)
|
|
|
|
# Safety: never return empty if we had input tokens
|
|
if not selected:
|
|
return list(tokens)
|
|
|
|
return selected
|
|
|
|
@staticmethod
|
|
def _select_date(tokens: list[OCRToken]) -> list[OCRToken]:
|
|
return _select_by_pattern(tokens, _DATE_PATTERN)
|
|
|
|
@staticmethod
|
|
def _select_amount(tokens: list[OCRToken]) -> list[OCRToken]:
|
|
return _select_by_pattern(tokens, _AMOUNT_PATTERN)
|
|
|
|
@staticmethod
|
|
def _select_bankgiro(tokens: list[OCRToken]) -> list[OCRToken]:
|
|
return _select_by_pattern(tokens, _BANKGIRO_PATTERN)
|
|
|
|
@staticmethod
|
|
def _select_plusgiro(tokens: list[OCRToken]) -> list[OCRToken]:
|
|
return _select_by_pattern(tokens, _PLUSGIRO_PATTERN)
|
|
|
|
@staticmethod
|
|
def _select_org_number(tokens: list[OCRToken]) -> list[OCRToken]:
|
|
return _select_by_pattern(tokens, _ORG_NUMBER_PATTERN)
|
|
|
|
@staticmethod
|
|
def _select_ocr_number(tokens: list[OCRToken]) -> list[OCRToken]:
|
|
"""Select token with the longest digit sequence (min 2 digits)."""
|
|
best: OCRToken | None = None
|
|
best_count = 0
|
|
for token in tokens:
|
|
digit_count = _count_digits(token.text)
|
|
if digit_count >= 2 and digit_count > best_count:
|
|
best = token
|
|
best_count = digit_count
|
|
return [best] if best else []
|
|
|
|
@staticmethod
|
|
def _select_by_label_exclusion(tokens: list[OCRToken]) -> list[OCRToken]:
|
|
"""Remove label keywords, keep remaining tokens."""
|
|
return [t for t in tokens if not _is_label(t.text)]
|
|
|
|
@staticmethod
|
|
def _select_payment_line(tokens: list[OCRToken]) -> list[OCRToken]:
|
|
"""Payment line keeps all tokens (needs full text)."""
|
|
return list(tokens)
|
|
|
|
|
|
def _select_by_pattern(
|
|
tokens: list[OCRToken],
|
|
pattern: re.Pattern[str],
|
|
) -> list[OCRToken]:
|
|
"""Select tokens matching a regex pattern."""
|
|
return [t for t in tokens if pattern.search(t.text.strip())]
|
|
|
|
|
|
def _fallback_selector(tokens: list[OCRToken]) -> list[OCRToken]:
|
|
"""Default: return all tokens unchanged."""
|
|
return list(tokens)
|
|
|
|
|
|
# Map field names to selector functions
|
|
_FIELD_SELECTORS: Final[dict[str, callable]] = {
|
|
"InvoiceDate": ValueSelector._select_date,
|
|
"InvoiceDueDate": ValueSelector._select_date,
|
|
"Amount": ValueSelector._select_amount,
|
|
"Bankgiro": ValueSelector._select_bankgiro,
|
|
"Plusgiro": ValueSelector._select_plusgiro,
|
|
"OCR": ValueSelector._select_ocr_number,
|
|
"InvoiceNumber": ValueSelector._select_by_label_exclusion,
|
|
"supplier_org_number": ValueSelector._select_org_number,
|
|
"customer_number": ValueSelector._select_by_label_exclusion,
|
|
"payment_line": ValueSelector._select_payment_line,
|
|
}
|